Skip to content

Commit a3a1303

Browse files
authored
Merge branch 'main' into add_fabric_warehouse
2 parents 1c8fc5c + 0c70406 commit a3a1303

File tree

7 files changed

+178
-10
lines changed

7 files changed

+178
-10
lines changed

pyproject.toml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ dependencies = [
2323
"requests",
2424
"rich[jupyter]",
2525
"ruamel.yaml",
26-
"sqlglot[rs]~=27.6.0",
26+
"sqlglot[rs]~=27.7.0",
2727
"tenacity",
2828
"time-machine",
2929
"json-stream"
@@ -275,6 +275,9 @@ filterwarnings = [
275275
]
276276
retry_delay = 10
277277

278+
[tool.ruff]
279+
line-length = 100
280+
278281
[tool.ruff.lint]
279282
select = [
280283
"F401",

sqlmesh/core/engine_adapter/bigquery.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -292,6 +292,13 @@ def create_schema(
292292
raise
293293
logger.warning("Failed to create schema '%s': %s", schema_name, e)
294294

295+
def get_bq_schema(self, table_name: TableName) -> t.List[bigquery.SchemaField]:
296+
table = exp.to_table(table_name)
297+
if len(table.parts) == 3 and "." in table.name:
298+
self.execute(exp.select("*").from_(table).limit(0))
299+
return self._query_job._query_results.schema
300+
return self._get_table(table).schema
301+
295302
def columns(
296303
self, table_name: TableName, include_pseudo_columns: bool = False
297304
) -> t.Dict[str, exp.DataType]:

sqlmesh/core/snapshot/evaluator.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1073,11 +1073,22 @@ def _cleanup_snapshot(
10731073

10741074
evaluation_strategy = _evaluation_strategy(snapshot, adapter)
10751075
for is_table_deployable, table_name in table_names:
1076-
evaluation_strategy.delete(
1077-
table_name,
1078-
is_table_deployable=is_table_deployable,
1079-
physical_schema=snapshot.physical_schema,
1080-
)
1076+
try:
1077+
evaluation_strategy.delete(
1078+
table_name,
1079+
is_table_deployable=is_table_deployable,
1080+
physical_schema=snapshot.physical_schema,
1081+
)
1082+
except Exception:
1083+
# Use `get_data_object` to check if the table exists instead of `table_exists` since the former
1084+
# is based on `INFORMATION_SCHEMA` and avoids touching the table directly.
1085+
# This is important when the table name is malformed for some reason and running any statement
1086+
# that touches the table would result in an error.
1087+
if adapter.get_data_object(table_name) is not None:
1088+
raise
1089+
logger.warning(
1090+
"Skipping cleanup of table '%s' because it does not exist", table_name
1091+
)
10811092

10821093
if on_complete is not None:
10831094
on_complete(table_name)

sqlmesh/dbt/adapter.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -291,11 +291,31 @@ def list_relations_without_caching(self, schema_relation: BaseRelation) -> t.Lis
291291
return relations
292292

293293
def get_columns_in_relation(self, relation: BaseRelation) -> t.List[Column]:
294-
from dbt.adapters.base.column import Column
295-
296294
mapped_table = self._map_table_name(self._normalize(self._relation_to_table(relation)))
295+
296+
if self.project_dialect == "bigquery":
297+
# dbt.adapters.bigquery.column.BigQueryColumn has a different constructor signature
298+
# We need to use BigQueryColumn.create_from_field() to create the column instead
299+
if (
300+
hasattr(self.column_type, "create_from_field")
301+
and callable(getattr(self.column_type, "create_from_field"))
302+
and hasattr(self.engine_adapter, "get_bq_schema")
303+
and callable(getattr(self.engine_adapter, "get_bq_schema"))
304+
):
305+
return [
306+
self.column_type.create_from_field(field) # type: ignore
307+
for field in self.engine_adapter.get_bq_schema(mapped_table) # type: ignore
308+
]
309+
from dbt.adapters.base.column import Column
310+
311+
return [
312+
Column.from_description(
313+
name=name, raw_data_type=dtype.sql(dialect=self.project_dialect)
314+
)
315+
for name, dtype in self.engine_adapter.columns(table_name=mapped_table).items()
316+
]
297317
return [
298-
Column.from_description(
318+
self.column_type.from_description(
299319
name=name, raw_data_type=dtype.sql(dialect=self.project_dialect)
300320
)
301321
for name, dtype in self.engine_adapter.columns(table_name=mapped_table).items()

tests/core/engine_adapter/integration/test_integration_bigquery.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -341,6 +341,39 @@ def test_compare_nested_values_in_table_diff(ctx: TestContext):
341341
ctx.engine_adapter.drop_table(target_table)
342342

343343

344+
def test_get_bq_schema(ctx: TestContext, engine_adapter: BigQueryEngineAdapter):
345+
from google.cloud.bigquery import SchemaField
346+
347+
table = ctx.table("test")
348+
349+
engine_adapter.execute(f"""
350+
CREATE TABLE {table.sql(dialect=ctx.dialect)} (
351+
id STRING NOT NULL,
352+
user_data STRUCT<id STRING NOT NULL, name STRING NOT NULL, address STRING>,
353+
tags ARRAY<STRING>,
354+
score NUMERIC,
355+
created_at DATETIME
356+
)
357+
""")
358+
359+
bg_schema = engine_adapter.get_bq_schema(table)
360+
assert len(bg_schema) == 5
361+
assert bg_schema[0] == SchemaField(name="id", field_type="STRING", mode="REQUIRED")
362+
assert bg_schema[1] == SchemaField(
363+
name="user_data",
364+
field_type="RECORD",
365+
mode="NULLABLE",
366+
fields=[
367+
SchemaField(name="id", field_type="STRING", mode="REQUIRED"),
368+
SchemaField(name="name", field_type="STRING", mode="REQUIRED"),
369+
SchemaField(name="address", field_type="STRING", mode="NULLABLE"),
370+
],
371+
)
372+
assert bg_schema[2] == SchemaField(name="tags", field_type="STRING", mode="REPEATED")
373+
assert bg_schema[3] == SchemaField(name="score", field_type="NUMERIC", mode="NULLABLE")
374+
assert bg_schema[4] == SchemaField(name="created_at", field_type="DATETIME", mode="NULLABLE")
375+
376+
344377
def test_column_types(ctx: TestContext):
345378
model_name = ctx.table("test")
346379
sqlmesh = ctx.create_context()

tests/core/test_snapshot_evaluator.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -438,12 +438,14 @@ def create_and_cleanup(name: str, dev_table_only: bool):
438438
return snapshot
439439

440440
snapshot = create_and_cleanup("catalog.test_schema.test_model", True)
441+
adapter_mock.get_data_object.assert_not_called()
441442
adapter_mock.drop_table.assert_called_once_with(
442443
f"catalog.sqlmesh__test_schema.test_schema__test_model__{snapshot.fingerprint.to_version()}__dev"
443444
)
444445
adapter_mock.reset_mock()
445446

446447
snapshot = create_and_cleanup("test_schema.test_model", False)
448+
adapter_mock.get_data_object.assert_not_called()
447449
adapter_mock.drop_table.assert_has_calls(
448450
[
449451
call(
@@ -455,6 +457,7 @@ def create_and_cleanup(name: str, dev_table_only: bool):
455457
adapter_mock.reset_mock()
456458

457459
snapshot = create_and_cleanup("test_model", False)
460+
adapter_mock.get_data_object.assert_not_called()
458461
adapter_mock.drop_table.assert_has_calls(
459462
[
460463
call(f"sqlmesh__default.test_model__{snapshot.fingerprint.to_version()}__dev"),
@@ -463,6 +466,59 @@ def create_and_cleanup(name: str, dev_table_only: bool):
463466
)
464467

465468

469+
def test_cleanup_fails(adapter_mock, make_snapshot):
470+
adapter_mock.drop_table.side_effect = RuntimeError("test_error")
471+
472+
evaluator = SnapshotEvaluator(adapter_mock)
473+
474+
model = SqlModel(
475+
name="catalog.test_schema.test_model",
476+
kind=IncrementalByTimeRangeKind(time_column="a"),
477+
storage_format="parquet",
478+
query=parse_one("SELECT a FROM tbl WHERE ds BETWEEN @start_ds and @end_ds"),
479+
)
480+
481+
snapshot = make_snapshot(model)
482+
snapshot.categorize_as(SnapshotChangeCategory.BREAKING, forward_only=True)
483+
snapshot.version = "test_version"
484+
485+
evaluator.promote([snapshot], EnvironmentNamingInfo(name="test_env"))
486+
with pytest.raises(NodeExecutionFailedError) as exc_info:
487+
evaluator.cleanup(
488+
[SnapshotTableCleanupTask(snapshot=snapshot.table_info, dev_table_only=True)]
489+
)
490+
491+
assert str(exc_info.value.__cause__) == "test_error"
492+
493+
494+
def test_cleanup_skip_missing_table(adapter_mock, make_snapshot):
495+
adapter_mock.get_data_object.return_value = None
496+
adapter_mock.drop_table.side_effect = RuntimeError("fail")
497+
498+
evaluator = SnapshotEvaluator(adapter_mock)
499+
500+
model = SqlModel(
501+
name="catalog.test_schema.test_model",
502+
kind=IncrementalByTimeRangeKind(time_column="a"),
503+
storage_format="parquet",
504+
query=parse_one("SELECT a FROM tbl WHERE ds BETWEEN @start_ds and @end_ds"),
505+
)
506+
507+
snapshot = make_snapshot(model)
508+
snapshot.categorize_as(SnapshotChangeCategory.BREAKING, forward_only=True)
509+
snapshot.version = "test_version"
510+
511+
evaluator.promote([snapshot], EnvironmentNamingInfo(name="test_env"))
512+
evaluator.cleanup([SnapshotTableCleanupTask(snapshot=snapshot.table_info, dev_table_only=True)])
513+
514+
adapter_mock.get_data_object.assert_called_once_with(
515+
f"catalog.sqlmesh__test_schema.test_schema__test_model__{snapshot.fingerprint.to_version()}__dev"
516+
)
517+
adapter_mock.drop_table.assert_called_once_with(
518+
f"catalog.sqlmesh__test_schema.test_schema__test_model__{snapshot.fingerprint.to_version()}__dev"
519+
)
520+
521+
466522
def test_cleanup_external_model(mocker: MockerFixture, adapter_mock, make_snapshot):
467523
evaluator = SnapshotEvaluator(adapter_mock)
468524

tests/dbt/test_adapter.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from sqlmesh.dbt.adapter import ParsetimeAdapter
1818
from sqlmesh.dbt.project import Project
1919
from sqlmesh.dbt.relation import Policy
20-
from sqlmesh.dbt.target import SnowflakeConfig
20+
from sqlmesh.dbt.target import BigQueryConfig, SnowflakeConfig
2121
from sqlmesh.utils.errors import ConfigError
2222
from sqlmesh.utils.jinja import JinjaMacroRegistry
2323

@@ -68,6 +68,44 @@ def test_adapter_relation(sushi_test_project: Project, runtime_renderer: t.Calla
6868
)
6969

7070

71+
def test_bigquery_get_columns_in_relation(
72+
sushi_test_project: Project,
73+
runtime_renderer: t.Callable,
74+
mocker: MockerFixture,
75+
):
76+
from dbt.adapters.bigquery import BigQueryColumn
77+
from google.cloud.bigquery import SchemaField
78+
79+
context = sushi_test_project.context
80+
context.target = BigQueryConfig(name="test", schema="test", database="test")
81+
82+
adapter_mock = mocker.MagicMock()
83+
adapter_mock.default_catalog = "test"
84+
adapter_mock.dialect = "bigquery"
85+
table_schema = [
86+
SchemaField(name="id", field_type="STRING", mode="REQUIRED"),
87+
SchemaField(
88+
name="user_data",
89+
field_type="RECORD",
90+
mode="NULLABLE",
91+
fields=[
92+
SchemaField(name="id", field_type="STRING", mode="REQUIRED"),
93+
SchemaField(name="name", field_type="STRING", mode="REQUIRED"),
94+
SchemaField(name="address", field_type="STRING", mode="NULLABLE"),
95+
],
96+
),
97+
SchemaField(name="tags", field_type="STRING", mode="REPEATED"),
98+
SchemaField(name="score", field_type="NUMERIC", mode="NULLABLE"),
99+
SchemaField(name="created_at", field_type="TIMESTAMP", mode="NULLABLE"),
100+
]
101+
adapter_mock.get_bq_schema.return_value = table_schema
102+
renderer = runtime_renderer(context, engine_adapter=adapter_mock, dialect="bigquery")
103+
assert renderer(
104+
"{%- set relation = api.Relation.create(database='test', schema='test', identifier='test_table') -%}"
105+
"{{ adapter.get_columns_in_relation(relation) }}"
106+
) == str([BigQueryColumn.create_from_field(field) for field in table_schema])
107+
108+
71109
@pytest.mark.cicdonly
72110
def test_normalization(
73111
sushi_test_project: Project, runtime_renderer: t.Callable, mocker: MockerFixture

0 commit comments

Comments
 (0)