diff --git a/docs/development/ADRs/next/0019-Connectivities.md b/docs/development/ADRs/next/0019-Connectivities.md index 76e85e49a6..be585124b1 100644 --- a/docs/development/ADRs/next/0019-Connectivities.md +++ b/docs/development/ADRs/next/0019-Connectivities.md @@ -31,7 +31,7 @@ We update and introduce the following concepts **NeighborTable** is a _NeighborConnectivity_ backed by a buffer. -**ConnectivityType**, **NeighborConnectivityType** contains all information that is needed for compilation. +**ConnectivityType** contains all information that is needed for compilation. ### Full definitions diff --git a/src/gt4py/next/common.py b/src/gt4py/next/common.py index 69d42d3dea..b7bd575a53 100644 --- a/src/gt4py/next/common.py +++ b/src/gt4py/next/common.py @@ -1012,22 +1012,18 @@ def remapping(cls) -> ConnectivityKind: @dataclasses.dataclass(frozen=True) -class ConnectivityType: # TODO(havogt): would better live in type_specifications but would have to solve a circular import +class ConnectivityType: domain: tuple[Dimension, ...] codomain: Dimension skip_value: Optional[core_defs.IntegralScalar] dtype: core_defs.DType + # TODO(havogt): refactor towards encoding this information in the local dimensions of the ConnectivityType.domain + max_neighbors: int @property def has_skip_values(self) -> bool: return self.skip_value is not None - -@dataclasses.dataclass(frozen=True) -class NeighborConnectivityType(ConnectivityType): - # TODO(havogt): refactor towards encoding this information in the local dimensions of the ConnectivityType.domain - max_neighbors: int - @property def source_dim(self) -> Dimension: return self.domain[0] @@ -1053,21 +1049,13 @@ def codomain(self) -> DimT_co: """ def __gt_type__(self) -> ConnectivityType: - if is_neighbor_connectivity(self): - return NeighborConnectivityType( - domain=self.domain.dims, - codomain=self.codomain, - dtype=self.dtype, - skip_value=self.skip_value, - max_neighbors=self.ndarray.shape[1], - ) - else: - return ConnectivityType( - domain=self.domain.dims, - codomain=self.codomain, - dtype=self.dtype, - skip_value=self.skip_value, - ) + return ConnectivityType( + domain=self.domain.dims, + codomain=self.codomain, + dtype=self.dtype, + skip_value=self.skip_value, + max_neighbors=self.ndarray.shape[1], + ) @property def kind(self) -> ConnectivityKind: @@ -1178,39 +1166,8 @@ def _connectivity( raise NotImplementedError -class NeighborConnectivity(Connectivity, Protocol): - # TODO(havogt): work towards encoding this properly in the type - def __gt_type__(self) -> NeighborConnectivityType: ... - - -def is_neighbor_connectivity(obj: Any) -> TypeGuard[NeighborConnectivity]: - if not isinstance(obj, Connectivity): - return False - domain_dims = obj.domain.dims - return ( - len(domain_dims) == 2 - and domain_dims[0].kind is DimensionKind.HORIZONTAL - and domain_dims[1].kind is DimensionKind.LOCAL - ) - - -class NeighborTable( - NeighborConnectivity, Protocol -): # TODO(havogt): try to express by inheriting from NdArrayConnectivityField (but this would require a protocol to move it out of `embedded.nd_array_field`) - @property - def ndarray(self) -> core_defs.NDArrayObject: - # Note that this property is currently already there from inheriting from `Field`, - # however this seems wrong, therefore we explicitly introduce it here (or it should come - # implicitly from the `NdArrayConnectivityField` protocol). - ... - - -def is_neighbor_table(obj: Any) -> TypeGuard[NeighborTable]: - return is_neighbor_connectivity(obj) and hasattr(obj, "ndarray") - - -OffsetProviderElem: TypeAlias = Dimension | NeighborConnectivity -OffsetProviderTypeElem: TypeAlias = Dimension | NeighborConnectivityType +OffsetProviderElem: TypeAlias = Dimension | Connectivity +OffsetProviderTypeElem: TypeAlias = Dimension | ConnectivityType # Note: `OffsetProvider` and `OffsetProviderType` should not be accessed directly, # use the `get_offset` and `get_offset_type` functions instead. OffsetProvider: TypeAlias = Mapping[Tag, OffsetProviderElem] diff --git a/src/gt4py/next/embedded/nd_array_field.py b/src/gt4py/next/embedded/nd_array_field.py index e9aff84a15..171b00a08a 100644 --- a/src/gt4py/next/embedded/nd_array_field.py +++ b/src/gt4py/next/embedded/nd_array_field.py @@ -1009,7 +1009,7 @@ def _builtin_op( offset_definition = common.get_offset( current_offset_provider, axis.value ) # assumes offset and local dimension have same name - assert common.is_neighbor_table(offset_definition) + assert isinstance(offset_definition, common.Connectivity) new_domain = common.Domain(*[nr for nr in field.domain if nr.dim != axis]) broadcast_slice = tuple( diff --git a/src/gt4py/next/ffront/fbuiltins.py b/src/gt4py/next/ffront/fbuiltins.py index af93beeb3b..b82ad2219b 100644 --- a/src/gt4py/next/ffront/fbuiltins.py +++ b/src/gt4py/next/ffront/fbuiltins.py @@ -481,7 +481,6 @@ def __getitem__(self, offset: int) -> common.Connectivity: if isinstance(offset_definition, common.Dimension): connectivity = common.CartesianConnectivity(offset_definition, offset) elif isinstance(offset_definition, common.Connectivity): - assert common.is_neighbor_connectivity(offset_definition) named_index = common.NamedIndex(self.target[-1], offset) connectivity = offset_definition[named_index] else: diff --git a/src/gt4py/next/iterator/embedded.py b/src/gt4py/next/iterator/embedded.py index a9b36a4624..e9f6230b3a 100644 --- a/src/gt4py/next/iterator/embedded.py +++ b/src/gt4py/next/iterator/embedded.py @@ -123,8 +123,8 @@ def __init__( def __gt_origin__(self) -> xtyping.Never: raise NotImplementedError - def __gt_type__(self) -> common.NeighborConnectivityType: - return common.NeighborConnectivityType( + def __gt_type__(self) -> common.ConnectivityType: + return common.ConnectivityType( domain=self.domain_dims, codomain=self.codomain_dim, max_neighbors=self._max_neighbors, @@ -563,7 +563,6 @@ def execute_shift( new_entry[i] = 0 else: offset_implementation = common.get_offset(offset_provider, tag) - assert common.is_neighbor_connectivity(offset_implementation) source_dim = offset_implementation.__gt_type__().source_dim cur_index = pos[source_dim.value] assert common.is_int_index(cur_index) @@ -586,7 +585,7 @@ def execute_shift( else: raise AssertionError() return new_pos - elif common.is_neighbor_connectivity(offset_implementation): + elif isinstance(offset_implementation, common.Connectivity): source_dim = offset_implementation.__gt_type__().source_dim assert source_dim.value in pos new_pos = pos.copy() @@ -1400,7 +1399,6 @@ def __gt_type__(self) -> ts.ListType: offset_provider = embedded_context.get_offset_provider() assert offset_provider is not None connectivity = common.get_offset(offset_provider, offset_tag) - assert common.is_neighbor_connectivity(connectivity) local_dim = connectivity.__gt_type__().neighbor_dim return ts.ListType(element_type=element_type, offset_type=local_dim) @@ -1428,7 +1426,6 @@ def neighbors(offset: runtime.Offset, it: ItIterator) -> _List: offset_provider = embedded_context.get_offset_provider() assert offset_provider is not None connectivity = common.get_offset(offset_provider, offset_str) - assert common.is_neighbor_connectivity(connectivity) return _List( values=tuple( shifted.deref() @@ -1506,7 +1503,6 @@ def deref(self) -> Any: offset_provider = embedded_context.get_offset_provider() assert offset_provider is not None connectivity = common.get_offset(offset_provider, self.list_offset) - assert common.is_neighbor_connectivity(connectivity) return _List( values=tuple( shifted.deref() @@ -1763,7 +1759,6 @@ def _fieldspec_list_to_value( offset_type = type_.offset_type assert isinstance(offset_type, common.Dimension) connectivity = common.get_offset(offset_provider, offset_type.value) - assert common.is_neighbor_connectivity(connectivity) return domain.insert( len(domain), common.named_range((offset_type, connectivity.__gt_type__().max_neighbors)), diff --git a/src/gt4py/next/iterator/ir_utils/domain_utils.py b/src/gt4py/next/iterator/ir_utils/domain_utils.py index 08275cf2cc..0f3f67fa56 100644 --- a/src/gt4py/next/iterator/ir_utils/domain_utils.py +++ b/src/gt4py/next/iterator/ir_utils/domain_utils.py @@ -71,7 +71,6 @@ def _unstructured_translate_range_statically( """ assert common.is_offset_provider(offset_provider) connectivity = offset_provider[tag] - assert common.is_neighbor_connectivity(connectivity) skip_value = connectivity.skip_value # fold & convert expr into actual integers @@ -192,7 +191,7 @@ def translate( new_ranges[current_dim] = SymbolicRange.translate( self.ranges[current_dim], val.value ) - elif isinstance(connectivity_type, common.NeighborConnectivityType): + elif isinstance(connectivity_type, common.ConnectivityType): # unstructured shift assert ( isinstance(val, itir.OffsetLiteral) and isinstance(val.value, int) diff --git a/src/gt4py/next/iterator/transforms/pass_manager.py b/src/gt4py/next/iterator/transforms/pass_manager.py index 1236a3209a..aaa0905ed1 100644 --- a/src/gt4py/next/iterator/transforms/pass_manager.py +++ b/src/gt4py/next/iterator/transforms/pass_manager.py @@ -53,7 +53,7 @@ def _max_domain_range_sizes(offset_provider: common.OffsetProvider) -> dict[str, """ sizes: dict[str, int] = {} for provider in offset_provider.values(): - if common.is_neighbor_connectivity(provider): + if isinstance(provider, common.Connectivity): src_dim = provider.__gt_type__().source_dim.value codomain_dim = provider.__gt_type__().codomain.value sizes[src_dim] = max(sizes.get(src_dim, 0), provider.ndarray.shape[0]) diff --git a/src/gt4py/next/iterator/transforms/unroll_reduce.py b/src/gt4py/next/iterator/transforms/unroll_reduce.py index c4ea32fc36..0cbad8385c 100644 --- a/src/gt4py/next/iterator/transforms/unroll_reduce.py +++ b/src/gt4py/next/iterator/transforms/unroll_reduce.py @@ -53,15 +53,15 @@ def _get_partial_offset_tags(reduce_args: Iterable[itir.Expr]) -> Iterable[str]: def _get_connectivity( applied_reduce_node: itir.FunCall, offset_provider_type: common.OffsetProviderType, -) -> common.NeighborConnectivityType: +) -> common.ConnectivityType: """Return single connectivity that is compatible with the arguments of the reduce.""" if not cpm.is_applied_reduce(applied_reduce_node): raise ValueError("Expected a call to a 'reduce' object, i.e. 'reduce(...)(...)'.") - connectivities: list[common.NeighborConnectivityType] = [] + connectivities: list[common.ConnectivityType] = [] for o in _get_partial_offset_tags(applied_reduce_node.args): conn = common.get_offset_type(offset_provider_type, o) - assert isinstance(conn, common.NeighborConnectivityType) + assert isinstance(conn, common.ConnectivityType) connectivities.append(conn) if not connectivities: diff --git a/src/gt4py/next/iterator/type_system/type_synthesizer.py b/src/gt4py/next/iterator/type_system/type_synthesizer.py index 16d5da7e3b..2dd42ecc71 100644 --- a/src/gt4py/next/iterator/type_system/type_synthesizer.py +++ b/src/gt4py/next/iterator/type_system/type_synthesizer.py @@ -326,7 +326,7 @@ def neighbors( ) assert isinstance(it, it_ts.IteratorType) conn_type = common.get_offset_type(offset_provider_type, offset_literal.value) - assert isinstance(conn_type, common.NeighborConnectivityType) + assert isinstance(conn_type, common.ConnectivityType) return ts.ListType(element_type=it.element_type, offset_type=conn_type.neighbor_dim) @@ -494,7 +494,7 @@ def _resolve_dimensions( ... itir.OffsetLiteral(value=0), ... ) >>> offset_provider_type = { - ... "V2E": common.NeighborConnectivityType( + ... "V2E": common.ConnectivityType( ... domain=(Vertex, V2E), ... codomain=Edge, ... skip_value=None, @@ -515,7 +515,7 @@ def _resolve_dimensions( offset_type = common.get_offset_type(offset_provider_type, off_literal.value) if isinstance(offset_type, common.Dimension) and input_dim == offset_type: continue # No shift applied - if isinstance(offset_type, (fbuiltins.FieldOffset, common.NeighborConnectivityType)): + if isinstance(offset_type, (fbuiltins.FieldOffset, common.ConnectivityType)): if input_dim == offset_type.codomain: # Check if input fits to offset input_dim = offset_type.domain[0] # Update input_dim for next iteration resolved_dims.append(input_dim) @@ -670,7 +670,7 @@ def apply_shift( type_ = common.get_offset_type(offset_provider_type, offset_axis.value) if isinstance(type_, common.Dimension): pass - elif isinstance(type_, common.NeighborConnectivityType): + elif isinstance(type_, common.ConnectivityType): found = False for i, dim in enumerate(new_position_dims): if dim.value == type_.source_dim.value: diff --git a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py index 2135af7fbb..4c98ed3da6 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py @@ -93,7 +93,7 @@ def _process_regular_arguments( # translate sparse dimensions to tuple dtype dim_name = dim.value connectivity = common.get_offset_type(offset_provider_type, dim_name) - assert isinstance(connectivity, common.NeighborConnectivityType) + assert isinstance(connectivity, common.ConnectivityType) size = connectivity.max_neighbors arg = f"gridtools::sid::dimension_to_tuple_like({arg})" arg_exprs.append(arg) @@ -106,7 +106,7 @@ def _process_connectivity_args( arg_exprs: list[str] = [] for name, connectivity_type in offset_provider_type.items(): - if isinstance(connectivity_type, common.NeighborConnectivityType): + if isinstance(connectivity_type, common.ConnectivityType): if connectivity_type.dtype.scalar_type not in [np.int32, np.int64]: raise ValueError( "Neighbor table indices must be of type 'np.int32' or 'np.int64'." @@ -138,7 +138,7 @@ def _process_connectivity_args( pass else: raise AssertionError( - f"Expected offset provider type '{name}' to be a 'NeighborConnectivityType' or 'Dimension', " + f"Expected offset provider type '{name}' to be a 'ConnectivityType' or 'Dimension', " f"got '{type(connectivity_type).__name__}'." ) diff --git a/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py b/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py index ecd8ed88ed..cda68bfa42 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py @@ -166,9 +166,7 @@ def _collect_offset_definitions( offset_definitions[offset_name] = TagDefinition( name=Sym(id=offset_name), alias=SymRef(id=dim.value) ) - elif isinstance( - connectivity_type := dim_or_connectivity_type, common.NeighborConnectivityType - ): + elif isinstance(connectivity_type := dim_or_connectivity_type, common.ConnectivityType): assert grid_type == common.GridType.UNSTRUCTURED offset_definitions[offset_name] = TagDefinition(name=Sym(id=offset_name)) if offset_name != connectivity_type.neighbor_dim.value: @@ -456,7 +454,7 @@ def _visit_unstructured_domain(self, node: itir.FunCall, **kwargs: Any) -> Node: for o in shift_offsets: if o in self.offset_provider_type and isinstance( common.get_offset_type(self.offset_provider_type, o), - common.NeighborConnectivityType, + common.ConnectivityType, ): connectivities.append(SymRef(id=o)) return UnstructuredDomain( diff --git a/src/gt4py/next/program_processors/runners/dace/lowering/gtir_dataflow.py b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_dataflow.py index e7dc855428..9a33873f28 100644 --- a/src/gt4py/next/program_processors/runners/dace/lowering/gtir_dataflow.py +++ b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_dataflow.py @@ -714,7 +714,7 @@ def _visit_if_branch_arg( offset_provider_type = self.subgraph_builder.get_offset_provider_type( arg.gt_dtype.offset_type.value ) - assert isinstance(offset_provider_type, gtx_common.NeighborConnectivityType) + assert isinstance(offset_provider_type, gtx_common.ConnectivityType) inner_desc = dace.data.Array( dtype=arg_desc.dtype, shape=[offset_provider_type.max_neighbors] ) @@ -985,7 +985,7 @@ def _visit_neighbors(self, node: gtir.FunCall) -> ValueExpr: offset = node.args[0].value assert isinstance(offset, str) conn_type = self.subgraph_builder.get_offset_provider_type(offset) - assert isinstance(conn_type, gtx_common.NeighborConnectivityType) + assert isinstance(conn_type, gtx_common.ConnectivityType) it = self.visit(node.args[1]) assert isinstance(it, IteratorExpr) @@ -1210,7 +1210,7 @@ def _visit_map(self, node: gtir.FunCall) -> ValueExpr: tasklet_expression = f"{output_connector} = {fun_python_code}" input_args = [self.visit(arg) for arg in node.args] - input_conn_types: dict[gtx_common.Dimension, gtx_common.NeighborConnectivityType] = {} + input_conn_types: dict[gtx_common.Dimension, gtx_common.ConnectivityType] = {} for input_arg in input_args: assert isinstance(input_arg.gt_dtype, ts.ListType) assert input_arg.gt_dtype.offset_type is not None @@ -1219,7 +1219,7 @@ def _visit_map(self, node: gtir.FunCall) -> ValueExpr: # this input argument is the result of `make_const_list` continue offset_provider_t = self.subgraph_builder.get_offset_provider_type(offset_type.value) - assert isinstance(offset_provider_t, gtx_common.NeighborConnectivityType) + assert isinstance(offset_provider_t, gtx_common.ConnectivityType) input_conn_types[offset_type] = offset_provider_t if len(input_conn_types) == 0: @@ -1332,7 +1332,7 @@ def _visit_map(self, node: gtir.FunCall) -> ValueExpr: def _make_reduce_with_skip_values( self, input_expr: ValueExpr | MemletExpr, - offset_provider_type: gtx_common.NeighborConnectivityType, + offset_provider_type: gtx_common.ConnectivityType, reduce_init: SymbolExpr, reduce_identity: SymbolExpr, reduce_wcr: str, @@ -1478,7 +1478,7 @@ def _visit_reduce(self, node: gtir.FunCall) -> ValueExpr: ) offset_type = input_expr.gt_dtype.offset_type offset_provider_type = self.subgraph_builder.get_offset_provider_type(offset_type.value) - assert isinstance(offset_provider_type, gtx_common.NeighborConnectivityType) + assert isinstance(offset_provider_type, gtx_common.ConnectivityType) if offset_provider_type.has_skip_values: self._make_reduce_with_skip_values( @@ -1655,7 +1655,7 @@ def _make_dynamic_neighbor_offset( def _make_unstructured_shift( self, it: IteratorExpr, - conn_type: gtx_common.NeighborConnectivityType, + conn_type: gtx_common.ConnectivityType, conn_node: dace_nodes.AccessNode, offset_expr: DataExpr, ) -> IteratorExpr: diff --git a/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg_primitives.py b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg_primitives.py index 007202a87d..5f1e67dd44 100644 --- a/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg_primitives.py +++ b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg_primitives.py @@ -330,7 +330,7 @@ def _construct_if_branch_output( offset_provider_type = sdfg_builder.get_offset_provider_type( out_type.dtype.offset_type.value ) - assert isinstance(offset_provider_type, gtx_common.NeighborConnectivityType) + assert isinstance(offset_provider_type, gtx_common.ConnectivityType) shape = [*shape, offset_provider_type.max_neighbors] out, _ = sdfg_builder.add_temp_array(ctx.sdfg, shape, dtype) diff --git a/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg_scan.py b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg_scan.py index a569c06fbf..38fe78b47a 100644 --- a/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg_scan.py +++ b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg_scan.py @@ -383,7 +383,7 @@ def get_scan_output_shape( assert scan_init_data.gt_type.offset_type offset_type = scan_init_data.gt_type.offset_type offset_provider_type = sdfg_builder.get_offset_provider_type(offset_type.value) - assert isinstance(offset_provider_type, gtx_common.NeighborConnectivityType) + assert isinstance(offset_provider_type, gtx_common.ConnectivityType) list_size = offset_provider_type.max_neighbors return [scan_column_size, dace.symbolic.SymExpr(list_size)] diff --git a/src/gt4py/next/program_processors/runners/dace/program.py b/src/gt4py/next/program_processors/runners/dace/program.py index f8c8fd84a3..c041296794 100644 --- a/src/gt4py/next/program_processors/runners/dace/program.py +++ b/src/gt4py/next/program_processors/runners/dace/program.py @@ -153,10 +153,10 @@ def __sdfg_closure__(self, reevaluate: Optional[dict[str, str]] = None) -> dict[ if not self.compilation_options.connectivities: return {} - used_connectivities: dict[str, gtx_common.NeighborConnectivity] = { + used_connectivities: dict[str, gtx_common.Connectivity] = { conn_id: conn for offset, conn in self.compilation_options.connectivities.items() - if gtx_common.is_neighbor_table(conn) + if isinstance(conn, gtx_common.Connectivity) and (conn_id := gtx_dace_args.connectivity_identifier(offset)) in self.sdfg_closure_cache["arrays"] } diff --git a/src/gt4py/next/program_processors/runners/dace/sdfg_args.py b/src/gt4py/next/program_processors/runners/dace/sdfg_args.py index f33fbf8bb5..4f10cd5a15 100644 --- a/src/gt4py/next/program_processors/runners/dace/sdfg_args.py +++ b/src/gt4py/next/program_processors/runners/dace/sdfg_args.py @@ -83,7 +83,7 @@ def _field_symbol( assert m[1] in offset_provider_type offset = m[1] conn_type = offset_provider_type[offset] - assert isinstance(conn_type, gtx_common.NeighborConnectivityType) + assert isinstance(conn_type, gtx_common.ConnectivityType) if dim == conn_type.source_dim: name = f"__{field_name}_source_{sym}" elif dim == conn_type.neighbor_dim: @@ -130,14 +130,14 @@ def range_stop_symbol(field_name: str, dim: gtx_common.Dimension) -> dace.symbol def filter_connectivity_types( offset_provider_type: gtx_common.OffsetProviderType, -) -> dict[str, gtx_common.NeighborConnectivityType]: +) -> dict[str, gtx_common.ConnectivityType]: """ - Filter offset provider types of type `NeighborConnectivityType`. + Filter offset provider types of type `ConnectivityType`. In other words, filter out the cartesian offset providers. """ return { offset: conn for offset, conn in offset_provider_type.items() - if isinstance(conn, gtx_common.NeighborConnectivityType) + if isinstance(conn, gtx_common.ConnectivityType) } diff --git a/src/gt4py/next/program_processors/runners/dace/sdfg_callable.py b/src/gt4py/next/program_processors/runners/dace/sdfg_callable.py index bfae24bef2..c79b510ec6 100644 --- a/src/gt4py/next/program_processors/runners/dace/sdfg_callable.py +++ b/src/gt4py/next/program_processors/runners/dace/sdfg_callable.py @@ -99,7 +99,7 @@ def get_sdfg_conn_args( for offset, connectivity in offset_provider.items(): name = gtx_dace_args.connectivity_identifier(offset) if name in sdfg.arrays: - assert gtx_common.is_neighbor_connectivity(connectivity) + assert isinstance(connectivity, gtx_common.Connectivity) assert field_utils.verify_device_field_type( connectivity, core_defs.CUPY_DEVICE_TYPE diff --git a/src/gt4py/next/program_processors/runners/dace/workflow/translation.py b/src/gt4py/next/program_processors/runners/dace/workflow/translation.py index ad8e8ea04b..3969539a10 100644 --- a/src/gt4py/next/program_processors/runners/dace/workflow/translation.py +++ b/src/gt4py/next/program_processors/runners/dace/workflow/translation.py @@ -57,7 +57,7 @@ def find_constant_symbols( # Same for connectivity tables, for which the first dimension is always horizontal for offset, conn_type in offset_provider_type.items(): if ( - isinstance(conn_type, common.NeighborConnectivityType) + isinstance(conn_type, common.ConnectivityType) and (conn_id := gtx_dace_args.connectivity_identifier(offset)) in sdfg.arrays ): assert not sdfg.arrays[conn_id].transient diff --git a/src/gt4py/next/program_processors/runners/gtfn.py b/src/gt4py/next/program_processors/runners/gtfn.py index c1743dea6a..3565eca7e7 100644 --- a/src/gt4py/next/program_processors/runners/gtfn.py +++ b/src/gt4py/next/program_processors/runners/gtfn.py @@ -98,7 +98,7 @@ def extract_connectivity_args( if (ndarray := getattr(conn, "ndarray", None)) is not None ] assert all( - common.is_neighbor_table(conn) and field_utils.verify_device_field_type(conn, device) + isinstance(conn, common.Connectivity) and field_utils.verify_device_field_type(conn, device) for conn in offset_provider.values() if hasattr(conn, "ndarray") ) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py index f410dadc95..93a5164ae1 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py @@ -766,14 +766,14 @@ def testee(a: cases.VField) -> cases.EField: a = cases.allocate(unstructured_case, testee, "a")() out = cases.allocate(unstructured_case, testee, cases.RETURN)() - ORIGIN = 2 + origin = 2 e2v_table = unstructured_case.offset_provider["E2V"].asnumpy() neighbor_iter = iter(enumerate(e2v_table)) - edge_start = next(i for i, v in neighbor_iter if all(v >= ORIGIN)) - edge_stop = next(i for i, v in neighbor_iter if any(v < ORIGIN)) + edge_start = next(i for i, v in neighbor_iter if all(v >= origin)) + edge_stop = next(i for i, v in neighbor_iter if any(v < origin)) ref = np.sum(a.ndarray[e2v_table[edge_start:edge_stop,]], axis=1) - cases.verify(unstructured_case, testee, a[ORIGIN:], out=out[edge_start:edge_stop], ref=ref) + cases.verify(unstructured_case, testee, a[origin:], out=out[edge_start:edge_stop], ref=ref) @pytest.mark.uses_unstructured_shift diff --git a/tests/next_tests/regression_tests/ffront_tests/test_offset_dimensions_names.py b/tests/next_tests/regression_tests/ffront_tests/test_offset_dimensions_names.py index 9b340b5cde..410fb87ba2 100644 --- a/tests/next_tests/regression_tests/ffront_tests/test_offset_dimensions_names.py +++ b/tests/next_tests/regression_tests/ffront_tests/test_offset_dimensions_names.py @@ -51,7 +51,7 @@ def test_offset_dimension_name_differ(case): """ Ensure that gtfn works with offset name that differs from the name of the local dimension. - If the value of the `NeighborConnectivityType.neighbor_dim` did not match the `FieldOffset` value, + If the value of the `ConnectivityType.neighbor_dim` did not match the `FieldOffset` value, gtfn would silently ignore the neighbor index, see https://github.com/GridTools/gridtools/pull/1814. """ diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_unroll_reduce.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_unroll_reduce.py index 8083c56a8f..387a27634e 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_unroll_reduce.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_unroll_reduce.py @@ -16,7 +16,7 @@ def dummy_connectivity_type(max_neighbors: int, has_skip_values: bool): - return common.NeighborConnectivityType( + return common.ConnectivityType( domain=[common.Dimension("dummy_origin"), common.Dimension("dummy_neighbor")], codomain=common.Dimension("dummy_codomain"), skip_value=common._DEFAULT_SKIP_VALUE if has_skip_values else None,