You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Now that every sample carries dataset provenance in our standard data format — _TaggedDataset injects dataset_index (the norm-head row, dropout-immune) and dataset_repo_id (the deduplicated mixture-level name) into each item (src/opentau/datasets/dataset_mixture.py:83-135), and optional keys robot_type / control_mode are emitted by _emit_optional_keys (src/opentau/datasets/lerobot_dataset.py) — we can compute per-(dataset, control mode) loss breakdowns from the batch itself, instead of relying on how the validation loop happens to be batched.
Two related shortcomings in the current validation path (src/opentau/scripts/train.py:844-950) motivate this issue.
Problem 1: per-dataset breakdown is an artifact of dataloader structure, not provenance
Today the only per-dataset granularity we get comes from iterating one separate dataloader per underlying dataset:
(src/opentau/scripts/train.py:868-918, dataloaders built by WeightedDatasetMixture.get_per_dataset_dataloaders() at src/opentau/datasets/dataset_mixture.py:944-977.)
Consequences:
The breakdown granularity is per source dataset, never per (dataset, control mode) — even though control_mode ({"joint", "ee", "mixed"}) is exactly the axis we want to slice MSE/CE along, and it is already present in the batch.
The loss itself (policy.forward) returns a single scalar MSE/CE reduced over the whole batch (modeling_pi05.py:740, modeling_pi06.py:623, pi07/low_level/modeling_pi07_low_level.py:931). There is no per-sample / per-group disaggregation, so we cannot break a mixed batch down by provenance at all — the only reason we get any per-dataset numbers today is that each dataloader is homogeneous by construction.
Crucially this disaggregation does not belong inside forward. The model shouldn't know about dataset taxonomy or own metric bookkeeping; grouping by (dataset_index, control_mode) should happen in the training/eval loop using the provenance keys already in the batch.
Problem 2: validation badly under-parallelized across ranks
get_per_dataset_dataloaders() creates one DataLoader per dataset, each iterated separately and then sharded across ranks by accelerator.prepare(). With a heterogeneous mixture of many small validation subsets (some with only a single frame after val_split_ratio), this is pathological:
A dataset with fewer frames than world_size leaves most ranks with an empty/padded shard — they do a no-op forward (or wait) while one rank does the real work.
Because we loop datasets sequentially (for ds_name, ds_loader in ...), the idle time stacks: every tiny dataset is its own under-filled collective round, and accelerator.gather_for_metrics + the trailing accelerator.wait_for_everyone() (train.py:950) force all ranks to rendezvous on each one.
drop_last=False (dataset_mixture.py:973) is correct for not discarding val data, but combined with tiny datasets it guarantees ragged, under-utilized batches.
Net effect: validation wall-clock scales with the number of datasets rather than the total val frames, and most ranks sit idle.
Proposed direction (for discussion, not prescriptive)
Move per-group aggregation out of forward and into the loop, keyed on provenance. Have policies optionally return unreduced (per-sample) MSE/CE — or keep returning scalars but additionally expose a per-sample loss — and let the validation loop bucket them by (dataset_index, control_mode) from the batch, using dataset_index as the dropout-immune key (control_mode can be masked at train time; pair it with dataset_index or recover via compute_norm_key). Log Validation/{dataset}/{control_mode}/{MSE,CE} plus the existing mixture-weighted aggregate (_mixture_weighted_aggregate, train.py:932-946).
Parallelize validation across all ranks regardless of per-dataset size. Instead of one sequential under-filled dataloader per dataset, run a single validation pass over the combined val mixture (so every rank's shard is full), and rely on the per-sample provenance keys for grouping rather than on homogeneous dataloaders. This decouples the breakdown from the batching and keeps every rank busy. (Open question: whether to keep get_per_dataset_dataloaders for any callers, or replace it.)
Acceptance
Validation logs MSE/CE broken down by (dataset, control_mode), computed in the loop from batch provenance — not inside any forward.
The grouping is keyed on dataset_index (dropout-immune) so a mixed batch can be disaggregated correctly; no reliance on one-dataset-per-dataloader homogeneity.
Validation keeps all ranks busy: a mixture of many 1-frame datasets no longer leaves most ranks idle. Quantify the wall-clock improvement on a representative multi-dataset config.
Per-step val loss remains deterministic under a fixed seed (per CLAUDE.md hard rule Fixing reward normalizer #3).
Background
Now that every sample carries dataset provenance in our standard data format —
_TaggedDatasetinjectsdataset_index(the norm-head row, dropout-immune) anddataset_repo_id(the deduplicated mixture-level name) into each item (src/opentau/datasets/dataset_mixture.py:83-135), and optional keysrobot_type/control_modeare emitted by_emit_optional_keys(src/opentau/datasets/lerobot_dataset.py) — we can compute per-(dataset, control mode) loss breakdowns from the batch itself, instead of relying on how the validation loop happens to be batched.Two related shortcomings in the current validation path (
src/opentau/scripts/train.py:844-950) motivate this issue.Problem 1: per-dataset breakdown is an artifact of dataloader structure, not provenance
Today the only per-dataset granularity we get comes from iterating one separate dataloader per underlying dataset:
(
src/opentau/scripts/train.py:868-918, dataloaders built byWeightedDatasetMixture.get_per_dataset_dataloaders()atsrc/opentau/datasets/dataset_mixture.py:944-977.)Consequences:
control_mode({"joint", "ee", "mixed"}) is exactly the axis we want to slice MSE/CE along, and it is already present in the batch.policy.forward) returns a single scalarMSE/CEreduced over the whole batch (modeling_pi05.py:740,modeling_pi06.py:623,pi07/low_level/modeling_pi07_low_level.py:931). There is no per-sample / per-group disaggregation, so we cannot break a mixed batch down by provenance at all — the only reason we get any per-dataset numbers today is that each dataloader is homogeneous by construction.forward. The model shouldn't know about dataset taxonomy or own metric bookkeeping; grouping by(dataset_index, control_mode)should happen in the training/eval loop using the provenance keys already in the batch.Problem 2: validation badly under-parallelized across ranks
get_per_dataset_dataloaders()creates one DataLoader per dataset, each iterated separately and then sharded across ranks byaccelerator.prepare(). With a heterogeneous mixture of many small validation subsets (some with only a single frame afterval_split_ratio), this is pathological:world_sizeleaves most ranks with an empty/padded shard — they do a no-opforward(or wait) while one rank does the real work.for ds_name, ds_loader in ...), the idle time stacks: every tiny dataset is its own under-filled collective round, andaccelerator.gather_for_metrics+ the trailingaccelerator.wait_for_everyone()(train.py:950) force all ranks to rendezvous on each one.drop_last=False(dataset_mixture.py:973) is correct for not discarding val data, but combined with tiny datasets it guarantees ragged, under-utilized batches.Net effect: validation wall-clock scales with the number of datasets rather than the total val frames, and most ranks sit idle.
Proposed direction (for discussion, not prescriptive)
forwardand into the loop, keyed on provenance. Have policies optionally return unreduced (per-sample) MSE/CE — or keep returning scalars but additionally expose a per-sample loss — and let the validation loop bucket them by(dataset_index, control_mode)from the batch, usingdataset_indexas the dropout-immune key (control_modecan be masked at train time; pair it withdataset_indexor recover viacompute_norm_key). LogValidation/{dataset}/{control_mode}/{MSE,CE}plus the existing mixture-weighted aggregate (_mixture_weighted_aggregate,train.py:932-946).get_per_dataset_dataloadersfor any callers, or replace it.)Acceptance
(dataset, control_mode), computed in the loop from batch provenance — not inside anyforward.dataset_index(dropout-immune) so a mixed batch can be disaggregated correctly; no reliance on one-dataset-per-dataloader homogeneity.References
src/opentau/scripts/train.py:844-950src/opentau/datasets/dataset_mixture.py:944-977dataset_index,dataset_repo_id):src/opentau/datasets/dataset_mixture.py:83-135robot_type,control_mode):src/opentau/datasets/lerobot_dataset.py(_emit_optional_keys)modeling_pi05.py:740,modeling_pi06.py:623,pi07/low_level/modeling_pi07_low_level.py:931MetricsTracker/AverageMeterinsrc/opentau/utils/logging_utils.py; mixture aggregate_mixture_weighted_aggregateinsrc/opentau/scripts/train.py