Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 1 addition & 5 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,16 +1,12 @@
name = "FunctionProperties"
uuid = "f62d2435-5019-4c03-9749-2d4c77af0cbc"
version = "0.1.4"
authors = ["SciML"]
version = "0.1.3"

[deps]
Cassette = "7057c7e9-c182-5462-911a-8362d720325c"
DiffRules = "b552c78f-8df3-52c6-915a-8e097449b14b"

[compat]
Cassette = "0.3.12"
ComponentArrays = "0.15"
DiffRules = "1.15"
Random = "1.10"
SafeTestsets = "0.1"
Test = "1.10"
Expand Down
170 changes: 35 additions & 135 deletions src/FunctionProperties.jl
Original file line number Diff line number Diff line change
@@ -1,160 +1,60 @@
module FunctionProperties

using Cassette, DiffRules
using Core: CodeInfo, SlotNumber, SSAValue, ReturnNode, GotoIfNot
using Core: GotoIfNot

const printbranch = false

Cassette.@context HasBranchingCtx

function Cassette.overdub(ctx::HasBranchingCtx, f, args...)
if Cassette.canrecurse(ctx, f, args...)
return Cassette.recurse(ctx, f, args...)
else
return Cassette.fallback(ctx, f, args...)
end
end
"""
is_leaf(f, args...) -> Bool

for (mod, f, n) in DiffRules.diffrules(; filter_modules = nothing)
if !(isdefined(@__MODULE__, mod) && isdefined(getfield(@__MODULE__, mod), f))
continue # Skip rules for methods not defined in the current scope
end
@eval function Cassette.overdub(
::HasBranchingCtx, f::Core.Typeof($mod.$f),
x::Vararg{Any, $n}
)
return f(x...)
end
end
Override this to exempt a function from `hasbranching` analysis.
Return `true` to treat `f` as branch-free regardless of its implementation.

function _pass(::Type{<:HasBranchingCtx}, reflection::Cassette.Reflection)
ir = reflection.code_info

if any(x -> isa(x, GotoIfNot), ir.code)
printbranch && println("GotoIfNot detected in $(reflection.method)\nir = $ir\n")
Cassette.insert_statements!(
ir.code, ir.codelocs,
(stmt, i) -> i == 1 ? 3 : nothing,
(
stmt,
i,
) -> Any[
Expr(
:call,
Expr(
:nooverdub,
GlobalRef(Base, :getfield)
),
Expr(:contextslot),
QuoteNode(:metadata)
),
Expr(
:call,
Expr(
:nooverdub,
GlobalRef(Base, :setindex!)
),
SSAValue(1), true,
QuoteNode(:has_branching)
),
stmt,
]
)
Cassette.insert_statements!(
ir.code, ir.codelocs,
(stmt, i) -> i > 2 && isa(stmt, Expr) ? 1 : nothing,
(
stmt,
i,
) -> begin
callstmt = Meta.isexpr(stmt, :(=)) ? stmt.args[2] :
stmt
Meta.isexpr(stmt, :call) ||
Meta.isexpr(stmt, :invoke) || return Any[stmt]
callstmt = Expr(
callstmt.head,
Expr(:nooverdub, callstmt.args[1]),
callstmt.args[2:end]...
)
return Any[
Meta.isexpr(stmt, :(=)) ?
Expr(:(=), stmt.args[1], callstmt) :
callstmt,
]
end
)
end
return ir
end
## Example

const pass = Cassette.@pass _pass
```julia
FunctionProperties.is_leaf(::typeof(my_fn)) = true
```
"""
is_leaf(f, args...) = false

"""
hasbranching(f, x...)

Checks whether the function `f` has branches (if statements) that are dependent on the value x
that would be taken in a tracing system, such as during AD tracing by a package like ReverseDiff.jl.
Checks whether the function `f` has branches (if statements) that are dependent on the
value `x` that would be taken in a tracing system, such as during AD tracing by a package
like ReverseDiff.jl.

## Arguments:
## Arguments

* `f`: the function to inspect
* `x`: test arguments for the inspection. These values do not need to be the values that
would be used in the actual calls to the function but instead prototype values which
match the types that would be used in the actual function call. This is used to trace to
the correct internal dispatches.
- `f`: the function to inspect.
- `x`: test arguments. These values don't need to match the actual call values, but their
*types* must match — they are used to select the right method specialization.

## Outputs:
## Outputs

Boolean for whether the function has branches.
Boolean for whether the function's immediate IR contains a conditional branch (`GotoIfNot`).

## Customizing and Removing Dispatches from the Checks
## Customizing and Removing Functions from the Checks

Some internal functions of a package may cause false positives because a branch may be known to
resolve at compile time. If this is known, then you can add a dispatch to opt that function out
of the analysis via:
Some functions may produce false positives because their internal branches are compile-time
constants. Override `FunctionProperties.is_leaf` to opt them out:

```julia
function FunctionProperties.Cassette.overdub(::FunctionProperties.HasBranchingCtx, ::typeof(f), x...)
f(x...)
end
FunctionProperties.is_leaf(::typeof(my_fn)) = true
```
"""
function hasbranching(f, x...)
metadata = Dict(:has_branching => false)
Cassette.overdub(Cassette.disablehooks(HasBranchingCtx(; pass, metadata)), f, x...)
return metadata[:has_branching]
end

