diff --git a/dataframely/columns/decimal.py b/dataframely/columns/decimal.py index 9a072ea..7d5d719 100644 --- a/dataframely/columns/decimal.py +++ b/dataframely/columns/decimal.py @@ -98,7 +98,11 @@ def dtype(self) -> pl.DataType: return pl.Decimal(self.precision, self.scale) def validate_dtype(self, dtype: PolarsDataType) -> bool: - return dtype.is_decimal() + return ( + isinstance(dtype, pl.Decimal) + and dtype.scale == self.scale + and (self.precision is None or dtype.precision == self.precision) + ) def sqlalchemy_dtype(self, dialect: sa.Dialect) -> sa_TypeEngine: if self.scale and not self.precision: diff --git a/tests/column_types/test_decimal.py b/tests/column_types/test_decimal.py index 11719e4..ec4dc03 100644 --- a/tests/column_types/test_decimal.py +++ b/tests/column_types/test_decimal.py @@ -66,7 +66,7 @@ def test_invalid_args(kwargs: dict[str, Any]) -> None: @pytest.mark.parametrize( - "dtype", [pl.Decimal, pl.Decimal(12), pl.Decimal(None, 8), pl.Decimal(6, 2)] + "dtype", [pl.Decimal, pl.Decimal(12), pl.Decimal(None, 0), pl.Decimal(6, 0)] ) def test_any_decimal_dtype_passes(dtype: DataTypeClass) -> None: df = pl.DataFrame(schema={"a": dtype}) @@ -171,3 +171,58 @@ def test_validate_range( actual = evaluate_rules(lf, rules_from_exprs(column.validation_rules(pl.col("a")))) expected = pl.LazyFrame(valid) assert_frame_equal(actual, expected) + + +@pytest.mark.parametrize( + ("schema_precision", "schema_scale", "dtype", "should_pass"), + [ + # Exact match should pass + (38, 10, pl.Decimal(38, 10), True), + # Wrong scale should fail + (38, 10, pl.Decimal(38, 3), False), + # Wrong precision should fail + (10, 2, pl.Decimal(38, 2), False), + # Both wrong should fail + (10, 2, pl.Decimal(38, 5), False), + # precision=None should accept any precision with matching scale + (None, 5, pl.Decimal(10, 5), True), + (None, 5, pl.Decimal(20, 5), True), + # precision=None with wrong scale should fail + (None, 5, pl.Decimal(10, 3), False), + ], +) +def test_precision_scale_validation( + schema_precision: int | None, + schema_scale: int, + dtype: DataTypeClass, + should_pass: bool, +) -> None: + class TestSchema(dy.Schema): + a = dy.Decimal(precision=schema_precision, scale=schema_scale) + + df = pl.DataFrame(schema={"a": dtype}) + assert TestSchema.is_valid(df) == should_pass + + +@pytest.mark.parametrize( + ("schema_precision", "schema_scale", "input_dtype"), + [ + (38, 10, pl.Decimal(38, 3)), + (10, 2, pl.Decimal(38, 2)), + (10, 5, pl.Decimal(20, 3)), + ], +) +def test_precision_scale_casting( + schema_precision: int, + schema_scale: int, + input_dtype: pl.DataType, +) -> None: + class TestSchema(dy.Schema): + a = dy.Decimal(precision=schema_precision, scale=schema_scale) + + df_input = pl.DataFrame({"a": [decimal.Decimal("12.34")]}).with_columns( + pl.col("a").cast(input_dtype) + ) + df_validated = TestSchema.validate(df_input, cast=True) + assert df_validated.schema["a"].precision == schema_precision # type: ignore[attr-defined] + assert df_validated.schema["a"].scale == schema_scale # type: ignore[attr-defined]