Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Commit ab55f21

Browse files
committed
data-diff now uses database A's now instead of cli's now. (issue #284)
1 parent 9289570 commit ab55f21

File tree

13 files changed

+124
-90
lines changed

13 files changed

+124
-90
lines changed

data_diff/__main__.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from copy import deepcopy
2+
from datetime import datetime
23
import sys
34
import time
45
import json
@@ -15,8 +16,9 @@
1516
from .joindiff_tables import TABLE_WRITE_LIMIT, JoinDiffer
1617
from .table_segment import TableSegment
1718
from .sqeleton.schema import create_schema
19+
from .sqeleton.queries.api import current_timestamp
1820
from .databases import connect
19-
from .parse_time import parse_time_before_now, UNITS_STR, ParseError
21+
from .parse_time import parse_time_before, UNITS_STR, ParseError
2022
from .config import apply_config_from_file
2123
from .tracking import disable_tracking
2224
from . import __version__
@@ -299,17 +301,6 @@ def _main(
299301

300302
start = time.monotonic()
301303

302-
try:
303-
options = dict(
304-
min_update=max_age and parse_time_before_now(max_age),
305-
max_update=min_age and parse_time_before_now(min_age),
306-
case_sensitive=case_sensitive,
307-
where=where,
308-
)
309-
except ParseError as e:
310-
logging.error(f"Error while parsing age expression: {e}")
311-
return
312-
313304
if database1 is None or database2 is None:
314305
logging.error(
315306
f"Error: Databases not specified. Got {database1} and {database2}. Use --help for more information."
@@ -326,6 +317,19 @@ def _main(
326317
logging.error(e)
327318
return
328319

320+
now: datetime = db1.query(current_timestamp(), datetime)
321+
now = now.replace(tzinfo=None)
322+
try:
323+
options = dict(
324+
min_update=max_age and parse_time_before(now, max_age),
325+
max_update=min_age and parse_time_before(now, min_age),
326+
case_sensitive=case_sensitive,
327+
where=where,
328+
)
329+
except ParseError as e:
330+
logging.error(f"Error while parsing age expression: {e}")
331+
return
332+
329333
dbs = db1, db2
330334

331335
if interactive:

data_diff/joindiff_tables.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ def sample(table_expr):
6262

6363
def create_temp_table(c: Compiler, path: TablePath, expr: Expr) -> str:
6464
db = c.database
65+
c = c.replace(root=False) # we're compiling fragments, not full queries
6566
if isinstance(db, BigQuery):
6667
return f"create table {c.compile(path)} OPTIONS(expiration_timestamp=TIMESTAMP_ADD(CURRENT_TIMESTAMP(), INTERVAL 1 DAY)) as {c.compile(expr)}"
6768
elif isinstance(db, Presto):

data_diff/parse_time.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,5 +70,5 @@ def parse_time_delta(t: str):
7070
return timedelta(**time_dict)
7171

7272

73-
def parse_time_before_now(t: str):
74-
return datetime.now() - parse_time_delta(t)
73+
def parse_time_before(time: datetime, delta: str):
74+
return time - parse_time_delta(delta)

data_diff/sqeleton/abcs/database_types.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,10 @@ def to_string(self, s: str) -> str:
184184
def random(self) -> str:
185185
"Provide SQL for generating a random number betweein 0..1"
186186

187+
@abstractmethod
188+
def current_timestamp(self) -> str:
189+
"Provide SQL for returning the current timestamp, aka now"
190+
187191
@abstractmethod
188192
def offset_limit(self, offset: Optional[int] = None, limit: Optional[int] = None):
189193
"Provide SQL fragment for limit and offset inside a select"
@@ -199,6 +203,7 @@ def timestamp_value(self, t: datetime) -> str:
199203
"Provide SQL for the given timestamp value"
200204
...
201205

206+
202207
@abstractmethod
203208
def parse_type(
204209
self,

data_diff/sqeleton/databases/base.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,10 @@ def timestamp_value(self, t: DbTime) -> str:
142142
return f"'{t.isoformat()}'"
143143

144144
def random(self) -> str:
145-
return "RANDOM()"
145+
return "random()"
146+
147+
def current_timestamp(self) -> str:
148+
return "current_timestamp()"
146149

147150
def explain_as_text(self, query: str) -> str:
148151
return f"EXPLAIN {query}"

data_diff/sqeleton/queries/api.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,4 +95,8 @@ def insert_rows_in_batches(db, table: TablePath, rows, *, columns=None, batch_si
9595
db.query(table.insert_rows(batch, columns=columns))
9696

9797

98+
def current_timestamp():
99+
return CurrentTimestamp()
100+
101+
98102
commit = Commit()

data_diff/sqeleton/queries/ast_classes.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from ..abcs import Compilable
99
from ..schema import Schema
1010

11-
from .compiler import Compiler, cv_params
11+
from .compiler import Compiler, cv_params, Root
1212
from .base import SKIP, CompileError, DbPath, args_as_tuple
1313

1414

@@ -47,7 +47,7 @@ def cast_to(self, to):
4747

4848

4949
@dataclass
50-
class Code(ExprNode):
50+
class Code(ExprNode, Root):
5151
code: str
5252

5353
def compile(self, c: Compiler) -> str:
@@ -434,7 +434,7 @@ def compile(self, c: Compiler) -> str:
434434

435435

436436
@dataclass
437-
class Join(ExprNode, ITable):
437+
class Join(ExprNode, ITable, Root):
438438
source_tables: Sequence[ITable]
439439
op: str = None
440440
on_exprs: Sequence[Expr] = None
@@ -499,7 +499,7 @@ def compile(self, parent_c: Compiler) -> str:
499499

500500

501501
@dataclass
502-
class GroupBy(ExprNode, ITable):
502+
class GroupBy(ExprNode, ITable, Root):
503503
table: ITable
504504
keys: Sequence[Expr] = None # IKey?
505505
values: Sequence[Expr] = None
@@ -540,7 +540,7 @@ def compile(self, c: Compiler) -> str:
540540

541541

542542
@dataclass
543-
class TableOp(ExprNode, ITable):
543+
class TableOp(ExprNode, ITable, Root):
544544
op: str
545545
table1: ITable
546546
table2: ITable
@@ -571,7 +571,7 @@ def compile(self, parent_c: Compiler) -> str:
571571

572572

573573
@dataclass
574-
class Select(ExprNode, ITable):
574+
class Select(ExprNode, ITable, Root):
575575
table: Expr = None
576576
columns: Sequence[Expr] = None
577577
where_exprs: Sequence[Expr] = None
@@ -771,7 +771,7 @@ def compile_for_insert(self, c: Compiler):
771771

772772

773773
@dataclass
774-
class Explain(ExprNode):
774+
class Explain(ExprNode, Root):
775775
select: Select
776776

777777
type = str
@@ -780,10 +780,16 @@ def compile(self, c: Compiler) -> str:
780780
return c.dialect.explain_as_text(c.compile(self.select))
781781

782782

783+
class CurrentTimestamp(ExprNode):
784+
type = datetime
785+
786+
def compile(self, c: Compiler) -> str:
787+
return c.dialect.current_timestamp()
788+
783789
# DDL
784790

785791

786-
class Statement(Compilable):
792+
class Statement(Compilable, Root):
787793
type = None
788794

789795

data_diff/sqeleton/queries/compiler.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@
1111

1212
cv_params = contextvars.ContextVar("params")
1313

14+
class Root:
15+
"Nodes inheriting from Root can be used as root statements in SQL (e.g. SELECT yes, RANDOM() no)"
16+
1417

1518
@dataclass
1619
class Compiler(AbstractCompiler):
@@ -33,6 +36,10 @@ def compile(self, elem, params=None) -> str:
3336
if params:
3437
cv_params.set(params)
3538

39+
if self.root and isinstance(elem, Compilable) and not isinstance(elem, Root):
40+
from .ast_classes import Select
41+
elem = Select(columns=[elem])
42+
3643
res = self._compile(elem)
3744
if self.root and self._subqueries:
3845
subq = ", ".join(f"\n {k} AS ({v})" for k, v in self._subqueries.items())

tests/common.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from data_diff import tracking
1414
from data_diff import connect
1515
from data_diff.sqeleton.queries import table
16+
from data_diff.table_segment import TableSegment
1617
from data_diff.sqeleton.databases import Database
1718
from data_diff.query_utils import drop_table
1819

@@ -86,10 +87,13 @@ def get_git_revision_short_hash() -> str:
8687
_database_instances = {}
8788

8889

89-
def get_conn(cls: type) -> Database:
90-
if cls not in _database_instances:
91-
_database_instances[cls] = connect(CONN_STRINGS[cls], N_THREADS)
92-
return _database_instances[cls]
90+
def get_conn(cls: type, shared: bool =True) -> Database:
91+
if shared:
92+
if cls not in _database_instances:
93+
_database_instances[cls] = get_conn(cls, shared=False)
94+
return _database_instances[cls]
95+
96+
return connect(CONN_STRINGS[cls], N_THREADS)
9397

9498

9599
def _print_used_dbs():
@@ -134,11 +138,12 @@ class DiffTestCase(unittest.TestCase):
134138
db_cls = None
135139
src_schema = None
136140
dst_schema = None
141+
shared_connection = True
137142

138143
def setUp(self):
139144
assert self.db_cls, self.db_cls
140145

141-
self.connection = get_conn(self.db_cls)
146+
self.connection = get_conn(self.db_cls, self.shared_connection)
142147

143148
table_suffix = random_table_suffix()
144149
self.table_src_name = f"src{table_suffix}"
@@ -150,11 +155,11 @@ def setUp(self):
150155
drop_table(self.connection, self.table_src_path)
151156
drop_table(self.connection, self.table_dst_path)
152157

158+
self.src_table = table(self.table_src_path, schema=self.src_schema)
159+
self.dst_table = table(self.table_dst_path, schema=self.dst_schema)
153160
if self.src_schema:
154-
self.src_table = table(self.table_src_path, schema=self.src_schema)
155161
self.connection.query(self.src_table.create())
156162
if self.dst_schema:
157-
self.dst_table = table(self.table_dst_path, schema=self.dst_schema)
158163
self.connection.query(self.dst_table.create())
159164

160165
return super().setUp()
@@ -175,3 +180,8 @@ def _test_per_database(cls):
175180
return _parameterized_class_per_conn(databases)(cls)
176181

177182
return _test_per_database
183+
184+
def table_segment(database, table_path, key_columns, *args, **kw):
185+
if isinstance(key_columns, str):
186+
key_columns = (key_columns,)
187+
return TableSegment(database, table_path, key_columns, *args, **kw)

tests/sqeleton/test_query.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,9 @@ def is_distinct_from(self, a: str, b: str) -> str:
3333
def random(self) -> str:
3434
return "random()"
3535

36+
def current_timestamp(self) -> str:
37+
return "now()"
38+
3639
def offset_limit(self, offset: Optional[int] = None, limit: Optional[int] = None):
3740
x = offset and f"OFFSET {offset}", limit and f"LIMIT {limit}"
3841
return " ".join(filter(None, x))

0 commit comments

Comments
 (0)