Cassette.overdub(::HasBranchingCtx, ::typeof(+), x...) = +(x...)
Cassette.overdub(::HasBranchingCtx, ::typeof(*), x...) = *(x...)
function Cassette.overdub(::HasBranchingCtx, ::typeof(Base.materialize), x...)
return Base.materialize(x...)
end
function Cassette.overdub(::HasBranchingCtx, ::typeof(Base.literal_pow), x...)
return Base.literal_pow(x...)
end
Cassette.overdub(::HasBranchingCtx, ::typeof(Base.getindex), x...) = Base.getindex(x...)
Cassette.overdub(::HasBranchingCtx, ::typeof(Base.setindex!), x...) = Base.setindex!(x...)
Cassette.overdub(::HasBranchingCtx, ::typeof(Core.Typeof), x...) = Core.Typeof(x...)
Cassette.overdub(::HasBranchingCtx, ::typeof(Base.vec), x...) = Base.vec(x...)
Cassette.overdub(::HasBranchingCtx, ::typeof(Base.vect), x...) = Base.vect(x...)

Cassette.overdub(::HasBranchingCtx, ::typeof(Base.vcat), x...) = Base.vcat(x...)
Cassette.overdub(::HasBranchingCtx, ::typeof(Base.hcat), x...) = Base.hcat(x...)
Cassette.overdub(::HasBranchingCtx, ::typeof(Base.hvcat), x...) = Base.hvcat(x...)
Cassette.overdub(::HasBranchingCtx, ::typeof(Base.cat), x...) = Base.cat(x...)
Cassette.overdub(::HasBranchingCtx, ::typeof(Base.stack), x...) = Base.stack(x...)

function Cassette.overdub(::HasBranchingCtx, ::typeof(Base.Broadcast.broadcasted), x...)
return Base.Broadcast.broadcasted(x...)
end
function Cassette.overdub(
::HasBranchingCtx, ::Type{Base.OneTo{T}},
stop
) where {T <: Integer}
return Base.OneTo{T}(stop)
is_leaf(f, x...) && return false
argtypes = Tuple{Core.Typeof.(x)...}
results = try
code_typed(f, argtypes; optimize = false)
catch
return false
end
isempty(results) && return false
ci = first(results)[1]
return any(isa(s, GotoIfNot) for s in ci.code)
end

export hasbranching
export hasbranching, is_leaf

end
170 changes: 83 additions & 87 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,92 +13,88 @@ end

if GROUP in ("All", "Core")

@test hasbranching(1, 2) do x, y
(x < 0 ? -x : x) + exp(y)
end

@test !hasbranching(1, 2) do x, y
ifelse(x < 0, -x, x) + exp(y)
end

# Test overloading

f_branch() = true ? 1 : 0
@test FunctionProperties.hasbranching(f_branch)
function FunctionProperties.Cassette.overdub(
::FunctionProperties.HasBranchingCtx, ::typeof(f_branch), x...
)
return f_branch(x...)
end
@test !FunctionProperties.hasbranching(f_branch)

# Test simple mutating functions
function f(dx, x)
return @inbounds dx[1] = x[1]
end
x = zeros(1)
dx = zeros(1)
@test !FunctionProperties.hasbranching(f, dx, x)

# Test broadcast
function f(x)
return cos.(x .+ x .* x)
end
x = [1.0]
@test !FunctionProperties.hasbranching(f, x)

# Neural networks
#
# The relevant scenario is a neural-network-shaped ODE right-hand side (SciML/SciMLSensitivity.jl#997):
# `hasbranching` must report it as branch-free so a tracing AD like ReverseDiff can compile a tape.
# The forward pass is expressed here as explicit affine transforms plus broadcast activations, which
# is the value flow `hasbranching` actually inspects. We deliberately do not trace a real Lux layer:
# modern Lux layer dispatch routes through device-detection / type-introspection helpers that contain
# genuine (but value-independent, compile-time) `GotoIfNot` branches, which this syntactic IR scan
# cannot distinguish from value-dependent branches (SciML/FunctionProperties.jl#46).
rng = Random.default_rng()
W = randn(rng, Float32, 1, 1)
b = randn(rng, Float32, 1)
p = ComponentArray(; weight = W, bias = b)
t = [0.0]

