diff --git a/src/gt4py/next/program_processors/runners/dace/lowering/gtir_dataflow.py b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_dataflow.py index e7dc855428..e8d1914aa8 100644 --- a/src/gt4py/next/program_processors/runners/dace/lowering/gtir_dataflow.py +++ b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_dataflow.py @@ -661,29 +661,31 @@ def _visit_deref(self, node: gtir.FunCall) -> DataExpr: def _visit_if_branch_arg( self, - if_sdfg: dace.SDFG, - if_branch_state: dace.SDFGState, + sdfg: dace.SDFG, + state: dace.SDFGState, param_name: str, arg: IteratorExpr | DataExpr, deref_on_input_memlet: bool, - if_sdfg_input_memlets: dict[str, MemletExpr | ValueExpr], + input_memlets: dict[str, MemletExpr | ValueExpr], ) -> IteratorExpr | ValueExpr: """ Helper method to be called by `_visit_if_branch()` to visit the input arguments. Args: - if_sdfg: The nested SDFG where the if expression is lowered. - if_branch_state: The state inside the nested SDFG where the if branch is lowered. + sdfg: The SDFG where the if expression is lowered. + state: The state inside the given SDFG where the if branch is lowered. param_name: The parameter name of the input argument. arg: The input argument expression. deref_on_input_memlet: When True, the given iterator argument can be dereferenced on the input memlet. - if_sdfg_input_memlets: The memlets that provide input data to the nested SDFG, will be update inside this function. + input_memlets: The memlets that provide input data to the SDFG, will be updated inside this function. """ use_full_shape = False if isinstance(arg, (MemletExpr, ValueExpr)): + field_dims = [] arg_desc = arg.dc_node.desc(self.sdfg) arg_expr = arg elif isinstance(arg, IteratorExpr): + field_dims = [dim for dim, _ in arg.field_domain] arg_desc = arg.field.desc(self.sdfg) if deref_on_input_memlet: # If the iterator is just dereferenced inside the branch state, @@ -707,26 +709,39 @@ def _visit_if_branch_arg( inner_desc = arg_desc.clone() inner_desc.transient = False elif isinstance(arg.gt_dtype, ts.ScalarType): - inner_desc = dace.data.Scalar(arg_desc.dtype) + if isinstance(arg, MemletExpr) and len(arg.gt_field.dims) == 1: + # TODO(edopao): we cannot use a scalar because of an issue in gpu codegen, + # which leads to compilation error: cannot convert 'const double' to 'const double*' + inner_desc = dace.data.Array(dtype=arg_desc.dtype, shape=(1,)) + else: + inner_desc = dace.data.Scalar(arg_desc.dtype) else: # for list of values, we retrieve the local size from the corresponding offset - assert arg.gt_dtype.offset_type is not None - offset_provider_type = self.subgraph_builder.get_offset_provider_type( - arg.gt_dtype.offset_type.value + local_dim = arg.gt_dtype.offset_type + assert local_dim is not None + assert isinstance( + self.subgraph_builder.get_offset_provider_type(local_dim.value), + gtx_common.NeighborConnectivityType, ) - assert isinstance(offset_provider_type, gtx_common.NeighborConnectivityType) + # find position of the local dimension in the field layout + assert isinstance(arg_desc, dace.data.Array) + assert all(dim.kind != gtx_common.DimensionKind.LOCAL for dim in field_dims) + extended_dims = gtx_common.order_dimensions([*field_dims, local_dim]) + local_dim_pos = extended_dims.index(local_dim) inner_desc = dace.data.Array( - dtype=arg_desc.dtype, shape=[offset_provider_type.max_neighbors] + dtype=arg_desc.dtype, + shape=(arg_desc.shape[local_dim_pos],), + strides=(arg_desc.strides[local_dim_pos],), ) - if param_name in if_sdfg.arrays: + if param_name in sdfg.arrays: # the data desciptor was added by the visitor of the other branch expression - assert if_sdfg.data(param_name) == inner_desc + assert sdfg.data(param_name) == inner_desc else: - if_sdfg.add_datadesc(param_name, inner_desc) - if_sdfg_input_memlets[param_name] = arg_expr + sdfg.add_datadesc(param_name, inner_desc) + input_memlets[param_name] = arg_expr - inner_node = if_branch_state.add_access(param_name) + inner_node = state.add_access(param_name) if isinstance(arg, IteratorExpr) and use_full_shape: return IteratorExpr(inner_node, arg.gt_dtype, arg.field_domain, arg.indices) else: @@ -734,10 +749,10 @@ def _visit_if_branch_arg( def _visit_if_branch( self, - if_sdfg: dace.SDFG, - if_branch_state: dace.SDFGState, + sdfg: dace.SDFG, + state: dace.SDFGState, expr: gtir.Expr, - if_sdfg_input_memlets: dict[str, MemletExpr | ValueExpr], + input_memlets: dict[str, MemletExpr | ValueExpr], direct_deref_iterators: Iterable[str], ) -> tuple[list[DataflowInputEdge], MaybeNestedInTuple[DataflowOutputEdge]]: """ @@ -746,10 +761,10 @@ def _visit_if_branch( This function is called by `_visit_if()` for each if-branch. Args: - if_sdfg: The nested SDFG where the if expression is lowered. - if_branch_state: The state inside the nested SDFG where the if branch is lowered. + sdfg: The SDFG where the if expression is lowered. + state: The state inside the given SDFG where the if branch is lowered. expr: The if branch expression to lower. - if_sdfg_input_memlets: The memlets that provide input data to the nested SDFG, will be update inside this function. + input_memlets: The memlets that provide input data to the SDFG, will be updated inside this function. direct_deref_iterators: Fields that are accessed with direct iterator deref, without any shift. Returns: @@ -757,13 +772,17 @@ def _visit_if_branch( - the list of input edges for the parent dataflow - the tree representation of output data, in the form of a tuple of data edges. """ - assert if_branch_state in if_sdfg.states() + assert state in sdfg.states() - lambda_args = [] - lambda_params = [] + lambda_args: list[MaybeNestedInTuple[IteratorExpr | DataExpr]] = [] + lambda_params: list[gtir.Sym] = [] for pname in symbol_ref_utils.collect_symbol_refs(expr, self.symbol_map.keys()): arg = self.symbol_map[pname] - if isinstance(arg, tuple): + inner_arg: MaybeNestedInTuple[IteratorExpr | DataExpr] + if isinstance(arg, SymbolExpr): + psymbol = im.sym(pname, gtx_dace_args.as_itir_type(arg.dc_dtype)) + inner_arg = arg + elif isinstance(arg, tuple): ptype = get_tuple_type(arg) # type: ignore[arg-type] psymbol = im.sym(pname, ptype) psymbol_tree = gtir_to_sdfg_utils.make_symbol_tree(pname, ptype) @@ -771,25 +790,25 @@ def _visit_if_branch( inner_arg = gtx_utils.tree_map( lambda tsym, targ, deref_on_input_memlet=deref_on_input_memlet: ( self._visit_if_branch_arg( - if_sdfg, - if_branch_state, + sdfg, + state, str(tsym.id), targ, deref_on_input_memlet, - if_sdfg_input_memlets, + input_memlets, ) ) )(psymbol_tree, arg) else: - psymbol = im.sym(pname, arg.gt_dtype) # type: ignore[union-attr] + psymbol = im.sym(pname, arg.gt_dtype) deref_on_input_memlet = pname in direct_deref_iterators inner_arg = self._visit_if_branch_arg( - if_sdfg, - if_branch_state, + sdfg, + state, pname, arg, deref_on_input_memlet, - if_sdfg_input_memlets, + input_memlets, ) lambda_args.append(inner_arg) lambda_params.append(psymbol) @@ -797,7 +816,7 @@ def _visit_if_branch( # visit each branch of the if-statement as if it was a Lambda node lambda_node = gtir.Lambda(params=lambda_params, expr=expr) input_edges, output_tree = translate_lambda_to_dataflow( - if_sdfg, if_branch_state, self.subgraph_builder, lambda_node, lambda_args + sdfg, state, self.subgraph_builder, lambda_node, lambda_args ) return input_edges, output_tree @@ -818,9 +837,16 @@ def _visit_if_branch_result( # If the result is currently written to a transient node, inside the nested SDFG, # we need to allocate a non-transient data node. result_desc = edge.result.dc_node.desc(sdfg) - output_desc = result_desc.clone() - output_desc.transient = False - output_data = sdfg.add_datadesc(output_data, output_desc, find_new_name=True) + if isinstance(sym.type, ts.ScalarType) and isinstance(result_desc, dace.data.Array): + # TODO(edopao): a scalar should not be represented as an array, but + # currently this can happen because of an issue workaround, see the + # todo comment above in `_visit_if_branch_arg()`. + assert len(result_desc.shape) == 1 and result_desc.shape[0] == 1 + _, output_desc = sdfg.add_scalar(output_data, result_desc.dtype) + else: + output_desc = result_desc.clone() + output_desc.transient = False + output_data = sdfg.add_datadesc(output_data, output_desc, find_new_name=True) output_node = state.add_access(output_data) state.add_nedge( edge.result.dc_node, @@ -868,6 +894,13 @@ def write_output_of_nested_sdfg_to_temporary(inner_value: ValueExpr) -> ValueExp nsdfg = dace.SDFG(self.subgraph_builder.unique_nsdfg_name("if_stmt")) nsdfg.debuginfo = gtir_to_sdfg_utils.debug_info(node, default=self.sdfg.debuginfo) + # add connectivities + for aname, adesc in self.sdfg.arrays.items(): + if gtx_dace_args.is_connectivity_identifier(aname): + adesc = adesc.clone() + adesc.transient = True + nsdfg.add_datadesc(aname, adesc) + # create states inside the nested SDFG for the if-branches if_region = dace.sdfg.state.ConditionalBlock("if") nsdfg.add_node(if_region, ensure_unique_name=True) @@ -948,14 +981,21 @@ def write_output_of_nested_sdfg_to_temporary(inner_value: ValueExpr) -> ValueExp outputs = {outval.dc_node.data for outval in gtx_utils.flatten_nested_tuple((result,))} + # map the connectivities that were used inside the nested SDFG + used_connectivities = [ + aname + for aname, adesc in nsdfg.arrays.items() + if gtx_dace_args.is_connectivity_identifier(aname) and not adesc.transient + ] + # all free symbols are mapped to the symbols available in parent SDFG nsdfg_symbols_mapping = {str(sym): sym for sym in nsdfg.free_symbols} if isinstance(condition_value, SymbolExpr): nsdfg_symbols_mapping["__cond"] = condition_value.value nsdfg_node = self.state.add_nested_sdfg( nsdfg, - inputs=set(input_memlets.keys()), - outputs=outputs, + inputs=sorted(used_connectivities | input_memlets.keys()), + outputs=sorted(outputs), symbol_mapping=nsdfg_symbols_mapping, ) @@ -971,6 +1011,16 @@ def write_output_of_nested_sdfg_to_temporary(inner_value: ValueExpr) -> ValueExp self.sdfg.make_array_memlet(input_expr.dc_node.data), ) + for conn in used_connectivities: + desc = self.sdfg.data(conn) + desc.transient = False + self._add_input_data_edge( + self.state.add_access(conn), + dace_subsets.Range.from_array(desc), + nsdfg_node, + conn, + ) + return ( gtx_utils.tree_map(write_output_of_nested_sdfg_to_temporary)(result) if isinstance(result, tuple) @@ -1904,6 +1954,9 @@ def visit_Literal(self, node: gtir.Literal) -> SymbolExpr: dc_dtype = gtx_dace_args.as_dace_type(node.type) return SymbolExpr(node.value, dc_dtype) + def visit_OffsetLiteral(self, node: gtir.OffsetLiteral) -> SymbolExpr: + return SymbolExpr(node.value, gtir_to_sdfg_types.INDEX_DTYPE) + def visit_SymRef(self, node: gtir.SymRef) -> MaybeNestedInTuple[IteratorExpr | DataExpr]: param = str(node.id) if param in self.symbol_map: @@ -1918,7 +1971,7 @@ def translate_lambda_to_dataflow( state: dace.SDFGState, sdfg_builder: gtir_to_sdfg.DataflowBuilder, node: gtir.Lambda, - args: Sequence[MaybeNestedInTuple[IteratorExpr | MemletExpr | ValueExpr]], + args: Sequence[MaybeNestedInTuple[IteratorExpr | DataExpr]], ) -> tuple[list[DataflowInputEdge], MaybeNestedInTuple[DataflowOutputEdge]]: """ Entry point to visit a `Lambda` node and lower it to a dataflow graph,