Skip to content
This repository was archived by the owner on May 15, 2025. It is now read-only.

Commit 9cb861a

Browse files
committed
Update batched Raphson to use same parameters as unbatched
1 parent 534e1db commit 9cb861a

File tree

10 files changed

+98
-120
lines changed

10 files changed

+98
-120
lines changed

Project.toml

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,21 +16,16 @@ SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
1616
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
1717

1818
[weakdeps]
19-
AbstractDifferentiation = "c29ec348-61ec-40c8-8164-b8c60e9d9f3d"
20-
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
2119
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
2220

2321
[extensions]
24-
SimpleNonlinearSolveADLinearSolveExt = ["AbstractDifferentiation", "LinearSolve"]
2522
SimpleNonlinearSolveNNlibExt = "NNlib"
2623

2724
[compat]
28-
AbstractDifferentiation = "0.5"
2925
ArrayInterface = "6, 7"
3026
DiffEqBase = "6.126"
3127
FiniteDiff = "2"
3228
ForwardDiff = "0.10.3"
33-
LinearSolve = "2"
3429
NNlib = "0.8, 0.9"
3530
PackageExtensionCompat = "1"
3631
PrecompileTools = "1"
@@ -40,6 +35,4 @@ StaticArraysCore = "1.4"
4035
julia = "1.6"
4136

4237
[extras]
43-
AbstractDifferentiation = "c29ec348-61ec-40c8-8164-b8c60e9d9f3d"
44-
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
4538
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"

ext/SimpleNonlinearSolveADLinearSolveExt.jl

Lines changed: 0 additions & 90 deletions
This file was deleted.

src/SimpleNonlinearSolve.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@ function __init__()
1515
@require_extensions
1616
end
1717

18-
const ADLinearSolveExtLoaded = Ref{Bool}(false)
1918
const NNlibExtLoaded = Ref{Bool}(false)
2019

2120
abstract type AbstractSimpleNonlinearSolveAlgorithm <: SciMLBase.AbstractNonlinearAlgorithm end
@@ -78,6 +77,6 @@ end
7877
# DiffEq styled algorithms
7978
export Bisection, Brent, Broyden, LBroyden, SimpleDFSane, Falsi, Halley, Klement,
8079
Ridder, SimpleNewtonRaphson, SimpleTrustRegion, Alefeld
81-
export BatchedBroyden, SimpleBatchedNewtonRaphson, SimpleBatchedDFSane
80+
export BatchedBroyden, BatchedSimpleNewtonRaphson, BatchedSimpleDFSane
8281

8382
end # module

src/batched/dfsane.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
Base.@kwdef struct SimpleBatchedDFSane{T, F, TC <: NLSolveTerminationCondition} <:
1+
Base.@kwdef struct BatchedSimpleDFSane{T, F, TC <: NLSolveTerminationCondition} <:
22
AbstractBatchedNonlinearSolveAlgorithm
33
σₘᵢₙ::T = 1.0f-10
44
σₘₐₓ::T = 1.0f+10
@@ -16,7 +16,7 @@ Base.@kwdef struct SimpleBatchedDFSane{T, F, TC <: NLSolveTerminationCondition}
1616
end
1717

