Skip to content

Extract _build_norm_heads to a free function so fit_fast_tokenizer can call it #350

@shuheng-liu

Description

@shuheng-liu

Context

Deferred follow-up from #348 (per-(robot_type, control_mode) normalization in fit_fast_tokenizer). The PR added a per-head aggregator in the fit script that mirrors what DatasetMixtureMetadata._build_norm_heads does at training. Surgical fixes in #348 close every divergence the review identified, but the design is brittle: any future change to _build_norm_heads's pooling semantics has to be mirrored in the fit-script aggregator manually.

Problem

_build_norm_heads lives in src/opentau/datasets/dataset_mixture.py:322 as an instance method on DatasetMixtureMetadata. It reads:

  • self.dataset_names (deduplicated via _make_dataset_names)
  • self.per_dataset_stats (post-_to_standard_data_format)
  • self.cfg.max_action_dim, self.cfg.max_state_dim, self.cfg.num_cams
  • The compute_norm_key free function

But DatasetMixtureMetadata.__init__ does far more than build heads — it loads parquet, builds episode indices, computes sample weights, etc. — so a fit-script user can't just instantiate it cheaply to read the norm-head structures off it.

The fit script's _aggregate_stats_per_head ends up re-implementing the head-building logic from scratch (~140 lines), which is why the original review caught five concrete divergence findings. PR #348 fixed all five but the structural risk remains: the fit aggregator and _build_norm_heads can drift again the next time either side evolves.

Proposed shape

Extract _build_norm_heads into a free function in dataset_mixture.py with an explicit signature:

def build_norm_heads(
    *,
    dataset_names: list[str],          # already-deduplicated
    per_dataset_stats: list[dict],     # already standardized (post _to_standard_data_format)
    metadatas: list[DatasetMetadata],  # source of total_frames + robot_type/control_mode
    raw_dims: list[tuple[int, int]],
) -> tuple[...]:
    ...

Then have DatasetMixtureMetadata._build_norm_heads reduce to a one-liner that forwards self.* into the free function. The fit script does its own "build per_dataset_stats from LeRobotDatasetMetadata.stats and run _to_standard_data_format on each" prep step, then calls the same build_norm_heads.

A secondary extraction is probably needed for the standardize-and-pad step (_to_standard_data_format at dataset_mixture.py:474) so the fit script can produce the standardized stats dict without going through a full DatasetMixtureMetadata.

Why deferred

Cross-file refactor touching dataset_mixture.py for benefits that are entirely "future-proofing". #348's surgical fixes close the currently-known divergences and add TestNormalizeEquivalenceVsProduction to catch future ones — that gates the most common failure mode. The proper extraction is correctness-equivalent today, just more durable. Doing it as a focused refactor PR is cleaner than bundling into a fit-script bug-fix PR.

Acceptance

  • _build_norm_heads lives as a free function; the instance method is a thin forwarder.
  • _to_standard_data_format is either factored similarly or has a documented "build the standardized stats dict outside of a DatasetMixtureMetadata instance" entry point.
  • fit_fast_tokenizer._aggregate_stats_per_head calls the free function directly. The ~140 lines of re-implementation collapse to a thin adapter.
  • TestNormalizeEquivalenceVsProduction still passes (sanity).
  • Bonus: a new test pins that the fit-script and training paths produce byte-equal per_norm_key_stats for an end-to-end synthetic mixture.

Metadata

Metadata

Assignees

Labels

Type

No type
No fields configured for issues without a type.

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions