From 10b7752aad6cfd91ac62046d6ac2b97bd8925ae2 Mon Sep 17 00:00:00 2001 From: Josh Hope-Collins Date: Thu, 12 Mar 2026 15:18:10 +0000 Subject: [PATCH 1/4] Ensemble.allgather --- firedrake/ensemble/ensemble.py | 67 ++++++++++++++++++++++- tests/firedrake/ensemble/test_ensemble.py | 47 ++++++++++++++++ 2 files changed, 112 insertions(+), 2 deletions(-) diff --git a/firedrake/ensemble/ensemble.py b/firedrake/ensemble/ensemble.py index 883b0e598a..aac4cc0c11 100644 --- a/firedrake/ensemble/ensemble.py +++ b/firedrake/ensemble/ensemble.py @@ -17,8 +17,14 @@ def _ensemble_mpi_dispatch(func): """ @wraps(func) def _mpi_dispatch(self, *args, **kwargs): - if any(isinstance(arg, (Function, Cofunction)) - for arg in [*args, *kwargs.values()]): + all_args = [] + for arg in [*args, *kwargs.values()]: + if isinstance(arg, (list, tuple)): + all_args.extend(arg) + else: + all_args.append(arg) + + if any(isinstance(arg, (Function, Cofunction)) for arg in all_args): return func(self, *args, **kwargs) else: mpicall = getattr(self._ensemble_comm, func.__name__) @@ -584,3 +590,60 @@ def isendrecv(self, fsend: Function | Cofunction, dest: int, sendtag: int = 0, requests.extend([self._ensemble_comm.Irecv(dat.data, source=source, tag=recvtag) for dat in frecv.dat]) return requests + + @PETSc.Log.EventDecorator() + @_ensemble_mpi_dispatch + def allgather(self, fsend: Function | Cofunction | list[Function | Cofunction], + frecv: list[Function | Cofunction], + recvcounts: list[int] | None = None + ) -> list[Function | Cofunction]: + """ + Allgather the :class:`.Function` (or ``list`` of ``Functions``) ``fsend`` + from all ranks into ``frecv`` on every rank. + + Parameters + ---------- + fsend : + The :class:`.Function` on the local rank to gather on all ranks. + frecv : + The list of :class:`.Function` that ``fsend`` from each ensemble + rank will be gathered into. + recvcounts : + A list with the number of :class:`.Function` in ``fsend`` on each + rank. If not provided then it will be calculated using an Allgather + before communicating the data in ``fsend``. + + Returns + ------- + list[Function | Cofunction] : + The gathered functions ``frecv``. + + Raises + ------ + ValueError + If ``recvcounts`` does not have :meth:`.Ensemble.ensemble_size` + elements. + ValueError + If ``sum(recvcounts)`` does not have + :meth:`.Ensemble.ensemble_size` elements. + """ + if not isinstance(fsend, (list, tuple)): + fsend = [fsend] + if recvcounts is None: + recvcounts = self.allgather(len(fsend)) + if len(recvcounts) != self.ensemble_size: + raise ValueError( + f"Ensemble has {self.ensemble_size} ranks but {len(recvcounts)} recvcounts provided.") + if sum(recvcounts) != len(frecv): + raise ValueError( + f"Need to receive {sum(recvcounts)} items but frecv has length {len(frecv)}") + + # loop over ensemble ranks and bcast local data to all other ranks. + idx = 0 + for root, nsend in enumerate(recvcounts): + for i in range(nsend): + if self.ensemble_rank == root: + frecv[idx].assign(fsend[i]) + self.bcast(frecv[idx], root=root) + idx += 1 + return frecv diff --git a/tests/firedrake/ensemble/test_ensemble.py b/tests/firedrake/ensemble/test_ensemble.py index 6f9fef11d0..37b233e715 100644 --- a/tests/firedrake/ensemble/test_ensemble.py +++ b/tests/firedrake/ensemble/test_ensemble.py @@ -351,6 +351,53 @@ def test_sendrecv(ensemble, mesh, W, urank, blocking): parallel_assert(errornorm(urecv, u_expect) < 1e-12) +@pytest.mark.parallel(nprocs=6) +@pytest.mark.parametrize("recv_counts", ["no_recvcounts", "recvcounts"]) +@pytest.mark.parametrize("distribution", ["balanced", "imbalanced"]) +def test_allgather(ensemble, mesh, recv_counts, distribution): + U = FunctionSpace(mesh, "CG", 1) + V = FunctionSpace(mesh, "DG", 2) + W = U*V + spaces = [U, W, V] + + if distribution == "balanced": + recvcounts = [2, 2, 2] + elif distribution == "imbalanced": + recvcounts = [1, 3, 2] + else: + raise ValueError(f"Unrecognised {distribution=}") + + rank = ensemble.ensemble_rank + + local_spaces = spaces[:recvcounts[rank]] + + global_spaces = [] + for root in range(ensemble.ensemble_size): + for i in range(recvcounts[root]): + global_spaces.append(spaces[i]) + + fsend = [Function(fs).assign(10*(rank+1) + i) + for i, fs in enumerate(local_spaces)] + + frecv = [Function(fs) for fs in global_spaces] + + if recv_counts == "no_recvcounts": + ensemble.allgather(fsend, frecv) + elif recv_counts == "recvcounts": + ensemble.allgather(fsend, frecv, recvcounts=recvcounts) + else: + raise ValueError(f"Unrecognised {recv_counts=}") + + idx = 0 + for root in range(ensemble.ensemble_size): + for i in range(recvcounts[root]): + fcheck = Function(spaces[i]).assign(10*(root+1) + i) + f = frecv[idx] + error = errornorm(fcheck, f)/norm(fcheck) + parallel_assert(error < 1e-12, msg=f"{root=} | {i=} | {error=}") + idx += 1 + + @pytest.mark.parallel(nprocs=6) def test_ensemble_solvers(ensemble, W, urank, urank_sum): """ From 248271fa087e8856bbdc733439fee72c758979b3 Mon Sep 17 00:00:00 2001 From: Josh Hope-Collins Date: Thu, 12 Mar 2026 15:33:22 +0000 Subject: [PATCH 2/4] small ensemble function adjoint_utils fixes --- firedrake/adjoint_utils/ensemble_function.py | 30 +++++++++----------- firedrake/adjoint_utils/function.py | 3 +- 2 files changed, 16 insertions(+), 17 deletions(-) diff --git a/firedrake/adjoint_utils/ensemble_function.py b/firedrake/adjoint_utils/ensemble_function.py index 179d017285..df9a00e8a5 100644 --- a/firedrake/adjoint_utils/ensemble_function.py +++ b/firedrake/adjoint_utils/ensemble_function.py @@ -13,7 +13,7 @@ class EnsembleFunctionMixin(OverloadedType): Enables EnsembleFunction to do the following: - Be a Control for a NumpyReducedFunctional (_ad_to_list and _ad_assign_numpy) - Be used with pyadjoint TAO solver (_ad_{to,from}_petsc) - - Be used as a Control for Taylor tests (_ad_dot) + - Be used as a Control for Taylor tests (_ad_dot, _ad_add, _ad_mul) """ @staticmethod @@ -32,10 +32,8 @@ def wrapper(self, *args, **kwargs): @staticmethod def _ad_to_list(m): with m.vec_ro() as gvec: - lvec = PETSc.Vec().createSeq(gvec.size, - comm=PETSc.COMM_SELF) - PETSc.Scatter().toAll(gvec).scatter( - gvec, lvec, addv=PETSc.InsertMode.INSERT_VALUES) + scatter, lvec = PETSc.Scatter().toAll(gvec) + scatter.scatter(gvec, lvec, addv=PETSc.InsertMode.INSERT_VALUES) return lvec.array_r.tolist() @staticmethod @@ -50,22 +48,22 @@ def _ad_dot(self, other, options=None): local_dot = sum(uself._ad_dot(uother, options=options) for uself, uother in zip(self.subfunctions, other.subfunctions)) - return self.ensemble.ensemble_comm.allreduce(local_dot) + return self.function_space().ensemble.allreduce(local_dot) - def _ad_convert_riesz(self, value, options=None): - raise NotImplementedError + def _ad_convert_riesz(self, value, riesz_map=None): + return value.riesz_representation(riesz_map=riesz_map or "L2") def _ad_init_zero(self, dual=False): - from firedrake import EnsembleFunction, EnsembleCofunction + from firedrake import EnsembleFunction + space = self.function_space() if dual: - return EnsembleCofunction(self.function_space().dual()) - else: - return EnsembleFunction(self.function_space()) + space = space.dual() + return EnsembleFunction(space) def _ad_create_checkpoint(self): if disk_checkpointing(): raise NotImplementedError( - "Disk checkpointing not implemented for EnsembleFunctions") + f"Disk checkpointing not implemented for {type(self).__name__}") else: return self.copy() @@ -73,12 +71,12 @@ def _ad_restore_at_checkpoint(self, checkpoint): if type(checkpoint) is type(self): return checkpoint raise NotImplementedError( - "Disk checkpointing not implemented for EnsembleFunctions") + f"Disk checkpointing not implemented for {type(self).__name__}") def _ad_from_petsc(self, vec): - with self.vec_wo as self_v: + with self.vec_wo() as self_v: vec.copy(self_v) def _ad_to_petsc(self, vec=None): - with self.vec_ro as self_v: + with self.vec_ro() as self_v: return self_v.copy(vec or self._vec.duplicate()) diff --git a/firedrake/adjoint_utils/function.py b/firedrake/adjoint_utils/function.py index a4b83b9828..320dc2d0dc 100644 --- a/firedrake/adjoint_utils/function.py +++ b/firedrake/adjoint_utils/function.py @@ -253,7 +253,8 @@ def _ad_add(self, other): from firedrake import Function r = Function(self.function_space()) - Function.assign(r, self + other) + r += self + r += other return r def _ad_dot(self, other, options=None): From 1ca9ad93a3fcc7912e8a4523f0a2a41a803890b8 Mon Sep 17 00:00:00 2001 From: Josh Hope-Collins Date: Thu, 12 Mar 2026 15:33:35 +0000 Subject: [PATCH 3/4] EnsembleReducedFunctional refactor --- firedrake/adjoint/__init__.py | 6 +- firedrake/adjoint/ensemble_adjvec.py | 121 ++ .../adjoint/ensemble_reduced_functional.py | 1277 ++++++++++++++--- .../test_ensemble_reduced_functional.py | 1079 ++++++++++++-- 4 files changed, 2157 insertions(+), 326 deletions(-) create mode 100644 firedrake/adjoint/ensemble_adjvec.py diff --git a/firedrake/adjoint/__init__.py b/firedrake/adjoint/__init__.py index 56ba19fdff..2251c99830 100644 --- a/firedrake/adjoint/__init__.py +++ b/firedrake/adjoint/__init__.py @@ -32,7 +32,11 @@ from firedrake.adjoint.ufl_constraints import ( # noqa: F401 UFLInequalityConstraint, UFLEqualityConstraint ) -from firedrake.adjoint.ensemble_reduced_functional import EnsembleReducedFunctional # noqa F401 +from firedrake.adjoint.ensemble_adjvec import EnsembleAdjVec # noqa F401 +from firedrake.adjoint.ensemble_reduced_functional import ( # noqa F401 + EnsembleBcastReducedFunctional, EnsembleReduceReducedFunctional, + EnsembleTransformReducedFunctional, EnsembleAllgatherReducedFunctional, + EnsembleReducedFunctional) from firedrake.adjoint.transformed_functional import L2RieszMap, L2TransformedFunctional # noqa: F401 from firedrake.adjoint.covariance_operator import ( # noqa F401 WhiteNoiseGenerator, AutoregressiveCovariance, diff --git a/firedrake/adjoint/ensemble_adjvec.py b/firedrake/adjoint/ensemble_adjvec.py new file mode 100644 index 0000000000..46b9ec0e17 --- /dev/null +++ b/firedrake/adjoint/ensemble_adjvec.py @@ -0,0 +1,121 @@ +from functools import cached_property +from pyadjoint.overloaded_type import OverloadedType +from pyadjoint.adjfloat import AdjFloat +from firedrake.ensemble import Ensemble +from firedrake.adjoint_utils.checkpointing import disk_checkpointing + + +class EnsembleAdjVec(OverloadedType): + """ + A vector of :class:`pyadjoint.AdjFloat` distributed + over an :class:`.Ensemble`. + + Analagous to the :class:`.EnsembleFunction` and + :class:`.EnsembleCofunction` types but for :class:`~pyadjoint.AdjFloat`. + + Implements basic :class:`pyadjoint.OverloadedType` functionality + to be used as a :class:`pyadjoint.Control` or functional for the + :class:`~.ensemble_reduced_functional.EnsembleReducedFunctional` types. + + Parameters + ---------- + subvec : + The local part of the vector. + ensemble : + The :class:`.Ensemble` communicator. + + See Also + -------- + :class:`~.Ensemble` + :class:`~.EnsembleFunction` + :class:`~.EnsembleCofunction` + :class:`~.EnsembleReducedFunctional` + """ + + def __init__(self, subvec: list[AdjFloat], ensemble: Ensemble): + if not isinstance(ensemble, Ensemble): + raise TypeError( + f"EnsembleAdjVec needs an Ensemble, not a {type(ensemble).__name__}") + if not all(isinstance(v, (AdjFloat, float)) for v in subvec): + raise TypeError( + f"EnsembleAdjVec must be instantiated with a list of AdjFloats, not {subvec}") + self._subvec = [AdjFloat(x) for x in subvec] + self.ensemble = ensemble + OverloadedType.__init__(self) + + @property + def subvec(self) -> list[AdjFloat]: + """The part of the vector on the local spatial comm.""" + return self._subvec + + @cached_property + def local_size(self) -> int: + """The length of the part of the vector on the local spatial comm.""" + return len(self._subvec) + + @cached_property + def global_size(self) -> int: + """The global length of vector.""" + return self.ensemble.allreduce(self.local_size) + + def _ad_init_zero(self, dual: bool = False) -> "EnsembleAdjVec": + return type(self)( + [v._ad_init_zero(dual=dual) for v in self.subvec], + self.ensemble) + + def _ad_dot(self, other: OverloadedType) -> float: + local_dot = sum(s._ad_dot(o) + for s, o in zip(self.subvec, other.subvec)) + global_dot = self.ensemble.ensemble_comm.allreduce(local_dot) + return global_dot + + def _ad_add(self, other) -> "EnsembleAdjVec": + return EnsembleAdjVec( + [s._ad_add(o) for s, o in zip(self.subvec, other.subvec)], + ensemble=self.ensemble) + + def _ad_mul(self, other) -> "EnsembleAdjVec": + return EnsembleAdjVec( + [s._ad_mul(o) for s, o in zip(self.subvec, + self._maybe_scalar(other))], + ensemble=self.ensemble) + + def _ad_iadd(self, other) -> "EnsembleAdjVec": + for s, o in zip(self.subvec, other.subvec): + s._ad_iadd(o) + return self + + def _ad_imul(self, other) -> "EnsembleAdjVec": + for s, o in zip(self.subvec, self._maybe_scalar(other)): + s._ad_imul(o) + return self + + def _maybe_scalar(self, val): + if isinstance(val, EnsembleAdjVec): + return val.subvec + else: + return [val for _ in self.subvec] + + def _ad_copy(self) -> "EnsembleAdjVec": + return EnsembleAdjVec( + [v._ad_copy() for v in self.subvec], + ensemble=self.ensemble) + + def _ad_convert_riesz(self, value, riesz_map=None) -> "EnsembleAdjVec": + return EnsembleAdjVec( + [s._ad_convert_riesz(v, riesz_map=riesz_map) + for s, v in zip(self.subvec, self._maybe_scalar(value))], + ensemble=self.ensemble) + + def _ad_create_checkpoint(self): + if disk_checkpointing(): + raise NotImplementedError( + f"Disk checkpointing not implemented for {type(self).__name__}") + else: + return self._ad_copy() + + def _ad_restore_at_checkpoint(self, checkpoint): + if type(checkpoint) is type(self): + return checkpoint + raise NotImplementedError( + f"Disk checkpointing not implemented for {type(self).__name__}") diff --git a/firedrake/adjoint/ensemble_reduced_functional.py b/firedrake/adjoint/ensemble_reduced_functional.py index 72979a5702..105805dc7f 100644 --- a/firedrake/adjoint/ensemble_reduced_functional.py +++ b/firedrake/adjoint/ensemble_reduced_functional.py @@ -1,249 +1,1106 @@ -from pyadjoint.reduced_functional import AbstractReducedFunctional, ReducedFunctional +from functools import cached_property +from pyadjoint.reduced_functional import AbstractReducedFunctional +from pyadjoint import Control, AdjFloat, no_annotations, OverloadedType from pyadjoint.enlisting import Enlist -from pyop2.mpi import MPI - from firedrake.function import Function from firedrake.cofunction import Cofunction +from .ensemble_adjvec import EnsembleAdjVec +from firedrake.ensemble import ( + Ensemble, EnsembleFunctionSpace, EnsembleFunction) +from firedrake.ensemble.ensemble_function import EnsembleFunctionBase +__all__ = ( + "EnsembleReducedFunctional", + "EnsembleBcastReducedFunctional", + "EnsembleReduceReducedFunctional", + "EnsembleAllgatherReducedFunctional", + "EnsembleTransformReducedFunctional", +) -class EnsembleReducedFunctional(AbstractReducedFunctional): - """Enable solving simultaneously reduced functionals in parallel. - Consider a functional :math:`J` and its gradient :math:`\\dfrac{dJ}{dm}`, - where :math:`m` is the control parameter. Let us assume that :math:`J` is the sum of - :math:`N` functionals :math:`J_i(m)`, i.e., +# utility functions to hide API differences between EnsembleFunction and EnsembleAdjVec - .. math:: - J = \\sum_{i=1}^{N} J_i(m). +def _local_subs(val): + if isinstance(val, EnsembleFunctionBase): + return val.subfunctions + elif isinstance(val, EnsembleAdjVec): + return val.subvec + elif isinstance(val, (list, tuple)): + return val + else: + raise TypeError( + f"Cannot use {type(val).__name__} as an ensemble overloaded type.") - The gradient over a summation is a linear operation. Therefore, we can write the gradient - :math:`\\dfrac{dJ}{dm}` as - .. math:: +def _global_size(val): + if isinstance(val, EnsembleFunctionBase): + return val.function_space().nglobal_spaces + elif isinstance(val, EnsembleAdjVec): + return val.global_size + else: + raise TypeError( + f"Cannot use {type(val).__name__} as an ensemble overloaded type.") - \\frac{dJ}{dm} = \\sum_{i=1}^{N} \\frac{dJ_i}{dm}, - The :class:`EnsembleReducedFunctional` allows simultaneous evaluation of :math:`J_i` and - :math:`\\dfrac{dJ_i}{dm}`. After that, the allreduce :class:`~.ensemble.Ensemble` - operation is employed to sum the functionals and their gradients over an ensemble - communicator. +def _local_size(val): + return len(_local_subs(val)) + + +def _set_local_subs(dst, src): + assert _local_size(dst) == _local_size(src) + dst_subs = _local_subs(dst) + for i, s in enumerate(src): + if hasattr(dst_subs[i], 'assign'): + dst_subs[i].assign(s) + else: + dst_subs[i] = s + return dst + + +def _ensemble(val): + if isinstance(val, EnsembleFunctionBase): + return val.function_space().ensemble + elif isinstance(val, EnsembleAdjVec): + return val.ensemble + + +def _make_ensemble_obj(local_vals, ensemble): + if all(isinstance(val, float) for val in local_vals): + return EnsembleAdjVec(local_vals, ensemble) + elif all(isinstance(val, (Function, Cofunction)) for val in local_vals): + ensemble_space = EnsembleFunctionSpace( + [val.function_space() for val in local_vals], ensemble) + ensemble_function = EnsembleFunction(ensemble_space) + _set_local_subs(ensemble_function, local_vals) + return ensemble_function + else: + raise TypeError("All local values must be of same type," + " either AdjFloat or Function or Cofunction.") + + +class EnsembleReduceReducedFunctional(AbstractReducedFunctional): + """ + A parallel reduction from all components of an :class:`.EnsembleFunction` + or :class:`.EnsembleAdjVec` distributed over an :class:`.Ensemble` into a + :class:`.Function` or :class:`pyadjoint.AdjFloat` on each ensemble member. + + Currently the only reduction operation implemented is a sum. The adjoint + operation to a sum is a broadcast, so the ``derivative`` function will + return an ensemble object with a copy of the ``adj_input`` in all + components. + + The ``functional`` must be suitable to reduce each component of ``control`` + into. The acceptable combinations are shown in the table below. For an + :class:`.EnsembleFunction` or :class:`.EnsembleCofunction` ``control`` + there is an additional restriction that the ``functional`` and all + components of ``control`` must have the same :func:`.FunctionSpace`. - If gather_functional is present, then all the values of J are communicated to all ensemble - ranks, and passed in a list to gather_functional, which is a reduced functional that expects - a list of that size of the relevant types. + .. list-table:: + :header-rows: 1 + + * - ``control`` + - ``functional`` + * - :class:`.EnsembleFunction` + - :class:`.Function` + * - :class:`.EnsembleCofunction` + - :class:`.Cofunction` + * - :class:`.EnsembleAdjVec` + - :class:`~pyadjoint.AdjFloat` Parameters ---------- - functional : pyadjoint.OverloadedType - An instance of an OverloadedType, usually :class:`pyadjoint.AdjFloat`. - This should be the functional that we want to reduce. - control : pyadjoint.Control or list of pyadjoint.Control - A single or a list of Control instances, which you want to map to the functional. - ensemble : Ensemble - An instance of the :class:`~.ensemble.Ensemble`. It is used to communicate the - functionals and their derivatives between the ensemble members. - scatter_control : bool - Whether scattering a control (or a list of controls) over the ensemble communicator - ``Ensemble.ensemble comm``. - gather_functional : An instance of the :class:`pyadjoint.ReducedFunctional`. - that takes in all of the Js. - derivative_components : list of int - The indices of the controls that the derivative should be computed with respect to. - If present, it overwrites ``derivative_cb_pre`` and ``derivative_cb_post``. - scale : float - A scaling factor applied to the functional and its gradient(with respect to the control). - tape : pyadjoint.Tape - A tape object that the reduced functional will use to evaluate the functional and - its gradients (or derivatives). - eval_cb_pre : :func: - Callback function before evaluating the functional. Input is a list of Controls. - eval_cb_pos : :func: - Callback function after evaluating the functional. Inputs are the functional value - and a list of Controls. - derivative_cb_pre : :func: - Callback function before evaluating gradients (or derivatives). Input is a list of - gradients (or derivatives). Should return a list of Controls (usually the same list as - the input) to be passed to :func:`pyadjoint.compute_gradient`. - derivative_cb_post : :func: - Callback function after evaluating derivatives. Inputs are the functional, a list of - gradients (or derivatives), and controls. All of them are the checkpointed versions. - Should return a list of gradients (or derivatives) (usually the same list as the input) - to be returned from ``self.derivative``. - hessian_cb_pre : :func: - Callback function before evaluating the Hessian. Input is a list of Controls. - hessian_cb_post : :func: - Callback function after evaluating the Hessian. Inputs are the functional, a list of - Hessian, and controls. + functional : + The result of the reduction, i.e. the sum of all components of the + ``control`` from all ensemble members. Must be identical on all + members. + control : + An object with several components to sum, distributed over an + :class:`.Ensemble`. Must be a single :class:`~pyadjoint.Control` + rather than a list. + + Notes + ----- + Unlike most ``ReducedFunctional`` classes, this one does not require any + operations to be taped before creating it. The ``functional`` and + ``control`` arguments are just to specify the source and destination + spaces. + + This class is primarily intended as a component for building larger + ``ReducedFunctional`` classes over an :class:`.Ensemble`, for example the + :class:`.EnsembleReducedFunctional`. + + Developer Notes + --------------- + The "hidden" parameter ``_only_forward`` exists because bcast and reduce + are adjoint to each other so we want to implement the derivative of each + using the other. ``_only_forward`` is used to do this without infinite + recursion. See Also -------- - :class:`~.ensemble.Ensemble`, :class:`pyadjoint.ReducedFunctional`. + :class:`pyadjoint.ReducedFunctional`. + :class:`~.Ensemble` + :class:`~.EnsembleFunction` + :class:`~.EnsembleCofunction` + :class:`~.EnsembleAdjVec` + :class:`~.EnsembleBcastReducedFunctional` + :class:`~.EnsembleTransformReducedFunctional` + :class:`~.EnsembleAllgatherReducedFunctional` + :class:`~.EnsembleReducedFunctional` + """ + def __init__(self, functional: OverloadedType, control: Control, + _only_forward=False): + self.functional = functional + self._controls = Enlist(control) + + if isinstance(functional, AdjFloat): + if not isinstance(control.control, EnsembleAdjVec): + raise TypeError( + f"Control for {type(self).__name__} must be an" + " EnsembleAdjVec if using an AdjFloat functional.") + elif isinstance(functional, (Function, Cofunction)): + if not isinstance(control.control, EnsembleFunctionBase): + raise TypeError( + f"Control for {type(self).__name__} must be an" + " EnsembleFunction or EnsembleCofunction if using" + " a Function or Cofunction functional.") + if not all([c.function_space() == functional.function_space() + for c in control.subfunctions]): + raise ValueError( + f"All subfunctions of the {type(self).__name__} control" + " must have the same function space as the functional.") + else: + raise ValueError( + f"Do not know how to handle a {type(functional).__name__}" + f" control for {type(self).__name__}.") + + # Adjoint action is a bcast so we just piggyback. + # Possibly don't do this if we're being created by the + # bcast rf to avoid infinite recursion. + if not _only_forward: + self._bcast = EnsembleBcastReducedFunctional( + functional=control.control._ad_init_zero(dual=True), + control=Control(functional._ad_init_zero(dual=True)), + _only_forward=True + ) + + @property + def controls(self): + return self._controls + + @property + def ensemble(self): + """The :class:`.Ensemble` that the control is defined over.""" + return _ensemble(self.controls[0].control) + + @no_annotations + def __call__(self, values): + for c, v in zip(self.controls, Enlist(values)): + c.update(v) + return self.tlm(values) + + @no_annotations + def derivative(self, adj_input=1.0, apply_riesz=False): + dJ = self._bcast(adj_input) + if apply_riesz: + return self.controls[0]._ad_convert_riesz( + dJ, riesz_map=self.controls[0].riesz_map) + return dJ + + @no_annotations + def tlm(self, m_dot): + if isinstance(m_dot, (list, tuple)): + m_dot = m_dot[0] + vals = _local_subs(m_dot) + local_sum = vals[0]._ad_init_zero() + for v in vals: + local_sum = local_sum._ad_add(v) + return self.ensemble.allreduce(local_sum) + + @no_annotations + def hessian(self, m_dot, hessian_input=None, evaluate_tlm=True, + apply_riesz=False): + if evaluate_tlm: + self.tlm(m_dot) + if hessian_input is None: + hessian_input = self.functional._ad_init_zero(dual=True) + return self.derivative(hessian_input, apply_riesz=apply_riesz) + + +class EnsembleBcastReducedFunctional(AbstractReducedFunctional): + """ + A parallel broadcast from a :class:`.Function` or + :class:`pyadjoint.AdjFloat` into all components of an + :class:`.EnsembleFunction` or :class:`.EnsembleAdjVec` distributed over an + :class:`.Ensemble`. + + The adjoint operation to a broadcast is a sum, so the ``derivative`` + function will return an object on each ensemble member with the sum of + all components of the ``adj_input``. + + The ``functional`` must be suitable to broadcast the ``control`` into. + The acceptable combinations are shown in the table below. For a + :class:`.Function` or :class:`.Cofunction` ``control`` there is an + additional restriction that the ``control`` and all components of + ``functional`` must have the same :func:`.FunctionSpace`. + + .. list-table:: + :header-rows: 1 + + * - ``control`` + - ``functional`` + * - :class:`.Function` + - :class:`.EnsembleFunction` + * - :class:`.Cofunction` + - :class:`.EnsembleCofunction` + * - :class:`~pyadjoint.AdjFloat` + - :class:`.EnsembleAdjVec` + + Parameters + ---------- + functional : + The result of the broadcast, i.e. all components of ``functional`` + will be copies of ``control``. + control : + The object to broadcast. Must be identical on each ensemble member. + Must be a single :class:`~pyadjoint.Control` rather than a list. + root : + If ``None`` then the argument to ``__call__`` or ``tlm`` is assumed to + be the same on all ensemble members, and the "broadcast" only requires + local copies rather than global communications. + If ``root`` is an ``int`` then this ensemble member is used as the + root for the broadcast and ``__call__`` and ``tlm`` will carry out + global communication. Notes ----- - The functionals :math:`J_i` and the control must be defined over a common - `ensemble.comm` communicator. To understand more about how ensemble parallelism - works, please refer to the `Firedrake manual - `_. + Unlike most ``ReducedFunctional`` classes, this one does not require any + operations to be taped before creating it. The ``functional`` and + ``control`` arguments are just to specify the source and destination + spaces. + + This class is primarily intended as a component for building larger + ``ReducedFunctional`` classes over an :class:`.Ensemble`, for example the + :class:`.EnsembleReducedFunctional`. + + Developer Notes + --------------- + The "hidden" parameter ``_only_forward`` exists because bcast and reduce + are adjoint to each other so we want to implement the derivative of each + using the other. ``_only_forward`` is used to do this without infinite + recursion. + + See Also + -------- + :class:`pyadjoint.ReducedFunctional`. + :class:`~.Ensemble` + :class:`~.EnsembleFunction` + :class:`~.EnsembleCofunction` + :class:`~.EnsembleAdjVec` + :class:`~.EnsembleReduceReducedFunctional` + :class:`~.EnsembleTransformReducedFunctional` + :class:`~.EnsembleAllgatherReducedFunctional` + :class:`~.EnsembleReducedFunctional` """ - def __init__(self, functional, control, ensemble, scatter_control=True, - gather_functional=None, - derivative_components=None, - scale=1.0, tape=None, - eval_cb_pre=lambda *args: None, - eval_cb_post=lambda *args: None, - derivative_cb_pre=lambda controls: controls, - derivative_cb_post=lambda checkpoint, derivative_components, controls: derivative_components, - hessian_cb_pre=lambda *args: None, - hessian_cb_post=lambda *args: None): - self.local_reduced_functional = ReducedFunctional( - functional, control, - derivative_components=derivative_components, - scale=scale, tape=tape, - eval_cb_pre=eval_cb_pre, - eval_cb_post=eval_cb_post, - derivative_cb_pre=derivative_cb_pre, - derivative_cb_post=derivative_cb_post, - hessian_cb_pre=hessian_cb_pre, - hessian_cb_post=hessian_cb_post - ) - - self.ensemble = ensemble - self.scatter_control = scatter_control - self.gather_functional = gather_functional + def __init__(self, functional: OverloadedType, + control: Control, root: int | None = None, + _only_forward=False): + self.functional = functional + self._controls = Enlist(control) + self.root = root + + if isinstance(control.control, AdjFloat): + if not isinstance(functional, EnsembleAdjVec): + raise TypeError( + f"Functional for {type(self).__name__} must be an" + " EnsembleAdjVec if using an AdjFloat control.") + elif isinstance(control.control, (Function, Cofunction)): + if not isinstance(functional, EnsembleFunctionBase): + raise TypeError( + f"Functional for {type(self).__name__} must be an" + " EnsembleFunction or EnsembleCofunction if using" + " a Function or Cofunction control.") + if not all([f.function_space() == control.function_space() + for f in functional.subfunctions]): + raise ValueError( + f"All subfunctions of the {type(self).__name__} functional" + " must have the same function space as the control.") + else: + raise ValueError( + f"Do not know how to handle a {type(control).__name__}" + f" control for {type(self).__name__}.") + + # Adjoint action is a reduction so we just piggyback. + # Possibly don't do this if we're being created by the + # reduction rf to avoid infinite recursion. + if not _only_forward: + self._reduce = EnsembleReduceReducedFunctional( + functional=control.control._ad_init_zero(dual=True), + control=Control(functional._ad_init_zero(dual=True)), + _only_forward=True + ) @property def controls(self): - return self.local_reduced_functional.controls - - def _allgather_J(self, J): - if isinstance(J, float): - vals = self.ensemble.ensemble_comm.allgather(J) - elif isinstance(J, Function): - # allgather not implemented in ensemble.py - vals = [] - for i in range(self.ensemble.ensemble_comm.size): - J0 = J.copy(deepcopy=True) - vals.append(self.ensemble.bcast(J0, root=i)) + return self._controls + + @property + def ensemble(self): + """The :class:`.Ensemble` that the functional is defined over.""" + return _ensemble(self.functional) + + @no_annotations + def __call__(self, values): + for c, v in zip(self.controls, Enlist(values)): + c.update(v) + return self.tlm(values) + + @no_annotations + def derivative(self, adj_input=1.0, apply_riesz=False): + dJ = self._reduce(adj_input) + if apply_riesz: + dJ = self.controls[0]._ad_convert_riesz( + dJ, riesz_map=self.controls[0].riesz_map) + return dJ + + @no_annotations + def tlm(self, m_dot): + if self.root is None: + m_dot = m_dot else: - raise NotImplementedError(f"Functionals of type {type(J).__name__} are not supported.") - return vals + m_dot = self.ensemble.bcast(m_dot, root=self.root) + tlv = self.functional._ad_init_zero() + _set_local_subs( + tlv, [m_dot for _ in range(_local_size(self.functional))]) + return tlv + @no_annotations + def hessian(self, m_dot, hessian_input=None, evaluate_tlm=True, + apply_riesz=False): + if evaluate_tlm: + self.tlm(m_dot) + return self.derivative(hessian_input, apply_riesz=apply_riesz) + + +class EnsembleAllgatherReducedFunctional(AbstractReducedFunctional): + """ + A parallel allgather of all components of an :class:`.EnsembleFunction`, + :class:`.EnsembleCofunction`, or :class:`.EnsembleAdjVec` onto all ensemble + members. The result is placed in an :class:`.EnsembleFunction`, + :class:`.EnsembleCofunction`, or :class:`.EnsembleAdjVec` with enough + components locally on each ensemble member to fit all components from every + ensemble member of the control. + + For example in the code below, the control is a :class:`.EnsembleFunction` + with two components on the first ensemble member and one on the second + member, so the functional must be a :class:`.EnsembleFunction` with three + components on each ensemble member with matching spaces. + + .. code-block:: python3 + + V0 = FunctionSpace(mesh, "DG", 0) + V1 = FunctionSpace(mesh, "CG", 1) + + if ensemble.ensemble_rank == 0: + local_controls = [V0, V1] + elif ensemble.ensemble_rank == 1: + local_controls = [V0] + + local_functionals = [V0, V1, V0] + + control_space = EnsembleFunctionSpace(local_controls, ensemble) + functional_space = EnsembleFunctionSpace(local_functionals, ensemble) + + Parameters + ---------- + functional : + The result of the allgather. + control : + The object to allgather. + Must be a single :class:`~pyadjoint.Control` rather than a list. + + Notes + ----- + Unlike most ``ReducedFunctional`` classes, this one does not require any + operations to be taped before creating it. The ``functional`` and + ``control`` arguments are just to specify the source and destination + spaces. + + This class is primarily intended as a component for building larger + ``ReducedFunctional`` classes over an :class:`.Ensemble`, for example the + :class:`.EnsembleReducedFunctional`. + + See Also + -------- + :class:`pyadjoint.ReducedFunctional`. + :class:`~.Ensemble` + :class:`~.EnsembleFunction` + :class:`~.EnsembleCofunction` + :class:`~.EnsembleAdjVec` + :class:`~.EnsembleReduceReducedFunctional` + :class:`~.EnsembleBcastReducedFunctional` + :class:`~.EnsembleTransformReducedFunctional` + :class:`~.EnsembleReducedFunctional` + """ + def __init__(self, functional: OverloadedType, control: Control): + assert _local_size(functional) == _global_size(control.control) + self._controls = Enlist(control) + self.functional = functional + + @property + def controls(self): + return self._controls + + @property + def ensemble(self): + """The :class:`.Ensemble` that the control and functional are defined over.""" + return _ensemble(self.functional) + + @no_annotations def __call__(self, values): - """Computes the reduced functional with supplied control value. - - Parameters - ---------- - values : pyadjoint.OverloadedType - If you have multiple controls this should be a list of - new values for each control in the order you listed the controls to the constructor. - If you have a single control it can either be a list or a single object. - Each new value should have the same type as the corresponding control. - - Returns - ------- - pyadjoint.OverloadedType - The computed value. Typically of instance of :class:`pyadjoint.AdjFloat`. - - """ - local_functional = self.local_reduced_functional(values) - ensemble_comm = self.ensemble.ensemble_comm - if self.gather_functional: - controls_g = self._allgather_J(local_functional) - total_functional = self.gather_functional(controls_g) - # if gather_functional is None then we do a sum - elif isinstance(local_functional, float): - total_functional = ensemble_comm.allreduce(sendobj=local_functional, op=MPI.SUM) - elif isinstance(local_functional, Function): - total_functional = type(local_functional)(local_functional.function_space()) - total_functional = self.ensemble.allreduce(local_functional, total_functional) + for c, v in zip(self.controls, Enlist(values)): + c.update(v) + return self.tlm(values) + + @no_annotations + def tlm(self, m_dot): + if isinstance(m_dot, (list, tuple)): + m_dot = m_dot[0] + tlv = self.functional._ad_init_zero() + # Ensemble.allgather has slightly different signature to Comm.allgather + # so we actually have to distinguish between the two cases here. + if isinstance(self.functional, EnsembleFunctionBase): + self.ensemble.allgather( + _local_subs(m_dot), _local_subs(tlv)) + elif isinstance(self.functional, EnsembleAdjVec): + local_mdot_lists = self.ensemble.allgather(_local_subs(m_dot)) + # allgather will return a list of lists that we need to flatten + global_mdot = [md + for local_mdots in local_mdot_lists + for md in local_mdots] + _set_local_subs(tlv, global_mdot) else: - raise NotImplementedError("This type of functional is not supported.") - return total_functional + raise TypeError( + f"Cannot use {type(m_dot).__name__}" + " as an ensemble overloaded type.") + return tlv + + @cached_property + def _local_indices(self): + local_size = _local_size(self.controls[0].control) + local_offset = self.ensemble.ensemble_comm.exscan(local_size) or 0 + return [local_offset + i for i in range(local_size)] + @no_annotations def derivative(self, adj_input=1.0, apply_riesz=False): - """Compute derivatives of a functional with respect to the control parameters. - - Parameters - ---------- - adj_input : float - The adjoint input. - apply_riesz: bool - If True, apply the Riesz map of each control in order to return - a primal gradient rather than a derivative in the dual space. - - Returns - ------- - dJdm_total : pyadjoint.OverloadedType - The result of Allreduce operations of ``dJdm_local`` into ``dJdm_total`` over the`Ensemble.ensemble_comm`. - - See Also - -------- - :meth:`~.ensemble.Ensemble.allreduce`, :meth:`pyadjoint.ReducedFunctional.derivative`. - """ - - if self.gather_functional: - dJg_dmg = self.gather_functional.derivative(adj_input=adj_input, - apply_riesz=False) - i = self.ensemble.ensemble_comm.rank - adj_input = dJg_dmg[i] - - dJdm_local = self.local_reduced_functional.derivative(adj_input=adj_input, - apply_riesz=apply_riesz) - - if self.scatter_control: - dJdm_local = Enlist(dJdm_local) - dJdm_total = [] - - for dJdm in dJdm_local: - if not isinstance(dJdm, (Cofunction, Function, float)): - raise NotImplementedError( - f"Gradients of type {type(dJdm).__name__} are not supported.") - - dJdm_total.append( - self.ensemble.allreduce(dJdm, type(dJdm)(dJdm.function_space())) - if isinstance(dJdm, (Cofunction, Function)) - else self.ensemble.ensemble_comm.allreduce(sendobj=dJdm, op=MPI.SUM) + dJ_global = _local_subs(adj_input) + dJ_local = [dJ_global[i] for i in self._local_indices] + + dJ = self.controls[0].control._ad_init_zero(dual=True) + _set_local_subs(dJ, dJ_local) + if apply_riesz: + dJ = dJ._ad_convert_riesz( + dJ, riesz_map=self.controls[0].riesz_map) + return dJ + + @no_annotations + def hessian(self, m_dot, hessian_input=None, evaluate_tlm=True, + apply_riesz=False): + if evaluate_tlm: + self.tlm(m_dot) + return self.derivative(hessian_input, apply_riesz=apply_riesz) + + +class EnsembleTransformReducedFunctional(AbstractReducedFunctional): + """ + A parallel transformation from a set of :class:`.EnsembleFunction`, + :class:`.EnsembleCofunction`, or :class:`.EnsembleAdjVec` controls to + a functional which is also an ensemble object (i.e. one of these classes). + + The functional and all controls must have the same number of components on + each ensemble member, and the transformation maps from components in the + same position in each control to the corresponding component in the + functional. This is explained in more detail below. + The transform for each component position is provided by an + :class:`~pyadjoint.reduced_functional.AbstractReducedFunctional` defined on + the local ensemble member, whose ``controls`` and ``functional`` must match + the corresponding components of the ``controls`` and ``functional`` of the + ``EnsembleTransformReducedFunctional``. + + For example (ignoring the ensemble parallel partition for now), if we have + two controls :math:`u` and :math:`v`, and a functional :math:`w`, which + each have three components: + + .. math:: + + u \\in U=U_{0} \\times U_{1} \\times U_{2}, + + v \\in V=V_{0} \\times V_{1} \\times V_{2}, + + w \\in W=W_{0} \\times W_{1} \\times W_{2}, + + Then the ``EnsembleTransformReducedFunctional`` maps + :math:`\\hat{J} : U \\times V \\to W`, and we need a + :class:`~pyadjoint.ReducedFunctional` for each component: + + .. math:: + + \\hat{J}_{0} : U_{0} \\times V_{0} \\to W_{0}, + + \\hat{J}_{1} : U_{1} \\times V_{1} \\to W_{1}, + + \\hat{J}_{2} : U_{2} \\times V_{2} \\to W_{2}. + + If :math:`u`, :math:`v`, and :math:`w` are each an + :class:`.EnsembleFunction`, :class:`.EnsembleCofunction` or + :class:`.EnsembleAdjVec`, then the components are distributed over the + ensemble members. The corresponding :class:`~.pyadjoint.ReducedFunctional` + for each :math:`\\hat{J}_{i}` are defined locally on each ensemble member. + For example: + + .. code-block:: python3 + + if ensemble.ensemble_rank == 0: + Ulocals = [U0, U1] + Vlocals = [V0, V1] + elif ensemble.ensemble_rank == 1: + Ulocals = [U2] + Vlocals = [V2] + + U = EnsembleFunctionSpace(Ulocals, ensemble) + V = EnsembleFunctionSpace(Vlocals, ensemble) + + u = EnsembleFunction(U) + v = EnsembleFunction(V) + + Jlocals = [] + wlocals = [] + for ui, vi in zip(u.subfunctions, v.subfunctions): + with set_working_tape() as tape: + wi = assemble(ui*vi*dx) + Ji = ReducedFunctional(wi, [Control(ui), Control(vi)], tape=tape) + + wlocals.append(wi) + Jlocals.append(Ji) + + w = EnsembleAdjVec(wlocals, ensemble) + + Jhat = EnsembleTransformReducedFunctional( + w, [Control(u), Control(v)], Jlocals) + + Note that by using :func:`~pyadjoint.set_working_tape` we ensure that + each local :class:`~pyadjoint.ReducedFunctional` has its own tape. + For such a simple example this is unlikely to make a difference, but + for more complex operations this will ensure that each local + ``ReducedFunctional`` does not interfere with the others. + + Parameters + ---------- + functional : + The result of the transform. + control : + The inputs to the transform. + + Notes + ----- + Unlike most ``ReducedFunctional`` classes, this one does not require any + operations on the ``control`` and ``functional`` to be taped before + creating it. The ``functional`` and ``control`` arguments are just to + specify the source and destination spaces. + + This class is primarily intended as a component for building larger + ``ReducedFunctional`` classes over an :class:`.Ensemble`, for example the + :class:`.EnsembleReducedFunctional`. + + See Also + -------- + :class:`pyadjoint.ReducedFunctional`. + :class:`~.Ensemble` + :class:`~.EnsembleFunction` + :class:`~.EnsembleCofunction` + :class:`~.EnsembleAdjVec` + :class:`~.EnsembleReduceReducedFunctional` + :class:`~.EnsembleBcastReducedFunctional` + :class:`~.EnsembleAllgatherReducedFunctional` + :class:`~.EnsembleReducedFunctional` + """ + def __init__(self, functional: OverloadedType, control: Control | list[Control], + rfs: AbstractReducedFunctional | list[AbstractReducedFunctional]): + self.rfs = Enlist(rfs) + self.functional = functional + self._controls = Enlist(control) + + EnsembleTypes = (EnsembleFunctionBase, EnsembleAdjVec) + + if not isinstance(functional, EnsembleTypes): + raise TypeError( + f"Functional for {type(self).__name__} must be either an" + f" EnsembleFunction or EnsembleAdjVec, not {type(functional)}" + ) + for c in self.controls: + if not isinstance(c.control, EnsembleTypes): + raise TypeError( + f"Controls for {type(self).__name__} must be either an " + f"EnsembleFunction or EnsembleAdjVec not {type(c.control)}" ) - return dJdm_local.delist(dJdm_total) - return dJdm_local + clens = set(len(_local_subs(c.control)) for c in self.controls) + flen = len(_local_subs(functional)) + rlen = len(self.rfs) + if len(clens) != 1: + raise ValueError( + f"All Controls for {type(self).__name__} must have" + " the same number of components on each ensemble rank" + ) + clen = clens.pop() + if clen != flen: + raise ValueError( + f"Control of with {clen} local components for" + f" {type(self).__name__} must have the same number of local" + f" components as the functional ({flen})" + ) + if clen != rlen: + raise ValueError( + f"{type(self).__name__} given {rlen} local ReducedFunctionals," + f" but needs one for each local component of Control with" + f" length {clen}") + + @property + def controls(self): + return self._controls + + @property + def ensemble(self): + """The :class:`.Ensemble` that the control and functional are defined + over.""" + return _ensemble(self.functional) + + @no_annotations + def __call__(self, values): + for c, v in zip(self.controls, Enlist(values)): + c.update(v) + + local_vals = self._global_to_local_data(values) + local_Js = [rf(v) for rf, v in zip(self.rfs, local_vals)] + + J = self.functional._ad_init_zero() + self._local_to_global_data(local_Js, J) + + return J + + @no_annotations def tlm(self, m_dot): - """Return the action of the tangent linear model of the functional. - - The tangent linear model is evaluated w.r.t. the control on a vector - m_dot, around the last supplied value of the control. - - Parameters - ---------- - m_dot : pyadjoint.OverloadedType - The direction in which to compute the action of the tangent linear model. - - Returns - ------- - pyadjoint.OverloadedType: The action of the tangent linear model in the - direction m_dot. Should be an instance of the same type as the functional. - """ - local_tlm = self.local_reduced_functional.tlm(m_dot) - ensemble_comm = self.ensemble.ensemble_comm - if self.gather_functional: - mdot_g = self._allgather_J(local_tlm) - total_tlm = self.gather_functional.tlm(mdot_g) - # if gather_functional is None then we do a sum - elif isinstance(local_tlm, float): - total_tlm = ensemble_comm.allreduce(sendobj=local_tlm, op=MPI.SUM) - elif isinstance(local_tlm, Function): - total_tlm = type(local_tlm)(local_tlm.function_space()) - total_tlm = self.ensemble.allreduce(local_tlm, total_tlm) + local_mdot = self._global_to_local_data(m_dot) + local_tlm = [rf.tlm(md) for rf, md in zip(self.rfs, local_mdot)] + + tlm = self.functional._ad_init_zero() + self._local_to_global_data(local_tlm, tlm) + + return tlm + + @no_annotations + def derivative(self, adj_input=1.0, apply_riesz=False): + local_adj = self._global_to_local_data(adj_input) + local_dJ = [rf.derivative(adj_input=adj[0], apply_riesz=apply_riesz) + for rf, adj in zip(self.rfs, local_adj)] + + dJ = self.controls.delist( + [c.control._ad_init_zero(dual=not apply_riesz) + for c in self.controls]) + + self._local_to_global_data(local_dJ, dJ) + + return dJ + + @no_annotations + def hessian(self, m_dot, hessian_input=None, evaluate_tlm=True, + apply_riesz=False): + if evaluate_tlm: + self.tlm(m_dot) + + local_hin = self._global_to_local_data(hessian_input) + local_hess = [rf.hessian(m_dot=None, evaluate_tlm=False, + hessian_input=hin[0], + apply_riesz=apply_riesz) + for rf, hin in zip(self.rfs, local_hin)] + + hessian = self.controls.delist( + [c.control._ad_init_zero(dual=not apply_riesz) + for c in self.controls]) + + self._local_to_global_data(local_hess, hessian) + + return hessian + + def _local_to_global_data(self, local_data, global_data): + # N local lists of length n -> n global lists of length N + # [(1,), (2,), (3,)]-> [(1, 2, 3)] + # [(1, 11), (2, 12), (3, 13)] -> [(1, 2, 3), (11, 12, 13)] + + for j, global_group in enumerate(Enlist(global_data)): + local_group = [Enlist(local_group)[j] + for local_group in local_data] + _set_local_subs(global_group, local_group) + + return global_data + + def _global_to_local_data(self, global_data): + # n global lists of length N -> N local lists of length n + # [(1, 2, 3)] -> [(1,), (2,), (3,)] + # [(1, 2, 3), (11, 12, 13)] -> [(1, 11), (2, 12), (3, 13)] + + local_groups = [ + ld for ld in zip(*map(_local_subs, Enlist(global_data)))] + return local_groups + + +class EnsembleReducedFunctional(AbstractReducedFunctional): + """ + A reduced functional where multiple independent terms of the + objective functional are calculated in parallel. + + This class covers two cases: + + 1. The terms all depend on the same controls. + + 2. The terms each depend on different controls. + + In the first case, if we have a functional :math:`J(m)` where :math:`m` is + the control parameter, then we assume that :math:`J` is the sum of + :math:`N` functionals :math:`J_i(m)`, which all depend on :math:`m` but are + independent of each other, i.e. + + .. math:: + + J(m) = \\sum_{i=1}^{N} J_i(m). + + The gradient over a summation is a linear operation. Therefore, we can + write the gradient :math:`\\dfrac{dJ}{dm}` as + + .. math:: + + \\frac{dJ}{dm} = \\sum_{i=1}^{N} \\frac{dJ_i}{dm}. + + In this case both the ``controls`` and ``functional`` are distributed only + in space, e.g. are a :class:`.Function`, :class:`~pyadjoint.AdjFloat` etc. + + In the second case, we assume that the control :math:`m` now has :math:`N` + components :math:`m_i`, and that each term :math:`J_i` depends on a + different component of :math:`m`, i.e. + + .. math:: + + J(m) = \\sum_{i=1}^{N} J_i(m_i). + + Now we can write the gradient :math:`\\dfrac{dJ}{dm}` component-wise: + + .. math:: + + \\frac{dJ}{dm_i} = \\frac{dJ_i}{dm_i}. + + In this case the ``functional`` is distributed only in space, e.g. + a :class:`.Function`, :class:`~pyadjoint.AdjFloat` etc, but the + ``controls`` are distributed also over the ``Ensemble``, for example + an :class:`~.EnsembleFunction` or :class:`~.EnsembleAdjVec`. + + In both cases, :class:`EnsembleReducedFunctional` allows simultaneous + evaluation of :math:`J_i` and either :math:`\\dfrac{dJ_i}{dm}` or + :math:`\\dfrac{dJ_i}{dm_i}` (as well as the tangent linear model and + Hessian actions) by different spatial communicators on the ``ensemble``. + All the required communication is handled by the :class:`~.Ensemble`. + + The terms :math:`J_i` on each spatial communicator are provided as a list + of :class:`pyadjoint.ReducedFunctional` for the local terms. For the + second case the number of terms on each spatial comm must match the number + of local components of the controls (e.g. :meth:`.EnsembleAdjVec.subvec`). + It is essential that each :math:`J_i` has its own :class:`pyadjoint.Tape` + so that it can maintain its own state independent of the other terms. + + If the ``gather_rf`` :class:`~pyadjoint.ReducedFunctional` is then, instead + of just summing the values of :math:`J_i`, they are all-gathered onto every + spatial comm and passed as the controls to ``gather_rf``. :math:`J` is then + the output of the ``gather_rf``. + + If :math:`G` is the ``gather_rf`` then for case 1 this is equivalent to + + .. math:: + + J(m) = G(J_1(m), J_2(m), \\dots, J_N(m)) + + and for case 2 this is equivalent to + + .. math:: + + J(m) = G(J_1(m_1), J_2(m_2), \\dots, J_N(m_N)) + + Parameters + ---------- + functional : + An instance of an OverloadedType, usually :class:`~pyadjoint.AdjFloat`. + This should be the functional that we want to calculate. + control : + A single or a list of :class:`pyadjoint.Control` instances, which you + want to map to the functional. + rfs : + The :class:`~pyadjoint.ReducedFunctional` for each term :math:`J_{i}`. + It is essential that each ``rf`` has its own :class:`pyadjoint.Tape` + so that it can maintain its own state independent of the other terms. + gather_rf : + The reduced functional to map all ensemble components to the + functional. Requires the functional to be a non-ensemble type (e.g. an + :class:`~pyadjoint.AdjFloat` or :class:`~.Function`) + ensemble : + An instance of the :class:`~.ensemble.Ensemble`. It is used to + communicate the functionals and their derivatives between the ensemble + members. If either the functional or controls are ensemble types (e.g. + :class:`.EnsembleFunction` or :class:`.EnsembleAdjVec`) then the + ``ensemble`` is accessed from them, and this argument is ignored. + + Notes + ----- + Each :class:`~pyadjoint.ReducedFunctional` in ``rfs`` that define the + functionals :math:`J_i` must be defined over a single ``ensemble.comm`` + communicator. + + To understand more about ensemble parallelism, please refer to the + `Firedrake manual `_. + + See Also + -------- + :class:`pyadjoint.ReducedFunctional`. + :class:`~.Ensemble` + :class:`~.EnsembleFunction` + :class:`~.EnsembleCofunction` + :class:`~.EnsembleAdjVec` + :class:`~.EnsembleReduceReducedFunctional` + :class:`~.EnsembleBcastReducedFunctional` + :class:`~.EnsembleTransformReducedFunctional` + :class:`~.EnsembleAllgatherReducedFunctional` + """ + def __init__(self, functional: OverloadedType, + control: Control | list[Control], + rfs: list[AbstractReducedFunctional], + gather_rf: AbstractReducedFunctional | None = None, + ensemble: Ensemble | None = None): + self._local_rfs = Enlist(rfs) + self._controls = Enlist(control) + self.functional = functional + self._ensemble = ensemble + self._gather_rf = gather_rf + + # total number of ensemble components + local_size = len(self._local_rfs) + global_size = self.ensemble.allreduce(local_size) + + # Case 1: Standard summation reduction + # ------- ----------- -------- + # outer_controls -> | bcast | -> ensemble_controls -> | transform | -> ensemble_functional -> | reduce | -> outer_functional + # ------- ----------- -------- + # + # Case 2: Non-summation reduction requiring gather + # ------- ----------- -------- ----------- + # outer_controls -> | bcast | -> ensemble_controls -> | transform | -> ensemble_functional -> | gather | -> gather_functional -> | gather_rf | -> outer_functional + # ------- ----------- -------- ----------- + + ensemble_types = (EnsembleFunctionBase, EnsembleAdjVec) + + # Do we need to broadcast the controls? + is_ensemble_control = set(isinstance(c.control, ensemble_types) + for c in self.controls) + if len(is_ensemble_control) != 1: + raise TypeError( + "Either all or none of the controls must be ensemble types") + + if is_ensemble_control.pop(): + self._input_op = 'none' else: - raise NotImplementedError("This type of functional is not supported.") - return total_tlm + self._input_op = 'bcast' + + # Do we need to reduce or gather the functional? + if isinstance(functional, ensemble_types): + self._output_op = 'none' + elif gather_rf is None: + self._output_op = 'reduce' + else: + self._output_op = 'gather' + + # build ensemble_controls + if self._input_op == 'bcast': + ensemble_controls = [ + Control(_make_ensemble_obj( + [control.control for _ in range(local_size)], + ensemble=self.ensemble)) + for control in self.controls + ] + elif self._input_op == 'none': + ensemble_controls = self.controls + + # build ensemble_functional + if self._output_op in ('reduce', 'gather'): + ensemble_functional = _make_ensemble_obj( + [rf.functional for rf in self._local_rfs], + ensemble=self.ensemble) + elif self._output_op == 'none': + ensemble_functional = self.functional + # build the transform + self._ensemble_transform = EnsembleTransformReducedFunctional( + ensemble_functional, + self.controls.delist(ensemble_controls), + self._local_rfs) + + # build the input operation + if self._input_op == 'bcast': + # controls are Functions or AdjFloats, so need to bcast to + # EnsembleFunctions or EnsembleAdjVecs for EnsembleTransform input. + self._ensemble_bcast = [ + EnsembleBcastReducedFunctional( + ec.control._ad_init_zero(), + Control(c.control._ad_init_zero())) + for ec, c in zip(ensemble_controls, self.controls)] + + # build the output operation + if self._output_op == 'reduce': + # functional is Function or AdjFloat, so need to reduce from + # EnsembleFunction or EnsembleAdjVec from EnsembleTransform output. + self._ensemble_reduce = EnsembleReduceReducedFunctional( + functional._ad_init_zero(), + Control(ensemble_functional._ad_init_zero())) + + if self._output_op == 'gather': + # Need to gather all components of the ensemble_functional onto + # each rank to pipe through gather_rf before returning result. + # check gather takes right type of arguments + if len(gather_rf.controls) != global_size: + raise ValueError("gather_rf must have one control for" + " each component on all ensemble members") + + # check gather takes correct type of arguments + gather_type = type(self._local_rfs[0].functional) + if not all(isinstance(c.control, gather_type) + for c in gather_rf.controls): + raise ValueError( + "gather_rf.controls must match types of rf.functional") + + # Now make the massive output type for the gather + gather_functional = _make_ensemble_obj( + [c.control._ad_init_zero() for c in gather_rf.controls], + ensemble=ensemble) + + self._ensemble_gather = EnsembleAllgatherReducedFunctional( + gather_functional, + Control(ensemble_functional._ad_init_zero())) + + @property + def controls(self): + return self._controls + + @property + def ensemble(self): + """The :class:`.Ensemble` that the reduced functional is + defined over.""" + return self._ensemble + + @no_annotations + def __call__(self, values): + vals = Enlist(values) + for c, v in zip(self.controls, vals): + c.update(v) + + if self._input_op == 'bcast': + vals = [bcast(val) + for bcast, val in zip(self._ensemble_bcast, vals)] + + J = self._ensemble_transform(vals) + + if self._output_op == 'reduce': + # just sum the transform results + J = self._ensemble_reduce(J) + elif self._output_op == 'gather': + # pipe the gathered transform results through the gather_rf + gathered_Js = self._ensemble_gather(J) + J = self._gather_rf(_local_subs(gathered_Js)) + + return J + + @no_annotations + def derivative(self, adj_input=1.0, apply_riesz=False): + if self._output_op == 'reduce': + # just broadcast the adj_input + adj_input = self._ensemble_reduce.derivative(adj_input, + apply_riesz=False) + + elif self._output_op == 'gather': + # pipe the adj_input through the gather_rf before broadcasting + local_input = self._gather_rf.derivative(adj_input, + apply_riesz=False) + + global_input = _make_ensemble_obj(local_input, self.ensemble) + + adj_input = self._ensemble_gather.derivative(global_input, + apply_riesz=False) + + transform_riesz = apply_riesz if self._input_op == 'none' else False + + dJ = self._ensemble_transform.derivative( + adj_input=adj_input, apply_riesz=transform_riesz) + + if self._input_op == 'bcast': + dJ = self.controls.delist( + [bcast.derivative(adj_input=dj, apply_riesz=apply_riesz) + for bcast, dj in zip(self._ensemble_bcast, Enlist(dJ))]) + + return dJ + + @no_annotations + def tlm(self, m_dot): + if self._input_op == 'bcast': + m_dot = [bcast(md) + for bcast, md in zip(self._ensemble_bcast, Enlist(m_dot))] + + tlv = self._ensemble_transform.tlm(m_dot) + + if self._output_op == 'reduce': + # just sum the transform results + tlv = self._ensemble_reduce.tlm(tlv) + elif self._output_op == 'gather': + # pipe the gathered transform results through the gather_rf + gathered_tlvs = self._ensemble_gather.tlm(tlv) + tlv = self._gather_rf.tlm(_local_subs(gathered_tlvs)) + + return tlv + + @no_annotations def hessian(self, m_dot, hessian_input=None, evaluate_tlm=True, apply_riesz=False): - """The Hessian is not yet implemented for ensemble reduced functional. + if evaluate_tlm: + self.tlm(m_dot) + + hkwargs = {'m_dot': None, 'evaluate_tlm': False} + + if self._output_op == 'reduce': + # just broadcast the hessian_input + hessian_input = self._ensemble_reduce.hessian( + **hkwargs, hessian_input=hessian_input, apply_riesz=False) + + elif self._output_op == 'gather': + # pipe the hessian_input through the gather_rf before broadcasting + local_input = self._gather_rf.hessian( + **hkwargs, hessian_input=hessian_input, apply_riesz=False) + + global_input = _make_ensemble_obj(local_input, self.ensemble) + + hessian_input = self._ensemble_gather.hessian( + **hkwargs, hessian_input=global_input, apply_riesz=False) + + transform_riesz = apply_riesz if self._input_op == 'none' else False + + hessian = self._ensemble_transform.hessian( + **hkwargs, hessian_input=hessian_input, apply_riesz=transform_riesz) + + if self._input_op == 'bcast': + hessian = self.controls.delist( + [bcast.hessian(**hkwargs, hessian_input=hess, apply_riesz=apply_riesz) + for bcast, hess in zip(self._ensemble_bcast, Enlist(hessian))]) - Raises: - NotImplementedError: This method is not yet implemented for ensemble reduced functional. - """ - raise NotImplementedError("Hessian is not yet implemented for ensemble reduced functional.") + return hessian diff --git a/tests/firedrake/adjoint/test_ensemble_reduced_functional.py b/tests/firedrake/adjoint/test_ensemble_reduced_functional.py index bd654ed52d..a15ef041ba 100644 --- a/tests/firedrake/adjoint/test_ensemble_reduced_functional.py +++ b/tests/firedrake/adjoint/test_ensemble_reduced_functional.py @@ -1,8 +1,11 @@ from firedrake import * from firedrake.adjoint import * +from firedrake.adjoint.ensemble_adjvec import EnsembleAdjVec from pyadjoint.reduced_functional_numpy import ReducedFunctionalNumPy import pytest from numpy.testing import assert_allclose +from numpy import mean +from pytest_mpi.parallel_assert import parallel_assert @pytest.fixture(autouse=True) @@ -10,132 +13,978 @@ def autouse_set_test_tape(set_test_tape): pass -@pytest.mark.parallel(nprocs=4) -@pytest.mark.skipcomplex # Taping for complex-valued 0-forms not yet done -def test_verification(): - ensemble = Ensemble(COMM_WORLD, 2) - size = ensemble.ensemble_comm.size - mesh = UnitSquareMesh(4, 4, comm=ensemble.comm) +@pytest.mark.parallel(nprocs=[1, 2, 3, 6]) +@pytest.mark.skipcomplex +def test_ensemble_bcast_float(): + ensemble = Ensemble(COMM_WORLD, 1) + + rank = ensemble.ensemble_rank + size = ensemble.ensemble_size + + nglobal_floats = 6 + nlocal_floats = nglobal_floats // size + + c = AdjFloat(0.0) + J = EnsembleAdjVec( + [AdjFloat(0.0) for _ in range(nlocal_floats)], ensemble) + + Jhat = EnsembleBcastReducedFunctional(J, Control(c)) + + # check the control is broadcast to all ranks + eps = 1e-12 + + x = AdjFloat(3.0) + Jx = Jhat(x) + + expect = x + match_local = all((Ji - expect) < eps for Ji in Jx.subvec) + + parallel_assert( + match_local, + msg=f"Broadcast AdjFloats {Jx.subvec} do not match expected value {expect}." + ) + + # Check the adjoint is reduced back to all ranks. + # Because the functional is an array we need to + # pass an adj_input of an array. + + offset = rank*nlocal_floats + adj_input = EnsembleAdjVec( + [AdjFloat(offset + i + 1.0) for i in range(nlocal_floats)], + ensemble=ensemble) + + expect = AdjFloat(sum(i+1.0 for i in range(nglobal_floats))) + + dJ = Jhat.derivative(adj_input=adj_input, apply_riesz=True) + + match_local = dJ - expect < eps + + parallel_assert( + match_local, + msg=f"Broadcast derivative {dJ} does not match" + f" expected value {expect}." + ) + + +@pytest.mark.parallel(nprocs=[1, 2, 3, 4, 6]) +@pytest.mark.skipcomplex +def test_ensemble_bcast_function(): + nspatial_ranks = 2 if COMM_WORLD.size == 4 else 1 + ensemble = Ensemble(COMM_WORLD, nspatial_ranks) + + rank = ensemble.ensemble_rank + size = ensemble.ensemble_size + + nglobal_funcs = 6 + nlocal_funcs = nglobal_funcs // size + + mesh = UnitIntervalMesh(12, comm=ensemble.comm) + R = FunctionSpace(mesh, "R", 0) - x = function.Function(R, val=1.0) - J = assemble(x * x * dx(domain=mesh)) - rf = EnsembleReducedFunctional(J, Control(x), ensemble) - ensemble_J = rf(x) - assert_allclose(ensemble_J, size, rtol=1e-12) - dJdm = rf.derivative() - assert_allclose(dJdm.dat.data_ro, 2.0 * size, rtol=1e-12) - assert taylor_test(rf, x, Function(R, val=0.1)) > 1.9 + Re = EnsembleFunctionSpace( + [R for _ in range(nlocal_funcs)], ensemble) + c = Function(R).assign(1.0) + J = EnsembleFunction(Re) + + Jhat = EnsembleBcastReducedFunctional(J, Control(c)) + + # check the control is broadcast to all ranks + eps = 1e-12 + + x = Function(R).assign(3.0) + Jx = Jhat(x) + + expect = x + match_local = all(errornorm(Ji, expect) < eps + for Ji in Jx.subfunctions) + + parallel_assert( + match_local, + msg=f"Broadcast Functions do not match on rank {rank}" + ) + + # Check the adjoint is reduced back to all ranks. + # Because the functional is an EnsembleFunction we + # need to pass an adj_input of an EnsembleCofunction. + + adj_input = EnsembleFunction(Re) + offset = rank*nlocal_funcs + for i, adji in enumerate(adj_input.subfunctions): + adji.assign(offset + i + 1.0) + adj_input = adj_input.riesz_representation() + + expect = Function(R).assign(sum(i+1.0 for i in range(nglobal_funcs))) + + dJ = Jhat.derivative(adj_input=adj_input, apply_riesz=True) + + match_local = errornorm(dJ, expect) < eps + + parallel_assert( + match_local, + msg=f"Broadcast derivative {dJ.dat.data[:]} does not match" + f" expected value {expect.dat.data[:]}." + ) + + +@pytest.mark.parallel(nprocs=[1, 2, 3, 6]) +@pytest.mark.skipcomplex +def test_ensemble_reduction_float(): + ensemble = Ensemble(COMM_WORLD, 1) + + rank = ensemble.ensemble_rank + size = ensemble.ensemble_size + + nglobal_floats = 6 + nlocal_floats = nglobal_floats // size + + control = EnsembleAdjVec( + [AdjFloat(0.0) for _ in range(nlocal_floats)], + ensemble=ensemble) + J = AdjFloat(0.0) + + Jhat = EnsembleReduceReducedFunctional(J, Control(control)) + + # check the control is reduced to all ranks + eps = 1e-12 + + offset = rank*nlocal_floats + x = EnsembleAdjVec( + [AdjFloat(offset + i + 1.0) for i in range(nlocal_floats)], + ensemble=ensemble) + + Jx = Jhat(x) + + expect = AdjFloat(sum(i+1.0 for i in range(nglobal_floats))) + match_local = Jx - expect < eps + + parallel_assert( + match_local, + msg=f"Reduced AdjFloat {Jx} does not match" + f" expected value {expect}" + ) + + # TLM + tlmx = Jhat.tlm(x) + + match_local = tlmx - expect < eps + + parallel_assert( + match_local, + msg=f"Reduced TLM AdjFloat {tlmx} does not match" + f" expected value {expect}" + ) + + # Check the adjoint is broadcast back to all ranks. + # Because the functional is a Function we need to + # pass an adj_input of an Cofunction. + + adj_value = 20.0 + adj_input = AdjFloat(adj_value) + + expect = AdjFloat(adj_value) + dJ = Jhat.derivative(adj_input=adj_input, apply_riesz=True) + + match_local = all((dJi - expect) < eps for dJi in dJ.subvec) + + parallel_assert( + match_local, + msg=f"Reduced derivatives {dJ} do not match expected value {expect}." + ) + + for i, dj in enumerate(dJ.subvec): + dj *= (i+2)*0.3 + taylor = taylor_to_dict(Jhat, x, dJ) + + # derivative and hessian should be "exact" + assert mean(taylor['R0']['Rate']) + assert all(r < 1e-14 for r in taylor['R1']['Residual']) + assert all(r < 1e-14 for r in taylor['R2']['Residual']) + + +@pytest.mark.parallel(nprocs=[1, 2, 3, 4, 6]) +@pytest.mark.skipcomplex +def test_ensemble_reduction_function(): + nspatial_ranks = 2 if COMM_WORLD.size == 4 else 1 + ensemble = Ensemble(COMM_WORLD, nspatial_ranks) + + rank = ensemble.ensemble_rank + size = ensemble.ensemble_size + + nglobal_funcs = 6 + nlocal_funcs = nglobal_funcs // size + + mesh = UnitIntervalMesh(12, comm=ensemble.comm) -@pytest.mark.parallel(nprocs=4) -@pytest.mark.skipcomplex # Taping for complex-valued 0-forms not yet done -def test_verification_gather_functional_adjfloat_evaluation(): - ensemble = Ensemble(COMM_WORLD, 2) - rank = ensemble.ensemble_comm.rank - mesh = UnitSquareMesh(4, 4, comm=ensemble.comm) R = FunctionSpace(mesh, "R", 0) - x = function.Function(R, val=rank+1) - J = assemble(x * x * dx(domain=mesh)) - a = AdjFloat(1.0) - b = AdjFloat(1.0) - Jg_m = [Control(a), Control(b)] - Jg = ReducedFunctional(a**2 + b**2, Jg_m) - rf = EnsembleReducedFunctional(J, Control(x), ensemble, - scatter_control=False, - gather_functional=Jg) - ensemble_J = rf(x) - dJdm = rf.derivative() - assert_allclose(ensemble_J, 1.0**4+2.0**4, rtol=1e-12) - assert_allclose(dJdm.dat.data_ro, 4*(rank+1)**3, rtol=1e-12) - - -@pytest.mark.parallel(nprocs=4) -@pytest.mark.skipcomplex # Taping for complex-valued 0-forms not yet done -@pytest.mark.xfail(reason="Taylor's test fails because the inner product \ - between the perturbation and gradient is not allreduced \ - for `scatter_control=False`.") -def test_verification_gather_functional_adjfloat_taylor(): - ensemble = Ensemble(COMM_WORLD, 2) - rank = ensemble.ensemble_comm.rank - mesh = UnitSquareMesh(4, 4, comm=ensemble.comm) + Re = EnsembleFunctionSpace( + [R for _ in range(nlocal_funcs)], ensemble) + + c = EnsembleFunction(Re) + J = Function(R) + + Jhat = EnsembleReduceReducedFunctional(J, Control(c)) + + # check the control is reduced to all ranks + eps = 1e-12 + + x = EnsembleFunction(Re) + + offset = rank*nlocal_funcs + for i, xi in enumerate(x.subfunctions): + xi.assign(offset + i + 1.0) + + Jx = Jhat(x) + + expect = Function(R).assign(sum(i+1.0 for i in range(nglobal_funcs))) + match_local = errornorm(Jx, expect) < eps + + parallel_assert( + match_local, + msg=f"Reduced Function {Jx.dat.data[:]} does not match" + f" expected value {expect.dat.data[:]}" + ) + + # Check the adjoint is broadcast back to all ranks. + # Because the functional is a Function we need to + # pass an adj_input of an Cofunction. + + adj_value = 20.0 + adj_input = (Function(R).assign(adj_value)).riesz_representation() + + expect = Function(R).assign(adj_value) + dJ = Jhat.derivative(adj_input=adj_input, apply_riesz=True) + + match_local = all(errornorm(dJi, expect) < eps + for dJi in dJ.subfunctions) + + parallel_assert( + match_local, + msg=f"Reduced derivatives {[dJi.dat.data[:] for dJi in dJ.subfunctions]}" + f" do not match expected value {expect.dat.data[:]}." + ) + + +@pytest.mark.parallel(nprocs=[1, 2, 3, 4, 6]) +@pytest.mark.skipcomplex +def test_ensemble_transform_float(): + nspatial_ranks = 2 if COMM_WORLD.size == 4 else 1 + ensemble = Ensemble(COMM_WORLD, nspatial_ranks) + + rank = ensemble.ensemble_rank + size = ensemble.ensemble_size + + nglobal_funcs = 6 + nlocal_funcs = nglobal_funcs // size + + mesh = UnitIntervalMesh(12, comm=ensemble.comm) + R = FunctionSpace(mesh, "R", 0) - x = function.Function(R, val=rank+1) - J = assemble(x * x * dx(domain=mesh)) - a = AdjFloat(1.0) - b = AdjFloat(1.0) - Jg_m = [Control(a), Control(b)] - Jg = ReducedFunctional(a**2 + b**2, Jg_m) - rf = EnsembleReducedFunctional(J, Control(x), ensemble, - scatter_control=False, - gather_functional=Jg) - assert taylor_test(rf, x, Function(R, val=0.1)) > 1.9 - - -@pytest.mark.parallel(nprocs=4) -@pytest.mark.skipcomplex # Taping for complex-valued 0-forms not yet done -def test_verification_gather_functional_Function_evaluation(): - ensemble = Ensemble(COMM_WORLD, 2) - rank = ensemble.ensemble_comm.rank - mesh = UnitSquareMesh(4, 4, comm=ensemble.comm) + Re = EnsembleFunctionSpace( + [R for _ in range(nlocal_funcs)], ensemble) + + c = EnsembleFunction(Re) + + rfs = [] + J = [] + offset = rank*nlocal_funcs + for ci in c.subfunctions: + with set_working_tape() as tape: + Ji = assemble(ci*ci*dx) + J.append(Ji) + rfs.append(ReducedFunctional(Ji, Control(ci), tape=tape)) + + J = EnsembleAdjVec(J, ensemble) + + Jhat = EnsembleTransformReducedFunctional(J, Control(c), rfs) + + # check the control is reduced to all ranks + eps = 1e-12 + + x = EnsembleFunction(Re) + + for i, xi in enumerate(x.subfunctions): + xi.assign(offset + i + 1.0) + + # check + Jx = Jhat(x) + + expect = [rf(xi) for rf, xi in zip(rfs, x.subfunctions)] + + match_local = all((Ji - ei) < eps for Ji, ei in zip(Jx.subvec, expect)) + + parallel_assert( + match_local, + msg=f"Transformed results {Jx} do not match expected values {expect}" + ) + + # Check the adjoint matches on all slots. + # Because the functional is a list[AdjFloat] we need to + # pass an adj_input of a list[AdjFloat]. + + adj_input = EnsembleAdjVec( + [AdjFloat(offset + i + 1.0) + for i in range(nlocal_funcs)], + ensemble=ensemble) + + expect = EnsembleFunction(Re) + for rf, adji, ei in zip(rfs, adj_input.subvec, expect.subfunctions): + ei.assign(rf.derivative(adj_input=adji, apply_riesz=True)) + + dJ = Jhat.derivative(adj_input=adj_input, apply_riesz=True) + + match_local = all( + errornorm(dJi, ei) < eps + for dJi, ei in zip(dJ.subfunctions, expect.subfunctions)) + + parallel_assert( + match_local, + msg=f"Reduced derivatives {[dJi.dat.data[:] for dJi in dJ.subfunctions]}" + f" do not match expected value {[ei.dat.data[:] for ei in expect.subfunctions]}." + ) + + _ = Jhat.tlm(x) + _ = Jhat.hessian(x, hessian_input=adj_input) + + +@pytest.mark.parallel(nprocs=[1, 2, 3, 4, 6]) +@pytest.mark.skipcomplex +def test_ensemble_transform_float_two_controls(): + nspatial_ranks = 2 if COMM_WORLD.size == 4 else 1 + ensemble = Ensemble(COMM_WORLD, nspatial_ranks) + + rank = ensemble.ensemble_rank + size = ensemble.ensemble_size + + nglobal_funcs = 6 + nlocal_funcs = nglobal_funcs // size + + mesh = UnitIntervalMesh(12, comm=ensemble.comm) + R = FunctionSpace(mesh, "R", 0) - x = function.Function(R, val=rank+1) - J = Function(R).assign(x**2) - a = Function(R).assign(1.0) - b = Function(R).assign(1.0) - Jg_m = [Control(a), Control(b)] - Jg = assemble((a**2 + b**2)*dx) - Jghat = ReducedFunctional(Jg, Jg_m) - rf = EnsembleReducedFunctional(J, Control(x), ensemble, - scatter_control=False, - gather_functional=Jghat) - ensemble_J = rf(x) - dJdm = rf.derivative() - assert_allclose(ensemble_J, 1.0**4+2.0**4, rtol=1e-12) - assert_allclose(dJdm.dat.data_ro, 4*(rank+1)**3, rtol=1e-12) - - -@pytest.mark.parallel(nprocs=4) + Re = EnsembleFunctionSpace( + [R for _ in range(nlocal_funcs)], ensemble) + + c0 = EnsembleFunction(Re) + c1 = EnsembleFunction(Re) + + rfs = [] + J = [] + offset = rank*nlocal_funcs + for c0i, c1i in zip(c0.subfunctions, c1.subfunctions): + with set_working_tape() as tape: + Ji = assemble((c0i*c0i + c1i*c1i)*dx) + J.append(Ji) + rfs.append(ReducedFunctional( + Ji, [Control(c0i), Control(c1i)], tape=tape)) + + J = EnsembleAdjVec(J, ensemble) + + Jhat = EnsembleTransformReducedFunctional( + J, [Control(c0), Control(c1)], rfs) + + # check the control is reduced to all ranks + eps = 1e-12 + + x0 = EnsembleFunction(Re) + x1 = EnsembleFunction(Re) + + for i, (x0i, x1i) in enumerate(zip(x0.subfunctions, x1.subfunctions)): + x0i.assign(offset + i + 1.0) + x1i.assign(2*(offset + i + 1.0)) + + # check + Jx = Jhat([x0, x1]) + + expect = [rf([x0i, x1i]) + for rf, x0i, x1i in zip(rfs, x0.subfunctions, x1.subfunctions)] + + match_local = all((Ji - ei) < eps for Ji, ei in zip(Jx.subvec, expect)) + + parallel_assert( + match_local, + msg=f"Transformed results {Jx} do not match expected values {expect}" + ) + + # Check the adjoint matches on all slots. + # Because the functional is a AdjFloat we need to + # pass an adj_input of a list[AdjFloat]. + + adj_input = EnsembleAdjVec( + [AdjFloat(offset + i + 1.0) + for i in range(nlocal_funcs)], + ensemble=ensemble) + + expect0 = EnsembleFunction(Re) + expect1 = EnsembleFunction(Re) + for rf, adji, e0i, e1i in zip(rfs, adj_input.subvec, + expect0.subfunctions, + expect1.subfunctions): + e0, e1 = rf.derivative(adj_input=adji, apply_riesz=True) + e0i.assign(e0) + e1i.assign(e1) + + dJ0, dJ1 = Jhat.derivative(adj_input=adj_input, apply_riesz=True) + + match_local0 = all( + errornorm(dJi, ei) < eps + for dJi, ei in zip(dJ0.subfunctions, expect0.subfunctions)) + + match_local1 = all( + errornorm(dJi, ei) < eps + for dJi, ei in zip(dJ1.subfunctions, expect1.subfunctions)) + + parallel_assert( + match_local0, + msg=f"Reduced derivatives {[dJ0i.dat.data[:] for dJ0i in dJ0.subfunctions]}" + f" do not match expected value {[e0i.dat.data[:] for e0i in expect0.subfunctions]}." + ) + + parallel_assert( + match_local1, + msg=f"Reduced derivatives {[dJ1i.dat.data[:] for dJ1i in dJ1.subfunctions]}" + f" do not match expected value {[e1i.dat.data[:] for e1i in expect1.subfunctions]}." + ) + + _ = Jhat.tlm([x0, x1]) + _ = Jhat.hessian([x0, x1], hessian_input=adj_input) + + +@pytest.mark.parallel(nprocs=[1, 2, 3, 4, 6]) +@pytest.mark.skipcomplex +def test_ensemble_transform_function(): + nspatial_ranks = 2 if COMM_WORLD.size == 4 else 1 + ensemble = Ensemble(COMM_WORLD, nspatial_ranks) + + rank = ensemble.ensemble_rank + size = ensemble.ensemble_size + + nglobal_funcs = 6 + nlocal_funcs = nglobal_funcs // size + + mesh = UnitIntervalMesh(12, comm=ensemble.comm) + + R = FunctionSpace(mesh, "R", 0) + Re = EnsembleFunctionSpace( + [R for _ in range(nlocal_funcs)], ensemble) + + c = EnsembleFunction(Re) + J = EnsembleFunction(Re) + + rfs = [] + offset = rank*nlocal_funcs + for i, (Ji, ci) in enumerate(zip(J.subfunctions, c.subfunctions)): + with set_working_tape() as tape: + Ji.assign(ci) + Ji += 2*(offset + i + 1.0) + rfs.append(ReducedFunctional(Ji, Control(ci), tape=tape)) + + Jhat = EnsembleTransformReducedFunctional(J, Control(c), rfs) + + # check the control is reduced to all ranks + eps = 1e-12 + + x = EnsembleFunction(Re) + + for i, xi in enumerate(x.subfunctions): + xi.assign(offset + i + 1.0) + + # check + Jx = Jhat(x) + + expect = EnsembleFunction(Re) + for rf, xi, ei in zip(rfs, x.subfunctions, + expect.subfunctions): + ei.assign(rf(xi)) + + match_local = all( + errornorm(Ji, ei) < eps + for Ji, ei in zip(Jx.subfunctions, expect.subfunctions)) + + parallel_assert( + match_local, + msg=f"Transformed Functions {[Ji.dat.data[:] for Ji in Jx.subfunctions]}" + f" do not match expected value {[ei.dat.data[:] for ei in expect.subfunctions]}" + ) + + # Check the adjoint matches on all slots. + # Because the functional is a Function we need to + # pass an adj_input of an Cofunction. + + adj_input = EnsembleFunction(Re) + for i, adj in enumerate(adj_input.subfunctions): + adj.assign(offset + i + 1.0) + + adj_input = adj_input.riesz_representation() + + expect = EnsembleFunction(Re) + for rf, adji, ei in zip(rfs, adj_input.subfunctions, + expect.subfunctions): + ei.assign(rf.derivative(adj_input=adji, apply_riesz=True)) + + dJ = Jhat.derivative(adj_input=adj_input, apply_riesz=True) + + match_local = all( + errornorm(dJi, ei) < eps + for dJi, ei in zip(dJ.subfunctions, expect.subfunctions)) + + parallel_assert( + match_local, + msg=f"Reduced derivatives {[dJi.dat.data[:] for dJi in dJ.subfunctions]}" + f" do not match expected value {[ei.dat.data[:] for ei in expect.subfunctions]}." + ) + + _ = Jhat.tlm(x) + _ = Jhat.hessian(x, hessian_input=adj_input) + + +@pytest.mark.parallel(nprocs=[1, 2, 3, 4, 6]) +@pytest.mark.skipcomplex +def test_ensemble_transform_function_two_controls(): + nspatial_ranks = 2 if COMM_WORLD.size == 4 else 1 + ensemble = Ensemble(COMM_WORLD, nspatial_ranks) + + rank = ensemble.ensemble_rank + size = ensemble.ensemble_size + + nglobal_funcs = 6 + nlocal_funcs = nglobal_funcs // size + + mesh = UnitIntervalMesh(12, comm=ensemble.comm) + + R = FunctionSpace(mesh, "R", 0) + Re = EnsembleFunctionSpace( + [R for _ in range(nlocal_funcs)], ensemble) + + c0 = EnsembleFunction(Re) + c1 = EnsembleFunction(Re) + J = EnsembleFunction(Re) + + rfs = [] + offset = rank*nlocal_funcs + for i, (Ji, c0i, c1i) in enumerate(zip(J.subfunctions, + c0.subfunctions, + c1.subfunctions)): + with set_working_tape() as tape: + Ji.assign(c0i + c1i) + rfs.append(ReducedFunctional( + Ji, [Control(c0i), Control(c1i)], tape=tape)) + + Jhat = EnsembleTransformReducedFunctional(J, [Control(c0), Control(c1)], rfs) + + # check the control is reduced to all ranks + eps = 1e-12 + + x0 = EnsembleFunction(Re) + x1 = EnsembleFunction(Re) + + for i, (x0i, x1i) in enumerate(zip(x0.subfunctions, + x1.subfunctions)): + x0i.assign(offset + i + 1.0) + x1i.assign(2*(offset + i + 1.0)) + + Jx = Jhat([x0, x1]) + + expect = EnsembleFunction(Re) + for rf, x0i, x1i, ei in zip(rfs, x0.subfunctions, + x1.subfunctions, + expect.subfunctions): + ei.assign(rf([x0i, x1i])) + + match_local = all( + errornorm(Ji, ei) < eps + for Ji, ei in zip(Jx.subfunctions, + expect.subfunctions)) + + parallel_assert( + match_local, + msg=f"Transformed Functions {[Ji.dat.data[:] for Ji in Jx.subfunctions]}" + f" do not match expected value {[ei.dat.data[:] for ei in expect.subfunctions]}" + ) + + # Check the adjoint matches on all slots. + # Because the functional is a Function we need to + # pass an adj_input of an Cofunction. + + adj_input = EnsembleFunction(Re) + for i, adj in enumerate(adj_input.subfunctions): + adj.assign(offset + i + 1.0) + + adj_input = adj_input.riesz_representation() + + expect0 = EnsembleFunction(Re) + expect1 = EnsembleFunction(Re) + for rf, adji, e0i, e1i in zip(rfs, adj_input.subfunctions, + expect0.subfunctions, + expect1.subfunctions): + e0, e1 = rf.derivative(adj_input=adji, apply_riesz=True) + e0i.assign(e0) + e1i.assign(e1) + + dJ0, dJ1 = Jhat.derivative(adj_input=adj_input, apply_riesz=True) + + match_local0 = all( + errornorm(dJi, ei) < eps + for dJi, ei in zip(dJ0.subfunctions, expect0.subfunctions)) + + match_local1 = all( + errornorm(dJi, ei) < eps + for dJi, ei in zip(dJ1.subfunctions, expect1.subfunctions)) + + parallel_assert( + match_local0, + msg=f"Reduced derivatives {[dJ0i.dat.data[:] for dJ0i in dJ0.subfunctions]}" + f" do not match expected value {[e0i.dat.data[:] for e0i in expect0.subfunctions]}." + ) + + parallel_assert( + match_local1, + msg=f"Reduced derivatives {[dJ1i.dat.data[:] for dJ1i in dJ1.subfunctions]}" + f" do not match expected value {[e1i.dat.data[:] for e1i in expect1.subfunctions]}." + ) + + _ = Jhat.tlm([x0, x1]) + _ = Jhat.hessian([x0, x1], hessian_input=adj_input) + + +@pytest.mark.parallel(nprocs=[1, 2, 3, 4, 6]) @pytest.mark.skipcomplex # Taping for complex-valued 0-forms not yet done -@pytest.mark.xfail(reason="Taylor's test fails because the inner product \ - between the perturbation and gradient is not allreduced \ - for `scatter_control=False`.") -def test_verification_gather_functional_Function_taylor(): - ensemble = Ensemble(COMM_WORLD, 2) - rank = ensemble.ensemble_comm.rank - mesh = UnitSquareMesh(4, 4, comm=ensemble.comm) +def test_ensemble_rf_function_to_float(): + nspatial_ranks = 2 if COMM_WORLD.size == 4 else 1 + ensemble = Ensemble(COMM_WORLD, nspatial_ranks) + + rank = ensemble.ensemble_rank + size = ensemble.ensemble_size + + nglobal_rfs = 6 + nlocal_rfs = nglobal_rfs // size + + mesh = UnitIntervalMesh(12, comm=ensemble.comm) R = FunctionSpace(mesh, "R", 0) - x = function.Function(R, val=rank+1) - J = Function(R).assign(x**2) - a = Function(R).assign(1.0) - b = Function(R).assign(1.0) - Jg_m = [Control(a), Control(b)] - Jg = assemble((a**2 + b**2)*dx) - Jghat = ReducedFunctional(Jg, Jg_m) - rf = EnsembleReducedFunctional(J, Control(x), ensemble, - scatter_control=False, - gather_functional=Jghat) - assert taylor_test(rf, x, Function(R, val=0.1)) > 1.9 - - -@pytest.mark.parallel(nprocs=3) + + control = Control(Function(R)) + J = AdjFloat(0.) + + rfs = [] + offset = rank*nlocal_rfs + for i in range(nlocal_rfs): + c = Function(R) + weight = (offset + i + 1.0) + with set_working_tape() as tape: + Ji = weight*assemble((c**4)*dx) + rfs.append( + ReducedFunctional(Ji, Control(c), tape=tape)) + + Jhat = EnsembleReducedFunctional( + J, control, rfs, ensemble=ensemble) + + sum_weights = sum((i + 1.0) for i in range(nglobal_rfs)) + + xval = 3.0 + Jexpect = (xval**4)*sum_weights + + x = Function(R).assign(xval) + J = Jhat(x) + assert_allclose(J, Jexpect, rtol=1e-12) + + adj_input = AdjFloat(4.0) + edJ = Jhat.derivative(adj_input=adj_input, apply_riesz=True) + + assert_allclose(edJ.dat.data_ro, adj_input*(4*xval**3)*sum_weights, rtol=1e-12) + + dy = Function(R, val=0.1) + assert taylor_test(Jhat, x, dy) > 1.95 + + _ = Jhat.tlm(x) + _ = Jhat.hessian(x) + + taylor = taylor_to_dict(Jhat, x, dy) + + R0 = mean(taylor['R0']['Rate']) + R1 = mean(taylor['R1']['Rate']) + R2 = mean(taylor['R2']['Rate']) + + assert R0 > 0.95 + assert R1 > 1.95 + assert R2 > 2.95 + + +@pytest.mark.parallel(nprocs=[1, 2, 3, 4, 6]) @pytest.mark.skipcomplex # Taping for complex-valued 0-forms not yet done -def test_minimise(): - # Optimisation test using a list of controls. - # This test is equivalent to the one found at: - # https://github.com/firedrakeproject/firedrake/blob/master/tests/firedrake/adjoint/test_optimisation.py#L92 - # In this test, the functional is the result of an ensemble allreduce operation. +def test_ensemble_rf_efunction_to_float(): + nspatial_ranks = 2 if COMM_WORLD.size == 4 else 1 + ensemble = Ensemble(COMM_WORLD, nspatial_ranks) + + rank = ensemble.ensemble_rank + size = ensemble.ensemble_size + + nglobal_funcs = 6 + nlocal_funcs = nglobal_funcs // size + + mesh = UnitIntervalMesh(12, comm=ensemble.comm) + R = FunctionSpace(mesh, "R", 0) + Re = EnsembleFunctionSpace( + [R for _ in range(nlocal_funcs)], ensemble) + + control = Control(EnsembleFunction(Re)) + J = AdjFloat(0.) + + rfs = [] + for i in range(nlocal_funcs): + ci = Function(R) + with set_working_tape() as tape: + Ji = assemble((ci**4)*dx) + rfs.append( + ReducedFunctional(Ji, Control(ci), tape=tape)) + + Jhat = EnsembleReducedFunctional( + J, control, rfs, ensemble=ensemble) + + x = EnsembleFunction(Re) + + offset = rank*nlocal_funcs + for i, xi in enumerate(x.subfunctions): + xi.assign(offset + i + 1.0) + + eps = 1e-12 + + J = Jhat(x) + + expect = sum(w**4 for w in range(1, nglobal_funcs+1)) + + parallel_assert( + (J - expect) < eps, + msg=f"Functional {J} does not match expected value {expect}." + ) + + adj_input = 3.0 + dJ = Jhat.derivative(adj_input=adj_input, apply_riesz=True) + + expect = EnsembleFunction(Re) + for i, ei in enumerate(expect.subfunctions): + ei.assign(adj_input*4*(offset + i + 1.0)**3) + + match_local = all( + errornorm(dJi, ei) < eps + for dJi, ei in zip(dJ.subfunctions, expect.subfunctions)) + + parallel_assert( + match_local, + msg=f"Derivatives {[dJi.dat.data[:] for dJi in dJ.subfunctions]}" + f" do not match expected values {[ei.dat.data[:] for ei in expect.subfunctions]}." + ) + + dy = EnsembleFunction(Re) + + for i, dyi in enumerate(dy.subfunctions): + dyi.assign(0.1*(-0.5*offset - (i + 1.0))) + + assert taylor_test(Jhat, x, dy) > 1.95 + + _ = Jhat.tlm(x) + _ = Jhat.hessian(x) + + taylor = taylor_to_dict(Jhat, x, dy) + + assert min(taylor['R0']['Rate']) > 0.95, taylor['R0'] + assert min(taylor['R1']['Rate']) > 1.95, taylor['R1'] + assert min(taylor['R2']['Rate']) > 2.95, taylor['R2'] + + +@pytest.mark.parallel(nprocs=[1, 2, 4]) +@pytest.mark.skipcomplex +@pytest.mark.parametrize("control_type", ["Function", "AdjFloat"]) +def test_ensemble_rf_gather_local_control(control_type): ensemble = Ensemble(COMM_WORLD, 1) - mesh = UnitSquareMesh(4, 4, comm=ensemble.comm) + rank = ensemble.ensemble_rank + + nglobal_spaces = 4 + nlocal_spaces = nglobal_spaces//ensemble.ensemble_size + + if control_type == "Function": + mesh = UnitIntervalMesh(1, comm=ensemble.comm) + R = FunctionSpace(mesh, "DG", 0) + + def new_val(expr): + return Function(R).interpolate(expr) # assign can't do nonlinear exprs + + elif control_type == "AdjFloat": + + def new_val(expr): + return AdjFloat(expr) + + m = new_val(0.) + + # transform rfs for all ensemble components + Jlocals = [] + for i in range(nlocal_spaces): + idv = rank*nlocal_spaces + 1. + i + x = new_val(0.) + with set_working_tape() as tape: + Jexpr = idv*x**2 + J = assemble(Jexpr*dx) if control_type == "Function" else Jexpr + Jlocals.append(ReducedFunctional(J, Control(x), tape=tape)) + + # rf for the gathered outputs from each ensemble component + allgather_controls = [Control(AdjFloat(0.)) for _ in range(nglobal_spaces)] + with set_working_tape() as tape: + Jg = sum(c.control for c in allgather_controls)**3 + Jgather = ReducedFunctional(Jg, allgather_controls, tape=tape) + + # sanity check that the gather reduced functional is valid + mg = [AdjFloat(i+1.) for i in range(nglobal_spaces)] + hg = [AdjFloat(-2.*(i+1)*(-1)**i) for i in range(nglobal_spaces)] + + grf_taylor = taylor_to_dict(Jgather, mg, hg) + + parallel_assert(min(grf_taylor['R0']['Rate']) > 0.95, msg=str(grf_taylor['R0'])) + parallel_assert(min(grf_taylor['R1']['Rate']) > 1.95, msg=str(grf_taylor['R1'])) + parallel_assert(min(grf_taylor['R2']['Rate']) > 2.95, msg=str(grf_taylor['R2'])) + + # Now create the full reduced functional over the ensemble + Jhat = EnsembleReducedFunctional(AdjFloat(0.), Control(m), Jlocals, + gather_rf=Jgather, ensemble=ensemble) + + mval = 13. + hval = -6. + + # Is the re-evaluation correct? + mexpected = sum((i+1)*mval**2 for i in range(nglobal_spaces))**3 + hexpected = sum((i+1)*hval**2 for i in range(nglobal_spaces))**3 + + m = new_val(mval) + h = new_val(hval) + parallel_assert(abs(Jhat(h) - hexpected) < 1e-12) + parallel_assert(abs(Jhat(m) - mexpected) < 1e-12) + + # Is the TLM correct? + assert taylor_test(Jhat, m, h, dJdm=Jhat.tlm(h)) > 1.95 + + # Now check the derivative and hessian + taylor = taylor_to_dict(Jhat, m, h) + + parallel_assert(min(taylor['R0']['Rate']) > 0.95, msg=str(taylor['R0'])) + parallel_assert(min(taylor['R1']['Rate']) > 1.95, msg=str(taylor['R1'])) + parallel_assert(min(taylor['R2']['Rate']) > 2.95, msg=str(taylor['R2'])) + + +@pytest.mark.parallel(nprocs=[1, 2, 4]) +@pytest.mark.skipcomplex +@pytest.mark.parametrize("control_type", ["EnsembleFunction", "EnsembleAdjVec"]) +def test_ensemble_rf_gather_global_control(control_type): + ensemble = Ensemble(COMM_WORLD, 1) + rank = ensemble.ensemble_rank + + nglobal_spaces = 4 + nlocal_spaces = nglobal_spaces//ensemble.ensemble_size + + if control_type == "EnsembleFunction": + mesh = UnitIntervalMesh(1, comm=ensemble.comm) + R = FunctionSpace(mesh, "DG", 0) + V = EnsembleFunctionSpace([R for _ in range(nlocal_spaces)], ensemble) + m = EnsembleFunction(V) + + elif control_type == "EnsembleAdjVec": + m = EnsembleAdjVec([0. for _ in range(nlocal_spaces)], ensemble) + + # transform rfs for all ensemble components + Jlocals = [] + offset = rank*nlocal_spaces + 1. + for i in range(nlocal_spaces): + idv = offset + i + x = Function(R) if control_type == "EnsembleFunction" else AdjFloat(0.) + with set_working_tape() as tape: + Jexpr = idv*x**2 + J = assemble(Jexpr*dx) if control_type == "EnsembleFunction" else Jexpr + Jlocals.append(ReducedFunctional(J, Control(x), tape=tape)) + + # rf for the gathered outputs from each ensemble component + allgather_controls = [Control(AdjFloat(0.)) for _ in range(nglobal_spaces)] + with set_working_tape() as tape: + Jg = sum(c.control for c in allgather_controls)**3 + Jgather = ReducedFunctional(Jg, allgather_controls, tape=tape) + + # sanity check that the gather reduced functional is valid + mg = [AdjFloat(i+1.) for i in range(nglobal_spaces)] + hg = [AdjFloat(-2.*(i+1)*(-1)**i) for i in range(nglobal_spaces)] + + grf_taylor = taylor_to_dict(Jgather, mg, hg) + + parallel_assert(min(grf_taylor['R0']['Rate']) > 0.95, msg=str(grf_taylor['R0'])) + parallel_assert(min(grf_taylor['R1']['Rate']) > 1.95, msg=str(grf_taylor['R1'])) + parallel_assert(min(grf_taylor['R2']['Rate']) > 2.95, msg=str(grf_taylor['R2'])) + + # Now create the full reduced functional over the ensemble + Jhat = EnsembleReducedFunctional(AdjFloat(0.), Control(m), Jlocals, + gather_rf=Jgather, ensemble=ensemble) + + # Is the re-evaluation correct? + mexpected = sum((i+1)*(11.*(i+1))**2 for i in range(nglobal_spaces))**3 + hexpected = sum((i+1)*(-2.*(i+1))**2 for i in range(nglobal_spaces))**3 + + if control_type == "EnsembleFunction": + m = EnsembleFunction(V) + h = EnsembleFunction(V) + for i, (mi, hi) in enumerate(zip(m.subfunctions, h.subfunctions)): + mi.assign(11.*(offset+i)) + hi.assign(-2.*(offset+i)) + else: + m = EnsembleAdjVec([11.*(offset+i) for i in range(nlocal_spaces)], ensemble) + h = EnsembleAdjVec([-2.*(offset+i) for i in range(nlocal_spaces)], ensemble) + + parallel_assert(abs(Jhat(h) - hexpected) < 1e-12) + parallel_assert(abs(Jhat(m) - mexpected) < 1e-12) + + # Is the TLM correct? + assert taylor_test(Jhat, m, h, dJdm=Jhat.tlm(h)) > 1.95 + + # Now check the derivative and hessian + taylor = taylor_to_dict(Jhat, m, h) + + parallel_assert(min(taylor['R0']['Rate']) > 0.95, msg=str(taylor['R0'])) + parallel_assert(min(taylor['R1']['Rate']) > 1.95, msg=str(taylor['R1'])) + parallel_assert(min(taylor['R2']['Rate']) > 2.95, msg=str(taylor['R2'])) + + +@pytest.mark.parallel(nprocs=[1, 3]) +@pytest.mark.skipcomplex # Taping for complex-valued 0-forms not yet done +def test_ensemble_rf_minimise(): + """ + Optimisation test using a list of controls. + This test is equivalent to the one found at: + https://github.com/firedrakeproject/firedrake/blob/master/tests/firedrake/adjoint/test_optimisation.py#L92 + In this test, the functional is the result of an ensemble allreduce operation. + """ + nspatial_ranks = 1 + ensemble = Ensemble(COMM_WORLD, nspatial_ranks) + + size = ensemble.ensemble_size + + mesh = UnitIntervalMesh(1, comm=ensemble.comm) R = FunctionSpace(mesh, "R", 0) + + nglobal_rfs = 3 + nlocal_rfs = nglobal_rfs//size n = 2 - x = [Function(R) for i in range(n)] - c = [Control(xi) for xi in x] - # Rosenbrock function https://en.wikipedia.org/wiki/Rosenbrock_function - # with minimum at x = (1, 1, 1, ...) - with set_working_tape(): - f = 100*(x[1] - x[0]**2)**2 + (1 - x[0])**2 - J = assemble(f * dx(domain=mesh)) - rf = EnsembleReducedFunctional(J, c, ensemble) - rf_np = ReducedFunctionalNumPy(rf) + + rfs_local = [] + for _ in range(nlocal_rfs): + x = [Function(R) for _ in range(n)] + local_controls = [Control(xi) for xi in x] + with set_working_tape() as tape: + # Rosenbrock function https://en.wikipedia.org/wiki/Rosenbrock_function + # with minimum at x = (1, 1, 1, ...) + f = 100*(x[1] - x[0]**2)**2 + (1 - x[0])**2 + Jlocal = assemble(f*dx) + rfs_local.append( + ReducedFunctional(Jlocal, local_controls, tape=tape)) + + Jglobal = AdjFloat(0.) + controls_global = [Control(Function(R)) for _ in range(n)] + + Jhat = EnsembleReducedFunctional(Jglobal, controls_global, + rfs_local, ensemble=ensemble) + rf_np = ReducedFunctionalNumPy(Jhat) result = minimize(rf_np) + assert_allclose([float(xi) for xi in result], 1., rtol=1e-8) From 9e5ec000231402c8a670ecd983511979cc58e412 Mon Sep 17 00:00:00 2001 From: Josh Hope-Collins Date: Thu, 12 Mar 2026 15:35:32 +0000 Subject: [PATCH 4/4] update fwi demo with ensemblerf refactor --- .../full_waveform_inversion.py.rst | 52 ++++++++++++------- 1 file changed, 32 insertions(+), 20 deletions(-) diff --git a/demos/full_waveform_inversion/full_waveform_inversion.py.rst b/demos/full_waveform_inversion/full_waveform_inversion.py.rst index 06390072da..e47ae0b541 100644 --- a/demos/full_waveform_inversion/full_waveform_inversion.py.rst +++ b/demos/full_waveform_inversion/full_waveform_inversion.py.rst @@ -257,39 +257,51 @@ Next, the FWI problem is executed with the following steps: :align: center -To have the step 4, we need first to tape the forward problem. That is done by calling:: +To have the step 4, we need first to tape the forward problem. +That is done by calling :func:`~.pyadjoint.continue_annotation`. - from firedrake.adjoint import * - continue_annotation() - get_working_tape().progress_bar = ProgressBar -**Steps 2-3**: Solve the wave equation and compute the functional:: +**Steps 2-3**: Solve the wave equation and compute the functional. +We create a ``ReducedFunctional`` for each source, which for our +case means one per ensemble member. Creating a ``ReducedFunctional`` +per component that we are parallelising over (i.e. per source) - +rather than creating one per ensemble member - we can change +the ensemble parallel partition with minimal changes to the code.:: + + from firedrake.adjoint import * f = Cofunction(V.dual()) # Wave equation forcing term. solver, u_np1, u_n, u_nm1 = wave_equation_solver(c_guess, f, dt, V) interpolate_receivers = interpolate(u_np1, V_r) - J_val = 0.0 - for step in range(total_steps): - f.assign(ricker_wavelet(step * dt, frequency_peak) * q_s) - solver.solve() - u_nm1.assign(u_n) - u_n.assign(u_np1) - guess_receiver = assemble(interpolate_receivers) - misfit = guess_receiver - true_data_receivers[step] - J_val += 0.5 * assemble(inner(misfit, misfit) * dx) -We now instantiate :class:`~.EnsembleReducedFunctional`:: - - J_hat = EnsembleReducedFunctional(J_val, - Control(c_guess, riesz_map="l2"), - my_ensemble) + continue_annotation() + J_val = 0.0 + with set_working_tape() as tape: + for step in range(total_steps): + f.assign(ricker_wavelet(step * dt, frequency_peak) * q_s) + solver.solve() + u_nm1.assign(u_n) + u_n.assign(u_np1) + guess_receiver = assemble(interpolate_receivers) + misfit = guess_receiver - true_data_receivers[step] + J_val += 0.5 * assemble(inner(misfit, misfit) * dx) + + control = Control(c_guess) + Jhat_local = ReducedFunctional(J_val, control, tape=tape) + tape.progress_bar = ProgressBar + pause_annotation() + +We now instantiate :class:`~.adjoint.ensemble_reduced_functional.EnsembleReducedFunctional` +with the local ``ReducedFunctional`` for each source:: + + J_hat = EnsembleReducedFunctional(J_val, control, Jhat_local, my_ensemble) which enables us to recompute :math:`J` and its gradient :math:`\nabla_{\mathtt{c\_guess}} J`, where the :math:`J_s` and its gradients :math:`\nabla_{\mathtt{c\_guess}} J_s` are computed in parallel based on the ``my_ensemble`` configuration. -**Steps 4-6**: The instance of the :class:`~.EnsembleReducedFunctional`, named ``J_hat``, +**Steps 4-6**: The instance of the :class:`~.adjoint.ensemble_reduced_functional.EnsembleReducedFunctional`, named ``J_hat``, is then passed as an argument to the ``minimize`` function. The default ``minimize`` function uses ``scipy.minimize``, and wraps the ``ReducedFunctional`` in a ``ReducedFunctionalNumPy`` that handles transferring data between Firedrake and numpy data structures. However, because