Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion cosmos_framework/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,2 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: OpenMDW-1.1

6 changes: 2 additions & 4 deletions cosmos_framework/auxiliary/guardrail/common/presets.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,6 @@
from cosmos_framework.auxiliary.guardrail.common.core import GuardrailRunner
from cosmos_framework.auxiliary.guardrail.face_blur_filter.face_blur_filter import RetinaFaceFilter
from cosmos_framework.auxiliary.guardrail.qwen3guard.qwen3guard import Qwen3Guard
from cosmos_framework.auxiliary.guardrail.video_content_safety_filter.video_content_safety_filter import (
VideoContentSafetyFilter,
)
from cosmos_framework.utils import log


Expand All @@ -27,7 +24,8 @@ def create_video_guardrail_runner(offload_model_to_cpu: bool = False) -> Guardra
"""Create the video guardrail runner."""
return GuardrailRunner(
safety_models=[
# VideoContentSafetyFilter(offload_model_to_cpu=offload_model_to_cpu), # Too many false positives
# VideoContentSafetyFilter(offload_model_to_cpu=offload_model_to_cpu)
# Too many false positives, add back when fixed
],
postprocessors=[RetinaFaceFilter(offload_model_to_cpu=offload_model_to_cpu)],
)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
# Copyright (c) 2019
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: OpenMDW-1.1

Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,2 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: OpenMDW-1.1

15 changes: 14 additions & 1 deletion cosmos_framework/callbacks/compile_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
"""Training callback that defers AOT compilation of the VAE tokenizer.

The actual compilation logic lives in
:meth:`~projects.cosmos3.vfm.tokenizers.wan2pt2_vae_4x16x16.Wan2pt2VAEInterface.compile_encode`.
:meth:`~cosmos_framework.model.vfm.tokenizers.wan2pt2_vae_4x16x16.Wan2pt2VAEInterface.compile_encode`.
This module provides a :class:`CompileTokenizer` callback that invokes it
at the right point during training (after ``compile_after_iterations``
steps, to avoid NCCL timeouts during CUDA/cuDNN warm-up).
Expand All @@ -21,6 +21,7 @@
"""

from collections.abc import Sequence
from typing import Literal

import torch

