Skip to content

mashu/NaNTracker.jl

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

13 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

NaNTracker.jl

CI codecov Documentation (stable) Documentation (dev) License: MIT

Lightweight NaN detection for Flux.jl models. Wraps leaf layers to check forward inputs, forward outputs, and incoming gradients — throws a DomainError with the exact layer path at the first NaN.

Quick start

using NaNTracker, Flux

model = Chain(Dense(10 => 20, relu), Dense(20 => 5))

# Wrap — checks every forward and backward pass for NaN
tracked = nantrack(model)

x = randn(Float32, 10, 8)
loss, grads = Flux.withgradient(tracked) do m
    sum(m(x))
end

# Unwrap when done debugging
clean = nanuntrack(tracked)

Diagnosing gradient explosions

Enable stats tracking to record norm and maxabs at every checked layer:

enable_stats!()
# ... training step ...
dump_stats(path_contains="attention")  # show only attention layers
clear_stats!()                          # reset for next step
disable_stats!()                        # zero overhead when off

When a NaN is detected with stats enabled, the recent trajectory is printed automatically.

About

Simple NaNTracker to test forward and backward pass in Flux models.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages