diff --git a/cosmos_framework/auxiliary/guardrail/common/presets.py b/cosmos_framework/auxiliary/guardrail/common/presets.py index d320b5e..20bba86 100644 --- a/cosmos_framework/auxiliary/guardrail/common/presets.py +++ b/cosmos_framework/auxiliary/guardrail/common/presets.py @@ -26,9 +26,7 @@ def create_text_guardrail_runner(offload_model_to_cpu: bool = False) -> Guardrai def create_video_guardrail_runner(offload_model_to_cpu: bool = False) -> GuardrailRunner: """Create the video guardrail runner.""" return GuardrailRunner( - safety_models=[ - # VideoContentSafetyFilter(offload_model_to_cpu=offload_model_to_cpu), # Too many false positives - ], + safety_models=[VideoContentSafetyFilter(offload_model_to_cpu=offload_model_to_cpu)], postprocessors=[RetinaFaceFilter(offload_model_to_cpu=offload_model_to_cpu)], ) diff --git a/cosmos_framework/callbacks/compile_tokenizer.py b/cosmos_framework/callbacks/compile_tokenizer.py index 84efa16..3ee15d2 100644 --- a/cosmos_framework/callbacks/compile_tokenizer.py +++ b/cosmos_framework/callbacks/compile_tokenizer.py @@ -4,7 +4,7 @@ """Training callback that defers AOT compilation of the VAE tokenizer. The actual compilation logic lives in -:meth:`~projects.cosmos3.vfm.tokenizers.wan2pt2_vae_4x16x16.Wan2pt2VAEInterface.compile_encode`. +:meth:`~cosmos_framework.model.vfm.tokenizers.wan2pt2_vae_4x16x16.Wan2pt2VAEInterface.compile_encode`. This module provides a :class:`CompileTokenizer` callback that invokes it at the right point during training (after ``compile_after_iterations`` steps, to avoid NCCL timeouts during CUDA/cuDNN warm-up). @@ -21,6 +21,7 @@ """ from collections.abc import Sequence +from typing import Literal import torch @@ -43,6 +44,10 @@ def __init__( enabled: bool = False, compile_after_iterations: int = 3, warmup_resolutions: Sequence[str] | None = None, + backend: Literal["cudagraphs", "inductor"] = "inductor", + mode: Literal["reduce-overhead", "max-autotune"] | None = "reduce-overhead", + fullgraph: bool = False, + dynamic: bool = False, ): """ Args: @@ -60,6 +65,10 @@ def __init__( self.compile_after_iterations: int = compile_after_iterations self.skip_counter: int = 0 self.warmup_resolutions: Sequence[str] | None = warmup_resolutions + self.backend: Literal["cudagraphs", "inductor"] = backend + self.mode: Literal["reduce-overhead", "max-autotune"] | None = mode + self.fullgraph: bool = fullgraph + self.dynamic: bool = dynamic if self.enabled: if self.warmup_resolutions is None: @@ -101,6 +110,10 @@ def on_training_step_start( tokenizer.compile_encode( self.warmup_resolutions, output_dir=self.config.job.path_local, + backend=self.backend, + mode=self.mode, + fullgraph=self.fullgraph, + dynamic=self.dynamic, ) self.skip_counter += 1 diff --git a/cosmos_framework/callbacks/data_stats.py b/cosmos_framework/callbacks/data_stats.py index 20e3161..2914981 100644 --- a/cosmos_framework/callbacks/data_stats.py +++ b/cosmos_framework/callbacks/data_stats.py @@ -51,7 +51,6 @@ def on_training_step_end( # Handle case where dataset_name gets batched into a list if isinstance(dataset_name, list): - assert len(dataset_name) == 1, "dataset_name should be a list of 1" dataset_name = dataset_name[0] diff --git a/cosmos_framework/callbacks/dataloader_state.py b/cosmos_framework/callbacks/dataloader_state.py index fee45b9..bee9e43 100644 --- a/cosmos_framework/callbacks/dataloader_state.py +++ b/cosmos_framework/callbacks/dataloader_state.py @@ -26,18 +26,14 @@ class DataLoaderStateCallback(Callback): def __init__( self, distributor_type: str | None = None, - name: str = "", ) -> None: super().__init__() self.distributor_type = distributor_type - self.name = name self.config: Any = None self.state: dict[int, NoReplaceShardlistState] = {} self.verbose = True def _update_state_from_batch(self, data_batch: dict[str, torch.Tensor]) -> None: - if "sample_worker_id" not in data_batch: - return # batch has no position metadata (shuffle=False or iterable data_source) worker_ids = data_batch["sample_worker_id"].tolist() # [B] epochs = data_batch["sample_epoch"].tolist() # [B] indices = data_batch["sample_index"].tolist() # [B] @@ -50,8 +46,6 @@ def _update_state_from_batch(self, data_batch: dict[str, torch.Tensor]) -> None: ): self.state[worker_id] = NoReplaceShardlistState(epoch=epoch, index=index) - _ACTIVE_DISTRIBUTOR_TYPES = ("no_replace", "data_packer") - def on_training_step_batch_end( self, model: ImaginaireModel, @@ -60,7 +54,7 @@ def on_training_step_batch_end( loss: torch.Tensor, iteration: int = 0, ) -> None: - if self.distributor_type in self._ACTIVE_DISTRIBUTOR_TYPES: + if self.distributor_type == "no_replace": self._update_state_from_batch(data_batch) def on_training_step_end( @@ -71,7 +65,7 @@ def on_training_step_end( loss: torch.Tensor, iteration: int = 0, ) -> None: - if self.distributor_type in self._ACTIVE_DISTRIBUTOR_TYPES: + if self.distributor_type == "no_replace": if self.verbose: if iteration % self.config.trainer.logging_iter == 0: msg = "\n" @@ -80,10 +74,10 @@ def on_training_step_end( log.info(msg) def has_checkpoint_state(self) -> bool: - return self.distributor_type in self._ACTIVE_DISTRIBUTOR_TYPES + return self.distributor_type == "no_replace" def state_dict(self) -> dict[int, dict[str, int]]: - if self.distributor_type not in self._ACTIVE_DISTRIBUTOR_TYPES: + if self.distributor_type != "no_replace": return {} state_dict: dict[int, dict[str, int]] = {} @@ -96,7 +90,7 @@ def state_dict(self) -> dict[int, dict[str, int]]: return state_dict def load_state_dict(self, state_dict: dict[int, dict[str, int]]) -> None: - if self.distributor_type not in self._ACTIVE_DISTRIBUTOR_TYPES: + if self.distributor_type != "no_replace": return if not state_dict: @@ -104,114 +98,10 @@ def load_state_dict(self, state_dict: dict[int, dict[str, int]]) -> None: return self.state = {} - # Build env var prefix. For data_packer, namespacing avoids conflicts - # when multiple DataPackerDataLoader instances share the same process - # (e.g. inside JointDataPackerDataLoader). name="" → original format. - _dp_pfx = f"DP_STATE_{self.name}_" if self.name else "DP_STATE_" for worker_id, per_worker_state in state_dict.items(): epoch = per_worker_state["epoch"] index = per_worker_state["index"] self.state[worker_id] = NoReplaceShardlistState(epoch=epoch, index=index) - if self.distributor_type == "data_packer": - os.environ[f"{_dp_pfx}WORKER_{worker_id}_EPOCH"] = str(epoch) - os.environ[f"{_dp_pfx}WORKER_{worker_id}_INDEX"] = str(index) - log.info(f"Loaded data_packer dataloader state for worker {worker_id}: epoch={epoch}, index={index}") - else: - os.environ[f"NSL_STATE_WORKER_{worker_id}_EPOCH"] = str(epoch) - os.environ[f"NSL_STATE_WORKER_{worker_id}_INDEX"] = str(index) - log.info(f"Loaded no_replace dataloader state for worker {worker_id}: epoch={epoch}, index={index}") - - -class JointDataLoaderStateCallback(Callback): - """Checkpoint/resume state for ``JointDataPackerDataLoader``. - - Manages two levels of state in a single DCP checkpoint entry - (``checkpoint_component = "dataloader"``): - - 1. **Outer** ``global_id`` — the number of batches the outer loader has - yielded. Restored via ``outer_loader.set_start_iteration(global_id)`` - so the deterministic dataset-selection sequence resumes from the correct - step. - - 2. **Inner** per-dataset, per-worker ``(epoch, index)`` — one - ``DataLoaderStateCallback`` per inner loader, keyed by the dataset name. - Each inner callback sets namespaced env vars on ``load_state_dict`` so - workers fast-forward to the saved sample position. - - Usage in experiment configs:: - - joint_loader = JointDataPackerDataLoader(dataloaders={...}, seed=42) - exp["dataloader_train"] = joint_loader - exp["trainer"]["callbacks"]["dataloader_state"] = JointDataLoaderStateCallback( - outer_loader=joint_loader, - distributor_type="data_packer", - ) - - The ``checkpoint_component = "dataloader"`` class attribute ensures the DCP - checkpointer's ``_DataloaderWrapper`` discovers exactly this callback (it - picks the first matching callback). Do **not** also register standalone - ``DataLoaderStateCallback`` instances for the inner loaders — this class - already handles them all. - """ - - checkpoint_component: str = "dataloader" - - def __init__( - self, - outer_loader: Any, - distributor_type: str = "data_packer", - ) -> None: - super().__init__() - self._outer = outer_loader - self._inner: dict[str, DataLoaderStateCallback] = { - name: DataLoaderStateCallback(distributor_type=distributor_type, name=name) - for name in outer_loader._names - } - self.config: Any = None - - def _update_state_from_batch(self, batch: dict) -> None: - name = batch.get("dataset_name") - if name in self._inner: - self._inner[name]._update_state_from_batch(batch) - - def on_training_step_batch_end( - self, - model: Any, - data_batch: dict, - output_batch: dict, - loss: Any, - iteration: int = 0, - ) -> None: - self._update_state_from_batch(data_batch) - - def on_training_step_end( - self, - model: Any, - data_batch: dict, - output_batch: dict, - loss: Any, - iteration: int = 0, - ) -> None: - if self.config and iteration % self.config.trainer.logging_iter == 0: - msg = f"\nJointDataPackerDataLoader global_id={self._outer._global_id}\n" - for name, cb in self._inner.items(): - for wid, state in cb.state.items(): - msg += f" [{name}] worker {wid}: epoch={state.epoch}, index={state.index}\n" - log.info(msg) - - def has_checkpoint_state(self) -> bool: - return True - - def state_dict(self) -> dict: - return { - "global_id": self._outer._global_id, - **{name: cb.state_dict() for name, cb in self._inner.items()}, - } - - def load_state_dict(self, state: dict) -> None: - global_id = state.get("global_id", 0) - self._outer.set_start_iteration(global_id) - log.info(f"JointDataLoaderStateCallback: resumed outer global_id={global_id}") - for name, cb in self._inner.items(): - if name in state: - cb.load_state_dict(state[name]) + os.environ[f"NSL_STATE_WORKER_{worker_id}_EPOCH"] = str(epoch) + os.environ[f"NSL_STATE_WORKER_{worker_id}_INDEX"] = str(index) + log.info(f"Loaded no replace dataloader state for worker {worker_id}: epoch={epoch}, index={index}") diff --git a/cosmos_framework/callbacks/every_n_draw_sample.py b/cosmos_framework/callbacks/every_n_draw_sample.py index baf1ffc..9aa96fa 100644 --- a/cosmos_framework/callbacks/every_n_draw_sample.py +++ b/cosmos_framework/callbacks/every_n_draw_sample.py @@ -154,8 +154,6 @@ def x0_pred(self, trainer, model, data_batch, output_batch, loss, iteration): tag = "ema" if self.is_ema else "reg" log.debug("starting data and condition model", rank0_only=False) - - data_clean = model.get_data_and_condition(data_batch) raw_data = data_clean.raw_state_vision x0 = data_clean.x0_tokens_vision @@ -185,7 +183,6 @@ def x0_pred(self, trainer, model, data_batch, output_batch, loss, iteration): log.debug(f"done denoising {sigma}", rank0_only=False) mse_loss = distributed.dist_reduce_tensor(F.mse_loss(sample, x0)) mse_loss_list.append(mse_loss) - if hasattr(model, "decode"): sample = model.decode(sample) to_show.append(sample.float().cpu()) @@ -316,7 +313,6 @@ def sample(self, trainer, model, data_batch, output_batch, loss, iteration): for sample_idx in range(data_clean.batch_size): n_vis = num_items[sample_idx] # First item(s) are condition, last item is generation target - # but we need to support multiple conditions per sample in the future. Current code # can handle this without throwing an error. condition_images.append(raw_data[vis_offset]) # source image (1, C, 1, H, W) diff --git a/cosmos_framework/callbacks/grad_clip.py b/cosmos_framework/callbacks/grad_clip.py index 151c1bc..f3cb4fa 100644 --- a/cosmos_framework/callbacks/grad_clip.py +++ b/cosmos_framework/callbacks/grad_clip.py @@ -132,7 +132,7 @@ def _clip_grad( # `torch.distributed._tensor.ops.math_ops._NormPartial`. # We can simply reduce the DTensor to get the total norm in this # tensor's process group and then convert it to a local tensor. - + # NOTE: It has two purposes: # 1. to make sure the total norm is computed correctly when PP is used (see below) # 2. to return a reduced mesh_norm tensor whose .item() would return the correct value if isinstance(mesh_norm, DTensor): diff --git a/cosmos_framework/callbacks/hf_export.py b/cosmos_framework/callbacks/hf_export.py index 8bcba5a..07c324f 100644 --- a/cosmos_framework/callbacks/hf_export.py +++ b/cosmos_framework/callbacks/hf_export.py @@ -137,11 +137,11 @@ def on_save_checkpoint(self, model: Any, state_dict: dict[str, Any]) -> None: if not isinstance(model, VLMModel): # The legacy vlm/train.py path passes model_parts: list[nn.Module] (raw HF # models without the VLMModel attribute structure). HF export requires the - # VLMModel wrapper, which is only available via the unified cosmos_framework/scripts/train.py path. + # VLMModel wrapper, which is only available via the unified scripts/train.py path. if isinstance(model, list): log.warning( "[HFExportCallback] Received model_parts (list) instead of VLMModel. " - "HF export requires the unified training path (cosmos_framework/scripts/train.py). Skipping." + "HF export requires the unified training path (scripts/train.py). Skipping." ) else: log.warning( diff --git a/cosmos_framework/callbacks/mfu.py b/cosmos_framework/callbacks/mfu.py index 3035437..4a6c792 100644 --- a/cosmos_framework/callbacks/mfu.py +++ b/cosmos_framework/callbacks/mfu.py @@ -138,7 +138,6 @@ def _ensure_initialised(self, model: ImaginaireModel) -> None: ac_cfg = getattr(model_cfg, "activation_checkpointing", None) ac_mode = getattr(ac_cfg, "mode", "none") - # Some activations don't need to be recomputed under selective AC, so # we need to remove them from the FLOP computation. self._use_activation_checkpointing = ac_mode != "none" diff --git a/cosmos_framework/callbacks/wandb_log_eval.py b/cosmos_framework/callbacks/wandb_log_eval.py index ac6911f..abea93f 100644 --- a/cosmos_framework/callbacks/wandb_log_eval.py +++ b/cosmos_framework/callbacks/wandb_log_eval.py @@ -88,7 +88,6 @@ def on_validation_step_end( # Handle case where dataset_name gets batched into a list if isinstance(dataset_name, list): - assert len(dataset_name) == 1, "dataset_name should be a list of 1" dataset_name = dataset_name[0] diff --git a/cosmos_framework/checkpoint/s3_filesystem.py b/cosmos_framework/checkpoint/s3_filesystem.py index e47219e..8975ce1 100644 --- a/cosmos_framework/checkpoint/s3_filesystem.py +++ b/cosmos_framework/checkpoint/s3_filesystem.py @@ -285,7 +285,7 @@ def __init__( """ super().__init__( path=path, - sync_files=False, + sync_files=False, # FIXME: setting this to True makes the run to fail (L#333: `os.fsync(stream.fileno())`) **kwargs, ) self.fs = S3FileSystem(credential_path, enable_gcs_patch_in_boto3=enable_gcs_patch_in_boto3) # type: ignore diff --git a/cosmos_framework/configs/base/base_config_test.py b/cosmos_framework/configs/base/base_config_test.py index 093e17b..2ea4319 100644 --- a/cosmos_framework/configs/base/base_config_test.py +++ b/cosmos_framework/configs/base/base_config_test.py @@ -29,7 +29,7 @@ def test_config_init_experiment_mot(experiment_name): Parameterized test to verify config initialization for multiple experiments. PYTHONPATH=. torchrun --nproc_per_node=8 -m pytest -s cosmos_framework/configs/base/config_test_mot.py --L1 """ - config_file = "configs/base/config.py" + config_file = "cosmos_framework/configs/base/config.py" config_module = get_config_module(config_file) config = importlib.import_module(config_module).make_config() config = override( diff --git a/cosmos_framework/configs/base/defaults/callbacks.py b/cosmos_framework/configs/base/defaults/callbacks.py index 46c85ed..805ffb5 100644 --- a/cosmos_framework/configs/base/defaults/callbacks.py +++ b/cosmos_framework/configs/base/defaults/callbacks.py @@ -10,7 +10,6 @@ from cosmos_framework.utils.lazy_config import LazyCall as L from cosmos_framework.utils.callback import LowPrecisionCallback, WandBCallback from cosmos_framework.callbacks.compile_tokenizer import CompileTokenizer - from cosmos_framework.callbacks.device_monitor import DeviceMonitor from cosmos_framework.callbacks.every_n_draw_sample import EveryNDrawSample from cosmos_framework.callbacks.expert_heatmap import ExpertHeatmap diff --git a/cosmos_framework/configs/base/defaults/cluster.py b/cosmos_framework/configs/base/defaults/cluster.py index 23b49dd..46450b9 100644 --- a/cosmos_framework/configs/base/defaults/cluster.py +++ b/cosmos_framework/configs/base/defaults/cluster.py @@ -23,14 +23,24 @@ class ClusterConfig: DefaultClusterConfig: ClusterConfig = ClusterConfig( object_store_bucket_data="", - object_store_bucket_checkpoint="bucket-checkpoint", - object_store_bucket_pretrained="bucket-pretrained", - object_store_credential_data="credentials/data.secret", - object_store_credential_checkpoint="credentials/checkpoint.secret", - object_store_credential_pretrained="credentials/pretrained.secret", + object_store_bucket_checkpoint="bucket4", + object_store_bucket_pretrained="bucket4", + object_store_credential_data="credentials/s3_training.secret", + object_store_credential_checkpoint="credentials/s3_checkpoint.secret", + object_store_credential_pretrained="credentials/s3_checkpoint.secret", +) + +DefaultClusterConfig: ClusterConfig = ClusterConfig( + object_store_bucket_data="", + object_store_bucket_checkpoint="bucket1", + object_store_bucket_pretrained="bucket0", + object_store_credential_data="credentials/gcp_checkpoint.secret", + object_store_credential_checkpoint="credentials/gcp_training.secret", + object_store_credential_pretrained="credentials/gcp_training.secret", ) def register_cluster(): cs = ConfigStore.instance() - cs.store(group="cluster", package="job.cluster", name="default", node=DefaultClusterConfig) + cs.store(group="cluster", package="job.cluster", name="aws_iad_h100", node=DefaultClusterConfig) + cs.store(group="cluster", package="job.cluster", name="gcp_iad_gb200", node=DefaultClusterConfig) diff --git a/cosmos_framework/configs/base/defaults/compile.py b/cosmos_framework/configs/base/defaults/compile.py index b0e1c88..3d5ebf7 100644 --- a/cosmos_framework/configs/base/defaults/compile.py +++ b/cosmos_framework/configs/base/defaults/compile.py @@ -24,7 +24,7 @@ class CompileConfig: # (maps to ``torch.compile(dynamic=...)``). Defaults to True for training, # which sees varying shapes across batches (sequence length, CP sharding, ...); # specializing would recompile continuously. See ParallelismOverrides in - # cosmos_framework/inference/common/args.py for the inference-side rationale + # packages/cosmos3/cosmos3/common/args.py for the inference-side rationale # (where dynamic=False is preferred for stable AR shapes). compile_dynamic: bool = True diff --git a/cosmos_framework/configs/base/defaults/multiview_dataloader.py b/cosmos_framework/configs/base/defaults/multiview_dataloader.py deleted file mode 100644 index a646ac6..0000000 --- a/cosmos_framework/configs/base/defaults/multiview_dataloader.py +++ /dev/null @@ -1,150 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: OpenMDW-1.1 - -""" -Hydra ConfigStore registration for multiview dataloaders. - -Registers named dataloader configs that can be referenced via Hydra overrides -(e.g. ``{override /data_train: video_control_mads_multiview_0823_gcs_720p_10fps_93frames_7views}``) -or used as templates for inline ``L(get_multiview_video_loader)(...)`` in -experiment configs. - -Two naming conventions: - - **Transfer** (with control signal): - ``video_control_{dataset}_{store}_{res}_{fps}_{frames}_{views}`` - - **Predict** (no control signal): - ``video_{dataset}_{store}_{res}_{fps}_{frames}_{views}`` -""" - -from hydra.core.config_store import ConfigStore - -from cosmos_framework.utils.lazy_config import LazyCall as L -from cosmos_framework.data.vfm.multiview.multiview_data_source import ( - DEFAULT_CAMERAS, - INDEX_TO_CAMERA_MAPPING, - TRANSFER_CAPTION_KEY_MAPPING, - TRANSFER_CONTROL_KEY_MAPPING, - TRANSFER_VIDEO_KEY_MAPPING, -) -from cosmos_framework.data.vfm.multiview.multiview_dataset import ( - MultiviewAugmentationConfig, - get_multiview_video_loader, -) - -# --------------------------------------------------------------------------- -# Camera view subsets -# --------------------------------------------------------------------------- - -CAMERA_VIEW_CONFIGS: dict[str, tuple[str, ...]] = { - "7views": DEFAULT_CAMERAS, - "1view_front": ("camera_front_wide_120fov",), - "4views": ( - "camera_front_wide_120fov", - "camera_cross_right_120fov", - "camera_rear_tele_30fov", - "camera_cross_left_120fov", - ), -} - -# --------------------------------------------------------------------------- -# Grid dimensions -# --------------------------------------------------------------------------- - -_TRANSFER_DATASETS = ["mads_multiview_0823"] -_OBJECT_STORES = ["gcs"] - -_RESOLUTIONS: list[tuple[str, tuple[int, int]]] = [ - ("720p", (720, 1280)), -] - -_FPS: list[tuple[str, int]] = [ - ("10fps", 1), # MADS transfer data is already at 10 fps -] - -_NUM_VIDEO_FRAMES: list[tuple[str, int]] = [ - ("29frames", 29), - ("61frames", 61), - ("93frames", 93), -] - - -def register_multiview_dataloaders() -> None: - """Register all multiview dataloader configs with Hydra ConfigStore.""" - - cs = ConfigStore.instance() - - # ----- Transfer dataloaders (with control signals) ----- - for dataset in _TRANSFER_DATASETS: - for object_store in _OBJECT_STORES: - for resolution_str, resolution_hw in _RESOLUTIONS: - for fps_str, downsample_factor in _FPS: - for num_frames_str, num_frames in _NUM_VIDEO_FRAMES: - for views_str, camera_keys in CAMERA_VIEW_CONFIGS.items(): - name = ( - f"video_control_{dataset}_{object_store}_{resolution_str}_" - f"{fps_str}_{num_frames_str}_{views_str}" - ) - cs.store( - group="data_train", - package="dataloader_train", - name=name, - node=L(get_multiview_video_loader)( - dataset_name=dataset, - is_train=True, - augmentation_config=L(MultiviewAugmentationConfig)( - resolution_hw=resolution_hw, - fps_downsample_factor=downsample_factor, - num_video_frames=num_frames, - camera_keys=camera_keys, - camera_video_key_mapping=TRANSFER_VIDEO_KEY_MAPPING, - camera_caption_key_mapping=TRANSFER_CAPTION_KEY_MAPPING, - camera_control_key_mapping=TRANSFER_CONTROL_KEY_MAPPING, - position_to_camera_mapping=INDEX_TO_CAMERA_MAPPING, - single_caption_camera_name="camera_front_wide_120fov", - ), - ), - ) - - # ----- Predict dataloaders (no control signals, for future use) ----- - # These use named keys (video_camera_front_wide_120fov, etc.) and need - # different datasets (e.g. alpamayo_dec2024) with 30 fps native data. - # Uncomment and add predict datasets to the catalog when needed. - # - # _PREDICT_DATASETS = ["alpamayo_dec2024"] - # _PREDICT_FPS = [("10fps", 3), ("15fps", 2)] # 30 fps native → downsample - # for dataset in _PREDICT_DATASETS: - # for object_store in _OBJECT_STORES: - # for resolution_str, resolution_hw in _RESOLUTIONS: - # for fps_str, downsample_factor in _PREDICT_FPS: - # for num_frames_str, num_frames in _NUM_VIDEO_FRAMES: - # for views_str, camera_keys in CAMERA_VIEW_CONFIGS.items(): - # name = ( - # f"video_{dataset}_{object_store}_{resolution_str}_" - # f"{fps_str}_{num_frames_str}_{views_str}" - # ) - # cs.store( - # group="data_train", - # package="dataloader_train", - # name=name, - # node=L(get_multiview_video_loader)( - # dataset_name=dataset, - # is_train=True, - # augmentation_config=L(MultiviewAugmentationConfig)( - # resolution_hw=resolution_hw, - # fps_downsample_factor=downsample_factor, - # num_video_frames=num_frames, - # camera_keys=camera_keys, - # camera_video_key_mapping=PREDICT_VIDEO_KEY_MAPPING, - # camera_caption_key_mapping=PREDICT_CAPTION_KEY_MAPPING, - # camera_control_key_mapping=None, - # position_to_camera_mapping=None, - # single_caption_camera_name=None, - # ), - # ), - # ) - - -# Auto-register on import -register_multiview_dataloaders() diff --git a/cosmos_framework/configs/base/defaults/tokenizer.py b/cosmos_framework/configs/base/defaults/tokenizer.py index 55cb01c..237dba6 100644 --- a/cosmos_framework/configs/base/defaults/tokenizer.py +++ b/cosmos_framework/configs/base/defaults/tokenizer.py @@ -17,10 +17,12 @@ PRETRAINED_TOKENIZER_FLUX_VAE_PTH = "pretrained/tokenizers/image/flux/ae.safetensors" # UniAE checkpoint paths -PRETRAINED_TOKENIZER_UNIAE_4X16X16_C48_T8TO24_64TO512P_FPS_ALL_ENCODER_NONCAUSAL_DECODER_NONCAUSAL_NOGAN_BEST_S1_VAE_PTH = "pretrained/tokenizers/video/cosmos/uniae4x16x16_c48_t8to24_64to512p_fps_all_encoder_noncausal_decoder_noncausal_nogan_best_s1.pt" +PRETRAINED_TOKENIZER_UNIAE_4X16X16_C48_T32TO160_MIXP_FPS_MIX_ENCODER_NONCAUSAL_DECODER_NONCAUSAL_NOGAN_S3_QWEN0P6B_VAE_PTH = ( + "s3://bucket1/uniae/tok_experiments/" + "uniae_s3_prod32_ditval_video_b1_50k_r1/checkpoints/iter_000050000.pt" +) # DCAE checkpoint paths -PRETRAINED_TOKENIZER_DCAE_PTH = "pretrained/tokenizers/video/cosmos/dc-ae-v-1.0-f32t4c64-cosmos-encoder-causal-decoder-chunk-causal-4-frame-120-pad-7-no-gan.pt" PRETRAINED_TOKENIZER_DCAE_4X32X32_C64_T120_256P_FPS_ALL_ENCODER_CAUSAL_DECODER_CHUNKCAUSAL4_NOGAN_COSMOS_PAD_7_V0PT2_PTH = "pretrained/tokenizers/video/cosmos/dcae4x32x32_c64_t120_256p_fps_all_encoder_causal_decoder_chunk_causal_4_nogan_cosmos_pad_7_v0.2.pt" PRETRAINED_TOKENIZER_DCAE_4X32X32_C96_T120_256P_FPS_ALL_ENCODER_CAUSAL_DECODER_CHUNKCAUSAL4_NOGAN_COSMOS_PAD_7_V0PT2_PTH = "pretrained/tokenizers/video/cosmos/dcae4x32x32_c96_t120_256p_fps_all_encoder_causal_decoder_chunk_causal_4_nogan_cosmos_pad_7_v0.2.pt" PRETRAINED_TOKENIZER_DCAE_4X32X32_C128_T120_256P_FPS_ALL_ENCODER_CAUSAL_DECODER_CHUNKCAUSAL4_NOGAN_COSMOS_PAD_7_V0PT2_PTH = "pretrained/tokenizers/video/cosmos/dcae4x32x32_c128_t120_256p_fps_all_encoder_causal_decoder_chunk_causal_4_nogan_cosmos_pad_7_v0.2.pt" @@ -44,6 +46,7 @@ chunk_duration=1, spatial_compression_factor=8, temporal_compression_factor=1, + causal=True, ) Wan2pt1VAEConfig: LazyDict = L(Wan2pt1VAEInterface)( @@ -53,6 +56,7 @@ vae_path=PRETRAINED_TOKENIZER_WAN2PT1_VAE_PTH, spatial_compression_factor=8, temporal_compression_factor=4, + causal=True, ) Wan2pt2VAEConfig: LazyDict = L(Wan2pt2VAEInterface)( @@ -61,14 +65,7 @@ vae_path=PRETRAINED_TOKENIZER_WAN2PT2_VAE_PTH, spatial_compression_factor=16, temporal_compression_factor=4, -) - -DCAE4x32x32Config: LazyDict = L(DCAE4x32x32Interface)( - bucket_name=PLACEHOLDER, - object_store_credential_path_pretrained=PLACEHOLDER, - vae_path=PRETRAINED_TOKENIZER_DCAE_PTH, - spatial_compression_factor=32, - temporal_compression_factor=4, + causal=True, ) DCAE4x32x32C64T120_256pFpsAllEncoderCausalDecoderChunkCausal4NoganCosmosPad7V0pt2Config: LazyDict = L( @@ -80,6 +77,7 @@ model_name="dcae4x32x32_c64_t120_256p_fps_all_encoder_causal_decoder_chunk_causal_4_nogan_cosmos_pad_7_v0.2", spatial_compression_factor=32, temporal_compression_factor=4, + causal=True, ) DCAE4x32x32C96T120_256pFpsAllEncoderCausalDecoderChunkCausal4NoganCosmosPad7V0pt2Config: LazyDict = L( @@ -104,12 +102,17 @@ temporal_compression_factor=4, ) -UniAE4x16x16C48T8to24_64to512pFpsAllEncoderNoncausalDecoderNoncausalNoganBestS1Config: LazyDict = L(UniAEVAEInterface)( + +UniAE4x16x16C48T32to160MixpFpsMixEncoderNoncausalDecoderNoncausalNoganS3Qwen0p6bVAEConfig: LazyDict = L( + UniAEVAEInterface +)( bucket_name=PLACEHOLDER, object_store_credential_path_pretrained=PLACEHOLDER, - vae_path=PRETRAINED_TOKENIZER_UNIAE_4X16X16_C48_T8TO24_64TO512P_FPS_ALL_ENCODER_NONCAUSAL_DECODER_NONCAUSAL_NOGAN_BEST_S1_VAE_PTH, + vae_path=PRETRAINED_TOKENIZER_UNIAE_4X16X16_C48_T32TO160_MIXP_FPS_MIX_ENCODER_NONCAUSAL_DECODER_NONCAUSAL_NOGAN_S3_QWEN0P6B_VAE_PTH, spatial_compression_factor=16, temporal_compression_factor=4, + pixel_trim=True, + causal=False, ) # ============================================================================= @@ -173,8 +176,8 @@ def register_tokenizer(): cs.store( group="tokenizer", package="model.config.tokenizer", - name="uniae_4x16x16_c48_t8to24_64to512p_fps_all_encoder_noncausal_decoder_noncausal_nogan_best_s1_tokenizer", - node=UniAE4x16x16C48T8to24_64to512pFpsAllEncoderNoncausalDecoderNoncausalNoganBestS1Config, + name="uniae_4x16x16_c48_t32to160_mixp_fps_mix_encoder_noncausal_decoder_noncausal_nogan_s3_qwen0p6b_tokenizer", + node=UniAE4x16x16C48T32to160MixpFpsMixEncoderNoncausalDecoderNoncausalNoganS3Qwen0p6bVAEConfig, ) # Flux tokenizer cs.store(group="tokenizer", package="model.config.tokenizer", name="flux_tokenizer", node=FluxVAEConfig) @@ -182,25 +185,19 @@ def register_tokenizer(): cs.store( group="tokenizer", package="model.config.tokenizer", - name="dc_ae_4x32x32_tokenizer", - node=DCAE4x32x32Config, - ) - cs.store( - group="tokenizer", - package="model.config.tokenizer", - name="dc_ae_4x32x32_c64_t120_256p_fps_all_encoder_causal_decoder_chunk_causal_4_nogan_cosmos_pad_7_v0.2_tokenizer", + name="dcae4x32x32_c64_t120_256p_fps_all_encoder_causal_decoder_chunk_causal_4_nogan_cosmos_pad_7_v0.2_tokenizer", node=DCAE4x32x32C64T120_256pFpsAllEncoderCausalDecoderChunkCausal4NoganCosmosPad7V0pt2Config, ) cs.store( group="tokenizer", package="model.config.tokenizer", - name="dc_ae_4x32x32_c96_t120_256p_fps_all_encoder_causal_decoder_chunk_causal_4_nogan_cosmos_pad_7_v0.2_tokenizer", + name="dcae4x32x32_c96_t120_256p_fps_all_encoder_causal_decoder_chunk_causal_4_nogan_cosmos_pad_7_v0.2_tokenizer", node=DCAE4x32x32C96T120_256pFpsAllEncoderCausalDecoderChunkCausal4NoganCosmosPad7V0pt2Config, ) cs.store( group="tokenizer", package="model.config.tokenizer", - name="dc_ae_4x32x32_c128_t120_256p_fps_all_encoder_causal_decoder_chunk_causal_4_nogan_cosmos_pad_7_v0.2_tokenizer", + name="dcae4x32x32_c128_t120_256p_fps_all_encoder_causal_decoder_chunk_causal_4_nogan_cosmos_pad_7_v0.2_tokenizer", node=DCAE4x32x32C128T120_256pFpsAllEncoderCausalDecoderChunkCausal4NoganCosmosPad7V0pt2Config, ) diff --git a/cosmos_framework/configs/base/defaults/unittest.py b/cosmos_framework/configs/base/defaults/unittest.py index 5af5bb2..7e1b8c1 100644 --- a/cosmos_framework/configs/base/defaults/unittest.py +++ b/cosmos_framework/configs/base/defaults/unittest.py @@ -2,8 +2,12 @@ # SPDX-License-Identifier: OpenMDW-1.1 import attrs +# from cosmos_framework.configs.base.defaults.cluster import DefaultClusterConfig + # We are hardcoding the unittest assets in this file. +# CLUSTER_CONFIG = DefaultClusterConfig + # add codeowner for cosmos_framework/model/vfm/tokenizers diff --git a/cosmos_framework/configs/base/defaults/vlm.py b/cosmos_framework/configs/base/defaults/vlm.py index fa4d7c8..e031d33 100644 --- a/cosmos_framework/configs/base/defaults/vlm.py +++ b/cosmos_framework/configs/base/defaults/vlm.py @@ -95,11 +95,7 @@ def download_tokenizer_files(model_name: str, config_variant: str) -> str: return destination_dir -def create_qwen2_tokenizer_with_download(pretrained_model_name: str, config_variant: str, **_unused_kwargs): - # **_unused_kwargs absorbs extras (e.g. tokenizer_type) that OmegaConf - # merges in from a vlm_config preset's tokenizer block when an experiment - # overrides the tokenizer with this function but doesn't fully replace - # the preset's kwarg dict. +def create_qwen2_tokenizer_with_download(pretrained_model_name: str, config_variant: str): destination_dir = download_tokenizer_files(pretrained_model_name, config_variant) return LLMTokenizerProcessor(Qwen2Tokenizer.from_pretrained(destination_dir)) @@ -140,9 +136,6 @@ class VLMConfig: # HuggingFace model identifier or local path. Drives AutoConfig + AutoModel selection. model_name: str = "" - # Safetensor path for model - safetensors_path: str = "" - # Optional pretrained-weights overlay (separate from the AutoModel structural # init driven by model_name). pretrained_weights: PretrainedWeightsConfig = PretrainedWeightsConfig() @@ -285,29 +278,6 @@ class VLMConfig: ), ) -CosmosReason2_VLM_30b_a3b_Private_GCP_Config: VLMConfig = VLMConfig( - model_name="nvidia/Cosmos-Reason2-30B-A3B-Private", - model_instance=L(Qwen3VLMoeTextForCausalLM)( - config=L(create_vlm_config)( - base_config=L(Qwen3VLMoeMoTConfig.from_json_file)( - json_file="cosmos_framework/model/vfm/vlm/qwen3_vl_moe/configs/Qwen3-VL-30B-A3B-Instruct.json" - ), - layer_module="Qwen3VLMoeTextMoTDecoderLayer", - qk_norm_for_text=True, - ), - ), - tokenizer=L(build_processor_lazy)( - tokenizer_type="Qwen/Qwen3-VL-30B-A3B-Instruct", - config_variant="gcp", - ), - layer_module="Qwen3VLMoeTextMoTDecoderLayer", - pretrained_weights=PretrainedWeightsConfig( - backbone_path="s3://bucket0/cosmos3/pretrained/huggingface/Cosmos-Reason/Cosmos-Reason2-30B-A3B-Private/", - credentials_path="credentials/gcp_checkpoint.secret", - enable_gcs_patch_in_boto3=True, - ), -) - # Config for Qwen3VL 235B A22B Instruct model # Qwen3VLMoE uses Qwen2Tokenizer Qwen3VLMoT_VLM_235b_a22b_Instruct_GCP_Config: VLMConfig = VLMConfig( @@ -458,48 +428,6 @@ class VLMConfig: ), ) -CosmosReason2_VLM_2b_Private_GCP_Config: VLMConfig = VLMConfig( - model_name="nvidia/Cosmos-Reason2-2B-Private", - model_instance=L(Qwen3VLTextForCausalLM)( - config=L(create_vlm_config)( - base_config=L(Qwen3VLMoTConfig.from_json_file)( - json_file="cosmos_framework/model/vfm/vlm/qwen3_vl/configs/Qwen3-VL-2B-Instruct.json" - ), - qk_norm_for_text=True, - ), - ), - tokenizer=L(build_processor_lazy)( - tokenizer_type="Qwen/Qwen3-VL-2B-Instruct", - config_variant="gcp", - ), - pretrained_weights=PretrainedWeightsConfig( - backbone_path="s3://bucket0/cosmos3/pretrained/huggingface/Cosmos-Reason/Cosmos-Reason2-2B-Private/", - credentials_path="credentials/gcp_checkpoint.secret", - enable_gcs_patch_in_boto3=True, - ), -) - -Cosmos3Reasoner_VLM_2b_Private_GCP_Config: VLMConfig = VLMConfig( - model_name="nvidia/Cosmos3-Reasoner-2B-Private", - model_instance=L(Qwen3VLTextForCausalLM)( - config=L(create_vlm_config)( - base_config=L(Qwen3VLMoTConfig.from_json_file)( - json_file="cosmos_framework/model/vfm/vlm/qwen3_vl/configs/Qwen3-VL-2B-Instruct.json" - ), - qk_norm_for_text=True, - ), - ), - tokenizer=L(build_processor_lazy)( - tokenizer_type="Qwen/Qwen3-VL-2B-Instruct", - config_variant="gcp", - ), - pretrained_weights=PretrainedWeightsConfig( - backbone_path="s3://bucket0/cosmos3/pretrained/huggingface/Cosmos-Reason/Cosmos3-Reasoner-2B-Private/", - credentials_path="credentials/gcp_checkpoint.secret", - enable_gcs_patch_in_boto3=True, - ), -) - # Config for Qwen3VL 4B Instruct model # Qwen3VL uses Qwen2Tokenizer Qwen3VLMoT_VLM_4b_Instruct_Config: VLMConfig = VLMConfig( @@ -586,8 +514,8 @@ class VLMConfig: ), ) -CosmosReason2_VLM_8b_Private_GCP_Config: VLMConfig = VLMConfig( - model_name="nvidia/Cosmos-Reason2-8B-Private", +Cosmos3Reasoner_VLM_8b_Private_GCP_Config: VLMConfig = VLMConfig( + model_name="nvidia/Cosmos3-Reasoner-8B-Private", model_instance=L(Qwen3VLTextForCausalLM)( config=L(create_vlm_config)( base_config=L(Qwen3VLMoTConfig.from_json_file)( @@ -601,14 +529,14 @@ class VLMConfig: config_variant="gcp", ), pretrained_weights=PretrainedWeightsConfig( - backbone_path="s3://bucket0/cosmos3/pretrained/huggingface/Cosmos-Reason/Cosmos-Reason2-8B-Private/", + backbone_path="s3://bucket0/cosmos3/pretrained/huggingface/Cosmos-Reason/Cosmos3-Reasoner-8B-Private/", credentials_path="credentials/gcp_checkpoint.secret", enable_gcs_patch_in_boto3=True, ), ) -Cosmos3Reasoner_VLM_8b_Private_GCP_Config: VLMConfig = VLMConfig( - model_name="nvidia/Cosmos3-Reasoner-8B-Private", +Cosmos3NanoReasoner_VLM_GCP_Config: VLMConfig = VLMConfig( + model_name="nvidia/Cosmos3-Nano-Reasoner", model_instance=L(Qwen3VLTextForCausalLM)( config=L(create_vlm_config)( base_config=L(Qwen3VLMoTConfig.from_json_file)( @@ -617,18 +545,18 @@ class VLMConfig: qk_norm_for_text=True, ), ), - tokenizer=L(build_processor_lazy)( - tokenizer_type="Qwen/Qwen3-VL-8B-Instruct", + tokenizer=L(create_qwen2_tokenizer_with_download)( + pretrained_model_name="Qwen/Qwen3-VL-8B-Instruct", config_variant="gcp", ), pretrained_weights=PretrainedWeightsConfig( - backbone_path="s3://bucket0/cosmos3/pretrained/huggingface/Cosmos-Reason/Cosmos3-Reasoner-8B-Private/", + backbone_path="s3://bucket0/cosmos3/pretrained/huggingface/Cosmos-Reason/Cosmos3-Nano-Reasoner/", credentials_path="credentials/gcp_checkpoint.secret", enable_gcs_patch_in_boto3=True, ), ) -Cosmos3NanoReasoner_VLM_GCP_Config: VLMConfig = VLMConfig( +Cosmos3NanoReasoner_VLM_GCP_Config_0517: VLMConfig = VLMConfig( model_name="nvidia/Cosmos3-Nano-Reasoner", model_instance=L(Qwen3VLTextForCausalLM)( config=L(create_vlm_config)( @@ -643,13 +571,12 @@ class VLMConfig: config_variant="gcp", ), pretrained_weights=PretrainedWeightsConfig( - backbone_path="s3://bucket0/cosmos3/pretrained/huggingface/Cosmos-Reason/Cosmos3-Nano-Reasoner/", + backbone_path="s3://bucket0/cosmos3/pretrained/huggingface/Cosmos-Reason/Cosmos3-Nano-Reasoner-bb9c6f5/", credentials_path="credentials/gcp_checkpoint.secret", enable_gcs_patch_in_boto3=True, ), ) - # Config for Qwen3VL 32B Instruct model # Qwen3VL uses Qwen2Tokenizer Qwen3VLMoT_VLM_32b_Instruct_Config: VLMConfig = VLMConfig( @@ -693,8 +620,8 @@ class VLMConfig: ), ) -CosmosReason2_VLM_32b_Private_GCP_Config: VLMConfig = VLMConfig( - model_name="nvidia/Cosmos-Reason2-32B-Private", +Cosmos3Reasoner_VLM_32b_Private_GCP_Config: VLMConfig = VLMConfig( + model_name="nvidia/Cosmos3-Reasoner-32B-Private", model_instance=L(Qwen3VLTextForCausalLM)( config=L(create_vlm_config)( base_config=L(Qwen3VLMoTConfig.from_json_file)( @@ -708,14 +635,14 @@ class VLMConfig: config_variant="gcp", ), pretrained_weights=PretrainedWeightsConfig( - backbone_path="s3://bucket0/cosmos3/pretrained/huggingface/Cosmos-Reason/Cosmos-Reason2-32B-Private/", + backbone_path="s3://bucket0/cosmos3/pretrained/huggingface/Cosmos-Reason/Cosmos3-Reasoner-32B-Private/", credentials_path="credentials/gcp_checkpoint.secret", enable_gcs_patch_in_boto3=True, ), ) -Cosmos3Reasoner_VLM_32b_Private_GCP_Config: VLMConfig = VLMConfig( - model_name="nvidia/Cosmos3-Reasoner-32B-Private", +Cosmos3SuperReasoner_VLM_GCP_Config: VLMConfig = VLMConfig( + model_name="nvidia/Cosmos3-Super-Reasoner", model_instance=L(Qwen3VLTextForCausalLM)( config=L(create_vlm_config)( base_config=L(Qwen3VLMoTConfig.from_json_file)( @@ -724,18 +651,18 @@ class VLMConfig: qk_norm_for_text=True, ), ), - tokenizer=L(build_processor_lazy)( - tokenizer_type="Qwen/Qwen3-VL-32B-Instruct", + tokenizer=L(create_qwen2_tokenizer_with_download)( + pretrained_model_name="Qwen/Qwen3-VL-32B-Instruct", config_variant="gcp", ), pretrained_weights=PretrainedWeightsConfig( - backbone_path="s3://bucket0/cosmos3/pretrained/huggingface/Cosmos-Reason/Cosmos3-Reasoner-32B-Private/", + backbone_path="s3://bucket0/cosmos3/pretrained/huggingface/Cosmos-Reason/Cosmos3-Super-Reasoner/", credentials_path="credentials/gcp_checkpoint.secret", enable_gcs_patch_in_boto3=True, ), ) -Cosmos3SuperReasoner_VLM_GCP_Config: VLMConfig = VLMConfig( +Cosmos3SuperReasoner_VLM_GCP_Config_0517: VLMConfig = VLMConfig( model_name="nvidia/Cosmos3-Super-Reasoner", model_instance=L(Qwen3VLTextForCausalLM)( config=L(create_vlm_config)( @@ -750,12 +677,63 @@ class VLMConfig: config_variant="gcp", ), pretrained_weights=PretrainedWeightsConfig( - backbone_path="s3://bucket0/cosmos3/pretrained/huggingface/Cosmos-Reason/Cosmos3-Super-Reasoner/", + backbone_path="s3://bucket0/cosmos3/pretrained/huggingface/Cosmos-Reason/Cosmos3-Super-Reasoner-b6df0d1/", + credentials_path="credentials/gcp_checkpoint.secret", + enable_gcs_patch_in_boto3=True, + ), +) + +# Cosmos3-Edge-Reasoner at commit 4acb717. +# nemotron_siglip2 architecture: Nemotron text backbone (56-block hybrid layout, 2048 hidden) +# + SigLIP2 vision encoder. The text transformer is identical in shape to +# Nemotron-3-Dense-VL-2B (hidden_size=2048, 56 alternating attn/MLP blocks → 28 +# effective MoT layers after _transform_text_dict). Uses the same +# nemotron_3_dense_vl weight remapping and config JSON. +Cosmos3EdgeReasoner_VLM_GCP_Config_4acb717: VLMConfig = VLMConfig( + model_name="nvidia/Cosmos3-Edge-Reasoner", + model_instance=L(Nemotron3DenseVLTextForCausalLM)( + config=L(create_vlm_config)( + base_config=L(Nemotron3DenseVLMoTConfig.from_json_file)( + json_file="cosmos_framework/model/vfm/vlm/nemotron_3_dense_vl/configs/Nemotron-2B-Dense-VL.json" + ), + qk_norm_for_text=False, + ), + ), + tokenizer=L(build_processor_lazy)( + tokenizer_type="nvidia/Cosmos3-Edge-Reasoner", + config_variant="gcp", + ), + pretrained_weights=PretrainedWeightsConfig( + backbone_path="s3://bucket0/cosmos3/pretrained/huggingface/nvidia/Cosmos3-Edge-Reasoner-4acb717/", credentials_path="credentials/gcp_checkpoint.secret", enable_gcs_patch_in_boto3=True, + checkpoint_format="nemotron_3_dense_vl", ), ) +# Cosmos3-Edge-Reasoner at commit 9b4c028 (2026-05-29). +# Same nemotron_siglip2 architecture as 4acb717; new weights uploaded 2026-05-29. +Cosmos3EdgeReasoner_VLM_GCP_Config_9b4c028: VLMConfig = VLMConfig( + model_name="nvidia/Cosmos3-Edge-Reasoner", + model_instance=L(Nemotron3DenseVLTextForCausalLM)( + config=L(create_vlm_config)( + base_config=L(Nemotron3DenseVLMoTConfig.from_json_file)( + json_file="cosmos_framework/model/vfm/vlm/nemotron_3_dense_vl/configs/Nemotron-2B-Dense-VL.json" + ), + qk_norm_for_text=False, + ), + ), + tokenizer=L(build_processor_lazy)( + tokenizer_type="nvidia/Cosmos3-Edge-Reasoner", + config_variant="gcp", + ), + pretrained_weights=PretrainedWeightsConfig( + backbone_path="s3://bucket0/cosmos3/pretrained/huggingface/nvidia/Cosmos3-Edge-Reasoner-9b4c028/", + credentials_path="credentials/gcp_checkpoint.secret", + enable_gcs_patch_in_boto3=True, + checkpoint_format="nemotron_3_dense_vl", + ), +) def register_vlm(): @@ -832,24 +810,6 @@ def register_vlm(): name="cosmos_reason2_vlm_2b_gcp", node=CosmosReason2_VLM_2b_GCP_Config, ) - cs.store( - group="vlm_config", - package="model.config.vlm_config", - name="cosmos_reason2_vlm_2b_private_gcp", - node=CosmosReason2_VLM_2b_Private_GCP_Config, - ) - cs.store( - group="vlm_config", - package="model.config.vlm_config", - name="cosmos3_reasoner_vlm_2b_private_gcp", - node=Cosmos3Reasoner_VLM_2b_Private_GCP_Config, - ) - cs.store( - group="vlm_config", - package="model.config.vlm_config", - name="cosmos_reason2_vlm_8b_private_gcp", - node=CosmosReason2_VLM_8b_Private_GCP_Config, - ) cs.store( group="vlm_config", package="model.config.vlm_config", @@ -865,8 +825,8 @@ def register_vlm(): cs.store( group="vlm_config", package="model.config.vlm_config", - name="cosmos_reason2_vlm_32b_private_gcp", - node=CosmosReason2_VLM_32b_Private_GCP_Config, + name="cosmos3_nano_reasoner_vlm_gcp_0517", + node=Cosmos3NanoReasoner_VLM_GCP_Config_0517, ) cs.store( group="vlm_config", @@ -883,8 +843,8 @@ def register_vlm(): cs.store( group="vlm_config", package="model.config.vlm_config", - name="cosmos_reason2_vlm_30b_a3b_private_gcp", - node=CosmosReason2_VLM_30b_a3b_Private_GCP_Config, + name="cosmos3_super_reasoner_vlm_gcp_0517", + node=Cosmos3SuperReasoner_VLM_GCP_Config_0517, ) cs.store( group="vlm_config", @@ -922,3 +882,15 @@ def register_vlm(): name="qwen3_vl_mot_vlm_32b_instruct_gcp", node=Qwen3VLMoT_VLM_32b_Instruct_GCP_Config, ) + cs.store( + group="vlm_config", + package="model.config.vlm_config", + name="cosmos3_edge_reasoner_vlm_gcp_4acb717", + node=Cosmos3EdgeReasoner_VLM_GCP_Config_4acb717, + ) + cs.store( + group="vlm_config", + package="model.config.vlm_config", + name="cosmos3_edge_reasoner_vlm_gcp_9b4c028", + node=Cosmos3EdgeReasoner_VLM_GCP_Config_9b4c028, + ) diff --git a/cosmos_framework/configs/base/vlm/config.py b/cosmos_framework/configs/base/vlm/config.py index 66dc3ed..4d1de85 100644 --- a/cosmos_framework/configs/base/vlm/config.py +++ b/cosmos_framework/configs/base/vlm/config.py @@ -4,10 +4,9 @@ from cosmos_framework.trainer import ImaginaireTrainer from cosmos_framework.utils import log from cosmos_framework.utils.config_helper import import_all_modules_from_package +from cosmos_framework.configs.base.defaults.checkpointer import register_checkpoint, register_ckpt_type from cosmos_framework.configs.base.vlm.defaults.callbacks import register_callbacks -from cosmos_framework.configs.base.vlm.defaults.checkpointer import register_checkpoint, register_ckpt_type from cosmos_framework.configs.base.vlm.defaults.config import Config - from cosmos_framework.configs.base.vlm.defaults.model import register_model from cosmos_framework.configs.base.vlm.defaults.optimizer import register_optimizer, register_scheduler from cosmos_framework.configs.base.vlm.defaults.vlm_policy import register_vlm_policy diff --git a/cosmos_framework/configs/base/vlm/defaults/callbacks.py b/cosmos_framework/configs/base/vlm/defaults/callbacks.py index 1910b63..3392205 100644 --- a/cosmos_framework/configs/base/vlm/defaults/callbacks.py +++ b/cosmos_framework/configs/base/vlm/defaults/callbacks.py @@ -12,7 +12,6 @@ from cosmos_framework.utils.lazy_config import LazyCall as L from cosmos_framework.utils.callback import LowPrecisionCallback, WandBCallback from cosmos_framework.callbacks.dataloader_state import DataLoaderStateCallback - from cosmos_framework.callbacks.grad_clip import GradClip from cosmos_framework.callbacks.hf_export import HFExportCallback from cosmos_framework.callbacks.iter_speed import IterSpeed @@ -47,7 +46,6 @@ def register_callbacks(): config=PLACEHOLDER, trainer=PLACEHOLDER, ), # reads model.precision; no extra kwarg needed - # nvtx=L(NVTXCallback)(synchronize=True), ) diff --git a/cosmos_framework/configs/base/vlm/defaults/dataloader.py b/cosmos_framework/configs/base/vlm/defaults/dataloader.py deleted file mode 100644 index 36b878d..0000000 --- a/cosmos_framework/configs/base/vlm/defaults/dataloader.py +++ /dev/null @@ -1,80 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: OpenMDW-1.1 - -from torch.utils.data import DataLoader - -from cosmos_framework.utils.lazy_config import LazyCall as L -from cosmos_framework.utils.config_helper import ConfigStore -from cosmos_framework.data.vfm.vlm.collate_fn import custom_collate -from cosmos_framework.data.vfm.vlm.debug_data_qwen import DebugQwenDataset -from cosmos_framework.data.vfm.vlm.dummy_data_qwen import DummyQwenDataset -from cosmos_framework.data.vfm.processors import build_processor_lazy - - -# Debug dataset -def create_debug_dataloader_config_qwen( - num_images, loss_on_completion_only: bool = True, use_dummy_image: bool = False -): - return L(DataLoader)( - dataset=L(DebugQwenDataset)( - tokenizer=L(build_processor_lazy)( - tokenizer_type="${model.config.policy.backbone.model_name}", - credentials="${checkpoint.load_from_object_store.credentials}", - bucket="${checkpoint.load_from_object_store.bucket}", - ), - num_images=num_images, - seq_len="${model.config.policy.model_max_length}", - image_token_len="${model.config.policy.qwen_max_video_token_length}", - # use_dummy_image=use_dummy_image, - ), - num_workers=8, - prefetch_factor=4, - batch_size=1, - sampler=None, - persistent_workers=False, - pin_memory=True, - collate_fn=custom_collate, - ) - - -def create_dummy_dataloader_config_qwen(): - return L(DataLoader)( - dataset=L(DummyQwenDataset)( - tokenizer=L(build_processor_lazy)( - tokenizer_type="${model.config.policy.backbone.model_name}", - credentials="${checkpoint.load_from_object_store.credentials}", - bucket="${checkpoint.load_from_object_store.bucket}", - ), - num_visual_tokens="${model.config.policy.qwen_max_video_token_length}", - total_tokens="${model.config.policy.model_max_length}", - batch_size="${dataloader_train.batch_size}", - ), - num_workers=8, - prefetch_factor=4, - batch_size=1, - sampler=None, - persistent_workers=False, - pin_memory=True, - collate_fn=custom_collate, - ) - - -def register_data_debug(): - cs = ConfigStore.instance() - for split in ["train", "val"]: - cs.store( - group=f"data_{split}", - package=f"dataloader_{split}", - name="debug_image_data_qwen", # This data is from pixtral model output, expected to have low loss ~1.4 - node=create_debug_dataloader_config_qwen(1), - ) - cs.store( - group=f"data_{split}", - package=f"dataloader_{split}", - name="dummy_image_data_qwen", - node=create_dummy_dataloader_config_qwen(), - ) - - -def register_data(): - register_data_debug() diff --git a/cosmos_framework/configs/base/vlm/defaults/policy_config.py b/cosmos_framework/configs/base/vlm/defaults/policy_config.py index 307eb92..c203168 100644 --- a/cosmos_framework/configs/base/vlm/defaults/policy_config.py +++ b/cosmos_framework/configs/base/vlm/defaults/policy_config.py @@ -29,7 +29,7 @@ class PolicyConfig: trainable_map: Union[str, None] = None monkey_patch_for_text_only_data: bool = False - # HF attention impl. Default "cosmos" routes through imaginaire.attention + # HF attention impl. Default "cosmos" routes through cosmos_framework.model.attention # (NATTEN/blackwell-fmha on GB200). Override to "flash_attention_2", # "sdpa", or "eager" for fallback. attn_implementation: str = "cosmos" @@ -53,5 +53,6 @@ class VLMModelConfig: ema: EMAConfig = EMAConfig(enabled=False) # Force deterministic kernels in Flash-Attention init (slower; required for - # parity bit-exactness) - deterministic: bool = False + # parity bit-exactness). VLM-only knob — consumed by VLMModel.__init__ via + # init_flash_attn_meta. + deterministic: bool = True diff --git a/cosmos_framework/data/imaginaire/webdataset/augmentors/image/__init__.py b/cosmos_framework/data/imaginaire/webdataset/augmentors/image/__init__.py new file mode 100644 index 0000000..28a81be --- /dev/null +++ b/cosmos_framework/data/imaginaire/webdataset/augmentors/image/__init__.py @@ -0,0 +1,2 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: OpenMDW-1.1 diff --git a/cosmos_framework/data/imaginaire/webdataset/augmentors/image/cropping.py b/cosmos_framework/data/imaginaire/webdataset/augmentors/image/cropping.py new file mode 100644 index 0000000..b34cb81 --- /dev/null +++ b/cosmos_framework/data/imaginaire/webdataset/augmentors/image/cropping.py @@ -0,0 +1,150 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: OpenMDW-1.1 + +from typing import Optional + +import torch +import torchvision.transforms.functional as transforms_F +from loguru import logger as logging + +from cosmos_framework.data.imaginaire.webdataset.augmentors.augmentor import Augmentor +from cosmos_framework.data.imaginaire.webdataset.augmentors.image.misc import obtain_augmentation_size, obtain_image_size + + +class CenterCrop(Augmentor): + def __init__(self, input_keys: list, output_keys: Optional[list] = None, args: Optional[dict] = None) -> None: + super().__init__(input_keys, output_keys, args) + + def __call__(self, data_dict: dict) -> dict: + r"""Performs center crop. + + Args: + data_dict (dict): Input data dict + Returns: + data_dict (dict): Output dict where images are center cropped. + We also save the cropping parameters in the aug_params dict + so that it will be used by other transforms. + """ + assert (self.args is not None) and ("size" in self.args), "Please specify size in args" + + img_size = obtain_augmentation_size(data_dict, self.args) + width, height = img_size + + orig_w, orig_h = obtain_image_size(data_dict, self.input_keys) + for key in self.input_keys: + data_dict[key] = transforms_F.center_crop(data_dict[key], [height, width]) + + # We also add the aug params we use. This will be useful for other transforms + crop_x0 = (orig_w - width) // 2 + crop_y0 = (orig_h - height) // 2 + cropping_params = { + "resize_w": orig_w, + "resize_h": orig_h, + "crop_x0": crop_x0, + "crop_y0": crop_y0, + "crop_w": width, + "crop_h": height, + } + + if "aug_params" not in data_dict: + data_dict["aug_params"] = dict() + + data_dict["aug_params"]["cropping"] = cropping_params + return data_dict + + +class BottomCrop(Augmentor): + def __init__(self, input_keys: list, output_keys: Optional[list] = None, args: Optional[dict] = None) -> None: + super().__init__(input_keys, output_keys, args) + + def __call__(self, data_dict: dict) -> dict: + r"""Crops rows from the bottom of the image/video to reach ``target_height``. + + The top of the frame is preserved (content is top-anchored). Width is unchanged. + Works for 3-D ``[C, H, W]`` images and 4-D ``[C, T, H, W]`` or ``[T, C, H, W]`` + videos — the last two dims are always treated as (H, W). + + Args: + data_dict (dict): Input data dict. ``self.args["target_height"]`` is the + desired output height. Source height must be ``>= target_height``. + + Returns: + data_dict (dict): Output dict where images are bottom-cropped and + ``image_size`` is updated to ``[target_h, w, orig_h, orig_w]`` to mirror + :class:`ReflectionPadding`'s contract. + """ + assert (self.args is not None) and ("target_height" in self.args), "Please specify target_height in args" + if self.output_keys is None: + self.output_keys = self.input_keys + + target_h = int(self.args["target_height"]) + orig_w, orig_h = obtain_image_size(data_dict, self.input_keys) + assert orig_h >= target_h, ( + f"BottomCrop requires source height >= target_height: got orig_h={orig_h}, target_h={target_h}" + ) + + for inp_key, out_key in zip(self.input_keys, self.output_keys): + tensor = data_dict[inp_key] + # Slice the last 2 dims; the second-to-last dim is height regardless of + # whether the tensor is CHW, CTHW, or TCHW. + data_dict[out_key] = tensor[..., :target_h, :] + + if out_key != inp_key: + del data_dict[inp_key] + + data_dict["image_size"] = torch.tensor([target_h, orig_w, orig_h, orig_w], dtype=torch.float) + + return data_dict + + +class RandomCrop(Augmentor): + def __init__(self, input_keys: list, output_keys: Optional[list] = None, args: Optional[dict] = None) -> None: + super().__init__(input_keys, output_keys, args) + + def __call__(self, data_dict: dict) -> dict: + r"""Performs random crop. + + Args: + data_dict (dict): Input data dict + Returns: + data_dict (dict): Output dict where images are center cropped. + We also save the cropping parameters in the aug_params dict + so that it will be used by other transforms. + """ + + img_size = obtain_augmentation_size(data_dict, self.args) + width, height = img_size + + orig_w, orig_h = obtain_image_size(data_dict, self.input_keys) + # Obtaining random crop coords + try: + crop_x0 = int(torch.randint(0, orig_w - width + 1, size=(1,)).item()) + crop_y0 = int(torch.randint(0, orig_h - height + 1, size=(1,)).item()) + except Exception as e: + logging.warning( + f"Random crop failed. Performing center crop, original_size(wxh): {orig_w}x{orig_h}, random_size(wxh): {width}x{height}" + ) + for key in self.input_keys: + data_dict[key] = transforms_F.center_crop(data_dict[key], [height, width]) + crop_x0 = (orig_w - width) // 2 + crop_y0 = (orig_h - height) // 2 + + # We also add the aug params we use. This will be useful for other transforms + cropping_params = { + "resize_w": orig_w, + "resize_h": orig_h, + "crop_x0": crop_x0, + "crop_y0": crop_y0, + "crop_w": width, + "crop_h": height, + } + + if "aug_params" not in data_dict: + data_dict["aug_params"] = dict() + + data_dict["aug_params"]["cropping"] = cropping_params + + # We must perform same random cropping for all input keys + for key in self.input_keys: + data_dict[key] = transforms_F.crop(data_dict[key], crop_y0, crop_x0, height, width) + return data_dict diff --git a/cosmos_framework/data/imaginaire/webdataset/augmentors/image/flip.py b/cosmos_framework/data/imaginaire/webdataset/augmentors/image/flip.py new file mode 100644 index 0000000..8f0bb7d --- /dev/null +++ b/cosmos_framework/data/imaginaire/webdataset/augmentors/image/flip.py @@ -0,0 +1,32 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: OpenMDW-1.1 + +from typing import Optional + +import torch +import torchvision.transforms.functional as transforms_F + +from cosmos_framework.data.imaginaire.webdataset.augmentors.augmentor import Augmentor + + +class HorizontalFlip(Augmentor): + def __init__(self, input_keys: list, output_keys: Optional[list] = None, args: Optional[dict] = None) -> None: + super().__init__(input_keys, output_keys, args) + + def __call__(self, data_dict: dict) -> dict: + r"""Performs horizontal flipping. + + Args: + data_dict (dict): Input data dict + Returns: + data_dict (dict): Output dict where images are center cropped. + """ + flip_enabled = getattr(self.args, "enabled", True) + if flip_enabled: + p = getattr(self.args, "prob", 0.5) + coin_flip = torch.rand(1).item() > p + for key in self.input_keys: + if coin_flip: + data_dict[key] = transforms_F.hflip(data_dict[key]) + + return data_dict diff --git a/cosmos_framework/data/imaginaire/webdataset/augmentors/image/misc.py b/cosmos_framework/data/imaginaire/webdataset/augmentors/image/misc.py new file mode 100644 index 0000000..d3e5216 --- /dev/null +++ b/cosmos_framework/data/imaginaire/webdataset/augmentors/image/misc.py @@ -0,0 +1,51 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: OpenMDW-1.1 + +from typing import Union + +import torch +from PIL import Image + + +def obtain_image_size(data_dict: dict, input_keys: list) -> tuple[int, int]: + r"""Function for obtaining the image size from the data dict. + + Args: + data_dict (dict): Input data dict + input_keys (list): List of input keys + Returns: + width (int): Width of the input image + height (int): Height of the input image + """ + + data1 = data_dict[input_keys[0]] + if isinstance(data1, Image.Image): + width, height = data1.size + elif isinstance(data1, torch.Tensor): + height, width = data1.size()[-2:] + else: + raise ValueError("data to random crop should be PIL Image or tensor") + + return width, height + + +def obtain_augmentation_size(data_dict: dict, augmentor_cfg: dict) -> Union[int, tuple]: + r"""Function for obtaining size of the augmentation. + When dealing with multi-aspect ratio dataloaders, we need to + find the augmentation size from the aspect ratio of the data. + If data_dict contains "_res_size_map" (e.g. from resolution sampling), + that map is used instead of augmentor_cfg["size"]. + + Args: + data_dict (dict): Input data dict + augmentor_cfg (dict): Augmentor config + Returns: + aug_size (int): Size of augmentation + """ + if "__url__" in data_dict and "aspect_ratio" in data_dict["__url__"].meta.opts: + aspect_ratio = data_dict["__url__"].meta.opts["aspect_ratio"] + else: # Non-webdataset format + aspect_ratio = data_dict["aspect_ratio"] + if "_res_size_map" in data_dict: + return data_dict["_res_size_map"][aspect_ratio] + return augmentor_cfg["size"][aspect_ratio] diff --git a/cosmos_framework/data/imaginaire/webdataset/augmentors/image/normalize.py b/cosmos_framework/data/imaginaire/webdataset/augmentors/image/normalize.py new file mode 100644 index 0000000..a949230 --- /dev/null +++ b/cosmos_framework/data/imaginaire/webdataset/augmentors/image/normalize.py @@ -0,0 +1,36 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: OpenMDW-1.1 + +from typing import Optional + +import torch +import torchvision.transforms.functional as transforms_F + +from cosmos_framework.data.imaginaire.webdataset.augmentors.augmentor import Augmentor + + +class Normalize(Augmentor): + def __init__(self, input_keys: list, output_keys: Optional[list] = None, args: Optional[dict] = None) -> None: + super().__init__(input_keys, output_keys, args) + + def __call__(self, data_dict: dict) -> dict: + r"""Performs data normalization. + + Args: + data_dict (dict): Input data dict + Returns: + data_dict (dict): Output dict where images are center cropped. + """ + assert self.args is not None, "Please specify args" + + mean = self.args["mean"] + std = self.args["std"] + + for key in self.input_keys: + if isinstance(data_dict[key], torch.Tensor): + data_dict[key] = data_dict[key].to(dtype=torch.get_default_dtype()).div(255) + else: + data_dict[key] = transforms_F.to_tensor(data_dict[key]) # division by 255 is applied in to_tensor() + + data_dict[key] = transforms_F.normalize(tensor=data_dict[key], mean=mean, std=std) + return data_dict diff --git a/cosmos_framework/data/imaginaire/webdataset/augmentors/image/padding.py b/cosmos_framework/data/imaginaire/webdataset/augmentors/image/padding.py new file mode 100644 index 0000000..e14d66f --- /dev/null +++ b/cosmos_framework/data/imaginaire/webdataset/augmentors/image/padding.py @@ -0,0 +1,60 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: OpenMDW-1.1 + +from typing import Optional + +import omegaconf +import torch +import torchvision.transforms.functional as transforms_F + +from cosmos_framework.data.imaginaire.webdataset.augmentors.augmentor import Augmentor +from cosmos_framework.data.imaginaire.webdataset.augmentors.image.misc import obtain_augmentation_size, obtain_image_size + + +class ReflectionPadding(Augmentor): + def __init__(self, input_keys: list, output_keys: Optional[list] = None, args: Optional[dict] = None) -> None: + super().__init__(input_keys, output_keys, args) + + def __call__(self, data_dict: dict) -> dict: + r"""Performs reflection padding. This function also returns a padding mask. + + Args: + data_dict (dict): Input data dict + Returns: + data_dict (dict): Output dict where images are center cropped. + """ + + assert self.args is not None, "Please specify args in augmentation" + if self.output_keys is None: + self.output_keys = self.input_keys + + # Obtain image and augmentation sizes + orig_w, orig_h = obtain_image_size(data_dict, self.input_keys) + target_size = obtain_augmentation_size(data_dict, self.args) + + assert isinstance(target_size, (tuple, omegaconf.listconfig.ListConfig)), "Please specify target size as tuple" + target_w, target_h = target_size + + target_w = int(target_w) + target_h = int(target_h) + + # One-sided padding (bottom and right only, content stays at top-left) + padding_right = target_w - orig_w + padding_bottom = target_h - orig_h + padding_vals = [0, 0, padding_right, padding_bottom] + + for inp_key, out_key in zip(self.input_keys, self.output_keys): + if max(padding_vals[0], padding_vals[2]) >= orig_w or max(padding_vals[1], padding_vals[3]) >= orig_h: + # In this case, we can't perform reflection padding. This is because padding values + # are larger than the image size. So, perform edge padding instead. + data_dict[out_key] = transforms_F.pad(data_dict[inp_key], padding_vals, padding_mode="edge") + else: + # Perform reflection padding + data_dict[out_key] = transforms_F.pad(data_dict[inp_key], padding_vals, padding_mode="reflect") + + if out_key != inp_key: + del data_dict[inp_key] + + data_dict["image_size"] = torch.tensor([target_h, target_w, orig_h, orig_w], dtype=torch.float) + + return data_dict diff --git a/cosmos_framework/data/imaginaire/webdataset/augmentors/image/resize.py b/cosmos_framework/data/imaginaire/webdataset/augmentors/image/resize.py new file mode 100644 index 0000000..82cdea9 --- /dev/null +++ b/cosmos_framework/data/imaginaire/webdataset/augmentors/image/resize.py @@ -0,0 +1,175 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: OpenMDW-1.1 + +from typing import Optional + +import omegaconf +import torchvision.transforms.functional as transforms_F + +from cosmos_framework.data.imaginaire.webdataset.augmentors.augmentor import Augmentor +from cosmos_framework.data.imaginaire.webdataset.augmentors.image.misc import obtain_augmentation_size, obtain_image_size + + +class ResizeSmallestSide(Augmentor): + def __init__(self, input_keys: list, output_keys: Optional[list] = None, args: Optional[dict] = None) -> None: + super().__init__(input_keys, output_keys, args) + + def __call__(self, data_dict: dict) -> dict: + r"""Performs resizing to smaller side + + Args: + data_dict (dict): Input data dict + Returns: + data_dict (dict): Output dict where images are resized + """ + + if self.output_keys is None: + self.output_keys = self.input_keys + assert self.args is not None, "Please specify args in augmentations" + + for inp_key, out_key in zip(self.input_keys, self.output_keys): + out_size = obtain_augmentation_size(data_dict, self.args) + assert isinstance(out_size, int), "Arg size in resize should be an integer" + data_dict[out_key] = transforms_F.resize( + data_dict[inp_key], + size=out_size, # type: ignore + interpolation=getattr(self.args, "interpolation", transforms_F.InterpolationMode.BICUBIC), + antialias=True, + ) + if out_key != inp_key: + del data_dict[inp_key] + return data_dict + + +class ResizeLargestSide(Augmentor): + def __init__(self, input_keys: list, output_keys: Optional[list] = None, args: Optional[dict] = None) -> None: + super().__init__(input_keys, output_keys, args) + + def __call__(self, data_dict: dict) -> dict: + r"""Performs resizing to larger side + + Args: + data_dict (dict): Input data dict + Returns: + data_dict (dict): Output dict where images are resized + """ + + if self.output_keys is None: + self.output_keys = self.input_keys + assert self.args is not None, "Please specify args in augmentations" + + for inp_key, out_key in zip(self.input_keys, self.output_keys): + out_size = obtain_augmentation_size(data_dict, self.args) + assert isinstance(out_size, int), "Arg size in resize should be an integer" + orig_w, orig_h = obtain_image_size(data_dict, self.input_keys) + + scaling_ratio = min(out_size / orig_w, out_size / orig_h) + target_size = [int(scaling_ratio * orig_h), int(scaling_ratio * orig_w)] + + data_dict[out_key] = transforms_F.resize( + data_dict[inp_key], + size=target_size, + interpolation=getattr(self.args, "interpolation", transforms_F.InterpolationMode.BICUBIC), + antialias=True, + ) + if out_key != inp_key: + del data_dict[inp_key] + return data_dict + + +class ResizeSmallestSideAspectPreserving(Augmentor): + def __init__(self, input_keys: list, output_keys: Optional[list] = None, args: Optional[dict] = None) -> None: + super().__init__(input_keys, output_keys, args) + + def __call__(self, data_dict: dict) -> dict: + r"""Performs aspect-ratio preserving resizing. + Image is resized to the dimension which has the smaller ratio of (size / target_size). + First we compute (w_img / w_target) and (h_img / h_target) and resize the image + to the dimension that has the smaller of these ratios. + + Args: + data_dict (dict): Input data dict + Returns: + data_dict (dict): Output dict where images are resized + """ + + if self.output_keys is None: + self.output_keys = self.input_keys + assert self.args is not None, "Please specify args in augmentations" + + img_size = obtain_augmentation_size(data_dict, self.args) + assert isinstance(img_size, (tuple, omegaconf.listconfig.ListConfig)), ( + f"Arg size in resize should be a tuple, get {type(img_size)}, {img_size}" + ) + img_w, img_h = img_size + + orig_w, orig_h = obtain_image_size(data_dict, self.input_keys) + scaling_ratio = max((img_w / orig_w), (img_h / orig_h)) + target_size = (int(scaling_ratio * orig_h + 0.5), int(scaling_ratio * orig_w + 0.5)) + + assert target_size[0] >= img_h and target_size[1] >= img_w, ( + f"Resize error. orig {(orig_w, orig_h)} desire {img_size} compute {target_size}" + ) + + for inp_key, out_key in zip(self.input_keys, self.output_keys): + data_dict[out_key] = transforms_F.resize( + data_dict[inp_key], + size=target_size, # type: ignore + interpolation=( + self.args["interpolation"] + if "interpolation" in self.args + else transforms_F.InterpolationMode.BICUBIC + ), + antialias=True, + ) + + if out_key != inp_key: + del data_dict[inp_key] + return data_dict + + +class ResizeLargestSideAspectPreserving(Augmentor): + def __init__(self, input_keys: list, output_keys: Optional[list] = None, args: Optional[dict] = None) -> None: + super().__init__(input_keys, output_keys, args) + + def __call__(self, data_dict: dict) -> dict: + r"""Performs aspect-ratio preserving resizing. + Image is resized to the dimension which has the larger ratio of (size / target_size). + First we compute (w_img / w_target) and (h_img / h_target) and resize the image + to the dimension that has the larger of these ratios. + + Args: + data_dict (dict): Input data dict + Returns: + data_dict (dict): Output dict where images are resized + """ + + if self.output_keys is None: + self.output_keys = self.input_keys + assert self.args is not None, "Please specify args in augmentations" + + img_size = obtain_augmentation_size(data_dict, self.args) + assert isinstance(img_size, (tuple, omegaconf.listconfig.ListConfig)), ( + f"Arg size in resize should be a tuple, get {type(img_size)}, {img_size}" + ) + img_w, img_h = img_size + + orig_w, orig_h = obtain_image_size(data_dict, self.input_keys) + scaling_ratio = min((img_w / orig_w), (img_h / orig_h)) + target_size = (int(scaling_ratio * orig_h + 0.5), int(scaling_ratio * orig_w + 0.5)) + + assert target_size[0] <= img_h and target_size[1] <= img_w, ( + f"Resize error. orig {(orig_w, orig_h)} desire {img_size} compute {target_size}" + ) + + for inp_key, out_key in zip(self.input_keys, self.output_keys): + data_dict[out_key] = transforms_F.resize( + data_dict[inp_key], + size=target_size, # type: ignore + interpolation=getattr(self.args, "interpolation", transforms_F.InterpolationMode.BICUBIC), + antialias=True, + ) + + if out_key != inp_key: + del data_dict[inp_key] + return data_dict diff --git a/cosmos_framework/data/vfm/action/action_normalization.py b/cosmos_framework/data/vfm/action/action_normalization.py index c58bb90..d553161 100644 --- a/cosmos_framework/data/vfm/action/action_normalization.py +++ b/cosmos_framework/data/vfm/action/action_normalization.py @@ -27,7 +27,7 @@ def load_action_stats(stats_path: str, stats_key: str = "global") -> dict[str, n elif stats_key != "global": raise KeyError(f"Action normalization stats block {stats_key!r} not found in {stats_path}.") stat_keys = {"mean", "std", "min", "max", "q01", "q99"} - return {key: np.array(value, dtype=np.float32) for key, value in raw.items() if key in stat_keys} + return {k: np.array(v, dtype=np.float32) for k, v in raw.items() if k in stat_keys} def normalize_action( @@ -35,7 +35,7 @@ def normalize_action( method: str, stats: dict[str, torch.Tensor], ) -> torch.Tensor: - """Normalize action tensor.""" + """Normalize action tensor (all dimensions including gripper).""" if method == "quantile": q01, q99 = stats["q01"], stats["q99"] denom = (q99 - q01).clamp(min=1e-8) diff --git a/cosmos_framework/data/vfm/action/domain_utils.py b/cosmos_framework/data/vfm/action/domain_utils.py index 917fe4c..6f433f7 100644 --- a/cosmos_framework/data/vfm/action/domain_utils.py +++ b/cosmos_framework/data/vfm/action/domain_utils.py @@ -14,9 +14,12 @@ "bridge_orig_lerobot": 7, "droid_lerobot": 8, "robomind-franka": 8, # Both Droid and RoboMIND-Franka are using robotiq and franka + "embodiment_b": 9, "robomind-franka-dual": 12, "robomind-ur": 13, "agibotworld": 15, + "embodiment_c_gripper": 15, + "embodiment_c_gripper_ext": 15, "fractal": 20, } @@ -31,8 +34,16 @@ "robomind-franka": 10, "robomind-franka-dual": 20, "robomind-ur": 10, + "embodiment_b": 30, "agibotworld": 29, + "embodiment_c_gripper": 29, + "embodiment_c_gripper_ext": 29, "fractal": 10, + # NOTE: ``libero`` (7/10/13 depending on ``rotation_space``) and ``hand_pose`` + # (variable with ``keypoint_option`` and ``rotation_format``) are absent + # because their raw width is set per-dataset at construction time. Inference + # in inverse_dynamics/policy modes is not supported for these domains until + # canonical widths are added here. } @@ -45,3 +56,20 @@ def get_domain_id(embodiment_type: str) -> int: f"Available embodiments: {sorted(EMBODIMENT_TO_DOMAIN_ID.keys())}" ) return EMBODIMENT_TO_DOMAIN_ID[key] + + +def get_action_dim(embodiment_type: str) -> int: + """Get the raw action dimension for a given embodiment type.""" + key = embodiment_type.lower().strip() + if key not in EMBODIMENT_TO_RAW_ACTION_DIM: + raise KeyError( + f"Unknown embodiment type: {embodiment_type!r}. " + f"Available embodiments: {sorted(EMBODIMENT_TO_RAW_ACTION_DIM.keys())}" + ) + return EMBODIMENT_TO_RAW_ACTION_DIM[key] + + +def is_valid_domain_name(embodiment_type: str) -> bool: + """Check if the given embodiment type is recognized.""" + key = embodiment_type.lower().strip() + return key in EMBODIMENT_TO_RAW_ACTION_DIM diff --git a/cosmos_framework/data/vfm/action/json_formatter.py b/cosmos_framework/data/vfm/action/json_formatter.py index b511e93..201a76e 100644 --- a/cosmos_framework/data/vfm/action/json_formatter.py +++ b/cosmos_framework/data/vfm/action/json_formatter.py @@ -7,9 +7,9 @@ import torch +from cosmos_framework.utils import log from cosmos_framework.data.vfm.action.viewpoint_utils import DEFAULT_VIEWPOINT_TEMPLATES from cosmos_framework.data.vfm.utils import VIDEO_RES_SIZE_INFO -from cosmos_framework.utils import log def _should_append_idle_frame_info(mode: object) -> bool: diff --git a/cosmos_framework/data/vfm/action/transforms.py b/cosmos_framework/data/vfm/action/transforms.py index d17b141..9cebb4e 100644 --- a/cosmos_framework/data/vfm/action/transforms.py +++ b/cosmos_framework/data/vfm/action/transforms.py @@ -19,6 +19,7 @@ import torch import torchvision.transforms.functional as transforms_F +from cosmos_framework.utils import log from cosmos_framework.data.vfm.action.json_formatter import ActionPromptJsonFormatter from cosmos_framework.data.vfm.action.viewpoint_utils import ViewpointTextInfo from cosmos_framework.data.vfm.augmentors.duration_fps_text_timestamps import DurationFPSTextTimeStamps @@ -27,7 +28,6 @@ from cosmos_framework.data.vfm.augmentors.text_tokenizer import TextTokenizerTransform from cosmos_framework.data.vfm.sequence_packing import SequencePlan from cosmos_framework.data.vfm.utils import VIDEO_RES_SIZE_INFO -from cosmos_framework.utils import log from cosmos_framework.utils.vfm.data_utils import get_vision_data_resolution @@ -309,7 +309,6 @@ def build_sequence_plan_from_mode( base_action_length = action_length - num_history_actions if mode == "forward_dynamics": condition_frame_indexes_action = list(range(action_length)) - # This currently assumes that the action length is the same as the video length - 1 # and if action length is the same as the video length, then the first action is the conditioning action elif base_action_length == video_length - 1: diff --git a/cosmos_framework/data/vfm/augmentor_provider.py b/cosmos_framework/data/vfm/augmentor_provider.py index 2ace65c..5ce1727 100644 --- a/cosmos_framework/data/vfm/augmentor_provider.py +++ b/cosmos_framework/data/vfm/augmentor_provider.py @@ -564,6 +564,9 @@ def get_video_augmentor_v3( conditioning_config = kwargs.get("conditioning_config", None) uniform_conditioning = kwargs.get("uniform_conditioning", False) temporal_compression_factor = kwargs.get("temporal_compression_factor", 4) + causal_vae = kwargs.get("causal_vae", True) + uniae_pad_frames = kwargs.get("uniae_pad_frames", None) + uniae_chunk_frames = kwargs.get("uniae_chunk_frames", None) print("Running video_basic_augmentor_v3...") augmentors = { @@ -577,6 +580,9 @@ def get_video_augmentor_v3( "min_stride": min_stride, "seek_mode": "exact", # Change to "approximate"? "dataset_resolution_type": dataset_resolution_type, + "causal_vae": causal_vae, + "uniae_pad_frames": uniae_pad_frames, + "uniae_chunk_frames": uniae_chunk_frames, }, ), "merge_datadict": L(merge_datadict.DataDictMerger)( @@ -670,7 +676,6 @@ def get_video_augmentor_v3( return augmentors - # Use video_basic_augmentor_v3_json_caption instead. @augmentor_register("video_basic_augmentor_v3_with_audio") def get_video_augmentor_v3_with_audio( @@ -829,6 +834,7 @@ def get_video_augmentor_v3_json_caption( conditioning_config = kwargs.get("conditioning_config", None) uniform_conditioning = kwargs.get("uniform_conditioning", False) temporal_compression_factor = kwargs.get("temporal_compression_factor", 4) + causal_vae = kwargs.get("causal_vae", True) print("Running video_augmentor_v3_json_caption...") augmentors = { @@ -856,6 +862,7 @@ def get_video_augmentor_v3_json_caption( "extract_audio": extract_audio, "audio_sample_rate": audio_sample_rate, "emit_placeholder_sound": not extract_audio, + "causal_vae": causal_vae, }, ), "merge_datadict": L(merge_datadict.DataDictMerger)( diff --git a/cosmos_framework/data/vfm/augmentors/idle_frames_text_info.py b/cosmos_framework/data/vfm/augmentors/idle_frames_text_info.py index 302715c..34f6116 100644 --- a/cosmos_framework/data/vfm/augmentors/idle_frames_text_info.py +++ b/cosmos_framework/data/vfm/augmentors/idle_frames_text_info.py @@ -8,7 +8,7 @@ frames (i.e. the relative-pose delta is close to identity and the gripper command does not change). The upstream dataset is responsible for populating ``data_dict[idle_frames_key]`` via -:func:`projects.cosmos3.vfm.datasets.action.pose_utils.compute_idle_frames`. +:func:`cosmos_framework.data.vfm.action.pose_utils.compute_idle_frames`. Per-field dropout (default 5%) is applied here, matching Pi0.7's approach of independently dropping each metadata component. This is complementary to the diff --git a/cosmos_framework/data/vfm/augmentors/interleaved_video_parsing.py b/cosmos_framework/data/vfm/augmentors/interleaved_video_parsing.py index aea759e..41e4e4c 100644 --- a/cosmos_framework/data/vfm/augmentors/interleaved_video_parsing.py +++ b/cosmos_framework/data/vfm/augmentors/interleaved_video_parsing.py @@ -414,7 +414,7 @@ def __call__(self, data_dict: dict) -> dict | None: ) # [C,T,H,W] num_multiplier = (end_frame - start_frame) / self.num_frames - + # NOTE: matches legacy VideoParsing.__call__ output keys exactly. Do NOT add # variable-length fields like ``frame_indices`` here -- ``video_flatten_keys`` in # ``get_video_transfer_augmentor`` lists ``frame_indices``, and surfacing a # per-sample list there would crash ``custom_collate_fn`` (default_collate requires diff --git a/cosmos_framework/data/vfm/augmentors/pkl_to_media.py b/cosmos_framework/data/vfm/augmentors/pkl_to_media.py index 54d47b8..aa9eb21 100644 --- a/cosmos_framework/data/vfm/augmentors/pkl_to_media.py +++ b/cosmos_framework/data/vfm/augmentors/pkl_to_media.py @@ -27,14 +27,12 @@ def token_to_pixels(token_length: int, patch_size: int = 14, temporal_patch_size: int = 2) -> int: """Convert token length to pixels based on patch size and temporal patch size.""" - merged_patch_size = patch_size * 2 return token_length * merged_patch_size**2 * temporal_patch_size def pixels_to_token(pixels: int, patch_size: int = 14, temporal_patch_size: int = 2) -> int: """Convert pixels to token length based on patch size and temporal patch size.""" - merged_patch_size = patch_size * 2 return pixels // merged_patch_size**2 // temporal_patch_size diff --git a/cosmos_framework/data/vfm/augmentors/text_transforms_for_image.py b/cosmos_framework/data/vfm/augmentors/text_transforms_for_image.py index edaef11..d38fae4 100644 --- a/cosmos_framework/data/vfm/augmentors/text_transforms_for_image.py +++ b/cosmos_framework/data/vfm/augmentors/text_transforms_for_image.py @@ -8,13 +8,17 @@ from cosmos_framework.data.imaginaire.webdataset.augmentors.v3_text_transforms import pad_and_resize from cosmos_framework.data.imaginaire.webdataset.augmentors.augmentor import Augmentor from cosmos_framework.utils import log -from cosmos_framework.data.vfm.data_sources.data_registration import _CAPTION_EMBEDDING_KEY_MAPPING_IMAGES # For the qwen captions, we have 3 variants: short, medium, long # In addition, for synthetic data, we create prompt embeddings as well. # There is quite a bit of entropy in the way prompt data is saved. # Captions are saved as "prompts", while the corresponding embeddings are saved as "original_prompt" # This part will be cleaned after synthetic data is cleaned to be in the same format as real data. +_CAPTION_EMBEDDING_KEY_MAPPING_IMAGES = { + "ai_v3p1": "ai_v3p1", + "qwen2p5_7b_v4": "qwen2p5_7b_v4", + "prompts": "qwen2p5_7b_v4", +} _AVAILABLE_QWEN_CAPTIONS = ["qwen2p5_7b_short", "qwen2p5_7b_medium", "qwen2p5_7b_long"] _AVAILABLE_QWEN3_30B_A3B_CAPTIONS = [ "qwen3_30b_a3b_short", diff --git a/cosmos_framework/data/vfm/augmentors/transfer_control_transform.py b/cosmos_framework/data/vfm/augmentors/transfer_control_transform.py index 74fb523..a18e8fd 100644 --- a/cosmos_framework/data/vfm/augmentors/transfer_control_transform.py +++ b/cosmos_framework/data/vfm/augmentors/transfer_control_transform.py @@ -5,7 +5,7 @@ Augmentors for transfer (control-conditioned) image and video generation in the cosmos3 VFM pipeline. Transfer training conditions the model on control signals (edge, blur, depth, or segmentation) -to generate images or videos, aligned with cosmos_framework/transfer2. This module provides: +to generate images or videos, aligned with cosmos/transfer2. This module provides: - **TransferToTrainingFormat**: Converts (control_input, target) into the joint dataloader format with SequencePlan (condition frame + generated frame), for both image and video outputs. diff --git a/cosmos_framework/data/vfm/augmentors/video_parsing.py b/cosmos_framework/data/vfm/augmentors/video_parsing.py index cfaa934..8972818 100644 --- a/cosmos_framework/data/vfm/augmentors/video_parsing.py +++ b/cosmos_framework/data/vfm/augmentors/video_parsing.py @@ -345,7 +345,7 @@ def __call__(self, data_dict: dict) -> dict | None: video_info["video"] = video_frames video_info["num_multiplier"] = num_multiplier # Store the frame skipping multiplier - + # NOTE: Explaining the logic of conditioning FPS calculation: # 1. Our video parser stores the original video FPS of the video. # 2. We have multiple modes of frame selection -- consecutive chunk of frames or subsampled frames. # Here's what we do in each case: @@ -434,6 +434,20 @@ def __init__(self, input_keys: list, output_keys: Optional[list] = None, args: O self.dataset_resolution_type = args.get("dataset_resolution_type", "all") self.resolution_tier = _DATASET_RESOLUTION_TIER.get(self.dataset_resolution_type) + # VAE temporal alignment mode. + # causal_vae=True (default): align to 1+4N (causal VAE, e.g. Wan 2.2) + # causal_vae=False: align to 4N (non-causal VAE, e.g. UniAE) + self.causal_vae = args.get("causal_vae", True) + self.uniae_pad_frames = args.get("uniae_pad_frames", None) + self.uniae_chunk_frames = args.get("uniae_chunk_frames", None) + if self.uniae_chunk_frames is not None: + assert self.uniae_pad_frames is not None, ( + "uniae_pad_frames must be specified if uniae_chunk_frames is specified" + ) + assert self.uniae_chunk_frames > 2 * self.uniae_pad_frames, ( + "uniae_chunk_frames must be greater than 2 * uniae_pad_frames" + ) + def _sample_stride_with_bias(self, max_stride: int, min_stride: int = 1) -> int: """Sample a stride from [min_stride, max_stride] with bias controlled by low_fps_bias. @@ -520,7 +534,6 @@ def _validate_and_probe(self, video: Optional[bytes], meta_dict: dict, data_dict return True def __call__(self, data_dict: dict) -> dict | None: - # if in future we need to train with batch size > 1, need to pad frames try: meta_dict = data_dict[self.meta_key] @@ -569,11 +582,35 @@ def __call__(self, data_dict: dict) -> dict | None: stride = self._sample_stride_with_bias(self.max_stride, self.min_stride) frame_indices = np.arange(0, num_video_frames, stride).tolist() - # VAE compress temporal by 4x, with 1 as condition - # thus the max_video_frames must be 1 + 4N + # Align frame count to VAE temporal compression requirement. + # causal_vae=True: 1+4N (causal VAE, e.g. Wan 2.2) + # causal_vae=False: 4N (non-causal VAE, e.g. UniAE) num_video_frames = min(len(frame_indices), self.args.get("max_num_frames", 1000)) - N = (num_video_frames - 1) // 4 - num_video_frames = 1 + 4 * N + if self.causal_vae: + N = (num_video_frames - 1) // 4 + num_video_frames = 1 + 4 * N + else: + # If this is UniAE, we need to align the frame count to the chunk size and padding. + if self.uniae_chunk_frames is not None: + # trim excess frames + effective_chunk_frames = self.uniae_chunk_frames - 2 * self.uniae_pad_frames + while ( + num_video_frames % effective_chunk_frames != 0 + and (num_video_frames % effective_chunk_frames + 2 * self.uniae_pad_frames) % 4 != 0 + and num_video_frames > 0 + ): + num_video_frames -= 1 + + if num_video_frames == 0: + log.warning( + f"VideoParsingWithFullFrames: video too short for UniAE. " + f"url: {data_dict['__url__']}, key: {data_dict['__key__']}", + rank0_only=False, + ) + return None + else: + N = num_video_frames // 4 + num_video_frames = 4 * N frame_indices = frame_indices[0:num_video_frames] frame_batch = video_decoder.get_frames_at(frame_indices) @@ -698,7 +735,6 @@ def __init__(self, input_keys: list, output_keys: Optional[list] = None, args: O super().__init__(input_keys, output_keys, args) def __call__(self, data_dict: dict) -> dict | None: - # if in future we need to train with batch size > 1, need to pad frames try: meta_dict = data_dict[self.meta_key] @@ -772,11 +808,16 @@ def __call__(self, data_dict: dict) -> dict | None: stride = self._sample_stride_with_bias(self.max_stride, self.min_stride) frame_indices = np.arange(chunk_start_clamped, chunk_end_clamped, stride).tolist() - # VAE compress temporal by 4x, with 1 as condition - # thus the max_video_frames must be 1 + 4N + # Align frame count to VAE temporal compression requirement. + # causal_vae=True: 1+4N (causal VAE, e.g. Wan 2.2) + # causal_vae=False: 4N (non-causal VAE, e.g. UniAE) num_video_frames = min(len(frame_indices), self.args.get("max_num_frames", 1000)) - N = (num_video_frames - 1) // 4 - num_video_frames = 1 + 4 * N + if self.causal_vae: + N = (num_video_frames - 1) // 4 + num_video_frames = 1 + 4 * N + else: + N = num_video_frames // 4 + num_video_frames = 4 * N if num_video_frames < 1: log.warning( f"VideoParsingChunkedFrames: chunk too short for stride. " diff --git a/cosmos_framework/data/vfm/augmentors/vlm/nvlm_data_unify.py b/cosmos_framework/data/vfm/augmentors/vlm/nvlm_data_unify.py deleted file mode 100644 index eb029eb..0000000 --- a/cosmos_framework/data/vfm/augmentors/vlm/nvlm_data_unify.py +++ /dev/null @@ -1,120 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: OpenMDW-1.1 - -"""Visual-Text Transformations or Augmentations.""" - -import io -from typing import Dict, Optional - -from PIL import Image - -from cosmos_framework.data.imaginaire.webdataset.augmentors.augmentor import Augmentor -from cosmos_framework.utils import log -from cosmos_framework.data.vfm.augmentors.vlm.nvlm_sample_loaders_and_part_filters import ( - get_data_class, - get_part_filter, - get_sample_loader, -) - - -class NVLMImageDataUnify(Augmentor): - """ - This augmentor is used to unify the data format of the nvlm data. - It will take the raw nvlm data tar and convert it to a dictionary with the following keys: - { - "__url__": str, - "__key__": str, - "data_class": str, - "images": List[PIL.Image.Image], - "text": str, - "words_boxes": Optional[List[List[int]]], - "words_text": Optional[List[str]], - "similarity_matrix": Optional[List[List[float]]], - } - """ - - def __init__( - self, - input_keys: list = ["raw_nvlm"], - output_keys: Optional[list] = [], - args: Optional[dict] = None, - data_path_prefix: list[str] = [ - "cosmos_framework/ar/v2/nvlm/", - ], # prefix of the data in s3 - ) -> None: - super().__init__(input_keys, output_keys, args) - self.data_path_prefix = data_path_prefix - - def convert_image(self, img): - try: - if isinstance(img, bytes): - img = Image.open(io.BytesIO(img)).convert("RGB") - elif isinstance(img, Image.Image): - img = img.convert("RGB") - pass # Image is already in PIL format - elif isinstance(img, list): - for i in range(len(img)): - img[i], success = self.convert_image(img[i]) - if not success: - return Image.new("RGB", (256, 256), (0, 0, 0)), False - return img, True - else: - raise ValueError(f"Invalid image type: {type(img)}") - - success = True - except Exception as e: - log.warning(f"Error processing image: {e}. Creating an empty black image.", rank0_only=False) - img = Image.new("RGB", (256, 256), (0, 0, 0)) # Creates a 256x256 black image - success = False - return img, success - - def __call__(self, data_dict: Dict) -> Dict: - url = data_dict["__url__"] - data_path = "/".join(url.path.split("/")[:-1]) # remove the last part of the path - sample_loader = get_sample_loader(data_path) - part_filter = get_part_filter(data_path) - data_class = get_data_class(data_path) - assert sample_loader is not None and part_filter is not None and data_class is not None, ( - f"sample_loader({sample_loader}) or part_filter({part_filter}) or data_class({data_class}) is not found for {data_path}" - ) - - raw = {"__url__": url, "__key__": data_dict["__key__"]} - output = {"__url__": url, "__key__": data_dict["__key__"]} - for k, v in data_dict.items(): - ext = k.split(".")[-1] - if part_filter(ext): - raw[ext] = v - try: - output_converted = sample_loader(raw) - # Here output_converted will be a dictionary with the following keys: - # { - # "__key__": str, - # "image": PIL.Image.Image, - # "images": List[PIL.Image.Image], - # "text": str, - # "words_boxes": Optional - # "words_text": Optional - # "similarity_matrix": Optional - # } - except Exception as e: - log.warning( - f"Error in sample_loader: {e}, sample_loader: {sample_loader}, data_path: {data_path}, raw: {raw.keys()}, original_data_dict: {data_dict.keys()}, __url__: {url}, __key__: {data_dict['__key__']}" - ) - return None - - output.update(output_converted) - if "image" not in output_converted and "images" not in output_converted: - success = False - log.warning(f"image not found in {output_converted.keys()}") - if "image" in output_converted: # Single image case - img, success = self.convert_image(output["image"]) - output["images"] = [img] # What should be the format for the iamges - elif "images" in output_converted: - output["images"] = output_converted["images"] - output["images"], success = self.convert_image(output["images"]) - if not success: - log.warning(f"image conversion failed for {data_dict['__key__']} url: {url} | Skip this data") - return None - output["data_class"] = data_class - - return output diff --git a/cosmos_framework/data/vfm/augmentors/vlm/nvlm_sample_loaders_and_part_filters.py b/cosmos_framework/data/vfm/augmentors/vlm/nvlm_sample_loaders_and_part_filters.py deleted file mode 100644 index fabe0c3..0000000 --- a/cosmos_framework/data/vfm/augmentors/vlm/nvlm_sample_loaders_and_part_filters.py +++ /dev/null @@ -1,2815 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: OpenMDW-1.1 - -# Combined Sample Loaders -# Auto-generated script combining all sample_loader.py files (Dont edit this file! Edit the projects/cosmos/reasoning/v1/scripts/create_sample_loader_and_part_filter_file.py instead) - -import io - -import torch -from PIL import Image - -from cosmos_framework.utils import log -from cosmos_framework.data.vfm.data_sources.vlm.nvlm import data_path_mapping - -# This file was automatically generated by `nvgpt4 data prepare`. - -# import torch - - -def sample_loader_0(raw: dict) -> dict: # Note: Images are already decoded to tensors - - if "text" in raw: - caption = raw["text"] - else: - caption = raw["json"]["caption"] - return dict( - __key__=raw["__key__"], - image=raw["jpg"], # expected type: torch.Tensor - caption=caption, # expected type: str - ) - - -def part_filter_0(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("json", "jpg", "text") - - -# This file was automatically generated by `energon prepare`. - - - -def sample_loader_1(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - images = [raw["jpg"]] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_1(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("jpg", "json") - - -# This file was automatically generated by `energon prepare`. - - - -def sample_loader_2(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - images = [raw["jpg"]] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_2(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("jpg", "json") - - -# This file was automatically generated by `energon prepare`. - - - -def sample_loader_3(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - images = [raw["jpg"]] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_3(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("jpg", "json") - - -# This file was automatically generated by `energon prepare`. - - - -def sample_loader_4(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - images = [raw["jpg"]] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_4(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("jpg", "json") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_5(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - images = [raw["jpg"]] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_5(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("json", "jpg") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_6(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - - key = raw["__key__"] - if "docvqa" in key: - context = json_item["question"] - answers = json_item["answers"] - image = raw["jpg"] - answer_weights = json_item["answer_weights"] - elif "textvqa" in key or "lrv_instruct" in key: - context = json_item["question"] - answers = json_item["answer"] - image = raw["jpg"] - answer_weights = None - elif "stvqa" in key: - context = json_item["question"] - answers = json_item["answers"] - image = raw["jpg"] - answer_weights = [1.0] * len(json_item["answers"]) - elif "chartqa" in key: - context = json_item["query"] - answers = json_item["label"] - image = raw["png"] - answer_weights = None - elif "screenqa" in key: - image = raw["jpg"] - context = json_item["question"] - answers = json_item["ground_truth"] - answer_weights = [1.0] * len(json_item["ground_truth"]) - elif "HME100K" in key: - image = raw["jpg"] - context = "Please write out the expression of the formula in the image using LaTeX format." - answers = json_item["latex_formula"] - answer_weights = None - else: # scale, textbook - image = raw["jpg"] - context = json_item["question"] - answers = json_item["answer"] - answer_weights = None - - return dict( - __key__=key, - image=image, - context=context, - answers=answers, - answer_weights=answer_weights, - ) - - -def part_filter_6(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("json", "jpg", "png") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_7(raw: dict) -> dict: # Note: Images are already decoded to tensors - j = raw["json"] - - return dict( - __key__=raw["__key__"], - image=raw["jpg"], # expected type: torch.Tensor - context=j["question_string"], # expected type: str - answers=j["answer"], # expected type: typing.Union[typing.List[str], NoneType], default: None - answer_weights=None, # expected type: typing.Union[torch.Tensor, NoneType], default: None - ) - - -def part_filter_7(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("jpg", "json") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_8(raw: dict) -> dict: # Note: Images are already decoded to tensors - j = raw["json"] - - return dict( - __key__=raw["__key__"], - image=raw["jpg"], # expected type: torch.Tensor - context=j["question"], # expected type: str - answers=str(j["answer"]), # expected type: typing.Optional[typing.List[str]], default: None - answer_weights=None, # expected type: typing.Optional[torch.Tensor], default: None - ) - - -def part_filter_8(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("jpg", "json") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_9(raw: dict) -> dict: # Note: Images are already decoded to tensors - j = raw["json"] - - return dict( - __key__=raw["__key__"], - image=raw["jpg"], # expected type: torch.Tensor - context=j["question"].strip(), # expected type: str - answers=j["gt_answer"].strip(), # expected type: typing.Union[typing.List[str], NoneType], default: None - answer_weights=None, # expected type: typing.Union[torch.Tensor, NoneType], default: None - ) - - -def part_filter_9(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("json", "jpg") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_10(raw: dict) -> dict: # Note: Images are already decoded to tensors - j = raw["json"] - return dict( - image=raw["jpg"], # expected type: torch.Tensor - context=j["question"], # expected type: str - answers=j["answer"], # expected type: typing.Optional[typing.List[str]], default: None - answer_weights=None, # expected type: typing.Optional[torch.Tensor], default: None - ) - - -def part_filter_10(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("json", "jpg") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_11(raw: dict) -> dict: # Note: Images are already decoded to tensors - j = raw["json"] - - return dict( - __key__=raw["__key__"], - image=raw["jpg"], - context=j["question"], - answers=j["answer"], - answer_weights=None, - ) - - -def part_filter_11(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("jpg", "json") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_12(raw: dict) -> dict: # Note: Images are already decoded to tensors - j = raw["json"] - return dict( - __key__=raw["__key__"], - image=raw["jpg"], # expected type: torch.Tensor - context=j["question"], # expected type: str - answers=j["answer"], # expected type: typing.Optional[typing.List[str]], default: None - answer_weights=None, # expected type: typing.Optional[torch.Tensor], default: None - ) - - -def part_filter_12(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("jpg", "json") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_13(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - key = raw["__key__"] - - if "geoqa_plus" in key or "tqa" in key: - return dict( - __key__=raw["__key__"], - image=raw["jpg"], - context=json_item["question"], - choices=json_item["choices"], - correct_choice_idx=json_item["correct_answer_index"], - ) - elif "geometry3k" in key: - return dict( - __key__=raw["__key__"], - image=raw["jpg"], - context=json_item["question"], - choices=json_item["choices"], - correct_choice_idx=ord(json_item["answer"].lower()) - 97, - ) - else: # science_qa, ai2d - image_key = "png" if "png" in raw else "jpg" - if image_key not in raw: - log.warning(f"Image key {image_key} not found in with raw keys: {raw.keys()}") - return dict( - __key__=raw["__key__"], # science_qa_sample_{idx} - image=raw[image_key], # expected type: torch.Tensor - context=json_item["question"], # expected type: str - choices=json_item["choices"], # expected type: typing.Union[typing.List[str], NoneType], default: None - correct_choice_idx=json_item["correct_choice_index"], - ) - - -def part_filter_13(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("json", "png", "jpg") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_14(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - - return dict( - __key__=raw["__key__"], # arxiv_qa_sample_{idx} - image=raw["jpg"], # expected type: torch.Tensor - context=json_item["question"], # expected type: str - choices=json_item["options"], # expected type: typing.Union[typing.List[str], NoneType], default: None - correct_choice_idx=json_item["correct_choice_index"], - ) - - -def part_filter_14(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("jpg", "json") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_15(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - - if json_item["question_type"] == "multi_choice": - correct_choice_idx = json_item["choices"].index(json_item["answer"]) - return dict( - __key__=raw["__key__"], - image=raw["jpg"], - context=json_item["question"], - choices=json_item["choices"], - correct_choice_idx=correct_choice_idx, - ) - else: - # A temporary hack for non multi-choice samples. - # If correct_choice_idx=-1, we should route it to the VQAWebdataset dataloading method. - # (74.7% free-text questions, 25.3% multi-choice questions) - return dict( - __key__=raw["__key__"], - image=raw["jpg"], - context=json_item["question"], - choices=[json_item["answer"]], - correct_choice_idx=-1, - ) - - -def part_filter_15(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("jpg", "json") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_16(raw: dict) -> dict: # Note: Images are already decoded to tensors - j = raw["json"] - - return dict( - __key__="llava-{}".format(raw["__key__"]), images=[raw["jpg"]], texts=j["conversations"], similarity_matrix=None - ) - - -def part_filter_16(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("jpg", "json") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_17(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - images = [raw["jpg"]] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_17(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("json", "jpg") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_18(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - images = [raw["jpg"]] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_18(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("json", "jpg") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_19(raw: dict) -> dict: # Note: Images are already decoded to tensors - - json_item = raw["json"] - images = [raw["jpg"]] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, # expected type: torch.Tensor - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_19(part: str) -> bool: - - return part in ("jpg", "json") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_20(raw: dict) -> dict: # Note: Images are already decoded to tensors - j = raw["json"] - - return dict( - __key__="llava-{}".format(raw["__key__"]), images=[raw["jpg"]], texts=j["conversations"], similarity_matrix=None - ) - - -def part_filter_20(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("jpg", "json") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_21(raw: dict) -> dict: # Note: Images are already decoded to tensors - j = raw["json"] - - return dict( - __key__="llava-{}".format(raw["__key__"]), images=[raw["png"]], texts=j["conversations"], similarity_matrix=None - ) - - -def part_filter_21(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("json", "png") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_22(raw: dict) -> dict: # Note: Images are already decoded to tensors - - json_item = raw["json"] - images = [raw["jpg"]] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, # expected type: torch.Tensor - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_22(part: str) -> bool: - - return part in ("jpg", "json") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_23(raw: dict) -> dict: # Note: Images are already decoded to tensors - - json_item = raw["json"] - images = [raw["jpg"]] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, # expected type: torch.Tensor - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_23(part: str) -> bool: - - return part in ("jpg", "json") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_24(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - images = [raw["jpg"]] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_24(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("json", "jpg") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_25(raw: dict) -> dict: # Note: Images are already decoded to tensors - j = raw["json"] - - return dict( - __key__="llava-{}".format(raw["__key__"]), images=[raw["jpg"]], texts=j["conversations"], similarity_matrix=None - ) - - -def part_filter_25(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("json", "jpg") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_26(raw: dict) -> dict: # Note: Images are already decoded to tensors - j = raw["json"] - - return dict( - __key__="llava-{}".format(raw["__key__"]), images=[raw["jpg"]], texts=j["conversations"], similarity_matrix=None - ) - - -def part_filter_26(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("jpg", "json") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_27(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - images = [raw["jpg"]] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_27(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("json", "jpg") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_28(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - images = [raw["jpg"]] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_28(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("jpg", "json") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_29(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - images = [raw["jpg"]] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_29(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("json", "jpg") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_30(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - images = [raw["jpg"]] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_30(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("json", "jpg") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_31(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - images = [raw["jpg"]] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_31(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("jpg", "json") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_32(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - images = [raw["jpg"]] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_32(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("json", "jpg") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_33(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - images = [raw["jpg"]] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_33(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("json", "jpg") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_34(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - images = [raw["jpg"]] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_34(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("json", "jpg") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_35(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - images = [raw["jpg"]] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_35(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("jpg", "json") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_36(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - images = [raw["jpg"]] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_36(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("jpg", "json") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_37(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - images = [raw["jpg"]] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_37(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("json", "jpg") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_38(raw: dict) -> dict: - j = raw["json"] - - if "ReCTs" in raw["__key__"]: - return dict( - __key__=raw["__key__"], - image=raw["jpg"], - text="", - words_boxes=j["quads_1k_normalized"], - words_text=j["texts"], - ) - else: # coco-text-multi, textocr-multi - return dict( - __key__=raw["__key__"], - image=raw["jpg"], - text="", - words_boxes=j["bboxes_1k_normalized"], - words_text=j["texts"], - ) - - -def part_filter_38(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("jpg", "json") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_39(raw: dict) -> dict: # Note: Images are already decoded to tensors - j = raw["json"] - return dict( - image=raw["jpg"], # expected type: torch.Tensor - text=" ".join(j["lines"]["text"]), # expected type: str - ) - - -def part_filter_39(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("jpg", "json") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_40(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - images = [raw["jpg"]] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_40(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("json", "jpg") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_41(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - images = [raw["jpg"]] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_41(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("jpg", "json") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_42(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - - img = Image.open(io.BytesIO(raw["img"])) - images = [img] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_42(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("img", "json") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_43(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - - img = Image.open(io.BytesIO(raw["img"])) - images = [img] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_43(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("img", "json") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_44(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - - img = Image.open(io.BytesIO(raw["img"])) - images = [img] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_44(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("img", "json") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_45(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - - img = Image.open(io.BytesIO(raw["img"])) - images = [img] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_45(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("img", "json") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_46(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - - img = Image.open(io.BytesIO(raw["img"])) - images = [img] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_46(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("img", "json") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_47(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - - img = Image.open(io.BytesIO(raw["img"])) - images = [img] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_47(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("img", "json") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_48(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - - img = Image.open(io.BytesIO(raw["img"])) - images = [img] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_48(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("img", "json") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_49(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - - img = Image.open(io.BytesIO(raw["img"])) - images = [img] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_49(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("img", "json") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_50(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - - img = Image.open(io.BytesIO(raw["img"])) - images = [img] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_50(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("img", "json") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_51(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - for i, turn in enumerate(json_item["conversations"]): - if i > 0 and turn["from"] == "human" and "" in turn["value"]: - turn["value"] = turn["value"].replace("\n", "") - - img = Image.open(io.BytesIO(raw["img"])) - images = [img] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_51(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("img", "json") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_52(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - - img = Image.open(io.BytesIO(raw["img"])) - images = [img] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_52(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("img", "json") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_53(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - # for i, turn in enumerate(json_item['conversations']): - # if i > 0 and turn['from'] == 'human' and '' in turn['value']: - # turn['value'] = turn['value'].replace("\n", "") - - img = Image.open(io.BytesIO(raw["img"])) - images = [img] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_53(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("img", "json") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_54(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - - img = Image.open(io.BytesIO(raw["img"])) - images = [img] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_54(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("img", "json") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_55(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - - img = Image.open(io.BytesIO(raw["img"])) - images = [img] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_55(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("img", "json") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_56(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - - img = Image.open(io.BytesIO(raw["img"])) - images = [img] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_56(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("img", "json") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_57(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - - img = Image.open(io.BytesIO(raw["img"])) - images = [img] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_57(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("img", "json") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_58(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - - img = Image.open(io.BytesIO(raw["img"])) - images = [img] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_58(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("img", "json") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_59(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - - img = Image.open(io.BytesIO(raw["img"])) - images = [img] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_59(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("img", "json") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_60(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - - img = Image.open(io.BytesIO(raw["img"])) - images = [img] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_60(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("img", "json") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_61(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - - img = Image.open(io.BytesIO(raw["img"])) - images = [img] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_61(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("img", "json") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_62(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - - img = Image.open(io.BytesIO(raw["img"])) - images = [img] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_62(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("img", "json") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_63(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - - img = Image.open(io.BytesIO(raw["img"])) - images = [img] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_63(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("img", "json") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_64(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - - img = Image.open(io.BytesIO(raw["img"])) - images = [img] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_64(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("img", "json") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_65(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - - img = Image.open(io.BytesIO(raw["img"])) - images = [img] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_65(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("img", "json") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_66(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - - img = Image.open(io.BytesIO(raw["img"])) - images = [img] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_66(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("img", "json") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_67(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - - img = Image.open(io.BytesIO(raw["img"])) - images = [img] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_67(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("img", "json") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_68(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - images = [raw["jpg"]] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_68(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("json", "jpg") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_69(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - images = [raw["jpg"]] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_69(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("jpg", "json") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_70(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - images = [raw["jpg"]] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_70(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("jpg", "json") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_71(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - images = [raw["jpg"]] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_71(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("jpg", "json") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_72(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - images = [raw["jpg"]] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_72(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("jpg", "json") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_73(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - - img = Image.open(io.BytesIO(raw["img"])) - images = [img] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_73(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("json", "img") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_74(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - - img = Image.open(io.BytesIO(raw["img"])) - images = [img] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_74(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("json", "img") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_75(raw: dict) -> dict: # Note: Images are already decoded to tensors - - if "text" in raw: - caption = raw["text"] - else: - caption = raw["json"]["caption"] - return dict( - __key__=raw["__key__"], - image=raw["jpg"], # expected type: torch.Tensor - caption=caption, # expected type: str - ) - - -def part_filter_75(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("json", "jpg", "text") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_76(raw: dict) -> dict: # Note: Images are already decoded to tensors - - if "text" in raw: - caption = raw["text"] - else: - caption = raw["json"]["caption"] - return dict( - __key__=raw["__key__"], - image=raw["jpg"], # expected type: torch.Tensor - caption=caption, # expected type: str - ) - - -def part_filter_76(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("json", "jpg", "text") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_77(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - - total = len(json_item["conversations"]) // 2 - idx = random.randrange(total) # noqa: F821 - human = json_item["conversations"][idx * 2] - out = json_item["conversations"][idx * 2 + 1] - - return dict( - __key__=raw["__key__"], - image=raw["jpg"], - context=human["value"].replace("\n", ""), - answers=out["value"], - answer_weights=None, - ) - - -def part_filter_77(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("json", "jpg") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_78(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - - total = len(json_item["conversations"]) // 2 - idx = random.randrange(total) # noqa: F821 - human = json_item["conversations"][idx * 2] - out = json_item["conversations"][idx * 2 + 1] - - return dict( - __key__=raw["__key__"], - image=raw["jpg"], - context=human["value"].replace("\n", ""), - answers=out["value"], - answer_weights=None, - ) - - -def part_filter_78(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("json", "jpg") - - - - -def sample_loader_79(raw: dict) -> dict: # Note: Images are already decoded to tensors - j = raw["json"] - if "answer" in j: - answers = [a[0] for a in j["answer"][0]] - answer_weights = torch.Tensor([float(a[1]) for a in j["answer"][0]]) - else: - answers = None - answer_weights = None - - return dict( - __key__=raw["__key__"], - image=raw["jpg"], # expected type: torch.Tensor - context=j["question"], # expected type: str - answers=answers, # expected type: typing.List[str] - answer_weights=answer_weights, # expected type: typing.Union[torch.Tensor, NoneType] - ) - - -def part_filter_79(part: str) -> bool: - # Filter for parts required by the sample_loader - return part in ("jpg", "json") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_80(raw: dict) -> dict: # Note: Images are already decoded to tensors - - return dict( - __key__=raw["__key__"], - image=raw["jpg"], # expected type: torch.Tensor - context=raw["json"]["question"], # expected type: str - answers=raw["json"]["answer"], # expected type: typing.Union[typing.List[str], NoneType], default: None - answer_weights=None, # expected type: typing.Union[torch.Tensor, NoneType], default: None - ) - - -def part_filter_80(part: str) -> bool: - - return part in ("json", "jpg") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_81(raw: dict) -> dict: # Note: Images are already decoded to tensors - j = raw["json"] - return dict( - image=raw["jpg"], # expected type: torch.Tensor - text=" ".join(j["lines"]["text"]), # expected type: str - ) - - -def part_filter_81(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("jpg", "json") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_82(raw: dict) -> dict: # Note: Images are already decoded to tensors - j = raw["json"] - - return dict( - __key__=raw["__key__"], - image=raw["jpg"], - text=j["text"], - words_boxes=j["bbox_1k_normalized"], - ) - - -def part_filter_82(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("jpg", "json") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_83(raw: dict) -> dict: - j = raw["json"] - - return dict( - __key__=raw["__key__"], - image=raw["jpg"], - text=j["text"], - words_boxes=j["bbox_1k_normalized"], - ) - - -def part_filter_83(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("json", "jpg") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_84(raw: dict) -> dict: # Note: Images are already decoded to tensors - j = raw["json"] - - return dict( - __key__=raw["__key__"], - image=raw["jpg"], - text=j["text"], - words_boxes=j["bbox_1k_normalized"], - ) - - -def part_filter_84(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("json", "jpg") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_85(raw: dict) -> dict: # Note: Images are already decoded to tensors - j = raw["json"] - - return dict(__key__=raw["__key__"], image=raw["jpg"], text=j["text"], words_boxes=j["quad_1k_normalized"]) - - -def part_filter_85(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("json", "jpg") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_86(raw: dict) -> dict: # Note: Images are already decoded to tensors - j = raw["json"] - - return dict( - __key__=raw["__key__"], - image=raw["jpg"], - text=j["text"], - words_boxes=j["bbox_1k_normalized"], - ) - - -def part_filter_86(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("jpg", "json") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_87(raw: dict) -> dict: # Note: Images are already decoded to tensors - j = raw["json"] - - quad = j["quad"] - quad = [val for point in quad for val in point] - - return dict( - image=raw["jpg"], # expected type: torch.Tensor - text=j["text"], # expected type: str - words_boxes=quad, # expected type: typing.Optional[torch.Tensor], default: None - ) - - -def part_filter_87(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("json", "jpg") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_88(raw: dict) -> dict: # Note: Images are already decoded to tensors - j = raw["json"] - return dict( - __key__=raw["__key__"], - image=raw["jpg"], - text="", - words_boxes=j["bboxes_1k_normalized"], - words_text=j["texts"], - ) - - -def part_filter_88(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("json", "jpg") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_89(raw: dict) -> dict: - j = raw["json"] - return dict( - __key__=raw["__key__"], - image=raw["jpg"], - text="", - words_boxes=j["bboxes_1k_normalized"], - words_text=j["texts"], - ) - - -def part_filter_89(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("jpg", "json") - - -# This file was automatically generated by `nvgpt4 data prepare`. - - - -def sample_loader_90(raw: dict) -> dict: # Note: Images are already decoded to tensors - j = raw["json"] - return dict( - __key__=raw["__key__"], - image=raw["jpg"], - text="", - words_boxes=j["quads_1k_normalized"], - words_text=j["texts"], - ) - - -def part_filter_90(part: str) -> bool: - - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("json", "jpg") - - - - -def sample_loader_91(raw: dict) -> dict: # Note: Images are already decoded to tensors - j = raw["json"] - if "answer" in j: - answers = [a[0] for a in j["answer"][0]] - answer_weights = torch.Tensor([float(a[1]) for a in j["answer"][0]]) - else: - answers = None - answer_weights = None - - return dict( - __key__=raw["__key__"], - image=raw["jpg"], # expected type: torch.Tensor - context=j["question"], # expected type: str - answers=answers, # expected type: typing.List[str] - answer_weights=answer_weights, # expected type: typing.Union[torch.Tensor, NoneType] - ) - - -def part_filter_91(part: str) -> bool: - # Filter for parts required by the sample_loader - return part in ("jpg", "json") - - -# Dataset -> Sample Loader Mapping -dataset_loader_mapping = { - "coco_train_val_restval": { - "sample_loader": "sample_loader_0", - "part_filter": "part_filter_0", - "data_class": "CaptioningWebdataset", - "data_weight": 0.01, - }, - "extended-sci/data/merged/CoT": { - "sample_loader": "sample_loader_1", - "part_filter": "part_filter_1", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.006, - }, - "extended-sci/data/merged/single-choice": { - "sample_loader": "sample_loader_2", - "part_filter": "part_filter_2", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.004, - }, - "extended-sci/data/extended-sci-3/CoT": { - "sample_loader": "sample_loader_3", - "part_filter": "part_filter_3", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.0006, - }, - "extended-sci/data/extended-sci-3/single-choice": { - "sample_loader": "sample_loader_4", - "part_filter": "part_filter_4", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.0004, - }, - "nvlm/wdai/data/SceMQA_processed": { - "sample_loader": "sample_loader_5", - "part_filter": "part_filter_5", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.0006, - }, - "nvlm/wdai/data/vqa_collection_doc_text_st_chart_scale_textbook_LRV_Screen": { - "sample_loader": "sample_loader_6", - "part_filter": "part_filter_6", - "data_class": "VQAWebdataset", - "data_weight": 0.08, - }, - "nvlm/wdai/data/plotqa/processed": { - "sample_loader": "sample_loader_7", - "part_filter": "part_filter_7", - "data_class": "VQAWebdataset", - "data_weight": 0.095, - }, - "nvlm/wdai/data/clevr-math/processed": { - "sample_loader": "sample_loader_8", - "part_filter": "part_filter_8", - "data_class": "VQAWebdataset", - "data_weight": 0.02, - }, - "nvlm/wdai/data/MMC-Instruction/processed": { - "sample_loader": "sample_loader_9", - "part_filter": "part_filter_9", - "data_class": "VQAWebdataset", - "data_weight": 0.07, - }, - "nvlm/wdai/data/ocrvqa/processed": { - "sample_loader": "sample_loader_10", - "part_filter": "part_filter_10", - "data_class": "VQAWebdataset", - "data_weight": 0.06, - }, - "nvlm/wdai/data/dude/processed": { - "sample_loader": "sample_loader_11", - "part_filter": "part_filter_11", - "data_class": "VQAWebdataset", - "data_weight": 0.01, - }, - "nvlm/wdai/data/VisualMRC/processed": { - "sample_loader": "sample_loader_12", - "part_filter": "part_filter_12", - "data_class": "VQAWebdataset", - "data_weight": 0.015, - }, - "nvlm/wdai/data/mcvqa_collection_scienceqa_ai2d_geoqaplus_geometry3k_tqa": { - "sample_loader": "sample_loader_13", - "part_filter": "part_filter_13", - "data_class": "MultiChoiceVQAWebdataset", - "data_weight": 0.025, - }, - "nvlm/wdai/data/arxiv_qa/processed": { - "sample_loader": "sample_loader_14", - "part_filter": "part_filter_14", - "data_class": "MultiChoiceVQAWebdataset", - "data_weight": 0.02, - }, - "nvlm/wdai/data/tabmwp/processed": { - "sample_loader": "sample_loader_15", - "part_filter": "part_filter_15", - "data_class": "MultiChoiceVQAWebdataset", - "data_weight": 0.015, - }, - "nvlm/wdai/data/ocr_vqa_aug/processed": { - "sample_loader": "sample_loader_16", - "part_filter": "part_filter_16", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.055, - }, - "nvlm/wdai/data/dvqa_full/processed": { - "sample_loader": "sample_loader_17", - "part_filter": "part_filter_17", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.055, - }, - "nvlm/wdai/data/LLaVA-v1.5_shuffle/no_refcoco_vg_ocrvqa": { - "sample_loader": "sample_loader_18", - "part_filter": "part_filter_18", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.085, - }, - "vqa/more_data/infographics_vqa/processed/train": { - "sample_loader": "sample_loader_19", - "part_filter": "part_filter_19", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.01, - }, - "nvlm/wdai/data/sharegpt4o/processed": { - "sample_loader": "sample_loader_20", - "part_filter": "part_filter_20", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.02, - }, - "nvlm/wdai/data/sparse_ocr_data/merged": { - "sample_loader": "sample_loader_21", - "part_filter": "part_filter_21", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.045, - }, - "nvlm/nayeonl/data/blendv4/MetaMathQA/processed/train_text_image": { - "sample_loader": "sample_loader_22", - "part_filter": "part_filter_22", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.004, - }, - "nvlm/nayeonl/data/blendv4/gsm8k/processed/train_text_image": { - "sample_loader": "sample_loader_23", - "part_filter": "part_filter_23", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.003, - }, - "nvlm/wdai/data/docmatix/processed": { - "sample_loader": "sample_loader_24", - "part_filter": "part_filter_24", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.1, - }, - "nvlm/wdai/data/bentham_hw_squad/processed": { - "sample_loader": "sample_loader_25", - "part_filter": "part_filter_25", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.01, - }, - "nvlm/wdai/data/WikiTableQA/processed": { - "sample_loader": "sample_loader_26", - "part_filter": "part_filter_26", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.003, - }, - "nvlm/wdai/data/figureqa/processed": { - "sample_loader": "sample_loader_27", - "part_filter": "part_filter_27", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.01, - }, - "nvlm/wdai/data/llava-onevision/ai2d_combined_processed": { - "sample_loader": "sample_loader_28", - "part_filter": "part_filter_28", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.01, - }, - "nvlm/wdai/data/llava-onevision/math_combined_processed": { - "sample_loader": "sample_loader_29", - "part_filter": "part_filter_29", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.035, - }, - "nvlm/wdai/data/llava-onevision/robut_combined_processed": { - "sample_loader": "sample_loader_30", - "part_filter": "part_filter_30", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.01, - }, - "nvlm/wdai/data/llava-onevision/llavar_20k_processed": { - "sample_loader": "sample_loader_31", - "part_filter": "part_filter_31", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.007, - }, - "nvlm/wdai/data/llava-onevision/tallyqa_processed": { - "sample_loader": "sample_loader_32", - "part_filter": "part_filter_32", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.02, - }, - "nvlm/wdai/data/llava-onevision/ureader_ie_processed": { - "sample_loader": "sample_loader_33", - "part_filter": "part_filter_33", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.007, - }, - "nvlm/wdai/data/llava-onevision/visual7w_processed": { - "sample_loader": "sample_loader_34", - "part_filter": "part_filter_34", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.006, - }, - "nvlm/wdai/data/llava-onevision/mavis_math_rule_geo_processed": { - "sample_loader": "sample_loader_35", - "part_filter": "part_filter_35", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.01, - }, - "nvlm/wdai/data/llava-onevision/ureader_kg_processed": { - "sample_loader": "sample_loader_36", - "part_filter": "part_filter_36", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.005, - }, - "nvlm/wdai/data/llava-onevision/ureader_qa_processed": { - "sample_loader": "sample_loader_37", - "part_filter": "part_filter_37", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.02, - }, - "nvlm/wdai/data/ocr_multi_collection_cocotext_textocr_ReCTs": { - "sample_loader": "sample_loader_38", - "part_filter": "part_filter_38", - "data_class": "OCRWebdataset", - "data_weight": 0.01, - }, - "nvlm/wdai/data/pdfa-eng-wds/processed_word_len_500": { - "sample_loader": "sample_loader_39", - "part_filter": "part_filter_39", - "data_class": "OCRWebdataset", - "data_weight": 0.015, - }, - "nvlm/wdai/data/llava-onevision/super_clevr_processed": { - "sample_loader": "sample_loader_40", - "part_filter": "part_filter_40", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.007, - }, - "nvlm/wdai/data/llava-onevision/icon_qa_processed": { - "sample_loader": "sample_loader_41", - "part_filter": "part_filter_41", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.009, - }, - "nvlm/wdai/data/augmentations/chartqa_aug": { - "sample_loader": "sample_loader_42", - "part_filter": "part_filter_42", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.005, - }, - "nvlm/wdai/data/augmentations/gpt_chartqa": { - "sample_loader": "sample_loader_43", - "part_filter": "part_filter_43", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.006, - }, - "nvlm/wdai/data/augmentations/gpt_docvqa": { - "sample_loader": "sample_loader_44", - "part_filter": "part_filter_44", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.006, - }, - "nvlm/wdai/data/augmentations/docvqa_text": { - "sample_loader": "sample_loader_45", - "part_filter": "part_filter_45", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.006, - }, - "nvlm/wdai/data/augmentations/textvqa_text": { - "sample_loader": "sample_loader_46", - "part_filter": "part_filter_46", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.008, - }, - "nvlm/wdai/data/augmentations/i2s-musicsheet": { - "sample_loader": "sample_loader_47", - "part_filter": "part_filter_47", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.0005, - }, - "nvlm/wdai/data/augmentations/music": { - "sample_loader": "sample_loader_48", - "part_filter": "part_filter_48", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.007, - }, - "nvlm/wdai/data/augmentations/invoice": { - "sample_loader": "sample_loader_49", - "part_filter": "part_filter_49", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.002, - }, - "nvlm/wdai/data/augmentations/k12": { - "sample_loader": "sample_loader_50", - "part_filter": "part_filter_50", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.019, - }, - "nvlm/wdai/data/augmentations/MTVQA": { - "sample_loader": "sample_loader_51", - "part_filter": "part_filter_51", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.007, - }, - "nvlm/wdai/data/augmentations/VisualWebInstruct": { - "sample_loader": "sample_loader_52", - "part_filter": "part_filter_52", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.028, - }, - "nvlm/wdai/data/augmentations/financeqa": { - "sample_loader": "sample_loader_53", - "part_filter": "part_filter_53", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.005, - }, - "nvlm/wdai/data/augmentations/docreason": { - "sample_loader": "sample_loader_54", - "part_filter": "part_filter_54", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.004, - }, - "nvlm/wdai/data/augmentations/gpt_mtwi": { - "sample_loader": "sample_loader_55", - "part_filter": "part_filter_55", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.005, - }, - "nvlm/wdai/data/augmentations/geos_gpt": { - "sample_loader": "sample_loader_56", - "part_filter": "part_filter_56", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.0001, - }, - "nvlm/wdai/data/augmentations/cauldron_vistext": { - "sample_loader": "sample_loader_57", - "part_filter": "part_filter_57", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.005, - }, - "nvlm/wdai/data/augmentations/memes": { - "sample_loader": "sample_loader_58", - "part_filter": "part_filter_58", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.005, - }, - "nvlm/wdai/data/augmentations/gpt_roadtext": { - "sample_loader": "sample_loader_59", - "part_filter": "part_filter_59", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.0002, - }, - "nvlm/wdai/data/augmentations/indoor_qa": { - "sample_loader": "sample_loader_60", - "part_filter": "part_filter_60", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.001, - }, - "nvlm/wdai/data/augmentations/colpali": { - "sample_loader": "sample_loader_61", - "part_filter": "part_filter_61", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.007, - }, - "nvlm/wdai/data/augmentations/pmc_vqa": { - "sample_loader": "sample_loader_62", - "part_filter": "part_filter_62", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.01, - }, - "nvlm/wdai/data/augmentations/pathvqa": { - "sample_loader": "sample_loader_63", - "part_filter": "part_filter_63", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.004, - }, - "nvlm/wdai/data/augmentations/sciqa": { - "sample_loader": "sample_loader_64", - "part_filter": "part_filter_64", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.027, - }, - "nvlm/wdai/data/augmentations/chinese_meme": { - "sample_loader": "sample_loader_65", - "part_filter": "part_filter_65", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.001, - }, - "nvlm/wdai/data/augmentations/gpt_hiertext": { - "sample_loader": "sample_loader_66", - "part_filter": "part_filter_66", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.003, - }, - "nvlm/wdai/data/augmentations/cauldron_cocoqa": { - "sample_loader": "sample_loader_67", - "part_filter": "part_filter_67", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.007, - }, - "nvlm/wdai/data/cmm-math/processed": { - "sample_loader": "sample_loader_68", - "part_filter": "part_filter_68", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.001, - }, - "nvlm/wdai/data/mmtab/processed": { - "sample_loader": "sample_loader_69", - "part_filter": "part_filter_69", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.008, - }, - "nvlm/wdai/data/simchart9k/processed": { - "sample_loader": "sample_loader_70", - "part_filter": "part_filter_70", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.001, - }, - "nvlm/wdai/data/llava-onevision/mapqa_processed": { - "sample_loader": "sample_loader_71", - "part_filter": "part_filter_71", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.005, - }, - "nvlm/wdai/data/llava-onevision/vizwiz_processed": { - "sample_loader": "sample_loader_72", - "part_filter": "part_filter_72", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.002, - }, - "nvlm/wdai/data/augmentations/gpt_infovqa": { - "sample_loader": "sample_loader_73", - "part_filter": "part_filter_73", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.001, - }, - "nvlm/wdai/data/augmentations/viquae": { - "sample_loader": "sample_loader_74", - "part_filter": "part_filter_74", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.0005, - }, - "captioning/ccs_recaptioned/webdataset": { - "sample_loader": "sample_loader_75", - "part_filter": "part_filter_75", - "data_class": "CaptioningWebdataset", - "data_weight": 0.2, - }, - "captioning/laion115m-clean": { - "sample_loader": "sample_loader_76", - "part_filter": "part_filter_76", - "data_class": "CaptioningWebdataset", - "data_weight": 0.579, - }, - "nvlm/wdai/data/dvqa_full/processed_pt": { - "sample_loader": "sample_loader_77", - "part_filter": "part_filter_77", - "data_class": "VQAWebdataset", - "data_weight": 0.02, - }, - "nvlm/wdai/data/docmatix/processed_pt": { - "sample_loader": "sample_loader_78", - "part_filter": "part_filter_78", - "data_class": "VQAWebdataset", - "data_weight": 0.02, - }, - "vqa/VQAv2/stage1": { - "sample_loader": "sample_loader_91", - "part_filter": "part_filter_91", - "data_class": "VQAWebdataset", - "data_weight": 1.0, - }, - "vqa/Visual_Genome": { - "sample_loader": "sample_loader_80", - "part_filter": "part_filter_80", - "data_class": "VQAWebdataset", - "data_weight": 0.01, - }, - "nvlm/wdai/data/pdfa-eng-wds/processed_word_len_300": { - "sample_loader": "sample_loader_81", - "part_filter": "part_filter_81", - "data_class": "OCRWebdataset", - "data_weight": 0.08, - }, - "nvlm/wdai/data/textocr/processed": { - "sample_loader": "sample_loader_82", - "part_filter": "part_filter_82", - "data_class": "OCRWebdataset", - "data_weight": 0.02, - }, - "nvlm/wdai/data/coco-text/processed": { - "sample_loader": "sample_loader_83", - "part_filter": "part_filter_83", - "data_class": "OCRWebdataset", - "data_weight": 0.002, - }, - "nvlm/wdai/data/ArT/processed": { - "sample_loader": "sample_loader_84", - "part_filter": "part_filter_84", - "data_class": "OCRWebdataset", - "data_weight": 0.001, - }, - "nvlm/wdai/data/ReCTs/processed": { - "sample_loader": "sample_loader_85", - "part_filter": "part_filter_85", - "data_class": "OCRWebdataset", - "data_weight": 0.001, - }, - "nvlm/wdai/data/lsvt/processed": { - "sample_loader": "sample_loader_86", - "part_filter": "part_filter_86", - "data_class": "OCRWebdataset", - "data_weight": 0.005, - }, - "nvlm/wdai/data/RCTW/processed": { - "sample_loader": "sample_loader_87", - "part_filter": "part_filter_87", - "data_class": "OCRWebdataset", - "data_weight": 0.001, - }, - "nvlm/wdai/data/coco-text/processed_multi": { - "sample_loader": "sample_loader_88", - "part_filter": "part_filter_88", - "data_class": "OCRWebdataset", - "data_weight": 0.0003, - }, - "nvlm/wdai/data/textocr/processed_multi": { - "sample_loader": "sample_loader_89", - "part_filter": "part_filter_89", - "data_class": "OCRWebdataset", - "data_weight": 0.0004, - }, - "nvlm/wdai/data/ReCTs/processed_multi": { - "sample_loader": "sample_loader_90", - "part_filter": "part_filter_90", - "data_class": "OCRWebdataset", - "data_weight": 0.0003, - }, -} - - -def get_sample_loader(path): - """Returns the correct sample_loader function for a dataset.""" - if path not in dataset_loader_mapping: - path = data_path_mapping(path) - assert path in dataset_loader_mapping, f"path {path} not in dataset_loader_mapping" - return globals().get(dataset_loader_mapping.get(path, {}).get("sample_loader")) - - -def get_part_filter(path): - """Returns the correct part_filter function for a dataset.""" - if path not in dataset_loader_mapping: - path = data_path_mapping(path) - assert path in dataset_loader_mapping, f"path {path} not in dataset_loader_mapping" - return globals().get(dataset_loader_mapping.get(path, {}).get("part_filter")) - - -def get_data_class(path): - """Returns the correct data_class for a dataset.""" - if path not in dataset_loader_mapping: - path = data_path_mapping(path) - - assert path in dataset_loader_mapping, f"path {path} not in dataset_loader_mapping" - return dataset_loader_mapping[path]["data_class"] diff --git a/cosmos_framework/data/vfm/augmentors/vlm/prompt_format.py b/cosmos_framework/data/vfm/augmentors/vlm/prompt_format.py index ec86e66..5b576c4 100644 --- a/cosmos_framework/data/vfm/augmentors/vlm/prompt_format.py +++ b/cosmos_framework/data/vfm/augmentors/vlm/prompt_format.py @@ -45,7 +45,6 @@ def __call__(self, data_dict: Dict) -> Dict: if isinstance(list_of_conversation[0], list): selected_conversation = random.sample(list_of_conversation, 1)[0] elif isinstance(list_of_conversation[0], dict): - selected_conversation = list_of_conversation else: raise ValueError( @@ -82,7 +81,6 @@ def __call__(self, data_dict: Dict) -> Dict: del data_dict[conversation_key] - # # enforce chat order # self._enforce_text_chat_order(selected_conversation) @@ -91,7 +89,7 @@ def __call__(self, data_dict: Dict) -> Dict: def _enforce_text_chat_order(self, conversation: list) -> None: """ Reorder text content within user messages based on text_chat_order setting. - NOTE: this does NOT work for interleaved data!!!!!! + NOTE (maxzhaoshuol): this does NOT work for interleaved data!!!!!! Args: conversation: List of message dictionaries diff --git a/cosmos_framework/data/vfm/augmentors/vlm/timestamp.py b/cosmos_framework/data/vfm/augmentors/vlm/timestamp.py index edede0c..6228153 100644 --- a/cosmos_framework/data/vfm/augmentors/vlm/timestamp.py +++ b/cosmos_framework/data/vfm/augmentors/vlm/timestamp.py @@ -392,17 +392,15 @@ def augment_user_prompt( elif output_format == "temporal_caption": event = assistant_message[0] if random.random() < 0.333333: - start = round(event["start"]) end = round(event["end"]) elif random.random() < 0.666666: - start = round(event["start"] * 2) / 2 end = round(event["end"] * 2) / 2 else: start = event["start"] end = event["end"] - if start == end: # HACK: remove events with start == end + if start == end: raise ValueError("Start and end time are the same for data.") if timestamp_format == "seconds": if random.random() < 0.5: diff --git a/cosmos_framework/data/vfm/augmentors/vlm/timestamp_with_subject_tracking.py b/cosmos_framework/data/vfm/augmentors/vlm/timestamp_with_subject_tracking.py index 0507109..90cc9ab 100644 --- a/cosmos_framework/data/vfm/augmentors/vlm/timestamp_with_subject_tracking.py +++ b/cosmos_framework/data/vfm/augmentors/vlm/timestamp_with_subject_tracking.py @@ -224,17 +224,15 @@ def augment_user_prompt( elif output_format == "temporal_caption_subject": event = assistant_message[0] if random.random() < 0.333333: - start = round(event["start"]) end = round(event["end"]) elif random.random() < 0.666666: - start = round(event["start"] * 2) / 2 end = round(event["end"] * 2) / 2 else: start = event["start"] end = event["end"] - if start == end: # HACK: remove events with start == end + if start == end: log.warning(f"Start and end time are the same for data. {event}") return None diff --git a/cosmos_framework/data/vfm/augmentors/vlm/timestamp_without_augment_message.py b/cosmos_framework/data/vfm/augmentors/vlm/timestamp_without_augment_message.py index 584ca1d..7212510 100644 --- a/cosmos_framework/data/vfm/augmentors/vlm/timestamp_without_augment_message.py +++ b/cosmos_framework/data/vfm/augmentors/vlm/timestamp_without_augment_message.py @@ -162,14 +162,12 @@ def augment_user_prompt( elif output_format == "temporal_caption": event = assistant_message[0] if random.random() < 0.5: - start = round(event["start"]) end = round(event["end"]) else: - start = round(event["start"] * 2) / 2 end = round(event["end"] * 2) / 2 - if start == end: # HACK: remove events with start == end + if start == end: raise ValueError("Start and end time are the same for data.") user_prompt = random.choice( [ diff --git a/cosmos_framework/data/vfm/augmentors/vlm/timestamp_without_end_time.py b/cosmos_framework/data/vfm/augmentors/vlm/timestamp_without_end_time.py index 8df9dd1..812e113 100644 --- a/cosmos_framework/data/vfm/augmentors/vlm/timestamp_without_end_time.py +++ b/cosmos_framework/data/vfm/augmentors/vlm/timestamp_without_end_time.py @@ -205,10 +205,8 @@ def augment_user_prompt( elif output_format == "temporal_caption": event = assistant_message[0] if random.random() < 0.333333: - start = round(event["start"]) elif random.random() < 0.666666: - start = round(event["start"] * 2) / 2 else: start = event["start"] diff --git a/cosmos_framework/data/vfm/augmentors/vlm/tokenize_data.py b/cosmos_framework/data/vfm/augmentors/vlm/tokenize_data.py index a0ce29c..f3a9914 100644 --- a/cosmos_framework/data/vfm/augmentors/vlm/tokenize_data.py +++ b/cosmos_framework/data/vfm/augmentors/vlm/tokenize_data.py @@ -158,7 +158,6 @@ def __call__(self, data_dict: Dict) -> Dict: if message["role"] == "user" and isinstance(message["content"], list): total_images += len([content for content in message["content"] if content["type"] == "image"]) total_videos += len([content for content in message["content"] if content["type"] == "video"]) - assert total_videos == 1 or total_videos == 0, "Only one video is supported for now" # url @@ -167,7 +166,6 @@ def __call__(self, data_dict: Dict) -> Dict: # go through each message in the conversation for message in conversation: # for user message, we insert the media - if message["role"] == "user" and isinstance( message["content"], list ): # Otherwise it's text and content is a string @@ -225,7 +223,6 @@ def __call__(self, data_dict: Dict) -> Dict: raw_images.append(image) elif content["type"] == "video": - # as tokenization will NOT upsample the video, we can use a larger value here at the cost of multiple video having 1.5x token length max_total_pixels = token_to_pixels(self.max_video_token_length * 1.5, temporal_patch_size=2) media_key = content["video"] @@ -248,7 +245,6 @@ def __call__(self, data_dict: Dict) -> Dict: return None videos = data_dict["media"][media_key]["videos"] # list of PIL images fps = data_dict["media"][media_key]["fps"] - # this is because videos are decoded to be around "max_video_token_length" tokens videos = maybe_subsample_frames( diff --git a/cosmos_framework/data/vfm/joint_dataloader.py b/cosmos_framework/data/vfm/joint_dataloader.py index 56b13e7..4b6b443 100644 --- a/cosmos_framework/data/vfm/joint_dataloader.py +++ b/cosmos_framework/data/vfm/joint_dataloader.py @@ -57,7 +57,6 @@ def custom_collate_fn(batch): # Handle standard list of samples elem = batch[0] if isinstance(elem, dict): - # Some Action datasets add optional metadata keys (for example # ``additional_view_description`` for concat-view captions) only for a # subset of samples. PyTorch can batch such samples together when @@ -261,7 +260,7 @@ def _prewarm_dataloaders(self) -> None: The first ``next()`` call on an ``InfiniteDataLoader`` iterator triggers ``DataLoader.__iter__()`` which spawns worker processes. For action dataloaders using ``multiprocessing_context='spawn'``, each worker must - fully initialise heavy datasets (BridgeOrigLeRobotDataset, EMBODIMENT_A, etc.) + fully initialise heavy datasets (BridgeOrigLeRobotDataset, embodiment_a, etc.) from scratch. If this happens lazily during training, the resulting delay (potentially minutes) causes NCCL collective timeouts when faster ranks enter the forward pass while slower ranks are still loading data. diff --git a/cosmos_framework/data/vfm/packing_iterable_dataset.py b/cosmos_framework/data/vfm/packing_iterable_dataset.py index 847b85a..ac30f0d 100644 --- a/cosmos_framework/data/vfm/packing_iterable_dataset.py +++ b/cosmos_framework/data/vfm/packing_iterable_dataset.py @@ -4,13 +4,17 @@ """ Abstract base class for pool-based token-budget bin-packing over multiple datasets. -Extracted from ``projects.cosmos3.vfm.datasets.vlm.joint_dataset_dynamic_batch_webloader`` +Extracted from ``cosmos_framework.data.vfm.vlm.joint_dataset_dynamic_batch_webloader`` so that both the VLM and VFM internal dataloaders can share a single packing implementation. Usage ----- Subclass and implement ``compute_sample_tokens(sample) -> int``. Optionally override ``collate_batch(samples) -> Any`` for custom collation. + + class MyPacker(PackingIterableDataset): + def compute_sample_tokens(self, sample): + return len(sample["input_ids"]) """ from __future__ import annotations @@ -62,11 +66,6 @@ class PackingIterableDataset(torch.utils.data.IterableDataset, ABC): singletons regardless of budget. batching_strategy: ``"prefer_closest"`` (default) or ``"prefer_first"``. - apply_long_sample_halving: - When ``True`` (default), ``_max_tokens`` halves the budget for any - batch whose largest sample exceeds 1000 tokens — a memory-safety - heuristic. Set ``False`` only when memory headroom at the literal - ``max_tokens`` budget has been validated for the recipe. """ def __init__( @@ -77,7 +76,6 @@ def __init__( max_batch_size: int, long_threshold: int, batching_strategy: str, - apply_long_sample_halving: bool = True, ): super().__init__() @@ -90,7 +88,6 @@ def __init__( self.long_threshold = long_threshold self.max_batch_size = max_batch_size self.batching_strategy = batching_strategy - self.apply_long_sample_halving = apply_long_sample_halving self._pool: deque[dict] = deque() self._dataset_names: list[str] = [] @@ -166,8 +163,6 @@ def __iter__(self): # ------------------------------------------------------------------ def _max_tokens(self, cur_max: int) -> int: - if not self.apply_long_sample_halving: - return self.max_tokens if cur_max < 1000: return self.max_tokens return self.max_tokens // 2 diff --git a/cosmos_framework/data/vfm/processors/__init__.py b/cosmos_framework/data/vfm/processors/__init__.py index a1cc60e..646800b 100644 --- a/cosmos_framework/data/vfm/processors/__init__.py +++ b/cosmos_framework/data/vfm/processors/__init__.py @@ -104,13 +104,6 @@ def build_processor( bucket: Optional[str] = None, cache_dir: Optional[str] = None, ): - # Local artifact path: source the processor from a bundled directory - # (e.g. the top level of nvidia/Cosmos3-Nano, which ships its own - # preprocessor_config.json, tokenizer.json, etc). Avoids the redundant - # upstream Qwen/Qwen3-VL-*-Instruct fetch. Cosmos3-Nano/Super both ship - # a Qwen3VL-compatible processor, so dispatch to Qwen3VLProcessor. - if os.path.isdir(tokenizer_type): - return Qwen3VLProcessor(tokenizer_type, cache_dir=cache_dir) if credentials is None or bucket is None: if config_variant is None: config_variant = "s3" @@ -125,7 +118,12 @@ def build_processor( return Qwen3VLProcessor(tokenizer_type, credentials=credentials, bucket=bucket, cache_dir=cache_dir) elif "nvidia/NVIDIA-Nemotron-Nano-12B-v2-VL-BF16" in tokenizer_type: return NemotronVLProcessor(tokenizer_type, credentials=credentials, bucket=bucket, cache_dir=cache_dir) - elif "NVIDIA-Nemotron-3-Dense-VL" in tokenizer_type or "Qwen3-2B-ViT" in tokenizer_type: + elif ( + "NVIDIA-Nemotron-3-Dense-VL" in tokenizer_type + or "Qwen3-2B-ViT" in tokenizer_type + or "nvidia/Cosmos3-Reasoner-2B-Private" in tokenizer_type + or "nvidia/Cosmos3-Edge-Reasoner" in tokenizer_type + ): return Nemotron3DenseVLProcessor(tokenizer_type, credentials=credentials, bucket=bucket, cache_dir=cache_dir) elif "Qwen/Qwen3-0.6B" in tokenizer_type: local_path = _download_llm_tokenizer(tokenizer_type, credentials, bucket, cache_dir) @@ -138,26 +136,9 @@ def build_processor( raise ValueError(f"Tokenizer type {tokenizer_type} not supported") -def build_processor_lazy( - *args, - repository: Optional[str] = None, - revision: Optional[str] = None, - subdir: str = "", - **kwargs, -): +def build_processor_lazy(*args, **kwargs): """LazyCall wrapper that resolves ``build_processor`` on this module at call time. - Two modes: - 1. Upstream tokenizer (legacy): pass ``tokenizer_type=""`` - (and optional ``config_variant`` / ``credentials`` / ``bucket``). - The processor is sourced from the upstream HF repo (e.g. - ``Qwen/Qwen3-VL-8B-Instruct``). - 2. Local artifact: pass ``repository`` + ``revision`` (and optional - ``subdir``). The processor is sourced from the HF cache of the - named artifact (e.g. ``nvidia/Cosmos3-Nano``), reusing the same - revision the OmniModel checkpoint download uses. Avoids a - redundant upstream Qwen3-VL-*-Instruct fetch. - LazyCall captures its target at config-construction time, so a direct ``L(build_processor)`` would freeze the original function reference and bypass any later ``monkeypatch.setattr`` on this module's @@ -165,13 +146,4 @@ def build_processor_lazy( lookup on every call, so test fixtures patching ``build_processor`` are honored when the config is instantiated. """ - if repository is not None: - from cosmos_framework.utils.checkpoint_db import CheckpointDirHf - - if revision is None: - raise ValueError("'revision' is required when 'repository' is set") - local_path = CheckpointDirHf(repository=repository, revision=revision).download() - if subdir: - local_path = os.path.join(local_path, subdir) - return sys.modules[__name__].build_processor(local_path, **kwargs) return sys.modules[__name__].build_processor(*args, **kwargs) diff --git a/cosmos_framework/data/vfm/processors/nemotronvl_processor.py b/cosmos_framework/data/vfm/processors/nemotronvl_processor.py index 767c8ef..c6addbf 100644 --- a/cosmos_framework/data/vfm/processors/nemotronvl_processor.py +++ b/cosmos_framework/data/vfm/processors/nemotronvl_processor.py @@ -248,7 +248,6 @@ def __init__( # NemotronVL hardcodes these helper attributes because they are not # discoverable from the HF model config; the values match the upstream # vision-encoder configuration. - # HACK: hardcoded based on the model config. self.min_height_width = 512 self.patch_size = 16 self.temporal_patch_size = 1 @@ -258,7 +257,6 @@ def __init__( def _resolve_pad_id(self): # NemotronVL's tokenizer does not specify a pad_token; reserve # for padding (project convention). - return self.processor.tokenizer.convert_tokens_to_ids("") def apply_chat_template( diff --git a/cosmos_framework/data/vfm/sequence_packing.py b/cosmos_framework/data/vfm/sequence_packing.py index d2821cc..aa4e275 100644 --- a/cosmos_framework/data/vfm/sequence_packing.py +++ b/cosmos_framework/data/vfm/sequence_packing.py @@ -850,7 +850,6 @@ def _pack_action_tokens( packed_seq.action.token_shapes.append((action_split_len,)) packed_seq.action.tokens.append(input_action_tokens) - condition_set = {idx for idx in condition_frame_indexes_action if 0 <= idx < action_split_len} assert isinstance(packed_seq.action.condition_mask, list) @@ -2095,7 +2094,7 @@ def verify_natten_parameter_list( {'window_size_float': (0.5, 0.5), 'dilation_float': (1.0, 0.0)} # valid # Fixed window size of 8x8, dilation of 2x1. - + # NOTE: requires ALL inputs to be at least 16x8 {'window_size': (8, 8), 'dilation': (2, 1)} # valid # Multi-profile: different parameters for 2D (images) and 3D (videos) @@ -2231,7 +2230,7 @@ def generate_natten_metadata( {'window_size_float': (0.5, 0.5), 'dilation_float': (1.0, 0.0)} # valid # Fixed window size of 8x8, dilation of 2x1. - + # NOTE: requires ALL inputs to be at least 16x8 {'window_size': (8, 8), 'dilation': (2, 1)} # valid # Invalid: @@ -2363,9 +2362,9 @@ def filter_shape(shape: tuple) -> tuple: is_causal = dim_params["is_causal"] # Create varlen metadata for natten varlen/varsized ops - + # NOTE: generate_multi_dim_varlen_parameters will automatically map window size -1 to # full size, that's why constant window sizes aren't allowed. - + # NOTE: if any of the parameters are constant, natten will simplify them natten_metadata.append( generate_multi_dim_varlen_parameters( token_layout_list=token_layout_list, @@ -2780,7 +2779,6 @@ def build_sequence_plans_from_data_batch( Returns: List of SequencePlan objects, one per sample in the batch. """ - # For new modalities, please generate the sequence_plan in the dataset class!!!! # If sequence_plan already exists in data_batch, return it @@ -2790,7 +2788,6 @@ def build_sequence_plans_from_data_batch( assert "action" not in data_batch or data_batch["action"] is None, "Action data SHOULD have sequence_plans!" assert "sound" not in data_batch or data_batch["sound"] is None, "Sound data SHOULD have sequence_plans!" - # Determine batch size from available tensors batch_size = 0 for key in [input_video_key, input_image_key]: diff --git a/cosmos_framework/data/vfm/sound_data_utils.py b/cosmos_framework/data/vfm/sound_data_utils.py index 0a8ec63..2d739b0 100644 --- a/cosmos_framework/data/vfm/sound_data_utils.py +++ b/cosmos_framework/data/vfm/sound_data_utils.py @@ -3,7 +3,8 @@ """Sound data utilities for building sequence plans and handling audio-video generation modes. -This module provides utilities for building SequencePlan objects based on sound generation modes. +This module provides utilities for building SequencePlan objects based on sound generation modes, +similar to how action modes are handled in cosmos_framework/data/vfm/action/data_utils.py. Supported modes: - t2vs: Text → Video + Sound (joint generation) @@ -27,7 +28,8 @@ def build_sequence_plan_for_sound( """Build a SequencePlan based on the sound generation mode. This function determines the appropriate condition frame indexes for vision and sound - based on the specified mode. + based on the specified mode. It mirrors how `build_sequence_plan_from_mode` works + for action in cosmos_framework/data/vfm/action/data_utils.py. Args: mode: Generation mode. One of: diff --git a/cosmos_framework/data/vfm/utils.py b/cosmos_framework/data/vfm/utils.py index ab6f238..a96f725 100644 --- a/cosmos_framework/data/vfm/utils.py +++ b/cosmos_framework/data/vfm/utils.py @@ -6,7 +6,6 @@ IMAGE_RES_SIZE_INFO: dict[str, dict[str, tuple[int, int]]] = { # Our desired 256 resolution is the one below (commented). - # Desired: "256": {"1,1": (336, 336), "4,3": (384, 288), "3,4": (288, 384), "16,9": (448, 256), "9,16": (256, 448)}, "256": { "1,1": (256, 256), @@ -41,7 +40,6 @@ VIDEO_RES_SIZE_INFO: dict[str, dict[str, tuple[int, int]]] = { # Our desired 256 resolution is the one below (commented). - # Desired: "256": {"1,1": (336, 336), "4,3": (384, 288), "3,4": (288, 384), "16,9": (448, 256), "9,16": (256, 448)}, "256": { "1,1": (256, 256), diff --git a/cosmos_framework/data/vfm/vlm/video_decoder_qwen.py b/cosmos_framework/data/vfm/vlm/video_decoder_qwen.py new file mode 100644 index 0000000..12c9bcc --- /dev/null +++ b/cosmos_framework/data/vfm/vlm/video_decoder_qwen.py @@ -0,0 +1,249 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: OpenMDW-1.1 + +""" +Copied from projects/cosmos/reason1/datasets/video_decoder_qwen.py +Changes: +1: remove hardcoded hyper-parameters for Qwen, now read it from processor +2: support skipping smart resize, since it may resize the video frames to be smaller than model input and frames will get resized up later in processor +""" + +import random +import re +from concurrent.futures import ThreadPoolExecutor +from functools import partial +from typing import Callable, Optional + +import torch +from PIL import Image +from qwen_vl_utils.vision_process import smart_nframes, smart_resize +from torchcodec.decoders import VideoDecoder +from torchvision import transforms +from torchvision.transforms import InterpolationMode + +from cosmos_framework.utils import log +from cosmos_framework.data.vfm.processors.qwen3vl_processor import Qwen3VLProcessor + +Image.MAX_IMAGE_PIXELS = 933120000 +_VIDEO_EXTENSIONS = "mp4 avi webm mov".split() + +VIDEO_DECODER_OPTIONS = {} + + +def token_to_pixels(token_length: int, patch_size: int = 14, temporal_patch_size: int = 2, merge_size: int = 2) -> int: + """Convert token length to pixels based on patch size and temporal patch size. + + Args: + token_length: Token length + patch_size: Patch size + temporal_patch_size: Temporal patch size, + for Qwen it has 3D conv, temporal patch size is 2; for other models like internVL or eagle er, the temporal patch size is 1 since their VIT is image encoder; + merge_size: Merge size, or called pixel shuffing factor; + for Qwen and internVL it is 2; for eagle er it is 1; + """ + merged_patch_size = patch_size * merge_size + return token_length * merged_patch_size**2 * temporal_patch_size + + +def pixels_to_token(pixels: int, patch_size: int = 14, temporal_patch_size: int = 2, merge_size: int = 2) -> int: + """Convert pixels to token length based on patch size and temporal patch size.""" + merged_patch_size = patch_size * merge_size + return pixels // merged_patch_size**2 // temporal_patch_size + + +def video_decoder_qwen( + num_threads: int = 0, + min_fps_thres: int = 4, + max_fps_thres: int = 60, + target_fps: float = 2.0, + min_video_token_length: int = 16, + max_video_token_length: int = 8192, + random_augmentation: bool = False, + frame_count_random_range: Optional[list[int]] = None, + **kwargs, +) -> Callable: + """ + Sampling video frames similar to Qwen. It prioritizes matching the target FPS first and then resizing the video frames. + See https://github.com/kq-chen/qwen-vl-utils/blob/main/src/qwen_vl_utils/vision_process.py#L118 for more details. + + Args: + key: Video file name/key + data: Video binary data + min_fps_thres: Minimum FPS threshold + max_fps_thres: Maximum FPS threshold + target_fps: Target FPS + min_video_token_length: Minimum token length + max_video_token_length: Maximum token length + num_threads: Number of threads for the torchcodec video decoder + random_augmentation: Whether to randomize the FPS and max_video_token_length + frame_count_random_range: Random frame count range + + Returns: + dict with video frames tensor and target FPS + """ + + video_decoder_configured = partial( + _video_decoder_qwen_func, + min_fps_thres=min_fps_thres, + max_fps_thres=max_fps_thres, + num_threads=num_threads, + target_fps=target_fps, + min_video_token_length=min_video_token_length, + max_video_token_length=max_video_token_length, + random_augmentation=random_augmentation, + frame_count_random_range=frame_count_random_range, + ) + + return video_decoder_configured + + +def _video_decoder_qwen_func( + key: str, + data: bytes, + processor: Qwen3VLProcessor, + min_fps_thres: int = 4, + max_fps_thres: int = 60, + target_fps: float = 2.0, + min_video_token_length: int = 16, + max_video_token_length: int = 8192, + num_threads: int = 0, + random_augmentation: bool = False, + fps_random_range: list[float] = [0.5, 1.5], + max_video_token_length_random_range: list[float] = [0.75, 1.25], + frame_count_random_range: Optional[list[int]] = None, + start_frame: Optional[int] = None, + end_frame: Optional[int] = None, + decoding_timeout: int = 60, + **kwargs, +) -> dict | None: + """Actual video decoder function. + + Args: + key (str): Video file name/key + data (bytes): Video binary data + min_fps_thres (int, optional): Minimum FPS threshold. Defaults to 4. + max_fps_thres (int, optional): Maximum FPS threshold. Defaults to 60. + target_fps (float, optional): Target FPS. Defaults to 2.0. + min_video_token_length (int, optional): Minimum token length. Defaults to 16. + max_video_token_length (int, optional): Maximum token length. Defaults to 8192. + num_threads (int, optional): Number of threads for the torchcodec video decoder. Defaults to 0. + random_augmentation (bool, optional): Whether to randomize the FPS and max_video_token_length. Defaults to False. + fps_random_range (list[float], optional): Random FPS range. Defaults to [10.0, 24.0]. + max_video_token_length_random_range (list[float], optional): Random max_video_token_length range. Defaults to [0.75, 1.25]. + frame_count_random_range (list[int], optional): Random frame count range. If provided, take priority over fps_random_range. + start_frame (Optional[int], optional): Start frame. Defaults to None. If both start_frame and end_frame are provided, the video will be decoded from start_frame to end_frame. + end_frame (Optional[int], optional): End frame. Defaults to None. If both start_frame and end_frame are provided, the video will be decoded from start_frame to end_frame. + decoding_timeout (int, optional): Timeout in seconds. Defaults to 60. + Raises: + ValueError: Video fps lower than 1, skipping + ValueError: Video fps lower than min_fps_thres, skipping + ValueError: Video fps higher than max_fps_thres, skipping + + Returns: + dict | None: Dictionary with video frames tensor and target FPS + """ + # Check video extension + extension = re.sub(r".*[.]", "", key) + if extension.lower() not in _VIDEO_EXTENSIONS: + return None + + # Read video with torchcodec + video_reader = VideoDecoder(data, num_ffmpeg_threads=num_threads) + total_frames = video_reader.metadata.num_frames + video_fps = video_reader.metadata.average_fps + + # torchcodec returns ``None`` for containers that don't store frame count + # or average fps (e.g. some MKV/WebM streams). Downstream arithmetic + # (``total_frames - 1``, ``video_fps < 1``, ...) would TypeError on None; + # surface a ValueError so the dataloader's skip path handles it uniformly. + if total_frames is None or video_fps is None: + raise ValueError(f"torchcodec missing metadata (num_frames={total_frames}, average_fps={video_fps}), skipping") + + if start_frame is not None and end_frame is not None: + total_frames = end_frame - start_frame + + if video_fps < 1: + raise ValueError("Video fps lower than 1, skipping") + if video_fps < min_fps_thres: + raise ValueError(f"Video fps {video_fps} lower than {min_fps_thres}, skipping") + if video_fps > max_fps_thres: + raise ValueError(f"Video fps {video_fps} higher than {max_fps_thres}, skipping") + + if random_augmentation: + if frame_count_random_range is not None: + # Random number of frames + min_frames_range, max_frames_range = frame_count_random_range + min_frames_range = min(min_frames_range, total_frames) + max_frames_range = min(max_frames_range, total_frames) + target_frames = random.uniform(min_frames_range, max_frames_range) + target_fps = target_frames / total_frames * video_fps + else: + # randomize fps + target_fps = ( + random.uniform(fps_random_range[0], fps_random_range[1]) * target_fps + if random.random() < 0.5 + else target_fps + ) + # randomize max_video_token_length + max_video_token_length = int( + random.uniform(max_video_token_length_random_range[0], max_video_token_length_random_range[1]) + * max_video_token_length + ) + log.debug(f"random_augmentation: max_video_token_length: {max_video_token_length}, target_fps: {target_fps}") + + patch_size = processor.patch_size + min_height_width = processor.min_height_width + temporal_patch_size = processor.temporal_patch_size + merge_size = processor.merge_size + min_pixels: int = token_to_pixels(min_video_token_length, patch_size, temporal_patch_size, merge_size) + max_pixels: int = token_to_pixels(max_video_token_length, patch_size, temporal_patch_size, merge_size) + max_frames: int = max_pixels // (min_height_width) ** 2 // temporal_patch_size + + # sample based on target fps + nframes = smart_nframes(dict(fps=target_fps), total_frames=total_frames, video_fps=video_fps) + nframes = min(nframes, max_frames) + if start_frame is not None and end_frame is not None: + idx = torch.linspace(start_frame, end_frame - 1, nframes).round().long().tolist() # [nframes] + else: + idx = torch.linspace(0, total_frames - 1, nframes).round().long().tolist() # [nframes] + + def _decode_video() -> torch.Tensor: + return video_reader.get_frames_at(indices=idx).data # [T, C, H, W] uint8 + + # Use ThreadPoolExecutor to run video decoding with a timeout. + # If the thread is stuck, abandon it immediately. + executor = ThreadPoolExecutor(max_workers=1) + future = executor.submit(_decode_video) + try: + video_frames = future.result(timeout=decoding_timeout) + executor.shutdown(wait=False) + except TimeoutError as e: + log.warning(f"[{key}] Video decoding timed out after {decoding_timeout} seconds") + executor.shutdown(wait=False) + return None + + sample_fps = nframes / max(total_frames, 1e-6) * video_fps + + # recompute max_pixels based on number of sampled frames + nframes, _, height, width = video_frames.shape + max_pixels = max_pixels // nframes + if processor.use_smart_resize: + resized_height, resized_width = smart_resize( + height, + width, + factor=patch_size * merge_size, + min_pixels=min_pixels, + max_pixels=max_pixels, + ) + log.debug( + f"resized_height: {resized_height}, resized_width: {resized_width} | original height: {height}, original width: {width}" + ) + video_frames = transforms.functional.resize( + video_frames, + [resized_height, resized_width], + interpolation=InterpolationMode.BICUBIC, + antialias=True, + ).float() # [T,C,H,W] + video_frames = video_frames.permute(1, 0, 2, 3) # [C,T,H,W] + + return dict(videos=video_frames, fps=sample_fps) diff --git a/cosmos_framework/data/vlm/processors/nemotron3densevl_processor.py b/cosmos_framework/data/vlm/processors/nemotron3densevl_processor.py new file mode 100644 index 0000000..08c5d37 --- /dev/null +++ b/cosmos_framework/data/vlm/processors/nemotron3densevl_processor.py @@ -0,0 +1,249 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: OpenMDW-1.1 + +import os +from typing import Dict, List, Optional + +import numpy as np +import torch +from PIL import Image +from qwen_vl_utils.vision_process import smart_resize +from transformers.models.auto.processing_auto import AutoProcessor + +from cosmos_framework.utils import log +from cosmos_framework.utils.vlm.pretrained_models_downloader import maybe_download_hf_model_from_s3 + + +def convert_string_content_to_list_content(messages: List[Dict]) -> List[Dict]: + """ + Convert the string content to a list of dicts. + """ + for message_id, message in enumerate(messages): + if isinstance(message["content"], str): + messages[message_id]["content"] = [{"type": "text", "text": message["content"]}] + return messages + + +def maybe_parse_video_content( + messages: List[Dict], +) -> tuple[int, Optional[list[float]], Optional[list[int]], Optional[list[list[int]]]]: + """ + Convert the string content to a list of dicts. + """ + num_video = 0 + video_fps = [] + video_total_num_frames = [] + video_frames_indices = [] + for message_id, message in enumerate(messages): + if isinstance(message["content"], list): + for sub_content in message["content"]: + if sub_content.get("type", "") == "video" and isinstance(sub_content["video"], list): + num_video += 1 + fps = sub_content.get("fps", None) + if fps is None: + log.critical( + f"fps is None for video {sub_content}. Better to set the fps explicitly", rank0_only=False + ) + video_fps.append(fps) + video_total_num_frames.append(len(sub_content["video"])) + video_frames_indices.append(list(range(video_total_num_frames[-1]))) + return num_video, video_fps, video_total_num_frames, video_frames_indices + + +class Nemotron3DenseVLProcessor: + # This is a wrapper around the AutoProcessor class to add some helper functions + def __init__( + self, + name="Qwen/Qwen3-VL-2B-Init", + credentials: str = "./credentials/s3_training.secret", + bucket: str = "bucket4", + cache_dir: str = None, + ): + self.name = name + if os.path.isdir(name): + model_name_or_path_local = name + else: + model_name_or_path_local = maybe_download_hf_model_from_s3( + name, credentials, bucket, include_model_weights=False + ) + + self.processor = AutoProcessor.from_pretrained(model_name_or_path_local, trust_remote_code=True) + log.info("Successfully loaded processor from local cache") + + if hasattr(self.processor, "image_token"): + self.image_token_id = self.processor.tokenizer.convert_tokens_to_ids(self.processor.image_token) + else: + self.image_token_id = None + if hasattr(self.processor, "video_token"): + self.video_token_id = self.processor.tokenizer.convert_tokens_to_ids(self.processor.video_token) + else: + self.video_token_id = None + self.eos_id = self.processor.tokenizer.eos_token_id + self.pad_id = self.processor.tokenizer.pad_token_id + self.vision_end_id = self.processor.tokenizer.convert_tokens_to_ids("") + + # Helper attributes for the dataloader video decoding function + self.shortest_edge = self.processor.image_processor.size["shortest_edge"] + self.min_height_width = int(np.sqrt(self.shortest_edge)) + self.patch_size = self.processor.video_processor.patch_size + self.temporal_patch_size = self.processor.video_processor.temporal_patch_size + self.merge_size = self.processor.video_processor.merge_size + self.use_smart_resize = True + if self.pad_id is None: + self.pad_id = self.eos_id + + def apply_chat_template( + self, + messages, + add_generation_prompt=False, + return_tensors="pt", + tokenize=True, + **kwargs, + ): + """ + Return: + inputs: dict + input_ids: torch.Tensor, shape: (N_token) + attention_mask: torch.Tensor, shape: (N_token) + texts: str, the raw text + image_sizes: torch.Tensor, shape (N_img, 2) + pixel_values: torch.Tensor, shape (N_img_patch, 3, 224, 224) + """ + + # messages = [msg for msg in messages if msg.get("role") != "system"] + assert tokenize, "tokenize must be True" + assert return_tensors == "pt", "return_tensors must be pt" + # Note: this tokenizer does not support "content": str, it always expect "content" entry to be a list of dicts + messages = convert_string_content_to_list_content(messages) + kwargs = {} + for message_id, message in enumerate(messages): + if isinstance(message["content"], list): + for sub_content in message["content"]: + if sub_content.get("type", "") == "image": + image = sub_content["image"] + max_pixels = sub_content.get("max_pixels", self.processor.image_processor.size["longest_edge"]) + min_pixels = sub_content.get("min_pixels", self.processor.image_processor.size["shortest_edge"]) + assert isinstance(image, Image.Image), ( + "image must be a url string for now, not support list of images for one content" + ) + width, height = image.size + resized_height, resized_width = smart_resize( + height, + width, + factor=32, + min_pixels=min_pixels, + max_pixels=max_pixels, + ) + image = image.resize((resized_width, resized_height)) + sub_content["image"] = image + + num_video, video_fps, video_total_num_frames, video_frames_indices = maybe_parse_video_content(messages) + if num_video > 0: + # Here we add the args to avoid the error: + # File "/usr/local/lib/python3.12/dist-packages/transformers/video_processing_utils.py", line 321, in _decode_and_sample_videos + # raise ValueError( + # ValueError: Sampling frames from a list of images is not supported! Set `do_sample_frames=False`. + # kwargs["videos_kwargs"] = dict(do_sample_frames=False) + assert num_video == 1, "only support one video for now" + fps = video_fps[0] + total_num_frames = video_total_num_frames[0] + frames_indices = video_frames_indices[0] + kwargs.update( + { + "do_sample_frames": False, + "video_metadata": dict(fps=fps, total_num_frames=total_num_frames, frames_indices=frames_indices), + } + ) + + inputs = self.processor.apply_chat_template( + messages, + tokenize=tokenize, + add_generation_prompt=add_generation_prompt, + return_dict=True, + return_tensors=return_tensors, + # padding="max_length", + # max_length=16000, + # truncation=False, + **kwargs, + ) + + # Convert batch features into single features + # By default, the processor returns a batch of features, but we use processor in dataloader, so we need to convert it to single features + inputs["input_ids"] = inputs["input_ids"][0] # [N_token] + inputs["attention_mask"] = inputs["attention_mask"][0] # [N_token] + num_image_tokens = inputs["input_ids"] == self.image_token_id # [N_token] + num_video_tokens = inputs["input_ids"] == self.video_token_id # [N_token] + return inputs + + def add_assistant_tokens_mask(self, tokens): + """ + Add a mask to the assistant tokens. + This is used to mask out tokens that are not generated by the assistant (e.g., system prompts, user prompts, chat templates), such that in the loss computation, only the tokens generated by the assistant are used. + If there are multiple turns in the conversation, the mask will mask all the assistant tokens in each turn. + + Args: + tokens (Union[List[int], torch.Tensor]): The tokens to add the mask to. + Returns: + Union[List[bool], torch.Tensor]: The mask. True for tokens generated by the assistant (i.e. should apply loss on), False for tokens not generated by the assistant. + """ + if isinstance(tokens, torch.Tensor) and tokens.ndim == 2: + mask = torch.stack( + [self.add_assistant_tokens_mask(tokens[i]) for i in range(tokens.shape[0])] + ) # [B,N_token] + assert mask.shape == tokens.shape + return mask + np_tokens = tokens.cpu().numpy() if isinstance(tokens, torch.Tensor) else np.array(tokens) + assert np_tokens.ndim == 1 + + # Constants defining bos, eos and fixed offsets. + BOS_TOKEN = "<|im_start|>" + EOS_TOKEN = "<|im_end|>" + ROLE = "assistant" + # Offsets: skip the bos + "assistant\n" (always 3 tokens) and include the eos (+1) for supervision + START_OFFSET = 3 + END_OFFSET = 1 + + # Retrieve token IDs for the markers and the role. + bos_token_id = self.processor.tokenizer.convert_tokens_to_ids(BOS_TOKEN) + eos_token_id = self.processor.tokenizer.convert_tokens_to_ids(EOS_TOKEN) + role_id = self.processor.tokenizer.convert_tokens_to_ids(ROLE) + role_ids = self.processor.tokenizer.encode( + ROLE, add_special_tokens=False + ) # In case the role_id corresponds to multiple tokens, decode it back to string for accurate comparison + think_start_id = self.processor.tokenizer.convert_tokens_to_ids("") + think_end_id = self.processor.tokenizer.convert_tokens_to_ids("") + + # Locate all positions where the start and end markers appear. + start_indices = np.where(np_tokens == bos_token_id)[0] + end_indices = np.where(np_tokens == eos_token_id)[0] + + # Initialize the mask with False values. + masks = np.zeros_like(np_tokens, dtype=bool) + assert len(start_indices) == len(end_indices) + # For each pair of bos/eos, check if the role is 'assistant' + # and apply the mask accordingly. + for start, end in zip(start_indices, end_indices): + end_pos = None + if np_tokens[start + 1] == role_id: + # Mask tokens from after the assistant header (start+3) to include the end marker (end+1) + masks[start + START_OFFSET : end + END_OFFSET] = True + end_pos = start + START_OFFSET + elif all(np_tokens[start + 1 : start + 1 + len(role_ids)] == role_ids): + masks[start + START_OFFSET + len(role_ids) - 1 : end + END_OFFSET] = True + end_pos = start + START_OFFSET + len(role_ids) - 1 + if end_pos is not None and np_tokens[end_pos] == think_start_id: + masks[end_pos] = False + if np_tokens[end_pos + 1] == think_end_id: + masks[end_pos + 1] = False + + assert masks.shape == np_tokens.shape + if isinstance(tokens, torch.Tensor): + return torch.from_numpy(masks) + else: + return masks.tolist() + + def encode(self, *args, **kwargs): + return self.processor.encode(*args, **kwargs) + + def decode(self, *args, **kwargs): + return self.processor.decode(*args, **kwargs) diff --git a/cosmos_framework/data/vlm/processors/nemotronvl_processor.py b/cosmos_framework/data/vlm/processors/nemotronvl_processor.py new file mode 100644 index 0000000..2edb222 --- /dev/null +++ b/cosmos_framework/data/vlm/processors/nemotronvl_processor.py @@ -0,0 +1,553 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: OpenMDW-1.1 + +import os +from typing import Dict, List, Optional + +import numpy as np +import torch +from PIL import Image +from transformers.models.auto.processing_auto import AutoProcessor +from transformers.processing_utils import VideosKwargs +from transformers.video_utils import VideoMetadata + +from cosmos_framework.utils import log +from cosmos_framework.utils.vlm.pretrained_models_downloader import maybe_download_hf_model_from_s3 + +nemotron_chat_template = """ +{%- set ns = namespace(enable_thinking=false, has_sys_prompt=false, non_tool_system_content='', has_video=false, explicit_think_requested=false) -%} +{%- set msg = namespace(content='') -%} +{%- for message in messages -%} + {%- if message['role'] == 'system' -%} + {%- set ns.has_sys_prompt = true -%} + {# Extract system content without tool flags #} + {%- if message['content'] is string -%} + {%- set ns.non_tool_system_content = message['content'].replace('', '<_end_think>').replace('/think', '').replace('/no_think', '').replace('<_end_think>', '').strip() -%} + {%- else -%} + {%- set ns.non_tool_system_content = '' -%} + {%- for content in message['content'] -%} + {%- if content['type'] == 'text' -%} + {%- set ns.non_tool_system_content = ns.non_tool_system_content + content['text'].replace('', '<_end_think>').replace('/think', '').replace('/no_think', '').replace('<_end_think>', '') -%} + {%- endif -%} + {%- endfor -%} + {%- set ns.non_tool_system_content = ns.non_tool_system_content.strip() -%} + {%- endif -%} + {%- endif -%} + {# Check for video content in all messages #} + {%- if message['content'] is not string -%} + {%- for content in message['content'] -%} + {%- if content['type'] == 'video' or content['type'] == 'video_url' -%} + {%- set ns.has_video = true -%} + {%- endif -%} + {%- endfor -%} + {%- endif -%} + {%- if message['content'] is string -%} + {%- if message['role'] == 'user' or message['role'] == 'system' -%} + {%- if '/think' in message['content'].replace('', '') -%} + {%- set ns.enable_thinking = true -%} + {%- set ns.explicit_think_requested = true -%} + {%- elif '/no_think' in message['content'] -%} + {%- set ns.enable_thinking = false -%} + {%- endif -%} + {%- endif -%} + {%- else -%} + {%- for content in message['content'] -%} + {%- if content['type'] == 'text' -%} + {%- if message['role'] == 'user' or message['role'] == 'system' -%} + {%- if '/think' in content['text'].replace('', '') -%} + {%- set ns.enable_thinking = true -%} + {%- set ns.explicit_think_requested = true -%} + {%- elif '/no_think' in content['text'] -%} + {%- set ns.enable_thinking = false -%} + {%- endif -%} + {%- endif -%} + {%- endif -%} + {%- endfor -%} + {%- endif -%} +{%- endfor -%} + +{{- bos_token -}} +{%- if messages[0]['role'] != 'system' -%} + {{- 'System\n' -}} +{%- else -%} + {{- 'System\n' + ns.non_tool_system_content }} +{%- endif -%} + +{%- if tools -%} + {%- if ns.non_tool_system_content != '' -%} + {{- '\n\n' -}} + {%- endif -%} + {{- 'You can use the following tools to assist the user if required:\n' -}} + {{- '[' -}} + {%- for tool in tools -%} + {{- (tool.function if tool.function is defined else tool) | tojson -}} + {{- ', ' if not loop.last else '' -}} + {%- endfor -%} + {{- ']\n\n' -}} + + {{- 'If you decide to call any tool(s), use the following format:\n' -}} + {{- '[{"name": "tool_name1", "arguments": "tool_args1"}, ' -}} + {{- '{"name": "tool_name2", "arguments": "tool_args2"}]\n\n' -}} + + {{- 'The user will execute tool-calls and return responses from tool(s) in this format:\n' -}} + {{- '[{"response": "tool_response1"}, ' -}} + {{- '{"response": "tool_response2"}]\n\n' -}} + + {{- 'Based on the tool responses, you can call additional tools if needed, ' -}} + {{- 'correct tool calls if any errors are found, or just respond to the user.' -}} +{%- endif -%} +{{- '\n' -}} + +{%- set messages = messages[1:] if messages[0]['role'] == 'system' else messages -%} + +{# Prevent no user or assistant message #} +{%- if messages|length == 0 -%} + {%- set messages = [{'role': 'user', 'content': ''}] -%} +{%- endif -%} + +{%- for message in messages %} + {%- if message['content'] is string -%} + {%- set msg.content = message['content'].replace('', '<_end_think>').replace('/think', '').replace('/no_think', '').replace('<_end_think>', '').strip() -%} + {%- else -%} + {%- set msg.content = '' -%} + {%- set mm_content = '' -%} + {%- set counters = namespace(images=0, videos=0) -%} + + {%- for content in message['content'] -%} + {%- if content['type'] == 'image' -%} + {%- set counters.images = counters.images + 1 -%} + {%- elif content['type'] == 'video' -%} + {%- set counters.videos = counters.videos + 1 -%} + {%- elif content['type'] == 'text' -%} + {%- set msg.content = msg.content + content['text'] -%} + {%- endif -%} + {%- endfor -%} + {%- if '' in msg.content -%} + {%- set counters.images = 0 -%} + {%- endif -%} + {%- if '