Skip to content

Commit 75be063

Browse files
zxqfd555Manul from Pathway
authored andcommitted
improve static mode parsing for certain types in postgres input (#9969)
GitOrigin-RevId: b923ecece38ba91c60c67826b532a0f722587977
1 parent c5b6624 commit 75be063

3 files changed

Lines changed: 300 additions & 70 deletions

File tree

integration_tests/db_connectors/test_postgres.py

Lines changed: 92 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -436,11 +436,11 @@ def on_change(
436436
2025, 4, 23, 10, 13, tzinfo=datetime.timezone.utc
437437
)
438438
assert row["i"].value == SimpleObject("test")
439-
assert row["j"] == pd.Timedelta("4 days 2 seconds 123 us").value // 1_000
439+
assert row["j"] == pd.Timedelta("4 days 2 seconds 123 us")
440440
assert row["k"] == ("abc", "def", "ghi")
441441
assert row["l"] == (
442-
pd.Timedelta("4 days 2 seconds 123 us").value // 1_000,
443-
pd.Timedelta("1 days 2 seconds 3 us").value // 1_000,
442+
pd.Timedelta("4 days 2 seconds 123 us"),
443+
pd.Timedelta("1 days 2 seconds 3 us"),
444444
)
445445
assert row["m"] == (("a", "b"), ("c", "d"))
446446

@@ -2117,6 +2117,95 @@ def stream_target():
21172117
assert oor_id not in ids_out, f"Out-of-range row {oor_id} must be skipped"
21182118

21192119

2120+
@pytest.mark.parametrize("pg_type", ["BIGINT", "INTEGER"])
2121+
def test_static_int_read_as_duration(tmp_path, postgres, pg_type):
2122+
"""Integral columns declared as pw.Duration must be returned as Duration values
2123+
(nanoseconds in jsonlines output) in static read mode."""
2124+
2125+
class InputSchema(pw.Schema):
2126+
id: int = pw.column_definition(primary_key=True)
2127+
duration_col: pw.Duration
2128+
2129+
output_path = tmp_path / "output.jsonl"
2130+
table_name = postgres.random_table_name()
2131+
2132+
postgres.execute_sql(
2133+
f"""
2134+
CREATE TABLE {table_name} (
2135+
id {pg_type} PRIMARY KEY,
2136+
duration_col {pg_type} NOT NULL
2137+
);
2138+
"""
2139+
)
2140+
# 10 seconds expressed in microseconds, as the output connector writes pw.Duration
2141+
microseconds = 10_000_000
2142+
postgres.execute_sql(
2143+
f"INSERT INTO {table_name} (id, duration_col) VALUES (1, {microseconds});"
2144+
)
2145+
2146+
table = pw.io.postgres.read(
2147+
postgres_settings=POSTGRES_SETTINGS,
2148+
table_name=table_name,
2149+
schema=InputSchema,
2150+
mode="static",
2151+
)
2152+
pw.io.jsonlines.write(table, output_path)
2153+
run()
2154+
2155+
rows = []
2156+
with open(output_path) as f:
2157+
for line in f:
2158+
rows.append(json.loads(line))
2159+
2160+
assert len(rows) == 1
2161+
# pw.Duration is serialized to jsonlines as nanoseconds
2162+
assert rows[0]["duration_col"] == microseconds * 1_000
2163+
2164+
2165+
def test_static_bigint_array_read_as_duration_list(tmp_path, postgres):
2166+
"""BIGINT[] columns declared as list[pw.Duration] must be returned as lists of
2167+
Duration values (each element in nanoseconds in jsonlines) in static read mode."""
2168+
2169+
class InputSchema(pw.Schema):
2170+
id: int = pw.column_definition(primary_key=True)
2171+
durations: list[pw.Duration]
2172+
2173+
output_path = tmp_path / "output.jsonl"
2174+
table_name = postgres.random_table_name()
2175+
2176+
postgres.execute_sql(
2177+
f"""
2178+
CREATE TABLE {table_name} (
2179+
id BIGINT PRIMARY KEY,
2180+
durations BIGINT[] NOT NULL
2181+
);
2182+
"""
2183+
)
2184+
# Three durations in microseconds: 1s, 2s, 3s
2185+
postgres.execute_sql(
2186+
f"INSERT INTO {table_name} (id, durations)"
2187+
f" VALUES (1, ARRAY[1000000, 2000000, 3000000]::BIGINT[]);"
2188+
)
2189+
2190+
table = pw.io.postgres.read(
2191+
postgres_settings=POSTGRES_SETTINGS,
2192+
table_name=table_name,
2193+
schema=InputSchema,
2194+
mode="static",
2195+
)
2196+
pw.io.jsonlines.write(table, output_path)
2197+
run()
2198+
2199+
rows = []
2200+
with open(output_path) as f:
2201+
for line in f:
2202+
rows.append(json.loads(line))
2203+
2204+
assert len(rows) == 1
2205+
# Each Duration element serialized as nanoseconds
2206+
assert rows[0]["durations"] == [1_000_000_000, 2_000_000_000, 3_000_000_000]
2207+
2208+
21202209
def test_no_publication(tmp_path, postgres):
21212210
class InputSchema(pw.Schema):
21222211
value: str

integration_tests/db_connectors/test_postgres_parsing.py

Lines changed: 64 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, get_args
1+
from typing import Any, Union, get_args, get_origin
22

33
import numpy as np
44
import pandas as pd
@@ -155,11 +155,15 @@ def streaming_target():
155155
output_table_name,
156156
init_mode="create_if_not_exists",
157157
)
158+
observer, type_errors = _make_type_check_observer(ItemType)
159+
pw.io.python.write(table, observer)
158160
wait_result_with_checker(
159161
PostgresRowCountChecker(postgres, output_table_name, 2 * len(items)), 60
160162
)
161163
streaming_thread.join()
162164

