-
Notifications
You must be signed in to change notification settings - Fork 235
[tx] add per-layer gradient checkpointing with scan for memory-efficient training #906
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request successfully integrates per-layer gradient checkpointing using jax.lax.scan and jax.checkpoint, which is a significant improvement for memory efficiency during training. The implementation correctly separates the checkpointed and non-checkpointed forward paths, and the new tests adequately cover the correctness of outputs, hidden states, and gradients. The is_training flag and gradient_checkpointing configuration are well-handled across the models and backend. There are a few minor areas for improvement related to test clarity and robustness.
Compute lm_head projection in chunks to avoid materializing the full [B*T, V] logits tensor. Key changes: - Add compute_logits flag to model.__call__ (skip lm_head when False) - Add lm_head weight to CausalLMOutput for external computation - Implement chunked logprobs with jax.lax.map (default chunk_size=1024) - Add loss_chunk_size config option Memory savings: O(B*T*V) -> O(chunk_size*V) for logits tensor. For Qwen3-4B with V=151k, 8k seq: ~19GB -> ~300MB peak logits memory. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
…ze<=0 The chunked cross-entropy path computes logits via direct matmul with lm_head weight, bypassing LoRA adapters. This is incorrect when train_unembed=True since LoRA should be applied to lm_head. Changes: - Rename is_training to skip_logits for clarity - Add _use_chunked_loss flag to backend - Automatically switch to non-chunked mode when: - train_unembed=True (requires LoRA on lm_head) - loss_chunk_size <= 0 (config-based disable) - Non-chunked path uses pre-computed logits with LoRA correctly applied
Recompute activations during backward to save memory. Only one layer's activations are held at a time during backward pass, reducing peak memory by ~num_layers factor. - Add gradient_checkpointing config to ModelConfig - Apply jax.checkpoint per-layer when is_training=True - Rename compute_logits to is_training (controls both logits and checkpointing) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
…euse Add _forward_layers_checkpointed() using jax.lax.fori_loop so XLA compiles ONE loop body and reuses buffers during backward recomputation. With a Python loop, XLA unrolls N separate checkpoint regions and can't optimize buffer reuse across them. Only enabled when gradient_checkpointing=True. Without checkpointing, activations are stored anyway, so fori_loop's buffer reuse doesn't help and its weight stacking overhead makes it worse. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- test_jax_backend.py: extend test_gradient_checkpointing to verify gradients match - test_models_common.py: add common tests for Llama3/Qwen3 (output, hidden_states, edge cases)
Handle edge case where self.layers is empty to prevent IndexError. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
fd7572d to
5cf1c66
Compare
- Resolve conflicts in llama3.py and qwen3.py - Integrate LogitsProcessor from main - Move chunked logprobs computation to LogitsProcessor.compute_chunked_logprobs
- Add LogitsProcessor.compute_logprobs() that handles both chunked and non-chunked paths - Add _logits_to_logprobs() and _compute_chunked_logprobs() as private helpers - Simplify jax.py to single compute_logprobs call
- LogitsProcessor is now a standalone utility with three static methods: compute_logits(), compute_logprobs(), logits_to_logprobs() - Model forward() returns only hidden_states (removed logits computation) - Simplified CausalLMOutput: removed logits and lm_head fields - Generator uses LogitsProcessor for all logits/logprobs computation - Backend uses LogitsProcessor.compute_logprobs() with chunking - Updated tests to use new LogitsProcessor API Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Create CausalLMBase class with compute_logits/compute_logprobs methods - Models expose wrapper methods instead of direct LogitsProcessor access - Update generator and jax.py backend to use model methods - LogitsProcessor is now internal implementation detail Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Replace _has_train_unembed flag with _train_unembed_mask array - Check at runtime if any adapter in batch needs LoRA on lm_head - Use jax.lax.cond to choose chunked vs non-chunked path - Handle adapter reuse correctly (reset mask on delete) - Remove unused _use_chunked_loss flag Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Replace abstract property with __init__(lm_head) in base class - Subclasses explicitly call CausalLMBase.__init__(self, lm_head) - Fix test to support multiple adapters for mixed train_unembed test Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Remove get_lm_head_weight() abstract method (no longer needed) - Chunked path now uses lm_head() directly instead of raw matmul - Expand adapter_indices from [B] to [B*T] for per-token handling - Remove restriction that disabled chunking with adapter_indices Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Remove _train_unembed_mask tracking from JaxBackend - Simplify _model_forward to always pass adapter_indices to compute_logprobs - Fix chunked path to reshape hidden states to [chunk_size, 1, H] for LoRA compatibility Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Instead of allocating [B*T] array via jnp.repeat, compute adapter indices per-chunk using only a [chunk_size] buffer. This reduces memory overhead significantly for long sequences. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Load HF model, get logits, save weights, delete HF model, then load our model. This avoids having both models in memory simultaneously. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Add loss_chunk_size and gradient_checkpointing to config in tests - Restore test_chunked_logprobs test that was lost during merge Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Resolved conflicts: - tx/utils/logits_processor.py: Keep chunked logprobs implementation - tx/utils/generator.py: Keep left-padded sequence handling from main - tests/models/test_models_common.py: Keep chunked logprobs test Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Restructure test to avoid OOM by loading and deleting models sequentially instead of having two models in memory simultaneously. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Restructure test to avoid OOM by creating and deleting backends sequentially instead of having two in memory simultaneously. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Test coverage for: - Chunk boundary cases (padding, exact division, larger than total) - Adapter indices handling (None, per-batch, same for all) - Gradient checkpointing flag Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Combines chunked lm_head memory optimization with per-layer gradient checkpointing. Resolves conflicts by keeping both features: - loss_chunk_size for chunked logprobs computation - gradient_checkpointing with is_training flag for activation recomputation Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Move _forward_layers_checkpointed and _forward_layers from Llama3Model and Qwen3Model into shared utility functions in tx/models/utils.py. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Resolves conflicts with main branch changes: - Updated KVCache API (removed cache_position from layer tuple) - Added dot_product_attention import - Updated DummyModel to accept loss_chunk_size parameter - Use main's cleaner _compute_chunked_logprobs implementation Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
KVCache.update() handles cache position internally, so this variable is no longer needed after the KVCache API refactor. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
|
@raulchen is attempting to deploy a commit to the Tyler's projects Team on Vercel. A member of the Team first needs to authorize it. |
skyrl-tx/tx/models/llama3.py
Outdated
|
|
||
| hidden_states, (k, v) = layer( | ||
| if is_training and self.config.gradient_checkpointing: | ||
| hidden_states, all_hidden_states = forward_layers_checkpointed( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we unify these two methods into a single method forward_layers that will do the right thing depending on the parameters that are passed in? The reasoning is that model authors should not be concerned with the difference between checkpointing and non-checkpointing, they should just be able to write the model, and also we would want to avoid duplicating the code into all the model definitions.
Internally you can of course still have _forward_layers_checkpointed and _ forward_layers and implement the "public" forward_layers function in terms of those.
skyrl-tx/tx/models/qwen3.py
Outdated
| output_hidden_states: bool | None = None, | ||
| adapter_indices: jax.Array | None = None, | ||
| kv_cache: KVCache | None = None, | ||
| is_training: bool = False, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
An alternative way to do this would be to use https://flax.readthedocs.io/en/v0.8.3/api_reference/flax.experimental.nnx/module.html#flax.experimental.nnx.Module.train (and https://flax.readthedocs.io/en/v0.8.3/api_reference/flax.experimental.nnx/module.html#flax.experimental.nnx.Module.eval) to switch the module to train/eval mode. This would have some benefits: (a) it is the standard way to do something like this (similar to how pytorch does it, there you see this kind of pattern a lot), (b) if we are going to add dropout in the future, handling training / eval like this will be much easier and (c) we don't need to pass the flags around if we want to access similar attributes. It is less explicit though, and we would need to provide our own attribute to set. Curious about your thoughts :)
Summary
jax.lax.scanwithjax.checkpointgradient_checkpointing=Truein model configImplementation
Uses
jax.lax.scanso XLA compiles ONE loop body and reuses buffers during backward recomputation. With a Python loop, XLA unrolls N separate checkpoint regions and can't optimize buffer reuse across them.Tradeoff: requires stacking all layer weights once per forward pass. This is acceptable because checkpointing already trades
compute for memory.
Test plan
is_training=Falseuses standard path with KV cacheBenchmark results #891 (comment)