diff --git a/pyproject.toml b/pyproject.toml index e4a5609d93..e6d0a33549 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -485,7 +485,7 @@ url = 'https://gridtools.github.io/pypi/' [tool.uv.sources] atlas4py = {index = "test.pypi"} dace = [ - {git = "https://github.com/GridTools/dace", branch = "romanc/stree-v2", group = "dace-cartesian"}, + {git = "https://github.com/GridTools/dace", branch = "romanc/math-functions", group = "dace-cartesian"}, {index = "gridtools", group = "dace-next"} ] diff --git a/src/gt4py/cartesian/gtc/dace/oir_to_tasklet.py b/src/gt4py/cartesian/gtc/dace/oir_to_tasklet.py index 2329128d70..6a7c71047e 100644 --- a/src/gt4py/cartesian/gtc/dace/oir_to_tasklet.py +++ b/src/gt4py/cartesian/gtc/dace/oir_to_tasklet.py @@ -226,38 +226,38 @@ def visit_NativeFunction(self, node: common.NativeFunction, **_kwargs: Any) -> s common.NativeFunction.ABS: "abs", common.NativeFunction.MIN: "min", common.NativeFunction.MAX: "max", - common.NativeFunction.MOD: "fmod", + common.NativeFunction.MOD: "dace.math.fmod", common.NativeFunction.SIN: "dace.math.sin", common.NativeFunction.COS: "dace.math.cos", common.NativeFunction.TAN: "dace.math.tan", - common.NativeFunction.ARCSIN: "asin", - common.NativeFunction.ARCCOS: "acos", - common.NativeFunction.ARCTAN: "atan", + common.NativeFunction.ARCSIN: "dace.math.asin", + common.NativeFunction.ARCCOS: "dace.math.acos", + common.NativeFunction.ARCTAN: "dace.math.atan", common.NativeFunction.SINH: "dace.math.sinh", common.NativeFunction.COSH: "dace.math.cosh", common.NativeFunction.TANH: "dace.math.tanh", - common.NativeFunction.ARCSINH: "asinh", - common.NativeFunction.ARCCOSH: "acosh", - common.NativeFunction.ARCTANH: "atanh", + common.NativeFunction.ARCSINH: "dace.math.asinh", + common.NativeFunction.ARCCOSH: "dace.math.acosh", + common.NativeFunction.ARCTANH: "dace.math.atanh", common.NativeFunction.SQRT: "dace.math.sqrt", common.NativeFunction.POW: "dace.math.pow", common.NativeFunction.EXP: "dace.math.exp", common.NativeFunction.LOG: "dace.math.log", - common.NativeFunction.LOG10: "log10", - common.NativeFunction.GAMMA: "tgamma", - common.NativeFunction.CBRT: "cbrt", + common.NativeFunction.LOG10: "dace.math.log10", + common.NativeFunction.GAMMA: "dace.math.tgamma", + common.NativeFunction.CBRT: "dace.math.cbrt", common.NativeFunction.ISFINITE: "isfinite", common.NativeFunction.ISINF: "isinf", common.NativeFunction.ISNAN: "isnan", common.NativeFunction.FLOOR: "dace.math.ifloor", - common.NativeFunction.CEIL: "ceil", - common.NativeFunction.TRUNC: "trunc", + common.NativeFunction.CEIL: "dace.math.ceil", + common.NativeFunction.TRUNC: "dace.math.trunc", common.NativeFunction.INT32: "dace.int32", common.NativeFunction.INT64: "dace.int64", common.NativeFunction.FLOAT32: "dace.float32", common.NativeFunction.FLOAT64: "dace.float64", - common.NativeFunction.ERF: "erf", - common.NativeFunction.ERFC: "erfc", + common.NativeFunction.ERF: "dace.math.erf", + common.NativeFunction.ERFC: "dace.math.erfc", common.NativeFunction.ROUND: "nearbyint", common.NativeFunction.ROUND_AWAY_FROM_ZERO: "round", } diff --git a/tests/cartesian_tests/unit_tests/test_gtc/dace/test_oir_to_tasklet.py b/tests/cartesian_tests/unit_tests/test_gtc/dace/test_oir_to_tasklet.py index 1ecb7659ee..980727fa3c 100644 --- a/tests/cartesian_tests/unit_tests/test_gtc/dace/test_oir_to_tasklet.py +++ b/tests/cartesian_tests/unit_tests/test_gtc/dace/test_oir_to_tasklet.py @@ -139,3 +139,22 @@ def test_integer_power_of_integer() -> None: tasklet_code = visitor.visit_NativeFuncCall(pow_call, ctx=fake_context, is_target=False) assert "ipow" not in tasklet_code + + +@pytest.mark.parametrize( + "arg", + [ + oir.Literal(value="2", dtype=common.DataType.FLOAT32), + oir.Literal(value="2", dtype=common.DataType.FLOAT64), + ], +) +def test_log10_respects_floating_point_precision(arg: oir.Literal) -> None: + log10_call = oir.NativeFuncCall(func=common.NativeFunction.LOG10, args=[arg]) + + visitor = oir_to_tasklet.OIRToTasklet() + fake_context = oir_to_tasklet.Context( + code="asdf", targets=set(), inputs={}, outputs={}, tree=None, scope=None + ) + tasklet_code = visitor.visit_NativeFuncCall(log10_call, ctx=fake_context, is_target=False) + + assert "dace.math.log10" in tasklet_code diff --git a/uv.lock b/uv.lock index 59e8b3010b..437365b901 100644 --- a/uv.lock +++ b/uv.lock @@ -1210,7 +1210,7 @@ wheels = [ [[package]] name = "dace" version = "1.0.0" -source = { git = "https://github.com/GridTools/dace?branch=romanc%2Fstree-v2#d5fbadb626389e425fac5ed93d2a880811eca41f" } +source = { git = "https://github.com/GridTools/dace?branch=romanc%2Fmath-functions#3df061c8aeabcaeea966f79e39a4dbded2628df9" } resolution-markers = [ "python_full_version >= '3.14' and sys_platform == 'win32'", "python_full_version >= '3.14' and sys_platform == 'emscripten'", @@ -1789,7 +1789,7 @@ build = [ { name = "wheel" }, ] dace-cartesian = [ - { name = "dace", version = "1.0.0", source = { git = "https://github.com/GridTools/dace?branch=romanc%2Fstree-v2#d5fbadb626389e425fac5ed93d2a880811eca41f" } }, + { name = "dace", version = "1.0.0", source = { git = "https://github.com/GridTools/dace?branch=romanc%2Fmath-functions#3df061c8aeabcaeea966f79e39a4dbded2628df9" } }, ] dace-next = [ { name = "dace", version = "43!2026.4.27", source = { registry = "https://gridtools.github.io/pypi/" } }, @@ -1961,7 +1961,7 @@ build = [ { name = "setuptools", specifier = ">=77.0.3" }, { name = "wheel", specifier = ">=0.33.6" }, ] -dace-cartesian = [{ name = "dace", git = "https://github.com/GridTools/dace?branch=romanc%2Fstree-v2" }] +dace-cartesian = [{ name = "dace", git = "https://github.com/GridTools/dace?branch=romanc%2Fmath-functions" }] dace-next = [{ name = "dace", specifier = "==43!2026.4.27", index = "https://gridtools.github.io/pypi/", conflict = { package = "gt4py", group = "dace-next" } }] dev = [ { name = "atlas4py", specifier = ">=0.41", index = "https://test.pypi.org/simple" },