Skip to content

Add OutputDiscardCheckpoint#682

Draft
AkshitaB wants to merge 15 commits into
mainfrom
akshitab/output-discard-checkpoint
Draft

Add OutputDiscardCheckpoint#682
AkshitaB wants to merge 15 commits into
mainfrom
akshitab/output-discard-checkpoint

Conversation

@AkshitaB
Copy link
Copy Markdown
Contributor

@AkshitaB AkshitaB commented May 16, 2026

Summary

Add OutputDiscardCheckpoint, an activation-recompute primitive for cases where the output of a checkpointed region dominates memory (rather than its intermediates).

Vanilla torch.utils.checkpoint discards intermediates inside the wrapped function but can't free the output -- downstream consumers and their saved-for-backward references hold it live. OutputDiscardCheckpoint extends that pattern: forward runs under no_grad, the output's storage can be freed after downstream forward consumes it, and a backward hook recomputes the forward and rebinds the freed storage in place (via a C++ share_storage extension, with a Python fallback). The tensor object survives so existing autograd saved-tensor references stay valid; only its underlying bytes are recycled.

Useful for fat-output ops where the output is wider than the input -- precision casts (bf16 -> fp32 doubling), FFN up-projections, attention outputs before SDPA fuses them.

Contents

  • src/olmo_core/nn/output_discard_checkpoint.py -- the primitive (287 lines), with a code-block usage example and :param:/:returns: docstrings.
  • src/olmo_core/nn/__init__.py -- exports OutputDiscardCheckpoint.
  • src/test/nn/output_discard_checkpoint_test.py -- four tests: storage discard/restore, grad parity vs a non-checkpointed Sequential, Python-fallback path forced via monkeypatch, and a 3D-Linear regression test (see below).

Bug found and fixed during integration testing

Writing the FFN integration script surfaced a real bug in the Python fallback. Tensor.set_(new_storage, ...) swaps dst's StorageImpl for a new one. Any autograd-saved view of dst that still referenced the original StorageImpl -- e.g. the 2D-reshape view that MmBackward saves when Linear is called with a 3D input -- would see the original storage, which was resize_(0)'d and never refilled. Backward then hit "tensor has non-zero numel but data is not allocated".

The C++ path didn't have this bug: it mutates dst's existing StorageImpl in place via set_data_ptr, so saved views see the new data through the same StorageImpl. The fix makes the Python fallback equivalent: resize dst's existing storage and copy src's bytes into it. Costs an extra allocation + copy during recompute on machines without a C++ toolchain, but is correct for tensors with saved views (which is most real workloads).

The 3D-Linear regression test ensures this doesn't sneak back in.

Standalone scripts (not run in CI)

Two scripts under src/scripts/ -- intentionally outside src/test/ so pytest never picks them up. Useful for human-driven benchmarking and verification:

  • src/scripts/benchmark_odc.py -- compares baseline / torch.utils.checkpoint / ODC (C++) / ODC (Python fallback) on a fat-output Linear -> activation -> Linear workload. Reports peak GPU memory and forward/backward wall time. Supports --scenarios grid for a shape sweep.
  • src/scripts/odc_ffn_integration_check.py -- wraps the real OLMo FeedForward (SwiGLU) with ODC around the fat activation(w1(x)) * w3(x) intermediate that w2 saves for backward. Asserts output + gradient parity vs the baseline; exits non-zero on failure.

Test plan

  • pytest -v src/test/nn/output_discard_checkpoint_test.py -- 4 tests pass (incl. forced Python fallback and 3D-Linear regression).
  • make checks -- isort / black / ruff / mypy all clean.
  • python src/scripts/odc_ffn_integration_check.py --device cpu -- PASS (output and grad parity at 0 diff in fp32).
  • On GPU: python src/scripts/benchmark_odc.py -- visually verify ODC peak memory < baseline.
  • On GPU: python src/scripts/odc_ffn_integration_check.py --dtype bf16 -- verify parity holds in bf16.

🤖 Generated with Claude Code

TianhuaTao and others added 15 commits May 15, 2026 15:20
A Megatron-style activation-recompute primitive. Forward runs under no_grad;
caller frees the output's storage after downstream consumption; a backward hook
recomputes and shares storage back into the original output tensor objects
without triggering autograd version errors. C++ share_storage extension built
via torch.utils.cpp_extension.load_inline, with a Python fallback for
environments without a compiler.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Monkeypatches _get_share_storage to None to exercise _fallback_share_storage
on CI machines where ninja and a C++ compiler are present and the C++
extension would otherwise always be used.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Add a code-block example to the class docstring showing the four-step
pattern (checkpoint, run downstream, discard+register, backward) and the
constraints on the choice of hook_tensor. Add :param:/:returns: docstrings
to the three public methods.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
The previous Python fallback rebound dst's storage via Tensor.set_(),
which swaps dst's StorageImpl for a new one. Any autograd-saved view of
dst that still referenced the original StorageImpl (e.g. the 2D-reshape
view that MmBackward saves when Linear is called with a 3D input) would
see the original storage -- which was resize_(0)'d -- and backward would
hit "tensor has non-zero numel but data is not allocated".

