From 9178a2e4303e7a35c8262514e3bcb8b3f7735c5f Mon Sep 17 00:00:00 2001 From: Shuheng Liu Date: Tue, 2 Jun 2026 12:01:13 -0700 Subject: [PATCH 1/2] feat(datasets): allow per-dataset override of mixture val_split_ratio `DatasetConfig.val_split_ratio` was deprecated and ignored, so the mixture-level `val_split_ratio` applied uniformly to every dataset. Restore it as a real per-dataset override: `None` inherits the mixture default, any value (incl. `0.0` to opt out of validation) wins for that dataset only. Mirrors the existing `tolerance_s` / `skip_timestamp_check` inherit-on-None pattern, resolved in `make_dataset`. Also zero out per-dataset overrides in `fit_fast_tokenizer._build_train_cfg` so its "no validation split" invariant holds regardless of the input config. --- src/opentau/configs/default.py | 50 ++++++++----------- src/opentau/datasets/factory.py | 14 +++++- src/opentau/scripts/fit_fast_tokenizer.py | 14 ++++-- tests/configs/test_default.py | 22 +++++++-- tests/datasets/test_datasets.py | 58 +++++++++++++++++++++++ 5 files changed, 119 insertions(+), 39 deletions(-) diff --git a/src/opentau/configs/default.py b/src/opentau/configs/default.py index 7b1f9a14..328f3331 100644 --- a/src/opentau/configs/default.py +++ b/src/opentau/configs/default.py @@ -22,7 +22,6 @@ - Evaluation settings and parameters """ -import warnings from dataclasses import dataclass, field import draccus @@ -138,15 +137,12 @@ class DatasetConfig: tolerance_s: float | None = None skip_timestamp_check: bool | None = None - # DEPRECATED. Set `val_split_ratio` on `DatasetMixtureConfig` instead — the - # mixture-level value is the single source of truth and is applied uniformly - # to every dataset in the mixture. This per-dataset field is retained only - # so that pre-existing JSON configs continue to parse; setting it here has - # no effect on the actual split. The default is `None` (sentinel meaning - # "user did not set this") so that - # `DatasetMixtureConfig.__post_init__` can distinguish a real per-dataset - # override from the unset default and only emit a `DeprecationWarning` in - # the former case. + # Per-dataset override for the train/validation split ratio. `None` + # inherits from `DatasetMixtureConfig.val_split_ratio`. A non-None value + # (including `0.0`, which opts this dataset out of validation) wins over the + # mixture default for this dataset only — useful when one dataset in a + # mixture wants a different validation fraction than the rest. Must be in + # `[0, 1]` when set. Only consulted when `TrainPipelineConfig.val_freq > 0`. val_split_ratio: float | None = None def __post_init__(self): @@ -160,6 +156,12 @@ def __post_init__(self): f"got {self.tolerance_s} for {self.repo_id or self.vqa}." ) + if self.val_split_ratio is not None and not (0.0 <= self.val_split_ratio <= 1.0): + raise ValueError( + f"`DatasetConfig.val_split_ratio` must be in [0, 1] (or None to inherit), " + f"got {self.val_split_ratio} for {self.repo_id or self.vqa}." + ) + # If data_features_name_mapping is provided, upsert it into the global # DATA_FEATURES_NAME_MAPPING. Register under the plain repo_id (back-compat # fallback, last-wins) AND, when this entry carries a real control mode, @@ -207,6 +209,12 @@ class DatasetMixtureConfig: vector_resample_strategy: Resample strategy for non-image features, such as action or state. Must be one of 'linear' or 'nearest'. Defaults to 'nearest'. + val_split_ratio: Mixture-wide default fraction of each dataset reserved + for the validation split (only used when + ``TrainPipelineConfig.val_freq > 0``). A per-dataset + ``DatasetConfig.val_split_ratio`` overrides this value for that + dataset; ``None`` there inherits this mixture default. Must be in + ``[0, 1]``. Defaults to 0.05. n_obs_history: Number of historical observation steps to include. When set to ``T``, each camera returns shape ``(T, C, H, W)`` and state returns shape ``(T, max_state_dim)``. When ``None``, the default @@ -320,9 +328,10 @@ class DatasetMixtureConfig: image_resample_strategy: str = "nearest" # Resample strategy for non-image features, such as action or state vector_resample_strategy: str = "nearest" - # Ratio of the dataset to be used for validation. Please specify a value. + # Mixture-wide default ratio of each dataset to be used for validation. # If `val_freq` is set to 0, a validation dataset will not be created and this value will be ignored. - # This value is applied to all datasets in the mixture. + # A per-dataset `DatasetConfig.val_split_ratio` overrides this for that + # dataset (`None` there inherits this value). # Defaults to 0.05. val_split_ratio: float = 0.05 # Number of historical observation steps. None preserves default single-step behavior. @@ -398,23 +407,6 @@ def __post_init__(self): if not 0.0 <= value <= 1.0: raise ValueError(f"`{name}` must be in [0, 1], got {value}.") - # `DatasetConfig.val_split_ratio` is deprecated — the mixture-level - # value is the single source of truth (read by `factory.make_dataset`). - # The per-dataset field defaults to `None`; warn only when the user - # actually set a value there, since that's the case where their input - # is being silently ignored. - for dataset_cfg in self.datasets: - if dataset_cfg.val_split_ratio is not None: - warnings.warn( - "`DatasetConfig.val_split_ratio` is deprecated and ignored; " - "set `val_split_ratio` on `DatasetMixtureConfig` instead. " - f"Got dataset value {dataset_cfg.val_split_ratio} " - f"vs. mixture value {self.val_split_ratio}; the mixture " - "value will be used.", - DeprecationWarning, - stacklevel=2, - ) - @dataclass class WandBConfig: diff --git a/src/opentau/datasets/factory.py b/src/opentau/datasets/factory.py index d4226b6c..87bd023c 100644 --- a/src/opentau/datasets/factory.py +++ b/src/opentau/datasets/factory.py @@ -178,7 +178,9 @@ def make_dataset( A train and validation dataset are returned if `train_cfg.val_freq` is greater than 0. The validation dataset is a subset of the train dataset, and is used for evaluation during training. - The validation dataset is created by splitting the train dataset into train and validation sets based on `train_cfg.dataset_mixture.val_split_ratio`. + The validation dataset is created by splitting the train dataset into train and validation sets based on the + effective split ratio: the per-dataset `cfg.val_split_ratio` when set, otherwise the mixture-wide + `train_cfg.dataset_mixture.val_split_ratio` (the per-dataset value `None` inherits the mixture default). Args: cfg (DatasetConfig): A DatasetConfig used to create a LeRobotDataset. @@ -265,7 +267,15 @@ def make_dataset( dataset.meta.stats[key][stats_type] = np.array(stats, dtype=np.float32) if train_cfg.val_freq > 0: - val_size = int(len(dataset) * train_cfg.dataset_mixture.val_split_ratio) + # Per-dataset value wins over the mixture-wide default; `None` means + # "inherit". Mirrors the `tolerance_s` / `skip_timestamp_check` + # resolution above. See `DatasetConfig` / `DatasetMixtureConfig` docs. + effective_val_split = ( + cfg.val_split_ratio + if cfg.val_split_ratio is not None + else train_cfg.dataset_mixture.val_split_ratio + ) + val_size = int(len(dataset) * effective_val_split) train_size = len(dataset) - val_size train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size]) train_dataset.meta = copy.deepcopy(dataset.meta) # type: ignore[assignment] diff --git a/src/opentau/scripts/fit_fast_tokenizer.py b/src/opentau/scripts/fit_fast_tokenizer.py index 75474dbd..f77023f4 100644 --- a/src/opentau/scripts/fit_fast_tokenizer.py +++ b/src/opentau/scripts/fit_fast_tokenizer.py @@ -1017,6 +1017,15 @@ def _build_train_cfg( from opentau.configs.train import TrainPipelineConfig + # DatasetMixtureConfig.val_split_ratio defaults to 0.05, and a per-dataset + # ``DatasetConfig.val_split_ratio`` overrides it -- if either is non-zero, + # make_dataset would return a (train, val) tuple per dataset and the assert + # in ``_build_mixture_parallel`` would (correctly) reject it. Zero out BOTH + # the mixture default and every per-dataset override so the parallel mixture + # build always returns single train datasets. Rebuild the dataset configs + # (``dataclasses.replace`` shares the ``datasets`` list reference) so the + # caller's parsed config is left untouched. + datasets_for_fit = [dataclasses.replace(dc, val_split_ratio=0.0) for dc in mixture_cfg.datasets] # We only need action chunks for the tokenizer fit. Override mixture-side # knobs that would otherwise force per-sample state-history loads and # augmentation rolls (none of which affect the action column). Use @@ -1031,11 +1040,8 @@ def _build_train_cfg( response_drop_prob=1.0, # drop all responses metadata_drop_all_prob=1.0, # drop all metadata metadata_drop_each_prob=0.0, - # DatasetMixtureConfig.val_split_ratio defaults to 0.05 -- if we let it - # through, make_dataset would return a (train, val) tuple per dataset - # and the assert below would (correctly) reject it. Force 0 here so the - # parallel mixture build always returns single train datasets. val_split_ratio=0.0, + datasets=datasets_for_fit, ) fake_policy = SimpleNamespace( action_delta_indices=list(range(chunk_size)), diff --git a/tests/configs/test_default.py b/tests/configs/test_default.py index 3f3063b1..d520db6b 100644 --- a/tests/configs/test_default.py +++ b/tests/configs/test_default.py @@ -163,19 +163,33 @@ def test_val_split_ratio_no_warning_when_only_mixture_customized(): ) -def test_val_split_ratio_warns_when_child_overrides(): - """Setting `val_split_ratio` on a child `DatasetConfig` must emit a DeprecationWarning.""" +def test_val_split_ratio_no_warning_when_child_overrides(): + """Setting `val_split_ratio` on a child `DatasetConfig` is a supported + per-dataset override (inherit-on-None, like `tolerance_s`), not a deprecated + field, so it must NOT emit a DeprecationWarning. + """ with warnings.catch_warnings(record=True) as caught: warnings.simplefilter("always") DatasetMixtureConfig( datasets=[DatasetConfig(repo_id="foo/bar", val_split_ratio=0.2)], val_split_ratio=0.1, ) - assert any( - issubclass(w.category, DeprecationWarning) and "val_split_ratio" in str(w.message) for w in caught + val_split_warnings = [ + w + for w in caught + if issubclass(w.category, DeprecationWarning) and "val_split_ratio" in str(w.message) + ] + assert not val_split_warnings, ( + f"Unexpected val_split_ratio DeprecationWarning(s): {[str(w.message) for w in val_split_warnings]}" ) +def test_dataset_config_val_split_ratio_out_of_range_raises(): + """A per-dataset `val_split_ratio` outside [0, 1] must be rejected at config time.""" + with pytest.raises(ValueError, match=r"`DatasetConfig.val_split_ratio` must be in \[0, 1\]"): + DatasetConfig(repo_id="foo/bar", val_split_ratio=1.5) + + def test_dataset_mixture_config_tolerance_defaults(): """Mixture-level timestamp-sync defaults match the historical `LeRobotDataset` behavior.""" cfg = DatasetMixtureConfig() diff --git a/tests/datasets/test_datasets.py b/tests/datasets/test_datasets.py index c2f4c493..c93770f9 100644 --- a/tests/datasets/test_datasets.py +++ b/tests/datasets/test_datasets.py @@ -864,6 +864,64 @@ def test_make_dataset_per_dataset_skip_false_overrides_mixture_true(train_pipeli cleanup() +def test_make_dataset_per_dataset_val_split_ratio_wins(train_pipeline_config): + """A per-dataset `DatasetConfig.val_split_ratio` overrides the mixture default.""" + dataset_cfg = train_pipeline_config.dataset_mixture.datasets[0] + cleanup = _register_mapping(dataset_cfg.repo_id) + try: + dataset_cfg.val_split_ratio = 0.2 + train_pipeline_config.dataset_mixture.val_split_ratio = 0.05 + train_pipeline_config.val_freq = 1 # enable the train/val split branch + + with ( + patch("opentau.datasets.factory.LeRobotDatasetMetadata") as mock_meta_cls, + patch("opentau.datasets.factory.LeRobotDataset") as mock_ds_cls, + ): + mock_meta_cls.return_value = MagicMock(features=[]) + mock_ds = MagicMock(meta=MagicMock(info={}, stats={}, camera_keys=[])) + mock_ds.__len__.return_value = 100 + mock_ds.shallow_copy_with_dropout.return_value = mock_ds + mock_ds_cls.return_value = mock_ds + + result = make_dataset(dataset_cfg, train_pipeline_config) + + assert isinstance(result, tuple) + _, val_dataset = result + # 0.2 (per-dataset) wins over 0.05 (mixture): int(100 * 0.2) == 20. + assert len(val_dataset) == 20 + finally: + cleanup() + + +def test_make_dataset_inherits_mixture_val_split_ratio(train_pipeline_config): + """When the per-dataset `val_split_ratio` is None, the mixture default applies.""" + dataset_cfg = train_pipeline_config.dataset_mixture.datasets[0] + cleanup = _register_mapping(dataset_cfg.repo_id) + try: + assert dataset_cfg.val_split_ratio is None # dataclass default (inherit) + train_pipeline_config.dataset_mixture.val_split_ratio = 0.1 + train_pipeline_config.val_freq = 1 + + with ( + patch("opentau.datasets.factory.LeRobotDatasetMetadata") as mock_meta_cls, + patch("opentau.datasets.factory.LeRobotDataset") as mock_ds_cls, + ): + mock_meta_cls.return_value = MagicMock(features=[]) + mock_ds = MagicMock(meta=MagicMock(info={}, stats={}, camera_keys=[])) + mock_ds.__len__.return_value = 100 + mock_ds.shallow_copy_with_dropout.return_value = mock_ds + mock_ds_cls.return_value = mock_ds + + result = make_dataset(dataset_cfg, train_pipeline_config) + + assert isinstance(result, tuple) + _, val_dataset = result + # None (per-dataset) inherits mixture 0.1: int(100 * 0.1) == 10. + assert len(val_dataset) == 10 + finally: + cleanup() + + # TODO(aliberts): Move to more appropriate location def test_flatten_unflatten_dict(): d = { From b0c392a14cbb010ec6d3189f45bd6d78f882c54d Mon Sep 17 00:00:00 2001 From: Shuheng Liu Date: Tue, 2 Jun 2026 12:11:03 -0700 Subject: [PATCH 2/2] test(datasets): cover val_split_ratio=0.0 per-dataset opt-out MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Address review feedback that the advertised `0.0` opt-out was untested. Add a factory-level test (0.0 yields an empty val Subset, all samples stay in train) and a mixture-level test proving an empty member contributes no samples even with a positive mixture weight — `_calculate_sample_weights` skips length-0 datasets, so the opt-out works on the explicit-weights path too, not only the inferred-weights path. --- tests/datasets/test_dataset_mixture.py | 27 +++++++++++++++++++ tests/datasets/test_datasets.py | 37 ++++++++++++++++++++++++++ 2 files changed, 64 insertions(+) diff --git a/tests/datasets/test_dataset_mixture.py b/tests/datasets/test_dataset_mixture.py index 69647a98..9c47bd1d 100644 --- a/tests/datasets/test_dataset_mixture.py +++ b/tests/datasets/test_dataset_mixture.py @@ -314,6 +314,33 @@ def test_calculate_sample_weights_zero_weights(self, train_pipeline_config, data assert weights is not None assert torch.sum(weights) == 0 + @pytest.mark.slow # 2 sec + def test_calculate_sample_weights_skips_empty_member(self, train_pipeline_config, datasets_factory): + """A zero-length member contributes no samples even with a positive weight. + + This is the validation-mixture shape produced when a dataset opts out of + validation via `DatasetConfig.val_split_ratio=0.0`: `make_dataset` still + returns a `(train, val)` tuple whose val half is an empty `Subset`, which + gets appended to the val mixture. `_calculate_sample_weights` skips + length-0 datasets, so an explicit non-zero mixture weight on that member + has no effect — the opt-out works on the explicit-weights path too, not + only the inferred-weights path. + """ + full, to_empty = datasets_factory(2) + # An empty Subset with `.meta` mirrors what `make_dataset` builds for a + # 0.0 val split (random_split(..., [len, 0]) + meta deep-copy). + empty = torch.utils.data.Subset(to_empty, []) + empty.meta = to_empty.meta + + mixture = WeightedDatasetMixture(train_pipeline_config, [full, empty], [0.5, 0.5], 30.0) + weights = mixture._calculate_sample_weights() + + assert weights is not None + # Only the non-empty member contributes samples; the empty member is + # skipped despite its 0.5 weight. + assert len(weights) == len(full) + assert torch.all(weights > 0) + def test_get_dataloader_success(self, train_pipeline_config, datasets_factory): """Test successful dataloader creation.""" mixture = WeightedDatasetMixture(train_pipeline_config, datasets_factory(2), [0.7, 0.3], 30.0) diff --git a/tests/datasets/test_datasets.py b/tests/datasets/test_datasets.py index c93770f9..8aea4ed0 100644 --- a/tests/datasets/test_datasets.py +++ b/tests/datasets/test_datasets.py @@ -922,6 +922,43 @@ def test_make_dataset_inherits_mixture_val_split_ratio(train_pipeline_config): cleanup() +def test_make_dataset_per_dataset_val_split_ratio_zero_opts_out(train_pipeline_config): + """A per-dataset `val_split_ratio=0.0` opts that dataset out of validation. + + `make_dataset` still returns a `(train, val)` tuple (the branch is gated on + `val_freq`, not the ratio), but the val half is empty and all samples stay + in train. The empty val `Subset` is harmless in the val mixture: it carries + no samples, and `WeightedDatasetMixture._calculate_sample_weights` skips + length-0 members regardless of their weight (covered in + `test_dataset_mixture.py::...skips_empty_member`). + """ + dataset_cfg = train_pipeline_config.dataset_mixture.datasets[0] + cleanup = _register_mapping(dataset_cfg.repo_id) + try: + dataset_cfg.val_split_ratio = 0.0 # opt this dataset out of validation + train_pipeline_config.dataset_mixture.val_split_ratio = 0.05 + train_pipeline_config.val_freq = 1 + + with ( + patch("opentau.datasets.factory.LeRobotDatasetMetadata") as mock_meta_cls, + patch("opentau.datasets.factory.LeRobotDataset") as mock_ds_cls, + ): + mock_meta_cls.return_value = MagicMock(features=[]) + mock_ds = MagicMock(meta=MagicMock(info={}, stats={}, camera_keys=[])) + mock_ds.__len__.return_value = 100 + mock_ds.shallow_copy_with_dropout.return_value = mock_ds + mock_ds_cls.return_value = mock_ds + + result = make_dataset(dataset_cfg, train_pipeline_config) + + assert isinstance(result, tuple) + train_dataset, val_dataset = result + assert len(val_dataset) == 0 # int(100 * 0.0) == 0: empty val split + assert len(train_dataset) == 100 # every sample remains for training + finally: + cleanup() + + # TODO(aliberts): Move to more appropriate location def test_flatten_unflatten_dict(): d = {