Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion CLAUDE.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,10 @@ These override defaults — read them before running anything.
- Exceptions where plain torch ops are fine: simple `.flatten()` of a contiguous suffix, single-axis `.sum()` over the last dim, contiguity calls (`.contiguous()`), and shape-preserving ops (`.to(...)`, `.float()`). Don't rewrite these just for the sake of einops.
- When matching a reference (HuggingFace `transformers`, upstream LeRobot, the original π model code), preserve the existing op style verbatim inside that block — readability gains are not worth diff churn against an upstream reference.

5. **Distributed forward/backward must keep collective counts aligned across ranks.** FSDP / ZeRO-3 hang at NCCL with mismatched all-gather sizes when ranks disagree on what `forward` does. The failure mode is silent at CPU smoke-test time and only manifests after collectives diverge on a real run, so the patterns below are mandatory — copy them when extending the relevant code, don't reinvent them:
5. **Distributed forward/backward and metric gathers must stay consistent across ranks.** FSDP / ZeRO-3 hang at NCCL with mismatched all-gather sizes when ranks disagree on what `forward` does; gathered per-sample results silently desync rows when paired tensors are de-padded independently. Both failure modes are silent at CPU smoke-test time and only manifest on real multi-rank runs, so the patterns below are mandatory — copy them when extending the relevant code, don't reinvent them:
- **Per-rank branch decisions that fire collectives must be OR-reduced first.** When a `forward` takes a Python-level branch based on what the local micro-batch contains (e.g. `if has_response: embed_language_tokens(...)` in `embed_prefix`), use `_global_or_branch_decisions` in `src/opentau/policies/pi07/low_level/modeling_pi07_low_level.py` — one SUM all-reduce that both OR-reduces the per-rank decisions and asserts cross-rank presence agreement. Adding a new optional branch in distributed `forward` without going through it (or an equivalent pre-branch all-reduce) is the same bug.
- **Composite forward units must be a single `nn.Module`.** Bundle multi-component decoder steps (e.g. a backbone layer paired with an action-expert layer) into one `nn.Module` so FSDP's all-gather hook prefetches every sub-component together — like `InterleavedDecoderLayer` in `src/opentau/policies/pi07/gemma3_with_expert.py`. Calling sub-components directly on a separately-wrapped layer (`layer.input_layernorm(...)`, `layer.self_attn.q_proj(...)`) bypasses the hook and triggers mismatched all-gather sizes across ranks.
- **Per-sample tensors that must stay row-aligned across ranks belong in one `gather_for_metrics({...})` dict call of int tensors.** `accelerate.gather_for_metrics` de-pads the ragged final batch (total samples not divisible by `world_size`) by trimming the gathered result to `gradient_state.remainder`. Splitting the gather per quantity breaks row alignment two ways: (1) tensors and non-tensors trim differently — `gather` trims to `remainder`, the `gather_object` path used for lists (e.g. `dataset_repo_id` strings) does not — so provenance desyncs from the loss rows; (2) even all-tensor calls stay aligned only by relying on every independent trim being identical, an accelerate implementation detail. Build the provenance as int tensors and gather everything in one dict call so a single trim hits every entry, aligned by construction — like the validation per-sample gather in `src/opentau/scripts/train.py` (per-sample MSE/CE sums + counts alongside `dataset_index` and a source index). The failure is silent (per-dataset metrics bucketed under the wrong dataset, not a hang) and the CPU suite can't catch it: the multi-rank de-pad path only fires on >1-rank hardware.

## Project overview

Expand Down
Loading