diff --git a/Project.toml b/Project.toml index 02810e7..20a3503 100644 --- a/Project.toml +++ b/Project.toml @@ -21,6 +21,7 @@ Cubature = "667455a9-e2ce-5579-9412-b964f529a492" FastGaussQuadrature = "442a2c76-b920-505d-bb47-c5924d526838" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" MCIntegration = "ea1e2de9-7db7-4b42-91ee-0cd1bf6df167" +Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [extensions] @@ -30,7 +31,8 @@ IntegralsCubatureExt = "Cubature" IntegralsFastGaussQuadratureExt = "FastGaussQuadrature" IntegralsForwardDiffExt = "ForwardDiff" IntegralsMCIntegrationExt = "MCIntegration" -IntegralsZygoteExt = ["Zygote", "ChainRulesCore"] +IntegralsMooncakeExt = ["Mooncake", "Zygote", "ChainRulesCore"] +IntegralsZygoteExt = ["Zygote", "ChainRulesCore", "Mooncake"] [compat] Aqua = "0.8" @@ -46,6 +48,7 @@ ForwardDiff = "0.10.36, 1" HCubature = "1.7" LinearAlgebra = "1.10" MCIntegration = "0.4.2" +Mooncake = "0.4.184" MonteCarloIntegration = "0.2" QuadGK = "2.11" Random = "1.10" @@ -68,10 +71,11 @@ FastGaussQuadrature = "442a2c76-b920-505d-bb47-c5924d526838" FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" MCIntegration = "ea1e2de9-7db7-4b42-91ee-0cd1bf6df167" +Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["Aqua", "Arblib", "StaticArrays", "FiniteDiff", "SafeTestsets", "Test", "Distributions", "ForwardDiff", "Zygote", "ChainRulesCore", "FastGaussQuadrature", "Cuba", "Cubature", "MCIntegration"] +test = ["Aqua", "Arblib", "StaticArrays", "FiniteDiff", "SafeTestsets", "Test", "Distributions", "ForwardDiff", "Zygote", "ChainRulesCore", "FastGaussQuadrature", "Cuba", "Cubature", "MCIntegration", "Mooncake"] diff --git a/ext/IntegralsMooncakeExt.jl b/ext/IntegralsMooncakeExt.jl new file mode 100644 index 0000000..4581458 --- /dev/null +++ b/ext/IntegralsMooncakeExt.jl @@ -0,0 +1,314 @@ +module IntegralsMooncakeExt +using Mooncake +using LinearAlgebra: dot +using Integrals, SciMLBase, QuadGK +using Mooncake: @from_chainrules, @is_primitive, increment!!, MinimalCtx, rrule!!, NoFData, NoRData, CoDual, primal, NoRData, zero_fcodual +import Mooncake: increment_and_get_rdata!, @zero_derivative +using Integrals: AbstractIntegralMetaAlgorithm, IntegralProblem +import ChainRulesCore +import ChainRulesCore: Tangent, NoTangent, ProjectTo +using Zygote # use chainrules defined in ZygoteExt + +batch_unwrap(x::AbstractArray) = dropdims(x; dims=ndims(x)) + +@zero_derivative MinimalCtx Tuple{typeof(QuadGK.quadgk),Vararg} +@zero_derivative MinimalCtx Tuple{typeof(QuadGK.cachedrule),Any,Integer} +@zero_derivative MinimalCtx Tuple{typeof(Integrals.checkkwargs),Vararg} +@zero_derivative MinimalCtx Tuple{typeof(Integrals.isinplace),Vararg} +@zero_derivative MinimalCtx Tuple{typeof(Integrals.init_cacheval),Union{<:SciMLBase.AbstractIntegralAlgorithm,<:AbstractIntegralMetaAlgorithm},Union{<:IntegralProblem,<:SampledIntegralProblem}} +@zero_derivative MinimalCtx Tuple{typeof(Integrals.substitute_f),Union{<:BatchIntegralFunction,<:IntegralFunction},Any,Any,Any} +@zero_derivative MinimalCtx Tuple{typeof(Integrals.substitute_v),Any,Any,Union{<:AbstractVector,<:Number},Union{<:AbstractVector,<:Number}} +@zero_derivative MinimalCtx Tuple{typeof(Integrals.substitute_bv),Any,AbstractArray,Union{<:AbstractVector,<:Number},Union{<:AbstractVector,<:Number}} + +# @from_chainrules MinimalCtx Tuple{Type{IntegralProblem{iip}},Any,Any,Any} where {iip} true +@is_primitive MinimalCtx Tuple{Type{IntegralProblem{iip}},Any,Any,Any} where {iip} +function Mooncake.rrule!!(::CoDual{Type{IntegralProblem{iip}}}, f::CoDual, domain::CoDual, p::CoDual; kwargs...) where {iip} + f_prim, domain_prim, p_prim = map(primal, (f, domain, p)) + prob = IntegralProblem{iip}(f_prim, domain_prim, p_prim; kwargs...) + + function IntegralProblem_iip_pullback(Δ) + data = Δ isa NoRData ? Δ : Δ.data + ddomain = hasproperty(data, :domain) ? data.domain : NoRData() + dp = hasproperty(data, :p) ? data.p : NoRData() + dkwargs = hasproperty(Δ, :kwargs) ? data.kwargs : NoRData() + + # domain is always a Tuple, so it always has NoFData + # below conditional is in case p is an Array or similar + if Mooncake.rdata_type(typeof(p_prim)) == NoRData() + Mooncake.increment!!(p.dx, dp) + grad_p = NoRData() + else + grad_p = dp + end + + return NoRData(), NoRData(), ddomain, grad_p, dkwargs + end + return zero_fcodual(prob), IntegralProblem_iip_pullback +end + +# Mooncake does not need chainrule for evaluate! as it supports mutation. +@from_chainrules MinimalCtx Tuple{Type{IntegralProblem},Any,Any,Any} true +@from_chainrules MinimalCtx Tuple{typeof(Integrals.u2t),Any,Any} true +@from_chainrules MinimalCtx Tuple{typeof(SciMLBase.build_solution),IntegralProblem,Any,Any,Any} true + +@from_chainrules MinimalCtx Tuple{typeof(Integrals.__solvebp),Any,Any,Any,Any,Any} true +function ChainRulesCore.rrule(::typeof(Integrals.__solvebp), cache, alg, sensealg, domain, + p; + kwargs...) + # TODO: integrate the primal and dual in the same call to the quadrature library + out = Integrals.__solvebp_call(cache, alg, sensealg, domain, p; kwargs...) + + # the adjoint will be the integral of the input sensitivities, so it maps the + # sensitivity of the output to an object of the type of the parameters + function quadrature_adjoint(Δ) + # https://juliadiff.org/ChainRulesCore.jl/dev/design/many_tangents.html#manytypes + if sensealg.vjp isa Integrals.ZygoteVJP + if isinplace(cache) + # zygote doesn't support mutation, so we build an oop pullback + if cache.f isa BatchIntegralFunction + dx = similar(cache.f.integrand_prototype, + size(cache.f.integrand_prototype)[begin:(end-1)]..., 1) + _f = x -> (cache.f(dx, x, p); dx) + # TODO: let the user pass a batched jacobian so we can return a BatchIntegralFunction + dfdp_ = function (x, p) + x_ = x isa AbstractArray ? reshape(x, size(x)..., 1) : [x] + z, back = Zygote.pullback(p) do p + _dx = Zygote.Buffer(dx) + cache.f(_dx, x_, p) + copy(_dx) + end + return back(z .= (Δ isa AbstractArray ? reshape(Δ, size(Δ)..., 1) : + Δ))[1] + end + dfdp = IntegralFunction{false}(dfdp_, nothing) + else + dx = similar(cache.f.integrand_prototype) + _f = x -> (cache.f(dx, x, p); dx) + dfdp_ = function (x, p) + _, back = Zygote.pullback(p) do p + _dx = Zygote.Buffer(dx) + cache.f(_dx, x, p) + copy(_dx) + end + back(Δ)[1] + end + dfdp = IntegralFunction{false}(dfdp_, nothing) + end + else + _f = x -> cache.f(x, p) + if cache.f isa BatchIntegralFunction + # TODO: let the user pass a batched jacobian so we can return a BatchIntegralFunction + dfdp_ = function (x, p) + x_ = x isa AbstractArray ? reshape(x, size(x)..., 1) : [x] + z, back = Zygote.pullback(p -> cache.f(x_, p), p) + return back(Δ isa AbstractArray ? reshape(Δ, size(Δ)..., 1) : [Δ])[1] + end + dfdp = IntegralFunction{false}(dfdp_, nothing) + else + dfdp_ = function (x, p) + z, back = Zygote.pullback(p -> cache.f(x, p), p) + back(z isa Number ? only(Δ) : Δ)[1] + end + dfdp = IntegralFunction{false}(dfdp_, nothing) + end + end + elseif sensealg.vjp isa Integrals.MooncakeVJP + # SOMETHINGS UP WITH DFDP FUNCTION prob.f it cant accept two ints and error. + if isinplace(cache) + if cache.f isa BatchIntegralFunction + error("TODO") + else + dx = similar(cache.f.integrand_prototype) + _f = x -> (cache.f(dx, x, p); dx) + dfdp_ = function (x, p) + # dx is modified inplace by dfdp/integralfunc_closure_p calls AND the Reverse pass tangent comes externally (from Δ). + # Therefore, Δ.u is Tangent passed to the pullback AND integralfunc_closure_p must always return dx as Output. + # i.e. (tangent(output) == Δ.u). Otherwise integralfunc_closure_p only outputs "nothing" and tangent(output) != Δ.u + integralfunc_closure_p = p -> (cache.f(dx, x, p); dx) + cache_z = Mooncake.prepare_pullback_cache(integralfunc_closure_p, p) + z, grads = Mooncake.value_and_pullback!!(cache_z, Δ.u, integralfunc_closure_p, p) + return grads[2] + end + dfdp = IntegralFunction{false}(dfdp_, nothing) + end + else + _f = x -> cache.f(x, p) + if cache.f isa BatchIntegralFunction + # TODO: let the user pass a batched jacobian so we can return a BatchIntegralFunction + error("TODO") + else + dfdp_ = function (x, p) + integralfunc_closure_p = p -> cache.f(x, p) + cache_z = Mooncake.prepare_pullback_cache(integralfunc_closure_p, p) + # Δ.u is integrand function's output sensitivity which we pass into Mooncake's pullback + z, grads = Mooncake.value_and_pullback!!(cache_z, Δ.u, integralfunc_closure_p, p) + return grads[2] + end + dfdp = IntegralFunction{false}(dfdp_, nothing) + end + end + elseif sensealg.vjp isa Integrals.ReverseDiffVJP + error("TODO") + end + + prob = Integrals.build_problem(cache) + # dp_prob = remake(prob, f = dfdp) # fails because we change iip + dp_prob = IntegralProblem(dfdp, prob.domain, prob.p; prob.kwargs...) + # the infinity transformation was already applied to f so we don't apply it to dfdp + dp_cache = init(dp_prob, + alg; + sensealg=sensealg, + cache.kwargs...) + + project_p = ProjectTo(p) + dp = project_p(solve!(dp_cache).u) + + # Because Mooncake tangent structure vs Zygote, Chainrules, ReverseDiff + du_adj = sensealg.vjp isa Integrals.MooncakeVJP ? Δ.u : Δ + + lb, ub = domain + if lb isa Number + # TODO replace evaluation at endpoint (which anyone can do without Integrals.jl) + # with integration of dfdx uing the same quadrature + dlb = cache.f isa BatchIntegralFunction ? -batch_unwrap(_f([lb])) : -_f(lb) + dub = cache.f isa BatchIntegralFunction ? batch_unwrap(_f([ub])) : _f(ub) + return (NoTangent(), + NoTangent(), + NoTangent(), + NoTangent(), + Tangent{typeof(domain)}(dot(dlb, du_adj), dot(dub, du_adj)), + dp) + else + # we need to compute 2*length(lb) integrals on the faces of the hypercube, as we + # can see from writing the multidimensional integral as an iterated integral + # alternatively we can use Stokes' theorem to replace the integral on the + # boundary with a volume integral of the flux of the integrand + # ∫∂Ω ω = ∫Ω dω, which would be better since we won't have to change the + # dimensionality of the integral or the quadrature used (such as quadratures + # that don't evaluate points on the boundaries) and it could be generalized to + # other kinds of domains. The only question is to determine ω in terms of f and + # the deformation of the surface (e.g. consider integral over an ellipse and + # asking for the derivative of the result w.r.t. the semiaxes of the ellipse) + return (NoTangent(), NoTangent(), NoTangent(), NoTangent(), NoTangent(), dp) + end + end + out, quadrature_adjoint +end + +# Internal Mooncake overloads to accommodate IntegralSolution etc. Struct's Tangent Types. +# Allows clear translation from ChainRules -> Mooncake's tangent. +function Mooncake.increment_and_get_rdata!( + f::NoFData, r::Tuple{T,T}, t::Union{Tangent{Tuple{T,T},Tuple{T,T}},Tangent{Any,Tuple{Float64,Float64}}} +) where {T<:Base.IEEEFloat} + return r .+ t.backing +end + +function Mooncake.increment_and_get_rdata!( + f::Tuple{Vector{T},Vector{T}}, + r::NoRData, + t::Tangent{Any,Tuple{Vector{T},Vector{T}}}, +) where {T<:Base.IEEEFloat} + Mooncake.increment!!(f, t.backing) + return NoRData() +end + +# sol.u & p are single scalar values, domain (lb,ub) is single/multi - variate. +function Mooncake.increment_and_get_rdata!( + f::NoFData, + r::T, + t::Tangent{Any, + @NamedTuple{ + u::T, + resid::R, + prob::Tangent{Any, + @NamedTuple{ + f::NoTangent, + domain::Tangent{Any,Tuple{M,M}}, + p::P, + kwargs::NoTangent + } + }, + alg::A, + retcode::NoTangent, + chi::NoTangent, + stats::NoTangent + } + } +) where {T<:Base.IEEEFloat, + R<:Union{NoTangent,T}, + P<:Union{T,Vector{T}}, + M<:Union{T,Vector{T}}, + A<:Union{NoTangent, + Tangent{Any, + @NamedTuple{ + nodes::Vector{T}, + weights::Vector{T}, + subintervals::NoTangent + } + } + } +} + # rdata component of t + r (u field) + return Mooncake.increment_and_get_rdata!(f, r, t.u) +end + +# sol.u is vector valued, p is scalar/vector valued, domain can be single/multi - variate +# resid can be single/vector valued. For inplace integrals (iip true) : included integrand_prototype field in typeof{prob.f} +function Mooncake.increment_and_get_rdata!( + f::Vector{T}, + r::NoRData, + t::Union{ + Tangent{ + Any, + @NamedTuple{ + u::Vector{T}, + resid::R, + prob::Tangent{ + Any, + @NamedTuple{ + f::F, + domain::Tangent{Any,M}, + p::P, + kwargs::NoTangent + } + }, + alg::A, + retcode::NoTangent, + chi::NoTangent, + stats::NoTangent + } + } + } +) where {T<:Base.IEEEFloat, + R<:Union{NoTangent,T,Vector{T}}, + P<:Union{T,Vector{T}}, + M<:Union{Tuple{T,T},Tuple{Vector{T},Vector{T}}}, + F<:Union{NoTangent, + Tangent{ + Any, + @NamedTuple{ + f::NoTangent, + integrand_prototype::Vector{T} + } + } + }, + A<:Union{NoTangent, + Tangent{Any, + @NamedTuple{ + nodes::Vector{T}, + weights::Vector{T}, + subintervals::NoTangent + } + } + } +} + Mooncake.increment!!(f, t.u) + # rdata component(t) + r + return t.prob.domain +end + +# cannot mutate NoRData() in place, therefore return as is. +function Mooncake.increment!!(::Mooncake.NoRData, y::Tangent{Any,Y}) where {T<:Base.IEEEFloat,Y<:Union{Tuple{T,T},Tuple{Vector{T},Vector{T}}}} + return Mooncake.NoRData() +end +end \ No newline at end of file diff --git a/ext/IntegralsZygoteExt.jl b/ext/IntegralsZygoteExt.jl index 2e77adf..4bc43be 100644 --- a/ext/IntegralsZygoteExt.jl +++ b/ext/IntegralsZygoteExt.jl @@ -4,6 +4,7 @@ using Integrals using Zygote import ChainRulesCore import ChainRulesCore: Tangent, NoTangent, ProjectTo +using Mooncake # call __solve_bp's chainrule from Mooncake's extension ChainRulesCore.@non_differentiable Integrals.checkkwargs(kwargs...) ChainRulesCore.@non_differentiable Integrals.isinplace(f, args...) # fixes #99 @@ -62,117 +63,6 @@ function ChainRulesCore.rrule(::typeof(Integrals.u2t), lb, ub) return out, u2t_pullback end -function ChainRulesCore.rrule(::typeof(Integrals.__solvebp), cache, alg, sensealg, domain, - p; - kwargs...) - # TODO: integrate the primal and dual in the same call to the quadrature library - out = Integrals.__solvebp_call(cache, alg, sensealg, domain, p; kwargs...) - - # the adjoint will be the integral of the input sensitivities, so it maps the - # sensitivity of the output to an object of the type of the parameters - function quadrature_adjoint(Δ) - # https://juliadiff.org/ChainRulesCore.jl/dev/design/many_tangents.html#manytypes - if isinplace(cache) - # zygote doesn't support mutation, so we build an oop pullback - if sensealg.vjp isa Integrals.ZygoteVJP - if cache.f isa BatchIntegralFunction - dx = similar(cache.f.integrand_prototype, - size(cache.f.integrand_prototype)[begin:(end - 1)]..., 1) - _f = x -> (cache.f(dx, x, p); dx) - # TODO: let the user pass a batched jacobian so we can return a BatchIntegralFunction - dfdp_ = function (x, p) - x_ = x isa AbstractArray ? reshape(x, size(x)..., 1) : [x] - z, back = Zygote.pullback(p) do p - _dx = Zygote.Buffer(dx) - cache.f(_dx, x_, p) - copy(_dx) - end - return back(z .= (Δ isa AbstractArray ? reshape(Δ, size(Δ)..., 1) : - Δ))[1] - end - dfdp = IntegralFunction{false}(dfdp_, nothing) - else - dx = similar(cache.f.integrand_prototype) - _f = x -> (cache.f(dx, x, p); dx) - dfdp_ = function (x, p) - _, back = Zygote.pullback(p) do p - _dx = Zygote.Buffer(dx) - cache.f(_dx, x, p) - copy(_dx) - end - back(Δ)[1] - end - dfdp = IntegralFunction{false}(dfdp_, nothing) - end - elseif sensealg.vjp isa Integrals.ReverseDiffVJP - error("TODO") - end - else - _f = x -> cache.f(x, p) - if sensealg.vjp isa Integrals.ZygoteVJP - if cache.f isa BatchIntegralFunction - # TODO: let the user pass a batched jacobian so we can return a BatchIntegralFunction - dfdp_ = function (x, p) - x_ = x isa AbstractArray ? reshape(x, size(x)..., 1) : [x] - z, back = Zygote.pullback(p -> cache.f(x_, p), p) - return back(Δ isa AbstractArray ? reshape(Δ, size(Δ)..., 1) : [Δ])[1] - end - dfdp = IntegralFunction{false}(dfdp_, nothing) - else - dfdp_ = function (x, p) - z, back = Zygote.pullback(p -> cache.f(x, p), p) - back(z isa Number ? only(Δ) : Δ)[1] - end - dfdp = IntegralFunction{false}(dfdp_, nothing) - end - elseif sensealg.vjp isa Integrals.ReverseDiffVJP - error("TODO") - end - end - - prob = Integrals.build_problem(cache) - # dp_prob = remake(prob, f = dfdp) # fails because we change iip - dp_prob = IntegralProblem(dfdp, prob.domain, prob.p; prob.kwargs...) - # the infinity transformation was already applied to f so we don't apply it to dfdp - dp_cache = init(dp_prob, - alg; - sensealg = sensealg, - cache.kwargs...) - - project_p = ProjectTo(p) - dp = project_p(solve!(dp_cache).u) - - lb, ub = domain - if lb isa Number - # TODO replace evaluation at endpoint (which anyone can do without Integrals.jl) - # with integration of dfdx uing the same quadrature - dlb = cache.f isa BatchIntegralFunction ? -batch_unwrap(_f([lb])) : -_f(lb) - dub = cache.f isa BatchIntegralFunction ? batch_unwrap(_f([ub])) : _f(ub) - return (NoTangent(), - NoTangent(), - NoTangent(), - NoTangent(), - Tangent{typeof(domain)}(dot(dlb, Δ), dot(dub, Δ)), - dp) - else - # we need to compute 2*length(lb) integrals on the faces of the hypercube, as we - # can see from writing the multidimensional integral as an iterated integral - # alternatively we can use Stokes' theorem to replace the integral on the - # boundary with a volume integral of the flux of the integrand - # ∫∂Ω ω = ∫Ω dω, which would be better since we won't have to change the - # dimensionality of the integral or the quadrature used (such as quadratures - # that don't evaluate points on the boundaries) and it could be generalized to - # other kinds of domains. The only question is to determine ω in terms of f and - # the deformation of the surface (e.g. consider integral over an ellipse and - # asking for the derivative of the result w.r.t. the semiaxes of the ellipse) - return (NoTangent(), NoTangent(), NoTangent(), NoTangent(), NoTangent(), dp) - end - end - out, quadrature_adjoint -end - -batch_unwrap(x::AbstractArray) = dropdims(x; dims = ndims(x)) - Zygote.@adjoint function Zygote.literal_getproperty(sol::SciMLBase.IntegralSolution, ::Val{:u}) sol.u, Δ -> (SciMLBase.build_solution(sol.prob, sol.alg, Δ, sol.resid),) diff --git a/src/Integrals.jl b/src/Integrals.jl index 15f33d5..771fe60 100644 --- a/src/Integrals.jl +++ b/src/Integrals.jl @@ -61,6 +61,13 @@ struct ReverseDiffVJP <: IntegralVJP compile::Bool end +""" + MooncakeVJP <: IntegralVJP + +Uses Mooncake.jl for vector-Jacobian products in automatic differentiation of integrals. +""" +struct MooncakeVJP <: IntegralVJP end + function scale_x!(_x, ub, lb, x) _x .= (ub .- lb) .* x .+ lb _x diff --git a/test/derivative_tests.jl b/test/derivative_tests.jl index 7a21dea..ad17e88 100644 --- a/test/derivative_tests.jl +++ b/test/derivative_tests.jl @@ -1,4 +1,4 @@ -using Integrals, Zygote, FiniteDiff, ForwardDiff#, SciMLSensitivity +using Integrals, Mooncake, Zygote, FiniteDiff, ForwardDiff#, SciMLSensitivity using Cuba, Cubature using FastGaussQuadrature using Test @@ -84,6 +84,10 @@ Base.axes(::Scalar) = ScalarAxes() Base.iterate(::ScalarAxes) = nothing Base.reshape(A::AbstractArray, ::ScalarAxes) = Scalar(only(A)) +# Scalar struct defined around Real Numbers (test/derivative_tests.jl) +# Mooncake, like Zygote also treats 0-D data wrt to the type of datastructure. +Mooncake.rdata_type(::Type{Scalar{T}}) where {T<:Real} = Mooncake.rdata_type(T) + # here we assume f evaluated at scalar inputs gives a scalar output # p will be able to be a number after https://github.com/FluxML/Zygote.jl/pull/1489 # p will be able to be a 0-array after https://github.com/FluxML/Zygote.jl/pull/1491 @@ -144,6 +148,52 @@ do_tests = function (; f, scalarize, lb, ub, p, alg, abstol, reltol) return end +# Mooncake Sensealg testing helper function +do_tests_mooncake = function (; f, scalarize, lb, ub, p, alg, abstol, reltol) + testf = function (lb, ub, p) + prob = IntegralProblem(f, (lb, ub), p) + scalarize(solve(prob, alg; reltol, abstol, sensealg=Integrals.ReCallVJP{Integrals.MooncakeVJP}(Integrals.MooncakeVJP()))) + end + sol_fp = testf(lb, ub, p) + + # sensealg when non zygoet? + cache = Mooncake.prepare_gradient_cache(testf, lb, ub, p isa Number && f isa BatchIntegralFunction ? Scalar(p) : p) + forwpassval, gradients = Mooncake.value_and_gradient!!(cache, testf, lb, ub, p isa Number && f isa BatchIntegralFunction ? Scalar(p) : p) + + @test forwpassval == sol_fp + + f_lb = lb -> testf(lb, ub, p) + f_ub = ub -> testf(lb, ub, p) + + dlb = lb isa AbstractArray ? :gradient : :derivative + dub = ub isa AbstractArray ? :gradient : :derivative + + dlb2 = getproperty(FiniteDiff, Symbol(:finite_difference_, dlb))(f_lb, lb) + dub2 = getproperty(FiniteDiff, Symbol(:finite_difference_, dub))(f_ub, ub) + + if lb isa Number + @test gradients[2] ≈ dlb2 atol = abstol rtol = reltol + @test gradients[3] ≈ dub2 atol = abstol rtol = reltol + else # TODO: implement multivariate limit derivatives in MooncakeExt + @test_broken gradients[2] ≈ dlb2 atol = abstol rtol = reltol + @test_broken gradients[3] ≈ dub2 atol = abstol rtol = reltol + end + + f_p = p -> testf(lb, ub, p) + dp = p isa AbstractArray ? :gradient : :derivative + + dp2 = getproperty(FiniteDiff, Symbol(:finite_difference_, dp))(f_p, p) + dp3 = getproperty(ForwardDiff, dp)(f_p, p) + + @test dp2 ≈ dp3 atol = abstol rtol = reltol + + # test Mooncake for parameter p + @test gradients[4] ≈ dp2 atol = abstol rtol = reltol + @test dp2 ≈ dp3 atol = abstol rtol = reltol + + return +end + ### One Dimensional for (alg, req) in pairs(alg_req), (j, f) in enumerate(integrands), (i, scalarize) in enumerate(scalarize_solution) @@ -152,6 +202,7 @@ for (alg, req) in pairs(alg_req), (j, f) in enumerate(integrands), @info "One-dimensional, scalar, oop derivative test" alg=nameof(typeof(alg)) integrand=j scalarize=i do_tests(; f, scalarize, lb = 1.0, ub = 3.0, p = 2.0, alg, abstol, reltol) + do_tests_mooncake(; f, scalarize, lb = 1.0, ub = 3.0, p = 2.0, alg, abstol, reltol) end ## One-dimensional nout @@ -163,6 +214,8 @@ for (alg, req) in pairs(alg_req), (j, f) in enumerate(integrands), @info "One-dimensional, multivariate, oop derivative test" alg=nameof(typeof(alg)) integrand=j scalarize=i nout do_tests(; f, scalarize, lb = 1.0, ub = 3.0, p = [2.0i for i in 1:nout], alg, abstol, reltol) + do_tests_mooncake(; + f, scalarize, lb=1.0, ub=3.0, p=[2.0i for i in 1:nout], alg, abstol, reltol) end ### N-dimensional @@ -173,6 +226,7 @@ for (alg, req) in pairs(alg_req), (j, f) in enumerate(integrands), @info "Multi-dimensional, scalar, oop derivative test" alg=nameof(typeof(alg)) integrand=j scalarize=i dim do_tests(; f, scalarize, lb = ones(dim), ub = 3ones(dim), p = 2.0, alg, abstol, reltol) + do_tests_mooncake(; f, scalarize, lb=ones(dim), ub=3ones(dim), p=2.0, alg, abstol, reltol) end ### N-dimensional nout @@ -185,8 +239,11 @@ for (alg, req) in pairs(alg_req), (j, f) in enumerate(integrands), @info "Multi-dimensional, multivariate, oop derivative test" alg=nameof(typeof(alg)) integrand=j scalarize=i dim nout do_tests(; f, scalarize, lb = ones(dim), ub = 3ones(dim), p = [2.0i for i in 1:nout], alg, abstol, reltol) + do_tests_mooncake(; f, scalarize, lb=ones(dim), ub=3ones(dim), + p=[2.0i for i in 1:nout], alg, abstol, reltol) end +#### in place IntegralCache, IntegralFunction Tests ### One Dimensional for (alg, req) in pairs(alg_req), (j, f) in enumerate(integrands), (i, scalarize) in enumerate(scalarize_solution) @@ -197,6 +254,7 @@ for (alg, req) in pairs(alg_req), (j, f) in enumerate(integrands), @info "One-dimensional, scalar, iip derivative test" alg=nameof(typeof(alg)) integrand=j scalarize=i fiip = IntegralFunction((y, x, p) -> f_helper!(f, y, x, p), zeros(1)) do_tests(; f = fiip, scalarize, lb = 1.0, ub = 3.0, p = 2.0, alg, abstol, reltol) + do_tests_mooncake(; f=fiip, scalarize, lb = 1.0, ub = 3.0, p = 2.0, alg, abstol, reltol) end ## One-dimensional nout @@ -210,6 +268,8 @@ for (alg, req) in pairs(alg_req), (j, f) in enumerate(integrands), fiip = IntegralFunction((y, x, p) -> f_helper!(f, y, x, p), zeros(nout)) do_tests(; f = fiip, scalarize, lb = 1.0, ub = 3.0, p = [2.0i for i in 1:nout], alg, abstol, reltol) + do_tests_mooncake(; f=fiip, scalarize, lb = 1.0, ub = 3.0, + p = [2.0i for i in 1:nout], alg, abstol, reltol) end ### N-dimensional @@ -223,6 +283,8 @@ for (alg, req) in pairs(alg_req), (j, f) in enumerate(integrands), fiip = IntegralFunction((y, x, p) -> f_helper!(f, y, x, p), zeros(1)) do_tests(; f = fiip, scalarize, lb = ones(dim), ub = 3ones(dim), p = 2.0, alg, abstol, reltol) + do_tests_mooncake(; + f=fiip, scalarize, lb = ones(dim), ub = 3ones(dim), p = 2.0, alg, abstol, reltol) end ### N-dimensional nout iip @@ -237,6 +299,8 @@ for (alg, req) in pairs(alg_req), (j, f) in enumerate(integrands), fiip = IntegralFunction((y, x, p) -> f_helper!(f, y, x, p), zeros(nout)) do_tests(; f = fiip, scalarize, lb = ones(dim), ub = 3ones(dim), p = [2.0i for i in 1:nout], alg, abstol, reltol) + do_tests_mooncake(; f = fiip, scalarize, lb = ones(dim), ub = 3ones(dim), + p = [2.0i for i in 1:nout], alg, abstol, reltol) end ### Batch, One Dimensional @@ -347,7 +411,7 @@ for (alg, req) in pairs(alg_req), (j, f) in enumerate(integrands), p = [2.0i for i in 1:nout], alg, abstol, reltol) end -@testset "ChangeOfVariables rrule" begin +@testset "ChangeOfVariables rrules" begin alg = QuadGKJL() # test a simple u-substitution of x = 2.7u + 1.3 talg = Integrals.ChangeOfVariables(alg) do f, domain @@ -358,18 +422,43 @@ end error("not implemented") end end - testf = (f, lb, ub, p, alg) -> begin + + testf = (f, lb, ub, p, alg, sensealg) -> begin prob = IntegralProblem(f, (lb, ub), p) - solve(prob, alg; abstol, reltol).u + solve(prob, alg; abstol, reltol, sensealg = sensealg).u end _testf = (x, p) -> x^2 * p lb, ub, p = 1.0, 5.0, 2.0 - sol = Zygote.withgradient((args...) -> testf(_testf, args..., alg), lb, ub, p) - tsol = Zygote.withgradient((args...) -> testf(_testf, args..., talg), lb, ub, p) - @test sol.val ≈ tsol.val - # Fundamental theorem of Calculus part 1 - @test sol.grad[1] ≈ tsol.grad[1] ≈ -_testf(lb, p) - @test sol.grad[2] ≈ tsol.grad[2] ≈ _testf(ub, p) - # This is to check ∂p - @test sol.grad[3] ≈ tsol.grad[3] -end + + @testset "Sensitivity using Zygote" begin + sensealg = Integrals.ReCallVJP(Integrals.ZygoteVJP()) + sol = Zygote.withgradient((args...) -> testf(_testf, args...), lb, ub, p, alg, sensealg) + tsol = Zygote.withgradient((args...) -> testf(_testf, args...), lb, ub, p, talg, sensealg) + @test sol.val ≈ tsol.val + # Fundamental theorem of Calculus part 1 + @test sol.grad[1] ≈ tsol.grad[1] ≈ -_testf(lb, p) + @test sol.grad[2] ≈ tsol.grad[2] ≈ _testf(ub, p) + # This is to check ∂p + @test sol.grad[3] ≈ tsol.grad[3] + end + + @testset "Sensitivity using Mooncake" begin + sensealg = Integrals.ReCallVJP(Integrals.MooncakeVJP()) + # anonymous function for cache creation and gradient evaluation call must be the same. + func = (args...) -> testf(_testf, args...) + cache = Mooncake.prepare_gradient_cache(func, lb, ub, p, alg, sensealg) + sol = Mooncake.value_and_gradient!!(cache, func, + lb, ub, p, alg, sensealg) + + cache = Mooncake.prepare_gradient_cache(func, lb, ub, p, talg, sensealg) + tsol = Mooncake.value_and_gradient!!( + cache, func, lb, ub, p, talg, sensealg) + + @test sol[1] ≈ tsol[1] + # Fundamental theorem of Calculus part 1 + @test sol[2][2] ≈ tsol[2][2] ≈ -_testf(lb, p) + @test sol[2][3] ≈ tsol[2][3] ≈ _testf(ub, p) + # To check ∂p + @test sol[2][4] ≈ tsol[2][4] + end +end \ No newline at end of file