Skip to content

Warn when FSDP auto-wrap policy splits tied weights#21613

Merged
ethanwharris merged 5 commits intoLightning-AI:masterfrom
c-pozzi:fix/fsdp-tied-weights-warning
Apr 13, 2026
Merged

Warn when FSDP auto-wrap policy splits tied weights#21613
ethanwharris merged 5 commits intoLightning-AI:masterfrom
c-pozzi:fix/fsdp-tied-weights-warning

Conversation

@c-pozzi
Copy link
Copy Markdown
Contributor

@c-pozzi c-pozzi commented Mar 26, 2026

Summary

  • Adds detection of shared (tied) parameters that would be placed in separate FSDP units by the auto_wrap_policy
  • Emits a rank_zero_warn before wrapping, turning a cryptic RuntimeError: size mismatch into an actionable message
  • Applies to both Fabric (setup_module) and PyTorch Lightning (_setup_model) FSDP strategies

Motivation

Models like Llama, GPT-2, and Mistral tie their input embedding and output head weights. When users include torch.nn.Embedding in their FSDP auto-wrap policy, the embedding gets its own FSDP unit while the tied lm_head stays in the root unit. FSDP shards each unit independently, so lm_head sees a flat/sharded tensor instead of the expected 2D weight — causing a size mismatch deep in torch with no indication of the real cause.

Test plan

  • Tied weights across FSDP units → warning emitted
  • Tied weights in same FSDP unit → no warning
  • No shared params → no warning
  • No policy set → no warning

Closes #21403

🤖 Generated with Claude Code


📚 Documentation preview 📚: https://pytorch-lightning--21613.org.readthedocs.build/en/21613/

c-pozzi added 2 commits March 26, 2026 08:26
Detect shared parameters that would be placed in separate FSDP units
by the auto-wrap policy and emit a warning before wrapping. This turns
a cryptic RuntimeError (size mismatch) into an actionable message.

Applies to both Fabric and PyTorch Lightning FSDP strategies.

Closes Lightning-AI#21403
Cover four scenarios: tied weights across units (warns), tied weights
in same unit (no warn), no shared params (no warn), no policy (no warn).
@github-actions github-actions bot added fabric lightning.fabric.Fabric pl Generic label for PyTorch Lightning package labels Mar 26, 2026
@c-pozzi c-pozzi marked this pull request as ready for review March 26, 2026 09:15
@c-pozzi
Copy link
Copy Markdown
Contributor Author

c-pozzi commented Mar 26, 2026

@justusschock

@codecov
Copy link
Copy Markdown

codecov bot commented Mar 26, 2026

Codecov Report

❌ Patch coverage is 97.50000% with 1 line in your changes missing coverage. Please review.
✅ Project coverage is 87%. Comparing base (bb7820f) to head (ac456ab).
⚠️ Report is 9 commits behind head on master.

Additional details and impacted files
@@           Coverage Diff           @@
##           master   #21613   +/-   ##
=======================================
  Coverage      87%      87%           
=======================================
  Files         270      270           
  Lines       23934    23974   +40     
=======================================
+ Hits        20713    20749   +36     
- Misses       3221     3225    +4     

@ethanwharris ethanwharris merged commit 4ea9b01 into Lightning-AI:master Apr 13, 2026
142 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

fabric lightning.fabric.Fabric pl Generic label for PyTorch Lightning package

Projects

None yet

Development

Successfully merging this pull request may close these issues.

tensor size mismatch/crash trying FSDP with Llama 3.x model(s)

4 participants