Skip to content

feat(train): aggregate val loss per (dataset, control_mode), parallelize#374

Merged
shuheng-liu merged 2 commits into
mainfrom
claude/thirsty-blackwell-184416
Jun 2, 2026
Merged

feat(train): aggregate val loss per (dataset, control_mode), parallelize#374
shuheng-liu merged 2 commits into
mainfrom
claude/thirsty-blackwell-184416

Conversation

@shuheng-liu
Copy link
Copy Markdown
Member

What this does

Closes #373.

Reworks the validation path so per-dataset loss breakdowns come from batch provenance rather than from how the val loop happens to be batched, and so validation stays parallel across ranks regardless of per-dataset size.

Before, validation built one DataLoader per underlying dataset and iterated them sequentially. That had two problems:

  1. Granularity was an artifact of batching, not provenance. policy.forward only returned a single batch-reduced MSE/CE scalar, so a mixed batch could not be disaggregated — the only reason we got any per-dataset numbers was that each loader was homogeneous by construction. We could never slice by (dataset, control_mode).
  2. Validation was badly under-parallelized. With a heterogeneous mixture of many small val subsets (some a single frame after val_split_ratio), a dataset with fewer frames than world_size left most ranks idle, and looping datasets sequentially made that idle time stack — wall-clock scaled with the number of datasets, not total val frames.

