From 3addb7184421815595c2459d14dd0add551cef5d Mon Sep 17 00:00:00 2001 From: lfengad Date: Tue, 9 Jun 2026 14:54:19 +0800 Subject: [PATCH 01/16] fix: add ci test download cache with refined retry mechanism (#26) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### Summary CI tests download input assets (e.g. action/video inputs) over the network, and these intermittently fail with transient gateway errors (502/503/504), flaking the run. This PR makes those downloads robust and avoids re-fetching the same assets every run. ### Changes - **Backoff retry** (`inference/common/args.py`): wrap each input download in an outer retry with exponential backoff + jitter (6 attempts, env-overridable via `COSMOS_DOWNLOAD_*`). Permanent errors (400/401/403/404) fail fast. - **Opt-in download cache**: when `COSMOS_DOWNLOAD_CACHE_DIR` is set, downloads are cached by URL and reused across runs; unset → unchanged behavior. Concurrent writers use an atomic move. - **CI wiring** (`gpu-tests.yml`): the `unittest` and `inference-smoke` jobs point at a shared persistent cache dir (`$RUNNER_WORKSPACE/cosmos_input_cache`, outside the repo tree so cleanup keeps it), reused across runs and PRs on the same runner. ### Impact - Production/local behavior unchanged: cache is off unless the env var is set; retry is transparent on success and only adds resilience on failure. - Only new persisted artifact is the cache dir; replaces previously-leaked `/tmp` temp dirs in those jobs. --------- Co-authored-by: Claude Opus 4.8 (1M context) --- .github/workflows/gpu-tests.yml | 5 ++ cosmos_framework/inference/common/args.py | 77 ++++++++++++++++++- .../inference/common/args_test.py | 7 +- 3 files changed, 84 insertions(+), 5 deletions(-) diff --git a/.github/workflows/gpu-tests.yml b/.github/workflows/gpu-tests.yml index 20e984f..9245dd2 100644 --- a/.github/workflows/gpu-tests.yml +++ b/.github/workflows/gpu-tests.yml @@ -121,9 +121,11 @@ jobs: # One inference call over t2vs (+sound), action policy, and forward_dynamics; checks each output. # MAX_GPUS defaults to 8. -s streams the live process log. + # Reuse the same input-asset cache dir as the unittest job. - name: Nano inference smoke (t2vs + action policy + forward_dynamics, 8 GPU) run: | export LD_LIBRARY_PATH= + export COSMOS_DOWNLOAD_CACHE_DIR="$RUNNER_WORKSPACE/cosmos_input_cache" uv run --all-extras --group=cu128-train python -m pytest -v -s \ tests/nano_inference_smoke_test.py --num-gpus=8 --levels=2 -o addopts= @@ -193,9 +195,12 @@ jobs: # is absent (via RunIf / pytest.skip guards), so this is green without # internal credentials; provide the credential file on the runner to # exercise them. New tests are picked up automatically (no markers/lists). + # Cache downloaded input assets in a persistent dir (outside the repo tree, + # so the cleanup step keeps it) and reuse it across runs. - name: Unit tests run: | export LD_LIBRARY_PATH= + export COSMOS_DOWNLOAD_CACHE_DIR="$RUNNER_WORKSPACE/cosmos_input_cache" uv run --all-extras --group=cu128-train python -m pytest -v -s \ cosmos_framework/ -o addopts= diff --git a/cosmos_framework/inference/common/args.py b/cosmos_framework/inference/common/args.py index 8d7cd5a..8ac4ec7 100644 --- a/cosmos_framework/inference/common/args.py +++ b/cosmos_framework/inference/common/args.py @@ -3,11 +3,14 @@ import contextlib import glob +import hashlib import itertools import json import os +import random import re import tempfile +import time from abc import ABC, abstractmethod from pathlib import Path from typing import ( @@ -52,7 +55,47 @@ MEDIA_EXTENSIONS = IMAGE_EXTENSIONS + VIDEO_EXTENSIONS +# Retry transient download errors with exponential backoff (env-overridable). +_DOWNLOAD_MAX_ATTEMPTS = int(os.environ.get("COSMOS_DOWNLOAD_MAX_ATTEMPTS", "6")) +_DOWNLOAD_BACKOFF_BASE_S = float(os.environ.get("COSMOS_DOWNLOAD_BACKOFF_S", "4")) +_DOWNLOAD_BACKOFF_CAP_S = float(os.environ.get("COSMOS_DOWNLOAD_BACKOFF_CAP_S", "60")) + +# Statuses not worth retrying. +_PERMANENT_HTTP_MARKERS = ("400 Bad Request", "401 Unauthorized", "403 Forbidden", "404 Not Found") + + +def _is_permanent_download_error(exc: BaseException) -> bool: + if type(exc).__name__ in {"NotFoundError", "PermissionError"}: + return True + msg = str(exc) + return any(marker in msg for marker in _PERMANENT_HTTP_MARKERS) + + def _download_file_url(url: str, path: Path): + """Download ``url`` to ``path``, retrying transient network/server errors.""" + from cosmos_framework.utils import log + + last_exc: BaseException | None = None + for attempt in range(1, _DOWNLOAD_MAX_ATTEMPTS + 1): + try: + _download_file_url_once(url, path) + return + except Exception as exc: # noqa: BLE001 + last_exc = exc + if _is_permanent_download_error(exc) or attempt == _DOWNLOAD_MAX_ATTEMPTS: + break + delay = min(_DOWNLOAD_BACKOFF_CAP_S, _DOWNLOAD_BACKOFF_BASE_S * 2 ** (attempt - 1)) + delay += random.uniform(0, delay * 0.25) # jitter + log.warning( + f"Download attempt {attempt}/{_DOWNLOAD_MAX_ATTEMPTS} for {url} failed " + f"({type(exc).__name__}: {exc}); retrying in {delay:.1f}s." + ) + time.sleep(delay) + + raise RuntimeError(f"Failed to download {url} after {_DOWNLOAD_MAX_ATTEMPTS} attempt(s)") from last_exc + + +def _download_file_url_once(url: str, path: Path): if "huggingface.co" in url: _download_file_hf(url, path) else: @@ -85,6 +128,33 @@ def _download_file_hf(url: str, path: Path): f.write(chunk) +def _resolve_url_download(url: str, name: str) -> Path: + """Fetch ``url`` to a local file and return its path. + + When ``COSMOS_DOWNLOAD_CACHE_DIR`` is set, downloads are cached there by URL + and reused across runs; otherwise a fresh temp dir is used per download. + """ + cache_root = os.environ.get("COSMOS_DOWNLOAD_CACHE_DIR") + if not cache_root: + local_path = Path(tempfile.mkdtemp()) / name + _download_file_url(url, local_path) + return local_path + + cache_dir = Path(cache_root) + digest = hashlib.sha256(url.encode()).hexdigest()[:16] + cache_path = cache_dir / f"{digest}-{name}" + done_marker = Path(f"{cache_path}.done") + if cache_path.exists() and done_marker.exists(): + return cache_path + cache_dir.mkdir(parents=True, exist_ok=True) + # Atomic move so concurrent writers never observe a half-written file. + tmp_path = cache_path.with_name(f"{cache_path.name}.{os.getpid()}.tmp") + _download_file_url(url, tmp_path) + os.replace(tmp_path, cache_path) + done_marker.write_text(url) + return cache_path + + def _download_file(url: str, path: Path): if "://" not in url and Path(url).resolve() == path.resolve(): return @@ -94,10 +164,9 @@ def _download_file(url: str, path: Path): return if "://" in url: - # Download to a temporary directory and symlink to the final path. - # This keeps the output directory small. - local_path = Path(tempfile.TemporaryDirectory(delete=False).name) / path.name - _download_file_url(url, local_path) + # Download (optionally via the persistent cache) and symlink to the final + # path. This keeps the output directory small. + local_path = _resolve_url_download(url, path.name) else: local_path = Path(url) diff --git a/cosmos_framework/inference/common/args_test.py b/cosmos_framework/inference/common/args_test.py index 73facad..ae56295 100644 --- a/cosmos_framework/inference/common/args_test.py +++ b/cosmos_framework/inference/common/args_test.py @@ -5,6 +5,8 @@ import os from pathlib import Path +import pytest + from cosmos_framework.inference.args import DEFAULT_CHECKPOINT, DEFAULT_CHECKPOINT_NAME from cosmos_framework.inference.common.args import CheckpointConfig, CheckpointOverrides, download_file @@ -13,7 +15,10 @@ } -def test_download_file(tmp_path: Path): +def test_download_file(tmp_path: Path, monkeypatch: pytest.MonkeyPatch): + # Disable the URL cache; this test asserts each download is independent. + monkeypatch.delenv("COSMOS_DOWNLOAD_CACHE_DIR", raising=False) + download_url_1 = ( "https://github.com/nvidia-cosmos/cosmos-dependencies/raw/2b17a2413bd86b2cf9b03823637108851e4ddf2d/inputs/vision/robot_153.jpg" ) From 55c6276119503f5024fd400e4283e768913f1243 Mon Sep 17 00:00:00 2001 From: yy-code-nv Date: Tue, 9 Jun 2026 20:44:42 +0800 Subject: [PATCH 02/16] Remove unused code; Add golden for GB200 (#28) - Remove unused code for config.py (used for old toml config system) - Add vision_sft_nano golden for GB200 --- cosmos_framework/utils/config.py | 62 ++------------------------------ tests/launch_regression_test.py | 18 +++++++--- 2 files changed, 16 insertions(+), 64 deletions(-) diff --git a/cosmos_framework/utils/config.py b/cosmos_framework/utils/config.py index 33fc9a9..441d1b8 100644 --- a/cosmos_framework/utils/config.py +++ b/cosmos_framework/utils/config.py @@ -518,70 +518,14 @@ def validate(self) -> None: assert self.job.name != "" -def _reload_make_config_for_registrations(root_cfg: "Config") -> None: - """Run ``make_config()`` once for import-time registrations (same intent as loading ``config.py``). - - Deserialized YAML/TOML instantiates attrs ``Config`` with ``__class__.__module__`` set to the - module that defines the class (often ``…defaults.config``). ``load_callable`` splits on the - last dot, which turns that into ``import …defaults`` + ``getattr(..., "config")`` — the - ``defaults.config`` submodule, which often has no ``make_config``. The entrypoint with - ``make_config`` is typically the sibling module ``….config``. - """ - from cosmos_framework.utils.serialization import load_callable - - cls_mod = type(root_cfg).__module__ - - def _try_make_config(mod: object) -> bool: - mk = getattr(mod, "make_config", None) - if mk is None: - return False - _ = mk() - return True - - if cls_mod.endswith(".defaults.config"): - sibling = cls_mod[: -len(".defaults.config")] + ".config" - try: - if _try_make_config(importlib.import_module(sibling)): - return - except ModuleNotFoundError: - pass - - try: - if _try_make_config(load_callable(cls_mod)): - return - except (AssertionError, AttributeError, ModuleNotFoundError): - pass - - try: - if _try_make_config(importlib.import_module(cls_mod)): - return - except ModuleNotFoundError: - pass - - raise AttributeError( - f"No make_config() found for Config class module {cls_mod!r}. " - "YAML/TOML export must match a tree whose Python package exposes make_config " - "(e.g. cosmos_framework.configs.base.vlm.config next to cosmos_framework.configs.base.vlm.defaults.config)." - ) - - def load_config(config_path: str, opts: list[str], enable_one_logger: bool = False) -> Config: - from cosmos_framework.utils.serialization import from_toml, from_yaml + from cosmos_framework.utils.serialization import from_yaml, load_callable t1 = time.monotonic_ns() if config_path.endswith(".yaml"): config = from_yaml(config_path) - # Import-time registrations (dataloaders, experiments, …): YAML root class - # typically lives in …defaults.config; make_config() is on sibling …config. - _reload_make_config_for_registrations(config) - - from cosmos_framework.utils.config_helper import override - - config = override(config, opts, remove_defaults=True) - elif config_path.endswith(".toml"): - config = from_toml(config_path) - # TOML is the same exported structured schema as YAML. - _reload_make_config_for_registrations(config) + # for registration of dataloaders, etc. + _ = load_callable(config.__module__).make_config() from cosmos_framework.utils.config_helper import override diff --git a/tests/launch_regression_test.py b/tests/launch_regression_test.py index e1b65f3..0a1a9f2 100644 --- a/tests/launch_regression_test.py +++ b/tests/launch_regression_test.py @@ -431,10 +431,8 @@ def h100_inputs(tmp_path_factory: pytest.TempPathFactory): as a Hydra backbone override. """ arch = _detect_arch() - if arch == "gb200": - pytest.skip("gb200 inputs not in OSS layout; goldens kept for historical reference only.") - if arch != "h100": - pytest.skip(f"no regression goldens for GPU arch {arch!r}; only h100 supported") + if arch not in ("h100", "gb200"): + pytest.skip(f"no regression goldens for GPU arch {arch!r}; only h100/gb200 supported") if shutil.which("uvx") is None: pytest.skip("uvx not on PATH -- required to prepare regression inputs") @@ -466,7 +464,7 @@ def _make_dcp() -> Path: _ensure("BASE_CHECKPOINT_PATH", _make_dcp) try: - yield {"vlm_model_path": os.environ["MODEL_PATH"]} + yield {"vlm_model_path": os.environ.get("MODEL_PATH", "")} finally: for var in set_vars: os.environ.pop(var, None) @@ -588,6 +586,16 @@ def test_launch_regression_8gpu(spec_key: str, tmp_path: Path, h100_inputs: dict 39.70305, 48.52226, 52.18334, 22.77521, 25.06970, ], }, + # Captured 2026-06-09 on a 4 × NVIDIA GB200 node with seed 42 against the + # current TOML-config pipeline (inputs prepared in-test by ``h100_inputs``, + # which now also serves gb200). Runs under ``--deterministic`` so loss + # reproduces bit-exact across all 10 iters; loss matches the h100 nano + # series within ~1e-3. grad_norm is non-det because ``compile.enabled=true`` + # makes the all-rank reduction not bit-exact, so None (same as h100). + "vision_sft_nano": { + "loss": [0.2269, 0.2181, 0.2026, 0.2309, 0.2178, 0.273, 0.2871, 0.2164, 0.2059, 0.264], + "grad_norm": None, + }, }, # Recaptured 2026-06-03 on a 4 × NVIDIA H100 80GB HBM3 node with seed 42 and # transformers==4.57.6. VLM model is ``Qwen/Qwen3-VL-8B-Instruct``; inputs are From bbda32163c1600307aa492567c9b3b2fae0886cc Mon Sep 17 00:00:00 2001 From: LiangHao Date: Thu, 11 Jun 2026 10:31:03 +0800 Subject: [PATCH 03/16] Add DROID action-policy SFT recipe (Cosmos3-Nano, joint_pos) (#24) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Summary Adds a **DROID action-policy SFT recipe** for `nvidia/Cosmos3-Nano`, mirroring the internal `droid_lerobot_8b` policy run, so users can post-train the action-generation + action heads on DROID (LeRobot v3.0) data. ## What's included - **`data/vfm/action/datasets/droid_lerobot_dataset.py`** — DROID LeRobot dataset: compact columnar load + episode-aware windowing (replaces an eager full-table materialization), plus `joint_pos` (8D: 7 joints + gripper) and `use_state` support. - **`data/vfm/action/datasets/action_sft_dataset.py`** (new) — `get_action_droid_sft_dataset(...)` wrapping the dataset through `ActionTransformPipeline`. - **`configs/.../action/posttrain_config/action_policy_droid_nano.py`** (new) — registered `action_policy_droid_nano` experiment (Cosmos3-Nano / 8B MoT): optimizer trains gen+action heads (5× LR on action heads), `LambdaLinear` schedule, count-based batch, res480, `encode_exact_durations=[33]` (chunk 32 → 33 frames). - **`checkpoint/dcp.py`** — EMA warm-start: when `keys_to_skip_loading` excludes `net_ema.`, initialize `net_ema = net` from the base weights so EMA starts from the init rather than zeros. - **`examples/toml/sft_config/action_policy_droid_{nano,repro}.toml`** — 1-GPU smoke + scaled (res480) configs. - **`examples/launch_sft_action_policy_droid.sh`** + **`docs/action_policy_droid_posttraining.md`** — runnable launcher and walkthrough. ## Validation End-to-end on H200: - **1 node / 8×H200** — dry-run + training at res480, `max_samples_per_batch=32` (64 OOMs at 139 GiB; internal used 128 on GB200). - **2 nodes / 16 ranks** — HSDP `shard 8 × replicate 2`, `TRAIN_EXIT=0`. - Recipe faithful to internal `droid_lerobot_8b`: lr 1e-4 / betas / wd, 5× action-head LR, `LambdaLinear`, shift `{256:3,480:5,720:10}`, `concat_view`, `chunk_length=32`. ## Notes - Count-based batch (`max_samples_per_batch`, `max_sequence_length=None`) lives in the experiment Python — TOML cannot express `null`, and the loader only overrides keys present in the TOML. - Base checkpoint: convert `nvidia/Cosmos3-Nano` → DCP and pass via `BASE_CHECKPOINT_PATH`; action heads init fresh (skipped on load). --------- Signed-off-by: Hao Liang Co-authored-by: Claude Opus 4.8 Co-authored-by: lfengad Co-authored-by: Yu-Wei Chao <82182961+ychao-nvidia@users.noreply.github.com> --- cosmos_framework/checkpoint/dcp.py | 14 + cosmos_framework/configs/base/config.py | 1 + .../action_policy_droid_nano.py | 229 ++++++++++++++++ .../vfm/action/datasets/action_sft_dataset.py | 86 ++++++ .../action/datasets/droid_lerobot_dataset.py | 246 ++++++++++++++++-- docs/action_policy_droid_posttrain.md | 96 +++++++ examples/launch_sft_action_policy_droid.sh | 39 +++ .../sft_config/action_policy_droid_repro.toml | 53 ++++ sitecustomize.py | 14 + 9 files changed, 750 insertions(+), 28 deletions(-) create mode 100644 cosmos_framework/configs/base/experiment/action/posttrain_config/action_policy_droid_nano.py create mode 100644 cosmos_framework/data/vfm/action/datasets/action_sft_dataset.py create mode 100644 docs/action_policy_droid_posttrain.md create mode 100755 examples/launch_sft_action_policy_droid.sh create mode 100644 examples/toml/sft_config/action_policy_droid_repro.toml diff --git a/cosmos_framework/checkpoint/dcp.py b/cosmos_framework/checkpoint/dcp.py index 9fa64bb..4e036c9 100644 --- a/cosmos_framework/checkpoint/dcp.py +++ b/cosmos_framework/checkpoint/dcp.py @@ -539,6 +539,20 @@ def load( "Ensure the model has net_ema submodule." ) _state_dict[sd_key] = _state_dict[key_ema] + elif warm_start and any(str(s).startswith("net_ema") for s in self.keys_to_skip_loading): + # Only when net_ema.* is explicitly skipped on load (e.g. an HF->DCP + # init from convert_model_to_dcp that has only net.*): the skipped + # net_ema.* keep build_net() construction values (random init when + # vlm_config.pretrained_weights.enabled=False), which would seed EMA + # from random weights -> copy net.* -> net_ema.* so EMA starts from the + # freshly-loaded init. When net_ema.* IS loaded (e.g. a training DCP + # that carries a trained EMA), do NOT clobber it. + log.info("Warm start: net_ema. skipped on load -> resetting net_ema = net.") + for sd_key in list(_state_dict.keys()): + if sd_key.startswith("net."): + key_ema = "net_ema." + sd_key.removeprefix("net.") + if key_ema in _state_dict: + _state_dict[key_ema] = _state_dict[sd_key] results = _model_wrapper.load_state_dict(_state_dict) if results is not None: if len(results.missing_keys) > 0: diff --git a/cosmos_framework/configs/base/config.py b/cosmos_framework/configs/base/config.py index f7fd14e..391df1c 100644 --- a/cosmos_framework/configs/base/config.py +++ b/cosmos_framework/configs/base/config.py @@ -98,4 +98,5 @@ def make_config() -> Config: # Register shipped experiments explicitly. import cosmos_framework.configs.base.experiment.sft.vision_sft_nano # noqa: F401 import cosmos_framework.configs.base.experiment.sft.vision_sft_super # noqa: F401 + import cosmos_framework.configs.base.experiment.action.posttrain_config.action_policy_droid_nano # noqa: F401 return c diff --git a/cosmos_framework/configs/base/experiment/action/posttrain_config/action_policy_droid_nano.py b/cosmos_framework/configs/base/experiment/action/posttrain_config/action_policy_droid_nano.py new file mode 100644 index 0000000..4c3a6fb --- /dev/null +++ b/cosmos_framework/configs/base/experiment/action/posttrain_config/action_policy_droid_nano.py @@ -0,0 +1,229 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: OpenMDW-1.1 + +"""``action_policy_droid_nano`` — Cosmos3-Nano DROID action policy SFT recipe. + +Mirrors the vision SFT stack (PackingDataLoader + RankPartitionedDataLoader), +but feeds the DROID action dataset (``joint_pos`` 8D + ``use_state``, raw/ +un-normalized — same as the internal ``droid_lerobot_8b_policy`` run) through +``ActionTransformPipeline``, and trains the generation + action heads from the +public ``nvidia/Cosmos3-Nano`` base. + +Usage (1 node, 8 GPU):: + + DROID_ROOT=/path/to/droid_lerobot_640x360/success \\ + BASE_CHECKPOINT_PATH= \\ + WAN_VAE_PATH= \\ + torchrun --nproc_per_node=8 -m cosmos_framework.scripts.train \\ + --sft-toml examples/toml/sft_config/action_policy_droid_repro.toml +""" + +import copy + +from hydra.core.config_store import ConfigStore + +from cosmos_framework.utils.lazy_config import LazyCall as L +from cosmos_framework.utils.lazy_config import LazyDict + +from cosmos_framework.configs.base.experiment.sft.models.nano_model_config import NANO_MODEL_CONFIG +from cosmos_framework.data.vfm.joint_dataloader import ( + PackingDataLoader, + RankPartitionedDataLoader, +) +from cosmos_framework.data.vfm.action.datasets.action_sft_dataset import get_action_droid_sft_dataset + +cs = ConfigStore.instance() + + +action_policy_droid_nano = LazyDict( + dict( + defaults=[ + {"override /model": "mot_fsdp"}, + {"override /data_train": None}, + {"override /data_val": None}, + # Match internal droid_lerobot_8b_policy: apex FusedAdam with fp32 + # master_weights + eps 1e-8. adamw + fused + eps 1e-6 (bf16, no fp32 + # master) under-steps the small 5x-lr action heads and leaves the action + # loss on a noisy high plateau; an exact-match forward/optimizer test + # confirmed the convergence gap was the optimizer, not the model. + {"override /optimizer": "fusedadamw"}, + {"override /scheduler": "lambdalinear"}, # matches internal droid_lerobot_8b (was lambdacosine) + {"override /checkpoint": "s3"}, + { + "override /callbacks": [ + "basic", + "optimization", + "job_monitor", + ] + }, + {"override /ema": "power"}, + {"override /tokenizer": "wan2pt2_tokenizer"}, + {"override /sound_tokenizer": None}, + {"override /cluster": None}, + {"override /vlm_config": None}, + {"override /ckpt_type": "dcp"}, + "_self_", + ], + job=dict( + project="cosmos3", + group="action_sft", + name="action_policy_droid_nano", + wandb_mode="disabled", + ), + model=dict( + config=copy.deepcopy(NANO_MODEL_CONFIG), # action_gen=True, max_action_dim=64 + ), + optimizer=dict( + betas=[0.9, 0.99], + eps=1.0e-08, + fused=True, # popped by build_optimizer for FusedAdam (fused by construction) + # Generation + action heads (mirrors internal droid_lerobot_8b_policy). + keys_to_select=[ + "moe_gen", + "time_embedder", + "vae2llm", + "llm2vae", + "action2llm", + "llm2action", + "action_modality_embed", + ], + lr=2.0e-04, # matches internal droid_lerobot_8b_policy submit (--lr 2e-4) + lr_multipliers={ + "action2llm": 5.0, + "llm2action": 5.0, + "action_modality_embed": 5.0, + }, + optimizer_type="FusedAdam", + weight_decay=0.05, + ), + scheduler=dict( + lr_scheduler_type="LambdaLinear", # matches internal droid_lerobot_8b (was LambdaCosine) + cycle_lengths=[100], # smoke: 100 iters (real run sets via TOML) + f_max=[0.4], + f_min=[0.0], + f_start=[0.0], + verbosity_interval=0, + warm_up_steps=[0], + ), + trainer=dict( + distributed_parallelism="fsdp", + grad_accum_iter=1, + logging_iter=1, + max_iter=100, # smoke + max_val_iter=None, + run_validation=False, + run_validation_on_start=False, + save_zero_checkpoint=False, + seed=42, + timeout_period=999999999, + validation_iter=100, + compile_config=dict(recompile_limit=8, use_duck_shape=False), + cudnn=dict(benchmark=True, deterministic=False), + ddp=dict(broadcast_buffers=True, find_unused_parameters=False, static_graph=True), + grad_scaler_args=dict(enabled=False), + callbacks=dict( + dataloader_speed=dict(every_n=100, save_s3=False, step_size=1), + device_monitor=dict( + every_n=200, log_memory_detail=True, save_s3=False, step_size=1, upload_every_n_mul=5 + ), + grad_clip=dict(clip_norm=1.0, force_finite=True), # matches internal make_8b + heart_beat=dict(every_n=200, save_s3=False, step_size=1, update_interval_in_minute=20), + iter_speed=dict(every_n=1, hit_thres=50, save_s3=False, save_s3_every_log_n=500), + low_precision=dict(update_iter=1), + manual_gc=dict(every_n=5, gc_level=1, warm_up=1), + param_count=dict(save_s3=False), + skip_nan_step=dict(max_consecutive_nan=100), + training_stats=dict(log_freq=100), + ), + ), + checkpoint=dict( + broadcast_via_filesystem=False, + dcp_async_mode_enabled=False, + enable_gcs_patch_in_boto3=True, + keys_not_to_resume=[], + # Skip net_ema. (→ EMA warm-start copies net→net_ema, see dcp.py) AND the + # action heads, so they init fresh from the base — matches internal + # make_8b _DEFAULT_KEYS_TO_SKIP (Cosmos3-Nano's action heads are not + # DROID-policy-trained). + keys_to_skip_loading=[ + "net_ema.", + "action2llm", + "llm2action", + "action_modality_embed", + "action_pos_embed", + ], + load_ema_to_reg=False, + load_path="???", # Cosmos3-Nano DCP dir; supply via TOML/env + load_training_state=False, + only_load_scheduler_state=False, + save_iter=100, + strict_resume=False, # base init: tolerate key set differences + verbose=True, + hf_export=dict( + enabled=False, + export_every_n=1, + hf_repo_id=None, + upload_to_object_store=dict(bucket="", credentials="", enabled=False), + ), + jit=dict(device="cuda", dtype="bfloat16", enabled=False, input_shape=None, strict=True), + load_from_object_store=dict(bucket="", credentials="", enabled=False), + save_to_object_store=dict(bucket="", credentials="", enabled=False), + ), + dataloader_train=L(PackingDataLoader)( + audio_sample_rate=48000, + dataset_name="action_droid", + max_samples_per_batch=128, # count-based batch (matches internal res480 8B) + max_sequence_length=None, # None disables token packing (TOML can't express null) + patch_spatial=2, + sound_latent_fps=0, + tokenizer_spatial_compression_factor=16, + tokenizer_temporal_compression_factor=4, + dataloader=L(RankPartitionedDataLoader)( + batch_size=1, + in_order=False, + num_workers=4, + persistent_workers=True, + pin_memory=True, + prefetch_factor=4, + sampler=None, + datasets=dict( + droid=dict( + ratio=1, + dataset=L(get_action_droid_sft_dataset)( + root="${oc.env:DROID_ROOT}", + fps=15.0, + chunk_length=32, + action_space="joint_pos", + use_state=True, + use_image_augmentation=True, # SR boost (random crop+rescale + color jitter) + # Keep-ranges window filter (drops idle/non-task frames). Off by default; + # the launcher sets use_filter_dict=True + filter_dict_path for internal parity. + use_filter_dict=False, + filter_dict_path=None, + action_normalization=None, + viewpoint="concat_view", # wrist 480p (top) + L/R shoulder 320x180 (bottom) + resolution="480", # 640x360 data @ 480p (matches internal res480 run) + max_action_dim="${model.config.max_action_dim}", + cfg_dropout_rate=0.1, + tokenizer_config="${model.config.vlm_config.tokenizer}", + ), + ), + ), + ), + ), + dataloader_val=None, + upload_reproducible_setup=False, + ), + flags={"allow_objects": True}, +) + + +# chunk_length=32 → 33 observation frames; pin the VAE encode duration to match +# (internal used [17] for chunk_length=16). Set post-construction so it lands on +# the deep-copied NANO_MODEL_CONFIG.tokenizer. +action_policy_droid_nano["model"]["config"]["tokenizer"]["encode_exact_durations"] = [33] + + +for _item in [action_policy_droid_nano]: + _name = [k for k, v in globals().items() if v is _item][0] + cs.store(group="experiment", package="_global_", name=_name, node=_item) diff --git a/cosmos_framework/data/vfm/action/datasets/action_sft_dataset.py b/cosmos_framework/data/vfm/action/datasets/action_sft_dataset.py new file mode 100644 index 0000000..5d5b74e --- /dev/null +++ b/cosmos_framework/data/vfm/action/datasets/action_sft_dataset.py @@ -0,0 +1,86 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: OpenMDW-1.1 + +"""Map-style action SFT dataset: ``DROIDLeRobotDataset`` → ``ActionTransformPipeline``. + +The base ``DROIDLeRobotDataset.__getitem__`` returns the raw sample +(``video``/``action``/``ai_caption``/``viewpoint``/``mode``/``domain_id``/ +``idle_frames``). The model expects each sample to be passed through +``ActionTransformPipeline`` (spatial resize/pad, text tokenization, action +padding to ``max_action_dim``, and ``sequence_plan`` construction). This thin +wrapper composes the two so the experiment can hand a single map-style dataset +to ``RankPartitionedDataLoader`` (mirroring how the vision recipe uses +``get_sft_dataset``). +""" +from __future__ import annotations + +from typing import Any + +from torch.utils.data import Dataset + +from cosmos_framework.data.vfm.action.datasets.droid_lerobot_dataset import DROIDLeRobotDataset +from cosmos_framework.data.vfm.action.transforms import ActionTransformPipeline + + +class ActionSFTDataset(Dataset): + """Wraps a map-style action dataset and applies ``ActionTransformPipeline`` per sample.""" + + def __init__(self, dataset: Dataset, transform: ActionTransformPipeline, resolution: str | int | None): + super().__init__() + self._dataset = dataset + self._transform = transform + self._resolution = resolution + + def __len__(self) -> int: + return len(self._dataset) + + def __getitem__(self, idx: int) -> dict[str, Any]: + return self._transform(self._dataset[idx], self._resolution) + + +def get_action_droid_sft_dataset( + *, + root: str, + fps: float = 15.0, + chunk_length: int = 32, + action_space: str = "joint_pos", + use_state: bool = True, + action_normalization: str | None = None, + viewpoint: str = "concat_view", + use_image_augmentation: bool = False, + use_filter_dict: bool = False, + filter_dict_path: str | None = None, + resolution: str | int = "256", + max_action_dim: int = 64, + tokenizer_config: dict | None = None, + cfg_dropout_rate: float = 0.1, + append_viewpoint_info: bool = True, + append_duration_fps_timestamps: bool = True, + append_resolution_info: bool = True, + append_idle_frames: bool = False, +) -> ActionSFTDataset: + """Build the DROID action SFT dataset (joint_pos 8D by default), matching the + internal ``droid_lerobot_8b_policy`` data: ``action_space='joint_pos'`` + + ``use_state`` (8D, raw/un-normalized), concat_view, chunk_length 32.""" + dataset = DROIDLeRobotDataset( + root=root, + fps=fps, + chunk_length=chunk_length, + viewpoint=viewpoint, + action_space=action_space, + use_state=use_state, + action_normalization=action_normalization, + use_image_augmentation=use_image_augmentation, + use_filter_dict=use_filter_dict, + filter_dict_path=filter_dict_path, + ) + transform = ActionTransformPipeline( + tokenizer_config=tokenizer_config, + cfg_dropout_rate=cfg_dropout_rate, + max_action_dim=max_action_dim, + append_viewpoint_info=append_viewpoint_info, + append_duration_fps_timestamps=append_duration_fps_timestamps, + append_resolution_info=append_resolution_info, + append_idle_frames=append_idle_frames, + ) + return ActionSFTDataset(dataset, transform, resolution) diff --git a/cosmos_framework/data/vfm/action/datasets/droid_lerobot_dataset.py b/cosmos_framework/data/vfm/action/datasets/droid_lerobot_dataset.py index 07d0ebe..204df69 100644 --- a/cosmos_framework/data/vfm/action/datasets/droid_lerobot_dataset.py +++ b/cosmos_framework/data/vfm/action/datasets/droid_lerobot_dataset.py @@ -14,11 +14,12 @@ import pyarrow.parquet as pq import torch import torch.nn.functional as F +import torchvision.transforms as T from lerobot.datasets.video_utils import decode_video_frames from torch.utils.data import Dataset from cosmos_framework.data.vfm.action.action_normalization import load_action_stats, normalize_action -from cosmos_framework.data.vfm.action.action_spec import Gripper, Pos, Rot, build_action_spec +from cosmos_framework.data.vfm.action.action_spec import Gripper, Joint, Pos, Rot, build_action_spec from cosmos_framework.data.vfm.action.domain_utils import get_domain_id from cosmos_framework.data.vfm.action.pose_utils import ( build_abs_pose_from_components, @@ -35,6 +36,17 @@ "right": "observation.image.exterior_image_2_left", } _STATE_FEATURE = "observation.state.cartesian_position" +# joint_pos (8D = 7 arm joints + gripper) features, matching the internal +# DROIDLeRobotDataset(action_space="joint_pos", use_state=...). These are +# absolute joint commands/states (no normalization is applied for joint_pos, +# matching the internal canonical run which leaves action_normalization=None). +_JOINT_ACTION_FEATURE = "action.joint_position" # [7] commanded joints +_ACTION_GRIPPER_FEATURE = "action.gripper_position" # [1] commanded gripper +_JOINT_STATE_FEATURE = "observation.state.joint_positions" # [7] observed joints +_GRIPPER_STATE_FEATURE = "observation.state.gripper_position" # [1] observed gripper +# Columns whose parquet dtype is a list (need to_pylist -> stacked array). +_LIST_COLUMNS = {_STATE_FEATURE, _JOINT_ACTION_FEATURE, _JOINT_STATE_FEATURE} +_ACTION_SPACES = ("ee_pose", "joint_pos") # 90-degree clockwise rotation about the Z axis in the local frame. This matches # the production DROID wrapper conversion from Franka panda_link8 to OpenCV. @@ -48,12 +60,17 @@ class DROIDLeRobotDataset(Dataset): - """DROID Action dataset matching the v1.2 midtrain config default. - - The supported action layout is 10D ``[pos_delta(3), rot6d_delta(6), gripper(1)]``. - Unsupported branches from the production wrapper, such as joint-space - actions, filter dictionaries, temporal-segment validation, state prefixing, - and image augmentation, are intentionally omitted. + """DROID Action dataset. + + Two action layouts: + * ``action_space="ee_pose"`` (default): 10D ``[pos_delta(3), rot6d_delta(6), + gripper(1)]``, quantile-normalized (the v1.2 midtrain default). + * ``action_space="joint_pos"``: 8D ``[joint(7), gripper(1)]`` absolute joint + commands, NOT normalized, with ``use_state=True`` prepending the initial + observed joint+gripper state → ``(chunk+1, 8)`` — matching the internal + ``Cosmos3-Nano-Policy-DROID`` post-training run. + Filter dictionaries, temporal-segment validation, and image augmentation from + the production wrapper are intentionally omitted. """ def __init__( @@ -65,12 +82,22 @@ def __init__( pose_convention: PoseConvention = "backward_framewise", tolerance_s: float = 2e-4, viewpoint: Viewpoint = "concat_view", + action_space: str = "ee_pose", + use_state: bool = False, + action_normalization: str | None = "quantile", + use_image_augmentation: bool = False, + use_filter_dict: bool = False, + filter_dict_path: str | None = None, ) -> None: super().__init__() if pose_convention != "backward_framewise": raise NotImplementedError("This minimal DROID dataset only supports backward_framewise pose deltas.") if viewpoint != "concat_view": raise NotImplementedError("This minimal DROID dataset only supports concat_view.") + if action_space not in _ACTION_SPACES: + raise NotImplementedError(f"action_space must be one of {_ACTION_SPACES}, got {action_space!r}.") + if use_state and action_space != "joint_pos": + raise NotImplementedError("use_state is only supported with action_space='joint_pos'.") self._fps = float(fps) self._dt = 1.0 / self._fps @@ -79,6 +106,22 @@ def __init__( self._pose_convention = pose_convention self._tolerance_s = float(tolerance_s) self._viewpoint = viewpoint + self._action_space = action_space + self._use_state = bool(use_state) + # Per-sample image augmentation (random crop+rescale + color jitter), applied + # to all views with shared params (temporally + cross-view consistent). Lazy-built. + self._use_image_augmentation = bool(use_image_augmentation) + self._image_augmentor: T.Compose | None = None + # Keep-ranges window filter (internal use_filter_dict): restrict training windows + # to curated active segments, dropping idle/non-task frames. Off by default; the + # keep-ranges JSON is supplied via filter_dict_path (an internal data artifact). + self._use_filter_dict = bool(use_filter_dict) + self._filter_dict_path = filter_dict_path + if self._use_filter_dict and not self._filter_dict_path: + raise ValueError("use_filter_dict=True requires filter_dict_path") + # joint_pos trains on raw 8D joint values (the internal canonical run + # leaves action_normalization=None); ee_pose keeps quantile normalization. + self._action_normalization = None if action_space == "joint_pos" else action_normalization self._domain_id = get_domain_id("droid_lerobot") self._norm_stats: dict[str, torch.Tensor] | None = None @@ -93,14 +136,84 @@ def __init__( int(row["task_index"]): str(row["task"]) for row in pq.read_table(self._root / "meta" / "tasks.parquet").to_pylist() } - self._rows = sorted( - ( - row - for path in sorted((self._root / "data").glob("chunk-*/file-*.parquet")) - for row in pq.read_table(path).to_pylist() - ), - key=lambda row: int(row["index"]), - ) + # Compact, lazy frame index. Materializing every frame as a Python dict + # (``sorted(... pq.read_table(path).to_pylist() ...)``) does not scale: + # the full DROID success shard is ~18M frames, which is tens of GB of + # dicts plus an 18M-element Python sort at construction, and each + # DataLoader worker faults in its own copy. Instead we read only the + # columns the sample builder needs into contiguous numpy arrays + # (~1 GB total) -- read-only after init, so worker forks share them + # copy-on-write. + if action_space == "joint_pos": + feature_cols = [_JOINT_ACTION_FEATURE, _ACTION_GRIPPER_FEATURE, _JOINT_STATE_FEATURE, _GRIPPER_STATE_FEATURE] + else: + feature_cols = [_STATE_FEATURE, _ACTION_GRIPPER_FEATURE] + columns = ["index", "episode_index", "task_index", "timestamp", *feature_cols] + index_parts, episode_parts, task_parts, ts_parts = [], [], [], [] + feature_parts: dict[str, list] = {c: [] for c in feature_cols} + for path in sorted((self._root / "data").glob("chunk-*/file-*.parquet")): + table = pq.read_table(path, columns=columns) + index_parts.append(table["index"].to_numpy()) + episode_parts.append(table["episode_index"].to_numpy()) + task_parts.append(table["task_index"].to_numpy()) + ts_parts.append(table["timestamp"].to_numpy()) + for c in feature_cols: + if c in _LIST_COLUMNS: + feature_parts[c].append(np.asarray(table[c].to_pylist(), dtype=np.float32)) + else: + feature_parts[c].append(np.asarray(table[c].to_numpy(), dtype=np.float32)) + order = np.argsort(np.concatenate(index_parts).astype(np.int64), kind="stable") + self._row_episode = np.concatenate(episode_parts).astype(np.int64)[order] + self._row_task = np.concatenate(task_parts).astype(np.int64)[order] + self._row_timestamp = np.concatenate(ts_parts).astype(np.float64)[order] + # Per-feature arrays keyed by parquet column name (read-only after init). + self._feat = { + c: np.concatenate(feature_parts[c], axis=0).astype(np.float32)[order] for c in feature_cols + } + + # Group frames into episodes and keep only within-episode chunk windows. + # The global frame index is ordered by episode in LeRobot v3, so episodes + # are contiguous blocks once sorted by ``index``. The previous code sliced + # the flat row list (``rows[idx : idx + chunk + 1]``) with no boundary + # guard, so ~one chunk of samples per episode silently mixed two episodes; + # restricting to in-episode windows yields ``total - n_episodes * chunk`` + # valid samples (matching the production dataset). + assert np.all(np.diff(self._row_episode) >= 0), "episode_index is not contiguous after sorting by frame index" + ep_vals, ep_starts, ep_counts = np.unique(self._row_episode, return_index=True, return_counts=True) + self._ep_vals = ep_vals.astype(np.int64) + self._ep_starts = ep_starts.astype(np.int64) + self._valid_cum = np.cumsum(np.maximum(0, ep_counts - self._chunk_length)).astype(np.int64) + + # Keep-ranges filter: build a per-segment index over only the kept windows. + # Mirrors internal _append_index_records (use_filter_dict): the filter dict maps a + # gs:// trajectory key -> list of [start, end] frame ranges; keep windows whose start + # is in [max(start,0), min(end-chunk, valid)). Episodes absent from the dict are dropped. + if self._use_filter_dict: + with open(self._filter_dict_path) as f: + filter_dict = json.load(f) + seg_ep_pos, seg_win_start, seg_len = [], [], [] + for pos in range(len(self._ep_vals)): + valid = int(max(0, ep_counts[pos] - self._chunk_length)) + if valid <= 0: + continue + ep_id = str(self._episodes[int(self._ep_vals[pos])]["episode_id"]) + key = ( + f"gs://xembodiment_data/r2d2/r2d2-data-full/{ep_id}/recordings/" + f"MP4--gs://xembodiment_data/r2d2/r2d2-data-full/{ep_id}/trajectory.h5" + ) + ranges = filter_dict.get(key) + if ranges is None: + continue + for s, e in ranges: + ws = max(int(s), 0) + we = min(int(e) - self._chunk_length, valid) + if we - ws > 0: + seg_ep_pos.append(pos) + seg_win_start.append(ws) + seg_len.append(we - ws) + self._seg_ep_pos = np.asarray(seg_ep_pos, dtype=np.int64) + self._seg_win_start = np.asarray(seg_win_start, dtype=np.int64) + self._seg_cum = np.cumsum(seg_len).astype(np.int64) if seg_len else np.zeros(0, dtype=np.int64) @property def fps(self) -> float: @@ -124,28 +237,62 @@ def domain_id(self) -> int: @property def action_dim(self) -> int: - return 10 + return 8 if self._action_space == "joint_pos" else 10 + + def _action_spec(self): + if self._action_space == "joint_pos": + return build_action_spec(Joint(n=7, label="joint"), Gripper()) + return build_action_spec(Pos(), Rot("rot6d"), Gripper()) @property def action_names(self) -> list[str]: - return build_action_spec(Pos(), Rot("rot6d"), Gripper()).names + return self._action_spec().names def _choose_mode(self) -> str: if self._mode == "joint": return random.choice(_MODE_CHOICES) return self._mode + def _window_rows(self, start: int, stop: int, episode_index: int) -> list[dict[str, Any]]: + """Reconstruct the per-frame dicts the sample builder consumes for the + half-open frame window ``[start, stop)`` from the compact column arrays. + ``start``/``stop`` are guaranteed to lie within a single episode.""" + return [ + { + "episode_index": episode_index, + "task_index": int(self._row_task[j]), + "timestamp": float(self._row_timestamp[j]), + **{c: self._feat[c][j] for c in self._feat}, + } + for j in range(start, stop) + ] + def __getitem__(self, idx: int) -> dict[str, Any]: mode = self._choose_mode() idx = int(idx) - first_row = self._rows[idx] - episode = self._episodes[int(first_row["episode_index"])] - - observation_rows = self._rows[idx : idx + self._chunk_length + 1] - action_rows = observation_rows[: self._chunk_length] + # Map the flat sample index to a within-episode frame window. + if self._use_filter_dict: + seg = int(np.searchsorted(self._seg_cum, idx, side="right")) + base = int(self._seg_cum[seg - 1]) if seg > 0 else 0 + ep = int(self._seg_ep_pos[seg]) + start = int(self._ep_starts[ep]) + int(self._seg_win_start[seg]) + (idx - base) + else: + ep = int(np.searchsorted(self._valid_cum, idx, side="right")) + prev = int(self._valid_cum[ep - 1]) if ep > 0 else 0 + start = int(self._ep_starts[ep]) + (idx - prev) + episode_index = int(self._ep_vals[ep]) + episode = self._episodes[episode_index] + + observation_rows = self._window_rows(start, start + self._chunk_length + 1, episode_index) video = self._load_concat_video(episode, observation_rows) - raw_action, initial_pose = self._build_raw_action(observation_rows, action_rows) + if self._action_space == "joint_pos": + raw_action = self._build_joint_action(observation_rows) + extras: dict[str, Any] = {} + else: + action_rows = observation_rows[: self._chunk_length] + raw_action, initial_pose = self._build_raw_action(observation_rows, action_rows) + extras = {"initial_pose": initial_pose} task = self._tasks[int(observation_rows[0]["task_index"])] ai_caption = random.choice(task.split(" | ")) @@ -154,13 +301,32 @@ def __getitem__(self, idx: int) -> dict[str, Any]: video=video, action=raw_action, ai_caption=ai_caption, - initial_pose=initial_pose, additional_view_description=( "The top row is from the wrist-mounted camera. " "The bottom row contains two horizontally concatenated third-person perspective views of the scene from opposite sides, with the robot visible." ), + **extras, ) + def _build_joint_action(self, observation_rows: list[dict[str, Any]]) -> torch.Tensor: + """8D joint-position action ``[joint(7), gripper(1)]`` over the chunk, matching + the internal ``action_space='joint_pos'``. The window is ``chunk+1`` frames: + ``row[0]`` is the initial observed state (prepended when ``use_state``), and + ``rows[1:]`` are the ``chunk`` commanded actions. Gripper is flipped (1 - g). + No normalization is applied (internal canonical run uses raw joint values).""" + action_rows = observation_rows[1:] + joints = np.asarray([r[_JOINT_ACTION_FEATURE] for r in action_rows], dtype=np.float32) # [chunk, 7] + gripper = np.asarray([r[_ACTION_GRIPPER_FEATURE] for r in action_rows], dtype=np.float32).reshape(-1, 1) + gripper = 1.0 - gripper + action = np.concatenate([joints, gripper], axis=-1) # [chunk, 8] + if self._use_state: + init = observation_rows[0] + init_joint = np.asarray(init[_JOINT_STATE_FEATURE], dtype=np.float32) # [7] + init_gripper = np.asarray([1.0 - float(init[_GRIPPER_STATE_FEATURE])], dtype=np.float32) # [1] + initial_state = np.concatenate([init_joint, init_gripper])[None, :] # [1, 8] + action = np.concatenate([initial_state, action], axis=0) # [chunk + 1, 8] + return torch.from_numpy(action).float() + def _load_concat_video( self, episode: dict[str, Any], @@ -179,6 +345,25 @@ def _load_concat_video( wrist = frames_by_view["wrist"] left = frames_by_view["left"] right = frames_by_view["right"] + + if self._use_image_augmentation: + # Random crop+rescale (spatial jitter) + color jitter, BEFORE the concat. + # All three views are stacked so one sampled set of params is applied + # uniformly across every frame and view (temporally + cross-view consistent), + # while each __getitem__ resamples. Matches the internal DROID recipe. + if self._image_augmentor is None: + _, _, h, w = wrist.shape + self._image_augmentor = T.Compose( + [ + T.RandomCrop((int(h * 0.95), int(w * 0.95))), + T.Resize((h, w), antialias=True), + T.ColorJitter(brightness=0.3, contrast=0.4, saturation=0.5, hue=0.08), + ] + ) + n, m = wrist.shape[0], wrist.shape[0] + left.shape[0] + combined = self._image_augmentor(torch.cat([wrist, left, right], dim=0)) + wrist, left, right = combined[:n], combined[n:m], combined[m:] + _, _, h_w, w_w = wrist.shape half_h, half_w = h_w // 2, w_w // 2 left = F.interpolate(left, size=(half_h, half_w), mode="bilinear", align_corners=False) @@ -233,7 +418,7 @@ def _build_result( ai_caption: str, **extras: Any, ) -> dict[str, Any]: - spec = build_action_spec(Pos(), Rot("rot6d"), Gripper()) + spec = self._action_spec() idle_frames = compute_idle_frames( action, spec, @@ -243,12 +428,15 @@ def _build_result( joint_threshold=5e-3 / self._fps, min_streak=3, ) - normalized_action = normalize_action(action, "quantile", self._load_norm_stats()) + if self._action_normalization is None: + out_action = action + else: + out_action = normalize_action(action, self._action_normalization, self._load_norm_stats()) formatted_video = (video * 255.0).clamp(0.0, 255.0).to(torch.uint8).permute(1, 0, 2, 3) return { "ai_caption": ai_caption, "video": formatted_video, - "action": normalized_action, + "action": out_action, "conditioning_fps": torch.tensor(self._fps, dtype=torch.long), "mode": mode, "domain_id": torch.tensor(self._domain_id, dtype=torch.long), @@ -267,4 +455,6 @@ def _load_norm_stats(self) -> dict[str, torch.Tensor]: return self._norm_stats def __len__(self) -> int: - return max(0, len(self._rows) - self._chunk_length) + if self._use_filter_dict: + return int(self._seg_cum[-1]) if self._seg_cum.size else 0 + return int(self._valid_cum[-1]) if self._valid_cum.size else 0 diff --git a/docs/action_policy_droid_posttrain.md b/docs/action_policy_droid_posttrain.md new file mode 100644 index 0000000..ac5c3fb --- /dev/null +++ b/docs/action_policy_droid_posttrain.md @@ -0,0 +1,96 @@ + + + +# DROID Action-Policy Post-Training — `Cosmos3-Nano-Policy-DROID` + +> **STATUS: recipe ships in this package.** The registered experiment, the DROID action +> dataset class (`joint_pos` 8D + `use_state`), and the EMA warm-start fix land here. +> To run it you supply two external inputs — a prepared **DROID LeRobot v3.0** dataset and +> a **DCP base checkpoint** converted from `nvidia/Cosmos3-Nano` (see +> [Inputs you provide](#inputs-you-provide)). Validated end-to-end on H200: 1 node / 8 GPU +> and 2 nodes / 16 ranks (HSDP). + +Fine-tune `Cosmos3-Nano` (the 8B MoT) into an action policy on the **DROID LeRobot** dataset, +reproducing `Cosmos3-Nano-Policy-DROID`. The policy is initialized from **`nvidia/Cosmos3-Nano`** +(public Hugging Face repo) and trained with absolute joint-position actions + proprioceptive +state at 480p. + +______________________________________________________________________ + +## Inputs you provide + +This package ships the training stack — the registered `action_policy_droid_nano` experiment, +the DROID action dataset class with the recipe knobs (`action_space=joint_pos`, `use_state`, +`concat_view`), and the EMA warm-start in `checkpoint/dcp.py`. Two inputs are external and must +be provided per environment: + +1. **Prepared DROID LeRobot v3.0 dataset** — the LeRobot v2.0→v3.0 conversion + success + filtering is run out-of-band (not yet in this repo). Point `DROID_ROOT` at the resulting + `…/droid_lerobot/success` directory (must contain `meta/info.json`). +2. **DCP base checkpoint** — convert `nvidia/Cosmos3-Nano` to DCP and point + `BASE_CHECKPOINT_PATH` at it (see [Full reproduction](#full-reproduction)). Action heads are + not loaded from it (they init fresh). + +## Dataset — DROID LeRobot + +To be released. + +## Recipe + +| knob | value | +| ----------------- | ------------------------------------------------------------------- | +| init | `nvidia/Cosmos3-Nano` (public Hugging Face repo) | +| action space | `joint_pos` (absolute joint position, 8-D incl. gripper) | +| state | `use_state=true` (proprioception; valid only with `joint_pos`) | +| resolution | `480` | +| viewpoint / video | `concat_view` / `video_mode=null` | +| chunk length | `32` (tokenizer `encode_exact_durations=[33]`) | +| lr | `2e-4` | +| samples/rank | `32` (H200-safe; 64 OOMs at 480p). global batch = `32 × world_size` | +| eval | disabled for the reproduction run | + +## Full reproduction + +The OSS flow mirrors the other recipes (see [docs/training.md](./training.md)): + +```shell +# Step 1: prepare DROID LeRobot v3.0 success split -> $DATASET_PATH (see "Inputs you provide") + +# Step 2: convert the base checkpoint -> $BASE_CHECKPOINT_PATH +python -m cosmos_framework.scripts.convert_model_to_dcp \ + --checkpoint-path Cosmos3-Nano \ + -o $BASE_CHECKPOINT_PATH + +# Step 3: launch. The TOML selects the experiment + scalars; the dataset/action +# knobs come from the registered experiment. +export DATASET_PATH=/path/to/dataset/success +export BASE_CHECKPOINT_PATH=/path/to/base_checkpoint +export WAN_VAE_PATH=/path/to/Wan2.2_VAE.pth +export NPROC_PER_NODE=8 +bash examples/launch_sft_action_policy_droid.sh +``` + +The recipe TOML (`examples/toml/sft_config/action_policy_droid_repro.toml`) sets the scalar +knobs (`max_iter`, `save_iter`, `grad_clip`, parallelism, wandb); the dataset/action knobs +(`joint_pos`, `use_state`, `concat_view`, 480p, chunk 32, count-based batch) live in the +registered `action_policy_droid_nano` experiment per the schema's design. For multi-node HSDP, +set `model.parallelism.data_parallel_replicate_degree = ` (intra-node shard stays 8). + +## Smoke reproduction + +Config/import/data sanity without burning a full run: small node count + a handful of iters via +`--config-overrides "trainer.max_iter=10" "checkpoint.save_iter=10"` (and a small +`data_parallel_shard_degree`). Use this to validate the recipe composes and the dataset opens +before any large allocation. + +## Checkpoints + +- Saved every `save_iter` iters (1000 in the validated run) to the object store, at + `////checkpoints/iter_/`. +- The run is **resumable** from the latest checkpoint (re-launch with the same `job.name`). +- Export to HF safetensors via `cosmos_framework.scripts.export_model` (see [docs/training.md](./training.md)). + +## Non-goals + +- **Closed-loop / action evaluation is out of scope** for this reproduction pass (training + reproduction only), unless explicitly expanded. diff --git a/examples/launch_sft_action_policy_droid.sh b/examples/launch_sft_action_policy_droid.sh new file mode 100755 index 0000000..6ab0bc9 --- /dev/null +++ b/examples/launch_sft_action_policy_droid.sh @@ -0,0 +1,39 @@ +#!/usr/bin/env bash +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: OpenMDW-1.1 + +# ============================================================================ +# Structured-TOML launch for DROID action-policy SFT on Cosmos3-Nano (8B MoT). +# Drives cosmos_framework.scripts.train against +# examples/toml/sft_config/action_policy_droid_repro.toml (selects the +# registered `action_policy_droid_nano` experiment; res480, joint_pos 8D + +# use_state, trains the generation + action heads). See +# docs/action_policy_droid_posttraining.md. +# +# Env vars (override for your filesystem): +# DATASET_PATH DROID LeRobot v3.0 success split (…/droid_lerobot/success) +# BASE_CHECKPOINT_PATH DCP of nvidia/Cosmos3-Nano (convert_model_to_dcp; see docs) +# WAN_VAE_PATH Wan2.2 VAE .pth (Wan-AI/Wan2.2-TI2V-5B) +# WANDB_API_KEY for online logging (TOML wandb_mode="online") +# NPROC_PER_NODE torchrun --nproc_per_node (default 8) +# +# Single-node smoke (config/data sanity, a few iters): +# TAIL_OVERRIDES=(trainer.max_iter=10 checkpoint.save_iter=10 \ +# dataloader_train.max_samples_per_batch=32) +# bash examples/launch_sft_action_policy_droid.sh +# +# Multi-node: launch on every worker; the trainer reads torchrun's +# --nnodes/--node_rank. For HSDP set +# model.parallelism.data_parallel_replicate_degree = (shard stays 8). +# ============================================================================ + +TOML_FILE="examples/toml/sft_config/action_policy_droid_repro.toml" +: "${DATASET_PATH:=examples/data/lerobot_v30/droid_lerobot/success}" +: "${BASE_CHECKPOINT_PATH:=examples/checkpoints/Cosmos3-Nano}" + +# The experiment reads ${oc.env:DROID_ROOT}; bridge the launcher's DATASET_PATH to it. +export DROID_ROOT="${DROID_ROOT:-$DATASET_PATH}" + +EXTRA_DATASET_CHECK='[[ -f "$DROID_ROOT/meta/info.json" ]] || { echo "ERROR: missing $DROID_ROOT/meta/info.json (prepare DROID LeRobot v3.0 — see docs/action_policy_droid_posttraining.md)" >&2; exit 1; }' + +source "$(dirname "${BASH_SOURCE[0]}")/_sft_launcher_common.sh" diff --git a/examples/toml/sft_config/action_policy_droid_repro.toml b/examples/toml/sft_config/action_policy_droid_repro.toml new file mode 100644 index 0000000..a929a7c --- /dev/null +++ b/examples/toml/sft_config/action_policy_droid_repro.toml @@ -0,0 +1,53 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: OpenMDW-1.1 + +# ============================================================================ +# DROID action-policy SFT — run config for the `action_policy_droid_nano` +# experiment. The recipe knobs (optimizer/lr, scheduler type, grad_clip, +# count-based batch, action-head skip-on-load, dataset knobs) live in the +# registered experiment; this file only sets run-level scalars (iters, ckpt +# cadence, parallelism shape, wandb, VAE path). +# +# Env required: +# DROID_ROOT=/path/to/droid_lerobot_640x360/success +# BASE_CHECKPOINT_PATH= +# WAN_VAE_PATH= +# IMAGINAIRE_OUTPUT_ROOT=/path/to/output_root # persist checkpoints +# ============================================================================ + +[job] +task = "vfm" +experiment = "action_policy_droid_nano" +project = "cosmos3_action" +group = "action_sft" +name = "action_policy_droid_repro" +wandb_mode = "online" + +[model] +precision = "bfloat16" + +[model.parallelism] +data_parallel_shard_degree = 8 # intra-node (8x H200); set replicate for multi-node HSDP +data_parallel_replicate_degree = 1 + +[model.activation_checkpointing] +mode = "full" +save_ops_regex = ["fmha"] + +[model.tokenizer] +vae_path = "${oc.env:WAN_VAE_PATH}" + +[scheduler] +cycle_lengths = [10000] # match max_iter + +[trainer] +max_iter = 10000 +logging_iter = 50 + +[checkpoint] +load_path = "${oc.env:BASE_CHECKPOINT_PATH}" +save_iter = 1000 + +# Per-rank batch is 128 in the experiment (res480; matches the reference recipe). +# 128 OOMs on a 139 GiB H200 — override at launch for H200: +# --opts dataloader_train.max_samples_per_batch=32 diff --git a/sitecustomize.py b/sitecustomize.py index 246bf6a..c185145 100644 --- a/sitecustomize.py +++ b/sitecustomize.py @@ -15,6 +15,20 @@ import os import sys +# Opt-in (COSMOS_DL_FILE_SYSTEM_SHARING=1): switch torch's DataLoader IPC from the +# default 'file_descriptor' strategy (which stages worker tensors in /dev/shm) to +# 'file_system'. On shm-constrained containers, large video batches overflow the +# small /dev/shm tmpfs and a worker dies mid-transfer -> the main process then sees +# "unable to open shared memory object ... No such file or directory". 'file_system' +# sidesteps /dev/shm entirely. Guarded so non-training processes never import torch. +if os.environ.get("COSMOS_DL_FILE_SYSTEM_SHARING") == "1": + try: + import torch.multiprocessing as _tmp + + _tmp.set_sharing_strategy("file_system") + except Exception: + pass + _DIR = os.environ.get("LOAD_TRACE_DIR", "") if _DIR: _TAG = os.environ.get("LOAD_TRACE_TAG", "default") From 016c96de846b0c2c17f2212045d269956262efeb Mon Sep 17 00:00:00 2001 From: Muneer Ali Date: Thu, 11 Jun 2026 11:40:49 +0530 Subject: [PATCH 04/16] Fix join_path type inconsistency: return Path when Path inputs are provided (#33) ## Summary `LocalBackend.join_path` accepted `Union[str, Path]` inputs but always returned `str` (via `os.path.join`), even when `Path` objects were passed. This violated the type contract and could cause `AttributeError` downstream. ## Changes - **local_backend.py**: Now checks if any input is a `Path` and returns `Path(result)` accordingly. Removed the stale TODO that acknowledged this issue. - **base_backend.py, easy_io.py, file_client.py**: Updated return type from `str` to `Union[str, Path]`. - **boto3_backend.py, msc_backend.py, http_backend.py**: Updated return type signature for consistency with the abstract base class. ## Related Issue Closes #32 Co-authored-by: Maosheng Liao --- .../utils/easy_io/backends/base_backend.py | 2 +- .../utils/easy_io/backends/boto3_backend.py | 4 ++-- .../utils/easy_io/backends/http_backend.py | 2 +- .../utils/easy_io/backends/local_backend.py | 10 ++++++---- cosmos_framework/utils/easy_io/backends/msc_backend.py | 4 ++-- cosmos_framework/utils/easy_io/easy_io.py | 2 +- cosmos_framework/utils/easy_io/file_client.py | 4 ++-- 7 files changed, 15 insertions(+), 13 deletions(-) diff --git a/cosmos_framework/utils/easy_io/backends/base_backend.py b/cosmos_framework/utils/easy_io/backends/base_backend.py index 94484bb..1eb5009 100644 --- a/cosmos_framework/utils/easy_io/backends/base_backend.py +++ b/cosmos_framework/utils/easy_io/backends/base_backend.py @@ -70,7 +70,7 @@ def isfile(self, filepath: Union[str, Path]) -> bool: pass @abstractmethod - def join_path(self, filepath: Union[str, Path], *filepaths: Union[str, Path]) -> str: + def join_path(self, filepath: Union[str, Path], *filepaths: Union[str, Path]) -> Union[str, Path]: pass @abstractmethod diff --git a/cosmos_framework/utils/easy_io/backends/boto3_backend.py b/cosmos_framework/utils/easy_io/backends/boto3_backend.py index 86b1ab0..ce4ba5f 100644 --- a/cosmos_framework/utils/easy_io/backends/boto3_backend.py +++ b/cosmos_framework/utils/easy_io/backends/boto3_backend.py @@ -284,7 +284,7 @@ def join_path( self, filepath: Union[str, Path], *filepaths: Union[str, Path], - ) -> str: + ) -> Union[str, Path]: r"""Concatenate all file paths. Join one or more filepath components intelligently. The return value @@ -294,7 +294,7 @@ def join_path( filepath (str or Path): Path to be concatenated. Returns: - str: The result after concatenation. + str or Path: The result after concatenation. Examples: >>> backend = Boto3Backend() diff --git a/cosmos_framework/utils/easy_io/backends/http_backend.py b/cosmos_framework/utils/easy_io/backends/http_backend.py index 32be908..593c9ac 100644 --- a/cosmos_framework/utils/easy_io/backends/http_backend.py +++ b/cosmos_framework/utils/easy_io/backends/http_backend.py @@ -112,7 +112,7 @@ def isdir(self, filepath: Union[str, Path]) -> bool: def isfile(self, filepath: Union[str, Path]) -> bool: raise NotImplementedError(f"isfile not supported in {self.name}") - def join_path(self, filepath: Union[str, Path], *filepaths: Union[str, Path]) -> str: + def join_path(self, filepath: Union[str, Path], *filepaths: Union[str, Path]) -> Union[str, Path]: raise NotImplementedError(f"join_path not supported in {self.name}") @contextmanager diff --git a/cosmos_framework/utils/easy_io/backends/local_backend.py b/cosmos_framework/utils/easy_io/backends/local_backend.py index 80d05b8..7599314 100644 --- a/cosmos_framework/utils/easy_io/backends/local_backend.py +++ b/cosmos_framework/utils/easy_io/backends/local_backend.py @@ -187,7 +187,7 @@ def isfile(self, filepath: Union[str, Path]) -> bool: """ return osp.isfile(filepath) - def join_path(self, filepath: Union[str, Path], *filepaths: Union[str, Path]) -> str: + def join_path(self, filepath: Union[str, Path], *filepaths: Union[str, Path]) -> Union[str, Path]: r"""Concatenate all file paths. Join one or more filepath components intelligently. The return value @@ -197,7 +197,7 @@ def join_path(self, filepath: Union[str, Path], *filepaths: Union[str, Path]) -> filepath (str or Path): Path to be concatenated. Returns: - str: The result of concatenation. + str or Path: The result of concatenation. Returns a Path if any input is a Path. Examples: >>> backend = LocalBackend() @@ -207,8 +207,10 @@ def join_path(self, filepath: Union[str, Path], *filepaths: Union[str, Path]) -> >>> backend.join_path(filepath1, filepath2, filepath3) '/path/of/dir/dir2/path/of/file' """ - # TODO, if filepath or filepaths are Path, should return Path - return osp.join(filepath, *filepaths) + result = osp.join(filepath, *filepaths) + if isinstance(filepath, Path) or any(isinstance(p, Path) for p in filepaths): + return Path(result) + return result @contextmanager def get_local_path( diff --git a/cosmos_framework/utils/easy_io/backends/msc_backend.py b/cosmos_framework/utils/easy_io/backends/msc_backend.py index 83cc56f..72ac654 100644 --- a/cosmos_framework/utils/easy_io/backends/msc_backend.py +++ b/cosmos_framework/utils/easy_io/backends/msc_backend.py @@ -554,7 +554,7 @@ def join_path( self, filepath: Union[str, Path], *filepaths: Union[str, Path], - ) -> str: + ) -> Union[str, Path]: r"""Concatenate all file paths. Join one or more filepath components intelligently. The return value @@ -564,7 +564,7 @@ def join_path( filepath (str or Path): Path to be concatenated. Returns: - str: The result after concatenation. + str or Path: The result after concatenation. Examples: >>> backend = MSCBackend() diff --git a/cosmos_framework/utils/easy_io/easy_io.py b/cosmos_framework/utils/easy_io/easy_io.py index 686ae30..1963764 100644 --- a/cosmos_framework/utils/easy_io/easy_io.py +++ b/cosmos_framework/utils/easy_io/easy_io.py @@ -424,7 +424,7 @@ def join_path( backend_key (str, optional): The key to get the backend from register. Returns: - str: The result of concatenation. + str or Path: The result of concatenation. Returns a Path if any input is a Path. Examples: >>> filepath1 = '/path/of/dir1' diff --git a/cosmos_framework/utils/easy_io/file_client.py b/cosmos_framework/utils/easy_io/file_client.py index 650489f..4a33328 100644 --- a/cosmos_framework/utils/easy_io/file_client.py +++ b/cosmos_framework/utils/easy_io/file_client.py @@ -375,7 +375,7 @@ def isfile(self, filepath: Union[str, Path]) -> bool: """ return self.client.isfile(filepath) - def join_path(self, filepath: Union[str, Path], *filepaths: Union[str, Path]) -> str: + def join_path(self, filepath: Union[str, Path], *filepaths: Union[str, Path]) -> Union[str, Path]: r"""Concatenate all file paths. Join one or more filepath components intelligently. The return value @@ -385,7 +385,7 @@ def join_path(self, filepath: Union[str, Path], *filepaths: Union[str, Path]) -> filepath (str or Path): Path to be concatenated. Returns: - str: The result of concatenation. + str or Path: The result of concatenation. Returns a Path if any input is a Path. """ return self.client.join_path(filepath, *filepaths) From b64eca2eebb220d30a8fec7b529a35ad54e4e666 Mon Sep 17 00:00:00 2001 From: Yu-Wei Chao <82182961+ychao-nvidia@users.noreply.github.com> Date: Thu, 11 Jun 2026 11:34:32 -0700 Subject: [PATCH 05/16] Update Cosmos3-Nano-Policy-DROID server doc (#35) ### Summary Documents the Cosmos3-Nano-Policy-DROID policy server and aligns it with the [cosmos cookbook](https://github.com/NVIDIA/cosmos/blob/main/cookbooks/cosmos3/generator/action/run_policy_with_cosmos_framework.md) so the two stay consistent. Replaces the prior RoboLab/OpenPI WebSocket guide with a Docker-based server-client workflow. ### Changes - **`docs/action_policy_droid_server.md`** (new): full guide for serving Cosmos3-Nano-Policy-DROID via a policy **Server** that streams actions to a RoboLab **Client**, using a Docker-based setup (clone, build image, launch container). - **`docs/action_policy_robolab_server.md`** (removed): superseded by the above; the old uv/OpenPI WebSocket flow no longer matches the cookbook. - **`README.md`**: add a TOC entry, a Policy Server section, and a reference-table row linking the new guide. ### Impact Docs-only change; no code paths affected. Co-authored-by: Claude Opus 4.8 (1M context) --- README.md | 6 ++ docs/action_policy_droid_server.md | 109 +++++++++++++++++++++++++++ docs/action_policy_robolab_server.md | 63 ---------------- 3 files changed, 115 insertions(+), 63 deletions(-) create mode 100644 docs/action_policy_droid_server.md delete mode 100644 docs/action_policy_robolab_server.md diff --git a/README.md b/README.md index d40c8dd..eb53701 100644 --- a/README.md +++ b/README.md @@ -29,6 +29,7 @@ - [Training (Supervised Fine-Tuning)](./docs/training.md) - [JSONL Dataset](./docs/dataset_jsonl.md) - [Inference](./docs/inference.md) +- [Policy Server](./docs/action_policy_droid_server.md) - Reference - [Code Structure](./docs/code_structure.md) - [Environment Variables](./docs/environment_variables.md) @@ -82,6 +83,10 @@ python -m cosmos_framework.scripts.inference \ --seed=0 ``` +## Policy Server + +See [Policy Server](./docs/action_policy_droid_server.md) for the full guide. + ## Reference | Topic | What it covers | @@ -90,4 +95,5 @@ python -m cosmos_framework.scripts.inference \ | [Code Structure](./docs/code_structure.md) | Repository layout and a per-subpackage tour of `cosmos_framework/` — where each concern lives and where to add new code. | | [Training](./docs/training.md) | Launching multi-GPU and multi-node runs; parallelism strategies; mixed precision; resuming. | | [Inference (from a trained checkpoint)](./docs/inference.md) | Loading a trained checkpoint into one of the inference backends. | +| [Policy Server](./docs/action_policy_droid_server.md) | Running the server-client pipeline for Cosmos3-Nano-Policy-DROID. | | [FAQ](./docs/faq.md) | Troubleshooting (OOM, NCCL hangs, slow training), environment variables, and common pitfalls. | diff --git a/docs/action_policy_droid_server.md b/docs/action_policy_droid_server.md new file mode 100644 index 0000000..60e0042 --- /dev/null +++ b/docs/action_policy_droid_server.md @@ -0,0 +1,109 @@ +# Cosmos3-Nano-Policy-DROID Server + +[Cosmos3-Nano-Policy-DROID](https://huggingface.co/nvidia/Cosmos3-Nano-Policy-DROID) is served by a policy **Server** that streams actions to a **Client** driving a simulated or real robot. This example uses [`RoboLab`](https://github.com/NVlabs/RoboLab), a simulation benchmark for task-generalist policies, as the client. Start the server first, then connect the client. + + + +______________________________________________________________________ + +**Table of Contents** + +- [Policy Server](#policy-server) +- [Simulation Client](#simulation-client) + +______________________________________________________________________ + + + +## Policy Server + +First, clone [`cosmos-framework`](https://github.com/NVIDIA/cosmos-framework): + +```bash +git clone https://github.com/NVIDIA/cosmos-framework.git +cd cosmos-framework +``` + +Build the Docker image: + +```bash +docker build \ + -t cosmos-framework:latest \ + . +``` + +Set your Hugging Face token and launch the container, which installs the dependencies: + +```bash +# Set your Hugging Face token (https://huggingface.co/settings/tokens): +export HF_TOKEN= + +docker run \ + -it \ + -e HF_HOME=/workspace/.cache/huggingface \ + -e HF_TOKEN=$HF_TOKEN \ + --net host \ + --rm \ + --runtime nvidia \ + -v .:/workspace \ + -v /workspace/.venv \ + -v $HOME/.cache/huggingface:/root/.cache/huggingface \ + cosmos-framework:latest \ + bash -c '\ + uv sync \ + --all-extras \ + --group=cu130-train \ + --group=policy-server && \ + exec bash; \ + ' +``` + +The `--group=cu130-train` line targets CUDA 13.x drivers. On CUDA 12.x systems, replace it with `--group=cu128-train` (see the [Cosmos3 Cookbooks: Environment Setup](https://github.com/NVIDIA/cosmos/blob/main/cookbooks/cosmos3/README.md) for details). + +Inside the container, start the policy server: + +``` +python -m cosmos_framework.scripts.action_policy_server_robolab \ + --port 8000 +``` + +## Simulation Client + +Clone [`RoboLab`](https://github.com/NVlabs/RoboLab): + +```bash +git clone https://github.com/NVlabs/RoboLab.git +cd RoboLab +``` + +Build the Docker image: + +```bash +./docker/build_docker.sh latest +``` + +Launch the container: + +```bash +./docker/run_docker.sh latest +``` + +Run a task against the policy server. This opens a viewer window for real-time visualization of the simulation: + +```bash +python policies/cosmos3/run.py \ + --task BananaInBowlTask +``` + +To evaluate across multiple sub-environments in parallel in headless mode: + +```bash +python policies/cosmos3/run.py \ + --task BananaInBowlTask \ + --num-envs 10 \ + --headless +``` + +Example output: + + diff --git a/docs/action_policy_robolab_server.md b/docs/action_policy_robolab_server.md deleted file mode 100644 index 5c25a86..0000000 --- a/docs/action_policy_robolab_server.md +++ /dev/null @@ -1,63 +0,0 @@ -# Action Policy RoboLab Server - -Use the RoboLab server when your client uses the openpi-style WebSocket protocol. The server accepts msgpack-encoded observation dictionaries with NumPy arrays and returns msgpack-encoded dictionaries containing `action` and, when `--decode-video` is set, `video`. - -The server delegates WebSocket protocol handling to OpenPI's `WebsocketPolicyServer`. Install OpenPI's lightweight server package in the Cosmos3 environment before launching: - -```shell -uv sync --all-extras --group=cu130-train --group=policy-server -source .venv/bin/activate -``` - -The `policy-server` group installs `openpi-server`. Alternatively, install the full `Physical-Intelligence/openpi` package if you manage dependencies in a separate environment. Run `uv sync` once for a fresh checkout or when dependency groups change; you do not need to rerun it before every server launch. If your GPU driver does not support CUDA 13, use the matching CUDA group for your node, for example `--group=cu128-train`. - -## Start the Server - -The primary OSS flow is to serve the consolidated DROID policy checkpoint from Hugging Face. - -```shell -python -m cosmos_framework.scripts.action_policy_server_robolab --port 8000 -``` - -By default, the server uses the released DROID RoboLab serving config: `nvidia/Cosmos3-Nano-Policy-DROID` on `main`, `droid_lerobot`, `480` resolution, 15 FPS conditioning, 540x640 input images, 32 action steps, 8-dimensional `joint_pos` actions, guidance 3.0, 4 denoising steps, shift 5.0, and per-request NumPy RNG seeds initialized from seed 0. - -You can also pass a local consolidated safetensors directory produced by `cosmos_framework.scripts.export_model`. - -```shell -python -m cosmos_framework.scripts.action_policy_server_robolab \ - --checkpoint-path /path/to/consolidated/model \ - --port 8000 -``` - -For other checkpoints, set `--resolution`, `--action-chunk-size`, and seed behavior to the values used by that policy's serving config. - -The server sends an empty metadata dictionary when a client connects, matching openpi's WebSocket policy server. Each request is an observation dictionary with a `prompt`, `observation/image`, `observation/gripper_position`, and either joint state fields for `--action-space joint_pos` or end-effector pose fields for `--action-space midtrain`. - -## Common Options - -| Argument | Default | Description | -| -------------------------------------------------- | ------------------------------------------ | --------------------------------------------------------------------------------------------------------------------------------------------------- | -| `--checkpoint-path` | `nvidia/Cosmos3-Nano-Policy-DROID` | `nvidia/Cosmos3-Nano-Policy-DROID`, `Cosmos3-Nano-Policy-DROID`, or a consolidated local safetensors checkpoint directory. | -| `--hf-revision` | `main` | Hugging Face revision to download for the public DROID policy checkpoint. | -| `--allow-dcp-checkpoint` | disabled | Permit direct DCP checkpoint loading for parity/debugging. | -| `--domain-name` | `droid_lerobot` | Action domain passed to `get_domain_id()`. | -| `--decode-video` | disabled | Return decoded rollout video as a uint8 NumPy array. | -| `--action-space` | `joint_pos` | Use `joint_pos` or `midtrain` RoboLab postprocessing. | -| `--resolution` | `480` | Action transform resolution. Use `480` for `nvidia/Cosmos3-Nano-Policy-DROID`. | -| `--conditioning-fps` | `15.0` | Conditioning FPS used by the action transform. | -| `--action-chunk-size` | `32` | Number of action steps to predict per request. Use `32` for `nvidia/Cosmos3-Nano-Policy-DROID`. | -| `--image-height` | `540` | Input observation image height before action transform preprocessing. | -| `--image-width` | `640` | Input observation image width before action transform preprocessing. | -| `--action-dim` | `8` | Raw action dimension. Use `8` for DROID `joint_pos`; set explicitly for other action spaces. | -| `--history-length` | `1` | Number of state/history action rows to trim from the generated action output. | -| `--guidance` | `3.0` | Classifier-free guidance scale. | -| `--num-steps` | `4` | Number of denoising steps. | -| `--shift` | `5.0` | UniPC sampler shift. | -| `--seed` | `0` | Base generation seed used to initialize the request RNG. | -| `--deterministic-seed` / `--no-deterministic-seed` | deterministic disabled | Use the same seed for every request, or advance a seeded RNG per request. The default advances the RNG for RoboLab parity with the internal server. | - -Health check: - -```shell -curl http://localhost:8000/healthz -``` From c6fee8df8be7c70a4bdeb8455d9d7b5f53ae6383 Mon Sep 17 00:00:00 2001 From: "Hongyu, Chiu" Date: Fri, 12 Jun 2026 12:55:31 +0800 Subject: [PATCH 06/16] Make T2I and T2V compatible with `transformers>=5.0.0`. (#27) Hi Cosmos team, We are fixing some CVE issues found in `transformers<=5.0.0`. This PR makes minor updates so the codebase works seamlessly with both pinned `4.57.6` and `>=5.0.0` for T2I and T2V. Signed-off-by: Hong-Yu Chiu --- cosmos_framework/model/vfm/mot/unified_mot.py | 2 +- cosmos_framework/model/vfm/vlm/qwen3_vl/qwen3_vl.py | 6 +++++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/cosmos_framework/model/vfm/mot/unified_mot.py b/cosmos_framework/model/vfm/mot/unified_mot.py index 4908e7a..a03e4aa 100644 --- a/cosmos_framework/model/vfm/mot/unified_mot.py +++ b/cosmos_framework/model/vfm/mot/unified_mot.py @@ -691,7 +691,7 @@ def _impl_init( ``Nemotron3DenseVLTextModel``. Sub-layer classes (MLP, RMSNorm, rotary embedding) are dispatched through ``layer_types``. """ - self.padding_idx = config.pad_token_id + self.padding_idx = getattr(config, "pad_token_id", None) self.vocab_size = config.vocab_size self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) diff --git a/cosmos_framework/model/vfm/vlm/qwen3_vl/qwen3_vl.py b/cosmos_framework/model/vfm/vlm/qwen3_vl/qwen3_vl.py index c2ccc3f..46f6f98 100644 --- a/cosmos_framework/model/vfm/vlm/qwen3_vl/qwen3_vl.py +++ b/cosmos_framework/model/vfm/vlm/qwen3_vl/qwen3_vl.py @@ -320,7 +320,11 @@ def __init__(self, config: Qwen3VLTextConfig): self.original_max_seq_len = config.max_position_embeddings self.config = config - self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + rope_type = self.rope_type + if rope_type not in ROPE_INIT_FUNCTIONS and rope_type == "default": + # transformers>=5 renamed "default" RoPE entry to "proportional". + rope_type = "proportional" + self.rope_init_fn = ROPE_INIT_FUNCTIONS[rope_type] self.mrope_section = ( config.rope_scaling.get("mrope_section", [24, 20, 20]) if config.rope_scaling is not None else [24, 20, 20] From 82f82293ffd8983651cd51d8191287da3973f534 Mon Sep 17 00:00:00 2001 From: Maosheng Liao Date: Fri, 12 Jun 2026 13:33:03 +0800 Subject: [PATCH 07/16] Refactor datapackerdataloader to be modular (#19) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Summary Refactors the training data layer from the monolithic `DataPackerDataLoader` / `DataPacker` / `PackingIterableDataset` into a modular, four-role abstraction wired by a single loader. Behavior is preserved (golden-batch byte-identical to the legacy loader; resume validated live), and all existing recipes are migrated. DataDistributor → RawItemProcessor → SampleBatcher → BatchCollator (shard/shuffle/ (raw item → one (samples → (group → one resume) sample dict) batch groups) batch dict) Each role is a small ABC with one required method; pick a built-in per slot or write your own. `CosmosDataLoader` is a `torch.utils.data.DataLoader` subclass, so it drops into the existing training loop. ## What changed ### New dataflow package — `cosmos_framework/data/vfm/dataflow/` - **Loaders:** `CosmosDataLoader` (+ `batch_size=` sugar → `SimpleBatcher` + `DefaultBatchCollator`), `JointCosmosDataLoader` (ratio-weighted heterogeneous join). - **Distributors:** `IterableDistributor`, `MapDistributor` (resumable), `RankPartitionedDistributor`, `MixtureDistributor`. - **Processors:** `IdentityProcessor` (+ recipe-specific `VLMProcessor`, `VideoPhy2Processor`). - **Batchers:** `SimpleBatcher`, `PoolPackingBatcher`, `SequentialPackingBatcher`. - **Collators:** `DefaultBatchCollator`, `VFMListCollator` (+ recipe `VLMCollator`). ### Legacy removal - Deleted `data_packer.py`, `data_packer_dataloader.py`, `packing_iterable_dataset.py`, `test_dp_state_distributed.py` (+ old tests). ### Experiment migrations - VLM `llava_ov` (renamed from `llava_ov_datapacker`, streaming `IterableDistributor`). - VLM `videophy2_sft_nano`. - VFM: existing path unchanged; added `vision_sft_nano_v2` (new-loader variant). - Added `llava_ov_mapresume` — map-style (`load_dataset(streaming=False)` + `MapDistributor`) resumable example. ### Config / TOML - `PATH_REMAPS["vlm"]`: route `dataloader_train.{max_samples_per_batch, max_sequence_length}` → nested `batcher.{max_batch_size, max_tokens}`. ### Checkpoint / resume - Renamed the resume-state selector value `"data_packer"` → `"cosmos_dataloader"` and env prefix `DP_STATE_` → `COSMOS_DL_STATE_` (`DataLoaderStateCallback`, `JointDataLoaderStateCallback`, `MapDistributor`). On-disk format unchanged. ### CI / tests / docs - Updated `tests/launch_regression_test.py` + launch scripts for the `llava_ov` rename (golden loss keyed by `llava_ov`; workflow `-k llava_ov`). - Added golden-batch, resume, and per-role unit tests. - Replaced `docs/custom_dataset.md` with the `CosmosDataLoader` tutorial; removed `docs/dataflow.md`. ## Validation - **Golden-batch equality:** VLM / videophy2 / VFM batches byte-identical to the legacy loader. - **Live save→stop→resume** on `pre_exp012_llava_ov_mapresume` (8 dp ranks, `save_iter=100`): per-rank `input_ids` shapes identical across the resume boundary — **792 `(iter, rank)` keys, 0 mismatches** — and loss curves match. No duplicated/skipped samples on any rank. - **No CI risk:** the `llava_ov` golden recipe and its streaming data path are unchanged; the remap only affects the 3 VLM TOMLs, all of which compose cleanly onto a real `PoolPackingBatcher`. ## Hard invariant Dataloader resume + checkpoint saving must not regress. Held: resume is preserved through the existing `DataLoaderStateCallback`, with map-style fast-forward and the multi-sample contiguity guard, and validated end-to-end above. --------- Co-authored-by: Claude Opus 4.8 (1M context) --- .github/workflows/gpu-tests.yml | 8 +- conftest.py | 5 + .../callbacks/cosmos_dataloader_state.py | 224 ++++++ .../callbacks/dataloader_state.py | 112 +-- cosmos_framework/configs/base/config.py | 3 +- .../base/experiment/sft/vision_sft_nano.py | 64 +- .../base/vlm/experiment/dataflow_roles.py | 114 +++ .../llava_ov_datapacker_experiment.py | 376 --------- .../base/vlm/experiment/llava_ov_vlm.py | 292 +++++++ .../experiment/videophy2_dataflow_roles.py | 124 +++ .../base/vlm/experiment/videophy2_sft_nano.py | 219 +----- .../configs/toml_config/sft_config.py | 14 +- .../configs/toml_config/toml_config_helper.py | 6 +- cosmos_framework/data/vfm/data_packer.py | 107 --- .../data/vfm/data_packer_dataloader.py | 623 --------------- .../data/vfm/dataflow/__init__.py | 38 + cosmos_framework/data/vfm/dataflow/base.py | 65 ++ .../data/vfm/dataflow/batchers.py | 350 +++++++++ .../data/vfm/dataflow/collators.py | 211 +++++ .../data/vfm/dataflow/distributors.py | 175 +++++ .../data/vfm/dataflow/golden_vfm_test.py | 276 +++++++ cosmos_framework/data/vfm/dataflow/loader.py | 261 +++++++ .../data/vfm/dataflow/processors.py | 17 + .../data/vfm/dataflow/resume_test.py | 58 ++ .../data/vfm/packing_iterable_dataset.py | 276 ------- .../data/vfm/packing_iterable_dataset_test.py | 78 -- .../data/vfm/test_dp_state_distributed.py | 683 ---------------- docs/custom_dataset.md | 736 ++++++++---------- examples/launch_sft_llava_ov.sh | 8 +- ...launch_sft_llava_ov_mapstyle_dataloader.sh | 44 ++ examples/launch_sft_videophy2_nano.sh | 2 +- ...launch_sft_vision_nano_cosmosdataloader.sh | 28 + ...llava_ov_datapacker.toml => llava_ov.toml} | 17 +- .../llava_ov_mapstyle_dataloader.toml | 117 +++ .../toml/sft_config/videophy2_sft_nano.toml | 6 +- .../vision_sft_nano_mapstyle_dataloader.toml | 91 +++ tests/launch_regression_test.py | 16 +- 37 files changed, 2982 insertions(+), 2862 deletions(-) create mode 100644 cosmos_framework/callbacks/cosmos_dataloader_state.py create mode 100644 cosmos_framework/configs/base/vlm/experiment/dataflow_roles.py delete mode 100644 cosmos_framework/configs/base/vlm/experiment/llava_ov_datapacker_experiment.py create mode 100644 cosmos_framework/configs/base/vlm/experiment/llava_ov_vlm.py create mode 100644 cosmos_framework/configs/base/vlm/experiment/videophy2_dataflow_roles.py delete mode 100644 cosmos_framework/data/vfm/data_packer.py delete mode 100644 cosmos_framework/data/vfm/data_packer_dataloader.py create mode 100644 cosmos_framework/data/vfm/dataflow/__init__.py create mode 100644 cosmos_framework/data/vfm/dataflow/base.py create mode 100644 cosmos_framework/data/vfm/dataflow/batchers.py create mode 100644 cosmos_framework/data/vfm/dataflow/collators.py create mode 100644 cosmos_framework/data/vfm/dataflow/distributors.py create mode 100644 cosmos_framework/data/vfm/dataflow/golden_vfm_test.py create mode 100644 cosmos_framework/data/vfm/dataflow/loader.py create mode 100644 cosmos_framework/data/vfm/dataflow/processors.py create mode 100644 cosmos_framework/data/vfm/dataflow/resume_test.py delete mode 100644 cosmos_framework/data/vfm/packing_iterable_dataset.py delete mode 100644 cosmos_framework/data/vfm/packing_iterable_dataset_test.py delete mode 100644 cosmos_framework/data/vfm/test_dp_state_distributed.py create mode 100755 examples/launch_sft_llava_ov_mapstyle_dataloader.sh create mode 100755 examples/launch_sft_vision_nano_cosmosdataloader.sh rename examples/toml/sft_config/{llava_ov_datapacker.toml => llava_ov.toml} (87%) create mode 100644 examples/toml/sft_config/llava_ov_mapstyle_dataloader.toml create mode 100644 examples/toml/sft_config/vision_sft_nano_mapstyle_dataloader.toml diff --git a/.github/workflows/gpu-tests.yml b/.github/workflows/gpu-tests.yml index 9245dd2..3a92614 100644 --- a/.github/workflows/gpu-tests.yml +++ b/.github/workflows/gpu-tests.yml @@ -10,7 +10,7 @@ # * training-smoke — Nano SFT pipeline (convert -> train 5 -> export -> t2i) # * generator-regression — vision_sft_nano loss vs goldens (4-GPU subset) # * inference-smoke — Nano multi-modality inference (t2vs + policy + forward_dynamics) -# * reasoner-regression — llava_ov_datapacker loss vs goldens (4-GPU subset) +# * reasoner-regression — llava_ov loss vs goldens (4-GPU subset) # # Requires: # * a self-hosted runner labelled [self-hosted, gpu, h200] with 8 GPUs, @@ -153,12 +153,12 @@ jobs: - name: Sync environment (cu128-train) run: uv sync --all-extras --group=cu128-train - # Reasoner (llava_ov_datapacker) loss vs the h100 goldens. -s streams the live log. - - name: Reasoner regression (llava_ov_datapacker, 4-GPU subset) + # Reasoner (llava_ov) loss vs the h100 goldens. -s streams the live log. + - name: Reasoner regression (llava_ov, 4-GPU subset) run: | export LD_LIBRARY_PATH= uv run --all-extras --group=cu128-train python -m pytest -v -s \ - tests/launch_regression_test.py -k llava_ov_datapacker \ + tests/launch_regression_test.py -k llava_ov \ --num-gpus=4 --levels=2 -o addopts= # The h100_inputs fixture removes its DCP stage on teardown; clear the diff --git a/conftest.py b/conftest.py index 6613358..32ca4e8 100644 --- a/conftest.py +++ b/conftest.py @@ -252,6 +252,11 @@ def init_torch_test(): _WHITELIST_ENV_VARS = { "LD_LIBRARY_PATH", + # Set as a side-effect of importing TransformerEngine (via NANO_MODEL_CONFIG / + # SUPER_MODEL_CONFIG). Any SFT experiment config test that imports a model config + # will trigger this; whitelisting avoids a spurious teardown error that is + # unrelated to the test logic. + "NVTE_CUDA_INCLUDE_DIR", "QT_QPA_FONTDIR", "QT_QPA_PLATFORM_PLUGIN_PATH", "TORCHINDUCTOR_CACHE_DIR", diff --git a/cosmos_framework/callbacks/cosmos_dataloader_state.py b/cosmos_framework/callbacks/cosmos_dataloader_state.py new file mode 100644 index 0000000..a5553b0 --- /dev/null +++ b/cosmos_framework/callbacks/cosmos_dataloader_state.py @@ -0,0 +1,224 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: OpenMDW-1.1 + +"""Checkpoint / resume callbacks for ``CosmosDataLoader``. + +Two public classes: + +* ``CosmosDataLoaderStateCallback`` — for a single ``CosmosDataLoader`` whose + distributor is a ``MapDistributor``. Saves per-worker ``(epoch, index)`` to + the DCP checkpoint and, on resume, sets ``COSMOS_DL_STATE_*`` env vars so + that ``MapDistributor.stream`` fast-forwards each worker to the saved + position. + +* ``JointCosmosDataLoaderStateCallback`` — for ``JointCosmosDataLoader``. + Persists the outer ``global_id`` (dataset-selection sequence cursor) plus + inner per-dataset per-worker state via one ``CosmosDataLoaderStateCallback`` + per inner loader. + +Usage (single loader):: + + exp["trainer"]["callbacks"]["dataloader_state"] = CosmosDataLoaderStateCallback() + +Usage (joint loader):: + + joint_loader = JointCosmosDataLoader(dataloaders={...}, seed=42) + exp["dataloader_train"] = joint_loader + exp["trainer"]["callbacks"]["dataloader_state"] = JointCosmosDataLoaderStateCallback( + outer_loader=joint_loader, + ) +""" + +from __future__ import annotations + +import os +from dataclasses import dataclass +from typing import Any + +import torch + +from cosmos_framework.model._base import ImaginaireModel +from cosmos_framework.utils import log +from cosmos_framework.utils.callback import Callback + + +@dataclass +class _WorkerState: + epoch: int = 0 + index: int = 0 + + +class CosmosDataLoaderStateCallback(Callback): + """Checkpoint/resume for a single ``CosmosDataLoader(MapDistributor)``. + + Tracks the highest-seen ``(epoch, index)`` per worker from batch metadata + fields ``sample_worker_id``, ``sample_epoch``, ``sample_index`` (injected + by ``MapDistributor``). + + On ``state_dict()`` the per-worker positions are serialised into the DCP + checkpoint (``checkpoint_component = "dataloader"``). + + On ``load_state_dict()`` the positions are written to env vars:: + + COSMOS_DL_STATE_WORKER_{id}_EPOCH + COSMOS_DL_STATE_WORKER_{id}_INDEX + + (or ``COSMOS_DL_STATE_{name}_WORKER_{id}_*`` when ``name`` is set, for + multi-loader namespacing). ``MapDistributor.stream`` pops these on first + iteration and resumes from ``index + 1``. + """ + + checkpoint_component: str = "dataloader" + + def __init__(self, name: str = "", distributor_type: str | None = None) -> None: + # distributor_type is accepted but unused — it exists only so that Hydra + # struct-merging over the legacy DataLoaderStateCallback entry (which + # carries distributor_type="${data_setting.distributor_type}") does not + # raise an unexpected-keyword-argument error at instantiation time. + super().__init__() + self.name = name + self.config: Any = None + self.state: dict[int, _WorkerState] = {} + + @property + def _env_prefix(self) -> str: + return f"COSMOS_DL_STATE_{self.name}_" if self.name else "COSMOS_DL_STATE_" + + def _update_state_from_batch(self, data_batch: dict[str, torch.Tensor]) -> None: + if "sample_worker_id" not in data_batch: + return # IterableDistributor / no position metadata + worker_ids = data_batch["sample_worker_id"].tolist() + epochs = data_batch["sample_epoch"].tolist() + indices = data_batch["sample_index"].tolist() + for worker_id, epoch, index in zip(worker_ids, epochs, indices, strict=True): + cur = self.state.get(worker_id) + if cur is None: + self.state[worker_id] = _WorkerState(epoch=epoch, index=index) + elif epoch > cur.epoch or (epoch == cur.epoch and index > cur.index): + self.state[worker_id] = _WorkerState(epoch=epoch, index=index) + + def on_training_step_batch_end( + self, + model: ImaginaireModel, + data_batch: dict[str, torch.Tensor], + output_batch: dict[str, torch.Tensor], + loss: torch.Tensor, + iteration: int = 0, + ) -> None: + self._update_state_from_batch(data_batch) + + def on_training_step_end( + self, + model: ImaginaireModel, + data_batch: dict[str, torch.Tensor], + output_batch: dict[str, torch.Tensor], + loss: torch.Tensor, + iteration: int = 0, + ) -> None: + if self.config and iteration % self.config.trainer.logging_iter == 0: + msg = "\n" + for wid, s in self.state.items(): + msg += f"worker {wid}: epoch={s.epoch}, index={s.index}\n" + log.info(msg) + + def has_checkpoint_state(self) -> bool: + return True + + def state_dict(self) -> dict[int, dict[str, int]]: + result: dict[int, dict[str, int]] = {} + for worker_id, s in self.state.items(): + result[worker_id] = {"epoch": s.epoch, "index": s.index} + log.info(f"Saved CosmosDataLoader state for worker {worker_id}: epoch={s.epoch}, index={s.index}") + return result + + def load_state_dict(self, state_dict: dict[int, dict[str, int]]) -> None: + if not state_dict: + log.info("No CosmosDataLoader state found in checkpoint") + return + + pfx = self._env_prefix + self.state = {} + for worker_id, per_worker in state_dict.items(): + epoch = per_worker["epoch"] + index = per_worker["index"] + self.state[worker_id] = _WorkerState(epoch=epoch, index=index) + os.environ[f"{pfx}WORKER_{worker_id}_EPOCH"] = str(epoch) + os.environ[f"{pfx}WORKER_{worker_id}_INDEX"] = str(index) + log.info(f"Loaded CosmosDataLoader state for worker {worker_id}: epoch={epoch}, index={index}") + + +class JointCosmosDataLoaderStateCallback(Callback): + """Checkpoint/resume for ``JointCosmosDataLoader``. + + Manages two levels of state in a single DCP checkpoint entry: + + 1. **Outer** ``global_id`` — how many batches the outer loader has yielded. + Restored via ``outer_loader.set_start_iteration(global_id)`` so the + deterministic dataset-selection sequence resumes from the right step. + + 2. **Inner** per-dataset, per-worker ``(epoch, index)`` — one + ``CosmosDataLoaderStateCallback`` per inner loader, keyed by name. + + The ``checkpoint_component = "dataloader"`` class attribute ensures the DCP + checkpointer's ``_DataloaderWrapper`` discovers exactly this callback. Do + **not** also register standalone ``CosmosDataLoaderStateCallback`` instances + for the inner loaders — this class already handles them all. + """ + + checkpoint_component: str = "dataloader" + + def __init__(self, outer_loader: Any) -> None: + super().__init__() + self._outer = outer_loader + self._inner: dict[str, CosmosDataLoaderStateCallback] = { + name: CosmosDataLoaderStateCallback(name=name) + for name in outer_loader._names + } + self.config: Any = None + + def _update_state_from_batch(self, batch: dict) -> None: + name = batch.get("dataset_name") + if name in self._inner: + self._inner[name]._update_state_from_batch(batch) + + def on_training_step_batch_end( + self, + model: Any, + data_batch: dict, + output_batch: dict, + loss: Any, + iteration: int = 0, + ) -> None: + self._update_state_from_batch(data_batch) + + def on_training_step_end( + self, + model: Any, + data_batch: dict, + output_batch: dict, + loss: Any, + iteration: int = 0, + ) -> None: + if self.config and iteration % self.config.trainer.logging_iter == 0: + msg = f"\nJointCosmosDataLoader global_id={self._outer._global_id}\n" + for name, cb in self._inner.items(): + for wid, s in cb.state.items(): + msg += f" [{name}] worker {wid}: epoch={s.epoch}, index={s.index}\n" + log.info(msg) + + def has_checkpoint_state(self) -> bool: + return True + + def state_dict(self) -> dict: + return { + "global_id": self._outer._global_id, + **{name: cb.state_dict() for name, cb in self._inner.items()}, + } + + def load_state_dict(self, state: dict) -> None: + global_id = state.get("global_id", 0) + self._outer.set_start_iteration(global_id) + log.info(f"JointCosmosDataLoaderStateCallback: resumed outer global_id={global_id}") + for name, cb in self._inner.items(): + if name in state: + cb.load_state_dict(state[name]) diff --git a/cosmos_framework/callbacks/dataloader_state.py b/cosmos_framework/callbacks/dataloader_state.py index fee45b9..ec20eea 100644 --- a/cosmos_framework/callbacks/dataloader_state.py +++ b/cosmos_framework/callbacks/dataloader_state.py @@ -50,7 +50,7 @@ def _update_state_from_batch(self, data_batch: dict[str, torch.Tensor]) -> None: ): self.state[worker_id] = NoReplaceShardlistState(epoch=epoch, index=index) - _ACTIVE_DISTRIBUTOR_TYPES = ("no_replace", "data_packer") + _ACTIVE_DISTRIBUTOR_TYPES = ("no_replace",) def on_training_step_batch_end( self, @@ -104,114 +104,10 @@ def load_state_dict(self, state_dict: dict[int, dict[str, int]]) -> None: return self.state = {} - # Build env var prefix. For data_packer, namespacing avoids conflicts - # when multiple DataPackerDataLoader instances share the same process - # (e.g. inside JointDataPackerDataLoader). name="" → original format. - _dp_pfx = f"DP_STATE_{self.name}_" if self.name else "DP_STATE_" for worker_id, per_worker_state in state_dict.items(): epoch = per_worker_state["epoch"] index = per_worker_state["index"] self.state[worker_id] = NoReplaceShardlistState(epoch=epoch, index=index) - if self.distributor_type == "data_packer": - os.environ[f"{_dp_pfx}WORKER_{worker_id}_EPOCH"] = str(epoch) - os.environ[f"{_dp_pfx}WORKER_{worker_id}_INDEX"] = str(index) - log.info(f"Loaded data_packer dataloader state for worker {worker_id}: epoch={epoch}, index={index}") - else: - os.environ[f"NSL_STATE_WORKER_{worker_id}_EPOCH"] = str(epoch) - os.environ[f"NSL_STATE_WORKER_{worker_id}_INDEX"] = str(index) - log.info(f"Loaded no_replace dataloader state for worker {worker_id}: epoch={epoch}, index={index}") - - -class JointDataLoaderStateCallback(Callback): - """Checkpoint/resume state for ``JointDataPackerDataLoader``. - - Manages two levels of state in a single DCP checkpoint entry - (``checkpoint_component = "dataloader"``): - - 1. **Outer** ``global_id`` — the number of batches the outer loader has - yielded. Restored via ``outer_loader.set_start_iteration(global_id)`` - so the deterministic dataset-selection sequence resumes from the correct - step. - - 2. **Inner** per-dataset, per-worker ``(epoch, index)`` — one - ``DataLoaderStateCallback`` per inner loader, keyed by the dataset name. - Each inner callback sets namespaced env vars on ``load_state_dict`` so - workers fast-forward to the saved sample position. - - Usage in experiment configs:: - - joint_loader = JointDataPackerDataLoader(dataloaders={...}, seed=42) - exp["dataloader_train"] = joint_loader - exp["trainer"]["callbacks"]["dataloader_state"] = JointDataLoaderStateCallback( - outer_loader=joint_loader, - distributor_type="data_packer", - ) - - The ``checkpoint_component = "dataloader"`` class attribute ensures the DCP - checkpointer's ``_DataloaderWrapper`` discovers exactly this callback (it - picks the first matching callback). Do **not** also register standalone - ``DataLoaderStateCallback`` instances for the inner loaders — this class - already handles them all. - """ - - checkpoint_component: str = "dataloader" - - def __init__( - self, - outer_loader: Any, - distributor_type: str = "data_packer", - ) -> None: - super().__init__() - self._outer = outer_loader - self._inner: dict[str, DataLoaderStateCallback] = { - name: DataLoaderStateCallback(distributor_type=distributor_type, name=name) - for name in outer_loader._names - } - self.config: Any = None - - def _update_state_from_batch(self, batch: dict) -> None: - name = batch.get("dataset_name") - if name in self._inner: - self._inner[name]._update_state_from_batch(batch) - - def on_training_step_batch_end( - self, - model: Any, - data_batch: dict, - output_batch: dict, - loss: Any, - iteration: int = 0, - ) -> None: - self._update_state_from_batch(data_batch) - - def on_training_step_end( - self, - model: Any, - data_batch: dict, - output_batch: dict, - loss: Any, - iteration: int = 0, - ) -> None: - if self.config and iteration % self.config.trainer.logging_iter == 0: - msg = f"\nJointDataPackerDataLoader global_id={self._outer._global_id}\n" - for name, cb in self._inner.items(): - for wid, state in cb.state.items(): - msg += f" [{name}] worker {wid}: epoch={state.epoch}, index={state.index}\n" - log.info(msg) - - def has_checkpoint_state(self) -> bool: - return True - - def state_dict(self) -> dict: - return { - "global_id": self._outer._global_id, - **{name: cb.state_dict() for name, cb in self._inner.items()}, - } - - def load_state_dict(self, state: dict) -> None: - global_id = state.get("global_id", 0) - self._outer.set_start_iteration(global_id) - log.info(f"JointDataLoaderStateCallback: resumed outer global_id={global_id}") - for name, cb in self._inner.items(): - if name in state: - cb.load_state_dict(state[name]) + os.environ[f"NSL_STATE_WORKER_{worker_id}_EPOCH"] = str(epoch) + os.environ[f"NSL_STATE_WORKER_{worker_id}_INDEX"] = str(index) + log.info(f"Loaded no_replace dataloader state for worker {worker_id}: epoch={epoch}, index={index}") diff --git a/cosmos_framework/configs/base/config.py b/cosmos_framework/configs/base/config.py index 391df1c..d1f975d 100644 --- a/cosmos_framework/configs/base/config.py +++ b/cosmos_framework/configs/base/config.py @@ -95,7 +95,8 @@ def make_config() -> Config: register_cluster() register_vlm() - # Register shipped experiments explicitly. + # Register shipped experiments explicitly. (vision_sft_nano also defines + # vision_sft_nano_mapstyle_dataloader — the CosmosDataLoader variant — in the same module.) import cosmos_framework.configs.base.experiment.sft.vision_sft_nano # noqa: F401 import cosmos_framework.configs.base.experiment.sft.vision_sft_super # noqa: F401 import cosmos_framework.configs.base.experiment.action.posttrain_config.action_policy_droid_nano # noqa: F401 diff --git a/cosmos_framework/configs/base/experiment/sft/vision_sft_nano.py b/cosmos_framework/configs/base/experiment/sft/vision_sft_nano.py index 0bf8e71..fc78994 100644 --- a/cosmos_framework/configs/base/experiment/sft/vision_sft_nano.py +++ b/cosmos_framework/configs/base/experiment/sft/vision_sft_nano.py @@ -39,6 +39,13 @@ PackingDataLoader, RankPartitionedDataLoader, ) +from cosmos_framework.data.vfm.dataflow import ( + CosmosDataLoader, + IdentityProcessor, + RankPartitionedDistributor, + SequentialPackingBatcher, + VFMListCollator, +) from cosmos_framework.data.vfm.local_datasets.sft_dataset import get_sft_dataset from cosmos_framework.utils.lazy_config import LazyCall as L from cosmos_framework.utils.lazy_config import LazyDict @@ -279,6 +286,61 @@ ) -for _item in [vision_sft_nano]: +# ``vision_sft_nano_mapstyle_dataloader`` — identical to ``vision_sft_nano`` except the training +# dataloader uses the four-role ``CosmosDataLoader`` stack +# (``RankPartitionedDistributor`` → ``IdentityProcessor`` → +# ``SequentialPackingBatcher`` → ``VFMListCollator``) instead of the legacy +# ``PackingDataLoader`` + ``RankPartitionedDataLoader``. Every other block is reused +# verbatim by deep-copying the base recipe and overriding only ``job.name`` and +# ``dataloader_train``. +vision_sft_nano_mapstyle_dataloader = copy.deepcopy(vision_sft_nano) +vision_sft_nano_mapstyle_dataloader.job.name = "vision_sft_nano_mapstyle_dataloader" +vision_sft_nano_mapstyle_dataloader.dataloader_train = L(CosmosDataLoader)( + distributor=L(RankPartitionedDistributor)( + datasets=dict( + video=dict( + ratio=1, + dataset=L(get_sft_dataset)( + append_duration_fps_timestamps=True, + append_resolution_info=True, + caption_suffix="", + cfg_dropout_keep_metadata=False, + cfg_dropout_rate=0.1, + # 70% T2V, 20% I2V (first frame), 10% V2V (first 5 frames / 2 latent frames) + conditioning_config={0: 0.7, 1: 0.2, 2: 0.1}, + conditioning_fps=-1, + conditioning_fps_noise_std=0.0, + frame_selection_mode="first", + jsonl_paths=["${oc.env:DATASET_PATH}/train/video_dataset_file.jsonl"], + min_short_edge=0, + num_video_frames=-1, + resolution="256", + sample_by_window=False, + temporal_compression_factor=4, + temporal_interval_mode="max_30fps", + use_system_prompt=False, + tokenizer_config="${model.config.vlm_config.tokenizer}", + ), + ), + ), + ), + processor=L(IdentityProcessor)(), + batcher=L(SequentialPackingBatcher)( + max_sequence_length=45056, + tokenizer_spatial_compression_factor=16, + tokenizer_temporal_compression_factor=4, + patch_spatial=2, + max_samples_per_batch=None, + sound_latent_fps=0, + audio_sample_rate=48000, + ), + collator=L(VFMListCollator)(), + num_workers=4, + persistent_workers=True, + prefetch_factor=4, +) + + +for _item in [vision_sft_nano, vision_sft_nano_mapstyle_dataloader]: _name = [k for k, v in globals().items() if v is _item][0] cs.store(group="experiment", package="_global_", name=_name, node=_item) diff --git a/cosmos_framework/configs/base/vlm/experiment/dataflow_roles.py b/cosmos_framework/configs/base/vlm/experiment/dataflow_roles.py new file mode 100644 index 0000000..e14d441 --- /dev/null +++ b/cosmos_framework/configs/base/vlm/experiment/dataflow_roles.py @@ -0,0 +1,114 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: OpenMDW-1.1 + +"""VLM dataflow roles (RawItemProcessor + BatchCollator) extracted 1:1 from +VLMDataPacker (llava_ov_vlm.py). Behavior-preserving.""" + +from __future__ import annotations + +from typing import Any + +import torch + +from cosmos_framework.data.vfm.dataflow.base import BatchCollator, RawItemProcessor +from cosmos_framework.utils.vlm.constant import IGNORE_INDEX, PROCESSOR_KEYS_TO_ADD + + +class VLMProcessor(RawItemProcessor): + """ShareGPT image+conversation record -> VLM training tensors.""" + + def __init__(self, processor: Any, ignore_index: int = IGNORE_INDEX) -> None: + self._processor = processor + self._ignore_index = ignore_index + + @staticmethod + def _decode_image(image: Any) -> Any: + """Decode a HuggingFace streaming image to PIL. + + In streaming mode HuggingFace delivers images as + ``{"bytes": bytes, "path": str}`` dicts rather than decoded PIL Images. + """ + if isinstance(image, dict): + import io + + from PIL import Image + + raw = image.get("bytes") + if raw: + return Image.open(io.BytesIO(raw)).convert("RGB") + path = image.get("path") + if path: + return Image.open(path).convert("RGB") + return None + return image + + def _sharegpt_to_openai(self, item: dict) -> list[dict]: + """Convert ShareGPT conversation to OpenAI message format. + + LLaVA-OneVision-Data records use ``from``/``value`` pairs where the + human turn may contain a ```` placeholder. We strip the + placeholder and attach the PIL image as a separate content block. + """ + conversations = item.get("conversations", []) + image = self._decode_image(item.get("image")) # PIL.Image or None + messages: list[dict] = [] + image_inserted = False + + for turn in conversations: + role = "user" if turn["from"] == "human" else "assistant" + text = turn["value"].replace("", "").strip() + + if role == "user" and not image_inserted and image is not None: + content: Any = [ + {"type": "image", "image": image}, + {"type": "text", "text": text}, + ] + image_inserted = True + else: + content = text + + messages.append({"role": role, "content": content}) + + return messages + + def process(self, item: dict) -> dict: + messages = self._sharegpt_to_openai(item) + inputs = self._processor.apply_chat_template( + messages, tokenize=True, add_generation_prompt=False + ) + input_ids = inputs["input_ids"] + token_mask = self._processor.add_assistant_tokens_mask(input_ids) + labels = input_ids.clone() + labels[~token_mask] = self._ignore_index + result: dict = {"input_ids": input_ids, "labels": labels} + for key in PROCESSOR_KEYS_TO_ADD: + if key in inputs and inputs[key] is not None: + result[key] = inputs[key] + return result + + +class VLMCollator(BatchCollator): + """max_batch_size=1 collation: batch-dim sequence tensors, keep vision tensors + flat, stamp resume meta (zeros — streaming source has no position).""" + + def collate(self, samples: list[dict]) -> dict: + assert len(samples) == 1, f"VLMCollator expects max_batch_size=1, got {len(samples)}" + s = samples[0] + worker_info = torch.utils.data.get_worker_info() + worker_id = worker_info.id if worker_info is not None else 0 + batch: dict = { + "input_ids": s["input_ids"].unsqueeze(0), + "labels": s["labels"].unsqueeze(0), + "sample_worker_id": torch.tensor([worker_id]), + "sample_epoch": torch.tensor([0]), + "sample_index": torch.tensor([0]), + } + if "attention_mask" in s and s["attention_mask"] is not None: + batch["attention_mask"] = s["attention_mask"].unsqueeze(0) + for key in ( + "pixel_values", "pixel_values_videos", "image_grid_thw", + "video_grid_thw", "second_per_grid_ts", + ): + if key in s and s[key] is not None: + batch[key] = s[key] + return batch diff --git a/cosmos_framework/configs/base/vlm/experiment/llava_ov_datapacker_experiment.py b/cosmos_framework/configs/base/vlm/experiment/llava_ov_datapacker_experiment.py deleted file mode 100644 index 723e1b4..0000000 --- a/cosmos_framework/configs/base/vlm/experiment/llava_ov_datapacker_experiment.py +++ /dev/null @@ -1,376 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: OpenMDW-1.1 - -"""VLM training on lmms-lab/LLaVA-OneVision-Data via DataPackerDataLoader. - -Self-contained — inlines the Phase 2 VLMModel/FSDP2 base (formerly the -``pre_exp012_000_phase2_vlm_smoke_4gpu_8b`` smoke recipe in -``pre_exp012_phase2_vlm_smoke.py``) and replaces the dataloader with the -OSS-facing DataPackerDataLoader + VLMDataPacker pattern. Hydra defaults -below pin the VLM model (``vlm_fsdp`` / ``qwen3_vl_8b_instruct``), the -checkpoint backend, and callbacks. - -The dataset is loaded in streaming mode from the HuggingFace Hub so no local -download is required. Each record is converted from ShareGPT conversation -format to the OpenAI message format expected by Qwen3-VL's processor, then -tokenized in the DataLoader worker via ``processor.apply_chat_template``. - -Resume semantics ----------------- -The streaming HF dataset is a ``datasets.IterableDataset``, which -``DataPackerDataLoader`` flags with ``_has_dp_meta=False`` (see -``data_packer_dataloader.py:317-321`` — "Stateful resume is not supported for -IterableDataset sources"). On checkpoint save the dataloader shard stores -placeholder ``(epoch=0, index=0)`` per worker — VLMDataPacker.sft_collate_fn -stamps these zeros explicitly because the stream has no meaningful position -to record. When resuming with ``checkpoint.load_training_state=true``: - - - model / optim / scheduler / trainer state restore correctly (iter - counter, optimizer momentum, LR schedule position all continue). - - dataloader stream position does NOT restore; the streamed dataset - re-yields from the beginning, so the first N resumed iters see the - same samples as the first N iters of the original run. - -For a true position-stateful resume, swap the data_source to a map-style -dataset (``load_dataset(..., streaming=False)``). - -Usage (smoke test):: - - torchrun --nproc_per_node=4 --master_port=12344 -m cosmos_framework.scripts.train \\ - --config=cosmos_framework/configs/base/vlm/config.py -- \\ - experiment=pre_exp012_llava_ov_datapacker \\ - "model.config.policy.backbone.model_name=/path/to/Siglip2-Qwen3-1.7B-BF16-Alignment" \\ - trainer.max_iter=10 trainer.logging_iter=1 \\ - job.wandb_mode=disabled ckpt_type=dummy - -See ``launch_vlm_llava_ov.sh`` for a ready-to-run shell script. -""" - -from __future__ import annotations - -from typing import Any - -from hydra.core.config_store import ConfigStore - -from cosmos_framework.utils.lazy_config import LazyCall as L -from cosmos_framework.utils.lazy_config import LazyDict, instantiate -from cosmos_framework.data.vfm.data_packer import DataPacker -from cosmos_framework.data.vfm.data_packer_dataloader import DataPackerDataLoader -from cosmos_framework.data.vfm.processors import build_processor -from cosmos_framework.utils.vlm.constant import IGNORE_INDEX, PROCESSOR_KEYS_TO_ADD - -cs = ConfigStore.instance() - - -# --------------------------------------------------------------------------- -# LLaVA-OneVision-Data source factory -# -# Loads lmms-lab/LLaVA-OneVision-Data in streaming mode so no local download -# is needed. streaming=True returns an IterableDataset which DataPackerDataLoader -# wraps directly. -# --------------------------------------------------------------------------- - - -def build_vlm_datapacker_dataloader(**kwargs) -> "DataPackerDataLoader": - """Thin wrapper around DataPackerDataLoader that drops schema keys injected by - OmegaConf when the parent experiment's VLMRecipeDataLoader schema merges with - our DataPackerDataLoader config (e.g. ``storage_type``). - """ - for _spurious in ("storage_type",): - kwargs.pop(_spurious, None) - return DataPackerDataLoader(**kwargs) - - -def get_llava_ov_streaming( - subset: str = "si", - split: str = "train", -) -> Any: - """Load lmms-lab/LLaVA-OneVision-Data as a streaming HuggingFace IterableDataset. - - Args: - subset: Dataset config/subset name. ``"si"`` (single-image, ~1M samples) - is the standard choice; pass any valid config name from the Hub. - split: Dataset split (default ``"train"``). - - Returns: - A streaming ``datasets.IterableDataset`` whose items have keys: - ``id``, ``image`` (PIL.Image), ``conversations`` (ShareGPT format). - """ - try: - from datasets import load_dataset - except ImportError as exc: - raise ImportError("pip install datasets to use lmms-lab/LLaVA-OneVision-Data") from exc - - ds = load_dataset( - "lmms-lab/LLaVA-OneVision-Data", - name=subset, - split=split, - streaming=True, - ) - # Pre-filter to remove records without an image or conversations so - # sft_process_sample never receives unparseable samples (DataPacker's - # packing engine does not tolerate None returns from sft_process_sample). - return ds.filter(lambda x: x.get("image") is not None and len(x.get("conversations") or []) >= 2) - - -# --------------------------------------------------------------------------- -# VLMDataPacker -# -# Bridges lmms-lab/LLaVA-OneVision-Data (ShareGPT format) into the -# VLMModel training loop. -# -# Three-step pipeline per sample: -# 1. Convert ShareGPT (from/value) → OpenAI messages (role/content). -# 2. Apply processor.apply_chat_template → input_ids, pixel_values, etc. -# 3. Build labels by masking non-assistant tokens with IGNORE_INDEX. -# --------------------------------------------------------------------------- - - -class VLMDataPacker(DataPacker): - """DataPacker adapter for lmms-lab/LLaVA-OneVision-Data + Qwen3-VL processor. - - Converts ShareGPT-format image+conversation samples into the - ``input_ids / labels / pixel_values / image_grid_thw`` batch dict that - ``VLMModel.training_step`` expects. - - Designed for ``max_batch_size=1`` — each packed batch is a single sample. - The ``sft_collate_fn`` adds a leading batch dimension to 1-D tensors - (``input_ids``, ``labels``, ``attention_mask``) while leaving - ``pixel_values`` and ``image_grid_thw`` in their native flat shapes, - matching what Qwen3-VL's forward pass expects. - """ - - def __init__( - self, - tokenizer_config: Any, - max_seq_len: int = 16000, - ignore_index: int = IGNORE_INDEX, - ) -> None: - self._max_seq_len = max_seq_len - self._ignore_index = ignore_index - # Instantiate if tokenizer_config is a Hydra LazyCall; use directly if already built. - self._processor = ( - tokenizer_config if hasattr(tokenizer_config, "apply_chat_template") else instantiate(tokenizer_config) - ) - - # ------------------------------------------------------------------ - # Internal helpers - # ------------------------------------------------------------------ - - @staticmethod - def _decode_image(image: Any) -> Any: - """Decode a HuggingFace streaming image to PIL. - - In streaming mode HuggingFace delivers images as - ``{"bytes": bytes, "path": str}`` dicts rather than decoded PIL Images. - """ - if isinstance(image, dict): - import io - - from PIL import Image - - raw = image.get("bytes") - if raw: - return Image.open(io.BytesIO(raw)).convert("RGB") - path = image.get("path") - if path: - return Image.open(path).convert("RGB") - return None - return image - - def _sharegpt_to_openai(self, item: dict) -> list[dict]: - """Convert ShareGPT conversation to OpenAI message format. - - LLaVA-OneVision-Data records use ``from``/``value`` pairs where the - human turn may contain a ```` placeholder. We strip the - placeholder and attach the PIL image as a separate content block. - """ - conversations = item.get("conversations", []) - image = self._decode_image(item.get("image")) # PIL.Image or None - messages: list[dict] = [] - image_inserted = False - - for turn in conversations: - role = "user" if turn["from"] == "human" else "assistant" - text = turn["value"].replace("", "").strip() - - if role == "user" and not image_inserted and image is not None: - content: Any = [ - {"type": "image", "image": image}, - {"type": "text", "text": text}, - ] - image_inserted = True - else: - content = text - - messages.append({"role": role, "content": content}) - - return messages - - # ------------------------------------------------------------------ - # DataPacker protocol - # ------------------------------------------------------------------ - - def sft_process_sample(self, item: dict) -> dict: - """Convert one LLaVA-OV record to VLM training tensors.""" - messages = self._sharegpt_to_openai(item) - inputs = self._processor.apply_chat_template( - messages, - tokenize=True, - add_generation_prompt=False, - ) - input_ids = inputs["input_ids"] # [N] - - token_mask = self._processor.add_assistant_tokens_mask(input_ids) # [N] bool - labels = input_ids.clone() # [N] - labels[~token_mask] = self._ignore_index - - result: dict = { - "input_ids": input_ids, - "labels": labels, - } - for key in PROCESSOR_KEYS_TO_ADD: - if key in inputs and inputs[key] is not None: - result[key] = inputs[key] - - return result - - def compute_num_tokens(self, sample: dict) -> int: - """Token count = sequence length (input_ids).""" - return int(sample["input_ids"].shape[0]) # [N] → scalar - - def sft_collate_fn( - self, - samples: list[dict], - max_len: int, - ignore_label_id: int = IGNORE_INDEX, - ) -> dict: - """Assemble one VLM training batch. - - Designed for ``max_batch_size=1``. 1-D sequence tensors get an - unsqueezed batch dimension; ``pixel_values`` / ``image_grid_thw`` - stay in the flat format Qwen3-VL expects. - """ - assert len(samples) == 1, f"VLMDataPacker expects max_batch_size=1, got {len(samples)}" - s = samples[0] - - import torch - - worker_info = torch.utils.data.get_worker_info() - worker_id = worker_info.id if worker_info is not None else 0 - - batch: dict = { - "input_ids": s["input_ids"].unsqueeze(0), # [1,N] - "labels": s["labels"].unsqueeze(0), # [1,N] - "sample_worker_id": torch.tensor([worker_id]), # [1] - "sample_epoch": torch.tensor([0]), # [1] streaming has no epoch concept - "sample_index": torch.tensor([0]), # [1] streaming has no global index - } - - if "attention_mask" in s and s["attention_mask"] is not None: - batch["attention_mask"] = s["attention_mask"].unsqueeze(0) # [1,N] - - # Vision tensors: pixel_values [P,C] and image_grid_thw [1,3] stay flat. - for key in ("pixel_values", "pixel_values_videos", "image_grid_thw", "video_grid_thw", "second_per_grid_ts"): - if key in s and s[key] is not None: - batch[key] = s[key] - - return batch - - -# --------------------------------------------------------------------------- -# Experiment registration -# --------------------------------------------------------------------------- - - -pre_exp012_llava_ov_datapacker = LazyDict( - dict( - # Hydra defaults — inlined from the former pre_exp012_000_phase2_vlm_smoke_4gpu_8b - # smoke recipe. data_train/data_val intentionally omitted because the - # dataloader_train below is a self-contained DataPackerDataLoader; pulling in - # the smoke's s3 webdataset defaults would let storage_type schema bleed into - # our DataPackerDataLoader config. - defaults=[ - {"override /checkpoint": "s3"}, - {"override /model": "vlm_fsdp"}, - {"override /vlm_policy": "qwen3_vl_8b_instruct"}, - {"override /callbacks": ["basic_vlm", "basic_log"]}, - "_self_", - ], - job=dict( - name="pre_exp012_llava_ov_datapacker_${now:%Y-%m-%d}_${now:%H-%M-%S}", - group="vlm_llava_ov_demo", - wandb_mode="disabled", - ), - trainer=dict( - max_iter=10, - logging_iter=1, - run_validation=False, - ), - optimizer=dict( - lr=1e-5, - fused=True, - ), - model=dict( - config=dict( - # Phase 2 requires a trainable_params regex; ".*" = full fine-tune. - freeze=dict( - trainable_params=[".*"], - ), - parallelism=dict( - data_parallel_shard_degree=4, - data_parallel_replicate_degree=-1, - ), - ), - ), - # Local-only mode: disable the parent's object-store IO and clear the - # S3 credentials/bucket so maybe_download_hf_model_from_s3 falls back - # to HuggingFace Hub (avoids opening credentials/s3_training.secret in - # OSS smoke runs). Pattern mirrors vision_sft_nano.py. - checkpoint=dict( - # Don't save checkpoints during smoke runs. - save_iter=100000, - load_from_object_store=dict(enabled=False, credentials="", bucket=""), - save_to_object_store=dict(enabled=False, credentials="", bucket=""), - ), - # Replace the S3 WebDataset-based dataloader with DataPackerDataLoader - # pointing at lmms-lab/LLaVA-OneVision-Data streamed from HuggingFace Hub. - dataloader_train=L(build_vlm_datapacker_dataloader)( - data_source=L(get_llava_ov_streaming)( - subset="ai2d(gpt4v)", - split="train", - ), - data_packer=L(VLMDataPacker)( - tokenizer_config=L(build_processor)( - tokenizer_type="${model.config.policy.backbone.model_name}", - # OSS smoke mode: route the processor download through the - # HF Hub fallback rather than the S3 default (which would - # try to open credentials/s3_training.secret). - config_variant="hf", - ), - max_seq_len="${dataloader_train.max_tokens}", - ignore_index=IGNORE_INDEX, - ), - max_tokens=16000, - max_batch_size=1, - pool_size=16, - num_workers=2, - prefetch_factor=2, - persistent_workers=True, - pin_memory=True, - ), - dataloader_val=None, - # Suppress S3 uploads in callbacks (iter_speed.save_s3, param_count.save_s3, - # wandb_*.save_s3 all interpolate from ${upload_reproducible_setup}). Mirrors - # the VFM SFT experiments under cosmos/configs/base/experiment/sft/. - upload_reproducible_setup=False, - ), - flags={"allow_objects": True}, -) - -cs.store( - group="experiment", - package="_global_", - name="pre_exp012_llava_ov_datapacker", - node=pre_exp012_llava_ov_datapacker, -) diff --git a/cosmos_framework/configs/base/vlm/experiment/llava_ov_vlm.py b/cosmos_framework/configs/base/vlm/experiment/llava_ov_vlm.py new file mode 100644 index 0000000..2d80828 --- /dev/null +++ b/cosmos_framework/configs/base/vlm/experiment/llava_ov_vlm.py @@ -0,0 +1,292 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: OpenMDW-1.1 + +"""VLM training on lmms-lab/LLaVA-OneVision-Data via CosmosDataLoader. + +Self-contained — inlines the Phase 2 VLMModel/FSDP2 base (formerly the +``pre_exp012_000_phase2_vlm_smoke_4gpu_8b`` smoke recipe in +``pre_exp012_phase2_vlm_smoke.py``) and replaces the dataloader with the +OSS-facing CosmosDataLoader + four-role dataflow pattern. Hydra defaults +below pin the VLM model (``vlm_fsdp`` / ``qwen3_vl_8b_instruct``), the +checkpoint backend, and callbacks. + +The dataset is loaded in streaming mode from the HuggingFace Hub so no local +download is required. Each record is converted from ShareGPT conversation +format to the OpenAI message format expected by Qwen3-VL's processor, then +tokenized in the DataLoader worker via ``processor.apply_chat_template``. + +Resume semantics +---------------- +The streaming HF dataset is a ``datasets.IterableDataset``, which +CosmosDataLoader's IterableDistributor flags with no stateful resume +(streaming has no meaningful position to record). On checkpoint save the +dataloader shard stores placeholder ``(epoch=0, index=0)`` per worker. +When resuming with ``checkpoint.load_training_state=true``: + + - model / optim / scheduler / trainer state restore correctly (iter + counter, optimizer momentum, LR schedule position all continue). + - dataloader stream position does NOT restore; the streamed dataset + re-yields from the beginning, so the first N resumed iters see the + same samples as the first N iters of the original run. + +For a true position-stateful resume, swap the data_source to a map-style +dataset (``load_dataset(..., streaming=False)``). + +Usage (smoke test):: + + torchrun --nproc_per_node=4 --master_port=12344 -m cosmos_framework.scripts.train \\ + --config=cosmos_framework/configs/base/vlm/config.py -- \\ + experiment=pre_exp012_llava_ov \\ + "model.config.policy.backbone.model_name=/path/to/Siglip2-Qwen3-1.7B-BF16-Alignment" \\ + trainer.max_iter=10 trainer.logging_iter=1 \\ + job.wandb_mode=disabled ckpt_type=dummy + +See ``launch_vlm_llava_ov.sh`` for a ready-to-run shell script. +""" + +from __future__ import annotations + +import copy +from typing import Any + +from hydra.core.config_store import ConfigStore + +from cosmos_framework.utils.lazy_config import LazyCall as L +from cosmos_framework.utils.lazy_config import LazyDict +from cosmos_framework.data.vfm.dataflow import ( + CosmosDataLoader, + IterableDistributor, + MapDistributor, + PoolPackingBatcher, +) +from cosmos_framework.data.vfm.processors import build_processor +from cosmos_framework.utils.vlm.constant import IGNORE_INDEX +from cosmos_framework.configs.base.vlm.experiment.dataflow_roles import VLMProcessor, VLMCollator +from cosmos_framework.callbacks.cosmos_dataloader_state import CosmosDataLoaderStateCallback + +cs = ConfigStore.instance() + + +# --------------------------------------------------------------------------- +# LLaVA-OneVision-Data source factory +# +# Loads lmms-lab/LLaVA-OneVision-Data in streaming mode so no local download +# is needed. streaming=True returns an IterableDataset which CosmosDataLoader +# wraps directly via IterableDistributor. +# --------------------------------------------------------------------------- + + +def get_llava_ov_streaming( + subset: str = "si", + split: str = "train", +) -> Any: + """Load lmms-lab/LLaVA-OneVision-Data as a streaming HuggingFace IterableDataset. + + Args: + subset: Dataset config/subset name. ``"si"`` (single-image, ~1M samples) + is the standard choice; pass any valid config name from the Hub. + split: Dataset split (default ``"train"``). + + Returns: + A streaming ``datasets.IterableDataset`` whose items have keys: + ``id``, ``image`` (PIL.Image), ``conversations`` (ShareGPT format). + """ + try: + from datasets import load_dataset + except ImportError as exc: + raise ImportError("pip install datasets to use lmms-lab/LLaVA-OneVision-Data") from exc + + ds = load_dataset( + "lmms-lab/LLaVA-OneVision-Data", + name=subset, + split=split, + streaming=True, + ) + # Pre-filter to remove records without an image or conversations so + # sft_process_sample never receives unparseable samples (the packing + # engine does not tolerate None returns from the processor). + return ds.filter(lambda x: x.get("image") is not None and len(x.get("conversations") or []) >= 2) + + +# --------------------------------------------------------------------------- +# Map-style data source factory (for the resumable variant below) +# +# Loads a subset as a real on-disk ``datasets.Dataset`` (streaming=False — +# random-access), filters it, and caps it to ``n`` rows so ``MapDistributor`` +# can checkpoint exact ``(epoch, index)`` positions per worker. +# --------------------------------------------------------------------------- + + +def get_llava_ov_map( + subset: str = "ai2d(gpt4v)", + split: str = "train", + n: int = 4000, +) -> Any: + """Load a filtered LLaVA-OV subset as a real map-style ``datasets.Dataset``. + + Uses ``load_dataset(..., streaming=False)`` so the result is a genuine + random-access (map-style) Dataset — exactly the case ``MapDistributor`` is + built to shard + resume. The subset is filtered to valid image/conversation + rows and capped to ``n`` rows (via ``.select``) so a ``save_iter=100`` run + saves/resumes well inside one epoch (mid-epoch resume, no epoch-wrap). + + Args: + subset: Dataset config/subset name (e.g. ``"ai2d(gpt4v)"``). + split: Dataset split (default ``"train"``). + n: Max number of rows to keep after filtering. + + Returns: + A ``datasets.Dataset`` (map-style) with columns from LLaVA-OV. + """ + from datasets import load_dataset + + ds = load_dataset("lmms-lab/LLaVA-OneVision-Data", name=subset, split=split, streaming=False) + ds = ds.filter(lambda x: x.get("image") is not None and len(x.get("conversations") or []) >= 2) + if n is not None and n < len(ds): + ds = ds.select(range(n)) + return ds + + +# --------------------------------------------------------------------------- +# Experiment registration +# --------------------------------------------------------------------------- + + +pre_exp012_llava_ov = LazyDict( + dict( + # Hydra defaults — inlined from the former pre_exp012_000_phase2_vlm_smoke_4gpu_8b + # smoke recipe. data_train/data_val intentionally omitted because the + # dataloader_train below is a self-contained CosmosDataLoader; pulling in + # the smoke's s3 webdataset defaults would let storage_type schema bleed into + # our CosmosDataLoader config. + defaults=[ + {"override /checkpoint": "s3"}, + {"override /model": "vlm_fsdp"}, + {"override /vlm_policy": "qwen3_vl_8b_instruct"}, + {"override /callbacks": ["basic_vlm", "basic_log"]}, + "_self_", + ], + job=dict( + name="pre_exp012_llava_ov_${now:%Y-%m-%d}_${now:%H-%M-%S}", + group="vlm_llava_ov_demo", + wandb_mode="disabled", + ), + trainer=dict( + max_iter=10, + logging_iter=1, + run_validation=False, + ), + optimizer=dict( + lr=1e-5, + fused=True, + ), + model=dict( + config=dict( + # Phase 2 requires a trainable_params regex; ".*" = full fine-tune. + freeze=dict( + trainable_params=[".*"], + ), + parallelism=dict( + data_parallel_shard_degree=4, + data_parallel_replicate_degree=-1, + ), + ), + ), + # Local-only mode: disable the parent's object-store IO and clear the + # S3 credentials/bucket so maybe_download_hf_model_from_s3 falls back + # to HuggingFace Hub (avoids opening credentials/s3_training.secret in + # OSS smoke runs). Pattern mirrors vision_sft_nano.py. + checkpoint=dict( + # Don't save checkpoints during smoke runs. + save_iter=100000, + load_from_object_store=dict(enabled=False, credentials="", bucket=""), + save_to_object_store=dict(enabled=False, credentials="", bucket=""), + ), + # Replace the S3 WebDataset-based dataloader with CosmosDataLoader + # pointing at lmms-lab/LLaVA-OneVision-Data streamed from HuggingFace Hub, + # wired through the four-role dataflow (IterableDistributor, VLMProcessor, + # PoolPackingBatcher, VLMCollator). + dataloader_train=L(CosmosDataLoader)( + distributor=L(IterableDistributor)( + iterable=L(get_llava_ov_streaming)(subset="ai2d(gpt4v)", split="train"), + ), + processor=L(VLMProcessor)( + processor=L(build_processor)( + tokenizer_type="${model.config.policy.backbone.model_name}", + # OSS smoke mode: route the processor download through the + # HF Hub fallback rather than the S3 default (which would + # try to open credentials/s3_training.secret). + config_variant="hf", + ), + ignore_index=IGNORE_INDEX, + ), + batcher=L(PoolPackingBatcher)( + max_tokens=16000, pool_size=16, max_batch_size=1, long_threshold=6400, + ), + collator=L(VLMCollator)(), + num_workers=2, + ), + dataloader_val=None, + # Suppress S3 uploads in callbacks (iter_speed.save_s3, param_count.save_s3, + # wandb_*.save_s3 all interpolate from ${upload_reproducible_setup}). Mirrors + # the VFM SFT experiments under cosmos/configs/base/experiment/sft/. + upload_reproducible_setup=False, + ), + flags={"allow_objects": True}, +) + +cs.store( + group="experiment", + package="_global_", + name="pre_exp012_llava_ov", + node=pre_exp012_llava_ov, +) + + +# --------------------------------------------------------------------------- +# pre_exp012_llava_ov_mapstyle_dataloader — map-style, resumable variant. +# +# Identical to pre_exp012_llava_ov except it swaps the streaming +# IterableDistributor for a MapDistributor over a real on-disk Dataset +# (get_llava_ov_map, streaming=False), which gives exact per-worker (epoch, +# index) checkpoint/resume. It therefore also: wires the dataloader_state +# CosmosDataLoaderStateCallback (sets COSMOS_DL_STATE_* env vars on resume so +# MapDistributor fast-forwards), enables checkpoint saving (save_iter=100), and +# uses num_workers=0 to keep worker bookkeeping simple. Every other block is +# reused verbatim from pre_exp012_llava_ov. +# --------------------------------------------------------------------------- +pre_exp012_llava_ov_mapstyle_dataloader = copy.deepcopy(pre_exp012_llava_ov) +pre_exp012_llava_ov_mapstyle_dataloader.job.name = ( + "pre_exp012_llava_ov_mapstyle_dataloader_${now:%Y-%m-%d}_${now:%H-%M-%S}" +) +pre_exp012_llava_ov_mapstyle_dataloader.trainer.callbacks = dict( + dataloader_state=L(CosmosDataLoaderStateCallback)(), +) +pre_exp012_llava_ov_mapstyle_dataloader.checkpoint.save_iter = 100 +pre_exp012_llava_ov_mapstyle_dataloader.dataloader_train = L(CosmosDataLoader)( + distributor=L(MapDistributor)( + dataset=L(get_llava_ov_map)(subset="ai2d(gpt4v)", split="train", n=4000), + shuffle=True, + seed=42, + name="", + ), + processor=L(VLMProcessor)( + processor=L(build_processor)( + tokenizer_type="${model.config.policy.backbone.model_name}", + config_variant="hf", + ), + ignore_index=IGNORE_INDEX, + ), + batcher=L(PoolPackingBatcher)( + max_tokens=16000, pool_size=16, max_batch_size=1, long_threshold=6400, + ), + collator=L(VLMCollator)(), + num_workers=0, +) + +cs.store( + group="experiment", + package="_global_", + name="pre_exp012_llava_ov_mapstyle_dataloader", + node=pre_exp012_llava_ov_mapstyle_dataloader, +) diff --git a/cosmos_framework/configs/base/vlm/experiment/videophy2_dataflow_roles.py b/cosmos_framework/configs/base/vlm/experiment/videophy2_dataflow_roles.py new file mode 100644 index 0000000..1341c2a --- /dev/null +++ b/cosmos_framework/configs/base/vlm/experiment/videophy2_dataflow_roles.py @@ -0,0 +1,124 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: OpenMDW-1.1 + +"""videophy2 RawItemProcessor extracted 1:1 from VideoPhy2DataPacker.""" + +from __future__ import annotations + +import io +from typing import Any + +from cosmos_framework.data.vfm.dataflow.base import RawItemProcessor +from cosmos_framework.utils.vlm.constant import IGNORE_INDEX, PROCESSOR_KEYS_TO_ADD + +_MAX_VIDEO_FRAMES = 32 +_TARGET_VIDEO_FPS = 2.0 + + +def _decode_video_to_pil_frames(video_bytes: bytes) -> tuple[list, float]: + from torchcodec.decoders import VideoDecoder + from PIL import Image + import numpy as np + + decoder = VideoDecoder(video_bytes) + total_frames = decoder.metadata.num_frames or 0 + source_fps = float(decoder.metadata.average_fps or 0.0) or 30.0 + + if total_frames <= 0: + raise ValueError("video has zero frames") + + stride = max(1, int(round(source_fps / _TARGET_VIDEO_FPS))) + indices = list(range(0, total_frames, stride)) + if len(indices) > _MAX_VIDEO_FRAMES: + step = len(indices) / _MAX_VIDEO_FRAMES + indices = [indices[int(i * step)] for i in range(_MAX_VIDEO_FRAMES)] + + frames_tensor = decoder.get_frames_at(indices=indices).data + frames_np = frames_tensor.permute(0, 2, 3, 1).contiguous().cpu().numpy().astype(np.uint8) + frames = [Image.fromarray(f) for f in frames_np] + + effective_fps = source_fps / stride if stride > 0 else source_fps + return frames, float(effective_fps) + + +class VideoPhy2Processor(RawItemProcessor): + """LocalSFT {"texts","media"} record -> VLM training tensors.""" + + def __init__(self, processor: Any, ignore_index: int = IGNORE_INDEX) -> None: + self._processor = processor + self._ignore_index = ignore_index + + def _materialize_media_in_conversation( + self, + conversation: list, + media_bytes_by_key: dict, + ) -> list: + # Resolve "video": "" / "image": "" references against + # data_dict["media"] (bytes); decode each unique key once. + decoded_cache: dict[str, tuple[list, float]] = {} + new_messages: list[dict] = [] + for message in conversation: + if not isinstance(message, dict): + continue + content = message.get("content") + if isinstance(content, str): + new_messages.append({"role": message.get("role", "user"), "content": content}) + continue + if not isinstance(content, list): + continue + new_content: list[dict] = [] + for item in content: + if not isinstance(item, dict): + continue + kind = item.get("type") + if kind == "video": + key = item.get("video") + if not isinstance(key, str): + new_content.append(item) + continue + if key not in media_bytes_by_key: + raise KeyError( + f"conversation references video key {key!r} not present in " + f"sample['media'] (keys: {list(media_bytes_by_key)})" + ) + if key not in decoded_cache: + decoded_cache[key] = _decode_video_to_pil_frames(media_bytes_by_key[key]) + frames, fps = decoded_cache[key] + new_content.append({"type": "video", "video": frames, "fps": fps}) + elif kind == "image": + key = item.get("image") + if not isinstance(key, str): + new_content.append(item) + continue + if key not in media_bytes_by_key: + raise KeyError( + f"conversation references image key {key!r} not present in " + f"sample['media'] (keys: {list(media_bytes_by_key)})" + ) + from PIL import Image + + img = Image.open(io.BytesIO(media_bytes_by_key[key])).convert("RGB") + new_content.append({"type": "image", "image": img}) + else: + new_content.append(item) + new_messages.append({"role": message.get("role", "user"), "content": new_content}) + return new_messages + + def process(self, item: dict) -> dict: + conversation = item.get("texts") + if not isinstance(conversation, list): + raise TypeError( + f"LocalSFTDataset sample expected 'texts' to be a list, got {type(conversation).__name__}" + ) + media_bytes_by_key = item.get("media") or {} + messages = self._materialize_media_in_conversation(conversation, media_bytes_by_key) + inputs = self._processor.apply_chat_template(messages, tokenize=True, add_generation_prompt=False) + input_ids = inputs["input_ids"] + token_mask = self._processor.add_assistant_tokens_mask(input_ids) + labels = input_ids.clone() + labels[~token_mask] = self._ignore_index + result: dict = {"input_ids": input_ids, "labels": labels} + for key in PROCESSOR_KEYS_TO_ADD: + if key in inputs and inputs[key] is not None: + result[key] = inputs[key] + return result diff --git a/cosmos_framework/configs/base/vlm/experiment/videophy2_sft_nano.py b/cosmos_framework/configs/base/vlm/experiment/videophy2_sft_nano.py index 077278a..358b2c6 100644 --- a/cosmos_framework/configs/base/vlm/experiment/videophy2_sft_nano.py +++ b/cosmos_framework/configs/base/vlm/experiment/videophy2_sft_nano.py @@ -1,7 +1,7 @@ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: OpenMDW-1.1 -"""VideoPhy-2 SFT recipe: LocalSFTDataset + DataPackerDataLoader on Qwen3-VL. +"""VideoPhy-2 SFT recipe: LocalSFTDataset + CosmosDataLoader on Qwen3-VL. Launch via examples/launch_sft_videophy2_nano.sh after running prepare_videophy2_from_hf to populate $VIDEOPHYSICS_ROOT. @@ -18,14 +18,15 @@ from hydra.core.config_store import ConfigStore from cosmos_framework.utils.lazy_config import LazyCall as L -from cosmos_framework.utils.lazy_config import LazyDict, instantiate -from cosmos_framework.data.vfm.data_packer import DataPacker -from cosmos_framework.data.vfm.data_packer_dataloader import DataPackerDataLoader +from cosmos_framework.utils.lazy_config import LazyDict +from cosmos_framework.data.vfm.dataflow import CosmosDataLoader, IterableDistributor, PoolPackingBatcher from cosmos_framework.data.vfm.processors import build_processor from cosmos_framework.data.vlm.local_sft_dataset import LocalSFTDataset from cosmos_framework.data.vlm.data_sources_videophy2.videophy2 import DATAINFO from cosmos_framework.utils import log from cosmos_framework.utils.vlm.constant import IGNORE_INDEX, PROCESSOR_KEYS_TO_ADD +from cosmos_framework.configs.base.vlm.experiment.dataflow_roles import VLMCollator +from cosmos_framework.configs.base.vlm.experiment.videophy2_dataflow_roles import VideoPhy2Processor cs = ConfigStore.instance() @@ -33,7 +34,7 @@ class _UnshardedLocalSFTDataset(LocalSFTDataset): """Yield the full shuffled manifest per iteration. - Why: ``DataPackerDataLoader._IterableWrapper`` already shards by + Why: ``CosmosDataLoader``'s IterableDistributor already shards by ``dp_rank * num_workers + worker_id``; stock ``LocalSFTDataset`` shards again inside ``__iter__``, double-sharding to ``1 / (world*workers)^2``. """ @@ -53,7 +54,7 @@ def build_videophy2_local_dataset( dataset_key: str, split: str, ) -> _UnshardedLocalSFTDataset: - # augmentor_config=None: the DataPacker decodes+tokenizes inline; the + # augmentor_config=None: the Processor decodes+tokenizes inline; the # BytesToMedia/TokenizeData augmentors aren't shipped in OSS. source = DATAINFO[dataset_key] if split not in source.manifest_path: @@ -75,12 +76,6 @@ def build_videophy2_local_dataset( ) -def build_videophy2_datapacker_dataloader(**kwargs) -> DataPackerDataLoader: - for _spurious in ("storage_type",): - kwargs.pop(_spurious, None) - return DataPackerDataLoader(**kwargs) - - _MAX_VIDEO_FRAMES = 32 _TARGET_VIDEO_FPS = 2.0 @@ -111,138 +106,30 @@ def _decode_video_to_pil_frames(video_bytes: bytes) -> tuple[list, float]: return frames, float(effective_fps) -class VideoPhy2DataPacker(DataPacker): - """LocalSFTDataset + Qwen3-VL processor adapter; max_batch_size=1.""" - - def __init__( - self, - tokenizer_config: Any, - max_seq_len: int = 16000, - ignore_index: int = IGNORE_INDEX, - ) -> None: - self._max_seq_len = max_seq_len - self._ignore_index = ignore_index - self._processor = ( - tokenizer_config if hasattr(tokenizer_config, "apply_chat_template") else instantiate(tokenizer_config) - ) - - def _materialize_media_in_conversation( - self, - conversation: list, - media_bytes_by_key: dict, - ) -> list: - # Resolve "video": "" / "image": "" references against - # data_dict["media"] (bytes); decode each unique key once. - decoded_cache: dict[str, tuple[list, float]] = {} - new_messages: list[dict] = [] - for message in conversation: - if not isinstance(message, dict): - continue - content = message.get("content") - if isinstance(content, str): - new_messages.append({"role": message.get("role", "user"), "content": content}) - continue - if not isinstance(content, list): - continue - new_content: list[dict] = [] - for item in content: - if not isinstance(item, dict): - continue - kind = item.get("type") - if kind == "video": - key = item.get("video") - if not isinstance(key, str): - new_content.append(item) - continue - if key not in media_bytes_by_key: - raise KeyError( - f"conversation references video key {key!r} not present in " - f"sample['media'] (keys: {list(media_bytes_by_key)})" - ) - if key not in decoded_cache: - decoded_cache[key] = _decode_video_to_pil_frames(media_bytes_by_key[key]) - frames, fps = decoded_cache[key] - new_content.append({"type": "video", "video": frames, "fps": fps}) - elif kind == "image": - key = item.get("image") - if not isinstance(key, str): - new_content.append(item) - continue - if key not in media_bytes_by_key: - raise KeyError( - f"conversation references image key {key!r} not present in " - f"sample['media'] (keys: {list(media_bytes_by_key)})" - ) - from PIL import Image - - img = Image.open(io.BytesIO(media_bytes_by_key[key])).convert("RGB") - new_content.append({"type": "image", "image": img}) - else: - new_content.append(item) - new_messages.append({"role": message.get("role", "user"), "content": new_content}) - return new_messages - - def sft_process_sample(self, item: dict) -> dict: - conversation = item.get("texts") - if not isinstance(conversation, list): - raise TypeError( - f"LocalSFTDataset sample expected 'texts' to be a list, got {type(conversation).__name__}" - ) - media_bytes_by_key = item.get("media") or {} - messages = self._materialize_media_in_conversation(conversation, media_bytes_by_key) - - inputs = self._processor.apply_chat_template( - messages, - tokenize=True, - add_generation_prompt=False, - ) - input_ids = inputs["input_ids"] # [N] - - token_mask = self._processor.add_assistant_tokens_mask(input_ids) # [N] bool - labels = input_ids.clone() - labels[~token_mask] = self._ignore_index - - result: dict = { - "input_ids": input_ids, - "labels": labels, - } - for key in PROCESSOR_KEYS_TO_ADD: - if key in inputs and inputs[key] is not None: - result[key] = inputs[key] - - return result - - def compute_num_tokens(self, sample: dict) -> int: - return int(sample["input_ids"].shape[0]) - - def sft_collate_fn( - self, - samples: list[dict], - max_len: int, - ignore_label_id: int = IGNORE_INDEX, - ) -> dict: - assert len(samples) == 1, f"VideoPhy2DataPacker expects max_batch_size=1, got {len(samples)}" - s = samples[0] - - worker_info = torch.utils.data.get_worker_info() - worker_id = worker_info.id if worker_info is not None else 0 - - batch: dict = { - "input_ids": s["input_ids"].unsqueeze(0), - "labels": s["labels"].unsqueeze(0), - "sample_worker_id": torch.tensor([worker_id]), - "sample_epoch": torch.tensor([0]), - "sample_index": torch.tensor([0]), - } - - if "attention_mask" in s and s["attention_mask"] is not None: - batch["attention_mask"] = s["attention_mask"].unsqueeze(0) - - for key in ("pixel_values", "pixel_values_videos", "image_grid_thw", "video_grid_thw", "second_per_grid_ts"): - if key in s and s[key] is not None: - batch[key] = s[key] - - return batch +def _dl(dataset_key, split, num_workers, persistent_workers=False, pin_memory=False, prefetch_factor=None): + return L(CosmosDataLoader)( + distributor=L(IterableDistributor)( + iterable=L(build_videophy2_local_dataset)(dataset_key=dataset_key, split=split), + ), + processor=L(VideoPhy2Processor)( + processor=L(build_processor)( + tokenizer_type="${model.config.policy.backbone.model_name}", + config_variant="hf", + ), + ignore_index=IGNORE_INDEX, + ), + batcher=L(PoolPackingBatcher)( + max_tokens="${data_setting.max_tokens}", + pool_size=16, + max_batch_size=1, + long_threshold=6400, + ), + collator=L(VLMCollator)(), + num_workers=num_workers, + persistent_workers=persistent_workers, + pin_memory=pin_memory, + prefetch_factor=prefetch_factor, + ) videophy2_sft_nano = LazyDict( @@ -312,48 +199,8 @@ def sft_collate_fn( hf_export=dict(enabled=True), ), upload_reproducible_setup=False, - dataloader_train=L(build_videophy2_datapacker_dataloader)( - data_source=L(build_videophy2_local_dataset)( - dataset_key="videophy2_train", - split="train", - ), - data_packer=L(VideoPhy2DataPacker)( - tokenizer_config=L(build_processor)( - tokenizer_type="${model.config.policy.backbone.model_name}", - config_variant="hf", - ), - max_seq_len="${data_setting.max_tokens}", - ignore_index=IGNORE_INDEX, - ), - max_tokens="${data_setting.max_tokens}", - max_batch_size=1, - pool_size=16, - num_workers=2, - prefetch_factor=2, - persistent_workers=True, - pin_memory=True, - ), - dataloader_val=L(build_videophy2_datapacker_dataloader)( - data_source=L(build_videophy2_local_dataset)( - dataset_key="videophy2_val", - split="val", - ), - data_packer=L(VideoPhy2DataPacker)( - tokenizer_config=L(build_processor)( - tokenizer_type="${model.config.policy.backbone.model_name}", - config_variant="hf", - ), - max_seq_len="${data_setting.max_tokens}", - ignore_index=IGNORE_INDEX, - ), - max_tokens="${data_setting.max_tokens}", - max_batch_size=1, - pool_size=16, - num_workers=0, - prefetch_factor=None, - persistent_workers=False, - pin_memory=True, - ), + dataloader_train=_dl("videophy2_train", "train", 2, persistent_workers=True, pin_memory=True, prefetch_factor=2), + dataloader_val=_dl("videophy2_val", "val", 0, persistent_workers=False, pin_memory=True, prefetch_factor=None), ), flags={"allow_objects": True}, ) diff --git a/cosmos_framework/configs/toml_config/sft_config.py b/cosmos_framework/configs/toml_config/sft_config.py index d02e365..a15d00a 100644 --- a/cosmos_framework/configs/toml_config/sft_config.py +++ b/cosmos_framework/configs/toml_config/sft_config.py @@ -614,7 +614,7 @@ class DataloaderTrainConfig(BaseModel): """Top-level dataloader scalars only. The dataloader's class (LazyCall) and full pipeline wiring (datasets, packers, …) stay in the experiment Python — they vary too much between VFM IterativeJointDataLoader, - PackingDataLoader, and VLM DataPackerDataLoader to model uniformly. + PackingDataLoader, and VLM CosmosDataLoader to model uniformly. """ model_config = _PYDANTIC_MODEL_CONFIG @@ -622,16 +622,18 @@ class DataloaderTrainConfig(BaseModel): max_samples_per_batch: Optional[int] = Field( default=None, description=( - "Cap on samples per micro-batch. Remapped to 'max_batch_size' " - "on the VLM DataPackerDataLoader. None = no per-count cap " + "Cap on samples per micro-batch. Remapped to " + "'dataloader_train.batcher.max_batch_size' on the VLM CosmosDataLoader " + "(its PoolPackingBatcher). None = no per-count cap " "(the packer's token budget is what limits batch size)." ), ) max_sequence_length: Optional[int] = Field( default=None, description=( - "Cap on tokens per packed sequence. Remapped to 'max_tokens' " - "on the VLM DataPackerDataLoader. None = no per-token cap." + "Cap on tokens per packed sequence. Remapped to " + "'dataloader_train.batcher.max_tokens' on the VLM CosmosDataLoader " + "(its PoolPackingBatcher). None = no per-token cap." ), ) max_caption_tokens: Optional[int] = Field( @@ -646,7 +648,7 @@ class DataloaderTrainConfig(BaseModel): seed: int = Field( default=42, description=( - "Dataloader RNG seed. Skipped on VLM (DataPackerDataLoader has " + "Dataloader RNG seed. Skipped on VLM (CosmosDataLoader has " "no seed ctor kwarg there)." ), ) diff --git a/cosmos_framework/configs/toml_config/toml_config_helper.py b/cosmos_framework/configs/toml_config/toml_config_helper.py index 55e2cc7..ac54696 100644 --- a/cosmos_framework/configs/toml_config/toml_config_helper.py +++ b/cosmos_framework/configs/toml_config/toml_config_helper.py @@ -79,8 +79,10 @@ ("model", "attn_implementation"): ("model", "config", "policy", "attn_implementation"), ("model", "ema"): ("model", "config", "ema"), ("model", "backbone"): ("model", "config", "policy", "backbone"), - ("dataloader_train", "max_samples_per_batch"): ("dataloader_train", "max_batch_size"), - ("dataloader_train", "max_sequence_length"): ("dataloader_train", "max_tokens"), + # VLM uses CosmosDataLoader whose batch/token caps live on the nested + # PoolPackingBatcher (dataloader_train.batcher.*), not flat on the loader. + ("dataloader_train", "max_samples_per_batch"): ("dataloader_train", "batcher", "max_batch_size"), + ("dataloader_train", "max_sequence_length"): ("dataloader_train", "batcher", "max_tokens"), ("dataloader_train", "max_caption_tokens"): None, # VFM-only knob — VLM packer caps via max_sequence_length # Catch-all for any other model.* sub-keys ("model",): ("model", "config"), diff --git a/cosmos_framework/data/vfm/data_packer.py b/cosmos_framework/data/vfm/data_packer.py deleted file mode 100644 index 17dfc08..0000000 --- a/cosmos_framework/data/vfm/data_packer.py +++ /dev/null @@ -1,107 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: OpenMDW-1.1 - -""" -Storage-agnostic sample-level transform protocol for OSS-compatible training. - -Interface is name-compatible with cosmos-rl's ``BaseDataPacker`` SFT methods: - - sft_process_sample ↔ sft_process_sample (identical) - sft_collate_fn ↔ sft_collate_fn (identical) - compute_num_tokens ← NEW single-sample token cost for packing budget - -After adding a one-line ``compute_num_tokens`` default to cosmos-rl's -``BaseDataPacker``, existing cosmos-rl packers (``HFVLMDataPacker``, -``Qwen3_VL_DataPacker``, etc.) become directly usable here with no other changes. - -Usage ------ -Subclass ``DataPacker`` and implement three methods, then plug into -``DataPackerDataLoader``:: - - class MyPacker(DataPacker): - def sft_process_sample(self, item): - return {"input_ids": tokenizer(item["text"]).input_ids} - - def compute_num_tokens(self, sample): - return len(sample["input_ids"]) - - def sft_collate_fn(self, samples, max_len, ignore_label_id=-100): - # pad and stack - ... - - loader = DataPackerDataLoader( - data_source=my_dataset, - data_packer=MyPacker(), - max_tokens=16000, - ) -""" - -from __future__ import annotations - -from abc import ABC, abstractmethod -from typing import Any - - -class DataPacker(ABC): - """Storage-agnostic protocol for transforming dataset items into training batches. - - OSS users subclass this to support any model or dataset format. - Three abstract methods are required; the rest of the training infrastructure - (packing, worker management, Hydra config) is inherited automatically. - """ - - @abstractmethod - def sft_process_sample(self, item: Any) -> dict: - """Convert one raw dataset item into a training-ready sample dict. - - Parameters - ---------- - item: - Whatever the user's ``data_source`` iterable yields — - a HuggingFace record, a ``{"image": PIL.Image, "text": str}`` dict, - or any other format. - - Returns - ------- - dict - Must contain at minimum the keys expected by ``sft_collate_fn`` - and must have a token-countable representation for - ``compute_num_tokens``. - """ - - @abstractmethod - def compute_num_tokens(self, sample: dict) -> int: - """Return the token cost of one sample for the packing budget. - - For VLM/text models this is typically ``len(sample["input_ids"])``. - For VFM models override with the VAE spatial/temporal formula. - - This method corresponds to the *per-sample* granularity needed by - ``PackingIterableDataset._best_fit_batch``. It differs from - cosmos-rl's ``sft_compute_max_len`` (batch-level) intentionally. - """ - - @abstractmethod - def sft_collate_fn( - self, - samples: list[dict], - max_len: int, - ignore_label_id: int = -100, - ) -> dict: - """Collate a list of packed samples into one training batch. - - Parameters - ---------- - samples: - List of dicts returned by ``sft_process_sample``. - max_len: - Maximum token length in this batch (for padding). - ignore_label_id: - Label value for masked/padding positions (default ``-100``). - - Returns - ------- - dict - Batch ready for ``model.forward()``. - """ diff --git a/cosmos_framework/data/vfm/data_packer_dataloader.py b/cosmos_framework/data/vfm/data_packer_dataloader.py deleted file mode 100644 index 5b4777e..0000000 --- a/cosmos_framework/data/vfm/data_packer_dataloader.py +++ /dev/null @@ -1,623 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: OpenMDW-1.1 - -""" -OSS-facing dataloader that wires any Python iterable + DataPacker into the -shared PackingIterableDataset engine. - -Follows the same two-layer pattern as the internal path: - private _DataPackerIterableDataset ↔ private _JointIterableDataset - public DataPackerDataLoader ↔ public JointDatasetDynamicBatchingWebLoader - -Data-parallel sharding ----------------------- -When ``torch.distributed`` is initialized, ``DataPackerDataLoader`` automatically -shards ``data_source`` across ranks **and** DataLoader workers using round-robin -filtering — the same pattern as ``SFTDataset`` in -``projects/cosmos3/vfm/datasets/local_datasets/sft_dataset.py``. - -Each ``(dp_rank, worker_id)`` pair sees every -``dp_world_size × num_workers``-th item, giving disjoint coverage. - -Usage ------ -Pass a pre-built iterable directly:: - - loader = DataPackerDataLoader( - data_source=my_dataset, # any Python iterable - data_packer=MyDataPacker(...), - max_tokens=16000, - num_workers=4, - ) - -Or load a HuggingFace / local dataset via ``load_data_source`` — compatible -with Hydra ``LazyCall`` so CLI overrides work without editing Python files:: - - from cosmos_framework.utils.lazy_config import LazyCall as L - from cosmos_framework.data.vfm.data_packer_dataloader import ( - DataPackerDataLoader, - load_data_source, - ) - - dataloader_train = L(DataPackerDataLoader)( - data_source=L(load_data_source)( - name="liuhaotian/LLaVA-Instruct-150K", - split=["train"], - ), - data_packer=L(MyDataPacker)(...), - max_tokens=16000, - ) - - # CLI override (no Python file edit needed): - # dataloader_train.data_source.name=my-org/my-dataset - # dataloader_train.data_source.split=[train,validation] - - # FSDP + TP/PP (pass parallel_dims for correct DP rank): - loader = DataPackerDataLoader( - data_source=..., - data_packer=..., - max_tokens=16000, - parallel_dims=parallel_dims, # uses parallel_dims.dp_coord - ) -""" - -from __future__ import annotations - -import os -from typing import Any - -import numpy as np -import torch -import torch.utils.data - -from cosmos_framework.utils import log -from cosmos_framework.data.vfm.data_packer import DataPacker -from cosmos_framework.data.vfm.packing_iterable_dataset import PackingIterableDataset - - -def load_data_source( - name: str, - split: str | list[str] = "train", - subset: str | None = None, - revision: str | None = None, -) -> Any: - """Load a HuggingFace or local dataset for use as ``data_source``. - - Designed to be used as a ``LazyCall`` in Hydra experiment configs so that - dataset name and split can be overridden from the CLI without editing Python - files (see module docstring for an example). - - Parameters - ---------- - name: - HuggingFace dataset name (e.g. ``"liuhaotian/LLaVA-Instruct-150K"``) or - a local directory path to a dataset saved with ``dataset.save_to_disk()``. - Local paths are detected via ``os.path.isdir`` and loaded with - ``load_from_disk``; all other values go through ``load_dataset``. - split: - Split name or list of split names to load. When a list is given the - splits are concatenated into a single dataset. - subset: - HuggingFace dataset subset / config name (optional). - revision: - Git revision / commit hash of the dataset (optional). - - Returns - ------- - datasets.Dataset - A concatenated ``datasets.Dataset`` ready to be passed to - ``DataPackerDataLoader`` as ``data_source``. - - Raises - ------ - ImportError - If the ``datasets`` package is not installed. - """ - try: - from datasets import Dataset, concatenate_datasets, load_dataset, load_from_disk - except ImportError as exc: - raise ImportError( - "The 'datasets' package is required by load_data_source. Install it with: pip install datasets" - ) from exc - - import os - - if os.path.isdir(name): - # Dataset saved with dataset.save_to_disk() — use load_from_disk. - raw = load_from_disk(name) - else: - # HuggingFace Hub name or other format supported by load_dataset. - raw = load_dataset(name, subset, revision=revision) - - if isinstance(raw, Dataset): - # load_from_disk on a single Dataset (not DatasetDict) — return as-is. - return raw - - # DatasetDict: select and concatenate requested splits. - splits = [split] if isinstance(split, str) else split - return concatenate_datasets([raw[s] for s in splits]) - - -class _IterableWrapper(torch.utils.data.IterableDataset): - """Wraps any Python iterable as a ``torch.utils.data.IterableDataset`` - with built-in data-parallel + multi-worker sharding. - - Sharding follows the same ``(dp_rank × num_workers)`` formula as - ``SFTDataset`` — each ``(dp_rank, worker_id)`` pair receives every - ``dp_world_size × num_workers``-th item starting at - ``dp_rank * num_workers + worker_id``. - - .. warning:: - For ``num_workers=0``, worker-level sharding is skipped automatically. - """ - - def __init__(self, iterable: Any, dp_rank: int = 0, dp_world_size: int = 1): - super().__init__() - self._iterable = iterable - self._dp_rank = dp_rank - self._dp_world_size = dp_world_size - - def __iter__(self): - worker_info = torch.utils.data.get_worker_info() - if worker_info is not None: - num_workers = worker_info.num_workers - worker_id = worker_info.id - else: - num_workers, worker_id = 1, 0 - - # Total independent streams = dp_world_size × num_workers. - # Each (rank, worker) pair owns stream = rank * num_workers + worker_id. - total_streams = self._dp_world_size * num_workers - my_stream = self._dp_rank * num_workers + worker_id - - for i, item in enumerate(self._iterable): - if i % total_streams == my_stream: - yield item - - -class _ShuffledMapIterableDataset(torch.utils.data.IterableDataset): - """Stateful, sharded wrapper for map-style ``torch.utils.data.Dataset``. - - Used for ALL map-style ``data_source`` inputs, regardless of ``shuffle``. - Handles DP × worker sharding and stateful checkpoint/resume. - - - Shuffle (``shuffle=True``): per-epoch ``torch.randperm(n)`` seeded with - ``base_seed + epoch``, giving a different but reproducible ordering every epoch. - - No shuffle (``shuffle=False``): sequential iteration ``[0, 1, ..., n-1]`` - each epoch — deterministic and resumable at the exact position. - - Sharding: ``stream_id = dp_rank * num_workers + worker_id``; each stream - yields ``perm[stream_id :: total_streams]`` — disjoint, full coverage. - - Resume: reads ``DP_STATE_WORKER_{worker_id}_EPOCH`` / - ``DP_STATE_WORKER_{worker_id}_INDEX`` env vars set by - ``DataLoaderStateCallback.load_state_dict`` before workers start. - When a dataset ``name`` is provided (non-empty), env vars are namespaced - as ``DP_STATE_{name}_WORKER_{worker_id}_EPOCH`` to avoid conflicts when - multiple ``DataPackerDataLoader`` instances share the same process (e.g. - inside ``JointDataPackerDataLoader``). - - The generator body is lazy: ``worker_info`` (and env vars) are read on the - first ``next()`` call inside the worker process, not at construction time. - Requires ``persistent_workers=True`` and ``fork`` start method (Linux/CUDA - default) — both enforced / documented by ``DataPackerDataLoader``. - """ - - def __init__( - self, - dataset: torch.utils.data.Dataset, - seed: int, - dp_rank: int, - dp_world_size: int, - shuffle: bool = True, - name: str = "", - ) -> None: - super().__init__() - self._dataset = dataset - self._seed = seed - self._dp_rank = dp_rank - self._dp_world_size = dp_world_size - self._shuffle = shuffle - self._name = name - - def __len__(self) -> int: - return len(self._dataset) # type: ignore[arg-type] - - def __iter__(self): - worker_info = torch.utils.data.get_worker_info() - num_workers = worker_info.num_workers if worker_info is not None else 1 - worker_id = worker_info.id if worker_info is not None else 0 - - stream_id = self._dp_rank * num_workers + worker_id - total_streams = self._dp_world_size * num_workers - n = len(self._dataset) # type: ignore[arg-type] - - # os.environ.pop: consume once so a hypothetical second __iter__ call - # in the same worker process defaults to a fresh-start sentinel instead of - # re-fast-forwarding. -1 means "no items seen yet" → start = 0. - # DataLoaderStateCallback always saves index ≥ 0, so -1 is unambiguous. - # When self._name is non-empty, env vars are namespaced to avoid conflicts - # between multiple DataPackerDataLoader instances in the same process. - _pfx = f"DP_STATE_{self._name}_" if self._name else "DP_STATE_" - resume_epoch = int(os.environ.pop(f"{_pfx}WORKER_{worker_id}_EPOCH", 0)) - resume_pos = int(os.environ.pop(f"{_pfx}WORKER_{worker_id}_INDEX", -1)) - - epoch = resume_epoch - while True: - if self._shuffle: - g = torch.Generator().manual_seed(self._seed + epoch) - perm = torch.randperm(n, generator=g).tolist() - else: - perm = list(range(n)) - stream_slice = perm[stream_id::total_streams] - - # resume_pos is the last index successfully included in a training - # batch, so start one past it. On new epochs start from 0. - start = (resume_pos + 1) if epoch == resume_epoch else 0 - for pos in range(start, len(stream_slice)): - item = self._dataset[stream_slice[pos]] - # Attach position metadata; _DataPackerIterableDataset strips - # these before sft_process_sample and re-attaches after so they - # survive through the pool to collate_batch. - yield {"_dp_epoch": epoch, "_dp_stream_pos": pos, **item} - - epoch += 1 - - -class _DataPackerIterableDataset(PackingIterableDataset): - """Private: injects a DataPacker into the shared packing engine. - - Not registered in Hydra directly. Use ``DataPackerDataLoader`` instead. - """ - - def __init__( - self, - data_source: Any, - data_packer: DataPacker, - max_tokens: int, - pool_size: int, - max_batch_size: int, - long_threshold: int, - batching_strategy: str, - dp_rank: int = 0, - dp_world_size: int = 1, - shuffle: bool = False, - seed: int = 0, - name: str = "", - apply_long_sample_halving: bool = True, - ): - is_map = isinstance(data_source, torch.utils.data.Dataset) and not isinstance( - data_source, torch.utils.data.IterableDataset - ) - is_iterable = isinstance(data_source, torch.utils.data.IterableDataset) - if not is_map and not is_iterable: - raise TypeError( - f"data_source must be a torch.utils.data.Dataset or " - f"torch.utils.data.IterableDataset, got {type(data_source).__name__}" - ) - - if is_map: - # All map-style datasets go through _ShuffledMapIterableDataset, - # which handles sharding and position metadata regardless of shuffle. - # This enables stateful checkpoint/resume even when shuffle=False. - data_source = _ShuffledMapIterableDataset( - dataset=data_source, - seed=seed, - dp_rank=dp_rank, - dp_world_size=dp_world_size, - shuffle=shuffle, - name=name, - ) - self._has_dp_meta = True - else: - # Iterable-style: wrap with _IterableWrapper for sharding only. - # Stateful resume is not supported for IterableDataset sources. - data_source = _IterableWrapper(data_source, dp_rank=dp_rank, dp_world_size=dp_world_size) - self._has_dp_meta = False - - datasets_cfg = {"default": {"dataset": data_source, "ratio": 1.0}} - super().__init__( - datasets_cfg=datasets_cfg, - max_tokens=max_tokens, - pool_size=pool_size, - max_batch_size=max_batch_size, - long_threshold=long_threshold, - batching_strategy=batching_strategy, - apply_long_sample_halving=apply_long_sample_halving, - ) - self._data_packer = data_packer - - def _get_next_sample(self) -> dict: - raw_item = super()._get_next_sample() - if self._has_dp_meta: - # Strip _dp_* keys before sft_process_sample so the user's packer - # receives a clean item, then re-attach so metadata survives the pool. - dp_meta = {k: raw_item.pop(k) for k in list(raw_item) if k.startswith("_dp_")} - processed = self._data_packer.sft_process_sample(raw_item) - processed.update(dp_meta) - return processed - return self._data_packer.sft_process_sample(raw_item) - - def compute_sample_tokens(self, sample: dict) -> int: - return self._data_packer.compute_num_tokens(sample) - - def collate_batch(self, samples: list) -> dict: - max_len = max(self.compute_sample_tokens(s) for s in samples) - - if self._has_dp_meta and "_dp_epoch" in samples[0]: - max_epoch = max(s["_dp_epoch"] for s in samples) - max_pos = max(s["_dp_stream_pos"] for s in samples) - clean = [{k: v for k, v in s.items() if not k.startswith("_dp_")} for s in samples] - batch = self._data_packer.sft_collate_fn(clean, max_len) - worker_info = torch.utils.data.get_worker_info() - worker_id = worker_info.id if worker_info is not None else 0 - batch["sample_worker_id"] = torch.tensor([worker_id] * len(samples)) - batch["sample_epoch"] = torch.tensor([max_epoch] * len(samples)) - batch["sample_index"] = torch.tensor([max_pos] * len(samples)) - else: - batch = self._data_packer.sft_collate_fn(samples, max_len) - - return batch - - -class DataPackerDataLoader(torch.utils.data.DataLoader): - """Public OSS entry point for bringing any dataset into i4 training. - - Wraps ``_DataPackerIterableDataset`` in a standard - ``torch.utils.data.DataLoader`` — no WebDataset dependency required. - OSS users' data can be HuggingFace datasets, local files, generators, - or any Python iterable. - - Data-parallel sharding is automatic when ``torch.distributed`` is - initialized. Each ``(dp_rank, worker_id)`` pair receives a disjoint - subset of ``data_source``. - - Parameters - ---------- - data_source: - ``torch.utils.data.Dataset`` (map-style) or - ``torch.utils.data.IterableDataset`` — HuggingFace datasets, custom - datasets, or generators wrapped in an ``IterableDataset``. Plain - lists / generators are not accepted; wrap them in an ``IterableDataset`` - first. - data_packer: - A ``DataPacker`` subclass instance. Provides sample-level transform - (``sft_process_sample``), token counting (``compute_num_tokens``), and - batch collation (``sft_collate_fn``). - max_tokens: - Token budget per batch. - pool_size: - Samples to buffer before bin-packing. - max_batch_size: - Hard cap on items per batch. - long_threshold: - Samples with token count >= this are emitted as singleton batches. - batching_strategy: - ``"prefer_closest"`` (default) or ``"prefer_first"``. - shuffle: - If ``True`` and ``data_source`` is a map-style ``Dataset``, shuffle - samples with a per-epoch ``torch.randperm`` seeded by ``seed + epoch``. - Enables stateful checkpoint/resume via ``DataLoaderStateCallback`` - (``distributor_type="data_packer"``). Has no effect for - ``IterableDataset`` inputs — a warning is logged in that case. - seed: - Base seed for the per-epoch shuffle permutation. Epoch ``e`` uses - ``seed + e`` as the generator seed. Ignored when ``shuffle=False``. - num_workers, prefetch_factor, persistent_workers, pin_memory: - Forwarded to ``torch.utils.data.DataLoader``. When ``shuffle=True`` - and ``num_workers > 0``, ``persistent_workers`` is automatically - promoted to ``True`` (required for correct resume behaviour). - parallel_dims: - Optional ``ParallelDims`` instance (from cosmos-rl). When provided, - ``parallel_dims.dp_coord`` supplies the data-parallel rank and world - size, which is correct for FSDP+TP/PP where the DP degree differs from - the global world size. When ``None`` (default), rank info is read from - ``torch.distributed`` if initialized, else defaults to ``(0, 1)``. - name: - Optional identifier used to namespace resume env vars when multiple - ``DataPackerDataLoader`` instances share the same process (e.g. inside - ``JointDataPackerDataLoader``). When non-empty, env vars become - ``DP_STATE_{name}_WORKER_{id}_EPOCH/INDEX`` instead of the default - ``DP_STATE_WORKER_{id}_EPOCH/INDEX``. Must match the ``name`` passed - to the corresponding ``DataLoaderStateCallback`` or - ``JointDataLoaderStateCallback``. Leave empty (default) for - single-loader configurations. - apply_long_sample_halving: - When ``True`` (default), the inner ``PackingIterableDataset._max_tokens`` - halves the budget for any batch whose largest sample has >= 1000 tokens - — a memory-safety heuristic. Set ``False`` to use the literal - ``max_tokens`` budget unconditionally; only do this when memory - headroom at the un-halved budget has been validated for the recipe - (large MoT + LoRA recipes can OOM at the literal budget — see - ``packing_iterable_dataset.py::_max_tokens``). - """ - - def __init__( - self, - data_source: Any, - data_packer: DataPacker, - max_tokens: int, - pool_size: int = 16, - max_batch_size: int = 1, - long_threshold: int = 6400, - batching_strategy: str = "prefer_closest", - shuffle: bool = False, - seed: int = 0, - num_workers: int = 0, - prefetch_factor: int | None = None, - persistent_workers: bool = False, - pin_memory: bool = False, - parallel_dims=None, - name: str = "", - apply_long_sample_halving: bool = True, - ): - is_map = isinstance(data_source, torch.utils.data.Dataset) and not isinstance( - data_source, torch.utils.data.IterableDataset - ) - is_iterable = isinstance(data_source, torch.utils.data.IterableDataset) - if shuffle and is_iterable: - log.warning( - "DataPackerDataLoader: shuffle=True has no effect for IterableDataset " - "data_source. Shuffle the dataset before passing it in.", - rank0_only=True, - ) - - # Correctness requirement: map-style datasets use _ShuffledMapIterableDataset - # which reads resume env vars on the first __iter__ call inside each worker. - # With persistent_workers=False, workers re-spawn each iteration and - # re-inherit the env vars, causing incorrect fast-forward on every epoch - # boundary. Enforce persistent_workers=True for all map-style datasets. - if is_map and num_workers > 0 and not persistent_workers: - log.info( - "DataPackerDataLoader: map-style data_source requires persistent_workers=True " - "for correct stateful resume behaviour. Overriding persistent_workers to True.", - rank0_only=True, - ) - persistent_workers = True - - # Resolve data-parallel rank and world-size. - # Priority: explicit parallel_dims > torch.distributed > single-GPU default. - if parallel_dims is not None: - dp_rank, dp_world_size = parallel_dims.dp_coord - elif torch.distributed.is_initialized(): - dp_rank = torch.distributed.get_rank() - dp_world_size = torch.distributed.get_world_size() - - # rank/world_size differ from the data-parallel rank/world_size. - # Pass `parallel_dims` to use the correct DP coordinates; otherwise - # data sharding will be incorrect (each logical DP group sees the - # same shard as another group). - if dp_world_size > 1: - log.info( - "DataPackerDataLoader: using global rank for DP sharding. " - "For FSDP+TP/PP setups pass parallel_dims= to use the correct " - "DP rank/world_size.", - rank0_only=True, - ) - else: - dp_rank, dp_world_size = 0, 1 - - dataset = _DataPackerIterableDataset( - data_source=data_source, - data_packer=data_packer, - max_tokens=max_tokens, - pool_size=pool_size, - max_batch_size=max_batch_size, - long_threshold=long_threshold, - batching_strategy=batching_strategy, - dp_rank=dp_rank, - dp_world_size=dp_world_size, - shuffle=shuffle, - seed=seed, - name=name, - apply_long_sample_halving=apply_long_sample_halving, - ) - loader_kwargs: dict = dict( - num_workers=num_workers, - persistent_workers=persistent_workers and num_workers > 0, - pin_memory=pin_memory, - ) - if num_workers > 0 and prefetch_factor is not None: - loader_kwargs["prefetch_factor"] = prefetch_factor - # batch_size=None disables PyTorch's automatic batching/collation. - # _DataPackerIterableDataset.__iter__ already yields fully-collated batch dicts; - # letting the DataLoader re-collate them adds spurious batch dimensions. - super().__init__(dataset, batch_size=None, **loader_kwargs) - - -class JointDataPackerDataLoader: - """Wraps multiple ``DataPackerDataLoader`` instances with ratio-based seeded selection. - - Mirrors the design of ``IterativeJointDataLoader``: one output batch = one - inner loader, selected deterministically by ratio at each step. Adds a - ``"dataset_name"`` key to every yielded batch so downstream callbacks can - route state updates to the correct inner loader. - - Parameters - ---------- - dataloaders: - ``{name: {"dataloader": DataPackerDataLoader, "ratio": int}}`` mapping. - Entries with ``ratio <= 0`` are silently skipped. - seed: - Base seed for the per-step dataset selection. Step ``i`` uses - ``np.random.RandomState(seed + i)`` to pick the inner loader index, - giving the same sequence on every rank (assuming synchronized - ``set_start_iteration`` calls) and fully reproducible resume. - - Stateful checkpoint/resume - -------------------------- - Pair with ``JointDataLoaderStateCallback`` (from - ``cosmos_framework.callbacks.dataloader_state``). That callback saves the outer - ``global_id`` and each inner loader's per-worker ``(epoch, index)`` state - in a single DCP checkpoint entry. On resume: - - 1. ``JointDataLoaderStateCallback.load_state_dict`` calls - ``set_start_iteration(global_id)`` to restore the selection sequence. - 2. Each inner ``DataLoaderStateCallback.load_state_dict`` sets namespaced - env vars so inner-loader workers fast-forward to the saved position. - - Each ``DataPackerDataLoader`` must be constructed with a unique ``name`` - that matches the key used in this ``dataloaders`` dict so env vars are - namespaced correctly (see ``DataPackerDataLoader`` ``name`` parameter). - """ - - def __init__( - self, - dataloaders: dict[str, dict], - seed: int = 42, - ) -> None: - entries = [ - (name, cfg["dataloader"], cfg["ratio"]) - for name, cfg in dataloaders.items() - if cfg.get("ratio", 0) > 0 - ] - if not entries: - raise ValueError("JointDataPackerDataLoader: no dataloaders with ratio > 0") - - self._names: list[str] = [e[0] for e in entries] - if "global_id" in self._names: - raise ValueError( - "JointDataPackerDataLoader: dataset name 'global_id' is reserved " - "by the checkpoint state format; use a different name." - ) - self._loaders: list[DataPackerDataLoader] = [e[1] for e in entries] - ratios = np.array([e[2] for e in entries], dtype=float) - self._probs: np.ndarray = ratios / ratios.sum() - self._seed = seed - self._global_id = 0 - # Iterators are created lazily on the first __iter__ call so that - # DataLoaderStateCallback.load_state_dict can install resume env vars - # before workers are spawned (for num_workers > 0, iter(DataLoader) - # forks workers immediately; env vars must be set in the parent first). - self._iterators: list | None = None - - total = ratios.sum() - lines = [f"JointDataPackerDataLoader: {len(self._names)} streams"] - for name, ratio in zip(self._names, ratios): - lines.append(f" {name}: ratio={ratio:.4g} ({ratio / total:.1%})") - log.info("\n".join(lines)) - - def set_start_iteration(self, iteration: int) -> None: - """Restore deterministic selection sequence after checkpoint resume. - - Called by ``JointDataLoaderStateCallback.load_state_dict`` and by the - trainer (if present) via ``hasattr`` guard. - """ - self._global_id = iteration - - def __iter__(self): - # Lazy init: create iterators here (not in __init__) so that - # load_state_dict can set resume env vars before workers fork. - if self._iterators is None: - self._iterators = [iter(loader) for loader in self._loaders] - while True: - rng = np.random.RandomState(self._seed + self._global_id) - idx = int(rng.choice(len(self._loaders), p=self._probs)) - try: - batch = next(self._iterators[idx]) - except StopIteration: - # Inner DataPackerDataLoaders are infinite; this guard handles - # the unlikely case of a finite IterableDataset inner source. - self._iterators[idx] = iter(self._loaders[idx]) - batch = next(self._iterators[idx]) - batch["dataset_name"] = self._names[idx] - self._global_id += 1 - yield batch diff --git a/cosmos_framework/data/vfm/dataflow/__init__.py b/cosmos_framework/data/vfm/dataflow/__init__.py new file mode 100644 index 0000000..50b300b --- /dev/null +++ b/cosmos_framework/data/vfm/dataflow/__init__.py @@ -0,0 +1,38 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: OpenMDW-1.1 + +"""Modular training dataflow: DataDistributor -> RawItemProcessor -> +SampleBatcher -> BatchCollator, wired by CosmosDataLoader.""" + +from __future__ import annotations + +from cosmos_framework.data.vfm.dataflow.base import ( + BatchCollator, + DataDistributor, + RawItemProcessor, + SampleBatcher, +) +from cosmos_framework.data.vfm.dataflow.batchers import PoolPackingBatcher, SequentialPackingBatcher, SimpleBatcher +from cosmos_framework.data.vfm.dataflow.collators import DefaultBatchCollator, VFMListCollator +from cosmos_framework.data.vfm.dataflow.distributors import IterableDistributor, MapDistributor, MixtureDistributor, RankPartitionedDistributor +from cosmos_framework.data.vfm.dataflow.loader import CosmosDataLoader, JointCosmosDataLoader +from cosmos_framework.data.vfm.dataflow.processors import IdentityProcessor + +__all__ = [ + "BatchCollator", + "CosmosDataLoader", + "JointCosmosDataLoader", + "DataDistributor", + "DefaultBatchCollator", + "IdentityProcessor", + "IterableDistributor", + "MapDistributor", + "MixtureDistributor", + "RankPartitionedDistributor", + "PoolPackingBatcher", + "RawItemProcessor", + "SampleBatcher", + "SequentialPackingBatcher", + "SimpleBatcher", + "VFMListCollator", +] diff --git a/cosmos_framework/data/vfm/dataflow/base.py b/cosmos_framework/data/vfm/dataflow/base.py new file mode 100644 index 0000000..0be77a5 --- /dev/null +++ b/cosmos_framework/data/vfm/dataflow/base.py @@ -0,0 +1,65 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: OpenMDW-1.1 + +"""The four dataflow role ABCs. + +A raw item flows through four independently-swappable roles in a fixed order +enforced by the loader: + + DataDistributor -> RawItemProcessor -> SampleBatcher -> BatchCollator + +See docs/superpowers/specs/2026-06-04-modular-dataflow-refactor-design.md. +""" + +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import Any, Iterator + + +class DataDistributor(ABC): + """Owns the raw dataset, shards it disjointly across DP ranks x workers, + shuffles, and (later) carries checkpoint/resume state.""" + + @abstractmethod + def stream( + self, dp_rank: int, dp_world_size: int, worker_id: int, num_workers: int + ) -> Iterator[Any]: + """Yield this (rank, worker)'s disjoint slice of raw items, indefinitely.""" + + def state_dict(self) -> dict: + """Resume state. No-op default; resumable distributors override.""" + return {} + + def load_state_dict(self, state: dict) -> None: + """Restore resume state. No-op default; resumable distributors override.""" + return None + + +class RawItemProcessor(ABC): + """Transforms one raw dataset item into one training-ready sample dict.""" + + @abstractmethod + def process(self, item: Any) -> dict: + ... + + +class SampleBatcher(ABC): + """Consumes a stream of samples and yields groups (the selection strategy).""" + + @abstractmethod + def batches(self, samples: Iterator[dict]) -> Iterator[list[dict]]: + """Pull from ``samples``; yield one ``list[dict]`` per batch.""" + + def sample_size(self, sample: dict) -> int: + """Per-sample token cost for packing batchers. Non-packing batchers + never call this; packing batchers override it (or inject a size_fn).""" + raise NotImplementedError + + +class BatchCollator(ABC): + """Collates one group of samples into one batch dict for ``model.forward()``.""" + + @abstractmethod + def collate(self, samples: list[dict]) -> dict: + ... diff --git a/cosmos_framework/data/vfm/dataflow/batchers.py b/cosmos_framework/data/vfm/dataflow/batchers.py new file mode 100644 index 0000000..b6c63dd --- /dev/null +++ b/cosmos_framework/data/vfm/dataflow/batchers.py @@ -0,0 +1,350 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: OpenMDW-1.1 + +"""Built-in SampleBatcher implementations.""" + +from __future__ import annotations + +from collections import deque +from enum import Enum +from typing import Callable, Iterator, Optional + +from cosmos_framework.data.vfm.dataflow.base import SampleBatcher + + +class SimpleBatcher(SampleBatcher): + """Fixed-size batching — stock DataLoader behavior. Never needs sample_size.""" + + def __init__(self, batch_size: int, drop_last: bool = False): + if batch_size < 1: + raise ValueError(f"batch_size must be >= 1, got {batch_size}") + self.batch_size = batch_size + self.drop_last = drop_last + + def batches(self, samples: Iterator[dict]) -> Iterator[list[dict]]: + buf: list[dict] = [] + for s in samples: + buf.append(s) + if len(buf) == self.batch_size: + yield buf + buf = [] + if buf and not self.drop_last: + yield buf + + +class _Modality(Enum): + IMAGE = "image" + VIDEO = "video" + TEXT = "text" + + +class PoolPackingBatcher(SampleBatcher): + """Pool-based greedy bin-packing batcher (re-homed from PackingIterableDataset). + + Buffers ``pool_size`` samples and assembles each batch by greedily selecting + candidates that fit within the padded token budget, never mixing modalities + within a batch. ``sample_size`` defaults to ``len(sample["input_ids"])``; + pass ``size_fn`` to override, or subclass and override the method. + """ + + def __init__( + self, + max_tokens: int, + pool_size: int = 16, + max_batch_size: int = 1, + long_threshold: int = 6400, + batching_strategy: str = "prefer_closest", + apply_long_sample_halving: bool = True, + size_fn: Optional[Callable[[dict], int]] = None, + ): + assert batching_strategy in ("prefer_first", "prefer_closest"), ( + f"batching_strategy must be 'prefer_first' or 'prefer_closest', got {batching_strategy!r}" + ) + self.max_tokens = max_tokens + self.pool_size = pool_size + self.max_batch_size = max_batch_size + self.long_threshold = long_threshold + self.batching_strategy = batching_strategy + self.apply_long_sample_halving = apply_long_sample_halving + self._size_fn = size_fn + + def sample_size(self, sample: dict) -> int: + if self._size_fn is not None: + return self._size_fn(sample) + # len() == shape[0] for a 1-D tensor and also works for list input_ids. + return len(sample["input_ids"]) + + def batches(self, samples: Iterator[dict]) -> Iterator[list[dict]]: + pool: deque[dict] = deque() + src = iter(samples) + exhausted = False + while True: + while not exhausted and len(pool) < self.pool_size: + try: + pool.append(next(src)) + except StopIteration: + exhausted = True + if not pool: + return + yield self._best_fit_batch(pool) + + def _max_tokens(self, cur_max: int) -> int: + if not self.apply_long_sample_halving: + return self.max_tokens + if cur_max < 1000: + return self.max_tokens + return self.max_tokens // 2 + + @staticmethod + def _get_modality(sample: dict) -> "_Modality": + if "pixel_values" in sample: + return _Modality.IMAGE + elif "pixel_values_videos" in sample: + return _Modality.VIDEO + return _Modality.TEXT + + @staticmethod + def _padded_cost(cur_max: int, k: int) -> int: + return cur_max * k + + def _best_fit_batch(self, pool: deque) -> list[dict]: + seed = pool.popleft() + seed_modality = self._get_modality(seed) + L0 = self.sample_size(seed) + if L0 >= self.long_threshold or L0 >= self._max_tokens(L0): + return [seed] + chosen = [seed] + cur_max = L0 + while pool: + if self.max_batch_size and len(chosen) >= self.max_batch_size: + break + best_idx = self._find_best_candidate(pool, cur_max, len(chosen), seed_modality) + if best_idx is None: + break + cand = self._remove_from_pool(pool, best_idx) + chosen.append(cand) + cur_max = max(cur_max, self.sample_size(cand)) + return chosen + + def _find_best_candidate(self, pool, cur_max, num_chosen, seed_modality): + if self.batching_strategy == "prefer_first": + return self._find_best_candidate_prefer_first(pool, cur_max, num_chosen, seed_modality) + return self._find_best_candidate_prefer_closest(pool, cur_max, num_chosen, seed_modality) + + def _find_best_candidate_prefer_first(self, pool, cur_max, num_chosen, seed_modality): + best_idx = None + best_new_tokens = None + for idx, cand in enumerate(pool): + if self._get_modality(cand) != seed_modality: + continue + L = self.sample_size(cand) + new_max = max(cur_max, L) + new_tokens = self._padded_cost(new_max, num_chosen + 1) + if new_tokens <= self._max_tokens(cur_max): + if best_new_tokens is None or new_tokens < best_new_tokens: + best_new_tokens = new_tokens + best_idx = idx + return best_idx + + def _find_best_candidate_prefer_closest(self, pool, cur_max, num_chosen, seed_modality): + best_idx = None + best_new_tokens = None + smallest_length_diff = None + for idx, cand in enumerate(pool): + if self._get_modality(cand) != seed_modality: + continue + L = self.sample_size(cand) + new_max = max(cur_max, L) + new_tokens = self._padded_cost(new_max, num_chosen + 1) + if new_tokens <= self._max_tokens(cur_max): + length_diff = abs(L - cur_max) + if ( + best_new_tokens is None + or new_tokens < best_new_tokens + or (new_tokens == best_new_tokens and length_diff < smallest_length_diff) + ): + best_new_tokens = new_tokens + best_idx = idx + smallest_length_diff = length_diff + return best_idx + + @staticmethod + def _remove_from_pool(pool: deque, idx: int) -> dict: + if idx == 0: + return pool.popleft() + elif idx == len(pool) - 1: + return pool.pop() + else: + pool.rotate(-idx) + item = pool.popleft() + pool.rotate(idx) + return item + + +from collections import deque as _deque + +from cosmos_framework.utils import log + + +class SequentialPackingBatcher(SampleBatcher): + """Order-preserving pull-until-budget packing (port of PackingDataLoader.__iter__). + + Accumulates samples in stream order until `max_sequence_length` (or + `max_samples_per_batch`); a sample that would overflow a non-empty batch is + carried to the next batch (bounded by `lookahead_limit`); a sample that alone + exceeds the budget is discarded with a log. `sample_size` ports the VFM VAE + token formula (needs the tokenizer compression factors + patch size + optional + sound params). + """ + + def __init__( + self, + max_sequence_length: Optional[int] = None, + tokenizer_spatial_compression_factor: int = 16, + tokenizer_temporal_compression_factor: int = 4, + patch_spatial: int = 2, + max_samples_per_batch=None, + lookahead_limit: int = 10, + sound_latent_fps: float = 0, + audio_sample_rate: int = 48000, + ): + self.max_sequence_length = max_sequence_length + self.tokenizer_spatial_compression_factor = tokenizer_spatial_compression_factor + self.tokenizer_temporal_compression_factor = tokenizer_temporal_compression_factor + self.patch_spatial = patch_spatial + self.max_samples_per_batch = max_samples_per_batch + self.lookahead_limit = lookahead_limit + self.sound_latent_fps = sound_latent_fps + self.audio_sample_rate = audio_sample_rate + assert (self.max_sequence_length is None) != (self.max_samples_per_batch is None), ( + "Exactly one of max_sequence_length or max_samples_per_batch must be set " + "(token-budget mode vs count-only mode), matching legacy PackingDataLoader." + ) + + def sample_size(self, sample: dict) -> int: + # PORT of _compute_num_tokens_per_sample (joint_dataloader.py:325-400), + # operating on a SINGLE sample. + # + # In the original batched method: + # - text_token_ids is a list of tensors → num_text_tokens = text_token_ids[0].shape[0] + # - text_token_ids is a 2-D tensor [B,S] → num_text_tokens = text_token_ids.shape[1] + # For a single sample: + # - 1-D tensor [S] → shape[0] (torch.arange(N) case from tests) + # - list of tensors → text_token_ids[0].shape[0] + # - list of ints → len(text_token_ids) + # - 2-D tensor [1,S] → shape[1] (mirrors original .shape[1] branch) + import torch as _torch + text_token_ids = sample["text_token_ids"] + if isinstance(text_token_ids, list): + if len(text_token_ids) > 0 and isinstance(text_token_ids[0], _torch.Tensor): + num_text_tokens = text_token_ids[0].shape[0] + else: + num_text_tokens = len(text_token_ids) + else: + # tensor: 1-D → shape[0], 2-D → shape[1] + if text_token_ids.ndim == 1: + num_text_tokens = text_token_ids.shape[0] + else: + num_text_tokens = text_token_ids.shape[1] + + num_tokens = num_text_tokens + 1 + + # Vision part — single sample has "images" or "video" as a tensor, + # not a list. Wrap in [media] to mirror the original's iteration loop. + is_image_batch = "images" in sample + input_images_or_videos = sample["images" if is_image_batch else "video"] + + for media in input_images_or_videos if isinstance(input_images_or_videos, list) else [input_images_or_videos]: + if is_image_batch: + _, H, W = media.shape + T = 1 + else: + _, T, H, W = media.shape + + vae_spatial_downsample = self.tokenizer_spatial_compression_factor * self.patch_spatial + vae_temporal_downsample = self.tokenizer_temporal_compression_factor + + latent_h_shape = H // vae_spatial_downsample + latent_w_shape = W // vae_spatial_downsample + latent_t_shape = 1 + (T - 1) // vae_temporal_downsample + + num_vision_tokens = latent_h_shape * latent_w_shape * latent_t_shape + 2 + num_tokens += num_vision_tokens + + # Action part — single sample: action is a tensor [T_action, D] or None, + # not wrapped in a list. Mirror the original: iterate as list for uniform handling. + if "action" in sample: + action = sample["action"] + action_list = action if isinstance(action, list) else [action] + for act in action_list: + if act is None: + continue + num_tokens += act.shape[0] + + # Sound part — estimate sound tokens from audio waveform length + if self.sound_latent_fps > 0 and "sound" in sample: + sound_data = sample["sound"] + if isinstance(sound_data, list) and len(sound_data) > 0: + first_sound = sound_data[0] + if isinstance(first_sound, list): + first_sound = first_sound[0] + if first_sound is not None and isinstance(first_sound, _torch.Tensor): + num_audio_samples = first_sound.shape[-1] + audio_duration = num_audio_samples / self.audio_sample_rate + num_sound_tokens = int(audio_duration * self.sound_latent_fps) + num_tokens += num_sound_tokens + elif isinstance(sound_data, _torch.Tensor): + num_audio_samples = sound_data.shape[-1] + audio_duration = num_audio_samples / self.audio_sample_rate + num_sound_tokens = int(audio_duration * self.sound_latent_fps) + num_tokens += num_sound_tokens + + return num_tokens + + def batches(self, samples): + src = iter(samples) + carry = _deque() + exhausted = False + while True: + current_len = 0 + num_samples = 0 + group = [] + skipped = _deque() + lookahead = 0 + + def _next(): + if carry: + return carry.popleft() + return next(src) + + while True: + if self.max_samples_per_batch is not None and num_samples >= self.max_samples_per_batch: + break + if group and lookahead >= self.lookahead_limit: + break + try: + s = _next() + except StopIteration: + exhausted = True + break + n = self.sample_size(s) + if self.max_sequence_length is not None and current_len + n >= self.max_sequence_length: + if not group: + log.error( + f"SequentialPackingBatcher: discarding oversized sample with {n} " + f"tokens (max_sequence_length={self.max_sequence_length})", + rank0_only=False, + ) + continue + skipped.append(s) + lookahead += 1 + continue + current_len += n + num_samples += 1 + group.append(s) + for s in reversed(skipped): + carry.appendleft(s) + if group: + yield group + if exhausted and not carry: + return diff --git a/cosmos_framework/data/vfm/dataflow/collators.py b/cosmos_framework/data/vfm/dataflow/collators.py new file mode 100644 index 0000000..81b3782 --- /dev/null +++ b/cosmos_framework/data/vfm/dataflow/collators.py @@ -0,0 +1,211 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: OpenMDW-1.1 + +"""Built-in BatchCollator implementations.""" + +from __future__ import annotations + +import torch +import torch.utils.data +from torch.utils.data.dataloader import default_collate + +from cosmos_framework.data.vfm.dataflow.base import BatchCollator + + +class DefaultBatchCollator(BatchCollator): + """Stacks samples with torch's default_collate — stock DataLoader behavior.""" + + def collate(self, samples: list[dict]) -> dict: + return torch.utils.data.default_collate(samples) + + +# --------------------------------------------------------------------------- +# VFMListCollator — reproduces the legacy PackingDataLoader packed-batch structure: +# for each raw sample, inner-collate at batch_size=1 (custom_collate_fn copy), +# split per _get_next_sample i=0 rules, then _update_output_batch-accumulate. +# +# This produces list[list[Tensor]] for _MULTI_ITEM_KEYS (not flat list[Tensor]). +# --------------------------------------------------------------------------- + +_TIMING_KEYS = {"_sample_time", "_aug_time", "_pre_aug_time", "_aug_step_times"} +_BATCH_TIMING_KEYS = { + "_worker_batch_time", + "_worker_aug_time", + "_worker_io_time", + "_worker_aug_step_times", + "_worker_id", +} + +# Verbatim copy of JointDataLoader._MULTI_ITEM_KEYS +_MULTI_ITEM_KEYS = {"text_token_ids", "images", "video", "action", "sound"} + +# Verbatim copy of JointDataLoader._FLATTEN_LIST_KEYS +_FLATTEN_LIST_KEYS = {"image_size"} + + +def _vfm_inner_collate(batch): + """ + Verbatim copy of custom_collate_fn from joint_dataloader.py. + + Collate function that works like default_collate for all keys other than "text_token_ids", "images", and "video". + For "text_token_ids", "images", and "video" it simply returns them in a list, instead of stacking them as a tensor. + """ + list_collate_keys = { + "text_token_ids", + "images", + "video", + "action", + "domain_id", + "sequence_plan", + "sound", + "raw_action_dim", + "image_size", + } + + # Data keys where a per-sample value of ``None`` is a meaningful signal + # (e.g. audio extraction failed for that sample → ``sound=None`` paired + # with ``plan.has_sound=False``). These keys must be kept as a list with + # ``None`` placeholders so the model can align per-sample data 1:1 with + # per-sample plans. Dropping the entire key on any None would leave the + # remaining sound tensors mis-aligned with the plans whose ``has_sound`` + # flag was set BEFORE collation, causing ``sequence_packing`` to index + # past the end of ``x0_tokens_sound``. + sparse_data_keys = {"sound"} + + # Handle the case where the batch is already a dictionary (e.g. column-wise batching) + if isinstance(batch, dict): + return {key: (value if key in list_collate_keys else default_collate(value)) for key, value in batch.items()} + + # Handle standard list of samples + elem = batch[0] + if isinstance(elem, dict): + + # Some Action datasets add optional metadata keys (for example + # ``additional_view_description`` for concat-view captions) only for a + # subset of samples. PyTorch can batch such samples together when + # DataLoader batch_size > 1; collating only elem's keys and indexing + # every sample by that key turns the optional field into a fatal + # KeyError. Use the union of keys and skip optional keys that are not + # present in every sample. Required training keys still fail loudly via + # downstream assertions if actually missing. + result = {} + keys = set().union(*(d.keys() for d in batch)) + for key in keys: + if key in _TIMING_KEYS: + continue + values = [d.get(key) for d in batch] + if any(value is None for value in values): + # Sparse data keys keep their None placeholders to preserve + # 1:1 alignment with sequence_plan. Other (optional metadata) + # keys not present in every sample are dropped. + if key in sparse_data_keys: + result[key] = values + continue + if key in list_collate_keys: + result[key] = values + else: + result[key] = default_collate(values) + result.update(_aggregate_worker_timing(batch)) + return result + else: + return default_collate(batch) + + +def _aggregate_worker_timing(samples: list[dict]) -> dict: + """Extract per-sample timing keys, aggregate into per-batch scalars.""" + info: dict[str, float | int] = {} + if "_sample_time" in samples[0]: + info["_worker_batch_time"] = sum(s.get("_sample_time", 0.0) for s in samples) + if "_aug_time" in samples[0]: + aug_total = sum(s.get("_aug_time", 0.0) for s in samples) + info["_worker_aug_time"] = aug_total + if "_worker_batch_time" in info: + info["_worker_io_time"] = info["_worker_batch_time"] - aug_total + if "_aug_step_times" in samples[0]: + agg: dict[str, float] = {} + for s in samples: + for step_name, t in s.get("_aug_step_times", {}).items(): + agg[step_name] = agg.get(step_name, 0.0) + t + info["_worker_aug_step_times"] = agg + worker_info = torch.utils.data.get_worker_info() + info["_worker_id"] = worker_info.id if worker_info is not None else 0 + return info + + +def _split_one(batch: dict) -> dict: + """Port of _get_next_sample split rules for i=0 (verbatim from joint_dataloader.py lines 470-490). + + Splitting rules: + - _BATCH_TIMING_KEYS: passed through as-is. + - _MULTI_ITEM_KEYS with list value: elem = v[0]; if elem is a list → sample[k]=elem, + else → sample[k]=v[0:1] (single-element list wrapping the tensor). + - Other list values: sample[k] = v[0] (bare element, direct-indexed). + - Non-list (tensor/scalar) values: sample[k] = v[0:1] (preserve batch dim). + """ + sample = {} + for k, v in batch.items(): + if k in _BATCH_TIMING_KEYS: + sample[k] = v + elif isinstance(v, list) and k in _MULTI_ITEM_KEYS: + elem = v[0] + sample[k] = elem if isinstance(elem, list) else v[0:1] + elif isinstance(v, list): + sample[k] = v[0] + else: + sample[k] = v[0:1] + return sample + + +def _accumulate(output_batch: dict, output: dict) -> None: + """Port of _update_output_batch from joint_dataloader.py lines 405-418.""" + for key, value in output.items(): + if key in _BATCH_TIMING_KEYS: + if key not in output_batch: + output_batch[key] = value + elif key in _FLATTEN_LIST_KEYS and isinstance(value, list): + if key not in output_batch: + output_batch[key] = value + else: + output_batch[key].extend(value) + elif key not in output_batch: + output_batch[key] = [value] + else: + output_batch[key].append(value) + + +# Keep _vfm_collate as an alias for the inner collate (used by legacy callers). +_vfm_collate = _vfm_inner_collate + + +class VFMListCollator(BatchCollator): + """Reproduces the legacy PackingDataLoader packed-batch structure. + + For a group of N raw SFTDataset samples, the packed output has: + - ``_MULTI_ITEM_KEYS`` (``text_token_ids``, ``video``, ``images``, + ``action``, ``sound``): ``list[list[Tensor]]`` — each inner list + is a single-element list ``[tensor]``, matching the + ``v[i:i+1]`` slice from ``_get_next_sample``. + - Metadata list keys (``sequence_plan``, ``domain_id``, + ``raw_action_dim``): flat ``list[element]``. + - ``image_size`` (``_FLATTEN_LIST_KEYS``): flat ``list[Tensor]`` + (extended, not appended). + - Non-list tensor keys: ``list[Tensor(1,...)]``. + - ``_BATCH_TIMING_KEYS``: set once from the first sample. + + Implementation: for each sample, inner-collate at batch_size=1 via + ``_vfm_inner_collate`` (verbatim ``custom_collate_fn``), split + sample 0 per ``_split_one`` (verbatim ``_get_next_sample`` i=0 + rules), then accumulate via ``_accumulate`` (verbatim + ``_update_output_batch``). Byte-identical to the legacy packer. + """ + + def collate(self, samples: list[dict]) -> dict: + # Reproduce the legacy PackingDataLoader packed batch: for each sample, + # inner-collate at batch_size=1, split sample 0 per _MULTI_ITEM_KEYS / + # list / tensor rules, then _update_output_batch-accumulate across the group. + output_batch: dict = {} + for s in samples: + collated = _vfm_inner_collate([s]) # verbatim custom_collate_fn copy + split = _split_one(collated) # i=0 split (rules from _get_next_sample) + _accumulate(output_batch, split) # _update_output_batch copy + return output_batch diff --git a/cosmos_framework/data/vfm/dataflow/distributors.py b/cosmos_framework/data/vfm/dataflow/distributors.py new file mode 100644 index 0000000..e9fd6ad --- /dev/null +++ b/cosmos_framework/data/vfm/dataflow/distributors.py @@ -0,0 +1,175 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: OpenMDW-1.1 + +"""Built-in DataDistributor implementations. + +IterableDistributor wraps any Python iterable / IterableDataset with +round-robin DP x worker sharding (no resume). MapDistributor wraps a map-style +Dataset with per-epoch shuffle + slice sharding (resume lands in a later plan). +""" + +from __future__ import annotations + +from typing import Any, Iterator + +from cosmos_framework.data.vfm.dataflow.base import DataDistributor + + +class IterableDistributor(DataDistributor): + """Round-robin shard of an iterable: each (rank, worker) sees every + ``dp_world_size * num_workers``-th item starting at + ``dp_rank * num_workers + worker_id``. Not resumable.""" + + def __init__(self, iterable: Any): + self._iterable = iterable + + def stream( + self, dp_rank: int, dp_world_size: int, worker_id: int, num_workers: int + ) -> Iterator[Any]: + total_streams = dp_world_size * num_workers + my_stream = dp_rank * num_workers + worker_id + for i, item in enumerate(self._iterable): + if i % total_streams == my_stream: + yield item + + +import torch + + +class MapDistributor(DataDistributor): + """Per-epoch shuffle + slice sharding of a map-style Dataset. Resume (env-var + fast-forward) is added in a later plan; for now the ABC no-op defaults apply.""" + + def __init__( + self, + dataset: torch.utils.data.Dataset, + seed: int = 0, + shuffle: bool = True, + name: str = "", + ): + self._dataset = dataset + self._seed = seed + self._shuffle = shuffle + self._name = name + + def __len__(self) -> int: + return len(self._dataset) # type: ignore[arg-type] + + def stream(self, dp_rank, dp_world_size, worker_id, num_workers): + import os + + stream_id = dp_rank * num_workers + worker_id + total_streams = dp_world_size * num_workers + n = len(self._dataset) # type: ignore[arg-type] + if n == 0: + return + if stream_id >= n: + return + _pfx = f"COSMOS_DL_STATE_{self._name}_" if self._name else "COSMOS_DL_STATE_" + resume_epoch = int(os.environ.pop(f"{_pfx}WORKER_{worker_id}_EPOCH", 0)) + resume_pos = int(os.environ.pop(f"{_pfx}WORKER_{worker_id}_INDEX", -1)) + epoch = resume_epoch + while True: + if self._shuffle: + g = torch.Generator().manual_seed(self._seed + epoch) + perm = torch.randperm(n, generator=g).tolist() + else: + perm = list(range(n)) + stream_slice = perm[stream_id::total_streams] + start = (resume_pos + 1) if epoch == resume_epoch else 0 + for pos in range(start, len(stream_slice)): + item = self._dataset[stream_slice[pos]] + if isinstance(item, dict): + yield {"_dp_epoch": epoch, "_dp_stream_pos": pos, **item} + else: + yield item + epoch += 1 + + +from cosmos_framework.utils.lazy_config import instantiate + + +class RankPartitionedDistributor(DataDistributor): + """Allocate whole DP ranks to datasets by ratio; the chosen dataset self-shards. + Ports RankPartitionedDataLoader (joint_dataloader.py:660-757) minus the inner + torch DataLoader (CosmosDataLoader owns workers/collation).""" + + def __init__(self, datasets: dict): + self._datasets_cfg = datasets + self._cached = None # built dataset for this rank, set on first stream() + + def stream(self, dp_rank, dp_world_size, worker_id, num_workers): + if self._cached is None: + self._cached = self._allocate_and_build(dp_rank, dp_world_size) + yield from iter(self._cached) + + def _allocate_and_build(self, rank, world_size): + names, dataset_configs, ratios = [], [], [] + for name, cfg in self._datasets_cfg.items(): + if cfg["ratio"] <= 0: + continue + names.append(name) + dataset_configs.append(cfg["dataset"]) + ratios.append(cfg["ratio"]) + assert len(names) > 0, "No datasets with positive ratios" + assert world_size >= len(names), f"world_size {world_size} < num datasets {len(names)}" + # PORT the allocation verbatim from joint_dataloader.py:707-744: + total_ratio = sum(ratios) + ideal = [r / total_ratio * world_size for r in ratios] + allocations = [max(1, int(q)) for q in ideal] + remaining = world_size - sum(allocations) + if remaining > 0: + order = sorted(range(len(ratios)), key=lambda i: ideal[i] - allocations[i], reverse=True) + for j in range(remaining): + allocations[order[j]] += 1 + elif remaining < 0: + deficit = -remaining + while deficit > 0: + best = max( + (i for i in range(len(allocations)) if allocations[i] > 1), + key=lambda i: (allocations[i] - ideal[i], allocations[i]), + ) + allocations[best] -= 1 + deficit -= 1 + cumulative = 0 + idx = -1 + for i, a in enumerate(allocations): + if rank < cumulative + a: + idx = i + break + cumulative += a + assert idx >= 0 + shard_rank = rank - cumulative + shard_world_size = allocations[idx] + cfg = dataset_configs[idx] + ds = cfg if isinstance(cfg, torch.utils.data.IterableDataset) else instantiate(cfg) + ds.shard_world_size = shard_world_size + ds.shard_rank = shard_rank + ds.shard_id = idx + return ds + + +import random as _random_mod + + +class MixtureDistributor(DataDistributor): + """Ratio-weighted merge of multiple distributors into one stream (homogeneous + join). Generalizes PackingIterableDataset's weighted _get_next_sample.""" + + def __init__(self, sources: dict, seed: int = 0): + # sources: {name: (DataDistributor, ratio_float)} + self._names = list(sources.keys()) + self._dists = [sources[n][0] for n in self._names] + self._ratios = [float(sources[n][1]) for n in self._names] + self._seed = seed + + def stream(self, dp_rank, dp_world_size, worker_id, num_workers): + rng = _random_mod.Random(self._seed + dp_rank * 100003 + worker_id) + iters = [d.stream(dp_rank, dp_world_size, worker_id, num_workers) for d in self._dists] + while True: + idx = rng.choices(range(len(iters)), weights=self._ratios, k=1)[0] + try: + yield next(iters[idx]) + except StopIteration: + iters[idx] = self._dists[idx].stream(dp_rank, dp_world_size, worker_id, num_workers) + yield next(iters[idx]) diff --git a/cosmos_framework/data/vfm/dataflow/golden_vfm_test.py b/cosmos_framework/data/vfm/dataflow/golden_vfm_test.py new file mode 100644 index 0000000..131c812 --- /dev/null +++ b/cosmos_framework/data/vfm/dataflow/golden_vfm_test.py @@ -0,0 +1,276 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: OpenMDW-1.1 + +"""Golden-batch EXACT equality: legacy PackingDataLoader+RankPartitionedDataLoader +vs the new four-role VFM dataflow stack on the same fixed, deterministic source. + +This test proves that the new CosmosDataLoader(RankPartitionedDistributor, +IdentityProcessor, SequentialPackingBatcher, VFMListCollator) yields STRUCTURALLY +IDENTICAL packed batches to PackingDataLoader(RankPartitionedDataLoader(...)) given +an identical input stream — INCLUDING the list[list[Tensor]] nesting for +_MULTI_ITEM_KEYS. + +Comparison covers all payload keys (those not starting with ``_`` and not +``dataset_name``). No nesting normalization is applied. + + - ``video``, ``text_token_ids``: both stacks produce ``list[list[Tensor]]`` + (each inner list is a single-element list wrapping the sample tensor). + - ``image_size``: both stacks produce a flat ``list[Tensor]`` + (``_FLATTEN_LIST_KEYS`` path → extend rather than append). + +Bookkeeping keys excluded from comparison: + - Keys starting with ``_`` (e.g. ``_num_tokens``): legacy internal metadata. + - ``dataset_name``: set by the legacy packing loop; not emitted by + SequentialPackingBatcher. +""" + +from __future__ import annotations + +import os +import torch +import torch.distributed as dist +import torch.utils.data + + +# ────────────────────────────────────────────────────────────────────────────── +# Deterministic stub dataset +# ────────────────────────────────────────────────────────────────────────────── + +# Fixed sample specs: (text_len, T, H, W). Varied so token counts differ and +# multiple samples actually pack per batch. +_SAMPLE_SPECS = [ + (10, 1, 64, 64), + (5, 1, 32, 32), + (8, 2, 64, 64), + (3, 1, 32, 64), + (12, 1, 64, 64), + (6, 2, 32, 32), + (4, 1, 64, 32), + (9, 1, 32, 32), + (7, 2, 64, 32), + (11, 1, 32, 64), + (2, 1, 32, 32), + (15, 1, 64, 64), + (5, 2, 32, 64), + (8, 1, 32, 32), + (6, 1, 64, 64), +] + + +def _make_fixed_samples(): + """Return a deterministic list of SFT-shaped sample dicts.""" + samples = [] + for idx, (tlen, T, H, W) in enumerate(_SAMPLE_SPECS): + # Use constant tensors so equality checks are trivially deterministic. + video = torch.full((3, T, H, W), float(idx), dtype=torch.float32) + text_token_ids = torch.arange(tlen, dtype=torch.long) + # image_size: a small tensor exercising the _FLATTEN_LIST_KEYS path. + image_size = torch.tensor([H, W], dtype=torch.long) + samples.append({ + "video": video, + "text_token_ids": text_token_ids, + "image_size": image_size, + }) + return samples + + +class _FixedSFTDataset(torch.utils.data.IterableDataset): + """Yields the fixed sample list, cycling twice. + + Exposes shard_world_size / shard_rank / shard_id attributes so + RankPartitionedDataLoader and RankPartitionedDistributor can set them. + For world_size=1 (single-process) we simply ignore them and yield all. + """ + + def __init__(self): + super().__init__() + self._samples = _make_fixed_samples() + self.shard_world_size = 1 + self.shard_rank = 0 + self.shard_id = 0 + + def __len__(self): + # Twice the fixed list so the packer can fill N=5 batches comfortably. + return len(self._samples) * 2 + + def __iter__(self): + # Yield ALL samples (world_size=1 case; repeating twice so the packer + # can fill N=5 batches without exhausting the stream). + yield from self._samples + yield from self._samples + + +# ────────────────────────────────────────────────────────────────────────────── +# Token-budget: sized to guarantee multi-sample packing +# ────────────────────────────────────────────────────────────────────────────── +# With (spatial_factor=16, patch_spatial=2, temporal_factor=4): +# 32x32 video, T=1 → latent 1x1x1 + 2 = 3 vision tokens +# 64x64 video, T=1 → latent 2x2x1 + 2 = 6 vision tokens +# Smallest sample: text_len=2 → 2+1+3=6 tokens. +# Budget of 80 → many samples pack per batch. +_BUDGET = 80 + +_PACKER_KWARGS = dict( + tokenizer_spatial_compression_factor=16, + tokenizer_temporal_compression_factor=4, + patch_spatial=2, + sound_latent_fps=0, +) + + +# ────────────────────────────────────────────────────────────────────────────── +# Helpers +# ────────────────────────────────────────────────────────────────────────────── + +def _setup_dist(monkeypatch): + """Init a single-process gloo group; return True if we used gloo, False for monkeypatch. + + Uses monkeypatch.setenv (auto-restored at teardown) so the test does not leave + MASTER_ADDR/MASTER_PORT dirtied in os.environ — the repo conftest enforces this. + """ + monkeypatch.setenv("MASTER_ADDR", "127.0.0.1") + monkeypatch.setenv("MASTER_PORT", "29557") + if not dist.is_initialized(): + try: + dist.init_process_group(backend="gloo", rank=0, world_size=1) + return True + except Exception: + pass + # Fallback: monkeypatch so RankPartitionedDataLoader.__init__ succeeds. + return False + + +def _monkeypatch_dist(monkeypatch): + """Patch the three distributed calls used by RankPartitionedDataLoader.""" + monkeypatch.setattr(dist, "is_initialized", lambda: True) + monkeypatch.setattr(dist, "get_world_size", lambda: 1) + monkeypatch.setattr(dist, "get_rank", lambda: 0) + + +def _drain(loader, n: int) -> list[dict]: + it = iter(loader) + return [next(it) for _ in range(n)] + + +# ────────────────────────────────────────────────────────────────────────────── +# Exact structural comparison helpers +# ────────────────────────────────────────────────────────────────────────────── + +_SKIP_KEYS = {"dataset_name"} # bookkeeping only, not in new stack + + +def _payload_keys(batch: dict) -> set[str]: + """Return non-bookkeeping keys for comparison.""" + return {k for k in batch if not k.startswith("_") and k not in _SKIP_KEYS} + + +def _assert_tensors_equal(a, b, label: str) -> None: + assert isinstance(a, torch.Tensor) and isinstance(b, torch.Tensor), ( + f"{label}: expected two Tensors, got {type(a)} and {type(b)}" + ) + assert torch.equal(a, b), f"{label}: tensor mismatch:\n legacy={a}\n new={b}" + + +def _assert_exact(legacy_val, new_val, key: str) -> None: + """Recursively assert exact structural and value equality.""" + if isinstance(legacy_val, list) and isinstance(new_val, list): + assert len(legacy_val) == len(new_val), ( + f"key={key}: list length mismatch: legacy={len(legacy_val)}, new={len(new_val)}" + ) + for i, (a, b) in enumerate(zip(legacy_val, new_val)): + _assert_exact(a, b, f"{key}[{i}]") + elif isinstance(legacy_val, torch.Tensor) and isinstance(new_val, torch.Tensor): + _assert_tensors_equal(legacy_val, new_val, f"key={key}") + else: + assert type(legacy_val) == type(new_val), ( + f"key={key}: type mismatch: legacy={type(legacy_val)}, new={type(new_val)}" + ) + assert legacy_val == new_val, ( + f"key={key}: value mismatch: legacy={legacy_val!r}, new={new_val!r}" + ) + + +def _compare_batches_exact(legacy: dict, new: dict, batch_idx: int) -> None: + """Assert EXACT structural identity (including list[list[Tensor]] nesting).""" + lk = _payload_keys(legacy) + nk = _payload_keys(new) + assert lk == nk, ( + f"Batch {batch_idx}: key mismatch: legacy={sorted(lk)}, new={sorted(nk)}" + ) + for key in sorted(lk): + _assert_exact(legacy[key], new[key], f"batch[{batch_idx}][{key}]") + + +# ────────────────────────────────────────────────────────────────────────────── +# The golden-batch EXACT equality test +# ────────────────────────────────────────────────────────────────────────────── + +N_BATCHES = 5 + + +def test_vfm_golden_batches_match(monkeypatch): + """New four-role stack yields STRUCTURALLY IDENTICAL packed batches as legacy PackingDataLoader. + + Asserts exact list[list[Tensor]] nesting for _MULTI_ITEM_KEYS (video, text_token_ids) + and flat list[Tensor] for image_size — no nesting normalization. + """ + from cosmos_framework.data.vfm.joint_dataloader import ( + PackingDataLoader, + RankPartitionedDataLoader, + ) + from cosmos_framework.data.vfm.dataflow import ( + CosmosDataLoader, + RankPartitionedDistributor, + SequentialPackingBatcher, + VFMListCollator, + IdentityProcessor, + ) + + # ── distributed bootstrap ────────────────────────────────────────────── + used_gloo = _setup_dist(monkeypatch) + if not used_gloo: + _monkeypatch_dist(monkeypatch) + + try: + # ── legacy stack ────────────────────────────────────────────────── + stub_legacy = _FixedSFTDataset() + legacy = PackingDataLoader( + dataloader=RankPartitionedDataLoader( + datasets={"video": {"dataset": stub_legacy, "ratio": 1}}, + batch_size=1, + ), + max_sequence_length=_BUDGET, + max_samples_per_batch=None, + **_PACKER_KWARGS, + ) + + # ── new stack ───────────────────────────────────────────────────── + stub_new = _FixedSFTDataset() + new = CosmosDataLoader( + distributor=RankPartitionedDistributor( + {"video": {"dataset": stub_new, "ratio": 1}} + ), + processor=IdentityProcessor(), + batcher=SequentialPackingBatcher( + max_sequence_length=_BUDGET, + max_samples_per_batch=None, + audio_sample_rate=48000, + **_PACKER_KWARGS, + ), + collator=VFMListCollator(), + num_workers=0, + ) + + # ── drain N batches and compare ─────────────────────────────────── + legacy_batches = _drain(legacy, N_BATCHES) + new_batches = _drain(new, N_BATCHES) + + assert len(legacy_batches) == N_BATCHES, f"Expected {N_BATCHES} legacy batches" + assert len(new_batches) == N_BATCHES, f"Expected {N_BATCHES} new batches" + + for i, (lb, nb) in enumerate(zip(legacy_batches, new_batches)): + _compare_batches_exact(lb, nb, batch_idx=i) + + finally: + if used_gloo and dist.is_initialized(): + dist.destroy_process_group() diff --git a/cosmos_framework/data/vfm/dataflow/loader.py b/cosmos_framework/data/vfm/dataflow/loader.py new file mode 100644 index 0000000..3a2a6a4 --- /dev/null +++ b/cosmos_framework/data/vfm/dataflow/loader.py @@ -0,0 +1,261 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: OpenMDW-1.1 + +"""CosmosDataLoader — slim orchestrator that wires the four dataflow roles +(DataDistributor -> RawItemProcessor -> SampleBatcher -> BatchCollator) inside +each DataLoader worker. The canonical training dataloader. +""" + +from __future__ import annotations + +import torch +import torch.utils.data +import numpy as np + +from cosmos_framework.utils import log +from cosmos_framework.data.vfm.dataflow.base import ( + BatchCollator, + DataDistributor, + RawItemProcessor, + SampleBatcher, +) +from cosmos_framework.data.vfm.dataflow.batchers import SimpleBatcher +from cosmos_framework.data.vfm.dataflow.collators import DefaultBatchCollator + + +class _DataflowIterableDataset(torch.utils.data.IterableDataset): + """Wires distributor -> processor -> batcher -> collator inside a worker.""" + + def __init__( + self, + distributor: DataDistributor, + processor: RawItemProcessor, + batcher: SampleBatcher, + collator: BatchCollator, + dp_rank: int, + dp_world_size: int, + ): + super().__init__() + self._distributor = distributor + self._processor = processor + self._batcher = batcher + self._collator = collator + self._dp_rank = dp_rank + self._dp_world_size = dp_world_size + + def __iter__(self): + info = torch.utils.data.get_worker_info() + worker_id, num_workers = (info.id, info.num_workers) if info else (0, 1) + raw = self._distributor.stream(self._dp_rank, self._dp_world_size, worker_id, num_workers) + + def _processed(): + for item in raw: + if isinstance(item, dict): + meta = {k: item.pop(k) for k in list(item) if k.startswith("_dp_")} + else: + meta = {} + s = self._processor.process(item) + if meta and isinstance(s, dict): + s.update(meta) + yield s + + for group in self._batcher.batches(_processed()): + has_meta = bool(group) and isinstance(group[0], dict) and "_dp_epoch" in group[0] + if has_meta: + epochs = [s["_dp_epoch"] for s in group] + positions = [s["_dp_stream_pos"] for s in group] + max_epoch = max(epochs) + max_pos = max(positions) + # Resume records (max_epoch, max_pos) and fast-forwards to max_pos+1 — + # bit-for-bit with the legacy collate_batch. That is gap-free only when + # this batch is a single sample (max_batch_size=1, all live recipes) or a + # single-epoch contiguous run (sequential packing). A reordering batcher + # (pool packing) at batch_size>1, or a batch spanning an epoch boundary, + # would leave buffered lower positions unrecorded and skip them on resume. + # Fail loudly rather than silently drop samples in that unsupported combo. + if len(group) > 1: + contiguous = min(epochs) == max_epoch and sorted(positions) == list( + range(min(positions), max_pos + 1) + ) + if not contiguous: + raise ValueError( + "Map-style resume cannot safely stamp a multi-sample batch whose " + "_dp_stream_pos values are non-contiguous or span epochs (reordering " + "batcher + batch_size>1). Use max_batch_size=1 with pool packing, a " + "sequential (order-preserving) batcher, or an iterable (non-resumable) " + "source." + ) + clean = [{k: v for k, v in s.items() if not k.startswith("_dp_")} for s in group] + batch = self._collator.collate(clean) + batch["sample_worker_id"] = torch.tensor([worker_id] * len(group)) + batch["sample_epoch"] = torch.tensor([max_epoch] * len(group)) + batch["sample_index"] = torch.tensor([max_pos] * len(group)) + else: + batch = self._collator.collate(group) + yield batch + + +class CosmosDataLoader(torch.utils.data.DataLoader): + """Public entry point: bring any dataset into training via four roles. + + Either pass an explicit ``batcher`` (and optional ``collator``), or pass a + bare ``batch_size=N`` for stock fixed-size batching — the loader then builds + ``SimpleBatcher(N)`` + ``DefaultBatchCollator()``. Passing both is an error. + + DP coordinates: ``parallel_dims.dp_coord`` > ``torch.distributed`` > (0, 1). + """ + + def __init__( + self, + distributor: DataDistributor, + processor: RawItemProcessor, + batcher: SampleBatcher | None = None, + collator: BatchCollator | None = None, + batch_size: int | None = None, + num_workers: int = 0, + prefetch_factor: int | None = None, + persistent_workers: bool = False, + pin_memory: bool = False, + parallel_dims=None, + ): + if batch_size is not None and batcher is not None: + raise ValueError( + "Pass either batch_size= (sugar) or an explicit batcher=, not both." + ) + if batch_size is None and batcher is None: + raise ValueError("Provide either a batcher= or a batch_size=.") + if batch_size is not None: + batcher = SimpleBatcher(batch_size=batch_size) + if collator is None: + collator = DefaultBatchCollator() + + if parallel_dims is not None: + dp_rank, dp_world_size = parallel_dims.dp_coord + elif torch.distributed.is_initialized(): + dp_rank = torch.distributed.get_rank() + dp_world_size = torch.distributed.get_world_size() + if dp_world_size > 1: + log.info( + "CosmosDataLoader: using global rank for DP sharding. " + "For FSDP+TP/PP pass parallel_dims= for the correct DP rank.", + rank0_only=True, + ) + else: + dp_rank, dp_world_size = 0, 1 + + dataset = _DataflowIterableDataset( + distributor=distributor, + processor=processor, + batcher=batcher, + collator=collator, + dp_rank=dp_rank, + dp_world_size=dp_world_size, + ) + + from cosmos_framework.data.vfm.dataflow.distributors import MapDistributor + + if isinstance(distributor, MapDistributor) and num_workers > 0 and not persistent_workers: + log.info( + "CosmosDataLoader: MapDistributor requires persistent_workers=True for " + "correct stateful resume; overriding to True.", + rank0_only=True, + ) + persistent_workers = True + + if persistent_workers and num_workers == 0: + log.info( + "CosmosDataLoader: persistent_workers=True ignored because num_workers=0.", + rank0_only=True, + ) + persistent_workers = False + + loader_kwargs: dict = dict( + num_workers=num_workers, + persistent_workers=persistent_workers, + pin_memory=pin_memory, + ) + if num_workers > 0 and prefetch_factor is not None: + loader_kwargs["prefetch_factor"] = prefetch_factor + super().__init__(dataset, batch_size=None, **loader_kwargs) + + +class JointCosmosDataLoader: + """Wraps multiple ``CosmosDataLoader`` instances with ratio-based seeded selection. + + One output batch = one inner loader, selected deterministically by ratio at each + step. Adds a ``"dataset_name"`` key to every yielded batch so downstream callbacks + can route state updates to the correct inner loader. + + Parameters + ---------- + dataloaders: + ``{name: {"dataloader": CosmosDataLoader, "ratio": int}}`` mapping. + Entries with ``ratio <= 0`` are silently skipped. + seed: + Base seed for per-step dataset selection. Step ``i`` uses + ``np.random.RandomState(seed + i)`` — fully reproducible on resume via + ``set_start_iteration``. + """ + + def __init__( + self, + dataloaders: dict, + seed: int = 42, + ) -> None: + entries = [ + (name, cfg["dataloader"], cfg["ratio"]) + for name, cfg in dataloaders.items() + if cfg.get("ratio", 0) > 0 + ] + if not entries: + raise ValueError("JointCosmosDataLoader: no dataloaders with ratio > 0") + + self._names: list[str] = [e[0] for e in entries] + if "global_id" in self._names: + raise ValueError( + "JointCosmosDataLoader: dataset name 'global_id' is reserved " + "by the checkpoint state format; use a different name." + ) + self._loaders: list[CosmosDataLoader] = [e[1] for e in entries] + ratios = np.array([e[2] for e in entries], dtype=float) + self._probs: np.ndarray = ratios / ratios.sum() + self._seed = seed + self._global_id = 0 + # Iterators are created lazily on the first __iter__ call so that + # DataLoaderStateCallback.load_state_dict can install resume env vars + # before workers are spawned (for num_workers > 0, iter(DataLoader) + # forks workers immediately; env vars must be set in the parent first). + self._iterators: list | None = None + + total = ratios.sum() + lines = [f"JointCosmosDataLoader: {len(self._names)} streams"] + for name, ratio in zip(self._names, ratios): + lines.append(f" {name}: ratio={ratio:.4g} ({ratio / total:.1%})") + log.info("\n".join(lines)) + + def set_start_iteration(self, iteration: int) -> None: + """Restore deterministic selection sequence after checkpoint resume. + + Called by ``JointCosmosDataLoaderStateCallback.load_state_dict`` and by the + trainer (if present) via ``hasattr`` guard. + """ + self._global_id = iteration + + def __iter__(self): + # Lazy init: create iterators here (not in __init__) so that + # load_state_dict can set resume env vars before workers fork. + if self._iterators is None: + self._iterators = [iter(loader) for loader in self._loaders] + while True: + rng = np.random.RandomState(self._seed + self._global_id) + idx = int(rng.choice(len(self._loaders), p=self._probs)) + try: + batch = next(self._iterators[idx]) + except StopIteration: + # Inner CosmosDataLoaders are infinite; this guard handles + # the unlikely case of a finite IterableDataset inner source. + self._iterators[idx] = iter(self._loaders[idx]) + batch = next(self._iterators[idx]) + batch["dataset_name"] = self._names[idx] + self._global_id += 1 + yield batch diff --git a/cosmos_framework/data/vfm/dataflow/processors.py b/cosmos_framework/data/vfm/dataflow/processors.py new file mode 100644 index 0000000..fc2b6af --- /dev/null +++ b/cosmos_framework/data/vfm/dataflow/processors.py @@ -0,0 +1,17 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: OpenMDW-1.1 + +"""Built-in RawItemProcessor implementations.""" + +from __future__ import annotations + +from typing import Any + +from cosmos_framework.data.vfm.dataflow.base import RawItemProcessor + + +class IdentityProcessor(RawItemProcessor): + """No-op processor: the dataset already yields training-ready samples.""" + + def process(self, item: Any) -> Any: + return item diff --git a/cosmos_framework/data/vfm/dataflow/resume_test.py b/cosmos_framework/data/vfm/dataflow/resume_test.py new file mode 100644 index 0000000..2fde6b9 --- /dev/null +++ b/cosmos_framework/data/vfm/dataflow/resume_test.py @@ -0,0 +1,58 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: OpenMDW-1.1 + +"""Checkpoint->restart resume parity for CosmosDataLoader(MapDistributor) using +CosmosDataLoaderStateCallback. Single process, num_workers=0.""" + +from __future__ import annotations + +import torch + +from cosmos_framework.callbacks.cosmos_dataloader_state import CosmosDataLoaderStateCallback +from cosmos_framework.data.vfm.dataflow import ( + CosmosDataLoader, + IdentityProcessor, + MapDistributor, +) + + +class _IdDS(torch.utils.data.Dataset): + def __init__(self, n): + self._n = n + + def __len__(self): + return self._n + + def __getitem__(self, idx): + return {"id": torch.tensor(idx)} + + +def _build(seed=0): + return CosmosDataLoader( + distributor=MapDistributor(_IdDS(20), shuffle=False, seed=seed), + processor=IdentityProcessor(), + batch_size=1, + num_workers=0, + ) + + +def test_resume_continues_without_dup_or_skip(): + cb = CosmosDataLoaderStateCallback() + loader = _build() + it = iter(loader) + seen_ids = [] + for _ in range(5): + b = next(it) + cb._update_state_from_batch(b) + seen_ids.append(b["id"].item()) + assert seen_ids == [0, 1, 2, 3, 4] + + state = cb.state_dict() + assert state[0]["index"] == 4 + cb2 = CosmosDataLoaderStateCallback() + cb2.load_state_dict(state) + + loader2 = _build() + it2 = iter(loader2) # one iterator: env-var fast-forward happens once, then continues + resumed = [next(it2)["id"].item() for _ in range(3)] + assert resumed == [5, 6, 7] diff --git a/cosmos_framework/data/vfm/packing_iterable_dataset.py b/cosmos_framework/data/vfm/packing_iterable_dataset.py deleted file mode 100644 index 847b85a..0000000 --- a/cosmos_framework/data/vfm/packing_iterable_dataset.py +++ /dev/null @@ -1,276 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: OpenMDW-1.1 - -""" -Abstract base class for pool-based token-budget bin-packing over multiple datasets. - -Extracted from ``projects.cosmos3.vfm.datasets.vlm.joint_dataset_dynamic_batch_webloader`` -so that both the VLM and VFM internal dataloaders can share a single packing implementation. - -Usage ------ -Subclass and implement ``compute_sample_tokens(sample) -> int``. -Optionally override ``collate_batch(samples) -> Any`` for custom collation. -""" - -from __future__ import annotations - -import random -from abc import ABC, abstractmethod -from collections import deque -from enum import Enum -from typing import Any, Union - -import torch - -from cosmos_framework.utils.lazy_config import instantiate -from cosmos_framework.utils import log - - -class Modality(Enum): - IMAGE = "image" - VIDEO = "video" - TEXT = "text" - - -class PackingIterableDataset(torch.utils.data.IterableDataset, ABC): - """Pool-based greedy bin-packing IterableDataset. - - Maintains a pool of ``pool_size`` samples and assembles batches by - greedily selecting candidates that fit within the token budget - ``max_tokens``. Subclasses supply two hooks: - - * ``compute_sample_tokens(sample)`` — token cost of one sample (abstract). - * ``collate_batch(samples)`` — assemble a packed list into a batch - (default: identity, returns the list unchanged). - - Parameters - ---------- - datasets_cfg: - Mapping ``{name: {"dataset": , "ratio": }}``. - The *dataset* value may be a Hydra lazy config, an already-constructed - ``IterableDataset``, or a plain ``DataLoader`` (its ``.dataset`` is - unwrapped automatically). - max_tokens: - Token budget per batch (padded cost = ``cur_max_len * batch_size``). - pool_size: - Number of samples to buffer before selecting a batch. - max_batch_size: - Hard cap on items per batch (0 or None = no cap). - long_threshold: - Samples with token count ``>= long_threshold`` are emitted as - singletons regardless of budget. - batching_strategy: - ``"prefer_closest"`` (default) or ``"prefer_first"``. - apply_long_sample_halving: - When ``True`` (default), ``_max_tokens`` halves the budget for any - batch whose largest sample exceeds 1000 tokens — a memory-safety - heuristic. Set ``False`` only when memory headroom at the literal - ``max_tokens`` budget has been validated for the recipe. - """ - - def __init__( - self, - datasets_cfg: dict[str, dict[str, Union[int, object]]], - max_tokens: int, - pool_size: int, - max_batch_size: int, - long_threshold: int, - batching_strategy: str, - apply_long_sample_halving: bool = True, - ): - super().__init__() - - assert batching_strategy in ("prefer_first", "prefer_closest"), ( - f"batching_strategy must be 'prefer_first' or 'prefer_closest', got {batching_strategy!r}" - ) - - self.max_tokens = max_tokens - self.pool_size = pool_size - self.long_threshold = long_threshold - self.max_batch_size = max_batch_size - self.batching_strategy = batching_strategy - self.apply_long_sample_halving = apply_long_sample_halving - - self._pool: deque[dict] = deque() - self._dataset_names: list[str] = [] - self._ratios: list[float] = [] - self._datasets: list[torch.utils.data.IterableDataset] = [] - - for name, cfg in datasets_cfg.items(): - assert {"ratio", "dataset"} <= cfg.keys(), ( - f"Each entry must have 'dataset' and 'ratio' keys: {name} -> {cfg.keys()}" - ) - ratio = cfg["ratio"] - if ratio == 0: - log.info(f"Skipping dataset {name} with ratio {ratio}") - continue - dataset_cfg = cfg["dataset"] - - ds = ( - instantiate(dataset_cfg) - if not isinstance(dataset_cfg, (torch.utils.data.IterableDataset, torch.utils.data.DataLoader)) - else dataset_cfg - ) - if isinstance(ds, torch.utils.data.DataLoader): - ds = ds.dataset - if hasattr(ds, "build_dataset") and callable(getattr(ds, "build_dataset")): - ds = ds.build_dataset() - - assert isinstance(ds, torch.utils.data.IterableDataset), ( - f"Expected an IterableDataset, got {type(ds)} for {name}" - ) - - self._dataset_names.append(name) - self._ratios.append(float(ratio)) - self._datasets.append(ds) - log.info(f"Added dataset {name} with ratio {ratio}") - - log.info(f"added data: {list(datasets_cfg.keys())}") - assert len(self._datasets) > 0, "No datasets added" - self._data_len: int = sum(int(getattr(ds, "total_images", 0)) for ds in self._datasets) - if self._data_len == 0: - self._data_len = 10**12 - self.iterators = [iter(ds) for ds in self._datasets] - - # ------------------------------------------------------------------ - # Abstract / overridable hooks - # ------------------------------------------------------------------ - - @abstractmethod - def compute_sample_tokens(self, sample: dict) -> int: - """Return the token cost of one sample for packing budget accounting.""" - - def collate_batch(self, samples: list[dict]) -> Any: - """Assemble a packed list of samples into one batch. - - Default implementation returns the list unchanged (identity). - Override to pad, stack, or transform samples into tensors. - """ - return samples - - # ------------------------------------------------------------------ - # PyTorch Dataset API - # ------------------------------------------------------------------ - - def __len__(self) -> int: - return self._data_len - - def __iter__(self): - while True: - batch = self._best_fit_batch() - yield self.collate_batch(batch) - - # ------------------------------------------------------------------ - # Internal packing helpers (moved verbatim from _JointIterableDataset) - # ------------------------------------------------------------------ - - def _max_tokens(self, cur_max: int) -> int: - if not self.apply_long_sample_halving: - return self.max_tokens - if cur_max < 1000: - return self.max_tokens - return self.max_tokens // 2 - - def _get_next_sample(self) -> dict: - index_id = random.choices(range(len(self.iterators)), weights=self._ratios, k=1)[0] - curr_dataset = self.iterators[index_id] - try: - output = next(curr_dataset) - except StopIteration: - log.critical(f"dataset {self._dataset_names[index_id]} exhausted") - self.iterators[index_id] = iter(self._datasets[index_id]) - output = next(self.iterators[index_id]) - return output - - def _fill_pool(self): - while len(self._pool) < self.pool_size: - self._pool.append(self._get_next_sample()) - - def _padded_cost(self, cur_max: int, k: int) -> int: - return cur_max * k - - def _get_modality(self, sample: dict) -> Modality: - if "pixel_values" in sample: - return Modality.IMAGE - elif "pixel_values_videos" in sample: - return Modality.VIDEO - return Modality.TEXT - - def _best_fit_batch(self) -> list[dict]: - """Build one batch using the configured token-budget strategy.""" - self._fill_pool() - seed = self._pool.popleft() - seed_modality = self._get_modality(seed) - L0 = self.compute_sample_tokens(seed) - - if L0 >= self.long_threshold or L0 >= self._max_tokens(L0): - return [seed] - - chosen = [seed] - cur_max = L0 - - while self._pool: - if self.max_batch_size and len(chosen) >= self.max_batch_size: - break - best_idx = self._find_best_candidate(cur_max, len(chosen), seed_modality) - if best_idx is None: - break - cand = self._remove_from_pool(best_idx) - chosen.append(cand) - cur_max = max(cur_max, self.compute_sample_tokens(cand)) - - return chosen - - def _find_best_candidate(self, cur_max: int, num_chosen: int, seed_modality: Modality) -> int | None: - if self.batching_strategy == "prefer_first": - return self._find_best_candidate_prefer_first(cur_max, num_chosen, seed_modality) - return self._find_best_candidate_prefer_closest(cur_max, num_chosen, seed_modality) - - def _find_best_candidate_prefer_first(self, cur_max: int, num_chosen: int, seed_modality: Modality) -> int | None: - best_idx = None - best_new_tokens = None - for idx, cand in enumerate(self._pool): - if self._get_modality(cand) != seed_modality: - continue - L = self.compute_sample_tokens(cand) - new_max = max(cur_max, L) - new_tokens = self._padded_cost(new_max, num_chosen + 1) - if new_tokens <= self._max_tokens(cur_max): - if best_new_tokens is None or new_tokens < best_new_tokens: - best_new_tokens = new_tokens - best_idx = idx - return best_idx - - def _find_best_candidate_prefer_closest(self, cur_max: int, num_chosen: int, seed_modality: Modality) -> int | None: - best_idx = None - best_new_tokens = None - smallest_length_diff = None - for idx, cand in enumerate(self._pool): - if self._get_modality(cand) != seed_modality: - continue - L = self.compute_sample_tokens(cand) - new_max = max(cur_max, L) - new_tokens = self._padded_cost(new_max, num_chosen + 1) - if new_tokens <= self._max_tokens(cur_max): - length_diff = abs(L - cur_max) - if ( - best_new_tokens is None - or new_tokens < best_new_tokens - or (new_tokens == best_new_tokens and length_diff < smallest_length_diff) - ): - best_new_tokens = new_tokens - best_idx = idx - smallest_length_diff = length_diff - return best_idx - - def _remove_from_pool(self, idx: int) -> dict: - if idx == 0: - return self._pool.popleft() - elif idx == len(self._pool) - 1: - return self._pool.pop() - else: - self._pool.rotate(-idx) - item = self._pool.popleft() - self._pool.rotate(idx) - return item diff --git a/cosmos_framework/data/vfm/packing_iterable_dataset_test.py b/cosmos_framework/data/vfm/packing_iterable_dataset_test.py deleted file mode 100644 index 2fb7f4f..0000000 --- a/cosmos_framework/data/vfm/packing_iterable_dataset_test.py +++ /dev/null @@ -1,78 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: OpenMDW-1.1 - -"""Unit tests for the ``apply_long_sample_halving`` knob added to -:class:`cosmos_framework.data.vfm.packing_iterable_dataset.PackingIterableDataset` with -the wandb + DataPacker spec (2026-05-21-toml-interface-wandb-datapacker-design.md). -""" - -from __future__ import annotations - -import torch - -from cosmos_framework.data.vfm.packing_iterable_dataset import PackingIterableDataset - - -class _StubIterable(torch.utils.data.IterableDataset): - """Trivial finite IterableDataset; the halving tests don't iterate it.""" - - def __iter__(self): - yield from () - - -class _StubPacking(PackingIterableDataset): - """Minimal concrete subclass; ``_max_tokens`` is the SUT, so we just need - a constructable instance — ``compute_sample_tokens`` isn't called by these - tests.""" - - def compute_sample_tokens(self, sample: dict) -> int: # pragma: no cover - unused - return 0 - - -def _make(apply_long_sample_halving: bool = True) -> _StubPacking: - return _StubPacking( - datasets_cfg={"default": {"dataset": _StubIterable(), "ratio": 1.0}}, - max_tokens=45056, - pool_size=16, - max_batch_size=1, - long_threshold=6400, - batching_strategy="prefer_closest", - apply_long_sample_halving=apply_long_sample_halving, - ) - - -# ----- halving heuristic ---------------------------------------------------- - - -def test_default_applies_halving_above_threshold(): - """Default behavior: cur_max >= 1000 triggers ``max_tokens // 2``.""" - ds = _make() - assert ds.apply_long_sample_halving is True - assert ds._max_tokens(999) == 45056 # below threshold → full budget - assert ds._max_tokens(1000) == 22528 # at threshold → halved - assert ds._max_tokens(5000) == 22528 # well above → halved - - -def test_halving_disabled_keeps_full_budget(): - """``apply_long_sample_halving=False`` returns ``max_tokens`` literally.""" - ds = _make(apply_long_sample_halving=False) - assert ds.apply_long_sample_halving is False - assert ds._max_tokens(999) == 45056 - assert ds._max_tokens(1000) == 45056 # would have been halved with default - assert ds._max_tokens(50_000) == 45056 - - -def test_halving_default_is_true_when_unspecified(): - """Backwards compat: constructing without the new kwarg keeps the original - (halving-active) behavior bit-for-bit — every existing recipe is unchanged. - """ - ds = _StubPacking( - datasets_cfg={"default": {"dataset": _StubIterable(), "ratio": 1.0}}, - max_tokens=10_000, - pool_size=16, - max_batch_size=1, - long_threshold=6400, - batching_strategy="prefer_closest", - ) - assert ds.apply_long_sample_halving is True - assert ds._max_tokens(2000) == 5000 diff --git a/cosmos_framework/data/vfm/test_dp_state_distributed.py b/cosmos_framework/data/vfm/test_dp_state_distributed.py deleted file mode 100644 index 732addd..0000000 --- a/cosmos_framework/data/vfm/test_dp_state_distributed.py +++ /dev/null @@ -1,683 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: OpenMDW-1.1 - -""" -Distributed dataloader state checkpoint/resume test. - -Runs with torchrun --nproc_per_node=4. Tests the full path for both -shuffle=True and shuffle=False: - - 1. Each rank trains N batches with DataPackerDataLoader. - 2. DataLoaderStateCallback collects per-worker (epoch, index) state — - both epoch and position within the epoch are saved. - 3. Each rank saves its state to rank_{rank}.pkl via pickle - (same format as DistributedCheckpointer._save_as_pkl). - 4. Rank 0 verifies all 4 pkl files are non-empty, contain both epoch - and index, and that each rank will resume to a distinct item ID - (confirming disjoint sharding). - 5. All ranks load their pkl back and call load_state_dict() - -> sets DP_STATE_WORKER_*_EPOCH/INDEX env vars. - 6. Each rank creates a new DataPackerDataLoader and verifies the first - item resumes from the correct (epoch, position). - -shuffle=True: per-epoch randperm — expected next id from perm[rank::world_size][saved_index+1] -shuffle=False: sequential — expected next id == saved_index+1 (within this rank's stride) - -Usage: - torchrun --nproc_per_node=4 --master_port=50025 \ - cosmos_framework/data/vfm/test_dp_state_distributed.py -""" - -import os -import pickle -import shutil -import tempfile - -import numpy as np -import torch -import torch.distributed as dist -import torch.utils.data - -from cosmos_framework.data.vfm.data_packer import DataPacker -from cosmos_framework.data.vfm.data_packer_dataloader import DataPackerDataLoader, JointDataPackerDataLoader -from cosmos_framework.callbacks.dataloader_state import DataLoaderStateCallback, JointDataLoaderStateCallback - - -# --------------------------------------------------------------------------- -# Minimal fixtures -# --------------------------------------------------------------------------- - -class SimplePacker(DataPacker): - def sft_process_sample(self, item): - return item - - def compute_num_tokens(self, sample): - return 1 - - def sft_collate_fn(self, samples, max_len, ignore_label_id=-100): - return {"ids": torch.tensor([s["id"] for s in samples])} - - -class SimpleDataset(torch.utils.data.Dataset): - """Map-style dataset: items are {'id': i} for i in range(n).""" - def __init__(self, n=10_000): - self.n = n - def __len__(self): - return self.n - def __getitem__(self, i): - return {"id": i} - - -# --------------------------------------------------------------------------- -# Reusable test helper -# --------------------------------------------------------------------------- - -def run_state_test(rank, world_size, shuffle, seed, tmp_dir, n_batches=5, dataset_size=10_000): - """Run the full train → pkl-save → verify → resume cycle for one shuffle mode.""" - - label = f"shuffle={'True' if shuffle else 'False'}" - - class FakeParallelDims: - @property - def dp_coord(self): - return (rank, world_size) - - # ------------------------------------------------------------------ - # Phase 1: train n_batches, collect state via callback - # ------------------------------------------------------------------ - cb = DataLoaderStateCallback(distributor_type="data_packer") - loader = DataPackerDataLoader( - data_source=SimpleDataset(dataset_size), - data_packer=SimplePacker(), - max_tokens=50, - max_batch_size=4, - shuffle=shuffle, - seed=seed, - num_workers=0, - parallel_dims=FakeParallelDims(), - ) - - for i, batch in enumerate(loader): - cb._update_state_from_batch(batch) - if i + 1 >= n_batches: - break - - state = cb.state_dict() - assert state, f"[rank {rank}][{label}] empty state after {n_batches} batches" - saved_epoch = state[0]["epoch"] - saved_index = state[0]["index"] - print(f"[rank {rank}][{label}] phase1: epoch={saved_epoch}, index={saved_index}", flush=True) - - # Save pkl (one file per rank, matching DistributedCheckpointer._save_as_pkl) - pkl_path = os.path.join(tmp_dir, f"rank_{rank}.pkl") - with open(pkl_path, "wb") as f: - pickle.dump(state, f) - - dist.barrier() - - # ------------------------------------------------------------------ - # Phase 2: rank 0 verifies all pkl files and disjoint sharding - # ------------------------------------------------------------------ - if rank == 0: - print(f"\n[rank 0][{label}] Verifying all rank pkl files...", flush=True) - all_states = {} - for r in range(world_size): - path = os.path.join(tmp_dir, f"rank_{r}.pkl") - assert os.path.exists(path), f"missing {path}" - with open(path, "rb") as f: - s = pickle.load(f) - assert s, f"rank {r}: empty state" - assert 0 in s, f"rank {r}: worker_id 0 missing" - assert "epoch" in s[0] and "index" in s[0], \ - f"rank {r}: state missing epoch or index keys — got {s[0].keys()}" - all_states[r] = s - print( - f" rank_{r}.pkl: epoch={s[0]['epoch']}, index={s[0]['index']}", - flush=True, - ) - - # Reconstruct ground-truth next id for each rank and verify disjoint - if shuffle: - g = torch.Generator().manual_seed(seed) - perm = torch.randperm(dataset_size, generator=g).tolist() - else: - perm = list(range(dataset_size)) - - first_ids = [] - for r in range(world_size): - saved_idx = all_states[r][0]["index"] - stream_slice = perm[r::world_size] # num_workers=0 → stream_id=r - first_ids.append(stream_slice[saved_idx + 1]) - print(f" rank_{r}: index={saved_idx}, next_id={first_ids[-1]}", flush=True) - - assert len(set(first_ids)) == world_size, \ - f"ranks share next item ids — sharding broken: {first_ids}" - print( - f" All {world_size} ranks will resume to distinct item IDs: {first_ids} OK", - flush=True, - ) - - dist.barrier() - - # ------------------------------------------------------------------ - # Phase 3: each rank loads its pkl and resumes - # ------------------------------------------------------------------ - with open(pkl_path, "rb") as f: - loaded_state = pickle.load(f) - - cb2 = DataLoaderStateCallback(distributor_type="data_packer") - cb2.load_state_dict(loaded_state) - - assert os.environ.get("DP_STATE_WORKER_0_EPOCH") == str(saved_epoch), \ - f"[rank {rank}][{label}] env EPOCH mismatch" - assert os.environ.get("DP_STATE_WORKER_0_INDEX") == str(saved_index), \ - f"[rank {rank}][{label}] env INDEX mismatch" - - loader2 = DataPackerDataLoader( - data_source=SimpleDataset(dataset_size), - data_packer=SimplePacker(), - max_tokens=50, - max_batch_size=1, - shuffle=shuffle, - seed=seed, - num_workers=0, - parallel_dims=FakeParallelDims(), - ) - first_batch = next(iter(loader2)) - actual_pos = first_batch["sample_index"][0].item() - actual_id = first_batch["ids"][0].item() - - # Ground truth: reconstruct this rank's stream slice - if shuffle: - g = torch.Generator().manual_seed(seed + saved_epoch) - perm = torch.randperm(dataset_size, generator=g).tolist() - else: - perm = list(range(dataset_size)) - stream_slice = perm[rank::world_size] - expected_pos = saved_index + 1 - expected_id = stream_slice[expected_pos] - - assert actual_pos == expected_pos, \ - f"[rank {rank}][{label}] position mismatch: expected {expected_pos}, got {actual_pos}" - assert actual_id == expected_id, \ - f"[rank {rank}][{label}] id mismatch: expected {expected_id}, got {actual_id}" - - print( - f"[rank {rank}][{label}] resume: pos={actual_pos} (expected {expected_pos}), " - f"id={actual_id} (expected {expected_id}) OK", - flush=True, - ) - - dist.barrier() - - # Clean up pkl files for next test run - os.remove(pkl_path) - dist.barrier() - - -def run_state_test_multi_worker( - rank, world_size, shuffle, seed, tmp_dir, n_batches=20, dataset_size=10_000, num_workers=2 -): - """State checkpoint/resume test with num_workers > 1. - - With num_workers workers per rank, DataLoaderStateCallback tracks state - for each worker_id (0..num_workers-1) independently. The saved pkl - contains entries for all worker_ids; on resume each worker reads its own - env var and fast-forwards to the correct position. - - Verification: after resume, every (worker_id, sample_index) pair seen in - the first resumed batches must have sample_index >= saved_index_for_that_worker + 1. - """ - label = f"shuffle={'True' if shuffle else 'False'}, num_workers={num_workers}" - - class FakeParallelDims: - @property - def dp_coord(self): - return (rank, world_size) - - # ------------------------------------------------------------------ - # Phase 1: train n_batches, collect per-worker state - # ------------------------------------------------------------------ - cb = DataLoaderStateCallback(distributor_type="data_packer") - loader = DataPackerDataLoader( - data_source=SimpleDataset(dataset_size), - data_packer=SimplePacker(), - max_tokens=50, - max_batch_size=4, - shuffle=shuffle, - seed=seed, - num_workers=num_workers, - prefetch_factor=2, - parallel_dims=FakeParallelDims(), - ) - - for i, batch in enumerate(loader): - cb._update_state_from_batch(batch) - if i + 1 >= n_batches: - break - - state = cb.state_dict() - assert state, f"[rank {rank}][{label}] empty state after {n_batches} batches" - assert len(state) == num_workers, \ - f"[rank {rank}][{label}] expected {num_workers} worker entries, got {len(state)}" - for wid in range(num_workers): - assert wid in state, f"[rank {rank}][{label}] worker_id {wid} missing from state" - assert "epoch" in state[wid] and "index" in state[wid], \ - f"[rank {rank}][{label}] worker {wid} state missing epoch/index" - print( - f"[rank {rank}][{label}] phase1: " - + ", ".join(f"w{wid}=(epoch={state[wid]['epoch']},idx={state[wid]['index']})" - for wid in sorted(state)), - flush=True, - ) - - # Save pkl - pkl_path = os.path.join(tmp_dir, f"rank_{rank}.pkl") - with open(pkl_path, "wb") as f: - pickle.dump(state, f) - - dist.barrier() - - # ------------------------------------------------------------------ - # Phase 2: rank 0 verifies env vars will be set for all worker_ids - # ------------------------------------------------------------------ - if rank == 0: - print(f"\n[rank 0][{label}] Verifying pkl files contain all worker_ids...", flush=True) - for r in range(world_size): - path = os.path.join(tmp_dir, f"rank_{r}.pkl") - with open(path, "rb") as f: - s = pickle.load(f) - for wid in range(num_workers): - assert wid in s, f"rank {r}: worker_id {wid} missing" - print(f" rank_{r}: workers {sorted(s.keys())} — OK", flush=True) - - dist.barrier() - - # ------------------------------------------------------------------ - # Phase 3: load pkl, verify env vars set for all workers, resume - # ------------------------------------------------------------------ - with open(pkl_path, "rb") as f: - loaded_state = pickle.load(f) - - cb2 = DataLoaderStateCallback(distributor_type="data_packer") - cb2.load_state_dict(loaded_state) - - for wid in range(num_workers): - saved_epoch = loaded_state[wid]["epoch"] - saved_index = loaded_state[wid]["index"] - assert os.environ.get(f"DP_STATE_WORKER_{wid}_EPOCH") == str(saved_epoch), \ - f"[rank {rank}][{label}] w{wid} env EPOCH mismatch" - assert os.environ.get(f"DP_STATE_WORKER_{wid}_INDEX") == str(saved_index), \ - f"[rank {rank}][{label}] w{wid} env INDEX mismatch" - - # Resume: iterate until we have seen the first batch from each worker, then - # verify exact (position, item_id) matches the deterministic permutation. - # This confirms the ordering is identical to what would have been produced - # without a checkpoint, not merely that positions are monotonically increasing. - loader2 = DataPackerDataLoader( - data_source=SimpleDataset(dataset_size), - data_packer=SimplePacker(), - max_tokens=50, - max_batch_size=1, - shuffle=shuffle, - seed=seed, - num_workers=num_workers, - prefetch_factor=2, - parallel_dims=FakeParallelDims(), - ) - - # Collect the first batch produced by each worker after resume. - first_per_worker: dict = {} - for batch in loader2: - wid = int(batch["sample_worker_id"][0].item()) - if wid not in first_per_worker: - first_per_worker[wid] = ( - int(batch["sample_index"][0].item()), - int(batch["ids"][0].item()), - ) - if len(first_per_worker) == num_workers: - break - - # Reconstruct the ground-truth permutation for this epoch. - saved_epoch0 = loaded_state[0]["epoch"] # all workers share the same epoch - if shuffle: - g = torch.Generator().manual_seed(seed + saved_epoch0) - perm = torch.randperm(dataset_size, generator=g).tolist() - else: - perm = list(range(dataset_size)) - - for wid in range(num_workers): - saved_index = loaded_state[wid]["index"] - # stream_id for this worker on this rank: - # stream_id = rank * num_workers + wid - stream_id = rank * num_workers + wid - stream_slice = perm[stream_id::(world_size * num_workers)] - expected_pos = saved_index + 1 - expected_id = stream_slice[expected_pos] - - actual_pos, actual_id = first_per_worker[wid] - assert actual_pos == expected_pos, \ - f"[rank {rank}][{label}] w{wid} pos mismatch: expected {expected_pos}, got {actual_pos}" - assert actual_id == expected_id, \ - f"[rank {rank}][{label}] w{wid} id mismatch: expected {expected_id}, got {actual_id}" - print( - f"[rank {rank}][{label}] w{wid}: resume pos={actual_pos} (expected {expected_pos}), " - f"id={actual_id} (expected {expected_id}) OK", - flush=True, - ) - - dist.barrier() - - os.remove(pkl_path) - dist.barrier() - - -# --------------------------------------------------------------------------- -# JointDataPackerDataLoader tests -# --------------------------------------------------------------------------- - -def run_joint_selection_test(rank, world_size, seed=42, n_batches=20, dataset_size=10_000): - """Verify JointDataPackerDataLoader produces the expected deterministic selection sequence. - - Reconstructs the expected dataset_name sequence using the same - np.random.RandomState(seed + global_id) formula and asserts it matches. - Each rank runs independently (selection is identical across ranks since - it depends only on seed + global_id, not on rank). - """ - - class FakeParallelDims: - @property - def dp_coord(self): - return (rank, world_size) - - loader_a = DataPackerDataLoader( - data_source=SimpleDataset(dataset_size), - data_packer=SimplePacker(), - max_tokens=50, - max_batch_size=4, - name="ds_a", - num_workers=0, - parallel_dims=FakeParallelDims(), - ) - loader_b = DataPackerDataLoader( - data_source=SimpleDataset(dataset_size), - data_packer=SimplePacker(), - max_tokens=50, - max_batch_size=4, - name="ds_b", - num_workers=0, - parallel_dims=FakeParallelDims(), - ) - joint = JointDataPackerDataLoader( - dataloaders={ - "ds_a": {"dataloader": loader_a, "ratio": 3}, - "ds_b": {"dataloader": loader_b, "ratio": 1}, - }, - seed=seed, - ) - - observed = [] - for i, batch in enumerate(joint): - assert "dataset_name" in batch, f"[rank {rank}] dataset_name key missing from batch" - observed.append(batch["dataset_name"]) - if i + 1 >= n_batches: - break - - # Reconstruct expected sequence - probs = np.array([3, 1], dtype=float) / 4.0 - expected = [] - for i in range(n_batches): - rng = np.random.RandomState(seed + i) - idx = int(rng.choice(2, p=probs)) - expected.append(["ds_a", "ds_b"][idx]) - - assert observed == expected, ( - f"[rank {rank}] selection mismatch:\n observed={observed}\n expected={expected}" - ) - print(f"[rank {rank}][TEST 5] deterministic selection OK: {observed}", flush=True) - - dist.barrier() - - -def run_joint_state_test(rank, world_size, shuffle, seed, tmp_dir, n_batches=10, dataset_size=10_000): - """Full checkpoint/resume cycle for JointDataPackerDataLoader. - - Phase 1: train n_batches, collect state via JointDataLoaderStateCallback. - Phase 2: save to pkl, reload, verify global_id. - Phase 3: create fresh joint loader, call load_state_dict, verify: - - first batch dataset_name matches expected selection at global_id step - - first batch sample_index == saved_index + 1 for that dataset - """ - label = f"shuffle={'True' if shuffle else 'False'}" - - class FakeParallelDims: - @property - def dp_coord(self): - return (rank, world_size) - - # ------------------------------------------------------------------ - # Phase 1: train n_batches, collect state - # ------------------------------------------------------------------ - loader_a = DataPackerDataLoader( - data_source=SimpleDataset(dataset_size), - data_packer=SimplePacker(), - max_tokens=50, - max_batch_size=4, - shuffle=shuffle, - seed=seed, - name="ds_a", - num_workers=0, - parallel_dims=FakeParallelDims(), - ) - loader_b = DataPackerDataLoader( - data_source=SimpleDataset(dataset_size), - data_packer=SimplePacker(), - max_tokens=50, - max_batch_size=4, - shuffle=shuffle, - seed=seed + 1, - name="ds_b", - num_workers=0, - parallel_dims=FakeParallelDims(), - ) - joint = JointDataPackerDataLoader( - dataloaders={ - "ds_a": {"dataloader": loader_a, "ratio": 3}, - "ds_b": {"dataloader": loader_b, "ratio": 1}, - }, - seed=seed, - ) - cb = JointDataLoaderStateCallback(outer_loader=joint, distributor_type="data_packer") - - for i, batch in enumerate(joint): - cb._update_state_from_batch(batch) - if i + 1 >= n_batches: - break - - saved_state = cb.state_dict() - saved_global_id = saved_state["global_id"] - assert saved_global_id == n_batches, \ - f"[rank {rank}][{label}] expected global_id={n_batches}, got {saved_global_id}" - - print( - f"[rank {rank}][{label}] phase1: global_id={saved_global_id}, " - + ", ".join( - f"{name}=w0(epoch={saved_state[name][0]['epoch']},idx={saved_state[name][0]['index']})" - for name in ("ds_a", "ds_b") - if name in saved_state and saved_state[name] - ), - flush=True, - ) - - pkl_path = os.path.join(tmp_dir, f"joint_rank_{rank}.pkl") - with open(pkl_path, "wb") as f: - pickle.dump(saved_state, f) - - dist.barrier() - - # ------------------------------------------------------------------ - # Phase 2: reload, verify global_id - # ------------------------------------------------------------------ - with open(pkl_path, "rb") as f: - loaded_state = pickle.load(f) - - assert loaded_state["global_id"] == saved_global_id, \ - f"[rank {rank}][{label}] global_id mismatch after reload" - - # ------------------------------------------------------------------ - # Phase 3: resume, verify first batch - # ------------------------------------------------------------------ - loader2_a = DataPackerDataLoader( - data_source=SimpleDataset(dataset_size), - data_packer=SimplePacker(), - max_tokens=50, - max_batch_size=1, - shuffle=shuffle, - seed=seed, - name="ds_a", - num_workers=0, - parallel_dims=FakeParallelDims(), - ) - loader2_b = DataPackerDataLoader( - data_source=SimpleDataset(dataset_size), - data_packer=SimplePacker(), - max_tokens=50, - max_batch_size=1, - shuffle=shuffle, - seed=seed + 1, - name="ds_b", - num_workers=0, - parallel_dims=FakeParallelDims(), - ) - joint2 = JointDataPackerDataLoader( - dataloaders={ - "ds_a": {"dataloader": loader2_a, "ratio": 3}, - "ds_b": {"dataloader": loader2_b, "ratio": 1}, - }, - seed=seed, - ) - cb2 = JointDataLoaderStateCallback(outer_loader=joint2, distributor_type="data_packer") - cb2.load_state_dict(loaded_state) - - # Determine which dataset the first resumed batch should come from - probs = np.array([3, 1], dtype=float) / 4.0 - rng = np.random.RandomState(seed + saved_global_id) - expected_first_ds = ["ds_a", "ds_b"][int(rng.choice(2, p=probs))] - - first_batch = next(iter(joint2)) - actual_ds = first_batch["dataset_name"] - assert actual_ds == expected_first_ds, \ - f"[rank {rank}][{label}] first dataset mismatch: expected={expected_first_ds}, got={actual_ds}" - - # Verify sample position within the selected dataset matches saved_index + 1 - ds_inner_state = loaded_state.get(actual_ds, {}) - if ds_inner_state: - saved_index = ds_inner_state[0]["index"] - saved_epoch = ds_inner_state[0]["epoch"] - actual_pos = int(first_batch["sample_index"][0].item()) - expected_pos = saved_index + 1 - - # Also verify exact item id matches the deterministic permutation - seed_for_ds = seed if actual_ds == "ds_a" else seed + 1 - if shuffle: - g = torch.Generator().manual_seed(seed_for_ds + saved_epoch) - perm = torch.randperm(dataset_size, generator=g).tolist() - else: - perm = list(range(dataset_size)) - stream_slice = perm[rank::world_size] - expected_id = stream_slice[expected_pos] - actual_id = int(first_batch["ids"][0].item()) - - assert actual_pos == expected_pos, \ - f"[rank {rank}][{label}][{actual_ds}] pos mismatch: expected {expected_pos}, got {actual_pos}" - assert actual_id == expected_id, \ - f"[rank {rank}][{label}][{actual_ds}] id mismatch: expected {expected_id}, got {actual_id}" - - print( - f"[rank {rank}][{label}] resume OK: global_id={saved_global_id}, " - f"first_ds={actual_ds} (expected {expected_first_ds}), " - f"pos={first_batch['sample_index'][0].item()}, id={first_batch['ids'][0].item()}", - flush=True, - ) - - dist.barrier() - os.remove(pkl_path) - dist.barrier() - - -# --------------------------------------------------------------------------- -# Main -# --------------------------------------------------------------------------- - -def main(): - dist.init_process_group(backend="nccl") - rank = dist.get_rank() - world_size = dist.get_world_size() - - # All ranks need to agree on the scratch dir, so pick a deterministic - # location under the system tempdir rather than mkdtemp() (which would - # return a different path on each rank). - tmp_dir = os.path.join(tempfile.gettempdir(), "cosmos_dp_state_test_tmp") - if rank == 0: - os.makedirs(tmp_dir, exist_ok=True) - print(f"[rank 0] Using tmp_dir: {tmp_dir}", flush=True) - dist.barrier() - - # Test 1: shuffle=True, num_workers=0 - if rank == 0: - print("\n" + "=" * 60, flush=True) - print("[rank 0] TEST 1: shuffle=True, num_workers=0", flush=True) - print("=" * 60, flush=True) - run_state_test(rank, world_size, shuffle=True, seed=99, tmp_dir=tmp_dir) - - # Test 2: shuffle=False, num_workers=0 - if rank == 0: - print("\n" + "=" * 60, flush=True) - print("[rank 0] TEST 2: shuffle=False, num_workers=0", flush=True) - print("=" * 60, flush=True) - run_state_test(rank, world_size, shuffle=False, seed=0, tmp_dir=tmp_dir) - - # Test 3: shuffle=True, num_workers=2 - if rank == 0: - print("\n" + "=" * 60, flush=True) - print("[rank 0] TEST 3: shuffle=True, num_workers=2", flush=True) - print("=" * 60, flush=True) - run_state_test_multi_worker(rank, world_size, shuffle=True, seed=77, tmp_dir=tmp_dir, num_workers=2) - - # Test 4: shuffle=False, num_workers=2 - if rank == 0: - print("\n" + "=" * 60, flush=True) - print("[rank 0] TEST 4: shuffle=False, num_workers=2", flush=True) - print("=" * 60, flush=True) - run_state_test_multi_worker(rank, world_size, shuffle=False, seed=0, tmp_dir=tmp_dir, num_workers=2) - - # Test 5: JointDataPackerDataLoader — deterministic selection - if rank == 0: - print("\n" + "=" * 60, flush=True) - print("[rank 0] TEST 5: JointDataPackerDataLoader deterministic selection", flush=True) - print("=" * 60, flush=True) - run_joint_selection_test(rank, world_size, seed=42) - - # Test 6a: JointDataPackerDataLoader — stateful resume, shuffle=True - if rank == 0: - print("\n" + "=" * 60, flush=True) - print("[rank 0] TEST 6a: JointDataPackerDataLoader state resume, shuffle=True", flush=True) - print("=" * 60, flush=True) - run_joint_state_test(rank, world_size, shuffle=True, seed=99, tmp_dir=tmp_dir) - - # Test 6b: JointDataPackerDataLoader — stateful resume, shuffle=False - if rank == 0: - print("\n" + "=" * 60, flush=True) - print("[rank 0] TEST 6b: JointDataPackerDataLoader state resume, shuffle=False", flush=True) - print("=" * 60, flush=True) - run_joint_state_test(rank, world_size, shuffle=False, seed=0, tmp_dir=tmp_dir) - - if rank == 0: - shutil.rmtree(tmp_dir, ignore_errors=True) - print("\n=== ALL DISTRIBUTED STATE TESTS PASSED ===", flush=True) - - dist.destroy_process_group() - - -if __name__ == "__main__": - main() diff --git a/docs/custom_dataset.md b/docs/custom_dataset.md index 590a474..5fbdae3 100644 --- a/docs/custom_dataset.md +++ b/docs/custom_dataset.md @@ -1,498 +1,454 @@ # Custom Datasets for Generator and Reasoner Training -This guide explains how to bring your own dataset into Cosmos training using -`DataPackerDataLoader` and `JointDataPackerDataLoader` — the OSS-facing data -layer that works without any internal infrastructure (no WebDataset, no -object-store credentials). +Bring your own dataset into Cosmos training with **`CosmosDataLoader`** — the +OSS-facing data layer that works without any internal infrastructure (no +WebDataset, no object-store credentials). + +`CosmosDataLoader` turns any dataset into training batches by composing four +small, swappable roles. Pick a built-in for each slot, or write your own — the +loader wires them together in a fixed, safe order. + +``` +DataDistributor → RawItemProcessor → SampleBatcher → BatchCollator +(raw items, sharded (one raw item (a stream of (a group of + across DP ranks × → one sample samples → groups samples → one + workers, shuffled, dict) that form a batch) batch dict for + resumable) model.forward) +``` + +- **`DataDistributor`** owns the dataset and yields *this* rank/worker's disjoint + slice of raw items (sharding, shuffle, checkpoint/resume). +- **`RawItemProcessor`** turns one raw item into one training-ready sample dict + (decode, tokenize, etc.). +- **`SampleBatcher`** pulls from the sample stream and decides *which* samples go + together in a batch (fixed size, token-budget packing, …). +- **`BatchCollator`** turns a chosen group of samples into one batch dict. + +Everything lives in `cosmos_framework.data.vfm.dataflow`. The loader is a +`torch.utils.data.DataLoader` subclass, so it drops into existing training loops. --- ## Contents -1. [Overview](#overview) -2. [DataPackerDataLoader](#datapackerdataloader) - - [Step 1 — Prepare your data source](#step-1--prepare-your-data-source) - - [Step 2 — Write your DataPacker](#step-2--write-your-datapacker) - - [Step 3 — Wire into an experiment config](#step-3--wire-everything-into-an-experiment-config) - - [Key parameters](#key-parameters) - - [Shuffle and stateful checkpoint/resume](#shuffle-and-stateful-checkpointresume) - - [Data-parallel sharding](#data-parallel-sharding) -3. [JointDataPackerDataLoader](#jointdatapackerdataloader) - - [When to use it](#when-to-use-it) - - [How to wire it up](#how-to-wire-it-up) - - [Stateful checkpoint/resume](#stateful-checkpointresume) -4. [Real-world examples](#real-world-examples) -5. [Checklist](#checklist-for-a-new-dataset) +1. [Quickstart (60 seconds)](#1-quickstart-60-seconds) +2. [The four roles](#2-the-four-roles) +3. [Recipes by use-case](#3-recipes-by-use-case) +4. [Wiring into a training recipe (Hydra)](#4-wiring-into-a-training-recipe-hydra) +5. [Checkpoint / resume](#5-checkpoint--resume) +6. [Distributed & sharding](#6-distributed--sharding) +7. [Troubleshooting / FAQ](#7-troubleshooting--faq) +8. [End-to-end worked example](#8-end-to-end-worked-example-custom-dataset--training) +9. [Real-world examples](#9-real-world-examples) +10. [Checklist for a new dataset](#10-checklist-for-a-new-dataset) --- -## Overview - -The data pipeline has two parts you control: +## 1. Quickstart (60 seconds) -``` -Your dataset (Dataset or IterableDataset) - │ - ▼ -┌────────────────────────────────────────────────┐ -│ DataPackerDataLoader │ -│ │ -│ map-style Dataset (any shuffle setting): │ -│ ┌──────────────────────────────────────────┐ │ -│ │ _ShuffledMapIterableDataset │ │ -│ │ • per-epoch randperm (shuffle=True) │ │ -│ │ or sequential (shuffle=False) │ │ -│ │ • DP × worker sharding │ │ -│ │ • position metadata for stateful resume │ │ -│ └──────────────────┬───────────────────────┘ │ -│ │ │ -│ IterableDataset: │ -│ ┌──────────────────────────────────────────┐ │ -│ │ _IterableWrapper │ │ -│ │ • DP × worker sharding only │ │ -│ │ • no stateful resume │ │ -│ └──────────────────┬───────────────────────┘ │ -│ │ raw item │ -│ ┌──────────────────▼───────────────────────┐ │ -│ │ _DataPackerIterableDataset │ │ -│ │ (subclass of PackingIterableDataset) │ │ -│ │ │ │ -│ │ • fill pool (pool_size samples) │ │ -│ │ • greedy bin-pack within max_tokens │ │ -│ │ • cap at max_batch_size │ │ -│ │ │ │ -│ │ → DataPacker.sft_process_sample ← you │ │ -│ │ → DataPacker.compute_num_tokens ← you │ │ -│ │ → DataPacker.sft_collate_fn ← you │ │ -│ └──────────────────────────────────────────┘ │ -└────────────────────────────────────────────────┘ - │ fully-collated batch dict - ▼ - Trainer / model.forward() -``` +"I have a map-style dataset and just want normal, shuffled, resumable batches": -Key point: **all map-style datasets** (whether `shuffle=True` or `shuffle=False`) -are routed through `_ShuffledMapIterableDataset`, which attaches position -metadata to every sample. This means stateful checkpoint/resume works regardless -of whether shuffle is enabled. +```python +from cosmos_framework.data.vfm.dataflow import ( + CosmosDataLoader, MapDistributor, IdentityProcessor, +) ---- +loader = CosmosDataLoader( + distributor=MapDistributor(my_dataset, shuffle=True, seed=0), # any torch map Dataset + processor=IdentityProcessor(), # dataset already yields samples + batch_size=32, # sugar → SimpleBatcher + DefaultBatchCollator +) -## DataPackerDataLoader +for batch in loader: + out = model(**batch) +``` -### Step 1 — Prepare your data source +`batch_size=N` is convenience sugar: when you don't pass an explicit `batcher`, +the loader builds a `SimpleBatcher(N)` + `DefaultBatchCollator()` (stock +`torch.utils.data` stacking). Pass an explicit `batcher`/`collator` for anything +fancier. (Passing both `batch_size=` and `batcher=` is an error.) -`DataPackerDataLoader` accepts either a **map-style** `torch.utils.data.Dataset` -or an **iterable-style** `torch.utils.data.IterableDataset`. Plain lists and -generators are rejected with a `TypeError`. +--- -| Type | Notes | -| ----------------------------------------- | ------------------------------------------------------------------------------------- | -| `torch.utils.data.Dataset` (map-style) | Pass directly. Supports `shuffle=True/False` and stateful checkpoint/resume. | -| `torch.utils.data.IterableDataset` | Pass directly. No shuffle, no stateful resume — shuffle externally if needed. | -| HuggingFace `Dataset` | Is a `torch.utils.data.Dataset` subclass — pass directly, `shuffle=True` works. | -| HuggingFace `IterableDataset` (streaming) | Is a `torch.utils.data.IterableDataset` — pass directly, use `.shuffle()` externally. | +## 2. The four roles -#### Loading from HuggingFace (simplest) +Each role is a tiny ABC in `cosmos_framework.data.vfm.dataflow.base`. Implement +the one method (plus, for distributors, optional resume hooks). ```python -from cosmos_framework.data.vfm.data_packer_dataloader import load_data_source +class DataDistributor(ABC): + def stream(self, dp_rank, dp_world_size, worker_id, num_workers) -> Iterator[Any]: + """Yield this (rank, worker)'s disjoint slice of raw items, indefinitely.""" + def state_dict(self) -> dict: ... # optional, for resume + def load_state_dict(self, state) -> None: ... -# HuggingFace Hub dataset (downloaded, map-style) -data_source = load_data_source("liuhaotian/LLaVA-Instruct-150K", split="train") +class RawItemProcessor(ABC): + def process(self, item) -> dict: ... # one raw item → one sample dict -# Dataset saved with dataset.save_to_disk() -data_source = load_data_source("/path/to/my_saved_dataset", split="train") +class SampleBatcher(ABC): + def batches(self, samples: Iterator[dict]) -> Iterator[list[dict]]: ... + def sample_size(self, sample) -> int: ... # only packing batchers need this -# Then pass with shuffle for per-epoch shuffling + stateful resume -DataPackerDataLoader(data_source=data_source, ..., shuffle=True, seed=42) +class BatchCollator(ABC): + def collate(self, samples: list[dict]) -> dict: ... # group → batch dict ``` -#### Streaming from HuggingFace (no disk space) +The loader passes the rank/worker coordinates *into* `stream()` — you never read +`get_worker_info()` yourself. The fixed order (`distribute → process → batch → +collate`) is enforced by the loader, so the stages can't be misordered. + +### Built-ins + +| Role | Built-in | Use it when | +| ----------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------- | --------------------------------------------------------------------------- | +| Distributor | `IterableDistributor(iterable)` | streaming / `IterableDataset` source (round-robin shard; **not** resumable) | +| | `MapDistributor(dataset, seed=0, shuffle=True, name="")` | map-style `Dataset` (per-epoch shuffle, slice shard, **resumable**) | +| | `RankPartitionedDistributor({name: {"dataset":…, "ratio":…}})` | assign whole DP ranks to different datasets by ratio | +| | `MixtureDistributor({name: (distributor, ratio)}, seed=0)` | mix several distributors into one stream at the sample level | +| Processor | `IdentityProcessor()` | the dataset already yields finished sample dicts | +| | *(write your own)* | decode/tokenize/transform a raw record | +| Batcher | `SimpleBatcher(batch_size, drop_last=False)` | fixed-size batches | +| | `PoolPackingBatcher(max_tokens, pool_size=16, max_batch_size=1, long_threshold=6400, batching_strategy="prefer_closest", apply_long_sample_halving=True, size_fn=None)` | token-budget bin-packing (reorders within a pool to minimize padding) | +| | `SequentialPackingBatcher(max_sequence_length, …, max_samples_per_batch=None)` | order-preserving pack-until-budget (no reordering) | +| Collator | `DefaultBatchCollator()` | stack with `torch.utils.data.default_collate` | +| | `VFMListCollator()` | VFM packed batches (media kept as per-sample lists) | + +Recipe-specific roles live next to their recipes, e.g. `VLMProcessor` / +`VLMCollator` (in `configs/base/vlm/experiment/dataflow_roles.py`) and +`VideoPhy2Processor`. -```python -from datasets import load_dataset +--- -data_source = load_dataset( - "lmms-lab/LLaVA-OneVision-Data", name="si", split="train", streaming=True -) -# shuffle before passing — IterableDataset does not support internal shuffle -data_source = data_source.shuffle(seed=42, buffer_size=10_000) -``` +## 3. Recipes by use-case -#### Custom map-style dataset +### Bring your own map-style dataset (shuffle + resume) ```python -class MyMapDataset(torch.utils.data.Dataset): - def __len__(self): return 10_000 - def __getitem__(self, idx): return {"video": ..., "text": ...} - -# Pass directly — DataPackerDataLoader handles sharding and shuffle internally -DataPackerDataLoader(data_source=MyMapDataset(), ..., shuffle=True, seed=42) +class MyImageCaptionDataset(torch.utils.data.Dataset): + def __len__(self): return len(self.records) + def __getitem__(self, i): return self.records[i] # a plain dict + +loader = CosmosDataLoader( + distributor=MapDistributor(MyImageCaptionDataset(...), shuffle=True, seed=42), + processor=MyProcessor(), # turns the record into {"input_ids": …, "pixel_values": …} + batch_size=8, + num_workers=4, +) ``` ---- - -### Step 2 — Write your DataPacker +Map-style sources are **resumable** (see §5). -`DataPacker` is an abstract base class. Implement all three methods, then place -the class in the same experiment config file that uses it. +### Bring your own streaming / iterable dataset ```python -from cosmos_framework.data.vfm.data_packer import DataPacker - -class MyDataPacker(DataPacker): - - def sft_process_sample(self, item: dict) -> dict: - """ - Convert one raw item from data_source into a training-ready sample. - Called inside DataLoader workers — tokenization, decoding, transforms go here. - """ - ... - return {"input_ids": ..., "labels": ..., ...} - - def compute_num_tokens(self, sample: dict) -> int: - """ - Return the token cost of one processed sample. - Used by the packing engine to decide how many samples fit in a batch. - """ - return int(sample["input_ids"].shape[0]) - - def sft_collate_fn(self, samples: list[dict], max_len: int, - ignore_label_id: int = -100) -> dict: - """ - Collate a list of processed samples into one batch dict. - max_len is the longest token sequence in this batch (for padding). - """ - ... - return {"input_ids": ..., "labels": ..., ...} +hf_stream = load_dataset("some/dataset", split="train", streaming=True) +loader = CosmosDataLoader( + distributor=IterableDistributor(hf_stream), # round-robin shard across rank×worker + processor=MyProcessor(), + batch_size=8, +) ``` -> **Note on extra batch keys**: For map-style datasets, `DataPackerDataLoader` -> automatically appends `sample_worker_id`, `sample_epoch`, and `sample_index` to -> every batch dict. These are used by `DataLoaderStateCallback` for stateful -> checkpoint/resume and are transparent to the model as long as `training_step` -> accesses the batch by key (not `**kwargs` unpack). +Iterable sources are **not** resumable (you can't random-access to fast-forward). -#### Token counting for Generator models +### Token-budget packing for variable-length sequences ```python -import math - -def compute_num_tokens(self, sample: dict) -> int: - tokens = 1 + len(sample.get("text_token_ids", [])) - v = sample.get("video") # shape [C, T, H, W] - if v is not None: - _, T, H, W = v.shape - latent_h = math.ceil(H / (self.spatial_compression * self.patch_spatial)) - latent_w = math.ceil(W / (self.spatial_compression * self.patch_spatial)) - latent_t = 1 + (T - 1) // self.temporal_compression - tokens += latent_h * latent_w * latent_t + 2 - return tokens +loader = CosmosDataLoader( + distributor=IterableDistributor(stream), + processor=MyProcessor(), # yields {"input_ids": Tensor[L], …} + batcher=PoolPackingBatcher(max_tokens=16000, pool_size=16, max_batch_size=1), + collator=MyCollator(), +) ``` -Typical values: `spatial_compression=16`, `temporal_compression=4`, `patch_spatial=2`. - ---- +`PoolPackingBatcher.sample_size` defaults to `len(sample["input_ids"])`; pass +`size_fn=lambda s: …` (or subclass and override `sample_size`) for a custom cost. -### Step 3 — Wire everything into an experiment config +### Order-preserving sequence packing ```python -from cosmos_framework.utils.lazy_config import LazyCall as L, LazyDict -from cosmos_framework.data.vfm.data_packer_dataloader import DataPackerDataLoader, load_data_source -from cosmos_framework.callbacks.dataloader_state import DataLoaderStateCallback -from hydra.core.config_store import ConfigStore - -cs = ConfigStore.instance() - -my_experiment = LazyDict(dict( - defaults=[...], # inherit model, optimizer, scheduler from a base - - trainer=dict( - callbacks=dict( - # Tracks per-worker (epoch, position) for checkpoint/resume. - # Works with both shuffle=True and shuffle=False for map-style datasets. - dataloader_state=L(DataLoaderStateCallback)(distributor_type="data_packer"), - ), - ), - - dataloader_train=L(DataPackerDataLoader)( - data_source=L(load_data_source)(name="my-org/my-dataset", split="train"), - data_packer=L(MyDataPacker)(...), - max_tokens=16000, - pool_size=16, - max_batch_size=1, - shuffle=True, # per-epoch randperm, different order every epoch - seed=42, # epoch e uses seed+e → reproducible permutations - num_workers=4, - prefetch_factor=4, - persistent_workers=True, - pin_memory=True, - ), - dataloader_val=None, -), flags={"allow_objects": True}) - -cs.store(group="experiment", package="_global_", name="my_experiment", node=my_experiment) +batcher=SequentialPackingBatcher( + max_sequence_length=45056, + tokenizer_spatial_compression_factor=16, + tokenizer_temporal_compression_factor=4, + patch_spatial=2, +) ``` -Launch: +Packs samples in stream order until the token budget is hit (no reordering). +Exactly one of `max_sequence_length` (token-budget mode) or +`max_samples_per_batch` (count-only mode) must be set. -```bash -torchrun --nproc_per_node=8 -m cosmos_framework.scripts.train \ - --config=cosmos_framework/configs/base/config.py -- \ - experiment=my_experiment \ - trainer.max_iter=1000 -``` +### Mix multiple datasets by ratio (one pipeline) ---- +```python +distributor=MixtureDistributor( + {"webvid": (IterableDistributor(webvid), 3.0), + "internal": (IterableDistributor(internal), 1.0)}, + seed=0, +) +``` -### Key parameters - -| Parameter | What it controls | -| -------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | -| `max_tokens` | Token budget per batch. Packing stops when adding one more sample would exceed this. For Generator, counts video latent tokens; for Reasoner, counts `input_ids` length. | -| `pool_size` | Samples to buffer before bin-packing. Larger pool → better packing efficiency, more memory. Default: 16. | -| `max_batch_size` | Hard cap on samples per batch regardless of token budget. Use `1` for Reasoner (one image per step), `128`–`256` for action policy training. | -| `shuffle` | `True` → per-epoch `randperm` shuffle for map-style datasets (no effect on `IterableDataset`, a warning is logged). `False` → sequential, still resumable. | -| `seed` | Base seed for the shuffle permutation. Epoch `e` uses `seed + e` → reproducible, different ordering every epoch. Default: `0`. | -| `name` | Optional string that namespaces resume env vars. **Required** when multiple `DataPackerDataLoader` instances share the same process (i.e., inside `JointDataPackerDataLoader`). Each inner loader must have a unique `name` matching its key in the `dataloaders` dict. Leave empty (default) for single-loader setups. | -| `long_threshold` | Samples with token count ≥ this are emitted as singleton batches, bypassing packing. Default: 6400. | -| `batching_strategy` | `"prefer_closest"` (default) picks candidates nearest in token length. `"prefer_first"` picks the first that fits. | -| `num_workers` | DataLoader workers for `sft_process_sample`. Use `0` for debugging. | -| `persistent_workers` | Automatically promoted to `True` for all map-style datasets when `num_workers > 0` (required for correct resume behaviour). | +Ratio-weighted merge into a single stream — use when the datasets share one +processor/batcher/collator (homogeneous join). ---- +### Interleave heterogeneous pipelines (different processors/collators) -### Shuffle and stateful checkpoint/resume +```python +from cosmos_framework.data.vfm.dataflow import JointCosmosDataLoader -For map-style datasets, `DataPackerDataLoader` tracks each worker's position and -resumes training from **exactly** where it left off after a checkpoint. This works -for both `shuffle=True` and `shuffle=False`. +joint = JointCosmosDataLoader( + dataloaders={ + "vlm": {"dataloader": vlm_loader, "ratio": 1}, + "vfm": {"dataloader": vfm_loader, "ratio": 3}, + }, + seed=42, +) +``` -#### How it works +Each output batch comes from one selected inner `CosmosDataLoader` (ratio-weighted, +seeded). Use when the joined datasets need *different* processing — each inner +loader is a full four-role pipeline. Every yielded batch is tagged with +`"dataset_name"`. (`"global_id"` is reserved by the checkpoint state and cannot be +used as a dataset name.) -1. Each epoch, a permutation is generated with `torch.randperm(n, generator=torch.Generator().manual_seed(seed + epoch))` (or `list(range(n))` when `shuffle=False`). -2. Each `(dp_rank, worker_id)` pair sees a disjoint stride: `perm[stream_id :: total_streams]` where `stream_id = dp_rank * num_workers + worker_id`. -3. After each training step, `DataLoaderStateCallback` reads `sample_epoch` and `sample_index` from the batch and tracks the high-water mark per worker. -4. At checkpoint, the DCP checkpointer saves the state to `iter_XXXXXXXXX/dataloader/rank_{rank}.pkl`. -5. On resume, `load_state_dict` sets `DP_STATE_WORKER_{worker_id}_EPOCH/INDEX` env vars before workers start, and workers fast-forward past already-seen samples. +--- -**At most `pool_size` (default 16) samples are re-processed** at each resume (they pass through `sft_process_sample` again but are trained on only once). +## 4. Wiring into a training recipe (Hydra) -#### Required wiring +Recipes build the loader with `LazyCall` so CLI overrides work: ```python -from cosmos_framework.callbacks.dataloader_state import DataLoaderStateCallback +from cosmos_framework.utils.lazy_config import LazyCall as L -exp["trainer"]["callbacks"]["dataloader_state"] = L(DataLoaderStateCallback)( - distributor_type="data_packer" +dataloader_train = L(CosmosDataLoader)( + distributor=L(MapDistributor)(dataset=L(my_dataset_factory)(...), shuffle=True), + processor=L(MyProcessor)(...), + batcher=L(PoolPackingBatcher)(max_tokens="${data_setting.max_tokens}", max_batch_size=1), + collator=L(MyCollator)(), + num_workers=2, ) ``` -Use `ckpt_type=dcp` (the default) — not `ckpt_type=dummy` which disables all checkpointing. +Override from the CLI like any Hydra node, e.g. +`dataloader_train.batcher.max_tokens=8000`. See the live recipes for full examples: +`pre_exp012_llava_ov` (VLM), `videophy2_sft_nano` (videophy2), +`pre_exp012_llava_ov_mapstyle_dataloader` (map-style resumable VLM), and +`vision_sft_nano_mapstyle_dataloader` (VFM, alongside the legacy `vision_sft_nano`). -#### Limitations - -- **Map-style datasets only.** Stateful resume is not supported for `IterableDataset` sources. -- **`fork` start method required** (the default for Linux/CUDA). `spawn` is not supported. -- **`persistent_workers=True` required** when `num_workers > 0` (auto-enforced for all map-style datasets). +> **Structured-TOML launches.** When you launch a VLM recipe via `--sft-toml`, +> the flat `[dataloader_train]` knobs `max_samples_per_batch` and +> `max_sequence_length` are routed onto the loader's nested batcher +> (`dataloader_train.batcher.max_batch_size` and `…batcher.max_tokens`) by +> `PATH_REMAPS["vlm"]` in `configs/toml_config/toml_config_helper.py`. This only +> works for experiments whose batcher actually has those fields (e.g. +> `PoolPackingBatcher`). --- -### Data-parallel sharding - -`DataPackerDataLoader` automatically shards `data_source` across ranks **and** -DataLoader workers. Each `(dp_rank, worker_id)` pair receives a disjoint subset — -a strided slice of the (shuffled) permutation. +## 5. Checkpoint / resume -**If your dataset already shards internally** (like `SFTDataset`), disable its -sharding before passing it to `DataPackerDataLoader`: +Resume is handled by the existing `DataLoaderStateCallback`: ```python -def get_my_dataset_no_dp(**kwargs): - dataset = MyDataset(**kwargs) - dataset.shard_world_size = 1 # disable internal sharding - dataset.shard_rank = 0 - return dataset +from cosmos_framework.callbacks.dataloader_state import DataLoaderStateCallback +cb = DataLoaderStateCallback(distributor_type="cosmos_dataloader") ``` -**For FSDP + TP/PP**: pass `parallel_dims` so the correct DP rank is used -(global rank ≠ DP rank in these setups): - -```python -DataPackerDataLoader(..., parallel_dims=parallel_dims) -``` +- Use a **`MapDistributor`** source. On save, the callback records each worker's + `(epoch, index)` from the per-batch `sample_worker_id`/`sample_epoch`/`sample_index` + tensors the loader stamps. On load, it sets `COSMOS_DL_STATE_*` env vars *before* + workers fork; `MapDistributor.stream` reads them and fast-forwards to the exact + next sample — no duplicated or skipped samples. +- **Iterable** sources are not resumable (no position to fast-forward to); the + stream restarts from the beginning. +- For multiple loaders sharing a process (e.g. inside `JointCosmosDataLoader`), + give each a distinct `name=` so resume env vars are namespaced + (`COSMOS_DL_STATE_{name}_WORKER_{id}_{EPOCH,INDEX}`), and use a single + `JointDataLoaderStateCallback(outer_loader=joint_loader, distributor_type="cosmos_dataloader")` + instead of one `DataLoaderStateCallback` per inner loader. +- Use `ckpt_type=dcp` (the default) — not `ckpt_type=dummy`, which disables all + checkpointing. The on-disk checkpoint format is unchanged. + +> **Validated:** a live save→stop→resume on `pre_exp012_llava_ov_mapstyle_dataloader` +> (8 dp ranks, `save_iter=100`) reproduces the original run's per-rank +> `input_ids` shapes exactly across the resume boundary — no duplicated or +> skipped samples on any rank. --- -## JointDataPackerDataLoader - -### When to use it +## 6. Distributed & sharding + +- The loader resolves the data-parallel coordinates as + `parallel_dims.dp_coord` > `torch.distributed` > `(0, 1)`. For FSDP+TP/PP, pass + `parallel_dims=` so sharding uses the correct DP rank (not the global rank). +- `IterableDistributor`/`MapDistributor` give each `(dp_rank, worker_id)` pair a + disjoint, complete slice: stream `i` is taken iff + `i % (dp_world_size × num_workers) == dp_rank × num_workers + worker_id`. +- `RankPartitionedDistributor` instead assigns *whole ranks* to datasets by ratio + and sets `shard_world_size`/`shard_rank` on the chosen dataset (which then + self-shards across the ranks sharing it). If your dataset already shards + internally, disable that (`dataset.shard_world_size = 1; dataset.shard_rank = 0`) + before handing it to a per-`(rank,worker)`-sharding distributor. +- A `MapDistributor` source with `num_workers > 0` is automatically promoted to + `persistent_workers=True` (required for correct stateful resume). The `fork` + start method (the Linux/CUDA default) is required; `spawn` is not supported. -`JointDataPackerDataLoader` wraps **multiple** `DataPackerDataLoader` instances -with ratio-based seeded selection. Use it when training on multiple datasets with -different modalities or formats — for example, video + action data at a 3:1 ratio. +--- -Semantics mirror `IterativeJointDataLoader`: +## 7. Troubleshooting / FAQ + +- **OOM on large packed batches.** `PoolPackingBatcher(apply_long_sample_halving=True)` + (default) halves the token budget for any batch whose largest sample ≥ 1000 + tokens. Set `False` only after validating memory headroom at the literal budget. +- **`ValueError: Provide either a batcher= or a batch_size=`.** You passed neither; + give one. Passing both is also an error. +- **`ValueError: Map-style resume cannot safely stamp a multi-sample batch …`.** + A reordering batcher (pool packing) with `batch_size > 1` on a resumable + `MapDistributor` can't record a gap-free resume position. Use `max_batch_size=1` + with pool packing, a sequential (order-preserving) batcher, or an iterable + (non-resumable) source. +- **`'int' object is not iterable` / wrong tensor shapes in the model.** Your + `BatchCollator` is producing a different batch structure than the model expects. + Match the structure the model consumes (for VFM, that's `VFMListCollator`, which + keeps media as per-sample lists). +- **Oversized samples silently dropped (sequential packing).** A single sample + larger than `max_sequence_length` is discarded with a logged error — increase + the budget or filter upstream. +- **`num_workers` / `persistent_workers`.** `persistent_workers=True` is ignored + (with a log) when `num_workers=0`. -- **One batch = one dataset** — samples from different datasets never share a packed batch. -- Ratios control how frequently each dataset is visited (per batch, not per sample). -- Selection is deterministic: step `i` always picks the same dataset given the same `seed`. -- Stateful checkpoint/resume: both the outer step counter (`global_id`) and each inner - loader's per-worker position are saved and restored. +--- -### How to wire it up +## 8. End-to-end worked example (custom dataset → training) -Each inner `DataPackerDataLoader` must be given a unique `name` that matches its -key in the `dataloaders` dict. The `name` namespaces the resume env vars to -prevent conflicts between concurrent loaders. +A local image-caption folder, fully custom processor, normal batching: ```python -from cosmos_framework.data.vfm.data_packer_dataloader import DataPackerDataLoader, JointDataPackerDataLoader -from cosmos_framework.callbacks.dataloader_state import JointDataLoaderStateCallback -from cosmos_framework.utils.lazy_config import LazyCall as L - -# Build the joint loader -joint_loader = JointDataPackerDataLoader( - dataloaders={ - "video": { - "dataloader": DataPackerDataLoader( - data_source=MyVideoDataset(...), - data_packer=MyVideoDataPacker(...), - max_tokens=45056, - shuffle=True, - seed=0, - name="video", # must match the key above - num_workers=4, - persistent_workers=True, - pin_memory=True, - ), - "ratio": 3, # video 3×, action 1× - }, - "action": { - "dataloader": DataPackerDataLoader( - data_source=MyActionDataset(...), - data_packer=MyActionDataPacker(...), - max_tokens=999_999, - max_batch_size=128, - shuffle=True, - seed=0, - name="action", # must match the key above - num_workers=4, - persistent_workers=True, - pin_memory=True, - ), - "ratio": 1, - }, - }, - seed=42, # controls outer dataset selection sequence +import torch +from cosmos_framework.data.vfm.dataflow import ( + CosmosDataLoader, MapDistributor, RawItemProcessor, DefaultBatchCollator, SimpleBatcher, ) -# Wire into the experiment config -exp["dataloader_train"] = joint_loader -exp["trainer"]["callbacks"]["dataloader_state"] = JointDataLoaderStateCallback( - outer_loader=joint_loader, - distributor_type="data_packer", +class ImageCaptionFolder(torch.utils.data.Dataset): + def __init__(self, records): self.records = records # [{"image_path":…, "caption":…}, …] + def __len__(self): return len(self.records) + def __getitem__(self, i): return self.records[i] + +class ImageCaptionProcessor(RawItemProcessor): + def __init__(self, tokenizer, image_loader): + self.tokenizer, self.image_loader = tokenizer, image_loader + def process(self, item): + return { + "pixel_values": self.image_loader(item["image_path"]), # Tensor[C,H,W] + "input_ids": self.tokenizer(item["caption"]), # Tensor[L] + } + +loader = CosmosDataLoader( + distributor=MapDistributor(ImageCaptionFolder(records), shuffle=True, seed=0), + processor=ImageCaptionProcessor(tokenizer, image_loader), + batcher=SimpleBatcher(batch_size=16), + collator=DefaultBatchCollator(), + num_workers=4, ) -``` - -> **Reserved name**: `"global_id"` cannot be used as a dataset name — it is -> reserved by the checkpoint state format. - -#### `JointDataPackerDataLoader` parameters - -| Parameter | What it controls | -| ------------- | ------------------------------------------------------------------------------------------------------------------------ | -| `dataloaders` | Dict mapping dataset name → `{"dataloader": DataPackerDataLoader, "ratio": int}`. Entries with `ratio <= 0` are skipped. | -| `seed` | Base seed for outer dataset selection. Step `i` uses `np.random.RandomState(seed + i)` → same sequence on every rank. | - -#### `JointDataLoaderStateCallback` - -This single callback replaces the per-inner-loader `DataLoaderStateCallback` -instances. It saves: - -- `global_id` — the outer step counter, which determines which dataset fires at each step on resume. -- Per-dataset, per-worker `(epoch, index)` — each inner loader's position. -All state is written to a single DCP checkpoint entry (`checkpoint_component="dataloader"`). - -### Stateful checkpoint/resume +for batch in loader: # batch = {"pixel_values": [16,C,H,W], "input_ids": [16,L]} + loss = model(**batch) + loss.backward() +``` -At checkpoint step `N`: +To pack variable-length captions by token budget instead of fixed size, swap the +batcher for `PoolPackingBatcher(max_tokens=…, max_batch_size=1)` and provide a +collator that pads/stacks accordingly — nothing else changes. -- `global_id = N` is saved. -- Each inner loader saves its per-worker `(epoch, index)` under its `name` key. +--- -On resume: +## 9. Real-world examples -1. `JointDataLoaderStateCallback.load_state_dict` calls `set_start_iteration(N)` on the outer loader → selection sequence resumes from step `N`. -2. Each inner `DataLoaderStateCallback.load_state_dict` sets namespaced env vars (`DP_STATE_{name}_WORKER_{id}_EPOCH/INDEX`) → workers fast-forward to the saved position. +### Reasoner (VLM) — HuggingFace image-text dataset, streaming -Inner loader iterators are created lazily on the **first** `__iter__` call (not at -`__init__` time), ensuring workers fork **after** env vars have been set. +**File**: `cosmos_framework/configs/base/vlm/experiment/llava_ov_vlm.py` +(`pre_exp012_llava_ov`) ---- +``` +distributor: IterableDistributor(get_llava_ov_streaming(...)) # lmms-lab/LLaVA-OneVision-Data +processor: VLMProcessor (ShareGPT → OpenAI messages → Qwen3-VL processor) +batcher: PoolPackingBatcher(max_tokens≈16000, max_batch_size=1) +collator: VLMCollator +``` -## Real-world examples +Streaming source → **not** resumable. For a resumable variant of the same recipe, +see `llava_ov_mapstyle_dataloader_experiment.py` (`pre_exp012_llava_ov_mapstyle_dataloader`): it loads +the subset as a real map-style `Dataset` (`load_dataset(..., streaming=False)`) and +wraps it in a `MapDistributor`, so checkpoint/resume works (see §5). -### Reasoner — HuggingFace image-text dataset +### Reasoner (VLM) — local video dialog dataset -**File**: `cosmos_framework/configs/base/vlm/experiment/llava_ov_datapacker_experiment.py` +**File**: `cosmos_framework/configs/base/vlm/experiment/videophy2_sft_nano.py` +(`videophy2_sft_nano`) ``` -data_source: lmms-lab/LLaVA-OneVision-Data (streaming IterableDataset) -DataPacker: VLMDataPacker - sft_process_sample: ShareGPT → OpenAI messages → Qwen3-VL processor - compute_num_tokens: len(input_ids) - sft_collate_fn: unsqueeze batch dim, keep pixel_values flat -max_batch_size: 1 -max_tokens: ~16000 -shuffle: False (streaming IterableDataset — use .shuffle() externally) +distributor: IterableDistributor(build_videophy2_local_dataset(...)) +processor: VideoPhy2Processor +batcher: PoolPackingBatcher(max_tokens≈16000, max_batch_size=1) +collator: VLMCollator ``` -### Action Policy — Robot learning (LIBERO) +### Generator (VFM) — Cosmos video SFT -**File**: `cosmos_framework/configs/base/experiment/action/posttrain_config/libero_policy_datapacker_experiment.py` +**File**: the `vision_sft_nano_mapstyle_dataloader` experiment (the new-loader VFM variant, +alongside the legacy `vision_sft_nano`). ``` -data_source: LIBERODataset (map-style Dataset, passed directly) -DataPacker: ActionDataPacker - sft_process_sample: full ActionTransformPipeline (resize, tokenize, pad action) - compute_num_tokens: VAE video tokens + text tokens - sft_collate_fn: action/domain_id/sequence_plan fields + video + text -max_batch_size: 128 (token budget disabled — batch bounded by max_batch_size) -max_tokens: 999999 -shuffle: True, seed=0 +distributor: IterableDistributor over the Cosmos video dataset +processor: recipe processor (decode + tokenize) +batcher: SequentialPackingBatcher(max_sequence_length=…) # order-preserving token packing +collator: VFMListCollator # media kept as per-sample lists ``` -`LIBERODataset` is a map-style `Dataset` passed directly. `shuffle=True` enables -per-epoch shuffling and stateful checkpoint/resume. This pattern (high `max_tokens` +--- -- bounded `max_batch_size`) is standard for action policy training where you want -a fixed number of demonstrations per step. +## 10. Checklist for a new dataset + +### Single dataset (`CosmosDataLoader`) + +- [ ] Pick a **distributor**: `MapDistributor` (map-style `Dataset`, shuffle + + **resume**) or `IterableDistributor` (streaming, not resumable). +- [ ] Write a **`RawItemProcessor`** (or use `IdentityProcessor` if your dataset + already yields finished sample dicts). +- [ ] Pick a **batcher**: `batch_size=N` sugar, `SimpleBatcher`, + `PoolPackingBatcher` (token-budget, reorders), or `SequentialPackingBatcher` + (token-budget, order-preserving). +- [ ] Pick a **collator**: `DefaultBatchCollator`, `VFMListCollator`, or your own + (must match the structure the model consumes). +- [ ] For real resume: use a `MapDistributor`, add + `DataLoaderStateCallback(distributor_type="cosmos_dataloader")`, and + `ckpt_type=dcp` (not `dummy`). +- [ ] For FSDP+TP/PP, pass `parallel_dims=` so the correct DP rank is used. +- [ ] Register the experiment in the Hydra ConfigStore + (`cs.store(group="experiment", …)`). +- [ ] Smoke-test with `--dryrun` (config build) then `trainer.max_iter=10` before a + full run. + +### Multiple datasets (`JointCosmosDataLoader`) + +- [ ] Build each inner pipeline as its own `CosmosDataLoader`; give each a unique + `name=` matching its key in `dataloaders` (namespaces resume env vars). +- [ ] Set each dataset's `ratio` (controls how often it is visited, per batch). +- [ ] Use a single + `JointDataLoaderStateCallback(outer_loader=joint_loader, distributor_type="cosmos_dataloader")` + — do **not** also register a standalone `DataLoaderStateCallback` per inner + loader. +- [ ] Avoid `"global_id"` as a dataset name (reserved by the checkpoint state). +- [ ] Use `ckpt_type=dcp` for real checkpoint/resume. --- -## Checklist for a new dataset - -### Single dataset (`DataPackerDataLoader`) - -- [ ] Choose a `data_source`: map-style `Dataset` or `IterableDataset` (no plain lists/generators) -- [ ] For map-style: pass directly; use `shuffle=True, seed=` for per-epoch shuffle -- [ ] For iterable: shuffle externally before passing (e.g. `.shuffle(buffer_size=N)`) -- [ ] If dataset has internal DP sharding, disable it (`shard_world_size=1`) -- [ ] Subclass `DataPacker` and implement `sft_process_sample`, `compute_num_tokens`, `sft_collate_fn` -- [ ] Choose `max_tokens` and `max_batch_size` for your modality -- [ ] Add `DataLoaderStateCallback(distributor_type="data_packer")` to the experiment's callbacks (works for both `shuffle=True` and `shuffle=False` on map-style datasets) -- [ ] Use `ckpt_type=dcp` (not `dummy`) for real checkpoint/resume -- [ ] Register in Hydra ConfigStore with `cs.store(group="experiment", ...)` -- [ ] Smoke-test with `ckpt_type=dummy trainer.max_iter=10` before a full run - -### Multiple datasets (`JointDataPackerDataLoader`) - -- [ ] Give each inner `DataPackerDataLoader` a unique `name` matching its key in `dataloaders` -- [ ] Set appropriate `ratio` for each dataset (controls visit frequency per batch) -- [ ] Use `JointDataLoaderStateCallback(outer_loader=joint_loader)` instead of `DataLoaderStateCallback` -- [ ] Do **not** also register standalone `DataLoaderStateCallback` for inner loaders — `JointDataLoaderStateCallback` handles all of them -- [ ] Avoid using `"global_id"` as a dataset name (reserved) -- [ ] Use `ckpt_type=dcp` for real checkpoint/resume +## Reference: where things live + +- ABCs + built-ins: `cosmos_framework/data/vfm/dataflow/` (`base.py`, + `distributors.py`, `batchers.py`, `collators.py`, `processors.py`, `loader.py`). +- Public symbols are re-exported from `cosmos_framework.data.vfm.dataflow`. +- Live recipes using the loader: `pre_exp012_llava_ov`, + `pre_exp012_llava_ov_mapstyle_dataloader`, `videophy2_sft_nano`, and `vision_sft_nano_mapstyle_dataloader`. diff --git a/examples/launch_sft_llava_ov.sh b/examples/launch_sft_llava_ov.sh index 69ecfa9..7027a58 100755 --- a/examples/launch_sft_llava_ov.sh +++ b/examples/launch_sft_llava_ov.sh @@ -2,10 +2,10 @@ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: OpenMDW-1.1 -# Structured-TOML launch for llava_ov_datapacker (VLM SFT on -# lmms-lab/LLaVA-OneVision-Data via DataPackerDataLoader). Drives +# Structured-TOML launch for llava_ov (VLM SFT on +# lmms-lab/LLaVA-OneVision-Data via CosmosDataLoader). Drives # cosmos_framework.scripts.train against -# examples/toml/sft_config/llava_ov_datapacker.toml. +# examples/toml/sft_config/llava_ov.toml. # # [job].task = "vlm" — picks cosmos_framework/configs/base/vlm/config.py as the base config. # @@ -20,6 +20,6 @@ # Usage (8-GPU allocation, inside the training container, from the repo root): # bash examples/launch_sft_llava_ov.sh -TOML_FILE="examples/toml/sft_config/llava_ov_datapacker.toml" +TOML_FILE="examples/toml/sft_config/llava_ov.toml" source "$(dirname "${BASH_SOURCE[0]}")/_sft_launcher_common.sh" diff --git a/examples/launch_sft_llava_ov_mapstyle_dataloader.sh b/examples/launch_sft_llava_ov_mapstyle_dataloader.sh new file mode 100755 index 0000000..32e67ef --- /dev/null +++ b/examples/launch_sft_llava_ov_mapstyle_dataloader.sh @@ -0,0 +1,44 @@ +#!/usr/bin/env bash +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: OpenMDW-1.1 + +# MapDistributor-backed VLM SFT (llava_ov_mapstyle_dataloader) with +# per-worker (epoch, index) checkpoint/resume via CosmosDataLoaderStateCallback. +# +# Optional env vars: +# RUN_NAME — job name, also used to derive the checkpoint output path. +# Default: llava_ov_mapstyle_dataloader_ +# RESUME_FROM_CKPT — if set, resumes from this checkpoint directory (supplies +# checkpoint.load_path and sets load_training_state=true). +# Expected format: +# ////checkpoints/iter_ +# +# Usage — fresh run: +# bash examples/launch_sft_llava_ov_mapstyle_dataloader.sh +# +# Usage — resume: +# RESUME_FROM_CKPT= \ +# RUN_NAME= \ +# bash examples/launch_sft_llava_ov_mapstyle_dataloader.sh + +TOML_FILE="examples/toml/sft_config/llava_ov_mapstyle_dataloader.toml" +: "${RUN_NAME:=llava_ov_mapstyle_dataloader_$(date +%Y%m%d_%H%M%S)}" +MASTER_PORT="${MASTER_PORT:-50016}" + +TAIL_OVERRIDES=( + "data_setting.max_tokens=16000" + "trainer.max_iter=60" + "checkpoint.save_iter=50" + "job.wandb_mode=online" + "job.name=${RUN_NAME}" +) + +if [[ -n "${RESUME_FROM_CKPT:-}" ]]; then + echo ">>> Resuming from checkpoint: ${RESUME_FROM_CKPT}" + TAIL_OVERRIDES+=( + "checkpoint.load_path=${RESUME_FROM_CKPT}" + "checkpoint.load_training_state=true" + ) +fi + +source "$(dirname "${BASH_SOURCE[0]}")/_sft_launcher_common.sh" diff --git a/examples/launch_sft_videophy2_nano.sh b/examples/launch_sft_videophy2_nano.sh index 2191f76..b0818fb 100755 --- a/examples/launch_sft_videophy2_nano.sh +++ b/examples/launch_sft_videophy2_nano.sh @@ -3,7 +3,7 @@ # SPDX-License-Identifier: OpenMDW-1.1 # Structured-TOML launch for videophy2_sft_nano (VLM dialog SFT on VideoPhy-2 -# via DataPackerDataLoader). Drives cosmos_framework.scripts.train against +# via CosmosDataLoader). Drives cosmos_framework.scripts.train against # examples/toml/sft_config/videophy2_sft_nano.toml. # # [job].task = "vlm" — picks cosmos_framework/configs/base/vlm/config.py as the base config. diff --git a/examples/launch_sft_vision_nano_cosmosdataloader.sh b/examples/launch_sft_vision_nano_cosmosdataloader.sh new file mode 100755 index 0000000..a6f5797 --- /dev/null +++ b/examples/launch_sft_vision_nano_cosmosdataloader.sh @@ -0,0 +1,28 @@ +#!/usr/bin/env bash +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: OpenMDW-1.1 + +# Dataflow-loader mirror of the VFM vision_sft_nano recipe (vision_sft_nano_mapstyle_dataloader) +# for loss-curve regression vs the baseline launched by launch_sft_vision_nano.sh. +# +# Optional env vars (defaults below point under examples/; override to put +# data or checkpoints on a different filesystem): +# DATASET_PATH default: examples/data/bridge-v2-subset-synthetic-captions/sft_dataset_bridge +# (must contain train/video_dataset_file.jsonl) +# BASE_CHECKPOINT_PATH default: examples/checkpoints/Cosmos3-Nano +# RUN_NAME default: vision_sft_nano_mapstyle_dataloader_ + +TOML_FILE="examples/toml/sft_config/vision_sft_nano_mapstyle_dataloader.toml" +: "${DATASET_PATH:=examples/data/bridge-v2-subset-synthetic-captions/sft_dataset_bridge}" +: "${BASE_CHECKPOINT_PATH:=examples/checkpoints/Cosmos3-Nano}" + +EXTRA_DATASET_CHECK='[[ -f "$DATASET_PATH/train/video_dataset_file.jsonl" ]] || { echo "ERROR: missing $DATASET_PATH/train/video_dataset_file.jsonl" >&2; exit 1; }' + +: "${RUN_NAME:=vision_sft_nano_mapstyle_dataloader_$(date +%Y%m%d_%H%M%S)}" + +TAIL_OVERRIDES=( + "trainer.logging_iter=1" "trainer.max_iter=500" + "job.project=cosmos_oss_alignment" "job.wandb_mode=online" "job.name=${RUN_NAME}" +) + +source "$(dirname "${BASH_SOURCE[0]}")/_sft_launcher_common.sh" diff --git a/examples/toml/sft_config/llava_ov_datapacker.toml b/examples/toml/sft_config/llava_ov.toml similarity index 87% rename from examples/toml/sft_config/llava_ov_datapacker.toml rename to examples/toml/sft_config/llava_ov.toml index 3c45f11..f07a07b 100644 --- a/examples/toml/sft_config/llava_ov_datapacker.toml +++ b/examples/toml/sft_config/llava_ov.toml @@ -1,8 +1,8 @@ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: OpenMDW-1.1 -# pre_exp012_llava_ov_datapacker — VLM training on lmms-lab/LLaVA-OneVision-Data -# via DataPackerDataLoader. Base config = cosmos_framework/configs/base/vlm/config.py +# pre_exp012_llava_ov — VLM training on lmms-lab/LLaVA-OneVision-Data +# via CosmosDataLoader. Base config = cosmos_framework/configs/base/vlm/config.py # (selected by [job].task="vlm"). # # One knob that the SFTExperimentConfig dataclass does NOT model — supply @@ -14,7 +14,7 @@ # # Example launch: # torchrun --nproc_per_node=4 -m cosmos_framework.scripts.train \ -# --sft-toml examples/toml/sft_config/llava_ov_datapacker.toml -- \ +# --sft-toml examples/toml/sft_config/llava_ov.toml -- \ # data_setting.max_tokens=16000 # # Per-task remap (see _PATH_REMAPS["vlm"]): @@ -25,16 +25,15 @@ # model.attn_implementation -> model.config.policy.attn_implementation # model.backbone.* -> model.config.policy.backbone.* # model.ema.* -> model.config.ema.* -# dataloader_train.max_samples_per_batch -> dataloader_train.max_batch_size # model.{max_num_tokens_after_packing, joint_attn_implementation, lora_*, # tokenizer.*} and dataloader_train.{max_sequence_length, seed} -> SKIPPED [job] task = "vlm" -experiment = "pre_exp012_llava_ov_datapacker" +experiment = "pre_exp012_llava_ov" project = "cosmos3" # matches legacy group = "vlm_llava_ov_demo" -name = "pre_exp012_llava_ov_datapacker" +name = "pre_exp012_llava_ov" wandb_mode = "disabled" [model] @@ -102,4 +101,8 @@ load_path = "???" # MISSING sentinel; skipped by save_iter = 100 [dataloader_train] -max_samples_per_batch = 1 # → dataloader_train.max_batch_size on VLM +# Routed by PATH_REMAPS["vlm"] onto the CosmosDataLoader's nested PoolPackingBatcher: +# max_samples_per_batch -> dataloader_train.batcher.max_batch_size +# max_sequence_length -> dataloader_train.batcher.max_tokens +max_samples_per_batch = 1 +max_sequence_length = 16000 diff --git a/examples/toml/sft_config/llava_ov_mapstyle_dataloader.toml b/examples/toml/sft_config/llava_ov_mapstyle_dataloader.toml new file mode 100644 index 0000000..5b1a8d9 --- /dev/null +++ b/examples/toml/sft_config/llava_ov_mapstyle_dataloader.toml @@ -0,0 +1,117 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: OpenMDW-1.1 + +# pre_exp012_llava_ov_mapstyle_dataloader — map-style resumable VLM recipe +# Base config = cosmos_framework/configs/base/vlm/config.py (task="vlm"). +# +# Loads LLaVA-OneVision-Data ai2d(gpt4v) as a real map-style Dataset +# (load_dataset(streaming=False)), filters + caps to 4000 rows, and shards it +# with MapDistributor — enabling deterministic per-worker (epoch, index) +# checkpointing and resume. With dp_shard=8 that is 500 samples/rank/epoch, so +# the iter-100 checkpoint (100 samples/rank) lands mid-epoch. +# +# dryrun:: +# python -m cosmos_framework.scripts.train \ +# --sft-toml=examples/toml/sft_config/llava_ov_mapstyle_dataloader.toml --dryrun -- \ +# data_setting.max_tokens=16000 +# +# Fresh run:: +# torchrun --nproc_per_node=4 --master_port=12344 \ +# -m cosmos_framework.scripts.train \ +# --sft-toml=examples/toml/sft_config/llava_ov_mapstyle_dataloader.toml -- \ +# data_setting.max_tokens=16000 +# +# Resume run (from iter 100 checkpoint):: +# torchrun --nproc_per_node=4 --master_port=12344 \ +# -m cosmos_framework.scripts.train \ +# --sft-toml=examples/toml/sft_config/llava_ov_mapstyle_dataloader.toml -- \ +# data_setting.max_tokens=16000 \ +# checkpoint.load_path=/tmp/imaginaire4-output/cosmos_oss_alignment/vlm_llava_ov_demo/pre_exp012_llava_ov_mapstyle_dataloader/checkpoints/iter_000000100 \ +# checkpoint.load_training_state=true +# +# Per-task remap (see _PATH_REMAPS["vlm"]): +# model.parallelism.* -> model.config.parallelism.* +# model.compile.* -> model.config.compile.* +# model.activation_checkpointing.* -> model.config.activation_checkpointing.* +# model.precision -> model.config.precision +# model.attn_implementation -> model.config.policy.attn_implementation +# model.backbone.* -> model.config.policy.backbone.* +# model.ema.* -> model.config.ema.* +# dataloader_train.{max_sequence_length, seed} -> SKIPPED (handled by batcher) + +[job] +task = "vlm" +experiment = "pre_exp012_llava_ov_mapstyle_dataloader" +project = "cosmos_oss_alignment" +group = "vlm_llava_ov_demo" +name = "pre_exp012_llava_ov_mapstyle_dataloader" +wandb_mode = "online" + +[model] +attn_implementation = "cosmos" +precision = "bfloat16" + +[model.backbone] +model_name = "Qwen/Qwen3-VL-8B-Instruct" + +[model.ema] +enabled = false +rate = 0.1 +iteration_shift = 0 + +[model.parallelism] +data_parallel_shard_degree = 8 +data_parallel_replicate_degree = -1 +context_parallel_shard_degree = 1 +cfg_parallel_shard_degree = 1 + +[model.compile] +enabled = false +compile_dynamic = true + +[model.activation_checkpointing] +mode = "full" +save_ops_regex = ["fmha"] +preserve_rng_state = true +determinism_check = "default" + +[optimizer] +betas = [0.9, 0.95] +eps = 1.0e-8 +fused = true +lr = 1.0e-5 +weight_decay = 0.1 + +[scheduler] +cycle_lengths = [500] +f_max = [1.0] +f_min = [0.5] +f_start = [0.05] +verbosity_interval = 0 +warm_up_steps = [1000] + +[trainer] +distributed_parallelism = "fsdp" +grad_accum_iter = 1 +logging_iter = 1 +max_iter = 500 + +[trainer.callbacks.compile_tokenizer] +compile_after_iterations = 3 +enabled = false + +[trainer.callbacks.grad_clip] +clip_norm = 1.0 +force_finite = false + +[checkpoint] +keys_to_skip_loading = [] +load_path = "???" # MISSING sentinel; supply at resume time +save_iter = 100 + +[dataloader_train] +# Routed by PATH_REMAPS["vlm"] onto the CosmosDataLoader's nested PoolPackingBatcher: +# max_samples_per_batch -> dataloader_train.batcher.max_batch_size +# max_sequence_length -> dataloader_train.batcher.max_tokens +max_samples_per_batch = 1 +max_sequence_length = 16000 diff --git a/examples/toml/sft_config/videophy2_sft_nano.toml b/examples/toml/sft_config/videophy2_sft_nano.toml index 8cb3bf3..a183db3 100644 --- a/examples/toml/sft_config/videophy2_sft_nano.toml +++ b/examples/toml/sft_config/videophy2_sft_nano.toml @@ -1,7 +1,7 @@ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: OpenMDW-1.1 -# videophy2_sft_nano — VLM dialog SFT on VideoPhy-2 via DataPackerDataLoader. +# videophy2_sft_nano — VLM dialog SFT on VideoPhy-2 via CosmosDataLoader. # Base config = cosmos_framework/configs/base/vlm/config.py (selected by [job].task="vlm"). # # Dataset prep: @@ -84,4 +84,8 @@ load_path = "???" save_iter = 100 [dataloader_train] +# Routed by PATH_REMAPS["vlm"] onto the CosmosDataLoader's nested PoolPackingBatcher: +# max_samples_per_batch -> dataloader_train.batcher.max_batch_size +# max_sequence_length -> dataloader_train.batcher.max_tokens max_samples_per_batch = 1 +max_sequence_length = 16000 diff --git a/examples/toml/sft_config/vision_sft_nano_mapstyle_dataloader.toml b/examples/toml/sft_config/vision_sft_nano_mapstyle_dataloader.toml new file mode 100644 index 0000000..5d76be8 --- /dev/null +++ b/examples/toml/sft_config/vision_sft_nano_mapstyle_dataloader.toml @@ -0,0 +1,91 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: OpenMDW-1.1 + +# vision_sft_nano_mapstyle_dataloader — T2V / I2V / V2V vision-only SFT (Qwen3-VL-8B / nano) +# Dataflow-loader mirror of vision_sft_nano: uses CosmosDataLoader + +# RankPartitionedDistributor + SequentialPackingBatcher + VFMListCollator. +# +# NOTE: the [dataloader_train] block from vision_sft_nano.toml is intentionally +# REMOVED here. In vision_sft_nano.toml, [dataloader_train.max_sequence_length] +# mapped onto PackingDataLoader's constructor kwarg. CosmosDataLoader has no +# top-level max_sequence_length attribute — the batcher (SequentialPackingBatcher) +# owns it, and max_sequence_length=45056 is already hardcoded in the experiment +# LazyDict (vision_sft_nano_mapstyle_dataloader.py). Applying the key onto the CosmosDataLoader node +# would raise ConfigAttributeError, so it is omitted here. + +[job] +task = "vfm" +experiment = "vision_sft_nano_mapstyle_dataloader" +project = "cosmos3" +group = "sft" +name = "vision_sft_nano_mapstyle_dataloader" +wandb_mode = "disabled" + +[model] +max_num_tokens_after_packing = 45056 +joint_attn_implementation = "two_way" +precision = "bfloat16" # was [model.parallelism].precision + +[model.ema] +enabled = true +rate = 0.1 +iteration_shift = 0 + +[model.parallelism] +data_parallel_shard_degree = -1 # -1 = auto from WORLD_SIZE (matches legacy) +data_parallel_replicate_degree = 1 + +[model.compile] +enabled = true # was [model.parallelism].use_torch_compile +compile_dynamic = true + +[model.activation_checkpointing] +mode = "full" +save_ops_regex = ["fmha"] +preserve_rng_state = true +determinism_check = "default" + +[model.tokenizer] +vae_path = "${oc.env:WAN_VAE_PATH}" + +[optimizer] +betas = [0.9, 0.95] +eps = 1.0e-6 +fused = true +keys_to_select = [ + "moe_gen", + "time_embedder", + "vae2llm", + "llm2vae", +] +lr = 2.0e-5 +weight_decay = 0 # int matches legacy YAML repr +# lr_multipliers intentionally empty for vision SFT (Hydra default {} stands). + +[scheduler] +cycle_lengths = [1000] +f_max = [1.0] +f_min = [0.0] +f_start = [0.0] +verbosity_interval = 0 +warm_up_steps = [50] + +[trainer] +distributed_parallelism = "fsdp" +grad_accum_iter = 2 +logging_iter = 1 +max_iter = 500 + +[trainer.callbacks.compile_tokenizer] +compile_after_iterations = 3 +enabled = false +# warmup_resolutions omitted (None at experiment level) + +[trainer.callbacks.grad_clip] +clip_norm = 0.1 +force_finite = true + +[checkpoint] +keys_to_skip_loading = ["net_ema."] +load_path = "${oc.env:BASE_CHECKPOINT_PATH}" +save_iter = 100 diff --git a/tests/launch_regression_test.py b/tests/launch_regression_test.py index 0a1a9f2..0106a2c 100644 --- a/tests/launch_regression_test.py +++ b/tests/launch_regression_test.py @@ -47,7 +47,7 @@ ``pytest-custom-exit-code`` plugin (not installed in the training image). Determinism notes: - * ``llava_ov_datapacker`` runs **without** ``--deterministic`` on H100 AND + * ``llava_ov`` runs **without** ``--deterministic`` on H100 AND overrides ``model.config.deterministic=false``: the Qwen3-VL text path uses an attention backend whose Hopper FMHA backward kernel has no deterministic mode (raises ``NotImplementedError`` under PyTorch's @@ -169,7 +169,7 @@ def _detect_arch() -> str: # --- log parsers ------------------------------------------------------------- # -# VLM (``pre_exp012_llava_ov_datapacker``) logs the DP-reduced loss on rank 0:: +# VLM (``pre_exp012_llava_ov``) logs the DP-reduced loss on rank 0:: # # train/loss_avg: 1.32225 (iteration 0) # @@ -218,7 +218,7 @@ class LaunchSpec: # ``test_launch_regression_8gpu`` (the ``gpus`` marker carries only one value, # so the test functions are split). _SPEC_KEYS = ( - "llava_ov_datapacker", + "llava_ov", "vision_sft_nano", ) _SPEC_KEYS_8GPU = ("vision_sft_super",) @@ -234,10 +234,10 @@ def _build_specs(paths: dict[str, str]) -> dict[str, LaunchSpec]: super_extra_env["BASE_CHECKPOINT_PATH"] = super_ckpt return { - "llava_ov_datapacker": LaunchSpec( + "llava_ov": LaunchSpec( # Replicates launch_sft_llava_ov.sh, capped to 10 iters. - key="llava_ov_datapacker", - sft_toml="examples/toml/sft_config/llava_ov_datapacker.toml", + key="llava_ov", + sft_toml="examples/toml/sft_config/llava_ov.toml", extra_hydra_args=( # TAIL_OVERRIDES from launch_sft_llava_ov.sh — fields not modeled # by SFTExperimentConfig. @@ -579,7 +579,7 @@ def test_launch_regression_8gpu(spec_key: str, tmp_path: Path, h100_inputs: dict # and seed 42 against the legacy training pipeline. VLM backbone is not # part of the OSS layout. "gb200": { - "llava_ov_datapacker": { + "llava_ov": { "loss": [1.32208, 1.20886, 1.39254, 1.40460, 1.16652, 1.24852, 1.38463, 1.22766, 0.96263, 1.14468], "grad_norm": [ 38.62454, 23.61477, 30.53218, 36.46255, 25.06240, @@ -606,7 +606,7 @@ def test_launch_regression_8gpu(spec_key: str, tmp_path: Path, h100_inputs: dict # and the loss_tol_bands tiers). Centered on the midpoint of two H200 CI # runs (CI runs on H200) so the tiered bands keep maximum margin; iter-0 # is bit-exact across H100/H200 runs. grad-norm is non-det, so None. - "llava_ov_datapacker": { + "llava_ov": { "loss": [1.06924, 0.88399, 1.09293, 1.16314, 1.03592, 0.99041, 1.11041, 0.97001, 0.81246, 0.98548], "grad_norm": None, }, From e40b5774287a94a8f403dce4714d50892a8d48f8 Mon Sep 17 00:00:00 2001 From: Liangkai Zhang Date: Fri, 12 Jun 2026 07:17:50 -0700 Subject: [PATCH 08/16] [Cosmos3 OSS]Add more action dataset (#34) 1. Add more action datasets, covering bridge/AgibotWorldBeta/Robomind Franka Dual 2. Add action stats for action datasets. 3. Add action denormalization. --- .../data/vfm/action/action_normalization.py | 29 +- .../vfm/action/action_normalization_test.py | 145 ++ cosmos_framework/data/vfm/action/agibot_fk.py | 398 +++++ .../data/vfm/action/agibot_spec.py | 129 ++ .../data/vfm/action/datasets/__init__.py | 19 +- .../agibotworld_beta_lerobot_dataset.py | 277 ++++ .../data/vfm/action/datasets/base_dataset.py | 204 +++ .../datasets/bridge_orig_lerobot_dataset.py | 152 ++ .../action/datasets/droid_lerobot_dataset.py | 163 +- .../datasets/robomind_franka_dataset.py | 192 +++ .../stats/agibotworld_beta_lerobot_stats.json | 4 + .../stats/bridge_orig_lerobot_stats.json | 4 + .../droid_lerobot_stats.json} | 0 .../datasets/stats/robomind_franka_stats.json | 4 + .../G1_omnipicker_calibrated.urdf | 1350 +++++++++++++++++ 15 files changed, 2924 insertions(+), 146 deletions(-) create mode 100644 cosmos_framework/data/vfm/action/action_normalization_test.py create mode 100644 cosmos_framework/data/vfm/action/agibot_fk.py create mode 100644 cosmos_framework/data/vfm/action/agibot_spec.py create mode 100644 cosmos_framework/data/vfm/action/datasets/agibotworld_beta_lerobot_dataset.py create mode 100644 cosmos_framework/data/vfm/action/datasets/base_dataset.py create mode 100644 cosmos_framework/data/vfm/action/datasets/bridge_orig_lerobot_dataset.py create mode 100644 cosmos_framework/data/vfm/action/datasets/robomind_franka_dataset.py create mode 100644 cosmos_framework/data/vfm/action/datasets/stats/agibotworld_beta_lerobot_stats.json create mode 100644 cosmos_framework/data/vfm/action/datasets/stats/bridge_orig_lerobot_stats.json rename cosmos_framework/data/vfm/action/datasets/{droid_lerobot_normalization.json => stats/droid_lerobot_stats.json} (100%) create mode 100644 cosmos_framework/data/vfm/action/datasets/stats/robomind_franka_stats.json create mode 100644 cosmos_framework/data/vfm/action/urdf_visualizer/G1_omnipicker_calibrated.urdf diff --git a/cosmos_framework/data/vfm/action/action_normalization.py b/cosmos_framework/data/vfm/action/action_normalization.py index c58bb90..8504cf0 100644 --- a/cosmos_framework/data/vfm/action/action_normalization.py +++ b/cosmos_framework/data/vfm/action/action_normalization.py @@ -12,7 +12,7 @@ from cosmos_framework.utils import log -def load_action_stats(stats_path: str, stats_key: str = "global") -> dict[str, np.ndarray]: +def load_action_stats(stats_path: str) -> dict[str, np.ndarray]: """Load pre-computed action normalization stats from a JSON file.""" path = Path(stats_path) if not path.exists(): @@ -20,12 +20,6 @@ def load_action_stats(stats_path: str, stats_key: str = "global") -> dict[str, n log.info(f"Loading action normalization stats from {stats_path}") with path.open("r") as f: raw = json.load(f) - if stats_key in raw: - raw = raw[stats_key] - if not isinstance(raw, dict): - raise TypeError(f"Action normalization stats block {stats_key!r} in {stats_path} must be a dict.") - elif stats_key != "global": - raise KeyError(f"Action normalization stats block {stats_key!r} not found in {stats_path}.") stat_keys = {"mean", "std", "min", "max", "q01", "q99"} return {key: np.array(value, dtype=np.float32) for key, value in raw.items() if key in stat_keys} @@ -39,11 +33,28 @@ def normalize_action( if method == "quantile": q01, q99 = stats["q01"], stats["q99"] denom = (q99 - q01).clamp(min=1e-8) - return (2.0 * (action - q01) / denom - 1.0).clamp(-1.0, 1.0) + return 2.0 * (action - q01) / denom - 1.0 if method == "meanstd": return (action - stats["mean"]) / stats["std"].clamp(min=1e-8) if method == "minmax": lo, hi = stats["min"], stats["max"] denom = (hi - lo).clamp(min=1e-8) - return (2.0 * (action - lo) / denom - 1.0).clamp(-1.0, 1.0) + return 2.0 * (action - lo) / denom - 1.0 + raise ValueError(f"Unknown normalization method: {method!r}") + + +def denormalize_action( + action: torch.Tensor, + method: str, + stats: dict[str, torch.Tensor], +) -> torch.Tensor: + """Denormalize action tensor.""" + if method == "quantile": + q01, q99 = stats["q01"], stats["q99"] + return 0.5 * (action + 1.0) * (q99 - q01) + q01 + if method == "meanstd": + return action * stats["std"] + stats["mean"] + if method == "minmax": + lo, hi = stats["min"], stats["max"] + return 0.5 * (action + 1.0) * (hi - lo) + lo raise ValueError(f"Unknown normalization method: {method!r}") diff --git a/cosmos_framework/data/vfm/action/action_normalization_test.py b/cosmos_framework/data/vfm/action/action_normalization_test.py new file mode 100644 index 0000000..bbfff16 --- /dev/null +++ b/cosmos_framework/data/vfm/action/action_normalization_test.py @@ -0,0 +1,145 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: OpenMDW-1.1 + +import json + +import numpy as np +import pytest +import torch + +from cosmos_framework.data.vfm.action.action_normalization import ( + denormalize_action, + load_action_stats, + normalize_action, +) + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +_RAW_STATS = { + "mean": [0.0, 1.0, -1.0], + "std": [1.0, 2.0, 0.5], + "min": [-2.0, -1.0, -3.0], + "max": [2.0, 3.0, 1.0], + "q01": [-1.0, 0.0, -2.0], + "q99": [1.0, 2.0, 0.0], +} + + +def _tensor_stats(raw=_RAW_STATS) -> dict[str, torch.Tensor]: + return {k: torch.tensor(v, dtype=torch.float32) for k, v in raw.items()} + + +def _action() -> torch.Tensor: + return torch.tensor([[0.0, 1.0, -1.0], [1.0, 2.0, 0.0]], dtype=torch.float32) + + +# --------------------------------------------------------------------------- +# load_action_stats +# --------------------------------------------------------------------------- + + +def test_load_action_stats_flat(tmp_path): + p = tmp_path / "stats.json" + p.write_text(json.dumps(_RAW_STATS)) + result = load_action_stats(str(p)) + assert set(result) == set(_RAW_STATS) + for key, value in result.items(): + assert isinstance(value, np.ndarray) + assert value.dtype == np.float32 + np.testing.assert_array_equal(value, np.array(_RAW_STATS[key], dtype=np.float32)) + + +def test_load_action_stats_filters_unknown_keys(tmp_path): + raw = {**_RAW_STATS, "extra_field": [1.0, 2.0]} + p = tmp_path / "stats.json" + p.write_text(json.dumps(raw)) + result = load_action_stats(str(p)) + assert "extra_field" not in result + + +def test_load_action_stats_missing_file(): + with pytest.raises(FileNotFoundError): + load_action_stats("/nonexistent/path/stats.json") + + +# --------------------------------------------------------------------------- +# normalize_action / denormalize_action — round-trip identity +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize("method", ["quantile", "meanstd", "minmax"]) +def test_round_trip(method): + action = _action() + stats = _tensor_stats() + normalized = normalize_action(action, method, stats) + recovered = denormalize_action(normalized, method, stats) + torch.testing.assert_close(recovered, action, atol=1e-5, rtol=1e-5) + + +# --------------------------------------------------------------------------- +# normalize_action — endpoint correctness +# --------------------------------------------------------------------------- + + +def test_normalize_quantile_endpoints(): + stats = _tensor_stats() + q01, q99 = stats["q01"], stats["q99"] + assert torch.allclose(normalize_action(q01.unsqueeze(0), "quantile", stats), torch.full((1, 3), -1.0)) + assert torch.allclose(normalize_action(q99.unsqueeze(0), "quantile", stats), torch.full((1, 3), 1.0)) + + +def test_normalize_minmax_endpoints(): + stats = _tensor_stats() + lo, hi = stats["min"], stats["max"] + assert torch.allclose(normalize_action(lo.unsqueeze(0), "minmax", stats), torch.full((1, 3), -1.0)) + assert torch.allclose(normalize_action(hi.unsqueeze(0), "minmax", stats), torch.full((1, 3), 1.0)) + + +def test_normalize_meanstd_zero_mean(): + stats = _tensor_stats() + result = normalize_action(stats["mean"].unsqueeze(0), "meanstd", stats) + assert torch.allclose(result, torch.zeros(1, 3)) + + +# --------------------------------------------------------------------------- +# denormalize_action — endpoint correctness +# --------------------------------------------------------------------------- + + +def test_denormalize_quantile_endpoints(): + stats = _tensor_stats() + q01, q99 = stats["q01"], stats["q99"] + assert torch.allclose(denormalize_action(torch.full((1, 3), -1.0), "quantile", stats), q01.unsqueeze(0)) + assert torch.allclose(denormalize_action(torch.full((1, 3), 1.0), "quantile", stats), q99.unsqueeze(0)) + + +def test_denormalize_minmax_endpoints(): + stats = _tensor_stats() + lo, hi = stats["min"], stats["max"] + assert torch.allclose(denormalize_action(torch.full((1, 3), -1.0), "minmax", stats), lo.unsqueeze(0)) + assert torch.allclose(denormalize_action(torch.full((1, 3), 1.0), "minmax", stats), hi.unsqueeze(0)) + + +# --------------------------------------------------------------------------- +# Edge cases +# --------------------------------------------------------------------------- + + +def test_normalize_zero_range_no_nan(): + stats = {k: torch.zeros(3) for k in ("q01", "q99", "mean", "std", "min", "max")} + action = torch.ones(1, 3) + for method in ("quantile", "meanstd", "minmax"): + result = normalize_action(action, method, stats) + assert torch.isfinite(result).all(), f"{method} produced non-finite output with zero range" + + +def test_normalize_unknown_method_raises(): + with pytest.raises(ValueError, match="Unknown normalization method"): + normalize_action(_action(), "unknown_method", _tensor_stats()) + + +def test_denormalize_unknown_method_raises(): + with pytest.raises(ValueError, match="Unknown normalization method"): + denormalize_action(_action(), "unknown_method", _tensor_stats()) diff --git a/cosmos_framework/data/vfm/action/agibot_fk.py b/cosmos_framework/data/vfm/action/agibot_fk.py new file mode 100644 index 0000000..1038284 --- /dev/null +++ b/cosmos_framework/data/vfm/action/agibot_fk.py @@ -0,0 +1,398 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: OpenMDW-1.1 + +"""Lightweight AgiBot World forward kinematics for datasets and viewers.""" + +from __future__ import annotations + +import xml.etree.ElementTree as ET +from functools import lru_cache + +import numpy as np + +from cosmos_framework.data.vfm.action.agibot_spec import ( + AGIBOT_WORLD_ARM_JOINT_NAMES_LEFT, + AGIBOT_WORLD_ARM_JOINT_NAMES_RIGHT, + AGIBOT_WORLD_ARM_STATE_SLICE, + AGIBOT_WORLD_EXT_ARM_STATE_SLICE, + AGIBOT_WORLD_EXT_STATE_HEAD_PITCH_IDX, + AGIBOT_WORLD_EXT_STATE_HEAD_YAW_IDX, + AGIBOT_WORLD_EXT_STATE_LEFT_HAND_SLICE, + AGIBOT_WORLD_EXT_STATE_RIGHT_HAND_SLICE, + AGIBOT_WORLD_EXT_STATE_ROBOT_ORIENTATION_SLICE, + AGIBOT_WORLD_EXT_STATE_ROBOT_POSITION_SLICE, + AGIBOT_WORLD_EXT_STATE_WAIST_LIFT_IDX, + AGIBOT_WORLD_EXT_STATE_WAIST_PITCH_IDX, + AGIBOT_WORLD_GRIPPER_OPEN_ACTUATOR_DEG, + AGIBOT_WORLD_GRIPPER_OPEN_ANGLE_RAD, + AGIBOT_WORLD_HEAD_CAMERA_LINK_NAME, + AGIBOT_WORLD_HEAD_PITCH_JOINT_NAME, + AGIBOT_WORLD_HEAD_YAW_JOINT_NAME, + AGIBOT_WORLD_LEFT_EE_LINK_NAME, + AGIBOT_WORLD_LEFT_GRIPPER_JOINT_MIMICS, + AGIBOT_WORLD_RIGHT_EE_LINK_NAME, + AGIBOT_WORLD_RIGHT_GRIPPER_JOINT_MIMICS, + AGIBOT_WORLD_STATE_HEAD_PITCH_IDX, + AGIBOT_WORLD_STATE_HEAD_YAW_IDX, + AGIBOT_WORLD_STATE_WAIST_LIFT_IDX, + AGIBOT_WORLD_STATE_WAIST_PITCH_IDX, + AGIBOT_WORLD_WAIST_LIFT_JOINT_NAME, + AGIBOT_WORLD_WAIST_PITCH_JOINT_NAME, + get_agibot_world_embodiment_spec, + get_agibot_world_kind_spec, + get_agibot_world_urdf_path, +) +from cosmos_framework.data.vfm.action.pose_utils import convert_rotation + +_GRIPPER_VALUE_EPS = 1e-4 +_QUATERNION_NORM_EPS = 1e-8 +_GRIPPER_ACTUATOR_OVERSHOOT_DEG = 5.0 +# Main-branch wrist rotations composed with one extra local-Z 180 degree rotation. +AGIBOT_WORLD_LEFT_GRIPPER_TO_OPENCV: np.ndarray = np.asarray( + [ + [0.0, 1.0, 0.0], + [-1.0, 0.0, 0.0], + [0.0, 0.0, 1.0], + ], + dtype=np.float32, +) +AGIBOT_WORLD_RIGHT_GRIPPER_TO_OPENCV: np.ndarray = np.asarray( + [ + [0.0, -1.0, 0.0], + [1.0, 0.0, 0.0], + [0.0, 0.0, 1.0], + ], + dtype=np.float32, +) +AGIBOT_WORLD_GRIPPER_TO_OPENCV_BY_WRIST: dict[str, np.ndarray] = { + "right_wrist": AGIBOT_WORLD_RIGHT_GRIPPER_TO_OPENCV, + "left_wrist": AGIBOT_WORLD_LEFT_GRIPPER_TO_OPENCV, +} + + +def _scale_to_unit_interval(values: np.ndarray, scale: float) -> np.ndarray: + """Scale non-negative gripper actuator values to ``[0,1]``.""" + + return np.clip(values / scale, 0.0, 1.0).astype(np.float32, copy=False) + + +def _scale_negative_to_unit_interval(values: np.ndarray, scale: float) -> np.ndarray: + """Scale URDF-style negative gripper angles to ``[0,1]`` open fractions.""" + + return np.clip(-values / scale, 0.0, 1.0).astype(np.float32, copy=False) + + +def _normalize_quaternions_xyzw(quaternions: np.ndarray) -> np.ndarray: + """Normalize ``xyzw`` quaternions, treating all-zero rows as identity.""" + + normalized = np.asarray(quaternions, dtype=np.float32).copy() # [T,4] + norms = np.linalg.norm(normalized, axis=-1, keepdims=True) # [T,1] + valid = norms[:, 0] >= _QUATERNION_NORM_EPS # [T] + normalized[valid] = normalized[valid] / norms[valid] # [T_valid,4] + normalized[~valid] = np.asarray([0.0, 0.0, 0.0, 1.0], dtype=np.float32) # [T_invalid,4] + return normalized + + +def _quat_xyzw_to_rotation_matrix(quaternions: np.ndarray) -> np.ndarray: + """Convert ``xyzw`` quaternions to rotation matrices.""" + + normalized = _normalize_quaternions_xyzw(quaternions) # [T,4] + rotations = convert_rotation( + normalized, + input_format="quat_xyzw", + output_format="matrix", + normalize_matrix=True, + ) + return np.asarray(rotations, dtype=np.float32) + + +def build_robot_base_transforms(positions: np.ndarray, quaternions: np.ndarray) -> np.ndarray: + """Build robot-base poses from position and ``xyzw`` quaternion arrays.""" + + positions = np.asarray(positions, dtype=np.float32) # [T,3] + quaternions = np.asarray(quaternions, dtype=np.float32) # [T,4] + if positions.ndim != 2 or positions.shape[1] != 3: + raise ValueError(f"robot base positions must have shape [T,3], got {positions.shape}.") + if quaternions.ndim != 2 or quaternions.shape[1] != 4: + raise ValueError(f"robot base quaternions must have shape [T,4], got {quaternions.shape}.") + if positions.shape[0] != quaternions.shape[0]: + raise ValueError( + f"robot base positions/quaternions must share T, got {positions.shape[0]} and {quaternions.shape[0]}." + ) + + transforms = np.tile(np.eye(4, dtype=np.float32), (positions.shape[0], 1, 1)) # [T,4,4] + transforms[:, :3, :3] = _quat_xyzw_to_rotation_matrix(quaternions) # [T,3,3] + transforms[:, :3, 3] = positions # [T,3] + return transforms + + +def _invert_rigid_transform(transform: np.ndarray) -> np.ndarray: + """Invert one homogeneous rigid transform.""" + + inverse = np.eye(4, dtype=np.float32) # [4,4] + rotation_t = transform[:3, :3].T.astype(np.float32, copy=False) # [3,3] + inverse[:3, :3] = rotation_t + inverse[:3, 3] = -(rotation_t @ transform[:3, 3]) # [3] + return inverse + + +def apply_robot_base_motion_to_poses( + poses_by_name: dict[str, np.ndarray], + positions: np.ndarray, + quaternions: np.ndarray, +) -> dict[str, np.ndarray]: + """Apply mobile-base motion to FK poses, normalized to the first frame.""" + + base_poses = build_robot_base_transforms(positions, quaternions) # [T,4,4] + initial_base_inv = _invert_rigid_transform(base_poses[0]) # [4,4] + base_motion = np.einsum("ij,tjk->tik", initial_base_inv, base_poses).astype(np.float32, copy=False) # [T,4,4] + return { + name: np.einsum("tij,tjk->tik", base_motion, poses).astype(np.float32, copy=False) # [T,4,4] + for name, poses in poses_by_name.items() + } + + +def _apply_ext_base_motion_to_poses( + poses_by_name: dict[str, np.ndarray], + states: np.ndarray, + embodiment_type: str, +) -> dict[str, np.ndarray]: + """Apply ext mobile-base motion to FK poses, normalized to the first frame.""" + + if embodiment_type != "agibot_world_gripper_ext": + return poses_by_name + if states.shape[1] < AGIBOT_WORLD_EXT_STATE_ROBOT_ORIENTATION_SLICE.stop: + raise ValueError( + f"agibot_world_gripper_ext state must include robot pose through index " + f"{AGIBOT_WORLD_EXT_STATE_ROBOT_ORIENTATION_SLICE.stop - 1}, got shape {states.shape}." + ) + + positions = states[:, AGIBOT_WORLD_EXT_STATE_ROBOT_POSITION_SLICE].astype(np.float32, copy=False) # [T,3] + quaternions = states[:, AGIBOT_WORLD_EXT_STATE_ROBOT_ORIENTATION_SLICE].astype(np.float32, copy=False) # [T,4] + return apply_robot_base_motion_to_poses(poses_by_name, positions, quaternions) + + +def apply_agibot_gripper_to_opencv( + poses_by_name: dict[str, np.ndarray], + to_opencv_by_wrist: dict[str, np.ndarray], +) -> dict[str, np.ndarray]: + """Post-rotate AgiBot gripper wrist poses into OpenCV convention.""" + + aligned = {name: poses.astype(np.float32, copy=True) for name, poses in poses_by_name.items()} # {name:[...,4,4]} + for wrist_name, wrist_to_opencv in to_opencv_by_wrist.items(): + poses = aligned.get(wrist_name) + if poses is None: + continue + aligned[wrist_name][..., :3, :3] = poses[..., :3, :3] @ wrist_to_opencv.astype(poses.dtype) # [...,3,3] + return aligned + + +def _get_agibot_world_mujoco_kinematics_xml() -> str: + """Build a MuJoCo-loadable kinematics-only XML string from the committed URDF.""" + + root = ET.parse(get_agibot_world_urdf_path()).getroot() + mujoco_element = root.find("mujoco") + if mujoco_element is None: + mujoco_element = ET.Element("mujoco") + root.insert(0, mujoco_element) + compiler_element = mujoco_element.find("compiler") + if compiler_element is None: + compiler_element = ET.SubElement(mujoco_element, "compiler") + compiler_element.attrib["fusestatic"] = "false" + + for link_element in root.findall("link"): + for child_element in list(link_element): + if child_element.tag in {"visual", "collision"}: + link_element.remove(child_element) + + return ET.tostring(root, encoding="unicode") + + +class _MujocoFk: + """MuJoCo-backed FK engine for the committed AgiBot G1 omnipicker URDF.""" + + def __init__(self) -> None: + import mujoco + + self._mujoco = mujoco + self.model = mujoco.MjModel.from_xml_string(_get_agibot_world_mujoco_kinematics_xml()) + self.data = mujoco.MjData(self.model) + self._joint_qpos_addresses: dict[str, int] = {} + for joint_id in range(self.model.njnt): + joint_name = mujoco.mj_id2name(self.model, mujoco.mjtObj.mjOBJ_JOINT, joint_id) + if joint_name is not None: + self._joint_qpos_addresses[joint_name] = int(self.model.jnt_qposadr[joint_id]) + + def link_poses(self, joint_values: dict[str, float]) -> dict[str, np.ndarray]: + """Return world transforms for every named body in the MuJoCo model.""" + + self.data.qpos[:] = 0.0 + for joint_name, joint_value in joint_values.items(): + qpos_address = self._joint_qpos_addresses.get(joint_name) + if qpos_address is not None: + self.data.qpos[qpos_address] = float(joint_value) + self._mujoco.mj_forward(self.model, self.data) + + poses: dict[str, np.ndarray] = {} + for body_id in range(1, self.model.nbody): + body_name = self._mujoco.mj_id2name(self.model, self._mujoco.mjtObj.mjOBJ_BODY, body_id) + if body_name is None: + continue + transform = np.eye(4, dtype=np.float32) + transform[:3, :3] = self.data.xmat[body_id].reshape(3, 3).astype(np.float32, copy=False) + transform[:3, 3] = self.data.xpos[body_id].astype(np.float32, copy=False) + poses[body_name] = transform + return poses + + +@lru_cache(maxsize=1) +def _get_fk_engine() -> _MujocoFk: + """Return a cached MuJoCo FK engine for the committed AgiBot URDF.""" + + return _MujocoFk() + + +def _extract_joint_values_from_state(state: np.ndarray, embodiment_type: str) -> dict[str, float]: + """Map one observation.state vector to the URDF joint names used for FK.""" + + if embodiment_type == "agibot_world_gripper_ext": + # Ext layout: 94-dim state with joints at different offsets. + arm_state = state[AGIBOT_WORLD_EXT_ARM_STATE_SLICE] + head_yaw = float(state[AGIBOT_WORLD_EXT_STATE_HEAD_YAW_IDX]) + head_pitch = float(state[AGIBOT_WORLD_EXT_STATE_HEAD_PITCH_IDX]) + waist_lift = float(state[AGIBOT_WORLD_EXT_STATE_WAIST_LIFT_IDX]) + waist_pitch = float(state[AGIBOT_WORLD_EXT_STATE_WAIST_PITCH_IDX]) + else: + arm_state = state[AGIBOT_WORLD_ARM_STATE_SLICE] + head_yaw = float(state[AGIBOT_WORLD_STATE_HEAD_YAW_IDX]) + head_pitch = float(state[AGIBOT_WORLD_STATE_HEAD_PITCH_IDX]) + waist_pitch = float(state[AGIBOT_WORLD_STATE_WAIST_PITCH_IDX]) + waist_lift = float(state[AGIBOT_WORLD_STATE_WAIST_LIFT_IDX]) + + joint_values = { + AGIBOT_WORLD_WAIST_LIFT_JOINT_NAME: float(waist_lift), + AGIBOT_WORLD_WAIST_PITCH_JOINT_NAME: float(waist_pitch), + AGIBOT_WORLD_HEAD_YAW_JOINT_NAME: float(head_yaw), + AGIBOT_WORLD_HEAD_PITCH_JOINT_NAME: float(head_pitch), + } + joint_values.update({name: float(arm_state[idx]) for idx, name in enumerate(AGIBOT_WORLD_ARM_JOINT_NAMES_LEFT)}) + joint_values.update({name: float(arm_state[7 + idx]) for idx, name in enumerate(AGIBOT_WORLD_ARM_JOINT_NAMES_RIGHT)}) + _set_gripper_joint_values_from_state(joint_values, state, embodiment_type) + return joint_values + + +def _set_gripper_joint_values_from_state( + joint_values: dict[str, float], + state: np.ndarray, + embodiment_type: str, +) -> None: + """Map observed scalar gripper state into all omnipicker finger joints.""" + + embodiment_spec = get_agibot_world_embodiment_spec(embodiment_type) + if embodiment_spec.kind != "gripper": + return + + if embodiment_type == "agibot_world_gripper_ext": + left_raw = float(state[AGIBOT_WORLD_EXT_STATE_LEFT_HAND_SLICE][0]) + right_raw = float(state[AGIBOT_WORLD_EXT_STATE_RIGHT_HAND_SLICE][0]) + else: + kind_spec = get_agibot_world_kind_spec(embodiment_type) + state_hand_slice = kind_spec.state_hand_slice + left_raw = float(state[state_hand_slice.start]) + right_raw = float(state[state_hand_slice.start + 1]) + + left_open = float(convert_gripper_state_to_open_fraction(np.asarray([left_raw], dtype=np.float32))[0]) # [1] + right_open = float(convert_gripper_state_to_open_fraction(np.asarray([right_raw], dtype=np.float32))[0]) # [1] + for opening, joint_mimics in ( + (left_open, AGIBOT_WORLD_LEFT_GRIPPER_JOINT_MIMICS), + (right_open, AGIBOT_WORLD_RIGHT_GRIPPER_JOINT_MIMICS), + ): + primary_angle = -float(np.clip(opening, 0.0, 1.0)) * AGIBOT_WORLD_GRIPPER_OPEN_ANGLE_RAD + for joint_name, multiplier, offset in joint_mimics: + joint_values[joint_name] = multiplier * primary_angle + offset + + +def compute_fk_transforms( + state: np.ndarray, + embodiment_type: str, +) -> dict[str, np.ndarray]: + """Compute native-frame calibrated head-camera and gripper-base transforms for one state.""" + + fk_engine = _get_fk_engine() + link_poses = fk_engine.link_poses(_extract_joint_values_from_state(state, embodiment_type)) + + return { + "head_camera": link_poses[AGIBOT_WORLD_HEAD_CAMERA_LINK_NAME].astype(np.float32, copy=False), + "right_wrist": link_poses[AGIBOT_WORLD_RIGHT_EE_LINK_NAME].astype(np.float32, copy=False), + "left_wrist": link_poses[AGIBOT_WORLD_LEFT_EE_LINK_NAME].astype(np.float32, copy=False), + } + + +def compute_fk_transforms_batch( + states: np.ndarray, + embodiment_type: str, +) -> dict[str, np.ndarray]: + """Compute absolute transforms for a batch of AgiBot observation states.""" + + num_steps = int(states.shape[0]) + head_camera = np.empty((num_steps, 4, 4), dtype=np.float32) + right_wrist = np.empty((num_steps, 4, 4), dtype=np.float32) + left_wrist = np.empty((num_steps, 4, 4), dtype=np.float32) + + for step in range(num_steps): + transforms = compute_fk_transforms(states[step], embodiment_type) + head_camera[step] = transforms["head_camera"] + right_wrist[step] = transforms["right_wrist"] + left_wrist[step] = transforms["left_wrist"] + + transforms_by_name = { + "head_camera": head_camera, + "right_wrist": right_wrist, + "left_wrist": left_wrist, + } + return _apply_ext_base_motion_to_poses(transforms_by_name, states, embodiment_type) + + +def convert_gripper_state_to_open_fraction(values: np.ndarray) -> np.ndarray: + """Convert observed AgiBot gripper state to viewer/dataset open fractions. + + The shared viewer/action convention is ``0=closed`` and ``1=open``. + Observed AgiBot gripper state uses actuator-close angle units: ``0`` is + open and ``120`` is closed. Some episodes contain small closed-state + overshoot above ``120``; those values are accepted and clipped to fully + closed. Small open-state sensor jitter such as ``0.217`` must therefore + remain nearly fully open, not be interpreted as a normalized close fraction. + """ + + values = np.asarray(values, dtype=np.float32) + if values.size == 0: + return values + if not np.isfinite(values).all(): + raise ValueError("AgiBot gripper values contain NaN or Inf values.") + + min_value = float(np.min(values)) + max_value = float(np.max(values)) + if ( + min_value < -_GRIPPER_VALUE_EPS + and min_value >= -AGIBOT_WORLD_GRIPPER_OPEN_ANGLE_RAD - _GRIPPER_VALUE_EPS + and max_value <= _GRIPPER_VALUE_EPS + ): + return _scale_negative_to_unit_interval(values, AGIBOT_WORLD_GRIPPER_OPEN_ANGLE_RAD) + if ( + min_value < -_GRIPPER_VALUE_EPS + and min_value >= -np.degrees(AGIBOT_WORLD_GRIPPER_OPEN_ANGLE_RAD) - _GRIPPER_VALUE_EPS + and max_value <= _GRIPPER_VALUE_EPS + ): + return _scale_negative_to_unit_interval(values, np.degrees(AGIBOT_WORLD_GRIPPER_OPEN_ANGLE_RAD)) + max_actuator_value = AGIBOT_WORLD_GRIPPER_OPEN_ACTUATOR_DEG + _GRIPPER_ACTUATOR_OVERSHOOT_DEG + if min_value >= -_GRIPPER_VALUE_EPS and max_value <= max_actuator_value + _GRIPPER_VALUE_EPS: + close_fraction = _scale_to_unit_interval(values, AGIBOT_WORLD_GRIPPER_OPEN_ACTUATOR_DEG) # [*] + return (1.0 - close_fraction).astype(np.float32, copy=False) # [*] + + raise ValueError( + f"Unsupported AgiBot gripper value range; min={min_value:.4f}, max={max_value:.4f}. " + f"Expected URDF angle [-pi/4,0] or actuator-close degrees [0,{max_actuator_value:.1f}] " + f"(values above {AGIBOT_WORLD_GRIPPER_OPEN_ACTUATOR_DEG:.1f} are clipped closed)." + ) + + diff --git a/cosmos_framework/data/vfm/action/agibot_spec.py b/cosmos_framework/data/vfm/action/agibot_spec.py new file mode 100644 index 0000000..0abfe22 --- /dev/null +++ b/cosmos_framework/data/vfm/action/agibot_spec.py @@ -0,0 +1,129 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: OpenMDW-1.1 + +"""Shared AgiBot metadata used by datasets and visualizers.""" + +from __future__ import annotations + +import math +from dataclasses import dataclass +from pathlib import Path +from typing import Literal + +AgibotWorldKind = Literal["gripper"] + +AGIBOT_WORLD_URDF_FILENAME = "G1_omnipicker_calibrated.urdf" +AGIBOT_WORLD_ARM_STATE_SLICE = slice(0, 14) +AGIBOT_WORLD_STATE_HEAD_YAW_IDX = 16 +AGIBOT_WORLD_STATE_HEAD_PITCH_IDX = 17 +AGIBOT_WORLD_STATE_WAIST_PITCH_IDX = 18 +AGIBOT_WORLD_STATE_WAIST_LIFT_IDX = 19 +AGIBOT_WORLD_HEAD_PITCH_JOINT_NAME = "idx04_head_pitch_joint" + +# -- Ext layout constants (94-dim state) ------------------------------------- +# The ext split stores joints at different offsets from the standard layout. +AGIBOT_WORLD_EXT_ARM_STATE_SLICE = slice(54, 68) +AGIBOT_WORLD_EXT_STATE_HEAD_YAW_IDX = 82 +AGIBOT_WORLD_EXT_STATE_HEAD_PITCH_IDX = 83 +AGIBOT_WORLD_EXT_STATE_WAIST_PITCH_IDX = 84 +AGIBOT_WORLD_EXT_STATE_WAIST_LIFT_IDX = 85 +AGIBOT_WORLD_EXT_STATE_ROBOT_POSITION_SLICE = slice(86, 89) +AGIBOT_WORLD_EXT_STATE_ROBOT_ORIENTATION_SLICE = slice(89, 93) +AGIBOT_WORLD_EXT_STATE_LEFT_HAND_SLICE = slice(0, 1) +AGIBOT_WORLD_EXT_STATE_RIGHT_HAND_SLICE = slice(1, 2) +AGIBOT_WORLD_HEAD_CAMERA_LINK_NAME = "head_camera_link" +AGIBOT_WORLD_LEFT_EE_LINK_NAME = "gripper_l_base_link" +AGIBOT_WORLD_RIGHT_EE_LINK_NAME = "gripper_r_base_link" +AGIBOT_WORLD_ARM_JOINT_NAMES_LEFT = tuple(f"idx{4 + i:02d}_left_arm_joint{i}" for i in range(1, 8)) +AGIBOT_WORLD_ARM_JOINT_NAMES_RIGHT = tuple(f"idx{11 + i:02d}_right_arm_joint{i}" for i in range(1, 8)) +AGIBOT_WORLD_WAIST_LIFT_JOINT_NAME = "idx01_waist_lift_joint" +AGIBOT_WORLD_WAIST_PITCH_JOINT_NAME = "idx02_waist_pitch_joint" +AGIBOT_WORLD_HEAD_YAW_JOINT_NAME = "idx03_head_yaw_joint" +AGIBOT_WORLD_GRIPPER_OPEN_ANGLE_RAD = math.pi / 4.0 +AGIBOT_WORLD_GRIPPER_OPEN_ACTUATOR_DEG = 120.0 +AGIBOT_WORLD_LEFT_GRIPPER_JOINT_MIMICS = ( + ("idx31_gripper_l_inner_joint1", 1.0, 0.0), + ("idx32_gripper_l_inner_joint3", 0.1, 0.0), + ("idx33_gripper_l_inner_joint4", 0.25, 0.0), + ("idx39_gripper_l_inner_joint0", -0.7, 0.0), + ("idx41_gripper_l_outer_joint1", -1.0, 0.0), + ("idx42_gripper_l_outer_joint3", 0.1, 0.0), + ("idx43_gripper_l_outer_joint4", -0.25, 0.0), + ("idx49_gripper_l_outer_joint0", 0.7, 0.0), +) +AGIBOT_WORLD_RIGHT_GRIPPER_JOINT_MIMICS = ( + ("idx71_gripper_r_inner_joint1", 1.0, 0.0), + ("idx72_gripper_r_inner_joint3", 0.1, 0.0), + ("idx73_gripper_r_inner_joint4", 0.25, 0.0), + ("idx79_gripper_r_inner_joint0", -0.7, 0.0), + ("idx81_gripper_r_outer_joint1", -1.0, 0.0), + ("idx82_gripper_r_outer_joint3", 0.1, 0.0), + ("idx83_gripper_r_outer_joint4", -0.25, 0.0), + ("idx89_gripper_r_outer_joint0", 0.7, 0.0), +) + + +@dataclass(frozen=True) +class AgibotWorldKindSpec: + """Layout metadata shared across all embodiments of one hand kind.""" + + kind: AgibotWorldKind + state_hand_slice: slice + + +@dataclass(frozen=True) +class AgibotWorldEmbodimentSpec: + """Per-embodiment metadata shared by training and visualization code.""" + + embodiment_type: str + kind: AgibotWorldKind + + +AGIBOT_WORLD_KIND_SPECS: dict[AgibotWorldKind, AgibotWorldKindSpec] = { + "gripper": AgibotWorldKindSpec( + kind="gripper", + state_hand_slice=slice(14, 16), + ), +} + +AGIBOT_WORLD_EMBODIMENT_SPECS: dict[str, AgibotWorldEmbodimentSpec] = { + "agibot_world_gripper": AgibotWorldEmbodimentSpec( + embodiment_type="agibot_world_gripper", + kind="gripper", + ), + "agibot_world_gripper_ext": AgibotWorldEmbodimentSpec( + embodiment_type="agibot_world_gripper_ext", + kind="gripper", + ), +} + + +def get_agibot_world_embodiment_spec(embodiment_type: str) -> AgibotWorldEmbodimentSpec: + """Return the registered spec for one AgiBot embodiment.""" + + try: + return AGIBOT_WORLD_EMBODIMENT_SPECS[embodiment_type] + except KeyError as exc: + raise ValueError( + f"Unknown AgiBot World embodiment_type={embodiment_type!r}. " + f"Expected one of {sorted(AGIBOT_WORLD_EMBODIMENT_SPECS)}." + ) from exc + + +def get_agibot_world_kind_spec(embodiment_type: str | AgibotWorldKind) -> AgibotWorldKindSpec: + """Resolve an embodiment type or kind to its shared layout metadata.""" + + kind = embodiment_type if embodiment_type in AGIBOT_WORLD_KIND_SPECS else get_agibot_world_kind(embodiment_type) + return AGIBOT_WORLD_KIND_SPECS[kind] + + +def get_agibot_world_kind(embodiment_type: str) -> AgibotWorldKind: + """Return the hand kind used by one AgiBot embodiment.""" + + return get_agibot_world_embodiment_spec(embodiment_type).kind + + +def get_agibot_world_urdf_path() -> Path: + """Return the committed AgiBot G1 omnipicker URDF path.""" + + return Path(__file__).resolve().parent / "urdf_visualizer" / AGIBOT_WORLD_URDF_FILENAME diff --git a/cosmos_framework/data/vfm/action/datasets/__init__.py b/cosmos_framework/data/vfm/action/datasets/__init__.py index 6828c76..0b01e6b 100644 --- a/cosmos_framework/data/vfm/action/datasets/__init__.py +++ b/cosmos_framework/data/vfm/action/datasets/__init__.py @@ -1,8 +1,23 @@ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: OpenMDW-1.1 -"""Minimal Action dataset wrappers.""" +"""Action dataset wrappers for Cosmos Action. +All concrete datasets inherit from :class:`ActionBaseDataset` and expose a +``load_action_stats()`` classmethod for retrieving pre-computed normalization +statistics without instantiating the dataset. +""" + +from cosmos_framework.data.vfm.action.datasets.agibotworld_beta_lerobot_dataset import AgiBotWorldBetaLeRobotDataset +from cosmos_framework.data.vfm.action.datasets.base_dataset import ActionBaseDataset +from cosmos_framework.data.vfm.action.datasets.bridge_orig_lerobot_dataset import BridgeOrigLeRobotDataset from cosmos_framework.data.vfm.action.datasets.droid_lerobot_dataset import DROIDLeRobotDataset +from cosmos_framework.data.vfm.action.datasets.robomind_franka_dataset import RoboMINDFrankaDataset -__all__ = ["DROIDLeRobotDataset"] +__all__ = [ + "ActionBaseDataset", + "AgiBotWorldBetaLeRobotDataset", + "BridgeOrigLeRobotDataset", + "DROIDLeRobotDataset", + "RoboMINDFrankaDataset", +] diff --git a/cosmos_framework/data/vfm/action/datasets/agibotworld_beta_lerobot_dataset.py b/cosmos_framework/data/vfm/action/datasets/agibotworld_beta_lerobot_dataset.py new file mode 100644 index 0000000..f95feea --- /dev/null +++ b/cosmos_framework/data/vfm/action/datasets/agibotworld_beta_lerobot_dataset.py @@ -0,0 +1,277 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: OpenMDW-1.1 + +"""AgiBotWorld-Beta LeRobot dataset.""" + +from __future__ import annotations + +from pathlib import Path +from typing import Any, Literal + +import numpy as np +import torch +import torch.nn.functional as F +from lerobot.datasets.video_utils import decode_video_frames + +from cosmos_framework.data.vfm.action.agibot_fk import ( + AGIBOT_WORLD_GRIPPER_TO_OPENCV_BY_WRIST, + apply_agibot_gripper_to_opencv, + apply_robot_base_motion_to_poses, + compute_fk_transforms_batch, + convert_gripper_state_to_open_fraction, +) +from cosmos_framework.data.vfm.action.action_spec import ActionSpec, Gripper, Pos, Rot, build_action_spec +from cosmos_framework.data.vfm.action.datasets.base_dataset import ActionBaseDataset +from cosmos_framework.data.vfm.action.pose_utils import pose_abs_to_rel + +PoseConvention = Literal["backward_framewise"] +Viewpoint = Literal["concat_view", "ego_view"] + +_HEAD_KEY = "observation.images.head" +_HAND_LEFT_KEY = "observation.images.hand_left" +_HAND_RIGHT_KEY = "observation.images.hand_right" +_CONCAT_KEY = "observation.images.video_concat_view" + +_EFFECTOR_KEY = "observation.states.effector.position" +_JOINT_KEY = "observation.states.joint.position" +_HEAD_STATE_KEY = "observation.states.head.position" +_WAIST_KEY = "observation.states.waist.position" +_ROBOT_POSITION_KEY = "observation.states.robot.position" +_ROBOT_ORIENTATION_KEY = "observation.states.robot.orientation" + +_NORMALIZER_PATH = Path(__file__).parent / "stats/agibotworld_beta_lerobot_stats.json" + + +def _split_task_for_caption(task: str) -> tuple[str, str]: + ai_caption, separator, debug_caption = task.partition("|") + if not separator: + return task.strip(), "" + return ai_caption.strip(), debug_caption.strip() + + +def _assemble_agibot_world_state( + effector_pos: np.ndarray, + joint_pos: np.ndarray, + head_pos: np.ndarray, + waist_pos: np.ndarray, +) -> np.ndarray: + """Assemble standard 20D gripper state from Beta decomposed fields.""" + + body_head = np.stack( + [head_pos[:, 0], head_pos[:, 1], waist_pos[:, 0], waist_pos[:, 1]], + axis=-1, + ) + return np.concatenate([joint_pos, effector_pos, body_head], axis=-1).astype(np.float32, copy=False) + + +def _compute_idle_frames_agibot(action: torch.Tensor) -> int: + """Small local idle-frame helper for the 29D AgiBot FK layout. + + The shared `compute_idle_frames` expects one rotation group after each + position block; AgiBot's action spec has three such groups plus grippers. + For cookbook inference, idle frames are metadata only, so this conservative + implementation marks the initial low-motion streak length. + """ + + if action.numel() == 0: + return 0 + abs_action = action.detach().abs() + motion = torch.cat( + [ + abs_action[:, 0:3], + abs_action[:, 9:12], + abs_action[:, 18:21], + abs_action[:, 18:19].diff(dim=0, prepend=abs_action[0:1, 18:19]), + abs_action[:, 28:29].diff(dim=0, prepend=abs_action[0:1, 28:29]), + ], + dim=-1, + ).amax(dim=-1) + below = motion < 1e-3 + count = 0 + for value in below.tolist(): + if not value: + break + count += 1 + return count + + +class AgiBotWorldBetaLeRobotDataset(ActionBaseDataset): + """AgiBotWorld-Beta dataset with FK-pose 29D actions. + + Action layout matches the AgiBot World gripper normalizer: + + [head_pos+rot6d(9), right_pos+rot6d(9), right_gripper(1), + left_pos+rot6d(9), left_gripper(1)] + + The local cookbook asset provides head, left wrist, and right wrist videos. + By default this wrapper uses `concat_view`: head view on top, left/right + wrist views resized and concatenated on the bottom. + """ + + + def __init__( + self, + root: str, + fps: float = 10.0, + chunk_length: int = 16, + mode: str = "joint", + pose_convention: PoseConvention = "backward_framewise", + tolerance_s: float = 3e-4, + viewpoint: Viewpoint = "concat_view", + action_normalization: str | None = "quantile", + sample_stride: int = 1, + ) -> None: + if viewpoint not in ("concat_view", "ego_view"): + raise NotImplementedError("Supported viewpoints are concat_view and ego_view.") + super().__init__( + root=root, + domain_name="agibotworld", + fps=fps, + chunk_length=chunk_length, + mode=mode, + pose_convention=pose_convention, + tolerance_s=tolerance_s, + viewpoint=viewpoint, + action_normalization=action_normalization, + sample_stride=sample_stride, + ) + self._rows_by_episode: dict[int, list[dict[str, Any]]] = {} + for row in self._rows: + self._rows_by_episode.setdefault(int(row["episode_index"]), []).append(row) + self._timestamps_by_episode = { + episode_id: np.asarray([float(row["timestamp"]) for row in rows], dtype=np.float64) + for episode_id, rows in self._rows_by_episode.items() + } + + @property + def action_dim(self) -> int: + return 29 + + def _action_spec(self) -> ActionSpec: + return build_action_spec( + Pos(prefix="head"), + Rot("rot6d", prefix="head"), + Pos(prefix="right"), + Rot("rot6d", prefix="right"), + Gripper(prefix="right"), + Pos(prefix="left"), + Rot("rot6d", prefix="left"), + Gripper(prefix="left"), + ) + + @classmethod + def _stats_path(cls) -> Path: + return _NORMALIZER_PATH + + def _compute_idle_frames(self, action: torch.Tensor) -> int: + return _compute_idle_frames_agibot(action) + + def __len__(self) -> int: + return max(0, (len(self._rows) - self._chunk_length + self._sample_stride - 1) // self._sample_stride) + + def __getitem__(self, idx: int) -> dict[str, Any]: + mode = self._choose_mode() + row_idx = int(idx) * self._sample_stride + start_row = self._rows[row_idx] + observation_rows = self._select_observation_rows(start_row) + episode = self._episodes[int(observation_rows[0]["episode_index"])] + task = self._tasks[int(observation_rows[0]["task_index"])] + ai_caption, debug_caption = _split_task_for_caption(task) + + video = self._load_video(episode, observation_rows) + action, extras = self._build_fk_action(observation_rows) + if self._viewpoint == "concat_view": + extras["additional_view_description"] = ( + "The top row shows the head-mounted camera view looking down at the workspace. " + "The bottom row contains two horizontally concatenated wrist-mounted camera views: " + "the left hand camera on the left and the right hand camera on the right." + ) + if debug_caption: + extras["debug_caption"] = debug_caption + + return self._build_result( + mode=mode, + video=video, + action=action, + ai_caption=ai_caption, + action_spec_names=self.action_names, + **extras, + ) + + def _select_observation_rows(self, start_row: dict[str, Any]) -> list[dict[str, Any]]: + """Select T+1 rows at this wrapper's target FPS within one episode.""" + + episode_id = int(start_row["episode_index"]) + rows = self._rows_by_episode[episode_id] + timestamps = self._timestamps_by_episode[episode_id] + start_frame = int(start_row["frame_index"]) + start_ts = float(start_row["timestamp"]) + target_ts = start_ts + np.arange(self._chunk_length + 1, dtype=np.float64) / self._fps + indices = np.searchsorted(timestamps, target_ts, side="left") + indices = np.minimum(indices, len(rows) - 1) + prev = np.maximum(indices - 1, 0) + choose_prev = np.abs(timestamps[prev] - target_ts) <= np.abs(timestamps[indices] - target_ts) + indices = np.where(choose_prev, prev, indices) + if int(indices[-1]) <= start_frame: + raise IndexError(f"Could not select {self._chunk_length + 1} frames from episode {episode_id} at fps={self._fps}.") + return [rows[int(i)] for i in indices] + + def _load_video(self, episode: dict[str, Any], observation_rows: list[dict[str, Any]]) -> torch.Tensor: + if self._viewpoint == "ego_view": + return self._load_video_key(episode, observation_rows, _HEAD_KEY) + + # Prefer a pre-rendered concat view if present. The local asset includes + # metadata for this key but not the public mp4, so the fallback composes + # it from the three camera streams. + concat_path = self._video_path(episode, _CONCAT_KEY) + if concat_path.exists(): + return self._load_video_key(episode, observation_rows, _CONCAT_KEY) + top = self._load_video_key(episode, observation_rows, _HEAD_KEY) + left = self._load_video_key(episode, observation_rows, _HAND_LEFT_KEY) + right = self._load_video_key(episode, observation_rows, _HAND_RIGHT_KEY) + return self._compose_multi_view(top, left, right) + + def _load_video_key(self, episode: dict[str, Any], observation_rows: list[dict[str, Any]], key: str) -> torch.Tensor: + timestamps = [float(row["timestamp"]) for row in observation_rows] + return decode_video_frames( + self._video_path(episode, key), + [float(episode.get(f"videos/{key}/from_timestamp", 0.0)) + ts for ts in timestamps], + self._tolerance_s, + ) + + def _compose_multi_view(self, top: torch.Tensor, left: torch.Tensor, right: torch.Tensor) -> torch.Tensor: + # Inputs are [T,C,H,W] float tensors in [0,1]. + _, _, h_top, w_top = top.shape + half_h, half_w = h_top // 2, w_top // 2 + left = F.interpolate(left, size=(half_h, half_w), mode="bilinear", align_corners=False) + right = F.interpolate(right, size=(half_h, half_w), mode="bilinear", align_corners=False) + bottom = torch.cat([left, right], dim=-1) + return torch.cat([top, bottom], dim=-2) + + def _build_fk_action(self, rows: list[dict[str, Any]]) -> tuple[torch.Tensor, dict[str, torch.Tensor]]: + effector_pos = np.asarray([row[_EFFECTOR_KEY] for row in rows], dtype=np.float32) + joint_pos = np.asarray([row[_JOINT_KEY] for row in rows], dtype=np.float32) + head_pos = np.asarray([row[_HEAD_STATE_KEY] for row in rows], dtype=np.float32) + waist_pos = np.asarray([row[_WAIST_KEY] for row in rows], dtype=np.float32) + robot_pos = np.asarray([row[_ROBOT_POSITION_KEY] for row in rows], dtype=np.float32) + robot_quat = np.asarray([row[_ROBOT_ORIENTATION_KEY] for row in rows], dtype=np.float32) + states_np = _assemble_agibot_world_state(effector_pos, joint_pos, head_pos, waist_pos) + + native_fk = compute_fk_transforms_batch(states_np, "agibot_world_gripper") + native_fk = apply_robot_base_motion_to_poses(native_fk, robot_pos, robot_quat) + fk = apply_agibot_gripper_to_opencv(native_fk, AGIBOT_WORLD_GRIPPER_TO_OPENCV_BY_WRIST) + + head_rel = pose_abs_to_rel(fk["head_camera"], rotation_format="rot6d", pose_convention=self._pose_convention) + right_rel = pose_abs_to_rel(fk["right_wrist"], rotation_format="rot6d", pose_convention=self._pose_convention) + left_rel = pose_abs_to_rel(fk["left_wrist"], rotation_format="rot6d", pose_convention=self._pose_convention) + right_gripper = convert_gripper_state_to_open_fraction(effector_pos[1:, 1:2]) + left_gripper = convert_gripper_state_to_open_fraction(effector_pos[1:, 0:1]) + action_np = np.concatenate([head_rel, right_rel, right_gripper, left_rel, left_gripper], axis=-1).astype( + np.float32 + ) + extras = { + "initial_pose": torch.from_numpy(fk["head_camera"][0].copy()).float(), + "initial_pose_right": torch.from_numpy(fk["right_wrist"][0].copy()).float(), + "initial_pose_left": torch.from_numpy(fk["left_wrist"][0].copy()).float(), + } + return torch.from_numpy(action_np).float(), extras diff --git a/cosmos_framework/data/vfm/action/datasets/base_dataset.py b/cosmos_framework/data/vfm/action/datasets/base_dataset.py new file mode 100644 index 0000000..564d48e --- /dev/null +++ b/cosmos_framework/data/vfm/action/datasets/base_dataset.py @@ -0,0 +1,204 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: OpenMDW-1.1 + +"""Abstract base class for Action LeRobot datasets.""" + +from __future__ import annotations + +import json +import random +from abc import ABC, abstractmethod +from pathlib import Path +from typing import Any + +import numpy as np +import pyarrow.parquet as pq +import torch +from torch.utils.data import Dataset + +from cosmos_framework.data.vfm.action.action_normalization import load_action_stats, normalize_action +from cosmos_framework.data.vfm.action.action_spec import ActionSpec +from cosmos_framework.data.vfm.action.domain_utils import get_domain_id +from cosmos_framework.data.vfm.action.pose_utils import compute_idle_frames + +_MODE_CHOICES = ("forward_dynamics", "inverse_dynamics", "policy") + + +class ActionBaseDataset(ABC, Dataset): + """Abstract base for Action LeRobot datasets. + + Subclasses must implement the abstract methods listed below. + """ + + def __init__( + self, + root: str, + domain_name: str, + fps: float, + chunk_length: int, + mode: str, + pose_convention: str, + tolerance_s: float, + viewpoint: str, + action_normalization: str | None = "quantile", + sample_stride: int = 1, + ) -> None: + super().__init__() + if pose_convention != "backward_framewise": + raise NotImplementedError(f"{type(self).__name__} only supports backward_framewise pose deltas.") + + self._fps = float(fps) + self._dt = 1.0 / self._fps + self._chunk_length = int(chunk_length) + self._sample_stride = int(sample_stride) + if self._sample_stride < 1: + raise ValueError(f"sample_stride must be >= 1, got {self._sample_stride}") + self._mode = mode + self._pose_convention = pose_convention + self._tolerance_s = float(tolerance_s) + self._viewpoint = viewpoint + self._domain_id = get_domain_id(domain_name) + self._action_normalization = action_normalization + self._norm_stats: dict[str, torch.Tensor] | None = None + + self._root = Path(root) + self._info = json.loads((self._root / "meta" / "info.json").read_text()) + self._episodes = { + int(row["episode_index"]): row + for path in sorted((self._root / "meta" / "episodes").glob("chunk-*/file-*.parquet")) + for row in pq.read_table(path).to_pylist() + } + self._tasks = { + int(row["task_index"]): str(row["task"]) + for row in pq.read_table(self._root / "meta" / "tasks.parquet").to_pylist() + } + self._rows = sorted( + ( + row + for path in sorted((self._root / "data").glob("chunk-*/file-*.parquet")) + for row in pq.read_table(path).to_pylist() + ), + key=lambda row: int(row["index"]), + ) + + @property + def fps(self) -> float: + return self._fps + + @property + def chunk_length(self) -> int: + return self._chunk_length + + @property + def mode(self) -> str: + return self._mode + + @mode.setter + def mode(self, value: str) -> None: + self._mode = value + + @property + def domain_id(self) -> int: + return self._domain_id + + @property + def action_normalization(self) -> str: + return self._action_normalization + + @property + @abstractmethod + def action_dim(self) -> int: ... + + @abstractmethod + def _action_spec(self) -> ActionSpec: ... + + @property + def action_names(self) -> list[str]: + return self._action_spec().names + + @classmethod + @abstractmethod + def _stats_path(cls) -> Path: + """Return the path to the stats JSON file for this dataset.""" + ... + + @classmethod + def load_action_stats(cls) -> dict[str, torch.Tensor]: + """Return action normalization stats for this dataset as torch tensors.""" + return { + key: torch.from_numpy(value).float() + for key, value in load_action_stats(str(cls._stats_path())).items() + } + + @abstractmethod + def __getitem__(self, idx: int) -> dict[str, Any]: ... + + def _compute_idle_frames(self, action: torch.Tensor) -> int: + return compute_idle_frames( + action, + self._action_spec(), + eps_t=5e-3 / self._fps, + eps_r=np.deg2rad(1.5) / self._fps, + eps_g=1e-2, + joint_threshold=5e-3 / self._fps, + min_streak=3, + ) + + def _choose_mode(self) -> str: + if self._mode == "joint": + return random.choice(_MODE_CHOICES) + return self._mode + + def _video_path(self, episode: dict[str, Any], video_key: str) -> Path: + chunk_idx = int( + episode.get( + f"videos/{video_key}/chunk_index", + episode.get(f"videos/{video_key}/episode_chunk", episode.get("data/chunk_index", 0)), + ) + ) + file_idx = int( + episode.get( + f"videos/{video_key}/file_index", + episode.get(f"videos/{video_key}/episode_file", episode.get("data/file_index", 0)), + ) + ) + rel = self._info["video_path"].format( + video_key=video_key, + chunk_index=chunk_idx, + file_index=file_idx, + episode_chunk=chunk_idx, + episode_file=file_idx, + ) + return self._root / rel + + def _load_norm_stats(self) -> dict[str, torch.Tensor]: + if self._norm_stats is None: + self._norm_stats = self.load_action_stats() + return self._norm_stats + + def _build_result( + self, + *, + mode: str, + video: torch.Tensor, + action: torch.Tensor, + ai_caption: str, + **extras: Any, + ) -> dict[str, Any]: + idle_frames = self._compute_idle_frames(action) + normalized_action = normalize_action(action, self.action_normalization, self._load_norm_stats()) + formatted_video = (video * 255.0).clamp(0.0, 255.0).to(torch.uint8).permute(1, 0, 2, 3) + return { + "ai_caption": ai_caption, + "video": formatted_video, + "action": normalized_action, + "conditioning_fps": torch.tensor(self._fps, dtype=torch.long), + "mode": mode, + "domain_id": torch.tensor(self._domain_id, dtype=torch.long), + "viewpoint": self._viewpoint, + "idle_frames": torch.tensor(idle_frames, dtype=torch.long), + **extras, + } + + def __len__(self) -> int: + return max(0, (len(self._rows) - self._chunk_length + self._sample_stride - 1) // self._sample_stride) diff --git a/cosmos_framework/data/vfm/action/datasets/bridge_orig_lerobot_dataset.py b/cosmos_framework/data/vfm/action/datasets/bridge_orig_lerobot_dataset.py new file mode 100644 index 0000000..5992ce6 --- /dev/null +++ b/cosmos_framework/data/vfm/action/datasets/bridge_orig_lerobot_dataset.py @@ -0,0 +1,152 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: OpenMDW-1.1 + +"""Bridge Orig LeRobot dataset.""" + +from __future__ import annotations + +import random +from pathlib import Path +from typing import Any, Literal + +import numpy as np +import torch +from lerobot.datasets.video_utils import decode_video_frames + +from cosmos_framework.data.vfm.action.action_spec import ActionSpec, Gripper, Pos, Rot, build_action_spec +from cosmos_framework.data.vfm.action.datasets.base_dataset import ActionBaseDataset +from cosmos_framework.data.vfm.action.pose_utils import ( + build_abs_pose_from_components, + pose_abs_to_rel, +) + +PoseConvention = Literal["backward_framewise"] +Viewpoint = Literal["ego_view"] + +_IMAGE_FEATURE = "observation.images.image_0" +_STATE_FEATURE = "observation.state" +_ACTION_FEATURE = "action" + +# Raw Bridge state -> kinematics frame. The WidowX controller records +# R_state = R_fk @ DEFAULT_ROTATION.T, so R_fk = R_state @ DEFAULT_ROTATION. +_DEFAULT_ROTATION = np.array( + [[0.0, 0.0, 1.0], [0.0, 1.0, 0.0], [-1.0, 0.0, 0.0]], + dtype=np.float32, +) + +# Kinematics frame -> OpenCV frame used by Cosmos action. +_BRIDGE_TO_OPENCV = np.array( + [[0.0, 0.0, 1.0], [-1.0, 0.0, 0.0], [0.0, -1.0, 0.0]], + dtype=np.float32, +) + +# Re-reference from ee_gripper_link to gripper_link in the kinematics frame. +_TCP_TO_FLANGE = np.array( + [ + [1.0, 0.0, 0.0, -0.093575], + [0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 0.0], + [0.0, 0.0, 0.0, 1.0], + ], + dtype=np.float32, +) + +_NORMALIZER_PATH = Path(__file__).parent / "stats/bridge_orig_lerobot_stats.json" + + +class BridgeOrigLeRobotDataset(ActionBaseDataset): + """Bridge Orig dataset with 10D cartesian actions: + + [pos_delta(3), rot6d_delta(6), gripper(1)] + + Uses a single ``image_0`` ego-view video, backward-framewise rot6d actions, + and quantile normalization. + """ + + + def __init__( + self, + root: str, + fps: float = 5.0, + chunk_length: int = 16, + mode: str = "joint", + pose_convention: PoseConvention = "backward_framewise", + tolerance_s: float = 1e-4, + viewpoint: Viewpoint = "ego_view", + action_normalization: str | None = "quantile", + sample_stride: int = 1, + ) -> None: + if viewpoint != "ego_view": + raise NotImplementedError("This minimal Bridge dataset only supports ego_view.") + super().__init__( + root=root, + domain_name="bridge_orig_lerobot", + fps=fps, + chunk_length=chunk_length, + mode=mode, + pose_convention=pose_convention, + tolerance_s=tolerance_s, + viewpoint=viewpoint, + action_normalization=action_normalization, + sample_stride=sample_stride, + ) + + @property + def action_dim(self) -> int: + return 10 + + def _action_spec(self) -> ActionSpec: + return build_action_spec(Pos(), Rot("rot6d"), Gripper()) + + @classmethod + def _stats_path(cls) -> Path: + return _NORMALIZER_PATH + + def __getitem__(self, idx: int) -> dict[str, Any]: + mode = self._choose_mode() + idx = int(idx) + first_row = self._rows[idx] + episode = self._episodes[int(first_row["episode_index"])] + + row_idx = idx * self._sample_stride + observation_rows = self._rows[row_idx : row_idx + self._chunk_length + 1] + action_rows = observation_rows[: self._chunk_length] + + video = self._load_video(episode, observation_rows) + raw_action, initial_pose = self._build_raw_action(observation_rows, action_rows) + task = self._tasks[int(observation_rows[0]["task_index"])] + ai_caption = random.choice([part.strip() for part in task.split(" | ") if part.strip()] or [task]) + + return self._build_result( + mode=mode, + video=video, + action=raw_action, + ai_caption=ai_caption, + initial_pose=initial_pose, + ) + + def _load_video(self, episode: dict[str, Any], observation_rows: list[dict[str, Any]]) -> torch.Tensor: + timestamps = [float(row["timestamp"]) for row in observation_rows] + return decode_video_frames( + self._video_path(episode, _IMAGE_FEATURE), + [float(episode.get(f"videos/{_IMAGE_FEATURE}/from_timestamp", 0.0)) + ts for ts in timestamps], + self._tolerance_s, + ) + + def _build_raw_action( + self, + observation_rows: list[dict[str, Any]], + action_rows: list[dict[str, Any]], + ) -> tuple[torch.Tensor, torch.Tensor]: + state = np.asarray([row[_STATE_FEATURE] for row in observation_rows], dtype=np.float32) + poses_abs = build_abs_pose_from_components(state[:, 0:3], state[:, 3:6], "euler_xyz") + + poses_abs[:, :3, :3] = poses_abs[:, :3, :3] @ _DEFAULT_ROTATION.astype(poses_abs.dtype) + poses_abs = poses_abs @ _TCP_TO_FLANGE.astype(poses_abs.dtype) + poses_abs[:, :3, :3] = poses_abs[:, :3, :3] @ _BRIDGE_TO_OPENCV.astype(poses_abs.dtype) + + initial_pose = torch.from_numpy(poses_abs[0].copy()).float() + poses_rel = pose_abs_to_rel(poses_abs, rotation_format="rot6d", pose_convention=self._pose_convention) + gripper = np.asarray([row[_ACTION_FEATURE][6] for row in action_rows], dtype=np.float32).reshape(-1, 1) + action = np.concatenate([poses_rel[-self._chunk_length :], gripper[-self._chunk_length :]], axis=-1) + return torch.from_numpy(action).float(), initial_pose diff --git a/cosmos_framework/data/vfm/action/datasets/droid_lerobot_dataset.py b/cosmos_framework/data/vfm/action/datasets/droid_lerobot_dataset.py index 204df69..631f1e9 100644 --- a/cosmos_framework/data/vfm/action/datasets/droid_lerobot_dataset.py +++ b/cosmos_framework/data/vfm/action/datasets/droid_lerobot_dataset.py @@ -16,14 +16,11 @@ import torch.nn.functional as F import torchvision.transforms as T from lerobot.datasets.video_utils import decode_video_frames -from torch.utils.data import Dataset -from cosmos_framework.data.vfm.action.action_normalization import load_action_stats, normalize_action -from cosmos_framework.data.vfm.action.action_spec import Gripper, Joint, Pos, Rot, build_action_spec -from cosmos_framework.data.vfm.action.domain_utils import get_domain_id +from cosmos_framework.data.vfm.action.action_spec import ActionSpec, Gripper, Joint, Pos, Rot, build_action_spec +from cosmos_framework.data.vfm.action.datasets.base_dataset import ActionBaseDataset from cosmos_framework.data.vfm.action.pose_utils import ( build_abs_pose_from_components, - compute_idle_frames, pose_abs_to_rel, ) @@ -55,11 +52,10 @@ dtype=np.float32, ) -_NORMALIZER_PATH = Path(__file__).parent / "droid_lerobot_normalization.json" -_MODE_CHOICES = ("forward_dynamics", "inverse_dynamics", "policy") +_NORMALIZER_PATH = Path(__file__).parent / "stats/droid_lerobot_stats.json" -class DROIDLeRobotDataset(Dataset): +class DROIDLeRobotDataset(ActionBaseDataset): """DROID Action dataset. Two action layouts: @@ -75,7 +71,7 @@ class DROIDLeRobotDataset(Dataset): def __init__( self, - root: str = "/path/to/cosmos3_action_datasets/droid_plus_lerobot_640x360_20260412", + root: str, fps: float = 15.0, chunk_length: int = 16, mode: str = "joint", @@ -89,23 +85,28 @@ def __init__( use_filter_dict: bool = False, filter_dict_path: str | None = None, ) -> None: - super().__init__() - if pose_convention != "backward_framewise": - raise NotImplementedError("This minimal DROID dataset only supports backward_framewise pose deltas.") if viewpoint != "concat_view": raise NotImplementedError("This minimal DROID dataset only supports concat_view.") if action_space not in _ACTION_SPACES: raise NotImplementedError(f"action_space must be one of {_ACTION_SPACES}, got {action_space!r}.") if use_state and action_space != "joint_pos": raise NotImplementedError("use_state is only supported with action_space='joint_pos'.") + if use_filter_dict and not filter_dict_path: + raise ValueError("use_filter_dict=True requires filter_dict_path") + + # joint_pos uses raw joint values — disable normalization at the base level. + super().__init__( + root=root, + domain_name="droid_lerobot", + fps=fps, + chunk_length=chunk_length, + mode=mode, + pose_convention=pose_convention, + tolerance_s=tolerance_s, + viewpoint=viewpoint, + action_normalization=None if action_space == "joint_pos" else action_normalization, + ) - self._fps = float(fps) - self._dt = 1.0 / self._fps - self._chunk_length = int(chunk_length) - self._mode = mode - self._pose_convention = pose_convention - self._tolerance_s = float(tolerance_s) - self._viewpoint = viewpoint self._action_space = action_space self._use_state = bool(use_state) # Per-sample image augmentation (random crop+rescale + color jitter), applied @@ -117,25 +118,7 @@ def __init__( # keep-ranges JSON is supplied via filter_dict_path (an internal data artifact). self._use_filter_dict = bool(use_filter_dict) self._filter_dict_path = filter_dict_path - if self._use_filter_dict and not self._filter_dict_path: - raise ValueError("use_filter_dict=True requires filter_dict_path") - # joint_pos trains on raw 8D joint values (the internal canonical run - # leaves action_normalization=None); ee_pose keeps quantile normalization. - self._action_normalization = None if action_space == "joint_pos" else action_normalization - self._domain_id = get_domain_id("droid_lerobot") - self._norm_stats: dict[str, torch.Tensor] | None = None - - self._root = Path(root) - self._info = json.loads((self._root / "meta" / "info.json").read_text()) - self._episodes = { - int(row["episode_index"]): row - for path in sorted((self._root / "meta" / "episodes").glob("chunk-*/file-*.parquet")) - for row in pq.read_table(path).to_pylist() - } - self._tasks = { - int(row["task_index"]): str(row["task"]) - for row in pq.read_table(self._root / "meta" / "tasks.parquet").to_pylist() - } + # Compact, lazy frame index. Materializing every frame as a Python dict # (``sorted(... pq.read_table(path).to_pylist() ...)``) does not scale: # the full DROID success shard is ~18M frames, which is tens of GB of @@ -215,43 +198,18 @@ def __init__( self._seg_win_start = np.asarray(seg_win_start, dtype=np.int64) self._seg_cum = np.cumsum(seg_len).astype(np.int64) if seg_len else np.zeros(0, dtype=np.int64) - @property - def fps(self) -> float: - return self._fps - - @property - def chunk_length(self) -> int: - return self._chunk_length - - @property - def mode(self) -> str: - return self._mode - - @mode.setter - def mode(self, value: str) -> None: - self._mode = value - - @property - def domain_id(self) -> int: - return self._domain_id - @property def action_dim(self) -> int: return 8 if self._action_space == "joint_pos" else 10 - def _action_spec(self): + def _action_spec(self) -> ActionSpec: if self._action_space == "joint_pos": return build_action_spec(Joint(n=7, label="joint"), Gripper()) return build_action_spec(Pos(), Rot("rot6d"), Gripper()) - @property - def action_names(self) -> list[str]: - return self._action_spec().names - - def _choose_mode(self) -> str: - if self._mode == "joint": - return random.choice(_MODE_CHOICES) - return self._mode + @classmethod + def _stats_path(cls) -> Path: + return _NORMALIZER_PATH def _window_rows(self, start: int, stop: int, episode_index: int) -> list[dict[str, Any]]: """Reconstruct the per-frame dicts the sample builder consumes for the @@ -371,28 +329,6 @@ def _load_concat_video( bottom = torch.cat([left, right], dim=-1) return torch.cat([wrist, bottom], dim=-2) - def _video_path(self, episode: dict[str, Any], video_key: str) -> Path: - chunk_idx = int( - episode.get( - f"videos/{video_key}/chunk_index", - episode.get(f"videos/{video_key}/episode_chunk", episode.get("data/chunk_index", 0)), - ) - ) - file_idx = int( - episode.get( - f"videos/{video_key}/file_index", - episode.get(f"videos/{video_key}/episode_file", episode.get("data/file_index", 0)), - ) - ) - rel = self._info["video_path"].format( - video_key=video_key, - chunk_index=chunk_idx, - file_index=file_idx, - episode_chunk=chunk_idx, - episode_file=file_idx, - ) - return self._root / rel - def _build_raw_action( self, observation_rows: list[dict[str, Any]], @@ -404,56 +340,13 @@ def _build_raw_action( initial_pose = torch.from_numpy(poses_abs[0].copy()).float() poses_rel = pose_abs_to_rel(poses_abs, rotation_format="rot6d", pose_convention=self._pose_convention) - gripper = np.asarray([row["action.gripper_position"] for row in action_rows], dtype=np.float32).reshape(-1, 1) + gripper = np.asarray( + [row[_ACTION_GRIPPER_FEATURE] for row in action_rows], dtype=np.float32 + ).reshape(-1, 1) gripper = 1.0 - gripper action = np.concatenate([poses_rel[-self._chunk_length :], gripper[-self._chunk_length :]], axis=-1) return torch.from_numpy(action).float(), initial_pose - def _build_result( - self, - *, - mode: str, - video: torch.Tensor, - action: torch.Tensor, - ai_caption: str, - **extras: Any, - ) -> dict[str, Any]: - spec = self._action_spec() - idle_frames = compute_idle_frames( - action, - spec, - eps_t=5e-3 / self._fps, - eps_r=np.deg2rad(1.5) / self._fps, - eps_g=1e-2, - joint_threshold=5e-3 / self._fps, - min_streak=3, - ) - if self._action_normalization is None: - out_action = action - else: - out_action = normalize_action(action, self._action_normalization, self._load_norm_stats()) - formatted_video = (video * 255.0).clamp(0.0, 255.0).to(torch.uint8).permute(1, 0, 2, 3) - return { - "ai_caption": ai_caption, - "video": formatted_video, - "action": out_action, - "conditioning_fps": torch.tensor(self._fps, dtype=torch.long), - "mode": mode, - "domain_id": torch.tensor(self._domain_id, dtype=torch.long), - "viewpoint": self._viewpoint, - "idle_frames": torch.tensor(idle_frames, dtype=torch.long), - **extras, - } - - def _load_norm_stats(self) -> dict[str, torch.Tensor]: - if self._norm_stats is not None: - return self._norm_stats - self._norm_stats = { - key: torch.from_numpy(value).float() - for key, value in load_action_stats(str(_NORMALIZER_PATH)).items() - } - return self._norm_stats - def __len__(self) -> int: if self._use_filter_dict: return int(self._seg_cum[-1]) if self._seg_cum.size else 0 diff --git a/cosmos_framework/data/vfm/action/datasets/robomind_franka_dataset.py b/cosmos_framework/data/vfm/action/datasets/robomind_franka_dataset.py new file mode 100644 index 0000000..136cd6c --- /dev/null +++ b/cosmos_framework/data/vfm/action/datasets/robomind_franka_dataset.py @@ -0,0 +1,192 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: OpenMDW-1.1 + +"""RoboMIND Franka LeRobot dataset.""" + +from __future__ import annotations + +import random +from pathlib import Path +from typing import Any, Literal + +import numpy as np +import torch +import torch.nn.functional as F +from lerobot.datasets.video_utils import decode_video_frames + +from cosmos_framework.data.vfm.action.action_spec import ActionSpec, Gripper, Pos, Rot, build_action_spec +from cosmos_framework.data.vfm.action.datasets.base_dataset import ActionBaseDataset +from cosmos_framework.data.vfm.action.pose_utils import ( + build_abs_pose_from_components, + pose_abs_to_rel, +) + +PoseConvention = Literal["backward_framewise"] +Viewpoint = Literal["concat_view"] + +_IMAGE_FEATURES = { + "front": "observation.images.camera_front", + "left": "observation.images.camera_left", + "right": "observation.images.camera_right", +} +_STATE_FEATURE = "observation.states.end_effector" +_ACTION_FEATURE = "actions.joint_position" + +# 90-degree clockwise rotation about the Z axis in the local frame. This matches +# the production RoboMIND Franka wrapper conversion to OpenCV coordinates. +_ROBOMIND_FRANKA_TO_OPENCV: np.ndarray = np.array( + [[0.0, -1.0, 0.0], [1.0, 0.0, 0.0], [0.0, 0.0, 1.0]], + dtype=np.float32, +) + +_NORMALIZER_PATH = Path(__file__).parent / "stats/robomind_franka_stats.json" + + +def _dual_arm_action_spec(): + return build_action_spec( + Pos(prefix="left"), + Rot("rot6d", prefix="left"), + Gripper(prefix="left"), + Pos(prefix="right"), + Rot("rot6d", prefix="right"), + Gripper(prefix="right"), + ) + + +class RoboMINDFrankaDataset(ActionBaseDataset): + """RoboMIND Franka dual-arm dataset with 20D cartesian actions:: + + [left_pos_delta(3), left_rot6d_delta(6), left_gripper(1), + right_pos_delta(3), right_rot6d_delta(6), right_gripper(1)] + + Single-arm shards, split/filter logic, image augmentation, fast + initialization, and alternate viewpoints are omitted. + """ + + + def __init__( + self, + root: str, + fps: float = 10.0, + chunk_length: int = 16, + mode: str = "joint", + embodiment_type: str = "robomind-franka-dual", + pose_convention: PoseConvention = "backward_framewise", + tolerance_s: float = 1e-4, + viewpoint: Viewpoint = "concat_view", + action_normalization: str | None = "quantile", + sample_stride: int = 1, + ) -> None: + if embodiment_type != "robomind-franka-dual": + raise NotImplementedError("This minimal RoboMIND dataset only supports robomind-franka-dual.") + if viewpoint != "concat_view": + raise NotImplementedError("This minimal RoboMIND dataset only supports concat_view.") + self._embodiment_type = embodiment_type + super().__init__( + root=root, + domain_name=embodiment_type, + fps=fps, + chunk_length=chunk_length, + mode=mode, + pose_convention=pose_convention, + tolerance_s=tolerance_s, + viewpoint=viewpoint, + action_normalization=action_normalization, + sample_stride=sample_stride, + ) + + @property + def action_dim(self) -> int: + return 20 + + def _action_spec(self) -> ActionSpec: + return _dual_arm_action_spec() + + @classmethod + def _stats_path(cls) -> Path: + return _NORMALIZER_PATH + + def __getitem__(self, idx: int) -> dict[str, Any]: + mode = self._choose_mode() + idx = int(idx) + first_row = self._rows[idx] + episode = self._episodes[int(first_row["episode_index"])] + + row_idx = idx * self._sample_stride + observation_rows = self._rows[row_idx : row_idx + self._chunk_length + 1] + action_rows = observation_rows[: self._chunk_length] + + video = self._load_concat_video(episode, observation_rows) + raw_action, initial_pose_left, initial_pose_right = self._build_raw_action(observation_rows, action_rows) + task = self._tasks[int(observation_rows[0]["task_index"])] + ai_caption = random.choice([part.strip() for part in task.split(" | ") if part.strip()] or [task]) + + return self._build_result( + mode=mode, + video=video, + action=raw_action, + ai_caption=ai_caption, + initial_pose=initial_pose_left, + initial_pose_right=initial_pose_right, + additional_view_description=( + "The top row shows a third-person perspective looking towards the dual-arm Franka robot from the front. " + "The bottom-left view looks at the scene from the left side, and the bottom-right view looks at the scene from the right side." + ), + ) + + def _load_concat_video( + self, + episode: dict[str, Any], + observation_rows: list[dict[str, Any]], + ) -> torch.Tensor: + timestamps = [float(row["timestamp"]) for row in observation_rows] + frames_by_view = { + name: decode_video_frames( + self._video_path(episode, video_key), + [float(episode.get(f"videos/{video_key}/from_timestamp", 0.0)) + ts for ts in timestamps], + self._tolerance_s, + ) + for name, video_key in _IMAGE_FEATURES.items() + } + + front = frames_by_view["front"] + left = frames_by_view["left"] + right = frames_by_view["right"] + _, _, h_front, w_front = front.shape + half_h, half_w = h_front // 2, w_front // 2 + left = F.interpolate(left, size=(half_h, half_w), mode="bilinear", align_corners=False) + right = F.interpolate(right, size=(half_h, half_w), mode="bilinear", align_corners=False) + bottom = torch.cat([left, right], dim=-1) + return torch.cat([front, bottom], dim=-2) + + def _build_relative_poses( + self, + positions: np.ndarray, + euler_xyz: np.ndarray, + ) -> tuple[np.ndarray, torch.Tensor]: + poses_abs = build_abs_pose_from_components(positions, euler_xyz, "euler_xyz") + poses_abs[:, :3, :3] = poses_abs[:, :3, :3] @ _ROBOMIND_FRANKA_TO_OPENCV + initial_pose = torch.from_numpy(poses_abs[0].copy()).float() + poses_rel = pose_abs_to_rel(poses_abs, rotation_format="rot6d", pose_convention=self._pose_convention) + return poses_rel, initial_pose + + def _build_raw_action( + self, + observation_rows: list[dict[str, Any]], + action_rows: list[dict[str, Any]], + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + state = np.asarray([row[_STATE_FEATURE] for row in observation_rows], dtype=np.float32) + gripper = np.asarray([row[_ACTION_FEATURE] for row in action_rows], dtype=np.float32) + + poses_rel_left, initial_pose_left = self._build_relative_poses(state[:, 0:3], state[:, 3:6]) + poses_rel_right, initial_pose_right = self._build_relative_poses(state[:, 6:9], state[:, 9:12]) + action = np.concatenate( + [ + poses_rel_left[-self._chunk_length :], + 1.0 - gripper[-self._chunk_length :, [7]], + poses_rel_right[-self._chunk_length :], + 1.0 - gripper[-self._chunk_length :, [15]], + ], + axis=-1, + ) + return torch.from_numpy(action).float(), initial_pose_left, initial_pose_right diff --git a/cosmos_framework/data/vfm/action/datasets/stats/agibotworld_beta_lerobot_stats.json b/cosmos_framework/data/vfm/action/datasets/stats/agibotworld_beta_lerobot_stats.json new file mode 100644 index 0000000..970ac30 --- /dev/null +++ b/cosmos_framework/data/vfm/action/datasets/stats/agibotworld_beta_lerobot_stats.json @@ -0,0 +1,4 @@ +{ + "q01": [-0.000167, -0.007272, -0.014935, 0.999999, -0.000306, -0.000594, -0.000260, 0.999227, -0.025516, -0.012912, -0.017163, -0.017614, 0.994613, -0.064506, -0.053231, -0.066267, 0.994383, -0.051163, 0.000000, -0.011640, -0.015508, -0.013880, 0.996511, -0.050126, -0.040305, -0.047330, 0.996618, -0.038303, 0.000000], + "q99": [ 0.000164, 0.004822, 0.013706, 1.000000, 0.000240, 0.000703, 0.000278, 1.000000, 0.030090, 0.013182, 0.016960, 0.016101, 1.000000, 0.066268, 0.053905, 0.064357, 1.000000, 0.052547, 1.000000, 0.010890, 0.015347, 0.012968, 1.000000, 0.047482, 0.042217, 0.050173, 1.000000, 0.041428, 1.000000] +} diff --git a/cosmos_framework/data/vfm/action/datasets/stats/bridge_orig_lerobot_stats.json b/cosmos_framework/data/vfm/action/datasets/stats/bridge_orig_lerobot_stats.json new file mode 100644 index 0000000..66d1d79 --- /dev/null +++ b/cosmos_framework/data/vfm/action/datasets/stats/bridge_orig_lerobot_stats.json @@ -0,0 +1,4 @@ +{ + "q01": [-0.038884, -0.028667, -0.037840, 0.976292, -0.163098, -0.081545, -0.160193, 0.976322, -0.078872, 0.000000], + "q99": [ 0.039722, 0.029068, 0.026702, 1.000000, 0.160195, 0.081655, 0.163227, 1.000000, 0.095189, 1.000000] +} diff --git a/cosmos_framework/data/vfm/action/datasets/droid_lerobot_normalization.json b/cosmos_framework/data/vfm/action/datasets/stats/droid_lerobot_stats.json similarity index 100% rename from cosmos_framework/data/vfm/action/datasets/droid_lerobot_normalization.json rename to cosmos_framework/data/vfm/action/datasets/stats/droid_lerobot_stats.json diff --git a/cosmos_framework/data/vfm/action/datasets/stats/robomind_franka_stats.json b/cosmos_framework/data/vfm/action/datasets/stats/robomind_franka_stats.json new file mode 100644 index 0000000..66e3c3c --- /dev/null +++ b/cosmos_framework/data/vfm/action/datasets/stats/robomind_franka_stats.json @@ -0,0 +1,4 @@ +{ + "q01": [-0.051367, -0.031964, -0.046482, 0.988101, -0.053179, -0.128603, -0.075432, 0.994427, -0.059973, 0.000000, -0.035108, -0.021212, -0.029788, 0.986086, -0.098043, -0.111441, -0.093441, 0.991492, -0.058030, 0.000000], + "q99": [ 0.043729, 0.021737, 0.036738, 1.000000, 0.075612, 0.102791, 0.053223, 1.000000, 0.077057, 1.000000, 0.047581, 0.021270, 0.025712, 1.000000, 0.095525, 0.126049, 0.098778, 1.000000, 0.041914, 0.995443] +} diff --git a/cosmos_framework/data/vfm/action/urdf_visualizer/G1_omnipicker_calibrated.urdf b/cosmos_framework/data/vfm/action/urdf_visualizer/G1_omnipicker_calibrated.urdf new file mode 100644 index 0000000..bd83679 --- /dev/null +++ b/cosmos_framework/data/vfm/action/urdf_visualizer/G1_omnipicker_calibrated.urdf @@ -0,0 +1,1350 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + From b3967db33f4e1d565947031e4fb722286a37cf52 Mon Sep 17 00:00:00 2001 From: lfengad Date: Sat, 13 Jun 2026 10:09:20 +0800 Subject: [PATCH 09/16] A2v sound encoder inference (#39) Co-authored-by: Claude Opus 4.8 (1M context) --- cosmos_framework/inference/args.py | 25 ++++++ cosmos_framework/inference/args_test.py | 52 +++++++++++ .../inference/common/checkpoints.py | 10 +-- .../audio_image2video/sample_args.json | 22 +++++ cosmos_framework/inference/inference.py | 41 +++++++-- cosmos_framework/inference/sound.py | 63 +++++++++++++- cosmos_framework/inference/sound_test.py | 87 +++++++++++++++++++ 7 files changed, 288 insertions(+), 12 deletions(-) create mode 100644 cosmos_framework/inference/defaults/audio_image2video/sample_args.json create mode 100644 cosmos_framework/inference/sound_test.py diff --git a/cosmos_framework/inference/args.py b/cosmos_framework/inference/args.py index 3c234f7..62c3c89 100644 --- a/cosmos_framework/inference/args.py +++ b/cosmos_framework/inference/args.py @@ -160,6 +160,7 @@ class ModelMode(StrEnum): IMAGE2IMAGE = "image2image" IMAGE2VIDEO = "image2video" VIDEO2VIDEO = "video2video" + AUDIO_IMAGE2VIDEO = "audio_image2video" # Action FORWARD_DYNAMICS = "forward_dynamics" @@ -176,6 +177,10 @@ def is_action(self) -> bool: def is_reasoner(self) -> bool: return self in REASONER_MODEL_MODES + @property + def is_sound_condition(self) -> bool: + return self in SOUND_CONDITION_MODEL_MODES + # Image-output modes: ``num_frames`` defaults to 1 and the output is saved as a still image. _IMAGE_OUTPUT_MODES: frozenset[ModelMode] = frozenset({ModelMode.TEXT2IMAGE, ModelMode.IMAGE2IMAGE}) @@ -187,6 +192,10 @@ def is_reasoner(self) -> bool: REASONER_MODEL_MODES: frozenset[ModelMode] = frozenset({ModelMode.REASONER}) +# Modes that condition generation on a real input audio clip (require a model +# with ``sound_gen=True`` and a ``sound_path``). +SOUND_CONDITION_MODEL_MODES: frozenset[ModelMode] = frozenset({ModelMode.AUDIO_IMAGE2VIDEO}) + class VisionMode(StrEnum): IMAGE = "image" @@ -513,6 +522,7 @@ def _build_vision_data(self, model_config: "OmniMoTModelConfig", sample_meta: Sa class SoundDataArgs(ArgsBase): enable_sound: bool = False + sound_path: ResolvedFilePath | None = None class SoundDataOverrides(OverridesBase): @@ -520,8 +530,23 @@ class SoundDataOverrides(OverridesBase): enable_sound: Training[bool | None] = None """Enable joint video+sound generation (t2vs mode). Requires a checkpoint with sound modules.""" + sound_path: ResolvedFilePathOrUrl | None = None + """Path or URL to a conditioning audio clip (e.g. .wav/.mp3/.flac). Required for + audio_image2video; the clip is encoded by the AVAE and used as a clean condition.""" + + @override + def download(self, output_dir: Path): + super().download(output_dir) + self.sound_path = download_file(self.sound_path, output_dir, "sound") def _build_sound_data(self, model_config: "OmniMoTModelConfig", sample_meta: SampleMeta): + if sample_meta.model_mode.is_sound_condition: + if self.sound_path is None: + raise ValueError( + f"model_mode={sample_meta.model_mode.value} requires a `sound_path` " + "(a conditioning audio clip)" + ) + self.enable_sound = True if self.enable_sound is None: self.enable_sound = False if self.enable_sound and not model_config.sound_gen: diff --git a/cosmos_framework/inference/args_test.py b/cosmos_framework/inference/args_test.py index 3bf3703..631f87e 100644 --- a/cosmos_framework/inference/args_test.py +++ b/cosmos_framework/inference/args_test.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: OpenMDW-1.1 import json +import types from pathlib import Path import omegaconf @@ -15,6 +16,7 @@ ModelMode, OmniSampleOverrides, OmniSetupOverrides, + SoundDataOverrides, ) from cosmos_framework.inference.common.config import structure_config @@ -156,3 +158,53 @@ def test_sample_args(tmp_path: Path): assert text2image_args.num_steps == 50 assert text2image_args.guidance == 4.0 assert text2image_args.shift == 3.0 + + +def test_build_sound_data_requires_sound_path_for_a2v(): + model_config = types.SimpleNamespace(sound_gen=True) + sample_meta = types.SimpleNamespace(model_mode=ModelMode.AUDIO_IMAGE2VIDEO) + + overrides = SoundDataOverrides(sound_path=None) + with pytest.raises(ValueError, match="sound_path"): + overrides._build_sound_data(model_config=model_config, sample_meta=sample_meta) + + overrides = SoundDataOverrides(sound_path="https://example.com/clip.wav") + overrides._build_sound_data(model_config=model_config, sample_meta=sample_meta) + assert overrides.enable_sound is True + + +def test_build_sound_data_rejects_model_without_sound_gen(): + model_config = types.SimpleNamespace(sound_gen=False) + sample_meta = types.SimpleNamespace(model_mode=ModelMode.AUDIO_IMAGE2VIDEO) + overrides = SoundDataOverrides(sound_path="https://example.com/clip.wav") + with pytest.raises(ValueError, match="sound tokenizer"): + overrides._build_sound_data(model_config=model_config, sample_meta=sample_meta) + + +def test_audio_image2video_conditions_image_and_sound(tmp_path: Path): + import omegaconf + from cosmos_framework.inference.common.config import structure_config + + setup_args = OmniSetupOverrides( + checkpoint_path=DEFAULT_CHECKPOINT_NAME, + output_dir=tmp_path / "outputs", + ).build_setup() + model_dict = structure_config(setup_args.load_model_config_dict(), omegaconf.DictConfig) + + img = tmp_path / "robot.jpg" + img.write_bytes(b"\xff\xd8\xff\xe0") # minimal non-empty file; not actually decoded here + clip = tmp_path / "clip.wav" + clip.write_bytes(b"RIFF") + + args = OmniSampleOverrides( + name="a2v", + output_dir=tmp_path / "a2v", + model_mode=ModelMode.AUDIO_IMAGE2VIDEO, + vision_path=str(img), + sound_path=str(clip), + ).build_sample(model_config=model_dict.config) + + assert args.condition_vision_mode.value == "image" + assert args.condition_frame_indexes_vision == [0] + assert args.enable_sound is True + assert Path(args.sound_path).name == "clip.wav" diff --git a/cosmos_framework/inference/common/checkpoints.py b/cosmos_framework/inference/common/checkpoints.py index 4041cb9..0a10aae 100644 --- a/cosmos_framework/inference/common/checkpoints.py +++ b/cosmos_framework/inference/common/checkpoints.py @@ -76,8 +76,8 @@ def _materialize_avae_ckpt(local_dir: str) -> None: ``[C]`` and loads via ``load_state_dict(strict=False)`` — so without remapping the keys, none match and every decoder weight is silently left at init (noise). We invert the forward conversion (key remap + snake reshape) and wrap the result - under ``state_dict``. Decoder-only is sufficient: generation only decodes sound - latents to a waveform. Idempotent. + under ``state_dict``. Native ``encoder.layers.*`` keys pass through + ``_avae_block_key_to_legacy`` unchanged. Idempotent. """ import torch from safetensors.torch import load_file @@ -249,9 +249,9 @@ def register_checkpoints(): revision="main", subdirectory="sound_tokenizer", ), - # The sound_tokenizer/ safetensors are decoder-only and use the diffusers - # OobleckDecoder key layout; _materialize_avae_ckpt remaps them back to the - # legacy decoder.layers.* layout the native AVAE loader expects. + # _materialize_avae_ckpt remaps the diffusers OobleckDecoder keys + # (decoder.block.*) back to the legacy decoder.layers.* layout the native AVAE + # loader expects; native encoder.layers.* keys pass through unchanged. post_download=_materialize_avae_ckpt, ), ) diff --git a/cosmos_framework/inference/defaults/audio_image2video/sample_args.json b/cosmos_framework/inference/defaults/audio_image2video/sample_args.json new file mode 100644 index 0000000..1fe7360 --- /dev/null +++ b/cosmos_framework/inference/defaults/audio_image2video/sample_args.json @@ -0,0 +1,22 @@ +{ + "num_steps": 35, + "guidance": 6.0, + "shift": 10.0, + "sigma_max": 80.0, + "normalize_cfg": false, + "autoregressive": false, + "negative_prompt": null, + "negative_prompt_file": "neg_prompts.json", + "duration_template": "The video is {duration:.1f} seconds long and is of {fps:.0f} FPS.", + "resolution_template": "This video is of {height}x{width} resolution.", + "negative_metadata_mode": "none", + "inverse_duration_template": "The video is not {duration:.1f} seconds long and is not of {fps:.0f} FPS.", + "inverse_resolution_template": "This video is not of {height}x{width} resolution.", + "negative_prompt_keep_metadata": true, + "aspect_ratio": "16,9", + "fps": 24, + "num_frames": 189, + "video_save_quality": 10, + "image_save_quality": 95, + "enable_sound": true +} diff --git a/cosmos_framework/inference/inference.py b/cosmos_framework/inference/inference.py index 83ff665..564da97 100644 --- a/cosmos_framework/inference/inference.py +++ b/cosmos_framework/inference/inference.py @@ -612,17 +612,29 @@ def get_sample_data( create_placeholder_audio, get_audio_tokenizer_info, inject_sound_into_batch, + load_conditioning_audio, ) audio_info = get_audio_tokenizer_info(model) if not audio_info.has_sound: raise ValueError("enable_sound=True but model has no sound tokenizer") - audio_placeholder = create_placeholder_audio( - num_frames=sample_args.num_frames, - conditioning_fps=sample_args.fps, - audio_info=audio_info, - ) - inject_sound_into_batch(out, audio_placeholder, model) + + condition_sound = sample_args.sound_path is not None + if condition_sound: + num_samples = int(sample_args.num_frames / sample_args.fps * audio_info.sample_rate) + audio = load_conditioning_audio( + Path(sample_args.sound_path), + sample_rate=audio_info.sample_rate, + audio_channels=getattr(audio_info.tokenizer, "audio_channels", 2), + num_samples=num_samples, + ) + else: + audio = create_placeholder_audio( + num_frames=sample_args.num_frames, + conditioning_fps=sample_args.fps, + audio_info=audio_info, + ) + inject_sound_into_batch(out, audio, model, condition_sound=condition_sound) return out @@ -1062,6 +1074,23 @@ def _create(cls, setup_args: SetupArgs, **kwargs: Any) -> Self: tokenizer_cfg.pop("revision", None) tokenizer_cfg.pop("subdir", None) tokenizer_cfg["tokenizer_type"] = str(checkpoint_path) + # AVAE source: the configured ``avae_path`` when set, else the loaded + # checkpoint's bundled ``sound_tokenizer/``. The inference-only + # ``from_checkpoint`` key (default False) forces bundled; pop it so it + # never reaches AVAEInterface. + sound_cfg = model_dict["config"].get("sound_tokenizer") + if sound_cfg is not None: + from_checkpoint = sound_cfg.pop("from_checkpoint", False) + sound_tokenizer_dir = Path(checkpoint_path) / "sound_tokenizer" + if sound_tokenizer_dir.is_dir() and (from_checkpoint or not sound_cfg.get("avae_path")): + from cosmos_framework.inference.common.checkpoints import ( + _AVAE_LEGACY_CKPT_NAME, + _materialize_avae_ckpt, + ) + + _materialize_avae_ckpt(str(sound_tokenizer_dir)) + sound_cfg["bucket_name"] = "" + sound_cfg["avae_path"] = str(sound_tokenizer_dir / _AVAE_LEGACY_CKPT_NAME) config = Cosmos3OmniConfig(model=model_dict) model = Cosmos3OmniModel.from_pretrained_dcp( checkpoint_path, diff --git a/cosmos_framework/inference/sound.py b/cosmos_framework/inference/sound.py index 1ecf3fb..59d008e 100644 --- a/cosmos_framework/inference/sound.py +++ b/cosmos_framework/inference/sound.py @@ -59,10 +59,68 @@ def create_placeholder_audio( return torch.zeros(1, sound_channels, sound_num_samples) # [1,C_audio,N_samples] +def load_conditioning_audio( + path: Path, + *, + sample_rate: int, + audio_channels: int, + num_samples: int, +) -> torch.Tensor: + """Decode an audio file into a conditioning waveform aligned to the video. + + Reads ``path`` with soundfile, resamples to ``sample_rate``, conforms the + channel count to ``audio_channels`` (mono->stereo duplicate, stereo->mono + mean), and trims or zero-pads to exactly ``num_samples`` so the audio and + video latent streams cover the same duration. + + Returns: + Audio tensor of shape (1, C, N) where C == audio_channels and + N == num_samples, dtype float32. + """ + import soundfile as sf # type: ignore[import-not-found] + + data, src_sr = sf.read(str(path), dtype="float32", always_2d=True) # [N, C] + waveform = torch.from_numpy(data).transpose(0, 1).contiguous() # [C, N] + + # Resample with scipy (torchaudio is not a project dependency). + if src_sr != sample_rate: + from math import gcd + + import scipy.signal + + g = gcd(int(src_sr), int(sample_rate)) + up, down = int(sample_rate) // g, int(src_sr) // g + resampled = scipy.signal.resample_poly(waveform.numpy(), up, down, axis=-1) # [C, N'] + waveform = torch.from_numpy(resampled.astype("float32")).contiguous() + + # Conform channels. + cur_channels = waveform.shape[0] + if cur_channels != audio_channels: + if cur_channels == 1 and audio_channels == 2: + waveform = waveform.repeat(2, 1) + elif cur_channels == 2 and audio_channels == 1: + waveform = waveform.mean(dim=0, keepdim=True) + else: + raise ValueError( + f"Cannot convert {cur_channels}-channel audio to {audio_channels} channels" + ) + + # Trim or zero-pad to num_samples. + n = waveform.shape[-1] + if n > num_samples: + waveform = waveform[:, :num_samples] + elif n < num_samples: + waveform = torch.nn.functional.pad(waveform, (0, num_samples - n)) + + return waveform.unsqueeze(0).to(dtype=torch.float32) # [1, C, N] + + def inject_sound_into_batch( data_batch: dict[str, Any], audio_tensor: torch.Tensor | None, model: Any, + *, + condition_sound: bool = False, ) -> dict[str, Any]: """Add sound data and upgrade the SequencePlan in an existing data batch. @@ -73,6 +131,9 @@ def inject_sound_into_batch( data_batch: Existing data batch (from get_video_sample_batch or build_conditioned_video_batch). audio_tensor: Audio waveform tensor (1, C, N) or None. model: The OmniMoTModel instance. + condition_sound: When True, the provided audio is used as a clean + condition (mode "ts2v") and the video is generated from it. When + False (default), sound is generated jointly (mode "t2vs"). Returns: The same data_batch dict, mutated in-place with sound fields added. @@ -103,7 +164,7 @@ def inject_sound_into_batch( # existing vision conditioning is preserved in the sequence plan for i2v and v2v modes sequence_plan = build_sequence_plan_for_sound( - mode="t2vs", + mode="ts2v" if condition_sound else "t2vs", video_latent_length=video_latent_t, sound_latent_length=sound_latent_t, ) diff --git a/cosmos_framework/inference/sound_test.py b/cosmos_framework/inference/sound_test.py new file mode 100644 index 0000000..0666429 --- /dev/null +++ b/cosmos_framework/inference/sound_test.py @@ -0,0 +1,87 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: OpenMDW-1.1 + +from pathlib import Path + +import soundfile as sf +import torch + +from cosmos_framework.inference.sound import load_conditioning_audio + + +def _write_wav(path: Path, sample_rate: int, channels: int, num_samples: int) -> None: + if channels > 1: + data = torch.zeros(num_samples, channels).numpy() + else: + data = torch.zeros(num_samples).numpy() + sf.write(str(path), data, sample_rate) + + +def test_load_conditioning_audio_resamples_and_pads(tmp_path: Path): + src = tmp_path / "in.wav" + _write_wav(src, sample_rate=44100, channels=1, num_samples=44100) # 1.0s mono @44.1k + + out = load_conditioning_audio(src, sample_rate=48000, audio_channels=2, num_samples=96000) + + assert out.shape == (1, 2, 96000) # [1, C, N]; stereo, padded to 2.0s @48k + assert out.dtype == torch.float32 + + +def test_load_conditioning_audio_trims(tmp_path: Path): + src = tmp_path / "in.wav" + _write_wav(src, sample_rate=48000, channels=2, num_samples=48000 * 4) # 4s stereo @48k + + out = load_conditioning_audio(src, sample_rate=48000, audio_channels=2, num_samples=48000 * 2) + + assert out.shape == (1, 2, 48000 * 2) # trimmed to 2s + + +import types + +from cosmos_framework.data.vfm.sequence_packing import SequencePlan +from cosmos_framework.inference.sound import inject_sound_into_batch + + +def _fake_model(sound_latent_t: int, temporal_cf: int = 4): + sound_tok = types.SimpleNamespace( + get_latent_num_samples=lambda n: sound_latent_t, + audio_channels=2, + ) + vision_tok = types.SimpleNamespace(temporal_compression_factor=temporal_cf) + return types.SimpleNamespace(tokenizer_sound_gen=sound_tok, tokenizer_vision_gen=vision_tok) + + +def test_inject_sound_conditions_sound_and_preserves_image(): + model = _fake_model(sound_latent_t=50) + video = torch.zeros(1, 3, 48, 16, 16) # [1,3,T,H,W], T=48 -> 12 video latents @cf=4 + audio = torch.zeros(1, 2, 96000) + batch = { + "video": [video], + "sequence_plan": [ + SequencePlan(has_text=True, has_vision=True, condition_frame_indexes_vision=[0]) + ], + } + + inject_sound_into_batch(batch, audio, model, condition_sound=True) + + plan = batch["sequence_plan"][0] + assert plan.has_sound is True + assert plan.condition_frame_indexes_sound == list(range(50)) # all sound conditioned (ts2v) + assert plan.condition_frame_indexes_vision == [0] # image cond preserved + + +def test_inject_sound_default_generates_sound(): + model = _fake_model(sound_latent_t=50) + video = torch.zeros(1, 3, 48, 16, 16) + audio = torch.zeros(1, 2, 96000) + batch = { + "video": [video], + "sequence_plan": [ + SequencePlan(has_text=True, has_vision=True, condition_frame_indexes_vision=[]) + ], + } + + inject_sound_into_batch(batch, audio, model) # default condition_sound=False + + plan = batch["sequence_plan"][0] + assert plan.condition_frame_indexes_sound == [] # t2vs: sound generated From 35a6fb6db2782664e271b3ecb96ef4c45471bd09 Mon Sep 17 00:00:00 2001 From: Maosheng Liao Date: Mon, 15 Jun 2026 10:58:29 +0800 Subject: [PATCH 10/16] =?UTF-8?q?docs:=20fix=20stale=20dataloader-state=20?= =?UTF-8?q?callback=20references=20in=20custom=5Fdatase=E2=80=A6=20(#40)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Update document for custom dataset. Co-authored-by: Claude Opus 4.8 (1M context) --- docs/custom_dataset.md | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/docs/custom_dataset.md b/docs/custom_dataset.md index 5fbdae3..be7e1aa 100644 --- a/docs/custom_dataset.md +++ b/docs/custom_dataset.md @@ -249,11 +249,11 @@ Override from the CLI like any Hydra node, e.g. ## 5. Checkpoint / resume -Resume is handled by the existing `DataLoaderStateCallback`: +Resume is handled by `CosmosDataLoaderStateCallback`: ```python -from cosmos_framework.callbacks.dataloader_state import DataLoaderStateCallback -cb = DataLoaderStateCallback(distributor_type="cosmos_dataloader") +from cosmos_framework.callbacks.cosmos_dataloader_state import CosmosDataLoaderStateCallback +cb = CosmosDataLoaderStateCallback() ``` - Use a **`MapDistributor`** source. On save, the callback records each worker's @@ -266,8 +266,8 @@ cb = DataLoaderStateCallback(distributor_type="cosmos_dataloader") - For multiple loaders sharing a process (e.g. inside `JointCosmosDataLoader`), give each a distinct `name=` so resume env vars are namespaced (`COSMOS_DL_STATE_{name}_WORKER_{id}_{EPOCH,INDEX}`), and use a single - `JointDataLoaderStateCallback(outer_loader=joint_loader, distributor_type="cosmos_dataloader")` - instead of one `DataLoaderStateCallback` per inner loader. + `JointCosmosDataLoaderStateCallback(outer_loader=joint_loader)` + instead of one `CosmosDataLoaderStateCallback` per inner loader. - Use `ckpt_type=dcp` (the default) — not `ckpt_type=dummy`, which disables all checkpointing. The on-disk checkpoint format is unchanged. @@ -423,7 +423,7 @@ collator: VFMListCollator # media kept as p - [ ] Pick a **collator**: `DefaultBatchCollator`, `VFMListCollator`, or your own (must match the structure the model consumes). - [ ] For real resume: use a `MapDistributor`, add - `DataLoaderStateCallback(distributor_type="cosmos_dataloader")`, and + `CosmosDataLoaderStateCallback()`, and `ckpt_type=dcp` (not `dummy`). - [ ] For FSDP+TP/PP, pass `parallel_dims=` so the correct DP rank is used. - [ ] Register the experiment in the Hydra ConfigStore @@ -437,8 +437,8 @@ collator: VFMListCollator # media kept as p `name=` matching its key in `dataloaders` (namespaces resume env vars). - [ ] Set each dataset's `ratio` (controls how often it is visited, per batch). - [ ] Use a single - `JointDataLoaderStateCallback(outer_loader=joint_loader, distributor_type="cosmos_dataloader")` - — do **not** also register a standalone `DataLoaderStateCallback` per inner + `JointCosmosDataLoaderStateCallback(outer_loader=joint_loader)` + — do **not** also register a standalone `CosmosDataLoaderStateCallback` per inner loader. - [ ] Avoid `"global_id"` as a dataset name (reserved by the checkpoint state). - [ ] Use `ckpt_type=dcp` for real checkpoint/resume. From 5fb773cf156ec7064d7909aec62b8efbde781dbd Mon Sep 17 00:00:00 2001 From: yangyangt Date: Thu, 11 Jun 2026 23:05:27 -0700 Subject: [PATCH 11/16] Release: sync from imaginaire4 + restore local_datasets to main Pipeline run via packages/cosmos-framework-release/release.sh: - 220+ files changed/added/removed across guardrails, callbacks, configs, data, model, tools, utils to match current i4 source. - local_datasets/ restored to match cosmos-framework main exactly; the dir is now CF-owned (excluded from the mapping going forward). - Removed 4 orphan files re-introduced on this branch (multiview_dataloader, vlm/defaults/dataloader, nvlm_data_unify, nvlm_sample_loaders_and_part_filters) -- already excluded in mapping_config.toml; nothing in CF imports them. - New modules brought in: data/imaginaire/webdataset/augmentors/image, data/vfm/action/action_processing, data/vfm/vlm/video_decoder_qwen, data/vlm/processors/{nemotron3densevl,nemotronvl}, model/tokenizer/evaluation, model/vfm/mot/cosmos3_vfm_qwen3_vl_network_test, utils/vfm/video_preprocess, others. - Internal http(s) URLs scrubbed to https://invalid_url (s3://, github, pytorch, docs.nvidia, arxiv, etc. preserved). NFS/usr leak paths scrubbed to /invalid_dir. SPDX/OpenMDW-1.1 headers applied. - COSMOS_INTERNAL flag now defaults to False (was inheriting TRAINING=True). - Zero dangling cosmos_framework module imports. Co-Authored-By: Claude Sonnet 4.6 --- cosmos_framework/__init__.py | 1 - .../auxiliary/guardrail/common/presets.py | 5 +- .../face_blur_filter/retinaface_utils.py | 1 - .../guardrail/qwen3guard/__init__.py | 1 - .../callbacks/compile_tokenizer.py | 15 +- cosmos_framework/callbacks/data_stats.py | 1 - .../callbacks/dataloader_state.py | 18 +- .../callbacks/every_n_draw_sample.py | 4 - cosmos_framework/callbacks/grad_clip.py | 2 +- cosmos_framework/callbacks/hf_export.py | 5 +- cosmos_framework/callbacks/mfu.py | 1 - cosmos_framework/callbacks/wandb_log_eval.py | 1 - cosmos_framework/checkpoint/dcp.py | 34 +- cosmos_framework/checkpoint/s3_filesystem.py | 165 +- cosmos_framework/configs/base/__init__.py | 1 - .../configs/base/base_config_test.py | 27 +- .../configs/base/defaults/__init__.py | 1 - .../configs/base/defaults/callbacks.py | 1 - .../configs/base/defaults/cluster.py | 22 +- .../configs/base/defaults/compile.py | 2 +- .../configs/base/defaults/model_config.py | 18 +- .../base/defaults/multiview_dataloader.py | 150 - .../configs/base/defaults/tokenizer.py | 43 +- .../configs/base/defaults/unittest.py | 5 + cosmos_framework/configs/base/defaults/vlm.py | 203 +- .../base/experiment/action/__init__.py | 1 - .../action/posttrain_config/__init__.py | 1 - .../action/pretrained_config/__init__.py | 1 - .../experiment/posttrain_video/__init__.py | 1 - cosmos_framework/configs/base/vlm/__init__.py | 1 - cosmos_framework/configs/base/vlm/config.py | 3 +- .../configs/base/vlm/defaults/__init__.py | 1 - .../configs/base/vlm/defaults/callbacks.py | 2 - .../configs/base/vlm/defaults/dataloader.py | 80 - .../configs/base/vlm/defaults/optimizer.py | 1 + .../base/vlm/defaults/policy_config.py | 7 +- .../configs/base/vlm/experiment/__init__.py | 1 - .../configs/base/vlm/freeze_config.py | 1 + .../webdataset/augmentors/image/__init__.py | 2 + .../webdataset/augmentors/image/cropping.py | 150 + .../webdataset/augmentors/image/flip.py | 32 + .../webdataset/augmentors/image/misc.py | 51 + .../webdataset/augmentors/image/normalize.py | 36 + .../webdataset/augmentors/image/padding.py | 60 + .../webdataset/augmentors/image/resize.py | 175 + cosmos_framework/data/vfm/action/__init__.py | 1 - .../data/vfm/action/action_processing.py | 257 ++ .../data/vfm/action/domain_utils.py | 29 +- .../data/vfm/action/json_formatter.py | 2 +- .../data/vfm/action/pose_utils.py | 44 +- .../data/vfm/action/pose_utils_test.py | 2 - .../data/vfm/action/transforms.py | 88 +- .../data/vfm/action/transforms_test.py | 24 +- .../data/vfm/augmentor_provider.py | 21 +- .../data/vfm/augmentors/__init__.py | 1 - .../vfm/augmentors/idle_frames_text_info.py | 2 +- .../vfm/augmentors/image_editing_transform.py | 184 +- .../image_editing_transform_test.py | 159 + .../augmentors/interleaved_video_parsing.py | 2 +- .../data/vfm/augmentors/pkl_to_media.py | 2 - .../data/vfm/augmentors/sequence_plan.py | 58 +- .../augmentors/text_transforms_for_image.py | 6 +- .../augmentors/transfer_control_transform.py | 2 +- .../data/vfm/augmentors/video_parsing.py | 100 +- .../data/vfm/augmentors/vlm/__init__.py | 1 - .../vfm/augmentors/vlm/nvlm_data_unify.py | 120 - .../nvlm_sample_loaders_and_part_filters.py | 2815 ----------------- .../data/vfm/augmentors/vlm/prompt_format.py | 4 +- .../data/vfm/augmentors/vlm/timestamp.py | 6 +- .../vlm/timestamp_with_subject_tracking.py | 4 +- .../vlm/timestamp_without_augment_message.py | 4 +- .../vlm/timestamp_without_end_time.py | 2 - .../data/vfm/augmentors/vlm/tokenize_data.py | 4 - cosmos_framework/data/vfm/joint_dataloader.py | 78 +- .../data/vfm/packing_iterable_dataset.py | 271 ++ .../data/vfm/processors/__init__.py | 10 +- .../vfm/processors/nemotronvl_processor.py | 4 +- .../data/vfm/processors/qwen3vl_processor.py | 4 +- cosmos_framework/data/vfm/sequence_packing.py | 210 +- cosmos_framework/data/vfm/sound_data_utils.py | 6 +- cosmos_framework/data/vfm/utils.py | 52 +- .../data/vfm/vlm/video_decoder_qwen.py | 249 ++ .../processors/nemotron3densevl_processor.py | 248 ++ .../vlm/processors/nemotronvl_processor.py | 553 ++++ .../data/vlm/processors/qwen3vl_processor.py | 19 +- cosmos_framework/model/attention/checks.py | 6 +- .../model/attention/cudnn/checks.py | 1 - .../model/attention/cudnn/cudnn_forward.py | 3 +- .../model/attention/cudnn/functions.py | 1 - .../model/attention/cudnn/meta.py | 1 - .../model/attention/flash2/__init__.py | 17 +- .../model/attention/flash2/checks.py | 17 +- .../model/attention/flash2/functions.py | 2 +- .../model/attention/flash3/functions.py | 4 +- cosmos_framework/model/attention/frontend.py | 4 +- .../model/attention/natten/checks.py | 2 +- cosmos_framework/model/attention/varlen.py | 6 +- .../model/tokenizer/evaluation/metric.py | 431 +++ .../evaluation/reconstruction_metrics.py | 497 +++ .../model/tokenizer/models/__init__.py | 3 + .../model/tokenizer/models/dense_backends.py | 53 +- .../model/tokenizer/models/dense_runtime.py | 331 +- .../models/modules/attention/full_attn.py | 10 +- .../models/modules/quantizers/fsq.py | 12 +- .../models/modules/quantizers/lfq.py | 152 +- .../models/modules/quantizers/residual_vq.py | 149 +- .../tokenizer/models/sparse_autoencoder.py | 51 +- .../model/tokenizer/models/text_decoder.py | 10 +- .../model/tokenizer/models/utils.py | 14 +- cosmos_framework/model/tokenizer/utils/hf.py | 6 +- .../model/vfm/algorithm/loss/__init__.py | 1 + .../model/vfm/algorithm/loss/cross_entropy.py | 1 + .../model/vfm/algorithm/loss/flow_matching.py | 1 + .../model/vfm/diffusion/samplers/__init__.py | 1 - .../model/vfm/diffusion/samplers/edm.py | 2 +- .../diffusion/samplers/fm_solvers_unipc.py | 7 +- cosmos_framework/model/vfm/hf_model.py | 7 +- cosmos_framework/model/vfm/mot/__init__.py | 1 - cosmos_framework/model/vfm/mot/attention.py | 12 +- .../model/vfm/mot/attention_test.py | 2 +- .../model/vfm/mot/cfgp_ar_test.py | 6 +- .../model/vfm/mot/context_parallel_utils.py | 2 +- .../model/vfm/mot/cosmos3_vfm_network.py | 5 +- .../mot/cosmos3_vfm_qwen3_vl_network_test.py | 1081 +++++++ .../model/vfm/mot/dot_product_attention.py | 7 +- .../model/vfm/mot/modeling_utils.py | 1 - .../model/vfm/mot/unified_3dmrope_utils.py | 48 +- cosmos_framework/model/vfm/mot/unified_mot.py | 133 +- cosmos_framework/model/vfm/omni_mot_model.py | 710 +++-- cosmos_framework/model/vfm/parallelize_vlm.py | 5 +- .../model/vfm/tokenizers/audio/__init__.py | 1 - .../model/vfm/tokenizers/audio/avae.py | 9 +- .../audio/avae_utils/activations.py | 1 + .../avae_utils/alias_free_torch/__init__.py | 1 + .../audio/avae_utils/alias_free_torch/act.py | 1 + .../avae_utils/alias_free_torch/filter.py | 1 + .../avae_utils/alias_free_torch/resample.py | 1 + .../audio/avae_utils/bottlenecks.py | 1 + .../vfm/tokenizers/audio/avae_utils/env.py | 1 + .../vfm/tokenizers/audio/avae_utils/models.py | 1 + .../tokenizers/audio/avae_utils/modules.py | 1 + .../audio/avae_utils/modules_encodec.py | 1 + .../vfm/tokenizers/dc_ae/dc_ae_4x32x32.py | 65 +- .../model/vfm/tokenizers/dc_ae/dc_ae_v.py | 7 +- .../model/vfm/tokenizers/dc_ae/dc_ae_v_ops.py | 1 - .../dc_ae/dc_ae_v_triton_rms_norm.py | 1 - .../model/vfm/tokenizers/flux_vae_8x8.py | 2 + .../model/vfm/tokenizers/interface.py | 31 +- .../vfm/tokenizers/tokenization_qwen2.py | 5 +- .../model/vfm/tokenizers/uniae/__init__.py | 1 - .../model/vfm/tokenizers/uniae/frame_math.py | 326 ++ .../vfm/tokenizers/uniae/noncausal_4x16x16.py | 412 ++- .../model/vfm/tokenizers/wan2pt1_vae_4x8x8.py | 3 + .../vfm/tokenizers/wan2pt2_vae_4x16x16.py | 7 +- .../model/vfm/upsampler/__init__.py | 1 - .../model/vfm/upsampler/prompts.py | 265 +- cosmos_framework/model/vfm/utils/__init__.py | 1 - .../model/vfm/utils/data_and_condition.py | 2 + cosmos_framework/model/vfm/utils/memory.py | 11 +- .../model/vfm/utils/safetensors_loader.py | 4 +- .../vfm/utils/safetensors_loader_test.py | 11 +- cosmos_framework/model/vfm/vlm/__init__.py | 1 - .../vfm/vlm/nemotron_3_dense_vl/__init__.py | 1 - .../nemotron_3_dense_vl_test.py | 741 ++++- .../model/vfm/vlm/qwen3_vl/__init__.py | 1 - .../vfm/vlm/qwen3_vl/configs/__init__.py | 1 - .../vlm/qwen3_vl/configuration_qwen3_vl.py | 16 +- .../model/vfm/vlm/qwen3_vl/qwen3_vl.py | 37 +- .../model/vfm/vlm/qwen3_vl/utils.py | 3 - .../vlm/qwen3_vl/video_processing_qwen3_vl.py | 16 +- .../model/vfm/vlm/qwen3_vl_moe/__init__.py | 1 - .../vfm/vlm/qwen3_vl_moe/configs/__init__.py | 1 - .../model/vfm/vlm/qwen3_vl_moe/moe.py | 11 +- .../model/vfm/vlm/qwen3_vl_moe/moe_test.py | 39 +- .../vfm/vlm/qwen3_vl_moe/qwen3_vl_moe.py | 32 +- cosmos_framework/model/vfm/vlm_model.py | 4 +- cosmos_framework/tools/flops/qwen3_vl.py | 2 +- cosmos_framework/tools/visualize/video.py | 3 +- cosmos_framework/trainer/__init__.py | 1 - cosmos_framework/utils/__init__.py | 1 - cosmos_framework/utils/callback.py | 2 +- cosmos_framework/utils/checkpoint_db.py | 13 +- cosmos_framework/utils/checkpointer.py | 3 - cosmos_framework/utils/config.py | 5 +- cosmos_framework/utils/device.py | 4 - cosmos_framework/utils/distributed.py | 4 +- .../utils/easy_io/backends/base_backend.py | 2 +- .../utils/easy_io/backends/boto3_backend.py | 4 +- .../utils/easy_io/backends/http_backend.py | 2 +- .../utils/easy_io/backends/local_backend.py | 4 +- .../utils/easy_io/backends/msc_backend.py | 12 +- cosmos_framework/utils/easy_io/easy_io.py | 3 +- cosmos_framework/utils/easy_io/file_client.py | 4 +- .../easy_io/handlers/imageio_video_handler.py | 69 +- .../utils/easy_io/handlers/registry_utils.py | 1 + cosmos_framework/utils/ema.py | 2 +- .../utils/env_parsers/cred_env_parser.py | 4 +- .../utils/lazy_config/__init__.py | 4 +- cosmos_framework/utils/lazy_config/file_io.py | 1 - cosmos_framework/utils/lazy_config/lazy.py | 7 +- cosmos_framework/utils/misc.py | 8 +- cosmos_framework/utils/object_store.py | 54 +- .../one_logger/one_logger_override_utils.py | 2 +- .../utils/one_logger/one_logger_utils.py | 10 +- cosmos_framework/utils/serialization.py | 32 +- .../utils/training_telemetry/callback.py | 4 +- .../utils/training_telemetry/utils.py | 12 +- cosmos_framework/utils/vfm/data_utils.py | 6 +- cosmos_framework/utils/vfm/flash_attn.py | 1 + .../utils/vfm/hf_attention_cosmos.py | 10 +- cosmos_framework/utils/vfm/model_loader.py | 2 +- cosmos_framework/utils/vfm/monkey_patch.py | 2 +- cosmos_framework/utils/vfm/optimizer.py | 10 +- cosmos_framework/utils/vfm/parallelism.py | 5 +- .../utils/vfm/video_preprocess.py | 32 + cosmos_framework/utils/vfm/vlm/__init__.py | 1 - .../utils/vfm/vlm/flop_calculator.py | 1 - .../vfm/vlm/pretrained_models_downloader.py | 7 +- cosmos_framework/utils/vlm/__init__.py | 1 - .../utils/vlm/compute_flops_qwen3vl.py | 4 +- .../utils/vlm/dcp_checkpointer.py | 2 - cosmos_framework/utils/vlm/distributed.py | 4 +- cosmos_framework/utils/vlm/optimizer.py | 1 - .../utils/vlm/pretrained_models_downloader.py | 2 +- 224 files changed, 8504 insertions(+), 4745 deletions(-) delete mode 100644 cosmos_framework/configs/base/defaults/multiview_dataloader.py delete mode 100644 cosmos_framework/configs/base/vlm/defaults/dataloader.py create mode 100644 cosmos_framework/data/imaginaire/webdataset/augmentors/image/__init__.py create mode 100644 cosmos_framework/data/imaginaire/webdataset/augmentors/image/cropping.py create mode 100644 cosmos_framework/data/imaginaire/webdataset/augmentors/image/flip.py create mode 100644 cosmos_framework/data/imaginaire/webdataset/augmentors/image/misc.py create mode 100644 cosmos_framework/data/imaginaire/webdataset/augmentors/image/normalize.py create mode 100644 cosmos_framework/data/imaginaire/webdataset/augmentors/image/padding.py create mode 100644 cosmos_framework/data/imaginaire/webdataset/augmentors/image/resize.py create mode 100644 cosmos_framework/data/vfm/action/action_processing.py create mode 100644 cosmos_framework/data/vfm/augmentors/image_editing_transform_test.py delete mode 100644 cosmos_framework/data/vfm/augmentors/vlm/nvlm_data_unify.py delete mode 100644 cosmos_framework/data/vfm/augmentors/vlm/nvlm_sample_loaders_and_part_filters.py create mode 100644 cosmos_framework/data/vfm/packing_iterable_dataset.py create mode 100644 cosmos_framework/data/vfm/vlm/video_decoder_qwen.py create mode 100644 cosmos_framework/data/vlm/processors/nemotron3densevl_processor.py create mode 100644 cosmos_framework/data/vlm/processors/nemotronvl_processor.py create mode 100644 cosmos_framework/model/tokenizer/evaluation/metric.py create mode 100644 cosmos_framework/model/tokenizer/evaluation/reconstruction_metrics.py create mode 100644 cosmos_framework/model/vfm/mot/cosmos3_vfm_qwen3_vl_network_test.py create mode 100644 cosmos_framework/model/vfm/tokenizers/uniae/frame_math.py create mode 100644 cosmos_framework/utils/vfm/video_preprocess.py diff --git a/cosmos_framework/__init__.py b/cosmos_framework/__init__.py index 503ec1b..28a81be 100644 --- a/cosmos_framework/__init__.py +++ b/cosmos_framework/__init__.py @@ -1,3 +1,2 @@ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: OpenMDW-1.1 - diff --git a/cosmos_framework/auxiliary/guardrail/common/presets.py b/cosmos_framework/auxiliary/guardrail/common/presets.py index d320b5e..73b27d5 100644 --- a/cosmos_framework/auxiliary/guardrail/common/presets.py +++ b/cosmos_framework/auxiliary/guardrail/common/presets.py @@ -27,8 +27,9 @@ def create_video_guardrail_runner(offload_model_to_cpu: bool = False) -> Guardra """Create the video guardrail runner.""" return GuardrailRunner( safety_models=[ - # VideoContentSafetyFilter(offload_model_to_cpu=offload_model_to_cpu), # Too many false positives - ], + #VideoContentSafetyFilter(offload_model_to_cpu=offload_model_to_cpu) + # Too many false positives, add back when fixed + ], postprocessors=[RetinaFaceFilter(offload_model_to_cpu=offload_model_to_cpu)], ) diff --git a/cosmos_framework/auxiliary/guardrail/face_blur_filter/retinaface_utils.py b/cosmos_framework/auxiliary/guardrail/face_blur_filter/retinaface_utils.py index cffebc2..805ecd5 100644 --- a/cosmos_framework/auxiliary/guardrail/face_blur_filter/retinaface_utils.py +++ b/cosmos_framework/auxiliary/guardrail/face_blur_filter/retinaface_utils.py @@ -1,4 +1,3 @@ -# Copyright (c) 2019 # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: OpenMDW-1.1 diff --git a/cosmos_framework/auxiliary/guardrail/qwen3guard/__init__.py b/cosmos_framework/auxiliary/guardrail/qwen3guard/__init__.py index 503ec1b..28a81be 100644 --- a/cosmos_framework/auxiliary/guardrail/qwen3guard/__init__.py +++ b/cosmos_framework/auxiliary/guardrail/qwen3guard/__init__.py @@ -1,3 +1,2 @@ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: OpenMDW-1.1 - diff --git a/cosmos_framework/callbacks/compile_tokenizer.py b/cosmos_framework/callbacks/compile_tokenizer.py index 84efa16..3ee15d2 100644 --- a/cosmos_framework/callbacks/compile_tokenizer.py +++ b/cosmos_framework/callbacks/compile_tokenizer.py @@ -4,7 +4,7 @@ """Training callback that defers AOT compilation of the VAE tokenizer. The actual compilation logic lives in -:meth:`~projects.cosmos3.vfm.tokenizers.wan2pt2_vae_4x16x16.Wan2pt2VAEInterface.compile_encode`. +:meth:`~cosmos_framework.model.vfm.tokenizers.wan2pt2_vae_4x16x16.Wan2pt2VAEInterface.compile_encode`. This module provides a :class:`CompileTokenizer` callback that invokes it at the right point during training (after ``compile_after_iterations`` steps, to avoid NCCL timeouts during CUDA/cuDNN warm-up). @@ -21,6 +21,7 @@ """ from collections.abc import Sequence +from typing import Literal import torch @@ -43,6 +44,10 @@ def __init__( enabled: bool = False, compile_after_iterations: int = 3, warmup_resolutions: Sequence[str] | None = None, + backend: Literal["cudagraphs", "inductor"] = "inductor", + mode: Literal["reduce-overhead", "max-autotune"] | None = "reduce-overhead", + fullgraph: bool = False, + dynamic: bool = False, ): """ Args: @@ -60,6 +65,10 @@ def __init__( self.compile_after_iterations: int = compile_after_iterations self.skip_counter: int = 0 self.warmup_resolutions: Sequence[str] | None = warmup_resolutions + self.backend: Literal["cudagraphs", "inductor"] = backend + self.mode: Literal["reduce-overhead", "max-autotune"] | None = mode + self.fullgraph: bool = fullgraph + self.dynamic: bool = dynamic if self.enabled: if self.warmup_resolutions is None: @@ -101,6 +110,10 @@ def on_training_step_start( tokenizer.compile_encode( self.warmup_resolutions, output_dir=self.config.job.path_local, + backend=self.backend, + mode=self.mode, + fullgraph=self.fullgraph, + dynamic=self.dynamic, ) self.skip_counter += 1 diff --git a/cosmos_framework/callbacks/data_stats.py b/cosmos_framework/callbacks/data_stats.py index 20e3161..2914981 100644 --- a/cosmos_framework/callbacks/data_stats.py +++ b/cosmos_framework/callbacks/data_stats.py @@ -51,7 +51,6 @@ def on_training_step_end( # Handle case where dataset_name gets batched into a list if isinstance(dataset_name, list): - assert len(dataset_name) == 1, "dataset_name should be a list of 1" dataset_name = dataset_name[0] diff --git a/cosmos_framework/callbacks/dataloader_state.py b/cosmos_framework/callbacks/dataloader_state.py index ec20eea..bee9e43 100644 --- a/cosmos_framework/callbacks/dataloader_state.py +++ b/cosmos_framework/callbacks/dataloader_state.py @@ -26,18 +26,14 @@ class DataLoaderStateCallback(Callback): def __init__( self, distributor_type: str | None = None, - name: str = "", ) -> None: super().__init__() self.distributor_type = distributor_type - self.name = name self.config: Any = None self.state: dict[int, NoReplaceShardlistState] = {} self.verbose = True def _update_state_from_batch(self, data_batch: dict[str, torch.Tensor]) -> None: - if "sample_worker_id" not in data_batch: - return # batch has no position metadata (shuffle=False or iterable data_source) worker_ids = data_batch["sample_worker_id"].tolist() # [B] epochs = data_batch["sample_epoch"].tolist() # [B] indices = data_batch["sample_index"].tolist() # [B] @@ -50,8 +46,6 @@ def _update_state_from_batch(self, data_batch: dict[str, torch.Tensor]) -> None: ): self.state[worker_id] = NoReplaceShardlistState(epoch=epoch, index=index) - _ACTIVE_DISTRIBUTOR_TYPES = ("no_replace",) - def on_training_step_batch_end( self, model: ImaginaireModel, @@ -60,7 +54,7 @@ def on_training_step_batch_end( loss: torch.Tensor, iteration: int = 0, ) -> None: - if self.distributor_type in self._ACTIVE_DISTRIBUTOR_TYPES: + if self.distributor_type == "no_replace": self._update_state_from_batch(data_batch) def on_training_step_end( @@ -71,7 +65,7 @@ def on_training_step_end( loss: torch.Tensor, iteration: int = 0, ) -> None: - if self.distributor_type in self._ACTIVE_DISTRIBUTOR_TYPES: + if self.distributor_type == "no_replace": if self.verbose: if iteration % self.config.trainer.logging_iter == 0: msg = "\n" @@ -80,10 +74,10 @@ def on_training_step_end( log.info(msg) def has_checkpoint_state(self) -> bool: - return self.distributor_type in self._ACTIVE_DISTRIBUTOR_TYPES + return self.distributor_type == "no_replace" def state_dict(self) -> dict[int, dict[str, int]]: - if self.distributor_type not in self._ACTIVE_DISTRIBUTOR_TYPES: + if self.distributor_type != "no_replace": return {} state_dict: dict[int, dict[str, int]] = {} @@ -96,7 +90,7 @@ def state_dict(self) -> dict[int, dict[str, int]]: return state_dict def load_state_dict(self, state_dict: dict[int, dict[str, int]]) -> None: - if self.distributor_type not in self._ACTIVE_DISTRIBUTOR_TYPES: + if self.distributor_type != "no_replace": return if not state_dict: @@ -110,4 +104,4 @@ def load_state_dict(self, state_dict: dict[int, dict[str, int]]) -> None: self.state[worker_id] = NoReplaceShardlistState(epoch=epoch, index=index) os.environ[f"NSL_STATE_WORKER_{worker_id}_EPOCH"] = str(epoch) os.environ[f"NSL_STATE_WORKER_{worker_id}_INDEX"] = str(index) - log.info(f"Loaded no_replace dataloader state for worker {worker_id}: epoch={epoch}, index={index}") + log.info(f"Loaded no replace dataloader state for worker {worker_id}: epoch={epoch}, index={index}") diff --git a/cosmos_framework/callbacks/every_n_draw_sample.py b/cosmos_framework/callbacks/every_n_draw_sample.py index baf1ffc..9aa96fa 100644 --- a/cosmos_framework/callbacks/every_n_draw_sample.py +++ b/cosmos_framework/callbacks/every_n_draw_sample.py @@ -154,8 +154,6 @@ def x0_pred(self, trainer, model, data_batch, output_batch, loss, iteration): tag = "ema" if self.is_ema else "reg" log.debug("starting data and condition model", rank0_only=False) - - data_clean = model.get_data_and_condition(data_batch) raw_data = data_clean.raw_state_vision x0 = data_clean.x0_tokens_vision @@ -185,7 +183,6 @@ def x0_pred(self, trainer, model, data_batch, output_batch, loss, iteration): log.debug(f"done denoising {sigma}", rank0_only=False) mse_loss = distributed.dist_reduce_tensor(F.mse_loss(sample, x0)) mse_loss_list.append(mse_loss) - if hasattr(model, "decode"): sample = model.decode(sample) to_show.append(sample.float().cpu()) @@ -316,7 +313,6 @@ def sample(self, trainer, model, data_batch, output_batch, loss, iteration): for sample_idx in range(data_clean.batch_size): n_vis = num_items[sample_idx] # First item(s) are condition, last item is generation target - # but we need to support multiple conditions per sample in the future. Current code # can handle this without throwing an error. condition_images.append(raw_data[vis_offset]) # source image (1, C, 1, H, W) diff --git a/cosmos_framework/callbacks/grad_clip.py b/cosmos_framework/callbacks/grad_clip.py index 151c1bc..f3cb4fa 100644 --- a/cosmos_framework/callbacks/grad_clip.py +++ b/cosmos_framework/callbacks/grad_clip.py @@ -132,7 +132,7 @@ def _clip_grad( # `torch.distributed._tensor.ops.math_ops._NormPartial`. # We can simply reduce the DTensor to get the total norm in this # tensor's process group and then convert it to a local tensor. - + # NOTE: It has two purposes: # 1. to make sure the total norm is computed correctly when PP is used (see below) # 2. to return a reduced mesh_norm tensor whose .item() would return the correct value if isinstance(mesh_norm, DTensor): diff --git a/cosmos_framework/callbacks/hf_export.py b/cosmos_framework/callbacks/hf_export.py index 8bcba5a..6d23568 100644 --- a/cosmos_framework/callbacks/hf_export.py +++ b/cosmos_framework/callbacks/hf_export.py @@ -1,5 +1,6 @@ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: OpenMDW-1.1 + """HFExportCallback: export VLM DCP checkpoints to HuggingFace safetensors format. Design notes @@ -137,11 +138,11 @@ def on_save_checkpoint(self, model: Any, state_dict: dict[str, Any]) -> None: if not isinstance(model, VLMModel): # The legacy vlm/train.py path passes model_parts: list[nn.Module] (raw HF # models without the VLMModel attribute structure). HF export requires the - # VLMModel wrapper, which is only available via the unified cosmos_framework/scripts/train.py path. + # VLMModel wrapper, which is only available via the unified scripts/train.py path. if isinstance(model, list): log.warning( "[HFExportCallback] Received model_parts (list) instead of VLMModel. " - "HF export requires the unified training path (cosmos_framework/scripts/train.py). Skipping." + "HF export requires the unified training path (scripts/train.py). Skipping." ) else: log.warning( diff --git a/cosmos_framework/callbacks/mfu.py b/cosmos_framework/callbacks/mfu.py index 3035437..4a6c792 100644 --- a/cosmos_framework/callbacks/mfu.py +++ b/cosmos_framework/callbacks/mfu.py @@ -138,7 +138,6 @@ def _ensure_initialised(self, model: ImaginaireModel) -> None: ac_cfg = getattr(model_cfg, "activation_checkpointing", None) ac_mode = getattr(ac_cfg, "mode", "none") - # Some activations don't need to be recomputed under selective AC, so # we need to remove them from the FLOP computation. self._use_activation_checkpointing = ac_mode != "none" diff --git a/cosmos_framework/callbacks/wandb_log_eval.py b/cosmos_framework/callbacks/wandb_log_eval.py index ac6911f..abea93f 100644 --- a/cosmos_framework/callbacks/wandb_log_eval.py +++ b/cosmos_framework/callbacks/wandb_log_eval.py @@ -88,7 +88,6 @@ def on_validation_step_end( # Handle case where dataset_name gets batched into a list if isinstance(dataset_name, list): - assert len(dataset_name) == 1, "dataset_name should be a list of 1" dataset_name = dataset_name[0] diff --git a/cosmos_framework/checkpoint/dcp.py b/cosmos_framework/checkpoint/dcp.py index 4e036c9..7318e4a 100644 --- a/cosmos_framework/checkpoint/dcp.py +++ b/cosmos_framework/checkpoint/dcp.py @@ -63,6 +63,7 @@ set_model_state_dict, ) from torch.distributed.checkpoint.stateful import Stateful +from torch.nn.modules.module import _IncompatibleKeys from cosmos_framework.checkpoint.base import AbstractCheckpointer from cosmos_framework.checkpoint.s3_filesystem import S3StorageReader, S3StorageWriter @@ -85,11 +86,11 @@ def __init__(self, model: nn.Module) -> None: def state_dict(self) -> dict[str, Any]: return get_model_state_dict(self.model) - def load_state_dict(self, state_dict: dict[str, Any]) -> None: - set_model_state_dict( + def load_state_dict(self, state_dict: dict[str, Any]) -> _IncompatibleKeys: + return set_model_state_dict( self.model, model_state_dict=state_dict, - options=StateDictOptions(strict=True), + options=StateDictOptions(strict=False), ) @@ -539,28 +540,13 @@ def load( "Ensure the model has net_ema submodule." ) _state_dict[sd_key] = _state_dict[key_ema] - elif warm_start and any(str(s).startswith("net_ema") for s in self.keys_to_skip_loading): - # Only when net_ema.* is explicitly skipped on load (e.g. an HF->DCP - # init from convert_model_to_dcp that has only net.*): the skipped - # net_ema.* keep build_net() construction values (random init when - # vlm_config.pretrained_weights.enabled=False), which would seed EMA - # from random weights -> copy net.* -> net_ema.* so EMA starts from the - # freshly-loaded init. When net_ema.* IS loaded (e.g. a training DCP - # that carries a trained EMA), do NOT clobber it. - log.info("Warm start: net_ema. skipped on load -> resetting net_ema = net.") - for sd_key in list(_state_dict.keys()): - if sd_key.startswith("net."): - key_ema = "net_ema." + sd_key.removeprefix("net.") - if key_ema in _state_dict: - _state_dict[key_ema] = _state_dict[sd_key] results = _model_wrapper.load_state_dict(_state_dict) - if results is not None: - if len(results.missing_keys) > 0: - raise ValueError(f"Missing keys (not found in checkpoint): {results.missing_keys}") - if len(results.unexpected_keys) > 0: - raise ValueError( - f"Unexpected keys (found in checkpoint but not in model): {results.unexpected_keys}" - ) + if len(results.missing_keys) > 0: + raise ValueError(f"Missing keys (not found in checkpoint): {results.missing_keys}") + if len(results.unexpected_keys) > 0: + raise ValueError( + f"Unexpected keys (found in checkpoint but not in model): {results.unexpected_keys}" + ) elif key == "optim": log.info("- Loading the optimizer...") diff --git a/cosmos_framework/checkpoint/s3_filesystem.py b/cosmos_framework/checkpoint/s3_filesystem.py index e47219e..029570e 100644 --- a/cosmos_framework/checkpoint/s3_filesystem.py +++ b/cosmos_framework/checkpoint/s3_filesystem.py @@ -3,32 +3,89 @@ import io import os +import threading import time from contextlib import contextmanager from typing import Generator, Union from urllib.parse import urlparse +import boto3 +from botocore.config import Config as S3Config from botocore.exceptions import ClientError from torch.distributed.checkpoint import FileSystemReader, FileSystemWriter from torch.distributed.checkpoint.filesystem import FileSystemBase from cosmos_framework.utils import log from cosmos_framework.utils.easy_io import easy_io +from cosmos_framework.utils.easy_io.backends import auto_auth -class S3Stream(io.BytesIO): +class _CancellableReader: + """Pipe-reader wrapper whose ``read`` raises once a cancel event is set. + + Lets us abort an in-flight ``client.upload_fileobj`` on producer error: a + read exception makes boto3 abort the multipart upload, whereas just + closing the pipe writer would signal EOF and finalize a truncated file. """ - Workaround for PyTorch manually closing the stream before we can upload it to S3. We override the close() as noop - and instead call our own _true_close() method to close the stream after we are done using it. - The commit at fault is https://github.com/pytorch/pytorch/commit/9c909bf3bb122db2cce95e2eb7459bbe50dfa15a + + def __init__(self, f, cancel_event: threading.Event) -> None: + self._f = f + self._cancel = cancel_event + + def read(self, n: int = -1) -> bytes: + if self._cancel.is_set(): + raise IOError("S3 upload cancelled by caller") + return self._f.read(n) + + def readable(self) -> bool: + return True + + def close(self) -> None: + self._f.close() + + +class _CountingPipeWriter(io.RawIOBase): + """Write-only pipe wrapper that fakes ``tell()`` by counting bytes written. + + DCP calls ``stream.tell()`` to record per-tensor byte offsets in the + checkpoint metadata, but kernel pipes aren't seekable. We maintain the + byte count ourselves; nothing actually seeks. """ - def close(self): - self.flush() - # No close + def __init__(self, write_file) -> None: + super().__init__() + self._f = write_file + self._pos = 0 + + def write(self, b) -> int: + n = self._f.write(b) + if n is None: + raise OSError("_CountingPipeWriter: underlying pipe write returned None; expected a blocking write.") + self._pos += n + return n + + def writable(self) -> bool: + return True + + def seekable(self) -> bool: + return False # pipes can't seek; consumers (zipfile, etc.) check this + + def tell(self) -> int: + return self._pos - def _true_close(self): - super().close() + def fileno(self) -> int: + return self._f.fileno() + + def flush(self) -> None: + self._f.flush() + + def close(self) -> None: + if self.closed: + return + try: + super().close() # invokes self.flush(), then sets self.closed = True + finally: + self._f.close() class S3FileSystem(FileSystemBase): @@ -69,6 +126,33 @@ def __init__( if enable_gcs_patch_in_boto3: log.info("enable_gcs_patch_in_boto3: True") + # Direct boto3 client for streaming-multipart uploads (``upload_fileobj`` + # via boto3's TransferManager). We can't reuse ``self.easy_io_backend``'s + # client: easy_io abstracts the transport (could be ``Boto3Backend`` or + # ``MSCBackend``) and intentionally doesn't expose a raw boto3 client. + # Built lazily so read-only callers don't pay for it. + self._credential_path = credential_path + self._boto3_client = None + + def _get_boto3_client(self): + """Lazily build a boto3 S3 client configured for our endpoint. + + Config mirrors cosmos_framework/utils/easy_io/backends/boto3_client.py:289 to + preserve GCS-via-S3 signature/checksum compatibility. + """ + if self._boto3_client is None: + with auto_auth.open_auth(self._credential_path, "r") as f: + cred_info = auto_auth.json_load_auth(f) + cfg = S3Config( + signature_version="s3v4", + s3={"addressing_style": "virtual"}, + response_checksum_validation="when_required", + request_checksum_calculation="when_required", + retries={"max_attempts": 5, "mode": "adaptive"}, + ) + self._boto3_client = boto3.client("s3", **cred_info, config=cfg) + return self._boto3_client + def _retry_with_backoff(self, operation_func, *args, **kwargs): """ Execute an operation with exponential backoff retry logic. @@ -135,24 +219,61 @@ def download_operation(): log.info(f"S3 Filesystem: Downloading {key} from bucket {bucket}", rank0_only=False) self._retry_with_backoff(download_operation) - log.info("S3 Filesystem: Download complete", rank0_only=False) + log.info(f"S3 Filesystem: Download complete for {key} in bucket {bucket}", rank0_only=False) yield stream finally: stream.close() elif mode == "wb": - stream = S3Stream() + # Streaming multipart upload: yield the writer end of a pipe to DCP + # and drain the reader end via ``client.upload_fileobj`` in a + # background thread. Peak memory is bounded by boto3's TransferConfig + # (~80 MiB) regardless of file size; the pipe (~64 KiB) provides + # backpressure. See ``_CancellableReader`` for how producer-side + # errors abort the multipart upload. + client = self._get_boto3_client() + r_fd, w_fd = os.pipe() + read_file = os.fdopen(r_fd, "rb") + write_file = os.fdopen(w_fd, "wb") + counting_writer = _CountingPipeWriter(write_file) + upload_err: list = [None] + cancel_event = threading.Event() + + def _upload_thread(): + try: + client.upload_fileobj( + _CancellableReader(read_file, cancel_event), + Bucket=bucket, + Key=key, + ) + except Exception as e: # noqa: BLE001 — capture and re-raise on main thread + upload_err[0] = e + finally: + try: + read_file.close() + except Exception: + pass + + log.info(f"S3 Filesystem: Streaming upload {key} to bucket {bucket}", rank0_only=False) + uploader = threading.Thread(target=_upload_thread, daemon=True, name=f"s3-upload-{key[-32:]}") + uploader.start() + + caller_raised = False try: - yield stream - - def upload_operation(): - stream.seek(0) - self.easy_io_backend.put(obj=stream, filepath=path_str) - - log.info(f"S3 Filesystem: Uploading {key} to bucket {bucket}", rank0_only=False) - self._retry_with_backoff(upload_operation) - log.info("S3 Filesystem: Upload complete", rank0_only=False) + yield counting_writer + except Exception: + caller_raised = True + cancel_event.set() + raise finally: - stream._true_close() + try: + counting_writer.close() # closes the pipe write end → EOF for the reader + except Exception: + pass + uploader.join() + if upload_err[0] is not None and not caller_raised: + # Upload thread failed; surface that to the caller. + raise upload_err[0] + log.info(f"S3 Filesystem: Upload complete for {key}", rank0_only=False) else: raise ValueError(f"Unsupported mode: {mode}") @@ -285,7 +406,7 @@ def __init__( """ super().__init__( path=path, - sync_files=False, + sync_files=False, # FIXME: setting this to True makes the run to fail (L#333: `os.fsync(stream.fileno())`) **kwargs, ) self.fs = S3FileSystem(credential_path, enable_gcs_patch_in_boto3=enable_gcs_patch_in_boto3) # type: ignore diff --git a/cosmos_framework/configs/base/__init__.py b/cosmos_framework/configs/base/__init__.py index 503ec1b..28a81be 100644 --- a/cosmos_framework/configs/base/__init__.py +++ b/cosmos_framework/configs/base/__init__.py @@ -1,3 +1,2 @@ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: OpenMDW-1.1 - diff --git a/cosmos_framework/configs/base/base_config_test.py b/cosmos_framework/configs/base/base_config_test.py index 3eb2b0d..bed1a91 100644 --- a/cosmos_framework/configs/base/base_config_test.py +++ b/cosmos_framework/configs/base/base_config_test.py @@ -17,21 +17,19 @@ from cosmos_framework.utils.config_helper import get_config_module, override +@pytest.mark.timeout(300) @pytest.mark.L0 @pytest.mark.parametrize( "experiment_name", [ - "vision_sft_nano", + "t2i_mot_exp001_009_qwen3_vl_2b_256res_frozen_llm", ], ) -def test_config_init_experiment_mot(experiment_name, monkeypatch): +def test_config_init_experiment_mot(experiment_name): """ Parameterized test to verify config initialization for multiple experiments. PYTHONPATH=. torchrun --nproc_per_node=8 -m pytest -s cosmos_framework/configs/base/config_test_mot.py --L1 """ - # The SFT experiments interpolate the dataset location from ${oc.env:DATASET_PATH}; - # config construction only needs the variable defined, not a real dataset on disk. - monkeypatch.setenv("DATASET_PATH", "/tmp/dataset") config_file = "cosmos_framework/configs/base/config.py" config_module = get_config_module(config_file) config = importlib.import_module(config_module).make_config() @@ -44,11 +42,17 @@ def test_config_init_experiment_mot(experiment_name, monkeypatch): ) -def _make_self_mock(*, pretrained_enabled: bool, load_weights_from_pretrained: bool) -> MagicMock: +def _make_self_mock( + *, + pretrained_enabled: bool, + load_weights_from_pretrained: bool, + exclude_reasoner_weights_from_checkpoint: bool = False, +) -> MagicMock: """Mock the OmniMoTModel attributes that load_pretrained_model_if_needed reads.""" self_mock = MagicMock() self_mock.vlm_config.pretrained_weights.enabled = pretrained_enabled self_mock.config.diffusion_expert_config.load_weights_from_pretrained = load_weights_from_pretrained + self_mock.config.exclude_reasoner_weights_from_checkpoint = exclude_reasoner_weights_from_checkpoint self_mock.config.ema.enabled = False return self_mock @@ -90,6 +94,17 @@ def test_resume_skips_everything(self): loader.assert_not_called() self_mock.net.language_model.init_moe.assert_not_called() + def test_resume_reloads_reasoner_when_excluded_from_checkpoint(self): + """Reasoner-excluding resumable checkpoint: HF load, but no generation copy.""" + self_mock = _make_self_mock( + pretrained_enabled=True, + load_weights_from_pretrained=True, + exclude_reasoner_weights_from_checkpoint=True, + ) + loader = self._call(self_mock, has_resumable_checkpoint=True, has_load_path=False) + loader.assert_called_once() + self_mock.net.language_model.init_moe.assert_not_called() + def test_warm_start_loads_but_skips_copy(self): """load_path set, no checkpoint: HF load but skip understanding→generation copy.""" self_mock = _make_self_mock(pretrained_enabled=True, load_weights_from_pretrained=True) diff --git a/cosmos_framework/configs/base/defaults/__init__.py b/cosmos_framework/configs/base/defaults/__init__.py index 503ec1b..28a81be 100644 --- a/cosmos_framework/configs/base/defaults/__init__.py +++ b/cosmos_framework/configs/base/defaults/__init__.py @@ -1,3 +1,2 @@ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: OpenMDW-1.1 - diff --git a/cosmos_framework/configs/base/defaults/callbacks.py b/cosmos_framework/configs/base/defaults/callbacks.py index 46c85ed..805ffb5 100644 --- a/cosmos_framework/configs/base/defaults/callbacks.py +++ b/cosmos_framework/configs/base/defaults/callbacks.py @@ -10,7 +10,6 @@ from cosmos_framework.utils.lazy_config import LazyCall as L from cosmos_framework.utils.callback import LowPrecisionCallback, WandBCallback from cosmos_framework.callbacks.compile_tokenizer import CompileTokenizer - from cosmos_framework.callbacks.device_monitor import DeviceMonitor from cosmos_framework.callbacks.every_n_draw_sample import EveryNDrawSample from cosmos_framework.callbacks.expert_heatmap import ExpertHeatmap diff --git a/cosmos_framework/configs/base/defaults/cluster.py b/cosmos_framework/configs/base/defaults/cluster.py index 23b49dd..46450b9 100644 --- a/cosmos_framework/configs/base/defaults/cluster.py +++ b/cosmos_framework/configs/base/defaults/cluster.py @@ -23,14 +23,24 @@ class ClusterConfig: DefaultClusterConfig: ClusterConfig = ClusterConfig( object_store_bucket_data="", - object_store_bucket_checkpoint="bucket-checkpoint", - object_store_bucket_pretrained="bucket-pretrained", - object_store_credential_data="credentials/data.secret", - object_store_credential_checkpoint="credentials/checkpoint.secret", - object_store_credential_pretrained="credentials/pretrained.secret", + object_store_bucket_checkpoint="bucket4", + object_store_bucket_pretrained="bucket4", + object_store_credential_data="credentials/s3_training.secret", + object_store_credential_checkpoint="credentials/s3_checkpoint.secret", + object_store_credential_pretrained="credentials/s3_checkpoint.secret", +) + +DefaultClusterConfig: ClusterConfig = ClusterConfig( + object_store_bucket_data="", + object_store_bucket_checkpoint="bucket1", + object_store_bucket_pretrained="bucket0", + object_store_credential_data="credentials/gcp_checkpoint.secret", + object_store_credential_checkpoint="credentials/gcp_training.secret", + object_store_credential_pretrained="credentials/gcp_training.secret", ) def register_cluster(): cs = ConfigStore.instance() - cs.store(group="cluster", package="job.cluster", name="default", node=DefaultClusterConfig) + cs.store(group="cluster", package="job.cluster", name="aws_iad_h100", node=DefaultClusterConfig) + cs.store(group="cluster", package="job.cluster", name="gcp_iad_gb200", node=DefaultClusterConfig) diff --git a/cosmos_framework/configs/base/defaults/compile.py b/cosmos_framework/configs/base/defaults/compile.py index b0e1c88..3d5ebf7 100644 --- a/cosmos_framework/configs/base/defaults/compile.py +++ b/cosmos_framework/configs/base/defaults/compile.py @@ -24,7 +24,7 @@ class CompileConfig: # (maps to ``torch.compile(dynamic=...)``). Defaults to True for training, # which sees varying shapes across batches (sequence length, CP sharding, ...); # specializing would recompile continuously. See ParallelismOverrides in - # cosmos_framework/inference/common/args.py for the inference-side rationale + # packages/cosmos3/cosmos3/common/args.py for the inference-side rationale # (where dynamic=False is preferred for stable AR shapes). compile_dynamic: bool = True diff --git a/cosmos_framework/configs/base/defaults/model_config.py b/cosmos_framework/configs/base/defaults/model_config.py index c7b9a5c..f7e7c8d 100644 --- a/cosmos_framework/configs/base/defaults/model_config.py +++ b/cosmos_framework/configs/base/defaults/model_config.py @@ -1,7 +1,6 @@ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: OpenMDW-1.1 - from typing import Any import attrs @@ -42,6 +41,14 @@ class DiffusionExpertConfig: rope_t_extrapolation_ratio: float = 1.0 enable_fps_modulation: bool = False base_fps: int = 24 + # Base temporal compression factor for SOUND m-RoPE. None = current behavior + # (sound advances at base_fps positions/sec). Set to the vision tcf (4) to put + # sound on the same latent-frame temporal grid as vision/action. + sound_base_temporal_compression_factor: int | None = None + # Temporal coordinates used for unified_3d_mrope vision tokens. + # - "latent_index": legacy behavior, positions are 0, 1, ..., T_latent-1. + # - "uniae_source_right_edge": use UniAE padded-patch right-edge source-frame coordinates. + vision_temporal_position_mode: str = "latent_index" # For unified_3d_mrope: whether spatial (H, W) indices reset to 0 for each vision segment unified_3d_mrope_reset_spatial_ids: bool = True # Setting the temporal gap on the boundary of the different modalities, default is 0, using a value greater than 0 will add an additional offset on the accumulated temporal offset. @@ -273,3 +280,12 @@ class OmniMoTModelConfig: sound_latent_fps: int = 25 # Sound tokenizer's latent rate (e.g., 48kHz / 1920 hop = 25 Hz) log_enc_time_every_n: int = 100 # Frequency of logging encoding time to W&B + + # When True, ``OmniMoTModel.state_dict`` / ``load_state_dict`` skip the + # reasoner (und) pathway weights under ``language_model`` — i.e. every key + # WITHOUT a ``_moe_gen`` suffix (including ``visual`` / ``lm_head`` / + # ``embed_tokens``). These are not written to checkpoints and are left + # untouched on load (typically already populated from the HF pretrained + # backbone). Generation-pathway (``_moe_gen``) and VFM heads are saved / + # loaded as usual. + exclude_reasoner_weights_from_checkpoint: bool = False diff --git a/cosmos_framework/configs/base/defaults/multiview_dataloader.py b/cosmos_framework/configs/base/defaults/multiview_dataloader.py deleted file mode 100644 index a646ac6..0000000 --- a/cosmos_framework/configs/base/defaults/multiview_dataloader.py +++ /dev/null @@ -1,150 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: OpenMDW-1.1 - -""" -Hydra ConfigStore registration for multiview dataloaders. - -Registers named dataloader configs that can be referenced via Hydra overrides -(e.g. ``{override /data_train: video_control_mads_multiview_0823_gcs_720p_10fps_93frames_7views}``) -or used as templates for inline ``L(get_multiview_video_loader)(...)`` in -experiment configs. - -Two naming conventions: - - **Transfer** (with control signal): - ``video_control_{dataset}_{store}_{res}_{fps}_{frames}_{views}`` - - **Predict** (no control signal): - ``video_{dataset}_{store}_{res}_{fps}_{frames}_{views}`` -""" - -from hydra.core.config_store import ConfigStore - -from cosmos_framework.utils.lazy_config import LazyCall as L -from cosmos_framework.data.vfm.multiview.multiview_data_source import ( - DEFAULT_CAMERAS, - INDEX_TO_CAMERA_MAPPING, - TRANSFER_CAPTION_KEY_MAPPING, - TRANSFER_CONTROL_KEY_MAPPING, - TRANSFER_VIDEO_KEY_MAPPING, -) -from cosmos_framework.data.vfm.multiview.multiview_dataset import ( - MultiviewAugmentationConfig, - get_multiview_video_loader, -) - -# --------------------------------------------------------------------------- -# Camera view subsets -# --------------------------------------------------------------------------- - -CAMERA_VIEW_CONFIGS: dict[str, tuple[str, ...]] = { - "7views": DEFAULT_CAMERAS, - "1view_front": ("camera_front_wide_120fov",), - "4views": ( - "camera_front_wide_120fov", - "camera_cross_right_120fov", - "camera_rear_tele_30fov", - "camera_cross_left_120fov", - ), -} - -# --------------------------------------------------------------------------- -# Grid dimensions -# --------------------------------------------------------------------------- - -_TRANSFER_DATASETS = ["mads_multiview_0823"] -_OBJECT_STORES = ["gcs"] - -_RESOLUTIONS: list[tuple[str, tuple[int, int]]] = [ - ("720p", (720, 1280)), -] - -_FPS: list[tuple[str, int]] = [ - ("10fps", 1), # MADS transfer data is already at 10 fps -] - -_NUM_VIDEO_FRAMES: list[tuple[str, int]] = [ - ("29frames", 29), - ("61frames", 61), - ("93frames", 93), -] - - -def register_multiview_dataloaders() -> None: - """Register all multiview dataloader configs with Hydra ConfigStore.""" - - cs = ConfigStore.instance() - - # ----- Transfer dataloaders (with control signals) ----- - for dataset in _TRANSFER_DATASETS: - for object_store in _OBJECT_STORES: - for resolution_str, resolution_hw in _RESOLUTIONS: - for fps_str, downsample_factor in _FPS: - for num_frames_str, num_frames in _NUM_VIDEO_FRAMES: - for views_str, camera_keys in CAMERA_VIEW_CONFIGS.items(): - name = ( - f"video_control_{dataset}_{object_store}_{resolution_str}_" - f"{fps_str}_{num_frames_str}_{views_str}" - ) - cs.store( - group="data_train", - package="dataloader_train", - name=name, - node=L(get_multiview_video_loader)( - dataset_name=dataset, - is_train=True, - augmentation_config=L(MultiviewAugmentationConfig)( - resolution_hw=resolution_hw, - fps_downsample_factor=downsample_factor, - num_video_frames=num_frames, - camera_keys=camera_keys, - camera_video_key_mapping=TRANSFER_VIDEO_KEY_MAPPING, - camera_caption_key_mapping=TRANSFER_CAPTION_KEY_MAPPING, - camera_control_key_mapping=TRANSFER_CONTROL_KEY_MAPPING, - position_to_camera_mapping=INDEX_TO_CAMERA_MAPPING, - single_caption_camera_name="camera_front_wide_120fov", - ), - ), - ) - - # ----- Predict dataloaders (no control signals, for future use) ----- - # These use named keys (video_camera_front_wide_120fov, etc.) and need - # different datasets (e.g. alpamayo_dec2024) with 30 fps native data. - # Uncomment and add predict datasets to the catalog when needed. - # - # _PREDICT_DATASETS = ["alpamayo_dec2024"] - # _PREDICT_FPS = [("10fps", 3), ("15fps", 2)] # 30 fps native → downsample - # for dataset in _PREDICT_DATASETS: - # for object_store in _OBJECT_STORES: - # for resolution_str, resolution_hw in _RESOLUTIONS: - # for fps_str, downsample_factor in _PREDICT_FPS: - # for num_frames_str, num_frames in _NUM_VIDEO_FRAMES: - # for views_str, camera_keys in CAMERA_VIEW_CONFIGS.items(): - # name = ( - # f"video_{dataset}_{object_store}_{resolution_str}_" - # f"{fps_str}_{num_frames_str}_{views_str}" - # ) - # cs.store( - # group="data_train", - # package="dataloader_train", - # name=name, - # node=L(get_multiview_video_loader)( - # dataset_name=dataset, - # is_train=True, - # augmentation_config=L(MultiviewAugmentationConfig)( - # resolution_hw=resolution_hw, - # fps_downsample_factor=downsample_factor, - # num_video_frames=num_frames, - # camera_keys=camera_keys, - # camera_video_key_mapping=PREDICT_VIDEO_KEY_MAPPING, - # camera_caption_key_mapping=PREDICT_CAPTION_KEY_MAPPING, - # camera_control_key_mapping=None, - # position_to_camera_mapping=None, - # single_caption_camera_name=None, - # ), - # ), - # ) - - -# Auto-register on import -register_multiview_dataloaders() diff --git a/cosmos_framework/configs/base/defaults/tokenizer.py b/cosmos_framework/configs/base/defaults/tokenizer.py index 55cb01c..526d579 100644 --- a/cosmos_framework/configs/base/defaults/tokenizer.py +++ b/cosmos_framework/configs/base/defaults/tokenizer.py @@ -17,10 +17,12 @@ PRETRAINED_TOKENIZER_FLUX_VAE_PTH = "pretrained/tokenizers/image/flux/ae.safetensors" # UniAE checkpoint paths -PRETRAINED_TOKENIZER_UNIAE_4X16X16_C48_T8TO24_64TO512P_FPS_ALL_ENCODER_NONCAUSAL_DECODER_NONCAUSAL_NOGAN_BEST_S1_VAE_PTH = "pretrained/tokenizers/video/cosmos/uniae4x16x16_c48_t8to24_64to512p_fps_all_encoder_noncausal_decoder_noncausal_nogan_best_s1.pt" +PRETRAINED_TOKENIZER_UNIAE_4X16X16_C48_T16TO160_MIXP_FPS_MIX_ENCODER_NONCAUSAL_DECODER_NONCAUSAL_NOGAN_S3_NEMOTRON2B_VAE_PTH = ( + "s3://bucket1/uniae/tok_experiments/" + "s3_siglip2_so400m_singledec_l48_textdec_nemotron2b_32node_bucketed_256480_v45i32c23_t16-160_exp009/checkpoints/iter_000050000.pt" +) # DCAE checkpoint paths -PRETRAINED_TOKENIZER_DCAE_PTH = "pretrained/tokenizers/video/cosmos/dc-ae-v-1.0-f32t4c64-cosmos-encoder-causal-decoder-chunk-causal-4-frame-120-pad-7-no-gan.pt" PRETRAINED_TOKENIZER_DCAE_4X32X32_C64_T120_256P_FPS_ALL_ENCODER_CAUSAL_DECODER_CHUNKCAUSAL4_NOGAN_COSMOS_PAD_7_V0PT2_PTH = "pretrained/tokenizers/video/cosmos/dcae4x32x32_c64_t120_256p_fps_all_encoder_causal_decoder_chunk_causal_4_nogan_cosmos_pad_7_v0.2.pt" PRETRAINED_TOKENIZER_DCAE_4X32X32_C96_T120_256P_FPS_ALL_ENCODER_CAUSAL_DECODER_CHUNKCAUSAL4_NOGAN_COSMOS_PAD_7_V0PT2_PTH = "pretrained/tokenizers/video/cosmos/dcae4x32x32_c96_t120_256p_fps_all_encoder_causal_decoder_chunk_causal_4_nogan_cosmos_pad_7_v0.2.pt" PRETRAINED_TOKENIZER_DCAE_4X32X32_C128_T120_256P_FPS_ALL_ENCODER_CAUSAL_DECODER_CHUNKCAUSAL4_NOGAN_COSMOS_PAD_7_V0PT2_PTH = "pretrained/tokenizers/video/cosmos/dcae4x32x32_c128_t120_256p_fps_all_encoder_causal_decoder_chunk_causal_4_nogan_cosmos_pad_7_v0.2.pt" @@ -44,6 +46,7 @@ chunk_duration=1, spatial_compression_factor=8, temporal_compression_factor=1, + causal=True, ) Wan2pt1VAEConfig: LazyDict = L(Wan2pt1VAEInterface)( @@ -53,6 +56,7 @@ vae_path=PRETRAINED_TOKENIZER_WAN2PT1_VAE_PTH, spatial_compression_factor=8, temporal_compression_factor=4, + causal=True, ) Wan2pt2VAEConfig: LazyDict = L(Wan2pt2VAEInterface)( @@ -61,14 +65,7 @@ vae_path=PRETRAINED_TOKENIZER_WAN2PT2_VAE_PTH, spatial_compression_factor=16, temporal_compression_factor=4, -) - -DCAE4x32x32Config: LazyDict = L(DCAE4x32x32Interface)( - bucket_name=PLACEHOLDER, - object_store_credential_path_pretrained=PLACEHOLDER, - vae_path=PRETRAINED_TOKENIZER_DCAE_PTH, - spatial_compression_factor=32, - temporal_compression_factor=4, + causal=True, ) DCAE4x32x32C64T120_256pFpsAllEncoderCausalDecoderChunkCausal4NoganCosmosPad7V0pt2Config: LazyDict = L( @@ -80,6 +77,7 @@ model_name="dcae4x32x32_c64_t120_256p_fps_all_encoder_causal_decoder_chunk_causal_4_nogan_cosmos_pad_7_v0.2", spatial_compression_factor=32, temporal_compression_factor=4, + causal=True, ) DCAE4x32x32C96T120_256pFpsAllEncoderCausalDecoderChunkCausal4NoganCosmosPad7V0pt2Config: LazyDict = L( @@ -104,12 +102,17 @@ temporal_compression_factor=4, ) -UniAE4x16x16C48T8to24_64to512pFpsAllEncoderNoncausalDecoderNoncausalNoganBestS1Config: LazyDict = L(UniAEVAEInterface)( + +UniAE4x16x16C48T16to160MixpFpsMixEncoderNoncausalDecoderNoncausalNoganS3Nemotron2bVAEConfig: LazyDict = L( + UniAEVAEInterface +)( bucket_name=PLACEHOLDER, object_store_credential_path_pretrained=PLACEHOLDER, - vae_path=PRETRAINED_TOKENIZER_UNIAE_4X16X16_C48_T8TO24_64TO512P_FPS_ALL_ENCODER_NONCAUSAL_DECODER_NONCAUSAL_NOGAN_BEST_S1_VAE_PTH, + vae_path=PRETRAINED_TOKENIZER_UNIAE_4X16X16_C48_T16TO160_MIXP_FPS_MIX_ENCODER_NONCAUSAL_DECODER_NONCAUSAL_NOGAN_S3_NEMOTRON2B_VAE_PTH, spatial_compression_factor=16, temporal_compression_factor=4, + pixel_trim=True, + causal=False, ) # ============================================================================= @@ -173,8 +176,8 @@ def register_tokenizer(): cs.store( group="tokenizer", package="model.config.tokenizer", - name="uniae_4x16x16_c48_t8to24_64to512p_fps_all_encoder_noncausal_decoder_noncausal_nogan_best_s1_tokenizer", - node=UniAE4x16x16C48T8to24_64to512pFpsAllEncoderNoncausalDecoderNoncausalNoganBestS1Config, + name="uniae_4x16x16_c48_t16to160_mixp_fps_mix_encoder_noncausal_decoder_noncausal_nogan_s3_nemotron2b_tokenizer", + node=UniAE4x16x16C48T16to160MixpFpsMixEncoderNoncausalDecoderNoncausalNoganS3Nemotron2bVAEConfig, ) # Flux tokenizer cs.store(group="tokenizer", package="model.config.tokenizer", name="flux_tokenizer", node=FluxVAEConfig) @@ -182,25 +185,19 @@ def register_tokenizer(): cs.store( group="tokenizer", package="model.config.tokenizer", - name="dc_ae_4x32x32_tokenizer", - node=DCAE4x32x32Config, - ) - cs.store( - group="tokenizer", - package="model.config.tokenizer", - name="dc_ae_4x32x32_c64_t120_256p_fps_all_encoder_causal_decoder_chunk_causal_4_nogan_cosmos_pad_7_v0.2_tokenizer", + name="dcae4x32x32_c64_t120_256p_fps_all_encoder_causal_decoder_chunk_causal_4_nogan_cosmos_pad_7_v0.2_tokenizer", node=DCAE4x32x32C64T120_256pFpsAllEncoderCausalDecoderChunkCausal4NoganCosmosPad7V0pt2Config, ) cs.store( group="tokenizer", package="model.config.tokenizer", - name="dc_ae_4x32x32_c96_t120_256p_fps_all_encoder_causal_decoder_chunk_causal_4_nogan_cosmos_pad_7_v0.2_tokenizer", + name="dcae4x32x32_c96_t120_256p_fps_all_encoder_causal_decoder_chunk_causal_4_nogan_cosmos_pad_7_v0.2_tokenizer", node=DCAE4x32x32C96T120_256pFpsAllEncoderCausalDecoderChunkCausal4NoganCosmosPad7V0pt2Config, ) cs.store( group="tokenizer", package="model.config.tokenizer", - name="dc_ae_4x32x32_c128_t120_256p_fps_all_encoder_causal_decoder_chunk_causal_4_nogan_cosmos_pad_7_v0.2_tokenizer", + name="dcae4x32x32_c128_t120_256p_fps_all_encoder_causal_decoder_chunk_causal_4_nogan_cosmos_pad_7_v0.2_tokenizer", node=DCAE4x32x32C128T120_256pFpsAllEncoderCausalDecoderChunkCausal4NoganCosmosPad7V0pt2Config, ) diff --git a/cosmos_framework/configs/base/defaults/unittest.py b/cosmos_framework/configs/base/defaults/unittest.py index 5af5bb2..89ecfcc 100644 --- a/cosmos_framework/configs/base/defaults/unittest.py +++ b/cosmos_framework/configs/base/defaults/unittest.py @@ -1,9 +1,14 @@ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: OpenMDW-1.1 + import attrs +# from cosmos_framework.configs.base.defaults.cluster import DefaultClusterConfig + # We are hardcoding the unittest assets in this file. +# CLUSTER_CONFIG = DefaultClusterConfig + # add codeowner for cosmos_framework/model/vfm/tokenizers diff --git a/cosmos_framework/configs/base/defaults/vlm.py b/cosmos_framework/configs/base/defaults/vlm.py index fa4d7c8..9176983 100644 --- a/cosmos_framework/configs/base/defaults/vlm.py +++ b/cosmos_framework/configs/base/defaults/vlm.py @@ -95,11 +95,7 @@ def download_tokenizer_files(model_name: str, config_variant: str) -> str: return destination_dir -def create_qwen2_tokenizer_with_download(pretrained_model_name: str, config_variant: str, **_unused_kwargs): - # **_unused_kwargs absorbs extras (e.g. tokenizer_type) that OmegaConf - # merges in from a vlm_config preset's tokenizer block when an experiment - # overrides the tokenizer with this function but doesn't fully replace - # the preset's kwarg dict. +def create_qwen2_tokenizer_with_download(pretrained_model_name: str, config_variant: str): destination_dir = download_tokenizer_files(pretrained_model_name, config_variant) return LLMTokenizerProcessor(Qwen2Tokenizer.from_pretrained(destination_dir)) @@ -140,7 +136,7 @@ class VLMConfig: # HuggingFace model identifier or local path. Drives AutoConfig + AutoModel selection. model_name: str = "" - # Safetensor path for model + # Safetensor path for model for load a safetensor from different folder safetensors_path: str = "" # Optional pretrained-weights overlay (separate from the AutoModel structural @@ -285,29 +281,6 @@ class VLMConfig: ), ) -CosmosReason2_VLM_30b_a3b_Private_GCP_Config: VLMConfig = VLMConfig( - model_name="nvidia/Cosmos-Reason2-30B-A3B-Private", - model_instance=L(Qwen3VLMoeTextForCausalLM)( - config=L(create_vlm_config)( - base_config=L(Qwen3VLMoeMoTConfig.from_json_file)( - json_file="cosmos_framework/model/vfm/vlm/qwen3_vl_moe/configs/Qwen3-VL-30B-A3B-Instruct.json" - ), - layer_module="Qwen3VLMoeTextMoTDecoderLayer", - qk_norm_for_text=True, - ), - ), - tokenizer=L(build_processor_lazy)( - tokenizer_type="Qwen/Qwen3-VL-30B-A3B-Instruct", - config_variant="gcp", - ), - layer_module="Qwen3VLMoeTextMoTDecoderLayer", - pretrained_weights=PretrainedWeightsConfig( - backbone_path="s3://bucket0/cosmos3/pretrained/huggingface/Cosmos-Reason/Cosmos-Reason2-30B-A3B-Private/", - credentials_path="credentials/gcp_checkpoint.secret", - enable_gcs_patch_in_boto3=True, - ), -) - # Config for Qwen3VL 235B A22B Instruct model # Qwen3VLMoE uses Qwen2Tokenizer Qwen3VLMoT_VLM_235b_a22b_Instruct_GCP_Config: VLMConfig = VLMConfig( @@ -458,48 +431,6 @@ class VLMConfig: ), ) -CosmosReason2_VLM_2b_Private_GCP_Config: VLMConfig = VLMConfig( - model_name="nvidia/Cosmos-Reason2-2B-Private", - model_instance=L(Qwen3VLTextForCausalLM)( - config=L(create_vlm_config)( - base_config=L(Qwen3VLMoTConfig.from_json_file)( - json_file="cosmos_framework/model/vfm/vlm/qwen3_vl/configs/Qwen3-VL-2B-Instruct.json" - ), - qk_norm_for_text=True, - ), - ), - tokenizer=L(build_processor_lazy)( - tokenizer_type="Qwen/Qwen3-VL-2B-Instruct", - config_variant="gcp", - ), - pretrained_weights=PretrainedWeightsConfig( - backbone_path="s3://bucket0/cosmos3/pretrained/huggingface/Cosmos-Reason/Cosmos-Reason2-2B-Private/", - credentials_path="credentials/gcp_checkpoint.secret", - enable_gcs_patch_in_boto3=True, - ), -) - -Cosmos3Reasoner_VLM_2b_Private_GCP_Config: VLMConfig = VLMConfig( - model_name="nvidia/Cosmos3-Reasoner-2B-Private", - model_instance=L(Qwen3VLTextForCausalLM)( - config=L(create_vlm_config)( - base_config=L(Qwen3VLMoTConfig.from_json_file)( - json_file="cosmos_framework/model/vfm/vlm/qwen3_vl/configs/Qwen3-VL-2B-Instruct.json" - ), - qk_norm_for_text=True, - ), - ), - tokenizer=L(build_processor_lazy)( - tokenizer_type="Qwen/Qwen3-VL-2B-Instruct", - config_variant="gcp", - ), - pretrained_weights=PretrainedWeightsConfig( - backbone_path="s3://bucket0/cosmos3/pretrained/huggingface/Cosmos-Reason/Cosmos3-Reasoner-2B-Private/", - credentials_path="credentials/gcp_checkpoint.secret", - enable_gcs_patch_in_boto3=True, - ), -) - # Config for Qwen3VL 4B Instruct model # Qwen3VL uses Qwen2Tokenizer Qwen3VLMoT_VLM_4b_Instruct_Config: VLMConfig = VLMConfig( @@ -586,8 +517,8 @@ class VLMConfig: ), ) -CosmosReason2_VLM_8b_Private_GCP_Config: VLMConfig = VLMConfig( - model_name="nvidia/Cosmos-Reason2-8B-Private", +Cosmos3Reasoner_VLM_8b_Private_GCP_Config: VLMConfig = VLMConfig( + model_name="nvidia/Cosmos3-Reasoner-8B-Private", model_instance=L(Qwen3VLTextForCausalLM)( config=L(create_vlm_config)( base_config=L(Qwen3VLMoTConfig.from_json_file)( @@ -601,14 +532,14 @@ class VLMConfig: config_variant="gcp", ), pretrained_weights=PretrainedWeightsConfig( - backbone_path="s3://bucket0/cosmos3/pretrained/huggingface/Cosmos-Reason/Cosmos-Reason2-8B-Private/", + backbone_path="s3://bucket0/cosmos3/pretrained/huggingface/Cosmos-Reason/Cosmos3-Reasoner-8B-Private/", credentials_path="credentials/gcp_checkpoint.secret", enable_gcs_patch_in_boto3=True, ), ) -Cosmos3Reasoner_VLM_8b_Private_GCP_Config: VLMConfig = VLMConfig( - model_name="nvidia/Cosmos3-Reasoner-8B-Private", +Cosmos3NanoReasoner_VLM_GCP_Config: VLMConfig = VLMConfig( + model_name="nvidia/Cosmos3-Nano-Reasoner", model_instance=L(Qwen3VLTextForCausalLM)( config=L(create_vlm_config)( base_config=L(Qwen3VLMoTConfig.from_json_file)( @@ -617,18 +548,18 @@ class VLMConfig: qk_norm_for_text=True, ), ), - tokenizer=L(build_processor_lazy)( - tokenizer_type="Qwen/Qwen3-VL-8B-Instruct", + tokenizer=L(create_qwen2_tokenizer_with_download)( + pretrained_model_name="Qwen/Qwen3-VL-8B-Instruct", config_variant="gcp", ), pretrained_weights=PretrainedWeightsConfig( - backbone_path="s3://bucket0/cosmos3/pretrained/huggingface/Cosmos-Reason/Cosmos3-Reasoner-8B-Private/", + backbone_path="s3://bucket0/cosmos3/pretrained/huggingface/Cosmos-Reason/Cosmos3-Nano-Reasoner/", credentials_path="credentials/gcp_checkpoint.secret", enable_gcs_patch_in_boto3=True, ), ) -Cosmos3NanoReasoner_VLM_GCP_Config: VLMConfig = VLMConfig( +Cosmos3NanoReasoner_VLM_GCP_Config_0517: VLMConfig = VLMConfig( model_name="nvidia/Cosmos3-Nano-Reasoner", model_instance=L(Qwen3VLTextForCausalLM)( config=L(create_vlm_config)( @@ -643,13 +574,12 @@ class VLMConfig: config_variant="gcp", ), pretrained_weights=PretrainedWeightsConfig( - backbone_path="s3://bucket0/cosmos3/pretrained/huggingface/Cosmos-Reason/Cosmos3-Nano-Reasoner/", + backbone_path="s3://bucket0/cosmos3/pretrained/huggingface/Cosmos-Reason/Cosmos3-Nano-Reasoner-bb9c6f5/", credentials_path="credentials/gcp_checkpoint.secret", enable_gcs_patch_in_boto3=True, ), ) - # Config for Qwen3VL 32B Instruct model # Qwen3VL uses Qwen2Tokenizer Qwen3VLMoT_VLM_32b_Instruct_Config: VLMConfig = VLMConfig( @@ -693,8 +623,8 @@ class VLMConfig: ), ) -CosmosReason2_VLM_32b_Private_GCP_Config: VLMConfig = VLMConfig( - model_name="nvidia/Cosmos-Reason2-32B-Private", +Cosmos3Reasoner_VLM_32b_Private_GCP_Config: VLMConfig = VLMConfig( + model_name="nvidia/Cosmos3-Reasoner-32B-Private", model_instance=L(Qwen3VLTextForCausalLM)( config=L(create_vlm_config)( base_config=L(Qwen3VLMoTConfig.from_json_file)( @@ -708,14 +638,14 @@ class VLMConfig: config_variant="gcp", ), pretrained_weights=PretrainedWeightsConfig( - backbone_path="s3://bucket0/cosmos3/pretrained/huggingface/Cosmos-Reason/Cosmos-Reason2-32B-Private/", + backbone_path="s3://bucket0/cosmos3/pretrained/huggingface/Cosmos-Reason/Cosmos3-Reasoner-32B-Private/", credentials_path="credentials/gcp_checkpoint.secret", enable_gcs_patch_in_boto3=True, ), ) -Cosmos3Reasoner_VLM_32b_Private_GCP_Config: VLMConfig = VLMConfig( - model_name="nvidia/Cosmos3-Reasoner-32B-Private", +Cosmos3SuperReasoner_VLM_GCP_Config: VLMConfig = VLMConfig( + model_name="nvidia/Cosmos3-Super-Reasoner", model_instance=L(Qwen3VLTextForCausalLM)( config=L(create_vlm_config)( base_config=L(Qwen3VLMoTConfig.from_json_file)( @@ -724,18 +654,18 @@ class VLMConfig: qk_norm_for_text=True, ), ), - tokenizer=L(build_processor_lazy)( - tokenizer_type="Qwen/Qwen3-VL-32B-Instruct", + tokenizer=L(create_qwen2_tokenizer_with_download)( + pretrained_model_name="Qwen/Qwen3-VL-32B-Instruct", config_variant="gcp", ), pretrained_weights=PretrainedWeightsConfig( - backbone_path="s3://bucket0/cosmos3/pretrained/huggingface/Cosmos-Reason/Cosmos3-Reasoner-32B-Private/", + backbone_path="s3://bucket0/cosmos3/pretrained/huggingface/Cosmos-Reason/Cosmos3-Super-Reasoner/", credentials_path="credentials/gcp_checkpoint.secret", enable_gcs_patch_in_boto3=True, ), ) -Cosmos3SuperReasoner_VLM_GCP_Config: VLMConfig = VLMConfig( +Cosmos3SuperReasoner_VLM_GCP_Config_0517: VLMConfig = VLMConfig( model_name="nvidia/Cosmos3-Super-Reasoner", model_instance=L(Qwen3VLTextForCausalLM)( config=L(create_vlm_config)( @@ -750,12 +680,63 @@ class VLMConfig: config_variant="gcp", ), pretrained_weights=PretrainedWeightsConfig( - backbone_path="s3://bucket0/cosmos3/pretrained/huggingface/Cosmos-Reason/Cosmos3-Super-Reasoner/", + backbone_path="s3://bucket0/cosmos3/pretrained/huggingface/Cosmos-Reason/Cosmos3-Super-Reasoner-b6df0d1/", credentials_path="credentials/gcp_checkpoint.secret", enable_gcs_patch_in_boto3=True, ), ) +# Cosmos3-Edge-Reasoner at commit 4acb717. +# nemotron_siglip2 architecture: Nemotron text backbone (56-block hybrid layout, 2048 hidden) +# + SigLIP2 vision encoder. The text transformer is identical in shape to +# Nemotron-3-Dense-VL-2B (hidden_size=2048, 56 alternating attn/MLP blocks → 28 +# effective MoT layers after _transform_text_dict). Uses the same +# nemotron_3_dense_vl weight remapping and config JSON. +Cosmos3EdgeReasoner_VLM_GCP_Config_4acb717: VLMConfig = VLMConfig( + model_name="nvidia/Cosmos3-Edge-Reasoner", + model_instance=L(Nemotron3DenseVLTextForCausalLM)( + config=L(create_vlm_config)( + base_config=L(Nemotron3DenseVLMoTConfig.from_json_file)( + json_file="cosmos_framework/model/vfm/vlm/nemotron_3_dense_vl/configs/Nemotron-2B-Dense-VL.json" + ), + qk_norm_for_text=False, + ), + ), + tokenizer=L(build_processor_lazy)( + tokenizer_type="nvidia/Cosmos3-Edge-Reasoner", + config_variant="gcp", + ), + pretrained_weights=PretrainedWeightsConfig( + backbone_path="s3://bucket0/cosmos3/pretrained/huggingface/nvidia/Cosmos3-Edge-Reasoner-4acb717/", + credentials_path="credentials/gcp_checkpoint.secret", + enable_gcs_patch_in_boto3=True, + checkpoint_format="nemotron_3_dense_vl", + ), +) + +# Cosmos3-Edge-Reasoner at commit 9b4c028 (2026-05-29). +# Same nemotron_siglip2 architecture as 4acb717; new weights uploaded 2026-05-29. +Cosmos3EdgeReasoner_VLM_GCP_Config_9b4c028: VLMConfig = VLMConfig( + model_name="nvidia/Cosmos3-Edge-Reasoner", + model_instance=L(Nemotron3DenseVLTextForCausalLM)( + config=L(create_vlm_config)( + base_config=L(Nemotron3DenseVLMoTConfig.from_json_file)( + json_file="cosmos_framework/model/vfm/vlm/nemotron_3_dense_vl/configs/Nemotron-2B-Dense-VL.json" + ), + qk_norm_for_text=False, + ), + ), + tokenizer=L(build_processor_lazy)( + tokenizer_type="nvidia/Cosmos3-Edge-Reasoner", + config_variant="gcp", + ), + pretrained_weights=PretrainedWeightsConfig( + backbone_path="s3://bucket0/cosmos3/pretrained/huggingface/nvidia/Cosmos3-Edge-Reasoner-9b4c028/", + credentials_path="credentials/gcp_checkpoint.secret", + enable_gcs_patch_in_boto3=True, + checkpoint_format="nemotron_3_dense_vl", + ), +) def register_vlm(): @@ -832,24 +813,6 @@ def register_vlm(): name="cosmos_reason2_vlm_2b_gcp", node=CosmosReason2_VLM_2b_GCP_Config, ) - cs.store( - group="vlm_config", - package="model.config.vlm_config", - name="cosmos_reason2_vlm_2b_private_gcp", - node=CosmosReason2_VLM_2b_Private_GCP_Config, - ) - cs.store( - group="vlm_config", - package="model.config.vlm_config", - name="cosmos3_reasoner_vlm_2b_private_gcp", - node=Cosmos3Reasoner_VLM_2b_Private_GCP_Config, - ) - cs.store( - group="vlm_config", - package="model.config.vlm_config", - name="cosmos_reason2_vlm_8b_private_gcp", - node=CosmosReason2_VLM_8b_Private_GCP_Config, - ) cs.store( group="vlm_config", package="model.config.vlm_config", @@ -865,8 +828,8 @@ def register_vlm(): cs.store( group="vlm_config", package="model.config.vlm_config", - name="cosmos_reason2_vlm_32b_private_gcp", - node=CosmosReason2_VLM_32b_Private_GCP_Config, + name="cosmos3_nano_reasoner_vlm_gcp_0517", + node=Cosmos3NanoReasoner_VLM_GCP_Config_0517, ) cs.store( group="vlm_config", @@ -883,8 +846,8 @@ def register_vlm(): cs.store( group="vlm_config", package="model.config.vlm_config", - name="cosmos_reason2_vlm_30b_a3b_private_gcp", - node=CosmosReason2_VLM_30b_a3b_Private_GCP_Config, + name="cosmos3_super_reasoner_vlm_gcp_0517", + node=Cosmos3SuperReasoner_VLM_GCP_Config_0517, ) cs.store( group="vlm_config", @@ -922,3 +885,15 @@ def register_vlm(): name="qwen3_vl_mot_vlm_32b_instruct_gcp", node=Qwen3VLMoT_VLM_32b_Instruct_GCP_Config, ) + cs.store( + group="vlm_config", + package="model.config.vlm_config", + name="cosmos3_edge_reasoner_vlm_gcp_4acb717", + node=Cosmos3EdgeReasoner_VLM_GCP_Config_4acb717, + ) + cs.store( + group="vlm_config", + package="model.config.vlm_config", + name="cosmos3_edge_reasoner_vlm_gcp_9b4c028", + node=Cosmos3EdgeReasoner_VLM_GCP_Config_9b4c028, + ) diff --git a/cosmos_framework/configs/base/experiment/action/__init__.py b/cosmos_framework/configs/base/experiment/action/__init__.py index 503ec1b..28a81be 100644 --- a/cosmos_framework/configs/base/experiment/action/__init__.py +++ b/cosmos_framework/configs/base/experiment/action/__init__.py @@ -1,3 +1,2 @@ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: OpenMDW-1.1 - diff --git a/cosmos_framework/configs/base/experiment/action/posttrain_config/__init__.py b/cosmos_framework/configs/base/experiment/action/posttrain_config/__init__.py index 503ec1b..28a81be 100644 --- a/cosmos_framework/configs/base/experiment/action/posttrain_config/__init__.py +++ b/cosmos_framework/configs/base/experiment/action/posttrain_config/__init__.py @@ -1,3 +1,2 @@ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: OpenMDW-1.1 - diff --git a/cosmos_framework/configs/base/experiment/action/pretrained_config/__init__.py b/cosmos_framework/configs/base/experiment/action/pretrained_config/__init__.py index 503ec1b..28a81be 100644 --- a/cosmos_framework/configs/base/experiment/action/pretrained_config/__init__.py +++ b/cosmos_framework/configs/base/experiment/action/pretrained_config/__init__.py @@ -1,3 +1,2 @@ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: OpenMDW-1.1 - diff --git a/cosmos_framework/configs/base/experiment/posttrain_video/__init__.py b/cosmos_framework/configs/base/experiment/posttrain_video/__init__.py index 503ec1b..28a81be 100644 --- a/cosmos_framework/configs/base/experiment/posttrain_video/__init__.py +++ b/cosmos_framework/configs/base/experiment/posttrain_video/__init__.py @@ -1,3 +1,2 @@ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: OpenMDW-1.1 - diff --git a/cosmos_framework/configs/base/vlm/__init__.py b/cosmos_framework/configs/base/vlm/__init__.py index 503ec1b..28a81be 100644 --- a/cosmos_framework/configs/base/vlm/__init__.py +++ b/cosmos_framework/configs/base/vlm/__init__.py @@ -1,3 +1,2 @@ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: OpenMDW-1.1 - diff --git a/cosmos_framework/configs/base/vlm/config.py b/cosmos_framework/configs/base/vlm/config.py index 66dc3ed..4d1de85 100644 --- a/cosmos_framework/configs/base/vlm/config.py +++ b/cosmos_framework/configs/base/vlm/config.py @@ -4,10 +4,9 @@ from cosmos_framework.trainer import ImaginaireTrainer from cosmos_framework.utils import log from cosmos_framework.utils.config_helper import import_all_modules_from_package +from cosmos_framework.configs.base.defaults.checkpointer import register_checkpoint, register_ckpt_type from cosmos_framework.configs.base.vlm.defaults.callbacks import register_callbacks -from cosmos_framework.configs.base.vlm.defaults.checkpointer import register_checkpoint, register_ckpt_type from cosmos_framework.configs.base.vlm.defaults.config import Config - from cosmos_framework.configs.base.vlm.defaults.model import register_model from cosmos_framework.configs.base.vlm.defaults.optimizer import register_optimizer, register_scheduler from cosmos_framework.configs.base.vlm.defaults.vlm_policy import register_vlm_policy diff --git a/cosmos_framework/configs/base/vlm/defaults/__init__.py b/cosmos_framework/configs/base/vlm/defaults/__init__.py index 503ec1b..28a81be 100644 --- a/cosmos_framework/configs/base/vlm/defaults/__init__.py +++ b/cosmos_framework/configs/base/vlm/defaults/__init__.py @@ -1,3 +1,2 @@ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: OpenMDW-1.1 - diff --git a/cosmos_framework/configs/base/vlm/defaults/callbacks.py b/cosmos_framework/configs/base/vlm/defaults/callbacks.py index 1910b63..3392205 100644 --- a/cosmos_framework/configs/base/vlm/defaults/callbacks.py +++ b/cosmos_framework/configs/base/vlm/defaults/callbacks.py @@ -12,7 +12,6 @@ from cosmos_framework.utils.lazy_config import LazyCall as L from cosmos_framework.utils.callback import LowPrecisionCallback, WandBCallback from cosmos_framework.callbacks.dataloader_state import DataLoaderStateCallback - from cosmos_framework.callbacks.grad_clip import GradClip from cosmos_framework.callbacks.hf_export import HFExportCallback from cosmos_framework.callbacks.iter_speed import IterSpeed @@ -47,7 +46,6 @@ def register_callbacks(): config=PLACEHOLDER, trainer=PLACEHOLDER, ), # reads model.precision; no extra kwarg needed - # nvtx=L(NVTXCallback)(synchronize=True), ) diff --git a/cosmos_framework/configs/base/vlm/defaults/dataloader.py b/cosmos_framework/configs/base/vlm/defaults/dataloader.py deleted file mode 100644 index 36b878d..0000000 --- a/cosmos_framework/configs/base/vlm/defaults/dataloader.py +++ /dev/null @@ -1,80 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: OpenMDW-1.1 - -from torch.utils.data import DataLoader - -from cosmos_framework.utils.lazy_config import LazyCall as L -from cosmos_framework.utils.config_helper import ConfigStore -from cosmos_framework.data.vfm.vlm.collate_fn import custom_collate -from cosmos_framework.data.vfm.vlm.debug_data_qwen import DebugQwenDataset -from cosmos_framework.data.vfm.vlm.dummy_data_qwen import DummyQwenDataset -from cosmos_framework.data.vfm.processors import build_processor_lazy - - -# Debug dataset -def create_debug_dataloader_config_qwen( - num_images, loss_on_completion_only: bool = True, use_dummy_image: bool = False -): - return L(DataLoader)( - dataset=L(DebugQwenDataset)( - tokenizer=L(build_processor_lazy)( - tokenizer_type="${model.config.policy.backbone.model_name}", - credentials="${checkpoint.load_from_object_store.credentials}", - bucket="${checkpoint.load_from_object_store.bucket}", - ), - num_images=num_images, - seq_len="${model.config.policy.model_max_length}", - image_token_len="${model.config.policy.qwen_max_video_token_length}", - # use_dummy_image=use_dummy_image, - ), - num_workers=8, - prefetch_factor=4, - batch_size=1, - sampler=None, - persistent_workers=False, - pin_memory=True, - collate_fn=custom_collate, - ) - - -def create_dummy_dataloader_config_qwen(): - return L(DataLoader)( - dataset=L(DummyQwenDataset)( - tokenizer=L(build_processor_lazy)( - tokenizer_type="${model.config.policy.backbone.model_name}", - credentials="${checkpoint.load_from_object_store.credentials}", - bucket="${checkpoint.load_from_object_store.bucket}", - ), - num_visual_tokens="${model.config.policy.qwen_max_video_token_length}", - total_tokens="${model.config.policy.model_max_length}", - batch_size="${dataloader_train.batch_size}", - ), - num_workers=8, - prefetch_factor=4, - batch_size=1, - sampler=None, - persistent_workers=False, - pin_memory=True, - collate_fn=custom_collate, - ) - - -def register_data_debug(): - cs = ConfigStore.instance() - for split in ["train", "val"]: - cs.store( - group=f"data_{split}", - package=f"dataloader_{split}", - name="debug_image_data_qwen", # This data is from pixtral model output, expected to have low loss ~1.4 - node=create_debug_dataloader_config_qwen(1), - ) - cs.store( - group=f"data_{split}", - package=f"dataloader_{split}", - name="dummy_image_data_qwen", - node=create_dummy_dataloader_config_qwen(), - ) - - -def register_data(): - register_data_debug() diff --git a/cosmos_framework/configs/base/vlm/defaults/optimizer.py b/cosmos_framework/configs/base/vlm/defaults/optimizer.py index 6632dc8..538ab7d 100644 --- a/cosmos_framework/configs/base/vlm/defaults/optimizer.py +++ b/cosmos_framework/configs/base/vlm/defaults/optimizer.py @@ -1,5 +1,6 @@ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: OpenMDW-1.1 + """Hydra config registrations for VLM optimizer + LR scheduler.""" from typing import Any diff --git a/cosmos_framework/configs/base/vlm/defaults/policy_config.py b/cosmos_framework/configs/base/vlm/defaults/policy_config.py index 307eb92..c203168 100644 --- a/cosmos_framework/configs/base/vlm/defaults/policy_config.py +++ b/cosmos_framework/configs/base/vlm/defaults/policy_config.py @@ -29,7 +29,7 @@ class PolicyConfig: trainable_map: Union[str, None] = None monkey_patch_for_text_only_data: bool = False - # HF attention impl. Default "cosmos" routes through imaginaire.attention + # HF attention impl. Default "cosmos" routes through cosmos_framework.model.attention # (NATTEN/blackwell-fmha on GB200). Override to "flash_attention_2", # "sdpa", or "eager" for fallback. attn_implementation: str = "cosmos" @@ -53,5 +53,6 @@ class VLMModelConfig: ema: EMAConfig = EMAConfig(enabled=False) # Force deterministic kernels in Flash-Attention init (slower; required for - # parity bit-exactness) - deterministic: bool = False + # parity bit-exactness). VLM-only knob — consumed by VLMModel.__init__ via + # init_flash_attn_meta. + deterministic: bool = True diff --git a/cosmos_framework/configs/base/vlm/experiment/__init__.py b/cosmos_framework/configs/base/vlm/experiment/__init__.py index 503ec1b..28a81be 100644 --- a/cosmos_framework/configs/base/vlm/experiment/__init__.py +++ b/cosmos_framework/configs/base/vlm/experiment/__init__.py @@ -1,3 +1,2 @@ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: OpenMDW-1.1 - diff --git a/cosmos_framework/configs/base/vlm/freeze_config.py b/cosmos_framework/configs/base/vlm/freeze_config.py index 9629782..d161fe8 100644 --- a/cosmos_framework/configs/base/vlm/freeze_config.py +++ b/cosmos_framework/configs/base/vlm/freeze_config.py @@ -1,5 +1,6 @@ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: OpenMDW-1.1 + """VLM freeze config (read by ``vlm_model._apply_freeze_config``).""" import attrs diff --git a/cosmos_framework/data/imaginaire/webdataset/augmentors/image/__init__.py b/cosmos_framework/data/imaginaire/webdataset/augmentors/image/__init__.py new file mode 100644 index 0000000..28a81be --- /dev/null +++ b/cosmos_framework/data/imaginaire/webdataset/augmentors/image/__init__.py @@ -0,0 +1,2 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: OpenMDW-1.1 diff --git a/cosmos_framework/data/imaginaire/webdataset/augmentors/image/cropping.py b/cosmos_framework/data/imaginaire/webdataset/augmentors/image/cropping.py new file mode 100644 index 0000000..b34cb81 --- /dev/null +++ b/cosmos_framework/data/imaginaire/webdataset/augmentors/image/cropping.py @@ -0,0 +1,150 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: OpenMDW-1.1 + +from typing import Optional + +import torch +import torchvision.transforms.functional as transforms_F +from loguru import logger as logging + +from cosmos_framework.data.imaginaire.webdataset.augmentors.augmentor import Augmentor +from cosmos_framework.data.imaginaire.webdataset.augmentors.image.misc import obtain_augmentation_size, obtain_image_size + + +class CenterCrop(Augmentor): + def __init__(self, input_keys: list, output_keys: Optional[list] = None, args: Optional[dict] = None) -> None: + super().__init__(input_keys, output_keys, args) + + def __call__(self, data_dict: dict) -> dict: + r"""Performs center crop. + + Args: + data_dict (dict): Input data dict + Returns: + data_dict (dict): Output dict where images are center cropped. + We also save the cropping parameters in the aug_params dict + so that it will be used by other transforms. + """ + assert (self.args is not None) and ("size" in self.args), "Please specify size in args" + + img_size = obtain_augmentation_size(data_dict, self.args) + width, height = img_size + + orig_w, orig_h = obtain_image_size(data_dict, self.input_keys) + for key in self.input_keys: + data_dict[key] = transforms_F.center_crop(data_dict[key], [height, width]) + + # We also add the aug params we use. This will be useful for other transforms + crop_x0 = (orig_w - width) // 2 + crop_y0 = (orig_h - height) // 2 + cropping_params = { + "resize_w": orig_w, + "resize_h": orig_h, + "crop_x0": crop_x0, + "crop_y0": crop_y0, + "crop_w": width, + "crop_h": height, + } + + if "aug_params" not in data_dict: + data_dict["aug_params"] = dict() + + data_dict["aug_params"]["cropping"] = cropping_params + return data_dict + + +class BottomCrop(Augmentor): + def __init__(self, input_keys: list, output_keys: Optional[list] = None, args: Optional[dict] = None) -> None: + super().__init__(input_keys, output_keys, args) + + def __call__(self, data_dict: dict) -> dict: + r"""Crops rows from the bottom of the image/video to reach ``target_height``. + + The top of the frame is preserved (content is top-anchored). Width is unchanged. + Works for 3-D ``[C, H, W]`` images and 4-D ``[C, T, H, W]`` or ``[T, C, H, W]`` + videos — the last two dims are always treated as (H, W). + + Args: + data_dict (dict): Input data dict. ``self.args["target_height"]`` is the + desired output height. Source height must be ``>= target_height``. + + Returns: + data_dict (dict): Output dict where images are bottom-cropped and + ``image_size`` is updated to ``[target_h, w, orig_h, orig_w]`` to mirror + :class:`ReflectionPadding`'s contract. + """ + assert (self.args is not None) and ("target_height" in self.args), "Please specify target_height in args" + if self.output_keys is None: + self.output_keys = self.input_keys + + target_h = int(self.args["target_height"]) + orig_w, orig_h = obtain_image_size(data_dict, self.input_keys) + assert orig_h >= target_h, ( + f"BottomCrop requires source height >= target_height: got orig_h={orig_h}, target_h={target_h}" + ) + + for inp_key, out_key in zip(self.input_keys, self.output_keys): + tensor = data_dict[inp_key] + # Slice the last 2 dims; the second-to-last dim is height regardless of + # whether the tensor is CHW, CTHW, or TCHW. + data_dict[out_key] = tensor[..., :target_h, :] + + if out_key != inp_key: + del data_dict[inp_key] + + data_dict["image_size"] = torch.tensor([target_h, orig_w, orig_h, orig_w], dtype=torch.float) + + return data_dict + + +class RandomCrop(Augmentor): + def __init__(self, input_keys: list, output_keys: Optional[list] = None, args: Optional[dict] = None) -> None: + super().__init__(input_keys, output_keys, args) + + def __call__(self, data_dict: dict) -> dict: + r"""Performs random crop. + + Args: + data_dict (dict): Input data dict + Returns: + data_dict (dict): Output dict where images are center cropped. + We also save the cropping parameters in the aug_params dict + so that it will be used by other transforms. + """ + + img_size = obtain_augmentation_size(data_dict, self.args) + width, height = img_size + + orig_w, orig_h = obtain_image_size(data_dict, self.input_keys) + # Obtaining random crop coords + try: + crop_x0 = int(torch.randint(0, orig_w - width + 1, size=(1,)).item()) + crop_y0 = int(torch.randint(0, orig_h - height + 1, size=(1,)).item()) + except Exception as e: + logging.warning( + f"Random crop failed. Performing center crop, original_size(wxh): {orig_w}x{orig_h}, random_size(wxh): {width}x{height}" + ) + for key in self.input_keys: + data_dict[key] = transforms_F.center_crop(data_dict[key], [height, width]) + crop_x0 = (orig_w - width) // 2 + crop_y0 = (orig_h - height) // 2 + + # We also add the aug params we use. This will be useful for other transforms + cropping_params = { + "resize_w": orig_w, + "resize_h": orig_h, + "crop_x0": crop_x0, + "crop_y0": crop_y0, + "crop_w": width, + "crop_h": height, + } + + if "aug_params" not in data_dict: + data_dict["aug_params"] = dict() + + data_dict["aug_params"]["cropping"] = cropping_params + + # We must perform same random cropping for all input keys + for key in self.input_keys: + data_dict[key] = transforms_F.crop(data_dict[key], crop_y0, crop_x0, height, width) + return data_dict diff --git a/cosmos_framework/data/imaginaire/webdataset/augmentors/image/flip.py b/cosmos_framework/data/imaginaire/webdataset/augmentors/image/flip.py new file mode 100644 index 0000000..8f0bb7d --- /dev/null +++ b/cosmos_framework/data/imaginaire/webdataset/augmentors/image/flip.py @@ -0,0 +1,32 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: OpenMDW-1.1 + +from typing import Optional + +import torch +import torchvision.transforms.functional as transforms_F + +from cosmos_framework.data.imaginaire.webdataset.augmentors.augmentor import Augmentor + + +class HorizontalFlip(Augmentor): + def __init__(self, input_keys: list, output_keys: Optional[list] = None, args: Optional[dict] = None) -> None: + super().__init__(input_keys, output_keys, args) + + def __call__(self, data_dict: dict) -> dict: + r"""Performs horizontal flipping. + + Args: + data_dict (dict): Input data dict + Returns: + data_dict (dict): Output dict where images are center cropped. + """ + flip_enabled = getattr(self.args, "enabled", True) + if flip_enabled: + p = getattr(self.args, "prob", 0.5) + coin_flip = torch.rand(1).item() > p + for key in self.input_keys: + if coin_flip: + data_dict[key] = transforms_F.hflip(data_dict[key]) + + return data_dict diff --git a/cosmos_framework/data/imaginaire/webdataset/augmentors/image/misc.py b/cosmos_framework/data/imaginaire/webdataset/augmentors/image/misc.py new file mode 100644 index 0000000..d3e5216 --- /dev/null +++ b/cosmos_framework/data/imaginaire/webdataset/augmentors/image/misc.py @@ -0,0 +1,51 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: OpenMDW-1.1 + +from typing import Union + +import torch +from PIL import Image + + +def obtain_image_size(data_dict: dict, input_keys: list) -> tuple[int, int]: + r"""Function for obtaining the image size from the data dict. + + Args: + data_dict (dict): Input data dict + input_keys (list): List of input keys + Returns: + width (int): Width of the input image + height (int): Height of the input image + """ + + data1 = data_dict[input_keys[0]] + if isinstance(data1, Image.Image): + width, height = data1.size + elif isinstance(data1, torch.Tensor): + height, width = data1.size()[-2:] + else: + raise ValueError("data to random crop should be PIL Image or tensor") + + return width, height + + +def obtain_augmentation_size(data_dict: dict, augmentor_cfg: dict) -> Union[int, tuple]: + r"""Function for obtaining size of the augmentation. + When dealing with multi-aspect ratio dataloaders, we need to + find the augmentation size from the aspect ratio of the data. + If data_dict contains "_res_size_map" (e.g. from resolution sampling), + that map is used instead of augmentor_cfg["size"]. + + Args: + data_dict (dict): Input data dict + augmentor_cfg (dict): Augmentor config + Returns: + aug_size (int): Size of augmentation + """ + if "__url__" in data_dict and "aspect_ratio" in data_dict["__url__"].meta.opts: + aspect_ratio = data_dict["__url__"].meta.opts["aspect_ratio"] + else: # Non-webdataset format + aspect_ratio = data_dict["aspect_ratio"] + if "_res_size_map" in data_dict: + return data_dict["_res_size_map"][aspect_ratio] + return augmentor_cfg["size"][aspect_ratio] diff --git a/cosmos_framework/data/imaginaire/webdataset/augmentors/image/normalize.py b/cosmos_framework/data/imaginaire/webdataset/augmentors/image/normalize.py new file mode 100644 index 0000000..a949230 --- /dev/null +++ b/cosmos_framework/data/imaginaire/webdataset/augmentors/image/normalize.py @@ -0,0 +1,36 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: OpenMDW-1.1 + +from typing import Optional + +import torch +import torchvision.transforms.functional as transforms_F + +from cosmos_framework.data.imaginaire.webdataset.augmentors.augmentor import Augmentor + + +class Normalize(Augmentor): + def __init__(self, input_keys: list, output_keys: Optional[list] = None, args: Optional[dict] = None) -> None: + super().__init__(input_keys, output_keys, args) + + def __call__(self, data_dict: dict) -> dict: + r"""Performs data normalization. + + Args: + data_dict (dict): Input data dict + Returns: + data_dict (dict): Output dict where images are center cropped. + """ + assert self.args is not None, "Please specify args" + + mean = self.args["mean"] + std = self.args["std"] + + for key in self.input_keys: + if isinstance(data_dict[key], torch.Tensor): + data_dict[key] = data_dict[key].to(dtype=torch.get_default_dtype()).div(255) + else: + data_dict[key] = transforms_F.to_tensor(data_dict[key]) # division by 255 is applied in to_tensor() + + data_dict[key] = transforms_F.normalize(tensor=data_dict[key], mean=mean, std=std) + return data_dict diff --git a/cosmos_framework/data/imaginaire/webdataset/augmentors/image/padding.py b/cosmos_framework/data/imaginaire/webdataset/augmentors/image/padding.py new file mode 100644 index 0000000..e14d66f --- /dev/null +++ b/cosmos_framework/data/imaginaire/webdataset/augmentors/image/padding.py @@ -0,0 +1,60 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: OpenMDW-1.1 + +from typing import Optional + +import omegaconf +import torch +import torchvision.transforms.functional as transforms_F + +from cosmos_framework.data.imaginaire.webdataset.augmentors.augmentor import Augmentor +from cosmos_framework.data.imaginaire.webdataset.augmentors.image.misc import obtain_augmentation_size, obtain_image_size + + +class ReflectionPadding(Augmentor): + def __init__(self, input_keys: list, output_keys: Optional[list] = None, args: Optional[dict] = None) -> None: + super().__init__(input_keys, output_keys, args) + + def __call__(self, data_dict: dict) -> dict: + r"""Performs reflection padding. This function also returns a padding mask. + + Args: + data_dict (dict): Input data dict + Returns: + data_dict (dict): Output dict where images are center cropped. + """ + + assert self.args is not None, "Please specify args in augmentation" + if self.output_keys is None: + self.output_keys = self.input_keys + + # Obtain image and augmentation sizes + orig_w, orig_h = obtain_image_size(data_dict, self.input_keys) + target_size = obtain_augmentation_size(data_dict, self.args) + + assert isinstance(target_size, (tuple, omegaconf.listconfig.ListConfig)), "Please specify target size as tuple" + target_w, target_h = target_size + + target_w = int(target_w) + target_h = int(target_h) + + # One-sided padding (bottom and right only, content stays at top-left) + padding_right = target_w - orig_w + padding_bottom = target_h - orig_h + padding_vals = [0, 0, padding_right, padding_bottom] + + for inp_key, out_key in zip(self.input_keys, self.output_keys): + if max(padding_vals[0], padding_vals[2]) >= orig_w or max(padding_vals[1], padding_vals[3]) >= orig_h: + # In this case, we can't perform reflection padding. This is because padding values + # are larger than the image size. So, perform edge padding instead. + data_dict[out_key] = transforms_F.pad(data_dict[inp_key], padding_vals, padding_mode="edge") + else: + # Perform reflection padding + data_dict[out_key] = transforms_F.pad(data_dict[inp_key], padding_vals, padding_mode="reflect") + + if out_key != inp_key: + del data_dict[inp_key] + + data_dict["image_size"] = torch.tensor([target_h, target_w, orig_h, orig_w], dtype=torch.float) + + return data_dict diff --git a/cosmos_framework/data/imaginaire/webdataset/augmentors/image/resize.py b/cosmos_framework/data/imaginaire/webdataset/augmentors/image/resize.py new file mode 100644 index 0000000..82cdea9 --- /dev/null +++ b/cosmos_framework/data/imaginaire/webdataset/augmentors/image/resize.py @@ -0,0 +1,175 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: OpenMDW-1.1 + +from typing import Optional + +import omegaconf +import torchvision.transforms.functional as transforms_F + +from cosmos_framework.data.imaginaire.webdataset.augmentors.augmentor import Augmentor +from cosmos_framework.data.imaginaire.webdataset.augmentors.image.misc import obtain_augmentation_size, obtain_image_size + + +class ResizeSmallestSide(Augmentor): + def __init__(self, input_keys: list, output_keys: Optional[list] = None, args: Optional[dict] = None) -> None: + super().__init__(input_keys, output_keys, args) + + def __call__(self, data_dict: dict) -> dict: + r"""Performs resizing to smaller side + + Args: + data_dict (dict): Input data dict + Returns: + data_dict (dict): Output dict where images are resized + """ + + if self.output_keys is None: + self.output_keys = self.input_keys + assert self.args is not None, "Please specify args in augmentations" + + for inp_key, out_key in zip(self.input_keys, self.output_keys): + out_size = obtain_augmentation_size(data_dict, self.args) + assert isinstance(out_size, int), "Arg size in resize should be an integer" + data_dict[out_key] = transforms_F.resize( + data_dict[inp_key], + size=out_size, # type: ignore + interpolation=getattr(self.args, "interpolation", transforms_F.InterpolationMode.BICUBIC), + antialias=True, + ) + if out_key != inp_key: + del data_dict[inp_key] + return data_dict + + +class ResizeLargestSide(Augmentor): + def __init__(self, input_keys: list, output_keys: Optional[list] = None, args: Optional[dict] = None) -> None: + super().__init__(input_keys, output_keys, args) + + def __call__(self, data_dict: dict) -> dict: + r"""Performs resizing to larger side + + Args: + data_dict (dict): Input data dict + Returns: + data_dict (dict): Output dict where images are resized + """ + + if self.output_keys is None: + self.output_keys = self.input_keys + assert self.args is not None, "Please specify args in augmentations" + + for inp_key, out_key in zip(self.input_keys, self.output_keys): + out_size = obtain_augmentation_size(data_dict, self.args) + assert isinstance(out_size, int), "Arg size in resize should be an integer" + orig_w, orig_h = obtain_image_size(data_dict, self.input_keys) + + scaling_ratio = min(out_size / orig_w, out_size / orig_h) + target_size = [int(scaling_ratio * orig_h), int(scaling_ratio * orig_w)] + + data_dict[out_key] = transforms_F.resize( + data_dict[inp_key], + size=target_size, + interpolation=getattr(self.args, "interpolation", transforms_F.InterpolationMode.BICUBIC), + antialias=True, + ) + if out_key != inp_key: + del data_dict[inp_key] + return data_dict + + +class ResizeSmallestSideAspectPreserving(Augmentor): + def __init__(self, input_keys: list, output_keys: Optional[list] = None, args: Optional[dict] = None) -> None: + super().__init__(input_keys, output_keys, args) + + def __call__(self, data_dict: dict) -> dict: + r"""Performs aspect-ratio preserving resizing. + Image is resized to the dimension which has the smaller ratio of (size / target_size). + First we compute (w_img / w_target) and (h_img / h_target) and resize the image + to the dimension that has the smaller of these ratios. + + Args: + data_dict (dict): Input data dict + Returns: + data_dict (dict): Output dict where images are resized + """ + + if self.output_keys is None: + self.output_keys = self.input_keys + assert self.args is not None, "Please specify args in augmentations" + + img_size = obtain_augmentation_size(data_dict, self.args) + assert isinstance(img_size, (tuple, omegaconf.listconfig.ListConfig)), ( + f"Arg size in resize should be a tuple, get {type(img_size)}, {img_size}" + ) + img_w, img_h = img_size + + orig_w, orig_h = obtain_image_size(data_dict, self.input_keys) + scaling_ratio = max((img_w / orig_w), (img_h / orig_h)) + target_size = (int(scaling_ratio * orig_h + 0.5), int(scaling_ratio * orig_w + 0.5)) + + assert target_size[0] >= img_h and target_size[1] >= img_w, ( + f"Resize error. orig {(orig_w, orig_h)} desire {img_size} compute {target_size}" + ) + + for inp_key, out_key in zip(self.input_keys, self.output_keys): + data_dict[out_key] = transforms_F.resize( + data_dict[inp_key], + size=target_size, # type: ignore + interpolation=( + self.args["interpolation"] + if "interpolation" in self.args + else transforms_F.InterpolationMode.BICUBIC + ), + antialias=True, + ) + + if out_key != inp_key: + del data_dict[inp_key] + return data_dict + + +class ResizeLargestSideAspectPreserving(Augmentor): + def __init__(self, input_keys: list, output_keys: Optional[list] = None, args: Optional[dict] = None) -> None: + super().__init__(input_keys, output_keys, args) + + def __call__(self, data_dict: dict) -> dict: + r"""Performs aspect-ratio preserving resizing. + Image is resized to the dimension which has the larger ratio of (size / target_size). + First we compute (w_img / w_target) and (h_img / h_target) and resize the image + to the dimension that has the larger of these ratios. + + Args: + data_dict (dict): Input data dict + Returns: + data_dict (dict): Output dict where images are resized + """ + + if self.output_keys is None: + self.output_keys = self.input_keys + assert self.args is not None, "Please specify args in augmentations" + + img_size = obtain_augmentation_size(data_dict, self.args) + assert isinstance(img_size, (tuple, omegaconf.listconfig.ListConfig)), ( + f"Arg size in resize should be a tuple, get {type(img_size)}, {img_size}" + ) + img_w, img_h = img_size + + orig_w, orig_h = obtain_image_size(data_dict, self.input_keys) + scaling_ratio = min((img_w / orig_w), (img_h / orig_h)) + target_size = (int(scaling_ratio * orig_h + 0.5), int(scaling_ratio * orig_w + 0.5)) + + assert target_size[0] <= img_h and target_size[1] <= img_w, ( + f"Resize error. orig {(orig_w, orig_h)} desire {img_size} compute {target_size}" + ) + + for inp_key, out_key in zip(self.input_keys, self.output_keys): + data_dict[out_key] = transforms_F.resize( + data_dict[inp_key], + size=target_size, # type: ignore + interpolation=getattr(self.args, "interpolation", transforms_F.InterpolationMode.BICUBIC), + antialias=True, + ) + + if out_key != inp_key: + del data_dict[inp_key] + return data_dict diff --git a/cosmos_framework/data/vfm/action/__init__.py b/cosmos_framework/data/vfm/action/__init__.py index 503ec1b..28a81be 100644 --- a/cosmos_framework/data/vfm/action/__init__.py +++ b/cosmos_framework/data/vfm/action/__init__.py @@ -1,3 +1,2 @@ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: OpenMDW-1.1 - diff --git a/cosmos_framework/data/vfm/action/action_processing.py b/cosmos_framework/data/vfm/action/action_processing.py new file mode 100644 index 0000000..1a09170 --- /dev/null +++ b/cosmos_framework/data/vfm/action/action_processing.py @@ -0,0 +1,257 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: OpenMDW-1.1 + +"""Shared Action processing records and normalization helpers.""" + +from __future__ import annotations + +import json +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Literal, Protocol + +import numpy as np +import torch + +from cosmos_framework.utils import log + +ActionNormalizationMethod = Literal["quantile", "quantile_rot", "meanstd", "minmax"] + + +class ActionNormalizer(Protocol): + """Tensor-level action normalization interface used by ActionProcessor.""" + + def normalize_action(self, action: torch.Tensor) -> torch.Tensor: # action: [...,D], returns [...,D] + """Map raw action values into model-space values.""" + ... + + def denormalize_action(self, action: torch.Tensor) -> torch.Tensor: # action: [...,D], returns [...,D] + """Invert model-space action values back into raw action values.""" + ... + + +@dataclass(frozen=True) +class ActionAffineNormalization: + """Resolved affine action normalizer. + + Forward normalization is ``(raw - offset) / scale``. Inverse + denormalization is ``normalized * scale + offset``. + + ``forward_clamp`` records lossy range-style forward clamping. When + ``forward_clamp_mask`` is provided, only channels with a ``True`` mask + entry are clamped; this represents mixed UMI normalizers where some fields + are range-clamped and others are plain affine transforms. + """ + + offset: torch.Tensor + scale: torch.Tensor + forward_clamp: tuple[float, float] | None = None + forward_clamp_mask: torch.Tensor | None = None + + def normalize_action(self, action: torch.Tensor) -> torch.Tensor: # action: [...,D], returns [...,D] + """Normalize raw action values with resolved affine parameters.""" + offset = self.offset.to(device=action.device, dtype=action.dtype) # [D] + scale = self.scale.to(device=action.device, dtype=action.dtype) # [D] + normalized = (action - offset) / scale # [...,D] + if self.forward_clamp is not None: + lo, hi = self.forward_clamp + clamped = normalized.clamp(lo, hi) # [...,D] + if self.forward_clamp_mask is None: + normalized = clamped # [...,D] + else: + clamp_mask = self.forward_clamp_mask.to(device=action.device, dtype=torch.bool) # [D] + normalized = torch.where(clamp_mask, clamped, normalized) # [...,D] + return normalized # [...,D] + + def denormalize_action(self, action: torch.Tensor) -> torch.Tensor: # action: [...,D], returns [...,D] + """Invert action normalization with resolved affine parameters.""" + offset = self.offset.to(device=action.device, dtype=action.dtype) # [D] + scale = self.scale.to(device=action.device, dtype=action.dtype) # [D] + return action * scale + offset # [...,D] + + +def load_action_stats(stats_path: str, stats_key: str = "global") -> dict[str, np.ndarray]: + """Load pre-computed action normalization stats from a JSON file.""" + path = Path(stats_path) + if not path.exists(): + raise FileNotFoundError(f"Action normalization stats not found at {stats_path}.") + log.info(f"Loading action normalization stats from {stats_path}") + with path.open("r") as f: + raw = json.load(f) + stat_keys = {"mean", "std", "min", "max", "q01", "q99"} + if stats_key in raw: + raw = raw[stats_key] + if not isinstance(raw, dict): + raise TypeError(f"Action normalization stats block {stats_key!r} in {stats_path} must be a dict.") + elif stats_key != "global" and not any(key in raw for key in stat_keys): + raise KeyError(f"Action normalization stats block {stats_key!r} not found in {stats_path}.") + return {k: np.array(v, dtype=np.float32) for k, v in raw.items() if k in stat_keys} + + +def resolve_action_normalization( + method: ActionNormalizationMethod, + stats: dict[str, torch.Tensor], +) -> ActionAffineNormalization: + """Resolve configured action stats into affine forward/inverse parameters.""" + if method == "meanstd": + offset = stats["mean"] # [D] + scale = stats["std"].clamp(min=1e-8) # [D] + return ActionAffineNormalization(offset=offset, scale=scale) + + if method == "minmax": + lo = stats["min"] # [D] + hi = stats["max"] # [D] + elif method in ("quantile", "quantile_rot"): + lo = stats["q01"] # [D] + hi = stats["q99"] # [D] + else: + raise ValueError(f"Unknown normalization method: {method!r}") + + offset = (hi + lo) / 2.0 # [D] + scale = (hi - lo).clamp(min=1e-8) / 2.0 # [D] + return ActionAffineNormalization( + offset=offset, + scale=scale, + forward_clamp=(-1.0, 1.0), + ) + + +def make_pose_action_scale_normalizer( + action_dim: int, + *, + translation_scale: float = 1.0, + rotation_scale: float = 1.0, +) -> ActionAffineNormalization: + """Build a normalizer that maps raw pose deltas into scaled model space. + + Pose actions use the shared layout ``[translation(3), rotation(...)]``. + The returned normalizer multiplies translation channels by + ``translation_scale`` and rotation channels by ``rotation_scale`` during + preprocessing, then inverts those factors during postprocessing. + """ + if action_dim < 3: + raise ValueError(f"Pose action_dim must be at least 3, got {action_dim}") + if translation_scale == 0: + raise ValueError("translation_scale must be non-zero") + if rotation_scale == 0: + raise ValueError("rotation_scale must be non-zero") + + offset = torch.zeros(action_dim, dtype=torch.float32) # [D] + scale = torch.ones(action_dim, dtype=torch.float32) # [D] + scale[:3] = 1.0 / float(translation_scale) # [D] + if action_dim > 3: + scale[3:] = 1.0 / float(rotation_scale) # [D] + return ActionAffineNormalization(offset=offset, scale=scale) + + +@dataclass(frozen=True) +class ActionProcessingRecord: + """Per-sample metadata needed to invert Action model-space preprocessing.""" + + raw_action_dim: int + action_normalizer: ActionNormalizer | None + + +def pad_action_to_max_dim( + action: torch.Tensor, max_action_dim: int +) -> torch.Tensor: # action: [T,D], returns [T,D_model] + """Pad action tensor to max_action_dim along the last dimension. + + Args: + action: Action tensor of shape (T, D) where D is the current action dimension. + max_action_dim: Target action dimension to pad to. + + Returns: + Padded action tensor of shape (T, max_action_dim). + """ + if action.shape[-1] > max_action_dim: + raise ValueError(f"Action dimension {action.shape[-1]} is greater than max_action_dim {max_action_dim}") + if action.shape[-1] == max_action_dim: + return action # [T,D_model] + padding_size = max_action_dim - action.shape[-1] + zero_padding = torch.zeros(*action.shape[:-1], padding_size, dtype=action.dtype, device=action.device) # [T,D_pad] + return torch.cat([action, zero_padding], dim=-1) # [T,D_model] + + +def make_batched_action_processing_fields( + record: ActionProcessingRecord, + batch_size: int, + *, + action_channel_masking: bool = True, +) -> dict[str, list[torch.Tensor | ActionProcessingRecord | None]]: + """Build batch-list fields whose action width and inverse record cannot drift apart.""" + raw_action_dim = torch.tensor(record.raw_action_dim, dtype=torch.long) if action_channel_masking else None # [] + return { + "raw_action_dim": [raw_action_dim] * batch_size, + "action_processing_record": [record] * batch_size, + } + + +class ActionProcessor: + """Forward and inverse Action tensor processing for a single sample.""" + + def __init__(self, max_action_dim: int, action_channel_masking: bool = True) -> None: + self.max_action_dim = int(max_action_dim) + self.action_channel_masking = bool(action_channel_masking) + + def preprocess_action( + self, + data_dict: dict[str, Any], + action: torch.Tensor, + *, + action_normalizer: ActionNormalizer | None, + ) -> dict[str, Any]: + """Return a sample with normalized, padded action fields and the inverse record.""" + + raw_action_dim = int(action.shape[-1]) + if action_normalizer is not None: + action = action_normalizer.normalize_action(action) # [T,D] + if int(action.shape[-1]) != raw_action_dim: + raise ValueError( + f"Action normalizer changed action width from {raw_action_dim} to {int(action.shape[-1])}" + ) + + processed_data_dict = dict(data_dict) + processed_data_dict["action"] = pad_action_to_max_dim(action, self.max_action_dim) # [T,D_model] + record = ActionProcessingRecord( + raw_action_dim=raw_action_dim, + action_normalizer=action_normalizer, + ) + processed_data_dict["raw_action_dim"] = ( + torch.tensor(record.raw_action_dim, dtype=torch.long) if self.action_channel_masking else None + ) # [] + processed_data_dict["action_processing_record"] = record + return processed_data_dict + + @staticmethod + def _unpad_action(action: torch.Tensor, raw_action_dim: int) -> torch.Tensor: + """Drop model-only padded action channels.""" + if action.shape[-1] < raw_action_dim: + raise ValueError(f"invalid raw_action_dim={raw_action_dim} for action with shape {tuple(action.shape)}") + return action[..., :raw_action_dim].contiguous() # [...,D_raw] + + @staticmethod + def postprocess_action( + action: torch.Tensor, + record: ActionProcessingRecord, + ) -> torch.Tensor: + """Unpad and denormalize a model-space action tensor.""" + action = ActionProcessor._unpad_action(action, record.raw_action_dim) # [...,D_raw] + if record.action_normalizer is not None: + action = record.action_normalizer.denormalize_action(action) # [...,D_raw] + return action # [...,D_raw] + + +def get_action_processing_records(data_batch: dict[str, Any]) -> list[ActionProcessingRecord | None]: + """Read all per-sample processing records from a collated Action batch.""" + records = data_batch.get("action_processing_record") + if records is None: + return [] + if isinstance(records, ActionProcessingRecord): + return [records] + if isinstance(records, list): + for record in records: + if record is not None and not isinstance(record, ActionProcessingRecord): + raise TypeError(f"Unexpected action_processing_record entry type: {type(record).__name__}") + return records + raise TypeError(f"Unexpected action_processing_record type: {type(records).__name__}") diff --git a/cosmos_framework/data/vfm/action/domain_utils.py b/cosmos_framework/data/vfm/action/domain_utils.py index 910cc39..6f433f7 100644 --- a/cosmos_framework/data/vfm/action/domain_utils.py +++ b/cosmos_framework/data/vfm/action/domain_utils.py @@ -14,9 +14,12 @@ "bridge_orig_lerobot": 7, "droid_lerobot": 8, "robomind-franka": 8, # Both Droid and RoboMIND-Franka are using robotiq and franka + "embodiment_b": 9, "robomind-franka-dual": 12, "robomind-ur": 13, "agibotworld": 15, + "embodiment_c_gripper": 15, + "embodiment_c_gripper_ext": 15, "fractal": 20, } @@ -24,7 +27,6 @@ EMBODIMENT_TO_RAW_ACTION_DIM: dict[str, int] = { "av": 9, "camera_pose": 9, - "hand_pose": 57, "pusht": 2, "umi": 10, "bridge_orig_lerobot": 10, @@ -32,8 +34,16 @@ "robomind-franka": 10, "robomind-franka-dual": 20, "robomind-ur": 10, + "embodiment_b": 30, "agibotworld": 29, + "embodiment_c_gripper": 29, + "embodiment_c_gripper_ext": 29, "fractal": 10, + # NOTE: ``libero`` (7/10/13 depending on ``rotation_space``) and ``hand_pose`` + # (variable with ``keypoint_option`` and ``rotation_format``) are absent + # because their raw width is set per-dataset at construction time. Inference + # in inverse_dynamics/policy modes is not supported for these domains until + # canonical widths are added here. } @@ -46,3 +56,20 @@ def get_domain_id(embodiment_type: str) -> int: f"Available embodiments: {sorted(EMBODIMENT_TO_DOMAIN_ID.keys())}" ) return EMBODIMENT_TO_DOMAIN_ID[key] + + +def get_action_dim(embodiment_type: str) -> int: + """Get the raw action dimension for a given embodiment type.""" + key = embodiment_type.lower().strip() + if key not in EMBODIMENT_TO_RAW_ACTION_DIM: + raise KeyError( + f"Unknown embodiment type: {embodiment_type!r}. " + f"Available embodiments: {sorted(EMBODIMENT_TO_RAW_ACTION_DIM.keys())}" + ) + return EMBODIMENT_TO_RAW_ACTION_DIM[key] + + +def is_valid_domain_name(embodiment_type: str) -> bool: + """Check if the given embodiment type is recognized.""" + key = embodiment_type.lower().strip() + return key in EMBODIMENT_TO_RAW_ACTION_DIM diff --git a/cosmos_framework/data/vfm/action/json_formatter.py b/cosmos_framework/data/vfm/action/json_formatter.py index b511e93..201a76e 100644 --- a/cosmos_framework/data/vfm/action/json_formatter.py +++ b/cosmos_framework/data/vfm/action/json_formatter.py @@ -7,9 +7,9 @@ import torch +from cosmos_framework.utils import log from cosmos_framework.data.vfm.action.viewpoint_utils import DEFAULT_VIEWPOINT_TEMPLATES from cosmos_framework.data.vfm.utils import VIDEO_RES_SIZE_INFO -from cosmos_framework.utils import log def _should_append_idle_frame_info(mode: object) -> bool: diff --git a/cosmos_framework/data/vfm/action/pose_utils.py b/cosmos_framework/data/vfm/action/pose_utils.py index 12ea51b..d6d26a4 100644 --- a/cosmos_framework/data/vfm/action/pose_utils.py +++ b/cosmos_framework/data/vfm/action/pose_utils.py @@ -298,24 +298,15 @@ def build_abs_pose_from_components( def _delta_transform_to_pose_vector( delta_T: np.ndarray, rotation_output_format: RotationConvention, - translation_scale: float = 1.0, - rotation_scale: float = 1.0, ) -> np.ndarray: """Encode a relative transform as an action vector. The shared action-vector layout is always ``[translation(3), rotation(...)]``. - The translation block is multiplied by ``translation_scale`` before concatenation, - and the rotation block is multiplied by ``rotation_scale``. Args: delta_T: Relative transform of shape ``(4, 4)``. rotation_output_format: Concrete convention used for the output rotation block. - translation_scale: Scalar multiplier applied to the translation block. - rotation_scale: Scalar multiplier applied to the rotation block. Used to - match the loss scale of the rotation block to the translation block. - The decoder must divide by the same factor before reconstructing the - rotation matrix. Returns: A ``float32`` action vector whose first three values are translation and @@ -325,12 +316,11 @@ def _delta_transform_to_pose_vector( if delta_np.shape != (4, 4): raise ValueError(f"delta_T must have shape (4, 4), got {delta_np.shape}") - translation = delta_np[:3, 3] * translation_scale + translation = delta_np[:3, 3] rotation = np.asarray( convert_rotation(delta_np[:3, :3], input_format="matrix", output_format=rotation_output_format), dtype=np.float32, ) - rotation = rotation * rotation_scale return np.concatenate([translation, rotation]).astype(np.float32) @@ -344,19 +334,19 @@ def _pose_vector_to_delta_transform( """Decode an action vector back into a relative homogeneous transform. This is the inverse of `_delta_transform_to_pose_vector()` when the same - rotation convention and scale are used. + rotation convention is used. Scale arguments are provided for callers that + need to decode model-space pose actions before action-normalizer + denormalization has been applied. Args: pose_vector: Relative-pose action vector with layout ``[translation(3), rotation(...)]``. rotation_input_format: Concrete convention used by the rotation block. - translation_scale: Scalar used to undo the translation scaling applied during - encoding. + translation_scale: Scalar used to undo translation scaling in the input + vector. normalize_rotation: Whether to project the decoded rotation to a valid matrix before assembling the transform. - rotation_scale: Scalar used to undo the rotation scaling applied during - encoding. Must match the value used by - `_delta_transform_to_pose_vector()`. + rotation_scale: Scalar used to undo rotation scaling in the input vector. Returns: A relative homogeneous transform with shape ``(4, 4)`` and dtype @@ -440,8 +430,6 @@ def pose_abs_to_rel( poses_abs: np.ndarray, rotation_format: RotationConvention = "rot9d", pose_convention: PoseConvention = "backward_framewise", - translation_scale: float = 1.0, - rotation_scale: float = 1.0, ) -> np.ndarray: """Convert an absolute-pose trajectory into relative-pose action vectors. @@ -454,12 +442,6 @@ def pose_abs_to_rel( pose_convention: Pose convention: - ``backward_framewise``: ``delta_T = T_i^{-1} @ T_{i+1}`` - ``backward_anchored``: ``delta_T = T_0^{-1} @ T_{i+1}`` - translation_scale: Scalar multiplier applied to the translation block of each - encoded action vector. - rotation_scale: Scalar multiplier applied to the rotation block of each - encoded action vector. Use this to match the loss scale of rotation - and translation. `pose_rel_to_abs()` must be called with the same - value to invert the scaling. Returns: An array of shape ``(T - 1, D)`` where ``D = 3 + rotation_dim``. @@ -481,8 +463,6 @@ def pose_abs_to_rel( _delta_transform_to_pose_vector( delta_T, rotation_output_format=rotation_format, - translation_scale=translation_scale, - rotation_scale=rotation_scale, ) ) @@ -510,10 +490,12 @@ def pose_rel_to_abs( identity transform is used. normalize_rotation: Whether to project decoded rotations onto ``SO(3)`` before composing them back into the trajectory. - translation_scale: Scalar used to undo the translation scaling applied during - `pose_abs_to_rel()`. - rotation_scale: Scalar used to undo the rotation scaling applied during - `pose_abs_to_rel()`. Must match the value passed there. + translation_scale: Scalar used to undo translation scaling in + ``poses_rel``. Prefer denormalizing with the dataset action + normalizer before calling this function. + rotation_scale: Scalar used to undo rotation scaling in ``poses_rel``. + Prefer denormalizing with the dataset action normalizer before + calling this function. Returns: Absolute poses with shape ``(T, 4, 4)`` where ``T = len(poses_rel) + 1``. diff --git a/cosmos_framework/data/vfm/action/pose_utils_test.py b/cosmos_framework/data/vfm/action/pose_utils_test.py index 93f9791..3cdc20b 100644 --- a/cosmos_framework/data/vfm/action/pose_utils_test.py +++ b/cosmos_framework/data/vfm/action/pose_utils_test.py @@ -196,14 +196,12 @@ def test_pose_abs_to_rel_roundtrips_through_pose_rel_to_abs( poses_abs, rotation_format=rotation_format, pose_convention=pose_convention, - translation_scale=2.5, ) reconstructed = pose_rel_to_abs( poses_rel, rotation_format=rotation_format, pose_convention=pose_convention, initial_pose=poses_abs[0], - translation_scale=2.5, ) np.testing.assert_allclose(reconstructed, poses_abs, atol=1e-5) diff --git a/cosmos_framework/data/vfm/action/transforms.py b/cosmos_framework/data/vfm/action/transforms.py index d17b141..3462f1e 100644 --- a/cosmos_framework/data/vfm/action/transforms.py +++ b/cosmos_framework/data/vfm/action/transforms.py @@ -19,6 +19,11 @@ import torch import torchvision.transforms.functional as transforms_F +from cosmos_framework.utils import log +from cosmos_framework.data.vfm.action.action_processing import ( + ActionNormalizer, + ActionProcessor, +) from cosmos_framework.data.vfm.action.json_formatter import ActionPromptJsonFormatter from cosmos_framework.data.vfm.action.viewpoint_utils import ViewpointTextInfo from cosmos_framework.data.vfm.augmentors.duration_fps_text_timestamps import DurationFPSTextTimeStamps @@ -27,7 +32,6 @@ from cosmos_framework.data.vfm.augmentors.text_tokenizer import TextTokenizerTransform from cosmos_framework.data.vfm.sequence_packing import SequencePlan from cosmos_framework.data.vfm.utils import VIDEO_RES_SIZE_INFO -from cosmos_framework.utils import log from cosmos_framework.utils.vfm.data_utils import get_vision_data_resolution @@ -36,28 +40,6 @@ def _should_append_idle_frame_info(mode: object) -> bool: return mode != "inverse_dynamics" -def pad_action_to_max_dim(action: torch.Tensor, max_action_dim: int) -> torch.Tensor: - """Pad action tensor to max_action_dim along the last dimension. - - Args: - action: Action tensor of shape (T, D) where D is the current action dimension. - max_action_dim: Target action dimension to pad to. - - Returns: - Padded action tensor of shape (T, max_action_dim). - """ - if action.shape[-1] > max_action_dim: - raise ValueError(f"Action dimension {action.shape[-1]} is greater than max_action_dim {max_action_dim}") - elif action.shape[-1] == max_action_dim: - return action - else: - padding_size = max_action_dim - action.shape[-1] - zero_padding = torch.zeros( - *action.shape[:-1], padding_size, dtype=action.dtype, device=action.device - ) # [T,padding_size] - return torch.cat([action, zero_padding], dim=-1) # [T,max_action_dim] - - def find_closest_target_size(h: int, w: int, resolution: str | int) -> tuple[int, int]: """Find the closest predefined target size for a given input resolution. @@ -205,7 +187,7 @@ def reflection_pad_to_target( def remove_reflection_padding( tensor: torch.Tensor, - image_size: torch.Tensor, + image_size: torch.Tensor | list[torch.Tensor] | None, ) -> torch.Tensor: """Remove reflection padding added by :func:`reflection_pad_to_target`. @@ -215,17 +197,30 @@ def remove_reflection_padding( tensor: Tensor whose last two dimensions are the padded spatial dims. Supports any leading dimensions, e.g. ``(C, T, H, W)`` or ``(C, H, W)``. - image_size: 1-D tensor of shape ``(4,)`` containing - ``[target_h, target_w, orig_h_resized, orig_w_resized]`` where - ``orig_h/w_resized`` is the original spatial size after - aspect-preserving resize (i.e. the content region before - padding) — the same convention stored by - :func:`reflection_pad_to_target` and VFM's - ``ReflectionPadding``. + image_size: Spatial metadata using the convention produced by + :func:`reflection_pad_to_target`. Accepted forms are ``None`` (no + crop), a tensor with shape ``(4,)`` or ``(1, 4)``, or a non-empty + list whose first element has one of those tensor shapes. The four + values are ``[target_h, target_w, orig_h_resized, + orig_w_resized]``, where ``orig_h/w_resized`` is the original + spatial size after aspect-preserving resize (i.e. the content + region before padding). This matches the convention stored by + :func:`reflection_pad_to_target` and VFM's ``ReflectionPadding``. Returns: Cropped tensor of shape ``(..., orig_h_resized, orig_w_resized)``. """ + if image_size is None: + return tensor + if isinstance(image_size, list): + if not image_size: + raise ValueError("Expected at least one image_size entry") + image_size = image_size[0] # [1,4] or [4] + if image_size.ndim == 2 and image_size.shape[0] == 1: + image_size = image_size[0] # [4] + if image_size.ndim != 1: + raise ValueError(f"Expected image_size shape [4] or [1,4], got {tuple(image_size.shape)}") + target_h = int(image_size[0].item()) target_w = int(image_size[1].item()) orig_h_resized = int(image_size[2].item()) @@ -309,7 +304,6 @@ def build_sequence_plan_from_mode( base_action_length = action_length - num_history_actions if mode == "forward_dynamics": condition_frame_indexes_action = list(range(action_length)) - # This currently assumes that the action length is the same as the video length - 1 # and if action length is the same as the video length, then the first action is the conditioning action elif base_action_length == video_length - 1: @@ -487,6 +481,10 @@ def __init__( self.video_temporal_downsample: int = video_temporal_downsample self.max_action_dim: int = max_action_dim self.action_channel_masking: bool = action_channel_masking + self.action_processor: ActionProcessor = ActionProcessor( + max_action_dim=max_action_dim, + action_channel_masking=action_channel_masking, + ) # --- Spatial resize/padding stage (resolution supplied at call time) --- self.video_resize: VideoResize = VideoResize( @@ -557,7 +555,12 @@ def __init__( }, ) - def __call__(self, data_dict: dict, resolution: str | None) -> dict: + def __call__( + self, + data_dict: dict, + resolution: str | None, + action_normalizer: ActionNormalizer | None = None, + ) -> dict: """Apply the transform pipeline to a single data dictionary. Resolution is required at call time and is the only source of truth @@ -576,7 +579,9 @@ def __call__(self, data_dict: dict, resolution: str | None) -> dict: sample is in inverse dynamics mode (if enabled). 7. Tokenize caption text (if enabled). 8. Build a ``SequencePlan`` from the ``"mode"`` key (if present). - 9. If action is needed by the plan, pad ``"action"`` to ``max_action_dim``. + 9. If action is needed by the plan, normalize real channels, pad + ``"action"`` to ``max_action_dim``, and attach + ``"action_processing_record"``. 10. Otherwise, nullify ``"action"`` and ``"domain_id"`` (e.g. in ``"image2video"`` mode). @@ -584,11 +589,14 @@ def __call__(self, data_dict: dict, resolution: str | None) -> dict: data_dict: A sample dictionary as returned by a Action dataset. resolution: Resolution tier key (e.g. ``"256"``, ``"480"``, ``"720"``) for this sample. When ``None``, auto-detected from video dimensions. + action_normalizer: Optional source-provided action normalizer. When + present, only unpadded real action channels are normalized + before model-space channel padding. Returns: The same dictionary, mutated in-place with padded tensors, - ``image_size``, tokenized text IDs, and a - ``"sequence_plan"`` entry added. + ``image_size``, tokenized text IDs, a ``"sequence_plan"`` entry, + and action processing metadata added. """ mode = data_dict.get("mode") assert mode is not None, "mode is required" @@ -654,13 +662,17 @@ def __call__(self, data_dict: dict, resolution: str | None) -> dict: if sequence_plan.has_action: assert isinstance(action, torch.Tensor), "action tensor is required when sequence plan has action" - data_dict["raw_action_dim"] = torch.tensor(action.shape[1]) if self.action_channel_masking else None - data_dict["action"] = pad_action_to_max_dim(action, self.max_action_dim) + data_dict = self.action_processor.preprocess_action( + data_dict, + action, + action_normalizer=action_normalizer, + ) else: # Nullify action-related fields when action is not needed so the # collate function can simply stack all non-None actions. data_dict["raw_action_dim"] = None data_dict["action"] = None data_dict["domain_id"] = None + data_dict["action_processing_record"] = None return data_dict diff --git a/cosmos_framework/data/vfm/action/transforms_test.py b/cosmos_framework/data/vfm/action/transforms_test.py index 5f024fb..759ec21 100644 --- a/cosmos_framework/data/vfm/action/transforms_test.py +++ b/cosmos_framework/data/vfm/action/transforms_test.py @@ -9,7 +9,11 @@ import torch from cosmos_framework.data.vfm.action.json_formatter import ActionPromptJsonFormatter -from cosmos_framework.data.vfm.action.transforms import ActionTransformPipeline +from cosmos_framework.data.vfm.action.transforms import ( + ActionTransformPipeline, + reflection_pad_to_target, + remove_reflection_padding, +) from cosmos_framework.data.vfm.augmentors.duration_fps_text_timestamps import DurationFPSTextTimeStamps from cosmos_framework.data.vfm.augmentors.resolution_text_info import ResolutionTextInfo @@ -60,6 +64,24 @@ def test_action_prompt_json_formatter_builds_requested_structure() -> None: assert "additional_view_description" not in result +@pytest.mark.L0 +def test_video_padding_round_trips_to_unpadded_region() -> None: + video = torch.arange(3 * 2 * 4 * 5, dtype=torch.float32).reshape(3, 2, 4, 5) # [C,T,H,W] + data_dict = {"video": video} + + padded = reflection_pad_to_target( + data_dict, + keys=["video"], + keep_aspect_ratio=True, + target_w=8, + target_h=6, + ) + round_tripped = remove_reflection_padding(padded["video"], padded["image_size"]) # [C,T,H,W] + + assert padded["video"].shape == (3, 2, 6, 8) + torch.testing.assert_close(round_tripped, video) + + @pytest.mark.L0 def test_action_prompt_json_formatter_drops_empty_fields() -> None: formatter = ActionPromptJsonFormatter() diff --git a/cosmos_framework/data/vfm/augmentor_provider.py b/cosmos_framework/data/vfm/augmentor_provider.py index 2ace65c..3e3d785 100644 --- a/cosmos_framework/data/vfm/augmentor_provider.py +++ b/cosmos_framework/data/vfm/augmentor_provider.py @@ -564,6 +564,9 @@ def get_video_augmentor_v3( conditioning_config = kwargs.get("conditioning_config", None) uniform_conditioning = kwargs.get("uniform_conditioning", False) temporal_compression_factor = kwargs.get("temporal_compression_factor", 4) + causal_vae = kwargs.get("causal_vae", True) + uniae_pad_frames = kwargs.get("uniae_pad_frames", None) + uniae_chunk_frames = kwargs.get("uniae_chunk_frames", None) print("Running video_basic_augmentor_v3...") augmentors = { @@ -577,6 +580,10 @@ def get_video_augmentor_v3( "min_stride": min_stride, "seek_mode": "exact", # Change to "approximate"? "dataset_resolution_type": dataset_resolution_type, + "resolution": resolution, + "causal_vae": causal_vae, + "uniae_pad_frames": uniae_pad_frames, + "uniae_chunk_frames": uniae_chunk_frames, }, ), "merge_datadict": L(merge_datadict.DataDictMerger)( @@ -599,6 +606,9 @@ def get_video_augmentor_v3( "conditioning_config": conditioning_config, "uniform_conditioning": uniform_conditioning, "temporal_compression_factor": temporal_compression_factor, + "resolution": resolution, + "uniae_pad_frames": uniae_pad_frames, + "uniae_chunk_frames": uniae_chunk_frames, }, ) augmentors.update( @@ -670,7 +680,6 @@ def get_video_augmentor_v3( return augmentors - # Use video_basic_augmentor_v3_json_caption instead. @augmentor_register("video_basic_augmentor_v3_with_audio") def get_video_augmentor_v3_with_audio( @@ -829,6 +838,9 @@ def get_video_augmentor_v3_json_caption( conditioning_config = kwargs.get("conditioning_config", None) uniform_conditioning = kwargs.get("uniform_conditioning", False) temporal_compression_factor = kwargs.get("temporal_compression_factor", 4) + causal_vae = kwargs.get("causal_vae", True) + uniae_pad_frames = kwargs.get("uniae_pad_frames", None) + uniae_chunk_frames = kwargs.get("uniae_chunk_frames", None) print("Running video_augmentor_v3_json_caption...") augmentors = { @@ -853,9 +865,13 @@ def get_video_augmentor_v3_json_caption( "min_stride": min_stride, "seek_mode": "exact", "dataset_resolution_type": dataset_resolution_type, + "resolution": resolution, "extract_audio": extract_audio, "audio_sample_rate": audio_sample_rate, "emit_placeholder_sound": not extract_audio, + "causal_vae": causal_vae, + "uniae_pad_frames": uniae_pad_frames, + "uniae_chunk_frames": uniae_chunk_frames, }, ), "merge_datadict": L(merge_datadict.DataDictMerger)( @@ -881,6 +897,9 @@ def get_video_augmentor_v3_json_caption( "conditioning_config": conditioning_config, "uniform_conditioning": uniform_conditioning, "temporal_compression_factor": temporal_compression_factor, + "resolution": resolution, + "uniae_pad_frames": uniae_pad_frames, + "uniae_chunk_frames": uniae_chunk_frames, }, ) augmentors.update( diff --git a/cosmos_framework/data/vfm/augmentors/__init__.py b/cosmos_framework/data/vfm/augmentors/__init__.py index 503ec1b..28a81be 100644 --- a/cosmos_framework/data/vfm/augmentors/__init__.py +++ b/cosmos_framework/data/vfm/augmentors/__init__.py @@ -1,3 +1,2 @@ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: OpenMDW-1.1 - diff --git a/cosmos_framework/data/vfm/augmentors/idle_frames_text_info.py b/cosmos_framework/data/vfm/augmentors/idle_frames_text_info.py index 302715c..34f6116 100644 --- a/cosmos_framework/data/vfm/augmentors/idle_frames_text_info.py +++ b/cosmos_framework/data/vfm/augmentors/idle_frames_text_info.py @@ -8,7 +8,7 @@ frames (i.e. the relative-pose delta is close to identity and the gripper command does not change). The upstream dataset is responsible for populating ``data_dict[idle_frames_key]`` via -:func:`projects.cosmos3.vfm.datasets.action.pose_utils.compute_idle_frames`. +:func:`cosmos_framework.data.vfm.action.pose_utils.compute_idle_frames`. Per-field dropout (default 5%) is applied here, matching Pi0.7's approach of independently dropping each metadata component. This is complementary to the diff --git a/cosmos_framework/data/vfm/augmentors/image_editing_transform.py b/cosmos_framework/data/vfm/augmentors/image_editing_transform.py index fdaaa40..4af344b 100644 --- a/cosmos_framework/data/vfm/augmentors/image_editing_transform.py +++ b/cosmos_framework/data/vfm/augmentors/image_editing_transform.py @@ -18,6 +18,7 @@ from __future__ import annotations +import json import random import torch @@ -26,13 +27,12 @@ from cosmos_framework.data.imaginaire.webdataset.augmentors.augmentor import Augmentor from cosmos_framework.utils import log -from cosmos_framework.data.vfm.sequence_packing import SequencePlan class ExtractImageEditingConversation(Augmentor): """Extract and validate image editing conversation from standard annotation format. - This augmentor processes the cosmos-interleaved conversation format for image editing: + This augmentor processes cosmos-interleaved conversation data for image editing: - Validates that the conversation has exactly one round (user + assistant) - User message must contain at least one image and text instruction - Assistant message must contain exactly one image (the edited result) @@ -42,6 +42,9 @@ class ExtractImageEditingConversation(Augmentor): - texts: Dict containing "content" with conversation data - mllm_media_list: Dict mapping image keys to PIL images (for understanding) - diffusion_media_list: Dict mapping image keys to PIL images (for diffusion/VAE) + - optional structured instruction key: Dict, JSON string, or JSON bytes containing + text_json.content and gemini_rewrite. When configured, gemini_rewrite is used as + the training prompt and text_json.content is used only to recover image references. Output Format (added to data_dict): - source_image: PIL.Image (the input image for editing) @@ -53,10 +56,124 @@ def __init__( self, input_keys: list | None = None, max_round: int = 1, + instruction_key: str = "texts", + conversation_key: str = "texts", + structured_instruction_field: str | None = None, args: dict | None = None, ) -> None: super().__init__(input_keys or [], None, args) - self.max_round = max_round + self.max_round: int = max_round + self.instruction_key: str = instruction_key + self.conversation_key: str = conversation_key + self.structured_instruction_field: str | None = structured_instruction_field + + def _decode_json_text(self, text: str, payload_name: str, sample_key: str) -> dict | None: + try: + payload = json.loads(text) + except json.JSONDecodeError as e: + log.warning( + f"Error decoding {payload_name} JSON: {sample_key}, {str(e)}", + rank0_only=False, + ) + return None + + if not isinstance(payload, dict): + log.warning( + f"Decoded {payload_name} is not a dict: {sample_key}, got {type(payload)}", + rank0_only=False, + ) + return None + return payload + + def _decode_payload(self, payload: object, payload_name: str, sample_key: str) -> dict | None: + if isinstance(payload, dict): + return payload + + if isinstance(payload, str): + return self._decode_json_text(payload, payload_name, sample_key) + + if isinstance(payload, (bytes, bytearray)): + try: + text = bytes(payload).decode("utf-8") + except UnicodeDecodeError as e: + log.warning( + f"Error decoding {payload_name} bytes as UTF-8: {sample_key}, {str(e)}", + rank0_only=False, + ) + return None + return self._decode_json_text(text, payload_name, sample_key) + + log.warning( + f"Unsupported {payload_name} payload type: {sample_key}, got {type(payload)}", + rank0_only=False, + ) + return None + + def _get_instruction_payload(self, data_dict: dict, sample_key: str) -> dict | None: + payload = data_dict.get(self.instruction_key) + if payload is None: + log.warning( + f"{self.instruction_key} not found in data_dict: {sample_key}", + rank0_only=False, + ) + return None + return self._decode_payload(payload, self.instruction_key, sample_key) + + def _get_conversation_payload( + self, + data_dict: dict, + instruction_payload: dict, + sample_key: str, + ) -> dict | None: + if self.conversation_key == self.instruction_key: + return instruction_payload + + if self.conversation_key in data_dict: + return self._decode_payload(data_dict[self.conversation_key], self.conversation_key, sample_key) + + nested_payload = instruction_payload.get(self.conversation_key) + if nested_payload is None: + log.warning( + f"{self.conversation_key} not found in {self.instruction_key}: {sample_key}", + rank0_only=False, + ) + return None + return self._decode_payload(nested_payload, f"{self.instruction_key}.{self.conversation_key}", sample_key) + + def _get_structured_instruction(self, instruction_payload: dict, sample_key: str) -> str | None: + if self.structured_instruction_field is None: + return None + + rewrite_error = instruction_payload.get("rewrite_error") + if rewrite_error is not None: + log.warning( + f"Structured instruction rewrite_error is non-null: {sample_key}, {rewrite_error}", + rank0_only=False, + ) + return None + + structured_payload = instruction_payload.get(self.structured_instruction_field) + if not isinstance(structured_payload, dict): + log.warning( + f"{self.structured_instruction_field} missing or not a dict: {sample_key}", + rank0_only=False, + ) + return None + + edit_type = structured_payload.get("edit_type") + structured_instruction = structured_payload.get("structured_instruction") + if not isinstance(edit_type, str) or not edit_type: + log.warning(f"Structured instruction edit_type missing: {sample_key}", rank0_only=False) + return None + if not isinstance(structured_instruction, dict) or not structured_instruction: + log.warning(f"Structured instruction body missing: {sample_key}", rank0_only=False) + return None + + prompt = { + "edit_type": edit_type, + "structured_instruction": structured_instruction, + } + return json.dumps(prompt, ensure_ascii=False) def __call__(self, data_dict: dict) -> dict | None: """Extract image editing conversation. @@ -69,23 +186,30 @@ def __call__(self, data_dict: dict) -> dict | None: or None if the data is invalid. """ # Validate required keys - for required_key in ["mllm_media_list", "diffusion_media_list", "texts"]: + sample_key = data_dict.get("__key__", "unknown") + for required_key in ["diffusion_media_list", self.instruction_key]: if required_key not in data_dict: log.warning( - f"{required_key} not found in data_dict: {data_dict.get('__key__', 'unknown')}", + f"{required_key} not found in data_dict: {sample_key}", rank0_only=False, ) return None - mllm_media_list = data_dict["mllm_media_list"] diffusion_media_list = data_dict["diffusion_media_list"] + instruction_payload = self._get_instruction_payload(data_dict, sample_key) + if instruction_payload is None: + return None + conversation_payload = self._get_conversation_payload(data_dict, instruction_payload, sample_key) + if conversation_payload is None: + return None + conversation_content_key = f"{self.conversation_key}.content" # Get conversation content try: - texts_content = data_dict["texts"].get("content") + texts_content = conversation_payload.get("content") if texts_content is None: log.warning( - f"texts.content is None: {data_dict.get('__key__', 'unknown')}", + f"{conversation_content_key} is None: {sample_key}", rank0_only=False, ) return None @@ -99,13 +223,13 @@ def __call__(self, data_dict: dict) -> dict | None: selected_conversations = texts_content else: log.warning( - f"Unexpected texts.content format: {data_dict.get('__key__', 'unknown')}", + f"Unexpected {conversation_content_key} format: {sample_key}", rank0_only=False, ) return None except Exception as e: log.warning( - f"Error accessing texts.content: {data_dict.get('__key__', 'unknown')}, {str(e)}", + f"Error accessing {conversation_content_key}: {sample_key}, {str(e)}", rank0_only=False, ) return None @@ -115,15 +239,14 @@ def __call__(self, data_dict: dict) -> dict | None: if len(selected_conversations) > 2: log.warning( f"Multi-round conversation found ({len(selected_conversations)} messages), " - f"keeping only first round: {data_dict.get('__key__', 'unknown')}", + f"keeping only first round: {sample_key}", rank0_only=False, ) selected_conversations = selected_conversations[:2] if len(selected_conversations) < 2: log.warning( - f"Expected at least 2 messages (user + assistant), got {len(selected_conversations)}: " - f"{data_dict.get('__key__', 'unknown')}", + f"Expected at least 2 messages (user + assistant), got {len(selected_conversations)}: {sample_key}", rank0_only=False, ) return None @@ -134,14 +257,14 @@ def __call__(self, data_dict: dict) -> dict | None: if user_msg.get("role") != "user": log.warning( - f"First message role is not 'user': {data_dict.get('__key__', 'unknown')}", + f"First message role is not 'user': {sample_key}", rank0_only=False, ) return None if assistant_msg.get("role") != "assistant": log.warning( - f"Second message role is not 'assistant': {data_dict.get('__key__', 'unknown')}", + f"Second message role is not 'assistant': {sample_key}", rank0_only=False, ) return None @@ -167,24 +290,29 @@ def __call__(self, data_dict: dict) -> dict | None: if user_image_key is None: log.warning( - f"No image found in user message: {data_dict.get('__key__', 'unknown')}", + f"No image found in user message: {sample_key}", rank0_only=False, ) return None - editing_instruction = " ".join(user_text_parts).strip() - if not editing_instruction: - log.warning( - f"No text instruction found in user message: {data_dict.get('__key__', 'unknown')}", - rank0_only=False, - ) - return None + if self.structured_instruction_field is None: + editing_instruction = " ".join(user_text_parts).strip() + if not editing_instruction: + log.warning( + f"No text instruction found in user message: {sample_key}", + rank0_only=False, + ) + return None + else: + editing_instruction = self._get_structured_instruction(instruction_payload, sample_key) + if editing_instruction is None: + return None # Extract assistant content: must have exactly one image assistant_content = assistant_msg.get("content", []) if isinstance(assistant_content, str): log.warning( - f"Assistant content is text-only (no image): {data_dict.get('__key__', 'unknown')}", + f"Assistant content is text-only (no image): {sample_key}", rank0_only=False, ) return None @@ -199,7 +327,7 @@ def __call__(self, data_dict: dict) -> dict | None: if assistant_image_key is None: log.warning( - f"No image found in assistant message: {data_dict.get('__key__', 'unknown')}", + f"No image found in assistant message: {sample_key}", rank0_only=False, ) return None @@ -208,7 +336,7 @@ def __call__(self, data_dict: dict) -> dict | None: for media_key in [user_image_key, assistant_image_key]: if media_key not in diffusion_media_list: log.warning( - f"Image {media_key} not found in diffusion_media_list: {data_dict.get('__key__', 'unknown')}", + f"Image {media_key} not found in diffusion_media_list: {sample_key}", rank0_only=False, ) return None @@ -225,7 +353,7 @@ def __call__(self, data_dict: dict) -> dict | None: if source_image is None or target_image is None: log.warning( - f"Source or target image is None: {data_dict.get('__key__', 'unknown')}", + f"Source or target image is None: {sample_key}", rank0_only=False, ) return None @@ -329,6 +457,8 @@ def __call__(self, data_dict: dict) -> dict | None: # by GenerationDataClean.num_vision_items_per_sample (set in get_data_and_condition). # In pack_input_sequence, all items except the last are fully conditioned; # the last item uses condition_frame_indexes_vision ([] = fully generated). + from cosmos_framework.data.vfm.sequence_packing import SequencePlan + data_dict["sequence_plan"] = SequencePlan( has_text=True, has_vision=True, diff --git a/cosmos_framework/data/vfm/augmentors/image_editing_transform_test.py b/cosmos_framework/data/vfm/augmentors/image_editing_transform_test.py new file mode 100644 index 0000000..849ad00 --- /dev/null +++ b/cosmos_framework/data/vfm/augmentors/image_editing_transform_test.py @@ -0,0 +1,159 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: OpenMDW-1.1 + +import json + +import pytest +from PIL import Image + +from cosmos_framework.data.vfm.augmentors.image_editing_transform import ExtractImageEditingConversation + +_STRUCTURED_KEY = "edit_schema_all_inputs_qwen3-vl-235b-a22b-instruct" + + +def _conversation(instruction: str = "Make the cup red") -> list[list[dict]]: + return [ + [ + { + "role": "user", + "content": [ + {"type": "image", "image": "image_0"}, + {"type": "text", "text": instruction}, + ], + }, + { + "role": "assistant", + "content": [ + {"type": "image", "image": "image_1"}, + ], + }, + ] + ] + + +def _media_data() -> tuple[Image.Image, Image.Image, dict[str, Image.Image]]: + source_image = Image.new("RGB", (16, 16), color="blue") + target_image = Image.new("RGB", (16, 16), color="red") + media_list = { + "image_0": source_image, + "image_1": target_image, + } + return source_image, target_image, media_list + + +def _base_data_dict() -> tuple[dict, Image.Image, Image.Image]: + source_image, target_image, media_list = _media_data() + data_dict = { + "__key__": "sample_000001", + "mllm_media_list": media_list, + "diffusion_media_list": media_list, + } + return data_dict, source_image, target_image + + +def _structured_payload() -> dict: + return { + "rewrite_error": None, + "gemini_rewrite": { + "edit_type": "adjust", + "structured_instruction": { + "target_object": "cup", + "attribute_type": "color", + "desired_value": "red", + }, + }, + "text_json": { + "content": _conversation("Original dense instruction"), + }, + "original_instruction": "Original dense instruction", + } + + +@pytest.mark.L0 +@pytest.mark.CPU +def test_extract_image_editing_conversation_keeps_texts_behavior() -> None: + data_dict, source_image, target_image = _base_data_dict() + data_dict["texts"] = {"content": _conversation("Make the cup red")} + + result = ExtractImageEditingConversation()(data_dict) + + assert result is not None + assert result["source_image"] is source_image + assert result["target_image"] is target_image + assert result["editing_instruction"] == "Make the cup red" + + +@pytest.mark.L0 +@pytest.mark.CPU +def test_extract_structured_dict_payload_uses_gemini_rewrite() -> None: + data_dict, source_image, target_image = _base_data_dict() + payload = _structured_payload() + data_dict[_STRUCTURED_KEY] = payload + + result = ExtractImageEditingConversation( + instruction_key=_STRUCTURED_KEY, + conversation_key="text_json", + structured_instruction_field="gemini_rewrite", + )(data_dict) + + expected_instruction = json.dumps( + { + "edit_type": payload["gemini_rewrite"]["edit_type"], + "structured_instruction": payload["gemini_rewrite"]["structured_instruction"], + }, + ensure_ascii=False, + ) + assert result is not None + assert result["source_image"] is source_image + assert result["target_image"] is target_image + assert result["editing_instruction"] == expected_instruction + + +@pytest.mark.L0 +@pytest.mark.CPU +@pytest.mark.parametrize("encode_as_bytes", [False, True]) +def test_extract_structured_json_payload_uses_gemini_rewrite(encode_as_bytes: bool) -> None: + data_dict, _, _ = _base_data_dict() + payload = _structured_payload() + payload_json = json.dumps(payload, ensure_ascii=False) + data_dict[_STRUCTURED_KEY] = payload_json.encode("utf-8") if encode_as_bytes else payload_json + + result = ExtractImageEditingConversation( + instruction_key=_STRUCTURED_KEY, + conversation_key="text_json", + structured_instruction_field="gemini_rewrite", + )(data_dict) + + expected_instruction = json.dumps( + { + "edit_type": payload["gemini_rewrite"]["edit_type"], + "structured_instruction": payload["gemini_rewrite"]["structured_instruction"], + }, + ensure_ascii=False, + ) + assert result is not None + assert result["editing_instruction"] == expected_instruction + + +@pytest.mark.L0 +@pytest.mark.CPU +@pytest.mark.parametrize( + "payload_update", + [ + {"gemini_rewrite": None}, + {"rewrite_error": "failed to rewrite"}, + ], +) +def test_extract_structured_invalid_payload_returns_none(payload_update: dict) -> None: + data_dict, _, _ = _base_data_dict() + payload = _structured_payload() + payload.update(payload_update) + data_dict[_STRUCTURED_KEY] = payload + + result = ExtractImageEditingConversation( + instruction_key=_STRUCTURED_KEY, + conversation_key="text_json", + structured_instruction_field="gemini_rewrite", + )(data_dict) + + assert result is None diff --git a/cosmos_framework/data/vfm/augmentors/interleaved_video_parsing.py b/cosmos_framework/data/vfm/augmentors/interleaved_video_parsing.py index aea759e..41e4e4c 100644 --- a/cosmos_framework/data/vfm/augmentors/interleaved_video_parsing.py +++ b/cosmos_framework/data/vfm/augmentors/interleaved_video_parsing.py @@ -414,7 +414,7 @@ def __call__(self, data_dict: dict) -> dict | None: ) # [C,T,H,W] num_multiplier = (end_frame - start_frame) / self.num_frames - + # NOTE: matches legacy VideoParsing.__call__ output keys exactly. Do NOT add # variable-length fields like ``frame_indices`` here -- ``video_flatten_keys`` in # ``get_video_transfer_augmentor`` lists ``frame_indices``, and surfacing a # per-sample list there would crash ``custom_collate_fn`` (default_collate requires diff --git a/cosmos_framework/data/vfm/augmentors/pkl_to_media.py b/cosmos_framework/data/vfm/augmentors/pkl_to_media.py index 54d47b8..aa9eb21 100644 --- a/cosmos_framework/data/vfm/augmentors/pkl_to_media.py +++ b/cosmos_framework/data/vfm/augmentors/pkl_to_media.py @@ -27,14 +27,12 @@ def token_to_pixels(token_length: int, patch_size: int = 14, temporal_patch_size: int = 2) -> int: """Convert token length to pixels based on patch size and temporal patch size.""" - merged_patch_size = patch_size * 2 return token_length * merged_patch_size**2 * temporal_patch_size def pixels_to_token(pixels: int, patch_size: int = 14, temporal_patch_size: int = 2) -> int: """Convert pixels to token length based on patch size and temporal patch size.""" - merged_patch_size = patch_size * 2 return pixels // merged_patch_size**2 // temporal_patch_size diff --git a/cosmos_framework/data/vfm/augmentors/sequence_plan.py b/cosmos_framework/data/vfm/augmentors/sequence_plan.py index 9f0500a..5a2202f 100644 --- a/cosmos_framework/data/vfm/augmentors/sequence_plan.py +++ b/cosmos_framework/data/vfm/augmentors/sequence_plan.py @@ -7,15 +7,22 @@ - weighted dict (``conditioning_config``): explicit frame-count → probability pairs - uniform (``uniform_conditioning=True``): k ~ Uniform{0, T_latent-1}, where T_latent is computed from the actual video length using the VAE temporal compression factor + or UniAE chunking parameters when provided. """ import random +from collections.abc import Mapping from typing import Optional import torch from cosmos_framework.data.imaginaire.webdataset.augmentors.augmentor import Augmentor from cosmos_framework.data.vfm.sequence_packing import SequencePlan +from cosmos_framework.model.vfm.tokenizers.uniae.frame_math import ( + get_uniae_chunk_frames, + get_uniae_latent_num_frames, + normalize_uniae_chunk_frames, +) class SequencePlanAugmentor(Augmentor): @@ -37,6 +44,11 @@ class SequencePlanAugmentor(Augmentor): must be provided. - "temporal_compression_factor" (int, default 4): VAE temporal compression factor used to convert pixel frame count N to T_latent = 1 + (N-1) // tcf. + - "uniae_chunk_frames" / "uniae_pad_frames" (optional): When provided, + use UniAE's non-causal first-frame plus padded-chunk latent count. + ``uniae_chunk_frames`` may be a scalar or a resolution-keyed mapping. + - "resolution" (str, optional): Target dataset resolution key. Preferred over + the current tensor shape when selecting a resolution-keyed UniAE chunk. """ def __init__(self, input_keys: list, output_keys: Optional[list] = None, args: Optional[dict] = None) -> None: @@ -48,6 +60,9 @@ def __init__(self, input_keys: list, output_keys: Optional[list] = None, args: O self.conditioning_config = args.get("conditioning_config") self.uniform_conditioning = args.get("uniform_conditioning", False) self.temporal_compression_factor = args.get("temporal_compression_factor", 4) + self.target_resolution_key = None if args.get("resolution") is None else str(args["resolution"]) + self.uniae_pad_frames = None if args.get("uniae_pad_frames") is None else int(args["uniae_pad_frames"]) + self.uniae_chunk_frames = self._normalize_uniae_chunk_frames(args.get("uniae_chunk_frames")) if self.conditioning_config is None and not self.uniform_conditioning: raise ValueError("args must provide 'conditioning_config' or set 'uniform_conditioning=True'") @@ -70,6 +85,43 @@ def __init__(self, input_keys: list, output_keys: Optional[list] = None, args: O else: self.normalized_config = {0: 1.0} + def _normalize_uniae_chunk_frames( + self, uniae_chunk_frames: int | Mapping[str, int] | None + ) -> int | dict[str, int] | None: + return normalize_uniae_chunk_frames( + uniae_chunk_frames, + pad_frames=self.uniae_pad_frames, + temporal_compression_factor=self.temporal_compression_factor, + ) + + def _get_uniae_chunk_frames(self, spatial_shape: tuple[int, int] | None = None) -> int: + assert self.uniae_chunk_frames is not None + return get_uniae_chunk_frames( + self.uniae_chunk_frames, + spatial_shape=spatial_shape, + target_resolution_key=self.target_resolution_key, + ) + + def _get_latent_frame_count(self, num_frames: int | None, spatial_shape: tuple[int, int] | None = None) -> int: + if num_frames is None: + return 1 + if num_frames < 1: + raise ValueError(f"video must contain at least one frame, got {num_frames}") + if num_frames == 1: + return 1 + if self.uniae_chunk_frames is None: + return 1 + (num_frames - 1) // self.temporal_compression_factor + + assert self.uniae_pad_frames is not None + return get_uniae_latent_num_frames( + num_frames, + self.uniae_chunk_frames, + pad_frames=self.uniae_pad_frames, + temporal_compression_factor=self.temporal_compression_factor, + spatial_shape=spatial_shape, + target_resolution_key=self.target_resolution_key, + ) + def __call__(self, data_dict: dict) -> dict: """Create a SequencePlan with random conditional frames. @@ -94,15 +146,17 @@ def __call__(self, data_dict: dict) -> dict: # Determine number of frames # Video should be a tensor with shape (C, T, H, W) by this point in the pipeline + spatial_shape = None if isinstance(video, torch.Tensor): assert video.ndim == 4, "video should be a tensor with shape (C, T, H, W)" - num_frames = video.shape[1] + num_frames = video.shape[1] # video: [C,T,H,W] + spatial_shape = (video.shape[2], video.shape[3]) else: # If video is not a tensor or dict, we can't determine the exact number # Use a conservative approach - will be limited by max available frames num_frames = None - T_latent = 1 + (num_frames - 1) // self.temporal_compression_factor if num_frames is not None else 1 + T_latent = self._get_latent_frame_count(num_frames, spatial_shape) # Sample number of conditional frames if self.uniform_conditioning: diff --git a/cosmos_framework/data/vfm/augmentors/text_transforms_for_image.py b/cosmos_framework/data/vfm/augmentors/text_transforms_for_image.py index edaef11..d38fae4 100644 --- a/cosmos_framework/data/vfm/augmentors/text_transforms_for_image.py +++ b/cosmos_framework/data/vfm/augmentors/text_transforms_for_image.py @@ -8,13 +8,17 @@ from cosmos_framework.data.imaginaire.webdataset.augmentors.v3_text_transforms import pad_and_resize from cosmos_framework.data.imaginaire.webdataset.augmentors.augmentor import Augmentor from cosmos_framework.utils import log -from cosmos_framework.data.vfm.data_sources.data_registration import _CAPTION_EMBEDDING_KEY_MAPPING_IMAGES # For the qwen captions, we have 3 variants: short, medium, long # In addition, for synthetic data, we create prompt embeddings as well. # There is quite a bit of entropy in the way prompt data is saved. # Captions are saved as "prompts", while the corresponding embeddings are saved as "original_prompt" # This part will be cleaned after synthetic data is cleaned to be in the same format as real data. +_CAPTION_EMBEDDING_KEY_MAPPING_IMAGES = { + "ai_v3p1": "ai_v3p1", + "qwen2p5_7b_v4": "qwen2p5_7b_v4", + "prompts": "qwen2p5_7b_v4", +} _AVAILABLE_QWEN_CAPTIONS = ["qwen2p5_7b_short", "qwen2p5_7b_medium", "qwen2p5_7b_long"] _AVAILABLE_QWEN3_30B_A3B_CAPTIONS = [ "qwen3_30b_a3b_short", diff --git a/cosmos_framework/data/vfm/augmentors/transfer_control_transform.py b/cosmos_framework/data/vfm/augmentors/transfer_control_transform.py index 74fb523..a18e8fd 100644 --- a/cosmos_framework/data/vfm/augmentors/transfer_control_transform.py +++ b/cosmos_framework/data/vfm/augmentors/transfer_control_transform.py @@ -5,7 +5,7 @@ Augmentors for transfer (control-conditioned) image and video generation in the cosmos3 VFM pipeline. Transfer training conditions the model on control signals (edge, blur, depth, or segmentation) -to generate images or videos, aligned with cosmos_framework/transfer2. This module provides: +to generate images or videos, aligned with cosmos/transfer2. This module provides: - **TransferToTrainingFormat**: Converts (control_input, target) into the joint dataloader format with SequencePlan (condition frame + generated frame), for both image and video outputs. diff --git a/cosmos_framework/data/vfm/augmentors/video_parsing.py b/cosmos_framework/data/vfm/augmentors/video_parsing.py index cfaa934..25a5580 100644 --- a/cosmos_framework/data/vfm/augmentors/video_parsing.py +++ b/cosmos_framework/data/vfm/augmentors/video_parsing.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: OpenMDW-1.1 import random +from collections.abc import Mapping from typing import Optional import numpy as np @@ -15,12 +16,18 @@ from cosmos_framework.data.imaginaire.webdataset.augmentors.image.misc import obtain_augmentation_size from cosmos_framework.utils import log from cosmos_framework.data.vfm.utils import VIDEO_RES_SIZE_INFO +from cosmos_framework.model.vfm.tokenizers.uniae.frame_math import ( + align_uniae_num_video_frames, + get_uniae_chunk_frames, + normalize_uniae_chunk_frames, +) # Map dataset_resolution_type to resolution tier key in VIDEO_RES_SIZE_INFO _DATASET_RESOLUTION_TIER: dict[str, str] = {"gt480p": "480", "gt720p": "720", "gt1080p": "1080"} _MIN_FPS = 10 _MAX_FPS = 60 +_UNIAE_TEMPORAL_COMPRESSION_FACTOR = 4 class VideoParsing(Augmentor): @@ -345,7 +352,7 @@ def __call__(self, data_dict: dict) -> dict | None: video_info["video"] = video_frames video_info["num_multiplier"] = num_multiplier # Store the frame skipping multiplier - + # NOTE: Explaining the logic of conditioning FPS calculation: # 1. Our video parser stores the original video FPS of the video. # 2. We have multiple modes of frame selection -- consecutive chunk of frames or subsampled frames. # Here's what we do in each case: @@ -434,6 +441,45 @@ def __init__(self, input_keys: list, output_keys: Optional[list] = None, args: O self.dataset_resolution_type = args.get("dataset_resolution_type", "all") self.resolution_tier = _DATASET_RESOLUTION_TIER.get(self.dataset_resolution_type) + # VAE temporal alignment mode. + # causal_vae=True (default): align to 1+4N (causal VAE, e.g. Wan 2.2) + # causal_vae=False: align to 4N (non-causal VAE, e.g. UniAE) + self.causal_vae = args.get("causal_vae", True) + self.target_resolution_key = None if args.get("resolution") is None else str(args["resolution"]) + self.uniae_pad_frames = None if args.get("uniae_pad_frames") is None else int(args["uniae_pad_frames"]) + self.uniae_chunk_frames = self._normalize_uniae_chunk_frames(args.get("uniae_chunk_frames", None)) + + def _normalize_uniae_chunk_frames( + self, uniae_chunk_frames: int | Mapping[str, int] | None + ) -> int | dict[str, int] | None: + return normalize_uniae_chunk_frames( + uniae_chunk_frames, + pad_frames=self.uniae_pad_frames, + temporal_compression_factor=_UNIAE_TEMPORAL_COMPRESSION_FACTOR, + missing_pad_message="uniae_pad_frames must be specified if uniae_chunk_frames is specified", + temporal_divisibility_name="UniAE temporal compression factor", + ) + + def _get_uniae_chunk_frames(self, spatial_shape: tuple[int, int] | None = None) -> int: + assert self.uniae_chunk_frames is not None + return get_uniae_chunk_frames( + self.uniae_chunk_frames, + spatial_shape=spatial_shape, + target_resolution_key=self.target_resolution_key, + ) + + def _align_uniae_num_video_frames(self, num_video_frames: int, spatial_shape: tuple[int, int] | None = None) -> int: + assert self.uniae_pad_frames is not None + assert self.uniae_chunk_frames is not None + return align_uniae_num_video_frames( + num_video_frames, + self.uniae_chunk_frames, + pad_frames=self.uniae_pad_frames, + temporal_compression_factor=_UNIAE_TEMPORAL_COMPRESSION_FACTOR, + spatial_shape=spatial_shape, + target_resolution_key=self.target_resolution_key, + ) + def _sample_stride_with_bias(self, max_stride: int, min_stride: int = 1) -> int: """Sample a stride from [min_stride, max_stride] with bias controlled by low_fps_bias. @@ -520,7 +566,6 @@ def _validate_and_probe(self, video: Optional[bytes], meta_dict: dict, data_dict return True def __call__(self, data_dict: dict) -> dict | None: - # if in future we need to train with batch size > 1, need to pad frames try: meta_dict = data_dict[self.meta_key] @@ -553,8 +598,10 @@ def __call__(self, data_dict: dict) -> dict | None: f"Resize error. orig {(orig_w, orig_h)} desire {img_size} compute {target_size}" ) transform = [Resize(target_size)] + output_spatial_shape = target_size else: transform = None + output_spatial_shape = (meta_dict["height"], meta_dict["width"]) # Adding try-expcept because some of the data is bad and video decoding call fail. try: @@ -569,11 +616,33 @@ def __call__(self, data_dict: dict) -> dict | None: stride = self._sample_stride_with_bias(self.max_stride, self.min_stride) frame_indices = np.arange(0, num_video_frames, stride).tolist() - # VAE compress temporal by 4x, with 1 as condition - # thus the max_video_frames must be 1 + 4N + # Align frame count to the active VAE temporal contract. + # causal_vae=True: 1+4N (causal VAE, e.g. Wan 2.2). + # causal_vae=False: UniAE chunk/pad alignment if configured; otherwise 4N. num_video_frames = min(len(frame_indices), self.args.get("max_num_frames", 1000)) - N = (num_video_frames - 1) // 4 - num_video_frames = 1 + 4 * N + if self.causal_vae: + N = (num_video_frames - 1) // 4 + num_video_frames = 1 + 4 * N + else: + # If this is UniAE, we need to align the frame count to the chunk size and padding. + if self.uniae_chunk_frames is not None: + # T is valid when r = (T-1) % effective_chunk_frames satisfies: + # r == 0 (exact multiple of chunks) + # OR r % 4 == target_r where target_r = (-2*pad_frames) % 4 + # Compute minimum trim delta in O(1): + # delta = steps to nearest r' <= r satisfying the condition. + num_video_frames = self._align_uniae_num_video_frames(num_video_frames, output_spatial_shape) + + if num_video_frames == 0: + log.warning( + f"VideoParsingWithFullFrames: video too short for UniAE. " + f"url: {data_dict['__url__']}, key: {data_dict['__key__']}", + rank0_only=False, + ) + return None + else: + N = num_video_frames // 4 + num_video_frames = 4 * N frame_indices = frame_indices[0:num_video_frames] frame_batch = video_decoder.get_frames_at(frame_indices) @@ -698,7 +767,6 @@ def __init__(self, input_keys: list, output_keys: Optional[list] = None, args: O super().__init__(input_keys, output_keys, args) def __call__(self, data_dict: dict) -> dict | None: - # if in future we need to train with batch size > 1, need to pad frames try: meta_dict = data_dict[self.meta_key] @@ -743,8 +811,10 @@ def __call__(self, data_dict: dict) -> dict | None: f"Resize error. orig {(orig_w, orig_h)} desire {img_size} compute {target_size}" ) transform = [Resize(target_size)] + output_spatial_shape = target_size else: transform = None + output_spatial_shape = (meta_dict["height"], meta_dict["width"]) # Adding try-expcept because some of the data is bad and video decoding call fail. try: @@ -772,11 +842,19 @@ def __call__(self, data_dict: dict) -> dict | None: stride = self._sample_stride_with_bias(self.max_stride, self.min_stride) frame_indices = np.arange(chunk_start_clamped, chunk_end_clamped, stride).tolist() - # VAE compress temporal by 4x, with 1 as condition - # thus the max_video_frames must be 1 + 4N + # Align frame count to the active VAE temporal contract. + # causal_vae=True: 1+4N (causal VAE, e.g. Wan 2.2). + # causal_vae=False: UniAE chunk/pad alignment if configured; otherwise 4N. num_video_frames = min(len(frame_indices), self.args.get("max_num_frames", 1000)) - N = (num_video_frames - 1) // 4 - num_video_frames = 1 + 4 * N + if self.causal_vae: + N = (num_video_frames - 1) // 4 + num_video_frames = 1 + 4 * N + else: + if self.uniae_chunk_frames is not None: + num_video_frames = self._align_uniae_num_video_frames(num_video_frames, output_spatial_shape) + else: + N = num_video_frames // 4 + num_video_frames = 4 * N if num_video_frames < 1: log.warning( f"VideoParsingChunkedFrames: chunk too short for stride. " diff --git a/cosmos_framework/data/vfm/augmentors/vlm/__init__.py b/cosmos_framework/data/vfm/augmentors/vlm/__init__.py index 503ec1b..28a81be 100644 --- a/cosmos_framework/data/vfm/augmentors/vlm/__init__.py +++ b/cosmos_framework/data/vfm/augmentors/vlm/__init__.py @@ -1,3 +1,2 @@ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: OpenMDW-1.1 - diff --git a/cosmos_framework/data/vfm/augmentors/vlm/nvlm_data_unify.py b/cosmos_framework/data/vfm/augmentors/vlm/nvlm_data_unify.py deleted file mode 100644 index eb029eb..0000000 --- a/cosmos_framework/data/vfm/augmentors/vlm/nvlm_data_unify.py +++ /dev/null @@ -1,120 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: OpenMDW-1.1 - -"""Visual-Text Transformations or Augmentations.""" - -import io -from typing import Dict, Optional - -from PIL import Image - -from cosmos_framework.data.imaginaire.webdataset.augmentors.augmentor import Augmentor -from cosmos_framework.utils import log -from cosmos_framework.data.vfm.augmentors.vlm.nvlm_sample_loaders_and_part_filters import ( - get_data_class, - get_part_filter, - get_sample_loader, -) - - -class NVLMImageDataUnify(Augmentor): - """ - This augmentor is used to unify the data format of the nvlm data. - It will take the raw nvlm data tar and convert it to a dictionary with the following keys: - { - "__url__": str, - "__key__": str, - "data_class": str, - "images": List[PIL.Image.Image], - "text": str, - "words_boxes": Optional[List[List[int]]], - "words_text": Optional[List[str]], - "similarity_matrix": Optional[List[List[float]]], - } - """ - - def __init__( - self, - input_keys: list = ["raw_nvlm"], - output_keys: Optional[list] = [], - args: Optional[dict] = None, - data_path_prefix: list[str] = [ - "cosmos_framework/ar/v2/nvlm/", - ], # prefix of the data in s3 - ) -> None: - super().__init__(input_keys, output_keys, args) - self.data_path_prefix = data_path_prefix - - def convert_image(self, img): - try: - if isinstance(img, bytes): - img = Image.open(io.BytesIO(img)).convert("RGB") - elif isinstance(img, Image.Image): - img = img.convert("RGB") - pass # Image is already in PIL format - elif isinstance(img, list): - for i in range(len(img)): - img[i], success = self.convert_image(img[i]) - if not success: - return Image.new("RGB", (256, 256), (0, 0, 0)), False - return img, True - else: - raise ValueError(f"Invalid image type: {type(img)}") - - success = True - except Exception as e: - log.warning(f"Error processing image: {e}. Creating an empty black image.", rank0_only=False) - img = Image.new("RGB", (256, 256), (0, 0, 0)) # Creates a 256x256 black image - success = False - return img, success - - def __call__(self, data_dict: Dict) -> Dict: - url = data_dict["__url__"] - data_path = "/".join(url.path.split("/")[:-1]) # remove the last part of the path - sample_loader = get_sample_loader(data_path) - part_filter = get_part_filter(data_path) - data_class = get_data_class(data_path) - assert sample_loader is not None and part_filter is not None and data_class is not None, ( - f"sample_loader({sample_loader}) or part_filter({part_filter}) or data_class({data_class}) is not found for {data_path}" - ) - - raw = {"__url__": url, "__key__": data_dict["__key__"]} - output = {"__url__": url, "__key__": data_dict["__key__"]} - for k, v in data_dict.items(): - ext = k.split(".")[-1] - if part_filter(ext): - raw[ext] = v - try: - output_converted = sample_loader(raw) - # Here output_converted will be a dictionary with the following keys: - # { - # "__key__": str, - # "image": PIL.Image.Image, - # "images": List[PIL.Image.Image], - # "text": str, - # "words_boxes": Optional - # "words_text": Optional - # "similarity_matrix": Optional - # } - except Exception as e: - log.warning( - f"Error in sample_loader: {e}, sample_loader: {sample_loader}, data_path: {data_path}, raw: {raw.keys()}, original_data_dict: {data_dict.keys()}, __url__: {url}, __key__: {data_dict['__key__']}" - ) - return None - - output.update(output_converted) - if "image" not in output_converted and "images" not in output_converted: - success = False - log.warning(f"image not found in {output_converted.keys()}") - if "image" in output_converted: # Single image case - img, success = self.convert_image(output["image"]) - output["images"] = [img] # What should be the format for the iamges - elif "images" in output_converted: - output["images"] = output_converted["images"] - output["images"], success = self.convert_image(output["images"]) - if not success: - log.warning(f"image conversion failed for {data_dict['__key__']} url: {url} | Skip this data") - return None - output["data_class"] = data_class - - return output diff --git a/cosmos_framework/data/vfm/augmentors/vlm/nvlm_sample_loaders_and_part_filters.py b/cosmos_framework/data/vfm/augmentors/vlm/nvlm_sample_loaders_and_part_filters.py deleted file mode 100644 index fabe0c3..0000000 --- a/cosmos_framework/data/vfm/augmentors/vlm/nvlm_sample_loaders_and_part_filters.py +++ /dev/null @@ -1,2815 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: OpenMDW-1.1 - -# Combined Sample Loaders -# Auto-generated script combining all sample_loader.py files (Dont edit this file! Edit the projects/cosmos/reasoning/v1/scripts/create_sample_loader_and_part_filter_file.py instead) - -import io - -import torch -from PIL import Image - -from cosmos_framework.utils import log -from cosmos_framework.data.vfm.data_sources.vlm.nvlm import data_path_mapping - -# This file was automatically generated by `nvgpt4 data prepare`. - -# import torch - - -def sample_loader_0(raw: dict) -> dict: # Note: Images are already decoded to tensors - - if "text" in raw: - caption = raw["text"] - else: - caption = raw["json"]["caption"] - return dict( - __key__=raw["__key__"], - image=raw["jpg"], # expected type: torch.Tensor - caption=caption, # expected type: str - ) - - -def part_filter_0(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("json", "jpg", "text") - - -# This file was automatically generated by `energon prepare`. - - - -def sample_loader_1(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - images = [raw["jpg"]] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_1(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("jpg", "json") - - -# This file was automatically generated by `energon prepare`. - - - -def sample_loader_2(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - images = [raw["jpg"]] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_2(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("jpg", "json") - - -# This file was automatically generated by `energon prepare`. - - - -def sample_loader_3(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - images = [raw["jpg"]] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_3(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("jpg", "json") - - -# This file was automatically generated by `energon prepare`. - - - -def sample_loader_4(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - images = [raw["jpg"]] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_4(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("jpg", "json") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_5(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - images = [raw["jpg"]] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_5(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("json", "jpg") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_6(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - - key = raw["__key__"] - if "docvqa" in key: - context = json_item["question"] - answers = json_item["answers"] - image = raw["jpg"] - answer_weights = json_item["answer_weights"] - elif "textvqa" in key or "lrv_instruct" in key: - context = json_item["question"] - answers = json_item["answer"] - image = raw["jpg"] - answer_weights = None - elif "stvqa" in key: - context = json_item["question"] - answers = json_item["answers"] - image = raw["jpg"] - answer_weights = [1.0] * len(json_item["answers"]) - elif "chartqa" in key: - context = json_item["query"] - answers = json_item["label"] - image = raw["png"] - answer_weights = None - elif "screenqa" in key: - image = raw["jpg"] - context = json_item["question"] - answers = json_item["ground_truth"] - answer_weights = [1.0] * len(json_item["ground_truth"]) - elif "HME100K" in key: - image = raw["jpg"] - context = "Please write out the expression of the formula in the image using LaTeX format." - answers = json_item["latex_formula"] - answer_weights = None - else: # scale, textbook - image = raw["jpg"] - context = json_item["question"] - answers = json_item["answer"] - answer_weights = None - - return dict( - __key__=key, - image=image, - context=context, - answers=answers, - answer_weights=answer_weights, - ) - - -def part_filter_6(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("json", "jpg", "png") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_7(raw: dict) -> dict: # Note: Images are already decoded to tensors - j = raw["json"] - - return dict( - __key__=raw["__key__"], - image=raw["jpg"], # expected type: torch.Tensor - context=j["question_string"], # expected type: str - answers=j["answer"], # expected type: typing.Union[typing.List[str], NoneType], default: None - answer_weights=None, # expected type: typing.Union[torch.Tensor, NoneType], default: None - ) - - -def part_filter_7(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("jpg", "json") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_8(raw: dict) -> dict: # Note: Images are already decoded to tensors - j = raw["json"] - - return dict( - __key__=raw["__key__"], - image=raw["jpg"], # expected type: torch.Tensor - context=j["question"], # expected type: str - answers=str(j["answer"]), # expected type: typing.Optional[typing.List[str]], default: None - answer_weights=None, # expected type: typing.Optional[torch.Tensor], default: None - ) - - -def part_filter_8(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("jpg", "json") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_9(raw: dict) -> dict: # Note: Images are already decoded to tensors - j = raw["json"] - - return dict( - __key__=raw["__key__"], - image=raw["jpg"], # expected type: torch.Tensor - context=j["question"].strip(), # expected type: str - answers=j["gt_answer"].strip(), # expected type: typing.Union[typing.List[str], NoneType], default: None - answer_weights=None, # expected type: typing.Union[torch.Tensor, NoneType], default: None - ) - - -def part_filter_9(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("json", "jpg") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_10(raw: dict) -> dict: # Note: Images are already decoded to tensors - j = raw["json"] - return dict( - image=raw["jpg"], # expected type: torch.Tensor - context=j["question"], # expected type: str - answers=j["answer"], # expected type: typing.Optional[typing.List[str]], default: None - answer_weights=None, # expected type: typing.Optional[torch.Tensor], default: None - ) - - -def part_filter_10(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("json", "jpg") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_11(raw: dict) -> dict: # Note: Images are already decoded to tensors - j = raw["json"] - - return dict( - __key__=raw["__key__"], - image=raw["jpg"], - context=j["question"], - answers=j["answer"], - answer_weights=None, - ) - - -def part_filter_11(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("jpg", "json") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_12(raw: dict) -> dict: # Note: Images are already decoded to tensors - j = raw["json"] - return dict( - __key__=raw["__key__"], - image=raw["jpg"], # expected type: torch.Tensor - context=j["question"], # expected type: str - answers=j["answer"], # expected type: typing.Optional[typing.List[str]], default: None - answer_weights=None, # expected type: typing.Optional[torch.Tensor], default: None - ) - - -def part_filter_12(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("jpg", "json") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_13(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - key = raw["__key__"] - - if "geoqa_plus" in key or "tqa" in key: - return dict( - __key__=raw["__key__"], - image=raw["jpg"], - context=json_item["question"], - choices=json_item["choices"], - correct_choice_idx=json_item["correct_answer_index"], - ) - elif "geometry3k" in key: - return dict( - __key__=raw["__key__"], - image=raw["jpg"], - context=json_item["question"], - choices=json_item["choices"], - correct_choice_idx=ord(json_item["answer"].lower()) - 97, - ) - else: # science_qa, ai2d - image_key = "png" if "png" in raw else "jpg" - if image_key not in raw: - log.warning(f"Image key {image_key} not found in with raw keys: {raw.keys()}") - return dict( - __key__=raw["__key__"], # science_qa_sample_{idx} - image=raw[image_key], # expected type: torch.Tensor - context=json_item["question"], # expected type: str - choices=json_item["choices"], # expected type: typing.Union[typing.List[str], NoneType], default: None - correct_choice_idx=json_item["correct_choice_index"], - ) - - -def part_filter_13(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("json", "png", "jpg") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_14(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - - return dict( - __key__=raw["__key__"], # arxiv_qa_sample_{idx} - image=raw["jpg"], # expected type: torch.Tensor - context=json_item["question"], # expected type: str - choices=json_item["options"], # expected type: typing.Union[typing.List[str], NoneType], default: None - correct_choice_idx=json_item["correct_choice_index"], - ) - - -def part_filter_14(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("jpg", "json") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_15(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - - if json_item["question_type"] == "multi_choice": - correct_choice_idx = json_item["choices"].index(json_item["answer"]) - return dict( - __key__=raw["__key__"], - image=raw["jpg"], - context=json_item["question"], - choices=json_item["choices"], - correct_choice_idx=correct_choice_idx, - ) - else: - # A temporary hack for non multi-choice samples. - # If correct_choice_idx=-1, we should route it to the VQAWebdataset dataloading method. - # (74.7% free-text questions, 25.3% multi-choice questions) - return dict( - __key__=raw["__key__"], - image=raw["jpg"], - context=json_item["question"], - choices=[json_item["answer"]], - correct_choice_idx=-1, - ) - - -def part_filter_15(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("jpg", "json") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_16(raw: dict) -> dict: # Note: Images are already decoded to tensors - j = raw["json"] - - return dict( - __key__="llava-{}".format(raw["__key__"]), images=[raw["jpg"]], texts=j["conversations"], similarity_matrix=None - ) - - -def part_filter_16(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("jpg", "json") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_17(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - images = [raw["jpg"]] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_17(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("json", "jpg") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_18(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - images = [raw["jpg"]] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_18(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("json", "jpg") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_19(raw: dict) -> dict: # Note: Images are already decoded to tensors - - json_item = raw["json"] - images = [raw["jpg"]] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, # expected type: torch.Tensor - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_19(part: str) -> bool: - - return part in ("jpg", "json") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_20(raw: dict) -> dict: # Note: Images are already decoded to tensors - j = raw["json"] - - return dict( - __key__="llava-{}".format(raw["__key__"]), images=[raw["jpg"]], texts=j["conversations"], similarity_matrix=None - ) - - -def part_filter_20(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("jpg", "json") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_21(raw: dict) -> dict: # Note: Images are already decoded to tensors - j = raw["json"] - - return dict( - __key__="llava-{}".format(raw["__key__"]), images=[raw["png"]], texts=j["conversations"], similarity_matrix=None - ) - - -def part_filter_21(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("json", "png") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_22(raw: dict) -> dict: # Note: Images are already decoded to tensors - - json_item = raw["json"] - images = [raw["jpg"]] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, # expected type: torch.Tensor - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_22(part: str) -> bool: - - return part in ("jpg", "json") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_23(raw: dict) -> dict: # Note: Images are already decoded to tensors - - json_item = raw["json"] - images = [raw["jpg"]] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, # expected type: torch.Tensor - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_23(part: str) -> bool: - - return part in ("jpg", "json") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_24(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - images = [raw["jpg"]] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_24(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("json", "jpg") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_25(raw: dict) -> dict: # Note: Images are already decoded to tensors - j = raw["json"] - - return dict( - __key__="llava-{}".format(raw["__key__"]), images=[raw["jpg"]], texts=j["conversations"], similarity_matrix=None - ) - - -def part_filter_25(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("json", "jpg") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_26(raw: dict) -> dict: # Note: Images are already decoded to tensors - j = raw["json"] - - return dict( - __key__="llava-{}".format(raw["__key__"]), images=[raw["jpg"]], texts=j["conversations"], similarity_matrix=None - ) - - -def part_filter_26(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("jpg", "json") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_27(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - images = [raw["jpg"]] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_27(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("json", "jpg") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_28(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - images = [raw["jpg"]] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_28(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("jpg", "json") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_29(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - images = [raw["jpg"]] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_29(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("json", "jpg") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_30(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - images = [raw["jpg"]] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_30(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("json", "jpg") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_31(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - images = [raw["jpg"]] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_31(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("jpg", "json") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_32(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - images = [raw["jpg"]] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_32(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("json", "jpg") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_33(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - images = [raw["jpg"]] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_33(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("json", "jpg") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_34(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - images = [raw["jpg"]] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_34(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("json", "jpg") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_35(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - images = [raw["jpg"]] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_35(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("jpg", "json") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_36(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - images = [raw["jpg"]] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_36(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("jpg", "json") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_37(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - images = [raw["jpg"]] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_37(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("json", "jpg") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_38(raw: dict) -> dict: - j = raw["json"] - - if "ReCTs" in raw["__key__"]: - return dict( - __key__=raw["__key__"], - image=raw["jpg"], - text="", - words_boxes=j["quads_1k_normalized"], - words_text=j["texts"], - ) - else: # coco-text-multi, textocr-multi - return dict( - __key__=raw["__key__"], - image=raw["jpg"], - text="", - words_boxes=j["bboxes_1k_normalized"], - words_text=j["texts"], - ) - - -def part_filter_38(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("jpg", "json") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_39(raw: dict) -> dict: # Note: Images are already decoded to tensors - j = raw["json"] - return dict( - image=raw["jpg"], # expected type: torch.Tensor - text=" ".join(j["lines"]["text"]), # expected type: str - ) - - -def part_filter_39(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("jpg", "json") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_40(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - images = [raw["jpg"]] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_40(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("json", "jpg") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_41(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - images = [raw["jpg"]] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_41(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("jpg", "json") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_42(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - - img = Image.open(io.BytesIO(raw["img"])) - images = [img] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_42(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("img", "json") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_43(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - - img = Image.open(io.BytesIO(raw["img"])) - images = [img] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_43(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("img", "json") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_44(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - - img = Image.open(io.BytesIO(raw["img"])) - images = [img] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_44(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("img", "json") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_45(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - - img = Image.open(io.BytesIO(raw["img"])) - images = [img] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_45(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("img", "json") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_46(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - - img = Image.open(io.BytesIO(raw["img"])) - images = [img] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_46(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("img", "json") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_47(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - - img = Image.open(io.BytesIO(raw["img"])) - images = [img] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_47(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("img", "json") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_48(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - - img = Image.open(io.BytesIO(raw["img"])) - images = [img] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_48(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("img", "json") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_49(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - - img = Image.open(io.BytesIO(raw["img"])) - images = [img] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_49(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("img", "json") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_50(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - - img = Image.open(io.BytesIO(raw["img"])) - images = [img] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_50(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("img", "json") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_51(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - for i, turn in enumerate(json_item["conversations"]): - if i > 0 and turn["from"] == "human" and "" in turn["value"]: - turn["value"] = turn["value"].replace("\n", "") - - img = Image.open(io.BytesIO(raw["img"])) - images = [img] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_51(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("img", "json") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_52(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - - img = Image.open(io.BytesIO(raw["img"])) - images = [img] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_52(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("img", "json") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_53(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - # for i, turn in enumerate(json_item['conversations']): - # if i > 0 and turn['from'] == 'human' and '' in turn['value']: - # turn['value'] = turn['value'].replace("\n", "") - - img = Image.open(io.BytesIO(raw["img"])) - images = [img] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_53(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("img", "json") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_54(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - - img = Image.open(io.BytesIO(raw["img"])) - images = [img] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_54(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("img", "json") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_55(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - - img = Image.open(io.BytesIO(raw["img"])) - images = [img] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_55(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("img", "json") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_56(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - - img = Image.open(io.BytesIO(raw["img"])) - images = [img] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_56(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("img", "json") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_57(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - - img = Image.open(io.BytesIO(raw["img"])) - images = [img] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_57(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("img", "json") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_58(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - - img = Image.open(io.BytesIO(raw["img"])) - images = [img] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_58(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("img", "json") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_59(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - - img = Image.open(io.BytesIO(raw["img"])) - images = [img] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_59(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("img", "json") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_60(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - - img = Image.open(io.BytesIO(raw["img"])) - images = [img] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_60(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("img", "json") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_61(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - - img = Image.open(io.BytesIO(raw["img"])) - images = [img] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_61(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("img", "json") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_62(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - - img = Image.open(io.BytesIO(raw["img"])) - images = [img] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_62(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("img", "json") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_63(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - - img = Image.open(io.BytesIO(raw["img"])) - images = [img] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_63(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("img", "json") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_64(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - - img = Image.open(io.BytesIO(raw["img"])) - images = [img] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_64(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("img", "json") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_65(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - - img = Image.open(io.BytesIO(raw["img"])) - images = [img] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_65(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("img", "json") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_66(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - - img = Image.open(io.BytesIO(raw["img"])) - images = [img] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_66(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("img", "json") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_67(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - - img = Image.open(io.BytesIO(raw["img"])) - images = [img] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_67(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("img", "json") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_68(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - images = [raw["jpg"]] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_68(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("json", "jpg") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_69(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - images = [raw["jpg"]] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_69(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("jpg", "json") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_70(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - images = [raw["jpg"]] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_70(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("jpg", "json") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_71(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - images = [raw["jpg"]] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_71(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("jpg", "json") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_72(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - images = [raw["jpg"]] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_72(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("jpg", "json") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_73(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - - img = Image.open(io.BytesIO(raw["img"])) - images = [img] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_73(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("json", "img") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_74(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - - img = Image.open(io.BytesIO(raw["img"])) - images = [img] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_74(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("json", "img") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_75(raw: dict) -> dict: # Note: Images are already decoded to tensors - - if "text" in raw: - caption = raw["text"] - else: - caption = raw["json"]["caption"] - return dict( - __key__=raw["__key__"], - image=raw["jpg"], # expected type: torch.Tensor - caption=caption, # expected type: str - ) - - -def part_filter_75(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("json", "jpg", "text") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_76(raw: dict) -> dict: # Note: Images are already decoded to tensors - - if "text" in raw: - caption = raw["text"] - else: - caption = raw["json"]["caption"] - return dict( - __key__=raw["__key__"], - image=raw["jpg"], # expected type: torch.Tensor - caption=caption, # expected type: str - ) - - -def part_filter_76(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("json", "jpg", "text") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_77(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - - total = len(json_item["conversations"]) // 2 - idx = random.randrange(total) # noqa: F821 - human = json_item["conversations"][idx * 2] - out = json_item["conversations"][idx * 2 + 1] - - return dict( - __key__=raw["__key__"], - image=raw["jpg"], - context=human["value"].replace("\n", ""), - answers=out["value"], - answer_weights=None, - ) - - -def part_filter_77(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("json", "jpg") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_78(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - - total = len(json_item["conversations"]) // 2 - idx = random.randrange(total) # noqa: F821 - human = json_item["conversations"][idx * 2] - out = json_item["conversations"][idx * 2 + 1] - - return dict( - __key__=raw["__key__"], - image=raw["jpg"], - context=human["value"].replace("\n", ""), - answers=out["value"], - answer_weights=None, - ) - - -def part_filter_78(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("json", "jpg") - - - - -def sample_loader_79(raw: dict) -> dict: # Note: Images are already decoded to tensors - j = raw["json"] - if "answer" in j: - answers = [a[0] for a in j["answer"][0]] - answer_weights = torch.Tensor([float(a[1]) for a in j["answer"][0]]) - else: - answers = None - answer_weights = None - - return dict( - __key__=raw["__key__"], - image=raw["jpg"], # expected type: torch.Tensor - context=j["question"], # expected type: str - answers=answers, # expected type: typing.List[str] - answer_weights=answer_weights, # expected type: typing.Union[torch.Tensor, NoneType] - ) - - -def part_filter_79(part: str) -> bool: - # Filter for parts required by the sample_loader - return part in ("jpg", "json") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_80(raw: dict) -> dict: # Note: Images are already decoded to tensors - - return dict( - __key__=raw["__key__"], - image=raw["jpg"], # expected type: torch.Tensor - context=raw["json"]["question"], # expected type: str - answers=raw["json"]["answer"], # expected type: typing.Union[typing.List[str], NoneType], default: None - answer_weights=None, # expected type: typing.Union[torch.Tensor, NoneType], default: None - ) - - -def part_filter_80(part: str) -> bool: - - return part in ("json", "jpg") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_81(raw: dict) -> dict: # Note: Images are already decoded to tensors - j = raw["json"] - return dict( - image=raw["jpg"], # expected type: torch.Tensor - text=" ".join(j["lines"]["text"]), # expected type: str - ) - - -def part_filter_81(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("jpg", "json") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_82(raw: dict) -> dict: # Note: Images are already decoded to tensors - j = raw["json"] - - return dict( - __key__=raw["__key__"], - image=raw["jpg"], - text=j["text"], - words_boxes=j["bbox_1k_normalized"], - ) - - -def part_filter_82(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("jpg", "json") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_83(raw: dict) -> dict: - j = raw["json"] - - return dict( - __key__=raw["__key__"], - image=raw["jpg"], - text=j["text"], - words_boxes=j["bbox_1k_normalized"], - ) - - -def part_filter_83(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("json", "jpg") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_84(raw: dict) -> dict: # Note: Images are already decoded to tensors - j = raw["json"] - - return dict( - __key__=raw["__key__"], - image=raw["jpg"], - text=j["text"], - words_boxes=j["bbox_1k_normalized"], - ) - - -def part_filter_84(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("json", "jpg") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_85(raw: dict) -> dict: # Note: Images are already decoded to tensors - j = raw["json"] - - return dict(__key__=raw["__key__"], image=raw["jpg"], text=j["text"], words_boxes=j["quad_1k_normalized"]) - - -def part_filter_85(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("json", "jpg") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_86(raw: dict) -> dict: # Note: Images are already decoded to tensors - j = raw["json"] - - return dict( - __key__=raw["__key__"], - image=raw["jpg"], - text=j["text"], - words_boxes=j["bbox_1k_normalized"], - ) - - -def part_filter_86(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("jpg", "json") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_87(raw: dict) -> dict: # Note: Images are already decoded to tensors - j = raw["json"] - - quad = j["quad"] - quad = [val for point in quad for val in point] - - return dict( - image=raw["jpg"], # expected type: torch.Tensor - text=j["text"], # expected type: str - words_boxes=quad, # expected type: typing.Optional[torch.Tensor], default: None - ) - - -def part_filter_87(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("json", "jpg") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_88(raw: dict) -> dict: # Note: Images are already decoded to tensors - j = raw["json"] - return dict( - __key__=raw["__key__"], - image=raw["jpg"], - text="", - words_boxes=j["bboxes_1k_normalized"], - words_text=j["texts"], - ) - - -def part_filter_88(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("json", "jpg") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_89(raw: dict) -> dict: - j = raw["json"] - return dict( - __key__=raw["__key__"], - image=raw["jpg"], - text="", - words_boxes=j["bboxes_1k_normalized"], - words_text=j["texts"], - ) - - -def part_filter_89(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("jpg", "json") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_90(raw: dict) -> dict: # Note: Images are already decoded to tensors - j = raw["json"] - return dict( - __key__=raw["__key__"], - image=raw["jpg"], - text="", - words_boxes=j["quads_1k_normalized"], - words_text=j["texts"], - ) - - -def part_filter_90(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("json", "jpg") - - - - -def sample_loader_91(raw: dict) -> dict: # Note: Images are already decoded to tensors - j = raw["json"] - if "answer" in j: - answers = [a[0] for a in j["answer"][0]] - answer_weights = torch.Tensor([float(a[1]) for a in j["answer"][0]]) - else: - answers = None - answer_weights = None - - return dict( - __key__=raw["__key__"], - image=raw["jpg"], # expected type: torch.Tensor - context=j["question"], # expected type: str - answers=answers, # expected type: typing.List[str] - answer_weights=answer_weights, # expected type: typing.Union[torch.Tensor, NoneType] - ) - - -def part_filter_91(part: str) -> bool: - # Filter for parts required by the sample_loader - return part in ("jpg", "json") - - -# Dataset -> Sample Loader Mapping -dataset_loader_mapping = { - "coco_train_val_restval": { - "sample_loader": "sample_loader_0", - "part_filter": "part_filter_0", - "data_class": "CaptioningWebdataset", - "data_weight": 0.01, - }, - "extended-sci/data/merged/CoT": { - "sample_loader": "sample_loader_1", - "part_filter": "part_filter_1", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.006, - }, - "extended-sci/data/merged/single-choice": { - "sample_loader": "sample_loader_2", - "part_filter": "part_filter_2", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.004, - }, - "extended-sci/data/extended-sci-3/CoT": { - "sample_loader": "sample_loader_3", - "part_filter": "part_filter_3", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.0006, - }, - "extended-sci/data/extended-sci-3/single-choice": { - "sample_loader": "sample_loader_4", - "part_filter": "part_filter_4", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.0004, - }, - "nvlm/wdai/data/SceMQA_processed": { - "sample_loader": "sample_loader_5", - "part_filter": "part_filter_5", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.0006, - }, - "nvlm/wdai/data/vqa_collection_doc_text_st_chart_scale_textbook_LRV_Screen": { - "sample_loader": "sample_loader_6", - "part_filter": "part_filter_6", - "data_class": "VQAWebdataset", - "data_weight": 0.08, - }, - "nvlm/wdai/data/plotqa/processed": { - "sample_loader": "sample_loader_7", - "part_filter": "part_filter_7", - "data_class": "VQAWebdataset", - "data_weight": 0.095, - }, - "nvlm/wdai/data/clevr-math/processed": { - "sample_loader": "sample_loader_8", - "part_filter": "part_filter_8", - "data_class": "VQAWebdataset", - "data_weight": 0.02, - }, - "nvlm/wdai/data/MMC-Instruction/processed": { - "sample_loader": "sample_loader_9", - "part_filter": "part_filter_9", - "data_class": "VQAWebdataset", - "data_weight": 0.07, - }, - "nvlm/wdai/data/ocrvqa/processed": { - "sample_loader": "sample_loader_10", - "part_filter": "part_filter_10", - "data_class": "VQAWebdataset", - "data_weight": 0.06, - }, - "nvlm/wdai/data/dude/processed": { - "sample_loader": "sample_loader_11", - "part_filter": "part_filter_11", - "data_class": "VQAWebdataset", - "data_weight": 0.01, - }, - "nvlm/wdai/data/VisualMRC/processed": { - "sample_loader": "sample_loader_12", - "part_filter": "part_filter_12", - "data_class": "VQAWebdataset", - "data_weight": 0.015, - }, - "nvlm/wdai/data/mcvqa_collection_scienceqa_ai2d_geoqaplus_geometry3k_tqa": { - "sample_loader": "sample_loader_13", - "part_filter": "part_filter_13", - "data_class": "MultiChoiceVQAWebdataset", - "data_weight": 0.025, - }, - "nvlm/wdai/data/arxiv_qa/processed": { - "sample_loader": "sample_loader_14", - "part_filter": "part_filter_14", - "data_class": "MultiChoiceVQAWebdataset", - "data_weight": 0.02, - }, - "nvlm/wdai/data/tabmwp/processed": { - "sample_loader": "sample_loader_15", - "part_filter": "part_filter_15", - "data_class": "MultiChoiceVQAWebdataset", - "data_weight": 0.015, - }, - "nvlm/wdai/data/ocr_vqa_aug/processed": { - "sample_loader": "sample_loader_16", - "part_filter": "part_filter_16", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.055, - }, - "nvlm/wdai/data/dvqa_full/processed": { - "sample_loader": "sample_loader_17", - "part_filter": "part_filter_17", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.055, - }, - "nvlm/wdai/data/LLaVA-v1.5_shuffle/no_refcoco_vg_ocrvqa": { - "sample_loader": "sample_loader_18", - "part_filter": "part_filter_18", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.085, - }, - "vqa/more_data/infographics_vqa/processed/train": { - "sample_loader": "sample_loader_19", - "part_filter": "part_filter_19", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.01, - }, - "nvlm/wdai/data/sharegpt4o/processed": { - "sample_loader": "sample_loader_20", - "part_filter": "part_filter_20", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.02, - }, - "nvlm/wdai/data/sparse_ocr_data/merged": { - "sample_loader": "sample_loader_21", - "part_filter": "part_filter_21", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.045, - }, - "nvlm/nayeonl/data/blendv4/MetaMathQA/processed/train_text_image": { - "sample_loader": "sample_loader_22", - "part_filter": "part_filter_22", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.004, - }, - "nvlm/nayeonl/data/blendv4/gsm8k/processed/train_text_image": { - "sample_loader": "sample_loader_23", - "part_filter": "part_filter_23", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.003, - }, - "nvlm/wdai/data/docmatix/processed": { - "sample_loader": "sample_loader_24", - "part_filter": "part_filter_24", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.1, - }, - "nvlm/wdai/data/bentham_hw_squad/processed": { - "sample_loader": "sample_loader_25", - "part_filter": "part_filter_25", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.01, - }, - "nvlm/wdai/data/WikiTableQA/processed": { - "sample_loader": "sample_loader_26", - "part_filter": "part_filter_26", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.003, - }, - "nvlm/wdai/data/figureqa/processed": { - "sample_loader": "sample_loader_27", - "part_filter": "part_filter_27", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.01, - }, - "nvlm/wdai/data/llava-onevision/ai2d_combined_processed": { - "sample_loader": "sample_loader_28", - "part_filter": "part_filter_28", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.01, - }, - "nvlm/wdai/data/llava-onevision/math_combined_processed": { - "sample_loader": "sample_loader_29", - "part_filter": "part_filter_29", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.035, - }, - "nvlm/wdai/data/llava-onevision/robut_combined_processed": { - "sample_loader": "sample_loader_30", - "part_filter": "part_filter_30", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.01, - }, - "nvlm/wdai/data/llava-onevision/llavar_20k_processed": { - "sample_loader": "sample_loader_31", - "part_filter": "part_filter_31", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.007, - }, - "nvlm/wdai/data/llava-onevision/tallyqa_processed": { - "sample_loader": "sample_loader_32", - "part_filter": "part_filter_32", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.02, - }, - "nvlm/wdai/data/llava-onevision/ureader_ie_processed": { - "sample_loader": "sample_loader_33", - "part_filter": "part_filter_33", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.007, - }, - "nvlm/wdai/data/llava-onevision/visual7w_processed": { - "sample_loader": "sample_loader_34", - "part_filter": "part_filter_34", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.006, - }, - "nvlm/wdai/data/llava-onevision/mavis_math_rule_geo_processed": { - "sample_loader": "sample_loader_35", - "part_filter": "part_filter_35", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.01, - }, - "nvlm/wdai/data/llava-onevision/ureader_kg_processed": { - "sample_loader": "sample_loader_36", - "part_filter": "part_filter_36", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.005, - }, - "nvlm/wdai/data/llava-onevision/ureader_qa_processed": { - "sample_loader": "sample_loader_37", - "part_filter": "part_filter_37", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.02, - }, - "nvlm/wdai/data/ocr_multi_collection_cocotext_textocr_ReCTs": { - "sample_loader": "sample_loader_38", - "part_filter": "part_filter_38", - "data_class": "OCRWebdataset", - "data_weight": 0.01, - }, - "nvlm/wdai/data/pdfa-eng-wds/processed_word_len_500": { - "sample_loader": "sample_loader_39", - "part_filter": "part_filter_39", - "data_class": "OCRWebdataset", - "data_weight": 0.015, - }, - "nvlm/wdai/data/llava-onevision/super_clevr_processed": { - "sample_loader": "sample_loader_40", - "part_filter": "part_filter_40", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.007, - }, - "nvlm/wdai/data/llava-onevision/icon_qa_processed": { - "sample_loader": "sample_loader_41", - "part_filter": "part_filter_41", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.009, - }, - "nvlm/wdai/data/augmentations/chartqa_aug": { - "sample_loader": "sample_loader_42", - "part_filter": "part_filter_42", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.005, - }, - "nvlm/wdai/data/augmentations/gpt_chartqa": { - "sample_loader": "sample_loader_43", - "part_filter": "part_filter_43", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.006, - }, - "nvlm/wdai/data/augmentations/gpt_docvqa": { - "sample_loader": "sample_loader_44", - "part_filter": "part_filter_44", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.006, - }, - "nvlm/wdai/data/augmentations/docvqa_text": { - "sample_loader": "sample_loader_45", - "part_filter": "part_filter_45", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.006, - }, - "nvlm/wdai/data/augmentations/textvqa_text": { - "sample_loader": "sample_loader_46", - "part_filter": "part_filter_46", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.008, - }, - "nvlm/wdai/data/augmentations/i2s-musicsheet": { - "sample_loader": "sample_loader_47", - "part_filter": "part_filter_47", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.0005, - }, - "nvlm/wdai/data/augmentations/music": { - "sample_loader": "sample_loader_48", - "part_filter": "part_filter_48", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.007, - }, - "nvlm/wdai/data/augmentations/invoice": { - "sample_loader": "sample_loader_49", - "part_filter": "part_filter_49", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.002, - }, - "nvlm/wdai/data/augmentations/k12": { - "sample_loader": "sample_loader_50", - "part_filter": "part_filter_50", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.019, - }, - "nvlm/wdai/data/augmentations/MTVQA": { - "sample_loader": "sample_loader_51", - "part_filter": "part_filter_51", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.007, - }, - "nvlm/wdai/data/augmentations/VisualWebInstruct": { - "sample_loader": "sample_loader_52", - "part_filter": "part_filter_52", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.028, - }, - "nvlm/wdai/data/augmentations/financeqa": { - "sample_loader": "sample_loader_53", - "part_filter": "part_filter_53", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.005, - }, - "nvlm/wdai/data/augmentations/docreason": { - "sample_loader": "sample_loader_54", - "part_filter": "part_filter_54", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.004, - }, - "nvlm/wdai/data/augmentations/gpt_mtwi": { - "sample_loader": "sample_loader_55", - "part_filter": "part_filter_55", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.005, - }, - "nvlm/wdai/data/augmentations/geos_gpt": { - "sample_loader": "sample_loader_56", - "part_filter": "part_filter_56", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.0001, - }, - "nvlm/wdai/data/augmentations/cauldron_vistext": { - "sample_loader": "sample_loader_57", - "part_filter": "part_filter_57", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.005, - }, - "nvlm/wdai/data/augmentations/memes": { - "sample_loader": "sample_loader_58", - "part_filter": "part_filter_58", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.005, - }, - "nvlm/wdai/data/augmentations/gpt_roadtext": { - "sample_loader": "sample_loader_59", - "part_filter": "part_filter_59", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.0002, - }, - "nvlm/wdai/data/augmentations/indoor_qa": { - "sample_loader": "sample_loader_60", - "part_filter": "part_filter_60", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.001, - }, - "nvlm/wdai/data/augmentations/colpali": { - "sample_loader": "sample_loader_61", - "part_filter": "part_filter_61", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.007, - }, - "nvlm/wdai/data/augmentations/pmc_vqa": { - "sample_loader": "sample_loader_62", - "part_filter": "part_filter_62", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.01, - }, - "nvlm/wdai/data/augmentations/pathvqa": { - "sample_loader": "sample_loader_63", - "part_filter": "part_filter_63", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.004, - }, - "nvlm/wdai/data/augmentations/sciqa": { - "sample_loader": "sample_loader_64", - "part_filter": "part_filter_64", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.027, - }, - "nvlm/wdai/data/augmentations/chinese_meme": { - "sample_loader": "sample_loader_65", - "part_filter": "part_filter_65", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.001, - }, - "nvlm/wdai/data/augmentations/gpt_hiertext": { - "sample_loader": "sample_loader_66", - "part_filter": "part_filter_66", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.003, - }, - "nvlm/wdai/data/augmentations/cauldron_cocoqa": { - "sample_loader": "sample_loader_67", - "part_filter": "part_filter_67", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.007, - }, - "nvlm/wdai/data/cmm-math/processed": { - "sample_loader": "sample_loader_68", - "part_filter": "part_filter_68", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.001, - }, - "nvlm/wdai/data/mmtab/processed": { - "sample_loader": "sample_loader_69", - "part_filter": "part_filter_69", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.008, - }, - "nvlm/wdai/data/simchart9k/processed": { - "sample_loader": "sample_loader_70", - "part_filter": "part_filter_70", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.001, - }, - "nvlm/wdai/data/llava-onevision/mapqa_processed": { - "sample_loader": "sample_loader_71", - "part_filter": "part_filter_71", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.005, - }, - "nvlm/wdai/data/llava-onevision/vizwiz_processed": { - "sample_loader": "sample_loader_72", - "part_filter": "part_filter_72", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.002, - }, - "nvlm/wdai/data/augmentations/gpt_infovqa": { - "sample_loader": "sample_loader_73", - "part_filter": "part_filter_73", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.001, - }, - "nvlm/wdai/data/augmentations/viquae": { - "sample_loader": "sample_loader_74", - "part_filter": "part_filter_74", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.0005, - }, - "captioning/ccs_recaptioned/webdataset": { - "sample_loader": "sample_loader_75", - "part_filter": "part_filter_75", - "data_class": "CaptioningWebdataset", - "data_weight": 0.2, - }, - "captioning/laion115m-clean": { - "sample_loader": "sample_loader_76", - "part_filter": "part_filter_76", - "data_class": "CaptioningWebdataset", - "data_weight": 0.579, - }, - "nvlm/wdai/data/dvqa_full/processed_pt": { - "sample_loader": "sample_loader_77", - "part_filter": "part_filter_77", - "data_class": "VQAWebdataset", - "data_weight": 0.02, - }, - "nvlm/wdai/data/docmatix/processed_pt": { - "sample_loader": "sample_loader_78", - "part_filter": "part_filter_78", - "data_class": "VQAWebdataset", - "data_weight": 0.02, - }, - "vqa/VQAv2/stage1": { - "sample_loader": "sample_loader_91", - "part_filter": "part_filter_91", - "data_class": "VQAWebdataset", - "data_weight": 1.0, - }, - "vqa/Visual_Genome": { - "sample_loader": "sample_loader_80", - "part_filter": "part_filter_80", - "data_class": "VQAWebdataset", - "data_weight": 0.01, - }, - "nvlm/wdai/data/pdfa-eng-wds/processed_word_len_300": { - "sample_loader": "sample_loader_81", - "part_filter": "part_filter_81", - "data_class": "OCRWebdataset", - "data_weight": 0.08, - }, - "nvlm/wdai/data/textocr/processed": { - "sample_loader": "sample_loader_82", - "part_filter": "part_filter_82", - "data_class": "OCRWebdataset", - "data_weight": 0.02, - }, - "nvlm/wdai/data/coco-text/processed": { - "sample_loader": "sample_loader_83", - "part_filter": "part_filter_83", - "data_class": "OCRWebdataset", - "data_weight": 0.002, - }, - "nvlm/wdai/data/ArT/processed": { - "sample_loader": "sample_loader_84", - "part_filter": "part_filter_84", - "data_class": "OCRWebdataset", - "data_weight": 0.001, - }, - "nvlm/wdai/data/ReCTs/processed": { - "sample_loader": "sample_loader_85", - "part_filter": "part_filter_85", - "data_class": "OCRWebdataset", - "data_weight": 0.001, - }, - "nvlm/wdai/data/lsvt/processed": { - "sample_loader": "sample_loader_86", - "part_filter": "part_filter_86", - "data_class": "OCRWebdataset", - "data_weight": 0.005, - }, - "nvlm/wdai/data/RCTW/processed": { - "sample_loader": "sample_loader_87", - "part_filter": "part_filter_87", - "data_class": "OCRWebdataset", - "data_weight": 0.001, - }, - "nvlm/wdai/data/coco-text/processed_multi": { - "sample_loader": "sample_loader_88", - "part_filter": "part_filter_88", - "data_class": "OCRWebdataset", - "data_weight": 0.0003, - }, - "nvlm/wdai/data/textocr/processed_multi": { - "sample_loader": "sample_loader_89", - "part_filter": "part_filter_89", - "data_class": "OCRWebdataset", - "data_weight": 0.0004, - }, - "nvlm/wdai/data/ReCTs/processed_multi": { - "sample_loader": "sample_loader_90", - "part_filter": "part_filter_90", - "data_class": "OCRWebdataset", - "data_weight": 0.0003, - }, -} - - -def get_sample_loader(path): - """Returns the correct sample_loader function for a dataset.""" - if path not in dataset_loader_mapping: - path = data_path_mapping(path) - assert path in dataset_loader_mapping, f"path {path} not in dataset_loader_mapping" - return globals().get(dataset_loader_mapping.get(path, {}).get("sample_loader")) - - -def get_part_filter(path): - """Returns the correct part_filter function for a dataset.""" - if path not in dataset_loader_mapping: - path = data_path_mapping(path) - assert path in dataset_loader_mapping, f"path {path} not in dataset_loader_mapping" - return globals().get(dataset_loader_mapping.get(path, {}).get("part_filter")) - - -def get_data_class(path): - """Returns the correct data_class for a dataset.""" - if path not in dataset_loader_mapping: - path = data_path_mapping(path) - - assert path in dataset_loader_mapping, f"path {path} not in dataset_loader_mapping" - return dataset_loader_mapping[path]["data_class"] diff --git a/cosmos_framework/data/vfm/augmentors/vlm/prompt_format.py b/cosmos_framework/data/vfm/augmentors/vlm/prompt_format.py index ec86e66..5b576c4 100644 --- a/cosmos_framework/data/vfm/augmentors/vlm/prompt_format.py +++ b/cosmos_framework/data/vfm/augmentors/vlm/prompt_format.py @@ -45,7 +45,6 @@ def __call__(self, data_dict: Dict) -> Dict: if isinstance(list_of_conversation[0], list): selected_conversation = random.sample(list_of_conversation, 1)[0] elif isinstance(list_of_conversation[0], dict): - selected_conversation = list_of_conversation else: raise ValueError( @@ -82,7 +81,6 @@ def __call__(self, data_dict: Dict) -> Dict: del data_dict[conversation_key] - # # enforce chat order # self._enforce_text_chat_order(selected_conversation) @@ -91,7 +89,7 @@ def __call__(self, data_dict: Dict) -> Dict: def _enforce_text_chat_order(self, conversation: list) -> None: """ Reorder text content within user messages based on text_chat_order setting. - NOTE: this does NOT work for interleaved data!!!!!! + NOTE (maxzhaoshuol): this does NOT work for interleaved data!!!!!! Args: conversation: List of message dictionaries diff --git a/cosmos_framework/data/vfm/augmentors/vlm/timestamp.py b/cosmos_framework/data/vfm/augmentors/vlm/timestamp.py index edede0c..88d3ac5 100644 --- a/cosmos_framework/data/vfm/augmentors/vlm/timestamp.py +++ b/cosmos_framework/data/vfm/augmentors/vlm/timestamp.py @@ -97,7 +97,7 @@ def overlay_text( return images, [compute_timestamps(i, fps, processor) for i in range(len(images))] # Try to use DejaVu Sans Mono font for better readability - font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSansMono.ttf", font_size) + font = ImageFont.truetype("/invalid_dir", font_size) # Process each image processed_images = [] @@ -392,17 +392,15 @@ def augment_user_prompt( elif output_format == "temporal_caption": event = assistant_message[0] if random.random() < 0.333333: - start = round(event["start"]) end = round(event["end"]) elif random.random() < 0.666666: - start = round(event["start"] * 2) / 2 end = round(event["end"] * 2) / 2 else: start = event["start"] end = event["end"] - if start == end: # HACK: remove events with start == end + if start == end: raise ValueError("Start and end time are the same for data.") if timestamp_format == "seconds": if random.random() < 0.5: diff --git a/cosmos_framework/data/vfm/augmentors/vlm/timestamp_with_subject_tracking.py b/cosmos_framework/data/vfm/augmentors/vlm/timestamp_with_subject_tracking.py index 0507109..90cc9ab 100644 --- a/cosmos_framework/data/vfm/augmentors/vlm/timestamp_with_subject_tracking.py +++ b/cosmos_framework/data/vfm/augmentors/vlm/timestamp_with_subject_tracking.py @@ -224,17 +224,15 @@ def augment_user_prompt( elif output_format == "temporal_caption_subject": event = assistant_message[0] if random.random() < 0.333333: - start = round(event["start"]) end = round(event["end"]) elif random.random() < 0.666666: - start = round(event["start"] * 2) / 2 end = round(event["end"] * 2) / 2 else: start = event["start"] end = event["end"] - if start == end: # HACK: remove events with start == end + if start == end: log.warning(f"Start and end time are the same for data. {event}") return None diff --git a/cosmos_framework/data/vfm/augmentors/vlm/timestamp_without_augment_message.py b/cosmos_framework/data/vfm/augmentors/vlm/timestamp_without_augment_message.py index 584ca1d..7212510 100644 --- a/cosmos_framework/data/vfm/augmentors/vlm/timestamp_without_augment_message.py +++ b/cosmos_framework/data/vfm/augmentors/vlm/timestamp_without_augment_message.py @@ -162,14 +162,12 @@ def augment_user_prompt( elif output_format == "temporal_caption": event = assistant_message[0] if random.random() < 0.5: - start = round(event["start"]) end = round(event["end"]) else: - start = round(event["start"] * 2) / 2 end = round(event["end"] * 2) / 2 - if start == end: # HACK: remove events with start == end + if start == end: raise ValueError("Start and end time are the same for data.") user_prompt = random.choice( [ diff --git a/cosmos_framework/data/vfm/augmentors/vlm/timestamp_without_end_time.py b/cosmos_framework/data/vfm/augmentors/vlm/timestamp_without_end_time.py index 8df9dd1..812e113 100644 --- a/cosmos_framework/data/vfm/augmentors/vlm/timestamp_without_end_time.py +++ b/cosmos_framework/data/vfm/augmentors/vlm/timestamp_without_end_time.py @@ -205,10 +205,8 @@ def augment_user_prompt( elif output_format == "temporal_caption": event = assistant_message[0] if random.random() < 0.333333: - start = round(event["start"]) elif random.random() < 0.666666: - start = round(event["start"] * 2) / 2 else: start = event["start"] diff --git a/cosmos_framework/data/vfm/augmentors/vlm/tokenize_data.py b/cosmos_framework/data/vfm/augmentors/vlm/tokenize_data.py index a0ce29c..f3a9914 100644 --- a/cosmos_framework/data/vfm/augmentors/vlm/tokenize_data.py +++ b/cosmos_framework/data/vfm/augmentors/vlm/tokenize_data.py @@ -158,7 +158,6 @@ def __call__(self, data_dict: Dict) -> Dict: if message["role"] == "user" and isinstance(message["content"], list): total_images += len([content for content in message["content"] if content["type"] == "image"]) total_videos += len([content for content in message["content"] if content["type"] == "video"]) - assert total_videos == 1 or total_videos == 0, "Only one video is supported for now" # url @@ -167,7 +166,6 @@ def __call__(self, data_dict: Dict) -> Dict: # go through each message in the conversation for message in conversation: # for user message, we insert the media - if message["role"] == "user" and isinstance( message["content"], list ): # Otherwise it's text and content is a string @@ -225,7 +223,6 @@ def __call__(self, data_dict: Dict) -> Dict: raw_images.append(image) elif content["type"] == "video": - # as tokenization will NOT upsample the video, we can use a larger value here at the cost of multiple video having 1.5x token length max_total_pixels = token_to_pixels(self.max_video_token_length * 1.5, temporal_patch_size=2) media_key = content["video"] @@ -248,7 +245,6 @@ def __call__(self, data_dict: Dict) -> Dict: return None videos = data_dict["media"][media_key]["videos"] # list of PIL images fps = data_dict["media"][media_key]["fps"] - # this is because videos are decoded to be around "max_video_token_length" tokens videos = maybe_subsample_frames( diff --git a/cosmos_framework/data/vfm/joint_dataloader.py b/cosmos_framework/data/vfm/joint_dataloader.py index 56b13e7..ba84ceb 100644 --- a/cosmos_framework/data/vfm/joint_dataloader.py +++ b/cosmos_framework/data/vfm/joint_dataloader.py @@ -1,7 +1,9 @@ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: OpenMDW-1.1 +import math from collections import deque +from collections.abc import Mapping from dataclasses import dataclass from typing import Any, ClassVar, Dict, Union @@ -12,6 +14,11 @@ from cosmos_framework.utils.lazy_config import instantiate from cosmos_framework.utils import log +from cosmos_framework.model.vfm.tokenizers.uniae.frame_math import ( + get_uniae_chunk_frames, + get_uniae_latent_num_frames, + normalize_uniae_chunk_frames, +) _TIMING_KEYS = {"_sample_time", "_aug_time", "_pre_aug_time", "_aug_step_times"} _BATCH_TIMING_KEYS = { @@ -38,6 +45,7 @@ def custom_collate_fn(batch): "sound", "raw_action_dim", "image_size", + "action_processing_record", } # Data keys where a per-sample value of ``None`` is a meaningful signal @@ -57,7 +65,6 @@ def custom_collate_fn(batch): # Handle standard list of samples elem = batch[0] if isinstance(elem, dict): - # Some Action datasets add optional metadata keys (for example # ``additional_view_description`` for concat-view captions) only for a # subset of samples. PyTorch can batch such samples together when @@ -72,6 +79,9 @@ def custom_collate_fn(batch): if key in _TIMING_KEYS: continue values = [d.get(key) for d in batch] + if key == "action_processing_record": + result[key] = values + continue if any(value is None for value in values): # Sparse data keys keep their None placeholders to preserve # 1:1 alignment with sequence_plan. Other (optional metadata) @@ -165,6 +175,8 @@ def __init__( prewarm: bool = True, default_lookahead_limit: int = _DEFAULT_LOOKAHEAD_LIMIT, lookahead_limits: Dict[str, int] | None = None, + uniae_chunk_frames: int | Mapping[str, int] | None = None, + uniae_pad_frames: int | None = None, ): """ Initialize the JointDataLoader with multiple datasets. @@ -186,6 +198,8 @@ def __init__( default_lookahead_limit: Packing-loop look-ahead fallback for dataloaders not in ``lookahead_limits``. lookahead_limits: Optional ``{dataset_name: int}`` per-dataloader override. + uniae_chunk_frames: Optional UniAE full chunk size, or resolution-keyed chunk sizes. + uniae_pad_frames: Optional UniAE boundary padding frames per chunk. Example: joint_loader = IterativeJointDataLoader( @@ -211,6 +225,8 @@ def __init__( self.sound_latent_fps = sound_latent_fps self.audio_sample_rate = audio_sample_rate self.default_lookahead_limit = int(default_lookahead_limit) + self.uniae_pad_frames = int(uniae_pad_frames) if uniae_pad_frames is not None else None + self.uniae_chunk_frames = self._normalize_uniae_chunk_frames(uniae_chunk_frames) assert (self.max_sequence_length is None) != (self.max_samples_per_batch is None), ( "Exactly one of max_sequence_length or max_samples_per_batch must be None, but not both." @@ -221,6 +237,8 @@ def __init__( assert not unknown, f"lookahead_limits references unknown dataloaders {unknown}; valid: {sorted(dataloaders)}" for dataset_name, dataloader_data in dataloaders.items(): + if dataloader_data is None: + continue assert set(dataloader_data.keys()) == {"dataloader", "ratio"}, f"Invalid config: {dataloader_data}" if dataloader_data["ratio"] <= 0: continue @@ -255,13 +273,42 @@ def __init__( "JointDataLoader: prewarm DISABLED (debug mode); first iteration may incur per-stream cold-load cost" ) + def _normalize_uniae_chunk_frames( + self, uniae_chunk_frames: int | Mapping[str, int] | None + ) -> int | dict[str, int] | None: + return normalize_uniae_chunk_frames( + uniae_chunk_frames, + pad_frames=self.uniae_pad_frames, + temporal_compression_factor=self.tokenizer_temporal_compression_factor, + temporal_divisibility_name="tokenizer_temporal_compression_factor", + ) + + def _get_uniae_chunk_frames(self, spatial_shape: tuple[int, int]) -> int: + assert self.uniae_chunk_frames is not None + return get_uniae_chunk_frames(self.uniae_chunk_frames, spatial_shape=spatial_shape) + + def _compute_vision_latent_t_shape(self, T: int, H: int, W: int) -> int: + if T < 1: + raise ValueError(f"Vision media must contain at least one frame, got {T}.") + if T == 1 or self.uniae_chunk_frames is None: + return 1 + (T - 1) // self.tokenizer_temporal_compression_factor + + assert self.uniae_pad_frames is not None + return get_uniae_latent_num_frames( + T, + self.uniae_chunk_frames, + pad_frames=self.uniae_pad_frames, + temporal_compression_factor=self.tokenizer_temporal_compression_factor, + spatial_shape=(H, W), + ) + def _prewarm_dataloaders(self) -> None: """Force all dataloader iterators to spawn workers and produce one batch. The first ``next()`` call on an ``InfiniteDataLoader`` iterator triggers ``DataLoader.__iter__()`` which spawns worker processes. For action dataloaders using ``multiprocessing_context='spawn'``, each worker must - fully initialise heavy datasets (BridgeOrigLeRobotDataset, EMBODIMENT_A, etc.) + fully initialise heavy datasets (BridgeOrigLeRobotDataset, embodiment_a, etc.) from scratch. If this happens lazily during training, the resulting delay (potentially minutes) causes NCCL collective timeouts when faster ranks enter the forward pass while slower ranks are still loading data. @@ -362,14 +409,13 @@ def _compute_num_tokens_per_sample(self, data_batch: dict) -> int: else: _, T, H, W = media.shape - vae_spatial_downsample = self.tokenizer_spatial_compression_factor * self.patch_spatial - vae_temporal_downsample = self.tokenizer_temporal_compression_factor - - latent_h_shape = H // vae_spatial_downsample - latent_w_shape = W // vae_spatial_downsample - latent_t_shape = 1 + (T - 1) // vae_temporal_downsample + latent_h_shape = H // self.tokenizer_spatial_compression_factor + latent_w_shape = W // self.tokenizer_spatial_compression_factor + patch_h_shape = math.ceil(latent_h_shape / self.patch_spatial) + patch_w_shape = math.ceil(latent_w_shape / self.patch_spatial) + latent_t_shape = self._compute_vision_latent_t_shape(T, H, W) - num_vision_tokens = latent_h_shape * latent_w_shape * latent_t_shape + 2 + num_vision_tokens = patch_h_shape * patch_w_shape * latent_t_shape + 2 num_tokens += num_vision_tokens # Action part: each action time step is 1 token. @@ -534,6 +580,8 @@ def __init__( prewarm: bool = True, default_lookahead_limit: int = JointDataLoader._DEFAULT_LOOKAHEAD_LIMIT, lookahead_limits: Dict[str, int] | None = None, + uniae_chunk_frames: int | Mapping[str, int] | None = None, + uniae_pad_frames: int | None = None, ): super().__init__( dataloaders, @@ -547,6 +595,8 @@ def __init__( prewarm=prewarm, default_lookahead_limit=default_lookahead_limit, lookahead_limits=lookahead_limits, + uniae_chunk_frames=uniae_chunk_frames, + uniae_pad_frames=uniae_pad_frames, ) self.seed = seed # Calculate probabilities for random sampling @@ -787,6 +837,8 @@ def __init__( audio_sample_rate: int = 48000, dataset_name: str = "default", lookahead_limit: int = JointDataLoader._DEFAULT_LOOKAHEAD_LIMIT, + uniae_chunk_frames: int | Mapping[str, int] | None = None, + uniae_pad_frames: int | None = None, ): """ Args: @@ -802,6 +854,8 @@ def __init__( audio_sample_rate: Audio sample rate in Hz. dataset_name: Name tag attached to every sample in the output batch. lookahead_limit: Packing-loop look-ahead for the wrapped dataloader. + uniae_chunk_frames: Optional UniAE full chunk size, or resolution-keyed chunk sizes. + uniae_pad_frames: Optional UniAE boundary padding frames per chunk. """ wrapped = {dataset_name: {"dataloader": dataloader, "ratio": 1}} super().__init__( @@ -814,6 +868,8 @@ def __init__( sound_latent_fps=sound_latent_fps, audio_sample_rate=audio_sample_rate, lookahead_limits={dataset_name: int(lookahead_limit)}, + uniae_chunk_frames=uniae_chunk_frames, + uniae_pad_frames=uniae_pad_frames, ) def __iter__(self): @@ -905,6 +961,8 @@ def __init__( audio_sample_rate: int = 48000, default_lookahead_limit: int = JointDataLoader._DEFAULT_LOOKAHEAD_LIMIT, lookahead_limits: Dict[str, int] | None = None, + uniae_chunk_frames: int | Mapping[str, int] | None = None, + uniae_pad_frames: int | None = None, ): super().__init__( dataloaders, @@ -917,6 +975,8 @@ def __init__( audio_sample_rate=audio_sample_rate, default_lookahead_limit=default_lookahead_limit, lookahead_limits=lookahead_limits, + uniae_chunk_frames=uniae_chunk_frames, + uniae_pad_frames=uniae_pad_frames, ) # Convert data ratios to probabilities diff --git a/cosmos_framework/data/vfm/packing_iterable_dataset.py b/cosmos_framework/data/vfm/packing_iterable_dataset.py new file mode 100644 index 0000000..ac30f0d --- /dev/null +++ b/cosmos_framework/data/vfm/packing_iterable_dataset.py @@ -0,0 +1,271 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: OpenMDW-1.1 + +""" +Abstract base class for pool-based token-budget bin-packing over multiple datasets. + +Extracted from ``cosmos_framework.data.vfm.vlm.joint_dataset_dynamic_batch_webloader`` +so that both the VLM and VFM internal dataloaders can share a single packing implementation. + +Usage +----- +Subclass and implement ``compute_sample_tokens(sample) -> int``. +Optionally override ``collate_batch(samples) -> Any`` for custom collation. + + class MyPacker(PackingIterableDataset): + def compute_sample_tokens(self, sample): + return len(sample["input_ids"]) +""" + +from __future__ import annotations + +import random +from abc import ABC, abstractmethod +from collections import deque +from enum import Enum +from typing import Any, Union + +import torch + +from cosmos_framework.utils.lazy_config import instantiate +from cosmos_framework.utils import log + + +class Modality(Enum): + IMAGE = "image" + VIDEO = "video" + TEXT = "text" + + +class PackingIterableDataset(torch.utils.data.IterableDataset, ABC): + """Pool-based greedy bin-packing IterableDataset. + + Maintains a pool of ``pool_size`` samples and assembles batches by + greedily selecting candidates that fit within the token budget + ``max_tokens``. Subclasses supply two hooks: + + * ``compute_sample_tokens(sample)`` — token cost of one sample (abstract). + * ``collate_batch(samples)`` — assemble a packed list into a batch + (default: identity, returns the list unchanged). + + Parameters + ---------- + datasets_cfg: + Mapping ``{name: {"dataset": , "ratio": }}``. + The *dataset* value may be a Hydra lazy config, an already-constructed + ``IterableDataset``, or a plain ``DataLoader`` (its ``.dataset`` is + unwrapped automatically). + max_tokens: + Token budget per batch (padded cost = ``cur_max_len * batch_size``). + pool_size: + Number of samples to buffer before selecting a batch. + max_batch_size: + Hard cap on items per batch (0 or None = no cap). + long_threshold: + Samples with token count ``>= long_threshold`` are emitted as + singletons regardless of budget. + batching_strategy: + ``"prefer_closest"`` (default) or ``"prefer_first"``. + """ + + def __init__( + self, + datasets_cfg: dict[str, dict[str, Union[int, object]]], + max_tokens: int, + pool_size: int, + max_batch_size: int, + long_threshold: int, + batching_strategy: str, + ): + super().__init__() + + assert batching_strategy in ("prefer_first", "prefer_closest"), ( + f"batching_strategy must be 'prefer_first' or 'prefer_closest', got {batching_strategy!r}" + ) + + self.max_tokens = max_tokens + self.pool_size = pool_size + self.long_threshold = long_threshold + self.max_batch_size = max_batch_size + self.batching_strategy = batching_strategy + + self._pool: deque[dict] = deque() + self._dataset_names: list[str] = [] + self._ratios: list[float] = [] + self._datasets: list[torch.utils.data.IterableDataset] = [] + + for name, cfg in datasets_cfg.items(): + assert {"ratio", "dataset"} <= cfg.keys(), ( + f"Each entry must have 'dataset' and 'ratio' keys: {name} -> {cfg.keys()}" + ) + ratio = cfg["ratio"] + if ratio == 0: + log.info(f"Skipping dataset {name} with ratio {ratio}") + continue + dataset_cfg = cfg["dataset"] + + ds = ( + instantiate(dataset_cfg) + if not isinstance(dataset_cfg, (torch.utils.data.IterableDataset, torch.utils.data.DataLoader)) + else dataset_cfg + ) + if isinstance(ds, torch.utils.data.DataLoader): + ds = ds.dataset + if hasattr(ds, "build_dataset") and callable(getattr(ds, "build_dataset")): + ds = ds.build_dataset() + + assert isinstance(ds, torch.utils.data.IterableDataset), ( + f"Expected an IterableDataset, got {type(ds)} for {name}" + ) + + self._dataset_names.append(name) + self._ratios.append(float(ratio)) + self._datasets.append(ds) + log.info(f"Added dataset {name} with ratio {ratio}") + + log.info(f"added data: {list(datasets_cfg.keys())}") + assert len(self._datasets) > 0, "No datasets added" + self._data_len: int = sum(int(getattr(ds, "total_images", 0)) for ds in self._datasets) + if self._data_len == 0: + self._data_len = 10**12 + self.iterators = [iter(ds) for ds in self._datasets] + + # ------------------------------------------------------------------ + # Abstract / overridable hooks + # ------------------------------------------------------------------ + + @abstractmethod + def compute_sample_tokens(self, sample: dict) -> int: + """Return the token cost of one sample for packing budget accounting.""" + + def collate_batch(self, samples: list[dict]) -> Any: + """Assemble a packed list of samples into one batch. + + Default implementation returns the list unchanged (identity). + Override to pad, stack, or transform samples into tensors. + """ + return samples + + # ------------------------------------------------------------------ + # PyTorch Dataset API + # ------------------------------------------------------------------ + + def __len__(self) -> int: + return self._data_len + + def __iter__(self): + while True: + batch = self._best_fit_batch() + yield self.collate_batch(batch) + + # ------------------------------------------------------------------ + # Internal packing helpers (moved verbatim from _JointIterableDataset) + # ------------------------------------------------------------------ + + def _max_tokens(self, cur_max: int) -> int: + if cur_max < 1000: + return self.max_tokens + return self.max_tokens // 2 + + def _get_next_sample(self) -> dict: + index_id = random.choices(range(len(self.iterators)), weights=self._ratios, k=1)[0] + curr_dataset = self.iterators[index_id] + try: + output = next(curr_dataset) + except StopIteration: + log.critical(f"dataset {self._dataset_names[index_id]} exhausted") + self.iterators[index_id] = iter(self._datasets[index_id]) + output = next(self.iterators[index_id]) + return output + + def _fill_pool(self): + while len(self._pool) < self.pool_size: + self._pool.append(self._get_next_sample()) + + def _padded_cost(self, cur_max: int, k: int) -> int: + return cur_max * k + + def _get_modality(self, sample: dict) -> Modality: + if "pixel_values" in sample: + return Modality.IMAGE + elif "pixel_values_videos" in sample: + return Modality.VIDEO + return Modality.TEXT + + def _best_fit_batch(self) -> list[dict]: + """Build one batch using the configured token-budget strategy.""" + self._fill_pool() + seed = self._pool.popleft() + seed_modality = self._get_modality(seed) + L0 = self.compute_sample_tokens(seed) + + if L0 >= self.long_threshold or L0 >= self._max_tokens(L0): + return [seed] + + chosen = [seed] + cur_max = L0 + + while self._pool: + if self.max_batch_size and len(chosen) >= self.max_batch_size: + break + best_idx = self._find_best_candidate(cur_max, len(chosen), seed_modality) + if best_idx is None: + break + cand = self._remove_from_pool(best_idx) + chosen.append(cand) + cur_max = max(cur_max, self.compute_sample_tokens(cand)) + + return chosen + + def _find_best_candidate(self, cur_max: int, num_chosen: int, seed_modality: Modality) -> int | None: + if self.batching_strategy == "prefer_first": + return self._find_best_candidate_prefer_first(cur_max, num_chosen, seed_modality) + return self._find_best_candidate_prefer_closest(cur_max, num_chosen, seed_modality) + + def _find_best_candidate_prefer_first(self, cur_max: int, num_chosen: int, seed_modality: Modality) -> int | None: + best_idx = None + best_new_tokens = None + for idx, cand in enumerate(self._pool): + if self._get_modality(cand) != seed_modality: + continue + L = self.compute_sample_tokens(cand) + new_max = max(cur_max, L) + new_tokens = self._padded_cost(new_max, num_chosen + 1) + if new_tokens <= self._max_tokens(cur_max): + if best_new_tokens is None or new_tokens < best_new_tokens: + best_new_tokens = new_tokens + best_idx = idx + return best_idx + + def _find_best_candidate_prefer_closest(self, cur_max: int, num_chosen: int, seed_modality: Modality) -> int | None: + best_idx = None + best_new_tokens = None + smallest_length_diff = None + for idx, cand in enumerate(self._pool): + if self._get_modality(cand) != seed_modality: + continue + L = self.compute_sample_tokens(cand) + new_max = max(cur_max, L) + new_tokens = self._padded_cost(new_max, num_chosen + 1) + if new_tokens <= self._max_tokens(cur_max): + length_diff = abs(L - cur_max) + if ( + best_new_tokens is None + or new_tokens < best_new_tokens + or (new_tokens == best_new_tokens and length_diff < smallest_length_diff) + ): + best_new_tokens = new_tokens + best_idx = idx + smallest_length_diff = length_diff + return best_idx + + def _remove_from_pool(self, idx: int) -> dict: + if idx == 0: + return self._pool.popleft() + elif idx == len(self._pool) - 1: + return self._pool.pop() + else: + self._pool.rotate(-idx) + item = self._pool.popleft() + self._pool.rotate(idx) + return item diff --git a/cosmos_framework/data/vfm/processors/__init__.py b/cosmos_framework/data/vfm/processors/__init__.py index a1cc60e..51268f1 100644 --- a/cosmos_framework/data/vfm/processors/__init__.py +++ b/cosmos_framework/data/vfm/processors/__init__.py @@ -125,7 +125,12 @@ def build_processor( return Qwen3VLProcessor(tokenizer_type, credentials=credentials, bucket=bucket, cache_dir=cache_dir) elif "nvidia/NVIDIA-Nemotron-Nano-12B-v2-VL-BF16" in tokenizer_type: return NemotronVLProcessor(tokenizer_type, credentials=credentials, bucket=bucket, cache_dir=cache_dir) - elif "NVIDIA-Nemotron-3-Dense-VL" in tokenizer_type or "Qwen3-2B-ViT" in tokenizer_type: + elif ( + "NVIDIA-Nemotron-3-Dense-VL" in tokenizer_type + or "Qwen3-2B-ViT" in tokenizer_type + or "nvidia/Cosmos3-Reasoner-2B-Private" in tokenizer_type + or "nvidia/Cosmos3-Edge-Reasoner" in tokenizer_type + ): return Nemotron3DenseVLProcessor(tokenizer_type, credentials=credentials, bucket=bucket, cache_dir=cache_dir) elif "Qwen/Qwen3-0.6B" in tokenizer_type: local_path = _download_llm_tokenizer(tokenizer_type, credentials, bucket, cache_dir) @@ -137,7 +142,6 @@ def build_processor( else: raise ValueError(f"Tokenizer type {tokenizer_type} not supported") - def build_processor_lazy( *args, repository: Optional[str] = None, @@ -174,4 +178,4 @@ def build_processor_lazy( if subdir: local_path = os.path.join(local_path, subdir) return sys.modules[__name__].build_processor(local_path, **kwargs) - return sys.modules[__name__].build_processor(*args, **kwargs) + return sys.modules[__name__].build_processor(*args, **kwargs) \ No newline at end of file diff --git a/cosmos_framework/data/vfm/processors/nemotronvl_processor.py b/cosmos_framework/data/vfm/processors/nemotronvl_processor.py index 767c8ef..077c80e 100644 --- a/cosmos_framework/data/vfm/processors/nemotronvl_processor.py +++ b/cosmos_framework/data/vfm/processors/nemotronvl_processor.py @@ -248,7 +248,6 @@ def __init__( # NemotronVL hardcodes these helper attributes because they are not # discoverable from the HF model config; the values match the upstream # vision-encoder configuration. - # HACK: hardcoded based on the model config. self.min_height_width = 512 self.patch_size = 16 self.temporal_patch_size = 1 @@ -258,7 +257,6 @@ def __init__( def _resolve_pad_id(self): # NemotronVL's tokenizer does not specify a pad_token; reserve # for padding (project convention). - return self.processor.tokenizer.convert_tokens_to_ids("") def apply_chat_template( @@ -396,7 +394,7 @@ def add_assistant_tokens_mask(self, tokens): import requests - response = requests.get("https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg") + response = requests.get("https://invalid_url") img = Image.open(BytesIO(response.content)) # test video diff --git a/cosmos_framework/data/vfm/processors/qwen3vl_processor.py b/cosmos_framework/data/vfm/processors/qwen3vl_processor.py index 030d040..dffa47c 100644 --- a/cosmos_framework/data/vfm/processors/qwen3vl_processor.py +++ b/cosmos_framework/data/vfm/processors/qwen3vl_processor.py @@ -71,7 +71,7 @@ def apply_chat_template( num_video, video_fps, video_total_num_frames, video_frames_indices = maybe_parse_video_content(messages) if num_video > 0: # Here we add the args to avoid the error: - # File "/usr/local/lib/python3.12/dist-packages/transformers/video_processing_utils.py", line 321, in _decode_and_sample_videos + # File "/invalid_dir", line 321, in _decode_and_sample_videos # raise ValueError( # ValueError: Sampling frames from a list of images is not supported! Set `do_sample_frames=False`. kwargs["videos_kwargs"] = dict(do_sample_frames=False) @@ -178,7 +178,7 @@ def add_assistant_tokens_mask(self, tokens): "content": [ { "type": "video", - "video": ["https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg"] * 4, + "video": ["https://invalid_url"] * 4, "fps": 12, }, {"type": "text", "text": "What is the capital of France?"}, diff --git a/cosmos_framework/data/vfm/sequence_packing.py b/cosmos_framework/data/vfm/sequence_packing.py index d2821cc..1209a2d 100644 --- a/cosmos_framework/data/vfm/sequence_packing.py +++ b/cosmos_framework/data/vfm/sequence_packing.py @@ -685,6 +685,7 @@ def _pack_vision_tokens( enable_fps_modulation: bool = False, base_fps: float = 24.0, temporal_compression_factor: int = 4, + vision_temporal_positions: torch.Tensor | None = None, ) -> int: """Pack vision tokens into the sequence. @@ -701,6 +702,8 @@ def _pack_vision_tokens( enable_fps_modulation: If True, scale temporal position IDs based on video FPS. base_fps: Base FPS for normalization (default 24.0). temporal_compression_factor: VAE temporal compression factor (default 4). + vision_temporal_positions: Optional explicit temporal coordinate per latent + frame, shape ``(T,)``. Used by UniAE to account for kept boundary latents. Returns: Vision split length. """ @@ -773,6 +776,8 @@ def _pack_vision_tokens( if packed_seq._use_mrope: # Determine FPS for this vision segment (None disables FPS modulation) effective_fps = vision_fps if enable_fps_modulation else None + if vision_temporal_positions is not None: + vision_temporal_positions = vision_temporal_positions.to(device="cpu", dtype=torch.float32) # [T] vision_mrope_ids, packed_seq._mrope_temporal_offset = get_3d_mrope_ids_vae_tokens( grid_t=latent_t, @@ -783,6 +788,8 @@ def _pack_vision_tokens( fps=effective_fps, base_fps=base_fps, temporal_compression_factor=temporal_compression_factor, + temporal_positions=vision_temporal_positions, + actual_temporal_compression_factor=temporal_compression_factor, ) # vision_mrope_ids: [3,N_vision_tokens] packed_seq.position_ids.append(vision_mrope_ids) else: @@ -850,7 +857,6 @@ def _pack_action_tokens( packed_seq.action.token_shapes.append((action_split_len,)) packed_seq.action.tokens.append(input_action_tokens) - condition_set = {idx for idx in condition_frame_indexes_action if 0 <= idx < action_split_len} assert isinstance(packed_seq.action.condition_mask, list) @@ -917,6 +923,7 @@ def _pack_sound_tokens( enable_fps_modulation: bool = False, base_fps: float = 24.0, sound_fps: float | None = None, + sound_base_temporal_compression_factor: int | None = None, ) -> int: """Pack sound/audio tokens into the sequence. @@ -936,6 +943,8 @@ def _pack_sound_tokens( enable_fps_modulation: If True, scale temporal positions by FPS ratio. base_fps: Base FPS for normalization (default 24.0). sound_fps: Sound latent FPS (e.g., 25.0). Used for FPS-aware m-RoPE positions. + sound_base_temporal_compression_factor: Base temporal compression factor for sound FPS scaling. + ``None`` preserves the current behavior where sound advances at ``base_fps`` positions/sec. Returns: Number of sound tokens added. @@ -1008,6 +1017,7 @@ def _pack_sound_tokens( fps=effective_fps, base_fps=base_fps, temporal_compression_factor=1, # Sound latent is already at sound_latent_fps (no further compression) + base_temporal_compression_factor=sound_base_temporal_compression_factor, start_frame_offset=0, # Sound[0] aligns with vision frame 0 ) # sound_mrope_ids: [3,N_sound_tokens] packed_seq.position_ids.append(sound_mrope_ids) @@ -1047,11 +1057,18 @@ def _pack_supertokens_temporal_causal( ``num_action_tokens_per_supertoken=0`` is stamped on the pack and read by the attention builder so NATTEN metadata stays in sync automatically. - mRoPE layout (with actions, unified_3d_mrope only): - - Null actions (frame 0): all tcf tokens at ``temporal_offset``. - - Real training actions (frames 1..T-1): ``start_frame_offset=1`` so the - last action in group i co-locates with vision frame i. - - AR real actions (single supertoken): ``start_frame_offset=0``. + mRoPE layout (with actions, unified_3d_mrope only). The layout is inferred from the + action tensor shape: + - Whole-clip training (frame 0 is the clean conditioning frame, so + ``real_actions`` has ``(T-1)*tcf`` rows): null action for supertoken 0, real + actions for frames 1..T-1 with ``start_frame_offset=1`` so the last action in + group i co-locates with vision frame i; vision uses ``start_frame_offset=0``. + - AR generation, single frame OR chunk (every frame carries a real action, so + ``real_actions`` has ``latent_t*tcf`` rows): vision AND action both use + ``start_frame_offset=1``, generalizing the single-frame AR supertoken to + ``latent_t`` frames. The caller (``pack_input_sequence_autoregressive``) + seeds ``temporal_offset`` one frame-stride back to compensate, so the unit + lands at the same absolute positions as the whole-clip training pack. - Interleaved per frame as cat([action_ids, vision_ids]). ``input_timestep`` is float (TF/none) or Tensor(T_max,) (DF, per-frame sigma). @@ -1094,32 +1111,36 @@ def _pack_supertokens_temporal_causal( if pack_action_tokens: # Build all_action_tokens: shape (latent_t * tcf, action_dim) # - # Cases: - # 1. Training with conditioning frame (latent_t > 1, real_actions < latent_t*tcf): - # Prepend tcf null tokens for frame 0, then real actions for frames 1..T-1. - # 2. KV-cache continuation (latent_t > 1, real_actions == latent_t*tcf): all supertokens - # carry real actions (no conditioning frame in-segment). - # 3. AR frame N>0 (latent_t == 1, action provided): real actions, no null prefix. - # 4. AR frame 0 / image2video (action is None): all null tokens. + # Cases (token assembly; mRoPE start_frame_offset is chosen separately below, + # inferred from the same action shape): + # 1. Whole-clip training with conditioning frame (latent_t > 1, real_actions + # has (T-1)*tcf rows): prepend tcf null tokens for frame 0, then real + # actions for frames 1..T-1. + # 2. AR generation (every frame has a real action, real_actions has + # latent_t*tcf rows — single frame OR chunk): no null prefix. + # 3. AR frame 0 / image2video (action is None): all null tokens. if input_action_tokens is not None: - # input_action_tokens shape: (1, T*tcf, D) or (T*tcf, D) for training; (tcf, D) for AR frame N>0 + # input_action_tokens shape: (1, T*tcf, D) or (T*tcf, D) for training; (T*tcf, D) for AR units if input_action_tokens.dim() == 3: real_actions = input_action_tokens.squeeze(0) # [T*tcf,action_dim] or [N,action_dim] else: real_actions = input_action_tokens # [N,action_dim] null_tokens = torch.zeros(tcf, action_dim, device=device, dtype=real_actions.dtype) # [tcf,action_dim] - if latent_t == 1: - # AR frame N>0: single supertoken with real actions, no null prefix - all_action_tokens = real_actions # [tcf,action_dim] - null_action_flag = False - elif real_actions.shape[0] == latent_t * tcf: - # All frames have real actions (e.g. KV-cache continuation segments) + if real_actions.shape[0] == latent_t * tcf: + # AR generation (single frame: tcf == 1*tcf, or chunk: latent_t*tcf): + # every supertoken carries a real action, no null prefix. all_action_tokens = real_actions null_action_flag = False - else: + elif real_actions.shape[0] == (latent_t - 1) * tcf: # Conditioning frame present: null for supertoken 0, real for 1..T-1 all_action_tokens = torch.cat([null_tokens, real_actions], dim=0) # [T*tcf,action_dim] null_action_flag = True + else: + raise ValueError( + "Temporal-causal action tokens must have either latent_t*tcf rows for AR chunks " + f"or (latent_t-1)*tcf rows for whole-clip training; got {real_actions.shape[0]} rows " + f"for latent_t={latent_t}, tcf={tcf}." + ) else: # AR frame 0 or image2video: all action tokens are null all_action_tokens = torch.zeros( @@ -1171,14 +1192,17 @@ def _pack_supertokens_temporal_causal( temporal_offset = packed_seq._mrope_temporal_offset effective_vision_fps = vision_fps if enable_fps_modulation else None - # AR frame N>=1 with action_gen=True (latent_t==1 and real actions supplied): - # shift both vision and action by start_frame_offset=1 so the last action in - # the group co-locates with vision frame N, mirroring training's layout. - # All other cases (training latent_t>1, AR action_gen=False, AR frame 0 null) - # keep start_frame_offset=0. The caller in pack_input_sequence_autoregressive - # seeds temporal_offset accordingly (N-1 frames back when this shift applies). - ar_with_real_actions = latent_t == 1 and pack_action_tokens and input_action_tokens is not None - vision_sfo = 1 if ar_with_real_actions else 0 + # AR generation (single frame OR chunk) is detected by every frame carrying a + # real action (``real_actions`` has ``latent_t*tcf`` rows). There, vision AND + # action both use start_frame_offset=1 so the last action in each group + # co-locates with its vision frame, mirroring whole-clip training; the caller + # (pack_input_sequence_autoregressive) seeds temporal_offset one frame-stride + # back to compensate. Whole-clip training (frame 0 is the null conditioning + # frame, ``real_actions`` has ``(T-1)*tcf`` rows) keeps vision start_frame_offset=0. + all_frames_have_real_action = ( + pack_action_tokens and input_action_tokens is not None and real_actions.shape[0] == latent_t * tcf + ) + vision_sfo = 1 if all_frames_have_real_action else 0 vision_ids_flat, new_offset = get_3d_mrope_ids_vae_tokens( grid_t=latent_t, @@ -1195,10 +1219,10 @@ def _pack_supertokens_temporal_causal( if pack_action_tokens: effective_action_fps = action_fps if enable_fps_modulation else None - # Action IDs: null for frame 0 (all tcf tokens share temporal_offset, - # co-located with vision frame 0), real for frames 1..T-1. - # Real tokens (training and AR) use start_frame_offset=1 so the last - # action in a group co-locates with vision frame i. + # Action IDs. Real action tokens use start_frame_offset=1 so the last + # sub-token of a group co-locates with its vision frame. Whole-clip training + # has a null action at frame 0 (the conditioning frame); AR units have a real + # action for every frame. fps_active = effective_action_fps is not None t_dtype = torch.float32 if fps_active else torch.long t_offset = float(temporal_offset) if fps_active else int(temporal_offset) @@ -1221,28 +1245,24 @@ def _real_action_ids(n_frames: int, start_frame_offset: int) -> torch.Tensor: ) return flat.reshape(3, n_frames, tcf) # [3,n_frames,tcf] - if latent_t > 1 and input_action_tokens is not None: - if real_actions.shape[0] == latent_t * tcf: - # KV continuation: real action in every supertoken (including frame 0) - action_ids_3d = _real_action_ids(latent_t, start_frame_offset=0) - else: - # Training with conditioning frame: supertoken 0 = null, 1..T-1 = real - null_ids_3d = null_ids.reshape(3, 1, tcf) # [3,1,tcf] - real_ids_3d = _real_action_ids(latent_t - 1, start_frame_offset=1) # [3,T-1,tcf] - action_ids_3d = torch.cat([null_ids_3d, real_ids_3d], dim=1) # [3,T,tcf] + if all_frames_have_real_action: + # AR generation (single frame: tcf == 1*tcf, or chunk: latent_t*tcf): + # every supertoken carries a real action. start_frame_offset=1 puts + # a_{j-1}'s last sub-token on vision frame j -- the whole-clip TF + # training layout. The caller seeds temporal_offset (N-1) frame-strides + # back to compensate. + action_ids_3d = _real_action_ids(latent_t, start_frame_offset=1) # [3,T,tcf] elif latent_t > 1: - # No action tensor (all-null layout): same ID structure as training w/ conditioning frame. + # Whole-clip training: supertoken 0 = null (conditioning frame), frames + # 1..T-1 = real with start_frame_offset=1. Covers real-action training + # (real_actions has (T-1)*tcf rows) and the architectural all-null layout + # (input_action_tokens is None); the tokens differ but the IDs match. null_ids_3d = null_ids.reshape(3, 1, tcf) # [3,1,tcf] real_ids_3d = _real_action_ids(latent_t - 1, start_frame_offset=1) # [3,T-1,tcf] action_ids_3d = torch.cat([null_ids_3d, real_ids_3d], dim=1) # [3,T,tcf] - elif input_action_tokens is None: - # AR frame 0 / image2video: only null - action_ids_3d = null_ids.reshape(3, 1, tcf) # [3,1,tcf] else: - # AR frame N>=1: single supertoken with real actions. start_frame_offset=1 - # matches training (last action co-locates with vision frame N); caller - # seeds temporal_offset to (N-1) frame-strides back to compensate. - action_ids_3d = _real_action_ids(1, start_frame_offset=1) # [3,1,tcf] + # AR frame 0 / image2video (latent_t == 1, no action): only null. + action_ids_3d = null_ids.reshape(3, 1, tcf) # [3,1,tcf] # (3, T*H*W) → (3, T, H*W) vision_ids_3d = vision_ids_flat.reshape(3, latent_t, patches_per_frame) # [3,T,patch_h*patch_w] @@ -1309,7 +1329,9 @@ def pack_input_sequence( unified_3d_mrope_temporal_modality_margin: int = 0, enable_fps_modulation: bool = False, base_fps: float = 24.0, + sound_base_temporal_compression_factor: int | None = None, temporal_compression_factor: int = 4, + vision_temporal_position_mode: str = "latent_index", video_temporal_causal: bool = False, action_dim: int = 32, initial_mrope_temporal_offset: int | float = 0, @@ -1347,8 +1369,13 @@ def pack_input_sequence( Uses the same flag as diffusion_expert_config.enable_fps_modulation. base_fps: Base FPS for normalization (default 24.0). Uses the same value as diffusion_expert_config.base_fps. + sound_base_temporal_compression_factor: Base temporal compression factor for sound FPS scaling. + ``None`` preserves the current behavior where sound advances at ``base_fps`` positions/sec. temporal_compression_factor: VAE temporal compression factor (default 4). Obtained from the VAE tokenizer at runtime. + vision_temporal_position_mode: Temporal coordinates used for unified_3d_mrope vision tokens. + "latent_index" keeps legacy positions; "uniae_source_right_edge" uses + per-latent positions from gen_data_clean.temporal_positions_vision. Returns: PackedSequence containing all packed tensors and metadata. See PackedSequence for field details. """ @@ -1361,6 +1388,44 @@ def pack_input_sequence( if isinstance(input_text_indexes, torch.Tensor): raise ValueError("input_text_tokens must be a list, not a tensor") + supported_vision_temporal_position_modes = {"latent_index", "uniae_source_right_edge"} + if vision_temporal_position_mode not in supported_vision_temporal_position_modes: + raise ValueError( + "Unsupported vision_temporal_position_mode: " + f"{vision_temporal_position_mode}. Supported modes: {supported_vision_temporal_position_modes}." + ) + has_any_vision = any(plan.has_vision for plan in sequence_plans) + explicit_vision_temporal_positions_active = vision_temporal_position_mode != "latent_index" and has_any_vision + if explicit_vision_temporal_positions_active: + if position_embedding_type != "unified_3d_mrope": + raise NotImplementedError( + "Explicit vision temporal positions are only supported with position_embedding_type='unified_3d_mrope'." + ) + if gen_data_clean.temporal_positions_vision is None: + raise ValueError( + f"vision_temporal_position_mode={vision_temporal_position_mode} requires " + "gen_data_clean.temporal_positions_vision." + ) + if gen_data_clean.x0_tokens_vision is not None and len(gen_data_clean.temporal_positions_vision) != len( + gen_data_clean.x0_tokens_vision + ): + raise ValueError( + "temporal_positions_vision must have one entry per x0_tokens_vision item, " + f"got {len(gen_data_clean.temporal_positions_vision)} positions for " + f"{len(gen_data_clean.x0_tokens_vision)} vision items." + ) + if video_temporal_causal: + raise NotImplementedError( + "video_temporal_causal=True is not wired for explicit UniAE vision temporal positions yet." + ) + if any(plan.has_action for plan in sequence_plans): + raise NotImplementedError("Action packing is not wired for explicit UniAE vision temporal positions yet.") + if initial_mrope_temporal_offset != 0: + raise NotImplementedError( + "Autoregressive mRoPE temporal offsets are not wired for explicit UniAE vision temporal positions yet." + ) + use_float_mrope_positions = enable_fps_modulation or explicit_vision_temporal_positions_active + # Initialize packed sequence (acts as builder during packing) packed_seq = PackedSequence() @@ -1405,7 +1470,7 @@ def pack_input_sequence( special_tokens, curr_rope_id, has_generation=has_generation_for_sample, - use_float_positions=enable_fps_modulation, + use_float_positions=use_float_mrope_positions, ) sample_len += text_sample_len @@ -1496,6 +1561,7 @@ def pack_input_sequence( shared_latent_t: int | None = None shared_patch_h: int | None = None shared_patch_w: int | None = None + shared_temporal_positions: torch.Tensor | None = None # FPS is recorded per-sample (shape [B]); for multi-item samples # (transfer / image-edit) every vision item in this sample shares # the same conditioning FPS, so we read by sample_idx, not by the @@ -1510,7 +1576,18 @@ def pack_input_sequence( sample_vision_fps = float(gen_data_clean.fps_vision[sample_idx].item()) for item_idx in range(num_vis): - input_vision_tokens = gen_data_clean.x0_tokens_vision[idx_vision] + flat_vision_idx = idx_vision + input_vision_tokens = gen_data_clean.x0_tokens_vision[flat_vision_idx] + vision_temporal_positions: torch.Tensor | None = None + if explicit_vision_temporal_positions_active: + assert gen_data_clean.temporal_positions_vision is not None + vision_temporal_positions = gen_data_clean.temporal_positions_vision[flat_vision_idx] + if vision_temporal_positions.shape[0] != input_vision_tokens.shape[2]: + raise ValueError( + "vision_temporal_positions must match latent_t for each vision item, " + f"got {vision_temporal_positions.shape[0]} positions and " + f"latent_t={input_vision_tokens.shape[2]} for item {flat_vision_idx}." + ) vision_fps = sample_vision_fps idx_vision += 1 @@ -1544,6 +1621,19 @@ def pack_input_sequence( f"got item {item_idx} (H,W)=({item_latent_h},{item_latent_w}) " f"vs first=({shared_patch_h},{shared_patch_w})" ) + if vision_temporal_positions is not None: + if shared_temporal_positions is None: + shared_temporal_positions = vision_temporal_positions + else: + comparison_temporal_positions = vision_temporal_positions.to( + device=shared_temporal_positions.device + ) # [T] + assert torch.allclose(comparison_temporal_positions, shared_temporal_positions), ( + "share_vision_temporal_positions requires equal explicit temporal positions " + f"across vision items, got item {item_idx} positions " + f"{vision_temporal_positions.tolist()} vs first " + f"{shared_temporal_positions.tolist()}." + ) # Rewind so this item starts at the same temporal offset as item 0. packed_seq._mrope_temporal_offset = items_temporal_offset_snapshot @@ -1558,6 +1648,7 @@ def pack_input_sequence( enable_fps_modulation=enable_fps_modulation, base_fps=base_fps, temporal_compression_factor=temporal_compression_factor, + vision_temporal_positions=vision_temporal_positions, ) vision_split_len += item_split_len sample_len += vision_split_len @@ -1622,6 +1713,7 @@ def pack_input_sequence( enable_fps_modulation=enable_fps_modulation, base_fps=base_fps, sound_fps=sound_fps, + sound_base_temporal_compression_factor=sound_base_temporal_compression_factor, ) sample_len += sound_split_len else: @@ -1641,8 +1733,8 @@ def pack_input_sequence( # EOV position IDs: 3D mRoPE or 1D RoPE if packed_seq._use_mrope: - # Use float dtype when FPS modulation is enabled for consistency - eov_dtype = torch.float32 if enable_fps_modulation else torch.long + # Use float dtype when any vision mRoPE positions are fractional. + eov_dtype = torch.float32 if use_float_mrope_positions else torch.long eov_mrope_ids = torch.full((3, 1), packed_seq._mrope_temporal_offset, dtype=eov_dtype) # [3,1] packed_seq.position_ids.append(eov_mrope_ids) # type: ignore[arg-type] packed_seq._mrope_temporal_offset += 1 @@ -2095,7 +2187,7 @@ def verify_natten_parameter_list( {'window_size_float': (0.5, 0.5), 'dilation_float': (1.0, 0.0)} # valid # Fixed window size of 8x8, dilation of 2x1. - + # NOTE: requires ALL inputs to be at least 16x8 {'window_size': (8, 8), 'dilation': (2, 1)} # valid # Multi-profile: different parameters for 2D (images) and 3D (videos) @@ -2231,7 +2323,7 @@ def generate_natten_metadata( {'window_size_float': (0.5, 0.5), 'dilation_float': (1.0, 0.0)} # valid # Fixed window size of 8x8, dilation of 2x1. - + # NOTE: requires ALL inputs to be at least 16x8 {'window_size': (8, 8), 'dilation': (2, 1)} # valid # Invalid: @@ -2363,9 +2455,9 @@ def filter_shape(shape: tuple) -> tuple: is_causal = dim_params["is_causal"] # Create varlen metadata for natten varlen/varsized ops - + # NOTE: generate_multi_dim_varlen_parameters will automatically map window size -1 to # full size, that's why constant window sizes aren't allowed. - + # NOTE: if any of the parameters are constant, natten will simplify them natten_metadata.append( generate_multi_dim_varlen_parameters( token_layout_list=token_layout_list, @@ -2780,7 +2872,6 @@ def build_sequence_plans_from_data_batch( Returns: List of SequencePlan objects, one per sample in the batch. """ - # For new modalities, please generate the sequence_plan in the dataset class!!!! # If sequence_plan already exists in data_batch, return it @@ -2790,7 +2881,6 @@ def build_sequence_plans_from_data_batch( assert "action" not in data_batch or data_batch["action"] is None, "Action data SHOULD have sequence_plans!" assert "sound" not in data_batch or data_batch["sound"] is None, "Sound data SHOULD have sequence_plans!" - # Determine batch size from available tensors batch_size = 0 for key in [input_video_key, input_image_key]: diff --git a/cosmos_framework/data/vfm/sound_data_utils.py b/cosmos_framework/data/vfm/sound_data_utils.py index 0a8ec63..2d739b0 100644 --- a/cosmos_framework/data/vfm/sound_data_utils.py +++ b/cosmos_framework/data/vfm/sound_data_utils.py @@ -3,7 +3,8 @@ """Sound data utilities for building sequence plans and handling audio-video generation modes. -This module provides utilities for building SequencePlan objects based on sound generation modes. +This module provides utilities for building SequencePlan objects based on sound generation modes, +similar to how action modes are handled in cosmos_framework/data/vfm/action/data_utils.py. Supported modes: - t2vs: Text → Video + Sound (joint generation) @@ -27,7 +28,8 @@ def build_sequence_plan_for_sound( """Build a SequencePlan based on the sound generation mode. This function determines the appropriate condition frame indexes for vision and sound - based on the specified mode. + based on the specified mode. It mirrors how `build_sequence_plan_from_mode` works + for action in cosmos_framework/data/vfm/action/data_utils.py. Args: mode: Generation mode. One of: diff --git a/cosmos_framework/data/vfm/utils.py b/cosmos_framework/data/vfm/utils.py index ab6f238..8b04f0a 100644 --- a/cosmos_framework/data/vfm/utils.py +++ b/cosmos_framework/data/vfm/utils.py @@ -6,7 +6,6 @@ IMAGE_RES_SIZE_INFO: dict[str, dict[str, tuple[int, int]]] = { # Our desired 256 resolution is the one below (commented). - # Desired: "256": {"1,1": (336, 336), "4,3": (384, 288), "3,4": (288, 384), "16,9": (448, 256), "9,16": (256, 448)}, "256": { "1,1": (256, 256), @@ -41,7 +40,6 @@ VIDEO_RES_SIZE_INFO: dict[str, dict[str, tuple[int, int]]] = { # Our desired 256 resolution is the one below (commented). - # Desired: "256": {"1,1": (336, 336), "4,3": (384, 288), "3,4": (288, 384), "16,9": (448, 256), "9,16": (256, 448)}, "256": { "1,1": (256, 256), @@ -111,10 +109,42 @@ def parse_frame_range_from_wdinfo(wdinfo: str) -> tuple[int, int] | None: return None +def _normalize_skip_frame_ranges( + skip_frame_range: str | list[str] | None, +) -> set[tuple[int, int]]: + """Normalize ``skip_frame_range`` into a set of (min_frames, max_frames) buckets. + + Args: + skip_frame_range: A single bucket string like ``"300_400"``, a list of such + strings, or None. Each string identifies the frame-range bucket + (e.g. ``frames_300_400``) that should be skipped. + + Returns: + Set of (min_frames, max_frames) tuples to skip. Empty if ``skip_frame_range`` is None. + """ + if skip_frame_range is None: + return set() + + if isinstance(skip_frame_range, str): + skip_frame_range = [skip_frame_range] + + skip_buckets: set[tuple[int, int]] = set() + for bucket in skip_frame_range: + match = re.fullmatch(r"(\d+)_(\d+)", bucket.strip()) + if match is None: + raise ValueError( + f"Invalid skip_frame_range entry {bucket!r}. Expected the form '_', e.g. '300_400'." + ) + skip_buckets.add((int(match.group(1)), int(match.group(2)))) + + return skip_buckets + + def filter_wdinfos_by_frame_range( wdinfos: list[str], min_frames: int | None = None, max_frames: int | None = None, + skip_frame_range: str | list[str] | None = None, ) -> list[str]: """ Filter wdinfo files based on frame range. @@ -125,10 +155,16 @@ def filter_wdinfos_by_frame_range( - min_frames is EXCLUSIVE: wdinfo_max must be > min_frames - max_frames is INCLUSIVE: wdinfo_max must be <= max_frames + Additionally, any wdinfo whose frame-range bucket matches an entry in + ``skip_frame_range`` is excluded. + Args: wdinfos: List of wdinfo paths min_frames: Minimum number of frames (exclusive). If None, no lower bound. max_frames: Maximum number of frames (inclusive). If None, no upper bound. + skip_frame_range: Frame-range bucket(s) to exclude, e.g. ``"300_400"`` to + drop the ``frames_300_400`` bucket. Accepts a single string or a list + of strings. If None, no bucket is skipped. Returns: Filtered list of wdinfo paths @@ -144,8 +180,14 @@ def filter_wdinfos_by_frame_range( # frames_400_500 excluded because wdinfo_max (500) <= min_frames (500) # frames_500_600 included because wdinfo_max (600) > min_frames (500) AND <= max_frames (600) # frames_600_700 excluded because wdinfo_max (700) > max_frames (600) + + >>> filter_wdinfos_by_frame_range(wdinfos, skip_frame_range="500_600") + ['wdinfo/frames_400_500/wdinfo.json', 'wdinfo/frames_600_700/wdinfo.json'] + # frames_500_600 excluded because its bucket matches skip_frame_range """ - if min_frames is None and max_frames is None: + skip_buckets = _normalize_skip_frame_ranges(skip_frame_range) + + if min_frames is None and max_frames is None and not skip_buckets: return wdinfos filtered = [] @@ -158,6 +200,10 @@ def filter_wdinfos_by_frame_range( wdinfo_min, wdinfo_max = frame_range + # Skip explicitly excluded buckets (matched on the full (min, max) bucket). + if (wdinfo_min, wdinfo_max) in skip_buckets: + continue + # Filter based on wdinfo's upper bound (wdinfo_max): # - min_frames is exclusive: wdinfo_max must be > min_frames # - max_frames is inclusive: wdinfo_max must be <= max_frames diff --git a/cosmos_framework/data/vfm/vlm/video_decoder_qwen.py b/cosmos_framework/data/vfm/vlm/video_decoder_qwen.py new file mode 100644 index 0000000..12c9bcc --- /dev/null +++ b/cosmos_framework/data/vfm/vlm/video_decoder_qwen.py @@ -0,0 +1,249 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: OpenMDW-1.1 + +""" +Copied from projects/cosmos/reason1/datasets/video_decoder_qwen.py +Changes: +1: remove hardcoded hyper-parameters for Qwen, now read it from processor +2: support skipping smart resize, since it may resize the video frames to be smaller than model input and frames will get resized up later in processor +""" + +import random +import re +from concurrent.futures import ThreadPoolExecutor +from functools import partial +from typing import Callable, Optional + +import torch +from PIL import Image +from qwen_vl_utils.vision_process import smart_nframes, smart_resize +from torchcodec.decoders import VideoDecoder +from torchvision import transforms +from torchvision.transforms import InterpolationMode + +from cosmos_framework.utils import log +from cosmos_framework.data.vfm.processors.qwen3vl_processor import Qwen3VLProcessor + +Image.MAX_IMAGE_PIXELS = 933120000 +_VIDEO_EXTENSIONS = "mp4 avi webm mov".split() + +VIDEO_DECODER_OPTIONS = {} + + +def token_to_pixels(token_length: int, patch_size: int = 14, temporal_patch_size: int = 2, merge_size: int = 2) -> int: + """Convert token length to pixels based on patch size and temporal patch size. + + Args: + token_length: Token length + patch_size: Patch size + temporal_patch_size: Temporal patch size, + for Qwen it has 3D conv, temporal patch size is 2; for other models like internVL or eagle er, the temporal patch size is 1 since their VIT is image encoder; + merge_size: Merge size, or called pixel shuffing factor; + for Qwen and internVL it is 2; for eagle er it is 1; + """ + merged_patch_size = patch_size * merge_size + return token_length * merged_patch_size**2 * temporal_patch_size + + +def pixels_to_token(pixels: int, patch_size: int = 14, temporal_patch_size: int = 2, merge_size: int = 2) -> int: + """Convert pixels to token length based on patch size and temporal patch size.""" + merged_patch_size = patch_size * merge_size + return pixels // merged_patch_size**2 // temporal_patch_size + + +def video_decoder_qwen( + num_threads: int = 0, + min_fps_thres: int = 4, + max_fps_thres: int = 60, + target_fps: float = 2.0, + min_video_token_length: int = 16, + max_video_token_length: int = 8192, + random_augmentation: bool = False, + frame_count_random_range: Optional[list[int]] = None, + **kwargs, +) -> Callable: + """ + Sampling video frames similar to Qwen. It prioritizes matching the target FPS first and then resizing the video frames. + See https://github.com/kq-chen/qwen-vl-utils/blob/main/src/qwen_vl_utils/vision_process.py#L118 for more details. + + Args: + key: Video file name/key + data: Video binary data + min_fps_thres: Minimum FPS threshold + max_fps_thres: Maximum FPS threshold + target_fps: Target FPS + min_video_token_length: Minimum token length + max_video_token_length: Maximum token length + num_threads: Number of threads for the torchcodec video decoder + random_augmentation: Whether to randomize the FPS and max_video_token_length + frame_count_random_range: Random frame count range + + Returns: + dict with video frames tensor and target FPS + """ + + video_decoder_configured = partial( + _video_decoder_qwen_func, + min_fps_thres=min_fps_thres, + max_fps_thres=max_fps_thres, + num_threads=num_threads, + target_fps=target_fps, + min_video_token_length=min_video_token_length, + max_video_token_length=max_video_token_length, + random_augmentation=random_augmentation, + frame_count_random_range=frame_count_random_range, + ) + + return video_decoder_configured + + +def _video_decoder_qwen_func( + key: str, + data: bytes, + processor: Qwen3VLProcessor, + min_fps_thres: int = 4, + max_fps_thres: int = 60, + target_fps: float = 2.0, + min_video_token_length: int = 16, + max_video_token_length: int = 8192, + num_threads: int = 0, + random_augmentation: bool = False, + fps_random_range: list[float] = [0.5, 1.5], + max_video_token_length_random_range: list[float] = [0.75, 1.25], + frame_count_random_range: Optional[list[int]] = None, + start_frame: Optional[int] = None, + end_frame: Optional[int] = None, + decoding_timeout: int = 60, + **kwargs, +) -> dict | None: + """Actual video decoder function. + + Args: + key (str): Video file name/key + data (bytes): Video binary data + min_fps_thres (int, optional): Minimum FPS threshold. Defaults to 4. + max_fps_thres (int, optional): Maximum FPS threshold. Defaults to 60. + target_fps (float, optional): Target FPS. Defaults to 2.0. + min_video_token_length (int, optional): Minimum token length. Defaults to 16. + max_video_token_length (int, optional): Maximum token length. Defaults to 8192. + num_threads (int, optional): Number of threads for the torchcodec video decoder. Defaults to 0. + random_augmentation (bool, optional): Whether to randomize the FPS and max_video_token_length. Defaults to False. + fps_random_range (list[float], optional): Random FPS range. Defaults to [10.0, 24.0]. + max_video_token_length_random_range (list[float], optional): Random max_video_token_length range. Defaults to [0.75, 1.25]. + frame_count_random_range (list[int], optional): Random frame count range. If provided, take priority over fps_random_range. + start_frame (Optional[int], optional): Start frame. Defaults to None. If both start_frame and end_frame are provided, the video will be decoded from start_frame to end_frame. + end_frame (Optional[int], optional): End frame. Defaults to None. If both start_frame and end_frame are provided, the video will be decoded from start_frame to end_frame. + decoding_timeout (int, optional): Timeout in seconds. Defaults to 60. + Raises: + ValueError: Video fps lower than 1, skipping + ValueError: Video fps lower than min_fps_thres, skipping + ValueError: Video fps higher than max_fps_thres, skipping + + Returns: + dict | None: Dictionary with video frames tensor and target FPS + """ + # Check video extension + extension = re.sub(r".*[.]", "", key) + if extension.lower() not in _VIDEO_EXTENSIONS: + return None + + # Read video with torchcodec + video_reader = VideoDecoder(data, num_ffmpeg_threads=num_threads) + total_frames = video_reader.metadata.num_frames + video_fps = video_reader.metadata.average_fps + + # torchcodec returns ``None`` for containers that don't store frame count + # or average fps (e.g. some MKV/WebM streams). Downstream arithmetic + # (``total_frames - 1``, ``video_fps < 1``, ...) would TypeError on None; + # surface a ValueError so the dataloader's skip path handles it uniformly. + if total_frames is None or video_fps is None: + raise ValueError(f"torchcodec missing metadata (num_frames={total_frames}, average_fps={video_fps}), skipping") + + if start_frame is not None and end_frame is not None: + total_frames = end_frame - start_frame + + if video_fps < 1: + raise ValueError("Video fps lower than 1, skipping") + if video_fps < min_fps_thres: + raise ValueError(f"Video fps {video_fps} lower than {min_fps_thres}, skipping") + if video_fps > max_fps_thres: + raise ValueError(f"Video fps {video_fps} higher than {max_fps_thres}, skipping") + + if random_augmentation: + if frame_count_random_range is not None: + # Random number of frames + min_frames_range, max_frames_range = frame_count_random_range + min_frames_range = min(min_frames_range, total_frames) + max_frames_range = min(max_frames_range, total_frames) + target_frames = random.uniform(min_frames_range, max_frames_range) + target_fps = target_frames / total_frames * video_fps + else: + # randomize fps + target_fps = ( + random.uniform(fps_random_range[0], fps_random_range[1]) * target_fps + if random.random() < 0.5 + else target_fps + ) + # randomize max_video_token_length + max_video_token_length = int( + random.uniform(max_video_token_length_random_range[0], max_video_token_length_random_range[1]) + * max_video_token_length + ) + log.debug(f"random_augmentation: max_video_token_length: {max_video_token_length}, target_fps: {target_fps}") + + patch_size = processor.patch_size + min_height_width = processor.min_height_width + temporal_patch_size = processor.temporal_patch_size + merge_size = processor.merge_size + min_pixels: int = token_to_pixels(min_video_token_length, patch_size, temporal_patch_size, merge_size) + max_pixels: int = token_to_pixels(max_video_token_length, patch_size, temporal_patch_size, merge_size) + max_frames: int = max_pixels // (min_height_width) ** 2 // temporal_patch_size + + # sample based on target fps + nframes = smart_nframes(dict(fps=target_fps), total_frames=total_frames, video_fps=video_fps) + nframes = min(nframes, max_frames) + if start_frame is not None and end_frame is not None: + idx = torch.linspace(start_frame, end_frame - 1, nframes).round().long().tolist() # [nframes] + else: + idx = torch.linspace(0, total_frames - 1, nframes).round().long().tolist() # [nframes] + + def _decode_video() -> torch.Tensor: + return video_reader.get_frames_at(indices=idx).data # [T, C, H, W] uint8 + + # Use ThreadPoolExecutor to run video decoding with a timeout. + # If the thread is stuck, abandon it immediately. + executor = ThreadPoolExecutor(max_workers=1) + future = executor.submit(_decode_video) + try: + video_frames = future.result(timeout=decoding_timeout) + executor.shutdown(wait=False) + except TimeoutError as e: + log.warning(f"[{key}] Video decoding timed out after {decoding_timeout} seconds") + executor.shutdown(wait=False) + return None + + sample_fps = nframes / max(total_frames, 1e-6) * video_fps + + # recompute max_pixels based on number of sampled frames + nframes, _, height, width = video_frames.shape + max_pixels = max_pixels // nframes + if processor.use_smart_resize: + resized_height, resized_width = smart_resize( + height, + width, + factor=patch_size * merge_size, + min_pixels=min_pixels, + max_pixels=max_pixels, + ) + log.debug( + f"resized_height: {resized_height}, resized_width: {resized_width} | original height: {height}, original width: {width}" + ) + video_frames = transforms.functional.resize( + video_frames, + [resized_height, resized_width], + interpolation=InterpolationMode.BICUBIC, + antialias=True, + ).float() # [T,C,H,W] + video_frames = video_frames.permute(1, 0, 2, 3) # [C,T,H,W] + + return dict(videos=video_frames, fps=sample_fps) diff --git a/cosmos_framework/data/vlm/processors/nemotron3densevl_processor.py b/cosmos_framework/data/vlm/processors/nemotron3densevl_processor.py new file mode 100644 index 0000000..fd8406f --- /dev/null +++ b/cosmos_framework/data/vlm/processors/nemotron3densevl_processor.py @@ -0,0 +1,248 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: OpenMDW-1.1 + +import os +from typing import Dict, List, Optional + +import numpy as np +import torch +from PIL import Image +from qwen_vl_utils.vision_process import smart_resize +from transformers.models.auto.processing_auto import AutoProcessor + +from cosmos_framework.utils import log +from cosmos_framework.utils.vlm.pretrained_models_downloader import maybe_download_hf_model_from_s3 + + +def convert_string_content_to_list_content(messages: List[Dict]) -> List[Dict]: + """ + Convert the string content to a list of dicts. + """ + for message_id, message in enumerate(messages): + if isinstance(message["content"], str): + messages[message_id]["content"] = [{"type": "text", "text": message["content"]}] + return messages + + +def maybe_parse_video_content( + messages: List[Dict], +) -> tuple[int, Optional[list[float]], Optional[list[int]], Optional[list[list[int]]]]: + """ + Convert the string content to a list of dicts. + """ + num_video = 0 + video_fps = [] + video_total_num_frames = [] + video_frames_indices = [] + for message_id, message in enumerate(messages): + if isinstance(message["content"], list): + for sub_content in message["content"]: + if sub_content.get("type", "") == "video" and isinstance(sub_content["video"], list): + num_video += 1 + fps = sub_content.get("fps", None) + if fps is None: + log.critical( + f"fps is None for video {sub_content}. Better to set the fps explicitly", rank0_only=False + ) + video_fps.append(fps) + video_total_num_frames.append(len(sub_content["video"])) + video_frames_indices.append(list(range(video_total_num_frames[-1]))) + return num_video, video_fps, video_total_num_frames, video_frames_indices + + +class Nemotron3DenseVLProcessor: + # This is a wrapper around the AutoProcessor class to add some helper functions + def __init__( + self, + name="Qwen/Qwen3-VL-2B-Init", + credentials: str = "./credentials/s3_training.secret", + bucket: str = "bucket4", + cache_dir: str = None, + ): + self.name = name + if os.path.isdir(name): + model_name_or_path_local = name + else: + model_name_or_path_local = maybe_download_hf_model_from_s3( + name, credentials, bucket, include_model_weights=False + ) + + self.processor = AutoProcessor.from_pretrained(model_name_or_path_local, trust_remote_code=True) + log.info("Successfully loaded processor from local cache") + + if hasattr(self.processor, "image_token"): + self.image_token_id = self.processor.tokenizer.convert_tokens_to_ids(self.processor.image_token) + else: + self.image_token_id = None + if hasattr(self.processor, "video_token"): + self.video_token_id = self.processor.tokenizer.convert_tokens_to_ids(self.processor.video_token) + else: + self.video_token_id = None + self.eos_id = self.processor.tokenizer.eos_token_id + self.pad_id = self.processor.tokenizer.pad_token_id + self.vision_end_id = self.processor.tokenizer.convert_tokens_to_ids("") + + # Helper attributes for the dataloader video decoding function + self.shortest_edge = self.processor.image_processor.size["shortest_edge"] + self.min_height_width = int(np.sqrt(self.shortest_edge)) + self.patch_size = self.processor.video_processor.patch_size + self.temporal_patch_size = self.processor.video_processor.temporal_patch_size + self.merge_size = self.processor.video_processor.merge_size + self.use_smart_resize = True + if self.pad_id is None: + self.pad_id = self.eos_id + + def apply_chat_template( + self, + messages, + add_generation_prompt=False, + return_tensors="pt", + tokenize=True, + **kwargs, + ): + """ + Return: + inputs: dict + input_ids: torch.Tensor, shape: (N_token) + attention_mask: torch.Tensor, shape: (N_token) + texts: str, the raw text + image_sizes: torch.Tensor, shape (N_img, 2) + pixel_values: torch.Tensor, shape (N_img_patch, 3, 224, 224) + """ + + # messages = [msg for msg in messages if msg.get("role") != "system"] + assert tokenize, "tokenize must be True" + assert return_tensors == "pt", "return_tensors must be pt" + # Note: this tokenizer does not support "content": str, it always expect "content" entry to be a list of dicts + messages = convert_string_content_to_list_content(messages) + kwargs = {} + for message_id, message in enumerate(messages): + if isinstance(message["content"], list): + for sub_content in message["content"]: + if sub_content.get("type", "") == "image": + image = sub_content["image"] + max_pixels = sub_content.get("max_pixels", self.processor.image_processor.size["longest_edge"]) + min_pixels = sub_content.get("min_pixels", self.processor.image_processor.size["shortest_edge"]) + assert isinstance(image, Image.Image), ( + "image must be a url string for now, not support list of images for one content" + ) + width, height = image.size + resized_height, resized_width = smart_resize( + height, + width, + factor=32, + min_pixels=min_pixels, + max_pixels=max_pixels, + ) + image = image.resize((resized_width, resized_height)) + sub_content["image"] = image + + num_video, video_fps, video_total_num_frames, video_frames_indices = maybe_parse_video_content(messages) + if num_video > 0: + # Here we add the args to avoid the error: + # File "/invalid_dir", line 321, in _decode_and_sample_videos + # raise ValueError( + # ValueError: Sampling frames from a list of images is not supported! Set `do_sample_frames=False`. + video_metadata = [ + dict(fps=fps, total_num_frames=total_num_frames, frames_indices=frames_indices) + for fps, total_num_frames, frames_indices in zip( + video_fps, video_total_num_frames, video_frames_indices + ) + ] + kwargs["videos_kwargs"] = { + "do_sample_frames": False, + "video_metadata": video_metadata[0] if num_video == 1 else video_metadata, + } + + inputs = self.processor.apply_chat_template( + messages, + tokenize=tokenize, + add_generation_prompt=add_generation_prompt, + return_dict=True, + return_tensors=return_tensors, + # padding="max_length", + # max_length=16000, + # truncation=False, + **kwargs, + ) + + # Convert batch features into single features + # By default, the processor returns a batch of features, but we use processor in dataloader, so we need to convert it to single features + inputs["input_ids"] = inputs["input_ids"][0] # [N_token] + inputs["attention_mask"] = inputs["attention_mask"][0] # [N_token] + num_image_tokens = inputs["input_ids"] == self.image_token_id # [N_token] + num_video_tokens = inputs["input_ids"] == self.video_token_id # [N_token] + return inputs + + def add_assistant_tokens_mask(self, tokens): + """ + Add a mask to the assistant tokens. + This is used to mask out tokens that are not generated by the assistant (e.g., system prompts, user prompts, chat templates), such that in the loss computation, only the tokens generated by the assistant are used. + If there are multiple turns in the conversation, the mask will mask all the assistant tokens in each turn. + + Args: + tokens (Union[List[int], torch.Tensor]): The tokens to add the mask to. + Returns: + Union[List[bool], torch.Tensor]: The mask. True for tokens generated by the assistant (i.e. should apply loss on), False for tokens not generated by the assistant. + """ + if isinstance(tokens, torch.Tensor) and tokens.ndim == 2: + mask = torch.stack( + [self.add_assistant_tokens_mask(tokens[i]) for i in range(tokens.shape[0])] + ) # [B,N_token] + assert mask.shape == tokens.shape + return mask + np_tokens = tokens.cpu().numpy() if isinstance(tokens, torch.Tensor) else np.array(tokens) + assert np_tokens.ndim == 1 + + # Constants defining bos, eos and fixed offsets. + BOS_TOKEN = "<|im_start|>" + EOS_TOKEN = "<|im_end|>" + ROLE = "assistant" + # Offsets: skip the bos + "assistant\n" (always 3 tokens) and include the eos (+1) for supervision + START_OFFSET = 3 + END_OFFSET = 1 + + # Retrieve token IDs for the markers and the role. + bos_token_id = self.processor.tokenizer.convert_tokens_to_ids(BOS_TOKEN) + eos_token_id = self.processor.tokenizer.convert_tokens_to_ids(EOS_TOKEN) + role_id = self.processor.tokenizer.convert_tokens_to_ids(ROLE) + role_ids = self.processor.tokenizer.encode( + ROLE, add_special_tokens=False + ) # In case the role_id corresponds to multiple tokens, decode it back to string for accurate comparison + think_start_id = self.processor.tokenizer.convert_tokens_to_ids("") + think_end_id = self.processor.tokenizer.convert_tokens_to_ids("") + + # Locate all positions where the start and end markers appear. + start_indices = np.where(np_tokens == bos_token_id)[0] + end_indices = np.where(np_tokens == eos_token_id)[0] + + # Initialize the mask with False values. + masks = np.zeros_like(np_tokens, dtype=bool) + assert len(start_indices) == len(end_indices) + # For each pair of bos/eos, check if the role is 'assistant' + # and apply the mask accordingly. + for start, end in zip(start_indices, end_indices): + end_pos = None + if np_tokens[start + 1] == role_id: + # Mask tokens from after the assistant header (start+3) to include the end marker (end+1) + masks[start + START_OFFSET : end + END_OFFSET] = True + end_pos = start + START_OFFSET + elif all(np_tokens[start + 1 : start + 1 + len(role_ids)] == role_ids): + masks[start + START_OFFSET + len(role_ids) - 1 : end + END_OFFSET] = True + end_pos = start + START_OFFSET + len(role_ids) - 1 + if end_pos is not None and np_tokens[end_pos] == think_start_id: + masks[end_pos] = False + if np_tokens[end_pos + 1] == think_end_id: + masks[end_pos + 1] = False + + assert masks.shape == np_tokens.shape + if isinstance(tokens, torch.Tensor): + return torch.from_numpy(masks) + else: + return masks.tolist() + + def encode(self, *args, **kwargs): + return self.processor.encode(*args, **kwargs) + + def decode(self, *args, **kwargs): + return self.processor.decode(*args, **kwargs) diff --git a/cosmos_framework/data/vlm/processors/nemotronvl_processor.py b/cosmos_framework/data/vlm/processors/nemotronvl_processor.py new file mode 100644 index 0000000..1fde099 --- /dev/null +++ b/cosmos_framework/data/vlm/processors/nemotronvl_processor.py @@ -0,0 +1,553 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: OpenMDW-1.1 + +import os +from typing import Dict, List, Optional + +import numpy as np +import torch +from PIL import Image +from transformers.models.auto.processing_auto import AutoProcessor +from transformers.processing_utils import VideosKwargs +from transformers.video_utils import VideoMetadata + +from cosmos_framework.utils import log +from cosmos_framework.utils.vlm.pretrained_models_downloader import maybe_download_hf_model_from_s3 + +nemotron_chat_template = """ +{%- set ns = namespace(enable_thinking=false, has_sys_prompt=false, non_tool_system_content='', has_video=false, explicit_think_requested=false) -%} +{%- set msg = namespace(content='') -%} +{%- for message in messages -%} + {%- if message['role'] == 'system' -%} + {%- set ns.has_sys_prompt = true -%} + {# Extract system content without tool flags #} + {%- if message['content'] is string -%} + {%- set ns.non_tool_system_content = message['content'].replace('', '<_end_think>').replace('/think', '').replace('/no_think', '').replace('<_end_think>', '').strip() -%} + {%- else -%} + {%- set ns.non_tool_system_content = '' -%} + {%- for content in message['content'] -%} + {%- if content['type'] == 'text' -%} + {%- set ns.non_tool_system_content = ns.non_tool_system_content + content['text'].replace('', '<_end_think>').replace('/think', '').replace('/no_think', '').replace('<_end_think>', '') -%} + {%- endif -%} + {%- endfor -%} + {%- set ns.non_tool_system_content = ns.non_tool_system_content.strip() -%} + {%- endif -%} + {%- endif -%} + {# Check for video content in all messages #} + {%- if message['content'] is not string -%} + {%- for content in message['content'] -%} + {%- if content['type'] == 'video' or content['type'] == 'video_url' -%} + {%- set ns.has_video = true -%} + {%- endif -%} + {%- endfor -%} + {%- endif -%} + {%- if message['content'] is string -%} + {%- if message['role'] == 'user' or message['role'] == 'system' -%} + {%- if '/think' in message['content'].replace('', '') -%} + {%- set ns.enable_thinking = true -%} + {%- set ns.explicit_think_requested = true -%} + {%- elif '/no_think' in message['content'] -%} + {%- set ns.enable_thinking = false -%} + {%- endif -%} + {%- endif -%} + {%- else -%} + {%- for content in message['content'] -%} + {%- if content['type'] == 'text' -%} + {%- if message['role'] == 'user' or message['role'] == 'system' -%} + {%- if '/think' in content['text'].replace('', '') -%} + {%- set ns.enable_thinking = true -%} + {%- set ns.explicit_think_requested = true -%} + {%- elif '/no_think' in content['text'] -%} + {%- set ns.enable_thinking = false -%} + {%- endif -%} + {%- endif -%} + {%- endif -%} + {%- endfor -%} + {%- endif -%} +{%- endfor -%} + +{{- bos_token -}} +{%- if messages[0]['role'] != 'system' -%} + {{- 'System\n' -}} +{%- else -%} + {{- 'System\n' + ns.non_tool_system_content }} +{%- endif -%} + +{%- if tools -%} + {%- if ns.non_tool_system_content != '' -%} + {{- '\n\n' -}} + {%- endif -%} + {{- 'You can use the following tools to assist the user if required:\n' -}} + {{- '[' -}} + {%- for tool in tools -%} + {{- (tool.function if tool.function is defined else tool) | tojson -}} + {{- ', ' if not loop.last else '' -}} + {%- endfor -%} + {{- ']\n\n' -}} + + {{- 'If you decide to call any tool(s), use the following format:\n' -}} + {{- '[{"name": "tool_name1", "arguments": "tool_args1"}, ' -}} + {{- '{"name": "tool_name2", "arguments": "tool_args2"}]\n\n' -}} + + {{- 'The user will execute tool-calls and return responses from tool(s) in this format:\n' -}} + {{- '[{"response": "tool_response1"}, ' -}} + {{- '{"response": "tool_response2"}]\n\n' -}} + + {{- 'Based on the tool responses, you can call additional tools if needed, ' -}} + {{- 'correct tool calls if any errors are found, or just respond to the user.' -}} +{%- endif -%} +{{- '\n' -}} + +{%- set messages = messages[1:] if messages[0]['role'] == 'system' else messages -%} + +{# Prevent no user or assistant message #} +{%- if messages|length == 0 -%} + {%- set messages = [{'role': 'user', 'content': ''}] -%} +{%- endif -%} + +{%- for message in messages %} + {%- if message['content'] is string -%} + {%- set msg.content = message['content'].replace('', '<_end_think>').replace('/think', '').replace('/no_think', '').replace('<_end_think>', '').strip() -%} + {%- else -%} + {%- set msg.content = '' -%} + {%- set mm_content = '' -%} + {%- set counters = namespace(images=0, videos=0) -%} + + {%- for content in message['content'] -%} + {%- if content['type'] == 'image' -%} + {%- set counters.images = counters.images + 1 -%} + {%- elif content['type'] == 'video' -%} + {%- set counters.videos = counters.videos + 1 -%} + {%- elif content['type'] == 'text' -%} + {%- set msg.content = msg.content + content['text'] -%} + {%- endif -%} + {%- endfor -%} + {%- if '' in msg.content -%} + {%- set counters.images = 0 -%} + {%- endif -%} + {%- if '