From 675a29ee8ed272148c740e63ba9268d5207c14f6 Mon Sep 17 00:00:00 2001 From: James Braza Date: Thu, 28 May 2026 15:21:38 -0700 Subject: [PATCH 1/3] [fsdp] Keep meta-init from casting rotary inv_freq to bf16 (fixes #1709) HFModelWrapper's meta-init cast every parameter and buffer to the target dtype, including non-persistent buffers like Qwen3RotaryEmbedding.inv_freq that from_pretrained (rank 0) leaves at fp32. The dtype divergence made the init-time rank-0->all non-persistent-buffer broadcast reinterpret rank-0's fp32 bytes into the half-width bf16 buffers, NaN-ing rotary attention on every non-rank-0 rank under sequence parallelism. Now cast only parameters and persistent buffers, matching from_pretrained's default-dtype context. Adds a CPU unit test for the dtype invariant and a 2-rank SP=2 GPU test (inv_freq finite + forward NaN-free); verified on 2-node 8xH100 that the reproducer flips bug_reproduces True->False. Co-Authored-By: Claude Opus 4.8 (1M context) --- .../skyrl_train/workers/model_wrapper.py | 17 +++- .../skyrl_train/gpu/gpu_ci/test_meta_init.py | 89 +++++++++++++++++++ .../skyrl_train/models/test_models.py | 24 +++++ 3 files changed, 129 insertions(+), 1 deletion(-) create mode 100644 tests/backends/skyrl_train/gpu/gpu_ci/test_meta_init.py diff --git a/skyrl/backends/skyrl_train/workers/model_wrapper.py b/skyrl/backends/skyrl_train/workers/model_wrapper.py index ba2069f27a..6ed9b5b8d1 100644 --- a/skyrl/backends/skyrl_train/workers/model_wrapper.py +++ b/skyrl/backends/skyrl_train/workers/model_wrapper.py @@ -138,7 +138,22 @@ def __init__( if meta_init: with torch.device("meta"): self.model = model_class.from_config(model_config, trust_remote_code=True) - self.model.to(torch.bfloat16 if bf16 else torch.float32) + target_dtype = torch.bfloat16 if bf16 else torch.float32 + # Cast just params + persistent buffers to the target dtype; + # leave non-persistent buffers at their init dtype. For example, transformers + # builds `Qwen3RotaryEmbedding.inv_freq` via a RoPE init that hardcodes + # `dtype=torch.float`, so it is fp32 no matter the model dtype, which rank-0's + # `from_pretrained` preserves but a blanket `.to` would clobber. + # SEE: https://github.com/huggingface/transformers/blob/v5.8.0/src/transformers/modeling_rope_utils.py#L177-L178 + # Otherwise, e.g., non-rank-0's `inv_freq` is bf16 while rank-0's stays fp32, + # so the init-time rank-0→all broadcast that seeds these buffers copies + # rank-0's fp32 bytes into that half-width bf16 buffer as huge garbage values + non_persistent_names = {n for n, _ in self.model.named_non_persistent_buffers()} + for p in self.model.parameters(): + p.data = p.data.to(target_dtype) + for name, buf in self.model.named_buffers(): + if name not in non_persistent_names: + buf.data = buf.data.to(target_dtype) else: self.model = model_class.from_pretrained( pretrain_or_model, diff --git a/tests/backends/skyrl_train/gpu/gpu_ci/test_meta_init.py b/tests/backends/skyrl_train/gpu/gpu_ci/test_meta_init.py new file mode 100644 index 0000000000..e9dbdf675e --- /dev/null +++ b/tests/backends/skyrl_train/gpu/gpu_ci/test_meta_init.py @@ -0,0 +1,89 @@ +""" +uv run --isolated --extra dev --extra fsdp pytest -s -vvv tests/backends/skyrl_train/gpu/gpu_ci/test_meta_init.py + +Multi-rank FSDP2 meta-init correctness under sequence parallelism. +""" + +import pytest +import ray +import torch + +from skyrl.backends.skyrl_train.distributed.utils import get_free_port +from skyrl.backends.skyrl_train.workers.fsdp.fsdp_worker import FSDPRefWorkerBase +from skyrl.train.config import AlgorithmConfig, RefConfig, TrainerConfig + +MODEL_NAME = "Qwen/Qwen3-0.6B" # Small placeholder model with a non-persistent rotary inv_freq buffer +SERVER_HOST = "127.0.0.1" +WORLD_SIZE = 2 +SP_SIZE = 2 +SEQ_LEN = 128 + + +@ray.remote(num_gpus=1) +class MetaInitProbeRefWorker(FSDPRefWorkerBase): + def record_inv_freq(self) -> list[dict]: + out: list[dict] = [] + for name, buf in self.model.model.named_buffers(): + if not name.endswith("inv_freq") or name.endswith("original_inv_freq"): + # `original_inv_freq` is transformers' pre-cast backup of `inv_freq`. + # It has the same data, so skip it to avoid duplicate records + continue + out.append( + { + "name": name, + "n_nan": int(torch.isnan(buf).sum().item()), + "dtype": str(buf.dtype), + "first5": buf.detach().float().cpu().tolist()[:5], + } + ) + return out + + def forward_and_count_nan(self, sequences: torch.Tensor) -> int: + sequences = sequences.to("cuda") + with torch.no_grad(), torch.autocast(dtype=torch.bfloat16, device_type="cuda"): + log_probs = self.model(sequences, sequences.shape[1] - 1, torch.ones_like(sequences)) + return int(torch.isnan(log_probs).sum().item()) + + +@pytest.mark.usefixtures("ray_init_fixture") +@pytest.mark.parametrize("bf16", [True, False]) +def test_meta_init_inv_freq_finite_under_sp(bf16: bool) -> None: + """Meta-init under SP=2 must leave every rank's rotary `inv_freq` finite and its + forward NaN-free, for both bf16 (which triggers the dtype mismatch) and fp32.""" + cfg = TrainerConfig( + algorithm=AlgorithmConfig(temperature=0.1), # Placeholder non-None temperature + ref=RefConfig(sequence_parallel_size=SP_SIZE), + bf16=bf16, + ) + # All WORLD_SIZE worker actors must agree on `MASTER_PORT`; pick once on the driver + master_port = get_free_port() + + actors = [ + MetaInitProbeRefWorker.remote( + cfg=cfg, + world_size=WORLD_SIZE, + rank=r, + local_rank=0, + master_addr=SERVER_HOST, + master_port=master_port, + sequence_parallel_size=SP_SIZE, + ) + for r in range(WORLD_SIZE) + ] + # Stand up the NCCL process group + device mesh, then build the FSDP2 model + ray.get([a.init_worker_process_group.remote() for a in actors]) + ray.get([a.init_model.remote(MODEL_NAME) for a in actors]) + + inv_freq_records = ray.get([a.record_inv_freq.remote() for a in actors]) + for rank, records in enumerate(inv_freq_records): + assert records, f"rank {rank}: no inv_freq buffers found — test expects a model with rotary embeddings" + for record in records: + assert record["n_nan"] == 0, ( + f"rank {rank}: {record['name']} non-finite after FSDP init " + f"(n_nan={record['n_nan']}, dtype={record['dtype']}, first5={record['first5']})" + ) + + sequences = torch.randint(10, 10_000, (1, SEQ_LEN), dtype=torch.long) + nan_counts = ray.get([a.forward_and_count_nan.remote(sequences) for a in actors]) + for rank, n_nan in enumerate(nan_counts): + assert n_nan == 0, f"rank {rank}: log_probs has {n_nan} NaN positions under SP={SP_SIZE} (bf16={bf16})" diff --git a/tests/backends/skyrl_train/models/test_models.py b/tests/backends/skyrl_train/models/test_models.py index 546b4c3b25..7f631ab016 100644 --- a/tests/backends/skyrl_train/models/test_models.py +++ b/tests/backends/skyrl_train/models/test_models.py @@ -2,6 +2,8 @@ import torch from flash_attn.bert_padding import pad_input, unpad_input +from skyrl.backends.skyrl_train.workers.model_wrapper import HFModelWrapper + @pytest.fixture def input_ids(): @@ -65,3 +67,25 @@ def test_flash_attention_sequence_unpacking(input_ids, attention_mask, position_ assert torch.equal(unpacked_input_ids, input_ids) # mask out the attention mask because the padding value used can differ assert torch.equal(unpacked_position_ids * attention_mask, position_ids * attention_mask) + + +@pytest.mark.parametrize("bf16", [True, False]) +def test_meta_init_keeps_non_persistent_buffers_fp32(bf16: bool) -> None: + """Meta-init casts params and persistent buffers to the target dtype while leaving + non-persistent buffers (e.g. `Qwen3RotaryEmbedding.inv_freq`) at fp32.""" + target_dtype = torch.bfloat16 if bf16 else torch.float32 + wrapper = HFModelWrapper( + "Qwen/Qwen3-0.6B", # any rotary model works + bf16=bf16, + meta_init=True, # We're exercising the non-rank-0 meta path + ) + + for name, param in wrapper.model.named_parameters(): + assert param.dtype == target_dtype, f"param {name} is {param.dtype}, expected {target_dtype}" + + inv_freq_seen = False + for name, buf in wrapper.model.named_non_persistent_buffers(): + assert buf.dtype == torch.float32, f"non-persistent buffer {name} is {buf.dtype}, expected fp32" + if name.endswith("inv_freq"): + inv_freq_seen = True + assert inv_freq_seen, "expected at least one inv_freq buffer in Qwen3" From d6d217fd3a05a798629a5e0e8644db19afb7ef3a Mon Sep 17 00:00:00 2001 From: James Braza Date: Fri, 29 May 2026 10:47:38 -0700 Subject: [PATCH 2/3] [test] Make the meta-init regression test actually exercise meta-init MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The GPU test used `Qwen/Qwen3-0.6B`, which ties word embeddings, so `FSDPRefWorkerBase.init_model` gated meta-init off (`use_meta_tensor=not tie_word_embeddings`) and the test passed on `main` even without the fix — it never reached the buggy path. Switch both meta-init tests to a non-tied model (`llamafactory/tiny-random-Llama-3`) so the meta path is taken, and assert each rank's `inv_freq` stays fp32: the corrupted bf16 values can be finite garbage (e.g. ~2e7), so the finiteness/forward checks alone could pass. Confirmed on 2xH100 that both tests now fail on `main` and pass with the fix. Co-Authored-By: Claude Opus 4.8 (1M context) --- .../skyrl_train/gpu/gpu_ci/test_meta_init.py | 12 +++++++++++- tests/backends/skyrl_train/models/test_models.py | 6 +++--- 2 files changed, 14 insertions(+), 4 deletions(-) diff --git a/tests/backends/skyrl_train/gpu/gpu_ci/test_meta_init.py b/tests/backends/skyrl_train/gpu/gpu_ci/test_meta_init.py index e9dbdf675e..1eb55dbdb2 100644 --- a/tests/backends/skyrl_train/gpu/gpu_ci/test_meta_init.py +++ b/tests/backends/skyrl_train/gpu/gpu_ci/test_meta_init.py @@ -12,7 +12,12 @@ from skyrl.backends.skyrl_train.workers.fsdp.fsdp_worker import FSDPRefWorkerBase from skyrl.train.config import AlgorithmConfig, RefConfig, TrainerConfig -MODEL_NAME = "Qwen/Qwen3-0.6B" # Small placeholder model with a non-persistent rotary inv_freq buffer +# Use a model that does not tie its embeddings (shares one weight tensor between +# the input embedding and the output `lm_head`), since `FSDPRefWorkerBase.init_model` +# gates meta-init on `not tie_word_embeddings`, leading to a 'tied' model (e.g. Qwen3-0.6B) +# skipping meta-init entirely, so it can't catch the regression we're testing against +# https://github.com/huggingface/transformers/blob/v5.8.0/src/transformers/modeling_utils.py#L2582 +MODEL_NAME = "llamafactory/tiny-random-Llama-3" SERVER_HOST = "127.0.0.1" WORLD_SIZE = 2 SP_SIZE = 2 @@ -78,6 +83,11 @@ def test_meta_init_inv_freq_finite_under_sp(bf16: bool) -> None: for rank, records in enumerate(inv_freq_records): assert records, f"rank {rank}: no inv_freq buffers found — test expects a model with rotary embeddings" for record in records: + assert record["dtype"] == "torch.float32", ( + f"rank {rank}: {record['name']} is {record['dtype']} after FSDP init, expected fp32 " + f"(non-rank-0 meta-init cast it to bf16, diverging from rank-0's fp32). Checked in addition to " + f"the finiteness assert below, which a finite-but-garbage value (e.g. ~2e7) would pass." + ) assert record["n_nan"] == 0, ( f"rank {rank}: {record['name']} non-finite after FSDP init " f"(n_nan={record['n_nan']}, dtype={record['dtype']}, first5={record['first5']})" diff --git a/tests/backends/skyrl_train/models/test_models.py b/tests/backends/skyrl_train/models/test_models.py index 7f631ab016..d2ff394deb 100644 --- a/tests/backends/skyrl_train/models/test_models.py +++ b/tests/backends/skyrl_train/models/test_models.py @@ -72,10 +72,10 @@ def test_flash_attention_sequence_unpacking(input_ids, attention_mask, position_ @pytest.mark.parametrize("bf16", [True, False]) def test_meta_init_keeps_non_persistent_buffers_fp32(bf16: bool) -> None: """Meta-init casts params and persistent buffers to the target dtype while leaving - non-persistent buffers (e.g. `Qwen3RotaryEmbedding.inv_freq`) at fp32.""" + non-persistent buffers (rotary `inv_freq`) at fp32.""" target_dtype = torch.bfloat16 if bf16 else torch.float32 wrapper = HFModelWrapper( - "Qwen/Qwen3-0.6B", # any rotary model works + "llamafactory/tiny-random-Llama-3", # any model with a rotary `inv_freq` buffer bf16=bf16, meta_init=True, # We're exercising the non-rank-0 meta path ) @@ -88,4 +88,4 @@ def test_meta_init_keeps_non_persistent_buffers_fp32(bf16: bool) -> None: assert buf.dtype == torch.float32, f"non-persistent buffer {name} is {buf.dtype}, expected fp32" if name.endswith("inv_freq"): inv_freq_seen = True - assert inv_freq_seen, "expected at least one inv_freq buffer in Qwen3" + assert inv_freq_seen, "expected at least one inv_freq buffer in the model" From ae0e8bf6a26d0bd15fee8b7c96a123034e65a814 Mon Sep 17 00:00:00 2001 From: James Braza Date: Fri, 29 May 2026 17:56:28 -0700 Subject: [PATCH 3/3] [test] Reproduce the actual NaN-under-SP failure, not just the dtype divergence The prior test used a tiny non-tied model (tiny-random-Llama-3, head_dim 4), which surfaced the `inv_freq` bf16 dtype divergence but not the forward NaN: too few rotary frequencies for the corrupted bf16 buffer to land on a NaN. Switch to Qwen/Qwen3-8B (non-tied, head_dim 128) so the corruption reproduces the actual NaN logits under SP>1, and assert forward-NaN-free first (the headline symptom) with the dtype check as a deterministic backstop. Verified on 2xH100: the forward NaN fires deterministically on main (3/3 runs, bf16=True) and the test passes with the fix. Co-Authored-By: Claude Opus 4.8 (1M context) --- .../skyrl_train/gpu/gpu_ci/test_meta_init.py | 36 +++++++++++-------- 1 file changed, 22 insertions(+), 14 deletions(-) diff --git a/tests/backends/skyrl_train/gpu/gpu_ci/test_meta_init.py b/tests/backends/skyrl_train/gpu/gpu_ci/test_meta_init.py index 1eb55dbdb2..fbf298dbdf 100644 --- a/tests/backends/skyrl_train/gpu/gpu_ci/test_meta_init.py +++ b/tests/backends/skyrl_train/gpu/gpu_ci/test_meta_init.py @@ -12,12 +12,15 @@ from skyrl.backends.skyrl_train.workers.fsdp.fsdp_worker import FSDPRefWorkerBase from skyrl.train.config import AlgorithmConfig, RefConfig, TrainerConfig -# Use a model that does not tie its embeddings (shares one weight tensor between -# the input embedding and the output `lm_head`), since `FSDPRefWorkerBase.init_model` -# gates meta-init on `not tie_word_embeddings`, leading to a 'tied' model (e.g. Qwen3-0.6B) -# skipping meta-init entirely, so it can't catch the regression we're testing against +# "Tied" word embeddings share one weight tensor between the input embedding and the output `lm_head`: # https://github.com/huggingface/transformers/blob/v5.8.0/src/transformers/modeling_utils.py#L2582 -MODEL_NAME = "llamafactory/tiny-random-Llama-3" +# This test needs a model that is both non-tied and has a realistic head_dim: +# - Non-tied: `FSDPRefWorkerBase.init_model` gates meta-init on `not tie_word_embeddings`, so a +# tied model (e.g. Qwen3-0.6B) skips meta-init entirely and can't reproduce the bug. +# - Realistic head_dim (e.g. Qwen3-8B's 128): with that many frequencies, +# the bf16 `inv_freq` ends up holding a NaN and the forward NaNs with SP>1; +# a tiny head_dim (e.g. 4) only shows the dtype change (bf16 vs fp32), not a NaN. +MODEL_NAME = "Qwen/Qwen3-8B" SERVER_HOST = "127.0.0.1" WORLD_SIZE = 2 SP_SIZE = 2 @@ -79,21 +82,26 @@ def test_meta_init_inv_freq_finite_under_sp(bf16: bool) -> None: ray.get([a.init_worker_process_group.remote() for a in actors]) ray.get([a.init_model.remote(MODEL_NAME) for a in actors]) + # What we're protecting against: NaN logits with SP>1. + # Non-rank-0's `inv_freq` is cast to bf16 and ends up holding a NaN, + # which poisons the SP-coupled attention so every rank's forward produces NaN logits + sequences = torch.randint(10, 10_000, (1, SEQ_LEN), dtype=torch.long) + nan_counts = ray.get([a.forward_and_count_nan.remote(sequences) for a in actors]) + for rank, n_nan in enumerate(nan_counts): + assert n_nan == 0, f"rank {rank}: log_probs has {n_nan} NaN positions under SP={SP_SIZE} (bf16={bf16})" + + # Also assert the dtype: the unfixed meta-init casts non-rank-0's buffers + # (including `inv_freq`) to bf16, but the forward only NaNs when those bf16 values include a NaN, + # so this dtype assertion catches the bad cast even where the forward stays finite inv_freq_records = ray.get([a.record_inv_freq.remote() for a in actors]) for rank, records in enumerate(inv_freq_records): - assert records, f"rank {rank}: no inv_freq buffers found — test expects a model with rotary embeddings" + assert records, f"rank {rank}: test expects a model with rotary embeddings, but found no inv_freq buffers" for record in records: assert record["dtype"] == "torch.float32", ( - f"rank {rank}: {record['name']} is {record['dtype']} after FSDP init, expected fp32 " - f"(non-rank-0 meta-init cast it to bf16, diverging from rank-0's fp32). Checked in addition to " - f"the finiteness assert below, which a finite-but-garbage value (e.g. ~2e7) would pass." + f"rank {rank}: {record['name']} is {record['dtype']} after FSDP init, " + f"expected fp32 (cast to bf16 by meta-init, diverging from rank-0)" ) assert record["n_nan"] == 0, ( f"rank {rank}: {record['name']} non-finite after FSDP init " f"(n_nan={record['n_nan']}, dtype={record['dtype']}, first5={record['first5']})" ) - - sequences = torch.randint(10, 10_000, (1, SEQ_LEN), dtype=torch.long) - nan_counts = ray.get([a.forward_and_count_nan.remote(sequences) for a in actors]) - for rank, n_nan in enumerate(nan_counts): - assert n_nan == 0, f"rank {rank}: log_probs has {n_nan} NaN positions under SP={SP_SIZE} (bf16={bf16})"