Skip to content

Commit 68bdbb7

Browse files
authored
Feat!: Infer column types in the same manner as dbt does when converting dbt seeds (#2328)
1 parent 326263c commit 68bdbb7

File tree

3 files changed

+75
-2
lines changed

3 files changed

+75
-2
lines changed

setup.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@
6363
"beautifulsoup4",
6464
"black==24.1.1",
6565
"cryptography~=41.0.7",
66+
"dbt-common",
6667
"dbt-core",
6768
"dbt-duckdb>=1.7.1",
6869
"Faker",
@@ -113,6 +114,7 @@
113114
],
114115
"dbt": [
115116
"dbt-core<2",
117+
"dbt-common",
116118
],
117119
"gcppostgres": [
118120
"cloud-sql-python-connector[pg8000]",

sqlmesh/dbt/seed.py

Lines changed: 46 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,10 @@
22

33
import typing as t
44

5+
import agate
6+
from dbt_common.clients import agate_helper
7+
from sqlglot import exp
8+
59
from sqlmesh.core.model import Model, SeedKind, create_seed_model
610
from sqlmesh.dbt.basemodel import BaseModelConfig
711

@@ -18,11 +22,51 @@ class SeedConfig(BaseModelConfig):
1822
General propreties, General configs, and For seeds sections.
1923
"""
2024

25+
delimiter: str = ","
26+
2127
def to_sqlmesh(self, context: DbtContext) -> Model:
2228
"""Converts the dbt seed into a SQLMesh model."""
29+
seed_path = self.path.absolute().as_posix()
30+
kwargs = self.sqlmesh_model_kwargs(context)
31+
if kwargs.get("columns") is None:
32+
agate_table = agate_helper.from_csv(seed_path, [], delimiter=self.delimiter)
33+
kwargs["columns"] = {
34+
name: AGATE_TYPE_MAPPING[tpe.__class__]
35+
for name, tpe in zip(agate_table.column_names, agate_table.column_types)
36+
}
37+
2338
return create_seed_model(
2439
self.canonical_name(context),
25-
SeedKind(path=self.path.absolute().as_posix()),
40+
SeedKind(path=seed_path),
2641
dialect=context.dialect,
27-
**self.sqlmesh_model_kwargs(context),
42+
**kwargs,
2843
)
44+
45+
46+
class Integer(agate.data_types.DataType):
47+
def cast(self, d: str) -> t.Optional[int]:
48+
if d is None:
49+
return d
50+
try:
51+
return int(d)
52+
except ValueError:
53+
raise agate.exceptions.CastError('Can not parse value "%s" as Integer.' % d)
54+
55+
def jsonify(self, d: str) -> str:
56+
return d
57+
58+
59+
# The dbt version has a bug in which they check whether the type of the input value
60+
# is int, while the input value is actually always a string.
61+
agate_helper.Integer = Integer # type: ignore
62+
63+
64+
AGATE_TYPE_MAPPING = {
65+
agate_helper.Integer: exp.DataType.build("int"),
66+
agate_helper.Number: exp.DataType.build("double"),
67+
agate_helper.ISODateTime: exp.DataType.build("datetime"),
68+
agate.Date: exp.DataType.build("date"),
69+
agate.DateTime: exp.DataType.build("datetime"),
70+
agate.Boolean: exp.DataType.build("boolean"),
71+
agate.Text: exp.DataType.build("text"),
72+
}

tests/dbt/test_transformation.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,33 @@ def test_seed_columns():
294294
assert sqlmesh_seed.column_descriptions == expected_column_descriptions
295295

296296

297+
def test_seed_column_inference(tmp_path):
298+
seed_csv = tmp_path / "seed.csv"
299+
with open(seed_csv, "w") as fd:
300+
fd.write("int_col,double_col,datetime_col,date_col,boolean_col,text_col\n")
301+
fd.write("1,1.2,2021-01-01 00:00:00,2021-01-01,true,foo\n")
302+
fd.write("2,2.3,2021-01-02 00:00:00,2021-01-02,false,bar\n")
303+
304+
seed = SeedConfig(
305+
name="test_model",
306+
package="package",
307+
path=Path(seed_csv),
308+
)
309+
310+
context = DbtContext()
311+
context.project_name = "Foo"
312+
context.target = DuckDbConfig(name="target", schema="test")
313+
sqlmesh_seed = seed.to_sqlmesh(context)
314+
assert sqlmesh_seed.columns_to_types == {
315+
"int_col": exp.DataType.build("int"),
316+
"double_col": exp.DataType.build("double"),
317+
"datetime_col": exp.DataType.build("datetime"),
318+
"date_col": exp.DataType.build("date"),
319+
"boolean_col": exp.DataType.build("boolean"),
320+
"text_col": exp.DataType.build("text"),
321+
}
322+
323+
297324
@pytest.mark.xdist_group("dbt_manifest")
298325
@pytest.mark.parametrize(
299326
"model_fqn", ['"memory"."sushi"."waiters"', '"memory"."sushi"."waiter_names"']

0 commit comments

Comments
 (0)