The C++ extension path didn't have this bug: it mutates dst's existing
StorageImpl in place via set_data_ptr, so saved views see the new data.

Make the Python fallback equivalent by resizing dst's existing storage
and copying src's bytes into it, preserving StorageImpl identity. This
costs an extra allocation + copy during recompute on machines without a
C++ toolchain, but is correct for tensors with saved views.

Add a 3D-Linear regression test covering the failure mode.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Two standalone scripts under src/scripts/ (intentionally outside src/test/
so CI does not run them):

- benchmark_odc.py: compares baseline / torch.utils.checkpoint / ODC
  (C++ extension) / ODC (Python fallback forced) on a synthetic fat-output
  workload. Reports peak GPU memory and forward/backward wall time. Supports
  a single-shape mode and a grid sweep.

- odc_ffn_integration_check.py: wraps the real olmo_core SwiGLU FeedForward
  with ODC around the fat (activation(w1(x)) * w3(x)) intermediate that w2
  saves for backward. Runs a few iterations and asserts output and gradient
  parity vs the baseline FeedForward. Exits non-zero on any failure.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Adds two GPU matrix entries that invoke the standalone scripts from
src/scripts/ to validate OutputDiscardCheckpoint end-to-end on real GPU
hardware. Revert before merging this PR.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Rewrite benchmark_odc.py around a BenchBlock abstraction so the wrapped
region (the thing ODC discards) is interchangeable. Adds five concrete
block types covering the spectrum of ODC fit:

- fp32_cast: x.float() under ODC followed by an fp32 Linear (the MoE
  router's pattern). Recompute is trivial -> ODC should be neutral or
  positive even at N=1.
- up_proj: fat Linear up-projection with no activation inside. Linear
  saves only its input -> recompute spike is small.
- silu_up: silu(up(x)) inside; activation adds a saved intermediate to
  recompute, raising recompute peak.
- swiglu: OLMo SwiGLU FFN; three fat intermediates saved during recompute
  -> worst-case ODC recompute footprint.
- rms_norm: RMSNorm + Linear + residual; cheap recompute, modest savings.

Each block runs at N = 1 (single-layer, ODC's worst case because the
savings window is zero) and N = args.n_layers (default 4) so the
multi-layer payoff is visible. Adds --only to restrict to a subset and
--layers to override the stack depths.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Chain N ODCFeedForward blocks (default N = 4) and verify per-iteration
output + gradient parity vs a baseline stack. Exercises per-block
OutputDiscardCheckpoint instances and the order in which their recompute
hooks fire as backward walks back through the stack -- a regression here
would silently corrupt gradients in multi-FFN training.

Adds --n-layers (default 4) and --layers (override). By default runs
both N = 1 and N = --n-layers so the single-layer case stays covered.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Flip the outer/inner loop in main() so output is grouped by stack depth
first, then iterates through all block types within each depth. Makes
cross-block-type comparison at a fixed depth easier to read.

Also adds an --iters flag (default 10) for the number of timed iterations,
matching the workflow's invocation.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
- New FP32SoftmaxBlock: bf16 -> upcast -> fp32 softmax -> downcast pattern,
  modeling the attention/router softmax-in-fp32 path. Softmax saves its
  OUTPUT for backward, so the discarded tensor is the fp32 softmax probs.
- --dtype now accepts multiple values (default [bf16]). Each is run as a
  separate top-level group so the precision-boundary effect is visible
  side-by-side (fp32_cast / fp32_softmax discard tensor doubles at
  bf16/fp16 base, neutral at fp32 base).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
So the precision-boundary effect on fp32_cast and fp32_softmax is visible
in the CI output. Revert before merge along with the other temporary
ODC steps.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
When the input is already fp32, x.float() returns x itself (no copy), so
h aliases x. ODC's discard would resize x's storage to 0, breaking
backward. Same aliasing issue for torch.utils.checkpoint -- its recompute
graph would also wrap an identity. Skip the checkpoint variants for this
degenerate case so the benchmark row truthfully reports "no benefit"
rather than crashing.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@AkshitaB
Copy link
Copy Markdown
Contributor Author

Benchmarking

========== n_layers=8  (d_model=2048, d_ff=8192, batch=2, seq=2048, dtype=torch.float32) ==========

  [fp32_cast] x.float() under ODC; consumer is an fp32 Linear (router pattern)
  config                          peak (MB)   fwd (ms)   bwd (ms)   total (ms)   mem saved   time delta
  ---------------------------- ------------ ---------- ---------- ------------  ----------  -----------
  baseline                            544.3       5.76      10.88        16.65  (baseline)   (baseline)
  torch.utils.checkpoint              544.3       5.79      10.89        16.68       +0.0%        +0.2%
  ODC (C++ if available)              544.3       5.73      10.88        16.60       +0.0%        -0.3%
  ODC (python fallback)               544.3       5.76      10.88        16.64       +0.0%        -0.1%

  [up_proj] fat Linear up-projection, no activation inside the discarded region
  config                          peak (MB)   fwd (ms)   bwd (ms)   total (ms)   mem saved   time delta
  ---------------------------- ------------ ---------- ---------- ------------  ----------  -----------
  baseline                           2560.3      44.07      84.83       128.90  (baseline)   (baseline)
  torch.utils.checkpoint             2560.3      44.42      84.93       129.35       +0.0%        +0.3%
  ODC (C++ if available)             2400.3      44.12     106.80       150.92       +6.2%       +17.1%
  ODC (python fallback)              2464.3      44.06     107.34       151.40       +3.7%       +17.4%

  [silu_up] silu(up(x)) inside; activation adds a saved intermediate to recompute
  config                          peak (MB)   fwd (ms)   bwd (ms)   total (ms)   mem saved   time delta
  ---------------------------- ------------ ---------- ---------- ------------  ----------  -----------
  baseline                           3584.3      44.49      85.95       130.44  (baseline)   (baseline)
  torch.utils.checkpoint             2656.3      44.88     107.87       152.75      +25.9%       +17.1%
  ODC (C++ if available)             2560.3      44.48     108.53       153.02      +28.6%       +17.3%
  ODC (python fallback)              2592.3      44.56     109.31       153.86      +27.7%       +18.0%

  [swiglu] OLMo SwiGLU FFN; three fat intermediates saved during recompute
  config                          peak (MB)   fwd (ms)   bwd (ms)   total (ms)   mem saved   time delta
  ---------------------------- ------------ ---------- ---------- ------------  ----------  -----------
  baseline                           6240.3      67.26     130.38       197.63  (baseline)   (baseline)
  torch.utils.checkpoint             3776.3      67.60     174.82       242.42      +39.5%       +22.7%
  ODC (C++ if available)             3904.3      67.31     175.86       243.17      +37.4%       +23.0%
  ODC (python fallback)              3904.3      67.32     176.54       243.86      +37.4%       +23.4%

  [rms_norm] RMSNorm + Linear + residual; cheap recompute, small fat-output savings
  config                          peak (MB)   fwd (ms)   bwd (ms)   total (ms)   mem saved   time delta
  ---------------------------- ------------ ---------- ---------- ------------  ----------  -----------
  baseline                            800.4       6.18      11.78        17.96  (baseline)   (baseline)
  torch.utils.checkpoint              800.3       6.34      12.00        18.33       +0.0%        +2.1%
  ODC (C++ if available)              560.3       6.19      11.99        18.19      +30.0%        +1.3%
  ODC (python fallback)               560.3       6.20      12.18        18.38      +30.0%        +2.3%

  [fp32_softmax] softmax in fp32 (attention/routing pattern); 2x size if base dtype is bf16/fp16
  config                          peak (MB)   fwd (ms)   bwd (ms)   total (ms)   mem saved   time delta
  ---------------------------- ------------ ---------- ---------- ------------  ----------  -----------
  baseline                            576.3       5.86      11.45        17.30  (baseline)   (baseline)
  torch.utils.checkpoint              800.3       6.10      11.69        17.79      -38.9%        +2.8%
  ODC (C++ if available)              560.3       5.93      11.69        17.61       +2.8%        +1.8%
  ODC (python fallback)               560.3       5.92      11.87        17.79       +2.8%        +2.8%

Copy link
Copy Markdown
Contributor

@TianhuaTao TianhuaTao left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks ok for now as an alpha feature.
There might be concerns regarding RNG state, autocast context, and cases where requires_grad == False, but for now we don't have them in our active paths, so it should be fine.

@AkshitaB AkshitaB marked this pull request as draft May 17, 2026 15:53
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants