diff --git a/grudge/discretization.py b/grudge/discretization.py index dc08cd11f..c05a59c35 100644 --- a/grudge/discretization.py +++ b/grudge/discretization.py @@ -7,13 +7,7 @@ .. autofunction:: make_discretization_collection .. currentmodule:: grudge.discretization - -.. autofunction:: relabel_partitions - -Internal things that are visble due to type annotations -------------------------------------------------------- - -.. class:: _InterPartitionConnectionPair +.. autoclass:: PartID """ __copyright__ = """ @@ -41,13 +35,14 @@ THE SOFTWARE. """ -from typing import Mapping, Optional, Union, Tuple, TYPE_CHECKING, Any +from typing import Sequence, Mapping, Optional, Union, Tuple, TYPE_CHECKING, Any from pytools import memoize_method, single_valued +from dataclasses import dataclass, replace + from grudge.dof_desc import ( VTAG_ALL, - BTAG_MULTIVOL_PARTITION, DD_VOLUME_ALL, DISCR_TAG_BASE, DISCR_TAG_MODAL, @@ -70,8 +65,7 @@ make_face_restriction, DiscretizationConnection ) -from meshmode.mesh import ( - InterPartitionAdjacencyGroup, Mesh, BTAG_PARTITION, BoundaryTag) +from meshmode.mesh import Mesh, BTAG_PARTITION from meshmode.dof_array import DOFArray from warnings import warn @@ -80,6 +74,89 @@ import mpi4py.MPI +@dataclass(frozen=True) +class PartID: + """Unique identifier for a piece of a partitioned mesh. + + .. attribute:: volume_tag + + The volume of the part. + + .. attribute:: rank + + The (optional) MPI rank of the part. + + """ + volume_tag: VolumeTag + rank: Optional[int] = None + + +# {{{ part ID normalization + +def _normalize_mesh_part_ids( + mesh: Mesh, + volume_tags: Sequence[VolumeTag], + mpi_communicator: Optional["mpi4py.MPI.Intracomm"] = None): + """Convert a mesh's configuration-dependent "part ID" into a fixed type.""" + from numbers import Integral + if VTAG_ALL not in volume_tags: + # Multi-volume + if mpi_communicator is not None: + # Accept PartID + def as_part_id(mesh_part_id): + if isinstance(mesh_part_id, PartID): + return mesh_part_id + else: + raise TypeError(f"Unable to convert {mesh_part_id} to PartID.") + else: + # Accept PartID or volume tag + def as_part_id(mesh_part_id): + if isinstance(mesh_part_id, PartID): + return mesh_part_id + elif mesh_part_id in volume_tags: + return PartID(mesh_part_id) + else: + raise TypeError(f"Unable to convert {mesh_part_id} to PartID.") + else: + # Single-volume + if mpi_communicator is not None: + # Accept PartID or rank + def as_part_id(mesh_part_id): + if isinstance(mesh_part_id, PartID): + return mesh_part_id + elif isinstance(mesh_part_id, Integral): + return PartID(VTAG_ALL, int(mesh_part_id)) + else: + raise TypeError(f"Unable to convert {mesh_part_id} to PartID.") + else: + # Shouldn't be called + def as_part_id(mesh_part_id): + raise TypeError(f"Unable to convert {mesh_part_id} to PartID.") + + facial_adjacency_groups = mesh.facial_adjacency_groups + + new_facial_adjacency_groups = [] + + from meshmode.mesh import InterPartAdjacencyGroup + for grp_list in facial_adjacency_groups: + new_grp_list = [] + for fagrp in grp_list: + if isinstance(fagrp, InterPartAdjacencyGroup): + part_id = as_part_id(fagrp.part_id) + new_fagrp = replace( + fagrp, + boundary_tag=BTAG_PARTITION(part_id), + part_id=part_id) + else: + new_fagrp = fagrp + new_grp_list.append(new_fagrp) + new_facial_adjacency_groups.append(new_grp_list) + + return mesh.copy(facial_adjacency_groups=new_facial_adjacency_groups) + +# }}} + + # {{{ discr_tag_to_group_factory normalization def _normalize_discr_tag_to_group_factory( @@ -133,6 +210,11 @@ class DiscretizationCollection: (volume, interior facets, boundaries) and associated element groups. + .. note:: + + Do not call the constructor directly. Use + :func:`make_discretization_collection` instead. + .. autoattribute:: dim .. autoattribute:: ambient_dim .. autoattribute:: real_dtype @@ -160,8 +242,9 @@ def __init__(self, array_context: ArrayContext, discr_tag_to_group_factory: Optional[ Mapping[DiscretizationTag, ElementGroupFactory]] = None, mpi_communicator: Optional["mpi4py.MPI.Intracomm"] = None, - inter_partition_connections: Optional[ - Mapping[BoundaryDomainTag, DiscretizationConnection]] = None + inter_part_connections: Optional[ + Mapping[Tuple[PartID, PartID], + DiscretizationConnection]] = None, ) -> None: """ :arg discr_tag_to_group_factory: A mapping from discretization tags @@ -202,15 +285,19 @@ def __init__(self, array_context: ArrayContext, from meshmode.discretization import Discretization - # {{{ deprecated backward compatibility yuck - if isinstance(volume_discrs, Mesh): + # {{{ deprecated backward compatibility yuck + warn("Calling the DiscretizationCollection constructor directly " "is deprecated, call make_discretization_collection " "instead. This will stop working in 2023.", DeprecationWarning, stacklevel=2) mesh = volume_discrs + + mesh = _normalize_mesh_part_ids( + mesh, [VTAG_ALL], mpi_communicator=mpi_communicator) + discr_tag_to_group_factory = _normalize_discr_tag_to_group_factory( dim=mesh.dim, discr_tag_to_group_factory=discr_tag_to_group_factory, @@ -224,30 +311,30 @@ def __init__(self, array_context: ArrayContext, del mesh - if inter_partition_connections is not None: - raise TypeError("may not pass inter_partition_connections when " + if inter_part_connections is not None: + raise TypeError("may not pass inter_part_connections when " "DiscretizationCollection constructor is called in " "legacy mode") - self._inter_partition_connections = \ - _set_up_inter_partition_connections( + self._inter_part_connections = \ + _set_up_inter_part_connections( array_context=self._setup_actx, mpi_communicator=mpi_communicator, volume_discrs=volume_discrs, base_group_factory=( discr_tag_to_group_factory[DISCR_TAG_BASE])) - else: - if inter_partition_connections is None: - raise TypeError("inter_partition_connections must be passed when " - "DiscretizationCollection constructor is called in " - "'modern' mode") - - self._inter_partition_connections = inter_partition_connections + # }}} + else: assert discr_tag_to_group_factory is not None self._discr_tag_to_group_factory = discr_tag_to_group_factory - # }}} + if inter_part_connections is None: + raise TypeError("inter_part_connections must be passed when " + "DiscretizationCollection constructor is called in " + "'modern' mode") + + self._inter_part_connections = inter_part_connections self._volume_discrs = volume_discrs @@ -729,104 +816,66 @@ def normal(self, dd): # {{{ distributed/multi-volume setup -def _check_btag(tag: BoundaryTag) -> Union[BTAG_MULTIVOL_PARTITION, BTAG_PARTITION]: - if isinstance(tag, BTAG_MULTIVOL_PARTITION): - return tag - - elif isinstance(tag, BTAG_PARTITION): - return tag - - else: - raise TypeError("unexpected type of inter-partition boundary tag " - f"'{type(tag)}'") - - -def _remote_rank_from_btag(btag: BoundaryTag) -> Optional[int]: - if isinstance(btag, BTAG_PARTITION): - return btag.part_nr - - elif isinstance(btag, BTAG_MULTIVOL_PARTITION): - return btag.other_rank - - else: - raise TypeError("unexpected type of inter-partition boundary tag " - f"'{type(btag)}'") - - -def _flip_dtag( - self_rank: Optional[int], - domain_tag: BoundaryDomainTag, - ) -> BoundaryDomainTag: - if isinstance(domain_tag.tag, BTAG_PARTITION): - assert self_rank is not None - return BoundaryDomainTag( - BTAG_PARTITION(self_rank), domain_tag.volume_tag) - - elif isinstance(domain_tag.tag, BTAG_MULTIVOL_PARTITION): - return BoundaryDomainTag( - BTAG_MULTIVOL_PARTITION( - other_rank=None if domain_tag.tag.other_rank is None else self_rank, - other_volume_tag=domain_tag.volume_tag), - domain_tag.tag.other_volume_tag) - - else: - raise TypeError("unexpected type of inter-partition boundary tag " - f"'{type(domain_tag.tag)}'") - - -def _set_up_inter_partition_connections( +def _set_up_inter_part_connections( array_context: ArrayContext, mpi_communicator: Optional["mpi4py.MPI.Intracomm"], volume_discrs: Mapping[VolumeTag, Discretization], - base_group_factory: ElementGroupFactory, + base_group_factory: ElementGroupFactory, ) -> Mapping[ - BoundaryDomainTag, + Tuple[PartID, PartID], DiscretizationConnection]: - from meshmode.distributed import (get_inter_partition_tags, + from meshmode.distributed import (get_connected_parts, make_remote_group_infos, InterRankBoundaryInfo, MPIBoundaryCommSetupHelper) - inter_part_tags = { - BoundaryDomainTag(_check_btag(btag), discr_vol_tag) - for discr_vol_tag, volume_discr in volume_discrs.items() - for btag in get_inter_partition_tags(volume_discr.mesh)} + rank = mpi_communicator.Get_rank() if mpi_communicator is not None else None + + # Save boundary restrictions as they're created to avoid potentially creating + # them twice in the loop below + cached_part_bdry_restrictions: Mapping[ + Tuple[PartID, PartID], + DiscretizationConnection] = {} + + def get_part_bdry_restriction(self_part_id, other_part_id): + cached_result = cached_part_bdry_restrictions.get( + (self_part_id, other_part_id), None) + if cached_result is not None: + return cached_result + return cached_part_bdry_restrictions.setdefault( + (self_part_id, other_part_id), + make_face_restriction( + array_context, volume_discrs[self_part_id.volume_tag], + base_group_factory, + boundary_tag=BTAG_PARTITION(other_part_id))) inter_part_conns: Mapping[ - BoundaryDomainTag, + Tuple[PartID, PartID], DiscretizationConnection] = {} - if inter_part_tags: - local_boundary_restrictions = { - domain_tag: make_face_restriction( - array_context, volume_discrs[domain_tag.volume_tag], - base_group_factory, boundary_tag=domain_tag.tag) - for domain_tag in inter_part_tags} + irbis = [] - irbis = [] + for vtag, volume_discr in volume_discrs.items(): + part_id = PartID(vtag, rank) + connected_part_ids = get_connected_parts(volume_discr.mesh) + for connected_part_id in connected_part_ids: + bdry_restr = get_part_bdry_restriction(part_id, connected_part_id) - for domain_tag in inter_part_tags: - assert isinstance( - domain_tag.tag, (BTAG_PARTITION, BTAG_MULTIVOL_PARTITION)) - - other_rank = _remote_rank_from_btag(domain_tag.tag) - btag_restr = local_boundary_restrictions[domain_tag] - - if other_rank is None: + if connected_part_id.rank == rank: # {{{ rank-local interface between multiple volumes - assert isinstance(domain_tag.tag, BTAG_MULTIVOL_PARTITION) + rev_bdry_restr = get_part_bdry_restriction( + connected_part_id, part_id) from meshmode.discretization.connection import \ make_partition_connection - remote_dtag = _flip_dtag(None, domain_tag) - inter_part_conns[domain_tag] = make_partition_connection( + inter_part_conns[connected_part_id, part_id] = \ + make_partition_connection( array_context, - local_bdry_conn=btag_restr, - remote_bdry_discr=( - local_boundary_restrictions[remote_dtag].to_discr), + local_bdry_conn=bdry_restr, + remote_bdry_discr=rev_bdry_restr.to_discr, remote_group_infos=make_remote_group_infos( - array_context, remote_dtag.tag, btag_restr)) + array_context, connected_part_id, rev_bdry_restr)) # }}} else: @@ -838,27 +887,25 @@ def _set_up_inter_partition_connections( irbis.append( InterRankBoundaryInfo( - local_btag=domain_tag.tag, - local_part_id=domain_tag, - remote_part_id=_flip_dtag( - mpi_communicator.rank, domain_tag), - remote_rank=other_rank, - local_boundary_connection=btag_restr)) + local_part_id=part_id, + remote_part_id=connected_part_id, + remote_rank=connected_part_id.rank, + local_boundary_connection=bdry_restr)) # }}} - if irbis: - assert mpi_communicator is not None + if irbis: + assert mpi_communicator is not None - with MPIBoundaryCommSetupHelper(mpi_communicator, array_context, - irbis, base_group_factory) as bdry_setup_helper: - while True: - conns = bdry_setup_helper.complete_some() - if not conns: - # We're done. - break + with MPIBoundaryCommSetupHelper(mpi_communicator, array_context, + irbis, base_group_factory) as bdry_setup_helper: + while True: + conns = bdry_setup_helper.complete_some() + if not conns: + # We're done. + break - inter_part_conns.update(conns) + inter_part_conns.update(conns) return inter_part_conns @@ -942,6 +989,7 @@ def make_discretization_collection( from pytools import single_valued, is_single_valued + assert len(volumes) > 0 assert is_single_valued(mesh_or_discr.ambient_dim for mesh_or_discr in volumes.values()) @@ -953,62 +1001,30 @@ def make_discretization_collection( del order + mpi_communicator = getattr(array_context, "mpi_communicator", None) + + if any( + isinstance(mesh_or_discr, Discretization) + for mesh_or_discr in volumes.values()): + raise NotImplementedError("Doesn't work at the moment") + volume_discrs = { - vtag: ( - Discretization( - array_context, mesh_or_discr, - discr_tag_to_group_factory[DISCR_TAG_BASE]) - if isinstance(mesh_or_discr, Mesh) else mesh_or_discr) - for vtag, mesh_or_discr in volumes.items()} + vtag: Discretization( + array_context, + _normalize_mesh_part_ids( + mesh, volumes.keys(), mpi_communicator=mpi_communicator), + discr_tag_to_group_factory[DISCR_TAG_BASE]) + for vtag, mesh in volumes.items()} return DiscretizationCollection( array_context=array_context, volume_discrs=volume_discrs, discr_tag_to_group_factory=discr_tag_to_group_factory, - inter_partition_connections=_set_up_inter_partition_connections( + inter_part_connections=_set_up_inter_part_connections( array_context=array_context, - mpi_communicator=getattr( - array_context, "mpi_communicator", None), + mpi_communicator=mpi_communicator, volume_discrs=volume_discrs, - base_group_factory=discr_tag_to_group_factory[DISCR_TAG_BASE], - )) - -# }}} - - -# {{{ relabel_partitions - -def relabel_partitions(mesh: Mesh, - self_rank: int, - part_nr_to_rank_and_vol_tag: Mapping[int, Tuple[int, VolumeTag]]) -> Mesh: - """Given a partitioned mesh (which includes :class:`meshmode.mesh.BTAG_PARTITION` - boundary tags), relabel those boundary tags into - :class:`grudge.dof_desc.BTAG_MULTIVOL_PARTITION` tags, which map each - of the incoming partitions onto a combination of rank and volume tag, - given by *part_nr_to_rank_and_vol_tag*. - """ - - def _new_btag(btag: BoundaryTag) -> BTAG_MULTIVOL_PARTITION: - if not isinstance(btag, BTAG_PARTITION): - raise TypeError("unexpected inter-partition boundary tags of type " - f"'{type(btag)}', expected BTAG_PARTITION") - - rank, vol_tag = part_nr_to_rank_and_vol_tag[btag.part_nr] - return BTAG_MULTIVOL_PARTITION( - other_rank=(None if rank == self_rank else rank), - other_volume_tag=vol_tag) - - assert mesh.facial_adjacency_groups is not None - - from dataclasses import replace - return mesh.copy(facial_adjacency_groups=[ - [ - replace(fagrp, - boundary_tag=_new_btag(fagrp.boundary_tag)) - if isinstance(fagrp, InterPartitionAdjacencyGroup) - else fagrp - for fagrp in grp_fagrp_list] - for grp_fagrp_list in mesh.facial_adjacency_groups]) + base_group_factory=discr_tag_to_group_factory[DISCR_TAG_BASE])) # }}} diff --git a/grudge/dof_desc.py b/grudge/dof_desc.py index e3015d1d8..46d5553a1 100644 --- a/grudge/dof_desc.py +++ b/grudge/dof_desc.py @@ -8,8 +8,6 @@ :mod:`grudge`-specific boundary tags ------------------------------------ -.. autoclass:: BTAG_MULTIVOL_PARTITION - Domain tags ----------- @@ -111,24 +109,6 @@ class VTAG_ALL: # noqa: N801 # }}} -# {{{ partition identifier - -@dataclass(init=True, eq=True, frozen=True) -class BTAG_MULTIVOL_PARTITION: # noqa: N801 - """ - .. attribute:: other_rank - - An integer, or *None*. If *None*, this marks a partition boundary - to another volume on the same rank. - - .. attribute:: other_volume_tag - """ - other_rank: Optional[int] - other_volume_tag: "VolumeTag" - -# }}} - - # {{{ domain tag @dataclass(frozen=True, eq=True) @@ -411,7 +391,7 @@ def _normalize_domain_and_discr_tag( domain = BoundaryDomainTag(FACE_RESTR_ALL) elif domain in [FACE_RESTR_INTERIOR, "int_faces"]: domain = BoundaryDomainTag(FACE_RESTR_INTERIOR) - elif isinstance(domain, (BTAG_PARTITION, BTAG_MULTIVOL_PARTITION)): + elif isinstance(domain, BTAG_PARTITION): domain = BoundaryDomainTag(domain, _contextual_volume_tag) elif domain in [BTAG_ALL, BTAG_REALLY_ALL, BTAG_NONE]: domain = BoundaryDomainTag(domain, _contextual_volume_tag) diff --git a/grudge/trace_pair.py b/grudge/trace_pair.py index 3db3517a6..d06fd13ea 100644 --- a/grudge/trace_pair.py +++ b/grudge/trace_pair.py @@ -78,7 +78,7 @@ from pytools import memoize_on_first_arg from pytools.obj_array import obj_array_vectorize -from grudge.discretization import DiscretizationCollection, _remote_rank_from_btag +from grudge.discretization import DiscretizationCollection, PartID from grudge.projection import project from meshmode.mesh import BTAG_PARTITION @@ -88,7 +88,7 @@ import grudge.dof_desc as dof_desc from grudge.dof_desc import ( DOFDesc, DD_VOLUME_ALL, FACE_RESTR_INTERIOR, DISCR_TAG_BASE, - VolumeTag, VolumeDomainTag, BoundaryDomainTag, BTAG_MULTIVOL_PARTITION, + VolumeTag, VolumeDomainTag, BoundaryDomainTag, ConvertibleToDOFDesc, ) @@ -378,15 +378,16 @@ def local_inter_volume_trace_pairs( raise TypeError( f"expected a base-discretized other DOFDesc, got '{other_volume_dd}'") - self_btag = BTAG_MULTIVOL_PARTITION( - other_rank=None, - other_volume_tag=other_volume_dd.domain_tag.tag) - other_btag = BTAG_MULTIVOL_PARTITION( - other_rank=None, - other_volume_tag=self_volume_dd.domain_tag.tag) + rank = ( + dcoll.mpi_communicator.Get_rank() + if dcoll.mpi_communicator is not None + else None) + + self_part_id = PartID(self_volume_dd.domain_tag.tag, rank) + other_part_id = PartID(other_volume_dd.domain_tag.tag, rank) - self_trace_dd = self_volume_dd.trace(self_btag) - other_trace_dd = other_volume_dd.trace(other_btag) + self_trace_dd = self_volume_dd.trace(BTAG_PARTITION(other_part_id)) + other_trace_dd = other_volume_dd.trace(BTAG_PARTITION(self_part_id)) # FIXME: In all likelihood, these traces will be reevaluated from # the other side, which is hard to prevent given the interface we @@ -396,8 +397,7 @@ def local_inter_volume_trace_pairs( other_trace = project( dcoll, other_volume_dd, other_trace_dd, other_ary) - other_to_self = dcoll._inter_partition_connections[ - BoundaryDomainTag(self_btag, self_volume_dd.domain_tag.tag)] + other_to_self = dcoll._inter_part_connections[other_part_id, self_part_id] def get_opposite_trace(el): if isinstance(el, Number): @@ -442,29 +442,17 @@ def update_for_type(self, key_hash, key: Type[Any]): @memoize_on_first_arg -def _remote_inter_partition_tags( +def _connected_parts( dcoll: DiscretizationCollection, self_volume_tag: VolumeTag, - other_volume_tag: Optional[VolumeTag] = None - ) -> Sequence[BoundaryDomainTag]: - if other_volume_tag is None: - other_volume_tag = self_volume_tag - - result: List[BoundaryDomainTag] = [] - for domain_tag in dcoll._inter_partition_connections: - if isinstance(domain_tag.tag, BTAG_PARTITION): - if domain_tag.volume_tag == self_volume_tag: - result.append(domain_tag) - - elif isinstance(domain_tag.tag, BTAG_MULTIVOL_PARTITION): - if (domain_tag.tag.other_rank is not None - and domain_tag.volume_tag == self_volume_tag - and domain_tag.tag.other_volume_tag == other_volume_tag): - result.append(domain_tag) - - else: - raise AssertionError("unexpected inter-partition tag type encountered: " - f"'{domain_tag.tag}'") + other_volume_tag: VolumeTag + ) -> Sequence[PartID]: + result: List[PartID] = [ + connected_part_id + for connected_part_id, part_id in dcoll._inter_part_connections.keys() + if ( + part_id.volume_tag == self_volume_tag + and connected_part_id.volume_tag == other_volume_tag)] return result @@ -506,21 +494,25 @@ class _RankBoundaryCommunicationEager: def __init__(self, actx: ArrayContext, dcoll: DiscretizationCollection, - domain_tag: BoundaryDomainTag, - *, local_bdry_data: ArrayOrContainer, + *, + local_part_id: PartID, + remote_part_id: PartID, + local_bdry_data: ArrayOrContainer, send_data: ArrayOrContainer, comm_tag: Optional[Hashable] = None): comm = dcoll.mpi_communicator assert comm is not None - remote_rank = _remote_rank_from_btag(domain_tag.tag) + remote_rank = remote_part_id.rank assert remote_rank is not None self.dcoll = dcoll self.array_context = actx - self.domain_tag = domain_tag - self.bdry_discr = dcoll.discr_from_dd(domain_tag) + self.local_part_id = local_part_id + self.remote_part_id = remote_part_id + self.bdry_discr = dcoll.discr_from_dd( + BoundaryDomainTag(BTAG_PARTITION(remote_part_id))) self.local_bdry_data = local_bdry_data self.comm_tag = self.base_comm_tag @@ -552,7 +544,8 @@ def finish(self): self.recv_data_np, self.array_context) unswapped_remote_bdry_data = unflatten(self.local_bdry_data, recv_data_flat, self.array_context) - bdry_conn = self.dcoll._inter_partition_connections[self.domain_tag] + bdry_conn = self.dcoll._inter_part_connections[ + self.remote_part_id, self.local_part_id] remote_bdry_data = bdry_conn(unswapped_remote_bdry_data) # Complete the nonblocking send request associated with communicating @@ -560,7 +553,9 @@ def finish(self): self.send_req.Wait() return TracePair( - DOFDesc(self.domain_tag, DISCR_TAG_BASE), + DOFDesc( + BoundaryDomainTag(BTAG_PARTITION(self.remote_part_id)), + DISCR_TAG_BASE), interior=self.local_bdry_data, exterior=remote_bdry_data) @@ -573,8 +568,9 @@ class _RankBoundaryCommunicationLazy: def __init__(self, actx: ArrayContext, dcoll: DiscretizationCollection, - domain_tag: BoundaryDomainTag, *, + local_part_id: PartID, + remote_part_id: PartID, local_bdry_data: ArrayOrContainer, send_data: ArrayOrContainer, comm_tag: Optional[Hashable] = None) -> None: @@ -584,10 +580,12 @@ def __init__(self, self.dcoll = dcoll self.array_context = actx - self.bdry_discr = dcoll.discr_from_dd(domain_tag) - self.domain_tag = domain_tag + self.bdry_discr = dcoll.discr_from_dd( + BoundaryDomainTag(BTAG_PARTITION(remote_part_id))) + self.local_part_id = local_part_id + self.remote_part_id = remote_part_id - remote_rank = _remote_rank_from_btag(domain_tag.tag) + remote_rank = remote_part_id.rank assert remote_rank is not None self.local_bdry_data = local_bdry_data @@ -617,10 +615,13 @@ def communicate_single_array(key, local_bdry_subary): communicate_single_array, self.local_bdry_data) def finish(self): - bdry_conn = self.dcoll._inter_partition_connections[self.domain_tag] + bdry_conn = self.dcoll._inter_part_connections[ + self.remote_part_id, self.local_part_id] return TracePair( - DOFDesc(self.domain_tag, DISCR_TAG_BASE), + DOFDesc( + BoundaryDomainTag(BTAG_PARTITION(self.remote_part_id)), + DISCR_TAG_BASE), interior=self.local_bdry_data, exterior=bdry_conn(self.remote_data)) @@ -637,9 +638,9 @@ def cross_rank_trace_pairs( r"""Get a :class:`list` of *ary* trace pairs for each partition boundary. For each partition boundary, the field data values in *ary* are - communicated to/from the neighboring partition. Presumably, this - communication is MPI (but strictly speaking, may not be, and this - routine is agnostic to the underlying communication). + communicated to/from the neighboring part. Presumably, this communication + is MPI (but strictly speaking, may not be, and this routine is agnostic to + the underlying communication). For each face on each partition boundary, a :class:`TracePair` is created with the locally, and @@ -684,19 +685,36 @@ def cross_rank_trace_pairs( # }}} - comm_bdtags = _remote_inter_partition_tags( - dcoll, self_volume_tag=volume_dd.domain_tag.tag) + if dcoll.mpi_communicator is None: + return [] + + rank = dcoll.mpi_communicator.Get_rank() + + local_part_id = PartID(volume_dd.domain_tag.tag, rank) + + connected_part_ids = _connected_parts( + dcoll, self_volume_tag=volume_dd.domain_tag.tag, + other_volume_tag=volume_dd.domain_tag.tag) + + remote_part_ids = [ + part_id + for part_id in connected_part_ids + if part_id.rank != rank] # This asserts that there is only one data exchange per rank, so that # there is no risk of mismatched data reaching the wrong recipient. # (Since we have only a single tag.) - assert len(comm_bdtags) == len( - {_remote_rank_from_btag(bdtag.tag) for bdtag in comm_bdtags}) + assert len(remote_part_ids) == len({part_id.rank for part_id in remote_part_ids}) if isinstance(ary, Number): # NOTE: Assumes that the same number is passed on every rank - return [TracePair(DOFDesc(bdtag, DISCR_TAG_BASE), interior=ary, exterior=ary) - for bdtag in comm_bdtags] + return [ + TracePair( + DOFDesc( + BoundaryDomainTag(BTAG_PARTITION(remote_part_id)), + DISCR_TAG_BASE), + interior=ary, exterior=ary) + for remote_part_id in remote_part_ids] actx = get_container_context_recursively(ary) assert actx is not None @@ -708,19 +726,21 @@ def cross_rank_trace_pairs( else: rbc = _RankBoundaryCommunicationEager - def start_comm(bdtag): - local_bdry_data = project( - dcoll, - DOFDesc(VolumeDomainTag(bdtag.volume_tag), DISCR_TAG_BASE), - DOFDesc(bdtag, DISCR_TAG_BASE), - ary) + def start_comm(remote_part_id): + bdtag = BoundaryDomainTag(BTAG_PARTITION(remote_part_id)) - return rbc(actx, dcoll, bdtag, + local_bdry_data = project(dcoll, volume_dd, bdtag, ary) + + return rbc(actx, dcoll, + local_part_id=local_part_id, + remote_part_id=remote_part_id, local_bdry_data=local_bdry_data, send_data=local_bdry_data, comm_tag=comm_tag) - rank_bdry_communcators = [start_comm(bdtag) for bdtag in comm_bdtags] + rank_bdry_communcators = [ + start_comm(remote_part_id) + for remote_part_id in remote_part_ids] return [rc.finish() for rc in rank_bdry_communcators] # }}} @@ -760,16 +780,26 @@ def cross_rank_inter_volume_trace_pairs( # }}} - comm_bdtags = _remote_inter_partition_tags( - dcoll, - self_volume_tag=self_volume_dd.domain_tag.tag, + if dcoll.mpi_communicator is None: + return [] + + rank = dcoll.mpi_communicator.Get_rank() + + local_part_id = PartID(self_volume_dd.domain_tag.tag, rank) + + connected_part_ids = _connected_parts( + dcoll, self_volume_tag=self_volume_dd.domain_tag.tag, other_volume_tag=other_volume_dd.domain_tag.tag) + remote_part_ids = [ + part_id + for part_id in connected_part_ids + if part_id.rank != rank] + # This asserts that there is only one data exchange per rank, so that # there is no risk of mismatched data reaching the wrong recipient. # (Since we have only a single tag.) - assert len(comm_bdtags) == len( - {_remote_rank_from_btag(bdtag.tag) for bdtag in comm_bdtags}) + assert len(remote_part_ids) == len({part_id.rank for part_id in remote_part_ids}) actx = get_container_context_recursively(self_ary) assert actx is not None @@ -781,25 +811,23 @@ def cross_rank_inter_volume_trace_pairs( else: rbc = _RankBoundaryCommunicationEager - def start_comm(bdtag): - assert isinstance(bdtag.tag, BTAG_MULTIVOL_PARTITION) - self_volume_dd = DOFDesc( - VolumeDomainTag(bdtag.volume_tag), DISCR_TAG_BASE) - other_volume_dd = DOFDesc( - VolumeDomainTag(bdtag.tag.other_volume_tag), DISCR_TAG_BASE) + def start_comm(remote_part_id): + bdtag = BoundaryDomainTag(BTAG_PARTITION(remote_part_id)) local_bdry_data = project(dcoll, self_volume_dd, bdtag, self_ary) send_data = project(dcoll, other_volume_dd, - BTAG_MULTIVOL_PARTITION( - other_rank=bdtag.tag.other_rank, - other_volume_tag=bdtag.volume_tag), other_ary) + BTAG_PARTITION(local_part_id), other_ary) - return rbc(actx, dcoll, bdtag, + return rbc(actx, dcoll, + local_part_id=local_part_id, + remote_part_id=remote_part_id, local_bdry_data=local_bdry_data, send_data=send_data, comm_tag=comm_tag) - rank_bdry_communcators = [start_comm(bdtag) for bdtag in comm_bdtags] + rank_bdry_communcators = [ + start_comm(remote_part_id) + for remote_part_id in remote_part_ids] return [rc.finish() for rc in rank_bdry_communcators] # }}} diff --git a/requirements.txt b/requirements.txt index 6d8841e9a..2107e5aeb 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,7 +10,7 @@ git+https://github.com/inducer/leap.git#egg=leap git+https://github.com/inducer/meshpy.git#egg=meshpy git+https://github.com/inducer/modepy.git#egg=modepy git+https://github.com/inducer/arraycontext.git#egg=arraycontext -git+https://github.com/inducer/meshmode.git@generic-part-bdry#egg=meshmode +git+https://github.com/inducer/meshmode.git#egg=meshmode git+https://github.com/inducer/pyvisfile.git#egg=pyvisfile git+https://github.com/inducer/pymetis.git#egg=pymetis git+https://github.com/illinois-ceesd/logpyle.git#egg=logpyle diff --git a/test/test_grudge.py b/test/test_grudge.py index a0bb9ac18..819752098 100644 --- a/test/test_grudge.py +++ b/test/test_grudge.py @@ -1084,22 +1084,16 @@ def test_multi_volume(actx_factory): nelements_per_axis=(8,)*dim, order=4) meg, = mesh.groups - part_per_element = ( - mesh.vertices[0, meg.vertex_indices[:, 0]] > 0).astype(np.int32) + x = mesh.vertices[0, meg.vertex_indices] + x_elem_avg = np.sum(x, axis=1)/x.shape[1] + volume_to_elements = { + 0: np.where(x_elem_avg <= 0)[0], + 1: np.where(x_elem_avg > 0)[0]} from meshmode.mesh.processing import partition_mesh - from grudge.discretization import relabel_partitions - parts = { - i: relabel_partitions( - partition_mesh(mesh, part_per_element, i)[0], - self_rank=0, - part_nr_to_rank_and_vol_tag={ - 0: (0, 0), - 1: (0, 1), - }) - for i in range(2)} - - make_discretization_collection(actx, parts, order=4) + volume_to_mesh = partition_mesh(mesh, volume_to_elements) + + make_discretization_collection(actx, volume_to_mesh, order=4) # }}}