Skip to content

Commit aa4a437

Browse files
authored
Fix: convert NaN and NaT to None for test generation (#2437)
* Fix: convert NaN and NaT to None for test generation * style fixup * Refactor test generation tests to use parametrize * Convert np.nan to None again * Fix test
1 parent 69bc602 commit aa4a437

File tree

3 files changed

+156
-172
lines changed

3 files changed

+156
-172
lines changed

sqlmesh/core/test/definition.py

Lines changed: 23 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -163,8 +163,8 @@ def assert_equal(
163163
actual_types, errors="ignore"
164164
)
165165

166-
actual = actual.replace({None: np.nan})
167-
expected = expected.replace({None: np.nan})
166+
actual = actual.replace({np.nan: None})
167+
expected = expected.replace({np.nan: None})
168168

169169
def _to_hashable(x: t.Any) -> t.Any:
170170
if isinstance(x, (list, np.ndarray)):
@@ -544,9 +544,10 @@ def generate_test(
544544
inputs = {
545545
models[dep]
546546
.name: pandas_timestamp_to_pydatetime(
547-
engine_adapter.fetchdf(query).apply(lambda col: col.map(_normalize_dataframe)),
547+
engine_adapter.fetchdf(query).apply(lambda col: col.map(_normalize_df_value)),
548548
models[dep].columns_to_types,
549549
)
550+
.replace({np.nan: None})
550551
.to_dict(orient="records")
551552
for dep, query in input_queries.items()
552553
}
@@ -588,10 +589,14 @@ def generate_test(
588589
cte_query = cte_query.with_(prev.alias, prev.this)
589590

590591
cte_output = t.cast(SqlModelTest, test)._execute(cte_query)
591-
ctes[cte.alias] = pandas_timestamp_to_pydatetime(
592-
cte_output.apply(lambda col: col.map(_normalize_dataframe)),
593-
cte_query.named_selects,
594-
).to_dict(orient="records")
592+
ctes[cte.alias] = (
593+
pandas_timestamp_to_pydatetime(
594+
cte_output.apply(lambda col: col.map(_normalize_df_value)),
595+
cte_query.named_selects,
596+
)
597+
.replace({np.nan: None})
598+
.to_dict(orient="records")
599+
)
595600

596601
previous_ctes.append(cte)
597602

@@ -602,9 +607,13 @@ def generate_test(
602607
else:
603608
output = t.cast(PythonModelTest, test)._execute_model()
604609

605-
outputs["query"] = pandas_timestamp_to_pydatetime(
606-
output.apply(lambda col: col.map(_normalize_dataframe)), model.columns_to_types
607-
).to_dict(orient="records")
610+
outputs["query"] = (
611+
pandas_timestamp_to_pydatetime(
612+
output.apply(lambda col: col.map(_normalize_df_value)), model.columns_to_types
613+
)
614+
.replace({np.nan: None})
615+
.to_dict(orient="records")
616+
)
608617

609618
test.tearDown()
610619

@@ -662,14 +671,14 @@ def _raise_error(msg: str, path: Path | None = None) -> None:
662671
raise TestError(msg)
663672

664673

665-
def _normalize_dataframe(value: t.Any) -> t.Any:
674+
def _normalize_df_value(value: t.Any) -> t.Any:
666675
"""Normalize data in a pandas dataframe so ruamel and sqlglot can deal with it."""
667676
if isinstance(value, (list, np.ndarray)):
668-
return [_normalize_dataframe(v) for v in value]
677+
return [_normalize_df_value(v) for v in value]
669678
if isinstance(value, dict):
670679
if "key" in value and "value" in value:
671680
# Maps returned by DuckDB look like: {'key': ['key1', 'key2'], 'value': [10, 20]}
672681
# so we convert to {'key1': 10, 'key2': 20} (TODO: handle more dialects here)
673-
return {k: _normalize_dataframe(v) for k, v in zip(value["key"], value["value"])}
674-
return {k: _normalize_dataframe(v) for k, v in value.items()}
682+
return {k: _normalize_df_value(v) for k, v in zip(value["key"], value["value"])}
683+
return {k: _normalize_df_value(v) for k, v in value.items()}
675684
return value

tests/core/test_config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -353,7 +353,7 @@ def test_cloud_composer_scheduler_config(tmp_path_factory):
353353
"'^dev$': dev_catalog\n '[': other_catalog",
354354
{},
355355
"duckdb",
356-
"`\[` is not a valid regular expression.",
356+
"`\\[` is not a valid regular expression.",
357357
),
358358
(
359359
"'^prod$': prod_catalog",

0 commit comments

Comments
 (0)