Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion src/probints.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,16 @@
# Read the integrator's scalar error estimate. Older integrators stored it as the
# `EEst` field directly; current OrdinaryDiffEqCore moved it onto the controller
# cache (`integrator.controller_cache.EEst`, exposed there as `get_EEst`). Reading
# it this way keeps DiffEqCallbacks free of an OrdinaryDiffEqCore dependency while
# supporting both integrator layouts.
function _integrator_EEst(integrator)
if hasproperty(integrator, :EEst)
return getproperty(integrator, :EEst)
else
return integrator.controller_cache.EEst
end
end

struct ProbIntsCache{T}
σ::T
order::Int
Expand Down Expand Up @@ -40,7 +53,7 @@ struct AdaptiveProbIntsCache
order::Int
end
function (p::AdaptiveProbIntsCache)(integrator)
return integrator.u .= integrator.u .+ integrator.EEst * sqrt(integrator.dt^(2 * p.order)) * randn(size(integrator.u))
return integrator.u .= integrator.u .+ _integrator_EEst(integrator) * sqrt(integrator.dt^(2 * p.order)) * randn(size(integrator.u))
end

"""
Expand Down
6 changes: 5 additions & 1 deletion test/AD/saving_tracker_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,11 @@ u0 = TrackedArray([1.0f0, 0.0f0, 0.0f0])
tspan = TrackedArray([0.0f0, 1.0f0])
prob = ODEProblem{false}(rober, u0, tspan, p)
saved_values = SavedValues(eltype(tspan), eltype(p))
cb = SavingCallback((u, t, integrator) -> integrator.EEst * integrator.dt, saved_values)
# OrdinaryDiffEqCore's ODEIntegrator no longer has an `EEst` field; the scalar
# error estimate now lives on the controller cache.
eest(integrator) = hasproperty(integrator, :EEst) ? integrator.EEst :
integrator.controller_cache.EEst
cb = SavingCallback((u, t, integrator) -> eest(integrator) * integrator.dt, saved_values)

@test !all(
iszero.(
Expand Down
1 change: 1 addition & 0 deletions test/independentlylinearizedtests.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using Test, DiffEqCallbacks
using DiffEqCallbacks: sample, store!, IndependentlyLinearizedSolutionChunks, finish!
using SciMLBase: ReturnCode

@testset "IndependentlyLinearizedSolution" begin
ils = IndependentlyLinearizedSolution{Float64, Float64}(
Expand Down
134 changes: 130 additions & 4 deletions test/integrating_GK_sum_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,136 @@ sol = solve(

#### TESTING ON LINEAR SYSTEM WITH ANALYTICAL SOLUTION ####

# Reuse shared helper functions defined in integrating_GK_tests.jl:
# compute_dGdp, compute_dGdp_nt, simple_linear_system, adjoint_linear,
# adjoint_linear_inplace, analytical_derivative, callback_saving_linear,
# callback_saving_linear_inplace
function compute_dGdp(integrand)
temp = zeros(length(integrand.integrand), length(integrand.integrand[1]))
for i in 1:length(integrand.integrand)
for j in 1:length(integrand.integrand[1])
temp[i, j] = integrand.integrand[i][j]
end
end
return sum(temp, dims = 1)[:]
end

function simple_linear_system(u, p, t)
a, b = p
return [-a * u[2], b * u[1]]
end

function adjoint_linear(u, p, t, sol)
a, b = p
return -[0 b; -a 0] * u - 2.0 * (sol(t) .- 1.0)
end

function adjoint_linear_inplace(du, u, p, t, sol)
a, b = p
return du .= -[0 b; -a 0] * u - 2.0 * (sol(t) .- 1.0)
end

function analytical_derivative(p, t)
a, b = p
d1 = (
b * (
(cos(2t * sqrt(a * b)) - 4cos(t * sqrt(a * b))) / (b * sqrt(a / b)) +
(b * t * cos(2t * sqrt(a * b))) / sqrt(a * b) +
(-4b * t * cos(t * sqrt(a * b))) / sqrt(a * b) +
2(
(2b * t * sin(t * sqrt(a * b))) / sqrt(a * b) +
(-b * t * sin(2t * sqrt(a * b))) / sqrt(a * b)
) * sqrt(a / b)
) +
(b * t * (a + 3b)) / sqrt(a * b) +
(-a * b * t * cos(2t * sqrt(a * b))) / sqrt(a * b) + 3 / sqrt(a / b) +
2t * sqrt(a * b) - sin(2t * sqrt(a * b))
) / (4.0(a^0.5) * (b^1.5)) +
(
(3a * (b / (a^2))) / sqrt(b / a) +
(a * (b / (a^2)) * cos(2t * sqrt(a * b))) / sqrt(b / a) +
(b * t * (b + 3a)) / sqrt(a * b) +
(b * t * (a - b) * cos(2t * sqrt(a * b))) / sqrt(a * b) +
(-4a * (b / (a^2)) * cos(t * sqrt(a * b))) / sqrt(b / a) +
(-4a * b * t * cos(t * sqrt(a * b))) / sqrt(a * b) +
(2a * b * t * sqrt(b / a) * sin(2t * sqrt(a * b))) / sqrt(a * b) +
(-4a * b * t * sqrt(b / a) * sin(t * sqrt(a * b))) / sqrt(a * b) +
6t * sqrt(a * b) + 8sqrt(b / a) * cos(t * sqrt(a * b)) + sin(2t * sqrt(a * b)) -
6sqrt(b / a) - 8sin(t * sqrt(a * b)) - 2sqrt(b / a) * cos(2t * sqrt(a * b))
) /
(4.0(a^1.5) * (b^0.5)) +
(
-2.0(b^1.5) *
(
(
b * (
2(cos(2t * sqrt(a * b)) - 4cos(t * sqrt(a * b))) * sqrt(a / b) +
sin(2t * sqrt(a * b)) - 8sin(t * sqrt(a * b))
) + 6b * sqrt(a / b) +
2t * (a + 3b) * sqrt(a * b) - a * sin(2t * sqrt(a * b))
) / (16.0a * (b^3.0))
)
) /
(a^0.5) -
6.0(a^0.5) * (b^0.5) *
(
(
(a - b) * sin(2t * sqrt(a * b)) + 8a * sqrt(b / a) * cos(t * sqrt(a * b)) +
2t * (b + 3a) * sqrt(a * b) - 6a * sqrt(b / a) - 8a * sin(t * sqrt(a * b)) -
2a * sqrt(b / a) * cos(2t * sqrt(a * b))
) / (16.0b * (a^3.0))
)
d2 = (
b * (
(a * t * cos(2t * sqrt(a * b))) / sqrt(a * b) +
(-(a / (b^2)) * (cos(2t * sqrt(a * b)) - 4cos(t * sqrt(a * b)))) / sqrt(a / b) +
(-4a * t * cos(t * sqrt(a * b))) / sqrt(a * b) +
2(
(2a * t * sin(t * sqrt(a * b))) / sqrt(a * b) +
(-a * t * sin(2t * sqrt(a * b))) / sqrt(a * b)
) * sqrt(a / b)
) +
(-3b * (a / (b^2))) / sqrt(a / b) + (a * t * (a + 3b)) / sqrt(a * b) +
(-t * (a^2) * cos(2t * sqrt(a * b))) / sqrt(a * b) + 6sqrt(a / b) +
2(cos(2t * sqrt(a * b)) - 4cos(t * sqrt(a * b))) * sqrt(a / b) +
6t * sqrt(a * b) + sin(2t * sqrt(a * b)) - 8sin(t * sqrt(a * b))
) /
(4.0(a^0.5) * (b^1.5)) +
(
(4cos(t * sqrt(a * b))) / sqrt(b / a) + (-cos(2t * sqrt(a * b))) / sqrt(b / a) +
(a * t * (b + 3a)) / sqrt(a * b) +
(-4t * (a^2) * cos(t * sqrt(a * b))) / sqrt(a * b) +
(a * t * (a - b) * cos(2t * sqrt(a * b))) / sqrt(a * b) +
(-4t * (a^2) * sqrt(b / a) * sin(t * sqrt(a * b))) / sqrt(a * b) +
(2t * (a^2) * sqrt(b / a) * sin(2t * sqrt(a * b))) / sqrt(a * b) +
-3 / sqrt(b / a) + 2t * sqrt(a * b) - sin(2t * sqrt(a * b))
) /
(4.0(a^1.5) * (b^0.5)) +
(
-2.0(a^1.5) *
(
(
(a - b) * sin(2t * sqrt(a * b)) + 8a * sqrt(b / a) * cos(t * sqrt(a * b)) +
2t * (b + 3a) * sqrt(a * b) - 6a * sqrt(b / a) - 8a * sin(t * sqrt(a * b)) -
2a * sqrt(b / a) * cos(2t * sqrt(a * b))
) / (16.0b * (a^3.0))
)
) / (b^0.5) -
6.0(a^0.5) * (b^0.5) *
(
(
b * (
2(cos(2t * sqrt(a * b)) - 4cos(t * sqrt(a * b))) * sqrt(a / b) +
sin(2t * sqrt(a * b)) - 8sin(t * sqrt(a * b))
) + 6b * sqrt(a / b) +
2t * (a + 3b) * sqrt(a * b) - a * sin(2t * sqrt(a * b))
) / (16.0a * (b^3.0))
)
return [d1, d2]
end

function callback_saving_linear(u, t, integrator, sol)
return -1 .* [-sol(t)[2] 0; 0 sol(t)[1]]' * u
end
function callback_saving_linear_inplace(du, u, t, integrator, sol)
return du .= -1 .* [-sol(t)[2] 0; 0 sol(t)[1]]' * u
end

u0 = [1.0, 1.0] # initial condition
tspan = (0.0, 10.0) # simulation time
Expand Down
Loading