diff --git a/legate/core/__init__.py b/legate/core/__init__.py index 8a278ed92f..92c4fbdbfa 100644 --- a/legate/core/__init__.py +++ b/legate/core/__init__.py @@ -52,6 +52,8 @@ PartitionByRestriction, PartitionByImage, PartitionByImageRange, + PartitionByPreimage, + PartitionByPreimageRange, EqualPartition, PartitionByWeights, IndexPartition, diff --git a/legate/core/_legion/__init__.py b/legate/core/_legion/__init__.py index f67fa67930..8c6529365e 100644 --- a/legate/core/_legion/__init__.py +++ b/legate/core/_legion/__init__.py @@ -38,6 +38,8 @@ PartitionByRestriction, PartitionByImage, PartitionByImageRange, + PartitionByPreimage, + PartitionByPreimageRange, EqualPartition, PartitionByWeights, PartitionByDomain, @@ -88,6 +90,8 @@ "PartitionByDomain", "PartitionByImage", "PartitionByImageRange", + "PartitionByPreimage", + "PartitionByPreimageRange", "PartitionByRestriction", "PartitionByWeights", "PartitionFunctor", diff --git a/legate/core/_legion/future.py b/legate/core/_legion/future.py index 470af4dd03..0b4b16248a 100644 --- a/legate/core/_legion/future.py +++ b/legate/core/_legion/future.py @@ -343,7 +343,7 @@ def from_list( context, domain, points, - futures, + futures_, num_futures, False, 0, diff --git a/legate/core/_legion/partition.py b/legate/core/_legion/partition.py index c0d4c5c24d..70aff9c671 100644 --- a/legate/core/_legion/partition.py +++ b/legate/core/_legion/partition.py @@ -109,6 +109,14 @@ def get_root(self) -> Region: """ return self.parent.get_root() + @property + def disjoint(self) -> bool: + return self.index_partition.disjoint + + @property + def complete(self) -> bool: + return self.index_partition.complete + class IndexPartition: _logical_handle: Any @@ -260,3 +268,15 @@ def get_root(self) -> IndexSpace: Return the root IndexSpace in this tree. """ return self.parent.get_root() + + @property + def disjoint(self) -> bool: + return legion.legion_index_partition_is_disjoint( # type: ignore + self.runtime, self.handle + ) + + @property + def complete(self) -> bool: + return legion.legion_index_partition_is_complete( # type: ignore + self.runtime, self.handle + ) diff --git a/legate/core/_legion/partition_functor.py b/legate/core/_legion/partition_functor.py index 2f98224419..f34998de9a 100644 --- a/legate/core/_legion/partition_functor.py +++ b/legate/core/_legion/partition_functor.py @@ -17,11 +17,20 @@ from typing import TYPE_CHECKING, Any, Union from .. import ffi, legion +from . import FieldID from .future import FutureMap from .geometry import Point if TYPE_CHECKING: - from . import FieldID, IndexPartition, IndexSpace, Rect, Region, Transform + from . import ( + FieldID, + IndexPartition, + IndexSpace, + Partition, + Rect, + Region, + Transform, + ) class PartitionFunctor: @@ -103,7 +112,7 @@ class PartitionByImage(PartitionFunctor): def __init__( self, region: Region, - part: IndexPartition, + part: Partition, field: Union[int, FieldID], mapper: int = 0, tag: int = 0, @@ -148,7 +157,7 @@ class PartitionByImageRange(PartitionFunctor): def __init__( self, region: Region, - part: IndexPartition, + part: Partition, field: Union[int, FieldID], mapper: int = 0, tag: int = 0, diff --git a/legate/core/constraints.py b/legate/core/constraints.py index 5b910c3e32..81ecd55266 100644 --- a/legate/core/constraints.py +++ b/legate/core/constraints.py @@ -17,7 +17,7 @@ from collections.abc import Iterable from typing import TYPE_CHECKING, Any, Iterator, Optional, Protocol, Union -from .partition import Restriction +from .partition import ImagePartition, Replicate, Restriction if TYPE_CHECKING: from .partition import PartitionBase @@ -219,6 +219,79 @@ def unknowns(self) -> Iterator[PartSym]: yield unknown +class Image(Expr): + def __init__( + self, + source_store: Store, + dst_store: Store, + src_part_sym: Expr, + mapper: int, + range: bool = False, + functor: Any = ImagePartition, + disjoint: bool = True, + complete: bool = True, + ): + self._source_store = source_store + self._dst_store = dst_store + self._src_part_sym = src_part_sym + self._mapper = mapper + self._range = range + self._functor = functor + self._disjoint = disjoint + self._complete = complete + + def subst(self, mapping: dict[PartSym, PartitionBase]) -> Expr: + return Image( + self._source_store, + self._dst_store, + self._src_part_sym.subst(mapping), + self._mapper, + range=self._range, + functor=self._functor, + disjoint=self._disjoint, + complete=self._complete, + ) + + @property + def ndim(self) -> int: + return self._src_part_sym.ndim + + def reduce(self) -> Lit: + expr = self._src_part_sym.reduce() + assert isinstance(expr, Lit) + part = expr._part + if isinstance(part, Replicate): + return Lit(part) + return Lit( + self._functor( + self._source_store, + part, + self._mapper, + range=self._range, + disjoint=self._disjoint, + complete=self._complete, + ) + ) + + def unknowns(self) -> Iterator[PartSym]: + for unknown in self._src_part_sym.unknowns(): + yield unknown + + def equals(self, other: object) -> bool: + return ( + isinstance(other, Image) + and self._source_store == other._source_store + and self._dst_store == other._dst_store + # Careful! Overloaded equals operator. + and self._src_part_sym is other._src_part_sym + and self._range == other._range + and self._mapper == other._mapper + and self._functor == other._functor + and self._disjoint == other._disjoint + and self._complete == other._complete + ) + + class Constraint: pass diff --git a/legate/core/operation.py b/legate/core/operation.py index bdc2828096..ac4fa68b77 100644 --- a/legate/core/operation.py +++ b/legate/core/operation.py @@ -27,9 +27,9 @@ import legate.core.types as ty from . import Future, FutureMap, Rect -from .constraints import PartSym +from .constraints import Image, PartSym from .launcher import CopyLauncher, FillLauncher, TaskLauncher -from .partition import REPLICATE, Weighted +from .partition import REPLICATE, ImagePartition, Weighted from .shape import Shape from .store import Store, StorePartition from .utils import OrderedSet, capture_traceback_repr @@ -157,6 +157,12 @@ def get_all_stores(self) -> OrderedSet[Store]: result.update(store for (store, _) in self._reductions) return result + def get_all_modified_stores(self) -> OrderedSet[Store]: + result: OrderedSet[Store] = OrderedSet() + result.update(self._outputs) + result.update(store for (store, _) in self._reductions) + return result + def add_alignment(self, store1: Store, store2: Store) -> None: """ Sets an alignment between stores. Equivalent to the following code: @@ -216,6 +222,36 @@ def add_broadcast( part = self._get_unique_partition(store) self.add_constraint(part.broadcast(axes=axes)) + # add_image_constraint adds a constraint that the image of store1 is + # contained within the partition of store2. + def add_image_constraint( + self, + store1: Store, + store2: Store, + range: bool = False, + functor: Any = ImagePartition, + disjoint: bool = True, + complete: bool = True, + ) -> None: + self._check_store(store1) + self._check_store(store2) + # TODO (rohany): We only support point (and rect types if range) here. + # It seems like rects should be added to legate.core's type system + # rather than an external type system to understand this then. + part1 = self._get_unique_partition(store1) + part2 = self._get_unique_partition(store2) + image = Image( + store1, + store2, + part1, + self._context.mapper_id, + range=range, + functor=functor, + disjoint=disjoint, + complete=complete, + ) + self.add_constraint(image <= part2) + def add_constraint(self, constraint: Constraint) -> None: """ Adds a partitioning constraint to the operation @@ -851,7 +887,8 @@ def __init__( op_id=op_id, ) self._launch_domain: Rect = launch_domain - self._input_projs: list[Union[ProjFn, None]] = [] + # TODO (rohany): The int here is an explicit ID. + self._input_projs: list[Union[ProjFn, None, int]] = [] self._output_projs: list[Union[ProjFn, None]] = [] self._reduction_projs: list[Union[ProjFn, None]] = [] @@ -859,6 +896,8 @@ def __init__( self._output_parts: list[StorePartition] = [] self._reduction_parts: list[tuple[StorePartition, int]] = [] + self._scalar_future_maps: list[FutureMap] = [] + @property def launch_ndim(self) -> int: return self._launch_domain.dim @@ -876,7 +915,7 @@ def _check_arg(arg: Union[Store, StorePartition]) -> None: def add_input( self, arg: Union[Store, StorePartition], - proj: Optional[ProjFn] = None, + proj: Optional[Union[ProjFn, int]] = None, ) -> None: """ Adds a store as input to the task @@ -918,10 +957,9 @@ def add_output( self._check_arg(arg) if isinstance(arg, Store): if arg.unbound: - raise NotImplementedError( - "Unbound store cannot be used with " - "manually parallelized task" - ) + self._unbound_outputs.append(len(self._outputs)) + self._outputs.append(arg) + return if arg.kind is Future: self._scalar_outputs.append(len(self._outputs)) self._outputs.append(arg) @@ -1006,6 +1044,16 @@ def launch(self, strategy: Strategy) -> None: part.store, req, tag=0, read_write=can_read_write ) + for fm in self._scalar_future_maps: + launcher.add_future_map(fm) + + # Add all unbound stores. + for store_idx in self._unbound_outputs: + store = self._outputs[store_idx] + fspace = self.context.runtime.create_field_space() + field_id = fspace.allocate_field(store.type) + launcher.add_unbound_output(store, fspace, field_id) + self._add_scalar_args_to_launcher(launcher) launcher.set_can_raise_exception(self.can_raise_exception) diff --git a/legate/core/partition.py b/legate/core/partition.py index c0eb24ad6b..41e2b15580 100644 --- a/legate/core/partition.py +++ b/legate/core/partition.py @@ -16,14 +16,22 @@ from abc import ABC, abstractmethod, abstractproperty from functools import lru_cache -from typing import TYPE_CHECKING, Optional, Sequence, Type, Union +from typing import TYPE_CHECKING, Any, Optional, Sequence, Type, Union from . import ( + FutureMap, IndexPartition, + PartitionByDomain, + PartitionByImage, + PartitionByImageRange, + PartitionByPreimage, + PartitionByPreimageRange, PartitionByRestriction, PartitionByWeights, + Point, Rect, Transform, + ffi, legion, ) from .launcher import Broadcast, Partition @@ -32,7 +40,7 @@ from .shape import Shape if TYPE_CHECKING: - from . import FutureMap, Partition as LegionPartition, Region + from . import Partition as LegionPartition, Region RequirementType = Union[Type[Broadcast], Type[Partition]] @@ -49,7 +57,11 @@ def even(self) -> bool: @abstractmethod def construct( - self, region: Region, complete: bool = False + self, + region: Region, + complete: bool = False, + color_shape: Optional[Shape] = None, + color_transform: Optional[Transform] = None, ) -> Optional[LegionPartition]: ... @@ -125,7 +137,11 @@ def scale(self, scale: tuple[int]) -> Replicate: return self def construct( - self, region: Region, complete: bool = False + self, + region: Region, + complete: bool = False, + color_shape: Optional[Shape] = None, + color_transform: Optional[Transform] = None, ) -> Optional[LegionPartition]: return None @@ -256,7 +272,7 @@ def translate(self, offset: Shape) -> Tiling: self._offset + offset, ) - # This function promotes the translated partition to REPLICATE if it + # This function promotes the translated partition to Replicate if it # doesn't overlap with the original partition. def translate_range(self, offset: Shape) -> Union[Replicate, Tiling]: promote = False @@ -292,11 +308,16 @@ def scale(self, scale: tuple[int]) -> Tiling: ) def construct( - self, region: Region, complete: bool = False + self, + region: Region, + complete: bool = False, + color_shape: Optional[Shape] = None, + color_transform: Optional[Transform] = None, ) -> Optional[LegionPartition]: + assert color_shape is None or color_transform is not None index_space = region.index_space index_partition = runtime.partition_manager.find_index_partition( - index_space, self + index_space, self, color_shape=color_shape ) if index_partition is None: tile_shape = self._tile_shape @@ -309,12 +330,26 @@ def construct( extent = Rect(hi, lo, exclusive=False) - color_space = runtime.find_or_create_index_space(self._color_shape) + color_space = runtime.find_or_create_index_space( + self._color_shape if color_shape is None else color_shape + ) + + if color_transform is not None: + transform = color_transform.compose(transform) + functor = PartitionByRestriction(transform, extent) if complete: - kind = legion.LEGION_DISJOINT_COMPLETE_KIND + kind = ( + legion.LEGION_DISJOINT_COMPLETE_KIND + if color_shape is None + else legion.LEGION_ALIASED_COMPLETE_KIND # type: ignore + ) else: - kind = legion.LEGION_DISJOINT_INCOMPLETE_KIND + kind = ( + legion.LEGION_DISJOINT_INCOMPLETE_KIND + if color_shape is None + else legion.LEGION_ALIASED_INCOMPLETE_KIND # type: ignore + ) index_partition = IndexPartition( runtime.legion_context, runtime.legion_runtime, @@ -325,7 +360,7 @@ def construct( keep=True, # export this partition functor to other libraries ) runtime.partition_manager.record_index_partition( - index_space, self, index_partition + index_space, self, index_partition, color_shape=color_shape ) return region.get_child(index_partition) @@ -405,7 +440,11 @@ def translate_range(self, offset: Shape) -> None: raise NotImplementedError("This method shouldn't be invoked") def construct( - self, region: Region, complete: bool = False + self, + region: Region, + complete: bool = False, + color_shape: Optional[Shape] = None, + color_transform: Optional[Transform] = None, ) -> Optional[LegionPartition]: assert complete @@ -430,3 +469,465 @@ def construct( index_space, self, index_partition ) return region.get_child(index_partition) + + +class ImagePartition(PartitionBase): + def __init__( + self, + store: Any, + part: PartitionBase, + mapper: int, + range: bool = False, + disjoint: bool = True, + complete: bool = True, + ) -> None: + self._mapper = mapper + self._store = store + self._part = part + # Whether this is an image or image_range operation. + self._range = range + self._disjoint = disjoint + self._complete = complete + + @property + def color_shape(self) -> Optional[Shape]: + return self._part.color_shape + + @property + def even(self) -> bool: + return False + + def construct( + self, + region: Region, + complete: bool = False, + color_shape: Optional[Shape] = None, + color_transform: Optional[Transform] = None, + ) -> Optional[LegionPartition]: + # TODO (rohany): We can't import RegionField due to an import cycle. + # assert(isinstance(self._store.storage, RegionField)) + source_region = self._store.storage.region + source_field = self._store.storage.field + + # TODO (rohany): What should the value of complete be? + source_part = self._store.find_or_create_legion_partition( + self._part, + preserve_colors=True, + ) + if self._range: + functor = PartitionByImageRange( + source_region, + source_part, + source_field.field_id, + mapper=self._mapper, + ) + else: + functor = PartitionByImage( # type: ignore + source_region, + source_part, + source_field.field_id, + mapper=self._mapper, + ) + index_partition = runtime.partition_manager.find_index_partition( + region.index_space, self + ) + if index_partition is None: + if self._disjoint and self._complete: + kind = legion.LEGION_DISJOINT_COMPLETE_KIND + elif self._disjoint and not self._complete: + kind = legion.LEGION_DISJOINT_INCOMPLETE_KIND + elif not self._disjoint and self._complete: + kind = legion.LEGION_ALIASED_COMPLETE_KIND # type: ignore + else: + kind = legion.LEGION_ALIASED_INCOMPLETE_KIND # type: ignore + index_partition = IndexPartition( + runtime.legion_context, + runtime.legion_runtime, + region.index_space, + source_part.color_space, + functor=functor, + kind=kind, + keep=True, + ) + runtime.partition_manager.record_index_partition( + region.index_space, self, index_partition + ) + return region.get_child(index_partition) + + def is_complete_for(self, extents: Shape, offsets: Shape) -> bool: + return self._complete + + def is_disjoint_for(self, launch_domain: Optional[Rect]) -> bool: + return self._disjoint + + def satisfies_restriction( + self, restrictions: Sequence[Restriction] + ) -> bool: + for restriction in restrictions: + # If there are some restricted dimensions to this store, + # then this key partition is likely not a good choice. + if restriction == Restriction.RESTRICTED: + return False + return True + + def needs_delinearization(self, launch_ndim: int) -> bool: + assert self.color_shape is not None + return launch_ndim != self.color_shape.ndim + + @property + def requirement(self) -> RequirementType: + return Partition + + def __hash__(self) -> int: + return hash( + ( + self.__class__, + self._store._storage, + # Importantly, we _cannot_ store the version of the store + # in the hash value. This is because the store's version may + # change after we've already put this functor into a table. + # That would result in the hash value changing without moving + # the position in the table, breaking invariants of the table. + # However, we must still check for version in equality to avoid + # using old values. + # self._store._version, + self._part, + self._range, + self._mapper, + ) + ) + + def __eq__(self, other: object) -> bool: + return ( + isinstance(other, ImagePartition) + # Importantly, we check equality of Storage objects rather than + # Stores. This is because Stores can have equivalent storages but + # not be equal due to transformations on the store. By checking + # that the Storages are equal, we are basically checking whether + # we have the same RegionField object. + and self._store._storage == other._store._storage + and self._store.version == other._store.version + and self._part == other._part + and self._range == other._range + and self._mapper == other._mapper + ) + + def __str__(self) -> str: + return f"image({self._store}, {self._part}, range={self._range})" + + def __repr__(self) -> str: + return str(self) + + +class PreimagePartition(PartitionBase): + # TODO (rohany): I don't even know if I need a store here. I really just + # need the index space that is being partitioned (or the IndexPartition). + # For simplicities sake it seems like taking the store is fine. + def __init__( + self, + source: Any, + dest: Any, + part: PartitionBase, + mapper: int, + range: bool = False, + disjoint: bool = False, + complete: bool = True, + ) -> None: + self._mapper = mapper + self._source = source + # Importantly, we don't store a reference to `dest` and instead + # hold onto a handle of the underlying region. This is important + # because if we store dest itself on the partition then legate + # can't collect and reuse the storage under dest. Since all we + # actually need from dest is the underlying index space, storing + # the region sidesteps this limitation. + self._dest_region = dest.storage.region + self._part = part + # Whether this is an image or image_range operation. + self._range = range + self._disjoint = disjoint + self._complete = complete + + @property + def color_shape(self) -> Optional[Shape]: + return self._part.color_shape + + @property + def even(self) -> bool: + return False + + def construct( + self, + region: Region, + complete: bool = False, + color_shape: Optional[Shape] = None, + color_transform: Optional[Transform] = None, + ) -> Optional[LegionPartition]: + dest_part = self._part.construct(self._dest_region) + source_region = self._source.storage.region + source_field = self._source.storage.field.field_id + functorFn = ( + PartitionByPreimageRange if self._range else PartitionByPreimage + ) + functor = functorFn( + dest_part.index_partition, # type: ignore + source_region, + source_region, + source_field, + mapper=self._mapper, + ) + index_partition = runtime.partition_manager.find_index_partition( + region.index_space, self + ) + if index_partition is None: + if self._disjoint and self._complete: + kind = legion.LEGION_DISJOINT_COMPLETE_KIND + elif self._disjoint and not self._complete: + kind = legion.LEGION_DISJOINT_INCOMPLETE_KIND + elif not self._disjoint and self._complete: + kind = legion.LEGION_ALIASED_COMPLETE_KIND # type: ignore + else: + kind = legion.LEGION_ALIASED_INCOMPLETE_KIND # type: ignore + # Discharge some typing errors. + assert dest_part is not None + index_partition = IndexPartition( + runtime.legion_context, + runtime.legion_runtime, + region.index_space, + dest_part.color_space, + functor=functor, + kind=kind, + keep=True, + ) + runtime.partition_manager.record_index_partition( + region.index_space, self, index_partition + ) + return region.get_child(index_partition) + + def is_complete_for(self, extents: Shape, offsets: Shape) -> bool: + return self._complete + + def is_disjoint_for(self, launch_domain: Optional[Rect]) -> bool: + return self._disjoint + + def satisfies_restriction( + self, restrictions: Sequence[Restriction] + ) -> bool: + for restriction in restrictions: + if restriction != Restriction.UNRESTRICTED: + raise NotImplementedError + return True + + def needs_delinearization(self, launch_ndim: int) -> bool: + assert self.color_shape is not None + return launch_ndim != self.color_shape.ndim + + @property + def requirement(self) -> RequirementType: + return Partition + + def __hash__(self) -> int: + return hash( + ( + self.__class__, + self._source._storage, + # Importantly, we _cannot_ store the version of the store + # in the hash value. This is because the store's version may + # change after we've already put this functor into a table. + # That would result in the hash value changing without moving + # the position in the table, breaking invariants of the table. + # However, we must still check for version in equality to avoid + # using old values. + # self._store._version, + self._dest_region.index_space, + self._part, + self._range, + self._mapper, + ) + ) + + def __eq__(self, other: object) -> bool: + return ( + isinstance(other, PreimagePartition) + # See the comment on ImagePartition.__eq__ about why we use + # source._storage for equality. + and self._source._storage == other._source._storage + and self._source._version == other._source._version + and self._dest_region.index_space == other._dest_region.index_space + and self._part == other._part + and self._range == other._range + and self._mapper == other._mapper + ) + + def __str__(self) -> str: + return f"preimage({self._source}, {self._part}, range={self._range})" + + def __repr__(self) -> str: + return str(self) + + +class DomainPartition(PartitionBase): + def __init__( + self, + shape: Shape, + color_shape: Shape, + domains: Union[FutureMap, dict[Point, Rect]], + ): + self._color_shape = color_shape + self._domains = domains + self._shape = shape + if len(shape) == 0: + raise AssertionError + + @property + def color_shape(self) -> Optional[Shape]: + return self._color_shape + + @property + def even(self) -> bool: + return False + + def construct( + self, + region: Region, + complete: bool = False, + color_shape: Optional[Shape] = None, + color_transform: Optional[Transform] = None, + ) -> Optional[LegionPartition]: + index_space = region.index_space + index_partition = runtime.partition_manager.find_index_partition( + index_space, self + ) + if index_partition is None: + functor = PartitionByDomain(self._domains) + index_partition = IndexPartition( + runtime.legion_context, + runtime.legion_runtime, + index_space, + runtime.find_or_create_index_space(self._color_shape), + functor=functor, + keep=True, + ) + runtime.partition_manager.record_index_partition( + index_space, self, index_partition + ) + return region.get_child(index_partition) + + # TODO (rohany): We could figure this out by staring at the domain map. + def is_complete_for(self, extents: Shape, offsets: Shape) -> bool: + return False + + # TODO (rohany): We could figure this out by staring at the domain map. + def is_disjoint_for(self, launch_domain: Optional[Rect]) -> bool: + return False + + def satisfies_restriction( + self, restrictions: Sequence[Restriction] + ) -> bool: + for restriction in restrictions: + # If there are some restricted dimensions to this store, + # then this key partition is likely not a good choice. + if restriction == Restriction.RESTRICTED: + return False + return True + + @property + def requirement(self) -> RequirementType: + return Partition + + def needs_delinearization(self, launch_ndim: int) -> bool: + return launch_ndim != self._color_shape.ndim + + def __hash__(self) -> int: + return hash( + ( + self.__class__, + self._shape, + self._color_shape, + # TODO (rohany): No better ideas... + id(self._domains), + ) + ) + + # TODO (rohany): Implement this. + def __eq__(self, other: object) -> bool: + return False + + def __str__(self) -> str: + return f"by_domain({self._color_shape}, {self._domains})" + + def __repr__(self) -> str: + return str(self) + + +# AffineProjection is translated from C++ to Python from the DISTAL +# AffineProjection functor. In particular, it encapsulates applying affine +# projections on `DomainPartition` objects. +class AffineProjection: + # Project each point to the following dimensions of the output point. + # Passing `None` as an entry in `projs` discards the chosen dimension + # from the projection. + def __init__(self, projs: list[Optional[int]]): + self.projs = projs + + @property + def dim(self) -> int: + return len(self.projs) + + def project_point(self, point: Point, output_bound: Point) -> Point: + output_dim = output_bound.dim + set_mask = [False] * output_dim + result = Point(dim=output_dim) + for i in range(0, self.dim): + mapTo = self.projs[i] + if mapTo is None: + continue + result[mapTo] = point[i] + set_mask[mapTo] = True + # Replace unset indices with their boundaries. + for i in range(0, output_dim): + if not set_mask[i]: + result[i] = output_bound[i] + return result + + def project_partition( + self, part: DomainPartition, bounds: Rect, tx_point: Any = None + ) -> DomainPartition: + projected = {} + if isinstance(part._domains, FutureMap): + for point in Rect(hi=part.color_shape): + fut = part._domains.get_future(point) + buf = fut.get_buffer() + dom = ffi.from_buffer("legion_domain_t*", buf)[0] # type: ignore # noqa + lg_rect = getattr( + legion, f"legion_domain_get_rect_{dom.dim}d" + )(dom) + lo = Point(dim=bounds.dim) + hi = Point(dim=bounds.dim) + for i in range(dom.dim): + lo[i] = lg_rect.lo.x[i] + hi[i] = lg_rect.hi.x[i] + lo = self.project_point(lo, bounds.lo) + hi = self.project_point(hi, bounds.hi) + if tx_point is not None: + point = tx_point(point) + projected[point] = Rect( + lo=tuple(lo), hi=tuple(hi), exclusive=False + ) + else: + for p, r in part._domains.items(): + lo = self.project_point(r.lo, bounds.lo) + hi = self.project_point(r.hi, bounds.hi) + if tx_point is not None: + p = tx_point(p) + projected[p] = Rect( + lo=tuple(lo), hi=tuple(hi), exclusive=False + ) + new_shape = Shape( + tuple(bounds.hi[idx] + 1 for idx in range(bounds.dim)) + ) + color_shape = part.color_shape + if tx_point is not None: + color_shape = Shape(tx_point(color_shape, exclusive=True)) + assert color_shape is not None + return DomainPartition(new_shape, color_shape, projected) diff --git a/legate/core/runtime.py b/legate/core/runtime.py index 00b48510fb..5df82f4973 100644 --- a/legate/core/runtime.py +++ b/legate/core/runtime.py @@ -31,6 +31,7 @@ from . import ( Fence, FieldSpace, + Fill, Future, FutureMap, IndexSpace, @@ -327,6 +328,20 @@ def allocate_field(self) -> tuple[Region, int]: def free_field( self, region: Region, field_id: int, ordered: bool = False ) -> None: + # When freeing this field, also issue a fill to invalidate any + # valid instances attached to this region. This allows us to reuse + # that space without having to make an instance allocation of the + # same size and shape. + buf = ffi.new("char[]", self.field_size) + fut = Future.from_buffer(self.runtime.legion_runtime, ffi.buffer(buf)) + fill = Fill( + region, + region, + field_id, + fut, + mapper=self.runtime.core_context.mapper_id, + ) + fill.launch(self.runtime.legion_runtime, self.runtime.legion_context) self.free_fields.append((region, field_id)) region_manager = self.runtime.find_region_manager(region) if region_manager.decrease_active_field_count(): @@ -612,7 +627,7 @@ def __init__(self, runtime: Runtime) -> None: ) self._piece_factors = list(reversed(factors)) self._index_partitions: dict[ - tuple[IndexSpace, PartitionBase], IndexPartition + tuple[IndexSpace, PartitionBase, Optional[Shape]], IndexPartition ] = {} # Maps storage id-partition pairs to Legion partitions self._legion_partitions: dict[ @@ -817,9 +832,12 @@ def use_complete_tiling(self, shape: Shape, tile_shape: Shape) -> bool: return not (num_tiles > 256 and num_tiles > 16 * self._num_pieces) def find_index_partition( - self, index_space: IndexSpace, functor: PartitionBase + self, + index_space: IndexSpace, + functor: PartitionBase, + color_shape: Optional[Shape] = None, ) -> Union[IndexPartition, None]: - key = (index_space, functor) + key = (index_space, functor, color_shape) return self._index_partitions.get(key) def record_index_partition( @@ -827,8 +845,9 @@ def record_index_partition( index_space: IndexSpace, functor: PartitionBase, index_partition: IndexPartition, + color_shape: Optional[Shape] = None, ) -> None: - key = (index_space, functor) + key = (index_space, functor, color_shape) assert key not in self._index_partitions self._index_partitions[key] = index_partition @@ -1259,9 +1278,13 @@ def _schedule(self, ops: List[Operation]) -> None: must_be_single = len(op.scalar_outputs) > 0 partitioner = Partitioner([op], must_be_single=must_be_single) strategies.append(partitioner.partition_stores()) - for op, strategy in zip(ops, strategies): op.launch(strategy) + # We also need to bump the versions for each modified store. + # TODO (rohany): We also need a callback here to evict cached + # partitions with old store values so that we don't leak these. + for store in op.get_all_modified_stores(): + store.bump_version() def flush_scheduling_window(self) -> None: if len(self._outstanding_ops) == 0: diff --git a/legate/core/shape.py b/legate/core/shape.py index af84a61323..3909030de8 100644 --- a/legate/core/shape.py +++ b/legate/core/shape.py @@ -17,6 +17,7 @@ from functools import reduce from typing import TYPE_CHECKING, Iterable, Iterator, Optional, Union, overload +import numpy as np from typing_extensions import TypeAlias if TYPE_CHECKING: @@ -27,9 +28,15 @@ def _cast_tuple(value: int | Iterable[int], ndim: int) -> tuple[int, ...]: - if isinstance(value, int): + if isinstance(value, Shape): + return value.extents + elif isinstance(value, Iterable): + return tuple(value) + elif isinstance(value, int) or np.issubdtype( + type(value), np.integer + ): # type: ignore return (value,) * ndim - return tuple(value) + return tuple(value) # type: ignore class _ShapeComparisonResult(tuple[bool, ...]): diff --git a/legate/core/solver.py b/legate/core/solver.py index 8efe888664..8fd759cf56 100644 --- a/legate/core/solver.py +++ b/legate/core/solver.py @@ -17,7 +17,7 @@ from typing import TYPE_CHECKING, Generic, List, Optional, Tuple, TypeVar from . import FieldSpace, Future, Rect -from .constraints import Alignment, Broadcast, Containment, PartSym +from .constraints import Alignment, Broadcast, Containment, Image, PartSym from .partition import REPLICATE from .runtime import runtime from .shape import Shape @@ -363,7 +363,7 @@ def compute_launch_shape( # If we're here, this means that replicated stores are safe to access # in parallel, so we filter those out to determine the launch domain - parts = [part for part in partitions.values() if part is not REPLICATE] + parts = [part for part in partitions.values() if part != REPLICATE] # If all stores are replicated, we can't parallelize the operation if len(parts) == 0: @@ -428,16 +428,38 @@ def _solve_store_constraints( restrictions = all_restrictions[unknown] cls = constraints.find(unknown) - partition = store.compute_key_partition(restrictions) - if not partition.even and len(cls) > 1: - partition, unknown = self.maybe_find_alternative_key_partition( - partition, - unknown, - cls, - restrictions, - must_be_even, - ) - key_parts.add(unknown) + # If we are supposed to be aligned with a partition in dependents, + # then don't make a decision right now. + depends = False + for to_align in cls: + if to_align in dependent: + depends = True + if depends: + continue + + # If we already have a partition for this equality class we + # need to use it. + for to_align in cls: + # Don't align to Futures, as those are trivially replicated. + if to_align.store.kind is Future: + continue + if to_align in result: + partition = result[to_align] + break + else: + partition = store.compute_key_partition(restrictions) + if not partition.even and len(cls) > 1: + ( + partition, + unknown, + ) = self.maybe_find_alternative_key_partition( + partition, + unknown, + cls, + restrictions, + must_be_even, + ) + key_parts.add(unknown) for to_align in cls: if to_align in result: @@ -450,6 +472,14 @@ def _solve_store_constraints( assert isinstance(expr, Lit) result[rhs] = expr._part + for unknown in unknowns: + if unknown in result: + continue + cls = constraints.find(unknown) + for to_align in cls: + if to_align in result: + result[unknown] = result[to_align] + return result, key_parts @staticmethod @@ -489,24 +519,38 @@ def partition_stores(self) -> Strategy: c._lhs, PartSym ): if c._lhs in dependent: - raise NotImplementedError( - "Partitions constrained by multiple constraints " - "are not supported yet" - ) - for unknown in c._rhs.unknowns(): - must_be_even.add(unknown) + rhs = dependent[c._lhs] + # While we can't have multiple constraints, we are ok + # with seeing a duplicate image constraint, as that + # doesn't affect the solving. + if not (isinstance(rhs, Image) and rhs.equals(c._rhs)): + raise NotImplementedError( + "Partitions constrained by multiple " + "constraints are not supported yet" + ) + if not isinstance(c._rhs, Image): + for unknown in c._rhs.unknowns(): + must_be_even.add(unknown) dependent[c._lhs] = c._rhs elif isinstance(c, Containment) and isinstance( c._rhs, PartSym ): if c._rhs in dependent: - raise NotImplementedError( - "Partitions constrained by multiple constraints " - "are not supported yet" - ) - for unknown in c._lhs.unknowns(): - must_be_even.add(unknown) + lhs = dependent[c._rhs] + # While we can't have multiple constraints, we are ok + # with seeing a duplicate image constraint, as that + # doesn't affect the solving. + if not (isinstance(lhs, Image) and lhs.equals(c._lhs)): + raise NotImplementedError( + "Partitions constrained by multiple " + "constraints are not supported yet" + ) + if not isinstance(c._lhs, Image): + for unknown in c._lhs.unknowns(): + must_be_even.add(unknown) dependent[c._rhs] = c._lhs + else: + raise NotImplementedError for op in self._ops: all_outputs.update( store for store in op.outputs if not store.unbound diff --git a/legate/core/store.py b/legate/core/store.py index 309585999a..8b31c8e630 100644 --- a/legate/core/store.py +++ b/legate/core/store.py @@ -34,7 +34,7 @@ DistributedAllocation, InlineMappedAllocation, ) -from .partition import REPLICATE, PartitionBase, Restriction, Tiling +from .partition import REPLICATE, PartitionBase, Restriction, Tiling, Weighted from .projection import execute_functor_symbolically from .runtime import runtime from .shape import Shape @@ -59,7 +59,7 @@ from .context import Context from .launcher import Proj from .projection import ProjFn - from .transform import TransformStackBase + from .transform import Transform, TransformStackBase from math import prod @@ -752,18 +752,28 @@ def reset_key_partition(self) -> None: runtime.partition_manager.reset_storage_key_partition(self._unique_id) def find_or_create_legion_partition( - self, functor: PartitionBase, complete: bool + self, + functor: PartitionBase, + complete: bool, + color_shape: Optional[Shape] = None, + color_transform: Optional[Transform] = None, ) -> Optional[LegionPartition]: if self.kind is not RegionField: return None assert isinstance(self.data, RegionField) + assert color_shape is None or color_transform is not None part, found = runtime.partition_manager.find_legion_partition( self._unique_id, functor ) - if not found: - part = functor.construct(self.data.region, complete=complete) + if not found or color_shape is None: + part = functor.construct( + self.data.region, + complete=complete, + color_shape=color_shape, + color_transform=color_transform, # type: ignore + ) runtime.partition_manager.record_legion_partition( self._unique_id, functor, part ) @@ -838,14 +848,17 @@ def get_child_store(self, *indices: int) -> Store: def get_requirement( self, launch_ndim: int, - proj_fn: Optional[ProjFn] = None, + proj_fn: Optional[Union[ProjFn, int]] = None, ) -> Proj: part = self._storage_partition.find_or_create_legion_partition() if part is not None: - proj_id = self._store.compute_projection(proj_fn, launch_ndim) - if self._partition.needs_delinearization(launch_ndim): - assert proj_id == 0 - proj_id = runtime.get_delinearize_functor() + if isinstance(proj_fn, int): + proj_id = proj_fn + else: + proj_id = self._store.compute_projection(proj_fn, launch_ndim) + if self._partition.needs_delinearization(launch_ndim): + assert proj_id == 0 + proj_id = runtime.get_delinearize_functor() else: proj_id = 0 return self._partition.requirement(part, proj_id) @@ -900,6 +913,11 @@ def __init__( # when no custom functor is given self._projection: Union[None, int] = None self._restrictions: Union[None, tuple[Restriction, ...]] = None + # We maintain a version on store objects to cache dependent + # partitions created from the store. Operations that write + # to stores will bump their version and invalidate dependent + # partitions that were created with this store as the source. + self._version = 0 if self._shape is not None: if any(extent < 0 for extent in self._shape.extents): @@ -1072,6 +1090,13 @@ def transformed(self) -> bool: """ return not self._transform.bottom + @property + def version(self) -> int: + return self._version + + def bump_version(self) -> None: + self._version += 1 + def attach_external_allocation( self, context: Context, alloc: Attachable, share: bool ) -> None: @@ -1633,7 +1658,9 @@ def compute_key_partition( else: partition = None - if partition is not None: + if partition is not None and ( + not (self.transformed and isinstance(partition, Weighted)) + ): partition = self._transform.convert_partition(partition) return partition else: @@ -1686,15 +1713,28 @@ def find_restrictions(self) -> tuple[Restriction, ...]: return self._restrictions def find_or_create_legion_partition( - self, partition: PartitionBase, complete: bool = False + self, + partition: PartitionBase, + complete: bool = False, + preserve_colors: bool = False, ) -> Optional[LegionPartition]: # Create a Legion partition for a given functor. # Before we do that, we need to map the partition back # to the original coordinate space. - return self._storage.find_or_create_legion_partition( - self._transform.invert_partition(partition), - complete=complete, - ) + if preserve_colors: + return self._storage.find_or_create_legion_partition( + self._transform.invert_partition(partition), + complete=complete, + color_shape=partition.color_shape, + color_transform=self._transform.get_inverse_color_transform( # type: ignore # noqa + partition.color_shape.ndim, # type: ignore + ), + ) + else: + return self._storage.find_or_create_legion_partition( + self._transform.invert_partition(partition), + complete=complete, + ) def partition(self, partition: PartitionBase) -> StorePartition: storage_partition = self._storage.partition( diff --git a/legate/core/transform.py b/legate/core/transform.py index 9a7d5baa0e..926709d58f 100644 --- a/legate/core/transform.py +++ b/legate/core/transform.py @@ -14,12 +14,18 @@ # from __future__ import annotations -from typing import TYPE_CHECKING, Protocol, Tuple +from typing import TYPE_CHECKING, Optional, Protocol, Tuple import numpy as np -from . import AffineTransform -from .partition import Replicate, Restriction, Tiling +from . import AffineTransform, Point, Rect, Transform as LegionTransform +from .partition import ( + AffineProjection, + DomainPartition, + Replicate, + Restriction, + Tiling, +) from .projection import ProjExpr from .runtime import runtime from .shape import Shape @@ -72,6 +78,9 @@ def invert_restrictions(self, restrictions: Restrictions) -> Restrictions: def get_inverse_transform(self, ndim: int) -> AffineTransform: ... + def get_inverse_color_transform(self, ndim: int) -> LegionTransform: + ... + class Transform(TransformProto, Protocol): def invert(self, partition: PartitionBase) -> PartitionBase: @@ -160,6 +169,9 @@ def get_inverse_transform(self, ndim: int) -> AffineTransform: result.offset[self._dim] = -self._offset return result + def get_inverse_color_transform(self, ndim: int) -> LegionTransform: + raise NotImplementedError("Not implemented yet") + def serialize(self, buf: BufferBuilder) -> None: code = runtime.get_transform_code(self.__class__.__name__) buf.pack_32bit_int(code) @@ -204,6 +216,23 @@ def invert(self, partition: PartitionBase) -> PartitionBase: partition.color_shape.drop(self._extra_dim), partition.offset.drop(self._extra_dim), ) + if isinstance(partition, DomainPartition): + # Project away the desired dimension. + all_axes = list(range(0, len(partition._shape))) + shape = partition._shape.drop(self._extra_dim) + axes: list[Optional[int]] = ( + all_axes[: self._extra_dim] + + [None] + + [x - 1 for x in all_axes[self._extra_dim + 1 :]] + ) + + def tx_point(p: Point, exclusive: bool = False) -> Point: + return Point(Shape(p).drop(self._extra_dim)) + + result = AffineProjection(axes).project_partition( + partition, Rect(hi=shape), tx_point=tx_point + ) + return result else: raise ValueError( f"Unsupported partition: {type(partition).__name__}" @@ -236,6 +265,31 @@ def convert(self, partition: PartitionBase) -> PartitionBase: ) elif isinstance(partition, Replicate): return partition + elif isinstance(partition, DomainPartition): + # The idea here is to project all of the dimensions except + # the promoted one into a new affine projection. In the + # future, we could imagine caching these to avoid redundantly + # computing them. + all_axes = list(range(0, len(partition._shape))) + axes = all_axes[: self._extra_dim] + [ + x + 1 for x in all_axes[self._extra_dim :] + ] + shape = list(partition._shape.extents) + new_shape = Shape( + shape[: self._extra_dim] + + [self._dim_size] + + shape[self._extra_dim :] + ) + + def tx_point(p: Point, exclusive: bool = False) -> Point: + return Point( + Shape(p).insert(self._extra_dim, 1 if exclusive else 0) + ) + + result = AffineProjection(axes).project_partition( # type: ignore + partition, Rect(hi=new_shape), tx_point=tx_point + ) + return result else: raise ValueError( f"Unsupported partition: {type(partition).__name__}" @@ -257,6 +311,16 @@ def get_inverse_transform(self, ndim: int) -> AffineTransform: parent_dim += 1 return result + def get_inverse_color_transform(self, ndim: int) -> LegionTransform: + parent_ndim = ndim - 1 + result = LegionTransform(parent_ndim, ndim) + parent_dim = 0 + for child_dim in range(ndim): + if child_dim != self._extra_dim: + result.trans[parent_dim, child_dim] = 1 + parent_dim += 1 + return result + def serialize(self, buf: BufferBuilder) -> None: code = runtime.get_transform_code(self.__class__.__name__) buf.pack_32bit_int(code) @@ -356,6 +420,9 @@ def get_inverse_transform(self, ndim: int) -> AffineTransform: child_dim += 1 return result + def get_inverse_color_transform(self, ndim: int) -> LegionTransform: + raise NotImplementedError("Not implemented yet") + def serialize(self, buf: BufferBuilder) -> None: code = runtime.get_transform_code(self.__class__.__name__) buf.pack_32bit_int(code) @@ -442,6 +509,9 @@ def get_inverse_transform(self, ndim: int) -> AffineTransform: result.trans[self._axes[dim], dim] = 1 return result + def get_inverse_color_transform(self, ndim: int) -> LegionTransform: + raise NotImplementedError("Not implemented yet") + def serialize(self, buf: BufferBuilder) -> None: code = runtime.get_transform_code(self.__class__.__name__) buf.pack_32bit_int(code) @@ -567,6 +637,9 @@ def get_inverse_transform(self, ndim: int) -> AffineTransform: return result + def get_inverse_color_transform(self, ndim: int) -> LegionTransform: + raise NotImplementedError("Not implemented yet") + def serialize(self, buf: BufferBuilder) -> None: code = runtime.get_transform_code(self.__class__.__name__) buf.pack_32bit_int(code) @@ -668,6 +741,11 @@ def get_inverse_transform(self, ndim: int) -> AffineTransform: parent = self._parent.get_inverse_transform(transform.M) return transform.compose(parent) + def get_inverse_color_transform(self, ndim: int) -> LegionTransform: + transform = self._transform.get_inverse_color_transform(ndim) + parent = self._parent.get_inverse_color_transform(transform.M) + return transform.compose(parent) + def stack(self, transform: Transform) -> TransformStack: return TransformStack(transform, self) @@ -727,6 +805,9 @@ def invert_restrictions(self, restrictions: Restrictions) -> Restrictions: def get_inverse_transform(self, ndim: int) -> AffineTransform: return AffineTransform(ndim, ndim, True) + def get_inverse_color_transform(self, ndim: int) -> LegionTransform: + return LegionTransform(ndim, ndim, True) + def stack(self, transform: Transform) -> TransformStack: return TransformStack(transform, self) diff --git a/src/core/mapping/core_mapper.cc b/src/core/mapping/core_mapper.cc index 0e42875a1b..2bb09063df 100644 --- a/src/core/mapping/core_mapper.cc +++ b/src/core/mapping/core_mapper.cc @@ -79,6 +79,10 @@ class CoreMapper : public Legion::Mapping::NullMapper { const Legion::Task& task, const SelectShardingFunctorInput& input, SelectShardingFunctorOutput& output) override; + virtual void select_sharding_functor(const Legion::Mapping::MapperContext ctx, + const Legion::Fill& fill, + const SelectShardingFunctorInput& input, + SelectShardingFunctorOutput& output); void select_steal_targets(const Legion::Mapping::MapperContext ctx, const SelectStealingInput& input, SelectStealingOutput& output) override; @@ -321,6 +325,16 @@ void CoreMapper::select_sharding_functor(const Legion::Mapping::MapperContext ct output.chosen_functor = context.get_sharding_id(LEGATE_CORE_TOPLEVEL_TASK_SHARD_ID); } +void CoreMapper::select_sharding_functor(const Legion::Mapping::MapperContext ctx, + const Legion::Fill& fill, + const SelectShardingFunctorInput& input, + SelectShardingFunctorOutput& output) +{ + const int launch_dim = fill.index_domain.get_dim(); + assert(launch_dim == 1); + output.chosen_functor = context.get_sharding_id(LEGATE_CORE_TOPLEVEL_TASK_SHARD_ID); +} + void CoreMapper::select_steal_targets(const Legion::Mapping::MapperContext ctx, const SelectStealingInput& input, SelectStealingOutput& output) diff --git a/src/core/runtime/context.h b/src/core/runtime/context.h index dac640a3fd..233ae8f8de 100644 --- a/src/core/runtime/context.h +++ b/src/core/runtime/context.h @@ -297,11 +297,14 @@ class TaskContext { ReturnValues pack_return_values_with_exception(int32_t index, const std::string& error_message) const; + // The API doesn't handle passing through future maps well right now, so we need + // to access this directly. + const Legion::Task* task_; + private: std::vector get_return_values() const; private: - const Legion::Task* task_; const std::vector& regions_; Legion::Context context_; Legion::Runtime* runtime_;