11import math
22import sys
33import logging
4- from typing import Dict , Tuple , Optional , Sequence
4+ from typing import Dict , Tuple , Optional , Sequence , Type
55from functools import lru_cache , wraps
66from concurrent .futures import ThreadPoolExecutor
77import threading
88from abc import abstractmethod
99
10- from .database_types import AbstractDatabase , ColType , Integer , Decimal , Float , UnknownColType
10+ from .database_types import (
11+ AbstractDatabase ,
12+ ColType ,
13+ Integer ,
14+ Decimal ,
15+ Float ,
16+ PrecisionType ,
17+ TemporalType ,
18+ UnknownColType ,
19+ )
1120from data_diff .sql import DbPath , SqlOrStr , Compiler , Explain , Select
1221
1322logger = logging .getLogger ("database" )
@@ -62,7 +71,7 @@ class Database(AbstractDatabase):
6271 Instanciated using :meth:`~data_diff.connect_to_uri`
6372 """
6473
65- DATETIME_TYPES : Dict [str , type ] = {}
74+ TYPE_CLASSES : Dict [str , type ] = {}
6675 default_schema : str = None
6776
6877 @property
@@ -109,6 +118,9 @@ def _convert_db_precision_to_digits(self, p: int) -> int:
109118 # See: https://en.wikipedia.org/wiki/Single-precision_floating-point_format
110119 return math .floor (math .log (2 ** p , 10 ))
111120
121+ def _parse_type_repr (self , type_repr : str ) -> Optional [Type [ColType ]]:
122+ return self .TYPE_CLASSES .get (type_repr )
123+
112124 def _parse_type (
113125 self ,
114126 col_name : str ,
@@ -119,36 +131,35 @@ def _parse_type(
119131 ) -> ColType :
120132 """ """
121133
122- cls = self .DATETIME_TYPES .get (type_repr )
123- if cls :
134+ cls = self ._parse_type_repr (type_repr )
135+ if not cls :
136+ return UnknownColType (type_repr )
137+
138+ if issubclass (cls , TemporalType ):
124139 return cls (
125140 precision = datetime_precision if datetime_precision is not None else DEFAULT_DATETIME_PRECISION ,
126141 rounds = self .ROUNDS_ON_PREC_LOSS ,
127142 )
128143
129- cls = self .NUMERIC_TYPES .get (type_repr )
130- if cls :
131- if issubclass (cls , Integer ):
132- # Some DBs have a constant numeric_scale, so they don't report it.
133- # We fill in the constant, so we need to ignore it for integers.
134- return cls (precision = 0 )
135-
136- elif issubclass (cls , Decimal ):
137- if numeric_scale is None :
138- raise ValueError (
139- f"{ self .name } : Unexpected numeric_scale is NULL, for column { col_name } of type { type_repr } ."
140- )
141- return cls (precision = numeric_scale )
142-
143- assert issubclass (cls , Float )
144+ elif issubclass (cls , Integer ):
145+ return cls ()
146+
147+ elif issubclass (cls , Decimal ):
148+ if numeric_scale is None :
149+ raise ValueError (
150+ f"{ self .name } : Unexpected numeric_scale is NULL, for column { col_name } of type { type_repr } ."
151+ )
152+ return cls (precision = numeric_scale )
153+
154+ elif issubclass (cls , Float ):
144155 # assert numeric_scale is None
145156 return cls (
146157 precision = self ._convert_db_precision_to_digits (
147158 numeric_precision if numeric_precision is not None else DEFAULT_NUMERIC_PRECISION
148159 )
149160 )
150161
151- return UnknownColType ( type_repr )
162+ raise TypeError ( f"Parsing { type_repr } returned an unknown type ' { cls } '." )
152163
153164 def select_table_schema (self , path : DbPath ) -> str :
154165 schema , table = self ._normalize_table_path (path )
0 commit comments