Skip to content
This repository was archived by the owner on May 15, 2025. It is now read-only.

Commit 2233c47

Browse files
committed
multivariate halley and some tests
1 parent 77fcd76 commit 2233c47

File tree

2 files changed

+71
-27
lines changed

2 files changed

+71
-27
lines changed

src/halley.jl

Lines changed: 61 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -42,10 +42,28 @@ function SciMLBase.__solve(prob::NonlinearProblem,
4242
maxiters = 1000, kwargs...)
4343
f = Base.Fix2(prob.f, prob.p)
4444
x = float(prob.u0)
45-
fx = f(x)
46-
# fx = float(prob.u0)
47-
if !isa(fx, Number) || !isa(x, Number)
48-
error("Halley currently only supports scalar-valued single-variable functions")
45+
# Defining all derivative expressions in one place before the iterations
46+
if isa(x, AbstractArray)
47+
if alg_autodiff(alg)
48+
n = length(x)
49+
a_dfdx(x) = ForwardDiff.jacobian(f, x)
50+
a_d2fdx(x) = ForwardDiff.jacobian(a_dfdx, x)
51+
A = Array{Union{Nothing, Number}}(nothing, n, n)
52+
#fx = f(x)
53+
else
54+
n = length(x)
55+
f_dfdx(x) = FiniteDiff.finite_difference_jacobian(f, x, diff_type(alg), eltype(x))
56+
f_d2fdx(x) = FiniteDiff.finite_difference_jacobian(f_dfdx, x, diff_type(alg), eltype(x))
57+
A = Array{Union{Nothing, Number}}(nothing, n, n)
58+
end
59+
elseif isa(x, Number)
60+
if alg_autodiff(alg)
61+
sa_dfdx(x) = ForwardDiff.derivative(f, x)
62+
sa_d2fdx(x) = ForwardDiff.derivative(sa_dfdx, x)
63+
else
64+
sf_dfdx(x) = FiniteDiff.finite_difference_derivative(f, x, diff_type(alg), eltype(x))
65+
sf_d2fdx(x) = FiniteDiff.finite_difference_derivative(sf_dfdx, x, diff_type(alg), eltype(x))
66+
end
4967
end
5068
T = typeof(x)
5169

