diff --git a/finat/ufl/brokenelement.py b/finat/ufl/brokenelement.py index 3e820288..2db1748e 100644 --- a/finat/ufl/brokenelement.py +++ b/finat/ufl/brokenelement.py @@ -10,11 +10,30 @@ # Modified by Matthew Scroggs, 2023 from finat.ufl.finiteelementbase import FiniteElementBase +from finat.ufl.mixedelement import MixedElement, VectorElement, TensorElement from ufl.sobolevspace import L2 class BrokenElement(FiniteElementBase): """The discontinuous version of an existing Finite Element space.""" + def __new__(cls, element): + """ + Broken qualifier must be below Mixed/Vector/Tensor so we + overload __new__ to return: + + BrokenElement(MixedElement(elem0, elem1)) -> MixedElement(BrokenElement(elem0), BrokenElement(elem1)) + + and similarly for VectorElement and TensorElement. + """ + if isinstance(element, (VectorElement, TensorElement)): + return element.reconstruct(sub_element=BrokenElement(element.sub_elements[0])) + + elif isinstance(element, MixedElement): + return MixedElement(list(map(BrokenElement, element.sub_elements))) + + else: # hopefully no special casing needed + return super().__new__(cls) + def __init__(self, element): """Init.""" self._element = element diff --git a/finat/ufl/mixedelement.py b/finat/ufl/mixedelement.py index 20c25353..fb337812 100644 --- a/finat/ufl/mixedelement.py +++ b/finat/ufl/mixedelement.py @@ -345,9 +345,10 @@ def __repr__(self): """Doc.""" return self._repr - def reconstruct(self, **kwargs): + def reconstruct(self, sub_element=None, **kwargs): """Doc.""" - sub_element = self._sub_element.reconstruct(**kwargs) + if sub_element is None: + sub_element = self._sub_element.reconstruct(**kwargs) return VectorElement(sub_element, dim=len(self.sub_elements)) def variant(self): @@ -544,9 +545,10 @@ def symmetry(self): """ return self._symmetry - def reconstruct(self, **kwargs): + def reconstruct(self, sub_element=None, **kwargs): """Doc.""" - sub_element = self._sub_element.reconstruct(**kwargs) + if sub_element is None: + sub_element = self._sub_element.reconstruct(**kwargs) return TensorElement(sub_element, shape=self._shape, symmetry=self._symmetry) def __str__(self): diff --git a/test/finat/test_create_broken_element.py b/test/finat/test_create_broken_element.py new file mode 100644 index 00000000..6b6a04ed --- /dev/null +++ b/test/finat/test_create_broken_element.py @@ -0,0 +1,49 @@ +import pytest +import ufl +from finat.ufl import FiniteElement, BrokenElement, VectorElement, TensorElement, MixedElement + +sub_elements = [ + FiniteElement("CG", ufl.triangle, 1), + FiniteElement("BDM", ufl.triangle, 2), + FiniteElement("DG", ufl.interval, 2, variant="spectral") +] + +sub_ids = [ + "CG(1)", + "BDM(2)", + "DG(2,spectral)" +] + + +@pytest.mark.parametrize("sub_element", sub_elements, ids=sub_ids) +@pytest.mark.parametrize("shape", (1, 2, (2, 3)), ids=("1", "2", "(2,3)")) +def test_create_broken_vector_or_tensor_element(shape, sub_element): + """Check that BrokenElement returns a nested element + for mixed, vector, and tensor elements. + """ + if not isinstance(shape, int): + make_element = lambda elem: TensorElement(elem, shape=shape) + else: + make_element = lambda elem: VectorElement(elem, dim=shape) + + tensor = make_element(sub_element) + expected = make_element(BrokenElement(sub_element)) + + assert BrokenElement(tensor) == expected + + +@pytest.mark.parametrize("sub_elements", [sub_elements, sub_elements[-1:]], + ids=(f"nsubs={len(sub_elements)}", "nsubs=1")) +def test_create_broken_mixed_element(sub_elements): + """Check that BrokenElement returns a nested element + for mixed, vector, and tensor elements. + """ + mixed = MixedElement(sub_elements) + expected = MixedElement([BrokenElement(elem) for elem in sub_elements]) + assert BrokenElement(mixed) == expected + + +if __name__ == "__main__": + import os + import sys + pytest.main(args=[os.path.abspath(__file__)] + sys.argv[1:])