Fix meta-init bf16 cast corrupting rotary inv_freq under sequence parallelism#1725
Conversation
There was a problem hiding this comment.
Code Review
This pull request modifies the meta-initialization logic in model_wrapper.py to cast only parameters and persistent buffers to the target dtype, preventing corruption of non-persistent buffers (like inv_freq in rotary embeddings) during FSDP initialization. It also adds corresponding unit and integration tests. The reviewer identified critical issues, including the use of a non-existent PyTorch method named_non_persistent_buffers() which will cause runtime AttributeErrors, the use of a non-existent model (Qwen/Qwen3-0.6B) in tests, and a similar blanket .to cast issue in fsdp_worker.py that needs to be addressed.
|
@jamesbraza thanks for the PR! As a sanity check I ran the test script against I expected NaN results without the fix in your PR, but tests pass. Can you confirm that the unit test is correctly simulating the issue you saw? |
Yep sorry about that, you're right the GPU tests were passing on |
|
@jamesbraza thanks for the updated test, but it would be great to get a regression test for the NaN failures with SP > 1. I re-ran the updated test and
The reproduction script you provided before in #1709 clearly produces NaNs with BF16, but then the script forces meta init for Qwen3-1.7B, which clearly uses So the reproducer is not simulating an actual FSDP init for Qwen3 1.7B with SkyRL. |
…aSky-AI#1709) HFModelWrapper's meta-init cast every parameter and buffer to the target dtype, including non-persistent buffers like Qwen3RotaryEmbedding.inv_freq that from_pretrained (rank 0) leaves at fp32. The dtype divergence made the init-time rank-0->all non-persistent-buffer broadcast reinterpret rank-0's fp32 bytes into the half-width bf16 buffers, NaN-ing rotary attention on every non-rank-0 rank under sequence parallelism. Now cast only parameters and persistent buffers, matching from_pretrained's default-dtype context. Adds a CPU unit test for the dtype invariant and a 2-rank SP=2 GPU test (inv_freq finite + forward NaN-free); verified on 2-node 8xH100 that the reproducer flips bug_reproduces True->False. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
The GPU test used `Qwen/Qwen3-0.6B`, which ties word embeddings, so `FSDPRefWorkerBase.init_model` gated meta-init off (`use_meta_tensor=not tie_word_embeddings`) and the test passed on `main` even without the fix — it never reached the buggy path. Switch both meta-init tests to a non-tied model (`llamafactory/tiny-random-Llama-3`) so the meta path is taken, and assert each rank's `inv_freq` stays fp32: the corrupted bf16 values can be finite garbage (e.g. ~2e7), so the finiteness/forward checks alone could pass. Confirmed on 2xH100 that both tests now fail on `main` and pass with the fix. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
…divergence The prior test used a tiny non-tied model (tiny-random-Llama-3, head_dim 4), which surfaced the `inv_freq` bf16 dtype divergence but not the forward NaN: too few rotary frequencies for the corrupted bf16 buffer to land on a NaN. Switch to Qwen/Qwen3-8B (non-tied, head_dim 128) so the corruption reproduces the actual NaN logits under SP>1, and assert forward-NaN-free first (the headline symptom) with the dtype check as a deterministic backstop. Verified on 2xH100: the forward NaN fires deterministically on main (3/3 runs, bf16=True) and the test passes with the fix. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
be23934 to
ae0e8bf
Compare
Yeah thanks for pointing this out, agreed let's get this right, I just pushed a commit upgrading the test. So Also |
SumanthRH
left a comment
There was a problem hiding this comment.
Great catch and thanks for the hard work on the regression test.
Meta-init cast non-persistent buffers (rotary
inv_freq) to bf16 while rank-0'sfrom_pretrainedkept them fp32, so the init-time buffer broadcast reinterpreted rank-0's fp32 bytes as bf16 garbage and produced NaN attention under SP>1. Now we cast only params and persistent buffers, matching transformersfrom_pretrained.Closes #1709