Skip to content

pi07 embed_prefix: hoist embed_* calls out of has_* branches to drop the global-sync all-reduces #266

@shuheng-liu

Description

@shuheng-liu

Background

Pi07's Gemma3WithExpertModel.embed_prefix has data-dependent branches that decide whether to add response / metadata / subgoal / prefix-end blocks based on whether any sample in the local batch has real (non-padded) content for that field. Each branch's body calls embed_language_tokens(...) and/or embed_image(...), which trigger forward calls into the FSDP-wrapped tree.

Under FSDP / ZeRO-3 with realistic stochastic *_drop_prob settings, different ranks roll different drop outcomes, take different branches, and end up issuing different counts of FSDP all-gather collectives → NCCL deadlock. PR #265 fixes this with a _global_any(local, device) helper that OR-reduces each branch decision across ranks via a 1-element MAX all-reduce, so every rank takes the same branch.

That fix works, but it has two warts:

  1. Three small all-reduces per embed_prefix (~tens of µs total — negligible against step time, but still extra collectives).
  2. Ranks whose local micro-batch has no data for a field still embed pad tokens through the whole prefix when some other rank has that field's content. Their pad_masks are all-False so attention/loss correctly ignore them, but every interleaved layer still spends compute carrying those pad slots.

Proposed cleanup

Hoist every embed_language_tokens(...) and embed_image(...) call out of the if has_* branches. Make the call unconditional (so every rank issues the same FSDP all-gather count) but keep the bookkeeping (embs.append, pad_masks.append, att_masks += [...]) gated on the local condition.

Sketch:
```python

Always run — uniform FSDP all-gather count across ranks.

response_emb = (
self.gemma3_with_expert.embed_language_tokens(response_tokens)
if response_tokens is not None
else None
)

Bookkeeping stays conditional on the LOCAL data — no global sync needed.

if response_emb is not None and response_masks is not None and response_masks.any():
embs.append(response_emb)
pad_masks.append(response_masks)
att_masks += [1] * response_emb.shape[1]
```

Same shape for metadata, prefix_end, subgoal_images (the subgoal body has 2× embed_language_tokens + N× embed_image; all should be hoisted).

Different ranks would then have different prefix lengths — that's OK because:

  • FSDP collectives are tied to Module.forward call counts, not input shapes; uniform call counts → no desync.
  • Each rank's gemma3_with_expert.forward processes its own seq length; the per-layer InterleavedDecoderLayer / SiglipEncoderLayer forwards are still uniformly called once per layer.
  • Loss / metric reductions go through accelerator.gather_for_metrics which handles per-rank scalar reductions.

Verification needed before merging

The above only holds if no downstream op depends on a uniform prefix structure across ranks. Quick checklist to walk before merging:

  • embed_prefix callers (PI07LowLevelFlowMatching.forward and friends) work with per-rank-variable prefix_embs.shape[1].
  • Any cross-rank slicing / indexing that assumes the same prefix layout per rank.
  • Determinism check: two seeded FSDP runs still bit-identical (per CLAUDE.md hard rule Fixing reward normalizer #3).

Win

Drops the 3 all-reduces and the wasted-pad compute. Restores per-rank-honest prefix structure. Aesthetic match to the existing pattern of "the unconditional path always runs; the optional path appends" elsewhere in the codebase.

Out of scope (already handled in PR #265)

The _global_any fix is already shipped and verified — it works, just leaves these efficiency / cleanliness tradeoffs on the table. This issue is for the follow-up cleanup.

Metadata

Metadata

Assignees

Labels

bugSomething isn't workingoptimizationOptimizes the performance of something

Type

No type
No fields configured for issues without a type.

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions