Skip to content

Commit fea8910

Browse files
treyspizeigerman
authored andcommitted
Fix: do not include a schema def in CTAS unless all column types are known (#2063)
* Do not include a schema def in CTAS unless all col types known * Make columns_to_types all known helper * Make tests cover all cases
1 parent ab3c677 commit fea8910

File tree

8 files changed

+81
-25
lines changed

8 files changed

+81
-25
lines changed

sqlmesh/core/engine_adapter/base.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
)
4040
from sqlmesh.core.model.kind import TimeColumn
4141
from sqlmesh.core.schema_diff import SchemaDiffer
42-
from sqlmesh.utils import double_escape, random_id
42+
from sqlmesh.utils import columns_to_types_all_known, double_escape, random_id
4343
from sqlmesh.utils.connection_pool import create_connection_pool
4444
from sqlmesh.utils.date import TimeLike, make_inclusive, to_ts
4545
from sqlmesh.utils.errors import SQLMeshError, UnsupportedCatalogOperationError
@@ -604,13 +604,14 @@ def _create_table_from_source_queries(
604604
# types, and for evaluation methods like `LogicalReplaceQueryMixin.replace_query()`
605605
# calls and SCD Type 2 model calls.
606606
schema = None
607+
columns_to_types_known = columns_to_types and columns_to_types_all_known(columns_to_types)
607608
if (
608609
column_descriptions
609-
and columns_to_types
610+
and columns_to_types_known
610611
and self.COMMENT_CREATION_TABLE.is_in_schema_def_ctas
611612
and self.comments_enabled
612613
):
613-
schema = self._build_schema_exp(table, columns_to_types, column_descriptions)
614+
schema = self._build_schema_exp(table, columns_to_types, column_descriptions) # type: ignore
614615

615616
with self.transaction(condition=len(source_queries) > 1):
616617
for i, source_query in enumerate(source_queries):

sqlmesh/core/model/definition.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
from sqlmesh.core.model.meta import ModelMeta
3434
from sqlmesh.core.model.seed import CsvSeedReader, Seed, create_seed
3535
from sqlmesh.core.renderer import ExpressionRenderer, QueryRenderer
36-
from sqlmesh.utils import str_to_bool
36+
from sqlmesh.utils import columns_to_types_all_known, str_to_bool
3737
from sqlmesh.utils.date import TimeLike, make_inclusive, to_datetime, to_ds, to_ts
3838
from sqlmesh.utils.errors import ConfigError, SQLMeshError, raise_config_error
3939
from sqlmesh.utils.hashing import hash_data
@@ -575,10 +575,7 @@ def annotated(self) -> bool:
575575
}
576576
if not columns_to_types:
577577
return False
578-
return all(
579-
not column_type.is_type(exp.DataType.Type.UNKNOWN, exp.DataType.Type.NULL)
580-
for column_type in columns_to_types.values()
581-
)
578+
return columns_to_types_all_known(columns_to_types)
582579

583580
@property
584581
def sorted_python_env(self) -> t.List[t.Tuple[str, Executable]]:

sqlmesh/utils/__init__.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from functools import lru_cache, reduce, wraps
2020
from pathlib import Path
2121

22+
from sqlglot import exp
2223
from sqlglot.dialects.dialect import Dialects
2324

2425
logger = logging.getLogger(__name__)
@@ -305,3 +306,13 @@ def groupby(
305306
for item in items:
306307
grouped[func(item)].append(item)
307308
return grouped
309+
310+
311+
def columns_to_types_all_known(columns_to_types: t.Dict[str, exp.DataType]) -> bool:
312+
"""
313+
Checks that all column types are known and not NULL.
314+
"""
315+
return all(
316+
not column_type.is_type(exp.DataType.Type.UNKNOWN, exp.DataType.Type.NULL)
317+
for column_type in columns_to_types.values()
318+
)

tests/core/engine_adapter/test_base.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -400,15 +400,24 @@ def test_comments(make_mocked_engine_adapter: t.Callable, mocker: MockerFixture)
400400

401401
adapter.create_table(
402402
"test_table",
403-
{"a": "int", "b": "int"},
403+
{"a": exp.DataType.build("INT"), "b": exp.DataType.build("INT")},
404404
table_description="test description",
405405
column_descriptions={"a": "a description"},
406406
)
407407

408408
adapter.ctas(
409409
"test_table",
410410
parse_one("SELECT a, b FROM source_table"),
411-
{"a": "int", "b": "int"},
411+
{"a": exp.DataType.build("INT"), "b": exp.DataType.build("INT")},
412+
table_description="test description",
413+
column_descriptions={"a": "a description"},
414+
)
415+
416+
# CTAS call should not include schema definition if UNKNOWN data type present
417+
adapter.ctas(
418+
"test_table",
419+
parse_one("SELECT a, b FROM source_table"),
420+
{"a": exp.DataType.build("UNKNOWN"), "b": exp.DataType.build("INT")},
412421
table_description="test description",
413422
column_descriptions={"a": "a description"},
414423
)
@@ -431,8 +440,10 @@ def test_comments(make_mocked_engine_adapter: t.Callable, mocker: MockerFixture)
431440

432441
sql_calls = to_sql_calls(adapter)
433442
assert sql_calls == [
434-
"""CREATE TABLE IF NOT EXISTS "test_table" ("a" int COMMENT 'a description', "b" int) COMMENT='test description'""",
435-
"""CREATE TABLE IF NOT EXISTS "test_table" ("a" int COMMENT 'a description', "b" int) COMMENT='test description' AS SELECT "a", "b" FROM "source_table\"""",
443+
"""CREATE TABLE IF NOT EXISTS "test_table" ("a" INT COMMENT 'a description', "b" INT) COMMENT='test description'""",
444+
"""CREATE TABLE IF NOT EXISTS "test_table" ("a" INT COMMENT 'a description', "b" INT) COMMENT='test description' AS SELECT "a", "b" FROM "source_table\"""",
445+
"""CREATE TABLE IF NOT EXISTS "test_table" COMMENT='test description' AS SELECT "a", "b" FROM "source_table\"""",
446+
"""COMMENT ON COLUMN "test_table"."a" IS 'a description'""",
436447
"""CREATE OR REPLACE VIEW "test_view" COMMENT='test description' AS SELECT "a", "b" FROM "source_table\"""",
437448
"""COMMENT ON TABLE "test_table" IS 'test description'""",
438449
"""COMMENT ON COLUMN "test_table"."a" IS 'a description'""",
@@ -1334,7 +1345,9 @@ def test_merge_scd_type_2_pandas(make_mocked_engine_adapter: t.Callable):
13341345

13351346
def test_replace_query(make_mocked_engine_adapter: t.Callable):
13361347
adapter = make_mocked_engine_adapter(EngineAdapter)
1337-
adapter.replace_query("test_table", parse_one("SELECT a FROM tbl"), {"a": "int"})
1348+
adapter.replace_query(
1349+
"test_table", parse_one("SELECT a FROM tbl"), {"a": exp.DataType.build("INT")}
1350+
)
13381351

13391352
# TODO: Shouldn't we enforce that `a` is casted to an int?
13401353
assert to_sql_calls(adapter) == [
@@ -1347,7 +1360,9 @@ def test_replace_query_pandas(make_mocked_engine_adapter: t.Callable):
13471360
adapter.DEFAULT_BATCH_SIZE = 1
13481361

13491362
df = pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]})
1350-
adapter.replace_query("test_table", df, {"a": "int", "b": "int"})
1363+
adapter.replace_query(
1364+
"test_table", df, {"a": exp.DataType.build("INT"), "b": exp.DataType.build("INT")}
1365+
)
13511366

13521367
assert to_sql_calls(adapter) == [
13531368
'CREATE OR REPLACE TABLE "test_table" AS SELECT CAST("a" AS INT) AS "a", CAST("b" AS INT) AS "b" FROM (VALUES (1, 4)) AS "t"("a", "b")',

tests/core/engine_adapter/test_bigquery.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,9 @@ def test_replace_query(make_mocked_engine_adapter: t.Callable, mocker: MockerFix
218218
execute_mock = mocker.patch(
219219
"sqlmesh.core.engine_adapter.bigquery.BigQueryEngineAdapter.execute"
220220
)
221-
adapter.replace_query("test_table", parse_one("SELECT a FROM tbl"), {"a": "int"})
221+
adapter.replace_query(
222+
"test_table", parse_one("SELECT a FROM tbl"), {"a": exp.DataType.build("INT")}
223+
)
222224

223225
sql_calls = _to_sql_calls(execute_mock)
224226
assert sql_calls == ["CREATE OR REPLACE TABLE `test_table` AS SELECT `a` FROM `tbl`"]
@@ -591,15 +593,15 @@ def test_comments(make_mocked_engine_adapter: t.Callable, mocker: MockerFixture)
591593

592594
adapter.create_table(
593595
"test_table",
594-
{"a": "int", "b": "int"},
596+
{"a": exp.DataType.build("INT"), "b": exp.DataType.build("INT")},
595597
table_description="test description",
596598
column_descriptions={"a": "a description"},
597599
)
598600

599601
adapter.ctas(
600602
"test_table",
601603
parse_one("SELECT a, b FROM source_table"),
602-
{"a": "int", "b": "int"},
604+
{"a": exp.DataType.build("INT"), "b": exp.DataType.build("INT")},
603605
table_description="test description",
604606
column_descriptions={"a": "a description"},
605607
)
@@ -617,8 +619,8 @@ def test_comments(make_mocked_engine_adapter: t.Callable, mocker: MockerFixture)
617619

618620
sql_calls = _to_sql_calls(execute_mock)
619621
assert sql_calls == [
620-
"CREATE TABLE IF NOT EXISTS `test_table` (`a` int OPTIONS (description='a description'), `b` int) OPTIONS (description='test description')",
621-
"CREATE TABLE IF NOT EXISTS `test_table` (`a` int OPTIONS (description='a description'), `b` int) OPTIONS (description='test description') AS SELECT `a`, `b` FROM `source_table`",
622+
"CREATE TABLE IF NOT EXISTS `test_table` (`a` INT64 OPTIONS (description='a description'), `b` INT64) OPTIONS (description='test description')",
623+
"CREATE TABLE IF NOT EXISTS `test_table` (`a` INT64 OPTIONS (description='a description'), `b` INT64) OPTIONS (description='test description') AS SELECT `a`, `b` FROM `source_table`",
622624
"CREATE OR REPLACE VIEW `test_table` OPTIONS (description='test description') AS SELECT `a`, `b` FROM `source_table`",
623625
"ALTER TABLE `test_table` SET OPTIONS(description = 'test description')",
624626
]

tests/core/engine_adapter/test_databricks.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import pandas as pd
55
import pytest
66
from pytest_mock import MockFixture
7-
from sqlglot import parse_one
7+
from sqlglot import exp, parse_one
88

99
from sqlmesh.core.engine_adapter import DatabricksEngineAdapter
1010
from tests.core.engine_adapter import to_sql_calls
@@ -18,7 +18,9 @@ def test_replace_query_not_exists(mocker: MockFixture, make_mocked_engine_adapte
1818
return_value=False,
1919
)
2020
adapter = make_mocked_engine_adapter(DatabricksEngineAdapter)
21-
adapter.replace_query("test_table", parse_one("SELECT a FROM tbl"), {"a": "int"})
21+
adapter.replace_query(
22+
"test_table", parse_one("SELECT a FROM tbl"), {"a": exp.DataType.build("INT")}
23+
)
2224

2325
assert to_sql_calls(adapter) == [
2426
"CREATE TABLE IF NOT EXISTS `test_table` AS SELECT `a` FROM `tbl`",
@@ -47,7 +49,9 @@ def test_replace_query_pandas_not_exists(
4749
)
4850
adapter = make_mocked_engine_adapter(DatabricksEngineAdapter)
4951
df = pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]})
50-
adapter.replace_query("test_table", df, {"a": "int", "b": "int"})
52+
adapter.replace_query(
53+
"test_table", df, {"a": exp.DataType.build("INT"), "b": exp.DataType.build("INT")}
54+
)
5155

