diff --git a/meshmode/mesh/io.py b/meshmode/mesh/io.py index 4eb52211..2912db58 100644 --- a/meshmode/mesh/io.py +++ b/meshmode/mesh/io.py @@ -20,6 +20,8 @@ THE SOFTWARE. """ +from typing import Tuple, Optional, List + import numpy as np from gmsh_interop.reader import ( # noqa: F401 @@ -47,6 +49,8 @@ # {{{ gmsh receiver class GmshMeshReceiver(GmshMeshReceiverBase): + tags: Optional[List[Tuple[int, int]]] + def __init__(self, mesh_construction_kwargs=None): # Use data fields similar to meshpy.triangle.MeshInfo and # meshpy.tet.MeshInfo @@ -82,6 +86,8 @@ def set_up_elements(self, count): self.element_nodes = [None] * count self.element_types = [None] * count self.element_markers = [None] * count + self.tags = [] + self.gmsh_tag_index_to_mine = {} def add_element(self, element_nr, element_type, vertex_nrs, lexicographic_nodes, tag_numbers): @@ -98,10 +104,6 @@ def finalize_elements(self): # May raise ValueError if try to add different tags with the same name def add_tag(self, name, index, dimension): - if self.tags is None: - self.tags = [] - if self.gmsh_tag_index_to_mine is None: - self.gmsh_tag_index_to_mine = {} # add tag if new if index not in self.gmsh_tag_index_to_mine: self.gmsh_tag_index_to_mine[index] = len(self.tags) diff --git a/test/test_mesh.py b/test/test_mesh.py index 6d5014ec..3887fa2e 100644 --- a/test/test_mesh.py +++ b/test/test_mesh.py @@ -1481,6 +1481,27 @@ def separated(x, y): # }}} +def test_gmsh_tag_reading(): + from meshmode.mesh.io import generate_gmsh, ScriptSource + generate_gmsh(ScriptSource(""" + h = 1; // Characteristic length of a mesh element + Point(1) = {0, 0, 0, h}; // Point construction + Point(2) = {10, 0, 0, h}; + Point(3) = {10, 10, 0, h}; + Point(4) = {0, 10, 0, h}; + Line(1) = {1,2}; //Lines + Line(2) = {2,3}; + Line(3) = {3,4}; + Line(4) = {4,1}; + Curve Loop(1) = {1,2,3,4}; // A Boundary + Plane Surface(1) = {1}; // A Surface + Physical Surface(1) = {1}; // Setting a label to the Surface + + Mesh 2; + RecombineMesh; + """, ".geo"), dimensions=2) + + if __name__ == "__main__": import sys if len(sys.argv) > 1: