From ae6db4e09bea081ec50b3e3559babc0585e6c614 Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Wed, 6 May 2026 18:35:02 +0200 Subject: [PATCH 01/10] Fix stride of local dimension in lowering of if-expressions --- .../runners/dace/lowering/gtir_dataflow.py | 112 +++++++++++------- 1 file changed, 71 insertions(+), 41 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace/lowering/gtir_dataflow.py b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_dataflow.py index da590d84e0..a2fe77cba9 100644 --- a/src/gt4py/next/program_processors/runners/dace/lowering/gtir_dataflow.py +++ b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_dataflow.py @@ -681,9 +681,11 @@ def _visit_if_branch_arg( """ use_full_shape = False if isinstance(arg, (MemletExpr, ValueExpr)): + field_dims = [] arg_desc = arg.dc_node.desc(self.sdfg) arg_expr = arg elif isinstance(arg, IteratorExpr): + field_dims = [dim for dim, _ in arg.field_domain] arg_desc = arg.field.desc(self.sdfg) if deref_on_input_memlet: # If the iterator is just dereferenced inside the branch state, @@ -710,13 +712,21 @@ def _visit_if_branch_arg( inner_desc = dace.data.Scalar(arg_desc.dtype) else: # for list of values, we retrieve the local size from the corresponding offset - assert arg.gt_dtype.offset_type is not None - offset_provider_type = self.subgraph_builder.get_offset_provider_type( - arg.gt_dtype.offset_type.value + local_dim = arg.gt_dtype.offset_type + assert local_dim is not None + assert isinstance( + self.subgraph_builder.get_offset_provider_type(local_dim.value), + gtx_common.NeighborConnectivityType, ) - assert isinstance(offset_provider_type, gtx_common.NeighborConnectivityType) + # find position of the local dimension in the field layout + assert isinstance(arg_desc, dace.data.Array) + assert all(dim.kind != gtx_common.DimensionKind.LOCAL for dim in field_dims) + extended_dims = gtx_common.order_dimensions([*field_dims, local_dim]) + local_dim_pos = extended_dims.index(local_dim) inner_desc = dace.data.Array( - dtype=arg_desc.dtype, shape=[offset_provider_type.max_neighbors] + dtype=arg_desc.dtype, + shape=(arg_desc.shape[local_dim_pos],), + strides=(arg_desc.strides[local_dim_pos],), ) if param_name in if_sdfg.arrays: @@ -732,7 +742,7 @@ def _visit_if_branch_arg( else: return ValueExpr(inner_node, arg.gt_dtype) - def _visit_if_branch( + def _lower_if_state( self, if_sdfg: dace.SDFG, if_branch_state: dace.SDFGState, @@ -741,9 +751,10 @@ def _visit_if_branch( direct_deref_iterators: Iterable[str], ) -> tuple[list[DataflowInputEdge], MaybeNestedInTuple[DataflowOutputEdge]]: """ - Helper method to visit an if-branch expression and lower it to a dataflow inside the given nested SDFG and state. + Helper method to visit an expression and lower it to a dataflow inside the given nested SDFG and state. - This function is called by `_visit_if()` for each if-branch. + This function is called by `_visit_if()` for the entry state (evaulation of + if-condition) and for each if-branch. Args: if_sdfg: The nested SDFG where the if expression is lowered. @@ -759,10 +770,15 @@ def _visit_if_branch( """ assert if_branch_state in if_sdfg.states() - lambda_args = [] - lambda_params = [] + lambda_args: list[MaybeNestedInTuple[IteratorExpr | DataExpr]] = [] + lambda_params: list[gtir.Sym] = [] for pname in symbol_ref_utils.collect_symbol_refs(expr, self.symbol_map.keys()): arg = self.symbol_map[pname] + if isinstance(arg, SymbolExpr): + psymbol = im.sym(pname, gtx_dace_args.as_itir_type(arg.dc_dtype)) + lambda_args.append(arg) + lambda_params.append(psymbol) + continue if isinstance(arg, tuple): ptype = get_tuple_type(arg) # type: ignore[arg-type] psymbol = im.sym(pname, ptype) @@ -781,7 +797,7 @@ def _visit_if_branch( ) )(psymbol_tree, arg) else: - psymbol = im.sym(pname, arg.gt_dtype) # type: ignore[union-attr] + psymbol = im.sym(pname, arg.gt_dtype) deref_on_input_memlet = pname in direct_deref_iterators inner_arg = self._visit_if_branch_arg( if_sdfg, @@ -854,20 +870,16 @@ def write_output_of_nested_sdfg_to_temporary(inner_value: ValueExpr) -> ValueExp assert len(node.args) == 3 - # evaluate the if-condition that will write to a boolean scalar node - condition_value = self.visit(node.args[0]) - assert ( - ( - isinstance(condition_value.gt_dtype, ts.ScalarType) - and condition_value.gt_dtype.kind == ts.ScalarKind.BOOL - ) - if isinstance(condition_value, (MemletExpr, ValueExpr)) - else (condition_value.dc_dtype == dace.dtypes.bool_) - ) - nsdfg = dace.SDFG(self.subgraph_builder.unique_nsdfg_name("if_stmt")) nsdfg.debuginfo = gtir_to_sdfg_utils.debug_info(node, default=self.sdfg.debuginfo) + # add connectivities + for aname, adesc in self.sdfg.arrays.items(): + if gtx_dace_args.is_connectivity_identifier(aname): + adesc = adesc.clone() + adesc.transient = True + nsdfg.add_datadesc(aname, adesc) + # create states inside the nested SDFG for the if-branches if_region = dace.sdfg.state.ConditionalBlock("if") nsdfg.add_node(if_region, ensure_unique_name=True) @@ -883,16 +895,6 @@ def write_output_of_nested_sdfg_to_temporary(inner_value: ValueExpr) -> ValueExp # Use `None` for unconditional execution of else-branch, if the condition is not met. if_region.add_branch(None, else_body) - input_memlets: dict[str, MemletExpr | ValueExpr] = {} - nsdfg_symbols_mapping: Optional[dict[str, dace.symbol]] = None - - # define scalar or symbol for the condition value inside the nested SDFG - if isinstance(condition_value, SymbolExpr): - nsdfg.add_symbol("__cond", dace.dtypes.bool) - else: - nsdfg.add_scalar("__cond", dace.dtypes.bool) - input_memlets["__cond"] = condition_value - # Collect all field iterators that are shifted inside any of the then/else # branch expressions. Iterator shift expressions require the field argument # as iterator, therefore the corresponding array has to be passed with full @@ -902,7 +904,7 @@ def write_output_of_nested_sdfg_to_temporary(inner_value: ValueExpr) -> ValueExp # be lowered outside the nested SDFG, so that just the local value (a scalar # or a list of values) is passed as input to the nested SDFG. shifted_iterator_symbols = set() - for branch_expr in node.args[1:3]: + for branch_expr in node.args: for shift_node in eve.walk_values(branch_expr).filter( lambda x: cpm.is_applied_shift(x) ): @@ -919,13 +921,28 @@ def write_output_of_nested_sdfg_to_temporary(inner_value: ValueExpr) -> ValueExp if isinstance(sym_type, IteratorExpr) } direct_deref_iterators = ( - set(symbol_ref_utils.collect_symbol_refs(node.args[1:3], iterator_symbols)) + set(symbol_ref_utils.collect_symbol_refs(node.args, iterator_symbols)) - shifted_iterator_symbols ) + # collect all memlets that are needed as input to the nested SDFG + input_memlets: dict[str, MemletExpr | ValueExpr] = {} + + # evaluate the if-condition that will write to a boolean scalar node + in_edges, out_edge = self._lower_if_state( + nsdfg, entry_state, node.args[0], input_memlets, direct_deref_iterators + ) + for edge in in_edges: + edge.connect(map_entry=None) + assert isinstance(out_edge, DataflowOutputEdge) + condition_node = out_edge.result.dc_node + # write the boolean result to the '__cond' synbol on the interstate edge + nsdfg.add_symbol("__cond", dace.dtypes.bool) + nsdfg.out_edges(entry_state)[0].data.assignments["__cond"] = condition_node.data + for nstate, arg in zip([tstate, fstate], node.args[1:3]): # visit each if-branch in the corresponding state of the nested SDFG - in_edges, out_edges = self._visit_if_branch( + in_edges, out_edges = self._lower_if_state( nsdfg, nstate, arg, input_memlets, direct_deref_iterators ) for edge in in_edges: @@ -948,15 +965,18 @@ def write_output_of_nested_sdfg_to_temporary(inner_value: ValueExpr) -> ValueExp outputs = {outval.dc_node.data for outval in gtx_utils.flatten_nested_tuple((result,))} - # all free symbols are mapped to the symbols available in parent SDFG - nsdfg_symbols_mapping = {str(sym): sym for sym in nsdfg.free_symbols} - if isinstance(condition_value, SymbolExpr): - nsdfg_symbols_mapping["__cond"] = condition_value.value + # map the connectivities that were used inside the nested SDFG + used_connectivities = { + aname + for aname, adesc in nsdfg.arrays.items() + if gtx_dace_args.is_connectivity_identifier(aname) and not adesc.transient + } + nsdfg_node = self.state.add_nested_sdfg( nsdfg, - inputs=set(input_memlets.keys()), + inputs=(used_connectivities | input_memlets.keys()), outputs=outputs, - symbol_mapping=nsdfg_symbols_mapping, + symbol_mapping=None, # free symbols are mapped to symbols in the parent SDFG ) for inner, input_expr in input_memlets.items(): @@ -971,6 +991,16 @@ def write_output_of_nested_sdfg_to_temporary(inner_value: ValueExpr) -> ValueExp self.sdfg.make_array_memlet(input_expr.dc_node.data), ) + for conn in used_connectivities: + desc = self.sdfg.data(conn) + desc.transient = False + self._add_input_data_edge( + self.state.add_access(conn), + dace_subsets.Range.from_array(desc), + nsdfg_node, + conn, + ) + return ( gtx_utils.tree_map(write_output_of_nested_sdfg_to_temporary)(result) if isinstance(result, tuple) From dbbfd526cf04d76c4342ba474215730c488a82a1 Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Wed, 6 May 2026 18:35:17 +0200 Subject: [PATCH 02/10] edit --- .../program_processors/runners/dace/lowering/gtir_dataflow.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/gt4py/next/program_processors/runners/dace/lowering/gtir_dataflow.py b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_dataflow.py index a2fe77cba9..6690ca21ab 100644 --- a/src/gt4py/next/program_processors/runners/dace/lowering/gtir_dataflow.py +++ b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_dataflow.py @@ -751,7 +751,8 @@ def _lower_if_state( direct_deref_iterators: Iterable[str], ) -> tuple[list[DataflowInputEdge], MaybeNestedInTuple[DataflowOutputEdge]]: """ - Helper method to visit an expression and lower it to a dataflow inside the given nested SDFG and state. + Helper method to visit each argument of an if-expression and lower it to + a dataflow gragh inside the given nested SDFG and state. This function is called by `_visit_if()` for the entry state (evaulation of if-condition) and for each if-branch. From 2cd7728233ea3577ef2d67ad2c2f6f4ea04178c5 Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Wed, 6 May 2026 18:42:30 +0200 Subject: [PATCH 03/10] edit --- .../runners/dace/lowering/gtir_dataflow.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/gt4py/next/program_processors/runners/dace/lowering/gtir_dataflow.py b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_dataflow.py index 6690ca21ab..c74c282df6 100644 --- a/src/gt4py/next/program_processors/runners/dace/lowering/gtir_dataflow.py +++ b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_dataflow.py @@ -1921,6 +1921,9 @@ def visit_Literal(self, node: gtir.Literal) -> SymbolExpr: dc_dtype = gtx_dace_args.as_dace_type(node.type) return SymbolExpr(node.value, dc_dtype) + def visit_OffsetLiteral(self, node: gtir.OffsetLiteral) -> SymbolExpr: + return SymbolExpr(node.value, gtir_to_sdfg_types.INDEX_DTYPE) + def visit_SymRef(self, node: gtir.SymRef) -> MaybeNestedInTuple[IteratorExpr | DataExpr]: param = str(node.id) if param in self.symbol_map: @@ -1935,7 +1938,7 @@ def translate_lambda_to_dataflow( state: dace.SDFGState, sdfg_builder: gtir_to_sdfg.DataflowBuilder, node: gtir.Lambda, - args: Sequence[MaybeNestedInTuple[IteratorExpr | MemletExpr | ValueExpr]], + args: Sequence[MaybeNestedInTuple[IteratorExpr | DataExpr]], ) -> tuple[list[DataflowInputEdge], MaybeNestedInTuple[DataflowOutputEdge]]: """ Entry point to visit a `Lambda` node and lower it to a dataflow graph, From ecd57805e7fbf3e89c2c0e73fb1cb3fb3ce23bbf Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Thu, 7 May 2026 09:33:11 +0200 Subject: [PATCH 04/10] edit --- .../runners/dace/lowering/gtir_dataflow.py | 71 ++++++++++--------- 1 file changed, 38 insertions(+), 33 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace/lowering/gtir_dataflow.py b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_dataflow.py index c74c282df6..0ba5583e51 100644 --- a/src/gt4py/next/program_processors/runners/dace/lowering/gtir_dataflow.py +++ b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_dataflow.py @@ -659,25 +659,25 @@ def _visit_deref(self, node: gtir.FunCall) -> DataExpr: field_desc.dtype, deref_node, connector_mapping["val"] ) - def _visit_if_branch_arg( + def _lower_if_state_arg( self, - if_sdfg: dace.SDFG, - if_branch_state: dace.SDFGState, + sdfg: dace.SDFG, + state: dace.SDFGState, param_name: str, arg: IteratorExpr | DataExpr, deref_on_input_memlet: bool, - if_sdfg_input_memlets: dict[str, MemletExpr | ValueExpr], + input_memlets: dict[str, MemletExpr | ValueExpr], ) -> IteratorExpr | ValueExpr: """ - Helper method to be called by `_visit_if_branch()` to visit the input arguments. + Helper method to be called by `_lower_if_state()` to visit the input arguments. Args: - if_sdfg: The nested SDFG where the if expression is lowered. - if_branch_state: The state inside the nested SDFG where the if branch is lowered. + sdfg: The SDFG where the if expression is lowered. + state: The state inside the SDFG where the argument is used. param_name: The parameter name of the input argument. arg: The input argument expression. deref_on_input_memlet: When True, the given iterator argument can be dereferenced on the input memlet. - if_sdfg_input_memlets: The memlets that provide input data to the nested SDFG, will be update inside this function. + input_memlets: The memlets that provide input data to the SDFG, which is updated inside this function. """ use_full_shape = False if isinstance(arg, (MemletExpr, ValueExpr)): @@ -709,7 +709,12 @@ def _visit_if_branch_arg( inner_desc = arg_desc.clone() inner_desc.transient = False elif isinstance(arg.gt_dtype, ts.ScalarType): - inner_desc = dace.data.Scalar(arg_desc.dtype) + if len(arg_desc.shape) == 1: + # TODO(edopao): cannot use a scalar because of an issue in gpu codegen, + # wich leads to compilation error: cannot convert 'const double' to 'const double*' + inner_desc = dace.data.Array(dtype=arg_desc.dtype, shape=(1,)) + else: + inner_desc = dace.data.Scalar(arg_desc.dtype) else: # for list of values, we retrieve the local size from the corresponding offset local_dim = arg.gt_dtype.offset_type @@ -729,14 +734,14 @@ def _visit_if_branch_arg( strides=(arg_desc.strides[local_dim_pos],), ) - if param_name in if_sdfg.arrays: + if param_name in sdfg.arrays: # the data desciptor was added by the visitor of the other branch expression - assert if_sdfg.data(param_name) == inner_desc + assert sdfg.data(param_name) == inner_desc else: - if_sdfg.add_datadesc(param_name, inner_desc) - if_sdfg_input_memlets[param_name] = arg_expr + sdfg.add_datadesc(param_name, inner_desc) + input_memlets[param_name] = arg_expr - inner_node = if_branch_state.add_access(param_name) + inner_node = state.add_access(param_name) if isinstance(arg, IteratorExpr) and use_full_shape: return IteratorExpr(inner_node, arg.gt_dtype, arg.field_domain, arg.indices) else: @@ -744,24 +749,24 @@ def _visit_if_branch_arg( def _lower_if_state( self, - if_sdfg: dace.SDFG, - if_branch_state: dace.SDFGState, + sdfg: dace.SDFG, + state: dace.SDFGState, expr: gtir.Expr, - if_sdfg_input_memlets: dict[str, MemletExpr | ValueExpr], + input_memlets: dict[str, MemletExpr | ValueExpr], direct_deref_iterators: Iterable[str], ) -> tuple[list[DataflowInputEdge], MaybeNestedInTuple[DataflowOutputEdge]]: """ - Helper method to visit each argument of an if-expression and lower it to - a dataflow gragh inside the given nested SDFG and state. + Helper method to visit the subexpressions of an if-expression and lower them + to a dataflow graph inside the given nested SDFG and state. - This function is called by `_visit_if()` for the entry state (evaulation of + This function is called by `_visit_if()` for the entry state (evaluation of if-condition) and for each if-branch. Args: - if_sdfg: The nested SDFG where the if expression is lowered. - if_branch_state: The state inside the nested SDFG where the if branch is lowered. + sdfg: The SDFG where the if expression is lowered. + state: The state inside the SDFG where the if subexpression is lowered. expr: The if branch expression to lower. - if_sdfg_input_memlets: The memlets that provide input data to the nested SDFG, will be update inside this function. + input_memlets: The memlets that provide input data to the SDFG, which are update inside this function. direct_deref_iterators: Fields that are accessed with direct iterator deref, without any shift. Returns: @@ -769,7 +774,7 @@ def _lower_if_state( - the list of input edges for the parent dataflow - the tree representation of output data, in the form of a tuple of data edges. """ - assert if_branch_state in if_sdfg.states() + assert state in sdfg.states() lambda_args: list[MaybeNestedInTuple[IteratorExpr | DataExpr]] = [] lambda_params: list[gtir.Sym] = [] @@ -787,26 +792,26 @@ def _lower_if_state( deref_on_input_memlet = pname in direct_deref_iterators inner_arg = gtx_utils.tree_map( lambda tsym, targ, deref_on_input_memlet=deref_on_input_memlet: ( - self._visit_if_branch_arg( - if_sdfg, - if_branch_state, + self._lower_if_state_arg( + sdfg, + state, str(tsym.id), targ, deref_on_input_memlet, - if_sdfg_input_memlets, + input_memlets, ) ) )(psymbol_tree, arg) else: psymbol = im.sym(pname, arg.gt_dtype) deref_on_input_memlet = pname in direct_deref_iterators - inner_arg = self._visit_if_branch_arg( - if_sdfg, - if_branch_state, + inner_arg = self._lower_if_state_arg( + sdfg, + state, pname, arg, deref_on_input_memlet, - if_sdfg_input_memlets, + input_memlets, ) lambda_args.append(inner_arg) lambda_params.append(psymbol) @@ -814,7 +819,7 @@ def _lower_if_state( # visit each branch of the if-statement as if it was a Lambda node lambda_node = gtir.Lambda(params=lambda_params, expr=expr) input_edges, output_tree = translate_lambda_to_dataflow( - if_sdfg, if_branch_state, self.subgraph_builder, lambda_node, lambda_args + sdfg, state, self.subgraph_builder, lambda_node, lambda_args ) return input_edges, output_tree From a8fa800672ef5095702e7463ec814254f76cef2e Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Thu, 7 May 2026 12:04:52 +0200 Subject: [PATCH 05/10] edit --- .../program_processors/runners/dace/lowering/gtir_dataflow.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/gt4py/next/program_processors/runners/dace/lowering/gtir_dataflow.py b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_dataflow.py index 0ba5583e51..904c097935 100644 --- a/src/gt4py/next/program_processors/runners/dace/lowering/gtir_dataflow.py +++ b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_dataflow.py @@ -709,7 +709,7 @@ def _lower_if_state_arg( inner_desc = arg_desc.clone() inner_desc.transient = False elif isinstance(arg.gt_dtype, ts.ScalarType): - if len(arg_desc.shape) == 1: + if isinstance(arg, MemletExpr) and len(arg_expr.gt_field.dims) == 1: # TODO(edopao): cannot use a scalar because of an issue in gpu codegen, # wich leads to compilation error: cannot convert 'const double' to 'const double*' inner_desc = dace.data.Array(dtype=arg_desc.dtype, shape=(1,)) From 80c45e6432589964475d184f570be7604f087c15 Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Thu, 7 May 2026 12:11:15 +0200 Subject: [PATCH 06/10] edit --- .../program_processors/runners/dace/lowering/gtir_dataflow.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/gt4py/next/program_processors/runners/dace/lowering/gtir_dataflow.py b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_dataflow.py index 904c097935..b7eb66f849 100644 --- a/src/gt4py/next/program_processors/runners/dace/lowering/gtir_dataflow.py +++ b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_dataflow.py @@ -709,7 +709,7 @@ def _lower_if_state_arg( inner_desc = arg_desc.clone() inner_desc.transient = False elif isinstance(arg.gt_dtype, ts.ScalarType): - if isinstance(arg, MemletExpr) and len(arg_expr.gt_field.dims) == 1: + if isinstance(arg, MemletExpr) and len(arg.gt_field.dims) == 1: # TODO(edopao): cannot use a scalar because of an issue in gpu codegen, # wich leads to compilation error: cannot convert 'const double' to 'const double*' inner_desc = dace.data.Array(dtype=arg_desc.dtype, shape=(1,)) From ba91d43b05f45684afd85128277ce695f6bf3315 Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Thu, 7 May 2026 13:34:00 +0200 Subject: [PATCH 07/10] edit --- .../runners/dace/lowering/gtir_dataflow.py | 26 +++++++++---------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace/lowering/gtir_dataflow.py b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_dataflow.py index b7eb66f849..7373b5a560 100644 --- a/src/gt4py/next/program_processors/runners/dace/lowering/gtir_dataflow.py +++ b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_dataflow.py @@ -673,11 +673,11 @@ def _lower_if_state_arg( Args: sdfg: The SDFG where the if expression is lowered. - state: The state inside the SDFG where the argument is used. - param_name: The parameter name of the input argument. - arg: The input argument expression. + state: The state inside the given SDFG where the argument is used. + param_name: The name of the input argument. + arg: The expression corresponding to the input argument. deref_on_input_memlet: When True, the given iterator argument can be dereferenced on the input memlet. - input_memlets: The memlets that provide input data to the SDFG, which is updated inside this function. + input_memlets: The memlets that provide input data to the SDFG, will be updated inside this function. """ use_full_shape = False if isinstance(arg, (MemletExpr, ValueExpr)): @@ -710,8 +710,8 @@ def _lower_if_state_arg( inner_desc.transient = False elif isinstance(arg.gt_dtype, ts.ScalarType): if isinstance(arg, MemletExpr) and len(arg.gt_field.dims) == 1: - # TODO(edopao): cannot use a scalar because of an issue in gpu codegen, - # wich leads to compilation error: cannot convert 'const double' to 'const double*' + # TODO(edopao): we cannot use a scalar because of an issue in gpu codegen, + # which leads to compilation error: cannot convert 'const double' to 'const double*' inner_desc = dace.data.Array(dtype=arg_desc.dtype, shape=(1,)) else: inner_desc = dace.data.Scalar(arg_desc.dtype) @@ -756,17 +756,17 @@ def _lower_if_state( direct_deref_iterators: Iterable[str], ) -> tuple[list[DataflowInputEdge], MaybeNestedInTuple[DataflowOutputEdge]]: """ - Helper method to visit the subexpressions of an if-expression and lower them - to a dataflow graph inside the given nested SDFG and state. + Helper method to visit a subexpression inside an if-node and lower it + to a dataflow graph inside the given SDFG and state. This function is called by `_visit_if()` for the entry state (evaluation of - if-condition) and for each if-branch. + if-condition) and for each if-branch state. Args: sdfg: The SDFG where the if expression is lowered. - state: The state inside the SDFG where the if subexpression is lowered. - expr: The if branch expression to lower. - input_memlets: The memlets that provide input data to the SDFG, which are update inside this function. + state: The state inside the given SDFG where the subexpression is lowered. + expr: The subexpression to lower. + input_memlets: The memlets that provide input data to the SDFG, will be updated inside this function. direct_deref_iterators: Fields that are accessed with direct iterator deref, without any shift. Returns: @@ -942,7 +942,7 @@ def write_output_of_nested_sdfg_to_temporary(inner_value: ValueExpr) -> ValueExp edge.connect(map_entry=None) assert isinstance(out_edge, DataflowOutputEdge) condition_node = out_edge.result.dc_node - # write the boolean result to the '__cond' synbol on the interstate edge + # write the boolean result to the '__cond' symbol on the interstate edge nsdfg.add_symbol("__cond", dace.dtypes.bool) nsdfg.out_edges(entry_state)[0].data.assignments["__cond"] = condition_node.data From f56010746631697427e6c12edd211f04436dd969 Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Thu, 7 May 2026 14:29:02 +0200 Subject: [PATCH 08/10] apply review comments --- .../runners/dace/lowering/gtir_dataflow.py | 85 ++++++++++--------- 1 file changed, 46 insertions(+), 39 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace/lowering/gtir_dataflow.py b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_dataflow.py index 7373b5a560..0e2d2142e3 100644 --- a/src/gt4py/next/program_processors/runners/dace/lowering/gtir_dataflow.py +++ b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_dataflow.py @@ -659,7 +659,7 @@ def _visit_deref(self, node: gtir.FunCall) -> DataExpr: field_desc.dtype, deref_node, connector_mapping["val"] ) - def _lower_if_state_arg( + def _visit_if_branch_arg( self, sdfg: dace.SDFG, state: dace.SDFGState, @@ -669,13 +669,13 @@ def _lower_if_state_arg( input_memlets: dict[str, MemletExpr | ValueExpr], ) -> IteratorExpr | ValueExpr: """ - Helper method to be called by `_lower_if_state()` to visit the input arguments. + Helper method to be called by `_visit_if_branch()` to visit the input arguments. Args: sdfg: The SDFG where the if expression is lowered. - state: The state inside the given SDFG where the argument is used. - param_name: The name of the input argument. - arg: The expression corresponding to the input argument. + state: The state inside the given SDFG where the if branch is lowered. + param_name: The parameter name of the input argument. + arg: The input argument expression. deref_on_input_memlet: When True, the given iterator argument can be dereferenced on the input memlet. input_memlets: The memlets that provide input data to the SDFG, will be updated inside this function. """ @@ -747,7 +747,7 @@ def _lower_if_state_arg( else: return ValueExpr(inner_node, arg.gt_dtype) - def _lower_if_state( + def _visit_if_branch( self, sdfg: dace.SDFG, state: dace.SDFGState, @@ -756,16 +756,14 @@ def _lower_if_state( direct_deref_iterators: Iterable[str], ) -> tuple[list[DataflowInputEdge], MaybeNestedInTuple[DataflowOutputEdge]]: """ - Helper method to visit a subexpression inside an if-node and lower it - to a dataflow graph inside the given SDFG and state. + Helper method to visit an if-branch expression and lower it to a dataflow inside the given nested SDFG and state. - This function is called by `_visit_if()` for the entry state (evaluation of - if-condition) and for each if-branch state. + This function is called by `_visit_if()` for each if-branch. Args: sdfg: The SDFG where the if expression is lowered. - state: The state inside the given SDFG where the subexpression is lowered. - expr: The subexpression to lower. + state: The state inside the given SDFG where the if branch is lowered. + expr: The if branch expression to lower. input_memlets: The memlets that provide input data to the SDFG, will be updated inside this function. direct_deref_iterators: Fields that are accessed with direct iterator deref, without any shift. @@ -780,19 +778,18 @@ def _lower_if_state( lambda_params: list[gtir.Sym] = [] for pname in symbol_ref_utils.collect_symbol_refs(expr, self.symbol_map.keys()): arg = self.symbol_map[pname] + inner_arg: MaybeNestedInTuple[IteratorExpr | DataExpr] if isinstance(arg, SymbolExpr): psymbol = im.sym(pname, gtx_dace_args.as_itir_type(arg.dc_dtype)) - lambda_args.append(arg) - lambda_params.append(psymbol) - continue - if isinstance(arg, tuple): + inner_arg = arg + elif isinstance(arg, tuple): ptype = get_tuple_type(arg) # type: ignore[arg-type] psymbol = im.sym(pname, ptype) psymbol_tree = gtir_to_sdfg_utils.make_symbol_tree(pname, ptype) deref_on_input_memlet = pname in direct_deref_iterators inner_arg = gtx_utils.tree_map( lambda tsym, targ, deref_on_input_memlet=deref_on_input_memlet: ( - self._lower_if_state_arg( + self._visit_if_branch_arg( sdfg, state, str(tsym.id), @@ -805,7 +802,7 @@ def _lower_if_state( else: psymbol = im.sym(pname, arg.gt_dtype) deref_on_input_memlet = pname in direct_deref_iterators - inner_arg = self._lower_if_state_arg( + inner_arg = self._visit_if_branch_arg( sdfg, state, pname, @@ -876,6 +873,17 @@ def write_output_of_nested_sdfg_to_temporary(inner_value: ValueExpr) -> ValueExp assert len(node.args) == 3 + # evaluate the if-condition that will write to a boolean scalar node + condition_value = self.visit(node.args[0]) + assert ( + ( + isinstance(condition_value.gt_dtype, ts.ScalarType) + and condition_value.gt_dtype.kind == ts.ScalarKind.BOOL + ) + if isinstance(condition_value, (MemletExpr, ValueExpr)) + else (condition_value.dc_dtype == dace.dtypes.bool_) + ) + nsdfg = dace.SDFG(self.subgraph_builder.unique_nsdfg_name("if_stmt")) nsdfg.debuginfo = gtir_to_sdfg_utils.debug_info(node, default=self.sdfg.debuginfo) @@ -901,6 +909,16 @@ def write_output_of_nested_sdfg_to_temporary(inner_value: ValueExpr) -> ValueExp # Use `None` for unconditional execution of else-branch, if the condition is not met. if_region.add_branch(None, else_body) + input_memlets: dict[str, MemletExpr | ValueExpr] = {} + nsdfg_symbols_mapping: Optional[dict[str, dace.symbol]] = None + + # define scalar or symbol for the condition value inside the nested SDFG + if isinstance(condition_value, SymbolExpr): + nsdfg.add_symbol("__cond", dace.dtypes.bool) + else: + nsdfg.add_scalar("__cond", dace.dtypes.bool) + input_memlets["__cond"] = condition_value + # Collect all field iterators that are shifted inside any of the then/else # branch expressions. Iterator shift expressions require the field argument # as iterator, therefore the corresponding array has to be passed with full @@ -910,7 +928,7 @@ def write_output_of_nested_sdfg_to_temporary(inner_value: ValueExpr) -> ValueExp # be lowered outside the nested SDFG, so that just the local value (a scalar # or a list of values) is passed as input to the nested SDFG. shifted_iterator_symbols = set() - for branch_expr in node.args: + for branch_expr in node.args[1:3]: for shift_node in eve.walk_values(branch_expr).filter( lambda x: cpm.is_applied_shift(x) ): @@ -927,28 +945,13 @@ def write_output_of_nested_sdfg_to_temporary(inner_value: ValueExpr) -> ValueExp if isinstance(sym_type, IteratorExpr) } direct_deref_iterators = ( - set(symbol_ref_utils.collect_symbol_refs(node.args, iterator_symbols)) + set(symbol_ref_utils.collect_symbol_refs(node.args[1:3], iterator_symbols)) - shifted_iterator_symbols ) - # collect all memlets that are needed as input to the nested SDFG - input_memlets: dict[str, MemletExpr | ValueExpr] = {} - - # evaluate the if-condition that will write to a boolean scalar node - in_edges, out_edge = self._lower_if_state( - nsdfg, entry_state, node.args[0], input_memlets, direct_deref_iterators - ) - for edge in in_edges: - edge.connect(map_entry=None) - assert isinstance(out_edge, DataflowOutputEdge) - condition_node = out_edge.result.dc_node - # write the boolean result to the '__cond' symbol on the interstate edge - nsdfg.add_symbol("__cond", dace.dtypes.bool) - nsdfg.out_edges(entry_state)[0].data.assignments["__cond"] = condition_node.data - for nstate, arg in zip([tstate, fstate], node.args[1:3]): # visit each if-branch in the corresponding state of the nested SDFG - in_edges, out_edges = self._lower_if_state( + in_edges, out_edges = self._visit_if_branch( nsdfg, nstate, arg, input_memlets, direct_deref_iterators ) for edge in in_edges: @@ -978,11 +981,15 @@ def write_output_of_nested_sdfg_to_temporary(inner_value: ValueExpr) -> ValueExp if gtx_dace_args.is_connectivity_identifier(aname) and not adesc.transient } + # all free symbols are mapped to the symbols available in parent SDFG + nsdfg_symbols_mapping = {str(sym): sym for sym in nsdfg.free_symbols} + if isinstance(condition_value, SymbolExpr): + nsdfg_symbols_mapping["__cond"] = condition_value.value nsdfg_node = self.state.add_nested_sdfg( nsdfg, - inputs=(used_connectivities | input_memlets.keys()), - outputs=outputs, - symbol_mapping=None, # free symbols are mapped to symbols in the parent SDFG + inputs=sorted(used_connectivities | input_memlets.keys()), + outputs=sorted(outputs), + symbol_mapping=nsdfg_symbols_mapping, ) for inner, input_expr in input_memlets.items(): From c174e5ac10c50fef89445228a3a19cc4e59b280d Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Fri, 8 May 2026 11:15:01 +0200 Subject: [PATCH 09/10] fix --- .../runners/dace/lowering/gtir_dataflow.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace/lowering/gtir_dataflow.py b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_dataflow.py index 0e2d2142e3..fe54255918 100644 --- a/src/gt4py/next/program_processors/runners/dace/lowering/gtir_dataflow.py +++ b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_dataflow.py @@ -837,9 +837,16 @@ def _visit_if_branch_result( # If the result is currently written to a transient node, inside the nested SDFG, # we need to allocate a non-transient data node. result_desc = edge.result.dc_node.desc(sdfg) - output_desc = result_desc.clone() - output_desc.transient = False - output_data = sdfg.add_datadesc(output_data, output_desc, find_new_name=True) + if isinstance(sym.type, ts.ScalarType) and isinstance(result_desc, dace.data.Array): + # TODO(edopao): a scalar should not be represented as an array, but + # currently this can happen because of an issue workaround, see the + # todo comment above in `_visit_if_branch_arg()`. + assert len(result_desc.shape) == 1 and result_desc.shape[0] == 1 + _, output_desc = sdfg.add_scalar(output_data, result_desc.dtype) + else: + output_desc = result_desc.clone() + output_desc.transient = False + output_data = sdfg.add_datadesc(output_data, output_desc, find_new_name=True) output_node = state.add_access(output_data) state.add_nedge( edge.result.dc_node, From 5128b69a5e551f8d0b4c610e02de1525d0c83b6d Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Tue, 12 May 2026 13:37:01 +0200 Subject: [PATCH 10/10] preserve order of connectivities --- .../program_processors/runners/dace/lowering/gtir_dataflow.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace/lowering/gtir_dataflow.py b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_dataflow.py index 41048593dc..e8d1914aa8 100644 --- a/src/gt4py/next/program_processors/runners/dace/lowering/gtir_dataflow.py +++ b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_dataflow.py @@ -982,11 +982,11 @@ def write_output_of_nested_sdfg_to_temporary(inner_value: ValueExpr) -> ValueExp outputs = {outval.dc_node.data for outval in gtx_utils.flatten_nested_tuple((result,))} # map the connectivities that were used inside the nested SDFG - used_connectivities = { + used_connectivities = [ aname for aname, adesc in nsdfg.arrays.items() if gtx_dace_args.is_connectivity_identifier(aname) and not adesc.transient - } + ] # all free symbols are mapped to the symbols available in parent SDFG nsdfg_symbols_mapping = {str(sym): sym for sym in nsdfg.free_symbols}