diff --git a/src/integrating_GK_sum.jl b/src/integrating_GK_sum.jl index cacf524b..26735201 100644 --- a/src/integrating_GK_sum.jl +++ b/src/integrating_GK_sum.jl @@ -10,6 +10,7 @@ mutable struct SavingIntegrandGKSumAffect{ gk_step_cache::IntegrandCacheType gk_err_cache::IntegrandCacheType tol::Float64 + integrand_inplace::Union{Nothing, Bool} end function integrate_gk!( @@ -18,44 +19,39 @@ function integrate_gk!( ) affect!.gk_step_cache = recursive_zero!(affect!.gk_step_cache) affect!.gk_err_cache = recursive_zero!(affect!.gk_err_cache) + isinplace_prob = DiffEqBase.isinplace(integrator.sol.prob) + inplace_integrand = affect!.integrand_inplace === nothing ? + (isinplace_prob && affect!.integrand_cache !== nothing) : + affect!.integrand_inplace for i in 1:(2 * order + 1) t_temp = (gk_points[order][i] + 1) * ((bound_r - bound_l) / 2) + bound_l - if DiffEqBase.isinplace(integrator.sol.prob) + if isinplace_prob curu = first(get_tmp_cache(integrator)) integrator(curu, t_temp) - if affect!.integrand_cache == nothing - recursive_axpy!( - gk_weights[order][i], - affect!.integrand_func(curu, t_temp, integrator), affect!.gk_step_cache - ) - if i % 2 == 0 - recursive_axpy!( - g_weights[order][div(i, 2)], - affect!.integrand_func(curu, t_temp, integrator), affect!.gk_err_cache - ) - end - else - affect!.integrand_func(affect!.integrand_cache, curu, t_temp, integrator) + else + curu = integrator(t_temp) + end + if inplace_integrand + affect!.integrand_func(affect!.integrand_cache, curu, t_temp, integrator) + recursive_axpy!( + gk_weights[order][i], + affect!.integrand_cache, affect!.gk_step_cache + ) + if i % 2 == 0 recursive_axpy!( - gk_weights[order][i], - affect!.integrand_cache, affect!.gk_step_cache + g_weights[order][div(i, 2)], + affect!.integrand_cache, affect!.gk_err_cache ) - if i % 2 == 0 - recursive_axpy!( - g_weights[order][div(i, 2)], - affect!.integrand_cache, affect!.gk_err_cache - ) - end end else recursive_axpy!( gk_weights[order][i], - affect!.integrand_func(integrator(t_temp), t_temp, integrator), affect!.gk_step_cache + affect!.integrand_func(curu, t_temp, integrator), affect!.gk_step_cache ) if i % 2 == 0 recursive_axpy!( g_weights[order][div(i, 2)], - affect!.integrand_func(integrator(t_temp), t_temp, integrator), affect!.gk_err_cache + affect!.integrand_func(curu, t_temp, integrator), affect!.gk_err_cache ) end end @@ -108,6 +104,18 @@ returns Integral(integrand_func(u(t),t)dt over the problem tspan. that `integrand_func` will output (or higher compatible type). - `integrand_prototype` is a prototype of the output from the integrand. +## Keyword Arguments + + - `integrand_inplace = nothing`: controls which form of `integrand_func` is called. + With the default `nothing`, the in-place `integrand_func(out, u, t, integrator)` + form is used for in-place problems (when an `integrand_prototype` is given) and + the allocating `integrand_func(u, t, integrator)` form for out-of-place problems. + Pass `integrand_inplace = true` to force the in-place form even for out-of-place + problems — the integrand output (e.g. a parameter-shaped buffer) may be mutable + even when the state is immutable, which avoids allocating the output on every + quadrature node. Requires an `integrand_prototype`. Pass `integrand_inplace = false` + to force the allocating form. + The outputted values are saved into `integrand_values`. The values are found via `integrand_values.integrand`. @@ -119,11 +127,21 @@ via `integrand_values.integrand`. solvers are required. """ function IntegratingGKSumCallback( - integrand_func, integrand_values::IntegrandValuesSum, integrand_prototype, tol = 1.0e-7 + integrand_func, integrand_values::IntegrandValuesSum, integrand_prototype, + tol = 1.0e-7; + integrand_inplace::Union{Nothing, Bool} = nothing ) + if integrand_inplace === true && integrand_prototype === nothing + throw( + ArgumentError( + "integrand_inplace = true requires an integrand_prototype to use as the output buffer." + ) + ) + end affect! = SavingIntegrandGKSumAffect( integrand_func, integrand_values, integrand_prototype, - allocate_zeros(integrand_prototype), allocate_zeros(integrand_prototype), allocate_zeros(integrand_prototype), tol + allocate_zeros(integrand_prototype), allocate_zeros(integrand_prototype), + allocate_zeros(integrand_prototype), tol, integrand_inplace ) condition = true_condition return DiscreteCallback(condition, affect!, save_positions = (false, false)) diff --git a/src/integrating_sum.jl b/src/integrating_sum.jl index 07b57df2..d69b2266 100644 --- a/src/integrating_sum.jl +++ b/src/integrating_sum.jl @@ -69,6 +69,15 @@ mutable struct SavingIntegrandSumAffect{IntegrandFunc, integrandType, IntegrandC integrand_values::IntegrandValuesSum{integrandType} integrand_cache::IntegrandCacheType accumulation_cache::IntegrandCacheType + integrand_inplace::Union{Nothing, Bool} +end + +function SavingIntegrandSumAffect( + integrand_func, integrand_values, integrand_cache, accumulation_cache + ) + return SavingIntegrandSumAffect( + integrand_func, integrand_values, integrand_cache, accumulation_cache, nothing + ) end function (affect!::SavingIntegrandSumAffect)(integrator) @@ -79,26 +88,27 @@ function (affect!::SavingIntegrandSumAffect)(integrator) n = div(SciMLBase.alg_order(integrator.alg) + 1, 2) end accumulation_cache = recursive_zero!(affect!.accumulation_cache) + isinplace_prob = DiffEqBase.isinplace(integrator.sol.prob) + inplace_integrand = affect!.integrand_inplace === nothing ? + (isinplace_prob && affect!.integrand_cache !== nothing) : + affect!.integrand_inplace for i in 1:n t_temp = ((integrator.t - integrator.tprev) / 2) * gauss_points[n][i] + (integrator.t + integrator.tprev) / 2 - if DiffEqBase.isinplace(integrator.sol.prob) + if isinplace_prob curu = first(get_tmp_cache(integrator)) integrator(curu, t_temp) - if affect!.integrand_cache === nothing - recursive_axpy!( - gauss_weights[n][i], - affect!.integrand_func(curu, t_temp, integrator), accumulation_cache - ) - else - affect!.integrand_func(affect!.integrand_cache, curu, t_temp, integrator) - recursive_axpy!( - gauss_weights[n][i], affect!.integrand_cache, accumulation_cache - ) - end + else + curu = integrator(t_temp) + end + if inplace_integrand + affect!.integrand_func(affect!.integrand_cache, curu, t_temp, integrator) + recursive_axpy!( + gauss_weights[n][i], affect!.integrand_cache, accumulation_cache + ) else recursive_axpy!( gauss_weights[n][i], - affect!.integrand_func(integrator(t_temp), t_temp, integrator), accumulation_cache + affect!.integrand_func(curu, t_temp, integrator), accumulation_cache ) end end @@ -128,6 +138,18 @@ returns Integral(integrand_func(u(t),t)dt over the problem tspan. that `integrand_func` will output (or higher compatible type). - `integrand_prototype` is a prototype of the output from the integrand. +## Keyword Arguments + + - `integrand_inplace = nothing`: controls which form of `integrand_func` is called. + With the default `nothing`, the in-place `integrand_func(out, u, t, integrator)` + form is used for in-place problems (when an `integrand_prototype` is given) and + the allocating `integrand_func(u, t, integrator)` form for out-of-place problems. + Pass `integrand_inplace = true` to force the in-place form even for out-of-place + problems — the integrand output (e.g. a parameter-shaped buffer) may be mutable + even when the state is immutable, which avoids allocating the output on every + quadrature node. Requires an `integrand_prototype`. Pass `integrand_inplace = false` + to force the allocating form. + The outputted values are saved into `integrand_values`. The values are found via `integrand_values.integrand`. @@ -137,11 +159,19 @@ via `integrand_values.integrand`. solvers are required. """ function IntegratingSumCallback( - integrand_func, integrand_values::IntegrandValuesSum, integrand_prototype + integrand_func, integrand_values::IntegrandValuesSum, integrand_prototype; + integrand_inplace::Union{Nothing, Bool} = nothing ) + if integrand_inplace === true && integrand_prototype === nothing + throw( + ArgumentError( + "integrand_inplace = true requires an integrand_prototype to use as the output buffer." + ) + ) + end affect! = SavingIntegrandSumAffect( integrand_func, integrand_values, integrand_prototype, - allocate_zeros(integrand_prototype) + allocate_zeros(integrand_prototype), integrand_inplace ) condition = true_condition return DiscreteCallback(condition, affect!, save_positions = (false, false)) diff --git a/test/integrating_GK_sum_tests.jl b/test/integrating_GK_sum_tests.jl index adcd67c4..3e405695 100644 --- a/test/integrating_GK_sum_tests.jl +++ b/test/integrating_GK_sum_tests.jl @@ -80,3 +80,22 @@ dGdp_analytical = analytical_derivative(p, tspan[end]) @test isapprox( dGdp_analytical, integrand_values_inplace.integrand, atol = 1.0e-11, rtol = 1.0e-11 ) + +# integrand_inplace = true: in-place integrand with an out-of-place problem +using StaticArrays +prob_oop = ODEProblem((u, p, t) -> SVector(1.0), SVector(0.0), (0.0, 1.0)) +integrated = IntegrandValuesSum(zeros(1)) +sol = solve( + prob_oop, Euler(), + callback = IntegratingGKSumCallback( + (out, u, t, integrator) -> (out[1] = u[1]; nothing), integrated, Float64[0.0]; + integrand_inplace = true + ), + dt = 0.1 +) +@test integrated.integrand[1] == 0.5 + +@test_throws ArgumentError IntegratingGKSumCallback( + (out, u, t, integrator) -> nothing, IntegrandValuesSum(zeros(1)), nothing; + integrand_inplace = true +) diff --git a/test/integrating_sum_tests.jl b/test/integrating_sum_tests.jl index 5acc211e..093d7b54 100644 --- a/test/integrating_sum_tests.jl +++ b/test/integrating_sum_tests.jl @@ -20,3 +20,37 @@ sol = solve( dt = 0.1 ) @test integrated.integrand[1] == 0.5 + +# integrand_inplace = true: in-place integrand with an out-of-place problem +# (e.g. immutable state with a mutable, parameter-shaped integrand buffer) +using StaticArrays +prob_oop = ODEProblem((u, p, t) -> SVector(1.0), SVector(0.0), (0.0, 1.0)) +integrated = IntegrandValuesSum(zeros(1)) +sol = solve( + prob_oop, Euler(), + callback = IntegratingSumCallback( + (out, u, t, integrator) -> (out[1] = u[1]; nothing), integrated, Float64[0.0]; + integrand_inplace = true + ), + dt = 0.1 +) +@test integrated.integrand[1] == 0.5 + +# integrand_inplace = true requires a prototype +@test_throws ArgumentError IntegratingSumCallback( + (out, u, t, integrator) -> nothing, IntegrandValuesSum(zeros(1)), nothing; + integrand_inplace = true +) + +# integrand_inplace = false forces the allocating form even for in-place problems +prob_iip = ODEProblem((du, u, p, t) -> (du[1] = 1.0; nothing), [0.0], (0.0, 1.0)) +integrated = IntegrandValuesSum(zeros(1)) +sol = solve( + prob_iip, Euler(), + callback = IntegratingSumCallback( + (u, t, integrator) -> [u[1]], integrated, Float64[0.0]; + integrand_inplace = false + ), + dt = 0.1 +) +@test integrated.integrand[1] == 0.5