Add OutputDiscardCheckpoint#682
Draft
AkshitaB wants to merge 15 commits into
Draft
Conversation
A Megatron-style activation-recompute primitive. Forward runs under no_grad; caller frees the output's storage after downstream consumption; a backward hook recomputes and shares storage back into the original output tensor objects without triggering autograd version errors. C++ share_storage extension built via torch.utils.cpp_extension.load_inline, with a Python fallback for environments without a compiler. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Monkeypatches _get_share_storage to None to exercise _fallback_share_storage on CI machines where ninja and a C++ compiler are present and the C++ extension would otherwise always be used. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Add a code-block example to the class docstring showing the four-step pattern (checkpoint, run downstream, discard+register, backward) and the constraints on the choice of hook_tensor. Add :param:/:returns: docstrings to the three public methods. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
The previous Python fallback rebound dst's storage via Tensor.set_(), which swaps dst's StorageImpl for a new one. Any autograd-saved view of dst that still referenced the original StorageImpl (e.g. the 2D-reshape view that MmBackward saves when Linear is called with a 3D input) would see the original storage -- which was resize_(0)'d -- and backward would hit "tensor has non-zero numel but data is not allocated". The C++ extension path didn't have this bug: it mutates dst's existing StorageImpl in place via set_data_ptr, so saved views see the new data. Make the Python fallback equivalent by resizing dst's existing storage and copying src's bytes into it, preserving StorageImpl identity. This costs an extra allocation + copy during recompute on machines without a C++ toolchain, but is correct for tensors with saved views. Add a 3D-Linear regression test covering the failure mode. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Two standalone scripts under src/scripts/ (intentionally outside src/test/ so CI does not run them): - benchmark_odc.py: compares baseline / torch.utils.checkpoint / ODC (C++ extension) / ODC (Python fallback forced) on a synthetic fat-output workload. Reports peak GPU memory and forward/backward wall time. Supports a single-shape mode and a grid sweep. - odc_ffn_integration_check.py: wraps the real olmo_core SwiGLU FeedForward with ODC around the fat (activation(w1(x)) * w3(x)) intermediate that w2 saves for backward. Runs a few iterations and asserts output and gradient parity vs the baseline FeedForward. Exits non-zero on any failure. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Adds two GPU matrix entries that invoke the standalone scripts from src/scripts/ to validate OutputDiscardCheckpoint end-to-end on real GPU hardware. Revert before merging this PR. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Rewrite benchmark_odc.py around a BenchBlock abstraction so the wrapped region (the thing ODC discards) is interchangeable. Adds five concrete block types covering the spectrum of ODC fit: - fp32_cast: x.float() under ODC followed by an fp32 Linear (the MoE router's pattern). Recompute is trivial -> ODC should be neutral or positive even at N=1. - up_proj: fat Linear up-projection with no activation inside. Linear saves only its input -> recompute spike is small. - silu_up: silu(up(x)) inside; activation adds a saved intermediate to recompute, raising recompute peak. - swiglu: OLMo SwiGLU FFN; three fat intermediates saved during recompute -> worst-case ODC recompute footprint. - rms_norm: RMSNorm + Linear + residual; cheap recompute, modest savings. Each block runs at N = 1 (single-layer, ODC's worst case because the savings window is zero) and N = args.n_layers (default 4) so the multi-layer payoff is visible. Adds --only to restrict to a subset and --layers to override the stack depths. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Chain N ODCFeedForward blocks (default N = 4) and verify per-iteration output + gradient parity vs a baseline stack. Exercises per-block OutputDiscardCheckpoint instances and the order in which their recompute hooks fire as backward walks back through the stack -- a regression here would silently corrupt gradients in multi-FFN training. Adds --n-layers (default 4) and --layers (override). By default runs both N = 1 and N = --n-layers so the single-layer case stays covered. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Flip the outer/inner loop in main() so output is grouped by stack depth first, then iterates through all block types within each depth. Makes cross-block-type comparison at a fixed depth easier to read. Also adds an --iters flag (default 10) for the number of timed iterations, matching the workflow's invocation. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
- New FP32SoftmaxBlock: bf16 -> upcast -> fp32 softmax -> downcast pattern, modeling the attention/router softmax-in-fp32 path. Softmax saves its OUTPUT for backward, so the discarded tensor is the fp32 softmax probs. - --dtype now accepts multiple values (default [bf16]). Each is run as a separate top-level group so the precision-boundary effect is visible side-by-side (fp32_cast / fp32_softmax discard tensor doubles at bf16/fp16 base, neutral at fp32 base). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
So the precision-boundary effect on fp32_cast and fp32_softmax is visible in the CI output. Revert before merge along with the other temporary ODC steps. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
When the input is already fp32, x.float() returns x itself (no copy), so h aliases x. ODC's discard would resize x's storage to 0, breaking backward. Same aliasing issue for torch.utils.checkpoint -- its recompute graph would also wrap an identity. Skip the checkpoint variants for this degenerate case so the benchmark row truthfully reports "no benefit" rather than crashing. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Contributor
Author
Benchmarking |
TianhuaTao
approved these changes
May 17, 2026
Contributor
TianhuaTao
left a comment
There was a problem hiding this comment.
Looks ok for now as an alpha feature.
There might be concerns regarding RNG state, autocast context, and cases where requires_grad == False, but for now we don't have them in our active paths, so it should be fine.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Add
OutputDiscardCheckpoint, an activation-recompute primitive for cases where the output of a checkpointed region dominates memory (rather than its intermediates).Vanilla
torch.utils.checkpointdiscards intermediates inside the wrapped function but can't free the output -- downstream consumers and their saved-for-backward references hold it live.OutputDiscardCheckpointextends that pattern: forward runs underno_grad, the output's storage can be freed after downstream forward consumes it, and a backward hook recomputes the forward and rebinds the freed storage in place (via a C++share_storageextension, with a Python fallback). The tensor object survives so existing autograd saved-tensor references stay valid; only its underlying bytes are recycled.Useful for fat-output ops where the output is wider than the input -- precision casts (bf16 -> fp32 doubling), FFN up-projections, attention outputs before SDPA fuses them.
Contents
src/olmo_core/nn/output_discard_checkpoint.py-- the primitive (287 lines), with a code-block usage example and:param:/:returns:docstrings.src/olmo_core/nn/__init__.py-- exportsOutputDiscardCheckpoint.src/test/nn/output_discard_checkpoint_test.py-- four tests: storage discard/restore, grad parity vs a non-checkpointedSequential, Python-fallback path forced via monkeypatch, and a 3D-Linear regression test (see below).Bug found and fixed during integration testing
Writing the FFN integration script surfaced a real bug in the Python fallback.
Tensor.set_(new_storage, ...)swapsdst'sStorageImplfor a new one. Any autograd-saved view ofdstthat still referenced the originalStorageImpl-- e.g. the 2D-reshape view thatMmBackwardsaves whenLinearis called with a 3D input -- would see the original storage, which wasresize_(0)'d and never refilled. Backward then hit "tensor has non-zero numel but data is not allocated".The C++ path didn't have this bug: it mutates
dst's existingStorageImplin place viaset_data_ptr, so saved views see the new data through the sameStorageImpl. The fix makes the Python fallback equivalent: resizedst's existing storage and copysrc's bytes into it. Costs an extra allocation + copy during recompute on machines without a C++ toolchain, but is correct for tensors with saved views (which is most real workloads).The 3D-Linear regression test ensures this doesn't sneak back in.
Standalone scripts (not run in CI)
Two scripts under
src/scripts/-- intentionally outsidesrc/test/so pytest never picks them up. Useful for human-driven benchmarking and verification:src/scripts/benchmark_odc.py-- compares baseline /torch.utils.checkpoint/ ODC (C++) / ODC (Python fallback) on a fat-outputLinear -> activation -> Linearworkload. Reports peak GPU memory and forward/backward wall time. Supports--scenarios gridfor a shape sweep.src/scripts/odc_ffn_integration_check.py-- wraps the real OLMoFeedForward(SwiGLU) with ODC around the fatactivation(w1(x)) * w3(x)intermediate thatw2saves for backward. Asserts output + gradient parity vs the baseline; exits non-zero on failure.Test plan
pytest -v src/test/nn/output_discard_checkpoint_test.py-- 4 tests pass (incl. forced Python fallback and 3D-Linear regression).make checks-- isort / black / ruff / mypy all clean.python src/scripts/odc_ffn_integration_check.py --device cpu-- PASS (output and grad parity at 0 diff in fp32).python src/scripts/benchmark_odc.py-- visually verify ODC peak memory < baseline.python src/scripts/odc_ffn_integration_check.py --dtype bf16-- verify parity holds in bf16.🤖 Generated with Claude Code