Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions finat/ufl/brokenelement.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,30 @@
# Modified by Matthew Scroggs, 2023

from finat.ufl.finiteelementbase import FiniteElementBase
from finat.ufl.mixedelement import MixedElement, VectorElement, TensorElement
from ufl.sobolevspace import L2


class BrokenElement(FiniteElementBase):
"""The discontinuous version of an existing Finite Element space."""
def __new__(cls, element):
"""
Broken qualifier must be below Mixed/Vector/Tensor so we
overload __new__ to return:

BrokenElement(MixedElement(elem0, elem1)) -> MixedElement(BrokenElement(elem0), BrokenElement(elem1))

and similarly for VectorElement and TensorElement.
"""
if isinstance(element, (VectorElement, TensorElement)):
return element.reconstruct(sub_element=BrokenElement(element.sub_elements[0]))

elif isinstance(element, MixedElement):
return MixedElement(list(map(BrokenElement, element.sub_elements)))

else: # hopefully no special casing needed
return super().__new__(cls)

def __init__(self, element):
"""Init."""
self._element = element
Expand Down
10 changes: 6 additions & 4 deletions finat/ufl/mixedelement.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,9 +345,10 @@ def __repr__(self):
"""Doc."""
return self._repr

def reconstruct(self, **kwargs):
def reconstruct(self, sub_element=None, **kwargs):
"""Doc."""
sub_element = self._sub_element.reconstruct(**kwargs)
if sub_element is None:
sub_element = self._sub_element.reconstruct(**kwargs)
return VectorElement(sub_element, dim=len(self.sub_elements))

def variant(self):
Expand Down Expand Up @@ -544,9 +545,10 @@ def symmetry(self):
"""
return self._symmetry

def reconstruct(self, **kwargs):
def reconstruct(self, sub_element=None, **kwargs):
"""Doc."""
sub_element = self._sub_element.reconstruct(**kwargs)
if sub_element is None:
sub_element = self._sub_element.reconstruct(**kwargs)
return TensorElement(sub_element, shape=self._shape, symmetry=self._symmetry)

def __str__(self):
Expand Down
49 changes: 49 additions & 0 deletions test/finat/test_create_broken_element.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import pytest
import ufl
from finat.ufl import FiniteElement, BrokenElement, VectorElement, TensorElement, MixedElement

sub_elements = [
FiniteElement("CG", ufl.triangle, 1),
FiniteElement("BDM", ufl.triangle, 2),
FiniteElement("DG", ufl.interval, 2, variant="spectral")
]

sub_ids = [
"CG(1)",
"BDM(2)",
"DG(2,spectral)"
]


@pytest.mark.parametrize("sub_element", sub_elements, ids=sub_ids)
@pytest.mark.parametrize("shape", (1, 2, (2, 3)), ids=("1", "2", "(2,3)"))
def test_create_broken_vector_or_tensor_element(shape, sub_element):
"""Check that BrokenElement returns a nested element
for mixed, vector, and tensor elements.
"""
if not isinstance(shape, int):
make_element = lambda elem: TensorElement(elem, shape=shape)
else:
make_element = lambda elem: VectorElement(elem, dim=shape)

tensor = make_element(sub_element)
expected = make_element(BrokenElement(sub_element))

assert BrokenElement(tensor) == expected


@pytest.mark.parametrize("sub_elements", [sub_elements, sub_elements[-1:]],
ids=(f"nsubs={len(sub_elements)}", "nsubs=1"))
def test_create_broken_mixed_element(sub_elements):
"""Check that BrokenElement returns a nested element
for mixed, vector, and tensor elements.
"""
mixed = MixedElement(sub_elements)
expected = MixedElement([BrokenElement(elem) for elem in sub_elements])
assert BrokenElement(mixed) == expected


if __name__ == "__main__":
import os
import sys
pytest.main(args=[os.path.abspath(__file__)] + sys.argv[1:])