diff --git a/cosmos_framework/__init__.py b/cosmos_framework/__init__.py index 503ec1b..28a81be 100644 --- a/cosmos_framework/__init__.py +++ b/cosmos_framework/__init__.py @@ -1,3 +1,2 @@ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: OpenMDW-1.1 - diff --git a/cosmos_framework/auxiliary/guardrail/common/presets.py b/cosmos_framework/auxiliary/guardrail/common/presets.py index d320b5e..8536685 100644 --- a/cosmos_framework/auxiliary/guardrail/common/presets.py +++ b/cosmos_framework/auxiliary/guardrail/common/presets.py @@ -7,9 +7,6 @@ from cosmos_framework.auxiliary.guardrail.common.core import GuardrailRunner from cosmos_framework.auxiliary.guardrail.face_blur_filter.face_blur_filter import RetinaFaceFilter from cosmos_framework.auxiliary.guardrail.qwen3guard.qwen3guard import Qwen3Guard -from cosmos_framework.auxiliary.guardrail.video_content_safety_filter.video_content_safety_filter import ( - VideoContentSafetyFilter, -) from cosmos_framework.utils import log @@ -27,7 +24,8 @@ def create_video_guardrail_runner(offload_model_to_cpu: bool = False) -> Guardra """Create the video guardrail runner.""" return GuardrailRunner( safety_models=[ - # VideoContentSafetyFilter(offload_model_to_cpu=offload_model_to_cpu), # Too many false positives + # VideoContentSafetyFilter(offload_model_to_cpu=offload_model_to_cpu) + # Too many false positives, add back when fixed ], postprocessors=[RetinaFaceFilter(offload_model_to_cpu=offload_model_to_cpu)], ) diff --git a/cosmos_framework/auxiliary/guardrail/face_blur_filter/retinaface_utils.py b/cosmos_framework/auxiliary/guardrail/face_blur_filter/retinaface_utils.py index cffebc2..805ecd5 100644 --- a/cosmos_framework/auxiliary/guardrail/face_blur_filter/retinaface_utils.py +++ b/cosmos_framework/auxiliary/guardrail/face_blur_filter/retinaface_utils.py @@ -1,4 +1,3 @@ -# Copyright (c) 2019 # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: OpenMDW-1.1 diff --git a/cosmos_framework/auxiliary/guardrail/qwen3guard/__init__.py b/cosmos_framework/auxiliary/guardrail/qwen3guard/__init__.py index 503ec1b..28a81be 100644 --- a/cosmos_framework/auxiliary/guardrail/qwen3guard/__init__.py +++ b/cosmos_framework/auxiliary/guardrail/qwen3guard/__init__.py @@ -1,3 +1,2 @@ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: OpenMDW-1.1 - diff --git a/cosmos_framework/callbacks/compile_tokenizer.py b/cosmos_framework/callbacks/compile_tokenizer.py index 84efa16..3ee15d2 100644 --- a/cosmos_framework/callbacks/compile_tokenizer.py +++ b/cosmos_framework/callbacks/compile_tokenizer.py @@ -4,7 +4,7 @@ """Training callback that defers AOT compilation of the VAE tokenizer. The actual compilation logic lives in -:meth:`~projects.cosmos3.vfm.tokenizers.wan2pt2_vae_4x16x16.Wan2pt2VAEInterface.compile_encode`. +:meth:`~cosmos_framework.model.vfm.tokenizers.wan2pt2_vae_4x16x16.Wan2pt2VAEInterface.compile_encode`. This module provides a :class:`CompileTokenizer` callback that invokes it at the right point during training (after ``compile_after_iterations`` steps, to avoid NCCL timeouts during CUDA/cuDNN warm-up). @@ -21,6 +21,7 @@ """ from collections.abc import Sequence +from typing import Literal import torch @@ -43,6 +44,10 @@ def __init__( enabled: bool = False, compile_after_iterations: int = 3, warmup_resolutions: Sequence[str] | None = None, + backend: Literal["cudagraphs", "inductor"] = "inductor", + mode: Literal["reduce-overhead", "max-autotune"] | None = "reduce-overhead", + fullgraph: bool = False, + dynamic: bool = False, ): """ Args: @@ -60,6 +65,10 @@ def __init__( self.compile_after_iterations: int = compile_after_iterations self.skip_counter: int = 0 self.warmup_resolutions: Sequence[str] | None = warmup_resolutions + self.backend: Literal["cudagraphs", "inductor"] = backend + self.mode: Literal["reduce-overhead", "max-autotune"] | None = mode + self.fullgraph: bool = fullgraph + self.dynamic: bool = dynamic if self.enabled: if self.warmup_resolutions is None: @@ -101,6 +110,10 @@ def on_training_step_start( tokenizer.compile_encode( self.warmup_resolutions, output_dir=self.config.job.path_local, + backend=self.backend, + mode=self.mode, + fullgraph=self.fullgraph, + dynamic=self.dynamic, ) self.skip_counter += 1 diff --git a/cosmos_framework/callbacks/data_stats.py b/cosmos_framework/callbacks/data_stats.py index 20e3161..2914981 100644 --- a/cosmos_framework/callbacks/data_stats.py +++ b/cosmos_framework/callbacks/data_stats.py @@ -51,7 +51,6 @@ def on_training_step_end( # Handle case where dataset_name gets batched into a list if isinstance(dataset_name, list): - assert len(dataset_name) == 1, "dataset_name should be a list of 1" dataset_name = dataset_name[0] diff --git a/cosmos_framework/callbacks/dataloader_state.py b/cosmos_framework/callbacks/dataloader_state.py index ec20eea..bee9e43 100644 --- a/cosmos_framework/callbacks/dataloader_state.py +++ b/cosmos_framework/callbacks/dataloader_state.py @@ -26,18 +26,14 @@ class DataLoaderStateCallback(Callback): def __init__( self, distributor_type: str | None = None, - name: str = "", ) -> None: super().__init__() self.distributor_type = distributor_type - self.name = name self.config: Any = None self.state: dict[int, NoReplaceShardlistState] = {} self.verbose = True def _update_state_from_batch(self, data_batch: dict[str, torch.Tensor]) -> None: - if "sample_worker_id" not in data_batch: - return # batch has no position metadata (shuffle=False or iterable data_source) worker_ids = data_batch["sample_worker_id"].tolist() # [B] epochs = data_batch["sample_epoch"].tolist() # [B] indices = data_batch["sample_index"].tolist() # [B] @@ -50,8 +46,6 @@ def _update_state_from_batch(self, data_batch: dict[str, torch.Tensor]) -> None: ): self.state[worker_id] = NoReplaceShardlistState(epoch=epoch, index=index) - _ACTIVE_DISTRIBUTOR_TYPES = ("no_replace",) - def on_training_step_batch_end( self, model: ImaginaireModel, @@ -60,7 +54,7 @@ def on_training_step_batch_end( loss: torch.Tensor, iteration: int = 0, ) -> None: - if self.distributor_type in self._ACTIVE_DISTRIBUTOR_TYPES: + if self.distributor_type == "no_replace": self._update_state_from_batch(data_batch) def on_training_step_end( @@ -71,7 +65,7 @@ def on_training_step_end( loss: torch.Tensor, iteration: int = 0, ) -> None: - if self.distributor_type in self._ACTIVE_DISTRIBUTOR_TYPES: + if self.distributor_type == "no_replace": if self.verbose: if iteration % self.config.trainer.logging_iter == 0: msg = "\n" @@ -80,10 +74,10 @@ def on_training_step_end( log.info(msg) def has_checkpoint_state(self) -> bool: - return self.distributor_type in self._ACTIVE_DISTRIBUTOR_TYPES + return self.distributor_type == "no_replace" def state_dict(self) -> dict[int, dict[str, int]]: - if self.distributor_type not in self._ACTIVE_DISTRIBUTOR_TYPES: + if self.distributor_type != "no_replace": return {} state_dict: dict[int, dict[str, int]] = {} @@ -96,7 +90,7 @@ def state_dict(self) -> dict[int, dict[str, int]]: return state_dict def load_state_dict(self, state_dict: dict[int, dict[str, int]]) -> None: - if self.distributor_type not in self._ACTIVE_DISTRIBUTOR_TYPES: + if self.distributor_type != "no_replace": return if not state_dict: @@ -110,4 +104,4 @@ def load_state_dict(self, state_dict: dict[int, dict[str, int]]) -> None: self.state[worker_id] = NoReplaceShardlistState(epoch=epoch, index=index) os.environ[f"NSL_STATE_WORKER_{worker_id}_EPOCH"] = str(epoch) os.environ[f"NSL_STATE_WORKER_{worker_id}_INDEX"] = str(index) - log.info(f"Loaded no_replace dataloader state for worker {worker_id}: epoch={epoch}, index={index}") + log.info(f"Loaded no replace dataloader state for worker {worker_id}: epoch={epoch}, index={index}") diff --git a/cosmos_framework/callbacks/every_n_draw_sample.py b/cosmos_framework/callbacks/every_n_draw_sample.py index baf1ffc..9aa96fa 100644 --- a/cosmos_framework/callbacks/every_n_draw_sample.py +++ b/cosmos_framework/callbacks/every_n_draw_sample.py @@ -154,8 +154,6 @@ def x0_pred(self, trainer, model, data_batch, output_batch, loss, iteration): tag = "ema" if self.is_ema else "reg" log.debug("starting data and condition model", rank0_only=False) - - data_clean = model.get_data_and_condition(data_batch) raw_data = data_clean.raw_state_vision x0 = data_clean.x0_tokens_vision @@ -185,7 +183,6 @@ def x0_pred(self, trainer, model, data_batch, output_batch, loss, iteration): log.debug(f"done denoising {sigma}", rank0_only=False) mse_loss = distributed.dist_reduce_tensor(F.mse_loss(sample, x0)) mse_loss_list.append(mse_loss) - if hasattr(model, "decode"): sample = model.decode(sample) to_show.append(sample.float().cpu()) @@ -316,7 +313,6 @@ def sample(self, trainer, model, data_batch, output_batch, loss, iteration): for sample_idx in range(data_clean.batch_size): n_vis = num_items[sample_idx] # First item(s) are condition, last item is generation target - # but we need to support multiple conditions per sample in the future. Current code # can handle this without throwing an error. condition_images.append(raw_data[vis_offset]) # source image (1, C, 1, H, W) diff --git a/cosmos_framework/callbacks/grad_clip.py b/cosmos_framework/callbacks/grad_clip.py index 151c1bc..f3cb4fa 100644 --- a/cosmos_framework/callbacks/grad_clip.py +++ b/cosmos_framework/callbacks/grad_clip.py @@ -132,7 +132,7 @@ def _clip_grad( # `torch.distributed._tensor.ops.math_ops._NormPartial`. # We can simply reduce the DTensor to get the total norm in this # tensor's process group and then convert it to a local tensor. - + # NOTE: It has two purposes: # 1. to make sure the total norm is computed correctly when PP is used (see below) # 2. to return a reduced mesh_norm tensor whose .item() would return the correct value if isinstance(mesh_norm, DTensor): diff --git a/cosmos_framework/callbacks/hf_export.py b/cosmos_framework/callbacks/hf_export.py index 8bcba5a..6d23568 100644 --- a/cosmos_framework/callbacks/hf_export.py +++ b/cosmos_framework/callbacks/hf_export.py @@ -1,5 +1,6 @@ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: OpenMDW-1.1 + """HFExportCallback: export VLM DCP checkpoints to HuggingFace safetensors format. Design notes @@ -137,11 +138,11 @@ def on_save_checkpoint(self, model: Any, state_dict: dict[str, Any]) -> None: if not isinstance(model, VLMModel): # The legacy vlm/train.py path passes model_parts: list[nn.Module] (raw HF # models without the VLMModel attribute structure). HF export requires the - # VLMModel wrapper, which is only available via the unified cosmos_framework/scripts/train.py path. + # VLMModel wrapper, which is only available via the unified scripts/train.py path. if isinstance(model, list): log.warning( "[HFExportCallback] Received model_parts (list) instead of VLMModel. " - "HF export requires the unified training path (cosmos_framework/scripts/train.py). Skipping." + "HF export requires the unified training path (scripts/train.py). Skipping." ) else: log.warning( diff --git a/cosmos_framework/callbacks/mfu.py b/cosmos_framework/callbacks/mfu.py index 3035437..4a6c792 100644 --- a/cosmos_framework/callbacks/mfu.py +++ b/cosmos_framework/callbacks/mfu.py @@ -138,7 +138,6 @@ def _ensure_initialised(self, model: ImaginaireModel) -> None: ac_cfg = getattr(model_cfg, "activation_checkpointing", None) ac_mode = getattr(ac_cfg, "mode", "none") - # Some activations don't need to be recomputed under selective AC, so # we need to remove them from the FLOP computation. self._use_activation_checkpointing = ac_mode != "none" diff --git a/cosmos_framework/callbacks/wandb_log_eval.py b/cosmos_framework/callbacks/wandb_log_eval.py index ac6911f..abea93f 100644 --- a/cosmos_framework/callbacks/wandb_log_eval.py +++ b/cosmos_framework/callbacks/wandb_log_eval.py @@ -88,7 +88,6 @@ def on_validation_step_end( # Handle case where dataset_name gets batched into a list if isinstance(dataset_name, list): - assert len(dataset_name) == 1, "dataset_name should be a list of 1" dataset_name = dataset_name[0] diff --git a/cosmos_framework/checkpoint/dcp.py b/cosmos_framework/checkpoint/dcp.py index 4e036c9..7318e4a 100644 --- a/cosmos_framework/checkpoint/dcp.py +++ b/cosmos_framework/checkpoint/dcp.py @@ -63,6 +63,7 @@ set_model_state_dict, ) from torch.distributed.checkpoint.stateful import Stateful +from torch.nn.modules.module import _IncompatibleKeys from cosmos_framework.checkpoint.base import AbstractCheckpointer from cosmos_framework.checkpoint.s3_filesystem import S3StorageReader, S3StorageWriter @@ -85,11 +86,11 @@ def __init__(self, model: nn.Module) -> None: def state_dict(self) -> dict[str, Any]: return get_model_state_dict(self.model) - def load_state_dict(self, state_dict: dict[str, Any]) -> None: - set_model_state_dict( + def load_state_dict(self, state_dict: dict[str, Any]) -> _IncompatibleKeys: + return set_model_state_dict( self.model, model_state_dict=state_dict, - options=StateDictOptions(strict=True), + options=StateDictOptions(strict=False), ) @@ -539,28 +540,13 @@ def load( "Ensure the model has net_ema submodule." ) _state_dict[sd_key] = _state_dict[key_ema] - elif warm_start and any(str(s).startswith("net_ema") for s in self.keys_to_skip_loading): - # Only when net_ema.* is explicitly skipped on load (e.g. an HF->DCP - # init from convert_model_to_dcp that has only net.*): the skipped - # net_ema.* keep build_net() construction values (random init when - # vlm_config.pretrained_weights.enabled=False), which would seed EMA - # from random weights -> copy net.* -> net_ema.* so EMA starts from the - # freshly-loaded init. When net_ema.* IS loaded (e.g. a training DCP - # that carries a trained EMA), do NOT clobber it. - log.info("Warm start: net_ema. skipped on load -> resetting net_ema = net.") - for sd_key in list(_state_dict.keys()): - if sd_key.startswith("net."): - key_ema = "net_ema." + sd_key.removeprefix("net.") - if key_ema in _state_dict: - _state_dict[key_ema] = _state_dict[sd_key] results = _model_wrapper.load_state_dict(_state_dict) - if results is not None: - if len(results.missing_keys) > 0: - raise ValueError(f"Missing keys (not found in checkpoint): {results.missing_keys}") - if len(results.unexpected_keys) > 0: - raise ValueError( - f"Unexpected keys (found in checkpoint but not in model): {results.unexpected_keys}" - ) + if len(results.missing_keys) > 0: + raise ValueError(f"Missing keys (not found in checkpoint): {results.missing_keys}") + if len(results.unexpected_keys) > 0: + raise ValueError( + f"Unexpected keys (found in checkpoint but not in model): {results.unexpected_keys}" + ) elif key == "optim": log.info("- Loading the optimizer...") diff --git a/cosmos_framework/checkpoint/s3_filesystem.py b/cosmos_framework/checkpoint/s3_filesystem.py index e47219e..029570e 100644 --- a/cosmos_framework/checkpoint/s3_filesystem.py +++ b/cosmos_framework/checkpoint/s3_filesystem.py @@ -3,32 +3,89 @@ import io import os +import threading import time from contextlib import contextmanager from typing import Generator, Union from urllib.parse import urlparse +import boto3 +from botocore.config import Config as S3Config from botocore.exceptions import ClientError from torch.distributed.checkpoint import FileSystemReader, FileSystemWriter from torch.distributed.checkpoint.filesystem import FileSystemBase from cosmos_framework.utils import log from cosmos_framework.utils.easy_io import easy_io +from cosmos_framework.utils.easy_io.backends import auto_auth -class S3Stream(io.BytesIO): +class _CancellableReader: + """Pipe-reader wrapper whose ``read`` raises once a cancel event is set. + + Lets us abort an in-flight ``client.upload_fileobj`` on producer error: a + read exception makes boto3 abort the multipart upload, whereas just + closing the pipe writer would signal EOF and finalize a truncated file. """ - Workaround for PyTorch manually closing the stream before we can upload it to S3. We override the close() as noop - and instead call our own _true_close() method to close the stream after we are done using it. - The commit at fault is https://github.com/pytorch/pytorch/commit/9c909bf3bb122db2cce95e2eb7459bbe50dfa15a + + def __init__(self, f, cancel_event: threading.Event) -> None: + self._f = f + self._cancel = cancel_event + + def read(self, n: int = -1) -> bytes: + if self._cancel.is_set(): + raise IOError("S3 upload cancelled by caller") + return self._f.read(n) + + def readable(self) -> bool: + return True + + def close(self) -> None: + self._f.close() + + +class _CountingPipeWriter(io.RawIOBase): + """Write-only pipe wrapper that fakes ``tell()`` by counting bytes written. + + DCP calls ``stream.tell()`` to record per-tensor byte offsets in the + checkpoint metadata, but kernel pipes aren't seekable. We maintain the + byte count ourselves; nothing actually seeks. """ - def close(self): - self.flush() - # No close + def __init__(self, write_file) -> None: + super().__init__() + self._f = write_file + self._pos = 0 + + def write(self, b) -> int: + n = self._f.write(b) + if n is None: + raise OSError("_CountingPipeWriter: underlying pipe write returned None; expected a blocking write.") + self._pos += n + return n + + def writable(self) -> bool: + return True + + def seekable(self) -> bool: + return False # pipes can't seek; consumers (zipfile, etc.) check this + + def tell(self) -> int: + return self._pos - def _true_close(self): - super().close() + def fileno(self) -> int: + return self._f.fileno() + + def flush(self) -> None: + self._f.flush() + + def close(self) -> None: + if self.closed: + return + try: + super().close() # invokes self.flush(), then sets self.closed = True + finally: + self._f.close() class S3FileSystem(FileSystemBase): @@ -69,6 +126,33 @@ def __init__( if enable_gcs_patch_in_boto3: log.info("enable_gcs_patch_in_boto3: True") + # Direct boto3 client for streaming-multipart uploads (``upload_fileobj`` + # via boto3's TransferManager). We can't reuse ``self.easy_io_backend``'s + # client: easy_io abstracts the transport (could be ``Boto3Backend`` or + # ``MSCBackend``) and intentionally doesn't expose a raw boto3 client. + # Built lazily so read-only callers don't pay for it. + self._credential_path = credential_path + self._boto3_client = None + + def _get_boto3_client(self): + """Lazily build a boto3 S3 client configured for our endpoint. + + Config mirrors cosmos_framework/utils/easy_io/backends/boto3_client.py:289 to + preserve GCS-via-S3 signature/checksum compatibility. + """ + if self._boto3_client is None: + with auto_auth.open_auth(self._credential_path, "r") as f: + cred_info = auto_auth.json_load_auth(f) + cfg = S3Config( + signature_version="s3v4", + s3={"addressing_style": "virtual"}, + response_checksum_validation="when_required", + request_checksum_calculation="when_required", + retries={"max_attempts": 5, "mode": "adaptive"}, + ) + self._boto3_client = boto3.client("s3", **cred_info, config=cfg) + return self._boto3_client + def _retry_with_backoff(self, operation_func, *args, **kwargs): """ Execute an operation with exponential backoff retry logic. @@ -135,24 +219,61 @@ def download_operation(): log.info(f"S3 Filesystem: Downloading {key} from bucket {bucket}", rank0_only=False) self._retry_with_backoff(download_operation) - log.info("S3 Filesystem: Download complete", rank0_only=False) + log.info(f"S3 Filesystem: Download complete for {key} in bucket {bucket}", rank0_only=False) yield stream finally: stream.close() elif mode == "wb": - stream = S3Stream() + # Streaming multipart upload: yield the writer end of a pipe to DCP + # and drain the reader end via ``client.upload_fileobj`` in a + # background thread. Peak memory is bounded by boto3's TransferConfig + # (~80 MiB) regardless of file size; the pipe (~64 KiB) provides + # backpressure. See ``_CancellableReader`` for how producer-side + # errors abort the multipart upload. + client = self._get_boto3_client() + r_fd, w_fd = os.pipe() + read_file = os.fdopen(r_fd, "rb") + write_file = os.fdopen(w_fd, "wb") + counting_writer = _CountingPipeWriter(write_file) + upload_err: list = [None] + cancel_event = threading.Event() + + def _upload_thread(): + try: + client.upload_fileobj( + _CancellableReader(read_file, cancel_event), + Bucket=bucket, + Key=key, + ) + except Exception as e: # noqa: BLE001 — capture and re-raise on main thread + upload_err[0] = e + finally: + try: + read_file.close() + except Exception: + pass + + log.info(f"S3 Filesystem: Streaming upload {key} to bucket {bucket}", rank0_only=False) + uploader = threading.Thread(target=_upload_thread, daemon=True, name=f"s3-upload-{key[-32:]}") + uploader.start() + + caller_raised = False try: - yield stream - - def upload_operation(): - stream.seek(0) - self.easy_io_backend.put(obj=stream, filepath=path_str) - - log.info(f"S3 Filesystem: Uploading {key} to bucket {bucket}", rank0_only=False) - self._retry_with_backoff(upload_operation) - log.info("S3 Filesystem: Upload complete", rank0_only=False) + yield counting_writer + except Exception: + caller_raised = True + cancel_event.set() + raise finally: - stream._true_close() + try: + counting_writer.close() # closes the pipe write end → EOF for the reader + except Exception: + pass + uploader.join() + if upload_err[0] is not None and not caller_raised: + # Upload thread failed; surface that to the caller. + raise upload_err[0] + log.info(f"S3 Filesystem: Upload complete for {key}", rank0_only=False) else: raise ValueError(f"Unsupported mode: {mode}") @@ -285,7 +406,7 @@ def __init__( """ super().__init__( path=path, - sync_files=False, + sync_files=False, # FIXME: setting this to True makes the run to fail (L#333: `os.fsync(stream.fileno())`) **kwargs, ) self.fs = S3FileSystem(credential_path, enable_gcs_patch_in_boto3=enable_gcs_patch_in_boto3) # type: ignore diff --git a/cosmos_framework/configs/base/__init__.py b/cosmos_framework/configs/base/__init__.py index 503ec1b..28a81be 100644 --- a/cosmos_framework/configs/base/__init__.py +++ b/cosmos_framework/configs/base/__init__.py @@ -1,3 +1,2 @@ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: OpenMDW-1.1 - diff --git a/cosmos_framework/configs/base/base_config_test.py b/cosmos_framework/configs/base/base_config_test.py deleted file mode 100644 index 3eb2b0d..0000000 --- a/cosmos_framework/configs/base/base_config_test.py +++ /dev/null @@ -1,105 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: OpenMDW-1.1 - -""" -This file is used to test the config of the cosmos3 vfm project. -It is used to verify the config is loadable. - -To run the test, you can use the following command: -pytest -s cosmos_framework/configs/base/base_config_test.py -""" - -import importlib -from unittest.mock import MagicMock, patch - -import pytest - -from cosmos_framework.utils.config_helper import get_config_module, override - - -@pytest.mark.L0 -@pytest.mark.parametrize( - "experiment_name", - [ - "vision_sft_nano", - ], -) -def test_config_init_experiment_mot(experiment_name, monkeypatch): - """ - Parameterized test to verify config initialization for multiple experiments. - PYTHONPATH=. torchrun --nproc_per_node=8 -m pytest -s cosmos_framework/configs/base/config_test_mot.py --L1 - """ - # The SFT experiments interpolate the dataset location from ${oc.env:DATASET_PATH}; - # config construction only needs the variable defined, not a real dataset on disk. - monkeypatch.setenv("DATASET_PATH", "/tmp/dataset") - config_file = "cosmos_framework/configs/base/config.py" - config_module = get_config_module(config_file) - config = importlib.import_module(config_module).make_config() - config = override( - config, - [ - "--", - f"experiment={experiment_name}", - ], - ) - - -def _make_self_mock(*, pretrained_enabled: bool, load_weights_from_pretrained: bool) -> MagicMock: - """Mock the OmniMoTModel attributes that load_pretrained_model_if_needed reads.""" - self_mock = MagicMock() - self_mock.vlm_config.pretrained_weights.enabled = pretrained_enabled - self_mock.config.diffusion_expert_config.load_weights_from_pretrained = load_weights_from_pretrained - self_mock.config.ema.enabled = False - return self_mock - - -@pytest.mark.L0 -class TestLoadPretrainedGate: - """Decision matrix for ``OmniMoTModel.load_pretrained_model_if_needed``. - - Replaces the previous ``OmniMoTModelConfig.validate`` tests now that - LoadPretrained callback probes ``latest_checkpoint.txt`` / ``load_path`` at - ``on_train_start`` and forwards the two booleans, instead of mutating the - config during validation. - """ - - _LOADER_TARGET = "cosmos_framework.model.vfm.omni_mot_model.load_language_model_safetensors" - - def _call(self, self_mock: MagicMock, *, has_resumable_checkpoint: bool, has_load_path: bool) -> MagicMock: - from cosmos_framework.model.vfm.omni_mot_model import OmniMoTModel - - with patch(self._LOADER_TARGET) as loader: - OmniMoTModel.load_pretrained_model_if_needed( - self_mock, - has_resumable_checkpoint=has_resumable_checkpoint, - has_load_path=has_load_path, - ) - return loader - - def test_fresh_init_loads_and_copies(self): - """No checkpoint, no load_path: HF load AND understanding→generation copy.""" - self_mock = _make_self_mock(pretrained_enabled=True, load_weights_from_pretrained=True) - loader = self._call(self_mock, has_resumable_checkpoint=False, has_load_path=False) - loader.assert_called_once() - self_mock.net.language_model.init_moe.assert_called_once() - - def test_resume_skips_everything(self): - """Resumable checkpoint exists: neither HF load nor copy.""" - self_mock = _make_self_mock(pretrained_enabled=True, load_weights_from_pretrained=True) - loader = self._call(self_mock, has_resumable_checkpoint=True, has_load_path=False) - loader.assert_not_called() - self_mock.net.language_model.init_moe.assert_not_called() - - def test_warm_start_loads_but_skips_copy(self): - """load_path set, no checkpoint: HF load but skip understanding→generation copy.""" - self_mock = _make_self_mock(pretrained_enabled=True, load_weights_from_pretrained=True) - loader = self._call(self_mock, has_resumable_checkpoint=False, has_load_path=True) - loader.assert_called_once() - self_mock.net.language_model.init_moe.assert_not_called() - - def test_pretrained_disabled_short_circuits(self): - """pretrained_weights.enabled=False: early return regardless of other flags.""" - self_mock = _make_self_mock(pretrained_enabled=False, load_weights_from_pretrained=True) - loader = self._call(self_mock, has_resumable_checkpoint=False, has_load_path=False) - loader.assert_not_called() - self_mock.net.language_model.init_moe.assert_not_called() diff --git a/cosmos_framework/configs/base/config.py b/cosmos_framework/configs/base/config.py index d1f975d..e766c5c 100644 --- a/cosmos_framework/configs/base/config.py +++ b/cosmos_framework/configs/base/config.py @@ -38,7 +38,6 @@ class Config(config.Config): {"ema": "power"}, {"tokenizer": "wan2pt2_tokenizer"}, {"sound_tokenizer": None}, # Optional: for audio-video generation - {"cluster": "default"}, {"vlm_config": None}, {"ckpt_type": "dcp"}, {"experiment": None}, @@ -72,7 +71,6 @@ def make_config() -> Config: from cosmos_framework.configs.base.defaults.callbacks import register_callbacks from cosmos_framework.configs.base.defaults.checkpointer import register_checkpoint, register_ckpt_type - from cosmos_framework.configs.base.defaults.cluster import register_cluster from cosmos_framework.configs.base.defaults.ema import register_ema # from cosmos_framework.configs.base.defaults.data import register_data @@ -92,7 +90,6 @@ def make_config() -> Config: register_tokenizer() register_sound_tokenizer() register_ema() - register_cluster() register_vlm() # Register shipped experiments explicitly. (vision_sft_nano also defines diff --git a/cosmos_framework/configs/base/defaults/__init__.py b/cosmos_framework/configs/base/defaults/__init__.py index 503ec1b..28a81be 100644 --- a/cosmos_framework/configs/base/defaults/__init__.py +++ b/cosmos_framework/configs/base/defaults/__init__.py @@ -1,3 +1,2 @@ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: OpenMDW-1.1 - diff --git a/cosmos_framework/configs/base/defaults/callbacks.py b/cosmos_framework/configs/base/defaults/callbacks.py index 46c85ed..805ffb5 100644 --- a/cosmos_framework/configs/base/defaults/callbacks.py +++ b/cosmos_framework/configs/base/defaults/callbacks.py @@ -10,7 +10,6 @@ from cosmos_framework.utils.lazy_config import LazyCall as L from cosmos_framework.utils.callback import LowPrecisionCallback, WandBCallback from cosmos_framework.callbacks.compile_tokenizer import CompileTokenizer - from cosmos_framework.callbacks.device_monitor import DeviceMonitor from cosmos_framework.callbacks.every_n_draw_sample import EveryNDrawSample from cosmos_framework.callbacks.expert_heatmap import ExpertHeatmap diff --git a/cosmos_framework/configs/base/defaults/cluster.py b/cosmos_framework/configs/base/defaults/cluster.py deleted file mode 100644 index 23b49dd..0000000 --- a/cosmos_framework/configs/base/defaults/cluster.py +++ /dev/null @@ -1,36 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: OpenMDW-1.1 - -import attrs -from hydra.core.config_store import ConfigStore - - -@attrs.define(slots=False) -class ClusterConfig: - """ - Config for the cluster specific information. - Everything cluster specific should be here. - """ - - object_store_bucket_data: str - object_store_bucket_checkpoint: str - object_store_bucket_pretrained: str - - object_store_credential_data: str - object_store_credential_checkpoint: str - object_store_credential_pretrained: str - - -DefaultClusterConfig: ClusterConfig = ClusterConfig( - object_store_bucket_data="", - object_store_bucket_checkpoint="bucket-checkpoint", - object_store_bucket_pretrained="bucket-pretrained", - object_store_credential_data="credentials/data.secret", - object_store_credential_checkpoint="credentials/checkpoint.secret", - object_store_credential_pretrained="credentials/pretrained.secret", -) - - -def register_cluster(): - cs = ConfigStore.instance() - cs.store(group="cluster", package="job.cluster", name="default", node=DefaultClusterConfig) diff --git a/cosmos_framework/configs/base/defaults/compile.py b/cosmos_framework/configs/base/defaults/compile.py index b0e1c88..3d5ebf7 100644 --- a/cosmos_framework/configs/base/defaults/compile.py +++ b/cosmos_framework/configs/base/defaults/compile.py @@ -24,7 +24,7 @@ class CompileConfig: # (maps to ``torch.compile(dynamic=...)``). Defaults to True for training, # which sees varying shapes across batches (sequence length, CP sharding, ...); # specializing would recompile continuously. See ParallelismOverrides in - # cosmos_framework/inference/common/args.py for the inference-side rationale + # packages/cosmos3/cosmos3/common/args.py for the inference-side rationale # (where dynamic=False is preferred for stable AR shapes). compile_dynamic: bool = True diff --git a/cosmos_framework/configs/base/defaults/model_config.py b/cosmos_framework/configs/base/defaults/model_config.py index c7b9a5c..f7e7c8d 100644 --- a/cosmos_framework/configs/base/defaults/model_config.py +++ b/cosmos_framework/configs/base/defaults/model_config.py @@ -1,7 +1,6 @@ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: OpenMDW-1.1 - from typing import Any import attrs @@ -42,6 +41,14 @@ class DiffusionExpertConfig: rope_t_extrapolation_ratio: float = 1.0 enable_fps_modulation: bool = False base_fps: int = 24 + # Base temporal compression factor for SOUND m-RoPE. None = current behavior + # (sound advances at base_fps positions/sec). Set to the vision tcf (4) to put + # sound on the same latent-frame temporal grid as vision/action. + sound_base_temporal_compression_factor: int | None = None + # Temporal coordinates used for unified_3d_mrope vision tokens. + # - "latent_index": legacy behavior, positions are 0, 1, ..., T_latent-1. + # - "uniae_source_right_edge": use UniAE padded-patch right-edge source-frame coordinates. + vision_temporal_position_mode: str = "latent_index" # For unified_3d_mrope: whether spatial (H, W) indices reset to 0 for each vision segment unified_3d_mrope_reset_spatial_ids: bool = True # Setting the temporal gap on the boundary of the different modalities, default is 0, using a value greater than 0 will add an additional offset on the accumulated temporal offset. @@ -273,3 +280,12 @@ class OmniMoTModelConfig: sound_latent_fps: int = 25 # Sound tokenizer's latent rate (e.g., 48kHz / 1920 hop = 25 Hz) log_enc_time_every_n: int = 100 # Frequency of logging encoding time to W&B + + # When True, ``OmniMoTModel.state_dict`` / ``load_state_dict`` skip the + # reasoner (und) pathway weights under ``language_model`` — i.e. every key + # WITHOUT a ``_moe_gen`` suffix (including ``visual`` / ``lm_head`` / + # ``embed_tokens``). These are not written to checkpoints and are left + # untouched on load (typically already populated from the HF pretrained + # backbone). Generation-pathway (``_moe_gen``) and VFM heads are saved / + # loaded as usual. + exclude_reasoner_weights_from_checkpoint: bool = False diff --git a/cosmos_framework/configs/base/defaults/multiview_dataloader.py b/cosmos_framework/configs/base/defaults/multiview_dataloader.py deleted file mode 100644 index a646ac6..0000000 --- a/cosmos_framework/configs/base/defaults/multiview_dataloader.py +++ /dev/null @@ -1,150 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: OpenMDW-1.1 - -""" -Hydra ConfigStore registration for multiview dataloaders. - -Registers named dataloader configs that can be referenced via Hydra overrides -(e.g. ``{override /data_train: video_control_mads_multiview_0823_gcs_720p_10fps_93frames_7views}``) -or used as templates for inline ``L(get_multiview_video_loader)(...)`` in -experiment configs. - -Two naming conventions: - - **Transfer** (with control signal): - ``video_control_{dataset}_{store}_{res}_{fps}_{frames}_{views}`` - - **Predict** (no control signal): - ``video_{dataset}_{store}_{res}_{fps}_{frames}_{views}`` -""" - -from hydra.core.config_store import ConfigStore - -from cosmos_framework.utils.lazy_config import LazyCall as L -from cosmos_framework.data.vfm.multiview.multiview_data_source import ( - DEFAULT_CAMERAS, - INDEX_TO_CAMERA_MAPPING, - TRANSFER_CAPTION_KEY_MAPPING, - TRANSFER_CONTROL_KEY_MAPPING, - TRANSFER_VIDEO_KEY_MAPPING, -) -from cosmos_framework.data.vfm.multiview.multiview_dataset import ( - MultiviewAugmentationConfig, - get_multiview_video_loader, -) - -# --------------------------------------------------------------------------- -# Camera view subsets -# --------------------------------------------------------------------------- - -CAMERA_VIEW_CONFIGS: dict[str, tuple[str, ...]] = { - "7views": DEFAULT_CAMERAS, - "1view_front": ("camera_front_wide_120fov",), - "4views": ( - "camera_front_wide_120fov", - "camera_cross_right_120fov", - "camera_rear_tele_30fov", - "camera_cross_left_120fov", - ), -} - -# --------------------------------------------------------------------------- -# Grid dimensions -# --------------------------------------------------------------------------- - -_TRANSFER_DATASETS = ["mads_multiview_0823"] -_OBJECT_STORES = ["gcs"] - -_RESOLUTIONS: list[tuple[str, tuple[int, int]]] = [ - ("720p", (720, 1280)), -] - -_FPS: list[tuple[str, int]] = [ - ("10fps", 1), # MADS transfer data is already at 10 fps -] - -_NUM_VIDEO_FRAMES: list[tuple[str, int]] = [ - ("29frames", 29), - ("61frames", 61), - ("93frames", 93), -] - - -def register_multiview_dataloaders() -> None: - """Register all multiview dataloader configs with Hydra ConfigStore.""" - - cs = ConfigStore.instance() - - # ----- Transfer dataloaders (with control signals) ----- - for dataset in _TRANSFER_DATASETS: - for object_store in _OBJECT_STORES: - for resolution_str, resolution_hw in _RESOLUTIONS: - for fps_str, downsample_factor in _FPS: - for num_frames_str, num_frames in _NUM_VIDEO_FRAMES: - for views_str, camera_keys in CAMERA_VIEW_CONFIGS.items(): - name = ( - f"video_control_{dataset}_{object_store}_{resolution_str}_" - f"{fps_str}_{num_frames_str}_{views_str}" - ) - cs.store( - group="data_train", - package="dataloader_train", - name=name, - node=L(get_multiview_video_loader)( - dataset_name=dataset, - is_train=True, - augmentation_config=L(MultiviewAugmentationConfig)( - resolution_hw=resolution_hw, - fps_downsample_factor=downsample_factor, - num_video_frames=num_frames, - camera_keys=camera_keys, - camera_video_key_mapping=TRANSFER_VIDEO_KEY_MAPPING, - camera_caption_key_mapping=TRANSFER_CAPTION_KEY_MAPPING, - camera_control_key_mapping=TRANSFER_CONTROL_KEY_MAPPING, - position_to_camera_mapping=INDEX_TO_CAMERA_MAPPING, - single_caption_camera_name="camera_front_wide_120fov", - ), - ), - ) - - # ----- Predict dataloaders (no control signals, for future use) ----- - # These use named keys (video_camera_front_wide_120fov, etc.) and need - # different datasets (e.g. alpamayo_dec2024) with 30 fps native data. - # Uncomment and add predict datasets to the catalog when needed. - # - # _PREDICT_DATASETS = ["alpamayo_dec2024"] - # _PREDICT_FPS = [("10fps", 3), ("15fps", 2)] # 30 fps native → downsample - # for dataset in _PREDICT_DATASETS: - # for object_store in _OBJECT_STORES: - # for resolution_str, resolution_hw in _RESOLUTIONS: - # for fps_str, downsample_factor in _PREDICT_FPS: - # for num_frames_str, num_frames in _NUM_VIDEO_FRAMES: - # for views_str, camera_keys in CAMERA_VIEW_CONFIGS.items(): - # name = ( - # f"video_{dataset}_{object_store}_{resolution_str}_" - # f"{fps_str}_{num_frames_str}_{views_str}" - # ) - # cs.store( - # group="data_train", - # package="dataloader_train", - # name=name, - # node=L(get_multiview_video_loader)( - # dataset_name=dataset, - # is_train=True, - # augmentation_config=L(MultiviewAugmentationConfig)( - # resolution_hw=resolution_hw, - # fps_downsample_factor=downsample_factor, - # num_video_frames=num_frames, - # camera_keys=camera_keys, - # camera_video_key_mapping=PREDICT_VIDEO_KEY_MAPPING, - # camera_caption_key_mapping=PREDICT_CAPTION_KEY_MAPPING, - # camera_control_key_mapping=None, - # position_to_camera_mapping=None, - # single_caption_camera_name=None, - # ), - # ), - # ) - - -# Auto-register on import -register_multiview_dataloaders() diff --git a/cosmos_framework/configs/base/defaults/tokenizer.py b/cosmos_framework/configs/base/defaults/tokenizer.py index 55cb01c..526d579 100644 --- a/cosmos_framework/configs/base/defaults/tokenizer.py +++ b/cosmos_framework/configs/base/defaults/tokenizer.py @@ -17,10 +17,12 @@ PRETRAINED_TOKENIZER_FLUX_VAE_PTH = "pretrained/tokenizers/image/flux/ae.safetensors" # UniAE checkpoint paths -PRETRAINED_TOKENIZER_UNIAE_4X16X16_C48_T8TO24_64TO512P_FPS_ALL_ENCODER_NONCAUSAL_DECODER_NONCAUSAL_NOGAN_BEST_S1_VAE_PTH = "pretrained/tokenizers/video/cosmos/uniae4x16x16_c48_t8to24_64to512p_fps_all_encoder_noncausal_decoder_noncausal_nogan_best_s1.pt" +PRETRAINED_TOKENIZER_UNIAE_4X16X16_C48_T16TO160_MIXP_FPS_MIX_ENCODER_NONCAUSAL_DECODER_NONCAUSAL_NOGAN_S3_NEMOTRON2B_VAE_PTH = ( + "s3://bucket1/uniae/tok_experiments/" + "s3_siglip2_so400m_singledec_l48_textdec_nemotron2b_32node_bucketed_256480_v45i32c23_t16-160_exp009/checkpoints/iter_000050000.pt" +) # DCAE checkpoint paths -PRETRAINED_TOKENIZER_DCAE_PTH = "pretrained/tokenizers/video/cosmos/dc-ae-v-1.0-f32t4c64-cosmos-encoder-causal-decoder-chunk-causal-4-frame-120-pad-7-no-gan.pt" PRETRAINED_TOKENIZER_DCAE_4X32X32_C64_T120_256P_FPS_ALL_ENCODER_CAUSAL_DECODER_CHUNKCAUSAL4_NOGAN_COSMOS_PAD_7_V0PT2_PTH = "pretrained/tokenizers/video/cosmos/dcae4x32x32_c64_t120_256p_fps_all_encoder_causal_decoder_chunk_causal_4_nogan_cosmos_pad_7_v0.2.pt" PRETRAINED_TOKENIZER_DCAE_4X32X32_C96_T120_256P_FPS_ALL_ENCODER_CAUSAL_DECODER_CHUNKCAUSAL4_NOGAN_COSMOS_PAD_7_V0PT2_PTH = "pretrained/tokenizers/video/cosmos/dcae4x32x32_c96_t120_256p_fps_all_encoder_causal_decoder_chunk_causal_4_nogan_cosmos_pad_7_v0.2.pt" PRETRAINED_TOKENIZER_DCAE_4X32X32_C128_T120_256P_FPS_ALL_ENCODER_CAUSAL_DECODER_CHUNKCAUSAL4_NOGAN_COSMOS_PAD_7_V0PT2_PTH = "pretrained/tokenizers/video/cosmos/dcae4x32x32_c128_t120_256p_fps_all_encoder_causal_decoder_chunk_causal_4_nogan_cosmos_pad_7_v0.2.pt" @@ -44,6 +46,7 @@ chunk_duration=1, spatial_compression_factor=8, temporal_compression_factor=1, + causal=True, ) Wan2pt1VAEConfig: LazyDict = L(Wan2pt1VAEInterface)( @@ -53,6 +56,7 @@ vae_path=PRETRAINED_TOKENIZER_WAN2PT1_VAE_PTH, spatial_compression_factor=8, temporal_compression_factor=4, + causal=True, ) Wan2pt2VAEConfig: LazyDict = L(Wan2pt2VAEInterface)( @@ -61,14 +65,7 @@ vae_path=PRETRAINED_TOKENIZER_WAN2PT2_VAE_PTH, spatial_compression_factor=16, temporal_compression_factor=4, -) - -DCAE4x32x32Config: LazyDict = L(DCAE4x32x32Interface)( - bucket_name=PLACEHOLDER, - object_store_credential_path_pretrained=PLACEHOLDER, - vae_path=PRETRAINED_TOKENIZER_DCAE_PTH, - spatial_compression_factor=32, - temporal_compression_factor=4, + causal=True, ) DCAE4x32x32C64T120_256pFpsAllEncoderCausalDecoderChunkCausal4NoganCosmosPad7V0pt2Config: LazyDict = L( @@ -80,6 +77,7 @@ model_name="dcae4x32x32_c64_t120_256p_fps_all_encoder_causal_decoder_chunk_causal_4_nogan_cosmos_pad_7_v0.2", spatial_compression_factor=32, temporal_compression_factor=4, + causal=True, ) DCAE4x32x32C96T120_256pFpsAllEncoderCausalDecoderChunkCausal4NoganCosmosPad7V0pt2Config: LazyDict = L( @@ -104,12 +102,17 @@ temporal_compression_factor=4, ) -UniAE4x16x16C48T8to24_64to512pFpsAllEncoderNoncausalDecoderNoncausalNoganBestS1Config: LazyDict = L(UniAEVAEInterface)( + +UniAE4x16x16C48T16to160MixpFpsMixEncoderNoncausalDecoderNoncausalNoganS3Nemotron2bVAEConfig: LazyDict = L( + UniAEVAEInterface +)( bucket_name=PLACEHOLDER, object_store_credential_path_pretrained=PLACEHOLDER, - vae_path=PRETRAINED_TOKENIZER_UNIAE_4X16X16_C48_T8TO24_64TO512P_FPS_ALL_ENCODER_NONCAUSAL_DECODER_NONCAUSAL_NOGAN_BEST_S1_VAE_PTH, + vae_path=PRETRAINED_TOKENIZER_UNIAE_4X16X16_C48_T16TO160_MIXP_FPS_MIX_ENCODER_NONCAUSAL_DECODER_NONCAUSAL_NOGAN_S3_NEMOTRON2B_VAE_PTH, spatial_compression_factor=16, temporal_compression_factor=4, + pixel_trim=True, + causal=False, ) # ============================================================================= @@ -173,8 +176,8 @@ def register_tokenizer(): cs.store( group="tokenizer", package="model.config.tokenizer", - name="uniae_4x16x16_c48_t8to24_64to512p_fps_all_encoder_noncausal_decoder_noncausal_nogan_best_s1_tokenizer", - node=UniAE4x16x16C48T8to24_64to512pFpsAllEncoderNoncausalDecoderNoncausalNoganBestS1Config, + name="uniae_4x16x16_c48_t16to160_mixp_fps_mix_encoder_noncausal_decoder_noncausal_nogan_s3_nemotron2b_tokenizer", + node=UniAE4x16x16C48T16to160MixpFpsMixEncoderNoncausalDecoderNoncausalNoganS3Nemotron2bVAEConfig, ) # Flux tokenizer cs.store(group="tokenizer", package="model.config.tokenizer", name="flux_tokenizer", node=FluxVAEConfig) @@ -182,25 +185,19 @@ def register_tokenizer(): cs.store( group="tokenizer", package="model.config.tokenizer", - name="dc_ae_4x32x32_tokenizer", - node=DCAE4x32x32Config, - ) - cs.store( - group="tokenizer", - package="model.config.tokenizer", - name="dc_ae_4x32x32_c64_t120_256p_fps_all_encoder_causal_decoder_chunk_causal_4_nogan_cosmos_pad_7_v0.2_tokenizer", + name="dcae4x32x32_c64_t120_256p_fps_all_encoder_causal_decoder_chunk_causal_4_nogan_cosmos_pad_7_v0.2_tokenizer", node=DCAE4x32x32C64T120_256pFpsAllEncoderCausalDecoderChunkCausal4NoganCosmosPad7V0pt2Config, ) cs.store( group="tokenizer", package="model.config.tokenizer", - name="dc_ae_4x32x32_c96_t120_256p_fps_all_encoder_causal_decoder_chunk_causal_4_nogan_cosmos_pad_7_v0.2_tokenizer", + name="dcae4x32x32_c96_t120_256p_fps_all_encoder_causal_decoder_chunk_causal_4_nogan_cosmos_pad_7_v0.2_tokenizer", node=DCAE4x32x32C96T120_256pFpsAllEncoderCausalDecoderChunkCausal4NoganCosmosPad7V0pt2Config, ) cs.store( group="tokenizer", package="model.config.tokenizer", - name="dc_ae_4x32x32_c128_t120_256p_fps_all_encoder_causal_decoder_chunk_causal_4_nogan_cosmos_pad_7_v0.2_tokenizer", + name="dcae4x32x32_c128_t120_256p_fps_all_encoder_causal_decoder_chunk_causal_4_nogan_cosmos_pad_7_v0.2_tokenizer", node=DCAE4x32x32C128T120_256pFpsAllEncoderCausalDecoderChunkCausal4NoganCosmosPad7V0pt2Config, ) diff --git a/cosmos_framework/configs/base/defaults/unittest.py b/cosmos_framework/configs/base/defaults/unittest.py deleted file mode 100644 index 5af5bb2..0000000 --- a/cosmos_framework/configs/base/defaults/unittest.py +++ /dev/null @@ -1,27 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: OpenMDW-1.1 -import attrs - -# We are hardcoding the unittest assets in this file. - -# add codeowner for cosmos_framework/model/vfm/tokenizers - - -@attrs.define(slots=False) -class SwfitStackPDXrConfig: - """ - Config for the cluster specific information. - Everything cluster specific should be here. - """ - - object_store_bucket_data: str - object_store_credential_data: str - - -UNITTEST_CONFIG = SwfitStackPDXrConfig( - object_store_bucket_data="unittest", - object_store_credential_data="credentials/pdx_dir.secret", -) - -TOKENIZER_RECONSTRUCTION_VIDEO_PATH = "tokenizer/video/panda70m_test_0000039_00000.mp4" -AVAE_RECONSTRUCTION_AUDIO_PATH = "tokenizer/audio/test_audio.wav" diff --git a/cosmos_framework/configs/base/defaults/vlm.py b/cosmos_framework/configs/base/defaults/vlm.py index fa4d7c8..32718f4 100644 --- a/cosmos_framework/configs/base/defaults/vlm.py +++ b/cosmos_framework/configs/base/defaults/vlm.py @@ -95,11 +95,7 @@ def download_tokenizer_files(model_name: str, config_variant: str) -> str: return destination_dir -def create_qwen2_tokenizer_with_download(pretrained_model_name: str, config_variant: str, **_unused_kwargs): - # **_unused_kwargs absorbs extras (e.g. tokenizer_type) that OmegaConf - # merges in from a vlm_config preset's tokenizer block when an experiment - # overrides the tokenizer with this function but doesn't fully replace - # the preset's kwarg dict. +def create_qwen2_tokenizer_with_download(pretrained_model_name: str, config_variant: str): destination_dir = download_tokenizer_files(pretrained_model_name, config_variant) return LLMTokenizerProcessor(Qwen2Tokenizer.from_pretrained(destination_dir)) @@ -140,7 +136,7 @@ class VLMConfig: # HuggingFace model identifier or local path. Drives AutoConfig + AutoModel selection. model_name: str = "" - # Safetensor path for model + # Safetensor path for model for load a safetensor from different folder safetensors_path: str = "" # Optional pretrained-weights overlay (separate from the AutoModel structural @@ -285,29 +281,6 @@ class VLMConfig: ), ) -CosmosReason2_VLM_30b_a3b_Private_GCP_Config: VLMConfig = VLMConfig( - model_name="nvidia/Cosmos-Reason2-30B-A3B-Private", - model_instance=L(Qwen3VLMoeTextForCausalLM)( - config=L(create_vlm_config)( - base_config=L(Qwen3VLMoeMoTConfig.from_json_file)( - json_file="cosmos_framework/model/vfm/vlm/qwen3_vl_moe/configs/Qwen3-VL-30B-A3B-Instruct.json" - ), - layer_module="Qwen3VLMoeTextMoTDecoderLayer", - qk_norm_for_text=True, - ), - ), - tokenizer=L(build_processor_lazy)( - tokenizer_type="Qwen/Qwen3-VL-30B-A3B-Instruct", - config_variant="gcp", - ), - layer_module="Qwen3VLMoeTextMoTDecoderLayer", - pretrained_weights=PretrainedWeightsConfig( - backbone_path="s3://bucket0/cosmos3/pretrained/huggingface/Cosmos-Reason/Cosmos-Reason2-30B-A3B-Private/", - credentials_path="credentials/gcp_checkpoint.secret", - enable_gcs_patch_in_boto3=True, - ), -) - # Config for Qwen3VL 235B A22B Instruct model # Qwen3VLMoE uses Qwen2Tokenizer Qwen3VLMoT_VLM_235b_a22b_Instruct_GCP_Config: VLMConfig = VLMConfig( @@ -458,48 +431,6 @@ class VLMConfig: ), ) -CosmosReason2_VLM_2b_Private_GCP_Config: VLMConfig = VLMConfig( - model_name="nvidia/Cosmos-Reason2-2B-Private", - model_instance=L(Qwen3VLTextForCausalLM)( - config=L(create_vlm_config)( - base_config=L(Qwen3VLMoTConfig.from_json_file)( - json_file="cosmos_framework/model/vfm/vlm/qwen3_vl/configs/Qwen3-VL-2B-Instruct.json" - ), - qk_norm_for_text=True, - ), - ), - tokenizer=L(build_processor_lazy)( - tokenizer_type="Qwen/Qwen3-VL-2B-Instruct", - config_variant="gcp", - ), - pretrained_weights=PretrainedWeightsConfig( - backbone_path="s3://bucket0/cosmos3/pretrained/huggingface/Cosmos-Reason/Cosmos-Reason2-2B-Private/", - credentials_path="credentials/gcp_checkpoint.secret", - enable_gcs_patch_in_boto3=True, - ), -) - -Cosmos3Reasoner_VLM_2b_Private_GCP_Config: VLMConfig = VLMConfig( - model_name="nvidia/Cosmos3-Reasoner-2B-Private", - model_instance=L(Qwen3VLTextForCausalLM)( - config=L(create_vlm_config)( - base_config=L(Qwen3VLMoTConfig.from_json_file)( - json_file="cosmos_framework/model/vfm/vlm/qwen3_vl/configs/Qwen3-VL-2B-Instruct.json" - ), - qk_norm_for_text=True, - ), - ), - tokenizer=L(build_processor_lazy)( - tokenizer_type="Qwen/Qwen3-VL-2B-Instruct", - config_variant="gcp", - ), - pretrained_weights=PretrainedWeightsConfig( - backbone_path="s3://bucket0/cosmos3/pretrained/huggingface/Cosmos-Reason/Cosmos3-Reasoner-2B-Private/", - credentials_path="credentials/gcp_checkpoint.secret", - enable_gcs_patch_in_boto3=True, - ), -) - # Config for Qwen3VL 4B Instruct model # Qwen3VL uses Qwen2Tokenizer Qwen3VLMoT_VLM_4b_Instruct_Config: VLMConfig = VLMConfig( @@ -586,8 +517,8 @@ class VLMConfig: ), ) -CosmosReason2_VLM_8b_Private_GCP_Config: VLMConfig = VLMConfig( - model_name="nvidia/Cosmos-Reason2-8B-Private", +Cosmos3Reasoner_VLM_8b_Private_GCP_Config: VLMConfig = VLMConfig( + model_name="nvidia/Cosmos3-Reasoner-8B-Private", model_instance=L(Qwen3VLTextForCausalLM)( config=L(create_vlm_config)( base_config=L(Qwen3VLMoTConfig.from_json_file)( @@ -601,14 +532,14 @@ class VLMConfig: config_variant="gcp", ), pretrained_weights=PretrainedWeightsConfig( - backbone_path="s3://bucket0/cosmos3/pretrained/huggingface/Cosmos-Reason/Cosmos-Reason2-8B-Private/", + backbone_path="s3://bucket0/cosmos3/pretrained/huggingface/Cosmos-Reason/Cosmos3-Reasoner-8B-Private/", credentials_path="credentials/gcp_checkpoint.secret", enable_gcs_patch_in_boto3=True, ), ) -Cosmos3Reasoner_VLM_8b_Private_GCP_Config: VLMConfig = VLMConfig( - model_name="nvidia/Cosmos3-Reasoner-8B-Private", +Cosmos3NanoReasoner_VLM_GCP_Config: VLMConfig = VLMConfig( + model_name="nvidia/Cosmos3-Nano-Reasoner", model_instance=L(Qwen3VLTextForCausalLM)( config=L(create_vlm_config)( base_config=L(Qwen3VLMoTConfig.from_json_file)( @@ -617,18 +548,18 @@ class VLMConfig: qk_norm_for_text=True, ), ), - tokenizer=L(build_processor_lazy)( - tokenizer_type="Qwen/Qwen3-VL-8B-Instruct", + tokenizer=L(create_qwen2_tokenizer_with_download)( + pretrained_model_name="Qwen/Qwen3-VL-8B-Instruct", config_variant="gcp", ), pretrained_weights=PretrainedWeightsConfig( - backbone_path="s3://bucket0/cosmos3/pretrained/huggingface/Cosmos-Reason/Cosmos3-Reasoner-8B-Private/", + backbone_path="s3://bucket0/cosmos3/pretrained/huggingface/Cosmos-Reason/Cosmos3-Nano-Reasoner/", credentials_path="credentials/gcp_checkpoint.secret", enable_gcs_patch_in_boto3=True, ), ) -Cosmos3NanoReasoner_VLM_GCP_Config: VLMConfig = VLMConfig( +Cosmos3NanoReasoner_VLM_GCP_Config_0517: VLMConfig = VLMConfig( model_name="nvidia/Cosmos3-Nano-Reasoner", model_instance=L(Qwen3VLTextForCausalLM)( config=L(create_vlm_config)( @@ -643,13 +574,12 @@ class VLMConfig: config_variant="gcp", ), pretrained_weights=PretrainedWeightsConfig( - backbone_path="s3://bucket0/cosmos3/pretrained/huggingface/Cosmos-Reason/Cosmos3-Nano-Reasoner/", + backbone_path="s3://bucket0/cosmos3/pretrained/huggingface/Cosmos-Reason/Cosmos3-Nano-Reasoner-bb9c6f5/", credentials_path="credentials/gcp_checkpoint.secret", enable_gcs_patch_in_boto3=True, ), ) - # Config for Qwen3VL 32B Instruct model # Qwen3VL uses Qwen2Tokenizer Qwen3VLMoT_VLM_32b_Instruct_Config: VLMConfig = VLMConfig( @@ -693,8 +623,8 @@ class VLMConfig: ), ) -CosmosReason2_VLM_32b_Private_GCP_Config: VLMConfig = VLMConfig( - model_name="nvidia/Cosmos-Reason2-32B-Private", +Cosmos3Reasoner_VLM_32b_Private_GCP_Config: VLMConfig = VLMConfig( + model_name="nvidia/Cosmos3-Reasoner-32B-Private", model_instance=L(Qwen3VLTextForCausalLM)( config=L(create_vlm_config)( base_config=L(Qwen3VLMoTConfig.from_json_file)( @@ -708,14 +638,14 @@ class VLMConfig: config_variant="gcp", ), pretrained_weights=PretrainedWeightsConfig( - backbone_path="s3://bucket0/cosmos3/pretrained/huggingface/Cosmos-Reason/Cosmos-Reason2-32B-Private/", + backbone_path="s3://bucket0/cosmos3/pretrained/huggingface/Cosmos-Reason/Cosmos3-Reasoner-32B-Private/", credentials_path="credentials/gcp_checkpoint.secret", enable_gcs_patch_in_boto3=True, ), ) -Cosmos3Reasoner_VLM_32b_Private_GCP_Config: VLMConfig = VLMConfig( - model_name="nvidia/Cosmos3-Reasoner-32B-Private", +Cosmos3SuperReasoner_VLM_GCP_Config: VLMConfig = VLMConfig( + model_name="nvidia/Cosmos3-Super-Reasoner", model_instance=L(Qwen3VLTextForCausalLM)( config=L(create_vlm_config)( base_config=L(Qwen3VLMoTConfig.from_json_file)( @@ -724,18 +654,18 @@ class VLMConfig: qk_norm_for_text=True, ), ), - tokenizer=L(build_processor_lazy)( - tokenizer_type="Qwen/Qwen3-VL-32B-Instruct", + tokenizer=L(create_qwen2_tokenizer_with_download)( + pretrained_model_name="Qwen/Qwen3-VL-32B-Instruct", config_variant="gcp", ), pretrained_weights=PretrainedWeightsConfig( - backbone_path="s3://bucket0/cosmos3/pretrained/huggingface/Cosmos-Reason/Cosmos3-Reasoner-32B-Private/", + backbone_path="s3://bucket0/cosmos3/pretrained/huggingface/Cosmos-Reason/Cosmos3-Super-Reasoner/", credentials_path="credentials/gcp_checkpoint.secret", enable_gcs_patch_in_boto3=True, ), ) -Cosmos3SuperReasoner_VLM_GCP_Config: VLMConfig = VLMConfig( +Cosmos3SuperReasoner_VLM_GCP_Config_0517: VLMConfig = VLMConfig( model_name="nvidia/Cosmos3-Super-Reasoner", model_instance=L(Qwen3VLTextForCausalLM)( config=L(create_vlm_config)( @@ -750,12 +680,63 @@ class VLMConfig: config_variant="gcp", ), pretrained_weights=PretrainedWeightsConfig( - backbone_path="s3://bucket0/cosmos3/pretrained/huggingface/Cosmos-Reason/Cosmos3-Super-Reasoner/", + backbone_path="s3://bucket0/cosmos3/pretrained/huggingface/Cosmos-Reason/Cosmos3-Super-Reasoner-b6df0d1/", credentials_path="credentials/gcp_checkpoint.secret", enable_gcs_patch_in_boto3=True, ), ) +# Cosmos3-Edge-Reasoner at commit 4acb717. +# nemotron_siglip2 architecture: Nemotron text backbone (56-block hybrid layout, 2048 hidden) +# + SigLIP2 vision encoder. The text transformer is identical in shape to +# Nemotron-3-Dense-VL-2B (hidden_size=2048, 56 alternating attn/MLP blocks → 28 +# effective MoT layers after _transform_text_dict). Uses the same +# nemotron_3_dense_vl weight remapping and config JSON. +Cosmos3EdgeReasoner_VLM_GCP_Config_4acb717: VLMConfig = VLMConfig( + model_name="nvidia/Cosmos3-Edge-Reasoner", + model_instance=L(Nemotron3DenseVLTextForCausalLM)( + config=L(create_vlm_config)( + base_config=L(Nemotron3DenseVLMoTConfig.from_json_file)( + json_file="cosmos_framework/model/vfm/vlm/nemotron_3_dense_vl/configs/Nemotron-2B-Dense-VL.json" + ), + qk_norm_for_text=False, + ), + ), + tokenizer=L(build_processor_lazy)( + tokenizer_type="nvidia/Cosmos3-Edge-Reasoner", + config_variant="gcp", + ), + pretrained_weights=PretrainedWeightsConfig( + backbone_path="s3://bucket0/cosmos3/pretrained/huggingface/nvidia/Cosmos3-Edge-Reasoner-4acb717/", + credentials_path="credentials/gcp_checkpoint.secret", + enable_gcs_patch_in_boto3=True, + checkpoint_format="nemotron_3_dense_vl", + ), +) + +# Cosmos3-Edge-Reasoner at commit 9b4c028 (2026-05-29). +# Same nemotron_siglip2 architecture as 4acb717; new weights uploaded 2026-05-29. +Cosmos3EdgeReasoner_VLM_GCP_Config_9b4c028: VLMConfig = VLMConfig( + model_name="nvidia/Cosmos3-Edge-Reasoner", + model_instance=L(Nemotron3DenseVLTextForCausalLM)( + config=L(create_vlm_config)( + base_config=L(Nemotron3DenseVLMoTConfig.from_json_file)( + json_file="cosmos_framework/model/vfm/vlm/nemotron_3_dense_vl/configs/Nemotron-2B-Dense-VL.json" + ), + qk_norm_for_text=False, + ), + ), + tokenizer=L(build_processor_lazy)( + tokenizer_type="nvidia/Cosmos3-Edge-Reasoner", + config_variant="gcp", + ), + pretrained_weights=PretrainedWeightsConfig( + backbone_path="s3://bucket0/cosmos3/pretrained/huggingface/nvidia/Cosmos3-Edge-Reasoner-9b4c028/", + credentials_path="credentials/gcp_checkpoint.secret", + enable_gcs_patch_in_boto3=True, + checkpoint_format="nemotron_3_dense_vl", + ), +) def register_vlm(): @@ -832,24 +813,6 @@ def register_vlm(): name="cosmos_reason2_vlm_2b_gcp", node=CosmosReason2_VLM_2b_GCP_Config, ) - cs.store( - group="vlm_config", - package="model.config.vlm_config", - name="cosmos_reason2_vlm_2b_private_gcp", - node=CosmosReason2_VLM_2b_Private_GCP_Config, - ) - cs.store( - group="vlm_config", - package="model.config.vlm_config", - name="cosmos3_reasoner_vlm_2b_private_gcp", - node=Cosmos3Reasoner_VLM_2b_Private_GCP_Config, - ) - cs.store( - group="vlm_config", - package="model.config.vlm_config", - name="cosmos_reason2_vlm_8b_private_gcp", - node=CosmosReason2_VLM_8b_Private_GCP_Config, - ) cs.store( group="vlm_config", package="model.config.vlm_config", @@ -865,8 +828,8 @@ def register_vlm(): cs.store( group="vlm_config", package="model.config.vlm_config", - name="cosmos_reason2_vlm_32b_private_gcp", - node=CosmosReason2_VLM_32b_Private_GCP_Config, + name="cosmos3_nano_reasoner_vlm_gcp_0517", + node=Cosmos3NanoReasoner_VLM_GCP_Config_0517, ) cs.store( group="vlm_config", @@ -883,8 +846,8 @@ def register_vlm(): cs.store( group="vlm_config", package="model.config.vlm_config", - name="cosmos_reason2_vlm_30b_a3b_private_gcp", - node=CosmosReason2_VLM_30b_a3b_Private_GCP_Config, + name="cosmos3_super_reasoner_vlm_gcp_0517", + node=Cosmos3SuperReasoner_VLM_GCP_Config_0517, ) cs.store( group="vlm_config", @@ -922,3 +885,15 @@ def register_vlm(): name="qwen3_vl_mot_vlm_32b_instruct_gcp", node=Qwen3VLMoT_VLM_32b_Instruct_GCP_Config, ) + cs.store( + group="vlm_config", + package="model.config.vlm_config", + name="cosmos3_edge_reasoner_vlm_gcp_4acb717", + node=Cosmos3EdgeReasoner_VLM_GCP_Config_4acb717, + ) + cs.store( + group="vlm_config", + package="model.config.vlm_config", + name="cosmos3_edge_reasoner_vlm_gcp_9b4c028", + node=Cosmos3EdgeReasoner_VLM_GCP_Config_9b4c028, + ) diff --git a/cosmos_framework/configs/base/experiment/action/__init__.py b/cosmos_framework/configs/base/experiment/action/__init__.py index 503ec1b..28a81be 100644 --- a/cosmos_framework/configs/base/experiment/action/__init__.py +++ b/cosmos_framework/configs/base/experiment/action/__init__.py @@ -1,3 +1,2 @@ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: OpenMDW-1.1 - diff --git a/cosmos_framework/configs/base/experiment/action/posttrain_config/__init__.py b/cosmos_framework/configs/base/experiment/action/posttrain_config/__init__.py index 503ec1b..28a81be 100644 --- a/cosmos_framework/configs/base/experiment/action/posttrain_config/__init__.py +++ b/cosmos_framework/configs/base/experiment/action/posttrain_config/__init__.py @@ -1,3 +1,2 @@ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: OpenMDW-1.1 - diff --git a/cosmos_framework/configs/base/experiment/action/posttrain_config/action_policy_droid_nano.py b/cosmos_framework/configs/base/experiment/action/posttrain_config/action_policy_droid_nano.py index 4c3a6fb..4bc0c29 100644 --- a/cosmos_framework/configs/base/experiment/action/posttrain_config/action_policy_droid_nano.py +++ b/cosmos_framework/configs/base/experiment/action/posttrain_config/action_policy_droid_nano.py @@ -59,7 +59,6 @@ {"override /ema": "power"}, {"override /tokenizer": "wan2pt2_tokenizer"}, {"override /sound_tokenizer": None}, - {"override /cluster": None}, {"override /vlm_config": None}, {"override /ckpt_type": "dcp"}, "_self_", diff --git a/cosmos_framework/configs/base/experiment/action/pretrained_config/__init__.py b/cosmos_framework/configs/base/experiment/action/pretrained_config/__init__.py index 503ec1b..28a81be 100644 --- a/cosmos_framework/configs/base/experiment/action/pretrained_config/__init__.py +++ b/cosmos_framework/configs/base/experiment/action/pretrained_config/__init__.py @@ -1,3 +1,2 @@ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: OpenMDW-1.1 - diff --git a/cosmos_framework/configs/base/experiment/posttrain_video/__init__.py b/cosmos_framework/configs/base/experiment/posttrain_video/__init__.py index 503ec1b..28a81be 100644 --- a/cosmos_framework/configs/base/experiment/posttrain_video/__init__.py +++ b/cosmos_framework/configs/base/experiment/posttrain_video/__init__.py @@ -1,3 +1,2 @@ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: OpenMDW-1.1 - diff --git a/cosmos_framework/configs/base/experiment/sft/vision_sft_nano.py b/cosmos_framework/configs/base/experiment/sft/vision_sft_nano.py index fc78994..46957d1 100644 --- a/cosmos_framework/configs/base/experiment/sft/vision_sft_nano.py +++ b/cosmos_framework/configs/base/experiment/sft/vision_sft_nano.py @@ -78,7 +78,6 @@ {"override /ema": "power"}, {"override /tokenizer": "wan2pt2_tokenizer"}, {"override /sound_tokenizer": None}, - {"override /cluster": None}, {"override /vlm_config": None}, {"override /ckpt_type": "dcp"}, "_self_", diff --git a/cosmos_framework/configs/base/experiment/sft/vision_sft_super.py b/cosmos_framework/configs/base/experiment/sft/vision_sft_super.py index a1134c5..a49bb3d 100644 --- a/cosmos_framework/configs/base/experiment/sft/vision_sft_super.py +++ b/cosmos_framework/configs/base/experiment/sft/vision_sft_super.py @@ -87,7 +87,6 @@ {"override /ema": "power"}, {"override /tokenizer": "wan2pt2_tokenizer"}, {"override /sound_tokenizer": None}, - {"override /cluster": None}, {"override /vlm_config": None}, {"override /ckpt_type": "dcp"}, "_self_", diff --git a/cosmos_framework/configs/base/vlm/__init__.py b/cosmos_framework/configs/base/vlm/__init__.py index 503ec1b..28a81be 100644 --- a/cosmos_framework/configs/base/vlm/__init__.py +++ b/cosmos_framework/configs/base/vlm/__init__.py @@ -1,3 +1,2 @@ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: OpenMDW-1.1 - diff --git a/cosmos_framework/configs/base/vlm/config.py b/cosmos_framework/configs/base/vlm/config.py index 66dc3ed..4d1de85 100644 --- a/cosmos_framework/configs/base/vlm/config.py +++ b/cosmos_framework/configs/base/vlm/config.py @@ -4,10 +4,9 @@ from cosmos_framework.trainer import ImaginaireTrainer from cosmos_framework.utils import log from cosmos_framework.utils.config_helper import import_all_modules_from_package +from cosmos_framework.configs.base.defaults.checkpointer import register_checkpoint, register_ckpt_type from cosmos_framework.configs.base.vlm.defaults.callbacks import register_callbacks -from cosmos_framework.configs.base.vlm.defaults.checkpointer import register_checkpoint, register_ckpt_type from cosmos_framework.configs.base.vlm.defaults.config import Config - from cosmos_framework.configs.base.vlm.defaults.model import register_model from cosmos_framework.configs.base.vlm.defaults.optimizer import register_optimizer, register_scheduler from cosmos_framework.configs.base.vlm.defaults.vlm_policy import register_vlm_policy diff --git a/cosmos_framework/configs/base/vlm/defaults/__init__.py b/cosmos_framework/configs/base/vlm/defaults/__init__.py index 503ec1b..28a81be 100644 --- a/cosmos_framework/configs/base/vlm/defaults/__init__.py +++ b/cosmos_framework/configs/base/vlm/defaults/__init__.py @@ -1,3 +1,2 @@ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: OpenMDW-1.1 - diff --git a/cosmos_framework/configs/base/vlm/defaults/callbacks.py b/cosmos_framework/configs/base/vlm/defaults/callbacks.py index 1910b63..3392205 100644 --- a/cosmos_framework/configs/base/vlm/defaults/callbacks.py +++ b/cosmos_framework/configs/base/vlm/defaults/callbacks.py @@ -12,7 +12,6 @@ from cosmos_framework.utils.lazy_config import LazyCall as L from cosmos_framework.utils.callback import LowPrecisionCallback, WandBCallback from cosmos_framework.callbacks.dataloader_state import DataLoaderStateCallback - from cosmos_framework.callbacks.grad_clip import GradClip from cosmos_framework.callbacks.hf_export import HFExportCallback from cosmos_framework.callbacks.iter_speed import IterSpeed @@ -47,7 +46,6 @@ def register_callbacks(): config=PLACEHOLDER, trainer=PLACEHOLDER, ), # reads model.precision; no extra kwarg needed - # nvtx=L(NVTXCallback)(synchronize=True), ) diff --git a/cosmos_framework/configs/base/vlm/defaults/dataloader.py b/cosmos_framework/configs/base/vlm/defaults/dataloader.py deleted file mode 100644 index 36b878d..0000000 --- a/cosmos_framework/configs/base/vlm/defaults/dataloader.py +++ /dev/null @@ -1,80 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: OpenMDW-1.1 - -from torch.utils.data import DataLoader - -from cosmos_framework.utils.lazy_config import LazyCall as L -from cosmos_framework.utils.config_helper import ConfigStore -from cosmos_framework.data.vfm.vlm.collate_fn import custom_collate -from cosmos_framework.data.vfm.vlm.debug_data_qwen import DebugQwenDataset -from cosmos_framework.data.vfm.vlm.dummy_data_qwen import DummyQwenDataset -from cosmos_framework.data.vfm.processors import build_processor_lazy - - -# Debug dataset -def create_debug_dataloader_config_qwen( - num_images, loss_on_completion_only: bool = True, use_dummy_image: bool = False -): - return L(DataLoader)( - dataset=L(DebugQwenDataset)( - tokenizer=L(build_processor_lazy)( - tokenizer_type="${model.config.policy.backbone.model_name}", - credentials="${checkpoint.load_from_object_store.credentials}", - bucket="${checkpoint.load_from_object_store.bucket}", - ), - num_images=num_images, - seq_len="${model.config.policy.model_max_length}", - image_token_len="${model.config.policy.qwen_max_video_token_length}", - # use_dummy_image=use_dummy_image, - ), - num_workers=8, - prefetch_factor=4, - batch_size=1, - sampler=None, - persistent_workers=False, - pin_memory=True, - collate_fn=custom_collate, - ) - - -def create_dummy_dataloader_config_qwen(): - return L(DataLoader)( - dataset=L(DummyQwenDataset)( - tokenizer=L(build_processor_lazy)( - tokenizer_type="${model.config.policy.backbone.model_name}", - credentials="${checkpoint.load_from_object_store.credentials}", - bucket="${checkpoint.load_from_object_store.bucket}", - ), - num_visual_tokens="${model.config.policy.qwen_max_video_token_length}", - total_tokens="${model.config.policy.model_max_length}", - batch_size="${dataloader_train.batch_size}", - ), - num_workers=8, - prefetch_factor=4, - batch_size=1, - sampler=None, - persistent_workers=False, - pin_memory=True, - collate_fn=custom_collate, - ) - - -def register_data_debug(): - cs = ConfigStore.instance() - for split in ["train", "val"]: - cs.store( - group=f"data_{split}", - package=f"dataloader_{split}", - name="debug_image_data_qwen", # This data is from pixtral model output, expected to have low loss ~1.4 - node=create_debug_dataloader_config_qwen(1), - ) - cs.store( - group=f"data_{split}", - package=f"dataloader_{split}", - name="dummy_image_data_qwen", - node=create_dummy_dataloader_config_qwen(), - ) - - -def register_data(): - register_data_debug() diff --git a/cosmos_framework/configs/base/vlm/defaults/optimizer.py b/cosmos_framework/configs/base/vlm/defaults/optimizer.py index 6632dc8..538ab7d 100644 --- a/cosmos_framework/configs/base/vlm/defaults/optimizer.py +++ b/cosmos_framework/configs/base/vlm/defaults/optimizer.py @@ -1,5 +1,6 @@ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: OpenMDW-1.1 + """Hydra config registrations for VLM optimizer + LR scheduler.""" from typing import Any diff --git a/cosmos_framework/configs/base/vlm/defaults/policy_config.py b/cosmos_framework/configs/base/vlm/defaults/policy_config.py index 307eb92..7df9987 100644 --- a/cosmos_framework/configs/base/vlm/defaults/policy_config.py +++ b/cosmos_framework/configs/base/vlm/defaults/policy_config.py @@ -29,7 +29,7 @@ class PolicyConfig: trainable_map: Union[str, None] = None monkey_patch_for_text_only_data: bool = False - # HF attention impl. Default "cosmos" routes through imaginaire.attention + # HF attention impl. Default "cosmos" routes through cosmos_framework.model.attention # (NATTEN/blackwell-fmha on GB200). Override to "flash_attention_2", # "sdpa", or "eager" for fallback. attn_implementation: str = "cosmos" @@ -53,5 +53,6 @@ class VLMModelConfig: ema: EMAConfig = EMAConfig(enabled=False) # Force deterministic kernels in Flash-Attention init (slower; required for - # parity bit-exactness) + # parity bit-exactness). VLM-only knob — consumed by VLMModel.__init__ via + # init_flash_attn_meta. deterministic: bool = False diff --git a/cosmos_framework/configs/base/vlm/experiment/__init__.py b/cosmos_framework/configs/base/vlm/experiment/__init__.py index 503ec1b..28a81be 100644 --- a/cosmos_framework/configs/base/vlm/experiment/__init__.py +++ b/cosmos_framework/configs/base/vlm/experiment/__init__.py @@ -1,3 +1,2 @@ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: OpenMDW-1.1 - diff --git a/cosmos_framework/configs/base/vlm/freeze_config.py b/cosmos_framework/configs/base/vlm/freeze_config.py index 9629782..d161fe8 100644 --- a/cosmos_framework/configs/base/vlm/freeze_config.py +++ b/cosmos_framework/configs/base/vlm/freeze_config.py @@ -1,5 +1,6 @@ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: OpenMDW-1.1 + """VLM freeze config (read by ``vlm_model._apply_freeze_config``).""" import attrs diff --git a/cosmos_framework/data/imaginaire/webdataset/augmentors/image/__init__.py b/cosmos_framework/data/imaginaire/webdataset/augmentors/image/__init__.py new file mode 100644 index 0000000..28a81be --- /dev/null +++ b/cosmos_framework/data/imaginaire/webdataset/augmentors/image/__init__.py @@ -0,0 +1,2 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: OpenMDW-1.1 diff --git a/cosmos_framework/data/imaginaire/webdataset/augmentors/image/cropping.py b/cosmos_framework/data/imaginaire/webdataset/augmentors/image/cropping.py new file mode 100644 index 0000000..b34cb81 --- /dev/null +++ b/cosmos_framework/data/imaginaire/webdataset/augmentors/image/cropping.py @@ -0,0 +1,150 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: OpenMDW-1.1 + +from typing import Optional + +import torch +import torchvision.transforms.functional as transforms_F +from loguru import logger as logging + +from cosmos_framework.data.imaginaire.webdataset.augmentors.augmentor import Augmentor +from cosmos_framework.data.imaginaire.webdataset.augmentors.image.misc import obtain_augmentation_size, obtain_image_size + + +class CenterCrop(Augmentor): + def __init__(self, input_keys: list, output_keys: Optional[list] = None, args: Optional[dict] = None) -> None: + super().__init__(input_keys, output_keys, args) + + def __call__(self, data_dict: dict) -> dict: + r"""Performs center crop. + + Args: + data_dict (dict): Input data dict + Returns: + data_dict (dict): Output dict where images are center cropped. + We also save the cropping parameters in the aug_params dict + so that it will be used by other transforms. + """ + assert (self.args is not None) and ("size" in self.args), "Please specify size in args" + + img_size = obtain_augmentation_size(data_dict, self.args) + width, height = img_size + + orig_w, orig_h = obtain_image_size(data_dict, self.input_keys) + for key in self.input_keys: + data_dict[key] = transforms_F.center_crop(data_dict[key], [height, width]) + + # We also add the aug params we use. This will be useful for other transforms + crop_x0 = (orig_w - width) // 2 + crop_y0 = (orig_h - height) // 2 + cropping_params = { + "resize_w": orig_w, + "resize_h": orig_h, + "crop_x0": crop_x0, + "crop_y0": crop_y0, + "crop_w": width, + "crop_h": height, + } + + if "aug_params" not in data_dict: + data_dict["aug_params"] = dict() + + data_dict["aug_params"]["cropping"] = cropping_params + return data_dict + + +class BottomCrop(Augmentor): + def __init__(self, input_keys: list, output_keys: Optional[list] = None, args: Optional[dict] = None) -> None: + super().__init__(input_keys, output_keys, args) + + def __call__(self, data_dict: dict) -> dict: + r"""Crops rows from the bottom of the image/video to reach ``target_height``. + + The top of the frame is preserved (content is top-anchored). Width is unchanged. + Works for 3-D ``[C, H, W]`` images and 4-D ``[C, T, H, W]`` or ``[T, C, H, W]`` + videos — the last two dims are always treated as (H, W). + + Args: + data_dict (dict): Input data dict. ``self.args["target_height"]`` is the + desired output height. Source height must be ``>= target_height``. + + Returns: + data_dict (dict): Output dict where images are bottom-cropped and + ``image_size`` is updated to ``[target_h, w, orig_h, orig_w]`` to mirror + :class:`ReflectionPadding`'s contract. + """ + assert (self.args is not None) and ("target_height" in self.args), "Please specify target_height in args" + if self.output_keys is None: + self.output_keys = self.input_keys + + target_h = int(self.args["target_height"]) + orig_w, orig_h = obtain_image_size(data_dict, self.input_keys) + assert orig_h >= target_h, ( + f"BottomCrop requires source height >= target_height: got orig_h={orig_h}, target_h={target_h}" + ) + + for inp_key, out_key in zip(self.input_keys, self.output_keys): + tensor = data_dict[inp_key] + # Slice the last 2 dims; the second-to-last dim is height regardless of + # whether the tensor is CHW, CTHW, or TCHW. + data_dict[out_key] = tensor[..., :target_h, :] + + if out_key != inp_key: + del data_dict[inp_key] + + data_dict["image_size"] = torch.tensor([target_h, orig_w, orig_h, orig_w], dtype=torch.float) + + return data_dict + + +class RandomCrop(Augmentor): + def __init__(self, input_keys: list, output_keys: Optional[list] = None, args: Optional[dict] = None) -> None: + super().__init__(input_keys, output_keys, args) + + def __call__(self, data_dict: dict) -> dict: + r"""Performs random crop. + + Args: + data_dict (dict): Input data dict + Returns: + data_dict (dict): Output dict where images are center cropped. + We also save the cropping parameters in the aug_params dict + so that it will be used by other transforms. + """ + + img_size = obtain_augmentation_size(data_dict, self.args) + width, height = img_size + + orig_w, orig_h = obtain_image_size(data_dict, self.input_keys) + # Obtaining random crop coords + try: + crop_x0 = int(torch.randint(0, orig_w - width + 1, size=(1,)).item()) + crop_y0 = int(torch.randint(0, orig_h - height + 1, size=(1,)).item()) + except Exception as e: + logging.warning( + f"Random crop failed. Performing center crop, original_size(wxh): {orig_w}x{orig_h}, random_size(wxh): {width}x{height}" + ) + for key in self.input_keys: + data_dict[key] = transforms_F.center_crop(data_dict[key], [height, width]) + crop_x0 = (orig_w - width) // 2 + crop_y0 = (orig_h - height) // 2 + + # We also add the aug params we use. This will be useful for other transforms + cropping_params = { + "resize_w": orig_w, + "resize_h": orig_h, + "crop_x0": crop_x0, + "crop_y0": crop_y0, + "crop_w": width, + "crop_h": height, + } + + if "aug_params" not in data_dict: + data_dict["aug_params"] = dict() + + data_dict["aug_params"]["cropping"] = cropping_params + + # We must perform same random cropping for all input keys + for key in self.input_keys: + data_dict[key] = transforms_F.crop(data_dict[key], crop_y0, crop_x0, height, width) + return data_dict diff --git a/cosmos_framework/data/imaginaire/webdataset/augmentors/image/flip.py b/cosmos_framework/data/imaginaire/webdataset/augmentors/image/flip.py new file mode 100644 index 0000000..8f0bb7d --- /dev/null +++ b/cosmos_framework/data/imaginaire/webdataset/augmentors/image/flip.py @@ -0,0 +1,32 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: OpenMDW-1.1 + +from typing import Optional + +import torch +import torchvision.transforms.functional as transforms_F + +from cosmos_framework.data.imaginaire.webdataset.augmentors.augmentor import Augmentor + + +class HorizontalFlip(Augmentor): + def __init__(self, input_keys: list, output_keys: Optional[list] = None, args: Optional[dict] = None) -> None: + super().__init__(input_keys, output_keys, args) + + def __call__(self, data_dict: dict) -> dict: + r"""Performs horizontal flipping. + + Args: + data_dict (dict): Input data dict + Returns: + data_dict (dict): Output dict where images are center cropped. + """ + flip_enabled = getattr(self.args, "enabled", True) + if flip_enabled: + p = getattr(self.args, "prob", 0.5) + coin_flip = torch.rand(1).item() > p + for key in self.input_keys: + if coin_flip: + data_dict[key] = transforms_F.hflip(data_dict[key]) + + return data_dict diff --git a/cosmos_framework/data/imaginaire/webdataset/augmentors/image/misc.py b/cosmos_framework/data/imaginaire/webdataset/augmentors/image/misc.py new file mode 100644 index 0000000..d3e5216 --- /dev/null +++ b/cosmos_framework/data/imaginaire/webdataset/augmentors/image/misc.py @@ -0,0 +1,51 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: OpenMDW-1.1 + +from typing import Union + +import torch +from PIL import Image + + +def obtain_image_size(data_dict: dict, input_keys: list) -> tuple[int, int]: + r"""Function for obtaining the image size from the data dict. + + Args: + data_dict (dict): Input data dict + input_keys (list): List of input keys + Returns: + width (int): Width of the input image + height (int): Height of the input image + """ + + data1 = data_dict[input_keys[0]] + if isinstance(data1, Image.Image): + width, height = data1.size + elif isinstance(data1, torch.Tensor): + height, width = data1.size()[-2:] + else: + raise ValueError("data to random crop should be PIL Image or tensor") + + return width, height + + +def obtain_augmentation_size(data_dict: dict, augmentor_cfg: dict) -> Union[int, tuple]: + r"""Function for obtaining size of the augmentation. + When dealing with multi-aspect ratio dataloaders, we need to + find the augmentation size from the aspect ratio of the data. + If data_dict contains "_res_size_map" (e.g. from resolution sampling), + that map is used instead of augmentor_cfg["size"]. + + Args: + data_dict (dict): Input data dict + augmentor_cfg (dict): Augmentor config + Returns: + aug_size (int): Size of augmentation + """ + if "__url__" in data_dict and "aspect_ratio" in data_dict["__url__"].meta.opts: + aspect_ratio = data_dict["__url__"].meta.opts["aspect_ratio"] + else: # Non-webdataset format + aspect_ratio = data_dict["aspect_ratio"] + if "_res_size_map" in data_dict: + return data_dict["_res_size_map"][aspect_ratio] + return augmentor_cfg["size"][aspect_ratio] diff --git a/cosmos_framework/data/imaginaire/webdataset/augmentors/image/normalize.py b/cosmos_framework/data/imaginaire/webdataset/augmentors/image/normalize.py new file mode 100644 index 0000000..a949230 --- /dev/null +++ b/cosmos_framework/data/imaginaire/webdataset/augmentors/image/normalize.py @@ -0,0 +1,36 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: OpenMDW-1.1 + +from typing import Optional + +import torch +import torchvision.transforms.functional as transforms_F + +from cosmos_framework.data.imaginaire.webdataset.augmentors.augmentor import Augmentor + + +class Normalize(Augmentor): + def __init__(self, input_keys: list, output_keys: Optional[list] = None, args: Optional[dict] = None) -> None: + super().__init__(input_keys, output_keys, args) + + def __call__(self, data_dict: dict) -> dict: + r"""Performs data normalization. + + Args: + data_dict (dict): Input data dict + Returns: + data_dict (dict): Output dict where images are center cropped. + """ + assert self.args is not None, "Please specify args" + + mean = self.args["mean"] + std = self.args["std"] + + for key in self.input_keys: + if isinstance(data_dict[key], torch.Tensor): + data_dict[key] = data_dict[key].to(dtype=torch.get_default_dtype()).div(255) + else: + data_dict[key] = transforms_F.to_tensor(data_dict[key]) # division by 255 is applied in to_tensor() + + data_dict[key] = transforms_F.normalize(tensor=data_dict[key], mean=mean, std=std) + return data_dict diff --git a/cosmos_framework/data/imaginaire/webdataset/augmentors/image/padding.py b/cosmos_framework/data/imaginaire/webdataset/augmentors/image/padding.py new file mode 100644 index 0000000..e14d66f --- /dev/null +++ b/cosmos_framework/data/imaginaire/webdataset/augmentors/image/padding.py @@ -0,0 +1,60 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: OpenMDW-1.1 + +from typing import Optional + +import omegaconf +import torch +import torchvision.transforms.functional as transforms_F + +from cosmos_framework.data.imaginaire.webdataset.augmentors.augmentor import Augmentor +from cosmos_framework.data.imaginaire.webdataset.augmentors.image.misc import obtain_augmentation_size, obtain_image_size + + +class ReflectionPadding(Augmentor): + def __init__(self, input_keys: list, output_keys: Optional[list] = None, args: Optional[dict] = None) -> None: + super().__init__(input_keys, output_keys, args) + + def __call__(self, data_dict: dict) -> dict: + r"""Performs reflection padding. This function also returns a padding mask. + + Args: + data_dict (dict): Input data dict + Returns: + data_dict (dict): Output dict where images are center cropped. + """ + + assert self.args is not None, "Please specify args in augmentation" + if self.output_keys is None: + self.output_keys = self.input_keys + + # Obtain image and augmentation sizes + orig_w, orig_h = obtain_image_size(data_dict, self.input_keys) + target_size = obtain_augmentation_size(data_dict, self.args) + + assert isinstance(target_size, (tuple, omegaconf.listconfig.ListConfig)), "Please specify target size as tuple" + target_w, target_h = target_size + + target_w = int(target_w) + target_h = int(target_h) + + # One-sided padding (bottom and right only, content stays at top-left) + padding_right = target_w - orig_w + padding_bottom = target_h - orig_h + padding_vals = [0, 0, padding_right, padding_bottom] + + for inp_key, out_key in zip(self.input_keys, self.output_keys): + if max(padding_vals[0], padding_vals[2]) >= orig_w or max(padding_vals[1], padding_vals[3]) >= orig_h: + # In this case, we can't perform reflection padding. This is because padding values + # are larger than the image size. So, perform edge padding instead. + data_dict[out_key] = transforms_F.pad(data_dict[inp_key], padding_vals, padding_mode="edge") + else: + # Perform reflection padding + data_dict[out_key] = transforms_F.pad(data_dict[inp_key], padding_vals, padding_mode="reflect") + + if out_key != inp_key: + del data_dict[inp_key] + + data_dict["image_size"] = torch.tensor([target_h, target_w, orig_h, orig_w], dtype=torch.float) + + return data_dict diff --git a/cosmos_framework/data/imaginaire/webdataset/augmentors/image/resize.py b/cosmos_framework/data/imaginaire/webdataset/augmentors/image/resize.py new file mode 100644 index 0000000..82cdea9 --- /dev/null +++ b/cosmos_framework/data/imaginaire/webdataset/augmentors/image/resize.py @@ -0,0 +1,175 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: OpenMDW-1.1 + +from typing import Optional + +import omegaconf +import torchvision.transforms.functional as transforms_F + +from cosmos_framework.data.imaginaire.webdataset.augmentors.augmentor import Augmentor +from cosmos_framework.data.imaginaire.webdataset.augmentors.image.misc import obtain_augmentation_size, obtain_image_size + + +class ResizeSmallestSide(Augmentor): + def __init__(self, input_keys: list, output_keys: Optional[list] = None, args: Optional[dict] = None) -> None: + super().__init__(input_keys, output_keys, args) + + def __call__(self, data_dict: dict) -> dict: + r"""Performs resizing to smaller side + + Args: + data_dict (dict): Input data dict + Returns: + data_dict (dict): Output dict where images are resized + """ + + if self.output_keys is None: + self.output_keys = self.input_keys + assert self.args is not None, "Please specify args in augmentations" + + for inp_key, out_key in zip(self.input_keys, self.output_keys): + out_size = obtain_augmentation_size(data_dict, self.args) + assert isinstance(out_size, int), "Arg size in resize should be an integer" + data_dict[out_key] = transforms_F.resize( + data_dict[inp_key], + size=out_size, # type: ignore + interpolation=getattr(self.args, "interpolation", transforms_F.InterpolationMode.BICUBIC), + antialias=True, + ) + if out_key != inp_key: + del data_dict[inp_key] + return data_dict + + +class ResizeLargestSide(Augmentor): + def __init__(self, input_keys: list, output_keys: Optional[list] = None, args: Optional[dict] = None) -> None: + super().__init__(input_keys, output_keys, args) + + def __call__(self, data_dict: dict) -> dict: + r"""Performs resizing to larger side + + Args: + data_dict (dict): Input data dict + Returns: + data_dict (dict): Output dict where images are resized + """ + + if self.output_keys is None: + self.output_keys = self.input_keys + assert self.args is not None, "Please specify args in augmentations" + + for inp_key, out_key in zip(self.input_keys, self.output_keys): + out_size = obtain_augmentation_size(data_dict, self.args) + assert isinstance(out_size, int), "Arg size in resize should be an integer" + orig_w, orig_h = obtain_image_size(data_dict, self.input_keys) + + scaling_ratio = min(out_size / orig_w, out_size / orig_h) + target_size = [int(scaling_ratio * orig_h), int(scaling_ratio * orig_w)] + + data_dict[out_key] = transforms_F.resize( + data_dict[inp_key], + size=target_size, + interpolation=getattr(self.args, "interpolation", transforms_F.InterpolationMode.BICUBIC), + antialias=True, + ) + if out_key != inp_key: + del data_dict[inp_key] + return data_dict + + +class ResizeSmallestSideAspectPreserving(Augmentor): + def __init__(self, input_keys: list, output_keys: Optional[list] = None, args: Optional[dict] = None) -> None: + super().__init__(input_keys, output_keys, args) + + def __call__(self, data_dict: dict) -> dict: + r"""Performs aspect-ratio preserving resizing. + Image is resized to the dimension which has the smaller ratio of (size / target_size). + First we compute (w_img / w_target) and (h_img / h_target) and resize the image + to the dimension that has the smaller of these ratios. + + Args: + data_dict (dict): Input data dict + Returns: + data_dict (dict): Output dict where images are resized + """ + + if self.output_keys is None: + self.output_keys = self.input_keys + assert self.args is not None, "Please specify args in augmentations" + + img_size = obtain_augmentation_size(data_dict, self.args) + assert isinstance(img_size, (tuple, omegaconf.listconfig.ListConfig)), ( + f"Arg size in resize should be a tuple, get {type(img_size)}, {img_size}" + ) + img_w, img_h = img_size + + orig_w, orig_h = obtain_image_size(data_dict, self.input_keys) + scaling_ratio = max((img_w / orig_w), (img_h / orig_h)) + target_size = (int(scaling_ratio * orig_h + 0.5), int(scaling_ratio * orig_w + 0.5)) + + assert target_size[0] >= img_h and target_size[1] >= img_w, ( + f"Resize error. orig {(orig_w, orig_h)} desire {img_size} compute {target_size}" + ) + + for inp_key, out_key in zip(self.input_keys, self.output_keys): + data_dict[out_key] = transforms_F.resize( + data_dict[inp_key], + size=target_size, # type: ignore + interpolation=( + self.args["interpolation"] + if "interpolation" in self.args + else transforms_F.InterpolationMode.BICUBIC + ), + antialias=True, + ) + + if out_key != inp_key: + del data_dict[inp_key] + return data_dict + + +class ResizeLargestSideAspectPreserving(Augmentor): + def __init__(self, input_keys: list, output_keys: Optional[list] = None, args: Optional[dict] = None) -> None: + super().__init__(input_keys, output_keys, args) + + def __call__(self, data_dict: dict) -> dict: + r"""Performs aspect-ratio preserving resizing. + Image is resized to the dimension which has the larger ratio of (size / target_size). + First we compute (w_img / w_target) and (h_img / h_target) and resize the image + to the dimension that has the larger of these ratios. + + Args: + data_dict (dict): Input data dict + Returns: + data_dict (dict): Output dict where images are resized + """ + + if self.output_keys is None: + self.output_keys = self.input_keys + assert self.args is not None, "Please specify args in augmentations" + + img_size = obtain_augmentation_size(data_dict, self.args) + assert isinstance(img_size, (tuple, omegaconf.listconfig.ListConfig)), ( + f"Arg size in resize should be a tuple, get {type(img_size)}, {img_size}" + ) + img_w, img_h = img_size + + orig_w, orig_h = obtain_image_size(data_dict, self.input_keys) + scaling_ratio = min((img_w / orig_w), (img_h / orig_h)) + target_size = (int(scaling_ratio * orig_h + 0.5), int(scaling_ratio * orig_w + 0.5)) + + assert target_size[0] <= img_h and target_size[1] <= img_w, ( + f"Resize error. orig {(orig_w, orig_h)} desire {img_size} compute {target_size}" + ) + + for inp_key, out_key in zip(self.input_keys, self.output_keys): + data_dict[out_key] = transforms_F.resize( + data_dict[inp_key], + size=target_size, # type: ignore + interpolation=getattr(self.args, "interpolation", transforms_F.InterpolationMode.BICUBIC), + antialias=True, + ) + + if out_key != inp_key: + del data_dict[inp_key] + return data_dict diff --git a/cosmos_framework/data/vfm/action/__init__.py b/cosmos_framework/data/vfm/action/__init__.py index 503ec1b..28a81be 100644 --- a/cosmos_framework/data/vfm/action/__init__.py +++ b/cosmos_framework/data/vfm/action/__init__.py @@ -1,3 +1,2 @@ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: OpenMDW-1.1 - diff --git a/cosmos_framework/data/vfm/action/action_processing.py b/cosmos_framework/data/vfm/action/action_processing.py new file mode 100644 index 0000000..1a09170 --- /dev/null +++ b/cosmos_framework/data/vfm/action/action_processing.py @@ -0,0 +1,257 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: OpenMDW-1.1 + +"""Shared Action processing records and normalization helpers.""" + +from __future__ import annotations + +import json +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Literal, Protocol + +import numpy as np +import torch + +from cosmos_framework.utils import log + +ActionNormalizationMethod = Literal["quantile", "quantile_rot", "meanstd", "minmax"] + + +class ActionNormalizer(Protocol): + """Tensor-level action normalization interface used by ActionProcessor.""" + + def normalize_action(self, action: torch.Tensor) -> torch.Tensor: # action: [...,D], returns [...,D] + """Map raw action values into model-space values.""" + ... + + def denormalize_action(self, action: torch.Tensor) -> torch.Tensor: # action: [...,D], returns [...,D] + """Invert model-space action values back into raw action values.""" + ... + + +@dataclass(frozen=True) +class ActionAffineNormalization: + """Resolved affine action normalizer. + + Forward normalization is ``(raw - offset) / scale``. Inverse + denormalization is ``normalized * scale + offset``. + + ``forward_clamp`` records lossy range-style forward clamping. When + ``forward_clamp_mask`` is provided, only channels with a ``True`` mask + entry are clamped; this represents mixed UMI normalizers where some fields + are range-clamped and others are plain affine transforms. + """ + + offset: torch.Tensor + scale: torch.Tensor + forward_clamp: tuple[float, float] | None = None + forward_clamp_mask: torch.Tensor | None = None + + def normalize_action(self, action: torch.Tensor) -> torch.Tensor: # action: [...,D], returns [...,D] + """Normalize raw action values with resolved affine parameters.""" + offset = self.offset.to(device=action.device, dtype=action.dtype) # [D] + scale = self.scale.to(device=action.device, dtype=action.dtype) # [D] + normalized = (action - offset) / scale # [...,D] + if self.forward_clamp is not None: + lo, hi = self.forward_clamp + clamped = normalized.clamp(lo, hi) # [...,D] + if self.forward_clamp_mask is None: + normalized = clamped # [...,D] + else: + clamp_mask = self.forward_clamp_mask.to(device=action.device, dtype=torch.bool) # [D] + normalized = torch.where(clamp_mask, clamped, normalized) # [...,D] + return normalized # [...,D] + + def denormalize_action(self, action: torch.Tensor) -> torch.Tensor: # action: [...,D], returns [...,D] + """Invert action normalization with resolved affine parameters.""" + offset = self.offset.to(device=action.device, dtype=action.dtype) # [D] + scale = self.scale.to(device=action.device, dtype=action.dtype) # [D] + return action * scale + offset # [...,D] + + +def load_action_stats(stats_path: str, stats_key: str = "global") -> dict[str, np.ndarray]: + """Load pre-computed action normalization stats from a JSON file.""" + path = Path(stats_path) + if not path.exists(): + raise FileNotFoundError(f"Action normalization stats not found at {stats_path}.") + log.info(f"Loading action normalization stats from {stats_path}") + with path.open("r") as f: + raw = json.load(f) + stat_keys = {"mean", "std", "min", "max", "q01", "q99"} + if stats_key in raw: + raw = raw[stats_key] + if not isinstance(raw, dict): + raise TypeError(f"Action normalization stats block {stats_key!r} in {stats_path} must be a dict.") + elif stats_key != "global" and not any(key in raw for key in stat_keys): + raise KeyError(f"Action normalization stats block {stats_key!r} not found in {stats_path}.") + return {k: np.array(v, dtype=np.float32) for k, v in raw.items() if k in stat_keys} + + +def resolve_action_normalization( + method: ActionNormalizationMethod, + stats: dict[str, torch.Tensor], +) -> ActionAffineNormalization: + """Resolve configured action stats into affine forward/inverse parameters.""" + if method == "meanstd": + offset = stats["mean"] # [D] + scale = stats["std"].clamp(min=1e-8) # [D] + return ActionAffineNormalization(offset=offset, scale=scale) + + if method == "minmax": + lo = stats["min"] # [D] + hi = stats["max"] # [D] + elif method in ("quantile", "quantile_rot"): + lo = stats["q01"] # [D] + hi = stats["q99"] # [D] + else: + raise ValueError(f"Unknown normalization method: {method!r}") + + offset = (hi + lo) / 2.0 # [D] + scale = (hi - lo).clamp(min=1e-8) / 2.0 # [D] + return ActionAffineNormalization( + offset=offset, + scale=scale, + forward_clamp=(-1.0, 1.0), + ) + + +def make_pose_action_scale_normalizer( + action_dim: int, + *, + translation_scale: float = 1.0, + rotation_scale: float = 1.0, +) -> ActionAffineNormalization: + """Build a normalizer that maps raw pose deltas into scaled model space. + + Pose actions use the shared layout ``[translation(3), rotation(...)]``. + The returned normalizer multiplies translation channels by + ``translation_scale`` and rotation channels by ``rotation_scale`` during + preprocessing, then inverts those factors during postprocessing. + """ + if action_dim < 3: + raise ValueError(f"Pose action_dim must be at least 3, got {action_dim}") + if translation_scale == 0: + raise ValueError("translation_scale must be non-zero") + if rotation_scale == 0: + raise ValueError("rotation_scale must be non-zero") + + offset = torch.zeros(action_dim, dtype=torch.float32) # [D] + scale = torch.ones(action_dim, dtype=torch.float32) # [D] + scale[:3] = 1.0 / float(translation_scale) # [D] + if action_dim > 3: + scale[3:] = 1.0 / float(rotation_scale) # [D] + return ActionAffineNormalization(offset=offset, scale=scale) + + +@dataclass(frozen=True) +class ActionProcessingRecord: + """Per-sample metadata needed to invert Action model-space preprocessing.""" + + raw_action_dim: int + action_normalizer: ActionNormalizer | None + + +def pad_action_to_max_dim( + action: torch.Tensor, max_action_dim: int +) -> torch.Tensor: # action: [T,D], returns [T,D_model] + """Pad action tensor to max_action_dim along the last dimension. + + Args: + action: Action tensor of shape (T, D) where D is the current action dimension. + max_action_dim: Target action dimension to pad to. + + Returns: + Padded action tensor of shape (T, max_action_dim). + """ + if action.shape[-1] > max_action_dim: + raise ValueError(f"Action dimension {action.shape[-1]} is greater than max_action_dim {max_action_dim}") + if action.shape[-1] == max_action_dim: + return action # [T,D_model] + padding_size = max_action_dim - action.shape[-1] + zero_padding = torch.zeros(*action.shape[:-1], padding_size, dtype=action.dtype, device=action.device) # [T,D_pad] + return torch.cat([action, zero_padding], dim=-1) # [T,D_model] + + +def make_batched_action_processing_fields( + record: ActionProcessingRecord, + batch_size: int, + *, + action_channel_masking: bool = True, +) -> dict[str, list[torch.Tensor | ActionProcessingRecord | None]]: + """Build batch-list fields whose action width and inverse record cannot drift apart.""" + raw_action_dim = torch.tensor(record.raw_action_dim, dtype=torch.long) if action_channel_masking else None # [] + return { + "raw_action_dim": [raw_action_dim] * batch_size, + "action_processing_record": [record] * batch_size, + } + + +class ActionProcessor: + """Forward and inverse Action tensor processing for a single sample.""" + + def __init__(self, max_action_dim: int, action_channel_masking: bool = True) -> None: + self.max_action_dim = int(max_action_dim) + self.action_channel_masking = bool(action_channel_masking) + + def preprocess_action( + self, + data_dict: dict[str, Any], + action: torch.Tensor, + *, + action_normalizer: ActionNormalizer | None, + ) -> dict[str, Any]: + """Return a sample with normalized, padded action fields and the inverse record.""" + + raw_action_dim = int(action.shape[-1]) + if action_normalizer is not None: + action = action_normalizer.normalize_action(action) # [T,D] + if int(action.shape[-1]) != raw_action_dim: + raise ValueError( + f"Action normalizer changed action width from {raw_action_dim} to {int(action.shape[-1])}" + ) + + processed_data_dict = dict(data_dict) + processed_data_dict["action"] = pad_action_to_max_dim(action, self.max_action_dim) # [T,D_model] + record = ActionProcessingRecord( + raw_action_dim=raw_action_dim, + action_normalizer=action_normalizer, + ) + processed_data_dict["raw_action_dim"] = ( + torch.tensor(record.raw_action_dim, dtype=torch.long) if self.action_channel_masking else None + ) # [] + processed_data_dict["action_processing_record"] = record + return processed_data_dict + + @staticmethod + def _unpad_action(action: torch.Tensor, raw_action_dim: int) -> torch.Tensor: + """Drop model-only padded action channels.""" + if action.shape[-1] < raw_action_dim: + raise ValueError(f"invalid raw_action_dim={raw_action_dim} for action with shape {tuple(action.shape)}") + return action[..., :raw_action_dim].contiguous() # [...,D_raw] + + @staticmethod + def postprocess_action( + action: torch.Tensor, + record: ActionProcessingRecord, + ) -> torch.Tensor: + """Unpad and denormalize a model-space action tensor.""" + action = ActionProcessor._unpad_action(action, record.raw_action_dim) # [...,D_raw] + if record.action_normalizer is not None: + action = record.action_normalizer.denormalize_action(action) # [...,D_raw] + return action # [...,D_raw] + + +def get_action_processing_records(data_batch: dict[str, Any]) -> list[ActionProcessingRecord | None]: + """Read all per-sample processing records from a collated Action batch.""" + records = data_batch.get("action_processing_record") + if records is None: + return [] + if isinstance(records, ActionProcessingRecord): + return [records] + if isinstance(records, list): + for record in records: + if record is not None and not isinstance(record, ActionProcessingRecord): + raise TypeError(f"Unexpected action_processing_record entry type: {type(record).__name__}") + return records + raise TypeError(f"Unexpected action_processing_record type: {type(records).__name__}") diff --git a/cosmos_framework/data/vfm/action/domain_utils.py b/cosmos_framework/data/vfm/action/domain_utils.py index 910cc39..6f433f7 100644 --- a/cosmos_framework/data/vfm/action/domain_utils.py +++ b/cosmos_framework/data/vfm/action/domain_utils.py @@ -14,9 +14,12 @@ "bridge_orig_lerobot": 7, "droid_lerobot": 8, "robomind-franka": 8, # Both Droid and RoboMIND-Franka are using robotiq and franka + "embodiment_b": 9, "robomind-franka-dual": 12, "robomind-ur": 13, "agibotworld": 15, + "embodiment_c_gripper": 15, + "embodiment_c_gripper_ext": 15, "fractal": 20, } @@ -24,7 +27,6 @@ EMBODIMENT_TO_RAW_ACTION_DIM: dict[str, int] = { "av": 9, "camera_pose": 9, - "hand_pose": 57, "pusht": 2, "umi": 10, "bridge_orig_lerobot": 10, @@ -32,8 +34,16 @@ "robomind-franka": 10, "robomind-franka-dual": 20, "robomind-ur": 10, + "embodiment_b": 30, "agibotworld": 29, + "embodiment_c_gripper": 29, + "embodiment_c_gripper_ext": 29, "fractal": 10, + # NOTE: ``libero`` (7/10/13 depending on ``rotation_space``) and ``hand_pose`` + # (variable with ``keypoint_option`` and ``rotation_format``) are absent + # because their raw width is set per-dataset at construction time. Inference + # in inverse_dynamics/policy modes is not supported for these domains until + # canonical widths are added here. } @@ -46,3 +56,20 @@ def get_domain_id(embodiment_type: str) -> int: f"Available embodiments: {sorted(EMBODIMENT_TO_DOMAIN_ID.keys())}" ) return EMBODIMENT_TO_DOMAIN_ID[key] + + +def get_action_dim(embodiment_type: str) -> int: + """Get the raw action dimension for a given embodiment type.""" + key = embodiment_type.lower().strip() + if key not in EMBODIMENT_TO_RAW_ACTION_DIM: + raise KeyError( + f"Unknown embodiment type: {embodiment_type!r}. " + f"Available embodiments: {sorted(EMBODIMENT_TO_RAW_ACTION_DIM.keys())}" + ) + return EMBODIMENT_TO_RAW_ACTION_DIM[key] + + +def is_valid_domain_name(embodiment_type: str) -> bool: + """Check if the given embodiment type is recognized.""" + key = embodiment_type.lower().strip() + return key in EMBODIMENT_TO_RAW_ACTION_DIM diff --git a/cosmos_framework/data/vfm/action/json_formatter.py b/cosmos_framework/data/vfm/action/json_formatter.py index b511e93..201a76e 100644 --- a/cosmos_framework/data/vfm/action/json_formatter.py +++ b/cosmos_framework/data/vfm/action/json_formatter.py @@ -7,9 +7,9 @@ import torch +from cosmos_framework.utils import log from cosmos_framework.data.vfm.action.viewpoint_utils import DEFAULT_VIEWPOINT_TEMPLATES from cosmos_framework.data.vfm.utils import VIDEO_RES_SIZE_INFO -from cosmos_framework.utils import log def _should_append_idle_frame_info(mode: object) -> bool: diff --git a/cosmos_framework/data/vfm/action/pose_utils.py b/cosmos_framework/data/vfm/action/pose_utils.py index 12ea51b..d6d26a4 100644 --- a/cosmos_framework/data/vfm/action/pose_utils.py +++ b/cosmos_framework/data/vfm/action/pose_utils.py @@ -298,24 +298,15 @@ def build_abs_pose_from_components( def _delta_transform_to_pose_vector( delta_T: np.ndarray, rotation_output_format: RotationConvention, - translation_scale: float = 1.0, - rotation_scale: float = 1.0, ) -> np.ndarray: """Encode a relative transform as an action vector. The shared action-vector layout is always ``[translation(3), rotation(...)]``. - The translation block is multiplied by ``translation_scale`` before concatenation, - and the rotation block is multiplied by ``rotation_scale``. Args: delta_T: Relative transform of shape ``(4, 4)``. rotation_output_format: Concrete convention used for the output rotation block. - translation_scale: Scalar multiplier applied to the translation block. - rotation_scale: Scalar multiplier applied to the rotation block. Used to - match the loss scale of the rotation block to the translation block. - The decoder must divide by the same factor before reconstructing the - rotation matrix. Returns: A ``float32`` action vector whose first three values are translation and @@ -325,12 +316,11 @@ def _delta_transform_to_pose_vector( if delta_np.shape != (4, 4): raise ValueError(f"delta_T must have shape (4, 4), got {delta_np.shape}") - translation = delta_np[:3, 3] * translation_scale + translation = delta_np[:3, 3] rotation = np.asarray( convert_rotation(delta_np[:3, :3], input_format="matrix", output_format=rotation_output_format), dtype=np.float32, ) - rotation = rotation * rotation_scale return np.concatenate([translation, rotation]).astype(np.float32) @@ -344,19 +334,19 @@ def _pose_vector_to_delta_transform( """Decode an action vector back into a relative homogeneous transform. This is the inverse of `_delta_transform_to_pose_vector()` when the same - rotation convention and scale are used. + rotation convention is used. Scale arguments are provided for callers that + need to decode model-space pose actions before action-normalizer + denormalization has been applied. Args: pose_vector: Relative-pose action vector with layout ``[translation(3), rotation(...)]``. rotation_input_format: Concrete convention used by the rotation block. - translation_scale: Scalar used to undo the translation scaling applied during - encoding. + translation_scale: Scalar used to undo translation scaling in the input + vector. normalize_rotation: Whether to project the decoded rotation to a valid matrix before assembling the transform. - rotation_scale: Scalar used to undo the rotation scaling applied during - encoding. Must match the value used by - `_delta_transform_to_pose_vector()`. + rotation_scale: Scalar used to undo rotation scaling in the input vector. Returns: A relative homogeneous transform with shape ``(4, 4)`` and dtype @@ -440,8 +430,6 @@ def pose_abs_to_rel( poses_abs: np.ndarray, rotation_format: RotationConvention = "rot9d", pose_convention: PoseConvention = "backward_framewise", - translation_scale: float = 1.0, - rotation_scale: float = 1.0, ) -> np.ndarray: """Convert an absolute-pose trajectory into relative-pose action vectors. @@ -454,12 +442,6 @@ def pose_abs_to_rel( pose_convention: Pose convention: - ``backward_framewise``: ``delta_T = T_i^{-1} @ T_{i+1}`` - ``backward_anchored``: ``delta_T = T_0^{-1} @ T_{i+1}`` - translation_scale: Scalar multiplier applied to the translation block of each - encoded action vector. - rotation_scale: Scalar multiplier applied to the rotation block of each - encoded action vector. Use this to match the loss scale of rotation - and translation. `pose_rel_to_abs()` must be called with the same - value to invert the scaling. Returns: An array of shape ``(T - 1, D)`` where ``D = 3 + rotation_dim``. @@ -481,8 +463,6 @@ def pose_abs_to_rel( _delta_transform_to_pose_vector( delta_T, rotation_output_format=rotation_format, - translation_scale=translation_scale, - rotation_scale=rotation_scale, ) ) @@ -510,10 +490,12 @@ def pose_rel_to_abs( identity transform is used. normalize_rotation: Whether to project decoded rotations onto ``SO(3)`` before composing them back into the trajectory. - translation_scale: Scalar used to undo the translation scaling applied during - `pose_abs_to_rel()`. - rotation_scale: Scalar used to undo the rotation scaling applied during - `pose_abs_to_rel()`. Must match the value passed there. + translation_scale: Scalar used to undo translation scaling in + ``poses_rel``. Prefer denormalizing with the dataset action + normalizer before calling this function. + rotation_scale: Scalar used to undo rotation scaling in ``poses_rel``. + Prefer denormalizing with the dataset action normalizer before + calling this function. Returns: Absolute poses with shape ``(T, 4, 4)`` where ``T = len(poses_rel) + 1``. diff --git a/cosmos_framework/data/vfm/action/pose_utils_test.py b/cosmos_framework/data/vfm/action/pose_utils_test.py index 93f9791..3cdc20b 100644 --- a/cosmos_framework/data/vfm/action/pose_utils_test.py +++ b/cosmos_framework/data/vfm/action/pose_utils_test.py @@ -196,14 +196,12 @@ def test_pose_abs_to_rel_roundtrips_through_pose_rel_to_abs( poses_abs, rotation_format=rotation_format, pose_convention=pose_convention, - translation_scale=2.5, ) reconstructed = pose_rel_to_abs( poses_rel, rotation_format=rotation_format, pose_convention=pose_convention, initial_pose=poses_abs[0], - translation_scale=2.5, ) np.testing.assert_allclose(reconstructed, poses_abs, atol=1e-5) diff --git a/cosmos_framework/data/vfm/action/transforms.py b/cosmos_framework/data/vfm/action/transforms.py index d17b141..3462f1e 100644 --- a/cosmos_framework/data/vfm/action/transforms.py +++ b/cosmos_framework/data/vfm/action/transforms.py @@ -19,6 +19,11 @@ import torch import torchvision.transforms.functional as transforms_F +from cosmos_framework.utils import log +from cosmos_framework.data.vfm.action.action_processing import ( + ActionNormalizer, + ActionProcessor, +) from cosmos_framework.data.vfm.action.json_formatter import ActionPromptJsonFormatter from cosmos_framework.data.vfm.action.viewpoint_utils import ViewpointTextInfo from cosmos_framework.data.vfm.augmentors.duration_fps_text_timestamps import DurationFPSTextTimeStamps @@ -27,7 +32,6 @@ from cosmos_framework.data.vfm.augmentors.text_tokenizer import TextTokenizerTransform from cosmos_framework.data.vfm.sequence_packing import SequencePlan from cosmos_framework.data.vfm.utils import VIDEO_RES_SIZE_INFO -from cosmos_framework.utils import log from cosmos_framework.utils.vfm.data_utils import get_vision_data_resolution @@ -36,28 +40,6 @@ def _should_append_idle_frame_info(mode: object) -> bool: return mode != "inverse_dynamics" -def pad_action_to_max_dim(action: torch.Tensor, max_action_dim: int) -> torch.Tensor: - """Pad action tensor to max_action_dim along the last dimension. - - Args: - action: Action tensor of shape (T, D) where D is the current action dimension. - max_action_dim: Target action dimension to pad to. - - Returns: - Padded action tensor of shape (T, max_action_dim). - """ - if action.shape[-1] > max_action_dim: - raise ValueError(f"Action dimension {action.shape[-1]} is greater than max_action_dim {max_action_dim}") - elif action.shape[-1] == max_action_dim: - return action - else: - padding_size = max_action_dim - action.shape[-1] - zero_padding = torch.zeros( - *action.shape[:-1], padding_size, dtype=action.dtype, device=action.device - ) # [T,padding_size] - return torch.cat([action, zero_padding], dim=-1) # [T,max_action_dim] - - def find_closest_target_size(h: int, w: int, resolution: str | int) -> tuple[int, int]: """Find the closest predefined target size for a given input resolution. @@ -205,7 +187,7 @@ def reflection_pad_to_target( def remove_reflection_padding( tensor: torch.Tensor, - image_size: torch.Tensor, + image_size: torch.Tensor | list[torch.Tensor] | None, ) -> torch.Tensor: """Remove reflection padding added by :func:`reflection_pad_to_target`. @@ -215,17 +197,30 @@ def remove_reflection_padding( tensor: Tensor whose last two dimensions are the padded spatial dims. Supports any leading dimensions, e.g. ``(C, T, H, W)`` or ``(C, H, W)``. - image_size: 1-D tensor of shape ``(4,)`` containing - ``[target_h, target_w, orig_h_resized, orig_w_resized]`` where - ``orig_h/w_resized`` is the original spatial size after - aspect-preserving resize (i.e. the content region before - padding) — the same convention stored by - :func:`reflection_pad_to_target` and VFM's - ``ReflectionPadding``. + image_size: Spatial metadata using the convention produced by + :func:`reflection_pad_to_target`. Accepted forms are ``None`` (no + crop), a tensor with shape ``(4,)`` or ``(1, 4)``, or a non-empty + list whose first element has one of those tensor shapes. The four + values are ``[target_h, target_w, orig_h_resized, + orig_w_resized]``, where ``orig_h/w_resized`` is the original + spatial size after aspect-preserving resize (i.e. the content + region before padding). This matches the convention stored by + :func:`reflection_pad_to_target` and VFM's ``ReflectionPadding``. Returns: Cropped tensor of shape ``(..., orig_h_resized, orig_w_resized)``. """ + if image_size is None: + return tensor + if isinstance(image_size, list): + if not image_size: + raise ValueError("Expected at least one image_size entry") + image_size = image_size[0] # [1,4] or [4] + if image_size.ndim == 2 and image_size.shape[0] == 1: + image_size = image_size[0] # [4] + if image_size.ndim != 1: + raise ValueError(f"Expected image_size shape [4] or [1,4], got {tuple(image_size.shape)}") + target_h = int(image_size[0].item()) target_w = int(image_size[1].item()) orig_h_resized = int(image_size[2].item()) @@ -309,7 +304,6 @@ def build_sequence_plan_from_mode( base_action_length = action_length - num_history_actions if mode == "forward_dynamics": condition_frame_indexes_action = list(range(action_length)) - # This currently assumes that the action length is the same as the video length - 1 # and if action length is the same as the video length, then the first action is the conditioning action elif base_action_length == video_length - 1: @@ -487,6 +481,10 @@ def __init__( self.video_temporal_downsample: int = video_temporal_downsample self.max_action_dim: int = max_action_dim self.action_channel_masking: bool = action_channel_masking + self.action_processor: ActionProcessor = ActionProcessor( + max_action_dim=max_action_dim, + action_channel_masking=action_channel_masking, + ) # --- Spatial resize/padding stage (resolution supplied at call time) --- self.video_resize: VideoResize = VideoResize( @@ -557,7 +555,12 @@ def __init__( }, ) - def __call__(self, data_dict: dict, resolution: str | None) -> dict: + def __call__( + self, + data_dict: dict, + resolution: str | None, + action_normalizer: ActionNormalizer | None = None, + ) -> dict: """Apply the transform pipeline to a single data dictionary. Resolution is required at call time and is the only source of truth @@ -576,7 +579,9 @@ def __call__(self, data_dict: dict, resolution: str | None) -> dict: sample is in inverse dynamics mode (if enabled). 7. Tokenize caption text (if enabled). 8. Build a ``SequencePlan`` from the ``"mode"`` key (if present). - 9. If action is needed by the plan, pad ``"action"`` to ``max_action_dim``. + 9. If action is needed by the plan, normalize real channels, pad + ``"action"`` to ``max_action_dim``, and attach + ``"action_processing_record"``. 10. Otherwise, nullify ``"action"`` and ``"domain_id"`` (e.g. in ``"image2video"`` mode). @@ -584,11 +589,14 @@ def __call__(self, data_dict: dict, resolution: str | None) -> dict: data_dict: A sample dictionary as returned by a Action dataset. resolution: Resolution tier key (e.g. ``"256"``, ``"480"``, ``"720"``) for this sample. When ``None``, auto-detected from video dimensions. + action_normalizer: Optional source-provided action normalizer. When + present, only unpadded real action channels are normalized + before model-space channel padding. Returns: The same dictionary, mutated in-place with padded tensors, - ``image_size``, tokenized text IDs, and a - ``"sequence_plan"`` entry added. + ``image_size``, tokenized text IDs, a ``"sequence_plan"`` entry, + and action processing metadata added. """ mode = data_dict.get("mode") assert mode is not None, "mode is required" @@ -654,13 +662,17 @@ def __call__(self, data_dict: dict, resolution: str | None) -> dict: if sequence_plan.has_action: assert isinstance(action, torch.Tensor), "action tensor is required when sequence plan has action" - data_dict["raw_action_dim"] = torch.tensor(action.shape[1]) if self.action_channel_masking else None - data_dict["action"] = pad_action_to_max_dim(action, self.max_action_dim) + data_dict = self.action_processor.preprocess_action( + data_dict, + action, + action_normalizer=action_normalizer, + ) else: # Nullify action-related fields when action is not needed so the # collate function can simply stack all non-None actions. data_dict["raw_action_dim"] = None data_dict["action"] = None data_dict["domain_id"] = None + data_dict["action_processing_record"] = None return data_dict diff --git a/cosmos_framework/data/vfm/action/transforms_test.py b/cosmos_framework/data/vfm/action/transforms_test.py index 5f024fb..759ec21 100644 --- a/cosmos_framework/data/vfm/action/transforms_test.py +++ b/cosmos_framework/data/vfm/action/transforms_test.py @@ -9,7 +9,11 @@ import torch from cosmos_framework.data.vfm.action.json_formatter import ActionPromptJsonFormatter -from cosmos_framework.data.vfm.action.transforms import ActionTransformPipeline +from cosmos_framework.data.vfm.action.transforms import ( + ActionTransformPipeline, + reflection_pad_to_target, + remove_reflection_padding, +) from cosmos_framework.data.vfm.augmentors.duration_fps_text_timestamps import DurationFPSTextTimeStamps from cosmos_framework.data.vfm.augmentors.resolution_text_info import ResolutionTextInfo @@ -60,6 +64,24 @@ def test_action_prompt_json_formatter_builds_requested_structure() -> None: assert "additional_view_description" not in result +@pytest.mark.L0 +def test_video_padding_round_trips_to_unpadded_region() -> None: + video = torch.arange(3 * 2 * 4 * 5, dtype=torch.float32).reshape(3, 2, 4, 5) # [C,T,H,W] + data_dict = {"video": video} + + padded = reflection_pad_to_target( + data_dict, + keys=["video"], + keep_aspect_ratio=True, + target_w=8, + target_h=6, + ) + round_tripped = remove_reflection_padding(padded["video"], padded["image_size"]) # [C,T,H,W] + + assert padded["video"].shape == (3, 2, 6, 8) + torch.testing.assert_close(round_tripped, video) + + @pytest.mark.L0 def test_action_prompt_json_formatter_drops_empty_fields() -> None: formatter = ActionPromptJsonFormatter() diff --git a/cosmos_framework/data/vfm/augmentor_provider.py b/cosmos_framework/data/vfm/augmentor_provider.py index 2ace65c..3e3d785 100644 --- a/cosmos_framework/data/vfm/augmentor_provider.py +++ b/cosmos_framework/data/vfm/augmentor_provider.py @@ -564,6 +564,9 @@ def get_video_augmentor_v3( conditioning_config = kwargs.get("conditioning_config", None) uniform_conditioning = kwargs.get("uniform_conditioning", False) temporal_compression_factor = kwargs.get("temporal_compression_factor", 4) + causal_vae = kwargs.get("causal_vae", True) + uniae_pad_frames = kwargs.get("uniae_pad_frames", None) + uniae_chunk_frames = kwargs.get("uniae_chunk_frames", None) print("Running video_basic_augmentor_v3...") augmentors = { @@ -577,6 +580,10 @@ def get_video_augmentor_v3( "min_stride": min_stride, "seek_mode": "exact", # Change to "approximate"? "dataset_resolution_type": dataset_resolution_type, + "resolution": resolution, + "causal_vae": causal_vae, + "uniae_pad_frames": uniae_pad_frames, + "uniae_chunk_frames": uniae_chunk_frames, }, ), "merge_datadict": L(merge_datadict.DataDictMerger)( @@ -599,6 +606,9 @@ def get_video_augmentor_v3( "conditioning_config": conditioning_config, "uniform_conditioning": uniform_conditioning, "temporal_compression_factor": temporal_compression_factor, + "resolution": resolution, + "uniae_pad_frames": uniae_pad_frames, + "uniae_chunk_frames": uniae_chunk_frames, }, ) augmentors.update( @@ -670,7 +680,6 @@ def get_video_augmentor_v3( return augmentors - # Use video_basic_augmentor_v3_json_caption instead. @augmentor_register("video_basic_augmentor_v3_with_audio") def get_video_augmentor_v3_with_audio( @@ -829,6 +838,9 @@ def get_video_augmentor_v3_json_caption( conditioning_config = kwargs.get("conditioning_config", None) uniform_conditioning = kwargs.get("uniform_conditioning", False) temporal_compression_factor = kwargs.get("temporal_compression_factor", 4) + causal_vae = kwargs.get("causal_vae", True) + uniae_pad_frames = kwargs.get("uniae_pad_frames", None) + uniae_chunk_frames = kwargs.get("uniae_chunk_frames", None) print("Running video_augmentor_v3_json_caption...") augmentors = { @@ -853,9 +865,13 @@ def get_video_augmentor_v3_json_caption( "min_stride": min_stride, "seek_mode": "exact", "dataset_resolution_type": dataset_resolution_type, + "resolution": resolution, "extract_audio": extract_audio, "audio_sample_rate": audio_sample_rate, "emit_placeholder_sound": not extract_audio, + "causal_vae": causal_vae, + "uniae_pad_frames": uniae_pad_frames, + "uniae_chunk_frames": uniae_chunk_frames, }, ), "merge_datadict": L(merge_datadict.DataDictMerger)( @@ -881,6 +897,9 @@ def get_video_augmentor_v3_json_caption( "conditioning_config": conditioning_config, "uniform_conditioning": uniform_conditioning, "temporal_compression_factor": temporal_compression_factor, + "resolution": resolution, + "uniae_pad_frames": uniae_pad_frames, + "uniae_chunk_frames": uniae_chunk_frames, }, ) augmentors.update( diff --git a/cosmos_framework/data/vfm/augmentors/__init__.py b/cosmos_framework/data/vfm/augmentors/__init__.py index 503ec1b..28a81be 100644 --- a/cosmos_framework/data/vfm/augmentors/__init__.py +++ b/cosmos_framework/data/vfm/augmentors/__init__.py @@ -1,3 +1,2 @@ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: OpenMDW-1.1 - diff --git a/cosmos_framework/data/vfm/augmentors/idle_frames_text_info.py b/cosmos_framework/data/vfm/augmentors/idle_frames_text_info.py index 302715c..34f6116 100644 --- a/cosmos_framework/data/vfm/augmentors/idle_frames_text_info.py +++ b/cosmos_framework/data/vfm/augmentors/idle_frames_text_info.py @@ -8,7 +8,7 @@ frames (i.e. the relative-pose delta is close to identity and the gripper command does not change). The upstream dataset is responsible for populating ``data_dict[idle_frames_key]`` via -:func:`projects.cosmos3.vfm.datasets.action.pose_utils.compute_idle_frames`. +:func:`cosmos_framework.data.vfm.action.pose_utils.compute_idle_frames`. Per-field dropout (default 5%) is applied here, matching Pi0.7's approach of independently dropping each metadata component. This is complementary to the diff --git a/cosmos_framework/data/vfm/augmentors/image_editing_transform.py b/cosmos_framework/data/vfm/augmentors/image_editing_transform.py index fdaaa40..4af344b 100644 --- a/cosmos_framework/data/vfm/augmentors/image_editing_transform.py +++ b/cosmos_framework/data/vfm/augmentors/image_editing_transform.py @@ -18,6 +18,7 @@ from __future__ import annotations +import json import random import torch @@ -26,13 +27,12 @@ from cosmos_framework.data.imaginaire.webdataset.augmentors.augmentor import Augmentor from cosmos_framework.utils import log -from cosmos_framework.data.vfm.sequence_packing import SequencePlan class ExtractImageEditingConversation(Augmentor): """Extract and validate image editing conversation from standard annotation format. - This augmentor processes the cosmos-interleaved conversation format for image editing: + This augmentor processes cosmos-interleaved conversation data for image editing: - Validates that the conversation has exactly one round (user + assistant) - User message must contain at least one image and text instruction - Assistant message must contain exactly one image (the edited result) @@ -42,6 +42,9 @@ class ExtractImageEditingConversation(Augmentor): - texts: Dict containing "content" with conversation data - mllm_media_list: Dict mapping image keys to PIL images (for understanding) - diffusion_media_list: Dict mapping image keys to PIL images (for diffusion/VAE) + - optional structured instruction key: Dict, JSON string, or JSON bytes containing + text_json.content and gemini_rewrite. When configured, gemini_rewrite is used as + the training prompt and text_json.content is used only to recover image references. Output Format (added to data_dict): - source_image: PIL.Image (the input image for editing) @@ -53,10 +56,124 @@ def __init__( self, input_keys: list | None = None, max_round: int = 1, + instruction_key: str = "texts", + conversation_key: str = "texts", + structured_instruction_field: str | None = None, args: dict | None = None, ) -> None: super().__init__(input_keys or [], None, args) - self.max_round = max_round + self.max_round: int = max_round + self.instruction_key: str = instruction_key + self.conversation_key: str = conversation_key + self.structured_instruction_field: str | None = structured_instruction_field + + def _decode_json_text(self, text: str, payload_name: str, sample_key: str) -> dict | None: + try: + payload = json.loads(text) + except json.JSONDecodeError as e: + log.warning( + f"Error decoding {payload_name} JSON: {sample_key}, {str(e)}", + rank0_only=False, + ) + return None + + if not isinstance(payload, dict): + log.warning( + f"Decoded {payload_name} is not a dict: {sample_key}, got {type(payload)}", + rank0_only=False, + ) + return None + return payload + + def _decode_payload(self, payload: object, payload_name: str, sample_key: str) -> dict | None: + if isinstance(payload, dict): + return payload + + if isinstance(payload, str): + return self._decode_json_text(payload, payload_name, sample_key) + + if isinstance(payload, (bytes, bytearray)): + try: + text = bytes(payload).decode("utf-8") + except UnicodeDecodeError as e: + log.warning( + f"Error decoding {payload_name} bytes as UTF-8: {sample_key}, {str(e)}", + rank0_only=False, + ) + return None + return self._decode_json_text(text, payload_name, sample_key) + + log.warning( + f"Unsupported {payload_name} payload type: {sample_key}, got {type(payload)}", + rank0_only=False, + ) + return None + + def _get_instruction_payload(self, data_dict: dict, sample_key: str) -> dict | None: + payload = data_dict.get(self.instruction_key) + if payload is None: + log.warning( + f"{self.instruction_key} not found in data_dict: {sample_key}", + rank0_only=False, + ) + return None + return self._decode_payload(payload, self.instruction_key, sample_key) + + def _get_conversation_payload( + self, + data_dict: dict, + instruction_payload: dict, + sample_key: str, + ) -> dict | None: + if self.conversation_key == self.instruction_key: + return instruction_payload + + if self.conversation_key in data_dict: + return self._decode_payload(data_dict[self.conversation_key], self.conversation_key, sample_key) + + nested_payload = instruction_payload.get(self.conversation_key) + if nested_payload is None: + log.warning( + f"{self.conversation_key} not found in {self.instruction_key}: {sample_key}", + rank0_only=False, + ) + return None + return self._decode_payload(nested_payload, f"{self.instruction_key}.{self.conversation_key}", sample_key) + + def _get_structured_instruction(self, instruction_payload: dict, sample_key: str) -> str | None: + if self.structured_instruction_field is None: + return None + + rewrite_error = instruction_payload.get("rewrite_error") + if rewrite_error is not None: + log.warning( + f"Structured instruction rewrite_error is non-null: {sample_key}, {rewrite_error}", + rank0_only=False, + ) + return None + + structured_payload = instruction_payload.get(self.structured_instruction_field) + if not isinstance(structured_payload, dict): + log.warning( + f"{self.structured_instruction_field} missing or not a dict: {sample_key}", + rank0_only=False, + ) + return None + + edit_type = structured_payload.get("edit_type") + structured_instruction = structured_payload.get("structured_instruction") + if not isinstance(edit_type, str) or not edit_type: + log.warning(f"Structured instruction edit_type missing: {sample_key}", rank0_only=False) + return None + if not isinstance(structured_instruction, dict) or not structured_instruction: + log.warning(f"Structured instruction body missing: {sample_key}", rank0_only=False) + return None + + prompt = { + "edit_type": edit_type, + "structured_instruction": structured_instruction, + } + return json.dumps(prompt, ensure_ascii=False) def __call__(self, data_dict: dict) -> dict | None: """Extract image editing conversation. @@ -69,23 +186,30 @@ def __call__(self, data_dict: dict) -> dict | None: or None if the data is invalid. """ # Validate required keys - for required_key in ["mllm_media_list", "diffusion_media_list", "texts"]: + sample_key = data_dict.get("__key__", "unknown") + for required_key in ["diffusion_media_list", self.instruction_key]: if required_key not in data_dict: log.warning( - f"{required_key} not found in data_dict: {data_dict.get('__key__', 'unknown')}", + f"{required_key} not found in data_dict: {sample_key}", rank0_only=False, ) return None - mllm_media_list = data_dict["mllm_media_list"] diffusion_media_list = data_dict["diffusion_media_list"] + instruction_payload = self._get_instruction_payload(data_dict, sample_key) + if instruction_payload is None: + return None + conversation_payload = self._get_conversation_payload(data_dict, instruction_payload, sample_key) + if conversation_payload is None: + return None + conversation_content_key = f"{self.conversation_key}.content" # Get conversation content try: - texts_content = data_dict["texts"].get("content") + texts_content = conversation_payload.get("content") if texts_content is None: log.warning( - f"texts.content is None: {data_dict.get('__key__', 'unknown')}", + f"{conversation_content_key} is None: {sample_key}", rank0_only=False, ) return None @@ -99,13 +223,13 @@ def __call__(self, data_dict: dict) -> dict | None: selected_conversations = texts_content else: log.warning( - f"Unexpected texts.content format: {data_dict.get('__key__', 'unknown')}", + f"Unexpected {conversation_content_key} format: {sample_key}", rank0_only=False, ) return None except Exception as e: log.warning( - f"Error accessing texts.content: {data_dict.get('__key__', 'unknown')}, {str(e)}", + f"Error accessing {conversation_content_key}: {sample_key}, {str(e)}", rank0_only=False, ) return None @@ -115,15 +239,14 @@ def __call__(self, data_dict: dict) -> dict | None: if len(selected_conversations) > 2: log.warning( f"Multi-round conversation found ({len(selected_conversations)} messages), " - f"keeping only first round: {data_dict.get('__key__', 'unknown')}", + f"keeping only first round: {sample_key}", rank0_only=False, ) selected_conversations = selected_conversations[:2] if len(selected_conversations) < 2: log.warning( - f"Expected at least 2 messages (user + assistant), got {len(selected_conversations)}: " - f"{data_dict.get('__key__', 'unknown')}", + f"Expected at least 2 messages (user + assistant), got {len(selected_conversations)}: {sample_key}", rank0_only=False, ) return None @@ -134,14 +257,14 @@ def __call__(self, data_dict: dict) -> dict | None: if user_msg.get("role") != "user": log.warning( - f"First message role is not 'user': {data_dict.get('__key__', 'unknown')}", + f"First message role is not 'user': {sample_key}", rank0_only=False, ) return None if assistant_msg.get("role") != "assistant": log.warning( - f"Second message role is not 'assistant': {data_dict.get('__key__', 'unknown')}", + f"Second message role is not 'assistant': {sample_key}", rank0_only=False, ) return None @@ -167,24 +290,29 @@ def __call__(self, data_dict: dict) -> dict | None: if user_image_key is None: log.warning( - f"No image found in user message: {data_dict.get('__key__', 'unknown')}", + f"No image found in user message: {sample_key}", rank0_only=False, ) return None - editing_instruction = " ".join(user_text_parts).strip() - if not editing_instruction: - log.warning( - f"No text instruction found in user message: {data_dict.get('__key__', 'unknown')}", - rank0_only=False, - ) - return None + if self.structured_instruction_field is None: + editing_instruction = " ".join(user_text_parts).strip() + if not editing_instruction: + log.warning( + f"No text instruction found in user message: {sample_key}", + rank0_only=False, + ) + return None + else: + editing_instruction = self._get_structured_instruction(instruction_payload, sample_key) + if editing_instruction is None: + return None # Extract assistant content: must have exactly one image assistant_content = assistant_msg.get("content", []) if isinstance(assistant_content, str): log.warning( - f"Assistant content is text-only (no image): {data_dict.get('__key__', 'unknown')}", + f"Assistant content is text-only (no image): {sample_key}", rank0_only=False, ) return None @@ -199,7 +327,7 @@ def __call__(self, data_dict: dict) -> dict | None: if assistant_image_key is None: log.warning( - f"No image found in assistant message: {data_dict.get('__key__', 'unknown')}", + f"No image found in assistant message: {sample_key}", rank0_only=False, ) return None @@ -208,7 +336,7 @@ def __call__(self, data_dict: dict) -> dict | None: for media_key in [user_image_key, assistant_image_key]: if media_key not in diffusion_media_list: log.warning( - f"Image {media_key} not found in diffusion_media_list: {data_dict.get('__key__', 'unknown')}", + f"Image {media_key} not found in diffusion_media_list: {sample_key}", rank0_only=False, ) return None @@ -225,7 +353,7 @@ def __call__(self, data_dict: dict) -> dict | None: if source_image is None or target_image is None: log.warning( - f"Source or target image is None: {data_dict.get('__key__', 'unknown')}", + f"Source or target image is None: {sample_key}", rank0_only=False, ) return None @@ -329,6 +457,8 @@ def __call__(self, data_dict: dict) -> dict | None: # by GenerationDataClean.num_vision_items_per_sample (set in get_data_and_condition). # In pack_input_sequence, all items except the last are fully conditioned; # the last item uses condition_frame_indexes_vision ([] = fully generated). + from cosmos_framework.data.vfm.sequence_packing import SequencePlan + data_dict["sequence_plan"] = SequencePlan( has_text=True, has_vision=True, diff --git a/cosmos_framework/data/vfm/augmentors/image_editing_transform_test.py b/cosmos_framework/data/vfm/augmentors/image_editing_transform_test.py new file mode 100644 index 0000000..849ad00 --- /dev/null +++ b/cosmos_framework/data/vfm/augmentors/image_editing_transform_test.py @@ -0,0 +1,159 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: OpenMDW-1.1 + +import json + +import pytest +from PIL import Image + +from cosmos_framework.data.vfm.augmentors.image_editing_transform import ExtractImageEditingConversation + +_STRUCTURED_KEY = "edit_schema_all_inputs_qwen3-vl-235b-a22b-instruct" + + +def _conversation(instruction: str = "Make the cup red") -> list[list[dict]]: + return [ + [ + { + "role": "user", + "content": [ + {"type": "image", "image": "image_0"}, + {"type": "text", "text": instruction}, + ], + }, + { + "role": "assistant", + "content": [ + {"type": "image", "image": "image_1"}, + ], + }, + ] + ] + + +def _media_data() -> tuple[Image.Image, Image.Image, dict[str, Image.Image]]: + source_image = Image.new("RGB", (16, 16), color="blue") + target_image = Image.new("RGB", (16, 16), color="red") + media_list = { + "image_0": source_image, + "image_1": target_image, + } + return source_image, target_image, media_list + + +def _base_data_dict() -> tuple[dict, Image.Image, Image.Image]: + source_image, target_image, media_list = _media_data() + data_dict = { + "__key__": "sample_000001", + "mllm_media_list": media_list, + "diffusion_media_list": media_list, + } + return data_dict, source_image, target_image + + +def _structured_payload() -> dict: + return { + "rewrite_error": None, + "gemini_rewrite": { + "edit_type": "adjust", + "structured_instruction": { + "target_object": "cup", + "attribute_type": "color", + "desired_value": "red", + }, + }, + "text_json": { + "content": _conversation("Original dense instruction"), + }, + "original_instruction": "Original dense instruction", + } + + +@pytest.mark.L0 +@pytest.mark.CPU +def test_extract_image_editing_conversation_keeps_texts_behavior() -> None: + data_dict, source_image, target_image = _base_data_dict() + data_dict["texts"] = {"content": _conversation("Make the cup red")} + + result = ExtractImageEditingConversation()(data_dict) + + assert result is not None + assert result["source_image"] is source_image + assert result["target_image"] is target_image + assert result["editing_instruction"] == "Make the cup red" + + +@pytest.mark.L0 +@pytest.mark.CPU +def test_extract_structured_dict_payload_uses_gemini_rewrite() -> None: + data_dict, source_image, target_image = _base_data_dict() + payload = _structured_payload() + data_dict[_STRUCTURED_KEY] = payload + + result = ExtractImageEditingConversation( + instruction_key=_STRUCTURED_KEY, + conversation_key="text_json", + structured_instruction_field="gemini_rewrite", + )(data_dict) + + expected_instruction = json.dumps( + { + "edit_type": payload["gemini_rewrite"]["edit_type"], + "structured_instruction": payload["gemini_rewrite"]["structured_instruction"], + }, + ensure_ascii=False, + ) + assert result is not None + assert result["source_image"] is source_image + assert result["target_image"] is target_image + assert result["editing_instruction"] == expected_instruction + + +@pytest.mark.L0 +@pytest.mark.CPU +@pytest.mark.parametrize("encode_as_bytes", [False, True]) +def test_extract_structured_json_payload_uses_gemini_rewrite(encode_as_bytes: bool) -> None: + data_dict, _, _ = _base_data_dict() + payload = _structured_payload() + payload_json = json.dumps(payload, ensure_ascii=False) + data_dict[_STRUCTURED_KEY] = payload_json.encode("utf-8") if encode_as_bytes else payload_json + + result = ExtractImageEditingConversation( + instruction_key=_STRUCTURED_KEY, + conversation_key="text_json", + structured_instruction_field="gemini_rewrite", + )(data_dict) + + expected_instruction = json.dumps( + { + "edit_type": payload["gemini_rewrite"]["edit_type"], + "structured_instruction": payload["gemini_rewrite"]["structured_instruction"], + }, + ensure_ascii=False, + ) + assert result is not None + assert result["editing_instruction"] == expected_instruction + + +@pytest.mark.L0 +@pytest.mark.CPU +@pytest.mark.parametrize( + "payload_update", + [ + {"gemini_rewrite": None}, + {"rewrite_error": "failed to rewrite"}, + ], +) +def test_extract_structured_invalid_payload_returns_none(payload_update: dict) -> None: + data_dict, _, _ = _base_data_dict() + payload = _structured_payload() + payload.update(payload_update) + data_dict[_STRUCTURED_KEY] = payload + + result = ExtractImageEditingConversation( + instruction_key=_STRUCTURED_KEY, + conversation_key="text_json", + structured_instruction_field="gemini_rewrite", + )(data_dict) + + assert result is None diff --git a/cosmos_framework/data/vfm/augmentors/interleaved_video_parsing.py b/cosmos_framework/data/vfm/augmentors/interleaved_video_parsing.py index aea759e..41e4e4c 100644 --- a/cosmos_framework/data/vfm/augmentors/interleaved_video_parsing.py +++ b/cosmos_framework/data/vfm/augmentors/interleaved_video_parsing.py @@ -414,7 +414,7 @@ def __call__(self, data_dict: dict) -> dict | None: ) # [C,T,H,W] num_multiplier = (end_frame - start_frame) / self.num_frames - + # NOTE: matches legacy VideoParsing.__call__ output keys exactly. Do NOT add # variable-length fields like ``frame_indices`` here -- ``video_flatten_keys`` in # ``get_video_transfer_augmentor`` lists ``frame_indices``, and surfacing a # per-sample list there would crash ``custom_collate_fn`` (default_collate requires diff --git a/cosmos_framework/data/vfm/augmentors/pkl_to_media.py b/cosmos_framework/data/vfm/augmentors/pkl_to_media.py index 54d47b8..aa9eb21 100644 --- a/cosmos_framework/data/vfm/augmentors/pkl_to_media.py +++ b/cosmos_framework/data/vfm/augmentors/pkl_to_media.py @@ -27,14 +27,12 @@ def token_to_pixels(token_length: int, patch_size: int = 14, temporal_patch_size: int = 2) -> int: """Convert token length to pixels based on patch size and temporal patch size.""" - merged_patch_size = patch_size * 2 return token_length * merged_patch_size**2 * temporal_patch_size def pixels_to_token(pixels: int, patch_size: int = 14, temporal_patch_size: int = 2) -> int: """Convert pixels to token length based on patch size and temporal patch size.""" - merged_patch_size = patch_size * 2 return pixels // merged_patch_size**2 // temporal_patch_size diff --git a/cosmos_framework/data/vfm/augmentors/sequence_plan.py b/cosmos_framework/data/vfm/augmentors/sequence_plan.py index 9f0500a..5a2202f 100644 --- a/cosmos_framework/data/vfm/augmentors/sequence_plan.py +++ b/cosmos_framework/data/vfm/augmentors/sequence_plan.py @@ -7,15 +7,22 @@ - weighted dict (``conditioning_config``): explicit frame-count → probability pairs - uniform (``uniform_conditioning=True``): k ~ Uniform{0, T_latent-1}, where T_latent is computed from the actual video length using the VAE temporal compression factor + or UniAE chunking parameters when provided. """ import random +from collections.abc import Mapping from typing import Optional import torch from cosmos_framework.data.imaginaire.webdataset.augmentors.augmentor import Augmentor from cosmos_framework.data.vfm.sequence_packing import SequencePlan +from cosmos_framework.model.vfm.tokenizers.uniae.frame_math import ( + get_uniae_chunk_frames, + get_uniae_latent_num_frames, + normalize_uniae_chunk_frames, +) class SequencePlanAugmentor(Augmentor): @@ -37,6 +44,11 @@ class SequencePlanAugmentor(Augmentor): must be provided. - "temporal_compression_factor" (int, default 4): VAE temporal compression factor used to convert pixel frame count N to T_latent = 1 + (N-1) // tcf. + - "uniae_chunk_frames" / "uniae_pad_frames" (optional): When provided, + use UniAE's non-causal first-frame plus padded-chunk latent count. + ``uniae_chunk_frames`` may be a scalar or a resolution-keyed mapping. + - "resolution" (str, optional): Target dataset resolution key. Preferred over + the current tensor shape when selecting a resolution-keyed UniAE chunk. """ def __init__(self, input_keys: list, output_keys: Optional[list] = None, args: Optional[dict] = None) -> None: @@ -48,6 +60,9 @@ def __init__(self, input_keys: list, output_keys: Optional[list] = None, args: O self.conditioning_config = args.get("conditioning_config") self.uniform_conditioning = args.get("uniform_conditioning", False) self.temporal_compression_factor = args.get("temporal_compression_factor", 4) + self.target_resolution_key = None if args.get("resolution") is None else str(args["resolution"]) + self.uniae_pad_frames = None if args.get("uniae_pad_frames") is None else int(args["uniae_pad_frames"]) + self.uniae_chunk_frames = self._normalize_uniae_chunk_frames(args.get("uniae_chunk_frames")) if self.conditioning_config is None and not self.uniform_conditioning: raise ValueError("args must provide 'conditioning_config' or set 'uniform_conditioning=True'") @@ -70,6 +85,43 @@ def __init__(self, input_keys: list, output_keys: Optional[list] = None, args: O else: self.normalized_config = {0: 1.0} + def _normalize_uniae_chunk_frames( + self, uniae_chunk_frames: int | Mapping[str, int] | None + ) -> int | dict[str, int] | None: + return normalize_uniae_chunk_frames( + uniae_chunk_frames, + pad_frames=self.uniae_pad_frames, + temporal_compression_factor=self.temporal_compression_factor, + ) + + def _get_uniae_chunk_frames(self, spatial_shape: tuple[int, int] | None = None) -> int: + assert self.uniae_chunk_frames is not None + return get_uniae_chunk_frames( + self.uniae_chunk_frames, + spatial_shape=spatial_shape, + target_resolution_key=self.target_resolution_key, + ) + + def _get_latent_frame_count(self, num_frames: int | None, spatial_shape: tuple[int, int] | None = None) -> int: + if num_frames is None: + return 1 + if num_frames < 1: + raise ValueError(f"video must contain at least one frame, got {num_frames}") + if num_frames == 1: + return 1 + if self.uniae_chunk_frames is None: + return 1 + (num_frames - 1) // self.temporal_compression_factor + + assert self.uniae_pad_frames is not None + return get_uniae_latent_num_frames( + num_frames, + self.uniae_chunk_frames, + pad_frames=self.uniae_pad_frames, + temporal_compression_factor=self.temporal_compression_factor, + spatial_shape=spatial_shape, + target_resolution_key=self.target_resolution_key, + ) + def __call__(self, data_dict: dict) -> dict: """Create a SequencePlan with random conditional frames. @@ -94,15 +146,17 @@ def __call__(self, data_dict: dict) -> dict: # Determine number of frames # Video should be a tensor with shape (C, T, H, W) by this point in the pipeline + spatial_shape = None if isinstance(video, torch.Tensor): assert video.ndim == 4, "video should be a tensor with shape (C, T, H, W)" - num_frames = video.shape[1] + num_frames = video.shape[1] # video: [C,T,H,W] + spatial_shape = (video.shape[2], video.shape[3]) else: # If video is not a tensor or dict, we can't determine the exact number # Use a conservative approach - will be limited by max available frames num_frames = None - T_latent = 1 + (num_frames - 1) // self.temporal_compression_factor if num_frames is not None else 1 + T_latent = self._get_latent_frame_count(num_frames, spatial_shape) # Sample number of conditional frames if self.uniform_conditioning: diff --git a/cosmos_framework/data/vfm/augmentors/text_transforms_for_image.py b/cosmos_framework/data/vfm/augmentors/text_transforms_for_image.py index edaef11..d38fae4 100644 --- a/cosmos_framework/data/vfm/augmentors/text_transforms_for_image.py +++ b/cosmos_framework/data/vfm/augmentors/text_transforms_for_image.py @@ -8,13 +8,17 @@ from cosmos_framework.data.imaginaire.webdataset.augmentors.v3_text_transforms import pad_and_resize from cosmos_framework.data.imaginaire.webdataset.augmentors.augmentor import Augmentor from cosmos_framework.utils import log -from cosmos_framework.data.vfm.data_sources.data_registration import _CAPTION_EMBEDDING_KEY_MAPPING_IMAGES # For the qwen captions, we have 3 variants: short, medium, long # In addition, for synthetic data, we create prompt embeddings as well. # There is quite a bit of entropy in the way prompt data is saved. # Captions are saved as "prompts", while the corresponding embeddings are saved as "original_prompt" # This part will be cleaned after synthetic data is cleaned to be in the same format as real data. +_CAPTION_EMBEDDING_KEY_MAPPING_IMAGES = { + "ai_v3p1": "ai_v3p1", + "qwen2p5_7b_v4": "qwen2p5_7b_v4", + "prompts": "qwen2p5_7b_v4", +} _AVAILABLE_QWEN_CAPTIONS = ["qwen2p5_7b_short", "qwen2p5_7b_medium", "qwen2p5_7b_long"] _AVAILABLE_QWEN3_30B_A3B_CAPTIONS = [ "qwen3_30b_a3b_short", diff --git a/cosmos_framework/data/vfm/augmentors/transfer_control_transform.py b/cosmos_framework/data/vfm/augmentors/transfer_control_transform.py index 74fb523..a18e8fd 100644 --- a/cosmos_framework/data/vfm/augmentors/transfer_control_transform.py +++ b/cosmos_framework/data/vfm/augmentors/transfer_control_transform.py @@ -5,7 +5,7 @@ Augmentors for transfer (control-conditioned) image and video generation in the cosmos3 VFM pipeline. Transfer training conditions the model on control signals (edge, blur, depth, or segmentation) -to generate images or videos, aligned with cosmos_framework/transfer2. This module provides: +to generate images or videos, aligned with cosmos/transfer2. This module provides: - **TransferToTrainingFormat**: Converts (control_input, target) into the joint dataloader format with SequencePlan (condition frame + generated frame), for both image and video outputs. diff --git a/cosmos_framework/data/vfm/augmentors/video_parsing.py b/cosmos_framework/data/vfm/augmentors/video_parsing.py index cfaa934..25a5580 100644 --- a/cosmos_framework/data/vfm/augmentors/video_parsing.py +++ b/cosmos_framework/data/vfm/augmentors/video_parsing.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: OpenMDW-1.1 import random +from collections.abc import Mapping from typing import Optional import numpy as np @@ -15,12 +16,18 @@ from cosmos_framework.data.imaginaire.webdataset.augmentors.image.misc import obtain_augmentation_size from cosmos_framework.utils import log from cosmos_framework.data.vfm.utils import VIDEO_RES_SIZE_INFO +from cosmos_framework.model.vfm.tokenizers.uniae.frame_math import ( + align_uniae_num_video_frames, + get_uniae_chunk_frames, + normalize_uniae_chunk_frames, +) # Map dataset_resolution_type to resolution tier key in VIDEO_RES_SIZE_INFO _DATASET_RESOLUTION_TIER: dict[str, str] = {"gt480p": "480", "gt720p": "720", "gt1080p": "1080"} _MIN_FPS = 10 _MAX_FPS = 60 +_UNIAE_TEMPORAL_COMPRESSION_FACTOR = 4 class VideoParsing(Augmentor): @@ -345,7 +352,7 @@ def __call__(self, data_dict: dict) -> dict | None: video_info["video"] = video_frames video_info["num_multiplier"] = num_multiplier # Store the frame skipping multiplier - + # NOTE: Explaining the logic of conditioning FPS calculation: # 1. Our video parser stores the original video FPS of the video. # 2. We have multiple modes of frame selection -- consecutive chunk of frames or subsampled frames. # Here's what we do in each case: @@ -434,6 +441,45 @@ def __init__(self, input_keys: list, output_keys: Optional[list] = None, args: O self.dataset_resolution_type = args.get("dataset_resolution_type", "all") self.resolution_tier = _DATASET_RESOLUTION_TIER.get(self.dataset_resolution_type) + # VAE temporal alignment mode. + # causal_vae=True (default): align to 1+4N (causal VAE, e.g. Wan 2.2) + # causal_vae=False: align to 4N (non-causal VAE, e.g. UniAE) + self.causal_vae = args.get("causal_vae", True) + self.target_resolution_key = None if args.get("resolution") is None else str(args["resolution"]) + self.uniae_pad_frames = None if args.get("uniae_pad_frames") is None else int(args["uniae_pad_frames"]) + self.uniae_chunk_frames = self._normalize_uniae_chunk_frames(args.get("uniae_chunk_frames", None)) + + def _normalize_uniae_chunk_frames( + self, uniae_chunk_frames: int | Mapping[str, int] | None + ) -> int | dict[str, int] | None: + return normalize_uniae_chunk_frames( + uniae_chunk_frames, + pad_frames=self.uniae_pad_frames, + temporal_compression_factor=_UNIAE_TEMPORAL_COMPRESSION_FACTOR, + missing_pad_message="uniae_pad_frames must be specified if uniae_chunk_frames is specified", + temporal_divisibility_name="UniAE temporal compression factor", + ) + + def _get_uniae_chunk_frames(self, spatial_shape: tuple[int, int] | None = None) -> int: + assert self.uniae_chunk_frames is not None + return get_uniae_chunk_frames( + self.uniae_chunk_frames, + spatial_shape=spatial_shape, + target_resolution_key=self.target_resolution_key, + ) + + def _align_uniae_num_video_frames(self, num_video_frames: int, spatial_shape: tuple[int, int] | None = None) -> int: + assert self.uniae_pad_frames is not None + assert self.uniae_chunk_frames is not None + return align_uniae_num_video_frames( + num_video_frames, + self.uniae_chunk_frames, + pad_frames=self.uniae_pad_frames, + temporal_compression_factor=_UNIAE_TEMPORAL_COMPRESSION_FACTOR, + spatial_shape=spatial_shape, + target_resolution_key=self.target_resolution_key, + ) + def _sample_stride_with_bias(self, max_stride: int, min_stride: int = 1) -> int: """Sample a stride from [min_stride, max_stride] with bias controlled by low_fps_bias. @@ -520,7 +566,6 @@ def _validate_and_probe(self, video: Optional[bytes], meta_dict: dict, data_dict return True def __call__(self, data_dict: dict) -> dict | None: - # if in future we need to train with batch size > 1, need to pad frames try: meta_dict = data_dict[self.meta_key] @@ -553,8 +598,10 @@ def __call__(self, data_dict: dict) -> dict | None: f"Resize error. orig {(orig_w, orig_h)} desire {img_size} compute {target_size}" ) transform = [Resize(target_size)] + output_spatial_shape = target_size else: transform = None + output_spatial_shape = (meta_dict["height"], meta_dict["width"]) # Adding try-expcept because some of the data is bad and video decoding call fail. try: @@ -569,11 +616,33 @@ def __call__(self, data_dict: dict) -> dict | None: stride = self._sample_stride_with_bias(self.max_stride, self.min_stride) frame_indices = np.arange(0, num_video_frames, stride).tolist() - # VAE compress temporal by 4x, with 1 as condition - # thus the max_video_frames must be 1 + 4N + # Align frame count to the active VAE temporal contract. + # causal_vae=True: 1+4N (causal VAE, e.g. Wan 2.2). + # causal_vae=False: UniAE chunk/pad alignment if configured; otherwise 4N. num_video_frames = min(len(frame_indices), self.args.get("max_num_frames", 1000)) - N = (num_video_frames - 1) // 4 - num_video_frames = 1 + 4 * N + if self.causal_vae: + N = (num_video_frames - 1) // 4 + num_video_frames = 1 + 4 * N + else: + # If this is UniAE, we need to align the frame count to the chunk size and padding. + if self.uniae_chunk_frames is not None: + # T is valid when r = (T-1) % effective_chunk_frames satisfies: + # r == 0 (exact multiple of chunks) + # OR r % 4 == target_r where target_r = (-2*pad_frames) % 4 + # Compute minimum trim delta in O(1): + # delta = steps to nearest r' <= r satisfying the condition. + num_video_frames = self._align_uniae_num_video_frames(num_video_frames, output_spatial_shape) + + if num_video_frames == 0: + log.warning( + f"VideoParsingWithFullFrames: video too short for UniAE. " + f"url: {data_dict['__url__']}, key: {data_dict['__key__']}", + rank0_only=False, + ) + return None + else: + N = num_video_frames // 4 + num_video_frames = 4 * N frame_indices = frame_indices[0:num_video_frames] frame_batch = video_decoder.get_frames_at(frame_indices) @@ -698,7 +767,6 @@ def __init__(self, input_keys: list, output_keys: Optional[list] = None, args: O super().__init__(input_keys, output_keys, args) def __call__(self, data_dict: dict) -> dict | None: - # if in future we need to train with batch size > 1, need to pad frames try: meta_dict = data_dict[self.meta_key] @@ -743,8 +811,10 @@ def __call__(self, data_dict: dict) -> dict | None: f"Resize error. orig {(orig_w, orig_h)} desire {img_size} compute {target_size}" ) transform = [Resize(target_size)] + output_spatial_shape = target_size else: transform = None + output_spatial_shape = (meta_dict["height"], meta_dict["width"]) # Adding try-expcept because some of the data is bad and video decoding call fail. try: @@ -772,11 +842,19 @@ def __call__(self, data_dict: dict) -> dict | None: stride = self._sample_stride_with_bias(self.max_stride, self.min_stride) frame_indices = np.arange(chunk_start_clamped, chunk_end_clamped, stride).tolist() - # VAE compress temporal by 4x, with 1 as condition - # thus the max_video_frames must be 1 + 4N + # Align frame count to the active VAE temporal contract. + # causal_vae=True: 1+4N (causal VAE, e.g. Wan 2.2). + # causal_vae=False: UniAE chunk/pad alignment if configured; otherwise 4N. num_video_frames = min(len(frame_indices), self.args.get("max_num_frames", 1000)) - N = (num_video_frames - 1) // 4 - num_video_frames = 1 + 4 * N + if self.causal_vae: + N = (num_video_frames - 1) // 4 + num_video_frames = 1 + 4 * N + else: + if self.uniae_chunk_frames is not None: + num_video_frames = self._align_uniae_num_video_frames(num_video_frames, output_spatial_shape) + else: + N = num_video_frames // 4 + num_video_frames = 4 * N if num_video_frames < 1: log.warning( f"VideoParsingChunkedFrames: chunk too short for stride. " diff --git a/cosmos_framework/data/vfm/augmentors/vlm/__init__.py b/cosmos_framework/data/vfm/augmentors/vlm/__init__.py index 503ec1b..28a81be 100644 --- a/cosmos_framework/data/vfm/augmentors/vlm/__init__.py +++ b/cosmos_framework/data/vfm/augmentors/vlm/__init__.py @@ -1,3 +1,2 @@ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: OpenMDW-1.1 - diff --git a/cosmos_framework/data/vfm/augmentors/vlm/nvlm_data_unify.py b/cosmos_framework/data/vfm/augmentors/vlm/nvlm_data_unify.py deleted file mode 100644 index eb029eb..0000000 --- a/cosmos_framework/data/vfm/augmentors/vlm/nvlm_data_unify.py +++ /dev/null @@ -1,120 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: OpenMDW-1.1 - -"""Visual-Text Transformations or Augmentations.""" - -import io -from typing import Dict, Optional - -from PIL import Image - -from cosmos_framework.data.imaginaire.webdataset.augmentors.augmentor import Augmentor -from cosmos_framework.utils import log -from cosmos_framework.data.vfm.augmentors.vlm.nvlm_sample_loaders_and_part_filters import ( - get_data_class, - get_part_filter, - get_sample_loader, -) - - -class NVLMImageDataUnify(Augmentor): - """ - This augmentor is used to unify the data format of the nvlm data. - It will take the raw nvlm data tar and convert it to a dictionary with the following keys: - { - "__url__": str, - "__key__": str, - "data_class": str, - "images": List[PIL.Image.Image], - "text": str, - "words_boxes": Optional[List[List[int]]], - "words_text": Optional[List[str]], - "similarity_matrix": Optional[List[List[float]]], - } - """ - - def __init__( - self, - input_keys: list = ["raw_nvlm"], - output_keys: Optional[list] = [], - args: Optional[dict] = None, - data_path_prefix: list[str] = [ - "cosmos_framework/ar/v2/nvlm/", - ], # prefix of the data in s3 - ) -> None: - super().__init__(input_keys, output_keys, args) - self.data_path_prefix = data_path_prefix - - def convert_image(self, img): - try: - if isinstance(img, bytes): - img = Image.open(io.BytesIO(img)).convert("RGB") - elif isinstance(img, Image.Image): - img = img.convert("RGB") - pass # Image is already in PIL format - elif isinstance(img, list): - for i in range(len(img)): - img[i], success = self.convert_image(img[i]) - if not success: - return Image.new("RGB", (256, 256), (0, 0, 0)), False - return img, True - else: - raise ValueError(f"Invalid image type: {type(img)}") - - success = True - except Exception as e: - log.warning(f"Error processing image: {e}. Creating an empty black image.", rank0_only=False) - img = Image.new("RGB", (256, 256), (0, 0, 0)) # Creates a 256x256 black image - success = False - return img, success - - def __call__(self, data_dict: Dict) -> Dict: - url = data_dict["__url__"] - data_path = "/".join(url.path.split("/")[:-1]) # remove the last part of the path - sample_loader = get_sample_loader(data_path) - part_filter = get_part_filter(data_path) - data_class = get_data_class(data_path) - assert sample_loader is not None and part_filter is not None and data_class is not None, ( - f"sample_loader({sample_loader}) or part_filter({part_filter}) or data_class({data_class}) is not found for {data_path}" - ) - - raw = {"__url__": url, "__key__": data_dict["__key__"]} - output = {"__url__": url, "__key__": data_dict["__key__"]} - for k, v in data_dict.items(): - ext = k.split(".")[-1] - if part_filter(ext): - raw[ext] = v - try: - output_converted = sample_loader(raw) - # Here output_converted will be a dictionary with the following keys: - # { - # "__key__": str, - # "image": PIL.Image.Image, - # "images": List[PIL.Image.Image], - # "text": str, - # "words_boxes": Optional - # "words_text": Optional - # "similarity_matrix": Optional - # } - except Exception as e: - log.warning( - f"Error in sample_loader: {e}, sample_loader: {sample_loader}, data_path: {data_path}, raw: {raw.keys()}, original_data_dict: {data_dict.keys()}, __url__: {url}, __key__: {data_dict['__key__']}" - ) - return None - - output.update(output_converted) - if "image" not in output_converted and "images" not in output_converted: - success = False - log.warning(f"image not found in {output_converted.keys()}") - if "image" in output_converted: # Single image case - img, success = self.convert_image(output["image"]) - output["images"] = [img] # What should be the format for the iamges - elif "images" in output_converted: - output["images"] = output_converted["images"] - output["images"], success = self.convert_image(output["images"]) - if not success: - log.warning(f"image conversion failed for {data_dict['__key__']} url: {url} | Skip this data") - return None - output["data_class"] = data_class - - return output diff --git a/cosmos_framework/data/vfm/augmentors/vlm/nvlm_sample_loaders_and_part_filters.py b/cosmos_framework/data/vfm/augmentors/vlm/nvlm_sample_loaders_and_part_filters.py deleted file mode 100644 index fabe0c3..0000000 --- a/cosmos_framework/data/vfm/augmentors/vlm/nvlm_sample_loaders_and_part_filters.py +++ /dev/null @@ -1,2815 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: OpenMDW-1.1 - -# Combined Sample Loaders -# Auto-generated script combining all sample_loader.py files (Dont edit this file! Edit the projects/cosmos/reasoning/v1/scripts/create_sample_loader_and_part_filter_file.py instead) - -import io - -import torch -from PIL import Image - -from cosmos_framework.utils import log -from cosmos_framework.data.vfm.data_sources.vlm.nvlm import data_path_mapping - -# This file was automatically generated by `nvgpt4 data prepare`. - -# import torch - - -def sample_loader_0(raw: dict) -> dict: # Note: Images are already decoded to tensors - - if "text" in raw: - caption = raw["text"] - else: - caption = raw["json"]["caption"] - return dict( - __key__=raw["__key__"], - image=raw["jpg"], # expected type: torch.Tensor - caption=caption, # expected type: str - ) - - -def part_filter_0(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("json", "jpg", "text") - - -# This file was automatically generated by `energon prepare`. - - - -def sample_loader_1(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - images = [raw["jpg"]] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_1(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("jpg", "json") - - -# This file was automatically generated by `energon prepare`. - - - -def sample_loader_2(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - images = [raw["jpg"]] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_2(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("jpg", "json") - - -# This file was automatically generated by `energon prepare`. - - - -def sample_loader_3(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - images = [raw["jpg"]] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_3(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("jpg", "json") - - -# This file was automatically generated by `energon prepare`. - - - -def sample_loader_4(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - images = [raw["jpg"]] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_4(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("jpg", "json") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_5(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - images = [raw["jpg"]] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_5(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("json", "jpg") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_6(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - - key = raw["__key__"] - if "docvqa" in key: - context = json_item["question"] - answers = json_item["answers"] - image = raw["jpg"] - answer_weights = json_item["answer_weights"] - elif "textvqa" in key or "lrv_instruct" in key: - context = json_item["question"] - answers = json_item["answer"] - image = raw["jpg"] - answer_weights = None - elif "stvqa" in key: - context = json_item["question"] - answers = json_item["answers"] - image = raw["jpg"] - answer_weights = [1.0] * len(json_item["answers"]) - elif "chartqa" in key: - context = json_item["query"] - answers = json_item["label"] - image = raw["png"] - answer_weights = None - elif "screenqa" in key: - image = raw["jpg"] - context = json_item["question"] - answers = json_item["ground_truth"] - answer_weights = [1.0] * len(json_item["ground_truth"]) - elif "HME100K" in key: - image = raw["jpg"] - context = "Please write out the expression of the formula in the image using LaTeX format." - answers = json_item["latex_formula"] - answer_weights = None - else: # scale, textbook - image = raw["jpg"] - context = json_item["question"] - answers = json_item["answer"] - answer_weights = None - - return dict( - __key__=key, - image=image, - context=context, - answers=answers, - answer_weights=answer_weights, - ) - - -def part_filter_6(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("json", "jpg", "png") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_7(raw: dict) -> dict: # Note: Images are already decoded to tensors - j = raw["json"] - - return dict( - __key__=raw["__key__"], - image=raw["jpg"], # expected type: torch.Tensor - context=j["question_string"], # expected type: str - answers=j["answer"], # expected type: typing.Union[typing.List[str], NoneType], default: None - answer_weights=None, # expected type: typing.Union[torch.Tensor, NoneType], default: None - ) - - -def part_filter_7(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("jpg", "json") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_8(raw: dict) -> dict: # Note: Images are already decoded to tensors - j = raw["json"] - - return dict( - __key__=raw["__key__"], - image=raw["jpg"], # expected type: torch.Tensor - context=j["question"], # expected type: str - answers=str(j["answer"]), # expected type: typing.Optional[typing.List[str]], default: None - answer_weights=None, # expected type: typing.Optional[torch.Tensor], default: None - ) - - -def part_filter_8(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("jpg", "json") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_9(raw: dict) -> dict: # Note: Images are already decoded to tensors - j = raw["json"] - - return dict( - __key__=raw["__key__"], - image=raw["jpg"], # expected type: torch.Tensor - context=j["question"].strip(), # expected type: str - answers=j["gt_answer"].strip(), # expected type: typing.Union[typing.List[str], NoneType], default: None - answer_weights=None, # expected type: typing.Union[torch.Tensor, NoneType], default: None - ) - - -def part_filter_9(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("json", "jpg") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_10(raw: dict) -> dict: # Note: Images are already decoded to tensors - j = raw["json"] - return dict( - image=raw["jpg"], # expected type: torch.Tensor - context=j["question"], # expected type: str - answers=j["answer"], # expected type: typing.Optional[typing.List[str]], default: None - answer_weights=None, # expected type: typing.Optional[torch.Tensor], default: None - ) - - -def part_filter_10(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("json", "jpg") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_11(raw: dict) -> dict: # Note: Images are already decoded to tensors - j = raw["json"] - - return dict( - __key__=raw["__key__"], - image=raw["jpg"], - context=j["question"], - answers=j["answer"], - answer_weights=None, - ) - - -def part_filter_11(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("jpg", "json") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_12(raw: dict) -> dict: # Note: Images are already decoded to tensors - j = raw["json"] - return dict( - __key__=raw["__key__"], - image=raw["jpg"], # expected type: torch.Tensor - context=j["question"], # expected type: str - answers=j["answer"], # expected type: typing.Optional[typing.List[str]], default: None - answer_weights=None, # expected type: typing.Optional[torch.Tensor], default: None - ) - - -def part_filter_12(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("jpg", "json") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_13(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - key = raw["__key__"] - - if "geoqa_plus" in key or "tqa" in key: - return dict( - __key__=raw["__key__"], - image=raw["jpg"], - context=json_item["question"], - choices=json_item["choices"], - correct_choice_idx=json_item["correct_answer_index"], - ) - elif "geometry3k" in key: - return dict( - __key__=raw["__key__"], - image=raw["jpg"], - context=json_item["question"], - choices=json_item["choices"], - correct_choice_idx=ord(json_item["answer"].lower()) - 97, - ) - else: # science_qa, ai2d - image_key = "png" if "png" in raw else "jpg" - if image_key not in raw: - log.warning(f"Image key {image_key} not found in with raw keys: {raw.keys()}") - return dict( - __key__=raw["__key__"], # science_qa_sample_{idx} - image=raw[image_key], # expected type: torch.Tensor - context=json_item["question"], # expected type: str - choices=json_item["choices"], # expected type: typing.Union[typing.List[str], NoneType], default: None - correct_choice_idx=json_item["correct_choice_index"], - ) - - -def part_filter_13(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("json", "png", "jpg") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_14(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - - return dict( - __key__=raw["__key__"], # arxiv_qa_sample_{idx} - image=raw["jpg"], # expected type: torch.Tensor - context=json_item["question"], # expected type: str - choices=json_item["options"], # expected type: typing.Union[typing.List[str], NoneType], default: None - correct_choice_idx=json_item["correct_choice_index"], - ) - - -def part_filter_14(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("jpg", "json") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_15(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - - if json_item["question_type"] == "multi_choice": - correct_choice_idx = json_item["choices"].index(json_item["answer"]) - return dict( - __key__=raw["__key__"], - image=raw["jpg"], - context=json_item["question"], - choices=json_item["choices"], - correct_choice_idx=correct_choice_idx, - ) - else: - # A temporary hack for non multi-choice samples. - # If correct_choice_idx=-1, we should route it to the VQAWebdataset dataloading method. - # (74.7% free-text questions, 25.3% multi-choice questions) - return dict( - __key__=raw["__key__"], - image=raw["jpg"], - context=json_item["question"], - choices=[json_item["answer"]], - correct_choice_idx=-1, - ) - - -def part_filter_15(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("jpg", "json") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_16(raw: dict) -> dict: # Note: Images are already decoded to tensors - j = raw["json"] - - return dict( - __key__="llava-{}".format(raw["__key__"]), images=[raw["jpg"]], texts=j["conversations"], similarity_matrix=None - ) - - -def part_filter_16(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("jpg", "json") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_17(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - images = [raw["jpg"]] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_17(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("json", "jpg") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_18(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - images = [raw["jpg"]] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_18(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("json", "jpg") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_19(raw: dict) -> dict: # Note: Images are already decoded to tensors - - json_item = raw["json"] - images = [raw["jpg"]] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, # expected type: torch.Tensor - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_19(part: str) -> bool: - - return part in ("jpg", "json") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_20(raw: dict) -> dict: # Note: Images are already decoded to tensors - j = raw["json"] - - return dict( - __key__="llava-{}".format(raw["__key__"]), images=[raw["jpg"]], texts=j["conversations"], similarity_matrix=None - ) - - -def part_filter_20(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("jpg", "json") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_21(raw: dict) -> dict: # Note: Images are already decoded to tensors - j = raw["json"] - - return dict( - __key__="llava-{}".format(raw["__key__"]), images=[raw["png"]], texts=j["conversations"], similarity_matrix=None - ) - - -def part_filter_21(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("json", "png") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_22(raw: dict) -> dict: # Note: Images are already decoded to tensors - - json_item = raw["json"] - images = [raw["jpg"]] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, # expected type: torch.Tensor - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_22(part: str) -> bool: - - return part in ("jpg", "json") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_23(raw: dict) -> dict: # Note: Images are already decoded to tensors - - json_item = raw["json"] - images = [raw["jpg"]] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, # expected type: torch.Tensor - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_23(part: str) -> bool: - - return part in ("jpg", "json") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_24(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - images = [raw["jpg"]] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_24(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("json", "jpg") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_25(raw: dict) -> dict: # Note: Images are already decoded to tensors - j = raw["json"] - - return dict( - __key__="llava-{}".format(raw["__key__"]), images=[raw["jpg"]], texts=j["conversations"], similarity_matrix=None - ) - - -def part_filter_25(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("json", "jpg") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_26(raw: dict) -> dict: # Note: Images are already decoded to tensors - j = raw["json"] - - return dict( - __key__="llava-{}".format(raw["__key__"]), images=[raw["jpg"]], texts=j["conversations"], similarity_matrix=None - ) - - -def part_filter_26(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("jpg", "json") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_27(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - images = [raw["jpg"]] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_27(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("json", "jpg") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_28(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - images = [raw["jpg"]] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_28(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("jpg", "json") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_29(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - images = [raw["jpg"]] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_29(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("json", "jpg") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_30(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - images = [raw["jpg"]] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_30(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("json", "jpg") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_31(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - images = [raw["jpg"]] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_31(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("jpg", "json") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_32(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - images = [raw["jpg"]] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_32(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("json", "jpg") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_33(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - images = [raw["jpg"]] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_33(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("json", "jpg") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_34(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - images = [raw["jpg"]] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_34(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("json", "jpg") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_35(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - images = [raw["jpg"]] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_35(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("jpg", "json") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_36(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - images = [raw["jpg"]] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_36(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("jpg", "json") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_37(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - images = [raw["jpg"]] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_37(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("json", "jpg") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_38(raw: dict) -> dict: - j = raw["json"] - - if "ReCTs" in raw["__key__"]: - return dict( - __key__=raw["__key__"], - image=raw["jpg"], - text="", - words_boxes=j["quads_1k_normalized"], - words_text=j["texts"], - ) - else: # coco-text-multi, textocr-multi - return dict( - __key__=raw["__key__"], - image=raw["jpg"], - text="", - words_boxes=j["bboxes_1k_normalized"], - words_text=j["texts"], - ) - - -def part_filter_38(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("jpg", "json") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_39(raw: dict) -> dict: # Note: Images are already decoded to tensors - j = raw["json"] - return dict( - image=raw["jpg"], # expected type: torch.Tensor - text=" ".join(j["lines"]["text"]), # expected type: str - ) - - -def part_filter_39(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("jpg", "json") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_40(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - images = [raw["jpg"]] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_40(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("json", "jpg") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_41(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - images = [raw["jpg"]] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_41(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("jpg", "json") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_42(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - - img = Image.open(io.BytesIO(raw["img"])) - images = [img] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_42(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("img", "json") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_43(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - - img = Image.open(io.BytesIO(raw["img"])) - images = [img] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_43(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("img", "json") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_44(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - - img = Image.open(io.BytesIO(raw["img"])) - images = [img] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_44(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("img", "json") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_45(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - - img = Image.open(io.BytesIO(raw["img"])) - images = [img] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_45(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("img", "json") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_46(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - - img = Image.open(io.BytesIO(raw["img"])) - images = [img] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_46(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("img", "json") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_47(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - - img = Image.open(io.BytesIO(raw["img"])) - images = [img] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_47(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("img", "json") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_48(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - - img = Image.open(io.BytesIO(raw["img"])) - images = [img] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_48(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("img", "json") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_49(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - - img = Image.open(io.BytesIO(raw["img"])) - images = [img] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_49(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("img", "json") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_50(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - - img = Image.open(io.BytesIO(raw["img"])) - images = [img] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_50(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("img", "json") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_51(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - for i, turn in enumerate(json_item["conversations"]): - if i > 0 and turn["from"] == "human" and "" in turn["value"]: - turn["value"] = turn["value"].replace("\n", "") - - img = Image.open(io.BytesIO(raw["img"])) - images = [img] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_51(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("img", "json") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_52(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - - img = Image.open(io.BytesIO(raw["img"])) - images = [img] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_52(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("img", "json") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_53(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - # for i, turn in enumerate(json_item['conversations']): - # if i > 0 and turn['from'] == 'human' and '' in turn['value']: - # turn['value'] = turn['value'].replace("\n", "") - - img = Image.open(io.BytesIO(raw["img"])) - images = [img] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_53(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("img", "json") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_54(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - - img = Image.open(io.BytesIO(raw["img"])) - images = [img] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_54(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("img", "json") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_55(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - - img = Image.open(io.BytesIO(raw["img"])) - images = [img] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_55(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("img", "json") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_56(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - - img = Image.open(io.BytesIO(raw["img"])) - images = [img] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_56(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("img", "json") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_57(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - - img = Image.open(io.BytesIO(raw["img"])) - images = [img] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_57(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("img", "json") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_58(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - - img = Image.open(io.BytesIO(raw["img"])) - images = [img] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_58(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("img", "json") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_59(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - - img = Image.open(io.BytesIO(raw["img"])) - images = [img] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_59(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("img", "json") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_60(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - - img = Image.open(io.BytesIO(raw["img"])) - images = [img] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_60(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("img", "json") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_61(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - - img = Image.open(io.BytesIO(raw["img"])) - images = [img] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_61(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("img", "json") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_62(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - - img = Image.open(io.BytesIO(raw["img"])) - images = [img] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_62(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("img", "json") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_63(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - - img = Image.open(io.BytesIO(raw["img"])) - images = [img] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_63(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("img", "json") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_64(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - - img = Image.open(io.BytesIO(raw["img"])) - images = [img] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_64(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("img", "json") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_65(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - - img = Image.open(io.BytesIO(raw["img"])) - images = [img] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_65(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("img", "json") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_66(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - - img = Image.open(io.BytesIO(raw["img"])) - images = [img] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_66(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("img", "json") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_67(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - - img = Image.open(io.BytesIO(raw["img"])) - images = [img] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_67(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("img", "json") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_68(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - images = [raw["jpg"]] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_68(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("json", "jpg") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_69(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - images = [raw["jpg"]] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_69(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("jpg", "json") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_70(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - images = [raw["jpg"]] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_70(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("jpg", "json") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_71(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - images = [raw["jpg"]] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_71(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("jpg", "json") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_72(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - images = [raw["jpg"]] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_72(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("jpg", "json") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_73(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - - img = Image.open(io.BytesIO(raw["img"])) - images = [img] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_73(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("json", "img") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_74(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - - img = Image.open(io.BytesIO(raw["img"])) - images = [img] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_74(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("json", "img") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_75(raw: dict) -> dict: # Note: Images are already decoded to tensors - - if "text" in raw: - caption = raw["text"] - else: - caption = raw["json"]["caption"] - return dict( - __key__=raw["__key__"], - image=raw["jpg"], # expected type: torch.Tensor - caption=caption, # expected type: str - ) - - -def part_filter_75(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("json", "jpg", "text") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_76(raw: dict) -> dict: # Note: Images are already decoded to tensors - - if "text" in raw: - caption = raw["text"] - else: - caption = raw["json"]["caption"] - return dict( - __key__=raw["__key__"], - image=raw["jpg"], # expected type: torch.Tensor - caption=caption, # expected type: str - ) - - -def part_filter_76(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("json", "jpg", "text") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_77(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - - total = len(json_item["conversations"]) // 2 - idx = random.randrange(total) # noqa: F821 - human = json_item["conversations"][idx * 2] - out = json_item["conversations"][idx * 2 + 1] - - return dict( - __key__=raw["__key__"], - image=raw["jpg"], - context=human["value"].replace("\n", ""), - answers=out["value"], - answer_weights=None, - ) - - -def part_filter_77(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("json", "jpg") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_78(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - - total = len(json_item["conversations"]) // 2 - idx = random.randrange(total) # noqa: F821 - human = json_item["conversations"][idx * 2] - out = json_item["conversations"][idx * 2 + 1] - - return dict( - __key__=raw["__key__"], - image=raw["jpg"], - context=human["value"].replace("\n", ""), - answers=out["value"], - answer_weights=None, - ) - - -def part_filter_78(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("json", "jpg") - - - - -def sample_loader_79(raw: dict) -> dict: # Note: Images are already decoded to tensors - j = raw["json"] - if "answer" in j: - answers = [a[0] for a in j["answer"][0]] - answer_weights = torch.Tensor([float(a[1]) for a in j["answer"][0]]) - else: - answers = None - answer_weights = None - - return dict( - __key__=raw["__key__"], - image=raw["jpg"], # expected type: torch.Tensor - context=j["question"], # expected type: str - answers=answers, # expected type: typing.List[str] - answer_weights=answer_weights, # expected type: typing.Union[torch.Tensor, NoneType] - ) - - -def part_filter_79(part: str) -> bool: - # Filter for parts required by the sample_loader - return part in ("jpg", "json") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_80(raw: dict) -> dict: # Note: Images are already decoded to tensors - - return dict( - __key__=raw["__key__"], - image=raw["jpg"], # expected type: torch.Tensor - context=raw["json"]["question"], # expected type: str - answers=raw["json"]["answer"], # expected type: typing.Union[typing.List[str], NoneType], default: None - answer_weights=None, # expected type: typing.Union[torch.Tensor, NoneType], default: None - ) - - -def part_filter_80(part: str) -> bool: - - return part in ("json", "jpg") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_81(raw: dict) -> dict: # Note: Images are already decoded to tensors - j = raw["json"] - return dict( - image=raw["jpg"], # expected type: torch.Tensor - text=" ".join(j["lines"]["text"]), # expected type: str - ) - - -def part_filter_81(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("jpg", "json") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_82(raw: dict) -> dict: # Note: Images are already decoded to tensors - j = raw["json"] - - return dict( - __key__=raw["__key__"], - image=raw["jpg"], - text=j["text"], - words_boxes=j["bbox_1k_normalized"], - ) - - -def part_filter_82(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("jpg", "json") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_83(raw: dict) -> dict: - j = raw["json"] - - return dict( - __key__=raw["__key__"], - image=raw["jpg"], - text=j["text"], - words_boxes=j["bbox_1k_normalized"], - ) - - -def part_filter_83(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("json", "jpg") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_84(raw: dict) -> dict: # Note: Images are already decoded to tensors - j = raw["json"] - - return dict( - __key__=raw["__key__"], - image=raw["jpg"], - text=j["text"], - words_boxes=j["bbox_1k_normalized"], - ) - - -def part_filter_84(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("json", "jpg") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_85(raw: dict) -> dict: # Note: Images are already decoded to tensors - j = raw["json"] - - return dict(__key__=raw["__key__"], image=raw["jpg"], text=j["text"], words_boxes=j["quad_1k_normalized"]) - - -def part_filter_85(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("json", "jpg") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_86(raw: dict) -> dict: # Note: Images are already decoded to tensors - j = raw["json"] - - return dict( - __key__=raw["__key__"], - image=raw["jpg"], - text=j["text"], - words_boxes=j["bbox_1k_normalized"], - ) - - -def part_filter_86(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("jpg", "json") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_87(raw: dict) -> dict: # Note: Images are already decoded to tensors - j = raw["json"] - - quad = j["quad"] - quad = [val for point in quad for val in point] - - return dict( - image=raw["jpg"], # expected type: torch.Tensor - text=j["text"], # expected type: str - words_boxes=quad, # expected type: typing.Optional[torch.Tensor], default: None - ) - - -def part_filter_87(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("json", "jpg") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_88(raw: dict) -> dict: # Note: Images are already decoded to tensors - j = raw["json"] - return dict( - __key__=raw["__key__"], - image=raw["jpg"], - text="", - words_boxes=j["bboxes_1k_normalized"], - words_text=j["texts"], - ) - - -def part_filter_88(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("json", "jpg") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_89(raw: dict) -> dict: - j = raw["json"] - return dict( - __key__=raw["__key__"], - image=raw["jpg"], - text="", - words_boxes=j["bboxes_1k_normalized"], - words_text=j["texts"], - ) - - -def part_filter_89(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("jpg", "json") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_90(raw: dict) -> dict: # Note: Images are already decoded to tensors - j = raw["json"] - return dict( - __key__=raw["__key__"], - image=raw["jpg"], - text="", - words_boxes=j["quads_1k_normalized"], - words_text=j["texts"], - ) - - -def part_filter_90(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("json", "jpg") - - - - -def sample_loader_91(raw: dict) -> dict: # Note: Images are already decoded to tensors - j = raw["json"] - if "answer" in j: - answers = [a[0] for a in j["answer"][0]] - answer_weights = torch.Tensor([float(a[1]) for a in j["answer"][0]]) - else: - answers = None - answer_weights = None - - return dict( - __key__=raw["__key__"], - image=raw["jpg"], # expected type: torch.Tensor - context=j["question"], # expected type: str - answers=answers, # expected type: typing.List[str] - answer_weights=answer_weights, # expected type: typing.Union[torch.Tensor, NoneType] - ) - - -def part_filter_91(part: str) -> bool: - # Filter for parts required by the sample_loader - return part in ("jpg", "json") - - -# Dataset -> Sample Loader Mapping -dataset_loader_mapping = { - "coco_train_val_restval": { - "sample_loader": "sample_loader_0", - "part_filter": "part_filter_0", - "data_class": "CaptioningWebdataset", - "data_weight": 0.01, - }, - "extended-sci/data/merged/CoT": { - "sample_loader": "sample_loader_1", - "part_filter": "part_filter_1", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.006, - }, - "extended-sci/data/merged/single-choice": { - "sample_loader": "sample_loader_2", - "part_filter": "part_filter_2", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.004, - }, - "extended-sci/data/extended-sci-3/CoT": { - "sample_loader": "sample_loader_3", - "part_filter": "part_filter_3", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.0006, - }, - "extended-sci/data/extended-sci-3/single-choice": { - "sample_loader": "sample_loader_4", - "part_filter": "part_filter_4", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.0004, - }, - "nvlm/wdai/data/SceMQA_processed": { - "sample_loader": "sample_loader_5", - "part_filter": "part_filter_5", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.0006, - }, - "nvlm/wdai/data/vqa_collection_doc_text_st_chart_scale_textbook_LRV_Screen": { - "sample_loader": "sample_loader_6", - "part_filter": "part_filter_6", - "data_class": "VQAWebdataset", - "data_weight": 0.08, - }, - "nvlm/wdai/data/plotqa/processed": { - "sample_loader": "sample_loader_7", - "part_filter": "part_filter_7", - "data_class": "VQAWebdataset", - "data_weight": 0.095, - }, - "nvlm/wdai/data/clevr-math/processed": { - "sample_loader": "sample_loader_8", - "part_filter": "part_filter_8", - "data_class": "VQAWebdataset", - "data_weight": 0.02, - }, - "nvlm/wdai/data/MMC-Instruction/processed": { - "sample_loader": "sample_loader_9", - "part_filter": "part_filter_9", - "data_class": "VQAWebdataset", - "data_weight": 0.07, - }, - "nvlm/wdai/data/ocrvqa/processed": { - "sample_loader": "sample_loader_10", - "part_filter": "part_filter_10", - "data_class": "VQAWebdataset", - "data_weight": 0.06, - }, - "nvlm/wdai/data/dude/processed": { - "sample_loader": "sample_loader_11", - "part_filter": "part_filter_11", - "data_class": "VQAWebdataset", - "data_weight": 0.01, - }, - "nvlm/wdai/data/VisualMRC/processed": { - "sample_loader": "sample_loader_12", - "part_filter": "part_filter_12", - "data_class": "VQAWebdataset", - "data_weight": 0.015, - }, - "nvlm/wdai/data/mcvqa_collection_scienceqa_ai2d_geoqaplus_geometry3k_tqa": { - "sample_loader": "sample_loader_13", - "part_filter": "part_filter_13", - "data_class": "MultiChoiceVQAWebdataset", - "data_weight": 0.025, - }, - "nvlm/wdai/data/arxiv_qa/processed": { - "sample_loader": "sample_loader_14", - "part_filter": "part_filter_14", - "data_class": "MultiChoiceVQAWebdataset", - "data_weight": 0.02, - }, - "nvlm/wdai/data/tabmwp/processed": { - "sample_loader": "sample_loader_15", - "part_filter": "part_filter_15", - "data_class": "MultiChoiceVQAWebdataset", - "data_weight": 0.015, - }, - "nvlm/wdai/data/ocr_vqa_aug/processed": { - "sample_loader": "sample_loader_16", - "part_filter": "part_filter_16", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.055, - }, - "nvlm/wdai/data/dvqa_full/processed": { - "sample_loader": "sample_loader_17", - "part_filter": "part_filter_17", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.055, - }, - "nvlm/wdai/data/LLaVA-v1.5_shuffle/no_refcoco_vg_ocrvqa": { - "sample_loader": "sample_loader_18", - "part_filter": "part_filter_18", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.085, - }, - "vqa/more_data/infographics_vqa/processed/train": { - "sample_loader": "sample_loader_19", - "part_filter": "part_filter_19", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.01, - }, - "nvlm/wdai/data/sharegpt4o/processed": { - "sample_loader": "sample_loader_20", - "part_filter": "part_filter_20", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.02, - }, - "nvlm/wdai/data/sparse_ocr_data/merged": { - "sample_loader": "sample_loader_21", - "part_filter": "part_filter_21", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.045, - }, - "nvlm/nayeonl/data/blendv4/MetaMathQA/processed/train_text_image": { - "sample_loader": "sample_loader_22", - "part_filter": "part_filter_22", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.004, - }, - "nvlm/nayeonl/data/blendv4/gsm8k/processed/train_text_image": { - "sample_loader": "sample_loader_23", - "part_filter": "part_filter_23", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.003, - }, - "nvlm/wdai/data/docmatix/processed": { - "sample_loader": "sample_loader_24", - "part_filter": "part_filter_24", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.1, - }, - "nvlm/wdai/data/bentham_hw_squad/processed": { - "sample_loader": "sample_loader_25", - "part_filter": "part_filter_25", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.01, - }, - "nvlm/wdai/data/WikiTableQA/processed": { - "sample_loader": "sample_loader_26", - "part_filter": "part_filter_26", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.003, - }, - "nvlm/wdai/data/figureqa/processed": { - "sample_loader": "sample_loader_27", - "part_filter": "part_filter_27", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.01, - }, - "nvlm/wdai/data/llava-onevision/ai2d_combined_processed": { - "sample_loader": "sample_loader_28", - "part_filter": "part_filter_28", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.01, - }, - "nvlm/wdai/data/llava-onevision/math_combined_processed": { - "sample_loader": "sample_loader_29", - "part_filter": "part_filter_29", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.035, - }, - "nvlm/wdai/data/llava-onevision/robut_combined_processed": { - "sample_loader": "sample_loader_30", - "part_filter": "part_filter_30", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.01, - }, - "nvlm/wdai/data/llava-onevision/llavar_20k_processed": { - "sample_loader": "sample_loader_31", - "part_filter": "part_filter_31", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.007, - }, - "nvlm/wdai/data/llava-onevision/tallyqa_processed": { - "sample_loader": "sample_loader_32", - "part_filter": "part_filter_32", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.02, - }, - "nvlm/wdai/data/llava-onevision/ureader_ie_processed": { - "sample_loader": "sample_loader_33", - "part_filter": "part_filter_33", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.007, - }, - "nvlm/wdai/data/llava-onevision/visual7w_processed": { - "sample_loader": "sample_loader_34", - "part_filter": "part_filter_34", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.006, - }, - "nvlm/wdai/data/llava-onevision/mavis_math_rule_geo_processed": { - "sample_loader": "sample_loader_35", - "part_filter": "part_filter_35", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.01, - }, - "nvlm/wdai/data/llava-onevision/ureader_kg_processed": { - "sample_loader": "sample_loader_36", - "part_filter": "part_filter_36", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.005, - }, - "nvlm/wdai/data/llava-onevision/ureader_qa_processed": { - "sample_loader": "sample_loader_37", - "part_filter": "part_filter_37", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.02, - }, - "nvlm/wdai/data/ocr_multi_collection_cocotext_textocr_ReCTs": { - "sample_loader": "sample_loader_38", - "part_filter": "part_filter_38", - "data_class": "OCRWebdataset", - "data_weight": 0.01, - }, - "nvlm/wdai/data/pdfa-eng-wds/processed_word_len_500": { - "sample_loader": "sample_loader_39", - "part_filter": "part_filter_39", - "data_class": "OCRWebdataset", - "data_weight": 0.015, - }, - "nvlm/wdai/data/llava-onevision/super_clevr_processed": { - "sample_loader": "sample_loader_40", - "part_filter": "part_filter_40", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.007, - }, - "nvlm/wdai/data/llava-onevision/icon_qa_processed": { - "sample_loader": "sample_loader_41", - "part_filter": "part_filter_41", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.009, - }, - "nvlm/wdai/data/augmentations/chartqa_aug": { - "sample_loader": "sample_loader_42", - "part_filter": "part_filter_42", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.005, - }, - "nvlm/wdai/data/augmentations/gpt_chartqa": { - "sample_loader": "sample_loader_43", - "part_filter": "part_filter_43", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.006, - }, - "nvlm/wdai/data/augmentations/gpt_docvqa": { - "sample_loader": "sample_loader_44", - "part_filter": "part_filter_44", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.006, - }, - "nvlm/wdai/data/augmentations/docvqa_text": { - "sample_loader": "sample_loader_45", - "part_filter": "part_filter_45", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.006, - }, - "nvlm/wdai/data/augmentations/textvqa_text": { - "sample_loader": "sample_loader_46", - "part_filter": "part_filter_46", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.008, - }, - "nvlm/wdai/data/augmentations/i2s-musicsheet": { - "sample_loader": "sample_loader_47", - "part_filter": "part_filter_47", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.0005, - }, - "nvlm/wdai/data/augmentations/music": { - "sample_loader": "sample_loader_48", - "part_filter": "part_filter_48", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.007, - }, - "nvlm/wdai/data/augmentations/invoice": { - "sample_loader": "sample_loader_49", - "part_filter": "part_filter_49", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.002, - }, - "nvlm/wdai/data/augmentations/k12": { - "sample_loader": "sample_loader_50", - "part_filter": "part_filter_50", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.019, - }, - "nvlm/wdai/data/augmentations/MTVQA": { - "sample_loader": "sample_loader_51", - "part_filter": "part_filter_51", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.007, - }, - "nvlm/wdai/data/augmentations/VisualWebInstruct": { - "sample_loader": "sample_loader_52", - "part_filter": "part_filter_52", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.028, - }, - "nvlm/wdai/data/augmentations/financeqa": { - "sample_loader": "sample_loader_53", - "part_filter": "part_filter_53", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.005, - }, - "nvlm/wdai/data/augmentations/docreason": { - "sample_loader": "sample_loader_54", - "part_filter": "part_filter_54", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.004, - }, - "nvlm/wdai/data/augmentations/gpt_mtwi": { - "sample_loader": "sample_loader_55", - "part_filter": "part_filter_55", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.005, - }, - "nvlm/wdai/data/augmentations/geos_gpt": { - "sample_loader": "sample_loader_56", - "part_filter": "part_filter_56", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.0001, - }, - "nvlm/wdai/data/augmentations/cauldron_vistext": { - "sample_loader": "sample_loader_57", - "part_filter": "part_filter_57", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.005, - }, - "nvlm/wdai/data/augmentations/memes": { - "sample_loader": "sample_loader_58", - "part_filter": "part_filter_58", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.005, - }, - "nvlm/wdai/data/augmentations/gpt_roadtext": { - "sample_loader": "sample_loader_59", - "part_filter": "part_filter_59", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.0002, - }, - "nvlm/wdai/data/augmentations/indoor_qa": { - "sample_loader": "sample_loader_60", - "part_filter": "part_filter_60", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.001, - }, - "nvlm/wdai/data/augmentations/colpali": { - "sample_loader": "sample_loader_61", - "part_filter": "part_filter_61", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.007, - }, - "nvlm/wdai/data/augmentations/pmc_vqa": { - "sample_loader": "sample_loader_62", - "part_filter": "part_filter_62", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.01, - }, - "nvlm/wdai/data/augmentations/pathvqa": { - "sample_loader": "sample_loader_63", - "part_filter": "part_filter_63", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.004, - }, - "nvlm/wdai/data/augmentations/sciqa": { - "sample_loader": "sample_loader_64", - "part_filter": "part_filter_64", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.027, - }, - "nvlm/wdai/data/augmentations/chinese_meme": { - "sample_loader": "sample_loader_65", - "part_filter": "part_filter_65", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.001, - }, - "nvlm/wdai/data/augmentations/gpt_hiertext": { - "sample_loader": "sample_loader_66", - "part_filter": "part_filter_66", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.003, - }, - "nvlm/wdai/data/augmentations/cauldron_cocoqa": { - "sample_loader": "sample_loader_67", - "part_filter": "part_filter_67", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.007, - }, - "nvlm/wdai/data/cmm-math/processed": { - "sample_loader": "sample_loader_68", - "part_filter": "part_filter_68", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.001, - }, - "nvlm/wdai/data/mmtab/processed": { - "sample_loader": "sample_loader_69", - "part_filter": "part_filter_69", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.008, - }, - "nvlm/wdai/data/simchart9k/processed": { - "sample_loader": "sample_loader_70", - "part_filter": "part_filter_70", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.001, - }, - "nvlm/wdai/data/llava-onevision/mapqa_processed": { - "sample_loader": "sample_loader_71", - "part_filter": "part_filter_71", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.005, - }, - "nvlm/wdai/data/llava-onevision/vizwiz_processed": { - "sample_loader": "sample_loader_72", - "part_filter": "part_filter_72", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.002, - }, - "nvlm/wdai/data/augmentations/gpt_infovqa": { - "sample_loader": "sample_loader_73", - "part_filter": "part_filter_73", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.001, - }, - "nvlm/wdai/data/augmentations/viquae": { - "sample_loader": "sample_loader_74", - "part_filter": "part_filter_74", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.0005, - }, - "captioning/ccs_recaptioned/webdataset": { - "sample_loader": "sample_loader_75", - "part_filter": "part_filter_75", - "data_class": "CaptioningWebdataset", - "data_weight": 0.2, - }, - "captioning/laion115m-clean": { - "sample_loader": "sample_loader_76", - "part_filter": "part_filter_76", - "data_class": "CaptioningWebdataset", - "data_weight": 0.579, - }, - "nvlm/wdai/data/dvqa_full/processed_pt": { - "sample_loader": "sample_loader_77", - "part_filter": "part_filter_77", - "data_class": "VQAWebdataset", - "data_weight": 0.02, - }, - "nvlm/wdai/data/docmatix/processed_pt": { - "sample_loader": "sample_loader_78", - "part_filter": "part_filter_78", - "data_class": "VQAWebdataset", - "data_weight": 0.02, - }, - "vqa/VQAv2/stage1": { - "sample_loader": "sample_loader_91", - "part_filter": "part_filter_91", - "data_class": "VQAWebdataset", - "data_weight": 1.0, - }, - "vqa/Visual_Genome": { - "sample_loader": "sample_loader_80", - "part_filter": "part_filter_80", - "data_class": "VQAWebdataset", - "data_weight": 0.01, - }, - "nvlm/wdai/data/pdfa-eng-wds/processed_word_len_300": { - "sample_loader": "sample_loader_81", - "part_filter": "part_filter_81", - "data_class": "OCRWebdataset", - "data_weight": 0.08, - }, - "nvlm/wdai/data/textocr/processed": { - "sample_loader": "sample_loader_82", - "part_filter": "part_filter_82", - "data_class": "OCRWebdataset", - "data_weight": 0.02, - }, - "nvlm/wdai/data/coco-text/processed": { - "sample_loader": "sample_loader_83", - "part_filter": "part_filter_83", - "data_class": "OCRWebdataset", - "data_weight": 0.002, - }, - "nvlm/wdai/data/ArT/processed": { - "sample_loader": "sample_loader_84", - "part_filter": "part_filter_84", - "data_class": "OCRWebdataset", - "data_weight": 0.001, - }, - "nvlm/wdai/data/ReCTs/processed": { - "sample_loader": "sample_loader_85", - "part_filter": "part_filter_85", - "data_class": "OCRWebdataset", - "data_weight": 0.001, - }, - "nvlm/wdai/data/lsvt/processed": { - "sample_loader": "sample_loader_86", - "part_filter": "part_filter_86", - "data_class": "OCRWebdataset", - "data_weight": 0.005, - }, - "nvlm/wdai/data/RCTW/processed": { - "sample_loader": "sample_loader_87", - "part_filter": "part_filter_87", - "data_class": "OCRWebdataset", - "data_weight": 0.001, - }, - "nvlm/wdai/data/coco-text/processed_multi": { - "sample_loader": "sample_loader_88", - "part_filter": "part_filter_88", - "data_class": "OCRWebdataset", - "data_weight": 0.0003, - }, - "nvlm/wdai/data/textocr/processed_multi": { - "sample_loader": "sample_loader_89", - "part_filter": "part_filter_89", - "data_class": "OCRWebdataset", - "data_weight": 0.0004, - }, - "nvlm/wdai/data/ReCTs/processed_multi": { - "sample_loader": "sample_loader_90", - "part_filter": "part_filter_90", - "data_class": "OCRWebdataset", - "data_weight": 0.0003, - }, -} - - -def get_sample_loader(path): - """Returns the correct sample_loader function for a dataset.""" - if path not in dataset_loader_mapping: - path = data_path_mapping(path) - assert path in dataset_loader_mapping, f"path {path} not in dataset_loader_mapping" - return globals().get(dataset_loader_mapping.get(path, {}).get("sample_loader")) - - -def get_part_filter(path): - """Returns the correct part_filter function for a dataset.""" - if path not in dataset_loader_mapping: - path = data_path_mapping(path) - assert path in dataset_loader_mapping, f"path {path} not in dataset_loader_mapping" - return globals().get(dataset_loader_mapping.get(path, {}).get("part_filter")) - - -def get_data_class(path): - """Returns the correct data_class for a dataset.""" - if path not in dataset_loader_mapping: - path = data_path_mapping(path) - - assert path in dataset_loader_mapping, f"path {path} not in dataset_loader_mapping" - return dataset_loader_mapping[path]["data_class"] diff --git a/cosmos_framework/data/vfm/augmentors/vlm/prompt_format.py b/cosmos_framework/data/vfm/augmentors/vlm/prompt_format.py index ec86e66..5b576c4 100644 --- a/cosmos_framework/data/vfm/augmentors/vlm/prompt_format.py +++ b/cosmos_framework/data/vfm/augmentors/vlm/prompt_format.py @@ -45,7 +45,6 @@ def __call__(self, data_dict: Dict) -> Dict: if isinstance(list_of_conversation[0], list): selected_conversation = random.sample(list_of_conversation, 1)[0] elif isinstance(list_of_conversation[0], dict): - selected_conversation = list_of_conversation else: raise ValueError( @@ -82,7 +81,6 @@ def __call__(self, data_dict: Dict) -> Dict: del data_dict[conversation_key] - # # enforce chat order # self._enforce_text_chat_order(selected_conversation) @@ -91,7 +89,7 @@ def __call__(self, data_dict: Dict) -> Dict: def _enforce_text_chat_order(self, conversation: list) -> None: """ Reorder text content within user messages based on text_chat_order setting. - NOTE: this does NOT work for interleaved data!!!!!! + NOTE (maxzhaoshuol): this does NOT work for interleaved data!!!!!! Args: conversation: List of message dictionaries diff --git a/cosmos_framework/data/vfm/augmentors/vlm/timestamp.py b/cosmos_framework/data/vfm/augmentors/vlm/timestamp.py index edede0c..88d3ac5 100644 --- a/cosmos_framework/data/vfm/augmentors/vlm/timestamp.py +++ b/cosmos_framework/data/vfm/augmentors/vlm/timestamp.py @@ -97,7 +97,7 @@ def overlay_text( return images, [compute_timestamps(i, fps, processor) for i in range(len(images))] # Try to use DejaVu Sans Mono font for better readability - font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSansMono.ttf", font_size) + font = ImageFont.truetype("/invalid_dir", font_size) # Process each image processed_images = [] @@ -392,17 +392,15 @@ def augment_user_prompt( elif output_format == "temporal_caption": event = assistant_message[0] if random.random() < 0.333333: - start = round(event["start"]) end = round(event["end"]) elif random.random() < 0.666666: - start = round(event["start"] * 2) / 2 end = round(event["end"] * 2) / 2 else: start = event["start"] end = event["end"] - if start == end: # HACK: remove events with start == end + if start == end: raise ValueError("Start and end time are the same for data.") if timestamp_format == "seconds": if random.random() < 0.5: diff --git a/cosmos_framework/data/vfm/augmentors/vlm/timestamp_with_subject_tracking.py b/cosmos_framework/data/vfm/augmentors/vlm/timestamp_with_subject_tracking.py index 0507109..90cc9ab 100644 --- a/cosmos_framework/data/vfm/augmentors/vlm/timestamp_with_subject_tracking.py +++ b/cosmos_framework/data/vfm/augmentors/vlm/timestamp_with_subject_tracking.py @@ -224,17 +224,15 @@ def augment_user_prompt( elif output_format == "temporal_caption_subject": event = assistant_message[0] if random.random() < 0.333333: - start = round(event["start"]) end = round(event["end"]) elif random.random() < 0.666666: - start = round(event["start"] * 2) / 2 end = round(event["end"] * 2) / 2 else: start = event["start"] end = event["end"] - if start == end: # HACK: remove events with start == end + if start == end: log.warning(f"Start and end time are the same for data. {event}") return None diff --git a/cosmos_framework/data/vfm/augmentors/vlm/timestamp_without_augment_message.py b/cosmos_framework/data/vfm/augmentors/vlm/timestamp_without_augment_message.py index 584ca1d..7212510 100644 --- a/cosmos_framework/data/vfm/augmentors/vlm/timestamp_without_augment_message.py +++ b/cosmos_framework/data/vfm/augmentors/vlm/timestamp_without_augment_message.py @@ -162,14 +162,12 @@ def augment_user_prompt( elif output_format == "temporal_caption": event = assistant_message[0] if random.random() < 0.5: - start = round(event["start"]) end = round(event["end"]) else: - start = round(event["start"] * 2) / 2 end = round(event["end"] * 2) / 2 - if start == end: # HACK: remove events with start == end + if start == end: raise ValueError("Start and end time are the same for data.") user_prompt = random.choice( [ diff --git a/cosmos_framework/data/vfm/augmentors/vlm/timestamp_without_end_time.py b/cosmos_framework/data/vfm/augmentors/vlm/timestamp_without_end_time.py index 8df9dd1..812e113 100644 --- a/cosmos_framework/data/vfm/augmentors/vlm/timestamp_without_end_time.py +++ b/cosmos_framework/data/vfm/augmentors/vlm/timestamp_without_end_time.py @@ -205,10 +205,8 @@ def augment_user_prompt( elif output_format == "temporal_caption": event = assistant_message[0] if random.random() < 0.333333: - start = round(event["start"]) elif random.random() < 0.666666: - start = round(event["start"] * 2) / 2 else: start = event["start"] diff --git a/cosmos_framework/data/vfm/augmentors/vlm/tokenize_data.py b/cosmos_framework/data/vfm/augmentors/vlm/tokenize_data.py index a0ce29c..f3a9914 100644 --- a/cosmos_framework/data/vfm/augmentors/vlm/tokenize_data.py +++ b/cosmos_framework/data/vfm/augmentors/vlm/tokenize_data.py @@ -158,7 +158,6 @@ def __call__(self, data_dict: Dict) -> Dict: if message["role"] == "user" and isinstance(message["content"], list): total_images += len([content for content in message["content"] if content["type"] == "image"]) total_videos += len([content for content in message["content"] if content["type"] == "video"]) - assert total_videos == 1 or total_videos == 0, "Only one video is supported for now" # url @@ -167,7 +166,6 @@ def __call__(self, data_dict: Dict) -> Dict: # go through each message in the conversation for message in conversation: # for user message, we insert the media - if message["role"] == "user" and isinstance( message["content"], list ): # Otherwise it's text and content is a string @@ -225,7 +223,6 @@ def __call__(self, data_dict: Dict) -> Dict: raw_images.append(image) elif content["type"] == "video": - # as tokenization will NOT upsample the video, we can use a larger value here at the cost of multiple video having 1.5x token length max_total_pixels = token_to_pixels(self.max_video_token_length * 1.5, temporal_patch_size=2) media_key = content["video"] @@ -248,7 +245,6 @@ def __call__(self, data_dict: Dict) -> Dict: return None videos = data_dict["media"][media_key]["videos"] # list of PIL images fps = data_dict["media"][media_key]["fps"] - # this is because videos are decoded to be around "max_video_token_length" tokens videos = maybe_subsample_frames( diff --git a/cosmos_framework/data/vfm/joint_dataloader.py b/cosmos_framework/data/vfm/joint_dataloader.py index 56b13e7..ba84ceb 100644 --- a/cosmos_framework/data/vfm/joint_dataloader.py +++ b/cosmos_framework/data/vfm/joint_dataloader.py @@ -1,7 +1,9 @@ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: OpenMDW-1.1 +import math from collections import deque +from collections.abc import Mapping from dataclasses import dataclass from typing import Any, ClassVar, Dict, Union @@ -12,6 +14,11 @@ from cosmos_framework.utils.lazy_config import instantiate from cosmos_framework.utils import log +from cosmos_framework.model.vfm.tokenizers.uniae.frame_math import ( + get_uniae_chunk_frames, + get_uniae_latent_num_frames, + normalize_uniae_chunk_frames, +) _TIMING_KEYS = {"_sample_time", "_aug_time", "_pre_aug_time", "_aug_step_times"} _BATCH_TIMING_KEYS = { @@ -38,6 +45,7 @@ def custom_collate_fn(batch): "sound", "raw_action_dim", "image_size", + "action_processing_record", } # Data keys where a per-sample value of ``None`` is a meaningful signal @@ -57,7 +65,6 @@ def custom_collate_fn(batch): # Handle standard list of samples elem = batch[0] if isinstance(elem, dict): - # Some Action datasets add optional metadata keys (for example # ``additional_view_description`` for concat-view captions) only for a # subset of samples. PyTorch can batch such samples together when @@ -72,6 +79,9 @@ def custom_collate_fn(batch): if key in _TIMING_KEYS: continue values = [d.get(key) for d in batch] + if key == "action_processing_record": + result[key] = values + continue if any(value is None for value in values): # Sparse data keys keep their None placeholders to preserve # 1:1 alignment with sequence_plan. Other (optional metadata) @@ -165,6 +175,8 @@ def __init__( prewarm: bool = True, default_lookahead_limit: int = _DEFAULT_LOOKAHEAD_LIMIT, lookahead_limits: Dict[str, int] | None = None, + uniae_chunk_frames: int | Mapping[str, int] | None = None, + uniae_pad_frames: int | None = None, ): """ Initialize the JointDataLoader with multiple datasets. @@ -186,6 +198,8 @@ def __init__( default_lookahead_limit: Packing-loop look-ahead fallback for dataloaders not in ``lookahead_limits``. lookahead_limits: Optional ``{dataset_name: int}`` per-dataloader override. + uniae_chunk_frames: Optional UniAE full chunk size, or resolution-keyed chunk sizes. + uniae_pad_frames: Optional UniAE boundary padding frames per chunk. Example: joint_loader = IterativeJointDataLoader( @@ -211,6 +225,8 @@ def __init__( self.sound_latent_fps = sound_latent_fps self.audio_sample_rate = audio_sample_rate self.default_lookahead_limit = int(default_lookahead_limit) + self.uniae_pad_frames = int(uniae_pad_frames) if uniae_pad_frames is not None else None + self.uniae_chunk_frames = self._normalize_uniae_chunk_frames(uniae_chunk_frames) assert (self.max_sequence_length is None) != (self.max_samples_per_batch is None), ( "Exactly one of max_sequence_length or max_samples_per_batch must be None, but not both." @@ -221,6 +237,8 @@ def __init__( assert not unknown, f"lookahead_limits references unknown dataloaders {unknown}; valid: {sorted(dataloaders)}" for dataset_name, dataloader_data in dataloaders.items(): + if dataloader_data is None: + continue assert set(dataloader_data.keys()) == {"dataloader", "ratio"}, f"Invalid config: {dataloader_data}" if dataloader_data["ratio"] <= 0: continue @@ -255,13 +273,42 @@ def __init__( "JointDataLoader: prewarm DISABLED (debug mode); first iteration may incur per-stream cold-load cost" ) + def _normalize_uniae_chunk_frames( + self, uniae_chunk_frames: int | Mapping[str, int] | None + ) -> int | dict[str, int] | None: + return normalize_uniae_chunk_frames( + uniae_chunk_frames, + pad_frames=self.uniae_pad_frames, + temporal_compression_factor=self.tokenizer_temporal_compression_factor, + temporal_divisibility_name="tokenizer_temporal_compression_factor", + ) + + def _get_uniae_chunk_frames(self, spatial_shape: tuple[int, int]) -> int: + assert self.uniae_chunk_frames is not None + return get_uniae_chunk_frames(self.uniae_chunk_frames, spatial_shape=spatial_shape) + + def _compute_vision_latent_t_shape(self, T: int, H: int, W: int) -> int: + if T < 1: + raise ValueError(f"Vision media must contain at least one frame, got {T}.") + if T == 1 or self.uniae_chunk_frames is None: + return 1 + (T - 1) // self.tokenizer_temporal_compression_factor + + assert self.uniae_pad_frames is not None + return get_uniae_latent_num_frames( + T, + self.uniae_chunk_frames, + pad_frames=self.uniae_pad_frames, + temporal_compression_factor=self.tokenizer_temporal_compression_factor, + spatial_shape=(H, W), + ) + def _prewarm_dataloaders(self) -> None: """Force all dataloader iterators to spawn workers and produce one batch. The first ``next()`` call on an ``InfiniteDataLoader`` iterator triggers ``DataLoader.__iter__()`` which spawns worker processes. For action dataloaders using ``multiprocessing_context='spawn'``, each worker must - fully initialise heavy datasets (BridgeOrigLeRobotDataset, EMBODIMENT_A, etc.) + fully initialise heavy datasets (BridgeOrigLeRobotDataset, embodiment_a, etc.) from scratch. If this happens lazily during training, the resulting delay (potentially minutes) causes NCCL collective timeouts when faster ranks enter the forward pass while slower ranks are still loading data. @@ -362,14 +409,13 @@ def _compute_num_tokens_per_sample(self, data_batch: dict) -> int: else: _, T, H, W = media.shape - vae_spatial_downsample = self.tokenizer_spatial_compression_factor * self.patch_spatial - vae_temporal_downsample = self.tokenizer_temporal_compression_factor - - latent_h_shape = H // vae_spatial_downsample - latent_w_shape = W // vae_spatial_downsample - latent_t_shape = 1 + (T - 1) // vae_temporal_downsample + latent_h_shape = H // self.tokenizer_spatial_compression_factor + latent_w_shape = W // self.tokenizer_spatial_compression_factor + patch_h_shape = math.ceil(latent_h_shape / self.patch_spatial) + patch_w_shape = math.ceil(latent_w_shape / self.patch_spatial) + latent_t_shape = self._compute_vision_latent_t_shape(T, H, W) - num_vision_tokens = latent_h_shape * latent_w_shape * latent_t_shape + 2 + num_vision_tokens = patch_h_shape * patch_w_shape * latent_t_shape + 2 num_tokens += num_vision_tokens # Action part: each action time step is 1 token. @@ -534,6 +580,8 @@ def __init__( prewarm: bool = True, default_lookahead_limit: int = JointDataLoader._DEFAULT_LOOKAHEAD_LIMIT, lookahead_limits: Dict[str, int] | None = None, + uniae_chunk_frames: int | Mapping[str, int] | None = None, + uniae_pad_frames: int | None = None, ): super().__init__( dataloaders, @@ -547,6 +595,8 @@ def __init__( prewarm=prewarm, default_lookahead_limit=default_lookahead_limit, lookahead_limits=lookahead_limits, + uniae_chunk_frames=uniae_chunk_frames, + uniae_pad_frames=uniae_pad_frames, ) self.seed = seed # Calculate probabilities for random sampling @@ -787,6 +837,8 @@ def __init__( audio_sample_rate: int = 48000, dataset_name: str = "default", lookahead_limit: int = JointDataLoader._DEFAULT_LOOKAHEAD_LIMIT, + uniae_chunk_frames: int | Mapping[str, int] | None = None, + uniae_pad_frames: int | None = None, ): """ Args: @@ -802,6 +854,8 @@ def __init__( audio_sample_rate: Audio sample rate in Hz. dataset_name: Name tag attached to every sample in the output batch. lookahead_limit: Packing-loop look-ahead for the wrapped dataloader. + uniae_chunk_frames: Optional UniAE full chunk size, or resolution-keyed chunk sizes. + uniae_pad_frames: Optional UniAE boundary padding frames per chunk. """ wrapped = {dataset_name: {"dataloader": dataloader, "ratio": 1}} super().__init__( @@ -814,6 +868,8 @@ def __init__( sound_latent_fps=sound_latent_fps, audio_sample_rate=audio_sample_rate, lookahead_limits={dataset_name: int(lookahead_limit)}, + uniae_chunk_frames=uniae_chunk_frames, + uniae_pad_frames=uniae_pad_frames, ) def __iter__(self): @@ -905,6 +961,8 @@ def __init__( audio_sample_rate: int = 48000, default_lookahead_limit: int = JointDataLoader._DEFAULT_LOOKAHEAD_LIMIT, lookahead_limits: Dict[str, int] | None = None, + uniae_chunk_frames: int | Mapping[str, int] | None = None, + uniae_pad_frames: int | None = None, ): super().__init__( dataloaders, @@ -917,6 +975,8 @@ def __init__( audio_sample_rate=audio_sample_rate, default_lookahead_limit=default_lookahead_limit, lookahead_limits=lookahead_limits, + uniae_chunk_frames=uniae_chunk_frames, + uniae_pad_frames=uniae_pad_frames, ) # Convert data ratios to probabilities diff --git a/cosmos_framework/data/vfm/packing_iterable_dataset.py b/cosmos_framework/data/vfm/packing_iterable_dataset.py new file mode 100644 index 0000000..ac30f0d --- /dev/null +++ b/cosmos_framework/data/vfm/packing_iterable_dataset.py @@ -0,0 +1,271 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: OpenMDW-1.1 + +""" +Abstract base class for pool-based token-budget bin-packing over multiple datasets. + +Extracted from ``cosmos_framework.data.vfm.vlm.joint_dataset_dynamic_batch_webloader`` +so that both the VLM and VFM internal dataloaders can share a single packing implementation. + +Usage +----- +Subclass and implement ``compute_sample_tokens(sample) -> int``. +Optionally override ``collate_batch(samples) -> Any`` for custom collation. + + class MyPacker(PackingIterableDataset): + def compute_sample_tokens(self, sample): + return len(sample["input_ids"]) +""" + +from __future__ import annotations + +import random +from abc import ABC, abstractmethod +from collections import deque +from enum import Enum +from typing import Any, Union + +import torch + +from cosmos_framework.utils.lazy_config import instantiate +from cosmos_framework.utils import log + + +class Modality(Enum): + IMAGE = "image" + VIDEO = "video" + TEXT = "text" + + +class PackingIterableDataset(torch.utils.data.IterableDataset, ABC): + """Pool-based greedy bin-packing IterableDataset. + + Maintains a pool of ``pool_size`` samples and assembles batches by + greedily selecting candidates that fit within the token budget + ``max_tokens``. Subclasses supply two hooks: + + * ``compute_sample_tokens(sample)`` — token cost of one sample (abstract). + * ``collate_batch(samples)`` — assemble a packed list into a batch + (default: identity, returns the list unchanged). + + Parameters + ---------- + datasets_cfg: + Mapping ``{name: {"dataset": , "ratio": }}``. + The *dataset* value may be a Hydra lazy config, an already-constructed + ``IterableDataset``, or a plain ``DataLoader`` (its ``.dataset`` is + unwrapped automatically). + max_tokens: + Token budget per batch (padded cost = ``cur_max_len * batch_size``). + pool_size: + Number of samples to buffer before selecting a batch. + max_batch_size: + Hard cap on items per batch (0 or None = no cap). + long_threshold: + Samples with token count ``>= long_threshold`` are emitted as + singletons regardless of budget. + batching_strategy: + ``"prefer_closest"`` (default) or ``"prefer_first"``. + """ + + def __init__( + self, + datasets_cfg: dict[str, dict[str, Union[int, object]]], + max_tokens: int, + pool_size: int, + max_batch_size: int, + long_threshold: int, + batching_strategy: str, + ): + super().__init__() + + assert batching_strategy in ("prefer_first", "prefer_closest"), ( + f"batching_strategy must be 'prefer_first' or 'prefer_closest', got {batching_strategy!r}" + ) + + self.max_tokens = max_tokens + self.pool_size = pool_size + self.long_threshold = long_threshold + self.max_batch_size = max_batch_size + self.batching_strategy = batching_strategy + + self._pool: deque[dict] = deque() + self._dataset_names: list[str] = [] + self._ratios: list[float] = [] + self._datasets: list[torch.utils.data.IterableDataset] = [] + + for name, cfg in datasets_cfg.items(): + assert {"ratio", "dataset"} <= cfg.keys(), ( + f"Each entry must have 'dataset' and 'ratio' keys: {name} -> {cfg.keys()}" + ) + ratio = cfg["ratio"] + if ratio == 0: + log.info(f"Skipping dataset {name} with ratio {ratio}") + continue + dataset_cfg = cfg["dataset"] + + ds = ( + instantiate(dataset_cfg) + if not isinstance(dataset_cfg, (torch.utils.data.IterableDataset, torch.utils.data.DataLoader)) + else dataset_cfg + ) + if isinstance(ds, torch.utils.data.DataLoader): + ds = ds.dataset + if hasattr(ds, "build_dataset") and callable(getattr(ds, "build_dataset")): + ds = ds.build_dataset() + + assert isinstance(ds, torch.utils.data.IterableDataset), ( + f"Expected an IterableDataset, got {type(ds)} for {name}" + ) + + self._dataset_names.append(name) + self._ratios.append(float(ratio)) + self._datasets.append(ds) + log.info(f"Added dataset {name} with ratio {ratio}") + + log.info(f"added data: {list(datasets_cfg.keys())}") + assert len(self._datasets) > 0, "No datasets added" + self._data_len: int = sum(int(getattr(ds, "total_images", 0)) for ds in self._datasets) + if self._data_len == 0: + self._data_len = 10**12 + self.iterators = [iter(ds) for ds in self._datasets] + + # ------------------------------------------------------------------ + # Abstract / overridable hooks + # ------------------------------------------------------------------ + + @abstractmethod + def compute_sample_tokens(self, sample: dict) -> int: + """Return the token cost of one sample for packing budget accounting.""" + + def collate_batch(self, samples: list[dict]) -> Any: + """Assemble a packed list of samples into one batch. + + Default implementation returns the list unchanged (identity). + Override to pad, stack, or transform samples into tensors. + """ + return samples + + # ------------------------------------------------------------------ + # PyTorch Dataset API + # ------------------------------------------------------------------ + + def __len__(self) -> int: + return self._data_len + + def __iter__(self): + while True: + batch = self._best_fit_batch() + yield self.collate_batch(batch) + + # ------------------------------------------------------------------ + # Internal packing helpers (moved verbatim from _JointIterableDataset) + # ------------------------------------------------------------------ + + def _max_tokens(self, cur_max: int) -> int: + if cur_max < 1000: + return self.max_tokens + return self.max_tokens // 2 + + def _get_next_sample(self) -> dict: + index_id = random.choices(range(len(self.iterators)), weights=self._ratios, k=1)[0] + curr_dataset = self.iterators[index_id] + try: + output = next(curr_dataset) + except StopIteration: + log.critical(f"dataset {self._dataset_names[index_id]} exhausted") + self.iterators[index_id] = iter(self._datasets[index_id]) + output = next(self.iterators[index_id]) + return output + + def _fill_pool(self): + while len(self._pool) < self.pool_size: + self._pool.append(self._get_next_sample()) + + def _padded_cost(self, cur_max: int, k: int) -> int: + return cur_max * k + + def _get_modality(self, sample: dict) -> Modality: + if "pixel_values" in sample: + return Modality.IMAGE + elif "pixel_values_videos" in sample: + return Modality.VIDEO + return Modality.TEXT + + def _best_fit_batch(self) -> list[dict]: + """Build one batch using the configured token-budget strategy.""" + self._fill_pool() + seed = self._pool.popleft() + seed_modality = self._get_modality(seed) + L0 = self.compute_sample_tokens(seed) + + if L0 >= self.long_threshold or L0 >= self._max_tokens(L0): + return [seed] + + chosen = [seed] + cur_max = L0 + + while self._pool: + if self.max_batch_size and len(chosen) >= self.max_batch_size: + break + best_idx = self._find_best_candidate(cur_max, len(chosen), seed_modality) + if best_idx is None: + break + cand = self._remove_from_pool(best_idx) + chosen.append(cand) + cur_max = max(cur_max, self.compute_sample_tokens(cand)) + + return chosen + + def _find_best_candidate(self, cur_max: int, num_chosen: int, seed_modality: Modality) -> int | None: + if self.batching_strategy == "prefer_first": + return self._find_best_candidate_prefer_first(cur_max, num_chosen, seed_modality) + return self._find_best_candidate_prefer_closest(cur_max, num_chosen, seed_modality) + + def _find_best_candidate_prefer_first(self, cur_max: int, num_chosen: int, seed_modality: Modality) -> int | None: + best_idx = None + best_new_tokens = None + for idx, cand in enumerate(self._pool): + if self._get_modality(cand) != seed_modality: + continue + L = self.compute_sample_tokens(cand) + new_max = max(cur_max, L) + new_tokens = self._padded_cost(new_max, num_chosen + 1) + if new_tokens <= self._max_tokens(cur_max): + if best_new_tokens is None or new_tokens < best_new_tokens: + best_new_tokens = new_tokens + best_idx = idx + return best_idx + + def _find_best_candidate_prefer_closest(self, cur_max: int, num_chosen: int, seed_modality: Modality) -> int | None: + best_idx = None + best_new_tokens = None + smallest_length_diff = None + for idx, cand in enumerate(self._pool): + if self._get_modality(cand) != seed_modality: + continue + L = self.compute_sample_tokens(cand) + new_max = max(cur_max, L) + new_tokens = self._padded_cost(new_max, num_chosen + 1) + if new_tokens <= self._max_tokens(cur_max): + length_diff = abs(L - cur_max) + if ( + best_new_tokens is None + or new_tokens < best_new_tokens + or (new_tokens == best_new_tokens and length_diff < smallest_length_diff) + ): + best_new_tokens = new_tokens + best_idx = idx + smallest_length_diff = length_diff + return best_idx + + def _remove_from_pool(self, idx: int) -> dict: + if idx == 0: + return self._pool.popleft() + elif idx == len(self._pool) - 1: + return self._pool.pop() + else: + self._pool.rotate(-idx) + item = self._pool.popleft() + self._pool.rotate(idx) + return item diff --git a/cosmos_framework/data/vfm/processors/__init__.py b/cosmos_framework/data/vfm/processors/__init__.py index a1cc60e..ce98401 100644 --- a/cosmos_framework/data/vfm/processors/__init__.py +++ b/cosmos_framework/data/vfm/processors/__init__.py @@ -125,7 +125,12 @@ def build_processor( return Qwen3VLProcessor(tokenizer_type, credentials=credentials, bucket=bucket, cache_dir=cache_dir) elif "nvidia/NVIDIA-Nemotron-Nano-12B-v2-VL-BF16" in tokenizer_type: return NemotronVLProcessor(tokenizer_type, credentials=credentials, bucket=bucket, cache_dir=cache_dir) - elif "NVIDIA-Nemotron-3-Dense-VL" in tokenizer_type or "Qwen3-2B-ViT" in tokenizer_type: + elif ( + "NVIDIA-Nemotron-3-Dense-VL" in tokenizer_type + or "Qwen3-2B-ViT" in tokenizer_type + or "nvidia/Cosmos3-Reasoner-2B-Private" in tokenizer_type + or "nvidia/Cosmos3-Edge-Reasoner" in tokenizer_type + ): return Nemotron3DenseVLProcessor(tokenizer_type, credentials=credentials, bucket=bucket, cache_dir=cache_dir) elif "Qwen/Qwen3-0.6B" in tokenizer_type: local_path = _download_llm_tokenizer(tokenizer_type, credentials, bucket, cache_dir) diff --git a/cosmos_framework/data/vfm/processors/nemotronvl_processor.py b/cosmos_framework/data/vfm/processors/nemotronvl_processor.py index 767c8ef..077c80e 100644 --- a/cosmos_framework/data/vfm/processors/nemotronvl_processor.py +++ b/cosmos_framework/data/vfm/processors/nemotronvl_processor.py @@ -248,7 +248,6 @@ def __init__( # NemotronVL hardcodes these helper attributes because they are not # discoverable from the HF model config; the values match the upstream # vision-encoder configuration. - # HACK: hardcoded based on the model config. self.min_height_width = 512 self.patch_size = 16 self.temporal_patch_size = 1 @@ -258,7 +257,6 @@ def __init__( def _resolve_pad_id(self): # NemotronVL's tokenizer does not specify a pad_token; reserve # for padding (project convention). - return self.processor.tokenizer.convert_tokens_to_ids("") def apply_chat_template( @@ -396,7 +394,7 @@ def add_assistant_tokens_mask(self, tokens): import requests - response = requests.get("https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg") + response = requests.get("https://invalid_url") img = Image.open(BytesIO(response.content)) # test video diff --git a/cosmos_framework/data/vfm/processors/qwen3vl_processor.py b/cosmos_framework/data/vfm/processors/qwen3vl_processor.py index 030d040..dffa47c 100644 --- a/cosmos_framework/data/vfm/processors/qwen3vl_processor.py +++ b/cosmos_framework/data/vfm/processors/qwen3vl_processor.py @@ -71,7 +71,7 @@ def apply_chat_template( num_video, video_fps, video_total_num_frames, video_frames_indices = maybe_parse_video_content(messages) if num_video > 0: # Here we add the args to avoid the error: - # File "/usr/local/lib/python3.12/dist-packages/transformers/video_processing_utils.py", line 321, in _decode_and_sample_videos + # File "/invalid_dir", line 321, in _decode_and_sample_videos # raise ValueError( # ValueError: Sampling frames from a list of images is not supported! Set `do_sample_frames=False`. kwargs["videos_kwargs"] = dict(do_sample_frames=False) @@ -178,7 +178,7 @@ def add_assistant_tokens_mask(self, tokens): "content": [ { "type": "video", - "video": ["https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg"] * 4, + "video": ["https://invalid_url"] * 4, "fps": 12, }, {"type": "text", "text": "What is the capital of France?"}, diff --git a/cosmos_framework/data/vfm/sequence_packing.py b/cosmos_framework/data/vfm/sequence_packing.py index d2821cc..1209a2d 100644 --- a/cosmos_framework/data/vfm/sequence_packing.py +++ b/cosmos_framework/data/vfm/sequence_packing.py @@ -685,6 +685,7 @@ def _pack_vision_tokens( enable_fps_modulation: bool = False, base_fps: float = 24.0, temporal_compression_factor: int = 4, + vision_temporal_positions: torch.Tensor | None = None, ) -> int: """Pack vision tokens into the sequence. @@ -701,6 +702,8 @@ def _pack_vision_tokens( enable_fps_modulation: If True, scale temporal position IDs based on video FPS. base_fps: Base FPS for normalization (default 24.0). temporal_compression_factor: VAE temporal compression factor (default 4). + vision_temporal_positions: Optional explicit temporal coordinate per latent + frame, shape ``(T,)``. Used by UniAE to account for kept boundary latents. Returns: Vision split length. """ @@ -773,6 +776,8 @@ def _pack_vision_tokens( if packed_seq._use_mrope: # Determine FPS for this vision segment (None disables FPS modulation) effective_fps = vision_fps if enable_fps_modulation else None + if vision_temporal_positions is not None: + vision_temporal_positions = vision_temporal_positions.to(device="cpu", dtype=torch.float32) # [T] vision_mrope_ids, packed_seq._mrope_temporal_offset = get_3d_mrope_ids_vae_tokens( grid_t=latent_t, @@ -783,6 +788,8 @@ def _pack_vision_tokens( fps=effective_fps, base_fps=base_fps, temporal_compression_factor=temporal_compression_factor, + temporal_positions=vision_temporal_positions, + actual_temporal_compression_factor=temporal_compression_factor, ) # vision_mrope_ids: [3,N_vision_tokens] packed_seq.position_ids.append(vision_mrope_ids) else: @@ -850,7 +857,6 @@ def _pack_action_tokens( packed_seq.action.token_shapes.append((action_split_len,)) packed_seq.action.tokens.append(input_action_tokens) - condition_set = {idx for idx in condition_frame_indexes_action if 0 <= idx < action_split_len} assert isinstance(packed_seq.action.condition_mask, list) @@ -917,6 +923,7 @@ def _pack_sound_tokens( enable_fps_modulation: bool = False, base_fps: float = 24.0, sound_fps: float | None = None, + sound_base_temporal_compression_factor: int | None = None, ) -> int: """Pack sound/audio tokens into the sequence. @@ -936,6 +943,8 @@ def _pack_sound_tokens( enable_fps_modulation: If True, scale temporal positions by FPS ratio. base_fps: Base FPS for normalization (default 24.0). sound_fps: Sound latent FPS (e.g., 25.0). Used for FPS-aware m-RoPE positions. + sound_base_temporal_compression_factor: Base temporal compression factor for sound FPS scaling. + ``None`` preserves the current behavior where sound advances at ``base_fps`` positions/sec. Returns: Number of sound tokens added. @@ -1008,6 +1017,7 @@ def _pack_sound_tokens( fps=effective_fps, base_fps=base_fps, temporal_compression_factor=1, # Sound latent is already at sound_latent_fps (no further compression) + base_temporal_compression_factor=sound_base_temporal_compression_factor, start_frame_offset=0, # Sound[0] aligns with vision frame 0 ) # sound_mrope_ids: [3,N_sound_tokens] packed_seq.position_ids.append(sound_mrope_ids) @@ -1047,11 +1057,18 @@ def _pack_supertokens_temporal_causal( ``num_action_tokens_per_supertoken=0`` is stamped on the pack and read by the attention builder so NATTEN metadata stays in sync automatically. - mRoPE layout (with actions, unified_3d_mrope only): - - Null actions (frame 0): all tcf tokens at ``temporal_offset``. - - Real training actions (frames 1..T-1): ``start_frame_offset=1`` so the - last action in group i co-locates with vision frame i. - - AR real actions (single supertoken): ``start_frame_offset=0``. + mRoPE layout (with actions, unified_3d_mrope only). The layout is inferred from the + action tensor shape: + - Whole-clip training (frame 0 is the clean conditioning frame, so + ``real_actions`` has ``(T-1)*tcf`` rows): null action for supertoken 0, real + actions for frames 1..T-1 with ``start_frame_offset=1`` so the last action in + group i co-locates with vision frame i; vision uses ``start_frame_offset=0``. + - AR generation, single frame OR chunk (every frame carries a real action, so + ``real_actions`` has ``latent_t*tcf`` rows): vision AND action both use + ``start_frame_offset=1``, generalizing the single-frame AR supertoken to + ``latent_t`` frames. The caller (``pack_input_sequence_autoregressive``) + seeds ``temporal_offset`` one frame-stride back to compensate, so the unit + lands at the same absolute positions as the whole-clip training pack. - Interleaved per frame as cat([action_ids, vision_ids]). ``input_timestep`` is float (TF/none) or Tensor(T_max,) (DF, per-frame sigma). @@ -1094,32 +1111,36 @@ def _pack_supertokens_temporal_causal( if pack_action_tokens: # Build all_action_tokens: shape (latent_t * tcf, action_dim) # - # Cases: - # 1. Training with conditioning frame (latent_t > 1, real_actions < latent_t*tcf): - # Prepend tcf null tokens for frame 0, then real actions for frames 1..T-1. - # 2. KV-cache continuation (latent_t > 1, real_actions == latent_t*tcf): all supertokens - # carry real actions (no conditioning frame in-segment). - # 3. AR frame N>0 (latent_t == 1, action provided): real actions, no null prefix. - # 4. AR frame 0 / image2video (action is None): all null tokens. + # Cases (token assembly; mRoPE start_frame_offset is chosen separately below, + # inferred from the same action shape): + # 1. Whole-clip training with conditioning frame (latent_t > 1, real_actions + # has (T-1)*tcf rows): prepend tcf null tokens for frame 0, then real + # actions for frames 1..T-1. + # 2. AR generation (every frame has a real action, real_actions has + # latent_t*tcf rows — single frame OR chunk): no null prefix. + # 3. AR frame 0 / image2video (action is None): all null tokens. if input_action_tokens is not None: - # input_action_tokens shape: (1, T*tcf, D) or (T*tcf, D) for training; (tcf, D) for AR frame N>0 + # input_action_tokens shape: (1, T*tcf, D) or (T*tcf, D) for training; (T*tcf, D) for AR units if input_action_tokens.dim() == 3: real_actions = input_action_tokens.squeeze(0) # [T*tcf,action_dim] or [N,action_dim] else: real_actions = input_action_tokens # [N,action_dim] null_tokens = torch.zeros(tcf, action_dim, device=device, dtype=real_actions.dtype) # [tcf,action_dim] - if latent_t == 1: - # AR frame N>0: single supertoken with real actions, no null prefix - all_action_tokens = real_actions # [tcf,action_dim] - null_action_flag = False - elif real_actions.shape[0] == latent_t * tcf: - # All frames have real actions (e.g. KV-cache continuation segments) + if real_actions.shape[0] == latent_t * tcf: + # AR generation (single frame: tcf == 1*tcf, or chunk: latent_t*tcf): + # every supertoken carries a real action, no null prefix. all_action_tokens = real_actions null_action_flag = False - else: + elif real_actions.shape[0] == (latent_t - 1) * tcf: # Conditioning frame present: null for supertoken 0, real for 1..T-1 all_action_tokens = torch.cat([null_tokens, real_actions], dim=0) # [T*tcf,action_dim] null_action_flag = True + else: + raise ValueError( + "Temporal-causal action tokens must have either latent_t*tcf rows for AR chunks " + f"or (latent_t-1)*tcf rows for whole-clip training; got {real_actions.shape[0]} rows " + f"for latent_t={latent_t}, tcf={tcf}." + ) else: # AR frame 0 or image2video: all action tokens are null all_action_tokens = torch.zeros( @@ -1171,14 +1192,17 @@ def _pack_supertokens_temporal_causal( temporal_offset = packed_seq._mrope_temporal_offset effective_vision_fps = vision_fps if enable_fps_modulation else None - # AR frame N>=1 with action_gen=True (latent_t==1 and real actions supplied): - # shift both vision and action by start_frame_offset=1 so the last action in - # the group co-locates with vision frame N, mirroring training's layout. - # All other cases (training latent_t>1, AR action_gen=False, AR frame 0 null) - # keep start_frame_offset=0. The caller in pack_input_sequence_autoregressive - # seeds temporal_offset accordingly (N-1 frames back when this shift applies). - ar_with_real_actions = latent_t == 1 and pack_action_tokens and input_action_tokens is not None - vision_sfo = 1 if ar_with_real_actions else 0 + # AR generation (single frame OR chunk) is detected by every frame carrying a + # real action (``real_actions`` has ``latent_t*tcf`` rows). There, vision AND + # action both use start_frame_offset=1 so the last action in each group + # co-locates with its vision frame, mirroring whole-clip training; the caller + # (pack_input_sequence_autoregressive) seeds temporal_offset one frame-stride + # back to compensate. Whole-clip training (frame 0 is the null conditioning + # frame, ``real_actions`` has ``(T-1)*tcf`` rows) keeps vision start_frame_offset=0. + all_frames_have_real_action = ( + pack_action_tokens and input_action_tokens is not None and real_actions.shape[0] == latent_t * tcf + ) + vision_sfo = 1 if all_frames_have_real_action else 0 vision_ids_flat, new_offset = get_3d_mrope_ids_vae_tokens( grid_t=latent_t, @@ -1195,10 +1219,10 @@ def _pack_supertokens_temporal_causal( if pack_action_tokens: effective_action_fps = action_fps if enable_fps_modulation else None - # Action IDs: null for frame 0 (all tcf tokens share temporal_offset, - # co-located with vision frame 0), real for frames 1..T-1. - # Real tokens (training and AR) use start_frame_offset=1 so the last - # action in a group co-locates with vision frame i. + # Action IDs. Real action tokens use start_frame_offset=1 so the last + # sub-token of a group co-locates with its vision frame. Whole-clip training + # has a null action at frame 0 (the conditioning frame); AR units have a real + # action for every frame. fps_active = effective_action_fps is not None t_dtype = torch.float32 if fps_active else torch.long t_offset = float(temporal_offset) if fps_active else int(temporal_offset) @@ -1221,28 +1245,24 @@ def _real_action_ids(n_frames: int, start_frame_offset: int) -> torch.Tensor: ) return flat.reshape(3, n_frames, tcf) # [3,n_frames,tcf] - if latent_t > 1 and input_action_tokens is not None: - if real_actions.shape[0] == latent_t * tcf: - # KV continuation: real action in every supertoken (including frame 0) - action_ids_3d = _real_action_ids(latent_t, start_frame_offset=0) - else: - # Training with conditioning frame: supertoken 0 = null, 1..T-1 = real - null_ids_3d = null_ids.reshape(3, 1, tcf) # [3,1,tcf] - real_ids_3d = _real_action_ids(latent_t - 1, start_frame_offset=1) # [3,T-1,tcf] - action_ids_3d = torch.cat([null_ids_3d, real_ids_3d], dim=1) # [3,T,tcf] + if all_frames_have_real_action: + # AR generation (single frame: tcf == 1*tcf, or chunk: latent_t*tcf): + # every supertoken carries a real action. start_frame_offset=1 puts + # a_{j-1}'s last sub-token on vision frame j -- the whole-clip TF + # training layout. The caller seeds temporal_offset (N-1) frame-strides + # back to compensate. + action_ids_3d = _real_action_ids(latent_t, start_frame_offset=1) # [3,T,tcf] elif latent_t > 1: - # No action tensor (all-null layout): same ID structure as training w/ conditioning frame. + # Whole-clip training: supertoken 0 = null (conditioning frame), frames + # 1..T-1 = real with start_frame_offset=1. Covers real-action training + # (real_actions has (T-1)*tcf rows) and the architectural all-null layout + # (input_action_tokens is None); the tokens differ but the IDs match. null_ids_3d = null_ids.reshape(3, 1, tcf) # [3,1,tcf] real_ids_3d = _real_action_ids(latent_t - 1, start_frame_offset=1) # [3,T-1,tcf] action_ids_3d = torch.cat([null_ids_3d, real_ids_3d], dim=1) # [3,T,tcf] - elif input_action_tokens is None: - # AR frame 0 / image2video: only null - action_ids_3d = null_ids.reshape(3, 1, tcf) # [3,1,tcf] else: - # AR frame N>=1: single supertoken with real actions. start_frame_offset=1 - # matches training (last action co-locates with vision frame N); caller - # seeds temporal_offset to (N-1) frame-strides back to compensate. - action_ids_3d = _real_action_ids(1, start_frame_offset=1) # [3,1,tcf] + # AR frame 0 / image2video (latent_t == 1, no action): only null. + action_ids_3d = null_ids.reshape(3, 1, tcf) # [3,1,tcf] # (3, T*H*W) → (3, T, H*W) vision_ids_3d = vision_ids_flat.reshape(3, latent_t, patches_per_frame) # [3,T,patch_h*patch_w] @@ -1309,7 +1329,9 @@ def pack_input_sequence( unified_3d_mrope_temporal_modality_margin: int = 0, enable_fps_modulation: bool = False, base_fps: float = 24.0, + sound_base_temporal_compression_factor: int | None = None, temporal_compression_factor: int = 4, + vision_temporal_position_mode: str = "latent_index", video_temporal_causal: bool = False, action_dim: int = 32, initial_mrope_temporal_offset: int | float = 0, @@ -1347,8 +1369,13 @@ def pack_input_sequence( Uses the same flag as diffusion_expert_config.enable_fps_modulation. base_fps: Base FPS for normalization (default 24.0). Uses the same value as diffusion_expert_config.base_fps. + sound_base_temporal_compression_factor: Base temporal compression factor for sound FPS scaling. + ``None`` preserves the current behavior where sound advances at ``base_fps`` positions/sec. temporal_compression_factor: VAE temporal compression factor (default 4). Obtained from the VAE tokenizer at runtime. + vision_temporal_position_mode: Temporal coordinates used for unified_3d_mrope vision tokens. + "latent_index" keeps legacy positions; "uniae_source_right_edge" uses + per-latent positions from gen_data_clean.temporal_positions_vision. Returns: PackedSequence containing all packed tensors and metadata. See PackedSequence for field details. """ @@ -1361,6 +1388,44 @@ def pack_input_sequence( if isinstance(input_text_indexes, torch.Tensor): raise ValueError("input_text_tokens must be a list, not a tensor") + supported_vision_temporal_position_modes = {"latent_index", "uniae_source_right_edge"} + if vision_temporal_position_mode not in supported_vision_temporal_position_modes: + raise ValueError( + "Unsupported vision_temporal_position_mode: " + f"{vision_temporal_position_mode}. Supported modes: {supported_vision_temporal_position_modes}." + ) + has_any_vision = any(plan.has_vision for plan in sequence_plans) + explicit_vision_temporal_positions_active = vision_temporal_position_mode != "latent_index" and has_any_vision + if explicit_vision_temporal_positions_active: + if position_embedding_type != "unified_3d_mrope": + raise NotImplementedError( + "Explicit vision temporal positions are only supported with position_embedding_type='unified_3d_mrope'." + ) + if gen_data_clean.temporal_positions_vision is None: + raise ValueError( + f"vision_temporal_position_mode={vision_temporal_position_mode} requires " + "gen_data_clean.temporal_positions_vision." + ) + if gen_data_clean.x0_tokens_vision is not None and len(gen_data_clean.temporal_positions_vision) != len( + gen_data_clean.x0_tokens_vision + ): + raise ValueError( + "temporal_positions_vision must have one entry per x0_tokens_vision item, " + f"got {len(gen_data_clean.temporal_positions_vision)} positions for " + f"{len(gen_data_clean.x0_tokens_vision)} vision items." + ) + if video_temporal_causal: + raise NotImplementedError( + "video_temporal_causal=True is not wired for explicit UniAE vision temporal positions yet." + ) + if any(plan.has_action for plan in sequence_plans): + raise NotImplementedError("Action packing is not wired for explicit UniAE vision temporal positions yet.") + if initial_mrope_temporal_offset != 0: + raise NotImplementedError( + "Autoregressive mRoPE temporal offsets are not wired for explicit UniAE vision temporal positions yet." + ) + use_float_mrope_positions = enable_fps_modulation or explicit_vision_temporal_positions_active + # Initialize packed sequence (acts as builder during packing) packed_seq = PackedSequence() @@ -1405,7 +1470,7 @@ def pack_input_sequence( special_tokens, curr_rope_id, has_generation=has_generation_for_sample, - use_float_positions=enable_fps_modulation, + use_float_positions=use_float_mrope_positions, ) sample_len += text_sample_len @@ -1496,6 +1561,7 @@ def pack_input_sequence( shared_latent_t: int | None = None shared_patch_h: int | None = None shared_patch_w: int | None = None + shared_temporal_positions: torch.Tensor | None = None # FPS is recorded per-sample (shape [B]); for multi-item samples # (transfer / image-edit) every vision item in this sample shares # the same conditioning FPS, so we read by sample_idx, not by the @@ -1510,7 +1576,18 @@ def pack_input_sequence( sample_vision_fps = float(gen_data_clean.fps_vision[sample_idx].item()) for item_idx in range(num_vis): - input_vision_tokens = gen_data_clean.x0_tokens_vision[idx_vision] + flat_vision_idx = idx_vision + input_vision_tokens = gen_data_clean.x0_tokens_vision[flat_vision_idx] + vision_temporal_positions: torch.Tensor | None = None + if explicit_vision_temporal_positions_active: + assert gen_data_clean.temporal_positions_vision is not None + vision_temporal_positions = gen_data_clean.temporal_positions_vision[flat_vision_idx] + if vision_temporal_positions.shape[0] != input_vision_tokens.shape[2]: + raise ValueError( + "vision_temporal_positions must match latent_t for each vision item, " + f"got {vision_temporal_positions.shape[0]} positions and " + f"latent_t={input_vision_tokens.shape[2]} for item {flat_vision_idx}." + ) vision_fps = sample_vision_fps idx_vision += 1 @@ -1544,6 +1621,19 @@ def pack_input_sequence( f"got item {item_idx} (H,W)=({item_latent_h},{item_latent_w}) " f"vs first=({shared_patch_h},{shared_patch_w})" ) + if vision_temporal_positions is not None: + if shared_temporal_positions is None: + shared_temporal_positions = vision_temporal_positions + else: + comparison_temporal_positions = vision_temporal_positions.to( + device=shared_temporal_positions.device + ) # [T] + assert torch.allclose(comparison_temporal_positions, shared_temporal_positions), ( + "share_vision_temporal_positions requires equal explicit temporal positions " + f"across vision items, got item {item_idx} positions " + f"{vision_temporal_positions.tolist()} vs first " + f"{shared_temporal_positions.tolist()}." + ) # Rewind so this item starts at the same temporal offset as item 0. packed_seq._mrope_temporal_offset = items_temporal_offset_snapshot @@ -1558,6 +1648,7 @@ def pack_input_sequence( enable_fps_modulation=enable_fps_modulation, base_fps=base_fps, temporal_compression_factor=temporal_compression_factor, + vision_temporal_positions=vision_temporal_positions, ) vision_split_len += item_split_len sample_len += vision_split_len @@ -1622,6 +1713,7 @@ def pack_input_sequence( enable_fps_modulation=enable_fps_modulation, base_fps=base_fps, sound_fps=sound_fps, + sound_base_temporal_compression_factor=sound_base_temporal_compression_factor, ) sample_len += sound_split_len else: @@ -1641,8 +1733,8 @@ def pack_input_sequence( # EOV position IDs: 3D mRoPE or 1D RoPE if packed_seq._use_mrope: - # Use float dtype when FPS modulation is enabled for consistency - eov_dtype = torch.float32 if enable_fps_modulation else torch.long + # Use float dtype when any vision mRoPE positions are fractional. + eov_dtype = torch.float32 if use_float_mrope_positions else torch.long eov_mrope_ids = torch.full((3, 1), packed_seq._mrope_temporal_offset, dtype=eov_dtype) # [3,1] packed_seq.position_ids.append(eov_mrope_ids) # type: ignore[arg-type] packed_seq._mrope_temporal_offset += 1 @@ -2095,7 +2187,7 @@ def verify_natten_parameter_list( {'window_size_float': (0.5, 0.5), 'dilation_float': (1.0, 0.0)} # valid # Fixed window size of 8x8, dilation of 2x1. - + # NOTE: requires ALL inputs to be at least 16x8 {'window_size': (8, 8), 'dilation': (2, 1)} # valid # Multi-profile: different parameters for 2D (images) and 3D (videos) @@ -2231,7 +2323,7 @@ def generate_natten_metadata( {'window_size_float': (0.5, 0.5), 'dilation_float': (1.0, 0.0)} # valid # Fixed window size of 8x8, dilation of 2x1. - + # NOTE: requires ALL inputs to be at least 16x8 {'window_size': (8, 8), 'dilation': (2, 1)} # valid # Invalid: @@ -2363,9 +2455,9 @@ def filter_shape(shape: tuple) -> tuple: is_causal = dim_params["is_causal"] # Create varlen metadata for natten varlen/varsized ops - + # NOTE: generate_multi_dim_varlen_parameters will automatically map window size -1 to # full size, that's why constant window sizes aren't allowed. - + # NOTE: if any of the parameters are constant, natten will simplify them natten_metadata.append( generate_multi_dim_varlen_parameters( token_layout_list=token_layout_list, @@ -2780,7 +2872,6 @@ def build_sequence_plans_from_data_batch( Returns: List of SequencePlan objects, one per sample in the batch. """ - # For new modalities, please generate the sequence_plan in the dataset class!!!! # If sequence_plan already exists in data_batch, return it @@ -2790,7 +2881,6 @@ def build_sequence_plans_from_data_batch( assert "action" not in data_batch or data_batch["action"] is None, "Action data SHOULD have sequence_plans!" assert "sound" not in data_batch or data_batch["sound"] is None, "Sound data SHOULD have sequence_plans!" - # Determine batch size from available tensors batch_size = 0 for key in [input_video_key, input_image_key]: diff --git a/cosmos_framework/data/vfm/sound_data_utils.py b/cosmos_framework/data/vfm/sound_data_utils.py index 0a8ec63..2d739b0 100644 --- a/cosmos_framework/data/vfm/sound_data_utils.py +++ b/cosmos_framework/data/vfm/sound_data_utils.py @@ -3,7 +3,8 @@ """Sound data utilities for building sequence plans and handling audio-video generation modes. -This module provides utilities for building SequencePlan objects based on sound generation modes. +This module provides utilities for building SequencePlan objects based on sound generation modes, +similar to how action modes are handled in cosmos_framework/data/vfm/action/data_utils.py. Supported modes: - t2vs: Text → Video + Sound (joint generation) @@ -27,7 +28,8 @@ def build_sequence_plan_for_sound( """Build a SequencePlan based on the sound generation mode. This function determines the appropriate condition frame indexes for vision and sound - based on the specified mode. + based on the specified mode. It mirrors how `build_sequence_plan_from_mode` works + for action in cosmos_framework/data/vfm/action/data_utils.py. Args: mode: Generation mode. One of: diff --git a/cosmos_framework/data/vfm/utils.py b/cosmos_framework/data/vfm/utils.py index ab6f238..8b04f0a 100644 --- a/cosmos_framework/data/vfm/utils.py +++ b/cosmos_framework/data/vfm/utils.py @@ -6,7 +6,6 @@ IMAGE_RES_SIZE_INFO: dict[str, dict[str, tuple[int, int]]] = { # Our desired 256 resolution is the one below (commented). - # Desired: "256": {"1,1": (336, 336), "4,3": (384, 288), "3,4": (288, 384), "16,9": (448, 256), "9,16": (256, 448)}, "256": { "1,1": (256, 256), @@ -41,7 +40,6 @@ VIDEO_RES_SIZE_INFO: dict[str, dict[str, tuple[int, int]]] = { # Our desired 256 resolution is the one below (commented). - # Desired: "256": {"1,1": (336, 336), "4,3": (384, 288), "3,4": (288, 384), "16,9": (448, 256), "9,16": (256, 448)}, "256": { "1,1": (256, 256), @@ -111,10 +109,42 @@ def parse_frame_range_from_wdinfo(wdinfo: str) -> tuple[int, int] | None: return None +def _normalize_skip_frame_ranges( + skip_frame_range: str | list[str] | None, +) -> set[tuple[int, int]]: + """Normalize ``skip_frame_range`` into a set of (min_frames, max_frames) buckets. + + Args: + skip_frame_range: A single bucket string like ``"300_400"``, a list of such + strings, or None. Each string identifies the frame-range bucket + (e.g. ``frames_300_400``) that should be skipped. + + Returns: + Set of (min_frames, max_frames) tuples to skip. Empty if ``skip_frame_range`` is None. + """ + if skip_frame_range is None: + return set() + + if isinstance(skip_frame_range, str): + skip_frame_range = [skip_frame_range] + + skip_buckets: set[tuple[int, int]] = set() + for bucket in skip_frame_range: + match = re.fullmatch(r"(\d+)_(\d+)", bucket.strip()) + if match is None: + raise ValueError( + f"Invalid skip_frame_range entry {bucket!r}. Expected the form '_', e.g. '300_400'." + ) + skip_buckets.add((int(match.group(1)), int(match.group(2)))) + + return skip_buckets + + def filter_wdinfos_by_frame_range( wdinfos: list[str], min_frames: int | None = None, max_frames: int | None = None, + skip_frame_range: str | list[str] | None = None, ) -> list[str]: """ Filter wdinfo files based on frame range. @@ -125,10 +155,16 @@ def filter_wdinfos_by_frame_range( - min_frames is EXCLUSIVE: wdinfo_max must be > min_frames - max_frames is INCLUSIVE: wdinfo_max must be <= max_frames + Additionally, any wdinfo whose frame-range bucket matches an entry in + ``skip_frame_range`` is excluded. + Args: wdinfos: List of wdinfo paths min_frames: Minimum number of frames (exclusive). If None, no lower bound. max_frames: Maximum number of frames (inclusive). If None, no upper bound. + skip_frame_range: Frame-range bucket(s) to exclude, e.g. ``"300_400"`` to + drop the ``frames_300_400`` bucket. Accepts a single string or a list + of strings. If None, no bucket is skipped. Returns: Filtered list of wdinfo paths @@ -144,8 +180,14 @@ def filter_wdinfos_by_frame_range( # frames_400_500 excluded because wdinfo_max (500) <= min_frames (500) # frames_500_600 included because wdinfo_max (600) > min_frames (500) AND <= max_frames (600) # frames_600_700 excluded because wdinfo_max (700) > max_frames (600) + + >>> filter_wdinfos_by_frame_range(wdinfos, skip_frame_range="500_600") + ['wdinfo/frames_400_500/wdinfo.json', 'wdinfo/frames_600_700/wdinfo.json'] + # frames_500_600 excluded because its bucket matches skip_frame_range """ - if min_frames is None and max_frames is None: + skip_buckets = _normalize_skip_frame_ranges(skip_frame_range) + + if min_frames is None and max_frames is None and not skip_buckets: return wdinfos filtered = [] @@ -158,6 +200,10 @@ def filter_wdinfos_by_frame_range( wdinfo_min, wdinfo_max = frame_range + # Skip explicitly excluded buckets (matched on the full (min, max) bucket). + if (wdinfo_min, wdinfo_max) in skip_buckets: + continue + # Filter based on wdinfo's upper bound (wdinfo_max): # - min_frames is exclusive: wdinfo_max must be > min_frames # - max_frames is inclusive: wdinfo_max must be <= max_frames diff --git a/cosmos_framework/data/vfm/vlm/video_decoder_qwen.py b/cosmos_framework/data/vfm/vlm/video_decoder_qwen.py new file mode 100644 index 0000000..12c9bcc --- /dev/null +++ b/cosmos_framework/data/vfm/vlm/video_decoder_qwen.py @@ -0,0 +1,249 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: OpenMDW-1.1 + +""" +Copied from projects/cosmos/reason1/datasets/video_decoder_qwen.py +Changes: +1: remove hardcoded hyper-parameters for Qwen, now read it from processor +2: support skipping smart resize, since it may resize the video frames to be smaller than model input and frames will get resized up later in processor +""" + +import random +import re +from concurrent.futures import ThreadPoolExecutor +from functools import partial +from typing import Callable, Optional + +import torch +from PIL import Image +from qwen_vl_utils.vision_process import smart_nframes, smart_resize +from torchcodec.decoders import VideoDecoder +from torchvision import transforms +from torchvision.transforms import InterpolationMode + +from cosmos_framework.utils import log +from cosmos_framework.data.vfm.processors.qwen3vl_processor import Qwen3VLProcessor + +Image.MAX_IMAGE_PIXELS = 933120000 +_VIDEO_EXTENSIONS = "mp4 avi webm mov".split() + +VIDEO_DECODER_OPTIONS = {} + + +def token_to_pixels(token_length: int, patch_size: int = 14, temporal_patch_size: int = 2, merge_size: int = 2) -> int: + """Convert token length to pixels based on patch size and temporal patch size. + + Args: + token_length: Token length + patch_size: Patch size + temporal_patch_size: Temporal patch size, + for Qwen it has 3D conv, temporal patch size is 2; for other models like internVL or eagle er, the temporal patch size is 1 since their VIT is image encoder; + merge_size: Merge size, or called pixel shuffing factor; + for Qwen and internVL it is 2; for eagle er it is 1; + """ + merged_patch_size = patch_size * merge_size + return token_length * merged_patch_size**2 * temporal_patch_size + + +def pixels_to_token(pixels: int, patch_size: int = 14, temporal_patch_size: int = 2, merge_size: int = 2) -> int: + """Convert pixels to token length based on patch size and temporal patch size.""" + merged_patch_size = patch_size * merge_size + return pixels // merged_patch_size**2 // temporal_patch_size + + +def video_decoder_qwen( + num_threads: int = 0, + min_fps_thres: int = 4, + max_fps_thres: int = 60, + target_fps: float = 2.0, + min_video_token_length: int = 16, + max_video_token_length: int = 8192, + random_augmentation: bool = False, + frame_count_random_range: Optional[list[int]] = None, + **kwargs, +) -> Callable: + """ + Sampling video frames similar to Qwen. It prioritizes matching the target FPS first and then resizing the video frames. + See https://github.com/kq-chen/qwen-vl-utils/blob/main/src/qwen_vl_utils/vision_process.py#L118 for more details. + + Args: + key: Video file name/key + data: Video binary data + min_fps_thres: Minimum FPS threshold + max_fps_thres: Maximum FPS threshold + target_fps: Target FPS + min_video_token_length: Minimum token length + max_video_token_length: Maximum token length + num_threads: Number of threads for the torchcodec video decoder + random_augmentation: Whether to randomize the FPS and max_video_token_length + frame_count_random_range: Random frame count range + + Returns: + dict with video frames tensor and target FPS + """ + + video_decoder_configured = partial( + _video_decoder_qwen_func, + min_fps_thres=min_fps_thres, + max_fps_thres=max_fps_thres, + num_threads=num_threads, + target_fps=target_fps, + min_video_token_length=min_video_token_length, + max_video_token_length=max_video_token_length, + random_augmentation=random_augmentation, + frame_count_random_range=frame_count_random_range, + ) + + return video_decoder_configured + + +def _video_decoder_qwen_func( + key: str, + data: bytes, + processor: Qwen3VLProcessor, + min_fps_thres: int = 4, + max_fps_thres: int = 60, + target_fps: float = 2.0, + min_video_token_length: int = 16, + max_video_token_length: int = 8192, + num_threads: int = 0, + random_augmentation: bool = False, + fps_random_range: list[float] = [0.5, 1.5], + max_video_token_length_random_range: list[float] = [0.75, 1.25], + frame_count_random_range: Optional[list[int]] = None, + start_frame: Optional[int] = None, + end_frame: Optional[int] = None, + decoding_timeout: int = 60, + **kwargs, +) -> dict | None: + """Actual video decoder function. + + Args: + key (str): Video file name/key + data (bytes): Video binary data + min_fps_thres (int, optional): Minimum FPS threshold. Defaults to 4. + max_fps_thres (int, optional): Maximum FPS threshold. Defaults to 60. + target_fps (float, optional): Target FPS. Defaults to 2.0. + min_video_token_length (int, optional): Minimum token length. Defaults to 16. + max_video_token_length (int, optional): Maximum token length. Defaults to 8192. + num_threads (int, optional): Number of threads for the torchcodec video decoder. Defaults to 0. + random_augmentation (bool, optional): Whether to randomize the FPS and max_video_token_length. Defaults to False. + fps_random_range (list[float], optional): Random FPS range. Defaults to [10.0, 24.0]. + max_video_token_length_random_range (list[float], optional): Random max_video_token_length range. Defaults to [0.75, 1.25]. + frame_count_random_range (list[int], optional): Random frame count range. If provided, take priority over fps_random_range. + start_frame (Optional[int], optional): Start frame. Defaults to None. If both start_frame and end_frame are provided, the video will be decoded from start_frame to end_frame. + end_frame (Optional[int], optional): End frame. Defaults to None. If both start_frame and end_frame are provided, the video will be decoded from start_frame to end_frame. + decoding_timeout (int, optional): Timeout in seconds. Defaults to 60. + Raises: + ValueError: Video fps lower than 1, skipping + ValueError: Video fps lower than min_fps_thres, skipping + ValueError: Video fps higher than max_fps_thres, skipping + + Returns: + dict | None: Dictionary with video frames tensor and target FPS + """ + # Check video extension + extension = re.sub(r".*[.]", "", key) + if extension.lower() not in _VIDEO_EXTENSIONS: + return None + + # Read video with torchcodec + video_reader = VideoDecoder(data, num_ffmpeg_threads=num_threads) + total_frames = video_reader.metadata.num_frames + video_fps = video_reader.metadata.average_fps + + # torchcodec returns ``None`` for containers that don't store frame count + # or average fps (e.g. some MKV/WebM streams). Downstream arithmetic + # (``total_frames - 1``, ``video_fps < 1``, ...) would TypeError on None; + # surface a ValueError so the dataloader's skip path handles it uniformly. + if total_frames is None or video_fps is None: + raise ValueError(f"torchcodec missing metadata (num_frames={total_frames}, average_fps={video_fps}), skipping") + + if start_frame is not None and end_frame is not None: + total_frames = end_frame - start_frame + + if video_fps < 1: + raise ValueError("Video fps lower than 1, skipping") + if video_fps < min_fps_thres: + raise ValueError(f"Video fps {video_fps} lower than {min_fps_thres}, skipping") + if video_fps > max_fps_thres: + raise ValueError(f"Video fps {video_fps} higher than {max_fps_thres}, skipping") + + if random_augmentation: + if frame_count_random_range is not None: + # Random number of frames + min_frames_range, max_frames_range = frame_count_random_range + min_frames_range = min(min_frames_range, total_frames) + max_frames_range = min(max_frames_range, total_frames) + target_frames = random.uniform(min_frames_range, max_frames_range) + target_fps = target_frames / total_frames * video_fps + else: + # randomize fps + target_fps = ( + random.uniform(fps_random_range[0], fps_random_range[1]) * target_fps + if random.random() < 0.5 + else target_fps + ) + # randomize max_video_token_length + max_video_token_length = int( + random.uniform(max_video_token_length_random_range[0], max_video_token_length_random_range[1]) + * max_video_token_length + ) + log.debug(f"random_augmentation: max_video_token_length: {max_video_token_length}, target_fps: {target_fps}") + + patch_size = processor.patch_size + min_height_width = processor.min_height_width + temporal_patch_size = processor.temporal_patch_size + merge_size = processor.merge_size + min_pixels: int = token_to_pixels(min_video_token_length, patch_size, temporal_patch_size, merge_size) + max_pixels: int = token_to_pixels(max_video_token_length, patch_size, temporal_patch_size, merge_size) + max_frames: int = max_pixels // (min_height_width) ** 2 // temporal_patch_size + + # sample based on target fps + nframes = smart_nframes(dict(fps=target_fps), total_frames=total_frames, video_fps=video_fps) + nframes = min(nframes, max_frames) + if start_frame is not None and end_frame is not None: + idx = torch.linspace(start_frame, end_frame - 1, nframes).round().long().tolist() # [nframes] + else: + idx = torch.linspace(0, total_frames - 1, nframes).round().long().tolist() # [nframes] + + def _decode_video() -> torch.Tensor: + return video_reader.get_frames_at(indices=idx).data # [T, C, H, W] uint8 + + # Use ThreadPoolExecutor to run video decoding with a timeout. + # If the thread is stuck, abandon it immediately. + executor = ThreadPoolExecutor(max_workers=1) + future = executor.submit(_decode_video) + try: + video_frames = future.result(timeout=decoding_timeout) + executor.shutdown(wait=False) + except TimeoutError as e: + log.warning(f"[{key}] Video decoding timed out after {decoding_timeout} seconds") + executor.shutdown(wait=False) + return None + + sample_fps = nframes / max(total_frames, 1e-6) * video_fps + + # recompute max_pixels based on number of sampled frames + nframes, _, height, width = video_frames.shape + max_pixels = max_pixels // nframes + if processor.use_smart_resize: + resized_height, resized_width = smart_resize( + height, + width, + factor=patch_size * merge_size, + min_pixels=min_pixels, + max_pixels=max_pixels, + ) + log.debug( + f"resized_height: {resized_height}, resized_width: {resized_width} | original height: {height}, original width: {width}" + ) + video_frames = transforms.functional.resize( + video_frames, + [resized_height, resized_width], + interpolation=InterpolationMode.BICUBIC, + antialias=True, + ).float() # [T,C,H,W] + video_frames = video_frames.permute(1, 0, 2, 3) # [C,T,H,W] + + return dict(videos=video_frames, fps=sample_fps) diff --git a/cosmos_framework/data/vlm/processors/nemotron3densevl_processor.py b/cosmos_framework/data/vlm/processors/nemotron3densevl_processor.py new file mode 100644 index 0000000..fd8406f --- /dev/null +++ b/cosmos_framework/data/vlm/processors/nemotron3densevl_processor.py @@ -0,0 +1,248 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: OpenMDW-1.1 + +import os +from typing import Dict, List, Optional + +import numpy as np +import torch +from PIL import Image +from qwen_vl_utils.vision_process import smart_resize +from transformers.models.auto.processing_auto import AutoProcessor + +from cosmos_framework.utils import log +from cosmos_framework.utils.vlm.pretrained_models_downloader import maybe_download_hf_model_from_s3 + + +def convert_string_content_to_list_content(messages: List[Dict]) -> List[Dict]: + """ + Convert the string content to a list of dicts. + """ + for message_id, message in enumerate(messages): + if isinstance(message["content"], str): + messages[message_id]["content"] = [{"type": "text", "text": message["content"]}] + return messages + + +def maybe_parse_video_content( + messages: List[Dict], +) -> tuple[int, Optional[list[float]], Optional[list[int]], Optional[list[list[int]]]]: + """ + Convert the string content to a list of dicts. + """ + num_video = 0 + video_fps = [] + video_total_num_frames = [] + video_frames_indices = [] + for message_id, message in enumerate(messages): + if isinstance(message["content"], list): + for sub_content in message["content"]: + if sub_content.get("type", "") == "video" and isinstance(sub_content["video"], list): + num_video += 1 + fps = sub_content.get("fps", None) + if fps is None: + log.critical( + f"fps is None for video {sub_content}. Better to set the fps explicitly", rank0_only=False + ) + video_fps.append(fps) + video_total_num_frames.append(len(sub_content["video"])) + video_frames_indices.append(list(range(video_total_num_frames[-1]))) + return num_video, video_fps, video_total_num_frames, video_frames_indices + + +class Nemotron3DenseVLProcessor: + # This is a wrapper around the AutoProcessor class to add some helper functions + def __init__( + self, + name="Qwen/Qwen3-VL-2B-Init", + credentials: str = "./credentials/s3_training.secret", + bucket: str = "bucket4", + cache_dir: str = None, + ): + self.name = name + if os.path.isdir(name): + model_name_or_path_local = name + else: + model_name_or_path_local = maybe_download_hf_model_from_s3( + name, credentials, bucket, include_model_weights=False + ) + + self.processor = AutoProcessor.from_pretrained(model_name_or_path_local, trust_remote_code=True) + log.info("Successfully loaded processor from local cache") + + if hasattr(self.processor, "image_token"): + self.image_token_id = self.processor.tokenizer.convert_tokens_to_ids(self.processor.image_token) + else: + self.image_token_id = None + if hasattr(self.processor, "video_token"): + self.video_token_id = self.processor.tokenizer.convert_tokens_to_ids(self.processor.video_token) + else: + self.video_token_id = None + self.eos_id = self.processor.tokenizer.eos_token_id + self.pad_id = self.processor.tokenizer.pad_token_id + self.vision_end_id = self.processor.tokenizer.convert_tokens_to_ids("") + + # Helper attributes for the dataloader video decoding function + self.shortest_edge = self.processor.image_processor.size["shortest_edge"] + self.min_height_width = int(np.sqrt(self.shortest_edge)) + self.patch_size = self.processor.video_processor.patch_size + self.temporal_patch_size = self.processor.video_processor.temporal_patch_size + self.merge_size = self.processor.video_processor.merge_size + self.use_smart_resize = True + if self.pad_id is None: + self.pad_id = self.eos_id + + def apply_chat_template( + self, + messages, + add_generation_prompt=False, + return_tensors="pt", + tokenize=True, + **kwargs, + ): + """ + Return: + inputs: dict + input_ids: torch.Tensor, shape: (N_token) + attention_mask: torch.Tensor, shape: (N_token) + texts: str, the raw text + image_sizes: torch.Tensor, shape (N_img, 2) + pixel_values: torch.Tensor, shape (N_img_patch, 3, 224, 224) + """ + + # messages = [msg for msg in messages if msg.get("role") != "system"] + assert tokenize, "tokenize must be True" + assert return_tensors == "pt", "return_tensors must be pt" + # Note: this tokenizer does not support "content": str, it always expect "content" entry to be a list of dicts + messages = convert_string_content_to_list_content(messages) + kwargs = {} + for message_id, message in enumerate(messages): + if isinstance(message["content"], list): + for sub_content in message["content"]: + if sub_content.get("type", "") == "image": + image = sub_content["image"] + max_pixels = sub_content.get("max_pixels", self.processor.image_processor.size["longest_edge"]) + min_pixels = sub_content.get("min_pixels", self.processor.image_processor.size["shortest_edge"]) + assert isinstance(image, Image.Image), ( + "image must be a url string for now, not support list of images for one content" + ) + width, height = image.size + resized_height, resized_width = smart_resize( + height, + width, + factor=32, + min_pixels=min_pixels, + max_pixels=max_pixels, + ) + image = image.resize((resized_width, resized_height)) + sub_content["image"] = image + + num_video, video_fps, video_total_num_frames, video_frames_indices = maybe_parse_video_content(messages) + if num_video > 0: + # Here we add the args to avoid the error: + # File "/invalid_dir", line 321, in _decode_and_sample_videos + # raise ValueError( + # ValueError: Sampling frames from a list of images is not supported! Set `do_sample_frames=False`. + video_metadata = [ + dict(fps=fps, total_num_frames=total_num_frames, frames_indices=frames_indices) + for fps, total_num_frames, frames_indices in zip( + video_fps, video_total_num_frames, video_frames_indices + ) + ] + kwargs["videos_kwargs"] = { + "do_sample_frames": False, + "video_metadata": video_metadata[0] if num_video == 1 else video_metadata, + } + + inputs = self.processor.apply_chat_template( + messages, + tokenize=tokenize, + add_generation_prompt=add_generation_prompt, + return_dict=True, + return_tensors=return_tensors, + # padding="max_length", + # max_length=16000, + # truncation=False, + **kwargs, + ) + + # Convert batch features into single features + # By default, the processor returns a batch of features, but we use processor in dataloader, so we need to convert it to single features + inputs["input_ids"] = inputs["input_ids"][0] # [N_token] + inputs["attention_mask"] = inputs["attention_mask"][0] # [N_token] + num_image_tokens = inputs["input_ids"] == self.image_token_id # [N_token] + num_video_tokens = inputs["input_ids"] == self.video_token_id # [N_token] + return inputs + + def add_assistant_tokens_mask(self, tokens): + """ + Add a mask to the assistant tokens. + This is used to mask out tokens that are not generated by the assistant (e.g., system prompts, user prompts, chat templates), such that in the loss computation, only the tokens generated by the assistant are used. + If there are multiple turns in the conversation, the mask will mask all the assistant tokens in each turn. + + Args: + tokens (Union[List[int], torch.Tensor]): The tokens to add the mask to. + Returns: + Union[List[bool], torch.Tensor]: The mask. True for tokens generated by the assistant (i.e. should apply loss on), False for tokens not generated by the assistant. + """ + if isinstance(tokens, torch.Tensor) and tokens.ndim == 2: + mask = torch.stack( + [self.add_assistant_tokens_mask(tokens[i]) for i in range(tokens.shape[0])] + ) # [B,N_token] + assert mask.shape == tokens.shape + return mask + np_tokens = tokens.cpu().numpy() if isinstance(tokens, torch.Tensor) else np.array(tokens) + assert np_tokens.ndim == 1 + + # Constants defining bos, eos and fixed offsets. + BOS_TOKEN = "<|im_start|>" + EOS_TOKEN = "<|im_end|>" + ROLE = "assistant" + # Offsets: skip the bos + "assistant\n" (always 3 tokens) and include the eos (+1) for supervision + START_OFFSET = 3 + END_OFFSET = 1 + + # Retrieve token IDs for the markers and the role. + bos_token_id = self.processor.tokenizer.convert_tokens_to_ids(BOS_TOKEN) + eos_token_id = self.processor.tokenizer.convert_tokens_to_ids(EOS_TOKEN) + role_id = self.processor.tokenizer.convert_tokens_to_ids(ROLE) + role_ids = self.processor.tokenizer.encode( + ROLE, add_special_tokens=False + ) # In case the role_id corresponds to multiple tokens, decode it back to string for accurate comparison + think_start_id = self.processor.tokenizer.convert_tokens_to_ids("") + think_end_id = self.processor.tokenizer.convert_tokens_to_ids("") + + # Locate all positions where the start and end markers appear. + start_indices = np.where(np_tokens == bos_token_id)[0] + end_indices = np.where(np_tokens == eos_token_id)[0] + + # Initialize the mask with False values. + masks = np.zeros_like(np_tokens, dtype=bool) + assert len(start_indices) == len(end_indices) + # For each pair of bos/eos, check if the role is 'assistant' + # and apply the mask accordingly. + for start, end in zip(start_indices, end_indices): + end_pos = None + if np_tokens[start + 1] == role_id: + # Mask tokens from after the assistant header (start+3) to include the end marker (end+1) + masks[start + START_OFFSET : end + END_OFFSET] = True + end_pos = start + START_OFFSET + elif all(np_tokens[start + 1 : start + 1 + len(role_ids)] == role_ids): + masks[start + START_OFFSET + len(role_ids) - 1 : end + END_OFFSET] = True + end_pos = start + START_OFFSET + len(role_ids) - 1 + if end_pos is not None and np_tokens[end_pos] == think_start_id: + masks[end_pos] = False + if np_tokens[end_pos + 1] == think_end_id: + masks[end_pos + 1] = False + + assert masks.shape == np_tokens.shape + if isinstance(tokens, torch.Tensor): + return torch.from_numpy(masks) + else: + return masks.tolist() + + def encode(self, *args, **kwargs): + return self.processor.encode(*args, **kwargs) + + def decode(self, *args, **kwargs): + return self.processor.decode(*args, **kwargs) diff --git a/cosmos_framework/data/vlm/processors/nemotronvl_processor.py b/cosmos_framework/data/vlm/processors/nemotronvl_processor.py new file mode 100644 index 0000000..1fde099 --- /dev/null +++ b/cosmos_framework/data/vlm/processors/nemotronvl_processor.py @@ -0,0 +1,553 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: OpenMDW-1.1 + +import os +from typing import Dict, List, Optional + +import numpy as np +import torch +from PIL import Image +from transformers.models.auto.processing_auto import AutoProcessor +from transformers.processing_utils import VideosKwargs +from transformers.video_utils import VideoMetadata + +from cosmos_framework.utils import log +from cosmos_framework.utils.vlm.pretrained_models_downloader import maybe_download_hf_model_from_s3 + +nemotron_chat_template = """ +{%- set ns = namespace(enable_thinking=false, has_sys_prompt=false, non_tool_system_content='', has_video=false, explicit_think_requested=false) -%} +{%- set msg = namespace(content='') -%} +{%- for message in messages -%} + {%- if message['role'] == 'system' -%} + {%- set ns.has_sys_prompt = true -%} + {# Extract system content without tool flags #} + {%- if message['content'] is string -%} + {%- set ns.non_tool_system_content = message['content'].replace('', '<_end_think>').replace('/think', '').replace('/no_think', '').replace('<_end_think>', '').strip() -%} + {%- else -%} + {%- set ns.non_tool_system_content = '' -%} + {%- for content in message['content'] -%} + {%- if content['type'] == 'text' -%} + {%- set ns.non_tool_system_content = ns.non_tool_system_content + content['text'].replace('', '<_end_think>').replace('/think', '').replace('/no_think', '').replace('<_end_think>', '') -%} + {%- endif -%} + {%- endfor -%} + {%- set ns.non_tool_system_content = ns.non_tool_system_content.strip() -%} + {%- endif -%} + {%- endif -%} + {# Check for video content in all messages #} + {%- if message['content'] is not string -%} + {%- for content in message['content'] -%} + {%- if content['type'] == 'video' or content['type'] == 'video_url' -%} + {%- set ns.has_video = true -%} + {%- endif -%} + {%- endfor -%} + {%- endif -%} + {%- if message['content'] is string -%} + {%- if message['role'] == 'user' or message['role'] == 'system' -%} + {%- if '/think' in message['content'].replace('', '') -%} + {%- set ns.enable_thinking = true -%} + {%- set ns.explicit_think_requested = true -%} + {%- elif '/no_think' in message['content'] -%} + {%- set ns.enable_thinking = false -%} + {%- endif -%} + {%- endif -%} + {%- else -%} + {%- for content in message['content'] -%} + {%- if content['type'] == 'text' -%} + {%- if message['role'] == 'user' or message['role'] == 'system' -%} + {%- if '/think' in content['text'].replace('', '') -%} + {%- set ns.enable_thinking = true -%} + {%- set ns.explicit_think_requested = true -%} + {%- elif '/no_think' in content['text'] -%} + {%- set ns.enable_thinking = false -%} + {%- endif -%} + {%- endif -%} + {%- endif -%} + {%- endfor -%} + {%- endif -%} +{%- endfor -%} + +{{- bos_token -}} +{%- if messages[0]['role'] != 'system' -%} + {{- 'System\n' -}} +{%- else -%} + {{- 'System\n' + ns.non_tool_system_content }} +{%- endif -%} + +{%- if tools -%} + {%- if ns.non_tool_system_content != '' -%} + {{- '\n\n' -}} + {%- endif -%} + {{- 'You can use the following tools to assist the user if required:\n' -}} + {{- '[' -}} + {%- for tool in tools -%} + {{- (tool.function if tool.function is defined else tool) | tojson -}} + {{- ', ' if not loop.last else '' -}} + {%- endfor -%} + {{- ']\n\n' -}} + + {{- 'If you decide to call any tool(s), use the following format:\n' -}} + {{- '[{"name": "tool_name1", "arguments": "tool_args1"}, ' -}} + {{- '{"name": "tool_name2", "arguments": "tool_args2"}]\n\n' -}} + + {{- 'The user will execute tool-calls and return responses from tool(s) in this format:\n' -}} + {{- '[{"response": "tool_response1"}, ' -}} + {{- '{"response": "tool_response2"}]\n\n' -}} + + {{- 'Based on the tool responses, you can call additional tools if needed, ' -}} + {{- 'correct tool calls if any errors are found, or just respond to the user.' -}} +{%- endif -%} +{{- '\n' -}} + +{%- set messages = messages[1:] if messages[0]['role'] == 'system' else messages -%} + +{# Prevent no user or assistant message #} +{%- if messages|length == 0 -%} + {%- set messages = [{'role': 'user', 'content': ''}] -%} +{%- endif -%} + +{%- for message in messages %} + {%- if message['content'] is string -%} + {%- set msg.content = message['content'].replace('', '<_end_think>').replace('/think', '').replace('/no_think', '').replace('<_end_think>', '').strip() -%} + {%- else -%} + {%- set msg.content = '' -%} + {%- set mm_content = '' -%} + {%- set counters = namespace(images=0, videos=0) -%} + + {%- for content in message['content'] -%} + {%- if content['type'] == 'image' -%} + {%- set counters.images = counters.images + 1 -%} + {%- elif content['type'] == 'video' -%} + {%- set counters.videos = counters.videos + 1 -%} + {%- elif content['type'] == 'text' -%} + {%- set msg.content = msg.content + content['text'] -%} + {%- endif -%} + {%- endfor -%} + {%- if '' in msg.content -%} + {%- set counters.images = 0 -%} + {%- endif -%} + {%- if '