From a8948b707aa82bfaba9d02b522f96412357512d0 Mon Sep 17 00:00:00 2001 From: Astitva Aggarwal Date: Fri, 5 Dec 2025 17:26:50 +0000 Subject: [PATCH 01/12] rrules for integrals.jl --- ext/IntegralsMooncakeExt.jl | 149 ++++++++++++++++++++++++++++++++++++ src/Integrals.jl | 7 ++ 2 files changed, 156 insertions(+) create mode 100644 ext/IntegralsMooncakeExt.jl diff --git a/ext/IntegralsMooncakeExt.jl b/ext/IntegralsMooncakeExt.jl new file mode 100644 index 0000000..90bd823 --- /dev/null +++ b/ext/IntegralsMooncakeExt.jl @@ -0,0 +1,149 @@ +module IntegralsMooncakeExt +using Mooncake +using Integrals +using Zygote +using SciMLBase +using Mooncake: @from_chainrules, @zero_derivative, MinimalCtx, rrule!!, Dual, CoDual +using Integrals: AbstractIntegralMetaAlgorithm, IntegralProblem + +@zero_derivative MinimalCtx Tuple{typeof(Integrals.checkkwargs),Vararg} +@zero_derivative MinimalCtx Tuple{typeof(Integrals.isinplace),Vararg} +@zero_derivative MinimalCtx Tuple{typeof(Integrals.init_cacheval),Union{<:SciMLBase.AbstractIntegralAlgorithm,<:AbstractIntegralMetaAlgorithm},Union{<:IntegralProblem,<:SampledIntegralProblem}} +@zero_derivative MinimalCtx Tuple{typeof(Integrals.substitute_f),Union{<:BatchIntegralFunction,<:IntegralFunction},Any,Any,Any} +@zero_derivative MinimalCtx Tuple{typeof(Integrals.substitute_v),Any,Any,Union{<:AbstractVector,<:Number},Union{<:AbstractVector,<:Number}} +@zero_derivative MinimalCtx Tuple{typeof(Integrals.substitute_bv),Any,AbstractArray,Union{<:AbstractVector,<:Number},Union{<:AbstractVector,<:Number}} + +@from_chainrules MinimalCtx Tuple{Type{IntegralProblem},Any,Any,Any} true +@from_chainrules MinimalCtx Tuple{Type{IntegralProblem{iip}},Any,Any,Any} where {iip} true +@from_chainrules MinimalCtx Tuple{typeof(SciMLBase.build_solution),IntegralProblem,Any,Any,Any} true +@from_chainrules MinimalCtx Tuple{typeof(Integrals.u2t),Any,Any} true + +# evaluate doesnt need rrules as Mooncake supports mutation +# @is_primitive MinimalCtx Tuple{typeof(Integrals._evaluate!),Any,Any,Any,Any} +@from_chainrules MinimalCtx Tuple{typeof(Integrals.__solvebp),Any,Any,Any,Any,Any} true + +function ChainRulesCore.rrule(::typeof(Integrals.__solvebp), cache, alg, sensealg, domain, + p; + kwargs...) + # TODO: integrate the primal and dual in the same call to the quadrature library + out = Integrals.__solvebp_call(cache, alg, sensealg, domain, p; kwargs...) + + # the adjoint will be the integral of the input sensitivities, so it maps the + # sensitivity of the output to an object of the type of the parameters + function quadrature_adjoint(Δ) + # https://juliadiff.org/ChainRulesCore.jl/dev/design/many_tangents.html#manytypes + if sensealg.vjp isa Integrals.ZygoteVJP + if isinplace(cache) + # zygote doesn't support mutation, so we build an oop pullback + if cache.f isa BatchIntegralFunction + dx = similar(cache.f.integrand_prototype, + size(cache.f.integrand_prototype)[begin:(end-1)]..., 1) + _f = x -> (cache.f(dx, x, p); dx) + # TODO: let the user pass a batched jacobian so we can return a BatchIntegralFunction + dfdp_ = function (x, p) + x_ = x isa AbstractArray ? reshape(x, size(x)..., 1) : [x] + z, back = Zygote.pullback(p) do p + _dx = Zygote.Buffer(dx) + cache.f(_dx, x_, p) + copy(_dx) + end + return back(z .= (Δ isa AbstractArray ? reshape(Δ, size(Δ)..., 1) : + Δ))[1] + end + dfdp = IntegralFunction{false}(dfdp_, nothing) + else + dx = similar(cache.f.integrand_prototype) + _f = x -> (cache.f(dx, x, p); dx) + dfdp_ = function (x, p) + _, back = Zygote.pullback(p) do p + _dx = Zygote.Buffer(dx) + cache.f(_dx, x, p) + copy(_dx) + end + back(Δ)[1] + end + dfdp = IntegralFunction{false}(dfdp_, nothing) + end + else + _f = x -> cache.f(x, p) + if cache.f isa BatchIntegralFunction + # TODO: let the user pass a batched jacobian so we can return a BatchIntegralFunction + dfdp_ = function (x, p) + x_ = x isa AbstractArray ? reshape(x, size(x)..., 1) : [x] + z, back = Zygote.pullback(p -> cache.f(x_, p), p) + return back(Δ isa AbstractArray ? reshape(Δ, size(Δ)..., 1) : [Δ])[1] + end + dfdp = IntegralFunction{false}(dfdp_, nothing) + else + dfdp_ = function (x, p) + z, back = Zygote.pullback(p -> cache.f(x, p), p) + back(z isa Number ? only(Δ) : Δ)[1] + end + dfdp = IntegralFunction{false}(dfdp_, nothing) + end + end + elseif sensealg.vjp isa Integrals.MooncakeVJP + _f = x -> cache.f(x, p) + if cache.f isa BatchIntegralFunction + # TODO: let the user pass a batched jacobian so we can return a BatchIntegralFunction + dfdp_ = function (x, p) + x_ = x isa AbstractArray ? reshape(x, size(x)..., 1) : [x] + clos = p -> cache.f(x_, p) + z, back = Mooncake.value_and_gradient!!(Mooncake.prepare_gradient_cache(clos, p), clos, p) + return back(Δ isa AbstractArray ? reshape(Δ, size(Δ)..., 1) : [Δ])[1] + end + dfdp = IntegralFunction{false}(dfdp_, nothing) + else + dfdp_ = function (x, p) + clos = p -> cache.f(x_, p) + z, back = Mooncake.value_and_gradient!!(Mooncake.prepare_gradient_cache(clos, p), clos, p) + back(z isa Number ? only(Δ) : Δ)[1] + end + dfdp = IntegralFunction{false}(dfdp_, nothing) + end + elseif sensealg.vjp isa Integrals.ReverseDiffVJP + error("TODO") + end + + prob = Integrals.build_problem(cache) + # dp_prob = remake(prob, f = dfdp) # fails because we change iip + dp_prob = IntegralProblem(dfdp, prob.domain, prob.p; prob.kwargs...) + # the infinity transformation was already applied to f so we don't apply it to dfdp + dp_cache = init(dp_prob, + alg; + sensealg=sensealg, + cache.kwargs...) + + project_p = ProjectTo(p) + dp = project_p(solve!(dp_cache).u) + + lb, ub = domain + if lb isa Number + # TODO replace evaluation at endpoint (which anyone can do without Integrals.jl) + # with integration of dfdx uing the same quadrature + dlb = cache.f isa BatchIntegralFunction ? -batch_unwrap(_f([lb])) : -_f(lb) + dub = cache.f isa BatchIntegralFunction ? batch_unwrap(_f([ub])) : _f(ub) + return (NoTangent(), + NoTangent(), + NoTangent(), + NoTangent(), + Tangent{typeof(domain)}(dot(dlb, Δ), dot(dub, Δ)), + dp) + else + # we need to compute 2*length(lb) integrals on the faces of the hypercube, as we + # can see from writing the multidimensional integral as an iterated integral + # alternatively we can use Stokes' theorem to replace the integral on the + # boundary with a volume integral of the flux of the integrand + # ∫∂Ω ω = ∫Ω dω, which would be better since we won't have to change the + # dimensionality of the integral or the quadrature used (such as quadratures + # that don't evaluate points on the boundaries) and it could be generalized to + # other kinds of domains. The only question is to determine ω in terms of f and + # the deformation of the surface (e.g. consider integral over an ellipse and + # asking for the derivative of the result w.r.t. the semiaxes of the ellipse) + return (NoTangent(), NoTangent(), NoTangent(), NoTangent(), NoTangent(), dp) + end + end + out, quadrature_adjoint +end +# Test files: tests/ - derivative_tests,nested_ad_tests.jl, Quadrule_tests.jl +end \ No newline at end of file diff --git a/src/Integrals.jl b/src/Integrals.jl index 15f33d5..771fe60 100644 --- a/src/Integrals.jl +++ b/src/Integrals.jl @@ -61,6 +61,13 @@ struct ReverseDiffVJP <: IntegralVJP compile::Bool end +""" + MooncakeVJP <: IntegralVJP + +Uses Mooncake.jl for vector-Jacobian products in automatic differentiation of integrals. +""" +struct MooncakeVJP <: IntegralVJP end + function scale_x!(_x, ub, lb, x) _x .= (ub .- lb) .* x .+ lb _x From f2ecbafcbb1903b8ed5a7b481b837dde38bc09cb Mon Sep 17 00:00:00 2001 From: Astitva Aggarwal Date: Thu, 11 Dec 2025 18:02:08 +0000 Subject: [PATCH 02/12] sensealg Mooncake choice, most algs, use Chainrules --- Project.toml | 7 ++- ext/IntegralsMooncakeExt.jl | 101 +++++++++++++++++++++++++++----- ext/IntegralsZygoteExt.jl | 112 +----------------------------------- test/derivative_tests.jl | 107 +++++++++++++++++++++++++++++----- 4 files changed, 187 insertions(+), 140 deletions(-) diff --git a/Project.toml b/Project.toml index 02810e7..4e3e3b2 100644 --- a/Project.toml +++ b/Project.toml @@ -21,6 +21,7 @@ Cubature = "667455a9-e2ce-5579-9412-b964f529a492" FastGaussQuadrature = "442a2c76-b920-505d-bb47-c5924d526838" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" MCIntegration = "ea1e2de9-7db7-4b42-91ee-0cd1bf6df167" +Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [extensions] @@ -30,7 +31,8 @@ IntegralsCubatureExt = "Cubature" IntegralsFastGaussQuadratureExt = "FastGaussQuadrature" IntegralsForwardDiffExt = "ForwardDiff" IntegralsMCIntegrationExt = "MCIntegration" -IntegralsZygoteExt = ["Zygote", "ChainRulesCore"] +IntegralsMooncakeExt = ["Mooncake", "Zygote", "ChainRulesCore"] +IntegralsZygoteExt = ["Zygote", "ChainRulesCore", "Mooncake"] [compat] Aqua = "0.8" @@ -68,10 +70,11 @@ FastGaussQuadrature = "442a2c76-b920-505d-bb47-c5924d526838" FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" MCIntegration = "ea1e2de9-7db7-4b42-91ee-0cd1bf6df167" +Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["Aqua", "Arblib", "StaticArrays", "FiniteDiff", "SafeTestsets", "Test", "Distributions", "ForwardDiff", "Zygote", "ChainRulesCore", "FastGaussQuadrature", "Cuba", "Cubature", "MCIntegration"] +test = ["Aqua", "Arblib", "StaticArrays", "FiniteDiff", "SafeTestsets", "Test", "Distributions", "ForwardDiff", "Zygote", "ChainRulesCore", "FastGaussQuadrature", "Cuba", "Cubature", "MCIntegration", "Mooncake"] diff --git a/ext/IntegralsMooncakeExt.jl b/ext/IntegralsMooncakeExt.jl index 90bd823..8526537 100644 --- a/ext/IntegralsMooncakeExt.jl +++ b/ext/IntegralsMooncakeExt.jl @@ -1,11 +1,18 @@ module IntegralsMooncakeExt using Mooncake +using LinearAlgebra: dot using Integrals -using Zygote -using SciMLBase -using Mooncake: @from_chainrules, @zero_derivative, MinimalCtx, rrule!!, Dual, CoDual +using SciMLBase, QuadGK +using Mooncake: @from_chainrules, @zero_derivative, @is_primitive, increment!!, increment_and_get_rdata!, MinimalCtx, rrule!!, NoFData, CoDual, primal, NoRData, zero_fcodual using Integrals: AbstractIntegralMetaAlgorithm, IntegralProblem +import ChainRulesCore +import ChainRulesCore: Tangent, NoTangent, ProjectTo +using Zygote # use chainrules defined in ZygoteExt +batch_unwrap(x::AbstractArray) = dropdims(x; dims=ndims(x)) + +@zero_derivative MinimalCtx Tuple{typeof(QuadGK.quadgk),Vararg} +@zero_derivative MinimalCtx Tuple{typeof(QuadGK.cachedrule),Any,Integer} @zero_derivative MinimalCtx Tuple{typeof(Integrals.checkkwargs),Vararg} @zero_derivative MinimalCtx Tuple{typeof(Integrals.isinplace),Vararg} @zero_derivative MinimalCtx Tuple{typeof(Integrals.init_cacheval),Union{<:SciMLBase.AbstractIntegralAlgorithm,<:AbstractIntegralMetaAlgorithm},Union{<:IntegralProblem,<:SampledIntegralProblem}} @@ -13,13 +20,49 @@ using Integrals: AbstractIntegralMetaAlgorithm, IntegralProblem @zero_derivative MinimalCtx Tuple{typeof(Integrals.substitute_v),Any,Any,Union{<:AbstractVector,<:Number},Union{<:AbstractVector,<:Number}} @zero_derivative MinimalCtx Tuple{typeof(Integrals.substitute_bv),Any,AbstractArray,Union{<:AbstractVector,<:Number},Union{<:AbstractVector,<:Number}} +# @from_chainrules MinimalCtx Tuple{Type{IntegralProblem{iip}},Any,Any,Any} where {iip} true +@is_primitive MinimalCtx Tuple{Type{IntegralProblem{iip}},Any,Any,Any} where {iip} +function Mooncake.rrule!!(::CoDual{Type{IntegralProblem{iip}}}, f::CoDual, domain::CoDual, p::CoDual; kwargs...) where {iip} + f_prim, domain_prim, p_prim = map(primal, (f, domain, p)) + prob = IntegralProblem{iip}(f_prim, domain_prim, p_prim; kwargs...) + + function IntegralProblem_iip_pullback(Δ) + data = Δ isa NoRData ? Δ : Δ.data + ddomain = hasproperty(data, :domain) ? data.domain : NoRData() + dp = hasproperty(data, :p) ? data.p : NoRData() + dkwargs = Δ isa NoRData ? Δ : data.kwargs + + # domain is always a Tuple, so it always has NoFData + # below conditional is in case p is an Array or similar + if Mooncake.rdata_type(typeof(p_prim)) == NoRData() + p.dx .+= dp + grad_p = NoRData() + else + grad_p = dp + end + + return NoRData(), NoRData(), ddomain, grad_p, dkwargs + end + return zero_fcodual(prob), IntegralProblem_iip_pullback +end + +@is_primitive MinimalCtx Tuple{typeof(SciMLBase.build_solution),IntegralProblem,Any,Any,Any} +function Mooncake.rrule!!(::CoDual{typeof(SciMLBase.build_solution)}, prob::CoDual{IntegralProblem}, alg::CoDual, u::CoDual, resid::CoDual; kwargs) + kwargs_fp = map(primal, kwargs) + function pb!!(Δ) + return NoRData(), + NoRData(), + NoRData(), + Δ, + NoRData() + end + return SciMLBase.build_solution(prob.primal, alg.primal, u.primal, resid.primal; kwargs_fp...), pb!! +end + +# Mooncake does not need chainrule for evaluate! as it supports mutation. @from_chainrules MinimalCtx Tuple{Type{IntegralProblem},Any,Any,Any} true -@from_chainrules MinimalCtx Tuple{Type{IntegralProblem{iip}},Any,Any,Any} where {iip} true @from_chainrules MinimalCtx Tuple{typeof(SciMLBase.build_solution),IntegralProblem,Any,Any,Any} true @from_chainrules MinimalCtx Tuple{typeof(Integrals.u2t),Any,Any} true - -# evaluate doesnt need rrules as Mooncake supports mutation -# @is_primitive MinimalCtx Tuple{typeof(Integrals._evaluate!),Any,Any,Any,Any} @from_chainrules MinimalCtx Tuple{typeof(Integrals.__solvebp),Any,Any,Any,Any,Any} true function ChainRulesCore.rrule(::typeof(Integrals.__solvebp), cache, alg, sensealg, domain, @@ -89,15 +132,18 @@ function ChainRulesCore.rrule(::typeof(Integrals.__solvebp), cache, alg, senseal dfdp_ = function (x, p) x_ = x isa AbstractArray ? reshape(x, size(x)..., 1) : [x] clos = p -> cache.f(x_, p) - z, back = Mooncake.value_and_gradient!!(Mooncake.prepare_gradient_cache(clos, p), clos, p) - return back(Δ isa AbstractArray ? reshape(Δ, size(Δ)..., 1) : [Δ])[1] + cache_z = Mooncake.prepare_pullback_cache(clos, p) + z, grads = Mooncake.value_and_pullback!!(cache_z, [Δ.u], clos, p) + return grads[2] end dfdp = IntegralFunction{false}(dfdp_, nothing) else dfdp_ = function (x, p) - clos = p -> cache.f(x_, p) - z, back = Mooncake.value_and_gradient!!(Mooncake.prepare_gradient_cache(clos, p), clos, p) - back(z isa Number ? only(Δ) : Δ)[1] + clos = p -> cache.f(x, p) + cache_z = Mooncake.prepare_pullback_cache(clos, p) + # Δ.u is integrand function's output sensitivity which we pass into Mooncake's pullback + z, grads = Mooncake.value_and_pullback!!(cache_z, Δ.u, clos, p) + return grads[2] end dfdp = IntegralFunction{false}(dfdp_, nothing) end @@ -116,6 +162,9 @@ function ChainRulesCore.rrule(::typeof(Integrals.__solvebp), cache, alg, senseal project_p = ProjectTo(p) dp = project_p(solve!(dp_cache).u) + + # Because Mooncake tangent structure vs Zygote, Chainrules, ReverseDiff + du_adj = sensealg.vjp isa Integrals.MooncakeVJP ? Δ.u : Δ lb, ub = domain if lb isa Number @@ -127,7 +176,7 @@ function ChainRulesCore.rrule(::typeof(Integrals.__solvebp), cache, alg, senseal NoTangent(), NoTangent(), NoTangent(), - Tangent{typeof(domain)}(dot(dlb, Δ), dot(dub, Δ)), + Tangent{typeof(domain)}(dot(dlb, du_adj), dot(dub, du_adj)), dp) else # we need to compute 2*length(lb) integrals on the faces of the hypercube, as we @@ -145,5 +194,29 @@ function ChainRulesCore.rrule(::typeof(Integrals.__solvebp), cache, alg, senseal end out, quadrature_adjoint end -# Test files: tests/ - derivative_tests,nested_ad_tests.jl, Quadrule_tests.jl + +function Mooncake.increment_and_get_rdata!( + f::Tuple{Vector{Float64},Vector{Float64}}, + r::Mooncake.NoRData, + t::Tangent{Any,Tuple{Vector{Float64},Vector{Float64}}}, +) + Mooncake.increment!!(f[1], t[1]) + Mooncake.increment!!(f[2], t[2]) + return NoRData() +end + +function Mooncake.increment_and_get_rdata!( + f::NoFData, r::Tuple{T,T}, t::Union{Tangent{Any,Tuple{T,T}},Tangent{Tuple{T,T},Tuple{T,T}}} +) where {T<:Base.IEEEFloat} + return r .+ t.backing +end + +function Mooncake.increment_and_get_rdata!( + f::NoFData, + r::T, + t::Tangent{Any,@NamedTuple{u::T, resid::T, prob::Tangent{Any,@NamedTuple{f::NoTangent, domain::Tangent{Any,Tuple{T,T}}, p::T, kwargs::NoTangent}}, + alg::NoTangent, retcode::NoTangent, chi::NoTangent, stats::NoTangent}}) where T<:Number + return Mooncake.increment!!(t.u, r) +end + end \ No newline at end of file diff --git a/ext/IntegralsZygoteExt.jl b/ext/IntegralsZygoteExt.jl index 2e77adf..4bc43be 100644 --- a/ext/IntegralsZygoteExt.jl +++ b/ext/IntegralsZygoteExt.jl @@ -4,6 +4,7 @@ using Integrals using Zygote import ChainRulesCore import ChainRulesCore: Tangent, NoTangent, ProjectTo +using Mooncake # call __solve_bp's chainrule from Mooncake's extension ChainRulesCore.@non_differentiable Integrals.checkkwargs(kwargs...) ChainRulesCore.@non_differentiable Integrals.isinplace(f, args...) # fixes #99 @@ -62,117 +63,6 @@ function ChainRulesCore.rrule(::typeof(Integrals.u2t), lb, ub) return out, u2t_pullback end -function ChainRulesCore.rrule(::typeof(Integrals.__solvebp), cache, alg, sensealg, domain, - p; - kwargs...) - # TODO: integrate the primal and dual in the same call to the quadrature library - out = Integrals.__solvebp_call(cache, alg, sensealg, domain, p; kwargs...) - - # the adjoint will be the integral of the input sensitivities, so it maps the - # sensitivity of the output to an object of the type of the parameters - function quadrature_adjoint(Δ) - # https://juliadiff.org/ChainRulesCore.jl/dev/design/many_tangents.html#manytypes - if isinplace(cache) - # zygote doesn't support mutation, so we build an oop pullback - if sensealg.vjp isa Integrals.ZygoteVJP - if cache.f isa BatchIntegralFunction - dx = similar(cache.f.integrand_prototype, - size(cache.f.integrand_prototype)[begin:(end - 1)]..., 1) - _f = x -> (cache.f(dx, x, p); dx) - # TODO: let the user pass a batched jacobian so we can return a BatchIntegralFunction - dfdp_ = function (x, p) - x_ = x isa AbstractArray ? reshape(x, size(x)..., 1) : [x] - z, back = Zygote.pullback(p) do p - _dx = Zygote.Buffer(dx) - cache.f(_dx, x_, p) - copy(_dx) - end - return back(z .= (Δ isa AbstractArray ? reshape(Δ, size(Δ)..., 1) : - Δ))[1] - end - dfdp = IntegralFunction{false}(dfdp_, nothing) - else - dx = similar(cache.f.integrand_prototype) - _f = x -> (cache.f(dx, x, p); dx) - dfdp_ = function (x, p) - _, back = Zygote.pullback(p) do p - _dx = Zygote.Buffer(dx) - cache.f(_dx, x, p) - copy(_dx) - end - back(Δ)[1] - end - dfdp = IntegralFunction{false}(dfdp_, nothing) - end - elseif sensealg.vjp isa Integrals.ReverseDiffVJP - error("TODO") - end - else - _f = x -> cache.f(x, p) - if sensealg.vjp isa Integrals.ZygoteVJP - if cache.f isa BatchIntegralFunction - # TODO: let the user pass a batched jacobian so we can return a BatchIntegralFunction - dfdp_ = function (x, p) - x_ = x isa AbstractArray ? reshape(x, size(x)..., 1) : [x] - z, back = Zygote.pullback(p -> cache.f(x_, p), p) - return back(Δ isa AbstractArray ? reshape(Δ, size(Δ)..., 1) : [Δ])[1] - end - dfdp = IntegralFunction{false}(dfdp_, nothing) - else - dfdp_ = function (x, p) - z, back = Zygote.pullback(p -> cache.f(x, p), p) - back(z isa Number ? only(Δ) : Δ)[1] - end - dfdp = IntegralFunction{false}(dfdp_, nothing) - end - elseif sensealg.vjp isa Integrals.ReverseDiffVJP - error("TODO") - end - end - - prob = Integrals.build_problem(cache) - # dp_prob = remake(prob, f = dfdp) # fails because we change iip - dp_prob = IntegralProblem(dfdp, prob.domain, prob.p; prob.kwargs...) - # the infinity transformation was already applied to f so we don't apply it to dfdp - dp_cache = init(dp_prob, - alg; - sensealg = sensealg, - cache.kwargs...) - - project_p = ProjectTo(p) - dp = project_p(solve!(dp_cache).u) - - lb, ub = domain - if lb isa Number - # TODO replace evaluation at endpoint (which anyone can do without Integrals.jl) - # with integration of dfdx uing the same quadrature - dlb = cache.f isa BatchIntegralFunction ? -batch_unwrap(_f([lb])) : -_f(lb) - dub = cache.f isa BatchIntegralFunction ? batch_unwrap(_f([ub])) : _f(ub) - return (NoTangent(), - NoTangent(), - NoTangent(), - NoTangent(), - Tangent{typeof(domain)}(dot(dlb, Δ), dot(dub, Δ)), - dp) - else - # we need to compute 2*length(lb) integrals on the faces of the hypercube, as we - # can see from writing the multidimensional integral as an iterated integral - # alternatively we can use Stokes' theorem to replace the integral on the - # boundary with a volume integral of the flux of the integrand - # ∫∂Ω ω = ∫Ω dω, which would be better since we won't have to change the - # dimensionality of the integral or the quadrature used (such as quadratures - # that don't evaluate points on the boundaries) and it could be generalized to - # other kinds of domains. The only question is to determine ω in terms of f and - # the deformation of the surface (e.g. consider integral over an ellipse and - # asking for the derivative of the result w.r.t. the semiaxes of the ellipse) - return (NoTangent(), NoTangent(), NoTangent(), NoTangent(), NoTangent(), dp) - end - end - out, quadrature_adjoint -end - -batch_unwrap(x::AbstractArray) = dropdims(x; dims = ndims(x)) - Zygote.@adjoint function Zygote.literal_getproperty(sol::SciMLBase.IntegralSolution, ::Val{:u}) sol.u, Δ -> (SciMLBase.build_solution(sol.prob, sol.alg, Δ, sol.resid),) diff --git a/test/derivative_tests.jl b/test/derivative_tests.jl index 7a21dea..c04e553 100644 --- a/test/derivative_tests.jl +++ b/test/derivative_tests.jl @@ -1,4 +1,4 @@ -using Integrals, Zygote, FiniteDiff, ForwardDiff#, SciMLSensitivity +using Integrals, Mooncake, Zygote, FiniteDiff, ForwardDiff#, SciMLSensitivity using Cuba, Cubature using FastGaussQuadrature using Test @@ -84,6 +84,10 @@ Base.axes(::Scalar) = ScalarAxes() Base.iterate(::ScalarAxes) = nothing Base.reshape(A::AbstractArray, ::ScalarAxes) = Scalar(only(A)) +# Scalar struct defined around Real Numbers (test/derivative_tests.jl) +# Mooncake, like Zygote also treats 0-D data wrt to the type of datastructure. +Mooncake.rdata_type(::Type{Scalar{T}}) where {T<:Real} = Mooncake.rdata_type(T) + # here we assume f evaluated at scalar inputs gives a scalar output # p will be able to be a number after https://github.com/FluxML/Zygote.jl/pull/1489 # p will be able to be a 0-array after https://github.com/FluxML/Zygote.jl/pull/1491 @@ -144,6 +148,52 @@ do_tests = function (; f, scalarize, lb, ub, p, alg, abstol, reltol) return end +# Mooncake Sensealg testing helper function +do_tests_mooncake = function (; f, scalarize, lb, ub, p, alg, abstol, reltol) + testf = function (lb, ub, p) + prob = IntegralProblem(f, (lb, ub), p) + scalarize(solve(prob, alg; reltol, abstol, sensealg=Integrals.ReCallVJP{Integrals.MooncakeVJP}(Integrals.MooncakeVJP()))) + end + sol_fp = testf(lb, ub, p) + + # sensealg when non zygoet? + cache = Mooncake.prepare_gradient_cache(testf, lb, ub, p isa Number && f isa BatchIntegralFunction ? Scalar(p) : p) + forwpassval, gradients = Mooncake.value_and_gradient!!(cache, testf, lb, ub, p isa Number && f isa BatchIntegralFunction ? Scalar(p) : p) + + @test forwpassval == sol_fp + + f_lb = lb -> testf(lb, ub, p) + f_ub = ub -> testf(lb, ub, p) + + dlb = lb isa AbstractArray ? :gradient : :derivative + dub = ub isa AbstractArray ? :gradient : :derivative + + dlb2 = getproperty(FiniteDiff, Symbol(:finite_difference_, dlb))(f_lb, lb) + dub2 = getproperty(FiniteDiff, Symbol(:finite_difference_, dub))(f_ub, ub) + + if lb isa Number + @test gradients[2] ≈ dlb2 atol = abstol rtol = reltol + @test gradients[3] ≈ dub2 atol = abstol rtol = reltol + else # TODO: implement multivariate limit derivatives in MooncakeExt + @test_broken gradients[2] ≈ dlb2 atol = abstol rtol = reltol + @test_broken gradients[3] ≈ dub2 atol = abstol rtol = reltol + end + + f_p = p -> testf(lb, ub, p) + dp = p isa AbstractArray ? :gradient : :derivative + + dp2 = getproperty(FiniteDiff, Symbol(:finite_difference_, dp))(f_p, p) + dp3 = getproperty(ForwardDiff, dp)(f_p, p) + + @test dp2 ≈ dp3 atol = abstol rtol = reltol + + # test Mooncake for parameter p + @test gradients[4] ≈ dp2 atol = abstol rtol = reltol + @test dp2 ≈ dp3 atol = abstol rtol = reltol + + return +end + ### One Dimensional for (alg, req) in pairs(alg_req), (j, f) in enumerate(integrands), (i, scalarize) in enumerate(scalarize_solution) @@ -152,6 +202,7 @@ for (alg, req) in pairs(alg_req), (j, f) in enumerate(integrands), @info "One-dimensional, scalar, oop derivative test" alg=nameof(typeof(alg)) integrand=j scalarize=i do_tests(; f, scalarize, lb = 1.0, ub = 3.0, p = 2.0, alg, abstol, reltol) + do_tests_mooncake(; f, scalarize, lb = 1.0, ub = 3.0, p = 2.0, alg, abstol, reltol) end ## One-dimensional nout @@ -163,6 +214,8 @@ for (alg, req) in pairs(alg_req), (j, f) in enumerate(integrands), @info "One-dimensional, multivariate, oop derivative test" alg=nameof(typeof(alg)) integrand=j scalarize=i nout do_tests(; f, scalarize, lb = 1.0, ub = 3.0, p = [2.0i for i in 1:nout], alg, abstol, reltol) + do_tests_mooncake(; + f, scalarize, lb=1.0, ub=3.0, p=[2.0i for i in 1:nout], alg, abstol, reltol) end ### N-dimensional @@ -173,6 +226,7 @@ for (alg, req) in pairs(alg_req), (j, f) in enumerate(integrands), @info "Multi-dimensional, scalar, oop derivative test" alg=nameof(typeof(alg)) integrand=j scalarize=i dim do_tests(; f, scalarize, lb = ones(dim), ub = 3ones(dim), p = 2.0, alg, abstol, reltol) + do_tests_mooncake(; f, scalarize, lb=ones(dim), ub=3ones(dim), p=2.0, alg, abstol, reltol) end ### N-dimensional nout @@ -185,6 +239,8 @@ for (alg, req) in pairs(alg_req), (j, f) in enumerate(integrands), @info "Multi-dimensional, multivariate, oop derivative test" alg=nameof(typeof(alg)) integrand=j scalarize=i dim nout do_tests(; f, scalarize, lb = ones(dim), ub = 3ones(dim), p = [2.0i for i in 1:nout], alg, abstol, reltol) + do_tests_mooncake(; f, scalarize, lb=ones(dim), ub=3ones(dim), + p=[2.0i for i in 1:nout], alg, abstol, reltol) end ### One Dimensional @@ -347,7 +403,7 @@ for (alg, req) in pairs(alg_req), (j, f) in enumerate(integrands), p = [2.0i for i in 1:nout], alg, abstol, reltol) end -@testset "ChangeOfVariables rrule" begin +@testset "ChangeOfVariables rrules" begin alg = QuadGKJL() # test a simple u-substitution of x = 2.7u + 1.3 talg = Integrals.ChangeOfVariables(alg) do f, domain @@ -358,18 +414,43 @@ end error("not implemented") end end - testf = (f, lb, ub, p, alg) -> begin + + testf = (f, lb, ub, p, alg, sensealg) -> begin prob = IntegralProblem(f, (lb, ub), p) - solve(prob, alg; abstol, reltol).u + solve(prob, alg; abstol, reltol, sensealg = sensealg).u end _testf = (x, p) -> x^2 * p lb, ub, p = 1.0, 5.0, 2.0 - sol = Zygote.withgradient((args...) -> testf(_testf, args..., alg), lb, ub, p) - tsol = Zygote.withgradient((args...) -> testf(_testf, args..., talg), lb, ub, p) - @test sol.val ≈ tsol.val - # Fundamental theorem of Calculus part 1 - @test sol.grad[1] ≈ tsol.grad[1] ≈ -_testf(lb, p) - @test sol.grad[2] ≈ tsol.grad[2] ≈ _testf(ub, p) - # This is to check ∂p - @test sol.grad[3] ≈ tsol.grad[3] -end + + @testset "Sensitivity using Zygote" begin + sensealg = Integrals.ReCallVJP{IntegralZygoteVJP}(IntegralZygoteVJP()) + sol = Zygote.withgradient((args...) -> testf(_testf, args..., alg), lb, ub, p, sensealg) + tsol = Zygote.withgradient((args...) -> testf(_testf, args..., talg), lb, ub, p, sensealg) + @test sol.val ≈ tsol.val + # Fundamental theorem of Calculus part 1 + @test sol.grad[1] ≈ tsol.grad[1] ≈ -_testf(lb, p) + @test sol.grad[2] ≈ tsol.grad[2] ≈ _testf(ub, p) + # This is to check ∂p + @test sol.grad[3] ≈ tsol.grad[3] + end + + @testset "Sensitivity using Mooncake" begin + sensealg = Integrals.ReCallVJP{Integrals.MooncakeVJP}(Integrals.MooncakeVJP()) + # anonymous function for cache creation and gradient evaluation call must be the same. + func = (args...) -> testf(_testf, args..., alg, sensealg) + cache = Mooncake.prepare_gradient_cache(func, lb, ub, p, sensealg) + sol = Mooncake.value_and_gradient!!(cache, func, + lb, ub, p, sensealg) + + cache = Mooncake.prepare_gradient_cache(func, lb, ub, p, sensealg) + tsol = Mooncake.value_and_gradient!!( + cache, func, lb, ub, p, sensealg) + + @test sol[1] ≈ tsol[1] + # Fundamental theorem of Calculus part 1 + @test sol[2][2] ≈ tsol[2][2] ≈ -_testf(lb, p) + @test sol[2][3] ≈ tsol[2][3] ≈ _testf(ub, p) + # To check ∂p + @test sol[2][4] ≈ tsol[2][4] + end +end \ No newline at end of file From cf02fce2874bfa2c0ef66756a6666aa14b63c987 Mon Sep 17 00:00:00 2001 From: Astitva Aggarwal Date: Fri, 12 Dec 2025 01:48:27 +0000 Subject: [PATCH 03/12] Cubature algorithms now work + compat entries. --- Project.toml | 1 + ext/IntegralsMooncakeExt.jl | 65 +++++++++++++++++++++++++++++++++---- 2 files changed, 60 insertions(+), 6 deletions(-) diff --git a/Project.toml b/Project.toml index 4e3e3b2..20a3503 100644 --- a/Project.toml +++ b/Project.toml @@ -48,6 +48,7 @@ ForwardDiff = "0.10.36, 1" HCubature = "1.7" LinearAlgebra = "1.10" MCIntegration = "0.4.2" +Mooncake = "0.4.184" MonteCarloIntegration = "0.2" QuadGK = "2.11" Random = "1.10" diff --git a/ext/IntegralsMooncakeExt.jl b/ext/IntegralsMooncakeExt.jl index 8526537..578d530 100644 --- a/ext/IntegralsMooncakeExt.jl +++ b/ext/IntegralsMooncakeExt.jl @@ -1,9 +1,8 @@ module IntegralsMooncakeExt using Mooncake using LinearAlgebra: dot -using Integrals -using SciMLBase, QuadGK -using Mooncake: @from_chainrules, @zero_derivative, @is_primitive, increment!!, increment_and_get_rdata!, MinimalCtx, rrule!!, NoFData, CoDual, primal, NoRData, zero_fcodual +using Integrals, SciMLBase, QuadGK +using Mooncake: @from_chainrules, @zero_derivative, @is_primitive, increment!!, increment_and_get_rdata!, MinimalCtx, rrule!!, NoFData, NoRData, CoDual, primal, NoRData, zero_fcodual using Integrals: AbstractIntegralMetaAlgorithm, IntegralProblem import ChainRulesCore import ChainRulesCore: Tangent, NoTangent, ProjectTo @@ -197,7 +196,7 @@ end function Mooncake.increment_and_get_rdata!( f::Tuple{Vector{Float64},Vector{Float64}}, - r::Mooncake.NoRData, + r::NoRData, t::Tangent{Any,Tuple{Vector{Float64},Vector{Float64}}}, ) Mooncake.increment!!(f[1], t[1]) @@ -215,8 +214,62 @@ function Mooncake.increment_and_get_rdata!( f::NoFData, r::T, t::Tangent{Any,@NamedTuple{u::T, resid::T, prob::Tangent{Any,@NamedTuple{f::NoTangent, domain::Tangent{Any,Tuple{T,T}}, p::T, kwargs::NoTangent}}, - alg::NoTangent, retcode::NoTangent, chi::NoTangent, stats::NoTangent}}) where T<:Number - return Mooncake.increment!!(t.u, r) + alg::NoTangent, retcode::NoTangent, chi::NoTangent, stats::NoTangent}}) where {T<:Base.IEEEFloat} + + # rdata component of t + r (u field) + return t.u + r +end + +function Mooncake.increment_and_get_rdata!( + f::Vector{T}, + r::NoRData, + t::Tangent{ + Any, + @NamedTuple{ + u::Vector{T}, + resid::Vector{T}, + prob::Tangent{ + Any, + @NamedTuple{ + f::NoTangent, + domain::Tangent{Any,Tuple{T,T}}, + p::Vector{Float64}, + kwargs::NoTangent, + } + }, + alg::NoTangent, + retcode::NoTangent, + chi::NoTangent, + stats::NoTangent, + } + }, +) where {T<:Base.IEEEFloat} + + f .+= t.u + # rdata component of t + r + return t.prob.domain +end + +function Mooncake.increment_and_get_rdata!( + f::Vector{T}, + r::NoRData, + t::Tangent{Any,@NamedTuple{u::Vector{T}, resid::Vector{T}, + prob::Tangent{Any,@NamedTuple{f::NoTangent, domain::Tangent{Any,Tuple{Vector{T},Vector{T}}}, + p::Vector{T}, kwargs::NoTangent}}, alg::NoTangent, retcode::NoTangent, + chi::NoTangent, stats::NoTangent}}) where {T<:Base.IEEEFloat} + + f .+= t.u + # rdata component of t + r + return NoRData() +end + +function Mooncake.increment_and_get_rdata!( + f::NoFData, + r::T, + t::Tangent{Any,@NamedTuple{u::T, resid::T, prob::Tangent{Any,@NamedTuple{f::NoTangent, domain::Tangent{Any,Tuple{Vector{T},Vector{T}}}, p::T, kwargs::NoTangent}}, alg::NoTangent, retcode::NoTangent, chi::NoTangent, stats::NoTangent}} +) where {T<:Base.IEEEFloat} + # rdata component of t + r (u field) + return r + t.u end end \ No newline at end of file From af711db1f483afe601f23523c0ce3ad95d38a9fb Mon Sep 17 00:00:00 2001 From: Astitva Aggarwal Date: Fri, 12 Dec 2025 01:52:19 +0000 Subject: [PATCH 04/12] spell checks+minor format --- ext/IntegralsMooncakeExt.jl | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/ext/IntegralsMooncakeExt.jl b/ext/IntegralsMooncakeExt.jl index 578d530..7a9ee7c 100644 --- a/ext/IntegralsMooncakeExt.jl +++ b/ext/IntegralsMooncakeExt.jl @@ -130,18 +130,18 @@ function ChainRulesCore.rrule(::typeof(Integrals.__solvebp), cache, alg, senseal # TODO: let the user pass a batched jacobian so we can return a BatchIntegralFunction dfdp_ = function (x, p) x_ = x isa AbstractArray ? reshape(x, size(x)..., 1) : [x] - clos = p -> cache.f(x_, p) - cache_z = Mooncake.prepare_pullback_cache(clos, p) - z, grads = Mooncake.value_and_pullback!!(cache_z, [Δ.u], clos, p) + integralfunc_closure_p = p -> cache.f(x_, p) + cache_z = Mooncake.prepare_pullback_cache(integralfunc_closure_p, p) + z, grads = Mooncake.value_and_pullback!!(cache_z, [Δ.u], integralfunc_closure_p, p) return grads[2] end dfdp = IntegralFunction{false}(dfdp_, nothing) else dfdp_ = function (x, p) - clos = p -> cache.f(x, p) - cache_z = Mooncake.prepare_pullback_cache(clos, p) + integralfunc_closure_p = p -> cache.f(x, p) + cache_z = Mooncake.prepare_pullback_cache(integralfunc_closure_p, p) # Δ.u is integrand function's output sensitivity which we pass into Mooncake's pullback - z, grads = Mooncake.value_and_pullback!!(cache_z, Δ.u, clos, p) + z, grads = Mooncake.value_and_pullback!!(cache_z, Δ.u, integralfunc_closure_p, p) return grads[2] end dfdp = IntegralFunction{false}(dfdp_, nothing) @@ -161,7 +161,7 @@ function ChainRulesCore.rrule(::typeof(Integrals.__solvebp), cache, alg, senseal project_p = ProjectTo(p) dp = project_p(solve!(dp_cache).u) - + # Because Mooncake tangent structure vs Zygote, Chainrules, ReverseDiff du_adj = sensealg.vjp isa Integrals.MooncakeVJP ? Δ.u : Δ @@ -175,7 +175,7 @@ function ChainRulesCore.rrule(::typeof(Integrals.__solvebp), cache, alg, senseal NoTangent(), NoTangent(), NoTangent(), - Tangent{typeof(domain)}(dot(dlb, du_adj), dot(dub, du_adj)), + Tangent{typeof(domain)}(dot(dlb, du_adj), dot(dub, du_adj)), dp) else # we need to compute 2*length(lb) integrals on the faces of the hypercube, as we From ec808e8b234bd61bc8db25cff13b0fbb35f7cacd Mon Sep 17 00:00:00 2001 From: Astitva Aggarwal Date: Sun, 14 Dec 2025 02:26:54 +0000 Subject: [PATCH 05/12] almost done. --- ext/IntegralsMooncakeExt.jl | 195 +++++++++++++++++++----------------- test/derivative_tests.jl | 8 ++ 2 files changed, 111 insertions(+), 92 deletions(-) diff --git a/ext/IntegralsMooncakeExt.jl b/ext/IntegralsMooncakeExt.jl index 7a9ee7c..12ff507 100644 --- a/ext/IntegralsMooncakeExt.jl +++ b/ext/IntegralsMooncakeExt.jl @@ -2,7 +2,8 @@ module IntegralsMooncakeExt using Mooncake using LinearAlgebra: dot using Integrals, SciMLBase, QuadGK -using Mooncake: @from_chainrules, @zero_derivative, @is_primitive, increment!!, increment_and_get_rdata!, MinimalCtx, rrule!!, NoFData, NoRData, CoDual, primal, NoRData, zero_fcodual +using Mooncake: @from_chainrules, @is_primitive, increment!!, MinimalCtx, rrule!!, NoFData, NoRData, CoDual, primal, NoRData, zero_fcodual +import Mooncake: increment_and_get_rdata!, @zero_derivative using Integrals: AbstractIntegralMetaAlgorithm, IntegralProblem import ChainRulesCore import ChainRulesCore: Tangent, NoTangent, ProjectTo @@ -29,12 +30,12 @@ function Mooncake.rrule!!(::CoDual{Type{IntegralProblem{iip}}}, f::CoDual, domai data = Δ isa NoRData ? Δ : Δ.data ddomain = hasproperty(data, :domain) ? data.domain : NoRData() dp = hasproperty(data, :p) ? data.p : NoRData() - dkwargs = Δ isa NoRData ? Δ : data.kwargs + dkwargs = hasproperty(Δ, :kwargs) ? data.kwargs : NoRData() # domain is always a Tuple, so it always has NoFData # below conditional is in case p is an Array or similar if Mooncake.rdata_type(typeof(p_prim)) == NoRData() - p.dx .+= dp + Mooncake.increment!!(p.dx, dp) grad_p = NoRData() else grad_p = dp @@ -45,25 +46,12 @@ function Mooncake.rrule!!(::CoDual{Type{IntegralProblem{iip}}}, f::CoDual, domai return zero_fcodual(prob), IntegralProblem_iip_pullback end -@is_primitive MinimalCtx Tuple{typeof(SciMLBase.build_solution),IntegralProblem,Any,Any,Any} -function Mooncake.rrule!!(::CoDual{typeof(SciMLBase.build_solution)}, prob::CoDual{IntegralProblem}, alg::CoDual, u::CoDual, resid::CoDual; kwargs) - kwargs_fp = map(primal, kwargs) - function pb!!(Δ) - return NoRData(), - NoRData(), - NoRData(), - Δ, - NoRData() - end - return SciMLBase.build_solution(prob.primal, alg.primal, u.primal, resid.primal; kwargs_fp...), pb!! -end - # Mooncake does not need chainrule for evaluate! as it supports mutation. @from_chainrules MinimalCtx Tuple{Type{IntegralProblem},Any,Any,Any} true -@from_chainrules MinimalCtx Tuple{typeof(SciMLBase.build_solution),IntegralProblem,Any,Any,Any} true @from_chainrules MinimalCtx Tuple{typeof(Integrals.u2t),Any,Any} true -@from_chainrules MinimalCtx Tuple{typeof(Integrals.__solvebp),Any,Any,Any,Any,Any} true +@from_chainrules MinimalCtx Tuple{typeof(SciMLBase.build_solution),IntegralProblem,Any,Any,Any} true +@from_chainrules MinimalCtx Tuple{typeof(Integrals.__solvebp),Any,Any,Any,Any,Any} true function ChainRulesCore.rrule(::typeof(Integrals.__solvebp), cache, alg, sensealg, domain, p; kwargs...) @@ -125,26 +113,39 @@ function ChainRulesCore.rrule(::typeof(Integrals.__solvebp), cache, alg, senseal end end elseif sensealg.vjp isa Integrals.MooncakeVJP - _f = x -> cache.f(x, p) - if cache.f isa BatchIntegralFunction - # TODO: let the user pass a batched jacobian so we can return a BatchIntegralFunction - dfdp_ = function (x, p) - x_ = x isa AbstractArray ? reshape(x, size(x)..., 1) : [x] - integralfunc_closure_p = p -> cache.f(x_, p) - cache_z = Mooncake.prepare_pullback_cache(integralfunc_closure_p, p) - z, grads = Mooncake.value_and_pullback!!(cache_z, [Δ.u], integralfunc_closure_p, p) - return grads[2] + # SOMETHINGS UP WITH DFDP FUNCTION prob.f it cant accept two ints and error. + if isinplace(cache) + if cache.f isa BatchIntegralFunction + error("TODO") + else + dx = similar(cache.f.integrand_prototype) + _f = x -> (cache.f(dx, x, p); dx) + dfdp_ = function (x, p) + # dx is modified inplace by dfdp/integralfunc_closure_p calls AND the Reverse pass tangent comes externally (from Δ). + # Therefore, Δ.u is Tangent passed to the pullback AND integralfunc_closure_p must always return dx as Output. + # i.e. (tangent(output) == Δ.u). Otherwise integralfunc_closure_p only outputs "nothing" and tangent(output) != Δ.u + integralfunc_closure_p = p -> (cache.f(dx, x, p); dx) + cache_z = Mooncake.prepare_pullback_cache(integralfunc_closure_p, p) + z, grads = Mooncake.value_and_pullback!!(cache_z, Δ.u, integralfunc_closure_p, p) + return grads[2] + end + dfdp = IntegralFunction{false}(dfdp_, nothing) end - dfdp = IntegralFunction{false}(dfdp_, nothing) else - dfdp_ = function (x, p) - integralfunc_closure_p = p -> cache.f(x, p) - cache_z = Mooncake.prepare_pullback_cache(integralfunc_closure_p, p) - # Δ.u is integrand function's output sensitivity which we pass into Mooncake's pullback - z, grads = Mooncake.value_and_pullback!!(cache_z, Δ.u, integralfunc_closure_p, p) - return grads[2] + _f = x -> cache.f(x, p) + if cache.f isa BatchIntegralFunction + # TODO: let the user pass a batched jacobian so we can return a BatchIntegralFunction + error("TODO") + else + dfdp_ = function (x, p) + integralfunc_closure_p = p -> cache.f(x, p) + cache_z = Mooncake.prepare_pullback_cache(integralfunc_closure_p, p) + # Δ.u is integrand function's output sensitivity which we pass into Mooncake's pullback + z, grads = Mooncake.value_and_pullback!!(cache_z, Δ.u, integralfunc_closure_p, p) + return grads[2] + end + dfdp = IntegralFunction{false}(dfdp_, nothing) end - dfdp = IntegralFunction{false}(dfdp_, nothing) end elseif sensealg.vjp isa Integrals.ReverseDiffVJP error("TODO") @@ -194,82 +195,92 @@ function ChainRulesCore.rrule(::typeof(Integrals.__solvebp), cache, alg, senseal out, quadrature_adjoint end +# Internal Mooncake overloads to accomodate IntegralSolution etc. Struct's Tangent Types. +# Allows clear translation from ChainRules -> Mooncake's tangent. function Mooncake.increment_and_get_rdata!( - f::Tuple{Vector{Float64},Vector{Float64}}, - r::NoRData, - t::Tangent{Any,Tuple{Vector{Float64},Vector{Float64}}}, -) - Mooncake.increment!!(f[1], t[1]) - Mooncake.increment!!(f[2], t[2]) - return NoRData() + f::NoFData, r::Tuple{T,T}, t::Union{Tangent{Any,Tuple{T,T}},Tangent{Tuple{T,T},Tuple{T,T}}} +) where {T<:Base.IEEEFloat} + return r .+ t.backing end function Mooncake.increment_and_get_rdata!( - f::NoFData, r::Tuple{T,T}, t::Union{Tangent{Any,Tuple{T,T}},Tangent{Tuple{T,T},Tuple{T,T}}} + f::Tuple{Vector{T},Vector{T}}, + r::NoRData, + t::Tangent{Any,Tuple{Vector{T},Vector{T}}}, ) where {T<:Base.IEEEFloat} - return r .+ t.backing + Mooncake.increment!!(f, t.backing) + return NoRData() end +# sol.u & p are single scalar values, domain (lb,ub) is single/multi - variate. function Mooncake.increment_and_get_rdata!( f::NoFData, r::T, - t::Tangent{Any,@NamedTuple{u::T, resid::T, prob::Tangent{Any,@NamedTuple{f::NoTangent, domain::Tangent{Any,Tuple{T,T}}, p::T, kwargs::NoTangent}}, - alg::NoTangent, retcode::NoTangent, chi::NoTangent, stats::NoTangent}}) where {T<:Base.IEEEFloat} - + t::Union{ + Tangent{Any, + @NamedTuple{ + u::T, + resid::T, + prob::Tangent{Any, + @NamedTuple{ + f::S, + domain::Union{Tangent{Any,Tuple{T,T}},Tangent{Any,Tuple{Vector{T},Vector{T}}}}, + p::T, + kwargs::S + } + }, + alg::S, + retcode::S, + chi::S, + stats::S + } + } + }) where {T<:Base.IEEEFloat,S<:NoTangent} # rdata component of t + r (u field) - return t.u + r + return Mooncake.increment_and_get_rdata!(f, r, t.u) end +# sol.u is vector valued, p is vector valued, domain can be single/multi - variate +# resid can be single/vector valued, integrand_prototype field in prob for inplace integrals (iip true) function Mooncake.increment_and_get_rdata!( f::Vector{T}, r::NoRData, - t::Tangent{ - Any, - @NamedTuple{ - u::Vector{T}, - resid::Vector{T}, - prob::Tangent{ - Any, - @NamedTuple{ - f::NoTangent, - domain::Tangent{Any,Tuple{T,T}}, - p::Vector{Float64}, - kwargs::NoTangent, - } - }, - alg::NoTangent, - retcode::NoTangent, - chi::NoTangent, - stats::NoTangent, + t::Union{ + Tangent{ + Any, + @NamedTuple{ + u::Vector{T}, + resid::Union{T,Vector{T}}, + prob::Tangent{ + Any, + @NamedTuple{ + f::Union{S,Tangent{ + Any, + @NamedTuple{ + f::S, + integrand_prototype::Vector{T} + } + }}, + domain::Union{Tangent{Any,Tuple{T,T}},Tangent{Any,Tuple{Vector{T},Vector{T}}}}, + p::Union{T,Vector{T}}, + kwargs::S + } + }, + alg::S, + retcode::S, + chi::S, + stats::S + } } - }, -) where {T<:Base.IEEEFloat} - - f .+= t.u - # rdata component of t + r + } +) where {T<:Base.IEEEFloat,S<:NoTangent} + Mooncake.increment!!(f, t.u) + # rdata component(t) + r return t.prob.domain end -function Mooncake.increment_and_get_rdata!( - f::Vector{T}, - r::NoRData, - t::Tangent{Any,@NamedTuple{u::Vector{T}, resid::Vector{T}, - prob::Tangent{Any,@NamedTuple{f::NoTangent, domain::Tangent{Any,Tuple{Vector{T},Vector{T}}}, - p::Vector{T}, kwargs::NoTangent}}, alg::NoTangent, retcode::NoTangent, - chi::NoTangent, stats::NoTangent}}) where {T<:Base.IEEEFloat} - - f .+= t.u - # rdata component of t + r - return NoRData() -end - -function Mooncake.increment_and_get_rdata!( - f::NoFData, - r::T, - t::Tangent{Any,@NamedTuple{u::T, resid::T, prob::Tangent{Any,@NamedTuple{f::NoTangent, domain::Tangent{Any,Tuple{Vector{T},Vector{T}}}, p::T, kwargs::NoTangent}}, alg::NoTangent, retcode::NoTangent, chi::NoTangent, stats::NoTangent}} -) where {T<:Base.IEEEFloat} - # rdata component of t + r (u field) - return r + t.u +# cannot mutate NoRData() in place, therefore return as is. +function Mooncake.increment!!(::Mooncake.NoRData, y::Union{Tangent{Any,Tuple{T,T}},Tangent{Any,Tuple{Vector{T},Vector{T}}}}) where {T<:Base.IEEEFloat} + return Mooncake.NoRData() end - end \ No newline at end of file diff --git a/test/derivative_tests.jl b/test/derivative_tests.jl index c04e553..2ee2fd1 100644 --- a/test/derivative_tests.jl +++ b/test/derivative_tests.jl @@ -243,6 +243,7 @@ for (alg, req) in pairs(alg_req), (j, f) in enumerate(integrands), p=[2.0i for i in 1:nout], alg, abstol, reltol) end +#### in place IntegralCache, IntegralFunction Tests ### One Dimensional for (alg, req) in pairs(alg_req), (j, f) in enumerate(integrands), (i, scalarize) in enumerate(scalarize_solution) @@ -253,6 +254,7 @@ for (alg, req) in pairs(alg_req), (j, f) in enumerate(integrands), @info "One-dimensional, scalar, iip derivative test" alg=nameof(typeof(alg)) integrand=j scalarize=i fiip = IntegralFunction((y, x, p) -> f_helper!(f, y, x, p), zeros(1)) do_tests(; f = fiip, scalarize, lb = 1.0, ub = 3.0, p = 2.0, alg, abstol, reltol) + do_tests_mooncake(; f=fiip, scalarize, lb = 1.0, ub = 3.0, p = 2.0, alg, abstol, reltol) end ## One-dimensional nout @@ -266,6 +268,8 @@ for (alg, req) in pairs(alg_req), (j, f) in enumerate(integrands), fiip = IntegralFunction((y, x, p) -> f_helper!(f, y, x, p), zeros(nout)) do_tests(; f = fiip, scalarize, lb = 1.0, ub = 3.0, p = [2.0i for i in 1:nout], alg, abstol, reltol) + do_tests_mooncake(; f=fiip, scalarize, lb = 1.0, ub = 3.0, + p = [2.0i for i in 1:nout], alg, abstol, reltol) end ### N-dimensional @@ -279,6 +283,8 @@ for (alg, req) in pairs(alg_req), (j, f) in enumerate(integrands), fiip = IntegralFunction((y, x, p) -> f_helper!(f, y, x, p), zeros(1)) do_tests(; f = fiip, scalarize, lb = ones(dim), ub = 3ones(dim), p = 2.0, alg, abstol, reltol) + do_tests_mooncake(; + f=fiip, scalarize, lb = ones(dim), ub = 3ones(dim), p = 2.0, alg, abstol, reltol) end ### N-dimensional nout iip @@ -293,6 +299,8 @@ for (alg, req) in pairs(alg_req), (j, f) in enumerate(integrands), fiip = IntegralFunction((y, x, p) -> f_helper!(f, y, x, p), zeros(nout)) do_tests(; f = fiip, scalarize, lb = ones(dim), ub = 3ones(dim), p = [2.0i for i in 1:nout], alg, abstol, reltol) + do_tests_mooncake(; f = fiip, scalarize, lb = ones(dim), ub = 3ones(dim), + p = [2.0i for i in 1:nout], alg, abstol, reltol) end ### Batch, One Dimensional From ecea7e663bcd09d733cbaf1902e3714c0b138a4c Mon Sep 17 00:00:00 2001 From: Astitva Aggarwal Date: Sun, 14 Dec 2025 02:29:01 +0000 Subject: [PATCH 06/12] spellcheck --- ext/IntegralsMooncakeExt.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ext/IntegralsMooncakeExt.jl b/ext/IntegralsMooncakeExt.jl index 12ff507..44ed3f0 100644 --- a/ext/IntegralsMooncakeExt.jl +++ b/ext/IntegralsMooncakeExt.jl @@ -195,7 +195,7 @@ function ChainRulesCore.rrule(::typeof(Integrals.__solvebp), cache, alg, senseal out, quadrature_adjoint end -# Internal Mooncake overloads to accomodate IntegralSolution etc. Struct's Tangent Types. +# Internal Mooncake overloads to accommodate IntegralSolution etc. Struct's Tangent Types. # Allows clear translation from ChainRules -> Mooncake's tangent. function Mooncake.increment_and_get_rdata!( f::NoFData, r::Tuple{T,T}, t::Union{Tangent{Any,Tuple{T,T}},Tangent{Tuple{T,T},Tuple{T,T}}} From 415518a4dbf7eb733974095c70a97c823c0e362c Mon Sep 17 00:00:00 2001 From: Astitva Aggarwal Date: Sun, 14 Dec 2025 03:45:55 +0000 Subject: [PATCH 07/12] . --- ext/IntegralsMooncakeExt.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/ext/IntegralsMooncakeExt.jl b/ext/IntegralsMooncakeExt.jl index 44ed3f0..1c7359a 100644 --- a/ext/IntegralsMooncakeExt.jl +++ b/ext/IntegralsMooncakeExt.jl @@ -198,7 +198,7 @@ end # Internal Mooncake overloads to accommodate IntegralSolution etc. Struct's Tangent Types. # Allows clear translation from ChainRules -> Mooncake's tangent. function Mooncake.increment_and_get_rdata!( - f::NoFData, r::Tuple{T,T}, t::Union{Tangent{Any,Tuple{T,T}},Tangent{Tuple{T,T},Tuple{T,T}}} + f::NoFData, r::Tuple{T,T}, t::Tangent{Any,Union{Tuple{T,T},Tangent{Tuple{T,T},Tuple{T,T}}}} ) where {T<:Base.IEEEFloat} return r .+ t.backing end @@ -224,7 +224,7 @@ function Mooncake.increment_and_get_rdata!( prob::Tangent{Any, @NamedTuple{ f::S, - domain::Union{Tangent{Any,Tuple{T,T}},Tangent{Any,Tuple{Vector{T},Vector{T}}}}, + domain::Tangent{Any,Union{Tuple{T,T},Tuple{Vector{T},Vector{T}}}}, p::T, kwargs::S } @@ -261,7 +261,7 @@ function Mooncake.increment_and_get_rdata!( integrand_prototype::Vector{T} } }}, - domain::Union{Tangent{Any,Tuple{T,T}},Tangent{Any,Tuple{Vector{T},Vector{T}}}}, + domain::Tangent{Any,Union{Tuple{T,T},Tuple{Vector{T},Vector{T}}}}, p::Union{T,Vector{T}}, kwargs::S } @@ -280,7 +280,7 @@ function Mooncake.increment_and_get_rdata!( end # cannot mutate NoRData() in place, therefore return as is. -function Mooncake.increment!!(::Mooncake.NoRData, y::Union{Tangent{Any,Tuple{T,T}},Tangent{Any,Tuple{Vector{T},Vector{T}}}}) where {T<:Base.IEEEFloat} +function Mooncake.increment!!(::Mooncake.NoRData, y::Tangent{Any,Union{Tuple{T,T},Tangent{Any,Tuple{Vector{T},Vector{T}}}}}) where {T<:Base.IEEEFloat} return Mooncake.NoRData() end end \ No newline at end of file From 473d4d81e5b9444be004a2e3160949d1d21b6c58 Mon Sep 17 00:00:00 2001 From: Astitva Aggarwal Date: Sun, 14 Dec 2025 17:43:16 +0000 Subject: [PATCH 08/12] fixed union type in overloading --- ext/IntegralsMooncakeExt.jl | 84 ++++++++++++++++++++----------------- 1 file changed, 45 insertions(+), 39 deletions(-) diff --git a/ext/IntegralsMooncakeExt.jl b/ext/IntegralsMooncakeExt.jl index 1c7359a..518822d 100644 --- a/ext/IntegralsMooncakeExt.jl +++ b/ext/IntegralsMooncakeExt.jl @@ -198,7 +198,7 @@ end # Internal Mooncake overloads to accommodate IntegralSolution etc. Struct's Tangent Types. # Allows clear translation from ChainRules -> Mooncake's tangent. function Mooncake.increment_and_get_rdata!( - f::NoFData, r::Tuple{T,T}, t::Tangent{Any,Union{Tuple{T,T},Tangent{Tuple{T,T},Tuple{T,T}}}} + f::NoFData, r::Tuple{T,T}, t::Union{Tangent{Tuple{T,T},Tuple{T,T}},Tangent{Any,Tuple{Float64,Float64}}} ) where {T<:Base.IEEEFloat} return r .+ t.backing end @@ -216,32 +216,31 @@ end function Mooncake.increment_and_get_rdata!( f::NoFData, r::T, - t::Union{ - Tangent{Any, - @NamedTuple{ - u::T, - resid::T, - prob::Tangent{Any, - @NamedTuple{ - f::S, - domain::Tangent{Any,Union{Tuple{T,T},Tuple{Vector{T},Vector{T}}}}, - p::T, - kwargs::S - } - }, - alg::S, - retcode::S, - chi::S, - stats::S - } + t::Tangent{Any, + @NamedTuple{ + u::T, + resid::T, + prob::Tangent{Any, + @NamedTuple{ + f::NoTangent, + domain::Tangent{Any,Tuple{M,M}}, + p::P, + kwargs::NoTangent + } + }, + alg::NoTangent, + retcode::NoTangent, + chi::NoTangent, + stats::NoTangent } - }) where {T<:Base.IEEEFloat,S<:NoTangent} + } +) where {T<:Base.IEEEFloat,P<:Union{T,Vector{T}},M<:Union{T,Vector{T}}} # rdata component of t + r (u field) return Mooncake.increment_and_get_rdata!(f, r, t.u) end -# sol.u is vector valued, p is vector valued, domain can be single/multi - variate -# resid can be single/vector valued, integrand_prototype field in prob for inplace integrals (iip true) +# sol.u is vector valued, p is scalar/vector valued, domain can be single/multi - variate +# resid can be single/vector valued. For inplace integrals (iip true) : included integrand_prototype field in typeof{prob.f} function Mooncake.increment_and_get_rdata!( f::Vector{T}, r::NoRData, @@ -250,37 +249,44 @@ function Mooncake.increment_and_get_rdata!( Any, @NamedTuple{ u::Vector{T}, - resid::Union{T,Vector{T}}, + resid::R, prob::Tangent{ Any, @NamedTuple{ - f::Union{S,Tangent{ - Any, - @NamedTuple{ - f::S, - integrand_prototype::Vector{T} - } - }}, - domain::Tangent{Any,Union{Tuple{T,T},Tuple{Vector{T},Vector{T}}}}, - p::Union{T,Vector{T}}, - kwargs::S + f::F, + domain::Tangent{Any,M}, + p::P, + kwargs::NoTangent } }, - alg::S, - retcode::S, - chi::S, - stats::S + alg::NoTangent, + retcode::NoTangent, + chi::NoTangent, + stats::NoTangent + } + } + } +) where {T<:Base.IEEEFloat, + R<:Union{T,Vector{T}}, + P<:Union{T,Vector{T}}, + M<:Union{Tuple{T,T},Tuple{Vector{T},Vector{T}}}, + F<:Union{NoTangent, + Tangent{ + Any, + @NamedTuple{ + f::NoTangent, + integrand_prototype::Vector{T} } } } -) where {T<:Base.IEEEFloat,S<:NoTangent} +} Mooncake.increment!!(f, t.u) # rdata component(t) + r return t.prob.domain end # cannot mutate NoRData() in place, therefore return as is. -function Mooncake.increment!!(::Mooncake.NoRData, y::Tangent{Any,Union{Tuple{T,T},Tangent{Any,Tuple{Vector{T},Vector{T}}}}}) where {T<:Base.IEEEFloat} +function Mooncake.increment!!(::Mooncake.NoRData, y::Tangent{Any,Y}) where {T<:Base.IEEEFloat,Y<:Union{Tuple{T,T},Tuple{Vector{T},Vector{T}}}} return Mooncake.NoRData() end end \ No newline at end of file From 9b9257cbf3a29d6fe7bc69f1ca5927f1b0876bc8 Mon Sep 17 00:00:00 2001 From: Astitva Aggarwal Date: Sun, 14 Dec 2025 17:55:03 +0000 Subject: [PATCH 09/12] more dispatch stuff --- ext/IntegralsMooncakeExt.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ext/IntegralsMooncakeExt.jl b/ext/IntegralsMooncakeExt.jl index 518822d..38abaf9 100644 --- a/ext/IntegralsMooncakeExt.jl +++ b/ext/IntegralsMooncakeExt.jl @@ -219,7 +219,7 @@ function Mooncake.increment_and_get_rdata!( t::Tangent{Any, @NamedTuple{ u::T, - resid::T, + resid::R, prob::Tangent{Any, @NamedTuple{ f::NoTangent, @@ -234,7 +234,7 @@ function Mooncake.increment_and_get_rdata!( stats::NoTangent } } -) where {T<:Base.IEEEFloat,P<:Union{T,Vector{T}},M<:Union{T,Vector{T}}} +) where {T<:Base.IEEEFloat,R<:Union{NoTangent,T},P<:Union{T,Vector{T}},M<:Union{T,Vector{T}}} # rdata component of t + r (u field) return Mooncake.increment_and_get_rdata!(f, r, t.u) end From 67b5791ba6a332f9c9f31d8331c5ab39bed472ee Mon Sep 17 00:00:00 2001 From: Astitva Aggarwal Date: Sun, 14 Dec 2025 22:06:58 +0000 Subject: [PATCH 10/12] more dispatch, all tests. --- ext/IntegralsMooncakeExt.jl | 30 ++++++++++++++++++++++++++---- 1 file changed, 26 insertions(+), 4 deletions(-) diff --git a/ext/IntegralsMooncakeExt.jl b/ext/IntegralsMooncakeExt.jl index 38abaf9..4581458 100644 --- a/ext/IntegralsMooncakeExt.jl +++ b/ext/IntegralsMooncakeExt.jl @@ -228,13 +228,26 @@ function Mooncake.increment_and_get_rdata!( kwargs::NoTangent } }, - alg::NoTangent, + alg::A, retcode::NoTangent, chi::NoTangent, stats::NoTangent } } -) where {T<:Base.IEEEFloat,R<:Union{NoTangent,T},P<:Union{T,Vector{T}},M<:Union{T,Vector{T}}} +) where {T<:Base.IEEEFloat, + R<:Union{NoTangent,T}, + P<:Union{T,Vector{T}}, + M<:Union{T,Vector{T}}, + A<:Union{NoTangent, + Tangent{Any, + @NamedTuple{ + nodes::Vector{T}, + weights::Vector{T}, + subintervals::NoTangent + } + } + } +} # rdata component of t + r (u field) return Mooncake.increment_and_get_rdata!(f, r, t.u) end @@ -259,7 +272,7 @@ function Mooncake.increment_and_get_rdata!( kwargs::NoTangent } }, - alg::NoTangent, + alg::A, retcode::NoTangent, chi::NoTangent, stats::NoTangent @@ -267,7 +280,7 @@ function Mooncake.increment_and_get_rdata!( } } ) where {T<:Base.IEEEFloat, - R<:Union{T,Vector{T}}, + R<:Union{NoTangent,T,Vector{T}}, P<:Union{T,Vector{T}}, M<:Union{Tuple{T,T},Tuple{Vector{T},Vector{T}}}, F<:Union{NoTangent, @@ -278,6 +291,15 @@ function Mooncake.increment_and_get_rdata!( integrand_prototype::Vector{T} } } + }, + A<:Union{NoTangent, + Tangent{Any, + @NamedTuple{ + nodes::Vector{T}, + weights::Vector{T}, + subintervals::NoTangent + } + } } } Mooncake.increment!!(f, t.u) From e2f2b8c19fd1e4769defda1a088cdba3fd4061eb Mon Sep 17 00:00:00 2001 From: Astitva Aggarwal Date: Mon, 15 Dec 2025 00:49:12 +0000 Subject: [PATCH 11/12] sensealg choice, GTG --- test/derivative_tests.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/derivative_tests.jl b/test/derivative_tests.jl index 2ee2fd1..b2af20b 100644 --- a/test/derivative_tests.jl +++ b/test/derivative_tests.jl @@ -431,7 +431,7 @@ end lb, ub, p = 1.0, 5.0, 2.0 @testset "Sensitivity using Zygote" begin - sensealg = Integrals.ReCallVJP{IntegralZygoteVJP}(IntegralZygoteVJP()) + sensealg = Integrals.ReCallVJP(Integrals.ZygoteVJP()) sol = Zygote.withgradient((args...) -> testf(_testf, args..., alg), lb, ub, p, sensealg) tsol = Zygote.withgradient((args...) -> testf(_testf, args..., talg), lb, ub, p, sensealg) @test sol.val ≈ tsol.val @@ -443,7 +443,7 @@ end end @testset "Sensitivity using Mooncake" begin - sensealg = Integrals.ReCallVJP{Integrals.MooncakeVJP}(Integrals.MooncakeVJP()) + sensealg = Integrals.ReCallVJP(Integrals.MooncakeVJP()) # anonymous function for cache creation and gradient evaluation call must be the same. func = (args...) -> testf(_testf, args..., alg, sensealg) cache = Mooncake.prepare_gradient_cache(func, lb, ub, p, sensealg) From 6b5aedcefe8358eeaea20e6bcf5805db571de3e1 Mon Sep 17 00:00:00 2001 From: Astitva Aggarwal Date: Mon, 15 Dec 2025 01:11:22 +0000 Subject: [PATCH 12/12] fin. --- test/derivative_tests.jl | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/test/derivative_tests.jl b/test/derivative_tests.jl index b2af20b..ad17e88 100644 --- a/test/derivative_tests.jl +++ b/test/derivative_tests.jl @@ -432,8 +432,8 @@ end @testset "Sensitivity using Zygote" begin sensealg = Integrals.ReCallVJP(Integrals.ZygoteVJP()) - sol = Zygote.withgradient((args...) -> testf(_testf, args..., alg), lb, ub, p, sensealg) - tsol = Zygote.withgradient((args...) -> testf(_testf, args..., talg), lb, ub, p, sensealg) + sol = Zygote.withgradient((args...) -> testf(_testf, args...), lb, ub, p, alg, sensealg) + tsol = Zygote.withgradient((args...) -> testf(_testf, args...), lb, ub, p, talg, sensealg) @test sol.val ≈ tsol.val # Fundamental theorem of Calculus part 1 @test sol.grad[1] ≈ tsol.grad[1] ≈ -_testf(lb, p) @@ -445,14 +445,14 @@ end @testset "Sensitivity using Mooncake" begin sensealg = Integrals.ReCallVJP(Integrals.MooncakeVJP()) # anonymous function for cache creation and gradient evaluation call must be the same. - func = (args...) -> testf(_testf, args..., alg, sensealg) - cache = Mooncake.prepare_gradient_cache(func, lb, ub, p, sensealg) + func = (args...) -> testf(_testf, args...) + cache = Mooncake.prepare_gradient_cache(func, lb, ub, p, alg, sensealg) sol = Mooncake.value_and_gradient!!(cache, func, - lb, ub, p, sensealg) + lb, ub, p, alg, sensealg) - cache = Mooncake.prepare_gradient_cache(func, lb, ub, p, sensealg) + cache = Mooncake.prepare_gradient_cache(func, lb, ub, p, talg, sensealg) tsol = Mooncake.value_and_gradient!!( - cache, func, lb, ub, p, sensealg) + cache, func, lb, ub, p, talg, sensealg) @test sol[1] ≈ tsol[1] # Fundamental theorem of Calculus part 1