Consider the follow code:
from firedrake import *
from firedrake.adjoint import *
from checkpoint_schedules import SingleMemoryStorageSchedule
continue_annotation()
tape = get_working_tape()
schedule = SingleMemoryStorageSchedule()
tape.enable_checkpointing(schedule)
mesh = UnitSquareMesh(1,1)
V = FunctionSpace(mesh, "CG", 1)
m = Function(V).assign(1.0)
sumf = Function(V)
u = Function(V)
tst = TestFunction(V)
F = tst*u*dx - tst*m*m*dx
problem = NonlinearVariationalProblem(F, u)
solver = NonlinearVariationalSolver(problem)
for step in tape.timestepper(iter(range(4))):
solver.solve()
sumf.assign(sumf + u)
J = assemble(sumf*dx)
rf = ReducedFunctional(J, Control(m))
m0 = Function(V).assign(2.0)
print(rf(m0))
print(rf.derivative(apply_riesz=True).dat.data)
The taped model computes $$J=4m^2$$ and thus the derivative should be $$\frac{dJ}{dm} = 8m$$ and therefore for m=2 I expect the derivative to be 16. With the checkpointing schedule as above however, even though the forward (re)run with m=2 correctly produces $$J=16$$, the derivative it computes is 8 - it seems during the adjoint evaluation it reuses the old control value not the current one. If I comment out the enable_checkpointing line, I do get the expected result.
Consider the follow code:
The taped model computes$$J=4m^2$$ and thus the derivative should be $$\frac{dJ}{dm} = 8m$$ and therefore for $$J=16$$ , the derivative it computes is 8 - it seems during the adjoint evaluation it reuses the old control value not the current one. If I comment out the
m=2I expect the derivative to be 16. With the checkpointing schedule as above however, even though the forward (re)run withm=2correctly producesenable_checkpointingline, I do get the expected result.