@@ -65,22 +83,49 @@ function SciMLBase.__solve(prob::NonlinearProblem,
6583

6684
for i in 1:maxiters
6785
if alg_autodiff(alg)
68-
fx = f(x)
69-
dfdx(x) = ForwardDiff.derivative(f, x)
70-
dfx = dfdx(x)
71-
d2fx = ForwardDiff.derivative(dfdx, x)
86+
if isa(x, Number)
87+
fx = f(x)
88+
dfx = sa_dfdx(x)
89+
d2fx = sa_d2fdx(x)
90+
else
91+
fx = f(x)
92+
dfx = a_dfdx(x)
93+
d2fx = reshape(a_d2fdx(x), (n,n,n)) # A 3-dim Hessian Tensor
94+
ai = -(dfx \ fx)
95+
for j in 1:n
96+
tmp = transpose(d2fx[:, :, j] * ai)
97+
A[j, :] = tmp
98+
end
99+
bi = (dfx) \ (A * ai)
100+
ci = (ai .* ai) ./ (ai .+ (0.5 .* bi))
101+
end
72102
else
73-
fx = f(x)
74-
dfx = FiniteDiff.finite_difference_derivative(f, x, diff_type(alg), eltype(x),
75-
fx)
76-
d2fx = FiniteDiff.finite_difference_derivative(x -> FiniteDiff.finite_difference_derivative(f,
77-
x),
78-
x, diff_type(alg), eltype(x), fx)
103+
if isa(x, Number)
104+
fx = f(x)
105+
dfx = sf_dfdx(x)
106+
d2fx = sf_d2fdx(x)
107+
else
108+
fx = f(x)
109+
dfx = f_dfdx(x)
110+
d2fx = reshape(f_d2fdx(x), (n,n,n)) # A 3-dim Hessian Tensor
111+
ai = -(dfx \ fx)
112+
for j in 1:n
113+
tmp = transpose(d2fx[:, :, j] * ai)
114+
A[j, :] = tmp
115+
end
116+
bi = (dfx) \ (A * ai)
117+
ci = (ai .* ai) ./ (ai .+ (0.5 .* bi))
118+
end
79119
end
80120
iszero(fx) &&
81121
return SciMLBase.build_solution(prob, alg, x, fx; retcode = ReturnCode.Success)
82-
Δx = (2 * dfx^2 - fx * d2fx) \ (2fx * dfx)
83-
x -= Δx
122+
if isa(x, Number)
123+
Δx = (2 * dfx^2 - fx * d2fx) \ (2fx * dfx)
124+
x -= Δx
125+
else
126+
Δx = ci
127+
x += Δx
128+
end
84129
if isapprox(x, xo, atol = atol, rtol = rtol)
85130
return SciMLBase.build_solution(prob, alg, x, fx; retcode = ReturnCode.Success)
86131
end

test/basictests.jl

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -49,10 +49,10 @@ function benchmark_scalar(f, u0)
4949
sol = (solve(probN, Halley()))
5050
end
5151

52-
# function ff(u, p)
53-
# u .* u .- 2
54-
# end
55-
# const cu0 = @SVector[1.0, 1.0]
52+
function ff(u, p)
53+
u .* u .- 2
54+
end
55+
const cu0 = @SVector[1.0, 1.0]
5656
function sf(u, p)
5757
u * u - 2
5858
end
@@ -62,6 +62,10 @@ sol = benchmark_scalar(sf, csu0)
6262
@test sol.retcode === ReturnCode.Success
6363
@test sol.u * sol.u - 2 < 1e-9
6464

65+
sol = benchmark_scalar(ff, cu0)
66+
@test sol.retcode === ReturnCode.Success
67+
@test sol.u .* sol.u .- 2 < [1e-9, 1e-9]
68+
6569
if VERSION >= v"1.7"
6670
@test (@ballocated benchmark_scalar(sf, csu0)) == 0
6771
end
@@ -122,7 +126,7 @@ using ForwardDiff
122126
f, u0 = (u, p) -> u .* u .- p, @SVector[1.0, 1.0]
123127

124128
for alg in (SimpleNewtonRaphson(), LBroyden(), Klement(), SimpleTrustRegion(),
125-
SimpleDFSane(), BROYDEN_SOLVERS...)
129+
SimpleDFSane(), Halley(), BROYDEN_SOLVERS...)
126130
g = function (p)
127131
probN = NonlinearProblem{false}(f, csu0, p)
128132
sol = solve(probN, alg, abstol = 1e-9)
@@ -221,19 +225,14 @@ probN = NonlinearProblem(f, u0)
221225

222226
for alg in (SimpleNewtonRaphson(), SimpleNewtonRaphson(; autodiff = false),
223227
SimpleTrustRegion(),
224-
SimpleTrustRegion(; autodiff = false), LBroyden(), Klement(), SimpleDFSane(),
228+
SimpleTrustRegion(; autodiff = false), Halley(), Halley(; autodiff = false), LBroyden(), Klement(), SimpleDFSane(),
225229
BROYDEN_SOLVERS...)
226230
sol = solve(probN, alg)
227231

228232
@test sol.retcode == ReturnCode.Success
229233
@test sol.u[end] sqrt(2.0)
230234
end
231235

232-
# Separate Error check for Halley; will be included in above error checks for the improved Halley
233-
f, u0 = (u, p) -> u * u - 2.0, 1.0
234-
probN = NonlinearProblem(f, u0)
235-
236-
@test solve(probN, Halley()).u sqrt(2.0)
237236

238237
for u0 in [1.0, [1, 1.0]]
239238
local f, probN, sol

0 commit comments

Comments
 (0)