Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 43 additions & 0 deletions src/opentau/datasets/dataset_mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -975,3 +975,46 @@ def get_per_dataset_dataloaders(self) -> dict[str, DataLoader]:
worker_init_fn=worker_init_fn,
)
return loaders

def get_combined_val_dataloader(self) -> DataLoader | None:
"""Create one deterministic sequential DataLoader over the whole mixture.

Unlike :meth:`get_per_dataset_dataloaders` (one loader per dataset), this
returns a single loader over the concatenated mixture so that, under
``accelerator.prepare``, every rank's shard is full even when individual
validation subsets have fewer frames than ``world_size`` — the per-dataset
loaders leave most ranks idle on tiny subsets and stack that idle time
across datasets. Each sample still carries its ``dataset_index`` /
``dataset_repo_id`` provenance (injected by ``_TaggedDataset``), so the
validation loop can disaggregate metrics per ``(dataset, control_mode)``
from the batch rather than relying on homogeneous per-dataset loaders.
``shuffle=False`` + ``drop_last=False`` make the pass order-deterministic
(seed-independent) and score every sample exactly once.

Returns:
A single ``DataLoader`` over the mixture, or ``None`` when the mixture
is empty (mirrors the empty-dataset skip in
:meth:`get_per_dataset_dataloaders`).
"""
if len(self.concatenated_dataset) == 0:
logging.info("Combined validation DataLoader skipped: the mixture is empty.")
return None

worker_name_mapping_overrides = self._get_worker_name_mapping_overrides()
worker_init_fn = None
if worker_name_mapping_overrides:
worker_init_fn = functools.partial(
_apply_data_feature_name_mapping_overrides,
mapping_overrides=worker_name_mapping_overrides,
)

return DataLoader(
self.concatenated_dataset,
batch_size=self.cfg.dataloader_batch_size,
shuffle=False,
num_workers=self.cfg.num_workers,
pin_memory=torch.cuda.is_available(),
drop_last=False,
prefetch_factor=self.cfg.prefetch_factor,
worker_init_fn=worker_init_fn,
)
32 changes: 27 additions & 5 deletions src/opentau/policies/pi0/modeling_pi0.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

import torch
import torch.nn.functional as F # noqa: N812
from einops import rearrange
from einops import rearrange, reduce
from torch import Tensor, nn
from transformers import AutoTokenizer

Expand All @@ -37,7 +37,7 @@
PaliGemmaWithExpertModel,
)
from opentau.policies.pretrained import PreTrainedPolicy
from opentau.policies.utils import log_model_loading_keys, make_action_dim_mask
from opentau.policies.utils import PerSampleLoss, log_model_loading_keys, make_action_dim_mask
from opentau.utils.accelerate_utils import get_proc_accelerator
from opentau.utils.utils import get_safe_dtype

Expand Down Expand Up @@ -466,14 +466,22 @@ def sample_actions(self, batch: dict[str, Tensor], noise: Tensor | None = None)
return actions

def forward(
self, batch: dict[str, Tensor], noise: Tensor | None = None, time: Tensor | None = None
) -> dict[str, Tensor]:
self,
batch: dict[str, Tensor],
noise: Tensor | None = None,
time: Tensor | None = None,
return_per_sample: bool = False,
) -> dict[str, Tensor | PerSampleLoss]:
"""Do a full training forward pass to compute the loss.

Args:
batch: Batch of data containing environment observations, actions, and targets.
noise: Optional noise tensor.
time: Optional time tensor.
return_per_sample: When True, also returns per-sample
``MSE_per_sample``/``CE_per_sample`` (:class:`PerSampleLoss`) for the
validation per-(dataset, control_mode) breakdown. ``CE`` is a zero
stub for pi0, so ``CE_per_sample`` carries zero sum and count.

Returns:
A dictionary containing the loss components ("MSE" and "CE").
Expand Down Expand Up @@ -537,7 +545,21 @@ def forward(

loss = losses.sum() / (full_mask.sum() + 1e-8)

return {"MSE": loss, "CE": torch.zeros_like(loss, requires_grad=True)}
out: dict[str, Tensor | PerSampleLoss] = {
"MSE": loss,
"CE": torch.zeros_like(loss, requires_grad=True),
}
if return_per_sample:
# ``losses`` is already masked (and AWR-weighted, if enabled), matching
# the scalar reduction; reduce over (chunk, dim) keeping the batch axis.
out["MSE_per_sample"] = PerSampleLoss(
sum=reduce(losses, "b c d -> b", "sum"),
count=reduce(full_mask.float(), "b c d -> b", "sum"),
)
# pi0 has no CE term; emit zero sum/count so it forms no CE groups.
zeros = torch.zeros(losses.shape[0], device=losses.device)
out["CE_per_sample"] = PerSampleLoss(sum=zeros, count=zeros.clone())
return out

def prepare_images(self, batch: dict[str, Tensor]) -> tuple[list[Tensor], list[Tensor]]:
"""Apply Pi0 preprocessing to the images.
Expand Down
61 changes: 49 additions & 12 deletions src/opentau/policies/pi05/modeling_pi05.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
PaliGemmaWithExpertModel,
)
from opentau.policies.pretrained import PreTrainedPolicy, T
from opentau.policies.utils import flow_matching_masked_mse
from opentau.policies.utils import PerSampleLoss, ce_per_sample, flow_matching_masked_mse
from opentau.utils.accelerate_utils import get_proc_accelerator
from opentau.utils.utils import get_safe_dtype

