Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions cosmos_framework/configs/base/defaults/parallelism.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,19 @@

"""User-facing parallelism degrees shared by VFM and VLM trainers."""

from typing import Literal

import attrs
import torch

AttentionIOLayout = Literal["sequence_sharded", "replicated"]

# Canonical mapping from precision string (used in user-facing configs and
# threaded through OmegaConf) to ``torch.dtype``. Consumed by sites that
# need to translate ``precision`` / ``fsdp_master_dtype`` into concrete
# torch dtypes (e.g. ``MixedPrecisionPolicy``, ``HFModel`` meta-init).


PRECISION_TO_TORCH_DTYPE: dict[str, torch.dtype] = {
"bfloat16": torch.bfloat16,
"float16": torch.float16,
Expand All @@ -31,6 +37,15 @@ class ParallelismConfig:
# Number of ranks for context parallelism.
context_parallel_shard_degree: int = 1

# Tensor layout at the attention boundary when CP is enabled. Both
# layouts may run the attention kernel with head-sharded Q/K/V:
# ``sequence_sharded`` keeps surrounding projections/MLP sequence-sharded
# with Ulysses-style all-to-all into/out of attention, while
# ``replicated`` keeps current-frame hidden states replicated, slices
# local heads before attention, then reduces/gathers attention output back
# to replicated hidden states.
attention_io_layout: AttentionIOLayout = "sequence_sharded"

# Number of ranks for CFG parallelism.
cfg_parallel_shard_degree: int = 1

Expand Down
43 changes: 36 additions & 7 deletions cosmos_framework/data/vfm/augmentor_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import cosmos_framework.data.vfm.augmentors.append_fps_frames_for_image as append_fps_frames_for_image
import cosmos_framework.data.vfm.augmentors.audio_caption as audio_caption
import cosmos_framework.data.vfm.augmentors.caption_filter as caption_filter
import cosmos_framework.data.vfm.augmentors.cropping as cosmos_cropping
import cosmos_framework.data.vfm.augmentors.duration_fps_text_timestamps as duration_fps_text_timestamps
import cosmos_framework.data.vfm.augmentors.image_resolution_filter as image_resolution_filter
import cosmos_framework.data.vfm.augmentors.merge_datadict as merge_datadict
Expand All @@ -25,6 +26,9 @@
from cosmos_framework.data.vfm.augmentors import sequence_plan
from cosmos_framework.data.vfm.utils import IMAGE_RES_SIZE_INFO, VIDEO_RES_SIZE_INFO

# UniAE requires spatial dimensions divisible by (spatial_compression * patch_spatial) = 16 * 2 = 32.
UNIAE_SPATIAL_MULTIPLE = 32

AUGMENTOR_OPTIONS = {}

CAMERA_MOVEMENT_PHRASES = [
Expand Down Expand Up @@ -617,9 +621,20 @@ def get_video_augmentor_v3(
input_keys=["video"],
args={"size": VIDEO_RES_SIZE_INFO[resolution]},
),
"reflection_padding": L(padding.ReflectionPadding)(
input_keys=["video"],
args={"size": VIDEO_RES_SIZE_INFO[resolution]},
**(
{
"reflection_padding": L(padding.ReflectionPadding)(
input_keys=["video"],
args={"size": VIDEO_RES_SIZE_INFO[resolution]},
)
}
if causal_vae
else {
"crop_to_multiple": L(cosmos_cropping.CropToMultiple)(
input_keys=["video"],
args={"multiple": UNIAE_SPATIAL_MULTIPLE},
)
}
),
"text_transform": L(text_transforms_for_video.TextTransformForVideoWithFullFrames)(
input_keys=["metas", "ai_caption", "sequence_plan"],
Expand Down Expand Up @@ -781,6 +796,8 @@ def get_video_augmentor_v3_json_caption(
use_dynamic_fps: bool = False,
max_stride: int = 3,
min_stride: int = 1,
min_fps: float = 10.0,
max_fps: float = 60.0,
use_system_prompt: bool = False,
resize_on_read: bool = False,
dataset_resolution_type: str = "all",
Expand Down Expand Up @@ -842,7 +859,6 @@ def get_video_augmentor_v3_json_caption(
uniae_pad_frames = kwargs.get("uniae_pad_frames", None)
uniae_chunk_frames = kwargs.get("uniae_chunk_frames", None)

print("Running video_augmentor_v3_json_caption...")
augmentors = {
# Caption parsing runs BEFORE video parsing so that VideoParsingChunkedFrames can
# decode only the frames belonging to a randomly sampled caption chunk.
Expand All @@ -863,6 +879,8 @@ def get_video_augmentor_v3_json_caption(
"use_dynamic_fps": use_dynamic_fps,
"max_stride": max_stride,
"min_stride": min_stride,
"min_fps": min_fps,
"max_fps": max_fps,
"seek_mode": "exact",
"dataset_resolution_type": dataset_resolution_type,
"resolution": resolution,
Expand Down Expand Up @@ -908,9 +926,20 @@ def get_video_augmentor_v3_json_caption(
input_keys=["video"],
args={"size": VIDEO_RES_SIZE_INFO[resolution]},
),
"reflection_padding": L(padding.ReflectionPadding)(
input_keys=["video"],
args={"size": VIDEO_RES_SIZE_INFO[resolution]},
**(
{
"reflection_padding": L(padding.ReflectionPadding)(
input_keys=["video"],
args={"size": VIDEO_RES_SIZE_INFO[resolution]},
)
}
if causal_vae
else {
"crop_to_multiple": L(cosmos_cropping.CropToMultiple)(
input_keys=["video"],
args={"multiple": UNIAE_SPATIAL_MULTIPLE},
)
}
),
# Duration/FPS timestamp augmentor - appends metadata like "The video is 2.5 seconds long and is of 24 FPS."
# To customize the template or separator, add them to the args dict below:
Expand Down
4 changes: 3 additions & 1 deletion cosmos_framework/data/vfm/augmentors/cropping.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,9 @@ def __call__(self, data_dict: dict) -> dict:
# log.info(f"Data cropped from ({h}, {w}) to ({new_h}, {new_w})")
data_dict[key] = transforms_F.crop(data, top=top, left=left, height=new_h, width=new_w)

# Store final dimensions for downstream use (e.g., resolution text info)
# Store final dimensions for downstream use (e.g., ResolutionTextInfo)
# Use the same image_size format as ReflectionPadding: [target_h, target_w, orig_h, orig_w]
data_dict["image_size"] = torch.tensor([new_h, new_w, h, w], dtype=torch.float)
data_dict["final_height"] = new_h
data_dict["final_width"] = new_w

Expand Down
5 changes: 3 additions & 2 deletions cosmos_framework/data/vfm/augmentors/video_parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,8 +442,9 @@ def __init__(self, input_keys: list, output_keys: Optional[list] = None, args: O
self.resolution_tier = _DATASET_RESOLUTION_TIER.get(self.dataset_resolution_type)

# VAE temporal alignment mode.
# causal_vae=True (default): align to 1+4N (causal VAE, e.g. Wan 2.2)
# causal_vae=False: align to 4N (non-causal VAE, e.g. UniAE)
# causal_vae=True (default): align to 1+4N (causal VAE, e.g. Wan 2.2)
# causal_vae=False: align to 1+effective_chunk_frames*N (UniAE with chunk structure)
# or 4N (generic non-causal VAE)
self.causal_vae = args.get("causal_vae", True)
self.target_resolution_key = None if args.get("resolution") is None else str(args["resolution"])
self.uniae_pad_frames = None if args.get("uniae_pad_frames") is None else int(args["uniae_pad_frames"])
Expand Down
26 changes: 17 additions & 9 deletions cosmos_framework/model/tokenizer/models/dense_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,6 +461,15 @@ def decode(
When ``pixel_trim`` is enabled and ``pad_frames > 0``, the latent
contains boundary tokens from encoding. After decoding, the
corresponding boundary pixel frames are trimmed from each chunk.

**Output shape contract**:
- Video (``temporal_patches > 1``): ``[B, T, H, W, C]`` where T is the
total number of decoded pixel frames across all chunks (after trim).
- Image (``temporal_patches == 1``): ``[B, 1, H, W, C]``. The image
latent is decoded into ``patch_time`` identical frames (it was encoded
from ``patch_time`` copies of the same frame); only the last frame is
kept. This differs from pre-``dense_runtime`` behaviour where the
full ``[B, patch_time, H, W, C]`` was returned.
"""
if self.decoder_cache_spec.patch_frames != 0:
raise NotImplementedError("Dense runtime decoder V1 does not support KV cache.")
Expand All @@ -481,21 +490,20 @@ def decode(
pad_frames = self.pad_frames
trim_pixel = self.pixel_trim and pad_frames > 0

patch_time = self.patch_size[0]
# Images were encoded as a single latent (no noncausal first chunk).
# Videos have temporal_patches > 1: latent[0] is the noncausal first frame.
is_image = temporal_patches == 1

# Patch 0 is always a single-latent chunk — either the noncausal first
# frame (video) or the sole image latent. Both were encoded from
# [frame × patch_time] copies, so all decoded frames are equivalent;
# keep the last one. For images temporal_patches == 1, so the loop
# below is empty and this is the only chunk.
decoded_chunks: list[torch.Tensor] = []
decoded_first = self._decode_latent_chunk(latent[:, 0:1]) # [B, patch_time, H, W, C]
decoded_chunks.append(decoded_first[:, -1:])

if not is_image:
# Noncausal first latent: decode → patch_time pixel frames, keep last
# (the reconstructed original first frame).
first_latent = latent[:, 0:1]
decoded_first = self._decode_latent_chunk(first_latent) # [B, patch_time, H, W, C]
decoded_chunks.append(decoded_first[:, -1:])

for start_patch in range(0 if is_image else 1, temporal_patches, chunk_patch_frames):
for start_patch in range(1, temporal_patches, chunk_patch_frames):
end_patch = min(start_patch + chunk_patch_frames, temporal_patches)
latent_chunk = latent[:, start_patch:end_patch]
decoded_chunk = self._decode_latent_chunk(latent_chunk)
Expand Down
4 changes: 2 additions & 2 deletions cosmos_framework/model/vfm/mot/context_parallel_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ def gather_seq_scatter_heads(
x: shape of [z, seq, h, ...]
seq_dim: the dimension to gather
head_dim: the dimension to scatter
cp_mesh: ulysses sequence parallelism size
cp_mesh: sequence-sharded context-parallel mesh
Returns:
torch.Tensor: shape of gathered and scattered tensor
"""
Expand All @@ -260,7 +260,7 @@ def gather_heads_scatter_seq(
x (torch.Tensor): shape of [bsz, seq, h/n, ...]
head_dim (int): the dimension to gather
seq_dim (int): the dimension to scatter
cp_mesh (DeviceMesh): ulysses sequence parallelism size
cp_mesh (DeviceMesh): sequence-sharded context-parallel mesh
splits (List[torch.Tensor], optional): Manual splits for variable length scattering

Returns:
Expand Down
23 changes: 20 additions & 3 deletions cosmos_framework/model/vfm/mot/cosmos3_vfm_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ def __init__(self, language_model, config: Cosmos3VFMNetworkConfig):
self.num_kv_heads = text_config.num_key_value_heads
self.head_dim = text_config.head_dim
self.num_hidden_layers = text_config.num_hidden_layers
self.attention_io_layout = "sequence_sharded"
self.predict_text_tokens = config.predict_text_tokens

if config.natten_parameter_list is not None and config.joint_attn_implementation != "three_way":
Expand Down Expand Up @@ -1039,6 +1040,22 @@ def forward(
# all downstream supertoken geometry automatically in sync with the pack.
num_action_tokens_per_supertoken = packed_seq.num_action_tokens_per_supertoken

replicated_attention_io_cp = (
self.attention_io_layout == "replicated"
and self.parallel_dims is not None
and self.parallel_dims.cp_enabled
)
# ``sequence_sharded`` attention I/O shards the token sequence, so
# packing must pad sequence lengths to the CP size and the input/output
# sequence helpers need the CP mesh. ``replicated`` attention I/O keeps
# current-frame sequences replicated and uses the CP mesh later inside
# attention to slice local heads, so the effective sequence-sharding
# world size is 1 here.
sequence_shard_parallel_dims = None if replicated_attention_io_cp else self.parallel_dims
sequence_shard_world_size = (
1 if replicated_attention_io_cp else (self.parallel_dims.cp_size if self.parallel_dims else 1)
)

input_pack, attention_meta, natten_metadata_list = build_packed_sequence(
self.config.joint_attn_implementation,
packed_sequence=packed_sequence,
Expand All @@ -1053,7 +1070,7 @@ def forward(
num_layers=self.num_hidden_layers,
token_shapes=packed_seq.vision.token_shapes,
natten_parameter_list=self.natten_parameter_list,
cp_world_size=self.parallel_dims.cp_size if self.parallel_dims else 1,
cp_world_size=sequence_shard_world_size,
video_temporal_causal=self.video_temporal_causal,
skip_natten_metadata=memory is not None and not memory.requires_natten_metadata(),
vision_token_shapes=vision_token_shapes,
Expand All @@ -1067,7 +1084,7 @@ def forward(
attn_implementation=self.config.joint_attn_implementation,
input_pack=input_pack,
position_ids=packed_seq.position_ids,
parallel_dims=self.parallel_dims,
parallel_dims=sequence_shard_parallel_dims,
)

packed_outputs, lbl_metadata = self.language_model(
Expand All @@ -1079,7 +1096,7 @@ def forward(
)
last_hidden_state = get_context_parallel_last_hidden_state(
packed_outputs=packed_outputs,
parallel_dims=self.parallel_dims,
parallel_dims=sequence_shard_parallel_dims,
) # [N_total,hidden_size]
output_dict = dict()

Expand Down
Loading