Skip to content
Open
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,8 @@

Mirrors the vision SFT stack (PackingDataLoader + RankPartitionedDataLoader),
but feeds the DROID action dataset (``joint_pos`` 8D + ``use_state``, raw/
un-normalized — same as the internal ``droid_lerobot_8b_policy`` run) through
``ActionTransformPipeline``, and trains the generation + action heads from the
public ``nvidia/Cosmos3-Nano`` base.
un-normalized) through ``ActionTransformPipeline``, and trains the generation +
action heads from the public ``nvidia/Cosmos3-Nano`` base.

Usage (1 node, 8 GPU)::

Expand Down Expand Up @@ -41,13 +40,10 @@
{"override /model": "mot_fsdp"},
{"override /data_train": None},
{"override /data_val": None},
# Match internal droid_lerobot_8b_policy: apex FusedAdam with fp32
# master_weights + eps 1e-8. adamw + fused + eps 1e-6 (bf16, no fp32
# master) under-steps the small 5x-lr action heads and leaves the action
# loss on a noisy high plateau; an exact-match forward/optimizer test
# confirmed the convergence gap was the optimizer, not the model.
# FusedAdam with fp32 master_weights + eps 1e-8 (bf16 params + eps 1e-6
# diverged on the action loss).
{"override /optimizer": "fusedadamw"},
{"override /scheduler": "lambdalinear"}, # matches internal droid_lerobot_8b (was lambdacosine)
{"override /scheduler": "lambdalinear"}, # linear LR decay
{"override /checkpoint": "s3"},
{
"override /callbacks": [
Expand Down Expand Up @@ -77,7 +73,7 @@
betas=[0.9, 0.99],
eps=1.0e-08,
fused=True, # popped by build_optimizer for FusedAdam (fused by construction)
# Generation + action heads (mirrors internal droid_lerobot_8b_policy).
# Train the generation + action heads.
keys_to_select=[
"moe_gen",
"time_embedder",
Expand All @@ -87,7 +83,7 @@
"llm2action",
"action_modality_embed",
],
lr=2.0e-04, # matches internal droid_lerobot_8b_policy submit (--lr 2e-4)
lr=2.0e-04, # for the 8192 global batch
lr_multipliers={
"action2llm": 5.0,
"llm2action": 5.0,
Expand All @@ -97,7 +93,7 @@
weight_decay=0.05,
),
scheduler=dict(
lr_scheduler_type="LambdaLinear", # matches internal droid_lerobot_8b (was LambdaCosine)
lr_scheduler_type="LambdaLinear",
cycle_lengths=[100], # smoke: 100 iters (real run sets via TOML)
f_max=[0.4],
f_min=[0.0],
Expand Down Expand Up @@ -126,7 +122,7 @@
device_monitor=dict(
every_n=200, log_memory_detail=True, save_s3=False, step_size=1, upload_every_n_mul=5
),
grad_clip=dict(clip_norm=1.0, force_finite=True), # matches internal make_8b
grad_clip=dict(clip_norm=1.0, force_finite=True),
heart_beat=dict(every_n=200, save_s3=False, step_size=1, update_interval_in_minute=20),
iter_speed=dict(every_n=1, hit_thres=50, save_s3=False, save_s3_every_log_n=500),
low_precision=dict(update_iter=1),
Expand All @@ -141,10 +137,9 @@
dcp_async_mode_enabled=False,
enable_gcs_patch_in_boto3=True,
keys_not_to_resume=[],
# Skip net_ema. (→ EMA warm-start copies net→net_ema, see dcp.py) AND the
# action heads, so they init fresh from the base — matches internal
# make_8b _DEFAULT_KEYS_TO_SKIP (Cosmos3-Nano's action heads are not
# DROID-policy-trained).
# Skip net_ema. (EMA warm-starts from net, see dcp.py) and the action
# heads, so they init fresh from the base (the base has no DROID-trained
# action heads).
keys_to_skip_loading=[
"net_ema.",
"action2llm",
Expand Down Expand Up @@ -172,7 +167,7 @@
dataloader_train=L(PackingDataLoader)(
audio_sample_rate=48000,
dataset_name="action_droid",
max_samples_per_batch=128, # count-based batch (matches internal res480 8B)
max_samples_per_batch=128, # per rank -> 8192 global batch at 64 ranks (16 nodes, shard 8 x replicate 8)
max_sequence_length=None, # None disables token packing (TOML can't express null)
patch_spatial=2,
sound_latent_fps=0,
Expand All @@ -186,6 +181,13 @@
pin_memory=True,
prefetch_factor=4,
sampler=None,
# Shuffling is handled by the dataset (iterable_shuffle=True below):
# ActionIterableShuffleDataset streams rank x worker-sharded, episode-order-
# shuffled, sequential-within-episode. The map-style dataset has no internal
# shuffle, so a SequentialSampler would feed every rank the SAME consecutive
# overlapping windows -> global batch ~1 episode -> unstable grad-norm; a plain
# RandomSampler decorrelates but does random-access I/O -> slow + OOM. The
# iterable gives decorrelation with sequential reads.
datasets=dict(
droid=dict(
ratio=1,
Expand All @@ -194,15 +196,21 @@
fps=15.0,
chunk_length=32,
action_space="joint_pos",
# Policy-only task mode. "joint" would randomly pick
# forward_dynamics/inverse_dynamics/policy per sample (multi-task),
# which dilutes each per-task loss by ~1/3.
mode="policy",
use_state=True,
iterable_shuffle=True, # rank x worker episode-shuffle stream
episode_shuffle_seed=42,
use_image_augmentation=True, # SR boost (random crop+rescale + color jitter)
# Keep-ranges window filter (drops idle/non-task frames). Off by default;
# the launcher sets use_filter_dict=True + filter_dict_path for internal parity.
# set use_filter_dict=True + filter_dict_path to enable.
use_filter_dict=False,
filter_dict_path=None,
action_normalization=None,
viewpoint="concat_view", # wrist 480p (top) + L/R shoulder 320x180 (bottom)
resolution="480", # 640x360 data @ 480p (matches internal res480 run)
resolution="480", # 640x360 data @ 480p
max_action_dim="${model.config.max_action_dim}",
cfg_dropout_rate=0.1,
tokenizer_config="${model.config.vlm_config.tokenizer}",
Expand All @@ -218,12 +226,18 @@
)


# chunk_length=32 → 33 observation frames; pin the VAE encode duration to match
# (internal used [17] for chunk_length=16). Set post-construction so it lands on
# the deep-copied NANO_MODEL_CONFIG.tokenizer.
# chunk_length=32 -> 33 observation frames; pin the VAE encode duration to match.
# Set post-construction so it lands on the deep-copied NANO_MODEL_CONFIG.tokenizer.
action_policy_droid_nano["model"]["config"]["tokenizer"]["encode_exact_durations"] = [33]


# Uncap the packed-sequence length. The NANO default (45056) caps the packed sequence,
# truncating long DROID windows to ~1/4 of their natural length; -1 (uncapped) processes
# the full vision sequence per step. Does not change the per-token loss; widens the
# effective vision context per step.
action_policy_droid_nano["model"]["config"]["max_num_tokens_after_packing"] = -1


for _item in [action_policy_droid_nano]:
_name = [k for k, v in globals().items() if v is _item][0]
cs.store(group="experiment", package="_global_", name=_name, node=_item)
67 changes: 61 additions & 6 deletions cosmos_framework/data/vfm/action/datasets/action_sft_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from typing import Any

from torch.utils.data import Dataset
from torch.utils.data import Dataset, IterableDataset, get_worker_info

from cosmos_framework.data.vfm.action.datasets.droid_lerobot_dataset import DROIDLeRobotDataset
from cosmos_framework.data.vfm.action.transforms import ActionTransformPipeline
Expand All @@ -37,13 +37,63 @@ def __len__(self) -> int:
def __getitem__(self, idx: int) -> dict[str, Any]:
return self._transform(self._dataset[idx], self._resolution)

def get_shuffle_blocks(self):
"""Delegate to the inner DROIDLeRobotDataset (per-episode/segment flat-index blocks)."""
return self._dataset.get_shuffle_blocks()



class ActionIterableShuffleDataset(IterableDataset):
"""Streaming view of a map-style ``ActionSFTDataset``.

Each ``(rank, worker)`` is assigned a DISJOINT subset of episodes (sharded over
``shard_world_size * num_workers``), shuffles its episode ORDER, and streams the
windows WITHIN each episode sequentially -> within-rank batch diversity (the N
workers of a rank stream N different episodes) AND cross-rank diversity, while
keeping reads sequential (I/O locality + COW; no RandomSampler random-access OOM).
Re-shuffles each epoch and streams indefinitely (the trainer stops at ``max_iter``).

``shard_world_size`` / ``shard_rank`` are set by ``RankPartitionedDataLoader``.
"""

def __init__(self, dataset: "ActionSFTDataset", seed: int = 42):
super().__init__()
self._dataset = dataset
self._seed = int(seed)
self.shard_world_size = 1
self.shard_rank = 0

def __len__(self) -> int: # informational only; iteration is infinite
return len(self._dataset)

def __iter__(self):
import torch

blocks = self._dataset.get_shuffle_blocks()
wi = get_worker_info()
wid = wi.id if wi is not None else 0
nw = wi.num_workers if wi is not None else 1
global_shard = int(self.shard_rank) * nw + wid
total_shards = max(1, int(self.shard_world_size) * nw)
epoch = 0
while True:
g = torch.Generator()
g.manual_seed(self._seed + epoch) # same permutation across all (rank,worker) -> disjoint shard
order = torch.randperm(len(blocks), generator=g).tolist()
for b in order[global_shard::total_shards]:
Comment thread
lfengad marked this conversation as resolved.
start, length = blocks[b]
for idx in range(start, start + length):
yield self._dataset[idx]
epoch += 1


def get_action_droid_sft_dataset(
*,
root: str,
fps: float = 15.0,
chunk_length: int = 32,
action_space: str = "joint_pos",
mode: str = "policy",
use_state: bool = True,
action_normalization: str | None = None,
viewpoint: str = "concat_view",
Expand All @@ -58,16 +108,18 @@ def get_action_droid_sft_dataset(
append_duration_fps_timestamps: bool = True,
append_resolution_info: bool = True,
append_idle_frames: bool = False,
) -> ActionSFTDataset:
"""Build the DROID action SFT dataset (joint_pos 8D by default), matching the
internal ``droid_lerobot_8b_policy`` data: ``action_space='joint_pos'`` +
``use_state`` (8D, raw/un-normalized), concat_view, chunk_length 32."""
iterable_shuffle: bool = False,
episode_shuffle_seed: int = 42,
) -> Dataset:
"""Build the DROID action SFT dataset: ``action_space='joint_pos'`` (8D) +
``use_state`` (raw/un-normalized), concat_view, chunk_length 32."""
dataset = DROIDLeRobotDataset(
root=root,
fps=fps,
chunk_length=chunk_length,
viewpoint=viewpoint,
action_space=action_space,
mode=mode,
use_state=use_state,
action_normalization=action_normalization,
use_image_augmentation=use_image_augmentation,
Expand All @@ -83,4 +135,7 @@ def get_action_droid_sft_dataset(
append_resolution_info=append_resolution_info,
append_idle_frames=append_idle_frames,
)
return ActionSFTDataset(dataset, transform, resolution)
sft = ActionSFTDataset(dataset, transform, resolution)
if iterable_shuffle:
return ActionIterableShuffleDataset(sft, seed=episode_shuffle_seed)
return sft
53 changes: 34 additions & 19 deletions cosmos_framework/data/vfm/action/datasets/droid_lerobot_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,8 @@
"right": "observation.image.exterior_image_2_left",
}
_STATE_FEATURE = "observation.state.cartesian_position"
# joint_pos (8D = 7 arm joints + gripper) features, matching the internal
# DROIDLeRobotDataset(action_space="joint_pos", use_state=...). These are
# absolute joint commands/states (no normalization is applied for joint_pos,
# matching the internal canonical run which leaves action_normalization=None).
# joint_pos (8D = 7 arm joints + gripper) features. These are absolute joint
# commands/states (no normalization is applied for joint_pos; action_normalization=None).
_JOINT_ACTION_FEATURE = "action.joint_position" # [7] commanded joints
_ACTION_GRIPPER_FEATURE = "action.gripper_position" # [1] commanded gripper
_JOINT_STATE_FEATURE = "observation.state.joint_positions" # [7] observed joints
Expand Down Expand Up @@ -67,8 +65,7 @@ class DROIDLeRobotDataset(Dataset):
gripper(1)]``, quantile-normalized (the v1.2 midtrain default).
* ``action_space="joint_pos"``: 8D ``[joint(7), gripper(1)]`` absolute joint
commands, NOT normalized, with ``use_state=True`` prepending the initial
observed joint+gripper state → ``(chunk+1, 8)`` — matching the internal
``Cosmos3-Nano-Policy-DROID`` post-training run.
observed joint+gripper state → ``(chunk+1, 8)``.
Filter dictionaries, temporal-segment validation, and image augmentation from
the production wrapper are intentionally omitted.
"""
Expand All @@ -78,7 +75,7 @@ def __init__(
root: str = "/path/to/cosmos3_action_datasets/droid_plus_lerobot_640x360_20260412",
fps: float = 15.0,
chunk_length: int = 16,
mode: str = "joint",
mode: str = "policy",
pose_convention: PoseConvention = "backward_framewise",
tolerance_s: float = 2e-4,
viewpoint: Viewpoint = "concat_view",
Expand Down Expand Up @@ -112,15 +109,15 @@ def __init__(
# to all views with shared params (temporally + cross-view consistent). Lazy-built.
self._use_image_augmentation = bool(use_image_augmentation)
self._image_augmentor: T.Compose | None = None
# Keep-ranges window filter (internal use_filter_dict): restrict training windows
# to curated active segments, dropping idle/non-task frames. Off by default; the
# keep-ranges JSON is supplied via filter_dict_path (an internal data artifact).
# Keep-ranges window filter: restrict training windows to curated active segments,
# dropping idle/non-task frames. Off by default; the keep-ranges JSON is supplied
# via filter_dict_path (published at HF KarlP/droid).
self._use_filter_dict = bool(use_filter_dict)
self._filter_dict_path = filter_dict_path
if self._use_filter_dict and not self._filter_dict_path:
raise ValueError("use_filter_dict=True requires filter_dict_path")
# joint_pos trains on raw 8D joint values (the internal canonical run
# leaves action_normalization=None); ee_pose keeps quantile normalization.
# joint_pos trains on raw 8D joint values (action_normalization=None);
# ee_pose keeps quantile normalization.
self._action_normalization = None if action_space == "joint_pos" else action_normalization
self._domain_id = get_domain_id("droid_lerobot")
self._norm_stats: dict[str, torch.Tensor] | None = None
Expand Down Expand Up @@ -185,7 +182,7 @@ def __init__(
self._valid_cum = np.cumsum(np.maximum(0, ep_counts - self._chunk_length)).astype(np.int64)

# Keep-ranges filter: build a per-segment index over only the kept windows.
# Mirrors internal _append_index_records (use_filter_dict): the filter dict maps a
# Keep-ranges (use_filter_dict): the filter dict maps a
# gs:// trajectory key -> list of [start, end] frame ranges; keep windows whose start
# is in [max(start,0), min(end-chunk, valid)). Episodes absent from the dict are dropped.
if self._use_filter_dict:
Expand Down Expand Up @@ -309,11 +306,10 @@ def __getitem__(self, idx: int) -> dict[str, Any]:
)

def _build_joint_action(self, observation_rows: list[dict[str, Any]]) -> torch.Tensor:
"""8D joint-position action ``[joint(7), gripper(1)]`` over the chunk, matching
the internal ``action_space='joint_pos'``. The window is ``chunk+1`` frames:
``row[0]`` is the initial observed state (prepended when ``use_state``), and
``rows[1:]`` are the ``chunk`` commanded actions. Gripper is flipped (1 - g).
No normalization is applied (internal canonical run uses raw joint values)."""
"""8D joint-position action ``[joint(7), gripper(1)]`` over the chunk. The window
is ``chunk+1`` frames: ``row[0]`` is the initial observed state (prepended when
``use_state``), and ``rows[1:]`` are the ``chunk`` commanded actions. Gripper is
flipped (1 - g). No normalization is applied (raw joint values)."""
action_rows = observation_rows[1:]
joints = np.asarray([r[_JOINT_ACTION_FEATURE] for r in action_rows], dtype=np.float32) # [chunk, 7]
gripper = np.asarray([r[_ACTION_GRIPPER_FEATURE] for r in action_rows], dtype=np.float32).reshape(-1, 1)
Expand Down Expand Up @@ -350,7 +346,7 @@ def _load_concat_video(
# Random crop+rescale (spatial jitter) + color jitter, BEFORE the concat.
# All three views are stacked so one sampled set of params is applied
# uniformly across every frame and view (temporally + cross-view consistent),
# while each __getitem__ resamples. Matches the internal DROID recipe.
# while each __getitem__ resamples.
if self._image_augmentor is None:
_, _, h, w = wrist.shape
self._image_augmentor = T.Compose(
Expand Down Expand Up @@ -454,6 +450,25 @@ def _load_norm_stats(self) -> dict[str, torch.Tensor]:
}
return self._norm_stats


def get_shuffle_blocks(self) -> list[tuple[int, int]]:
"""Per-episode (or per kept-segment, when use_filter_dict) flat-index blocks
``(start, length)``. ActionIterableShuffleDataset shuffles the ORDER of these blocks
and shards them disjointly across ranks, while keeping windows *within* a block
sequential -> decorrelates batches across ranks without random-access I/O (preserves
locality + copy-on-write memory sharing across workers)."""
import numpy as _np

cum = self._seg_cum if self._use_filter_dict else self._valid_cum
blocks: list[tuple[int, int]] = []
prev = 0
for c in _np.asarray(cum).tolist():
c = int(c)
if c > prev:
blocks.append((prev, c - prev))
prev = c
return blocks

def __len__(self) -> int:
if self._use_filter_dict:
return int(self._seg_cum[-1]) if self._seg_cum.size else 0
Expand Down
Loading