Skip to content

Now, grpo.py matches grpo_fast.py on qwen3_4b_dapo_math{,_oc}.sh#1642

Open
finbarrtimbers wants to merge 127 commits into
mainfrom
finbarr/post-training-experiments
Open

Now, grpo.py matches grpo_fast.py on qwen3_4b_dapo_math{,_oc}.sh#1642
finbarrtimbers wants to merge 127 commits into
mainfrom
finbarr/post-training-experiments

Conversation

@finbarrtimbers
Copy link
Copy Markdown
Collaborator

@finbarrtimbers finbarrtimbers commented Apr 27, 2026

Summary

Fixes grpo.py (the olmo-core / FSDP2 GRPO path) on two axes: (1) the step-0 weight sync was broken in three independent ways, and (2) the per-step logprob recompute was running cross-doc attention while vLLM was running intra-doc, blowing up val/tis_clipfrac ~570× vs the HF reference even after the weight sync was working.

Bug 0 — OLMo-core trainer logprob recompute drops doc boundaries

grpo_utils.forward_for_logprobs calls the model with (input_ids, attention_mask, position_ids). HF FA3 derives intra-doc cu_seqlens from position-id resets automatically, so grpo_fast.py works. OLMo-core's Transformer does not — it requires explicit doc_lens / max_doc_lens to do intra-doc attention on packed sequences, and silently drops both attention_mask and any inferred structure otherwise. With pack_length=10240, every packed sequence holds multiple docs, so OLMo-core was computing full cross-doc attention while vLLM (and HF FA3) used intra-doc — a mismatch that surfaces as massively elevated val/tis_clipfrac.

Fix: add compute_olmo_core_doc_lens(attention_mask) helper that converts the integer-coded packed mask to (doc_lens, max_doc_lens), plus a pass_olmo_core_doc_lens: bool flag on forward_for_logprobs / compute_logprobs. OLMo-core call sites in olmo_core_train_modules.py (ref-policy recompute, old-logprobs recompute, per-step new-logprobs) set the flag; HF call sites are unchanged. See docs/grpo_divergence.md for the full investigation.

Bug 1 — FSDP2 root-module params produced bad NCCL sends (illegal memory access)

broadcast_weights_to_vllm only unsharded the FSDP2 submodules returned by _get_fsdp2_submodules, which deliberately excludes the root FSDPModule. Root-level params (e.g. model.norm.weight, lm_head.weight) were therefore still DTensors backed by only the local shard. Calling .contiguous().clone() on them produced a buffer with the global stride but only the local shard's storage, so the 399th NCCL broadcast hit unmapped memory — surfacing as Cuda failure 700: an illegal memory access in pyNcclCommunicator/enqueue.cc.

Fix: also call model.unshard() / model.reshard() on the root in the FSDP2 branch, and switch the FSDP2 branch to gather_whole_model=True (per-block iteration deadlocks the CUDA stream because reshard collectives interleave with trainer->vLLM NCCL sends).

Bug 2 — LLMRayActor.update_weights didn't match vLLM's IPC calling convention

vLLM's IPCWeightTransferEngine.trainer_send_weights calls the actor with a single positional dict: update_weights.remote(dict(update_info=...)). Our actor's signature was update_weights(names, dtype_names, shapes, packed, model_step) — so single-GPU runs (which use the IPC backend) failed at step 0 with TypeError: missing a required argument: 'dtype_names'.

Fix: unify the actor signature on update_weights(kwargs: dict, model_step=None) and have the NCCL callers in broadcast_weights_to_vllm build the same {"update_info": {...}} shape. Single path for both backends.

Bug 3 — Initial weight sync wasn't fired before trainer.fit()

grpo_fast does an explicit pre-training weight broadcast (initialize_weight_sync) so the first NCCL collective fires from a known-good state before rollouts start racing. grpo.py skipped this. Fix: add PolicyTrainerOLMoCoreProcess.run_initial_weight_sync() and call it from grpo.main() before trainer.fit(). Add torch.distributed.barrier() in setup_model_update_group so non-rank-0 ranks wait for rank 0 to finish registering NCCL weight-transfer engines.

