Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 21 additions & 29 deletions src/opentau/configs/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
- Evaluation settings and parameters
"""

import warnings
from dataclasses import dataclass, field

import draccus
Expand Down Expand Up @@ -138,15 +137,12 @@ class DatasetConfig:
tolerance_s: float | None = None
skip_timestamp_check: bool | None = None

# DEPRECATED. Set `val_split_ratio` on `DatasetMixtureConfig` instead — the
# mixture-level value is the single source of truth and is applied uniformly
# to every dataset in the mixture. This per-dataset field is retained only
# so that pre-existing JSON configs continue to parse; setting it here has
# no effect on the actual split. The default is `None` (sentinel meaning
# "user did not set this") so that
# `DatasetMixtureConfig.__post_init__` can distinguish a real per-dataset
# override from the unset default and only emit a `DeprecationWarning` in
# the former case.
# Per-dataset override for the train/validation split ratio. `None`
# inherits from `DatasetMixtureConfig.val_split_ratio`. A non-None value
# (including `0.0`, which opts this dataset out of validation) wins over the
# mixture default for this dataset only — useful when one dataset in a
# mixture wants a different validation fraction than the rest. Must be in
# `[0, 1]` when set. Only consulted when `TrainPipelineConfig.val_freq > 0`.
val_split_ratio: float | None = None

def __post_init__(self):
Expand All @@ -160,6 +156,12 @@ def __post_init__(self):
f"got {self.tolerance_s} for {self.repo_id or self.vqa}."
)

if self.val_split_ratio is not None and not (0.0 <= self.val_split_ratio <= 1.0):
raise ValueError(
f"`DatasetConfig.val_split_ratio` must be in [0, 1] (or None to inherit), "
f"got {self.val_split_ratio} for {self.repo_id or self.vqa}."
)

# If data_features_name_mapping is provided, upsert it into the global
# DATA_FEATURES_NAME_MAPPING. Register under the plain repo_id (back-compat
# fallback, last-wins) AND, when this entry carries a real control mode,
Expand Down Expand Up @@ -207,6 +209,12 @@ class DatasetMixtureConfig:
vector_resample_strategy: Resample strategy for non-image features, such
as action or state. Must be one of 'linear' or 'nearest'.
Defaults to 'nearest'.
val_split_ratio: Mixture-wide default fraction of each dataset reserved
for the validation split (only used when
``TrainPipelineConfig.val_freq > 0``). A per-dataset
``DatasetConfig.val_split_ratio`` overrides this value for that
dataset; ``None`` there inherits this mixture default. Must be in
``[0, 1]``. Defaults to 0.05.
n_obs_history: Number of historical observation steps to include. When
set to ``T``, each camera returns shape ``(T, C, H, W)`` and state
returns shape ``(T, max_state_dim)``. When ``None``, the default
Expand Down Expand Up @@ -320,9 +328,10 @@ class DatasetMixtureConfig:
image_resample_strategy: str = "nearest"
# Resample strategy for non-image features, such as action or state
vector_resample_strategy: str = "nearest"
# Ratio of the dataset to be used for validation. Please specify a value.
# Mixture-wide default ratio of each dataset to be used for validation.
# If `val_freq` is set to 0, a validation dataset will not be created and this value will be ignored.
# This value is applied to all datasets in the mixture.
# A per-dataset `DatasetConfig.val_split_ratio` overrides this for that
# dataset (`None` there inherits this value).
# Defaults to 0.05.
val_split_ratio: float = 0.05
# Number of historical observation steps. None preserves default single-step behavior.
Expand Down Expand Up @@ -398,23 +407,6 @@ def __post_init__(self):
if not 0.0 <= value <= 1.0:
raise ValueError(f"`{name}` must be in [0, 1], got {value}.")

# `DatasetConfig.val_split_ratio` is deprecated — the mixture-level
# value is the single source of truth (read by `factory.make_dataset`).
# The per-dataset field defaults to `None`; warn only when the user
# actually set a value there, since that's the case where their input
# is being silently ignored.
for dataset_cfg in self.datasets:
if dataset_cfg.val_split_ratio is not None:
warnings.warn(
"`DatasetConfig.val_split_ratio` is deprecated and ignored; "
"set `val_split_ratio` on `DatasetMixtureConfig` instead. "
f"Got dataset value {dataset_cfg.val_split_ratio} "
f"vs. mixture value {self.val_split_ratio}; the mixture "
"value will be used.",
DeprecationWarning,
stacklevel=2,
)


@dataclass
class WandBConfig:
Expand Down
14 changes: 12 additions & 2 deletions src/opentau/datasets/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,9 @@ def make_dataset(

A train and validation dataset are returned if `train_cfg.val_freq` is greater than 0.
The validation dataset is a subset of the train dataset, and is used for evaluation during training.
The validation dataset is created by splitting the train dataset into train and validation sets based on `train_cfg.dataset_mixture.val_split_ratio`.
The validation dataset is created by splitting the train dataset into train and validation sets based on the
effective split ratio: the per-dataset `cfg.val_split_ratio` when set, otherwise the mixture-wide
`train_cfg.dataset_mixture.val_split_ratio` (the per-dataset value `None` inherits the mixture default).

Args:
cfg (DatasetConfig): A DatasetConfig used to create a LeRobotDataset.
Expand Down Expand Up @@ -265,7 +267,15 @@ def make_dataset(
dataset.meta.stats[key][stats_type] = np.array(stats, dtype=np.float32)

if train_cfg.val_freq > 0:
val_size = int(len(dataset) * train_cfg.dataset_mixture.val_split_ratio)
# Per-dataset value wins over the mixture-wide default; `None` means
# "inherit". Mirrors the `tolerance_s` / `skip_timestamp_check`
# resolution above. See `DatasetConfig` / `DatasetMixtureConfig` docs.
effective_val_split = (
cfg.val_split_ratio
if cfg.val_split_ratio is not None
else train_cfg.dataset_mixture.val_split_ratio
)
val_size = int(len(dataset) * effective_val_split)
Comment thread
shuheng-liu marked this conversation as resolved.
train_size = len(dataset) - val_size
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])
train_dataset.meta = copy.deepcopy(dataset.meta) # type: ignore[assignment]
Expand Down
14 changes: 10 additions & 4 deletions src/opentau/scripts/fit_fast_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1017,6 +1017,15 @@ def _build_train_cfg(

from opentau.configs.train import TrainPipelineConfig

# DatasetMixtureConfig.val_split_ratio defaults to 0.05, and a per-dataset
# ``DatasetConfig.val_split_ratio`` overrides it -- if either is non-zero,
# make_dataset would return a (train, val) tuple per dataset and the assert
# in ``_build_mixture_parallel`` would (correctly) reject it. Zero out BOTH
# the mixture default and every per-dataset override so the parallel mixture
# build always returns single train datasets. Rebuild the dataset configs
# (``dataclasses.replace`` shares the ``datasets`` list reference) so the
# caller's parsed config is left untouched.
datasets_for_fit = [dataclasses.replace(dc, val_split_ratio=0.0) for dc in mixture_cfg.datasets]
# We only need action chunks for the tokenizer fit. Override mixture-side
# knobs that would otherwise force per-sample state-history loads and
# augmentation rolls (none of which affect the action column). Use
Expand All @@ -1031,11 +1040,8 @@ def _build_train_cfg(
response_drop_prob=1.0, # drop all responses
metadata_drop_all_prob=1.0, # drop all metadata
metadata_drop_each_prob=0.0,
# DatasetMixtureConfig.val_split_ratio defaults to 0.05 -- if we let it
# through, make_dataset would return a (train, val) tuple per dataset
# and the assert below would (correctly) reject it. Force 0 here so the
# parallel mixture build always returns single train datasets.
val_split_ratio=0.0,
datasets=datasets_for_fit,
)
fake_policy = SimpleNamespace(
action_delta_indices=list(range(chunk_size)),
Expand Down
22 changes: 18 additions & 4 deletions tests/configs/test_default.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,19 +163,33 @@ def test_val_split_ratio_no_warning_when_only_mixture_customized():
)


def test_val_split_ratio_warns_when_child_overrides():
"""Setting `val_split_ratio` on a child `DatasetConfig` must emit a DeprecationWarning."""
def test_val_split_ratio_no_warning_when_child_overrides():
"""Setting `val_split_ratio` on a child `DatasetConfig` is a supported
per-dataset override (inherit-on-None, like `tolerance_s`), not a deprecated
field, so it must NOT emit a DeprecationWarning.
"""
with warnings.catch_warnings(record=True) as caught:
warnings.simplefilter("always")
DatasetMixtureConfig(
datasets=[DatasetConfig(repo_id="foo/bar", val_split_ratio=0.2)],
val_split_ratio=0.1,
)
assert any(
issubclass(w.category, DeprecationWarning) and "val_split_ratio" in str(w.message) for w in caught
val_split_warnings = [
w
for w in caught
if issubclass(w.category, DeprecationWarning) and "val_split_ratio" in str(w.message)
]
assert not val_split_warnings, (
f"Unexpected val_split_ratio DeprecationWarning(s): {[str(w.message) for w in val_split_warnings]}"
)


def test_dataset_config_val_split_ratio_out_of_range_raises():
"""A per-dataset `val_split_ratio` outside [0, 1] must be rejected at config time."""
with pytest.raises(ValueError, match=r"`DatasetConfig.val_split_ratio` must be in \[0, 1\]"):
DatasetConfig(repo_id="foo/bar", val_split_ratio=1.5)


def test_dataset_mixture_config_tolerance_defaults():
"""Mixture-level timestamp-sync defaults match the historical `LeRobotDataset` behavior."""
cfg = DatasetMixtureConfig()
Expand Down
27 changes: 27 additions & 0 deletions tests/datasets/test_dataset_mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,33 @@ def test_calculate_sample_weights_zero_weights(self, train_pipeline_config, data
assert weights is not None
assert torch.sum(weights) == 0

@pytest.mark.slow # 2 sec
def test_calculate_sample_weights_skips_empty_member(self, train_pipeline_config, datasets_factory):
"""A zero-length member contributes no samples even with a positive weight.

This is the validation-mixture shape produced when a dataset opts out of
validation via `DatasetConfig.val_split_ratio=0.0`: `make_dataset` still
returns a `(train, val)` tuple whose val half is an empty `Subset`, which
gets appended to the val mixture. `_calculate_sample_weights` skips
length-0 datasets, so an explicit non-zero mixture weight on that member
has no effect — the opt-out works on the explicit-weights path too, not
only the inferred-weights path.
"""
full, to_empty = datasets_factory(2)
# An empty Subset with `.meta` mirrors what `make_dataset` builds for a
# 0.0 val split (random_split(..., [len, 0]) + meta deep-copy).
empty = torch.utils.data.Subset(to_empty, [])
empty.meta = to_empty.meta

mixture = WeightedDatasetMixture(train_pipeline_config, [full, empty], [0.5, 0.5], 30.0)
weights = mixture._calculate_sample_weights()

assert weights is not None
# Only the non-empty member contributes samples; the empty member is
# skipped despite its 0.5 weight.
assert len(weights) == len(full)
assert torch.all(weights > 0)

def test_get_dataloader_success(self, train_pipeline_config, datasets_factory):
"""Test successful dataloader creation."""
mixture = WeightedDatasetMixture(train_pipeline_config, datasets_factory(2), [0.7, 0.3], 30.0)
Expand Down
95 changes: 95 additions & 0 deletions tests/datasets/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -864,6 +864,101 @@ def test_make_dataset_per_dataset_skip_false_overrides_mixture_true(train_pipeli
cleanup()


def test_make_dataset_per_dataset_val_split_ratio_wins(train_pipeline_config):
"""A per-dataset `DatasetConfig.val_split_ratio` overrides the mixture default."""
dataset_cfg = train_pipeline_config.dataset_mixture.datasets[0]
cleanup = _register_mapping(dataset_cfg.repo_id)
try:
dataset_cfg.val_split_ratio = 0.2
train_pipeline_config.dataset_mixture.val_split_ratio = 0.05
train_pipeline_config.val_freq = 1 # enable the train/val split branch

with (
patch("opentau.datasets.factory.LeRobotDatasetMetadata") as mock_meta_cls,
patch("opentau.datasets.factory.LeRobotDataset") as mock_ds_cls,
):
mock_meta_cls.return_value = MagicMock(features=[])
mock_ds = MagicMock(meta=MagicMock(info={}, stats={}, camera_keys=[]))
mock_ds.__len__.return_value = 100
mock_ds.shallow_copy_with_dropout.return_value = mock_ds
mock_ds_cls.return_value = mock_ds

result = make_dataset(dataset_cfg, train_pipeline_config)

assert isinstance(result, tuple)
_, val_dataset = result
# 0.2 (per-dataset) wins over 0.05 (mixture): int(100 * 0.2) == 20.
assert len(val_dataset) == 20
finally:
cleanup()


def test_make_dataset_inherits_mixture_val_split_ratio(train_pipeline_config):
"""When the per-dataset `val_split_ratio` is None, the mixture default applies."""
dataset_cfg = train_pipeline_config.dataset_mixture.datasets[0]
cleanup = _register_mapping(dataset_cfg.repo_id)
try:
assert dataset_cfg.val_split_ratio is None # dataclass default (inherit)
train_pipeline_config.dataset_mixture.val_split_ratio = 0.1
train_pipeline_config.val_freq = 1

with (
patch("opentau.datasets.factory.LeRobotDatasetMetadata") as mock_meta_cls,
patch("opentau.datasets.factory.LeRobotDataset") as mock_ds_cls,
):
mock_meta_cls.return_value = MagicMock(features=[])
mock_ds = MagicMock(meta=MagicMock(info={}, stats={}, camera_keys=[]))
mock_ds.__len__.return_value = 100
mock_ds.shallow_copy_with_dropout.return_value = mock_ds
mock_ds_cls.return_value = mock_ds

result = make_dataset(dataset_cfg, train_pipeline_config)

assert isinstance(result, tuple)
_, val_dataset = result
# None (per-dataset) inherits mixture 0.1: int(100 * 0.1) == 10.
assert len(val_dataset) == 10
finally:
cleanup()


def test_make_dataset_per_dataset_val_split_ratio_zero_opts_out(train_pipeline_config):
"""A per-dataset `val_split_ratio=0.0` opts that dataset out of validation.

`make_dataset` still returns a `(train, val)` tuple (the branch is gated on
`val_freq`, not the ratio), but the val half is empty and all samples stay
in train. The empty val `Subset` is harmless in the val mixture: it carries
no samples, and `WeightedDatasetMixture._calculate_sample_weights` skips
length-0 members regardless of their weight (covered in
`test_dataset_mixture.py::...skips_empty_member`).
"""
dataset_cfg = train_pipeline_config.dataset_mixture.datasets[0]
cleanup = _register_mapping(dataset_cfg.repo_id)
try:
dataset_cfg.val_split_ratio = 0.0 # opt this dataset out of validation
train_pipeline_config.dataset_mixture.val_split_ratio = 0.05
train_pipeline_config.val_freq = 1

with (
patch("opentau.datasets.factory.LeRobotDatasetMetadata") as mock_meta_cls,
patch("opentau.datasets.factory.LeRobotDataset") as mock_ds_cls,
):
mock_meta_cls.return_value = MagicMock(features=[])
mock_ds = MagicMock(meta=MagicMock(info={}, stats={}, camera_keys=[]))
mock_ds.__len__.return_value = 100
mock_ds.shallow_copy_with_dropout.return_value = mock_ds
mock_ds_cls.return_value = mock_ds

result = make_dataset(dataset_cfg, train_pipeline_config)

assert isinstance(result, tuple)
train_dataset, val_dataset = result
assert len(val_dataset) == 0 # int(100 * 0.0) == 0: empty val split
assert len(train_dataset) == 100 # every sample remains for training
finally:
cleanup()


# TODO(aliberts): Move to more appropriate location
def test_flatten_unflatten_dict():
d = {
Expand Down
Loading