Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
150 changes: 150 additions & 0 deletions src/underworld3/discretisation/discretisation_mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -2481,6 +2481,9 @@ def write_timestep(
mesh_file = output_base_name + f".mesh.{index:05}.h5"
self.write(mesh_file)

if create_xdmf:
_write_mesh_viz_groups(self, mesh_file)

if meshVars is not None:
for var in meshVars:
save_location = output_base_name + f".mesh.{var.clean_name}.{index:05}.h5"
Expand Down Expand Up @@ -4101,6 +4104,116 @@ def mesh_update_callback(array, change_context):
## Simplified to allow us to decide how we want to checkpoint


def _petsc_numbering_to_global_ids(numbering):
"""Convert PETSc numbering entries to non-negative global ids."""

gids = numpy.asarray(numbering, dtype=numpy.int64).copy()
negative = gids < 0
gids[negative] = -gids[negative] - 1
return gids


def _local_viz_cell_connectivity(mesh):
"""Return local cell-to-vertex connectivity in global vertex ids."""

dm = mesh.dm
pStart, pEnd = dm.getDepthStratum(0)
cStart, cEnd = dm.getHeightStratum(0)
vertex_numbering = dm.getVertexNumbering().getIndices()
vertex_gids = _petsc_numbering_to_global_ids(vertex_numbering)
cell_num_points = mesh.element.entities[mesh.dim]

cell_points_list = []
for cell_id in range(cStart, cEnd):
closure = dm.getTransitiveClosure(cell_id)[0]
# Filter closure to strictly retain true vertices
points = numpy.asarray(
[p for p in closure if pStart <= p < pEnd],
dtype=numpy.int64,
)
if len(points) != cell_num_points:
raise RuntimeError(f"Expected {cell_num_points} vertices for cell {cell_id}, got {len(points)}.")
cell_points_list.append(vertex_gids[points - pStart])

if not cell_points_list:
return numpy.empty((0, cell_num_points), dtype=numpy.int64)

if mesh.dim == 3:
if dm.isSimplex():
reorder = [0, 2, 1, 3]
else:
reorder = [0, 3, 2, 1, 4, 5, 6, 7]
cell_points_list = [pts[reorder] for pts in cell_points_list]
Comment on lines +4141 to +4146

return numpy.asarray(cell_points_list, dtype=numpy.int64)


def _write_mesh_viz_groups(mesh, mesh_h5_path):
"""Write ParaView-safe ``/viz`` geometry/topology groups into a mesh HDF5."""

import underworld3 as uw
dm = mesh.dm
pStart, pEnd = dm.getDepthStratum(0)
vertex_numbering = dm.getVertexNumbering().getIndices()
vertex_gids = _petsc_numbering_to_global_ids(vertex_numbering)

coords_local = numpy.asarray(dm.getCoordinatesLocal().array, dtype=numpy.float64).reshape(-1, mesh.dim)
n_local_vertices = pEnd - pStart
if coords_local.shape[0] != n_local_vertices:
coords_local = numpy.asarray(mesh.X.coords, dtype=numpy.float64)
if coords_local.shape[0] != n_local_vertices:
raise RuntimeError(
f"Could not match local coordinate rows ({coords_local.shape[0]}) "
f"to DMPlex vertex count ({n_local_vertices}) for {mesh_h5_path}."
)

local_cells = _local_viz_cell_connectivity(mesh)

# Gather GIDs and coordinates separately to prevent float64 upcasting of integer GIDs
gathered_gids = uw.mpi.comm.gather(vertex_gids, root=0)
gathered_coords = uw.mpi.comm.gather(coords_local, root=0)
gathered_cells = uw.mpi.comm.gather(local_cells, root=0)
uw.mpi.barrier()

if uw.mpi.rank == 0:
import h5py

gid_blocks = [block for block in gathered_gids if block is not None and block.size > 0]
coord_blocks = [block for block in gathered_coords if block is not None and block.size > 0]
cell_blocks = [block for block in gathered_cells if block is not None and block.size > 0]

if gid_blocks:
all_gids = numpy.concatenate(gid_blocks)
all_coords = numpy.vstack(coord_blocks)

# Vectorized deduplication (numpy.unique returns sorted unique elements)
ordered_gids, unique_indices = numpy.unique(all_gids, return_index=True)
ordered_vertices = all_coords[unique_indices]
else:
ordered_gids = numpy.empty((0,), dtype=numpy.int64)
ordered_vertices = numpy.empty((0, mesh.dim), dtype=numpy.float64)

if cell_blocks:
all_cells = numpy.vstack(cell_blocks)
# Vectorized remapping using searchsorted (O(N log M) instead of Python dict lookup)
dense_cells = numpy.searchsorted(ordered_gids, all_cells)
else:
all_cells = numpy.empty((0, mesh.element.entities[mesh.dim]), dtype=numpy.int64)
dense_cells = numpy.empty_like(all_cells)

with h5py.File(mesh_h5_path, "a") as h5:
if "viz" in h5:
del h5["viz"]
viz = h5.create_group("viz")
geom = viz.create_group("geometry")
topo = viz.create_group("topology")
geom.create_dataset("vertices", data=ordered_vertices)
topo_cells = topo.create_dataset("cells", data=dense_cells)
topo_cells.attrs["cell_dim"] = mesh.dim

uw.mpi.barrier()


def _write_compat_groups(mesh, var, var_h5_path):
"""Write ``/vertex_fields/`` or ``/cell_fields/`` compatibility groups.

Expand Down Expand Up @@ -4161,6 +4274,7 @@ def checkpoint_xdmf(
):
import h5py
import os
import warnings

"""Create xdmf file for checkpoints"""

Expand Down Expand Up @@ -4197,19 +4311,42 @@ def checkpoint_xdmf(
numCorners = cells.shape[1]
cellDim = topo["cells"].attrs["cell_dim"]

if topoPath == "topology":
warnings.warn(
"Using raw '/topology/cells' for XDMF. This may not be Paraview-compatible. "
"Expected '/viz/topology/cells'.",
stacklevel=2,
)
Comment on lines +4315 to +4319

cells_data = cells[...]
c_min, c_max = cells_data.min(), cells_data.max()
if c_min < 0 or c_max >= numVertices:
warnings.warn(
f"XDMF connectivity is invalid! cells max {c_max} >= "
f"numVertices {numVertices} or min {c_min} < 0. ParaView will likely crash. "
f"Ensure cell-to-vertex connectivity is written.",
stacklevel=2,
)

h5.close()

# We only use a subset of the possible cell types
if spaceDim == 2:
if numCorners == 3:
topology_type = "Triangle"
elif numCorners == 4:
topology_type = "Quadrilateral"
else:
warnings.warn(f"Unexpected numCorners={numCorners} for 2D spaceDim. Expected 3 or 4.", stacklevel=2)
topology_type = "Quadrilateral"
geomType = "XY"
else:
if numCorners == 4:
topology_type = "Tetrahedron"
elif numCorners == 8:
topology_type = "Hexahedron"
else:
warnings.warn(f"Unexpected numCorners={numCorners} for 3D spaceDim. Expected 4 or 8.", stacklevel=2)
topology_type = "Hexahedron"
geomType = "XYZ"

Expand Down Expand Up @@ -4301,6 +4438,19 @@ def get_field_info(h5_filename, mesh_var, center):
center = "Node"
numItems, numComponents, dataset_path = get_field_info(var_filename, var, center)

if center == "Node" and numItems != numVertices:
warnings.warn(
f"Attribute '{var.clean_name}' Center is 'Node' but numItems "
f"({numItems}) != numVertices ({numVertices}).",
stacklevel=2,
)
elif center == "Cell" and numItems != numCells:
warnings.warn(
f"Attribute '{var.clean_name}' Center is 'Cell' but numItems "
f"({numItems}) != numCells ({numCells}).",
stacklevel=2,
)

# Use variable type when available, but reflect actual stored component count.
if hasattr(var, "vtype") and var.vtype in (
uw.VarType.TENSOR,
Expand Down
76 changes: 76 additions & 0 deletions tests/test_0005_xdmf_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,3 +288,79 @@ def test_tensor_variable_repacking(tmp_path):
_check_xdmf_refs(xdmf_file, str(tmp_path))

del mesh


# ---------------------------------------------------------------------------
# Test: Valid /viz/topology connectivity generation
# ---------------------------------------------------------------------------


def test_xdmf_viz_topology_written_correctly(tmp_path):
"""Verify that write_timestep correctly writes /viz/topology/cells and XDMF points to it."""

mesh = uw.meshing.StructuredQuadBox(elementRes=(3, 3))

# Write just the mesh (no vars needed to test mesh topology)
mesh.write_timestep("test_topo", index=0, outputPath=str(tmp_path))

mesh_h5 = os.path.join(str(tmp_path), "test_topo.mesh.00000.h5")
assert _check_h5_group_exists(mesh_h5, "viz/topology/cells"), (
"Mesh HDF5 must contain the /viz/topology/cells dataset"
)
assert _check_h5_group_exists(mesh_h5, "viz/geometry/vertices"), (
"Mesh HDF5 must contain the /viz/geometry/vertices dataset"
)

# Validate cell-to-vertex connectivity bounds
cells = _read_h5_dataset(mesh_h5, "viz/topology/cells")
vertices = _read_h5_dataset(mesh_h5, "viz/geometry/vertices")

num_vertices = vertices.shape[0]
assert cells.max() < num_vertices, (
f"Invalid topology: cells max ({cells.max()}) must be < numVertices ({num_vertices})"
)
assert cells.min() >= 0, "Invalid topology: cells indices cannot be negative"

# Validate XDMF actually points to the viz group, not the raw DMPlex group
xdmf_file = os.path.join(str(tmp_path), "test_topo.mesh.00000.xdmf")
assert os.path.exists(xdmf_file), "XDMF file should exist"

with open(xdmf_file, "r") as f:
xdmf_content = f.read()

assert "/viz/topology/cells" in xdmf_content, (
"XDMF file should explicitly point to /viz/topology/cells"
)
assert "/viz/geometry/vertices" in xdmf_content, (
"XDMF file should explicitly point to /viz/geometry/vertices"
)

del mesh


def test_xdmf_viz_topology_3d_written_correctly(tmp_path):
"""Verify that write_timestep correctly writes /viz/topology/cells for 3D meshes."""

mesh = uw.meshing.StructuredQuadBox(elementRes=(2, 2, 2))

# Write just the mesh
mesh.write_timestep("test_topo_3d", index=0, outputPath=str(tmp_path))

mesh_h5 = os.path.join(str(tmp_path), "test_topo_3d.mesh.00000.h5")
assert _check_h5_group_exists(mesh_h5, "viz/topology/cells"), (
"3D Mesh HDF5 must contain the /viz/topology/cells dataset"
)

# Validate connectivity
cells = _read_h5_dataset(mesh_h5, "viz/topology/cells")
vertices = _read_h5_dataset(mesh_h5, "viz/geometry/vertices")

num_vertices = vertices.shape[0]
assert cells.max() < num_vertices, "Invalid 3D topology bounds"

xdmf_file = os.path.join(str(tmp_path), "test_topo_3d.mesh.00000.xdmf")
assert os.path.exists(xdmf_file)
with open(xdmf_file, "r") as f:
assert "/viz/topology/cells" in f.read()

del mesh