From 9e62502b82d5e28e8420fa14bc0fd20e54552c25 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Thu, 3 Mar 2022 18:10:16 -0600 Subject: [PATCH] Tag negoatiation --- grudge/array_context.py | 21 +++++- grudge/trace_pair.py | 153 +++++++++++++++++++++++++++------------- 2 files changed, 123 insertions(+), 51 deletions(-) diff --git a/grudge/array_context.py b/grudge/array_context.py index 675e47a29..6abdcba05 100644 --- a/grudge/array_context.py +++ b/grudge/array_context.py @@ -32,7 +32,8 @@ # {{{ imports from typing import ( - TYPE_CHECKING, Mapping, Tuple, Any, Callable, Optional, Type) + TYPE_CHECKING, Mapping, Tuple, Any, Callable, Optional, Type, + Dict) from dataclasses import dataclass from meshmode.array_context import ( @@ -63,6 +64,8 @@ import pyopencl.tools from mpi4py import MPI + from grudge.trace_pair import CommunicationTag + class PyOpenCLArrayContext(_PyOpenCLArrayContextBase): """Inherits from :class:`meshmode.array_context.PyOpenCLArrayContext`. Extends it @@ -233,13 +236,18 @@ class MPIPyOpenCLArrayContext(PyOpenCLArrayContext, MPIBasedArrayContext): .. autofunction:: __init__ """ + _source_rank_sym_tag_to_num_tag: Dict[Tuple[int, CommunicationTag], int] + _dest_rank_sym_tag_to_num_tag: Dict[Tuple[int, CommunicationTag], int] + _dest_rank_to_taken_num_tag: Dict[int, int] + mpi_base_tag: int def __init__(self, mpi_communicator, queue: "pyopencl.CommandQueue", *, allocator: Optional["pyopencl.tools.AllocatorInterface"] = None, wait_event_queue_length: Optional[int] = None, - force_device_scalars: bool = False) -> None: + force_device_scalars: bool = False, + mpi_base_tag: int) -> None: """ See :class:`arraycontext.impl.pyopencl.PyOpenCLArrayContext` for most arguments. @@ -250,13 +258,20 @@ def __init__(self, self.mpi_communicator = mpi_communicator + self.mpi_base_tag = mpi_base_tag + + self._source_rank_sym_tag_to_num_tag = {} + self._dest_rank_sym_tag_to_num_tag = {} + self._dest_rank_to_next_num_tag = {} + def clone(self): # type-ignore-reason: 'DistributedLazyArrayContext' has no 'queue' member # pylint: disable=no-member return type(self)(self.mpi_communicator, self.queue, allocator=self.allocator, wait_event_queue_length=self._wait_event_queue_length, - force_device_scalars=self._force_device_scalars) + force_device_scalars=self._force_device_scalars, + mpi_base_tag=self.mpi_base_tag) # }}} diff --git a/grudge/trace_pair.py b/grudge/trace_pair.py index 872028d74..41c94f771 100644 --- a/grudge/trace_pair.py +++ b/grudge/trace_pair.py @@ -46,9 +46,7 @@ """ -from typing import List, Hashable, Optional, Type, Any - -from pytools.persistent_dict import KeyBuilder +from typing import List, Hashable, Dict, Tuple, TYPE_CHECKING, Callable from arraycontext import ( ArrayContainer, @@ -75,6 +73,9 @@ import numpy as np import grudge.dof_desc as dof_desc +if TYPE_CHECKING: + import mpi4py.MPI + # {{{ trace pair container class @@ -310,27 +311,102 @@ def interior_trace_pairs(dcoll: DiscretizationCollection, vec, *, # }}} -# {{{ distributed-memory functionality +# {{{ generic distributed support + +CommunicationTag = Hashable + @memoize_on_first_arg def connected_ranks(dcoll: DiscretizationCollection): from meshmode.distributed import get_connected_partitions return get_connected_partitions(dcoll._volume_discr.mesh) +# }}} + + +# {{{ eager distributed + +@dataclass +class _EagerMPITags: + send_mpi_tag: int + recv_mpi_tag: int + + +@dataclass +class _EagerMPIState: + mpi_communicator: "mpi4py.MPI.Comm" -class _RankBoundaryCommunication: - base_comm_tag = 1273 + # base_tag is used for tag + tag_negotiation_tag: int + first_assignable_tag: int + + source_rank_sym_tag_to_num_tag: Dict[Tuple[int, CommunicationTag], int] + dest_rank_sym_tag_to_num_tag: Dict[Tuple[int, CommunicationTag], int] + dest_rank_to_taken_num_tag: Dict[int, int] + + +class _EagerSymbolicTagNegotiator: + # You may ask: why do we need to communicate at all to agree + # on tag mappings? Imagine the case where different ranks + # hit different tags in different order. (Well, as long + # as we're expecting eager comm to complete inside of + # cross_rank_trace_pairs, I guess that would deadlock. + # But still.) + + def __init__(self, eager_mpi_state: _EagerMPIState, sym_tag: CommunicationTag, + remote_rank: int, + continuation: Callable[ + [_EagerMPITags], "_RankBoundaryCommunicationEager"]): + self.eager_mpi_state = eager_mpi_state + self.sym_tag = sym_tag + self.remote_rank = remote_rank + self.continuation = continuation + + rank_n_tag = (remote_rank, sym_tag) + assert rank_n_tag not in eager_mpi_state.source_rank_sym_tag_to_num_tag + assert rank_n_tag not in eager_mpi_state.dest_rank_sym_tag_to_num_tag + + self.send_num_tag = eager_mpi_state.dest_rank_to_taken_num_tag.setdefault( + remote_rank, eager_mpi_state.first_assignable_tag) + eager_mpi_state.dest_rank_sym_tag_to_num_tag[rank_n_tag] = self.send_num_tag + + comm = eager_mpi_state.mpi_communicator + self.send_req = comm.isend((sym_tag, self.send_num_tag), + remote_rank, tag=eager_mpi_state.tag_negotiation_tag) + self.recv_req = comm.irecv( + remote_rank, tag=eager_mpi_state.tag_negotiation_tag) + + def finish(self): + recv_sym_tag: CommunicationTag + recv_num_tag: int + recv_sym_tag, recv_num_tag = self.recv_req.wait() + self.send_req.wait() + self.eager_mpi_state.source_rank_sym_tag_to_num_tag[ + self.remote_rank, recv_sym_tag] = recv_num_tag + + # FIXME This asserts that the whole tag negotiation process + # is pointless. Unless there is a way to have eager communication + # for more than one tag pending at the same time (which, for now, + # there isn't), this whole endeavor is thoroughly unnecessary. + assert recv_sym_tag == self.sym_tag + + return self.continuation(_EagerMPITags( + send_mpi_tag=self.send_num_tag, recv_mpi_tag=recv_num_tag)) + + +class _RankBoundaryCommunicationEager: def __init__(self, - dcoll: DiscretizationCollection, - array_container: ArrayOrContainerT, - remote_rank, comm_tag: Optional[int] = None): + mpi_communicator, + dcoll: DiscretizationCollection, + array_container: ArrayOrContainerT, + *, remote_rank: int, send_mpi_tag: int, recv_mpi_tag: int): actx = get_container_context_recursively(array_container) + assert actx is not None + btag = BTAG_PARTITION(remote_rank) local_bdry_data = project(dcoll, "vol", btag, array_container) - comm = dcoll.mpi_communicator - self.dcoll = dcoll self.array_context = actx self.remote_btag = btag @@ -339,10 +415,6 @@ def __init__(self, self.local_bdry_data_np = \ to_numpy(flatten(self.local_bdry_data, actx), actx) - self.comm_tag = self.base_comm_tag - if comm_tag is not None: - self.comm_tag += comm_tag - # Here, we initialize both send and recieve operations through # mpi4py `Request` (MPI_Request) instances for comm.Isend (MPI_Isend) # and comm.Irecv (MPI_Irecv) respectively. These initiate non-blocking @@ -364,11 +436,11 @@ def __init__(self, # as well, just in case. self.send_req = comm.Isend(self.local_bdry_data_np, remote_rank, - tag=self.comm_tag) + tag=mpi_tag) self.remote_data_host_numpy = np.empty_like(self.local_bdry_data_np) self.recv_req = comm.Irecv(self.remote_data_host_numpy, remote_rank, - tag=self.comm_tag) + tag=mpi_tag) def finish(self): # Wait for the nonblocking receive request to complete before @@ -393,15 +465,18 @@ def finish(self): interior=self.local_bdry_data, exterior=swapped_remote_bdry_data) +# }}} -from pytato import make_distributed_recv, staple_distributed_send +# {{{ lazy distributed class _RankBoundaryCommunicationLazy: def __init__(self, dcoll: DiscretizationCollection, array_container: ArrayOrContainerT, - remote_rank: int, comm_tag: Hashable): + remote_rank: int, comm_tag: CommunicationTag): + from pytato import make_distributed_recv, staple_distributed_send + if comm_tag is None: raise ValueError("lazy communication requires 'tag' to be supplied") @@ -433,16 +508,15 @@ def finish(self): interior=self.local_bdry_data, exterior=bdry_conn(self.remote_data)) +# }}} -class _TagKeyBuilder(KeyBuilder): - def update_for_type(self, key_hash, key: Type[Any]): - self.rec(key_hash, (key.__module__, key.__name__, key.__name__,)) +# {{{ cross_rank_trace_pairs def cross_rank_trace_pairs( dcoll: DiscretizationCollection, ary, - comm_tag: Hashable = None, - tag: Hashable = None) -> List[TracePair]: + comm_tag: CommunicationTag = None, + tag: CommunicationTag = None) -> List[TracePair]: r"""Get a :class:`list` of *ary* trace pairs for each partition boundary. For each partition boundary, the field data values in *ary* are @@ -481,6 +555,12 @@ def cross_rank_trace_pairs( comm_tag = tag del tag + # {{{ + + + # }}} + + if isinstance(ary, Number): # NOTE: Assumed that the same number is passed on every rank return [TracePair(BTAG_PARTITION(remote_rank), interior=ary, exterior=ary) @@ -493,30 +573,7 @@ def cross_rank_trace_pairs( if isinstance(actx, MPIPytatoArrayContextBase): rbc = _RankBoundaryCommunicationLazy else: - rbc = _RankBoundaryCommunication - if comm_tag is not None: - num_tag: Optional[int] = None - if isinstance(comm_tag, int): - num_tag = comm_tag - - if num_tag is None: - # FIXME: This isn't guaranteed to be correct. - # See here for discussion: - # - https://github.com/illinois-ceesd/mirgecom/issues/617#issuecomment-1057082716 # noqa - # - https://github.com/inducer/grudge/pull/222 - from mpi4py import MPI - tag_ub = MPI.COMM_WORLD.Get_attr(MPI.TAG_UB) - key_builder = _TagKeyBuilder() - digest = key_builder(comm_tag) - num_tag = sum(ord(ch) << i for i, ch in enumerate(digest)) % tag_ub - - from warnings import warn - warn("Encountered unknown symbolic tag " - f"'{comm_tag}', assigning a value of '{num_tag}'. " - "This is a temporary workaround, please ensure that " - "tags are sufficiently distinct for your use case.") - - comm_tag = num_tag + rbc = partial(_RankBoundaryCommunicationEager, # Initialize and post all sends/receives rank_bdry_communcators = [