Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Commit e6a1b1c

Browse files
authored
Merge pull request #202 from datafold/aug11
[Tests] now using connect() instead of connect_to_uri(); refactor
2 parents 27a756e + 6054dbc commit e6a1b1c

File tree

5 files changed

+38
-56
lines changed

5 files changed

+38
-56
lines changed

data_diff/diff_tables.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,9 @@ def __post_init__(self):
9191
raise ValueError(f"Error: min_key expected to be smaller than max_key! ({self.min_key} >= {self.max_key})")
9292

9393
if self.min_update is not None and self.max_update is not None and self.min_update >= self.max_update:
94-
raise ValueError(f"Error: min_update expected to be smaller than max_update! ({self.min_update} >= {self.max_update})")
94+
raise ValueError(
95+
f"Error: min_update expected to be smaller than max_update! ({self.min_update} >= {self.max_update})"
96+
)
9597

9698
@property
9799
def _update_column(self):

data_diff/utils.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from abc import ABC, abstractmethod
66
from urllib.parse import urlparse
77
from uuid import UUID
8+
import operator
89
import string
910

1011
alphanums = string.digits + string.ascii_lowercase
@@ -221,3 +222,19 @@ def match_like(pattern: str, strs: Sequence[str]) -> Iterable[str]:
221222
for s in strs:
222223
if reo.match(s):
223224
yield s
225+
226+
227+
def accumulate(iterable, func=operator.add, *, initial=None):
228+
'Return running totals'
229+
# Taken from https://docs.python.org/3/library/itertools.html#itertools.accumulate, to backport 'initial' to 3.7
230+
it = iter(iterable)
231+
total = initial
232+
if initial is None:
233+
try:
234+
total = next(it)
235+
except StopIteration:
236+
return
237+
yield total
238+
for element in it:
239+
total = func(total, element)
240+
yield total

tests/test_database_types.py

Lines changed: 14 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,13 @@
88
from datetime import datetime, timedelta, timezone
99
import logging
1010
from decimal import Decimal
11+
from itertools import islice, repeat, chain
12+
1113
from parameterized import parameterized
1214

1315
from data_diff import databases as db
1416
from data_diff.databases import postgresql, oracle
15-
from data_diff.utils import number_to_human
17+
from data_diff.utils import number_to_human, accumulate
1618
from data_diff.diff_tables import TableDiffer, TableSegment, DEFAULT_BISECTION_THRESHOLD
1719
from .common import (
1820
CONN_STRINGS,
@@ -25,7 +27,7 @@
2527
)
2628

2729

28-
CONNS = {k: db.connect_to_uri(v, N_THREADS) for k, v in CONN_STRINGS.items()}
30+
CONNS = {k: db.connect.connect(v, N_THREADS) for k, v in CONN_STRINGS.items()}
2931

3032
CONNS[db.MySQL].query("SET @@session.time_zone='+00:00'", None)
3133
oracle.SESSION_TIME_ZONE = postgresql.SESSION_TIME_ZONE = "UTC"
@@ -290,54 +292,28 @@ def __init__(self, max):
290292
self.max = max
291293

292294
def __iter__(self):
293-
iter = DateTimeFaker(self.max)
294-
iter.prev = datetime(2000, 1, 1, 0, 0, 0, 0)
295-
iter.i = 0
296-
return iter
295+
initial = datetime(2000, 1, 1, 0, 0, 0, 0)
296+
step = timedelta(seconds=3, microseconds=571)
297+
return islice(chain(self.MANUAL_FAKES, accumulate(repeat(step), initial=initial)), self.max)
297298

298299
def __len__(self):
299300
return self.max
300301

301-
def __next__(self) -> datetime:
302-
if self.i < len(self.MANUAL_FAKES):
303-
fake = self.MANUAL_FAKES[self.i]
304-
self.i += 1
305-
return fake
306-
elif self.i < self.max:
307-
self.prev = self.prev + timedelta(seconds=3, microseconds=571)
308-
self.i += 1
309-
return self.prev
310-
else:
311-
raise StopIteration
312-
313302

