Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Commit 49e8a4f

Browse files
committed
List tables from schema, mid-work (WIP)
1 parent 940c2ce commit 49e8a4f

File tree

10 files changed

+135
-23
lines changed

10 files changed

+135
-23
lines changed

data_diff/sqeleton/abcs/mixins.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,9 +97,13 @@ class AbstractMixin_Schema(ABC):
9797
TODO: Move AbstractDatabase.query_table_schema() and friends over here
9898
"""
9999

100+
def table_information(self) -> Compilable:
101+
"Query to return a table of schema information about existing tables"
102+
raise NotImplementedError()
103+
100104
@abstractmethod
101-
def list_tables(self, like: Compilable = None) -> Compilable:
102-
"""Query to select the list of tables in the schema.
105+
def list_tables(self, table_schema: str, like: Compilable = None) -> Compilable:
106+
"""Query to select the list of tables in the schema. (query return type: table[str])
103107
104108
If 'like' is specified, the value is applied to the table name, using the 'like' operator.
105109
"""

data_diff/sqeleton/databases/base.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import decimal
1212

1313
from ..utils import is_uuid, safezip
14-
from ..queries import Expr, Compiler, table, Select, SKIP, Explain, Code
14+
from ..queries import Expr, Compiler, table, Select, SKIP, Explain, Code, this
1515
from ..abcs.database_types import (
1616
AbstractDatabase,
1717
AbstractDialect,
@@ -30,6 +30,8 @@
3030
DbPath,
3131
Boolean,
3232
)
33+
from ..abcs.mixins import Compilable
34+
from ..abcs.mixins import AbstractMixin_Schema
3335

3436
logger = logging.getLogger("database")
3537

@@ -101,6 +103,22 @@ def apply_query(callback: Callable[[str], Any], sql_code: Union[str, ThreadLocal
101103
return callback(sql_code)
102104

103105

106+
class Mixin_Schema(AbstractMixin_Schema):
107+
def table_information(self) -> Compilable:
108+
return table("information_schema", "tables")
109+
110+
def list_tables(self, table_schema: str, like: Compilable = None) -> Compilable:
111+
return (
112+
self.table_information()
113+
.where(
114+
this.table_schema == table_schema,
115+
this.table_name.like(like) if like is not None else SKIP,
116+
this.table_type == "BASE TABLE",
117+
)
118+
.select(this.table_name)
119+
)
120+
121+
104122
class BaseDialect(AbstractDialect):
105123
SUPPORTS_PRIMARY_KEY = False
106124
TYPE_CLASSES: Dict[str, type] = {}
@@ -354,7 +372,9 @@ def _refine_coltypes(self, table_path: DbPath, col_dict: Dict[str, ColType], whe
354372
return
355373

356374
fields = [Code(self.dialect.normalize_uuid(self.dialect.quote(c), String_UUID())) for c in text_columns]
357-
samples_by_row = self.query(table(*table_path).select(*fields).where(Code(where) if where else SKIP).limit(sample_size), list)
375+
samples_by_row = self.query(
376+
table(*table_path).select(*fields).where(Code(where) if where else SKIP).limit(sample_size), list
377+
)
358378
if not samples_by_row:
359379
raise ValueError(f"Table {table_path} is empty.")
360380

data_diff/sqeleton/databases/bigquery.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,9 @@
1111
TemporalType,
1212
Boolean,
1313
)
14-
from ..abcs.mixins import AbstractMixin_MD5, AbstractMixin_NormalizeValue
14+
from ..abcs.mixins import AbstractMixin_MD5, AbstractMixin_NormalizeValue, AbstractMixin_Schema
15+
from ..abcs import Compilable
16+
from ..queries import this, table, SKIP
1517
from .base import BaseDialect, Database, import_helper, parse_table_name, ConnectError, apply_query
1618
from .base import TIMESTAMP_PRECISION_POS, ThreadLocalInterpreter
1719

@@ -51,7 +53,20 @@ def normalize_boolean(self, value: str, coltype: Boolean) -> str:
5153
return self.to_string(f"cast({value} as int)")
5254

5355

54-
class Dialect(BaseDialect):
56+
class Mixin_Schema(AbstractMixin_Schema):
57+
def list_tables(self, table_schema: str, like: Compilable = None) -> Compilable:
58+
return (
59+
table(table_schema, "INFORMATION_SCHEMA", "TABLES")
60+
.where(
61+
this.table_schema == table_schema,
62+
this.table_name.like(like) if like is not None else SKIP,
63+
this.table_type == "BASE TABLE",
64+
)
65+
.select(this.table_name)
66+
)
67+
68+
69+
class Dialect(BaseDialect, Mixin_Schema):
5570
name = "BigQuery"
5671
ROUNDS_ON_PREC_LOSS = False # Technically BigQuery doesn't allow implicit rounding or truncation
5772
TYPE_CLASSES = {

data_diff/sqeleton/databases/duckdb.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
ThreadLocalInterpreter,
2525
TIMESTAMP_PRECISION_POS,
2626
)
27-
from .base import MD5_HEXDIGITS, CHECKSUM_HEXDIGITS
27+
from .base import MD5_HEXDIGITS, CHECKSUM_HEXDIGITS, Mixin_Schema
2828

2929

3030
@import_helper("duckdb")
@@ -54,7 +54,7 @@ def normalize_boolean(self, value: str, coltype: Boolean) -> str:
5454
return self.to_string(f"{value}::INTEGER")
5555

5656

57-
class Dialect(BaseDialect):
57+
class Dialect(BaseDialect, Mixin_Schema):
5858
name = "DuckDB"
5959
ROUNDS_ON_PREC_LOSS = False
6060
SUPPORTS_PRIMARY_KEY = True

data_diff/sqeleton/databases/mysql.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
ConnectError,
1818
BaseDialect,
1919
)
20-
from .base import MD5_HEXDIGITS, CHECKSUM_HEXDIGITS, TIMESTAMP_PRECISION_POS
20+
from .base import MD5_HEXDIGITS, CHECKSUM_HEXDIGITS, TIMESTAMP_PRECISION_POS, Mixin_Schema
2121

2222

2323
@import_helper("mysql")
@@ -47,7 +47,7 @@ def normalize_uuid(self, value: str, coltype: ColType_UUID) -> str:
4747
return f"TRIM(CAST({value} AS char))"
4848

4949

50-
class Dialect(BaseDialect):
50+
class Dialect(BaseDialect, Mixin_Schema):
5151
name = "MySQL"
5252
ROUNDS_ON_PREC_LOSS = True
5353
SUPPORTS_PRIMARY_KEY = True

data_diff/sqeleton/databases/postgresql.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,7 @@
1111
Boolean,
1212
)
1313
from ..abcs.mixins import AbstractMixin_MD5, AbstractMixin_NormalizeValue
14-
from .base import (
15-
BaseDialect,
16-
ThreadedDatabase,
17-
import_helper,
18-
ConnectError,
19-
)
14+
from .base import BaseDialect, ThreadedDatabase, import_helper, ConnectError, Mixin_Schema
2015
from .base import MD5_HEXDIGITS, CHECKSUM_HEXDIGITS, _CHECKSUM_BITSIZE, TIMESTAMP_PRECISION_POS
2116

