diff --git a/advection/advection_sc4dvar_aaorf.py b/advection/advection_sc4dvar_aaorf.py index 1a008cc..9f869f0 100644 --- a/advection/advection_sc4dvar_aaorf.py +++ b/advection/advection_sc4dvar_aaorf.py @@ -15,9 +15,9 @@ Jhat = FourDVarReducedFunctional( Control(control), - background_iprod=norm2(B), - observation_iprod=norm2(R), - observation_err=observation_error(0), + background_covariance=B, + observation_covariance=R, + observation_error=observation_error(0), weak_constraint=False) nstep = 0 @@ -38,13 +38,13 @@ # take observation obs_err = observation_error(stage.observation_index) stage.set_observation(qn, obs_err, - observation_iprod=norm2(R)) + observation_covariance=R) pause_annotation() print(f"{taylor_test(Jhat, control, values[0]) = }") -options = {'disp': True, 'ftol': 1e-2} +options = {'disp': fd.COMM_WORLD.rank == 0, 'ftol': 1e-2} derivative_options = {'riesz_representation': None} opt = minimize(Jhat, options=options, method="L-BFGS-B", diff --git a/advection/advection_sc4dvar_pyadj.py b/advection/advection_sc4dvar_pyadj.py index 252ec6a..beccde9 100644 --- a/advection/advection_sc4dvar_pyadj.py +++ b/advection/advection_sc4dvar_pyadj.py @@ -11,10 +11,10 @@ continue_annotation() # background functional -J = norm2(B)(control - background) +J = covariance_norm(control - background, B) # initial observation functional -J += norm2(R)(observation_error(0)(control)) +J += covariance_norm(observation_error(0)(control), R) nstep = 0 qn.assign(control) @@ -28,7 +28,7 @@ nstep += 1 # observation functional - J += norm2(R)(observation_error(i)(qn)) + J += covariance_norm(observation_error(i)(qn), R) pause_annotation() @@ -36,7 +36,7 @@ print(f"{taylor_test(Jhat, control, values[0]) = }") -options = {'disp': True, 'ftol': 1e-2} +options = {'disp': fd.COMM_WORLD.rank == 0, 'ftol': 1e-2} derivative_options = {'riesz_representation': 'l2'} opt = minimize(Jhat, options=options, method="L-BFGS-B", diff --git a/advection/advection_utils.py b/advection/advection_utils.py index f6de905..3206f83 100644 --- a/advection/advection_utils.py +++ b/advection/advection_utils.py @@ -3,13 +3,9 @@ import numpy as np from sys import exit -np.set_printoptions(legacy='1.25', precision=6) - +from firedrake.adjoint.fourdvar_reduced_functional import covariance_norm -def norm2(w): - def n2(x): - return fd.assemble(fd.inner(x, fd.Constant(w)*x)*fd.dx) - return n2 +np.set_printoptions(legacy='1.25', precision=6) def timestepper(mesh, V, dt, u): @@ -45,9 +41,9 @@ def analytic_solution(mesh, u, t, mag=1.0, phase=0.0): x, = fd.SpatialCoordinate(mesh) return mag*fd.sin(2*fd.pi*((x + phase) - u*t)) -B = 1 -R = 1 -Q = 1 +B = 10. +R = 0.1 +Q = 0.2*B ensemble = fd.Ensemble(fd.COMM_WORLD, 1) diff --git a/advection/advection_wc4dvar_aaorf.py b/advection/advection_wc4dvar_aaorf.py index 61406d4..db31442 100644 --- a/advection/advection_wc4dvar_aaorf.py +++ b/advection/advection_wc4dvar_aaorf.py @@ -5,58 +5,78 @@ from firedrake.adjoint import FourDVarReducedFunctional from advection_utils import * -control = fd.EnsembleFunction( - ensemble, [V for _ in range(len(targets))]) - -for x in control.subfunctions: - x.assign(background) - -continue_annotation() - -# create 4DVar reduced functional and record -# background and initial observation functionals - -Jhat = FourDVarReducedFunctional( - Control(control), - background_iprod=norm2(B), - observation_iprod=norm2(R), - observation_err=observation_error(0), - weak_constraint=True) - -nstep = 0 -# record observation stages -with Jhat.recording_stages() as stages: - # loop over stages - for stage, ctx in stages: - # start forward model - qn.assign(stage.control) - - # propogate - for _ in range(observation_freq): - qn1.assign(qn) - stepper.solve() - qn.assign(qn1) - nstep += 1 - - # take observation - obs_err = observation_error(stage.observation_index) - stage.set_observation(qn, obs_err, - observation_iprod=norm2(R), - forward_model_iprod=norm2(Q)) - -pause_annotation() - -# the perturbation values need to be held in the -# same type as the control i.e. and EnsembleFunction -vals = control.copy() -for v0, v1 in zip(vals.subfunctions, values): - v0.assign(v1) - -print(f"{Jhat(control) = }") -print(f"{taylor_test(Jhat, control, vals) = }") - -options = {'disp': True, 'ftol': 1e-2} -derivative_options = {'riesz_representation': 'l2'} - -opt = minimize(Jhat, options=options, method="L-BFGS-B", - derivative_options=derivative_options) + +def make_fdvrf(): + if ensemble.ensemble_size == 1: + nlocal_observations = len(targets) + elif ensemble.ensemble_size == 3: + nlocal_observations = 2 if ensemble.ensemble_rank == 0 else 1 + else: + raise ValueError("Must use either 1 or 3 ensemble ranks") + control_space = fd.EnsembleFunctionSpace( + [V for _ in range(nlocal_observations)], ensemble) + control = fd.EnsembleFunction(control_space) + + for x in control.subfunctions: + x.assign(background) + + continue_annotation() + + # create 4DVar reduced functional and record + # background and initial observation functionals + + Jhat = FourDVarReducedFunctional( + Control(control), + background_covariance=B, + observation_covariance=R, + observation_error=observation_error(0), + weak_constraint=True) + + nstep = 0 + # record observation stages + with Jhat.recording_stages(nstep=nstep) as stages: + # loop over stages + for stage, ctx in stages: + # start forward model + qn.assign(stage.control) + + # propogate + for _ in range(observation_freq): + qn1.assign(qn) + stepper.solve() + qn.assign(qn1) + ctx.nstep += 1 + + # take observation + obs_err = observation_error(stage.observation_index) + stage.set_observation(qn, obs_err, + observation_covariance=R, + forward_model_covariance=Q) + pause_annotation() + + return Jhat, control + + +if __name__ == '__main__': + Jhat, control = make_fdvrf() + + # the perturbation values need to be held in the + # same type as the control i.e. and EnsembleFunction + vals = control.copy() + for v0, v1 in zip(vals.subfunctions, values): + v0.assign(v1) + + print(f"{Jhat(control) = }") + print(f"{taylor_test(Jhat, control, vals) = }") + + options = { + 'disp': True, + 'ftol': 1e-2 + } + derivative_options = {'riesz_representation': 'l2'} + + opt = minimize(Jhat, method="L-BFGS-B", options=options, + derivative_options=derivative_options) + J0 = Jhat(control) + Jopt = Jhat(opt) + print(f"{J0 = :.3e} | {Jopt = :.3e} | {Jopt/J0 = :.3e}") diff --git a/advection/advection_wc4dvar_aaorf_tao.py b/advection/advection_wc4dvar_aaorf_tao.py new file mode 100644 index 0000000..e9f0ecf --- /dev/null +++ b/advection/advection_wc4dvar_aaorf_tao.py @@ -0,0 +1,62 @@ +import firedrake as fd +from firedrake.adjoint import ( + continue_annotation, pause_annotation, minimize, + stop_annotating, Control, taylor_test) +from firedrake.adjoint import FourDVarReducedFunctional +from firedrake.petsc import PETSc +from advection_utils import * +from fdvar import TAOSolver +from sys import exit + +from advection_wc4dvar_aaorf import make_fdvrf +Jhat, control = make_fdvrf() + +Print = PETSc.Sys.Print + +# the perturbation values need to be held in the +# same type as the control i.e. and EnsembleFunction +x0 = control.copy() + +ksp_params = { + 'monitor_short': None, + # 'converged_rate': None, + # 'converged_reason': None, +} + +tao_params = { + 'tao_view': ':tao_view.log', + 'tao': { + 'monitor': None, + 'converged_reason': None, + 'gatol': 1e-1, + 'grtol': 1e-1, + 'gttol': 1e-1, + }, + 'tao_type': 'nls', + 'tao_nls': { + 'ksp': ksp_params, + 'ksp_type': 'gmres', + 'pc_type': 'lmvm', + 'ksp_rtol': 1e-1, + # 'ksp_type': 'gmres', + # 'pc_type': 'python', + # 'pc_python_type': 'fdvar.AllAtOnceJacobiPC', + + }, + # 'tao_cg': { + # 'ksp': ksp_params, + # 'ksp_rtol': 1e-1, + # 'type': 'pr', # fr-pr-prp-hs-dy + # }, +} +tao = TAOSolver(Jhat, options_prefix="", + solver_parameters=tao_params) +tao.solve() + +xopt = Jhat.control.copy() +J0 = Jhat(x0) +Jopt = Jhat(xopt) + +Print(f"{J0 = :.3e} | {Jopt = :.3e} | {Jopt/J0 = :.3e}") + + diff --git a/advection/advection_wc4dvar_covariancemat.py b/advection/advection_wc4dvar_covariancemat.py new file mode 100644 index 0000000..9c303d4 --- /dev/null +++ b/advection/advection_wc4dvar_covariancemat.py @@ -0,0 +1,57 @@ +import firedrake as fd +from firedrake.petsc import PETSc, OptionsManager +from firedrake.adjoint import pyadjoint # noqa: F401 +from firedrake.matrix import ImplicitMatrix +from pyadjoint.optimization.tao_solver import PETScVecInterface +from advection_wc4dvar_aaorf import make_fdvrf +from mpi4py import MPI +import numpy as np +from typing import Optional, Collection +from fdvar.mat import * + + +def CovarianceMat(covariancerf): + space = covariancerf.controls[0].control.function_space() + comm = space.mesh().comm + covariance = covariancerf.covariance + if isinstance(covariance, Collection): + covariance, power = covariance + + sizes = space.dof_dset.layout_vec.sizes + shape = (sizes, sizes) + covmat = PETSc.Mat().createConstantDiagonal( + shape, covariance, comm=comm) + + covmat.setUp() + covmat.assemble() + return covmat + + +Jhat, control = make_fdvrf() +ensemble = Jhat.ensemble + +# >>>>> Covariance + +# Covariance Mat +# covrf = Jhat.background_norm +covrf = Jhat.stages[0].model_norm +covmat = CovarianceMat(covrf) + +# Covariance KSP +covksp = PETSc.KSP().create(comm=ensemble.comm) +covksp.setOptionsPrefix('cov_') +covksp.setOperators(covmat) + +covksp.pc.setType(PETSc.PC.Type.JACOBI) +covksp.setType(PETSc.KSP.Type.PREONLY) +covksp.setFromOptions() +covksp.setUp() +print(PETSc.Options().getAll()) + +x = covmat.createVecRight() +b = covmat.createVecLeft() + +b.array_w[:] = np.random.random_sample(b.array_w.shape) +print(f'{b.norm() = }') +covksp.solve(b, x) +print(f'{x.norm() = }') diff --git a/advection/advection_wc4dvar_pyadj.py b/advection/advection_wc4dvar_pyadj.py index c4d9bf4..a4825e8 100644 --- a/advection/advection_wc4dvar_pyadj.py +++ b/advection/advection_wc4dvar_pyadj.py @@ -12,10 +12,10 @@ continue_annotation() # background functional -J = norm2(B)(control[0] - background) +J = covariance_norm(control[0] - background, B) # initial observation functional -J += norm2(R)(observation_error(0)(control[0])) +J += covariance_norm(observation_error(0)(control[0]), R) nstep = 0 for i in range(1, len(control)): @@ -33,10 +33,10 @@ control[i].assign(qn) # model error functional - J += norm2(Q)(qn - control[i]) + J += covariance_norm(qn - control[i], Q) # observation functional - J += norm2(R)(observation_error(i)(control[i])) + J += covariance_norm(observation_error(i)(control[i]), R) pause_annotation() diff --git a/advection/advection_wc4dvar_saddlepc.py b/advection/advection_wc4dvar_saddlepc.py new file mode 100644 index 0000000..60dbe66 --- /dev/null +++ b/advection/advection_wc4dvar_saddlepc.py @@ -0,0 +1,89 @@ +import firedrake as fd +from advection_wc4dvar_aaorf import make_fdvrf +from fdvar.mat import * + +Jhat, control = make_fdvrf() +ensemble = Jhat.ensemble + +saddlepoint_params = { + 'ksp_monitor': None, + 'ksp_converged_rate': None, + 'ksp_rtol': 1e-5, + 'ksp_type': 'fgmres', + 'pc_type': 'fieldsplit', + 'pc_fieldsplit_type': 'schur', + 'pc_fieldsplit_0_fields': '0,1', # schur complement on dx + 'pc_fieldsplit_1_fields': '2', + + # diagonal pc + 'pc_fieldsplit_schur_fact_type': 'diag', + + # triangular pc + 'pc_fieldsplit_schur_fact_type': 'upper', + + # inexact constraint pc + 'pc_fieldsplit_0_fields': '2', + 'pc_fieldsplit_1_fields': '1', + 'pc_fieldsplit_2_fields': '0', + 'pc_fieldsplit_type': 'multiplicative', + + 'fieldsplit_0': { # D + 'ksp_type': 'preonly', + 'pc_type': 'python', + 'pc_python_type': 'EnsembleBJacobiPC', + 'sub': { + 'ksp_rtol': 1e-5, + 'ksp_type': 'cg', + 'pc_type': 'gamg', + 'mg_levels': { + 'ksp_type': 'chebyshev', + 'ksp_max_it': 2, + 'pc_type': 'bjacobi', + 'sub_pc_type': 'ilu', + } + } + } + 'fieldsplit_1': { # R + 'ksp_type': 'preonly', + 'pc_type': 'python', + 'pc_python_type': 'EnsembleBJacobiPC', + 'sub': { + 'ksp_rtol': 1e-5, + 'ksp_type': 'cg', + 'pc_type': 'icc', + } + } + 'fieldsplit_2': { # S + 'ksp_type': 'preonly', + 'pc_type': 'python', + 'pc_python_type': 'SaddleSchurPC', + 'pc_saddle_schur_observation_type': 'none', # or low-rank? + 'model_sub': { # L = I + 'ksp_type': 'preonly', + 'pc_type': 'none', + } + 'model_sub': { # Low bandwidth approx - its=N is direct solve + 'ksp_max_it': 1, + 'ksp_type': 'richardson', + 'pc_type': 'none', + } + 'model_sub': { # L_M pc - bjacobi with bsize = k + 'ksp_type': 'preonly', + 'pc_type': 'python', + 'pc_python_type': 'AllAtOnceBJacobiPC', + 'pc_aaobjacobi_bsize': 4, + 'sub': { + 'ksp_max_it': 4, + 'ksp_type': 'richardson', + 'pc_type': 'none', + } + } + } +} + +ksp, options = FDVarSaddlePointKSP(fdvrf, saddlepoint_params) +b = FDVarSaddlePointRHS(fdvrf) # TODO +x = b.duplicate() + +with options.inserted_options(): + ksp.solve(b, x) diff --git a/burgers/aaorf_4dvar.py b/burgers/aaorf_4dvar.py index 408c692..007840a 100644 --- a/burgers/aaorf_4dvar.py +++ b/burgers/aaorf_4dvar.py @@ -185,36 +185,19 @@ def H(x, name=None): global_comm.Barrier() -# weighted l2 inner product -def wl2prod(x, w=1.0, ad_block_tag=None): - return fd.assemble(fd.inner(x, w*x)*fd.dx, ad_block_tag=ad_block_tag)**2 - - def observation_err(i, state, name=None): return fd.Function(Vobs, name=f'Observation error H{i}(x{i}) - y{i}').assign(H(state, name) - y[i], ad_block_tag=f"Observation error calculation {i}") -background_iprod = partial(wl2prod, w=B, ad_block_tag='Background inner product') -observation_iprod = partial(wl2prod, w=R, ad_block_tag='Observation inner product') -model_iprod = partial(wl2prod, w=Q, ad_block_tag='Model inner product') - # Initialise forward model from prior/background initial conditions # and accumulate weak constraint functional as we go uapprox = [background.copy(deepcopy=True, annotate=False)] # Only initial rank needs data for initial conditions or time -if trank == 0: - background_iprod0 = background_iprod - if initial_observations: - observation_iprod0 = observation_iprod - observation_err0 = partial(observation_err, 0, name='Model observation 0') - else: - observation_iprod0 = None - observation_err0 = None +if trank == 0 and initial_observations: + observation_err0 = partial(observation_err, 0, name='Model observation 0') else: - background_iprod0 = None - observation_iprod0 = None observation_err0 = None ################################################## @@ -226,18 +209,20 @@ def observation_err(i, state, name=None): # first rank has one extra control for the initial conditions nlocal_controls = nlocal_observations + (1 if trank == 0 else 0) -aaofunc = fd.EnsembleFunction(ensemble, [V for _ in range(nlocal_controls)]) +control_space = fd.EnsembleFunctionSpace( + [V for _ in range(nlocal_controls)], ensemble) +control = fd.EnsembleFunction(control_space) if trank == 0: - aaofunc.subfunctions[0].assign(background) + control.subfunctions[0].assign(background) continue_annotation() Jhat = FourDVarReducedFunctional( - Control(aaofunc), - background_iprod=background_iprod0, - observation_iprod=observation_iprod0, - observation_err=observation_err0, + Control(control), + background_covariance=B, + observation_covariance=R, + observation_error=observation_err0, weak_constraint=(args.constraint == 'weak')) Jhat.background.topological.rename("Background") @@ -282,13 +267,10 @@ def observation_err(i, state, name=None): observation_err, local_obs_idx, name=f'Model observation {stage.observation_index}') - model_iprod = partial(wl2prod, w=Q, - ad_block_tag=f'Model inner product {stage.observation_index}') - # record the observation at the end of the stage stage.set_observation(un, obs_error, - observation_iprod=observation_iprod, - forward_model_iprod=model_iprod) + observation_covariance=R, + forward_model_covariance=Q) # PETSc.Sys.Print(f"{fd.norm(un) = }") global_comm.Barrier() diff --git a/burgers/burgers_simple_sc4dvar.py b/burgers/burgers_simple_sc4dvar.py new file mode 100644 index 0000000..82294c4 --- /dev/null +++ b/burgers/burgers_simple_sc4dvar.py @@ -0,0 +1,149 @@ +from firedrake import * +from firedrake.adjoint import * +from firedrake.__future__ import interpolate +from fdvar import generate_observation_data +np.random.seed(42) + +# number of observation windows, and steps per window +nw, nt, dt, nu = 32, 6, 1e-4, 0.05 + +# Covariance of background, observation, and model noise +sigma_b = 0.1 +sigma_r = 0.03 +sigma_q = 0.0002 + +# 1D periodic mesh +mesh = PeriodicUnitIntervalMesh(100) +x, = SpatialCoordinate(mesh) + +# Burger's equation with implicit midpoint integration +V = VectorFunctionSpace(mesh, "CG", 2) + +un, un1 = Function(V), Function(V) +v = TestFunction(V) +uh = (un + un1)/2 + +# finite element forms +F = (inner(un1 - un, v)*dx + + dt*inner(dot(uh, nabla_grad(uh)), v)*dx + + dt*inner(nu*grad(uh), grad(v))*dx) + +# timestepper solver +stepper = NonlinearVariationalSolver( + NonlinearVariationalProblem(F, un1)) + +def solve_step(): + un1.assign(un) + stepper.solve() + un.assign(un1) + +# "ground truth" reference solution +reference_ic = Function(V).project( + as_vector([1 + 0.5*sin(2*pi*x)])) + +# observations are point evaluations at random locations +observation_locations = [ + [x] for x in np.random.random_sample(20)] +vom = VertexOnlyMesh(mesh, observation_locations) +Y = VectorFunctionSpace(vom, "DG", 0) + +def H(x): # operator to take observations + return assemble(interpolate(x, Y)) + +# generate "ground-truth" observational data +y, background = generate_observation_data( + None, reference_ic, solve_step, un, + H, nw, nt, sigma_b, sigma_r, sigma_q) + +# create function evaluating observation error at window i +def observation_error(i): + return lambda x: Function(Y).assign(H(x) - y[i]) + +# create distributed control variable for entire timeseries +control = Function(V).assign(background) + +# tell pyadjoint to start taping operations +continue_annotation() + +# This object will construct and solve the 4DVar system +Jhat = FourDVarReducedFunctional( + Control(control), + background=background, + background_covariance=sigma_b, + observation_covariance=sigma_r, + observation_error=observation_error(0), + weak_constraint=False) + +# loop over each observation stage on the local communicator +with Jhat.recording_stages(nstages=nw) as stages: + for stage, ctx in stages: + idx = stage.local_index + un.assign(stage.control) + un1.assign(un) + + # let pyadjoint tape the time integration + for i in range(nt): + stepper.solve() + un.assign(un1) + + # tell pyadjoint a) we have finished this stage + # and b) how to evaluate this observation error + stage.set_observation( + state=un, + observation_error=observation_error(idx), + observation_covariance=sigma_r) + +# tell pyadjoint to finish taping operations +pause_annotation() + + +class CovariancePC(PCBase): + def initialize(self, pc): + w = Constant(sigma_b) + u = Function(V) + b = Cofunction(V.dual()) + a = (1/w)*inner(TrialFunction(V), TestFunction(V))*dx + solver = LinearVariationalSolver( + LinearVariationalProblem(a, b, u), + solver_parameters={'ksp_type': 'preonly', + 'pc_type': 'lu'}) + self.u, self.b, self.solver = u, b, solver + + def apply(self, pc, x, y): + with self.b.dat.vec_wo as vb: + x.copy(vb) + self.solver.solve() + with self.u.dat.vec_ro as vu: + vu.copy(y) + + # x.copy(y) + # y.scale(sigma_b) + + def update(self, pc, x): + pass + + def applyTranspose(self, pc, x, y): + raise NotImplementedError + + +# Solution strategy is controlled via this options dictionary +tao_parameters = { + 'tao_view': ':tao_view.log', + 'tao_monitor': None, + 'tao_converged_reason': None, + 'tao_gttol': 1e-2, + 'tao_type': 'nls', + 'tao_nls': { + 'ksp_monitor_short': None, + 'ksp_max_it': 10, + 'ksp_converged_maxits': None, + 'ksp_rtol': 1e-1, + 'ksp_type': 'gmres', + 'pc_type': 'none', + }, +} +tao = TAOSolver(MinimizationProblem(Jhat), + parameters=tao_parameters) +xopt = tao.solve() +PETSc.Sys.Print(f"{errornorm(reference_ic, background) = :.3e}") +PETSc.Sys.Print(f"{errornorm(reference_ic, xopt) = :.3e}") diff --git a/burgers/burgers_simple_sc4dvar_dirichlet.py b/burgers/burgers_simple_sc4dvar_dirichlet.py new file mode 100644 index 0000000..7c6441e --- /dev/null +++ b/burgers/burgers_simple_sc4dvar_dirichlet.py @@ -0,0 +1,206 @@ +from firedrake import * +from firedrake.adjoint import * +from firedrake.__future__ import interpolate +from fdvar import generate_observation_data +from sys import exit +np.random.seed(42) +Print = PETSc.Sys.Print + +# number of observation windows, and steps per window +nx, nw, nt, dt, nu = 100, 50, 6, 1e-4, 0.25 + +# T = 0.03 +# nsw = 50 +# qscale = T/nsw +qscale = nt*dt + +# Covariance of background, observation, and model noise +sigma_b = 1e-2 +sigma_r = 1e-3 +sigma_q = (1e-4)*qscale + +# lengthscales of the background and model covariances +L_b = 0.25 +L_q = 0.05 + +cfl_b = (L_b*L_b/2)*(nx*nx) +Print(f"{cfl_b = }") + +# 1D periodic mesh +mesh = UnitIntervalMesh(nx) +x, = SpatialCoordinate(mesh) + +# Burger's equation with implicit midpoint integration +degree = 2 +V = VectorFunctionSpace(mesh, "CG", degree) +V0 = FunctionSpace(mesh, "CG", degree) +Real = FunctionSpace(mesh, "R", 0) + +# vary between (Bmin =< B/sigma_b =< 1) +# Bmin = 0.1 +# Bprofile = cos(2*pi*x)*sin(4*pi*(0.5-x)) +# Bexpr = ((1 + Bmin) + (1 - Bmin)*Bprofile)/2 +# B = Function(V0).project(sigma_b*Bexpr) + +t = Function(Real).assign(0.) +fdt = Function(Real).assign(dt) + +un, un1 = Function(V), Function(V) +v = TestFunction(V) + +zero = Constant(0) +bcs = [DirichletBC(V, as_vector([zero]), "on_boundary")] + +# B = (sigma_b, L_b, bcs) +# Q = (sigma_q, L_q, bcs) +# R = sigma_r + +lu_params = {'ksp_type': 'preonly', 'pc_type': 'lu'} + +B = CovarianceOperator( + V, form_type="diffusion", + sigma=sigma_b, L=L_b, bcs=bcs, + solver_parameters=lu_params) + +# B = CovarianceOperator( +# V, form_type="mass", sigma=sigma_b, +# bcs=bcs, solver_parameters=lu_params) + +params = { + 'snes_rtol': 1e-10, + **lu_params, +} + +# forcing +k = Constant(0.1) + +x1 = 1 - x +t1 = t + 1 +g = ( + pi*k*( + (x + k*t1*sin(pi*x1*t1)) + *cos(pi*x*t1)*sin(pi*x1*t1) + + + (x1 - k*t1*sin(pi*x*t1)) + *sin(pi*x*t1)*cos(pi*x1*t1) + ) + + + (2*nu*(k*pi*t1)**2) + *(sin(pi*x*t1)*sin(pi*x1*t1) + + cos(pi*x*t1)*cos(pi*x1*t1)) +) +uh = (un + un1)/2 + +# finite element forms +F = ( + (inner(un1 - un, v)/fdt)*dx + + inner(dot(uh, nabla_grad(uh)), v)*dx + + inner(nu*grad(uh), grad(v))*dx + - inner(as_vector([g]), v)*dx +) + +# timestepper solver +stepper = NonlinearVariationalSolver( + NonlinearVariationalProblem(F, un1, bcs=bcs), + solver_parameters=params) + +# "ground truth" reference solution +reference_ic = Function(V).project( + as_vector([k*sin(2*pi*x)])) +for bc in bcs: + bc.apply(reference_ic) + +# observations are point evaluations at random locations +observation_locations = [ + [x] for x in np.random.random_sample(20)] +vom = VertexOnlyMesh(mesh, observation_locations) +Y = VectorFunctionSpace(vom, "DG", 0) +Y0 = FunctionSpace(vom, "DG", 0) + +# vary between (sigma_r =< R =< 1) +Rprofile = sin(6*pi*(x + 0.3)) +Rexpr = ((1 + sigma_r) + (1 - sigma_r)*Rprofile)/2 +Rfunc = Function(Y0).interpolate(Rexpr) +R = CovarianceOperator( + Y, form_type="mass", sigma=Rfunc, + solver_parameters=lu_params) + +def H(x): # operator to take observations + return assemble(interpolate(x, Y)) + +# generate "ground-truth" observational data +y, background = generate_observation_data( + None, reference_ic, stepper, un, un1, bcs, + H, nw, nt, sigma_b, sigma_r, sigma_q) + +# create function evaluating observation error at window i +def observation_error(i): + return lambda x: Function(Y).assign(H(x) - y[i]) + +# create distributed control variable for entire timeseries +control = Function(V).assign(background) + +# tell pyadjoint to start taping operations +continue_annotation() + +# This object will construct and solve the 4DVar system +Jhat = FourDVarReducedFunctional( + Control(control), + background=background, + background_covariance=B, + observation_covariance=R, + observation_error=observation_error(0), + weak_constraint=False) + +# loop over each observation stage on the local communicator +t.assign(0.) +with Jhat.recording_stages(nstages=nw, t=t) as stages: + for stage, ctx in stages: + idx = stage.local_index + un.assign(stage.control) + t.assign(ctx.t) + + # let pyadjoint tape the time integration + for i in range(nt): + un1.assign(un) + stepper.solve() + un.assign(un1) + t += dt + ctx.t.assign(t) + + # tell pyadjoint a) we have finished this stage + # and b) how to evaluate this observation error + stage.set_observation( + state=un, + observation_error=observation_error(idx), + observation_covariance=R) + + +# tell pyadjoint to finish taping operations +pause_annotation() + + +# Solution strategy is controlled via this options dictionary +tao_parameters = { + 'tao_view': ':tao_view.log', + 'tao_monitor': None, + 'tao_converged_reason': None, + 'tao_gttol': 1e-1, + 'tao_type': 'nls', + 'tao_nls': { + 'ksp_monitor_short': None, + 'ksp_converged_rate': None, + 'ksp_converged_maxits': None, + 'ksp_max_it': 15, + 'ksp_rtol': 1e-1, + 'ksp_type': 'cg', + 'pc_type': 'python', + 'pc_python_type': f'fdvar.SC4DVarBackgroundPC', + }, +} +tao = TAOSolver(MinimizationProblem(Jhat), + parameters=tao_parameters, + options_prefix="sc") +xopt = tao.solve() +PETSc.Sys.Print(f"{errornorm(reference_ic, background) = :.3e}") +PETSc.Sys.Print(f"{errornorm(reference_ic, xopt) = :.3e}") diff --git a/burgers/burgers_simple_sc4dvar_dirichlet_irk.py b/burgers/burgers_simple_sc4dvar_dirichlet_irk.py new file mode 100644 index 0000000..784f909 --- /dev/null +++ b/burgers/burgers_simple_sc4dvar_dirichlet_irk.py @@ -0,0 +1,239 @@ +from firedrake import * +from firedrake.adjoint import * +from firedrake.__future__ import interpolate +from fdvar import generate_observation_data +from numpy import mean as nmean +from numpy import max as nmax +from sys import exit +from irksome import TimeStepper, GaussLegendre, Dt +np.random.seed(42) + +# number of observation windows, and steps per window +nx, nw, nt, dt, nu = 20, 5, 3, 1e-4, 0.25 + +T = 0.03 +nsw = 50 + +# Covariance of background, observation, and model noise +sigma_b = 1e-2 +sigma_r = 1e-3 +sigma_q = (1e-4)*T/nsw + +# 1D periodic mesh +mesh = UnitIntervalMesh(nx) +x, = SpatialCoordinate(mesh) + +# Burger's equation with implicit midpoint integration +V = VectorFunctionSpace(mesh, "CG", 1) +R = FunctionSpace(mesh, "R", 0) +t = Function(R).assign(0) +fdt = Function(R).assign(dt) + +un, un1 = Function(V), Function(V) +v = TestFunction(V) + +zero = Constant(0) +bcs = [DirichletBC(V, as_vector([zero]), "on_boundary")] + +params = { + 'ksp_type': 'gmres', + 'pc_type': 'ilu', +} + +# forcing +k = Constant(0.1) + +x1 = 1 - x +t1 = t + 1 +g = ( + pi*k*( + (x + k*t1*sin(pi*x1*t1)) + *cos(pi*x*t1)*sin(pi*x1*t1) + + + (x1 - k*t1*sin(pi*x*t1)) + *sin(pi*x*t1)*cos(pi*x1*t1) + ) + + + (2*nu*(k*pi*t1)**2) + *(sin(pi*x*t1)*sin(pi*x1*t1) + + cos(pi*x*t1)*cos(pi*x1*t1)) +) +uh = (un + un1)/2 + +irk = True + +# finite element forms +F = ( + (inner(un1 - un, v)/fdt)*dx + # + inner(dot(as_vector([uh]), nabla_grad(uh)), v)*dx + # + inner(nu*grad(uh), grad(v))*dx + # - inner(g, v)*dx + + inner(dot(uh, nabla_grad(uh)), v)*dx + + inner(nu*grad(uh), grad(v))*dx + - inner(as_vector([g]), v)*dx +) + +Firk = ( + inner(Dt(un), v)*dx + # + inner(dot(as_vector([un]), nabla_grad(un)), v)*dx + # + inner(nu*grad(un), grad(v))*dx + # - inner(g, v)*dx + + inner(dot(un, nabla_grad(un)), v)*dx + + inner(nu*grad(un), grad(v))*dx + - inner(as_vector([g]), v)*dx +) +irk_stepper = TimeStepper( + Firk, GaussLegendre(1), t, fdt, un, + bcs=bcs, solver_parameters=params) + +# timestepper solver +nvs_stepper = NonlinearVariationalSolver( + NonlinearVariationalProblem(F, un1, bcs=bcs), + solver_parameters=params) +stepper = nvs_stepper + +def solve_step(): + if irk: + irk_stepper.advance() + else: + un1.assign(un) + nvs_stepper.solve() + un.assign(un1) + +# "ground truth" reference solution +reference_ic = Function(V).project( + as_vector([k*sin(2*pi*x)])) + # k*sin(2*pi*x)) + +# observations are point evaluations at random locations +observation_locations = [ + [x] for x in np.random.random_sample(20)] +vom = VertexOnlyMesh(mesh, observation_locations) +Y = VectorFunctionSpace(vom, "DG", 0) + +def H(x): # operator to take observations + return assemble(interpolate(x, Y)) + +# generate "ground-truth" observational data +y, background = generate_observation_data( + None, reference_ic, solve_step, un, + H, nw, nt, sigma_b, sigma_r, sigma_q) + +# create function evaluating observation error at window i +def observation_error(i): + return lambda x: Function(Y).assign(H(x) - y[i]) + +# create distributed control variable for entire timeseries +control = Function(V).assign(background) + +# tell pyadjoint to start taping operations +continue_annotation() + +# This object will construct and solve the 4DVar system +Jhat = FourDVarReducedFunctional( + Control(control), + background=background, + background_covariance=sigma_b, + observation_covariance=sigma_r, + observation_error=observation_error(0), + weak_constraint=False) + +# loop over each observation stage on the local communicator +trajectory = [background.copy(deepcopy=True, annotate=False)] +t.assign(0.) +with Jhat.recording_stages(nstages=nw, t=t) as stages: + for stage, ctx in stages: + idx = stage.local_index + un.assign(stage.control) + + t.assign(ctx.t) + + # let pyadjoint tape the time integration + for i in range(nt): + solve_step() + t += dt + trajectory.append(un.copy(deepcopy=True, annotate=False)) + + # tell pyadjoint a) we have finished this stage + # and b) how to evaluate this observation error + stage.set_observation( + state=un, + observation_error=observation_error(idx), + observation_covariance=sigma_r) + + +# tell pyadjoint to finish taping operations +pause_annotation() + +# print(f"{t.dat.data[0] = :.2e}") +# sc4dvar = Jhat.strong_reduced_functional +# print(f"{Jhat._total_functional = }") +# print(f"{sc4dvar.functional = :.2e}") +# print(f"{sc4dvar(background) = :.2e}") +# print(f"{sc4dvar(reference_ic) = :.2e}") +# print(f"{sc4dvar(background) = :.2e}") +# print(f"{sc4dvar(reference_ic) = :.2e}") + +# vtk = VTKFile('output/burgers.pvd') +# t = 0 +# for j, u in enumerate(trajectory): +# # ud = u.dat.data +# # print(f"{j = :>3d} | {norm(u) = :.2e} | {nmean(ud) = :.2e} | {nmax(-ud) = :.2e} | {nmax(ud) = :.2e}") +# # vtk.write(u, time=t) +# # t += dt +# # pass + + +class CovariancePC(PCBase): + def initialize(self, pc): + w = Constant(sigma_b) + u = Function(V) + b = Cofunction(V.dual()) + a = (1/w)*inner(TrialFunction(V), TestFunction(V))*dx + solver = LinearVariationalSolver( + LinearVariationalProblem(a, b, u), + solver_parameters={'ksp_type': 'preonly', + 'pc_type': 'lu'}) + self.u, self.b, self.solver = u, b, solver + + def apply(self, pc, x, y): + with self.b.dat.vec_wo as vb: + x.copy(vb) + self.solver.solve() + with self.u.dat.vec_ro as vu: + vu.copy(y) + + # x.copy(y) + # y.scale(sigma_b) + + def update(self, pc, x): + pass + + def applyTranspose(self, pc, x, y): + raise NotImplementedError + + +# Solution strategy is controlled via this options dictionary +tao_parameters = { + 'tao_view': ':tao_view.log', + 'tao_monitor': None, + 'tao_converged_reason': None, + 'tao_gttol': 2e-1, + 'tao_type': 'nls', + 'tao_nls': { + 'ksp_monitor_short': None, + 'ksp_converged_rate': None, + 'ksp_converged_maxits': None, + 'ksp_max_it': 10, + 'ksp_rtol': 1e-1, + 'ksp_type': 'gmres', + 'pc_type': 'none', + # 'pc_type': 'python', + # 'pc_python_type': f'{__name__}.CovariancePC', + }, +} +tao = TAOSolver(MinimizationProblem(Jhat), + parameters=tao_parameters) +xopt = tao.solve() +PETSc.Sys.Print(f"{errornorm(reference_ic, background) = :.3e}") +PETSc.Sys.Print(f"{errornorm(reference_ic, xopt) = :.3e}") diff --git a/burgers/burgers_simple_wc4dvar.py b/burgers/burgers_simple_wc4dvar.py new file mode 100644 index 0000000..ed996a0 --- /dev/null +++ b/burgers/burgers_simple_wc4dvar.py @@ -0,0 +1,116 @@ +from firedrake import * +from firedrake.adjoint import * +from firedrake.__future__ import interpolate +from fdvar import TAOSolver, generate_observation_data +np.random.seed(42) + +# number of observation windows, and steps per window +nw, nt, dt, nu = 32, 6, 1e-4, 0.05 + +# Covariance of background, observation, and model noise +sigma_b = 0.1 +sigma_r = 0.03 +sigma_q = 0.0002 + +# time-parallelism using firedrake's Ensemble +nspatial_ranks = 1 +ensemble = Ensemble(COMM_WORLD, nspatial_ranks) +ensemble_size = ensemble.ensemble_size + +# 1D periodic mesh +mesh = PeriodicUnitIntervalMesh( + 100, comm=ensemble.comm) +x, = SpatialCoordinate(mesh) + +# Burger's equation with implicit midpoint integration +V = VectorFunctionSpace(mesh, "CG", 2) + +un, un1 = Function(V), Function(V) +v = TestFunction(V) +uh = (un + un1)/2 + +# finite element forms +F = (inner(un1 - un, v)/dt + + inner(dot(uh, grad(uh)), v) + + inner(nu*grad(uh), grad(v)))*dx + +# timestepper solver +stepper = NonlinearVariationalSolver( + NonlinearVariationalProblem(F, un1)) + +# "ground truth" reference solution +reference_ic = Function(V).project( + as_vector([1 + 0.5*sin(2*pi*x)])) + +# observations are point evaluations at random locations +vom = VertexOnlyMesh(mesh, np.random.rand(20, 1)) +Y = VectorFunctionSpace(vom, "DG", 0) + +def H(x): # operator to take observations + return assemble(interpolate(x, Y)) + +# generate "ground-truth" observational data +y, background = generate_observation_data( + ensemble, reference_ic, stepper, un, un1, + H, nw, nt, sigma_b, sigma_r, sigma_q) + +# create function evaluating observation error at window i +def observation_error(i): + return lambda x: Function(Y).assign(H(x) - y[i]) + +# create distributed control variable for entire timeseries +V_ensemble = EnsembleFunctionSpace( + [V for _ in range(nw//ensemble_size)], ensemble) +control = EnsembleFunction(V_ensemble) + +# tell pyadjoint to start taping operations +continue_annotation() + +# This object will construct and solve the 4DVar system +Jhat = FourDVarReducedFunctional( + Control(control), + background=background, + background_covariance=sigma_b, + observation_covariance=sigma_r, + observation_error=observation_error(0), + weak_constraint=True) + +# loop over each observation stage on the local communicator +with Jhat.recording_stages() as stages: + for stage, ctx in stages: + idx = stage.local_index + un.assign(stage.control) + un1.assign(un) + + # let pyadjoint tape the time integration + for i in range(nt): + stepper.solve() + un.assign(un1) + + # tell pyadjoint a) we have finished this stage + # and b) how to evaluate this observation error + stage.set_observation( + state=un, + observation_error=observation_error(idx), + observation_covariance=sigma_r, + forward_model_covariance=sigma_q) + +# tell pyadjoint to finish taping operations +pause_annotation() + +# Solution strategy is controlled via this options dictionary +tao_parameters = { + 'tao_monitor': None, + 'tao_converged_reason': None, + 'tao_gttol': 1e-2, + 'tao_type': 'nls', + 'tao_nls': { + 'ksp_monitor_short': None, + 'ksp_converged_maxits': None, + 'ksp_max_it': 8, + 'ksp_rtol': 1e-1, + 'ksp_type': 'gmres'} +} +tao = TAOSolver(Jhat, options_prefix="fdv", + solver_parameters=tao_parameters) +tao.solve() diff --git a/burgers/burgers_wc4dvar_demo.py b/burgers/burgers_wc4dvar_demo.py index 574653f..5ee5731 100644 --- a/burgers/burgers_wc4dvar_demo.py +++ b/burgers/burgers_wc4dvar_demo.py @@ -2,15 +2,15 @@ import firedrake as fd from firedrake.petsc import PETSc from firedrake.__future__ import interpolate -from firedrake.adjoint import (continue_annotation, pause_annotation, - get_working_tape, Control, minimize) +from firedrake.adjoint import (continue_annotation, pause_annotation, Control, minimize) +from pyadjoint.reduced_functional_numpy import ReducedFunctionalNumPy from firedrake.adjoint import FourDVarReducedFunctional from burgers_utils import noisy_double_sin, burgers_stepper import numpy as np from functools import partial import argparse -from sys import exit from math import ceil +from fdvar import TAOSolver np.set_printoptions(legacy='1.25') @@ -167,19 +167,10 @@ def H(x, name=None): global_comm.Barrier() -# weighted l2 inner product -def wl2prod(x, w=1.0, ad_block_tag=None): - return fd.assemble(fd.inner(x, w*x)*fd.dx, ad_block_tag=ad_block_tag) - - def observation_err(i, state, name=None): return fd.Function(Vobs, name=f'Observation error H{i}(x{i}) - y{i}').assign(H(state, name) - y[i], ad_block_tag=f"Observation error calculation {i}") -background_iprod = partial(wl2prod, w=B, ad_block_tag='Background inner product') -observation_iprod = partial(wl2prod, w=R, ad_block_tag='Observation inner product') -model_iprod = partial(wl2prod, w=Q, ad_block_tag='Model inner product') - # Initialise forward model from prior/background initial conditions # and accumulate weak constraint functional as we go @@ -189,12 +180,9 @@ def observation_err(i, state, name=None): ### Create the 4dvar reduced functional ################################################## -background_iprod0 = background_iprod if initial_observations: - observation_iprod0 = observation_iprod observation_err0 = partial(observation_err, 0, name='Model observation 0') else: - observation_iprod0 = None observation_err0 = None ## Make sure this is the only point we are requiring user to know partition specifics @@ -202,18 +190,20 @@ def observation_err(i, state, name=None): # first rank has one extra control for the initial conditions nlocal_controls = nlocal_observations + (1 if trank == 0 else 0) -aaofunc = fd.EnsembleFunction(ensemble, [V for _ in range(nlocal_controls)]) +control_space = fd.EnsembleFunctionSpace( + [V for _ in range(nlocal_controls)], ensemble) +control = fd.EnsembleFunction(control_space) if trank == 0: - aaofunc.subfunctions[0].assign(background) + control.subfunctions[0].assign(background) continue_annotation() Jhat = FourDVarReducedFunctional( - Control(aaofunc), - background_iprod=background_iprod0, - observation_iprod=observation_iprod0, - observation_err=observation_err0, + Control(control), + background_covariance=B, + observation_covariance=R, + observation_error=observation_err0, weak_constraint=(args.constraint == 'weak')) Jhat.background.topological.rename("Background") @@ -222,8 +212,7 @@ def observation_err(i, state, name=None): PETSc.Sys.Print("Running forward model") global_comm.Barrier() -observation_idx = 1 if initial_observations else 0 -obs_offset = observation_idx +obs_offset = 1 if initial_observations else 0 ################################################## ### Record the forward model and observations @@ -256,13 +245,10 @@ def observation_err(i, state, name=None): observation_err, local_obs_idx, name=f'Model observation {stage.observation_index}') - model_iprod = partial(wl2prod, w=Q, - ad_block_tag=f'Model inner product {stage.observation_index}') - # record the observation at the end of the stage stage.set_observation(un, obs_error, - observation_iprod=observation_iprod, - forward_model_iprod=model_iprod) + observation_covariance=R, + forward_model_covariance=Q) global_comm.Barrier() @@ -279,22 +265,50 @@ def observation_err(i, state, name=None): ucontrol = Jhat.control.copy() -# minimiser should be given the derivative not the gradient -derivative_options = {'riesz_representation': 'l2'} - -options = { - 'disp': trank == 0, - 'maxcor': args.maxcor, - 'ftol': args.ftol, - 'gtol': args.gtol +# # scipy minimise +# uoptimised = minimize( +# ReducedFunctionalNumPy(Jhat), +# method="L-BFGS-B", +# options = { +# 'disp': trank == 0, +# 'maxcor': args.maxcor, +# 'ftol': args.ftol, +# 'gtol': args.gtol +# }) + +# # tao minimise +tao_params = { + 'tao_view': ':tao_view.log', + 'tao': { + 'monitor': None, + 'converged_reason': None, + 'ls_monitor': None, + 'gatol': 1e-3, + 'grtol': 1e-3, + 'gttol': 1e-3, + }, + 'tao_type': 'nls', + 'tao_nls': { + 'ksp': { + 'monitor_short': None, + 'converged_rate': None, + 'converged_maxits': None, + 'max_it': 10, + 'rtol': 1e-1, + }, + 'ksp_type': 'gmres', + 'ksp_pc_side': 'right', + 'pc_type': 'none', + }, + 'tao_cg_type': 'fr', # fr-pr-prp-hs-dy } - -uoptimised = minimize(Jhat, options=options, method="L-BFGS-B", - derivative_options=derivative_options) - -uopts = uoptimised.subfunctions +tao = TAOSolver(Jhat, options_prefix="", + solver_parameters=tao_params) +tao.solve() +uoptimised = Jhat.control.copy() global_comm.Barrier() +uopts = uoptimised.subfunctions PETSc.Sys.Print(f"Initial functional: {Jhat(ucontrol)}") PETSc.Sys.Print(f"Final functional: {Jhat(uoptimised)}") global_comm.Barrier() diff --git a/correlations/correlation_solves.py b/correlations/correlation_solves.py new file mode 100644 index 0000000..d2a3fef --- /dev/null +++ b/correlations/correlation_solves.py @@ -0,0 +1,100 @@ +import firedrake as fd +import numpy as np +import scipy.sparse.linalg as spla +from functools import partial +from correlations import ( + chordal_separation, soar_csr, csr_to_petsc, petsc_to_csr) + +from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter + +parser = ArgumentParser( + description='Construct and solve an SOAR correlation matrix on a periodic interval.', # noqa: E501 + formatter_class=ArgumentDefaultsHelpFormatter +) +parser.add_argument('--n', type=int, default=512, help='Number of mesh nodes.') # noqa: E501 +parser.add_argument('--L', type=float, default=0.6, help='Correlation length scale.') # noqa: E501 +parser.add_argument('--sigma', type=float, default=0.4, help='SOAR scaling.') # noqa: E501 +parser.add_argument('--minval', type=float, default=0.1, help='Minimum correlation value. Any below value this will be clipped to zero.') # noqa: E501 +parser.add_argument('--maxnz', type=int, default=16, help='Maximum number of nonzeros per row. Any values beyond this bandwidth will be clipped to zero.') # noqa: E501 +parser.add_argument('--psi', type=float, default=0.1, help='Shift to add to the major diagonal if the matrix is not positive-definite.') # noqa: E501 +parser.add_argument('--neigs', type=int, default=8, help='Number of eigenvalues to calculate to estimate positive-definiteness.') # noqa: E501 +parser.add_argument('--show_args', action='store_true', help='Print all the arguments when the script starts.') # noqa: E501 + +args, _ = parser.parse_known_args() + +if args.show_args: + print(args) + +np.random.seed(42) + +np.set_printoptions(legacy='1.25', precision=3, + threshold=2000, linewidth=200) + +# number of nodes +n = args.n +maxsep = 0.5*args.maxnz/n + +# eigenvalue estimates + +mesh = fd.PeriodicUnitIntervalMesh(n) +V = fd.FunctionSpace(mesh, 'DG', 0) +coords = fd.Function(V).interpolate(fd.SpatialCoordinate(mesh)[0]) +x = coords.dat.data + +chord_sep = partial(chordal_separation, + start=0, end=1) + +print('>>> Building SOAR CSR matrix') +Bcsr = soar_csr(x, args.L*maxsep, sigma=args.sigma, + tol=args.minval, maxsep=maxsep, + separation=chord_sep, + triangular=False) +size = n*n +nnz = Bcsr.nnz +nnzrow = nnz/n +fill = nnz/size +print(f'{size = } | {nnz = } | {nnzrow = } | {fill = }') + +print('>>> Checking positive-definiteness') +neigs = min(args.neigs, n-2) +emax = np.max(spla.eigsh(Bcsr, k=neigs, which='LM', + return_eigenvectors=False)) +emin = np.min(spla.eigsh(Bcsr, k=neigs, which='SA', + return_eigenvectors=False)) +cond = emax/emin +print('Eigenvalues:') +print(f'{emax = } | {emin = } | {cond = }') + +Bmat = csr_to_petsc(Bcsr) + +if emin < 0: + print('>>> Matrix is not SPD: shifting...') + Bmat.shift(abs(emin) + args.psi) + Bcsr = petsc_to_csr(Bmat) + emax = np.max(spla.eigsh(Bcsr, k=neigs, which='LM', + return_eigenvectors=False)) + emin = np.min(spla.eigsh(Bcsr, k=neigs, which='SA', + return_eigenvectors=False)) + cond = emax/emin + print('Eigenvalues after shift:') + print(f'{emax = } | {emin = } | {cond = }') + +params = { + 'ksp_converged_rate': None, + 'ksp_rtol': 1e-5, + 'ksp_type': 'cg', + 'pc_type': 'icc', +} + +print('>>> Setting up solver') +arguments = (fd.TestFunction(V), fd.TrialFunction(V)) +B = fd.AssembledMatrix(arguments, bcs=[], petscmat=Bmat) +solver = fd.LinearSolver(B, options_prefix='', + solver_parameters=params) + +x = fd.Function(V) +b = fd.Cofunction(V.dual()) +b.dat.data[:] = np.random.random_sample(b.dat.data.shape) + +print('>>> Solving...') +solver.solve(x, b) diff --git a/correlations/correlations.py b/correlations/correlations.py new file mode 100644 index 0000000..e6de260 --- /dev/null +++ b/correlations/correlations.py @@ -0,0 +1,92 @@ +from firedrake.petsc import PETSc +from scipy import sparse +import numpy as np + + +def minmag(a, b): + return np.where(np.abs(a) < np.abs(b), a, b) + + +def periodic_separation(y, x, start=None, end=None): + start = x[0] if start is None else start + end = x[-1] if end is None else end + internal = x - y + left = (y - start) + (end - x) + right = (end - y) + x + return minmag(minmag(left, internal), right) + + +def chordal_distance(separation, circumference): + diameter = circumference/np.pi + angle = 2*np.pi*separation/circumference + return diameter*np.sin(angle/2) + + +def chordal_separation(y, x, start=None, end=None): + start = x[0] if start is None else start + end = x[-1] if end is None else end + return chordal_distance( + periodic_separation(y, x, start, end), + end - start) + + +def separation_csr(x, maxsep=None, separation=None, + triangular=True): + n = len(x) + upper = sparse.csr_array((n, n)) + for i, y in enumerate(x): + if separation: + d = separation(y, x[i:]) + else: + d = x[i:] - y + if maxsep: + d[d > maxsep] = 0 + upper[i, i:] = d + upper.eliminate_zeros() + upper.setdiag(0) + if triangular: + return upper + else: + full = symmetrise_csr(upper) + full.setdiag(0) + return full + + +def symmetrise_csr(csr): + transpose = csr.T.tocsr() + transpose.setdiag(0) + return csr + transpose + + +def soar(r, L, sigma=1.): + rL = np.abs(r)/L + return sigma*(1 + rL)*np.exp(-rL) + + +def soar_csr(x, L, sigma=1., tol=None, + maxsep=None, separation=None, + triangular=True): + upper = separation_csr(x, maxsep=maxsep, + separation=separation, + triangular=True) + upper.data[:] = soar(upper.data, L, sigma=sigma) + if tol: + upper.data[upper.data < tol] = 0 + upper.eliminate_zeros() + return upper if triangular else symmetrise_csr(upper) + + +def csr_to_petsc(scipy_mat): + mat = PETSc.Mat().create() + mat.setType('aij') + mat.setSizes(scipy_mat.shape) + mat.setValuesCSR(scipy_mat.indptr, + scipy_mat.indices, + scipy_mat.data) + mat.assemble() + return mat + + +def petsc_to_csr(petsc_mat): + return sparse.csr_matrix(petsc_mat.getValuesCSR()[::-1], + shape=petsc_mat.getSize()) diff --git a/correlations/soar_mat.py b/correlations/soar_mat.py new file mode 100644 index 0000000..760f60a --- /dev/null +++ b/correlations/soar_mat.py @@ -0,0 +1,30 @@ +import numpy as np +from functools import partial +from correlations import ( + chordal_separation, soar_csr, csr_to_petsc, petsc_to_csr) + +np.set_printoptions(legacy='1.25', precision=2, + linewidth=200, threshold=10000) + +n = 16 + +L = 0.1 +tol = 0.4 + +maxsep = 0.3 + +width = 1 +start = 0 +end = start + width + +x = np.linspace(start, end, n, endpoint=False) + +chordsep = partial(chordal_separation, + start=start, end=end) + +petsc_mat = csr_to_petsc( + soar_csr(x, L, tol=tol, maxsep=maxsep, + separation=chordsep, + triangular=False)) + +print(f'mat =\n{petsc_to_csr(petsc_mat).todense()}') diff --git a/correlations/solver_external_operator.py b/correlations/solver_external_operator.py new file mode 100644 index 0000000..b20b5e6 --- /dev/null +++ b/correlations/solver_external_operator.py @@ -0,0 +1,49 @@ +import firedrake as fd +from firedrake.external_operators import AbstractExternalOperator, assemble_method + + +class LinearSolverOperator(AbstractExternalOperator): + def __init__(self, *operands, function_space, + derivatives=None, argument_slots=(), + operator_data=None): + + AbstractExternalOperator.__init__( + self, *operands, + function_space=function_space, + derivatives=derivatives, + argument_slots=argument_slots, + operator_data=operator_data) + + self.b, = operands + self._b = self.b.copy() + self.bexpr = fd.inner(self.b, fd.TestFunction(self._b.function_space()))*fd.dx + + self.A = operator_data["A"] + self.solver = fd.LinearSolver( + self.A, **operator_data.get("solver_kwargs", {})) + + def _solve(self, b): + x = fd.Function(self.function_space()) + self.solver.solve(x, b) + return x + + def _solve_transpose(self, b): + x = fd.Cofunction(self.function_space.dual()) + self.solver.solve_transpose(x, b) + return x + + def _assemble_b(self, b): + self._b.assign(b) + return fd.assemble(self.bexpr) + + @assemble_method(0, (0,)) + def assemble_operator(self, *args, **kwargs): + return self._solve(self._assemble_b(self.b)) + + @assemble_method(1, (0, None)) + def assemble_jacobian_action(self, *args, **kwargs): + return self._solve(self._assemble_b(self.argument_slots()[-1])) + + @assemble_method(1, (1, None)) + def assemble_jacobian_adjoint_action(self, *args, **kwargs): + return self._assemble_b(self._solve_transpose(self.argument_slots()[0])) diff --git a/correlations/taylor_test.py b/correlations/taylor_test.py new file mode 100644 index 0000000..8e88fb7 --- /dev/null +++ b/correlations/taylor_test.py @@ -0,0 +1,154 @@ +import firedrake as fd +from firedrake.adjoint import ( + continue_annotation, pause_annotation, Control, ReducedFunctional) +import numpy as np +import scipy.sparse.linalg as spla +from functools import partial +from correlations import ( + chordal_separation, soar_csr, csr_to_petsc, petsc_to_csr) + +from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter + + +def weighted_norm(M, x, **kwargs): + y = fd.Function(x.function_space().dual()) + fd.solve(M, y, x, **kwargs) + return fd.assemble(fd.inner(y, y)*fd.dx) + + +parser = ArgumentParser( + description='Construct and solve an SOAR correlation matrix on a periodic interval.', # noqa: E501 + formatter_class=ArgumentDefaultsHelpFormatter +) +parser.add_argument('--n', type=int, default=128, help='Number of mesh nodes.') # noqa: E501 +parser.add_argument('--L', type=float, default=0.6, help='Correlation length scale.') # noqa: E501 +parser.add_argument('--sigma', type=float, default=0.4, help='SOAR scaling.') # noqa: E501 +parser.add_argument('--minval', type=float, default=0.1, help='Minimum correlation value. Any below value this will be clipped to zero.') # noqa: E501 +parser.add_argument('--maxnz', type=int, default=16, help='Maximum number of nonzeros per row. Any values beyond this bandwidth will be clipped to zero.') # noqa: E501 +parser.add_argument('--psi', type=float, default=0.1, help='Shift to add to the major diagonal if the matrix is not positive-definite.') # noqa: E501 +parser.add_argument('--neigs', type=int, default=8, help='Number of eigenvalues to calculate to estimate positive-definiteness.') # noqa: E501 +parser.add_argument('--show_args', action='store_true', help='Print all the arguments when the script starts.') # noqa: E501 + +args, _ = parser.parse_known_args() + +if args.show_args: + print(args) + +np.random.seed(42) + +np.set_printoptions(legacy='1.25', precision=3, + threshold=2000, linewidth=200) + +# number of nodes +n = args.n +maxsep = 0.5*args.maxnz/n + +# eigenvalue estimates + +mesh = fd.PeriodicUnitIntervalMesh(n) +V = fd.FunctionSpace(mesh, 'DG', 0) +coords = fd.Function(V).interpolate(fd.SpatialCoordinate(mesh)[0]) +x = coords.dat.data + +chord_sep = partial(chordal_separation, + start=0, end=1) + +print('>>> Building SOAR CSR matrix') +Bcsr = soar_csr(x, args.L, sigma=args.sigma, + tol=args.minval, maxsep=maxsep, + separation=chord_sep, + triangular=False) +size = n*n +nnz = Bcsr.nnz +nnzrow = nnz/n +fill = nnz/size +print(f'{size = } | {nnz = } | {nnzrow = } | {fill = }') + +print('>>> Checking positive-definiteness') +neigs = min(args.neigs, n-2) +emax = np.max(spla.eigsh(Bcsr, k=neigs, which='LM', + return_eigenvectors=False)) +emin = np.min(spla.eigsh(Bcsr, k=neigs, which='SA', + return_eigenvectors=False)) +cond = emax/emin +print('Eigenvalues:') +print(f'{emax = } | {emin = } | {cond = }') + +Bmat = csr_to_petsc(Bcsr) + +if emin < 0: + print('>>> Matrix is not SPD: shifting...') + Bmat.shift(abs(emin) + args.psi) + Bcsr = petsc_to_csr(Bmat) + emax = np.max(spla.eigsh(Bcsr, k=neigs, which='LM', + return_eigenvectors=False)) + emin = np.min(spla.eigsh(Bcsr, k=neigs, which='SA', + return_eigenvectors=False)) + cond = emax/emin + print('Eigenvalues after shift:') + print(f'{emax = } | {emin = } | {cond = }') + +params = { + 'ksp_converged_rate': None, + 'ksp_monitor': None, + + 'ksp_rtol': 1e-5, + 'ksp_type': 'cg', + 'pc_type': 'icc', + + # 'ksp_type': 'preonly', + # 'pc_type': 'lu', + # 'pc_factor_mat_solver_type': 'mumps', +} + +print('>>> Setting up solver') +arguments = (fd.TestFunction(V), fd.TrialFunction(V)) + +B = fd.AssembledMatrix(arguments, bcs=[], petscmat=Bmat) + +x0 = fd.Function(V) +x1 = fd.Function(V) + +b0 = fd.Cofunction(V.dual()) +b0.dat.data[:] = np.random.random_sample(b0.dat.data.shape) + +b1 = fd.Cofunction(V.dual()) +b1.dat.data[:] = np.random.random_sample(b1.dat.data.shape) + +print('>>> Solving LinearSolver...') +solver = fd.LinearSolver(B, options_prefix='', + solver_parameters=params) +solver.solve(x0, b0) +xBx = fd.assemble(fd.inner(b0.riesz_representation(), x0)*fd.dx) +print(f'{xBx = }') + +print('>>> Solving LinearSolverOperator...') + +from solver_external_operator import LinearSolverOperator +bfunc = fd.Function(V) +solve_operator = LinearSolverOperator( + bfunc, function_space=V, + operator_data={ + "A": B, + "solver_kwargs": { + "options_prefix": '', + "solver_parameters": params + } + }) + +bfunc.assign(b0.riesz_representation()) +xBx = fd.assemble(fd.inner(b0.riesz_representation(), solve_operator)*fd.dx) +print(f'{xBx = }') + +bfunc.assign(b0.riesz_representation()) +continue_annotation() +J = fd.assemble(fd.inner(bfunc, solve_operator)*fd.dx) +Jhat = ReducedFunctional(J, Control(bfunc)) +pause_annotation() + +print(f'{Jhat(b1.riesz_representation()) = }') +print(f'{fd.norm(Jhat.derivative()) = }') + +solver.solve(x1, b1) +xBx = fd.assemble(fd.inner(b1.riesz_representation(), x1)*fd.dx) +print(f'{xBx = }') diff --git a/fdvar/__init__.py b/fdvar/__init__.py new file mode 100644 index 0000000..beb3e17 --- /dev/null +++ b/fdvar/__init__.py @@ -0,0 +1,3 @@ +from .tao_solver import * # noqa: F401, F403 +from .generate_data import * # noqa: F401, F403 +from .pc import * # noqa: F401, F403 diff --git a/fdvar/generate_data.py b/fdvar/generate_data.py new file mode 100644 index 0000000..fe11e8f --- /dev/null +++ b/fdvar/generate_data.py @@ -0,0 +1,59 @@ +import firedrake as fd +import numpy as np + + +def noisify(u, sigma, bcs=None, seed=None): + if seed is not None: + np.random.seed(seed) + for dat in u.dat: + dat.data[:] += np.random.normal(0, sigma, dat.data.shape) + if bcs: + for bc in bcs: + bc.apply(u) + return u + + +def generate_observation_data(ensemble, ic, stepper, un, un1, bcs, H, + nw, nt, sigma_b, sigma_r, sigma_q, seed=6): + if seed is not None: + np.random.seed(seed) + + if ensemble is None: + rank = 0 + nlocal_stages = nw + else: + rank = ensemble.ensemble_rank + nlocal_stages = nw//ensemble.ensemble_size + + if rank == 0: + nlocal_stages -= 1 + + # background is reference plus noise + for bc in bcs: + bc.apply(ic) + + background = noisify(ic.copy(deepcopy=True), sigma_b, bcs=bcs) + + y = [] + # initial observation + if rank == 0: + y.append(noisify(H(ic), sigma_r)) + + def solve_local_stages(): + for k in range(nlocal_stages): + for i in range(nt): + un1.assign(un) + stepper.solve() + un.assign(un1) + noisify(un, sigma_q, bcs=bcs) + y.append(noisify(H(un), sigma_r)) + + un.assign(ic) + if ensemble is None: + solve_local_stages() + else: + with ensemble.sequential(u=un) as ctx: + un.assign(ctx.u) + solve_local_stages() + + return y, background diff --git a/fdvar/mat.py b/fdvar/mat.py new file mode 100644 index 0000000..f5e6e49 --- /dev/null +++ b/fdvar/mat.py @@ -0,0 +1,386 @@ +import firedrake as fd +from firedrake.petsc import PETSc +from firedrake.petsc import OptionsManager +from firedrake.adjoint import ReducedFunctional +from firedrake.adjoint.fourdvar_reduced_functional import CovarianceNormReducedFunctional +from pyop2.mpi import MPI +from pyadjoint.optimization.tao_solver import PETScVecInterface +from pyadjoint.enlisting import Enlist +from typing import Optional +from enum import Enum +from functools import partial, cached_property, wraps + + +class RFAction(Enum): + TLM = 'tlm' + Adjoint = 'adjoint' + Hessian = 'hessian' +TLM = RFAction.TLM +Adjoint = RFAction.Adjoint +Hessian = RFAction.Hessian + + +def copy_controls(controls): + return controls.delist([c.control._ad_init_zero() for c in controls]) + + +def convert_types(overloaded, options): + overloaded = Enlist(overloaded) + return overloaded.delist([o._ad_convert_type(o, options=options) + for o in overloaded]) + + +def check_rf_action(action): + def check_rf_action_decorator(func): + @wraps(func) + def wrapper(self, *args, **kwargs): + if self.action != action: + raise NotImplementedError( + f'Cannot apply {str(action)} action if {self.action = }') + return func(*args, **kwargs) + return wrapper + return check_rf_action_decorator + + +class ReducedFunctionalMatCtx: + """ + PythonMat context to apply action of a pyadjoint.ReducedFunctional. + + Jhat : V -> U + TLM : V -> U + Adjoint : U* -> V* + Hessian : V x U* -> V* | V -> V* + + Parameters + ---------- + + action : RFAction + """ + + def __init__(self, Jhat: ReducedFunctional, + action: str = Hessian, *, + apply_riesz: bool = False, + comm: MPI.Comm = PETSc.COMM_WORLD): + self.Jhat = Jhat + self.control_interface = PETScVecInterface( + Jhat.controls, comm=comm) + self.apply_riesz = apply_riesz + if action in (Adjoint, TLM): + self.functional_interface = PETScVecInterface( + Jhat.functional, comm=comm) + + if action == Hessian: # control -> control + self.xinterface = self.control_interface + self.yinterface = self.control_interface + + self.x = copy_controls(Jhat.controls) + self.mult_impl = self._mult_hessian + + elif action == Adjoint: # functional -> control + self.xinterface = self.functional_interface + self.yinterface = self.control_interface + + self.x = Jhat.functional._ad_copy() + self.mult_impl = self._mult_adjoint + + elif action == TLM: # control -> functional + self.xinterface = self.control_interface + self.yinterface = self.functional_interface + + self.x = copy_controls(Jhat.controls) + self.mult_impl = self._mult_tlm + else: + raise ValueError( + 'Unrecognised {action = }.') + + self.action = action + self._m = copy_controls(Jhat.controls) + self._shift = 0 + + @classmethod + def update(cls, obj, x, A, P): + ctx = A.getPythonContext() + ctx.control_interface.from_petsc(x, ctx._m) + ctx.update_tape_values(update_adjoint=True) + ctx._shift = 0 + + def shift(self, A, alpha): + self._shift += alpha + + def update_tape_values(self, update_adjoint=True): + _ = self.Jhat(self._m) + if update_adjoint: + _ = self.Jhat.derivative(apply_riesz=False) + + def mult(self, A, x, y): + self.xinterface.from_petsc(x, self.x) + out = self.mult_impl(A, self.x) + self.yinterface.to_petsc(y, out) + + if self._shift != 0: + y.axpy(self._shift, x) + + # @check_rf_action(action=Hessian) + def _mult_hessian(self, A, x): + # self.update_tape_values(update_adjoint=True) + return self.Jhat.hessian( + x, apply_riesz=self.apply_riesz) + + # @check_rf_action(TLM) + def _mult_tlm(self, A, x): + # self.update_tape_values(update_adjoint=False) + return self.Jhat.tlm(x) + + # @check_rf_action(Adjoint) + def _mult_adjoint(self, A, x): + # self.update_tape_values(update_adjoint=False) + return self.Jhat.derivative( + adj_input=x, apply_riesz=self.apply_riesz) + + +def ReducedFunctionalMat(Jhat, action=Hessian, *, comm=PETSc.COMM_WORLD, **kwargs): + ctx = ReducedFunctionalMatCtx( + Jhat, action, comm=comm, **kwargs) + + ncol = ctx.xinterface.n + Ncol = ctx.xinterface.N + + nrow = ctx.yinterface.n + Nrow = ctx.yinterface.N + + mat = PETSc.Mat().createPython( + ((nrow, Nrow), (ncol, Ncol)), + ctx, comm=comm) + mat.setUp() + mat.assemble() + return mat + + +def EnsembleMat(ctx, row_space, col_space=None): + if col_space is None: + col_space = row_space + + # number of columns is row length, and vice-versa + ncol = row_space.nlocal_rank_dofs + Ncol = row_space.nglobal_dofs + + nrow = col_space.nlocal_rank_dofs + Nrow = col_space.nglobal_dofs + + mat = PETSc.Mat().createPython( + ((nrow, Nrow), (ncol, Ncol)), ctx, + comm=row_space.ensemble.global_comm) + mat.setUp() + mat.assemble() + return mat + + +class EnsembleMatCtxBase: + def __init__(self, row_space, col_space=None): + if col_space is None: + col_space = row_space.dual() + + if not isinstance(row_space, fd.EnsembleFunctionSpace): + raise ValueError( + f"EnsembleMat row_space must be EnsembleFunctionSpace not {type(row_space).__name__}") + if not isinstance(col_space, fd.EnsembleFunctionSpace): + raise ValueError( + f"EnsembleMat col_space must be EnsembleFunctionSpace not {type(col_space).__name__}") + + self.row_space = row_space + self.col_space = col_space + + # input/output Vecs will be copied in/out of these + # so that base classes can implement mult only in + # terms of Ensemble objects not Vecs. + self.x = fd.EnsembleFunction(self.row_space) + self.y = fd.EnsembleFunction(self.col_space) + + def mult(self, A, x, y): + with self.x.vec_wo() as xvec: + x.copy(xvec) + + self.mult_impl(A, self.x, self.y) + + with self.y.vec_ro() as yvec: + yvec.copy(y) + + +class EnsembleBlockDiagonalMatCtx(EnsembleMatCtxBase): + def __init__(self, blocks, row_space, col_space=None): + super().__init__(row_space, col_space) + self.blocks = blocks + + if self.row_space.nlocal_spaces != self.col_space.nlocal_spaces: + raise ValueError( + "EnsembleBlockDiagonalMat row and col spaces must be the same length," + f" not {row_space.nlocal_spaces = } and {col_space.nlocal_spaces = }") + + if len(self.blocks) != self.row_space.nlocal_spaces: + raise ValueError( + f"EnsembleBlockDiagonalMat requires one submatrix for each of the" + f" {self.row_space.nlocal_spaces} local subfunctions of the EnsembleFunctionSpace," + f" but only {len(self.blocks)} provided.") + + for i, (Vrow, Vcol, block) in enumerate(zip(self.row_space.local_spaces, + self.col_space.local_spaces, + self.blocks)): + # number of columns is row length, and vice-versa + vc_sizes = Vrow.dof_dset.layout_vec.sizes + vr_sizes = Vcol.dof_dset.layout_vec.sizes + mr_sizes, mc_sizes = block.sizes + if (vr_sizes[0] != mr_sizes[0]) or (vr_sizes[1] != mr_sizes[1]): + raise ValueError( + f"Row sizes {mr_sizes} of block {i} and {vr_sizes} of row_space {i} of EnsembleBlockDiagonalMat must match.") + if (vc_sizes[0] != mc_sizes[0]) or (vc_sizes[1] != mc_sizes[1]): + raise ValueError( + f"Col sizes of block {i} and col_space {i} of EnsembleBlockDiagonalMat must match.") + + def mult_impl(self, A, x, y): + for block, xsub, ysub in zip(self.blocks, + self.x.subfunctions, + self.y.subfunctions): + with xsub.dat.vec_ro as xvec, ysub.dat.vec_wo as yvec: + block.mult(xvec, yvec) + + +def EnsembleBlockDiagonalMat(row_space, blocks, **kwargs): + return EnsembleMat( + EnsembleBlockDiagonalMatCtx(blocks, row_space, **kwargs), + row_space, kwargs.get('col_space', None)) + + +# L Mat +class AllAtOnceRFMatCtx(EnsembleMatCtxBase): + def __init__(self, fdvrf, action, **kwargs): + super().__init__(fdvrf.solution_space) + + if action not in (TLM, Adjoint): + raise ValueError( + f"AllAtOnceRFMat action type must be 'tlm' or 'adjoint', not {action}") + + self.action = action + self.fdvrf = fdvrf + self.ensemble = fdvrf.ensemble + + self.models = [ReducedFunctionalMat(M, action, **kwargs) + for M in fdvrf.model_rfs] + + if action == Adjoint: + self.models = reversed(self.models) + + self.xhalo = fd.Function(self.row_space.local_spaces[0]) + self.mx = self.x.copy() + + # Set up list of x_{i-1} functions to propogate. + # TLM means we propogate forwards, and use halo from previous rank. + # Adjoint means we propogate backwards, and use halo from next rank. + # The initial timestep on the initial rank doesn't have a halo. + if self.action == TLM: + self.xprevs = [*self.x.subfunctions[:-1]] + else: + self.xprevs = [*reversed(self.x.subfunctions[1:])] + + initial_rank = 0 if action == TLM else (self.ensemble.ensemble_size - 1) + if self.ensemble.ensemble_rank != initial_rank: + self.xprevs.insert(self.xhalo, 0) + + def update_halos(self, x): + ensemble_rank = self.ensemble.ensemble_rank + ensemble_size = self.ensemble.ensemble_size + + # halo swap is a right shift + next_rank = (ensemble_rank + 1) % ensemble_size + prev_rank = (ensemble_rank - 1) % ensemble_size + + src = prev_rank if self.action == TLM else next_rank + dst = next_rank if self.action == TLM else prev_rank + + frecv = self.xhalo + fsend = self.x.subfunctions[-1 if self.action == TLM else 0] + + self.ensemble.sendrecv( + fsend=fsend, dest=dst, sendtag=dst, + frecv=frecv, source=src, recvtag=ensemble_rank) + + def mult_impl(self, A, x, y): + self.update_halos(x) + + # propogate from last step mx <- M*x + for M, xi, mxi in zip(self.models, self.xprevs, self.mx.subfunctions): + mxi.assign(M.mult(xi)) + + # diagonal contribution + # x_{i} <- x_{i} - M*x_{i-1} + x -= self.mx + + y.assign(x.riesz_representation()) + + +def AllAtOnceRFMat(fdvrf, action, **kwargs): + return EnsembleMat( + AllAtOnceRFMatCtx(fdvrf, action, **kwargs), + fdvrf.solution_space) + + +# H Mat +def ObservationEnsembleRFMat(fdvrf, action, **kwargs): + if action == TLM: + row_space = fdvrf.solution_space + col_space = fdvrf.observation_space + elif action == Adjoint: + row_space = fdvrf.observation_space + col_space = fdvrf.solution_space + else: + raise ValueError( + f"Unrecognised matrix action type {action}") + + blocks = [ReducedFunctionalMat(Jobs, action, **kwargs) + for Jobs in fdvrf.observation_rfs] + return EnsembleBlockDiagonalMat(blocks, row_space, col_space) + + +# CovarianceMat +def CovarianceMat(covariancerf, action='mult'): + """ + action='mult' for action of B, or 'inv' for action of B^{-1} + """ + if not isinstance(covariancerf, CovarianceNormReducedFunctional): + raise TypeError( + "CovarianceMat can only be constructed from a CovarianceNormReducedFunctional" + f" not a {type(covariancerf).__name__}") + space = covariancerf.controls[0].control.function_space() + comm = space.mesh().comm + covariance = covariancerf.covariance + + if action == 'mult': + weight = float(covariance) + elif action == 'inv': + weight = float(1/covariance) + else: + raise ValueError(f"Unrecognised action type {action} for CovarianceMat") + + sizes = space.dof_dset.layout_vec.sizes + shape = (sizes, sizes) + covmat = PETSc.Mat().createConstantDiagonal( + shape, covariance, comm=comm) + covmat.setUp() + covmat.assemble() + return covmat + + +# D Mat +def ModelCovarianceEnsembleRFMat(fdvrf, action, **kwargs): + blocks = [CovarianceMat(mnorm, action, **kwargs) + for mnorm in fdvrf.model_norms] + if fdvrf.ensemble.ensemble_rank == 0: + blocks.insert( + 0, CovarianceMat(fdvrf.background_norm, action, **kwargs)) + return EnsembleBlockDiagonalMat(blocks, fdvrf.solution_space) + + +# R Mat +def ObservationCovarianceEnsembleRFMat(fdvrf, action, **kwargs): + blocks = [CovarianceMat(obs_norm, action, **kwargs) + for obs_norm in fdvrf.observation_norms] + return EnsembleBlockDiagonalMat(blocks, fdvrf.observation_space) diff --git a/fdvar/pc.py b/fdvar/pc.py new file mode 100644 index 0000000..914ec27 --- /dev/null +++ b/fdvar/pc.py @@ -0,0 +1,191 @@ +import firedrake as fd +from fdvar.mat import EnsembleMatCtxBase +from math import pi, sqrt + +__all__ = ("SC4DVarBackgroundPC",) + + +class PCBase: + needs_python_amat = False + needs_python_pmat = False + + def __init__(self): + self.initialized = False + + def setUp(self, pc): + if not self.initialized: + self.initialize(pc) + self.initialized = True + self.update(pc) + + def initialize(self, pc): + if pc.getType() != "python": + raise ValueError("Expecting PC type python") + + self.A, self.P = pc.getOperators() + pcname = f"{type(self).__module__}.{type(self).__name__}" + if self.needs_python_amat: + if self.A.getType() != "python": + raise ValueError( + f"PC {pcname} needs a python type amat, not {self.A.getType()}") + self.amat = self.A.getPythonContext() + if self.needs_python_pmat: + if self.P.getType() != "python": + raise ValueError( + f"PC {pcname} needs a python type pmat, not {self.P.getType()}") + self.pmat = self.P.getPythonContext() + + self.parent_prefix = pc.getOptionsPrefix() + self.full_prefix = self.parent_prefix + self.prefix + + def update(self, pc): + pass + + +class EnsemblePCBase(PCBase): + requires_python_amat = True + requires_python_pmat = True + + def __init__(self): + self.initialized = False + + def initialize(self, pc): + super().initialize(pc) + + if not isinstance(self.pmat, EnsembleMatCtxBase): + pcname = f"{type(self).__module__}.{type(self).__name__}" + matname = f"{type(self.pmat).__module__}.{type(self).pmat.__name__}" + raise TypeError( + f"PC {pname} needs an EnsembleMatCtxBase pmat, but it is a {matname}") + + self.row_space = self.pmat.row_space + self.col_space = self.pmat.col_space + + self.x = fd.EnsembleFunction(self.row_space.dual()) + self.y = fd.EnsembleFunction(self.col_space) + + +class EnsembleBlockDiagonalPC(PCBase): + prefix = "ebjacobi_" + + def initialize(self, pc): + super().initialize(pc) + + ensemble_mat = self.pmat + self.function_space = ensemble_mat.function_space + self.ensemble = function_space.ensemble + + submats = ensemble_mat.blocks + + self.x = fd.EnsembleFunction(self.function_space) + self.y = fd.EnsembleCofunction(self.function_space.dual()) + + subksps = [] + for i, submat in enumerate(self.pmat.blocks): + ksp = PETSc.KSP().create(comm=ensemble.comm) + ksp.setOperators(submat) + + sub_prefix = self.parent_prefix + f"sub_{i}_" + # TODO: default options + options = OptionsManager({}, sub_prefix) + options.set_from_options(ksp) + self.subksps.append((ksp, options)) + + self.subksps = tuple(subksps) + + def apply(self, pc, x, y): + with self.x.vec_wo() as xvec: + x.copy(xvec) + + subvecs = zip(self.x.subfunctions, self.y.subfunctions) + for (subksp, suboptions), (subx, suby) in zip(self.subksps, subvecs): + with subx.dat.vec_ro as rhs, suby.dat.vec_wo as sol: + with suboptions.inserted_options(): + subksp.solve(rhs, sol) + + with self.y.vec_ro() as yvec: + yvec.copy(y) + + +class CovarianceMassPC(PCBase): + prefix = "covmass_" + pass + + +class CovarianceDiffusionPC(PCBase): + prefix = "covdiff_" + pass + + +class SC4DVarBackgroundPC(PCBase): + prefix = "scbkg_" + + def initialize(self, pc): + prefix = pc.getOptionsPrefix() or "" + options_prefix = prefix + self.prefix + A, _ = pc.getOperators() + scfdv = A.getPythonContext().problem.reduced_functional + + self.covariance = scfdv.background_norm.covariance + self.x = fd.Cofunction(self.covariance.V.dual()) + self.y = fd.Function(self.covariance.V) + + # # diffusion coefficient for lengthscale L + # nu = 0.5*L*L + # # normalisation for diffusion operator + # lambda_g = L*sqrt(2*pi) + # scale = lambda_g/sigma + + # V = scfdv.controls[0].control.function_space() + # sol = fd.Function(V) + # b = fd.Cofunction(V.dual()) + # u = fd.TrialFunction(V) + # v = fd.TestFunction(V) + # + # Binv = scale*(fd.inner(u, v) - nu*fd.inner(fd.grad(u), fd.grad(v)))*fd.dx + + # solver = fd.LinearVariationalSolver( + # fd.LinearVariationalProblem( + # Binv, b, sol, bcs=bcs, + # constant_jacobian=True), + # options_prefix=options_prefix, + # solver_parameters={'ksp_type': 'preonly', + # 'pc_type': 'lu'}) + + # self.sol, self.b, self.solver = sol, b, solver + + def apply(self, pc, x, y): + with self.x.dat.vec_wo as vx: + x.copy(vx) + self.covariance.action(self.x, tensor=self.y) + with self.y.dat.vec_ro as vy: + # with self.y.riesz_representation().dat.vec_ro as vy: + vy.copy(y) + + def update(self, pc): + pass + + def applyTranspose(self, pc, x, y): + raise NotImplementedError + + +class WC4DVarLTDLPC(PCBase): + prefix = "ltdl_" + + def initialize(self, pc): + self.fdvrf = self.pmat.fdvrf + if not isinstance(self.Jhat, FourDVarReducedFunctional): + raise TypeError( + "AllAtOnceJacobiPC expects a FourDVarReducedFunctional not a {type(self.Jhat).__name__}") + + def apply(self, pc, x, y): + with self.x.vec_wo() as xvec: + x.copy(xvec) + + # P = LL^{T} + + # apply L^{T} + ltx = self.Jhat.derivative + + with self.y.vec_ro() as yvec: + yvec.copy(y) diff --git a/fdvar/saddle.py b/fdvar/saddle.py new file mode 100644 index 0000000..7a8d13a --- /dev/null +++ b/fdvar/saddle.py @@ -0,0 +1,201 @@ +import firedrake as fd +from firedrake.petsc import PETSc +from firedrake.petsc import OptionsManager +from firedrake.adjoint import ReducedFunctional +from firedrake.adjoint.fourdvar_reduced_functional import CovarianceNormReducedFunctional +from pyop2.mpi import MPI +from pyadjoint.optimization.tao_solver import PETScVecInterface +from pyadjoint.enlisting import Enlist +from typing import Optional +from enum import Enum +from functools import partial, cached_property, wraps + + +class ISNest: + attrs = ( + 'getSizes', + 'getSize', + 'getLocalSize', + 'getIndices', + 'getComm', + ) + def __init__(self, ises): + self._comm = ises[0].getComm() + self._ises = ises + for attr in self.attrs: + setattr(self, attr, + partial(self._getattr, attr)) + + def _getattr(self, attr, i, *args, **kwargs): + return getattr(self.ises[i], attr)(*args, **kwargs) + + @cached_property + def globalSize(self): + return sum(self.getSize(i) for i in range(len(self))) + + @cached_property + def localSize(self): + return sum(self.getLocalSize(i) for i in range(len(self))) + + @property + def sizes(self): + return (self.localSize, self.globalSize) + + @property + def comm(self): + return self._comm + + @property + def ises(self): + return self._ises + + def __getitem__(self, i): + return self.ises[i] + + def __len__(self): + return len(self.ises) + + def __iter__(self): + return iter(self.ises) + + def createVec(self, i=None, vec_type=PETSc.Vec.Type.MPI): + vec = PETSc.Vec().create(comm=self.comm) + vec.setType(vec_type) + vec.setSizes(self.sizes if i is None else self.getSizes(i)) + return vec + + def createVecs(self, vec_type=PETSc.Vec.Type.MPI): + return (self.createVec(i, vec_type=vec_type) for i in range(len(self))) + + def createVecNest(self, vecs=None): + if vecs is None: + vecs = self.createVecs() + else: + if not all(vecs[i].getSizes() == self.getSizes(i) for i in range(len(self))): + raise ValueError("vec sizes must match is sizes") + return PETSc.Vec().createNest(vecs, self.ises, self.comm) + + +def copy_controls(controls): + return controls.delist([c.control._ad_init_zero() for c in controls]) + + +def convert_types(overloaded, options): + overloaded = Enlist(overloaded) + return overloaded.delist([o._ad_convert_type(o, options=options) + for o in overloaded]) + + +def saddle_ises(fdvrf): + ensemble = fdvrf.ensemble + rank = ensemble.ensemble_rank + global_comm = ensemble.global_comm + + Vsol = fdvrf.solution_space + Vobs = fdvrf.observation_space + Vs = Vsol.local_spaces[0] + Vo = Vobs.local_spaces[0] + + # ndofs per (dn, dl, dx) block + bsol = Vsol.dof_dset.layout_vec.getLocalSize() + bobs = Vobs.dof_dset.layout_vec.getLocalSize() + nlocal_blocks = Vsol.nlocal_spaces + + bsize_dn = bsol + bsize_dl = bobs + bsize_dx = bsol + bsize = bsize_dn + bsize_dl + bsize_dx + + # number of blocks on previous ranks + nprev_blocks = ensemble.ensemble_comm.exscan(nlocal_blocks) + if rank == 0: # exscan returns None + nprev_blocks = 0 + + # offset to start of global indices of each field in local block j + offset = bsize*nprev_blocks + offset_dn = lambda j: offset + j*bsize + offset_dl = lambda j: offset_dn(j) + bsize_dn + offset_dx = lambda j: offset_dl(j) + bsize_dl + + indices_dn = np.concatenate( + [offset_dn(j) + np.arange(bsize_dn, dtype=np.int32) + for j in range(nlocal_blocks)]) + + indices_dl = np.concatenate( + [offset_dl(j) + np.arange(bsize_dl, dtype=np.int32) + for j in range(nlocal_blocks)]) + + indices_dx = np.concatenate( + [offset_dx(j) + np.arange(bsize_dx, dtype=np.int32) + for j in range(nlocal_blocks)]) + + is_dn = PETSc.IS().createGeneral(indices_dn, comm=global_comm) + is_dl = PETSc.IS().createGeneral(indices_dl, comm=global_comm) + is_dx = PETSc.IS().createGeneral(indices_dx, comm=global_comm) + + return ISNest((is_dn, is_dl, is_dx)) + + +def FDVarSaddlePointKSP(fdvrf, solver_parameters, options_prefix=None): + saddlemat = FDVarSaddlePointMat(fdvrf) + + ksp = PETSc.KSP().create(comm=fdvrf.ensemble.global_comm) + ksp.setOperators(saddlemat, saddlemat) + + options = OptionsManager(solver_parameters, options_prefix) + options.set_from_options(ksp) + + return ksp, options + + +# Saddle-point MatNest +def FDVarSaddlePointMat(fdvrf): + ensemble = fdvrf.ensemble + + dn_is, dl_is, dx_is = saddle_ises(fdvrf) + + # L Mat + L = AllAtOnceRFMat(fdvrf, action=TLM) + Lt = AllAtOnceRFMat(fdvrf, action=Adjoint) + + Lrow = dn_is + Lcol = dx_is + + Ltrow = dx_is + Ltcol = dn_is + + # H Mat + H = ObservationEnsembleRFMat(fdvrf, action=TLM) + Ht = ObservationEnsembleRFMat(fdvrf, action=Adjoint) + + Hrow = dl_is + Hcol = dx_is + + Htrow = dx_is + Htcol = dl_is + + # D Mat + D = ModelCovarianceEnsembleRFMat(fdvrf) + + Drow = dn_is + Dcol = dn_is + + # R Mat + R = ObservationCovarianceEnsembleRFMat(fdvrf) + + Rrow = dl_is + Rcol = dl_is + + fdvmat = PETSc.Mat().createNest( + mats=[D, L, # noqa: E127,E202 + R, H, # noqa: E127,E202 + Lt, Ht ], # noqa: E127,E202 + isrows=[Drow, Lrow, # noqa: E127,E202 + Rrow, Hrow, # noqa: E127,E202 + Ltrow, Htrow ], # noqa: E127,E202 + iscols=[Dcol, Lcol, # noqa: E127,E202 + Rcol, Hcol, # noqa: E127,E202 + Ltcol, Htcol ], # noqa: E127,E202 + comm=ensemble.global_comm) + + return fdvmat diff --git a/fdvar/tao_solver.py b/fdvar/tao_solver.py new file mode 100644 index 0000000..db4cba9 --- /dev/null +++ b/fdvar/tao_solver.py @@ -0,0 +1,189 @@ +import firedrake as fd +from firedrake import PETSc +from pyadjoint.optimization.tao_solver import ( + OptionsManager, TAOConvergenceError, _tao_reasons) +from functools import cached_property +from fdvar.mat import ReducedFunctionalMat + +__all__ = ("TAOObjective", "TAOConvergenceError", "TAOSolver") + + +class TAOObjective: + def __init__(self, Jhat, apply_riesz=False): + self.Jhat = Jhat + self.apply_riesz = apply_riesz + self._control = Jhat.control.copy() + self._m = Jhat.control.copy() + self._mdot = Jhat.control.copy() + + self.n = self._m._vec.getLocalSize() + self.N = self._m._vec.getSize() + self.sizes = (self.n, self.N) + + def objective(self, tao, x): + with self._control.vec_wo() as cvec: + x.copy(cvec) + return self.Jhat(self._control) + + def gradient(self, tao, x, g): + dJ = self.Jhat.derivative() + with dJ.vec_ro() as dvec: + dvec.copy(g) + + def objective_gradient(self, tao, x, g): + with self._control.vec_wo() as cvec: + x.copy(cvec) + J = self.Jhat(self._control) + dJ = self.Jhat.derivative() + with dJ.vec_ro() as dvec: + dvec.copy(g) + return J + + def hessian(self, A, x, y): + with self._mdot.vec_ro() as mvec: + x.copy(mvec) + ddJ = self.Jhat.hessian() + with ddJ.vec_ro() as dvec: + dvec.copy(y) + if self._shift != 0.0: + y.axpy(self._shift, x) + + @cached_property + def hessian_mat(self): + return ReducedFunctionalMat(self.Jhat) + # ctx = HessianCtx(self.Jhat, dual_options=self.dual_options) + # mat = PETSc.Mat().createPython( + # (self.sizes, self.sizes), ctx, + # comm=self.Jhat.ensemble.global_comm) + # mat.setOption(PETSc.Mat.Option.SYMMETRIC, True) + # mat.setUp() + # mat.assemble() + # return mat + pass # trick folding + + @cached_property + def gradient_norm_mat(self): + ctx = GradientNormCtx(self.Jhat) + mat = PETSc.Mat().createPython( + (self.sizes, self.sizes), ctx, + comm=self.Jhat.ensemble.global_comm) + mat.setOption(PETSc.Mat.Option.SYMMETRIC, True) + mat.setUp() + mat.assemble() + return mat + + +class HessianCtx: + @classmethod + def update(cls, tao, x, H, P): + ctx = H.getPythonContext() + with ctx._m.vec_wo() as mvec: + x.copy(mvec) + ctx._shift = 0.0 + + def __init__(self, Jhat): + self.Jhat = Jhat + self._m = Jhat.control.copy() + self._mdot = Jhat.control.copy() + self._shift = 0.0 + + def shift(self, A, alpha): + self._shift += alpha + + def mult(self, A, x, y): + with self._mdot.vec_wo() as mdvec: + x.copy(mdvec) + + # TODO: Why do we need to reevaluate and derivate? + _ = self.Jhat(self._m) + _ = self.Jhat.derivative() + ddJ = self.Jhat.hessian([self._mdot]) + + with ddJ.vec_ro() as dvec: + dvec.copy(y) + + if self._shift != 0.0: + y.axpy(self._shift, x) + + +class GradientNormCtx: + def __init__(self, Jhat): + self._xfunc = Jhat.control.copy() + self._ycofunc = self._xfunc.riesz_representation() + + # TODO: Just implement EnsembleFunction._ad_convert_type + efs = Jhat.control_space + v = fd.TestFunction(efs._full_local_space) + self.M = fd.inner(v, self._xfunc._full_local_function)*fd.dx + + def mult(self, mat, x, y): + with self._xfunc.vec_wo() as xvec: + x.copy(xvec) + + fd.assemble(self.M, tensor=self._ycofunc._full_local_function) + + # ycofunc = self.Jhat.control._ad_convert_riesz( + # self._xfunc, riesz_map=self.Jhat.control.riesz_map) + + with self._ycofunc.vec_ro() as yvec: + yvec.copy(y) + + +class TAOSolver: + def __init__(self, Jhat, *, options_prefix=None, + solver_parameters=None): + self.Jhat = Jhat + self.ensemble = Jhat.ensemble + + self.tao_objective = TAOObjective(Jhat) + + self.tao = PETSc.TAO().create( + comm=Jhat.ensemble.global_comm) + + # solution vector + self._x = Jhat.control._vec.duplicate() + self.tao.setSolution(self._x) + + # evaluate objective and gradient + self.tao.setObjective( + self.tao_objective.objective) + + self.tao.setGradient( + self.tao_objective.gradient) + + self.tao.setObjectiveGradient( + self.tao_objective.objective_gradient) + + # evaluate hessian action + hessian_mat = self.tao_objective.hessian_mat + self.tao.setHessian( + hessian_mat.getPythonContext().update, + hessian_mat) + + # gradient norm in correct space + self.tao.setGradientNorm( + self.tao_objective.gradient_norm_mat) + + # solver parameters and finish setup + self.options = OptionsManager( + solver_parameters, options_prefix) + self.options.set_from_options(self.tao) + + self.tao.setUp() + + def solve(self): + control = self.Jhat.control + with control.tape_value().vec_ro() as cvec: + cvec.copy(self._x) + + with self.options.inserted_options(): + self.tao.solve() + + if self.tao.getConvergedReason() <= 0: + # Using the same format as Firedrake linear solver errors + raise TAOConvergenceError( + f"TAOSolver failed to converge after {self.tao.getIterationNumber()} iterations " + f"with reason: {_tao_reasons[self.tao.getConvergedReason()]}") + + with control.vec_wo() as cvec: + self._x.copy(cvec) diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..293d5f0 --- /dev/null +++ b/setup.py @@ -0,0 +1,7 @@ +from setuptools import find_packages, setup + +setup( + name="fdvar", + version="0.1", + packages=find_packages() +) diff --git a/tao/rfmat_rectangular.py b/tao/rfmat_rectangular.py new file mode 100644 index 0000000..ae5907c --- /dev/null +++ b/tao/rfmat_rectangular.py @@ -0,0 +1,91 @@ +import firedrake as fd +from firedrake.adjoint import ( + ReducedFunctional, Control, set_working_tape, + continue_annotation, pause_annotation) +from fdvar.mat import ReducedFunctionalMat + +mesh = fd.UnitIntervalMesh(3) +x, = fd.SpatialCoordinate(mesh) +expr = 0.5*x*x - 7 + +Vin = fd.FunctionSpace(mesh, "CG", 1) +Vout = fd.FunctionSpace(mesh, "DG", 1) + +u = fd.TrialFunction(Vin) +v = fd.TestFunction(Vout) + +A = (u*v + fd.inner(fd.grad(u), fd.grad(v)) + v*u.dx(0))*fd.dx + +def tlm(y): + return fd.assemble(fd.action(A, y)).riesz_representation() + +def adj(y): + return fd.assemble(fd.action(fd.adjoint(A), y)).riesz_representation() + +continue_annotation() +with set_working_tape() as tape: + x = fd.Function(Vin) + Jhat = ReducedFunctional(tlm(x), Control(x), tape=tape) +pause_annotation() + +def riesz(how): + return {'riesz_representation': how} + +mat_tlm = ReducedFunctionalMat( + Jhat, action_type='tlm', + options={'riesz_representation': None}) + +mat_adj = ReducedFunctionalMat( + Jhat, action_type='adjoint', + input_options={'riesz_representation': 'l2', 'function_space': Vout.dual()}) + +ctx_adj = mat_adj.getPythonContext() + +print("tlm action") + +# manual calculation +uin = fd.Function(Vin).project(expr) +uout = tlm(uin) +print(f"{type(uout) = }") +print(f"{uout.dat.data = }") + +# re-evaluate RF +vin = fd.Function(Vin).project(expr) +vout = Jhat.tlm(vin, options={'riesz_representation': None}) +print(f"{type(vout) = }") +print(f"{vout.dat.data = }") + +# Mat action +win = fd.Function(Vin).project(expr) +wout = fd.Function(Vout) +with win.dat.vec_ro as x, wout.dat.vec_wo as y: + mat_tlm.mult(x, y) +print(f"{type(wout) = }") +print(f"{wout.dat.data = }") + +print() + +print("adjoint action") + +# manual calculation +uin = fd.Function(Vout).project(expr) +uout = adj(uin) +print(f"{type(uout) = }") +print(f"{uout.dat.data = }") + +# re-evaluate RF +vin = fd.Function(Vout).project(expr) +vout = Jhat.derivative(adj_input=vin.riesz_representation()) +print(f"{type(vout) = }") +print(f"{vout.dat.data = }") + +# Mat action +win = fd.Function(Vout).project(expr).riesz_representation() +wout = fd.Function(Vin) +with win.dat.vec_ro as x, wout.dat.vec_wo as y: + mat_adj.mult(x, y) + +print(f"{type(wout) = }") +print(f"{wout.dat.data = }") + +print() diff --git a/tao/rfmat_square.py b/tao/rfmat_square.py new file mode 100644 index 0000000..d5af064 --- /dev/null +++ b/tao/rfmat_square.py @@ -0,0 +1,82 @@ +import firedrake as fd +from firedrake.adjoint import ( + ReducedFunctional, Control, set_working_tape, + continue_annotation, pause_annotation) +from fdvar.mat import ReducedFunctionalMat + +mesh = fd.UnitIntervalMesh(3) +x, = fd.SpatialCoordinate(mesh) +expr = 0.5*x*x - 7 + +V = fd.FunctionSpace(mesh, "CG", 1) + +u = fd.TrialFunction(V) +v = fd.TestFunction(V) + +A = (u*v + fd.inner(fd.grad(u), fd.grad(v)) + v*u.dx(0))*fd.dx + +def tlm(y): + return fd.assemble(fd.action(A, y)) + +def adj(y): + return fd.assemble(fd.action(fd.adjoint(A), y)).riesz_representation() + +continue_annotation() +with set_working_tape() as tape: + x = fd.Function(V) + Jhat = ReducedFunctional(tlm(x), Control(x), tape=tape) +pause_annotation() + +mat_tlm = ReducedFunctionalMat( + Jhat, action_type='tlm') + +mat_adj = ReducedFunctionalMat( + Jhat, action_type='adjoint', + input_options={'riesz_representation': 'l2', 'function_space': V}) + +print("tlm action") + +# manual calculation +uin = fd.Function(V).project(expr) +uout = tlm(uin) +print(f"{type(uout) = }") +print(f"{uout.dat.data = }") + +# re-evaluate RF +vin = fd.Function(V).project(expr) +vout = Jhat.tlm(vin) +print(f"{type(vout) = }") +print(f"{vout.dat.data = }") + +# Mat action +win = fd.Function(V).project(expr) +wout = fd.Function(V.dual()) +with win.dat.vec_ro as x, wout.dat.vec_wo as y: + mat_tlm.mult(x, y) +print(f"{type(wout) = }") +print(f"{wout.dat.data = }") + +print() + +print("adjoint action") + +# manual calculation +uin = fd.Function(V).project(expr) +uout = adj(uin) +print(f"{type(uout) = }") +print(f"{uout.dat.data = }") + +# re-evaluate RF +vin = fd.Function(V).project(expr) +vout = Jhat.derivative(adj_input=vin) +print(f"{type(vout) = }") +print(f"{vout.dat.data = }") + +# Mat action +win = fd.Function(V).project(expr) +wout = fd.Function(V) +with win.dat.vec_ro as x, wout.dat.vec_wo as y: + mat_adj.mult(x, y) + +print(f"{type(wout) = }") +print(f"{wout.dat.data = }")