Other changes

  • VLLMWeightSyncCallback gains inflight_updates: bool and skips the blocking ray_get(refs) when set, matching grpo_fast.
  • torch.cuda.set_device(0) in post_step so the NCCL broadcast targets the right device under Ray's per-actor CUDA_VISIBLE_DEVICES.
  • _collect_weight_metadata simplified — FSDP2 path no longer needed.
  • Default --output_dir / --checkpoint_state_dir to /tmp-3m/${RUN_NAME} in the two qwen3-4b DAPO scripts.

Results

val/tis_clipfrac (Qwen3-4B-Base DAPO-Math, 8×H100 jupiter):

Run Trainer mean max
parozgke HF (grpo_fast.py) 5.7e-6 3.7e-5
il33h8fl OLMo-core, before doc_lens fix 3.2e-3 7.8e-3
i3e7d0b5 OLMo-core, after doc_lens fix 3.8e-6 3.5e-5

OLMo-core trainer is now matching vLLM bit-for-bit on logprobs (slightly better than the HF reference).

Comparing OLMo-core (i3e7d0b5, in progress) vs HF (parozgke, finished) over the first 540 steps:

Numerics — at or below HF parity with vLLM:

Metric OLMo-core HF Ratio
val/tis_clipfrac 1.13e-6 2.99e-6 0.38
val/tis_ratio 1.000 1.000 1.00
policy/clipfrac_avg 0 0
val/advantages_max 9.07 9.02 1.00

OLMo-core's tis_clipfrac is now ~2.6× lower than HF's — i.e. its trainer logprobs match vLLM more tightly than HF's do.

Speed — OLMo-core dominates:

Metric OLMo-core HF Ratio
time/training 7.23s 45.26s 6.3× faster
time/total 7.96s 45.53s 5.7× faster
time/trainer_idle_waiting_for_inference 22.7s 30.3s 1.3× less idle
time/getting_response 93.9s 107.3s 1.1× faster
time/weight_sync 0.69s 0.42s 1.6× slower

Per-step training time scales much better on OLMo-core: as response length and pack heaviness grow, HF's time/training rose from 31.8s (first 160 steps) to 45.3s (first 540), while OLMo-core's only rose from 2.8s → 7.2s. Weight sync remains the one regression (~60% slower), consistent with the OLMo-core→vLLM sync path being a bit heavier.

Learning signal — converging:

Metric OLMo-core HF Ratio
objective/verifiable_correct_rate 0.458 0.473 0.97
loss/policy_avg 0.402 0.440 0.91

Solve-rate gap closed from ~9% (step 160) to ~3% (step 540), confirming the early gap was sampling drift rather than a real regression. Both runs use seed=1 but vLLM scheduling + async_steps=4 give different prompt mixes.

Runs

  1. https://beaker.org/ex/01KQDKM58STGTSYQ4KGJDS604Zqwen3_4b_dapo_math_oc.sh with the doc_lens fix (in progress, tis_clipfrac confirmed at parity)
  2. https://beaker.org/ex/01KQ89YS7R7X8C4CY01FEYPC14qwen3_4b_dapo_math_oc.sh with weight-sync fixes (FSDP2 / NCCL multi-GPU)

Test plan

  • qwen3_4b_dapo_math_oc.sh (FSDP2 / NCCL, multi-GPU)
  • scripts/train/debug/single_gpu_on_beaker.sh (DeepSpeed / IPC) regression
  • GPU tests

GPU_TESTS=bypass

…po_math.sh image positional leak Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
…pen-instruct-dev Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
…val to math verifier Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
…/brumo eval to math verifier Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>"

This reverts commit cf82a70.
…h dataset=math) Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
…ed-By: Claude Opus 4.7 <noreply@anthropic.com>
… Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
…tion checkpointing Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
…ng works Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
…ng Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
…po.py hang Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
…o-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
…enabled) Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
…ache lock contention Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
…sync hang Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
…to find hang location Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
…gs/lm_head Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
… lowercase b) Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
…ync hang Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
…Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
…for grpo.py Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
…arrier to prevent rank desync into gloo bookkeeping collective Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
…npoint hang Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
…o ranks aren't suppressed Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
… as numeric, inflating loss_denominator 60x in grpo.py
…po_utils to enable eval/* metric logging Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
…match-grpo notes) Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
…n_counts between grpo_fast and grpo_utils Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
…valued response_masks summed as numerics) Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
…locks pytest collection) Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
…edundant per-consumer .bool() coercions Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
…uences contract) Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
…, drop redundant per-consumer .bool() coercions Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>"

This reverts commit 607613d.
…umer .bool() coercions Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant