diff --git a/sqlmesh/dbt/builtin.py b/sqlmesh/dbt/builtin.py index b8180bc011..145e29a96c 100644 --- a/sqlmesh/dbt/builtin.py +++ b/sqlmesh/dbt/builtin.py @@ -50,6 +50,22 @@ def warn(self, msg: str) -> str: return "" +def try_or_compiler_error( + message_if_exception: str, func: t.Callable, *args: t.Any, **kwargs: t.Any +) -> t.Any: + try: + return func(*args, **kwargs) + except Exception: + if DBT_VERSION >= (1, 4, 0): + from dbt.exceptions import CompilationError + + raise CompilationError(message_if_exception) + else: + from dbt.exceptions import CompilationException # type: ignore + + raise CompilationException(message_if_exception) + + class Api: def __init__(self, dialect: t.Optional[str]) -> None: if dialect: @@ -411,6 +427,7 @@ def debug() -> str: "sqlmesh_incremental": True, "tojson": to_json, "toyaml": to_yaml, + "try_or_compiler_error": try_or_compiler_error, "zip": do_zip, "zip_strict": lambda *args: list(zip(*args)), } diff --git a/tests/dbt/test_transformation.py b/tests/dbt/test_transformation.py index e519713d26..304ac57731 100644 --- a/tests/dbt/test_transformation.py +++ b/tests/dbt/test_transformation.py @@ -1592,6 +1592,29 @@ def test_exceptions(sushi_test_project: Project): context.render('{{ exceptions.raise_compiler_error("Error") }}') +@pytest.mark.xdist_group("dbt_manifest") +def test_try_or_compiler_error(sushi_test_project: Project): + context = sushi_test_project.context + + result = context.render( + '{{ try_or_compiler_error("Error message", modules.datetime.datetime.strptime, "2023-01-15", "%Y-%m-%d") }}' + ) + assert "2023-01-15" in result + + with pytest.raises(CompilationError, match="Invalid date format"): + context.render( + '{{ try_or_compiler_error("Invalid date format", modules.datetime.datetime.strptime, "invalid", "%Y-%m-%d") }}' + ) + + # built-in macro calling try_or_compiler_error works + result = context.render( + '{{ dbt.dates_in_range("2023-01-01", "2023-01-03", "%Y-%m-%d", "%Y-%m-%d") }}' + ) + assert "2023-01-01" in result + assert "2023-01-02" in result + assert "2023-01-03" in result + + @pytest.mark.xdist_group("dbt_manifest") def test_modules(sushi_test_project: Project): context = sushi_test_project.context