From 35e4ce972388ad423023a78d06dc142591a535c6 Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Wed, 16 Mar 2022 11:25:56 -0500 Subject: [PATCH 1/3] add mpi_distribute --- meshmode/distributed.py | 97 ++++++++++++++++++++++++++++------------- 1 file changed, 66 insertions(+), 31 deletions(-) diff --git a/meshmode/distributed.py b/meshmode/distributed.py index d99e90dc8..532d52bb7 100644 --- a/meshmode/distributed.py +++ b/meshmode/distributed.py @@ -3,6 +3,7 @@ .. autoclass:: InterRankBoundaryInfo .. autoclass:: MPIBoundaryCommSetupHelper +.. autofunction:: mpi_distribute .. autofunction:: get_partition_by_pymetis .. autofunction:: membership_list_to_map .. autofunction:: get_connected_parts @@ -37,6 +38,7 @@ """ from dataclasses import dataclass +from contextlib import contextmanager import numpy as np from typing import List, Set, Union, Mapping, cast, Sequence, TYPE_CHECKING @@ -66,12 +68,69 @@ import logging logger = logging.getLogger(__name__) -TAG_BASE = 83411 -TAG_DISTRIBUTE_MESHES = TAG_BASE + 1 - # {{{ mesh distributor +@contextmanager +def _duplicate_mpi_comm(mpi_comm): + dup_comm = mpi_comm.Dup() + try: + yield dup_comm + finally: + dup_comm.Free() + + +def mpi_distribute(mpi_comm, source_data=None, source_rank=0): + """ + Distribute data to a set of processes. + + :arg mpi_comm: An ``MPI.Intracomm`` + :arg source_data: A :class:`dict` mapping destination ranks to data to be sent. + Only present on the source rank. + :arg source_rank: The rank from which the data is being sent. + :returns: The data local to the current process if there is any, otherwise + *None*. + """ + with _duplicate_mpi_comm(mpi_comm) as mpi_comm: + num_proc = mpi_comm.Get_size() + rank = mpi_comm.Get_rank() + + local_data = None + + if rank == source_rank: + if source_data is None: + raise TypeError("source rank has no data.") + + sending_to = np.full(num_proc, False) + for dest_rank in source_data.keys(): + sending_to[dest_rank] = True + + mpi_comm.scatter(sending_to, root=source_rank) + + reqs = [] + for dest_rank, data in source_data.items(): + if dest_rank == rank: + local_data = data + logger.info("rank %d: received data", rank) + else: + reqs.append(mpi_comm.isend(data, dest=dest_rank)) + + logger.info("rank %d: sent all data", rank) + + from mpi4py import MPI + MPI.Request.waitall(reqs) + + else: + receiving = mpi_comm.scatter(None, root=source_rank) + + if receiving: + local_data = mpi_comm.recv(source=source_rank) + logger.info("rank %d: received data", rank) + + return local_data + + +# TODO: Deprecate? class MPIMeshDistributor: """ .. automethod:: is_mananger_rank @@ -99,9 +158,7 @@ def send_mesh_parts(self, mesh, part_per_element, num_parts): Sends each part to a different rank. Returns one part that was not sent to any other rank. """ - mpi_comm = self.mpi_comm - rank = mpi_comm.Get_rank() - assert num_parts <= mpi_comm.Get_size() + assert num_parts <= self.mpi_comm.Get_size() assert self.is_mananger_rank() @@ -110,38 +167,16 @@ def send_mesh_parts(self, mesh, part_per_element, num_parts): from meshmode.mesh.processing import partition_mesh parts = partition_mesh(mesh, part_num_to_elements) - local_part = None - - reqs = [] - for r, part in parts.items(): - if r == self.manager_rank: - local_part = part - else: - reqs.append(mpi_comm.isend(part, dest=r, tag=TAG_DISTRIBUTE_MESHES)) - - logger.info("rank %d: sent all mesh parts", rank) - for req in reqs: - req.wait() - - return local_part + return mpi_distribute( + self.mpi_comm, source_data=parts, source_rank=self.manager_rank) def receive_mesh_part(self): """ Returns the mesh sent by the manager rank. """ - mpi_comm = self.mpi_comm - rank = mpi_comm.Get_rank() - assert not self.is_mananger_rank(), "Manager rank cannot receive mesh" - from mpi4py import MPI - status = MPI.Status() - result = self.mpi_comm.recv( - source=self.manager_rank, tag=TAG_DISTRIBUTE_MESHES, - status=status) - logger.info("rank %d: received local mesh (size = %d)", rank, status.count) - - return result + return mpi_distribute(self.mpi_comm, source_rank=self.manager_rank) # }}} From 004c95b97cfde95e30c77b42be48e8e992a9d1cf Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Wed, 29 Jun 2022 14:26:10 -0500 Subject: [PATCH 2/3] add type hints to mpi_distribute --- meshmode/distributed.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/meshmode/distributed.py b/meshmode/distributed.py index 532d52bb7..161dc78f4 100644 --- a/meshmode/distributed.py +++ b/meshmode/distributed.py @@ -40,7 +40,9 @@ from dataclasses import dataclass from contextlib import contextmanager import numpy as np -from typing import List, Set, Union, Mapping, cast, Sequence, TYPE_CHECKING +from typing import ( + Any, Optional, List, Set, Union, Mapping, cast, Sequence, TYPE_CHECKING +) from arraycontext import ArrayContext from meshmode.discretization.connection import ( @@ -80,7 +82,10 @@ def _duplicate_mpi_comm(mpi_comm): dup_comm.Free() -def mpi_distribute(mpi_comm, source_data=None, source_rank=0): +def mpi_distribute( + mpi_comm: "mpi4py.MPI.Intracomm", + source_data: Optional[Mapping[int, Any]] = None, + source_rank: int = 0) -> Optional[Any]: """ Distribute data to a set of processes. @@ -88,6 +93,7 @@ def mpi_distribute(mpi_comm, source_data=None, source_rank=0): :arg source_data: A :class:`dict` mapping destination ranks to data to be sent. Only present on the source rank. :arg source_rank: The rank from which the data is being sent. + :returns: The data local to the current process if there is any, otherwise *None*. """ @@ -101,7 +107,7 @@ def mpi_distribute(mpi_comm, source_data=None, source_rank=0): if source_data is None: raise TypeError("source rank has no data.") - sending_to = np.full(num_proc, False) + sending_to = [False] * num_proc for dest_rank in source_data.keys(): sending_to[dest_rank] = True @@ -121,7 +127,7 @@ def mpi_distribute(mpi_comm, source_data=None, source_rank=0): MPI.Request.waitall(reqs) else: - receiving = mpi_comm.scatter(None, root=source_rank) + receiving = mpi_comm.scatter([], root=source_rank) if receiving: local_data = mpi_comm.recv(source=source_rank) From a805c5ca6d2d12209944279dd5131786a5456f20 Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Tue, 25 Apr 2023 15:50:04 -0500 Subject: [PATCH 3/3] reorder arguments --- meshmode/distributed.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/meshmode/distributed.py b/meshmode/distributed.py index 161dc78f4..154edd5ef 100644 --- a/meshmode/distributed.py +++ b/meshmode/distributed.py @@ -84,15 +84,15 @@ def _duplicate_mpi_comm(mpi_comm): def mpi_distribute( mpi_comm: "mpi4py.MPI.Intracomm", - source_data: Optional[Mapping[int, Any]] = None, - source_rank: int = 0) -> Optional[Any]: + source_rank: int = 0, + source_data: Optional[Mapping[int, Any]] = None) -> Optional[Any]: """ Distribute data to a set of processes. :arg mpi_comm: An ``MPI.Intracomm`` + :arg source_rank: The rank from which the data is being sent. :arg source_data: A :class:`dict` mapping destination ranks to data to be sent. Only present on the source rank. - :arg source_rank: The rank from which the data is being sent. :returns: The data local to the current process if there is any, otherwise *None*. @@ -174,7 +174,7 @@ def send_mesh_parts(self, mesh, part_per_element, num_parts): parts = partition_mesh(mesh, part_num_to_elements) return mpi_distribute( - self.mpi_comm, source_data=parts, source_rank=self.manager_rank) + self.mpi_comm, source_rank=self.manager_rank, source_data=parts) def receive_mesh_part(self): """