Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 35 additions & 1 deletion src/config.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,43 @@ end

Tag(::Nothing, ::Type{V}) where {V} = nothing

# A tag `T1` provably nests a tag `T2` when `T2` appears inside `T1`'s type
# structure: in the seeded value type `V1`, or captured by the function type `F1`
# (closure/struct fields and type parameters). In that case `T1` was necessarily
# created while `T2`'s derivative was already in progress, so `Dual{T1}` must be
# composed outside `Dual{T2}` — independent of `tagcount`, whose values can be
# baked in an arbitrary order by precompilation (see #714).
function _containstag(@nospecialize(T), @nospecialize(target), seen::Base.IdSet{Any}, depth::Int)
T === target && return true
depth <= 0 && return false
if T isa Union
return _containstag(T.a, target, seen, depth - 1) ||
_containstag(T.b, target, seen, depth - 1)
elseif T isa UnionAll
return _containstag(Base.unwrap_unionall(T), target, seen, depth - 1)
end
T isa DataType || return false
T in seen && return false
push!(seen, T)
for p in T.parameters
p isa Type && _containstag(p, target, seen, depth - 1) && return true
end
if isconcretetype(T) && isstructtype(T)
for ft in fieldtypes(T)
ft isa Type && _containstag(ft, target, seen, depth - 1) && return true
end
end
return false
end

@generated function containstag(::Type{T1}, ::Type{T2}) where {T1,T2}
return _containstag(T1, T2, Base.IdSet{Any}(), 32) ? :(true) : :(false)
end

@inline function ≺(::Type{Tag{F1,V1}}, ::Type{Tag{F2,V2}}) where {F1,V1,F2,V2}
tagcount(Tag{F1,V1}) < tagcount(Tag{F2,V2})
containstag(Tag{F1,V1}, Tag{F2,V2}) && return false
containstag(Tag{F2,V2}, Tag{F1,V1}) && return true
return tagcount(Tag{F1,V1}) < tagcount(Tag{F2,V2})
end

struct InvalidTagException{E,O} <: Exception
Expand Down
51 changes: 51 additions & 0 deletions test/ConfusionTest.jl
Original file line number Diff line number Diff line change
Expand Up @@ -72,5 +72,56 @@ end
end
end == 0.0

# Nested differentiation must not depend on tagcount instantiation order (#714) #
#-------------------------------------------------------------------------------#

# containstag: nesting through the seeded value type V
struct TagOrderOuterMarker end
struct TagOrderInnerMarker end
let Tag = ForwardDiff.Tag, Dual = ForwardDiff.Dual
Touter = Tag{TagOrderOuterMarker, Float64}
Tinner = Tag{TagOrderInnerMarker, Dual{Touter, Float64, 1}}
@test ForwardDiff.containstag(Tinner, Touter)
@test !ForwardDiff.containstag(Touter, Tinner)
# bake tagcounts in inverted order: containment must win regardless
ForwardDiff.tagcount(Tinner)
ForwardDiff.tagcount(Touter)
@test ForwardDiff.:(≺)(Touter, Tinner)
@test !ForwardDiff.:(≺)(Tinner, Touter)
end

# Second derivative with tagcount baked in inverted order, as precompilation can
# do: the inner tag nests the outer through V, so ordering must not consult
# tagcount at all.
struct TagOrderInnerV end
(::TagOrderInnerV)(y) = y^3
struct TagOrderOuterV end
(::TagOrderOuterV)(x) = ForwardDiff.derivative(TagOrderInnerV(), x)
ForwardDiff.tagcount(ForwardDiff.Tag{TagOrderInnerV, ForwardDiff.Dual{ForwardDiff.Tag{TagOrderOuterV, Float64}, Float64, 1}})
ForwardDiff.tagcount(ForwardDiff.Tag{TagOrderOuterV, Float64})
@test ForwardDiff.derivative(TagOrderOuterV(), 2.0) ≈ 12.0

# Same with the outer perturbation entering through a capture in F (both tags
# have V === Float64): nesting is only visible through the callable's fields.
struct TagOrderOuterF end
struct TagOrderInnerF
x_dual::ForwardDiff.Dual{ForwardDiff.Tag{TagOrderOuterF, Float64}, Float64, 1}
end
(c::TagOrderInnerF)(y) = sin(c.x_dual * y)
(::TagOrderOuterF)(x_dual::ForwardDiff.Dual{ForwardDiff.Tag{TagOrderOuterF, Float64}, Float64, 1}) =
ForwardDiff.derivative(TagOrderInnerF(x_dual), 1.0)
ForwardDiff.tagcount(ForwardDiff.Tag{TagOrderInnerF, Float64})
ForwardDiff.tagcount(ForwardDiff.Tag{TagOrderOuterF, Float64})
@test ForwardDiff.derivative(TagOrderOuterF(), 0.5) ≈ cos(0.5) - 0.5 * sin(0.5)

# Three-level nesting where the innermost derivative is seeded with a plain
# Float64 while the outer perturbations enter through closure captures. A
# depth-only fast path mis-orders this case; it must keep working.
let
inner_deriv(d) = ForwardDiff.derivative(y -> y^2 * d, 1.0)
middle_grad(v) = ForwardDiff.gradient(u -> sum(inner_deriv(ui) * ui for ui in u), v)
outer_fn(x) = sum(middle_grad([x, 2x]))
@test ForwardDiff.derivative(outer_fn, 0.5) ≈ 12.0
end

end # module
Loading