2217
SESSION_TIME_ZONE = None # Changed by the tests
@@ -53,7 +48,7 @@ def normalize_boolean(self, value: str, coltype: Boolean) -> str:
5348
return self.to_string(f"{value}::int")
5449

5550

56-
class PostgresqlDialect(BaseDialect):
51+
class PostgresqlDialect(BaseDialect, Mixin_Schema):
5752
name = "PostgreSQL"
5853
ROUNDS_ON_PREC_LOSS = True
5954
SUPPORTS_PRIMARY_KEY = True

data_diff/sqeleton/databases/presto.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
Boolean,
2020
)
2121
from ..abcs.mixins import AbstractMixin_MD5, AbstractMixin_NormalizeValue
22-
from .base import BaseDialect, Database, import_helper, ThreadLocalInterpreter
22+
from .base import BaseDialect, Database, import_helper, ThreadLocalInterpreter, Mixin_Schema
2323
from .base import (
2424
MD5_HEXDIGITS,
2525
CHECKSUM_HEXDIGITS,
@@ -69,7 +69,7 @@ def normalize_boolean(self, value: str, coltype: Boolean) -> str:
6969
return self.to_string(f"cast ({value} as int)")
7070

7171

72-
class Dialect(BaseDialect):
72+
class Dialect(BaseDialect, Mixin_Schema):
7373
name = "Presto"
7474
ROUNDS_ON_PREC_LOSS = True
7575
TYPE_CLASSES = {

data_diff/sqeleton/databases/snowflake.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,9 @@
1212
DbPath,
1313
Boolean,
1414
)
15-
from ..abcs.mixins import AbstractMixin_MD5, AbstractMixin_NormalizeValue
15+
from ..abcs.mixins import AbstractMixin_MD5, AbstractMixin_NormalizeValue, AbstractMixin_Schema
16+
from ..abcs import Compilable
17+
from data_diff.sqeleton.queries import table, this, SKIP
1618
from .base import BaseDialect, ConnectError, Database, import_helper, CHECKSUM_MASK, ThreadLocalInterpreter
1719

1820

@@ -46,7 +48,23 @@ def normalize_boolean(self, value: str, coltype: Boolean) -> str:
4648
return self.to_string(f"{value}::int")
4749

4850

49-
class Dialect(BaseDialect):
51+
class Mixin_Schema(AbstractMixin_Schema):
52+
def table_information(self) -> Compilable:
53+
return table("INFORMATION_SCHEMA", "TABLES")
54+
55+
def list_tables(self, table_schema: str, like: Compilable = None) -> Compilable:
56+
return (
57+
self.table_information()
58+
.where(
59+
this.TABLE_SCHEMA == table_schema,
60+
this.TABLE_NAME.like(like) if like is not None else SKIP,
61+
this.TABLE_TYPE == "BASE TABLE",
62+
)
63+
.select(table_name=this.TABLE_NAME)
64+
)
65+
66+
67+
class Dialect(BaseDialect, Mixin_Schema):
5068
name = "Snowflake"
5169
ROUNDS_ON_PREC_LOSS = False
5270
TYPE_CLASSES = {
@@ -72,6 +90,9 @@ def quote(self, s: str):
7290
def to_string(self, s: str):
7391
return f"cast({s} as string)"
7492

93+
def table_information(self) -> Compilable:
94+
return table("INFORMATION_SCHEMA", "TABLES")
95+
7596

7697
class Snowflake(Database):
7798
dialect = Dialect()

data_diff/sqeleton/databases/vertica.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,9 @@
2424
Boolean,
2525
ColType_UUID,
2626
)
27-
from ..abcs.mixins import AbstractMixin_MD5, AbstractMixin_NormalizeValue
27+
from ..abcs.mixins import AbstractMixin_MD5, AbstractMixin_NormalizeValue, AbstractMixin_Schema
28+
from ..abcs import Compilable
29+
from ..queries import table, this, SKIP
2830

2931

3032
@import_helper("vertica")
@@ -60,7 +62,22 @@ def normalize_boolean(self, value: str, coltype: Boolean) -> str:
6062
return self.to_string(f"cast ({value} as int)")
6163

6264

63-
class Dialect(BaseDialect):
65+
class Mixin_Schema(AbstractMixin_Schema):
66+
def table_information(self) -> Compilable:
67+
return table("v_catalog", "tables")
68+
69+
def list_tables(self, table_schema: str, like: Compilable = None) -> Compilable:
70+
return (
71+
self.table_information()
72+
.where(
73+
this.table_schema == table_schema,
74+
this.table_name.like(like) if like is not None else SKIP,
75+
)
76+
.select(this.table_name)
77+
)
78+
79+
80+
class Dialect(BaseDialect, Mixin_Schema):
6481
name = "Vertica"
6582
ROUNDS_ON_PREC_LOSS = True
6683

tests/sqeleton/test_database.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,33 @@
1+
from typing import Callable, List
12
import unittest
23

34
from ..common import str_to_checksum, TEST_MYSQL_CONN_STRING
5+
from ..common import str_to_checksum, test_each_database_in_list, TestPerDatabase, get_conn, random_table_suffix
6+
# from data_diff.sqeleton import databases as db
7+
# from data_diff.sqeleton import connect
8+
9+
from data_diff.sqeleton.queries import table
10+
11+
from data_diff import databases as dbs
412
from data_diff.databases import connect
513

614

15+
TEST_DATABASES = {
16+
dbs.MySQL,
17+
dbs.PostgreSQL,
18+
# dbs.Oracle,
19+
# dbs.Redshift,
20+
dbs.Snowflake,
21+
dbs.DuckDB,
22+
dbs.BigQuery,
23+
dbs.Presto,
24+
dbs.Trino,
25+
dbs.Vertica,
26+
}
27+
28+
test_each_database: Callable = test_each_database_in_list(TEST_DATABASES)
29+
30+
731
class TestDatabase(unittest.TestCase):
832
def setUp(self):
933
self.mysql = connect(TEST_MYSQL_CONN_STRING)
@@ -25,3 +49,19 @@ def test_bad_uris(self):
2549
self.assertRaises(ValueError, connect, "postgresql:///bla/foo")
2650
self.assertRaises(ValueError, connect, "snowflake://user:pass@bya42734/xdiffdev/TEST1")
2751
self.assertRaises(ValueError, connect, "snowflake://user:pass@bya42734/xdiffdev/TEST1?warehouse=ha&schema=dup")
52+
53+
54+
@test_each_database
55+
class TestSchema(TestPerDatabase):
56+
def test_table_list(self):
57+
name = self.table_src_name
58+
db = self.connection
59+
tbl = table(db.parse_table_name(name), schema={'id': int})
60+
q = db.dialect.list_tables(db.default_schema, name)
61+
assert not db.query(q)
62+
63+
db.query(tbl.create())
64+
assert db.query(q, List[str] ) == [name]
65+
66+
db.query( tbl.drop() )
67+
assert not db.query(q)

0 commit comments

Comments
 (0)