Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ PrecompileTools = "1.2"
QuadGK = "2.9"
RecipesBase = "1.3.4"
RecursiveArrayTools = "3.27, 4"
SafeTestsets = "0.1, 1"
SciMLBase = "2.54, 3.1"
StaticArrays = "1.9.7"
StaticArraysCore = "1.4"
Expand All @@ -70,9 +71,10 @@ OrdinaryDiffEqTsit5 = "b1df2697-797e-41e3-8120-5422d3b24e4a"
OrdinaryDiffEqVerner = "79d7bb75-1356-48c1-b8c0-6832512096c2"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
QuadGK = "1fd47b50-473d-5c70-9696-f719f8f3bcdc"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Sundials = "c3572dad-4567-51f8-b174-8c6c989267f4"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["ADTypes", "Aqua", "DataInterpolations", "ForwardDiff", "Pkg", "NonlinearSolve", "ODEProblemLibrary", "OrdinaryDiffEqBDF", "OrdinaryDiffEqLowOrderRK", "OrdinaryDiffEqNonlinearSolve", "OrdinaryDiffEqTsit5", "OrdinaryDiffEqRosenbrock", "OrdinaryDiffEqVerner", "QuadGK", "StaticArrays", "Test", "Functors"]
test = ["ADTypes", "Aqua", "DataInterpolations", "ForwardDiff", "Pkg", "NonlinearSolve", "ODEProblemLibrary", "OrdinaryDiffEqBDF", "OrdinaryDiffEqLowOrderRK", "OrdinaryDiffEqNonlinearSolve", "OrdinaryDiffEqTsit5", "OrdinaryDiffEqRosenbrock", "OrdinaryDiffEqVerner", "QuadGK", "SafeTestsets", "StaticArrays", "Test", "Functors"]
1 change: 1 addition & 0 deletions test/independentlylinearizedtests.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using Test, DiffEqCallbacks
using DiffEqCallbacks: sample, store!, IndependentlyLinearizedSolutionChunks, finish!
using SciMLBase: ReturnCode

@testset "IndependentlyLinearizedSolution" begin
ils = IndependentlyLinearizedSolution{Float64, Float64}(
Expand Down
141 changes: 141 additions & 0 deletions test/integrating_GK_shared.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
function compute_dGdp(integrand)
temp = zeros(length(integrand.integrand), length(integrand.integrand[1]))
for i in 1:length(integrand.integrand)
for j in 1:length(integrand.integrand[1])
temp[i, j] = integrand.integrand[i][j]
end
end
return sum(temp, dims = 1)[:]
end

function compute_dGdp_nt(integrand)
temp = zeros(length(integrand.integrand), 4)
for i in 1:length(integrand.integrand)
temp[i, 1:2] .= integrand.integrand[i].x.αβ
temp[i, 3:4] .= integrand.integrand[i].δγ
end
return sum(temp, dims = 1)[:]
end

#### TESTING ON LINEAR SYSTEM WITH ANALYTICAL SOLUTION ####

function simple_linear_system(u, p, t)
a, b = p
return [-a * u[2], b * u[1]]
end

function adjoint_linear(u, p, t, sol)
a, b = p
return -[0 b; -a 0] * u - 2.0 * (sol(t) .- 1.0)
end

function adjoint_linear_inplace(du, u, p, t, sol)
a, b = p
return du .= -[0 b; -a 0] * u - 2.0 * (sol(t) .- 1.0)
end

function analytical_derivative(p, t)
a, b = p
d1 = (
b * (
(cos(2t * sqrt(a * b)) - 4cos(t * sqrt(a * b))) / (b * sqrt(a / b)) +
(b * t * cos(2t * sqrt(a * b))) / sqrt(a * b) +
(-4b * t * cos(t * sqrt(a * b))) / sqrt(a * b) +
2(
(2b * t * sin(t * sqrt(a * b))) / sqrt(a * b) +
(-b * t * sin(2t * sqrt(a * b))) / sqrt(a * b)
) * sqrt(a / b)
) +
(b * t * (a + 3b)) / sqrt(a * b) +
(-a * b * t * cos(2t * sqrt(a * b))) / sqrt(a * b) + 3 / sqrt(a / b) +
2t * sqrt(a * b) - sin(2t * sqrt(a * b))
) / (4.0(a^0.5) * (b^1.5)) +
(
(3a * (b / (a^2))) / sqrt(b / a) +
(a * (b / (a^2)) * cos(2t * sqrt(a * b))) / sqrt(b / a) +
(b * t * (b + 3a)) / sqrt(a * b) +
(b * t * (a - b) * cos(2t * sqrt(a * b))) / sqrt(a * b) +
(-4a * (b / (a^2)) * cos(t * sqrt(a * b))) / sqrt(b / a) +
(-4a * b * t * cos(t * sqrt(a * b))) / sqrt(a * b) +
(2a * b * t * sqrt(b / a) * sin(2t * sqrt(a * b))) / sqrt(a * b) +
(-4a * b * t * sqrt(b / a) * sin(t * sqrt(a * b))) / sqrt(a * b) +
6t * sqrt(a * b) + 8sqrt(b / a) * cos(t * sqrt(a * b)) + sin(2t * sqrt(a * b)) -
6sqrt(b / a) - 8sin(t * sqrt(a * b)) - 2sqrt(b / a) * cos(2t * sqrt(a * b))
) /
(4.0(a^1.5) * (b^0.5)) +
(
-2.0(b^1.5) *
(
(
b * (
2(cos(2t * sqrt(a * b)) - 4cos(t * sqrt(a * b))) * sqrt(a / b) +
sin(2t * sqrt(a * b)) - 8sin(t * sqrt(a * b))
) + 6b * sqrt(a / b) +
2t * (a + 3b) * sqrt(a * b) - a * sin(2t * sqrt(a * b))
) / (16.0a * (b^3.0))
)
) /
(a^0.5) -
6.0(a^0.5) * (b^0.5) *
(
(
(a - b) * sin(2t * sqrt(a * b)) + 8a * sqrt(b / a) * cos(t * sqrt(a * b)) +
2t * (b + 3a) * sqrt(a * b) - 6a * sqrt(b / a) - 8a * sin(t * sqrt(a * b)) -
2a * sqrt(b / a) * cos(2t * sqrt(a * b))
) / (16.0b * (a^3.0))
)
d2 = (
b * (
(a * t * cos(2t * sqrt(a * b))) / sqrt(a * b) +
(-(a / (b^2)) * (cos(2t * sqrt(a * b)) - 4cos(t * sqrt(a * b)))) / sqrt(a / b) +
(-4a * t * cos(t * sqrt(a * b))) / sqrt(a * b) +
2(
(2a * t * sin(t * sqrt(a * b))) / sqrt(a * b) +
(-a * t * sin(2t * sqrt(a * b))) / sqrt(a * b)
) * sqrt(a / b)
) +
(-3b * (a / (b^2))) / sqrt(a / b) + (a * t * (a + 3b)) / sqrt(a * b) +
(-t * (a^2) * cos(2t * sqrt(a * b))) / sqrt(a * b) + 6sqrt(a / b) +
2(cos(2t * sqrt(a * b)) - 4cos(t * sqrt(a * b))) * sqrt(a / b) +
6t * sqrt(a * b) + sin(2t * sqrt(a * b)) - 8sin(t * sqrt(a * b))
) /
(4.0(a^0.5) * (b^1.5)) +
(
(4cos(t * sqrt(a * b))) / sqrt(b / a) + (-cos(2t * sqrt(a * b))) / sqrt(b / a) +
(a * t * (b + 3a)) / sqrt(a * b) +
(-4t * (a^2) * cos(t * sqrt(a * b))) / sqrt(a * b) +
(a * t * (a - b) * cos(2t * sqrt(a * b))) / sqrt(a * b) +
(-4t * (a^2) * sqrt(b / a) * sin(t * sqrt(a * b))) / sqrt(a * b) +
(2t * (a^2) * sqrt(b / a) * sin(2t * sqrt(a * b))) / sqrt(a * b) +
-3 / sqrt(b / a) + 2t * sqrt(a * b) - sin(2t * sqrt(a * b))
) /
(4.0(a^1.5) * (b^0.5)) +
(
-2.0(a^1.5) *
(
(
(a - b) * sin(2t * sqrt(a * b)) + 8a * sqrt(b / a) * cos(t * sqrt(a * b)) +
2t * (b + 3a) * sqrt(a * b) - 6a * sqrt(b / a) - 8a * sin(t * sqrt(a * b)) -
2a * sqrt(b / a) * cos(2t * sqrt(a * b))
) / (16.0b * (a^3.0))
)
) / (b^0.5) -
6.0(a^0.5) * (b^0.5) *
(
(
b * (
2(cos(2t * sqrt(a * b)) - 4cos(t * sqrt(a * b))) * sqrt(a / b) +
sin(2t * sqrt(a * b)) - 8sin(t * sqrt(a * b))
) + 6b * sqrt(a / b) +
2t * (a + 3b) * sqrt(a * b) - a * sin(2t * sqrt(a * b))
) / (16.0a * (b^3.0))
)
return [d1, d2]
end

function callback_saving_linear(u, t, integrator, sol)
return -1 .* [-sol(t)[2] 0; 0 sol(t)[1]]' * u
end
function callback_saving_linear_inplace(du, u, t, integrator, sol)
return du .= -1 .* [-sol(t)[2] 0; 0 sol(t)[1]]' * u
end
7 changes: 4 additions & 3 deletions test/integrating_GK_sum_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ using OrdinaryDiffEqLowOrderRK, OrdinaryDiffEqTsit5, DiffEqCallbacks
using QuadGK
using Test

include("integrating_GK_shared.jl")

prob = ODEProblem((u, p, t) -> [1.0], [0.0], (0.0, 1.0))
integrated = IntegrandValuesSum(zeros(1))
sol = solve(
Expand All @@ -26,10 +28,9 @@ sol = solve(

#### TESTING ON LINEAR SYSTEM WITH ANALYTICAL SOLUTION ####

# Reuse shared helper functions defined in integrating_GK_tests.jl:
# compute_dGdp, compute_dGdp_nt, simple_linear_system, adjoint_linear,
# Shared helper functions (compute_dGdp, simple_linear_system, adjoint_linear,
# adjoint_linear_inplace, analytical_derivative, callback_saving_linear,
# callback_saving_linear_inplace
# callback_saving_linear_inplace) are included from integrating_GK_shared.jl above.

u0 = [1.0, 1.0] # initial condition
tspan = (0.0, 10.0) # simulation time
Expand Down
141 changes: 2 additions & 139 deletions test/integrating_GK_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ using OrdinaryDiffEqLowOrderRK, OrdinaryDiffEqTsit5, DiffEqCallbacks
using QuadGK
using Test

include("integrating_GK_shared.jl")

prob = ODEProblem((u, p, t) -> [1.0], [0.0], (0.0, 1.0))
integrated = IntegrandValues(Float64, Vector{Float64})
sol = solve(
Expand All @@ -24,155 +26,16 @@ sol = solve(
)
@test sum(integrated.integrand)[1] .≈ sin(1000) / 1000

function compute_dGdp(integrand)
temp = zeros(length(integrand.integrand), length(integrand.integrand[1]))
for i in 1:length(integrand.integrand)
for j in 1:length(integrand.integrand[1])
temp[i, j] = integrand.integrand[i][j]
end
end
return sum(temp, dims = 1)[:]
end

function compute_dGdp_nt(integrand)
temp = zeros(length(integrand.integrand), 4)
for i in 1:length(integrand.integrand)
temp[i, 1:2] .= integrand.integrand[i].x.αβ
temp[i, 3:4] .= integrand.integrand[i].δγ
end
return sum(temp, dims = 1)[:]
end

#### TESTING ON LINEAR SYSTEM WITH ANALYTICAL SOLUTION ####

function simple_linear_system(u, p, t)
a, b = p
return [-a * u[2], b * u[1]]
end

function adjoint_linear(u, p, t, sol)
a, b = p
return -[0 b; -a 0] * u - 2.0 * (sol(t) .- 1.0)
end

function adjoint_linear_inplace(du, u, p, t, sol)
a, b = p
return du .= -[0 b; -a 0] * u - 2.0 * (sol(t) .- 1.0)
end

u0 = [1.0, 1.0] # initial condition
tspan = (0.0, 10.0) # simulation time
p = [1.0, 2.0] # parameters
prob = ODEProblem(simple_linear_system, u0, tspan, p)
sol = solve(prob, Tsit5(), abstol = 1.0e-14, reltol = 1.0e-14)

function analytical_derivative(p, t)
a, b = p
d1 = (
b * (
(cos(2t * sqrt(a * b)) - 4cos(t * sqrt(a * b))) / (b * sqrt(a / b)) +
(b * t * cos(2t * sqrt(a * b))) / sqrt(a * b) +
(-4b * t * cos(t * sqrt(a * b))) / sqrt(a * b) +
2(
(2b * t * sin(t * sqrt(a * b))) / sqrt(a * b) +
(-b * t * sin(2t * sqrt(a * b))) / sqrt(a * b)
) * sqrt(a / b)
) +
(b * t * (a + 3b)) / sqrt(a * b) +
(-a * b * t * cos(2t * sqrt(a * b))) / sqrt(a * b) + 3 / sqrt(a / b) +
2t * sqrt(a * b) - sin(2t * sqrt(a * b))
) / (4.0(a^0.5) * (b^1.5)) +
(
(3a * (b / (a^2))) / sqrt(b / a) +
(a * (b / (a^2)) * cos(2t * sqrt(a * b))) / sqrt(b / a) +
(b * t * (b + 3a)) / sqrt(a * b) +
(b * t * (a - b) * cos(2t * sqrt(a * b))) / sqrt(a * b) +
(-4a * (b / (a^2)) * cos(t * sqrt(a * b))) / sqrt(b / a) +
(-4a * b * t * cos(t * sqrt(a * b))) / sqrt(a * b) +
(2a * b * t * sqrt(b / a) * sin(2t * sqrt(a * b))) / sqrt(a * b) +
(-4a * b * t * sqrt(b / a) * sin(t * sqrt(a * b))) / sqrt(a * b) +
6t * sqrt(a * b) + 8sqrt(b / a) * cos(t * sqrt(a * b)) + sin(2t * sqrt(a * b)) -
6sqrt(b / a) - 8sin(t * sqrt(a * b)) - 2sqrt(b / a) * cos(2t * sqrt(a * b))
) /
(4.0(a^1.5) * (b^0.5)) +
(
-2.0(b^1.5) *
(
(
b * (
2(cos(2t * sqrt(a * b)) - 4cos(t * sqrt(a * b))) * sqrt(a / b) +
sin(2t * sqrt(a * b)) - 8sin(t * sqrt(a * b))
) + 6b * sqrt(a / b) +
2t * (a + 3b) * sqrt(a * b) - a * sin(2t * sqrt(a * b))
) / (16.0a * (b^3.0))
)
) /
(a^0.5) -
6.0(a^0.5) * (b^0.5) *
(
(
(a - b) * sin(2t * sqrt(a * b)) + 8a * sqrt(b / a) * cos(t * sqrt(a * b)) +
2t * (b + 3a) * sqrt(a * b) - 6a * sqrt(b / a) - 8a * sin(t * sqrt(a * b)) -
2a * sqrt(b / a) * cos(2t * sqrt(a * b))
) / (16.0b * (a^3.0))
)
d2 = (
b * (
(a * t * cos(2t * sqrt(a * b))) / sqrt(a * b) +
(-(a / (b^2)) * (cos(2t * sqrt(a * b)) - 4cos(t * sqrt(a * b)))) / sqrt(a / b) +
(-4a * t * cos(t * sqrt(a * b))) / sqrt(a * b) +
2(
(2a * t * sin(t * sqrt(a * b))) / sqrt(a * b) +
(-a * t * sin(2t * sqrt(a * b))) / sqrt(a * b)
) * sqrt(a / b)
) +
(-3b * (a / (b^2))) / sqrt(a / b) + (a * t * (a + 3b)) / sqrt(a * b) +
(-t * (a^2) * cos(2t * sqrt(a * b))) / sqrt(a * b) + 6sqrt(a / b) +
2(cos(2t * sqrt(a * b)) - 4cos(t * sqrt(a * b))) * sqrt(a / b) +
6t * sqrt(a * b) + sin(2t * sqrt(a * b)) - 8sin(t * sqrt(a * b))
) /
(4.0(a^0.5) * (b^1.5)) +
(
(4cos(t * sqrt(a * b))) / sqrt(b / a) + (-cos(2t * sqrt(a * b))) / sqrt(b / a) +
(a * t * (b + 3a)) / sqrt(a * b) +
(-4t * (a^2) * cos(t * sqrt(a * b))) / sqrt(a * b) +
(a * t * (a - b) * cos(2t * sqrt(a * b))) / sqrt(a * b) +
(-4t * (a^2) * sqrt(b / a) * sin(t * sqrt(a * b))) / sqrt(a * b) +
(2t * (a^2) * sqrt(b / a) * sin(2t * sqrt(a * b))) / sqrt(a * b) +
-3 / sqrt(b / a) + 2t * sqrt(a * b) - sin(2t * sqrt(a * b))
) /
(4.0(a^1.5) * (b^0.5)) +
(
-2.0(a^1.5) *
(
(
(a - b) * sin(2t * sqrt(a * b)) + 8a * sqrt(b / a) * cos(t * sqrt(a * b)) +
2t * (b + 3a) * sqrt(a * b) - 6a * sqrt(b / a) - 8a * sin(t * sqrt(a * b)) -
2a * sqrt(b / a) * cos(2t * sqrt(a * b))
) / (16.0b * (a^3.0))
)
) / (b^0.5) -
6.0(a^0.5) * (b^0.5) *
(
(
b * (
2(cos(2t * sqrt(a * b)) - 4cos(t * sqrt(a * b))) * sqrt(a / b) +
sin(2t * sqrt(a * b)) - 8sin(t * sqrt(a * b))
) + 6b * sqrt(a / b) +
2t * (a + 3b) * sqrt(a * b) - a * sin(2t * sqrt(a * b))
) / (16.0a * (b^3.0))
)
return [d1, d2]
end

integrand_values = IntegrandValues(Float64, Vector{Float64})
integrand_values_inplace = IntegrandValues(Float64, Vector{Float64})
function callback_saving_linear(u, t, integrator, sol)
return -1 .* [-sol(t)[2] 0; 0 sol(t)[1]]' * u
end
function callback_saving_linear_inplace(du, u, t, integrator, sol)
return du .= -1 .* [-sol(t)[2] 0; 0 sol(t)[1]]' * u
end
cb = IntegratingGKCallback(
(u, t, integrator) -> callback_saving_linear(u, t, integrator, sol),
integrand_values, zeros(length(p))
Expand Down
2 changes: 2 additions & 0 deletions test/nopre/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb"
OrdinaryDiffEqLowOrderRK = "1344f307-1e59-4825-a18e-ace9aa3fa4c6"
OrdinaryDiffEqTsit5 = "b1df2697-797e-41e3-8120-5422d3b24e4a"
QuadGK = "1fd47b50-473d-5c70-9696-f719f8f3bcdc"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
Expand All @@ -22,6 +23,7 @@ JLArrays = "0.1, 0.2"
OrdinaryDiffEqLowOrderRK = "2"
OrdinaryDiffEqTsit5 = "2"
QuadGK = "2.9"
SafeTestsets = "0.1, 1"
SciMLSensitivity = "7.105"
Tracker = "0.2.35"
Zygote = "0.6.69, 0.7"
Loading