diff --git a/src/opentau/datasets/dataset_mixture.py b/src/opentau/datasets/dataset_mixture.py index db72342..29c4a12 100644 --- a/src/opentau/datasets/dataset_mixture.py +++ b/src/opentau/datasets/dataset_mixture.py @@ -975,3 +975,46 @@ def get_per_dataset_dataloaders(self) -> dict[str, DataLoader]: worker_init_fn=worker_init_fn, ) return loaders + + def get_combined_val_dataloader(self) -> DataLoader | None: + """Create one deterministic sequential DataLoader over the whole mixture. + + Unlike :meth:`get_per_dataset_dataloaders` (one loader per dataset), this + returns a single loader over the concatenated mixture so that, under + ``accelerator.prepare``, every rank's shard is full even when individual + validation subsets have fewer frames than ``world_size`` — the per-dataset + loaders leave most ranks idle on tiny subsets and stack that idle time + across datasets. Each sample still carries its ``dataset_index`` / + ``dataset_repo_id`` provenance (injected by ``_TaggedDataset``), so the + validation loop can disaggregate metrics per ``(dataset, control_mode)`` + from the batch rather than relying on homogeneous per-dataset loaders. + ``shuffle=False`` + ``drop_last=False`` make the pass order-deterministic + (seed-independent) and score every sample exactly once. + + Returns: + A single ``DataLoader`` over the mixture, or ``None`` when the mixture + is empty (mirrors the empty-dataset skip in + :meth:`get_per_dataset_dataloaders`). + """ + if len(self.concatenated_dataset) == 0: + logging.info("Combined validation DataLoader skipped: the mixture is empty.") + return None + + worker_name_mapping_overrides = self._get_worker_name_mapping_overrides() + worker_init_fn = None + if worker_name_mapping_overrides: + worker_init_fn = functools.partial( + _apply_data_feature_name_mapping_overrides, + mapping_overrides=worker_name_mapping_overrides, + ) + + return DataLoader( + self.concatenated_dataset, + batch_size=self.cfg.dataloader_batch_size, + shuffle=False, + num_workers=self.cfg.num_workers, + pin_memory=torch.cuda.is_available(), + drop_last=False, + prefetch_factor=self.cfg.prefetch_factor, + worker_init_fn=worker_init_fn, + ) diff --git a/src/opentau/policies/pi0/modeling_pi0.py b/src/opentau/policies/pi0/modeling_pi0.py index 6b81451..85850f4 100644 --- a/src/opentau/policies/pi0/modeling_pi0.py +++ b/src/opentau/policies/pi0/modeling_pi0.py @@ -25,7 +25,7 @@ import torch import torch.nn.functional as F # noqa: N812 -from einops import rearrange +from einops import rearrange, reduce from torch import Tensor, nn from transformers import AutoTokenizer @@ -37,7 +37,7 @@ PaliGemmaWithExpertModel, ) from opentau.policies.pretrained import PreTrainedPolicy -from opentau.policies.utils import log_model_loading_keys, make_action_dim_mask +from opentau.policies.utils import PerSampleLoss, log_model_loading_keys, make_action_dim_mask from opentau.utils.accelerate_utils import get_proc_accelerator from opentau.utils.utils import get_safe_dtype @@ -466,14 +466,22 @@ def sample_actions(self, batch: dict[str, Tensor], noise: Tensor | None = None) return actions def forward( - self, batch: dict[str, Tensor], noise: Tensor | None = None, time: Tensor | None = None - ) -> dict[str, Tensor]: + self, + batch: dict[str, Tensor], + noise: Tensor | None = None, + time: Tensor | None = None, + return_per_sample: bool = False, + ) -> dict[str, Tensor | PerSampleLoss]: """Do a full training forward pass to compute the loss. Args: batch: Batch of data containing environment observations, actions, and targets. noise: Optional noise tensor. time: Optional time tensor. + return_per_sample: When True, also returns per-sample + ``MSE_per_sample``/``CE_per_sample`` (:class:`PerSampleLoss`) for the + validation per-(dataset, control_mode) breakdown. ``CE`` is a zero + stub for pi0, so ``CE_per_sample`` carries zero sum and count. Returns: A dictionary containing the loss components ("MSE" and "CE"). @@ -537,7 +545,21 @@ def forward( loss = losses.sum() / (full_mask.sum() + 1e-8) - return {"MSE": loss, "CE": torch.zeros_like(loss, requires_grad=True)} + out: dict[str, Tensor | PerSampleLoss] = { + "MSE": loss, + "CE": torch.zeros_like(loss, requires_grad=True), + } + if return_per_sample: + # ``losses`` is already masked (and AWR-weighted, if enabled), matching + # the scalar reduction; reduce over (chunk, dim) keeping the batch axis. + out["MSE_per_sample"] = PerSampleLoss( + sum=reduce(losses, "b c d -> b", "sum"), + count=reduce(full_mask.float(), "b c d -> b", "sum"), + ) + # pi0 has no CE term; emit zero sum/count so it forms no CE groups. + zeros = torch.zeros(losses.shape[0], device=losses.device) + out["CE_per_sample"] = PerSampleLoss(sum=zeros, count=zeros.clone()) + return out def prepare_images(self, batch: dict[str, Tensor]) -> tuple[list[Tensor], list[Tensor]]: """Apply Pi0 preprocessing to the images. diff --git a/src/opentau/policies/pi05/modeling_pi05.py b/src/opentau/policies/pi05/modeling_pi05.py index 8e6c831..2e4e4cd 100644 --- a/src/opentau/policies/pi05/modeling_pi05.py +++ b/src/opentau/policies/pi05/modeling_pi05.py @@ -44,7 +44,7 @@ PaliGemmaWithExpertModel, ) from opentau.policies.pretrained import PreTrainedPolicy, T -from opentau.policies.utils import flow_matching_masked_mse +from opentau.policies.utils import PerSampleLoss, ce_per_sample, flow_matching_masked_mse from opentau.utils.accelerate_utils import get_proc_accelerator from opentau.utils.utils import get_safe_dtype @@ -680,17 +680,26 @@ def sample_actions( return actions def forward( - self, batch: dict[str, Tensor], noise: Tensor | None = None, time: Tensor | None = None - ) -> dict[str, Tensor]: + self, + batch: dict[str, Tensor], + noise: Tensor | None = None, + time: Tensor | None = None, + return_per_sample: bool = False, + ) -> dict[str, Tensor | PerSampleLoss]: """Do a full training forward pass to compute the loss. Args: batch: Batch of data containing environment observations, actions, and targets. noise: Optional noise tensor. time: Optional time tensor. + return_per_sample: When True, additionally return per-sample ``MSE_per_sample`` + / ``CE_per_sample`` (:class:`PerSampleLoss`) so the validation loop can + bucket the loss by ``(dataset, control_mode)`` provenance. The scalar + ``MSE``/``CE`` are unchanged, so the training path is unaffected. Returns: - A dictionary containing the loss components ("MSE" and "CE"). + A dictionary with the loss components ("MSE" and "CE"), plus + "MSE_per_sample"/"CE_per_sample" when ``return_per_sample`` is True. """ dataset_index = self._resolve_dataset_index(batch) batch = self.normalize_inputs(batch, dataset_index) @@ -732,12 +741,14 @@ def forward( discrete_action_masks, state=state, real_action_dim=batch.get("real_action_dim"), + return_per_sample=return_per_sample, ) - mse_loss = losses["MSE"] - ce_loss = losses["CE"] - - return {"MSE": mse_loss, "CE": ce_loss} + out: dict[str, Tensor | PerSampleLoss] = {"MSE": losses["MSE"], "CE": losses["CE"]} + if return_per_sample: + out["MSE_per_sample"] = losses["MSE_per_sample"] + out["CE_per_sample"] = losses["CE_per_sample"] + return out def prepare_state(self, batch: dict[str, Tensor]) -> Tensor: """Prepares the continuous state tensor, padding or truncating to max_state_dim. @@ -1315,7 +1326,8 @@ def forward( discrete_action_masks: Tensor | None = None, state: Tensor | None = None, real_action_dim: Tensor | None = None, - ) -> dict[str, Tensor]: + return_per_sample: bool = False, + ) -> dict[str, Tensor | PerSampleLoss]: """Do a full training forward pass and compute the loss. Args: @@ -1438,14 +1450,16 @@ def forward( # Shared masked-MSE reduction: AND-s frozen-prefix, timestep-pad, and # dim-pad masks together and divides by the unmasked-slot count. See # ``opentau.policies.utils.flow_matching_masked_mse`` for the full spec. - mse_loss = flow_matching_masked_mse( + mse_result = flow_matching_masked_mse( u_t=u_t, v_t=v_t, max_action_dim=self.config.max_action_dim, prefix_mask=prefix_mask, actions_is_pad=actions_is_pad, real_action_dim=real_action_dim, + return_per_sample=return_per_sample, ) + mse_loss, mse_per_sample = mse_result if return_per_sample else (mse_result, None) # compute cross entropy loss for discrete actions batch_size, seq_len = discrete_actions.shape @@ -1467,6 +1481,11 @@ def forward( discrete_action_is_pad = ~discrete_action_masks # convert into format where value for pad is True discrete_action_ce_loss = discrete_action_ce_loss * ~discrete_action_is_pad + # Per-sample CE numerator/denominator over valid tokens, for the val breakdown. + discrete_action_ce_per_sample = ( + ce_per_sample(discrete_action_ce_loss, ~discrete_action_is_pad) if return_per_sample else None + ) + # compute mean discrete_action_ce_loss = discrete_action_ce_loss.mean() @@ -1505,12 +1524,30 @@ def forward( # helps to control loss for response tokens in case of robotic data and VQA data response_ce_loss = response_ce_loss * ~response_is_pad[:, response_slice] + # Per-sample response CE (valid tokens only) for the val breakdown. + response_ce_per_sample = ( + ce_per_sample(response_ce_loss, ~response_is_pad[:, response_slice]) + if return_per_sample + else None + ) + # compute mean response_ce_loss = response_ce_loss.mean() else: response_ce_loss = torch.tensor(0.0, device=mse_loss.device) - - return {"MSE": mse_loss, "CE": discrete_action_ce_loss + response_ce_loss} + response_ce_per_sample = None + + out: dict[str, Tensor | PerSampleLoss] = { + "MSE": mse_loss, + "CE": discrete_action_ce_loss + response_ce_loss, + } + if return_per_sample: + ce_ps = discrete_action_ce_per_sample + if response_ce_per_sample is not None: + ce_ps = ce_ps + response_ce_per_sample + out["MSE_per_sample"] = mse_per_sample + out["CE_per_sample"] = ce_ps + return out def sample_actions( self, diff --git a/src/opentau/policies/pi05_mem/modeling_pi05.py b/src/opentau/policies/pi05_mem/modeling_pi05.py index 14c38f2..66bc620 100644 --- a/src/opentau/policies/pi05_mem/modeling_pi05.py +++ b/src/opentau/policies/pi05_mem/modeling_pi05.py @@ -56,7 +56,7 @@ from opentau.policies.pi05_mem.configuration_pi05 import PI05MemConfig from opentau.policies.pi07.video_encoder import SpaceTimeSiglipVideoEncoder from opentau.policies.pretrained import PreTrainedPolicy, T -from opentau.policies.utils import flow_matching_masked_mse +from opentau.policies.utils import PerSampleLoss, ce_per_sample, flow_matching_masked_mse from opentau.utils.accelerate_utils import get_proc_accelerator from opentau.utils.utils import get_safe_dtype @@ -661,9 +661,19 @@ def sample_actions( return actions def forward( - self, batch: dict[str, Tensor], noise: Tensor | None = None, time: Tensor | None = None - ) -> dict[str, Tensor]: - """Do a full training forward pass to compute the loss.""" + self, + batch: dict[str, Tensor], + noise: Tensor | None = None, + time: Tensor | None = None, + return_per_sample: bool = False, + ) -> dict[str, Tensor | PerSampleLoss]: + """Do a full training forward pass to compute the loss. + + When ``return_per_sample`` is True, also returns per-sample + ``MSE_per_sample``/``CE_per_sample`` (:class:`PerSampleLoss`) for the + validation per-(dataset, control_mode) breakdown; the scalar losses are + unchanged. + """ dataset_index = self._resolve_dataset_index(batch) batch = self.normalize_inputs(batch, dataset_index) batch["discrete_actions"] = self.normalize_discrete_actions(dict(batch), dataset_index)["actions"] @@ -696,12 +706,14 @@ def forward( discrete_action_masks, obs_history_is_pad=obs_history_is_pad, real_action_dim=batch.get("real_action_dim"), + return_per_sample=return_per_sample, ) - mse_loss = losses["MSE"] - ce_loss = losses["CE"] - - return {"MSE": mse_loss, "CE": ce_loss} + out: dict[str, Tensor | PerSampleLoss] = {"MSE": losses["MSE"], "CE": losses["CE"]} + if return_per_sample: + out["MSE_per_sample"] = losses["MSE_per_sample"] + out["CE_per_sample"] = losses["CE_per_sample"] + return out def prepare_state(self, batch: dict[str, Tensor]) -> Tensor: """Prepares the temporal state tensor, padding or truncating to max_state_dim. @@ -1094,7 +1106,8 @@ def forward( discrete_action_masks: Tensor | None = None, obs_history_is_pad: Tensor | None = None, real_action_dim: Tensor | None = None, - ) -> dict[str, Tensor]: + return_per_sample: bool = False, + ) -> dict[str, Tensor | PerSampleLoss]: """Do a full training forward pass and compute the loss.""" prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix( videos, @@ -1178,14 +1191,16 @@ def forward( v_t = v_t.to(dtype=torch.float32) # Shared masked-MSE reduction; see pi05 for the rationale. - mse_loss = flow_matching_masked_mse( + mse_result = flow_matching_masked_mse( u_t=u_t, v_t=v_t, max_action_dim=self.config.max_action_dim, prefix_mask=prefix_mask, actions_is_pad=actions_is_pad, real_action_dim=real_action_dim, + return_per_sample=return_per_sample, ) + mse_loss, mse_per_sample = mse_result if return_per_sample else (mse_result, None) assert discrete_actions is not None assert discrete_action_masks is not None @@ -1206,9 +1221,17 @@ def forward( discrete_action_is_pad = ~discrete_action_masks discrete_action_ce_loss = discrete_action_ce_loss * ~discrete_action_is_pad + ce_per_sample_loss = ( + ce_per_sample(discrete_action_ce_loss, ~discrete_action_is_pad) if return_per_sample else None + ) + discrete_action_ce_loss = discrete_action_ce_loss.mean() - return {"MSE": mse_loss, "CE": discrete_action_ce_loss} + out: dict[str, Tensor | PerSampleLoss] = {"MSE": mse_loss, "CE": discrete_action_ce_loss} + if return_per_sample: + out["MSE_per_sample"] = mse_per_sample + out["CE_per_sample"] = ce_per_sample_loss + return out def sample_actions( self, diff --git a/src/opentau/policies/pi06/modeling_pi06.py b/src/opentau/policies/pi06/modeling_pi06.py index c6a3d27..d63ae53 100644 --- a/src/opentau/policies/pi06/modeling_pi06.py +++ b/src/opentau/policies/pi06/modeling_pi06.py @@ -51,7 +51,7 @@ Gemma3WithExpertModel, ) from opentau.policies.pretrained import PreTrainedPolicy, T -from opentau.policies.utils import flow_matching_masked_mse +from opentau.policies.utils import PerSampleLoss, ce_per_sample, flow_matching_masked_mse from opentau.utils.accelerate_utils import get_proc_accelerator from opentau.utils.utils import get_safe_dtype @@ -587,9 +587,18 @@ def sample_actions( return actions def forward( - self, batch: dict[str, Tensor], noise: Tensor | None = None, time: Tensor | None = None - ) -> dict[str, Tensor]: - """Full training forward pass. Returns `{"MSE": ..., "CE": ...}`.""" + self, + batch: dict[str, Tensor], + noise: Tensor | None = None, + time: Tensor | None = None, + return_per_sample: bool = False, + ) -> dict[str, Tensor | PerSampleLoss]: + """Full training forward pass. Returns `{"MSE": ..., "CE": ...}`. + + When ``return_per_sample`` is True, also returns ``MSE_per_sample`` / + ``CE_per_sample`` (:class:`PerSampleLoss`) for the validation + per-(dataset, control_mode) breakdown; the scalar losses are unchanged. + """ dataset_index = self._resolve_dataset_index(batch) batch = self.normalize_inputs(batch, dataset_index) batch["discrete_actions"] = self.normalize_discrete_actions(dict(batch), dataset_index)["actions"] @@ -616,11 +625,14 @@ def forward( discrete_actions, discrete_action_masks, real_action_dim=batch.get("real_action_dim"), + return_per_sample=return_per_sample, ) - mse_loss = losses["MSE"] - ce_loss = losses["CE"] - return {"MSE": mse_loss, "CE": ce_loss} + out: dict[str, Tensor | PerSampleLoss] = {"MSE": losses["MSE"], "CE": losses["CE"]} + if return_per_sample: + out["MSE_per_sample"] = losses["MSE_per_sample"] + out["CE_per_sample"] = losses["CE_per_sample"] + return out # Preprocessing helpers (state discretization, image resize, etc.) @@ -978,7 +990,8 @@ def forward( discrete_actions: Tensor | None = None, discrete_action_masks: Tensor | None = None, real_action_dim: Tensor | None = None, - ) -> dict[str, Tensor]: + return_per_sample: bool = False, + ) -> dict[str, Tensor | PerSampleLoss]: """Full training forward pass. Returns `{"MSE": ..., "CE": ...}`.""" prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix( images, @@ -1061,14 +1074,16 @@ def forward( v_t = self.action_out_proj(suffix_out) v_t = v_t.to(dtype=torch.float32) - mse_loss = flow_matching_masked_mse( + mse_result = flow_matching_masked_mse( u_t=u_t, v_t=v_t, prefix_mask=prefix_mask, actions_is_pad=actions_is_pad, max_action_dim=self.config.max_action_dim, real_action_dim=real_action_dim, + return_per_sample=return_per_sample, ) + mse_loss, mse_per_sample = mse_result if return_per_sample else (mse_result, None) # Discrete-action cross-entropy (FAST tokens) via the dedicated head. batch_size_da, seq_len = discrete_actions.shape @@ -1085,6 +1100,9 @@ def forward( ) discrete_action_is_pad = ~discrete_action_masks discrete_action_ce_loss = discrete_action_ce_loss * ~discrete_action_is_pad + discrete_action_ce_per_sample = ( + ce_per_sample(discrete_action_ce_loss, ~discrete_action_is_pad) if return_per_sample else None + ) discrete_action_ce_loss = discrete_action_ce_loss.mean() # Optional response-token cross-entropy (via Gemma 3's shared lm_head). @@ -1105,11 +1123,27 @@ def forward( ) response_is_pad = ~response_masks response_ce_loss = response_ce_loss * ~response_is_pad[:, response_slice] + response_ce_per_sample = ( + ce_per_sample(response_ce_loss, ~response_is_pad[:, response_slice]) + if return_per_sample + else None + ) response_ce_loss = response_ce_loss.mean() else: response_ce_loss = torch.tensor(0.0, device=mse_loss.device) - - return {"MSE": mse_loss, "CE": discrete_action_ce_loss + response_ce_loss} + response_ce_per_sample = None + + out: dict[str, Tensor | PerSampleLoss] = { + "MSE": mse_loss, + "CE": discrete_action_ce_loss + response_ce_loss, + } + if return_per_sample: + ce_ps = discrete_action_ce_per_sample + if response_ce_per_sample is not None: + ce_ps = ce_ps + response_ce_per_sample + out["MSE_per_sample"] = mse_per_sample + out["CE_per_sample"] = ce_ps + return out def _gemma3_lm_head(self): """Return the language-modeling head of the Gemma 3 backbone, regardless diff --git a/src/opentau/policies/pi07/low_level/modeling_pi07_low_level.py b/src/opentau/policies/pi07/low_level/modeling_pi07_low_level.py index ab3d72c..ebda13f 100644 --- a/src/opentau/policies/pi07/low_level/modeling_pi07_low_level.py +++ b/src/opentau/policies/pi07/low_level/modeling_pi07_low_level.py @@ -63,7 +63,7 @@ ) from opentau.policies.pi07.video_encoder import SpaceTimeSiglipVideoEncoder from opentau.policies.pretrained import PreTrainedPolicy, ProjectionRemapError, T -from opentau.policies.utils import flow_matching_masked_mse +from opentau.policies.utils import PerSampleLoss, ce_per_sample, flow_matching_masked_mse from opentau.utils.accelerate_utils import get_proc_accelerator from opentau.utils.utils import get_safe_dtype @@ -859,8 +859,12 @@ def sample_actions( return actions def forward( - self, batch: dict[str, Tensor], noise: Tensor | None = None, time: Tensor | None = None - ) -> dict[str, Tensor | list]: + self, + batch: dict[str, Tensor], + noise: Tensor | None = None, + time: Tensor | None = None, + return_per_sample: bool = False, + ) -> dict[str, Tensor | list | PerSampleLoss]: """Training forward pass: normalize, prepare modalities, and compute losses. Returns a dict with ``"MSE"`` (flow-matching velocity loss) and @@ -872,9 +876,13 @@ def forward( batch: Training batch dict with observations, actions, and prompts. noise: Optional pre-sampled noise tensor. time: Optional pre-sampled flow-matching timesteps. + return_per_sample: When True, also returns per-sample + ``MSE_per_sample``/``CE_per_sample`` (:class:`PerSampleLoss`) for the + validation per-(dataset, control_mode) breakdown. Scalars unchanged. Returns: - Dict with ``"MSE"`` and ``"CE"`` scalar loss tensors. + Dict with ``"MSE"`` and ``"CE"`` scalar loss tensors (plus per-sample + entries when ``return_per_sample`` is True). """ dataset_index = self._resolve_dataset_index(batch) batch = self.normalize_inputs(batch, dataset_index) @@ -923,12 +931,16 @@ def forward( response_masks=response_masks, real_action_dim=batch.get("real_action_dim"), group_index=dataset_index, + return_per_sample=return_per_sample, ) mse_loss = losses["MSE"] ce_loss = losses["CE"] - out: dict[str, Tensor | list] = {"MSE": mse_loss, "CE": ce_loss} + out: dict[str, Tensor | list | PerSampleLoss] = {"MSE": mse_loss, "CE": ce_loss} + if return_per_sample: + out["MSE_per_sample"] = losses["MSE_per_sample"] + out["CE_per_sample"] = losses["CE_per_sample"] if outlier_records: out["outlier_records"] = outlier_records return out @@ -1916,7 +1928,8 @@ def forward( response_masks: Tensor | None = None, real_action_dim: Tensor | None = None, group_index: Tensor | None = None, - ) -> dict[str, Tensor]: + return_per_sample: bool = False, + ) -> dict[str, Tensor | PerSampleLoss]: """Training forward pass: embed all modalities and compute losses. Runs the VLM on the prefix (video, language, response, state, subgoal @@ -2048,14 +2061,16 @@ def forward( v_t = v_t.to(dtype=torch.float32) # Shared masked-MSE reduction; see pi05 for the rationale. - mse_loss = flow_matching_masked_mse( + mse_result = flow_matching_masked_mse( u_t=u_t, v_t=v_t, max_action_dim=self.config.max_action_dim, prefix_mask=prefix_mask, actions_is_pad=actions_is_pad, real_action_dim=real_action_dim, + return_per_sample=return_per_sample, ) + mse_loss, mse_per_sample = mse_result if return_per_sample else (mse_result, None) assert discrete_actions is not None assert discrete_action_masks is not None @@ -2076,9 +2091,17 @@ def forward( discrete_action_is_pad = ~discrete_action_masks discrete_action_ce_loss = discrete_action_ce_loss * ~discrete_action_is_pad + ce_per_sample_loss = ( + ce_per_sample(discrete_action_ce_loss, ~discrete_action_is_pad) if return_per_sample else None + ) + discrete_action_ce_loss = discrete_action_ce_loss.mean() - return {"MSE": mse_loss, "CE": discrete_action_ce_loss} + out: dict[str, Tensor | PerSampleLoss] = {"MSE": mse_loss, "CE": discrete_action_ce_loss} + if return_per_sample: + out["MSE_per_sample"] = mse_per_sample + out["CE_per_sample"] = ce_per_sample_loss + return out def sample_actions( self, diff --git a/src/opentau/policies/pi07_paligemma/low_level/modeling_pi07_low_level.py b/src/opentau/policies/pi07_paligemma/low_level/modeling_pi07_low_level.py index 8ecd717..db25b0e 100644 --- a/src/opentau/policies/pi07_paligemma/low_level/modeling_pi07_low_level.py +++ b/src/opentau/policies/pi07_paligemma/low_level/modeling_pi07_low_level.py @@ -52,7 +52,7 @@ PI07PaligemmaLowLevelConfig, ) from opentau.policies.pretrained import PreTrainedPolicy, ProjectionRemapError, T -from opentau.policies.utils import flow_matching_masked_mse +from opentau.policies.utils import PerSampleLoss, ce_per_sample, flow_matching_masked_mse from opentau.utils.accelerate_utils import get_proc_accelerator from opentau.utils.utils import get_safe_dtype @@ -1124,14 +1124,21 @@ def sample_actions( return actions def forward( - self, batch: dict[str, Tensor], noise: Tensor | None = None, time: Tensor | None = None - ) -> dict[str, Tensor | list]: + self, + batch: dict[str, Tensor], + noise: Tensor | None = None, + time: Tensor | None = None, + return_per_sample: bool = False, + ) -> dict[str, Tensor | list | PerSampleLoss]: """Do a full training forward pass to compute the loss. Args: batch: Batch of data containing environment observations, actions, and targets. noise: Optional noise tensor. time: Optional time tensor. + return_per_sample: When True, also returns per-sample + ``MSE_per_sample``/``CE_per_sample`` (:class:`PerSampleLoss`) for the + validation per-(dataset, control_mode) breakdown. Scalars unchanged. Returns: A dictionary containing the loss components ("MSE" and "CE"). When the @@ -1170,12 +1177,16 @@ def forward( time=time, real_action_dim=batch.get("real_action_dim"), group_index=dataset_index, + return_per_sample=return_per_sample, ) mse_loss = losses["MSE"] ce_loss = losses["CE"] - out: dict[str, Tensor | list] = {"MSE": mse_loss, "CE": ce_loss} + out: dict[str, Tensor | list | PerSampleLoss] = {"MSE": mse_loss, "CE": ce_loss} + if return_per_sample: + out["MSE_per_sample"] = losses["MSE_per_sample"] + out["CE_per_sample"] = losses["CE_per_sample"] if outlier_records: out["outlier_records"] = outlier_records return out @@ -2174,7 +2185,8 @@ def forward( time: Tensor | None = None, real_action_dim: Tensor | None = None, group_index: Tensor | None = None, - ) -> dict[str, Tensor]: + return_per_sample: bool = False, + ) -> dict[str, Tensor | PerSampleLoss]: """Do a full training forward pass and compute the loss. Args: @@ -2281,14 +2293,16 @@ def forward( v_t = v_t.to(dtype=torch.float32) # Shared masked-MSE reduction; see pi05 for the rationale. - mse_loss = flow_matching_masked_mse( + mse_result = flow_matching_masked_mse( u_t=u_t, v_t=v_t, max_action_dim=self.config.max_action_dim, prefix_mask=prefix_mask, actions_is_pad=actions_is_pad, real_action_dim=real_action_dim, + return_per_sample=return_per_sample, ) + mse_loss, mse_per_sample = mse_result if return_per_sample else (mse_result, None) # compute cross entropy loss for discrete actions batch_size, seq_len = discrete_actions.shape @@ -2310,10 +2324,18 @@ def forward( discrete_action_is_pad = ~discrete_action_masks # convert into format where value for pad is True discrete_action_ce_loss = discrete_action_ce_loss * ~discrete_action_is_pad + ce_per_sample_loss = ( + ce_per_sample(discrete_action_ce_loss, ~discrete_action_is_pad) if return_per_sample else None + ) + # compute mean discrete_action_ce_loss = discrete_action_ce_loss.mean() - return {"MSE": mse_loss, "CE": discrete_action_ce_loss} + out: dict[str, Tensor | PerSampleLoss] = {"MSE": mse_loss, "CE": discrete_action_ce_loss} + if return_per_sample: + out["MSE_per_sample"] = mse_per_sample + out["CE_per_sample"] = ce_per_sample_loss + return out def _build_suffix_items(self, x_t: Tensor) -> list[ContextItem]: """Default suffix layout: a single ``"action"`` block for ``x_t``. diff --git a/src/opentau/policies/utils.py b/src/opentau/policies/utils.py index aff8d82..818b634 100644 --- a/src/opentau/policies/utils.py +++ b/src/opentau/policies/utils.py @@ -24,10 +24,11 @@ import logging from collections import deque +from dataclasses import dataclass import torch import torch.nn.functional as F # noqa: N812 -from einops import rearrange +from einops import rearrange, reduce from torch import Tensor, nn @@ -110,6 +111,29 @@ def get_output_shape(module: nn.Module, input_shape: tuple) -> tuple: return tuple(output.shape) +@dataclass +class PerSampleLoss: + """Per-sample decomposition of a masked loss, with the batch dim kept. + + ``sum`` and ``count`` are both ``(B,)`` and hold, for each sample, the + summed unmasked loss and the number of unmasked slots that fed it. The + masked *mean* for any group of samples is ``Σsum / Σcount`` — carrying the + (numerator, denominator) pair rather than a per-sample mean is what lets a + caller (e.g. the validation loop) regroup samples by provenance and recover + an exact masked mean per group. Averaging per-sample means instead would + double-normalize and weight a 1-slot sample the same as a 200-slot one. + """ + + sum: Tensor + count: Tensor + + def __add__(self, other: "PerSampleLoss") -> "PerSampleLoss": + # Pool several loss components (e.g. discrete-action CE + response CE) + # into one (numerator, denominator); the pooled per-slot mean is then + # Σ(sum_i) / Σ(count_i) over the pooled slots. + return PerSampleLoss(sum=self.sum + other.sum, count=self.count + other.count) + + def make_action_dim_mask( real_action_dim: Tensor | None, max_action_dim: int, @@ -164,7 +188,8 @@ def flow_matching_masked_mse( prefix_mask: Tensor | None = None, actions_is_pad: Tensor | None = None, real_action_dim: Tensor | None = None, -) -> Tensor: + return_per_sample: bool = False, +) -> Tensor | tuple[Tensor, PerSampleLoss]: """Masked MSE for flow-matching velocity-field training. Shared across pi05, pi05_mem, pi06, pi07 (low_level), and pi07_paligemma @@ -195,9 +220,17 @@ def flow_matching_masked_mse( action chunk is padded (no real action target). ``None`` ⇒ all-False. real_action_dim: Optional long ``(B,)`` — real (pre-pad) action dim per sample. ``None`` ⇒ all-True (every dim is real). + return_per_sample: When True, additionally return a :class:`PerSampleLoss` + holding the per-sample ``(Σ over masked slots, #masked slots)`` so the + caller can regroup the loss by provenance. The scalar is computed + exactly as in the default path (bit-identical), so toggling this flag + never perturbs the training reduction. Returns: - Scalar tensor: masked mean of ``(u_t - v_t)**2`` over the unmasked slots. + Scalar tensor (masked mean of ``(u_t - v_t)**2`` over the unmasked slots) + when ``return_per_sample`` is False; otherwise ``(scalar, PerSampleLoss)`` + where the per-sample ``sum``/``count`` are over the same masked slots + (so each sample's mean is ``sum / count``). """ mse_loss = F.mse_loss(u_t, v_t, reduction="none") bsz, chunk_size = u_t.shape[:2] @@ -210,7 +243,44 @@ def flow_matching_masked_mse( mse_loss = mse_loss[:, :, :max_action_dim] dim_mask = make_action_dim_mask(real_action_dim, max_action_dim, batch_size=bsz, device=u_t.device) full_mask = postfix_mask & rearrange(dim_mask, "b d -> b 1 d") - return (mse_loss * full_mask).sum() / (full_mask.sum() + 1e-8) + masked = mse_loss * full_mask + scalar = masked.sum() / (full_mask.sum() + 1e-8) + if not return_per_sample: + return scalar + per_sample = PerSampleLoss( + sum=reduce(masked, "b c d -> b", "sum"), + count=reduce(full_mask.float(), "b c d -> b", "sum"), + ) + return scalar, per_sample + + +def ce_per_sample(masked_ce: Tensor, valid_mask: Tensor) -> PerSampleLoss: + """Per-sample numerator/denominator for a masked token cross-entropy. + + Policies compute their CE as ``F.cross_entropy(..., reduction="none")`` + reshaped to ``(B, S)`` and zeroed at pad positions, then reduce it with + ``.mean()`` to a scalar. This helper takes that same pad-zeroed ``(B, S)`` + tensor plus the per-token validity mask and returns the per-sample + ``(Σ over valid tokens, #valid tokens)``, so a caller can pool CE per + provenance group as ``Σsum / Σcount`` — the mean cross-entropy per valid + token. Multiple CE components (e.g. discrete-action + response) pool by + adding their :class:`PerSampleLoss` objects. + + Note this normalizes by *valid* token count, unlike the legacy scalar + ``.mean()`` which divides by the full ``B * S`` (pad slots included); the + per-group breakdown is therefore over valid tokens only. + + Args: + masked_ce: ``(B, S)`` cross-entropy already zeroed at pad positions. + valid_mask: ``(B, S)`` bool, True at non-pad (scored) tokens. + + Returns: + ``PerSampleLoss`` whose ``sum`` and ``count`` are ``(B,)``. + """ + return PerSampleLoss( + sum=reduce(masked_ce, "b s -> b", "sum"), + count=reduce(valid_mask.float(), "b s -> b", "sum"), + ) def log_model_loading_keys(missing_keys: list[str], unexpected_keys: list[str]) -> None: diff --git a/src/opentau/policies/value/modeling_value.py b/src/opentau/policies/value/modeling_value.py index 9493486..945bad9 100644 --- a/src/opentau/policies/value/modeling_value.py +++ b/src/opentau/policies/value/modeling_value.py @@ -33,7 +33,7 @@ from opentau.policies.normalize import Normalize from opentau.policies.pretrained import PreTrainedPolicy -from opentau.policies.utils import log_model_loading_keys +from opentau.policies.utils import PerSampleLoss, ce_per_sample, log_model_loading_keys from opentau.policies.value.configuration_value import ValueConfig from opentau.policies.value.siglip_gemma import ( SiglipGemmaValueConfig, @@ -365,14 +365,22 @@ def predict_value(self, batch: dict[str, Tensor]) -> Tensor: logits = self.model.get_value(images, img_masks, lang_tokens, lang_masks) return self.calculate_value(logits) - def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict[str, Tensor] | None]: + def forward( + self, batch: dict[str, Tensor], return_per_sample: bool = False + ) -> dict[str, Tensor | PerSampleLoss]: """Do a full training forward pass to compute the value loss. Args: batch: Dictionary containing observations and target values + return_per_sample: When True, also returns per-sample + ``MSE_per_sample``/``CE_per_sample`` (:class:`PerSampleLoss`) for the + validation per-(dataset, control_mode) breakdown. ``MSE`` is a zero + stub here, so ``MSE_per_sample`` carries zero sum and count; the CE + pools the value-bin and response-token terms per sample. Returns: - Tuple of (loss_dict, None) where loss_dict contains the MSE loss + Dict with "MSE"/"CE"/"L1"/"Accuracy" (plus per-sample CE/MSE entries + when ``return_per_sample`` is True). """ # `ValueFunction` is a single-dataset policy (its `Normalize` was # built with `num_datasets=1`); `_resolve_dataset_index` defaults @@ -400,6 +408,11 @@ def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict[str, Tensor] | # Mask CE loss if all action_is_pad are true. This is used for VQA dataset where we don't have actions tokens. value_ce_loss = value_ce_loss * (~diff_mask).float() + # Per-sample value-bin CE (one scored token per robotic sample) for the val breakdown. + value_ce_per_sample = ( + PerSampleLoss(sum=value_ce_loss, count=(~diff_mask).float()) if return_per_sample else None + ) + value_ce_loss = value_ce_loss.mean() l1_loss = F.l1_loss(values, batch["return_continuous"]) @@ -426,15 +439,31 @@ def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict[str, Tensor] | # Mask response loss if all action_is_pad are true. This is used for Robotic dataset where we have at least one actions tokens. response_ce_loss = response_ce_loss * rearrange(diff_mask.float(), "b -> b 1") + # Per-sample response CE (valid VQA tokens) for the val breakdown. + response_ce_per_sample = ( + ce_per_sample( + response_ce_loss, + (~response_is_pad[:, response_slice]) & rearrange(diff_mask, "b -> b 1"), + ) + if return_per_sample + else None + ) + # compute mean response_ce_loss = response_ce_loss.mean() - return { + out: dict[str, Tensor | PerSampleLoss] = { "MSE": torch.zeros_like(value_ce_loss, requires_grad=False), "CE": value_ce_loss + response_ce_loss, "L1": l1_loss, "Accuracy": accuracy, } + if return_per_sample: + # MSE is a zero stub for the value head; emit zero sum/count. + zeros = torch.zeros(diff_mask.shape[0], device=diff_mask.device) + out["MSE_per_sample"] = PerSampleLoss(sum=zeros, count=zeros.clone()) + out["CE_per_sample"] = value_ce_per_sample + response_ce_per_sample + return out def prepare_images(self, batch): """Preprocesses images for the model. diff --git a/src/opentau/scripts/train.py b/src/opentau/scripts/train.py index e80d70e..3d14f26 100644 --- a/src/opentau/scripts/train.py +++ b/src/opentau/scripts/train.py @@ -14,6 +14,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import inspect import json import logging import os @@ -318,11 +319,11 @@ def _mixture_weighted_aggregate( """Mixture-weighted average of per-dataset validation metrics. Weights are taken from ``name_to_weight`` and renormalized over only the - names present in ``per_dataset_trackers`` (empty datasets are skipped - upstream by ``WeightedDatasetMixture.get_per_dataset_dataloaders`` and so - will be missing from the trackers). When the renormalization total is 0 - -- empty trackers, or all selected datasets have weight 0 -- every metric - is returned as ``0.0``. + keys present in ``per_dataset_trackers`` (groups with no validation samples + in the combined pass are simply absent from the trackers). When the + renormalization total is 0 -- empty trackers, or all selected groups have + weight 0 -- every metric is returned as ``0.0``. Keys are generic: the + validation path keys both trackers and weights by norm-head ``norm_key``. The aggregated metric keys are derived from the first tracker's meters. All per-dataset trackers share the same meter set because they're all @@ -331,11 +332,12 @@ def _mixture_weighted_aggregate( tracker's keys are representative. Args: - per_dataset_trackers: One ``MetricsTracker`` per non-empty validation - dataset, keyed by dataset name. - name_to_weight: Mapping from dataset name to its mixture weight (need - not be normalized; need not be a strict subset/superset of the - tracker keys, but must contain every tracker key). + per_dataset_trackers: One ``MetricsTracker`` per group present in the + validation pass, keyed by group identifier (the validation path keys + by norm-head ``norm_key``). + name_to_weight: Mapping from that same group identifier to its mixture + weight (need not be normalized; need not be a strict subset/superset + of the tracker keys, but must contain every tracker key). Returns: Dict mapping each metric name found on the trackers to its weighted @@ -358,6 +360,39 @@ def _mixture_weighted_aggregate( } +def _bucket_per_sample( + group_index: torch.Tensor, + mse_sum: torch.Tensor, + mse_count: torch.Tensor, + ce_sum: torch.Tensor, + ce_count: torch.Tensor, +) -> dict[int, dict[str, float]]: + """Bucket per-sample MSE/CE ``(sum, count)`` by an integer group key. + + All inputs are 1-D tensors of equal length: ``group_index`` is the per-sample + integer group (norm-head row, or per-dataset enumerate index), and the others + are the per-sample numerator/denominator from the policies' ``PerSampleLoss`` + outputs. Returns ``{group: {"mse_sum", "mse_count", "ce_sum", "ce_count"}}`` + summed over the samples in each group; the masked mean for a group is then + ``Σsum / Σcount``. + + Pure / CPU-only (no collectives), so the validation loop can call it on rank 0 + after gathering, and it is unit-testable without an accelerator. Disjoint groups + never cross-contaminate, which is exactly what lets a *mixed* batch (the combined + val pass) be disaggregated by provenance. + """ + buckets: dict[int, dict[str, float]] = {} + for g in torch.unique(group_index).tolist(): + mask = group_index == g + buckets[int(g)] = { + "mse_sum": float(mse_sum[mask].sum()), + "mse_count": float(mse_count[mask].sum()), + "ce_sum": float(ce_sum[mask].sum()), + "ce_count": float(ce_count[mask].sum()), + } + return buckets + + def _find_unused_params_from_env() -> bool: """Parse the ``FIND_UNUSED_PARAMS`` env var into a bool. @@ -641,19 +676,18 @@ def train(cfg: TrainPipelineConfig): if cfg.val_freq > 0: train_dataloader = train_dataset.get_dataloader() - # One DataLoader per underlying val dataset so we can report per-dataset - # validation losses. The aggregate is computed by averaging across all. - per_dataset_val_dataloaders = val_dataset.get_per_dataset_dataloaders() - val_names = list(per_dataset_val_dataloaders.keys()) - prepared = accelerator.prepare( - policy, - optimizer, - train_dataloader, - lr_scheduler, - *per_dataset_val_dataloaders.values(), + # A single combined validation pass over the whole mixture: under + # ``accelerator.prepare`` every rank's shard stays full even when individual + # subsets have fewer frames than ``world_size`` (the old one-loader-per-dataset + # path left most ranks idle on tiny subsets and stacked that idle time). The + # validation loop disaggregates metrics per (dataset, control_mode) from each + # sample's ``dataset_index`` / ``dataset_repo_id`` provenance instead. + val_dataloader = val_dataset.get_combined_val_dataloader() + policy, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + policy, optimizer, train_dataloader, lr_scheduler ) - policy, optimizer, train_dataloader, lr_scheduler = prepared[:4] - per_dataset_val_dataloaders = dict(zip(val_names, prepared[4:], strict=True)) + if val_dataloader is not None: + val_dataloader = accelerator.prepare(val_dataloader) else: train_dataloader = train_dataset.get_dataloader() policy, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( @@ -841,13 +875,10 @@ def _strip_norm_buffers_pre_save(models, weights, input_dir): accelerator.wait_for_everyone() - if is_val_step: + if is_val_step and val_dataloader is not None: policy.eval() def _make_val_tracker(current_step: int = step) -> MetricsTracker: - # ``l1_loss`` and ``accuracy`` are populated lazily below iff the - # policy's ``forward`` returns ``"L1"`` / ``"Accuracy"``. See - # ``update_policy`` for the symmetric training-side pattern. return MetricsTracker( cfg.batch_size * accelerator.num_processes, { @@ -858,92 +889,197 @@ def _make_val_tracker(current_step: int = step) -> MetricsTracker: initial_step=current_step, ) - per_dataset_trackers: dict[str, MetricsTracker] = { - name: _make_val_tracker() for name in per_dataset_val_dataloaders + # Per-(dataset, control_mode) breakdown needs per-sample losses. Action + # policies expose them via ``forward(..., return_per_sample=True)``; other + # policies (e.g. high-level planners) don't, so fall back to an + # aggregate-only pass for them. The single combined loader — and its + # full-shard parallelization — applies to every policy regardless. + unwrapped_policy = accelerator.unwrap_model(policy) + supports_per_sample = ( + "return_per_sample" in inspect.signature(unwrapped_policy.forward).parameters + ) + mse_w = cfg.loss_weighting["MSE"] + ce_w = cfg.loss_weighting["CE"] + outlier_threshold = getattr(unwrapped_policy.config, "warn_outlier_threshold", None) + name_to_index = val_dataset.meta.dataset_name_to_index + + # Rank-0 accumulators across the whole pass. Per-sample tensors are gathered + # every batch (a collective on all ranks) but only retained on the main + # process, then bucketed once at the end. + ps_chunks: dict[str, list[torch.Tensor]] = { + "mse_sum": [], + "mse_count": [], + "ce_sum": [], + "ce_count": [], + "norm_index": [], + "source_index": [], } + agg_tracker = _make_val_tracker() # only used on the fallback (no per-sample) path + optional_vals: dict[str, list[float]] = {"l1_loss": [], "accuracy": []} logging.info(f"Validation at step {step}...") with torch.no_grad(): - for ds_name, ds_loader in per_dataset_val_dataloaders.items(): - ds_tracker = per_dataset_trackers[ds_name] - for batch in ds_loader: - losses = policy.forward(batch) - outlier_records = losses.pop("outlier_records", []) - loss = ( - cfg.loss_weighting["MSE"] * losses["MSE"] - + cfg.loss_weighting["CE"] * losses["CE"] - ) + for batch in val_dataloader: + losses = ( + policy.forward(batch, return_per_sample=True) + if supports_per_sample + else policy.forward(batch) + ) + outlier_records = losses.pop("outlier_records", []) + + # Optional value-head metrics (scalars); aggregate-level only. Gather + # outside the per-sample branch so both paths handle them identically. + l1_val = ( + accelerator.gather_for_metrics(losses["L1"]).to(dtype=torch.float32).mean().item() + if "L1" in losses + else None + ) + accuracy_val = ( + accelerator.gather_for_metrics(losses["Accuracy"]) + .to(dtype=torch.float32) + .mean() + .item() + if "Accuracy" in losses + else None + ) - # Gather and average metrics across processes. ``L1`` / - # ``Accuracy`` are optional — see ``update_policy`` for - # the symmetric training-side gating rationale. - loss = accelerator.gather_for_metrics(loss).mean().item() - mse_loss = ( - accelerator.gather_for_metrics(losses["MSE"]) - .to(dtype=torch.float32) - .mean() - .item() + if supports_per_sample: + mse_ps = losses["MSE_per_sample"] + ce_ps = losses["CE_per_sample"] + # Per-sample provenance keys: ``dataset_index`` is the dropout-immune + # norm-head row; ``source_index`` is the per-dataset enumerate index + # recovered from the (also dropout-immune) ``dataset_repo_id``. Both + # gathered as int tensors so the ragged last-batch de-pad stays + # row-aligned with the loss tensors (``gather_object`` would not). + norm_index = batch["dataset_index"].to(device=accelerator.device, dtype=torch.long) + source_index = torch.tensor( + [name_to_index[name] for name in batch["dataset_repo_id"]], + device=accelerator.device, + dtype=torch.long, ) - ce_loss = ( - accelerator.gather_for_metrics(losses["CE"]).to(dtype=torch.float32).mean().item() - ) - l1_loss = ( - accelerator.gather_for_metrics(losses["L1"]).to(dtype=torch.float32).mean().item() - if "L1" in losses - else None + # Gather all per-sample tensors in a SINGLE call so accelerate applies + # one identical ragged-last-batch de-pad to every entry: the per-sample + # losses and their provenance indices stay row-aligned by construction, + # rather than relying on six separate de-pads landing on the same trim. + gathered = accelerator.gather_for_metrics( + { + "mse_sum": mse_ps.sum.to(dtype=torch.float32), + "mse_count": mse_ps.count.to(dtype=torch.float32), + "ce_sum": ce_ps.sum.to(dtype=torch.float32), + "ce_count": ce_ps.count.to(dtype=torch.float32), + "norm_index": norm_index, + "source_index": source_index, + } ) - accuracy = ( - accelerator.gather_for_metrics(losses["Accuracy"]) + else: + loss = mse_w * losses["MSE"] + ce_w * losses["CE"] + loss_val = accelerator.gather_for_metrics(loss).to(dtype=torch.float32).mean().item() + mse_val = ( + accelerator.gather_for_metrics(losses["MSE"]) .to(dtype=torch.float32) .mean() .item() - if "Accuracy" in losses - else None ) - - # Warn about outliers on rank 0 too during eval (parity with training); see - # ``update_policy`` for the rank-uniform gating rationale. - outlier_threshold = getattr( - accelerator.unwrap_model(policy).config, "warn_outlier_threshold", None + ce_val = ( + accelerator.gather_for_metrics(losses["CE"]).to(dtype=torch.float32).mean().item() ) - if outlier_threshold is not None and outlier_threshold > 0: - log_outlier_records_distributed(accelerator, outlier_records, outlier_threshold) - if accelerator.is_main_process: - ds_tracker.loss = loss - ds_tracker.mse_loss = mse_loss - ds_tracker.ce_loss = ce_loss - _observe_optional(ds_tracker, "l1_loss", "val_l1_loss", ":.6f", l1_loss) - _observe_optional(ds_tracker, "accuracy", "val_accuracy", ":.3f", accuracy) + # Warn about outliers on rank 0 too during eval (parity with training); see + # ``update_policy`` for the rank-uniform gating rationale. + if outlier_threshold is not None and outlier_threshold > 0: + log_outlier_records_distributed(accelerator, outlier_records, outlier_threshold) + + if accelerator.is_main_process: + if supports_per_sample: + for key, value in gathered.items(): + ps_chunks[key].append(value.cpu()) + else: + agg_tracker.loss = loss_val + agg_tracker.mse_loss = mse_val + agg_tracker.ce_loss = ce_val + if l1_val is not None: + optional_vals["l1_loss"].append(l1_val) + if accuracy_val is not None: + optional_vals["accuracy"].append(accuracy_val) if accelerator.is_main_process: - for ds_name, ds_tracker in per_dataset_trackers.items(): - logging.info(f"Validation/{ds_name} {ds_tracker}") - ds_dict = ds_tracker.to_dict(use_avg=True) - accelerator.log({f"Validation/{ds_name}/Loss": ds_dict["loss"]}, step=step) - accelerator.log({f"Validation/{ds_name}/MSE Loss": ds_dict["mse_loss"]}, step=step) - accelerator.log({f"Validation/{ds_name}/CE Loss": ds_dict["ce_loss"]}, step=step) - if "l1_loss" in ds_tracker.metrics: - accelerator.log({f"Validation/{ds_name}/L1 Loss": ds_dict["l1_loss"]}, step=step) - if "accuracy" in ds_tracker.metrics: - accelerator.log({f"Validation/{ds_name}/Accuracy": ds_dict["accuracy"]}, step=step) - - # Mixture-weighted aggregate across the per-dataset trackers, so the - # overall scalar reflects the training mixture rather than being - # implicitly dominated by whichever val subset has the most batches. + meta = val_dataset.meta name_to_weight = dict( zip(val_dataset.dataset_names, val_dataset.dataset_weights, strict=True) ) - agg = _mixture_weighted_aggregate(per_dataset_trackers, name_to_weight) - logging.info(f"Validation/aggregate {agg}") - accelerator.log({"Validation/Loss": agg["loss"]}, step=step) - accelerator.log({"Validation/MSE Loss": agg["mse_loss"]}, step=step) - accelerator.log({"Validation/CE Loss": agg["ce_loss"]}, step=step) - if "l1_loss" in agg: - accelerator.log({"Validation/L1 Loss": agg["l1_loss"]}, step=step) - if "accuracy" in agg: - accelerator.log({"Validation/Accuracy": agg["accuracy"]}, step=step) + + if supports_per_sample and ps_chunks["norm_index"]: + norm_index = torch.cat(ps_chunks["norm_index"]) + source_index = torch.cat(ps_chunks["source_index"]) + mse_sum = torch.cat(ps_chunks["mse_sum"]) + mse_count = torch.cat(ps_chunks["mse_count"]) + ce_sum = torch.cat(ps_chunks["ce_sum"]) + ce_count = torch.cat(ps_chunks["ce_count"]) + norm_groups = _bucket_per_sample(norm_index, mse_sum, mse_count, ce_sum, ce_count) + source_groups = _bucket_per_sample(source_index, mse_sum, mse_count, ce_sum, ce_count) + + # Headline breakdown: one tracker per norm head (``robot::control_mode`` + # or a fallback dataset name), keyed by ``norm_key`` for the aggregate. + per_normkey_trackers: dict[str, MetricsTracker] = {} + for norm_row, group in norm_groups.items(): + norm_key = meta.norm_keys[norm_row] + mse_mean = group["mse_sum"] / (group["mse_count"] + 1e-8) + ce_mean = group["ce_sum"] / (group["ce_count"] + 1e-8) + loss_mean = mse_w * mse_mean + ce_w * ce_mean + tracker = _make_val_tracker() + tracker.mse_loss = mse_mean + tracker.ce_loss = ce_mean + tracker.loss = loss_mean + per_normkey_trackers[norm_key] = tracker + logging.info(f"Validation/{norm_key} {tracker}") + accelerator.log({f"Validation/{norm_key}/Loss": loss_mean}, step=step) + accelerator.log({f"Validation/{norm_key}/MSE Loss": mse_mean}, step=step) + accelerator.log({f"Validation/{norm_key}/CE Loss": ce_mean}, step=step) + + # Finer per-source-dataset breakdown (no info loss when datasets share + # a norm head). Logged-only; the aggregate stays per-norm-head. + for source_row, group in source_groups.items(): + ds_name = meta.dataset_names[source_row] + mse_mean = group["mse_sum"] / (group["mse_count"] + 1e-8) + ce_mean = group["ce_sum"] / (group["ce_count"] + 1e-8) + loss_mean = mse_w * mse_mean + ce_w * ce_mean + accelerator.log({f"Validation/by_dataset/{ds_name}/Loss": loss_mean}, step=step) + accelerator.log({f"Validation/by_dataset/{ds_name}/MSE Loss": mse_mean}, step=step) + accelerator.log({f"Validation/by_dataset/{ds_name}/CE Loss": ce_mean}, step=step) + + # Mixture-weighted aggregate over norm heads, so the overall scalar + # reflects the training mixture rather than being dominated by whichever + # head has the most val frames. Sum each head's member-dataset weights. + norm_key_weight: dict[str, float] = {} + for name, weight in name_to_weight.items(): + nk = meta.norm_keys[meta.dataset_to_norm_index[name]] + norm_key_weight[nk] = norm_key_weight.get(nk, 0.0) + weight + agg = _mixture_weighted_aggregate(per_normkey_trackers, norm_key_weight) + elif not supports_per_sample: + # Fallback: the combined-pass batch-mean aggregate; no per-group breakdown. + agg = agg_tracker.to_dict(use_avg=True) + else: + agg = {} + + if agg: + logging.info(f"Validation/aggregate {agg}") + accelerator.log({"Validation/Loss": agg["loss"]}, step=step) + accelerator.log({"Validation/MSE Loss": agg["mse_loss"]}, step=step) + accelerator.log({"Validation/CE Loss": agg["ce_loss"]}, step=step) + if optional_vals["l1_loss"]: + accelerator.log( + {"Validation/L1 Loss": sum(optional_vals["l1_loss"]) / len(optional_vals["l1_loss"])}, + step=step, + ) + if optional_vals["accuracy"]: + accelerator.log( + { + "Validation/Accuracy": sum(optional_vals["accuracy"]) + / len(optional_vals["accuracy"]) + }, + step=step, + ) # This barrier is probably necessary to ensure # other processes wait for the main process to finish saving diff --git a/tests/datasets/test_dataset_mixture.py b/tests/datasets/test_dataset_mixture.py index 55dbe50..69647a9 100644 --- a/tests/datasets/test_dataset_mixture.py +++ b/tests/datasets/test_dataset_mixture.py @@ -324,6 +324,34 @@ def test_get_dataloader_success(self, train_pipeline_config, datasets_factory): assert dataloader.batch_size == train_pipeline_config.batch_size assert dataloader.num_workers == train_pipeline_config.num_workers + def test_get_combined_val_dataloader(self, train_pipeline_config, datasets_factory): + """One deterministic sequential loader over the whole mixture (the + validation path): every sample is scored exactly once, in order. + + The loader serves ``concatenated_dataset`` — the same ``_TaggedDataset``- + wrapped mixture ``get_dataloader`` uses — so each sample carries the + ``dataset_index`` / ``dataset_repo_id`` provenance the validation loop + buckets on (the wrapper injection is covered by the integration tests + below). Coverage and order are asserted via the sampler (pure Python), + avoiding a fetch through the fake fixture's non-picklable worker dataset.""" + datasets = datasets_factory(2) + mixture = WeightedDatasetMixture(train_pipeline_config, datasets, [0.7, 0.3], 30.0) + + dataloader = mixture.get_combined_val_dataloader() + + assert dataloader is not None + assert dataloader.batch_size == train_pipeline_config.dataloader_batch_size + # Serves the tagged concat mixture (so provenance keys ride along). + assert dataloader.dataset is mixture.concatenated_dataset + # Sequential + deterministic: no weighted/random sampler, nothing dropped. + assert isinstance(dataloader.sampler, torch.utils.data.SequentialSampler) + assert dataloader.drop_last is False + + total = sum(len(ds) for ds in mixture.datasets) + assert len(dataloader.dataset) == total + # Every sample exactly once, in concat order — this is what the loader iterates. + assert list(dataloader.sampler) == list(range(total)) + @pytest.mark.slow # 1 sec def test_get_dataloader_zero_weights_error(self, train_pipeline_config, datasets_factory): """Test dataloader creation with zero weights raises error.""" diff --git a/tests/policies/test_pi06.py b/tests/policies/test_pi06.py index db1e0d4..2906c98 100644 --- a/tests/policies/test_pi06.py +++ b/tests/policies/test_pi06.py @@ -41,6 +41,7 @@ pad_discrete_tokens, resize_with_pad, ) +from opentau.policies.utils import PerSampleLoss, ce_per_sample # Block-causal attention mask (pi05 / π0.6 prefix-LM pattern) @@ -239,6 +240,105 @@ def test_prefix_mask_excludes_frozen_steps(self): # PI06Config defaults + validators +class TestFlowMatchingMaskedMsePerSample: + """``return_per_sample=True`` exposes per-sample ``(sum, count)`` without + perturbing the scalar — the contract the validation per-(dataset, + control_mode) breakdown relies on. Bit-identical scalars keep training + determinism intact (CLAUDE.md hard rule).""" + + def test_scalar_is_bit_identical_to_default(self): + torch.manual_seed(0) + u_t = torch.randn(3, 8, 4) + v_t = torch.randn(3, 8, 4) + actions_is_pad = torch.zeros(3, 8, dtype=torch.bool) + scalar_only = flow_matching_masked_mse( + u_t=u_t, v_t=v_t, actions_is_pad=actions_is_pad, max_action_dim=4 + ) + scalar, ps = flow_matching_masked_mse( + u_t=u_t, + v_t=v_t, + actions_is_pad=actions_is_pad, + max_action_dim=4, + return_per_sample=True, + ) + # Same float ops as the default path ⇒ exactly equal, not just close. + assert torch.equal(scalar, scalar_only) + assert ps.sum.shape == (3,) + assert ps.count.shape == (3,) + + def test_regrouped_mean_matches_scalar(self): + torch.manual_seed(0) + u_t = torch.randn(3, 8, 4) + v_t = torch.randn(3, 8, 4) + actions_is_pad = torch.zeros(3, 8, dtype=torch.bool) + scalar, ps = flow_matching_masked_mse( + u_t=u_t, + v_t=v_t, + actions_is_pad=actions_is_pad, + max_action_dim=4, + return_per_sample=True, + ) + # Σsum / Σcount over the whole batch reproduces the masked mean exactly. + assert torch.isclose(ps.sum.sum() / (ps.count.sum() + 1e-8), scalar, atol=1e-6) + + def test_per_sample_count_excludes_padded_steps_and_dims(self): + torch.manual_seed(1) + u_t = torch.randn(2, 8, 6) + v_t = torch.randn(2, 8, 6) + # sample 0: all 8 steps real; sample 1: only first 4 steps real. + actions_is_pad = torch.tensor([[False] * 8, [False, False, False, False, True, True, True, True]]) + # sample 0: 4 real action dims; sample 1: all 6. + real_action_dim = torch.tensor([4, 6], dtype=torch.long) + _, ps = flow_matching_masked_mse( + u_t=u_t, + v_t=v_t, + actions_is_pad=actions_is_pad, + max_action_dim=6, + real_action_dim=real_action_dim, + return_per_sample=True, + ) + assert ps.count[0].item() == pytest.approx(8 * 4) + assert ps.count[1].item() == pytest.approx(4 * 6) + # Independent reference for sample 0's numerator (8 steps × dims [0:4]). + ref0 = ((u_t[0, :, :4] - v_t[0, :, :4]) ** 2).sum() + assert torch.isclose(ps.sum[0], ref0, atol=1e-5) + + def test_fully_padded_sample_has_zero_sum_and_count(self): + torch.manual_seed(2) + u_t = torch.randn(2, 8, 4) + v_t = torch.randn(2, 8, 4) + actions_is_pad = torch.zeros(2, 8, dtype=torch.bool) + actions_is_pad[1] = True # sample 1 fully padded ⇒ contributes nothing + _, ps = flow_matching_masked_mse( + u_t=u_t, + v_t=v_t, + actions_is_pad=actions_is_pad, + max_action_dim=4, + return_per_sample=True, + ) + assert ps.sum[1].item() == 0.0 + assert ps.count[1].item() == 0.0 + + +class TestCePerSample: + """``ce_per_sample`` / ``PerSampleLoss`` arithmetic used by the per-group CE.""" + + def test_sum_and_count_over_valid_tokens(self): + # (B=2, S=3): sample 0 has 2 valid tokens, sample 1 has 3. + ce = torch.tensor([[1.0, 2.0, 5.0], [0.5, 0.5, 1.0]]) + valid = torch.tensor([[True, True, False], [True, True, True]]) + ps = ce_per_sample(ce * valid, valid) + assert torch.equal(ps.sum, torch.tensor([3.0, 2.0])) + assert torch.equal(ps.count, torch.tensor([2.0, 3.0])) + + def test_pooling_two_components_adds_sums_and_counts(self): + a = PerSampleLoss(sum=torch.tensor([1.0, 2.0]), count=torch.tensor([1.0, 1.0])) + b = PerSampleLoss(sum=torch.tensor([3.0, 4.0]), count=torch.tensor([2.0, 2.0])) + pooled = a + b + assert torch.equal(pooled.sum, torch.tensor([4.0, 6.0])) + assert torch.equal(pooled.count, torch.tensor([3.0, 3.0])) + + class TestPI06Config: def test_defaults_match_pi06_spec(self): cfg = PI06Config() diff --git a/tests/scripts/test_train.py b/tests/scripts/test_train.py index b8a0282..4df39bb 100644 --- a/tests/scripts/test_train.py +++ b/tests/scripts/test_train.py @@ -27,8 +27,10 @@ import accelerate import pytest +import torch from opentau.scripts.train import ( + _bucket_per_sample, _commit_wandb_step, _find_unused_params_from_env, _mixture_weighted_aggregate, @@ -261,6 +263,51 @@ def _make_partial_tracker(loss: float, mse: float, ce: float) -> MetricsTracker: assert agg["ce_loss"] == pytest.approx(0.25 * 3.0 + 0.75 * 7.0) +class TestBucketPerSample: + """``_bucket_per_sample`` disaggregates a *mixed* batch by integer group key — + the core of the per-(dataset, control_mode) validation breakdown (issue #373). + Grouping on the dropout-immune ``dataset_index`` is what makes a mixed batch + decomposable; this is pure CPU arithmetic with no accelerator.""" + + def test_disaggregates_mixed_batch_without_cross_contamination(self): + group = torch.tensor([0, 1, 0, 1, 2]) + mse_sum = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0]) + mse_count = torch.tensor([1.0, 1.0, 1.0, 1.0, 1.0]) + ce_sum = torch.tensor([10.0, 20.0, 30.0, 40.0, 50.0]) + ce_count = torch.tensor([2.0, 2.0, 2.0, 2.0, 2.0]) + buckets = _bucket_per_sample(group, mse_sum, mse_count, ce_sum, ce_count) + assert set(buckets) == {0, 1, 2} + # group 0 = samples {0, 2}; groups 1/2 must not leak in. + assert buckets[0]["mse_sum"] == pytest.approx(4.0) + assert buckets[0]["mse_count"] == pytest.approx(2.0) + assert buckets[0]["ce_sum"] == pytest.approx(40.0) + assert buckets[0]["ce_count"] == pytest.approx(4.0) + assert buckets[1]["mse_sum"] == pytest.approx(6.0) + assert buckets[2]["mse_sum"] == pytest.approx(5.0) + + def test_group_mean_is_sum_over_count_not_mean_of_means(self): + # Heterogeneous counts: the regrouped mean must be Σsum/Σcount, so a + # 3-slot sample outweighs a 1-slot sample (mean-of-means would not). + group = torch.tensor([0, 0]) + s = torch.tensor([2.0, 9.0]) + c = torch.tensor([1.0, 3.0]) + buckets = _bucket_per_sample(group, s, c, s, c) + mean = buckets[0]["mse_sum"] / buckets[0]["mse_count"] + assert mean == pytest.approx(11.0 / 4.0) + + def test_zero_count_group_carries_zeros_for_caller_guard(self): + # pi0 emits zero-count CE; the bucket keeps 0/0 and the caller's +1e-8 + # turns it into a 0 mean (no NaN). + group = torch.tensor([0, 1]) + mse = torch.tensor([1.0, 2.0]) + cnt = torch.tensor([1.0, 1.0]) + zeros = torch.tensor([0.0, 0.0]) + buckets = _bucket_per_sample(group, mse, cnt, zeros, zeros) + assert buckets[0]["ce_sum"] == 0.0 + assert buckets[0]["ce_count"] == 0.0 + assert buckets[0]["ce_sum"] / (buckets[0]["ce_count"] + 1e-8) == pytest.approx(0.0) + + class TestCommitWandbStep: """``_commit_wandb_step`` issues exactly one empty, commit-tagged log."""