diff --git a/ext/LinearSolveForwardDiffExt.jl b/ext/LinearSolveForwardDiffExt.jl index 31bfab6b2..018a9b8cd 100644 --- a/ext/LinearSolveForwardDiffExt.jl +++ b/ext/LinearSolveForwardDiffExt.jl @@ -76,9 +76,8 @@ function linearsolve_forwarddiff_solve!(cache::DualLinearCache, alg, args...; kw ∂_b = cache.partials_b xp_linsolve_rhs!(uu, ∂_A, ∂_b, cache) - + rhs_list = cache.rhs_list - cache.linear_cache.u .= cache.dual_u0_cache # We can reuse the linear cache, because the same factorization will work for the partials. for i in eachindex(rhs_list) @@ -177,7 +176,6 @@ function linearsolve_dual_solution!(dual_u::AbstractArray{DT}, u::AbstractArray, partials) where {T, V, N, DT <: Dual{T, V, N}} # Direct in-place construction of dual numbers without temporary allocations n_partials = length(partials) - for i in eachindex(u, dual_u) # Extract partials for this element directly partial_vals = ntuple(Val(N)) do j @@ -272,6 +270,13 @@ function __dual_init( else rhs_list = nothing end + # Use b for restructuring if sizes match (square system), otherwise use u (non-square) + # This preserves ComponentArray structure from b when possible + dual_u_init = if length(non_partial_cache.u) == length(b) + ArrayInterface.restructure(b, zeros(dual_type, length(b))) + else + ArrayInterface.restructure(non_partial_cache.u, zeros(dual_type, length(non_partial_cache.u))) + end return DualLinearCache{dual_type}( non_partial_cache, @@ -281,13 +286,13 @@ function __dual_init( partials_A_list, partials_b_list, rhs_list, - similar(new_b), - similar(new_b), - similar(new_b), + similar(non_partial_cache.u), # Use u's size, not b's size + similar(non_partial_cache.u), # primal_u_cache + similar(new_b), # primal_b_cache true, # Cache is initially valid A, b, - ArrayInterface.restructure(b, zeros(dual_type, length(b))) + dual_u_init ) end @@ -300,6 +305,7 @@ function SciMLBase.solve!( ForwardDiff.Dual} primal_sol = linearsolve_forwarddiff_solve!( cache::DualLinearCache, getfield(cache, :linear_cache).alg, args...; kwargs...) + dual_sol = linearsolve_dual_solution(getfield(cache, :linear_cache).u, getfield(cache, :rhs_list), cache) # For scalars, we still need to assign since cache.dual_u might not be pre-allocated