diff --git a/CHANGELOG.md b/CHANGELOG.md index 4c5767bf9..3351d65b8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,7 @@ - Fix `workflow_job` Python model submission method failing with dictionary attribute error ([#1360](https://github.com/databricks/dbt-databricks/issues/1360)) - Fix `TestWorkflowJob` functional test that was unreachable on all profiles due to incorrect skip list, wrong model fixture, and invalid `max_retries` parameter ([#1360](https://github.com/databricks/dbt-databricks/issues/1360)) - Fix column order mismatch in microbatch and replace_where incremental strategies by using INSERT BY NAME syntax ([#1338](https://github.com/databricks/dbt-databricks/issues/1338)) +- Fix `is_string()`, `is_number()`, `is_float()`, `is_integer()`, and `is_numeric()` returning `False` for Databricks/Spark column types by overriding them in `DatabricksColumn` with the correct Spark type names ([#1380](https://github.com/databricks/dbt-databricks/issues/1380)) - Fix `dbt run --empty` failing with inline `ref()` / `source()` aliases ([dbt-labs/dbt-adapters#660](https://github.com/dbt-labs/dbt-adapters/issues/660)) ### Under the Hood diff --git a/dbt/adapters/databricks/column.py b/dbt/adapters/databricks/column.py index 6aa32e970..350e54117 100644 --- a/dbt/adapters/databricks/column.py +++ b/dbt/adapters/databricks/column.py @@ -28,6 +28,53 @@ def create(cls, name: str, label_or_dtype: str) -> "DatabricksColumn": column_type = cls.translate_type(label_or_dtype) return cls(name, column_type) + def is_string(self) -> bool: + return self.dtype.lower() in { + "string", + "varchar", + "char", + "text", + "character varying", + "character", + "nchar", + "nvarchar", + } + + def is_number(self) -> bool: + return self.dtype.lower() in { + "tinyint", + "smallint", + "int", + "integer", + "bigint", + "long", + "float", + "double", + "decimal", + "numeric", + "real", + } or self.dtype.lower().startswith("decimal(") + + def is_float(self) -> bool: + return self.dtype.lower() in { + "float", + "double", + "real", + } + + def is_integer(self) -> bool: + return self.dtype.lower() in { + "tinyint", + "smallint", + "int", + "integer", + "bigint", + "long", + } + + def is_numeric(self) -> bool: + return self.is_number() + @classmethod def from_json_metadata(cls, json_metadata: str) -> list["DatabricksColumn"]: """ diff --git a/tests/unit/test_column.py b/tests/unit/test_column.py index 7c8ea6edd..387dad8aa 100644 --- a/tests/unit/test_column.py +++ b/tests/unit/test_column.py @@ -290,3 +290,80 @@ def test_parse_type_from_json_primitive_types(self): type_info = {"name": type_name} result = DatabricksColumn._parse_type_from_json(type_info) assert result == type_name + + +class TestTypeClassification: + @pytest.mark.parametrize( + "dtype", + [ + "string", + "varchar", + "char", + "text", + "character varying", + "character", + "nchar", + "nvarchar", + "STRING", + "VARCHAR", + "CHAR", + ], + ) + def test_is_string_true(self, dtype): + assert DatabricksColumn("col", dtype).is_string() is True + + @pytest.mark.parametrize("dtype", ["int", "bigint", "double", "decimal(10,2)"]) + def test_is_string_false(self, dtype): + assert DatabricksColumn("col", dtype).is_string() is False + + @pytest.mark.parametrize( + "dtype", + ["tinyint", "smallint", "int", "integer", "bigint", "long", "INT", "BIGINT"], + ) + def test_is_integer_true(self, dtype): + assert DatabricksColumn("col", dtype).is_integer() is True + + @pytest.mark.parametrize("dtype", ["float", "double", "string", "decimal(10,2)"]) + def test_is_integer_false(self, dtype): + assert DatabricksColumn("col", dtype).is_integer() is False + + @pytest.mark.parametrize("dtype", ["float", "double", "real", "FLOAT", "DOUBLE"]) + def test_is_float_true(self, dtype): + assert DatabricksColumn("col", dtype).is_float() is True + + @pytest.mark.parametrize("dtype", ["int", "bigint", "string", "decimal(10,2)"]) + def test_is_float_false(self, dtype): + assert DatabricksColumn("col", dtype).is_float() is False + + @pytest.mark.parametrize( + "dtype", + [ + "tinyint", + "smallint", + "int", + "integer", + "bigint", + "long", + "float", + "double", + "decimal", + "numeric", + "real", + "decimal(10,2)", + "decimal(38,0)", + "DECIMAL(10,2)", + ], + ) + def test_is_number_true(self, dtype): + assert DatabricksColumn("col", dtype).is_number() is True + + @pytest.mark.parametrize("dtype", ["string", "varchar", "boolean", "date", "timestamp"]) + def test_is_number_false(self, dtype): + assert DatabricksColumn("col", dtype).is_number() is False + + def test_is_numeric_delegates_to_is_number(self): + col = DatabricksColumn("col", "bigint") + assert col.is_numeric() == col.is_number() + + def test_is_numeric_false_for_string(self): + assert DatabricksColumn("col", "string").is_numeric() is False