5256
assert to_sql_calls(adapter) == [
5357
"CREATE TABLE IF NOT EXISTS `test_table` AS SELECT CAST(`a` AS INT) AS `a`, CAST(`b` AS INT) AS `b` FROM VALUES (1, 4), (2, 5), (3, 6) AS `t`(`a`, `b`)",

tests/core/engine_adapter/test_spark.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,9 @@ def test_replace_query_not_exists(mocker: MockerFixture, make_mocked_engine_adap
171171
return_value=False,
172172
)
173173
adapter = make_mocked_engine_adapter(SparkEngineAdapter)
174-
adapter.replace_query("test_table", parse_one("SELECT a FROM tbl"), {"a": "int"})
174+
adapter.replace_query(
175+
"test_table", parse_one("SELECT a FROM tbl"), {"a": exp.DataType.build("INT")}
176+
)
175177

176178
assert to_sql_calls(adapter) == [
177179
"CREATE TABLE IF NOT EXISTS `test_table` AS SELECT `a` FROM `tbl`",
@@ -876,13 +878,13 @@ def test_replace_query_with_wap_self_reference(
876878
adapter.replace_query(
877879
"catalog.schema.table.branch_wap_12345",
878880
parse_one("SELECT 1 as a FROM catalog.schema.table.branch_wap_12345"),
879-
columns_to_types={"a": "int"},
881+
columns_to_types={"a": exp.DataType.build("INT")},
880882
storage_format="ICEBERG",
881883
)
882884

883885
sql_calls = to_sql_calls(adapter)
884886
assert sql_calls == [
885-
"CREATE TABLE IF NOT EXISTS `catalog`.`schema`.`table` (`a` int)",
887+
"CREATE TABLE IF NOT EXISTS `catalog`.`schema`.`table` (`a` INT)",
886888
"CREATE SCHEMA IF NOT EXISTS `schema`",
887889
"CREATE TABLE IF NOT EXISTS `catalog`.`schema`.`temp_branch_wap_12345_abcdefgh` USING ICEBERG AS SELECT `a` FROM `catalog`.`schema`.`table`.`branch_wap_12345`",
888890
"INSERT OVERWRITE TABLE `catalog`.`schema`.`table`.`branch_wap_12345` (`a`) SELECT 1 AS `a` FROM `catalog`.`schema`.`temp_branch_wap_12345_abcdefgh`",

tests/utils/test_helpers.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
from sqlglot import expressions
2+
3+
from sqlmesh.utils import columns_to_types_all_known
4+
5+
6+
def test_columns_to_types_all_known() -> None:
7+
assert (
8+
columns_to_types_all_known(
9+
{"a": expressions.DataType.build("INT"), "b": expressions.DataType.build("INT")}
10+
)
11+
== True
12+
)
13+
assert (
14+
columns_to_types_all_known(
15+
{"a": expressions.DataType.build("UNKNOWN"), "b": expressions.DataType.build("INT")}
16+
)
17+
== False
18+
)
19+
assert (
20+
columns_to_types_all_known(
21+
{"a": expressions.DataType.build("NULL"), "b": expressions.DataType.build("INT")}
22+
)
23+
== False
24+
)

0 commit comments

Comments
 (0)