Context
Deferred follow-up from #348 (per-(robot_type, control_mode) normalization in fit_fast_tokenizer). The PR added a per-head aggregator in the fit script that mirrors what DatasetMixtureMetadata._build_norm_heads does at training. Surgical fixes in #348 close every divergence the review identified, but the design is brittle: any future change to _build_norm_heads's pooling semantics has to be mirrored in the fit-script aggregator manually.
Problem
_build_norm_heads lives in src/opentau/datasets/dataset_mixture.py:322 as an instance method on DatasetMixtureMetadata. It reads:
self.dataset_names (deduplicated via _make_dataset_names)
self.per_dataset_stats (post-_to_standard_data_format)
self.cfg.max_action_dim, self.cfg.max_state_dim, self.cfg.num_cams
- The
compute_norm_key free function
But DatasetMixtureMetadata.__init__ does far more than build heads — it loads parquet, builds episode indices, computes sample weights, etc. — so a fit-script user can't just instantiate it cheaply to read the norm-head structures off it.
The fit script's _aggregate_stats_per_head ends up re-implementing the head-building logic from scratch (~140 lines), which is why the original review caught five concrete divergence findings. PR #348 fixed all five but the structural risk remains: the fit aggregator and _build_norm_heads can drift again the next time either side evolves.
Proposed shape
Extract _build_norm_heads into a free function in dataset_mixture.py with an explicit signature:
def build_norm_heads(
*,
dataset_names: list[str], # already-deduplicated
per_dataset_stats: list[dict], # already standardized (post _to_standard_data_format)
metadatas: list[DatasetMetadata], # source of total_frames + robot_type/control_mode
raw_dims: list[tuple[int, int]],
) -> tuple[...]:
...
Then have DatasetMixtureMetadata._build_norm_heads reduce to a one-liner that forwards self.* into the free function. The fit script does its own "build per_dataset_stats from LeRobotDatasetMetadata.stats and run _to_standard_data_format on each" prep step, then calls the same build_norm_heads.
A secondary extraction is probably needed for the standardize-and-pad step (_to_standard_data_format at dataset_mixture.py:474) so the fit script can produce the standardized stats dict without going through a full DatasetMixtureMetadata.
Why deferred
Cross-file refactor touching dataset_mixture.py for benefits that are entirely "future-proofing". #348's surgical fixes close the currently-known divergences and add TestNormalizeEquivalenceVsProduction to catch future ones — that gates the most common failure mode. The proper extraction is correctness-equivalent today, just more durable. Doing it as a focused refactor PR is cleaner than bundling into a fit-script bug-fix PR.
Acceptance
_build_norm_heads lives as a free function; the instance method is a thin forwarder.
_to_standard_data_format is either factored similarly or has a documented "build the standardized stats dict outside of a DatasetMixtureMetadata instance" entry point.
fit_fast_tokenizer._aggregate_stats_per_head calls the free function directly. The ~140 lines of re-implementation collapse to a thin adapter.
TestNormalizeEquivalenceVsProduction still passes (sanity).
- Bonus: a new test pins that the fit-script and training paths produce byte-equal
per_norm_key_stats for an end-to-end synthetic mixture.
Context
Deferred follow-up from #348 (per-(robot_type, control_mode) normalization in
fit_fast_tokenizer). The PR added a per-head aggregator in the fit script that mirrors whatDatasetMixtureMetadata._build_norm_headsdoes at training. Surgical fixes in #348 close every divergence the review identified, but the design is brittle: any future change to_build_norm_heads's pooling semantics has to be mirrored in the fit-script aggregator manually.Problem
_build_norm_headslives insrc/opentau/datasets/dataset_mixture.py:322as an instance method onDatasetMixtureMetadata. It reads:self.dataset_names(deduplicated via_make_dataset_names)self.per_dataset_stats(post-_to_standard_data_format)self.cfg.max_action_dim,self.cfg.max_state_dim,self.cfg.num_camscompute_norm_keyfree functionBut
DatasetMixtureMetadata.__init__does far more than build heads — it loads parquet, builds episode indices, computes sample weights, etc. — so a fit-script user can't just instantiate it cheaply to read the norm-head structures off it.The fit script's
_aggregate_stats_per_headends up re-implementing the head-building logic from scratch (~140 lines), which is why the original review caught five concrete divergence findings. PR #348 fixed all five but the structural risk remains: the fit aggregator and_build_norm_headscan drift again the next time either side evolves.Proposed shape
Extract
_build_norm_headsinto a free function indataset_mixture.pywith an explicit signature:Then have
DatasetMixtureMetadata._build_norm_headsreduce to a one-liner that forwardsself.*into the free function. The fit script does its own "build per_dataset_stats fromLeRobotDatasetMetadata.statsand run_to_standard_data_formaton each" prep step, then calls the samebuild_norm_heads.A secondary extraction is probably needed for the standardize-and-pad step (
_to_standard_data_formatatdataset_mixture.py:474) so the fit script can produce the standardized stats dict without going through a fullDatasetMixtureMetadata.Why deferred
Cross-file refactor touching
dataset_mixture.pyfor benefits that are entirely "future-proofing". #348's surgical fixes close the currently-known divergences and addTestNormalizeEquivalenceVsProductionto catch future ones — that gates the most common failure mode. The proper extraction is correctness-equivalent today, just more durable. Doing it as a focused refactor PR is cleaner than bundling into a fit-script bug-fix PR.Acceptance
_build_norm_headslives as a free function; the instance method is a thin forwarder._to_standard_data_formatis either factored similarly or has a documented "build the standardized stats dict outside of aDatasetMixtureMetadatainstance" entry point.fit_fast_tokenizer._aggregate_stats_per_headcalls the free function directly. The~140lines of re-implementation collapse to a thin adapter.TestNormalizeEquivalenceVsProductionstill passes (sanity).per_norm_key_statsfor an end-to-end synthetic mixture.