From 865bac0843d7b22deb0dc74c9953de24a459fe26 Mon Sep 17 00:00:00 2001 From: Josh Hope-Collins Date: Fri, 10 Jan 2025 17:18:45 +0000 Subject: [PATCH 01/16] generate SOAR matrices from coordinate vector and set up AssembledMatrix & LinearSolver --- correlations/correlation_solves.py | 100 +++++++++++++++++++++++++++++ correlations/correlations.py | 92 ++++++++++++++++++++++++++ correlations/soar_mat.py | 30 +++++++++ 3 files changed, 222 insertions(+) create mode 100644 correlations/correlation_solves.py create mode 100644 correlations/correlations.py create mode 100644 correlations/soar_mat.py diff --git a/correlations/correlation_solves.py b/correlations/correlation_solves.py new file mode 100644 index 0000000..d2a3fef --- /dev/null +++ b/correlations/correlation_solves.py @@ -0,0 +1,100 @@ +import firedrake as fd +import numpy as np +import scipy.sparse.linalg as spla +from functools import partial +from correlations import ( + chordal_separation, soar_csr, csr_to_petsc, petsc_to_csr) + +from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter + +parser = ArgumentParser( + description='Construct and solve an SOAR correlation matrix on a periodic interval.', # noqa: E501 + formatter_class=ArgumentDefaultsHelpFormatter +) +parser.add_argument('--n', type=int, default=512, help='Number of mesh nodes.') # noqa: E501 +parser.add_argument('--L', type=float, default=0.6, help='Correlation length scale.') # noqa: E501 +parser.add_argument('--sigma', type=float, default=0.4, help='SOAR scaling.') # noqa: E501 +parser.add_argument('--minval', type=float, default=0.1, help='Minimum correlation value. Any below value this will be clipped to zero.') # noqa: E501 +parser.add_argument('--maxnz', type=int, default=16, help='Maximum number of nonzeros per row. Any values beyond this bandwidth will be clipped to zero.') # noqa: E501 +parser.add_argument('--psi', type=float, default=0.1, help='Shift to add to the major diagonal if the matrix is not positive-definite.') # noqa: E501 +parser.add_argument('--neigs', type=int, default=8, help='Number of eigenvalues to calculate to estimate positive-definiteness.') # noqa: E501 +parser.add_argument('--show_args', action='store_true', help='Print all the arguments when the script starts.') # noqa: E501 + +args, _ = parser.parse_known_args() + +if args.show_args: + print(args) + +np.random.seed(42) + +np.set_printoptions(legacy='1.25', precision=3, + threshold=2000, linewidth=200) + +# number of nodes +n = args.n +maxsep = 0.5*args.maxnz/n + +# eigenvalue estimates + +mesh = fd.PeriodicUnitIntervalMesh(n) +V = fd.FunctionSpace(mesh, 'DG', 0) +coords = fd.Function(V).interpolate(fd.SpatialCoordinate(mesh)[0]) +x = coords.dat.data + +chord_sep = partial(chordal_separation, + start=0, end=1) + +print('>>> Building SOAR CSR matrix') +Bcsr = soar_csr(x, args.L*maxsep, sigma=args.sigma, + tol=args.minval, maxsep=maxsep, + separation=chord_sep, + triangular=False) +size = n*n +nnz = Bcsr.nnz +nnzrow = nnz/n +fill = nnz/size +print(f'{size = } | {nnz = } | {nnzrow = } | {fill = }') + +print('>>> Checking positive-definiteness') +neigs = min(args.neigs, n-2) +emax = np.max(spla.eigsh(Bcsr, k=neigs, which='LM', + return_eigenvectors=False)) +emin = np.min(spla.eigsh(Bcsr, k=neigs, which='SA', + return_eigenvectors=False)) +cond = emax/emin +print('Eigenvalues:') +print(f'{emax = } | {emin = } | {cond = }') + +Bmat = csr_to_petsc(Bcsr) + +if emin < 0: + print('>>> Matrix is not SPD: shifting...') + Bmat.shift(abs(emin) + args.psi) + Bcsr = petsc_to_csr(Bmat) + emax = np.max(spla.eigsh(Bcsr, k=neigs, which='LM', + return_eigenvectors=False)) + emin = np.min(spla.eigsh(Bcsr, k=neigs, which='SA', + return_eigenvectors=False)) + cond = emax/emin + print('Eigenvalues after shift:') + print(f'{emax = } | {emin = } | {cond = }') + +params = { + 'ksp_converged_rate': None, + 'ksp_rtol': 1e-5, + 'ksp_type': 'cg', + 'pc_type': 'icc', +} + +print('>>> Setting up solver') +arguments = (fd.TestFunction(V), fd.TrialFunction(V)) +B = fd.AssembledMatrix(arguments, bcs=[], petscmat=Bmat) +solver = fd.LinearSolver(B, options_prefix='', + solver_parameters=params) + +x = fd.Function(V) +b = fd.Cofunction(V.dual()) +b.dat.data[:] = np.random.random_sample(b.dat.data.shape) + +print('>>> Solving...') +solver.solve(x, b) diff --git a/correlations/correlations.py b/correlations/correlations.py new file mode 100644 index 0000000..e6de260 --- /dev/null +++ b/correlations/correlations.py @@ -0,0 +1,92 @@ +from firedrake.petsc import PETSc +from scipy import sparse +import numpy as np + + +def minmag(a, b): + return np.where(np.abs(a) < np.abs(b), a, b) + + +def periodic_separation(y, x, start=None, end=None): + start = x[0] if start is None else start + end = x[-1] if end is None else end + internal = x - y + left = (y - start) + (end - x) + right = (end - y) + x + return minmag(minmag(left, internal), right) + + +def chordal_distance(separation, circumference): + diameter = circumference/np.pi + angle = 2*np.pi*separation/circumference + return diameter*np.sin(angle/2) + + +def chordal_separation(y, x, start=None, end=None): + start = x[0] if start is None else start + end = x[-1] if end is None else end + return chordal_distance( + periodic_separation(y, x, start, end), + end - start) + + +def separation_csr(x, maxsep=None, separation=None, + triangular=True): + n = len(x) + upper = sparse.csr_array((n, n)) + for i, y in enumerate(x): + if separation: + d = separation(y, x[i:]) + else: + d = x[i:] - y + if maxsep: + d[d > maxsep] = 0 + upper[i, i:] = d + upper.eliminate_zeros() + upper.setdiag(0) + if triangular: + return upper + else: + full = symmetrise_csr(upper) + full.setdiag(0) + return full + + +def symmetrise_csr(csr): + transpose = csr.T.tocsr() + transpose.setdiag(0) + return csr + transpose + + +def soar(r, L, sigma=1.): + rL = np.abs(r)/L + return sigma*(1 + rL)*np.exp(-rL) + + +def soar_csr(x, L, sigma=1., tol=None, + maxsep=None, separation=None, + triangular=True): + upper = separation_csr(x, maxsep=maxsep, + separation=separation, + triangular=True) + upper.data[:] = soar(upper.data, L, sigma=sigma) + if tol: + upper.data[upper.data < tol] = 0 + upper.eliminate_zeros() + return upper if triangular else symmetrise_csr(upper) + + +def csr_to_petsc(scipy_mat): + mat = PETSc.Mat().create() + mat.setType('aij') + mat.setSizes(scipy_mat.shape) + mat.setValuesCSR(scipy_mat.indptr, + scipy_mat.indices, + scipy_mat.data) + mat.assemble() + return mat + + +def petsc_to_csr(petsc_mat): + return sparse.csr_matrix(petsc_mat.getValuesCSR()[::-1], + shape=petsc_mat.getSize()) diff --git a/correlations/soar_mat.py b/correlations/soar_mat.py new file mode 100644 index 0000000..760f60a --- /dev/null +++ b/correlations/soar_mat.py @@ -0,0 +1,30 @@ +import numpy as np +from functools import partial +from correlations import ( + chordal_separation, soar_csr, csr_to_petsc, petsc_to_csr) + +np.set_printoptions(legacy='1.25', precision=2, + linewidth=200, threshold=10000) + +n = 16 + +L = 0.1 +tol = 0.4 + +maxsep = 0.3 + +width = 1 +start = 0 +end = start + width + +x = np.linspace(start, end, n, endpoint=False) + +chordsep = partial(chordal_separation, + start=start, end=end) + +petsc_mat = csr_to_petsc( + soar_csr(x, L, tol=tol, maxsep=maxsep, + separation=chordsep, + triangular=False)) + +print(f'mat =\n{petsc_to_csr(petsc_mat).todense()}') From be9d0aff697667b70f274206e58189899f89200c Mon Sep 17 00:00:00 2001 From: Josh Hope-Collins Date: Tue, 21 Jan 2025 13:06:44 +0000 Subject: [PATCH 02/16] solver external operator --- correlations/solver_external_operator.py | 49 ++++++++ correlations/taylor_test.py | 154 +++++++++++++++++++++++ 2 files changed, 203 insertions(+) create mode 100644 correlations/solver_external_operator.py create mode 100644 correlations/taylor_test.py diff --git a/correlations/solver_external_operator.py b/correlations/solver_external_operator.py new file mode 100644 index 0000000..b20b5e6 --- /dev/null +++ b/correlations/solver_external_operator.py @@ -0,0 +1,49 @@ +import firedrake as fd +from firedrake.external_operators import AbstractExternalOperator, assemble_method + + +class LinearSolverOperator(AbstractExternalOperator): + def __init__(self, *operands, function_space, + derivatives=None, argument_slots=(), + operator_data=None): + + AbstractExternalOperator.__init__( + self, *operands, + function_space=function_space, + derivatives=derivatives, + argument_slots=argument_slots, + operator_data=operator_data) + + self.b, = operands + self._b = self.b.copy() + self.bexpr = fd.inner(self.b, fd.TestFunction(self._b.function_space()))*fd.dx + + self.A = operator_data["A"] + self.solver = fd.LinearSolver( + self.A, **operator_data.get("solver_kwargs", {})) + + def _solve(self, b): + x = fd.Function(self.function_space()) + self.solver.solve(x, b) + return x + + def _solve_transpose(self, b): + x = fd.Cofunction(self.function_space.dual()) + self.solver.solve_transpose(x, b) + return x + + def _assemble_b(self, b): + self._b.assign(b) + return fd.assemble(self.bexpr) + + @assemble_method(0, (0,)) + def assemble_operator(self, *args, **kwargs): + return self._solve(self._assemble_b(self.b)) + + @assemble_method(1, (0, None)) + def assemble_jacobian_action(self, *args, **kwargs): + return self._solve(self._assemble_b(self.argument_slots()[-1])) + + @assemble_method(1, (1, None)) + def assemble_jacobian_adjoint_action(self, *args, **kwargs): + return self._assemble_b(self._solve_transpose(self.argument_slots()[0])) diff --git a/correlations/taylor_test.py b/correlations/taylor_test.py new file mode 100644 index 0000000..8e88fb7 --- /dev/null +++ b/correlations/taylor_test.py @@ -0,0 +1,154 @@ +import firedrake as fd +from firedrake.adjoint import ( + continue_annotation, pause_annotation, Control, ReducedFunctional) +import numpy as np +import scipy.sparse.linalg as spla +from functools import partial +from correlations import ( + chordal_separation, soar_csr, csr_to_petsc, petsc_to_csr) + +from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter + + +def weighted_norm(M, x, **kwargs): + y = fd.Function(x.function_space().dual()) + fd.solve(M, y, x, **kwargs) + return fd.assemble(fd.inner(y, y)*fd.dx) + + +parser = ArgumentParser( + description='Construct and solve an SOAR correlation matrix on a periodic interval.', # noqa: E501 + formatter_class=ArgumentDefaultsHelpFormatter +) +parser.add_argument('--n', type=int, default=128, help='Number of mesh nodes.') # noqa: E501 +parser.add_argument('--L', type=float, default=0.6, help='Correlation length scale.') # noqa: E501 +parser.add_argument('--sigma', type=float, default=0.4, help='SOAR scaling.') # noqa: E501 +parser.add_argument('--minval', type=float, default=0.1, help='Minimum correlation value. Any below value this will be clipped to zero.') # noqa: E501 +parser.add_argument('--maxnz', type=int, default=16, help='Maximum number of nonzeros per row. Any values beyond this bandwidth will be clipped to zero.') # noqa: E501 +parser.add_argument('--psi', type=float, default=0.1, help='Shift to add to the major diagonal if the matrix is not positive-definite.') # noqa: E501 +parser.add_argument('--neigs', type=int, default=8, help='Number of eigenvalues to calculate to estimate positive-definiteness.') # noqa: E501 +parser.add_argument('--show_args', action='store_true', help='Print all the arguments when the script starts.') # noqa: E501 + +args, _ = parser.parse_known_args() + +if args.show_args: + print(args) + +np.random.seed(42) + +np.set_printoptions(legacy='1.25', precision=3, + threshold=2000, linewidth=200) + +# number of nodes +n = args.n +maxsep = 0.5*args.maxnz/n + +# eigenvalue estimates + +mesh = fd.PeriodicUnitIntervalMesh(n) +V = fd.FunctionSpace(mesh, 'DG', 0) +coords = fd.Function(V).interpolate(fd.SpatialCoordinate(mesh)[0]) +x = coords.dat.data + +chord_sep = partial(chordal_separation, + start=0, end=1) + +print('>>> Building SOAR CSR matrix') +Bcsr = soar_csr(x, args.L, sigma=args.sigma, + tol=args.minval, maxsep=maxsep, + separation=chord_sep, + triangular=False) +size = n*n +nnz = Bcsr.nnz +nnzrow = nnz/n +fill = nnz/size +print(f'{size = } | {nnz = } | {nnzrow = } | {fill = }') + +print('>>> Checking positive-definiteness') +neigs = min(args.neigs, n-2) +emax = np.max(spla.eigsh(Bcsr, k=neigs, which='LM', + return_eigenvectors=False)) +emin = np.min(spla.eigsh(Bcsr, k=neigs, which='SA', + return_eigenvectors=False)) +cond = emax/emin +print('Eigenvalues:') +print(f'{emax = } | {emin = } | {cond = }') + +Bmat = csr_to_petsc(Bcsr) + +if emin < 0: + print('>>> Matrix is not SPD: shifting...') + Bmat.shift(abs(emin) + args.psi) + Bcsr = petsc_to_csr(Bmat) + emax = np.max(spla.eigsh(Bcsr, k=neigs, which='LM', + return_eigenvectors=False)) + emin = np.min(spla.eigsh(Bcsr, k=neigs, which='SA', + return_eigenvectors=False)) + cond = emax/emin + print('Eigenvalues after shift:') + print(f'{emax = } | {emin = } | {cond = }') + +params = { + 'ksp_converged_rate': None, + 'ksp_monitor': None, + + 'ksp_rtol': 1e-5, + 'ksp_type': 'cg', + 'pc_type': 'icc', + + # 'ksp_type': 'preonly', + # 'pc_type': 'lu', + # 'pc_factor_mat_solver_type': 'mumps', +} + +print('>>> Setting up solver') +arguments = (fd.TestFunction(V), fd.TrialFunction(V)) + +B = fd.AssembledMatrix(arguments, bcs=[], petscmat=Bmat) + +x0 = fd.Function(V) +x1 = fd.Function(V) + +b0 = fd.Cofunction(V.dual()) +b0.dat.data[:] = np.random.random_sample(b0.dat.data.shape) + +b1 = fd.Cofunction(V.dual()) +b1.dat.data[:] = np.random.random_sample(b1.dat.data.shape) + +print('>>> Solving LinearSolver...') +solver = fd.LinearSolver(B, options_prefix='', + solver_parameters=params) +solver.solve(x0, b0) +xBx = fd.assemble(fd.inner(b0.riesz_representation(), x0)*fd.dx) +print(f'{xBx = }') + +print('>>> Solving LinearSolverOperator...') + +from solver_external_operator import LinearSolverOperator +bfunc = fd.Function(V) +solve_operator = LinearSolverOperator( + bfunc, function_space=V, + operator_data={ + "A": B, + "solver_kwargs": { + "options_prefix": '', + "solver_parameters": params + } + }) + +bfunc.assign(b0.riesz_representation()) +xBx = fd.assemble(fd.inner(b0.riesz_representation(), solve_operator)*fd.dx) +print(f'{xBx = }') + +bfunc.assign(b0.riesz_representation()) +continue_annotation() +J = fd.assemble(fd.inner(bfunc, solve_operator)*fd.dx) +Jhat = ReducedFunctional(J, Control(bfunc)) +pause_annotation() + +print(f'{Jhat(b1.riesz_representation()) = }') +print(f'{fd.norm(Jhat.derivative()) = }') + +solver.solve(x1, b1) +xBx = fd.assemble(fd.inner(b1.riesz_representation(), x1)*fd.dx) +print(f'{xBx = }') From c7fedd0f5cb1f7875b27e3743c7fbb455263e5f7 Mon Sep 17 00:00:00 2001 From: Josh Hope-Collins Date: Tue, 4 Feb 2025 10:44:58 +0000 Subject: [PATCH 03/16] fdvar module and tao solver --- fdvar/__init__.py | 1 + fdvar/tao_solver.py | 186 ++++++++++++++++++++++++++++++++++++++++++++ setup.py | 7 ++ 3 files changed, 194 insertions(+) create mode 100644 fdvar/__init__.py create mode 100644 fdvar/tao_solver.py create mode 100644 setup.py diff --git a/fdvar/__init__.py b/fdvar/__init__.py new file mode 100644 index 0000000..a5acf8a --- /dev/null +++ b/fdvar/__init__.py @@ -0,0 +1 @@ +from .tao_solver import * # noqa: F401, F403 diff --git a/fdvar/tao_solver.py b/fdvar/tao_solver.py new file mode 100644 index 0000000..9ff68ba --- /dev/null +++ b/fdvar/tao_solver.py @@ -0,0 +1,186 @@ +import firedrake as fd +from firedrake import PETSc +from pyadjoint.optimization.tao_solver import ( + OptionsManager, TAOConvergenceError, _tao_reasons) +from functools import cached_property + +__all__ = ("TAOObjective", "TAOConvergenceError", "TAOSolver") + + +class TAOObjective: + def __init__(self, Jhat, dual_options=None): + self.Jhat = Jhat + self.dual_options = dual_options + self._control = Jhat.control.copy() + self._m = Jhat.control.copy() + self._mdot = Jhat.control.copy() + + self.n = self._m._vec.getLocalSize() + self.N = self._m._vec.getSize() + self.sizes = (self.n, self.N) + + def objective(self, tao, x): + with self._control.vec_wo() as cvec: + x.copy(cvec) + return self.Jhat(self._control) + + def gradient(self, tao, x, g): + dJ = self.Jhat.derivative(options=self.dual_options) + with dJ.vec_ro() as dvec: + dvec.copy(g) + # self.objective_gradient(tao, x, g) + + def objective_gradient(self, tao, x, g): + with self._control.vec_wo() as cvec: + x.copy(cvec) + J = self.Jhat(self._control) + dJ = self.Jhat.derivative(options=self.dual_options) + with dJ.vec_ro() as dvec: + dvec.copy(g) + # self.gradient(tao, x, g) + return J + + def hessian(self, A, x, y): + with self._mdot.vec_ro() as mvec: + x.copy(mvec) + ddJ = self.Jhat.hessian() + with ddJ.vec_ro() as dvec: + dvec.copy(y) + if self._shift != 0.0: + y.axpy(self._shift, x) + + @cached_property + def hessian_mat(self): + ctx = HessianCtx(self.Jhat, dual_options=self.dual_options) + mat = PETSc.Mat().createPython( + (self.sizes, self.sizes), ctx, + comm=self.Jhat.ensemble.global_comm) + mat.setOption(PETSc.Mat.Option.SYMMETRIC, True) + mat.setUp() + mat.assemble() + return mat + + @cached_property + def gradient_norm_mat(self): + ctx = GradientNormCtx(self.Jhat) + mat = PETSc.Mat().createPython( + (self.sizes, self.sizes), ctx, + comm=self.Jhat.ensemble.global_comm) + mat.setOption(PETSc.Mat.Option.SYMMETRIC, True) + mat.setUp() + mat.assemble() + return mat + + +class HessianCtx: + @classmethod + def update(cls, tao, x, H, P): + ctx = H.getPythonContext() + with ctx._m.vec_wo() as mvec: + x.copy(mvec) + ctx._shift = 0.0 + + def __init__(self, Jhat, dual_options=None): + self.Jhat = Jhat + self._m = Jhat.control.copy() + self._mdot = Jhat.control.copy() + self._shift = 0.0 + self.dual_options = dual_options + + def shift(self, A, alpha): + self._shift += alpha + + def mult(self, A, x, y): + with self._mdot.vec_wo() as mdvec: + x.copy(mdvec) + + # TODO: Why do we need to reevaluate and derivate? + # _ = self.Jhat(self._m) + # _ = self.Jhat.derivative(options=self.dual_options) + ddJ = self.Jhat.hessian([self._mdot]) + + with ddJ.vec_ro() as dvec: + dvec.copy(y) + + if self._shift != 0.0: + y.axpy(self._shift, x) + + +class GradientNormCtx: + def __init__(self, Jhat): + self._xfunc = Jhat.control.copy() + self._ycofunc = self._xfunc.riesz_representation() + + # TODO: Just implement EnsembleFunction._ad_convert_type + v = fd.TestFunction(self._xfunc._function_space) + self.M = fd.inner(v, self._xfunc._fbuf)*fd.dx + + def mult(self, mat, x, y): + with self._xfunc.vec_wo() as xvec: + x.copy(xvec) + + fd.assemble(self.M, tensor=self._ycofunc._fbuf) + + with self._ycofunc.vec_ro() as yvec: + yvec.copy(y) + + +class TAOSolver: + def __init__(self, Jhat, *, options_prefix=None, + solver_parameters=None): + self.Jhat = Jhat + self.ensemble = Jhat.ensemble + + dual_options = {'riesz_representation': None} + + self.tao_objective = TAOObjective(Jhat, dual_options) + + self.tao = PETSc.TAO().create( + comm=Jhat.ensemble.global_comm) + + # solution vector + self._x = Jhat.control._vec.duplicate() + self.tao.setSolution(self._x) + + # evaluate objective and gradient + self.tao.setObjective( + self.tao_objective.objective) + + self.tao.setGradient( + self.tao_objective.gradient) + + self.tao.setObjectiveGradient( + self.tao_objective.objective_gradient) + + # evaluate hessian action + hessian_mat = self.tao_objective.hessian_mat + self.tao.setHessian( + hessian_mat.getPythonContext().update, + hessian_mat) + + # gradient norm in correct space + self.tao.setGradientNorm( + self.tao_objective.gradient_norm_mat) + + # solver parameters and finish setup + self.options = OptionsManager( + solver_parameters, options_prefix) + self.options.set_from_options(self.tao) + self.tao.setUp() + + def solve(self): + control = self.Jhat.control + with control.tape_value().vec_ro() as cvec: + cvec.copy(self._x) + + with self.options.inserted_options(): + self.tao.solve() + + if self.tao.getConvergedReason() <= 0: + # Using the same format as Firedrake linear solver errors + raise TAOConvergenceError( + f"TAOSolver failed to converge after {self.tao.getIterationNumber()} iterations " + f"with reason: {_tao_reasons[self.tao.getConvergedReason()]}") + + with control.vec_wo() as cvec: + self._x.copy(cvec) diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..293d5f0 --- /dev/null +++ b/setup.py @@ -0,0 +1,7 @@ +from setuptools import find_packages, setup + +setup( + name="fdvar", + version="0.1", + packages=find_packages() +) From 4a765688daf26fda20c7d7cacdff9522798a7df7 Mon Sep 17 00:00:00 2001 From: Josh Hope-Collins Date: Tue, 4 Feb 2025 10:46:02 +0000 Subject: [PATCH 04/16] tao solve advection example --- advection/advection_utils.py | 14 +-- advection/advection_wc4dvar_aaorf.py | 117 ++++++++++++----------- advection/advection_wc4dvar_aaorf_tao.py | 56 +++++++++++ 3 files changed, 123 insertions(+), 64 deletions(-) create mode 100644 advection/advection_wc4dvar_aaorf_tao.py diff --git a/advection/advection_utils.py b/advection/advection_utils.py index f6de905..3206f83 100644 --- a/advection/advection_utils.py +++ b/advection/advection_utils.py @@ -3,13 +3,9 @@ import numpy as np from sys import exit -np.set_printoptions(legacy='1.25', precision=6) - +from firedrake.adjoint.fourdvar_reduced_functional import covariance_norm -def norm2(w): - def n2(x): - return fd.assemble(fd.inner(x, fd.Constant(w)*x)*fd.dx) - return n2 +np.set_printoptions(legacy='1.25', precision=6) def timestepper(mesh, V, dt, u): @@ -45,9 +41,9 @@ def analytic_solution(mesh, u, t, mag=1.0, phase=0.0): x, = fd.SpatialCoordinate(mesh) return mag*fd.sin(2*fd.pi*((x + phase) - u*t)) -B = 1 -R = 1 -Q = 1 +B = 10. +R = 0.1 +Q = 0.2*B ensemble = fd.Ensemble(fd.COMM_WORLD, 1) diff --git a/advection/advection_wc4dvar_aaorf.py b/advection/advection_wc4dvar_aaorf.py index 61406d4..b04b695 100644 --- a/advection/advection_wc4dvar_aaorf.py +++ b/advection/advection_wc4dvar_aaorf.py @@ -5,58 +5,65 @@ from firedrake.adjoint import FourDVarReducedFunctional from advection_utils import * -control = fd.EnsembleFunction( - ensemble, [V for _ in range(len(targets))]) - -for x in control.subfunctions: - x.assign(background) - -continue_annotation() - -# create 4DVar reduced functional and record -# background and initial observation functionals - -Jhat = FourDVarReducedFunctional( - Control(control), - background_iprod=norm2(B), - observation_iprod=norm2(R), - observation_err=observation_error(0), - weak_constraint=True) - -nstep = 0 -# record observation stages -with Jhat.recording_stages() as stages: - # loop over stages - for stage, ctx in stages: - # start forward model - qn.assign(stage.control) - - # propogate - for _ in range(observation_freq): - qn1.assign(qn) - stepper.solve() - qn.assign(qn1) - nstep += 1 - - # take observation - obs_err = observation_error(stage.observation_index) - stage.set_observation(qn, obs_err, - observation_iprod=norm2(R), - forward_model_iprod=norm2(Q)) - -pause_annotation() - -# the perturbation values need to be held in the -# same type as the control i.e. and EnsembleFunction -vals = control.copy() -for v0, v1 in zip(vals.subfunctions, values): - v0.assign(v1) - -print(f"{Jhat(control) = }") -print(f"{taylor_test(Jhat, control, vals) = }") - -options = {'disp': True, 'ftol': 1e-2} -derivative_options = {'riesz_representation': 'l2'} - -opt = minimize(Jhat, options=options, method="L-BFGS-B", - derivative_options=derivative_options) + +def make_fdvrf(): + control = fd.EnsembleFunction( + ensemble, [V for _ in range(len(targets))]) + + for x in control.subfunctions: + x.assign(background) + + continue_annotation() + + # create 4DVar reduced functional and record + # background and initial observation functionals + + Jhat = FourDVarReducedFunctional( + Control(control), + background_covariance=B, + observation_covariance=R, + observation_error=observation_error(0), + weak_constraint=True) + + nstep = 0 + # record observation stages + with Jhat.recording_stages() as stages: + # loop over stages + for stage, ctx in stages: + # start forward model + qn.assign(stage.control) + + # propogate + for _ in range(observation_freq): + qn1.assign(qn) + stepper.solve() + qn.assign(qn1) + nstep += 1 + + # take observation + obs_err = observation_error(stage.observation_index) + stage.set_observation(qn, obs_err, + observation_covariance=R, + forward_model_covariance=Q) + pause_annotation() + + return Jhat, control + + +if __name__ == '__main__': + Jhat, control = make_fdvrf() + + # the perturbation values need to be held in the + # same type as the control i.e. and EnsembleFunction + vals = control.copy() + for v0, v1 in zip(vals.subfunctions, values): + v0.assign(v1) + + print(f"{Jhat(control) = }") + print(f"{taylor_test(Jhat, control, vals) = }") + + options = {'disp': True, 'ftol': 1e-2} + derivative_options = {'riesz_representation': 'l2'} + + opt = minimize(Jhat, options=options, method="L-BFGS-B", + derivative_options=derivative_options) diff --git a/advection/advection_wc4dvar_aaorf_tao.py b/advection/advection_wc4dvar_aaorf_tao.py new file mode 100644 index 0000000..35e0641 --- /dev/null +++ b/advection/advection_wc4dvar_aaorf_tao.py @@ -0,0 +1,56 @@ +import firedrake as fd +from firedrake.adjoint import ( + continue_annotation, pause_annotation, minimize, + stop_annotating, Control, taylor_test) +from firedrake.adjoint import FourDVarReducedFunctional +from advection_utils import * +from fdvar import TAOSolver +from sys import exit + +from advection_wc4dvar_aaorf import make_fdvrf +Jhat, control = make_fdvrf() + +# the perturbation values need to be held in the +# same type as the control i.e. and EnsembleFunction +vals = control.copy() +for v0, v1 in zip(vals.subfunctions, values): + v0.assign(v1) + +# print(f"{Jhat(control) = }") +# print(f"{taylor_test(Jhat, control, vals) = }") + +# options = {'disp': True, 'ftol': 1e-2} +# derivative_options = {'riesz_representation': None} +# opt = minimize(Jhat, options=options, method="L-BFGS-B", +# derivative_options=derivative_options) +# exit() + +ksp_params = { + 'monitor': None, + 'converged_rate': None, + 'rtol': 1e-1, +} + +tao_params = { + 'tao_view': ':tao_view.log', + 'tao': { + 'monitor': None, + 'converged_reason': None, + 'gatol': 1e-1, + 'grtol': 1e-1, + 'gttol': 1e-1, + }, + 'tao_type': 'nls', + 'tao_nls': { + 'ksp': ksp_params, + 'ksp_type': 'cg', + 'pc_type': 'lmvm', + }, + 'tao_cg': { + 'ksp': ksp_params, + 'type': 'pr', # fr-pr-prp-hs-dy + }, +} +tao = TAOSolver(Jhat, options_prefix="", + solver_parameters=tao_params) +tao.solve() From c30578ba42254c9de92e70812c982e570d6ecd0d Mon Sep 17 00:00:00 2001 From: Josh Hope-Collins Date: Tue, 4 Feb 2025 10:46:38 +0000 Subject: [PATCH 05/16] WIP: mats/ksps for saddle point formulation --- advection/advection_wc4dvar_saddlepc.py | 281 ++++++++++++++++++++++++ 1 file changed, 281 insertions(+) create mode 100644 advection/advection_wc4dvar_saddlepc.py diff --git a/advection/advection_wc4dvar_saddlepc.py b/advection/advection_wc4dvar_saddlepc.py new file mode 100644 index 0000000..ab5ec7f --- /dev/null +++ b/advection/advection_wc4dvar_saddlepc.py @@ -0,0 +1,281 @@ +import firedrake as fd +from firedrake.petsc import PETSc, OptionsManager +from firedrake.adjoint import pyadjoint # noqa: F401 +from pyadjoint.optimization.tao_solver import PETScVecInterface +from advection_wc4dvar_aaorf import make_fdvrf +from mpi4py import MPI +import numpy as np +from typing import Optional + + +# CovarianceNormRF Mat +def CovarianceMat(covariancerf): + covariance = covariancerf.covariance + space = covariancerf.controls[0].control.function_space() + comm = space.mesh().comm + sizes = space.dof_dset.layout_vec.sizes + shape = (sizes, sizes) + covmat = PETSc.Mat().createConstantDiagonal( + shape, covariance, comm=comm) + covmat.setUp() + covmat.assemble() + return covmat + + +# pyadjoint RF Mat +class ReducedFunctionalMatCtx: + """ + PythonMat context to apply action of a pyadjoint.ReducedFunctional. + + Parameters + ---------- + + action_type + Union['hessian', 'tlm', 'adjoint'] + """ + def __init__(self, Jhat: pyadjoint.ReducedFunctional, + action_type: str = 'hessian', + derivative_options: Optional[dict] = None, + comm: MPI.Comm = PETSc.COMM_WORLD): + self.Jhat = Jhat + self.control_interface = PETScVecInterface(Jhat.controls, comm=comm) + self.functional_interface = PETScVecInterface( + Jhat.functional, comm=comm) + + if action_type == 'hessian': + self.xinterface = self.control_interface + self.yinterface = self.control_interface + elif action_type == 'adjoint': + self.xinterface = self.functional_interface + self.yinterface = self.control_interface + elif action_type == 'tlm': + self.xinterface = self.control_interface + self.yinterface = self.functional_interface + else: + raise ValueError( + 'Unrecognised {action_type = }.') + + self.action_type = action_type + self._m = Jhat.control.copy() + self.derivative_options = derivative_options + self._shift = 0 + + self.default_mult = { + 'hessian': self._mult_hessian, + 'tlm': self._mult_tlm, + 'adjoint': self._mult_adjoint + }[action_type] + + # Storage for result of action. + # Possibly in the dual space for adjoint actions. + if action_type == 'adjoint': + self.Jhat(self._m) + self._mdot = self.Jhat( + derivative_options=derivative_options) + else: + self._mdot = Jhat.control.copy() + + @classmethod + def update(cls, obj, x, A, P): + ctx = A.getPythonContext() + ctx.control_interface.from_petsc(x, ctx._m) + ctx._shift = 0 + + def update_tape_values(self, update_adjoint=True): + _ = self.Jhat(self._m) + if update_adjoint: + _ = self.Jhat.derivative(options=self.derivative_options) + + def mult(self, A, x, y): + self.xinterface.from_petsc(x, self._mdot) + out = self.default_mult(A, self._mdot) + self.yinterface.to_petsc(out, y) + if self._shift != 0: + y.axpy(self._shift, x) + + def _mult_hessian(self, A, x): + if self.action_type != 'hessian': + raise NotImplementedError( + f'Cannot apply hessian action if {self.action_type = }') + self.update_tape_values() + return self.Jhat.hessian(x) + + def _mult_tlm(self, A, x): + if self.action_type != 'tlm': + raise NotImplementedError( + f'Cannot apply tlm action if {self.action_type = }') + self.update_tape_values(update_adjoint=False) + return self.Jhat.tlm(x) + + def _mult_adjoint(self, A, x): + if self.action_type != 'adjoint': + raise NotImplementedError( + f'Cannot apply adjoint action if {self.action_type = }') + self.update_tape_values(update_adjoint=False) + return self.Jhat.derivative( + adj_value=x, derivative_options=self.derivative_options) + + +def ReducedFunctionalMat(self, Jhat, action_type='hessian', + derivative_options=None, + comm=PETSc.COMM_WORLD): + ctx = ReducedFunctionalMatCtx( + Jhat, action_type, + derivative_options, + comm=comm) + # TODO: use functional_interface sizes + # to allow non-square matrices. + n = ctx.control_interface.n + N = ctx.control_interface.N + mat = PETSc.Mat().createPython( + ((n, N), (n, N)), ctx, comm=comm) + mat.setUp() + mat.assemble() + return mat + + +class EnsembleBlockDiagonalMat: + def __init__(self, ensemble, spaces, blocks): + if isinstance(spaces, fd.EnsembleFunction): + if spaces.ensemble is not ensemble: + raise ValueError( + "Ensemble of EnsembleFunction must match ensemble provided") + spaces = spaces.local_function_spaces + if len(blocks) != len(spaces): + raise ValueError( + f"EnsembleBlockDiagonalMat requires one submatrix for each of the" + f" {len(spaces)} local subfunctions of theEnsembleFunction, but" + f" only {len(blocks)} provided.") + + for i, (subspace, block) in enumerate(zip(spaces, blocks)): + vsizes = subspace.dof_dset.layout_vec.sizes + msizes = block.sizes + if msizes[0] != msizes[1]: + raise ValueError( + f"Block {i} of EnsembleBlockDiagonalMat must be square, not {msizes}") + if msizes[0] != vsizes: + raise ValueError( + f"Block {i} of EnsembleBlockDiagonalMat must have shape {(vsizes, vsizes)}" + f" to match the EnsembleFunction, not shape {msizes}") + + self.ensemble = ensemble + self.blocks = blocks + self.spaces = spaces + + # EnsembleFunction knows how to split out subvecs for each block + self.x = fd.EnsembleFunction(self.ensemble, spaces) + self.y = fd.EnsembleCofunction(self.ensemble, [V.dual() for V in spaces]) + + def mult(self, A, x, y): + with self.x.vec_wo() as xvec: + x.copy(xvec) + + # compute answer + subvecs = zip(self.x.subfunctions, self.y.subfunctions) + for block, (xsub, ysub) in zip(self.blocks, subvecs): + with xsub.dat.vec_ro as xvec, ysub.dat.vec_wo as yvec: + block.mult(xvec, yvec) + + with self.y.vec_ro() as yvec: + yvec.copy(y) + + +class EnsembleBlockDiagonalPC: + prefix = "ensemblejacobi_" + + def __init__(self): + self.initialized = False + + def setUp(self, pc): + if not self.initialized: + self.initialize(pc) + self.update(pc) + + def initialize(self, pc): + if pc.getType() != "python": + raise ValueError("Expecting PC type python") + + pcprefix = pc.getOptionsPrefix() + prefix = pcprefix + self.prefix + options = PETSc.Options(prefix) + + _, P = pc.getOperators() + ensemble_mat = P.getPythonContext() + ensemble = ensemble_mat.ensemble + spaces = ensemble_mat.spaces + submats = ensemble_mat.blocks + + self.ensemble = ensemble + self.spaces = spaces + self.submats = submats + + self.x = fd.EnsembleFunction(ensemble, spaces) + self.y = fd.EnsembleCofunction(ensemble, + [V.dual() for V in spaces]) + + subksps = [] + for i, mat in enumerate(submats): + ksp = PETSc.KSP().create(comm=ensemble.comm) + ksp.setOperators(mat) + + sub_prefix = pcprefix + f"sub_{i}_" + # TODO: default options + options = OptionsManager({}, sub_prefix) + options.set_from_options(ksp) + self.subksps.append((ksp, options)) + + self.subksps = tuple(subksps) + + def apply(self, pc, x, y): + with self.x.vec_wo() as xvec: + x.copy(xvec) + + subfuncs = zip(self.x.subfunctions, self.y.subfunctions) + for (subksp, suboptions), (subx, suby) in zip(self.subksps, subfuncs): + with subx.dat.vec_ro as rhs, suby.dat.vec_wo as sol: + with suboptions.inserted_options(): + subksp.solve(rhs, sol) + + with self.y.vec_ro() as yvec: + yvec.copy(y) + + +class EnsembleMat: + def __init__(self, ensemblefunction, ctx): + self.ensemble = ensemblefunction.ensemble + sizes = ensemblefunction._vec.sizes + self.petsc_mat = PETSc.Mat().createPython( + (sizes, sizes), ctx, + comm=self.ensemble.comm) + self.petsc_mat.setUp() + self.petsc_mat.assemble() + + +Jhat, control = make_fdvrf() +ensemble = Jhat.ensemble + +# >>>>> Covariance + +# Covariance Mat +# covrf = Jhat.background_norm +covrf = Jhat.stages[0].model_norm +covmat = CovarianceMat(covrf) + +# Covariance KSP +covksp = PETSc.KSP().create(comm=ensemble.comm) +covksp.setOptionsPrefix('cov_') +covksp.setOperators(covmat) + +covksp.pc.setType(PETSc.PC.Type.JACOBI) +covksp.setType(PETSc.KSP.Type.PREONLY) +covksp.setFromOptions() +covksp.setUp() +print(PETSc.Options().getAll()) + +x = covmat.createVecRight() +b = covmat.createVecLeft() + +b.array_w[:] = np.random.random_sample(b.array_w.shape) +print(f'{b.norm() = }') +covksp.solve(b, x) +print(f'{x.norm() = }') From 9ca4c887bd0f5273cf8398a253f08e9cfd9e4a50 Mon Sep 17 00:00:00 2001 From: Josh Hope-Collins Date: Sat, 15 Feb 2025 11:43:37 +0000 Subject: [PATCH 06/16] efs update --- advection/advection_sc4dvar_aaorf.py | 10 +++---- advection/advection_sc4dvar_pyadj.py | 8 +++--- advection/advection_wc4dvar_aaorf.py | 5 ++-- advection/advection_wc4dvar_pyadj.py | 8 +++--- burgers/aaorf_4dvar.py | 42 ++++++++-------------------- burgers/burgers_wc4dvar_demo.py | 36 ++++++++---------------- fdvar/tao_solver.py | 7 +++-- 7 files changed, 43 insertions(+), 73 deletions(-) diff --git a/advection/advection_sc4dvar_aaorf.py b/advection/advection_sc4dvar_aaorf.py index 1a008cc..9f869f0 100644 --- a/advection/advection_sc4dvar_aaorf.py +++ b/advection/advection_sc4dvar_aaorf.py @@ -15,9 +15,9 @@ Jhat = FourDVarReducedFunctional( Control(control), - background_iprod=norm2(B), - observation_iprod=norm2(R), - observation_err=observation_error(0), + background_covariance=B, + observation_covariance=R, + observation_error=observation_error(0), weak_constraint=False) nstep = 0 @@ -38,13 +38,13 @@ # take observation obs_err = observation_error(stage.observation_index) stage.set_observation(qn, obs_err, - observation_iprod=norm2(R)) + observation_covariance=R) pause_annotation() print(f"{taylor_test(Jhat, control, values[0]) = }") -options = {'disp': True, 'ftol': 1e-2} +options = {'disp': fd.COMM_WORLD.rank == 0, 'ftol': 1e-2} derivative_options = {'riesz_representation': None} opt = minimize(Jhat, options=options, method="L-BFGS-B", diff --git a/advection/advection_sc4dvar_pyadj.py b/advection/advection_sc4dvar_pyadj.py index 252ec6a..beccde9 100644 --- a/advection/advection_sc4dvar_pyadj.py +++ b/advection/advection_sc4dvar_pyadj.py @@ -11,10 +11,10 @@ continue_annotation() # background functional -J = norm2(B)(control - background) +J = covariance_norm(control - background, B) # initial observation functional -J += norm2(R)(observation_error(0)(control)) +J += covariance_norm(observation_error(0)(control), R) nstep = 0 qn.assign(control) @@ -28,7 +28,7 @@ nstep += 1 # observation functional - J += norm2(R)(observation_error(i)(qn)) + J += covariance_norm(observation_error(i)(qn), R) pause_annotation() @@ -36,7 +36,7 @@ print(f"{taylor_test(Jhat, control, values[0]) = }") -options = {'disp': True, 'ftol': 1e-2} +options = {'disp': fd.COMM_WORLD.rank == 0, 'ftol': 1e-2} derivative_options = {'riesz_representation': 'l2'} opt = minimize(Jhat, options=options, method="L-BFGS-B", diff --git a/advection/advection_wc4dvar_aaorf.py b/advection/advection_wc4dvar_aaorf.py index b04b695..e3bd307 100644 --- a/advection/advection_wc4dvar_aaorf.py +++ b/advection/advection_wc4dvar_aaorf.py @@ -7,8 +7,9 @@ def make_fdvrf(): - control = fd.EnsembleFunction( - ensemble, [V for _ in range(len(targets))]) + control_space = fd.EnsembleFunctionSpace( + [V for _ in range(len(targets))], ensemble) + control = fd.EnsembleFunction(control_space) for x in control.subfunctions: x.assign(background) diff --git a/advection/advection_wc4dvar_pyadj.py b/advection/advection_wc4dvar_pyadj.py index c4d9bf4..a4825e8 100644 --- a/advection/advection_wc4dvar_pyadj.py +++ b/advection/advection_wc4dvar_pyadj.py @@ -12,10 +12,10 @@ continue_annotation() # background functional -J = norm2(B)(control[0] - background) +J = covariance_norm(control[0] - background, B) # initial observation functional -J += norm2(R)(observation_error(0)(control[0])) +J += covariance_norm(observation_error(0)(control[0]), R) nstep = 0 for i in range(1, len(control)): @@ -33,10 +33,10 @@ control[i].assign(qn) # model error functional - J += norm2(Q)(qn - control[i]) + J += covariance_norm(qn - control[i], Q) # observation functional - J += norm2(R)(observation_error(i)(control[i])) + J += covariance_norm(observation_error(i)(control[i]), R) pause_annotation() diff --git a/burgers/aaorf_4dvar.py b/burgers/aaorf_4dvar.py index 408c692..007840a 100644 --- a/burgers/aaorf_4dvar.py +++ b/burgers/aaorf_4dvar.py @@ -185,36 +185,19 @@ def H(x, name=None): global_comm.Barrier() -# weighted l2 inner product -def wl2prod(x, w=1.0, ad_block_tag=None): - return fd.assemble(fd.inner(x, w*x)*fd.dx, ad_block_tag=ad_block_tag)**2 - - def observation_err(i, state, name=None): return fd.Function(Vobs, name=f'Observation error H{i}(x{i}) - y{i}').assign(H(state, name) - y[i], ad_block_tag=f"Observation error calculation {i}") -background_iprod = partial(wl2prod, w=B, ad_block_tag='Background inner product') -observation_iprod = partial(wl2prod, w=R, ad_block_tag='Observation inner product') -model_iprod = partial(wl2prod, w=Q, ad_block_tag='Model inner product') - # Initialise forward model from prior/background initial conditions # and accumulate weak constraint functional as we go uapprox = [background.copy(deepcopy=True, annotate=False)] # Only initial rank needs data for initial conditions or time -if trank == 0: - background_iprod0 = background_iprod - if initial_observations: - observation_iprod0 = observation_iprod - observation_err0 = partial(observation_err, 0, name='Model observation 0') - else: - observation_iprod0 = None - observation_err0 = None +if trank == 0 and initial_observations: + observation_err0 = partial(observation_err, 0, name='Model observation 0') else: - background_iprod0 = None - observation_iprod0 = None observation_err0 = None ################################################## @@ -226,18 +209,20 @@ def observation_err(i, state, name=None): # first rank has one extra control for the initial conditions nlocal_controls = nlocal_observations + (1 if trank == 0 else 0) -aaofunc = fd.EnsembleFunction(ensemble, [V for _ in range(nlocal_controls)]) +control_space = fd.EnsembleFunctionSpace( + [V for _ in range(nlocal_controls)], ensemble) +control = fd.EnsembleFunction(control_space) if trank == 0: - aaofunc.subfunctions[0].assign(background) + control.subfunctions[0].assign(background) continue_annotation() Jhat = FourDVarReducedFunctional( - Control(aaofunc), - background_iprod=background_iprod0, - observation_iprod=observation_iprod0, - observation_err=observation_err0, + Control(control), + background_covariance=B, + observation_covariance=R, + observation_error=observation_err0, weak_constraint=(args.constraint == 'weak')) Jhat.background.topological.rename("Background") @@ -282,13 +267,10 @@ def observation_err(i, state, name=None): observation_err, local_obs_idx, name=f'Model observation {stage.observation_index}') - model_iprod = partial(wl2prod, w=Q, - ad_block_tag=f'Model inner product {stage.observation_index}') - # record the observation at the end of the stage stage.set_observation(un, obs_error, - observation_iprod=observation_iprod, - forward_model_iprod=model_iprod) + observation_covariance=R, + forward_model_covariance=Q) # PETSc.Sys.Print(f"{fd.norm(un) = }") global_comm.Barrier() diff --git a/burgers/burgers_wc4dvar_demo.py b/burgers/burgers_wc4dvar_demo.py index 574653f..af9fdfa 100644 --- a/burgers/burgers_wc4dvar_demo.py +++ b/burgers/burgers_wc4dvar_demo.py @@ -167,19 +167,10 @@ def H(x, name=None): global_comm.Barrier() -# weighted l2 inner product -def wl2prod(x, w=1.0, ad_block_tag=None): - return fd.assemble(fd.inner(x, w*x)*fd.dx, ad_block_tag=ad_block_tag) - - def observation_err(i, state, name=None): return fd.Function(Vobs, name=f'Observation error H{i}(x{i}) - y{i}').assign(H(state, name) - y[i], ad_block_tag=f"Observation error calculation {i}") -background_iprod = partial(wl2prod, w=B, ad_block_tag='Background inner product') -observation_iprod = partial(wl2prod, w=R, ad_block_tag='Observation inner product') -model_iprod = partial(wl2prod, w=Q, ad_block_tag='Model inner product') - # Initialise forward model from prior/background initial conditions # and accumulate weak constraint functional as we go @@ -189,12 +180,9 @@ def observation_err(i, state, name=None): ### Create the 4dvar reduced functional ################################################## -background_iprod0 = background_iprod if initial_observations: - observation_iprod0 = observation_iprod observation_err0 = partial(observation_err, 0, name='Model observation 0') else: - observation_iprod0 = None observation_err0 = None ## Make sure this is the only point we are requiring user to know partition specifics @@ -202,18 +190,20 @@ def observation_err(i, state, name=None): # first rank has one extra control for the initial conditions nlocal_controls = nlocal_observations + (1 if trank == 0 else 0) -aaofunc = fd.EnsembleFunction(ensemble, [V for _ in range(nlocal_controls)]) +control_space = fd.EnsembleFunctionSpace( + [V for _ in range(nlocal_controls)], ensemble) +control = fd.EnsembleFunction(control_space) if trank == 0: - aaofunc.subfunctions[0].assign(background) + control.subfunctions[0].assign(background) continue_annotation() Jhat = FourDVarReducedFunctional( - Control(aaofunc), - background_iprod=background_iprod0, - observation_iprod=observation_iprod0, - observation_err=observation_err0, + Control(control), + background_covariance=B, + observation_covariance=R, + observation_error=observation_err0, weak_constraint=(args.constraint == 'weak')) Jhat.background.topological.rename("Background") @@ -222,8 +212,7 @@ def observation_err(i, state, name=None): PETSc.Sys.Print("Running forward model") global_comm.Barrier() -observation_idx = 1 if initial_observations else 0 -obs_offset = observation_idx +obs_offset = 1 if initial_observations else 0 ################################################## ### Record the forward model and observations @@ -256,13 +245,10 @@ def observation_err(i, state, name=None): observation_err, local_obs_idx, name=f'Model observation {stage.observation_index}') - model_iprod = partial(wl2prod, w=Q, - ad_block_tag=f'Model inner product {stage.observation_index}') - # record the observation at the end of the stage stage.set_observation(un, obs_error, - observation_iprod=observation_iprod, - forward_model_iprod=model_iprod) + observation_covariance=R, + forward_model_covariance=Q) global_comm.Barrier() diff --git a/fdvar/tao_solver.py b/fdvar/tao_solver.py index 9ff68ba..193a6a0 100644 --- a/fdvar/tao_solver.py +++ b/fdvar/tao_solver.py @@ -112,14 +112,15 @@ def __init__(self, Jhat): self._ycofunc = self._xfunc.riesz_representation() # TODO: Just implement EnsembleFunction._ad_convert_type - v = fd.TestFunction(self._xfunc._function_space) - self.M = fd.inner(v, self._xfunc._fbuf)*fd.dx + efs = Jhat.control_space + v = fd.TestFunction(efs._full_local_space) + self.M = fd.inner(v, self._xfunc._full_local_function)*fd.dx def mult(self, mat, x, y): with self._xfunc.vec_wo() as xvec: x.copy(xvec) - fd.assemble(self.M, tensor=self._ycofunc._fbuf) + fd.assemble(self.M, tensor=self._ycofunc._full_local_function) with self._ycofunc.vec_ro() as yvec: yvec.copy(y) From 01a3b93a2e043f6f5bcfab125ba154cb5d98696f Mon Sep 17 00:00:00 2001 From: Josh Hope-Collins Date: Tue, 18 Feb 2025 10:52:52 +0000 Subject: [PATCH 07/16] WIP: saddle point pc impl --- advection/advection_wc4dvar_covariancemat.py | 57 ++ advection/advection_wc4dvar_saddlepc.py | 358 +++--------- fdvar/mat.py | 554 +++++++++++++++++++ fdvar/pc.py | 57 ++ 4 files changed, 751 insertions(+), 275 deletions(-) create mode 100644 advection/advection_wc4dvar_covariancemat.py create mode 100644 fdvar/mat.py create mode 100644 fdvar/pc.py diff --git a/advection/advection_wc4dvar_covariancemat.py b/advection/advection_wc4dvar_covariancemat.py new file mode 100644 index 0000000..9c303d4 --- /dev/null +++ b/advection/advection_wc4dvar_covariancemat.py @@ -0,0 +1,57 @@ +import firedrake as fd +from firedrake.petsc import PETSc, OptionsManager +from firedrake.adjoint import pyadjoint # noqa: F401 +from firedrake.matrix import ImplicitMatrix +from pyadjoint.optimization.tao_solver import PETScVecInterface +from advection_wc4dvar_aaorf import make_fdvrf +from mpi4py import MPI +import numpy as np +from typing import Optional, Collection +from fdvar.mat import * + + +def CovarianceMat(covariancerf): + space = covariancerf.controls[0].control.function_space() + comm = space.mesh().comm + covariance = covariancerf.covariance + if isinstance(covariance, Collection): + covariance, power = covariance + + sizes = space.dof_dset.layout_vec.sizes + shape = (sizes, sizes) + covmat = PETSc.Mat().createConstantDiagonal( + shape, covariance, comm=comm) + + covmat.setUp() + covmat.assemble() + return covmat + + +Jhat, control = make_fdvrf() +ensemble = Jhat.ensemble + +# >>>>> Covariance + +# Covariance Mat +# covrf = Jhat.background_norm +covrf = Jhat.stages[0].model_norm +covmat = CovarianceMat(covrf) + +# Covariance KSP +covksp = PETSc.KSP().create(comm=ensemble.comm) +covksp.setOptionsPrefix('cov_') +covksp.setOperators(covmat) + +covksp.pc.setType(PETSc.PC.Type.JACOBI) +covksp.setType(PETSc.KSP.Type.PREONLY) +covksp.setFromOptions() +covksp.setUp() +print(PETSc.Options().getAll()) + +x = covmat.createVecRight() +b = covmat.createVecLeft() + +b.array_w[:] = np.random.random_sample(b.array_w.shape) +print(f'{b.norm() = }') +covksp.solve(b, x) +print(f'{x.norm() = }') diff --git a/advection/advection_wc4dvar_saddlepc.py b/advection/advection_wc4dvar_saddlepc.py index ab5ec7f..60dbe66 100644 --- a/advection/advection_wc4dvar_saddlepc.py +++ b/advection/advection_wc4dvar_saddlepc.py @@ -1,281 +1,89 @@ import firedrake as fd -from firedrake.petsc import PETSc, OptionsManager -from firedrake.adjoint import pyadjoint # noqa: F401 -from pyadjoint.optimization.tao_solver import PETScVecInterface from advection_wc4dvar_aaorf import make_fdvrf -from mpi4py import MPI -import numpy as np -from typing import Optional - - -# CovarianceNormRF Mat -def CovarianceMat(covariancerf): - covariance = covariancerf.covariance - space = covariancerf.controls[0].control.function_space() - comm = space.mesh().comm - sizes = space.dof_dset.layout_vec.sizes - shape = (sizes, sizes) - covmat = PETSc.Mat().createConstantDiagonal( - shape, covariance, comm=comm) - covmat.setUp() - covmat.assemble() - return covmat - - -# pyadjoint RF Mat -class ReducedFunctionalMatCtx: - """ - PythonMat context to apply action of a pyadjoint.ReducedFunctional. - - Parameters - ---------- - - action_type - Union['hessian', 'tlm', 'adjoint'] - """ - def __init__(self, Jhat: pyadjoint.ReducedFunctional, - action_type: str = 'hessian', - derivative_options: Optional[dict] = None, - comm: MPI.Comm = PETSc.COMM_WORLD): - self.Jhat = Jhat - self.control_interface = PETScVecInterface(Jhat.controls, comm=comm) - self.functional_interface = PETScVecInterface( - Jhat.functional, comm=comm) - - if action_type == 'hessian': - self.xinterface = self.control_interface - self.yinterface = self.control_interface - elif action_type == 'adjoint': - self.xinterface = self.functional_interface - self.yinterface = self.control_interface - elif action_type == 'tlm': - self.xinterface = self.control_interface - self.yinterface = self.functional_interface - else: - raise ValueError( - 'Unrecognised {action_type = }.') - - self.action_type = action_type - self._m = Jhat.control.copy() - self.derivative_options = derivative_options - self._shift = 0 - - self.default_mult = { - 'hessian': self._mult_hessian, - 'tlm': self._mult_tlm, - 'adjoint': self._mult_adjoint - }[action_type] - - # Storage for result of action. - # Possibly in the dual space for adjoint actions. - if action_type == 'adjoint': - self.Jhat(self._m) - self._mdot = self.Jhat( - derivative_options=derivative_options) - else: - self._mdot = Jhat.control.copy() - - @classmethod - def update(cls, obj, x, A, P): - ctx = A.getPythonContext() - ctx.control_interface.from_petsc(x, ctx._m) - ctx._shift = 0 - - def update_tape_values(self, update_adjoint=True): - _ = self.Jhat(self._m) - if update_adjoint: - _ = self.Jhat.derivative(options=self.derivative_options) - - def mult(self, A, x, y): - self.xinterface.from_petsc(x, self._mdot) - out = self.default_mult(A, self._mdot) - self.yinterface.to_petsc(out, y) - if self._shift != 0: - y.axpy(self._shift, x) - - def _mult_hessian(self, A, x): - if self.action_type != 'hessian': - raise NotImplementedError( - f'Cannot apply hessian action if {self.action_type = }') - self.update_tape_values() - return self.Jhat.hessian(x) - - def _mult_tlm(self, A, x): - if self.action_type != 'tlm': - raise NotImplementedError( - f'Cannot apply tlm action if {self.action_type = }') - self.update_tape_values(update_adjoint=False) - return self.Jhat.tlm(x) - - def _mult_adjoint(self, A, x): - if self.action_type != 'adjoint': - raise NotImplementedError( - f'Cannot apply adjoint action if {self.action_type = }') - self.update_tape_values(update_adjoint=False) - return self.Jhat.derivative( - adj_value=x, derivative_options=self.derivative_options) - - -def ReducedFunctionalMat(self, Jhat, action_type='hessian', - derivative_options=None, - comm=PETSc.COMM_WORLD): - ctx = ReducedFunctionalMatCtx( - Jhat, action_type, - derivative_options, - comm=comm) - # TODO: use functional_interface sizes - # to allow non-square matrices. - n = ctx.control_interface.n - N = ctx.control_interface.N - mat = PETSc.Mat().createPython( - ((n, N), (n, N)), ctx, comm=comm) - mat.setUp() - mat.assemble() - return mat - - -class EnsembleBlockDiagonalMat: - def __init__(self, ensemble, spaces, blocks): - if isinstance(spaces, fd.EnsembleFunction): - if spaces.ensemble is not ensemble: - raise ValueError( - "Ensemble of EnsembleFunction must match ensemble provided") - spaces = spaces.local_function_spaces - if len(blocks) != len(spaces): - raise ValueError( - f"EnsembleBlockDiagonalMat requires one submatrix for each of the" - f" {len(spaces)} local subfunctions of theEnsembleFunction, but" - f" only {len(blocks)} provided.") - - for i, (subspace, block) in enumerate(zip(spaces, blocks)): - vsizes = subspace.dof_dset.layout_vec.sizes - msizes = block.sizes - if msizes[0] != msizes[1]: - raise ValueError( - f"Block {i} of EnsembleBlockDiagonalMat must be square, not {msizes}") - if msizes[0] != vsizes: - raise ValueError( - f"Block {i} of EnsembleBlockDiagonalMat must have shape {(vsizes, vsizes)}" - f" to match the EnsembleFunction, not shape {msizes}") - - self.ensemble = ensemble - self.blocks = blocks - self.spaces = spaces - - # EnsembleFunction knows how to split out subvecs for each block - self.x = fd.EnsembleFunction(self.ensemble, spaces) - self.y = fd.EnsembleCofunction(self.ensemble, [V.dual() for V in spaces]) - - def mult(self, A, x, y): - with self.x.vec_wo() as xvec: - x.copy(xvec) - - # compute answer - subvecs = zip(self.x.subfunctions, self.y.subfunctions) - for block, (xsub, ysub) in zip(self.blocks, subvecs): - with xsub.dat.vec_ro as xvec, ysub.dat.vec_wo as yvec: - block.mult(xvec, yvec) - - with self.y.vec_ro() as yvec: - yvec.copy(y) - - -class EnsembleBlockDiagonalPC: - prefix = "ensemblejacobi_" - - def __init__(self): - self.initialized = False - - def setUp(self, pc): - if not self.initialized: - self.initialize(pc) - self.update(pc) - - def initialize(self, pc): - if pc.getType() != "python": - raise ValueError("Expecting PC type python") - - pcprefix = pc.getOptionsPrefix() - prefix = pcprefix + self.prefix - options = PETSc.Options(prefix) - - _, P = pc.getOperators() - ensemble_mat = P.getPythonContext() - ensemble = ensemble_mat.ensemble - spaces = ensemble_mat.spaces - submats = ensemble_mat.blocks - - self.ensemble = ensemble - self.spaces = spaces - self.submats = submats - - self.x = fd.EnsembleFunction(ensemble, spaces) - self.y = fd.EnsembleCofunction(ensemble, - [V.dual() for V in spaces]) - - subksps = [] - for i, mat in enumerate(submats): - ksp = PETSc.KSP().create(comm=ensemble.comm) - ksp.setOperators(mat) - - sub_prefix = pcprefix + f"sub_{i}_" - # TODO: default options - options = OptionsManager({}, sub_prefix) - options.set_from_options(ksp) - self.subksps.append((ksp, options)) - - self.subksps = tuple(subksps) - - def apply(self, pc, x, y): - with self.x.vec_wo() as xvec: - x.copy(xvec) - - subfuncs = zip(self.x.subfunctions, self.y.subfunctions) - for (subksp, suboptions), (subx, suby) in zip(self.subksps, subfuncs): - with subx.dat.vec_ro as rhs, suby.dat.vec_wo as sol: - with suboptions.inserted_options(): - subksp.solve(rhs, sol) - - with self.y.vec_ro() as yvec: - yvec.copy(y) - - -class EnsembleMat: - def __init__(self, ensemblefunction, ctx): - self.ensemble = ensemblefunction.ensemble - sizes = ensemblefunction._vec.sizes - self.petsc_mat = PETSc.Mat().createPython( - (sizes, sizes), ctx, - comm=self.ensemble.comm) - self.petsc_mat.setUp() - self.petsc_mat.assemble() - +from fdvar.mat import * Jhat, control = make_fdvrf() ensemble = Jhat.ensemble -# >>>>> Covariance - -# Covariance Mat -# covrf = Jhat.background_norm -covrf = Jhat.stages[0].model_norm -covmat = CovarianceMat(covrf) - -# Covariance KSP -covksp = PETSc.KSP().create(comm=ensemble.comm) -covksp.setOptionsPrefix('cov_') -covksp.setOperators(covmat) - -covksp.pc.setType(PETSc.PC.Type.JACOBI) -covksp.setType(PETSc.KSP.Type.PREONLY) -covksp.setFromOptions() -covksp.setUp() -print(PETSc.Options().getAll()) - -x = covmat.createVecRight() -b = covmat.createVecLeft() - -b.array_w[:] = np.random.random_sample(b.array_w.shape) -print(f'{b.norm() = }') -covksp.solve(b, x) -print(f'{x.norm() = }') +saddlepoint_params = { + 'ksp_monitor': None, + 'ksp_converged_rate': None, + 'ksp_rtol': 1e-5, + 'ksp_type': 'fgmres', + 'pc_type': 'fieldsplit', + 'pc_fieldsplit_type': 'schur', + 'pc_fieldsplit_0_fields': '0,1', # schur complement on dx + 'pc_fieldsplit_1_fields': '2', + + # diagonal pc + 'pc_fieldsplit_schur_fact_type': 'diag', + + # triangular pc + 'pc_fieldsplit_schur_fact_type': 'upper', + + # inexact constraint pc + 'pc_fieldsplit_0_fields': '2', + 'pc_fieldsplit_1_fields': '1', + 'pc_fieldsplit_2_fields': '0', + 'pc_fieldsplit_type': 'multiplicative', + + 'fieldsplit_0': { # D + 'ksp_type': 'preonly', + 'pc_type': 'python', + 'pc_python_type': 'EnsembleBJacobiPC', + 'sub': { + 'ksp_rtol': 1e-5, + 'ksp_type': 'cg', + 'pc_type': 'gamg', + 'mg_levels': { + 'ksp_type': 'chebyshev', + 'ksp_max_it': 2, + 'pc_type': 'bjacobi', + 'sub_pc_type': 'ilu', + } + } + } + 'fieldsplit_1': { # R + 'ksp_type': 'preonly', + 'pc_type': 'python', + 'pc_python_type': 'EnsembleBJacobiPC', + 'sub': { + 'ksp_rtol': 1e-5, + 'ksp_type': 'cg', + 'pc_type': 'icc', + } + } + 'fieldsplit_2': { # S + 'ksp_type': 'preonly', + 'pc_type': 'python', + 'pc_python_type': 'SaddleSchurPC', + 'pc_saddle_schur_observation_type': 'none', # or low-rank? + 'model_sub': { # L = I + 'ksp_type': 'preonly', + 'pc_type': 'none', + } + 'model_sub': { # Low bandwidth approx - its=N is direct solve + 'ksp_max_it': 1, + 'ksp_type': 'richardson', + 'pc_type': 'none', + } + 'model_sub': { # L_M pc - bjacobi with bsize = k + 'ksp_type': 'preonly', + 'pc_type': 'python', + 'pc_python_type': 'AllAtOnceBJacobiPC', + 'pc_aaobjacobi_bsize': 4, + 'sub': { + 'ksp_max_it': 4, + 'ksp_type': 'richardson', + 'pc_type': 'none', + } + } + } +} + +ksp, options = FDVarSaddlePointKSP(fdvrf, saddlepoint_params) +b = FDVarSaddlePointRHS(fdvrf) # TODO +x = b.duplicate() + +with options.inserted_options(): + ksp.solve(b, x) diff --git a/fdvar/mat.py b/fdvar/mat.py new file mode 100644 index 0000000..eeecc57 --- /dev/null +++ b/fdvar/mat.py @@ -0,0 +1,554 @@ +import firedrake as fd +from firedrake.petsc import PETSc +from firedrake.petsc import OptionsManager +from firedrake.adjoint import ReducedFunctional +from firedrake.adjoint.fourdvar_reduced_functional import CovarianceNormReducedFunctional +from pyop2.mpi import MPI +from pyadjoint.optimization.tao_solver import PETScVecInterface +from pyadjoint.enlisting import Enlist +from typing import Optional +from enum import Enum +from functools import partial + + +class ISNest: + attrs = ( + 'getSizes', + 'getSize', + 'getLocalSize', + 'getIndices', + 'getComm', + ) + def __init__(self, ises): + self._comm = ises[0].getComm() + self._ises = ises + for attr in self.attrs: + setattr(self, attr, + partial(self._getattr, attr)) + + def _getattr(self, attr, i, *args, **kwargs): + return getattr(self.ises[i], attr)(*args, **kwargs) + + @cached_property + def globalSize(self): + return sum(self.getSize(i) for i in range(len(self))) + + @cached_property + def localSize(self): + return sum(self.getLocalSize(i) for i in range(len(self))) + + @property + def sizes(self): + return (self.localSize, self.globalSize) + + @property + def comm(self): + return self._comm + + @property + def ises(self): + return self._ises + + def __getitem__(self, i): + return self.ises[i] + + def __len__(self): + return len(self.ises) + + def __iter__(self): + return iter(self.ises) + + def createVec(self, i=None, vec_type=PETSc.Vec.Type.MPI): + vec = PETSc.Vec().create(comm=self.comm) + vec.setType(vec_type) + vec.setSizes(self.sizes if i is None else self.getSizes(i)) + return vec + + def createVecs(self, vec_type=PETSc.Vec.Type.MPI): + return (self.createVec(i, vec_type=vec_type) for i in range(len(self))) + + def createVecNest(self, vecs=None): + if vecs is None: + vecs = self.createVecs() + else: + if all(vecs[i].getSizes() != self.getSizes(i) for i in range(len(self))): + raise ValueError("vec sizes must match is sizes") + return PETSc.Vec().createNest(vecs, self.ises, self.comm) + + +class RFAction(Enum): + TLM = 'tlm' + Adjoint = 'adjoint' + Hessian = 'hessian' +TLM = RFAction.TLM +Adjoint = RFAction.Adjoint +Hessian = RFAction.Hessian + + +def copy_controls(controls): + return controls.delist([c.copy() for c in controls]) + + +def convert_types(overloaded, options): + overloaded = Enlist(overloaded) + return overloaded.delist([o._ad_convert_type(o, options=options) + for o in overloaded]) + + +def saddle_ises(fdvrf): + ensemble = fdvrf.ensemble + rank = ensemble.ensemble_rank + global_comm = ensemble.global_comm + + Vsol = fdvrf.solution_space + Vobs = fdvrf.observation_space + Vs = Vsol.local_spaces[0] + Vo = Vobs.local_spaces[0] + + # ndofs per (dn, dl, dx) block + bsol = Vsol.dof_dset.layout_vec.getLocalSize() + bobs = Vobs.dof_dset.layout_vec.getLocalSize() + nlocal_blocks = Vsol.nlocal_spaces + + bsize_dn = bsol + bsize_dl = bobs + bsize_dx = bsol + bsize = bsize_dn + bsize_dl + bsize_dx + + # number of blocks on previous ranks + nprev_blocks = ensemble.ensemble_comm.exscan(nlocal_blocks) + if rank == 0: # exscan returns None + nprev_blocks = 0 + + # offset to start of global indices of each field in local block j + offset = bsize*nprev_blocks + offset_dn = lambda j: offset + j*bsize + offset_dl = lambda j: offset_dn(j) + bsize_dn + offset_dx = lambda j: offset_dl(j) + bsize_dl + + indices_dn = np.concatenate( + [offset_dn(j) + np.arange(bsize_dn, dtype=np.int32) + for j in range(nlocal_blocks)]) + + indices_dl = np.concatenate( + [offset_dl(j) + np.arange(bsize_dl, dtype=np.int32) + for j in range(nlocal_blocks)]) + + indices_dx = np.concatenate( + [offset_dx(j) + np.arange(bsize_dx, dtype=np.int32) + for j in range(nlocal_blocks)]) + + is_dn = PETSc.IS().createGeneral(indices_dn, comm=global_comm) + is_dl = PETSc.IS().createGeneral(indices_dl, comm=global_comm) + is_dx = PETSc.IS().createGeneral(indices_dx, comm=global_comm) + + return ISNest((is_dn, is_dl, is_dx)) + + +def FDVarSaddlePointKSP(fdvrf, solver_parameters, options_prefix=None): + ensemble = fdvrf.ensemble + + saddlemat = FDVarSaddlePointMat(fdvrf) + + options = OptionsManager(solver_parameters, options_prefix) + + ksp = PETSc.KSP().create(comm=ensemble.global_comm) + options.set_from_options(ksp) + + return ksp, options + + +# Saddle-point MatNest +def FDVarSaddlePointMat(fdvrf): + ensemble = fdvrf.ensemble + + isnest = saddle_ises(fdvrf) + dn_is, dl_is, dx_is = isnest.ises + + # L Mat + L = AllAtOnceRFMat(fdvrf, action=TLM) + Lt = AllAtOnceRFMat(fdvrf, action=Adjoint) + + Lrow = dn_is + Lcol = dx_is + + Ltrow = dx_is + Ltcol = dn_is + + # H Mat + H = ObservationEnsembleRFMat(fdvrf, action=TLM) + Ht = ObservationEnsembleRFMat(fdvrf, action=Adjoint) + + Hrow = dl_is + Hcol = dx_is + + Htrow = dx_is + Htcol = dl_is + + # D Mat + D = ModelCovarianceEnsembleRFMat(fdvrf) + + Drow = dn_is + Dcol = dn_is + + # R Mat + R = ObservationCovarianceEnsembleRFMat(fdvrf) + + Rrow = dl_is + Rcol = dl_is + + fdvmat = PETSc.Mat().createNest( + mats=[D, L, # noqa: E127,E202 + R, H, # noqa: E127,E202 + Lt, Ht ], # noqa: E127,E202 + isrows=[Drow, Lrow, # noqa: E127,E202 + Rrow, Hrow, # noqa: E127,E202 + Ltrow, Htrow ], # noqa: E127,E202 + iscols=[Dcol, Lcol, # noqa: E127,E202 + Rcol, Hcol, # noqa: E127,E202 + Ltcol, Htcol ], # noqa: E127,E202 + comm=ensemble.global_comm) + + return fdvmat + + +class ReducedFunctionalMatCtx: + """ + PythonMat context to apply action of a pyadjoint.ReducedFunctional. + + Parameters + ---------- + + action : RFAction + """ + def __init__(self, Jhat: ReducedFunctional, + action: str = Hessian, + options: Optional[dict] = None, + input_options: Optional[dict] = None, + comm: MPI.Comm = PETSc.COMM_WORLD): + self.Jhat = Jhat + self.control_interface = PETScVecInterface(Jhat.controls, comm=comm) + self.functional_interface = PETScVecInterface(Jhat.functional, comm=comm) + + if action == Hessian: # control -> control + self.xinterface = self.control_interface + self.yinterface = self.control_interface + self.x = copy_controls(Jhat.controls) + self.mult_impl = self._mult_hessian + + elif action == Adjoint: # functional -> control + self.xinterface = self.functional_interface + self.yinterface = self.control_interface + self.x = Jhat.functional._ad_copy() + self.mult_impl = self._mult_adjoint + + elif action == TLM: # control -> functional + self.xinterface = self.control_interface + self.yinterface = self.functional_interface + self.x = copy_controls(Jhat.controls) + self.mult_impl = self._mult_tlm + else: + raise ValueError( + 'Unrecognised {action = }.') + + self.action = action + self._m = copy_controls(Jhat.controls) + self.input_options = input_options + self.options = options + self._shift = 0 + + @classmethod + def update(cls, obj, x, A, P): + ctx = A.getPythonContext() + ctx.control_interface.from_petsc(x, ctx._m) + ctx._shift = 0 + + def update_tape_values(self, update_adjoint=True): + _ = self.Jhat(self._m) + if update_adjoint: + _ = self.Jhat.derivative(options=self.options) + + def mult(self, A, x, y): + self.xinterface.from_petsc(x, self.x) + if self.input_options is None: + _x = self.x + else: + _x = convert_types(self.x, self.input_options) + out = self.mult_impl(A, _x) + self.yinterface.to_petsc(y, out) + if self._shift != 0: + y.axpy(self._shift, x) + + def _mult_hessian(self, A, x): + if self.action != Hessian: + raise NotImplementedError( + f'Cannot apply hessian action if {self.action = }') + self.update_tape_values(update_adjoint=True) + return self.Jhat.hessian(x, options=self.options) + + def _mult_tlm(self, A, x): + if self.action != TLM: + raise NotImplementedError( + f'Cannot apply tlm action if {self.action = }') + self.update_tape_values(update_adjoint=False) + return self.Jhat.tlm(x, options=self.options) + + def _mult_adjoint(self, A, x): + if self.action != Adjoint: + raise NotImplementedError( + f'Cannot apply adjoint action if {self.action = }') + self.update_tape_values(update_adjoint=False) + return self.Jhat.derivative(adj_input=x, options=self.options) + + +def ReducedFunctionalMat(Jhat, action=Hessian, + options=None, input_options=None, + comm=PETSc.COMM_WORLD): + ctx = ReducedFunctionalMatCtx( + Jhat, action, + options=options, + input_options=input_options, + comm=comm) + + ncol = ctx.xinterface.n + Ncol = ctx.xinterface.N + + nrow = ctx.yinterface.n + Nrow = ctx.yinterface.N + + mat = PETSc.Mat().createPython( + ((nrow, Nrow), (ncol, Ncol)), + ctx, comm=comm) + mat.setUp() + mat.assemble() + return mat + + +def EnsembleMat(ctx, row_space, col_space=None): + if col_space is None: + col_space = row_space + + # number of columns is row length, and vice-versa + ncol = row_space.nlocal_rank_dofs + Ncol = row_space.nglobal_dofs + + nrow = col_space.nlocal_rank_dofs + Nrow = col_space.nglobal_dofs + + mat = PETSc.Mat().createPython( + ((nrow, Nrow), (ncol, Ncol)), ctx, + comm=row_space.ensemble.global_comm) + mat.setUp() + mat.assemble() + return mat + + +class EnsembleMatCtxBase: + def __init__(self, row_space, col_space=None): + if col_space is None: + col_space = row_space + + if not isinstance(row_space, fd.EnsembleFunctionSpace): + raise ValueError( + f"EnsembleMat row_space must be EnsembleFunctionSpace not {type(row_space).__name__}") + if not isinstance(col_space, fd.EnsembleFunctionSpace): + raise ValueError( + f"EnsembleMat col_space must be EnsembleFunctionSpace not {type(col_space).__name__}") + + self.row_space = row_space + self.col_space = col_space + + # input/output Vecs will be copied in/out of these + # so that base classes can implement mult only in + # terms of Ensemble objects not Vecs. + self.x = fd.EnsembleFunction(self.row_space) + self.y = fd.EnsembleFunction(self.col_space.dual()) + + def mult(self, A, x, y): + with self.x.vec_wo() as xvec: + x.copy(xvec) + + self.mult_impl(A, self.x, self.y) + + with self.y.vec_ro() as yvec: + yvec.copy(y) + + +class EnsembleBlockDiagonalMatCtx(EnsembleMatCtxBase): + def __init__(self, blocks, row_space, col_space=None): + super().__init__(row_space, col_space) + self.blocks = blocks + + if self.row_space.nlocal_spaces != self.col_space.nlocal_spaces: + raise ValueError( + "EnsembleBlockDiagonalMat row and col spaces must be the same length," + f" not {row_space.nlocal_spaces = } and {col_space.nlocal_spaces = }") + + if len(self.blocks) != self.row_space.nlocal_spaces: + raise ValueError( + f"EnsembleBlockDiagonalMat requires one submatrix for each of the" + f" {self.row_space.nlocal_spaces} local subfunctions of the EnsembleFunctionSpace," + f" but only {len(self.blocks)} provided.") + + for i, (Vrow, Vcol, block) in enumerate(zip(self.row_space.local_spaces, + self.col_space.local_spaces, + self.blocks)): + # number of columns is row length, and vice-versa + vc_sizes = Vrow.dof_dset.layout_vec.sizes + vr_sizes = Vcol.dof_dset.layout_vec.sizes + mr_sizes, mc_sizes = block.sizes + if (vr_sizes[0] != mr_sizes[0]) or (vr_sizes[1] != mr_sizes[1]): + raise ValueError( + f"Row sizes {mr_sizes} of block {i} and {vr_sizes} of row_space {i} of EnsembleBlockDiagonalMat must match.") + if (vc_sizes[0] != mc_sizes[0]) or (vc_sizes[1] != mc_sizes[1]): + raise ValueError( + f"Col sizes of block {i} and col_space {i} of EnsembleBlockDiagonalMat must match.") + + def mult_impl(self, A, x, y): + for block, xsub, ysub in zip(self.blocks, + self.x.subfunctions, + self.y.subfunctions): + with xsub.dat.vec_ro as xvec, ysub.dat.vec_wo as yvec: + block.mult(xvec, yvec) + + +def EnsembleBlockDiagonalMat(blocks, row_space, col_space=None): + return EnsembleMat( + EnsembleBlockDiagonalMatCtx(blocks, row_space, col_space), + row_space, col_space) + + +# L Mat +class AllAtOnceRFMatCtx(EnsembleMatCtxBase): + def __init__(self, fdvrf, action, **kwargs): + super().__init__(fdvrf.solution_space) + + if action not in (TLM, Adjoint): + raise ValueError( + f"AllAtOnceRFMat action type must be 'tlm' or 'adjoint', not {action}") + + self.action = action + self.fdvrf = fdvrf + self.ensemble = fdvrf.ensemble + + self.models = [ReducedFunctionalMat(M, action, **kwargs) + for M in fdvrf.model_rfs] + + if action == Adjoint: + self.models = reversed(self.models) + + self.xhalo = fd.Function(self.row_space.local_spaces[0]) + self.mx = self.x.copy() + + # Set up list of x_{i-1} functions to propogate. + # TLM means we propogate forwards, and use halo from previous rank. + # Adjoint means we propogate backwards, and use halo from next rank. + # The initial timestep on the initial rank doesn't have a halo. + if self.action == TLM: + self.xprevs = [*self.x.subfunctions[:-1]] + else: + self.xprevs = [*reversed(self.x.subfunctions[1:])] + + initial_rank = 0 if action == TLM else (self.ensemble.ensemble_size - 1) + if self.ensemble.ensemble_rank != initial_rank: + self.xprevs.insert(self.xhalo, 0) + + + def update_halos(self, x): + ensemble_rank = self.ensemble.ensemble_rank + ensemble_size = self.ensemble.ensemble_size + + # halo swap is a right shift + next_rank = (ensemble_rank + 1) % ensemble_size + prev_rank = (ensemble_rank - 1) % ensemble_size + + src = prev_rank if self.action == TLM else next_rank + dst = next_rank if self.action == TLM else prev_rank + + frecv = self.xhalo + fsend = self.x.subfunctions[-1 if self.action == TLM else 0] + + self.ensemble.sendrecv( + fsend=fsend, dest=dst, sendtag=dst, + frecv=frecv, source=src, recvtag=ensemble_rank) + + def mult_impl(self, A, x, y): + self.update_halos(x) + + # propogate from last step + for M, xi, mxi in zip(self.models, self.xprevs, self.mx.subfunctions): + mxi.assign(M.mult(xi)) + + # diagonal contribution + x -= self.M + + y.assign(x.riesz_representation()) + + +def AllAtOnceRFMat(fdvrf, action, **kwargs): + return EnsembleMat( + AllAtOnceRFMatCtx(fdvrf, action, **kwargs), + fdvrf.solution_space) + + +# H Mat +def ObservationEnsembleRFMat(fdvrf, action, **kwargs): + if action == TLM: + row_space = fdvrf.solution_space + col_space = fdvrf.observation_space + elif action == Adjoint: + row_space = fdvrf.observation_space + col_space = fdvrf.solution_space + else: + raise ValueError( + f"Unrecognised matrix action type {action}") + + blocks = [ReducedFunctionalMat(Jobs, action, **kwargs) + for Jobs in fdvrf.observation_rfs] + return EnsembleBlockDiagonalMat(blocks, row_space, col_space) + + +# CovarianceMat +def CovarianceMat(covariancerf, action='mult'): + """ + action='mult' for action of B, or 'inv' for action of B^{-1} + """ + if not isinstance(covariancerf, CovarianceNormReducedFunctional): + raise TypeError( + "CovarianceMat can only be constructed from a CovarianceNormReducedFunctional" + f" not a {type(covariancerf).__name__}") + space = covariancerf.controls[0].control.function_space() + comm = space.mesh().comm + covariance = covariancerf.covariance + + if action == 'mult': + weight = float(covariance) + elif action == 'inv': + weight = float(1/covariance) + else: + raise ValueError(f"Unrecognised action type {action} for CovarianceMat") + + sizes = space.dof_dset.layout_vec.sizes + shape = (sizes, sizes) + covmat = PETSc.Mat().createConstantDiagonal( + shape, covariance, comm=comm) + covmat.setUp() + covmat.assemble() + return covmat + + +# D Mat +def ModelCovarianceEnsembleRFMat(fdvrf, action, **kwargs): + blocks = [CovarianceMat(mnorm, action, **kwargs) + for mnorm in fdvrf.model_norms] + if fdvrf.ensemble.ensemble_rank == 0: + blocks.insert( + 0, CovarianceMat(fdvrf.background_norm, action, **kwargs)) + return EnsembleBlockDiagonalMat(blocks, fdvrf.solution_space) + + +# R Mat +def ObservationCovarianceEnsembleRFMat(fdvrf, action, **kwargs): + blocks = [CovarianceMat(obs_norm, action, **kwargs) + for obs_norm in fdvrf.observation_norms] + return EnsembleBlockDiagonalMat(blocks, fdvrf.observation_space) diff --git a/fdvar/pc.py b/fdvar/pc.py new file mode 100644 index 0000000..d6cd3bb --- /dev/null +++ b/fdvar/pc.py @@ -0,0 +1,57 @@ + + +class EnsembleBlockDiagonalPC: + prefix = "ensemblejacobi_" + + def __init__(self): + self.initialized = False + + def setUp(self, pc): + if not self.initialized: + self.initialize(pc) + self.update(pc) + + def initialize(self, pc): + if pc.getType() != "python": + raise ValueError("Expecting PC type python") + + pcprefix = pc.getOptionsPrefix() + prefix = pcprefix + self.prefix + options = PETSc.Options(prefix) + + _, P = pc.getOperators() + self.mat = P.getPythonContext() + self.function_space = ensemble_mat.function_space + self.ensemble = function_space.ensemble + + submats = ensemble_mat.blocks + + self.x = fd.EnsembleFunction(self.function_space) + self.y = fd.EnsembleCofunction(self.function_space.dual()) + + subksps = [] + for i, submat in enumerate(self.mat.blocks): + ksp = PETSc.KSP().create(comm=ensemble.comm) + ksp.setOperators(mat) + + sub_prefix = pcprefix + f"sub_{i}_" + # TODO: default options + options = OptionsManager({}, sub_prefix) + options.set_from_options(ksp) + self.subksps.append((ksp, options)) + + self.subksps = tuple(subksps) + + def apply(self, pc, x, y): + with self.x.vec_wo() as xvec: + x.copy(xvec) + + subvecs = zip(self.x.subfunctions, self.y.subfunctions) + for (subksp, suboptions), (subx, suby) in zip(self.subksps, subvecs): + with subx.dat.vec_ro as rhs, suby.dat.vec_wo as sol: + with suboptions.inserted_options(): + subksp.solve(rhs, sol) + + with self.y.vec_ro() as yvec: + yvec.copy(y) + From 3607860dd69c527365d2a95298d1259180f2b9fc Mon Sep 17 00:00:00 2001 From: Josh Hope-Collins Date: Tue, 18 Feb 2025 10:57:33 +0000 Subject: [PATCH 08/16] accidentally deleted line --- fdvar/mat.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/fdvar/mat.py b/fdvar/mat.py index eeecc57..1045e1b 100644 --- a/fdvar/mat.py +++ b/fdvar/mat.py @@ -146,13 +146,12 @@ def saddle_ises(fdvrf): def FDVarSaddlePointKSP(fdvrf, solver_parameters, options_prefix=None): - ensemble = fdvrf.ensemble - saddlemat = FDVarSaddlePointMat(fdvrf) - options = OptionsManager(solver_parameters, options_prefix) + ksp = PETSc.KSP().create(comm=fdvrf.ensemble.global_comm) + ksp.setOperators(saddlemat, saddlemat) - ksp = PETSc.KSP().create(comm=ensemble.global_comm) + options = OptionsManager(solver_parameters, options_prefix) options.set_from_options(ksp) return ksp, options @@ -453,7 +452,6 @@ def __init__(self, fdvrf, action, **kwargs): if self.ensemble.ensemble_rank != initial_rank: self.xprevs.insert(self.xhalo, 0) - def update_halos(self, x): ensemble_rank = self.ensemble.ensemble_rank ensemble_size = self.ensemble.ensemble_size From 707b589ea815c88635a9f0c9dd0c99886e66df98 Mon Sep 17 00:00:00 2001 From: Josh Hope-Collins Date: Tue, 11 Mar 2025 13:22:02 +0000 Subject: [PATCH 09/16] wip --- advection/advection_wc4dvar_aaorf_tao.py | 11 ++- fdvar/mat.py | 16 ++--- fdvar/pc.py | 89 ++++++++++++++++++++--- tao/rfmat_rectangular.py | 91 ++++++++++++++++++++++++ tao/rfmat_square.py | 82 +++++++++++++++++++++ 5 files changed, 268 insertions(+), 21 deletions(-) create mode 100644 tao/rfmat_rectangular.py create mode 100644 tao/rfmat_square.py diff --git a/advection/advection_wc4dvar_aaorf_tao.py b/advection/advection_wc4dvar_aaorf_tao.py index 35e0641..c163ee7 100644 --- a/advection/advection_wc4dvar_aaorf_tao.py +++ b/advection/advection_wc4dvar_aaorf_tao.py @@ -28,7 +28,6 @@ ksp_params = { 'monitor': None, 'converged_rate': None, - 'rtol': 1e-1, } tao_params = { @@ -43,11 +42,17 @@ 'tao_type': 'nls', 'tao_nls': { 'ksp': ksp_params, - 'ksp_type': 'cg', - 'pc_type': 'lmvm', + # 'ksp_type': 'cg', + # 'pc_type': 'lmvm', + 'ksp_rtol': 1e-3, + 'ksp_type': 'gmres', + 'pc_type': 'python', + 'pc_python_type': 'fdvar.AllAtOnceJacobiPC', + }, 'tao_cg': { 'ksp': ksp_params, + 'ksp_rtol': 1e-1, 'type': 'pr', # fr-pr-prp-hs-dy }, } diff --git a/fdvar/mat.py b/fdvar/mat.py index 1045e1b..68e763e 100644 --- a/fdvar/mat.py +++ b/fdvar/mat.py @@ -71,7 +71,7 @@ def createVecNest(self, vecs=None): if vecs is None: vecs = self.createVecs() else: - if all(vecs[i].getSizes() != self.getSizes(i) for i in range(len(self))): + if not all(vecs[i].getSizes() == self.getSizes(i) for i in range(len(self))): raise ValueError("vec sizes must match is sizes") return PETSc.Vec().createNest(vecs, self.ises, self.comm) @@ -161,8 +161,7 @@ def FDVarSaddlePointKSP(fdvrf, solver_parameters, options_prefix=None): def FDVarSaddlePointMat(fdvrf): ensemble = fdvrf.ensemble - isnest = saddle_ises(fdvrf) - dn_is, dl_is, dx_is = isnest.ises + dn_is, dl_is, dx_is = saddle_ises(fdvrf) # L Mat L = AllAtOnceRFMat(fdvrf, action=TLM) @@ -411,10 +410,10 @@ def mult_impl(self, A, x, y): block.mult(xvec, yvec) -def EnsembleBlockDiagonalMat(blocks, row_space, col_space=None): +def EnsembleBlockDiagonalMat(row_space, blocks, **kwargs): return EnsembleMat( - EnsembleBlockDiagonalMatCtx(blocks, row_space, col_space), - row_space, col_space) + EnsembleBlockDiagonalMatCtx(blocks, row_space, **kwargs), + row_space, kwargs.get('col_space', None)) # L Mat @@ -473,12 +472,13 @@ def update_halos(self, x): def mult_impl(self, A, x, y): self.update_halos(x) - # propogate from last step + # propogate from last step mx <- M*x for M, xi, mxi in zip(self.models, self.xprevs, self.mx.subfunctions): mxi.assign(M.mult(xi)) # diagonal contribution - x -= self.M + # x_{i} <- x_{i} - M*x_{i-1} + x -= self.mx y.assign(x.riesz_representation()) diff --git a/fdvar/pc.py b/fdvar/pc.py index d6cd3bb..571ddf2 100644 --- a/fdvar/pc.py +++ b/fdvar/pc.py @@ -1,7 +1,10 @@ +import firedrake as fd +from fdvar.mat import EnsembleMatCtxBase -class EnsembleBlockDiagonalPC: - prefix = "ensemblejacobi_" +class PCBase: + needs_python_amat = False + needs_python_pmat = False def __init__(self): self.initialized = False @@ -9,18 +12,63 @@ def __init__(self): def setUp(self, pc): if not self.initialized: self.initialize(pc) + self.initialized = True self.update(pc) def initialize(self, pc): if pc.getType() != "python": raise ValueError("Expecting PC type python") - pcprefix = pc.getOptionsPrefix() - prefix = pcprefix + self.prefix - options = PETSc.Options(prefix) + self.A, self.P = pc.getOperators() + pcname = f"{type(self).__module__}.{type(self).__name__}" + if self.needs_python_amat: + if self.A.getType() != "python": + raise ValueError( + f"PC {pcname} needs a python type amat, not {self.A.getType()}") + self.amat = self.A.getPythonContext() + if self.needs_python_pmat: + if self.P.getType() != "python": + raise ValueError( + f"PC {pcname} needs a python type pmat, not {self.P.getType()}") + self.pmat = self.P.getPythonContext() + + self.parent_prefix = pc.getOptionsPrefix() + self.full_prefix = self.parent_prefix + self.prefix + + def update(self, pc): + pass + + +class EnsemblePCBase(PCBase): + requires_python_amat = True + requires_python_pmat = True + + def __init__(self): + self.initialized = False + + def initialize(self, pc): + super().initialize(pc) + + if not isinstance(self.pmat, EnsembleMatCtxBase): + pcname = f"{type(self).__module__}.{type(self).__name__}" + matname = f"{type(self.pmat).__module__}.{type(self).pmat.__name__}" + raise TypeError( + f"PC {pname} needs an EnsembleMatCtxBase pmat, but it is a {matname}") + + self.row_space = self.pmat.row_space + self.col_space = self.pmat.col_space + + self.x = fd.EnsembleFunction(self.row_space.dual()) + self.y = fd.EnsembleFunction(self.col_space) + + +class EnsembleBlockDiagonalPC(PCBase): + prefix = "ensemblejacobi_" + + def initialize(self, pc): + super().initialize(pc) - _, P = pc.getOperators() - self.mat = P.getPythonContext() + ensemble_mat = self.pmat self.function_space = ensemble_mat.function_space self.ensemble = function_space.ensemble @@ -30,11 +78,11 @@ def initialize(self, pc): self.y = fd.EnsembleCofunction(self.function_space.dual()) subksps = [] - for i, submat in enumerate(self.mat.blocks): + for i, submat in enumerate(self.pmat.blocks): ksp = PETSc.KSP().create(comm=ensemble.comm) - ksp.setOperators(mat) + ksp.setOperators(submat) - sub_prefix = pcprefix + f"sub_{i}_" + sub_prefix = self.parent_prefix + f"sub_{i}_" # TODO: default options options = OptionsManager({}, sub_prefix) options.set_from_options(ksp) @@ -55,3 +103,24 @@ def apply(self, pc, x, y): with self.y.vec_ro() as yvec: yvec.copy(y) + +class AllAtOnceJacobiPC(PCBase): + prefix = "aaojacobi_" + + def initialize(self, pc): + self.fdvrf = self.pmat.fdvrf + if not isinstance(self.Jhat, FourDVarReducedFunctional): + raise TypeError( + "AllAtOnceJacobiPC expects a FourDVarReducedFunctional not a {type(self.Jhat).__name__}") + + def apply(self, pc, x, y): + with self.x.vec_wo() as xvec: + x.copy(xvec) + + # P = LL^{T} + + # apply L^{T} + ltx = self.Jhat.derivative + + with self.y.vec_ro() as yvec: + yvec.copy(y) diff --git a/tao/rfmat_rectangular.py b/tao/rfmat_rectangular.py new file mode 100644 index 0000000..ae5907c --- /dev/null +++ b/tao/rfmat_rectangular.py @@ -0,0 +1,91 @@ +import firedrake as fd +from firedrake.adjoint import ( + ReducedFunctional, Control, set_working_tape, + continue_annotation, pause_annotation) +from fdvar.mat import ReducedFunctionalMat + +mesh = fd.UnitIntervalMesh(3) +x, = fd.SpatialCoordinate(mesh) +expr = 0.5*x*x - 7 + +Vin = fd.FunctionSpace(mesh, "CG", 1) +Vout = fd.FunctionSpace(mesh, "DG", 1) + +u = fd.TrialFunction(Vin) +v = fd.TestFunction(Vout) + +A = (u*v + fd.inner(fd.grad(u), fd.grad(v)) + v*u.dx(0))*fd.dx + +def tlm(y): + return fd.assemble(fd.action(A, y)).riesz_representation() + +def adj(y): + return fd.assemble(fd.action(fd.adjoint(A), y)).riesz_representation() + +continue_annotation() +with set_working_tape() as tape: + x = fd.Function(Vin) + Jhat = ReducedFunctional(tlm(x), Control(x), tape=tape) +pause_annotation() + +def riesz(how): + return {'riesz_representation': how} + +mat_tlm = ReducedFunctionalMat( + Jhat, action_type='tlm', + options={'riesz_representation': None}) + +mat_adj = ReducedFunctionalMat( + Jhat, action_type='adjoint', + input_options={'riesz_representation': 'l2', 'function_space': Vout.dual()}) + +ctx_adj = mat_adj.getPythonContext() + +print("tlm action") + +# manual calculation +uin = fd.Function(Vin).project(expr) +uout = tlm(uin) +print(f"{type(uout) = }") +print(f"{uout.dat.data = }") + +# re-evaluate RF +vin = fd.Function(Vin).project(expr) +vout = Jhat.tlm(vin, options={'riesz_representation': None}) +print(f"{type(vout) = }") +print(f"{vout.dat.data = }") + +# Mat action +win = fd.Function(Vin).project(expr) +wout = fd.Function(Vout) +with win.dat.vec_ro as x, wout.dat.vec_wo as y: + mat_tlm.mult(x, y) +print(f"{type(wout) = }") +print(f"{wout.dat.data = }") + +print() + +print("adjoint action") + +# manual calculation +uin = fd.Function(Vout).project(expr) +uout = adj(uin) +print(f"{type(uout) = }") +print(f"{uout.dat.data = }") + +# re-evaluate RF +vin = fd.Function(Vout).project(expr) +vout = Jhat.derivative(adj_input=vin.riesz_representation()) +print(f"{type(vout) = }") +print(f"{vout.dat.data = }") + +# Mat action +win = fd.Function(Vout).project(expr).riesz_representation() +wout = fd.Function(Vin) +with win.dat.vec_ro as x, wout.dat.vec_wo as y: + mat_adj.mult(x, y) + +print(f"{type(wout) = }") +print(f"{wout.dat.data = }") + +print() diff --git a/tao/rfmat_square.py b/tao/rfmat_square.py new file mode 100644 index 0000000..d5af064 --- /dev/null +++ b/tao/rfmat_square.py @@ -0,0 +1,82 @@ +import firedrake as fd +from firedrake.adjoint import ( + ReducedFunctional, Control, set_working_tape, + continue_annotation, pause_annotation) +from fdvar.mat import ReducedFunctionalMat + +mesh = fd.UnitIntervalMesh(3) +x, = fd.SpatialCoordinate(mesh) +expr = 0.5*x*x - 7 + +V = fd.FunctionSpace(mesh, "CG", 1) + +u = fd.TrialFunction(V) +v = fd.TestFunction(V) + +A = (u*v + fd.inner(fd.grad(u), fd.grad(v)) + v*u.dx(0))*fd.dx + +def tlm(y): + return fd.assemble(fd.action(A, y)) + +def adj(y): + return fd.assemble(fd.action(fd.adjoint(A), y)).riesz_representation() + +continue_annotation() +with set_working_tape() as tape: + x = fd.Function(V) + Jhat = ReducedFunctional(tlm(x), Control(x), tape=tape) +pause_annotation() + +mat_tlm = ReducedFunctionalMat( + Jhat, action_type='tlm') + +mat_adj = ReducedFunctionalMat( + Jhat, action_type='adjoint', + input_options={'riesz_representation': 'l2', 'function_space': V}) + +print("tlm action") + +# manual calculation +uin = fd.Function(V).project(expr) +uout = tlm(uin) +print(f"{type(uout) = }") +print(f"{uout.dat.data = }") + +# re-evaluate RF +vin = fd.Function(V).project(expr) +vout = Jhat.tlm(vin) +print(f"{type(vout) = }") +print(f"{vout.dat.data = }") + +# Mat action +win = fd.Function(V).project(expr) +wout = fd.Function(V.dual()) +with win.dat.vec_ro as x, wout.dat.vec_wo as y: + mat_tlm.mult(x, y) +print(f"{type(wout) = }") +print(f"{wout.dat.data = }") + +print() + +print("adjoint action") + +# manual calculation +uin = fd.Function(V).project(expr) +uout = adj(uin) +print(f"{type(uout) = }") +print(f"{uout.dat.data = }") + +# re-evaluate RF +vin = fd.Function(V).project(expr) +vout = Jhat.derivative(adj_input=vin) +print(f"{type(vout) = }") +print(f"{vout.dat.data = }") + +# Mat action +win = fd.Function(V).project(expr) +wout = fd.Function(V) +with win.dat.vec_ro as x, wout.dat.vec_wo as y: + mat_adj.mult(x, y) + +print(f"{type(wout) = }") +print(f"{wout.dat.data = }") From 99fdc8f41abb9f9390ed40fe4ccbe2135d206c78 Mon Sep 17 00:00:00 2001 From: JHopeCollins Date: Tue, 11 Mar 2025 15:29:17 +0000 Subject: [PATCH 10/16] wip --- advection/advection_wc4dvar_aaorf.py | 22 ++++++++--- advection/advection_wc4dvar_aaorf_tao.py | 49 ++++++++++++------------ fdvar/mat.py | 16 +++++--- fdvar/tao_solver.py | 23 ++++++----- 4 files changed, 66 insertions(+), 44 deletions(-) diff --git a/advection/advection_wc4dvar_aaorf.py b/advection/advection_wc4dvar_aaorf.py index e3bd307..db31442 100644 --- a/advection/advection_wc4dvar_aaorf.py +++ b/advection/advection_wc4dvar_aaorf.py @@ -7,8 +7,14 @@ def make_fdvrf(): + if ensemble.ensemble_size == 1: + nlocal_observations = len(targets) + elif ensemble.ensemble_size == 3: + nlocal_observations = 2 if ensemble.ensemble_rank == 0 else 1 + else: + raise ValueError("Must use either 1 or 3 ensemble ranks") control_space = fd.EnsembleFunctionSpace( - [V for _ in range(len(targets))], ensemble) + [V for _ in range(nlocal_observations)], ensemble) control = fd.EnsembleFunction(control_space) for x in control.subfunctions: @@ -28,7 +34,7 @@ def make_fdvrf(): nstep = 0 # record observation stages - with Jhat.recording_stages() as stages: + with Jhat.recording_stages(nstep=nstep) as stages: # loop over stages for stage, ctx in stages: # start forward model @@ -39,7 +45,7 @@ def make_fdvrf(): qn1.assign(qn) stepper.solve() qn.assign(qn1) - nstep += 1 + ctx.nstep += 1 # take observation obs_err = observation_error(stage.observation_index) @@ -63,8 +69,14 @@ def make_fdvrf(): print(f"{Jhat(control) = }") print(f"{taylor_test(Jhat, control, vals) = }") - options = {'disp': True, 'ftol': 1e-2} + options = { + 'disp': True, + 'ftol': 1e-2 + } derivative_options = {'riesz_representation': 'l2'} - opt = minimize(Jhat, options=options, method="L-BFGS-B", + opt = minimize(Jhat, method="L-BFGS-B", options=options, derivative_options=derivative_options) + J0 = Jhat(control) + Jopt = Jhat(opt) + print(f"{J0 = :.3e} | {Jopt = :.3e} | {Jopt/J0 = :.3e}") diff --git a/advection/advection_wc4dvar_aaorf_tao.py b/advection/advection_wc4dvar_aaorf_tao.py index c163ee7..e9f0ecf 100644 --- a/advection/advection_wc4dvar_aaorf_tao.py +++ b/advection/advection_wc4dvar_aaorf_tao.py @@ -3,6 +3,7 @@ continue_annotation, pause_annotation, minimize, stop_annotating, Control, taylor_test) from firedrake.adjoint import FourDVarReducedFunctional +from firedrake.petsc import PETSc from advection_utils import * from fdvar import TAOSolver from sys import exit @@ -10,24 +11,16 @@ from advection_wc4dvar_aaorf import make_fdvrf Jhat, control = make_fdvrf() +Print = PETSc.Sys.Print + # the perturbation values need to be held in the # same type as the control i.e. and EnsembleFunction -vals = control.copy() -for v0, v1 in zip(vals.subfunctions, values): - v0.assign(v1) - -# print(f"{Jhat(control) = }") -# print(f"{taylor_test(Jhat, control, vals) = }") - -# options = {'disp': True, 'ftol': 1e-2} -# derivative_options = {'riesz_representation': None} -# opt = minimize(Jhat, options=options, method="L-BFGS-B", -# derivative_options=derivative_options) -# exit() +x0 = control.copy() ksp_params = { - 'monitor': None, - 'converged_rate': None, + 'monitor_short': None, + # 'converged_rate': None, + # 'converged_reason': None, } tao_params = { @@ -42,20 +35,28 @@ 'tao_type': 'nls', 'tao_nls': { 'ksp': ksp_params, - # 'ksp_type': 'cg', - # 'pc_type': 'lmvm', - 'ksp_rtol': 1e-3, 'ksp_type': 'gmres', - 'pc_type': 'python', - 'pc_python_type': 'fdvar.AllAtOnceJacobiPC', - - }, - 'tao_cg': { - 'ksp': ksp_params, + 'pc_type': 'lmvm', 'ksp_rtol': 1e-1, - 'type': 'pr', # fr-pr-prp-hs-dy + # 'ksp_type': 'gmres', + # 'pc_type': 'python', + # 'pc_python_type': 'fdvar.AllAtOnceJacobiPC', + }, + # 'tao_cg': { + # 'ksp': ksp_params, + # 'ksp_rtol': 1e-1, + # 'type': 'pr', # fr-pr-prp-hs-dy + # }, } tao = TAOSolver(Jhat, options_prefix="", solver_parameters=tao_params) tao.solve() + +xopt = Jhat.control.copy() +J0 = Jhat(x0) +Jopt = Jhat(xopt) + +Print(f"{J0 = :.3e} | {Jopt = :.3e} | {Jopt/J0 = :.3e}") + + diff --git a/fdvar/mat.py b/fdvar/mat.py index 68e763e..9f8e483 100644 --- a/fdvar/mat.py +++ b/fdvar/mat.py @@ -8,7 +8,7 @@ from pyadjoint.enlisting import Enlist from typing import Optional from enum import Enum -from functools import partial +from functools import partial, cached_property class ISNest: @@ -225,8 +225,11 @@ def __init__(self, Jhat: ReducedFunctional, input_options: Optional[dict] = None, comm: MPI.Comm = PETSc.COMM_WORLD): self.Jhat = Jhat - self.control_interface = PETScVecInterface(Jhat.controls, comm=comm) - self.functional_interface = PETScVecInterface(Jhat.functional, comm=comm) + self.control_interface = PETScVecInterface( + Jhat.controls, comm=comm) + if action in (Adjoint, TLM): + self.functional_interface = PETScVecInterface( + Jhat.functional, comm=comm) if action == Hessian: # control -> control self.xinterface = self.control_interface @@ -261,6 +264,9 @@ def update(cls, obj, x, A, P): ctx.control_interface.from_petsc(x, ctx._m) ctx._shift = 0 + def shift(self, A, alpha): + self._shift += alpha + def update_tape_values(self, update_adjoint=True): _ = self.Jhat(self._m) if update_adjoint: @@ -282,14 +288,14 @@ def _mult_hessian(self, A, x): raise NotImplementedError( f'Cannot apply hessian action if {self.action = }') self.update_tape_values(update_adjoint=True) - return self.Jhat.hessian(x, options=self.options) + return self.Jhat.hessian(x) def _mult_tlm(self, A, x): if self.action != TLM: raise NotImplementedError( f'Cannot apply tlm action if {self.action = }') self.update_tape_values(update_adjoint=False) - return self.Jhat.tlm(x, options=self.options) + return self.Jhat.tlm(x) def _mult_adjoint(self, A, x): if self.action != Adjoint: diff --git a/fdvar/tao_solver.py b/fdvar/tao_solver.py index 193a6a0..bde46ff 100644 --- a/fdvar/tao_solver.py +++ b/fdvar/tao_solver.py @@ -3,6 +3,7 @@ from pyadjoint.optimization.tao_solver import ( OptionsManager, TAOConvergenceError, _tao_reasons) from functools import cached_property +from fdvar.mat import ReducedFunctionalMat __all__ = ("TAOObjective", "TAOConvergenceError", "TAOSolver") @@ -51,14 +52,16 @@ def hessian(self, A, x, y): @cached_property def hessian_mat(self): - ctx = HessianCtx(self.Jhat, dual_options=self.dual_options) - mat = PETSc.Mat().createPython( - (self.sizes, self.sizes), ctx, - comm=self.Jhat.ensemble.global_comm) - mat.setOption(PETSc.Mat.Option.SYMMETRIC, True) - mat.setUp() - mat.assemble() - return mat + return ReducedFunctionalMat(self.Jhat, options=self.dual_options) + # ctx = HessianCtx(self.Jhat, dual_options=self.dual_options) + # mat = PETSc.Mat().createPython( + # (self.sizes, self.sizes), ctx, + # comm=self.Jhat.ensemble.global_comm) + # mat.setOption(PETSc.Mat.Option.SYMMETRIC, True) + # mat.setUp() + # mat.assemble() + # return mat + pass # trick folding @cached_property def gradient_norm_mat(self): @@ -95,8 +98,8 @@ def mult(self, A, x, y): x.copy(mdvec) # TODO: Why do we need to reevaluate and derivate? - # _ = self.Jhat(self._m) - # _ = self.Jhat.derivative(options=self.dual_options) + _ = self.Jhat(self._m) + _ = self.Jhat.derivative(options=self.dual_options) ddJ = self.Jhat.hessian([self._mdot]) with ddJ.vec_ro() as dvec: From 651c4addda1fffb9638bae359f78aa36841739f4 Mon Sep 17 00:00:00 2001 From: JHopeCollins Date: Fri, 21 Mar 2025 13:52:55 +0000 Subject: [PATCH 11/16] wip --- fdvar/mat.py | 34 ++++++++++++++-------------------- fdvar/tao_solver.py | 2 +- 2 files changed, 15 insertions(+), 21 deletions(-) diff --git a/fdvar/mat.py b/fdvar/mat.py index 9f8e483..91f42bc 100644 --- a/fdvar/mat.py +++ b/fdvar/mat.py @@ -214,15 +214,20 @@ class ReducedFunctionalMatCtx: """ PythonMat context to apply action of a pyadjoint.ReducedFunctional. + Jhat : V -> U + TLM : V -> U + Adjoint : U* -> V* + Hessian : V x U* -> V* | V -> V* + Parameters ---------- action : RFAction """ + dual_options = {'riesz_representation': None} + def __init__(self, Jhat: ReducedFunctional, - action: str = Hessian, - options: Optional[dict] = None, - input_options: Optional[dict] = None, + action: str = Hessian, *, comm: MPI.Comm = PETSc.COMM_WORLD): self.Jhat = Jhat self.control_interface = PETScVecInterface( @@ -254,8 +259,6 @@ def __init__(self, Jhat: ReducedFunctional, self.action = action self._m = copy_controls(Jhat.controls) - self.input_options = input_options - self.options = options self._shift = 0 @classmethod @@ -270,15 +273,11 @@ def shift(self, A, alpha): def update_tape_values(self, update_adjoint=True): _ = self.Jhat(self._m) if update_adjoint: - _ = self.Jhat.derivative(options=self.options) + _ = self.Jhat.derivative(options=self.dual_options) def mult(self, A, x, y): self.xinterface.from_petsc(x, self.x) - if self.input_options is None: - _x = self.x - else: - _x = convert_types(self.x, self.input_options) - out = self.mult_impl(A, _x) + out = self.mult_impl(A, self.x) self.yinterface.to_petsc(y, out) if self._shift != 0: y.axpy(self._shift, x) @@ -288,7 +287,7 @@ def _mult_hessian(self, A, x): raise NotImplementedError( f'Cannot apply hessian action if {self.action = }') self.update_tape_values(update_adjoint=True) - return self.Jhat.hessian(x) + return self.Jhat.hessian(x, options=self.dual_options) def _mult_tlm(self, A, x): if self.action != TLM: @@ -302,17 +301,12 @@ def _mult_adjoint(self, A, x): raise NotImplementedError( f'Cannot apply adjoint action if {self.action = }') self.update_tape_values(update_adjoint=False) - return self.Jhat.derivative(adj_input=x, options=self.options) + return self.Jhat.derivative(adj_input=x, options=self.dual_options) -def ReducedFunctionalMat(Jhat, action=Hessian, - options=None, input_options=None, - comm=PETSc.COMM_WORLD): +def ReducedFunctionalMat(Jhat, action=Hessian, *, comm=PETSc.COMM_WORLD, **kwargs): ctx = ReducedFunctionalMatCtx( - Jhat, action, - options=options, - input_options=input_options, - comm=comm) + Jhat, action, comm=comm, **kwargs) ncol = ctx.xinterface.n Ncol = ctx.xinterface.N diff --git a/fdvar/tao_solver.py b/fdvar/tao_solver.py index bde46ff..5fd8ff6 100644 --- a/fdvar/tao_solver.py +++ b/fdvar/tao_solver.py @@ -52,7 +52,7 @@ def hessian(self, A, x, y): @cached_property def hessian_mat(self): - return ReducedFunctionalMat(self.Jhat, options=self.dual_options) + return ReducedFunctionalMat(self.Jhat) # ctx = HessianCtx(self.Jhat, dual_options=self.dual_options) # mat = PETSc.Mat().createPython( # (self.sizes, self.sizes), ctx, From 6a877e19b9cb2484f65da827f5ef03bbad97f3c7 Mon Sep 17 00:00:00 2001 From: Josh Hope-Collins Date: Thu, 24 Apr 2025 08:37:40 +0100 Subject: [PATCH 12/16] arf upd --- fdvar/__init__.py | 1 + fdvar/generate_data.py | 47 ++++++++++++++++++++++++++++++++++++++++++ fdvar/mat.py | 20 +++++++++++++----- fdvar/tao_solver.py | 23 ++++++++++----------- 4 files changed, 74 insertions(+), 17 deletions(-) create mode 100644 fdvar/generate_data.py diff --git a/fdvar/__init__.py b/fdvar/__init__.py index a5acf8a..c308d89 100644 --- a/fdvar/__init__.py +++ b/fdvar/__init__.py @@ -1 +1,2 @@ from .tao_solver import * # noqa: F401, F403 +from .generate_data import * # noqa: F401, F403 diff --git a/fdvar/generate_data.py b/fdvar/generate_data.py new file mode 100644 index 0000000..11698f4 --- /dev/null +++ b/fdvar/generate_data.py @@ -0,0 +1,47 @@ +import firedrake as fd +import numpy as np + + +def noisify(u, sigma, seed=None): + if seed is not None: + np.random.seed(seed) + for dat in u.dat: + dat.data[:] += np.random.normal(0, sigma, dat.data.shape) + return u + + +def generate_observation_data(ensemble, ic, stepper, un, un1, H, + nw, nt, sigma_b, sigma_r, sigma_q, seed=6): + if seed is not None: + np.random.seed(seed) + + rank = ensemble.ensemble_rank + + nlocal_stages = nw//ensemble.ensemble_size + if rank == 0: + nlocal_stages -= 1 + + # background is reference plus noise + background = ic.copy(deepcopy=True) + noisify(background, sigma_b) + + y = [] + # initial observation + if rank == 0: + y.append(noisify(H(ic), sigma_r)) + + widx = 0 + with ensemble.sequential(widx=widx, un=un) as ctx: + # if rank != 0: + # y.extend([None for _ in range(ctx.widx)]) + un.assign(ctx.un) + un1.assign(un) + for k in range(nlocal_stages): + for i in range(nt): + stepper.solve() + un.assign(un1) + noisify(un, sigma_q) + y.append(noisify(H(un), sigma_r)) + ctx.widx += 1 + + return y, background diff --git a/fdvar/mat.py b/fdvar/mat.py index 91f42bc..da81322 100644 --- a/fdvar/mat.py +++ b/fdvar/mat.py @@ -86,7 +86,7 @@ class RFAction(Enum): def copy_controls(controls): - return controls.delist([c.copy() for c in controls]) + return controls.delist([c.control._ad_init_zero() for c in controls]) def convert_types(overloaded, options): @@ -224,14 +224,15 @@ class ReducedFunctionalMatCtx: action : RFAction """ - dual_options = {'riesz_representation': None} def __init__(self, Jhat: ReducedFunctional, action: str = Hessian, *, + apply_riesz: bool = False, comm: MPI.Comm = PETSc.COMM_WORLD): self.Jhat = Jhat self.control_interface = PETScVecInterface( Jhat.controls, comm=comm) + self.apply_riesz = apply_riesz if action in (Adjoint, TLM): self.functional_interface = PETScVecInterface( Jhat.functional, comm=comm) @@ -239,18 +240,21 @@ def __init__(self, Jhat: ReducedFunctional, if action == Hessian: # control -> control self.xinterface = self.control_interface self.yinterface = self.control_interface + self.x = copy_controls(Jhat.controls) self.mult_impl = self._mult_hessian elif action == Adjoint: # functional -> control self.xinterface = self.functional_interface self.yinterface = self.control_interface + self.x = Jhat.functional._ad_copy() self.mult_impl = self._mult_adjoint elif action == TLM: # control -> functional self.xinterface = self.control_interface self.yinterface = self.functional_interface + self.x = copy_controls(Jhat.controls) self.mult_impl = self._mult_tlm else: @@ -273,12 +277,13 @@ def shift(self, A, alpha): def update_tape_values(self, update_adjoint=True): _ = self.Jhat(self._m) if update_adjoint: - _ = self.Jhat.derivative(options=self.dual_options) + _ = self.Jhat.derivative(apply_riesz=False) def mult(self, A, x, y): self.xinterface.from_petsc(x, self.x) out = self.mult_impl(A, self.x) self.yinterface.to_petsc(y, out) + if self._shift != 0: y.axpy(self._shift, x) @@ -286,13 +291,16 @@ def _mult_hessian(self, A, x): if self.action != Hessian: raise NotImplementedError( f'Cannot apply hessian action if {self.action = }') + self.update_tape_values(update_adjoint=True) - return self.Jhat.hessian(x, options=self.dual_options) + return self.Jhat.hessian( + x, apply_riesz=self.apply_riesz) def _mult_tlm(self, A, x): if self.action != TLM: raise NotImplementedError( f'Cannot apply tlm action if {self.action = }') + self.update_tape_values(update_adjoint=False) return self.Jhat.tlm(x) @@ -300,8 +308,10 @@ def _mult_adjoint(self, A, x): if self.action != Adjoint: raise NotImplementedError( f'Cannot apply adjoint action if {self.action = }') + self.update_tape_values(update_adjoint=False) - return self.Jhat.derivative(adj_input=x, options=self.dual_options) + return self.Jhat.derivative( + adj_input=x, apply_riesz=self.apply_riesz) def ReducedFunctionalMat(Jhat, action=Hessian, *, comm=PETSc.COMM_WORLD, **kwargs): diff --git a/fdvar/tao_solver.py b/fdvar/tao_solver.py index 5fd8ff6..db4cba9 100644 --- a/fdvar/tao_solver.py +++ b/fdvar/tao_solver.py @@ -9,9 +9,9 @@ class TAOObjective: - def __init__(self, Jhat, dual_options=None): + def __init__(self, Jhat, apply_riesz=False): self.Jhat = Jhat - self.dual_options = dual_options + self.apply_riesz = apply_riesz self._control = Jhat.control.copy() self._m = Jhat.control.copy() self._mdot = Jhat.control.copy() @@ -26,19 +26,17 @@ def objective(self, tao, x): return self.Jhat(self._control) def gradient(self, tao, x, g): - dJ = self.Jhat.derivative(options=self.dual_options) + dJ = self.Jhat.derivative() with dJ.vec_ro() as dvec: dvec.copy(g) - # self.objective_gradient(tao, x, g) def objective_gradient(self, tao, x, g): with self._control.vec_wo() as cvec: x.copy(cvec) J = self.Jhat(self._control) - dJ = self.Jhat.derivative(options=self.dual_options) + dJ = self.Jhat.derivative() with dJ.vec_ro() as dvec: dvec.copy(g) - # self.gradient(tao, x, g) return J def hessian(self, A, x, y): @@ -83,12 +81,11 @@ def update(cls, tao, x, H, P): x.copy(mvec) ctx._shift = 0.0 - def __init__(self, Jhat, dual_options=None): + def __init__(self, Jhat): self.Jhat = Jhat self._m = Jhat.control.copy() self._mdot = Jhat.control.copy() self._shift = 0.0 - self.dual_options = dual_options def shift(self, A, alpha): self._shift += alpha @@ -99,7 +96,7 @@ def mult(self, A, x, y): # TODO: Why do we need to reevaluate and derivate? _ = self.Jhat(self._m) - _ = self.Jhat.derivative(options=self.dual_options) + _ = self.Jhat.derivative() ddJ = self.Jhat.hessian([self._mdot]) with ddJ.vec_ro() as dvec: @@ -125,6 +122,9 @@ def mult(self, mat, x, y): fd.assemble(self.M, tensor=self._ycofunc._full_local_function) + # ycofunc = self.Jhat.control._ad_convert_riesz( + # self._xfunc, riesz_map=self.Jhat.control.riesz_map) + with self._ycofunc.vec_ro() as yvec: yvec.copy(y) @@ -135,9 +135,7 @@ def __init__(self, Jhat, *, options_prefix=None, self.Jhat = Jhat self.ensemble = Jhat.ensemble - dual_options = {'riesz_representation': None} - - self.tao_objective = TAOObjective(Jhat, dual_options) + self.tao_objective = TAOObjective(Jhat) self.tao = PETSc.TAO().create( comm=Jhat.ensemble.global_comm) @@ -170,6 +168,7 @@ def __init__(self, Jhat, *, options_prefix=None, self.options = OptionsManager( solver_parameters, options_prefix) self.options.set_from_options(self.tao) + self.tao.setUp() def solve(self): From 1c41a008f76b62c2d4adffb8dfc27ccc7be1c28e Mon Sep 17 00:00:00 2001 From: Josh Hope-Collins Date: Thu, 24 Apr 2025 08:38:08 +0100 Subject: [PATCH 13/16] simple burgers tao demo --- burgers/burgers_simple.py | 112 ++++++++++++++++++++++++++++++++ burgers/burgers_wc4dvar_demo.py | 60 ++++++++++++----- 2 files changed, 156 insertions(+), 16 deletions(-) create mode 100644 burgers/burgers_simple.py diff --git a/burgers/burgers_simple.py b/burgers/burgers_simple.py new file mode 100644 index 0000000..6a52c1d --- /dev/null +++ b/burgers/burgers_simple.py @@ -0,0 +1,112 @@ +from firedrake import * +from firedrake.__future__ import interpolate +from firedrake.adjoint import continue_annotation, pause_annotation, Control, FourDVarReducedFunctional +from fdvar import TAOSolver, generate_observation_data +from math import sqrt +np.random.seed(42) + +nw, dt, nt, nu = 10, 1e-4, 6, 0.05 + +sigma_b = 0.1 +sigma_r = 0.03 +sigma_q = 0.0002 + +# ensemble parallelism +nspatial_ranks = 1 +ensemble = Ensemble(COMM_WORLD, nspatial_ranks) +ensemble_size = ensemble.ensemble_size + +# mesh +mesh = PeriodicUnitIntervalMesh( + 100, comm=ensemble.comm) +x, = SpatialCoordinate(mesh) + +# finite element forms +V = VectorFunctionSpace(mesh, "CG", 2) + +un, un1 = Function(V), Function(V) +v = TestFunction(V) +uh = (un + un1)/2 + +F = (inner(un1 - un, v)*dx + + dt*inner(dot(uh, nabla_grad(uh)), v)*dx + + dt*inner(nu*grad(uh), grad(v))*dx) + +# timestepper solver +stepper = NonlinearVariationalSolver( + NonlinearVariationalProblem(F, un1)) + +# "ground truth" reference solutions +reference_ic = Function(V).project( + as_vector([1 + 0.5*sin(2*pi*x)])) + +# observation mesh and operator +observation_locations = [ + [x] for x in np.random.random_sample(20)] +vom = VertexOnlyMesh(mesh, observation_locations) +Y = VectorFunctionSpace(vom, "DG", 0) + +def H(u): + return assemble(interpolate(u, Y)) + +# generate ground-truth observational data +y, background = generate_observation_data( + ensemble, reference_ic, stepper, un, un1, + H, nw, nt, sigma_b, sigma_r, sigma_q) + +def observation_error(i): + return lambda x: Function(Y).assign(H(x) - y[i]) + +# create Ensemble control +V_ensemble = EnsembleFunctionSpace( + [V for _ in range(nw//ensemble_size)], ensemble) +control = EnsembleFunction(V_ensemble) + +# start recording +continue_annotation() + +# create 4DVar ReducedFunctional +Jhat = FourDVarReducedFunctional( + Control(control), + background=background, + background_covariance=sigma_b, + observation_covariance=sigma_r, + observation_error=observation_error(0), + weak_constraint=True) + +with Jhat.recording_stages() as stages: + for stage, ctx in stages: + un.assign(stage.control) + un1.assign(un) + + for i in range(nt): + stepper.solve() + un.assign(un1) + + # record observation at end of each stage + idx = stage.local_index + + stage.set_observation( + state=un, + observation_error=observation_error(idx), + observation_covariance=sigma_r, + forward_model_covariance=sigma_q) + +# finish recording +pause_annotation() + +tao_parameters = { + 'tao_monitor': None, + 'tao_gttol': 1e-1, + 'tao_type': 'nls', + 'tao_nls': { + 'ksp_monitor_short': None, + 'ksp_rtol': 2e-1, + 'ksp_type': 'gmres', + 'pc_type': 'none'} + # 'tao_type': 'cg', + # 'tao_cg_type': 'fr', # fr-pr-prp-hs-dy +} +tao = TAOSolver(Jhat, options_prefix="", + solver_parameters=tao_parameters) +tao.solve() diff --git a/burgers/burgers_wc4dvar_demo.py b/burgers/burgers_wc4dvar_demo.py index af9fdfa..5ee5731 100644 --- a/burgers/burgers_wc4dvar_demo.py +++ b/burgers/burgers_wc4dvar_demo.py @@ -2,15 +2,15 @@ import firedrake as fd from firedrake.petsc import PETSc from firedrake.__future__ import interpolate -from firedrake.adjoint import (continue_annotation, pause_annotation, - get_working_tape, Control, minimize) +from firedrake.adjoint import (continue_annotation, pause_annotation, Control, minimize) +from pyadjoint.reduced_functional_numpy import ReducedFunctionalNumPy from firedrake.adjoint import FourDVarReducedFunctional from burgers_utils import noisy_double_sin, burgers_stepper import numpy as np from functools import partial import argparse -from sys import exit from math import ceil +from fdvar import TAOSolver np.set_printoptions(legacy='1.25') @@ -265,22 +265,50 @@ def observation_err(i, state, name=None): ucontrol = Jhat.control.copy() -# minimiser should be given the derivative not the gradient -derivative_options = {'riesz_representation': 'l2'} - -options = { - 'disp': trank == 0, - 'maxcor': args.maxcor, - 'ftol': args.ftol, - 'gtol': args.gtol +# # scipy minimise +# uoptimised = minimize( +# ReducedFunctionalNumPy(Jhat), +# method="L-BFGS-B", +# options = { +# 'disp': trank == 0, +# 'maxcor': args.maxcor, +# 'ftol': args.ftol, +# 'gtol': args.gtol +# }) + +# # tao minimise +tao_params = { + 'tao_view': ':tao_view.log', + 'tao': { + 'monitor': None, + 'converged_reason': None, + 'ls_monitor': None, + 'gatol': 1e-3, + 'grtol': 1e-3, + 'gttol': 1e-3, + }, + 'tao_type': 'nls', + 'tao_nls': { + 'ksp': { + 'monitor_short': None, + 'converged_rate': None, + 'converged_maxits': None, + 'max_it': 10, + 'rtol': 1e-1, + }, + 'ksp_type': 'gmres', + 'ksp_pc_side': 'right', + 'pc_type': 'none', + }, + 'tao_cg_type': 'fr', # fr-pr-prp-hs-dy } - -uoptimised = minimize(Jhat, options=options, method="L-BFGS-B", - derivative_options=derivative_options) - -uopts = uoptimised.subfunctions +tao = TAOSolver(Jhat, options_prefix="", + solver_parameters=tao_params) +tao.solve() +uoptimised = Jhat.control.copy() global_comm.Barrier() +uopts = uoptimised.subfunctions PETSc.Sys.Print(f"Initial functional: {Jhat(ucontrol)}") PETSc.Sys.Print(f"Final functional: {Jhat(uoptimised)}") global_comm.Barrier() From 93f4d0288bbfbc165cc40e59638b6c97854f1dcb Mon Sep 17 00:00:00 2001 From: Josh Hope-Collins Date: Sun, 27 Apr 2025 16:18:49 +0200 Subject: [PATCH 14/16] arf updates --- burgers/burgers_simple_sc4dvar.py | 151 +++++++++++ burgers/burgers_simple_sc4dvar_dirichlet.py | 214 ++++++++++++++++ .../burgers_simple_sc4dvar_dirichlet_irk.py | 239 ++++++++++++++++++ ...rs_simple.py => burgers_simple_wc4dvar.py} | 62 ++--- fdvar/generate_data.py | 28 +- fdvar/mat.py | 40 +-- fdvar/pc.py | 21 +- 7 files changed, 693 insertions(+), 62 deletions(-) create mode 100644 burgers/burgers_simple_sc4dvar.py create mode 100644 burgers/burgers_simple_sc4dvar_dirichlet.py create mode 100644 burgers/burgers_simple_sc4dvar_dirichlet_irk.py rename burgers/{burgers_simple.py => burgers_simple_wc4dvar.py} (59%) diff --git a/burgers/burgers_simple_sc4dvar.py b/burgers/burgers_simple_sc4dvar.py new file mode 100644 index 0000000..6ebc77a --- /dev/null +++ b/burgers/burgers_simple_sc4dvar.py @@ -0,0 +1,151 @@ +from firedrake import * +from firedrake.adjoint import * +from firedrake.__future__ import interpolate +from fdvar import generate_observation_data +np.random.seed(42) + +# number of observation windows, and steps per window +nw, nt, dt, nu = 50, 6, 1e-4, 0.01 + +# Covariance of background, observation, and model noise +sigma_b = 0.1 +sigma_r = 0.03 +sigma_q = 0.0002 + +# 1D periodic mesh +mesh = PeriodicUnitIntervalMesh(100) +x, = SpatialCoordinate(mesh) + +# Burger's equation with implicit midpoint integration +V = VectorFunctionSpace(mesh, "CG", 2) + +un, un1 = Function(V), Function(V) +v = TestFunction(V) +uh = (un + un1)/2 + +# finite element forms +F = (inner(un1 - un, v)*dx + + dt*inner(dot(uh, nabla_grad(uh)), v)*dx + + dt*inner(nu*grad(uh), grad(v))*dx) + +# timestepper solver +stepper = NonlinearVariationalSolver( + NonlinearVariationalProblem(F, un1)) + +def solve_step(): + un1.assign(un) + stepper.solve() + un.assign(un1) + +# "ground truth" reference solution +reference_ic = Function(V).project( + as_vector([1 + 0.5*sin(2*pi*x)])) + +# observations are point evaluations at random locations +observation_locations = [ + [x] for x in np.random.random_sample(20)] +vom = VertexOnlyMesh(mesh, observation_locations) +Y = VectorFunctionSpace(vom, "DG", 0) + +def H(x): # operator to take observations + return assemble(interpolate(x, Y)) + +# generate "ground-truth" observational data +y, background = generate_observation_data( + None, reference_ic, solve_step, un, + H, nw, nt, sigma_b, sigma_r, sigma_q) + +# create function evaluating observation error at window i +def observation_error(i): + return lambda x: Function(Y).assign(H(x) - y[i]) + +# create distributed control variable for entire timeseries +control = Function(V).assign(background) + +# tell pyadjoint to start taping operations +continue_annotation() + +# This object will construct and solve the 4DVar system +Jhat = FourDVarReducedFunctional( + Control(control), + background=background, + background_covariance=sigma_b, + observation_covariance=sigma_r, + observation_error=observation_error(0), + weak_constraint=False) + +# loop over each observation stage on the local communicator +with Jhat.recording_stages(nstages=nw) as stages: + for stage, ctx in stages: + idx = stage.local_index + un.assign(stage.control) + un1.assign(un) + + # let pyadjoint tape the time integration + for i in range(nt): + stepper.solve() + un.assign(un1) + + # tell pyadjoint a) we have finished this stage + # and b) how to evaluate this observation error + stage.set_observation( + state=un, + observation_error=observation_error(idx), + observation_covariance=sigma_r) + +# tell pyadjoint to finish taping operations +pause_annotation() + + +class CovariancePC(PCBase): + def initialize(self, pc): + w = Constant(sigma_b) + u = Function(V) + b = Cofunction(V.dual()) + a = (1/w)*inner(TrialFunction(V), TestFunction(V))*dx + solver = LinearVariationalSolver( + LinearVariationalProblem(a, b, u), + solver_parameters={'ksp_type': 'preonly', + 'pc_type': 'lu'}) + self.u, self.b, self.solver = u, b, solver + + def apply(self, pc, x, y): + with self.b.dat.vec_wo as vb: + x.copy(vb) + self.solver.solve() + with self.u.dat.vec_ro as vu: + vu.copy(y) + + # x.copy(y) + # y.scale(sigma_b) + + def update(self, pc, x): + pass + + def applyTranspose(self, pc, x, y): + raise NotImplementedError + + +# Solution strategy is controlled via this options dictionary +tao_parameters = { + 'tao_view': ':tao_view.log', + 'tao_monitor': None, + 'tao_converged_reason': None, + 'tao_gttol': 1e-2, + 'tao_type': 'nls', + 'tao_nls': { + 'ksp_monitor_short': None, + 'ksp_max_it': 10, + 'ksp_converged_maxits': None, + 'ksp_rtol': 1e-1, + 'ksp_type': 'gmres', + 'pc_type': 'none', + # 'pc_type': 'python', + # 'pc_python_type': f'{__name__}.CovariancePC', + }, +} +tao = TAOSolver(MinimizationProblem(Jhat), + parameters=tao_parameters) +xopt = tao.solve() +PETSc.Sys.Print(f"{errornorm(reference_ic, background) = :.3e}") +PETSc.Sys.Print(f"{errornorm(reference_ic, xopt) = :.3e}") diff --git a/burgers/burgers_simple_sc4dvar_dirichlet.py b/burgers/burgers_simple_sc4dvar_dirichlet.py new file mode 100644 index 0000000..62ed812 --- /dev/null +++ b/burgers/burgers_simple_sc4dvar_dirichlet.py @@ -0,0 +1,214 @@ +from firedrake import * +from firedrake.adjoint import * +from firedrake.__future__ import interpolate +from fdvar import generate_observation_data +from sys import exit +np.random.seed(42) + + +class BackgroundPC(PCBase): + def initialize(self, pc): + A, _ = pc.getOperators() + scfdv = A.getPythonContext().problem.reduced_functional + B = scfdv.background_norm.covariance + Vc = scfdv.controls[0].control.function_space() + + u = Function(Vc) + b = Cofunction(Vc.dual()) + a = inner(TrialFunction(Vc)*(1/B), TestFunction(Vc))*dx + + solver = LinearVariationalSolver( + LinearVariationalProblem( + a, b, u, constant_jacobian=True), + solver_parameters={'ksp_type': 'preonly', + 'pc_type': 'lu'}) + + self.u, self.b, self.solver = u, b, solver + + def apply(self, pc, x, y): + with self.b.dat.vec_wo as vb: + x.copy(vb) + self.solver.solve() + with self.u.dat.vec_ro as vu: + vu.copy(y) + + def update(self, pc, x): + pass + + def applyTranspose(self, pc, x, y): + raise NotImplementedError + +# number of observation windows, and steps per window +nx, nw, nt, dt, nu = 30, 15, 5, 5e-4, 0.25 + +T = 0.03 +nsw = 50 + +# Covariance of background, observation, and model noise +sigma_b = 1e-2 +sigma_r = 1e-3 +sigma_q = (1e-4)*T/nsw + +B = sigma_b +R = sigma_r +Q = sigma_q + +# 1D periodic mesh +mesh = UnitIntervalMesh(nx) +x, = SpatialCoordinate(mesh) + +# Burger's equation with implicit midpoint integration +V = VectorFunctionSpace(mesh, "CG", 1) +V0 = FunctionSpace(mesh, "CG", 1) +Real = FunctionSpace(mesh, "R", 0) + +B = Function(V0).project(sigma_b*(1 - 0.9*cos(2*pi*x)*sin(4*pi*(0.5-x)))) + +t = Function(Real).assign(0) +fdt = Function(Real).assign(dt) + +un, un1 = Function(V), Function(V) +v = TestFunction(V) + +zero = Constant(0) +bcs = [DirichletBC(V, as_vector([zero]), "on_boundary")] + +params = { + 'ksp_type': 'gmres', + 'pc_type': 'ilu', +} + +# forcing +k = Constant(0.1) + +x1 = 1 - x +t1 = t + 1 +g = ( + pi*k*( + (x + k*t1*sin(pi*x1*t1)) + *cos(pi*x*t1)*sin(pi*x1*t1) + + + (x1 - k*t1*sin(pi*x*t1)) + *sin(pi*x*t1)*cos(pi*x1*t1) + ) + + + (2*nu*(k*pi*t1)**2) + *(sin(pi*x*t1)*sin(pi*x1*t1) + + cos(pi*x*t1)*cos(pi*x1*t1)) +) +uh = (un + un1)/2 + +# finite element forms +F = ( + (inner(un1 - un, v)/fdt)*dx + + inner(dot(uh, nabla_grad(uh)), v)*dx + + inner(nu*grad(uh), grad(v))*dx + - inner(as_vector([g]), v)*dx +) + +# timestepper solver +stepper = NonlinearVariationalSolver( + NonlinearVariationalProblem(F, un1, bcs=bcs), + solver_parameters=params) + + +def solve_step(): + un1.assign(un) + stepper.solve() + un.assign(un1) + + +# "ground truth" reference solution +reference_ic = Function(V).project( + as_vector([k*sin(2*pi*x)])) + +# observations are point evaluations at random locations +observation_locations = [ + [x] for x in np.random.random_sample(20)] +vom = VertexOnlyMesh(mesh, observation_locations) +Y = VectorFunctionSpace(vom, "DG", 0) +Y0 = FunctionSpace(vom, "DG", 0) + +# vary between (sigma_r =< R =< 1) +Rprofile = sin(6*pi*(x + 0.3)) +Rexpr = ((1 + sigma_r) + (1 - sigma_r)*Rprofile)/2 +R = Function(Y0).interpolate(Rexpr) + +def H(x): # operator to take observations + return assemble(interpolate(x, Y)) + +# generate "ground-truth" observational data +y, background = generate_observation_data( + None, reference_ic, stepper, un, un1, + H, nw, nt, sigma_b, sigma_r, sigma_q) + +# create function evaluating observation error at window i +def observation_error(i): + return lambda x: Function(Y).assign(H(x) - y[i]) + +# create distributed control variable for entire timeseries +control = Function(V).assign(background) + +# tell pyadjoint to start taping operations +continue_annotation() + +# This object will construct and solve the 4DVar system +Jhat = FourDVarReducedFunctional( + Control(control), + background=background, + background_covariance=B, + observation_covariance=R, + observation_error=observation_error(0), + weak_constraint=False) + +# loop over each observation stage on the local communicator +t.assign(0.) +with Jhat.recording_stages(nstages=nw, t=t) as stages: + for stage, ctx in stages: + idx = stage.local_index + un.assign(stage.control) + t.assign(ctx.t) + + # let pyadjoint tape the time integration + for i in range(nt): + un1.assign(un) + stepper.solve() + un.assign(un1) + t += dt + ctx.t.assign(t) + + # tell pyadjoint a) we have finished this stage + # and b) how to evaluate this observation error + stage.set_observation( + state=un, + observation_error=observation_error(idx), + observation_covariance=R) + + +# tell pyadjoint to finish taping operations +pause_annotation() + + +# Solution strategy is controlled via this options dictionary +tao_parameters = { + 'tao_view': ':tao_view.log', + 'tao_monitor': None, + 'tao_converged_reason': None, + 'tao_gttol': 1e-1, + 'tao_type': 'nls', + 'tao_nls': { + 'ksp_monitor_short': None, + 'ksp_converged_rate': None, + 'ksp_converged_maxits': None, + 'ksp_max_it': 6, + 'ksp_rtol': 1e-1, + 'ksp_type': 'gmres', + 'pc_type': 'python', + 'pc_python_type': f'{__name__}.BackgroundPC', + }, +} +tao = TAOSolver(MinimizationProblem(Jhat), + parameters=tao_parameters) +xopt = tao.solve() +PETSc.Sys.Print(f"{errornorm(reference_ic, background) = :.3e}") +PETSc.Sys.Print(f"{errornorm(reference_ic, xopt) = :.3e}") diff --git a/burgers/burgers_simple_sc4dvar_dirichlet_irk.py b/burgers/burgers_simple_sc4dvar_dirichlet_irk.py new file mode 100644 index 0000000..784f909 --- /dev/null +++ b/burgers/burgers_simple_sc4dvar_dirichlet_irk.py @@ -0,0 +1,239 @@ +from firedrake import * +from firedrake.adjoint import * +from firedrake.__future__ import interpolate +from fdvar import generate_observation_data +from numpy import mean as nmean +from numpy import max as nmax +from sys import exit +from irksome import TimeStepper, GaussLegendre, Dt +np.random.seed(42) + +# number of observation windows, and steps per window +nx, nw, nt, dt, nu = 20, 5, 3, 1e-4, 0.25 + +T = 0.03 +nsw = 50 + +# Covariance of background, observation, and model noise +sigma_b = 1e-2 +sigma_r = 1e-3 +sigma_q = (1e-4)*T/nsw + +# 1D periodic mesh +mesh = UnitIntervalMesh(nx) +x, = SpatialCoordinate(mesh) + +# Burger's equation with implicit midpoint integration +V = VectorFunctionSpace(mesh, "CG", 1) +R = FunctionSpace(mesh, "R", 0) +t = Function(R).assign(0) +fdt = Function(R).assign(dt) + +un, un1 = Function(V), Function(V) +v = TestFunction(V) + +zero = Constant(0) +bcs = [DirichletBC(V, as_vector([zero]), "on_boundary")] + +params = { + 'ksp_type': 'gmres', + 'pc_type': 'ilu', +} + +# forcing +k = Constant(0.1) + +x1 = 1 - x +t1 = t + 1 +g = ( + pi*k*( + (x + k*t1*sin(pi*x1*t1)) + *cos(pi*x*t1)*sin(pi*x1*t1) + + + (x1 - k*t1*sin(pi*x*t1)) + *sin(pi*x*t1)*cos(pi*x1*t1) + ) + + + (2*nu*(k*pi*t1)**2) + *(sin(pi*x*t1)*sin(pi*x1*t1) + + cos(pi*x*t1)*cos(pi*x1*t1)) +) +uh = (un + un1)/2 + +irk = True + +# finite element forms +F = ( + (inner(un1 - un, v)/fdt)*dx + # + inner(dot(as_vector([uh]), nabla_grad(uh)), v)*dx + # + inner(nu*grad(uh), grad(v))*dx + # - inner(g, v)*dx + + inner(dot(uh, nabla_grad(uh)), v)*dx + + inner(nu*grad(uh), grad(v))*dx + - inner(as_vector([g]), v)*dx +) + +Firk = ( + inner(Dt(un), v)*dx + # + inner(dot(as_vector([un]), nabla_grad(un)), v)*dx + # + inner(nu*grad(un), grad(v))*dx + # - inner(g, v)*dx + + inner(dot(un, nabla_grad(un)), v)*dx + + inner(nu*grad(un), grad(v))*dx + - inner(as_vector([g]), v)*dx +) +irk_stepper = TimeStepper( + Firk, GaussLegendre(1), t, fdt, un, + bcs=bcs, solver_parameters=params) + +# timestepper solver +nvs_stepper = NonlinearVariationalSolver( + NonlinearVariationalProblem(F, un1, bcs=bcs), + solver_parameters=params) +stepper = nvs_stepper + +def solve_step(): + if irk: + irk_stepper.advance() + else: + un1.assign(un) + nvs_stepper.solve() + un.assign(un1) + +# "ground truth" reference solution +reference_ic = Function(V).project( + as_vector([k*sin(2*pi*x)])) + # k*sin(2*pi*x)) + +# observations are point evaluations at random locations +observation_locations = [ + [x] for x in np.random.random_sample(20)] +vom = VertexOnlyMesh(mesh, observation_locations) +Y = VectorFunctionSpace(vom, "DG", 0) + +def H(x): # operator to take observations + return assemble(interpolate(x, Y)) + +# generate "ground-truth" observational data +y, background = generate_observation_data( + None, reference_ic, solve_step, un, + H, nw, nt, sigma_b, sigma_r, sigma_q) + +# create function evaluating observation error at window i +def observation_error(i): + return lambda x: Function(Y).assign(H(x) - y[i]) + +# create distributed control variable for entire timeseries +control = Function(V).assign(background) + +# tell pyadjoint to start taping operations +continue_annotation() + +# This object will construct and solve the 4DVar system +Jhat = FourDVarReducedFunctional( + Control(control), + background=background, + background_covariance=sigma_b, + observation_covariance=sigma_r, + observation_error=observation_error(0), + weak_constraint=False) + +# loop over each observation stage on the local communicator +trajectory = [background.copy(deepcopy=True, annotate=False)] +t.assign(0.) +with Jhat.recording_stages(nstages=nw, t=t) as stages: + for stage, ctx in stages: + idx = stage.local_index + un.assign(stage.control) + + t.assign(ctx.t) + + # let pyadjoint tape the time integration + for i in range(nt): + solve_step() + t += dt + trajectory.append(un.copy(deepcopy=True, annotate=False)) + + # tell pyadjoint a) we have finished this stage + # and b) how to evaluate this observation error + stage.set_observation( + state=un, + observation_error=observation_error(idx), + observation_covariance=sigma_r) + + +# tell pyadjoint to finish taping operations +pause_annotation() + +# print(f"{t.dat.data[0] = :.2e}") +# sc4dvar = Jhat.strong_reduced_functional +# print(f"{Jhat._total_functional = }") +# print(f"{sc4dvar.functional = :.2e}") +# print(f"{sc4dvar(background) = :.2e}") +# print(f"{sc4dvar(reference_ic) = :.2e}") +# print(f"{sc4dvar(background) = :.2e}") +# print(f"{sc4dvar(reference_ic) = :.2e}") + +# vtk = VTKFile('output/burgers.pvd') +# t = 0 +# for j, u in enumerate(trajectory): +# # ud = u.dat.data +# # print(f"{j = :>3d} | {norm(u) = :.2e} | {nmean(ud) = :.2e} | {nmax(-ud) = :.2e} | {nmax(ud) = :.2e}") +# # vtk.write(u, time=t) +# # t += dt +# # pass + + +class CovariancePC(PCBase): + def initialize(self, pc): + w = Constant(sigma_b) + u = Function(V) + b = Cofunction(V.dual()) + a = (1/w)*inner(TrialFunction(V), TestFunction(V))*dx + solver = LinearVariationalSolver( + LinearVariationalProblem(a, b, u), + solver_parameters={'ksp_type': 'preonly', + 'pc_type': 'lu'}) + self.u, self.b, self.solver = u, b, solver + + def apply(self, pc, x, y): + with self.b.dat.vec_wo as vb: + x.copy(vb) + self.solver.solve() + with self.u.dat.vec_ro as vu: + vu.copy(y) + + # x.copy(y) + # y.scale(sigma_b) + + def update(self, pc, x): + pass + + def applyTranspose(self, pc, x, y): + raise NotImplementedError + + +# Solution strategy is controlled via this options dictionary +tao_parameters = { + 'tao_view': ':tao_view.log', + 'tao_monitor': None, + 'tao_converged_reason': None, + 'tao_gttol': 2e-1, + 'tao_type': 'nls', + 'tao_nls': { + 'ksp_monitor_short': None, + 'ksp_converged_rate': None, + 'ksp_converged_maxits': None, + 'ksp_max_it': 10, + 'ksp_rtol': 1e-1, + 'ksp_type': 'gmres', + 'pc_type': 'none', + # 'pc_type': 'python', + # 'pc_python_type': f'{__name__}.CovariancePC', + }, +} +tao = TAOSolver(MinimizationProblem(Jhat), + parameters=tao_parameters) +xopt = tao.solve() +PETSc.Sys.Print(f"{errornorm(reference_ic, background) = :.3e}") +PETSc.Sys.Print(f"{errornorm(reference_ic, xopt) = :.3e}") diff --git a/burgers/burgers_simple.py b/burgers/burgers_simple_wc4dvar.py similarity index 59% rename from burgers/burgers_simple.py rename to burgers/burgers_simple_wc4dvar.py index 6a52c1d..11b5fa1 100644 --- a/burgers/burgers_simple.py +++ b/burgers/burgers_simple_wc4dvar.py @@ -1,71 +1,72 @@ from firedrake import * +from firedrake.adjoint import * from firedrake.__future__ import interpolate -from firedrake.adjoint import continue_annotation, pause_annotation, Control, FourDVarReducedFunctional from fdvar import TAOSolver, generate_observation_data -from math import sqrt np.random.seed(42) -nw, dt, nt, nu = 10, 1e-4, 6, 0.05 +# number of observation windows, and steps per window +nw, nt, dt, nu = 5, 6, 1e-4, 0.05 +# Covariance of background, observation, and model noise sigma_b = 0.1 sigma_r = 0.03 sigma_q = 0.0002 -# ensemble parallelism +# time-parallelism using firedrake's Ensemble nspatial_ranks = 1 ensemble = Ensemble(COMM_WORLD, nspatial_ranks) ensemble_size = ensemble.ensemble_size -# mesh +# 1D periodic mesh mesh = PeriodicUnitIntervalMesh( 100, comm=ensemble.comm) x, = SpatialCoordinate(mesh) -# finite element forms +# Burger's equation with implicit midpoint integration V = VectorFunctionSpace(mesh, "CG", 2) un, un1 = Function(V), Function(V) v = TestFunction(V) uh = (un + un1)/2 -F = (inner(un1 - un, v)*dx - + dt*inner(dot(uh, nabla_grad(uh)), v)*dx - + dt*inner(nu*grad(uh), grad(v))*dx) +# finite element forms +F = (inner(un1 - un, v)/dt + + inner(dot(uh, grad(uh)), v) + + inner(nu*grad(uh), grad(v)))*dx # timestepper solver stepper = NonlinearVariationalSolver( NonlinearVariationalProblem(F, un1)) -# "ground truth" reference solutions +# "ground truth" reference solution reference_ic = Function(V).project( as_vector([1 + 0.5*sin(2*pi*x)])) -# observation mesh and operator -observation_locations = [ - [x] for x in np.random.random_sample(20)] -vom = VertexOnlyMesh(mesh, observation_locations) +# observations are point evaluations at random locations +vom = VertexOnlyMesh(mesh, np.random.rand(20, 1)) Y = VectorFunctionSpace(vom, "DG", 0) -def H(u): - return assemble(interpolate(u, Y)) +def H(x): # operator to take observations + return assemble(interpolate(x, Y)) -# generate ground-truth observational data +# generate "ground-truth" observational data y, background = generate_observation_data( ensemble, reference_ic, stepper, un, un1, H, nw, nt, sigma_b, sigma_r, sigma_q) +# create function evaluating observation error at window i def observation_error(i): return lambda x: Function(Y).assign(H(x) - y[i]) -# create Ensemble control +# create distributed control variable for entire timeseries V_ensemble = EnsembleFunctionSpace( [V for _ in range(nw//ensemble_size)], ensemble) control = EnsembleFunction(V_ensemble) -# start recording +# tell pyadjoint to start taping operations continue_annotation() -# create 4DVar ReducedFunctional +# This object will construct and solve the 4DVar system Jhat = FourDVarReducedFunctional( Control(control), background=background, @@ -74,39 +75,40 @@ def observation_error(i): observation_error=observation_error(0), weak_constraint=True) +# loop over each observation stage on the local communicator with Jhat.recording_stages() as stages: for stage, ctx in stages: + idx = stage.local_index un.assign(stage.control) un1.assign(un) + # let pyadjoint tape the time integration for i in range(nt): stepper.solve() un.assign(un1) - # record observation at end of each stage - idx = stage.local_index - + # tell pyadjoint a) we have finished this stage + # and b) how to evaluate this observation error stage.set_observation( state=un, observation_error=observation_error(idx), observation_covariance=sigma_r, forward_model_covariance=sigma_q) -# finish recording +# tell pyadjoint to finish taping operations pause_annotation() +# Solution strategy is controlled via this options dictionary tao_parameters = { 'tao_monitor': None, + 'tao_converged_reason': None, 'tao_gttol': 1e-1, 'tao_type': 'nls', 'tao_nls': { 'ksp_monitor_short': None, - 'ksp_rtol': 2e-1, - 'ksp_type': 'gmres', - 'pc_type': 'none'} - # 'tao_type': 'cg', - # 'tao_cg_type': 'fr', # fr-pr-prp-hs-dy + 'ksp_rtol': 5e-1, + 'ksp_type': 'gmres'} } -tao = TAOSolver(Jhat, options_prefix="", +tao = TAOSolver(Jhat, options_prefix="fdv", solver_parameters=tao_parameters) tao.solve() diff --git a/fdvar/generate_data.py b/fdvar/generate_data.py index 11698f4..c01074e 100644 --- a/fdvar/generate_data.py +++ b/fdvar/generate_data.py @@ -15,33 +15,39 @@ def generate_observation_data(ensemble, ic, stepper, un, un1, H, if seed is not None: np.random.seed(seed) - rank = ensemble.ensemble_rank + if ensemble is None: + rank = 0 + nlocal_stages = nw + else: + rank = ensemble.ensemble_rank + nlocal_stages = nw//ensemble.ensemble_size - nlocal_stages = nw//ensemble.ensemble_size if rank == 0: nlocal_stages -= 1 # background is reference plus noise - background = ic.copy(deepcopy=True) - noisify(background, sigma_b) + background = noisify(ic.copy(deepcopy=True), sigma_b) y = [] # initial observation if rank == 0: y.append(noisify(H(ic), sigma_r)) - widx = 0 - with ensemble.sequential(widx=widx, un=un) as ctx: - # if rank != 0: - # y.extend([None for _ in range(ctx.widx)]) - un.assign(ctx.un) - un1.assign(un) + def solve_local_stages(): for k in range(nlocal_stages): for i in range(nt): + un1.assign(un) stepper.solve() un.assign(un1) noisify(un, sigma_q) y.append(noisify(H(un), sigma_r)) - ctx.widx += 1 + + un.assign(ic) + if ensemble is None: + solve_local_stages() + else: + with ensemble.sequential(u=un) as ctx: + un.assign(ctx.u) + solve_local_stages() return y, background diff --git a/fdvar/mat.py b/fdvar/mat.py index da81322..b069e2a 100644 --- a/fdvar/mat.py +++ b/fdvar/mat.py @@ -8,7 +8,7 @@ from pyadjoint.enlisting import Enlist from typing import Optional from enum import Enum -from functools import partial, cached_property +from functools import partial, cached_property, wraps class ISNest: @@ -210,6 +210,18 @@ def FDVarSaddlePointMat(fdvrf): return fdvmat +def check_rf_action(action): + def check_rf_action_decorator(func): + @wraps(func) + def wrapper(self, *args, **kwargs): + if self.action != action: + raise NotImplementedError( + f'Cannot apply {str(action)} action if {self.action = }') + return func(*args, **kwargs) + return wrapper + return check_rf_action_decorator + + class ReducedFunctionalMatCtx: """ PythonMat context to apply action of a pyadjoint.ReducedFunctional. @@ -269,6 +281,7 @@ def __init__(self, Jhat: ReducedFunctional, def update(cls, obj, x, A, P): ctx = A.getPythonContext() ctx.control_interface.from_petsc(x, ctx._m) + ctx.update_tape_values(update_adjoint=True) ctx._shift = 0 def shift(self, A, alpha): @@ -287,29 +300,20 @@ def mult(self, A, x, y): if self._shift != 0: y.axpy(self._shift, x) + @check_rf_action(action=Hessian) def _mult_hessian(self, A, x): - if self.action != Hessian: - raise NotImplementedError( - f'Cannot apply hessian action if {self.action = }') - - self.update_tape_values(update_adjoint=True) + # self.update_tape_values(update_adjoint=True) return self.Jhat.hessian( x, apply_riesz=self.apply_riesz) + @check_rf_action(TLM) def _mult_tlm(self, A, x): - if self.action != TLM: - raise NotImplementedError( - f'Cannot apply tlm action if {self.action = }') - - self.update_tape_values(update_adjoint=False) + # self.update_tape_values(update_adjoint=False) return self.Jhat.tlm(x) + @check_rf_action(Adjoint) def _mult_adjoint(self, A, x): - if self.action != Adjoint: - raise NotImplementedError( - f'Cannot apply adjoint action if {self.action = }') - - self.update_tape_values(update_adjoint=False) + # self.update_tape_values(update_adjoint=False) return self.Jhat.derivative( adj_input=x, apply_riesz=self.apply_riesz) @@ -354,7 +358,7 @@ def EnsembleMat(ctx, row_space, col_space=None): class EnsembleMatCtxBase: def __init__(self, row_space, col_space=None): if col_space is None: - col_space = row_space + col_space = row_space.dual() if not isinstance(row_space, fd.EnsembleFunctionSpace): raise ValueError( @@ -370,7 +374,7 @@ def __init__(self, row_space, col_space=None): # so that base classes can implement mult only in # terms of Ensemble objects not Vecs. self.x = fd.EnsembleFunction(self.row_space) - self.y = fd.EnsembleFunction(self.col_space.dual()) + self.y = fd.EnsembleFunction(self.col_space) def mult(self, A, x, y): with self.x.vec_wo() as xvec: diff --git a/fdvar/pc.py b/fdvar/pc.py index 571ddf2..d7fed77 100644 --- a/fdvar/pc.py +++ b/fdvar/pc.py @@ -63,7 +63,7 @@ def initialize(self, pc): class EnsembleBlockDiagonalPC(PCBase): - prefix = "ensemblejacobi_" + prefix = "ebjacobi_" def initialize(self, pc): super().initialize(pc) @@ -104,8 +104,23 @@ def apply(self, pc, x, y): yvec.copy(y) -class AllAtOnceJacobiPC(PCBase): - prefix = "aaojacobi_" +class CovarianceMassPC(PCBase): + prefix = "covmass_" + pass + + +class CovarianceDiffusionPC(PCBase): + prefix = "covdiff_" + pass + + +class SC4DVarBackgroundPC(PCBase): + prefix = "bkg_" + pass + + +class WC4DVarLTDLPC(PCBase): + prefix = "ltdl_" def initialize(self, pc): self.fdvrf = self.pmat.fdvrf From bb3196aa95d6eb2046564f30f40c569dd97ecc15 Mon Sep 17 00:00:00 2001 From: JHopeCollins Date: Tue, 6 May 2025 08:35:59 +0100 Subject: [PATCH 15/16] diffusion correlation operators --- burgers/burgers_simple_sc4dvar.py | 4 +- burgers/burgers_simple_sc4dvar_dirichlet.py | 110 +++++++++----------- burgers/burgers_simple_wc4dvar.py | 8 +- fdvar/__init__.py | 1 + fdvar/generate_data.py | 14 ++- fdvar/mat.py | 6 +- fdvar/pc.py | 54 +++++++++- 7 files changed, 123 insertions(+), 74 deletions(-) diff --git a/burgers/burgers_simple_sc4dvar.py b/burgers/burgers_simple_sc4dvar.py index 6ebc77a..82294c4 100644 --- a/burgers/burgers_simple_sc4dvar.py +++ b/burgers/burgers_simple_sc4dvar.py @@ -5,7 +5,7 @@ np.random.seed(42) # number of observation windows, and steps per window -nw, nt, dt, nu = 50, 6, 1e-4, 0.01 +nw, nt, dt, nu = 32, 6, 1e-4, 0.05 # Covariance of background, observation, and model noise sigma_b = 0.1 @@ -140,8 +140,6 @@ def applyTranspose(self, pc, x, y): 'ksp_rtol': 1e-1, 'ksp_type': 'gmres', 'pc_type': 'none', - # 'pc_type': 'python', - # 'pc_python_type': f'{__name__}.CovariancePC', }, } tao = TAOSolver(MinimizationProblem(Jhat), diff --git a/burgers/burgers_simple_sc4dvar_dirichlet.py b/burgers/burgers_simple_sc4dvar_dirichlet.py index 62ed812..7c6441e 100644 --- a/burgers/burgers_simple_sc4dvar_dirichlet.py +++ b/burgers/burgers_simple_sc4dvar_dirichlet.py @@ -4,67 +4,45 @@ from fdvar import generate_observation_data from sys import exit np.random.seed(42) - - -class BackgroundPC(PCBase): - def initialize(self, pc): - A, _ = pc.getOperators() - scfdv = A.getPythonContext().problem.reduced_functional - B = scfdv.background_norm.covariance - Vc = scfdv.controls[0].control.function_space() - - u = Function(Vc) - b = Cofunction(Vc.dual()) - a = inner(TrialFunction(Vc)*(1/B), TestFunction(Vc))*dx - - solver = LinearVariationalSolver( - LinearVariationalProblem( - a, b, u, constant_jacobian=True), - solver_parameters={'ksp_type': 'preonly', - 'pc_type': 'lu'}) - - self.u, self.b, self.solver = u, b, solver - - def apply(self, pc, x, y): - with self.b.dat.vec_wo as vb: - x.copy(vb) - self.solver.solve() - with self.u.dat.vec_ro as vu: - vu.copy(y) - - def update(self, pc, x): - pass - - def applyTranspose(self, pc, x, y): - raise NotImplementedError +Print = PETSc.Sys.Print # number of observation windows, and steps per window -nx, nw, nt, dt, nu = 30, 15, 5, 5e-4, 0.25 +nx, nw, nt, dt, nu = 100, 50, 6, 1e-4, 0.25 -T = 0.03 -nsw = 50 +# T = 0.03 +# nsw = 50 +# qscale = T/nsw +qscale = nt*dt # Covariance of background, observation, and model noise sigma_b = 1e-2 sigma_r = 1e-3 -sigma_q = (1e-4)*T/nsw +sigma_q = (1e-4)*qscale -B = sigma_b -R = sigma_r -Q = sigma_q +# lengthscales of the background and model covariances +L_b = 0.25 +L_q = 0.05 + +cfl_b = (L_b*L_b/2)*(nx*nx) +Print(f"{cfl_b = }") # 1D periodic mesh mesh = UnitIntervalMesh(nx) x, = SpatialCoordinate(mesh) # Burger's equation with implicit midpoint integration -V = VectorFunctionSpace(mesh, "CG", 1) -V0 = FunctionSpace(mesh, "CG", 1) +degree = 2 +V = VectorFunctionSpace(mesh, "CG", degree) +V0 = FunctionSpace(mesh, "CG", degree) Real = FunctionSpace(mesh, "R", 0) -B = Function(V0).project(sigma_b*(1 - 0.9*cos(2*pi*x)*sin(4*pi*(0.5-x)))) +# vary between (Bmin =< B/sigma_b =< 1) +# Bmin = 0.1 +# Bprofile = cos(2*pi*x)*sin(4*pi*(0.5-x)) +# Bexpr = ((1 + Bmin) + (1 - Bmin)*Bprofile)/2 +# B = Function(V0).project(sigma_b*Bexpr) -t = Function(Real).assign(0) +t = Function(Real).assign(0.) fdt = Function(Real).assign(dt) un, un1 = Function(V), Function(V) @@ -73,9 +51,24 @@ def applyTranspose(self, pc, x, y): zero = Constant(0) bcs = [DirichletBC(V, as_vector([zero]), "on_boundary")] +# B = (sigma_b, L_b, bcs) +# Q = (sigma_q, L_q, bcs) +# R = sigma_r + +lu_params = {'ksp_type': 'preonly', 'pc_type': 'lu'} + +B = CovarianceOperator( + V, form_type="diffusion", + sigma=sigma_b, L=L_b, bcs=bcs, + solver_parameters=lu_params) + +# B = CovarianceOperator( +# V, form_type="mass", sigma=sigma_b, +# bcs=bcs, solver_parameters=lu_params) + params = { - 'ksp_type': 'gmres', - 'pc_type': 'ilu', + 'snes_rtol': 1e-10, + **lu_params, } # forcing @@ -111,16 +104,11 @@ def applyTranspose(self, pc, x, y): NonlinearVariationalProblem(F, un1, bcs=bcs), solver_parameters=params) - -def solve_step(): - un1.assign(un) - stepper.solve() - un.assign(un1) - - # "ground truth" reference solution reference_ic = Function(V).project( as_vector([k*sin(2*pi*x)])) +for bc in bcs: + bc.apply(reference_ic) # observations are point evaluations at random locations observation_locations = [ @@ -132,14 +120,17 @@ def solve_step(): # vary between (sigma_r =< R =< 1) Rprofile = sin(6*pi*(x + 0.3)) Rexpr = ((1 + sigma_r) + (1 - sigma_r)*Rprofile)/2 -R = Function(Y0).interpolate(Rexpr) +Rfunc = Function(Y0).interpolate(Rexpr) +R = CovarianceOperator( + Y, form_type="mass", sigma=Rfunc, + solver_parameters=lu_params) def H(x): # operator to take observations return assemble(interpolate(x, Y)) # generate "ground-truth" observational data y, background = generate_observation_data( - None, reference_ic, stepper, un, un1, + None, reference_ic, stepper, un, un1, bcs, H, nw, nt, sigma_b, sigma_r, sigma_q) # create function evaluating observation error at window i @@ -200,15 +191,16 @@ def observation_error(i): 'ksp_monitor_short': None, 'ksp_converged_rate': None, 'ksp_converged_maxits': None, - 'ksp_max_it': 6, + 'ksp_max_it': 15, 'ksp_rtol': 1e-1, - 'ksp_type': 'gmres', + 'ksp_type': 'cg', 'pc_type': 'python', - 'pc_python_type': f'{__name__}.BackgroundPC', + 'pc_python_type': f'fdvar.SC4DVarBackgroundPC', }, } tao = TAOSolver(MinimizationProblem(Jhat), - parameters=tao_parameters) + parameters=tao_parameters, + options_prefix="sc") xopt = tao.solve() PETSc.Sys.Print(f"{errornorm(reference_ic, background) = :.3e}") PETSc.Sys.Print(f"{errornorm(reference_ic, xopt) = :.3e}") diff --git a/burgers/burgers_simple_wc4dvar.py b/burgers/burgers_simple_wc4dvar.py index 11b5fa1..ed996a0 100644 --- a/burgers/burgers_simple_wc4dvar.py +++ b/burgers/burgers_simple_wc4dvar.py @@ -5,7 +5,7 @@ np.random.seed(42) # number of observation windows, and steps per window -nw, nt, dt, nu = 5, 6, 1e-4, 0.05 +nw, nt, dt, nu = 32, 6, 1e-4, 0.05 # Covariance of background, observation, and model noise sigma_b = 0.1 @@ -102,11 +102,13 @@ def observation_error(i): tao_parameters = { 'tao_monitor': None, 'tao_converged_reason': None, - 'tao_gttol': 1e-1, + 'tao_gttol': 1e-2, 'tao_type': 'nls', 'tao_nls': { 'ksp_monitor_short': None, - 'ksp_rtol': 5e-1, + 'ksp_converged_maxits': None, + 'ksp_max_it': 8, + 'ksp_rtol': 1e-1, 'ksp_type': 'gmres'} } tao = TAOSolver(Jhat, options_prefix="fdv", diff --git a/fdvar/__init__.py b/fdvar/__init__.py index c308d89..beb3e17 100644 --- a/fdvar/__init__.py +++ b/fdvar/__init__.py @@ -1,2 +1,3 @@ from .tao_solver import * # noqa: F401, F403 from .generate_data import * # noqa: F401, F403 +from .pc import * # noqa: F401, F403 diff --git a/fdvar/generate_data.py b/fdvar/generate_data.py index c01074e..fe11e8f 100644 --- a/fdvar/generate_data.py +++ b/fdvar/generate_data.py @@ -2,15 +2,18 @@ import numpy as np -def noisify(u, sigma, seed=None): +def noisify(u, sigma, bcs=None, seed=None): if seed is not None: np.random.seed(seed) for dat in u.dat: dat.data[:] += np.random.normal(0, sigma, dat.data.shape) + if bcs: + for bc in bcs: + bc.apply(u) return u -def generate_observation_data(ensemble, ic, stepper, un, un1, H, +def generate_observation_data(ensemble, ic, stepper, un, un1, bcs, H, nw, nt, sigma_b, sigma_r, sigma_q, seed=6): if seed is not None: np.random.seed(seed) @@ -26,7 +29,10 @@ def generate_observation_data(ensemble, ic, stepper, un, un1, H, nlocal_stages -= 1 # background is reference plus noise - background = noisify(ic.copy(deepcopy=True), sigma_b) + for bc in bcs: + bc.apply(ic) + + background = noisify(ic.copy(deepcopy=True), sigma_b, bcs=bcs) y = [] # initial observation @@ -39,7 +45,7 @@ def solve_local_stages(): un1.assign(un) stepper.solve() un.assign(un1) - noisify(un, sigma_q) + noisify(un, sigma_q, bcs=bcs) y.append(noisify(H(un), sigma_r)) un.assign(ic) diff --git a/fdvar/mat.py b/fdvar/mat.py index b069e2a..b467cfc 100644 --- a/fdvar/mat.py +++ b/fdvar/mat.py @@ -300,18 +300,18 @@ def mult(self, A, x, y): if self._shift != 0: y.axpy(self._shift, x) - @check_rf_action(action=Hessian) + # @check_rf_action(action=Hessian) def _mult_hessian(self, A, x): # self.update_tape_values(update_adjoint=True) return self.Jhat.hessian( x, apply_riesz=self.apply_riesz) - @check_rf_action(TLM) + # @check_rf_action(TLM) def _mult_tlm(self, A, x): # self.update_tape_values(update_adjoint=False) return self.Jhat.tlm(x) - @check_rf_action(Adjoint) + # @check_rf_action(Adjoint) def _mult_adjoint(self, A, x): # self.update_tape_values(update_adjoint=False) return self.Jhat.derivative( diff --git a/fdvar/pc.py b/fdvar/pc.py index d7fed77..914ec27 100644 --- a/fdvar/pc.py +++ b/fdvar/pc.py @@ -1,5 +1,8 @@ import firedrake as fd from fdvar.mat import EnsembleMatCtxBase +from math import pi, sqrt + +__all__ = ("SC4DVarBackgroundPC",) class PCBase: @@ -115,8 +118,55 @@ class CovarianceDiffusionPC(PCBase): class SC4DVarBackgroundPC(PCBase): - prefix = "bkg_" - pass + prefix = "scbkg_" + + def initialize(self, pc): + prefix = pc.getOptionsPrefix() or "" + options_prefix = prefix + self.prefix + A, _ = pc.getOperators() + scfdv = A.getPythonContext().problem.reduced_functional + + self.covariance = scfdv.background_norm.covariance + self.x = fd.Cofunction(self.covariance.V.dual()) + self.y = fd.Function(self.covariance.V) + + # # diffusion coefficient for lengthscale L + # nu = 0.5*L*L + # # normalisation for diffusion operator + # lambda_g = L*sqrt(2*pi) + # scale = lambda_g/sigma + + # V = scfdv.controls[0].control.function_space() + # sol = fd.Function(V) + # b = fd.Cofunction(V.dual()) + # u = fd.TrialFunction(V) + # v = fd.TestFunction(V) + # + # Binv = scale*(fd.inner(u, v) - nu*fd.inner(fd.grad(u), fd.grad(v)))*fd.dx + + # solver = fd.LinearVariationalSolver( + # fd.LinearVariationalProblem( + # Binv, b, sol, bcs=bcs, + # constant_jacobian=True), + # options_prefix=options_prefix, + # solver_parameters={'ksp_type': 'preonly', + # 'pc_type': 'lu'}) + + # self.sol, self.b, self.solver = sol, b, solver + + def apply(self, pc, x, y): + with self.x.dat.vec_wo as vx: + x.copy(vx) + self.covariance.action(self.x, tensor=self.y) + with self.y.dat.vec_ro as vy: + # with self.y.riesz_representation().dat.vec_ro as vy: + vy.copy(y) + + def update(self, pc): + pass + + def applyTranspose(self, pc, x, y): + raise NotImplementedError class WC4DVarLTDLPC(PCBase): From 854cc09c5da5c8363adb424110213bc58434d901 Mon Sep 17 00:00:00 2001 From: JHopeCollins Date: Tue, 6 May 2025 10:40:13 +0100 Subject: [PATCH 16/16] saddle impl to separate file --- fdvar/mat.py | 180 ------------------------------------------- fdvar/saddle.py | 201 ++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 201 insertions(+), 180 deletions(-) create mode 100644 fdvar/saddle.py diff --git a/fdvar/mat.py b/fdvar/mat.py index b467cfc..f5e6e49 100644 --- a/fdvar/mat.py +++ b/fdvar/mat.py @@ -11,71 +11,6 @@ from functools import partial, cached_property, wraps -class ISNest: - attrs = ( - 'getSizes', - 'getSize', - 'getLocalSize', - 'getIndices', - 'getComm', - ) - def __init__(self, ises): - self._comm = ises[0].getComm() - self._ises = ises - for attr in self.attrs: - setattr(self, attr, - partial(self._getattr, attr)) - - def _getattr(self, attr, i, *args, **kwargs): - return getattr(self.ises[i], attr)(*args, **kwargs) - - @cached_property - def globalSize(self): - return sum(self.getSize(i) for i in range(len(self))) - - @cached_property - def localSize(self): - return sum(self.getLocalSize(i) for i in range(len(self))) - - @property - def sizes(self): - return (self.localSize, self.globalSize) - - @property - def comm(self): - return self._comm - - @property - def ises(self): - return self._ises - - def __getitem__(self, i): - return self.ises[i] - - def __len__(self): - return len(self.ises) - - def __iter__(self): - return iter(self.ises) - - def createVec(self, i=None, vec_type=PETSc.Vec.Type.MPI): - vec = PETSc.Vec().create(comm=self.comm) - vec.setType(vec_type) - vec.setSizes(self.sizes if i is None else self.getSizes(i)) - return vec - - def createVecs(self, vec_type=PETSc.Vec.Type.MPI): - return (self.createVec(i, vec_type=vec_type) for i in range(len(self))) - - def createVecNest(self, vecs=None): - if vecs is None: - vecs = self.createVecs() - else: - if not all(vecs[i].getSizes() == self.getSizes(i) for i in range(len(self))): - raise ValueError("vec sizes must match is sizes") - return PETSc.Vec().createNest(vecs, self.ises, self.comm) - - class RFAction(Enum): TLM = 'tlm' Adjoint = 'adjoint' @@ -95,121 +30,6 @@ def convert_types(overloaded, options): for o in overloaded]) -def saddle_ises(fdvrf): - ensemble = fdvrf.ensemble - rank = ensemble.ensemble_rank - global_comm = ensemble.global_comm - - Vsol = fdvrf.solution_space - Vobs = fdvrf.observation_space - Vs = Vsol.local_spaces[0] - Vo = Vobs.local_spaces[0] - - # ndofs per (dn, dl, dx) block - bsol = Vsol.dof_dset.layout_vec.getLocalSize() - bobs = Vobs.dof_dset.layout_vec.getLocalSize() - nlocal_blocks = Vsol.nlocal_spaces - - bsize_dn = bsol - bsize_dl = bobs - bsize_dx = bsol - bsize = bsize_dn + bsize_dl + bsize_dx - - # number of blocks on previous ranks - nprev_blocks = ensemble.ensemble_comm.exscan(nlocal_blocks) - if rank == 0: # exscan returns None - nprev_blocks = 0 - - # offset to start of global indices of each field in local block j - offset = bsize*nprev_blocks - offset_dn = lambda j: offset + j*bsize - offset_dl = lambda j: offset_dn(j) + bsize_dn - offset_dx = lambda j: offset_dl(j) + bsize_dl - - indices_dn = np.concatenate( - [offset_dn(j) + np.arange(bsize_dn, dtype=np.int32) - for j in range(nlocal_blocks)]) - - indices_dl = np.concatenate( - [offset_dl(j) + np.arange(bsize_dl, dtype=np.int32) - for j in range(nlocal_blocks)]) - - indices_dx = np.concatenate( - [offset_dx(j) + np.arange(bsize_dx, dtype=np.int32) - for j in range(nlocal_blocks)]) - - is_dn = PETSc.IS().createGeneral(indices_dn, comm=global_comm) - is_dl = PETSc.IS().createGeneral(indices_dl, comm=global_comm) - is_dx = PETSc.IS().createGeneral(indices_dx, comm=global_comm) - - return ISNest((is_dn, is_dl, is_dx)) - - -def FDVarSaddlePointKSP(fdvrf, solver_parameters, options_prefix=None): - saddlemat = FDVarSaddlePointMat(fdvrf) - - ksp = PETSc.KSP().create(comm=fdvrf.ensemble.global_comm) - ksp.setOperators(saddlemat, saddlemat) - - options = OptionsManager(solver_parameters, options_prefix) - options.set_from_options(ksp) - - return ksp, options - - -# Saddle-point MatNest -def FDVarSaddlePointMat(fdvrf): - ensemble = fdvrf.ensemble - - dn_is, dl_is, dx_is = saddle_ises(fdvrf) - - # L Mat - L = AllAtOnceRFMat(fdvrf, action=TLM) - Lt = AllAtOnceRFMat(fdvrf, action=Adjoint) - - Lrow = dn_is - Lcol = dx_is - - Ltrow = dx_is - Ltcol = dn_is - - # H Mat - H = ObservationEnsembleRFMat(fdvrf, action=TLM) - Ht = ObservationEnsembleRFMat(fdvrf, action=Adjoint) - - Hrow = dl_is - Hcol = dx_is - - Htrow = dx_is - Htcol = dl_is - - # D Mat - D = ModelCovarianceEnsembleRFMat(fdvrf) - - Drow = dn_is - Dcol = dn_is - - # R Mat - R = ObservationCovarianceEnsembleRFMat(fdvrf) - - Rrow = dl_is - Rcol = dl_is - - fdvmat = PETSc.Mat().createNest( - mats=[D, L, # noqa: E127,E202 - R, H, # noqa: E127,E202 - Lt, Ht ], # noqa: E127,E202 - isrows=[Drow, Lrow, # noqa: E127,E202 - Rrow, Hrow, # noqa: E127,E202 - Ltrow, Htrow ], # noqa: E127,E202 - iscols=[Dcol, Lcol, # noqa: E127,E202 - Rcol, Hcol, # noqa: E127,E202 - Ltcol, Htcol ], # noqa: E127,E202 - comm=ensemble.global_comm) - - return fdvmat - - def check_rf_action(action): def check_rf_action_decorator(func): @wraps(func) diff --git a/fdvar/saddle.py b/fdvar/saddle.py new file mode 100644 index 0000000..7a8d13a --- /dev/null +++ b/fdvar/saddle.py @@ -0,0 +1,201 @@ +import firedrake as fd +from firedrake.petsc import PETSc +from firedrake.petsc import OptionsManager +from firedrake.adjoint import ReducedFunctional +from firedrake.adjoint.fourdvar_reduced_functional import CovarianceNormReducedFunctional +from pyop2.mpi import MPI +from pyadjoint.optimization.tao_solver import PETScVecInterface +from pyadjoint.enlisting import Enlist +from typing import Optional +from enum import Enum +from functools import partial, cached_property, wraps + + +class ISNest: + attrs = ( + 'getSizes', + 'getSize', + 'getLocalSize', + 'getIndices', + 'getComm', + ) + def __init__(self, ises): + self._comm = ises[0].getComm() + self._ises = ises + for attr in self.attrs: + setattr(self, attr, + partial(self._getattr, attr)) + + def _getattr(self, attr, i, *args, **kwargs): + return getattr(self.ises[i], attr)(*args, **kwargs) + + @cached_property + def globalSize(self): + return sum(self.getSize(i) for i in range(len(self))) + + @cached_property + def localSize(self): + return sum(self.getLocalSize(i) for i in range(len(self))) + + @property + def sizes(self): + return (self.localSize, self.globalSize) + + @property + def comm(self): + return self._comm + + @property + def ises(self): + return self._ises + + def __getitem__(self, i): + return self.ises[i] + + def __len__(self): + return len(self.ises) + + def __iter__(self): + return iter(self.ises) + + def createVec(self, i=None, vec_type=PETSc.Vec.Type.MPI): + vec = PETSc.Vec().create(comm=self.comm) + vec.setType(vec_type) + vec.setSizes(self.sizes if i is None else self.getSizes(i)) + return vec + + def createVecs(self, vec_type=PETSc.Vec.Type.MPI): + return (self.createVec(i, vec_type=vec_type) for i in range(len(self))) + + def createVecNest(self, vecs=None): + if vecs is None: + vecs = self.createVecs() + else: + if not all(vecs[i].getSizes() == self.getSizes(i) for i in range(len(self))): + raise ValueError("vec sizes must match is sizes") + return PETSc.Vec().createNest(vecs, self.ises, self.comm) + + +def copy_controls(controls): + return controls.delist([c.control._ad_init_zero() for c in controls]) + + +def convert_types(overloaded, options): + overloaded = Enlist(overloaded) + return overloaded.delist([o._ad_convert_type(o, options=options) + for o in overloaded]) + + +def saddle_ises(fdvrf): + ensemble = fdvrf.ensemble + rank = ensemble.ensemble_rank + global_comm = ensemble.global_comm + + Vsol = fdvrf.solution_space + Vobs = fdvrf.observation_space + Vs = Vsol.local_spaces[0] + Vo = Vobs.local_spaces[0] + + # ndofs per (dn, dl, dx) block + bsol = Vsol.dof_dset.layout_vec.getLocalSize() + bobs = Vobs.dof_dset.layout_vec.getLocalSize() + nlocal_blocks = Vsol.nlocal_spaces + + bsize_dn = bsol + bsize_dl = bobs + bsize_dx = bsol + bsize = bsize_dn + bsize_dl + bsize_dx + + # number of blocks on previous ranks + nprev_blocks = ensemble.ensemble_comm.exscan(nlocal_blocks) + if rank == 0: # exscan returns None + nprev_blocks = 0 + + # offset to start of global indices of each field in local block j + offset = bsize*nprev_blocks + offset_dn = lambda j: offset + j*bsize + offset_dl = lambda j: offset_dn(j) + bsize_dn + offset_dx = lambda j: offset_dl(j) + bsize_dl + + indices_dn = np.concatenate( + [offset_dn(j) + np.arange(bsize_dn, dtype=np.int32) + for j in range(nlocal_blocks)]) + + indices_dl = np.concatenate( + [offset_dl(j) + np.arange(bsize_dl, dtype=np.int32) + for j in range(nlocal_blocks)]) + + indices_dx = np.concatenate( + [offset_dx(j) + np.arange(bsize_dx, dtype=np.int32) + for j in range(nlocal_blocks)]) + + is_dn = PETSc.IS().createGeneral(indices_dn, comm=global_comm) + is_dl = PETSc.IS().createGeneral(indices_dl, comm=global_comm) + is_dx = PETSc.IS().createGeneral(indices_dx, comm=global_comm) + + return ISNest((is_dn, is_dl, is_dx)) + + +def FDVarSaddlePointKSP(fdvrf, solver_parameters, options_prefix=None): + saddlemat = FDVarSaddlePointMat(fdvrf) + + ksp = PETSc.KSP().create(comm=fdvrf.ensemble.global_comm) + ksp.setOperators(saddlemat, saddlemat) + + options = OptionsManager(solver_parameters, options_prefix) + options.set_from_options(ksp) + + return ksp, options + + +# Saddle-point MatNest +def FDVarSaddlePointMat(fdvrf): + ensemble = fdvrf.ensemble + + dn_is, dl_is, dx_is = saddle_ises(fdvrf) + + # L Mat + L = AllAtOnceRFMat(fdvrf, action=TLM) + Lt = AllAtOnceRFMat(fdvrf, action=Adjoint) + + Lrow = dn_is + Lcol = dx_is + + Ltrow = dx_is + Ltcol = dn_is + + # H Mat + H = ObservationEnsembleRFMat(fdvrf, action=TLM) + Ht = ObservationEnsembleRFMat(fdvrf, action=Adjoint) + + Hrow = dl_is + Hcol = dx_is + + Htrow = dx_is + Htcol = dl_is + + # D Mat + D = ModelCovarianceEnsembleRFMat(fdvrf) + + Drow = dn_is + Dcol = dn_is + + # R Mat + R = ObservationCovarianceEnsembleRFMat(fdvrf) + + Rrow = dl_is + Rcol = dl_is + + fdvmat = PETSc.Mat().createNest( + mats=[D, L, # noqa: E127,E202 + R, H, # noqa: E127,E202 + Lt, Ht ], # noqa: E127,E202 + isrows=[Drow, Lrow, # noqa: E127,E202 + Rrow, Hrow, # noqa: E127,E202 + Ltrow, Htrow ], # noqa: E127,E202 + iscols=[Dcol, Lcol, # noqa: E127,E202 + Rcol, Hcol, # noqa: E127,E202 + Ltcol, Htcol ], # noqa: E127,E202 + comm=ensemble.global_comm) + + return fdvmat