Skip to content

Commit c742e6d

Browse files
authored
Fix: Convert seed's dataframe columns to the datetime object when their type is explicitly set to date / time (#983)
1 parent 4f08255 commit c742e6d

File tree

4 files changed

+79
-8
lines changed

4 files changed

+79
-8
lines changed

sqlmesh/core/engine_adapter/bigquery.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -148,9 +148,11 @@ def _create_table_from_df(
148148
of the table if `replace` is true.
149149
"""
150150
assert isinstance(df, pd.DataFrame)
151-
table = self.__get_bq_table(table_name, columns_to_types or columns_to_types_from_df(df))
151+
if columns_to_types is None:
152+
columns_to_types = columns_to_types_from_df(df)
153+
table = self.__get_bq_table(table_name, columns_to_types)
152154
self.client.create_table(table, exists_ok=exists)
153-
self.__load_pandas_to_table(table, df, replace=replace)
155+
self.__load_pandas_to_table(table, df, columns_to_types, replace=replace)
154156

155157
def _insert_append_pandas_df(
156158
self,
@@ -162,13 +164,16 @@ def _insert_append_pandas_df(
162164
"""
163165
Appends to a table from a pandas dataframe. Will create the table if it doesn't exist.
164166
"""
165-
table = self.__get_bq_table(table_name, columns_to_types or columns_to_types_from_df(df))
166-
self.__load_pandas_to_table(table, df, replace=False)
167+
if columns_to_types is None:
168+
columns_to_types = columns_to_types_from_df(df)
169+
table = self.__get_bq_table(table_name, columns_to_types)
170+
self.__load_pandas_to_table(table, df, columns_to_types, replace=False)
167171

168172
def __load_pandas_to_table(
169173
self,
170174
table: bigquery.Table,
171175
df: pd.DataFrame,
176+
columns_to_types: t.Dict[str, exp.DataType],
172177
replace: bool = False,
173178
) -> BigQueryQueryResult:
174179
"""
@@ -177,7 +182,7 @@ def __load_pandas_to_table(
177182
"""
178183
from google.cloud import bigquery
179184

180-
job_config = bigquery.job.LoadJobConfig()
185+
job_config = bigquery.job.LoadJobConfig(schema=self.__get_bq_schema(columns_to_types))
181186
if replace:
182187
job_config.write_disposition = bigquery.WriteDisposition.WRITE_TRUNCATE
183188
result = self.client.load_table_from_dataframe(df, table, job_config=job_config).result()
@@ -258,7 +263,7 @@ def _insert_overwrite_by_condition(
258263

259264
temp_bq_table = self.__get_temp_bq_table(table, columns_to_types)
260265
self.client.create_table(temp_bq_table, exists_ok=False)
261-
result = self.__load_pandas_to_table(temp_bq_table, df, replace=False)
266+
result = self.__load_pandas_to_table(temp_bq_table, df, columns_to_types, replace=False)
262267
if result.errors:
263268
raise SQLMeshError(result.errors)
264269

sqlmesh/core/model/definition.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from pathlib import Path
1111
from textwrap import indent
1212

13+
import pandas as pd
1314
from astor import to_source
1415
from pydantic import Field
1516
from sqlglot import diff, exp
@@ -27,6 +28,7 @@
2728
from sqlmesh.core.model.meta import ModelMeta
2829
from sqlmesh.core.model.seed import Seed, create_seed
2930
from sqlmesh.core.renderer import ExpressionRenderer, QueryRenderer
31+
from sqlmesh.utils import str_to_bool
3032
from sqlmesh.utils.date import TimeLike, make_inclusive, to_datetime
3133
from sqlmesh.utils.errors import ConfigError, SQLMeshError, raise_config_error
3234
from sqlmesh.utils.jinja import JinjaMacroRegistry, extract_macro_references
@@ -814,7 +816,24 @@ def render(
814816
**kwargs: t.Any,
815817
) -> t.Generator[QueryOrDF, None, None]:
816818
self._ensure_hydrated()
817-
yield from self.seed.read(batch_size=self.kind.batch_size)
819+
820+
date_or_time_columns = []
821+
bool_columns = []
822+
string_columns = []
823+
for name, tpe in (self.columns_to_types_ or {}).items():
824+
if tpe.this in exp.DataType.TEMPORAL_TYPES:
825+
date_or_time_columns.append(name)
826+
elif tpe.is_type("boolean"):
827+
bool_columns.append(name)
828+
elif tpe.this in exp.DataType.TEXT_TYPES:
829+
string_columns.append(name)
830+
831+
for df in self.seed.read(batch_size=self.kind.batch_size):
832+
for column in date_or_time_columns:
833+
df[column] = pd.to_datetime(df[column])
834+
df[bool_columns] = df[bool_columns].apply(lambda i: str_to_bool(str(i)))
835+
df[string_columns] = df[string_columns].astype(str)
836+
yield df
818837

819838
def text_diff(self, other: Model) -> str:
820839
if not isinstance(other, SeedModel):

tests/core/engine_adapter/test_bigquery.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,9 @@ def test_replace_query_pandas(mocker: MockerFixture):
127127
load_result.result.return_value = AttributeDict({"errors": None})
128128
client_mock.load_table_from_dataframe.return_value = load_result
129129
df = pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]})
130-
adapter.replace_query("test_table", df, {"a": "int", "b": "int"})
130+
adapter.replace_query(
131+
"test_table", df, {"a": exp.DataType.build("int"), "b": exp.DataType.build("int")}
132+
)
131133

132134
assert execute_mock.call_args_list == []
133135
assert client_mock.method_calls[0] == [

tests/core/test_model.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from datetime import date
12
from pathlib import Path
23

34
import pytest
@@ -588,6 +589,50 @@ def test_seed_pre_statements_only():
588589
assert not model.post_statements
589590

590591

592+
def test_seed_model_custom_types(tmp_path):
593+
model_csv_path = (tmp_path / "model.csv").absolute()
594+
595+
with open(model_csv_path, "w") as fd:
596+
fd.write(
597+
"""key,ds,b_a,b_b,i,i_str
598+
123,2022-01-01,false,0,321,321
599+
"""
600+
)
601+
602+
model = create_seed_model(
603+
"test_db.test_model",
604+
SeedKind(path=str(model_csv_path)),
605+
columns={
606+
"key": "string",
607+
"ds": "date",
608+
"b_a": "boolean",
609+
"b_b": "boolean",
610+
"i": "int",
611+
"i_str": "text",
612+
},
613+
)
614+
615+
df = next(model.render(context=None))
616+
617+
assert df["ds"].dtype == "datetime64[ns]"
618+
assert df["ds"].iloc[0].date() == date(2022, 1, 1)
619+
620+
assert df["key"].dtype == "object"
621+
assert df["key"].iloc[0] == "123"
622+
623+
assert df["b_a"].dtype == "bool"
624+
assert not df["b_a"].iloc[0]
625+
626+
assert df["b_b"].dtype == "bool"
627+
assert not df["b_b"].iloc[0]
628+
629+
assert df["i"].dtype == "int64"
630+
assert df["i"].iloc[0] == 321
631+
632+
assert df["i_str"].dtype == "object"
633+
assert df["i_str"].iloc[0] == "321"
634+
635+
591636
def test_audits():
592637
expressions = parse(
593638
"""

0 commit comments

Comments
 (0)