From ca60be5543df4001cdcb9b2f2fb2fb6b7c7b29ba Mon Sep 17 00:00:00 2001 From: yangyangt Date: Tue, 2 Jun 2026 07:46:45 -0700 Subject: [PATCH 01/11] Sync release drift from internal imaginaire4 Apply the cosmos-framework-release pipeline output over the public tree: - rewrite import paths (imaginaire.* / projects.cosmos3.* -> cosmos_framework.*) - rewrite internal module/file references in comments, docstrings, and path-form string literals via the file mapping - redact sensitive infra identifiers (S3 buckets, NFS/user paths, internal cluster config classes) - normalize license headers to OpenMDW (stamp missing, swap Apache, bump year) - exclude the one_logger / training_telemetry READMEs (broken relative links) - strip all TODO comments and any name-attributed marker comments (e.g. `# NOTE (alice): ...`); keep un-attributed NOTE/HACK/etc. Co-Authored-By: Claude Opus 4.7 --- .../auxiliary/guardrail/common/presets.py | 4 +- .../callbacks/compile_tokenizer.py | 15 +- cosmos_framework/callbacks/data_stats.py | 1 - .../callbacks/dataloader_state.py | 126 +------- .../callbacks/every_n_draw_sample.py | 4 - cosmos_framework/callbacks/grad_clip.py | 2 +- cosmos_framework/callbacks/hf_export.py | 4 +- cosmos_framework/callbacks/mfu.py | 1 - cosmos_framework/callbacks/wandb_log_eval.py | 1 - cosmos_framework/checkpoint/s3_filesystem.py | 2 +- .../configs/base/base_config_test.py | 2 +- .../configs/base/defaults/callbacks.py | 6 +- .../configs/base/defaults/cluster.py | 22 +- .../configs/base/defaults/compile.py | 2 +- .../configs/base/defaults/tokenizer.py | 43 ++- .../configs/base/defaults/unittest.py | 4 + cosmos_framework/configs/base/defaults/vlm.py | 204 ++++++------- cosmos_framework/configs/base/vlm/config.py | 13 +- .../configs/base/vlm/defaults/callbacks.py | 7 +- .../base/vlm/defaults/policy_config.py | 7 +- .../data/vfm/action/action_normalization.py | 4 +- .../data/vfm/action/domain_utils.py | 28 ++ .../data/vfm/action/json_formatter.py | 2 +- .../data/vfm/action/transforms.py | 3 +- .../data/vfm/augmentor_provider.py | 9 +- .../vfm/augmentors/idle_frames_text_info.py | 2 +- .../augmentors/interleaved_video_parsing.py | 2 +- .../data/vfm/augmentors/pkl_to_media.py | 2 - .../augmentors/transfer_control_transform.py | 2 +- .../data/vfm/augmentors/video_parsing.py | 63 +++- .../vfm/augmentors/vlm/nvlm_data_unify.py | 2 +- .../nvlm_sample_loaders_and_part_filters.py | 278 ++++++------------ .../data/vfm/augmentors/vlm/prompt_format.py | 4 +- .../data/vfm/augmentors/vlm/timestamp.py | 4 +- .../vlm/timestamp_with_subject_tracking.py | 4 +- .../vlm/timestamp_without_augment_message.py | 4 +- .../vlm/timestamp_without_end_time.py | 2 - .../data/vfm/augmentors/vlm/tokenize_data.py | 4 - cosmos_framework/data/vfm/joint_dataloader.py | 3 +- .../data/vfm/packing_iterable_dataset.py | 26 +- .../data/vfm/processors/__init__.py | 42 +-- .../vfm/processors/nemotronvl_processor.py | 2 - cosmos_framework/data/vfm/sequence_packing.py | 21 +- cosmos_framework/data/vfm/sound_data_utils.py | 6 +- cosmos_framework/data/vfm/utils.py | 2 - cosmos_framework/model/attention/backends.py | 11 + cosmos_framework/model/attention/checks.py | 6 +- .../model/attention/cudnn/__init__.py | 6 + .../model/attention/cudnn/checks.py | 1 - .../model/attention/cudnn/cudnn_forward.py | 18 +- .../model/attention/cudnn/functions.py | 5 +- .../model/attention/cudnn/meta.py | 4 +- .../model/attention/flash2/__init__.py | 17 +- .../model/attention/flash2/checks.py | 17 +- .../model/attention/flash2/functions.py | 2 +- .../model/attention/flash3/functions.py | 4 +- cosmos_framework/model/attention/frontend.py | 10 +- .../model/attention/natten/checks.py | 2 +- cosmos_framework/model/attention/varlen.py | 6 +- .../model/tokenizer/models/__init__.py | 3 + .../model/tokenizer/models/dense_backends.py | 53 +++- .../model/tokenizer/models/dense_runtime.py | 235 +++++++++++++-- .../models/modules/attention/full_attn.py | 10 +- .../models/modules/quantizers/residual_vq.py | 149 +++++++--- .../tokenizer/models/sparse_autoencoder.py | 2 +- .../model/tokenizer/models/text_decoder.py | 4 +- .../model/tokenizer/models/utils.py | 4 +- .../model/vfm/diffusion/samplers/edm.py | 2 +- .../diffusion/samplers/fm_solvers_unipc.py | 2 +- cosmos_framework/model/vfm/hf_model.py | 6 +- cosmos_framework/model/vfm/mot/attention.py | 24 +- .../model/vfm/mot/attention_test.py | 11 +- .../model/vfm/mot/context_parallel_utils.py | 2 +- .../model/vfm/mot/cosmos3_vfm_network.py | 2 +- .../mot/cosmos3_vfm_qwen3_vl_network_test.py | 7 +- .../model/vfm/mot/dot_product_attention.py | 5 +- .../model/vfm/mot/modeling_utils.py | 1 - cosmos_framework/model/vfm/mot/unified_mot.py | 118 ++++---- cosmos_framework/model/vfm/omni_mot_model.py | 184 ++---------- cosmos_framework/model/vfm/parallelize_vlm.py | 4 +- .../model/vfm/tokenizers/audio/avae.py | 9 +- .../model/vfm/tokenizers/dc_ae/__init__.py | 15 +- .../dc_ae/cosmos_ae_4x32x32_compile_test.py | 6 +- .../vfm/tokenizers/dc_ae/dc_ae_4x32x32.py | 65 +++- .../tokenizers/dc_ae/dc_ae_4x32x32_test.py | 6 +- .../model/vfm/tokenizers/dc_ae/dc_ae_v.py | 22 +- .../model/vfm/tokenizers/dc_ae/dc_ae_v_ops.py | 16 +- .../dc_ae/dc_ae_v_triton_rms_norm.py | 15 +- .../model/vfm/tokenizers/flux_vae_8x8.py | 2 + .../model/vfm/tokenizers/interface.py | 16 +- .../vfm/tokenizers/tokenization_qwen2.py | 4 +- .../vfm/tokenizers/uniae/noncausal_4x16x16.py | 220 +++++++------- .../uniae/noncausal_4x16x16_test.py | 211 ++++++++----- .../model/vfm/tokenizers/wan2pt1_vae_4x8x8.py | 3 + .../vfm/tokenizers/wan2pt2_vae_4x16x16.py | 9 +- .../model/vfm/upsampler/prompts.py | 2 +- .../model/vfm/utils/safetensors_loader.py | 4 +- .../vfm/utils/safetensors_loader_test.py | 4 +- .../vlm/qwen3_vl/configuration_qwen3_vl.py | 2 +- .../model/vfm/vlm/qwen3_vl/qwen3_vl.py | 17 +- .../model/vfm/vlm/qwen3_vl/utils.py | 12 + .../vlm/qwen3_vl/video_processing_qwen3_vl.py | 2 +- cosmos_framework/model/vfm/vlm_model.py | 14 +- cosmos_framework/scripts/train.py | 85 ++---- cosmos_framework/tools/flops/qwen3_vl.py | 2 +- cosmos_framework/tools/visualize/video.py | 17 +- cosmos_framework/trainer/__init__.py | 11 +- cosmos_framework/utils/callback.py | 107 ++++++- cosmos_framework/utils/checkpoint_db.py | 33 +-- cosmos_framework/utils/checkpointer.py | 3 - cosmos_framework/utils/config.py | 67 +---- cosmos_framework/utils/context_managers.py | 19 ++ cosmos_framework/utils/device.py | 4 - cosmos_framework/utils/distributed.py | 4 +- .../utils/easy_io/backends/local_backend.py | 1 - cosmos_framework/utils/easy_io/easy_io.py | 1 - .../utils/easy_io/easy_io_test.py | 6 +- cosmos_framework/utils/ema.py | 2 +- .../utils/env_parsers/cred_env_parser.py | 3 +- cosmos_framework/utils/flags.py | 7 +- .../utils/lazy_config/__init__.py | 2 +- cosmos_framework/utils/lazy_config/lazy.py | 26 +- .../utils/lazy_config/lazy_call.py | 12 + cosmos_framework/utils/misc.py | 8 +- cosmos_framework/utils/object_store.py | 54 +--- .../one_logger/one_logger_override_utils.py | 2 +- .../utils/one_logger/one_logger_utils.py | 10 +- cosmos_framework/utils/serialization.py | 32 +- .../utils/training_telemetry/__init__.py | 15 +- .../utils/training_telemetry/callback.py | 4 +- .../utils/training_telemetry/utils.py | 12 +- .../utils/vfm/hf_attention_cosmos.py | 10 +- cosmos_framework/utils/vfm/lora.py | 14 +- cosmos_framework/utils/vfm/model_loader.py | 2 +- cosmos_framework/utils/vfm/monkey_patch.py | 4 +- cosmos_framework/utils/vfm/optimizer.py | 10 +- cosmos_framework/utils/vfm/parallelism.py | 4 +- .../utils/vfm/vlm/flop_calculator.py | 1 - .../vfm/vlm/pretrained_models_downloader.py | 7 +- .../utils/vlm/compute_flops_qwen3vl.py | 4 +- .../utils/vlm/dcp_checkpointer.py | 2 - cosmos_framework/utils/vlm/distributed.py | 4 +- cosmos_framework/utils/vlm/optimizer.py | 1 - .../utils/vlm/pretrained_models_downloader.py | 2 +- 144 files changed, 1724 insertions(+), 1494 deletions(-) diff --git a/cosmos_framework/auxiliary/guardrail/common/presets.py b/cosmos_framework/auxiliary/guardrail/common/presets.py index d320b5e..20bba86 100644 --- a/cosmos_framework/auxiliary/guardrail/common/presets.py +++ b/cosmos_framework/auxiliary/guardrail/common/presets.py @@ -26,9 +26,7 @@ def create_text_guardrail_runner(offload_model_to_cpu: bool = False) -> Guardrai def create_video_guardrail_runner(offload_model_to_cpu: bool = False) -> GuardrailRunner: """Create the video guardrail runner.""" return GuardrailRunner( - safety_models=[ - # VideoContentSafetyFilter(offload_model_to_cpu=offload_model_to_cpu), # Too many false positives - ], + safety_models=[VideoContentSafetyFilter(offload_model_to_cpu=offload_model_to_cpu)], postprocessors=[RetinaFaceFilter(offload_model_to_cpu=offload_model_to_cpu)], ) diff --git a/cosmos_framework/callbacks/compile_tokenizer.py b/cosmos_framework/callbacks/compile_tokenizer.py index 84efa16..3ee15d2 100644 --- a/cosmos_framework/callbacks/compile_tokenizer.py +++ b/cosmos_framework/callbacks/compile_tokenizer.py @@ -4,7 +4,7 @@ """Training callback that defers AOT compilation of the VAE tokenizer. The actual compilation logic lives in -:meth:`~projects.cosmos3.vfm.tokenizers.wan2pt2_vae_4x16x16.Wan2pt2VAEInterface.compile_encode`. +:meth:`~cosmos_framework.model.vfm.tokenizers.wan2pt2_vae_4x16x16.Wan2pt2VAEInterface.compile_encode`. This module provides a :class:`CompileTokenizer` callback that invokes it at the right point during training (after ``compile_after_iterations`` steps, to avoid NCCL timeouts during CUDA/cuDNN warm-up). @@ -21,6 +21,7 @@ """ from collections.abc import Sequence +from typing import Literal import torch @@ -43,6 +44,10 @@ def __init__( enabled: bool = False, compile_after_iterations: int = 3, warmup_resolutions: Sequence[str] | None = None, + backend: Literal["cudagraphs", "inductor"] = "inductor", + mode: Literal["reduce-overhead", "max-autotune"] | None = "reduce-overhead", + fullgraph: bool = False, + dynamic: bool = False, ): """ Args: @@ -60,6 +65,10 @@ def __init__( self.compile_after_iterations: int = compile_after_iterations self.skip_counter: int = 0 self.warmup_resolutions: Sequence[str] | None = warmup_resolutions + self.backend: Literal["cudagraphs", "inductor"] = backend + self.mode: Literal["reduce-overhead", "max-autotune"] | None = mode + self.fullgraph: bool = fullgraph + self.dynamic: bool = dynamic if self.enabled: if self.warmup_resolutions is None: @@ -101,6 +110,10 @@ def on_training_step_start( tokenizer.compile_encode( self.warmup_resolutions, output_dir=self.config.job.path_local, + backend=self.backend, + mode=self.mode, + fullgraph=self.fullgraph, + dynamic=self.dynamic, ) self.skip_counter += 1 diff --git a/cosmos_framework/callbacks/data_stats.py b/cosmos_framework/callbacks/data_stats.py index 20e3161..2914981 100644 --- a/cosmos_framework/callbacks/data_stats.py +++ b/cosmos_framework/callbacks/data_stats.py @@ -51,7 +51,6 @@ def on_training_step_end( # Handle case where dataset_name gets batched into a list if isinstance(dataset_name, list): - assert len(dataset_name) == 1, "dataset_name should be a list of 1" dataset_name = dataset_name[0] diff --git a/cosmos_framework/callbacks/dataloader_state.py b/cosmos_framework/callbacks/dataloader_state.py index fee45b9..bee9e43 100644 --- a/cosmos_framework/callbacks/dataloader_state.py +++ b/cosmos_framework/callbacks/dataloader_state.py @@ -26,18 +26,14 @@ class DataLoaderStateCallback(Callback): def __init__( self, distributor_type: str | None = None, - name: str = "", ) -> None: super().__init__() self.distributor_type = distributor_type - self.name = name self.config: Any = None self.state: dict[int, NoReplaceShardlistState] = {} self.verbose = True def _update_state_from_batch(self, data_batch: dict[str, torch.Tensor]) -> None: - if "sample_worker_id" not in data_batch: - return # batch has no position metadata (shuffle=False or iterable data_source) worker_ids = data_batch["sample_worker_id"].tolist() # [B] epochs = data_batch["sample_epoch"].tolist() # [B] indices = data_batch["sample_index"].tolist() # [B] @@ -50,8 +46,6 @@ def _update_state_from_batch(self, data_batch: dict[str, torch.Tensor]) -> None: ): self.state[worker_id] = NoReplaceShardlistState(epoch=epoch, index=index) - _ACTIVE_DISTRIBUTOR_TYPES = ("no_replace", "data_packer") - def on_training_step_batch_end( self, model: ImaginaireModel, @@ -60,7 +54,7 @@ def on_training_step_batch_end( loss: torch.Tensor, iteration: int = 0, ) -> None: - if self.distributor_type in self._ACTIVE_DISTRIBUTOR_TYPES: + if self.distributor_type == "no_replace": self._update_state_from_batch(data_batch) def on_training_step_end( @@ -71,7 +65,7 @@ def on_training_step_end( loss: torch.Tensor, iteration: int = 0, ) -> None: - if self.distributor_type in self._ACTIVE_DISTRIBUTOR_TYPES: + if self.distributor_type == "no_replace": if self.verbose: if iteration % self.config.trainer.logging_iter == 0: msg = "\n" @@ -80,10 +74,10 @@ def on_training_step_end( log.info(msg) def has_checkpoint_state(self) -> bool: - return self.distributor_type in self._ACTIVE_DISTRIBUTOR_TYPES + return self.distributor_type == "no_replace" def state_dict(self) -> dict[int, dict[str, int]]: - if self.distributor_type not in self._ACTIVE_DISTRIBUTOR_TYPES: + if self.distributor_type != "no_replace": return {} state_dict: dict[int, dict[str, int]] = {} @@ -96,7 +90,7 @@ def state_dict(self) -> dict[int, dict[str, int]]: return state_dict def load_state_dict(self, state_dict: dict[int, dict[str, int]]) -> None: - if self.distributor_type not in self._ACTIVE_DISTRIBUTOR_TYPES: + if self.distributor_type != "no_replace": return if not state_dict: @@ -104,114 +98,10 @@ def load_state_dict(self, state_dict: dict[int, dict[str, int]]) -> None: return self.state = {} - # Build env var prefix. For data_packer, namespacing avoids conflicts - # when multiple DataPackerDataLoader instances share the same process - # (e.g. inside JointDataPackerDataLoader). name="" → original format. - _dp_pfx = f"DP_STATE_{self.name}_" if self.name else "DP_STATE_" for worker_id, per_worker_state in state_dict.items(): epoch = per_worker_state["epoch"] index = per_worker_state["index"] self.state[worker_id] = NoReplaceShardlistState(epoch=epoch, index=index) - if self.distributor_type == "data_packer": - os.environ[f"{_dp_pfx}WORKER_{worker_id}_EPOCH"] = str(epoch) - os.environ[f"{_dp_pfx}WORKER_{worker_id}_INDEX"] = str(index) - log.info(f"Loaded data_packer dataloader state for worker {worker_id}: epoch={epoch}, index={index}") - else: - os.environ[f"NSL_STATE_WORKER_{worker_id}_EPOCH"] = str(epoch) - os.environ[f"NSL_STATE_WORKER_{worker_id}_INDEX"] = str(index) - log.info(f"Loaded no_replace dataloader state for worker {worker_id}: epoch={epoch}, index={index}") - - -class JointDataLoaderStateCallback(Callback): - """Checkpoint/resume state for ``JointDataPackerDataLoader``. - - Manages two levels of state in a single DCP checkpoint entry - (``checkpoint_component = "dataloader"``): - - 1. **Outer** ``global_id`` — the number of batches the outer loader has - yielded. Restored via ``outer_loader.set_start_iteration(global_id)`` - so the deterministic dataset-selection sequence resumes from the correct - step. - - 2. **Inner** per-dataset, per-worker ``(epoch, index)`` — one - ``DataLoaderStateCallback`` per inner loader, keyed by the dataset name. - Each inner callback sets namespaced env vars on ``load_state_dict`` so - workers fast-forward to the saved sample position. - - Usage in experiment configs:: - - joint_loader = JointDataPackerDataLoader(dataloaders={...}, seed=42) - exp["dataloader_train"] = joint_loader - exp["trainer"]["callbacks"]["dataloader_state"] = JointDataLoaderStateCallback( - outer_loader=joint_loader, - distributor_type="data_packer", - ) - - The ``checkpoint_component = "dataloader"`` class attribute ensures the DCP - checkpointer's ``_DataloaderWrapper`` discovers exactly this callback (it - picks the first matching callback). Do **not** also register standalone - ``DataLoaderStateCallback`` instances for the inner loaders — this class - already handles them all. - """ - - checkpoint_component: str = "dataloader" - - def __init__( - self, - outer_loader: Any, - distributor_type: str = "data_packer", - ) -> None: - super().__init__() - self._outer = outer_loader - self._inner: dict[str, DataLoaderStateCallback] = { - name: DataLoaderStateCallback(distributor_type=distributor_type, name=name) - for name in outer_loader._names - } - self.config: Any = None - - def _update_state_from_batch(self, batch: dict) -> None: - name = batch.get("dataset_name") - if name in self._inner: - self._inner[name]._update_state_from_batch(batch) - - def on_training_step_batch_end( - self, - model: Any, - data_batch: dict, - output_batch: dict, - loss: Any, - iteration: int = 0, - ) -> None: - self._update_state_from_batch(data_batch) - - def on_training_step_end( - self, - model: Any, - data_batch: dict, - output_batch: dict, - loss: Any, - iteration: int = 0, - ) -> None: - if self.config and iteration % self.config.trainer.logging_iter == 0: - msg = f"\nJointDataPackerDataLoader global_id={self._outer._global_id}\n" - for name, cb in self._inner.items(): - for wid, state in cb.state.items(): - msg += f" [{name}] worker {wid}: epoch={state.epoch}, index={state.index}\n" - log.info(msg) - - def has_checkpoint_state(self) -> bool: - return True - - def state_dict(self) -> dict: - return { - "global_id": self._outer._global_id, - **{name: cb.state_dict() for name, cb in self._inner.items()}, - } - - def load_state_dict(self, state: dict) -> None: - global_id = state.get("global_id", 0) - self._outer.set_start_iteration(global_id) - log.info(f"JointDataLoaderStateCallback: resumed outer global_id={global_id}") - for name, cb in self._inner.items(): - if name in state: - cb.load_state_dict(state[name]) + os.environ[f"NSL_STATE_WORKER_{worker_id}_EPOCH"] = str(epoch) + os.environ[f"NSL_STATE_WORKER_{worker_id}_INDEX"] = str(index) + log.info(f"Loaded no replace dataloader state for worker {worker_id}: epoch={epoch}, index={index}") diff --git a/cosmos_framework/callbacks/every_n_draw_sample.py b/cosmos_framework/callbacks/every_n_draw_sample.py index baf1ffc..9aa96fa 100644 --- a/cosmos_framework/callbacks/every_n_draw_sample.py +++ b/cosmos_framework/callbacks/every_n_draw_sample.py @@ -154,8 +154,6 @@ def x0_pred(self, trainer, model, data_batch, output_batch, loss, iteration): tag = "ema" if self.is_ema else "reg" log.debug("starting data and condition model", rank0_only=False) - - data_clean = model.get_data_and_condition(data_batch) raw_data = data_clean.raw_state_vision x0 = data_clean.x0_tokens_vision @@ -185,7 +183,6 @@ def x0_pred(self, trainer, model, data_batch, output_batch, loss, iteration): log.debug(f"done denoising {sigma}", rank0_only=False) mse_loss = distributed.dist_reduce_tensor(F.mse_loss(sample, x0)) mse_loss_list.append(mse_loss) - if hasattr(model, "decode"): sample = model.decode(sample) to_show.append(sample.float().cpu()) @@ -316,7 +313,6 @@ def sample(self, trainer, model, data_batch, output_batch, loss, iteration): for sample_idx in range(data_clean.batch_size): n_vis = num_items[sample_idx] # First item(s) are condition, last item is generation target - # but we need to support multiple conditions per sample in the future. Current code # can handle this without throwing an error. condition_images.append(raw_data[vis_offset]) # source image (1, C, 1, H, W) diff --git a/cosmos_framework/callbacks/grad_clip.py b/cosmos_framework/callbacks/grad_clip.py index 151c1bc..f3cb4fa 100644 --- a/cosmos_framework/callbacks/grad_clip.py +++ b/cosmos_framework/callbacks/grad_clip.py @@ -132,7 +132,7 @@ def _clip_grad( # `torch.distributed._tensor.ops.math_ops._NormPartial`. # We can simply reduce the DTensor to get the total norm in this # tensor's process group and then convert it to a local tensor. - + # NOTE: It has two purposes: # 1. to make sure the total norm is computed correctly when PP is used (see below) # 2. to return a reduced mesh_norm tensor whose .item() would return the correct value if isinstance(mesh_norm, DTensor): diff --git a/cosmos_framework/callbacks/hf_export.py b/cosmos_framework/callbacks/hf_export.py index 8bcba5a..07c324f 100644 --- a/cosmos_framework/callbacks/hf_export.py +++ b/cosmos_framework/callbacks/hf_export.py @@ -137,11 +137,11 @@ def on_save_checkpoint(self, model: Any, state_dict: dict[str, Any]) -> None: if not isinstance(model, VLMModel): # The legacy vlm/train.py path passes model_parts: list[nn.Module] (raw HF # models without the VLMModel attribute structure). HF export requires the - # VLMModel wrapper, which is only available via the unified cosmos_framework/scripts/train.py path. + # VLMModel wrapper, which is only available via the unified scripts/train.py path. if isinstance(model, list): log.warning( "[HFExportCallback] Received model_parts (list) instead of VLMModel. " - "HF export requires the unified training path (cosmos_framework/scripts/train.py). Skipping." + "HF export requires the unified training path (scripts/train.py). Skipping." ) else: log.warning( diff --git a/cosmos_framework/callbacks/mfu.py b/cosmos_framework/callbacks/mfu.py index 3035437..4a6c792 100644 --- a/cosmos_framework/callbacks/mfu.py +++ b/cosmos_framework/callbacks/mfu.py @@ -138,7 +138,6 @@ def _ensure_initialised(self, model: ImaginaireModel) -> None: ac_cfg = getattr(model_cfg, "activation_checkpointing", None) ac_mode = getattr(ac_cfg, "mode", "none") - # Some activations don't need to be recomputed under selective AC, so # we need to remove them from the FLOP computation. self._use_activation_checkpointing = ac_mode != "none" diff --git a/cosmos_framework/callbacks/wandb_log_eval.py b/cosmos_framework/callbacks/wandb_log_eval.py index ac6911f..abea93f 100644 --- a/cosmos_framework/callbacks/wandb_log_eval.py +++ b/cosmos_framework/callbacks/wandb_log_eval.py @@ -88,7 +88,6 @@ def on_validation_step_end( # Handle case where dataset_name gets batched into a list if isinstance(dataset_name, list): - assert len(dataset_name) == 1, "dataset_name should be a list of 1" dataset_name = dataset_name[0] diff --git a/cosmos_framework/checkpoint/s3_filesystem.py b/cosmos_framework/checkpoint/s3_filesystem.py index e47219e..8975ce1 100644 --- a/cosmos_framework/checkpoint/s3_filesystem.py +++ b/cosmos_framework/checkpoint/s3_filesystem.py @@ -285,7 +285,7 @@ def __init__( """ super().__init__( path=path, - sync_files=False, + sync_files=False, # FIXME: setting this to True makes the run to fail (L#333: `os.fsync(stream.fileno())`) **kwargs, ) self.fs = S3FileSystem(credential_path, enable_gcs_patch_in_boto3=enable_gcs_patch_in_boto3) # type: ignore diff --git a/cosmos_framework/configs/base/base_config_test.py b/cosmos_framework/configs/base/base_config_test.py index 093e17b..2ea4319 100644 --- a/cosmos_framework/configs/base/base_config_test.py +++ b/cosmos_framework/configs/base/base_config_test.py @@ -29,7 +29,7 @@ def test_config_init_experiment_mot(experiment_name): Parameterized test to verify config initialization for multiple experiments. PYTHONPATH=. torchrun --nproc_per_node=8 -m pytest -s cosmos_framework/configs/base/config_test_mot.py --L1 """ - config_file = "configs/base/config.py" + config_file = "cosmos_framework/configs/base/config.py" config_module = get_config_module(config_file) config = importlib.import_module(config_module).make_config() config = override( diff --git a/cosmos_framework/configs/base/defaults/callbacks.py b/cosmos_framework/configs/base/defaults/callbacks.py index 46c85ed..602646d 100644 --- a/cosmos_framework/configs/base/defaults/callbacks.py +++ b/cosmos_framework/configs/base/defaults/callbacks.py @@ -10,7 +10,7 @@ from cosmos_framework.utils.lazy_config import LazyCall as L from cosmos_framework.utils.callback import LowPrecisionCallback, WandBCallback from cosmos_framework.callbacks.compile_tokenizer import CompileTokenizer - +from cosmos_framework.callbacks.dataloading_monitor import DetailedDataLoadingSpeedMonitor from cosmos_framework.callbacks.device_monitor import DeviceMonitor from cosmos_framework.callbacks.every_n_draw_sample import EveryNDrawSample from cosmos_framework.callbacks.expert_heatmap import ExpertHeatmap @@ -49,6 +49,10 @@ param_count=L(ParamCount)( # use model save_s3="${upload_reproducible_setup}", ), + dataloader_speed=L(DetailedDataLoadingSpeedMonitor)( + every_n=100, + save_s3="${upload_reproducible_setup}", + ), wandb_val=L(WandBCallbackEval)( save_s3="${upload_reproducible_setup}", ), diff --git a/cosmos_framework/configs/base/defaults/cluster.py b/cosmos_framework/configs/base/defaults/cluster.py index 23b49dd..46450b9 100644 --- a/cosmos_framework/configs/base/defaults/cluster.py +++ b/cosmos_framework/configs/base/defaults/cluster.py @@ -23,14 +23,24 @@ class ClusterConfig: DefaultClusterConfig: ClusterConfig = ClusterConfig( object_store_bucket_data="", - object_store_bucket_checkpoint="bucket-checkpoint", - object_store_bucket_pretrained="bucket-pretrained", - object_store_credential_data="credentials/data.secret", - object_store_credential_checkpoint="credentials/checkpoint.secret", - object_store_credential_pretrained="credentials/pretrained.secret", + object_store_bucket_checkpoint="bucket4", + object_store_bucket_pretrained="bucket4", + object_store_credential_data="credentials/s3_training.secret", + object_store_credential_checkpoint="credentials/s3_checkpoint.secret", + object_store_credential_pretrained="credentials/s3_checkpoint.secret", +) + +DefaultClusterConfig: ClusterConfig = ClusterConfig( + object_store_bucket_data="", + object_store_bucket_checkpoint="bucket1", + object_store_bucket_pretrained="bucket0", + object_store_credential_data="credentials/gcp_checkpoint.secret", + object_store_credential_checkpoint="credentials/gcp_training.secret", + object_store_credential_pretrained="credentials/gcp_training.secret", ) def register_cluster(): cs = ConfigStore.instance() - cs.store(group="cluster", package="job.cluster", name="default", node=DefaultClusterConfig) + cs.store(group="cluster", package="job.cluster", name="aws_iad_h100", node=DefaultClusterConfig) + cs.store(group="cluster", package="job.cluster", name="gcp_iad_gb200", node=DefaultClusterConfig) diff --git a/cosmos_framework/configs/base/defaults/compile.py b/cosmos_framework/configs/base/defaults/compile.py index b0e1c88..3d5ebf7 100644 --- a/cosmos_framework/configs/base/defaults/compile.py +++ b/cosmos_framework/configs/base/defaults/compile.py @@ -24,7 +24,7 @@ class CompileConfig: # (maps to ``torch.compile(dynamic=...)``). Defaults to True for training, # which sees varying shapes across batches (sequence length, CP sharding, ...); # specializing would recompile continuously. See ParallelismOverrides in - # cosmos_framework/inference/common/args.py for the inference-side rationale + # packages/cosmos3/cosmos3/common/args.py for the inference-side rationale # (where dynamic=False is preferred for stable AR shapes). compile_dynamic: bool = True diff --git a/cosmos_framework/configs/base/defaults/tokenizer.py b/cosmos_framework/configs/base/defaults/tokenizer.py index 55cb01c..237dba6 100644 --- a/cosmos_framework/configs/base/defaults/tokenizer.py +++ b/cosmos_framework/configs/base/defaults/tokenizer.py @@ -17,10 +17,12 @@ PRETRAINED_TOKENIZER_FLUX_VAE_PTH = "pretrained/tokenizers/image/flux/ae.safetensors" # UniAE checkpoint paths -PRETRAINED_TOKENIZER_UNIAE_4X16X16_C48_T8TO24_64TO512P_FPS_ALL_ENCODER_NONCAUSAL_DECODER_NONCAUSAL_NOGAN_BEST_S1_VAE_PTH = "pretrained/tokenizers/video/cosmos/uniae4x16x16_c48_t8to24_64to512p_fps_all_encoder_noncausal_decoder_noncausal_nogan_best_s1.pt" +PRETRAINED_TOKENIZER_UNIAE_4X16X16_C48_T32TO160_MIXP_FPS_MIX_ENCODER_NONCAUSAL_DECODER_NONCAUSAL_NOGAN_S3_QWEN0P6B_VAE_PTH = ( + "s3://bucket1/uniae/tok_experiments/" + "uniae_s3_prod32_ditval_video_b1_50k_r1/checkpoints/iter_000050000.pt" +) # DCAE checkpoint paths -PRETRAINED_TOKENIZER_DCAE_PTH = "pretrained/tokenizers/video/cosmos/dc-ae-v-1.0-f32t4c64-cosmos-encoder-causal-decoder-chunk-causal-4-frame-120-pad-7-no-gan.pt" PRETRAINED_TOKENIZER_DCAE_4X32X32_C64_T120_256P_FPS_ALL_ENCODER_CAUSAL_DECODER_CHUNKCAUSAL4_NOGAN_COSMOS_PAD_7_V0PT2_PTH = "pretrained/tokenizers/video/cosmos/dcae4x32x32_c64_t120_256p_fps_all_encoder_causal_decoder_chunk_causal_4_nogan_cosmos_pad_7_v0.2.pt" PRETRAINED_TOKENIZER_DCAE_4X32X32_C96_T120_256P_FPS_ALL_ENCODER_CAUSAL_DECODER_CHUNKCAUSAL4_NOGAN_COSMOS_PAD_7_V0PT2_PTH = "pretrained/tokenizers/video/cosmos/dcae4x32x32_c96_t120_256p_fps_all_encoder_causal_decoder_chunk_causal_4_nogan_cosmos_pad_7_v0.2.pt" PRETRAINED_TOKENIZER_DCAE_4X32X32_C128_T120_256P_FPS_ALL_ENCODER_CAUSAL_DECODER_CHUNKCAUSAL4_NOGAN_COSMOS_PAD_7_V0PT2_PTH = "pretrained/tokenizers/video/cosmos/dcae4x32x32_c128_t120_256p_fps_all_encoder_causal_decoder_chunk_causal_4_nogan_cosmos_pad_7_v0.2.pt" @@ -44,6 +46,7 @@ chunk_duration=1, spatial_compression_factor=8, temporal_compression_factor=1, + causal=True, ) Wan2pt1VAEConfig: LazyDict = L(Wan2pt1VAEInterface)( @@ -53,6 +56,7 @@ vae_path=PRETRAINED_TOKENIZER_WAN2PT1_VAE_PTH, spatial_compression_factor=8, temporal_compression_factor=4, + causal=True, ) Wan2pt2VAEConfig: LazyDict = L(Wan2pt2VAEInterface)( @@ -61,14 +65,7 @@ vae_path=PRETRAINED_TOKENIZER_WAN2PT2_VAE_PTH, spatial_compression_factor=16, temporal_compression_factor=4, -) - -DCAE4x32x32Config: LazyDict = L(DCAE4x32x32Interface)( - bucket_name=PLACEHOLDER, - object_store_credential_path_pretrained=PLACEHOLDER, - vae_path=PRETRAINED_TOKENIZER_DCAE_PTH, - spatial_compression_factor=32, - temporal_compression_factor=4, + causal=True, ) DCAE4x32x32C64T120_256pFpsAllEncoderCausalDecoderChunkCausal4NoganCosmosPad7V0pt2Config: LazyDict = L( @@ -80,6 +77,7 @@ model_name="dcae4x32x32_c64_t120_256p_fps_all_encoder_causal_decoder_chunk_causal_4_nogan_cosmos_pad_7_v0.2", spatial_compression_factor=32, temporal_compression_factor=4, + causal=True, ) DCAE4x32x32C96T120_256pFpsAllEncoderCausalDecoderChunkCausal4NoganCosmosPad7V0pt2Config: LazyDict = L( @@ -104,12 +102,17 @@ temporal_compression_factor=4, ) -UniAE4x16x16C48T8to24_64to512pFpsAllEncoderNoncausalDecoderNoncausalNoganBestS1Config: LazyDict = L(UniAEVAEInterface)( + +UniAE4x16x16C48T32to160MixpFpsMixEncoderNoncausalDecoderNoncausalNoganS3Qwen0p6bVAEConfig: LazyDict = L( + UniAEVAEInterface +)( bucket_name=PLACEHOLDER, object_store_credential_path_pretrained=PLACEHOLDER, - vae_path=PRETRAINED_TOKENIZER_UNIAE_4X16X16_C48_T8TO24_64TO512P_FPS_ALL_ENCODER_NONCAUSAL_DECODER_NONCAUSAL_NOGAN_BEST_S1_VAE_PTH, + vae_path=PRETRAINED_TOKENIZER_UNIAE_4X16X16_C48_T32TO160_MIXP_FPS_MIX_ENCODER_NONCAUSAL_DECODER_NONCAUSAL_NOGAN_S3_QWEN0P6B_VAE_PTH, spatial_compression_factor=16, temporal_compression_factor=4, + pixel_trim=True, + causal=False, ) # ============================================================================= @@ -173,8 +176,8 @@ def register_tokenizer(): cs.store( group="tokenizer", package="model.config.tokenizer", - name="uniae_4x16x16_c48_t8to24_64to512p_fps_all_encoder_noncausal_decoder_noncausal_nogan_best_s1_tokenizer", - node=UniAE4x16x16C48T8to24_64to512pFpsAllEncoderNoncausalDecoderNoncausalNoganBestS1Config, + name="uniae_4x16x16_c48_t32to160_mixp_fps_mix_encoder_noncausal_decoder_noncausal_nogan_s3_qwen0p6b_tokenizer", + node=UniAE4x16x16C48T32to160MixpFpsMixEncoderNoncausalDecoderNoncausalNoganS3Qwen0p6bVAEConfig, ) # Flux tokenizer cs.store(group="tokenizer", package="model.config.tokenizer", name="flux_tokenizer", node=FluxVAEConfig) @@ -182,25 +185,19 @@ def register_tokenizer(): cs.store( group="tokenizer", package="model.config.tokenizer", - name="dc_ae_4x32x32_tokenizer", - node=DCAE4x32x32Config, - ) - cs.store( - group="tokenizer", - package="model.config.tokenizer", - name="dc_ae_4x32x32_c64_t120_256p_fps_all_encoder_causal_decoder_chunk_causal_4_nogan_cosmos_pad_7_v0.2_tokenizer", + name="dcae4x32x32_c64_t120_256p_fps_all_encoder_causal_decoder_chunk_causal_4_nogan_cosmos_pad_7_v0.2_tokenizer", node=DCAE4x32x32C64T120_256pFpsAllEncoderCausalDecoderChunkCausal4NoganCosmosPad7V0pt2Config, ) cs.store( group="tokenizer", package="model.config.tokenizer", - name="dc_ae_4x32x32_c96_t120_256p_fps_all_encoder_causal_decoder_chunk_causal_4_nogan_cosmos_pad_7_v0.2_tokenizer", + name="dcae4x32x32_c96_t120_256p_fps_all_encoder_causal_decoder_chunk_causal_4_nogan_cosmos_pad_7_v0.2_tokenizer", node=DCAE4x32x32C96T120_256pFpsAllEncoderCausalDecoderChunkCausal4NoganCosmosPad7V0pt2Config, ) cs.store( group="tokenizer", package="model.config.tokenizer", - name="dc_ae_4x32x32_c128_t120_256p_fps_all_encoder_causal_decoder_chunk_causal_4_nogan_cosmos_pad_7_v0.2_tokenizer", + name="dcae4x32x32_c128_t120_256p_fps_all_encoder_causal_decoder_chunk_causal_4_nogan_cosmos_pad_7_v0.2_tokenizer", node=DCAE4x32x32C128T120_256pFpsAllEncoderCausalDecoderChunkCausal4NoganCosmosPad7V0pt2Config, ) diff --git a/cosmos_framework/configs/base/defaults/unittest.py b/cosmos_framework/configs/base/defaults/unittest.py index 5af5bb2..7e1b8c1 100644 --- a/cosmos_framework/configs/base/defaults/unittest.py +++ b/cosmos_framework/configs/base/defaults/unittest.py @@ -2,8 +2,12 @@ # SPDX-License-Identifier: OpenMDW-1.1 import attrs +# from cosmos_framework.configs.base.defaults.cluster import DefaultClusterConfig + # We are hardcoding the unittest assets in this file. +# CLUSTER_CONFIG = DefaultClusterConfig + # add codeowner for cosmos_framework/model/vfm/tokenizers diff --git a/cosmos_framework/configs/base/defaults/vlm.py b/cosmos_framework/configs/base/defaults/vlm.py index fa4d7c8..e031d33 100644 --- a/cosmos_framework/configs/base/defaults/vlm.py +++ b/cosmos_framework/configs/base/defaults/vlm.py @@ -95,11 +95,7 @@ def download_tokenizer_files(model_name: str, config_variant: str) -> str: return destination_dir -def create_qwen2_tokenizer_with_download(pretrained_model_name: str, config_variant: str, **_unused_kwargs): - # **_unused_kwargs absorbs extras (e.g. tokenizer_type) that OmegaConf - # merges in from a vlm_config preset's tokenizer block when an experiment - # overrides the tokenizer with this function but doesn't fully replace - # the preset's kwarg dict. +def create_qwen2_tokenizer_with_download(pretrained_model_name: str, config_variant: str): destination_dir = download_tokenizer_files(pretrained_model_name, config_variant) return LLMTokenizerProcessor(Qwen2Tokenizer.from_pretrained(destination_dir)) @@ -140,9 +136,6 @@ class VLMConfig: # HuggingFace model identifier or local path. Drives AutoConfig + AutoModel selection. model_name: str = "" - # Safetensor path for model - safetensors_path: str = "" - # Optional pretrained-weights overlay (separate from the AutoModel structural # init driven by model_name). pretrained_weights: PretrainedWeightsConfig = PretrainedWeightsConfig() @@ -285,29 +278,6 @@ class VLMConfig: ), ) -CosmosReason2_VLM_30b_a3b_Private_GCP_Config: VLMConfig = VLMConfig( - model_name="nvidia/Cosmos-Reason2-30B-A3B-Private", - model_instance=L(Qwen3VLMoeTextForCausalLM)( - config=L(create_vlm_config)( - base_config=L(Qwen3VLMoeMoTConfig.from_json_file)( - json_file="cosmos_framework/model/vfm/vlm/qwen3_vl_moe/configs/Qwen3-VL-30B-A3B-Instruct.json" - ), - layer_module="Qwen3VLMoeTextMoTDecoderLayer", - qk_norm_for_text=True, - ), - ), - tokenizer=L(build_processor_lazy)( - tokenizer_type="Qwen/Qwen3-VL-30B-A3B-Instruct", - config_variant="gcp", - ), - layer_module="Qwen3VLMoeTextMoTDecoderLayer", - pretrained_weights=PretrainedWeightsConfig( - backbone_path="s3://bucket0/cosmos3/pretrained/huggingface/Cosmos-Reason/Cosmos-Reason2-30B-A3B-Private/", - credentials_path="credentials/gcp_checkpoint.secret", - enable_gcs_patch_in_boto3=True, - ), -) - # Config for Qwen3VL 235B A22B Instruct model # Qwen3VLMoE uses Qwen2Tokenizer Qwen3VLMoT_VLM_235b_a22b_Instruct_GCP_Config: VLMConfig = VLMConfig( @@ -458,48 +428,6 @@ class VLMConfig: ), ) -CosmosReason2_VLM_2b_Private_GCP_Config: VLMConfig = VLMConfig( - model_name="nvidia/Cosmos-Reason2-2B-Private", - model_instance=L(Qwen3VLTextForCausalLM)( - config=L(create_vlm_config)( - base_config=L(Qwen3VLMoTConfig.from_json_file)( - json_file="cosmos_framework/model/vfm/vlm/qwen3_vl/configs/Qwen3-VL-2B-Instruct.json" - ), - qk_norm_for_text=True, - ), - ), - tokenizer=L(build_processor_lazy)( - tokenizer_type="Qwen/Qwen3-VL-2B-Instruct", - config_variant="gcp", - ), - pretrained_weights=PretrainedWeightsConfig( - backbone_path="s3://bucket0/cosmos3/pretrained/huggingface/Cosmos-Reason/Cosmos-Reason2-2B-Private/", - credentials_path="credentials/gcp_checkpoint.secret", - enable_gcs_patch_in_boto3=True, - ), -) - -Cosmos3Reasoner_VLM_2b_Private_GCP_Config: VLMConfig = VLMConfig( - model_name="nvidia/Cosmos3-Reasoner-2B-Private", - model_instance=L(Qwen3VLTextForCausalLM)( - config=L(create_vlm_config)( - base_config=L(Qwen3VLMoTConfig.from_json_file)( - json_file="cosmos_framework/model/vfm/vlm/qwen3_vl/configs/Qwen3-VL-2B-Instruct.json" - ), - qk_norm_for_text=True, - ), - ), - tokenizer=L(build_processor_lazy)( - tokenizer_type="Qwen/Qwen3-VL-2B-Instruct", - config_variant="gcp", - ), - pretrained_weights=PretrainedWeightsConfig( - backbone_path="s3://bucket0/cosmos3/pretrained/huggingface/Cosmos-Reason/Cosmos3-Reasoner-2B-Private/", - credentials_path="credentials/gcp_checkpoint.secret", - enable_gcs_patch_in_boto3=True, - ), -) - # Config for Qwen3VL 4B Instruct model # Qwen3VL uses Qwen2Tokenizer Qwen3VLMoT_VLM_4b_Instruct_Config: VLMConfig = VLMConfig( @@ -586,8 +514,8 @@ class VLMConfig: ), ) -CosmosReason2_VLM_8b_Private_GCP_Config: VLMConfig = VLMConfig( - model_name="nvidia/Cosmos-Reason2-8B-Private", +Cosmos3Reasoner_VLM_8b_Private_GCP_Config: VLMConfig = VLMConfig( + model_name="nvidia/Cosmos3-Reasoner-8B-Private", model_instance=L(Qwen3VLTextForCausalLM)( config=L(create_vlm_config)( base_config=L(Qwen3VLMoTConfig.from_json_file)( @@ -601,14 +529,14 @@ class VLMConfig: config_variant="gcp", ), pretrained_weights=PretrainedWeightsConfig( - backbone_path="s3://bucket0/cosmos3/pretrained/huggingface/Cosmos-Reason/Cosmos-Reason2-8B-Private/", + backbone_path="s3://bucket0/cosmos3/pretrained/huggingface/Cosmos-Reason/Cosmos3-Reasoner-8B-Private/", credentials_path="credentials/gcp_checkpoint.secret", enable_gcs_patch_in_boto3=True, ), ) -Cosmos3Reasoner_VLM_8b_Private_GCP_Config: VLMConfig = VLMConfig( - model_name="nvidia/Cosmos3-Reasoner-8B-Private", +Cosmos3NanoReasoner_VLM_GCP_Config: VLMConfig = VLMConfig( + model_name="nvidia/Cosmos3-Nano-Reasoner", model_instance=L(Qwen3VLTextForCausalLM)( config=L(create_vlm_config)( base_config=L(Qwen3VLMoTConfig.from_json_file)( @@ -617,18 +545,18 @@ class VLMConfig: qk_norm_for_text=True, ), ), - tokenizer=L(build_processor_lazy)( - tokenizer_type="Qwen/Qwen3-VL-8B-Instruct", + tokenizer=L(create_qwen2_tokenizer_with_download)( + pretrained_model_name="Qwen/Qwen3-VL-8B-Instruct", config_variant="gcp", ), pretrained_weights=PretrainedWeightsConfig( - backbone_path="s3://bucket0/cosmos3/pretrained/huggingface/Cosmos-Reason/Cosmos3-Reasoner-8B-Private/", + backbone_path="s3://bucket0/cosmos3/pretrained/huggingface/Cosmos-Reason/Cosmos3-Nano-Reasoner/", credentials_path="credentials/gcp_checkpoint.secret", enable_gcs_patch_in_boto3=True, ), ) -Cosmos3NanoReasoner_VLM_GCP_Config: VLMConfig = VLMConfig( +Cosmos3NanoReasoner_VLM_GCP_Config_0517: VLMConfig = VLMConfig( model_name="nvidia/Cosmos3-Nano-Reasoner", model_instance=L(Qwen3VLTextForCausalLM)( config=L(create_vlm_config)( @@ -643,13 +571,12 @@ class VLMConfig: config_variant="gcp", ), pretrained_weights=PretrainedWeightsConfig( - backbone_path="s3://bucket0/cosmos3/pretrained/huggingface/Cosmos-Reason/Cosmos3-Nano-Reasoner/", + backbone_path="s3://bucket0/cosmos3/pretrained/huggingface/Cosmos-Reason/Cosmos3-Nano-Reasoner-bb9c6f5/", credentials_path="credentials/gcp_checkpoint.secret", enable_gcs_patch_in_boto3=True, ), ) - # Config for Qwen3VL 32B Instruct model # Qwen3VL uses Qwen2Tokenizer Qwen3VLMoT_VLM_32b_Instruct_Config: VLMConfig = VLMConfig( @@ -693,8 +620,8 @@ class VLMConfig: ), ) -CosmosReason2_VLM_32b_Private_GCP_Config: VLMConfig = VLMConfig( - model_name="nvidia/Cosmos-Reason2-32B-Private", +Cosmos3Reasoner_VLM_32b_Private_GCP_Config: VLMConfig = VLMConfig( + model_name="nvidia/Cosmos3-Reasoner-32B-Private", model_instance=L(Qwen3VLTextForCausalLM)( config=L(create_vlm_config)( base_config=L(Qwen3VLMoTConfig.from_json_file)( @@ -708,14 +635,14 @@ class VLMConfig: config_variant="gcp", ), pretrained_weights=PretrainedWeightsConfig( - backbone_path="s3://bucket0/cosmos3/pretrained/huggingface/Cosmos-Reason/Cosmos-Reason2-32B-Private/", + backbone_path="s3://bucket0/cosmos3/pretrained/huggingface/Cosmos-Reason/Cosmos3-Reasoner-32B-Private/", credentials_path="credentials/gcp_checkpoint.secret", enable_gcs_patch_in_boto3=True, ), ) -Cosmos3Reasoner_VLM_32b_Private_GCP_Config: VLMConfig = VLMConfig( - model_name="nvidia/Cosmos3-Reasoner-32B-Private", +Cosmos3SuperReasoner_VLM_GCP_Config: VLMConfig = VLMConfig( + model_name="nvidia/Cosmos3-Super-Reasoner", model_instance=L(Qwen3VLTextForCausalLM)( config=L(create_vlm_config)( base_config=L(Qwen3VLMoTConfig.from_json_file)( @@ -724,18 +651,18 @@ class VLMConfig: qk_norm_for_text=True, ), ), - tokenizer=L(build_processor_lazy)( - tokenizer_type="Qwen/Qwen3-VL-32B-Instruct", + tokenizer=L(create_qwen2_tokenizer_with_download)( + pretrained_model_name="Qwen/Qwen3-VL-32B-Instruct", config_variant="gcp", ), pretrained_weights=PretrainedWeightsConfig( - backbone_path="s3://bucket0/cosmos3/pretrained/huggingface/Cosmos-Reason/Cosmos3-Reasoner-32B-Private/", + backbone_path="s3://bucket0/cosmos3/pretrained/huggingface/Cosmos-Reason/Cosmos3-Super-Reasoner/", credentials_path="credentials/gcp_checkpoint.secret", enable_gcs_patch_in_boto3=True, ), ) -Cosmos3SuperReasoner_VLM_GCP_Config: VLMConfig = VLMConfig( +Cosmos3SuperReasoner_VLM_GCP_Config_0517: VLMConfig = VLMConfig( model_name="nvidia/Cosmos3-Super-Reasoner", model_instance=L(Qwen3VLTextForCausalLM)( config=L(create_vlm_config)( @@ -750,12 +677,63 @@ class VLMConfig: config_variant="gcp", ), pretrained_weights=PretrainedWeightsConfig( - backbone_path="s3://bucket0/cosmos3/pretrained/huggingface/Cosmos-Reason/Cosmos3-Super-Reasoner/", + backbone_path="s3://bucket0/cosmos3/pretrained/huggingface/Cosmos-Reason/Cosmos3-Super-Reasoner-b6df0d1/", + credentials_path="credentials/gcp_checkpoint.secret", + enable_gcs_patch_in_boto3=True, + ), +) + +# Cosmos3-Edge-Reasoner at commit 4acb717. +# nemotron_siglip2 architecture: Nemotron text backbone (56-block hybrid layout, 2048 hidden) +# + SigLIP2 vision encoder. The text transformer is identical in shape to +# Nemotron-3-Dense-VL-2B (hidden_size=2048, 56 alternating attn/MLP blocks → 28 +# effective MoT layers after _transform_text_dict). Uses the same +# nemotron_3_dense_vl weight remapping and config JSON. +Cosmos3EdgeReasoner_VLM_GCP_Config_4acb717: VLMConfig = VLMConfig( + model_name="nvidia/Cosmos3-Edge-Reasoner", + model_instance=L(Nemotron3DenseVLTextForCausalLM)( + config=L(create_vlm_config)( + base_config=L(Nemotron3DenseVLMoTConfig.from_json_file)( + json_file="cosmos_framework/model/vfm/vlm/nemotron_3_dense_vl/configs/Nemotron-2B-Dense-VL.json" + ), + qk_norm_for_text=False, + ), + ), + tokenizer=L(build_processor_lazy)( + tokenizer_type="nvidia/Cosmos3-Edge-Reasoner", + config_variant="gcp", + ), + pretrained_weights=PretrainedWeightsConfig( + backbone_path="s3://bucket0/cosmos3/pretrained/huggingface/nvidia/Cosmos3-Edge-Reasoner-4acb717/", credentials_path="credentials/gcp_checkpoint.secret", enable_gcs_patch_in_boto3=True, + checkpoint_format="nemotron_3_dense_vl", ), ) +# Cosmos3-Edge-Reasoner at commit 9b4c028 (2026-05-29). +# Same nemotron_siglip2 architecture as 4acb717; new weights uploaded 2026-05-29. +Cosmos3EdgeReasoner_VLM_GCP_Config_9b4c028: VLMConfig = VLMConfig( + model_name="nvidia/Cosmos3-Edge-Reasoner", + model_instance=L(Nemotron3DenseVLTextForCausalLM)( + config=L(create_vlm_config)( + base_config=L(Nemotron3DenseVLMoTConfig.from_json_file)( + json_file="cosmos_framework/model/vfm/vlm/nemotron_3_dense_vl/configs/Nemotron-2B-Dense-VL.json" + ), + qk_norm_for_text=False, + ), + ), + tokenizer=L(build_processor_lazy)( + tokenizer_type="nvidia/Cosmos3-Edge-Reasoner", + config_variant="gcp", + ), + pretrained_weights=PretrainedWeightsConfig( + backbone_path="s3://bucket0/cosmos3/pretrained/huggingface/nvidia/Cosmos3-Edge-Reasoner-9b4c028/", + credentials_path="credentials/gcp_checkpoint.secret", + enable_gcs_patch_in_boto3=True, + checkpoint_format="nemotron_3_dense_vl", + ), +) def register_vlm(): @@ -832,24 +810,6 @@ def register_vlm(): name="cosmos_reason2_vlm_2b_gcp", node=CosmosReason2_VLM_2b_GCP_Config, ) - cs.store( - group="vlm_config", - package="model.config.vlm_config", - name="cosmos_reason2_vlm_2b_private_gcp", - node=CosmosReason2_VLM_2b_Private_GCP_Config, - ) - cs.store( - group="vlm_config", - package="model.config.vlm_config", - name="cosmos3_reasoner_vlm_2b_private_gcp", - node=Cosmos3Reasoner_VLM_2b_Private_GCP_Config, - ) - cs.store( - group="vlm_config", - package="model.config.vlm_config", - name="cosmos_reason2_vlm_8b_private_gcp", - node=CosmosReason2_VLM_8b_Private_GCP_Config, - ) cs.store( group="vlm_config", package="model.config.vlm_config", @@ -865,8 +825,8 @@ def register_vlm(): cs.store( group="vlm_config", package="model.config.vlm_config", - name="cosmos_reason2_vlm_32b_private_gcp", - node=CosmosReason2_VLM_32b_Private_GCP_Config, + name="cosmos3_nano_reasoner_vlm_gcp_0517", + node=Cosmos3NanoReasoner_VLM_GCP_Config_0517, ) cs.store( group="vlm_config", @@ -883,8 +843,8 @@ def register_vlm(): cs.store( group="vlm_config", package="model.config.vlm_config", - name="cosmos_reason2_vlm_30b_a3b_private_gcp", - node=CosmosReason2_VLM_30b_a3b_Private_GCP_Config, + name="cosmos3_super_reasoner_vlm_gcp_0517", + node=Cosmos3SuperReasoner_VLM_GCP_Config_0517, ) cs.store( group="vlm_config", @@ -922,3 +882,15 @@ def register_vlm(): name="qwen3_vl_mot_vlm_32b_instruct_gcp", node=Qwen3VLMoT_VLM_32b_Instruct_GCP_Config, ) + cs.store( + group="vlm_config", + package="model.config.vlm_config", + name="cosmos3_edge_reasoner_vlm_gcp_4acb717", + node=Cosmos3EdgeReasoner_VLM_GCP_Config_4acb717, + ) + cs.store( + group="vlm_config", + package="model.config.vlm_config", + name="cosmos3_edge_reasoner_vlm_gcp_9b4c028", + node=Cosmos3EdgeReasoner_VLM_GCP_Config_9b4c028, + ) diff --git a/cosmos_framework/configs/base/vlm/config.py b/cosmos_framework/configs/base/vlm/config.py index 66dc3ed..c5a087d 100644 --- a/cosmos_framework/configs/base/vlm/config.py +++ b/cosmos_framework/configs/base/vlm/config.py @@ -4,10 +4,15 @@ from cosmos_framework.trainer import ImaginaireTrainer from cosmos_framework.utils import log from cosmos_framework.utils.config_helper import import_all_modules_from_package +from cosmos_framework.configs.base.defaults.checkpointer import register_checkpoint, register_ckpt_type from cosmos_framework.configs.base.vlm.defaults.callbacks import register_callbacks -from cosmos_framework.configs.base.vlm.defaults.checkpointer import register_checkpoint, register_ckpt_type from cosmos_framework.configs.base.vlm.defaults.config import Config - +from cosmos_framework.configs.base.vlm.defaults.dataloader import register_data_debug +from cosmos_framework.configs.base.vlm.defaults.dataloader_weighted_url import ( + register_data_recipe, + register_data_weighted_url, + register_data_weighted_url_with_text, +) from cosmos_framework.configs.base.vlm.defaults.model import register_model from cosmos_framework.configs.base.vlm.defaults.optimizer import register_optimizer, register_scheduler from cosmos_framework.configs.base.vlm.defaults.vlm_policy import register_vlm_policy @@ -42,6 +47,10 @@ def make_config() -> Config: register_model() register_vlm_policy() # Register dataloader configs + register_data_weighted_url() + register_data_recipe() + register_data_weighted_url_with_text() + register_data_debug() log.info("Registering optimizer, scheduler, checkpoint, ckpt type, and callbacks") register_optimizer() register_scheduler() diff --git a/cosmos_framework/configs/base/vlm/defaults/callbacks.py b/cosmos_framework/configs/base/vlm/defaults/callbacks.py index 1910b63..f742972 100644 --- a/cosmos_framework/configs/base/vlm/defaults/callbacks.py +++ b/cosmos_framework/configs/base/vlm/defaults/callbacks.py @@ -12,7 +12,7 @@ from cosmos_framework.utils.lazy_config import LazyCall as L from cosmos_framework.utils.callback import LowPrecisionCallback, WandBCallback from cosmos_framework.callbacks.dataloader_state import DataLoaderStateCallback - +from cosmos_framework.callbacks.dataloading_monitor import DetailedDataLoadingSpeedMonitor from cosmos_framework.callbacks.grad_clip import GradClip from cosmos_framework.callbacks.hf_export import HFExportCallback from cosmos_framework.callbacks.iter_speed import IterSpeed @@ -40,6 +40,10 @@ def register_callbacks(): param_count=L(ParamCount)( # use model save_s3="${upload_reproducible_setup}", ), + dataloader_speed=L(DetailedDataLoadingSpeedMonitor)( + every_n=100, + save_s3="${upload_reproducible_setup}", + ), grad_clip=L(GradClip)(clip_norm=1.0, force_finite=False), # use model learning_rate_logger=L(LearningRateLogger)(every_n=10), low_precision=L(LowPrecisionCallback)( @@ -47,7 +51,6 @@ def register_callbacks(): config=PLACEHOLDER, trainer=PLACEHOLDER, ), # reads model.precision; no extra kwarg needed - # nvtx=L(NVTXCallback)(synchronize=True), ) diff --git a/cosmos_framework/configs/base/vlm/defaults/policy_config.py b/cosmos_framework/configs/base/vlm/defaults/policy_config.py index 307eb92..c203168 100644 --- a/cosmos_framework/configs/base/vlm/defaults/policy_config.py +++ b/cosmos_framework/configs/base/vlm/defaults/policy_config.py @@ -29,7 +29,7 @@ class PolicyConfig: trainable_map: Union[str, None] = None monkey_patch_for_text_only_data: bool = False - # HF attention impl. Default "cosmos" routes through imaginaire.attention + # HF attention impl. Default "cosmos" routes through cosmos_framework.model.attention # (NATTEN/blackwell-fmha on GB200). Override to "flash_attention_2", # "sdpa", or "eager" for fallback. attn_implementation: str = "cosmos" @@ -53,5 +53,6 @@ class VLMModelConfig: ema: EMAConfig = EMAConfig(enabled=False) # Force deterministic kernels in Flash-Attention init (slower; required for - # parity bit-exactness) - deterministic: bool = False + # parity bit-exactness). VLM-only knob — consumed by VLMModel.__init__ via + # init_flash_attn_meta. + deterministic: bool = True diff --git a/cosmos_framework/data/vfm/action/action_normalization.py b/cosmos_framework/data/vfm/action/action_normalization.py index c58bb90..d553161 100644 --- a/cosmos_framework/data/vfm/action/action_normalization.py +++ b/cosmos_framework/data/vfm/action/action_normalization.py @@ -27,7 +27,7 @@ def load_action_stats(stats_path: str, stats_key: str = "global") -> dict[str, n elif stats_key != "global": raise KeyError(f"Action normalization stats block {stats_key!r} not found in {stats_path}.") stat_keys = {"mean", "std", "min", "max", "q01", "q99"} - return {key: np.array(value, dtype=np.float32) for key, value in raw.items() if key in stat_keys} + return {k: np.array(v, dtype=np.float32) for k, v in raw.items() if k in stat_keys} def normalize_action( @@ -35,7 +35,7 @@ def normalize_action( method: str, stats: dict[str, torch.Tensor], ) -> torch.Tensor: - """Normalize action tensor.""" + """Normalize action tensor (all dimensions including gripper).""" if method == "quantile": q01, q99 = stats["q01"], stats["q99"] denom = (q99 - q01).clamp(min=1e-8) diff --git a/cosmos_framework/data/vfm/action/domain_utils.py b/cosmos_framework/data/vfm/action/domain_utils.py index 917fe4c..6f433f7 100644 --- a/cosmos_framework/data/vfm/action/domain_utils.py +++ b/cosmos_framework/data/vfm/action/domain_utils.py @@ -14,9 +14,12 @@ "bridge_orig_lerobot": 7, "droid_lerobot": 8, "robomind-franka": 8, # Both Droid and RoboMIND-Franka are using robotiq and franka + "embodiment_b": 9, "robomind-franka-dual": 12, "robomind-ur": 13, "agibotworld": 15, + "embodiment_c_gripper": 15, + "embodiment_c_gripper_ext": 15, "fractal": 20, } @@ -31,8 +34,16 @@ "robomind-franka": 10, "robomind-franka-dual": 20, "robomind-ur": 10, + "embodiment_b": 30, "agibotworld": 29, + "embodiment_c_gripper": 29, + "embodiment_c_gripper_ext": 29, "fractal": 10, + # NOTE: ``libero`` (7/10/13 depending on ``rotation_space``) and ``hand_pose`` + # (variable with ``keypoint_option`` and ``rotation_format``) are absent + # because their raw width is set per-dataset at construction time. Inference + # in inverse_dynamics/policy modes is not supported for these domains until + # canonical widths are added here. } @@ -45,3 +56,20 @@ def get_domain_id(embodiment_type: str) -> int: f"Available embodiments: {sorted(EMBODIMENT_TO_DOMAIN_ID.keys())}" ) return EMBODIMENT_TO_DOMAIN_ID[key] + + +def get_action_dim(embodiment_type: str) -> int: + """Get the raw action dimension for a given embodiment type.""" + key = embodiment_type.lower().strip() + if key not in EMBODIMENT_TO_RAW_ACTION_DIM: + raise KeyError( + f"Unknown embodiment type: {embodiment_type!r}. " + f"Available embodiments: {sorted(EMBODIMENT_TO_RAW_ACTION_DIM.keys())}" + ) + return EMBODIMENT_TO_RAW_ACTION_DIM[key] + + +def is_valid_domain_name(embodiment_type: str) -> bool: + """Check if the given embodiment type is recognized.""" + key = embodiment_type.lower().strip() + return key in EMBODIMENT_TO_RAW_ACTION_DIM diff --git a/cosmos_framework/data/vfm/action/json_formatter.py b/cosmos_framework/data/vfm/action/json_formatter.py index b511e93..201a76e 100644 --- a/cosmos_framework/data/vfm/action/json_formatter.py +++ b/cosmos_framework/data/vfm/action/json_formatter.py @@ -7,9 +7,9 @@ import torch +from cosmos_framework.utils import log from cosmos_framework.data.vfm.action.viewpoint_utils import DEFAULT_VIEWPOINT_TEMPLATES from cosmos_framework.data.vfm.utils import VIDEO_RES_SIZE_INFO -from cosmos_framework.utils import log def _should_append_idle_frame_info(mode: object) -> bool: diff --git a/cosmos_framework/data/vfm/action/transforms.py b/cosmos_framework/data/vfm/action/transforms.py index d17b141..9cebb4e 100644 --- a/cosmos_framework/data/vfm/action/transforms.py +++ b/cosmos_framework/data/vfm/action/transforms.py @@ -19,6 +19,7 @@ import torch import torchvision.transforms.functional as transforms_F +from cosmos_framework.utils import log from cosmos_framework.data.vfm.action.json_formatter import ActionPromptJsonFormatter from cosmos_framework.data.vfm.action.viewpoint_utils import ViewpointTextInfo from cosmos_framework.data.vfm.augmentors.duration_fps_text_timestamps import DurationFPSTextTimeStamps @@ -27,7 +28,6 @@ from cosmos_framework.data.vfm.augmentors.text_tokenizer import TextTokenizerTransform from cosmos_framework.data.vfm.sequence_packing import SequencePlan from cosmos_framework.data.vfm.utils import VIDEO_RES_SIZE_INFO -from cosmos_framework.utils import log from cosmos_framework.utils.vfm.data_utils import get_vision_data_resolution @@ -309,7 +309,6 @@ def build_sequence_plan_from_mode( base_action_length = action_length - num_history_actions if mode == "forward_dynamics": condition_frame_indexes_action = list(range(action_length)) - # This currently assumes that the action length is the same as the video length - 1 # and if action length is the same as the video length, then the first action is the conditioning action elif base_action_length == video_length - 1: diff --git a/cosmos_framework/data/vfm/augmentor_provider.py b/cosmos_framework/data/vfm/augmentor_provider.py index 2ace65c..5ce1727 100644 --- a/cosmos_framework/data/vfm/augmentor_provider.py +++ b/cosmos_framework/data/vfm/augmentor_provider.py @@ -564,6 +564,9 @@ def get_video_augmentor_v3( conditioning_config = kwargs.get("conditioning_config", None) uniform_conditioning = kwargs.get("uniform_conditioning", False) temporal_compression_factor = kwargs.get("temporal_compression_factor", 4) + causal_vae = kwargs.get("causal_vae", True) + uniae_pad_frames = kwargs.get("uniae_pad_frames", None) + uniae_chunk_frames = kwargs.get("uniae_chunk_frames", None) print("Running video_basic_augmentor_v3...") augmentors = { @@ -577,6 +580,9 @@ def get_video_augmentor_v3( "min_stride": min_stride, "seek_mode": "exact", # Change to "approximate"? "dataset_resolution_type": dataset_resolution_type, + "causal_vae": causal_vae, + "uniae_pad_frames": uniae_pad_frames, + "uniae_chunk_frames": uniae_chunk_frames, }, ), "merge_datadict": L(merge_datadict.DataDictMerger)( @@ -670,7 +676,6 @@ def get_video_augmentor_v3( return augmentors - # Use video_basic_augmentor_v3_json_caption instead. @augmentor_register("video_basic_augmentor_v3_with_audio") def get_video_augmentor_v3_with_audio( @@ -829,6 +834,7 @@ def get_video_augmentor_v3_json_caption( conditioning_config = kwargs.get("conditioning_config", None) uniform_conditioning = kwargs.get("uniform_conditioning", False) temporal_compression_factor = kwargs.get("temporal_compression_factor", 4) + causal_vae = kwargs.get("causal_vae", True) print("Running video_augmentor_v3_json_caption...") augmentors = { @@ -856,6 +862,7 @@ def get_video_augmentor_v3_json_caption( "extract_audio": extract_audio, "audio_sample_rate": audio_sample_rate, "emit_placeholder_sound": not extract_audio, + "causal_vae": causal_vae, }, ), "merge_datadict": L(merge_datadict.DataDictMerger)( diff --git a/cosmos_framework/data/vfm/augmentors/idle_frames_text_info.py b/cosmos_framework/data/vfm/augmentors/idle_frames_text_info.py index 302715c..34f6116 100644 --- a/cosmos_framework/data/vfm/augmentors/idle_frames_text_info.py +++ b/cosmos_framework/data/vfm/augmentors/idle_frames_text_info.py @@ -8,7 +8,7 @@ frames (i.e. the relative-pose delta is close to identity and the gripper command does not change). The upstream dataset is responsible for populating ``data_dict[idle_frames_key]`` via -:func:`projects.cosmos3.vfm.datasets.action.pose_utils.compute_idle_frames`. +:func:`cosmos_framework.data.vfm.action.pose_utils.compute_idle_frames`. Per-field dropout (default 5%) is applied here, matching Pi0.7's approach of independently dropping each metadata component. This is complementary to the diff --git a/cosmos_framework/data/vfm/augmentors/interleaved_video_parsing.py b/cosmos_framework/data/vfm/augmentors/interleaved_video_parsing.py index aea759e..41e4e4c 100644 --- a/cosmos_framework/data/vfm/augmentors/interleaved_video_parsing.py +++ b/cosmos_framework/data/vfm/augmentors/interleaved_video_parsing.py @@ -414,7 +414,7 @@ def __call__(self, data_dict: dict) -> dict | None: ) # [C,T,H,W] num_multiplier = (end_frame - start_frame) / self.num_frames - + # NOTE: matches legacy VideoParsing.__call__ output keys exactly. Do NOT add # variable-length fields like ``frame_indices`` here -- ``video_flatten_keys`` in # ``get_video_transfer_augmentor`` lists ``frame_indices``, and surfacing a # per-sample list there would crash ``custom_collate_fn`` (default_collate requires diff --git a/cosmos_framework/data/vfm/augmentors/pkl_to_media.py b/cosmos_framework/data/vfm/augmentors/pkl_to_media.py index 54d47b8..aa9eb21 100644 --- a/cosmos_framework/data/vfm/augmentors/pkl_to_media.py +++ b/cosmos_framework/data/vfm/augmentors/pkl_to_media.py @@ -27,14 +27,12 @@ def token_to_pixels(token_length: int, patch_size: int = 14, temporal_patch_size: int = 2) -> int: """Convert token length to pixels based on patch size and temporal patch size.""" - merged_patch_size = patch_size * 2 return token_length * merged_patch_size**2 * temporal_patch_size def pixels_to_token(pixels: int, patch_size: int = 14, temporal_patch_size: int = 2) -> int: """Convert pixels to token length based on patch size and temporal patch size.""" - merged_patch_size = patch_size * 2 return pixels // merged_patch_size**2 // temporal_patch_size diff --git a/cosmos_framework/data/vfm/augmentors/transfer_control_transform.py b/cosmos_framework/data/vfm/augmentors/transfer_control_transform.py index 74fb523..a18e8fd 100644 --- a/cosmos_framework/data/vfm/augmentors/transfer_control_transform.py +++ b/cosmos_framework/data/vfm/augmentors/transfer_control_transform.py @@ -5,7 +5,7 @@ Augmentors for transfer (control-conditioned) image and video generation in the cosmos3 VFM pipeline. Transfer training conditions the model on control signals (edge, blur, depth, or segmentation) -to generate images or videos, aligned with cosmos_framework/transfer2. This module provides: +to generate images or videos, aligned with cosmos/transfer2. This module provides: - **TransferToTrainingFormat**: Converts (control_input, target) into the joint dataloader format with SequencePlan (condition frame + generated frame), for both image and video outputs. diff --git a/cosmos_framework/data/vfm/augmentors/video_parsing.py b/cosmos_framework/data/vfm/augmentors/video_parsing.py index cfaa934..8972818 100644 --- a/cosmos_framework/data/vfm/augmentors/video_parsing.py +++ b/cosmos_framework/data/vfm/augmentors/video_parsing.py @@ -345,7 +345,7 @@ def __call__(self, data_dict: dict) -> dict | None: video_info["video"] = video_frames video_info["num_multiplier"] = num_multiplier # Store the frame skipping multiplier - + # NOTE: Explaining the logic of conditioning FPS calculation: # 1. Our video parser stores the original video FPS of the video. # 2. We have multiple modes of frame selection -- consecutive chunk of frames or subsampled frames. # Here's what we do in each case: @@ -434,6 +434,20 @@ def __init__(self, input_keys: list, output_keys: Optional[list] = None, args: O self.dataset_resolution_type = args.get("dataset_resolution_type", "all") self.resolution_tier = _DATASET_RESOLUTION_TIER.get(self.dataset_resolution_type) + # VAE temporal alignment mode. + # causal_vae=True (default): align to 1+4N (causal VAE, e.g. Wan 2.2) + # causal_vae=False: align to 4N (non-causal VAE, e.g. UniAE) + self.causal_vae = args.get("causal_vae", True) + self.uniae_pad_frames = args.get("uniae_pad_frames", None) + self.uniae_chunk_frames = args.get("uniae_chunk_frames", None) + if self.uniae_chunk_frames is not None: + assert self.uniae_pad_frames is not None, ( + "uniae_pad_frames must be specified if uniae_chunk_frames is specified" + ) + assert self.uniae_chunk_frames > 2 * self.uniae_pad_frames, ( + "uniae_chunk_frames must be greater than 2 * uniae_pad_frames" + ) + def _sample_stride_with_bias(self, max_stride: int, min_stride: int = 1) -> int: """Sample a stride from [min_stride, max_stride] with bias controlled by low_fps_bias. @@ -520,7 +534,6 @@ def _validate_and_probe(self, video: Optional[bytes], meta_dict: dict, data_dict return True def __call__(self, data_dict: dict) -> dict | None: - # if in future we need to train with batch size > 1, need to pad frames try: meta_dict = data_dict[self.meta_key] @@ -569,11 +582,35 @@ def __call__(self, data_dict: dict) -> dict | None: stride = self._sample_stride_with_bias(self.max_stride, self.min_stride) frame_indices = np.arange(0, num_video_frames, stride).tolist() - # VAE compress temporal by 4x, with 1 as condition - # thus the max_video_frames must be 1 + 4N + # Align frame count to VAE temporal compression requirement. + # causal_vae=True: 1+4N (causal VAE, e.g. Wan 2.2) + # causal_vae=False: 4N (non-causal VAE, e.g. UniAE) num_video_frames = min(len(frame_indices), self.args.get("max_num_frames", 1000)) - N = (num_video_frames - 1) // 4 - num_video_frames = 1 + 4 * N + if self.causal_vae: + N = (num_video_frames - 1) // 4 + num_video_frames = 1 + 4 * N + else: + # If this is UniAE, we need to align the frame count to the chunk size and padding. + if self.uniae_chunk_frames is not None: + # trim excess frames + effective_chunk_frames = self.uniae_chunk_frames - 2 * self.uniae_pad_frames + while ( + num_video_frames % effective_chunk_frames != 0 + and (num_video_frames % effective_chunk_frames + 2 * self.uniae_pad_frames) % 4 != 0 + and num_video_frames > 0 + ): + num_video_frames -= 1 + + if num_video_frames == 0: + log.warning( + f"VideoParsingWithFullFrames: video too short for UniAE. " + f"url: {data_dict['__url__']}, key: {data_dict['__key__']}", + rank0_only=False, + ) + return None + else: + N = num_video_frames // 4 + num_video_frames = 4 * N frame_indices = frame_indices[0:num_video_frames] frame_batch = video_decoder.get_frames_at(frame_indices) @@ -698,7 +735,6 @@ def __init__(self, input_keys: list, output_keys: Optional[list] = None, args: O super().__init__(input_keys, output_keys, args) def __call__(self, data_dict: dict) -> dict | None: - # if in future we need to train with batch size > 1, need to pad frames try: meta_dict = data_dict[self.meta_key] @@ -772,11 +808,16 @@ def __call__(self, data_dict: dict) -> dict | None: stride = self._sample_stride_with_bias(self.max_stride, self.min_stride) frame_indices = np.arange(chunk_start_clamped, chunk_end_clamped, stride).tolist() - # VAE compress temporal by 4x, with 1 as condition - # thus the max_video_frames must be 1 + 4N + # Align frame count to VAE temporal compression requirement. + # causal_vae=True: 1+4N (causal VAE, e.g. Wan 2.2) + # causal_vae=False: 4N (non-causal VAE, e.g. UniAE) num_video_frames = min(len(frame_indices), self.args.get("max_num_frames", 1000)) - N = (num_video_frames - 1) // 4 - num_video_frames = 1 + 4 * N + if self.causal_vae: + N = (num_video_frames - 1) // 4 + num_video_frames = 1 + 4 * N + else: + N = num_video_frames // 4 + num_video_frames = 4 * N if num_video_frames < 1: log.warning( f"VideoParsingChunkedFrames: chunk too short for stride. " diff --git a/cosmos_framework/data/vfm/augmentors/vlm/nvlm_data_unify.py b/cosmos_framework/data/vfm/augmentors/vlm/nvlm_data_unify.py index eb029eb..c9ee9da 100644 --- a/cosmos_framework/data/vfm/augmentors/vlm/nvlm_data_unify.py +++ b/cosmos_framework/data/vfm/augmentors/vlm/nvlm_data_unify.py @@ -39,7 +39,7 @@ def __init__( output_keys: Optional[list] = [], args: Optional[dict] = None, data_path_prefix: list[str] = [ - "cosmos_framework/ar/v2/nvlm/", + "cosmos/ar/v2/nvlm/", ], # prefix of the data in s3 ) -> None: super().__init__(input_keys, output_keys, args) diff --git a/cosmos_framework/data/vfm/augmentors/vlm/nvlm_sample_loaders_and_part_filters.py b/cosmos_framework/data/vfm/augmentors/vlm/nvlm_sample_loaders_and_part_filters.py index fabe0c3..88f93bf 100644 --- a/cosmos_framework/data/vfm/augmentors/vlm/nvlm_sample_loaders_and_part_filters.py +++ b/cosmos_framework/data/vfm/augmentors/vlm/nvlm_sample_loaders_and_part_filters.py @@ -13,12 +13,10 @@ from cosmos_framework.data.vfm.data_sources.vlm.nvlm import data_path_mapping # This file was automatically generated by `nvgpt4 data prepare`. - # import torch def sample_loader_0(raw: dict) -> dict: # Note: Images are already decoded to tensors - if "text" in raw: caption = raw["text"] else: @@ -31,16 +29,15 @@ def sample_loader_0(raw: dict) -> dict: # Note: Images are already decoded to t def part_filter_0(part: str) -> bool: - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, # remove it from the list, such that it is not decoded. If you need all, keep as is return part in ("json", "jpg", "text") +# Loader for: /lustre/fsw/portfolios/llmservice/users/zhuoliny/extended-sci/data/merged/CoT # This file was automatically generated by `energon prepare`. - def sample_loader_1(raw: dict) -> dict: # Note: Images are already decoded to tensors json_item = raw["json"] images = [raw["jpg"]] @@ -54,16 +51,15 @@ def sample_loader_1(raw: dict) -> dict: # Note: Images are already decoded to t def part_filter_1(part: str) -> bool: - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, # remove it from the list, such that it is not decoded. If you need all, keep as is return part in ("jpg", "json") +# Loader for: /lustre/fsw/portfolios/llmservice/users/zhuoliny/extended-sci/data/merged/single-choice # This file was automatically generated by `energon prepare`. - def sample_loader_2(raw: dict) -> dict: # Note: Images are already decoded to tensors json_item = raw["json"] images = [raw["jpg"]] @@ -77,16 +73,15 @@ def sample_loader_2(raw: dict) -> dict: # Note: Images are already decoded to t def part_filter_2(part: str) -> bool: - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, # remove it from the list, such that it is not decoded. If you need all, keep as is return part in ("jpg", "json") +# Loader for: /lustre/fsw/portfolios/llmservice/users/zhuoliny/extended-sci/data/extended-sci-3/CoT # This file was automatically generated by `energon prepare`. - def sample_loader_3(raw: dict) -> dict: # Note: Images are already decoded to tensors json_item = raw["json"] images = [raw["jpg"]] @@ -100,16 +95,15 @@ def sample_loader_3(raw: dict) -> dict: # Note: Images are already decoded to t def part_filter_3(part: str) -> bool: - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, # remove it from the list, such that it is not decoded. If you need all, keep as is return part in ("jpg", "json") +# Loader for: /lustre/fsw/portfolios/llmservice/users/zhuoliny/extended-sci/data/extended-sci-3/single-choice # This file was automatically generated by `energon prepare`. - def sample_loader_4(raw: dict) -> dict: # Note: Images are already decoded to tensors json_item = raw["json"] images = [raw["jpg"]] @@ -123,16 +117,15 @@ def sample_loader_4(raw: dict) -> dict: # Note: Images are already decoded to t def part_filter_4(part: str) -> bool: - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, # remove it from the list, such that it is not decoded. If you need all, keep as is return part in ("jpg", "json") +# Loader for: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/datasets/nvlm/wdai/data/SceMQA_processed # This file was automatically generated by `nvgpt4 data prepare`. - def sample_loader_5(raw: dict) -> dict: # Note: Images are already decoded to tensors json_item = raw["json"] images = [raw["jpg"]] @@ -146,16 +139,15 @@ def sample_loader_5(raw: dict) -> dict: # Note: Images are already decoded to t def part_filter_5(part: str) -> bool: - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, # remove it from the list, such that it is not decoded. If you need all, keep as is return part in ("json", "jpg") +# Loader for: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/datasets/nvlm/wdai/data/vqa_collection_doc_text_st_chart_scale_textbook_LRV_Screen # This file was automatically generated by `nvgpt4 data prepare`. - def sample_loader_6(raw: dict) -> dict: # Note: Images are already decoded to tensors json_item = raw["json"] @@ -206,16 +198,15 @@ def sample_loader_6(raw: dict) -> dict: # Note: Images are already decoded to t def part_filter_6(part: str) -> bool: - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, # remove it from the list, such that it is not decoded. If you need all, keep as is return part in ("json", "jpg", "png") +# Loader for: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/datasets/nvlm/wdai/data/plotqa/processed # This file was automatically generated by `nvgpt4 data prepare`. - def sample_loader_7(raw: dict) -> dict: # Note: Images are already decoded to tensors j = raw["json"] @@ -229,16 +220,15 @@ def sample_loader_7(raw: dict) -> dict: # Note: Images are already decoded to t def part_filter_7(part: str) -> bool: - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, # remove it from the list, such that it is not decoded. If you need all, keep as is return part in ("jpg", "json") +# Loader for: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/datasets/nvlm/wdai/data/clevr-math/processed # This file was automatically generated by `nvgpt4 data prepare`. - def sample_loader_8(raw: dict) -> dict: # Note: Images are already decoded to tensors j = raw["json"] @@ -252,16 +242,15 @@ def sample_loader_8(raw: dict) -> dict: # Note: Images are already decoded to t def part_filter_8(part: str) -> bool: - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, # remove it from the list, such that it is not decoded. If you need all, keep as is return part in ("jpg", "json") +# Loader for: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/datasets/nvlm/wdai/data/MMC-Instruction/processed # This file was automatically generated by `nvgpt4 data prepare`. - def sample_loader_9(raw: dict) -> dict: # Note: Images are already decoded to tensors j = raw["json"] @@ -275,16 +264,15 @@ def sample_loader_9(raw: dict) -> dict: # Note: Images are already decoded to t def part_filter_9(part: str) -> bool: - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, # remove it from the list, such that it is not decoded. If you need all, keep as is return part in ("json", "jpg") +# Loader for: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/datasets/nvlm/wdai/data/ocrvqa/processed # This file was automatically generated by `nvgpt4 data prepare`. - def sample_loader_10(raw: dict) -> dict: # Note: Images are already decoded to tensors j = raw["json"] return dict( @@ -296,16 +284,15 @@ def sample_loader_10(raw: dict) -> dict: # Note: Images are already decoded to def part_filter_10(part: str) -> bool: - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, # remove it from the list, such that it is not decoded. If you need all, keep as is return part in ("json", "jpg") +# Loader for: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/datasets/nvlm/wdai/data/dude/processed # This file was automatically generated by `nvgpt4 data prepare`. - def sample_loader_11(raw: dict) -> dict: # Note: Images are already decoded to tensors j = raw["json"] @@ -319,16 +306,15 @@ def sample_loader_11(raw: dict) -> dict: # Note: Images are already decoded to def part_filter_11(part: str) -> bool: - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, # remove it from the list, such that it is not decoded. If you need all, keep as is return part in ("jpg", "json") +# Loader for: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/datasets/nvlm/wdai/data/VisualMRC/processed # This file was automatically generated by `nvgpt4 data prepare`. - def sample_loader_12(raw: dict) -> dict: # Note: Images are already decoded to tensors j = raw["json"] return dict( @@ -341,16 +327,15 @@ def sample_loader_12(raw: dict) -> dict: # Note: Images are already decoded to def part_filter_12(part: str) -> bool: - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, # remove it from the list, such that it is not decoded. If you need all, keep as is return part in ("jpg", "json") +# Loader for: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/datasets/nvlm/wdai/data/mcvqa_collection_scienceqa_ai2d_geoqaplus_geometry3k_tqa # This file was automatically generated by `nvgpt4 data prepare`. - def sample_loader_13(raw: dict) -> dict: # Note: Images are already decoded to tensors json_item = raw["json"] key = raw["__key__"] @@ -385,16 +370,15 @@ def sample_loader_13(raw: dict) -> dict: # Note: Images are already decoded to def part_filter_13(part: str) -> bool: - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, # remove it from the list, such that it is not decoded. If you need all, keep as is return part in ("json", "png", "jpg") +# Loader for: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/datasets/nvlm/wdai/data/arxiv_qa/processed # This file was automatically generated by `nvgpt4 data prepare`. - def sample_loader_14(raw: dict) -> dict: # Note: Images are already decoded to tensors json_item = raw["json"] @@ -408,16 +392,15 @@ def sample_loader_14(raw: dict) -> dict: # Note: Images are already decoded to def part_filter_14(part: str) -> bool: - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, # remove it from the list, such that it is not decoded. If you need all, keep as is return part in ("jpg", "json") +# Loader for: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/datasets/nvlm/wdai/data/tabmwp/processed # This file was automatically generated by `nvgpt4 data prepare`. - def sample_loader_15(raw: dict) -> dict: # Note: Images are already decoded to tensors json_item = raw["json"] @@ -444,16 +427,15 @@ def sample_loader_15(raw: dict) -> dict: # Note: Images are already decoded to def part_filter_15(part: str) -> bool: - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, # remove it from the list, such that it is not decoded. If you need all, keep as is return part in ("jpg", "json") +# Loader for: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/datasets/nvlm/wdai/data/ocr_vqa_aug/processed/ # This file was automatically generated by `nvgpt4 data prepare`. - def sample_loader_16(raw: dict) -> dict: # Note: Images are already decoded to tensors j = raw["json"] @@ -463,16 +445,15 @@ def sample_loader_16(raw: dict) -> dict: # Note: Images are already decoded to def part_filter_16(part: str) -> bool: - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, # remove it from the list, such that it is not decoded. If you need all, keep as is return part in ("jpg", "json") +# Loader for: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/datasets/nvlm/wdai/data/dvqa_full/processed # This file was automatically generated by `nvgpt4 data prepare`. - def sample_loader_17(raw: dict) -> dict: # Note: Images are already decoded to tensors json_item = raw["json"] images = [raw["jpg"]] @@ -486,16 +467,15 @@ def sample_loader_17(raw: dict) -> dict: # Note: Images are already decoded to def part_filter_17(part: str) -> bool: - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, # remove it from the list, such that it is not decoded. If you need all, keep as is return part in ("json", "jpg") +# Loader for: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/datasets/nvlm/wdai/data/LLaVA-v1.5_shuffle/no_refcoco_vg_ocrvqa # This file was automatically generated by `nvgpt4 data prepare`. - def sample_loader_18(raw: dict) -> dict: # Note: Images are already decoded to tensors json_item = raw["json"] images = [raw["jpg"]] @@ -509,18 +489,16 @@ def sample_loader_18(raw: dict) -> dict: # Note: Images are already decoded to def part_filter_18(part: str) -> bool: - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, # remove it from the list, such that it is not decoded. If you need all, keep as is return part in ("json", "jpg") +# Loader for: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/datasets/vqa/more_data/infographics_vqa/processed/train/ # This file was automatically generated by `nvgpt4 data prepare`. - def sample_loader_19(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] images = [raw["jpg"]] @@ -533,14 +511,13 @@ def sample_loader_19(raw: dict) -> dict: # Note: Images are already decoded to def part_filter_19(part: str) -> bool: - return part in ("jpg", "json") +# Loader for: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/datasets/nvlm/wdai/data/sharegpt4o/processed # This file was automatically generated by `nvgpt4 data prepare`. - def sample_loader_20(raw: dict) -> dict: # Note: Images are already decoded to tensors j = raw["json"] @@ -550,16 +527,15 @@ def sample_loader_20(raw: dict) -> dict: # Note: Images are already decoded to def part_filter_20(part: str) -> bool: - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, # remove it from the list, such that it is not decoded. If you need all, keep as is return part in ("jpg", "json") +# Loader for: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/datasets/nvlm/wdai/data/sparse_ocr_data/merged # This file was automatically generated by `nvgpt4 data prepare`. - def sample_loader_21(raw: dict) -> dict: # Note: Images are already decoded to tensors j = raw["json"] @@ -569,18 +545,16 @@ def sample_loader_21(raw: dict) -> dict: # Note: Images are already decoded to def part_filter_21(part: str) -> bool: - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, # remove it from the list, such that it is not decoded. If you need all, keep as is return part in ("json", "png") +# Loader for: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/datasets/nvlm/nayeonl/data/blendv4/MetaMathQA/processed/train_text_image # This file was automatically generated by `nvgpt4 data prepare`. - def sample_loader_22(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] images = [raw["jpg"]] @@ -593,16 +567,14 @@ def sample_loader_22(raw: dict) -> dict: # Note: Images are already decoded to def part_filter_22(part: str) -> bool: - return part in ("jpg", "json") +# Loader for: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/datasets/nvlm/nayeonl/data/blendv4/gsm8k/processed/train_text_image # This file was automatically generated by `nvgpt4 data prepare`. - def sample_loader_23(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] images = [raw["jpg"]] @@ -615,14 +587,13 @@ def sample_loader_23(raw: dict) -> dict: # Note: Images are already decoded to def part_filter_23(part: str) -> bool: - return part in ("jpg", "json") +# Loader for: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/datasets/nvlm/wdai/data/docmatix/processed # This file was automatically generated by `nvgpt4 data prepare`. - def sample_loader_24(raw: dict) -> dict: # Note: Images are already decoded to tensors json_item = raw["json"] images = [raw["jpg"]] @@ -636,16 +607,15 @@ def sample_loader_24(raw: dict) -> dict: # Note: Images are already decoded to def part_filter_24(part: str) -> bool: - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, # remove it from the list, such that it is not decoded. If you need all, keep as is return part in ("json", "jpg") +# Loader for: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/datasets/nvlm/wdai/data/bentham_hw_squad/processed # This file was automatically generated by `nvgpt4 data prepare`. - def sample_loader_25(raw: dict) -> dict: # Note: Images are already decoded to tensors j = raw["json"] @@ -655,16 +625,15 @@ def sample_loader_25(raw: dict) -> dict: # Note: Images are already decoded to def part_filter_25(part: str) -> bool: - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, # remove it from the list, such that it is not decoded. If you need all, keep as is return part in ("json", "jpg") +# Loader for: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/datasets/nvlm/wdai/data/WikiTableQA/processed # This file was automatically generated by `nvgpt4 data prepare`. - def sample_loader_26(raw: dict) -> dict: # Note: Images are already decoded to tensors j = raw["json"] @@ -674,16 +643,15 @@ def sample_loader_26(raw: dict) -> dict: # Note: Images are already decoded to def part_filter_26(part: str) -> bool: - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, # remove it from the list, such that it is not decoded. If you need all, keep as is return part in ("jpg", "json") +# Loader for: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/datasets/nvlm/wdai/data/figureqa/processed # This file was automatically generated by `nvgpt4 data prepare`. - def sample_loader_27(raw: dict) -> dict: # Note: Images are already decoded to tensors json_item = raw["json"] images = [raw["jpg"]] @@ -697,16 +665,15 @@ def sample_loader_27(raw: dict) -> dict: # Note: Images are already decoded to def part_filter_27(part: str) -> bool: - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, # remove it from the list, such that it is not decoded. If you need all, keep as is return part in ("json", "jpg") +# Loader for: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/datasets/nvlm/wdai/data/llava-onevision/ai2d_combined_processed # This file was automatically generated by `nvgpt4 data prepare`. - def sample_loader_28(raw: dict) -> dict: # Note: Images are already decoded to tensors json_item = raw["json"] images = [raw["jpg"]] @@ -720,16 +687,15 @@ def sample_loader_28(raw: dict) -> dict: # Note: Images are already decoded to def part_filter_28(part: str) -> bool: - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, # remove it from the list, such that it is not decoded. If you need all, keep as is return part in ("jpg", "json") +# Loader for: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/datasets/nvlm/wdai/data/llava-onevision/math_combined_processed # This file was automatically generated by `nvgpt4 data prepare`. - def sample_loader_29(raw: dict) -> dict: # Note: Images are already decoded to tensors json_item = raw["json"] images = [raw["jpg"]] @@ -743,16 +709,15 @@ def sample_loader_29(raw: dict) -> dict: # Note: Images are already decoded to def part_filter_29(part: str) -> bool: - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, # remove it from the list, such that it is not decoded. If you need all, keep as is return part in ("json", "jpg") +# Loader for: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/datasets/nvlm/wdai/data/llava-onevision/robut_combined_processed # This file was automatically generated by `nvgpt4 data prepare`. - def sample_loader_30(raw: dict) -> dict: # Note: Images are already decoded to tensors json_item = raw["json"] images = [raw["jpg"]] @@ -766,16 +731,15 @@ def sample_loader_30(raw: dict) -> dict: # Note: Images are already decoded to def part_filter_30(part: str) -> bool: - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, # remove it from the list, such that it is not decoded. If you need all, keep as is return part in ("json", "jpg") +# Loader for: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/datasets/nvlm/wdai/data/llava-onevision/llavar_20k_processed # This file was automatically generated by `nvgpt4 data prepare`. - def sample_loader_31(raw: dict) -> dict: # Note: Images are already decoded to tensors json_item = raw["json"] images = [raw["jpg"]] @@ -789,16 +753,15 @@ def sample_loader_31(raw: dict) -> dict: # Note: Images are already decoded to def part_filter_31(part: str) -> bool: - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, # remove it from the list, such that it is not decoded. If you need all, keep as is return part in ("jpg", "json") +# Loader for: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/datasets/nvlm/wdai/data/llava-onevision/tallyqa_processed # This file was automatically generated by `nvgpt4 data prepare`. - def sample_loader_32(raw: dict) -> dict: # Note: Images are already decoded to tensors json_item = raw["json"] images = [raw["jpg"]] @@ -812,16 +775,15 @@ def sample_loader_32(raw: dict) -> dict: # Note: Images are already decoded to def part_filter_32(part: str) -> bool: - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, # remove it from the list, such that it is not decoded. If you need all, keep as is return part in ("json", "jpg") +# Loader for: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/datasets/nvlm/wdai/data/llava-onevision/ureader_ie_processed # This file was automatically generated by `nvgpt4 data prepare`. - def sample_loader_33(raw: dict) -> dict: # Note: Images are already decoded to tensors json_item = raw["json"] images = [raw["jpg"]] @@ -835,16 +797,15 @@ def sample_loader_33(raw: dict) -> dict: # Note: Images are already decoded to def part_filter_33(part: str) -> bool: - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, # remove it from the list, such that it is not decoded. If you need all, keep as is return part in ("json", "jpg") +# Loader for: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/datasets/nvlm/wdai/data/llava-onevision/visual7w_processed # This file was automatically generated by `nvgpt4 data prepare`. - def sample_loader_34(raw: dict) -> dict: # Note: Images are already decoded to tensors json_item = raw["json"] images = [raw["jpg"]] @@ -858,16 +819,15 @@ def sample_loader_34(raw: dict) -> dict: # Note: Images are already decoded to def part_filter_34(part: str) -> bool: - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, # remove it from the list, such that it is not decoded. If you need all, keep as is return part in ("json", "jpg") +# Loader for: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/datasets/nvlm/wdai/data/llava-onevision/mavis_math_rule_geo_processed # This file was automatically generated by `nvgpt4 data prepare`. - def sample_loader_35(raw: dict) -> dict: # Note: Images are already decoded to tensors json_item = raw["json"] images = [raw["jpg"]] @@ -881,16 +841,15 @@ def sample_loader_35(raw: dict) -> dict: # Note: Images are already decoded to def part_filter_35(part: str) -> bool: - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, # remove it from the list, such that it is not decoded. If you need all, keep as is return part in ("jpg", "json") +# Loader for: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/datasets/nvlm/wdai/data/llava-onevision/ureader_kg_processed # This file was automatically generated by `nvgpt4 data prepare`. - def sample_loader_36(raw: dict) -> dict: # Note: Images are already decoded to tensors json_item = raw["json"] images = [raw["jpg"]] @@ -904,16 +863,15 @@ def sample_loader_36(raw: dict) -> dict: # Note: Images are already decoded to def part_filter_36(part: str) -> bool: - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, # remove it from the list, such that it is not decoded. If you need all, keep as is return part in ("jpg", "json") +# Loader for: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/datasets/nvlm/wdai/data/llava-onevision/ureader_qa_processed # This file was automatically generated by `nvgpt4 data prepare`. - def sample_loader_37(raw: dict) -> dict: # Note: Images are already decoded to tensors json_item = raw["json"] images = [raw["jpg"]] @@ -927,16 +885,15 @@ def sample_loader_37(raw: dict) -> dict: # Note: Images are already decoded to def part_filter_37(part: str) -> bool: - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, # remove it from the list, such that it is not decoded. If you need all, keep as is return part in ("json", "jpg") +# Loader for: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/datasets/nvlm/wdai/data/ocr_multi_collection_cocotext_textocr_ReCTs # This file was automatically generated by `nvgpt4 data prepare`. - def sample_loader_38(raw: dict) -> dict: j = raw["json"] @@ -959,16 +916,15 @@ def sample_loader_38(raw: dict) -> dict: def part_filter_38(part: str) -> bool: - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, # remove it from the list, such that it is not decoded. If you need all, keep as is return part in ("jpg", "json") +# Loader for: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/datasets/nvlm/wdai/data/pdfa-eng-wds/processed_word_len_500 # This file was automatically generated by `nvgpt4 data prepare`. - def sample_loader_39(raw: dict) -> dict: # Note: Images are already decoded to tensors j = raw["json"] return dict( @@ -978,16 +934,15 @@ def sample_loader_39(raw: dict) -> dict: # Note: Images are already decoded to def part_filter_39(part: str) -> bool: - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, # remove it from the list, such that it is not decoded. If you need all, keep as is return part in ("jpg", "json") +# Loader for: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/datasets/nvlm/wdai/data/llava-onevision/super_clevr_processed # This file was automatically generated by `nvgpt4 data prepare`. - def sample_loader_40(raw: dict) -> dict: # Note: Images are already decoded to tensors json_item = raw["json"] images = [raw["jpg"]] @@ -1001,16 +956,15 @@ def sample_loader_40(raw: dict) -> dict: # Note: Images are already decoded to def part_filter_40(part: str) -> bool: - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, # remove it from the list, such that it is not decoded. If you need all, keep as is return part in ("json", "jpg") +# Loader for: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/datasets/nvlm/wdai/data/llava-onevision/icon_qa_processed # This file was automatically generated by `nvgpt4 data prepare`. - def sample_loader_41(raw: dict) -> dict: # Note: Images are already decoded to tensors json_item = raw["json"] images = [raw["jpg"]] @@ -1024,16 +978,15 @@ def sample_loader_41(raw: dict) -> dict: # Note: Images are already decoded to def part_filter_41(part: str) -> bool: - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, # remove it from the list, such that it is not decoded. If you need all, keep as is return part in ("jpg", "json") +# Loader for: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/datasets/nvlm/wdai/data/augmentations/chartqa_aug # This file was automatically generated by `nvgpt4 data prepare`. - def sample_loader_42(raw: dict) -> dict: # Note: Images are already decoded to tensors json_item = raw["json"] @@ -1049,16 +1002,15 @@ def sample_loader_42(raw: dict) -> dict: # Note: Images are already decoded to def part_filter_42(part: str) -> bool: - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, # remove it from the list, such that it is not decoded. If you need all, keep as is return part in ("img", "json") +# Loader for: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/datasets/nvlm/wdai/data/augmentations/gpt_chartqa # This file was automatically generated by `nvgpt4 data prepare`. - def sample_loader_43(raw: dict) -> dict: # Note: Images are already decoded to tensors json_item = raw["json"] @@ -1074,16 +1026,15 @@ def sample_loader_43(raw: dict) -> dict: # Note: Images are already decoded to def part_filter_43(part: str) -> bool: - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, # remove it from the list, such that it is not decoded. If you need all, keep as is return part in ("img", "json") +# Loader for: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/datasets/nvlm/wdai/data/augmentations/gpt_docvqa # This file was automatically generated by `nvgpt4 data prepare`. - def sample_loader_44(raw: dict) -> dict: # Note: Images are already decoded to tensors json_item = raw["json"] @@ -1099,16 +1050,15 @@ def sample_loader_44(raw: dict) -> dict: # Note: Images are already decoded to def part_filter_44(part: str) -> bool: - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, # remove it from the list, such that it is not decoded. If you need all, keep as is return part in ("img", "json") +# Loader for: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/datasets/nvlm/wdai/data/augmentations/docvqa_text # This file was automatically generated by `nvgpt4 data prepare`. - def sample_loader_45(raw: dict) -> dict: # Note: Images are already decoded to tensors json_item = raw["json"] @@ -1124,16 +1074,15 @@ def sample_loader_45(raw: dict) -> dict: # Note: Images are already decoded to def part_filter_45(part: str) -> bool: - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, # remove it from the list, such that it is not decoded. If you need all, keep as is return part in ("img", "json") +# Loader for: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/datasets/nvlm/wdai/data/augmentations/textvqa_text # This file was automatically generated by `nvgpt4 data prepare`. - def sample_loader_46(raw: dict) -> dict: # Note: Images are already decoded to tensors json_item = raw["json"] @@ -1149,16 +1098,15 @@ def sample_loader_46(raw: dict) -> dict: # Note: Images are already decoded to def part_filter_46(part: str) -> bool: - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, # remove it from the list, such that it is not decoded. If you need all, keep as is return part in ("img", "json") +# Loader for: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/datasets/nvlm/wdai/data/augmentations/i2s-musicsheet # This file was automatically generated by `nvgpt4 data prepare`. - def sample_loader_47(raw: dict) -> dict: # Note: Images are already decoded to tensors json_item = raw["json"] @@ -1174,16 +1122,15 @@ def sample_loader_47(raw: dict) -> dict: # Note: Images are already decoded to def part_filter_47(part: str) -> bool: - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, # remove it from the list, such that it is not decoded. If you need all, keep as is return part in ("img", "json") +# Loader for: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/datasets/nvlm/wdai/data/augmentations/music # This file was automatically generated by `nvgpt4 data prepare`. - def sample_loader_48(raw: dict) -> dict: # Note: Images are already decoded to tensors json_item = raw["json"] @@ -1199,16 +1146,15 @@ def sample_loader_48(raw: dict) -> dict: # Note: Images are already decoded to def part_filter_48(part: str) -> bool: - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, # remove it from the list, such that it is not decoded. If you need all, keep as is return part in ("img", "json") +# Loader for: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/datasets/nvlm/wdai/data/augmentations/invoice # This file was automatically generated by `nvgpt4 data prepare`. - def sample_loader_49(raw: dict) -> dict: # Note: Images are already decoded to tensors json_item = raw["json"] @@ -1224,16 +1170,15 @@ def sample_loader_49(raw: dict) -> dict: # Note: Images are already decoded to def part_filter_49(part: str) -> bool: - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, # remove it from the list, such that it is not decoded. If you need all, keep as is return part in ("img", "json") +# Loader for: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/datasets/nvlm/wdai/data/augmentations/k12 # This file was automatically generated by `nvgpt4 data prepare`. - def sample_loader_50(raw: dict) -> dict: # Note: Images are already decoded to tensors json_item = raw["json"] @@ -1249,16 +1194,15 @@ def sample_loader_50(raw: dict) -> dict: # Note: Images are already decoded to def part_filter_50(part: str) -> bool: - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, # remove it from the list, such that it is not decoded. If you need all, keep as is return part in ("img", "json") +# Loader for: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/datasets/nvlm/wdai/data/augmentations/MTVQA # This file was automatically generated by `nvgpt4 data prepare`. - def sample_loader_51(raw: dict) -> dict: # Note: Images are already decoded to tensors json_item = raw["json"] for i, turn in enumerate(json_item["conversations"]): @@ -1277,16 +1221,15 @@ def sample_loader_51(raw: dict) -> dict: # Note: Images are already decoded to def part_filter_51(part: str) -> bool: - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, # remove it from the list, such that it is not decoded. If you need all, keep as is return part in ("img", "json") +# Loader for: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/datasets/nvlm/wdai/data/augmentations/VisualWebInstruct # This file was automatically generated by `nvgpt4 data prepare`. - def sample_loader_52(raw: dict) -> dict: # Note: Images are already decoded to tensors json_item = raw["json"] @@ -1302,16 +1245,15 @@ def sample_loader_52(raw: dict) -> dict: # Note: Images are already decoded to def part_filter_52(part: str) -> bool: - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, # remove it from the list, such that it is not decoded. If you need all, keep as is return part in ("img", "json") +# Loader for: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/datasets/nvlm/wdai/data/augmentations/financeqa # This file was automatically generated by `nvgpt4 data prepare`. - def sample_loader_53(raw: dict) -> dict: # Note: Images are already decoded to tensors json_item = raw["json"] # for i, turn in enumerate(json_item['conversations']): @@ -1330,16 +1272,15 @@ def sample_loader_53(raw: dict) -> dict: # Note: Images are already decoded to def part_filter_53(part: str) -> bool: - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, # remove it from the list, such that it is not decoded. If you need all, keep as is return part in ("img", "json") +# Loader for: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/datasets/nvlm/wdai/data/augmentations/docreason # This file was automatically generated by `nvgpt4 data prepare`. - def sample_loader_54(raw: dict) -> dict: # Note: Images are already decoded to tensors json_item = raw["json"] @@ -1355,16 +1296,15 @@ def sample_loader_54(raw: dict) -> dict: # Note: Images are already decoded to def part_filter_54(part: str) -> bool: - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, # remove it from the list, such that it is not decoded. If you need all, keep as is return part in ("img", "json") +# Loader for: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/datasets/nvlm/wdai/data/augmentations/gpt_mtwi # This file was automatically generated by `nvgpt4 data prepare`. - def sample_loader_55(raw: dict) -> dict: # Note: Images are already decoded to tensors json_item = raw["json"] @@ -1380,16 +1320,15 @@ def sample_loader_55(raw: dict) -> dict: # Note: Images are already decoded to def part_filter_55(part: str) -> bool: - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, # remove it from the list, such that it is not decoded. If you need all, keep as is return part in ("img", "json") +# Loader for: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/datasets/nvlm/wdai/data/augmentations/geos_gpt # This file was automatically generated by `nvgpt4 data prepare`. - def sample_loader_56(raw: dict) -> dict: # Note: Images are already decoded to tensors json_item = raw["json"] @@ -1405,16 +1344,15 @@ def sample_loader_56(raw: dict) -> dict: # Note: Images are already decoded to def part_filter_56(part: str) -> bool: - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, # remove it from the list, such that it is not decoded. If you need all, keep as is return part in ("img", "json") +# Loader for: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/datasets/nvlm/wdai/data/augmentations/cauldron_vistext # This file was automatically generated by `nvgpt4 data prepare`. - def sample_loader_57(raw: dict) -> dict: # Note: Images are already decoded to tensors json_item = raw["json"] @@ -1430,16 +1368,15 @@ def sample_loader_57(raw: dict) -> dict: # Note: Images are already decoded to def part_filter_57(part: str) -> bool: - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, # remove it from the list, such that it is not decoded. If you need all, keep as is return part in ("img", "json") +# Loader for: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/datasets/nvlm/wdai/data/augmentations/memes # This file was automatically generated by `nvgpt4 data prepare`. - def sample_loader_58(raw: dict) -> dict: # Note: Images are already decoded to tensors json_item = raw["json"] @@ -1455,16 +1392,15 @@ def sample_loader_58(raw: dict) -> dict: # Note: Images are already decoded to def part_filter_58(part: str) -> bool: - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, # remove it from the list, such that it is not decoded. If you need all, keep as is return part in ("img", "json") +# Loader for: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/datasets/nvlm/wdai/data/augmentations/gpt_roadtext # This file was automatically generated by `nvgpt4 data prepare`. - def sample_loader_59(raw: dict) -> dict: # Note: Images are already decoded to tensors json_item = raw["json"] @@ -1480,16 +1416,15 @@ def sample_loader_59(raw: dict) -> dict: # Note: Images are already decoded to def part_filter_59(part: str) -> bool: - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, # remove it from the list, such that it is not decoded. If you need all, keep as is return part in ("img", "json") +# Loader for: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/datasets/nvlm/wdai/data/augmentations/indoor_qa # This file was automatically generated by `nvgpt4 data prepare`. - def sample_loader_60(raw: dict) -> dict: # Note: Images are already decoded to tensors json_item = raw["json"] @@ -1505,16 +1440,15 @@ def sample_loader_60(raw: dict) -> dict: # Note: Images are already decoded to def part_filter_60(part: str) -> bool: - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, # remove it from the list, such that it is not decoded. If you need all, keep as is return part in ("img", "json") +# Loader for: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/datasets/nvlm/wdai/data/augmentations/colpali # This file was automatically generated by `nvgpt4 data prepare`. - def sample_loader_61(raw: dict) -> dict: # Note: Images are already decoded to tensors json_item = raw["json"] @@ -1530,16 +1464,15 @@ def sample_loader_61(raw: dict) -> dict: # Note: Images are already decoded to def part_filter_61(part: str) -> bool: - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, # remove it from the list, such that it is not decoded. If you need all, keep as is return part in ("img", "json") +# Loader for: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/datasets/nvlm/wdai/data/augmentations/pmc_vqa # This file was automatically generated by `nvgpt4 data prepare`. - def sample_loader_62(raw: dict) -> dict: # Note: Images are already decoded to tensors json_item = raw["json"] @@ -1555,16 +1488,15 @@ def sample_loader_62(raw: dict) -> dict: # Note: Images are already decoded to def part_filter_62(part: str) -> bool: - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, # remove it from the list, such that it is not decoded. If you need all, keep as is return part in ("img", "json") +# Loader for: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/datasets/nvlm/wdai/data/augmentations/pathvqa # This file was automatically generated by `nvgpt4 data prepare`. - def sample_loader_63(raw: dict) -> dict: # Note: Images are already decoded to tensors json_item = raw["json"] @@ -1580,16 +1512,15 @@ def sample_loader_63(raw: dict) -> dict: # Note: Images are already decoded to def part_filter_63(part: str) -> bool: - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, # remove it from the list, such that it is not decoded. If you need all, keep as is return part in ("img", "json") +# Loader for: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/datasets/nvlm/wdai/data/augmentations/sciqa # This file was automatically generated by `nvgpt4 data prepare`. - def sample_loader_64(raw: dict) -> dict: # Note: Images are already decoded to tensors json_item = raw["json"] @@ -1605,16 +1536,15 @@ def sample_loader_64(raw: dict) -> dict: # Note: Images are already decoded to def part_filter_64(part: str) -> bool: - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, # remove it from the list, such that it is not decoded. If you need all, keep as is return part in ("img", "json") +# Loader for: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/datasets/nvlm/wdai/data/augmentations/chinese_meme # This file was automatically generated by `nvgpt4 data prepare`. - def sample_loader_65(raw: dict) -> dict: # Note: Images are already decoded to tensors json_item = raw["json"] @@ -1630,16 +1560,15 @@ def sample_loader_65(raw: dict) -> dict: # Note: Images are already decoded to def part_filter_65(part: str) -> bool: - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, # remove it from the list, such that it is not decoded. If you need all, keep as is return part in ("img", "json") +# Loader for: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/datasets/nvlm/wdai/data/augmentations/gpt_hiertext # This file was automatically generated by `nvgpt4 data prepare`. - def sample_loader_66(raw: dict) -> dict: # Note: Images are already decoded to tensors json_item = raw["json"] @@ -1655,16 +1584,15 @@ def sample_loader_66(raw: dict) -> dict: # Note: Images are already decoded to def part_filter_66(part: str) -> bool: - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, # remove it from the list, such that it is not decoded. If you need all, keep as is return part in ("img", "json") +# Loader for: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/datasets/nvlm/wdai/data/augmentations/cauldron_cocoqa # This file was automatically generated by `nvgpt4 data prepare`. - def sample_loader_67(raw: dict) -> dict: # Note: Images are already decoded to tensors json_item = raw["json"] @@ -1680,16 +1608,15 @@ def sample_loader_67(raw: dict) -> dict: # Note: Images are already decoded to def part_filter_67(part: str) -> bool: - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, # remove it from the list, such that it is not decoded. If you need all, keep as is return part in ("img", "json") +# Loader for: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/datasets/nvlm/wdai/data/cmm-math/processed # This file was automatically generated by `nvgpt4 data prepare`. - def sample_loader_68(raw: dict) -> dict: # Note: Images are already decoded to tensors json_item = raw["json"] images = [raw["jpg"]] @@ -1703,16 +1630,15 @@ def sample_loader_68(raw: dict) -> dict: # Note: Images are already decoded to def part_filter_68(part: str) -> bool: - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, # remove it from the list, such that it is not decoded. If you need all, keep as is return part in ("json", "jpg") +# Loader for: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/datasets/nvlm/wdai/data/mmtab/processed # This file was automatically generated by `nvgpt4 data prepare`. - def sample_loader_69(raw: dict) -> dict: # Note: Images are already decoded to tensors json_item = raw["json"] images = [raw["jpg"]] @@ -1726,16 +1652,15 @@ def sample_loader_69(raw: dict) -> dict: # Note: Images are already decoded to def part_filter_69(part: str) -> bool: - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, # remove it from the list, such that it is not decoded. If you need all, keep as is return part in ("jpg", "json") +# Loader for: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/datasets/nvlm/wdai/data/simchart9k/processed # This file was automatically generated by `nvgpt4 data prepare`. - def sample_loader_70(raw: dict) -> dict: # Note: Images are already decoded to tensors json_item = raw["json"] images = [raw["jpg"]] @@ -1749,16 +1674,15 @@ def sample_loader_70(raw: dict) -> dict: # Note: Images are already decoded to def part_filter_70(part: str) -> bool: - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, # remove it from the list, such that it is not decoded. If you need all, keep as is return part in ("jpg", "json") +# Loader for: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/datasets/nvlm/wdai/data/llava-onevision/mapqa_processed # This file was automatically generated by `nvgpt4 data prepare`. - def sample_loader_71(raw: dict) -> dict: # Note: Images are already decoded to tensors json_item = raw["json"] images = [raw["jpg"]] @@ -1772,16 +1696,15 @@ def sample_loader_71(raw: dict) -> dict: # Note: Images are already decoded to def part_filter_71(part: str) -> bool: - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, # remove it from the list, such that it is not decoded. If you need all, keep as is return part in ("jpg", "json") +# Loader for: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/datasets/nvlm/wdai/data/llava-onevision/vizwiz_processed # This file was automatically generated by `nvgpt4 data prepare`. - def sample_loader_72(raw: dict) -> dict: # Note: Images are already decoded to tensors json_item = raw["json"] images = [raw["jpg"]] @@ -1795,16 +1718,15 @@ def sample_loader_72(raw: dict) -> dict: # Note: Images are already decoded to def part_filter_72(part: str) -> bool: - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, # remove it from the list, such that it is not decoded. If you need all, keep as is return part in ("jpg", "json") +# Loader for: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/datasets/nvlm/wdai/data/augmentations/gpt_infovqa # This file was automatically generated by `nvgpt4 data prepare`. - def sample_loader_73(raw: dict) -> dict: # Note: Images are already decoded to tensors json_item = raw["json"] @@ -1820,16 +1742,15 @@ def sample_loader_73(raw: dict) -> dict: # Note: Images are already decoded to def part_filter_73(part: str) -> bool: - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, # remove it from the list, such that it is not decoded. If you need all, keep as is return part in ("json", "img") +# Loader for: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/datasets/nvlm/wdai/data/augmentations/viquae # This file was automatically generated by `nvgpt4 data prepare`. - def sample_loader_74(raw: dict) -> dict: # Note: Images are already decoded to tensors json_item = raw["json"] @@ -1845,18 +1766,16 @@ def sample_loader_74(raw: dict) -> dict: # Note: Images are already decoded to def part_filter_74(part: str) -> bool: - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, # remove it from the list, such that it is not decoded. If you need all, keep as is return part in ("json", "img") +# Loader for: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/datasets/captioning/ccs_recaptioned/webdataset # This file was automatically generated by `nvgpt4 data prepare`. - def sample_loader_75(raw: dict) -> dict: # Note: Images are already decoded to tensors - if "text" in raw: caption = raw["text"] else: @@ -1869,18 +1788,16 @@ def sample_loader_75(raw: dict) -> dict: # Note: Images are already decoded to def part_filter_75(part: str) -> bool: - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, # remove it from the list, such that it is not decoded. If you need all, keep as is return part in ("json", "jpg", "text") +# Loader for: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/datasets/captioning/laion115m-clean # This file was automatically generated by `nvgpt4 data prepare`. - def sample_loader_76(raw: dict) -> dict: # Note: Images are already decoded to tensors - if "text" in raw: caption = raw["text"] else: @@ -1893,16 +1810,15 @@ def sample_loader_76(raw: dict) -> dict: # Note: Images are already decoded to def part_filter_76(part: str) -> bool: - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, # remove it from the list, such that it is not decoded. If you need all, keep as is return part in ("json", "jpg", "text") +# Loader for: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/datasets/nvlm/wdai/data/dvqa_full/processed_pt # This file was automatically generated by `nvgpt4 data prepare`. - def sample_loader_77(raw: dict) -> dict: # Note: Images are already decoded to tensors json_item = raw["json"] @@ -1921,16 +1837,15 @@ def sample_loader_77(raw: dict) -> dict: # Note: Images are already decoded to def part_filter_77(part: str) -> bool: - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, # remove it from the list, such that it is not decoded. If you need all, keep as is return part in ("json", "jpg") +# Loader for: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/datasets/nvlm/wdai/data/docmatix/processed_pt # This file was automatically generated by `nvgpt4 data prepare`. - def sample_loader_78(raw: dict) -> dict: # Note: Images are already decoded to tensors json_item = raw["json"] @@ -1949,12 +1864,12 @@ def sample_loader_78(raw: dict) -> dict: # Note: Images are already decoded to def part_filter_78(part: str) -> bool: - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, # remove it from the list, such that it is not decoded. If you need all, keep as is return part in ("json", "jpg") +# Loader for: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/datasets/vqa/VQAv2/stage1 def sample_loader_79(raw: dict) -> dict: # Note: Images are already decoded to tensors @@ -1980,12 +1895,11 @@ def part_filter_79(part: str) -> bool: return part in ("jpg", "json") +# Loader for: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/datasets/vqa/Visual_Genome # This file was automatically generated by `nvgpt4 data prepare`. - def sample_loader_80(raw: dict) -> dict: # Note: Images are already decoded to tensors - return dict( __key__=raw["__key__"], image=raw["jpg"], # expected type: torch.Tensor @@ -1996,14 +1910,13 @@ def sample_loader_80(raw: dict) -> dict: # Note: Images are already decoded to def part_filter_80(part: str) -> bool: - return part in ("json", "jpg") +# Loader for: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/datasets/nvlm/wdai/data/pdfa-eng-wds/processed_word_len_300 # This file was automatically generated by `nvgpt4 data prepare`. - def sample_loader_81(raw: dict) -> dict: # Note: Images are already decoded to tensors j = raw["json"] return dict( @@ -2013,16 +1926,15 @@ def sample_loader_81(raw: dict) -> dict: # Note: Images are already decoded to def part_filter_81(part: str) -> bool: - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, # remove it from the list, such that it is not decoded. If you need all, keep as is return part in ("jpg", "json") +# Loader for: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/datasets/nvlm/wdai/data/textocr/processed # This file was automatically generated by `nvgpt4 data prepare`. - def sample_loader_82(raw: dict) -> dict: # Note: Images are already decoded to tensors j = raw["json"] @@ -2035,16 +1947,15 @@ def sample_loader_82(raw: dict) -> dict: # Note: Images are already decoded to def part_filter_82(part: str) -> bool: - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, # remove it from the list, such that it is not decoded. If you need all, keep as is return part in ("jpg", "json") +# Loader for: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/datasets/nvlm/wdai/data/coco-text/processed # This file was automatically generated by `nvgpt4 data prepare`. - def sample_loader_83(raw: dict) -> dict: j = raw["json"] @@ -2057,16 +1968,15 @@ def sample_loader_83(raw: dict) -> dict: def part_filter_83(part: str) -> bool: - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, # remove it from the list, such that it is not decoded. If you need all, keep as is return part in ("json", "jpg") +# Loader for: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/datasets/nvlm/wdai/data/ArT/processed # This file was automatically generated by `nvgpt4 data prepare`. - def sample_loader_84(raw: dict) -> dict: # Note: Images are already decoded to tensors j = raw["json"] @@ -2079,16 +1989,15 @@ def sample_loader_84(raw: dict) -> dict: # Note: Images are already decoded to def part_filter_84(part: str) -> bool: - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, # remove it from the list, such that it is not decoded. If you need all, keep as is return part in ("json", "jpg") +# Loader for: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/datasets/nvlm/wdai/data/ReCTs/processed # This file was automatically generated by `nvgpt4 data prepare`. - def sample_loader_85(raw: dict) -> dict: # Note: Images are already decoded to tensors j = raw["json"] @@ -2096,16 +2005,15 @@ def sample_loader_85(raw: dict) -> dict: # Note: Images are already decoded to def part_filter_85(part: str) -> bool: - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, # remove it from the list, such that it is not decoded. If you need all, keep as is return part in ("json", "jpg") +# Loader for: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/datasets/nvlm/wdai/data/lsvt/processed # This file was automatically generated by `nvgpt4 data prepare`. - def sample_loader_86(raw: dict) -> dict: # Note: Images are already decoded to tensors j = raw["json"] @@ -2118,16 +2026,15 @@ def sample_loader_86(raw: dict) -> dict: # Note: Images are already decoded to def part_filter_86(part: str) -> bool: - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, # remove it from the list, such that it is not decoded. If you need all, keep as is return part in ("jpg", "json") +# Loader for: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/datasets/nvlm/wdai/data/RCTW/processed # This file was automatically generated by `nvgpt4 data prepare`. - def sample_loader_87(raw: dict) -> dict: # Note: Images are already decoded to tensors j = raw["json"] @@ -2142,16 +2049,15 @@ def sample_loader_87(raw: dict) -> dict: # Note: Images are already decoded to def part_filter_87(part: str) -> bool: - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, # remove it from the list, such that it is not decoded. If you need all, keep as is return part in ("json", "jpg") +# Loader for: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/datasets/nvlm/wdai/data/coco-text/processed_multi # This file was automatically generated by `nvgpt4 data prepare`. - def sample_loader_88(raw: dict) -> dict: # Note: Images are already decoded to tensors j = raw["json"] return dict( @@ -2164,16 +2070,15 @@ def sample_loader_88(raw: dict) -> dict: # Note: Images are already decoded to def part_filter_88(part: str) -> bool: - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, # remove it from the list, such that it is not decoded. If you need all, keep as is return part in ("json", "jpg") +# Loader for: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/datasets/nvlm/wdai/data/textocr/processed_multi # This file was automatically generated by `nvgpt4 data prepare`. - def sample_loader_89(raw: dict) -> dict: j = raw["json"] return dict( @@ -2186,16 +2091,15 @@ def sample_loader_89(raw: dict) -> dict: def part_filter_89(part: str) -> bool: - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, # remove it from the list, such that it is not decoded. If you need all, keep as is return part in ("jpg", "json") +# Loader for: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/datasets/nvlm/wdai/data/ReCTs/processed_multi # This file was automatically generated by `nvgpt4 data prepare`. - def sample_loader_90(raw: dict) -> dict: # Note: Images are already decoded to tensors j = raw["json"] return dict( @@ -2208,12 +2112,12 @@ def sample_loader_90(raw: dict) -> dict: # Note: Images are already decoded to def part_filter_90(part: str) -> bool: - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, # remove it from the list, such that it is not decoded. If you need all, keep as is return part in ("json", "jpg") +# Loader for: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/datasets/vqa/VQAv2/stage1 def sample_loader_91(raw: dict) -> dict: # Note: Images are already decoded to tensors diff --git a/cosmos_framework/data/vfm/augmentors/vlm/prompt_format.py b/cosmos_framework/data/vfm/augmentors/vlm/prompt_format.py index ec86e66..5b576c4 100644 --- a/cosmos_framework/data/vfm/augmentors/vlm/prompt_format.py +++ b/cosmos_framework/data/vfm/augmentors/vlm/prompt_format.py @@ -45,7 +45,6 @@ def __call__(self, data_dict: Dict) -> Dict: if isinstance(list_of_conversation[0], list): selected_conversation = random.sample(list_of_conversation, 1)[0] elif isinstance(list_of_conversation[0], dict): - selected_conversation = list_of_conversation else: raise ValueError( @@ -82,7 +81,6 @@ def __call__(self, data_dict: Dict) -> Dict: del data_dict[conversation_key] - # # enforce chat order # self._enforce_text_chat_order(selected_conversation) @@ -91,7 +89,7 @@ def __call__(self, data_dict: Dict) -> Dict: def _enforce_text_chat_order(self, conversation: list) -> None: """ Reorder text content within user messages based on text_chat_order setting. - NOTE: this does NOT work for interleaved data!!!!!! + NOTE (maxzhaoshuol): this does NOT work for interleaved data!!!!!! Args: conversation: List of message dictionaries diff --git a/cosmos_framework/data/vfm/augmentors/vlm/timestamp.py b/cosmos_framework/data/vfm/augmentors/vlm/timestamp.py index edede0c..6228153 100644 --- a/cosmos_framework/data/vfm/augmentors/vlm/timestamp.py +++ b/cosmos_framework/data/vfm/augmentors/vlm/timestamp.py @@ -392,17 +392,15 @@ def augment_user_prompt( elif output_format == "temporal_caption": event = assistant_message[0] if random.random() < 0.333333: - start = round(event["start"]) end = round(event["end"]) elif random.random() < 0.666666: - start = round(event["start"] * 2) / 2 end = round(event["end"] * 2) / 2 else: start = event["start"] end = event["end"] - if start == end: # HACK: remove events with start == end + if start == end: raise ValueError("Start and end time are the same for data.") if timestamp_format == "seconds": if random.random() < 0.5: diff --git a/cosmos_framework/data/vfm/augmentors/vlm/timestamp_with_subject_tracking.py b/cosmos_framework/data/vfm/augmentors/vlm/timestamp_with_subject_tracking.py index 0507109..90cc9ab 100644 --- a/cosmos_framework/data/vfm/augmentors/vlm/timestamp_with_subject_tracking.py +++ b/cosmos_framework/data/vfm/augmentors/vlm/timestamp_with_subject_tracking.py @@ -224,17 +224,15 @@ def augment_user_prompt( elif output_format == "temporal_caption_subject": event = assistant_message[0] if random.random() < 0.333333: - start = round(event["start"]) end = round(event["end"]) elif random.random() < 0.666666: - start = round(event["start"] * 2) / 2 end = round(event["end"] * 2) / 2 else: start = event["start"] end = event["end"] - if start == end: # HACK: remove events with start == end + if start == end: log.warning(f"Start and end time are the same for data. {event}") return None diff --git a/cosmos_framework/data/vfm/augmentors/vlm/timestamp_without_augment_message.py b/cosmos_framework/data/vfm/augmentors/vlm/timestamp_without_augment_message.py index 584ca1d..7212510 100644 --- a/cosmos_framework/data/vfm/augmentors/vlm/timestamp_without_augment_message.py +++ b/cosmos_framework/data/vfm/augmentors/vlm/timestamp_without_augment_message.py @@ -162,14 +162,12 @@ def augment_user_prompt( elif output_format == "temporal_caption": event = assistant_message[0] if random.random() < 0.5: - start = round(event["start"]) end = round(event["end"]) else: - start = round(event["start"] * 2) / 2 end = round(event["end"] * 2) / 2 - if start == end: # HACK: remove events with start == end + if start == end: raise ValueError("Start and end time are the same for data.") user_prompt = random.choice( [ diff --git a/cosmos_framework/data/vfm/augmentors/vlm/timestamp_without_end_time.py b/cosmos_framework/data/vfm/augmentors/vlm/timestamp_without_end_time.py index 8df9dd1..812e113 100644 --- a/cosmos_framework/data/vfm/augmentors/vlm/timestamp_without_end_time.py +++ b/cosmos_framework/data/vfm/augmentors/vlm/timestamp_without_end_time.py @@ -205,10 +205,8 @@ def augment_user_prompt( elif output_format == "temporal_caption": event = assistant_message[0] if random.random() < 0.333333: - start = round(event["start"]) elif random.random() < 0.666666: - start = round(event["start"] * 2) / 2 else: start = event["start"] diff --git a/cosmos_framework/data/vfm/augmentors/vlm/tokenize_data.py b/cosmos_framework/data/vfm/augmentors/vlm/tokenize_data.py index a0ce29c..f3a9914 100644 --- a/cosmos_framework/data/vfm/augmentors/vlm/tokenize_data.py +++ b/cosmos_framework/data/vfm/augmentors/vlm/tokenize_data.py @@ -158,7 +158,6 @@ def __call__(self, data_dict: Dict) -> Dict: if message["role"] == "user" and isinstance(message["content"], list): total_images += len([content for content in message["content"] if content["type"] == "image"]) total_videos += len([content for content in message["content"] if content["type"] == "video"]) - assert total_videos == 1 or total_videos == 0, "Only one video is supported for now" # url @@ -167,7 +166,6 @@ def __call__(self, data_dict: Dict) -> Dict: # go through each message in the conversation for message in conversation: # for user message, we insert the media - if message["role"] == "user" and isinstance( message["content"], list ): # Otherwise it's text and content is a string @@ -225,7 +223,6 @@ def __call__(self, data_dict: Dict) -> Dict: raw_images.append(image) elif content["type"] == "video": - # as tokenization will NOT upsample the video, we can use a larger value here at the cost of multiple video having 1.5x token length max_total_pixels = token_to_pixels(self.max_video_token_length * 1.5, temporal_patch_size=2) media_key = content["video"] @@ -248,7 +245,6 @@ def __call__(self, data_dict: Dict) -> Dict: return None videos = data_dict["media"][media_key]["videos"] # list of PIL images fps = data_dict["media"][media_key]["fps"] - # this is because videos are decoded to be around "max_video_token_length" tokens videos = maybe_subsample_frames( diff --git a/cosmos_framework/data/vfm/joint_dataloader.py b/cosmos_framework/data/vfm/joint_dataloader.py index 56b13e7..4b6b443 100644 --- a/cosmos_framework/data/vfm/joint_dataloader.py +++ b/cosmos_framework/data/vfm/joint_dataloader.py @@ -57,7 +57,6 @@ def custom_collate_fn(batch): # Handle standard list of samples elem = batch[0] if isinstance(elem, dict): - # Some Action datasets add optional metadata keys (for example # ``additional_view_description`` for concat-view captions) only for a # subset of samples. PyTorch can batch such samples together when @@ -261,7 +260,7 @@ def _prewarm_dataloaders(self) -> None: The first ``next()`` call on an ``InfiniteDataLoader`` iterator triggers ``DataLoader.__iter__()`` which spawns worker processes. For action dataloaders using ``multiprocessing_context='spawn'``, each worker must - fully initialise heavy datasets (BridgeOrigLeRobotDataset, EMBODIMENT_A, etc.) + fully initialise heavy datasets (BridgeOrigLeRobotDataset, embodiment_a, etc.) from scratch. If this happens lazily during training, the resulting delay (potentially minutes) causes NCCL collective timeouts when faster ranks enter the forward pass while slower ranks are still loading data. diff --git a/cosmos_framework/data/vfm/packing_iterable_dataset.py b/cosmos_framework/data/vfm/packing_iterable_dataset.py index 847b85a..715ce6e 100644 --- a/cosmos_framework/data/vfm/packing_iterable_dataset.py +++ b/cosmos_framework/data/vfm/packing_iterable_dataset.py @@ -1,16 +1,27 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: OpenMDW-1.1 +# ----------------------------------------------------------------------------- +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# +# This codebase constitutes NVIDIA proprietary technology and is strictly +# confidential. Any unauthorized reproduction, distribution, or disclosure +# of this code, in whole or in part, without express written consent of +# NVIDIA is strictly prohibited. +# ----------------------------------------------------------------------------- """ Abstract base class for pool-based token-budget bin-packing over multiple datasets. -Extracted from ``projects.cosmos3.vfm.datasets.vlm.joint_dataset_dynamic_batch_webloader`` +Extracted from ``cosmos_framework.data.vfm.vlm.joint_dataset_dynamic_batch_webloader`` so that both the VLM and VFM internal dataloaders can share a single packing implementation. Usage ----- Subclass and implement ``compute_sample_tokens(sample) -> int``. Optionally override ``collate_batch(samples) -> Any`` for custom collation. + + class MyPacker(PackingIterableDataset): + def compute_sample_tokens(self, sample): + return len(sample["input_ids"]) """ from __future__ import annotations @@ -62,11 +73,6 @@ class PackingIterableDataset(torch.utils.data.IterableDataset, ABC): singletons regardless of budget. batching_strategy: ``"prefer_closest"`` (default) or ``"prefer_first"``. - apply_long_sample_halving: - When ``True`` (default), ``_max_tokens`` halves the budget for any - batch whose largest sample exceeds 1000 tokens — a memory-safety - heuristic. Set ``False`` only when memory headroom at the literal - ``max_tokens`` budget has been validated for the recipe. """ def __init__( @@ -77,7 +83,6 @@ def __init__( max_batch_size: int, long_threshold: int, batching_strategy: str, - apply_long_sample_halving: bool = True, ): super().__init__() @@ -90,7 +95,6 @@ def __init__( self.long_threshold = long_threshold self.max_batch_size = max_batch_size self.batching_strategy = batching_strategy - self.apply_long_sample_halving = apply_long_sample_halving self._pool: deque[dict] = deque() self._dataset_names: list[str] = [] @@ -166,8 +170,6 @@ def __iter__(self): # ------------------------------------------------------------------ def _max_tokens(self, cur_max: int) -> int: - if not self.apply_long_sample_halving: - return self.max_tokens if cur_max < 1000: return self.max_tokens return self.max_tokens // 2 diff --git a/cosmos_framework/data/vfm/processors/__init__.py b/cosmos_framework/data/vfm/processors/__init__.py index a1cc60e..646800b 100644 --- a/cosmos_framework/data/vfm/processors/__init__.py +++ b/cosmos_framework/data/vfm/processors/__init__.py @@ -104,13 +104,6 @@ def build_processor( bucket: Optional[str] = None, cache_dir: Optional[str] = None, ): - # Local artifact path: source the processor from a bundled directory - # (e.g. the top level of nvidia/Cosmos3-Nano, which ships its own - # preprocessor_config.json, tokenizer.json, etc). Avoids the redundant - # upstream Qwen/Qwen3-VL-*-Instruct fetch. Cosmos3-Nano/Super both ship - # a Qwen3VL-compatible processor, so dispatch to Qwen3VLProcessor. - if os.path.isdir(tokenizer_type): - return Qwen3VLProcessor(tokenizer_type, cache_dir=cache_dir) if credentials is None or bucket is None: if config_variant is None: config_variant = "s3" @@ -125,7 +118,12 @@ def build_processor( return Qwen3VLProcessor(tokenizer_type, credentials=credentials, bucket=bucket, cache_dir=cache_dir) elif "nvidia/NVIDIA-Nemotron-Nano-12B-v2-VL-BF16" in tokenizer_type: return NemotronVLProcessor(tokenizer_type, credentials=credentials, bucket=bucket, cache_dir=cache_dir) - elif "NVIDIA-Nemotron-3-Dense-VL" in tokenizer_type or "Qwen3-2B-ViT" in tokenizer_type: + elif ( + "NVIDIA-Nemotron-3-Dense-VL" in tokenizer_type + or "Qwen3-2B-ViT" in tokenizer_type + or "nvidia/Cosmos3-Reasoner-2B-Private" in tokenizer_type + or "nvidia/Cosmos3-Edge-Reasoner" in tokenizer_type + ): return Nemotron3DenseVLProcessor(tokenizer_type, credentials=credentials, bucket=bucket, cache_dir=cache_dir) elif "Qwen/Qwen3-0.6B" in tokenizer_type: local_path = _download_llm_tokenizer(tokenizer_type, credentials, bucket, cache_dir) @@ -138,26 +136,9 @@ def build_processor( raise ValueError(f"Tokenizer type {tokenizer_type} not supported") -def build_processor_lazy( - *args, - repository: Optional[str] = None, - revision: Optional[str] = None, - subdir: str = "", - **kwargs, -): +def build_processor_lazy(*args, **kwargs): """LazyCall wrapper that resolves ``build_processor`` on this module at call time. - Two modes: - 1. Upstream tokenizer (legacy): pass ``tokenizer_type=""`` - (and optional ``config_variant`` / ``credentials`` / ``bucket``). - The processor is sourced from the upstream HF repo (e.g. - ``Qwen/Qwen3-VL-8B-Instruct``). - 2. Local artifact: pass ``repository`` + ``revision`` (and optional - ``subdir``). The processor is sourced from the HF cache of the - named artifact (e.g. ``nvidia/Cosmos3-Nano``), reusing the same - revision the OmniModel checkpoint download uses. Avoids a - redundant upstream Qwen3-VL-*-Instruct fetch. - LazyCall captures its target at config-construction time, so a direct ``L(build_processor)`` would freeze the original function reference and bypass any later ``monkeypatch.setattr`` on this module's @@ -165,13 +146,4 @@ def build_processor_lazy( lookup on every call, so test fixtures patching ``build_processor`` are honored when the config is instantiated. """ - if repository is not None: - from cosmos_framework.utils.checkpoint_db import CheckpointDirHf - - if revision is None: - raise ValueError("'revision' is required when 'repository' is set") - local_path = CheckpointDirHf(repository=repository, revision=revision).download() - if subdir: - local_path = os.path.join(local_path, subdir) - return sys.modules[__name__].build_processor(local_path, **kwargs) return sys.modules[__name__].build_processor(*args, **kwargs) diff --git a/cosmos_framework/data/vfm/processors/nemotronvl_processor.py b/cosmos_framework/data/vfm/processors/nemotronvl_processor.py index 767c8ef..c6addbf 100644 --- a/cosmos_framework/data/vfm/processors/nemotronvl_processor.py +++ b/cosmos_framework/data/vfm/processors/nemotronvl_processor.py @@ -248,7 +248,6 @@ def __init__( # NemotronVL hardcodes these helper attributes because they are not # discoverable from the HF model config; the values match the upstream # vision-encoder configuration. - # HACK: hardcoded based on the model config. self.min_height_width = 512 self.patch_size = 16 self.temporal_patch_size = 1 @@ -258,7 +257,6 @@ def __init__( def _resolve_pad_id(self): # NemotronVL's tokenizer does not specify a pad_token; reserve # for padding (project convention). - return self.processor.tokenizer.convert_tokens_to_ids("") def apply_chat_template( diff --git a/cosmos_framework/data/vfm/sequence_packing.py b/cosmos_framework/data/vfm/sequence_packing.py index d2821cc..ae8720d 100644 --- a/cosmos_framework/data/vfm/sequence_packing.py +++ b/cosmos_framework/data/vfm/sequence_packing.py @@ -850,7 +850,6 @@ def _pack_action_tokens( packed_seq.action.token_shapes.append((action_split_len,)) packed_seq.action.tokens.append(input_action_tokens) - condition_set = {idx for idx in condition_frame_indexes_action if 0 <= idx < action_split_len} assert isinstance(packed_seq.action.condition_mask, list) @@ -2095,7 +2094,7 @@ def verify_natten_parameter_list( {'window_size_float': (0.5, 0.5), 'dilation_float': (1.0, 0.0)} # valid # Fixed window size of 8x8, dilation of 2x1. - + # NOTE: requires ALL inputs to be at least 16x8 {'window_size': (8, 8), 'dilation': (2, 1)} # valid # Multi-profile: different parameters for 2D (images) and 3D (videos) @@ -2231,7 +2230,7 @@ def generate_natten_metadata( {'window_size_float': (0.5, 0.5), 'dilation_float': (1.0, 0.0)} # valid # Fixed window size of 8x8, dilation of 2x1. - + # NOTE: requires ALL inputs to be at least 16x8 {'window_size': (8, 8), 'dilation': (2, 1)} # valid # Invalid: @@ -2247,6 +2246,11 @@ def generate_natten_metadata( over layers (nn.ModuleList). """ + # COSMOS-RELEASE-BEGIN-IGNORE + # sequence-packed tensors containing only and exactly subsequences with sizes from + # token_shapes, in the same order, and with no padding in between. + # We should either make sure this never happens, or have static checks in place. + # COSMOS-RELEASE-END-IGNORE if token_shapes is None or len(token_shapes) < 1: raise ValueError("'token_shapes' is required for 'three_way' attention.") @@ -2269,6 +2273,11 @@ def filter_shape(shape: tuple) -> tuple: return tuple(x for x in shape if x > 1) # Infer token layout rank (dimensionality) + # COSMOS-RELEASE-BEGIN-IGNORE + # compresses that dimension into size 1, which gets filtered out. To avoid failing checks + # we need to take the maximum dimensionality over the entire batch. We'll assert each token + # shape matches that dimensionality later, if NATTEN is required for this batch. + # COSMOS-RELEASE-END-IGNORE num_dims = max([len(filter_shape(token_shape)) for token_shape in token_shapes]) # Single pass: check if all layers support this dimensionality and if any need processing @@ -2363,9 +2372,9 @@ def filter_shape(shape: tuple) -> tuple: is_causal = dim_params["is_causal"] # Create varlen metadata for natten varlen/varsized ops - + # NOTE: generate_multi_dim_varlen_parameters will automatically map window size -1 to # full size, that's why constant window sizes aren't allowed. - + # NOTE: if any of the parameters are constant, natten will simplify them natten_metadata.append( generate_multi_dim_varlen_parameters( token_layout_list=token_layout_list, @@ -2780,7 +2789,6 @@ def build_sequence_plans_from_data_batch( Returns: List of SequencePlan objects, one per sample in the batch. """ - # For new modalities, please generate the sequence_plan in the dataset class!!!! # If sequence_plan already exists in data_batch, return it @@ -2790,7 +2798,6 @@ def build_sequence_plans_from_data_batch( assert "action" not in data_batch or data_batch["action"] is None, "Action data SHOULD have sequence_plans!" assert "sound" not in data_batch or data_batch["sound"] is None, "Sound data SHOULD have sequence_plans!" - # Determine batch size from available tensors batch_size = 0 for key in [input_video_key, input_image_key]: diff --git a/cosmos_framework/data/vfm/sound_data_utils.py b/cosmos_framework/data/vfm/sound_data_utils.py index 0a8ec63..2d739b0 100644 --- a/cosmos_framework/data/vfm/sound_data_utils.py +++ b/cosmos_framework/data/vfm/sound_data_utils.py @@ -3,7 +3,8 @@ """Sound data utilities for building sequence plans and handling audio-video generation modes. -This module provides utilities for building SequencePlan objects based on sound generation modes. +This module provides utilities for building SequencePlan objects based on sound generation modes, +similar to how action modes are handled in cosmos_framework/data/vfm/action/data_utils.py. Supported modes: - t2vs: Text → Video + Sound (joint generation) @@ -27,7 +28,8 @@ def build_sequence_plan_for_sound( """Build a SequencePlan based on the sound generation mode. This function determines the appropriate condition frame indexes for vision and sound - based on the specified mode. + based on the specified mode. It mirrors how `build_sequence_plan_from_mode` works + for action in cosmos_framework/data/vfm/action/data_utils.py. Args: mode: Generation mode. One of: diff --git a/cosmos_framework/data/vfm/utils.py b/cosmos_framework/data/vfm/utils.py index ab6f238..a96f725 100644 --- a/cosmos_framework/data/vfm/utils.py +++ b/cosmos_framework/data/vfm/utils.py @@ -6,7 +6,6 @@ IMAGE_RES_SIZE_INFO: dict[str, dict[str, tuple[int, int]]] = { # Our desired 256 resolution is the one below (commented). - # Desired: "256": {"1,1": (336, 336), "4,3": (384, 288), "3,4": (288, 384), "16,9": (448, 256), "9,16": (256, 448)}, "256": { "1,1": (256, 256), @@ -41,7 +40,6 @@ VIDEO_RES_SIZE_INFO: dict[str, dict[str, tuple[int, int]]] = { # Our desired 256 resolution is the one below (commented). - # Desired: "256": {"1,1": (336, 336), "4,3": (384, 288), "3,4": (288, 384), "16,9": (448, 256), "9,16": (256, 448)}, "256": { "1,1": (256, 256), diff --git a/cosmos_framework/model/attention/backends.py b/cosmos_framework/model/attention/backends.py index 4236b35..74f3cf0 100644 --- a/cosmos_framework/model/attention/backends.py +++ b/cosmos_framework/model/attention/backends.py @@ -22,8 +22,14 @@ from cosmos_framework.model.attention.utils.safe_ops import log from cosmos_framework.model.attention.utils.safe_ops.functools import lru_cache +# COSMOS-RELEASE-BEGIN-IGNORE +# isort: split +from cosmos_framework.model.attention.cudnn.checks import cudnn_attention_check + +# COSMOS-RELEASE-END-IGNORE BACKEND_CHECK_MAP = { + "cudnn": cudnn_attention_check, # COSMOS-RELEASE-IGNORELINE "natten": natten_attention_check, "flash2": flash2_attention_check, "flash3": flash3_attention_check, @@ -131,17 +137,22 @@ def get_backend_list(arch_tag: int) -> list[str]: if arch_tag == 90: default_backends = [ "flash3", + "cudnn", # COSMOS-RELEASE-IGNORELINE "natten", "flash2", ] elif arch_tag in [100, 103]: default_backends = [ + # COSMOS-RELEASE-BEGIN-IGNORE + "cudnn", + # COSMOS-RELEASE-END-IGNORE "natten", "flash2", ] elif arch_tag >= 80: default_backends = [ "flash2", + "cudnn", # COSMOS-RELEASE-IGNORELINE "natten", ] else: diff --git a/cosmos_framework/model/attention/checks.py b/cosmos_framework/model/attention/checks.py index 87d1d53..cb74e62 100644 --- a/cosmos_framework/model/attention/checks.py +++ b/cosmos_framework/model/attention/checks.py @@ -254,7 +254,7 @@ def varlen_tensor_checks( f"Q, K, and V must match in batch size, got {query.shape[0]=}, {key.shape[0]=}, {value.shape[0]=}." ) - + # NOTE: these checks introduce recompiles if not is_torch_compiling(): # Validate max_seqlen values: neither can be negative, and they must be # both zero/None (not varlen) or both positive (varlen). @@ -299,7 +299,7 @@ def varlen_tensor_checks( ) # Validate user-input cumulative_seqlen_{Q,KV}, max_seqlen_{Q,KV}, total_seqlen_{Q,KV} - + # NOTE: max_seqlen_Q == max_seqlen_KV == 0 is valid here (skip kernel / empty-batch case). # Mismatch (one 0, the other positive) is already caught by the early check above. # This feature may require support in the backends themselves; see NATTEN PR: # https://github.com/SHI-Labs/NATTEN/pull/327 @@ -334,7 +334,7 @@ def varlen_tensor_checks( total_seqlen_Q = query.shape[1] total_seqlen_KV = key.shape[1] - + # NOTE: these checks introduce recompiles if not is_torch_compiling(): # When both max_seqlens are 0, skip bounds checks (skip kernel / empty-batch case). # Mismatch is already caught by the early check, so at this point either both are 0 or both are positive. diff --git a/cosmos_framework/model/attention/cudnn/__init__.py b/cosmos_framework/model/attention/cudnn/__init__.py index ca3697b..3b7453f 100644 --- a/cosmos_framework/model/attention/cudnn/__init__.py +++ b/cosmos_framework/model/attention/cudnn/__init__.py @@ -13,6 +13,12 @@ from cosmos_framework.model.attention.utils.safe_ops import log from cosmos_framework.model.attention.utils.version import version_at_least +# COSMOS-RELEASE-BEGIN-IGNORE +# (ahassani) [11-20-2025] Banning cuDNN until reliability issues are resolved. +# Versions checked: 91300, 91400, 91500 +# (ahassani) [12-01-2025] +# 91500 ran on both GB200 and H100 SXM. +# COSMOS-RELEASE-END-IGNORE CUDNN_DISALLOWED = True CUDNN_MIN_BACKEND_VERSION = 91300 diff --git a/cosmos_framework/model/attention/cudnn/checks.py b/cosmos_framework/model/attention/cudnn/checks.py index 32b4540..ffc5f96 100644 --- a/cosmos_framework/model/attention/cudnn/checks.py +++ b/cosmos_framework/model/attention/cudnn/checks.py @@ -124,7 +124,6 @@ def cudnn_attention_check( causal_type=causal_type, ) - if is_causal and causal_type not in [CausalType.TopLeft, CausalType.DontCare]: target_fn("cuDNN Attention only supports top-left causal masking for now.", exception=RuntimeError) return False diff --git a/cosmos_framework/model/attention/cudnn/cudnn_forward.py b/cosmos_framework/model/attention/cudnn/cudnn_forward.py index 3637f7e..909f9a6 100644 --- a/cosmos_framework/model/attention/cudnn/cudnn_forward.py +++ b/cosmos_framework/model/attention/cudnn/cudnn_forward.py @@ -19,7 +19,6 @@ from cosmos_framework.model.attention.utils.safe_ops import log from cosmos_framework.model.attention.utils.safe_ops.functools import lru_cache - # Force using padded mask as a potential workaround for failing use cases FORCE_PADDED_MASK = False @@ -46,6 +45,17 @@ def get_dtype_choices(arch_tag: int) -> dict: log.debug("cuDNN Attention is not supported because compute capability is below the minimum (8.0).") return {} + # COSMOS-RELEASE-BEGIN-IGNORE + ## not seem to work. + # if arch_tag in [90, 100]: + # log.debug(f"cuDNN Attention supports FP8 for {arch_tag=}.") + # return { + # torch.float16: cudnn.data_type.HALF, + # torch.bfloat16: cudnn.data_type.BFLOAT16, + # torch.float8_e4m3fn: cudnn.data_type.FP8_E4M3, + # torch.float8_e5m2: cudnn.data_type.FP8_E5M2, + # } + # COSMOS-RELEASE-END-IGNORE log.debug(f"cuDNN Attention only supports FP16 and BF16 for {arch_tag=}.") return { @@ -316,11 +326,15 @@ def cudnn_sdpa_fwd_generate_op( handle = cudnn.create_handle() def cudnn_operation(q: Tensor, k: Tensor, v: Tensor, output: Tensor, lse: Tensor | None = None): - + # NOTE: This is INCREDIBLY important to do -- this is what wasted days of my time # with random NaNs and illegal memory accesses and things of that nature. stream = torch.cuda.current_stream(q.device) cudnn.set_stream(handle=handle, stream=stream.cuda_stream) + # COSMOS-RELEASE-BEGIN-IGNORE + # caching allocator plays nicely with the LRU cache over this, but for now let's avoid + # premature optimization. + # COSMOS-RELEASE-END-IGNORE workspace = torch.zeros(workspace_size_bytes, device=device, dtype=torch.uint8) # [workspace_size_bytes] variant_pack = { diff --git a/cosmos_framework/model/attention/cudnn/functions.py b/cosmos_framework/model/attention/cudnn/functions.py index 4087390..3014d74 100644 --- a/cosmos_framework/model/attention/cudnn/functions.py +++ b/cosmos_framework/model/attention/cudnn/functions.py @@ -55,6 +55,10 @@ def forward( padding_Q = 0 padding_KV = 0 + # COSMOS-RELEASE-BEGIN-IGNORE + # but as of 11/12/2025 does not seem to fix any issues. Keeping here in case + # it ever comes back. + # COSMOS-RELEASE-END-IGNORE if CUDNN_PADDING_REQUIRED: Q_multiplier = 256 KV_multiplier = 256 @@ -236,7 +240,6 @@ def cudnn_attention( raise_error=True, ) - assert not is_varlen # cudnn_attention_check should prevent this assertion failing num_heads = query.shape[-2] diff --git a/cosmos_framework/model/attention/cudnn/meta.py b/cosmos_framework/model/attention/cudnn/meta.py index 9a21992..61d8f5b 100644 --- a/cosmos_framework/model/attention/cudnn/meta.py +++ b/cosmos_framework/model/attention/cudnn/meta.py @@ -30,6 +30,9 @@ def get_fwd_dtypes(arch_tag: int) -> list[torch.dtype]: log.debug("cuDNN Attention is not supported because compute capability is below the minimum (8.0).") return [] + # COSMOS-RELEASE-BEGIN-IGNORE + ## not seem to work. + # COSMOS-RELEASE-END-IGNORE log.debug(f"cuDNN Attention only supports FP16 and BF16 for {arch_tag=}.") return [torch.float16, torch.bfloat16] @@ -46,5 +49,4 @@ def get_bwd_dtypes(arch_tag: int) -> list[torch.dtype]: """ - return [] diff --git a/cosmos_framework/model/attention/flash2/__init__.py b/cosmos_framework/model/attention/flash2/__init__.py index 77ada63..85c7cbc 100644 --- a/cosmos_framework/model/attention/flash2/__init__.py +++ b/cosmos_framework/model/attention/flash2/__init__.py @@ -11,6 +11,7 @@ import torch from cosmos_framework.model.attention.utils.safe_ops import log +from cosmos_framework.model.attention.utils.version import version_in_range # We lock to safe releases of Flash 2 # We will have a separate backend identifier for 2025 releases with CuTeDSL @@ -50,15 +51,13 @@ def flash2_supported() -> bool: else: flash2_version_str = flash_attn.__version__ - # Version range check disabled to accept whatever flash_attn the OSS - # container ships. - # if not version_in_range(flash2_version_str, FLASH_ATTENTION_V2_MIN_VERSION, FLASH_ATTENTION_V2_MAX_VERSION): - # log.debug( - # "Flash Attention v2 build is not supported; this backend only supports versions " - # f"{FLASH_ATTENTION_V2_MIN_VERSION} through {FLASH_ATTENTION_V2_MAX_VERSION}, got " - # f"{flash2_version_str}." - # ) - # return False + if not version_in_range(flash2_version_str, FLASH_ATTENTION_V2_MIN_VERSION, FLASH_ATTENTION_V2_MAX_VERSION): + log.debug( + "Flash Attention v2 build is not supported; this backend only supports versions " + f"{FLASH_ATTENTION_V2_MIN_VERSION} through {FLASH_ATTENTION_V2_MAX_VERSION}, got " + f"{flash2_version_str}." + ) + return False return True diff --git a/cosmos_framework/model/attention/flash2/checks.py b/cosmos_framework/model/attention/flash2/checks.py index 5332edb..73fef61 100644 --- a/cosmos_framework/model/attention/flash2/checks.py +++ b/cosmos_framework/model/attention/flash2/checks.py @@ -75,17 +75,12 @@ def flash2_attention_check( ) return False - - # mixed_modality_sft_8b smoke on Blackwell — flash3 isn't built for arch - # 100/103 and natten doesn't support varlen. Revisit before production - # training on this hardware. - # if is_varlen: - # target_fn( - # "Flash Attention v2 (flash2) varlen is banned due to instability. " - # "Please choose another backend.", - # exception=ValueError, - # ) - # return False + if is_varlen: + target_fn( + "Flash Attention v2 (flash2) varlen is banned due to instability. Please choose another backend.", + exception=ValueError, + ) + return False arch_tag = get_arch_tag(device) fwd_dtypes = get_fwd_dtypes(arch_tag) diff --git a/cosmos_framework/model/attention/flash2/functions.py b/cosmos_framework/model/attention/flash2/functions.py index 25cfd1a..2d1f491 100644 --- a/cosmos_framework/model/attention/flash2/functions.py +++ b/cosmos_framework/model/attention/flash2/functions.py @@ -175,7 +175,7 @@ def flash2_attention( assert output.dim() == 4 # [B,N,H,Dv] or [1,total_tokens,H,Dv] assert lse.dim() == 3 # [B,H,N] or [1,H,total_tokens] - + # NOTE: Do NOT call .contiguous on LSE, otherwise Attention Merging backward pass will be # incorrect. All output and lse tensors passed into `merge_attentions` must have the same data # pointer as their corresponding attention autograd ops! lse = lse.permute(0, 2, 1) # [B,N,H] or [1,total_tokens,H] diff --git a/cosmos_framework/model/attention/flash3/functions.py b/cosmos_framework/model/attention/flash3/functions.py index fd8fb7f..76c1505 100644 --- a/cosmos_framework/model/attention/flash3/functions.py +++ b/cosmos_framework/model/attention/flash3/functions.py @@ -15,7 +15,7 @@ from flash_attn_3_nv.flash_attn_interface import flash_attn_func, flash_attn_varlen_func from torch import Tensor - +# NOTE: older commits didn't have `return_attn_probs` as an argument, and there is no # reflection of the commit hash in the version, so we have to manually inspect the signatures HAS_RETURN_ATTN_PROBS = "return_attn_probs" in inspect.signature(flash_attn_func).parameters @@ -190,7 +190,7 @@ def flash3_attention( assert output.dim() == 4 # [B,N,H,Dv] or [1,total_tokens,H,Dv] assert lse.dim() == 3 # [B,H,N] or [1,H,total_tokens] - + # NOTE: Do NOT call .contiguous on LSE, otherwise Attention Merging backward pass will be # incorrect. All output and lse tensors passed into `merge_attentions` must have the same data # pointer as their corresponding attention autograd ops! lse = lse.permute(0, 2, 1) # [B,N,H] or [1,total_tokens,H] diff --git a/cosmos_framework/model/attention/frontend.py b/cosmos_framework/model/attention/frontend.py index d69c980..f2740ee 100644 --- a/cosmos_framework/model/attention/frontend.py +++ b/cosmos_framework/model/attention/frontend.py @@ -28,9 +28,15 @@ from cosmos_framework.model.attention.utils.environment import filter_attention_merge_backends from cosmos_framework.model.attention.utils.safe_ops import log +# COSMOS-RELEASE-BEGIN-IGNORE +# isort: split +from cosmos_framework.model.attention.cudnn import cudnn_attention + +# COSMOS-RELEASE-END-IGNORE # Map backend names to their frontend attention API BACKEND_MAP = { + "cudnn": cudnn_attention, # COSMOS-RELEASE-IGNORELINE "natten": natten_attention, "flash2": flash2_attention, "flash3": flash3_attention, @@ -393,7 +399,7 @@ def multi_dimensional_attention( # Automatic transformation for 1s in token layout # I.e. Attention over a (1, 16, 32) token layout is identical to over a (16, 32) - + # NOTE: assumes QKV token layouts match token_layout_ones = [i for i in range(num_dims) if token_layout_shape[i] == 1] if len(token_layout_ones) > 0: token_layout_t = tuple(s for i, s in enumerate(token_layout_shape) if i not in token_layout_ones) @@ -552,7 +558,7 @@ def multi_dimensional_attention_varlen( value (Tensor): 4-D value tensor with sequence-packed layout (`[1, seqlen_total, heads_kv, head_dim_v]`) - metadata (dict): Pre-computed varlen metadata from `imaginaire.varlen.generate_multi_dim_varlen_parameters`. + metadata (dict): Pre-computed varlen metadata from `cosmos_framework.varlen.generate_multi_dim_varlen_parameters`. scale (float | None): Attention scale. Defaults to head_dim ** -0.5. diff --git a/cosmos_framework/model/attention/natten/checks.py b/cosmos_framework/model/attention/natten/checks.py index 49b1a29..dfae074 100644 --- a/cosmos_framework/model/attention/natten/checks.py +++ b/cosmos_framework/model/attention/natten/checks.py @@ -118,7 +118,7 @@ def choose_natten_backend( target_fn = partial(log_or_raise_error, raise_error=raise_error) - + # NOTE: assumes attention_tensor_checks have already been run once! arch_tag = get_arch_tag(device) is_mla = query_shape[-1] != value_shape[-1] diff --git a/cosmos_framework/model/attention/varlen.py b/cosmos_framework/model/attention/varlen.py index e67ca6b..270e1d2 100644 --- a/cosmos_framework/model/attention/varlen.py +++ b/cosmos_framework/model/attention/varlen.py @@ -23,7 +23,7 @@ def generate_varlen_parameters( ) -> ( tuple[None, None, int, int] | tuple[Tensor, Tensor, int, int] ): # (cumseqlen_Q[B+1], cumseqlen_KV[B+1], max_seqlen_Q, max_seqlen_KV) - + # NOTE: max_seqlen_{Q,KV} require a device-host sync, since they're expected to be ints (with # which we launch the varlen kernel) and not device tensors. # .item() introduces control flow and breaks the graph. # It is also inefficient to repeat this per-op, and mostly there for convenience. @@ -97,7 +97,7 @@ def generate_varlen_parameters( if max_seqlen_Q < 0 or max_seqlen_KV < 0: raise ValueError(f"max_seqlen_Q and max_seqlen_KV cannot be negative, got {max_seqlen_Q=}, {max_seqlen_KV=}.") - + # NOTE: max_seqlen_Q == max_seqlen_KV == 0 is a valid case (skip kernel / empty batch). # This feature may require support in the backends themselves; see NATTEN PR: # https://github.com/SHI-Labs/NATTEN/pull/327 if (max_seqlen_Q == 0) != (max_seqlen_KV == 0): @@ -106,7 +106,7 @@ def generate_varlen_parameters( f"but computed {max_seqlen_Q=}, {max_seqlen_KV=} from provided seqlens." ) - + # NOTE: we have to prepend with 0 manually :( z = torch.tensor([0], dtype=torch.int32, device=seqlens_Q.device) # [1] cumulative_seqlen_Q = torch.cat([z, seqlens_Q.cumsum(0).to(torch.int32)], dim=0) # [B+1] cumulative_seqlen_KV = torch.cat([z, seqlens_KV.cumsum(0).to(torch.int32)], dim=0) # [B+1] diff --git a/cosmos_framework/model/tokenizer/models/__init__.py b/cosmos_framework/model/tokenizer/models/__init__.py index e3c835b..798b8f6 100644 --- a/cosmos_framework/model/tokenizer/models/__init__.py +++ b/cosmos_framework/model/tokenizer/models/__init__.py @@ -11,6 +11,8 @@ # Generic utilities # Metrics (moved from utils to metrics module for consolidation) +from cosmos_framework.model.tokenizer.evaluation.reconstruction_metrics import calculate_psnr + # Dense runtime from cosmos_framework.model.tokenizer.models.dense_runtime import ( DenseAutoencoderRuntime, @@ -46,6 +48,7 @@ # Utils "average_with_scatter_add", "batch_tensor_to_sparse", + "calculate_psnr", "crop_tensors_to_match", "reconstruct_from_temporal_slices", "resize_and_crop", diff --git a/cosmos_framework/model/tokenizer/models/dense_backends.py b/cosmos_framework/model/tokenizer/models/dense_backends.py index 841fd7c..35d9647 100644 --- a/cosmos_framework/model/tokenizer/models/dense_backends.py +++ b/cosmos_framework/model/tokenizer/models/dense_backends.py @@ -12,10 +12,12 @@ import torch.nn as nn import torch.nn.functional as F -from cosmos_framework.model.tokenizer.models.modules.attention.full_attn import tensor_dense_scaled_dot_product_attention +from cosmos_framework.model.tokenizer.models.modules.attention.full_attn import ( + tensor_dense_scaled_dot_product_attention, +) -DenseRuntimeBackend = Literal["varlen", "batched", "auto"] -DenseResolvedBackend = Literal["varlen", "batched"] +DenseRuntimeBackend = Literal["varlen", "batched", "batched_with_padding", "auto"] +DenseResolvedBackend = Literal["varlen", "batched", "batched_with_padding"] def resolve_dense_backend(backend: DenseRuntimeBackend, use_compile: bool) -> DenseResolvedBackend: @@ -33,7 +35,7 @@ def resolve_dense_backend(backend: DenseRuntimeBackend, use_compile: bool) -> De """ if backend == "auto": return "batched" if use_compile else "varlen" - if backend in ("varlen", "batched"): + if backend in ("varlen", "batched", "batched_with_padding"): return backend raise ValueError(f"Unsupported dense runtime backend: {backend}") @@ -69,6 +71,8 @@ def run_varlen_block_stack( def run_batched_block_stack( blocks: nn.ModuleList, feats: torch.Tensor, + cu_seqlens_q: torch.Tensor | None = None, + max_q_seqlen: int | None = None, q_freqs_cis: torch.Tensor | None = None, ) -> torch.Tensor: """Run the dense batched block path over uniform `[B, S, D]` chunks.""" @@ -79,21 +83,31 @@ def run_batched_block_stack( for block in blocks: if block.training and getattr(block, "use_checkpoint", False): output = torch.utils.checkpoint.checkpoint( - partial(run_batched_block, block, q_freqs_cis=q_freqs_cis), + partial( + run_batched_block, + block, + cu_seqlens_q=cu_seqlens_q, + max_q_seqlen=max_q_seqlen, + q_freqs_cis=q_freqs_cis, + ), output, use_reentrant=False, ) else: - output = run_batched_block(block, output, q_freqs_cis=q_freqs_cis) + output = run_batched_block( + block, output, cu_seqlens_q=cu_seqlens_q, max_q_seqlen=max_q_seqlen, q_freqs_cis=q_freqs_cis + ) return output def run_batched_block( block: nn.Module, feats: torch.Tensor, + cu_seqlens_q: torch.Tensor | None = None, + max_q_seqlen: int | None = None, q_freqs_cis: torch.Tensor | None = None, ) -> torch.Tensor: - """Run one transformer block with the dense batched attention path.""" + """Run one transformer block with the dense batched attention path with optional padding.""" if getattr(block, "multiscale", None) is not None: raise NotImplementedError("Dense runtime batched backend does not support multiscale blocks.") if getattr(block.attn, "_type", None) != "self": @@ -101,7 +115,9 @@ def run_batched_block( residual = feats h = block.norm1(feats) - h = run_batched_attention(block.attn, h, q_freqs_cis=q_freqs_cis) + h = run_batched_attention( + block.attn, h, cu_seqlens_q=cu_seqlens_q, max_q_seqlen=max_q_seqlen, q_freqs_cis=q_freqs_cis + ) feats = residual + h residual = feats h = block.norm2(feats) @@ -112,15 +128,19 @@ def run_batched_block( def run_batched_attention( attention: nn.Module, feats: torch.Tensor, + cu_seqlens_q: torch.Tensor | None = None, + max_q_seqlen: int | None = None, q_freqs_cis: torch.Tensor | None = None, ) -> torch.Tensor: - """Run one dense self-attention layer via the imaginaire attention frontend.""" + """Run one dense self-attention layer via the cosmos_framework attention frontend.""" if not hasattr(attention, "to_qkv"): raise ValueError("Dense runtime batched backend requires fused to_qkv linear projections.") if not hasattr(attention, "to_out"): raise ValueError("Dense runtime batched backend requires an output projection linear layer.") + # feats: [B, S_padded, hidden] (S_padded = pad_to tokens per batch item, padded for CUDA graph) batch_size, seq_len, hidden_size = feats.shape + # qkv: [B, S_padded, 3, H, D] qkv = F.linear(feats, attention.to_qkv.weight, attention.to_qkv.bias).reshape( batch_size, seq_len, @@ -128,9 +148,11 @@ def run_batched_attention( attention.num_heads, -1, ) + # q, k, v: [B, S_padded, H, D] q, k, v = qkv.unbind(dim=2) if getattr(attention, "qk_rms_norm", False): + # flatten to [B*S_padded, H, D] for per-token RMSNorm, then restore flat_q = q.reshape(batch_size * seq_len, attention.num_heads, -1) flat_k = k.reshape(batch_size * seq_len, attention.num_heads, -1) q = attention.q_rms_norm(flat_q).reshape(batch_size, seq_len, attention.num_heads, -1) @@ -139,6 +161,7 @@ def run_batched_attention( if getattr(attention, "use_rope", False): if q_freqs_cis is None: raise ValueError("Dense runtime batched backend requires precomputed q_freqs_cis when RoPE is enabled.") + # flatten to [B*S_padded, H, D] for RoPE application, then restore to [B, S_padded, H, D] flat_q = q.reshape(batch_size * seq_len, attention.num_heads, -1) flat_k = k.reshape(batch_size * seq_len, attention.num_heads, -1) flat_q, flat_k = attention.rope.apply_rotary_emb( @@ -150,6 +173,16 @@ def run_batched_attention( q = flat_q.reshape(batch_size, seq_len, attention.num_heads, -1) k = flat_k.reshape(batch_size, seq_len, attention.num_heads, -1) - h = tensor_dense_scaled_dot_product_attention(q=q, k=k, v=v) + # q, k, v: [B, S_padded, H, D] → attention → h: [B, S_padded, H, D] + h = tensor_dense_scaled_dot_product_attention( + q=q, + k=k, + v=v, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_kv=cu_seqlens_q, + max_q_seqlen=max_q_seqlen, + max_kv_seqlen=max_q_seqlen, + ) + # h: [B, S_padded, hidden] h = h.reshape(batch_size, seq_len, hidden_size) return F.linear(h, attention.to_out.weight, attention.to_out.bias) diff --git a/cosmos_framework/model/tokenizer/models/dense_runtime.py b/cosmos_framework/model/tokenizer/models/dense_runtime.py index 02a938e..3133108 100644 --- a/cosmos_framework/model/tokenizer/models/dense_runtime.py +++ b/cosmos_framework/model/tokenizer/models/dense_runtime.py @@ -115,11 +115,42 @@ def __init__( self, autoencoder: AutoencoderKL, backend: DenseRuntimeBackend = "auto", + pad_frames: int = 0, + pixel_trim: bool = True, + chunk_size: int = 16, ) -> None: - """Initialize the dense runtime wrapper.""" + """Initialize the dense runtime wrapper. + + Args: + autoencoder: The sparse autoencoder to wrap. + backend: Backend selection for block-stack execution. + pad_frames: Number of boundary frames to replicate at each end of + every temporal chunk before encoding. Must be divisible by + ``patch_size[0]``. Set ``0`` to disable boundary padding; + set ``>0`` (typically one temporal patch, e.g. ``4``) to give + the non-causal encoder additional context across chunk edges, + eliminating the per-chunk-boundary PSNR dip. + pixel_trim: When ``True`` and ``pad_frames > 0``, boundary latents + are kept in the encoded output and trimmed in pixel space after + decoding. When ``False``, boundary latents are trimmed + immediately after encoding. ``True`` should always be used + for the best reconstruction quality. + chunk_size: Number of *raw* frames consumed by the encoder per + temporal chunk. Forwarded to + ``autoencoder.num_sample_frames_batch_size`` and used to + slice the input video into encode batches. Must satisfy + ``2 * pad_frames < chunk_size``. Default ``16``. + """ super().__init__() self.autoencoder = autoencoder self.backend = backend + autoencoder.num_sample_frames_batch_size = chunk_size + if pad_frames < 0: + raise ValueError(f"pad_frames must be non-negative, got {pad_frames}.") + if 2 * pad_frames >= chunk_size: + raise ValueError(f"pad_frames must be less than chunk_size / 2, got {pad_frames=}, {chunk_size=}.") + self.pad_frames = pad_frames + self.pixel_trim = pixel_trim self._metadata_cache = {} self.cg_compiled = False @@ -128,10 +159,19 @@ def from_autoencoder( cls, autoencoder: AutoencoderKL, backend: DenseRuntimeBackend = "auto", + pad_frames: int = 0, + pixel_trim: bool = True, + chunk_size: int = 16, ) -> "DenseAutoencoderRuntime": """Build a dense runtime from a supported sparse autoencoder.""" cls._validate_autoencoder(autoencoder) - return cls(autoencoder=autoencoder, backend=backend) + return cls( + autoencoder=autoencoder, + backend=backend, + pad_frames=pad_frames, + pixel_trim=pixel_trim, + chunk_size=chunk_size, + ) @staticmethod def _validate_autoencoder(autoencoder: AutoencoderKL) -> None: @@ -204,9 +244,15 @@ def clear_metadata_cache(self) -> None: """Drop cached dense-grid metadata.""" self._metadata_cache.clear() - def encode(self, dense_video: torch.Tensor, sample_posterior: bool = False) -> torch.Tensor: + def encode( + self, + dense_video: torch.Tensor, + sample_posterior: bool = False, + pad_to: int | None = None, + chunk_raw_frames: int | None = None, + ) -> torch.Tensor: """Encode a dense video tensor into latent moments or posterior samples.""" - moments = self.encode_moments(dense_video) + moments = self.encode_moments(dense_video, chunk_raw_frames=chunk_raw_frames, pad_to=pad_to) if not sample_posterior: return moments return self._sample_dense_posterior(moments) @@ -215,8 +261,44 @@ def encode_moments( self, video: torch.Tensor, chunk_raw_frames: int | None = None, + pad_to: int | None = None, ) -> torch.Tensor: - """Encode a dense video tensor into `[B, T_p, H_p, W_p, 2C]` latent moments.""" + """Encode a dense video tensor into `[B, T_p, H_p, W_p, 2C]` latent moments. + + Args: + video: Dense channels-last video tensor ``[B, T, H, W, 3]``. + chunk_raw_frames: Number of raw frames per encoder chunk. Defaults + to ``self.encoder_chunk_spec.raw_frames``. + pad_to: Sequence-length padding target for the ``batched_with_padding`` + backend (reduces CUDA graph recapture). + + Shapes (example): + Config: ``patch_size = (1, 16, 16)``, ``chunk_size = 16``, + ``pad_frames = 1`` (1 raw frame replicated on each chunk edge). + Whole-video input: ``[B=1, T=28, H=480, W=832, 3]``. + + Per-chunk pipeline (loop slices the 28 frames into 2 chunks of + ``chunk_raw_frames = 16 - 2*1 = 14``): + + :: + + step shape notes + --------------------------------------------------------------------------------------- + 1. raw chunk [1, 14, 480, 832, 3] 1 of 2 chunks + 2. after input padding [1, 16, 480, 832, 3] 1 pre + 14 raw + 1 post + 3. after encoding (latent) [1, 16, 30, 52, 2C] T_p=16/1, H_p=480/16, W_p=832/16 + 4. after decoding [1, 16, 480, 832, 3] + 5. after pixel trim [1, 14, 480, 832, 3] drops pad_frames=1 pixel frame + on each end + + Across both chunks the concatenated pixel-space output is + ``[1, 28, 480, 832, 3]``; the latent fed to a downstream DiT is + ``[1, 32, 30, 52, 2C]``. + + For images (``T = 1``) the input is repeated to one temporal patch + (``T = patch_time``) and ``latents_per_boundary = 0``, so the + DiT-facing shape is ``[B, 1, H_p, W_p, 2C]``. + """ if video.ndim != 5: raise ValueError(f"Dense runtime expects 5D video tensor, got {video.ndim}D") if video.shape[4] != 3: @@ -224,30 +306,78 @@ def encode_moments( batch_size, raw_frames, height, width, _ = video.shape patch_time, patch_height, patch_width = self.patch_size - if raw_frames % patch_time != 0: + + if chunk_raw_frames is None: + chunk_raw_frames = self.encoder_chunk_spec.raw_frames + chunk_raw_frames = chunk_raw_frames - 2 * self.pad_frames + assert chunk_raw_frames > 0, ( + f"Padding frames must be less than chunk_raw_frames, got {chunk_raw_frames=}, {self.pad_frames=}." + ) + if chunk_raw_frames <= 0: + raise ValueError(f"chunk_raw_frames must be positive, got {chunk_raw_frames}.") + + # if input is an image, we pad to form single temporal patch + if raw_frames == 1: + is_image = True + video = video.repeat(1, patch_time, 1, 1, 1) + raw_frames = patch_time + else: + is_image = False + + if (chunk_raw_frames + 2 * self.pad_frames) % patch_time != 0: + raise ValueError( + f"chunk_raw_frames + 2 * pad_frames must be divisible by patch_size[0]={patch_time}, got {chunk_raw_frames=}, {self.pad_frames=}." + ) + + remainder = raw_frames % chunk_raw_frames + if remainder != 0 and (remainder + 2 * self.pad_frames) % patch_time != 0 and not is_image: raise ValueError( - f"Dense runtime requires frame count divisible by patch_size[0]={patch_time}, got {raw_frames}." + f"Dense runtime requires frame count equal to chunk_raw_frames * N + patch_time - (2 * pad_frames), got {raw_frames=}, {chunk_raw_frames=}, {self.pad_frames=}, {patch_time=}." ) if height % patch_height != 0 or width % patch_width != 0: raise ValueError( "Dense runtime requires spatial dimensions divisible by patch size " f"{(patch_height, patch_width)}, got {(height, width)}." ) + pad_frames = self.pad_frames + if not is_image: + latents_per_boundary = pad_frames // patch_time + else: + latents_per_boundary = 0 del batch_size - if chunk_raw_frames is None: - chunk_raw_frames = self.encoder_chunk_spec.raw_frames - if chunk_raw_frames <= 0: - raise ValueError(f"chunk_raw_frames must be positive, got {chunk_raw_frames}.") - if chunk_raw_frames % patch_time != 0: - raise ValueError( - f"chunk_raw_frames must be divisible by patch_size[0]={patch_time}, got {chunk_raw_frames}." - ) + + # preserve the chunk size to reduce number of captured cuda graphs + if self.backend == "batched_with_padding" and pad_to is None and self.cg_compiled: + width_patches = width // patch_width + height_patches = height // patch_height + padded_chunk_frames = chunk_raw_frames + 2 * pad_frames + temporal_patches = padded_chunk_frames // patch_time + pad_to = width_patches * height_patches * temporal_patches + encoded_chunks: list[torch.Tensor] = [] for start_frame in range(0, raw_frames, chunk_raw_frames): end_frame = min(start_frame + chunk_raw_frames, raw_frames) video_chunk = video[:, start_frame:end_frame] - encoded_chunk = self._encode_video_chunk(video_chunk) + + if pad_frames > 0 and not is_image: + # UniAE chunk-wise encoding suffers a PSNR dip at chunk boundaries + # because the non-causal encoder lacks context beyond the chunk edges. + # Padding each chunk with pad_frames replicated boundary frames on both + # sides gives the encoder that context, eliminating the boundary dip. + # In practice pad_frames=4 (one temporal patch) is used. + # The corresponding boundary latents are trimmed after decoding + # (see pixel_trim / latents_per_boundary below). + pre = video_chunk[:, 0:1].expand(-1, pad_frames, -1, -1, -1) + post = video_chunk[:, -1:].expand(-1, pad_frames, -1, -1, -1) + video_chunk = torch.cat([pre, video_chunk, post], dim=1) + + encoded_chunk = self._encode_video_chunk(video_chunk, pad_to=pad_to) + + if latents_per_boundary > 0 and not self.pixel_trim: + t_latent = encoded_chunk.shape[1] + encoded_chunk = encoded_chunk[:, latents_per_boundary : t_latent - latents_per_boundary] + encoded_chunks.append(encoded_chunk) return torch.cat(encoded_chunks, dim=1) @@ -256,7 +386,12 @@ def decode( dense_latent: torch.Tensor, chunk_raw_frames: int | None = None, ) -> torch.Tensor: - """Decode a dense latent grid into a dense channels-last video tensor.""" + """Decode a dense latent grid into a dense channels-last video tensor. + + When ``pixel_trim`` is enabled and ``pad_frames > 0``, the latent + contains boundary tokens from encoding. After decoding, the + corresponding boundary pixel frames are trimmed from each chunk. + """ if self.decoder_cache_spec.patch_frames != 0: raise NotImplementedError("Dense runtime decoder V1 does not support KV cache.") @@ -272,11 +407,18 @@ def decode( f"chunk_raw_frames must be divisible by patch_size[0]={self.patch_size[0]}, got {chunk_raw_frames}." ) chunk_patch_frames = chunk_raw_frames // self.patch_size[0] + + pad_frames = self.pad_frames + trim_pixel = self.pixel_trim and pad_frames > 0 + decoded_chunks: list[torch.Tensor] = [] for start_patch in range(0, temporal_patches, chunk_patch_frames): end_patch = min(start_patch + chunk_patch_frames, temporal_patches) latent_chunk = latent[:, start_patch:end_patch] - decoded_chunks.append(self._decode_latent_chunk(latent_chunk)) + decoded_chunk = self._decode_latent_chunk(latent_chunk) + if trim_pixel: + decoded_chunk = decoded_chunk[:, pad_frames:-pad_frames] + decoded_chunks.append(decoded_chunk) return torch.cat(decoded_chunks, dim=1) def _metadata_cache_key( @@ -361,13 +503,22 @@ def _canonicalize_dense_latent(self, dense_latent: torch.Tensor) -> torch.Tensor ) return latent.contiguous() - def _encode_video_chunk(self, dense_video_chunk: torch.Tensor) -> torch.Tensor: + def _encode_video_chunk( + self, + dense_video_chunk: torch.Tensor, + pad_to: int | None = None, + ) -> torch.Tensor: """Encode one dense video chunk into projected latent moments.""" + assert pad_to is None or self.backend == "batched_with_padding", ( + "pad_to is only supported for batched_with_padding backend" + ) + batch_size, raw_frames, height, width, _ = dense_video_chunk.shape patch_time, patch_height, patch_width = self.patch_size temporal_patches = raw_frames // patch_time height_patches = height // patch_height width_patches = width // patch_width + seq_len = temporal_patches * height_patches * width_patches patch_feats = self._patchify_dense_video(dense_video_chunk) metadata = self._get_or_build_grid_metadata( @@ -380,14 +531,43 @@ def _encode_video_chunk(self, dense_video_chunk: torch.Tensor) -> torch.Tensor: device=patch_feats.device, dtype=self.autoencoder.encoder.input_layer.weight.dtype, ) + + learned_pe = metadata.learned_pe + rope_freqs_cis = metadata.rope_freqs_cis + + needs_padding = pad_to is not None and pad_to > seq_len + if pad_to is not None and pad_to < seq_len: + raise ValueError(f"pad_to ({pad_to}) must be >= sequence length ({seq_len}).") + if needs_padding: + if batch_size != 1: + raise ValueError( + f"pad_to requires batch_size=1 for correct varlen masking, got batch_size={batch_size}." + ) + pad_amount = pad_to - seq_len + patch_feats = F.pad(patch_feats, (0, 0, 0, pad_amount)) + if learned_pe is not None: + learned_pe = F.pad(learned_pe, (0, 0, 0, pad_amount)) + if rope_freqs_cis is not None: + rope_pad = torch.zeros( + pad_amount, + rope_freqs_cis.shape[-1], + dtype=rope_freqs_cis.dtype, + device=rope_freqs_cis.device, + ) + rope_freqs_cis = torch.cat([rope_freqs_cis, rope_pad], dim=0) + moments = self._encode_chunk_core( patch_feats, - learned_pe=metadata.learned_pe, - rope_freqs_cis=metadata.rope_freqs_cis, + learned_pe=learned_pe, + rope_freqs_cis=rope_freqs_cis, q_seqlen=metadata.q_seqlen, cu_seqlens_q=metadata.cu_seqlens, - max_q_seqlen=metadata.max_seq_len, + max_q_seqlen=metadata.max_seq_len if not needs_padding else pad_to, ) + + if needs_padding: + moments = moments[:, :seq_len] + if self.cg_compiled: moments = moments.clone() return moments.reshape(batch_size, temporal_patches, height_patches, width_patches, -1) @@ -568,6 +748,17 @@ def _run_block_stack( feats, q_freqs_cis=rope_freqs_cis, ) + if backend == "batched_with_padding": + assert feats.shape[0] == 1, ( + "batched_with_padding backend only supports batch_size=1, due to varlen kernel requirements." + ) + return run_batched_block_stack( + blocks, + feats, + cu_seqlens_q=cu_seqlens_q, + max_q_seqlen=max_q_seqlen, + q_freqs_cis=rope_freqs_cis, + ) raise ValueError(f"Unsupported dense runtime backend: {backend}") def _get_or_build_grid_metadata( diff --git a/cosmos_framework/model/tokenizer/models/modules/attention/full_attn.py b/cosmos_framework/model/tokenizer/models/modules/attention/full_attn.py index d5a80e1..78b077f 100644 --- a/cosmos_framework/model/tokenizer/models/modules/attention/full_attn.py +++ b/cosmos_framework/model/tokenizer/models/modules/attention/full_attn.py @@ -85,8 +85,12 @@ def tensor_dense_scaled_dot_product_attention( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, + cu_seqlens_q: torch.Tensor | None = None, + cu_seqlens_kv: torch.Tensor | None = None, + max_q_seqlen: int | None = None, + max_kv_seqlen: int | None = None, ) -> torch.Tensor: - """Apply dense batched attention via the imaginaire attention frontend.""" + """Apply dense batched attention via the cosmos_framework attention frontend.""" if q.ndim != 4 or k.ndim != 4 or v.ndim != 4: raise ValueError( "Dense tensor attention expects [B, S, H, D]-style tensors, " @@ -99,6 +103,10 @@ def tensor_dense_scaled_dot_product_attention( query=q.contiguous(), key=k.contiguous(), value=v.contiguous(), + cumulative_seqlen_Q=cu_seqlens_q, + cumulative_seqlen_KV=cu_seqlens_kv, + max_seqlen_Q=max_q_seqlen, + max_seqlen_KV=max_kv_seqlen, ) diff --git a/cosmos_framework/model/tokenizer/models/modules/quantizers/residual_vq.py b/cosmos_framework/model/tokenizer/models/modules/quantizers/residual_vq.py index d2f160d..806ac12 100644 --- a/cosmos_framework/model/tokenizer/models/modules/quantizers/residual_vq.py +++ b/cosmos_framework/model/tokenizer/models/modules/quantizers/residual_vq.py @@ -291,6 +291,78 @@ def __init__( self.commitment_loss = commitment_loss + def to_code_shape(self, x: torch.Tensor) -> torch.Tensor: + """Reshape dense latent features to code-grid feature vectors.""" + embed_dim = self.codebooks[0].weight.shape[-1] + if x.ndim == 2 and x.shape[-1] == embed_dim: + return x # [N,E] + + if x.ndim != 4 or tuple(x.shape[1:]) != tuple(self.latent_shape): + raise ValueError( + f"Expected latent shape [B,{tuple(self.latent_shape)}] or [N,{embed_dim}], got {tuple(x.shape)}." + ) + + batch_size = x.shape[0] + latent_h, latent_w, latent_dim = [int(dim) for dim in self.latent_shape] + code_h, code_w, _ = [int(dim) for dim in self.code_shape] + height_factor = latent_h // code_h + width_factor = latent_w // code_w + + x = x.reshape(batch_size, code_h, height_factor, code_w, width_factor, latent_dim) # [B,h,Hs,w,Ws,D] + x = x.permute(0, 1, 3, 2, 4, 5).contiguous() # [B,h,w,Hs,Ws,D] + return x.reshape(batch_size, code_h, code_w, embed_dim) # [B,h,w,E] + + def to_latent_shape(self, embeds: torch.Tensor) -> torch.Tensor: + """Reshape code-grid embeddings back to dense latent layout.""" + embed_dim = self.codebooks[0].weight.shape[-1] + if embeds.ndim == 2 and embeds.shape[-1] == embed_dim: + return embeds # [N,E] + + code_h, code_w, _ = [int(dim) for dim in self.code_shape] + if embeds.ndim != 4 or tuple(embeds.shape[1:3]) != (code_h, code_w) or embeds.shape[-1] != embed_dim: + raise ValueError( + f"Expected code embedding shape [B,{code_h},{code_w},{embed_dim}] or [N,{embed_dim}], " + f"got {tuple(embeds.shape)}." + ) + + batch_size = embeds.shape[0] + latent_h, latent_w, latent_dim = [int(dim) for dim in self.latent_shape] + height_factor = latent_h // code_h + width_factor = latent_w // code_w + + embeds = embeds.reshape(batch_size, code_h, code_w, height_factor, width_factor, latent_dim) # [B,h,w,Hs,Ws,D] + embeds = embeds.permute(0, 1, 3, 2, 4, 5).contiguous() # [B,h,Hs,w,Ws,D] + return embeds.reshape(batch_size, latent_h, latent_w, latent_dim) # [B,H,W,D] + + def _embed_code_slices(self, code: torch.Tensor) -> list[torch.Tensor]: + """Look up per-depth codebook embeddings without summing codebook depth.""" + if code.shape[-1] != self.code_shape[-1]: + raise ValueError(f"Expected code depth {self.code_shape[-1]}, got code shape {tuple(code.shape)}.") + + code = code.long() # [...,Dq] + code_slices = torch.chunk(code, chunks=code.shape[-1], dim=-1) # list[[...,1]] + + if self.shared_codebook: + embeds = [self.codebooks[0].embed(code_slice) for code_slice in code_slices] # list[[...,1,E]] + else: + embeds = [ + self.codebooks[i].embed(code_slice) for i, code_slice in enumerate(code_slices) + ] # list[[...,1,E]] + + return embeds + + def get_codes_from_indices(self, indices: torch.Tensor) -> torch.Tensor: + """Decode flat residual-quantizer indices to summed embedding vectors.""" + if indices.ndim == 1: + if self.code_shape[-1] != 1: + raise ValueError( + f"Flat indices require one codebook, but this RQ bottleneck has depth {self.code_shape[-1]}." + ) + indices = indices.unsqueeze(-1) # [N,1] + + embeds = self._embed_code_slices(indices) # list[[...,1,E]] + return torch.cat(embeds, dim=-2).sum(-2) # [...,E] + def quantize(self, x: torch.Tensor) -> tuple[list[torch.Tensor], torch.Tensor]: """Quantize input using residual quantization. @@ -365,22 +437,22 @@ def embed_code(self, code: torch.Tensor) -> torch.Tensor: """Decode codes to embeddings. Args: - code: Code tensor of shape (B, h, w, d). + code: Code tensor of shape (B, h, w, d) or flat shape (N, d). Returns: - Embedded features of shape (B, H, W, embed_dim). + Embedded features of shape (B, H, W, D) or flat shape (N, embed_dim). """ - assert code.shape[1:] == self.code_shape - - code_slices = torch.chunk(code, chunks=code.shape[-1], dim=-1) + if code.ndim == 2: + return self.get_codes_from_indices(code) # [N,E] - if self.shared_codebook: - embeds = [self.codebooks[0].embed(code_slice) for i, code_slice in enumerate(code_slices)] - else: - embeds = [self.codebooks[i].embed(code_slice) for i, code_slice in enumerate(code_slices)] + if tuple(code.shape[1:]) != tuple(self.code_shape): + raise ValueError( + f"Expected code shape [B,{tuple(self.code_shape)}] or [N,{self.code_shape[-1]}], " + f"got {tuple(code.shape)}." + ) - embeds = torch.cat(embeds, dim=-2).sum(-2) - embeds = self.to_latent_shape(embeds) + embeds = self.get_codes_from_indices(code) # [B,h,w,E] + embeds = self.to_latent_shape(embeds) # [B,H,W,D] return embeds @@ -397,18 +469,11 @@ def embed_code_with_depth(self, code: torch.Tensor, to_latent_shape: bool = Fals Returns: Tuple of (embedded features, None). """ - assert code.shape[-1] == self.code_shape[-1] - - code_slices = torch.chunk(code, chunks=code.shape[-1], dim=-1) - - if self.shared_codebook: - embeds = [self.codebooks[0].embed(code_slice) for i, code_slice in enumerate(code_slices)] - else: - embeds = [self.codebooks[i].embed(code_slice) for i, code_slice in enumerate(code_slices)] + embeds = self._embed_code_slices(code) # list[[...,1,E]] if to_latent_shape: - embeds = [self.to_latent_shape(embed.squeeze(-2)).unsqueeze(-2) for embed in embeds] - embeds = torch.cat(embeds, dim=-2) + embeds = [self.to_latent_shape(embed.squeeze(-2)).unsqueeze(-2) for embed in embeds] # list[[B,H,W,1,D]] + embeds = torch.cat(embeds, dim=-2) # [...,Dq,E] or [B,H,W,Dq,D] return embeds, None @@ -429,25 +494,23 @@ def embed_partial_code( Returns: Quantized feature map. """ - assert code.shape[1:] == self.code_shape - assert code_idx < code.shape[-1] + if tuple(code.shape[1:]) != tuple(self.code_shape): + raise ValueError(f"Expected code shape [B,{tuple(self.code_shape)}], got {tuple(code.shape)}.") + if code_idx >= code.shape[-1]: + raise ValueError(f"code_idx must be smaller than code depth {code.shape[-1]}, got {code_idx}.") B, h, w, _ = code.shape - code_slices = torch.chunk(code, chunks=code.shape[-1], dim=-1) - if self.shared_codebook: - embeds = [self.codebooks[0].embed(code_slice) for i, code_slice in enumerate(code_slices)] - else: - embeds = [self.codebooks[i].embed(code_slice) for i, code_slice in enumerate(code_slices)] + embeds = self._embed_code_slices(code) # list[[B,h,w,1,E]] if decode_type == "select": - embeds = embeds[code_idx].view(B, h, w, -1) + embeds = embeds[code_idx].view(B, h, w, -1) # [B,h,w,E] elif decode_type == "add": - embeds = torch.cat(embeds[: code_idx + 1], dim=-2).sum(-2) + embeds = torch.cat(embeds[: code_idx + 1], dim=-2).sum(-2) # [B,h,w,E] else: raise NotImplementedError(f"{decode_type} is not implemented in partial decoding") - embeds = self.to_latent_shape(embeds) + embeds = self.to_latent_shape(embeds) # [B,H,W,D] return embeds @@ -468,30 +531,30 @@ def get_soft_codes( Returns: Tuple of (soft codes, hard codes). """ - x = self.to_code_shape(x) + x = self.to_code_shape(x) # [N,E] or [B,h,w,E] - residual_feature = x.detach().clone() + residual_feature = x.detach().clone() # [N,E] or [B,h,w,E] soft_code_list = [] code_list = [] n_codebooks = self.code_shape[-1] for i in range(n_codebooks): codebook = self.codebooks[i] - distances = codebook.compute_distances(residual_feature) - soft_code = F.softmax(-distances / temp, dim=-1) + distances = codebook.compute_distances(residual_feature) # [N,K] or [B,h,w,K] + soft_code = F.softmax(-distances / temp, dim=-1) # [N,K] or [B,h,w,K] if stochastic: - soft_code_flat = soft_code.reshape(-1, soft_code.shape[-1]) - code = torch.multinomial(soft_code_flat, 1) - code = code.reshape(*soft_code.shape[:-1]) + soft_code_flat = soft_code.reshape(-1, soft_code.shape[-1]) # [M,K] + code = torch.multinomial(soft_code_flat, 1) # [M,1] + code = code.reshape(*soft_code.shape[:-1]) # [N] or [B,h,w] else: - code = distances.argmin(dim=-1) - quants = codebook.embed(code) - residual_feature -= quants + code = distances.argmin(dim=-1) # [N] or [B,h,w] + quants = codebook.embed(code) # [N,E] or [B,h,w,E] + residual_feature -= quants # [N,E] or [B,h,w,E] code_list.append(code.unsqueeze(-1)) soft_code_list.append(soft_code.unsqueeze(-2)) - code = torch.cat(code_list, dim=-1) - soft_code = torch.cat(soft_code_list, dim=-2) + code = torch.cat(code_list, dim=-1) # [N,Dq] or [B,h,w,Dq] + soft_code = torch.cat(soft_code_list, dim=-2) # [N,Dq,K] or [B,h,w,Dq,K] return soft_code, code diff --git a/cosmos_framework/model/tokenizer/models/sparse_autoencoder.py b/cosmos_framework/model/tokenizer/models/sparse_autoencoder.py index 2da33f6..9354636 100644 --- a/cosmos_framework/model/tokenizer/models/sparse_autoencoder.py +++ b/cosmos_framework/model/tokenizer/models/sparse_autoencoder.py @@ -1517,7 +1517,7 @@ def __init__( self.post_logit_scale = None self.post_logit_bias = None - # Text decoder (Qwen3-based causal LM for image-to-text generation) + # Text decoder (configured causal LM for image-to-text generation) if self.use_text_decoder and text_decoder_model_name is not None: from cosmos_framework.model.tokenizer.models.text_decoder import ( TextDecoderWrapper, diff --git a/cosmos_framework/model/tokenizer/models/text_decoder.py b/cosmos_framework/model/tokenizer/models/text_decoder.py index 322dd67..9f72370 100644 --- a/cosmos_framework/model/tokenizer/models/text_decoder.py +++ b/cosmos_framework/model/tokenizer/models/text_decoder.py @@ -1257,8 +1257,8 @@ def generate_answer( """Generate an answer to a question about an image. Qwen3 uses its native chat template. Nemotron uses its own native - chat template with ``<|im_start|>``, ``<|im_end|>``, and - ```` in no-thinking mode. + chat template with ``<|im_start|>``, ``<|im_end|>``, and the + ``\n`` no-thinking prefix. Args: image_feats_tensor: [N, encoder_dim] features for ONE image. diff --git a/cosmos_framework/model/tokenizer/models/utils.py b/cosmos_framework/model/tokenizer/models/utils.py index 33eaad8..a5b923b 100644 --- a/cosmos_framework/model/tokenizer/models/utils.py +++ b/cosmos_framework/model/tokenizer/models/utils.py @@ -10,7 +10,7 @@ - Temporal utilities: split_temporal_dimension, restore_original_shape, reconstruct_from_temporal_slices -Note: Metrics like calculate_psnr have been moved to projects.cosmos3.tokenizer.evaluation.reconstruction_metrics +Note: Metrics like calculate_psnr have been moved to cosmos_framework.model.tokenizer.evaluation.reconstruction_metrics """ from __future__ import annotations @@ -561,7 +561,7 @@ def resize_and_crop( # ============================================================================= # Logging # ============================================================================= -# Note: calculate_psnr has been moved to projects.cosmos3.tokenizer.evaluation.reconstruction_metrics +# Note: calculate_psnr has been moved to cosmos_framework.model.tokenizer.evaluation.reconstruction_metrics # ============================================================================= diff --git a/cosmos_framework/model/vfm/diffusion/samplers/edm.py b/cosmos_framework/model/vfm/diffusion/samplers/edm.py index ed5e156..17cd296 100644 --- a/cosmos_framework/model/vfm/diffusion/samplers/edm.py +++ b/cosmos_framework/model/vfm/diffusion/samplers/edm.py @@ -199,7 +199,7 @@ def fori_loop(lower: int, upper: int, body_fun: Callable[[int, Any], Any], init_ """ val = init_val for i in range(lower, upper): - # Periodic log during sampling so long-running jobs keep producing output. + # Add log during sampling to meet APS job health requirement of one log every 2mins if i % 10 == 0: log.info(f"fori_loop: {i}") val = body_fun(i, val) diff --git a/cosmos_framework/model/vfm/diffusion/samplers/fm_solvers_unipc.py b/cosmos_framework/model/vfm/diffusion/samplers/fm_solvers_unipc.py index 2c87fea..92f0c44 100644 --- a/cosmos_framework/model/vfm/diffusion/samplers/fm_solvers_unipc.py +++ b/cosmos_framework/model/vfm/diffusion/samplers/fm_solvers_unipc.py @@ -1,6 +1,6 @@ # Copyright 2024 TSAIL Team and The HuggingFace Team. All rights reserved. # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: OpenMDW-1.1 # # Copied from https://github.com/huggingface/diffusers/blob/v0.31.0/src/diffusers/schedulers/scheduling_unipc_multistep.py diff --git a/cosmos_framework/model/vfm/hf_model.py b/cosmos_framework/model/vfm/hf_model.py index 8b8efc0..8b4ebdf 100644 --- a/cosmos_framework/model/vfm/hf_model.py +++ b/cosmos_framework/model/vfm/hf_model.py @@ -259,8 +259,8 @@ def load_weights( ``org/model`` repo IDs fall back to Hugging Face. credential_path: S3 credential file, or None for local/HF. parallel_dims: ``ParallelDims`` instance (from - ``projects.cosmos3.vfm.utils.parallelism``). The loader uses - it via :func:`~projects.cosmos3.vfm.models.utils.safetensors_loader._get_dp_shard_mesh` + ``cosmos_framework.utils.vfm.parallelism``). The loader uses + it via :func:`~cosmos_framework.model.vfm.utils.safetensors_loader._get_dp_shard_mesh` to obtain the 1-D ``dp_shard`` sub-mesh (or None when ``dp_shard <= 1``) for striping checkpoint reads across FSDP shard ranks. When non-None, the caller MUST have @@ -317,7 +317,7 @@ def load_weights( "raw_video", # image_sizes is collected by collate_fn but is NOT a Qwen3-VL forward arg # (Qwen3-VL uses image_grid_thw instead). Strip it so strict HF signatures - # don't reject it. + # don't reject it. NOTE: image_sizes IS valid for LLaVA-style models — if # a future Phase extends to those, remove this entry. "image_sizes", } diff --git a/cosmos_framework/model/vfm/mot/attention.py b/cosmos_framework/model/vfm/mot/attention.py index 002ca62..50fd413 100644 --- a/cosmos_framework/model/vfm/mot/attention.py +++ b/cosmos_framework/model/vfm/mot/attention.py @@ -98,9 +98,19 @@ def two_way_attention( sample_offsets = packed_query_states["sample_offsets"] + # COSMOS-RELEASE-BEGIN-IGNORE + # NOTE: we can only use the don't care causal mask when we know seqlen_Q == seqlen_KV. + # Since this is a varlen use case, we would need to statically check all Q and KV offsets + # are the same. + # We don't want to launch a kernel just to perform this check and slow down our model, + # and we don't want to just assume no one is going to copy this piece of code without + # reading this, and we definitely don't want to complicate the sequence_packing code so that + # it performs a static check when creating the packed sequence and metadata, so we can just rely + # on causal_q_offsets and causal_k_offsets being the same tensor. + # COSMOS-RELEASE-END-IGNORE use_dont_care_mask = causal_q_offsets is causal_k_offsets - + # NOTE: cosmos_framework attention is BSHD in, BSHD out causal_res = attention( causal_q.unsqueeze(0), # [1,N_und,heads,head_dim] causal_k.unsqueeze(0), # [1,N_und,heads,head_dim] @@ -176,9 +186,19 @@ def three_way_attention( ).reshape(-1) full_v[null_positions] = 0 + # COSMOS-RELEASE-BEGIN-IGNORE + # NOTE: we can only use the don't care causal mask when we know seqlen_Q == seqlen_KV. + # Since this is a varlen use case, we would need to statically check all Q and KV offsets + # are the same. + # We don't want to launch a kernel just to perform this check and slow down our model, + # and we don't want to just assume no one is going to copy this piece of code without + # reading this, and we definitely don't want to complicate the sequence_packing code so that + # it performs a static check when creating the packed sequence and metadata, so we can just rely + # on causal_q_offsets and causal_k_offsets being the same tensor. + # COSMOS-RELEASE-END-IGNORE use_dont_care_mask = causal_q_offsets is causal_k_offsets - + # NOTE: cosmos_framework attention is BSHD in, BSHD out causal_res = attention( causal_q.unsqueeze(0), # [1,N_und,heads,head_dim] causal_k.unsqueeze(0), # [1,N_und,heads,head_dim] diff --git a/cosmos_framework/model/vfm/mot/attention_test.py b/cosmos_framework/model/vfm/mot/attention_test.py index 66b58d5..d41fdff 100644 --- a/cosmos_framework/model/vfm/mot/attention_test.py +++ b/cosmos_framework/model/vfm/mot/attention_test.py @@ -244,7 +244,7 @@ def forward(self, *args, **kwargs): kwargs["sdpa_func"] = self.sdpa_func return self.attention_func(*args, **kwargs) - + # NOTE: we should try and maintain only one copy of QKV offsets if they're identical # between queries and key/values, since this enables the "don't care" mask, which enables # more attention backends in I4 attention. if query_factored_1["_causal_seq_offsets"].equal(key_factored_1["_causal_seq_offsets"]) and query_factored_1[ @@ -365,6 +365,15 @@ def forward(self, *args, **kwargs): ) +# COSMOS-RELEASE-BEGIN-IGNORE +# because we need GQA support, varlen, torch.compile, and we need it across architectures. +# Flash3 + torch.compile is banned because our container build of Flash3 doesn't support it, and +# patching on our end and lack of versioning on their end makes it very difficult to check for this +# at runtime. +# Flash2 varlen introduces instability in Blackwell, and is therefore banned. +# cuDNN is banned entirely until it can pass our tests. +# NATTEN must meet the version requirements for all the features to be available. +# COSMOS-RELEASE-END-IGNORE @pytest.mark.L0 @pytest.mark.skipif(not NATTEN_SUPPORTED, reason="NATTEN is not available, or too old.") def test_two_way_attention_cmp_flex_attn(): diff --git a/cosmos_framework/model/vfm/mot/context_parallel_utils.py b/cosmos_framework/model/vfm/mot/context_parallel_utils.py index 4bb49e0..96bf607 100644 --- a/cosmos_framework/model/vfm/mot/context_parallel_utils.py +++ b/cosmos_framework/model/vfm/mot/context_parallel_utils.py @@ -346,7 +346,7 @@ def context_parallel_attention( f"Local query heads ({q_heads_per_rank}) must be divisible by local KV heads ({kv_heads_per_rank})" ) - + # NOTE: q_und_seq, k_und_seq, and v_und_seq may have length 0 # when doing AR-inference with a KV-cache. if kv_head_repeats > 1: diff --git a/cosmos_framework/model/vfm/mot/cosmos3_vfm_network.py b/cosmos_framework/model/vfm/mot/cosmos3_vfm_network.py index 03f0c3f..39f7346 100644 --- a/cosmos_framework/model/vfm/mot/cosmos3_vfm_network.py +++ b/cosmos_framework/model/vfm/mot/cosmos3_vfm_network.py @@ -882,7 +882,7 @@ def _encode_sound( packed_tokens_sound = packed_tokens_sound.to(target_dtype) # [total_sound_tokens,sound_dim] # Project sound tokens + modality embedding - + # NOTE: Sound position info comes from m-RoPE position IDs in the attention layers. # No additive position embedding is used (unlike legacy video which keeps one for backward compat). packed_tokens_sound = ( self.sound2llm(packed_tokens_sound) + self.sound_modality_embed diff --git a/cosmos_framework/model/vfm/mot/cosmos3_vfm_qwen3_vl_network_test.py b/cosmos_framework/model/vfm/mot/cosmos3_vfm_qwen3_vl_network_test.py index 921e989..3bf9ca4 100644 --- a/cosmos_framework/model/vfm/mot/cosmos3_vfm_qwen3_vl_network_test.py +++ b/cosmos_framework/model/vfm/mot/cosmos3_vfm_qwen3_vl_network_test.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: OpenMDW-1.1 import gc +import getpass import math import os from typing import Any @@ -736,11 +737,15 @@ def test_unified_llm_outputs_with_hf(config_name: str): from transformers import Qwen3VLForConditionalGeneration as HFModelClass # Create the HF model - tokenizer_hf = AutoTokenizer.from_pretrained(config["model_name"]) + tokenizer_hf = AutoTokenizer.from_pretrained( + config["model_name"], + cache_dir=f"/nfs/dir/dir_cosmos_base/users/{getpass.getuser()}/hf_cache/", + ) hf_vlm_model = HFModelClass.from_pretrained( config["model_name"], dtype=torch.bfloat16, device_map="auto", + cache_dir=f"/nfs/dir/dir_cosmos_base/users/{getpass.getuser()}/hf_cache/", ) if config["model_type"] == "dense_llm": hf_model = hf_vlm_model diff --git a/cosmos_framework/model/vfm/mot/dot_product_attention.py b/cosmos_framework/model/vfm/mot/dot_product_attention.py index 8a6f46e..be8a968 100644 --- a/cosmos_framework/model/vfm/mot/dot_product_attention.py +++ b/cosmos_framework/model/vfm/mot/dot_product_attention.py @@ -117,7 +117,6 @@ def cudnn_fused_attn( o_quantizer = None rng_gen = None - # "thd_thd_thd" format requires contiguous tensors. # We should benchmark thd_th2d / th3d formats as well. q = q.contiguous() @@ -176,7 +175,7 @@ def cudnn_fused_attn( # is_cuda_graph args += (False,) - + # NOTE: The reason we do this instead of just calling DotProductAttention.forward is # I'd have to create DotProductAttention class and somehow pass it in here, but argument types for these torch.ops are very strict. # Moreover, back-propagation would still need additional tweaks to work properly. output_tensors = tex.fused_attn_fwd(*args) @@ -207,7 +206,7 @@ def _get_max_tokens(num_tokens: int) -> int: return max_t - +# NOTE: we need register_fake in order to make this operator fully torch.compile compatible. # The goal for this function is to return fake tensors of the correct shape and dtype # without having to run the actual operator. diff --git a/cosmos_framework/model/vfm/mot/modeling_utils.py b/cosmos_framework/model/vfm/mot/modeling_utils.py index d774fa6..418daec 100644 --- a/cosmos_framework/model/vfm/mot/modeling_utils.py +++ b/cosmos_framework/model/vfm/mot/modeling_utils.py @@ -162,7 +162,6 @@ def __init__( dim_w = dim_h dim_t = dim - 2 * dim_h assert dim == dim_h + dim_w + dim_t, f"bad dim: {dim} != {dim_h} + {dim_w} + {dim_t}" - self.register_buffer( "dim_spatial_range", torch.arange(0, dim_h, 2)[: (dim_h // 2)].float() / dim_h, diff --git a/cosmos_framework/model/vfm/mot/unified_mot.py b/cosmos_framework/model/vfm/mot/unified_mot.py index 4908e7a..b6fb011 100644 --- a/cosmos_framework/model/vfm/mot/unified_mot.py +++ b/cosmos_framework/model/vfm/mot/unified_mot.py @@ -189,7 +189,7 @@ class _MoTConfigBase(object): ``text_config`` access picks it up. Post-construction overrides via plain ``setattr`` (the - ``create_vlm_config`` flow in ``cosmos_framework/configs/base/defaults/vlm.py``) + ``create_vlm_config`` flow in ``configs/base/defaults/vlm.py``) just update the same plain attributes, so the next property access picks up the latest values. No cache, no ``__setattr__`` interception, no override bucket — the property rebuild is cheap @@ -620,7 +620,7 @@ def reasoner_forward( in a clean AR loop. All attention compute is dispatched through - ``imaginaire.attention.attention`` (per repo policy) which expects the + ``cosmos_framework.model.attention.attention`` (per repo policy) which expects the heads-last contiguous layout ``[B, S, H, D]`` and natively handles GQA (``H_KV != H``) — no manual head expansion is needed. @@ -653,7 +653,7 @@ def reasoner_forward( # q: [B,T,num_heads,head_dim], k: [B,T,num_kv_heads,head_dim] # The KV cache stores tensors in the same BSHD layout that - # ``imaginaire.attention.attention`` expects, with the seq dim at axis 1. + # ``cosmos_framework.model.attention.attention`` expects, with the seq dim at axis 1. if cache is not None: k_full, v_full = cache.update(layer_idx, k, v) else: @@ -691,7 +691,7 @@ def _impl_init( ``Nemotron3DenseVLTextModel``. Sub-layer classes (MLP, RMSNorm, rotary embedding) are dispatched through ``layer_types``. """ - self.padding_idx = config.pad_token_id + self.padding_idx = getattr(config, "pad_token_id", None) self.vocab_size = config.vocab_size self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) @@ -1024,7 +1024,7 @@ def forward( ln_out_gen = self.post_attention_layernorm_moe_gen(residual_gen) # [N_gen,hidden_size] # UNPAD MLP INPUT =============== - + # NOTE: This is only need for the MoE auxiliary loss computation and to avoid # artificial expert inbalance due to routing padding tokens. gen_len = pack_attn_out["_num_full_tokens"] und_len = pack_attn_out["_num_causal_tokens"] @@ -1225,7 +1225,7 @@ class ReasonerKVCache: """Per-layer KV cache for the reasoner-tower autoregressive loop. Tensors are stored in the heads-last BSHD layout that - ``imaginaire.attention.attention`` expects:: + ``cosmos_framework.model.attention.attention`` expects:: keys[layer_idx]: [B, T, num_kv_heads, head_dim] values[layer_idx]: [B, T, num_kv_heads, head_dim] @@ -1364,7 +1364,8 @@ def _sample_next_token( top_p: float | None, repetition_penalty: float = 1.0, presence_penalty: float = 0.0, - seen_mask: torch.Tensor | None = None, # [B,vocab_size] bool + seen_mask: torch.Tensor | None = None, # [B,vocab_size] bool — prompt ∪ output, for repetition_penalty + output_seen_mask: torch.Tensor | None = None, # [B,vocab_size] bool — output only, for presence_penalty generator: torch.Generator | None = None, ) -> torch.Tensor: # [B] """Greedy / multinomial sampling with optional top-k, top-p, and presence/repetition penalties. @@ -1375,20 +1376,21 @@ def _sample_next_token( ``>1.0`` discourages repetition, ``<1.0`` encourages it, ``1.0`` is identity. 2. Presence penalty (OpenAI semantics) — additive shift of every - logit at a position seen in history. ``>0`` discourages, - ``<0`` encourages, ``0`` is identity. Applied once per token - regardless of how often it appeared (presence, not frequency). + logit at a position seen in **output** (``output_seen_mask``). ``>0`` + discourages, ``<0`` encourages, ``0`` is identity. 3. ``do_sample=False`` short-circuits to argmax. The two penalties above are applied *before* this branch so they legitimately shift the greedy argmax — they're logit transformations, not sampling-only tricks. 4. ``do_sample=True``: temperature → top-k → top-p → multinomial. - ``seen_mask`` is the canonical "has this vocab id appeared in this - sample's history" matrix maintained by - :func:`_impl_generate_reasoner_text`. Both penalties default to - identity, and the fast path (both off) skips all penalty work and - leaves the existing greedy/sampling logic bit-identical. + Mask semantics (match vLLM): + * ``seen_mask`` is seeded with prompt tokens and updated with each + generated token — penalizes prompt ∪ output (HF convention). + * ``output_seen_mask`` is updated with each generated token only — penalizes + output only. + Both penalties default to identity; the fast path (both off) leaves the + existing greedy/sampling logic bit-identical. ``generator`` is the only RNG-consuming primitive in this module: when provided, it is threaded into ``torch.multinomial`` so the @@ -1398,30 +1400,22 @@ def _sample_next_token( pre-seed behavior of consuming the device's default RNG and is bit-identical to the previous call signature. """ - # Logit-transform stage: repetition + presence penalties. Both gate - # on ``seen_mask`` being present AND a non-identity coefficient, so - # the default-off path costs zero extra ops. - if seen_mask is not None and (repetition_penalty != 1.0 or presence_penalty != 0.0): - if repetition_penalty != 1.0: - # CTRL/HF formula: divide positive logits, multiply negative - # logits (both by ``penalty``). Phrasing the scale as a - # single ``where`` over a precomputed factor keeps the - # masked-update path branchless and lets autograd / inductor - # fuse it with the surrounding ops. - penalty_factor = torch.where( - logits > 0, - torch.full_like(logits, 1.0 / repetition_penalty), - torch.full_like(logits, repetition_penalty), - ) - logits = torch.where(seen_mask, logits * penalty_factor, logits) - if presence_penalty != 0.0: - # OpenAI semantics: subtract a constant from every seen - # token's logit, once per token (presence, not frequency). - logits = torch.where( - seen_mask, - logits - presence_penalty, - logits, - ) + if seen_mask is not None and repetition_penalty != 1.0: + # CTRL/HF formula: divide positive logits, multiply negative. + penalty_factor = torch.where( + logits > 0, + torch.full_like(logits, 1.0 / repetition_penalty), + torch.full_like(logits, repetition_penalty), + ) + logits = torch.where(seen_mask, logits * penalty_factor, logits) + if output_seen_mask is not None and presence_penalty != 0.0: + # OpenAI semantics: subtract a constant from every seen + # token's logit, once per token (presence, not frequency). + logits = torch.where( + output_seen_mask, + logits - presence_penalty, + logits, + ) if not do_sample: return torch.argmax(logits, dim=-1) @@ -1590,9 +1584,11 @@ def _impl_generate_reasoner_text( — appearing twice costs the same as appearing once. Both penalties are applied *before* the ``do_sample`` argmax/multinomial branch, so they shift the greedy - argmax too. When both are at identity, the per-sample - ``seen_mask`` is never allocated and the loop is - bit-identical to the un-penalized fast path. + argmax too. When both are at identity, no history mask + is allocated and the loop is bit-identical to the + un-penalized fast path. Repetition penalty uses prompt ∪ + output; presence penalty uses output only (OpenAI / vLLM + convention). seed: Optional integer seed for the sampling RNG. When provided (and ``do_sample=True``), a fresh ``torch.Generator`` is allocated on ``input_ids.device`` and seeded once with @@ -1675,23 +1671,15 @@ def _impl_generate_reasoner_text( ) # [B,T_prompt,hidden_size] logits = causal_lm.lm_head(hidden[:, -1, :]) # [B,vocab_size] - # ``seen_mask`` is the per-sample "vocab id has appeared in this - # sample's history" matrix consumed by ``_sample_next_token``'s - # repetition / presence penalty paths. Allocate only when at least - # one penalty is non-identity so the un-penalized fast path is - # bit-identical to the previous behavior (no extra alloc, no scatter, - # no per-step writes). We size from ``logits.size(-1)`` so we don't - # have to reach into ``lm_head.weight.shape`` (which would also - # work under FSDP but is one extra coupling point). The mask - # captures prompt tokens first so the prefill's own sampling step - # already sees the prompt as history — matching HF's - # ``RepetitionPenaltyLogitsProcessor`` convention of penalizing - # against the full ``input_ids``. - apply_penalties = repetition_penalty != 1.0 or presence_penalty != 0.0 + # seen_mask is seeded with prompt tokens (HF convention). + # output_seen_mask stays empty until output tokens accumulate (OpenAI convention). seen_mask: torch.Tensor | None = None - if apply_penalties: + output_seen_mask: torch.Tensor | None = None + if repetition_penalty != 1.0: seen_mask = torch.zeros(B, logits.size(-1), dtype=torch.bool, device=device) seen_mask.scatter_(1, input_ids, True) + if presence_penalty != 0.0: + output_seen_mask = torch.zeros(B, logits.size(-1), dtype=torch.bool, device=device) # Build a device-local ``torch.Generator`` only when an explicit # seed is supplied. ``torch.multinomial(generator=None)`` falls @@ -1717,13 +1705,14 @@ def _impl_generate_reasoner_text( repetition_penalty=repetition_penalty, presence_penalty=presence_penalty, seen_mask=seen_mask, + output_seen_mask=output_seen_mask, generator=generator, ) # [B] + # Fold the just-sampled token into both penalty histories. if seen_mask is not None: - # Fold the just-sampled token into each sample's history so the - # next decode step penalizes it too. Per-sample row writes are - # idempotent — writing True over True is a no-op. seen_mask.scatter_(1, next_token.unsqueeze(1), True) + if output_seen_mask is not None: + output_seen_mask.scatter_(1, next_token.unsqueeze(1), True) # Hoist invariants used by every decode step out of the loop body so we # don't pay per-iter Python and allocator overhead for what is in fact @@ -1804,19 +1793,19 @@ def _impl_generate_reasoner_text( repetition_penalty=repetition_penalty, presence_penalty=presence_penalty, seen_mask=seen_mask, + output_seen_mask=output_seen_mask, generator=generator, ) # [B] # Force pad on already-finished samples; finished stays True afterwards. # ``pad_tensor`` is hoisted above so we avoid the per-step # ``torch.full_like(next_token, pad_token_id)`` allocation. next_token = torch.where(finished, pad_tensor, next_token) + # Record (post-pad) emitted token in both penalty histories. Finished + # samples write pad_token_id, which is dead state and harmless. if seen_mask is not None: - # Record the (post-pad) emitted token in history. For - # still-running samples this is the actual sampled token; - # for already-finished samples it's ``pad_token_id``, which - # is harmless because finished samples don't sample anymore - # — their row of ``seen_mask`` is dead state from here on. seen_mask.scatter_(1, next_token.unsqueeze(1), True) + if output_seen_mask is not None: + output_seen_mask.scatter_(1, next_token.unsqueeze(1), True) if eos_tensor is not None: # Vectorized EOS comparison: broadcast ``next_token`` (``[B,1]``) # against ``eos_tensor`` (``[E]``) and reduce-any across the @@ -2197,5 +2186,4 @@ def generate_reasoner_text( seed: int | None = None, return_only_new_tokens: bool = False, ) -> torch.Tensor: - raise NotImplementedError("This method is not implemented for Nemotron 3 Dense VL.") diff --git a/cosmos_framework/model/vfm/omni_mot_model.py b/cosmos_framework/model/vfm/omni_mot_model.py index 9e93f98..33a6269 100644 --- a/cosmos_framework/model/vfm/omni_mot_model.py +++ b/cosmos_framework/model/vfm/omni_mot_model.py @@ -4,6 +4,7 @@ from __future__ import annotations import collections +import json import time from contextlib import contextmanager from typing import Any, Callable, Dict, Mapping, Optional, Tuple @@ -99,7 +100,6 @@ def set_precision(self) -> None: torch.backends.cudnn.allow_tf32 = torch.backends.cuda.matmul.allow_tf32 = False def set_up_data_key(self) -> None: - self.input_video_key = self.config.input_video_key # by default it is video key for Video diffusion model self.input_image_key = self.config.input_image_key self.input_caption_key = self.config.input_caption_key @@ -143,7 +143,6 @@ def set_up_tokenizers(self) -> None: vlm_tokenizer, special_tokens = add_special_tokens(vlm_tokenizer) self.vlm_tokenizer = vlm_tokenizer - self.llm_special_tokens = special_tokens self.llm_special_tokens["eos_token_id"] = vlm_tokenizer.eos_token_id @@ -170,7 +169,6 @@ def set_up_tokenizers(self) -> None: self.tokenizer_sound_gen = None - def build_net(self, dtype: torch.dtype): # Build model network and parallelize it. with torch.device("meta"): @@ -178,7 +176,7 @@ def build_net(self, dtype: torch.dtype): language_model = lazy_instantiate(self.vlm_config.model_instance) - + # NOTE: We pass "RF timesteps" to the network in the same scale as the scheduler # (i.e., roughly [0, num_train_timesteps]). The MoT network expects to internally # rescale timesteps before embedding; avoid hard-coding 1e-3 by computing it from # the configured scheduler resolution. @@ -321,7 +319,6 @@ def set_up_model(self): self.net_ema_worker = DTensorFastEmaModelUpdater() - s = config.ema.rate self.ema_exp_coefficient = np.roots([1, 7, 16 - s**-2, 12 - s**-2]).real.max() @@ -1121,7 +1118,6 @@ def _get_train_noise_level_vision( (timesteps, sigmas): Both [B,1] for TF/base, or [B,T_max] for diffusion_forcing. """ - rectified_flow = self.rectified_flow_image if is_image_batch else self.rectified_flow_video assert not self.config.rectified_flow_training_config.use_discrete_rf, ( @@ -2265,43 +2261,6 @@ def generate_samples_from_batch( assert n_sample == len(seed), f"Number of samples {n_sample} must match number of seeds {len(seed)}" - # FSDP collective-sequence alignment (throughput-preset inference). - # - # In throughput-preset inference each rank holds a different sample, - # and different samples can diverge on (a) the CFG decision per - # step — ``guidance != 1.0`` (and the optional ``guidance_interval`` - # gate) determines whether ``velocity_fn`` issues 1 or 2 model - # forwards — and (b) ``num_steps``. Either divergence makes the - # FSDP allgather sequence misalign across ranks, deadlocking NCCL - # at the 30-min watchdog timeout. - # - # We align in two places: - # 1. Inside velocity_fn (per call): all_reduce the local CFG - # decision; if ANY rank needs CFG, every rank does both - # forwards (cond + uncond). Ranks whose local decision was - # "no CFG" return ``cond_v`` directly — bit-identical to the - # original no-CFG path (no guidance blend, no normalize_cfg). - # 2. Around the sampler call: all_reduce the local num_steps; - # ranks with local < max issue a dummy sampler call with the - # remaining steps to pad the FSDP allgather stream. The - # dummy call's output is discarded; ``latents`` is never - # re-bound. - # - # Both collectives are scoped to the FSDP shard group (the only - # process group whose collective sequence is at risk), so they're - # safe under non-trivial parallel layouts. - if ( - self.parallel_dims is not None - and self.parallel_dims.dp_shard_mesh is not None - and torch.distributed.is_initialized() - and self.parallel_dims.dp_shard_mesh.size() > 1 - ): - _dp_shard_group = self.parallel_dims.dp_shard_mesh.get_group() - _align_device = self.tensor_kwargs["device"] - else: - _dp_shard_group = None - _align_device = None - # Create a velocity function for a single sample (for use with self.sampler). def velocity_fn(noise_x: list[torch.Tensor], timestep: torch.Tensor) -> list[torch.Tensor]: @@ -2326,34 +2285,16 @@ def _single_velocity_fn(tokens: list[list[int]], skip_text_tokens: bool): skip_text_tokens=skip_text_tokens, ) - # Local CFG decision — honors ``guidance_interval`` for this rank. - _local_needs_cfg = guidance != 1.0 - if _local_needs_cfg and guidance_interval is not None: + # Skip unconditional branch when outside the guidance interval + needs_cfg = guidance != 1.0 + if needs_cfg and guidance_interval is not None: assert len(guidance_interval) == 2, f"guidance_interval must be [lo, hi], got {guidance_interval}" t_lo, t_hi = guidance_interval - _local_needs_cfg = t_lo < timestep[0].item() < t_hi - - # FSDP alignment: if ANY rank in the shard group needs CFG this - # call, every rank computes both forwards. Cheap 1-element - # all_reduce per velocity_fn call; the alternative (forcing CFG - # always-on globally) would silently ignore the per-timestep - # ``guidance_interval`` gate. - if _dp_shard_group is not None: - _cfg_t = torch.tensor( - [1 if _local_needs_cfg else 0], device=_align_device, dtype=torch.int32 - ) - torch.distributed.all_reduce( - _cfg_t, op=torch.distributed.ReduceOp.MAX, group=_dp_shard_group - ) - _any_needs_cfg = bool(_cfg_t.item()) - else: - _any_needs_cfg = _local_needs_cfg + needs_cfg = t_lo < timestep[0].item() < t_hi - if not _any_needs_cfg: + if not needs_cfg: return _single_velocity_fn(cond_tokens, skip_text_tokens=False) - # Both forwards happen — needed for FSDP collective alignment - # across ranks even if THIS rank's local decision was "no CFG". cond_v, uncond_v = self._run_classifier_free_guidance( cond_tokens=cond_tokens, uncond_tokens=uncond_tokens, @@ -2361,14 +2302,6 @@ def _single_velocity_fn(tokens: list[list[int]], skip_text_tokens: bool): single_velocity_fn=_single_velocity_fn, ) - if not _local_needs_cfg: - # This rank doesn't actually need CFG (guidance==1.0 or sigma - # outside guidance_interval). Return cond_v directly so the - # output is bit-identical to the original no-CFG path; the - # uncond_v forward was only run to keep the FSDP allgather - # sequence aligned with peers. - return cond_v - v_pred = [u_i + guidance * (c_i - u_i) for c_i, u_i in zip(cond_v, uncond_v)] if normalize_cfg: @@ -2379,21 +2312,6 @@ def _single_velocity_fn(tokens: list[list[int]], skip_text_tokens: bool): return v_pred - # FSDP collective-sequence alignment (sampler outer loop). See the - # large block above the velocity_fn definition for the full - # rationale. all_reduce on the local num_steps so every rank knows - # the max; below, ranks with local < max issue a dummy sampler call - # to pad their FSDP allgather sequence. - if _dp_shard_group is not None: - _local_steps_t = torch.tensor([num_steps], device=_align_device, dtype=torch.int32) - torch.distributed.all_reduce( - _local_steps_t, op=torch.distributed.ReduceOp.MAX, group=_dp_shard_group - ) - _max_num_steps = int(_local_steps_t.item()) - else: - _max_num_steps = num_steps - _extra_num_steps = _max_num_steps - num_steps - # Run sampler for all samples at once. sampler = sampler or self.sampler scheduler_type = self.config.rectified_flow_inference_config.scheduler_type @@ -2410,23 +2328,6 @@ def _single_velocity_fn(tokens: list[list[int]], skip_text_tokens: bool): shift=shift, seed=seed, ) - if _extra_num_steps > 0: - # Dummy sampler call to issue (_extra_num_steps × per-step) - # FSDP allgathers; output discarded so `latents` keeps the - # real result captured above. Slow ranks have _extra_num_steps==0 - # here, but they're issuing the SAME number of in-sampler - # collectives via their longer real call. - log.debug( - f"FSDP alignment: dummy sampler run with {_extra_num_steps} " - f"extra steps (local={num_steps}, max={_max_num_steps})" - ) - _ = sampler( - velocity_fn, - latents, - num_steps=_extra_num_steps, - shift=shift, - seed=seed, - ) else: # EDM Sampler chunk_sizes = [_x.shape[0] for _x in initial_noise] @@ -2453,41 +2354,6 @@ def x0_fn(noise_x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor: sigma_min=0.002, solver_option="2ab", ) - if _extra_num_steps > 0: - # Pad the FSDP allgather sequence with ``_extra_num_steps`` - # direct ``x0_fn`` calls instead of a second EDM sampler - # run. Avoids two EDM-specific footguns: - # (1) ``EDMSampler._forward_impl`` always runs an extra - # ``sample_clean`` denoiser forward (see - # ``cosmos_framework/model/vfm/diffusion/samplers/edm.py``). - # A nested sampler call would add one too many - # forwards on fast ranks, since the slow rank's - # single call also pays the ``sample_clean`` cost. - # (2) ``get_rev_ts(..., num_steps=0)`` divides by zero, - # producing NaN sigmas. The fix's ``extra==1`` edge - # case would need num_steps=0 to balance the count. - # Direct ``x0_fn`` calls bypass both: each call routes - # through the same ``velocity_fn`` closure (so the - # per-call CFG all_reduce still aligns ranks), issues - # exactly one model forward, and discards its return. - # ``latents`` is the catted single tensor at this point; - # the dummy sigma value is irrelevant for collective - # alignment because the model's allgather sequence is - # determined by tensor shapes, not sigma. - log.debug( - f"FSDP alignment: padding {_extra_num_steps} dummy x0_fn calls " - f"(local={num_steps}, max={_max_num_steps})" - ) - # ``x0_fn`` expects a sigma in the RF domain (the real EDM - # loop converts raw sigmas via ``sigmas_L / (1 + sigmas_L)`` - # at edm.py:174, landing them in ``(0, 1)``). Mirror that - # transform here so the dummy call's timestep stays in the - # same numerical domain as a real sampler step. The exact - # value doesn't matter for collective alignment, only the - # domain. - _dummy_sigma = latents.new_tensor(sigma_max / (1.0 + sigma_max)) - for _ in range(_extra_num_steps): - _ = x0_fn(latents, _dummy_sigma) latents = list(torch.split(latents, chunk_sizes, dim=0)) # Split flattened latents back into vision, action, and sound @@ -2511,7 +2377,6 @@ def x0_fn(noise_x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor: vision_shape = gen_data_clean.x0_tokens_vision[idx_vision + j].shape vision_dim = int(torch.prod(torch.tensor(vision_shape))) if j == n_vis - 1: # the last vision item is the only target for each sample. - result_vision.append(latents[i][offset : offset + vision_dim].reshape(vision_shape)) else: # the other vision items are the condition inputs that we don't need to return pass @@ -2701,7 +2566,6 @@ def get_data_and_condition(self, data_batch: dict[str, torch.Tensor], iteration: is_image_batch = self.is_image_batch(data_batch) sample_vision_list = data_batch[self.input_image_key if is_image_batch else self.input_video_key] - # we should always get this information here during training. If we can read this field # from data_batch it means we are in the visualization callback: if "num_vision_items_per_sample" not in data_batch: @@ -2714,7 +2578,6 @@ def get_data_and_condition(self, data_batch: dict[str, torch.Tensor], iteration: num_vision_items_per_sample: list[int] | None = ( [len(v) for v in sample_vision_list] if has_multiple_vision_per_sample else None ) - # information is only stored in the GenerationDataClean object which will be discarded # outside the training loop. Error will be raised when the data batch is passed to the # visualization callbacks. @@ -3123,7 +2986,7 @@ def _extract_upsample_video_specs( tensor (``shape[-1]`` for width, ``shape[-2]`` for height), and the ``aspect_ratio`` string is reverse-looked-up against the canonical ``{IMAGE,VIDEO}_RES_SIZE_INFO`` tables in - :mod:`projects.cosmos3.vfm.datasets.utils` — image table for + :mod:`cosmos_framework.data.vfm.utils` — image table for ``"t2i"``, video table otherwise. Note these tables are ``{res: {ar: (W, H)}}`` (the first entry is *width*); the existing logging-only lookup in @@ -3138,7 +3001,7 @@ def _extract_upsample_video_specs( where ``num_frames`` is the temporal dimension (``shape[-3]``) of the same vision tensor. For ``"t2i"`` both fields are returned as ``None`` so - :func:`projects.cosmos3.vfm.upsampler.prompts.build_user_text`'s + :func:`cosmos_framework.model.vfm.upsampler.prompts.build_user_text`'s ``t2i``-must-have-no-video-args contract is satisfied. Args: @@ -3213,7 +3076,7 @@ def _extract_upsample_video_specs( raise ValueError(f"upsample task={task!r}: conditioning_fps must be positive; got {fps_int}.") num_frames = int(sample.shape[-3]) # Integer-floor seconds matches the canonical V4.2 ``M:SS`` rendering - # in :func:`projects.cosmos3.vfm.upsampler.prompts._format_duration`, + # in :func:`cosmos_framework.model.vfm.upsampler.prompts._format_duration`, # which expects an int and rejects fractional seconds. duration_secs = max(1, num_frames // fps_int) return aspect_ratio, w, h, fps_int, duration_secs @@ -3724,7 +3587,7 @@ def generate_reasoner_text( ``np.ndarray``, or a CHW / HWC tensor). prompt_builder: Optional callback that maps a raw prompt string to a chat-style messages list (e.g. - :func:`projects.cosmos3.vfm.upsampler.prompts.build_messages` + :func:`cosmos_framework.model.vfm.upsampler.prompts.build_messages` for V4.2 caption upsampling). When ``None``, prompts are wrapped as ``[{"role": "user", "content": prompt}]`` with no system message. @@ -3969,7 +3832,7 @@ def upsample_captions( prompt-driven branch. The only thing this method adds on top of the generic per-prompt loop is the V4.2 chat-template injection: each caption is wrapped via - :func:`projects.cosmos3.vfm.upsampler.prompts.build_messages` + :func:`cosmos_framework.model.vfm.upsampler.prompts.build_messages` (which returns ``[system, user]`` with the user content embedding the caption inside the canonical V4.2 template — instructions, task constraints, and output JSON schema for the requested task). @@ -3992,7 +3855,7 @@ def upsample_captions( position ids) before kicking off the AR decode loop. Each raw reasoner output is post-processed by - :func:`projects.cosmos3.vfm.upsampler.prompts.clean_response` + :func:`cosmos_framework.model.vfm.upsampler.prompts.clean_response` before being returned. The cleaner strips ```` / ```` / ```` / etc. reasoning blocks and any prose preamble that appears before the @@ -4044,7 +3907,7 @@ def upsample_captions( fps: Target frames-per-second for the generated clip. Required for the video tasks (``"t2v"``, ``"i2v"``) and must be ``None`` for ``"t2i"`` — the underlying - :func:`projects.cosmos3.vfm.upsampler.prompts.build_user_text` + :func:`cosmos_framework.model.vfm.upsampler.prompts.build_user_text` raises ``ValueError`` if a video task is missing ``fps`` or ``duration_secs``. duration_secs: Clip duration in whole seconds (rendered as @@ -4146,25 +4009,32 @@ def _builder(description: str) -> list[dict[str, Any]]: # into ``data_batch[self.input_caption_key]`` at the call site. cleaned_outputs: list[str] = [] n_stripped = 0 - n_fallback = 0 for raw, original in zip(raw_outputs, captions): cleaned_text, clean_info = clean_response(raw) if not clean_info["was_clean"]: n_stripped += 1 if not cleaned_text.strip(): cleaned_text = original - n_fallback += 1 + + # Stamp the actual generation ``duration`` onto the upsampled + # JSON object using the duration_secs argument. Only done for + # T2V and I2V tasks. + if duration_secs is not None: + cleaned_text = cleaned_text.removeprefix("```json").removesuffix("```").strip() + obj = json.loads(cleaned_text) + assert isinstance(obj, dict), f"JSON parsing failed with error: {type(obj)}" + obj["duration"] = f"{duration_secs}s" + cleaned_text = json.dumps(obj) + cleaned_outputs.append(cleaned_text) # Stay silent on the canonical all-clean path; only emit # telemetry when something actually happened. Logged per-rank # to match the surrounding upsampling logs in # :meth:`generate_samples_from_batch` (line ~2218). - if n_stripped or n_fallback: + if n_stripped: log.info( - f"upsample_captions(task={task!r}, n={len(raw_outputs)}): " - f"thinking-stripped={n_stripped}, " - f"empty-clean-fallback={n_fallback}", + f"upsample_captions(task={task!r}, n={len(raw_outputs)}): thinking-stripped={n_stripped}", rank0_only=False, ) diff --git a/cosmos_framework/model/vfm/parallelize_vlm.py b/cosmos_framework/model/vfm/parallelize_vlm.py index d4f31a4..e38b37c 100644 --- a/cosmos_framework/model/vfm/parallelize_vlm.py +++ b/cosmos_framework/model/vfm/parallelize_vlm.py @@ -3,12 +3,12 @@ """FSDP2 wrapping for Cosmos3 VLM ``HFModel`` instances. Hosts the single VLM-specific ``parallelize`` entry point used by -``vlm_model.VLMModel._init_vlm``. Lives under ``projects/cosmos3/vfm/models/`` +``vlm_model.VLMModel._init_vlm``. Lives under ``cosmos_framework/model/vfm/`` so the FSDP wrapping concern sits next to the model class it operates on (mirroring the layout of ``models/mot/parallelize_unified_mot.py`` for the MoT path). -Pure parallelism plumbing — :class:`~projects.cosmos3.vfm.utils.parallelism.ParallelDims` +Pure parallelism plumbing — :class:`~cosmos_framework.utils.vfm.parallelism.ParallelDims` and its meshes — stays in ``vfm/utils/parallelism.py``. """ diff --git a/cosmos_framework/model/vfm/tokenizers/audio/avae.py b/cosmos_framework/model/vfm/tokenizers/audio/avae.py index a23ee80..69c35cd 100644 --- a/cosmos_framework/model/vfm/tokenizers/audio/avae.py +++ b/cosmos_framework/model/vfm/tokenizers/audio/avae.py @@ -2,9 +2,9 @@ # SPDX-License-Identifier: OpenMDW-1.1 """ -AVAE (Audio Variational AutoEncoder) Tokenizer. - -Ported from BigVGAN (https://github.com/NVIDIA/BigVGAN). +AVAE (Audio Variational AutoEncoder) Tokenizer for Imaginaire4 +ported from https://gitlab-master.nvidia.com/ADLR/bigvgan +commit hash: 80fbd8cfecb1867cc864e6d4fe0a474d8403a474 """ import os @@ -128,7 +128,7 @@ def _load_avae_model( ) # Create model directly on device (don't use meta device) - + # NOTE: Unlike WanVAE/FluxVAE, AVAE uses weight_norm extensively in OobleckDecoder # and SpectrogramConvNeXtEncoder. After loading the checkpoint, we must call # remove_weight_norm() which requires materialized tensors (not meta tensors). # Therefore, we create the model directly on the target device instead of using @@ -358,7 +358,6 @@ def __init__( use_object_store = False # Parent directory is registered in checkpoint_db. - if vae_path_full: vae_dir, vae_name = os.path.split(vae_path_full) vae_dir = download_checkpoint_v2(vae_dir) diff --git a/cosmos_framework/model/vfm/tokenizers/dc_ae/__init__.py b/cosmos_framework/model/vfm/tokenizers/dc_ae/__init__.py index 3f85ab5..990ad7f 100644 --- a/cosmos_framework/model/vfm/tokenizers/dc_ae/__init__.py +++ b/cosmos_framework/model/vfm/tokenizers/dc_ae/__init__.py @@ -1,4 +1,17 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright 2025 NVIDIA CORPORATION & AFFILIATES +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# # SPDX-License-Identifier: OpenMDW-1.1 from cosmos_framework.model.vfm.tokenizers.dc_ae.dc_ae_v import ( diff --git a/cosmos_framework/model/vfm/tokenizers/dc_ae/cosmos_ae_4x32x32_compile_test.py b/cosmos_framework/model/vfm/tokenizers/dc_ae/cosmos_ae_4x32x32_compile_test.py index a027bf0..91ef71e 100644 --- a/cosmos_framework/model/vfm/tokenizers/dc_ae/cosmos_ae_4x32x32_compile_test.py +++ b/cosmos_framework/model/vfm/tokenizers/dc_ae/cosmos_ae_4x32x32_compile_test.py @@ -14,7 +14,9 @@ from cosmos_framework.utils.easy_io import easy_io from cosmos_framework.utils.helper_test import RunIf from cosmos_framework.configs.base.defaults.cluster import DefaultClusterConfig as CLUSTER_CONFIG -from cosmos_framework.configs.base.defaults.tokenizer import PRETRAINED_TOKENIZER_DCAE_PTH +from cosmos_framework.configs.base.defaults.tokenizer import ( + PRETRAINED_TOKENIZER_DCAE_4X32X32_C64_T120_256P_FPS_ALL_ENCODER_CAUSAL_DECODER_CHUNKCAUSAL4_NOGAN_COSMOS_PAD_7_V0PT2_PTH, +) from cosmos_framework.model.vfm.tokenizers.dc_ae.dc_ae_4x32x32 import DCAE4x32x32Interface from cosmos_framework.model.vfm.tokenizers.unittest_utils import ( numpy2tensor, @@ -29,7 +31,7 @@ """ -VAE_PATH = PRETRAINED_TOKENIZER_DCAE_PTH +VAE_PATH = PRETRAINED_TOKENIZER_DCAE_4X32X32_C64_T120_256P_FPS_ALL_ENCODER_CAUSAL_DECODER_CHUNKCAUSAL4_NOGAN_COSMOS_PAD_7_V0PT2_PTH def _make_cosmos_ae_from_s3(encoder_width_list): diff --git a/cosmos_framework/model/vfm/tokenizers/dc_ae/dc_ae_4x32x32.py b/cosmos_framework/model/vfm/tokenizers/dc_ae/dc_ae_4x32x32.py index 9aceec2..18b242a 100644 --- a/cosmos_framework/model/vfm/tokenizers/dc_ae/dc_ae_4x32x32.py +++ b/cosmos_framework/model/vfm/tokenizers/dc_ae/dc_ae_4x32x32.py @@ -1,11 +1,14 @@ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: OpenMDW-1.1 +from collections.abc import Sequence + import torch from cosmos_framework.utils import log from cosmos_framework.utils.distributed import get_rank, sync_model_states from cosmos_framework.utils.easy_io import easy_io +from cosmos_framework.data.vfm.utils import VIDEO_RES_SIZE_INFO from cosmos_framework.model.vfm.tokenizers.dc_ae.dc_ae_v import ( DCAEV, DCAEVConfig, @@ -13,7 +16,7 @@ ) from cosmos_framework.model.vfm.tokenizers.interface import VideoTokenizerInterface -DEFAULT_MODEL_NAME = "dcae4x32x32_c64_t120_256p_fps_all_encoder_causal_decoder_chunk_causal_4_nogan_cosmos_pad_7_v0.1" +DEFAULT_MODEL_NAME = "dcae4x32x32_c64_t120_256p_fps_all_encoder_causal_decoder_chunk_causal_4_nogan_cosmos_pad_7_v0.2" class DCAE4x32x32Interface(VideoTokenizerInterface): @@ -22,7 +25,7 @@ def __init__( bucket_name: str = "", object_store_credential_path_pretrained: str = "", vae_path: str = "", - chunk_duration: int = 81, + chunk_duration: int = 16, model_name: str = DEFAULT_MODEL_NAME, spatial_compression_factor: int = 32, temporal_compression_factor: int = 4, @@ -30,15 +33,21 @@ def __init__( encode_bucket_multiple: int = 2, # Placeholder device: str = "cuda", compilable: bool = True, + causal: bool = True, ): + self._causal = causal + assert self._causal, "DCAE4x32x32Interface is a causal tokenizer; causal must be True." vae_path_full = f"s3://{bucket_name}/{vae_path}" self._spatial_compression_factor = spatial_compression_factor self._temporal_compression_factor = temporal_compression_factor self.chunk_duration = chunk_duration + self.model_name = model_name + self.resolutions = None # Build config (without pretrained_path so DCAEV doesn't try to load itself). cfg: DCAEVConfig = dc_ae_v_f32t4_encoder_causal_decoder_chunk_causal_4(model_name, pretrained_path=None) cfg.compilable = compilable + cfg.encode_temporal_tile_size = chunk_duration # Instantiate model on meta device to avoid double allocation. with torch.device("meta"): @@ -61,6 +70,7 @@ def __init__( self.model.to(dtype=torch.bfloat16) sync_model_states(self.model) + self.model.encoder = self.model.encoder.to(memory_format=torch.channels_last_3d) self.is_compiled = False self.use_streaming_encode = False @@ -72,11 +82,52 @@ def compile_encode_for_cudagraphs( dynamic: bool = False, backend: str = "inductor", ) -> None: - - self.model.encoder = self.model.encoder.to(memory_format=torch.channels_last_3d) self.model.encoder = torch.compile(self.model.encoder, fullgraph=True, mode=mode) self.is_compiled = True + @torch.inference_mode() + def compile_encode( + self, + warmup_resolutions: Sequence[str], + output_dir: str | None = None, + aspect_ratio: str | None = None, + backend: str | None = "inductor", + mode: str | None = "reduce-overhead", + fullgraph: bool = False, + dynamic: bool = False, + ) -> None: + """Compile the encode function for the given resolutions.""" + if self.is_compiled: + log.warning("Tokenizer is already compiled, skipping compilation.") + return + + if backend is None: + raise ValueError("backend must be provided") + + self.compile_encode_for_cudagraphs(mode=mode, fullgraph=fullgraph, dynamic=dynamic, backend=backend) + + # Run warmup resolutions + if aspect_ratio is None: + aspect_ratios = list(VIDEO_RES_SIZE_INFO["256"].keys()) + else: + if isinstance(aspect_ratio, str): + if aspect_ratio not in VIDEO_RES_SIZE_INFO["256"]: + raise ValueError(f"Aspect ratio {aspect_ratio} not found in predefined aspect ratios") + aspect_ratios = [aspect_ratio] + else: + raise ValueError(f"Aspect ratio {aspect_ratio} must be a string") + + self.resolutions = warmup_resolutions + self.aspect_ratios = aspect_ratios + + T = self.chunk_duration - self.model.cfg.num_pad_frames + for resolution in warmup_resolutions: + for aspect_ratio in aspect_ratios: + H, W = VIDEO_RES_SIZE_INFO[resolution][aspect_ratio] + log.info(f"Warming up {resolution} {aspect_ratio}") + for _ in range(2): + self.model.encode(torch.randn(1, 3, T, H, W).cuda().to(torch.bfloat16)) + @property def dtype(self): return self.model.dtype @@ -86,6 +137,12 @@ def reset_dtype(self): @torch.inference_mode() def encode(self, state: torch.Tensor) -> torch.Tensor: + if self.resolutions is not None: + for resolution in self.resolutions: + if tuple(state.shape[3:]) in VIDEO_RES_SIZE_INFO[resolution].values(): + break + else: + raise ValueError(f"State shape {state.shape[2:]} is not in {self.resolutions}") in_dtype = state.dtype tcf = self._temporal_compression_factor # Add padding to the sequence length to make it divisible by diff --git a/cosmos_framework/model/vfm/tokenizers/dc_ae/dc_ae_4x32x32_test.py b/cosmos_framework/model/vfm/tokenizers/dc_ae/dc_ae_4x32x32_test.py index 5167277..6912123 100644 --- a/cosmos_framework/model/vfm/tokenizers/dc_ae/dc_ae_4x32x32_test.py +++ b/cosmos_framework/model/vfm/tokenizers/dc_ae/dc_ae_4x32x32_test.py @@ -10,7 +10,9 @@ from cosmos_framework.utils.easy_io import easy_io from cosmos_framework.utils.helper_test import RunIf from cosmos_framework.configs.base.defaults.cluster import DefaultClusterConfig as CLUSTER_CONFIG -from cosmos_framework.configs.base.defaults.tokenizer import PRETRAINED_TOKENIZER_DCAE_PTH +from cosmos_framework.configs.base.defaults.tokenizer import ( + PRETRAINED_TOKENIZER_DCAE_4X32X32_C64_T120_256P_FPS_ALL_ENCODER_CAUSAL_DECODER_CHUNKCAUSAL4_NOGAN_COSMOS_PAD_7_V0PT2_PTH, +) from cosmos_framework.configs.base.defaults.unittest import TOKENIZER_RECONSTRUCTION_VIDEO_PATH, UNITTEST_CONFIG from cosmos_framework.model.vfm.tokenizers.dc_ae.dc_ae_4x32x32 import DEFAULT_MODEL_NAME, DCAE4x32x32Interface from cosmos_framework.model.vfm.tokenizers.dc_ae.dc_ae_v import DCAEV, dc_ae_v_f32t4_encoder_causal_decoder_chunk_causal_4 @@ -27,7 +29,7 @@ RUN_SKIPPED_TEST_LOCALLY=1 pytest -s cosmos_framework/model/vfm/tokenizers/dc_ae/dc_ae_4x32x32_test.py -k test_dc_ae_local_checkpoint """ -VAE_PATH = PRETRAINED_TOKENIZER_DCAE_PTH +VAE_PATH = PRETRAINED_TOKENIZER_DCAE_4X32X32_C64_T120_256P_FPS_ALL_ENCODER_CAUSAL_DECODER_CHUNKCAUSAL4_NOGAN_COSMOS_PAD_7_V0PT2_PTH LOCAL_CHECKPOINT = os.path.expanduser( "~/work/imaginaire4/logs/cosmos_4x32x32_0211/checkpoints/" "dcae4x32x32_c64_t120_256p_fps_all_encoder_causal_decoder_chunk_causal_4_nogan_cosmos_pad_7_v0.1.pt" diff --git a/cosmos_framework/model/vfm/tokenizers/dc_ae/dc_ae_v.py b/cosmos_framework/model/vfm/tokenizers/dc_ae/dc_ae_v.py index 2d31cae..0c0edd8 100644 --- a/cosmos_framework/model/vfm/tokenizers/dc_ae/dc_ae_v.py +++ b/cosmos_framework/model/vfm/tokenizers/dc_ae/dc_ae_v.py @@ -1,4 +1,17 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright 2025 NVIDIA CORPORATION & AFFILIATES +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# # SPDX-License-Identifier: OpenMDW-1.1 from dataclasses import dataclass, field @@ -688,7 +701,7 @@ def _visit(module: nn.Module) -> None: w, dtype=dtype, device=device, - ) + ).to(memory_format=torch.channels_last_3d) ) elif isinstance(module, ResBlock3d): _visit(module.conv1) @@ -783,7 +796,8 @@ def temporal_tiled_encode(self, x: torch.Tensor) -> torch.Tensor: row = [] for i in tqdm(range(0, x.shape[2], overlap_size), desc="Tiled Encode", disable=not self.cfg.verbose): - tile = x[:, :, i : i + tile_size, :, :] + # Clone is required for compiled tokenizer to avoid recompilation (view has different memory strides). + tile = x[:, :, i : i + tile_size, :, :].clone() actual_t = tile.shape[2] remove_padding = False if actual_t < tile_size and self.cfg.compilable: @@ -897,7 +911,7 @@ def dc_ae_v_f32t4_encoder_causal_decoder_chunk_causal_4( elif name in [ "dcae4x32x32_c64_t120_256p_fps_all_encoder_causal_decoder_chunk_causal_4_nogan_cosmos_pad_7_v0.2", ]: - latent_channels, num_pad_frames, temporal_remainder, scaling_factor = 64, 7, 1, 0.7103 + latent_channels, num_pad_frames, temporal_remainder, scaling_factor = 64, 7, 1, 0.5704 encoder_width_list = [0, 64, 128, 512, 1024, 1024, 1024] elif name in [ "dcae4x32x32_c96_t120_256p_fps_all_encoder_causal_decoder_chunk_causal_4_nogan_cosmos_pad_7_v0.2", diff --git a/cosmos_framework/model/vfm/tokenizers/dc_ae/dc_ae_v_ops.py b/cosmos_framework/model/vfm/tokenizers/dc_ae/dc_ae_v_ops.py index 4a96748..7ce226e 100644 --- a/cosmos_framework/model/vfm/tokenizers/dc_ae/dc_ae_v_ops.py +++ b/cosmos_framework/model/vfm/tokenizers/dc_ae/dc_ae_v_ops.py @@ -1,4 +1,17 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright 2025 NVIDIA CORPORATION & AFFILIATES +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# # SPDX-License-Identifier: OpenMDW-1.1 import collections @@ -311,7 +324,6 @@ def forward( x = F.pad(x, self.custom_padding, mode=self.custom_padding_mode) if self.causal_chunk_length is not None: - B, C, T, H, W = x.shape assert T % self.causal_chunk_length == 0 assert self.conv.stride[0] == 1 diff --git a/cosmos_framework/model/vfm/tokenizers/dc_ae/dc_ae_v_triton_rms_norm.py b/cosmos_framework/model/vfm/tokenizers/dc_ae/dc_ae_v_triton_rms_norm.py index 2e00cbc..b2f5e2e 100644 --- a/cosmos_framework/model/vfm/tokenizers/dc_ae/dc_ae_v_triton_rms_norm.py +++ b/cosmos_framework/model/vfm/tokenizers/dc_ae/dc_ae_v_triton_rms_norm.py @@ -1,4 +1,17 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright 2025 NVIDIA CORPORATION & AFFILIATES +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# # SPDX-License-Identifier: OpenMDW-1.1 diff --git a/cosmos_framework/model/vfm/tokenizers/flux_vae_8x8.py b/cosmos_framework/model/vfm/tokenizers/flux_vae_8x8.py index 98a322a..3446646 100644 --- a/cosmos_framework/model/vfm/tokenizers/flux_vae_8x8.py +++ b/cosmos_framework/model/vfm/tokenizers/flux_vae_8x8.py @@ -376,8 +376,10 @@ def __init__( chunk_duration: int = 1, spatial_compression_factor: int = 8, temporal_compression_factor: int = 1, + causal: bool = True, ): super().__init__(object_store_credential_path_pretrained=object_store_credential_path_pretrained) + self._causal = causal # Load the Flux VAE model, passing backend_args for S3 support vae_path_full = f"s3://{bucket_name}/{vae_path}" diff --git a/cosmos_framework/model/vfm/tokenizers/interface.py b/cosmos_framework/model/vfm/tokenizers/interface.py index 639d241..bd21af5 100644 --- a/cosmos_framework/model/vfm/tokenizers/interface.py +++ b/cosmos_framework/model/vfm/tokenizers/interface.py @@ -87,8 +87,12 @@ def compile_encode( warmup_resolutions: Sequence[str], output_dir: str, aspect_ratio: str | None = None, + backend: str | None = None, + mode: str | None = None, + fullgraph: bool | None = None, + dynamic: bool | None = None, ) -> None: - """AOT-compile the tokenizer for the given resolutions. + """Compile the tokenizer for the given resolutions. Subclasses that support AOT compilation should override this method. The default raises ``NotImplementedError``. @@ -98,6 +102,11 @@ def compile_encode( output_dir: Root directory where compiled artifacts are stored (typically ``config.job.path_local``). aspect_ratio: If given, only compile this single aspect ratio. + --- Only used if the tokenizer does not support AOT compilation --- + backend: Backend to use for compilation. + mode: Mode to use for compilation. + fullgraph: Whether to compile the full graph. + dynamic: Whether to compile the dynamic graph. """ raise NotImplementedError(f"{type(self).__name__} does not support compilation") @@ -106,8 +115,9 @@ def is_chunk_overlap(self): return False @property - def is_causal(self): - return True + def is_causal(self) -> bool: + # Subclasses set self._causal in their __init__ via the `causal` constructor argument. + return getattr(self, "_causal", True) class AudioTokenizerInterface(ABC): diff --git a/cosmos_framework/model/vfm/tokenizers/tokenization_qwen2.py b/cosmos_framework/model/vfm/tokenizers/tokenization_qwen2.py index c85f115..b3cde60 100644 --- a/cosmos_framework/model/vfm/tokenizers/tokenization_qwen2.py +++ b/cosmos_framework/model/vfm/tokenizers/tokenization_qwen2.py @@ -1,5 +1,5 @@ # Copyright 2024 The Qwen Team and The HuggingFace Inc. team. -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: OpenMDW-1.1 # Source Repository: https://github.com/ByteDance-Seed/Bagel @@ -173,7 +173,7 @@ def __init__( continue bpe_merges.append(tuple(line.split())) self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges)))) - + # NOTE: the cache can grow without bound and will get really large for long running processes # (esp. for texts of language that do not use space between word, e.g. Chinese); technically # not a memory leak but appears as one. # GPT2Tokenizer has the same problem, so let's be consistent. diff --git a/cosmos_framework/model/vfm/tokenizers/uniae/noncausal_4x16x16.py b/cosmos_framework/model/vfm/tokenizers/uniae/noncausal_4x16x16.py index 7c1ba17..83ced39 100644 --- a/cosmos_framework/model/vfm/tokenizers/uniae/noncausal_4x16x16.py +++ b/cosmos_framework/model/vfm/tokenizers/uniae/noncausal_4x16x16.py @@ -1,9 +1,9 @@ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: OpenMDW-1.1 -"""UniAE S1 tokenizer wrapper for diffusion training (4x16x16 compression). +"""UniAE S3 tokenizer wrapper for diffusion training (4x16x16 compression). -Wraps the UniAE sparse autoencoder with DenseAutoencoderRuntime (SDPA compiled) +Wraps the UniAE sparse autoencoder with DenseAutoencoderRuntime (batched backend) to provide a VideoTokenizerInterface compatible with diffusion model training. Usage: @@ -13,10 +13,12 @@ vae_pth="s3://bucket0/pretrained/tokenizers/video/cosmos/...", object_store_credential_path_pretrained="credentials/gcp_checkpoint.secret", ) - latents = vae.encode(video) # [B, 3, T, H, W] -> [B, 48, T//4, H//16, W//16] - recon = vae.decode(latents) # [B, 48, T//4, H//16, W//16] -> [B, 3, T, H, W] + latents = vae.encode(video) # [B, 3, T, H, W] -> [B, 48, ceil(T/4), H//16, W//16] + recon = vae.decode(latents) # [B, 48, T_p, H//16, W//16] -> [B, 3, 4*T_p, H, W] """ +from collections.abc import Sequence + import torch from cosmos_framework.utils import log @@ -26,31 +28,30 @@ from cosmos_framework.model.tokenizer.models.sparse_autoencoder import AutoencoderKL from cosmos_framework.model.vfm.tokenizers.interface import VideoTokenizerInterface -# S1 architecture config (avoids importing cosmos_framework/configs/base which pulls in loss deps) -_S1_ARCH = dict( +# S3 architecture config (avoids importing configs/base which pulls in loss deps) +_S3_ARCH = dict( patch_size=(4, 16, 16), in_channels=3072, out_channels=3072, + # Encoder encoder_model_channels=1152, encoder_num_blocks=27, encoder_num_heads=16, encoder_mlp_channels=4304, - encoder_attn_mode="full", - encoder_window_size=None, encoder_pe_mode="joint", encoder_qk_rms_norm=False, encoder_use_bias=True, encoder_use_rms_norm=False, + # Decoder decoder_model_channels=1152, decoder_num_blocks=27, decoder_num_heads=16, decoder_mlp_channels=4304, - decoder_attn_mode="full", - decoder_window_size=None, decoder_pe_mode="joint", decoder_qk_rms_norm=True, decoder_use_bias=False, decoder_use_rms_norm=True, + # Common settings use_decoder=True, quantizer_type="rq", quantizer_codebook_size=65536, @@ -66,16 +67,21 @@ inference_kv_cache_size=0, use_quantizer=False, use_dual_latent=False, - use_text_alignment=True, - use_post_text_alignment=True, + use_text_alignment=False, + use_post_text_alignment=False, ) +# Backward-compat alias: S1 tokenizer had text-alignment disabled. +# Exposed for CI tests that check arch-dict safety (no stale attn keys, all keys +# valid in AutoencoderKL, text-alignment flags are False). +_S1_ARCH = {**_S3_ARCH, "use_text_alignment": False, "use_post_text_alignment": False} + class UniAEVAE: - """UniAE S1 VAE wrapper for diffusion training. + """UniAE S3 VAE wrapper for diffusion training. Loads the UniAE sparse autoencoder checkpoint, wraps it with - DenseAutoencoderRuntime (SDPA backend for compile-friendly inference), + DenseAutoencoderRuntime (batched backend for compile-friendly inference), and provides encode/decode in the standard [B, C, T, H, W] format. """ @@ -86,30 +92,62 @@ def __init__( object_store_credential_path_pretrained: str = "", dtype: torch.dtype = torch.bfloat16, device: str = "cuda", - backend: str = "sdpa", + backend: str = "batched", + pad_frames: int = 1, + pixel_trim: bool = True, + chunk_size: int = 16, ): self.dtype = dtype self.device = device self.z_dim = z_dim + self._pad_frames = pad_frames + self._pixel_trim = pixel_trim self._spatial_compression_factor = 16 self._temporal_compression_factor = 4 # make compatible with meta device autoencoder = AutoencoderKL( - **_S1_ARCH, + **_S3_ARCH, latent_channels=z_dim, quantizer_feature_dim=z_dim, ) - autoencoder.eval() autoencoder.to(device=device, dtype=dtype) # Load checkpoint + if vae_pth and get_rank() == 0: + backend_args = {"backend": "s3", "s3_credential_path": object_store_credential_path_pretrained} + state_dict = easy_io.load(vae_pth, backend_args=backend_args, map_location="cpu", weights_only=False) + if "model" in state_dict: + model_state = state_dict["model"] + elif "state_dict" in state_dict: + model_state = state_dict["state_dict"] + else: + model_state = state_dict + # Checkpoint may be saved from a wrapper with a 'network.' prefix — strip it. + if any(k.startswith("network.") for k in model_state): + model_state = { + k[len("network.") :] if k.startswith("network.") else k: v for k, v in model_state.items() + } + missing, unexpected = autoencoder.load_state_dict(model_state, strict=False) + if missing: + log.warning(f"Missing keys: {len(missing)} (e.g., {missing[:3]})") + if unexpected: + log.warning(f"Unexpected keys: {len(unexpected)} (e.g., {unexpected[:3]})") + log.info(f"Loaded checkpoint from {vae_pth}") + elif vae_pth: + autoencoder.to_empty(device=device) if vae_pth: - self._load_checkpoint(autoencoder, vae_pth, object_store_credential_path_pretrained, device) + sync_model_states(autoencoder) # Wrap with dense runtime for fast inference - self.dense_runtime = DenseAutoencoderRuntime.from_autoencoder(autoencoder, backend=backend) + self.dense_runtime = DenseAutoencoderRuntime.from_autoencoder( + autoencoder, + backend=backend, + pad_frames=self._pad_frames, + pixel_trim=self._pixel_trim, + chunk_size=chunk_size, + ) self.dense_runtime.eval() # Freeze all parameters @@ -117,59 +155,9 @@ def __init__( param.requires_grad = False log.info( - f"UniAE S1 loaded: {self.count_param() / 1e6:.1f}M params, " - f"backend={backend}, dtype={dtype}, device={device}" + f"UniAE loaded: {self.count_param() / 1e6:.1f}M params, backend={backend}, dtype={dtype}, device={device}" ) - def _load_checkpoint(self, model, pretrained_path, credential_path, device): - """Load checkpoint from local path or S3.""" - if get_rank() == 0: - if pretrained_path.startswith("s3://"): - backend_args = { - "backend": "s3", - "s3_credential_path": credential_path, - } - else: - backend_args = None - - ckpt = easy_io.load( - pretrained_path, - backend_args=backend_args, - map_location=device, - ) - - # Handle different checkpoint formats - if isinstance(ckpt, dict): - if "model" in ckpt: - state_dict = ckpt["model"] - elif "state_dict" in ckpt: - state_dict = ckpt["state_dict"] - elif "network" in ckpt: - state_dict = ckpt["network"] - else: - state_dict = ckpt - else: - state_dict = ckpt - - # Strip common prefixes - cleaned = {} - for k, v in state_dict.items(): - for prefix in ["network.", "module.", "model."]: - if k.startswith(prefix): - k = k[len(prefix) :] - cleaned[k] = v - - missing, unexpected = model.load_state_dict(cleaned, strict=False) - if missing: - log.warning(f"Missing keys: {len(missing)} (e.g., {missing[:3]})") - if unexpected: - log.warning(f"Unexpected keys: {len(unexpected)} (e.g., {unexpected[:3]})") - log.info(f"Loaded checkpoint from {pretrained_path}") - else: - model.to_empty(device=device) - - sync_model_states(model) - def count_param(self) -> int: return sum(p.numel() for p in self.dense_runtime.parameters()) @@ -177,44 +165,37 @@ def count_param(self) -> int: def encode(self, video: torch.Tensor) -> torch.Tensor: """Encode image or video to latent space. - For images (T=1 or 4D input), the input is repeated to 4 frames - since the non-causal tokenizer requires a full temporal patch. + Boundary padding and latent trimming are handled by DenseAutoencoderRuntime + via pad_frames (default 4) and pixel_trim. The full clip is encoded as a + single chunk. Args: - video: [B, 3, T, H, W] or [B, 3, H, W] (image) in range [-1, 1] + video: [B, 3, T, H, W] or [B, 3, H, W] (image) in range [-1, 1]. + T must satisfy T % 4 == 0. Returns: - latent: [B, z_dim, T//4, H//16, W//16] - For single-image input, T//4 = 1. + latent: [B, z_dim, ceil(T/4), H//16, W//16] + For single-image input, ceil(T/4) = 1. """ + tc = self._temporal_compression_factor # 4 + # Handle image input: [B, C, H, W] -> [B, C, 4, H, W] - is_image = video.ndim == 4 - if is_image: - video = video.unsqueeze(2) - video = torch.nn.functional.pad( - video, (0, 0, 0, 0, 0, self._temporal_compression_factor - 1), mode="constant", value=0.0 - ) + if video.ndim == 4: + video = video.unsqueeze(2).expand(-1, -1, tc, -1, -1).clone() B, C, T, H, W = video.shape - tc = self._temporal_compression_factor - - # For non-causal tokenizer, repeat last frame to fill last temporal patch - if T % tc != 0: - pad_t = tc - T % tc - last_frame = video[:, :, -1:].expand(-1, -1, pad_t, -1, -1) - video = torch.cat([video, last_frame], dim=2) - T = T + pad_t + full_chunk_size = self.dense_runtime.encoder_chunk_spec.raw_frames + chunk_size = full_chunk_size - 2 * self._pad_frames # Convert to channels-last [B, T, H, W, C] for dense runtime video_cl = video.permute(0, 2, 3, 4, 1).contiguous().to(dtype=self.dtype) - # Encode: returns [B, T_p, H_p, W_p, 2*z_dim] moments - moments = self.dense_runtime.encode(video_cl, sample_posterior=False) - - # Take mean (first half of channels) for deterministic encoding - mean, logvar = moments.chunk(2, dim=-1) + # Encode full clip as one chunk; dense_runtime handles boundary padding via pad_frames. + # Returns [B, T//4, H_p, W_p, 2*z_dim] (boundary latents trimmed inside dense_runtime). + moments = self.dense_runtime.encode(video_cl, sample_posterior=False, chunk_raw_frames=chunk_size) - # Convert to [B, z_dim, T_p, H_p, W_p] + # Take mean for deterministic encoding; convert to [B, z_dim, T_p, H_p, W_p] + mean, _ = moments.chunk(2, dim=-1) return mean.permute(0, 4, 1, 2, 3).contiguous() @torch.inference_mode() @@ -230,15 +211,18 @@ def decode(self, latent: torch.Tensor) -> torch.Tensor: # Convert to channels-last [B, T_p, H_p, W_p, z_dim] latent_cl = latent.permute(0, 2, 3, 4, 1).contiguous().to(dtype=self.dtype) - # Decode: returns [B, T, H, W, C] channels-last - decoded = self.dense_runtime.decode(latent_cl) + # Decode full clip as one chunk to avoid seam artifacts. + T_latent = latent.shape[2] + chunk_raw_frames = T_latent * self._temporal_compression_factor + decoded = self.dense_runtime.decode(latent_cl, chunk_raw_frames=chunk_raw_frames) # Convert to [B, C, T, H, W] and clamp video = decoded.permute(0, 4, 1, 2, 3).contiguous() return video.clamp(-1, 1).float() def get_latent_num_frames(self, num_pixel_frames: int) -> int: - return num_pixel_frames // self._temporal_compression_factor + tc = self._temporal_compression_factor + return (num_pixel_frames + tc - 1) // tc def get_pixel_num_frames(self, num_latent_frames: int) -> int: return num_latent_frames * self._temporal_compression_factor @@ -255,11 +239,18 @@ def __init__( encode_chunk_frames: int = 16, spatial_compression_factor: int = 16, temporal_compression_factor: int = 4, + pad_frames: int = 0, + pixel_trim: bool = True, + backend: str = "batched_with_padding", + causal: bool = False, ): super().__init__(object_store_credential_path_pretrained) + self._causal = causal + assert not self._causal, "UniAEVAEInterface is a non-causal tokenizer; causal must be False." self._spatial_compression_factor = spatial_compression_factor self._temporal_compression_factor = temporal_compression_factor self.encode_chunk_frames = encode_chunk_frames + # unused parameter self.use_streaming_encode = False vae_full_path = vae_path @@ -269,7 +260,12 @@ def __init__( self.vae = UniAEVAE( vae_pth=vae_full_path, object_store_credential_path_pretrained=object_store_credential_path_pretrained, + pad_frames=pad_frames, + pixel_trim=pixel_trim, + backend=backend, + chunk_size=self.encode_chunk_frames, ) + self.is_compiled = False def reset_dtype(self): pass @@ -295,6 +291,28 @@ def compile_encode_for_cudagraphs( self.vae.dense_runtime._encode_chunk_core = torch.compile( self.vae.dense_runtime._encode_chunk_core, **compile_kwargs ) + self.is_compiled = True + + @torch.inference_mode() + def compile_encode( + self, + warmup_resolutions: Sequence[str], + output_dir: str | None = None, + aspect_ratio: str | None = None, + backend: str | None = "inductor", + mode: str | None = "reduce-overhead", + fullgraph: bool = False, + dynamic: bool = False, + ) -> None: + """Compile the encode function for the given resolutions.""" + if self.is_compiled: + log.warning("Tokenizer is already compiled, skipping compilation.") + return + + if backend is None: + raise ValueError("backend must be provided") + + self.compile_encode_for_cudagraphs(mode=mode, fullgraph=fullgraph, dynamic=dynamic, backend=backend) def decode(self, latent: torch.Tensor) -> torch.Tensor: return self.vae.decode(latent) @@ -327,8 +345,4 @@ def latent_chunk_duration(self): @property def latent_ch(self) -> int: - return 48 - - @property - def is_causal(self): - return False + return self.vae.z_dim diff --git a/cosmos_framework/model/vfm/tokenizers/uniae/noncausal_4x16x16_test.py b/cosmos_framework/model/vfm/tokenizers/uniae/noncausal_4x16x16_test.py index 9feb2a3..0a855a3 100644 --- a/cosmos_framework/model/vfm/tokenizers/uniae/noncausal_4x16x16_test.py +++ b/cosmos_framework/model/vfm/tokenizers/uniae/noncausal_4x16x16_test.py @@ -4,19 +4,20 @@ # ----------------------------------------------------------------------------- """ -Tests for UniAE S1 tokenizer (4x16x16). +Tests for UniAE S3 tokenizer (4x16x16). Usage: # Basic encode/decode test with random data - CUDA_VISIBLE_DEVICES=0 RUN_SKIPPED_TEST_LOCALLY=1 pytest -s cosmos_framework/model/vfm/tokenizers/uniae/noncausal_4x16x16_test.py -k test_uniae_s1 + CUDA_VISIBLE_DEVICES=0 RUN_SKIPPED_TEST_LOCALLY=1 pytest -s cosmos_framework/model/vfm/tokenizers/uniae/noncausal_4x16x16_test.py -k test_uniae_s3 - # Full reconstruction test with real video (saves uniae_recon.mp4) + # Full reconstruction test with real video (saves uniae_s3_recon.mp4) CUDA_VISIBLE_DEVICES=0 RUN_SKIPPED_TEST_LOCALLY=1 pytest -s cosmos_framework/model/vfm/tokenizers/uniae/noncausal_4x16x16_test.py -k test_local_video Note: On this machine, CUDA device 0 = RTX 6000 Ada (48GB), device 1 = T400 (2GB). Always use CUDA_VISIBLE_DEVICES=0 for the RTX 6000. """ +import inspect import os import numpy as np @@ -25,9 +26,10 @@ from cosmos_framework.utils.easy_io import easy_io from cosmos_framework.utils.helper_test import RunIf +from cosmos_framework.model.tokenizer.models.sparse_autoencoder import AutoencoderKL from cosmos_framework.configs.base.defaults.cluster import DefaultClusterConfig as CLUSTER_CONFIG from cosmos_framework.configs.base.defaults.unittest import TOKENIZER_RECONSTRUCTION_VIDEO_PATH, UNITTEST_CONFIG -from cosmos_framework.model.vfm.tokenizers.uniae.noncausal_4x16x16 import UniAEVAE +from cosmos_framework.model.vfm.tokenizers.uniae.noncausal_4x16x16 import _S1_ARCH, UniAEVAE from cosmos_framework.model.vfm.tokenizers.unittest_utils import ( numpy2tensor, pad_video_batch, @@ -35,41 +37,78 @@ unpad_video_batch, ) -UNIAE_S1_PATH = ( - "s3://bucket0/pretrained/tokenizers/video/cosmos/" - "uniae4x16x16_c48_t8to24_64to512p_fps_all_encoder_noncausal_decoder_noncausal_nogan_best_s1.pt" +UNIAE_S3_PATH = ( + "s3://bucket1/uniae/tok_experiments/" + "uniae_s3_prod32_ditval_video_b1_50k_r1/checkpoints/iter_000050000.pt" ) +@pytest.mark.L0 +def test_uniae_s1_arch_matches_autoencoder_signature() -> None: + """The VFM wrapper should not pass stale tokenizer-training kwargs.""" + legacy_attention_keys = { + "encoder_attn_mode", + "encoder_window_size", + "decoder_attn_mode", + "decoder_window_size", + } + signature_keys = set(inspect.signature(AutoencoderKL.__init__).parameters) + + assert legacy_attention_keys.isdisjoint(_S1_ARCH) + assert set(_S1_ARCH).issubset(signature_keys) + assert _S1_ARCH["use_text_alignment"] is False + assert _S1_ARCH["use_post_text_alignment"] is False + + +@pytest.mark.L0 +@pytest.mark.parametrize( + ("num_pixel_frames", "expected_latent_frames"), + [ + (1, 1), + (2, 1), + (4, 1), + (5, 2), + (7, 2), + (8, 2), + (13, 4), + (16, 4), + ], +) +def test_uniae_latent_num_frames_matches_noncausal_padding( + num_pixel_frames: int, + expected_latent_frames: int, +) -> None: + """Frame-count helper should match encode's pad-to-multiple behavior.""" + vae = UniAEVAE.__new__(UniAEVAE) + vae._temporal_compression_factor = 4 + + assert vae.get_latent_num_frames(num_pixel_frames) == expected_latent_frames + + @pytest.mark.L0 @pytest.mark.skipif(os.getenv("RUN_SKIPPED_TEST_LOCALLY") != "1", reason="local_test_only") -def test_uniae_s1(): - """Basic encode/decode test with random data.""" +def test_uniae_s3(): + """Basic shape check: encode/decode with random data for a few T values.""" vae = UniAEVAE( - vae_pth=UNIAE_S1_PATH, + vae_pth=UNIAE_S3_PATH, object_store_credential_path_pretrained=CLUSTER_CONFIG.object_store_credential_pretrained, device="cuda", + dtype=torch.bfloat16, ) - print(f"\n[UniAE S1] Model parameters: {vae.count_param() / 1e6:.2f}M") + print(f"\n[UniAE S3] Model parameters: {vae.count_param() / 1e6:.2f}M") H, W = 256, 256 - for T in [4, 16, 32]: - print(f"\n[UniAE S1] Testing with T={T} frames, H={H}, W={W}") - video = torch.randn(1, 3, T, H, W, device="cuda") + for T in [4, 52, 100, 148]: + video = torch.randn(1, 3, T, H, W, device="cuda", dtype=torch.bfloat16) latents = vae.encode(video) - print(f" Input video shape: {video.shape} -> Latent shape: {latents.shape}") - print(f" Latent stats: mean={latents.mean():.4f}, std={latents.std():.4f}") video_recon = vae.decode(latents) - print(f" Reconstructed video shape: {video_recon.shape}") - - # Verify shapes - expected_T_latent = T // 4 - expected_H_latent = H // 16 - expected_W_latent = W // 16 - assert latents.shape == (1, 48, expected_T_latent, expected_H_latent, expected_W_latent), ( - f"Expected latent shape (1, 48, {expected_T_latent}, {expected_H_latent}, {expected_W_latent}), " - f"got {latents.shape}" + + expected_T_latent = vae.get_latent_num_frames(T) + assert latents.shape == (1, vae.z_dim, expected_T_latent, H // 16, W // 16), ( + f"T={T}: unexpected latent shape {tuple(latents.shape)}" ) + assert video_recon.shape == (1, 3, T, H, W), f"T={T}: unexpected recon shape {tuple(video_recon.shape)}" + print(f" T={T:3d} latent={tuple(latents.shape[1:])} recon={tuple(video_recon.shape[2:])} OK") @pytest.mark.L0 @@ -81,59 +120,85 @@ def test_uniae_s1(): ) @pytest.mark.skipif(os.getenv("RUN_SKIPPED_TEST_LOCALLY") != "1", reason="local_test_only") def test_local_video(): - """Full reconstruction test with a real video — saves output to logs/uniae_recon.mp4.""" + """Reconstruction test with real video for T in [52,56,...,148]; plots frame-wise PSNR per T.""" + import matplotlib.pyplot as plt + vae = UniAEVAE( - vae_pth=UNIAE_S1_PATH, + vae_pth=UNIAE_S3_PATH, object_store_credential_path_pretrained=CLUSTER_CONFIG.object_store_credential_pretrained, device="cuda", dtype=torch.bfloat16, ) - # Load video as numpy array (T, H, W, C) in range [0, 255] - video_in_numpy = easy_io.load( + # Load enough frames to cover all T values + T_values = list(range(52, 149, 4)) # 52, 56, ..., 148 + max_T = max(T_values) + video_full = easy_io.load( os.path.join(f"s3://{UNITTEST_CONFIG.object_store_bucket_data}", TOKENIZER_RECONSTRUCTION_VIDEO_PATH), backend_args={ "backend": "s3", "s3_credential_path": UNITTEST_CONFIG.object_store_credential_data, }, - )[0][:32] # Take 32 frames (divisible by 4) - - # Pad video to meet stride alignment requirements - padded_video_batch, crop_region = pad_video_batch( - video_in_numpy[np.newaxis, ...], # Add batch dimension - temporal_align=4, # Temporal compression factor - spatial_align=16, # Spatial compression factor - causal_mode=False, # UniAE is non-causal - only_pad_end=True, - ) - - # Convert to tensor format (B, C, T, H, W) in range [-1, 1] - video_tensor = numpy2tensor(padded_video_batch) - - # Encode and decode - print(f"\n[UniAE S1 Local Video] Input tensor shape: {video_tensor.shape}") - latents = vae.encode(video_tensor) - print(f"[UniAE S1 Local Video] Latent shape: {latents.shape}") - print(f"[UniAE S1 Local Video] Latent statistics: mean={latents.mean():.4f}, std={latents.std():.4f}") - video_recon = vae.decode(latents) - print(f"[UniAE S1 Local Video] Reconstructed shape: {video_recon.shape}") - - # Convert back to numpy and unpad - video_recon_numpy = tensor2numpy(video_recon) - video_recon_unpadded = unpad_video_batch(video_recon_numpy, crop_region) + )[0] # [T_total, H, W, C] + available_T = video_full.shape[0] + print(f"\n[UniAE S3] Video loaded: {video_full.shape}, using T up to {min(max_T, available_T)}") + + os.makedirs("logs", exist_ok=True) + cmap = plt.get_cmap("viridis") + fig, ax = plt.subplots(figsize=(14, 6)) + mean_psnrs = [] + + for i, T in enumerate(T_values): + if T > available_T: + print(f" T={T}: skipped (video only has {available_T} frames)") + continue + + video_in_numpy = video_full[:T] # [T, H, W, C] + + padded_video_batch, crop_region = pad_video_batch( + video_in_numpy[np.newaxis, ...], + temporal_align=4, + spatial_align=16, + causal_mode=False, + only_pad_end=True, + ) + video_tensor = numpy2tensor(padded_video_batch).cuda() - # Compute PSNR - gt = video_in_numpy[: video_recon_unpadded.shape[1]].astype(np.float32) - recon = video_recon_unpadded[0].astype(np.float32) - mse = np.mean((gt - recon) ** 2) - psnr = 10 * np.log10(255**2 / max(mse, 1e-10)) - print(f"[UniAE S1 Local Video] PSNR: {psnr:.2f} dB") + latents = vae.encode(video_tensor) + video_recon = vae.decode(latents) - # Save reconstruction - output_path = os.path.expanduser("logs/uniae_recon.mp4") - os.makedirs(os.path.dirname(output_path), exist_ok=True) - easy_io.dump(video_recon_unpadded[0].astype("uint8"), output_path) - print(f"[UniAE S1 Local Video] Saved reconstruction to: {output_path}") + video_recon_numpy = tensor2numpy(video_recon) + video_recon_unpadded = unpad_video_batch(video_recon_numpy, crop_region) + + gt = video_in_numpy[: video_recon_unpadded.shape[1]].astype(np.float32) + recon = video_recon_unpadded[0].astype(np.float32) + mse_per_frame = np.mean((gt - recon) ** 2, axis=(1, 2, 3)) + psnr_per_frame = 10 * np.log10(255**2 / np.maximum(mse_per_frame, 1e-10)) + mean_psnr = float(psnr_per_frame.mean()) + mean_psnrs.append((T, mean_psnr)) + + assert mean_psnr >= 30.0, f"T={T}: mean PSNR {mean_psnr:.2f} dB < 30 dB threshold" + color = cmap(i / len(T_values)) + ax.plot(psnr_per_frame, color=color, alpha=0.7, linewidth=0.8, label=f"T={T} ({mean_psnr:.1f}dB)") + print(f" T={T:3d} latent={tuple(latents.shape[1:])} mean PSNR={mean_psnr:.2f} dB") + + # Save reconstructed video + video_path = f"logs/uniae_s3_recon_T{T:03d}.mp4" + easy_io.dump(video_recon_unpadded[0].astype("uint8"), video_path) + print(f" saved {video_path}") + + ax.set_xlabel("Frame index") + ax.set_ylabel("PSNR (dB)") + ax.set_title("UniAE S3 frame-wise PSNR — real video, T ∈ [52, 148] step 4") + ax.legend(loc="upper right", fontsize=6, ncol=4) + fig.tight_layout() + plot_path = "logs/uniae_s3_local_framewise_psnr.png" + fig.savefig(plot_path, dpi=150) + plt.close(fig) + print(f"\n[UniAE S3] Plot saved to {plot_path}") + if mean_psnrs: + psnrs = [p for _, p in mean_psnrs] + print(f"[UniAE S3] Mean PSNR range: {min(psnrs):.2f} – {max(psnrs):.2f} dB") """ @@ -155,7 +220,7 @@ def test_local_image(): from PIL import Image vae = UniAEVAE( - vae_pth=UNIAE_S1_PATH, + vae_pth=UNIAE_S3_PATH, object_store_credential_path_pretrained=CLUSTER_CONFIG.object_store_credential_pretrained, device="cuda", dtype=torch.bfloat16, @@ -171,7 +236,7 @@ def test_local_image(): )[0][0] # First frame: (H, W, C) in [0, 255] H, W, C = video_in_numpy.shape - print(f"\n[UniAE S1 Image] Original image shape: ({H}, {W}, {C})") + print(f"\n[UniAE S3 Image] Original image shape: ({H}, {W}, {C})") # Pad spatial dimensions to be divisible by 16 pad_h = (16 - H % 16) % 16 @@ -184,14 +249,14 @@ def test_local_image(): image_tensor = torch.from_numpy(video_in_numpy).float().permute(2, 0, 1) / 127.5 - 1.0 # (C, H, W) image_batch = image_tensor.unsqueeze(0).cuda() # (1, C, H, W) - print(f"[UniAE S1 Image] Input tensor shape: {image_batch.shape}") + print(f"[UniAE S3 Image] Input tensor shape: {image_batch.shape}") # Encode and decode (encode handles repeat to 4 frames internally) latents = vae.encode(image_batch) - print(f"[UniAE S1 Image] Latent shape: {latents.shape}") - print(f"[UniAE S1 Image] Latent statistics: mean={latents.mean():.4f}, std={latents.std():.4f}") + print(f"[UniAE S3 Image] Latent shape: {latents.shape}") + print(f"[UniAE S3 Image] Latent statistics: mean={latents.mean():.4f}, std={latents.std():.4f}") video_recon = vae.decode(latents) - print(f"[UniAE S1 Image] Reconstructed shape: {video_recon.shape}") + print(f"[UniAE S3 Image] Reconstructed shape: {video_recon.shape}") # Take the first frame as reconstructed image recon_image = video_recon[0, :, 0].clamp(-1, 1) # (C, H, W) @@ -205,7 +270,7 @@ def test_local_image(): recon_f = recon_numpy.astype(np.float32) mse = np.mean((gt - recon_f) ** 2) psnr = 10 * np.log10(255**2 / max(mse, 1e-10)) - print(f"[UniAE S1 Image] PSNR: {psnr:.2f} dB") + print(f"[UniAE S3 Image] PSNR: {psnr:.2f} dB") # Save original and reconstruction side by side output_path = os.path.expanduser("logs/uniae_image_recon.png") @@ -216,4 +281,4 @@ def test_local_image(): side_by_side.paste(orig_img, (0, 0)) side_by_side.paste(recon_img, (W + 10, 0)) side_by_side.save(output_path) - print(f"[UniAE S1 Image] Saved side-by-side to: {output_path}") + print(f"[UniAE S3 Image] Saved side-by-side to: {output_path}") diff --git a/cosmos_framework/model/vfm/tokenizers/wan2pt1_vae_4x8x8.py b/cosmos_framework/model/vfm/tokenizers/wan2pt1_vae_4x8x8.py index d838bd7..542c90f 100644 --- a/cosmos_framework/model/vfm/tokenizers/wan2pt1_vae_4x8x8.py +++ b/cosmos_framework/model/vfm/tokenizers/wan2pt1_vae_4x8x8.py @@ -757,7 +757,10 @@ def __init__( use_channels_last_memory_format: bool = False, spatial_compression_factor: int = 8, temporal_compression_factor: int = 4, + causal: bool = True, ): + self._causal = causal + assert self._causal, "Wan2pt1VAEInterface is a causal tokenizer; causal must be True." vae_path_full = f"s3://{bucket_name}/{vae_path}" self.keep_decoder_cache = keep_decoder_cache self.keep_encoder_cache = keep_encoder_cache diff --git a/cosmos_framework/model/vfm/tokenizers/wan2pt2_vae_4x16x16.py b/cosmos_framework/model/vfm/tokenizers/wan2pt2_vae_4x16x16.py index 954f05a..becde09 100644 --- a/cosmos_framework/model/vfm/tokenizers/wan2pt2_vae_4x16x16.py +++ b/cosmos_framework/model/vfm/tokenizers/wan2pt2_vae_4x16x16.py @@ -1323,6 +1323,7 @@ def __init__( # with older configurations. temporal_window: int | None = None, encode_bucket_multiple: int | None = None, + causal: bool = True, ): # Remove temporal_window and encode_bucket_multiple once they have been # removed from the uploaded HuggingFace checkpoint. @@ -1345,9 +1346,7 @@ def __init__( self.chunk_duration = chunk_duration - # Local-path support: skip the s3:// prefix when bucket_name is empty - # so OSS users can point vae_path at an absolute local file. - vae_path_full = f"s3://{bucket_name}/{vae_path}" if bucket_name else vae_path + vae_path_full = f"s3://{bucket_name}/{vae_path}" self.model = WanVAE( dtype=torch.bfloat16, is_amp=False, @@ -1367,6 +1366,8 @@ def __init__( self._spatial_compression_factor = spatial_compression_factor self._temporal_compression_factor = temporal_compression_factor + self._causal = causal + assert self._causal, "Wan2pt2VAEInterface is a causal tokenizer; causal must be True." @property def dtype(self) -> torch.dtype: @@ -1417,6 +1418,8 @@ def compile_encode( warmup_resolutions: Sequence[str], output_dir: str, aspect_ratio: str | None = None, + # ignores torch compile args + **kwargs, ) -> None: """AOT-compile the tokenizer's chunk-level encode for every resolution. diff --git a/cosmos_framework/model/vfm/upsampler/prompts.py b/cosmos_framework/model/vfm/upsampler/prompts.py index 42289da..fd3def1 100644 --- a/cosmos_framework/model/vfm/upsampler/prompts.py +++ b/cosmos_framework/model/vfm/upsampler/prompts.py @@ -908,7 +908,7 @@ def is_upsampled_prompt(prompt: str) -> bool: the native upsampler again. Used by inference callers (e.g. - ``cosmos_framework.inference.OmniInference._iter_predictions``) to decide + ``cosmos3.inference.OmniInference._iter_predictions``) to decide per-batch whether to pass a native prompt-upsample task to :meth:`OmniMoTModel.generate_samples_from_batch`. Two motivating cases produce already-upsampled prompts: diff --git a/cosmos_framework/model/vfm/utils/safetensors_loader.py b/cosmos_framework/model/vfm/utils/safetensors_loader.py index 19ae6b3..15b8774 100644 --- a/cosmos_framework/model/vfm/utils/safetensors_loader.py +++ b/cosmos_framework/model/vfm/utils/safetensors_loader.py @@ -988,7 +988,7 @@ def load_language_model( if tie_embeddings: # The `*ForCausalLM` classes in - # `projects/cosmos3/vfm/models/mot/unified_mot.py` override + # `cosmos_framework/model/vfm/mot/unified_mot.py` override # `get_input_embeddings` (canonical HF idiom) to return the inner # `model.embed_tokens`, so this call returns a real `nn.Embedding` # rather than raising `NotImplementedError`. @@ -1200,7 +1200,7 @@ def load_vfm_model( r"""Load a complete Cosmos3 VFM checkpoint (safetensors) into a Cosmos3VFMNetwork. Loads the *entire* state of a - :class:`~projects.cosmos3.vfm.models.mot.cosmos3_vfm_network.Cosmos3VFMNetwork` + :class:`~cosmos_framework.model.vfm.mot.cosmos3_vfm_network.Cosmos3VFMNetwork` in one shot: - the language tower (``language_model.*``), which carries the diff --git a/cosmos_framework/model/vfm/utils/safetensors_loader_test.py b/cosmos_framework/model/vfm/utils/safetensors_loader_test.py index 151db3c..2f154ff 100644 --- a/cosmos_framework/model/vfm/utils/safetensors_loader_test.py +++ b/cosmos_framework/model/vfm/utils/safetensors_loader_test.py @@ -64,12 +64,12 @@ def _make_safetensors(tmp_path: Path, tensors: dict[str, torch.Tensor]) -> Path: return ckpt_dir - +# NOTE on ``parallel_dims`` in ``load_vlm_model`` tests: # # The single-rank CPU fallback is reached by passing ``parallel_dims=None`` # (the documented escape hatch — see ``load_vlm_model`` docstring). All # end-to-end tests below use that path; multi-rank behavior is covered in -# the GPU-marked tests under ``projects/cosmos3/vfm/models/mot/``. +# the GPU-marked tests under ``cosmos_framework/model/vfm/mot/``. # # Do NOT introduce a "fake" ``ParallelDims`` MagicMock fixture for this # fallback: ``MagicMock.__getitem__`` returns another MagicMock rather than diff --git a/cosmos_framework/model/vfm/vlm/qwen3_vl/configuration_qwen3_vl.py b/cosmos_framework/model/vfm/vlm/qwen3_vl/configuration_qwen3_vl.py index 9eaa380..03d8e55 100644 --- a/cosmos_framework/model/vfm/vlm/qwen3_vl/configuration_qwen3_vl.py +++ b/cosmos_framework/model/vfm/vlm/qwen3_vl/configuration_qwen3_vl.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: OpenMDW-1.1 # Source Repository: https://github.com/huggingface/transformers diff --git a/cosmos_framework/model/vfm/vlm/qwen3_vl/qwen3_vl.py b/cosmos_framework/model/vfm/vlm/qwen3_vl/qwen3_vl.py index c2ccc3f..3736d23 100644 --- a/cosmos_framework/model/vfm/vlm/qwen3_vl/qwen3_vl.py +++ b/cosmos_framework/model/vfm/vlm/qwen3_vl/qwen3_vl.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: OpenMDW-1.1 # Source Repository: https://github.com/huggingface/transformers @@ -33,6 +33,21 @@ from transformers.modeling_flash_attention_utils import FlashAttentionKwargs from transformers.modeling_outputs import BaseModelOutputWithPast, ModelOutput from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update + +# "default" rope type was removed from ROPE_INIT_FUNCTIONS in transformers>=5.x +if "default" not in ROPE_INIT_FUNCTIONS: + + def _default_rope_init(config, device=None, **kwargs): + base = config.rope_theta + head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads + partial_rotary_factor = getattr(config, "partial_rotary_factor", 1.0) + dim = int(head_dim * partial_rotary_factor) + inv_freq = 1.0 / ( + base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim) + ) + return inv_freq, 1.0 + + ROPE_INIT_FUNCTIONS["default"] = _default_rope_init from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from transformers.processing_utils import Unpack from transformers.utils import is_torchdynamo_compiling diff --git a/cosmos_framework/model/vfm/vlm/qwen3_vl/utils.py b/cosmos_framework/model/vfm/vlm/qwen3_vl/utils.py index 5a8ecff..a25d96d 100644 --- a/cosmos_framework/model/vfm/vlm/qwen3_vl/utils.py +++ b/cosmos_framework/model/vfm/vlm/qwen3_vl/utils.py @@ -1,5 +1,17 @@ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: OpenMDW-1.1 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. # Core masking functions extracted from transformers.masking_utils for BAGEL compatibility # Original Copyright 2025 HuggingFace Inc. team. Licensed under the Apache License, Version 2.0 diff --git a/cosmos_framework/model/vfm/vlm/qwen3_vl/video_processing_qwen3_vl.py b/cosmos_framework/model/vfm/vlm/qwen3_vl/video_processing_qwen3_vl.py index 717ef0a..9cfec3c 100644 --- a/cosmos_framework/model/vfm/vlm/qwen3_vl/video_processing_qwen3_vl.py +++ b/cosmos_framework/model/vfm/vlm/qwen3_vl/video_processing_qwen3_vl.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: OpenMDW-1.1 # Source Repository: https://github.com/huggingface/transformers diff --git a/cosmos_framework/model/vfm/vlm_model.py b/cosmos_framework/model/vfm/vlm_model.py index 163904e..d1ce765 100644 --- a/cosmos_framework/model/vfm/vlm_model.py +++ b/cosmos_framework/model/vfm/vlm_model.py @@ -124,7 +124,7 @@ def _get_overlay_config(model_type: str) -> tuple[list[str], Callable[[str], boo def _get_vision_encoder_modules(model: nn.Module, model_type: str) -> list: if model_type in _QWEN_VL_TYPES: - + # NOTE: intentional semantic change from `model_utils.get_model_vision_encoder`, # which returns only [patch_embed, blocks]. Qwen3-VL adds a learnable `pos_embed` # (nn.Embedding — see qwen3_vl.py Qwen3VLVisionModel); leaving it trainable while # freezing the rest of the vision encoder contradicts the intent of @@ -395,18 +395,8 @@ def _init_vlm(self, config: VLMModelConfig, checkpoint) -> None: # ── g. Load pretrain weights ── if load_pretrain_weights: - if policy.backbone.safetensors_path: - safetensors_local_path = maybe_download_hf_model_from_s3( - policy.backbone.safetensors_path, - checkpoint.load_from_object_store.credentials, - checkpoint.load_from_object_store.bucket, - include_model_weights=True, - ) - else: - safetensors_local_path = local_path - hf_model.load_weights( - checkpoint_path=safetensors_local_path, + checkpoint_path=local_path, credential_path=None, # local path after download parallel_dims=parallel_dims if torch.distributed.is_initialized() else None, ) diff --git a/cosmos_framework/scripts/train.py b/cosmos_framework/scripts/train.py index 7d19e1c..1da3039 100644 --- a/cosmos_framework/scripts/train.py +++ b/cosmos_framework/scripts/train.py @@ -1,28 +1,6 @@ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: OpenMDW-1.1 -"""SFT training entrypoint backed by the structured TOML dataclass. - -Sole input is ``--sft-toml `` — no ``--config`` or interface_toml flow. - -Usage:: - - torchrun --nproc_per_node= -m cosmos_framework.scripts.train \\ - --sft-toml=examples/toml/sft_config/.toml \\ - -- optimizer.lr=1e-5 trainer.max_iter=200 - -The TOML is loaded via ``SFTExperimentConfig.from_toml`` (structural validation, -raises on unknown keys), then -``cosmos_framework.configs.toml_config.sft_config.load_experiment_from_toml`` picks the -base ``config.py`` from ``[job].task`` (``vfm`` → ``cosmos_framework/configs/base/config.py``, -``vlm`` → ``cosmos_framework/configs/base/vlm/config.py``), resolves ``[job].experiment`` -against the Hydra ``ConfigStore``, and overlays every other TOML key as a Hydra -override. Trailing ``key.path=value`` positionals are applied last (so they -win over TOML). -""" - -from __future__ import annotations - import argparse import os import traceback @@ -30,15 +8,13 @@ import torch from loguru import logger as logging -from cosmos_framework.utils.config import Config +from cosmos_framework.utils.config import Config, load_config, pretty_print_overrides from cosmos_framework.utils.lazy_config import LazyConfig, instantiate from cosmos_framework.utils.serialization import to_yaml from cosmos_framework.utils import distributed from cosmos_framework.utils.context_managers import data_loader_init, distributed_init, model_init from cosmos_framework.utils.launch import log_reproducible_setup from cosmos_framework.utils.training_telemetry import telemetry -from cosmos_framework.configs.toml_config.sft_config import load_experiment_from_toml - # --------------------------------------------------------------------------- # --deterministic: mirrors launch_vfm.sh determinism settings. @@ -50,7 +26,7 @@ # and torch backend flags take effect. # 2. _apply_deterministic_config_overrides() — after load_config but before # config.freeze(), so the config mutations land before trainer.__init__ -# re-applies cudnn from config (imaginaire/trainer.py:125-126). +# re-applies cudnn from config (cosmos_framework/trainer.py:125-126). # # PYTHONHASHSEED must be set externally (Python locks it at interpreter startup); # we only warn when it's missing. @@ -69,6 +45,9 @@ def _setup_deterministic_env_and_backends() -> None: # CUBLAS_WORKSPACE_CONFIG must be set before any CUBLAS init, hence script entry. # ":4096:8" is the value recommended by PyTorch's `torch.use_deterministic_algorithms` # docs for CUDA >= 10.2 — without it, deterministic cuBLAS GEMMs raise RuntimeError. + # Refs: + # - https://pytorch.org/docs/stable/generated/torch.use_deterministic_algorithms.html + # - https://docs.nvidia.com/cuda/cublas/index.html#results-reproducibility os.environ.setdefault("CUBLAS_WORKSPACE_CONFIG", ":4096:8") torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False @@ -122,16 +101,7 @@ def _walk(cfg, mutations: dict) -> int: n += _walk(item, mutations) return n - # persistent_workers=False is needed alongside num_workers=0 — PyTorch's - # DataLoader rejects (num_workers=0, persistent_workers=True) with - # ValueError. Nested dataloaders (e.g. PackingDataLoader → RankPartitionedDataLoader) - # pass the kwargs straight to torch.utils.data.DataLoader so they trip on this. - dl_overrides = { - "num_workers": 0, - "prefetch_factor": None, - "persistent_workers": False, - "detshuffle": True, - } + dl_overrides = {"num_workers": 0, "prefetch_factor": None, "detshuffle": True} n_dl = _walk(config.dataloader_train, dl_overrides) + _walk(config.dataloader_val, dl_overrides) def _force_compile_disabled(cfg) -> int: @@ -188,7 +158,7 @@ def launch(config: Config, args: argparse.Namespace) -> None: # Apply --deterministic config-level overrides before validate/freeze/trainer-init # so (a) validate inspects the config the trainer will actually consume, and # (b) trainer.__init__ doesn't undo the script-level backends settings - # (imaginaire/trainer.py:125-126 re-applies cudnn from config). + # (cosmos_framework/trainer.py:125-126 re-applies cudnn from config). if args.deterministic: _apply_deterministic_config_overrides(config) # Check that the config is valid @@ -225,28 +195,20 @@ def launch(config: Config, args: argparse.Namespace) -> None: if __name__ == "__main__": - parser = argparse.ArgumentParser(description="SFT training (structured TOML)") - parser.add_argument( - "--sft-toml", - required=True, - help=( - "Path to an SFT structured-dataclass TOML — see " - "cosmos_framework/configs/toml_config/sft_config.py " - "(SFTExperimentConfig)." - ), - ) + # Usage: torchrun --nproc_per_node=1 -m scripts.train --config=projects//configs/config.py + + # Get the config file from the input arguments. + parser = argparse.ArgumentParser(description="Training") + parser.add_argument("--config", help="Path to the config file", required=False) parser.add_argument( "opts", + help=""" +Modify config options at the end of the command. For Yacs configs, use +space-separated "PATH.KEY VALUE" pairs. +For python-based LazyConfig, use "path.key=value". + """.strip(), + default=None, nargs=argparse.REMAINDER, - default=[], - help=( - "Extra Hydra-style dotted-path overrides applied AFTER the TOML " - "values (so they win). Use the standard Hydra syntax, e.g. " - "'optimizer.lr=1e-5 trainer.max_iter=200 " - "model.config.parallelism.data_parallel_shard_degree=4'. " - "Prefix with '--' to make argparse stop interpreting the rest as " - "flags." - ), ) parser.add_argument( "--dryrun", @@ -278,14 +240,12 @@ def launch(config: Config, args: argparse.Namespace) -> None: if args.deterministic: _setup_deterministic_env_and_backends() - config = load_experiment_from_toml(args.sft_toml, extra_overrides=args.opts) - - # log_reproducible_setup reads args.config for telemetry; this entrypoint - # only takes --sft-toml, so alias it so the launch info records the TOML. - args.config = args.sft_toml + config = load_config(args.config, args.opts, enable_one_logger=True) if args.dryrun: - logging.info("Config:\n" + config.pretty_print(use_color=True)) + logging.info( + "Config:\n" + config.pretty_print(use_color=True) + "\n" + pretty_print_overrides(args.opts, use_color=True) + ) os.makedirs(config.job.path_local, exist_ok=True) try: to_yaml(config, f"{config.job.path_local}/config.yaml") @@ -295,4 +255,5 @@ def launch(config: Config, args: argparse.Namespace) -> None: LazyConfig.save_yaml(config, f"{config.job.path_local}/config.yaml") print(f"{config.job.path_local}/config.yaml") else: + # Launch the training job. launch(config, args) diff --git a/cosmos_framework/tools/flops/qwen3_vl.py b/cosmos_framework/tools/flops/qwen3_vl.py index e4663cc..c84a244 100644 --- a/cosmos_framework/tools/flops/qwen3_vl.py +++ b/cosmos_framework/tools/flops/qwen3_vl.py @@ -516,7 +516,7 @@ def compute_qwen3vl_flops( flops_breakdown["vision_encoder"] = 0 # Embedding layer FLOPs - + # NOTE: Only text tokens need embeddings. Visual tokens are already embedded by vision encoder. text_tokens = total_tokens - visual_tokens if include_embeddings: # Embedding lookup: typically counted as 0 or hidden_size operations per token diff --git a/cosmos_framework/tools/visualize/video.py b/cosmos_framework/tools/visualize/video.py index 889bb7e..8603985 100644 --- a/cosmos_framework/tools/visualize/video.py +++ b/cosmos_framework/tools/visualize/video.py @@ -13,19 +13,16 @@ def save_video(grid, video_name, fps=30): - import imageio + import cv2 + import ffmpegcv grid = (grid * 255).astype(np.uint8) grid = np.transpose(grid, (1, 2, 3, 0)) - imageio.mimsave( - video_name, - list(grid), - format="mp4", - fps=float(fps), - codec="libx264", - pixelformat="yuv420p", - macro_block_size=1, - ) + with ffmpegcv.VideoWriter(video_name, "h264", fps) as writer: + for frame in grid: + frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR) + + writer.write(frame) def save_img_or_video( diff --git a/cosmos_framework/trainer/__init__.py b/cosmos_framework/trainer/__init__.py index 878d499..80916ec 100644 --- a/cosmos_framework/trainer/__init__.py +++ b/cosmos_framework/trainer/__init__.py @@ -28,6 +28,10 @@ from cosmos_framework.utils.checkpointer import Checkpointer from cosmos_framework.utils.misc import StragglerDetectorV2 +# COSMOS-RELEASE-BEGIN-IGNORE: remove one_logger +from cosmos_framework.utils.one_logger.one_logger_utils import initialize_one_logger_from_imaginaire_config + +# COSMOS-RELEASE-END-IGNORE class ImaginaireTrainer: @@ -109,6 +113,12 @@ def __init__(self, config): # Initialize cuDNN. torch.backends.cudnn.deterministic = config.trainer.cudnn.deterministic torch.backends.cudnn.benchmark = config.trainer.cudnn.benchmark + # COSMOS-RELEASE-BEGIN-IGNORE: remove one_logger + # OneLogger - initialize one_logger before instantiating CallBackGroup + enable_one_logger = os.environ.get("ENABLE_ONELOGGER", "FALSE").lower() == "true" + if enable_one_logger: + initialize_one_logger_from_imaginaire_config(config) + # COSMOS-RELEASE-END-IGNORE # Initialize the callback functions. self.callbacks = callback.CallBackGroup(config=config, trainer=self) # Initialize the model checkpointer. @@ -224,7 +234,6 @@ def train( model_ddp = model else: raise ValueError(f"Unknown distributed parallelism mode: {self.config.trainer.distributed_parallelism}") - log.info("Starting training...") sm_carveout = int(os.environ.get("GROUPED_MM_SM_CARVEOUT", "0")) if sm_carveout: diff --git a/cosmos_framework/utils/callback.py b/cosmos_framework/utils/callback.py index 0f2a219..373013d 100644 --- a/cosmos_framework/utils/callback.py +++ b/cosmos_framework/utils/callback.py @@ -19,6 +19,10 @@ from cosmos_framework.utils import distributed, log, misc, wandb_util from cosmos_framework.utils.misc import get_local_tensor_if_DTensor +# COSMOS-RELEASE-BEGIN-IGNORE: remove one_logger +from cosmos_framework.utils.one_logger.one_logger_utils import get_one_logger + +# COSMOS-RELEASE-END-IGNORE try: from megatron.core import parallel_state @@ -391,7 +395,7 @@ def on_training_step_end( loss: torch.Tensor, iteration: int = 0, ) -> None: - + # FIXME - this is not correct when using gradient accumulation since self.start_iteration_time is updated every batch # but this is only called when the optimizer is updated, so it's only the time for the last batch. self.elapsed_iteration_time += time.time() - self.start_iteration_time @@ -600,3 +604,104 @@ def on_after_dataloading(self, iteration: int = 0) -> None: torch.cuda.nvtx.range_pop() +# COSMOS-RELEASE-BEGIN-IGNORE +class OneLoggerCallback(Callback): + """Callback for OneLogger""" + + def __init__( + self, + config: Optional["Config"] = None, + trainer: Optional["ImaginaireTrainer"] = None, + ) -> None: + super().__init__(config, trainer) + + self.one_logger = get_one_logger() + self.one_logger.on_app_start(set_barrier=False, app_start_time=round(time.time() * 1000)) + + def on_train_start(self, model: ImaginaireModel, iteration: int = 0) -> None: + try: + batch_size = self.config.dataloader_train.batch_size + except Exception: + batch_size = 1 + if parallel_state is None or not parallel_state.is_initialized(): + data_parallel_size = 1 + else: + data_parallel_size = parallel_state.get_data_parallel_world_size() + global_batch_size = batch_size * data_parallel_size + + self.one_logger.on_train_start( + set_barrier=False, + train_iterations_start=iteration, + train_samples_start=iteration * global_batch_size, + ) + + def on_training_step_batch_start( + self, model: ImaginaireModel, data: dict[str, torch.Tensor], iteration: int = 0 + ) -> None: + self.one_logger.on_train_batch_start(set_barrier=False) + + def on_optimizer_init_start(self) -> None: + self.one_logger.on_optimizer_init_start(set_barrier=False) + + def on_optimizer_init_end(self) -> None: + self.one_logger.on_optimizer_init_end(set_barrier=False) + + def on_training_step_batch_end( + self, + model: ImaginaireModel, + data_batch: dict[str, torch.Tensor], + output_batch: dict[str, torch.Tensor], + loss: torch.Tensor, + iteration: int = 0, + ) -> None: + self.one_logger.on_train_batch_end(set_barrier=False) + + def on_validation_start( + self, model: ImaginaireModel, dataloader_val: torch.utils.data.DataLoader, iteration: int = 0 + ) -> None: + self.one_logger.on_validation_start(set_barrier=False) + + def on_validation_step_start( + self, model: ImaginaireModel, data: dict[str, torch.Tensor], iteration: int = 0 + ) -> None: + self.one_logger.on_validation_batch_start(set_barrier=False) + + def on_validation_step_end( + self, + model: ImaginaireModel, + data_batch: dict[str, torch.Tensor], + output_batch: dict[str, torch.Tensor], + loss: torch.Tensor, + iteration: int = 0, + ) -> None: + self.one_logger.on_validation_batch_end(set_barrier=False) + + def on_validation_end(self, model: ImaginaireModel, iteration: int = 0) -> None: + self.one_logger.on_validation_end(set_barrier=False) + + def on_load_checkpoint_start(self, model: ImaginaireModel) -> None: + self.one_logger.on_load_checkpoint_start(set_barrier=False) + + def on_load_checkpoint_end( + self, model: ImaginaireModel, iteration: int = 0, checkpoint_path: Optional[str] = None + ) -> None: + self.one_logger.on_load_checkpoint_end(set_barrier=False) + + def on_save_checkpoint_start(self, model: ImaginaireModel, iteration: int = 0) -> None: + self.one_logger.on_save_checkpoint_start(global_step=iteration) + + def on_save_checkpoint_end(self, model: ImaginaireModel, iteration: int = 0) -> None: + self.one_logger.on_save_checkpoint_end(global_step=iteration) + + def on_save_checkpoint_success(self, iteration: int = 0, elapsed_time: float = 0) -> None: + self.one_logger.on_save_checkpoint_success(global_step=iteration) + + def on_train_end(self, model: ImaginaireModel, iteration: int = 0) -> None: + self.one_logger.on_train_end(set_barrier=False) + + def on_app_end(self) -> None: + self.one_logger.on_app_end() + self.one_logger.finish() + + +# COSMOS-RELEASE-END-IGNORE diff --git a/cosmos_framework/utils/checkpoint_db.py b/cosmos_framework/utils/checkpoint_db.py index e036ffb..50b4667 100644 --- a/cosmos_framework/utils/checkpoint_db.py +++ b/cosmos_framework/utils/checkpoint_db.py @@ -57,7 +57,7 @@ import uuid from abc import ABC, abstractmethod from pathlib import Path -from typing import Annotated, Callable, TypeAlias +from typing import Annotated, TypeAlias import pydantic from typing_extensions import Self, override @@ -148,6 +148,8 @@ def _hf_download(cmd_args: list[str]) -> str: is_rank0 = os.environ.get("RANK", "0") == "0" cmd = [ "uvx", + "--with", + "click", f"hf@{HF_VERSION}", "download", "--format=json", @@ -268,7 +270,7 @@ def _download(self) -> str: class CheckpointConfig(pydantic.BaseModel): """Config for checkpoint.""" - model_config = pydantic.ConfigDict(extra="forbid", frozen=True, arbitrary_types_allowed=True) + model_config = pydantic.ConfigDict(extra="forbid", frozen=True) uuid: str """Checkpoint UUID.""" @@ -292,15 +294,6 @@ class CheckpointConfig(pydantic.BaseModel): hf: CheckpointHf """Config for checkpoint on Hugging Face.""" - post_download: Callable[[str], None] | None = pydantic.Field(default=None, exclude=True) - """Optional callback invoked with the local path after a successful download. - - Used to materialize derived artifacts inside the cache directory so downstream - loaders see the expected file layout (e.g. wrap a safetensors export back into - a legacy ``.ckpt`` for loaders that only read ``torch.load`` checkpoints). - Must be idempotent — invoked on every ``download()`` call. - """ - @property def full_name(self) -> str: """Return full name for debugging.""" @@ -311,17 +304,8 @@ def download(self) -> str: if INTERNAL: return self.s3.uri - include = getattr(self.hf, "include", ()) - if include: - _config_patterns = {"*.json", "*.txt", "*.yaml", "*.yml", "*.md"} - kind = "tokenizer/config files" if set(include).issubset(_config_patterns) else "files" - log.info(f"Downloading {self.hf.repository} {kind} ({', '.join(include)})") - else: - log.info(f"Downloading checkpoint {self.full_name}") - path = self.hf.download() - if self.post_download is not None: - self.post_download(path) - return path + log.info(f"Downloading checkpoint {self.full_name}") + return self.hf.download() @classmethod def maybe_from_uri(cls, uri: str) -> Self | None: @@ -441,11 +425,6 @@ def download_checkpoint_v2(checkpoint_uri: str, *, check_exists: bool = True) -> Similar to 'download_checkpoint', but unknown S3 URIs are passed through. """ - # Local-path short-circuit: if the URI exists on disk, return it as-is - # without consulting the registry. Prevents the registry from rewriting - # a known basename (e.g. Wan2.2_VAE.pth) into an s3:// URI we can't open. - if os.path.exists(checkpoint_uri): - return checkpoint_uri if INTERNAL: return checkpoint_uri if (checkpoint := CheckpointConfig.maybe_from_uri(sanitize_uri(checkpoint_uri))) is not None: diff --git a/cosmos_framework/utils/checkpointer.py b/cosmos_framework/utils/checkpointer.py index 5561e01..71cd202 100644 --- a/cosmos_framework/utils/checkpointer.py +++ b/cosmos_framework/utils/checkpointer.py @@ -39,9 +39,6 @@ def __init__(self, config_checkpoint: CheckpointConfig, config_job: JobConfig, c """ # Set the callback functions. self.callbacks = callbacks - - - self.checkpoint_dir_local = f"{config_job.path_local}/checkpoints" self.checkpoint_dir_object_store = f"{config_job.path}/checkpoints" self.save_to_object_store = config_checkpoint.save_to_object_store.enabled diff --git a/cosmos_framework/utils/config.py b/cosmos_framework/utils/config.py index 33fc9a9..e8e008a 100644 --- a/cosmos_framework/utils/config.py +++ b/cosmos_framework/utils/config.py @@ -346,7 +346,7 @@ class NVTXConfig: @make_freezable @attrs.define(slots=False) class StragglerDetectionConfig: - """Config for the Straggler detection tool.""" + """Config for Straggler detection tool: https://gitlab-master.nvidia.com/dl/gwe/fault_tolerance_related/straggler/-/tree/cupti?ref_type=heads""" # Enable the Straggler Detection. enabled: bool = False @@ -512,76 +512,19 @@ def validate(self) -> None: distributed.broadcast(job_name_tensor, 0) self.job.name = job_name_tensor.cpu().numpy().tobytes().decode("utf-8") - assert self.job.project != "" assert self.job.group != "" assert self.job.name != "" -def _reload_make_config_for_registrations(root_cfg: "Config") -> None: - """Run ``make_config()`` once for import-time registrations (same intent as loading ``config.py``). - - Deserialized YAML/TOML instantiates attrs ``Config`` with ``__class__.__module__`` set to the - module that defines the class (often ``…defaults.config``). ``load_callable`` splits on the - last dot, which turns that into ``import …defaults`` + ``getattr(..., "config")`` — the - ``defaults.config`` submodule, which often has no ``make_config``. The entrypoint with - ``make_config`` is typically the sibling module ``….config``. - """ - from cosmos_framework.utils.serialization import load_callable - - cls_mod = type(root_cfg).__module__ - - def _try_make_config(mod: object) -> bool: - mk = getattr(mod, "make_config", None) - if mk is None: - return False - _ = mk() - return True - - if cls_mod.endswith(".defaults.config"): - sibling = cls_mod[: -len(".defaults.config")] + ".config" - try: - if _try_make_config(importlib.import_module(sibling)): - return - except ModuleNotFoundError: - pass - - try: - if _try_make_config(load_callable(cls_mod)): - return - except (AssertionError, AttributeError, ModuleNotFoundError): - pass - - try: - if _try_make_config(importlib.import_module(cls_mod)): - return - except ModuleNotFoundError: - pass - - raise AttributeError( - f"No make_config() found for Config class module {cls_mod!r}. " - "YAML/TOML export must match a tree whose Python package exposes make_config " - "(e.g. cosmos_framework.configs.base.vlm.config next to cosmos_framework.configs.base.vlm.defaults.config)." - ) - - def load_config(config_path: str, opts: list[str], enable_one_logger: bool = False) -> Config: - from cosmos_framework.utils.serialization import from_toml, from_yaml + from cosmos_framework.utils.serialization import from_yaml, load_callable t1 = time.monotonic_ns() if config_path.endswith(".yaml"): config = from_yaml(config_path) - # Import-time registrations (dataloaders, experiments, …): YAML root class - # typically lives in …defaults.config; make_config() is on sibling …config. - _reload_make_config_for_registrations(config) - - from cosmos_framework.utils.config_helper import override - - config = override(config, opts, remove_defaults=True) - elif config_path.endswith(".toml"): - config = from_toml(config_path) - # TOML is the same exported structured schema as YAML. - _reload_make_config_for_registrations(config) + # for registration of dataloaders, etc. + _ = load_callable(config.__module__).make_config() from cosmos_framework.utils.config_helper import override @@ -607,7 +550,7 @@ def load_config(config_path: str, opts: list[str], enable_one_logger: bool = Fal def _load_py_config(config_path: str, opts: list[str], validate: bool = True) -> Config: - + # NOTE: circular dependency from cosmos_framework.utils.config_helper import get_config_module, override t1 = time.monotonic_ns() diff --git a/cosmos_framework/utils/context_managers.py b/cosmos_framework/utils/context_managers.py index 794c0bb..6b19e94 100644 --- a/cosmos_framework/utils/context_managers.py +++ b/cosmos_framework/utils/context_managers.py @@ -8,6 +8,14 @@ from cosmos_framework.utils.misc import timer +# COSMOS-RELEASE-BEGIN-IGNORE: remove one_logger and training_telemetry +from cosmos_framework.utils.one_logger.one_logger_context_managers import data_loader_init as one_logger_data_loader_init +from cosmos_framework.utils.one_logger.one_logger_context_managers import model_init as one_logger_model_init +from cosmos_framework.utils.training_telemetry.context_managers import data_loader_init as telemetry_data_loader_init +from cosmos_framework.utils.training_telemetry.context_managers import distributed_init as telemetry_distributed_init +from cosmos_framework.utils.training_telemetry.context_managers import model_init as telemetry_model_init + +# COSMOS-RELEASE-END-IGNORE @contextmanager @@ -36,6 +44,10 @@ def data_loader_init() -> Generator[None, None, None]: Wrap the data loader initialization with multiple context managers used for telemetry and one logger. """ contexts = [ + # COSMOS-RELEASE-BEGIN-IGNORE: remove one_logger and training_telemetry + one_logger_data_loader_init(), + telemetry_data_loader_init(), + # COSMOS-RELEASE-END-IGNORE timer("init_data_loader"), ] with ExitStack() as stack: @@ -48,6 +60,10 @@ def model_init(set_barrier: bool = False) -> Generator[None, None, None]: Wrap the instantiation of the model with multiple context managers used for telemetry and one logger. """ contexts = [ + # COSMOS-RELEASE-BEGIN-IGNORE: remove one_logger and training_telemetry + one_logger_model_init(set_barrier=set_barrier), + telemetry_model_init(), + # COSMOS-RELEASE-END-IGNORE timer("init_model"), ] with ExitStack() as stack: @@ -60,6 +76,9 @@ def distributed_init() -> Generator[None, None, None]: Wrap the distributed initialization, used for telemetry and timers """ contexts = [ + # COSMOS-RELEASE-BEGIN-IGNORE: remove one_logger and training_telemetry + telemetry_distributed_init(), + # COSMOS-RELEASE-END-IGNORE timer("init_distributed"), ] with ExitStack() as stack: diff --git a/cosmos_framework/utils/device.py b/cosmos_framework/utils/device.py index e87674d..7bc2f88 100644 --- a/cosmos_framework/utils/device.py +++ b/cosmos_framework/utils/device.py @@ -85,20 +85,16 @@ def gpu0_has_80gb_or_less(): class Device: - _nvml_affinity_elements = math.ceil(os.cpu_count() / 64) # type: ignore def __init__(self, device_idx: int): - super().__init__() self.handle = pynvml.nvmlDeviceGetHandleByIndex(device_idx) def get_name(self) -> str: - return pynvml.nvmlDeviceGetName(self.handle) def get_cpu_affinity(self) -> list[int]: - affinity_string = "" for j in pynvml.nvmlDeviceGetCpuAffinity(self.handle, Device._nvml_affinity_elements): # assume nvml returns list of 64 bit ints diff --git a/cosmos_framework/utils/distributed.py b/cosmos_framework/utils/distributed.py index 9f95223..64e0de7 100644 --- a/cosmos_framework/utils/distributed.py +++ b/cosmos_framework/utils/distributed.py @@ -54,7 +54,7 @@ def init() -> int | None: timeout_timedelta = timedelta(seconds=int(timeout_seconds)) dist.init_process_group(backend="nccl", init_method="env://", timeout=timeout_timedelta) log.critical( - f"Initialized distributed runtime with local rank {local_rank} with timeout {timeout_seconds}", + f"Initialized distributed training with local rank {local_rank} with timeout {timeout_seconds}", rank0_only=False, ) # Increase the L2 fetch granularity for faster speed. @@ -65,7 +65,7 @@ def init() -> int | None: p_value = ctypes.cast((ctypes.c_int * 1)(), ctypes.POINTER(ctypes.c_int)) _libcudart.cudaDeviceSetLimit(ctypes.c_int(0x05), ctypes.c_int(128)) _libcudart.cudaDeviceGetLimit(p_value, ctypes.c_int(0x05)) - log.info(f"Distributed setup with {get_world_size()} GPUs.") + log.info(f"Training with {get_world_size()} GPUs.") def get_rank(group: Optional[dist.ProcessGroup] = None) -> int: diff --git a/cosmos_framework/utils/easy_io/backends/local_backend.py b/cosmos_framework/utils/easy_io/backends/local_backend.py index 80d05b8..3c3af94 100644 --- a/cosmos_framework/utils/easy_io/backends/local_backend.py +++ b/cosmos_framework/utils/easy_io/backends/local_backend.py @@ -207,7 +207,6 @@ def join_path(self, filepath: Union[str, Path], *filepaths: Union[str, Path]) -> >>> backend.join_path(filepath1, filepath2, filepath3) '/path/of/dir/dir2/path/of/file' """ - # TODO, if filepath or filepaths are Path, should return Path return osp.join(filepath, *filepaths) @contextmanager diff --git a/cosmos_framework/utils/easy_io/easy_io.py b/cosmos_framework/utils/easy_io/easy_io.py index 686ae30..521ba64 100644 --- a/cosmos_framework/utils/easy_io/easy_io.py +++ b/cosmos_framework/utils/easy_io/easy_io.py @@ -137,7 +137,6 @@ def get_file_backend( prefix = "" if enable_singleton: - unique_key = f"{prefix}:{json.dumps(backend_args)}" if unique_key in backend_instances: return backend_instances[unique_key] diff --git a/cosmos_framework/utils/easy_io/easy_io_test.py b/cosmos_framework/utils/easy_io/easy_io_test.py index c2f3f52..4d7feff 100644 --- a/cosmos_framework/utils/easy_io/easy_io_test.py +++ b/cosmos_framework/utils/easy_io/easy_io_test.py @@ -22,12 +22,12 @@ def setup_s3(): @RunIf(requires_file="credentials/pbss_getty.secret") def test_s3_backend(): setup_s3() - for ith, _ in enumerate(easy_io.list_dir_or_file("s3://bucket6/")): + for ith, _ in enumerate(easy_io.list_dir_or_file("s3://checkpoints/")): if ith > 5: break - easy_io.copyfile_from_local("pyproject.toml", "s3://bucket6/pyproject.toml") - easy_io.remove("s3://bucket6/pyproject.toml") + easy_io.copyfile_from_local("pyproject.toml", "s3://checkpoints/pyproject.toml") + easy_io.remove("s3://checkpoints/pyproject.toml") @pytest.mark.L1("Requires data uploading to S3.") diff --git a/cosmos_framework/utils/ema.py b/cosmos_framework/utils/ema.py index 5ae9c4d..85bb817 100644 --- a/cosmos_framework/utils/ema.py +++ b/cosmos_framework/utils/ema.py @@ -104,7 +104,7 @@ class EMAModelTracker(torch.nn.Module): The EMA weights are registered as buffers, which are extractable as state dicts. The names follow those of the regular weights, except all "." are replaced with "-" (limitation of register_buffer()). This is similar to SDXL's implementation of EMA. There are no optimizable parameters. - TODO: multi-EMA weights. + TODO(snah): multi-EMA weights. Attributes: collected_params (list): temporarily stores the regular weights while in EMA mode. diff --git a/cosmos_framework/utils/env_parsers/cred_env_parser.py b/cosmos_framework/utils/env_parsers/cred_env_parser.py index 04810d9..7b96eb3 100644 --- a/cosmos_framework/utils/env_parsers/cred_env_parser.py +++ b/cosmos_framework/utils/env_parsers/cred_env_parser.py @@ -33,7 +33,8 @@ class CredentialEnvParser(EnvParser): PROD_TEAM_DIR_REGION_NAME = String(default="") PICASSO_AUTH_MODEL_REGISTRY_API_KEY = String(default="") - PICASSO_API_ENDPOINT_URL = String(default="https://invalid") + # COSMOS-RELEASE-REPLACE-NEXT: '"https://.*"' '"https://invalid"' + PICASSO_API_ENDPOINT_URL = String(default="https://meeocvslt2.execute-api.us-west-2.amazonaws.com") CRED_ENVS = CredentialEnvParser() diff --git a/cosmos_framework/utils/flags.py b/cosmos_framework/utils/flags.py index 05a22cf..91cd3e6 100644 --- a/cosmos_framework/utils/flags.py +++ b/cosmos_framework/utils/flags.py @@ -39,7 +39,8 @@ def _get_bool(name: str, default: bool) -> bool: This is used to make training dependencies optional. """ -INTERNAL: Final[bool] = _get_bool("COSMOS_INTERNAL", False) +# COSMOS-RELEASE-REPLACE-NEXT: TRAINING False +INTERNAL: Final[bool] = _get_bool("COSMOS_INTERNAL", TRAINING) """Whether to use internal (nvidia-only) resources (e.g. S3).""" SMOKE: Final[bool] = _get_bool("COSMOS_SMOKE", False) @@ -67,6 +68,10 @@ class Device(StrEnum): EXPERIMENTAL_CHECKPOINTS: Final[bool] = _get_bool("COSMOS_EXPERIMENTAL_CHECKPOINTS", INTERNAL) """Whether to enable experimental checkpoints.""" +# COSMOS-RELEASE-BEGIN-IGNORE +ENABLE_PI_CHECKPOINTS: Final[bool] = _get_bool("COSMOS_ENABLE_PI_CHECKPOINTS", False) +"""Whether to enable checkpoints from NVIDIA-DIR/Cosmos-Predict2.5-2B-PI-Private.""" +# COSMOS-RELEASE-END-IGNORE if INTERNAL: TRAINING = True diff --git a/cosmos_framework/utils/lazy_config/__init__.py b/cosmos_framework/utils/lazy_config/__init__.py index 34c3c42..d999bd0 100644 --- a/cosmos_framework/utils/lazy_config/__init__.py +++ b/cosmos_framework/utils/lazy_config/__init__.py @@ -15,7 +15,7 @@ PLACEHOLDER = None -class LazyDict(DictConfig): +class LazyDict(DictConfig): # NOTE: to differentiate between LazyDict & DictConfig def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) diff --git a/cosmos_framework/utils/lazy_config/lazy.py b/cosmos_framework/utils/lazy_config/lazy.py index a2af0e7..2834995 100644 --- a/cosmos_framework/utils/lazy_config/lazy.py +++ b/cosmos_framework/utils/lazy_config/lazy.py @@ -110,7 +110,7 @@ def _patch_import(): old_import = builtins.__import__ def find_relative_file(original_file, relative_import_path, level): - + # NOTE: "from . import x" is not handled. Because then it's unclear # if such import should produce `x` as a python module or DictConfig. # This can be discussed further if needed. relative_import_err = """ @@ -321,26 +321,6 @@ def is_serializable(item): except Exception as e: return False - # For classes / functions / bound methods we want the importable dotted - # path, not `repr(obj)` — the latter yields strings like - # `` or `` which break any - # downstream consumer that calls hydra.utils.instantiate on the loaded - # YAML (e.g. cosmos_framework.scripts.export_model). - from cosmos_framework.utils.lazy_config.registry import convert_target_to_string - - def _to_safe_string(value): - # Preserve primitives — `str(True)` is the literal string `"True"`, - # which yaml then quotes and downstream consumers parse as a string - # instead of the original bool/int/float. - if isinstance(value, (bool, int, float, str)) or value is None: - return value - try: - if callable(value): - return convert_target_to_string(value) - except Exception: - pass - return str(value) - # Function to convert unserializable items to strings def serialize_config(config): if isinstance(config, DictConfig): @@ -358,14 +338,14 @@ def serialize_config(config): serialize_config(value) else: if not is_serializable(value) and value is not None: - config[key] = _to_safe_string(value) + config[key] = str(value) elif isinstance(config, ListConfig): for i, item in enumerate(config): if isinstance(item, (DictConfig, ListConfig)): serialize_config(item) else: if not is_serializable(item) and item is not None: - config[i] = _to_safe_string(item) + config[i] = str(item) else: raise NotImplementedError("Input config must be a DictConfig or ListConfig.") return config diff --git a/cosmos_framework/utils/lazy_config/lazy_call.py b/cosmos_framework/utils/lazy_config/lazy_call.py index 7d0fc0e..528b93f 100644 --- a/cosmos_framework/utils/lazy_config/lazy_call.py +++ b/cosmos_framework/utils/lazy_config/lazy_call.py @@ -1,5 +1,17 @@ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: OpenMDW-1.1 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import collections.abc as abc import inspect diff --git a/cosmos_framework/utils/misc.py b/cosmos_framework/utils/misc.py index d3d9308..a7b2ab0 100644 --- a/cosmos_framework/utils/misc.py +++ b/cosmos_framework/utils/misc.py @@ -66,7 +66,6 @@ def to( assert device is not None or dtype is not None or memory_format is not None, ( "at least one of device, dtype, memory_format should be specified" ) - if isinstance(data, torch.Tensor): if ( memory_format == torch.channels_last @@ -542,7 +541,8 @@ def __exit__(self, exc_type, exc_val, exc_tb): class StragglerDetectorV2: - """StragglerDetectorV2 is a class that allows you to easily integrate the "straggler" tool. + """StragglerDetectorV2 is a class that allows you to easily integrate "straggler" tool: + https://gitlab-master.nvidia.com/dl/gwe/fault_tolerance_related/straggler/-/tree/cupti?ref_type=heads. This tool detects stragglers using low-level CUPTI tool, which can gather kernel execution time with very low overhead. The execution times are compared across different ranks, as well as to the execution time of the exact same kernels in the past. @@ -579,9 +579,9 @@ def __init__( def initialize(self): if self.enabled: if not straggler: - raise RuntimeError( - "Please install the `straggler` package before using StragglerDetectionV2." + "Please install straggler package before using StragglerDetectionV2." + "Package can be installed from here: https://gitlab-master.nvidia.com/dl/osiris/straggler" ) straggler.Detector.initialize( diff --git a/cosmos_framework/utils/object_store.py b/cosmos_framework/utils/object_store.py index a79ae8e..58cd2b1 100644 --- a/cosmos_framework/utils/object_store.py +++ b/cosmos_framework/utils/object_store.py @@ -11,24 +11,14 @@ from typing import TYPE_CHECKING, Any, Callable, Optional from urllib.parse import urlparse -import boto3 import numpy as np import torch import yaml -from botocore.config import Config from PIL import Image -import cosmos_framework.utils.easy_io.backends.auto_auth as auto from cosmos_framework.utils import distributed, log from cosmos_framework.utils.easy_io import easy_io -GLOBAL_S3_CONFIG = Config( - retries={"max_attempts": 20, "mode": "adaptive"}, - connect_timeout=10, - read_timeout=60, - request_checksum_calculation="when_required", - response_checksum_validation="when_required", -) Image.MAX_IMAGE_PIXELS = None if TYPE_CHECKING: @@ -41,26 +31,18 @@ class ObjectStore: **Deprecated**. Use `easy_io` directly instead. Attributes: - client (botocore.client.S3): Object store client object. easy_io_backend: easy_io backend. bucket (str): Object store bucket name. """ def __init__(self, config_object_storage: ObjectStoreConfig): - - # extracts the easy_io backend instead of the boto3 S3 client. - with auto.open_auth(config_object_storage.credentials, "r") as file: - object_storage_config = auto.json_load_auth(file) - self.client = Boto3Wrapper( - "s3", - **object_storage_config, - ) self.easy_io_backend = easy_io.get_file_backend( backend_args={ "backend": "s3", "s3_credential_path": config_object_storage.credentials, "path_mapping": None, - } + }, + enable_singleton=True, ) self.bucket = config_object_storage.bucket @@ -158,7 +140,6 @@ def save_object( """ assert type is not None or save_func is not None with io.BytesIO() as buffer: - # Write to buffer for common data types. if type == "torch": torch.save(object, buffer) @@ -199,31 +180,6 @@ def object_exists(self, key: str) -> bool: return self.easy_io_backend.exists(filepath=self._translate_key(key=key)) -class Boto3Wrapper: - """ - This class serves as a wrapper around boto3.client in order to make boto3.client serializable. It's required to use - spawn method of creating DataLoader workers, which is in turn required to avoid segfaults when using Triton, e.g. - for torch.compile or custom kernels. - """ - - def __init__(self, *args, **kwargs): - self._args = args - self._kwargs = kwargs - self.client = None - - def __setstate__(self, state): - self.__dict__ = state - - def __getattr__(self, item): - is_worker = torch.utils.data.get_worker_info() is not None - client = ( - boto3.client(*self._args, **self._kwargs, config=GLOBAL_S3_CONFIG) if self.client is None else self.client - ) - if is_worker: - self.client = client - return getattr(client, item) - - def sync_s3_dir_to_local( s3_dir: str, s3_credential_path: str, @@ -241,7 +197,7 @@ def sync_s3_dir_to_local( ALL distributed workers using `distributed.barrier()`. Defaults to True. cache_dir (str, optional): The cache folder to sync the S3 directory to. If None, the environment variable `IMAGINAIRE_CACHE_DIR` (defaulting - to "~/.cache/imaginaire") will be used. + to "~/.cache/cosmos_framework") will be used. local_rank_sync (bool, optional): Whether to synchronize download across workers within the same node using a node-level barrier. This is useful when the cache directory is not shared across nodes. Defaults to False. @@ -275,7 +231,7 @@ def sync_s3_dir_to_local( # If the local directory is not specified, use the default cache directory cache_dir = ( - os.environ.get("IMAGINAIRE_CACHE_DIR", os.path.expanduser("~/.cache/imaginaire")) + os.environ.get("IMAGINAIRE_CACHE_DIR", os.path.expanduser("~/.cache/cosmos_framework")) if cache_dir is None else cache_dir ) @@ -363,7 +319,7 @@ def download_from_s3_with_cache( } ) cache_dir = ( - os.environ.get("IMAGINAIRE_CACHE_DIR", os.path.expanduser("~/.cache/imaginaire")) + os.environ.get("IMAGINAIRE_CACHE_DIR", os.path.expanduser("~/.cache/cosmos_framework")) if cache_dir is None else cache_dir ) diff --git a/cosmos_framework/utils/one_logger/one_logger_override_utils.py b/cosmos_framework/utils/one_logger/one_logger_override_utils.py index dd979d7..54e01d3 100644 --- a/cosmos_framework/utils/one_logger/one_logger_override_utils.py +++ b/cosmos_framework/utils/one_logger/one_logger_override_utils.py @@ -12,7 +12,7 @@ def override_one_logger_callback(config) -> None: - """Add OneLoggerCallback to imaginaire config""" + """Add OneLoggerCallback to cosmos_framework config""" # Enable OneLogger by environment variable. enable_onelogger = os.environ.get("ENABLE_ONELOGGER", "FALSE").lower() == "true" diff --git a/cosmos_framework/utils/one_logger/one_logger_utils.py b/cosmos_framework/utils/one_logger/one_logger_utils.py index 0a0457b..ac9d697 100644 --- a/cosmos_framework/utils/one_logger/one_logger_utils.py +++ b/cosmos_framework/utils/one_logger/one_logger_utils.py @@ -317,8 +317,10 @@ def _set_one_logger(self): self.one_logger = OneLogger(config=config) except BaseException: logger.info( - "WARNING: the `one_logger` package is required to enable e2e metrics tracking, " - "but it is not installed." + "WARNING: one_logger package is required to enable e2e metrics " + "tracking. please go to " + "https://confluence.nvidia.com/display/MLWFO/Package+Repositories" + " for details to install it" ) else: self.one_logger = None @@ -1142,7 +1144,7 @@ def on_save_checkpoint_end(self, set_barrier: bool = False, **metrics_input_kwar self._store_set(f"productive_time:{global_step}", productive_time) - + # NOTE: If on_save_checkpoint_success is called already, track productive metrics here if self._store_has_key(f"on_save_checkpoint_success:{global_step}"): successful_save_checkpoint_sync_finish_time = productive_time.pop( "successful_save_checkpoint_sync_finish_time" @@ -1210,7 +1212,7 @@ def on_save_checkpoint_success(self, set_barrier: bool = False, **metrics_input_ # Fetch productivity metrics cached on_save_checkpoint_start productive_metrics = self.one_logger.store_pop(f"productive_metrics:{global_step}") - + # NOTE: Only track *_sync_* metrics after on_save_checkpoint_end is called. # Check if on_save_checkpoint_end is called. if self._store_has_key(f"on_save_checkpoint_end:{global_step}"): productive_time = self._store_get(f"productive_time:{global_step}") diff --git a/cosmos_framework/utils/serialization.py b/cosmos_framework/utils/serialization.py index a962983..d009dba 100644 --- a/cosmos_framework/utils/serialization.py +++ b/cosmos_framework/utils/serialization.py @@ -6,12 +6,11 @@ import importlib import json import os -import tomllib from collections.abc import Callable as Callable2 from collections.abc import Mapping, Sequence from dataclasses import fields, is_dataclass from types import UnionType -from typing import Any, List, Literal, Optional, TypeVar, Union, get_args, get_origin +from typing import Any, List, Optional, TypeVar, Union, get_args, get_origin import attrs import torch @@ -39,19 +38,6 @@ def from_yaml(path: str | None = None, clazz: type | None = None, file_like_or_s raise ValueError("expected file_like_or_str or path to not be None") -def from_toml(path: str | None = None, clazz: type | None = None, file_like_or_str=None) -> T: - if path: - assert os.path.exists(path), f"{path} does not exist" - with open(path, "rb") as in_f: - return from_dict(tomllib.load(in_f), clazz=clazz) - elif file_like_or_str: - if isinstance(file_like_or_str, (bytes, bytearray)): - return from_dict(tomllib.loads(file_like_or_str.decode("utf-8")), clazz=clazz) - return from_dict(tomllib.loads(file_like_or_str), clazz=clazz) - else: - raise ValueError("expected file_like_or_str or path to not be None") - - def _yaml_safe(obj: Any) -> Any: # primitives if obj is None or isinstance(obj, (bool, int, float, str)): @@ -167,7 +153,6 @@ def is_optional(x: type) -> bool: def _to_dict_value(x: T, field_type: type, metadata: dict, field_name: str = ""): - t = type(x) # attrs specific @@ -196,7 +181,6 @@ def _to_dict_value(x: T, field_type: type, metadata: dict, field_name: str = "") # general python types + dataclasses + attrs # * meta types elif field_type == type or field_type == abc.ABCMeta: - return to_qualitified_name(x) elif get_origin(field_type) is type: return to_qualitified_name(x) @@ -267,7 +251,7 @@ def to_dict(x: T, field_name: str = "", hydra_compat: bool = True) -> dict: if hydra_compat: result["_target_"] = to_qualitified_name(x.__class__) for f in fields(x): - + # NOTE: defaults are unnecessary to encode if hydra_compat and f.name == "defaults": continue result[f.name] = _to_dict_value( @@ -286,7 +270,7 @@ def to_dict(x: T, field_name: str = "", hydra_compat: bool = True) -> dict: if hydra_compat: result["_target_"] = to_qualitified_name(x.__class__) for f in attrs.fields(x.__class__): - + # NOTE: defaults are unnecessary to encode if hydra_compat and f.name == "defaults": continue result[f.name] = _to_dict_value( @@ -306,7 +290,6 @@ def _from_dict_value( force_construct_target: bool | None = None, ): - is_dc_type = is_dataclass(field_type) is_attrs_type = is_attrs(field_type) origin = get_origin(field_type) or field_type @@ -337,7 +320,7 @@ def _from_dict_value( assert not isinstance(x, str) return from_dict(x, field_type, field_name=field_name) elif field_type in (DictConfig, LazyDict) or origin in (dict,): - + # NOTE: _recursive_ is the name of the flag for this behaviour construct_target = x.get("_recursive_", field_type == DictConfig) if force_construct_target is not None: construct_target = force_construct_target @@ -403,13 +386,6 @@ def _from_dict_value( return x elif field_type is type(None) or field_type == Any: # no typing return x - elif origin is Literal: - allowed = get_args(field_type) - if x not in allowed: - raise TypeError( - f"value {x!r} not in {field_type} (allowed={allowed}, field={field_name})" - ) - return x else: raise TypeError( f"unexpected type: {field_type} (origin={origin}, concrete_type={concrete_type}, args={args}, x={x})" diff --git a/cosmos_framework/utils/training_telemetry/__init__.py b/cosmos_framework/utils/training_telemetry/__init__.py index 28a81be..1600ad5 100644 --- a/cosmos_framework/utils/training_telemetry/__init__.py +++ b/cosmos_framework/utils/training_telemetry/__init__.py @@ -1,2 +1,13 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: OpenMDW-1.1 +# ----------------------------------------------------------------------------- +# Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# +# This codebase constitutes NVIDIA proprietary technology and is strictly +# confidential. Any unauthorized reproduction, distribution, or disclosure +# of this code, in whole or in part, outside NVIDIA is strictly prohibited +# without prior written consent. +# +# For inquiries regarding the use of this code in other NVIDIA proprietary +# projects, please contact the Deep Imagination Research Team at +# dir@exchange.nvidia.com. +# ----------------------------------------------------------------------------- diff --git a/cosmos_framework/utils/training_telemetry/callback.py b/cosmos_framework/utils/training_telemetry/callback.py index 341792f..2951f0b 100644 --- a/cosmos_framework/utils/training_telemetry/callback.py +++ b/cosmos_framework/utils/training_telemetry/callback.py @@ -197,8 +197,8 @@ def on_training_step_end( average_forward_time=avg_forward_time, average_backward_time=avg_backward_time, average_dataloader_time=avg_dataloader_time, - tflops=0.0, - tokens_per_second=0.0, + tflops=0.0, # FIXME: is this available? + tokens_per_second=0.0, # FIXME: is this available? loss=loss.item(), batch_size=batch_size, ) diff --git a/cosmos_framework/utils/training_telemetry/utils.py b/cosmos_framework/utils/training_telemetry/utils.py index 374565e..8482b84 100644 --- a/cosmos_framework/utils/training_telemetry/utils.py +++ b/cosmos_framework/utils/training_telemetry/utils.py @@ -26,7 +26,10 @@ def import_training_telemetry() -> Optional[ModuleType]: __training_telemetry_module = importlib.import_module("training_telemetry") return __training_telemetry_module except ImportError as e: - logger.error(f"Telemetry is enabled but the `training_telemetry` package is not installed: {e}") + logger.error(f"Heimdall telemetry is enabled but package is not installed: {e}") + logger.info( + "Please install the package using `pip install aidot-training-telemetry --index-url=https://urm.nvidia.com/artifactory/api/pypi/nv-shared-pypi/simple`" + ) return None @@ -36,7 +39,9 @@ def set_telemetry_provider(local_path: str) -> Optional[Any]: """ global __enable_telemetry if not __enable_telemetry: - logger.info("Training telemetry is disabled. Set ENABLE_TELEMETRY=true to enable it.") + logger.info( + "Heimdall telemetry is disabled, if using Heimdall,consider setting ENABLE_TELEMETRY=true to enable it" + ) return None global __provider @@ -46,8 +51,7 @@ def set_telemetry_provider(local_path: str) -> Optional[Any]: training_telemetry = import_training_telemetry() if training_telemetry is None: logger.error( - "Training telemetry is enabled but the `training_telemetry` package is not installed. " - "Set ENABLE_TELEMETRY=false to disable, or install the package." + "Heimdall telemetry is enabled but package is not installed, consider setting ENABLE_TELEMETRY=false to disable it, or install the package using `pip install aidot-training-telemetry --index-url=https://urm.nvidia.com/artifactory/api/pypi/nv-shared-pypi/simple`" ) __enable_telemetry = False return None diff --git a/cosmos_framework/utils/vfm/hf_attention_cosmos.py b/cosmos_framework/utils/vfm/hf_attention_cosmos.py index 3037610..a25c43e 100644 --- a/cosmos_framework/utils/vfm/hf_attention_cosmos.py +++ b/cosmos_framework/utils/vfm/hf_attention_cosmos.py @@ -1,19 +1,19 @@ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: OpenMDW-1.1 -"""HF ``ALL_ATTENTION_FUNCTIONS`` adapter delegating to ``imaginaire.attention``. +"""HF ``ALL_ATTENTION_FUNCTIONS`` adapter delegating to ``cosmos_framework.model.attention``. Registered as the ``"cosmos"`` entry in HF's attention dispatch. -``imaginaire.attention`` owns backend selection (cuDNN / NATTEN / flash2 / +``cosmos_framework.model.attention`` owns backend selection (cuDNN / NATTEN / flash2 / flash3); to fall back to HF's own flash_attention_2 set ``policy.attn_implementation=flash_attention_2``. Layout: HF passes Q/K/V as BHSD ``[B, num_heads, N, head_dim]`` and expects -BSHD output. ``imaginaire.attention`` is BSHD throughout, so we transpose on +BSHD output. ``cosmos_framework.model.attention`` is BSHD throughout, so we transpose on entry; output layout already matches HF's expected return. Strict guards (raise rather than silently break loss parity): -- ``dropout > 0`` — ``imaginaire.attention`` has no dropout parameter. +- ``dropout > 0`` — ``cosmos_framework.model.attention`` has no dropout parameter. Qwen3-VL has ``attention_dropout=0`` so this never triggers in practice. - ``attention_mask is not None`` — adapter expects causal mask via ``is_causal=True`` (and no padding, i.e. Qwen3-VL VLM training with @@ -70,7 +70,7 @@ def hf_attention_cosmos( v = value.transpose(1, 2) # Cast fp32 -> bf16 if needed. - # imaginaire's flash2/flash3/cuDNN backends only accept fp16/bf16; NATTEN + # cosmos_framework's flash2/flash3/cuDNN backends only accept fp16/bf16; NATTEN # also accepts fp32 but routing fp32 attention loses Tensor Core # acceleration (10-20x slower). HF's flash_attention_2 internally casts # fp32->bf16 and we replicate that so this adapter is a drop-in replacement diff --git a/cosmos_framework/utils/vfm/lora.py b/cosmos_framework/utils/vfm/lora.py index 8bffe56..365a369 100644 --- a/cosmos_framework/utils/vfm/lora.py +++ b/cosmos_framework/utils/vfm/lora.py @@ -1,5 +1,17 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: OpenMDW-1.1 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. """Custom in-place LoRA injection for MoT-style models. diff --git a/cosmos_framework/utils/vfm/model_loader.py b/cosmos_framework/utils/vfm/model_loader.py index eae2487..c180ca3 100644 --- a/cosmos_framework/utils/vfm/model_loader.py +++ b/cosmos_framework/utils/vfm/model_loader.py @@ -280,7 +280,7 @@ def load_model_from_checkpoint( * **safetensors**: a directory containing one or more ``*.safetensors`` shards in the native Cosmos3 VFM state-dict layout. Loaded via - :func:`projects.cosmos3.vfm.models.utils.safetensors_loader.load_vfm_model`. + :func:`cosmos_framework.model.vfm.utils.safetensors_loader.load_vfm_model`. No ``/model`` suffix is appended. credential_path: Path to credentials file (if required for remote storage). Optional. enable_gcs_patch_in_boto3: Whether to enable the boto3 patch for GCS S3-compatibility. diff --git a/cosmos_framework/utils/vfm/monkey_patch.py b/cosmos_framework/utils/vfm/monkey_patch.py index 1320d0b..d3e9e7a 100644 --- a/cosmos_framework/utils/vfm/monkey_patch.py +++ b/cosmos_framework/utils/vfm/monkey_patch.py @@ -13,7 +13,7 @@ from cosmos_framework.utils import log -_EXPECTED_TRANSFORMERS_VERSION_PREFIX = "4.57." +_EXPECTED_TRANSFORMERS_VERSION = "4.57.1" def patch_qwen3_vl_forward(model): @@ -28,7 +28,7 @@ def patch_qwen3_vl_forward(model): model: The ``Qwen3VLModel`` instance (i.e. ``model.model.model`` when the outer model is ``HFModel``). """ - if not transformers.__version__.startswith(_EXPECTED_TRANSFORMERS_VERSION_PREFIX): + if transformers.__version__ != _EXPECTED_TRANSFORMERS_VERSION: raise ValueError(f"monkey patching transformers version {transformers.__version__} is not supported") if not isinstance(model, Qwen3VLModel): diff --git a/cosmos_framework/utils/vfm/optimizer.py b/cosmos_framework/utils/vfm/optimizer.py index 2fd7695..b9947a6 100644 --- a/cosmos_framework/utils/vfm/optimizer.py +++ b/cosmos_framework/utils/vfm/optimizer.py @@ -44,7 +44,7 @@ def _optimizer_cls( - ``"adam"`` / ``"adamw"``: forwarded to ``torch.optim.Adam`` / ``torch.optim.AdamW``. ``fused`` (if present in ``optimizer_kwargs``) flows through and selects the fused CUDA kernel. - - ``"fusedadam"``: NVIDIA's :class:`projects.cosmos3.vfm.utils.fused_adam.FusedAdam`. + - ``"fusedadam"``: NVIDIA's :class:`cosmos_framework.utils.vfm.fused_adam.FusedAdam`. It is fused by construction and rejects a ``fused`` kwarg, so any ``fused`` entry is popped before instantiation. We also force ``capturable=True`` and ``master_weights=True`` because those are the @@ -341,14 +341,6 @@ def state_dict(self) -> dict[str, Any]: ) def load_state_dict(self, state_dict: dict[str, Any]) -> None: - # Backward compat with old VLM checkpoints that prefix every key with - # "optimizer_0/" (the legacy list-of-optimizers layout; cosmos3 only - # ever ran with N=1). Strip the prefix transparently so those - # checkpoints continue to resume. - legacy_prefix = "optimizer_0/" - if any(k.startswith(legacy_prefix) for k in state_dict): - state_dict = {k.removeprefix(legacy_prefix): v for k, v in state_dict.items()} - set_optimizer_state_dict( model=self.model, optimizers=self.optimizers, diff --git a/cosmos_framework/utils/vfm/parallelism.py b/cosmos_framework/utils/vfm/parallelism.py index e6fe795..eb9dad2 100644 --- a/cosmos_framework/utils/vfm/parallelism.py +++ b/cosmos_framework/utils/vfm/parallelism.py @@ -27,8 +27,8 @@ - VFM inference — ``dp_shard`` + cfgp/cp overlays; replicate forced to 1. FSDP wrapping for VLM ``HFModel`` instances lives in -``projects.cosmos3.vfm.models.parallelize_vlm``; MoT wrapping lives in -``projects.cosmos3.vfm.models.mot.parallelize_unified_mot``. Both consume +``cosmos_framework.model.vfm.parallelize_vlm``; MoT wrapping lives in +``cosmos_framework.model.vfm.mot.parallelize_unified_mot``. Both consume ``ParallelDims`` from this module. """ diff --git a/cosmos_framework/utils/vfm/vlm/flop_calculator.py b/cosmos_framework/utils/vfm/vlm/flop_calculator.py index cf112c8..90e8aef 100644 --- a/cosmos_framework/utils/vfm/vlm/flop_calculator.py +++ b/cosmos_framework/utils/vfm/vlm/flop_calculator.py @@ -25,7 +25,6 @@ class FlopCalculator: # estimator to underestimate per-sample work and the dynamic batcher to # pack batches too large. Keep this False until the slope and intercept # are refit against is_causal=True benchmark data. - # benchmark runs and flip _IS_CAUSAL_FOR_CALIBRATION to True so this # calculator inherits the algorithmically correct FLOP count by default. _IS_CAUSAL_FOR_CALIBRATION: bool = False diff --git a/cosmos_framework/utils/vfm/vlm/pretrained_models_downloader.py b/cosmos_framework/utils/vfm/vlm/pretrained_models_downloader.py index a3921ae..54c18cb 100644 --- a/cosmos_framework/utils/vfm/vlm/pretrained_models_downloader.py +++ b/cosmos_framework/utils/vfm/vlm/pretrained_models_downloader.py @@ -193,16 +193,11 @@ def maybe_download_hf_model_from_s3( s3_prefix: str = "cosmos_reason2/hf_models", require_s3_exists: bool = False, ) -> str: - # Short-circuit when model_name_or_path is already a local directory — no - # S3 or HF Hub fetch is needed. Prevents opening credentials/*.secret - # in OSS/local-checkpoint smoke runs that already have the model on disk. - if os.path.isdir(model_name_or_path): - return model_name_or_path exclude_list = [".safetensors"] if not include_model_weights else [] s3_prefix = os.path.join(s3_prefix, model_name_or_path) # download the model from s3 to local cache if cache_dir is None: - cache_dir = os.path.expanduser(os.getenv("IMAGINAIRE_CACHE_DIR", "~/.cache/imaginaire")) + cache_dir = os.path.expanduser(os.getenv("IMAGINAIRE_CACHE_DIR", "~/.cache/cosmos_framework")) cache_dir = os.path.join(cache_dir, s3_prefix) diff --git a/cosmos_framework/utils/vlm/compute_flops_qwen3vl.py b/cosmos_framework/utils/vlm/compute_flops_qwen3vl.py index e64bb74..afe6051 100644 --- a/cosmos_framework/utils/vlm/compute_flops_qwen3vl.py +++ b/cosmos_framework/utils/vlm/compute_flops_qwen3vl.py @@ -8,7 +8,7 @@ given the model configuration and input specifications (total tokens, visual tokens, etc.). Usage: - from cosmos_framework.utils.vlm.compute_flops_qwen3vl import compute_qwen3vl_flops + from cosmos_framework.utils.scripts.compute_qwen3vl_flops import compute_qwen3vl_flops flops = compute_qwen3vl_flops( num_text_layers=32, @@ -480,7 +480,7 @@ def compute_qwen3vl_flops( flops_breakdown["vision_encoder"] = 0 # Embedding layer FLOPs - + # NOTE: Only text tokens need embeddings. Visual tokens are already embedded by vision encoder. text_tokens = total_tokens - visual_tokens if include_embeddings: # Embedding lookup: typically counted as 0 or hidden_size operations per token diff --git a/cosmos_framework/utils/vlm/dcp_checkpointer.py b/cosmos_framework/utils/vlm/dcp_checkpointer.py index 3ec7eef..6984cae 100644 --- a/cosmos_framework/utils/vlm/dcp_checkpointer.py +++ b/cosmos_framework/utils/vlm/dcp_checkpointer.py @@ -286,7 +286,6 @@ def load( iteration = 0 - if checkpoint_path is not None: self._check_checkpoint_exists(checkpoint_path) all_state_dicts = {} @@ -385,7 +384,6 @@ def _async_with_pinned_memory(self, checkpoint_file: str, state_dict: dict[str, self.staging = True self.staging_ckpt_file = checkpoint_file - self.maybe_wait_for_staging() def maybe_wait_for_staging(self) -> None: diff --git a/cosmos_framework/utils/vlm/distributed.py b/cosmos_framework/utils/vlm/distributed.py index a4cda9c..6b2dfff 100644 --- a/cosmos_framework/utils/vlm/distributed.py +++ b/cosmos_framework/utils/vlm/distributed.py @@ -141,7 +141,7 @@ def destroy_distributed(): # grads[i] = g.to_local() # # create bucket for all grads, we can allreduce them in one go -# +# # NOTE: why we don't set DTensor as bucket view? # # This is becuase we can't be sure that the training framework # # never release grad, or clean grad by set None. # # Create temporary bucket is a more reliable solution. @@ -215,7 +215,7 @@ def gradient_norm_clipping( # If total_norm is a DTensor, the placements must be `torch.distributed._tensor.ops.math_ops._NormPartial`. # We can simply reduce the DTensor to get the total norm in this tensor's process group # and then convert it to a local tensor. - + # NOTE: It has two purposes: # 1. to make sure the total norm is computed correctly when PP is used (see below) # 2. to return a reduced total_norm tensor whose .item() would return the correct value if isinstance(total_norm, DTensor): diff --git a/cosmos_framework/utils/vlm/optimizer.py b/cosmos_framework/utils/vlm/optimizer.py index ebafe7f..fb74a7f 100644 --- a/cosmos_framework/utils/vlm/optimizer.py +++ b/cosmos_framework/utils/vlm/optimizer.py @@ -42,7 +42,6 @@ class OptimizerConfig: def _optimizer_cls(params: list[nn.Parameter], optimizer_kwargs: dict[str, Any], name: str): if name.lower() == "adam": - optimizer = torch.optim.Adam(params, **optimizer_kwargs) elif name.lower() == "adamw": optimizer = torch.optim.AdamW(params, **optimizer_kwargs) diff --git a/cosmos_framework/utils/vlm/pretrained_models_downloader.py b/cosmos_framework/utils/vlm/pretrained_models_downloader.py index dbc1667..5fff313 100644 --- a/cosmos_framework/utils/vlm/pretrained_models_downloader.py +++ b/cosmos_framework/utils/vlm/pretrained_models_downloader.py @@ -160,7 +160,7 @@ def maybe_download_hf_model_from_s3( s3_prefix = os.path.join(s3_prefix, model_name_or_path) # download the model from s3 to local cache if cache_dir is None: - cache_dir = os.path.expanduser(os.getenv("IMAGINAIRE_CACHE_DIR", "~/.cache/imaginaire")) + cache_dir = os.path.expanduser(os.getenv("IMAGINAIRE_CACHE_DIR", "~/.cache/cosmos_framework")) cache_dir = os.path.join(cache_dir, s3_prefix) From b73a28d2736dcd70e640a8179729aa552260c278 Mon Sep 17 00:00:00 2001 From: yangyangt Date: Wed, 3 Jun 2026 06:33:31 -0700 Subject: [PATCH 02/11] Apply release pipeline: strip COSMOS-RELEASE-IGNORE blocks and fix leaked proprietary headers Re-run of the release pipeline with two new transforms: - Strip COSMOS-RELEASE-BEGIN-IGNORE..END-IGNORE blocks (one_logger, cuDNN, internal notes). - Replace leaked NVIDIA proprietary/confidential headers with the OpenMDW SPDX header. Co-Authored-By: Claude Opus 4.7 --- .../data/vfm/packing_iterable_dataset.py | 11 +- cosmos_framework/data/vfm/sequence_packing.py | 10 -- cosmos_framework/model/attention/backends.py | 8 -- .../model/attention/cudnn/__init__.py | 6 - .../model/attention/cudnn/cudnn_forward.py | 15 --- .../model/attention/cudnn/functions.py | 4 - .../model/attention/cudnn/meta.py | 3 - cosmos_framework/model/attention/frontend.py | 5 - cosmos_framework/model/vfm/mot/attention.py | 20 ---- .../model/vfm/mot/attention_test.py | 9 -- cosmos_framework/trainer/__init__.py | 10 -- cosmos_framework/utils/callback.py | 105 ------------------ cosmos_framework/utils/context_managers.py | 19 ---- cosmos_framework/utils/flags.py | 4 - .../utils/training_telemetry/__init__.py | 15 +-- 15 files changed, 4 insertions(+), 240 deletions(-) diff --git a/cosmos_framework/data/vfm/packing_iterable_dataset.py b/cosmos_framework/data/vfm/packing_iterable_dataset.py index 715ce6e..ac30f0d 100644 --- a/cosmos_framework/data/vfm/packing_iterable_dataset.py +++ b/cosmos_framework/data/vfm/packing_iterable_dataset.py @@ -1,12 +1,5 @@ -# ----------------------------------------------------------------------------- -# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. -# All rights reserved. -# -# This codebase constitutes NVIDIA proprietary technology and is strictly -# confidential. Any unauthorized reproduction, distribution, or disclosure -# of this code, in whole or in part, without express written consent of -# NVIDIA is strictly prohibited. -# ----------------------------------------------------------------------------- +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: OpenMDW-1.1 """ Abstract base class for pool-based token-budget bin-packing over multiple datasets. diff --git a/cosmos_framework/data/vfm/sequence_packing.py b/cosmos_framework/data/vfm/sequence_packing.py index ae8720d..aa4e275 100644 --- a/cosmos_framework/data/vfm/sequence_packing.py +++ b/cosmos_framework/data/vfm/sequence_packing.py @@ -2246,11 +2246,6 @@ def generate_natten_metadata( over layers (nn.ModuleList). """ - # COSMOS-RELEASE-BEGIN-IGNORE - # sequence-packed tensors containing only and exactly subsequences with sizes from - # token_shapes, in the same order, and with no padding in between. - # We should either make sure this never happens, or have static checks in place. - # COSMOS-RELEASE-END-IGNORE if token_shapes is None or len(token_shapes) < 1: raise ValueError("'token_shapes' is required for 'three_way' attention.") @@ -2273,11 +2268,6 @@ def filter_shape(shape: tuple) -> tuple: return tuple(x for x in shape if x > 1) # Infer token layout rank (dimensionality) - # COSMOS-RELEASE-BEGIN-IGNORE - # compresses that dimension into size 1, which gets filtered out. To avoid failing checks - # we need to take the maximum dimensionality over the entire batch. We'll assert each token - # shape matches that dimensionality later, if NATTEN is required for this batch. - # COSMOS-RELEASE-END-IGNORE num_dims = max([len(filter_shape(token_shape)) for token_shape in token_shapes]) # Single pass: check if all layers support this dimensionality and if any need processing diff --git a/cosmos_framework/model/attention/backends.py b/cosmos_framework/model/attention/backends.py index 74f3cf0..6877a59 100644 --- a/cosmos_framework/model/attention/backends.py +++ b/cosmos_framework/model/attention/backends.py @@ -22,11 +22,6 @@ from cosmos_framework.model.attention.utils.safe_ops import log from cosmos_framework.model.attention.utils.safe_ops.functools import lru_cache -# COSMOS-RELEASE-BEGIN-IGNORE -# isort: split -from cosmos_framework.model.attention.cudnn.checks import cudnn_attention_check - -# COSMOS-RELEASE-END-IGNORE BACKEND_CHECK_MAP = { "cudnn": cudnn_attention_check, # COSMOS-RELEASE-IGNORELINE @@ -143,9 +138,6 @@ def get_backend_list(arch_tag: int) -> list[str]: ] elif arch_tag in [100, 103]: default_backends = [ - # COSMOS-RELEASE-BEGIN-IGNORE - "cudnn", - # COSMOS-RELEASE-END-IGNORE "natten", "flash2", ] diff --git a/cosmos_framework/model/attention/cudnn/__init__.py b/cosmos_framework/model/attention/cudnn/__init__.py index 3b7453f..ca3697b 100644 --- a/cosmos_framework/model/attention/cudnn/__init__.py +++ b/cosmos_framework/model/attention/cudnn/__init__.py @@ -13,12 +13,6 @@ from cosmos_framework.model.attention.utils.safe_ops import log from cosmos_framework.model.attention.utils.version import version_at_least -# COSMOS-RELEASE-BEGIN-IGNORE -# (ahassani) [11-20-2025] Banning cuDNN until reliability issues are resolved. -# Versions checked: 91300, 91400, 91500 -# (ahassani) [12-01-2025] -# 91500 ran on both GB200 and H100 SXM. -# COSMOS-RELEASE-END-IGNORE CUDNN_DISALLOWED = True CUDNN_MIN_BACKEND_VERSION = 91300 diff --git a/cosmos_framework/model/attention/cudnn/cudnn_forward.py b/cosmos_framework/model/attention/cudnn/cudnn_forward.py index 909f9a6..a68ed00 100644 --- a/cosmos_framework/model/attention/cudnn/cudnn_forward.py +++ b/cosmos_framework/model/attention/cudnn/cudnn_forward.py @@ -45,17 +45,6 @@ def get_dtype_choices(arch_tag: int) -> dict: log.debug("cuDNN Attention is not supported because compute capability is below the minimum (8.0).") return {} - # COSMOS-RELEASE-BEGIN-IGNORE - ## not seem to work. - # if arch_tag in [90, 100]: - # log.debug(f"cuDNN Attention supports FP8 for {arch_tag=}.") - # return { - # torch.float16: cudnn.data_type.HALF, - # torch.bfloat16: cudnn.data_type.BFLOAT16, - # torch.float8_e4m3fn: cudnn.data_type.FP8_E4M3, - # torch.float8_e5m2: cudnn.data_type.FP8_E5M2, - # } - # COSMOS-RELEASE-END-IGNORE log.debug(f"cuDNN Attention only supports FP16 and BF16 for {arch_tag=}.") return { @@ -331,10 +320,6 @@ def cudnn_operation(q: Tensor, k: Tensor, v: Tensor, output: Tensor, lse: Tensor stream = torch.cuda.current_stream(q.device) cudnn.set_stream(handle=handle, stream=stream.cuda_stream) - # COSMOS-RELEASE-BEGIN-IGNORE - # caching allocator plays nicely with the LRU cache over this, but for now let's avoid - # premature optimization. - # COSMOS-RELEASE-END-IGNORE workspace = torch.zeros(workspace_size_bytes, device=device, dtype=torch.uint8) # [workspace_size_bytes] variant_pack = { diff --git a/cosmos_framework/model/attention/cudnn/functions.py b/cosmos_framework/model/attention/cudnn/functions.py index 3014d74..6d06c59 100644 --- a/cosmos_framework/model/attention/cudnn/functions.py +++ b/cosmos_framework/model/attention/cudnn/functions.py @@ -55,10 +55,6 @@ def forward( padding_Q = 0 padding_KV = 0 - # COSMOS-RELEASE-BEGIN-IGNORE - # but as of 11/12/2025 does not seem to fix any issues. Keeping here in case - # it ever comes back. - # COSMOS-RELEASE-END-IGNORE if CUDNN_PADDING_REQUIRED: Q_multiplier = 256 KV_multiplier = 256 diff --git a/cosmos_framework/model/attention/cudnn/meta.py b/cosmos_framework/model/attention/cudnn/meta.py index 61d8f5b..5cd1534 100644 --- a/cosmos_framework/model/attention/cudnn/meta.py +++ b/cosmos_framework/model/attention/cudnn/meta.py @@ -30,9 +30,6 @@ def get_fwd_dtypes(arch_tag: int) -> list[torch.dtype]: log.debug("cuDNN Attention is not supported because compute capability is below the minimum (8.0).") return [] - # COSMOS-RELEASE-BEGIN-IGNORE - ## not seem to work. - # COSMOS-RELEASE-END-IGNORE log.debug(f"cuDNN Attention only supports FP16 and BF16 for {arch_tag=}.") return [torch.float16, torch.bfloat16] diff --git a/cosmos_framework/model/attention/frontend.py b/cosmos_framework/model/attention/frontend.py index f2740ee..41f917b 100644 --- a/cosmos_framework/model/attention/frontend.py +++ b/cosmos_framework/model/attention/frontend.py @@ -28,11 +28,6 @@ from cosmos_framework.model.attention.utils.environment import filter_attention_merge_backends from cosmos_framework.model.attention.utils.safe_ops import log -# COSMOS-RELEASE-BEGIN-IGNORE -# isort: split -from cosmos_framework.model.attention.cudnn import cudnn_attention - -# COSMOS-RELEASE-END-IGNORE # Map backend names to their frontend attention API BACKEND_MAP = { diff --git a/cosmos_framework/model/vfm/mot/attention.py b/cosmos_framework/model/vfm/mot/attention.py index 50fd413..d0e12d4 100644 --- a/cosmos_framework/model/vfm/mot/attention.py +++ b/cosmos_framework/model/vfm/mot/attention.py @@ -98,16 +98,6 @@ def two_way_attention( sample_offsets = packed_query_states["sample_offsets"] - # COSMOS-RELEASE-BEGIN-IGNORE - # NOTE: we can only use the don't care causal mask when we know seqlen_Q == seqlen_KV. - # Since this is a varlen use case, we would need to statically check all Q and KV offsets - # are the same. - # We don't want to launch a kernel just to perform this check and slow down our model, - # and we don't want to just assume no one is going to copy this piece of code without - # reading this, and we definitely don't want to complicate the sequence_packing code so that - # it performs a static check when creating the packed sequence and metadata, so we can just rely - # on causal_q_offsets and causal_k_offsets being the same tensor. - # COSMOS-RELEASE-END-IGNORE use_dont_care_mask = causal_q_offsets is causal_k_offsets # NOTE: cosmos_framework attention is BSHD in, BSHD out @@ -186,16 +176,6 @@ def three_way_attention( ).reshape(-1) full_v[null_positions] = 0 - # COSMOS-RELEASE-BEGIN-IGNORE - # NOTE: we can only use the don't care causal mask when we know seqlen_Q == seqlen_KV. - # Since this is a varlen use case, we would need to statically check all Q and KV offsets - # are the same. - # We don't want to launch a kernel just to perform this check and slow down our model, - # and we don't want to just assume no one is going to copy this piece of code without - # reading this, and we definitely don't want to complicate the sequence_packing code so that - # it performs a static check when creating the packed sequence and metadata, so we can just rely - # on causal_q_offsets and causal_k_offsets being the same tensor. - # COSMOS-RELEASE-END-IGNORE use_dont_care_mask = causal_q_offsets is causal_k_offsets # NOTE: cosmos_framework attention is BSHD in, BSHD out diff --git a/cosmos_framework/model/vfm/mot/attention_test.py b/cosmos_framework/model/vfm/mot/attention_test.py index d41fdff..56c2dee 100644 --- a/cosmos_framework/model/vfm/mot/attention_test.py +++ b/cosmos_framework/model/vfm/mot/attention_test.py @@ -365,15 +365,6 @@ def forward(self, *args, **kwargs): ) -# COSMOS-RELEASE-BEGIN-IGNORE -# because we need GQA support, varlen, torch.compile, and we need it across architectures. -# Flash3 + torch.compile is banned because our container build of Flash3 doesn't support it, and -# patching on our end and lack of versioning on their end makes it very difficult to check for this -# at runtime. -# Flash2 varlen introduces instability in Blackwell, and is therefore banned. -# cuDNN is banned entirely until it can pass our tests. -# NATTEN must meet the version requirements for all the features to be available. -# COSMOS-RELEASE-END-IGNORE @pytest.mark.L0 @pytest.mark.skipif(not NATTEN_SUPPORTED, reason="NATTEN is not available, or too old.") def test_two_way_attention_cmp_flex_attn(): diff --git a/cosmos_framework/trainer/__init__.py b/cosmos_framework/trainer/__init__.py index 80916ec..5a63f22 100644 --- a/cosmos_framework/trainer/__init__.py +++ b/cosmos_framework/trainer/__init__.py @@ -28,10 +28,6 @@ from cosmos_framework.utils.checkpointer import Checkpointer from cosmos_framework.utils.misc import StragglerDetectorV2 -# COSMOS-RELEASE-BEGIN-IGNORE: remove one_logger -from cosmos_framework.utils.one_logger.one_logger_utils import initialize_one_logger_from_imaginaire_config - -# COSMOS-RELEASE-END-IGNORE class ImaginaireTrainer: @@ -113,12 +109,6 @@ def __init__(self, config): # Initialize cuDNN. torch.backends.cudnn.deterministic = config.trainer.cudnn.deterministic torch.backends.cudnn.benchmark = config.trainer.cudnn.benchmark - # COSMOS-RELEASE-BEGIN-IGNORE: remove one_logger - # OneLogger - initialize one_logger before instantiating CallBackGroup - enable_one_logger = os.environ.get("ENABLE_ONELOGGER", "FALSE").lower() == "true" - if enable_one_logger: - initialize_one_logger_from_imaginaire_config(config) - # COSMOS-RELEASE-END-IGNORE # Initialize the callback functions. self.callbacks = callback.CallBackGroup(config=config, trainer=self) # Initialize the model checkpointer. diff --git a/cosmos_framework/utils/callback.py b/cosmos_framework/utils/callback.py index 373013d..b940692 100644 --- a/cosmos_framework/utils/callback.py +++ b/cosmos_framework/utils/callback.py @@ -19,10 +19,6 @@ from cosmos_framework.utils import distributed, log, misc, wandb_util from cosmos_framework.utils.misc import get_local_tensor_if_DTensor -# COSMOS-RELEASE-BEGIN-IGNORE: remove one_logger -from cosmos_framework.utils.one_logger.one_logger_utils import get_one_logger - -# COSMOS-RELEASE-END-IGNORE try: from megatron.core import parallel_state @@ -604,104 +600,3 @@ def on_after_dataloading(self, iteration: int = 0) -> None: torch.cuda.nvtx.range_pop() -# COSMOS-RELEASE-BEGIN-IGNORE -class OneLoggerCallback(Callback): - """Callback for OneLogger""" - - def __init__( - self, - config: Optional["Config"] = None, - trainer: Optional["ImaginaireTrainer"] = None, - ) -> None: - super().__init__(config, trainer) - - self.one_logger = get_one_logger() - self.one_logger.on_app_start(set_barrier=False, app_start_time=round(time.time() * 1000)) - - def on_train_start(self, model: ImaginaireModel, iteration: int = 0) -> None: - try: - batch_size = self.config.dataloader_train.batch_size - except Exception: - batch_size = 1 - if parallel_state is None or not parallel_state.is_initialized(): - data_parallel_size = 1 - else: - data_parallel_size = parallel_state.get_data_parallel_world_size() - global_batch_size = batch_size * data_parallel_size - - self.one_logger.on_train_start( - set_barrier=False, - train_iterations_start=iteration, - train_samples_start=iteration * global_batch_size, - ) - - def on_training_step_batch_start( - self, model: ImaginaireModel, data: dict[str, torch.Tensor], iteration: int = 0 - ) -> None: - self.one_logger.on_train_batch_start(set_barrier=False) - - def on_optimizer_init_start(self) -> None: - self.one_logger.on_optimizer_init_start(set_barrier=False) - - def on_optimizer_init_end(self) -> None: - self.one_logger.on_optimizer_init_end(set_barrier=False) - - def on_training_step_batch_end( - self, - model: ImaginaireModel, - data_batch: dict[str, torch.Tensor], - output_batch: dict[str, torch.Tensor], - loss: torch.Tensor, - iteration: int = 0, - ) -> None: - self.one_logger.on_train_batch_end(set_barrier=False) - - def on_validation_start( - self, model: ImaginaireModel, dataloader_val: torch.utils.data.DataLoader, iteration: int = 0 - ) -> None: - self.one_logger.on_validation_start(set_barrier=False) - - def on_validation_step_start( - self, model: ImaginaireModel, data: dict[str, torch.Tensor], iteration: int = 0 - ) -> None: - self.one_logger.on_validation_batch_start(set_barrier=False) - - def on_validation_step_end( - self, - model: ImaginaireModel, - data_batch: dict[str, torch.Tensor], - output_batch: dict[str, torch.Tensor], - loss: torch.Tensor, - iteration: int = 0, - ) -> None: - self.one_logger.on_validation_batch_end(set_barrier=False) - - def on_validation_end(self, model: ImaginaireModel, iteration: int = 0) -> None: - self.one_logger.on_validation_end(set_barrier=False) - - def on_load_checkpoint_start(self, model: ImaginaireModel) -> None: - self.one_logger.on_load_checkpoint_start(set_barrier=False) - - def on_load_checkpoint_end( - self, model: ImaginaireModel, iteration: int = 0, checkpoint_path: Optional[str] = None - ) -> None: - self.one_logger.on_load_checkpoint_end(set_barrier=False) - - def on_save_checkpoint_start(self, model: ImaginaireModel, iteration: int = 0) -> None: - self.one_logger.on_save_checkpoint_start(global_step=iteration) - - def on_save_checkpoint_end(self, model: ImaginaireModel, iteration: int = 0) -> None: - self.one_logger.on_save_checkpoint_end(global_step=iteration) - - def on_save_checkpoint_success(self, iteration: int = 0, elapsed_time: float = 0) -> None: - self.one_logger.on_save_checkpoint_success(global_step=iteration) - - def on_train_end(self, model: ImaginaireModel, iteration: int = 0) -> None: - self.one_logger.on_train_end(set_barrier=False) - - def on_app_end(self) -> None: - self.one_logger.on_app_end() - self.one_logger.finish() - - -# COSMOS-RELEASE-END-IGNORE diff --git a/cosmos_framework/utils/context_managers.py b/cosmos_framework/utils/context_managers.py index 6b19e94..794c0bb 100644 --- a/cosmos_framework/utils/context_managers.py +++ b/cosmos_framework/utils/context_managers.py @@ -8,14 +8,6 @@ from cosmos_framework.utils.misc import timer -# COSMOS-RELEASE-BEGIN-IGNORE: remove one_logger and training_telemetry -from cosmos_framework.utils.one_logger.one_logger_context_managers import data_loader_init as one_logger_data_loader_init -from cosmos_framework.utils.one_logger.one_logger_context_managers import model_init as one_logger_model_init -from cosmos_framework.utils.training_telemetry.context_managers import data_loader_init as telemetry_data_loader_init -from cosmos_framework.utils.training_telemetry.context_managers import distributed_init as telemetry_distributed_init -from cosmos_framework.utils.training_telemetry.context_managers import model_init as telemetry_model_init - -# COSMOS-RELEASE-END-IGNORE @contextmanager @@ -44,10 +36,6 @@ def data_loader_init() -> Generator[None, None, None]: Wrap the data loader initialization with multiple context managers used for telemetry and one logger. """ contexts = [ - # COSMOS-RELEASE-BEGIN-IGNORE: remove one_logger and training_telemetry - one_logger_data_loader_init(), - telemetry_data_loader_init(), - # COSMOS-RELEASE-END-IGNORE timer("init_data_loader"), ] with ExitStack() as stack: @@ -60,10 +48,6 @@ def model_init(set_barrier: bool = False) -> Generator[None, None, None]: Wrap the instantiation of the model with multiple context managers used for telemetry and one logger. """ contexts = [ - # COSMOS-RELEASE-BEGIN-IGNORE: remove one_logger and training_telemetry - one_logger_model_init(set_barrier=set_barrier), - telemetry_model_init(), - # COSMOS-RELEASE-END-IGNORE timer("init_model"), ] with ExitStack() as stack: @@ -76,9 +60,6 @@ def distributed_init() -> Generator[None, None, None]: Wrap the distributed initialization, used for telemetry and timers """ contexts = [ - # COSMOS-RELEASE-BEGIN-IGNORE: remove one_logger and training_telemetry - telemetry_distributed_init(), - # COSMOS-RELEASE-END-IGNORE timer("init_distributed"), ] with ExitStack() as stack: diff --git a/cosmos_framework/utils/flags.py b/cosmos_framework/utils/flags.py index 91cd3e6..581b094 100644 --- a/cosmos_framework/utils/flags.py +++ b/cosmos_framework/utils/flags.py @@ -68,10 +68,6 @@ class Device(StrEnum): EXPERIMENTAL_CHECKPOINTS: Final[bool] = _get_bool("COSMOS_EXPERIMENTAL_CHECKPOINTS", INTERNAL) """Whether to enable experimental checkpoints.""" -# COSMOS-RELEASE-BEGIN-IGNORE -ENABLE_PI_CHECKPOINTS: Final[bool] = _get_bool("COSMOS_ENABLE_PI_CHECKPOINTS", False) -"""Whether to enable checkpoints from NVIDIA-DIR/Cosmos-Predict2.5-2B-PI-Private.""" -# COSMOS-RELEASE-END-IGNORE if INTERNAL: TRAINING = True diff --git a/cosmos_framework/utils/training_telemetry/__init__.py b/cosmos_framework/utils/training_telemetry/__init__.py index 1600ad5..28a81be 100644 --- a/cosmos_framework/utils/training_telemetry/__init__.py +++ b/cosmos_framework/utils/training_telemetry/__init__.py @@ -1,13 +1,2 @@ -# ----------------------------------------------------------------------------- -# Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. -# All rights reserved. -# -# This codebase constitutes NVIDIA proprietary technology and is strictly -# confidential. Any unauthorized reproduction, distribution, or disclosure -# of this code, in whole or in part, outside NVIDIA is strictly prohibited -# without prior written consent. -# -# For inquiries regarding the use of this code in other NVIDIA proprietary -# projects, please contact the Deep Imagination Research Team at -# dir@exchange.nvidia.com. -# ----------------------------------------------------------------------------- +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: OpenMDW-1.1 From d56702e3a32a0000383de861ebfc2da80ec965b9 Mon Sep 17 00:00:00 2001 From: yangyangt Date: Fri, 5 Jun 2026 00:39:52 -0700 Subject: [PATCH 03/11] Strip dataloader_weighted_url IGNORE block from vlm config Re-run of the release pipeline: source added COSMOS-RELEASE-BEGIN-IGNORE markers around the dataloader_weighted_url imports and registration calls, which the pipeline now removes. Co-Authored-By: Claude Opus 4.7 --- cosmos_framework/configs/base/vlm/config.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/cosmos_framework/configs/base/vlm/config.py b/cosmos_framework/configs/base/vlm/config.py index c5a087d..7354d3e 100644 --- a/cosmos_framework/configs/base/vlm/config.py +++ b/cosmos_framework/configs/base/vlm/config.py @@ -8,11 +8,6 @@ from cosmos_framework.configs.base.vlm.defaults.callbacks import register_callbacks from cosmos_framework.configs.base.vlm.defaults.config import Config from cosmos_framework.configs.base.vlm.defaults.dataloader import register_data_debug -from cosmos_framework.configs.base.vlm.defaults.dataloader_weighted_url import ( - register_data_recipe, - register_data_weighted_url, - register_data_weighted_url_with_text, -) from cosmos_framework.configs.base.vlm.defaults.model import register_model from cosmos_framework.configs.base.vlm.defaults.optimizer import register_optimizer, register_scheduler from cosmos_framework.configs.base.vlm.defaults.vlm_policy import register_vlm_policy @@ -47,10 +42,6 @@ def make_config() -> Config: register_model() register_vlm_policy() # Register dataloader configs - register_data_weighted_url() - register_data_recipe() - register_data_weighted_url_with_text() - register_data_debug() log.info("Registering optimizer, scheduler, checkpoint, ckpt type, and callbacks") register_optimizer() register_scheduler() From 5e9f05e1e70627a518f6011cdad505495be986c6 Mon Sep 17 00:00:00 2001 From: yangyangt Date: Fri, 5 Jun 2026 02:00:13 -0700 Subject: [PATCH 04/11] Release: ship webdataset image augmentors; strip dataloading_monitor + data_registration deps - New: cosmos_framework/data/imaginaire/webdataset/augmentors/image/ (6 modules + flip.py + __init__). - configs/base/defaults/callbacks.py: new COSMOS-RELEASE-IGNORE block strips the dataloading_monitor import + DetailedDataLoadingSpeedMonitor callback, removing the dataloading_monitor / webdataset.utils.stream dependency. - data/vfm/augmentors/text_transforms_for_image.py: inlines _CAPTION_EMBEDDING_KEY_MAPPING_IMAGES, removing the data_sources.data_registration import. Co-Authored-By: Claude Opus 4.7 --- .../configs/base/defaults/callbacks.py | 5 - .../webdataset/augmentors/image/__init__.py | 2 + .../webdataset/augmentors/image/cropping.py | 150 +++++++++++++++ .../webdataset/augmentors/image/flip.py | 32 ++++ .../webdataset/augmentors/image/misc.py | 51 +++++ .../webdataset/augmentors/image/normalize.py | 36 ++++ .../webdataset/augmentors/image/padding.py | 60 ++++++ .../webdataset/augmentors/image/resize.py | 175 ++++++++++++++++++ .../augmentors/text_transforms_for_image.py | 6 +- 9 files changed, 511 insertions(+), 6 deletions(-) create mode 100644 cosmos_framework/data/imaginaire/webdataset/augmentors/image/__init__.py create mode 100644 cosmos_framework/data/imaginaire/webdataset/augmentors/image/cropping.py create mode 100644 cosmos_framework/data/imaginaire/webdataset/augmentors/image/flip.py create mode 100644 cosmos_framework/data/imaginaire/webdataset/augmentors/image/misc.py create mode 100644 cosmos_framework/data/imaginaire/webdataset/augmentors/image/normalize.py create mode 100644 cosmos_framework/data/imaginaire/webdataset/augmentors/image/padding.py create mode 100644 cosmos_framework/data/imaginaire/webdataset/augmentors/image/resize.py diff --git a/cosmos_framework/configs/base/defaults/callbacks.py b/cosmos_framework/configs/base/defaults/callbacks.py index 602646d..805ffb5 100644 --- a/cosmos_framework/configs/base/defaults/callbacks.py +++ b/cosmos_framework/configs/base/defaults/callbacks.py @@ -10,7 +10,6 @@ from cosmos_framework.utils.lazy_config import LazyCall as L from cosmos_framework.utils.callback import LowPrecisionCallback, WandBCallback from cosmos_framework.callbacks.compile_tokenizer import CompileTokenizer -from cosmos_framework.callbacks.dataloading_monitor import DetailedDataLoadingSpeedMonitor from cosmos_framework.callbacks.device_monitor import DeviceMonitor from cosmos_framework.callbacks.every_n_draw_sample import EveryNDrawSample from cosmos_framework.callbacks.expert_heatmap import ExpertHeatmap @@ -49,10 +48,6 @@ param_count=L(ParamCount)( # use model save_s3="${upload_reproducible_setup}", ), - dataloader_speed=L(DetailedDataLoadingSpeedMonitor)( - every_n=100, - save_s3="${upload_reproducible_setup}", - ), wandb_val=L(WandBCallbackEval)( save_s3="${upload_reproducible_setup}", ), diff --git a/cosmos_framework/data/imaginaire/webdataset/augmentors/image/__init__.py b/cosmos_framework/data/imaginaire/webdataset/augmentors/image/__init__.py new file mode 100644 index 0000000..28a81be --- /dev/null +++ b/cosmos_framework/data/imaginaire/webdataset/augmentors/image/__init__.py @@ -0,0 +1,2 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: OpenMDW-1.1 diff --git a/cosmos_framework/data/imaginaire/webdataset/augmentors/image/cropping.py b/cosmos_framework/data/imaginaire/webdataset/augmentors/image/cropping.py new file mode 100644 index 0000000..b34cb81 --- /dev/null +++ b/cosmos_framework/data/imaginaire/webdataset/augmentors/image/cropping.py @@ -0,0 +1,150 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: OpenMDW-1.1 + +from typing import Optional + +import torch +import torchvision.transforms.functional as transforms_F +from loguru import logger as logging + +from cosmos_framework.data.imaginaire.webdataset.augmentors.augmentor import Augmentor +from cosmos_framework.data.imaginaire.webdataset.augmentors.image.misc import obtain_augmentation_size, obtain_image_size + + +class CenterCrop(Augmentor): + def __init__(self, input_keys: list, output_keys: Optional[list] = None, args: Optional[dict] = None) -> None: + super().__init__(input_keys, output_keys, args) + + def __call__(self, data_dict: dict) -> dict: + r"""Performs center crop. + + Args: + data_dict (dict): Input data dict + Returns: + data_dict (dict): Output dict where images are center cropped. + We also save the cropping parameters in the aug_params dict + so that it will be used by other transforms. + """ + assert (self.args is not None) and ("size" in self.args), "Please specify size in args" + + img_size = obtain_augmentation_size(data_dict, self.args) + width, height = img_size + + orig_w, orig_h = obtain_image_size(data_dict, self.input_keys) + for key in self.input_keys: + data_dict[key] = transforms_F.center_crop(data_dict[key], [height, width]) + + # We also add the aug params we use. This will be useful for other transforms + crop_x0 = (orig_w - width) // 2 + crop_y0 = (orig_h - height) // 2 + cropping_params = { + "resize_w": orig_w, + "resize_h": orig_h, + "crop_x0": crop_x0, + "crop_y0": crop_y0, + "crop_w": width, + "crop_h": height, + } + + if "aug_params" not in data_dict: + data_dict["aug_params"] = dict() + + data_dict["aug_params"]["cropping"] = cropping_params + return data_dict + + +class BottomCrop(Augmentor): + def __init__(self, input_keys: list, output_keys: Optional[list] = None, args: Optional[dict] = None) -> None: + super().__init__(input_keys, output_keys, args) + + def __call__(self, data_dict: dict) -> dict: + r"""Crops rows from the bottom of the image/video to reach ``target_height``. + + The top of the frame is preserved (content is top-anchored). Width is unchanged. + Works for 3-D ``[C, H, W]`` images and 4-D ``[C, T, H, W]`` or ``[T, C, H, W]`` + videos — the last two dims are always treated as (H, W). + + Args: + data_dict (dict): Input data dict. ``self.args["target_height"]`` is the + desired output height. Source height must be ``>= target_height``. + + Returns: + data_dict (dict): Output dict where images are bottom-cropped and + ``image_size`` is updated to ``[target_h, w, orig_h, orig_w]`` to mirror + :class:`ReflectionPadding`'s contract. + """ + assert (self.args is not None) and ("target_height" in self.args), "Please specify target_height in args" + if self.output_keys is None: + self.output_keys = self.input_keys + + target_h = int(self.args["target_height"]) + orig_w, orig_h = obtain_image_size(data_dict, self.input_keys) + assert orig_h >= target_h, ( + f"BottomCrop requires source height >= target_height: got orig_h={orig_h}, target_h={target_h}" + ) + + for inp_key, out_key in zip(self.input_keys, self.output_keys): + tensor = data_dict[inp_key] + # Slice the last 2 dims; the second-to-last dim is height regardless of + # whether the tensor is CHW, CTHW, or TCHW. + data_dict[out_key] = tensor[..., :target_h, :] + + if out_key != inp_key: + del data_dict[inp_key] + + data_dict["image_size"] = torch.tensor([target_h, orig_w, orig_h, orig_w], dtype=torch.float) + + return data_dict + + +class RandomCrop(Augmentor): + def __init__(self, input_keys: list, output_keys: Optional[list] = None, args: Optional[dict] = None) -> None: + super().__init__(input_keys, output_keys, args) + + def __call__(self, data_dict: dict) -> dict: + r"""Performs random crop. + + Args: + data_dict (dict): Input data dict + Returns: + data_dict (dict): Output dict where images are center cropped. + We also save the cropping parameters in the aug_params dict + so that it will be used by other transforms. + """ + + img_size = obtain_augmentation_size(data_dict, self.args) + width, height = img_size + + orig_w, orig_h = obtain_image_size(data_dict, self.input_keys) + # Obtaining random crop coords + try: + crop_x0 = int(torch.randint(0, orig_w - width + 1, size=(1,)).item()) + crop_y0 = int(torch.randint(0, orig_h - height + 1, size=(1,)).item()) + except Exception as e: + logging.warning( + f"Random crop failed. Performing center crop, original_size(wxh): {orig_w}x{orig_h}, random_size(wxh): {width}x{height}" + ) + for key in self.input_keys: + data_dict[key] = transforms_F.center_crop(data_dict[key], [height, width]) + crop_x0 = (orig_w - width) // 2 + crop_y0 = (orig_h - height) // 2 + + # We also add the aug params we use. This will be useful for other transforms + cropping_params = { + "resize_w": orig_w, + "resize_h": orig_h, + "crop_x0": crop_x0, + "crop_y0": crop_y0, + "crop_w": width, + "crop_h": height, + } + + if "aug_params" not in data_dict: + data_dict["aug_params"] = dict() + + data_dict["aug_params"]["cropping"] = cropping_params + + # We must perform same random cropping for all input keys + for key in self.input_keys: + data_dict[key] = transforms_F.crop(data_dict[key], crop_y0, crop_x0, height, width) + return data_dict diff --git a/cosmos_framework/data/imaginaire/webdataset/augmentors/image/flip.py b/cosmos_framework/data/imaginaire/webdataset/augmentors/image/flip.py new file mode 100644 index 0000000..8f0bb7d --- /dev/null +++ b/cosmos_framework/data/imaginaire/webdataset/augmentors/image/flip.py @@ -0,0 +1,32 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: OpenMDW-1.1 + +from typing import Optional + +import torch +import torchvision.transforms.functional as transforms_F + +from cosmos_framework.data.imaginaire.webdataset.augmentors.augmentor import Augmentor + + +class HorizontalFlip(Augmentor): + def __init__(self, input_keys: list, output_keys: Optional[list] = None, args: Optional[dict] = None) -> None: + super().__init__(input_keys, output_keys, args) + + def __call__(self, data_dict: dict) -> dict: + r"""Performs horizontal flipping. + + Args: + data_dict (dict): Input data dict + Returns: + data_dict (dict): Output dict where images are center cropped. + """ + flip_enabled = getattr(self.args, "enabled", True) + if flip_enabled: + p = getattr(self.args, "prob", 0.5) + coin_flip = torch.rand(1).item() > p + for key in self.input_keys: + if coin_flip: + data_dict[key] = transforms_F.hflip(data_dict[key]) + + return data_dict diff --git a/cosmos_framework/data/imaginaire/webdataset/augmentors/image/misc.py b/cosmos_framework/data/imaginaire/webdataset/augmentors/image/misc.py new file mode 100644 index 0000000..d3e5216 --- /dev/null +++ b/cosmos_framework/data/imaginaire/webdataset/augmentors/image/misc.py @@ -0,0 +1,51 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: OpenMDW-1.1 + +from typing import Union + +import torch +from PIL import Image + + +def obtain_image_size(data_dict: dict, input_keys: list) -> tuple[int, int]: + r"""Function for obtaining the image size from the data dict. + + Args: + data_dict (dict): Input data dict + input_keys (list): List of input keys + Returns: + width (int): Width of the input image + height (int): Height of the input image + """ + + data1 = data_dict[input_keys[0]] + if isinstance(data1, Image.Image): + width, height = data1.size + elif isinstance(data1, torch.Tensor): + height, width = data1.size()[-2:] + else: + raise ValueError("data to random crop should be PIL Image or tensor") + + return width, height + + +def obtain_augmentation_size(data_dict: dict, augmentor_cfg: dict) -> Union[int, tuple]: + r"""Function for obtaining size of the augmentation. + When dealing with multi-aspect ratio dataloaders, we need to + find the augmentation size from the aspect ratio of the data. + If data_dict contains "_res_size_map" (e.g. from resolution sampling), + that map is used instead of augmentor_cfg["size"]. + + Args: + data_dict (dict): Input data dict + augmentor_cfg (dict): Augmentor config + Returns: + aug_size (int): Size of augmentation + """ + if "__url__" in data_dict and "aspect_ratio" in data_dict["__url__"].meta.opts: + aspect_ratio = data_dict["__url__"].meta.opts["aspect_ratio"] + else: # Non-webdataset format + aspect_ratio = data_dict["aspect_ratio"] + if "_res_size_map" in data_dict: + return data_dict["_res_size_map"][aspect_ratio] + return augmentor_cfg["size"][aspect_ratio] diff --git a/cosmos_framework/data/imaginaire/webdataset/augmentors/image/normalize.py b/cosmos_framework/data/imaginaire/webdataset/augmentors/image/normalize.py new file mode 100644 index 0000000..a949230 --- /dev/null +++ b/cosmos_framework/data/imaginaire/webdataset/augmentors/image/normalize.py @@ -0,0 +1,36 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: OpenMDW-1.1 + +from typing import Optional + +import torch +import torchvision.transforms.functional as transforms_F + +from cosmos_framework.data.imaginaire.webdataset.augmentors.augmentor import Augmentor + + +class Normalize(Augmentor): + def __init__(self, input_keys: list, output_keys: Optional[list] = None, args: Optional[dict] = None) -> None: + super().__init__(input_keys, output_keys, args) + + def __call__(self, data_dict: dict) -> dict: + r"""Performs data normalization. + + Args: + data_dict (dict): Input data dict + Returns: + data_dict (dict): Output dict where images are center cropped. + """ + assert self.args is not None, "Please specify args" + + mean = self.args["mean"] + std = self.args["std"] + + for key in self.input_keys: + if isinstance(data_dict[key], torch.Tensor): + data_dict[key] = data_dict[key].to(dtype=torch.get_default_dtype()).div(255) + else: + data_dict[key] = transforms_F.to_tensor(data_dict[key]) # division by 255 is applied in to_tensor() + + data_dict[key] = transforms_F.normalize(tensor=data_dict[key], mean=mean, std=std) + return data_dict diff --git a/cosmos_framework/data/imaginaire/webdataset/augmentors/image/padding.py b/cosmos_framework/data/imaginaire/webdataset/augmentors/image/padding.py new file mode 100644 index 0000000..e14d66f --- /dev/null +++ b/cosmos_framework/data/imaginaire/webdataset/augmentors/image/padding.py @@ -0,0 +1,60 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: OpenMDW-1.1 + +from typing import Optional + +import omegaconf +import torch +import torchvision.transforms.functional as transforms_F + +from cosmos_framework.data.imaginaire.webdataset.augmentors.augmentor import Augmentor +from cosmos_framework.data.imaginaire.webdataset.augmentors.image.misc import obtain_augmentation_size, obtain_image_size + + +class ReflectionPadding(Augmentor): + def __init__(self, input_keys: list, output_keys: Optional[list] = None, args: Optional[dict] = None) -> None: + super().__init__(input_keys, output_keys, args) + + def __call__(self, data_dict: dict) -> dict: + r"""Performs reflection padding. This function also returns a padding mask. + + Args: + data_dict (dict): Input data dict + Returns: + data_dict (dict): Output dict where images are center cropped. + """ + + assert self.args is not None, "Please specify args in augmentation" + if self.output_keys is None: + self.output_keys = self.input_keys + + # Obtain image and augmentation sizes + orig_w, orig_h = obtain_image_size(data_dict, self.input_keys) + target_size = obtain_augmentation_size(data_dict, self.args) + + assert isinstance(target_size, (tuple, omegaconf.listconfig.ListConfig)), "Please specify target size as tuple" + target_w, target_h = target_size + + target_w = int(target_w) + target_h = int(target_h) + + # One-sided padding (bottom and right only, content stays at top-left) + padding_right = target_w - orig_w + padding_bottom = target_h - orig_h + padding_vals = [0, 0, padding_right, padding_bottom] + + for inp_key, out_key in zip(self.input_keys, self.output_keys): + if max(padding_vals[0], padding_vals[2]) >= orig_w or max(padding_vals[1], padding_vals[3]) >= orig_h: + # In this case, we can't perform reflection padding. This is because padding values + # are larger than the image size. So, perform edge padding instead. + data_dict[out_key] = transforms_F.pad(data_dict[inp_key], padding_vals, padding_mode="edge") + else: + # Perform reflection padding + data_dict[out_key] = transforms_F.pad(data_dict[inp_key], padding_vals, padding_mode="reflect") + + if out_key != inp_key: + del data_dict[inp_key] + + data_dict["image_size"] = torch.tensor([target_h, target_w, orig_h, orig_w], dtype=torch.float) + + return data_dict diff --git a/cosmos_framework/data/imaginaire/webdataset/augmentors/image/resize.py b/cosmos_framework/data/imaginaire/webdataset/augmentors/image/resize.py new file mode 100644 index 0000000..82cdea9 --- /dev/null +++ b/cosmos_framework/data/imaginaire/webdataset/augmentors/image/resize.py @@ -0,0 +1,175 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: OpenMDW-1.1 + +from typing import Optional + +import omegaconf +import torchvision.transforms.functional as transforms_F + +from cosmos_framework.data.imaginaire.webdataset.augmentors.augmentor import Augmentor +from cosmos_framework.data.imaginaire.webdataset.augmentors.image.misc import obtain_augmentation_size, obtain_image_size + + +class ResizeSmallestSide(Augmentor): + def __init__(self, input_keys: list, output_keys: Optional[list] = None, args: Optional[dict] = None) -> None: + super().__init__(input_keys, output_keys, args) + + def __call__(self, data_dict: dict) -> dict: + r"""Performs resizing to smaller side + + Args: + data_dict (dict): Input data dict + Returns: + data_dict (dict): Output dict where images are resized + """ + + if self.output_keys is None: + self.output_keys = self.input_keys + assert self.args is not None, "Please specify args in augmentations" + + for inp_key, out_key in zip(self.input_keys, self.output_keys): + out_size = obtain_augmentation_size(data_dict, self.args) + assert isinstance(out_size, int), "Arg size in resize should be an integer" + data_dict[out_key] = transforms_F.resize( + data_dict[inp_key], + size=out_size, # type: ignore + interpolation=getattr(self.args, "interpolation", transforms_F.InterpolationMode.BICUBIC), + antialias=True, + ) + if out_key != inp_key: + del data_dict[inp_key] + return data_dict + + +class ResizeLargestSide(Augmentor): + def __init__(self, input_keys: list, output_keys: Optional[list] = None, args: Optional[dict] = None) -> None: + super().__init__(input_keys, output_keys, args) + + def __call__(self, data_dict: dict) -> dict: + r"""Performs resizing to larger side + + Args: + data_dict (dict): Input data dict + Returns: + data_dict (dict): Output dict where images are resized + """ + + if self.output_keys is None: + self.output_keys = self.input_keys + assert self.args is not None, "Please specify args in augmentations" + + for inp_key, out_key in zip(self.input_keys, self.output_keys): + out_size = obtain_augmentation_size(data_dict, self.args) + assert isinstance(out_size, int), "Arg size in resize should be an integer" + orig_w, orig_h = obtain_image_size(data_dict, self.input_keys) + + scaling_ratio = min(out_size / orig_w, out_size / orig_h) + target_size = [int(scaling_ratio * orig_h), int(scaling_ratio * orig_w)] + + data_dict[out_key] = transforms_F.resize( + data_dict[inp_key], + size=target_size, + interpolation=getattr(self.args, "interpolation", transforms_F.InterpolationMode.BICUBIC), + antialias=True, + ) + if out_key != inp_key: + del data_dict[inp_key] + return data_dict + + +class ResizeSmallestSideAspectPreserving(Augmentor): + def __init__(self, input_keys: list, output_keys: Optional[list] = None, args: Optional[dict] = None) -> None: + super().__init__(input_keys, output_keys, args) + + def __call__(self, data_dict: dict) -> dict: + r"""Performs aspect-ratio preserving resizing. + Image is resized to the dimension which has the smaller ratio of (size / target_size). + First we compute (w_img / w_target) and (h_img / h_target) and resize the image + to the dimension that has the smaller of these ratios. + + Args: + data_dict (dict): Input data dict + Returns: + data_dict (dict): Output dict where images are resized + """ + + if self.output_keys is None: + self.output_keys = self.input_keys + assert self.args is not None, "Please specify args in augmentations" + + img_size = obtain_augmentation_size(data_dict, self.args) + assert isinstance(img_size, (tuple, omegaconf.listconfig.ListConfig)), ( + f"Arg size in resize should be a tuple, get {type(img_size)}, {img_size}" + ) + img_w, img_h = img_size + + orig_w, orig_h = obtain_image_size(data_dict, self.input_keys) + scaling_ratio = max((img_w / orig_w), (img_h / orig_h)) + target_size = (int(scaling_ratio * orig_h + 0.5), int(scaling_ratio * orig_w + 0.5)) + + assert target_size[0] >= img_h and target_size[1] >= img_w, ( + f"Resize error. orig {(orig_w, orig_h)} desire {img_size} compute {target_size}" + ) + + for inp_key, out_key in zip(self.input_keys, self.output_keys): + data_dict[out_key] = transforms_F.resize( + data_dict[inp_key], + size=target_size, # type: ignore + interpolation=( + self.args["interpolation"] + if "interpolation" in self.args + else transforms_F.InterpolationMode.BICUBIC + ), + antialias=True, + ) + + if out_key != inp_key: + del data_dict[inp_key] + return data_dict + + +class ResizeLargestSideAspectPreserving(Augmentor): + def __init__(self, input_keys: list, output_keys: Optional[list] = None, args: Optional[dict] = None) -> None: + super().__init__(input_keys, output_keys, args) + + def __call__(self, data_dict: dict) -> dict: + r"""Performs aspect-ratio preserving resizing. + Image is resized to the dimension which has the larger ratio of (size / target_size). + First we compute (w_img / w_target) and (h_img / h_target) and resize the image + to the dimension that has the larger of these ratios. + + Args: + data_dict (dict): Input data dict + Returns: + data_dict (dict): Output dict where images are resized + """ + + if self.output_keys is None: + self.output_keys = self.input_keys + assert self.args is not None, "Please specify args in augmentations" + + img_size = obtain_augmentation_size(data_dict, self.args) + assert isinstance(img_size, (tuple, omegaconf.listconfig.ListConfig)), ( + f"Arg size in resize should be a tuple, get {type(img_size)}, {img_size}" + ) + img_w, img_h = img_size + + orig_w, orig_h = obtain_image_size(data_dict, self.input_keys) + scaling_ratio = min((img_w / orig_w), (img_h / orig_h)) + target_size = (int(scaling_ratio * orig_h + 0.5), int(scaling_ratio * orig_w + 0.5)) + + assert target_size[0] <= img_h and target_size[1] <= img_w, ( + f"Resize error. orig {(orig_w, orig_h)} desire {img_size} compute {target_size}" + ) + + for inp_key, out_key in zip(self.input_keys, self.output_keys): + data_dict[out_key] = transforms_F.resize( + data_dict[inp_key], + size=target_size, # type: ignore + interpolation=getattr(self.args, "interpolation", transforms_F.InterpolationMode.BICUBIC), + antialias=True, + ) + + if out_key != inp_key: + del data_dict[inp_key] + return data_dict diff --git a/cosmos_framework/data/vfm/augmentors/text_transforms_for_image.py b/cosmos_framework/data/vfm/augmentors/text_transforms_for_image.py index edaef11..d38fae4 100644 --- a/cosmos_framework/data/vfm/augmentors/text_transforms_for_image.py +++ b/cosmos_framework/data/vfm/augmentors/text_transforms_for_image.py @@ -8,13 +8,17 @@ from cosmos_framework.data.imaginaire.webdataset.augmentors.v3_text_transforms import pad_and_resize from cosmos_framework.data.imaginaire.webdataset.augmentors.augmentor import Augmentor from cosmos_framework.utils import log -from cosmos_framework.data.vfm.data_sources.data_registration import _CAPTION_EMBEDDING_KEY_MAPPING_IMAGES # For the qwen captions, we have 3 variants: short, medium, long # In addition, for synthetic data, we create prompt embeddings as well. # There is quite a bit of entropy in the way prompt data is saved. # Captions are saved as "prompts", while the corresponding embeddings are saved as "original_prompt" # This part will be cleaned after synthetic data is cleaned to be in the same format as real data. +_CAPTION_EMBEDDING_KEY_MAPPING_IMAGES = { + "ai_v3p1": "ai_v3p1", + "qwen2p5_7b_v4": "qwen2p5_7b_v4", + "prompts": "qwen2p5_7b_v4", +} _AVAILABLE_QWEN_CAPTIONS = ["qwen2p5_7b_short", "qwen2p5_7b_medium", "qwen2p5_7b_long"] _AVAILABLE_QWEN3_30B_A3B_CAPTIONS = [ "qwen3_30b_a3b_short", From 4905b32e8037a8984c1a86a05e455672d89813be Mon Sep 17 00:00:00 2001 From: yangyangt Date: Fri, 5 Jun 2026 02:17:03 -0700 Subject: [PATCH 05/11] Release: strip dataloading_monitor wiring (vlm callbacks); drop orphan nvlm pair - configs/base/vlm/defaults/callbacks.py: new COSMOS-RELEASE-IGNORE block removes the dataloading_monitor import + speed-monitor callback. - Drop nvlm_data_unify.py and nvlm_sample_loaders_and_part_filters.py: unreachable in the release (no CF code imports nvlm_data_unify), now excluded from the mapping. Co-Authored-By: Claude Opus 4.7 --- .../configs/base/vlm/defaults/callbacks.py | 5 - .../vfm/augmentors/vlm/nvlm_data_unify.py | 120 - .../nvlm_sample_loaders_and_part_filters.py | 2719 ----------------- 3 files changed, 2844 deletions(-) delete mode 100644 cosmos_framework/data/vfm/augmentors/vlm/nvlm_data_unify.py delete mode 100644 cosmos_framework/data/vfm/augmentors/vlm/nvlm_sample_loaders_and_part_filters.py diff --git a/cosmos_framework/configs/base/vlm/defaults/callbacks.py b/cosmos_framework/configs/base/vlm/defaults/callbacks.py index f742972..3392205 100644 --- a/cosmos_framework/configs/base/vlm/defaults/callbacks.py +++ b/cosmos_framework/configs/base/vlm/defaults/callbacks.py @@ -12,7 +12,6 @@ from cosmos_framework.utils.lazy_config import LazyCall as L from cosmos_framework.utils.callback import LowPrecisionCallback, WandBCallback from cosmos_framework.callbacks.dataloader_state import DataLoaderStateCallback -from cosmos_framework.callbacks.dataloading_monitor import DetailedDataLoadingSpeedMonitor from cosmos_framework.callbacks.grad_clip import GradClip from cosmos_framework.callbacks.hf_export import HFExportCallback from cosmos_framework.callbacks.iter_speed import IterSpeed @@ -40,10 +39,6 @@ def register_callbacks(): param_count=L(ParamCount)( # use model save_s3="${upload_reproducible_setup}", ), - dataloader_speed=L(DetailedDataLoadingSpeedMonitor)( - every_n=100, - save_s3="${upload_reproducible_setup}", - ), grad_clip=L(GradClip)(clip_norm=1.0, force_finite=False), # use model learning_rate_logger=L(LearningRateLogger)(every_n=10), low_precision=L(LowPrecisionCallback)( diff --git a/cosmos_framework/data/vfm/augmentors/vlm/nvlm_data_unify.py b/cosmos_framework/data/vfm/augmentors/vlm/nvlm_data_unify.py deleted file mode 100644 index c9ee9da..0000000 --- a/cosmos_framework/data/vfm/augmentors/vlm/nvlm_data_unify.py +++ /dev/null @@ -1,120 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: OpenMDW-1.1 - -"""Visual-Text Transformations or Augmentations.""" - -import io -from typing import Dict, Optional - -from PIL import Image - -from cosmos_framework.data.imaginaire.webdataset.augmentors.augmentor import Augmentor -from cosmos_framework.utils import log -from cosmos_framework.data.vfm.augmentors.vlm.nvlm_sample_loaders_and_part_filters import ( - get_data_class, - get_part_filter, - get_sample_loader, -) - - -class NVLMImageDataUnify(Augmentor): - """ - This augmentor is used to unify the data format of the nvlm data. - It will take the raw nvlm data tar and convert it to a dictionary with the following keys: - { - "__url__": str, - "__key__": str, - "data_class": str, - "images": List[PIL.Image.Image], - "text": str, - "words_boxes": Optional[List[List[int]]], - "words_text": Optional[List[str]], - "similarity_matrix": Optional[List[List[float]]], - } - """ - - def __init__( - self, - input_keys: list = ["raw_nvlm"], - output_keys: Optional[list] = [], - args: Optional[dict] = None, - data_path_prefix: list[str] = [ - "cosmos/ar/v2/nvlm/", - ], # prefix of the data in s3 - ) -> None: - super().__init__(input_keys, output_keys, args) - self.data_path_prefix = data_path_prefix - - def convert_image(self, img): - try: - if isinstance(img, bytes): - img = Image.open(io.BytesIO(img)).convert("RGB") - elif isinstance(img, Image.Image): - img = img.convert("RGB") - pass # Image is already in PIL format - elif isinstance(img, list): - for i in range(len(img)): - img[i], success = self.convert_image(img[i]) - if not success: - return Image.new("RGB", (256, 256), (0, 0, 0)), False - return img, True - else: - raise ValueError(f"Invalid image type: {type(img)}") - - success = True - except Exception as e: - log.warning(f"Error processing image: {e}. Creating an empty black image.", rank0_only=False) - img = Image.new("RGB", (256, 256), (0, 0, 0)) # Creates a 256x256 black image - success = False - return img, success - - def __call__(self, data_dict: Dict) -> Dict: - url = data_dict["__url__"] - data_path = "/".join(url.path.split("/")[:-1]) # remove the last part of the path - sample_loader = get_sample_loader(data_path) - part_filter = get_part_filter(data_path) - data_class = get_data_class(data_path) - assert sample_loader is not None and part_filter is not None and data_class is not None, ( - f"sample_loader({sample_loader}) or part_filter({part_filter}) or data_class({data_class}) is not found for {data_path}" - ) - - raw = {"__url__": url, "__key__": data_dict["__key__"]} - output = {"__url__": url, "__key__": data_dict["__key__"]} - for k, v in data_dict.items(): - ext = k.split(".")[-1] - if part_filter(ext): - raw[ext] = v - try: - output_converted = sample_loader(raw) - # Here output_converted will be a dictionary with the following keys: - # { - # "__key__": str, - # "image": PIL.Image.Image, - # "images": List[PIL.Image.Image], - # "text": str, - # "words_boxes": Optional - # "words_text": Optional - # "similarity_matrix": Optional - # } - except Exception as e: - log.warning( - f"Error in sample_loader: {e}, sample_loader: {sample_loader}, data_path: {data_path}, raw: {raw.keys()}, original_data_dict: {data_dict.keys()}, __url__: {url}, __key__: {data_dict['__key__']}" - ) - return None - - output.update(output_converted) - if "image" not in output_converted and "images" not in output_converted: - success = False - log.warning(f"image not found in {output_converted.keys()}") - if "image" in output_converted: # Single image case - img, success = self.convert_image(output["image"]) - output["images"] = [img] # What should be the format for the iamges - elif "images" in output_converted: - output["images"] = output_converted["images"] - output["images"], success = self.convert_image(output["images"]) - if not success: - log.warning(f"image conversion failed for {data_dict['__key__']} url: {url} | Skip this data") - return None - output["data_class"] = data_class - - return output diff --git a/cosmos_framework/data/vfm/augmentors/vlm/nvlm_sample_loaders_and_part_filters.py b/cosmos_framework/data/vfm/augmentors/vlm/nvlm_sample_loaders_and_part_filters.py deleted file mode 100644 index 88f93bf..0000000 --- a/cosmos_framework/data/vfm/augmentors/vlm/nvlm_sample_loaders_and_part_filters.py +++ /dev/null @@ -1,2719 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: OpenMDW-1.1 - -# Combined Sample Loaders -# Auto-generated script combining all sample_loader.py files (Dont edit this file! Edit the projects/cosmos/reasoning/v1/scripts/create_sample_loader_and_part_filter_file.py instead) - -import io - -import torch -from PIL import Image - -from cosmos_framework.utils import log -from cosmos_framework.data.vfm.data_sources.vlm.nvlm import data_path_mapping - -# This file was automatically generated by `nvgpt4 data prepare`. -# import torch - - -def sample_loader_0(raw: dict) -> dict: # Note: Images are already decoded to tensors - if "text" in raw: - caption = raw["text"] - else: - caption = raw["json"]["caption"] - return dict( - __key__=raw["__key__"], - image=raw["jpg"], # expected type: torch.Tensor - caption=caption, # expected type: str - ) - - -def part_filter_0(part: str) -> bool: - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("json", "jpg", "text") - - -# Loader for: /lustre/fsw/portfolios/llmservice/users/zhuoliny/extended-sci/data/merged/CoT -# This file was automatically generated by `energon prepare`. - - -def sample_loader_1(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - images = [raw["jpg"]] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_1(part: str) -> bool: - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("jpg", "json") - - -# Loader for: /lustre/fsw/portfolios/llmservice/users/zhuoliny/extended-sci/data/merged/single-choice -# This file was automatically generated by `energon prepare`. - - -def sample_loader_2(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - images = [raw["jpg"]] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_2(part: str) -> bool: - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("jpg", "json") - - -# Loader for: /lustre/fsw/portfolios/llmservice/users/zhuoliny/extended-sci/data/extended-sci-3/CoT -# This file was automatically generated by `energon prepare`. - - -def sample_loader_3(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - images = [raw["jpg"]] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_3(part: str) -> bool: - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("jpg", "json") - - -# Loader for: /lustre/fsw/portfolios/llmservice/users/zhuoliny/extended-sci/data/extended-sci-3/single-choice -# This file was automatically generated by `energon prepare`. - - -def sample_loader_4(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - images = [raw["jpg"]] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_4(part: str) -> bool: - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("jpg", "json") - - -# Loader for: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/datasets/nvlm/wdai/data/SceMQA_processed -# This file was automatically generated by `nvgpt4 data prepare`. - - -def sample_loader_5(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - images = [raw["jpg"]] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_5(part: str) -> bool: - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("json", "jpg") - - -# Loader for: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/datasets/nvlm/wdai/data/vqa_collection_doc_text_st_chart_scale_textbook_LRV_Screen -# This file was automatically generated by `nvgpt4 data prepare`. - - -def sample_loader_6(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - - key = raw["__key__"] - if "docvqa" in key: - context = json_item["question"] - answers = json_item["answers"] - image = raw["jpg"] - answer_weights = json_item["answer_weights"] - elif "textvqa" in key or "lrv_instruct" in key: - context = json_item["question"] - answers = json_item["answer"] - image = raw["jpg"] - answer_weights = None - elif "stvqa" in key: - context = json_item["question"] - answers = json_item["answers"] - image = raw["jpg"] - answer_weights = [1.0] * len(json_item["answers"]) - elif "chartqa" in key: - context = json_item["query"] - answers = json_item["label"] - image = raw["png"] - answer_weights = None - elif "screenqa" in key: - image = raw["jpg"] - context = json_item["question"] - answers = json_item["ground_truth"] - answer_weights = [1.0] * len(json_item["ground_truth"]) - elif "HME100K" in key: - image = raw["jpg"] - context = "Please write out the expression of the formula in the image using LaTeX format." - answers = json_item["latex_formula"] - answer_weights = None - else: # scale, textbook - image = raw["jpg"] - context = json_item["question"] - answers = json_item["answer"] - answer_weights = None - - return dict( - __key__=key, - image=image, - context=context, - answers=answers, - answer_weights=answer_weights, - ) - - -def part_filter_6(part: str) -> bool: - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("json", "jpg", "png") - - -# Loader for: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/datasets/nvlm/wdai/data/plotqa/processed -# This file was automatically generated by `nvgpt4 data prepare`. - - -def sample_loader_7(raw: dict) -> dict: # Note: Images are already decoded to tensors - j = raw["json"] - - return dict( - __key__=raw["__key__"], - image=raw["jpg"], # expected type: torch.Tensor - context=j["question_string"], # expected type: str - answers=j["answer"], # expected type: typing.Union[typing.List[str], NoneType], default: None - answer_weights=None, # expected type: typing.Union[torch.Tensor, NoneType], default: None - ) - - -def part_filter_7(part: str) -> bool: - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("jpg", "json") - - -# Loader for: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/datasets/nvlm/wdai/data/clevr-math/processed -# This file was automatically generated by `nvgpt4 data prepare`. - - -def sample_loader_8(raw: dict) -> dict: # Note: Images are already decoded to tensors - j = raw["json"] - - return dict( - __key__=raw["__key__"], - image=raw["jpg"], # expected type: torch.Tensor - context=j["question"], # expected type: str - answers=str(j["answer"]), # expected type: typing.Optional[typing.List[str]], default: None - answer_weights=None, # expected type: typing.Optional[torch.Tensor], default: None - ) - - -def part_filter_8(part: str) -> bool: - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("jpg", "json") - - -# Loader for: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/datasets/nvlm/wdai/data/MMC-Instruction/processed -# This file was automatically generated by `nvgpt4 data prepare`. - - -def sample_loader_9(raw: dict) -> dict: # Note: Images are already decoded to tensors - j = raw["json"] - - return dict( - __key__=raw["__key__"], - image=raw["jpg"], # expected type: torch.Tensor - context=j["question"].strip(), # expected type: str - answers=j["gt_answer"].strip(), # expected type: typing.Union[typing.List[str], NoneType], default: None - answer_weights=None, # expected type: typing.Union[torch.Tensor, NoneType], default: None - ) - - -def part_filter_9(part: str) -> bool: - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("json", "jpg") - - -# Loader for: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/datasets/nvlm/wdai/data/ocrvqa/processed -# This file was automatically generated by `nvgpt4 data prepare`. - - -def sample_loader_10(raw: dict) -> dict: # Note: Images are already decoded to tensors - j = raw["json"] - return dict( - image=raw["jpg"], # expected type: torch.Tensor - context=j["question"], # expected type: str - answers=j["answer"], # expected type: typing.Optional[typing.List[str]], default: None - answer_weights=None, # expected type: typing.Optional[torch.Tensor], default: None - ) - - -def part_filter_10(part: str) -> bool: - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("json", "jpg") - - -# Loader for: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/datasets/nvlm/wdai/data/dude/processed -# This file was automatically generated by `nvgpt4 data prepare`. - - -def sample_loader_11(raw: dict) -> dict: # Note: Images are already decoded to tensors - j = raw["json"] - - return dict( - __key__=raw["__key__"], - image=raw["jpg"], - context=j["question"], - answers=j["answer"], - answer_weights=None, - ) - - -def part_filter_11(part: str) -> bool: - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("jpg", "json") - - -# Loader for: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/datasets/nvlm/wdai/data/VisualMRC/processed -# This file was automatically generated by `nvgpt4 data prepare`. - - -def sample_loader_12(raw: dict) -> dict: # Note: Images are already decoded to tensors - j = raw["json"] - return dict( - __key__=raw["__key__"], - image=raw["jpg"], # expected type: torch.Tensor - context=j["question"], # expected type: str - answers=j["answer"], # expected type: typing.Optional[typing.List[str]], default: None - answer_weights=None, # expected type: typing.Optional[torch.Tensor], default: None - ) - - -def part_filter_12(part: str) -> bool: - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("jpg", "json") - - -# Loader for: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/datasets/nvlm/wdai/data/mcvqa_collection_scienceqa_ai2d_geoqaplus_geometry3k_tqa -# This file was automatically generated by `nvgpt4 data prepare`. - - -def sample_loader_13(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - key = raw["__key__"] - - if "geoqa_plus" in key or "tqa" in key: - return dict( - __key__=raw["__key__"], - image=raw["jpg"], - context=json_item["question"], - choices=json_item["choices"], - correct_choice_idx=json_item["correct_answer_index"], - ) - elif "geometry3k" in key: - return dict( - __key__=raw["__key__"], - image=raw["jpg"], - context=json_item["question"], - choices=json_item["choices"], - correct_choice_idx=ord(json_item["answer"].lower()) - 97, - ) - else: # science_qa, ai2d - image_key = "png" if "png" in raw else "jpg" - if image_key not in raw: - log.warning(f"Image key {image_key} not found in with raw keys: {raw.keys()}") - return dict( - __key__=raw["__key__"], # science_qa_sample_{idx} - image=raw[image_key], # expected type: torch.Tensor - context=json_item["question"], # expected type: str - choices=json_item["choices"], # expected type: typing.Union[typing.List[str], NoneType], default: None - correct_choice_idx=json_item["correct_choice_index"], - ) - - -def part_filter_13(part: str) -> bool: - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("json", "png", "jpg") - - -# Loader for: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/datasets/nvlm/wdai/data/arxiv_qa/processed -# This file was automatically generated by `nvgpt4 data prepare`. - - -def sample_loader_14(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - - return dict( - __key__=raw["__key__"], # arxiv_qa_sample_{idx} - image=raw["jpg"], # expected type: torch.Tensor - context=json_item["question"], # expected type: str - choices=json_item["options"], # expected type: typing.Union[typing.List[str], NoneType], default: None - correct_choice_idx=json_item["correct_choice_index"], - ) - - -def part_filter_14(part: str) -> bool: - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("jpg", "json") - - -# Loader for: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/datasets/nvlm/wdai/data/tabmwp/processed -# This file was automatically generated by `nvgpt4 data prepare`. - - -def sample_loader_15(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - - if json_item["question_type"] == "multi_choice": - correct_choice_idx = json_item["choices"].index(json_item["answer"]) - return dict( - __key__=raw["__key__"], - image=raw["jpg"], - context=json_item["question"], - choices=json_item["choices"], - correct_choice_idx=correct_choice_idx, - ) - else: - # A temporary hack for non multi-choice samples. - # If correct_choice_idx=-1, we should route it to the VQAWebdataset dataloading method. - # (74.7% free-text questions, 25.3% multi-choice questions) - return dict( - __key__=raw["__key__"], - image=raw["jpg"], - context=json_item["question"], - choices=[json_item["answer"]], - correct_choice_idx=-1, - ) - - -def part_filter_15(part: str) -> bool: - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("jpg", "json") - - -# Loader for: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/datasets/nvlm/wdai/data/ocr_vqa_aug/processed/ -# This file was automatically generated by `nvgpt4 data prepare`. - - -def sample_loader_16(raw: dict) -> dict: # Note: Images are already decoded to tensors - j = raw["json"] - - return dict( - __key__="llava-{}".format(raw["__key__"]), images=[raw["jpg"]], texts=j["conversations"], similarity_matrix=None - ) - - -def part_filter_16(part: str) -> bool: - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("jpg", "json") - - -# Loader for: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/datasets/nvlm/wdai/data/dvqa_full/processed -# This file was automatically generated by `nvgpt4 data prepare`. - - -def sample_loader_17(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - images = [raw["jpg"]] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_17(part: str) -> bool: - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("json", "jpg") - - -# Loader for: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/datasets/nvlm/wdai/data/LLaVA-v1.5_shuffle/no_refcoco_vg_ocrvqa -# This file was automatically generated by `nvgpt4 data prepare`. - - -def sample_loader_18(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - images = [raw["jpg"]] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_18(part: str) -> bool: - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("json", "jpg") - - -# Loader for: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/datasets/vqa/more_data/infographics_vqa/processed/train/ -# This file was automatically generated by `nvgpt4 data prepare`. - - -def sample_loader_19(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - images = [raw["jpg"]] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, # expected type: torch.Tensor - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_19(part: str) -> bool: - return part in ("jpg", "json") - - -# Loader for: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/datasets/nvlm/wdai/data/sharegpt4o/processed -# This file was automatically generated by `nvgpt4 data prepare`. - - -def sample_loader_20(raw: dict) -> dict: # Note: Images are already decoded to tensors - j = raw["json"] - - return dict( - __key__="llava-{}".format(raw["__key__"]), images=[raw["jpg"]], texts=j["conversations"], similarity_matrix=None - ) - - -def part_filter_20(part: str) -> bool: - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("jpg", "json") - - -# Loader for: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/datasets/nvlm/wdai/data/sparse_ocr_data/merged -# This file was automatically generated by `nvgpt4 data prepare`. - - -def sample_loader_21(raw: dict) -> dict: # Note: Images are already decoded to tensors - j = raw["json"] - - return dict( - __key__="llava-{}".format(raw["__key__"]), images=[raw["png"]], texts=j["conversations"], similarity_matrix=None - ) - - -def part_filter_21(part: str) -> bool: - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("json", "png") - - -# Loader for: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/datasets/nvlm/nayeonl/data/blendv4/MetaMathQA/processed/train_text_image -# This file was automatically generated by `nvgpt4 data prepare`. - - -def sample_loader_22(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - images = [raw["jpg"]] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, # expected type: torch.Tensor - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_22(part: str) -> bool: - return part in ("jpg", "json") - - -# Loader for: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/datasets/nvlm/nayeonl/data/blendv4/gsm8k/processed/train_text_image -# This file was automatically generated by `nvgpt4 data prepare`. - - -def sample_loader_23(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - images = [raw["jpg"]] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, # expected type: torch.Tensor - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_23(part: str) -> bool: - return part in ("jpg", "json") - - -# Loader for: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/datasets/nvlm/wdai/data/docmatix/processed -# This file was automatically generated by `nvgpt4 data prepare`. - - -def sample_loader_24(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - images = [raw["jpg"]] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_24(part: str) -> bool: - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("json", "jpg") - - -# Loader for: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/datasets/nvlm/wdai/data/bentham_hw_squad/processed -# This file was automatically generated by `nvgpt4 data prepare`. - - -def sample_loader_25(raw: dict) -> dict: # Note: Images are already decoded to tensors - j = raw["json"] - - return dict( - __key__="llava-{}".format(raw["__key__"]), images=[raw["jpg"]], texts=j["conversations"], similarity_matrix=None - ) - - -def part_filter_25(part: str) -> bool: - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("json", "jpg") - - -# Loader for: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/datasets/nvlm/wdai/data/WikiTableQA/processed -# This file was automatically generated by `nvgpt4 data prepare`. - - -def sample_loader_26(raw: dict) -> dict: # Note: Images are already decoded to tensors - j = raw["json"] - - return dict( - __key__="llava-{}".format(raw["__key__"]), images=[raw["jpg"]], texts=j["conversations"], similarity_matrix=None - ) - - -def part_filter_26(part: str) -> bool: - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("jpg", "json") - - -# Loader for: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/datasets/nvlm/wdai/data/figureqa/processed -# This file was automatically generated by `nvgpt4 data prepare`. - - -def sample_loader_27(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - images = [raw["jpg"]] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_27(part: str) -> bool: - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("json", "jpg") - - -# Loader for: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/datasets/nvlm/wdai/data/llava-onevision/ai2d_combined_processed -# This file was automatically generated by `nvgpt4 data prepare`. - - -def sample_loader_28(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - images = [raw["jpg"]] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_28(part: str) -> bool: - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("jpg", "json") - - -# Loader for: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/datasets/nvlm/wdai/data/llava-onevision/math_combined_processed -# This file was automatically generated by `nvgpt4 data prepare`. - - -def sample_loader_29(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - images = [raw["jpg"]] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_29(part: str) -> bool: - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("json", "jpg") - - -# Loader for: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/datasets/nvlm/wdai/data/llava-onevision/robut_combined_processed -# This file was automatically generated by `nvgpt4 data prepare`. - - -def sample_loader_30(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - images = [raw["jpg"]] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_30(part: str) -> bool: - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("json", "jpg") - - -# Loader for: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/datasets/nvlm/wdai/data/llava-onevision/llavar_20k_processed -# This file was automatically generated by `nvgpt4 data prepare`. - - -def sample_loader_31(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - images = [raw["jpg"]] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_31(part: str) -> bool: - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("jpg", "json") - - -# Loader for: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/datasets/nvlm/wdai/data/llava-onevision/tallyqa_processed -# This file was automatically generated by `nvgpt4 data prepare`. - - -def sample_loader_32(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - images = [raw["jpg"]] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_32(part: str) -> bool: - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("json", "jpg") - - -# Loader for: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/datasets/nvlm/wdai/data/llava-onevision/ureader_ie_processed -# This file was automatically generated by `nvgpt4 data prepare`. - - -def sample_loader_33(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - images = [raw["jpg"]] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_33(part: str) -> bool: - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("json", "jpg") - - -# Loader for: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/datasets/nvlm/wdai/data/llava-onevision/visual7w_processed -# This file was automatically generated by `nvgpt4 data prepare`. - - -def sample_loader_34(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - images = [raw["jpg"]] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_34(part: str) -> bool: - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("json", "jpg") - - -# Loader for: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/datasets/nvlm/wdai/data/llava-onevision/mavis_math_rule_geo_processed -# This file was automatically generated by `nvgpt4 data prepare`. - - -def sample_loader_35(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - images = [raw["jpg"]] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_35(part: str) -> bool: - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("jpg", "json") - - -# Loader for: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/datasets/nvlm/wdai/data/llava-onevision/ureader_kg_processed -# This file was automatically generated by `nvgpt4 data prepare`. - - -def sample_loader_36(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - images = [raw["jpg"]] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_36(part: str) -> bool: - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("jpg", "json") - - -# Loader for: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/datasets/nvlm/wdai/data/llava-onevision/ureader_qa_processed -# This file was automatically generated by `nvgpt4 data prepare`. - - -def sample_loader_37(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - images = [raw["jpg"]] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_37(part: str) -> bool: - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("json", "jpg") - - -# Loader for: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/datasets/nvlm/wdai/data/ocr_multi_collection_cocotext_textocr_ReCTs -# This file was automatically generated by `nvgpt4 data prepare`. - - -def sample_loader_38(raw: dict) -> dict: - j = raw["json"] - - if "ReCTs" in raw["__key__"]: - return dict( - __key__=raw["__key__"], - image=raw["jpg"], - text="", - words_boxes=j["quads_1k_normalized"], - words_text=j["texts"], - ) - else: # coco-text-multi, textocr-multi - return dict( - __key__=raw["__key__"], - image=raw["jpg"], - text="", - words_boxes=j["bboxes_1k_normalized"], - words_text=j["texts"], - ) - - -def part_filter_38(part: str) -> bool: - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("jpg", "json") - - -# Loader for: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/datasets/nvlm/wdai/data/pdfa-eng-wds/processed_word_len_500 -# This file was automatically generated by `nvgpt4 data prepare`. - - -def sample_loader_39(raw: dict) -> dict: # Note: Images are already decoded to tensors - j = raw["json"] - return dict( - image=raw["jpg"], # expected type: torch.Tensor - text=" ".join(j["lines"]["text"]), # expected type: str - ) - - -def part_filter_39(part: str) -> bool: - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("jpg", "json") - - -# Loader for: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/datasets/nvlm/wdai/data/llava-onevision/super_clevr_processed -# This file was automatically generated by `nvgpt4 data prepare`. - - -def sample_loader_40(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - images = [raw["jpg"]] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_40(part: str) -> bool: - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("json", "jpg") - - -# Loader for: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/datasets/nvlm/wdai/data/llava-onevision/icon_qa_processed -# This file was automatically generated by `nvgpt4 data prepare`. - - -def sample_loader_41(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - images = [raw["jpg"]] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_41(part: str) -> bool: - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("jpg", "json") - - -# Loader for: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/datasets/nvlm/wdai/data/augmentations/chartqa_aug -# This file was automatically generated by `nvgpt4 data prepare`. - - -def sample_loader_42(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - - img = Image.open(io.BytesIO(raw["img"])) - images = [img] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_42(part: str) -> bool: - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("img", "json") - - -# Loader for: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/datasets/nvlm/wdai/data/augmentations/gpt_chartqa -# This file was automatically generated by `nvgpt4 data prepare`. - - -def sample_loader_43(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - - img = Image.open(io.BytesIO(raw["img"])) - images = [img] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_43(part: str) -> bool: - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("img", "json") - - -# Loader for: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/datasets/nvlm/wdai/data/augmentations/gpt_docvqa -# This file was automatically generated by `nvgpt4 data prepare`. - - -def sample_loader_44(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - - img = Image.open(io.BytesIO(raw["img"])) - images = [img] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_44(part: str) -> bool: - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("img", "json") - - -# Loader for: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/datasets/nvlm/wdai/data/augmentations/docvqa_text -# This file was automatically generated by `nvgpt4 data prepare`. - - -def sample_loader_45(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - - img = Image.open(io.BytesIO(raw["img"])) - images = [img] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_45(part: str) -> bool: - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("img", "json") - - -# Loader for: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/datasets/nvlm/wdai/data/augmentations/textvqa_text -# This file was automatically generated by `nvgpt4 data prepare`. - - -def sample_loader_46(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - - img = Image.open(io.BytesIO(raw["img"])) - images = [img] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_46(part: str) -> bool: - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("img", "json") - - -# Loader for: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/datasets/nvlm/wdai/data/augmentations/i2s-musicsheet -# This file was automatically generated by `nvgpt4 data prepare`. - - -def sample_loader_47(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - - img = Image.open(io.BytesIO(raw["img"])) - images = [img] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_47(part: str) -> bool: - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("img", "json") - - -# Loader for: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/datasets/nvlm/wdai/data/augmentations/music -# This file was automatically generated by `nvgpt4 data prepare`. - - -def sample_loader_48(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - - img = Image.open(io.BytesIO(raw["img"])) - images = [img] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_48(part: str) -> bool: - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("img", "json") - - -# Loader for: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/datasets/nvlm/wdai/data/augmentations/invoice -# This file was automatically generated by `nvgpt4 data prepare`. - - -def sample_loader_49(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - - img = Image.open(io.BytesIO(raw["img"])) - images = [img] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_49(part: str) -> bool: - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("img", "json") - - -# Loader for: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/datasets/nvlm/wdai/data/augmentations/k12 -# This file was automatically generated by `nvgpt4 data prepare`. - - -def sample_loader_50(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - - img = Image.open(io.BytesIO(raw["img"])) - images = [img] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_50(part: str) -> bool: - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("img", "json") - - -# Loader for: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/datasets/nvlm/wdai/data/augmentations/MTVQA -# This file was automatically generated by `nvgpt4 data prepare`. - - -def sample_loader_51(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - for i, turn in enumerate(json_item["conversations"]): - if i > 0 and turn["from"] == "human" and "" in turn["value"]: - turn["value"] = turn["value"].replace("\n", "") - - img = Image.open(io.BytesIO(raw["img"])) - images = [img] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_51(part: str) -> bool: - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("img", "json") - - -# Loader for: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/datasets/nvlm/wdai/data/augmentations/VisualWebInstruct -# This file was automatically generated by `nvgpt4 data prepare`. - - -def sample_loader_52(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - - img = Image.open(io.BytesIO(raw["img"])) - images = [img] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_52(part: str) -> bool: - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("img", "json") - - -# Loader for: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/datasets/nvlm/wdai/data/augmentations/financeqa -# This file was automatically generated by `nvgpt4 data prepare`. - - -def sample_loader_53(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - # for i, turn in enumerate(json_item['conversations']): - # if i > 0 and turn['from'] == 'human' and '' in turn['value']: - # turn['value'] = turn['value'].replace("\n", "") - - img = Image.open(io.BytesIO(raw["img"])) - images = [img] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_53(part: str) -> bool: - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("img", "json") - - -# Loader for: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/datasets/nvlm/wdai/data/augmentations/docreason -# This file was automatically generated by `nvgpt4 data prepare`. - - -def sample_loader_54(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - - img = Image.open(io.BytesIO(raw["img"])) - images = [img] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_54(part: str) -> bool: - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("img", "json") - - -# Loader for: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/datasets/nvlm/wdai/data/augmentations/gpt_mtwi -# This file was automatically generated by `nvgpt4 data prepare`. - - -def sample_loader_55(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - - img = Image.open(io.BytesIO(raw["img"])) - images = [img] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_55(part: str) -> bool: - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("img", "json") - - -# Loader for: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/datasets/nvlm/wdai/data/augmentations/geos_gpt -# This file was automatically generated by `nvgpt4 data prepare`. - - -def sample_loader_56(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - - img = Image.open(io.BytesIO(raw["img"])) - images = [img] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_56(part: str) -> bool: - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("img", "json") - - -# Loader for: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/datasets/nvlm/wdai/data/augmentations/cauldron_vistext -# This file was automatically generated by `nvgpt4 data prepare`. - - -def sample_loader_57(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - - img = Image.open(io.BytesIO(raw["img"])) - images = [img] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_57(part: str) -> bool: - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("img", "json") - - -# Loader for: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/datasets/nvlm/wdai/data/augmentations/memes -# This file was automatically generated by `nvgpt4 data prepare`. - - -def sample_loader_58(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - - img = Image.open(io.BytesIO(raw["img"])) - images = [img] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_58(part: str) -> bool: - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("img", "json") - - -# Loader for: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/datasets/nvlm/wdai/data/augmentations/gpt_roadtext -# This file was automatically generated by `nvgpt4 data prepare`. - - -def sample_loader_59(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - - img = Image.open(io.BytesIO(raw["img"])) - images = [img] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_59(part: str) -> bool: - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("img", "json") - - -# Loader for: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/datasets/nvlm/wdai/data/augmentations/indoor_qa -# This file was automatically generated by `nvgpt4 data prepare`. - - -def sample_loader_60(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - - img = Image.open(io.BytesIO(raw["img"])) - images = [img] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_60(part: str) -> bool: - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("img", "json") - - -# Loader for: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/datasets/nvlm/wdai/data/augmentations/colpali -# This file was automatically generated by `nvgpt4 data prepare`. - - -def sample_loader_61(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - - img = Image.open(io.BytesIO(raw["img"])) - images = [img] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_61(part: str) -> bool: - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("img", "json") - - -# Loader for: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/datasets/nvlm/wdai/data/augmentations/pmc_vqa -# This file was automatically generated by `nvgpt4 data prepare`. - - -def sample_loader_62(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - - img = Image.open(io.BytesIO(raw["img"])) - images = [img] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_62(part: str) -> bool: - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("img", "json") - - -# Loader for: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/datasets/nvlm/wdai/data/augmentations/pathvqa -# This file was automatically generated by `nvgpt4 data prepare`. - - -def sample_loader_63(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - - img = Image.open(io.BytesIO(raw["img"])) - images = [img] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_63(part: str) -> bool: - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("img", "json") - - -# Loader for: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/datasets/nvlm/wdai/data/augmentations/sciqa -# This file was automatically generated by `nvgpt4 data prepare`. - - -def sample_loader_64(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - - img = Image.open(io.BytesIO(raw["img"])) - images = [img] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_64(part: str) -> bool: - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("img", "json") - - -# Loader for: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/datasets/nvlm/wdai/data/augmentations/chinese_meme -# This file was automatically generated by `nvgpt4 data prepare`. - - -def sample_loader_65(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - - img = Image.open(io.BytesIO(raw["img"])) - images = [img] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_65(part: str) -> bool: - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("img", "json") - - -# Loader for: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/datasets/nvlm/wdai/data/augmentations/gpt_hiertext -# This file was automatically generated by `nvgpt4 data prepare`. - - -def sample_loader_66(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - - img = Image.open(io.BytesIO(raw["img"])) - images = [img] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_66(part: str) -> bool: - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("img", "json") - - -# Loader for: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/datasets/nvlm/wdai/data/augmentations/cauldron_cocoqa -# This file was automatically generated by `nvgpt4 data prepare`. - - -def sample_loader_67(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - - img = Image.open(io.BytesIO(raw["img"])) - images = [img] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_67(part: str) -> bool: - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("img", "json") - - -# Loader for: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/datasets/nvlm/wdai/data/cmm-math/processed -# This file was automatically generated by `nvgpt4 data prepare`. - - -def sample_loader_68(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - images = [raw["jpg"]] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_68(part: str) -> bool: - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("json", "jpg") - - -# Loader for: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/datasets/nvlm/wdai/data/mmtab/processed -# This file was automatically generated by `nvgpt4 data prepare`. - - -def sample_loader_69(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - images = [raw["jpg"]] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_69(part: str) -> bool: - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("jpg", "json") - - -# Loader for: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/datasets/nvlm/wdai/data/simchart9k/processed -# This file was automatically generated by `nvgpt4 data prepare`. - - -def sample_loader_70(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - images = [raw["jpg"]] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_70(part: str) -> bool: - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("jpg", "json") - - -# Loader for: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/datasets/nvlm/wdai/data/llava-onevision/mapqa_processed -# This file was automatically generated by `nvgpt4 data prepare`. - - -def sample_loader_71(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - images = [raw["jpg"]] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_71(part: str) -> bool: - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("jpg", "json") - - -# Loader for: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/datasets/nvlm/wdai/data/llava-onevision/vizwiz_processed -# This file was automatically generated by `nvgpt4 data prepare`. - - -def sample_loader_72(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - images = [raw["jpg"]] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_72(part: str) -> bool: - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("jpg", "json") - - -# Loader for: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/datasets/nvlm/wdai/data/augmentations/gpt_infovqa -# This file was automatically generated by `nvgpt4 data prepare`. - - -def sample_loader_73(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - - img = Image.open(io.BytesIO(raw["img"])) - images = [img] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_73(part: str) -> bool: - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("json", "img") - - -# Loader for: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/datasets/nvlm/wdai/data/augmentations/viquae -# This file was automatically generated by `nvgpt4 data prepare`. - - -def sample_loader_74(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - - img = Image.open(io.BytesIO(raw["img"])) - images = [img] - - return dict( - __key__="llava-{}".format(raw["__key__"]), - images=images, - texts=json_item["conversations"], - similarity_matrix=None, - ) - - -def part_filter_74(part: str) -> bool: - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("json", "img") - - -# Loader for: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/datasets/captioning/ccs_recaptioned/webdataset -# This file was automatically generated by `nvgpt4 data prepare`. - - -def sample_loader_75(raw: dict) -> dict: # Note: Images are already decoded to tensors - if "text" in raw: - caption = raw["text"] - else: - caption = raw["json"]["caption"] - return dict( - __key__=raw["__key__"], - image=raw["jpg"], # expected type: torch.Tensor - caption=caption, # expected type: str - ) - - -def part_filter_75(part: str) -> bool: - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("json", "jpg", "text") - - -# Loader for: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/datasets/captioning/laion115m-clean -# This file was automatically generated by `nvgpt4 data prepare`. - - -def sample_loader_76(raw: dict) -> dict: # Note: Images are already decoded to tensors - if "text" in raw: - caption = raw["text"] - else: - caption = raw["json"]["caption"] - return dict( - __key__=raw["__key__"], - image=raw["jpg"], # expected type: torch.Tensor - caption=caption, # expected type: str - ) - - -def part_filter_76(part: str) -> bool: - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("json", "jpg", "text") - - -# Loader for: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/datasets/nvlm/wdai/data/dvqa_full/processed_pt -# This file was automatically generated by `nvgpt4 data prepare`. - - -def sample_loader_77(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - - total = len(json_item["conversations"]) // 2 - idx = random.randrange(total) # noqa: F821 - human = json_item["conversations"][idx * 2] - out = json_item["conversations"][idx * 2 + 1] - - return dict( - __key__=raw["__key__"], - image=raw["jpg"], - context=human["value"].replace("\n", ""), - answers=out["value"], - answer_weights=None, - ) - - -def part_filter_77(part: str) -> bool: - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("json", "jpg") - - -# Loader for: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/datasets/nvlm/wdai/data/docmatix/processed_pt -# This file was automatically generated by `nvgpt4 data prepare`. - - -def sample_loader_78(raw: dict) -> dict: # Note: Images are already decoded to tensors - json_item = raw["json"] - - total = len(json_item["conversations"]) // 2 - idx = random.randrange(total) # noqa: F821 - human = json_item["conversations"][idx * 2] - out = json_item["conversations"][idx * 2 + 1] - - return dict( - __key__=raw["__key__"], - image=raw["jpg"], - context=human["value"].replace("\n", ""), - answers=out["value"], - answer_weights=None, - ) - - -def part_filter_78(part: str) -> bool: - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("json", "jpg") - - -# Loader for: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/datasets/vqa/VQAv2/stage1 - - -def sample_loader_79(raw: dict) -> dict: # Note: Images are already decoded to tensors - j = raw["json"] - if "answer" in j: - answers = [a[0] for a in j["answer"][0]] - answer_weights = torch.Tensor([float(a[1]) for a in j["answer"][0]]) - else: - answers = None - answer_weights = None - - return dict( - __key__=raw["__key__"], - image=raw["jpg"], # expected type: torch.Tensor - context=j["question"], # expected type: str - answers=answers, # expected type: typing.List[str] - answer_weights=answer_weights, # expected type: typing.Union[torch.Tensor, NoneType] - ) - - -def part_filter_79(part: str) -> bool: - # Filter for parts required by the sample_loader - return part in ("jpg", "json") - - -# Loader for: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/datasets/vqa/Visual_Genome -# This file was automatically generated by `nvgpt4 data prepare`. - - -def sample_loader_80(raw: dict) -> dict: # Note: Images are already decoded to tensors - return dict( - __key__=raw["__key__"], - image=raw["jpg"], # expected type: torch.Tensor - context=raw["json"]["question"], # expected type: str - answers=raw["json"]["answer"], # expected type: typing.Union[typing.List[str], NoneType], default: None - answer_weights=None, # expected type: typing.Union[torch.Tensor, NoneType], default: None - ) - - -def part_filter_80(part: str) -> bool: - return part in ("json", "jpg") - - -# Loader for: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/datasets/nvlm/wdai/data/pdfa-eng-wds/processed_word_len_300 -# This file was automatically generated by `nvgpt4 data prepare`. - - -def sample_loader_81(raw: dict) -> dict: # Note: Images are already decoded to tensors - j = raw["json"] - return dict( - image=raw["jpg"], # expected type: torch.Tensor - text=" ".join(j["lines"]["text"]), # expected type: str - ) - - -def part_filter_81(part: str) -> bool: - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("jpg", "json") - - -# Loader for: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/datasets/nvlm/wdai/data/textocr/processed -# This file was automatically generated by `nvgpt4 data prepare`. - - -def sample_loader_82(raw: dict) -> dict: # Note: Images are already decoded to tensors - j = raw["json"] - - return dict( - __key__=raw["__key__"], - image=raw["jpg"], - text=j["text"], - words_boxes=j["bbox_1k_normalized"], - ) - - -def part_filter_82(part: str) -> bool: - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("jpg", "json") - - -# Loader for: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/datasets/nvlm/wdai/data/coco-text/processed -# This file was automatically generated by `nvgpt4 data prepare`. - - -def sample_loader_83(raw: dict) -> dict: - j = raw["json"] - - return dict( - __key__=raw["__key__"], - image=raw["jpg"], - text=j["text"], - words_boxes=j["bbox_1k_normalized"], - ) - - -def part_filter_83(part: str) -> bool: - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("json", "jpg") - - -# Loader for: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/datasets/nvlm/wdai/data/ArT/processed -# This file was automatically generated by `nvgpt4 data prepare`. - - -def sample_loader_84(raw: dict) -> dict: # Note: Images are already decoded to tensors - j = raw["json"] - - return dict( - __key__=raw["__key__"], - image=raw["jpg"], - text=j["text"], - words_boxes=j["bbox_1k_normalized"], - ) - - -def part_filter_84(part: str) -> bool: - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("json", "jpg") - - -# Loader for: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/datasets/nvlm/wdai/data/ReCTs/processed -# This file was automatically generated by `nvgpt4 data prepare`. - - -def sample_loader_85(raw: dict) -> dict: # Note: Images are already decoded to tensors - j = raw["json"] - - return dict(__key__=raw["__key__"], image=raw["jpg"], text=j["text"], words_boxes=j["quad_1k_normalized"]) - - -def part_filter_85(part: str) -> bool: - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("json", "jpg") - - -# Loader for: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/datasets/nvlm/wdai/data/lsvt/processed -# This file was automatically generated by `nvgpt4 data prepare`. - - -def sample_loader_86(raw: dict) -> dict: # Note: Images are already decoded to tensors - j = raw["json"] - - return dict( - __key__=raw["__key__"], - image=raw["jpg"], - text=j["text"], - words_boxes=j["bbox_1k_normalized"], - ) - - -def part_filter_86(part: str) -> bool: - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("jpg", "json") - - -# Loader for: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/datasets/nvlm/wdai/data/RCTW/processed -# This file was automatically generated by `nvgpt4 data prepare`. - - -def sample_loader_87(raw: dict) -> dict: # Note: Images are already decoded to tensors - j = raw["json"] - - quad = j["quad"] - quad = [val for point in quad for val in point] - - return dict( - image=raw["jpg"], # expected type: torch.Tensor - text=j["text"], # expected type: str - words_boxes=quad, # expected type: typing.Optional[torch.Tensor], default: None - ) - - -def part_filter_87(part: str) -> bool: - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("json", "jpg") - - -# Loader for: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/datasets/nvlm/wdai/data/coco-text/processed_multi -# This file was automatically generated by `nvgpt4 data prepare`. - - -def sample_loader_88(raw: dict) -> dict: # Note: Images are already decoded to tensors - j = raw["json"] - return dict( - __key__=raw["__key__"], - image=raw["jpg"], - text="", - words_boxes=j["bboxes_1k_normalized"], - words_text=j["texts"], - ) - - -def part_filter_88(part: str) -> bool: - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("json", "jpg") - - -# Loader for: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/datasets/nvlm/wdai/data/textocr/processed_multi -# This file was automatically generated by `nvgpt4 data prepare`. - - -def sample_loader_89(raw: dict) -> dict: - j = raw["json"] - return dict( - __key__=raw["__key__"], - image=raw["jpg"], - text="", - words_boxes=j["bboxes_1k_normalized"], - words_text=j["texts"], - ) - - -def part_filter_89(part: str) -> bool: - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("jpg", "json") - - -# Loader for: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/datasets/nvlm/wdai/data/ReCTs/processed_multi -# This file was automatically generated by `nvgpt4 data prepare`. - - -def sample_loader_90(raw: dict) -> dict: # Note: Images are already decoded to tensors - j = raw["json"] - return dict( - __key__=raw["__key__"], - image=raw["jpg"], - text="", - words_boxes=j["quads_1k_normalized"], - words_text=j["texts"], - ) - - -def part_filter_90(part: str) -> bool: - # E.g. if your dataset contains jpeg, txt and json, but you won't use json, - # remove it from the list, such that it is not decoded. If you need all, keep as is - return part in ("json", "jpg") - - -# Loader for: /lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/datasets/vqa/VQAv2/stage1 - - -def sample_loader_91(raw: dict) -> dict: # Note: Images are already decoded to tensors - j = raw["json"] - if "answer" in j: - answers = [a[0] for a in j["answer"][0]] - answer_weights = torch.Tensor([float(a[1]) for a in j["answer"][0]]) - else: - answers = None - answer_weights = None - - return dict( - __key__=raw["__key__"], - image=raw["jpg"], # expected type: torch.Tensor - context=j["question"], # expected type: str - answers=answers, # expected type: typing.List[str] - answer_weights=answer_weights, # expected type: typing.Union[torch.Tensor, NoneType] - ) - - -def part_filter_91(part: str) -> bool: - # Filter for parts required by the sample_loader - return part in ("jpg", "json") - - -# Dataset -> Sample Loader Mapping -dataset_loader_mapping = { - "coco_train_val_restval": { - "sample_loader": "sample_loader_0", - "part_filter": "part_filter_0", - "data_class": "CaptioningWebdataset", - "data_weight": 0.01, - }, - "extended-sci/data/merged/CoT": { - "sample_loader": "sample_loader_1", - "part_filter": "part_filter_1", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.006, - }, - "extended-sci/data/merged/single-choice": { - "sample_loader": "sample_loader_2", - "part_filter": "part_filter_2", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.004, - }, - "extended-sci/data/extended-sci-3/CoT": { - "sample_loader": "sample_loader_3", - "part_filter": "part_filter_3", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.0006, - }, - "extended-sci/data/extended-sci-3/single-choice": { - "sample_loader": "sample_loader_4", - "part_filter": "part_filter_4", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.0004, - }, - "nvlm/wdai/data/SceMQA_processed": { - "sample_loader": "sample_loader_5", - "part_filter": "part_filter_5", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.0006, - }, - "nvlm/wdai/data/vqa_collection_doc_text_st_chart_scale_textbook_LRV_Screen": { - "sample_loader": "sample_loader_6", - "part_filter": "part_filter_6", - "data_class": "VQAWebdataset", - "data_weight": 0.08, - }, - "nvlm/wdai/data/plotqa/processed": { - "sample_loader": "sample_loader_7", - "part_filter": "part_filter_7", - "data_class": "VQAWebdataset", - "data_weight": 0.095, - }, - "nvlm/wdai/data/clevr-math/processed": { - "sample_loader": "sample_loader_8", - "part_filter": "part_filter_8", - "data_class": "VQAWebdataset", - "data_weight": 0.02, - }, - "nvlm/wdai/data/MMC-Instruction/processed": { - "sample_loader": "sample_loader_9", - "part_filter": "part_filter_9", - "data_class": "VQAWebdataset", - "data_weight": 0.07, - }, - "nvlm/wdai/data/ocrvqa/processed": { - "sample_loader": "sample_loader_10", - "part_filter": "part_filter_10", - "data_class": "VQAWebdataset", - "data_weight": 0.06, - }, - "nvlm/wdai/data/dude/processed": { - "sample_loader": "sample_loader_11", - "part_filter": "part_filter_11", - "data_class": "VQAWebdataset", - "data_weight": 0.01, - }, - "nvlm/wdai/data/VisualMRC/processed": { - "sample_loader": "sample_loader_12", - "part_filter": "part_filter_12", - "data_class": "VQAWebdataset", - "data_weight": 0.015, - }, - "nvlm/wdai/data/mcvqa_collection_scienceqa_ai2d_geoqaplus_geometry3k_tqa": { - "sample_loader": "sample_loader_13", - "part_filter": "part_filter_13", - "data_class": "MultiChoiceVQAWebdataset", - "data_weight": 0.025, - }, - "nvlm/wdai/data/arxiv_qa/processed": { - "sample_loader": "sample_loader_14", - "part_filter": "part_filter_14", - "data_class": "MultiChoiceVQAWebdataset", - "data_weight": 0.02, - }, - "nvlm/wdai/data/tabmwp/processed": { - "sample_loader": "sample_loader_15", - "part_filter": "part_filter_15", - "data_class": "MultiChoiceVQAWebdataset", - "data_weight": 0.015, - }, - "nvlm/wdai/data/ocr_vqa_aug/processed": { - "sample_loader": "sample_loader_16", - "part_filter": "part_filter_16", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.055, - }, - "nvlm/wdai/data/dvqa_full/processed": { - "sample_loader": "sample_loader_17", - "part_filter": "part_filter_17", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.055, - }, - "nvlm/wdai/data/LLaVA-v1.5_shuffle/no_refcoco_vg_ocrvqa": { - "sample_loader": "sample_loader_18", - "part_filter": "part_filter_18", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.085, - }, - "vqa/more_data/infographics_vqa/processed/train": { - "sample_loader": "sample_loader_19", - "part_filter": "part_filter_19", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.01, - }, - "nvlm/wdai/data/sharegpt4o/processed": { - "sample_loader": "sample_loader_20", - "part_filter": "part_filter_20", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.02, - }, - "nvlm/wdai/data/sparse_ocr_data/merged": { - "sample_loader": "sample_loader_21", - "part_filter": "part_filter_21", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.045, - }, - "nvlm/nayeonl/data/blendv4/MetaMathQA/processed/train_text_image": { - "sample_loader": "sample_loader_22", - "part_filter": "part_filter_22", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.004, - }, - "nvlm/nayeonl/data/blendv4/gsm8k/processed/train_text_image": { - "sample_loader": "sample_loader_23", - "part_filter": "part_filter_23", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.003, - }, - "nvlm/wdai/data/docmatix/processed": { - "sample_loader": "sample_loader_24", - "part_filter": "part_filter_24", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.1, - }, - "nvlm/wdai/data/bentham_hw_squad/processed": { - "sample_loader": "sample_loader_25", - "part_filter": "part_filter_25", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.01, - }, - "nvlm/wdai/data/WikiTableQA/processed": { - "sample_loader": "sample_loader_26", - "part_filter": "part_filter_26", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.003, - }, - "nvlm/wdai/data/figureqa/processed": { - "sample_loader": "sample_loader_27", - "part_filter": "part_filter_27", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.01, - }, - "nvlm/wdai/data/llava-onevision/ai2d_combined_processed": { - "sample_loader": "sample_loader_28", - "part_filter": "part_filter_28", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.01, - }, - "nvlm/wdai/data/llava-onevision/math_combined_processed": { - "sample_loader": "sample_loader_29", - "part_filter": "part_filter_29", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.035, - }, - "nvlm/wdai/data/llava-onevision/robut_combined_processed": { - "sample_loader": "sample_loader_30", - "part_filter": "part_filter_30", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.01, - }, - "nvlm/wdai/data/llava-onevision/llavar_20k_processed": { - "sample_loader": "sample_loader_31", - "part_filter": "part_filter_31", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.007, - }, - "nvlm/wdai/data/llava-onevision/tallyqa_processed": { - "sample_loader": "sample_loader_32", - "part_filter": "part_filter_32", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.02, - }, - "nvlm/wdai/data/llava-onevision/ureader_ie_processed": { - "sample_loader": "sample_loader_33", - "part_filter": "part_filter_33", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.007, - }, - "nvlm/wdai/data/llava-onevision/visual7w_processed": { - "sample_loader": "sample_loader_34", - "part_filter": "part_filter_34", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.006, - }, - "nvlm/wdai/data/llava-onevision/mavis_math_rule_geo_processed": { - "sample_loader": "sample_loader_35", - "part_filter": "part_filter_35", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.01, - }, - "nvlm/wdai/data/llava-onevision/ureader_kg_processed": { - "sample_loader": "sample_loader_36", - "part_filter": "part_filter_36", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.005, - }, - "nvlm/wdai/data/llava-onevision/ureader_qa_processed": { - "sample_loader": "sample_loader_37", - "part_filter": "part_filter_37", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.02, - }, - "nvlm/wdai/data/ocr_multi_collection_cocotext_textocr_ReCTs": { - "sample_loader": "sample_loader_38", - "part_filter": "part_filter_38", - "data_class": "OCRWebdataset", - "data_weight": 0.01, - }, - "nvlm/wdai/data/pdfa-eng-wds/processed_word_len_500": { - "sample_loader": "sample_loader_39", - "part_filter": "part_filter_39", - "data_class": "OCRWebdataset", - "data_weight": 0.015, - }, - "nvlm/wdai/data/llava-onevision/super_clevr_processed": { - "sample_loader": "sample_loader_40", - "part_filter": "part_filter_40", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.007, - }, - "nvlm/wdai/data/llava-onevision/icon_qa_processed": { - "sample_loader": "sample_loader_41", - "part_filter": "part_filter_41", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.009, - }, - "nvlm/wdai/data/augmentations/chartqa_aug": { - "sample_loader": "sample_loader_42", - "part_filter": "part_filter_42", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.005, - }, - "nvlm/wdai/data/augmentations/gpt_chartqa": { - "sample_loader": "sample_loader_43", - "part_filter": "part_filter_43", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.006, - }, - "nvlm/wdai/data/augmentations/gpt_docvqa": { - "sample_loader": "sample_loader_44", - "part_filter": "part_filter_44", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.006, - }, - "nvlm/wdai/data/augmentations/docvqa_text": { - "sample_loader": "sample_loader_45", - "part_filter": "part_filter_45", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.006, - }, - "nvlm/wdai/data/augmentations/textvqa_text": { - "sample_loader": "sample_loader_46", - "part_filter": "part_filter_46", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.008, - }, - "nvlm/wdai/data/augmentations/i2s-musicsheet": { - "sample_loader": "sample_loader_47", - "part_filter": "part_filter_47", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.0005, - }, - "nvlm/wdai/data/augmentations/music": { - "sample_loader": "sample_loader_48", - "part_filter": "part_filter_48", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.007, - }, - "nvlm/wdai/data/augmentations/invoice": { - "sample_loader": "sample_loader_49", - "part_filter": "part_filter_49", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.002, - }, - "nvlm/wdai/data/augmentations/k12": { - "sample_loader": "sample_loader_50", - "part_filter": "part_filter_50", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.019, - }, - "nvlm/wdai/data/augmentations/MTVQA": { - "sample_loader": "sample_loader_51", - "part_filter": "part_filter_51", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.007, - }, - "nvlm/wdai/data/augmentations/VisualWebInstruct": { - "sample_loader": "sample_loader_52", - "part_filter": "part_filter_52", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.028, - }, - "nvlm/wdai/data/augmentations/financeqa": { - "sample_loader": "sample_loader_53", - "part_filter": "part_filter_53", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.005, - }, - "nvlm/wdai/data/augmentations/docreason": { - "sample_loader": "sample_loader_54", - "part_filter": "part_filter_54", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.004, - }, - "nvlm/wdai/data/augmentations/gpt_mtwi": { - "sample_loader": "sample_loader_55", - "part_filter": "part_filter_55", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.005, - }, - "nvlm/wdai/data/augmentations/geos_gpt": { - "sample_loader": "sample_loader_56", - "part_filter": "part_filter_56", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.0001, - }, - "nvlm/wdai/data/augmentations/cauldron_vistext": { - "sample_loader": "sample_loader_57", - "part_filter": "part_filter_57", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.005, - }, - "nvlm/wdai/data/augmentations/memes": { - "sample_loader": "sample_loader_58", - "part_filter": "part_filter_58", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.005, - }, - "nvlm/wdai/data/augmentations/gpt_roadtext": { - "sample_loader": "sample_loader_59", - "part_filter": "part_filter_59", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.0002, - }, - "nvlm/wdai/data/augmentations/indoor_qa": { - "sample_loader": "sample_loader_60", - "part_filter": "part_filter_60", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.001, - }, - "nvlm/wdai/data/augmentations/colpali": { - "sample_loader": "sample_loader_61", - "part_filter": "part_filter_61", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.007, - }, - "nvlm/wdai/data/augmentations/pmc_vqa": { - "sample_loader": "sample_loader_62", - "part_filter": "part_filter_62", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.01, - }, - "nvlm/wdai/data/augmentations/pathvqa": { - "sample_loader": "sample_loader_63", - "part_filter": "part_filter_63", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.004, - }, - "nvlm/wdai/data/augmentations/sciqa": { - "sample_loader": "sample_loader_64", - "part_filter": "part_filter_64", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.027, - }, - "nvlm/wdai/data/augmentations/chinese_meme": { - "sample_loader": "sample_loader_65", - "part_filter": "part_filter_65", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.001, - }, - "nvlm/wdai/data/augmentations/gpt_hiertext": { - "sample_loader": "sample_loader_66", - "part_filter": "part_filter_66", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.003, - }, - "nvlm/wdai/data/augmentations/cauldron_cocoqa": { - "sample_loader": "sample_loader_67", - "part_filter": "part_filter_67", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.007, - }, - "nvlm/wdai/data/cmm-math/processed": { - "sample_loader": "sample_loader_68", - "part_filter": "part_filter_68", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.001, - }, - "nvlm/wdai/data/mmtab/processed": { - "sample_loader": "sample_loader_69", - "part_filter": "part_filter_69", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.008, - }, - "nvlm/wdai/data/simchart9k/processed": { - "sample_loader": "sample_loader_70", - "part_filter": "part_filter_70", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.001, - }, - "nvlm/wdai/data/llava-onevision/mapqa_processed": { - "sample_loader": "sample_loader_71", - "part_filter": "part_filter_71", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.005, - }, - "nvlm/wdai/data/llava-onevision/vizwiz_processed": { - "sample_loader": "sample_loader_72", - "part_filter": "part_filter_72", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.002, - }, - "nvlm/wdai/data/augmentations/gpt_infovqa": { - "sample_loader": "sample_loader_73", - "part_filter": "part_filter_73", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.001, - }, - "nvlm/wdai/data/augmentations/viquae": { - "sample_loader": "sample_loader_74", - "part_filter": "part_filter_74", - "data_class": "SimilarityInterleavedWebdataset", - "data_weight": 0.0005, - }, - "captioning/ccs_recaptioned/webdataset": { - "sample_loader": "sample_loader_75", - "part_filter": "part_filter_75", - "data_class": "CaptioningWebdataset", - "data_weight": 0.2, - }, - "captioning/laion115m-clean": { - "sample_loader": "sample_loader_76", - "part_filter": "part_filter_76", - "data_class": "CaptioningWebdataset", - "data_weight": 0.579, - }, - "nvlm/wdai/data/dvqa_full/processed_pt": { - "sample_loader": "sample_loader_77", - "part_filter": "part_filter_77", - "data_class": "VQAWebdataset", - "data_weight": 0.02, - }, - "nvlm/wdai/data/docmatix/processed_pt": { - "sample_loader": "sample_loader_78", - "part_filter": "part_filter_78", - "data_class": "VQAWebdataset", - "data_weight": 0.02, - }, - "vqa/VQAv2/stage1": { - "sample_loader": "sample_loader_91", - "part_filter": "part_filter_91", - "data_class": "VQAWebdataset", - "data_weight": 1.0, - }, - "vqa/Visual_Genome": { - "sample_loader": "sample_loader_80", - "part_filter": "part_filter_80", - "data_class": "VQAWebdataset", - "data_weight": 0.01, - }, - "nvlm/wdai/data/pdfa-eng-wds/processed_word_len_300": { - "sample_loader": "sample_loader_81", - "part_filter": "part_filter_81", - "data_class": "OCRWebdataset", - "data_weight": 0.08, - }, - "nvlm/wdai/data/textocr/processed": { - "sample_loader": "sample_loader_82", - "part_filter": "part_filter_82", - "data_class": "OCRWebdataset", - "data_weight": 0.02, - }, - "nvlm/wdai/data/coco-text/processed": { - "sample_loader": "sample_loader_83", - "part_filter": "part_filter_83", - "data_class": "OCRWebdataset", - "data_weight": 0.002, - }, - "nvlm/wdai/data/ArT/processed": { - "sample_loader": "sample_loader_84", - "part_filter": "part_filter_84", - "data_class": "OCRWebdataset", - "data_weight": 0.001, - }, - "nvlm/wdai/data/ReCTs/processed": { - "sample_loader": "sample_loader_85", - "part_filter": "part_filter_85", - "data_class": "OCRWebdataset", - "data_weight": 0.001, - }, - "nvlm/wdai/data/lsvt/processed": { - "sample_loader": "sample_loader_86", - "part_filter": "part_filter_86", - "data_class": "OCRWebdataset", - "data_weight": 0.005, - }, - "nvlm/wdai/data/RCTW/processed": { - "sample_loader": "sample_loader_87", - "part_filter": "part_filter_87", - "data_class": "OCRWebdataset", - "data_weight": 0.001, - }, - "nvlm/wdai/data/coco-text/processed_multi": { - "sample_loader": "sample_loader_88", - "part_filter": "part_filter_88", - "data_class": "OCRWebdataset", - "data_weight": 0.0003, - }, - "nvlm/wdai/data/textocr/processed_multi": { - "sample_loader": "sample_loader_89", - "part_filter": "part_filter_89", - "data_class": "OCRWebdataset", - "data_weight": 0.0004, - }, - "nvlm/wdai/data/ReCTs/processed_multi": { - "sample_loader": "sample_loader_90", - "part_filter": "part_filter_90", - "data_class": "OCRWebdataset", - "data_weight": 0.0003, - }, -} - - -def get_sample_loader(path): - """Returns the correct sample_loader function for a dataset.""" - if path not in dataset_loader_mapping: - path = data_path_mapping(path) - assert path in dataset_loader_mapping, f"path {path} not in dataset_loader_mapping" - return globals().get(dataset_loader_mapping.get(path, {}).get("sample_loader")) - - -def get_part_filter(path): - """Returns the correct part_filter function for a dataset.""" - if path not in dataset_loader_mapping: - path = data_path_mapping(path) - assert path in dataset_loader_mapping, f"path {path} not in dataset_loader_mapping" - return globals().get(dataset_loader_mapping.get(path, {}).get("part_filter")) - - -def get_data_class(path): - """Returns the correct data_class for a dataset.""" - if path not in dataset_loader_mapping: - path = data_path_mapping(path) - - assert path in dataset_loader_mapping, f"path {path} not in dataset_loader_mapping" - return dataset_loader_mapping[path]["data_class"] From a0b7e9bbd508c3a5cfe2001dc83a19cc12ae08fd Mon Sep 17 00:00:00 2001 From: yangyangt Date: Fri, 5 Jun 2026 02:27:53 -0700 Subject: [PATCH 06/11] Release: drop orphan multiview_dataloader config multiview_dataloader is unreachable: nothing imports it, and its only effect would be to register the (unmapped) multiview_data_source / multiview_dataset. Now excluded from the mapping. Co-Authored-By: Claude Opus 4.7 --- .../base/defaults/multiview_dataloader.py | 150 ------------------ 1 file changed, 150 deletions(-) delete mode 100644 cosmos_framework/configs/base/defaults/multiview_dataloader.py diff --git a/cosmos_framework/configs/base/defaults/multiview_dataloader.py b/cosmos_framework/configs/base/defaults/multiview_dataloader.py deleted file mode 100644 index a646ac6..0000000 --- a/cosmos_framework/configs/base/defaults/multiview_dataloader.py +++ /dev/null @@ -1,150 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: OpenMDW-1.1 - -""" -Hydra ConfigStore registration for multiview dataloaders. - -Registers named dataloader configs that can be referenced via Hydra overrides -(e.g. ``{override /data_train: video_control_mads_multiview_0823_gcs_720p_10fps_93frames_7views}``) -or used as templates for inline ``L(get_multiview_video_loader)(...)`` in -experiment configs. - -Two naming conventions: - - **Transfer** (with control signal): - ``video_control_{dataset}_{store}_{res}_{fps}_{frames}_{views}`` - - **Predict** (no control signal): - ``video_{dataset}_{store}_{res}_{fps}_{frames}_{views}`` -""" - -from hydra.core.config_store import ConfigStore - -from cosmos_framework.utils.lazy_config import LazyCall as L -from cosmos_framework.data.vfm.multiview.multiview_data_source import ( - DEFAULT_CAMERAS, - INDEX_TO_CAMERA_MAPPING, - TRANSFER_CAPTION_KEY_MAPPING, - TRANSFER_CONTROL_KEY_MAPPING, - TRANSFER_VIDEO_KEY_MAPPING, -) -from cosmos_framework.data.vfm.multiview.multiview_dataset import ( - MultiviewAugmentationConfig, - get_multiview_video_loader, -) - -# --------------------------------------------------------------------------- -# Camera view subsets -# --------------------------------------------------------------------------- - -CAMERA_VIEW_CONFIGS: dict[str, tuple[str, ...]] = { - "7views": DEFAULT_CAMERAS, - "1view_front": ("camera_front_wide_120fov",), - "4views": ( - "camera_front_wide_120fov", - "camera_cross_right_120fov", - "camera_rear_tele_30fov", - "camera_cross_left_120fov", - ), -} - -# --------------------------------------------------------------------------- -# Grid dimensions -# --------------------------------------------------------------------------- - -_TRANSFER_DATASETS = ["mads_multiview_0823"] -_OBJECT_STORES = ["gcs"] - -_RESOLUTIONS: list[tuple[str, tuple[int, int]]] = [ - ("720p", (720, 1280)), -] - -_FPS: list[tuple[str, int]] = [ - ("10fps", 1), # MADS transfer data is already at 10 fps -] - -_NUM_VIDEO_FRAMES: list[tuple[str, int]] = [ - ("29frames", 29), - ("61frames", 61), - ("93frames", 93), -] - - -def register_multiview_dataloaders() -> None: - """Register all multiview dataloader configs with Hydra ConfigStore.""" - - cs = ConfigStore.instance() - - # ----- Transfer dataloaders (with control signals) ----- - for dataset in _TRANSFER_DATASETS: - for object_store in _OBJECT_STORES: - for resolution_str, resolution_hw in _RESOLUTIONS: - for fps_str, downsample_factor in _FPS: - for num_frames_str, num_frames in _NUM_VIDEO_FRAMES: - for views_str, camera_keys in CAMERA_VIEW_CONFIGS.items(): - name = ( - f"video_control_{dataset}_{object_store}_{resolution_str}_" - f"{fps_str}_{num_frames_str}_{views_str}" - ) - cs.store( - group="data_train", - package="dataloader_train", - name=name, - node=L(get_multiview_video_loader)( - dataset_name=dataset, - is_train=True, - augmentation_config=L(MultiviewAugmentationConfig)( - resolution_hw=resolution_hw, - fps_downsample_factor=downsample_factor, - num_video_frames=num_frames, - camera_keys=camera_keys, - camera_video_key_mapping=TRANSFER_VIDEO_KEY_MAPPING, - camera_caption_key_mapping=TRANSFER_CAPTION_KEY_MAPPING, - camera_control_key_mapping=TRANSFER_CONTROL_KEY_MAPPING, - position_to_camera_mapping=INDEX_TO_CAMERA_MAPPING, - single_caption_camera_name="camera_front_wide_120fov", - ), - ), - ) - - # ----- Predict dataloaders (no control signals, for future use) ----- - # These use named keys (video_camera_front_wide_120fov, etc.) and need - # different datasets (e.g. alpamayo_dec2024) with 30 fps native data. - # Uncomment and add predict datasets to the catalog when needed. - # - # _PREDICT_DATASETS = ["alpamayo_dec2024"] - # _PREDICT_FPS = [("10fps", 3), ("15fps", 2)] # 30 fps native → downsample - # for dataset in _PREDICT_DATASETS: - # for object_store in _OBJECT_STORES: - # for resolution_str, resolution_hw in _RESOLUTIONS: - # for fps_str, downsample_factor in _PREDICT_FPS: - # for num_frames_str, num_frames in _NUM_VIDEO_FRAMES: - # for views_str, camera_keys in CAMERA_VIEW_CONFIGS.items(): - # name = ( - # f"video_{dataset}_{object_store}_{resolution_str}_" - # f"{fps_str}_{num_frames_str}_{views_str}" - # ) - # cs.store( - # group="data_train", - # package="dataloader_train", - # name=name, - # node=L(get_multiview_video_loader)( - # dataset_name=dataset, - # is_train=True, - # augmentation_config=L(MultiviewAugmentationConfig)( - # resolution_hw=resolution_hw, - # fps_downsample_factor=downsample_factor, - # num_video_frames=num_frames, - # camera_keys=camera_keys, - # camera_video_key_mapping=PREDICT_VIDEO_KEY_MAPPING, - # camera_caption_key_mapping=PREDICT_CAPTION_KEY_MAPPING, - # camera_control_key_mapping=None, - # position_to_camera_mapping=None, - # single_caption_camera_name=None, - # ), - # ), - # ) - - -# Auto-register on import -register_multiview_dataloaders() From 0020ed38c250f23c68ad21c11421ca3c4fdd218a Mon Sep 17 00:00:00 2001 From: yangyangt Date: Fri, 5 Jun 2026 02:42:57 -0700 Subject: [PATCH 07/11] Release: strip register_data_debug import from vlm config MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Source added a COSMOS-RELEASE-IGNORE block around the dataloader.register_data_debug import + call site; pipeline now strips it. This leaves dataloader.py with zero CF importers — orphan candidate. Co-Authored-By: Claude Opus 4.7 --- cosmos_framework/configs/base/vlm/config.py | 1 - 1 file changed, 1 deletion(-) diff --git a/cosmos_framework/configs/base/vlm/config.py b/cosmos_framework/configs/base/vlm/config.py index 7354d3e..4d1de85 100644 --- a/cosmos_framework/configs/base/vlm/config.py +++ b/cosmos_framework/configs/base/vlm/config.py @@ -7,7 +7,6 @@ from cosmos_framework.configs.base.defaults.checkpointer import register_checkpoint, register_ckpt_type from cosmos_framework.configs.base.vlm.defaults.callbacks import register_callbacks from cosmos_framework.configs.base.vlm.defaults.config import Config -from cosmos_framework.configs.base.vlm.defaults.dataloader import register_data_debug from cosmos_framework.configs.base.vlm.defaults.model import register_model from cosmos_framework.configs.base.vlm.defaults.optimizer import register_optimizer, register_scheduler from cosmos_framework.configs.base.vlm.defaults.vlm_policy import register_vlm_policy From f1550af3ed966a19cae3d3997804df77d6c93b46 Mon Sep 17 00:00:00 2001 From: yangyangt Date: Fri, 5 Jun 2026 02:46:03 -0700 Subject: [PATCH 08/11] Release: drop orphan vlm defaults/dataloader.py After vlm/config.py was IGNORE-stripped of its register_data_debug import, dataloader.py has zero CF importers. It was also the only consumer of the unmapped vlm.{collate_fn,debug_data_qwen,dummy_data_qwen} modules, so dropping it resolves those three dangling imports as well. Co-Authored-By: Claude Opus 4.7 --- .../configs/base/vlm/defaults/dataloader.py | 80 ------------------- 1 file changed, 80 deletions(-) delete mode 100644 cosmos_framework/configs/base/vlm/defaults/dataloader.py diff --git a/cosmos_framework/configs/base/vlm/defaults/dataloader.py b/cosmos_framework/configs/base/vlm/defaults/dataloader.py deleted file mode 100644 index 36b878d..0000000 --- a/cosmos_framework/configs/base/vlm/defaults/dataloader.py +++ /dev/null @@ -1,80 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: OpenMDW-1.1 - -from torch.utils.data import DataLoader - -from cosmos_framework.utils.lazy_config import LazyCall as L -from cosmos_framework.utils.config_helper import ConfigStore -from cosmos_framework.data.vfm.vlm.collate_fn import custom_collate -from cosmos_framework.data.vfm.vlm.debug_data_qwen import DebugQwenDataset -from cosmos_framework.data.vfm.vlm.dummy_data_qwen import DummyQwenDataset -from cosmos_framework.data.vfm.processors import build_processor_lazy - - -# Debug dataset -def create_debug_dataloader_config_qwen( - num_images, loss_on_completion_only: bool = True, use_dummy_image: bool = False -): - return L(DataLoader)( - dataset=L(DebugQwenDataset)( - tokenizer=L(build_processor_lazy)( - tokenizer_type="${model.config.policy.backbone.model_name}", - credentials="${checkpoint.load_from_object_store.credentials}", - bucket="${checkpoint.load_from_object_store.bucket}", - ), - num_images=num_images, - seq_len="${model.config.policy.model_max_length}", - image_token_len="${model.config.policy.qwen_max_video_token_length}", - # use_dummy_image=use_dummy_image, - ), - num_workers=8, - prefetch_factor=4, - batch_size=1, - sampler=None, - persistent_workers=False, - pin_memory=True, - collate_fn=custom_collate, - ) - - -def create_dummy_dataloader_config_qwen(): - return L(DataLoader)( - dataset=L(DummyQwenDataset)( - tokenizer=L(build_processor_lazy)( - tokenizer_type="${model.config.policy.backbone.model_name}", - credentials="${checkpoint.load_from_object_store.credentials}", - bucket="${checkpoint.load_from_object_store.bucket}", - ), - num_visual_tokens="${model.config.policy.qwen_max_video_token_length}", - total_tokens="${model.config.policy.model_max_length}", - batch_size="${dataloader_train.batch_size}", - ), - num_workers=8, - prefetch_factor=4, - batch_size=1, - sampler=None, - persistent_workers=False, - pin_memory=True, - collate_fn=custom_collate, - ) - - -def register_data_debug(): - cs = ConfigStore.instance() - for split in ["train", "val"]: - cs.store( - group=f"data_{split}", - package=f"dataloader_{split}", - name="debug_image_data_qwen", # This data is from pixtral model output, expected to have low loss ~1.4 - node=create_debug_dataloader_config_qwen(1), - ) - cs.store( - group=f"data_{split}", - package=f"dataloader_{split}", - name="dummy_image_data_qwen", - node=create_dummy_dataloader_config_qwen(), - ) - - -def register_data(): - register_data_debug() From 1614dafa63a5c5b2a75685867e2b3886a89b9fc1 Mon Sep 17 00:00:00 2001 From: yangyangt Date: Fri, 5 Jun 2026 03:05:00 -0700 Subject: [PATCH 09/11] Release: ship data/vfm/vlm/video_decoder_qwen Pulled in by the vfm augmentors (bytes_to_media, tokenize_data). Only project-internal imports are imaginaire.utils.log and the already-shipped qwen3vl_processor; no new external deps. Co-Authored-By: Claude Opus 4.7 --- .../data/vfm/vlm/video_decoder_qwen.py | 249 ++++++++++++++++++ 1 file changed, 249 insertions(+) create mode 100644 cosmos_framework/data/vfm/vlm/video_decoder_qwen.py diff --git a/cosmos_framework/data/vfm/vlm/video_decoder_qwen.py b/cosmos_framework/data/vfm/vlm/video_decoder_qwen.py new file mode 100644 index 0000000..12c9bcc --- /dev/null +++ b/cosmos_framework/data/vfm/vlm/video_decoder_qwen.py @@ -0,0 +1,249 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: OpenMDW-1.1 + +""" +Copied from projects/cosmos/reason1/datasets/video_decoder_qwen.py +Changes: +1: remove hardcoded hyper-parameters for Qwen, now read it from processor +2: support skipping smart resize, since it may resize the video frames to be smaller than model input and frames will get resized up later in processor +""" + +import random +import re +from concurrent.futures import ThreadPoolExecutor +from functools import partial +from typing import Callable, Optional + +import torch +from PIL import Image +from qwen_vl_utils.vision_process import smart_nframes, smart_resize +from torchcodec.decoders import VideoDecoder +from torchvision import transforms +from torchvision.transforms import InterpolationMode + +from cosmos_framework.utils import log +from cosmos_framework.data.vfm.processors.qwen3vl_processor import Qwen3VLProcessor + +Image.MAX_IMAGE_PIXELS = 933120000 +_VIDEO_EXTENSIONS = "mp4 avi webm mov".split() + +VIDEO_DECODER_OPTIONS = {} + + +def token_to_pixels(token_length: int, patch_size: int = 14, temporal_patch_size: int = 2, merge_size: int = 2) -> int: + """Convert token length to pixels based on patch size and temporal patch size. + + Args: + token_length: Token length + patch_size: Patch size + temporal_patch_size: Temporal patch size, + for Qwen it has 3D conv, temporal patch size is 2; for other models like internVL or eagle er, the temporal patch size is 1 since their VIT is image encoder; + merge_size: Merge size, or called pixel shuffing factor; + for Qwen and internVL it is 2; for eagle er it is 1; + """ + merged_patch_size = patch_size * merge_size + return token_length * merged_patch_size**2 * temporal_patch_size + + +def pixels_to_token(pixels: int, patch_size: int = 14, temporal_patch_size: int = 2, merge_size: int = 2) -> int: + """Convert pixels to token length based on patch size and temporal patch size.""" + merged_patch_size = patch_size * merge_size + return pixels // merged_patch_size**2 // temporal_patch_size + + +def video_decoder_qwen( + num_threads: int = 0, + min_fps_thres: int = 4, + max_fps_thres: int = 60, + target_fps: float = 2.0, + min_video_token_length: int = 16, + max_video_token_length: int = 8192, + random_augmentation: bool = False, + frame_count_random_range: Optional[list[int]] = None, + **kwargs, +) -> Callable: + """ + Sampling video frames similar to Qwen. It prioritizes matching the target FPS first and then resizing the video frames. + See https://github.com/kq-chen/qwen-vl-utils/blob/main/src/qwen_vl_utils/vision_process.py#L118 for more details. + + Args: + key: Video file name/key + data: Video binary data + min_fps_thres: Minimum FPS threshold + max_fps_thres: Maximum FPS threshold + target_fps: Target FPS + min_video_token_length: Minimum token length + max_video_token_length: Maximum token length + num_threads: Number of threads for the torchcodec video decoder + random_augmentation: Whether to randomize the FPS and max_video_token_length + frame_count_random_range: Random frame count range + + Returns: + dict with video frames tensor and target FPS + """ + + video_decoder_configured = partial( + _video_decoder_qwen_func, + min_fps_thres=min_fps_thres, + max_fps_thres=max_fps_thres, + num_threads=num_threads, + target_fps=target_fps, + min_video_token_length=min_video_token_length, + max_video_token_length=max_video_token_length, + random_augmentation=random_augmentation, + frame_count_random_range=frame_count_random_range, + ) + + return video_decoder_configured + + +def _video_decoder_qwen_func( + key: str, + data: bytes, + processor: Qwen3VLProcessor, + min_fps_thres: int = 4, + max_fps_thres: int = 60, + target_fps: float = 2.0, + min_video_token_length: int = 16, + max_video_token_length: int = 8192, + num_threads: int = 0, + random_augmentation: bool = False, + fps_random_range: list[float] = [0.5, 1.5], + max_video_token_length_random_range: list[float] = [0.75, 1.25], + frame_count_random_range: Optional[list[int]] = None, + start_frame: Optional[int] = None, + end_frame: Optional[int] = None, + decoding_timeout: int = 60, + **kwargs, +) -> dict | None: + """Actual video decoder function. + + Args: + key (str): Video file name/key + data (bytes): Video binary data + min_fps_thres (int, optional): Minimum FPS threshold. Defaults to 4. + max_fps_thres (int, optional): Maximum FPS threshold. Defaults to 60. + target_fps (float, optional): Target FPS. Defaults to 2.0. + min_video_token_length (int, optional): Minimum token length. Defaults to 16. + max_video_token_length (int, optional): Maximum token length. Defaults to 8192. + num_threads (int, optional): Number of threads for the torchcodec video decoder. Defaults to 0. + random_augmentation (bool, optional): Whether to randomize the FPS and max_video_token_length. Defaults to False. + fps_random_range (list[float], optional): Random FPS range. Defaults to [10.0, 24.0]. + max_video_token_length_random_range (list[float], optional): Random max_video_token_length range. Defaults to [0.75, 1.25]. + frame_count_random_range (list[int], optional): Random frame count range. If provided, take priority over fps_random_range. + start_frame (Optional[int], optional): Start frame. Defaults to None. If both start_frame and end_frame are provided, the video will be decoded from start_frame to end_frame. + end_frame (Optional[int], optional): End frame. Defaults to None. If both start_frame and end_frame are provided, the video will be decoded from start_frame to end_frame. + decoding_timeout (int, optional): Timeout in seconds. Defaults to 60. + Raises: + ValueError: Video fps lower than 1, skipping + ValueError: Video fps lower than min_fps_thres, skipping + ValueError: Video fps higher than max_fps_thres, skipping + + Returns: + dict | None: Dictionary with video frames tensor and target FPS + """ + # Check video extension + extension = re.sub(r".*[.]", "", key) + if extension.lower() not in _VIDEO_EXTENSIONS: + return None + + # Read video with torchcodec + video_reader = VideoDecoder(data, num_ffmpeg_threads=num_threads) + total_frames = video_reader.metadata.num_frames + video_fps = video_reader.metadata.average_fps + + # torchcodec returns ``None`` for containers that don't store frame count + # or average fps (e.g. some MKV/WebM streams). Downstream arithmetic + # (``total_frames - 1``, ``video_fps < 1``, ...) would TypeError on None; + # surface a ValueError so the dataloader's skip path handles it uniformly. + if total_frames is None or video_fps is None: + raise ValueError(f"torchcodec missing metadata (num_frames={total_frames}, average_fps={video_fps}), skipping") + + if start_frame is not None and end_frame is not None: + total_frames = end_frame - start_frame + + if video_fps < 1: + raise ValueError("Video fps lower than 1, skipping") + if video_fps < min_fps_thres: + raise ValueError(f"Video fps {video_fps} lower than {min_fps_thres}, skipping") + if video_fps > max_fps_thres: + raise ValueError(f"Video fps {video_fps} higher than {max_fps_thres}, skipping") + + if random_augmentation: + if frame_count_random_range is not None: + # Random number of frames + min_frames_range, max_frames_range = frame_count_random_range + min_frames_range = min(min_frames_range, total_frames) + max_frames_range = min(max_frames_range, total_frames) + target_frames = random.uniform(min_frames_range, max_frames_range) + target_fps = target_frames / total_frames * video_fps + else: + # randomize fps + target_fps = ( + random.uniform(fps_random_range[0], fps_random_range[1]) * target_fps + if random.random() < 0.5 + else target_fps + ) + # randomize max_video_token_length + max_video_token_length = int( + random.uniform(max_video_token_length_random_range[0], max_video_token_length_random_range[1]) + * max_video_token_length + ) + log.debug(f"random_augmentation: max_video_token_length: {max_video_token_length}, target_fps: {target_fps}") + + patch_size = processor.patch_size + min_height_width = processor.min_height_width + temporal_patch_size = processor.temporal_patch_size + merge_size = processor.merge_size + min_pixels: int = token_to_pixels(min_video_token_length, patch_size, temporal_patch_size, merge_size) + max_pixels: int = token_to_pixels(max_video_token_length, patch_size, temporal_patch_size, merge_size) + max_frames: int = max_pixels // (min_height_width) ** 2 // temporal_patch_size + + # sample based on target fps + nframes = smart_nframes(dict(fps=target_fps), total_frames=total_frames, video_fps=video_fps) + nframes = min(nframes, max_frames) + if start_frame is not None and end_frame is not None: + idx = torch.linspace(start_frame, end_frame - 1, nframes).round().long().tolist() # [nframes] + else: + idx = torch.linspace(0, total_frames - 1, nframes).round().long().tolist() # [nframes] + + def _decode_video() -> torch.Tensor: + return video_reader.get_frames_at(indices=idx).data # [T, C, H, W] uint8 + + # Use ThreadPoolExecutor to run video decoding with a timeout. + # If the thread is stuck, abandon it immediately. + executor = ThreadPoolExecutor(max_workers=1) + future = executor.submit(_decode_video) + try: + video_frames = future.result(timeout=decoding_timeout) + executor.shutdown(wait=False) + except TimeoutError as e: + log.warning(f"[{key}] Video decoding timed out after {decoding_timeout} seconds") + executor.shutdown(wait=False) + return None + + sample_fps = nframes / max(total_frames, 1e-6) * video_fps + + # recompute max_pixels based on number of sampled frames + nframes, _, height, width = video_frames.shape + max_pixels = max_pixels // nframes + if processor.use_smart_resize: + resized_height, resized_width = smart_resize( + height, + width, + factor=patch_size * merge_size, + min_pixels=min_pixels, + max_pixels=max_pixels, + ) + log.debug( + f"resized_height: {resized_height}, resized_width: {resized_width} | original height: {height}, original width: {width}" + ) + video_frames = transforms.functional.resize( + video_frames, + [resized_height, resized_width], + interpolation=InterpolationMode.BICUBIC, + antialias=True, + ).float() # [T,C,H,W] + video_frames = video_frames.permute(1, 0, 2, 3) # [C,T,H,W] + + return dict(videos=video_frames, fps=sample_fps) From b8b6a806ff190498c42940281f697162a413eae7 Mon Sep 17 00:00:00 2001 From: yangyangt Date: Fri, 5 Jun 2026 03:30:38 -0700 Subject: [PATCH 10/11] Release: ship tokenizer/evaluation/reconstruction_metrics Required by model/tokenizer/models/__init__.py (calculate_psnr re-export). FVDMetric is commented out in source (depended on the unmapped FVD + paths modules). A residual lazy import of tokenizer.evaluation.metric inside TokenizerMetric's codebook-usage branch remains dangling; only that code path is affected. Co-Authored-By: Claude Opus 4.7 --- .../evaluation/reconstruction_metrics.py | 497 ++++++++++++++++++ 1 file changed, 497 insertions(+) create mode 100644 cosmos_framework/model/tokenizer/evaluation/reconstruction_metrics.py diff --git a/cosmos_framework/model/tokenizer/evaluation/reconstruction_metrics.py b/cosmos_framework/model/tokenizer/evaluation/reconstruction_metrics.py new file mode 100644 index 0000000..66db4fb --- /dev/null +++ b/cosmos_framework/model/tokenizer/evaluation/reconstruction_metrics.py @@ -0,0 +1,497 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: OpenMDW-1.1 + +"""Metric computation for tokenizer evaluation. + +This module provides metrics for evaluating tokenizer quality: + - PSNRMetric: Peak signal-to-noise ratio (using torchmetrics) + - SSIMMetric: Structural similarity index (using torchmetrics) + - LPIPSMetric: Learned perceptual image patch similarity + - TokenizerMetric: Composite metric that includes codebook usage via compute_codebook_usage +""" + +from __future__ import annotations + +from typing import Any + +import torch +import torch.nn as nn + +# Import torchmetrics for SSIM and LPIPS +try: + from torchmetrics.image import StructuralSimilarityIndexMeasure + from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity + + HAS_TORCHMETRICS = True +except ImportError: + HAS_TORCHMETRICS = False + +# Standard batch keys +INPUT_KEY = "inputs" # [0, 1] range for PSNR/SSIM +RECON_KEY = "reconstructions" # [0, 1] range for PSNR/SSIM +INPUT_RAW_KEY = "inputs_raw" # [-1, 1] range for LPIPS +RECON_RAW_KEY = "reconstructions_raw" # [-1, 1] range for LPIPS + + +class TokenizerMetric(nn.Module): + """Composite metric module for tokenizer evaluation. + + Combines multiple metrics and computes them in a single forward pass. + + Args: + compute_psnr: Whether to compute PSNR. + compute_ssim: Whether to compute SSIM. + compute_lpips: Whether to compute LPIPS. + compute_code_usage: Whether to compute codebook usage. + """ + + def __init__( + self, + compute_psnr: bool = True, + compute_ssim: bool = True, + compute_lpips: bool = False, + compute_code_usage: bool = False, + num_codes: int = 65536, + ) -> None: + super().__init__() + self.compute_psnr = compute_psnr + self.compute_ssim = compute_ssim + self.compute_lpips = compute_lpips + self.compute_code_usage = compute_code_usage + self.num_codes = num_codes + + if compute_psnr: + self.psnr = PSNRMetric() + if compute_ssim: + self.ssim = SSIMMetric() + if compute_lpips: + self.lpips = LPIPSMetric() + + def forward( + self, + inputs: dict[str, torch.Tensor], + output_batch: dict[str, torch.Tensor], + iteration: int, + ) -> dict[str, Any]: + """Compute all enabled metrics. + + Args: + inputs: Input batch with original images/videos. Should contain: + - "inputs": [0, 1] range for PSNR/SSIM + - "inputs_raw": [-1, 1] range for LPIPS + output_batch: Output batch with reconstructions. Should contain: + - "reconstructions": [0, 1] range for PSNR/SSIM + - "reconstructions_raw": [-1, 1] range for LPIPS + iteration: Current iteration. + + Returns: + Dictionary of metric values. PSNR/SSIM/LPIPS return dicts with 'sum' and 'count' + for proper distributed averaging. + """ + metrics = {} + + # [0, 1] range data for PSNR/SSIM + original = inputs.get(INPUT_KEY) + recon = output_batch.get(RECON_KEY) + + # [-1, 1] range data for LPIPS + original_raw = inputs.get(INPUT_RAW_KEY) + recon_raw = output_batch.get(RECON_RAW_KEY) + + if original is None or recon is None: + return metrics + + if self.compute_psnr: + metrics["psnr"] = self.psnr(original, recon) + + if self.compute_ssim: + metrics["ssim"] = self.ssim(original, recon) + + if self.compute_lpips: + # Use [-1, 1] range data for LPIPS + # Fall back to converting [0, 1] to [-1, 1] if raw data not available + if original_raw is not None and recon_raw is not None: + metrics["lpips"] = self.lpips(original_raw, recon_raw) + else: + # Convert [0, 1] to [-1, 1] if raw data not provided + original_lpips = original * 2.0 - 1.0 + recon_lpips = recon * 2.0 - 1.0 + metrics["lpips"] = self.lpips(original_lpips, recon_lpips) + + if self.compute_code_usage: + quant_info = output_batch.get("quant_info") + if quant_info is not None: + indices = quant_info.get("indices") + if indices is not None: + from cosmos_framework.model.tokenizer.evaluation.metric import compute_codebook_usage + + code_stats = compute_codebook_usage(indices, self.num_codes) + metrics["code_perplexity"] = code_stats["perplexity"] + metrics["code_active_ratio"] = code_stats["active_ratio"] + metrics["code_active_count"] = code_stats["active_codes"] + + return metrics + + +class PSNRMetric(nn.Module): + """Peak Signal-to-Noise Ratio metric. + + Computes PSNR between original and reconstructed images. + Expects inputs in [0, 1] range (already normalized by caller). + + Uses per-sample MSE calculation on uint8 [0, 255] range: + - Convert [0, 1] float to [0, 255] uint8 + - Compute MSE per sample on uint8 values (average over C, H, W dimensions) + - Compute PSNR per sample with max_val=255 + - Return dict with sum and count for proper distributed averaging + """ + + def __init__(self) -> None: + super().__init__() + + def forward(self, original: torch.Tensor, reconstructed: torch.Tensor) -> dict: + """Compute PSNR between original and reconstructed tensors. + + Args: + original: Original tensor in [0, 1] range. Shape: (B, C, H, W) or (B, T, C, H, W). + reconstructed: Reconstructed tensor in [0, 1] range. + + Returns: + Dict with 'sum' (sum of per-sample PSNRs) and 'count' (number of samples) + for proper distributed averaging. + """ + # Handle video format by flattening batch and time dimensions + if original.dim() == 5: # (B, T, C, H, W) + b, t, c, h, w = original.shape + original = original.reshape(b * t, c, h, w) + reconstructed = reconstructed.reshape(b * t, c, h, w) + + # Convert to uint8 [0, 255] range + original_uint8 = (original.clamp(0, 1) * 255).byte() + reconstructed_uint8 = (reconstructed.clamp(0, 1) * 255).byte() + + # Compute per-sample MSE on uint8 values (as float for precision) + mse = torch.mean((original_uint8.float() - reconstructed_uint8.float()) ** 2, dim=[1, 2, 3]) # (B,) + + # Handle zero MSE (identical images) - use max PSNR of 100 dB + max_psnr = 100.0 + mse = torch.where( + mse == 0, + torch.tensor(10.0 ** (-max_psnr / 10.0) * 255.0 * 255.0, device=mse.device, dtype=mse.dtype), + mse, + ) + + # Compute PSNR per sample with max_val=255 + psnr = 20 * torch.log10(255.0 / torch.sqrt(mse)) + + # Return sum and count for proper distributed averaging + return {"sum": psnr.sum().item(), "count": psnr.shape[0]} + + +class SSIMMetric(nn.Module): + """Structural Similarity Index metric. + + Uses torchmetrics for SSIM computation. + Expects inputs in [0, 1] range (already normalized by caller). + """ + + def __init__(self) -> None: + super().__init__() + if HAS_TORCHMETRICS: + # data_range=1.0 for [0, 1] normalized images + self._ssim_metric = StructuralSimilarityIndexMeasure( + data_range=1.0, + sync_on_compute=False, + dist_sync_on_step=False, + ) + else: + self._ssim_metric = None + + def forward(self, original: torch.Tensor, reconstructed: torch.Tensor) -> dict: + """Compute SSIM between original and reconstructed tensors. + + Args: + original: Original tensor in [0, 1] range. Shape: (B, C, H, W) or (B, T, C, H, W). + reconstructed: Reconstructed tensor in [0, 1] range. + + Returns: + Dict with 'sum' (sum of per-sample SSIMs) and 'count' (number of samples) + for proper distributed averaging. + """ + if not HAS_TORCHMETRICS or self._ssim_metric is None: + return {"sum": 0.0, "count": 0} + + # Handle video by flattening temporal dimension + if original.dim() == 5: # B, T, C, H, W + b, t, c, h, w = original.shape + original = original.reshape(b * t, c, h, w) + reconstructed = reconstructed.reshape(b * t, c, h, w) + + # Clamp to [0, 1] range and convert to float32 for SSIM computation + original = original.clamp(0, 1).float() + reconstructed = reconstructed.clamp(0, 1).float() + + batch_size = original.shape[0] + + # Move metric to correct device + self._ssim_metric = self._ssim_metric.to(original.device) + + # Reset metric state before computing to avoid accumulation from previous calls + self._ssim_metric.reset() + + # Compute SSIM for each sample individually to get per-sample values + # We need to reset between samples to avoid state accumulation + ssim_sum = 0.0 + for i in range(batch_size): + orig_i = original[i : i + 1] + recon_i = reconstructed[i : i + 1] + # Update with single sample + self._ssim_metric.update(recon_i, orig_i) + # Compute returns the value for accumulated samples (just 1 here) + ssim_val = self._ssim_metric.compute() + ssim_sum += ssim_val.item() + # Reset for next sample to avoid accumulation + self._ssim_metric.reset() + + # Return sum and count for proper distributed averaging + return {"sum": ssim_sum, "count": batch_size} + + +class LPIPSMetric(nn.Module): + """Learned Perceptual Image Patch Similarity metric. + + Uses torchmetrics LPIPS with VGG backbone. + Expects inputs in [-1, 1] range for LPIPS computation. + Note: The forward() method expects [-1, 1] range directly (no conversion needed). + """ + + def __init__(self, net_type: str = "vgg") -> None: + super().__init__() + if HAS_TORCHMETRICS: + # LPIPS expects inputs in [-1, 1] range + self._lpips_metric = LearnedPerceptualImagePatchSimilarity( + net_type=net_type, + sync_on_compute=False, + dist_sync_on_step=False, + ) + else: + self._lpips_metric = None + + def forward(self, original: torch.Tensor, reconstructed: torch.Tensor) -> dict: + """Compute LPIPS between original and reconstructed tensors. + + Args: + original: Original tensor in [-1, 1] range. Shape: (B, C, H, W) or (B, T, C, H, W). + reconstructed: Reconstructed tensor in [-1, 1] range. + + Returns: + Dict with 'sum' (sum of per-sample LPIPS) and 'count' (number of samples) + for proper distributed averaging. + """ + if not HAS_TORCHMETRICS or self._lpips_metric is None: + return {"sum": 0.0, "count": 0} + + # Handle video by flattening temporal dimension + if original.dim() == 5: # B, T, C, H, W + b, t, c, h, w = original.shape + original = original.reshape(b * t, c, h, w) + reconstructed = reconstructed.reshape(b * t, c, h, w) + + # LPIPS expects [-1, 1] range - clamp and convert to float32 + original_lpips = original.clamp(-1.0, 1.0).float() + reconstructed_lpips = reconstructed.clamp(-1.0, 1.0).float() + + batch_size = original.shape[0] + + # Move metric to correct device + self._lpips_metric = self._lpips_metric.to(original.device) + + # Reset metric state before computing to avoid accumulation from previous calls + self._lpips_metric.reset() + + # Compute LPIPS for each sample individually + lpips_sum = 0.0 + for i in range(batch_size): + orig_i = original_lpips[i : i + 1] + recon_i = reconstructed_lpips[i : i + 1] + # Update with single sample + self._lpips_metric.update(recon_i, orig_i) + # Compute returns the value for accumulated samples (just 1 here) + lpips_val = self._lpips_metric.compute() + lpips_sum += lpips_val.item() + # Reset for next sample to avoid accumulation + self._lpips_metric.reset() + + # Return sum and count for proper distributed averaging + return {"sum": lpips_sum, "count": batch_size} + + +def calculate_psnr( + original: torch.Tensor | list[torch.Tensor], + reconstructed: torch.Tensor | list[torch.Tensor], +) -> torch.Tensor: + """Calculate PSNR between two tensors or lists of tensors. + + This is a standalone function for use in evaluation and training logging. + Expects inputs already in [0, 1] range. Converts to uint8 [0, 255] internally. + + Supports multiple input formats: + - Lists of tensors (variable-size images from sparse_to_img_list) + - 5D tensors (B, T, C, H, W) for video + - 4D tensors (B, C, H, W) for batched images + - 3D tensors (C, H, W) for single images + + Args: + original: Original image(s) in [0, 1] range. Can be tensor or list of tensors. + reconstructed: Reconstructed image(s) in [0, 1] range. Must match original format. + + Returns: + PSNR value as a tensor (scalar, for distributed gathering). + """ + # Handle lists of tensors (from sparse_to_img_list) + if isinstance(original, list) and isinstance(reconstructed, list): + if len(original) != len(reconstructed): + raise ValueError(f"Image lists must have the same length. Got {len(original)} and {len(reconstructed)}") + + psnr_values = [] + for orig, rec in zip(original, reconstructed): + psnr_values.append(calculate_psnr(orig, rec)) + + # Average PSNR across all images + return sum(psnr_values) / len(psnr_values) + + # At this point, both should be tensors + if original.shape != reconstructed.shape: + raise ValueError(f"Images must have the same shape. Got {original.shape} and {reconstructed.shape}") + + # Handle 3D tensor (C, H, W) - add batch dimension + if original.dim() == 3: + original = original.unsqueeze(0) + reconstructed = reconstructed.unsqueeze(0) + + # Handle 5D tensor (B, T, C, H, W) - flatten batch and time + if original.dim() == 5: + b, t = original.shape[:2] + original = original.reshape(b * t, *original.shape[2:]) + reconstructed = reconstructed.reshape(b * t, *reconstructed.shape[2:]) + + # Now we have 4D tensors (B, C, H, W) + # Convert to uint8 [0, 255] range + original_uint8 = (original.detach().clamp(0, 1) * 255).byte() + reconstructed_uint8 = (reconstructed.detach().clamp(0, 1) * 255).byte() + + # Compute MSE per sample on uint8 values + mse = torch.mean((original_uint8.float() - reconstructed_uint8.float()) ** 2, dim=[1, 2, 3]) + + # Handle zero MSE (identical images) - cap at 100 dB + max_psnr = 100.0 + mse = torch.where( + mse == 0, + torch.tensor(10.0 ** (-max_psnr / 10.0) * 255.0 * 255.0, device=mse.device, dtype=mse.dtype), + mse, + ) + + # Compute PSNR with max_val=255 + psnr = 20 * torch.log10(torch.tensor(255.0, device=mse.device, dtype=mse.dtype)) - 10 * torch.log10(mse) + + # Return mean PSNR + return psnr.mean() + + +class Rank0FIDMetric(nn.Module): + """FID metric that runs only on rank 0 to avoid distributed sync issues. + + Uses torchmetrics FrechetInceptionDistance internally but only computes + on rank 0's data to avoid NCCL collective operation mismatches caused by + torchmetrics/torch-fidelity's internal distributed synchronization. + + Note: FID is computed only on rank 0's portion of the data (1/world_size), + which may be less representative than full dataset FID, but avoids + distributed synchronization issues. + + Usage: + fid = Rank0FIDMetric(rank=rank).to(device) + + # During evaluation loop (only rank 0 updates) + for batch in dataloader: + fid.update(real_images, fake_images) + + # Compute FID (only rank 0 has valid result) + if rank == 0: + fid_value = fid.compute() + """ + + def __init__(self, rank: int = 0, feature_dim: int = 2048) -> None: + super().__init__() + self.rank = rank + self.feature_dim = feature_dim + self._fid_metric = None + + # Only initialize FID metric on rank 0 + if self.rank == 0: + try: + from torchmetrics.image.fid import FrechetInceptionDistance + + # normalize=True means input is [0, 1] float, not uint8 + self._fid_metric = FrechetInceptionDistance( + feature=feature_dim, + normalize=True, + sync_on_compute=False, + dist_sync_on_step=False, + ) + except ImportError: + pass + + @torch.no_grad() + def update(self, real_images: torch.Tensor, fake_images: torch.Tensor) -> None: + """Update FID statistics with a batch of real and fake images. + + Only updates on rank 0. + + Args: + real_images: Real images in [0, 1] range, shape (B, C, H, W) or (B, T, C, H, W) + fake_images: Fake/reconstructed images in [0, 1] range + """ + if self.rank != 0 or self._fid_metric is None: + return + + # Handle video format by flattening batch and time dimensions + if real_images.dim() == 5: # (B, T, C, H, W) + real_images = real_images.reshape(-1, *real_images.shape[2:]) + fake_images = fake_images.reshape(-1, *fake_images.shape[2:]) + + # Move metric to same device as images + device = real_images.device + self._fid_metric = self._fid_metric.to(device) + + # torchmetrics FID update + self._fid_metric.update(real_images, real=True) + self._fid_metric.update(fake_images, real=False) + + def compute(self) -> torch.Tensor: + """Compute FID from accumulated statistics. + + Only valid on rank 0. + + Returns: + FID value as a scalar tensor (inf if not rank 0 or metric unavailable) + """ + if self.rank != 0 or self._fid_metric is None: + return torch.tensor(float("inf")) + + return self._fid_metric.compute() + + def reset(self) -> None: + """Reset accumulated statistics.""" + if self._fid_metric is not None: + self._fid_metric.reset() + + +__all__ = [ + "TokenizerMetric", + "PSNRMetric", + "SSIMMetric", + "LPIPSMetric", + "Rank0FIDMetric", + "calculate_psnr", +] From 4284d7ffa6e0263bb0d4041e8683dd3c2f6a6510 Mon Sep 17 00:00:00 2001 From: yangyangt Date: Mon, 8 Jun 2026 05:01:43 -0700 Subject: [PATCH 11/11] Release: resolve dangling imports (ship processors/metric/video_preprocess, drop broken tests) Ship nemotron3densevl/nemotronvl processors, tokenizer evaluation/metric, and vfm/video_preprocess to satisfy previously-dangling imports. Drop test files that import unshipped helpers (helper_test/unittest_utils, scripts.eval). CF now has no dangling cosmos_framework module imports. Co-Authored-By: Claude Opus 4.7 --- .../processors/nemotron3densevl_processor.py | 249 ++++++++ .../vlm/processors/nemotronvl_processor.py | 553 ++++++++++++++++++ .../model/tokenizer/evaluation/metric.py | 433 ++++++++++++++ .../model/vfm/tokenizers/audio/avae_test.py | 229 -------- .../dc_ae/cosmos_ae_4x32x32_compile_test.py | 281 --------- .../tokenizers/dc_ae/dc_ae_4x32x32_test.py | 220 ------- .../uniae/noncausal_4x16x16_test.py | 284 --------- cosmos_framework/scripts/eval_utils_test.py | 386 ------------ .../utils/easy_io/easy_io_test.py | 70 --- .../utils/vfm/video_preprocess.py | 32 + 10 files changed, 1267 insertions(+), 1470 deletions(-) create mode 100644 cosmos_framework/data/vlm/processors/nemotron3densevl_processor.py create mode 100644 cosmos_framework/data/vlm/processors/nemotronvl_processor.py create mode 100644 cosmos_framework/model/tokenizer/evaluation/metric.py delete mode 100644 cosmos_framework/model/vfm/tokenizers/audio/avae_test.py delete mode 100644 cosmos_framework/model/vfm/tokenizers/dc_ae/cosmos_ae_4x32x32_compile_test.py delete mode 100644 cosmos_framework/model/vfm/tokenizers/dc_ae/dc_ae_4x32x32_test.py delete mode 100644 cosmos_framework/model/vfm/tokenizers/uniae/noncausal_4x16x16_test.py delete mode 100644 cosmos_framework/scripts/eval_utils_test.py delete mode 100644 cosmos_framework/utils/easy_io/easy_io_test.py create mode 100644 cosmos_framework/utils/vfm/video_preprocess.py diff --git a/cosmos_framework/data/vlm/processors/nemotron3densevl_processor.py b/cosmos_framework/data/vlm/processors/nemotron3densevl_processor.py new file mode 100644 index 0000000..08c5d37 --- /dev/null +++ b/cosmos_framework/data/vlm/processors/nemotron3densevl_processor.py @@ -0,0 +1,249 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: OpenMDW-1.1 + +import os +from typing import Dict, List, Optional + +import numpy as np +import torch +from PIL import Image +from qwen_vl_utils.vision_process import smart_resize +from transformers.models.auto.processing_auto import AutoProcessor + +from cosmos_framework.utils import log +from cosmos_framework.utils.vlm.pretrained_models_downloader import maybe_download_hf_model_from_s3 + + +def convert_string_content_to_list_content(messages: List[Dict]) -> List[Dict]: + """ + Convert the string content to a list of dicts. + """ + for message_id, message in enumerate(messages): + if isinstance(message["content"], str): + messages[message_id]["content"] = [{"type": "text", "text": message["content"]}] + return messages + + +def maybe_parse_video_content( + messages: List[Dict], +) -> tuple[int, Optional[list[float]], Optional[list[int]], Optional[list[list[int]]]]: + """ + Convert the string content to a list of dicts. + """ + num_video = 0 + video_fps = [] + video_total_num_frames = [] + video_frames_indices = [] + for message_id, message in enumerate(messages): + if isinstance(message["content"], list): + for sub_content in message["content"]: + if sub_content.get("type", "") == "video" and isinstance(sub_content["video"], list): + num_video += 1 + fps = sub_content.get("fps", None) + if fps is None: + log.critical( + f"fps is None for video {sub_content}. Better to set the fps explicitly", rank0_only=False + ) + video_fps.append(fps) + video_total_num_frames.append(len(sub_content["video"])) + video_frames_indices.append(list(range(video_total_num_frames[-1]))) + return num_video, video_fps, video_total_num_frames, video_frames_indices + + +class Nemotron3DenseVLProcessor: + # This is a wrapper around the AutoProcessor class to add some helper functions + def __init__( + self, + name="Qwen/Qwen3-VL-2B-Init", + credentials: str = "./credentials/s3_training.secret", + bucket: str = "bucket4", + cache_dir: str = None, + ): + self.name = name + if os.path.isdir(name): + model_name_or_path_local = name + else: + model_name_or_path_local = maybe_download_hf_model_from_s3( + name, credentials, bucket, include_model_weights=False + ) + + self.processor = AutoProcessor.from_pretrained(model_name_or_path_local, trust_remote_code=True) + log.info("Successfully loaded processor from local cache") + + if hasattr(self.processor, "image_token"): + self.image_token_id = self.processor.tokenizer.convert_tokens_to_ids(self.processor.image_token) + else: + self.image_token_id = None + if hasattr(self.processor, "video_token"): + self.video_token_id = self.processor.tokenizer.convert_tokens_to_ids(self.processor.video_token) + else: + self.video_token_id = None + self.eos_id = self.processor.tokenizer.eos_token_id + self.pad_id = self.processor.tokenizer.pad_token_id + self.vision_end_id = self.processor.tokenizer.convert_tokens_to_ids("") + + # Helper attributes for the dataloader video decoding function + self.shortest_edge = self.processor.image_processor.size["shortest_edge"] + self.min_height_width = int(np.sqrt(self.shortest_edge)) + self.patch_size = self.processor.video_processor.patch_size + self.temporal_patch_size = self.processor.video_processor.temporal_patch_size + self.merge_size = self.processor.video_processor.merge_size + self.use_smart_resize = True + if self.pad_id is None: + self.pad_id = self.eos_id + + def apply_chat_template( + self, + messages, + add_generation_prompt=False, + return_tensors="pt", + tokenize=True, + **kwargs, + ): + """ + Return: + inputs: dict + input_ids: torch.Tensor, shape: (N_token) + attention_mask: torch.Tensor, shape: (N_token) + texts: str, the raw text + image_sizes: torch.Tensor, shape (N_img, 2) + pixel_values: torch.Tensor, shape (N_img_patch, 3, 224, 224) + """ + + # messages = [msg for msg in messages if msg.get("role") != "system"] + assert tokenize, "tokenize must be True" + assert return_tensors == "pt", "return_tensors must be pt" + # Note: this tokenizer does not support "content": str, it always expect "content" entry to be a list of dicts + messages = convert_string_content_to_list_content(messages) + kwargs = {} + for message_id, message in enumerate(messages): + if isinstance(message["content"], list): + for sub_content in message["content"]: + if sub_content.get("type", "") == "image": + image = sub_content["image"] + max_pixels = sub_content.get("max_pixels", self.processor.image_processor.size["longest_edge"]) + min_pixels = sub_content.get("min_pixels", self.processor.image_processor.size["shortest_edge"]) + assert isinstance(image, Image.Image), ( + "image must be a url string for now, not support list of images for one content" + ) + width, height = image.size + resized_height, resized_width = smart_resize( + height, + width, + factor=32, + min_pixels=min_pixels, + max_pixels=max_pixels, + ) + image = image.resize((resized_width, resized_height)) + sub_content["image"] = image + + num_video, video_fps, video_total_num_frames, video_frames_indices = maybe_parse_video_content(messages) + if num_video > 0: + # Here we add the args to avoid the error: + # File "/usr/local/lib/python3.12/dist-packages/transformers/video_processing_utils.py", line 321, in _decode_and_sample_videos + # raise ValueError( + # ValueError: Sampling frames from a list of images is not supported! Set `do_sample_frames=False`. + # kwargs["videos_kwargs"] = dict(do_sample_frames=False) + assert num_video == 1, "only support one video for now" + fps = video_fps[0] + total_num_frames = video_total_num_frames[0] + frames_indices = video_frames_indices[0] + kwargs.update( + { + "do_sample_frames": False, + "video_metadata": dict(fps=fps, total_num_frames=total_num_frames, frames_indices=frames_indices), + } + ) + + inputs = self.processor.apply_chat_template( + messages, + tokenize=tokenize, + add_generation_prompt=add_generation_prompt, + return_dict=True, + return_tensors=return_tensors, + # padding="max_length", + # max_length=16000, + # truncation=False, + **kwargs, + ) + + # Convert batch features into single features + # By default, the processor returns a batch of features, but we use processor in dataloader, so we need to convert it to single features + inputs["input_ids"] = inputs["input_ids"][0] # [N_token] + inputs["attention_mask"] = inputs["attention_mask"][0] # [N_token] + num_image_tokens = inputs["input_ids"] == self.image_token_id # [N_token] + num_video_tokens = inputs["input_ids"] == self.video_token_id # [N_token] + return inputs + + def add_assistant_tokens_mask(self, tokens): + """ + Add a mask to the assistant tokens. + This is used to mask out tokens that are not generated by the assistant (e.g., system prompts, user prompts, chat templates), such that in the loss computation, only the tokens generated by the assistant are used. + If there are multiple turns in the conversation, the mask will mask all the assistant tokens in each turn. + + Args: + tokens (Union[List[int], torch.Tensor]): The tokens to add the mask to. + Returns: + Union[List[bool], torch.Tensor]: The mask. True for tokens generated by the assistant (i.e. should apply loss on), False for tokens not generated by the assistant. + """ + if isinstance(tokens, torch.Tensor) and tokens.ndim == 2: + mask = torch.stack( + [self.add_assistant_tokens_mask(tokens[i]) for i in range(tokens.shape[0])] + ) # [B,N_token] + assert mask.shape == tokens.shape + return mask + np_tokens = tokens.cpu().numpy() if isinstance(tokens, torch.Tensor) else np.array(tokens) + assert np_tokens.ndim == 1 + + # Constants defining bos, eos and fixed offsets. + BOS_TOKEN = "<|im_start|>" + EOS_TOKEN = "<|im_end|>" + ROLE = "assistant" + # Offsets: skip the bos + "assistant\n" (always 3 tokens) and include the eos (+1) for supervision + START_OFFSET = 3 + END_OFFSET = 1 + + # Retrieve token IDs for the markers and the role. + bos_token_id = self.processor.tokenizer.convert_tokens_to_ids(BOS_TOKEN) + eos_token_id = self.processor.tokenizer.convert_tokens_to_ids(EOS_TOKEN) + role_id = self.processor.tokenizer.convert_tokens_to_ids(ROLE) + role_ids = self.processor.tokenizer.encode( + ROLE, add_special_tokens=False + ) # In case the role_id corresponds to multiple tokens, decode it back to string for accurate comparison + think_start_id = self.processor.tokenizer.convert_tokens_to_ids("") + think_end_id = self.processor.tokenizer.convert_tokens_to_ids("") + + # Locate all positions where the start and end markers appear. + start_indices = np.where(np_tokens == bos_token_id)[0] + end_indices = np.where(np_tokens == eos_token_id)[0] + + # Initialize the mask with False values. + masks = np.zeros_like(np_tokens, dtype=bool) + assert len(start_indices) == len(end_indices) + # For each pair of bos/eos, check if the role is 'assistant' + # and apply the mask accordingly. + for start, end in zip(start_indices, end_indices): + end_pos = None + if np_tokens[start + 1] == role_id: + # Mask tokens from after the assistant header (start+3) to include the end marker (end+1) + masks[start + START_OFFSET : end + END_OFFSET] = True + end_pos = start + START_OFFSET + elif all(np_tokens[start + 1 : start + 1 + len(role_ids)] == role_ids): + masks[start + START_OFFSET + len(role_ids) - 1 : end + END_OFFSET] = True + end_pos = start + START_OFFSET + len(role_ids) - 1 + if end_pos is not None and np_tokens[end_pos] == think_start_id: + masks[end_pos] = False + if np_tokens[end_pos + 1] == think_end_id: + masks[end_pos + 1] = False + + assert masks.shape == np_tokens.shape + if isinstance(tokens, torch.Tensor): + return torch.from_numpy(masks) + else: + return masks.tolist() + + def encode(self, *args, **kwargs): + return self.processor.encode(*args, **kwargs) + + def decode(self, *args, **kwargs): + return self.processor.decode(*args, **kwargs) diff --git a/cosmos_framework/data/vlm/processors/nemotronvl_processor.py b/cosmos_framework/data/vlm/processors/nemotronvl_processor.py new file mode 100644 index 0000000..2edb222 --- /dev/null +++ b/cosmos_framework/data/vlm/processors/nemotronvl_processor.py @@ -0,0 +1,553 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: OpenMDW-1.1 + +import os +from typing import Dict, List, Optional + +import numpy as np +import torch +from PIL import Image +from transformers.models.auto.processing_auto import AutoProcessor +from transformers.processing_utils import VideosKwargs +from transformers.video_utils import VideoMetadata + +from cosmos_framework.utils import log +from cosmos_framework.utils.vlm.pretrained_models_downloader import maybe_download_hf_model_from_s3 + +nemotron_chat_template = """ +{%- set ns = namespace(enable_thinking=false, has_sys_prompt=false, non_tool_system_content='', has_video=false, explicit_think_requested=false) -%} +{%- set msg = namespace(content='') -%} +{%- for message in messages -%} + {%- if message['role'] == 'system' -%} + {%- set ns.has_sys_prompt = true -%} + {# Extract system content without tool flags #} + {%- if message['content'] is string -%} + {%- set ns.non_tool_system_content = message['content'].replace('', '<_end_think>').replace('/think', '').replace('/no_think', '').replace('<_end_think>', '').strip() -%} + {%- else -%} + {%- set ns.non_tool_system_content = '' -%} + {%- for content in message['content'] -%} + {%- if content['type'] == 'text' -%} + {%- set ns.non_tool_system_content = ns.non_tool_system_content + content['text'].replace('', '<_end_think>').replace('/think', '').replace('/no_think', '').replace('<_end_think>', '') -%} + {%- endif -%} + {%- endfor -%} + {%- set ns.non_tool_system_content = ns.non_tool_system_content.strip() -%} + {%- endif -%} + {%- endif -%} + {# Check for video content in all messages #} + {%- if message['content'] is not string -%} + {%- for content in message['content'] -%} + {%- if content['type'] == 'video' or content['type'] == 'video_url' -%} + {%- set ns.has_video = true -%} + {%- endif -%} + {%- endfor -%} + {%- endif -%} + {%- if message['content'] is string -%} + {%- if message['role'] == 'user' or message['role'] == 'system' -%} + {%- if '/think' in message['content'].replace('', '') -%} + {%- set ns.enable_thinking = true -%} + {%- set ns.explicit_think_requested = true -%} + {%- elif '/no_think' in message['content'] -%} + {%- set ns.enable_thinking = false -%} + {%- endif -%} + {%- endif -%} + {%- else -%} + {%- for content in message['content'] -%} + {%- if content['type'] == 'text' -%} + {%- if message['role'] == 'user' or message['role'] == 'system' -%} + {%- if '/think' in content['text'].replace('', '') -%} + {%- set ns.enable_thinking = true -%} + {%- set ns.explicit_think_requested = true -%} + {%- elif '/no_think' in content['text'] -%} + {%- set ns.enable_thinking = false -%} + {%- endif -%} + {%- endif -%} + {%- endif -%} + {%- endfor -%} + {%- endif -%} +{%- endfor -%} + +{{- bos_token -}} +{%- if messages[0]['role'] != 'system' -%} + {{- 'System\n' -}} +{%- else -%} + {{- 'System\n' + ns.non_tool_system_content }} +{%- endif -%} + +{%- if tools -%} + {%- if ns.non_tool_system_content != '' -%} + {{- '\n\n' -}} + {%- endif -%} + {{- 'You can use the following tools to assist the user if required:\n' -}} + {{- '[' -}} + {%- for tool in tools -%} + {{- (tool.function if tool.function is defined else tool) | tojson -}} + {{- ', ' if not loop.last else '' -}} + {%- endfor -%} + {{- ']\n\n' -}} + + {{- 'If you decide to call any tool(s), use the following format:\n' -}} + {{- '[{"name": "tool_name1", "arguments": "tool_args1"}, ' -}} + {{- '{"name": "tool_name2", "arguments": "tool_args2"}]\n\n' -}} + + {{- 'The user will execute tool-calls and return responses from tool(s) in this format:\n' -}} + {{- '[{"response": "tool_response1"}, ' -}} + {{- '{"response": "tool_response2"}]\n\n' -}} + + {{- 'Based on the tool responses, you can call additional tools if needed, ' -}} + {{- 'correct tool calls if any errors are found, or just respond to the user.' -}} +{%- endif -%} +{{- '\n' -}} + +{%- set messages = messages[1:] if messages[0]['role'] == 'system' else messages -%} + +{# Prevent no user or assistant message #} +{%- if messages|length == 0 -%} + {%- set messages = [{'role': 'user', 'content': ''}] -%} +{%- endif -%} + +{%- for message in messages %} + {%- if message['content'] is string -%} + {%- set msg.content = message['content'].replace('', '<_end_think>').replace('/think', '').replace('/no_think', '').replace('<_end_think>', '').strip() -%} + {%- else -%} + {%- set msg.content = '' -%} + {%- set mm_content = '' -%} + {%- set counters = namespace(images=0, videos=0) -%} + + {%- for content in message['content'] -%} + {%- if content['type'] == 'image' -%} + {%- set counters.images = counters.images + 1 -%} + {%- elif content['type'] == 'video' -%} + {%- set counters.videos = counters.videos + 1 -%} + {%- elif content['type'] == 'text' -%} + {%- set msg.content = msg.content + content['text'] -%} + {%- endif -%} + {%- endfor -%} + {%- if '' in msg.content -%} + {%- set counters.images = 0 -%} + {%- endif -%} + {%- if '