From 70a8b9f0fc3fd6ec3ee1fdde61cd1f757eb4986e Mon Sep 17 00:00:00 2001 From: Shuheng Liu Date: Tue, 2 Jun 2026 00:54:01 -0700 Subject: [PATCH 1/2] feat(train): aggregate val loss per (dataset, control_mode), parallelize Validation runs a single combined pass over the whole mixture instead of one DataLoader per dataset iterated sequentially, so every rank's shard stays full (many tiny val subsets no longer leave most ranks idle). Per-(dataset, control_mode) MSE/CE breakdowns are computed in the loop from each sample's dropout-immune dataset_index / dataset_repo_id provenance, not from homogeneous per-dataset loaders, so a mixed batch can be disaggregated correctly. Policies optionally expose per-sample MSE/CE as (sum, count) via a return_per_sample flag; the training scalar path is byte-for-byte unchanged (the flag defaults False and the train step never sets it). --- src/opentau/datasets/dataset_mixture.py | 43 +++ src/opentau/policies/pi0/modeling_pi0.py | 32 +- src/opentau/policies/pi05/modeling_pi05.py | 61 +++- .../policies/pi05_mem/modeling_pi05.py | 45 ++- src/opentau/policies/pi06/modeling_pi06.py | 56 +++- .../pi07/low_level/modeling_pi07_low_level.py | 39 ++- .../low_level/modeling_pi07_low_level.py | 36 +- src/opentau/policies/utils.py | 78 ++++- src/opentau/policies/value/modeling_value.py | 37 +- src/opentau/scripts/train.py | 316 ++++++++++++------ tests/datasets/test_dataset_mixture.py | 28 ++ tests/policies/test_pi06.py | 100 ++++++ tests/scripts/test_train.py | 47 +++ 13 files changed, 763 insertions(+), 155 deletions(-) diff --git a/src/opentau/datasets/dataset_mixture.py b/src/opentau/datasets/dataset_mixture.py index db723421..29c4a121 100644 --- a/src/opentau/datasets/dataset_mixture.py +++ b/src/opentau/datasets/dataset_mixture.py @@ -975,3 +975,46 @@ def get_per_dataset_dataloaders(self) -> dict[str, DataLoader]: worker_init_fn=worker_init_fn, ) return loaders + + def get_combined_val_dataloader(self) -> DataLoader | None: + """Create one deterministic sequential DataLoader over the whole mixture. + + Unlike :meth:`get_per_dataset_dataloaders` (one loader per dataset), this + returns a single loader over the concatenated mixture so that, under + ``accelerator.prepare``, every rank's shard is full even when individual + validation subsets have fewer frames than ``world_size`` — the per-dataset + loaders leave most ranks idle on tiny subsets and stack that idle time + across datasets. Each sample still carries its ``dataset_index`` / + ``dataset_repo_id`` provenance (injected by ``_TaggedDataset``), so the + validation loop can disaggregate metrics per ``(dataset, control_mode)`` + from the batch rather than relying on homogeneous per-dataset loaders. + ``shuffle=False`` + ``drop_last=False`` make the pass order-deterministic + (seed-independent) and score every sample exactly once. + + Returns: + A single ``DataLoader`` over the mixture, or ``None`` when the mixture + is empty (mirrors the empty-dataset skip in + :meth:`get_per_dataset_dataloaders`). + """ + if len(self.concatenated_dataset) == 0: + logging.info("Combined validation DataLoader skipped: the mixture is empty.") + return None + + worker_name_mapping_overrides = self._get_worker_name_mapping_overrides() + worker_init_fn = None + if worker_name_mapping_overrides: + worker_init_fn = functools.partial( + _apply_data_feature_name_mapping_overrides, + mapping_overrides=worker_name_mapping_overrides, + ) + + return DataLoader( + self.concatenated_dataset, + batch_size=self.cfg.dataloader_batch_size, + shuffle=False, + num_workers=self.cfg.num_workers, + pin_memory=torch.cuda.is_available(), + drop_last=False, + prefetch_factor=self.cfg.prefetch_factor, + worker_init_fn=worker_init_fn, + ) diff --git a/src/opentau/policies/pi0/modeling_pi0.py b/src/opentau/policies/pi0/modeling_pi0.py index 6b814514..85850f40 100644 --- a/src/opentau/policies/pi0/modeling_pi0.py +++ b/src/opentau/policies/pi0/modeling_pi0.py @@ -25,7 +25,7 @@ import torch import torch.nn.functional as F # noqa: N812 -from einops import rearrange +from einops import rearrange, reduce from torch import Tensor, nn from transformers import AutoTokenizer @@ -37,7 +37,7 @@ PaliGemmaWithExpertModel, ) from opentau.policies.pretrained import PreTrainedPolicy -from opentau.policies.utils import log_model_loading_keys, make_action_dim_mask +from opentau.policies.utils import PerSampleLoss, log_model_loading_keys, make_action_dim_mask from opentau.utils.accelerate_utils import get_proc_accelerator from opentau.utils.utils import get_safe_dtype @@ -466,14 +466,22 @@ def sample_actions(self, batch: dict[str, Tensor], noise: Tensor | None = None) return actions def forward( - self, batch: dict[str, Tensor], noise: Tensor | None = None, time: Tensor | None = None - ) -> dict[str, Tensor]: + self, + batch: dict[str, Tensor], + noise: Tensor | None = None, + time: Tensor | None = None, + return_per_sample: bool = False, + ) -> dict[str, Tensor | PerSampleLoss]: """Do a full training forward pass to compute the loss. Args: batch: Batch of data containing environment observations, actions, and targets. noise: Optional noise tensor. time: Optional time tensor. + return_per_sample: When True, also returns per-sample + ``MSE_per_sample``/``CE_per_sample`` (:class:`PerSampleLoss`) for the + validation per-(dataset, control_mode) breakdown. ``CE`` is a zero + stub for pi0, so ``CE_per_sample`` carries zero sum and count. Returns: A dictionary containing the loss components ("MSE" and "CE"). @@ -537,7 +545,21 @@ def forward( loss = losses.sum() / (full_mask.sum() + 1e-8) - return {"MSE": loss, "CE": torch.zeros_like(loss, requires_grad=True)} + out: dict[str, Tensor | PerSampleLoss] = { + "MSE": loss, + "CE": torch.zeros_like(loss, requires_grad=True), + } + if return_per_sample: + # ``losses`` is already masked (and AWR-weighted, if enabled), matching + # the scalar reduction; reduce over (chunk, dim) keeping the batch axis. + out["MSE_per_sample"] = PerSampleLoss( + sum=reduce(losses, "b c d -> b", "sum"), + count=reduce(full_mask.float(), "b c d -> b", "sum"), + ) + # pi0 has no CE term; emit zero sum/count so it forms no CE groups. + zeros = torch.zeros(losses.shape[0], device=losses.device) + out["CE_per_sample"] = PerSampleLoss(sum=zeros, count=zeros.clone()) + return out def prepare_images(self, batch: dict[str, Tensor]) -> tuple[list[Tensor], list[Tensor]]: """Apply Pi0 preprocessing to the images. diff --git a/src/opentau/policies/pi05/modeling_pi05.py b/src/opentau/policies/pi05/modeling_pi05.py index 8e6c8317..2e4e4cd2 100644 --- a/src/opentau/policies/pi05/modeling_pi05.py +++ b/src/opentau/policies/pi05/modeling_pi05.py @@ -44,7 +44,7 @@ PaliGemmaWithExpertModel, ) from opentau.policies.pretrained import PreTrainedPolicy, T -from opentau.policies.utils import flow_matching_masked_mse +from opentau.policies.utils import PerSampleLoss, ce_per_sample, flow_matching_masked_mse from opentau.utils.accelerate_utils import get_proc_accelerator from opentau.utils.utils import get_safe_dtype @@ -680,17 +680,26 @@ def sample_actions( return actions def forward( - self, batch: dict[str, Tensor], noise: Tensor | None = None, time: Tensor | None = None - ) -> dict[str, Tensor]: + self, + batch: dict[str, Tensor], + noise: Tensor | None = None, + time: Tensor | None = None, + return_per_sample: bool = False, + ) -> dict[str, Tensor | PerSampleLoss]: """Do a full training forward pass to compute the loss. Args: batch: Batch of data containing environment observations, actions, and targets. noise: Optional noise tensor. time: Optional time tensor. + return_per_sample: When True, additionally return per-sample ``MSE_per_sample`` + / ``CE_per_sample`` (:class:`PerSampleLoss`) so the validation loop can + bucket the loss by ``(dataset, control_mode)`` provenance. The scalar + ``MSE``/``CE`` are unchanged, so the training path is unaffected. Returns: - A dictionary containing the loss components ("MSE" and "CE"). + A dictionary with the loss components ("MSE" and "CE"), plus + "MSE_per_sample"/"CE_per_sample" when ``return_per_sample`` is True. """ dataset_index = self._resolve_dataset_index(batch) batch = self.normalize_inputs(batch, dataset_index) @@ -732,12 +741,14 @@ def forward( discrete_action_masks, state=state, real_action_dim=batch.get("real_action_dim"), + return_per_sample=return_per_sample, ) - mse_loss = losses["MSE"] - ce_loss = losses["CE"] - - return {"MSE": mse_loss, "CE": ce_loss} + out: dict[str, Tensor | PerSampleLoss] = {"MSE": losses["MSE"], "CE": losses["CE"]} + if return_per_sample: + out["MSE_per_sample"] = losses["MSE_per_sample"] + out["CE_per_sample"] = losses["CE_per_sample"] + return out def prepare_state(self, batch: dict[str, Tensor]) -> Tensor: """Prepares the continuous state tensor, padding or truncating to max_state_dim. @@ -1315,7 +1326,8 @@ def forward( discrete_action_masks: Tensor | None = None, state: Tensor | None = None, real_action_dim: Tensor | None = None, - ) -> dict[str, Tensor]: + return_per_sample: bool = False, + ) -> dict[str, Tensor | PerSampleLoss]: """Do a full training forward pass and compute the loss. Args: @@ -1438,14 +1450,16 @@ def forward( # Shared masked-MSE reduction: AND-s frozen-prefix, timestep-pad, and # dim-pad masks together and divides by the unmasked-slot count. See # ``opentau.policies.utils.flow_matching_masked_mse`` for the full spec. - mse_loss = flow_matching_masked_mse( + mse_result = flow_matching_masked_mse( u_t=u_t, v_t=v_t, max_action_dim=self.config.max_action_dim, prefix_mask=prefix_mask, actions_is_pad=actions_is_pad, real_action_dim=real_action_dim, + return_per_sample=return_per_sample, ) + mse_loss, mse_per_sample = mse_result if return_per_sample else (mse_result, None) # compute cross entropy loss for discrete actions batch_size, seq_len = discrete_actions.shape @@ -1467,6 +1481,11 @@ def forward( discrete_action_is_pad = ~discrete_action_masks # convert into format where value for pad is True discrete_action_ce_loss = discrete_action_ce_loss * ~discrete_action_is_pad + # Per-sample CE numerator/denominator over valid tokens, for the val breakdown. + discrete_action_ce_per_sample = ( + ce_per_sample(discrete_action_ce_loss, ~discrete_action_is_pad) if return_per_sample else None + ) + # compute mean discrete_action_ce_loss = discrete_action_ce_loss.mean() @@ -1505,12 +1524,30 @@ def forward( # helps to control loss for response tokens in case of robotic data and VQA data response_ce_loss = response_ce_loss * ~response_is_pad[:, response_slice] + # Per-sample response CE (valid tokens only) for the val breakdown. + response_ce_per_sample = ( + ce_per_sample(response_ce_loss, ~response_is_pad[:, response_slice]) + if return_per_sample + else None + ) + # compute mean response_ce_loss = response_ce_loss.mean() else: response_ce_loss = torch.tensor(0.0, device=mse_loss.device) - - return {"MSE": mse_loss, "CE": discrete_action_ce_loss + response_ce_loss} + response_ce_per_sample = None + + out: dict[str, Tensor | PerSampleLoss] = { + "MSE": mse_loss, + "CE": discrete_action_ce_loss + response_ce_loss, + } + if return_per_sample: + ce_ps = discrete_action_ce_per_sample + if response_ce_per_sample is not None: + ce_ps = ce_ps + response_ce_per_sample + out["MSE_per_sample"] = mse_per_sample + out["CE_per_sample"] = ce_ps + return out def sample_actions( self, diff --git a/src/opentau/policies/pi05_mem/modeling_pi05.py b/src/opentau/policies/pi05_mem/modeling_pi05.py index 14c38f27..66bc620c 100644 --- a/src/opentau/policies/pi05_mem/modeling_pi05.py +++ b/src/opentau/policies/pi05_mem/modeling_pi05.py @@ -56,7 +56,7 @@ from opentau.policies.pi05_mem.configuration_pi05 import PI05MemConfig from opentau.policies.pi07.video_encoder import SpaceTimeSiglipVideoEncoder from opentau.policies.pretrained import PreTrainedPolicy, T -from opentau.policies.utils import flow_matching_masked_mse +from opentau.policies.utils import PerSampleLoss, ce_per_sample, flow_matching_masked_mse from opentau.utils.accelerate_utils import get_proc_accelerator from opentau.utils.utils import get_safe_dtype @@ -661,9 +661,19 @@ def sample_actions( return actions def forward( - self, batch: dict[str, Tensor], noise: Tensor | None = None, time: Tensor | None = None - ) -> dict[str, Tensor]: - """Do a full training forward pass to compute the loss.""" + self, + batch: dict[str, Tensor], + noise: Tensor | None = None, + time: Tensor | None = None, + return_per_sample: bool = False, + ) -> dict[str, Tensor | PerSampleLoss]: + """Do a full training forward pass to compute the loss. + + When ``return_per_sample`` is True, also returns per-sample + ``MSE_per_sample``/``CE_per_sample`` (:class:`PerSampleLoss`) for the + validation per-(dataset, control_mode) breakdown; the scalar losses are + unchanged. + """ dataset_index = self._resolve_dataset_index(batch) batch = self.normalize_inputs(batch, dataset_index) batch["discrete_actions"] = self.normalize_discrete_actions(dict(batch), dataset_index)["actions"] @@ -696,12 +706,14 @@ def forward( discrete_action_masks, obs_history_is_pad=obs_history_is_pad, real_action_dim=batch.get("real_action_dim"), + return_per_sample=return_per_sample, ) - mse_loss = losses["MSE"] - ce_loss = losses["CE"] - - return {"MSE": mse_loss, "CE": ce_loss} + out: dict[str, Tensor | PerSampleLoss] = {"MSE": losses["MSE"], "CE": losses["CE"]} + if return_per_sample: + out["MSE_per_sample"] = losses["MSE_per_sample"] + out["CE_per_sample"] = losses["CE_per_sample"] + return out def prepare_state(self, batch: dict[str, Tensor]) -> Tensor: """Prepares the temporal state tensor, padding or truncating to max_state_dim. @@ -1094,7 +1106,8 @@ def forward( discrete_action_masks: Tensor | None = None, obs_history_is_pad: Tensor | None = None, real_action_dim: Tensor | None = None, - ) -> dict[str, Tensor]: + return_per_sample: bool = False, + ) -> dict[str, Tensor | PerSampleLoss]: """Do a full training forward pass and compute the loss.""" prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix( videos, @@ -1178,14 +1191,16 @@ def forward( v_t = v_t.to(dtype=torch.float32) # Shared masked-MSE reduction; see pi05 for the rationale. - mse_loss = flow_matching_masked_mse( + mse_result = flow_matching_masked_mse( u_t=u_t, v_t=v_t, max_action_dim=self.config.max_action_dim, prefix_mask=prefix_mask, actions_is_pad=actions_is_pad, real_action_dim=real_action_dim, + return_per_sample=return_per_sample, ) + mse_loss, mse_per_sample = mse_result if return_per_sample else (mse_result, None) assert discrete_actions is not None assert discrete_action_masks is not None @@ -1206,9 +1221,17 @@ def forward( discrete_action_is_pad = ~discrete_action_masks discrete_action_ce_loss = discrete_action_ce_loss * ~discrete_action_is_pad + ce_per_sample_loss = ( + ce_per_sample(discrete_action_ce_loss, ~discrete_action_is_pad) if return_per_sample else None + ) + discrete_action_ce_loss = discrete_action_ce_loss.mean() - return {"MSE": mse_loss, "CE": discrete_action_ce_loss} + out: dict[str, Tensor | PerSampleLoss] = {"MSE": mse_loss, "CE": discrete_action_ce_loss} + if return_per_sample: + out["MSE_per_sample"] = mse_per_sample + out["CE_per_sample"] = ce_per_sample_loss + return out def sample_actions( self, diff --git a/src/opentau/policies/pi06/modeling_pi06.py b/src/opentau/policies/pi06/modeling_pi06.py index c6a3d272..d63ae53f 100644 --- a/src/opentau/policies/pi06/modeling_pi06.py +++ b/src/opentau/policies/pi06/modeling_pi06.py @@ -51,7 +51,7 @@ Gemma3WithExpertModel, ) from opentau.policies.pretrained import PreTrainedPolicy, T -from opentau.policies.utils import flow_matching_masked_mse +from opentau.policies.utils import PerSampleLoss, ce_per_sample, flow_matching_masked_mse from opentau.utils.accelerate_utils import get_proc_accelerator from opentau.utils.utils import get_safe_dtype @@ -587,9 +587,18 @@ def sample_actions( return actions def forward( - self, batch: dict[str, Tensor], noise: Tensor | None = None, time: Tensor | None = None - ) -> dict[str, Tensor]: - """Full training forward pass. Returns `{"MSE": ..., "CE": ...}`.""" + self, + batch: dict[str, Tensor], + noise: Tensor | None = None, + time: Tensor | None = None, + return_per_sample: bool = False, + ) -> dict[str, Tensor | PerSampleLoss]: + """Full training forward pass. Returns `{"MSE": ..., "CE": ...}`. + + When ``return_per_sample`` is True, also returns ``MSE_per_sample`` / + ``CE_per_sample`` (:class:`PerSampleLoss`) for the validation + per-(dataset, control_mode) breakdown; the scalar losses are unchanged. + """ dataset_index = self._resolve_dataset_index(batch) batch = self.normalize_inputs(batch, dataset_index) batch["discrete_actions"] = self.normalize_discrete_actions(dict(batch), dataset_index)["actions"] @@ -616,11 +625,14 @@ def forward( discrete_actions, discrete_action_masks, real_action_dim=batch.get("real_action_dim"), + return_per_sample=return_per_sample, ) - mse_loss = losses["MSE"] - ce_loss = losses["CE"] - return {"MSE": mse_loss, "CE": ce_loss} + out: dict[str, Tensor | PerSampleLoss] = {"MSE": losses["MSE"], "CE": losses["CE"]} + if return_per_sample: + out["MSE_per_sample"] = losses["MSE_per_sample"] + out["CE_per_sample"] = losses["CE_per_sample"] + return out # Preprocessing helpers (state discretization, image resize, etc.) @@ -978,7 +990,8 @@ def forward( discrete_actions: Tensor | None = None, discrete_action_masks: Tensor | None = None, real_action_dim: Tensor | None = None, - ) -> dict[str, Tensor]: + return_per_sample: bool = False, + ) -> dict[str, Tensor | PerSampleLoss]: """Full training forward pass. Returns `{"MSE": ..., "CE": ...}`.""" prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix( images, @@ -1061,14 +1074,16 @@ def forward( v_t = self.action_out_proj(suffix_out) v_t = v_t.to(dtype=torch.float32) - mse_loss = flow_matching_masked_mse( + mse_result = flow_matching_masked_mse( u_t=u_t, v_t=v_t, prefix_mask=prefix_mask, actions_is_pad=actions_is_pad, max_action_dim=self.config.max_action_dim, real_action_dim=real_action_dim, + return_per_sample=return_per_sample, ) + mse_loss, mse_per_sample = mse_result if return_per_sample else (mse_result, None) # Discrete-action cross-entropy (FAST tokens) via the dedicated head. batch_size_da, seq_len = discrete_actions.shape @@ -1085,6 +1100,9 @@ def forward( ) discrete_action_is_pad = ~discrete_action_masks discrete_action_ce_loss = discrete_action_ce_loss * ~discrete_action_is_pad + discrete_action_ce_per_sample = ( + ce_per_sample(discrete_action_ce_loss, ~discrete_action_is_pad) if return_per_sample else None + ) discrete_action_ce_loss = discrete_action_ce_loss.mean() # Optional response-token cross-entropy (via Gemma 3's shared lm_head). @@ -1105,11 +1123,27 @@ def forward( ) response_is_pad = ~response_masks response_ce_loss = response_ce_loss * ~response_is_pad[:, response_slice] + response_ce_per_sample = ( + ce_per_sample(response_ce_loss, ~response_is_pad[:, response_slice]) + if return_per_sample + else None + ) response_ce_loss = response_ce_loss.mean() else: response_ce_loss = torch.tensor(0.0, device=mse_loss.device) - - return {"MSE": mse_loss, "CE": discrete_action_ce_loss + response_ce_loss} + response_ce_per_sample = None + + out: dict[str, Tensor | PerSampleLoss] = { + "MSE": mse_loss, + "CE": discrete_action_ce_loss + response_ce_loss, + } + if return_per_sample: + ce_ps = discrete_action_ce_per_sample + if response_ce_per_sample is not None: + ce_ps = ce_ps + response_ce_per_sample + out["MSE_per_sample"] = mse_per_sample + out["CE_per_sample"] = ce_ps + return out def _gemma3_lm_head(self): """Return the language-modeling head of the Gemma 3 backbone, regardless diff --git a/src/opentau/policies/pi07/low_level/modeling_pi07_low_level.py b/src/opentau/policies/pi07/low_level/modeling_pi07_low_level.py index ab3d72c5..ebda13f5 100644 --- a/src/opentau/policies/pi07/low_level/modeling_pi07_low_level.py +++ b/src/opentau/policies/pi07/low_level/modeling_pi07_low_level.py @@ -63,7 +63,7 @@ ) from opentau.policies.pi07.video_encoder import SpaceTimeSiglipVideoEncoder from opentau.policies.pretrained import PreTrainedPolicy, ProjectionRemapError, T -from opentau.policies.utils import flow_matching_masked_mse +from opentau.policies.utils import PerSampleLoss, ce_per_sample, flow_matching_masked_mse from opentau.utils.accelerate_utils import get_proc_accelerator from opentau.utils.utils import get_safe_dtype @@ -859,8 +859,12 @@ def sample_actions( return actions def forward( - self, batch: dict[str, Tensor], noise: Tensor | None = None, time: Tensor | None = None - ) -> dict[str, Tensor | list]: + self, + batch: dict[str, Tensor], + noise: Tensor | None = None, + time: Tensor | None = None, + return_per_sample: bool = False, + ) -> dict[str, Tensor | list | PerSampleLoss]: """Training forward pass: normalize, prepare modalities, and compute losses. Returns a dict with ``"MSE"`` (flow-matching velocity loss) and @@ -872,9 +876,13 @@ def forward( batch: Training batch dict with observations, actions, and prompts. noise: Optional pre-sampled noise tensor. time: Optional pre-sampled flow-matching timesteps. + return_per_sample: When True, also returns per-sample + ``MSE_per_sample``/``CE_per_sample`` (:class:`PerSampleLoss`) for the + validation per-(dataset, control_mode) breakdown. Scalars unchanged. Returns: - Dict with ``"MSE"`` and ``"CE"`` scalar loss tensors. + Dict with ``"MSE"`` and ``"CE"`` scalar loss tensors (plus per-sample + entries when ``return_per_sample`` is True). """ dataset_index = self._resolve_dataset_index(batch) batch = self.normalize_inputs(batch, dataset_index) @@ -923,12 +931,16 @@ def forward( response_masks=response_masks, real_action_dim=batch.get("real_action_dim"), group_index=dataset_index, + return_per_sample=return_per_sample, ) mse_loss = losses["MSE"] ce_loss = losses["CE"] - out: dict[str, Tensor | list] = {"MSE": mse_loss, "CE": ce_loss} + out: dict[str, Tensor | list | PerSampleLoss] = {"MSE": mse_loss, "CE": ce_loss} + if return_per_sample: + out["MSE_per_sample"] = losses["MSE_per_sample"] + out["CE_per_sample"] = losses["CE_per_sample"] if outlier_records: out["outlier_records"] = outlier_records return out @@ -1916,7 +1928,8 @@ def forward( response_masks: Tensor | None = None, real_action_dim: Tensor | None = None, group_index: Tensor | None = None, - ) -> dict[str, Tensor]: + return_per_sample: bool = False, + ) -> dict[str, Tensor | PerSampleLoss]: """Training forward pass: embed all modalities and compute losses. Runs the VLM on the prefix (video, language, response, state, subgoal @@ -2048,14 +2061,16 @@ def forward( v_t = v_t.to(dtype=torch.float32) # Shared masked-MSE reduction; see pi05 for the rationale. - mse_loss = flow_matching_masked_mse( + mse_result = flow_matching_masked_mse( u_t=u_t, v_t=v_t, max_action_dim=self.config.max_action_dim, prefix_mask=prefix_mask, actions_is_pad=actions_is_pad, real_action_dim=real_action_dim, + return_per_sample=return_per_sample, ) + mse_loss, mse_per_sample = mse_result if return_per_sample else (mse_result, None) assert discrete_actions is not None assert discrete_action_masks is not None @@ -2076,9 +2091,17 @@ def forward( discrete_action_is_pad = ~discrete_action_masks discrete_action_ce_loss = discrete_action_ce_loss * ~discrete_action_is_pad + ce_per_sample_loss = ( + ce_per_sample(discrete_action_ce_loss, ~discrete_action_is_pad) if return_per_sample else None + ) + discrete_action_ce_loss = discrete_action_ce_loss.mean() - return {"MSE": mse_loss, "CE": discrete_action_ce_loss} + out: dict[str, Tensor | PerSampleLoss] = {"MSE": mse_loss, "CE": discrete_action_ce_loss} + if return_per_sample: + out["MSE_per_sample"] = mse_per_sample + out["CE_per_sample"] = ce_per_sample_loss + return out def sample_actions( self, diff --git a/src/opentau/policies/pi07_paligemma/low_level/modeling_pi07_low_level.py b/src/opentau/policies/pi07_paligemma/low_level/modeling_pi07_low_level.py index 8ecd7173..db25b0ed 100644 --- a/src/opentau/policies/pi07_paligemma/low_level/modeling_pi07_low_level.py +++ b/src/opentau/policies/pi07_paligemma/low_level/modeling_pi07_low_level.py @@ -52,7 +52,7 @@ PI07PaligemmaLowLevelConfig, ) from opentau.policies.pretrained import PreTrainedPolicy, ProjectionRemapError, T -from opentau.policies.utils import flow_matching_masked_mse +from opentau.policies.utils import PerSampleLoss, ce_per_sample, flow_matching_masked_mse from opentau.utils.accelerate_utils import get_proc_accelerator from opentau.utils.utils import get_safe_dtype @@ -1124,14 +1124,21 @@ def sample_actions( return actions def forward( - self, batch: dict[str, Tensor], noise: Tensor | None = None, time: Tensor | None = None - ) -> dict[str, Tensor | list]: + self, + batch: dict[str, Tensor], + noise: Tensor | None = None, + time: Tensor | None = None, + return_per_sample: bool = False, + ) -> dict[str, Tensor | list | PerSampleLoss]: """Do a full training forward pass to compute the loss. Args: batch: Batch of data containing environment observations, actions, and targets. noise: Optional noise tensor. time: Optional time tensor. + return_per_sample: When True, also returns per-sample + ``MSE_per_sample``/``CE_per_sample`` (:class:`PerSampleLoss`) for the + validation per-(dataset, control_mode) breakdown. Scalars unchanged. Returns: A dictionary containing the loss components ("MSE" and "CE"). When the @@ -1170,12 +1177,16 @@ def forward( time=time, real_action_dim=batch.get("real_action_dim"), group_index=dataset_index, + return_per_sample=return_per_sample, ) mse_loss = losses["MSE"] ce_loss = losses["CE"] - out: dict[str, Tensor | list] = {"MSE": mse_loss, "CE": ce_loss} + out: dict[str, Tensor | list | PerSampleLoss] = {"MSE": mse_loss, "CE": ce_loss} + if return_per_sample: + out["MSE_per_sample"] = losses["MSE_per_sample"] + out["CE_per_sample"] = losses["CE_per_sample"] if outlier_records: out["outlier_records"] = outlier_records return out @@ -2174,7 +2185,8 @@ def forward( time: Tensor | None = None, real_action_dim: Tensor | None = None, group_index: Tensor | None = None, - ) -> dict[str, Tensor]: + return_per_sample: bool = False, + ) -> dict[str, Tensor | PerSampleLoss]: """Do a full training forward pass and compute the loss. Args: @@ -2281,14 +2293,16 @@ def forward( v_t = v_t.to(dtype=torch.float32) # Shared masked-MSE reduction; see pi05 for the rationale. - mse_loss = flow_matching_masked_mse( + mse_result = flow_matching_masked_mse( u_t=u_t, v_t=v_t, max_action_dim=self.config.max_action_dim, prefix_mask=prefix_mask, actions_is_pad=actions_is_pad, real_action_dim=real_action_dim, + return_per_sample=return_per_sample, ) + mse_loss, mse_per_sample = mse_result if return_per_sample else (mse_result, None) # compute cross entropy loss for discrete actions batch_size, seq_len = discrete_actions.shape @@ -2310,10 +2324,18 @@ def forward( discrete_action_is_pad = ~discrete_action_masks # convert into format where value for pad is True discrete_action_ce_loss = discrete_action_ce_loss * ~discrete_action_is_pad + ce_per_sample_loss = ( + ce_per_sample(discrete_action_ce_loss, ~discrete_action_is_pad) if return_per_sample else None + ) + # compute mean discrete_action_ce_loss = discrete_action_ce_loss.mean() - return {"MSE": mse_loss, "CE": discrete_action_ce_loss} + out: dict[str, Tensor | PerSampleLoss] = {"MSE": mse_loss, "CE": discrete_action_ce_loss} + if return_per_sample: + out["MSE_per_sample"] = mse_per_sample + out["CE_per_sample"] = ce_per_sample_loss + return out def _build_suffix_items(self, x_t: Tensor) -> list[ContextItem]: """Default suffix layout: a single ``"action"`` block for ``x_t``. diff --git a/src/opentau/policies/utils.py b/src/opentau/policies/utils.py index aff8d82d..818b634e 100644 --- a/src/opentau/policies/utils.py +++ b/src/opentau/policies/utils.py @@ -24,10 +24,11 @@ import logging from collections import deque +from dataclasses import dataclass import torch import torch.nn.functional as F # noqa: N812 -from einops import rearrange +from einops import rearrange, reduce from torch import Tensor, nn @@ -110,6 +111,29 @@ def get_output_shape(module: nn.Module, input_shape: tuple) -> tuple: return tuple(output.shape) +@dataclass +class PerSampleLoss: + """Per-sample decomposition of a masked loss, with the batch dim kept. + + ``sum`` and ``count`` are both ``(B,)`` and hold, for each sample, the + summed unmasked loss and the number of unmasked slots that fed it. The + masked *mean* for any group of samples is ``Σsum / Σcount`` — carrying the + (numerator, denominator) pair rather than a per-sample mean is what lets a + caller (e.g. the validation loop) regroup samples by provenance and recover + an exact masked mean per group. Averaging per-sample means instead would + double-normalize and weight a 1-slot sample the same as a 200-slot one. + """ + + sum: Tensor + count: Tensor + + def __add__(self, other: "PerSampleLoss") -> "PerSampleLoss": + # Pool several loss components (e.g. discrete-action CE + response CE) + # into one (numerator, denominator); the pooled per-slot mean is then + # Σ(sum_i) / Σ(count_i) over the pooled slots. + return PerSampleLoss(sum=self.sum + other.sum, count=self.count + other.count) + + def make_action_dim_mask( real_action_dim: Tensor | None, max_action_dim: int, @@ -164,7 +188,8 @@ def flow_matching_masked_mse( prefix_mask: Tensor | None = None, actions_is_pad: Tensor | None = None, real_action_dim: Tensor | None = None, -) -> Tensor: + return_per_sample: bool = False, +) -> Tensor | tuple[Tensor, PerSampleLoss]: """Masked MSE for flow-matching velocity-field training. Shared across pi05, pi05_mem, pi06, pi07 (low_level), and pi07_paligemma @@ -195,9 +220,17 @@ def flow_matching_masked_mse( action chunk is padded (no real action target). ``None`` ⇒ all-False. real_action_dim: Optional long ``(B,)`` — real (pre-pad) action dim per sample. ``None`` ⇒ all-True (every dim is real). + return_per_sample: When True, additionally return a :class:`PerSampleLoss` + holding the per-sample ``(Σ over masked slots, #masked slots)`` so the + caller can regroup the loss by provenance. The scalar is computed + exactly as in the default path (bit-identical), so toggling this flag + never perturbs the training reduction. Returns: - Scalar tensor: masked mean of ``(u_t - v_t)**2`` over the unmasked slots. + Scalar tensor (masked mean of ``(u_t - v_t)**2`` over the unmasked slots) + when ``return_per_sample`` is False; otherwise ``(scalar, PerSampleLoss)`` + where the per-sample ``sum``/``count`` are over the same masked slots + (so each sample's mean is ``sum / count``). """ mse_loss = F.mse_loss(u_t, v_t, reduction="none") bsz, chunk_size = u_t.shape[:2] @@ -210,7 +243,44 @@ def flow_matching_masked_mse( mse_loss = mse_loss[:, :, :max_action_dim] dim_mask = make_action_dim_mask(real_action_dim, max_action_dim, batch_size=bsz, device=u_t.device) full_mask = postfix_mask & rearrange(dim_mask, "b d -> b 1 d") - return (mse_loss * full_mask).sum() / (full_mask.sum() + 1e-8) + masked = mse_loss * full_mask + scalar = masked.sum() / (full_mask.sum() + 1e-8) + if not return_per_sample: + return scalar + per_sample = PerSampleLoss( + sum=reduce(masked, "b c d -> b", "sum"), + count=reduce(full_mask.float(), "b c d -> b", "sum"), + ) + return scalar, per_sample + + +def ce_per_sample(masked_ce: Tensor, valid_mask: Tensor) -> PerSampleLoss: + """Per-sample numerator/denominator for a masked token cross-entropy. + + Policies compute their CE as ``F.cross_entropy(..., reduction="none")`` + reshaped to ``(B, S)`` and zeroed at pad positions, then reduce it with + ``.mean()`` to a scalar. This helper takes that same pad-zeroed ``(B, S)`` + tensor plus the per-token validity mask and returns the per-sample + ``(Σ over valid tokens, #valid tokens)``, so a caller can pool CE per + provenance group as ``Σsum / Σcount`` — the mean cross-entropy per valid + token. Multiple CE components (e.g. discrete-action + response) pool by + adding their :class:`PerSampleLoss` objects. + + Note this normalizes by *valid* token count, unlike the legacy scalar + ``.mean()`` which divides by the full ``B * S`` (pad slots included); the + per-group breakdown is therefore over valid tokens only. + + Args: + masked_ce: ``(B, S)`` cross-entropy already zeroed at pad positions. + valid_mask: ``(B, S)`` bool, True at non-pad (scored) tokens. + + Returns: + ``PerSampleLoss`` whose ``sum`` and ``count`` are ``(B,)``. + """ + return PerSampleLoss( + sum=reduce(masked_ce, "b s -> b", "sum"), + count=reduce(valid_mask.float(), "b s -> b", "sum"), + ) def log_model_loading_keys(missing_keys: list[str], unexpected_keys: list[str]) -> None: diff --git a/src/opentau/policies/value/modeling_value.py b/src/opentau/policies/value/modeling_value.py index 9493486a..945bad9a 100644 --- a/src/opentau/policies/value/modeling_value.py +++ b/src/opentau/policies/value/modeling_value.py @@ -33,7 +33,7 @@ from opentau.policies.normalize import Normalize from opentau.policies.pretrained import PreTrainedPolicy -from opentau.policies.utils import log_model_loading_keys +from opentau.policies.utils import PerSampleLoss, ce_per_sample, log_model_loading_keys from opentau.policies.value.configuration_value import ValueConfig from opentau.policies.value.siglip_gemma import ( SiglipGemmaValueConfig, @@ -365,14 +365,22 @@ def predict_value(self, batch: dict[str, Tensor]) -> Tensor: logits = self.model.get_value(images, img_masks, lang_tokens, lang_masks) return self.calculate_value(logits) - def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict[str, Tensor] | None]: + def forward( + self, batch: dict[str, Tensor], return_per_sample: bool = False + ) -> dict[str, Tensor | PerSampleLoss]: """Do a full training forward pass to compute the value loss. Args: batch: Dictionary containing observations and target values + return_per_sample: When True, also returns per-sample + ``MSE_per_sample``/``CE_per_sample`` (:class:`PerSampleLoss`) for the + validation per-(dataset, control_mode) breakdown. ``MSE`` is a zero + stub here, so ``MSE_per_sample`` carries zero sum and count; the CE + pools the value-bin and response-token terms per sample. Returns: - Tuple of (loss_dict, None) where loss_dict contains the MSE loss + Dict with "MSE"/"CE"/"L1"/"Accuracy" (plus per-sample CE/MSE entries + when ``return_per_sample`` is True). """ # `ValueFunction` is a single-dataset policy (its `Normalize` was # built with `num_datasets=1`); `_resolve_dataset_index` defaults @@ -400,6 +408,11 @@ def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict[str, Tensor] | # Mask CE loss if all action_is_pad are true. This is used for VQA dataset where we don't have actions tokens. value_ce_loss = value_ce_loss * (~diff_mask).float() + # Per-sample value-bin CE (one scored token per robotic sample) for the val breakdown. + value_ce_per_sample = ( + PerSampleLoss(sum=value_ce_loss, count=(~diff_mask).float()) if return_per_sample else None + ) + value_ce_loss = value_ce_loss.mean() l1_loss = F.l1_loss(values, batch["return_continuous"]) @@ -426,15 +439,31 @@ def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict[str, Tensor] | # Mask response loss if all action_is_pad are true. This is used for Robotic dataset where we have at least one actions tokens. response_ce_loss = response_ce_loss * rearrange(diff_mask.float(), "b -> b 1") + # Per-sample response CE (valid VQA tokens) for the val breakdown. + response_ce_per_sample = ( + ce_per_sample( + response_ce_loss, + (~response_is_pad[:, response_slice]) & rearrange(diff_mask, "b -> b 1"), + ) + if return_per_sample + else None + ) + # compute mean response_ce_loss = response_ce_loss.mean() - return { + out: dict[str, Tensor | PerSampleLoss] = { "MSE": torch.zeros_like(value_ce_loss, requires_grad=False), "CE": value_ce_loss + response_ce_loss, "L1": l1_loss, "Accuracy": accuracy, } + if return_per_sample: + # MSE is a zero stub for the value head; emit zero sum/count. + zeros = torch.zeros(diff_mask.shape[0], device=diff_mask.device) + out["MSE_per_sample"] = PerSampleLoss(sum=zeros, count=zeros.clone()) + out["CE_per_sample"] = value_ce_per_sample + response_ce_per_sample + return out def prepare_images(self, batch): """Preprocesses images for the model. diff --git a/src/opentau/scripts/train.py b/src/opentau/scripts/train.py index e80d70ef..fc050040 100644 --- a/src/opentau/scripts/train.py +++ b/src/opentau/scripts/train.py @@ -14,6 +14,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import inspect import json import logging import os @@ -318,11 +319,11 @@ def _mixture_weighted_aggregate( """Mixture-weighted average of per-dataset validation metrics. Weights are taken from ``name_to_weight`` and renormalized over only the - names present in ``per_dataset_trackers`` (empty datasets are skipped - upstream by ``WeightedDatasetMixture.get_per_dataset_dataloaders`` and so - will be missing from the trackers). When the renormalization total is 0 - -- empty trackers, or all selected datasets have weight 0 -- every metric - is returned as ``0.0``. + keys present in ``per_dataset_trackers`` (groups with no validation samples + in the combined pass are simply absent from the trackers). When the + renormalization total is 0 -- empty trackers, or all selected groups have + weight 0 -- every metric is returned as ``0.0``. Keys are generic: the + validation path keys both trackers and weights by norm-head ``norm_key``. The aggregated metric keys are derived from the first tracker's meters. All per-dataset trackers share the same meter set because they're all @@ -331,11 +332,12 @@ def _mixture_weighted_aggregate( tracker's keys are representative. Args: - per_dataset_trackers: One ``MetricsTracker`` per non-empty validation - dataset, keyed by dataset name. - name_to_weight: Mapping from dataset name to its mixture weight (need - not be normalized; need not be a strict subset/superset of the - tracker keys, but must contain every tracker key). + per_dataset_trackers: One ``MetricsTracker`` per group present in the + validation pass, keyed by group identifier (the validation path keys + by norm-head ``norm_key``). + name_to_weight: Mapping from that same group identifier to its mixture + weight (need not be normalized; need not be a strict subset/superset + of the tracker keys, but must contain every tracker key). Returns: Dict mapping each metric name found on the trackers to its weighted @@ -358,6 +360,39 @@ def _mixture_weighted_aggregate( } +def _bucket_per_sample( + group_index: torch.Tensor, + mse_sum: torch.Tensor, + mse_count: torch.Tensor, + ce_sum: torch.Tensor, + ce_count: torch.Tensor, +) -> dict[int, dict[str, float]]: + """Bucket per-sample MSE/CE ``(sum, count)`` by an integer group key. + + All inputs are 1-D tensors of equal length: ``group_index`` is the per-sample + integer group (norm-head row, or per-dataset enumerate index), and the others + are the per-sample numerator/denominator from the policies' ``PerSampleLoss`` + outputs. Returns ``{group: {"mse_sum", "mse_count", "ce_sum", "ce_count"}}`` + summed over the samples in each group; the masked mean for a group is then + ``Σsum / Σcount``. + + Pure / CPU-only (no collectives), so the validation loop can call it on rank 0 + after gathering, and it is unit-testable without an accelerator. Disjoint groups + never cross-contaminate, which is exactly what lets a *mixed* batch (the combined + val pass) be disaggregated by provenance. + """ + buckets: dict[int, dict[str, float]] = {} + for g in torch.unique(group_index).tolist(): + mask = group_index == g + buckets[int(g)] = { + "mse_sum": float(mse_sum[mask].sum()), + "mse_count": float(mse_count[mask].sum()), + "ce_sum": float(ce_sum[mask].sum()), + "ce_count": float(ce_count[mask].sum()), + } + return buckets + + def _find_unused_params_from_env() -> bool: """Parse the ``FIND_UNUSED_PARAMS`` env var into a bool. @@ -641,19 +676,18 @@ def train(cfg: TrainPipelineConfig): if cfg.val_freq > 0: train_dataloader = train_dataset.get_dataloader() - # One DataLoader per underlying val dataset so we can report per-dataset - # validation losses. The aggregate is computed by averaging across all. - per_dataset_val_dataloaders = val_dataset.get_per_dataset_dataloaders() - val_names = list(per_dataset_val_dataloaders.keys()) - prepared = accelerator.prepare( - policy, - optimizer, - train_dataloader, - lr_scheduler, - *per_dataset_val_dataloaders.values(), + # A single combined validation pass over the whole mixture: under + # ``accelerator.prepare`` every rank's shard stays full even when individual + # subsets have fewer frames than ``world_size`` (the old one-loader-per-dataset + # path left most ranks idle on tiny subsets and stacked that idle time). The + # validation loop disaggregates metrics per (dataset, control_mode) from each + # sample's ``dataset_index`` / ``dataset_repo_id`` provenance instead. + val_dataloader = val_dataset.get_combined_val_dataloader() + policy, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + policy, optimizer, train_dataloader, lr_scheduler ) - policy, optimizer, train_dataloader, lr_scheduler = prepared[:4] - per_dataset_val_dataloaders = dict(zip(val_names, prepared[4:], strict=True)) + if val_dataloader is not None: + val_dataloader = accelerator.prepare(val_dataloader) else: train_dataloader = train_dataset.get_dataloader() policy, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( @@ -841,13 +875,10 @@ def _strip_norm_buffers_pre_save(models, weights, input_dir): accelerator.wait_for_everyone() - if is_val_step: + if is_val_step and val_dataloader is not None: policy.eval() def _make_val_tracker(current_step: int = step) -> MetricsTracker: - # ``l1_loss`` and ``accuracy`` are populated lazily below iff the - # policy's ``forward`` returns ``"L1"`` / ``"Accuracy"``. See - # ``update_policy`` for the symmetric training-side pattern. return MetricsTracker( cfg.batch_size * accelerator.num_processes, { @@ -858,92 +889,191 @@ def _make_val_tracker(current_step: int = step) -> MetricsTracker: initial_step=current_step, ) - per_dataset_trackers: dict[str, MetricsTracker] = { - name: _make_val_tracker() for name in per_dataset_val_dataloaders + # Per-(dataset, control_mode) breakdown needs per-sample losses. Action + # policies expose them via ``forward(..., return_per_sample=True)``; other + # policies (e.g. high-level planners) don't, so fall back to an + # aggregate-only pass for them. The single combined loader — and its + # full-shard parallelization — applies to every policy regardless. + unwrapped_policy = accelerator.unwrap_model(policy) + supports_per_sample = ( + "return_per_sample" in inspect.signature(unwrapped_policy.forward).parameters + ) + mse_w = cfg.loss_weighting["MSE"] + ce_w = cfg.loss_weighting["CE"] + outlier_threshold = getattr(unwrapped_policy.config, "warn_outlier_threshold", None) + name_to_index = val_dataset.meta.dataset_name_to_index + + # Rank-0 accumulators across the whole pass. Per-sample tensors are gathered + # every batch (a collective on all ranks) but only retained on the main + # process, then bucketed once at the end. + ps_chunks: dict[str, list[torch.Tensor]] = { + "mse_sum": [], + "mse_count": [], + "ce_sum": [], + "ce_count": [], + "norm_index": [], + "source_index": [], } + agg_tracker = _make_val_tracker() # only used on the fallback (no per-sample) path + optional_vals: dict[str, list[float]] = {"l1_loss": [], "accuracy": []} logging.info(f"Validation at step {step}...") with torch.no_grad(): - for ds_name, ds_loader in per_dataset_val_dataloaders.items(): - ds_tracker = per_dataset_trackers[ds_name] - for batch in ds_loader: - losses = policy.forward(batch) - outlier_records = losses.pop("outlier_records", []) - loss = ( - cfg.loss_weighting["MSE"] * losses["MSE"] - + cfg.loss_weighting["CE"] * losses["CE"] - ) + for batch in val_dataloader: + losses = ( + policy.forward(batch, return_per_sample=True) + if supports_per_sample + else policy.forward(batch) + ) + outlier_records = losses.pop("outlier_records", []) + + # Optional value-head metrics (scalars); aggregate-level only. Gather + # outside the per-sample branch so both paths handle them identically. + l1_val = ( + accelerator.gather_for_metrics(losses["L1"]).to(dtype=torch.float32).mean().item() + if "L1" in losses + else None + ) + accuracy_val = ( + accelerator.gather_for_metrics(losses["Accuracy"]) + .to(dtype=torch.float32) + .mean() + .item() + if "Accuracy" in losses + else None + ) - # Gather and average metrics across processes. ``L1`` / - # ``Accuracy`` are optional — see ``update_policy`` for - # the symmetric training-side gating rationale. - loss = accelerator.gather_for_metrics(loss).mean().item() - mse_loss = ( + if supports_per_sample: + mse_ps = losses["MSE_per_sample"] + ce_ps = losses["CE_per_sample"] + # Per-sample provenance keys: ``dataset_index`` is the dropout-immune + # norm-head row; ``source_index`` is the per-dataset enumerate index + # recovered from the (also dropout-immune) ``dataset_repo_id``. Both + # gathered as int tensors so the ragged last-batch de-pad stays + # row-aligned with the loss tensors (``gather_object`` would not). + norm_index = batch["dataset_index"].to(device=accelerator.device, dtype=torch.long) + source_index = torch.tensor( + [name_to_index[name] for name in batch["dataset_repo_id"]], + device=accelerator.device, + dtype=torch.long, + ) + gathered = { + "mse_sum": accelerator.gather_for_metrics(mse_ps.sum.to(dtype=torch.float32)), + "mse_count": accelerator.gather_for_metrics(mse_ps.count.to(dtype=torch.float32)), + "ce_sum": accelerator.gather_for_metrics(ce_ps.sum.to(dtype=torch.float32)), + "ce_count": accelerator.gather_for_metrics(ce_ps.count.to(dtype=torch.float32)), + "norm_index": accelerator.gather_for_metrics(norm_index), + "source_index": accelerator.gather_for_metrics(source_index), + } + else: + loss = mse_w * losses["MSE"] + ce_w * losses["CE"] + loss_val = accelerator.gather_for_metrics(loss).to(dtype=torch.float32).mean().item() + mse_val = ( accelerator.gather_for_metrics(losses["MSE"]) .to(dtype=torch.float32) .mean() .item() ) - ce_loss = ( + ce_val = ( accelerator.gather_for_metrics(losses["CE"]).to(dtype=torch.float32).mean().item() ) - l1_loss = ( - accelerator.gather_for_metrics(losses["L1"]).to(dtype=torch.float32).mean().item() - if "L1" in losses - else None - ) - accuracy = ( - accelerator.gather_for_metrics(losses["Accuracy"]) - .to(dtype=torch.float32) - .mean() - .item() - if "Accuracy" in losses - else None - ) - # Warn about outliers on rank 0 too during eval (parity with training); see - # ``update_policy`` for the rank-uniform gating rationale. - outlier_threshold = getattr( - accelerator.unwrap_model(policy).config, "warn_outlier_threshold", None - ) - if outlier_threshold is not None and outlier_threshold > 0: - log_outlier_records_distributed(accelerator, outlier_records, outlier_threshold) - - if accelerator.is_main_process: - ds_tracker.loss = loss - ds_tracker.mse_loss = mse_loss - ds_tracker.ce_loss = ce_loss - _observe_optional(ds_tracker, "l1_loss", "val_l1_loss", ":.6f", l1_loss) - _observe_optional(ds_tracker, "accuracy", "val_accuracy", ":.3f", accuracy) + # Warn about outliers on rank 0 too during eval (parity with training); see + # ``update_policy`` for the rank-uniform gating rationale. + if outlier_threshold is not None and outlier_threshold > 0: + log_outlier_records_distributed(accelerator, outlier_records, outlier_threshold) + + if accelerator.is_main_process: + if supports_per_sample: + for key, value in gathered.items(): + ps_chunks[key].append(value.cpu()) + else: + agg_tracker.loss = loss_val + agg_tracker.mse_loss = mse_val + agg_tracker.ce_loss = ce_val + if l1_val is not None: + optional_vals["l1_loss"].append(l1_val) + if accuracy_val is not None: + optional_vals["accuracy"].append(accuracy_val) if accelerator.is_main_process: - for ds_name, ds_tracker in per_dataset_trackers.items(): - logging.info(f"Validation/{ds_name} {ds_tracker}") - ds_dict = ds_tracker.to_dict(use_avg=True) - accelerator.log({f"Validation/{ds_name}/Loss": ds_dict["loss"]}, step=step) - accelerator.log({f"Validation/{ds_name}/MSE Loss": ds_dict["mse_loss"]}, step=step) - accelerator.log({f"Validation/{ds_name}/CE Loss": ds_dict["ce_loss"]}, step=step) - if "l1_loss" in ds_tracker.metrics: - accelerator.log({f"Validation/{ds_name}/L1 Loss": ds_dict["l1_loss"]}, step=step) - if "accuracy" in ds_tracker.metrics: - accelerator.log({f"Validation/{ds_name}/Accuracy": ds_dict["accuracy"]}, step=step) - - # Mixture-weighted aggregate across the per-dataset trackers, so the - # overall scalar reflects the training mixture rather than being - # implicitly dominated by whichever val subset has the most batches. + meta = val_dataset.meta name_to_weight = dict( zip(val_dataset.dataset_names, val_dataset.dataset_weights, strict=True) ) - agg = _mixture_weighted_aggregate(per_dataset_trackers, name_to_weight) - logging.info(f"Validation/aggregate {agg}") - accelerator.log({"Validation/Loss": agg["loss"]}, step=step) - accelerator.log({"Validation/MSE Loss": agg["mse_loss"]}, step=step) - accelerator.log({"Validation/CE Loss": agg["ce_loss"]}, step=step) - if "l1_loss" in agg: - accelerator.log({"Validation/L1 Loss": agg["l1_loss"]}, step=step) - if "accuracy" in agg: - accelerator.log({"Validation/Accuracy": agg["accuracy"]}, step=step) + + if supports_per_sample and ps_chunks["norm_index"]: + norm_index = torch.cat(ps_chunks["norm_index"]) + source_index = torch.cat(ps_chunks["source_index"]) + mse_sum = torch.cat(ps_chunks["mse_sum"]) + mse_count = torch.cat(ps_chunks["mse_count"]) + ce_sum = torch.cat(ps_chunks["ce_sum"]) + ce_count = torch.cat(ps_chunks["ce_count"]) + norm_groups = _bucket_per_sample(norm_index, mse_sum, mse_count, ce_sum, ce_count) + source_groups = _bucket_per_sample(source_index, mse_sum, mse_count, ce_sum, ce_count) + + # Headline breakdown: one tracker per norm head (``robot::control_mode`` + # or a fallback dataset name), keyed by ``norm_key`` for the aggregate. + per_normkey_trackers: dict[str, MetricsTracker] = {} + for norm_row, group in norm_groups.items(): + norm_key = meta.norm_keys[norm_row] + mse_mean = group["mse_sum"] / (group["mse_count"] + 1e-8) + ce_mean = group["ce_sum"] / (group["ce_count"] + 1e-8) + loss_mean = mse_w * mse_mean + ce_w * ce_mean + tracker = _make_val_tracker() + tracker.mse_loss = mse_mean + tracker.ce_loss = ce_mean + tracker.loss = loss_mean + per_normkey_trackers[norm_key] = tracker + logging.info(f"Validation/{norm_key} {tracker}") + accelerator.log({f"Validation/{norm_key}/Loss": loss_mean}, step=step) + accelerator.log({f"Validation/{norm_key}/MSE Loss": mse_mean}, step=step) + accelerator.log({f"Validation/{norm_key}/CE Loss": ce_mean}, step=step) + + # Finer per-source-dataset breakdown (no info loss when datasets share + # a norm head). Logged-only; the aggregate stays per-norm-head. + for source_row, group in source_groups.items(): + ds_name = meta.dataset_names[source_row] + mse_mean = group["mse_sum"] / (group["mse_count"] + 1e-8) + ce_mean = group["ce_sum"] / (group["ce_count"] + 1e-8) + loss_mean = mse_w * mse_mean + ce_w * ce_mean + accelerator.log({f"Validation/by_dataset/{ds_name}/Loss": loss_mean}, step=step) + accelerator.log({f"Validation/by_dataset/{ds_name}/MSE Loss": mse_mean}, step=step) + accelerator.log({f"Validation/by_dataset/{ds_name}/CE Loss": ce_mean}, step=step) + + # Mixture-weighted aggregate over norm heads, so the overall scalar + # reflects the training mixture rather than being dominated by whichever + # head has the most val frames. Sum each head's member-dataset weights. + norm_key_weight: dict[str, float] = {} + for name, weight in name_to_weight.items(): + nk = meta.norm_keys[meta.dataset_to_norm_index[name]] + norm_key_weight[nk] = norm_key_weight.get(nk, 0.0) + weight + agg = _mixture_weighted_aggregate(per_normkey_trackers, norm_key_weight) + elif not supports_per_sample: + # Fallback: the combined-pass batch-mean aggregate; no per-group breakdown. + agg = agg_tracker.to_dict(use_avg=True) + else: + agg = {} + + if agg: + logging.info(f"Validation/aggregate {agg}") + accelerator.log({"Validation/Loss": agg["loss"]}, step=step) + accelerator.log({"Validation/MSE Loss": agg["mse_loss"]}, step=step) + accelerator.log({"Validation/CE Loss": agg["ce_loss"]}, step=step) + if optional_vals["l1_loss"]: + accelerator.log( + {"Validation/L1 Loss": sum(optional_vals["l1_loss"]) / len(optional_vals["l1_loss"])}, + step=step, + ) + if optional_vals["accuracy"]: + accelerator.log( + { + "Validation/Accuracy": sum(optional_vals["accuracy"]) + / len(optional_vals["accuracy"]) + }, + step=step, + ) # This barrier is probably necessary to ensure # other processes wait for the main process to finish saving diff --git a/tests/datasets/test_dataset_mixture.py b/tests/datasets/test_dataset_mixture.py index 55dbe500..69647a98 100644 --- a/tests/datasets/test_dataset_mixture.py +++ b/tests/datasets/test_dataset_mixture.py @@ -324,6 +324,34 @@ def test_get_dataloader_success(self, train_pipeline_config, datasets_factory): assert dataloader.batch_size == train_pipeline_config.batch_size assert dataloader.num_workers == train_pipeline_config.num_workers + def test_get_combined_val_dataloader(self, train_pipeline_config, datasets_factory): + """One deterministic sequential loader over the whole mixture (the + validation path): every sample is scored exactly once, in order. + + The loader serves ``concatenated_dataset`` — the same ``_TaggedDataset``- + wrapped mixture ``get_dataloader`` uses — so each sample carries the + ``dataset_index`` / ``dataset_repo_id`` provenance the validation loop + buckets on (the wrapper injection is covered by the integration tests + below). Coverage and order are asserted via the sampler (pure Python), + avoiding a fetch through the fake fixture's non-picklable worker dataset.""" + datasets = datasets_factory(2) + mixture = WeightedDatasetMixture(train_pipeline_config, datasets, [0.7, 0.3], 30.0) + + dataloader = mixture.get_combined_val_dataloader() + + assert dataloader is not None + assert dataloader.batch_size == train_pipeline_config.dataloader_batch_size + # Serves the tagged concat mixture (so provenance keys ride along). + assert dataloader.dataset is mixture.concatenated_dataset + # Sequential + deterministic: no weighted/random sampler, nothing dropped. + assert isinstance(dataloader.sampler, torch.utils.data.SequentialSampler) + assert dataloader.drop_last is False + + total = sum(len(ds) for ds in mixture.datasets) + assert len(dataloader.dataset) == total + # Every sample exactly once, in concat order — this is what the loader iterates. + assert list(dataloader.sampler) == list(range(total)) + @pytest.mark.slow # 1 sec def test_get_dataloader_zero_weights_error(self, train_pipeline_config, datasets_factory): """Test dataloader creation with zero weights raises error.""" diff --git a/tests/policies/test_pi06.py b/tests/policies/test_pi06.py index db1e0d43..2906c98a 100644 --- a/tests/policies/test_pi06.py +++ b/tests/policies/test_pi06.py @@ -41,6 +41,7 @@ pad_discrete_tokens, resize_with_pad, ) +from opentau.policies.utils import PerSampleLoss, ce_per_sample # Block-causal attention mask (pi05 / π0.6 prefix-LM pattern) @@ -239,6 +240,105 @@ def test_prefix_mask_excludes_frozen_steps(self): # PI06Config defaults + validators +class TestFlowMatchingMaskedMsePerSample: + """``return_per_sample=True`` exposes per-sample ``(sum, count)`` without + perturbing the scalar — the contract the validation per-(dataset, + control_mode) breakdown relies on. Bit-identical scalars keep training + determinism intact (CLAUDE.md hard rule).""" + + def test_scalar_is_bit_identical_to_default(self): + torch.manual_seed(0) + u_t = torch.randn(3, 8, 4) + v_t = torch.randn(3, 8, 4) + actions_is_pad = torch.zeros(3, 8, dtype=torch.bool) + scalar_only = flow_matching_masked_mse( + u_t=u_t, v_t=v_t, actions_is_pad=actions_is_pad, max_action_dim=4 + ) + scalar, ps = flow_matching_masked_mse( + u_t=u_t, + v_t=v_t, + actions_is_pad=actions_is_pad, + max_action_dim=4, + return_per_sample=True, + ) + # Same float ops as the default path ⇒ exactly equal, not just close. + assert torch.equal(scalar, scalar_only) + assert ps.sum.shape == (3,) + assert ps.count.shape == (3,) + + def test_regrouped_mean_matches_scalar(self): + torch.manual_seed(0) + u_t = torch.randn(3, 8, 4) + v_t = torch.randn(3, 8, 4) + actions_is_pad = torch.zeros(3, 8, dtype=torch.bool) + scalar, ps = flow_matching_masked_mse( + u_t=u_t, + v_t=v_t, + actions_is_pad=actions_is_pad, + max_action_dim=4, + return_per_sample=True, + ) + # Σsum / Σcount over the whole batch reproduces the masked mean exactly. + assert torch.isclose(ps.sum.sum() / (ps.count.sum() + 1e-8), scalar, atol=1e-6) + + def test_per_sample_count_excludes_padded_steps_and_dims(self): + torch.manual_seed(1) + u_t = torch.randn(2, 8, 6) + v_t = torch.randn(2, 8, 6) + # sample 0: all 8 steps real; sample 1: only first 4 steps real. + actions_is_pad = torch.tensor([[False] * 8, [False, False, False, False, True, True, True, True]]) + # sample 0: 4 real action dims; sample 1: all 6. + real_action_dim = torch.tensor([4, 6], dtype=torch.long) + _, ps = flow_matching_masked_mse( + u_t=u_t, + v_t=v_t, + actions_is_pad=actions_is_pad, + max_action_dim=6, + real_action_dim=real_action_dim, + return_per_sample=True, + ) + assert ps.count[0].item() == pytest.approx(8 * 4) + assert ps.count[1].item() == pytest.approx(4 * 6) + # Independent reference for sample 0's numerator (8 steps × dims [0:4]). + ref0 = ((u_t[0, :, :4] - v_t[0, :, :4]) ** 2).sum() + assert torch.isclose(ps.sum[0], ref0, atol=1e-5) + + def test_fully_padded_sample_has_zero_sum_and_count(self): + torch.manual_seed(2) + u_t = torch.randn(2, 8, 4) + v_t = torch.randn(2, 8, 4) + actions_is_pad = torch.zeros(2, 8, dtype=torch.bool) + actions_is_pad[1] = True # sample 1 fully padded ⇒ contributes nothing + _, ps = flow_matching_masked_mse( + u_t=u_t, + v_t=v_t, + actions_is_pad=actions_is_pad, + max_action_dim=4, + return_per_sample=True, + ) + assert ps.sum[1].item() == 0.0 + assert ps.count[1].item() == 0.0 + + +class TestCePerSample: + """``ce_per_sample`` / ``PerSampleLoss`` arithmetic used by the per-group CE.""" + + def test_sum_and_count_over_valid_tokens(self): + # (B=2, S=3): sample 0 has 2 valid tokens, sample 1 has 3. + ce = torch.tensor([[1.0, 2.0, 5.0], [0.5, 0.5, 1.0]]) + valid = torch.tensor([[True, True, False], [True, True, True]]) + ps = ce_per_sample(ce * valid, valid) + assert torch.equal(ps.sum, torch.tensor([3.0, 2.0])) + assert torch.equal(ps.count, torch.tensor([2.0, 3.0])) + + def test_pooling_two_components_adds_sums_and_counts(self): + a = PerSampleLoss(sum=torch.tensor([1.0, 2.0]), count=torch.tensor([1.0, 1.0])) + b = PerSampleLoss(sum=torch.tensor([3.0, 4.0]), count=torch.tensor([2.0, 2.0])) + pooled = a + b + assert torch.equal(pooled.sum, torch.tensor([4.0, 6.0])) + assert torch.equal(pooled.count, torch.tensor([3.0, 3.0])) + + class TestPI06Config: def test_defaults_match_pi06_spec(self): cfg = PI06Config() diff --git a/tests/scripts/test_train.py b/tests/scripts/test_train.py index b8a02823..4df39bbb 100644 --- a/tests/scripts/test_train.py +++ b/tests/scripts/test_train.py @@ -27,8 +27,10 @@ import accelerate import pytest +import torch from opentau.scripts.train import ( + _bucket_per_sample, _commit_wandb_step, _find_unused_params_from_env, _mixture_weighted_aggregate, @@ -261,6 +263,51 @@ def _make_partial_tracker(loss: float, mse: float, ce: float) -> MetricsTracker: assert agg["ce_loss"] == pytest.approx(0.25 * 3.0 + 0.75 * 7.0) +class TestBucketPerSample: + """``_bucket_per_sample`` disaggregates a *mixed* batch by integer group key — + the core of the per-(dataset, control_mode) validation breakdown (issue #373). + Grouping on the dropout-immune ``dataset_index`` is what makes a mixed batch + decomposable; this is pure CPU arithmetic with no accelerator.""" + + def test_disaggregates_mixed_batch_without_cross_contamination(self): + group = torch.tensor([0, 1, 0, 1, 2]) + mse_sum = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0]) + mse_count = torch.tensor([1.0, 1.0, 1.0, 1.0, 1.0]) + ce_sum = torch.tensor([10.0, 20.0, 30.0, 40.0, 50.0]) + ce_count = torch.tensor([2.0, 2.0, 2.0, 2.0, 2.0]) + buckets = _bucket_per_sample(group, mse_sum, mse_count, ce_sum, ce_count) + assert set(buckets) == {0, 1, 2} + # group 0 = samples {0, 2}; groups 1/2 must not leak in. + assert buckets[0]["mse_sum"] == pytest.approx(4.0) + assert buckets[0]["mse_count"] == pytest.approx(2.0) + assert buckets[0]["ce_sum"] == pytest.approx(40.0) + assert buckets[0]["ce_count"] == pytest.approx(4.0) + assert buckets[1]["mse_sum"] == pytest.approx(6.0) + assert buckets[2]["mse_sum"] == pytest.approx(5.0) + + def test_group_mean_is_sum_over_count_not_mean_of_means(self): + # Heterogeneous counts: the regrouped mean must be Σsum/Σcount, so a + # 3-slot sample outweighs a 1-slot sample (mean-of-means would not). + group = torch.tensor([0, 0]) + s = torch.tensor([2.0, 9.0]) + c = torch.tensor([1.0, 3.0]) + buckets = _bucket_per_sample(group, s, c, s, c) + mean = buckets[0]["mse_sum"] / buckets[0]["mse_count"] + assert mean == pytest.approx(11.0 / 4.0) + + def test_zero_count_group_carries_zeros_for_caller_guard(self): + # pi0 emits zero-count CE; the bucket keeps 0/0 and the caller's +1e-8 + # turns it into a 0 mean (no NaN). + group = torch.tensor([0, 1]) + mse = torch.tensor([1.0, 2.0]) + cnt = torch.tensor([1.0, 1.0]) + zeros = torch.tensor([0.0, 0.0]) + buckets = _bucket_per_sample(group, mse, cnt, zeros, zeros) + assert buckets[0]["ce_sum"] == 0.0 + assert buckets[0]["ce_count"] == 0.0 + assert buckets[0]["ce_sum"] / (buckets[0]["ce_count"] + 1e-8) == pytest.approx(0.0) + + class TestCommitWandbStep: """``_commit_wandb_step`` issues exactly one empty, commit-tagged log.""" From 8559365e5173c98b7abf03634a9d387174a90979 Mon Sep 17 00:00:00 2001 From: Shuheng Liu Date: Tue, 2 Jun 2026 08:40:33 -0700 Subject: [PATCH 2/2] refactor(train): batch per-sample val gathers into one gather call MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Gather the per-sample MSE/CE (sum, count) and the provenance index tensors (norm_index, source_index) in a single accelerator.gather_for_metrics call so the ragged last-batch de-pad is applied once and identically to every entry — losses and their provenance stay row-aligned by construction rather than relying on separate de-pads landing on the same trim. Addresses a review note. --- src/opentau/scripts/train.py | 22 ++++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/src/opentau/scripts/train.py b/src/opentau/scripts/train.py index fc050040..3d14f26a 100644 --- a/src/opentau/scripts/train.py +++ b/src/opentau/scripts/train.py @@ -958,14 +958,20 @@ def _make_val_tracker(current_step: int = step) -> MetricsTracker: device=accelerator.device, dtype=torch.long, ) - gathered = { - "mse_sum": accelerator.gather_for_metrics(mse_ps.sum.to(dtype=torch.float32)), - "mse_count": accelerator.gather_for_metrics(mse_ps.count.to(dtype=torch.float32)), - "ce_sum": accelerator.gather_for_metrics(ce_ps.sum.to(dtype=torch.float32)), - "ce_count": accelerator.gather_for_metrics(ce_ps.count.to(dtype=torch.float32)), - "norm_index": accelerator.gather_for_metrics(norm_index), - "source_index": accelerator.gather_for_metrics(source_index), - } + # Gather all per-sample tensors in a SINGLE call so accelerate applies + # one identical ragged-last-batch de-pad to every entry: the per-sample + # losses and their provenance indices stay row-aligned by construction, + # rather than relying on six separate de-pads landing on the same trim. + gathered = accelerator.gather_for_metrics( + { + "mse_sum": mse_ps.sum.to(dtype=torch.float32), + "mse_count": mse_ps.count.to(dtype=torch.float32), + "ce_sum": ce_ps.sum.to(dtype=torch.float32), + "ce_count": ce_ps.count.to(dtype=torch.float32), + "norm_index": norm_index, + "source_index": source_index, + } + ) else: loss = mse_w * losses["MSE"] + ce_w * losses["CE"] loss_val = accelerator.gather_for_metrics(loss).to(dtype=torch.float32).mean().item()