diff --git a/skyrl/backends/skyrl_train/workers/model_wrapper.py b/skyrl/backends/skyrl_train/workers/model_wrapper.py index ba2069f27a..6ed9b5b8d1 100644 --- a/skyrl/backends/skyrl_train/workers/model_wrapper.py +++ b/skyrl/backends/skyrl_train/workers/model_wrapper.py @@ -138,7 +138,22 @@ def __init__( if meta_init: with torch.device("meta"): self.model = model_class.from_config(model_config, trust_remote_code=True) - self.model.to(torch.bfloat16 if bf16 else torch.float32) + target_dtype = torch.bfloat16 if bf16 else torch.float32 + # Cast just params + persistent buffers to the target dtype; + # leave non-persistent buffers at their init dtype. For example, transformers + # builds `Qwen3RotaryEmbedding.inv_freq` via a RoPE init that hardcodes + # `dtype=torch.float`, so it is fp32 no matter the model dtype, which rank-0's + # `from_pretrained` preserves but a blanket `.to` would clobber. + # SEE: https://github.com/huggingface/transformers/blob/v5.8.0/src/transformers/modeling_rope_utils.py#L177-L178 + # Otherwise, e.g., non-rank-0's `inv_freq` is bf16 while rank-0's stays fp32, + # so the init-time rank-0→all broadcast that seeds these buffers copies + # rank-0's fp32 bytes into that half-width bf16 buffer as huge garbage values + non_persistent_names = {n for n, _ in self.model.named_non_persistent_buffers()} + for p in self.model.parameters(): + p.data = p.data.to(target_dtype) + for name, buf in self.model.named_buffers(): + if name not in non_persistent_names: + buf.data = buf.data.to(target_dtype) else: self.model = model_class.from_pretrained( pretrain_or_model, diff --git a/tests/backends/skyrl_train/gpu/gpu_ci/test_meta_init.py b/tests/backends/skyrl_train/gpu/gpu_ci/test_meta_init.py new file mode 100644 index 0000000000..fbf298dbdf --- /dev/null +++ b/tests/backends/skyrl_train/gpu/gpu_ci/test_meta_init.py @@ -0,0 +1,107 @@ +""" +uv run --isolated --extra dev --extra fsdp pytest -s -vvv tests/backends/skyrl_train/gpu/gpu_ci/test_meta_init.py + +Multi-rank FSDP2 meta-init correctness under sequence parallelism. +""" + +import pytest +import ray +import torch + +from skyrl.backends.skyrl_train.distributed.utils import get_free_port +from skyrl.backends.skyrl_train.workers.fsdp.fsdp_worker import FSDPRefWorkerBase +from skyrl.train.config import AlgorithmConfig, RefConfig, TrainerConfig + +# "Tied" word embeddings share one weight tensor between the input embedding and the output `lm_head`: +# https://github.com/huggingface/transformers/blob/v5.8.0/src/transformers/modeling_utils.py#L2582 +# This test needs a model that is both non-tied and has a realistic head_dim: +# - Non-tied: `FSDPRefWorkerBase.init_model` gates meta-init on `not tie_word_embeddings`, so a +# tied model (e.g. Qwen3-0.6B) skips meta-init entirely and can't reproduce the bug. +# - Realistic head_dim (e.g. Qwen3-8B's 128): with that many frequencies, +# the bf16 `inv_freq` ends up holding a NaN and the forward NaNs with SP>1; +# a tiny head_dim (e.g. 4) only shows the dtype change (bf16 vs fp32), not a NaN. +MODEL_NAME = "Qwen/Qwen3-8B" +SERVER_HOST = "127.0.0.1" +WORLD_SIZE = 2 +SP_SIZE = 2 +SEQ_LEN = 128 + + +@ray.remote(num_gpus=1) +class MetaInitProbeRefWorker(FSDPRefWorkerBase): + def record_inv_freq(self) -> list[dict]: + out: list[dict] = [] + for name, buf in self.model.model.named_buffers(): + if not name.endswith("inv_freq") or name.endswith("original_inv_freq"): + # `original_inv_freq` is transformers' pre-cast backup of `inv_freq`. + # It has the same data, so skip it to avoid duplicate records + continue + out.append( + { + "name": name, + "n_nan": int(torch.isnan(buf).sum().item()), + "dtype": str(buf.dtype), + "first5": buf.detach().float().cpu().tolist()[:5], + } + ) + return out + + def forward_and_count_nan(self, sequences: torch.Tensor) -> int: + sequences = sequences.to("cuda") + with torch.no_grad(), torch.autocast(dtype=torch.bfloat16, device_type="cuda"): + log_probs = self.model(sequences, sequences.shape[1] - 1, torch.ones_like(sequences)) + return int(torch.isnan(log_probs).sum().item()) + + +@pytest.mark.usefixtures("ray_init_fixture") +@pytest.mark.parametrize("bf16", [True, False]) +def test_meta_init_inv_freq_finite_under_sp(bf16: bool) -> None: + """Meta-init under SP=2 must leave every rank's rotary `inv_freq` finite and its + forward NaN-free, for both bf16 (which triggers the dtype mismatch) and fp32.""" + cfg = TrainerConfig( + algorithm=AlgorithmConfig(temperature=0.1), # Placeholder non-None temperature + ref=RefConfig(sequence_parallel_size=SP_SIZE), + bf16=bf16, + ) + # All WORLD_SIZE worker actors must agree on `MASTER_PORT`; pick once on the driver + master_port = get_free_port() + + actors = [ + MetaInitProbeRefWorker.remote( + cfg=cfg, + world_size=WORLD_SIZE, + rank=r, + local_rank=0, + master_addr=SERVER_HOST, + master_port=master_port, + sequence_parallel_size=SP_SIZE, + ) + for r in range(WORLD_SIZE) + ] + # Stand up the NCCL process group + device mesh, then build the FSDP2 model + ray.get([a.init_worker_process_group.remote() for a in actors]) + ray.get([a.init_model.remote(MODEL_NAME) for a in actors]) + + # What we're protecting against: NaN logits with SP>1. + # Non-rank-0's `inv_freq` is cast to bf16 and ends up holding a NaN, + # which poisons the SP-coupled attention so every rank's forward produces NaN logits + sequences = torch.randint(10, 10_000, (1, SEQ_LEN), dtype=torch.long) + nan_counts = ray.get([a.forward_and_count_nan.remote(sequences) for a in actors]) + for rank, n_nan in enumerate(nan_counts): + assert n_nan == 0, f"rank {rank}: log_probs has {n_nan} NaN positions under SP={SP_SIZE} (bf16={bf16})" + + # Also assert the dtype: the unfixed meta-init casts non-rank-0's buffers + # (including `inv_freq`) to bf16, but the forward only NaNs when those bf16 values include a NaN, + # so this dtype assertion catches the bad cast even where the forward stays finite + inv_freq_records = ray.get([a.record_inv_freq.remote() for a in actors]) + for rank, records in enumerate(inv_freq_records): + assert records, f"rank {rank}: test expects a model with rotary embeddings, but found no inv_freq buffers" + for record in records: + assert record["dtype"] == "torch.float32", ( + f"rank {rank}: {record['name']} is {record['dtype']} after FSDP init, " + f"expected fp32 (cast to bf16 by meta-init, diverging from rank-0)" + ) + assert record["n_nan"] == 0, ( + f"rank {rank}: {record['name']} non-finite after FSDP init " + f"(n_nan={record['n_nan']}, dtype={record['dtype']}, first5={record['first5']})" + ) diff --git a/tests/backends/skyrl_train/models/test_models.py b/tests/backends/skyrl_train/models/test_models.py index 546b4c3b25..d2ff394deb 100644 --- a/tests/backends/skyrl_train/models/test_models.py +++ b/tests/backends/skyrl_train/models/test_models.py @@ -2,6 +2,8 @@ import torch from flash_attn.bert_padding import pad_input, unpad_input +from skyrl.backends.skyrl_train.workers.model_wrapper import HFModelWrapper + @pytest.fixture def input_ids(): @@ -65,3 +67,25 @@ def test_flash_attention_sequence_unpacking(input_ids, attention_mask, position_ assert torch.equal(unpacked_input_ids, input_ids) # mask out the attention mask because the padding value used can differ assert torch.equal(unpacked_position_ids * attention_mask, position_ids * attention_mask) + + +@pytest.mark.parametrize("bf16", [True, False]) +def test_meta_init_keeps_non_persistent_buffers_fp32(bf16: bool) -> None: + """Meta-init casts params and persistent buffers to the target dtype while leaving + non-persistent buffers (rotary `inv_freq`) at fp32.""" + target_dtype = torch.bfloat16 if bf16 else torch.float32 + wrapper = HFModelWrapper( + "llamafactory/tiny-random-Llama-3", # any model with a rotary `inv_freq` buffer + bf16=bf16, + meta_init=True, # We're exercising the non-rank-0 meta path + ) + + for name, param in wrapper.model.named_parameters(): + assert param.dtype == target_dtype, f"param {name} is {param.dtype}, expected {target_dtype}" + + inv_freq_seen = False + for name, buf in wrapper.model.named_non_persistent_buffers(): + assert buf.dtype == torch.float32, f"non-persistent buffer {name} is {buf.dtype}, expected fp32" + if name.endswith("inv_freq"): + inv_freq_seen = True + assert inv_freq_seen, "expected at least one inv_freq buffer in the model"