Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 73 additions & 32 deletions meshmode/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
.. autoclass:: InterRankBoundaryInfo
.. autoclass:: MPIBoundaryCommSetupHelper

.. autofunction:: mpi_distribute
.. autofunction:: get_partition_by_pymetis
.. autofunction:: membership_list_to_map
.. autofunction:: get_connected_parts
Expand Down Expand Up @@ -37,8 +38,11 @@
"""

from dataclasses import dataclass
from contextlib import contextmanager
import numpy as np
from typing import List, Set, Union, Mapping, cast, Sequence, TYPE_CHECKING
from typing import (
Any, Optional, List, Set, Union, Mapping, cast, Sequence, TYPE_CHECKING
)

from arraycontext import ArrayContext
from meshmode.discretization.connection import (
Expand Down Expand Up @@ -66,12 +70,73 @@
import logging
logger = logging.getLogger(__name__)

TAG_BASE = 83411
TAG_DISTRIBUTE_MESHES = TAG_BASE + 1


# {{{ mesh distributor

@contextmanager
def _duplicate_mpi_comm(mpi_comm):
dup_comm = mpi_comm.Dup()
try:
yield dup_comm
finally:
dup_comm.Free()


def mpi_distribute(
mpi_comm: "mpi4py.MPI.Intracomm",
source_rank: int = 0,
source_data: Optional[Mapping[int, Any]] = None) -> Optional[Any]:
"""
Distribute data to a set of processes.

:arg mpi_comm: An ``MPI.Intracomm``
:arg source_rank: The rank from which the data is being sent.
:arg source_data: A :class:`dict` mapping destination ranks to data to be sent.
Only present on the source rank.

:returns: The data local to the current process if there is any, otherwise
*None*.
"""
with _duplicate_mpi_comm(mpi_comm) as mpi_comm:
num_proc = mpi_comm.Get_size()
rank = mpi_comm.Get_rank()

local_data = None

if rank == source_rank:
if source_data is None:
raise TypeError("source rank has no data.")

sending_to = [False] * num_proc
for dest_rank in source_data.keys():
sending_to[dest_rank] = True

mpi_comm.scatter(sending_to, root=source_rank)

reqs = []
for dest_rank, data in source_data.items():
if dest_rank == rank:
local_data = data
logger.info("rank %d: received data", rank)
else:
reqs.append(mpi_comm.isend(data, dest=dest_rank))

logger.info("rank %d: sent all data", rank)

from mpi4py import MPI
MPI.Request.waitall(reqs)

else:
receiving = mpi_comm.scatter([], root=source_rank)

if receiving:
local_data = mpi_comm.recv(source=source_rank)
logger.info("rank %d: received data", rank)

return local_data


# TODO: Deprecate?
class MPIMeshDistributor:
"""
.. automethod:: is_mananger_rank
Expand Down Expand Up @@ -99,9 +164,7 @@ def send_mesh_parts(self, mesh, part_per_element, num_parts):
Sends each part to a different rank.
Returns one part that was not sent to any other rank.
"""
mpi_comm = self.mpi_comm
rank = mpi_comm.Get_rank()
assert num_parts <= mpi_comm.Get_size()
assert num_parts <= self.mpi_comm.Get_size()

assert self.is_mananger_rank()

Expand All @@ -110,38 +173,16 @@ def send_mesh_parts(self, mesh, part_per_element, num_parts):
from meshmode.mesh.processing import partition_mesh
parts = partition_mesh(mesh, part_num_to_elements)

local_part = None

reqs = []
for r, part in parts.items():
if r == self.manager_rank:
local_part = part
else:
reqs.append(mpi_comm.isend(part, dest=r, tag=TAG_DISTRIBUTE_MESHES))

logger.info("rank %d: sent all mesh parts", rank)
for req in reqs:
req.wait()

return local_part
return mpi_distribute(
self.mpi_comm, source_rank=self.manager_rank, source_data=parts)

def receive_mesh_part(self):
"""
Returns the mesh sent by the manager rank.
"""
mpi_comm = self.mpi_comm
rank = mpi_comm.Get_rank()

assert not self.is_mananger_rank(), "Manager rank cannot receive mesh"

from mpi4py import MPI
status = MPI.Status()
result = self.mpi_comm.recv(
source=self.manager_rank, tag=TAG_DISTRIBUTE_MESHES,
status=status)
logger.info("rank %d: received local mesh (size = %d)", rank, status.count)

return result
return mpi_distribute(self.mpi_comm, source_rank=self.manager_rank)

# }}}

Expand Down