Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion skyrl/backends/skyrl_train/workers/model_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,22 @@ def __init__(
if meta_init:
with torch.device("meta"):
self.model = model_class.from_config(model_config, trust_remote_code=True)
self.model.to(torch.bfloat16 if bf16 else torch.float32)
target_dtype = torch.bfloat16 if bf16 else torch.float32
# Cast just params + persistent buffers to the target dtype;
# leave non-persistent buffers at their init dtype. For example, transformers
# builds `Qwen3RotaryEmbedding.inv_freq` via a RoPE init that hardcodes
# `dtype=torch.float`, so it is fp32 no matter the model dtype, which rank-0's
# `from_pretrained` preserves but a blanket `.to` would clobber.
# SEE: https://github.com/huggingface/transformers/blob/v5.8.0/src/transformers/modeling_rope_utils.py#L177-L178
Comment thread
jamesbraza marked this conversation as resolved.
# Otherwise, e.g., non-rank-0's `inv_freq` is bf16 while rank-0's stays fp32,
# so the init-time rank-0→all broadcast that seeds these buffers copies
# rank-0's fp32 bytes into that half-width bf16 buffer as huge garbage values
non_persistent_names = {n for n, _ in self.model.named_non_persistent_buffers()}
for p in self.model.parameters():
p.data = p.data.to(target_dtype)
for name, buf in self.model.named_buffers():
if name not in non_persistent_names:
buf.data = buf.data.to(target_dtype)
Comment thread
jamesbraza marked this conversation as resolved.
else:
self.model = model_class.from_pretrained(
pretrain_or_model,
Expand Down
107 changes: 107 additions & 0 deletions tests/backends/skyrl_train/gpu/gpu_ci/test_meta_init.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
"""
uv run --isolated --extra dev --extra fsdp pytest -s -vvv tests/backends/skyrl_train/gpu/gpu_ci/test_meta_init.py

Multi-rank FSDP2 meta-init correctness under sequence parallelism.
"""

import pytest
import ray
import torch

from skyrl.backends.skyrl_train.distributed.utils import get_free_port
from skyrl.backends.skyrl_train.workers.fsdp.fsdp_worker import FSDPRefWorkerBase
from skyrl.train.config import AlgorithmConfig, RefConfig, TrainerConfig

# "Tied" word embeddings share one weight tensor between the input embedding and the output `lm_head`:
# https://github.com/huggingface/transformers/blob/v5.8.0/src/transformers/modeling_utils.py#L2582
# This test needs a model that is both non-tied and has a realistic head_dim:
# - Non-tied: `FSDPRefWorkerBase.init_model` gates meta-init on `not tie_word_embeddings`, so a
# tied model (e.g. Qwen3-0.6B) skips meta-init entirely and can't reproduce the bug.
# - Realistic head_dim (e.g. Qwen3-8B's 128): with that many frequencies,
# the bf16 `inv_freq` ends up holding a NaN and the forward NaNs with SP>1;
# a tiny head_dim (e.g. 4) only shows the dtype change (bf16 vs fp32), not a NaN.
MODEL_NAME = "Qwen/Qwen3-8B"
SERVER_HOST = "127.0.0.1"
WORLD_SIZE = 2
SP_SIZE = 2
SEQ_LEN = 128


@ray.remote(num_gpus=1)
class MetaInitProbeRefWorker(FSDPRefWorkerBase):
def record_inv_freq(self) -> list[dict]:
out: list[dict] = []
for name, buf in self.model.model.named_buffers():
if not name.endswith("inv_freq") or name.endswith("original_inv_freq"):
# `original_inv_freq` is transformers' pre-cast backup of `inv_freq`.
# It has the same data, so skip it to avoid duplicate records
continue
out.append(
{
"name": name,
"n_nan": int(torch.isnan(buf).sum().item()),
"dtype": str(buf.dtype),
"first5": buf.detach().float().cpu().tolist()[:5],
}
)
return out

def forward_and_count_nan(self, sequences: torch.Tensor) -> int:
sequences = sequences.to("cuda")
with torch.no_grad(), torch.autocast(dtype=torch.bfloat16, device_type="cuda"):
log_probs = self.model(sequences, sequences.shape[1] - 1, torch.ones_like(sequences))
return int(torch.isnan(log_probs).sum().item())


@pytest.mark.usefixtures("ray_init_fixture")
@pytest.mark.parametrize("bf16", [True, False])
def test_meta_init_inv_freq_finite_under_sp(bf16: bool) -> None:
"""Meta-init under SP=2 must leave every rank's rotary `inv_freq` finite and its
forward NaN-free, for both bf16 (which triggers the dtype mismatch) and fp32."""
cfg = TrainerConfig(
algorithm=AlgorithmConfig(temperature=0.1), # Placeholder non-None temperature
ref=RefConfig(sequence_parallel_size=SP_SIZE),
bf16=bf16,
)
# All WORLD_SIZE worker actors must agree on `MASTER_PORT`; pick once on the driver
master_port = get_free_port()

actors = [
MetaInitProbeRefWorker.remote(
cfg=cfg,
world_size=WORLD_SIZE,
rank=r,
local_rank=0,
master_addr=SERVER_HOST,
master_port=master_port,
sequence_parallel_size=SP_SIZE,
)
for r in range(WORLD_SIZE)
]
# Stand up the NCCL process group + device mesh, then build the FSDP2 model
ray.get([a.init_worker_process_group.remote() for a in actors])
ray.get([a.init_model.remote(MODEL_NAME) for a in actors])

# What we're protecting against: NaN logits with SP>1.
# Non-rank-0's `inv_freq` is cast to bf16 and ends up holding a NaN,
# which poisons the SP-coupled attention so every rank's forward produces NaN logits
sequences = torch.randint(10, 10_000, (1, SEQ_LEN), dtype=torch.long)
nan_counts = ray.get([a.forward_and_count_nan.remote(sequences) for a in actors])
for rank, n_nan in enumerate(nan_counts):
assert n_nan == 0, f"rank {rank}: log_probs has {n_nan} NaN positions under SP={SP_SIZE} (bf16={bf16})"

# Also assert the dtype: the unfixed meta-init casts non-rank-0's buffers
# (including `inv_freq`) to bf16, but the forward only NaNs when those bf16 values include a NaN,
# so this dtype assertion catches the bad cast even where the forward stays finite
inv_freq_records = ray.get([a.record_inv_freq.remote() for a in actors])
for rank, records in enumerate(inv_freq_records):
assert records, f"rank {rank}: test expects a model with rotary embeddings, but found no inv_freq buffers"
for record in records:
assert record["dtype"] == "torch.float32", (
f"rank {rank}: {record['name']} is {record['dtype']} after FSDP init, "
f"expected fp32 (cast to bf16 by meta-init, diverging from rank-0)"
)
assert record["n_nan"] == 0, (
f"rank {rank}: {record['name']} non-finite after FSDP init "
f"(n_nan={record['n_nan']}, dtype={record['dtype']}, first5={record['first5']})"
)
24 changes: 24 additions & 0 deletions tests/backends/skyrl_train/models/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import torch
from flash_attn.bert_padding import pad_input, unpad_input

from skyrl.backends.skyrl_train.workers.model_wrapper import HFModelWrapper


@pytest.fixture
def input_ids():
Expand Down Expand Up @@ -65,3 +67,25 @@ def test_flash_attention_sequence_unpacking(input_ids, attention_mask, position_
assert torch.equal(unpacked_input_ids, input_ids)
# mask out the attention mask because the padding value used can differ
assert torch.equal(unpacked_position_ids * attention_mask, position_ids * attention_mask)


@pytest.mark.parametrize("bf16", [True, False])
def test_meta_init_keeps_non_persistent_buffers_fp32(bf16: bool) -> None:
"""Meta-init casts params and persistent buffers to the target dtype while leaving
non-persistent buffers (rotary `inv_freq`) at fp32."""
target_dtype = torch.bfloat16 if bf16 else torch.float32
wrapper = HFModelWrapper(
"llamafactory/tiny-random-Llama-3", # any model with a rotary `inv_freq` buffer
bf16=bf16,
meta_init=True, # We're exercising the non-rank-0 meta path
)

for name, param in wrapper.model.named_parameters():
assert param.dtype == target_dtype, f"param {name} is {param.dtype}, expected {target_dtype}"

inv_freq_seen = False
for name, buf in wrapper.model.named_non_persistent_buffers():
assert buf.dtype == torch.float32, f"non-persistent buffer {name} is {buf.dtype}, expected fp32"
if name.endswith("inv_freq"):
inv_freq_seen = True
assert inv_freq_seen, "expected at least one inv_freq buffer in the model"
Loading