Skip to content

Commit 88ec3d5

Browse files
Fix Tridiagonal cache mutation with non-allocating DirectLdiv! (issue #825)
The issue: DirectLdiv! for Tridiagonal matrices calls ldiv!(u, A, b) which performs in-place LU factorization, mutating cache.A and breaking subsequent solves with the same cache. The fix preserves DirectLdiv! for performance while preventing cache.A mutation: - Add init_cacheval for DirectLdiv! with Tridiagonal/SymTridiagonal that allocates a copy of the matrix during init (one-time allocation) - Add specialized solve! methods that copy cache.A values to the cached workspace before calling ldiv! (non-allocating copyto!) - The original cache.A is never touched, preserving it for subsequent solves This gives the best of both worlds: - Fast DirectLdiv! performance (direct ldiv! without lu factorization overhead) - Non-destructive behavior (cache.A is preserved) - Minimal allocations during solve! (same 48 bytes as LUFactorization) Added regression test that verifies: - Default algorithm for Tridiagonal is DirectLdiv! on Julia 1.11+ - cache.A is not mutated after solve! - Multiple solves with same cache give correct answers - Minimal allocations during solve! 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
1 parent c8cd7a9 commit 88ec3d5

File tree

3 files changed

+125
-25
lines changed

3 files changed

+125
-25
lines changed

src/default.jl

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ end
134134

135135
function defaultalg(A::Tridiagonal, b, assump::OperatorAssumptions{Bool})
136136
if assump.issq
137-
@static if VERSION>=v"1.11"
137+
@static if VERSION >= v"1.11"
138138
DirectLdiv!()
139139
else
140140
DefaultLinearSolver(DefaultAlgorithmChoice.LUFactorization)
@@ -239,10 +239,11 @@ Get the tuned algorithm preference for the given element type and matrix size.
239239
Returns `nothing` if no preference exists. Uses preloaded constants for efficiency.
240240
Fast path when no preferences are set.
241241
"""
242-
@inline function get_tuned_algorithm(::Type{eltype_A}, ::Type{eltype_b}, matrix_size::Integer) where {eltype_A, eltype_b}
242+
@inline function get_tuned_algorithm(
243+
::Type{eltype_A}, ::Type{eltype_b}, matrix_size::Integer) where {eltype_A, eltype_b}
243244
# Determine the element type to use for preference lookup
244245
target_eltype = eltype_A !== Nothing ? eltype_A : eltype_b
245-
246+
246247
# Determine size category based on matrix size (matching LinearSolveAutotune categories)
247248
size_category = if matrix_size <= 20
248249
:tiny
@@ -255,10 +256,10 @@ Fast path when no preferences are set.
255256
else
256257
:big
257258
end
258-
259+
259260
# Fast path: if no preferences are set, return nothing immediately
260261
AUTOTUNE_PREFS_SET || return nothing
261-
262+
262263
# Look up the tuned algorithm from preloaded constants with type specialization
263264
return _get_tuned_algorithm_impl(target_eltype, size_category)
264265
end
@@ -286,11 +287,10 @@ end
286287

287288
@inline _get_tuned_algorithm_impl(::Type, ::Symbol) = nothing # Fallback for other types
288289

289-
290-
291290
# Convenience method for when A is nothing - delegate to main implementation
292-
@inline get_tuned_algorithm(::Type{Nothing}, ::Type{eltype_b}, matrix_size::Integer) where {eltype_b} =
293-
get_tuned_algorithm(eltype_b, eltype_b, matrix_size)
291+
@inline get_tuned_algorithm(::Type{Nothing},
292+
::Type{eltype_b},
293+
matrix_size::Integer) where {eltype_b} = get_tuned_algorithm(eltype_b, eltype_b, matrix_size)
294294

295295
# Allows A === nothing as a stand-in for dense matrix
296296
function defaultalg(A, b, assump::OperatorAssumptions{Bool})
@@ -304,7 +304,7 @@ function defaultalg(A, b, assump::OperatorAssumptions{Bool})
304304
ArrayInterface.can_setindex(b) &&
305305
(__conditioning(assump) === OperatorCondition.IllConditioned ||
306306
__conditioning(assump) === OperatorCondition.WellConditioned)
307-
307+
308308
# Small matrix override - always use GenericLUFactorization for tiny problems
309309
if length(b) <= 10
310310
DefaultAlgorithmChoice.GenericLUFactorization
@@ -313,7 +313,7 @@ function defaultalg(A, b, assump::OperatorAssumptions{Bool})
313313
matrix_size = length(b)
314314
eltype_A = A === nothing ? Nothing : eltype(A)
315315
tuned_alg = get_tuned_algorithm(eltype_A, eltype(b), matrix_size)
316-
316+
317317
if tuned_alg !== nothing
318318
tuned_alg
319319
elseif appleaccelerate_isavailable() && b isa Array &&
@@ -513,7 +513,7 @@ end
513513
newex = quote
514514
sol = SciMLBase.solve!(cache, $(algchoice_to_alg(alg)), args...; kwargs...)
515515
if sol.retcode === ReturnCode.Failure && alg.safetyfallback
516-
@SciMLMessage("LU factorization failed, falling back to QR factorization. `A` is potentially rank-deficient.",
516+
@SciMLMessage("LU factorization failed, falling back to QR factorization. `A` is potentially rank-deficient.",
517517
cache.verbose, :default_lu_fallback)
518518
sol = SciMLBase.solve!(
519519
cache, QRFactorization(ColumnNorm()), args...; kwargs...)
@@ -641,7 +641,8 @@ end
641641
@generated function defaultalg_adjoint_eval(cache::LinearCache, dy)
642642
ex = :()
643643
for alg in first.(EnumX.symbol_map(DefaultAlgorithmChoice.T))
644-
newex = if alg in Symbol.((DefaultAlgorithmChoice.RFLUFactorization, DefaultAlgorithmChoice.GenericLUFactorization))
644+
newex = if alg in Symbol.((DefaultAlgorithmChoice.RFLUFactorization,
645+
DefaultAlgorithmChoice.GenericLUFactorization))
645646
quote
646647
getproperty(cache.cacheval, $(Meta.quot(alg)))[1]' \ dy
647648
end

src/solve_function.jl

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,3 +84,45 @@ function SciMLBase.solve!(cache::LinearCache, alg::DirectLdiv!, args...; kwargs.
8484

8585
return SciMLBase.build_linear_solution(alg, u, nothing, cache)
8686
end
87+
88+
# Specialized handling for Tridiagonal matrices to avoid mutating cache.A
89+
# ldiv! for Tridiagonal performs in-place LU factorization which would corrupt the cache.
90+
# We cache a copy of the Tridiagonal matrix and use that for the factorization.
91+
# See https://github.com/SciML/LinearSolve.jl/issues/825
92+
93+
function init_cacheval(alg::DirectLdiv!, A::Tridiagonal, b, u, Pl, Pr, maxiters::Int,
94+
abstol, reltol, verbose::Union{LinearVerbosity, Bool},
95+
assumptions::OperatorAssumptions)
96+
# Allocate a copy of the Tridiagonal matrix to use as workspace for ldiv!
97+
return copy(A)
98+
end
99+
100+
function init_cacheval(alg::DirectLdiv!, A::SymTridiagonal, b, u, Pl, Pr, maxiters::Int,
101+
abstol, reltol, verbose::Union{LinearVerbosity, Bool},
102+
assumptions::OperatorAssumptions)
103+
# SymTridiagonal also gets mutated by ldiv!, cache a copy
104+
return copy(A)
105+
end
106+
107+
function SciMLBase.solve!(cache::LinearCache{<:Tridiagonal}, alg::DirectLdiv!,
108+
args...; kwargs...)
109+
(; A, b, u, cacheval) = cache
110+
# Copy current A values into the cached workspace (non-allocating)
111+
copyto!(cacheval.dl, A.dl)
112+
copyto!(cacheval.d, A.d)
113+
copyto!(cacheval.du, A.du)
114+
# Perform ldiv! on the copy, preserving the original A
115+
ldiv!(u, cacheval, b)
116+
return SciMLBase.build_linear_solution(alg, u, nothing, cache)
117+
end
118+
119+
function SciMLBase.solve!(cache::LinearCache{<:SymTridiagonal}, alg::DirectLdiv!,
120+
args...; kwargs...)
121+
(; A, b, u, cacheval) = cache
122+
# Copy current A values into the cached workspace (non-allocating)
123+
copyto!(cacheval.dv, A.dv)
124+
copyto!(cacheval.ev, A.ev)
125+
# Perform ldiv! on the copy, preserving the original A
126+
ldiv!(u, cacheval, b)
127+
return SciMLBase.build_linear_solution(alg, u, nothing, cache)
128+
end

test/basictests.jl

Lines changed: 69 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ A4 = A2 .|> ComplexF32
3434
b4 = b2 .|> ComplexF32
3535
x4 = x2 .|> ComplexF32
3636

37-
A5_ = A - 0.01Tridiagonal(ones(n,n)) + sparse([1], [8], 0.5, n,n)
37+
A5_ = A - 0.01Tridiagonal(ones(n, n)) + sparse([1], [8], 0.5, n, n)
3838
A5 = sparse(transpose(A5_) * A5_)
3939
x5 = zeros(n)
4040
u5 = ones(n)
@@ -46,7 +46,7 @@ prob3 = LinearProblem(A3, b3; u0 = x3)
4646
prob4 = LinearProblem(A4, b4; u0 = x4)
4747
prob5 = LinearProblem(A5, b5)
4848

49-
cache_kwargs = (;abstol = 1e-8, reltol = 1e-8, maxiter = 30)
49+
cache_kwargs = (; abstol = 1e-8, reltol = 1e-8, maxiter = 30)
5050

5151
function test_interface(alg, prob1, prob2)
5252
A1, b1 = prob1.A, prob1.b
@@ -79,10 +79,10 @@ end
7979

8080
function test_tolerance_update(alg, prob, u)
8181
cache = init(prob, alg)
82-
LinearSolve.update_tolerances!(cache; reltol = 1e-2, abstol=1e-8)
82+
LinearSolve.update_tolerances!(cache; reltol = 1e-2, abstol = 1e-8)
8383
u1 = copy(solve!(cache).u)
8484

85-
LinearSolve.update_tolerances!(cache; reltol = 1e-8, abstol=1e-8)
85+
LinearSolve.update_tolerances!(cache; reltol = 1e-8, abstol = 1e-8)
8686
u2 = solve!(cache).u
8787

8888
@test norm(u2 - u) < norm(u1 - u)
@@ -303,30 +303,86 @@ end
303303
ρ = 0.95
304304
A_tri = SymTridiagonal(ones(k) .+ ρ^2, -ρ * ones(k-1))
305305
b = rand(k)
306-
306+
307307
# Test with explicit LDLtFactorization
308308
prob_tri = LinearProblem(A_tri, b)
309309
sol = solve(prob_tri, LDLtFactorization())
310310
@test A_tri * sol.u b
311-
311+
312312
# Test that default algorithm uses LDLtFactorization for SymTridiagonal
313313
default_alg = LinearSolve.defaultalg(A_tri, b, OperatorAssumptions(true))
314314
@test default_alg isa LinearSolve.DefaultLinearSolver
315315
@test default_alg.alg == LinearSolve.DefaultAlgorithmChoice.LDLtFactorization
316-
316+
317317
# Test that the factorization is cached and reused
318318
cache = init(prob_tri, LDLtFactorization())
319319
sol1 = solve!(cache)
320320
@test A_tri * sol1.u b
321321
@test !cache.isfresh # Cache should not be fresh after first solve
322-
322+
323323
# Solve again with same matrix to ensure cache is reused
324324
cache.b = rand(k) # Change RHS
325325
sol2 = solve!(cache)
326326
@test A_tri * sol2.u cache.b
327327
@test !cache.isfresh # Cache should still not be fresh
328328
end
329329

330+
@testset "Tridiagonal cache not mutated (issue #825)" begin
331+
# Test that solving with Tridiagonal does not mutate cache.A
332+
# See https://github.com/SciML/LinearSolve.jl/issues/825
333+
k = 6
334+
lower = ones(k - 1)
335+
diag = -2 * ones(k)
336+
upper = ones(k - 1)
337+
A_tri = Tridiagonal(lower, diag, upper)
338+
b = rand(k)
339+
340+
# Store original matrix values for comparison
341+
A_orig = Tridiagonal(copy(lower), copy(diag), copy(upper))
342+
343+
# Test that default algorithm uses DirectLdiv! for Tridiagonal on Julia 1.11+
344+
default_alg = LinearSolve.defaultalg(A_tri, b, OperatorAssumptions(true))
345+
@static if VERSION >= v"1.11"
346+
@test default_alg isa DirectLdiv!
347+
else
348+
@test default_alg isa LinearSolve.DefaultLinearSolver
349+
@test default_alg.alg == LinearSolve.DefaultAlgorithmChoice.LUFactorization
350+
end
351+
352+
# Test with default algorithm
353+
prob_tri = LinearProblem(A_tri, b)
354+
cache = init(prob_tri)
355+
356+
# Verify solution is correct
357+
sol1 = solve!(cache)
358+
@test A_orig * sol1.u b
359+
360+
# Verify cache.A is not mutated
361+
@test cache.A A_orig
362+
363+
# Verify multiple solves give correct answers
364+
b2 = rand(k)
365+
cache.b = b2
366+
sol2 = solve!(cache)
367+
@test A_orig * sol2.u b2
368+
369+
# Cache.A should still be unchanged
370+
@test cache.A A_orig
371+
372+
# Verify solve! allocates minimally after first solve (warm-up)
373+
# The small allocation (48 bytes) is from the return type construction,
374+
# same as other factorization methods like LUFactorization
375+
@static if VERSION >= v"1.11"
376+
# Warm up
377+
for _ in 1:3
378+
solve!(cache)
379+
end
380+
# Test minimal allocations (same as LUFactorization)
381+
allocs = @allocated solve!(cache)
382+
@test allocs <= 64 # Allow small overhead from return type
383+
end
384+
end
385+
330386
test_algs = [
331387
LUFactorization(),
332388
QRFactorization(),
@@ -680,8 +736,10 @@ end
680736
prob3 = LinearProblem(op1, b1; u0 = x1)
681737
prob4 = LinearProblem(op2, b2; u0 = x2)
682738

683-
@test LinearSolve.defaultalg(op1, x1).alg === LinearSolve.DefaultAlgorithmChoice.DirectLdiv!
684-
@test LinearSolve.defaultalg(op2, x2).alg === LinearSolve.DefaultAlgorithmChoice.DirectLdiv!
739+
@test LinearSolve.defaultalg(op1, x1).alg ===
740+
LinearSolve.DefaultAlgorithmChoice.DirectLdiv!
741+
@test LinearSolve.defaultalg(op2, x2).alg ===
742+
LinearSolve.DefaultAlgorithmChoice.DirectLdiv!
685743
@test LinearSolve.defaultalg(op3, x1).alg ===
686744
LinearSolve.DefaultAlgorithmChoice.KrylovJL_GMRES
687745
@test LinearSolve.defaultalg(op4, x2).alg ===
@@ -800,7 +858,6 @@ end
800858
reinit!(cache; A = B1, b = b1)
801859
u = solve!(cache)
802860
@test norm(u - u0, Inf) < 1.0e-8
803-
804861
end
805862

806863
@testset "ParallelSolves" begin
@@ -818,7 +875,7 @@ end
818875
for i in 1:2
819876
@test sol[i] U[i]
820877
end
821-
878+
822879
Threads.@threads for i in 1:2
823880
sol[i] = solve(LinearProblem(A_sparse, B[i]), KLUFactorization())
824881
end

0 commit comments

Comments
 (0)