Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions cosmos_framework/auxiliary/guardrail/common/presets.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,7 @@ def create_text_guardrail_runner(offload_model_to_cpu: bool = False) -> Guardrai
def create_video_guardrail_runner(offload_model_to_cpu: bool = False) -> GuardrailRunner:
"""Create the video guardrail runner."""
return GuardrailRunner(
safety_models=[
# VideoContentSafetyFilter(offload_model_to_cpu=offload_model_to_cpu), # Too many false positives
],
safety_models=[VideoContentSafetyFilter(offload_model_to_cpu=offload_model_to_cpu)],
postprocessors=[RetinaFaceFilter(offload_model_to_cpu=offload_model_to_cpu)],
)

Expand Down
15 changes: 14 additions & 1 deletion cosmos_framework/callbacks/compile_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
"""Training callback that defers AOT compilation of the VAE tokenizer.

The actual compilation logic lives in
:meth:`~projects.cosmos3.vfm.tokenizers.wan2pt2_vae_4x16x16.Wan2pt2VAEInterface.compile_encode`.
:meth:`~cosmos_framework.model.vfm.tokenizers.wan2pt2_vae_4x16x16.Wan2pt2VAEInterface.compile_encode`.
This module provides a :class:`CompileTokenizer` callback that invokes it
at the right point during training (after ``compile_after_iterations``
steps, to avoid NCCL timeouts during CUDA/cuDNN warm-up).
Expand All @@ -21,6 +21,7 @@
"""

from collections.abc import Sequence
from typing import Literal

import torch

Expand All @@ -43,6 +44,10 @@ def __init__(
enabled: bool = False,
compile_after_iterations: int = 3,
warmup_resolutions: Sequence[str] | None = None,
backend: Literal["cudagraphs", "inductor"] = "inductor",
mode: Literal["reduce-overhead", "max-autotune"] | None = "reduce-overhead",
fullgraph: bool = False,
dynamic: bool = False,
):
"""
Args:
Expand All @@ -60,6 +65,10 @@ def __init__(
self.compile_after_iterations: int = compile_after_iterations
self.skip_counter: int = 0
self.warmup_resolutions: Sequence[str] | None = warmup_resolutions
self.backend: Literal["cudagraphs", "inductor"] = backend
self.mode: Literal["reduce-overhead", "max-autotune"] | None = mode
self.fullgraph: bool = fullgraph
self.dynamic: bool = dynamic

if self.enabled:
if self.warmup_resolutions is None:
Expand Down Expand Up @@ -101,6 +110,10 @@ def on_training_step_start(
tokenizer.compile_encode(
self.warmup_resolutions,
output_dir=self.config.job.path_local,
backend=self.backend,
mode=self.mode,
fullgraph=self.fullgraph,
dynamic=self.dynamic,
)

self.skip_counter += 1
1 change: 0 additions & 1 deletion cosmos_framework/callbacks/data_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@ def on_training_step_end(

# Handle case where dataset_name gets batched into a list
if isinstance(dataset_name, list):

assert len(dataset_name) == 1, "dataset_name should be a list of 1"
dataset_name = dataset_name[0]

Expand Down
126 changes: 8 additions & 118 deletions cosmos_framework/callbacks/dataloader_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,18 +26,14 @@ class DataLoaderStateCallback(Callback):
def __init__(
self,
distributor_type: str | None = None,
name: str = "",
) -> None:
super().__init__()
self.distributor_type = distributor_type
self.name = name
self.config: Any = None
self.state: dict[int, NoReplaceShardlistState] = {}
self.verbose = True

def _update_state_from_batch(self, data_batch: dict[str, torch.Tensor]) -> None:
if "sample_worker_id" not in data_batch:
return # batch has no position metadata (shuffle=False or iterable data_source)
worker_ids = data_batch["sample_worker_id"].tolist() # [B]
epochs = data_batch["sample_epoch"].tolist() # [B]
indices = data_batch["sample_index"].tolist() # [B]
Expand All @@ -50,8 +46,6 @@ def _update_state_from_batch(self, data_batch: dict[str, torch.Tensor]) -> None:
):
self.state[worker_id] = NoReplaceShardlistState(epoch=epoch, index=index)

_ACTIVE_DISTRIBUTOR_TYPES = ("no_replace", "data_packer")

def on_training_step_batch_end(
self,
model: ImaginaireModel,
Expand All @@ -60,7 +54,7 @@ def on_training_step_batch_end(
loss: torch.Tensor,
iteration: int = 0,
) -> None:
if self.distributor_type in self._ACTIVE_DISTRIBUTOR_TYPES:
if self.distributor_type == "no_replace":
self._update_state_from_batch(data_batch)

def on_training_step_end(
Expand All @@ -71,7 +65,7 @@ def on_training_step_end(
loss: torch.Tensor,
iteration: int = 0,
) -> None:
if self.distributor_type in self._ACTIVE_DISTRIBUTOR_TYPES:
if self.distributor_type == "no_replace":
if self.verbose:
if iteration % self.config.trainer.logging_iter == 0:
msg = "\n"
Expand All @@ -80,10 +74,10 @@ def on_training_step_end(
log.info(msg)

def has_checkpoint_state(self) -> bool:
return self.distributor_type in self._ACTIVE_DISTRIBUTOR_TYPES
return self.distributor_type == "no_replace"

def state_dict(self) -> dict[int, dict[str, int]]:
if self.distributor_type not in self._ACTIVE_DISTRIBUTOR_TYPES:
if self.distributor_type != "no_replace":
return {}

state_dict: dict[int, dict[str, int]] = {}
Expand All @@ -96,122 +90,18 @@ def state_dict(self) -> dict[int, dict[str, int]]:
return state_dict

def load_state_dict(self, state_dict: dict[int, dict[str, int]]) -> None:
if self.distributor_type not in self._ACTIVE_DISTRIBUTOR_TYPES:
if self.distributor_type != "no_replace":
return

if not state_dict:
log.info("No dataloader state found in checkpoint")
return

self.state = {}
# Build env var prefix. For data_packer, namespacing avoids conflicts
# when multiple DataPackerDataLoader instances share the same process
# (e.g. inside JointDataPackerDataLoader). name="" → original format.
_dp_pfx = f"DP_STATE_{self.name}_" if self.name else "DP_STATE_"
for worker_id, per_worker_state in state_dict.items():
epoch = per_worker_state["epoch"]
index = per_worker_state["index"]
self.state[worker_id] = NoReplaceShardlistState(epoch=epoch, index=index)
if self.distributor_type == "data_packer":
os.environ[f"{_dp_pfx}WORKER_{worker_id}_EPOCH"] = str(epoch)
os.environ[f"{_dp_pfx}WORKER_{worker_id}_INDEX"] = str(index)
log.info(f"Loaded data_packer dataloader state for worker {worker_id}: epoch={epoch}, index={index}")
else:
os.environ[f"NSL_STATE_WORKER_{worker_id}_EPOCH"] = str(epoch)
os.environ[f"NSL_STATE_WORKER_{worker_id}_INDEX"] = str(index)
log.info(f"Loaded no_replace dataloader state for worker {worker_id}: epoch={epoch}, index={index}")


class JointDataLoaderStateCallback(Callback):
"""Checkpoint/resume state for ``JointDataPackerDataLoader``.

