Lightweight NaN detection for Flux.jl models.
Wraps leaf layers to check forward inputs, forward outputs, and incoming gradients — throws a DomainError with the exact layer path at the first NaN.
using NaNTracker, Flux
model = Chain(Dense(10 => 20, relu), Dense(20 => 5))
# Wrap — checks every forward and backward pass for NaN
tracked = nantrack(model)
x = randn(Float32, 10, 8)
loss, grads = Flux.withgradient(tracked) do m
sum(m(x))
end
# Unwrap when done debugging
clean = nanuntrack(tracked)Enable stats tracking to record norm and maxabs at every checked layer:
enable_stats!()
# ... training step ...
dump_stats(path_contains="attention") # show only attention layers
clear_stats!() # reset for next step
disable_stats!() # zero overhead when offWhen a NaN is detected with stats enabled, the recent trajectory is printed automatically.