Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 13 additions & 7 deletions ext/LinearSolveForwardDiffExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,8 @@ function linearsolve_forwarddiff_solve!(cache::DualLinearCache, alg, args...; kw
∂_b = cache.partials_b

xp_linsolve_rhs!(uu, ∂_A, ∂_b, cache)

rhs_list = cache.rhs_list

cache.linear_cache.u .= cache.dual_u0_cache
# We can reuse the linear cache, because the same factorization will work for the partials.
for i in eachindex(rhs_list)
Expand Down Expand Up @@ -177,7 +176,6 @@ function linearsolve_dual_solution!(dual_u::AbstractArray{DT}, u::AbstractArray,
partials) where {T, V, N, DT <: Dual{T, V, N}}
# Direct in-place construction of dual numbers without temporary allocations
n_partials = length(partials)

for i in eachindex(u, dual_u)
# Extract partials for this element directly
partial_vals = ntuple(Val(N)) do j
Expand Down Expand Up @@ -272,6 +270,13 @@ function __dual_init(
else
rhs_list = nothing
end
# Use b for restructuring if sizes match (square system), otherwise use u (non-square)
# This preserves ComponentArray structure from b when possible
dual_u_init = if length(non_partial_cache.u) == length(b)
ArrayInterface.restructure(b, zeros(dual_type, length(b)))
else
ArrayInterface.restructure(non_partial_cache.u, zeros(dual_type, length(non_partial_cache.u)))
end

return DualLinearCache{dual_type}(
non_partial_cache,
Expand All @@ -281,13 +286,13 @@ function __dual_init(
partials_A_list,
partials_b_list,
rhs_list,
similar(new_b),
similar(new_b),
similar(new_b),
similar(non_partial_cache.u), # Use u's size, not b's size
similar(non_partial_cache.u), # primal_u_cache
similar(new_b), # primal_b_cache
true, # Cache is initially valid
A,
b,
ArrayInterface.restructure(b, zeros(dual_type, length(b)))
dual_u_init
)
end

Expand All @@ -300,6 +305,7 @@ function SciMLBase.solve!(
ForwardDiff.Dual}
primal_sol = linearsolve_forwarddiff_solve!(
cache::DualLinearCache, getfield(cache, :linear_cache).alg, args...; kwargs...)

dual_sol = linearsolve_dual_solution(getfield(cache, :linear_cache).u, getfield(cache, :rhs_list), cache)

# For scalars, we still need to assign since cache.dual_u might not be pre-allocated
Expand Down
Loading