Manages two levels of state in a single DCP checkpoint entry
(``checkpoint_component = "dataloader"``):

1. **Outer** ``global_id`` — the number of batches the outer loader has
yielded. Restored via ``outer_loader.set_start_iteration(global_id)``
so the deterministic dataset-selection sequence resumes from the correct
step.

2. **Inner** per-dataset, per-worker ``(epoch, index)`` — one
``DataLoaderStateCallback`` per inner loader, keyed by the dataset name.
Each inner callback sets namespaced env vars on ``load_state_dict`` so
workers fast-forward to the saved sample position.

Usage in experiment configs::

joint_loader = JointDataPackerDataLoader(dataloaders={...}, seed=42)
exp["dataloader_train"] = joint_loader
exp["trainer"]["callbacks"]["dataloader_state"] = JointDataLoaderStateCallback(
outer_loader=joint_loader,
distributor_type="data_packer",
)

The ``checkpoint_component = "dataloader"`` class attribute ensures the DCP
checkpointer's ``_DataloaderWrapper`` discovers exactly this callback (it
picks the first matching callback). Do **not** also register standalone
``DataLoaderStateCallback`` instances for the inner loaders — this class
already handles them all.
"""

checkpoint_component: str = "dataloader"

def __init__(
self,
outer_loader: Any,
distributor_type: str = "data_packer",
) -> None:
super().__init__()
self._outer = outer_loader
self._inner: dict[str, DataLoaderStateCallback] = {
name: DataLoaderStateCallback(distributor_type=distributor_type, name=name)
for name in outer_loader._names
}
self.config: Any = None

def _update_state_from_batch(self, batch: dict) -> None:
name = batch.get("dataset_name")
if name in self._inner:
self._inner[name]._update_state_from_batch(batch)

def on_training_step_batch_end(
self,
model: Any,
data_batch: dict,
output_batch: dict,
loss: Any,
iteration: int = 0,
) -> None:
self._update_state_from_batch(data_batch)

def on_training_step_end(
self,
model: Any,
data_batch: dict,
output_batch: dict,
loss: Any,
iteration: int = 0,
) -> None:
if self.config and iteration % self.config.trainer.logging_iter == 0:
msg = f"\nJointDataPackerDataLoader global_id={self._outer._global_id}\n"
for name, cb in self._inner.items():
for wid, state in cb.state.items():
msg += f" [{name}] worker {wid}: epoch={state.epoch}, index={state.index}\n"
log.info(msg)

def has_checkpoint_state(self) -> bool:
return True

def state_dict(self) -> dict:
return {
"global_id": self._outer._global_id,
**{name: cb.state_dict() for name, cb in self._inner.items()},
}

def load_state_dict(self, state: dict) -> None:
global_id = state.get("global_id", 0)
self._outer.set_start_iteration(global_id)
log.info(f"JointDataLoaderStateCallback: resumed outer global_id={global_id}")
for name, cb in self._inner.items():
if name in state:
cb.load_state_dict(state[name])
os.environ[f"NSL_STATE_WORKER_{worker_id}_EPOCH"] = str(epoch)
os.environ[f"NSL_STATE_WORKER_{worker_id}_INDEX"] = str(index)
log.info(f"Loaded no replace dataloader state for worker {worker_id}: epoch={epoch}, index={index}")
4 changes: 0 additions & 4 deletions cosmos_framework/callbacks/every_n_draw_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,8 +154,6 @@ def x0_pred(self, trainer, model, data_batch, output_batch, loss, iteration):
tag = "ema" if self.is_ema else "reg"

log.debug("starting data and condition model", rank0_only=False)


data_clean = model.get_data_and_condition(data_batch)
raw_data = data_clean.raw_state_vision
x0 = data_clean.x0_tokens_vision
Expand Down Expand Up @@ -185,7 +183,6 @@ def x0_pred(self, trainer, model, data_batch, output_batch, loss, iteration):
log.debug(f"done denoising {sigma}", rank0_only=False)
mse_loss = distributed.dist_reduce_tensor(F.mse_loss(sample, x0))
mse_loss_list.append(mse_loss)

if hasattr(model, "decode"):
sample = model.decode(sample)
to_show.append(sample.float().cpu())
Expand Down Expand Up @@ -316,7 +313,6 @@ def sample(self, trainer, model, data_batch, output_batch, loss, iteration):
for sample_idx in range(data_clean.batch_size):
n_vis = num_items[sample_idx]
# First item(s) are condition, last item is generation target

# but we need to support multiple conditions per sample in the future. Current code
# can handle this without throwing an error.
condition_images.append(raw_data[vis_offset]) # source image (1, C, 1, H, W)
Expand Down
2 changes: 1 addition & 1 deletion cosmos_framework/callbacks/grad_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def _clip_grad(
# `torch.distributed._tensor.ops.math_ops._NormPartial`.
# We can simply reduce the DTensor to get the total norm in this
# tensor's process group and then convert it to a local tensor.

# NOTE: It has two purposes:
# 1. to make sure the total norm is computed correctly when PP is used (see below)
# 2. to return a reduced mesh_norm tensor whose .item() would return the correct value
if isinstance(mesh_norm, DTensor):
Expand Down
4 changes: 2 additions & 2 deletions cosmos_framework/callbacks/hf_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,11 +137,11 @@ def on_save_checkpoint(self, model: Any, state_dict: dict[str, Any]) -> None:
if not isinstance(model, VLMModel):
# The legacy vlm/train.py path passes model_parts: list[nn.Module] (raw HF
# models without the VLMModel attribute structure). HF export requires the
# VLMModel wrapper, which is only available via the unified cosmos_framework/scripts/train.py path.
# VLMModel wrapper, which is only available via the unified scripts/train.py path.
if isinstance(model, list):
log.warning(
"[HFExportCallback] Received model_parts (list) instead of VLMModel. "
"HF export requires the unified training path (cosmos_framework/scripts/train.py). Skipping."
"HF export requires the unified training path (scripts/train.py). Skipping."
)
else:
log.warning(
Expand Down
1 change: 0 additions & 1 deletion cosmos_framework/callbacks/mfu.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,6 @@ def _ensure_initialised(self, model: ImaginaireModel) -> None:
ac_cfg = getattr(model_cfg, "activation_checkpointing", None)
ac_mode = getattr(ac_cfg, "mode", "none")


# Some activations don't need to be recomputed under selective AC, so
# we need to remove them from the FLOP computation.
self._use_activation_checkpointing = ac_mode != "none"
Expand Down
1 change: 0 additions & 1 deletion cosmos_framework/callbacks/wandb_log_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,6 @@ def on_validation_step_end(

# Handle case where dataset_name gets batched into a list
if isinstance(dataset_name, list):

assert len(dataset_name) == 1, "dataset_name should be a list of 1"
dataset_name = dataset_name[0]

Expand Down
2 changes: 1 addition & 1 deletion cosmos_framework/checkpoint/s3_filesystem.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,7 @@ def __init__(
"""
super().__init__(
path=path,
sync_files=False,
sync_files=False, # FIXME: setting this to True makes the run to fail (L#333: `os.fsync(stream.fileno())`)
**kwargs,
)
self.fs = S3FileSystem(credential_path, enable_gcs_patch_in_boto3=enable_gcs_patch_in_boto3) # type: ignore
Expand Down
2 changes: 1 addition & 1 deletion cosmos_framework/configs/base/base_config_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def test_config_init_experiment_mot(experiment_name):
Parameterized test to verify config initialization for multiple experiments.
PYTHONPATH=. torchrun --nproc_per_node=8 -m pytest -s cosmos_framework/configs/base/config_test_mot.py --L1
"""
config_file = "configs/base/config.py"
config_file = "cosmos_framework/configs/base/config.py"
config_module = get_config_module(config_file)
config = importlib.import_module(config_module).make_config()
config = override(
Expand Down
1 change: 0 additions & 1 deletion cosmos_framework/configs/base/defaults/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from cosmos_framework.utils.lazy_config import LazyCall as L
from cosmos_framework.utils.callback import LowPrecisionCallback, WandBCallback
from cosmos_framework.callbacks.compile_tokenizer import CompileTokenizer

from cosmos_framework.callbacks.device_monitor import DeviceMonitor
from cosmos_framework.callbacks.every_n_draw_sample import EveryNDrawSample
from cosmos_framework.callbacks.expert_heatmap import ExpertHeatmap
Expand Down
22 changes: 16 additions & 6 deletions cosmos_framework/configs/base/defaults/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,24 @@ class ClusterConfig:

DefaultClusterConfig: ClusterConfig = ClusterConfig(
object_store_bucket_data="",
object_store_bucket_checkpoint="bucket-checkpoint",
object_store_bucket_pretrained="bucket-pretrained",
object_store_credential_data="credentials/data.secret",
object_store_credential_checkpoint="credentials/checkpoint.secret",
object_store_credential_pretrained="credentials/pretrained.secret",
object_store_bucket_checkpoint="bucket4",
object_store_bucket_pretrained="bucket4",
object_store_credential_data="credentials/s3_training.secret",
object_store_credential_checkpoint="credentials/s3_checkpoint.secret",
object_store_credential_pretrained="credentials/s3_checkpoint.secret",
)

DefaultClusterConfig: ClusterConfig = ClusterConfig(
object_store_bucket_data="",
object_store_bucket_checkpoint="bucket1",
object_store_bucket_pretrained="bucket0",
object_store_credential_data="credentials/gcp_checkpoint.secret",
object_store_credential_checkpoint="credentials/gcp_training.secret",
object_store_credential_pretrained="credentials/gcp_training.secret",
)


def register_cluster():
cs = ConfigStore.instance()
cs.store(group="cluster", package="job.cluster", name="default", node=DefaultClusterConfig)
cs.store(group="cluster", package="job.cluster", name="aws_iad_h100", node=DefaultClusterConfig)
cs.store(group="cluster", package="job.cluster", name="gcp_iad_gb200", node=DefaultClusterConfig)
2 changes: 1 addition & 1 deletion cosmos_framework/configs/base/defaults/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class CompileConfig:
# (maps to ``torch.compile(dynamic=...)``). Defaults to True for training,
# which sees varying shapes across batches (sequence length, CP sharding, ...);
# specializing would recompile continuously. See ParallelismOverrides in
# cosmos_framework/inference/common/args.py for the inference-side rationale
# packages/cosmos3/cosmos3/common/args.py for the inference-side rationale
# (where dynamic=False is preferred for stable AR shapes).
compile_dynamic: bool = True

Expand Down
Loading