Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Commit 11e5dca

Browse files
authored
Merge pull request #30 from datafold/bjoernhaeuser-patch-1
PR #20 with added infrastructure (TimestampTZ repr)
2 parents 3815892 + 04018a1 commit 11e5dca

File tree

9 files changed

+95
-14
lines changed

9 files changed

+95
-14
lines changed

sqeleton/abcs/database_types.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
from runtype import dataclass
77

8-
from ..utils import ArithAlphanumeric, ArithUUID, Self
8+
from ..utils import ArithAlphanumeric, ArithUUID, Self, Unknown
99

1010

1111
DbPath = Tuple[str, ...]
@@ -20,7 +20,7 @@ class ColType:
2020
@dataclass
2121
class PrecisionType(ColType):
2222
precision: int
23-
rounds: bool
23+
rounds: Union[bool, Unknown] = Unknown
2424

2525

2626
class Boolean(ColType):

sqeleton/databases/base.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
String_VaryingAlphanum,
3131
TemporalType,
3232
UnknownColType,
33+
TimestampTZ,
3334
Text,
3435
DbTime,
3536
DbPath,
@@ -202,6 +203,8 @@ def constant_values(self, rows) -> str:
202203
def type_repr(self, t) -> str:
203204
if isinstance(t, str):
204205
return t
206+
elif isinstance(t, TimestampTZ):
207+
return f"TIMESTAMP({min(t.precision, DEFAULT_DATETIME_PRECISION)})"
205208
return {
206209
int: "INT",
207210
str: "VARCHAR",

sqeleton/databases/postgresql.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,11 @@ def set_timezone_to_utc(self) -> str:
9797
def current_timestamp(self) -> str:
9898
return "current_timestamp"
9999

100+
def type_repr(self, t) -> str:
101+
if isinstance(t, TimestampTZ):
102+
return f"timestamp ({t.precision}) with time zone"
103+
return super().type_repr(t)
104+
100105

101106
class PostgreSQL(ThreadedDatabase):
102107
dialect = PostgresqlDialect()

sqeleton/databases/presto.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,11 @@ def set_timezone_to_utc(self) -> str:
141141
def current_timestamp(self) -> str:
142142
return "current_timestamp"
143143

144+
def type_repr(self, t) -> str:
145+
if isinstance(t, TimestampTZ):
146+
return f"timestamp with time zone"
147+
return super().type_repr(t)
148+
144149

145150
class Presto(Database):
146151
dialect = Dialect()

sqeleton/databases/redshift.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from typing import List, Dict
2-
from ..abcs.database_types import Float, TemporalType, FractionalType, DbPath
2+
from ..abcs.database_types import Float, TemporalType, FractionalType, DbPath, TimestampTZ
33
from ..abcs.mixins import AbstractMixin_MD5
44
from .postgresql import (
55
PostgreSQL,
@@ -57,6 +57,11 @@ def concat(self, items: List[str]) -> str:
5757
def is_distinct_from(self, a: str, b: str) -> str:
5858
return f"({a} IS NULL != {b} IS NULL) OR ({a}!={b})"
5959

60+
def type_repr(self, t) -> str:
61+
if isinstance(t, TimestampTZ):
62+
return f"timestamptz"
63+
return super().type_repr(t)
64+
6065

6166
class Redshift(PostgreSQL):
6267
dialect = Dialect()

sqeleton/databases/snowflake.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,9 +48,9 @@ def md5_as_int(self, s: str) -> str:
4848
class Mixin_NormalizeValue(AbstractMixin_NormalizeValue):
4949
def normalize_timestamp(self, value: str, coltype: TemporalType) -> str:
5050
if coltype.rounds:
51-
timestamp = f"to_timestamp(round(date_part(epoch_nanosecond, {value}::timestamp(9))/1000000000, {coltype.precision}))"
51+
timestamp = f"to_timestamp(round(date_part(epoch_nanosecond, convert_timezone('UTC', {value})::timestamp(9))/1000000000, {coltype.precision}))"
5252
else:
53-
timestamp = f"cast({value} as timestamp({coltype.precision}))"
53+
timestamp = f"cast(convert_timezone('UTC', {value}) as timestamp({coltype.precision}))"
5454

5555
return f"to_char({timestamp}, 'YYYY-MM-DD HH24:MI:SS.FF6')"
5656

@@ -139,6 +139,11 @@ def set_timezone_to_utc(self) -> str:
139139
def optimizer_hints(self, hints: str) -> str:
140140
raise NotImplementedError("Optimizer hints not yet implemented in snowflake")
141141

142+
def type_repr(self, t) -> str:
143+
if isinstance(t, TimestampTZ):
144+
return f"timestamp_tz({t.precision})"
145+
return super().type_repr(t)
146+
142147

143148
class Snowflake(Database):
144149
dialect = Dialect()

sqeleton/utils.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -322,3 +322,20 @@ def match_like(pattern: str, strs: Sequence[str]) -> Iterable[str]:
322322
for s in strs:
323323
if reo.match(s):
324324
yield s
325+
326+
327+
328+
class UnknownMeta(type):
329+
def __instancecheck__(self, instance):
330+
return instance is Unknown
331+
332+
def __repr__(self):
333+
return "Unknown"
334+
335+
336+
class Unknown(metaclass=UnknownMeta):
337+
def __nonzero__(self):
338+
raise TypeError()
339+
340+
def __new__(class_, *args, **kwargs):
341+
raise RuntimeError("Unknown is a singleton")

tests/common.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,12 @@
77
import logging
88
import subprocess
99

10+
import sqeleton
1011
from parameterized import parameterized_class
1112

1213
from sqeleton import databases as db
1314
from sqeleton import connect
15+
from sqeleton.abcs.mixins import AbstractMixin_NormalizeValue
1416
from sqeleton.queries import table
1517
from sqeleton.databases import Database
1618
from sqeleton.query_utils import drop_table
@@ -83,7 +85,8 @@ def get_conn(cls: type, shared: bool = True) -> Database:
8385
_database_instances[cls] = get_conn(cls, shared=False)
8486
return _database_instances[cls]
8587

86-
return connect(CONN_STRINGS[cls], N_THREADS)
88+
con = sqeleton.connect.load_mixins(AbstractMixin_NormalizeValue)
89+
return con(CONN_STRINGS[cls], N_THREADS)
8790

8891

8992
def _print_used_dbs():

tests/test_database.py

Lines changed: 46 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
1-
from typing import Callable, List
2-
from datetime import datetime
31
import unittest
2+
from datetime import datetime
3+
from typing import Callable, List, Tuple
44

5-
from .common import str_to_checksum, TEST_MYSQL_CONN_STRING
6-
from .common import str_to_checksum, test_each_database_in_list, get_conn, random_table_suffix
7-
8-
from sqeleton.queries import table, current_timestamp
5+
import pytz
96

10-
from sqeleton import databases as dbs
117
from sqeleton import connect
12-
8+
from sqeleton import databases as dbs
9+
from sqeleton.queries import table, current_timestamp, NormalizeAsString
10+
from .common import TEST_MYSQL_CONN_STRING
11+
from .common import str_to_checksum, test_each_database_in_list, get_conn, random_table_suffix
12+
from sqeleton.abcs.database_types import TimestampTZ
1313

1414
TEST_DATABASES = {
1515
dbs.MySQL,
@@ -81,6 +81,43 @@ def test_current_timestamp(self):
8181
res = db.query(current_timestamp(), datetime)
8282
assert isinstance(res, datetime), (res, type(res))
8383

84+
def test_correct_timezone(self):
85+
name = "tbl_" + random_table_suffix()
86+
db = get_conn(self.db_cls)
87+
tbl = table(name, schema={
88+
"id": int, "created_at": TimestampTZ(9), "updated_at": TimestampTZ(9)
89+
})
90+
91+
db.query(tbl.create())
92+
93+
tz = pytz.timezone('Europe/Berlin')
94+
95+
now = datetime.now(tz)
96+
if isinstance(db, dbs.Presto):
97+
ms = now.microsecond // 1000 * 1000 # Presto max precision is 3
98+
now = now.replace(microsecond = ms)
99+
100+
db.query(table(name).insert_row(1, now, now))
101+
db.query(db.dialect.set_timezone_to_utc())
102+
103+
t = db.table(name).query_schema()
104+
t.schema["created_at"] = t.schema["created_at"].replace(precision=t.schema["created_at"].precision)
105+
106+
tbl = table(name, schema=t.schema)
107+
108+
results = db.query(tbl.select(NormalizeAsString(tbl[c]) for c in ["created_at", "updated_at"]), List[Tuple])
109+
110+
created_at = results[0][1]
111+
updated_at = results[0][1]
112+
113+
utc = now.astimezone(pytz.UTC)
114+
expected = utc.__format__("%Y-%m-%d %H:%M:%S.%f")
115+
116+
117+
self.assertEqual(created_at, expected)
118+
self.assertEqual(updated_at, expected)
119+
120+
db.query(tbl.drop())
84121

85122
@test_each_database
86123
class TestThreePartIds(unittest.TestCase):
@@ -104,3 +141,4 @@ def test_three_part_support(self):
104141
d = db.query_table_schema(part.path)
105142
assert len(d) == 1
106143
db.query(part.drop())
144+

0 commit comments

Comments
 (0)