Skip to content
This repository was archived by the owner on May 15, 2025. It is now read-only.

Commit cb3723f

Browse files
committed
Revert "LBroyden"
This reverts commit e905737.
1 parent e905737 commit cb3723f

File tree

6 files changed

+119
-210
lines changed

6 files changed

+119
-210
lines changed

ext/SimpleNonlinearSolveADLinearSolveExt.jl

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,8 @@
11
module SimpleNonlinearSolveADLinearSolveExt
22

3-
using AbstractDifferentiation,
4-
ArrayInterface, DiffEqBase, LinearAlgebra, LinearSolve,
3+
using AbstractDifferentiation, ArrayInterface, DiffEqBase, LinearAlgebra, LinearSolve,
54
SimpleNonlinearSolve, SciMLBase
6-
import SimpleNonlinearSolve: _construct_batched_problem_structure,
7-
_get_storage, _result_from_storage, _get_tolerance, @maybeinplace
5+
import SimpleNonlinearSolve: _construct_batched_problem_structure, _get_storage, _result_from_storage, _get_tolerance, @maybeinplace
86

97
const AD = AbstractDifferentiation
108

@@ -22,18 +20,19 @@ function SimpleNonlinearSolve.SimpleBatchedNewtonRaphson(; chunk_size = Val{0}()
2220
# TODO: Use `diff_type`. FiniteDiff.jl is currently not available in AD.jl
2321
chunksize = SciMLBase._unwrap_val(chunk_size) == 0 ? nothing : chunk_size
2422
ad = SciMLBase._unwrap_val(autodiff) ?
25-
AD.ForwardDiffBackend(; chunksize) :
26-
AD.FiniteDifferencesBackend()
27-
return SimpleBatchedNewtonRaphson{typeof(ad), Nothing, typeof(termination_condition)}(ad,
23+
AD.ForwardDiffBackend(; chunksize) :
24+
AD.FiniteDifferencesBackend()
25+
return SimpleBatchedNewtonRaphson{typeof(ad), Nothing, typeof(termination_condition)}(
26+
ad,
2827
nothing,
2928
termination_condition)
3029
end
3130

3231
function SciMLBase.__solve(prob::NonlinearProblem,
3332
alg::SimpleBatchedNewtonRaphson;
34-
abstol = nothing,
35-
reltol = nothing,
36-
maxiters = 1000,
33+
abstol=nothing,
34+
reltol=nothing,
35+
maxiters=1000,
3736
kwargs...)
3837
iip = isinplace(prob)
3938
@assert !iip "SimpleBatchedNewtonRaphson currently only supports out-of-place nonlinear problems."
@@ -58,9 +57,9 @@ function SciMLBase.__solve(prob::NonlinearProblem,
5857
alg,
5958
reconstruct(xₙ),
6059
reconstruct(fₙ);
61-
retcode = ReturnCode.Success)
60+
retcode=ReturnCode.Success)
6261

63-
solve(LinearProblem(𝓙, vec(fₙ); u0 = vec(δx)), alg.linsolve; kwargs...)
62+
solve(LinearProblem(𝓙, vec(fₙ); u0=vec(δx)), alg.linsolve; kwargs...)
6463
xₙ .-= δx
6564

6665
if termination_condition(fₙ, xₙ, xₙ₋₁, atol, rtol)
@@ -84,7 +83,7 @@ function SciMLBase.__solve(prob::NonlinearProblem,
8483
alg,
8584
reconstruct(xₙ),
8685
reconstruct(fₙ);
87-
retcode = ReturnCode.MaxIters)
86+
retcode=ReturnCode.MaxIters)
8887
end
8988

9089
end
Lines changed: 12 additions & 110 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,18 @@
11
module SimpleNonlinearSolveNNlibExt
22

33
using ArrayInterface, DiffEqBase, LinearAlgebra, NNlib, SimpleNonlinearSolve, SciMLBase
4-
import SimpleNonlinearSolve: _construct_batched_problem_structure,
5-
_get_storage, _init_𝓙, _result_from_storage, _get_tolerance, @maybeinplace
4+
import SimpleNonlinearSolve: _construct_batched_problem_structure, _get_storage, _init_𝓙, _result_from_storage, _get_tolerance, @maybeinplace
65

76
function __init__()
87
SimpleNonlinearSolve.NNlibExtLoaded[] = true
98
return
109
end
1110

12-
# Broyden's method
1311
@views function SciMLBase.__solve(prob::NonlinearProblem,
1412
alg::BatchedBroyden;
15-
abstol = nothing,
16-
reltol = nothing,
17-
maxiters = 1000,
13+
abstol=nothing,
14+
reltol=nothing,
15+
maxiters=1000,
1816
kwargs...)
1917
iip = isinplace(prob)
2018