Expand Down Expand Up @@ -680,17 +680,26 @@ def sample_actions(
return actions

def forward(
self, batch: dict[str, Tensor], noise: Tensor | None = None, time: Tensor | None = None
) -> dict[str, Tensor]:
self,
batch: dict[str, Tensor],
noise: Tensor | None = None,
time: Tensor | None = None,
return_per_sample: bool = False,
) -> dict[str, Tensor | PerSampleLoss]:
"""Do a full training forward pass to compute the loss.

Args:
batch: Batch of data containing environment observations, actions, and targets.
noise: Optional noise tensor.
time: Optional time tensor.
return_per_sample: When True, additionally return per-sample ``MSE_per_sample``
/ ``CE_per_sample`` (:class:`PerSampleLoss`) so the validation loop can
bucket the loss by ``(dataset, control_mode)`` provenance. The scalar
``MSE``/``CE`` are unchanged, so the training path is unaffected.

Returns:
A dictionary containing the loss components ("MSE" and "CE").
A dictionary with the loss components ("MSE" and "CE"), plus
"MSE_per_sample"/"CE_per_sample" when ``return_per_sample`` is True.
"""
dataset_index = self._resolve_dataset_index(batch)
batch = self.normalize_inputs(batch, dataset_index)
Expand Down Expand Up @@ -732,12 +741,14 @@ def forward(
discrete_action_masks,
state=state,
real_action_dim=batch.get("real_action_dim"),
return_per_sample=return_per_sample,
)

mse_loss = losses["MSE"]
ce_loss = losses["CE"]

return {"MSE": mse_loss, "CE": ce_loss}
out: dict[str, Tensor | PerSampleLoss] = {"MSE": losses["MSE"], "CE": losses["CE"]}
if return_per_sample:
out["MSE_per_sample"] = losses["MSE_per_sample"]
out["CE_per_sample"] = losses["CE_per_sample"]
return out

def prepare_state(self, batch: dict[str, Tensor]) -> Tensor:
"""Prepares the continuous state tensor, padding or truncating to max_state_dim.
Expand Down Expand Up @@ -1315,7 +1326,8 @@ def forward(
discrete_action_masks: Tensor | None = None,
state: Tensor | None = None,
real_action_dim: Tensor | None = None,
) -> dict[str, Tensor]:
return_per_sample: bool = False,
) -> dict[str, Tensor | PerSampleLoss]:
"""Do a full training forward pass and compute the loss.

