|
1 | 1 | module SimpleNonlinearSolveNNlibExt |
2 | 2 |
|
3 | 3 | using ArrayInterface, DiffEqBase, LinearAlgebra, NNlib, SimpleNonlinearSolve, SciMLBase |
4 | | -import SimpleNonlinearSolve: _construct_batched_problem_structure, |
5 | | - _get_storage, _init_𝓙, _result_from_storage, _get_tolerance, @maybeinplace |
| 4 | +import SimpleNonlinearSolve: _construct_batched_problem_structure, _get_storage, _init_𝓙, _result_from_storage, _get_tolerance, @maybeinplace |
6 | 5 |
|
7 | 6 | function __init__() |
8 | 7 | SimpleNonlinearSolve.NNlibExtLoaded[] = true |
9 | 8 | return |
10 | 9 | end |
11 | 10 |
|
12 | | -# Broyden's method |
13 | 11 | @views function SciMLBase.__solve(prob::NonlinearProblem, |
14 | 12 | alg::BatchedBroyden; |
15 | | - abstol = nothing, |
16 | | - reltol = nothing, |
17 | | - maxiters = 1000, |
| 13 | + abstol=nothing, |
| 14 | + reltol=nothing, |
| 15 | + maxiters=1000, |
18 | 16 | kwargs...) |
19 | 17 | iip = isinplace(prob) |
20 | 18 |
|
|
26 | 24 |
|
27 | 25 | storage = _get_storage(mode, u) |
28 | 26 |
|
29 | | - xₙ, xₙ₋₁, δxₙ, δf = ntuple(_ -> copy(u), 4) |
| 27 | + xₙ, xₙ₋₁, δx, δf = ntuple(_ -> copy(u), 4) |
30 | 28 | T = eltype(u) |
31 | 29 |
|
32 | 30 | atol = _get_tolerance(abstol, tc.abstol, T) |
|
43 | 41 | xₙ .= xₙ₋₁ .- 𝓙⁻¹f |
44 | 42 |
|
45 | 43 | @maybeinplace iip fₙ=f(xₙ) |
46 | | - δxₙ .= xₙ .- xₙ₋₁ |
| 44 | + δx .= xₙ .- xₙ₋₁ |
47 | 45 | δf .= fₙ .- fₙ₋₁ |
48 | 46 |
|
49 | 47 | batched_mul!(reshape(𝓙⁻¹f, L, 1, N), 𝓙⁻¹, reshape(δf, L, 1, N)) |
50 | | - δxₙᵀ = reshape(δxₙ, 1, L, N) |
| 48 | + δxᵀ = reshape(δx, 1, L, N) |
51 | 49 |
|
52 | | - batched_mul!(reshape(xᵀ𝓙⁻¹δf, 1, 1, N), δxₙᵀ, reshape(𝓙⁻¹f, L, 1, N)) |
53 | | - batched_mul!(xᵀ𝓙⁻¹, δxₙᵀ, 𝓙⁻¹) |
54 | | - δxₙ .= (δxₙ .- 𝓙⁻¹f) ./ (xᵀ𝓙⁻¹δf .+ T(1e-5)) |
55 | | - batched_mul!(𝓙⁻¹, reshape(δxₙ, L, 1, N), xᵀ𝓙⁻¹, one(T), one(T)) |
| 50 | + batched_mul!(reshape(xᵀ𝓙⁻¹δf, 1, 1, N), δxᵀ, reshape(𝓙⁻¹f, L, 1, N)) |
| 51 | + batched_mul!(xᵀ𝓙⁻¹, δxᵀ, 𝓙⁻¹) |
| 52 | + δx .= (δx .- 𝓙⁻¹f) ./ (xᵀ𝓙⁻¹δf .+ T(1e-5)) |
| 53 | + batched_mul!(𝓙⁻¹, reshape(δx, L, 1, N), xᵀ𝓙⁻¹, one(T), one(T)) |
56 | 54 |
|
57 | 55 | if termination_condition(fₙ, xₙ, xₙ₋₁, atol, rtol) |
58 | 56 | retcode, xₙ, fₙ = _result_from_storage(storage, xₙ, fₙ, f, mode, iip) |
|
76 | 74 | alg, |
77 | 75 | reconstruct(xₙ), |
78 | 76 | reconstruct(fₙ); |
79 | | - retcode = ReturnCode.MaxIters) |
80 | | -end |
81 | | - |
82 | | -# Limited Memory Broyden's method |
83 | | -@views function SciMLBase.__solve(prob::NonlinearProblem, |
84 | | - alg::BatchedLBroyden; |
85 | | - abstol = nothing, |
86 | | - reltol = nothing, |
87 | | - maxiters = 1000, |
88 | | - kwargs...) |
89 | | - iip = isinplace(prob) |
90 | | - |
91 | | - u, f, reconstruct = _construct_batched_problem_structure(prob) |
92 | | - L, N = size(u) |
93 | | - T = eltype(u) |
94 | | - |
95 | | - tc = alg.termination_condition |
96 | | - mode = DiffEqBase.get_termination_mode(tc) |
97 | | - |
98 | | - storage = _get_storage(mode, u) |
99 | | - |
100 | | - η = min(maxiters, alg.threshold) |
101 | | - U = fill!(similar(u, (η, L, N)), zero(T)) |
102 | | - Vᵀ = fill!(similar(u, (L, η, N)), zero(T)) |
103 | | - |
104 | | - xₙ, xₙ₋₁, δfₙ = ntuple(_ -> copy(u), 3) |
105 | | - |
106 | | - atol = _get_tolerance(abstol, tc.abstol, T) |
107 | | - rtol = _get_tolerance(reltol, tc.reltol, T) |
108 | | - termination_condition = tc(storage) |
109 | | - |
110 | | - @maybeinplace iip fₙ₋₁=f(xₙ) u |
111 | | - iip && (fₙ = copy(fₙ₋₁)) |
112 | | - δxₙ = -copy(fₙ₋₁) |
113 | | - ηNx = similar(xₙ, η, N) |
114 | | - |
115 | | - for i in 1:maxiters |
116 | | - @. xₙ = xₙ₋₁ - δxₙ |
117 | | - @maybeinplace iip fₙ=f(xₙ) |
118 | | - @. δxₙ = xₙ - xₙ₋₁ |
119 | | - @. δfₙ = fₙ - fₙ₋₁ |
120 | | - |
121 | | - if termination_condition(fₙ, xₙ, xₙ₋₁, atol, rtol) |
122 | | - retcode, xₙ, fₙ = _result_from_storage(storage, xₙ, fₙ, f, mode, iip) |
123 | | - return DiffEqBase.build_solution(prob, |
124 | | - alg, |
125 | | - reconstruct(xₙ), |
126 | | - reconstruct(fₙ); |
127 | | - retcode) |
128 | | - end |
129 | | - |
130 | | - _L = min(i, η) |
131 | | - _U = U[1:_L, :, :] |
132 | | - _Vᵀ = Vᵀ[:, 1:_L, :] |
133 | | - |
134 | | - idx = mod1(i, η) |
135 | | - |
136 | | - if i > 1 |
137 | | - partial_ηNx = ηNx[1:_L, :] |
138 | | - |
139 | | - _ηNx = reshape(partial_ηNx, 1, :, N) |
140 | | - batched_mul!(_ηNx, reshape(δxₙ, 1, L, N), _Vᵀ) |
141 | | - batched_mul!(Vᵀ[:, idx:idx, :], _ηNx, _U) |
142 | | - Vᵀ[:, idx, :] .-= δxₙ |
143 | | - |
144 | | - _ηNx = reshape(partial_ηNx, :, 1, N) |
145 | | - batched_mul!(_ηNx, _U, reshape(δfₙ, L, 1, N)) |
146 | | - batched_mul!(U[idx:idx, :, :], _Vᵀ, _ηNx) |
147 | | - U[idx, :, :] .-= δfₙ |
148 | | - else |
149 | | - Vᵀ[:, idx, :] .= -δxₙ |
150 | | - U[idx, :, :] .= -δfₙ |
151 | | - end |
152 | | - |
153 | | - U[idx, :, :] .= (δxₙ .- U[idx, :, :]) ./ |
154 | | - (sum(Vᵀ[:, idx, :] .* δfₙ; dims = 1) .+ |
155 | | - convert(T, 1e-5)) |
156 | | - |
157 | | - _L = min(i + 1, η) |
158 | | - _ηNx = reshape(ηNx[1:_L, :], :, 1, N) |
159 | | - batched_mul!(_ηNx, U[1:_L, :, :], reshape(δfₙ, L, 1, N)) |
160 | | - batched_mul!(reshape(δxₙ, L, 1, N), Vᵀ[:, 1:_L, :], _ηNx) |
161 | | - |
162 | | - xₙ₋₁ .= xₙ |
163 | | - fₙ₋₁ .= fₙ |
164 | | - end |
165 | | - |
166 | | - if mode ∈ DiffEqBase.SAFE_BEST_TERMINATION_MODES |
167 | | - xₙ = storage.u |
168 | | - @maybeinplace iip fₙ=f(xₙ) |
169 | | - end |
170 | | - |
171 | | - return DiffEqBase.build_solution(prob, |
172 | | - alg, |
173 | | - reconstruct(xₙ), |
174 | | - reconstruct(fₙ); |
175 | | - retcode = ReturnCode.MaxIters) |
| 77 | + retcode=ReturnCode.MaxIters) |
176 | 78 | end |
177 | 79 |
|
178 | 80 | end |
0 commit comments