Skip to content

Commit ec498a5

Browse files
committed
Fix: Coercion of string to integers when converting csv to agate tables (#2918)
1 parent 4f1d735 commit ec498a5

File tree

2 files changed

+34
-13
lines changed

2 files changed

+34
-13
lines changed

sqlmesh/dbt/seed.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -55,21 +55,22 @@ def to_sqlmesh(self, context: DbtContext) -> Model:
5555
)
5656

5757

58-
class Integer(agate.data_types.DataType):
59-
def cast(self, d: str) -> t.Optional[int]:
60-
if d is None:
61-
return d
62-
try:
63-
return int(d)
64-
except ValueError:
65-
raise agate.exceptions.CastError('Can not parse value "%s" as Integer.' % d)
66-
67-
def jsonify(self, d: str) -> str:
58+
class Integer(agate_helper.Integer):
59+
def cast(self, d: t.Any) -> t.Optional[int]:
60+
if isinstance(d, str):
61+
# The dbt's implementation doesn't support coercion of strings to integers.
62+
if d.strip().lower() in self.null_values:
63+
return None
64+
try:
65+
return int(d)
66+
except ValueError:
67+
raise agate.exceptions.CastError('Can not parse value "%s" as Integer.' % d)
68+
return super().cast(d)
69+
70+
def jsonify(self, d: t.Any) -> str:
6871
return d
6972

7073

71-
# The dbt version has a bug in which they check whether the type of the input value
72-
# is int, while the input value is actually always a string.
7374
agate_helper.Integer = Integer # type: ignore
7475

7576

tests/dbt/test_transformation.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import agate
2+
from datetime import datetime
13
import json
24
import logging
35
import typing as t
@@ -34,7 +36,7 @@
3436
from sqlmesh.dbt.model import Materialization, ModelConfig
3537
from sqlmesh.dbt.project import Project
3638
from sqlmesh.dbt.relation import Policy
37-
from sqlmesh.dbt.seed import SeedConfig
39+
from sqlmesh.dbt.seed import SeedConfig, Integer
3840
from sqlmesh.dbt.target import BigQueryConfig, DuckDbConfig, SnowflakeConfig
3941
from sqlmesh.dbt.test import TestConfig
4042
from sqlmesh.utils.errors import ConfigError, MacroEvalError, SQLMeshError
@@ -402,6 +404,7 @@ def test_seed_column_inference(tmp_path):
402404
fd.write("int_col,double_col,datetime_col,date_col,boolean_col,text_col\n")
403405
fd.write("1,1.2,2021-01-01 00:00:00,2021-01-01,true,foo\n")
404406
fd.write("2,2.3,2021-01-02 00:00:00,2021-01-02,false,bar\n")
407+
fd.write("null,,null,,,null\n")
405408

406409
seed = SeedConfig(
407410
name="test_model",
@@ -423,6 +426,23 @@ def test_seed_column_inference(tmp_path):
423426
}
424427

425428

429+
def test_agate_integer_cast():
430+
agate_integer = Integer(null_values=("null", ""))
431+
assert agate_integer.cast("1") == 1
432+
assert agate_integer.cast(1) == 1
433+
assert agate_integer.cast("null") is None
434+
assert agate_integer.cast("") is None
435+
436+
with pytest.raises(agate.exceptions.CastError):
437+
agate_integer.cast("1.2")
438+
439+
with pytest.raises(agate.exceptions.CastError):
440+
agate_integer.cast(1.2)
441+
442+
with pytest.raises(agate.exceptions.CastError):
443+
agate_integer.cast(datetime.now())
444+
445+
426446
@pytest.mark.xdist_group("dbt_manifest")
427447
def test_model_dialect(sushi_test_project: Project, assert_exp_eq):
428448
model_config = ModelConfig(

0 commit comments

Comments
 (0)