diff --git a/cosmos_framework/configs/base/defaults/parallelism.py b/cosmos_framework/configs/base/defaults/parallelism.py index ffce654..59b3eff 100644 --- a/cosmos_framework/configs/base/defaults/parallelism.py +++ b/cosmos_framework/configs/base/defaults/parallelism.py @@ -3,13 +3,19 @@ """User-facing parallelism degrees shared by VFM and VLM trainers.""" +from typing import Literal + import attrs import torch +AttentionIOLayout = Literal["sequence_sharded", "replicated"] + # Canonical mapping from precision string (used in user-facing configs and # threaded through OmegaConf) to ``torch.dtype``. Consumed by sites that # need to translate ``precision`` / ``fsdp_master_dtype`` into concrete # torch dtypes (e.g. ``MixedPrecisionPolicy``, ``HFModel`` meta-init). + + PRECISION_TO_TORCH_DTYPE: dict[str, torch.dtype] = { "bfloat16": torch.bfloat16, "float16": torch.float16, @@ -31,6 +37,15 @@ class ParallelismConfig: # Number of ranks for context parallelism. context_parallel_shard_degree: int = 1 + # Tensor layout at the attention boundary when CP is enabled. Both + # layouts may run the attention kernel with head-sharded Q/K/V: + # ``sequence_sharded`` keeps surrounding projections/MLP sequence-sharded + # with Ulysses-style all-to-all into/out of attention, while + # ``replicated`` keeps current-frame hidden states replicated, slices + # local heads before attention, then reduces/gathers attention output back + # to replicated hidden states. + attention_io_layout: AttentionIOLayout = "sequence_sharded" + # Number of ranks for CFG parallelism. cfg_parallel_shard_degree: int = 1 diff --git a/cosmos_framework/data/vfm/augmentor_provider.py b/cosmos_framework/data/vfm/augmentor_provider.py index 3e3d785..fa859e4 100644 --- a/cosmos_framework/data/vfm/augmentor_provider.py +++ b/cosmos_framework/data/vfm/augmentor_provider.py @@ -10,6 +10,7 @@ import cosmos_framework.data.vfm.augmentors.append_fps_frames_for_image as append_fps_frames_for_image import cosmos_framework.data.vfm.augmentors.audio_caption as audio_caption import cosmos_framework.data.vfm.augmentors.caption_filter as caption_filter +import cosmos_framework.data.vfm.augmentors.cropping as cosmos_cropping import cosmos_framework.data.vfm.augmentors.duration_fps_text_timestamps as duration_fps_text_timestamps import cosmos_framework.data.vfm.augmentors.image_resolution_filter as image_resolution_filter import cosmos_framework.data.vfm.augmentors.merge_datadict as merge_datadict @@ -25,6 +26,9 @@ from cosmos_framework.data.vfm.augmentors import sequence_plan from cosmos_framework.data.vfm.utils import IMAGE_RES_SIZE_INFO, VIDEO_RES_SIZE_INFO +# UniAE requires spatial dimensions divisible by (spatial_compression * patch_spatial) = 16 * 2 = 32. +UNIAE_SPATIAL_MULTIPLE = 32 + AUGMENTOR_OPTIONS = {} CAMERA_MOVEMENT_PHRASES = [ @@ -617,9 +621,20 @@ def get_video_augmentor_v3( input_keys=["video"], args={"size": VIDEO_RES_SIZE_INFO[resolution]}, ), - "reflection_padding": L(padding.ReflectionPadding)( - input_keys=["video"], - args={"size": VIDEO_RES_SIZE_INFO[resolution]}, + **( + { + "reflection_padding": L(padding.ReflectionPadding)( + input_keys=["video"], + args={"size": VIDEO_RES_SIZE_INFO[resolution]}, + ) + } + if causal_vae + else { + "crop_to_multiple": L(cosmos_cropping.CropToMultiple)( + input_keys=["video"], + args={"multiple": UNIAE_SPATIAL_MULTIPLE}, + ) + } ), "text_transform": L(text_transforms_for_video.TextTransformForVideoWithFullFrames)( input_keys=["metas", "ai_caption", "sequence_plan"], @@ -781,6 +796,8 @@ def get_video_augmentor_v3_json_caption( use_dynamic_fps: bool = False, max_stride: int = 3, min_stride: int = 1, + min_fps: float = 10.0, + max_fps: float = 60.0, use_system_prompt: bool = False, resize_on_read: bool = False, dataset_resolution_type: str = "all", @@ -842,7 +859,6 @@ def get_video_augmentor_v3_json_caption( uniae_pad_frames = kwargs.get("uniae_pad_frames", None) uniae_chunk_frames = kwargs.get("uniae_chunk_frames", None) - print("Running video_augmentor_v3_json_caption...") augmentors = { # Caption parsing runs BEFORE video parsing so that VideoParsingChunkedFrames can # decode only the frames belonging to a randomly sampled caption chunk. @@ -863,6 +879,8 @@ def get_video_augmentor_v3_json_caption( "use_dynamic_fps": use_dynamic_fps, "max_stride": max_stride, "min_stride": min_stride, + "min_fps": min_fps, + "max_fps": max_fps, "seek_mode": "exact", "dataset_resolution_type": dataset_resolution_type, "resolution": resolution, @@ -908,9 +926,20 @@ def get_video_augmentor_v3_json_caption( input_keys=["video"], args={"size": VIDEO_RES_SIZE_INFO[resolution]}, ), - "reflection_padding": L(padding.ReflectionPadding)( - input_keys=["video"], - args={"size": VIDEO_RES_SIZE_INFO[resolution]}, + **( + { + "reflection_padding": L(padding.ReflectionPadding)( + input_keys=["video"], + args={"size": VIDEO_RES_SIZE_INFO[resolution]}, + ) + } + if causal_vae + else { + "crop_to_multiple": L(cosmos_cropping.CropToMultiple)( + input_keys=["video"], + args={"multiple": UNIAE_SPATIAL_MULTIPLE}, + ) + } ), # Duration/FPS timestamp augmentor - appends metadata like "The video is 2.5 seconds long and is of 24 FPS." # To customize the template or separator, add them to the args dict below: diff --git a/cosmos_framework/data/vfm/augmentors/cropping.py b/cosmos_framework/data/vfm/augmentors/cropping.py index ce65744..12b826c 100644 --- a/cosmos_framework/data/vfm/augmentors/cropping.py +++ b/cosmos_framework/data/vfm/augmentors/cropping.py @@ -73,7 +73,9 @@ def __call__(self, data_dict: dict) -> dict: # log.info(f"Data cropped from ({h}, {w}) to ({new_h}, {new_w})") data_dict[key] = transforms_F.crop(data, top=top, left=left, height=new_h, width=new_w) - # Store final dimensions for downstream use (e.g., resolution text info) + # Store final dimensions for downstream use (e.g., ResolutionTextInfo) + # Use the same image_size format as ReflectionPadding: [target_h, target_w, orig_h, orig_w] + data_dict["image_size"] = torch.tensor([new_h, new_w, h, w], dtype=torch.float) data_dict["final_height"] = new_h data_dict["final_width"] = new_w diff --git a/cosmos_framework/data/vfm/augmentors/video_parsing.py b/cosmos_framework/data/vfm/augmentors/video_parsing.py index 25a5580..1304f6c 100644 --- a/cosmos_framework/data/vfm/augmentors/video_parsing.py +++ b/cosmos_framework/data/vfm/augmentors/video_parsing.py @@ -442,8 +442,9 @@ def __init__(self, input_keys: list, output_keys: Optional[list] = None, args: O self.resolution_tier = _DATASET_RESOLUTION_TIER.get(self.dataset_resolution_type) # VAE temporal alignment mode. - # causal_vae=True (default): align to 1+4N (causal VAE, e.g. Wan 2.2) - # causal_vae=False: align to 4N (non-causal VAE, e.g. UniAE) + # causal_vae=True (default): align to 1+4N (causal VAE, e.g. Wan 2.2) + # causal_vae=False: align to 1+effective_chunk_frames*N (UniAE with chunk structure) + # or 4N (generic non-causal VAE) self.causal_vae = args.get("causal_vae", True) self.target_resolution_key = None if args.get("resolution") is None else str(args["resolution"]) self.uniae_pad_frames = None if args.get("uniae_pad_frames") is None else int(args["uniae_pad_frames"]) diff --git a/cosmos_framework/model/tokenizer/models/dense_runtime.py b/cosmos_framework/model/tokenizer/models/dense_runtime.py index 59c6b9d..9382028 100644 --- a/cosmos_framework/model/tokenizer/models/dense_runtime.py +++ b/cosmos_framework/model/tokenizer/models/dense_runtime.py @@ -461,6 +461,15 @@ def decode( When ``pixel_trim`` is enabled and ``pad_frames > 0``, the latent contains boundary tokens from encoding. After decoding, the corresponding boundary pixel frames are trimmed from each chunk. + + **Output shape contract**: + - Video (``temporal_patches > 1``): ``[B, T, H, W, C]`` where T is the + total number of decoded pixel frames across all chunks (after trim). + - Image (``temporal_patches == 1``): ``[B, 1, H, W, C]``. The image + latent is decoded into ``patch_time`` identical frames (it was encoded + from ``patch_time`` copies of the same frame); only the last frame is + kept. This differs from pre-``dense_runtime`` behaviour where the + full ``[B, patch_time, H, W, C]`` was returned. """ if self.decoder_cache_spec.patch_frames != 0: raise NotImplementedError("Dense runtime decoder V1 does not support KV cache.") @@ -481,21 +490,20 @@ def decode( pad_frames = self.pad_frames trim_pixel = self.pixel_trim and pad_frames > 0 - patch_time = self.patch_size[0] # Images were encoded as a single latent (no noncausal first chunk). # Videos have temporal_patches > 1: latent[0] is the noncausal first frame. is_image = temporal_patches == 1 + # Patch 0 is always a single-latent chunk — either the noncausal first + # frame (video) or the sole image latent. Both were encoded from + # [frame × patch_time] copies, so all decoded frames are equivalent; + # keep the last one. For images temporal_patches == 1, so the loop + # below is empty and this is the only chunk. decoded_chunks: list[torch.Tensor] = [] + decoded_first = self._decode_latent_chunk(latent[:, 0:1]) # [B, patch_time, H, W, C] + decoded_chunks.append(decoded_first[:, -1:]) - if not is_image: - # Noncausal first latent: decode → patch_time pixel frames, keep last - # (the reconstructed original first frame). - first_latent = latent[:, 0:1] - decoded_first = self._decode_latent_chunk(first_latent) # [B, patch_time, H, W, C] - decoded_chunks.append(decoded_first[:, -1:]) - - for start_patch in range(0 if is_image else 1, temporal_patches, chunk_patch_frames): + for start_patch in range(1, temporal_patches, chunk_patch_frames): end_patch = min(start_patch + chunk_patch_frames, temporal_patches) latent_chunk = latent[:, start_patch:end_patch] decoded_chunk = self._decode_latent_chunk(latent_chunk) diff --git a/cosmos_framework/model/vfm/mot/context_parallel_utils.py b/cosmos_framework/model/vfm/mot/context_parallel_utils.py index 96bf607..10b05fd 100644 --- a/cosmos_framework/model/vfm/mot/context_parallel_utils.py +++ b/cosmos_framework/model/vfm/mot/context_parallel_utils.py @@ -237,7 +237,7 @@ def gather_seq_scatter_heads( x: shape of [z, seq, h, ...] seq_dim: the dimension to gather head_dim: the dimension to scatter - cp_mesh: ulysses sequence parallelism size + cp_mesh: sequence-sharded context-parallel mesh Returns: torch.Tensor: shape of gathered and scattered tensor """ @@ -260,7 +260,7 @@ def gather_heads_scatter_seq( x (torch.Tensor): shape of [bsz, seq, h/n, ...] head_dim (int): the dimension to gather seq_dim (int): the dimension to scatter - cp_mesh (DeviceMesh): ulysses sequence parallelism size + cp_mesh (DeviceMesh): sequence-sharded context-parallel mesh splits (List[torch.Tensor], optional): Manual splits for variable length scattering Returns: diff --git a/cosmos_framework/model/vfm/mot/cosmos3_vfm_network.py b/cosmos_framework/model/vfm/mot/cosmos3_vfm_network.py index 909a1d0..f1a8f29 100644 --- a/cosmos_framework/model/vfm/mot/cosmos3_vfm_network.py +++ b/cosmos_framework/model/vfm/mot/cosmos3_vfm_network.py @@ -124,6 +124,7 @@ def __init__(self, language_model, config: Cosmos3VFMNetworkConfig): self.num_kv_heads = text_config.num_key_value_heads self.head_dim = text_config.head_dim self.num_hidden_layers = text_config.num_hidden_layers + self.attention_io_layout = "sequence_sharded" self.predict_text_tokens = config.predict_text_tokens if config.natten_parameter_list is not None and config.joint_attn_implementation != "three_way": @@ -1039,6 +1040,22 @@ def forward( # all downstream supertoken geometry automatically in sync with the pack. num_action_tokens_per_supertoken = packed_seq.num_action_tokens_per_supertoken + replicated_attention_io_cp = ( + self.attention_io_layout == "replicated" + and self.parallel_dims is not None + and self.parallel_dims.cp_enabled + ) + # ``sequence_sharded`` attention I/O shards the token sequence, so + # packing must pad sequence lengths to the CP size and the input/output + # sequence helpers need the CP mesh. ``replicated`` attention I/O keeps + # current-frame sequences replicated and uses the CP mesh later inside + # attention to slice local heads, so the effective sequence-sharding + # world size is 1 here. + sequence_shard_parallel_dims = None if replicated_attention_io_cp else self.parallel_dims + sequence_shard_world_size = ( + 1 if replicated_attention_io_cp else (self.parallel_dims.cp_size if self.parallel_dims else 1) + ) + input_pack, attention_meta, natten_metadata_list = build_packed_sequence( self.config.joint_attn_implementation, packed_sequence=packed_sequence, @@ -1053,7 +1070,7 @@ def forward( num_layers=self.num_hidden_layers, token_shapes=packed_seq.vision.token_shapes, natten_parameter_list=self.natten_parameter_list, - cp_world_size=self.parallel_dims.cp_size if self.parallel_dims else 1, + cp_world_size=sequence_shard_world_size, video_temporal_causal=self.video_temporal_causal, skip_natten_metadata=memory is not None and not memory.requires_natten_metadata(), vision_token_shapes=vision_token_shapes, @@ -1067,7 +1084,7 @@ def forward( attn_implementation=self.config.joint_attn_implementation, input_pack=input_pack, position_ids=packed_seq.position_ids, - parallel_dims=self.parallel_dims, + parallel_dims=sequence_shard_parallel_dims, ) packed_outputs, lbl_metadata = self.language_model( @@ -1079,7 +1096,7 @@ def forward( ) last_hidden_state = get_context_parallel_last_hidden_state( packed_outputs=packed_outputs, - parallel_dims=self.parallel_dims, + parallel_dims=sequence_shard_parallel_dims, ) # [N_total,hidden_size] output_dict = dict() diff --git a/cosmos_framework/model/vfm/mot/parallelize_unified_mot.py b/cosmos_framework/model/vfm/mot/parallelize_unified_mot.py index e72d127..b22a5d2 100644 --- a/cosmos_framework/model/vfm/mot/parallelize_unified_mot.py +++ b/cosmos_framework/model/vfm/mot/parallelize_unified_mot.py @@ -1,8 +1,6 @@ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: OpenMDW-1.1 -from typing import Callable - """FSDP / activation-checkpointing / torch.compile pass for the unified MoT. The activation-checkpointing implementation here mirrors the torchtitan SAC @@ -14,6 +12,7 @@ """ import re +from typing import Callable import torch import torch.nn as nn @@ -27,11 +26,15 @@ create_selective_checkpoint_contexts, ) +from cosmos_framework.utils import log from cosmos_framework.configs.base.defaults.activation_checkpointing import ActivationCheckpointingConfig from cosmos_framework.configs.base.defaults.compile import CompileConfig from cosmos_framework.data.vfm.sequence_packing import ( FactoredSequencePack, JointSequencePack, + from_und_gen_splits, + get_gen_seq, + get_und_seq, ) from cosmos_framework.model.vfm.mot.attention import SplitInfo, dispatch_attention from cosmos_framework.model.vfm.mot.context_parallel_utils import context_parallel_attention @@ -51,7 +54,7 @@ class ContextParallelDispatch(nn.Module): the inner ``wrapped_dispatch`` with Ulysses-style all-to-all communication. This includes the AR frame 1+ gen-only path — the inner dispatch routes to ``attention_AR_gen_only`` which operates on the - head-sharded tensors produced by the all-to-all. + local-head tensors produced by the all-to-all. All cache writes flow through the ``MemoryState`` interface; neither this class nor the CP attention functions write to the cache directly. @@ -90,6 +93,131 @@ def forward( ) +class ARReplicatedIODispatch(nn.Module): + """AR CP dispatch for replicated attention I/O with local-head attention. + + ``Replicated I/O`` means the caller-side tensors at the attention boundary + are replicated across CP ranks. It does **not** mean attention compute is + replicated. For AR frame 1+, this wrapper slices the replicated current + Q/K/V to this rank's local Q/KV heads and runs attention against the local + KV-head cache. + + Shape flow for AR frame 1+: + before slicing: + q: [S,H,D], k/v: [S,H_kv,D], cached k/v: [B,S_hist,H_kv/CP,D] + after local head slicing: + q: [S,H/CP,D], k/v: [S,H_kv/CP,D], cached k/v: [B,S_hist,H_kv/CP,D] + after local attention: + out_local: [S,H/CP*D] + after sharded o_proj in PackedAttentionMoT: + out: [S,hidden_size] + + Current-frame hidden states stay replicated. For AR frame 1+, this wrapper + delegates to the existing memory-aware AR attention for local heads, then + returns the local current-frame attention output so ``PackedAttentionMoT`` + can apply the corresponding ``o_proj`` column slice. Frame 0 and non-AR + paths delegate unchanged; frame 0 seeds the local KV-head cache through + ``ARMemoryState.write_for_layer``. + """ + + def __init__( + self, + cp_mesh, + wrapped_dispatch: Callable = dispatch_attention, + ) -> None: + super().__init__() + self.cp_mesh = cp_mesh + self.wrapped_dispatch = wrapped_dispatch + + def _head_slices(self, q_heads: int, kv_heads: int) -> tuple[slice, slice]: + cp_group = self.cp_mesh.get_group() + cp_rank = torch.distributed.get_rank(cp_group) + cp_size = torch.distributed.get_world_size(cp_group) + assert kv_heads % cp_size == 0, ( + f"replicated attention_io_layout requires num_kv_heads({kv_heads}) % cp_size({cp_size}) == 0. " + f"num_kv_heads={kv_heads} is the upper bound for useful local-head attention CP." + ) + assert q_heads % kv_heads == 0, f"Query heads ({q_heads}) must be divisible by KV heads ({kv_heads})" + kv_heads_per_rank = kv_heads // cp_size + q_heads_per_kv_head = q_heads // kv_heads + q_heads_per_rank = kv_heads_per_rank * q_heads_per_kv_head + kv_start = cp_rank * kv_heads_per_rank + kv_end = kv_start + kv_heads_per_rank + q_start = cp_rank * q_heads_per_rank + q_end = q_start + q_heads_per_rank + return slice(q_start, q_end), slice(kv_start, kv_end) + + def _slice_local_heads( + self, + packed_query_states: FactoredSequencePack | JointSequencePack, + packed_key_states: FactoredSequencePack | JointSequencePack, + packed_value_states: FactoredSequencePack | JointSequencePack, + ) -> tuple[ + FactoredSequencePack | JointSequencePack, + FactoredSequencePack | JointSequencePack, + FactoredSequencePack | JointSequencePack, + ]: + # Input heads are full and sequence-replicated on every CP rank: + # q: [S,H,D], k/v: [S,H_kv,D]. + q_und_seq = get_und_seq(packed_query_states) # [S_und,H,D] + q_gen_seq = get_gen_seq(packed_query_states) # [S_curr,H,D] + k_und_seq = get_und_seq(packed_key_states) # [S_und,H_kv,D] + k_gen_seq = get_gen_seq(packed_key_states) # [S_curr,H_kv,D] + v_und_seq = get_und_seq(packed_value_states) # [S_und,H_kv,D] + v_gen_seq = get_gen_seq(packed_value_states) # [S_curr,H_kv,D] + + # Slice the contiguous Q-head group that corresponds to this rank's + # contiguous KV-head group: q -> [S,H/CP,D], k/v -> [S,H_kv/CP,D]. + q_slice, kv_slice = self._head_slices(q_gen_seq.shape[1], k_gen_seq.shape[1]) + q_und_local = q_und_seq[:, q_slice, :].contiguous() # [S_und,H_local,D] + q_gen_local = q_gen_seq[:, q_slice, :].contiguous() # [S_curr,H_local,D] + k_und_local = k_und_seq[:, kv_slice, :].contiguous() # [S_und,H_kv_local,D] + k_gen_local = k_gen_seq[:, kv_slice, :].contiguous() # [S_curr,H_kv_local,D] + v_und_local = v_und_seq[:, kv_slice, :].contiguous() # [S_und,H_kv_local,D] + v_gen_local = v_gen_seq[:, kv_slice, :].contiguous() # [S_curr,H_kv_local,D] + + local_query_pack = from_und_gen_splits(q_und_local, q_gen_local, packed_query_states) + local_key_pack = from_und_gen_splits(k_und_local, k_gen_local, packed_key_states) + local_value_pack = from_und_gen_splits(v_und_local, v_gen_local, packed_value_states) + return local_query_pack, local_key_pack, local_value_pack + + def forward( + self, + packed_query_states: FactoredSequencePack | JointSequencePack, + packed_key_states: FactoredSequencePack | JointSequencePack, + packed_value_states: FactoredSequencePack | JointSequencePack, + attention_mask: BlockMask | SplitInfo, + natten_metadata: dict | None = None, + memory_value: MemoryValue | None = None, + ) -> tuple[FactoredSequencePack | JointSequencePack, KVToStore | None]: + if memory_value is None or getattr(memory_value, "frame_idx", 0) <= 0: + return self.wrapped_dispatch( + packed_query_states, + packed_key_states, + packed_value_states, + attention_mask, + natten_metadata=natten_metadata, + memory_value=memory_value, + ) + if getattr(memory_value, "for_cuda_graphs", False): + raise ValueError("replicated attention_io_layout does not support ARMemoryState(for_cuda_graphs=True)") + + local_query_pack, local_key_pack, local_value_pack = self._slice_local_heads( + packed_query_states, + packed_key_states, + packed_value_states, + ) + local_output_pack, kv_to_store = self.wrapped_dispatch( + local_query_pack, + local_key_pack, + local_value_pack, + attention_mask, + natten_metadata=natten_metadata, + memory_value=memory_value, + ) + return local_output_pack, kv_to_store + + def _apply_selective_ac( module: nn.Module, ac: ActivationCheckpointingConfig, @@ -222,6 +350,38 @@ def apply_cp( return model +def apply_replicated_attention_io_cp( + model: nn.Module, + parallel_dims: ParallelDims, +) -> nn.Module: + """Install replicated-attention-IO context parallelism on every attention layer.""" + cp_mesh = parallel_dims.cp_mesh + cp_size = parallel_dims.cp_size + first_block = next(iter(model.model.layers.children())) + first_attn = first_block.self_attn + num_kv_heads = int(first_attn.num_key_value_heads) + num_attention_heads = int(first_attn.num_attention_heads) + assert num_kv_heads % cp_size == 0, ( + f"replicated attention_io_layout requires num_kv_heads({num_kv_heads}) % cp_size({cp_size}) == 0. " + f"num_kv_heads={num_kv_heads} is the upper bound for useful local-head attention CP." + ) + log.info( + "[replicated attention I/O CP] enabled " + f"cp_size={cp_size}, num_kv_heads={num_kv_heads}, num_attention_heads={num_attention_heads}, " + f"kv_heads_per_rank={num_kv_heads // cp_size}, max_useful_cp_size={num_kv_heads}", + rank0_only=True, + ) + for _, block in model.model.layers.named_children(): + attn = block.self_attn + attn.replicated_attention_io_local_head_o_proj = True + attn.replicated_attention_io_cp_mesh = cp_mesh + attn.dispatch_attention_fn = ARReplicatedIODispatch( + cp_mesh, + wrapped_dispatch=attn.dispatch_attention_fn, + ) + return model + + def apply_fsdp( model: nn.Module, parallel_dims: ParallelDims, @@ -253,6 +413,7 @@ def parallelize_unified_mot( parallel_dims: ParallelDims | None, compile_config: CompileConfig, ac_config: ActivationCheckpointingConfig, + attention_io_layout: str = "sequence_sharded", ) -> nn.Module: """Optimize the model using CP, FSDP, activation checkpointing, and torch.compile. @@ -271,10 +432,16 @@ def parallelize_unified_mot( back to the dataclass defaults (mode="selective", save the ``save_ops_regex`` ops, mode="full", save only the outputs of each transformer block). + attention_io_layout: Tensor layout at the attention boundary under CP. """ if parallel_dims is not None and parallel_dims.cp_enabled: - apply_cp(model, parallel_dims) + if attention_io_layout == "replicated": + apply_replicated_attention_io_cp(model, parallel_dims) + elif attention_io_layout == "sequence_sharded": + apply_cp(model, parallel_dims) + else: + raise ValueError(f"Unsupported attention_io_layout={attention_io_layout!r}") apply_ac(model, ac_config) if compile_config.enabled: apply_compile(model, compile_config) diff --git a/cosmos_framework/model/vfm/mot/parallelize_vfm_network.py b/cosmos_framework/model/vfm/mot/parallelize_vfm_network.py index 704bbce..746afa0 100644 --- a/cosmos_framework/model/vfm/mot/parallelize_vfm_network.py +++ b/cosmos_framework/model/vfm/mot/parallelize_vfm_network.py @@ -47,6 +47,7 @@ def parallelize_vfm_network( parallel_dims: ParallelDims | None, compile_config: CompileConfig, ac_config: ActivationCheckpointingConfig, + attention_io_layout: str = "sequence_sharded", ) -> torch.nn.Module: """Optimize the model using FSDP, CP, activation checkpointing, and torch.compile. @@ -62,7 +63,9 @@ def parallelize_vfm_network( ``OmniMoTModelConfig.sac``. Forwarded to ``parallelize_unified_mot``; ``None`` falls back to the ``ActivationCheckpointingConfig`` defaults. + attention_io_layout: Tensor layout at the attention boundary under CP. """ + model.attention_io_layout = attention_io_layout if parallel_dims is not None and parallel_dims.cp_enabled: model.parallel_dims = parallel_dims @@ -71,6 +74,7 @@ def parallelize_vfm_network( parallel_dims=parallel_dims, compile_config=compile_config, ac_config=ac_config, + attention_io_layout=attention_io_layout, ) if compile_config.enabled and compile_config.compiled_region == "all": diff --git a/cosmos_framework/model/vfm/mot/unified_mot.py b/cosmos_framework/model/vfm/mot/unified_mot.py index 4ead628..f7c6ab8 100644 --- a/cosmos_framework/model/vfm/mot/unified_mot.py +++ b/cosmos_framework/model/vfm/mot/unified_mot.py @@ -9,6 +9,7 @@ import torch from torch import nn +from torch.distributed import ProcessGroup from cosmos_framework.model.attention import attention as imaginaire_attention from cosmos_framework.model.attention.masks import CausalType @@ -431,6 +432,21 @@ def _transform_text_dict(self, text_dict: Mapping[str, Any]) -> Mapping[str, Any # ----------------------------------------------------------------------------- +def _apply_head_sharded_o_proj( + local_attn_output: torch.Tensor, # [N,H_local*D] + projection: nn.Linear, + feature_slice: slice, + cp_group: ProcessGroup, +) -> torch.Tensor: # [N,hidden_size] + """Apply one local input-column slice of ``projection`` and sum partial outputs.""" + local_weight = projection.weight[:, feature_slice] # [hidden_size,H_local*D] + out = torch.nn.functional.linear(local_attn_output, local_weight, bias=None) # [N,hidden_size] + torch.distributed.all_reduce(out, op=torch.distributed.ReduceOp.SUM, group=cp_group) + if projection.bias is not None: + out = out + projection.bias # [N,hidden_size] + return out + + class PackedAttentionMoT(nn.Module): """ Dual-pathway packed attention for MoT architectures. @@ -502,6 +518,36 @@ def __init__( self._apply_rotary_pos_emb = layer_types.apply_rotary_pos_emb self.dispatch_attention_fn = dispatch_attention + self.replicated_attention_io_local_head_o_proj = False + self.replicated_attention_io_cp_mesh: Any | None = None + + def _replicated_attention_io_q_feature_slice(self) -> slice: + cp_mesh = self.replicated_attention_io_cp_mesh + assert cp_mesh is not None, "replicated attention I/O requires a CP mesh" + cp_group = cp_mesh.get_group() + cp_rank = torch.distributed.get_rank(cp_group) + cp_size = torch.distributed.get_world_size(cp_group) + assert self.num_key_value_heads % cp_size == 0, ( + f"cp_size({cp_size}) must divide num_key_value_heads({self.num_key_value_heads})" + ) + assert self.num_attention_heads % self.num_key_value_heads == 0, ( + f"num_attention_heads({self.num_attention_heads}) must be divisible by " + f"num_key_value_heads({self.num_key_value_heads})" + ) + kv_heads_per_rank = self.num_key_value_heads // cp_size + q_heads_per_kv_head = self.num_attention_heads // self.num_key_value_heads + q_heads_per_rank = kv_heads_per_rank * q_heads_per_kv_head + q_start = cp_rank * q_heads_per_rank + q_end = q_start + q_heads_per_rank + return slice(q_start * self.head_dim, q_end * self.head_dim) + + def _uses_replicated_attention_io_local_head_o_proj(self, memory_value: MemoryValue | None) -> bool: + return ( + self.replicated_attention_io_local_head_o_proj + and memory_value is not None + and getattr(memory_value, "frame_idx", 0) > 0 + and not getattr(memory_value, "for_cuda_graphs", False) + ) def forward( self, @@ -587,7 +633,7 @@ def forward( # Produce kv_to_store for MemoryState.write_for_layer() when the # dispatch didn't already provide one (e.g. standard or AR frame-0 - # non-CP paths). CP dispatch returns head-sharded kv_to_store + # non-CP paths). CP dispatch returns local KV-head kv_to_store # directly, so kv_to_store is already non-None in that case. # # Gradient detach is NOT done here; each MemoryState.write_for_layer() @@ -603,9 +649,39 @@ def forward( v_und[:und_len].unsqueeze(0), ) - # Apply projections directly to get final results - und_seq = self.o_proj(get_und_seq(packed_attn_output)) # [N_und,hidden_size] - gen_seq = self.o_proj_moe_gen(get_gen_seq(packed_attn_output)) # [N_gen,hidden_size] + # Attention compute is local-head under both sequence-sharded and + # replicated attention I/O layouts. The difference here is the output + # layout returned to this module. Replicated attention I/O returns only + # this rank's local heads from AR frame 1+ attention: + # gen [N_gen,H_local*D] and und [0,H_local*D]. We therefore apply the + # matching o_proj input-column slice and all-reduce partial outputs back + # to replicated hidden states. The else path receives full attention + # heads at this boundary, so regular o_proj applies: + # und [N_und,H*D] -> [N_und,hidden_size], + # gen [N_gen,H*D] -> [N_gen,hidden_size]. + if self._uses_replicated_attention_io_local_head_o_proj(memory_value): + local_und_attn = get_und_seq(packed_attn_output) # [0,H_local*D] + local_gen_attn = get_gen_seq(packed_attn_output) # [N_gen,H_local*D] + assert local_und_attn.shape[0] == 0, "replicated attention I/O only supports gen-only frame 1+ attention" + feature_slice = self._replicated_attention_io_q_feature_slice() + assert feature_slice.start is not None and feature_slice.stop is not None + expected_local_features = feature_slice.stop - feature_slice.start + assert local_gen_attn.shape[-1] == expected_local_features, ( + f"Expected local attention features {expected_local_features}, got {local_gen_attn.shape[-1]}" + ) + cp_mesh = self.replicated_attention_io_cp_mesh + assert cp_mesh is not None, "replicated attention I/O requires a CP mesh" + cp_group = cp_mesh.get_group() + und_seq = local_gen_attn.new_empty((0, self.hidden_size)) # [0,hidden_size] + gen_seq = _apply_head_sharded_o_proj( + local_gen_attn, + self.o_proj_moe_gen, + feature_slice, + cp_group, + ) # [N_gen,hidden_size] + else: + und_seq = self.o_proj(get_und_seq(packed_attn_output)) # [N_und,hidden_size] + gen_seq = self.o_proj_moe_gen(get_gen_seq(packed_attn_output)) # [N_gen,hidden_size] return from_und_gen_splits(und_seq, gen_seq, pack), kv_to_store # [N_und+N_gen,hidden_size] def reasoner_forward( diff --git a/cosmos_framework/model/vfm/omni_mot_model.py b/cosmos_framework/model/vfm/omni_mot_model.py index 652b930..64f5270 100644 --- a/cosmos_framework/model/vfm/omni_mot_model.py +++ b/cosmos_framework/model/vfm/omni_mot_model.py @@ -237,6 +237,7 @@ def build_net(self, dtype: torch.dtype): parallel_dims=self.parallel_dims, compile_config=self.config.compile, ac_config=self.config.activation_checkpointing, + attention_io_layout=self.config.parallelism.attention_io_layout, ) with misc.timer("meta to cuda and broadcast model states"): @@ -2485,12 +2486,14 @@ def _single_velocity_fn(tokens: list[list[int]], skip_text_tokens: bool): # Run sampler for all samples at once. sampler = sampler or self.sampler scheduler_type = self.config.rectified_flow_inference_config.scheduler_type - if scheduler_type == "unipc": + if isinstance(sampler, FixedStepSampler): + log.info(f"Using sampler: FixedStep (t_list={sampler.t_list}, sample_type={sampler.sample_type})") + elif scheduler_type == "unipc": log.info(f"Using sampler: UniPC (shift={shift}, num_steps={num_steps})") else: log.info(f"Using sampler: EDM (sigma_max={sigma_max}, num_steps={num_steps})") - if scheduler_type == "unipc": + if isinstance(sampler, FixedStepSampler) or scheduler_type == "unipc": latents = sampler( velocity_fn, initial_noise, diff --git a/cosmos_framework/model/vfm/tokenizers/dc_ae/dc_ae_4x32x32.py b/cosmos_framework/model/vfm/tokenizers/dc_ae/dc_ae_4x32x32.py index 18b242a..0a1f76a 100644 --- a/cosmos_framework/model/vfm/tokenizers/dc_ae/dc_ae_4x32x32.py +++ b/cosmos_framework/model/vfm/tokenizers/dc_ae/dc_ae_4x32x32.py @@ -160,7 +160,7 @@ def decode(self, latent: torch.Tensor) -> torch.Tensor: def get_latent_num_frames(self, num_pixel_frames: int) -> int: return (num_pixel_frames + self.model.cfg.num_pad_frames) // self._temporal_compression_factor - def get_pixel_num_frames(self, num_latent_frames: int) -> int: + def get_pixel_num_frames(self, num_latent_frames: int, **kwargs) -> int: return num_latent_frames * self._temporal_compression_factor - self.model.cfg.num_pad_frames @property diff --git a/cosmos_framework/model/vfm/tokenizers/flux_vae_8x8.py b/cosmos_framework/model/vfm/tokenizers/flux_vae_8x8.py index 3446646..85188b4 100644 --- a/cosmos_framework/model/vfm/tokenizers/flux_vae_8x8.py +++ b/cosmos_framework/model/vfm/tokenizers/flux_vae_8x8.py @@ -450,7 +450,7 @@ def get_latent_num_frames(self, num_pixel_frames: int) -> int: """Get number of latent frames from pixel frames.""" return num_pixel_frames # Flux VAE doesn't compress temporally - def get_pixel_num_frames(self, num_latent_frames: int) -> int: + def get_pixel_num_frames(self, num_latent_frames: int, **kwargs) -> int: """Get number of pixel frames from latent frames.""" return num_latent_frames # Flux VAE doesn't compress temporally diff --git a/cosmos_framework/model/vfm/tokenizers/interface.py b/cosmos_framework/model/vfm/tokenizers/interface.py index 3c023bc..d307626 100644 --- a/cosmos_framework/model/vfm/tokenizers/interface.py +++ b/cosmos_framework/model/vfm/tokenizers/interface.py @@ -49,7 +49,7 @@ def get_latent_num_frames(self, num_pixel_frames: int) -> int: pass @abstractmethod - def get_pixel_num_frames(self, num_latent_frames: int) -> int: + def get_pixel_num_frames(self, num_latent_frames: int, **kwargs) -> int: pass def get_latent_temporal_positions( diff --git a/cosmos_framework/model/vfm/tokenizers/uniae/frame_math.py b/cosmos_framework/model/vfm/tokenizers/uniae/frame_math.py index 5e9be83..9ab1eb9 100644 --- a/cosmos_framework/model/vfm/tokenizers/uniae/frame_math.py +++ b/cosmos_framework/model/vfm/tokenizers/uniae/frame_math.py @@ -311,6 +311,44 @@ def align_uniae_num_video_frames( return num_video_frames +def ceil_uniae_num_video_frames( + num_video_frames: int, + uniae_chunk_frames: int | Mapping[str, int], + *, + pad_frames: int, + temporal_compression_factor: int, + resolution: str | None = None, + spatial_shape: tuple[int, int] | None = None, + target_resolution_key: str | None = None, + missing_resolution_message: str = ( + "spatial_shape or target resolution must be provided for resolution-keyed UniAE chunks" + ), +) -> int: + """Round up to the nearest valid UniAE noncausal count, preserving valid partial tails.""" + if num_video_frames < 1: + return 0 + + for candidate in range(num_video_frames, num_video_frames + temporal_compression_factor + 1): + aligned_candidate = align_uniae_num_video_frames( + candidate, + uniae_chunk_frames, + pad_frames=pad_frames, + temporal_compression_factor=temporal_compression_factor, + resolution=resolution, + spatial_shape=spatial_shape, + target_resolution_key=target_resolution_key, + missing_resolution_message=missing_resolution_message, + ) + if aligned_candidate == candidate: + return candidate + + raise RuntimeError( + "Failed to find a valid UniAE frame count within one temporal-compression window: " + f"{num_video_frames=}, {uniae_chunk_frames=}, {pad_frames=}, {temporal_compression_factor=}, " + f"{resolution=}, {spatial_shape=}, {target_resolution_key=}." + ) + + def _validate_full_chunk( full_chunk: int, *, diff --git a/cosmos_framework/model/vfm/tokenizers/uniae/frame_math_test.py b/cosmos_framework/model/vfm/tokenizers/uniae/frame_math_test.py new file mode 100644 index 0000000..b233ac1 --- /dev/null +++ b/cosmos_framework/model/vfm/tokenizers/uniae/frame_math_test.py @@ -0,0 +1,43 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: OpenMDW-1.1 + +from cosmos_framework.model.vfm.tokenizers.uniae.frame_math import ( + align_uniae_num_video_frames, + ceil_uniae_num_video_frames, +) + + +def test_ceil_uniae_num_video_frames_preserves_valid_partial_tail() -> None: + assert ( + ceil_uniae_num_video_frames( + 17, + {"480": 16}, + pad_frames=1, + temporal_compression_factor=4, + resolution="480", + ) + == 17 + ) + + +def test_ceil_uniae_num_video_frames_uses_next_valid_partial_tail() -> None: + assert ( + align_uniae_num_video_frames( + 24, + {"480": 16}, + pad_frames=1, + temporal_compression_factor=4, + resolution="480", + ) + == 21 + ) + assert ( + ceil_uniae_num_video_frames( + 24, + {"480": 16}, + pad_frames=1, + temporal_compression_factor=4, + resolution="480", + ) + == 25 + ) diff --git a/cosmos_framework/model/vfm/tokenizers/uniae/noncausal_4x16x16.py b/cosmos_framework/model/vfm/tokenizers/uniae/noncausal_4x16x16.py index 57b168e..445f88f 100644 --- a/cosmos_framework/model/vfm/tokenizers/uniae/noncausal_4x16x16.py +++ b/cosmos_framework/model/vfm/tokenizers/uniae/noncausal_4x16x16.py @@ -511,6 +511,10 @@ def latent_chunk_duration(self): "Use encode_chunk_frames[res_key] // temporal_compression_factor. Will be removed in a future MR." ) + @property + def pad_frames(self) -> int: + return self.vae._pad_frames + @property def latent_ch(self) -> int: return self.vae.z_dim diff --git a/cosmos_framework/model/vfm/tokenizers/wan2pt1_vae_4x8x8.py b/cosmos_framework/model/vfm/tokenizers/wan2pt1_vae_4x8x8.py index 542c90f..f1368a0 100644 --- a/cosmos_framework/model/vfm/tokenizers/wan2pt1_vae_4x8x8.py +++ b/cosmos_framework/model/vfm/tokenizers/wan2pt1_vae_4x8x8.py @@ -814,7 +814,7 @@ def decode(self, latent: torch.Tensor) -> torch.Tensor: # latent: [B,C,T_latent def get_latent_num_frames(self, num_pixel_frames: int) -> int: return 1 + (num_pixel_frames - 1) // 4 - def get_pixel_num_frames(self, num_latent_frames: int) -> int: + def get_pixel_num_frames(self, num_latent_frames: int, **kwargs) -> int: return (num_latent_frames - 1) * 4 + 1 @property diff --git a/cosmos_framework/model/vfm/tokenizers/wan2pt2_vae_4x16x16.py b/cosmos_framework/model/vfm/tokenizers/wan2pt2_vae_4x16x16.py index 12dc6d5..a796830 100644 --- a/cosmos_framework/model/vfm/tokenizers/wan2pt2_vae_4x16x16.py +++ b/cosmos_framework/model/vfm/tokenizers/wan2pt2_vae_4x16x16.py @@ -1653,7 +1653,7 @@ def _get_ref_caches( def get_latent_num_frames(self, num_pixel_frames: int) -> int: return 1 + (num_pixel_frames - 1) // 4 - def get_pixel_num_frames(self, num_latent_frames: int) -> int: + def get_pixel_num_frames(self, num_latent_frames: int, **kwargs) -> int: return (num_latent_frames - 1) * 4 + 1 @property diff --git a/cosmos_framework/utils/vfm/model_loader.py b/cosmos_framework/utils/vfm/model_loader.py index c180ca3..eaabf04 100644 --- a/cosmos_framework/utils/vfm/model_loader.py +++ b/cosmos_framework/utils/vfm/model_loader.py @@ -194,6 +194,9 @@ def _load_model( start_time = time.time() state_dict = ModelWrapper(model).state_dict() + if any(key.startswith("net_teacher.") for key in state_dict): + log.info("Dropping net_teacher.* keys from inference load target; distillation checkpoints do not save them.") + state_dict = {key: value for key, value in state_dict.items() if not key.startswith("net_teacher.")} if checkpoint_path.startswith("s3://"): storage_reader = S3StorageReader( @@ -351,6 +354,9 @@ def load_model_from_checkpoint( # Disable EMA for inference. config.model.config.ema.enabled = False + if hasattr(config.model.config, "load_teacher_weights"): + log.info("Setting load_teacher_weights=False for inference to skip teacher checkpoint download.") + config.model.config.load_teacher_weights = False config.validate() config.freeze() # type: ignore