This PR:

  • Adds a PerSampleLoss(sum, count) dataclass and a return_per_sample path to flow_matching_masked_mse, plus a shared ce_per_sample helper, in policies/utils.py. Policies optionally expose per-sample (unreduced) MSE/CE; the scalar they return is computed exactly as before (bit-identical), so the training path is untouched.
  • Threads return_per_sample through every action policy (pi0, pi05, pi05_mem, pi06, pi07/low_level, pi07_paligemma/low_level, value). Disaggregation lives in the training loop, not in forward — the model never learns dataset taxonomy.
  • Adds WeightedDatasetMixture.get_combined_val_dataloader(): a single deterministic (shuffle=False, drop_last=False) sequential pass over the whole val mixture, returning None when empty.
  • Rewrites the validation loop in scripts/train.py to run one combined pass (every rank's shard stays full), gather per-sample (sum, count) + provenance keys via gather_for_metrics, and bucket on rank 0 keyed on the dropout-immune dataset_index (norm head) and a dataset_repo_id-derived per-source index. Per-group means are Σsum / Σcount. Logs Validation/{norm_key}/{MSE,CE,Loss} (headline) and Validation/by_dataset/{name}/{MSE,CE,Loss} (per-source), plus the existing mixture-weighted aggregate (Validation/{Loss,MSE Loss,CE Loss}). Collective counts stay aligned across ranks (the per-sample flag is rank-uniform; bucketing is rank-0-only arithmetic). Policies whose forward doesn't accept return_per_sample (the high-level planners) fall back to an aggregate-only pass — they still get the single-pass parallelization.

Behavior notes for reviewers:

  • Per-group MSE is exactly consistent with the legacy normalization (both are sum/sum). Per-group CE uses a per-valid-token mean (Σsum/Σcount, pooling discrete-action + response components), which shifts CE magnitudes modestly vs the old per-dataset numbers — the regrouping changes them regardless.
  • The per-subset wandb namespace moves from Validation/{dataset_name}/… to Validation/{norm_key}/… plus the new Validation/by_dataset/{name}/…. The aggregate keys (Validation/Loss, Validation/MSE Loss, Validation/CE Loss) are unchanged.

🗃️ Feature

How it was tested

  • Pre-commit: clean (ruff, ruff-format, pyupgrade, typos, bandit, …).
  • CPU (pytest -m "not gpu"): the full policy CPU subset is green (487 passed / 2 skipped), plus new unit tests:
    • tests/policies/test_pi06.pyflow_matching_masked_mse(return_per_sample=True) returns a scalar bit-identical to the default path, Σsum/Σcount reproduces the masked mean, padded steps/dims and fully-padded samples are excluded from the count; ce_per_sample and PerSampleLoss.__add__ arithmetic.
    • tests/scripts/test_train.py_bucket_per_sample disaggregates a mixed batch by integer group key with no cross-contamination, group mean is Σsum/Σcount (not mean-of-means), and zero-count groups stay NaN-free.
    • tests/datasets/test_dataset_mixture.pyget_combined_val_dataloader is a single sequential loader over the concatenated mixture, covering every sample exactly once.
  • GPU (pytest -m "gpu") on the GPU dev box, for every touched policy (test_pi05, test_pi05_mem_gpu, test_pi06, test_pi07_low_level, test_pi07_paligemma_low_level, test_value): 14 passed, 6 skipped (single-GPU skips), 0 failed.
  • Determinism (CLAUDE.md rule 3): the per-step validation series should be verified bit-identical across two same-seed configs/dev/dev_config.json runs — being run separately. By construction the training path is unaffected: return_per_sample defaults False and the train step never sets it, so the train forward graph and reduction order are byte-for-byte unchanged (a CPU test asserts the scalar is torch.equal to the pre-change value). The combined val pass is shuffle=False over a fixed order with rank-0-only bucketing arithmetic, so the val series is deterministic.
  • Nightly gpu_test.yml + regression tests will exercise the full GPU/regression suite.

How to checkout & try? (for the reviewer)

pytest -sx tests/scripts/test_train.py::TestBucketPerSample
pytest -sx tests/policies/test_pi06.py::TestFlowMatchingMaskedMsePerSample
pytest -sx "tests/datasets/test_dataset_mixture.py::TestWeightedDatasetMixture::test_get_combined_val_dataloader"

The new validation breakdown runs automatically whenever val_freq > 0; a smoke run logs Validation/{norm_key}/… and Validation/by_dataset/… every val_freq steps:

accelerate launch --num_processes 1 src/opentau/scripts/train.py --config_path=configs/dev/dev_config.json

Checklist

  • I have added Google-style docstrings to important functions and ensured function parameters are typed.
  • My PR includes policy-related changes.
    • If the above is checked: I have run the GPU pytests (pytest -m "gpu") and regression tests.

Note: Before submitting this PR, please read the contributor guideline.

Validation runs a single combined pass over the whole mixture instead of one
DataLoader per dataset iterated sequentially, so every rank's shard stays full
(many tiny val subsets no longer leave most ranks idle). Per-(dataset,
control_mode) MSE/CE breakdowns are computed in the loop from each sample's
dropout-immune dataset_index / dataset_repo_id provenance, not from homogeneous
per-dataset loaders, so a mixed batch can be disaggregated correctly.

Policies optionally expose per-sample MSE/CE as (sum, count) via a
return_per_sample flag; the training scalar path is byte-for-byte unchanged
(the flag defaults False and the train step never sets it).
@shuheng-liu shuheng-liu added the feature New feature or request label Jun 2, 2026
@shuheng-liu shuheng-liu self-assigned this Jun 2, 2026
Copy link
Copy Markdown
Contributor

@claude claude Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reviewed the per-(dataset, control_mode) validation rework. Core logic is sound; two low-confidence suggestions left inline. No blocking issues.

@claude
Copy link
Copy Markdown
Contributor

claude Bot commented Jun 2, 2026

[claude-review] summary for commit 8559365

Re-reviewed the validation rework after the gather-batching commit. Core design remains sound and the determinism contract holds: return_per_sample defaults False, the scalar reduction is bit-identical (CPU-asserted), branch selection is rank-uniform, and every gather_for_metrics runs on all ranks with rank-0-only bucketing — collective counts stay aligned.

  • resolvedsrc/opentau/scripts/train.py:957 — the six separate gather_for_metrics calls are now a single gather_for_metrics({...}) over a dict, so the ragged last-batch de-pad is applied once and identically to losses + provenance indices. Row-alignment is now guaranteed by construction (was a suggestion in the prior review). The multi-rank de-pad path itself is still only exercisable on >1-rank hardware (being run separately per the author).
  • resolved (intentional) — per-group L1/Accuracy logging: confirmed out of scope — the only producer (value head) is single-norm-head (num_datasets=1), so its one group already equals the aggregate; L1/Accuracy are batch-reduced scalars, not per-sample quantities.

No blocking issues found.

Gather the per-sample MSE/CE (sum, count) and the provenance index tensors
(norm_index, source_index) in a single accelerator.gather_for_metrics call so
the ragged last-batch de-pad is applied once and identically to every entry —
losses and their provenance stay row-aligned by construction rather than
relying on separate de-pads landing on the same trim. Addresses a review note.
@shuheng-liu
Copy link
Copy Markdown
Member Author

Thanks for the review. Addressing the two suggestions:

train.py — per-sample gather alignment (fixed in 8559365). Switched the six separate gather_for_metrics calls to a single gather_for_metrics({...}) over a dict, so accelerate applies one identical ragged-last-batch de-pad to every entry — the per-sample losses and their norm_index / source_index provenance now stay row-aligned by construction rather than relying on six independent de-pads landing on the same trim. A true >1-rank run (plus the same-seed determinism check) is being exercised separately on multi-GPU hardware, since the CPU suite can't cover the multi-rank de-pad path.

train.py — per-group L1/Accuracy (intentional; keeping aggregate-only). L1/Accuracy are emitted only by the value head, which is a single-norm-head policy (num_datasets=1) — so its one norm-key group already equals the aggregate, making per-norm-key L1/Accuracy redundant with Validation/L1 Loss / Validation/Accuracy. Issue #373 scopes the per-(dataset, control_mode) breakdown to MSE/CE, and L1/Accuracy aren't per-sample quantities today (they're batch-reduced scalars), so bucketing them would require adding a separate per-sample path to the value head. I'm happy to add per-source L1/Accuracy as a follow-up if value-function multi-source training ever needs it, but it's out of scope for this PR.

Copy link
Copy Markdown
Contributor

@claude claude Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Re-reviewed at commit 8559365. The batched single gather_for_metrics({...}) resolves the prior row-alignment concern by construction — accelerate applies one identical ragged-last-batch trim across every entry in the dict, so per-sample losses and their norm_index/source_index provenance can no longer desync. Determinism contract still holds (train path untouched, return_per_sample defaults False). No blocking issues.

@shuheng-liu shuheng-liu marked this pull request as ready for review June 2, 2026 15:53
@shuheng-liu shuheng-liu merged commit 144c9f6 into main Jun 2, 2026
16 checks passed
@shuheng-liu shuheng-liu deleted the claude/thirsty-blackwell-184416 branch June 2, 2026 16:05
@claude claude Bot mentioned this pull request Jun 2, 2026
3 tasks
shuheng-liu added a commit that referenced this pull request Jun 2, 2026
Co-authored-by: Shuheng Liu <wish1104@icloud.com>
Co-authored-by: claude[bot] <41898282+claude[bot]@users.noreply.github.com>
Co-authored-by: Claude Opus 4.7 <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

feature New feature or request

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Aggregate val loss per (dataset, control_mode) outside forward + parallelize validation

1 participant