Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Commit cda5180

Browse files
committed
Added infrastructure to support PR #20 (TimestampTZ repr)
1 parent e5b34f3 commit cda5180

File tree

6 files changed

+38
-7
lines changed

6 files changed

+38
-7
lines changed

sqeleton/abcs/database_types.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
from runtype import dataclass
77

8-
from ..utils import ArithAlphanumeric, ArithUUID, Self
8+
from ..utils import ArithAlphanumeric, ArithUUID, Self, Unknown
99

1010

1111
DbPath = Tuple[str, ...]
@@ -20,7 +20,7 @@ class ColType:
2020
@dataclass
2121
class PrecisionType(ColType):
2222
precision: int
23-
rounds: bool
23+
rounds: Union[bool, Unknown] = Unknown
2424

2525

2626
class Boolean(ColType):

sqeleton/databases/base.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
String_VaryingAlphanum,
3131
TemporalType,
3232
UnknownColType,
33+
TimestampTZ,
3334
Text,
3435
DbTime,
3536
DbPath,
@@ -202,6 +203,8 @@ def constant_values(self, rows) -> str:
202203
def type_repr(self, t) -> str:
203204
if isinstance(t, str):
204205
return t
206+
elif isinstance(t, TimestampTZ):
207+
return f"TIMESTAMP({min(t.precision, DEFAULT_DATETIME_PRECISION)})"
205208
return {
206209
int: "INT",
207210
str: "VARCHAR",

sqeleton/databases/postgresql.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,11 @@ def set_timezone_to_utc(self) -> str:
9797
def current_timestamp(self) -> str:
9898
return "current_timestamp"
9999

100+
def type_repr(self, t) -> str:
101+
if isinstance(t, TimestampTZ):
102+
return f"timestamp ({t.precision}) with time zone"
103+
return super().type_repr(t)
104+
100105

101106
class PostgreSQL(ThreadedDatabase):
102107
dialect = PostgresqlDialect()

sqeleton/databases/snowflake.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,11 @@ def set_timezone_to_utc(self) -> str:
139139
def optimizer_hints(self, hints: str) -> str:
140140
raise NotImplementedError("Optimizer hints not yet implemented in snowflake")
141141

142+
def type_repr(self, t) -> str:
143+
if isinstance(t, TimestampTZ):
144+
return f"timestamp_tz({t.precision})"
145+
return super().type_repr(t)
146+
142147

143148
class Snowflake(Database):
144149
dialect = Dialect()

sqeleton/utils.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -322,3 +322,20 @@ def match_like(pattern: str, strs: Sequence[str]) -> Iterable[str]:
322322
for s in strs:
323323
if reo.match(s):
324324
yield s
325+
326+
327+
328+
class UnknownMeta(type):
329+
def __instancecheck__(self, instance):
330+
return instance is Unknown
331+
332+
def __repr__(self):
333+
return "Unknown"
334+
335+
336+
class Unknown(metaclass=UnknownMeta):
337+
def __nonzero__(self):
338+
raise TypeError()
339+
340+
def __new__(class_, *args, **kwargs):
341+
raise RuntimeError("Unknown is a singleton")

tests/test_database.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from sqeleton.queries import table, current_timestamp, NormalizeAsString
1010
from .common import TEST_MYSQL_CONN_STRING
1111
from .common import str_to_checksum, test_each_database_in_list, get_conn, random_table_suffix
12+
from sqeleton.abcs.database_types import TimestampTZ
1213

1314
TEST_DATABASES = {
1415
dbs.MySQL,
@@ -83,22 +84,22 @@ def test_current_timestamp(self):
8384
def test_correct_timezone(self):
8485
name = "tbl_" + random_table_suffix()
8586
db = get_conn(self.db_cls)
86-
tbl = table(db.parse_table_name(name), schema={
87-
"id": int, "created_at": "timestamp_tz(9)", "updated_at": "timestamp_tz(9)"
87+
tbl = table(name, schema={
88+
"id": int, "created_at": TimestampTZ(9), "updated_at": TimestampTZ(9)
8889
})
8990

9091
db.query(tbl.create())
9192

9293
tz = pytz.timezone('Europe/Berlin')
9394

9495
now = datetime.now(tz)
95-
db.query(table(db.parse_table_name(name)).insert_row("1", now, now))
96+
db.query(table(name).insert_row("1", now, now))
9697
db.query(db.dialect.set_timezone_to_utc())
9798

9899
t = db.table(name).query_schema()
99-
t.schema["created_at"] = t.schema["created_at"].replace(precision=t.schema["created_at"].precision, rounds=True)
100+
t.schema["created_at"] = t.schema["created_at"].replace(precision=t.schema["created_at"].precision)
100101

101-
tbl = table(db.parse_table_name(name), schema=t.schema)
102+
tbl = table(name, schema=t.schema)
102103

103104
results = db.query(tbl.select(NormalizeAsString(tbl[c]) for c in ["created_at", "updated_at"]), List[Tuple])
104105

0 commit comments

Comments
 (0)