Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .agents/skills/cosmos3-inference/SKILL.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ All paths below are relative to the cosmos3 package root (`../../../` from this
| Which model should I use? (Nano vs Super, memory, shift) | `README.md` § Models |
| Which modality? (t2i, t2v, i2v, examples) | `README.md` § Modalities |
| What parallelism preset? (latency vs throughput) | `README.md` § Inference |
| How do I lower GPU memory / offload to CPU? (`--offload-stages`) | `docs/inference.md` § CPU Offloading |
| What input fields are available? (prompt, vision_path, num_frames, ...) | `docs/inference.md` § Sample Arguments |
| What are the default parameter values? | `cosmos_framework/inference/defaults/<model_mode>/sample_args.json` (per-modality JSON) |
| How do I use custom defaults? | `docs/inference.md` § Custom Defaults |
Expand Down
1 change: 1 addition & 0 deletions .claude/skills/cosmos3-inference/SKILL.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ All paths below are relative to the cosmos3 package root (`../../../` from this
| Which model should I use? (Nano vs Super, memory, shift) | `README.md` § Models |
| Which modality? (t2i, t2v, i2v, examples) | `README.md` § Modalities |
| What parallelism preset? (latency vs throughput) | `README.md` § Inference |
| How do I lower GPU memory / offload to CPU? (`--offload-stages`) | `docs/inference.md` § CPU Offloading |
| What input fields are available? (prompt, vision_path, num_frames, ...) | `docs/inference.md` § Sample Arguments |
| What are the default parameter values? | `cosmos_framework/inference/defaults/<model_mode>/sample_args.json` (per-modality JSON) |
| How do I use custom defaults? | `docs/inference.md` § Custom Defaults |
Expand Down
26 changes: 26 additions & 0 deletions cosmos_framework/inference/args_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,32 @@ def check_model_equal(actual: pydantic.BaseModel, expected: pydantic.BaseModel):
check_model_equal(OmniSetupOverrides.model_validate(args.model_dump()).build_setup(), args)


def test_offload_stages(tmp_path: Path):
def _build(**kwargs):
return OmniSetupOverrides(
checkpoint_path=DEFAULT_CHECKPOINT_NAME,
output_dir=tmp_path / "outputs",
**kwargs,
).build_setup()

# Default: offloading disabled.
args = _build()
assert args.offload_stages == ()

# Arena stages round-trip through build_setup.
args = _build(offload_stages=("reasoner", "generator", "vae"))
assert args.offload_stages == ("reasoner", "generator", "vae")

# Guardrail offloading is a separate flag, not an --offload-stages value.
for bad in (("guardrails",), ("bogus",)):
with pytest.raises(pydantic.ValidationError):
OmniSetupOverrides(
checkpoint_path=DEFAULT_CHECKPOINT_NAME,
output_dir=tmp_path / "outputs",
offload_stages=bad,
)


def test_sample_args(tmp_path: Path):
setup_args = OmniSetupOverrides(
checkpoint_path=DEFAULT_CHECKPOINT_NAME,
Expand Down
14 changes: 14 additions & 0 deletions cosmos_framework/inference/common/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,11 @@
Training = Suppress


# Single-GPU CPU-offload stages selectable via ``--offload-stages``. (Guardrail
# offloading has its own dedicated flag, ``--offload-guardrail-models``.)
OffloadStage = Literal["reasoner", "generator", "vae"]


IMAGE_EXTENSIONS = [".png", ".jpg", ".jpeg", ".webp"]
VIDEO_EXTENSIONS = [".mp4"]
MEDIA_EXTENSIONS = IMAGE_EXTENSIONS + VIDEO_EXTENSIONS
Expand Down Expand Up @@ -712,6 +717,7 @@ class SetupArgs(ABC, CheckpointArgs, ParallelismArgs, GuardrailArgs):
warmup: pydantic.NonNegativeInt
max_model_len: pydantic.PositiveInt | None
max_num_seqs: pydantic.PositiveInt | None
offload_stages: tuple[OffloadStage, ...]

# Subclass must implement these fields/methods
# ------------------------------------------------------------
Expand Down Expand Up @@ -762,6 +768,14 @@ class SetupOverrides(ABC, CheckpointOverrides, ParallelismOverrides, GuardrailOv
max_num_seqs: pydantic.PositiveInt | None = 1
"""Maximum number of sequences per batch. When set, samples are packed into
batches by number of sequences."""
offload_stages: tuple[OffloadStage, ...] = ()
"""Single-GPU CPU-offload stages. Each named component is offloaded to pinned CPU
storage and staged into one reusable GPU arena only while in use, reducing peak
GPU memory. Choices: 'reasoner' / 'generator' (the MoT towers — enabling either
runs the understanding pathway once as a prefill that caches the per-layer K/V, then
runs the denoise loop generator-only) and 'vae' (the vision tokenizer, staged around
encode/decode). Empty = off (joint path, unchanged). Single-GPU only; incompatible
with CUDA graphs. Guardrail offloading has its own flag, --offload-guardrail-models."""

def _build_setup(self):
pass
Expand Down
Loading