@@ -26,7 +24,7 @@ end
2624

2725
storage = _get_storage(mode, u)
2826

29-
xₙ, xₙ₋₁, δxₙ, δf = ntuple(_ -> copy(u), 4)
27+
xₙ, xₙ₋₁, δx, δf = ntuple(_ -> copy(u), 4)
3028
T = eltype(u)
3129

3230
atol = _get_tolerance(abstol, tc.abstol, T)
@@ -43,16 +41,16 @@ end
4341
xₙ .= xₙ₋₁ .- 𝓙⁻¹f
4442

4543
@maybeinplace iip fₙ=f(xₙ)
46-
δxₙ .= xₙ .- xₙ₋₁
44+
δx .= xₙ .- xₙ₋₁
4745
δf .= fₙ .- fₙ₋₁
4846

4947
batched_mul!(reshape(𝓙⁻¹f, L, 1, N), 𝓙⁻¹, reshape(δf, L, 1, N))
50-
δxₙᵀ = reshape(δxₙ, 1, L, N)
48+
δxᵀ = reshape(δx, 1, L, N)
5149

52-
batched_mul!(reshape(xᵀ𝓙⁻¹δf, 1, 1, N), δxₙᵀ, reshape(𝓙⁻¹f, L, 1, N))
53-
batched_mul!(xᵀ𝓙⁻¹, δxₙᵀ, 𝓙⁻¹)
54-
δxₙ .= (δxₙ .- 𝓙⁻¹f) ./ (xᵀ𝓙⁻¹δf .+ T(1e-5))
55-
batched_mul!(𝓙⁻¹, reshape(δxₙ, L, 1, N), xᵀ𝓙⁻¹, one(T), one(T))
50+
batched_mul!(reshape(xᵀ𝓙⁻¹δf, 1, 1, N), δxᵀ, reshape(𝓙⁻¹f, L, 1, N))
51+
batched_mul!(xᵀ𝓙⁻¹, δxᵀ, 𝓙⁻¹)
52+
δx .= (δx .- 𝓙⁻¹f) ./ (xᵀ𝓙⁻¹δf .+ T(1e-5))
53+
batched_mul!(𝓙⁻¹, reshape(δx, L, 1, N), xᵀ𝓙⁻¹, one(T), one(T))
5654