165+
assert not type_errors, "\n".join(type_errors)
166+
163167
output_rows = postgres.get_table_contents(
164168
output_table_name, InputSchemaWithPkey.column_names()
165169
)
@@ -176,6 +180,60 @@ def streaming_target():
176180
_compare_input_and_output(ItemType, input_rows, output_rows)
177181

178182

183+
def _get_expected_python_type(ItemType: type) -> type | tuple:
184+
import types as builtin_types
185+
186+
origin = get_origin(ItemType)
187+
188+
if origin is Union or (
189+
hasattr(builtin_types, "UnionType") and origin is builtin_types.UnionType
190+
):
191+
args = get_args(ItemType)
192+
non_none_args = [a for a in args if a is not type(None)]
193+
inner = _get_expected_python_type(non_none_args[0])
194+
if isinstance(inner, tuple):
195+
return inner + (type(None),)
196+
return (inner, type(None))
197+
198+
if origin is not None:
199+
return origin
200+
201+
if ItemType is pw.Duration:
202+
return pd.Timedelta
203+
if ItemType in (pw.DateTimeNaive, pw.DateTimeUtc):
204+
return pd.Timestamp
205+
206+
return ItemType
207+
208+
209+
def _make_type_check_observer(
210+
ItemType: type,
211+
) -> tuple[pw.io.python.ConnectorObserver, list[str]]:
212+
type_errors: list[str] = []
213+
expected_type = _get_expected_python_type(ItemType)
214+
215+
class TypeCheckObserver(pw.io.python.ConnectorObserver):
216+
def on_change(self, key, row, time, is_addition):
217+
if is_addition:
218+
value = row["item"]
219+
if not isinstance(value, expected_type):
220+
# tuple is acceptable when the schema type is list
221+
if isinstance(value, tuple) and (
222+
expected_type is list
223+
or (isinstance(expected_type, tuple) and list in expected_type)
224+
):
225+
return
226+
type_errors.append(
227+
f"item value {value!r} has type {type(value)}, "
228+
f"expected {expected_type}"
229+
)
230+
231+
def on_end(self):
232+
pass
233+
234+
return TypeCheckObserver(), type_errors
235+
236+
179237
def _test_postgres_static(
180238
postgres, ItemType: type, items: list[Any], create_table: Any = _create_table
181239
) -> str:
@@ -205,8 +263,13 @@ def _test_postgres_static(
205263
pw.io.postgres.write(
206264
table, POSTGRES_SETTINGS, output_table_name, init_mode="create_if_not_exists"
207265
)
266+
267+
observer, type_errors = _make_type_check_observer(ItemType)
268+
pw.io.python.write(table, observer)
208269
run()
209270

271+
assert not type_errors, "\n".join(type_errors)
272+
210273
output_rows = postgres.get_table_contents(
211274
output_table_name, InputSchemaWithPkey.column_names()
212275
)

0 commit comments

Comments
 (0)