Skip to content

BUG: incorrect adjoint when using checkpoint scheduler #5082

@stephankramer

Description

@stephankramer

Consider the follow code:

from firedrake import *
from firedrake.adjoint import *
from checkpoint_schedules import SingleMemoryStorageSchedule
continue_annotation()
tape = get_working_tape()

schedule = SingleMemoryStorageSchedule()

tape.enable_checkpointing(schedule)

mesh = UnitSquareMesh(1,1)
V = FunctionSpace(mesh, "CG", 1)

m = Function(V).assign(1.0)

sumf = Function(V)
u = Function(V)
tst = TestFunction(V)
F = tst*u*dx - tst*m*m*dx

problem = NonlinearVariationalProblem(F, u)
solver = NonlinearVariationalSolver(problem)

for step in tape.timestepper(iter(range(4))):
    solver.solve()
    sumf.assign(sumf + u)

J = assemble(sumf*dx)
rf = ReducedFunctional(J, Control(m))

m0 = Function(V).assign(2.0)
print(rf(m0))
print(rf.derivative(apply_riesz=True).dat.data)

The taped model computes $$J=4m^2$$ and thus the derivative should be $$\frac{dJ}{dm} = 8m$$ and therefore for m=2 I expect the derivative to be 16. With the checkpointing schedule as above however, even though the forward (re)run with m=2 correctly produces $$J=16$$, the derivative it computes is 8 - it seems during the adjoint evaluation it reuses the old control value not the current one. If I comment out the enable_checkpointing line, I do get the expected result.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions