1+ from uuid import UUID
12import math
23import sys
34import logging
4- from typing import Dict , Tuple , Optional , Sequence
5+ from typing import Dict , Tuple , Optional , Sequence , Type , List
56from functools import lru_cache , wraps
67from concurrent .futures import ThreadPoolExecutor
78import threading
89from abc import abstractmethod
910
10- from .database_types import AbstractDatabase , ColType , Integer , Decimal , Float , UnknownColType
11- from data_diff .sql import DbPath , SqlOrStr , Compiler , Explain , Select
11+ from data_diff .utils import is_uuid , safezip
12+ from .database_types import (
13+ ColType_UUID ,
14+ AbstractDatabase ,
15+ ColType ,
16+ Integer ,
17+ Decimal ,
18+ Float ,
19+ PrecisionType ,
20+ TemporalType ,
21+ UnknownColType ,
22+ Text ,
23+ )
24+ from data_diff .sql import DbPath , SqlOrStr , Compiler , Explain , Select , TableName
1225
1326logger = logging .getLogger ("database" )
1427
@@ -62,7 +75,7 @@ class Database(AbstractDatabase):
6275 Instanciated using :meth:`~data_diff.connect_to_uri`
6376 """
6477
65- DATETIME_TYPES : Dict [str , type ] = {}
78+ TYPE_CLASSES : Dict [str , type ] = {}
6679 default_schema : str = None
6780
6881 @property
@@ -93,7 +106,7 @@ def query(self, sql_ast: SqlOrStr, res_type: type):
93106 assert len (res ) == 1 , (sql_code , res )
94107 return res [0 ]
95108 elif getattr (res_type , "__origin__" , None ) is list and len (res_type .__args__ ) == 1 :
96- if res_type .__args__ == (int ,):
109+ if res_type .__args__ == (int ,) or res_type . __args__ == ( str ,) :
97110 return [_one (row ) for row in res ]
98111 elif res_type .__args__ == (Tuple ,):
99112 return [tuple (row ) for row in res ]
@@ -109,8 +122,12 @@ def _convert_db_precision_to_digits(self, p: int) -> int:
109122 # See: https://en.wikipedia.org/wiki/Single-precision_floating-point_format
110123 return math .floor (math .log (2 ** p , 10 ))
111124
125+ def _parse_type_repr (self , type_repr : str ) -> Optional [Type [ColType ]]:
126+ return self .TYPE_CLASSES .get (type_repr )
127+
112128 def _parse_type (
113129 self ,
130+ table_path : DbPath ,
114131 col_name : str ,
115132 type_repr : str ,
116133 datetime_precision : int = None ,
@@ -119,36 +136,38 @@ def _parse_type(
119136 ) -> ColType :
120137 """ """
121138
122- cls = self .DATETIME_TYPES .get (type_repr )
123- if cls :
139+ cls = self ._parse_type_repr (type_repr )
140+ if not cls :
141+ return UnknownColType (type_repr )
142+
143+ if issubclass (cls , TemporalType ):
124144 return cls (
125145 precision = datetime_precision if datetime_precision is not None else DEFAULT_DATETIME_PRECISION ,
126146 rounds = self .ROUNDS_ON_PREC_LOSS ,
127147 )
128148
129- cls = self .NUMERIC_TYPES .get (type_repr )
130- if cls :
131- if issubclass (cls , Integer ):
132- # Some DBs have a constant numeric_scale, so they don't report it.
133- # We fill in the constant, so we need to ignore it for integers.
134- return cls (precision = 0 )
135-
136- elif issubclass (cls , Decimal ):
137- if numeric_scale is None :
138- raise ValueError (
139- f"{ self .name } : Unexpected numeric_scale is NULL, for column { col_name } of type { type_repr } ."
140- )
141- return cls (precision = numeric_scale )
149+ elif issubclass (cls , Integer ):
150+ return cls ()
151+
152+ elif issubclass (cls , Decimal ):
153+ if numeric_scale is None :
154+ raise ValueError (
155+ f"{ self .name } : Unexpected numeric_scale is NULL, for column { '.' .join (table_path )} .{ col_name } of type { type_repr } ."
156+ )
157+ return cls (precision = numeric_scale )
142158
143- assert issubclass (cls , Float )
159+ elif issubclass (cls , Float ):
144160 # assert numeric_scale is None
145161 return cls (
146162 precision = self ._convert_db_precision_to_digits (
147163 numeric_precision if numeric_precision is not None else DEFAULT_NUMERIC_PRECISION
148164 )
149165 )
150166
151- return UnknownColType (type_repr )
167+ elif issubclass (cls , Text ):
168+ return cls ()
169+
170+ raise TypeError (f"Parsing { type_repr } returned an unknown type '{ cls } '." )
152171
153172 def select_table_schema (self , path : DbPath ) -> str :
154173 schema , table = self ._normalize_table_path (path )
@@ -167,8 +186,34 @@ def query_table_schema(self, path: DbPath, filter_columns: Optional[Sequence[str
167186 accept = {i .lower () for i in filter_columns }
168187 rows = [r for r in rows if r [0 ].lower () in accept ]
169188
189+ col_dict : Dict [str , ColType ] = {row [0 ]: self ._parse_type (path , * row ) for row in rows }
190+
191+ self ._refine_coltypes (path , col_dict )
192+
170193 # Return a dict of form {name: type} after normalization
171- return {row [0 ]: self ._parse_type (* row ) for row in rows }
194+ return col_dict
195+
196+ def _refine_coltypes (self , table_path : DbPath , col_dict : Dict [str , ColType ]):
197+ "Refine the types in the column dict, by querying the database for a sample of their values"
198+
199+ text_columns = [k for k , v in col_dict .items () if isinstance (v , Text )]
200+ if not text_columns :
201+ return
202+
203+ fields = [self .normalize_uuid (c , ColType_UUID ()) for c in text_columns ]
204+ samples_by_row = self .query (Select (fields , TableName (table_path ), limit = 16 ), list )
205+ samples_by_col = list (zip (* samples_by_row ))
206+ for col_name , samples in safezip (text_columns , samples_by_col ):
207+ uuid_samples = list (filter (is_uuid , samples ))
208+
209+ if uuid_samples :
210+ if len (uuid_samples ) != len (samples ):
211+ logger .warning (
212+ f"Mixed UUID/Non-UUID values detected in column { '.' .join (table_path )} .{ col_name } , disabling UUID support."
213+ )
214+ else :
215+ assert col_name in col_dict
216+ col_dict [col_name ] = ColType_UUID ()
172217
173218 # @lru_cache()
174219 # def get_table_schema(self, path: DbPath) -> Dict[str, ColType]:
@@ -186,6 +231,15 @@ def _normalize_table_path(self, path: DbPath) -> DbPath:
186231 def parse_table_name (self , name : str ) -> DbPath :
187232 return parse_table_name (name )
188233
234+ def offset_limit (self , offset : Optional [int ] = None , limit : Optional [int ] = None ):
235+ if offset :
236+ raise NotImplementedError ("No support for OFFSET in query" )
237+
238+ return f"LIMIT { limit } "
239+
240+ def normalize_uuid (self , value : str , coltype : ColType_UUID ) -> str :
241+ return f"TRIM({ value } )"
242+
189243
190244class ThreadedDatabase (Database ):
191245 """Access the database through singleton threads.
0 commit comments