Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 10 additions & 8 deletions examples/wave/wave-min-mpi.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ class WaveTag:

def main(ctx_factory, dim=2, order=4, visualize=False):
comm = MPI.COMM_WORLD
num_parts = comm.Get_size()
num_parts = comm.size

cl_ctx = cl.create_some_context()
queue = cl.CommandQueue(cl_ctx)
Expand All @@ -60,10 +60,10 @@ def main(ctx_factory, dim=2, order=4, visualize=False):
force_device_scalars=True,
)

from meshmode.distributed import MPIMeshDistributor, get_partition_by_pymetis
mesh_dist = MPIMeshDistributor(comm)
from meshmode.distributed import get_partition_by_pymetis, membership_list_to_map
from meshmode.mesh.processing import partition_mesh

if mesh_dist.is_mananger_rank():
if comm.rank == 0:
from meshmode.mesh.generation import generate_regular_rect_mesh
mesh = generate_regular_rect_mesh(
a=(-0.5,)*dim,
Expand All @@ -72,14 +72,16 @@ def main(ctx_factory, dim=2, order=4, visualize=False):

logger.info("%d elements", mesh.nelements)

part_per_element = get_partition_by_pymetis(mesh, num_parts)

local_mesh = mesh_dist.send_mesh_parts(mesh, part_per_element, num_parts)
part_id_to_part = partition_mesh(mesh,
membership_list_to_map(
get_partition_by_pymetis(mesh, num_parts)))
parts = [part_id_to_part[i] for i in range(num_parts)]
local_mesh = comm.scatter(parts)

del mesh

else:
local_mesh = mesh_dist.receive_mesh_part()
local_mesh = comm.scatter(None)

dcoll = DiscretizationCollection(actx, local_mesh, order=order)

Expand Down
17 changes: 10 additions & 7 deletions examples/wave/wave-op-mpi.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def main(ctx_factory, dim=2, order=3,
queue = cl.CommandQueue(cl_ctx)

comm = MPI.COMM_WORLD
num_parts = comm.Get_size()
num_parts = comm.size

from grudge.array_context import get_reasonable_array_context_class
actx_class = get_reasonable_array_context_class(lazy=lazy, distributed=True)
Expand All @@ -195,12 +195,12 @@ def main(ctx_factory, dim=2, order=3,
allocator=cl_tools.MemoryPool(cl_tools.ImmediateAllocator(queue)),
force_device_scalars=True)

from meshmode.distributed import MPIMeshDistributor, get_partition_by_pymetis
mesh_dist = MPIMeshDistributor(comm)
from meshmode.distributed import get_partition_by_pymetis, membership_list_to_map
from meshmode.mesh.processing import partition_mesh

nel_1d = 16

if mesh_dist.is_mananger_rank():
if comm.rank == 0:
if use_nonaffine_mesh:
from meshmode.mesh.generation import generate_warped_rect_mesh
# FIXME: *generate_warped_rect_mesh* in meshmode warps a
Expand All @@ -218,14 +218,17 @@ def main(ctx_factory, dim=2, order=3,

logger.info("%d elements", mesh.nelements)

part_per_element = get_partition_by_pymetis(mesh, num_parts)
part_id_to_part = partition_mesh(mesh,
membership_list_to_map(
get_partition_by_pymetis(mesh, num_parts)))
parts = [part_id_to_part[i] for i in range(num_parts)]

local_mesh = mesh_dist.send_mesh_parts(mesh, part_per_element, num_parts)
local_mesh = comm.scatter(parts)

del mesh

else:
local_mesh = mesh_dist.receive_mesh_part()
local_mesh = comm.scatter(None)

from meshmode.discretization.poly_element import \
QuadratureSimplexGroupFactory, \
Expand Down
44 changes: 24 additions & 20 deletions test/test_mpi_communication.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,24 +115,26 @@ def _test_func_comparison_mpi_communication_entrypoint(actx):

comm = actx.mpi_communicator

from meshmode.distributed import MPIMeshDistributor, get_partition_by_pymetis
from meshmode.distributed import (
get_partition_by_pymetis, membership_list_to_map)
from meshmode.mesh import BTAG_ALL
from meshmode.mesh.processing import partition_mesh

num_parts = comm.Get_size()
num_parts = comm.size

mesh_dist = MPIMeshDistributor(comm)

if mesh_dist.is_mananger_rank():
if comm.rank == 0:
from meshmode.mesh.generation import generate_regular_rect_mesh
mesh = generate_regular_rect_mesh(a=(-1,)*2,
b=(1,)*2,
nelements_per_axis=(2,)*2)

part_per_element = get_partition_by_pymetis(mesh, num_parts)

local_mesh = mesh_dist.send_mesh_parts(mesh, part_per_element, num_parts)
part_id_to_part = partition_mesh(mesh,
membership_list_to_map(
get_partition_by_pymetis(mesh, num_parts)))
parts = [part_id_to_part[i] for i in range(num_parts)]
local_mesh = comm.scatter(parts)
else:
local_mesh = mesh_dist.receive_mesh_part()
local_mesh = comm.scatter(None)

dcoll = DiscretizationCollection(actx, local_mesh, order=5)

Expand Down Expand Up @@ -188,28 +190,30 @@ def test_mpi_wave_op(actx_class, num_ranks):

def _test_mpi_wave_op_entrypoint(actx, visualize=False):
comm = actx.mpi_communicator
i_local_rank = comm.Get_rank()
num_parts = comm.Get_size()
num_parts = comm.size

from meshmode.distributed import MPIMeshDistributor, get_partition_by_pymetis
mesh_dist = MPIMeshDistributor(comm)
from meshmode.distributed import (
get_partition_by_pymetis, membership_list_to_map)
from meshmode.mesh.processing import partition_mesh

dim = 2
order = 4

if mesh_dist.is_mananger_rank():
if comm.rank == 0:
from meshmode.mesh.generation import generate_regular_rect_mesh
mesh = generate_regular_rect_mesh(a=(-0.5,)*dim,
b=(0.5,)*dim,
nelements_per_axis=(16,)*dim)

part_per_element = get_partition_by_pymetis(mesh, num_parts)

local_mesh = mesh_dist.send_mesh_parts(mesh, part_per_element, num_parts)
part_id_to_part = partition_mesh(mesh,
membership_list_to_map(
get_partition_by_pymetis(mesh, num_parts)))
parts = [part_id_to_part[i] for i in range(num_parts)]
local_mesh = comm.scatter(parts)

del mesh
else:
local_mesh = mesh_dist.receive_mesh_part()
local_mesh = comm.scatter(None)

dcoll = DiscretizationCollection(actx, local_mesh, order=order)

Expand Down Expand Up @@ -270,7 +274,7 @@ def rhs(t, w):

final_t = 4
nsteps = int(final_t/dt)
logger.info("[%04d] dt %.5e nsteps %4d", i_local_rank, dt, nsteps)
logger.info("[%04d] dt %.5e nsteps %4d", comm.rank, dt, nsteps)

step = 0

Expand Down Expand Up @@ -308,7 +312,7 @@ def rhs(t, w):

logmgr.tick_after()
logmgr.close()
logger.info("Rank %d exiting", i_local_rank)
logger.info("Rank %d exiting", comm.rank)

# }}}

Expand Down