1818
function SciMLBase.__solve(prob::NonlinearProblem,
19-
alg::SimpleBatchedDFSane,
19+
alg::BatchedSimpleDFSane,
2020
args...;
2121
abstol = nothing,
2222
reltol = nothing,

src/batched/raphson.jl

Lines changed: 73 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,77 @@
1-
struct SimpleBatchedNewtonRaphson{AD, LS, TC <: NLSolveTerminationCondition} <:
1+
struct BatchedSimpleNewtonRaphson{CS, AD, FDT, TC <: NLSolveTerminationCondition} <:
22
AbstractBatchedNonlinearSolveAlgorithm
3-
autodiff::AD
4-
linsolve::LS
53
termination_condition::TC
64
end
75

8-
# Implementation of solve using Package Extensions
6+
alg_autodiff(alg::BatchedSimpleNewtonRaphson{CS, AD, FDT}) where {CS, AD, FDT} = AD
7+
diff_type(alg::BatchedSimpleNewtonRaphson{CS, AD, FDT}) where {CS, AD, FDT} = FDT
8+
9+
function BatchedSimpleNewtonRaphson(; chunk_size = Val{0}(),
10+
autodiff = Val{true}(),
11+
diff_type = Val{:forward},
12+
termination_condition = NLSolveTerminationCondition(NLSolveTerminationMode.NLSolveDefault;
13+
abstol = nothing,
14+
reltol = nothing))
15+
return BatchedSimpleNewtonRaphson{SciMLBase._unwrap_val(chunk_size),
16+
SciMLBase._unwrap_val(autodiff),
17+
SciMLBase._unwrap_val(diff_type), typeof(termination_condition)}(termination_condition)
18+
end
19+
20+
function SciMLBase.__solve(prob::NonlinearProblem, alg::BatchedSimpleNewtonRaphson;
21+
abstol = nothing, reltol = nothing, maxiters = 1000, kwargs...)
22+
iip = SciMLBase.isinplace(prob)
23+
@assert !iip "BatchedSimpleNewtonRaphson currently only supports out-of-place nonlinear problems."
24+
u, f, reconstruct = _construct_batched_problem_structure(prob)
25+
26+
tc = alg.termination_condition
27+
mode = DiffEqBase.get_termination_mode(tc)
28+
29+
storage = _get_storage(mode, u)
30+
31+
xₙ, xₙ₋₁ = copy(u), copy(u)
32+
T = eltype(u)
33+
34+
atol = _get_tolerance(abstol, tc.abstol, T)
35+
rtol = _get_tolerance(reltol, tc.reltol, T)
36+
termination_condition = tc(storage)
37+
38+
for i in 1:maxiters
39+
if alg_autodiff(alg)
40+
fₙ, 𝓙 = value_derivative(f, xₙ)
41+
else
42+
fₙ = f(xₙ)
43+
𝓙 = FiniteDiff.finite_difference_jacobian(f, xₙ, diff_type(alg), eltype(xₙ), fₙ)
44+
end
45+
46+
iszero(fₙ) && return DiffEqBase.build_solution(prob,
47+
alg,
48+
reconstruct(xₙ),
49+
reconstruct(fₙ);
50+
retcode = ReturnCode.Success)
51+
52+
δx = reshape(𝓙 \ vec(fₙ), size(xₙ))
53+
xₙ .-= δx
54+
55+
if termination_condition(fₙ, xₙ, xₙ₋₁, atol, rtol)
56+
retcode, xₙ, fₙ = _result_from_storage(storage, xₙ, fₙ, f, mode, iip)
57+
return DiffEqBase.build_solution(prob,
58+
alg,
59+
reconstruct(xₙ),
60+
reconstruct(fₙ);
61+
retcode)
62+
end
63+
64+
xₙ₋₁ .= xₙ
65+
end
66+
67+
if mode DiffEqBase.SAFE_BEST_TERMINATION_MODES
68+
xₙ = storage.u
69+
fₙ = f(xₙ)
70+
end
71+
72+
return DiffEqBase.build_solution(prob,
73+
alg,
74+
reconstruct(xₙ),
75+
reconstruct(fₙ);
76+
retcode = ReturnCode.MaxIters)
77+
end

src/dfsane.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ function SimpleDFSane(; σ_min::Real = 1e-10, σ_max::Real = 1e10, σ_1::Real =
7373
batched::Bool = false,
7474
max_inner_iterations = 1000)
7575
if batched
76-
return SimpleBatchedDFSane(; σₘᵢₙ = σ_min,
76+
return BatchedSimpleDFSane(; σₘᵢₙ = σ_min,
7777
σₘₐₓ = σ_max,
7878
σ₁ = σ_1,
7979
M,

src/raphson.jl

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22
SimpleNewtonRaphson(; batched = false,
33
chunk_size = Val{0}(),
44
autodiff = Val{true}(),
5-
diff_type = Val{:forward})
5+
diff_type = Val{:forward},
6+
termination_condition = missing)
67
78
A low-overhead implementation of Newton-Raphson. This method is non-allocating on scalar
89
and static array problems.
@@ -27,11 +28,8 @@ and static array problems.
2728
- `diff_type`: the type of finite differencing used if `autodiff = false`. Defaults to
2829
`Val{:forward}` for forward finite differences. For more details on the choices, see the
2930
[FiniteDiff.jl](https://github.com/JuliaDiff/FiniteDiff.jl) documentation.
30-
31-
!!! note
32-
33-
To use the `batched` version, remember to load `AbstractDifferentiation` and
34-
`LinearSolve`.
31+
- `termination_condition`: control the termination of the algorithm. (Only works for batched
32+
problems)
3533
"""
3634
struct SimpleNewtonRaphson{CS, AD, FDT} <: AbstractNewtonAlgorithm{CS, AD, FDT} end
3735

@@ -44,16 +42,19 @@ function SimpleNewtonRaphson(; batched = false,
4442
throw(ArgumentError("`termination_condition` is currently only supported for batched problems"))
4543
end
4644
if batched
47-
@assert ADLinearSolveExtLoaded[] "Please install and load `LinearSolve.jl` and `AbstractDifferentiation.jl` to use batched Newton-Raphson."
45+
# @assert ADLinearSolveFDExtLoaded[] "Please install and load `LinearSolve.jl`, `FiniteDifferences.jl` and `AbstractDifferentiation.jl` to use batched Newton-Raphson."
4846
termination_condition = ismissing(termination_condition) ?
4947
NLSolveTerminationCondition(NLSolveTerminationMode.NLSolveDefault;
5048
abstol = nothing,
5149
reltol = nothing) :
5250
termination_condition
53-
return SimpleBatchedNewtonRaphson(; chunk_size,
51+
return BatchedSimpleNewtonRaphson(; chunk_size,
5452
autodiff,
5553
diff_type,
5654
termination_condition)
55+
return SimpleNewtonRaphson{SciMLBase._unwrap_val(chunk_size),
56+
SciMLBase._unwrap_val(autodiff),
57+
SciMLBase._unwrap_val(diff_type)}()
5758
end
5859
return SimpleNewtonRaphson{SciMLBase._unwrap_val(chunk_size),
5960
SciMLBase._unwrap_val(autodiff),

test/Project.toml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,8 @@
11
[deps]
2-
AbstractDifferentiation = "c29ec348-61ec-40c8-8164-b8c60e9d9f3d"
32
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
43
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
54
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
65
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
7-
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
86
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
97
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
108
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"

test/basictests.jl

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
using SimpleNonlinearSolve, StaticArrays, BenchmarkTools, DiffEqBase, LinearAlgebra, Test,
2-
NNlib, AbstractDifferentiation, LinearSolve
2+
NNlib
33

44
const BATCHED_BROYDEN_SOLVERS = []
55
const BROYDEN_SOLVERS = []
66
const BATCHED_LBROYDEN_SOLVERS = []
77
const LBROYDEN_SOLVERS = []
88
const BATCHED_DFSANE_SOLVERS = []
99
const DFSANE_SOLVERS = []
10+
const BATCHED_RAPHSON_SOLVERS = []
1011

1112
for mode in instances(NLSolveTerminationMode.T)
1213
if mode
@@ -23,6 +24,12 @@ for mode in instances(NLSolveTerminationMode.T)
2324
push!(BATCHED_LBROYDEN_SOLVERS, LBroyden(; batched = true, termination_condition))
2425
push!(DFSANE_SOLVERS, SimpleDFSane(; batched = false, termination_condition))
2526
push!(BATCHED_DFSANE_SOLVERS, SimpleDFSane(; batched = true, termination_condition))
27+
push!(BATCHED_RAPHSON_SOLVERS,
28+
SimpleNewtonRaphson(; batched = true,
29+
termination_condition))
30+
push!(BATCHED_RAPHSON_SOLVERS,
31+
SimpleNewtonRaphson(; batched = true, autodiff = false,
32+
termination_condition))
2633
end
2734

2835
# SimpleNewtonRaphson
@@ -483,7 +490,8 @@ sol = solve(probN, Broyden(batched = true))
483490

484491
@testset "Batched Solver: $(nameof(typeof(alg)))" for alg in (BATCHED_BROYDEN_SOLVERS...,
485492
BATCHED_LBROYDEN_SOLVERS...,
486-
BATCHED_DFSANE_SOLVERS...)
493+
BATCHED_DFSANE_SOLVERS...,
494+
BATCHED_RAPHSON_SOLVERS...)
487495
sol = solve(probN, alg; abstol = 1e-3, reltol = 1e-3)
488496

489497
@test sol.retcode == ReturnCode.Success

test/inplace.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
using SimpleNonlinearSolve, StaticArrays, BenchmarkTools, DiffEqBase, LinearAlgebra, Test,
2-
NNlib, AbstractDifferentiation, LinearSolve
2+
NNlib
33

4-
# Supported Solvers: BatchedBroyden, SimpleBatchedDFSane
4+
# Supported Solvers: BatchedBroyden, BatchedSimpleDFSane
55
function f!(du::AbstractArray{<:Number, N},
66
u::AbstractArray{<:Number, N},
77
p::AbstractVector) where {N}

0 commit comments

Comments
 (0)