Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/development/ADRs/next/0019-Connectivities.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ We update and introduce the following concepts

**NeighborTable** is a _NeighborConnectivity_ backed by a buffer.

**ConnectivityType**, **NeighborConnectivityType** contains all information that is needed for compilation.
**ConnectivityType** contains all information that is needed for compilation.

### Full definitions

Expand Down
67 changes: 12 additions & 55 deletions src/gt4py/next/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1012,22 +1012,18 @@ def remapping(cls) -> ConnectivityKind:


@dataclasses.dataclass(frozen=True)
class ConnectivityType: # TODO(havogt): would better live in type_specifications but would have to solve a circular import
class ConnectivityType:
domain: tuple[Dimension, ...]
codomain: Dimension
skip_value: Optional[core_defs.IntegralScalar]
dtype: core_defs.DType
# TODO(havogt): refactor towards encoding this information in the local dimensions of the ConnectivityType.domain
max_neighbors: int

@property
def has_skip_values(self) -> bool:
return self.skip_value is not None


@dataclasses.dataclass(frozen=True)
class NeighborConnectivityType(ConnectivityType):
# TODO(havogt): refactor towards encoding this information in the local dimensions of the ConnectivityType.domain
max_neighbors: int

@property
def source_dim(self) -> Dimension:
return self.domain[0]
Expand All @@ -1053,21 +1049,13 @@ def codomain(self) -> DimT_co:
"""

def __gt_type__(self) -> ConnectivityType:
if is_neighbor_connectivity(self):
return NeighborConnectivityType(
domain=self.domain.dims,
codomain=self.codomain,
dtype=self.dtype,
skip_value=self.skip_value,
max_neighbors=self.ndarray.shape[1],
)
else:
return ConnectivityType(
domain=self.domain.dims,
codomain=self.codomain,
dtype=self.dtype,
skip_value=self.skip_value,
)
return ConnectivityType(
domain=self.domain.dims,
codomain=self.codomain,
dtype=self.dtype,
skip_value=self.skip_value,
max_neighbors=self.ndarray.shape[1],
)

@property
def kind(self) -> ConnectivityKind:
Expand Down Expand Up @@ -1178,39 +1166,8 @@ def _connectivity(
raise NotImplementedError


class NeighborConnectivity(Connectivity, Protocol):
# TODO(havogt): work towards encoding this properly in the type
def __gt_type__(self) -> NeighborConnectivityType: ...


def is_neighbor_connectivity(obj: Any) -> TypeGuard[NeighborConnectivity]:
if not isinstance(obj, Connectivity):
return False
domain_dims = obj.domain.dims
return (
len(domain_dims) == 2
and domain_dims[0].kind is DimensionKind.HORIZONTAL
and domain_dims[1].kind is DimensionKind.LOCAL
)


class NeighborTable(
NeighborConnectivity, Protocol
): # TODO(havogt): try to express by inheriting from NdArrayConnectivityField (but this would require a protocol to move it out of `embedded.nd_array_field`)
@property
def ndarray(self) -> core_defs.NDArrayObject:
# Note that this property is currently already there from inheriting from `Field`,
# however this seems wrong, therefore we explicitly introduce it here (or it should come
# implicitly from the `NdArrayConnectivityField` protocol).
...


def is_neighbor_table(obj: Any) -> TypeGuard[NeighborTable]:
return is_neighbor_connectivity(obj) and hasattr(obj, "ndarray")


OffsetProviderElem: TypeAlias = Dimension | NeighborConnectivity
OffsetProviderTypeElem: TypeAlias = Dimension | NeighborConnectivityType
OffsetProviderElem: TypeAlias = Dimension | Connectivity
OffsetProviderTypeElem: TypeAlias = Dimension | ConnectivityType
# Note: `OffsetProvider` and `OffsetProviderType` should not be accessed directly,
# use the `get_offset` and `get_offset_type` functions instead.
OffsetProvider: TypeAlias = Mapping[Tag, OffsetProviderElem]
Expand Down
2 changes: 1 addition & 1 deletion src/gt4py/next/embedded/nd_array_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -1009,7 +1009,7 @@ def _builtin_op(
offset_definition = common.get_offset(
current_offset_provider, axis.value
) # assumes offset and local dimension have same name
assert common.is_neighbor_table(offset_definition)
assert isinstance(offset_definition, common.Connectivity)
new_domain = common.Domain(*[nr for nr in field.domain if nr.dim != axis])

broadcast_slice = tuple(
Expand Down
1 change: 0 additions & 1 deletion src/gt4py/next/ffront/fbuiltins.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,7 +481,6 @@ def __getitem__(self, offset: int) -> common.Connectivity:
if isinstance(offset_definition, common.Dimension):
connectivity = common.CartesianConnectivity(offset_definition, offset)
elif isinstance(offset_definition, common.Connectivity):
assert common.is_neighbor_connectivity(offset_definition)
named_index = common.NamedIndex(self.target[-1], offset)
connectivity = offset_definition[named_index]
else:
Expand Down
11 changes: 3 additions & 8 deletions src/gt4py/next/iterator/embedded.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,8 +123,8 @@ def __init__(
def __gt_origin__(self) -> xtyping.Never:
raise NotImplementedError

def __gt_type__(self) -> common.NeighborConnectivityType:
return common.NeighborConnectivityType(
def __gt_type__(self) -> common.ConnectivityType:
return common.ConnectivityType(
domain=self.domain_dims,
codomain=self.codomain_dim,
max_neighbors=self._max_neighbors,
Expand Down Expand Up @@ -563,7 +563,6 @@ def execute_shift(
new_entry[i] = 0
else:
offset_implementation = common.get_offset(offset_provider, tag)
assert common.is_neighbor_connectivity(offset_implementation)
source_dim = offset_implementation.__gt_type__().source_dim
cur_index = pos[source_dim.value]
assert common.is_int_index(cur_index)
Expand All @@ -586,7 +585,7 @@ def execute_shift(
else:
raise AssertionError()
return new_pos
elif common.is_neighbor_connectivity(offset_implementation):
elif isinstance(offset_implementation, common.Connectivity):
source_dim = offset_implementation.__gt_type__().source_dim
assert source_dim.value in pos
new_pos = pos.copy()
Expand Down Expand Up @@ -1400,7 +1399,6 @@ def __gt_type__(self) -> ts.ListType:
offset_provider = embedded_context.get_offset_provider()
assert offset_provider is not None
connectivity = common.get_offset(offset_provider, offset_tag)
assert common.is_neighbor_connectivity(connectivity)
local_dim = connectivity.__gt_type__().neighbor_dim
return ts.ListType(element_type=element_type, offset_type=local_dim)

Expand Down Expand Up @@ -1428,7 +1426,6 @@ def neighbors(offset: runtime.Offset, it: ItIterator) -> _List:
offset_provider = embedded_context.get_offset_provider()
assert offset_provider is not None
connectivity = common.get_offset(offset_provider, offset_str)
assert common.is_neighbor_connectivity(connectivity)
return _List(
values=tuple(
shifted.deref()
Expand Down Expand Up @@ -1506,7 +1503,6 @@ def deref(self) -> Any:
offset_provider = embedded_context.get_offset_provider()
assert offset_provider is not None
connectivity = common.get_offset(offset_provider, self.list_offset)
assert common.is_neighbor_connectivity(connectivity)
return _List(
values=tuple(
shifted.deref()
Expand Down Expand Up @@ -1763,7 +1759,6 @@ def _fieldspec_list_to_value(
offset_type = type_.offset_type
assert isinstance(offset_type, common.Dimension)
connectivity = common.get_offset(offset_provider, offset_type.value)
assert common.is_neighbor_connectivity(connectivity)
return domain.insert(
len(domain),
common.named_range((offset_type, connectivity.__gt_type__().max_neighbors)),
Expand Down
3 changes: 1 addition & 2 deletions src/gt4py/next/iterator/ir_utils/domain_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,6 @@ def _unstructured_translate_range_statically(
"""
assert common.is_offset_provider(offset_provider)
connectivity = offset_provider[tag]
assert common.is_neighbor_connectivity(connectivity)
skip_value = connectivity.skip_value

# fold & convert expr into actual integers
Expand Down Expand Up @@ -192,7 +191,7 @@ def translate(
new_ranges[current_dim] = SymbolicRange.translate(
self.ranges[current_dim], val.value
)
elif isinstance(connectivity_type, common.NeighborConnectivityType):
elif isinstance(connectivity_type, common.ConnectivityType):
# unstructured shift
assert (
isinstance(val, itir.OffsetLiteral) and isinstance(val.value, int)
Expand Down
2 changes: 1 addition & 1 deletion src/gt4py/next/iterator/transforms/pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def _max_domain_range_sizes(offset_provider: common.OffsetProvider) -> dict[str,
"""
sizes: dict[str, int] = {}
for provider in offset_provider.values():
if common.is_neighbor_connectivity(provider):
if isinstance(provider, common.Connectivity):
src_dim = provider.__gt_type__().source_dim.value
codomain_dim = provider.__gt_type__().codomain.value
sizes[src_dim] = max(sizes.get(src_dim, 0), provider.ndarray.shape[0])
Expand Down
6 changes: 3 additions & 3 deletions src/gt4py/next/iterator/transforms/unroll_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,15 +53,15 @@ def _get_partial_offset_tags(reduce_args: Iterable[itir.Expr]) -> Iterable[str]:
def _get_connectivity(
applied_reduce_node: itir.FunCall,
offset_provider_type: common.OffsetProviderType,
) -> common.NeighborConnectivityType:
) -> common.ConnectivityType:
"""Return single connectivity that is compatible with the arguments of the reduce."""
if not cpm.is_applied_reduce(applied_reduce_node):
raise ValueError("Expected a call to a 'reduce' object, i.e. 'reduce(...)(...)'.")

connectivities: list[common.NeighborConnectivityType] = []
connectivities: list[common.ConnectivityType] = []
for o in _get_partial_offset_tags(applied_reduce_node.args):
conn = common.get_offset_type(offset_provider_type, o)
assert isinstance(conn, common.NeighborConnectivityType)
assert isinstance(conn, common.ConnectivityType)
connectivities.append(conn)

if not connectivities:
Expand Down
8 changes: 4 additions & 4 deletions src/gt4py/next/iterator/type_system/type_synthesizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,7 @@ def neighbors(
)
assert isinstance(it, it_ts.IteratorType)
conn_type = common.get_offset_type(offset_provider_type, offset_literal.value)
assert isinstance(conn_type, common.NeighborConnectivityType)
assert isinstance(conn_type, common.ConnectivityType)
return ts.ListType(element_type=it.element_type, offset_type=conn_type.neighbor_dim)


Expand Down Expand Up @@ -494,7 +494,7 @@ def _resolve_dimensions(
... itir.OffsetLiteral(value=0),
... )
>>> offset_provider_type = {
... "V2E": common.NeighborConnectivityType(
... "V2E": common.ConnectivityType(
... domain=(Vertex, V2E),
... codomain=Edge,
... skip_value=None,
Expand All @@ -515,7 +515,7 @@ def _resolve_dimensions(
offset_type = common.get_offset_type(offset_provider_type, off_literal.value)
if isinstance(offset_type, common.Dimension) and input_dim == offset_type:
continue # No shift applied
if isinstance(offset_type, (fbuiltins.FieldOffset, common.NeighborConnectivityType)):
if isinstance(offset_type, (fbuiltins.FieldOffset, common.ConnectivityType)):
if input_dim == offset_type.codomain: # Check if input fits to offset
input_dim = offset_type.domain[0] # Update input_dim for next iteration
resolved_dims.append(input_dim)
Expand Down Expand Up @@ -670,7 +670,7 @@ def apply_shift(
type_ = common.get_offset_type(offset_provider_type, offset_axis.value)
if isinstance(type_, common.Dimension):
pass
elif isinstance(type_, common.NeighborConnectivityType):
elif isinstance(type_, common.ConnectivityType):
found = False
for i, dim in enumerate(new_position_dims):
if dim.value == type_.source_dim.value:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def _process_regular_arguments(
# translate sparse dimensions to tuple dtype
dim_name = dim.value
connectivity = common.get_offset_type(offset_provider_type, dim_name)
assert isinstance(connectivity, common.NeighborConnectivityType)
assert isinstance(connectivity, common.ConnectivityType)
size = connectivity.max_neighbors
arg = f"gridtools::sid::dimension_to_tuple_like<generated::{dim_name}_t, {size}>({arg})"
arg_exprs.append(arg)
Expand All @@ -106,7 +106,7 @@ def _process_connectivity_args(
arg_exprs: list[str] = []

for name, connectivity_type in offset_provider_type.items():
if isinstance(connectivity_type, common.NeighborConnectivityType):
if isinstance(connectivity_type, common.ConnectivityType):
if connectivity_type.dtype.scalar_type not in [np.int32, np.int64]:
raise ValueError(
"Neighbor table indices must be of type 'np.int32' or 'np.int64'."
Expand Down Expand Up @@ -138,7 +138,7 @@ def _process_connectivity_args(
pass
else:
raise AssertionError(
f"Expected offset provider type '{name}' to be a 'NeighborConnectivityType' or 'Dimension', "
f"Expected offset provider type '{name}' to be a 'ConnectivityType' or 'Dimension', "
f"got '{type(connectivity_type).__name__}'."
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -166,9 +166,7 @@ def _collect_offset_definitions(
offset_definitions[offset_name] = TagDefinition(
name=Sym(id=offset_name), alias=SymRef(id=dim.value)
)
elif isinstance(
connectivity_type := dim_or_connectivity_type, common.NeighborConnectivityType
):
elif isinstance(connectivity_type := dim_or_connectivity_type, common.ConnectivityType):
assert grid_type == common.GridType.UNSTRUCTURED
offset_definitions[offset_name] = TagDefinition(name=Sym(id=offset_name))
if offset_name != connectivity_type.neighbor_dim.value:
Expand Down Expand Up @@ -456,7 +454,7 @@ def _visit_unstructured_domain(self, node: itir.FunCall, **kwargs: Any) -> Node:
for o in shift_offsets:
if o in self.offset_provider_type and isinstance(
common.get_offset_type(self.offset_provider_type, o),
common.NeighborConnectivityType,
common.ConnectivityType,
):
connectivities.append(SymRef(id=o))
return UnstructuredDomain(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -714,7 +714,7 @@ def _visit_if_branch_arg(
offset_provider_type = self.subgraph_builder.get_offset_provider_type(
arg.gt_dtype.offset_type.value
)
assert isinstance(offset_provider_type, gtx_common.NeighborConnectivityType)
assert isinstance(offset_provider_type, gtx_common.ConnectivityType)
inner_desc = dace.data.Array(
dtype=arg_desc.dtype, shape=[offset_provider_type.max_neighbors]
)
Expand Down Expand Up @@ -985,7 +985,7 @@ def _visit_neighbors(self, node: gtir.FunCall) -> ValueExpr:
offset = node.args[0].value
assert isinstance(offset, str)
conn_type = self.subgraph_builder.get_offset_provider_type(offset)
assert isinstance(conn_type, gtx_common.NeighborConnectivityType)
assert isinstance(conn_type, gtx_common.ConnectivityType)

it = self.visit(node.args[1])
assert isinstance(it, IteratorExpr)
Expand Down Expand Up @@ -1210,7 +1210,7 @@ def _visit_map(self, node: gtir.FunCall) -> ValueExpr:
tasklet_expression = f"{output_connector} = {fun_python_code}"

input_args = [self.visit(arg) for arg in node.args]
input_conn_types: dict[gtx_common.Dimension, gtx_common.NeighborConnectivityType] = {}
input_conn_types: dict[gtx_common.Dimension, gtx_common.ConnectivityType] = {}
for input_arg in input_args:
assert isinstance(input_arg.gt_dtype, ts.ListType)
assert input_arg.gt_dtype.offset_type is not None
Expand All @@ -1219,7 +1219,7 @@ def _visit_map(self, node: gtir.FunCall) -> ValueExpr:
# this input argument is the result of `make_const_list`
continue
offset_provider_t = self.subgraph_builder.get_offset_provider_type(offset_type.value)
assert isinstance(offset_provider_t, gtx_common.NeighborConnectivityType)
assert isinstance(offset_provider_t, gtx_common.ConnectivityType)
input_conn_types[offset_type] = offset_provider_t

if len(input_conn_types) == 0:
Expand Down Expand Up @@ -1332,7 +1332,7 @@ def _visit_map(self, node: gtir.FunCall) -> ValueExpr:
def _make_reduce_with_skip_values(
self,
input_expr: ValueExpr | MemletExpr,
offset_provider_type: gtx_common.NeighborConnectivityType,
offset_provider_type: gtx_common.ConnectivityType,
reduce_init: SymbolExpr,
reduce_identity: SymbolExpr,
reduce_wcr: str,
Expand Down Expand Up @@ -1478,7 +1478,7 @@ def _visit_reduce(self, node: gtir.FunCall) -> ValueExpr:
)
offset_type = input_expr.gt_dtype.offset_type
offset_provider_type = self.subgraph_builder.get_offset_provider_type(offset_type.value)
assert isinstance(offset_provider_type, gtx_common.NeighborConnectivityType)
assert isinstance(offset_provider_type, gtx_common.ConnectivityType)

if offset_provider_type.has_skip_values:
self._make_reduce_with_skip_values(
Expand Down Expand Up @@ -1655,7 +1655,7 @@ def _make_dynamic_neighbor_offset(
def _make_unstructured_shift(
self,
it: IteratorExpr,
conn_type: gtx_common.NeighborConnectivityType,
conn_type: gtx_common.ConnectivityType,
conn_node: dace_nodes.AccessNode,
offset_expr: DataExpr,
) -> IteratorExpr:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,7 @@ def _construct_if_branch_output(
offset_provider_type = sdfg_builder.get_offset_provider_type(
out_type.dtype.offset_type.value
)
assert isinstance(offset_provider_type, gtx_common.NeighborConnectivityType)
assert isinstance(offset_provider_type, gtx_common.ConnectivityType)
shape = [*shape, offset_provider_type.max_neighbors]

out, _ = sdfg_builder.add_temp_array(ctx.sdfg, shape, dtype)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,7 @@ def get_scan_output_shape(
assert scan_init_data.gt_type.offset_type
offset_type = scan_init_data.gt_type.offset_type
offset_provider_type = sdfg_builder.get_offset_provider_type(offset_type.value)
assert isinstance(offset_provider_type, gtx_common.NeighborConnectivityType)
assert isinstance(offset_provider_type, gtx_common.ConnectivityType)
list_size = offset_provider_type.max_neighbors
return [scan_column_size, dace.symbolic.SymExpr(list_size)]

Expand Down
Loading
Loading