Skip to content

Commit bb43e5f

Browse files
authored
Feat: add support for na_values and keep_default_na in csv_settings (#3872)
1 parent d566301 commit bb43e5f

File tree

4 files changed

+110
-3
lines changed

4 files changed

+110
-3
lines changed

docs/reference/model_configuration.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -274,5 +274,8 @@ Options specified within the `kind` property's `csv_settings` property (override
274274
| `skipinitialspace` | Skip spaces after delimiter. More information at the [Pandas documentation](https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.read_csv.html). | bool | N |
275275
| `lineterminator` | Character used to denote a line break. More information at the [Pandas documentation](https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.read_csv.html). | str | N |
276276
| `encoding` | Encoding to use for UTF when reading/writing (ex. 'utf-8'). More information at the [Pandas documentation](https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.read_csv.html). | str | N |
277+
| `na_values` | An array of values that should be recognized as NA/NaN. In order to specify such an array per column, a mapping in the form of `(col1 = (v1, v2, ...), col2 = ...)` can be passed instead. These values can be integers, strings, booleans or NULL, and they are converted to their corresponding Python values. More information at the [Pandas documentation](https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.read_csv.html). | array[value] \| array[array[key = value]] | N |
278+
| `keep_default_na` | Whether or not to include the default NaN values when parsing the data. More information at the [Pandas documentation](https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.read_csv.html). | bool | N |
279+
277280

278281
Python model kind `name` enum value: [ModelKindName.SEED](https://sqlmesh.readthedocs.io/en/stable/_readthedocs/html/sqlmesh/core/model/kind.html#ModelKindName)

sqlmesh/core/model/kind.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -640,9 +640,10 @@ def to_expression(
640640

641641
@property
642642
def data_hash_values(self) -> t.List[t.Optional[str]]:
643+
csv_setting_values = (self.csv_settings or CsvSettings()).dict().values()
643644
return [
644645
*super().data_hash_values,
645-
*(self.csv_settings or CsvSettings()).dict().values(),
646+
*(v if isinstance(v, (str, type(None))) else str(v) for v in csv_setting_values),
646647
]
647648

648649
@property

sqlmesh/core/model/seed.py

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import logging
34
import typing as t
45
import zlib
56
from io import StringIO
@@ -8,12 +9,18 @@
89
import pandas as pd
910
from sqlglot import exp
1011
from sqlglot.dialects.dialect import UNESCAPED_SEQUENCES
12+
from sqlglot.helper import seq_get
1113
from sqlglot.optimizer.normalize_identifiers import normalize_identifiers
1214

1315
from sqlmesh.core.model.common import parse_bool
1416
from sqlmesh.utils.pandas import columns_to_types_from_df
1517
from sqlmesh.utils.pydantic import PydanticModel, field_validator
1618

19+
logger = logging.getLogger(__name__)
20+
21+
NaHashables = t.List[t.Union[int, str, bool, t.Literal[None]]]
22+
NaValues = t.Union[NaHashables, t.Dict[str, NaHashables]]
23+
1724

1825
class CsvSettings(PydanticModel):
1926
"""Settings for CSV seeds."""
@@ -25,8 +32,10 @@ class CsvSettings(PydanticModel):
2532
skipinitialspace: t.Optional[bool] = None
2633
lineterminator: t.Optional[str] = None
2734
encoding: t.Optional[str] = None
35+
na_values: t.Optional[NaValues] = None
36+
keep_default_na: t.Optional[bool] = None
2837

29-
@field_validator("doublequote", "skipinitialspace", mode="before")
38+
@field_validator("doublequote", "skipinitialspace", "keep_default_na", mode="before")
3039
@classmethod
3140
def _bool_validator(cls, v: t.Any) -> t.Optional[bool]:
3241
if v is None:
@@ -46,6 +55,36 @@ def _str_validator(cls, v: t.Any) -> t.Optional[str]:
4655
v = v.this
4756
return UNESCAPED_SEQUENCES.get(v, v)
4857

58+
@field_validator("na_values", mode="before")
59+
@classmethod
60+
def _na_values_validator(cls, v: t.Any) -> t.Optional[NaValues]:
61+
if v is None or not isinstance(v, exp.Expression):
62+
return v
63+
64+
try:
65+
if isinstance(v, exp.Paren) or not isinstance(v, (exp.Tuple, exp.Array)):
66+
v = exp.Tuple(expressions=[v.unnest()])
67+
68+
expressions = v.expressions
69+
if isinstance(seq_get(expressions, 0), (exp.PropertyEQ, exp.EQ)):
70+
return {
71+
e.left.name: [
72+
rhs_val.to_py()
73+
for rhs_val in (
74+
[e.right.unnest()]
75+
if isinstance(e.right, exp.Paren)
76+
else e.right.expressions
77+
)
78+
]
79+
for e in expressions
80+
}
81+
82+
return [e.to_py() for e in expressions]
83+
except ValueError as e:
84+
logger.warning(f"Failed to coerce na_values '{v}', proceeding with defaults. {str(e)}")
85+
86+
return None
87+
4988

5089
class CsvSeedReader:
5190
def __init__(self, content: str, dialect: str, settings: CsvSettings):

tests/core/test_model.py

Lines changed: 65 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -897,6 +897,8 @@ def test_seed_csv_settings():
897897
csv_settings (
898898
quotechar = '''',
899899
escapechar = '\\',
900+
keep_default_na = false,
901+
na_values = (id = [1, '2', false, null], alias = ('foo'))
900902
),
901903
),
902904
columns (
@@ -910,7 +912,39 @@ def test_seed_csv_settings():
910912
model = load_sql_based_model(expressions, path=Path("./examples/sushi/models/test_model.sql"))
911913

912914
assert isinstance(model.kind, SeedKind)
913-
assert model.kind.csv_settings == CsvSettings(quotechar="'", escapechar="\\")
915+
assert model.kind.csv_settings == CsvSettings(
916+
quotechar="'",
917+
escapechar="\\",
918+
na_values={"id": [1, "2", False, None], "alias": ["foo"]},
919+
keep_default_na=False,
920+
)
921+
assert model.kind.data_hash_values == [
922+
"SEED",
923+
"'",
924+
"\\",
925+
"{'id': [1, '2', False, None], 'alias': ['foo']}",
926+
"False",
927+
]
928+
929+
expressions = d.parse(
930+
"""
931+
MODEL (
932+
name db.seed,
933+
kind SEED (
934+
path '../seeds/waiter_names.csv',
935+
csv_settings (
936+
na_values = ('#N/A', 'other')
937+
),
938+
),
939+
);
940+
"""
941+
)
942+
943+
model = load_sql_based_model(expressions, path=Path("./examples/sushi/models/test_model.sql"))
944+
945+
assert isinstance(model.kind, SeedKind)
946+
assert model.kind.csv_settings == CsvSettings(na_values=["#N/A", "other"])
947+
assert model.kind.data_hash_values == ["SEED", "['#N/A', 'other']"]
914948

915949

916950
def test_seed_marker_substitution():
@@ -7755,3 +7789,33 @@ def get_current_date(evaluator):
77557789
FROM "discount_promotion_dates" AS "discount_promotion_dates"
77567790
""",
77577791
)
7792+
7793+
7794+
def test_seed_dont_coerce_na_into_null(tmp_path):
7795+
model_csv_path = (tmp_path / "model.csv").absolute()
7796+
7797+
with open(model_csv_path, "w", encoding="utf-8") as fd:
7798+
fd.write("code\nNA")
7799+
7800+
expressions = d.parse(
7801+
f"""
7802+
MODEL (
7803+
name db.seed,
7804+
kind SEED (
7805+
path '{str(model_csv_path)}',
7806+
csv_settings (
7807+
-- override NaN handling, such that no value can be coerced into NaN
7808+
keep_default_na = false,
7809+
na_values = (),
7810+
),
7811+
),
7812+
);
7813+
"""
7814+
)
7815+
7816+
model = load_sql_based_model(expressions, path=Path("./examples/sushi/models/test_model.sql"))
7817+
7818+
assert isinstance(model.kind, SeedKind)
7819+
assert model.seed is not None
7820+
assert len(model.seed.content) > 0
7821+
assert next(model.render(context=None)).to_dict() == {"code": {0: "NA"}}

0 commit comments

Comments
 (0)