Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 18 additions & 16 deletions examples/parallel-vtkhdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,46 +32,48 @@ def main(*, ambient_dim: int) -> None:

from mpi4py import MPI
comm = MPI.COMM_WORLD
mpisize = comm.Get_size()
mpirank = comm.Get_rank()

from meshmode.distributed import MPIMeshDistributor
dist = MPIMeshDistributor(comm)
from meshmode.mesh.processing import partition_mesh
from meshmode.distributed import membership_list_to_map

order = 5
nelements = 64 if ambient_dim == 3 else 256

logger.info("[%4d] distributing mesh: started", mpirank)
logger.info("[%4d] distributing mesh: started", comm.rank)

if dist.is_mananger_rank():
if comm.rank == 0:
mesh = make_example_mesh(ambient_dim, nelements, order=order)
logger.info("[%4d] mesh: nelements %d nvertices %d",
mpirank, mesh.nelements, mesh.nvertices)
comm.rank, mesh.nelements, mesh.nvertices)

rng = np.random.default_rng()
part_per_element = rng.integers(mpisize, size=mesh.nelements)

local_mesh = dist.send_mesh_parts(mesh, part_per_element, mpisize)
part_id_to_part = partition_mesh(mesh,
membership_list_to_map(
rng.integers(comm.size, size=mesh.nelements)))
parts = [part_id_to_part[i] for i in range(comm.size)]
local_mesh = comm.scatter(parts)
else:
local_mesh = dist.receive_mesh_part()
# Reason for type-ignore: presumed faulty type annotation in mpi4py
local_mesh = comm.scatter(None) # type: ignore[arg-type]

logger.info("[%4d] distributing mesh: finished", mpirank)
logger.info("[%4d] distributing mesh: finished", comm.rank)

from meshmode.discretization import Discretization
from meshmode.discretization.poly_element import default_simplex_group_factory
discr = Discretization(actx, local_mesh,
default_simplex_group_factory(local_mesh.dim, order=order))

logger.info("[%4d] discretization: finished", mpirank)
logger.info("[%4d] discretization: finished", comm.rank)

vector_field = actx.thaw(discr.nodes())
scalar_field = actx.np.sin(vector_field[0])
part_id = 1.0 + mpirank + discr.zeros(actx) # type: ignore[operator]
logger.info("[%4d] fields: finished", mpirank)
part_id = 1.0 + comm.rank + discr.zeros(actx) # type: ignore[operator]
logger.info("[%4d] fields: finished", comm.rank)

from meshmode.discretization.visualization import make_visualizer
vis = make_visualizer(actx, discr, vis_order=order, force_equidistant=False)
logger.info("[%4d] make_visualizer: finished", mpirank)
logger.info("[%4d] make_visualizer: finished", comm.rank)

filename = f"parallel-vtkhdf-example-{ambient_dim}d.hdf"
vis.write_vtkhdf_file(filename, [
Expand All @@ -80,7 +82,7 @@ def main(*, ambient_dim: int) -> None:
("part_id", part_id)
], comm=comm, overwrite=True, use_high_order=False)

logger.info("[%4d] write: finished: %s", mpirank, filename)
logger.info("[%4d] write: finished: %s", comm.rank, filename)


if __name__ == "__main__":
Expand Down
5 changes: 4 additions & 1 deletion meshmode/distributed.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
"""
.. autoclass:: MPIMeshDistributor
.. autoclass:: InterRankBoundaryInfo
.. autoclass:: MPIBoundaryCommSetupHelper

Expand Down Expand Up @@ -85,6 +84,10 @@ def __init__(self, mpi_comm, manager_rank=0):
self.mpi_comm = mpi_comm
self.manager_rank = manager_rank

warn("MPIMeshDistributor is deprecated and will be removed in 2024. "
"Directly call partition_mesh and use mpi_comm.scatter instead.",
DeprecationWarning, stacklevel=2)

def is_mananger_rank(self):
return self.mpi_comm.Get_rank() == self.manager_rank

Expand Down
26 changes: 13 additions & 13 deletions test/test_partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,16 +368,14 @@ def count_tags(mesh, tag):
# {{{ MPI test boundary swap

def _test_mpi_boundary_swap(dim, order, num_groups):
from meshmode.distributed import MPIMeshDistributor, MPIBoundaryCommSetupHelper
from meshmode.distributed import (MPIBoundaryCommSetupHelper,
membership_list_to_map)
from meshmode.mesh.processing import partition_mesh

from mpi4py import MPI
mpi_comm = MPI.COMM_WORLD
i_local_part = mpi_comm.Get_rank()
num_parts = mpi_comm.Get_size()

mesh_dist = MPIMeshDistributor(mpi_comm)

if mesh_dist.is_mananger_rank():
if mpi_comm.rank == 0:
np.random.seed(42)
from meshmode.mesh.generation import generate_warped_rect_mesh
meshes = [generate_warped_rect_mesh(dim, order=order, nelements_side=4)
Expand All @@ -389,11 +387,14 @@ def _test_mpi_boundary_swap(dim, order, num_groups):
else:
mesh = meshes[0]

part_per_element = np.random.randint(num_parts, size=mesh.nelements)
part_id_to_part = partition_mesh(mesh,
membership_list_to_map(
np.random.randint(mpi_comm.size, size=mesh.nelements)))
parts = [part_id_to_part[i] for i in range(mpi_comm.size)]

local_mesh = mesh_dist.send_mesh_parts(mesh, part_per_element, num_parts)
local_mesh = mpi_comm.scatter(parts)
else:
local_mesh = mesh_dist.receive_mesh_part()
local_mesh = mpi_comm.scatter(None)

group_factory = default_simplex_group_factory(base_dim=dim, order=order)

Expand Down Expand Up @@ -436,14 +437,13 @@ def _test_mpi_boundary_swap(dim, order, num_groups):
remote_to_local_bdry_conns,
connected_parts)

logger.debug("Rank %d exiting", i_local_part)
logger.debug("Rank %d exiting", mpi_comm.rank)


def _test_connected_parts(mpi_comm, connected_parts):
num_parts = mpi_comm.Get_size()
i_local_part = mpi_comm.Get_rank()

assert i_local_part not in connected_parts
assert mpi_comm.rank not in connected_parts

# Get the full adjacency
connected_mask = np.empty(num_parts, dtype=bool)
Expand All @@ -456,7 +456,7 @@ def _test_connected_parts(mpi_comm, connected_parts):
# make sure it agrees with connected_parts
parts_connected_to_me = set()
for i_remote_part in range(num_parts):
if all_connected_masks[i_remote_part][i_local_part]:
if all_connected_masks[i_remote_part][mpi_comm.rank]:
parts_connected_to_me.add(i_remote_part)
assert parts_connected_to_me == connected_parts

Expand Down