From cdcf6999d87c30e1495f25a320405ad8f6def40d Mon Sep 17 00:00:00 2001 From: Josh Hope-Collins Date: Tue, 9 Dec 2025 19:40:51 +0000 Subject: [PATCH 1/3] overload BrokenElement.__new__ to nest Broken-ness below Vector/Tensor/Mixed-ness --- finat/ufl/brokenelement.py | 28 ++++++++++++++ test/finat/test_create_broken_element.py | 49 ++++++++++++++++++++++++ 2 files changed, 77 insertions(+) create mode 100644 test/finat/test_create_broken_element.py diff --git a/finat/ufl/brokenelement.py b/finat/ufl/brokenelement.py index 3e8202883..77f4a2f75 100644 --- a/finat/ufl/brokenelement.py +++ b/finat/ufl/brokenelement.py @@ -10,11 +10,39 @@ # Modified by Matthew Scroggs, 2023 from finat.ufl.finiteelementbase import FiniteElementBase +from finat.ufl.mixedelement import MixedElement, VectorElement, TensorElement from ufl.sobolevspace import L2 class BrokenElement(FiniteElementBase): """The discontinuous version of an existing Finite Element space.""" + def __new__(cls, element): + """ + Broken qualifier must be below Mixed/Vector/Tensor so we + overload __new__ to return: + + BrokenElement(MixedElement(elem0, elem1)) -> MixedElement(BrokenElement(elem0), BrokenElement(elem1)) + + and similarly for VectorElement and TensorElement. + """ + if isinstance(element, VectorElement): + return VectorElement( + BrokenElement(element.sub_elements[0]), + dim=len(element.sub_elements)) + + elif isinstance(element, TensorElement): + return TensorElement( + BrokenElement(element.sub_elements[0]), + symmetry=element._symmetry, + shape=element._shape) + + elif isinstance(element, MixedElement): + return MixedElement( + [BrokenElement(elem) for elem in element.sub_elements]) + + else: # hopefully no special casing needed + return super().__new__(cls) + def __init__(self, element): """Init.""" self._element = element diff --git a/test/finat/test_create_broken_element.py b/test/finat/test_create_broken_element.py new file mode 100644 index 000000000..6b6a04eda --- /dev/null +++ b/test/finat/test_create_broken_element.py @@ -0,0 +1,49 @@ +import pytest +import ufl +from finat.ufl import FiniteElement, BrokenElement, VectorElement, TensorElement, MixedElement + +sub_elements = [ + FiniteElement("CG", ufl.triangle, 1), + FiniteElement("BDM", ufl.triangle, 2), + FiniteElement("DG", ufl.interval, 2, variant="spectral") +] + +sub_ids = [ + "CG(1)", + "BDM(2)", + "DG(2,spectral)" +] + + +@pytest.mark.parametrize("sub_element", sub_elements, ids=sub_ids) +@pytest.mark.parametrize("shape", (1, 2, (2, 3)), ids=("1", "2", "(2,3)")) +def test_create_broken_vector_or_tensor_element(shape, sub_element): + """Check that BrokenElement returns a nested element + for mixed, vector, and tensor elements. + """ + if not isinstance(shape, int): + make_element = lambda elem: TensorElement(elem, shape=shape) + else: + make_element = lambda elem: VectorElement(elem, dim=shape) + + tensor = make_element(sub_element) + expected = make_element(BrokenElement(sub_element)) + + assert BrokenElement(tensor) == expected + + +@pytest.mark.parametrize("sub_elements", [sub_elements, sub_elements[-1:]], + ids=(f"nsubs={len(sub_elements)}", "nsubs=1")) +def test_create_broken_mixed_element(sub_elements): + """Check that BrokenElement returns a nested element + for mixed, vector, and tensor elements. + """ + mixed = MixedElement(sub_elements) + expected = MixedElement([BrokenElement(elem) for elem in sub_elements]) + assert BrokenElement(mixed) == expected + + +if __name__ == "__main__": + import os + import sys + pytest.main(args=[os.path.abspath(__file__)] + sys.argv[1:]) From a9cd9c54d06f6064d43b0792e751440b8ae8b3be Mon Sep 17 00:00:00 2001 From: Josh Hope-Collins Date: Thu, 11 Dec 2025 11:05:31 +0000 Subject: [PATCH 2/3] Update finat/ufl/brokenelement.py Co-authored-by: Pablo Brubeck --- finat/ufl/brokenelement.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/finat/ufl/brokenelement.py b/finat/ufl/brokenelement.py index 77f4a2f75..d231e8b9a 100644 --- a/finat/ufl/brokenelement.py +++ b/finat/ufl/brokenelement.py @@ -37,8 +37,7 @@ def __new__(cls, element): shape=element._shape) elif isinstance(element, MixedElement): - return MixedElement( - [BrokenElement(elem) for elem in element.sub_elements]) + return MixedElement(list(map(BrokenElement, element.sub_elements))) else: # hopefully no special casing needed return super().__new__(cls) From 230b5e2e1409e7a047946d4688173985ab9f6759 Mon Sep 17 00:00:00 2001 From: Josh Hope-Collins Date: Thu, 11 Dec 2025 11:17:20 +0000 Subject: [PATCH 3/3] optional sub_element reconstruct kwarg to Vector and Tensor elements --- finat/ufl/brokenelement.py | 12 ++---------- finat/ufl/mixedelement.py | 10 ++++++---- 2 files changed, 8 insertions(+), 14 deletions(-) diff --git a/finat/ufl/brokenelement.py b/finat/ufl/brokenelement.py index d231e8b9a..2db1748e5 100644 --- a/finat/ufl/brokenelement.py +++ b/finat/ufl/brokenelement.py @@ -25,16 +25,8 @@ def __new__(cls, element): and similarly for VectorElement and TensorElement. """ - if isinstance(element, VectorElement): - return VectorElement( - BrokenElement(element.sub_elements[0]), - dim=len(element.sub_elements)) - - elif isinstance(element, TensorElement): - return TensorElement( - BrokenElement(element.sub_elements[0]), - symmetry=element._symmetry, - shape=element._shape) + if isinstance(element, (VectorElement, TensorElement)): + return element.reconstruct(sub_element=BrokenElement(element.sub_elements[0])) elif isinstance(element, MixedElement): return MixedElement(list(map(BrokenElement, element.sub_elements))) diff --git a/finat/ufl/mixedelement.py b/finat/ufl/mixedelement.py index 20c253535..fb3378128 100644 --- a/finat/ufl/mixedelement.py +++ b/finat/ufl/mixedelement.py @@ -345,9 +345,10 @@ def __repr__(self): """Doc.""" return self._repr - def reconstruct(self, **kwargs): + def reconstruct(self, sub_element=None, **kwargs): """Doc.""" - sub_element = self._sub_element.reconstruct(**kwargs) + if sub_element is None: + sub_element = self._sub_element.reconstruct(**kwargs) return VectorElement(sub_element, dim=len(self.sub_elements)) def variant(self): @@ -544,9 +545,10 @@ def symmetry(self): """ return self._symmetry - def reconstruct(self, **kwargs): + def reconstruct(self, sub_element=None, **kwargs): """Doc.""" - sub_element = self._sub_element.reconstruct(**kwargs) + if sub_element is None: + sub_element = self._sub_element.reconstruct(**kwargs) return TensorElement(sub_element, shape=self._shape, symmetry=self._symmetry) def __str__(self):