diff --git a/src/gt4py/next/program_processors/runners/dace/library_nodes/__init__.py b/src/gt4py/next/program_processors/runners/dace/library_nodes/__init__.py new file mode 100644 index 0000000000..9239878658 --- /dev/null +++ b/src/gt4py/next/program_processors/runners/dace/library_nodes/__init__.py @@ -0,0 +1,16 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +from gt4py.next.program_processors.runners.dace.library_nodes.reduce_with_skip_values import ( + ReduceWithSkipValues, +) + + +__all__ = [ + "ReduceWithSkipValues", +] diff --git a/src/gt4py/next/program_processors/runners/dace/library_nodes/reduce_with_skip_values.py b/src/gt4py/next/program_processors/runners/dace/library_nodes/reduce_with_skip_values.py new file mode 100644 index 0000000000..78ab76f848 --- /dev/null +++ b/src/gt4py/next/program_processors/runners/dace/library_nodes/reduce_with_skip_values.py @@ -0,0 +1,176 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +from __future__ import annotations + +from typing import Any, Final + +import dace +from dace import library as dace_library, properties as dace_properties +from dace.sdfg import graph as dace_graph +from dace.transformation import transformation as dace_transform + +from gt4py.next import common as gtx_common + + +@dace.library.node +class ReduceWithSkipValues(dace.sdfg.nodes.LibraryNode): + """Implements reduction with skip values.""" + + implementations: Final[dict[str, dace_transform.ExpandTransformation]] = {} + default_implementation: Final[str | None] = "pure" + + # Properties + wcr = dace_properties.LambdaProperty(default="lambda a, b: a") + identity = dace_properties.Property(default=0) + init = dace_properties.Property(default=0) + input_conn = dace_properties.Property(default="_in") + output_conn = dace_properties.Property(default="_out") + mask_conn = dace_properties.Property(default="_mask") + + def __init__( + self, + name: str, + wcr: str, + identity: dace.symbolic.SymbolicType, + init: dace.symbolic.SymbolicType, + debuginfo: dace.dtypes.DebugInfo | None = None, + input_conn: str | None = None, + output_conn: str | None = None, + mask_conn: str | None = None, + ) -> None: + if input_conn is None: + input_conn = self.input_conn + else: + self.input_conn = input_conn + + if output_conn is None: + output_conn = self.output_conn + else: + self.output_conn = output_conn + + if mask_conn is None: + mask_conn = self.mask_conn + else: + self.mask_conn = mask_conn + + super().__init__(name, inputs={input_conn, mask_conn}, outputs={output_conn}) + self.wcr = wcr + self.identity = identity + self.init = init + self.debuginfo = debuginfo + + def validate(self, sdfg: dace.SDFG, state: dace.SDFGState) -> None: + assert len(list(state.in_edges_by_connector(self, self.input_conn))) == 1 + inedge: dace_graph.MultiConnectorEdge = next( + state.in_edges_by_connector(self, self.input_conn) + ) + assert len(list(state.out_edges_by_connector(self, self.output_conn))) == 1 + outedge: dace_graph.MultiConnectorEdge = next( + state.out_edges_by_connector(self, self.output_conn) + ) + assert len(list(state.in_edges_by_connector(self, self.mask_conn))) == 1 + maskedge: dace_graph.MultiConnectorEdge = next( + state.in_edges_by_connector(self, self.mask_conn) + ) + + mask_desc = sdfg.arrays[maskedge.data.data] + if len(mask_desc.shape) != 2: + raise ValueError(f"Invalid shape {mask_desc.shape} of mask array, expected 2d array.") + max_neighbors = mask_desc.shape[1] + if not (isinstance(max_neighbors, int) or str(max_neighbors).isdigit()): + raise ValueError( + f"Invalid shape {mask_desc.shape} of mask array, expected constant neighbors size." + ) + if inedge.data.num_elements() != max_neighbors: + raise ValueError(f"Invalid memlet on input connector {self.input_conn}.") + if maskedge.data.num_elements() != max_neighbors: + raise ValueError(f"Invalid memlet on input connector {self.mask_conn}.") + if outedge.data.num_elements() != 1: + raise ValueError(f"Invalid memlet on output connector {self.output_conn}.") + + +@dace_library.register_expansion(ReduceWithSkipValues, "pure") +class ReduceWithSkipValuesExpandInlined(dace_transform.ExpandTransformation): + """Implements pure expansion of the ReduceWithSkipValues library node.""" + + environments: Final[list[Any]] = [] + + @staticmethod + def expansion(node: ReduceWithSkipValues, state: dace.SDFGState, sdfg: dace.SDFG) -> dace.SDFG: + assert len(list(state.in_edges_by_connector(node, node.input_conn))) == 1 + inedge: dace_graph.MultiConnectorEdge = next( + state.in_edges_by_connector(node, node.input_conn) + ) + assert len(list(state.out_edges_by_connector(node, node.output_conn))) == 1 + outedge: dace_graph.MultiConnectorEdge = next( + state.out_edges_by_connector(node, node.output_conn) + ) + assert len(list(state.in_edges_by_connector(node, node.mask_conn))) == 1 + maskedge: dace_graph.MultiConnectorEdge = next( + state.in_edges_by_connector(node, node.mask_conn) + ) + input_desc = sdfg.arrays[inedge.data.data] + output_desc = sdfg.arrays[outedge.data.data] + mask_desc = sdfg.arrays[maskedge.data.data] + assert len(mask_desc.shape) == 2 + max_neighbors = mask_desc.shape[1] + assert isinstance(max_neighbors, int) or str(max_neighbors).isdigit() + + local_dim_index = inedge.data.src_subset.size().index(max_neighbors) + + nsdfg = dace.SDFG(node.label) + inp, _ = nsdfg.add_array( + node.input_conn, + (max_neighbors,), + input_desc.dtype, + strides=(input_desc.strides[local_dim_index],), + ) + mask, _ = nsdfg.add_array( + node.mask_conn, + (max_neighbors,), + mask_desc.dtype, + strides=(mask_desc.strides[1],), + ) + outp, _ = nsdfg.add_scalar(node.output_conn, output_desc.dtype) + st_init = nsdfg.add_state("init") + init_tasklet = st_init.add_tasklet( + name="write", + inputs={}, + outputs={"__tlet_out"}, + code=f"__tlet_out = {input_desc.dtype}({node.init})", + ) + st_init.add_edge( + init_tasklet, + "__tlet_out", + st_init.add_access(outp), + None, + dace.Memlet(data=outp, subset="0"), + ) + st_reduce = nsdfg.add_state_after(st_init, "compute") + # Fill skip values in local dimension with the reduce identity value + skip_value = f"{input_desc.dtype}({node.identity})" + # Since this map operates on a pure local dimension, we explicitly set sequential + # schedule and we set the flag 'wcr_nonatomic=True' on the write memlet. + # TODO(phimuell): decide if auto-optimizer should reset `wcr_nonatomic` properties, as DaCe does. + st_reduce.add_mapped_tasklet( + name="reduce_with_skip_values", + map_ranges={"i": f"0:{max_neighbors}"}, + inputs={ + "__tlet_inp": dace.Memlet(data=inp, subset="i"), + "__tlet_mask": dace.Memlet(data=mask, subset="i"), + }, + code=f"__tlet_out = __tlet_inp if __tlet_mask != {gtx_common._DEFAULT_SKIP_VALUE} else {skip_value}", + outputs={ + "__tlet_out": dace.Memlet(data=outp, subset="0", wcr=node.wcr, wcr_nonatomic=True), + }, + external_edges=True, + schedule=dace.dtypes.ScheduleType.Sequential, + ) + + return nsdfg diff --git a/src/gt4py/next/program_processors/runners/dace/lowering/gtir_dataflow.py b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_dataflow.py index e8d1914aa8..b824668234 100644 --- a/src/gt4py/next/program_processors/runners/dace/lowering/gtir_dataflow.py +++ b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_dataflow.py @@ -27,6 +27,7 @@ import dace from dace import nodes as dace_nodes, subsets as dace_subsets +from dace.libraries import standard as dace_stdlib from gt4py import eve from gt4py.eve.extended_typing import MaybeNestedInTuple, NestedTuple @@ -34,7 +35,10 @@ from gt4py.next.iterator import ir as gtir from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im from gt4py.next.iterator.transforms import symbol_ref_utils -from gt4py.next.program_processors.runners.dace import sdfg_args as gtx_dace_args +from gt4py.next.program_processors.runners.dace import ( + library_nodes as gtir_library_nodes, + sdfg_args as gtx_dace_args, +) from gt4py.next.program_processors.runners.dace.lowering import ( gtir_python_codegen, gtir_to_sdfg, @@ -277,7 +281,6 @@ def connect( outside data container is removed, the caller is responsible to propagate the strides of the destination array to the array inside the nested SDFG. """ - dest_desc = self.result.dc_node.desc(self.state) write_edge = self.state.in_edges(self.result.dc_node)[0] # Check the kind of node which writes the result @@ -286,18 +289,11 @@ def connect( assert map_exit is not None remove_last_node = True elif isinstance(write_edge.src, dace_nodes.NestedSDFG): - if isinstance(dest_desc, dace.data.Scalar): - # We keep scalar temporary storage, as a general rule, since it - # does not affect performance of the generated code. This scalar - # is only required in some cases, e.g. for nested SDFGs implementing - # reduction, which use a WCR memlet for the reduction operation. - remove_last_node = False - else: - # We remove the transient array on the output connection of a nested - # SDFG and write directly to the destination node. - # The caller is responsible to propagate the strides of the destination - # array to the array inside the nested SDFG. - remove_last_node = True + # We remove the transient array on the output connection of a nested + # SDFG and write directly to the destination node. + # The caller is responsible to propagate the strides of the destination + # array to the array inside the nested SDFG. + remove_last_node = True else: remove_last_node = False @@ -1379,139 +1375,6 @@ def _visit_map(self, node: gtir.FunCall) -> ValueExpr: gt_dtype=ts.ListType(node.type.element_type, offset_type), ) - def _make_reduce_with_skip_values( - self, - input_expr: ValueExpr | MemletExpr, - offset_provider_type: gtx_common.NeighborConnectivityType, - reduce_init: SymbolExpr, - reduce_identity: SymbolExpr, - reduce_wcr: str, - result_node: dace_nodes.AccessNode, - ) -> None: - """ - Helper method to lower reduction on a local field containing skip values. - - The reduction is implemented as a nested SDFG containing 2 states. In first - state, the result (a scalar data node passed as argumet) is initialized. - In second state, a mapped tasklet uses a write-conflict resolution (wcr) - memlet to update the result. - We use the offset provider as a mask to identify skip values: the value - that is written to the result node is either the input value, when the - corresponding neighbor index in the connectivity table is valid, or the - identity value if the neighbor index is missing. - """ - origin_map_index = gtir_to_sdfg_utils.get_map_variable(offset_provider_type.source_dim) - - assert ( - isinstance(input_expr.gt_dtype, ts.ListType) - and input_expr.gt_dtype.offset_type is not None - ) - offset_type = input_expr.gt_dtype.offset_type - connectivity = gtx_dace_args.connectivity_identifier(offset_type.value) - connectivity_node = self.state.add_access(connectivity) - connectivity_desc = connectivity_node.desc(self.sdfg) - connectivity_desc.transient = False - - desc = input_expr.dc_node.desc(self.sdfg) - if isinstance(input_expr, MemletExpr): - local_dim_indices = [i for i, size in enumerate(input_expr.subset.size()) if size != 1] - else: - local_dim_indices = list(range(len(desc.shape))) - - if len(local_dim_indices) != 1: - raise NotImplementedError( - f"Found {len(local_dim_indices)} local dimensions in reduce expression, expected one." - ) - local_dim_index = local_dim_indices[0] - assert desc.shape[local_dim_index] == offset_provider_type.max_neighbors - - # we lower the reduction map with WCR out memlet in a nested SDFG - nsdfg = dace.SDFG(self.subgraph_builder.unique_nsdfg_name("reduce_with_skip_values")) - nsdfg.add_array( - "values", - (desc.shape[local_dim_index],), - desc.dtype, - strides=(desc.strides[local_dim_index],), - ) - nsdfg.add_array( - "neighbor_indices", - (connectivity_desc.shape[1],), - connectivity_desc.dtype, - strides=(connectivity_desc.strides[1],), - ) - nsdfg.add_scalar("acc", desc.dtype) - st_init = nsdfg.add_state(f"{nsdfg.label}_init") - init_tasklet, connector_mapping = self.subgraph_builder.add_tasklet( - name="init_acc", - sdfg=self.sdfg, - state=st_init, - inputs={}, - outputs={"val"}, - code=f"val = {reduce_init.dc_dtype}({reduce_init.value})", - ) - st_init.add_edge( - init_tasklet, - connector_mapping["val"], - st_init.add_access("acc"), - None, - dace.Memlet(data="acc", subset="0"), - ) - st_reduce = nsdfg.add_state_after(st_init, f"{nsdfg.label}_reduce") - # Fill skip values in local dimension with the reduce identity value - skip_value = f"{reduce_identity.dc_dtype}({reduce_identity.value})" - # Since this map operates on a pure local dimension, we explicitly set sequential - # schedule and we set the flag 'wcr_nonatomic=True' on the write memlet. - # TODO(phimuell): decide if auto-optimizer should reset `wcr_nonatomic` properties, as DaCe does. - self.subgraph_builder.add_mapped_tasklet( - name="reduce_with_skip_values", - sdfg=self.sdfg, - state=st_reduce, - map_ranges={"i": f"0:{offset_provider_type.max_neighbors}"}, - inputs={ - "val": dace.Memlet(data="values", subset="i"), - "neighbor_idx": dace.Memlet(data="neighbor_indices", subset="i"), - }, - code=f"out = val if neighbor_idx != {gtx_common._DEFAULT_SKIP_VALUE} else {skip_value}", - outputs={ - "out": dace.Memlet(data="acc", subset="0", wcr=reduce_wcr, wcr_nonatomic=True), - }, - external_edges=True, - schedule=dace.dtypes.ScheduleType.Sequential, - ) - - nsdfg_node = self.state.add_nested_sdfg( - nsdfg, inputs={"values", "neighbor_indices"}, outputs={"acc"} - ) - - if isinstance(input_expr, MemletExpr): - self._add_input_data_edge(input_expr.dc_node, input_expr.subset, nsdfg_node, "values") - else: - self.state.add_edge( - input_expr.dc_node, - None, - nsdfg_node, - "values", - self.sdfg.make_array_memlet(input_expr.dc_node.data), - ) - # The layout of connectivity tables is known. - assert len(offset_provider_type.domain) == 2 - assert offset_provider_type.domain[1].kind == gtx_common.DimensionKind.LOCAL - self._add_input_data_edge( - connectivity_node, - dace_subsets.Range.from_string( - f"{origin_map_index}, 0:{offset_provider_type.max_neighbors}" - ), - nsdfg_node, - "neighbor_indices", - ) - self.state.add_edge( - nsdfg_node, - "acc", - result_node, - None, - dace.Memlet(data=result_node.data, subset="0"), - ) - def _visit_reduce(self, node: gtir.FunCall) -> ValueExpr: assert isinstance(node.type, ts.ScalarType) op_name, reduce_init, reduce_identity = get_reduce_params(node) @@ -1530,27 +1393,73 @@ def _visit_reduce(self, node: gtir.FunCall) -> ValueExpr: offset_provider_type = self.subgraph_builder.get_offset_provider_type(offset_type.value) assert isinstance(offset_provider_type, gtx_common.NeighborConnectivityType) + inp_conn = "_in" + outp_conn = "_out" + mask_conn = "_mask" + if offset_provider_type.has_skip_values: - self._make_reduce_with_skip_values( - input_expr, - offset_provider_type, - reduce_init, - reduce_identity, + name = self.subgraph_builder.unique_nsdfg_name("reduce_with_skip_values") + reduce_node = gtir_library_nodes.ReduceWithSkipValues( + name, + reduce_wcr, + identity=reduce_identity.value, + init=reduce_init.value, + debuginfo=gtir_to_sdfg_utils.debug_info(node), + input_conn=inp_conn, + output_conn=outp_conn, + mask_conn=mask_conn, + ) + else: + reduce_node = dace_stdlib.Reduce( + "reduce", reduce_wcr, - result_node, + axes=None, + identity=reduce_init.value, + schedule=dace.dtypes.ScheduleType.Default, + debuginfo=gtir_to_sdfg_utils.debug_info(node), + inputs={inp_conn}, + outputs={outp_conn}, ) + self.state.add_node(reduce_node) + if isinstance(input_expr, MemletExpr): + self._add_input_data_edge(input_expr.dc_node, input_expr.subset, reduce_node, inp_conn) else: - reduce_node = self.state.add_reduce(reduce_wcr, axes=None, identity=reduce_init.value) - if isinstance(input_expr, MemletExpr): - self._add_input_data_edge(input_expr.dc_node, input_expr.subset, reduce_node) - else: - self.state.add_nedge( - input_expr.dc_node, - reduce_node, - self.sdfg.make_array_memlet(input_expr.dc_node.data), - ) - self.state.add_nedge(reduce_node, result_node, dace.Memlet(data=result, subset="0")) + self.state.add_edge( + input_expr.dc_node, + None, + reduce_node, + inp_conn, + self.sdfg.make_array_memlet(input_expr.dc_node.data), + ) + self.state.add_edge( + reduce_node, outp_conn, result_node, None, dace.Memlet(data=result, subset="0") + ) + + if offset_provider_type.has_skip_values: + assert ( + isinstance(input_expr.gt_dtype, ts.ListType) + and input_expr.gt_dtype.offset_type is not None + ) + + offset_type = input_expr.gt_dtype.offset_type + connectivity = gtx_dace_args.connectivity_identifier(offset_type.value) + connectivity_node = self.state.add_access(connectivity) + connectivity_desc = connectivity_node.desc(self.sdfg) + connectivity_desc.transient = False + + # The layout of connectivity tables is known. + assert len(offset_provider_type.domain) == 2 + assert offset_provider_type.domain[1].kind == gtx_common.DimensionKind.LOCAL + origin_map_index = gtir_to_sdfg_utils.get_map_variable(offset_provider_type.source_dim) + self._add_input_data_edge( + connectivity_node, + dace_subsets.Range.from_string( + f"{origin_map_index}, 0:{offset_provider_type.max_neighbors}" + ), + reduce_node, + mask_conn, + ) return ValueExpr(result_node, node.type) diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/auto_optimize.py b/src/gt4py/next/program_processors/runners/dace/transformations/auto_optimize.py index 799e8ad228..532d9094ae 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/auto_optimize.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/auto_optimize.py @@ -20,7 +20,10 @@ from dace.transformation.passes import analysis as dace_analysis from gt4py.next import common as gtx_common -from gt4py.next.program_processors.runners.dace import transformations as gtx_transformations +from gt4py.next.program_processors.runners.dace import ( + library_nodes as gtir_library_nodes, + transformations as gtx_transformations, +) class GT4PyAutoOptHook(enum.Enum): @@ -234,6 +237,15 @@ def gt_auto_optimize( device = dace.DeviceType.GPU if gpu else dace.DeviceType.CPU optimization_hooks = optimization_hooks or {} + # We expand the GT4Py reduce nodes with skip values. + # TODO(edopao,phimuell): Check where this should be done. Doing it here ensures + # the same result as if the reduce expression was lowered before optimization. + for node, state in list(sdfg.all_nodes_recursive()): + if isinstance(node, gtir_library_nodes.ReduceWithSkipValues): + node.expand(state) + if validate_all: + sdfg.validate() + with dace.config.temporary_config(): # Do not store which transformations were applied inside the SDFG. dace.Config.set("store_history", value=False) @@ -751,8 +763,8 @@ def _gt_auto_process_dataflow_inside_maps( # NestedSDFGs inside the ConditionalBlocks it fuses. sdfg.apply_transformations_repeated( gtx_transformations.FuseHorizontalConditionBlocks(), - validate=True, - validate_all=True, + validate=False, + validate_all=validate_all, ) # Move dataflow into the branches of the `if` such that they are only evaluated