Args:
Expand Down Expand Up @@ -1438,14 +1450,16 @@ def forward(
# Shared masked-MSE reduction: AND-s frozen-prefix, timestep-pad, and
# dim-pad masks together and divides by the unmasked-slot count. See
# ``opentau.policies.utils.flow_matching_masked_mse`` for the full spec.
mse_loss = flow_matching_masked_mse(
mse_result = flow_matching_masked_mse(
u_t=u_t,
v_t=v_t,
max_action_dim=self.config.max_action_dim,
prefix_mask=prefix_mask,
actions_is_pad=actions_is_pad,
real_action_dim=real_action_dim,
return_per_sample=return_per_sample,
)
mse_loss, mse_per_sample = mse_result if return_per_sample else (mse_result, None)

# compute cross entropy loss for discrete actions
batch_size, seq_len = discrete_actions.shape
Expand All @@ -1467,6 +1481,11 @@ def forward(
discrete_action_is_pad = ~discrete_action_masks # convert into format where value for pad is True
discrete_action_ce_loss = discrete_action_ce_loss * ~discrete_action_is_pad

# Per-sample CE numerator/denominator over valid tokens, for the val breakdown.
discrete_action_ce_per_sample = (
ce_per_sample(discrete_action_ce_loss, ~discrete_action_is_pad) if return_per_sample else None
)

# compute mean
discrete_action_ce_loss = discrete_action_ce_loss.mean()

Expand Down Expand Up @@ -1505,12 +1524,30 @@ def forward(
# helps to control loss for response tokens in case of robotic data and VQA data
response_ce_loss = response_ce_loss * ~response_is_pad[:, response_slice]

# Per-sample response CE (valid tokens only) for the val breakdown.
response_ce_per_sample = (
ce_per_sample(response_ce_loss, ~response_is_pad[:, response_slice])
if return_per_sample
else None
)

# compute mean
response_ce_loss = response_ce_loss.mean()
else:
response_ce_loss = torch.tensor(0.0, device=mse_loss.device)

return {"MSE": mse_loss, "CE": discrete_action_ce_loss + response_ce_loss}
response_ce_per_sample = None

out: dict[str, Tensor | PerSampleLoss] = {
"MSE": mse_loss,
"CE": discrete_action_ce_loss + response_ce_loss,
}
if return_per_sample:
ce_ps = discrete_action_ce_per_sample
if response_ce_per_sample is not None:
ce_ps = ce_ps + response_ce_per_sample
out["MSE_per_sample"] = mse_per_sample
out["CE_per_sample"] = ce_ps
return out

def sample_actions(
self,
Expand Down
45 changes: 34 additions & 11 deletions src/opentau/policies/pi05_mem/modeling_pi05.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
from opentau.policies.pi05_mem.configuration_pi05 import PI05MemConfig
from opentau.policies.pi07.video_encoder import SpaceTimeSiglipVideoEncoder
from opentau.policies.pretrained import PreTrainedPolicy, T
from opentau.policies.utils import flow_matching_masked_mse
from opentau.policies.utils import PerSampleLoss, ce_per_sample, flow_matching_masked_mse
from opentau.utils.accelerate_utils import get_proc_accelerator
from opentau.utils.utils import get_safe_dtype

Expand Down Expand Up @@ -661,9 +661,19 @@ def sample_actions(
return actions

def forward(
self, batch: dict[str, Tensor], noise: Tensor | None = None, time: Tensor | None = None
) -> dict[str, Tensor]:
"""Do a full training forward pass to compute the loss."""
self,
batch: dict[str, Tensor],
noise: Tensor | None = None,
time: Tensor | None = None,
return_per_sample: bool = False,
) -> dict[str, Tensor | PerSampleLoss]:
"""Do a full training forward pass to compute the loss.

When ``return_per_sample`` is True, also returns per-sample
``MSE_per_sample``/``CE_per_sample`` (:class:`PerSampleLoss`) for the
validation per-(dataset, control_mode) breakdown; the scalar losses are
unchanged.
"""
dataset_index = self._resolve_dataset_index(batch)
batch = self.normalize_inputs(batch, dataset_index)
batch["discrete_actions"] = self.normalize_discrete_actions(dict(batch), dataset_index)["actions"]
Expand Down Expand Up @@ -696,12 +706,14 @@ def forward(
discrete_action_masks,
obs_history_is_pad=obs_history_is_pad,
real_action_dim=batch.get("real_action_dim"),
return_per_sample=return_per_sample,
)

mse_loss = losses["MSE"]
ce_loss = losses["CE"]

return {"MSE": mse_loss, "CE": ce_loss}
out: dict[str, Tensor | PerSampleLoss] = {"MSE": losses["MSE"], "CE": losses["CE"]}
if return_per_sample:
out["MSE_per_sample"] = losses["MSE_per_sample"]
out["CE_per_sample"] = losses["CE_per_sample"]
return out

def prepare_state(self, batch: dict[str, Tensor]) -> Tensor:
"""Prepares the temporal state tensor, padding or truncating to max_state_dim.
Expand Down Expand Up @@ -1094,7 +1106,8 @@ def forward(
discrete_action_masks: Tensor | None = None,
obs_history_is_pad: Tensor | None = None,
real_action_dim: Tensor | None = None,
) -> dict[str, Tensor]:
return_per_sample: bool = False,
) -> dict[str, Tensor | PerSampleLoss]:
"""Do a full training forward pass and compute the loss."""
prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(
videos,
Expand Down Expand Up @@ -1178,14 +1191,16 @@ def forward(
v_t = v_t.to(dtype=torch.float32)

# Shared masked-MSE reduction; see pi05 for the rationale.
mse_loss = flow_matching_masked_mse(
mse_result = flow_matching_masked_mse(
u_t=u_t,
v_t=v_t,
max_action_dim=self.config.max_action_dim,
prefix_mask=prefix_mask,
actions_is_pad=actions_is_pad,
real_action_dim=real_action_dim,
return_per_sample=return_per_sample,
)
mse_loss, mse_per_sample = mse_result if return_per_sample else (mse_result, None)

assert discrete_actions is not None
assert discrete_action_masks is not None
Expand All @@ -1206,9 +1221,17 @@ def forward(
discrete_action_is_pad = ~discrete_action_masks
discrete_action_ce_loss = discrete_action_ce_loss * ~discrete_action_is_pad

ce_per_sample_loss = (
ce_per_sample(discrete_action_ce_loss, ~discrete_action_is_pad) if return_per_sample else None
)

discrete_action_ce_loss = discrete_action_ce_loss.mean()

return {"MSE": mse_loss, "CE": discrete_action_ce_loss}
out: dict[str, Tensor | PerSampleLoss] = {"MSE": mse_loss, "CE": discrete_action_ce_loss}
if return_per_sample:
out["MSE_per_sample"] = mse_per_sample
out["CE_per_sample"] = ce_per_sample_loss
return out

def sample_actions(
self,
Expand Down
Loading
Loading