diff --git a/src/gt4py/next/ffront/foast_to_gtir.py b/src/gt4py/next/ffront/foast_to_gtir.py index 3825072cb7..c341a311b1 100644 --- a/src/gt4py/next/ffront/foast_to_gtir.py +++ b/src/gt4py/next/ffront/foast_to_gtir.py @@ -412,23 +412,15 @@ def _visit_where(self, node: foast.Call, **kwargs: Any) -> itir.FunCall: return self._lower_and_map("if_", *node.args) cond_ = self.visit(node.args[0]) + true_ = self.visit(node.args[1]) + false_ = self.visit(node.args[2]) cond_symref_name = f"__cond_{cond_.fingerprint()}" - def create_if( - true_: itir.Expr, false_: itir.Expr, arg_types: tuple[ts.TypeSpec, ts.TypeSpec] - ) -> itir.FunCall: - return _map( - "if_", - (im.ref(cond_symref_name), true_, false_), - (node.args[0].type, *arg_types), + result = im.tree_map( + im.lambda_("__a", "__b")( + im.op_as_fieldop("if_")(im.ref(cond_symref_name), im.ref("__a"), im.ref("__b")) ) - - result = lowering_utils.process_elements( - create_if, - (self.visit(node.args[1]), self.visit(node.args[2])), - node.type, - arg_types=(node.args[1].type, node.args[2].type), - ) + )(true_, false_) return im.let(cond_symref_name, cond_)(result) diff --git a/src/gt4py/next/iterator/builtins.py b/src/gt4py/next/iterator/builtins.py index e54c6ea3d7..b60932eed8 100644 --- a/src/gt4py/next/iterator/builtins.py +++ b/src/gt4py/next/iterator/builtins.py @@ -57,6 +57,11 @@ def map_(*args): raise BackendNotSelectedError() +@builtin_dispatch +def tree_map(*args): + raise BackendNotSelectedError() + + @builtin_dispatch def make_const_list(*args): raise BackendNotSelectedError() @@ -498,7 +503,8 @@ def get_domain_range(*args): "lift", "make_const_list", "make_tuple", - "map_", + "tree_map", + "map_", # TODO: rename to map_list "named_range", "neighbors", "reduce", diff --git a/src/gt4py/next/iterator/ir_utils/ir_makers.py b/src/gt4py/next/iterator/ir_utils/ir_makers.py index 4b30e878fe..9545525a94 100644 --- a/src/gt4py/next/iterator/ir_utils/ir_makers.py +++ b/src/gt4py/next/iterator/ir_utils/ir_makers.py @@ -629,6 +629,11 @@ def map_(op): return call(call("map_")(op)) +def tree_map(op): + """Create a `tree_map` call: tree_map(op)(tup1, tup2, ...).""" + return call(call("tree_map")(op)) + + def reduce(op, expr): """Create a `reduce` call.""" return call(call("reduce")(op, expr)) diff --git a/src/gt4py/next/iterator/transforms/pass_manager.py b/src/gt4py/next/iterator/transforms/pass_manager.py index a6228c6125..32d847841b 100644 --- a/src/gt4py/next/iterator/transforms/pass_manager.py +++ b/src/gt4py/next/iterator/transforms/pass_manager.py @@ -23,6 +23,7 @@ prune_empty_concat_where, remove_broadcast, symbol_ref_utils, + unroll_tree_map, ) from gt4py.next.iterator.transforms.collapse_list_get import CollapseListGet from gt4py.next.iterator.transforms.collapse_tuple import CollapseTuple @@ -176,6 +177,26 @@ def apply_common_transforms( ) # domain inference does not support dynamic offsets yet ir = infer_domain_ops.InferDomainOps.apply(ir) ir = concat_where.canonicalize_domain_argument(ir) + ir = unroll_tree_map.UnrollTreeMap.apply(ir, uids=uids) + + # After UnrollTreeMap, collapse `tuple_get(i, let(...)(make_tuple(...)))` patterns so that + # domain inference does not encounter `as_fieldop` nodes inside dead tuple elements + # (which would receive NEVER domain). Do multiple iterations for nested `let`s. + for _ in range(10): + collapsed = ir + ir = CollapseTuple.apply( + ir, + enabled_transformations=( + CollapseTuple.Transformation.PROPAGATE_TUPLE_GET + | CollapseTuple.Transformation.COLLAPSE_TUPLE_GET_MAKE_TUPLE + ), + uids=uids, + offset_provider_type=offset_provider_type, + ) # type: ignore[assignment] # always an itir.Program + if ir == collapsed: + break + else: + raise RuntimeError("'CollapseTuple' did not converge after `UnrollTreeMap`.") ir = infer_domain.infer_program( ir, @@ -290,6 +311,23 @@ def apply_fieldview_transforms( ir = infer_domain_ops.InferDomainOps.apply(ir) ir = concat_where.canonicalize_domain_argument(ir) + ir = unroll_tree_map.UnrollTreeMap.apply(ir, uids=uids) + for _ in range(10): + prev = ir + ir = CollapseTuple.apply( + ir, + enabled_transformations=( + CollapseTuple.Transformation.PROPAGATE_TUPLE_GET + | CollapseTuple.Transformation.COLLAPSE_TUPLE_GET_MAKE_TUPLE + ), + uids=uids, + offset_provider_type=offset_provider_type, + ) # type: ignore[assignment] # always an itir.Program + if ir == prev: + break + else: + raise RuntimeError("'CollapseTuple' did not converge after `UnrollTreeMap`.") + ir = ConstantFolding.apply(ir) # type: ignore[assignment] # always an itir.Program ir = infer_domain.infer_program( diff --git a/src/gt4py/next/iterator/transforms/unroll_tree_map.py b/src/gt4py/next/iterator/transforms/unroll_tree_map.py new file mode 100644 index 0000000000..d2b0b7b642 --- /dev/null +++ b/src/gt4py/next/iterator/transforms/unroll_tree_map.py @@ -0,0 +1,62 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause +import dataclasses +import functools + +from gt4py import eve +from gt4py.next import utils +from gt4py.next.iterator import ir as itir +from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im +from gt4py.next.iterator.type_system import inference as itir_inference +from gt4py.next.type_system import type_specifications as ts + + +@dataclasses.dataclass +class UnrollTreeMap(eve.NodeTranslator): + PRESERVED_ANNEX_ATTRS = ("domain",) + + uids: utils.IDGeneratorPool + + @classmethod + def apply(cls, program: itir.Program, *, uids: utils.IDGeneratorPool): + return cls(uids=uids).visit(program) + + def visit_FunCall(self, node: itir.FunCall): + node = self.generic_visit(node) + + if not cpm.is_call_to(node.fun, "tree_map"): + return node + + f = node.fun.args[0] + tup_args = node.args + tup_types: list[ts.TupleType] = [] + for tup in tup_args: + itir_inference.reinfer(tup) + assert isinstance(tup.type, ts.TupleType) + tup_types.append(tup.type) + + tup_refs = [next(self.uids["_utm"]) for _ in tup_args] + + @utils.tree_map( + collection_type=ts.TupleType, + result_collection_constructor=lambda _, elts: im.make_tuple(*elts), + with_path_arg=True, + ) + def mapper(*args): + *_el_types, path = args + return im.call(f)( + *( + functools.reduce(lambda expr, i: im.tuple_get(i, expr), path, im.ref(ref_name)) + for ref_name in tup_refs + ) + ) + + result = im.let(*zip(tup_refs, tup_args))(mapper(*tup_types)) + + itir_inference.reinfer(result) + return result diff --git a/src/gt4py/next/iterator/type_system/type_synthesizer.py b/src/gt4py/next/iterator/type_system/type_synthesizer.py index 16d5da7e3b..6437e23973 100644 --- a/src/gt4py/next/iterator/type_system/type_synthesizer.py +++ b/src/gt4py/next/iterator/type_system/type_synthesizer.py @@ -20,7 +20,6 @@ from gt4py.next.iterator import builtins, ir as itir from gt4py.next.iterator.type_system import type_specifications as it_ts from gt4py.next.type_system import type_info, type_specifications as ts -from gt4py.next.utils import tree_map def _type_synth_arg_cache_key(type_or_synth: TypeOrTypeSynthesizer) -> int: @@ -203,7 +202,7 @@ def if_( pred: ts.ScalarType | ts.DeferredType, true_branch: ts.DataType, false_branch: ts.DataType ) -> ts.DataType: if isinstance(true_branch, ts.TupleType) and isinstance(false_branch, ts.TupleType): - return tree_map( + return utils.tree_map( collection_type=ts.TupleType, result_collection_constructor=lambda _, elts: ts.TupleType(types=[*elts]), )(functools.partial(if_, pred))(true_branch, false_branch) @@ -633,6 +632,32 @@ def applied_map( return applied_map +@_register_builtin_type_synthesizer +def tree_map(op: TypeSynthesizer) -> TypeSynthesizer: + @type_synthesizer + def applied_map( + *args: ts.TupleType, offset_provider_type: common.OffsetProviderType + ) -> ts.TupleType: + if not args: + raise TypeError("tree_map requires at least one argument.") + if not all(isinstance(a, ts.TupleType) for a in args): + raise TypeError( + "tree_map requires all top-level arguments to be TupleType, " + f"got {[type(a).__name__ for a in args]}." + ) + + def leaf_op(*leaf_types: ts.TypeSpec) -> ts.TypeSpec: + return op(*leaf_types, offset_provider_type=offset_provider_type) # type: ignore[return-value] + + return utils.tree_map( # type: ignore[return-value] + leaf_op, + collection_type=ts.TupleType, + result_collection_constructor=lambda _, elts: ts.TupleType(types=[*elts]), + )(*args) + + return applied_map + + @_register_builtin_type_synthesizer def reduce(op: TypeSynthesizer, init: ts.TypeSpec) -> TypeSynthesizer: @type_synthesizer diff --git a/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py b/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py index 8e3bba90b9..bf2978a8f2 100644 --- a/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py +++ b/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py @@ -207,10 +207,9 @@ def foo( lowered ) # we generate a let for the condition which is removed by inlining for easier testing - reference = im.make_tuple( - im.op_as_fieldop("if_")("a", im.tuple_get(0, "b"), im.tuple_get(0, "c")), - im.op_as_fieldop("if_")("a", im.tuple_get(1, "b"), im.tuple_get(1, "c")), - ) + reference = im.tree_map( + im.lambda_("__a", "__b")(im.op_as_fieldop("if_")("a", im.ref("__a"), im.ref("__b"))) + )("b", "c") assert lowered_inlined.expr == reference diff --git a/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py b/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py index d73fc1945f..f901a7ce39 100644 --- a/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py +++ b/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py @@ -155,6 +155,23 @@ def expression_test_cases(): ), ts.ListType(element_type=int_type, offset_type=V2EDim), ), + # tree_map + ( + im.tree_map(im.ref("plus"))( + im.ref("t1", ts.TupleType(types=[int_type, int_type])), + im.ref("t2", ts.TupleType(types=[int_type, int_type])), + ), + ts.TupleType(types=[int_type, int_type]), + ), + ( + im.tree_map(im.ref("not_"))( + im.ref( + "t", + ts.TupleType(types=[bool_type, ts.TupleType(types=[bool_type, bool_type])]), + ), + ), + ts.TupleType(types=[bool_type, ts.TupleType(types=[bool_type, bool_type])]), + ), # reduce (im.reduce("plus", 0)(im.ref("l", int_list_type)), int_type), ( diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_unroll_tree_map.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_unroll_tree_map.py new file mode 100644 index 0000000000..3462ef4084 --- /dev/null +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_unroll_tree_map.py @@ -0,0 +1,96 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +from gt4py.next import common, utils +from gt4py.next.iterator import ir as itir +from gt4py.next.iterator.ir_utils import ir_makers as im +from gt4py.next.iterator.transforms.unroll_tree_map import UnrollTreeMap +from gt4py.next.type_system import type_specifications as ts + +IDim = common.Dimension("IDim") +T = ts.ScalarType(kind=ts.ScalarKind.FLOAT64) +i_field = ts.FieldType(dims=[IDim], dtype=T) +i_tuple_field = ts.TupleType(types=[i_field, i_field]) +i_nested_tuple_field = ts.TupleType(types=[i_tuple_field, i_field]) + +i_domain = im.call("cartesian_domain")(im.named_range(itir.AxisLiteral(value="IDim"), 0, 1)) + + +def _make_program( + params: list[itir.Sym], expr: itir.Expr, out_type: ts.TypeSpec = i_field +) -> itir.Program: + return itir.Program( + id="testee", + function_definitions=[], + params=[*params, im.sym("out", out_type)], + declarations=[], + body=[ + itir.SetAt( + expr=expr, + domain=i_domain, + target=im.ref("out", out_type), + ) + ], + ) + + +def _neg(): + return im.lambda_("__a")(im.op_as_fieldop("neg")("__a")) + + +def _plus(): + return im.lambda_("__a", "__b")(im.op_as_fieldop("plus")("__a", "__b")) + + +def test_multi_arg(): + uids = utils.IDGeneratorPool() + program = _make_program( + [im.sym("a", i_tuple_field), im.sym("b", i_tuple_field)], + im.call(im.call("tree_map")(_plus()))( + im.ref("a", i_tuple_field), im.ref("b", i_tuple_field) + ), + out_type=i_tuple_field, + ) + result = UnrollTreeMap.apply(program, uids=uids) + + expected = _make_program( + [im.sym("a", i_tuple_field), im.sym("b", i_tuple_field)], + im.let(("_utm_0", "a"), ("_utm_1", "b"))( + im.make_tuple( + im.call(_plus())(im.tuple_get(0, "_utm_0"), im.tuple_get(0, "_utm_1")), + im.call(_plus())(im.tuple_get(1, "_utm_0"), im.tuple_get(1, "_utm_1")), + ) + ), + out_type=i_tuple_field, + ) + assert result == expected + + +def test_nested(): + uids = utils.IDGeneratorPool() + program = _make_program( + [im.sym("t", i_nested_tuple_field)], + im.call(im.call("tree_map")(_neg()))(im.ref("t", i_nested_tuple_field)), + out_type=i_nested_tuple_field, + ) + result = UnrollTreeMap.apply(program, uids=uids) + + expected = _make_program( + [im.sym("t", i_nested_tuple_field)], + im.let("_utm_0", "t")( + im.make_tuple( + im.make_tuple( + im.call(_neg())(im.tuple_get(0, im.tuple_get(0, "_utm_0"))), + im.call(_neg())(im.tuple_get(1, im.tuple_get(0, "_utm_0"))), + ), + im.call(_neg())(im.tuple_get(1, "_utm_0")), + ) + ), + out_type=i_nested_tuple_field, + ) + assert result == expected