diff --git a/examples/parallel-vtkhdf.py b/examples/parallel-vtkhdf.py index 7a647fb93..c5ead8db5 100644 --- a/examples/parallel-vtkhdf.py +++ b/examples/parallel-vtkhdf.py @@ -32,46 +32,48 @@ def main(*, ambient_dim: int) -> None: from mpi4py import MPI comm = MPI.COMM_WORLD - mpisize = comm.Get_size() - mpirank = comm.Get_rank() - from meshmode.distributed import MPIMeshDistributor - dist = MPIMeshDistributor(comm) + from meshmode.mesh.processing import partition_mesh + from meshmode.distributed import membership_list_to_map order = 5 nelements = 64 if ambient_dim == 3 else 256 - logger.info("[%4d] distributing mesh: started", mpirank) + logger.info("[%4d] distributing mesh: started", comm.rank) - if dist.is_mananger_rank(): + if comm.rank == 0: mesh = make_example_mesh(ambient_dim, nelements, order=order) logger.info("[%4d] mesh: nelements %d nvertices %d", - mpirank, mesh.nelements, mesh.nvertices) + comm.rank, mesh.nelements, mesh.nvertices) rng = np.random.default_rng() - part_per_element = rng.integers(mpisize, size=mesh.nelements) - local_mesh = dist.send_mesh_parts(mesh, part_per_element, mpisize) + part_id_to_part = partition_mesh(mesh, + membership_list_to_map( + rng.integers(comm.size, size=mesh.nelements))) + parts = [part_id_to_part[i] for i in range(comm.size)] + local_mesh = comm.scatter(parts) else: - local_mesh = dist.receive_mesh_part() + # Reason for type-ignore: presumed faulty type annotation in mpi4py + local_mesh = comm.scatter(None) # type: ignore[arg-type] - logger.info("[%4d] distributing mesh: finished", mpirank) + logger.info("[%4d] distributing mesh: finished", comm.rank) from meshmode.discretization import Discretization from meshmode.discretization.poly_element import default_simplex_group_factory discr = Discretization(actx, local_mesh, default_simplex_group_factory(local_mesh.dim, order=order)) - logger.info("[%4d] discretization: finished", mpirank) + logger.info("[%4d] discretization: finished", comm.rank) vector_field = actx.thaw(discr.nodes()) scalar_field = actx.np.sin(vector_field[0]) - part_id = 1.0 + mpirank + discr.zeros(actx) # type: ignore[operator] - logger.info("[%4d] fields: finished", mpirank) + part_id = 1.0 + comm.rank + discr.zeros(actx) # type: ignore[operator] + logger.info("[%4d] fields: finished", comm.rank) from meshmode.discretization.visualization import make_visualizer vis = make_visualizer(actx, discr, vis_order=order, force_equidistant=False) - logger.info("[%4d] make_visualizer: finished", mpirank) + logger.info("[%4d] make_visualizer: finished", comm.rank) filename = f"parallel-vtkhdf-example-{ambient_dim}d.hdf" vis.write_vtkhdf_file(filename, [ @@ -80,7 +82,7 @@ def main(*, ambient_dim: int) -> None: ("part_id", part_id) ], comm=comm, overwrite=True, use_high_order=False) - logger.info("[%4d] write: finished: %s", mpirank, filename) + logger.info("[%4d] write: finished: %s", comm.rank, filename) if __name__ == "__main__": diff --git a/meshmode/distributed.py b/meshmode/distributed.py index d99e90dc8..9ebf9b816 100644 --- a/meshmode/distributed.py +++ b/meshmode/distributed.py @@ -1,5 +1,4 @@ """ -.. autoclass:: MPIMeshDistributor .. autoclass:: InterRankBoundaryInfo .. autoclass:: MPIBoundaryCommSetupHelper @@ -85,6 +84,10 @@ def __init__(self, mpi_comm, manager_rank=0): self.mpi_comm = mpi_comm self.manager_rank = manager_rank + warn("MPIMeshDistributor is deprecated and will be removed in 2024. " + "Directly call partition_mesh and use mpi_comm.scatter instead.", + DeprecationWarning, stacklevel=2) + def is_mananger_rank(self): return self.mpi_comm.Get_rank() == self.manager_rank diff --git a/test/test_partition.py b/test/test_partition.py index 1ad219623..c80ac5b12 100644 --- a/test/test_partition.py +++ b/test/test_partition.py @@ -368,16 +368,14 @@ def count_tags(mesh, tag): # {{{ MPI test boundary swap def _test_mpi_boundary_swap(dim, order, num_groups): - from meshmode.distributed import MPIMeshDistributor, MPIBoundaryCommSetupHelper + from meshmode.distributed import (MPIBoundaryCommSetupHelper, + membership_list_to_map) + from meshmode.mesh.processing import partition_mesh from mpi4py import MPI mpi_comm = MPI.COMM_WORLD - i_local_part = mpi_comm.Get_rank() - num_parts = mpi_comm.Get_size() - mesh_dist = MPIMeshDistributor(mpi_comm) - - if mesh_dist.is_mananger_rank(): + if mpi_comm.rank == 0: np.random.seed(42) from meshmode.mesh.generation import generate_warped_rect_mesh meshes = [generate_warped_rect_mesh(dim, order=order, nelements_side=4) @@ -389,11 +387,14 @@ def _test_mpi_boundary_swap(dim, order, num_groups): else: mesh = meshes[0] - part_per_element = np.random.randint(num_parts, size=mesh.nelements) + part_id_to_part = partition_mesh(mesh, + membership_list_to_map( + np.random.randint(mpi_comm.size, size=mesh.nelements))) + parts = [part_id_to_part[i] for i in range(mpi_comm.size)] - local_mesh = mesh_dist.send_mesh_parts(mesh, part_per_element, num_parts) + local_mesh = mpi_comm.scatter(parts) else: - local_mesh = mesh_dist.receive_mesh_part() + local_mesh = mpi_comm.scatter(None) group_factory = default_simplex_group_factory(base_dim=dim, order=order) @@ -436,14 +437,13 @@ def _test_mpi_boundary_swap(dim, order, num_groups): remote_to_local_bdry_conns, connected_parts) - logger.debug("Rank %d exiting", i_local_part) + logger.debug("Rank %d exiting", mpi_comm.rank) def _test_connected_parts(mpi_comm, connected_parts): num_parts = mpi_comm.Get_size() - i_local_part = mpi_comm.Get_rank() - assert i_local_part not in connected_parts + assert mpi_comm.rank not in connected_parts # Get the full adjacency connected_mask = np.empty(num_parts, dtype=bool) @@ -456,7 +456,7 @@ def _test_connected_parts(mpi_comm, connected_parts): # make sure it agrees with connected_parts parts_connected_to_me = set() for i_remote_part in range(num_parts): - if all_connected_masks[i_remote_part][i_local_part]: + if all_connected_masks[i_remote_part][mpi_comm.rank]: parts_connected_to_me.add(i_remote_part) assert parts_connected_to_me == connected_parts