function f(x, ps)
return ps.weight * x
end
@test !FunctionProperties.hasbranching(f, t, p)

function f(x, ps)
return x .+ x
end
@test !FunctionProperties.hasbranching(f, t, p)

# Affine transform followed by a broadcast activation (the original `apply_activation` intent).
function f2(x, ps)
return identity.(ps.weight * x .+ vec(ps.bias))
end
@test !FunctionProperties.hasbranching(f2, t, p)

# A multi-layer perceptron forward pass built from broadcast `tanh` activations.
rng = Random.default_rng()
tspan = (0.0f0, 8.0f0)
W1 = randn(rng, Float32, 32, 2)
b1 = randn(rng, Float32, 32)
W2 = randn(rng, Float32, 32, 32)
b2 = randn(rng, Float32, 32)
W3 = randn(rng, Float32, 1, 32)
b3 = randn(rng, Float32, 1)
p = ComponentArray(; W1, b1, W2, b2, W3, b3)
θ, ax = getdata(p), getaxes(p)

ann(x, p) = p.W3 * tanh.(p.W2 * tanh.(p.W1 * x .+ p.b1) .+ p.b2) .+ p.b3

function dxdt_(dx, x, p, t)
x1, x2 = x
dx[1] = x[2] + first(ann(x, p))
return dx[2] = first(ann([t, t], p))
end
x0 = [-4.0f0, 0.0f0]
ts = Float32.(collect(0.0:0.01:tspan[2]))
@test !FunctionProperties.hasbranching(dxdt_, copy(x0), x0, p, tspan[1])
@test hasbranching(1, 2) do x, y
(x < 0 ? -x : x) + exp(y)
end

@test !hasbranching(1, 2) do x, y
ifelse(x < 0, -x, x) + exp(y)
end

# Test overloading via is_leaf

f_branch() = true ? 1 : 0
@test FunctionProperties.hasbranching(f_branch)
FunctionProperties.is_leaf(::typeof(f_branch)) = true
@test !FunctionProperties.hasbranching(f_branch)

# Test simple mutating functions
function f(dx, x)
return @inbounds dx[1] = x[1]
end
x = zeros(1)
dx = zeros(1)
@test !FunctionProperties.hasbranching(f, dx, x)

# Test broadcast
function f(x)
return cos.(x .+ x .* x)
end
x = [1.0]
@test !FunctionProperties.hasbranching(f, x)

# Neural networks
#
# The relevant scenario is a neural-network-shaped ODE right-hand side (SciML/SciMLSensitivity.jl#997):
# `hasbranching` must report it as branch-free so a tracing AD like ReverseDiff can compile a tape.
# The forward pass is expressed here as explicit affine transforms plus broadcast activations, which
# is the value flow `hasbranching` actually inspects. We deliberately do not trace a real Lux layer:
# modern Lux layer dispatch routes through device-detection / type-introspection helpers that contain
# genuine (but value-independent, compile-time) `GotoIfNot` branches, which this syntactic IR scan
# cannot distinguish from value-dependent branches (SciML/FunctionProperties.jl#46).
rng = Random.default_rng()
W = randn(rng, Float32, 1, 1)
b = randn(rng, Float32, 1)
p = ComponentArray(; weight = W, bias = b)
t = [0.0]

function f(x, ps)
return ps.weight * x
end
@test !FunctionProperties.hasbranching(f, t, p)

function f(x, ps)
return x .+ x
end
@test !FunctionProperties.hasbranching(f, t, p)

# Affine transform followed by a broadcast activation (the original `apply_activation` intent).
function f2(x, ps)
return identity.(ps.weight * x .+ vec(ps.bias))
end
@test !FunctionProperties.hasbranching(f2, t, p)

# A multi-layer perceptron forward pass built from broadcast `tanh` activations.
rng = Random.default_rng()
tspan = (0.0f0, 8.0f0)
W1 = randn(rng, Float32, 32, 2)
b1 = randn(rng, Float32, 32)
W2 = randn(rng, Float32, 32, 32)
b2 = randn(rng, Float32, 32)
W3 = randn(rng, Float32, 1, 32)
b3 = randn(rng, Float32, 1)
p = ComponentArray(; W1, b1, W2, b2, W3, b3)
θ, ax = getdata(p), getaxes(p)

ann(x, p) = p.W3 * tanh.(p.W2 * tanh.(p.W1 * x .+ p.b1) .+ p.b2) .+ p.b3

function dxdt_(dx, x, p, t)
x1, x2 = x
dx[1] = x[2] + first(ann(x, p))
return dx[2] = first(ann([t, t], p))
end
x0 = [-4.0f0, 0.0f0]
ts = Float32.(collect(0.0:0.01:tspan[2]))
@test !FunctionProperties.hasbranching(dxdt_, copy(x0), x0, p, tspan[1])

end
Loading