@@ -163,8 +163,8 @@ def assert_equal(
163163 actual_types , errors = "ignore"
164164 )
165165
166- actual = actual .replace ({None : np .nan })
167- expected = expected .replace ({None : np .nan })
166+ actual = actual .replace ({np .nan : None })
167+ expected = expected .replace ({np .nan : None })
168168
169169 def _to_hashable (x : t .Any ) -> t .Any :
170170 if isinstance (x , (list , np .ndarray )):
@@ -544,9 +544,10 @@ def generate_test(
544544 inputs = {
545545 models [dep ]
546546 .name : pandas_timestamp_to_pydatetime (
547- engine_adapter .fetchdf (query ).apply (lambda col : col .map (_normalize_dataframe )),
547+ engine_adapter .fetchdf (query ).apply (lambda col : col .map (_normalize_df_value )),
548548 models [dep ].columns_to_types ,
549549 )
550+ .replace ({np .nan : None })
550551 .to_dict (orient = "records" )
551552 for dep , query in input_queries .items ()
552553 }
@@ -588,10 +589,14 @@ def generate_test(
588589 cte_query = cte_query .with_ (prev .alias , prev .this )
589590
590591 cte_output = t .cast (SqlModelTest , test )._execute (cte_query )
591- ctes [cte .alias ] = pandas_timestamp_to_pydatetime (
592- cte_output .apply (lambda col : col .map (_normalize_dataframe )),
593- cte_query .named_selects ,
594- ).to_dict (orient = "records" )
592+ ctes [cte .alias ] = (
593+ pandas_timestamp_to_pydatetime (
594+ cte_output .apply (lambda col : col .map (_normalize_df_value )),
595+ cte_query .named_selects ,
596+ )
597+ .replace ({np .nan : None })
598+ .to_dict (orient = "records" )
599+ )
595600
596601 previous_ctes .append (cte )
597602
@@ -602,9 +607,13 @@ def generate_test(
602607 else :
603608 output = t .cast (PythonModelTest , test )._execute_model ()
604609
605- outputs ["query" ] = pandas_timestamp_to_pydatetime (
606- output .apply (lambda col : col .map (_normalize_dataframe )), model .columns_to_types
607- ).to_dict (orient = "records" )
610+ outputs ["query" ] = (
611+ pandas_timestamp_to_pydatetime (
612+ output .apply (lambda col : col .map (_normalize_df_value )), model .columns_to_types
613+ )
614+ .replace ({np .nan : None })
615+ .to_dict (orient = "records" )
616+ )
608617
609618 test .tearDown ()
610619
@@ -662,14 +671,14 @@ def _raise_error(msg: str, path: Path | None = None) -> None:
662671 raise TestError (msg )
663672
664673
665- def _normalize_dataframe (value : t .Any ) -> t .Any :
674+ def _normalize_df_value (value : t .Any ) -> t .Any :
666675 """Normalize data in a pandas dataframe so ruamel and sqlglot can deal with it."""
667676 if isinstance (value , (list , np .ndarray )):
668- return [_normalize_dataframe (v ) for v in value ]
677+ return [_normalize_df_value (v ) for v in value ]
669678 if isinstance (value , dict ):
670679 if "key" in value and "value" in value :
671680 # Maps returned by DuckDB look like: {'key': ['key1', 'key2'], 'value': [10, 20]}
672681 # so we convert to {'key1': 10, 'key2': 20} (TODO: handle more dialects here)
673- return {k : _normalize_dataframe (v ) for k , v in zip (value ["key" ], value ["value" ])}
674- return {k : _normalize_dataframe (v ) for k , v in value .items ()}
682+ return {k : _normalize_df_value (v ) for k , v in zip (value ["key" ], value ["value" ])}
683+ return {k : _normalize_df_value (v ) for k , v in value .items ()}
675684 return value
0 commit comments