5755
if termination_condition(fₙ, xₙ, xₙ₋₁, atol, rtol)
5856
retcode, xₙ, fₙ = _result_from_storage(storage, xₙ, fₙ, f, mode, iip)
@@ -76,103 +74,7 @@ end
7674
alg,
7775
reconstruct(xₙ),
7876
reconstruct(fₙ);
79-
retcode = ReturnCode.MaxIters)
80-
end
81-
82-
# Limited Memory Broyden's method
83-
@views function SciMLBase.__solve(prob::NonlinearProblem,
84-
alg::BatchedLBroyden;
85-
abstol = nothing,
86-
reltol = nothing,
87-
maxiters = 1000,
88-
kwargs...)
89-
iip = isinplace(prob)
90-
91-
u, f, reconstruct = _construct_batched_problem_structure(prob)
92-
L, N = size(u)
93-
T = eltype(u)
94-
95-
tc = alg.termination_condition
96-
mode = DiffEqBase.get_termination_mode(tc)
97-
98-
storage = _get_storage(mode, u)
99-
100-
η = min(maxiters, alg.threshold)
101-
U = fill!(similar(u, (η, L, N)), zero(T))
102-
Vᵀ = fill!(similar(u, (L, η, N)), zero(T))
103-
104-
xₙ, xₙ₋₁, δfₙ = ntuple(_ -> copy(u), 3)
105-
106-
atol = _get_tolerance(abstol, tc.abstol, T)
107-
rtol = _get_tolerance(reltol, tc.reltol, T)
108-
termination_condition = tc(storage)
109-
110-
@maybeinplace iip fₙ₋₁=f(xₙ) u
111-
iip && (fₙ = copy(fₙ₋₁))
112-
δxₙ = -copy(fₙ₋₁)
113-
ηNx = similar(xₙ, η, N)
114-
115-
for i in 1:maxiters
116-
@. xₙ = xₙ₋₁ - δxₙ
117-
@maybeinplace iip fₙ=f(xₙ)
118-
@. δxₙ = xₙ - xₙ₋₁
119-
@. δfₙ = fₙ - fₙ₋₁
120-
121-
if termination_condition(fₙ, xₙ, xₙ₋₁, atol, rtol)
122-
retcode, xₙ, fₙ = _result_from_storage(storage, xₙ, fₙ, f, mode, iip)
123-
return DiffEqBase.build_solution(prob,
124-
alg,
125-
reconstruct(xₙ),
126-
reconstruct(fₙ);
127-
retcode)
128-
end
129-
130-
_L = min(i, η)
131-
_U = U[1:_L, :, :]
132-
_Vᵀ = Vᵀ[:, 1:_L, :]
133-
134-
idx = mod1(i, η)
135-
136-
if i > 1
137-
partial_ηNx = ηNx[1:_L, :]
138-
139-
_ηNx = reshape(partial_ηNx, 1, :, N)
140-
batched_mul!(_ηNx, reshape(δxₙ, 1, L, N), _Vᵀ)
141-
batched_mul!(Vᵀ[:, idx:idx, :], _ηNx, _U)
142-
Vᵀ[:, idx, :] .-= δxₙ
143-
144-
_ηNx = reshape(partial_ηNx, :, 1, N)
145-
batched_mul!(_ηNx, _U, reshape(δfₙ, L, 1, N))
146-
batched_mul!(U[idx:idx, :, :], _Vᵀ, _ηNx)
147-
U[idx, :, :] .-= δfₙ
148-
else
149-
Vᵀ[:, idx, :] .= -δxₙ
150-
U[idx, :, :] .= -δfₙ
151-
end
152-
153-
U[idx, :, :] .= (δxₙ .- U[idx, :, :]) ./
154-
(sum(Vᵀ[:, idx, :] .* δfₙ; dims = 1) .+
155-
convert(T, 1e-5))
156-
157-
_L = min(i + 1, η)
158-
_ηNx = reshape(ηNx[1:_L, :], :, 1, N)
159-
batched_mul!(_ηNx, U[1:_L, :, :], reshape(δfₙ, L, 1, N))
160-
batched_mul!(reshape(δxₙ, L, 1, N), Vᵀ[:, 1:_L, :], _ηNx)
161-
162-
xₙ₋₁ .= xₙ
163-
fₙ₋₁ .= fₙ
164-
end
165-
166-
if mode DiffEqBase.SAFE_BEST_TERMINATION_MODES
167-
xₙ = storage.u
168-
@maybeinplace iip fₙ=f(xₙ)
169-
end
170-
171-
return DiffEqBase.build_solution(prob,
172-
alg,
173-
reconstruct(xₙ),
174-
reconstruct(fₙ);
175-
retcode = ReturnCode.MaxIters)
77+
retcode=ReturnCode.MaxIters)
17678
end
17779

17880
end

src/batched/dfsane.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
Base.@kwdef struct SimpleBatchedDFSane{T, F, TC <: NLSolveTerminationCondition} <:
1+
@kwdef struct SimpleBatchedDFSane{T, F, TC <: NLSolveTerminationCondition} <:
22
AbstractBatchedNonlinearSolveAlgorithm
33
σₘᵢₙ::T = 1.0f-10
44
σₘₐₓ::T = 1.0f+10

src/batched/lbroyden.jl

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +0,0 @@
1-
struct BatchedLBroyden{TC <: NLSolveTerminationCondition} <:
2-
AbstractBatchedNonlinearSolveAlgorithm
3-
termination_condition::TC
4-
threshold::Int
5-
end
6-
7-
# Implementation of solve using Package Extensions

src/broyden.jl

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,6 @@ end
3030

3131
function SciMLBase.__solve(prob::NonlinearProblem, alg::Broyden, args...;
3232
abstol = nothing, reltol = nothing, maxiters = 1000, kwargs...)
33-
if SciMLBase.isinplace(prob)
34-
error("Broyden currently only supports out-of-place nonlinear problems")
35-
end
3633
tc = alg.termination_condition
3734
mode = DiffEqBase.get_termination_mode(tc)
3835
f = Base.Fix2(prob.f, prob.p)
@@ -42,14 +39,19 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::Broyden, args...;
4239
T = eltype(x)
4340
J⁻¹ = init_J(x)
4441

42+
if SciMLBase.isinplace(prob)
43+
error("Broyden currently only supports out-of-place nonlinear problems")
44+
end
45+
4546
atol = _get_tolerance(abstol, tc.abstol, T)
4647
rtol = _get_tolerance(reltol, tc.reltol, T)
4748

4849
if mode DiffEqBase.SAFE_BEST_TERMINATION_MODES
4950
error("Broyden currently doesn't support SAFE_BEST termination modes")
5051
end
5152

52-
storage = _get_storage(mode, x)
53+
storage = mode DiffEqBase.SAFE_TERMINATION_MODES ? NLSolveSafeTerminationResult() :
54+
nothing
5355
termination_condition = tc(storage)
5456

5557
xₙ = x

0 commit comments

Comments
 (0)