diff --git a/meshmode/distributed.py b/meshmode/distributed.py index d99e90dc8..154edd5ef 100644 --- a/meshmode/distributed.py +++ b/meshmode/distributed.py @@ -3,6 +3,7 @@ .. autoclass:: InterRankBoundaryInfo .. autoclass:: MPIBoundaryCommSetupHelper +.. autofunction:: mpi_distribute .. autofunction:: get_partition_by_pymetis .. autofunction:: membership_list_to_map .. autofunction:: get_connected_parts @@ -37,8 +38,11 @@ """ from dataclasses import dataclass +from contextlib import contextmanager import numpy as np -from typing import List, Set, Union, Mapping, cast, Sequence, TYPE_CHECKING +from typing import ( + Any, Optional, List, Set, Union, Mapping, cast, Sequence, TYPE_CHECKING +) from arraycontext import ArrayContext from meshmode.discretization.connection import ( @@ -66,12 +70,73 @@ import logging logger = logging.getLogger(__name__) -TAG_BASE = 83411 -TAG_DISTRIBUTE_MESHES = TAG_BASE + 1 - # {{{ mesh distributor +@contextmanager +def _duplicate_mpi_comm(mpi_comm): + dup_comm = mpi_comm.Dup() + try: + yield dup_comm + finally: + dup_comm.Free() + + +def mpi_distribute( + mpi_comm: "mpi4py.MPI.Intracomm", + source_rank: int = 0, + source_data: Optional[Mapping[int, Any]] = None) -> Optional[Any]: + """ + Distribute data to a set of processes. + + :arg mpi_comm: An ``MPI.Intracomm`` + :arg source_rank: The rank from which the data is being sent. + :arg source_data: A :class:`dict` mapping destination ranks to data to be sent. + Only present on the source rank. + + :returns: The data local to the current process if there is any, otherwise + *None*. + """ + with _duplicate_mpi_comm(mpi_comm) as mpi_comm: + num_proc = mpi_comm.Get_size() + rank = mpi_comm.Get_rank() + + local_data = None + + if rank == source_rank: + if source_data is None: + raise TypeError("source rank has no data.") + + sending_to = [False] * num_proc + for dest_rank in source_data.keys(): + sending_to[dest_rank] = True + + mpi_comm.scatter(sending_to, root=source_rank) + + reqs = [] + for dest_rank, data in source_data.items(): + if dest_rank == rank: + local_data = data + logger.info("rank %d: received data", rank) + else: + reqs.append(mpi_comm.isend(data, dest=dest_rank)) + + logger.info("rank %d: sent all data", rank) + + from mpi4py import MPI + MPI.Request.waitall(reqs) + + else: + receiving = mpi_comm.scatter([], root=source_rank) + + if receiving: + local_data = mpi_comm.recv(source=source_rank) + logger.info("rank %d: received data", rank) + + return local_data + + +# TODO: Deprecate? class MPIMeshDistributor: """ .. automethod:: is_mananger_rank @@ -99,9 +164,7 @@ def send_mesh_parts(self, mesh, part_per_element, num_parts): Sends each part to a different rank. Returns one part that was not sent to any other rank. """ - mpi_comm = self.mpi_comm - rank = mpi_comm.Get_rank() - assert num_parts <= mpi_comm.Get_size() + assert num_parts <= self.mpi_comm.Get_size() assert self.is_mananger_rank() @@ -110,38 +173,16 @@ def send_mesh_parts(self, mesh, part_per_element, num_parts): from meshmode.mesh.processing import partition_mesh parts = partition_mesh(mesh, part_num_to_elements) - local_part = None - - reqs = [] - for r, part in parts.items(): - if r == self.manager_rank: - local_part = part - else: - reqs.append(mpi_comm.isend(part, dest=r, tag=TAG_DISTRIBUTE_MESHES)) - - logger.info("rank %d: sent all mesh parts", rank) - for req in reqs: - req.wait() - - return local_part + return mpi_distribute( + self.mpi_comm, source_rank=self.manager_rank, source_data=parts) def receive_mesh_part(self): """ Returns the mesh sent by the manager rank. """ - mpi_comm = self.mpi_comm - rank = mpi_comm.Get_rank() - assert not self.is_mananger_rank(), "Manager rank cannot receive mesh" - from mpi4py import MPI - status = MPI.Status() - result = self.mpi_comm.recv( - source=self.manager_rank, tag=TAG_DISTRIBUTE_MESHES, - status=status) - logger.info("rank %d: received local mesh (size = %d)", rank, status.count) - - return result + return mpi_distribute(self.mpi_comm, source_rank=self.manager_rank) # }}}