From efbbe04482441f6973113511d084a1aa744f7b17 Mon Sep 17 00:00:00 2001 From: Leo Collins Date: Fri, 19 Dec 2025 18:42:28 +0000 Subject: [PATCH 1/3] remove code --- firedrake/interpolation.py | 25 +------------------------ 1 file changed, 1 insertion(+), 24 deletions(-) diff --git a/firedrake/interpolation.py b/firedrake/interpolation.py index c8693ce6db..14c8b4a838 100644 --- a/firedrake/interpolation.py +++ b/firedrake/interpolation.py @@ -726,20 +726,8 @@ def _get_callable(self, tensor=None, bcs=None, mat_type=None, sub_mat_type=None) # We need to split the target space V and generate separate kernels if self.rank == 2: expressions = {(0,): self.ufl_interpolate} - elif isinstance(self.dual_arg, Coargument): - # Split in the coargument - expressions = dict(split_form(self.ufl_interpolate)) else: - assert isinstance(self.dual_arg, Cofunction) - # Split in the cofunction: split_form can only split in the coargument - # Replace the cofunction with a coargument to construct the Jacobian - interp = self.ufl_interpolate._ufl_expr_reconstruct_(self.operand, self.target_space) - # Split the Jacobian into blocks - interp_split = dict(split_form(interp)) - # Split the cofunction - dual_split = dict(split_form(self.dual_arg)) - # Combine the splits by taking their action - expressions = {i: action(interp_split[i], dual_split[i[-1:]]) for i in interp_split} + expressions = dict(split_form(self.ufl_interpolate)) # Interpolate each sub expression into each function space for indices, sub_expr in expressions.items(): @@ -1649,14 +1637,6 @@ def _get_sub_interpolators( # See https://github.com/firedrakeproject/firedrake/issues/4668 space_equals = lambda V1, V2: V1 == V2 and V1.parent == V2.parent and V1.index == V2.index - # We need a Coargument in order to split the Interpolate - needs_action = not any(isinstance(a, Coargument) for a in self.interpolate_args) - if needs_action: - # Split the dual argument - dual_split = dict(split_form(self.dual_arg)) - # Create the Jacobian to be split into blocks - self.ufl_interpolate = self.ufl_interpolate._ufl_expr_reconstruct_(self.operand, self.target_space) - # Get sub-interpolators and sub-bcs for each block Isub: dict[tuple[int] | tuple[int, int], tuple[Interpolator, list[DirichletBC]]] = {} for indices, form in split_form(self.ufl_interpolate): @@ -1667,9 +1647,6 @@ def _get_sub_interpolators( for space, index in zip(spaces, indices): subspace = space.sub(index) sub_bcs.extend(bc for bc in bcs if space_equals(bc.function_space(), subspace)) - if needs_action: - # Take the action of each sub-cofunction against each block - form = action(form, dual_split[indices[-1:]]) Isub[indices] = (get_interpolator(form), sub_bcs) return Isub From 263b923fede87ed4ccf7aebb4d1bf908db17dbfc Mon Sep 17 00:00:00 2001 From: Leo Collins Date: Fri, 19 Dec 2025 22:43:56 +0000 Subject: [PATCH 2/3] working --- firedrake/formmanipulation.py | 20 +++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/firedrake/formmanipulation.py b/firedrake/formmanipulation.py index eb830493e1..f5fbe2ee80 100644 --- a/firedrake/formmanipulation.py +++ b/firedrake/formmanipulation.py @@ -3,7 +3,7 @@ import collections from ufl import as_tensor, as_vector, split -from ufl.classes import Zero, FixedIndex, ListTensor, ZeroBaseForm +from ufl.classes import Zero, FixedIndex, ListTensor, ZeroBaseForm, Interpolate from ufl.algorithms.map_integrands import map_integrand_dags from ufl.algorithms import expand_derivatives from ufl.corealg.map_dag import MultiFunction, map_expr_dags @@ -71,6 +71,11 @@ def split(self, form, argument_indices): args = form.arguments() self._arg_cache = {} self.blocks = dict(enumerate(map(as_tuple, argument_indices))) + + if isinstance(form, Interpolate) and not args: + dual_arg, _ = form.argument_slots() + args = dual_arg.arguments() + if len(args) == 0: # Functional can't be split return form @@ -191,14 +196,14 @@ def interpolate(self, o, operand): return self(ZeroBaseForm(o.arguments())) dual_arg, _ = o.argument_slots() - if len(dual_arg.arguments()) == 1 or len(dual_arg.arguments()[-1].function_space()) == 1: - # The dual argument has been contracted or does not need to be split + dual_arguments = dual_arg.arguments() + if len(dual_arguments) == 1 and len(dual_arguments[0].function_space()) == 1: return o._ufl_expr_reconstruct_(operand, dual_arg) - if not isinstance(dual_arg, Coargument): + if not isinstance(dual_arg, Coargument | Cofunction): raise NotImplementedError(f"I do not know how to split an Interpolate with a {type(dual_arg).__name__}.") - indices = self.blocks[dual_arg.number()] + indices = self.blocks[dual_arguments[0].number()] V = dual_arg.function_space() # Split the target (dual) argument @@ -254,6 +259,11 @@ def split_form(form, diagonal=False): """ splitter = ExtractSubBlock() args = form.arguments() + + if isinstance(form, Interpolate) and not args: + dual_arg, _ = form.argument_slots() + args = dual_arg.arguments() + shape = tuple(len(a.function_space()) for a in args) forms = [] rank = len(shape) From c1273ad54647cb96d859e03f762076b95fe5448e Mon Sep 17 00:00:00 2001 From: Leo Collins Date: Sat, 20 Dec 2025 01:13:35 +0000 Subject: [PATCH 3/3] lint --- firedrake/formmanipulation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/firedrake/formmanipulation.py b/firedrake/formmanipulation.py index f5fbe2ee80..906877e9b6 100644 --- a/firedrake/formmanipulation.py +++ b/firedrake/formmanipulation.py @@ -263,7 +263,7 @@ def split_form(form, diagonal=False): if isinstance(form, Interpolate) and not args: dual_arg, _ = form.argument_slots() args = dual_arg.arguments() - + shape = tuple(len(a.function_space()) for a in args) forms = [] rank = len(shape)