Skip to content

Commit ff1b6fe

Browse files
vchangeorgesittas
andauthored
Fix!: Handle maps during test generation (#2364)
* Fix: Handle maps during test generation * Bump sqlglot to 23.3.0 * Refactor generate_test tests * Fix typo in comment * Remove unused variables * Add more tests * Update sqlmesh/core/test/definition.py Co-authored-by: Jo <46752250+georgesittas@users.noreply.github.com> --------- Co-authored-by: Jo <46752250+georgesittas@users.noreply.github.com>
1 parent 58adc10 commit ff1b6fe

File tree

3 files changed

+92
-42
lines changed

3 files changed

+92
-42
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@
4646
"requests",
4747
"rich[jupyter]",
4848
"ruamel.yaml",
49-
"sqlglot[rs]~=23.2.0",
49+
"sqlglot[rs]~=23.3.0",
5050
],
5151
extras_require={
5252
"bigquery": [

sqlmesh/core/test/definition.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -592,4 +592,11 @@ def _normalize_dataframe(value: t.Any) -> t.Any:
592592
"""Normalize data in a pandas dataframe so ruamel and sqlglot can deal with it."""
593593
if isinstance(value, np.ndarray):
594594
return [_normalize_dataframe(v) for v in value]
595+
if isinstance(value, dict):
596+
if "key" in value and "value" in value:
597+
# Maps returned by DuckDB have the following structure: {'key': ['key1', 'key2', 'key3'], 'value': [10, 20, 30]}
598+
# so we convert to {'key1': 10, 'key2': 20, 'key3': 30}
599+
# TODO: handle more dialects here
600+
return {k: _normalize_dataframe(v) for k, v in zip(value["key"], value["value"])}
601+
return {k: _normalize_dataframe(v) for k, v in value.items()}
595602
return value

tests/core/test_test.py

Lines changed: 84 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -921,7 +921,15 @@ def test_successes(sushi_context: Context) -> None:
921921
assert "test_customer_revenue_by_day" in successful_tests
922922

923923

924-
def test_test_generation_with_array(tmp_path: Path) -> None:
924+
def test_test_generation_with_data_structures(tmp_path: Path) -> None:
925+
def create_test(query: str):
926+
context.create_test(
927+
"sqlmesh_example.foo",
928+
input_queries={"sqlmesh_example.bar": query},
929+
overwrite=True,
930+
)
931+
return load_yaml(context.path / c.TESTS / "test_foo.yaml")
932+
925933
init_example_project(tmp_path, dialect="duckdb")
926934

927935
config = Config(
@@ -930,57 +938,92 @@ def test_test_generation_with_array(tmp_path: Path) -> None:
930938
)
931939
foo_sql_file = tmp_path / "models" / "foo.sql"
932940
foo_sql_file.write_text(
933-
"MODEL (name sqlmesh_example.foo); SELECT array_col FROM sqlmesh_example.bar;"
941+
"MODEL (name sqlmesh_example.foo); SELECT col FROM sqlmesh_example.bar;"
934942
)
935943
bar_sql_file = tmp_path / "models" / "bar.sql"
936-
bar_sql_file.write_text(
937-
"MODEL (name sqlmesh_example.bar); SELECT array_col FROM external_table;"
938-
)
944+
bar_sql_file.write_text("MODEL (name sqlmesh_example.bar); SELECT col FROM external_table;")
939945

940946
context = Context(paths=tmp_path, config=config)
941947

942-
input_queries = {"sqlmesh_example.bar": "SELECT ['value1', 'value2'] AS array_col"}
943-
944-
context.create_test(
945-
"sqlmesh_example.foo",
946-
input_queries=input_queries,
947-
overwrite=True,
948-
variables={"start": "2020-01-01", "end": "2024-01-01"},
949-
)
950-
951-
test = load_yaml(context.path / c.TESTS / "test_foo.yaml")
952-
953-
assert len(test) == 1
954-
assert "test_foo" in test
955-
assert "vars" in test["test_foo"]
956-
assert test["test_foo"]["inputs"] == {
957-
"sqlmesh_example.bar": [{"array_col": ["value1", "value2"]}]
958-
}
959-
assert test["test_foo"]["outputs"] == {"query": [{"array_col": ["value1", "value2"]}]}
948+
# Array of strings
949+
test = create_test("SELECT ['value1', 'value2'] AS col")
950+
expected_value: t.Any = [{"col": ["value1", "value2"]}]
951+
assert test["test_foo"]["inputs"] == {"sqlmesh_example.bar": expected_value}
952+
assert test["test_foo"]["outputs"] == {"query": expected_value}
960953

961954
# Array of arrays
962-
input_queries = {
963-
"sqlmesh_example.bar": "SELECT [['value1'], ['value2', 'value3']] AS array_col"
964-
}
955+
test = create_test("SELECT [['value1'], ['value2', 'value3']] AS col")
956+
expected_value = [{"col": [["value1"], ["value2", "value3"]]}]
957+
assert test["test_foo"]["inputs"] == {"sqlmesh_example.bar": expected_value}
958+
assert test["test_foo"]["outputs"] == {"query": expected_value}
959+
960+
# Array of maps
961+
test = create_test("SELECT [MAP {'key': 'value1'}, MAP {'key': 'value2'}] AS col")
962+
expected_value = [{"col": [{"key": "value1"}, {"key": "value2"}]}]
963+
assert test["test_foo"]["inputs"] == {"sqlmesh_example.bar": expected_value}
964+
assert test["test_foo"]["outputs"] == {"query": expected_value}
965+
966+
# Array of structs
967+
test = create_test("SELECT [{'key': 'value1'}, {'key': 'value2'}] AS col")
968+
expected_value = [{"col": [{"key": "value1"}, {"key": "value2"}]}]
969+
assert test["test_foo"]["inputs"] == {"sqlmesh_example.bar": expected_value}
970+
assert test["test_foo"]["outputs"] == {"query": expected_value}
971+
972+
# Map of strings
973+
test = create_test("SELECT MAP {'key1': 'value1', 'key2': 'value2'} AS col")
974+
expected_value = [{"col": {"key1": "value1", "key2": "value2"}}]
975+
assert test["test_foo"]["inputs"] == {"sqlmesh_example.bar": expected_value}
976+
assert test["test_foo"]["outputs"] == {"query": expected_value}
977+
978+
# Struct of strings
979+
test = create_test("SELECT {'key1': 'value1', 'key2': 'value2'} AS col")
980+
expected_value = [{"col": {"key1": "value1", "key2": "value2"}}]
981+
assert test["test_foo"]["inputs"] == {"sqlmesh_example.bar": expected_value}
982+
assert test["test_foo"]["outputs"] == {"query": expected_value}
983+
984+
# Map of arrays
985+
test = create_test("SELECT MAP {'key1': ['value1'], 'key2': ['value2']} AS col")
986+
expected_value = [{"col": {"key1": ["value1"], "key2": ["value2"]}}]
987+
assert test["test_foo"]["inputs"] == {"sqlmesh_example.bar": expected_value}
988+
assert test["test_foo"]["outputs"] == {"query": expected_value}
989+
990+
# Struct of arrays
991+
test = create_test("SELECT {'key1': ['value1'], 'key2': ['value2']} AS col")
992+
expected_value = [{"col": {"key1": ["value1"], "key2": ["value2"]}}]
993+
assert test["test_foo"]["inputs"] == {"sqlmesh_example.bar": expected_value}
994+
assert test["test_foo"]["outputs"] == {"query": expected_value}
995+
996+
# Map of maps
997+
test = create_test(
998+
"SELECT MAP {'key1': MAP {'subkey1': 'value1'}, 'key2': MAP {'subkey2': 'value2'}} AS col"
999+
)
1000+
expected_value = [{"col": {"key1": {"subkey1": "value1"}, "key2": {"subkey2": "value2"}}}]
1001+
assert test["test_foo"]["inputs"] == {"sqlmesh_example.bar": expected_value}
1002+
assert test["test_foo"]["outputs"] == {"query": expected_value}
9651003

966-
context.create_test(
967-
"sqlmesh_example.foo",
968-
input_queries=input_queries,
969-
overwrite=True,
970-
variables={"start": "2020-01-01", "end": "2024-01-01"},
1004+
# Map of structs
1005+
test = create_test(
1006+
"SELECT MAP {'key1': {'subkey': 'value1'}, 'key2': {'subkey': 'value2'}} AS col"
9711007
)
1008+
expected_value = [{"col": {"key1": {"subkey": "value1"}, "key2": {"subkey": "value2"}}}]
1009+
assert test["test_foo"]["inputs"] == {"sqlmesh_example.bar": expected_value}
1010+
assert test["test_foo"]["outputs"] == {"query": expected_value}
9721011

973-
test = load_yaml(context.path / c.TESTS / "test_foo.yaml")
1012+
# Struct of structs
1013+
test = create_test(
1014+
"SELECT {'key1': {'subkey1': 'value1'}, 'key2': {'subkey2': 'value2'}} AS col"
1015+
)
1016+
expected_value = [{"col": {"key1": {"subkey1": "value1"}, "key2": {"subkey2": "value2"}}}]
1017+
assert test["test_foo"]["inputs"] == {"sqlmesh_example.bar": expected_value}
1018+
assert test["test_foo"]["outputs"] == {"query": expected_value}
9741019

975-
assert len(test) == 1
976-
assert "test_foo" in test
977-
assert "vars" in test["test_foo"]
978-
assert test["test_foo"]["inputs"] == {
979-
"sqlmesh_example.bar": [{"array_col": [["value1"], ["value2", "value3"]]}]
980-
}
981-
assert test["test_foo"]["outputs"] == {
982-
"query": [{"array_col": [["value1"], ["value2", "value3"]]}]
983-
}
1020+
# Struct of maps
1021+
test = create_test(
1022+
"SELECT {'key1': MAP {'subkey1': 'value1'}, 'key2': MAP {'subkey2': 'value2'}} AS col"
1023+
)
1024+
expected_value = [{"col": {"key1": {"subkey1": "value1"}, "key2": {"subkey2": "value2"}}}]
1025+
assert test["test_foo"]["inputs"] == {"sqlmesh_example.bar": expected_value}
1026+
assert test["test_foo"]["outputs"] == {"query": expected_value}
9841027

9851028

9861029
def test_test_generation_with_timestamp(tmp_path: Path) -> None:

0 commit comments

Comments
 (0)