Context
Deferred follow-up from #348 (per-(robot_type, control_mode) normalization in fit_fast_tokenizer). Tracked here so the PR can land without bundling unrelated plumbing.
Problem
When a dataset entry in the mixture has empty (robot_type, control_mode) after overrides, compute_norm_key falls back to using cfg.repo_id as the head's key. If two such entries share the same repo_id, the two diverge between fit and training:
- Fit time (current
_aggregate_stats_per_head in src/opentau/scripts/fit_fast_tokenizer.py): both rows compute the same fallback key repo_id, so they are POOLED into one shared head.
- Training time (
DatasetMixtureMetadata._build_norm_heads in src/opentau/datasets/dataset_mixture.py:340-348): the per-dataset name is _make_dataset_names's deduplicated form (X, X#0, X#1, ...), so the fallback keys differ and the rows stay as separate singleton heads.
Net: the BPE corpus the tokenizer is fit on doesn't match what the policy normalizes at runtime for that subset of rows.
Mitigation already in place (PR #348, commit c755cb8)
A Counter-based warning at fit time flags duplicate repo_id values that fell through to fallback. So an operator hitting this configuration sees an explicit signal:
N fallback-keyed repo_id values appear in the mixture more than once. Fit-time POOLS these under one shared head per repo_id, but training (via _make_dataset_names's #N dedup) keeps them as separate singleton heads...
This is enough to keep the misalignment from shipping silently, but doesn't actually fix the underlying divergence.
Proposed fix
Replicate _make_dataset_names's dedup at fit time inside _aggregate_stats_per_head so the fallback key passed to compute_norm_key is the deduplicated name, not the raw repo_id. Roughly:
raw_names = [(dc.repo_id or dc.vqa or "<no-name>") for dc in mixture_cfg.datasets]
counts = Counter(raw_names)
seen: dict[str, int] = {}
deduplicated_names: list[str] = []
for name in raw_names:
if counts[name] > 1:
i = seen.get(name, 0)
deduplicated_names.append(f"{name}#{i}")
seen[name] = i + 1
else:
deduplicated_names.append(name)
Then pass deduplicated_names[i] (not per_ds_repo[i]) as the third arg to compute_norm_key. This makes the fit-time fallback keys match training's exactly.
Why deferred
The dedup logic is currently a static method on WeightedDatasetMixture that takes TrainPipelineConfig and a list of BaseDataset. Lifting just the part that handles dedup is straightforward (the snippet above doesn't need either), but "the right thing" is probably to factor _make_dataset_names so both call sites share. That refactor felt out of scope for #348's fit-time normalization correctness fixes; tracking here.
Acceptance
- Fit-time fallback keys match training's (a synthetic test with duplicate
repo_id shows both paths producing the same (per_ds_key, per_norm_key_stats) shape).
- The c755cb8 dup-detection warning still fires (for visibility), but the underlying divergence is gone.
Context
Deferred follow-up from #348 (per-(robot_type, control_mode) normalization in
fit_fast_tokenizer). Tracked here so the PR can land without bundling unrelated plumbing.Problem
When a dataset entry in the mixture has empty
(robot_type, control_mode)after overrides,compute_norm_keyfalls back to usingcfg.repo_idas the head's key. If two such entries share the samerepo_id, the two diverge between fit and training:_aggregate_stats_per_headinsrc/opentau/scripts/fit_fast_tokenizer.py): both rows compute the same fallback keyrepo_id, so they are POOLED into one shared head.DatasetMixtureMetadata._build_norm_headsinsrc/opentau/datasets/dataset_mixture.py:340-348): the per-dataset name is_make_dataset_names's deduplicated form (X,X#0,X#1, ...), so the fallback keys differ and the rows stay as separate singleton heads.Net: the BPE corpus the tokenizer is fit on doesn't match what the policy normalizes at runtime for that subset of rows.
Mitigation already in place (PR #348, commit c755cb8)
A Counter-based warning at fit time flags duplicate
repo_idvalues that fell through to fallback. So an operator hitting this configuration sees an explicit signal:This is enough to keep the misalignment from shipping silently, but doesn't actually fix the underlying divergence.
Proposed fix
Replicate
_make_dataset_names's dedup at fit time inside_aggregate_stats_per_headso the fallback key passed tocompute_norm_keyis the deduplicated name, not the rawrepo_id. Roughly:Then pass
deduplicated_names[i](notper_ds_repo[i]) as the third arg tocompute_norm_key. This makes the fit-time fallback keys match training's exactly.Why deferred
The dedup logic is currently a static method on
WeightedDatasetMixturethat takesTrainPipelineConfigand a list ofBaseDataset. Lifting just the part that handles dedup is straightforward (the snippet above doesn't need either), but "the right thing" is probably to factor_make_dataset_namesso both call sites share. That refactor felt out of scope for #348's fit-time normalization correctness fixes; tracking here.Acceptance
repo_idshows both paths producing the same(per_ds_key, per_norm_key_stats)shape).