Expand All @@ -43,6 +44,10 @@ def __init__(
enabled: bool = False,
compile_after_iterations: int = 3,
warmup_resolutions: Sequence[str] | None = None,
backend: Literal["cudagraphs", "inductor"] = "inductor",
mode: Literal["reduce-overhead", "max-autotune"] | None = "reduce-overhead",
fullgraph: bool = False,
dynamic: bool = False,
):
"""
Args:
Expand All @@ -60,6 +65,10 @@ def __init__(
self.compile_after_iterations: int = compile_after_iterations
self.skip_counter: int = 0
self.warmup_resolutions: Sequence[str] | None = warmup_resolutions
self.backend: Literal["cudagraphs", "inductor"] = backend
self.mode: Literal["reduce-overhead", "max-autotune"] | None = mode
self.fullgraph: bool = fullgraph
self.dynamic: bool = dynamic

if self.enabled:
if self.warmup_resolutions is None:
Expand Down Expand Up @@ -101,6 +110,10 @@ def on_training_step_start(
tokenizer.compile_encode(
self.warmup_resolutions,
output_dir=self.config.job.path_local,
backend=self.backend,
mode=self.mode,
fullgraph=self.fullgraph,
dynamic=self.dynamic,
)

self.skip_counter += 1
1 change: 0 additions & 1 deletion cosmos_framework/callbacks/data_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@ def on_training_step_end(

# Handle case where dataset_name gets batched into a list
if isinstance(dataset_name, list):

assert len(dataset_name) == 1, "dataset_name should be a list of 1"
dataset_name = dataset_name[0]

Expand Down
18 changes: 6 additions & 12 deletions cosmos_framework/callbacks/dataloader_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,18 +26,14 @@ class DataLoaderStateCallback(Callback):
def __init__(
self,
distributor_type: str | None = None,
name: str = "",
) -> None:
super().__init__()
self.distributor_type = distributor_type
self.name = name
self.config: Any = None
self.state: dict[int, NoReplaceShardlistState] = {}
self.verbose = True

def _update_state_from_batch(self, data_batch: dict[str, torch.Tensor]) -> None:
if "sample_worker_id" not in data_batch:
return # batch has no position metadata (shuffle=False or iterable data_source)
worker_ids = data_batch["sample_worker_id"].tolist() # [B]
epochs = data_batch["sample_epoch"].tolist() # [B]
indices = data_batch["sample_index"].tolist() # [B]
Expand All @@ -50,8 +46,6 @@ def _update_state_from_batch(self, data_batch: dict[str, torch.Tensor]) -> None:
):
self.state[worker_id] = NoReplaceShardlistState(epoch=epoch, index=index)

_ACTIVE_DISTRIBUTOR_TYPES = ("no_replace",)

def on_training_step_batch_end(
self,
model: ImaginaireModel,
Expand All @@ -60,7 +54,7 @@ def on_training_step_batch_end(
loss: torch.Tensor,
iteration: int = 0,
) -> None:
if self.distributor_type in self._ACTIVE_DISTRIBUTOR_TYPES:
if self.distributor_type == "no_replace":
self._update_state_from_batch(data_batch)

def on_training_step_end(
Expand All @@ -71,7 +65,7 @@ def on_training_step_end(
loss: torch.Tensor,
iteration: int = 0,
) -> None:
if self.distributor_type in self._ACTIVE_DISTRIBUTOR_TYPES:
if self.distributor_type == "no_replace":
if self.verbose:
if iteration % self.config.trainer.logging_iter == 0:
msg = "\n"
Expand All @@ -80,10 +74,10 @@ def on_training_step_end(
log.info(msg)

def has_checkpoint_state(self) -> bool:
return self.distributor_type in self._ACTIVE_DISTRIBUTOR_TYPES
return self.distributor_type == "no_replace"

def state_dict(self) -> dict[int, dict[str, int]]:
if self.distributor_type not in self._ACTIVE_DISTRIBUTOR_TYPES:
if self.distributor_type != "no_replace":
return {}

state_dict: dict[int, dict[str, int]] = {}
Expand All @@ -96,7 +90,7 @@ def state_dict(self) -> dict[int, dict[str, int]]:
return state_dict

def load_state_dict(self, state_dict: dict[int, dict[str, int]]) -> None:
if self.distributor_type not in self._ACTIVE_DISTRIBUTOR_TYPES:
if self.distributor_type != "no_replace":
return

if not state_dict:
Expand All @@ -110,4 +104,4 @@ def load_state_dict(self, state_dict: dict[int, dict[str, int]]) -> None:
self.state[worker_id] = NoReplaceShardlistState(epoch=epoch, index=index)
os.environ[f"NSL_STATE_WORKER_{worker_id}_EPOCH"] = str(epoch)
os.environ[f"NSL_STATE_WORKER_{worker_id}_INDEX"] = str(index)
log.info(f"Loaded no_replace dataloader state for worker {worker_id}: epoch={epoch}, index={index}")
log.info(f"Loaded no replace dataloader state for worker {worker_id}: epoch={epoch}, index={index}")
4 changes: 0 additions & 4 deletions cosmos_framework/callbacks/every_n_draw_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,8 +154,6 @@ def x0_pred(self, trainer, model, data_batch, output_batch, loss, iteration):
tag = "ema" if self.is_ema else "reg"

log.debug("starting data and condition model", rank0_only=False)


data_clean = model.get_data_and_condition(data_batch)
raw_data = data_clean.raw_state_vision
x0 = data_clean.x0_tokens_vision
Expand Down Expand Up @@ -185,7 +183,6 @@ def x0_pred(self, trainer, model, data_batch, output_batch, loss, iteration):
log.debug(f"done denoising {sigma}", rank0_only=False)
mse_loss = distributed.dist_reduce_tensor(F.mse_loss(sample, x0))
mse_loss_list.append(mse_loss)

if hasattr(model, "decode"):
sample = model.decode(sample)
to_show.append(sample.float().cpu())
Expand Down Expand Up @@ -316,7 +313,6 @@ def sample(self, trainer, model, data_batch, output_batch, loss, iteration):
for sample_idx in range(data_clean.batch_size):
n_vis = num_items[sample_idx]
# First item(s) are condition, last item is generation target

# but we need to support multiple conditions per sample in the future. Current code
# can handle this without throwing an error.
condition_images.append(raw_data[vis_offset]) # source image (1, C, 1, H, W)
Expand Down
2 changes: 1 addition & 1 deletion cosmos_framework/callbacks/grad_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def _clip_grad(
# `torch.distributed._tensor.ops.math_ops._NormPartial`.
# We can simply reduce the DTensor to get the total norm in this
# tensor's process group and then convert it to a local tensor.

# NOTE: It has two purposes:
# 1. to make sure the total norm is computed correctly when PP is used (see below)
# 2. to return a reduced mesh_norm tensor whose .item() would return the correct value
if isinstance(mesh_norm, DTensor):
Expand Down
5 changes: 3 additions & 2 deletions cosmos_framework/callbacks/hf_export.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: OpenMDW-1.1

"""HFExportCallback: export VLM DCP checkpoints to HuggingFace safetensors format.

