1- import re
2-
31from .database_types import *
4- from .base import Database , import_helper
5- from .base import (
6- MD5_HEXDIGITS ,
7- CHECKSUM_HEXDIGITS ,
8- TIMESTAMP_PRECISION_POS ,
9- DEFAULT_DATETIME_PRECISION ,
10- )
2+ from .presto import Presto
3+ from .base import import_helper
4+ from .base import TIMESTAMP_PRECISION_POS
115
126
137@import_helper ("trino" )
@@ -17,49 +11,12 @@ def import_trino():
1711 return trino
1812
1913
20- class Trino (Database ):
21- default_schema = "public"
22- TYPE_CLASSES = {
23- # Timestamps
24- "timestamp with time zone" : TimestampTZ ,
25- "timestamp without time zone" : Timestamp ,
26- "timestamp" : Timestamp ,
27- # Numbers
28- "integer" : Integer ,
29- "bigint" : Integer ,
30- "real" : Float ,
31- "double" : Float ,
32- # Text
33- "varchar" : Text ,
34- }
35- ROUNDS_ON_PREC_LOSS = True
36-
14+ class Trino (Presto ):
3715 def __init__ (self , ** kw ):
3816 trino = import_trino ()
3917
4018 self ._conn = trino .dbapi .connect (** kw )
4119
42- def quote (self , s : str ):
43- return f'"{ s } "'
44-
45- def md5_to_int (self , s : str ) -> str :
46- return f"cast(from_base(substr(to_hex(md5(to_utf8({ s } ))), { 1 + MD5_HEXDIGITS - CHECKSUM_HEXDIGITS } ), 16) as decimal(38, 0))"
47-
48- def to_string (self , s : str ):
49- return f"cast({ s } as varchar)"
50-
51- def _query (self , sql_code : str ) -> list :
52- """Uses the standard SQL cursor interface"""
53- c = self ._conn .cursor ()
54- c .execute (sql_code )
55- if sql_code .lower ().startswith ("select" ):
56- return c .fetchall ()
57- if re .match (r"(insert|create|truncate|drop)" , sql_code , re .IGNORECASE ):
58- return c .fetchone ()
59-
60- def close (self ):
61- self ._conn .close ()
62-
6320 def normalize_timestamp (self , value : str , coltype : TemporalType ) -> str :
6421 if coltype .rounds :
6522 s = f"date_format(cast({ value } as timestamp({ coltype .precision } )), '%Y-%m-%d %H:%i:%S.%f')"
@@ -70,52 +27,5 @@ def normalize_timestamp(self, value: str, coltype: TemporalType) -> str:
7027 f"RPAD(RPAD({ s } , { TIMESTAMP_PRECISION_POS + coltype .precision } , '.'), { TIMESTAMP_PRECISION_POS + 6 } , '0')"
7128 )
7229
73- def normalize_number (self , value : str , coltype : FractionalType ) -> str :
74- return self .to_string (f"cast({ value } as decimal(38,{ coltype .precision } ))" )
75-
76- def select_table_schema (self , path : DbPath ) -> str :
77- schema , table = self ._normalize_table_path (path )
78-
79- return (
80- f"SELECT column_name, data_type, 3 as datetime_precision, 3 as numeric_precision FROM INFORMATION_SCHEMA.COLUMNS "
81- f"WHERE table_name = '{ table } ' AND table_schema = '{ schema } '"
82- )
83-
84- def _parse_type (
85- self ,
86- table_path : DbPath ,
87- col_name : str ,
88- type_repr : str ,
89- datetime_precision : int = None ,
90- numeric_precision : int = None ,
91- ) -> ColType :
92- timestamp_regexps = {
93- r"timestamp\((\d)\)" : Timestamp ,
94- r"timestamp\((\d)\) with time zone" : TimestampTZ ,
95- }
96- for regexp , t_cls in timestamp_regexps .items ():
97- m = re .match (regexp + "$" , type_repr )
98- if m :
99- datetime_precision = int (m .group (1 ))
100- return t_cls (
101- precision = datetime_precision if datetime_precision is not None else DEFAULT_DATETIME_PRECISION ,
102- rounds = self .ROUNDS_ON_PREC_LOSS ,
103- )
104-
105- number_regexps = {r"decimal\((\d+),(\d+)\)" : Decimal }
106- for regexp , n_cls in number_regexps .items ():
107- m = re .match (regexp + "$" , type_repr )
108- if m :
109- prec , scale = map (int , m .groups ())
110- return n_cls (scale )
111-
112- string_regexps = {r"varchar\((\d+)\)" : Text , r"char\((\d+)\)" : Text }
113- for regexp , n_cls in string_regexps .items ():
114- m = re .match (regexp + "$" , type_repr )
115- if m :
116- return n_cls ()
117-
118- return super ()._parse_type (table_path , col_name , type_repr , datetime_precision , numeric_precision )
119-
12030 def normalize_uuid (self , value : str , coltype : ColType_UUID ) -> str :
12131 return f"TRIM({ value } )"
0 commit comments