diff --git a/pyproject.toml b/pyproject.toml index e4a5609d93..9305a97f6d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,7 +11,7 @@ dace-cartesian = [ 'dace>=1.0.2' # refined in [tool.uv.sources] ] dace-next = [ - 'dace==43!2026.04.27' # uses custom index at 'https://github.com/GridTools/pypi' + 'dace==2.3.5' # uses custom index at 'https://github.com/GridTools/pypi' ] dev = [ {include-group = 'build'}, @@ -486,7 +486,7 @@ url = 'https://gridtools.github.io/pypi/' atlas4py = {index = "test.pypi"} dace = [ {git = "https://github.com/GridTools/dace", branch = "romanc/stree-v2", group = "dace-cartesian"}, - {index = "gridtools", group = "dace-next"} + {git = "https://github.com/philip-paul-mueller/dace", branch = "phimuell__new-gpu-codegen-dev", group = "dace-next"} ] # -- versioningit -- diff --git a/src/gt4py/next/iterator/transforms/pass_manager.py b/src/gt4py/next/iterator/transforms/pass_manager.py index 1236a3209a..a25bceefde 100644 --- a/src/gt4py/next/iterator/transforms/pass_manager.py +++ b/src/gt4py/next/iterator/transforms/pass_manager.py @@ -138,6 +138,7 @@ def apply_common_transforms( unroll_reduce=False, common_subexpression_elimination=True, force_inline_lambda_args=False, + transform_concat_where_to_as_fieldop=True, #: A dictionary mapping axes names to their length. See :func:`infer_domain.infer_expr` for #: more details. symbolic_domain_sizes: Optional[dict[str, itir.Expr]] = None, @@ -186,7 +187,8 @@ def apply_common_transforms( ir = prune_empty_concat_where.prune_empty_concat_where(ir) ir = remove_broadcast.RemoveBroadcast.apply(ir) - ir = concat_where.transform_to_as_fieldop(ir) + if transform_concat_where_to_as_fieldop: + ir = concat_where.transform_to_as_fieldop(ir) for _ in range(10): inlined = ir @@ -258,6 +260,12 @@ def apply_common_transforms( ir, opcount_preserving=True, force_inline_lambda_args=force_inline_lambda_args ) + ir = infer_domain.infer_program( + ir, + offset_provider=offset_provider, + symbolic_domain_sizes=symbolic_domain_sizes, + ) + assert isinstance(ir, itir.Program) return ir diff --git a/src/gt4py/next/program_processors/runners/dace/__init__.py b/src/gt4py/next/program_processors/runners/dace/__init__.py index 0bb2c40dc3..0406d85f25 100644 --- a/src/gt4py/next/program_processors/runners/dace/__init__.py +++ b/src/gt4py/next/program_processors/runners/dace/__init__.py @@ -12,6 +12,8 @@ make_dace_backend, run_dace_cpu, run_dace_cpu_cached, + run_dace_cpu_gt, + run_dace_cpu_gt_noopt, run_dace_cpu_noopt, run_dace_gpu, run_dace_gpu_cached, @@ -24,6 +26,8 @@ "make_dace_backend", "run_dace_cpu", "run_dace_cpu_cached", + "run_dace_cpu_gt", + "run_dace_cpu_gt_noopt", "run_dace_cpu_noopt", "run_dace_gpu", "run_dace_gpu_cached", diff --git a/src/gt4py/next/program_processors/runners/dace/lowering/gtir_dataflow.py b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_dataflow.py index e8d1914aa8..c26e8625ec 100644 --- a/src/gt4py/next/program_processors/runners/dace/lowering/gtir_dataflow.py +++ b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_dataflow.py @@ -269,6 +269,7 @@ def connect( map_exit: Optional[dace_nodes.MapExit], dest: dace_nodes.AccessNode, dest_subset: dace_subsets.Range, + allow_removal_of_last_node: bool, ) -> bool: """Create a connection to the `dest` node, writing the given `dest_subset`. @@ -280,24 +281,27 @@ def connect( dest_desc = self.result.dc_node.desc(self.state) write_edge = self.state.in_edges(self.result.dc_node)[0] - # Check the kind of node which writes the result - if isinstance(write_edge.src, dace_nodes.Tasklet): - # The temporary data written by a tasklet can be safely deleted. - assert map_exit is not None - remove_last_node = True - elif isinstance(write_edge.src, dace_nodes.NestedSDFG): - if isinstance(dest_desc, dace.data.Scalar): - # We keep scalar temporary storage, as a general rule, since it - # does not affect performance of the generated code. This scalar - # is only required in some cases, e.g. for nested SDFGs implementing - # reduction, which use a WCR memlet for the reduction operation. + if allow_removal_of_last_node: + if self.state.out_degree(self.result.dc_node) != 0: remove_last_node = False - else: - # We remove the transient array on the output connection of a nested - # SDFG and write directly to the destination node. - # The caller is responsible to propagate the strides of the destination - # array to the array inside the nested SDFG. + elif isinstance(write_edge.src, dace_nodes.Tasklet): + # The temporary data written by a tasklet can be safely deleted. remove_last_node = True + elif isinstance(write_edge.src, dace_nodes.NestedSDFG): + if isinstance(dest_desc, dace.data.Scalar): + # We keep scalar temporary storage, as a general rule, since it + # does not affect performance of the generated code. This scalar + # is only required in some cases, e.g. for nested SDFGs implementing + # reduction, which use a WCR memlet for the reduction operation. + remove_last_node = False + else: + # We remove the transient array on the output connection of a nested + # SDFG and write directly to the destination node. + # The caller is responsible to propagate the strides of the destination + # array to the array inside the nested SDFG. + remove_last_node = True + else: + remove_last_node = False else: remove_last_node = False @@ -552,6 +556,51 @@ def _construct_tasklet_result( ), ) + def _visit_can_deref(self, node: gtir.FunCall) -> DataExpr: + assert isinstance(node.type, ts.ScalarType) and node.type.kind == ts.ScalarKind.BOOL + assert len(node.args) == 1 + if not cpm.is_applied_shift(node.args[0]): + raise NotImplementedError( + f"Only `can_deref` of unstructured `shift` expressions is supported, got {node.args[0]}." + ) + it = self._visit_shift(node.args[0]) + index_values = {k: v for k, v in it.indices.items() if not isinstance(v, SymbolExpr)} + if len(index_values) == 0: + raise ValueError(f"Unexpected `can_deref` argument: {it}.") + can_deref_node, connector_mapping = self._add_tasklet( + name="can_deref", + inputs={f"index_{dim.value}" for dim in index_values}, + outputs={"valid"}, + code="valid = " + + " and ".join( + f"index_{dim.value} != {gtx_common._DEFAULT_SKIP_VALUE}" + for dim in index_values.keys() + ), + ) + for dim, index_expr in index_values.items(): + index_connector = f"index_{dim.value}" + if isinstance(index_expr, MemletExpr): + self._add_input_data_edge( + index_expr.dc_node, + index_expr.subset, + can_deref_node, + connector_mapping[index_connector], + ) + + else: + self._add_edge( + index_expr.dc_node, + None, + can_deref_node, + connector_mapping[index_connector], + dace.Memlet(data=index_expr.dc_node.data, subset="0"), + ) + return self._construct_tasklet_result( + dc_dtype=dace.bool_, + src_node=can_deref_node, + src_connector=connector_mapping["valid"], + ) + def _visit_deref(self, node: gtir.FunCall) -> DataExpr: """ Visit a `deref` node, which represents dereferencing of an iterator. @@ -709,12 +758,7 @@ def _visit_if_branch_arg( inner_desc = arg_desc.clone() inner_desc.transient = False elif isinstance(arg.gt_dtype, ts.ScalarType): - if isinstance(arg, MemletExpr) and len(arg.gt_field.dims) == 1: - # TODO(edopao): we cannot use a scalar because of an issue in gpu codegen, - # which leads to compilation error: cannot convert 'const double' to 'const double*' - inner_desc = dace.data.Array(dtype=arg_desc.dtype, shape=(1,)) - else: - inner_desc = dace.data.Scalar(arg_desc.dtype) + inner_desc = dace.data.Scalar(arg_desc.dtype) else: # for list of values, we retrieve the local size from the corresponding offset local_dim = arg.gt_dtype.offset_type @@ -837,16 +881,9 @@ def _visit_if_branch_result( # If the result is currently written to a transient node, inside the nested SDFG, # we need to allocate a non-transient data node. result_desc = edge.result.dc_node.desc(sdfg) - if isinstance(sym.type, ts.ScalarType) and isinstance(result_desc, dace.data.Array): - # TODO(edopao): a scalar should not be represented as an array, but - # currently this can happen because of an issue workaround, see the - # todo comment above in `_visit_if_branch_arg()`. - assert len(result_desc.shape) == 1 and result_desc.shape[0] == 1 - _, output_desc = sdfg.add_scalar(output_data, result_desc.dtype) - else: - output_desc = result_desc.clone() - output_desc.transient = False - output_data = sdfg.add_datadesc(output_data, output_desc, find_new_name=True) + output_desc = result_desc.clone() + output_desc.transient = False + output_data = sdfg.add_datadesc(output_data, output_desc, find_new_name=True) output_node = state.add_access(output_data) state.add_nedge( edge.result.dc_node, @@ -1181,7 +1218,7 @@ def _visit_list_get(self, node: gtir.FunCall) -> ValueExpr: assert index_arg.dc_dtype in dace.dtypes.INTEGER_TYPES src_subset = ( dace_subsets.Range(src_subset[:local_dim_index]) - + dace_subsets.Range.from_string(index_arg.value) + + dace_subsets.Range.from_indices([index_arg.value]) + dace_subsets.Range(src_subset[local_dim_index + 1 :]) ) if isinstance(src_arg, MemletExpr): @@ -1586,7 +1623,6 @@ def _make_cartesian_shift( self, it: IteratorExpr, offset_dim: gtx_common.Dimension, offset_expr: DataExpr ) -> IteratorExpr: """Implements cartesian shift along one dimension.""" - assert any(dim == offset_dim for dim, _ in it.field_domain) new_index: SymbolExpr | ValueExpr index_expr = it.indices[offset_dim] if isinstance(index_expr, SymbolExpr) and isinstance(offset_expr, SymbolExpr): @@ -1658,9 +1694,9 @@ def _make_cartesian_shift( def _make_dynamic_neighbor_offset( self, - offset_expr: MemletExpr | ValueExpr, + offset_expr: MemletExpr | ValueExpr | SymbolExpr, offset_table_node: dace_nodes.AccessNode, - origin_index: SymbolExpr, + origin_index: MemletExpr | ValueExpr | SymbolExpr, ) -> ValueExpr: """ Implements access to neighbor connectivity table by means of a tasklet node. @@ -1669,33 +1705,56 @@ def _make_dynamic_neighbor_offset( or computed by another tasklet (`DataExpr`). """ new_index_connector = "neighbor_index" - tasklet_node, connector_mapping = self._add_tasklet( - "dynamic_neighbor_offset", - {"table", "offset"}, - {new_index_connector}, - f"{new_index_connector} = table[{origin_index.value}, offset]", - ) + if isinstance(offset_expr, SymbolExpr) and isinstance(origin_index, SymbolExpr): + tasklet_node, connector_mapping = self._add_tasklet( + "dynamic_neighbor_offset", + {"table"}, + {new_index_connector}, + f"{new_index_connector} = table[{origin_index.value}, {offset_expr.value}]", + ) + elif isinstance(origin_index, SymbolExpr): + tasklet_node, connector_mapping = self._add_tasklet( + "dynamic_neighbor_offset", + {"table", "offset"}, + {new_index_connector}, + f"{new_index_connector} = table[{origin_index.value}, offset]", + ) + elif isinstance(offset_expr, SymbolExpr): + tasklet_node, connector_mapping = self._add_tasklet( + "dynamic_neighbor_offset", + {"table", "origin"}, + {new_index_connector}, + f"{new_index_connector} = table[origin, {offset_expr.value}]", + ) + else: + tasklet_node, connector_mapping = self._add_tasklet( + "dynamic_neighbor_offset", + {"table", "offset", "origin"}, + {new_index_connector}, + f"{new_index_connector} = table[origin, offset]", + ) self._add_input_data_edge( offset_table_node, dace_subsets.Range.from_array(offset_table_node.desc(self.sdfg)), tasklet_node, connector_mapping["table"], ) - if isinstance(offset_expr, MemletExpr): - self._add_input_data_edge( - offset_expr.dc_node, - offset_expr.subset, - tasklet_node, - connector_mapping["offset"], - ) - else: - self._add_edge( - offset_expr.dc_node, - None, - tasklet_node, - connector_mapping["offset"], - dace.Memlet(data=offset_expr.dc_node.data, subset="0"), - ) + for conn, input_expr in [("offset", offset_expr), ("origin", origin_index)]: + if isinstance(input_expr, MemletExpr): + self._add_input_data_edge( + input_expr.dc_node, + input_expr.subset, + tasklet_node, + connector_mapping[conn], + ) + elif isinstance(input_expr, ValueExpr): + self._add_edge( + input_expr.dc_node, + None, + tasklet_node, + connector_mapping[conn], + dace.Memlet(data=input_expr.dc_node.data, subset="0"), + ) dc_dtype = offset_table_node.desc(self.sdfg).dtype return self._construct_tasklet_result( @@ -1710,17 +1769,14 @@ def _make_unstructured_shift( offset_expr: DataExpr, ) -> IteratorExpr: """Implements shift in unstructured domain by means of a neighbor table.""" - # make sure that the field can be dereferenced with the given connectivity type - assert any(dim == conn_type.codomain for dim, _ in it.field_domain) # make sure that the iterator can access the connectivity table assert conn_type.source_dim in it.indices conn_source_index = it.indices[conn_type.source_dim] - assert isinstance(conn_source_index, SymbolExpr) shifted_indices = { dim: idx for dim, idx in it.indices.items() if dim != conn_type.source_dim } - if isinstance(offset_expr, SymbolExpr): + if isinstance(offset_expr, SymbolExpr) and isinstance(conn_source_index, SymbolExpr): # use memlet to retrieve the neighbor index shifted_indices[conn_type.codomain] = MemletExpr( dc_node=conn_node, @@ -1869,7 +1925,10 @@ def _visit_tuple_get( return tuple_fields[index] def visit_FunCall(self, node: gtir.FunCall) -> MaybeNestedInTuple[IteratorExpr | DataExpr]: - if cpm.is_call_to(node, "deref"): + if cpm.is_call_to(node, "can_deref"): + return self._visit_can_deref(node) + + elif cpm.is_call_to(node, "deref"): return self._visit_deref(node) elif cpm.is_call_to(node, "if_"): @@ -2003,6 +2062,7 @@ def translate_lambda_to_dataflow( flat_arg_nodes = ( x.field if isinstance(x, IteratorExpr) else x.dc_node # type: ignore[attr-defined] for x in gtx_utils.flatten_nested_tuple(tuple(args)) + if not isinstance(x, SymbolExpr) ) state.remove_nodes_from([node for node in flat_arg_nodes if state.degree(node) == 0]) diff --git a/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg_primitives.py b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg_primitives.py index 007202a87d..e926cc25a5 100644 --- a/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg_primitives.py +++ b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg_primitives.py @@ -9,11 +9,13 @@ from __future__ import annotations import abc +from collections import Counter from typing import TYPE_CHECKING, Iterable, Optional, Protocol import dace from dace import nodes as dace_nodes, subsets as dace_subsets +from gt4py.eve.extended_typing import MaybeNestedInTuple from gt4py.next import common as gtx_common, utils as gtx_utils from gt4py.next.iterator import ir as gtir from gt4py.next.iterator.ir_utils import ( @@ -73,16 +75,18 @@ def _parse_fieldop_arg( ctx: gtir_to_sdfg.SubgraphContext, sdfg_builder: gtir_to_sdfg.SDFGBuilder, domain: gtir_domain.FieldopDomain, -) -> gtir_dataflow.IteratorExpr | gtir_dataflow.MemletExpr: +) -> MaybeNestedInTuple[gtir_dataflow.IteratorExpr | gtir_dataflow.MemletExpr]: """ Helper method to visit an expression passed as argument to a field operator and create the local view for the field argument. """ arg = sdfg_builder.visit(node, ctx=ctx) - if not isinstance(arg, gtir_to_sdfg_types.FieldopData): - raise ValueError("Expected a field, found a tuple of fields.") - return arg.get_local_view(domain, ctx.sdfg) + if isinstance(arg, gtir_to_sdfg_types.FieldopData): + return arg.get_local_view(domain, ctx.sdfg) + else: + # handle tuples of fields + return gtx_utils.tree_map(lambda targ: targ.get_local_view(domain, ctx.sdfg))(arg) def _create_field_operator_impl( @@ -92,6 +96,7 @@ def _create_field_operator_impl( output_edge: gtir_dataflow.DataflowOutputEdge, output_type: ts.FieldType, map_exit: dace_nodes.MapExit, + output_consumer_count: dict[dace_nodes.AccessNode, int], ) -> gtir_to_sdfg_types.FieldopData: """ Helper method to allocate a temporary array that stores one field computed @@ -163,8 +168,11 @@ def _create_field_operator_impl( ) field_node = ctx.state.add_access(field_name) - # and here the edge writing the dataflow result data through the map exit node - output_edge.connect(map_exit, field_node, field_subset) + # and here the edge writing the dataflow result data through the map exit node. + # Note that we cannot remove the output data access node only if this is used + # for mutiple fields in a return tuple. + allow_removal_of_last_node = output_consumer_count[output_edge.result.dc_node] == 1 + output_edge.connect(map_exit, field_node, field_subset, allow_removal_of_last_node) return gtir_to_sdfg_types.FieldopData( field_node, ts.FieldType(field_dims, output_edge.result.gt_dtype), tuple(field_origin) @@ -174,10 +182,10 @@ def _create_field_operator_impl( def _create_field_operator( ctx: gtir_to_sdfg.SubgraphContext, domain: gtir_domain.FieldopDomain, - node_type: ts.FieldType, + node_type: ts.FieldType | ts.TupleType, sdfg_builder: gtir_to_sdfg.SDFGBuilder, input_edges: Iterable[gtir_dataflow.DataflowInputEdge], - output_edge: gtir_dataflow.DataflowOutputEdge, + output_tree: MaybeNestedInTuple[gtir_dataflow.DataflowOutputEdge], ) -> gtir_to_sdfg_types.FieldopResult: """ Helper method to build the output of a field operator. @@ -188,11 +196,11 @@ def _create_field_operator( node_type: The GT4Py type of the IR node that produces this field. sdfg_builder: The object used to build the map scope in the provided SDFG. input_edges: List of edges to pass input data into the dataflow. - output_edge: Edge corresponding to the dataflow output. + output_tree: A tree representation of the dataflow output data. Returns: - The descriptor of the field operator result, which is a single field defined - on the domain of the field operator. + The descriptor of the field operator result, which can be either a single + field or a tuple fields. """ if len(domain) == 0: @@ -214,7 +222,26 @@ def _create_field_operator( for edge in input_edges: edge.connect(map_entry) - return _create_field_operator_impl(ctx, sdfg_builder, domain, output_edge, node_type, map_exit) + # The same output node could be used for multiple fields in case of tuple return. + # In this case, the output access node cannot be removed. + consumer_count = Counter( + oedge.result.dc_node + for oedge in gtx_utils.flatten_nested_tuple((output_tree,)) + if oedge is not None + ) + if isinstance(node_type, ts.FieldType): + assert isinstance(output_tree, gtir_dataflow.DataflowOutputEdge) + return _create_field_operator_impl( + ctx, sdfg_builder, domain, output_tree, node_type, map_exit, consumer_count + ) + else: + # handle tuples of fields + output_symbol_tree = gtir_to_sdfg_utils.make_symbol_tree("x", node_type) + return gtx_utils.tree_map( + lambda output_edge, output_sym: _create_field_operator_impl( + ctx, sdfg_builder, domain, output_edge, output_sym.type, map_exit, consumer_count + ) + )(output_tree, output_symbol_tree) def translate_as_fieldop( @@ -246,9 +273,6 @@ def translate_as_fieldop( if cpm.is_call_to(fieldop_expr, "scan"): return translate_scan(node, ctx, sdfg_builder) - if not isinstance(node.type, ts.FieldType): - raise NotImplementedError("Unexpected 'as_fieldop' with tuple output in SDFG lowering.") - # Parse the domain of the field operator. assert isinstance(fieldop_domain_expr.type, ts.DomainType) field_domain = gtir_domain.get_field_domain( @@ -258,11 +282,13 @@ def translate_as_fieldop( if cpm.is_ref_to(fieldop_expr, "deref"): arg_type = node.args[0].type assert isinstance(arg_type, (ts.FieldType, ts.ScalarType)) - if isinstance(arg_type, ts.ScalarType) or arg_type.dims != node.type.dims: + if ( + isinstance(arg_type, ts.ScalarType) or arg_type.dims != node.type.dims # type: ignore[union-attr] + ): # Special usage of 'deref' as argument to fieldop expression, to broadcast # the input value (a scalar or a field slice) on the output domain. stencil_expr = im.lambda_("a")(im.deref("a")) - stencil_expr.expr.type = node.type.dtype + stencil_expr.expr.type = node.type.dtype # type: ignore[union-attr] else: # Special usage of 'deref' with field argument, to access the field # on the given domain. It copies a subset of the source field. @@ -282,13 +308,12 @@ def translate_as_fieldop( fieldop_args = [_parse_fieldop_arg(arg, ctx, sdfg_builder, field_domain) for arg in node.args] # represent the field operator as a mapped tasklet graph, which will range over the field domain - input_edges, output_edge = gtir_dataflow.translate_lambda_to_dataflow( + input_edges, output_edges = gtir_dataflow.translate_lambda_to_dataflow( ctx.sdfg, ctx.state, sdfg_builder, stencil_expr, fieldop_args ) - assert isinstance(output_edge, gtir_dataflow.DataflowOutputEdge) return _create_field_operator( - ctx, field_domain, node.type, sdfg_builder, input_edges, output_edge + ctx, field_domain, node.type, sdfg_builder, input_edges, output_edges ) diff --git a/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg_scan.py b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg_scan.py index a569c06fbf..5098b70ac9 100644 --- a/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg_scan.py +++ b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg_scan.py @@ -22,6 +22,7 @@ from __future__ import annotations +from collections import Counter from typing import Iterable, Sequence import dace @@ -47,43 +48,23 @@ from gt4py.next.type_system import type_info as ti, type_specifications as ts -def _parse_scan_fieldop_arg( +def _parse_fieldop_arg( node: gtir.Expr, ctx: gtir_to_sdfg.SubgraphContext, sdfg_builder: gtir_to_sdfg.SDFGBuilder, - field_domain: gtir_domain.FieldopDomain, -) -> MaybeNestedInTuple[gtir_dataflow.MemletExpr]: - """Helper method to visit an expression passed as argument to a scan field operator. - - On the innermost level, a scan operator is lowered to a loop region which computes - column elements in the vertical dimension. - - It differs from the helper method `gtir_to_sdfg_primitives` in that field arguments - are passed in full shape along the vertical dimension, rather than as iterator. + domain: gtir_domain.FieldopDomain, +) -> MaybeNestedInTuple[gtir_dataflow.IteratorExpr | gtir_dataflow.MemletExpr]: + """ + Helper method to visit an expression passed as argument to a field operator + and create the local view for the field argument. """ - - def _parse_fieldop_arg_impl( - arg: gtir_to_sdfg_types.FieldopData, - ) -> gtir_dataflow.MemletExpr: - arg_expr = arg.get_local_view(field_domain, ctx.sdfg) - if isinstance(arg_expr, gtir_dataflow.MemletExpr): - return arg_expr - # In scan field operator, the arguments to the vertical stencil are passed by value. - # Therefore, the full field shape is passed as `MemletExpr` rather than `IteratorExpr`. - field_type = ts.FieldType( - dims=[dim for dim, _ in arg_expr.field_domain], dtype=arg_expr.gt_dtype - ) - return gtir_dataflow.MemletExpr( - arg_expr.field, field_type, arg_expr.get_memlet_subset(ctx.sdfg) - ) - arg = sdfg_builder.visit(node, ctx=ctx) if isinstance(arg, gtir_to_sdfg_types.FieldopData): - return _parse_fieldop_arg_impl(arg) + return arg.get_local_view(domain, ctx.sdfg) else: # handle tuples of fields - return gtx_utils.tree_map(_parse_fieldop_arg_impl)(arg) + return gtx_utils.tree_map(lambda targ: targ.get_local_view(domain, ctx.sdfg))(arg) def _create_scan_field_operator_impl( @@ -93,6 +74,7 @@ def _create_scan_field_operator_impl( output_domain: infer_domain.NonTupleDomainAccess, output_type: ts.FieldType, map_exit: dace_nodes.MapExit | None, + output_consumer_count: dict[dace_nodes.AccessNode, int], ) -> gtir_to_sdfg_types.FieldopData | None: """ Helper method to allocate a temporary array that stores one field computed @@ -166,7 +148,12 @@ def _create_scan_field_operator_impl( # Up to now the nested SDFG is writing into a transient data container that # has the size to hold one column. The function below, that does the connection, # will remove that transient and write directly to the result field. - inner_map_output_temporary_removed = output_edge.connect(map_exit, field_node, field_subset) + # Note that we cannot remove the output data access node only if this is used + # for mutiple fields in a return tuple. + allow_removal_of_last_node = output_consumer_count[output_edge.result.dc_node] == 1 + inner_map_output_temporary_removed = output_edge.connect( + map_exit, field_node, field_subset, allow_removal_of_last_node + ) if not inner_map_output_temporary_removed: raise ValueError("The scan nested SDFG is expected to write directly to the result field.") @@ -264,6 +251,14 @@ def _create_scan_field_operator( else im.sym("__gtir_unused_dummy_var", node_type) ) + # The same output node could be used for multiple fields in case of tuple return. + # In this case, the output access node cannot be removed. + consumer_count = Counter( + oedge.result.dc_node + for oedge in gtx_utils.flatten_nested_tuple((output,)) + if oedge is not None + ) + return gtx_utils.tree_map( lambda edge, domain, sym: _create_scan_field_operator_impl( ctx, @@ -272,6 +267,7 @@ def _create_scan_field_operator( domain, sym.type, map_exit, + consumer_count, ) )(output, output_domain, dummy_output_symbol) @@ -427,7 +423,7 @@ def get_scan_output_shape( # inside the 'compute' state, visit the list of arguments to be passed to the stencil stencil_args = [ - _parse_scan_fieldop_arg(im.ref(p.id), compute_ctx, sdfg_builder, field_domain) + _parse_fieldop_arg(im.ref(p.id), compute_ctx, sdfg_builder, field_domain) for p in lambda_node.params ] # stil inside the 'compute' state, generate the dataflow representing the stencil diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/auto_optimize.py b/src/gt4py/next/program_processors/runners/dace/transformations/auto_optimize.py index 799e8ad228..3f803ffcf7 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/auto_optimize.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/auto_optimize.py @@ -751,8 +751,8 @@ def _gt_auto_process_dataflow_inside_maps( # NestedSDFGs inside the ConditionalBlocks it fuses. sdfg.apply_transformations_repeated( gtx_transformations.FuseHorizontalConditionBlocks(), - validate=True, - validate_all=True, + validate=False, + validate_all=validate_all, ) # Move dataflow into the branches of the `if` such that they are only evaluated diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/strides.py b/src/gt4py/next/program_processors/runners/dace/transformations/strides.py index 0840f77755..168f847ee8 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/strides.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/strides.py @@ -6,6 +6,7 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause +import warnings from typing import Optional, TypeAlias import dace @@ -527,7 +528,11 @@ def _gt_map_strides_into_nested_sdfg( raise NotImplementedError("NestedSDFGs can not be used to increase the rank.") if len(new_strides) != len(inner_shape): - raise ValueError("Failed to compute the inner strides.") + # It could still be possible to access an array at index 0. Consider a memlet + # which only writes index 0 to the inner shape (dim_oinflow == 1), although + # the inner shape is larger than 1, but we only read index 0 inside the SDFG. + warnings.warn("Failed to compute the inner strides.", stacklevel=2) + return # For the strides of the arrays inside the nested SDFG we will create a new unique # symbol which is initialized, through the symbol mapping, to the value of this diff --git a/src/gt4py/next/program_processors/runners/dace/workflow/backend.py b/src/gt4py/next/program_processors/runners/dace/workflow/backend.py index de6778a750..44be154f3c 100644 --- a/src/gt4py/next/program_processors/runners/dace/workflow/backend.py +++ b/src/gt4py/next/program_processors/runners/dace/workflow/backend.py @@ -69,6 +69,7 @@ class Params: def make_dace_backend( gpu: bool, cached: bool = True, + apply_common_transform: bool = False, auto_optimize: bool = True, async_sdfg_call: bool = True, optimization_args: dict[str, Any] | None = None, @@ -82,6 +83,8 @@ def make_dace_backend( Args: gpu: Enable GPU transformations and code generation. cached: Cache the lowered SDFG as a JSON file and the compiled programs. + apply_common_transform: Whether to apply the GTIR common transform before + lowering to SDFG. auto_optimize: Enable the SDFG auto-optimize pipeline. async_sdfg_call: Make an asynchronous SDFG call on GPU to allow overlapping of GPU kernel execution with the Python driver code. @@ -128,6 +131,7 @@ def make_dace_backend( cached=cached, auto_optimize=auto_optimize, otf_workflow__cached_translation=cached, + otf_workflow__bare_translation__apply_common_transform=apply_common_transform, otf_workflow__bare_translation__async_sdfg_call=(async_sdfg_call if gpu else False), otf_workflow__bare_translation__auto_optimize_args=optimization_args, otf_workflow__bare_translation__unstructured_horizontal_has_unit_stride=unstructured_horizontal_has_unit_stride, @@ -174,3 +178,18 @@ def make_dace_backend( auto_optimize=True, async_sdfg_call=True, ) + +run_dace_cpu_gt = make_dace_backend( + gpu=False, + cached=False, + apply_common_transform=True, + auto_optimize=True, + async_sdfg_call=False, +) +run_dace_cpu_gt_noopt = make_dace_backend( + gpu=False, + cached=False, + apply_common_transform=True, + auto_optimize=False, + async_sdfg_call=False, +) diff --git a/src/gt4py/next/program_processors/runners/dace/workflow/translation.py b/src/gt4py/next/program_processors/runners/dace/workflow/translation.py index ad8e8ea04b..7bf2482a0a 100644 --- a/src/gt4py/next/program_processors/runners/dace/workflow/translation.py +++ b/src/gt4py/next/program_processors/runners/dace/workflow/translation.py @@ -9,6 +9,7 @@ from __future__ import annotations import dataclasses +import functools from typing import Any, Optional import dace @@ -17,7 +18,8 @@ from gt4py._core import definitions as core_defs from gt4py.next import common, config from gt4py.next.instrumentation import metrics -from gt4py.next.iterator import ir as itir, transforms as itir_transforms +from gt4py.next.iterator import ir as itir +from gt4py.next.iterator.transforms import pass_manager from gt4py.next.otf import code_specs, definitions, stages, workflow from gt4py.next.otf.binding import interface from gt4py.next.program_processors.runners.dace import ( @@ -357,6 +359,7 @@ class DaCeTranslator( definitions.TranslationStep[code_specs.SDFGCodeSpec], ): device_type: core_defs.DeviceType + apply_common_transform: bool auto_optimize: bool auto_optimize_args: dict[str, Any] | None async_sdfg_call: bool @@ -375,6 +378,31 @@ def generate_sdfg( with gtx_wfdcommon.dace_context(device_type=self.device_type): return self._generate_sdfg_without_configuring_dace(*args, **kwargs) + def _preprocess_program( + self, + program: itir.Program, + offset_provider: common.OffsetProvider | common.OffsetProviderType, + ) -> itir.Program: + apply_common_transforms = functools.partial( + pass_manager.apply_common_transforms, + offset_provider=offset_provider, + force_inline_lambda_args=True, + transform_concat_where_to_as_fieldop=False, + use_max_domain_range_on_unstructured_shift=self.use_max_domain_range_on_unstructured_shift, + ) + + new_program = apply_common_transforms(program, unroll_reduce=False) + + if any( + node.id == "neighbors" + for node in new_program.pre_walk_values().if_isinstance(itir.SymRef) + ): + # if we don't unroll, there may be lifts left in the itir which can't + # be lowered to SDFG. In this case, just retry with unrolled reductions. + new_program = apply_common_transforms(program, unroll_reduce=True) + + return new_program + def _generate_sdfg_without_configuring_dace( self, ir: itir.Program, @@ -382,11 +410,14 @@ def _generate_sdfg_without_configuring_dace( column_axis: Optional[common.Dimension], ) -> dace.SDFG: if not self.disable_itir_transforms: - ir = itir_transforms.apply_fieldview_transforms( - ir, - use_max_domain_range_on_unstructured_shift=self.use_max_domain_range_on_unstructured_shift, - offset_provider=offset_provider, - ) + if self.apply_common_transform: + ir = self._preprocess_program(ir, offset_provider) + else: + ir = pass_manager.apply_fieldview_transforms( + ir, + use_max_domain_range_on_unstructured_shift=self.use_max_domain_range_on_unstructured_shift, + offset_provider=offset_provider, + ) offset_provider_type = common.offset_provider_to_type(offset_provider) on_gpu = self.device_type != core_defs.DeviceType.CPU diff --git a/tests/next_tests/definitions.py b/tests/next_tests/definitions.py index 3396d93d3c..b440bae543 100644 --- a/tests/next_tests/definitions.py +++ b/tests/next_tests/definitions.py @@ -73,6 +73,8 @@ class OptionalProgramBackendId(_PythonObjectIdMixin, str, enum.Enum): DACE_CPU = "gt4py.next.program_processors.runners.dace.run_dace_cpu" DACE_GPU = "gt4py.next.program_processors.runners.dace.run_dace_gpu" DACE_CPU_NO_OPT = "gt4py.next.program_processors.runners.dace.run_dace_cpu_noopt" + DACE_CPU_GT = "gt4py.next.program_processors.runners.dace.run_dace_cpu_gt" + DACE_CPU_GT_NO_OPT = "gt4py.next.program_processors.runners.dace.run_dace_cpu_gt_noopt" class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): @@ -203,6 +205,8 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): OptionalProgramBackendId.DACE_CPU: DACE_SKIP_TEST_LIST, OptionalProgramBackendId.DACE_GPU: DACE_SKIP_TEST_LIST, OptionalProgramBackendId.DACE_CPU_NO_OPT: DACE_SKIP_TEST_LIST, + OptionalProgramBackendId.DACE_CPU_GT: DACE_SKIP_TEST_LIST, + OptionalProgramBackendId.DACE_CPU_GT_NO_OPT: DACE_SKIP_TEST_LIST, ProgramBackendId.GTFN_CPU: GTFN_SKIP_TEST_LIST + [(USES_SCAN_NESTED, XFAIL, UNSUPPORTED_MESSAGE)], ProgramBackendId.GTFN_CPU_IMPERATIVE: GTFN_SKIP_TEST_LIST diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py index ab880868ce..ff129f1298 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py @@ -103,6 +103,14 @@ def __gt_allocator__( next_tests.definitions.OptionalProgramBackendId.DACE_CPU_NO_OPT, marks=pytest.mark.requires_dace, ), + pytest.param( + next_tests.definitions.OptionalProgramBackendId.DACE_CPU_GT, + marks=pytest.mark.requires_dace, + ), + pytest.param( + next_tests.definitions.OptionalProgramBackendId.DACE_CPU_GT_NO_OPT, + marks=pytest.mark.requires_dace, + ), ], ids=lambda p: p.short_id(), ) diff --git a/tests/next_tests/unit_tests/conftest.py b/tests/next_tests/unit_tests/conftest.py index 580df0ee49..98fdb16ff8 100644 --- a/tests/next_tests/unit_tests/conftest.py +++ b/tests/next_tests/unit_tests/conftest.py @@ -65,6 +65,10 @@ def _program_processor(request) -> tuple[ProgramProcessor, bool]: (next_tests.definitions.OptionalProgramBackendId.DACE_CPU_NO_OPT, True), marks=pytest.mark.requires_dace, ), + pytest.param( + (next_tests.definitions.OptionalProgramBackendId.DACE_CPU_GT_NO_OPT, True), + marks=pytest.mark.requires_dace, + ), ], ids=lambda p: p[0].short_id() if p[0] is not None else "None", ) diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace_translation.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace_translation.py index 117aa1a92b..c3785c371e 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace_translation.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace_translation.py @@ -65,6 +65,7 @@ def _translate_gtir_to_sdfg( # we use the SDFG hash in build cache to avoid clashes between CPU and GPU SDFGs return dace_wf_translation.DaCeTranslator( device_type=device_type, + apply_common_transform=False, auto_optimize=auto_optimize, auto_optimize_args=None, async_sdfg_call=async_sdfg_call, diff --git a/uv.lock b/uv.lock index 59e8b3010b..e7530e7e65 100644 --- a/uv.lock +++ b/uv.lock @@ -1244,8 +1244,8 @@ dependencies = [ [[package]] name = "dace" -version = "43!2026.4.27" -source = { registry = "https://gridtools.github.io/pypi/" } +version = "2.3.5" +source = { git = "https://github.com/philip-paul-mueller/dace?branch=phimuell__new-gpu-codegen-dev#a62787d92d4ffe3f4586e8be1fdfc4169e791c17" } resolution-markers = [ "python_full_version >= '3.14' and sys_platform == 'win32'", "python_full_version >= '3.14' and sys_platform == 'emscripten'", @@ -1276,9 +1276,6 @@ dependencies = [ { name = "sympy" }, { name = "typing-extensions" }, ] -wheels = [ - { url = "https://gridtools.github.io/pypi/dace/dace-43!2026.4.27-py3-none-any.whl", hash = "sha256:9098ceed412d287d575b2ed30cc90b754b966eda893698c55129a7bc5bd37d37" }, -] [[package]] name = "debugpy" @@ -1792,7 +1789,7 @@ dace-cartesian = [ { name = "dace", version = "1.0.0", source = { git = "https://github.com/GridTools/dace?branch=romanc%2Fstree-v2#d5fbadb626389e425fac5ed93d2a880811eca41f" } }, ] dace-next = [ - { name = "dace", version = "43!2026.4.27", source = { registry = "https://gridtools.github.io/pypi/" } }, + { name = "dace", version = "2.3.5", source = { git = "https://github.com/philip-paul-mueller/dace?branch=phimuell__new-gpu-codegen-dev#a62787d92d4ffe3f4586e8be1fdfc4169e791c17" } }, ] dev = [ { name = "atlas4py" }, @@ -1962,7 +1959,7 @@ build = [ { name = "wheel", specifier = ">=0.33.6" }, ] dace-cartesian = [{ name = "dace", git = "https://github.com/GridTools/dace?branch=romanc%2Fstree-v2" }] -dace-next = [{ name = "dace", specifier = "==43!2026.4.27", index = "https://gridtools.github.io/pypi/", conflict = { package = "gt4py", group = "dace-next" } }] +dace-next = [{ name = "dace", git = "https://github.com/philip-paul-mueller/dace?branch=phimuell__new-gpu-codegen-dev" }] dev = [ { name = "atlas4py", specifier = ">=0.41", index = "https://test.pypi.org/simple" }, { name = "coverage", extras = ["toml"], specifier = ">=7.6.1" },