Skip to content

Conversation

@raulchen
Copy link
Contributor

@raulchen raulchen commented Jan 21, 2026

Summary

  • Add per-layer gradient checkpointing using jax.lax.scan with jax.checkpoint
  • Reduces peak memory by ~num_layers factor (only one layer's activations held during backward)
  • Enable via gradient_checkpointing=True in model config

Implementation

Uses jax.lax.scan so XLA compiles ONE loop body and reuses buffers during backward recomputation. With a Python loop, XLA unrolls N separate checkpoint regions and can't optimize buffer reuse across them.

Tradeoff: requires stacking all layer weights once per forward pass. This is acceptable because checkpointing already trades
compute for memory.

Test plan

  • Forward outputs match with/without checkpointing
  • Hidden states length and values match
  • is_training=False uses standard path with KV cache
  • Gradients match between checkpointed and non-checkpointed paths

Benchmark results #891 (comment)

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request successfully integrates per-layer gradient checkpointing using jax.lax.scan and jax.checkpoint, which is a significant improvement for memory efficiency during training. The implementation correctly separates the checkpointed and non-checkpointed forward paths, and the new tests adequately cover the correctness of outputs, hidden states, and gradients. The is_training flag and gradient_checkpointing configuration are well-handled across the models and backend. There are a few minor areas for improvement related to test clarity and robustness.

raulchen and others added 17 commits January 20, 2026 18:55
Compute lm_head projection in chunks to avoid materializing the full
[B*T, V] logits tensor. Key changes:

- Add compute_logits flag to model.__call__ (skip lm_head when False)
- Add lm_head weight to CausalLMOutput for external computation
- Implement chunked logprobs with jax.lax.map (default chunk_size=1024)
- Add loss_chunk_size config option

Memory savings: O(B*T*V) -> O(chunk_size*V) for logits tensor.
For Qwen3-4B with V=151k, 8k seq: ~19GB -> ~300MB peak logits memory.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
…ze<=0

The chunked cross-entropy path computes logits via direct matmul with
lm_head weight, bypassing LoRA adapters. This is incorrect when
train_unembed=True since LoRA should be applied to lm_head.

Changes:
- Rename is_training to skip_logits for clarity
- Add _use_chunked_loss flag to backend
- Automatically switch to non-chunked mode when:
  - train_unembed=True (requires LoRA on lm_head)
  - loss_chunk_size <= 0 (config-based disable)
- Non-chunked path uses pre-computed logits with LoRA correctly applied
Recompute activations during backward to save memory. Only one layer's
activations are held at a time during backward pass, reducing peak
memory by ~num_layers factor.

- Add gradient_checkpointing config to ModelConfig
- Apply jax.checkpoint per-layer when is_training=True
- Rename compute_logits to is_training (controls both logits and checkpointing)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
…euse

Add _forward_layers_checkpointed() using jax.lax.fori_loop so XLA compiles
ONE loop body and reuses buffers during backward recomputation. With a
Python loop, XLA unrolls N separate checkpoint regions and can't optimize
buffer reuse across them.

Only enabled when gradient_checkpointing=True. Without checkpointing,
activations are stored anyway, so fori_loop's buffer reuse doesn't help
and its weight stacking overhead makes it worse.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- test_jax_backend.py: extend test_gradient_checkpointing to verify gradients match
- test_models_common.py: add common tests for Llama3/Qwen3 (output, hidden_states, edge cases)
Handle edge case where self.layers is empty to prevent IndexError.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
@raulchen raulchen force-pushed the per-layer-checkpointing branch from fd7572d to 5cf1c66 Compare January 21, 2026 02:56
raulchen and others added 9 commits January 21, 2026 13:28
- Resolve conflicts in llama3.py and qwen3.py
- Integrate LogitsProcessor from main
- Move chunked logprobs computation to LogitsProcessor.compute_chunked_logprobs
- Add LogitsProcessor.compute_logprobs() that handles both chunked and non-chunked paths
- Add _logits_to_logprobs() and _compute_chunked_logprobs() as private helpers
- Simplify jax.py to single compute_logprobs call
- LogitsProcessor is now a standalone utility with three static methods:
  compute_logits(), compute_logprobs(), logits_to_logprobs()
- Model forward() returns only hidden_states (removed logits computation)
- Simplified CausalLMOutput: removed logits and lm_head fields
- Generator uses LogitsProcessor for all logits/logprobs computation
- Backend uses LogitsProcessor.compute_logprobs() with chunking
- Updated tests to use new LogitsProcessor API

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Create CausalLMBase class with compute_logits/compute_logprobs methods
- Models expose wrapper methods instead of direct LogitsProcessor access
- Update generator and jax.py backend to use model methods
- LogitsProcessor is now internal implementation detail

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Replace _has_train_unembed flag with _train_unembed_mask array
- Check at runtime if any adapter in batch needs LoRA on lm_head
- Use jax.lax.cond to choose chunked vs non-chunked path
- Handle adapter reuse correctly (reset mask on delete)
- Remove unused _use_chunked_loss flag

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Replace abstract property with __init__(lm_head) in base class
- Subclasses explicitly call CausalLMBase.__init__(self, lm_head)
- Fix test to support multiple adapters for mixed train_unembed test

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
raulchen and others added 24 commits January 22, 2026 16:23
- Remove get_lm_head_weight() abstract method (no longer needed)
- Chunked path now uses lm_head() directly instead of raw matmul
- Expand adapter_indices from [B] to [B*T] for per-token handling
- Remove restriction that disabled chunking with adapter_indices

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Remove _train_unembed_mask tracking from JaxBackend
- Simplify _model_forward to always pass adapter_indices to compute_logprobs
- Fix chunked path to reshape hidden states to [chunk_size, 1, H] for LoRA compatibility

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Instead of allocating [B*T] array via jnp.repeat, compute adapter
indices per-chunk using only a [chunk_size] buffer. This reduces
memory overhead significantly for long sequences.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Load HF model, get logits, save weights, delete HF model, then load
our model. This avoids having both models in memory simultaneously.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Add loss_chunk_size and gradient_checkpointing to config in tests
- Restore test_chunked_logprobs test that was lost during merge

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Resolved conflicts:
- tx/utils/logits_processor.py: Keep chunked logprobs implementation
- tx/utils/generator.py: Keep left-padded sequence handling from main
- tests/models/test_models_common.py: Keep chunked logprobs test

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Restructure test to avoid OOM by loading and deleting models sequentially
instead of having two models in memory simultaneously.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Restructure test to avoid OOM by creating and deleting backends
sequentially instead of having two in memory simultaneously.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Test coverage for:
- Chunk boundary cases (padding, exact division, larger than total)
- Adapter indices handling (None, per-batch, same for all)
- Gradient checkpointing flag

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Combines chunked lm_head memory optimization with per-layer gradient
checkpointing. Resolves conflicts by keeping both features:
- loss_chunk_size for chunked logprobs computation
- gradient_checkpointing with is_training flag for activation recomputation

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Move _forward_layers_checkpointed and _forward_layers from Llama3Model
and Qwen3Model into shared utility functions in tx/models/utils.py.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Resolves conflicts with main branch changes:
- Updated KVCache API (removed cache_position from layer tuple)
- Added dot_product_attention import
- Updated DummyModel to accept loss_chunk_size parameter
- Use main's cleaner _compute_chunked_logprobs implementation

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
KVCache.update() handles cache position internally, so this variable
is no longer needed after the KVCache API refactor.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
@vercel
Copy link

vercel bot commented Jan 26, 2026

@raulchen is attempting to deploy a commit to the Tyler's projects Team on Vercel.

A member of the Team first needs to authorize it.


hidden_states, (k, v) = layer(
if is_training and self.config.gradient_checkpointing:
hidden_states, all_hidden_states = forward_layers_checkpointed(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we unify these two methods into a single method forward_layers that will do the right thing depending on the parameters that are passed in? The reasoning is that model authors should not be concerned with the difference between checkpointing and non-checkpointing, they should just be able to write the model, and also we would want to avoid duplicating the code into all the model definitions.

Internally you can of course still have _forward_layers_checkpointed and _ forward_layers and implement the "public" forward_layers function in terms of those.

output_hidden_states: bool | None = None,
adapter_indices: jax.Array | None = None,
kv_cache: KVCache | None = None,
is_training: bool = False,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

An alternative way to do this would be to use https://flax.readthedocs.io/en/v0.8.3/api_reference/flax.experimental.nnx/module.html#flax.experimental.nnx.Module.train (and https://flax.readthedocs.io/en/v0.8.3/api_reference/flax.experimental.nnx/module.html#flax.experimental.nnx.Module.eval) to switch the module to train/eval mode. This would have some benefits: (a) it is the standard way to do something like this (similar to how pytorch does it, there you see this kind of pattern a lot), (b) if we are going to add dropout in the future, handling training / eval like this will be much easier and (c) we don't need to pass the flags around if we want to access similar attributes. It is less explicit though, and we would need to provide our own attribute to set. Curious about your thoughts :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants