diff --git a/src/config.jl b/src/config.jl index 3c6c97e3..f7b74c01 100644 --- a/src/config.jl +++ b/src/config.jl @@ -20,9 +20,43 @@ end Tag(::Nothing, ::Type{V}) where {V} = nothing +# A tag `T1` provably nests a tag `T2` when `T2` appears inside `T1`'s type +# structure: in the seeded value type `V1`, or captured by the function type `F1` +# (closure/struct fields and type parameters). In that case `T1` was necessarily +# created while `T2`'s derivative was already in progress, so `Dual{T1}` must be +# composed outside `Dual{T2}` — independent of `tagcount`, whose values can be +# baked in an arbitrary order by precompilation (see #714). +function _containstag(@nospecialize(T), @nospecialize(target), seen::Base.IdSet{Any}, depth::Int) + T === target && return true + depth <= 0 && return false + if T isa Union + return _containstag(T.a, target, seen, depth - 1) || + _containstag(T.b, target, seen, depth - 1) + elseif T isa UnionAll + return _containstag(Base.unwrap_unionall(T), target, seen, depth - 1) + end + T isa DataType || return false + T in seen && return false + push!(seen, T) + for p in T.parameters + p isa Type && _containstag(p, target, seen, depth - 1) && return true + end + if isconcretetype(T) && isstructtype(T) + for ft in fieldtypes(T) + ft isa Type && _containstag(ft, target, seen, depth - 1) && return true + end + end + return false +end + +@generated function containstag(::Type{T1}, ::Type{T2}) where {T1,T2} + return _containstag(T1, T2, Base.IdSet{Any}(), 32) ? :(true) : :(false) +end @inline function ≺(::Type{Tag{F1,V1}}, ::Type{Tag{F2,V2}}) where {F1,V1,F2,V2} - tagcount(Tag{F1,V1}) < tagcount(Tag{F2,V2}) + containstag(Tag{F1,V1}, Tag{F2,V2}) && return false + containstag(Tag{F2,V2}, Tag{F1,V1}) && return true + return tagcount(Tag{F1,V1}) < tagcount(Tag{F2,V2}) end struct InvalidTagException{E,O} <: Exception diff --git a/test/ConfusionTest.jl b/test/ConfusionTest.jl index 13c62ae9..13949451 100644 --- a/test/ConfusionTest.jl +++ b/test/ConfusionTest.jl @@ -72,5 +72,56 @@ end end end == 0.0 +# Nested differentiation must not depend on tagcount instantiation order (#714) # +#-------------------------------------------------------------------------------# + +# containstag: nesting through the seeded value type V +struct TagOrderOuterMarker end +struct TagOrderInnerMarker end +let Tag = ForwardDiff.Tag, Dual = ForwardDiff.Dual + Touter = Tag{TagOrderOuterMarker, Float64} + Tinner = Tag{TagOrderInnerMarker, Dual{Touter, Float64, 1}} + @test ForwardDiff.containstag(Tinner, Touter) + @test !ForwardDiff.containstag(Touter, Tinner) + # bake tagcounts in inverted order: containment must win regardless + ForwardDiff.tagcount(Tinner) + ForwardDiff.tagcount(Touter) + @test ForwardDiff.:(≺)(Touter, Tinner) + @test !ForwardDiff.:(≺)(Tinner, Touter) +end + +# Second derivative with tagcount baked in inverted order, as precompilation can +# do: the inner tag nests the outer through V, so ordering must not consult +# tagcount at all. +struct TagOrderInnerV end +(::TagOrderInnerV)(y) = y^3 +struct TagOrderOuterV end +(::TagOrderOuterV)(x) = ForwardDiff.derivative(TagOrderInnerV(), x) +ForwardDiff.tagcount(ForwardDiff.Tag{TagOrderInnerV, ForwardDiff.Dual{ForwardDiff.Tag{TagOrderOuterV, Float64}, Float64, 1}}) +ForwardDiff.tagcount(ForwardDiff.Tag{TagOrderOuterV, Float64}) +@test ForwardDiff.derivative(TagOrderOuterV(), 2.0) ≈ 12.0 + +# Same with the outer perturbation entering through a capture in F (both tags +# have V === Float64): nesting is only visible through the callable's fields. +struct TagOrderOuterF end +struct TagOrderInnerF + x_dual::ForwardDiff.Dual{ForwardDiff.Tag{TagOrderOuterF, Float64}, Float64, 1} +end +(c::TagOrderInnerF)(y) = sin(c.x_dual * y) +(::TagOrderOuterF)(x_dual::ForwardDiff.Dual{ForwardDiff.Tag{TagOrderOuterF, Float64}, Float64, 1}) = + ForwardDiff.derivative(TagOrderInnerF(x_dual), 1.0) +ForwardDiff.tagcount(ForwardDiff.Tag{TagOrderInnerF, Float64}) +ForwardDiff.tagcount(ForwardDiff.Tag{TagOrderOuterF, Float64}) +@test ForwardDiff.derivative(TagOrderOuterF(), 0.5) ≈ cos(0.5) - 0.5 * sin(0.5) + +# Three-level nesting where the innermost derivative is seeded with a plain +# Float64 while the outer perturbations enter through closure captures. A +# depth-only fast path mis-orders this case; it must keep working. +let + inner_deriv(d) = ForwardDiff.derivative(y -> y^2 * d, 1.0) + middle_grad(v) = ForwardDiff.gradient(u -> sum(inner_deriv(ui) * ui for ui in u), v) + outer_fn(x) = sum(middle_grad([x, 2x])) + @test ForwardDiff.derivative(outer_fn, 0.5) ≈ 12.0 +end end # module