Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
985d803
Add design spec: video input for reasoner model-mode inference
foreverlms Jun 8, 2026
2d2490f
Add implementation plan: video input for reasoner model-mode
foreverlms Jun 8, 2026
6ece15f
feat(reasoner): add video_* sampling fields + mutual-exclusion valida…
foreverlms Jun 8, 2026
78a25a2
refactor(reasoner): video_fps PositiveFloat + construction-time mutua…
foreverlms Jun 8, 2026
6859a70
feat(reasoner): add video_* defaults (null) to reasoner sample_args
foreverlms Jun 8, 2026
5626909
feat(reasoner): video branch in prepare_multimodal_reasoner_inputs
foreverlms Jun 8, 2026
86a7e98
feat(reasoner): accept video tensors in _impl_generate_reasoner_text
foreverlms Jun 8, 2026
3f41015
feat(reasoner): forward video tensors through generate_reasoner_text …
foreverlms Jun 8, 2026
c7d9874
fix(reasoner): revert out-of-scope param additions to Nemotron genera…
foreverlms Jun 8, 2026
ff1d67d
feat(reasoner): videos param + video chat block in OmniMoTModel.gener…
foreverlms Jun 8, 2026
dbd7e86
docs(reasoner): update generate_reasoner_text docstring for video path
foreverlms Jun 8, 2026
c31261c
feat(reasoner): route mp4 vision_path to video conditioning in infere…
foreverlms Jun 8, 2026
769f465
docs(reasoner): document video input + add reasoner_video example
foreverlms Jun 8, 2026
663112e
docs(reasoner): clarify vision_path comment covers video too
foreverlms Jun 8, 2026
1d957bb
fix(reasoner): decode video frames for Qwen3VLProcessor; reduce knobs…
foreverlms Jun 8, 2026
e9aa7f5
fix(reasoner): decode video via torchvision.io + smart_nframes (drop …
foreverlms Jun 8, 2026
19bd716
fix(reasoner): emit reasoner_videos uniformly per-sample; mark design…
foreverlms Jun 8, 2026
92e3491
chore(reasoner): untrack video-reasoner spec/plan docs (keep in-repo,…
foreverlms Jun 8, 2026
eb20347
docs(reasoner): regenerate inference.md TOC for Reasoner section (pre…
foreverlms Jun 8, 2026
6e5830a
test(reasoner): cover video modality in get_sample_data tests; accoun…
foreverlms Jun 9, 2026
7b71889
Merge branch 'main' into maoshengl/video_reasoner_inference
foreverlms Jun 11, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion cosmos_framework/inference/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,7 +454,7 @@ def _build_vision_data(self, model_config: "OmniMoTModelConfig", sample_meta: Sa
if self.vision_path and "://" in self.vision_path:
raise ValueError("Must call `download()` before building vision data")

# Reasoner mode treats ``vision_path`` as a PIL image source; resolution/fps/num_frames are unused.
# Reasoner mode treats ``vision_path`` as an image (PIL) or video (mp4) source; resolution/fps/num_frames are unused.
if sample_meta.model_mode.is_reasoner:
self.condition_frame_indexes_vision = self.condition_frame_indexes_vision or []
self.condition_video_keep = self.condition_video_keep or "first"
Expand Down Expand Up @@ -609,6 +609,7 @@ class ReasonerDataArgs(ArgsBase):
top_p: _ReasonerTopP | None = None
repetition_penalty: _ReasonerRepetitionPenalty | None = None
presence_penalty: float | None = None
video_fps: pydantic.PositiveFloat | None = None


class ReasonerDataOverrides(OverridesBase):
Expand All @@ -629,6 +630,8 @@ class ReasonerDataOverrides(OverridesBase):
"""CTRL/HF-style multiplicative repetition penalty (>0). ``1.0`` is identity."""
presence_penalty: float | None = None
"""Additive presence penalty (any sign). ``0.0`` is identity."""
video_fps: pydantic.PositiveFloat | None = None
"""Frames per second to sample from a video vision_path. None -> decoder default (2.0)."""

def _build_reasoner_data(self, model_config: "OmniMoTModelConfig", sample_meta: SampleMeta):
if not sample_meta.model_mode.is_reasoner:
Expand Down
11 changes: 11 additions & 0 deletions cosmos_framework/inference/args_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
ModelMode,
OmniSampleOverrides,
OmniSetupOverrides,
ReasonerDataOverrides,
)
from cosmos_framework.inference.common.config import structure_config

Expand Down Expand Up @@ -156,3 +157,13 @@ def test_sample_args(tmp_path: Path):
assert text2image_args.num_steps == 50
assert text2image_args.guidance == 4.0
assert text2image_args.shift == 3.0


def test_reasoner_video_fps_defaults_none():
ov = ReasonerDataOverrides()
assert ov.video_fps is None


def test_reasoner_video_fps_accepts_positive_float():
ov = ReasonerDataOverrides(video_fps=2.0)
assert ov.video_fps == 2.0
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,6 @@
"top_k": null,
"top_p": null,
"repetition_penalty": 1.0,
"presence_penalty": 0.0
"presence_penalty": 0.0,
"video_fps": null
}
71 changes: 60 additions & 11 deletions cosmos_framework/inference/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,23 @@
import cattrs.preconf.json
import safetensors.torch
import torch
import torchvision.io
from PIL import Image
from qwen_vl_utils.vision_process import smart_nframes
from torch.utils._pytree import tree_map_only
from torch.utils.data import Dataset
from typing_extensions import Self

from cosmos_framework.configs.base.defaults.compile import CompileConfig
from cosmos_framework.configs.base.defaults.parallelism import ParallelismConfig
from cosmos_framework.inference.args import (
ModelMode,
NegativeMetadataMode,
OmniSampleArgs,
OmniSetupArgs,
)
from cosmos_framework.inference.common.args import (
VIDEO_EXTENSIONS,
CheckpointType,
ConfigFileType,
ParallelismArgs,
Expand All @@ -46,13 +51,11 @@
pil_to_conditioning_frames,
resize_pil_image,
)
from cosmos_framework.utils import log
from cosmos_framework.tools.visualize.video import save_img_or_video
from cosmos_framework.configs.base.defaults.compile import CompileConfig
from cosmos_framework.configs.base.defaults.parallelism import ParallelismConfig
from cosmos_framework.model.vfm.omni_mot_model import OmniMoTModel
from cosmos_framework.model.vfm.vlm.qwen3_vl.utils import _SYSTEM_PROMPT_IMAGE_EDITING
from cosmos_framework.model.vfm.upsampler.prompts import is_upsampled_prompt
from cosmos_framework.model.vfm.vlm.qwen3_vl.utils import _SYSTEM_PROMPT_IMAGE_EDITING
from cosmos_framework.tools.visualize.video import save_img_or_video
from cosmos_framework.utils import log

if TYPE_CHECKING:
from cosmos_framework.configs.base.defaults.model_config import OmniMoTModelConfig
Expand Down Expand Up @@ -463,14 +466,44 @@ def _get_prompt_sample_data(sample_args: OmniSampleArgs, model: OmniMoTModel, *,
return out


def _decode_reasoner_video(vision_path: str, video_fps: float | None) -> dict[str, Any]:
"""Decode a local video file into the frame-list payload the Qwen3-VL processor expects.

Returns ``{"frames": [PIL.Image, ...], "fps": float}``. Uses the same
``torchvision.io.read_video`` decode the rest of the inference path relies on
(no ``decord`` dependency), then uniformly samples frames toward ``video_fps``
(default 2.0) via Qwen's ``smart_nframes``. The repo ``Qwen3VLProcessor`` runs
with ``do_sample_frames=False``, so it consumes this pre-sampled frame list
as-is and handles its own per-frame resize."""
frames, _, info = torchvision.io.read_video(str(vision_path), pts_unit="sec") # [T,H,W,C] uint8
total_frames = int(frames.shape[0])
if total_frames == 0:
raise ValueError(f"Decoded zero frames from reasoner video: {vision_path}")
src_fps = float(info.get("video_fps") or 0.0) or 1.0
target_fps = video_fps if video_fps is not None else 2.0
nframes = smart_nframes({"fps": target_fps}, total_frames=total_frames, video_fps=src_fps)
idx = torch.linspace(0, total_frames - 1, nframes).round().long().tolist()
pil_frames = [Image.fromarray(frames[i].numpy()) for i in idx]
sample_fps = nframes / total_frames * src_fps
return {"frames": pil_frames, "fps": sample_fps}


def _get_reasoner_sample_data(sample_args: OmniSampleArgs, model: OmniMoTModel) -> dict[str, Any]:
"""Sample batch for reasoner text generation: prompt + optional conditioning image."""
"""Sample batch for reasoner text generation: prompt + optional conditioning image or video."""
image: Image.Image | None = None
video: dict[str, Any] | None = None
if sample_args.vision_path is not None:
image = Image.open(sample_args.vision_path).convert("RGB")
if Path(sample_args.vision_path).suffix.lower() in VIDEO_EXTENSIONS:
video = _decode_reasoner_video(str(sample_args.vision_path), sample_args.video_fps)
else:
image = Image.open(sample_args.vision_path).convert("RGB")
# Both keys are emitted for every sample (``None`` when absent) so the batch
# builder can positionally align them and the three-way homogeneity check in
# ``_generate_reasoner_batch`` reliably detects an image/video/text mix.
return {
model.input_caption_key: [sample_args.prompt],
"reasoner_images": [image],
"reasoner_videos": [video],
}


Expand Down Expand Up @@ -1655,13 +1688,28 @@ def _generate_reasoner_batch(

prompts: list[str] = data_batch[self.model.input_caption_key]
raw_images: list[Image.Image | None] = data_batch["reasoner_images"]
n_set = sum(img is not None for img in raw_images)
if 0 < n_set < len(raw_images):
raw_videos: list[dict[str, Any] | None] | None = data_batch.get("reasoner_videos")

n_img = sum(img is not None for img in raw_images)
n_vid = sum(v is not None for v in (raw_videos or []))
if n_img and n_vid:
raise ValueError(
"Reasoner batch mixes image- and video-conditioned samples. Split into separate batches."
)
if 0 < n_img < len(raw_images):
raise ValueError(
"Reasoner batch mixes image-conditioned and text-only samples "
f"({n_set}/{len(raw_images)} have vision_path). Split into separate batches."
f"({n_img}/{len(raw_images)} have an image vision_path). Split into separate batches."
)
if raw_videos is not None and 0 < n_vid < len(raw_videos):
raise ValueError(
"Reasoner batch mixes video-conditioned and text-only samples "
f"({n_vid}/{len(raw_videos)} have a video vision_path). Split into separate batches."
)
images: list[Image.Image] | None = cast(list[Image.Image], raw_images) if n_set == len(raw_images) else None
images: list[Image.Image] | None = cast(list[Image.Image], raw_images) if n_img == len(raw_images) else None
videos: list[dict[str, Any]] | None = (
cast(list[dict[str, Any]], raw_videos) if raw_videos is not None and n_vid == len(raw_videos) else None
)

try:
with sync_distributed_errors():
Expand All @@ -1686,6 +1734,7 @@ def _generate_reasoner_batch(
prompts,
max_new_tokens=sample_args_list[0].max_new_tokens,
images=images,
videos=videos,
do_sample=sample_args_list[0].do_sample,
temperature=sample_args_list[0].temperature,
top_k=sample_args_list[0].top_k,
Expand Down
33 changes: 31 additions & 2 deletions cosmos_framework/inference/inference_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@ def _make_reasoner_sample_args(**overrides: Any) -> SimpleNamespace:
model_mode=ModelMode.REASONER,
prompt="Describe a robotic arm.",
vision_path=None,
video_fps=None,
max_new_tokens=8,
do_sample=False,
temperature=1.0,
Expand All @@ -189,7 +190,11 @@ def test_get_sample_data_reasoner_text_only() -> None:

out = inference.get_sample_data(sample_args, model, device="cpu")

assert out == {"caption": ["Describe a robotic arm."], "reasoner_images": [None]}
assert out == {
"caption": ["Describe a robotic arm."],
"reasoner_images": [None],
"reasoner_videos": [None],
}


@pytest.mark.L0
Expand All @@ -205,13 +210,35 @@ def test_get_sample_data_reasoner_with_image(tmp_path: Path) -> None:

out = inference.get_sample_data(sample_args, model, device="cpu")

assert list(out) == ["caption", "reasoner_images"]
assert list(out) == ["caption", "reasoner_images", "reasoner_videos"]
assert out["caption"] == ["Describe a robotic arm."]
assert out["reasoner_videos"] == [None]
assert len(out["reasoner_images"]) == 1
assert out["reasoner_images"][0].size == (8, 8)
assert out["reasoner_images"][0].mode == "RGB"


@pytest.mark.L0
def test_get_sample_data_reasoner_with_video(monkeypatch: pytest.MonkeyPatch) -> None:
"""A video ``vision_path`` routes through ``_decode_reasoner_video`` into ``reasoner_videos``.

The decoder is monkeypatched (real decode needs torchvision + an actual clip);
this asserts the routing/contract, not the decode itself."""
from cosmos_framework.inference import inference

decoded = {"frames": ["F0", "F1"], "fps": 2.0}
monkeypatch.setattr(inference, "_decode_reasoner_video", lambda path, fps: decoded)
model = SimpleNamespace(input_caption_key="caption")
sample_args = _make_reasoner_sample_args(vision_path="/tmp/clip.mp4", video_fps=2.0)

out = inference.get_sample_data(sample_args, model, device="cpu")

assert out["caption"] == ["Describe a robotic arm."]
assert out["reasoner_videos"] == [decoded]
assert out["reasoner_images"] == [None]
assert "video_sampling_kwargs" not in out


@pytest.mark.L0
def test_reasoner_defaults_json_round_trip() -> None:
import json as _json
Expand Down Expand Up @@ -349,3 +376,5 @@ def test_reasoner_defaults_validate_against_overrides() -> None:
filtered = {k: v for k, v in defaults.items() if k in OmniSampleOverrides.model_fields}
assert set(defaults) - set(filtered) == set(), f"defaults has unknown fields: {set(defaults) - set(filtered)}"
OmniSampleOverrides.model_validate(filtered)


12 changes: 9 additions & 3 deletions cosmos_framework/model/vfm/mot/cosmos3_vfm_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,8 @@ def generate_reasoner_text(
*,
pixel_values: torch.Tensor | None = None,
image_grid_thw: torch.Tensor | None = None,
pixel_values_videos: torch.Tensor | None = None,
video_grid_thw: torch.Tensor | None = None,
attention_mask: torch.Tensor | None = None,
eos_token_id: int | list[int] | None = None,
pad_token_id: int | None = None,
Expand All @@ -296,9 +298,11 @@ def generate_reasoner_text(
prompts through this single entry point: pass
``pixel_values`` + ``image_grid_thw`` (and optionally
``attention_mask``) for image-conditioned prefill via the Qwen3-VL
visual encoder, or omit them for text-only prefill. Uses the
und-pathway weights (those WITHOUT the ``_moe_gen`` suffix) plus
``embed_tokens`` / ``norm`` / ``lm_head``; the generation pathway
visual encoder, or omit them for text-only prefill. Video
conditioning is also supported via ``pixel_values_videos`` +
``video_grid_thw``; the image and video pairs are mutually exclusive.
Uses the und-pathway weights (those WITHOUT the ``_moe_gen`` suffix)
plus ``embed_tokens`` / ``norm`` / ``lm_head``; the generation pathway
and all VFM-level multimodal embedders / heads (``vae2llm``,
``llm2vae``, ``sound2llm``, etc.) are bypassed.

Expand Down Expand Up @@ -327,6 +331,8 @@ def generate_reasoner_text(
max_new_tokens=max_new_tokens,
pixel_values=pixel_values,
image_grid_thw=image_grid_thw,
pixel_values_videos=pixel_values_videos,
video_grid_thw=video_grid_thw,
attention_mask=attention_mask,
eos_token_id=eos_token_id,
pad_token_id=pad_token_id,
Expand Down
29 changes: 24 additions & 5 deletions cosmos_framework/model/vfm/mot/unified_mot.py
Original file line number Diff line number Diff line change
Expand Up @@ -1494,6 +1494,8 @@ def _impl_generate_reasoner_text(
*,
pixel_values: torch.Tensor | None = None,
image_grid_thw: torch.Tensor | None = None,
pixel_values_videos: torch.Tensor | None = None,
video_grid_thw: torch.Tensor | None = None,
attention_mask: torch.Tensor | None = None,
eos_token_id: int | list[int] | None = None,
pad_token_id: int | None = None,
Expand Down Expand Up @@ -1550,10 +1552,9 @@ def _impl_generate_reasoner_text(
``Qwen3VLProcessor`` emits — pass it through unchanged.
Moved to the prompt's device internally. ``None`` (default)
means text-only prompt; in that case the multimodal prefill
path is skipped entirely. Videos are *not* supported here —
this function has no ``pixel_values_videos`` / ``video_grid_thw``
parameters; for I2V conditioning, frames must be passed as
images.
path is skipped entirely. For video conditioning, pass ``pixel_values_videos`` +
``video_grid_thw`` instead (mutually exclusive with the image
pair).
image_grid_thw: Optional ``[num_images, 3]`` long tensor giving
``(t, h, w)`` — the temporal / height / width feature-grid
size per image as produced by ``Qwen3VLProcessor`` (``t`` is
Expand Down Expand Up @@ -1643,11 +1644,15 @@ def _impl_generate_reasoner_text(

if (pixel_values is None) != (image_grid_thw is None):
raise ValueError("pixel_values and image_grid_thw must be provided together.")
if (pixel_values_videos is None) != (video_grid_thw is None):
raise ValueError("pixel_values_videos and video_grid_thw must be provided together.")
if pixel_values is not None and pixel_values_videos is not None:
raise ValueError("Reasoner conditions on one medium at a time: pass image OR video, not both.")

_prefill_start = time.time()

mrope_position_deltas: torch.Tensor | None = None
if pixel_values is None:
if pixel_values is None and pixel_values_videos is None:
hidden = model.reasoner_forward(input_ids, cache=cache) # [B,T_prompt,hidden_size]
else:
if not hasattr(causal_lm, "visual"):
Expand All @@ -1663,6 +1668,8 @@ def _impl_generate_reasoner_text(
input_ids=input_ids,
pixel_values=pixel_values,
image_grid_thw=image_grid_thw,
pixel_values_videos=pixel_values_videos,
video_grid_thw=video_grid_thw,
attention_mask=attention_mask,
)
hidden = model.reasoner_forward(
Expand Down Expand Up @@ -1936,6 +1943,8 @@ def generate_reasoner_text(
*,
pixel_values: torch.Tensor | None = None,
image_grid_thw: torch.Tensor | None = None,
pixel_values_videos: torch.Tensor | None = None,
video_grid_thw: torch.Tensor | None = None,
attention_mask: torch.Tensor | None = None,
eos_token_id: int | list[int] | None = None,
pad_token_id: int | None = None,
Expand All @@ -1956,6 +1965,8 @@ def generate_reasoner_text(
the Qwen3-VL visual encoder; omit them for text-only prefill. The
two arguments are mutually required: passing exactly one raises
``ValueError`` inside :func:`_impl_generate_reasoner_text`.
Video conditioning is also supported via ``pixel_values_videos`` +
``video_grid_thw``; the image and video pairs are mutually exclusive.

Uses the und-pathway weights (those WITHOUT the ``_moe_gen`` suffix)
plus the model-level ``embed_tokens`` / ``norm`` / ``lm_head``, and —
Expand All @@ -1970,6 +1981,8 @@ def generate_reasoner_text(
max_new_tokens=max_new_tokens,
pixel_values=pixel_values,
image_grid_thw=image_grid_thw,
pixel_values_videos=pixel_values_videos,
video_grid_thw=video_grid_thw,
attention_mask=attention_mask,
eos_token_id=eos_token_id,
pad_token_id=pad_token_id,
Expand Down Expand Up @@ -2064,6 +2077,8 @@ def generate_reasoner_text(
*,
pixel_values: torch.Tensor | None = None,
image_grid_thw: torch.Tensor | None = None,
pixel_values_videos: torch.Tensor | None = None,
video_grid_thw: torch.Tensor | None = None,
attention_mask: torch.Tensor | None = None,
eos_token_id: int | list[int] | None = None,
pad_token_id: int | None = None,
Expand All @@ -2084,6 +2099,8 @@ def generate_reasoner_text(
the Qwen3-VL visual encoder; omit them for text-only prefill. The
two arguments are mutually required: passing exactly one raises
``ValueError`` inside :func:`_impl_generate_reasoner_text`.
Video conditioning is also supported via ``pixel_values_videos`` +
``video_grid_thw``; the image and video pairs are mutually exclusive.

Uses the und-pathway weights (those WITHOUT the ``_moe_gen`` suffix)
plus the model-level ``embed_tokens`` / ``norm`` / ``lm_head``, and —
Expand All @@ -2099,6 +2116,8 @@ def generate_reasoner_text(
max_new_tokens=max_new_tokens,
pixel_values=pixel_values,
image_grid_thw=image_grid_thw,
pixel_values_videos=pixel_values_videos,
video_grid_thw=video_grid_thw,
attention_mask=attention_mask,
eos_token_id=eos_token_id,
pad_token_id=pad_token_id,
Expand Down
Loading