feat(train): aggregate val loss per (dataset, control_mode), parallelize#374
Conversation
Validation runs a single combined pass over the whole mixture instead of one DataLoader per dataset iterated sequentially, so every rank's shard stays full (many tiny val subsets no longer leave most ranks idle). Per-(dataset, control_mode) MSE/CE breakdowns are computed in the loop from each sample's dropout-immune dataset_index / dataset_repo_id provenance, not from homogeneous per-dataset loaders, so a mixed batch can be disaggregated correctly. Policies optionally expose per-sample MSE/CE as (sum, count) via a return_per_sample flag; the training scalar path is byte-for-byte unchanged (the flag defaults False and the train step never sets it).
|
[claude-review] summary for commit 8559365 Re-reviewed the validation rework after the gather-batching commit. Core design remains sound and the determinism contract holds:
No blocking issues found. |
Gather the per-sample MSE/CE (sum, count) and the provenance index tensors (norm_index, source_index) in a single accelerator.gather_for_metrics call so the ragged last-batch de-pad is applied once and identically to every entry — losses and their provenance stay row-aligned by construction rather than relying on separate de-pads landing on the same trim. Addresses a review note.
|
Thanks for the review. Addressing the two suggestions:
|
There was a problem hiding this comment.
Re-reviewed at commit 8559365. The batched single gather_for_metrics({...}) resolves the prior row-alignment concern by construction — accelerate applies one identical ragged-last-batch trim across every entry in the dict, so per-sample losses and their norm_index/source_index provenance can no longer desync. Determinism contract still holds (train path untouched, return_per_sample defaults False). No blocking issues.
What this does
Closes #373.
Reworks the validation path so per-dataset loss breakdowns come from batch provenance rather than from how the val loop happens to be batched, and so validation stays parallel across ranks regardless of per-dataset size.
Before, validation built one DataLoader per underlying dataset and iterated them sequentially. That had two problems:
policy.forwardonly returned a single batch-reducedMSE/CEscalar, so a mixed batch could not be disaggregated — the only reason we got any per-dataset numbers was that each loader was homogeneous by construction. We could never slice by(dataset, control_mode).val_split_ratio), a dataset with fewer frames thanworld_sizeleft most ranks idle, and looping datasets sequentially made that idle time stack — wall-clock scaled with the number of datasets, not total val frames.This PR:
PerSampleLoss(sum, count)dataclass and areturn_per_samplepath toflow_matching_masked_mse, plus a sharedce_per_samplehelper, inpolicies/utils.py. Policies optionally expose per-sample (unreduced) MSE/CE; the scalar they return is computed exactly as before (bit-identical), so the training path is untouched.return_per_samplethrough every action policy (pi0,pi05,pi05_mem,pi06,pi07/low_level,pi07_paligemma/low_level,value). Disaggregation lives in the training loop, not inforward— the model never learns dataset taxonomy.WeightedDatasetMixture.get_combined_val_dataloader(): a single deterministic (shuffle=False,drop_last=False) sequential pass over the whole val mixture, returningNonewhen empty.scripts/train.pyto run one combined pass (every rank's shard stays full), gather per-sample(sum, count)+ provenance keys viagather_for_metrics, and bucket on rank 0 keyed on the dropout-immunedataset_index(norm head) and adataset_repo_id-derived per-source index. Per-group means areΣsum / Σcount. LogsValidation/{norm_key}/{MSE,CE,Loss}(headline) andValidation/by_dataset/{name}/{MSE,CE,Loss}(per-source), plus the existing mixture-weighted aggregate (Validation/{Loss,MSE Loss,CE Loss}). Collective counts stay aligned across ranks (the per-sample flag is rank-uniform; bucketing is rank-0-only arithmetic). Policies whoseforwarddoesn't acceptreturn_per_sample(the high-level planners) fall back to an aggregate-only pass — they still get the single-pass parallelization.Behavior notes for reviewers:
sum/sum). Per-group CE uses a per-valid-token mean (Σsum/Σcount, pooling discrete-action + response components), which shifts CE magnitudes modestly vs the old per-dataset numbers — the regrouping changes them regardless.Validation/{dataset_name}/…toValidation/{norm_key}/…plus the newValidation/by_dataset/{name}/…. The aggregate keys (Validation/Loss,Validation/MSE Loss,Validation/CE Loss) are unchanged.🗃️ Feature
How it was tested
ruff,ruff-format,pyupgrade,typos,bandit, …).pytest -m "not gpu"): the full policy CPU subset is green (487 passed / 2 skipped), plus new unit tests:tests/policies/test_pi06.py—flow_matching_masked_mse(return_per_sample=True)returns a scalar bit-identical to the default path,Σsum/Σcountreproduces the masked mean, padded steps/dims and fully-padded samples are excluded from the count;ce_per_sampleandPerSampleLoss.__add__arithmetic.tests/scripts/test_train.py—_bucket_per_sampledisaggregates a mixed batch by integer group key with no cross-contamination, group mean isΣsum/Σcount(not mean-of-means), and zero-count groups stay NaN-free.tests/datasets/test_dataset_mixture.py—get_combined_val_dataloaderis a single sequential loader over the concatenated mixture, covering every sample exactly once.pytest -m "gpu") on the GPU dev box, for every touched policy (test_pi05,test_pi05_mem_gpu,test_pi06,test_pi07_low_level,test_pi07_paligemma_low_level,test_value): 14 passed, 6 skipped (single-GPU skips), 0 failed.configs/dev/dev_config.jsonruns — being run separately. By construction the training path is unaffected:return_per_sampledefaultsFalseand the train step never sets it, so the train forward graph and reduction order are byte-for-byte unchanged (a CPU test asserts the scalar istorch.equalto the pre-change value). The combined val pass isshuffle=Falseover a fixed order with rank-0-only bucketing arithmetic, so the val series is deterministic.gpu_test.yml+ regression tests will exercise the full GPU/regression suite.How to checkout & try? (for the reviewer)
pytest -sx tests/scripts/test_train.py::TestBucketPerSample pytest -sx tests/policies/test_pi06.py::TestFlowMatchingMaskedMsePerSample pytest -sx "tests/datasets/test_dataset_mixture.py::TestWeightedDatasetMixture::test_get_combined_val_dataloader"The new validation breakdown runs automatically whenever
val_freq > 0; a smoke run logsValidation/{norm_key}/…andValidation/by_dataset/…everyval_freqsteps:Checklist
Note: Before submitting this PR, please read the contributor guideline.