diff --git a/src/underworld3/discretisation/discretisation_mesh.py b/src/underworld3/discretisation/discretisation_mesh.py index b2527343..433b830b 100644 --- a/src/underworld3/discretisation/discretisation_mesh.py +++ b/src/underworld3/discretisation/discretisation_mesh.py @@ -2481,6 +2481,9 @@ def write_timestep( mesh_file = output_base_name + f".mesh.{index:05}.h5" self.write(mesh_file) + if create_xdmf: + _write_mesh_viz_groups(self, mesh_file) + if meshVars is not None: for var in meshVars: save_location = output_base_name + f".mesh.{var.clean_name}.{index:05}.h5" @@ -4101,6 +4104,116 @@ def mesh_update_callback(array, change_context): ## Simplified to allow us to decide how we want to checkpoint +def _petsc_numbering_to_global_ids(numbering): + """Convert PETSc numbering entries to non-negative global ids.""" + + gids = numpy.asarray(numbering, dtype=numpy.int64).copy() + negative = gids < 0 + gids[negative] = -gids[negative] - 1 + return gids + + +def _local_viz_cell_connectivity(mesh): + """Return local cell-to-vertex connectivity in global vertex ids.""" + + dm = mesh.dm + pStart, pEnd = dm.getDepthStratum(0) + cStart, cEnd = dm.getHeightStratum(0) + vertex_numbering = dm.getVertexNumbering().getIndices() + vertex_gids = _petsc_numbering_to_global_ids(vertex_numbering) + cell_num_points = mesh.element.entities[mesh.dim] + + cell_points_list = [] + for cell_id in range(cStart, cEnd): + closure = dm.getTransitiveClosure(cell_id)[0] + # Filter closure to strictly retain true vertices + points = numpy.asarray( + [p for p in closure if pStart <= p < pEnd], + dtype=numpy.int64, + ) + if len(points) != cell_num_points: + raise RuntimeError(f"Expected {cell_num_points} vertices for cell {cell_id}, got {len(points)}.") + cell_points_list.append(vertex_gids[points - pStart]) + + if not cell_points_list: + return numpy.empty((0, cell_num_points), dtype=numpy.int64) + + if mesh.dim == 3: + if dm.isSimplex(): + reorder = [0, 2, 1, 3] + else: + reorder = [0, 3, 2, 1, 4, 5, 6, 7] + cell_points_list = [pts[reorder] for pts in cell_points_list] + + return numpy.asarray(cell_points_list, dtype=numpy.int64) + + +def _write_mesh_viz_groups(mesh, mesh_h5_path): + """Write ParaView-safe ``/viz`` geometry/topology groups into a mesh HDF5.""" + + import underworld3 as uw + dm = mesh.dm + pStart, pEnd = dm.getDepthStratum(0) + vertex_numbering = dm.getVertexNumbering().getIndices() + vertex_gids = _petsc_numbering_to_global_ids(vertex_numbering) + + coords_local = numpy.asarray(dm.getCoordinatesLocal().array, dtype=numpy.float64).reshape(-1, mesh.dim) + n_local_vertices = pEnd - pStart + if coords_local.shape[0] != n_local_vertices: + coords_local = numpy.asarray(mesh.X.coords, dtype=numpy.float64) + if coords_local.shape[0] != n_local_vertices: + raise RuntimeError( + f"Could not match local coordinate rows ({coords_local.shape[0]}) " + f"to DMPlex vertex count ({n_local_vertices}) for {mesh_h5_path}." + ) + + local_cells = _local_viz_cell_connectivity(mesh) + + # Gather GIDs and coordinates separately to prevent float64 upcasting of integer GIDs + gathered_gids = uw.mpi.comm.gather(vertex_gids, root=0) + gathered_coords = uw.mpi.comm.gather(coords_local, root=0) + gathered_cells = uw.mpi.comm.gather(local_cells, root=0) + uw.mpi.barrier() + + if uw.mpi.rank == 0: + import h5py + + gid_blocks = [block for block in gathered_gids if block is not None and block.size > 0] + coord_blocks = [block for block in gathered_coords if block is not None and block.size > 0] + cell_blocks = [block for block in gathered_cells if block is not None and block.size > 0] + + if gid_blocks: + all_gids = numpy.concatenate(gid_blocks) + all_coords = numpy.vstack(coord_blocks) + + # Vectorized deduplication (numpy.unique returns sorted unique elements) + ordered_gids, unique_indices = numpy.unique(all_gids, return_index=True) + ordered_vertices = all_coords[unique_indices] + else: + ordered_gids = numpy.empty((0,), dtype=numpy.int64) + ordered_vertices = numpy.empty((0, mesh.dim), dtype=numpy.float64) + + if cell_blocks: + all_cells = numpy.vstack(cell_blocks) + # Vectorized remapping using searchsorted (O(N log M) instead of Python dict lookup) + dense_cells = numpy.searchsorted(ordered_gids, all_cells) + else: + all_cells = numpy.empty((0, mesh.element.entities[mesh.dim]), dtype=numpy.int64) + dense_cells = numpy.empty_like(all_cells) + + with h5py.File(mesh_h5_path, "a") as h5: + if "viz" in h5: + del h5["viz"] + viz = h5.create_group("viz") + geom = viz.create_group("geometry") + topo = viz.create_group("topology") + geom.create_dataset("vertices", data=ordered_vertices) + topo_cells = topo.create_dataset("cells", data=dense_cells) + topo_cells.attrs["cell_dim"] = mesh.dim + + uw.mpi.barrier() + + def _write_compat_groups(mesh, var, var_h5_path): """Write ``/vertex_fields/`` or ``/cell_fields/`` compatibility groups. @@ -4161,6 +4274,7 @@ def checkpoint_xdmf( ): import h5py import os + import warnings """Create xdmf file for checkpoints""" @@ -4197,19 +4311,42 @@ def checkpoint_xdmf( numCorners = cells.shape[1] cellDim = topo["cells"].attrs["cell_dim"] + if topoPath == "topology": + warnings.warn( + "Using raw '/topology/cells' for XDMF. This may not be Paraview-compatible. " + "Expected '/viz/topology/cells'.", + stacklevel=2, + ) + + cells_data = cells[...] + c_min, c_max = cells_data.min(), cells_data.max() + if c_min < 0 or c_max >= numVertices: + warnings.warn( + f"XDMF connectivity is invalid! cells max {c_max} >= " + f"numVertices {numVertices} or min {c_min} < 0. ParaView will likely crash. " + f"Ensure cell-to-vertex connectivity is written.", + stacklevel=2, + ) + h5.close() # We only use a subset of the possible cell types if spaceDim == 2: if numCorners == 3: topology_type = "Triangle" + elif numCorners == 4: + topology_type = "Quadrilateral" else: + warnings.warn(f"Unexpected numCorners={numCorners} for 2D spaceDim. Expected 3 or 4.", stacklevel=2) topology_type = "Quadrilateral" geomType = "XY" else: if numCorners == 4: topology_type = "Tetrahedron" + elif numCorners == 8: + topology_type = "Hexahedron" else: + warnings.warn(f"Unexpected numCorners={numCorners} for 3D spaceDim. Expected 4 or 8.", stacklevel=2) topology_type = "Hexahedron" geomType = "XYZ" @@ -4301,6 +4438,19 @@ def get_field_info(h5_filename, mesh_var, center): center = "Node" numItems, numComponents, dataset_path = get_field_info(var_filename, var, center) + if center == "Node" and numItems != numVertices: + warnings.warn( + f"Attribute '{var.clean_name}' Center is 'Node' but numItems " + f"({numItems}) != numVertices ({numVertices}).", + stacklevel=2, + ) + elif center == "Cell" and numItems != numCells: + warnings.warn( + f"Attribute '{var.clean_name}' Center is 'Cell' but numItems " + f"({numItems}) != numCells ({numCells}).", + stacklevel=2, + ) + # Use variable type when available, but reflect actual stored component count. if hasattr(var, "vtype") and var.vtype in ( uw.VarType.TENSOR, diff --git a/tests/test_0005_xdmf_compat.py b/tests/test_0005_xdmf_compat.py index 72ba4439..d0222750 100644 --- a/tests/test_0005_xdmf_compat.py +++ b/tests/test_0005_xdmf_compat.py @@ -288,3 +288,79 @@ def test_tensor_variable_repacking(tmp_path): _check_xdmf_refs(xdmf_file, str(tmp_path)) del mesh + + +# --------------------------------------------------------------------------- +# Test: Valid /viz/topology connectivity generation +# --------------------------------------------------------------------------- + + +def test_xdmf_viz_topology_written_correctly(tmp_path): + """Verify that write_timestep correctly writes /viz/topology/cells and XDMF points to it.""" + + mesh = uw.meshing.StructuredQuadBox(elementRes=(3, 3)) + + # Write just the mesh (no vars needed to test mesh topology) + mesh.write_timestep("test_topo", index=0, outputPath=str(tmp_path)) + + mesh_h5 = os.path.join(str(tmp_path), "test_topo.mesh.00000.h5") + assert _check_h5_group_exists(mesh_h5, "viz/topology/cells"), ( + "Mesh HDF5 must contain the /viz/topology/cells dataset" + ) + assert _check_h5_group_exists(mesh_h5, "viz/geometry/vertices"), ( + "Mesh HDF5 must contain the /viz/geometry/vertices dataset" + ) + + # Validate cell-to-vertex connectivity bounds + cells = _read_h5_dataset(mesh_h5, "viz/topology/cells") + vertices = _read_h5_dataset(mesh_h5, "viz/geometry/vertices") + + num_vertices = vertices.shape[0] + assert cells.max() < num_vertices, ( + f"Invalid topology: cells max ({cells.max()}) must be < numVertices ({num_vertices})" + ) + assert cells.min() >= 0, "Invalid topology: cells indices cannot be negative" + + # Validate XDMF actually points to the viz group, not the raw DMPlex group + xdmf_file = os.path.join(str(tmp_path), "test_topo.mesh.00000.xdmf") + assert os.path.exists(xdmf_file), "XDMF file should exist" + + with open(xdmf_file, "r") as f: + xdmf_content = f.read() + + assert "/viz/topology/cells" in xdmf_content, ( + "XDMF file should explicitly point to /viz/topology/cells" + ) + assert "/viz/geometry/vertices" in xdmf_content, ( + "XDMF file should explicitly point to /viz/geometry/vertices" + ) + + del mesh + + +def test_xdmf_viz_topology_3d_written_correctly(tmp_path): + """Verify that write_timestep correctly writes /viz/topology/cells for 3D meshes.""" + + mesh = uw.meshing.StructuredQuadBox(elementRes=(2, 2, 2)) + + # Write just the mesh + mesh.write_timestep("test_topo_3d", index=0, outputPath=str(tmp_path)) + + mesh_h5 = os.path.join(str(tmp_path), "test_topo_3d.mesh.00000.h5") + assert _check_h5_group_exists(mesh_h5, "viz/topology/cells"), ( + "3D Mesh HDF5 must contain the /viz/topology/cells dataset" + ) + + # Validate connectivity + cells = _read_h5_dataset(mesh_h5, "viz/topology/cells") + vertices = _read_h5_dataset(mesh_h5, "viz/geometry/vertices") + + num_vertices = vertices.shape[0] + assert cells.max() < num_vertices, "Invalid 3D topology bounds" + + xdmf_file = os.path.join(str(tmp_path), "test_topo_3d.mesh.00000.xdmf") + assert os.path.exists(xdmf_file) + with open(xdmf_file, "r") as f: + assert "/viz/topology/cells" in f.read() + + del mesh