11module SimpleNonlinearSolveADLinearSolveExt
22
3- using AbstractDifferentiation, ArrayInterface, DiffEqBase, LinearAlgebra, LinearSolve,
3+ using AbstractDifferentiation,
4+ ArrayInterface, DiffEqBase, LinearAlgebra, LinearSolve,
45 SimpleNonlinearSolve, SciMLBase
5- import SimpleNonlinearSolve: _construct_batched_problem_structure, _get_storage, _result_from_storage, _get_tolerance, @maybeinplace
6+ import SimpleNonlinearSolve: _construct_batched_problem_structure,
7+ _get_storage, _result_from_storage, _get_tolerance, @maybeinplace
68
79const AD = AbstractDifferentiation
810
@@ -20,19 +22,18 @@ function SimpleNonlinearSolve.SimpleBatchedNewtonRaphson(; chunk_size = Val{0}()
2022 # TODO : Use `diff_type`. FiniteDiff.jl is currently not available in AD.jl
2123 chunksize = SciMLBase. _unwrap_val (chunk_size) == 0 ? nothing : chunk_size
2224 ad = SciMLBase. _unwrap_val (autodiff) ?
23- AD. ForwardDiffBackend (; chunksize) :
24- AD. FiniteDifferencesBackend ()
25- return SimpleBatchedNewtonRaphson {typeof(ad), Nothing, typeof(termination_condition)} (
26- ad,
25+ AD. ForwardDiffBackend (; chunksize) :
26+ AD. FiniteDifferencesBackend ()
27+ return SimpleBatchedNewtonRaphson {typeof(ad), Nothing, typeof(termination_condition)} (ad,
2728 nothing ,
2829 termination_condition)
2930end
3031
3132function SciMLBase. __solve (prob:: NonlinearProblem ,
3233 alg:: SimpleBatchedNewtonRaphson ;
33- abstol= nothing ,
34- reltol= nothing ,
35- maxiters= 1000 ,
34+ abstol = nothing ,
35+ reltol = nothing ,
36+ maxiters = 1000 ,
3637 kwargs... )
3738 iip = isinplace (prob)
3839 @assert ! iip " SimpleBatchedNewtonRaphson currently only supports out-of-place nonlinear problems."
@@ -57,9 +58,9 @@ function SciMLBase.__solve(prob::NonlinearProblem,
5758 alg,
5859 reconstruct (xₙ),
5960 reconstruct (fₙ);
60- retcode= ReturnCode. Success)
61+ retcode = ReturnCode. Success)
6162
62- solve (LinearProblem (𝓙, vec (fₙ); u0= vec (δx)), alg. linsolve; kwargs... )
63+ solve (LinearProblem (𝓙, vec (fₙ); u0 = vec (δx)), alg. linsolve; kwargs... )
6364 xₙ .- = δx
6465
6566 if termination_condition (fₙ, xₙ, xₙ₋₁, atol, rtol)
@@ -83,7 +84,7 @@ function SciMLBase.__solve(prob::NonlinearProblem,
8384 alg,
8485 reconstruct (xₙ),
8586 reconstruct (fₙ);
86- retcode= ReturnCode. MaxIters)
87+ retcode = ReturnCode. MaxIters)
8788end
8889
8990end
0 commit comments