1- from typing import Callable , List
2- from datetime import datetime
31import unittest
2+ from datetime import datetime
3+ from typing import Callable , List , Tuple
44
5- from .common import str_to_checksum , TEST_MYSQL_CONN_STRING
6- from .common import str_to_checksum , test_each_database_in_list , get_conn , random_table_suffix
7-
8- from sqeleton .queries import table , current_timestamp
5+ import pytz
96
10- from sqeleton import databases as dbs
117from sqeleton import connect
12-
8+ from sqeleton import databases as dbs
9+ from sqeleton .queries import table , current_timestamp , NormalizeAsString
10+ from .common import TEST_MYSQL_CONN_STRING
11+ from .common import str_to_checksum , test_each_database_in_list , get_conn , random_table_suffix
12+ from sqeleton .abcs .database_types import TimestampTZ
1313
1414TEST_DATABASES = {
1515 dbs .MySQL ,
@@ -81,6 +81,43 @@ def test_current_timestamp(self):
8181 res = db .query (current_timestamp (), datetime )
8282 assert isinstance (res , datetime ), (res , type (res ))
8383
84+ def test_correct_timezone (self ):
85+ name = "tbl_" + random_table_suffix ()
86+ db = get_conn (self .db_cls )
87+ tbl = table (name , schema = {
88+ "id" : int , "created_at" : TimestampTZ (9 ), "updated_at" : TimestampTZ (9 )
89+ })
90+
91+ db .query (tbl .create ())
92+
93+ tz = pytz .timezone ('Europe/Berlin' )
94+
95+ now = datetime .now (tz )
96+ if isinstance (db , dbs .Presto ):
97+ ms = now .microsecond // 1000 * 1000 # Presto max precision is 3
98+ now = now .replace (microsecond = ms )
99+
100+ db .query (table (name ).insert_row (1 , now , now ))
101+ db .query (db .dialect .set_timezone_to_utc ())
102+
103+ t = db .table (name ).query_schema ()
104+ t .schema ["created_at" ] = t .schema ["created_at" ].replace (precision = t .schema ["created_at" ].precision )
105+
106+ tbl = table (name , schema = t .schema )
107+
108+ results = db .query (tbl .select (NormalizeAsString (tbl [c ]) for c in ["created_at" , "updated_at" ]), List [Tuple ])
109+
110+ created_at = results [0 ][1 ]
111+ updated_at = results [0 ][1 ]
112+
113+ utc = now .astimezone (pytz .UTC )
114+ expected = utc .__format__ ("%Y-%m-%d %H:%M:%S.%f" )
115+
116+
117+ self .assertEqual (created_at , expected )
118+ self .assertEqual (updated_at , expected )
119+
120+ db .query (tbl .drop ())
84121
85122@test_each_database
86123class TestThreePartIds (unittest .TestCase ):
@@ -104,3 +141,4 @@ def test_three_part_support(self):
104141 d = db .query_table_schema (part .path )
105142 assert len (d ) == 1
106143 db .query (part .drop ())
144+
0 commit comments