From 1b4707cfd1d6b5eb9623f3e51aa31c8d4696655c Mon Sep 17 00:00:00 2001 From: Sara Faghih-Naini Date: Fri, 20 Feb 2026 11:41:56 +0100 Subject: [PATCH 1/7] Tracer prototype --- src/gt4py/next/ffront/field_operator_ast.py | 12 +++++ .../ffront/foast_passes/type_deduction.py | 26 +++++++++- src/gt4py/next/ffront/foast_pretty_printer.py | 2 + src/gt4py/next/ffront/foast_to_gtir.py | 9 ++++ src/gt4py/next/ffront/foast_to_past.py | 8 ++-- src/gt4py/next/ffront/func_to_foast.py | 31 ++++++++++-- .../next/ffront/past_passes/type_deduction.py | 2 +- src/gt4py/next/iterator/builtins.py | 3 +- .../next/iterator/transforms/pass_manager.py | 3 ++ .../iterator/transforms/unroll_map_tuple.py | 47 +++++++++++++++++++ .../iterator/type_system/type_synthesizer.py | 13 +++++ src/gt4py/next/type_system/type_info.py | 8 ++++ .../next/type_system/type_specifications.py | 9 ++++ .../next/type_system/type_translation.py | 22 +++++++-- tests/next_tests/integration_tests/cases.py | 11 +++++ .../ffront_tests/test_execution.py | 30 ++++++++++++ 16 files changed, 222 insertions(+), 14 deletions(-) create mode 100644 src/gt4py/next/iterator/transforms/unroll_map_tuple.py diff --git a/src/gt4py/next/ffront/field_operator_ast.py b/src/gt4py/next/ffront/field_operator_ast.py index fa5bc4889f..f4950ea402 100644 --- a/src/gt4py/next/ffront/field_operator_ast.py +++ b/src/gt4py/next/ffront/field_operator_ast.py @@ -113,6 +113,18 @@ class TupleExpr(Expr): elts: list[Expr] +# TODO: give a good error for tuple(... for el in iter if ...) so that users understand that and why we don't support conditionals +# TODO: should this have SymbolTableTrait since target declares a new symbol. Write test that has two comprehensions using the same target name. +class TupleComprehension(Expr): + """ + tuple(element_expr for target in iterable) + """ + + element_expr: Expr + target: DataSymbol # TODO: how about `tuple(el1+el2 for el1, el2 in var_arg)`? + iterable: Expr + + class UnaryOp(Expr): op: dialect_ast_enums.UnaryOperator operand: Expr diff --git a/src/gt4py/next/ffront/foast_passes/type_deduction.py b/src/gt4py/next/ffront/foast_passes/type_deduction.py index 68bf108a0a..e545f9e002 100644 --- a/src/gt4py/next/ffront/foast_passes/type_deduction.py +++ b/src/gt4py/next/ffront/foast_passes/type_deduction.py @@ -5,7 +5,7 @@ # # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause - +import collections from typing import Any, Optional, TypeAlias, TypeVar, cast import gt4py.next.ffront.field_operator_ast as foast @@ -501,6 +501,10 @@ def visit_Subscript(self, node: foast.Subscript, **kwargs: Any) -> foast.Subscri f"Tuples need to be indexed with literal integers, got '{node.index}'.", ) from ex new_type = types[index] + case ts.VarArgType(element_type=element_type): + new_type = ( + element_type # TODO: we only temporarily allow any index for vararg types + ) case ts.OffsetType(source=source, target=(target1, target2)): if not target2.kind == DimensionKind.LOCAL: raise errors.DSLError( @@ -747,6 +751,26 @@ def visit_TupleExpr(self, node: foast.TupleExpr, **kwargs: Any) -> foast.TupleEx new_type = ts.TupleType(types=[element.type for element in new_elts]) return foast.TupleExpr(elts=new_elts, type=new_type, location=node.location) + def visit_TupleComprehension( + self, node: foast.TupleComprehension, **kwargs: Any + ) -> foast.TupleComprehension: + symtable: collections.ChainMap = kwargs["symtable"] # todo annotation + iterable = self.visit(node.iterable, **kwargs) + target = self.visit(node.target, **kwargs) + assert isinstance(iterable.type, ts.VarArgType) + target.type = iterable.type.element_type + element_expr = self.visit( + node.element_expr, + **{**kwargs, "symtable": symtable.new_child({node.target.id: target})}, + ) + return foast.TupleComprehension( + element_expr=element_expr, + target=target, + iterable=iterable, + location=node.location, + type=ts.VarArgType(element_type=element_expr.type), + ) + def visit_Call(self, node: foast.Call, **kwargs: Any) -> foast.Call: new_func = self.visit(node.func, **kwargs) new_args = self.visit(node.args, **kwargs) diff --git a/src/gt4py/next/ffront/foast_pretty_printer.py b/src/gt4py/next/ffront/foast_pretty_printer.py index 8b2e369501..77495d78f7 100644 --- a/src/gt4py/next/ffront/foast_pretty_printer.py +++ b/src/gt4py/next/ffront/foast_pretty_printer.py @@ -118,6 +118,8 @@ def apply(cls, node: foast.LocatedNode, **kwargs: Any) -> str: # type: ignore[o TupleExpr = as_fmt("({', '.join(elts)}{',' if len(elts)==1 else ''})") + TupleComprehension = as_fmt("tuple(({element_expr} for {target} in {iterable}))") + UnaryOp = as_fmt("{op}{operand}") def visit_UnaryOp(self, node: foast.UnaryOp, **kwargs: Any) -> str: diff --git a/src/gt4py/next/ffront/foast_to_gtir.py b/src/gt4py/next/ffront/foast_to_gtir.py index 3825072cb7..2e587c346e 100644 --- a/src/gt4py/next/ffront/foast_to_gtir.py +++ b/src/gt4py/next/ffront/foast_to_gtir.py @@ -257,6 +257,15 @@ def visit_Subscript(self, node: foast.Subscript, **kwargs: Any) -> itir.Expr: def visit_TupleExpr(self, node: foast.TupleExpr, **kwargs: Any) -> itir.Expr: return im.make_tuple(*[self.visit(el, **kwargs) for el in node.elts]) + def visit_TupleComprehension(self, node: foast.TupleComprehension, **kwargs: Any) -> itir.Expr: + return im.call( + im.call("map_tuple")( + im.lambda_(self.visit(node.target, **kwargs))( + self.visit(node.element_expr, **kwargs) + ) + ) + )(self.visit(node.iterable, **kwargs)) + def visit_UnaryOp(self, node: foast.UnaryOp, **kwargs: Any) -> itir.Expr: # TODO(tehrengruber): extend iterator ir to support unary operators dtype = type_info.extract_dtype(node.type) diff --git a/src/gt4py/next/ffront/foast_to_past.py b/src/gt4py/next/ffront/foast_to_past.py index 05b080b70b..c37cba5a78 100644 --- a/src/gt4py/next/ffront/foast_to_past.py +++ b/src/gt4py/next/ffront/foast_to_past.py @@ -21,7 +21,7 @@ from gt4py.next.ffront.stages import ConcreteFOASTOperatorDef, ConcretePASTProgramDef from gt4py.next.iterator import ir as itir from gt4py.next.otf import toolchain, workflow -from gt4py.next.type_system import type_info, type_specifications as ts +from gt4py.next.type_system import type_specifications as ts @dataclasses.dataclass(frozen=True) @@ -113,9 +113,9 @@ def __call__(self, inp: ConcreteFOASTOperatorDef) -> ConcretePASTProgramDef: *partial_program_type.definition.kw_only_args.keys(), ] assert isinstance(type_, ts.CallableType) - assert arg_types[-1] == type_info.return_type( - type_, with_args=list(arg_types), with_kwargs=kwarg_types - ) + # assert arg_types[-1] == type_info.return_type( + # type_, with_args=list(arg_types), with_kwargs=kwarg_types + # ) assert args_names[-1] == "out" params_decl: list[past.Symbol] = [ diff --git a/src/gt4py/next/ffront/func_to_foast.py b/src/gt4py/next/ffront/func_to_foast.py index ced0ff3905..adefa7ba9e 100644 --- a/src/gt4py/next/ffront/func_to_foast.py +++ b/src/gt4py/next/ffront/func_to_foast.py @@ -337,7 +337,12 @@ def visit_Expr(self, node: ast.Expr) -> foast.Expr: return self.visit(node.value) def visit_Name(self, node: ast.Name, **kwargs: Any) -> foast.Name: - return foast.Name(id=node.id, location=self.get_location(node)) + loc = self.get_location(node) + if isinstance(node.ctx, ast.Store): + return foast.DataSymbol(id=node.id, location=loc, type=ts.DeferredType(constraint=None)) + else: + assert isinstance(node.ctx, ast.Load) + return foast.Name(id=node.id, location=loc) def visit_UnaryOp(self, node: ast.UnaryOp, **kwargs: Any) -> foast.UnaryOp: return foast.UnaryOp( @@ -469,8 +474,10 @@ def visit_NotEq(self, node: ast.NotEq, **kwargs: Any) -> foast.CompareOperator: return foast.CompareOperator.NOTEQ def _verify_builtin_type_constructor(self, node: ast.Call) -> None: - if len(node.args) > 0: - arg = node.args[0] + (arg,) = ( + node.args + ) # note for review: the change here is unrelated to the actual pr and just a small cleanup + if node.func.id == "tuple": if not ( isinstance(arg, ast.Constant) or (isinstance(arg, ast.UnaryOp) and isinstance(arg.operand, ast.Constant)) @@ -484,9 +491,25 @@ def _func_name(self, node: ast.Call) -> str: return node.func.id # type: ignore[attr-defined] # We want this to fail if the attribute does not exist unexpectedly. def visit_Call(self, node: ast.Call, **kwargs: Any) -> foast.Call: - # TODO(tehrengruber): is this still needed or redundant with the checks in type deduction? if isinstance(node.func, ast.Name): func_name = self._func_name(node) + + if func_name == "tuple": + (gen_expr,) = node.args + assert ( + len(gen_expr.generators) == 1 + ) # we don't support (... for ... in ... for ... in ...) + assert ( + gen_expr.generators[0].ifs == [] + ) # we don't support if conditions in comprehensions + return foast.TupleComprehension( + element_expr=self.visit(gen_expr.elt, **kwargs), + target=self.visit(gen_expr.generators[0].target, **kwargs), + iterable=self.visit(gen_expr.generators[0].iter, **kwargs), + location=self.get_location(node), + ) + + # TODO(tehrengruber): is this still needed or redundant with the checks in type deduction? if func_name in fbuiltins.TYPE_BUILTIN_NAMES: self._verify_builtin_type_constructor(node) diff --git a/src/gt4py/next/ffront/past_passes/type_deduction.py b/src/gt4py/next/ffront/past_passes/type_deduction.py index 9d021ceb51..530d407459 100644 --- a/src/gt4py/next/ffront/past_passes/type_deduction.py +++ b/src/gt4py/next/ffront/past_passes/type_deduction.py @@ -248,7 +248,7 @@ def visit_Call(self, node: past.Call, **kwargs: Any) -> past.Call: operator_return_type = type_info.return_type( new_func.type, with_args=arg_types, with_kwargs=kwarg_types ) - if operator_return_type != new_kwargs["out"].type: + if not type_info.is_compatible_type(operator_return_type, new_kwargs["out"].type): raise ValueError( "Expected keyword argument 'out' to be of " f"type '{operator_return_type}', got " diff --git a/src/gt4py/next/iterator/builtins.py b/src/gt4py/next/iterator/builtins.py index e54c6ea3d7..7b24c91884 100644 --- a/src/gt4py/next/iterator/builtins.py +++ b/src/gt4py/next/iterator/builtins.py @@ -498,7 +498,8 @@ def get_domain_range(*args): "lift", "make_const_list", "make_tuple", - "map_", + "map_tuple", + "map_", # TODO: rename to map_list "named_range", "neighbors", "reduce", diff --git a/src/gt4py/next/iterator/transforms/pass_manager.py b/src/gt4py/next/iterator/transforms/pass_manager.py index 08ca9d94e0..4102790129 100644 --- a/src/gt4py/next/iterator/transforms/pass_manager.py +++ b/src/gt4py/next/iterator/transforms/pass_manager.py @@ -24,6 +24,7 @@ prune_empty_concat_where, remove_broadcast, symbol_ref_utils, + unroll_map_tuple, ) from gt4py.next.iterator.transforms.collapse_list_get import CollapseListGet from gt4py.next.iterator.transforms.collapse_tuple import CollapseTuple @@ -179,6 +180,7 @@ def apply_common_transforms( ) # domain inference does not support dynamic offsets yet ir = infer_domain_ops.InferDomainOps.apply(ir) ir = concat_where.canonicalize_domain_argument(ir) + ir = unroll_map_tuple.UnrollMapTuple.apply(ir, uids=uids) ir = infer_domain.infer_program( ir, @@ -293,6 +295,7 @@ def apply_fieldview_transforms( ir = infer_domain_ops.InferDomainOps.apply(ir) ir = concat_where.canonicalize_domain_argument(ir) + ir = unroll_map_tuple.UnrollMapTuple.apply(ir, uids=uids) ir = ConstantFolding.apply(ir) # type: ignore[assignment] # always an itir.Program ir = infer_domain.infer_program( diff --git a/src/gt4py/next/iterator/transforms/unroll_map_tuple.py b/src/gt4py/next/iterator/transforms/unroll_map_tuple.py new file mode 100644 index 0000000000..66f96d66fa --- /dev/null +++ b/src/gt4py/next/iterator/transforms/unroll_map_tuple.py @@ -0,0 +1,47 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause +import dataclasses + +from gt4py import eve +from gt4py.next import utils +from gt4py.next.iterator import ir as itir +from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im +from gt4py.next.iterator.type_system import inference as itir_inference +from gt4py.next.type_system import type_specifications as ts + + +@dataclasses.dataclass +class UnrollMapTuple(eve.NodeTranslator): + PRESERVED_ANNEX_ATTRS = ("domain",) + + uids: utils.IDGeneratorPool + + @classmethod + def apply(cls, program: itir.Program, *, uids: utils.IDGeneratorPool): + return cls(uids=uids).visit(program) + + def visit_FunCall(self, node: itir.Expr): + node = self.generic_visit(node) + + if cpm.is_call_to(node.fun, "map_tuple"): + # TODO: we have to duplicate the function here since the domain inference can not handle them yet + f = node.fun.args[0] + tup = node.args[0] + itir_inference.reinfer(tup) + assert isinstance(tup.type, ts.TupleType) + tup_ref = next(self.uids["_ump"]) + + result = im.let(tup_ref, tup)( + im.make_tuple( + *(im.call(f)(im.tuple_get(i, tup_ref)) for i in range(len(tup.type.types))) + ) + ) + itir_inference.reinfer(result) + + return result + return node diff --git a/src/gt4py/next/iterator/type_system/type_synthesizer.py b/src/gt4py/next/iterator/type_system/type_synthesizer.py index 6d77c70375..4406dd9aa8 100644 --- a/src/gt4py/next/iterator/type_system/type_synthesizer.py +++ b/src/gt4py/next/iterator/type_system/type_synthesizer.py @@ -633,6 +633,19 @@ def applied_map( return applied_map +@_register_builtin_type_synthesizer +def map_tuple(op: TypeSynthesizer) -> TypeSynthesizer: + @type_synthesizer + def applied_map( + arg: ts.TupleType, offset_provider_type: common.OffsetProviderType + ) -> ts.TupleType: + return ts.TupleType( + types=[op(arg_, offset_provider_type=offset_provider_type) for arg_ in arg.types] + ) + + return applied_map + + @_register_builtin_type_synthesizer def reduce(op: TypeSynthesizer, init: ts.TypeSpec) -> TypeSynthesizer: @type_synthesizer diff --git a/src/gt4py/next/type_system/type_info.py b/src/gt4py/next/type_system/type_info.py index eb70d15947..69fccd33da 100644 --- a/src/gt4py/next/type_system/type_info.py +++ b/src/gt4py/next/type_system/type_info.py @@ -566,6 +566,14 @@ def is_concretizable(symbol_type: ts.TypeSpec, to_type: ts.TypeSpec) -> bool: or issubclass(type_class(to_type), symbol_type.constraint) ): return True + if isinstance(symbol_type, ts.VarArgType) and isinstance(to_type, ts.VarArgType): + return is_concretizable(symbol_type.element_type, to_type.element_type) + if isinstance(symbol_type, ts.VarArgType) and isinstance(to_type, ts.TupleType): + if len(to_type.types) == 0 or ( + all(type_ == to_type.types[0] for type_ in to_type.types) + and is_concretizable(symbol_type.element_type, to_type.types[0]) + ): + return True elif is_concrete(symbol_type): return symbol_type == to_type return False diff --git a/src/gt4py/next/type_system/type_specifications.py b/src/gt4py/next/type_system/type_specifications.py index 59ac40f0f3..409138d593 100644 --- a/src/gt4py/next/type_system/type_specifications.py +++ b/src/gt4py/next/type_system/type_specifications.py @@ -148,6 +148,15 @@ def __len__(self) -> int: return len(self.types) +class VarArgType(DataType): + """Represents a variable number of arguments of the same type.""" + + element_type: DataType # TODO: maybe also support different DataTypes + + def __str__(self) -> str: + return f"VarArg[{self.element_type}]" + + class AnyPythonType: """Marker type representing any Python type which cannot be used for instantiation. diff --git a/src/gt4py/next/type_system/type_translation.py b/src/gt4py/next/type_system/type_translation.py index 0f145e04aa..0ca020625a 100644 --- a/src/gt4py/next/type_system/type_translation.py +++ b/src/gt4py/next/type_system/type_translation.py @@ -180,8 +180,12 @@ def from_type_hint( case builtins.tuple: if not args: raise ValueError(f"Tuple annotation '{type_hint}' requires at least one argument.") - if Ellipsis in args: - raise ValueError(f"Unbound tuples '{type_hint}' are not allowed.") + if len(args) == 2 and args[1] is Ellipsis: + return ts.VarArgType(element_type=from_type_hint_same_ns(args[0])) + elif Ellipsis in args: + raise ValueError( + f"Vararg tuple annotation '{type_hint}' cannot have more than one argument." + ) tuple_types = [from_type_hint_same_ns(arg) for arg in args] assert all(isinstance(elem, ts.DataType) for elem in tuple_types) return ts.TupleType(types=tuple_types) @@ -321,7 +325,19 @@ def from_value(value: Any) -> ts.TypeSpec: return UnknownPythonObject(value) else: type_ = xtyping.infer_type(value, annotate_callable_kwargs=True) - symbol_type = from_type_hint(type_) + if type_ == type[tuple]: + # TODO: this special casing here is not nice, but infer_type is also called on the annotations where + # we don't want to allow unparameterized tuples (or do we?). + symbol_type = ts.ConstructorType( + definition=ts.FunctionType( + pos_only_args=[ts.DeferredType(constraint=None)], + pos_or_kw_args={}, + kw_only_args={}, + returns=ts.DeferredType(constraint=ts.VarArgType), + ) + ) + else: + symbol_type = from_type_hint(type_) if isinstance(symbol_type, (ts.DataType, ts.CallableType, ts.OffsetType, ts.DimensionType)): return symbol_type diff --git a/tests/next_tests/integration_tests/cases.py b/tests/next_tests/integration_tests/cases.py index 78e6c62781..e723c963de 100644 --- a/tests/next_tests/integration_tests/cases.py +++ b/tests/next_tests/integration_tests/cases.py @@ -603,6 +603,15 @@ def _allocate_from_type( for t in types ) ) + case ts.VarArgType(element_type=element_type): + return tuple( + ( + _allocate_from_type( + case=case, arg_type=t, domain=domain, dtype=dtype, strategy=strategy + ) + for t in [element_type] * 3 # TODO: revisit + ) + ) case ts.NamedCollectionType(types=types) as named_collection_type_spec: container_constructor = ( named_collections.make_named_collection_constructor_from_type_spec( @@ -648,6 +657,8 @@ def get_param_size(param_type: ts.TypeSpec, sizes: dict[gtx.Dimension, int]) -> return sum([get_param_size(t, sizes=sizes) for t in types]) case ts.NamedCollectionType(types=types): return sum([get_param_size(t, sizes=sizes) for t in types]) + case ts.VarArgType(element_type=element_type): + return get_param_size(ts.TupleType(types=[element_type] * 3), sizes) # TODO: revisit case _: raise TypeError(f"Can not get size for parameter of type '{param_type}'.") diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py index 8060d5bb36..14f14b3ffb 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py @@ -336,6 +336,36 @@ def testee(a: tuple[cases.IField, cases.IJField]) -> cases.IJField: ) +@pytest.mark.uses_tuple_args +def test_tuple_comprehension(cartesian_case): + @gtx.field_operator + def testee( + tracers: tuple[cases.IFloatField, ...], factor: float + ) -> tuple[cases.IFloatField, ...]: + return tuple(tracer * factor for tracer in tracers) + + cases.verify_with_default_data( + cartesian_case, + testee, + ref=lambda t, f: tuple(el * f for el in t), + ) + + +@pytest.mark.uses_tuple_args +def test_tuple_vararg(cartesian_case): + @gtx.field_operator + def testee( + tracers: tuple[cases.IFloatField, ...], factor: float + ) -> tuple[cases.IFloatField, cases.IFloatField]: + return tracers[0] * factor, tracers[1] * factor + + cases.verify_with_default_data( + cartesian_case, + testee, + ref=lambda t, f: tuple(el * f for el in t[:2]), + ) + + @pytest.mark.uses_tuple_args @pytest.mark.xfail(reason="Iterator of tuple approach in lowering does not allow this.") def test_tuple_arg_with_unpromotable_dims(unstructured_case): From 02f881fef3276bd782eea4c98d414053b37987b1 Mon Sep 17 00:00:00 2001 From: Sara Faghih-Naini Date: Mon, 27 Apr 2026 15:48:50 +0200 Subject: [PATCH 2/7] Introduce GTIR tree_map builtin and transform to make_tuple, also supporting nesting (extracted from #2487) --- src/gt4py/next/ffront/field_operator_ast.py | 12 ---- .../ffront/foast_passes/type_deduction.py | 27 +------- src/gt4py/next/ffront/foast_pretty_printer.py | 2 - src/gt4py/next/ffront/foast_to_gtir.py | 30 ++------ src/gt4py/next/ffront/foast_to_past.py | 8 +-- src/gt4py/next/ffront/func_to_foast.py | 31 ++------- .../next/ffront/past_passes/type_deduction.py | 2 +- src/gt4py/next/iterator/builtins.py | 2 +- src/gt4py/next/iterator/ir_utils/ir_makers.py | 5 ++ .../next/iterator/transforms/pass_manager.py | 6 +- .../iterator/transforms/unroll_map_tuple.py | 47 ------------- .../iterator/transforms/unroll_tree_map.py | 69 +++++++++++++++++++ .../iterator/type_system/type_synthesizer.py | 18 +++-- src/gt4py/next/type_system/type_info.py | 8 --- .../next/type_system/type_specifications.py | 9 --- .../next/type_system/type_translation.py | 22 +----- tests/next_tests/integration_tests/cases.py | 11 --- .../ffront_tests/test_execution.py | 30 -------- .../ffront_tests/test_foast_to_gtir.py | 9 +-- .../transforms_tests/test_unroll_tree_map.py | 43 ++++++++++++ 20 files changed, 159 insertions(+), 232 deletions(-) delete mode 100644 src/gt4py/next/iterator/transforms/unroll_map_tuple.py create mode 100644 src/gt4py/next/iterator/transforms/unroll_tree_map.py create mode 100644 tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_unroll_tree_map.py diff --git a/src/gt4py/next/ffront/field_operator_ast.py b/src/gt4py/next/ffront/field_operator_ast.py index 8ee216b96b..95a2588077 100644 --- a/src/gt4py/next/ffront/field_operator_ast.py +++ b/src/gt4py/next/ffront/field_operator_ast.py @@ -123,18 +123,6 @@ class TupleExpr(Expr): elts: list[Expr] -# TODO: give a good error for tuple(... for el in iter if ...) so that users understand that and why we don't support conditionals -# TODO: should this have SymbolTableTrait since target declares a new symbol. Write test that has two comprehensions using the same target name. -class TupleComprehension(Expr): - """ - tuple(element_expr for target in iterable) - """ - - element_expr: Expr - target: DataSymbol # TODO: how about `tuple(el1+el2 for el1, el2 in var_arg)`? - iterable: Expr - - class UnaryOp(Expr): op: dialect_ast_enums.UnaryOperator operand: Expr diff --git a/src/gt4py/next/ffront/foast_passes/type_deduction.py b/src/gt4py/next/ffront/foast_passes/type_deduction.py index 2b33d54cca..11c0bfd88b 100644 --- a/src/gt4py/next/ffront/foast_passes/type_deduction.py +++ b/src/gt4py/next/ffront/foast_passes/type_deduction.py @@ -5,11 +5,10 @@ # # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause -import collections + import textwrap from typing import Any, Optional, Sequence, TypeAlias, TypeVar, cast - import gt4py.next.ffront.field_operator_ast as foast from gt4py import eve from gt4py.eve import NodeTranslator, NodeVisitor, traits @@ -429,10 +428,6 @@ def visit_Subscript(self, node: foast.Subscript, **kwargs: Any) -> foast.Subscri f"Tuples need to be indexed with literal integers, got '{node.index}'.", ) from ex new_type = types[index] - case ts.VarArgType(element_type=element_type): - new_type = ( - element_type # TODO: we only temporarily allow any index for vararg types - ) case ts.OffsetType(source=source, target=(target1, target2)): if not target2.kind == DimensionKind.LOCAL: raise errors.DSLError( @@ -679,26 +674,6 @@ def visit_TupleExpr(self, node: foast.TupleExpr, **kwargs: Any) -> foast.TupleEx new_type = ts.TupleType(types=[element.type for element in new_elts]) return foast.TupleExpr(elts=new_elts, type=new_type, location=node.location) - def visit_TupleComprehension( - self, node: foast.TupleComprehension, **kwargs: Any - ) -> foast.TupleComprehension: - symtable: collections.ChainMap = kwargs["symtable"] # todo annotation - iterable = self.visit(node.iterable, **kwargs) - target = self.visit(node.target, **kwargs) - assert isinstance(iterable.type, ts.VarArgType) - target.type = iterable.type.element_type - element_expr = self.visit( - node.element_expr, - **{**kwargs, "symtable": symtable.new_child({node.target.id: target})}, - ) - return foast.TupleComprehension( - element_expr=element_expr, - target=target, - iterable=iterable, - location=node.location, - type=ts.VarArgType(element_type=element_expr.type), - ) - def visit_Call(self, node: foast.Call, **kwargs: Any) -> foast.Call: new_func = self.visit(node.func, **kwargs) new_args = self.visit(node.args, **kwargs) diff --git a/src/gt4py/next/ffront/foast_pretty_printer.py b/src/gt4py/next/ffront/foast_pretty_printer.py index 77495d78f7..8b2e369501 100644 --- a/src/gt4py/next/ffront/foast_pretty_printer.py +++ b/src/gt4py/next/ffront/foast_pretty_printer.py @@ -118,8 +118,6 @@ def apply(cls, node: foast.LocatedNode, **kwargs: Any) -> str: # type: ignore[o TupleExpr = as_fmt("({', '.join(elts)}{',' if len(elts)==1 else ''})") - TupleComprehension = as_fmt("tuple(({element_expr} for {target} in {iterable}))") - UnaryOp = as_fmt("{op}{operand}") def visit_UnaryOp(self, node: foast.UnaryOp, **kwargs: Any) -> str: diff --git a/src/gt4py/next/ffront/foast_to_gtir.py b/src/gt4py/next/ffront/foast_to_gtir.py index 2e587c346e..78b0671db1 100644 --- a/src/gt4py/next/ffront/foast_to_gtir.py +++ b/src/gt4py/next/ffront/foast_to_gtir.py @@ -257,15 +257,6 @@ def visit_Subscript(self, node: foast.Subscript, **kwargs: Any) -> itir.Expr: def visit_TupleExpr(self, node: foast.TupleExpr, **kwargs: Any) -> itir.Expr: return im.make_tuple(*[self.visit(el, **kwargs) for el in node.elts]) - def visit_TupleComprehension(self, node: foast.TupleComprehension, **kwargs: Any) -> itir.Expr: - return im.call( - im.call("map_tuple")( - im.lambda_(self.visit(node.target, **kwargs))( - self.visit(node.element_expr, **kwargs) - ) - ) - )(self.visit(node.iterable, **kwargs)) - def visit_UnaryOp(self, node: foast.UnaryOp, **kwargs: Any) -> itir.Expr: # TODO(tehrengruber): extend iterator ir to support unary operators dtype = type_info.extract_dtype(node.type) @@ -421,23 +412,16 @@ def _visit_where(self, node: foast.Call, **kwargs: Any) -> itir.FunCall: return self._lower_and_map("if_", *node.args) cond_ = self.visit(node.args[0]) + true_ = self.visit(node.args[1]) + false_ = self.visit(node.args[2]) cond_symref_name = f"__cond_{cond_.fingerprint()}" - def create_if( - true_: itir.Expr, false_: itir.Expr, arg_types: tuple[ts.TypeSpec, ts.TypeSpec] - ) -> itir.FunCall: - return _map( - "if_", - (im.ref(cond_symref_name), true_, false_), - (node.args[0].type, *arg_types), + # tree_map(lambda a, b: as_fieldop(if_)(cond_ref, a, b))(true_tup, false_tup) + result = im.tree_map( + im.lambda_("__a", "__b")( + im.op_as_fieldop("if_")(im.ref(cond_symref_name), im.ref("__a"), im.ref("__b")) ) - - result = lowering_utils.process_elements( - create_if, - (self.visit(node.args[1]), self.visit(node.args[2])), - node.type, - arg_types=(node.args[1].type, node.args[2].type), - ) + )(true_, false_) return im.let(cond_symref_name, cond_)(result) diff --git a/src/gt4py/next/ffront/foast_to_past.py b/src/gt4py/next/ffront/foast_to_past.py index c37cba5a78..05b080b70b 100644 --- a/src/gt4py/next/ffront/foast_to_past.py +++ b/src/gt4py/next/ffront/foast_to_past.py @@ -21,7 +21,7 @@ from gt4py.next.ffront.stages import ConcreteFOASTOperatorDef, ConcretePASTProgramDef from gt4py.next.iterator import ir as itir from gt4py.next.otf import toolchain, workflow -from gt4py.next.type_system import type_specifications as ts +from gt4py.next.type_system import type_info, type_specifications as ts @dataclasses.dataclass(frozen=True) @@ -113,9 +113,9 @@ def __call__(self, inp: ConcreteFOASTOperatorDef) -> ConcretePASTProgramDef: *partial_program_type.definition.kw_only_args.keys(), ] assert isinstance(type_, ts.CallableType) - # assert arg_types[-1] == type_info.return_type( - # type_, with_args=list(arg_types), with_kwargs=kwarg_types - # ) + assert arg_types[-1] == type_info.return_type( + type_, with_args=list(arg_types), with_kwargs=kwarg_types + ) assert args_names[-1] == "out" params_decl: list[past.Symbol] = [ diff --git a/src/gt4py/next/ffront/func_to_foast.py b/src/gt4py/next/ffront/func_to_foast.py index adefa7ba9e..ced0ff3905 100644 --- a/src/gt4py/next/ffront/func_to_foast.py +++ b/src/gt4py/next/ffront/func_to_foast.py @@ -337,12 +337,7 @@ def visit_Expr(self, node: ast.Expr) -> foast.Expr: return self.visit(node.value) def visit_Name(self, node: ast.Name, **kwargs: Any) -> foast.Name: - loc = self.get_location(node) - if isinstance(node.ctx, ast.Store): - return foast.DataSymbol(id=node.id, location=loc, type=ts.DeferredType(constraint=None)) - else: - assert isinstance(node.ctx, ast.Load) - return foast.Name(id=node.id, location=loc) + return foast.Name(id=node.id, location=self.get_location(node)) def visit_UnaryOp(self, node: ast.UnaryOp, **kwargs: Any) -> foast.UnaryOp: return foast.UnaryOp( @@ -474,10 +469,8 @@ def visit_NotEq(self, node: ast.NotEq, **kwargs: Any) -> foast.CompareOperator: return foast.CompareOperator.NOTEQ def _verify_builtin_type_constructor(self, node: ast.Call) -> None: - (arg,) = ( - node.args - ) # note for review: the change here is unrelated to the actual pr and just a small cleanup - if node.func.id == "tuple": + if len(node.args) > 0: + arg = node.args[0] if not ( isinstance(arg, ast.Constant) or (isinstance(arg, ast.UnaryOp) and isinstance(arg.operand, ast.Constant)) @@ -491,25 +484,9 @@ def _func_name(self, node: ast.Call) -> str: return node.func.id # type: ignore[attr-defined] # We want this to fail if the attribute does not exist unexpectedly. def visit_Call(self, node: ast.Call, **kwargs: Any) -> foast.Call: + # TODO(tehrengruber): is this still needed or redundant with the checks in type deduction? if isinstance(node.func, ast.Name): func_name = self._func_name(node) - - if func_name == "tuple": - (gen_expr,) = node.args - assert ( - len(gen_expr.generators) == 1 - ) # we don't support (... for ... in ... for ... in ...) - assert ( - gen_expr.generators[0].ifs == [] - ) # we don't support if conditions in comprehensions - return foast.TupleComprehension( - element_expr=self.visit(gen_expr.elt, **kwargs), - target=self.visit(gen_expr.generators[0].target, **kwargs), - iterable=self.visit(gen_expr.generators[0].iter, **kwargs), - location=self.get_location(node), - ) - - # TODO(tehrengruber): is this still needed or redundant with the checks in type deduction? if func_name in fbuiltins.TYPE_BUILTIN_NAMES: self._verify_builtin_type_constructor(node) diff --git a/src/gt4py/next/ffront/past_passes/type_deduction.py b/src/gt4py/next/ffront/past_passes/type_deduction.py index 530d407459..9d021ceb51 100644 --- a/src/gt4py/next/ffront/past_passes/type_deduction.py +++ b/src/gt4py/next/ffront/past_passes/type_deduction.py @@ -248,7 +248,7 @@ def visit_Call(self, node: past.Call, **kwargs: Any) -> past.Call: operator_return_type = type_info.return_type( new_func.type, with_args=arg_types, with_kwargs=kwarg_types ) - if not type_info.is_compatible_type(operator_return_type, new_kwargs["out"].type): + if operator_return_type != new_kwargs["out"].type: raise ValueError( "Expected keyword argument 'out' to be of " f"type '{operator_return_type}', got " diff --git a/src/gt4py/next/iterator/builtins.py b/src/gt4py/next/iterator/builtins.py index 7b24c91884..273dca847f 100644 --- a/src/gt4py/next/iterator/builtins.py +++ b/src/gt4py/next/iterator/builtins.py @@ -498,7 +498,7 @@ def get_domain_range(*args): "lift", "make_const_list", "make_tuple", - "map_tuple", + "tree_map", "map_", # TODO: rename to map_list "named_range", "neighbors", diff --git a/src/gt4py/next/iterator/ir_utils/ir_makers.py b/src/gt4py/next/iterator/ir_utils/ir_makers.py index 4b30e878fe..9545525a94 100644 --- a/src/gt4py/next/iterator/ir_utils/ir_makers.py +++ b/src/gt4py/next/iterator/ir_utils/ir_makers.py @@ -629,6 +629,11 @@ def map_(op): return call(call("map_")(op)) +def tree_map(op): + """Create a `tree_map` call: tree_map(op)(tup1, tup2, ...).""" + return call(call("tree_map")(op)) + + def reduce(op, expr): """Create a `reduce` call.""" return call(call("reduce")(op, expr)) diff --git a/src/gt4py/next/iterator/transforms/pass_manager.py b/src/gt4py/next/iterator/transforms/pass_manager.py index 8825ad00ed..5feba5be9e 100644 --- a/src/gt4py/next/iterator/transforms/pass_manager.py +++ b/src/gt4py/next/iterator/transforms/pass_manager.py @@ -23,7 +23,7 @@ prune_empty_concat_where, remove_broadcast, symbol_ref_utils, - unroll_map_tuple, + unroll_tree_map, ) from gt4py.next.iterator.transforms.collapse_list_get import CollapseListGet from gt4py.next.iterator.transforms.collapse_tuple import CollapseTuple @@ -177,7 +177,7 @@ def apply_common_transforms( ) # domain inference does not support dynamic offsets yet ir = infer_domain_ops.InferDomainOps.apply(ir) ir = concat_where.canonicalize_domain_argument(ir) - ir = unroll_map_tuple.UnrollMapTuple.apply(ir, uids=uids) + ir = unroll_tree_map.UnrollTreeMap.apply(ir, uids=uids) ir = infer_domain.infer_program( ir, @@ -292,7 +292,7 @@ def apply_fieldview_transforms( ir = infer_domain_ops.InferDomainOps.apply(ir) ir = concat_where.canonicalize_domain_argument(ir) - ir = unroll_map_tuple.UnrollMapTuple.apply(ir, uids=uids) + ir = unroll_tree_map.UnrollTreeMap.apply(ir, uids=uids) ir = ConstantFolding.apply(ir) # type: ignore[assignment] # always an itir.Program ir = infer_domain.infer_program( diff --git a/src/gt4py/next/iterator/transforms/unroll_map_tuple.py b/src/gt4py/next/iterator/transforms/unroll_map_tuple.py deleted file mode 100644 index 66f96d66fa..0000000000 --- a/src/gt4py/next/iterator/transforms/unroll_map_tuple.py +++ /dev/null @@ -1,47 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2024, ETH Zurich -# All rights reserved. -# -# Please, refer to the LICENSE file in the root directory. -# SPDX-License-Identifier: BSD-3-Clause -import dataclasses - -from gt4py import eve -from gt4py.next import utils -from gt4py.next.iterator import ir as itir -from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im -from gt4py.next.iterator.type_system import inference as itir_inference -from gt4py.next.type_system import type_specifications as ts - - -@dataclasses.dataclass -class UnrollMapTuple(eve.NodeTranslator): - PRESERVED_ANNEX_ATTRS = ("domain",) - - uids: utils.IDGeneratorPool - - @classmethod - def apply(cls, program: itir.Program, *, uids: utils.IDGeneratorPool): - return cls(uids=uids).visit(program) - - def visit_FunCall(self, node: itir.Expr): - node = self.generic_visit(node) - - if cpm.is_call_to(node.fun, "map_tuple"): - # TODO: we have to duplicate the function here since the domain inference can not handle them yet - f = node.fun.args[0] - tup = node.args[0] - itir_inference.reinfer(tup) - assert isinstance(tup.type, ts.TupleType) - tup_ref = next(self.uids["_ump"]) - - result = im.let(tup_ref, tup)( - im.make_tuple( - *(im.call(f)(im.tuple_get(i, tup_ref)) for i in range(len(tup.type.types))) - ) - ) - itir_inference.reinfer(result) - - return result - return node diff --git a/src/gt4py/next/iterator/transforms/unroll_tree_map.py b/src/gt4py/next/iterator/transforms/unroll_tree_map.py new file mode 100644 index 0000000000..c341797216 --- /dev/null +++ b/src/gt4py/next/iterator/transforms/unroll_tree_map.py @@ -0,0 +1,69 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause +import dataclasses + +from gt4py import eve +from gt4py.next import utils +from gt4py.next.iterator import ir as itir +from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im +from gt4py.next.iterator.type_system import inference as itir_inference +from gt4py.next.type_system import type_specifications as ts + + +def _unroll( + f: itir.Expr, + tup_types: list[ts.TupleType], + tup_exprs: list[itir.Expr], +) -> itir.Expr: + """Recursively expand ``tree_map(f)(tup0, tup1, ...)`` into ``make_tuple`` / ``tuple_get``.""" + n = len(tup_types[0].types) + + elements: list[itir.Expr] = [] + for i in range(n): + child_types = [t.types[i] for t in tup_types] + child_exprs = [im.tuple_get(i, e) for e in tup_exprs] + + if all(isinstance(ct, ts.TupleType) for ct in child_types): + elements.append(_unroll(f, child_types, child_exprs)) # type: ignore[arg-type] + else: + elements.append(im.call(f)(*child_exprs)) + + return im.make_tuple(*elements) + + +@dataclasses.dataclass +class UnrollTreeMap(eve.NodeTranslator): + PRESERVED_ANNEX_ATTRS = ("domain",) + + uids: utils.IDGeneratorPool + + @classmethod + def apply(cls, program: itir.Program, *, uids: utils.IDGeneratorPool): + return cls(uids=uids).visit(program) + + def visit_FunCall(self, node: itir.Expr): + node = self.generic_visit(node) + + if not cpm.is_call_to(node.fun, "tree_map"): + return node + + f = node.fun.args[0] + tup_args = node.args + for tup in tup_args: + itir_inference.reinfer(tup) + assert isinstance(tup.type, ts.TupleType) + + tup_refs = [next(self.uids["_utm"]) for _ in tup_args] + body = _unroll(f, [tup.type for tup in tup_args], [im.ref(r) for r in tup_refs]) + + result = body + for ref_name, tup in reversed(list(zip(tup_refs, tup_args))): + result = im.let(ref_name, tup)(result) + + itir_inference.reinfer(result) + return result diff --git a/src/gt4py/next/iterator/type_system/type_synthesizer.py b/src/gt4py/next/iterator/type_system/type_synthesizer.py index 98f3540d91..02b7a52c3b 100644 --- a/src/gt4py/next/iterator/type_system/type_synthesizer.py +++ b/src/gt4py/next/iterator/type_system/type_synthesizer.py @@ -634,14 +634,22 @@ def applied_map( @_register_builtin_type_synthesizer -def map_tuple(op: TypeSynthesizer) -> TypeSynthesizer: +def tree_map(op: TypeSynthesizer) -> TypeSynthesizer: @type_synthesizer def applied_map( - arg: ts.TupleType, offset_provider_type: common.OffsetProviderType + *args: ts.TupleType, offset_provider_type: common.OffsetProviderType ) -> ts.TupleType: - return ts.TupleType( - types=[op(arg_, offset_provider_type=offset_provider_type) for arg_ in arg.types] - ) + def _recurse(*arg_types: ts.TypeSpec) -> ts.TypeSpec: + if isinstance(arg_types[0], ts.TupleType): + return ts.TupleType( + types=[ + _recurse(*(a.types[i] for a in arg_types)) # type: ignore[union-attr] + for i in range(len(arg_types[0].types)) # type: ignore[union-attr] + ] + ) + return op(*arg_types, offset_provider_type=offset_provider_type) + + return _recurse(*args) return applied_map diff --git a/src/gt4py/next/type_system/type_info.py b/src/gt4py/next/type_system/type_info.py index 69fccd33da..eb70d15947 100644 --- a/src/gt4py/next/type_system/type_info.py +++ b/src/gt4py/next/type_system/type_info.py @@ -566,14 +566,6 @@ def is_concretizable(symbol_type: ts.TypeSpec, to_type: ts.TypeSpec) -> bool: or issubclass(type_class(to_type), symbol_type.constraint) ): return True - if isinstance(symbol_type, ts.VarArgType) and isinstance(to_type, ts.VarArgType): - return is_concretizable(symbol_type.element_type, to_type.element_type) - if isinstance(symbol_type, ts.VarArgType) and isinstance(to_type, ts.TupleType): - if len(to_type.types) == 0 or ( - all(type_ == to_type.types[0] for type_ in to_type.types) - and is_concretizable(symbol_type.element_type, to_type.types[0]) - ): - return True elif is_concrete(symbol_type): return symbol_type == to_type return False diff --git a/src/gt4py/next/type_system/type_specifications.py b/src/gt4py/next/type_system/type_specifications.py index 409138d593..59ac40f0f3 100644 --- a/src/gt4py/next/type_system/type_specifications.py +++ b/src/gt4py/next/type_system/type_specifications.py @@ -148,15 +148,6 @@ def __len__(self) -> int: return len(self.types) -class VarArgType(DataType): - """Represents a variable number of arguments of the same type.""" - - element_type: DataType # TODO: maybe also support different DataTypes - - def __str__(self) -> str: - return f"VarArg[{self.element_type}]" - - class AnyPythonType: """Marker type representing any Python type which cannot be used for instantiation. diff --git a/src/gt4py/next/type_system/type_translation.py b/src/gt4py/next/type_system/type_translation.py index 1d7a9aa2f7..3671c5b344 100644 --- a/src/gt4py/next/type_system/type_translation.py +++ b/src/gt4py/next/type_system/type_translation.py @@ -181,12 +181,8 @@ def from_type_hint( case builtins.tuple: if not args: raise ValueError(f"Tuple annotation '{type_hint}' requires at least one argument.") - if len(args) == 2 and args[1] is Ellipsis: - return ts.VarArgType(element_type=from_type_hint_same_ns(args[0])) - elif Ellipsis in args: - raise ValueError( - f"Vararg tuple annotation '{type_hint}' cannot have more than one argument." - ) + if Ellipsis in args: + raise ValueError(f"Unbound tuples '{type_hint}' are not allowed.") tuple_types = [from_type_hint_same_ns(arg) for arg in args] assert all(isinstance(elem, ts.DataType) for elem in tuple_types) return ts.TupleType(types=tuple_types) @@ -330,19 +326,7 @@ def from_value(value: Any) -> ts.TypeSpec: return NamespaceProxy(value) else: type_ = xtyping.infer_type(value, annotate_callable_kwargs=True) - if type_ == type[tuple]: - # TODO: this special casing here is not nice, but infer_type is also called on the annotations where - # we don't want to allow unparameterized tuples (or do we?). - symbol_type = ts.ConstructorType( - definition=ts.FunctionType( - pos_only_args=[ts.DeferredType(constraint=None)], - pos_or_kw_args={}, - kw_only_args={}, - returns=ts.DeferredType(constraint=ts.VarArgType), - ) - ) - else: - symbol_type = from_type_hint(type_) + symbol_type = from_type_hint(type_) if isinstance(symbol_type, (ts.DataType, ts.CallableType, ts.OffsetType, ts.DimensionType)): return symbol_type diff --git a/tests/next_tests/integration_tests/cases.py b/tests/next_tests/integration_tests/cases.py index d65ecefb10..d552a09a2a 100644 --- a/tests/next_tests/integration_tests/cases.py +++ b/tests/next_tests/integration_tests/cases.py @@ -603,15 +603,6 @@ def _allocate_from_type( for t in types ) ) - case ts.VarArgType(element_type=element_type): - return tuple( - ( - _allocate_from_type( - case=case, arg_type=t, domain=domain, dtype=dtype, strategy=strategy - ) - for t in [element_type] * 3 # TODO: revisit - ) - ) case ts.NamedCollectionType(types=types) as named_collection_type_spec: container_constructor = ( named_collections.make_named_collection_constructor_from_type_spec( @@ -657,8 +648,6 @@ def get_param_size(param_type: ts.TypeSpec, sizes: dict[gtx.Dimension, int]) -> return sum([get_param_size(t, sizes=sizes) for t in types]) case ts.NamedCollectionType(types=types): return sum([get_param_size(t, sizes=sizes) for t in types]) - case ts.VarArgType(element_type=element_type): - return get_param_size(ts.TupleType(types=[element_type] * 3), sizes) # TODO: revisit case _: raise TypeError(f"Can not get size for parameter of type '{param_type}'.") diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py index 986fa5f5cb..c58ac5f497 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py @@ -338,36 +338,6 @@ def testee(a: tuple[cases.IField, cases.IJField]) -> cases.IJField: ) -@pytest.mark.uses_tuple_args -def test_tuple_comprehension(cartesian_case): - @gtx.field_operator - def testee( - tracers: tuple[cases.IFloatField, ...], factor: float - ) -> tuple[cases.IFloatField, ...]: - return tuple(tracer * factor for tracer in tracers) - - cases.verify_with_default_data( - cartesian_case, - testee, - ref=lambda t, f: tuple(el * f for el in t), - ) - - -@pytest.mark.uses_tuple_args -def test_tuple_vararg(cartesian_case): - @gtx.field_operator - def testee( - tracers: tuple[cases.IFloatField, ...], factor: float - ) -> tuple[cases.IFloatField, cases.IFloatField]: - return tracers[0] * factor, tracers[1] * factor - - cases.verify_with_default_data( - cartesian_case, - testee, - ref=lambda t, f: tuple(el * f for el in t[:2]), - ) - - @pytest.mark.uses_tuple_args @pytest.mark.xfail(reason="Iterator of tuple approach in lowering does not allow this.") def test_tuple_arg_with_unpromotable_dims(unstructured_case): diff --git a/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py b/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py index 8e3bba90b9..e516d7ddbd 100644 --- a/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py +++ b/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py @@ -207,10 +207,11 @@ def foo( lowered ) # we generate a let for the condition which is removed by inlining for easier testing - reference = im.make_tuple( - im.op_as_fieldop("if_")("a", im.tuple_get(0, "b"), im.tuple_get(0, "c")), - im.op_as_fieldop("if_")("a", im.tuple_get(1, "b"), im.tuple_get(1, "c")), - ) + reference = im.tree_map( # TODO: check if this is what we want + im.lambda_("__a", "__b")( + im.op_as_fieldop("if_")("a", im.ref("__a"), im.ref("__b")) + ) + )("b", "c") assert lowered_inlined.expr == reference diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_unroll_tree_map.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_unroll_tree_map.py new file mode 100644 index 0000000000..9f1a379236 --- /dev/null +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_unroll_tree_map.py @@ -0,0 +1,43 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +from gt4py.next.iterator.ir_utils import ir_makers as im +from gt4py.next.iterator.transforms.unroll_tree_map import _unroll +from gt4py.next.type_system import type_specifications as ts + + +T = ts.ScalarType(kind=ts.ScalarKind.FLOAT64) +TT = ts.TupleType(types=[T, T]) + + +def test_single_arg(): + result = _unroll(im.ref("f"), [TT], [im.ref("t")]) + expected = im.make_tuple(im.call("f")(im.tuple_get(0, "t")), im.call("f")(im.tuple_get(1, "t"))) + assert result == expected + + +def test_multi_arg(): + result = _unroll(im.ref("f"), [TT, TT], [im.ref("a"), im.ref("b")]) + expected = im.make_tuple( + im.call("f")(im.tuple_get(0, "a"), im.tuple_get(0, "b")), + im.call("f")(im.tuple_get(1, "a"), im.tuple_get(1, "b")), + ) + assert result == expected + + +def test_nested(): + outer = ts.TupleType(types=[TT, T]) + result = _unroll(im.ref("f"), [outer], [im.ref("t")]) + expected = im.make_tuple( + im.make_tuple( + im.call("f")(im.tuple_get(0, im.tuple_get(0, "t"))), + im.call("f")(im.tuple_get(1, im.tuple_get(0, "t"))), + ), + im.call("f")(im.tuple_get(1, "t")), + ) + assert result == expected From 0ec4692ac28ecd536e4756d5304a54b4b649a7fb Mon Sep 17 00:00:00 2001 From: Sara Faghih-Naini Date: Mon, 27 Apr 2026 16:57:07 +0200 Subject: [PATCH 3/7] Run pre-commit and fix some tests --- src/gt4py/next/ffront/foast_to_gtir.py | 1 - src/gt4py/next/iterator/builtins.py | 5 +++++ .../next/iterator/transforms/unroll_tree_map.py | 9 ++++++--- .../next/iterator/type_system/type_synthesizer.py | 13 +++++++------ .../unit_tests/ffront_tests/test_foast_to_gtir.py | 6 ++---- 5 files changed, 20 insertions(+), 14 deletions(-) diff --git a/src/gt4py/next/ffront/foast_to_gtir.py b/src/gt4py/next/ffront/foast_to_gtir.py index 78b0671db1..c341a311b1 100644 --- a/src/gt4py/next/ffront/foast_to_gtir.py +++ b/src/gt4py/next/ffront/foast_to_gtir.py @@ -416,7 +416,6 @@ def _visit_where(self, node: foast.Call, **kwargs: Any) -> itir.FunCall: false_ = self.visit(node.args[2]) cond_symref_name = f"__cond_{cond_.fingerprint()}" - # tree_map(lambda a, b: as_fieldop(if_)(cond_ref, a, b))(true_tup, false_tup) result = im.tree_map( im.lambda_("__a", "__b")( im.op_as_fieldop("if_")(im.ref(cond_symref_name), im.ref("__a"), im.ref("__b")) diff --git a/src/gt4py/next/iterator/builtins.py b/src/gt4py/next/iterator/builtins.py index 273dca847f..b60932eed8 100644 --- a/src/gt4py/next/iterator/builtins.py +++ b/src/gt4py/next/iterator/builtins.py @@ -57,6 +57,11 @@ def map_(*args): raise BackendNotSelectedError() +@builtin_dispatch +def tree_map(*args): + raise BackendNotSelectedError() + + @builtin_dispatch def make_const_list(*args): raise BackendNotSelectedError() diff --git a/src/gt4py/next/iterator/transforms/unroll_tree_map.py b/src/gt4py/next/iterator/transforms/unroll_tree_map.py index c341797216..03d12355ec 100644 --- a/src/gt4py/next/iterator/transforms/unroll_tree_map.py +++ b/src/gt4py/next/iterator/transforms/unroll_tree_map.py @@ -29,7 +29,8 @@ def _unroll( child_exprs = [im.tuple_get(i, e) for e in tup_exprs] if all(isinstance(ct, ts.TupleType) for ct in child_types): - elements.append(_unroll(f, child_types, child_exprs)) # type: ignore[arg-type] + nested_types = [ct for ct in child_types if isinstance(ct, ts.TupleType)] + elements.append(_unroll(f, nested_types, child_exprs)) else: elements.append(im.call(f)(*child_exprs)) @@ -46,7 +47,7 @@ class UnrollTreeMap(eve.NodeTranslator): def apply(cls, program: itir.Program, *, uids: utils.IDGeneratorPool): return cls(uids=uids).visit(program) - def visit_FunCall(self, node: itir.Expr): + def visit_FunCall(self, node: itir.FunCall): node = self.generic_visit(node) if not cpm.is_call_to(node.fun, "tree_map"): @@ -54,12 +55,14 @@ def visit_FunCall(self, node: itir.Expr): f = node.fun.args[0] tup_args = node.args + tup_types: list[ts.TupleType] = [] for tup in tup_args: itir_inference.reinfer(tup) assert isinstance(tup.type, ts.TupleType) + tup_types.append(tup.type) tup_refs = [next(self.uids["_utm"]) for _ in tup_args] - body = _unroll(f, [tup.type for tup in tup_args], [im.ref(r) for r in tup_refs]) + body = _unroll(f, tup_types, [im.ref(r) for r in tup_refs]) result = body for ref_name, tup in reversed(list(zip(tup_refs, tup_args))): diff --git a/src/gt4py/next/iterator/type_system/type_synthesizer.py b/src/gt4py/next/iterator/type_system/type_synthesizer.py index 02b7a52c3b..4d5fe5e6d0 100644 --- a/src/gt4py/next/iterator/type_system/type_synthesizer.py +++ b/src/gt4py/next/iterator/type_system/type_synthesizer.py @@ -633,23 +633,24 @@ def applied_map( return applied_map -@_register_builtin_type_synthesizer -def tree_map(op: TypeSynthesizer) -> TypeSynthesizer: +@_register_builtin_type_synthesizer(fun_names=["tree_map"]) +def _tree_map(op: TypeSynthesizer) -> TypeSynthesizer: @type_synthesizer def applied_map( *args: ts.TupleType, offset_provider_type: common.OffsetProviderType ) -> ts.TupleType: def _recurse(*arg_types: ts.TypeSpec) -> ts.TypeSpec: if isinstance(arg_types[0], ts.TupleType): + tup_types = [a for a in arg_types if isinstance(a, ts.TupleType)] return ts.TupleType( types=[ - _recurse(*(a.types[i] for a in arg_types)) # type: ignore[union-attr] - for i in range(len(arg_types[0].types)) # type: ignore[union-attr] + _recurse(*(a.types[i] for a in tup_types)) + for i in range(len(arg_types[0].types)) ] ) - return op(*arg_types, offset_provider_type=offset_provider_type) + return op(*arg_types, offset_provider_type=offset_provider_type) # type: ignore[return-value] - return _recurse(*args) + return _recurse(*args) # type: ignore[return-value] return applied_map diff --git a/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py b/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py index e516d7ddbd..726be8051b 100644 --- a/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py +++ b/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py @@ -207,10 +207,8 @@ def foo( lowered ) # we generate a let for the condition which is removed by inlining for easier testing - reference = im.tree_map( # TODO: check if this is what we want - im.lambda_("__a", "__b")( - im.op_as_fieldop("if_")("a", im.ref("__a"), im.ref("__b")) - ) + reference = im.tree_map( # TODO: check if this is what we want + im.lambda_("__a", "__b")(im.op_as_fieldop("if_")("a", im.ref("__a"), im.ref("__b"))) )("b", "c") assert lowered_inlined.expr == reference From ab84ecc0597d35a0c06724f67d1990c957858768 Mon Sep 17 00:00:00 2001 From: Sara Faghih-Naini Date: Tue, 28 Apr 2026 12:42:37 +0200 Subject: [PATCH 4/7] Run CollapseTuple after UnrollTreeMap --- .../next/iterator/transforms/pass_manager.py | 30 +++++++++++++++++++ .../ffront_tests/test_foast_to_gtir.py | 2 +- 2 files changed, 31 insertions(+), 1 deletion(-) diff --git a/src/gt4py/next/iterator/transforms/pass_manager.py b/src/gt4py/next/iterator/transforms/pass_manager.py index 5feba5be9e..6b98ccbae6 100644 --- a/src/gt4py/next/iterator/transforms/pass_manager.py +++ b/src/gt4py/next/iterator/transforms/pass_manager.py @@ -179,6 +179,23 @@ def apply_common_transforms( ir = concat_where.canonicalize_domain_argument(ir) ir = unroll_tree_map.UnrollTreeMap.apply(ir, uids=uids) + # After UnrollTreeMap, collapse `tuple_get(i, let(...)(make_tuple(...)))` patterns so that + # domain inference does not encounter `as_fieldop` nodes inside dead tuple elements + # (which would receive NEVER domain). Do multiple iterations for nested `let`s. + for _ in range(10): + collapsed = ir + ir = CollapseTuple.apply( + ir, + enabled_transformations=( + CollapseTuple.Transformation.PROPAGATE_TUPLE_GET + | CollapseTuple.Transformation.COLLAPSE_TUPLE_GET_MAKE_TUPLE + ), + uids=uids, + offset_provider_type=offset_provider_type, + ) # type: ignore[assignment] # always an itir.Program + if ir == collapsed: + break + ir = infer_domain.infer_program( ir, offset_provider=offset_provider, @@ -293,6 +310,19 @@ def apply_fieldview_transforms( ir = infer_domain_ops.InferDomainOps.apply(ir) ir = concat_where.canonicalize_domain_argument(ir) ir = unroll_tree_map.UnrollTreeMap.apply(ir, uids=uids) + for _ in range(10): + prev = ir + ir = CollapseTuple.apply( + ir, + enabled_transformations=( + CollapseTuple.Transformation.PROPAGATE_TUPLE_GET + | CollapseTuple.Transformation.COLLAPSE_TUPLE_GET_MAKE_TUPLE + ), + uids=uids, + offset_provider_type=offset_provider_type, + ) # type: ignore[assignment] # always an itir.Program + if ir == prev: + break ir = ConstantFolding.apply(ir) # type: ignore[assignment] # always an itir.Program ir = infer_domain.infer_program( diff --git a/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py b/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py index 726be8051b..bf2978a8f2 100644 --- a/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py +++ b/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py @@ -207,7 +207,7 @@ def foo( lowered ) # we generate a let for the condition which is removed by inlining for easier testing - reference = im.tree_map( # TODO: check if this is what we want + reference = im.tree_map( im.lambda_("__a", "__b")(im.op_as_fieldop("if_")("a", im.ref("__a"), im.ref("__b"))) )("b", "c") From 152300e75e46def284a98bf8a19bca5dc35435de Mon Sep 17 00:00:00 2001 From: Sara Faghih-Naini Date: Tue, 28 Apr 2026 15:27:33 +0200 Subject: [PATCH 5/7] Address review comments --- .../next/iterator/transforms/pass_manager.py | 5 ++ .../iterator/transforms/unroll_tree_map.py | 17 +++- .../iterator/type_system/type_synthesizer.py | 26 ++++-- .../iterator_tests/test_type_inference.py | 17 ++++ .../transforms_tests/test_unroll_tree_map.py | 85 +++++++++++++++---- 5 files changed, 124 insertions(+), 26 deletions(-) diff --git a/src/gt4py/next/iterator/transforms/pass_manager.py b/src/gt4py/next/iterator/transforms/pass_manager.py index 6b98ccbae6..32d847841b 100644 --- a/src/gt4py/next/iterator/transforms/pass_manager.py +++ b/src/gt4py/next/iterator/transforms/pass_manager.py @@ -195,6 +195,8 @@ def apply_common_transforms( ) # type: ignore[assignment] # always an itir.Program if ir == collapsed: break + else: + raise RuntimeError("'CollapseTuple' did not converge after `UnrollTreeMap`.") ir = infer_domain.infer_program( ir, @@ -323,6 +325,9 @@ def apply_fieldview_transforms( ) # type: ignore[assignment] # always an itir.Program if ir == prev: break + else: + raise RuntimeError("'CollapseTuple' did not converge after `UnrollTreeMap`.") + ir = ConstantFolding.apply(ir) # type: ignore[assignment] # always an itir.Program ir = infer_domain.infer_program( diff --git a/src/gt4py/next/iterator/transforms/unroll_tree_map.py b/src/gt4py/next/iterator/transforms/unroll_tree_map.py index 03d12355ec..8ef9dad7ec 100644 --- a/src/gt4py/next/iterator/transforms/unroll_tree_map.py +++ b/src/gt4py/next/iterator/transforms/unroll_tree_map.py @@ -21,18 +21,31 @@ def _unroll( tup_exprs: list[itir.Expr], ) -> itir.Expr: """Recursively expand ``tree_map(f)(tup0, tup1, ...)`` into ``make_tuple`` / ``tuple_get``.""" + assert tup_types, "tree_map requires at least one tuple argument." n = len(tup_types[0].types) + if any(len(t.types) != n for t in tup_types[1:]): + raise ValueError( + f"All tree_map arguments must have the same tuple structure at each level, " + f"got {[len(t.types) for t in tup_types]}." + ) elements: list[itir.Expr] = [] for i in range(n): child_types = [t.types[i] for t in tup_types] child_exprs = [im.tuple_get(i, e) for e in tup_exprs] - if all(isinstance(ct, ts.TupleType) for ct in child_types): + all_tuples = all(isinstance(ct, ts.TupleType) for ct in child_types) + all_leaves = all(not isinstance(ct, ts.TupleType) for ct in child_types) + if all_tuples: nested_types = [ct for ct in child_types if isinstance(ct, ts.TupleType)] elements.append(_unroll(f, nested_types, child_exprs)) - else: + elif all_leaves: elements.append(im.call(f)(*child_exprs)) + else: + raise ValueError( + "All tree_map arguments must have the same tree structure " + "(all leaves must be reached simultaneously)." + ) return im.make_tuple(*elements) diff --git a/src/gt4py/next/iterator/type_system/type_synthesizer.py b/src/gt4py/next/iterator/type_system/type_synthesizer.py index 4d5fe5e6d0..38b9ce1bbf 100644 --- a/src/gt4py/next/iterator/type_system/type_synthesizer.py +++ b/src/gt4py/next/iterator/type_system/type_synthesizer.py @@ -639,16 +639,30 @@ def _tree_map(op: TypeSynthesizer) -> TypeSynthesizer: def applied_map( *args: ts.TupleType, offset_provider_type: common.OffsetProviderType ) -> ts.TupleType: + if not args: + raise TypeError("tree_map requires at least one argument.") + def _recurse(*arg_types: ts.TypeSpec) -> ts.TypeSpec: - if isinstance(arg_types[0], ts.TupleType): + all_tuples = all(isinstance(a, ts.TupleType) for a in arg_types) + all_leaves = all(not isinstance(a, ts.TupleType) for a in arg_types) + if all_tuples: tup_types = [a for a in arg_types if isinstance(a, ts.TupleType)] + n = len(tup_types[0].types) + if any(len(t.types) != n for t in tup_types[1:]): + raise TypeError( + f"All tree_map arguments must have the same tuple structure at each level, " + f"got {[len(t.types) for t in tup_types]}." + ) return ts.TupleType( - types=[ - _recurse(*(a.types[i] for a in tup_types)) - for i in range(len(arg_types[0].types)) - ] + types=[_recurse(*(a.types[i] for a in tup_types)) for i in range(n)] + ) + elif all_leaves: + return op(*arg_types, offset_provider_type=offset_provider_type) # type: ignore[return-value] + else: + raise TypeError( + "All tree_map arguments must have the same tree structure " + "(all leaves must be reached simultaneously)." ) - return op(*arg_types, offset_provider_type=offset_provider_type) # type: ignore[return-value] return _recurse(*args) # type: ignore[return-value] diff --git a/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py b/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py index d73fc1945f..f901a7ce39 100644 --- a/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py +++ b/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py @@ -155,6 +155,23 @@ def expression_test_cases(): ), ts.ListType(element_type=int_type, offset_type=V2EDim), ), + # tree_map + ( + im.tree_map(im.ref("plus"))( + im.ref("t1", ts.TupleType(types=[int_type, int_type])), + im.ref("t2", ts.TupleType(types=[int_type, int_type])), + ), + ts.TupleType(types=[int_type, int_type]), + ), + ( + im.tree_map(im.ref("not_"))( + im.ref( + "t", + ts.TupleType(types=[bool_type, ts.TupleType(types=[bool_type, bool_type])]), + ), + ), + ts.TupleType(types=[bool_type, ts.TupleType(types=[bool_type, bool_type])]), + ), # reduce (im.reduce("plus", 0)(im.ref("l", int_list_type)), int_type), ( diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_unroll_tree_map.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_unroll_tree_map.py index 9f1a379236..0c0f348d88 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_unroll_tree_map.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_unroll_tree_map.py @@ -6,38 +6,87 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause +from gt4py.next import common, utils +from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils import ir_makers as im -from gt4py.next.iterator.transforms.unroll_tree_map import _unroll +from gt4py.next.iterator.transforms.unroll_tree_map import UnrollTreeMap from gt4py.next.type_system import type_specifications as ts - +IDim = common.Dimension("IDim") T = ts.ScalarType(kind=ts.ScalarKind.FLOAT64) -TT = ts.TupleType(types=[T, T]) +i_field = ts.FieldType(dims=[IDim], dtype=T) +i_tuple_field = ts.TupleType(types=[i_field, i_field]) +i_nested_tuple_field = ts.TupleType(types=[i_tuple_field, i_field]) +i_domain = im.call("cartesian_domain")(im.named_range(itir.AxisLiteral(value="IDim"), 0, 1)) + + +def _make_program(params: list[itir.Sym], expr: itir.Expr) -> itir.Program: + return itir.Program( + id="testee", + function_definitions=[], + params=[*params, im.sym("out", i_field)], + declarations=[], + body=[ + itir.SetAt( + expr=expr, + domain=i_domain, + target=im.ref("out"), + ) + ], + ) + + +def _neg(): + return im.lambda_("__a")(im.op_as_fieldop("neg")("__a")) -def test_single_arg(): - result = _unroll(im.ref("f"), [TT], [im.ref("t")]) - expected = im.make_tuple(im.call("f")(im.tuple_get(0, "t")), im.call("f")(im.tuple_get(1, "t"))) - assert result == expected + +def _plus(): + return im.lambda_("__a", "__b")(im.op_as_fieldop("plus")("__a", "__b")) def test_multi_arg(): - result = _unroll(im.ref("f"), [TT, TT], [im.ref("a"), im.ref("b")]) - expected = im.make_tuple( - im.call("f")(im.tuple_get(0, "a"), im.tuple_get(0, "b")), - im.call("f")(im.tuple_get(1, "a"), im.tuple_get(1, "b")), + uids = utils.IDGeneratorPool() + program = _make_program( + [im.sym("a", i_tuple_field), im.sym("b", i_tuple_field)], + im.call(im.call("tree_map")(_plus()))( + im.ref("a", i_tuple_field), im.ref("b", i_tuple_field) + ), + ) + result = UnrollTreeMap.apply(program, uids=uids) + + expected = _make_program( + [im.sym("a", i_tuple_field), im.sym("b", i_tuple_field)], + im.let("_utm_0", "a")( + im.let("_utm_1", "b")( + im.make_tuple( + im.call(_plus())(im.tuple_get(0, "_utm_0"), im.tuple_get(0, "_utm_1")), + im.call(_plus())(im.tuple_get(1, "_utm_0"), im.tuple_get(1, "_utm_1")), + ) + ) + ), ) assert result == expected def test_nested(): - outer = ts.TupleType(types=[TT, T]) - result = _unroll(im.ref("f"), [outer], [im.ref("t")]) - expected = im.make_tuple( - im.make_tuple( - im.call("f")(im.tuple_get(0, im.tuple_get(0, "t"))), - im.call("f")(im.tuple_get(1, im.tuple_get(0, "t"))), + uids = utils.IDGeneratorPool() + program = _make_program( + [im.sym("t", i_nested_tuple_field)], + im.call(im.call("tree_map")(_neg()))(im.ref("t", i_nested_tuple_field)), + ) + result = UnrollTreeMap.apply(program, uids=uids) + + expected = _make_program( + [im.sym("t", i_nested_tuple_field)], + im.let("_utm_0", "t")( + im.make_tuple( + im.make_tuple( + im.call(_neg())(im.tuple_get(0, im.tuple_get(0, "_utm_0"))), + im.call(_neg())(im.tuple_get(1, im.tuple_get(0, "_utm_0"))), + ), + im.call(_neg())(im.tuple_get(1, "_utm_0")), + ) ), - im.call("f")(im.tuple_get(1, "t")), ) assert result == expected From d459b0e7a05aa42c18878f3fe10c2ffecd3a6239 Mon Sep 17 00:00:00 2001 From: Sara Faghih-Naini Date: Tue, 28 Apr 2026 16:00:38 +0200 Subject: [PATCH 6/7] Address further review comments --- .../next/iterator/type_system/type_synthesizer.py | 5 +++++ .../transforms_tests/test_unroll_tree_map.py | 12 +++++++++--- 2 files changed, 14 insertions(+), 3 deletions(-) diff --git a/src/gt4py/next/iterator/type_system/type_synthesizer.py b/src/gt4py/next/iterator/type_system/type_synthesizer.py index 38b9ce1bbf..209d6ecf80 100644 --- a/src/gt4py/next/iterator/type_system/type_synthesizer.py +++ b/src/gt4py/next/iterator/type_system/type_synthesizer.py @@ -641,6 +641,11 @@ def applied_map( ) -> ts.TupleType: if not args: raise TypeError("tree_map requires at least one argument.") + if not all(isinstance(a, ts.TupleType) for a in args): + raise TypeError( + "tree_map requires all top-level arguments to be TupleType, " + f"got {[type(a).__name__ for a in args]}." + ) def _recurse(*arg_types: ts.TypeSpec) -> ts.TypeSpec: all_tuples = all(isinstance(a, ts.TupleType) for a in arg_types) diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_unroll_tree_map.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_unroll_tree_map.py index 0c0f348d88..f0b4165f1a 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_unroll_tree_map.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_unroll_tree_map.py @@ -21,17 +21,19 @@ i_domain = im.call("cartesian_domain")(im.named_range(itir.AxisLiteral(value="IDim"), 0, 1)) -def _make_program(params: list[itir.Sym], expr: itir.Expr) -> itir.Program: +def _make_program( + params: list[itir.Sym], expr: itir.Expr, out_type: ts.TypeSpec = i_field +) -> itir.Program: return itir.Program( id="testee", function_definitions=[], - params=[*params, im.sym("out", i_field)], + params=[*params, im.sym("out", out_type)], declarations=[], body=[ itir.SetAt( expr=expr, domain=i_domain, - target=im.ref("out"), + target=im.ref("out", out_type), ) ], ) @@ -52,6 +54,7 @@ def test_multi_arg(): im.call(im.call("tree_map")(_plus()))( im.ref("a", i_tuple_field), im.ref("b", i_tuple_field) ), + out_type=i_tuple_field, ) result = UnrollTreeMap.apply(program, uids=uids) @@ -65,6 +68,7 @@ def test_multi_arg(): ) ) ), + out_type=i_tuple_field, ) assert result == expected @@ -74,6 +78,7 @@ def test_nested(): program = _make_program( [im.sym("t", i_nested_tuple_field)], im.call(im.call("tree_map")(_neg()))(im.ref("t", i_nested_tuple_field)), + out_type=i_nested_tuple_field, ) result = UnrollTreeMap.apply(program, uids=uids) @@ -88,5 +93,6 @@ def test_nested(): im.call(_neg())(im.tuple_get(1, "_utm_0")), ) ), + out_type=i_nested_tuple_field, ) assert result == expected From 97af81e1f3a91dc6a048371ba23d664735ea05a5 Mon Sep 17 00:00:00 2001 From: Sara Faghih-Naini Date: Wed, 29 Apr 2026 15:24:01 +0200 Subject: [PATCH 7/7] Apply review comments --- .../iterator/transforms/unroll_tree_map.py | 55 ++++++------------- .../iterator/type_system/type_synthesizer.py | 36 ++++-------- .../transforms_tests/test_unroll_tree_map.py | 10 ++-- 3 files changed, 30 insertions(+), 71 deletions(-) diff --git a/src/gt4py/next/iterator/transforms/unroll_tree_map.py b/src/gt4py/next/iterator/transforms/unroll_tree_map.py index 8ef9dad7ec..d2b0b7b642 100644 --- a/src/gt4py/next/iterator/transforms/unroll_tree_map.py +++ b/src/gt4py/next/iterator/transforms/unroll_tree_map.py @@ -6,6 +6,7 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause import dataclasses +import functools from gt4py import eve from gt4py.next import utils @@ -15,41 +16,6 @@ from gt4py.next.type_system import type_specifications as ts -def _unroll( - f: itir.Expr, - tup_types: list[ts.TupleType], - tup_exprs: list[itir.Expr], -) -> itir.Expr: - """Recursively expand ``tree_map(f)(tup0, tup1, ...)`` into ``make_tuple`` / ``tuple_get``.""" - assert tup_types, "tree_map requires at least one tuple argument." - n = len(tup_types[0].types) - if any(len(t.types) != n for t in tup_types[1:]): - raise ValueError( - f"All tree_map arguments must have the same tuple structure at each level, " - f"got {[len(t.types) for t in tup_types]}." - ) - - elements: list[itir.Expr] = [] - for i in range(n): - child_types = [t.types[i] for t in tup_types] - child_exprs = [im.tuple_get(i, e) for e in tup_exprs] - - all_tuples = all(isinstance(ct, ts.TupleType) for ct in child_types) - all_leaves = all(not isinstance(ct, ts.TupleType) for ct in child_types) - if all_tuples: - nested_types = [ct for ct in child_types if isinstance(ct, ts.TupleType)] - elements.append(_unroll(f, nested_types, child_exprs)) - elif all_leaves: - elements.append(im.call(f)(*child_exprs)) - else: - raise ValueError( - "All tree_map arguments must have the same tree structure " - "(all leaves must be reached simultaneously)." - ) - - return im.make_tuple(*elements) - - @dataclasses.dataclass class UnrollTreeMap(eve.NodeTranslator): PRESERVED_ANNEX_ATTRS = ("domain",) @@ -75,11 +41,22 @@ def visit_FunCall(self, node: itir.FunCall): tup_types.append(tup.type) tup_refs = [next(self.uids["_utm"]) for _ in tup_args] - body = _unroll(f, tup_types, [im.ref(r) for r in tup_refs]) - result = body - for ref_name, tup in reversed(list(zip(tup_refs, tup_args))): - result = im.let(ref_name, tup)(result) + @utils.tree_map( + collection_type=ts.TupleType, + result_collection_constructor=lambda _, elts: im.make_tuple(*elts), + with_path_arg=True, + ) + def mapper(*args): + *_el_types, path = args + return im.call(f)( + *( + functools.reduce(lambda expr, i: im.tuple_get(i, expr), path, im.ref(ref_name)) + for ref_name in tup_refs + ) + ) + + result = im.let(*zip(tup_refs, tup_args))(mapper(*tup_types)) itir_inference.reinfer(result) return result diff --git a/src/gt4py/next/iterator/type_system/type_synthesizer.py b/src/gt4py/next/iterator/type_system/type_synthesizer.py index 209d6ecf80..6437e23973 100644 --- a/src/gt4py/next/iterator/type_system/type_synthesizer.py +++ b/src/gt4py/next/iterator/type_system/type_synthesizer.py @@ -20,7 +20,6 @@ from gt4py.next.iterator import builtins, ir as itir from gt4py.next.iterator.type_system import type_specifications as it_ts from gt4py.next.type_system import type_info, type_specifications as ts -from gt4py.next.utils import tree_map def _type_synth_arg_cache_key(type_or_synth: TypeOrTypeSynthesizer) -> int: @@ -203,7 +202,7 @@ def if_( pred: ts.ScalarType | ts.DeferredType, true_branch: ts.DataType, false_branch: ts.DataType ) -> ts.DataType: if isinstance(true_branch, ts.TupleType) and isinstance(false_branch, ts.TupleType): - return tree_map( + return utils.tree_map( collection_type=ts.TupleType, result_collection_constructor=lambda _, elts: ts.TupleType(types=[*elts]), )(functools.partial(if_, pred))(true_branch, false_branch) @@ -633,8 +632,8 @@ def applied_map( return applied_map -@_register_builtin_type_synthesizer(fun_names=["tree_map"]) -def _tree_map(op: TypeSynthesizer) -> TypeSynthesizer: +@_register_builtin_type_synthesizer +def tree_map(op: TypeSynthesizer) -> TypeSynthesizer: @type_synthesizer def applied_map( *args: ts.TupleType, offset_provider_type: common.OffsetProviderType @@ -647,29 +646,14 @@ def applied_map( f"got {[type(a).__name__ for a in args]}." ) - def _recurse(*arg_types: ts.TypeSpec) -> ts.TypeSpec: - all_tuples = all(isinstance(a, ts.TupleType) for a in arg_types) - all_leaves = all(not isinstance(a, ts.TupleType) for a in arg_types) - if all_tuples: - tup_types = [a for a in arg_types if isinstance(a, ts.TupleType)] - n = len(tup_types[0].types) - if any(len(t.types) != n for t in tup_types[1:]): - raise TypeError( - f"All tree_map arguments must have the same tuple structure at each level, " - f"got {[len(t.types) for t in tup_types]}." - ) - return ts.TupleType( - types=[_recurse(*(a.types[i] for a in tup_types)) for i in range(n)] - ) - elif all_leaves: - return op(*arg_types, offset_provider_type=offset_provider_type) # type: ignore[return-value] - else: - raise TypeError( - "All tree_map arguments must have the same tree structure " - "(all leaves must be reached simultaneously)." - ) + def leaf_op(*leaf_types: ts.TypeSpec) -> ts.TypeSpec: + return op(*leaf_types, offset_provider_type=offset_provider_type) # type: ignore[return-value] - return _recurse(*args) # type: ignore[return-value] + return utils.tree_map( # type: ignore[return-value] + leaf_op, + collection_type=ts.TupleType, + result_collection_constructor=lambda _, elts: ts.TupleType(types=[*elts]), + )(*args) return applied_map diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_unroll_tree_map.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_unroll_tree_map.py index f0b4165f1a..3462ef4084 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_unroll_tree_map.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_unroll_tree_map.py @@ -60,12 +60,10 @@ def test_multi_arg(): expected = _make_program( [im.sym("a", i_tuple_field), im.sym("b", i_tuple_field)], - im.let("_utm_0", "a")( - im.let("_utm_1", "b")( - im.make_tuple( - im.call(_plus())(im.tuple_get(0, "_utm_0"), im.tuple_get(0, "_utm_1")), - im.call(_plus())(im.tuple_get(1, "_utm_0"), im.tuple_get(1, "_utm_1")), - ) + im.let(("_utm_0", "a"), ("_utm_1", "b"))( + im.make_tuple( + im.call(_plus())(im.tuple_get(0, "_utm_0"), im.tuple_get(0, "_utm_1")), + im.call(_plus())(im.tuple_get(1, "_utm_0"), im.tuple_get(1, "_utm_1")), ) ), out_type=i_tuple_field,