314303
class IntFaker:
315-
MANUAL_FAKES = [127, -3, -9, 37, 15, 127]
304+
MANUAL_FAKES = [127, -3, -9, 37, 15, 0]
316305

317306
def __init__(self, max):
318307
self.max = max
319308

320309
def __iter__(self):
321-
iter = IntFaker(self.max)
322-
iter.prev = -128
323-
iter.i = 0
324-
return iter
310+
initial = -128
311+
step = 1
312+
return islice(chain(self.MANUAL_FAKES, accumulate(repeat(step), initial=initial)), self.max)
325313

326314
def __len__(self):
327315
return self.max
328316

329-
def __next__(self) -> int:
330-
if self.i < len(self.MANUAL_FAKES):
331-
fake = self.MANUAL_FAKES[self.i]
332-
self.i += 1
333-
return fake
334-
elif self.i < self.max:
335-
self.prev += 1
336-
self.i += 1
337-
return self.prev
338-
else:
339-
raise StopIteration
340-
341317

342318
class FloatFaker:
343319
MANUAL_FAKES = [
@@ -363,26 +339,13 @@ def __init__(self, max):
363339
self.max = max
364340

365341
def __iter__(self):
366-
iter = FloatFaker(self.max)
367-
iter.prev = -10.0001
368-
iter.i = 0
369-
return iter
342+
initial = -10.0001
343+
step = 0.00571
344+
return islice(chain(self.MANUAL_FAKES, accumulate(repeat(step), initial=initial)), self.max)
370345

371346
def __len__(self):
372347
return self.max
373348

374-
def __next__(self) -> float:
375-
if self.i < len(self.MANUAL_FAKES):
376-
fake = self.MANUAL_FAKES[self.i]
377-
self.i += 1
378-
return fake
379-
elif self.i < self.max:
380-
self.prev += 0.00571
381-
self.i += 1
382-
return self.prev
383-
else:
384-
raise StopIteration
385-
386349

387350
class UUID_Faker:
388351
def __init__(self, max):

tests/test_diff_tables.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import preql
77
import arrow # comes with preql
88

9-
from data_diff.databases import connect_to_uri
9+
from data_diff.databases.connect import connect
1010
from data_diff.diff_tables import TableDiffer, TableSegment, split_space
1111
from data_diff import databases as db
1212
from data_diff.utils import ArithAlphanumeric
@@ -21,7 +21,7 @@
2121
)
2222

2323
DATABASE_URIS = {k.__name__: v for k, v in CONN_STRINGS.items()}
24-
DATABASE_INSTANCES = {k.__name__: connect_to_uri(v, N_THREADS) for k, v in CONN_STRINGS.items()}
24+
DATABASE_INSTANCES = {k.__name__: connect(v, N_THREADS) for k, v in CONN_STRINGS.items()}
2525

2626
TEST_DATABASES = {x.__name__ for x in (db.MySQL, db.PostgreSQL, db.Oracle, db.Redshift, db.Snowflake, db.BigQuery)}
2727

tests/test_postgresql.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
import unittest
22

3-
from data_diff.databases.connect import connect_to_uri
3+
from data_diff.databases.connect import connect
44
from data_diff import TableSegment, TableDiffer
55
from .common import TEST_POSTGRESQL_CONN_STRING, random_table_suffix
66

77

88
class TestWithConnection(unittest.TestCase):
99
def setUp(self) -> None:
10-
self.connection = connect_to_uri(TEST_POSTGRESQL_CONN_STRING)
10+
self.connection = connect(TEST_POSTGRESQL_CONN_STRING)
1111

1212
self.connection.query('CREATE EXTENSION IF NOT EXISTS "uuid-ossp";', None)
1313

0 commit comments

Comments
 (0)