Design notes
Expand Down Expand Up @@ -137,11 +138,11 @@ def on_save_checkpoint(self, model: Any, state_dict: dict[str, Any]) -> None:
if not isinstance(model, VLMModel):
# The legacy vlm/train.py path passes model_parts: list[nn.Module] (raw HF
# models without the VLMModel attribute structure). HF export requires the
# VLMModel wrapper, which is only available via the unified cosmos_framework/scripts/train.py path.
# VLMModel wrapper, which is only available via the unified scripts/train.py path.
if isinstance(model, list):
log.warning(
"[HFExportCallback] Received model_parts (list) instead of VLMModel. "
"HF export requires the unified training path (cosmos_framework/scripts/train.py). Skipping."
"HF export requires the unified training path (scripts/train.py). Skipping."
)
else:
log.warning(
Expand Down
1 change: 0 additions & 1 deletion cosmos_framework/callbacks/mfu.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,6 @@ def _ensure_initialised(self, model: ImaginaireModel) -> None:
ac_cfg = getattr(model_cfg, "activation_checkpointing", None)
ac_mode = getattr(ac_cfg, "mode", "none")


# Some activations don't need to be recomputed under selective AC, so
# we need to remove them from the FLOP computation.
self._use_activation_checkpointing = ac_mode != "none"
Expand Down
1 change: 0 additions & 1 deletion cosmos_framework/callbacks/wandb_log_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,6 @@ def on_validation_step_end(

# Handle case where dataset_name gets batched into a list
if isinstance(dataset_name, list):

assert len(dataset_name) == 1, "dataset_name should be a list of 1"
dataset_name = dataset_name[0]

Expand Down
34 changes: 10 additions & 24 deletions cosmos_framework/checkpoint/dcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
set_model_state_dict,
)
from torch.distributed.checkpoint.stateful import Stateful
from torch.nn.modules.module import _IncompatibleKeys

from cosmos_framework.checkpoint.base import AbstractCheckpointer
from cosmos_framework.checkpoint.s3_filesystem import S3StorageReader, S3StorageWriter
Expand All @@ -85,11 +86,11 @@ def __init__(self, model: nn.Module) -> None:
def state_dict(self) -> dict[str, Any]:
return get_model_state_dict(self.model)

def load_state_dict(self, state_dict: dict[str, Any]) -> None:
set_model_state_dict(
def load_state_dict(self, state_dict: dict[str, Any]) -> _IncompatibleKeys:
return set_model_state_dict(
self.model,
model_state_dict=state_dict,
options=StateDictOptions(strict=True),
options=StateDictOptions(strict=False),
)


Expand Down Expand Up @@ -539,28 +540,13 @@ def load(
"Ensure the model has net_ema submodule."
)
_state_dict[sd_key] = _state_dict[key_ema]
elif warm_start and any(str(s).startswith("net_ema") for s in self.keys_to_skip_loading):
# Only when net_ema.* is explicitly skipped on load (e.g. an HF->DCP
# init from convert_model_to_dcp that has only net.*): the skipped
# net_ema.* keep build_net() construction values (random init when
# vlm_config.pretrained_weights.enabled=False), which would seed EMA
# from random weights -> copy net.* -> net_ema.* so EMA starts from the
# freshly-loaded init. When net_ema.* IS loaded (e.g. a training DCP
# that carries a trained EMA), do NOT clobber it.
log.info("Warm start: net_ema. skipped on load -> resetting net_ema = net.")
for sd_key in list(_state_dict.keys()):
if sd_key.startswith("net."):
key_ema = "net_ema." + sd_key.removeprefix("net.")
if key_ema in _state_dict:
_state_dict[key_ema] = _state_dict[sd_key]
results = _model_wrapper.load_state_dict(_state_dict)
if results is not None:
if len(results.missing_keys) > 0:
raise ValueError(f"Missing keys (not found in checkpoint): {results.missing_keys}")
if len(results.unexpected_keys) > 0:
raise ValueError(
f"Unexpected keys (found in checkpoint but not in model): {results.unexpected_keys}"
)
if len(results.missing_keys) > 0:
raise ValueError(f"Missing keys (not found in checkpoint): {results.missing_keys}")
if len(results.unexpected_keys) > 0:
raise ValueError(
f"Unexpected keys (found in checkpoint but not in model): {results.unexpected_keys}"
)

elif key == "optim":
log.info("- Loading the optimizer...")
Expand Down
Loading