Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Commit 101707b

Browse files
committed
Fix: Only parse relevant columns. Only warn on relevant columns.
1 parent b88972a commit 101707b

File tree

2 files changed

+25
-12
lines changed

2 files changed

+25
-12
lines changed

data_diff/database.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from abc import ABC, abstractmethod
66
from runtype import dataclass
77
import logging
8-
from typing import Tuple, Optional, List
8+
from typing import Sequence, Tuple, Optional, List
99
from concurrent.futures import ThreadPoolExecutor
1010
import threading
1111
from typing import Dict
@@ -131,10 +131,6 @@ def __post_init__(self):
131131
class UnknownColType(ColType):
132132
text: str
133133

134-
def __post_init__(self):
135-
logger.warn(f"Column of type '{self.text}' has no compatibility handling. "
136-
"If encoding/formatting differs between databases, it may result in false positives.")
137-
138134

139135
class AbstractDatabase(ABC):
140136
@abstractmethod
@@ -163,7 +159,7 @@ def select_table_schema(self, path: DbPath) -> str:
163159
...
164160

165161
@abstractmethod
166-
def query_table_schema(self, path: DbPath) -> Dict[str, ColType]:
162+
def query_table_schema(self, path: DbPath, filter_columns: Optional[Sequence[str]] = None) -> Dict[str, ColType]:
167163
"Query the table for its schema for table in 'path', and return {column: type}"
168164
...
169165

@@ -241,6 +237,10 @@ class Database(AbstractDatabase):
241237
DATETIME_TYPES = {}
242238
default_schema = None
243239

240+
@property
241+
def name(self):
242+
return type(self).__name__
243+
244244
def query(self, sql_ast: SqlOrStr, res_type: type):
245245
"Query the given SQL code/AST, and attempt to convert the result to type 'res_type'"
246246

@@ -321,12 +321,16 @@ def select_table_schema(self, path: DbPath) -> str:
321321
f"WHERE table_name = '{table}' AND table_schema = '{schema}'"
322322
)
323323

324-
def query_table_schema(self, path: DbPath) -> Dict[str, ColType]:
324+
def query_table_schema(self, path: DbPath, filter_columns: Optional[Sequence[str]] = None) -> Dict[str, ColType]:
325325
rows = self.query(self.select_table_schema(path), list)
326326
if not rows:
327-
raise RuntimeError(f"{self.__class__.__name__}: Table '{'.'.join(path)}' does not exist, or has no columns")
327+
raise RuntimeError(f"{self.name}: Table '{'.'.join(path)}' does not exist, or has no columns")
328+
329+
if filter_columns is not None:
330+
accept = {i.lower() for i in filter_columns}
331+
rows = [r for r in rows if r[0].lower() in accept]
328332

329-
# Return a dict of form {name: type} after canonizaation
333+
# Return a dict of form {name: type} after normalization
330334
return {row[0]: self._parse_type(*row[1:]) for row in rows}
331335

332336
# @lru_cache()
@@ -339,7 +343,7 @@ def _normalize_table_path(self, path: DbPath) -> DbPath:
339343
return self.default_schema, path[0]
340344
elif len(path) != 2:
341345
raise ValueError(
342-
f"{self.__class__.__name__}: Bad table path for {self}: '{'.'.join(path)}'. Expected form: schema.table"
346+
f"{self.name}: Bad table path for {self}: '{'.'.join(path)}'. Expected form: schema.table"
343347
)
344348

345349
return path
@@ -407,6 +411,7 @@ class Postgres(ThreadedDatabase):
407411
"decimal": Decimal,
408412
"integer": Integer,
409413
"numeric": Decimal,
414+
"bigint": Integer,
410415
}
411416
ROUNDS_ON_PREC_LOSS = True
412417

data_diff/diff_tables.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from runtype import dataclass
1313

1414
from .sql import Select, Checksum, Compare, DbPath, DbKey, DbTime, Count, TableName, Time, Min, Max
15-
from .database import Database, NumericType, PrecisionType, ColType
15+
from .database import Database, NumericType, PrecisionType, ColType, UnknownColType
1616

1717
logger = logging.getLogger("diff_tables")
1818

@@ -142,7 +142,8 @@ def with_schema(self) -> "TableSegment":
142142
"Queries the table schema from the database, and returns a new instance of TableSegmentWithSchema."
143143
if self._schema:
144144
return self
145-
schema = self.database.query_table_schema(self.table_path)
145+
146+
schema = self.database.query_table_schema(self.table_path, self._relevant_columns)
146147
if self.case_sensitive:
147148
schema = Schema_CaseSensitive(schema)
148149
else:
@@ -381,6 +382,13 @@ def _validate_and_adjust_columns(self, table1, table2):
381382
table1._schema[c] = col1.replace(precision=lowest.precision)
382383
table2._schema[c] = col2.replace(precision=lowest.precision)
383384

385+
for t in [table1, table2]:
386+
for c in t._relevant_columns:
387+
ctype = t._schema[c]
388+
if isinstance(ctype, UnknownColType):
389+
logger.warn(f"[{t.database.name}] Column '{c}' of type '{ctype.text}' has no compatibility handling. "
390+
"If encoding/formatting differs between databases, it may result in false positives.")
391+
384392
def _bisect_and_diff_tables(self, table1, table2, level=0, max_rows=None):
385393
assert table1.is_bounded and table2.is_bounded
386394

0 commit comments

Comments
 (0)