11# TODO : Preconditioners? Should Pl be transposed and made Pr and similar for Pr.
22
33@doc doc"""
4- LinearSolveAdjoint(; linsolve = nothing )
4+ LinearSolveAdjoint(; linsolve = missing )
55
66Given a Linear Problem ``A x = b`` computes the sensitivities for ``A`` and ``b`` as:
77
@@ -18,53 +18,49 @@ For more details, check [these notes](https://math.mit.edu/~stevenj/18.336/adjoi
1818## Choice of Linear Solver
1919
2020Note that in most cases, it makes sense to use the same linear solver for the adjoint as the
21- forward solve (this is done by keeping the linsolve as `nothing `). For example, if the
21+ forward solve (this is done by keeping the linsolve as `missing `). For example, if the
2222forward solve was performed via a Factorization, then we can reuse the factorization for the
2323adjoint solve. However, for specific structured matrices if ``A^T`` is known to have a
2424specific structure distinct from ``A`` then passing in a `linsolve` will be more efficient.
2525"""
2626@kwdef struct LinearSolveAdjoint{L} < :
2727 SciMLBase. AbstractSensitivityAlgorithm{0 , false , :central }
28- linsolve:: L = nothing
28+ linsolve:: L = missing
2929end
3030
31- function CRC. rrule (:: typeof (SciMLBase. init), prob:: LinearProblem ,
32- alg:: SciMLLinearSolveAlgorithm , args... ; kwargs... )
31+ function CRC. rrule (:: typeof (SciMLBase. solve), prob:: LinearProblem ,
32+ alg:: SciMLLinearSolveAlgorithm , args... ; alias_A = default_alias_A (
33+ alg, prob. A, prob. b), kwargs... )
34+ # sol = solve(prob, alg, args...; kwargs...)
3335 cache = init (prob, alg, args... ; kwargs... )
34- function ∇init (∂cache)
35- ∂∅ = NoTangent ()
36- ∂p = prob. p isa SciMLBase. NullParameters ? prob. p : ProjectTo (prob. p)(∂cache. p)
37- ∂prob = LinearProblem (∂cache. A, ∂cache. b, ∂p)
38- return (∂∅, ∂prob, ∂∅, ntuple (_ -> ∂∅, length (args))... )
39- end
40- return cache, ∇init
41- end
36+ (; A, sensealg) = cache
4237
43- function CRC. rrule (:: typeof (SciMLBase. solve!), cache:: LinearCache , alg, args... ;
44- kwargs... )
45- (; A, b, sensealg) = cache
38+ @assert sensealg isa LinearSolveAdjoint " Currently only `LinearSolveAdjoint` is supported for adjoint sensitivity analysis."
4639
4740 # Decide if we need to cache `A` and `b` for the reverse pass
48- if sensealg. linsolve === nothing
41+ if sensealg. linsolve === missing
4942 # We can reuse the factorization so no copy is needed
5043 # Krylov Methods don't modify `A`, so it's safe to just reuse it
5144 # No Copy is needed even for the default case
5245 if ! (alg isa AbstractFactorization || alg isa AbstractKrylovSubspaceMethod ||
5346 alg isa DefaultLinearSolver)
54- A_ = cache . alias_A ? deepcopy (A) : A
47+ A_ = alias_A ? deepcopy (A) : A
5548 end
5649 else
57- error (" Not Implemented Yet!!!" )
50+ if alg isa DefaultLinearSolver
51+ A_ = deepcopy (A)
52+ else
53+ A_ = alias_A ? deepcopy (A) : A
54+ end
5855 end
5956
60- # Forward Solve
61- sol = solve! (cache, alg, args... ; kwargs... )
57+ sol = solve! (cache)
58+
59+ function ∇linear_solve (∂sol)
60+ ∂∅ = NoTangent ()
6261
63- function ∇solve! (∂sol)
64- @assert ! cache. isfresh " `cache.A` has been updated between the forward and the \
65- reverse pass. This is not supported."
6662 ∂u = ∂sol. u
67- if sensealg. linsolve === nothing
63+ if sensealg. linsolve === missing
6864 λ = if cache. cacheval isa Factorization
6965 cache. cacheval' \ ∂u
7066 elseif cache. cacheval isa Tuple && cache. cacheval[1 ] isa Factorization
@@ -79,25 +75,23 @@ function CRC.rrule(::typeof(SciMLBase.solve!), cache::LinearCache, alg, args...;
7975 solve (invprob, alg; cache. abstol, cache. reltol, cache. verbose). u
8076 end
8177 else
82- error (" Not Implemented Yet!!!" )
78+ invprob = LinearProblem (transpose (A_), ∂u) # We cached `A`
79+ λ = solve (
80+ invprob, sensealg. linsolve; cache. abstol, cache. reltol, cache. verbose). u
8381 end
8482
8583 ∂A = - λ * transpose (sol. u)
8684 ∂b = λ
87- ∂∅ = NoTangent ()
88-
89- ∂cache = LinearCache (∂A, ∂b, ∂∅, ∂∅, ∂∅, ∂∅, cache. isfresh, ∂∅, ∂∅, cache. abstol,
90- cache. reltol, cache. maxiters, cache. verbose, cache. assumptions, cache. sensealg)
85+ ∂prob = LinearProblem (∂A, ∂b, ∂∅)
9186
92- return (∂∅, ∂cache , ∂∅, ntuple (_ -> ∂∅, length (args))... )
87+ return (∂∅, ∂prob , ∂∅, ntuple (_ -> ∂∅, length (args))... )
9388 end
94- return sol, ∇solve!
89+
90+ return sol, ∇linear_solve
9591end
9692
9793function CRC. rrule (:: Type{<:LinearProblem} , A, b, p; kwargs... )
9894 prob = LinearProblem (A, b, p)
99- function ∇prob (∂prob)
100- return NoTangent (), ∂prob. A, ∂prob. b, ∂prob. p
101- end
95+ ∇prob (∂prob) = (NoTangent (), ∂prob. A, ∂prob. b, ∂prob. p)
10296 return prob, ∇prob
10397end
0 commit comments