Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 15 additions & 2 deletions atom/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,6 +467,11 @@ def _remap_layer_name(name: str) -> list[str]:
"kimi_k25": "text_config",
}

# multimodal models fully supported by plugin mode
_PLUGIN_SUPPORTED_MULTIMODAL_MODELS: set[str] = {
"kimi_k25",
}


def get_hf_config(model: str, trust_remote_code: bool = False) -> PretrainedConfig:
config_dict, _ = PretrainedConfig.get_config_dict(
Expand All @@ -480,10 +485,18 @@ def _get_hf_token() -> str | None:
return token
return None

multimodal_model_types = _MULTIMODAL_MODEL_TYPES
if is_vllm():
# Avoid mutating module-level state
multimodal_model_types = {
name: text_key
for name, text_key in _MULTIMODAL_MODEL_TYPES.items()
if name not in _PLUGIN_SUPPORTED_MULTIMODAL_MODELS
}
# For multimodal models, extract the text sub-config so the rest of ATOM
# (which is text-only today) works transparently.
if model_type in _MULTIMODAL_MODEL_TYPES:
text_config_key = _MULTIMODAL_MODEL_TYPES[model_type]
if model_type in multimodal_model_types:
text_config_key = multimodal_model_types[model_type]
text_config_dict = config_dict.get(text_config_key, {}).copy()
# Remove auto_map to avoid trust_remote_code issues
text_config_dict.pop("auto_map", None)
Expand Down
132 changes: 132 additions & 0 deletions atom/model_config/kimi_k25.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
Comment thread
wuhuikx marked this conversation as resolved.
"""
Kimi-K2.5 Model Configuration.

This configuration supports video-chunk as an internal modality type.
A video-chunk is the smallest independently processable unit of video.
"""

from transformers import DeepseekV3Config
from transformers.configuration_utils import PretrainedConfig


class KimiK25VisionConfig(PretrainedConfig):
model_type = "kimi_k25_vision"

def __init__(
self,
# Vision Tower
patch_size: int = 14,
init_pos_emb_height: int = 64,
init_pos_emb_width: int = 64,
init_pos_emb_time: int = 4,
pos_emb_type: str = "divided_fixed",
num_attention_heads: int = 16,
num_hidden_layers: int = 27,
hidden_size: int = 1152,
intermediate_size: int = 4304,
merge_kernel_size: tuple[int, int] = (2, 2),
video_attn_type: str = "spatial_temporal",
merge_type: str = "sd2_tpool",
# MM Projector
mm_projector_type: str = "patchmerger",
mm_hidden_size: int | None = None,
projector_hidden_act: str = "gelu",
projector_ln_eps: float = 1e-5,
**kwargs,
):
super().__init__(**kwargs)
# Vision Tower
self.patch_size = patch_size
self.init_pos_emb_height = init_pos_emb_height
self.init_pos_emb_width = init_pos_emb_width
self.init_pos_emb_time = init_pos_emb_time
self.pos_emb_type = pos_emb_type
self.num_attention_heads = num_attention_heads
self.num_hidden_layers = num_hidden_layers
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.merge_kernel_size = merge_kernel_size
self.video_attn_type = video_attn_type
self.merge_type = merge_type
# MM Projector
self.mm_projector_type = mm_projector_type
if mm_hidden_size is not None:
self.mm_hidden_size = mm_hidden_size
else:
self.mm_hidden_size = hidden_size
self.projector_hidden_act = projector_hidden_act
self.projector_ln_eps = projector_ln_eps


class KimiK25Config(PretrainedConfig):
"""Kimi-K2.5 model configuration.

Kimi-K2.5 extends Kimi-K2 with vision support using video-chunks.
A video-chunk consists of multiple consecutive frames
that are processed together with temporal pooling.

Args:
vision_config: Configuration for the vision tower and projector.
text_config: Configuration for the text model (DeepseekV3).
ignore_index: The ignore index for the loss function.
media_placeholder_token_id: The token ID for media placeholders.
pad_token_id: The token ID for padding.
"""

model_type = "kimi_k25"

def __init__(
self,
vision_config: dict | KimiK25VisionConfig | None = None,
text_config: dict | DeepseekV3Config | None = None,
ignore_index: int = -100,
media_placeholder_token_id: int = 163605,
pad_token_id: int = 0,
use_unified_vision_chunk: bool = False,
video_placeholder: str = "<|kimi_k25_video_placeholder|>",
**kwargs,
):
# Vision config
if vision_config is None:
vision_config = KimiK25VisionConfig()
elif isinstance(vision_config, dict):
vision_config = KimiK25VisionConfig(**vision_config)
self.vision_config: KimiK25VisionConfig = vision_config

# Text config
if text_config is None:
text_config = DeepseekV3Config()
elif isinstance(text_config, dict):
text_config = DeepseekV3Config(**text_config)
self.text_config: DeepseekV3Config = text_config

# Set mm_hidden_size to text hidden size if not explicitly set
if self.vision_config.mm_hidden_size == self.vision_config.hidden_size:
self.vision_config.mm_hidden_size = self.text_config.hidden_size

# Other config
self.ignore_index = ignore_index
self.media_placeholder_token_id = media_placeholder_token_id
self.use_unified_vision_chunk = use_unified_vision_chunk
self.video_placeholder = video_placeholder

# Propagate quantization config from text model
if getattr(self.text_config, "quantization_config", None) is not None:
self.quantization_config = self.text_config.quantization_config

super().__init__(pad_token_id=pad_token_id, **kwargs)

@property
def hidden_size(self) -> int:
"""Get hidden size from text config for compatibility."""
return self.text_config.hidden_size

@property
def vocab_size(self) -> int:
"""Get vocab size from text config for compatibility."""
return self.text_config.vocab_size


__all__ = ["KimiK25Config", "KimiK25VisionConfig"]
15 changes: 11 additions & 4 deletions atom/plugin/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -1154,10 +1154,17 @@ def init_method_under_plugin_mode(
max_num_pages_per_req = self.vllm_config.model_config.max_model_len
max_num_reqs = self.vllm_config.scheduler_config.max_num_seqs
max_num_pages = max_num_reqs * max_num_pages_per_req
self.num_attention_heads = (
config.model_config.hf_config.num_attention_heads
// get_tp_group().world_size
)

hf_config = config.model_config.hf_config
text_config = getattr(hf_config, "text_config", None)
num_attention_heads = getattr(
hf_config, "num_attention_heads", None
) or getattr(text_config, "num_attention_heads", None)
assert (
num_attention_heads is not None
), "num_attention_heads is not found in config"

self.num_attention_heads = num_attention_heads // get_tp_group().world_size
self.padded_num_attention_heads = max(self.num_attention_heads, _MLA_MIN_HEADS)
self.block_size = kv_cache_spec.block_size
self.max_bs = max_num_reqs
Expand Down
1 change: 1 addition & 0 deletions atom/plugin/vllm/model_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
"Qwen3NextForCausalLM": "atom.models.qwen3_next:Qwen3NextForCausalLM",
"Qwen3_5MoeForConditionalGeneration": "atom.models.qwen3_5:Qwen3_5MoeForConditionalGeneration_",
"Qwen3_5ForConditionalGeneration": "atom.models.qwen3_5:Qwen3_5ForConditionalGeneration_",
"KimiK25ForConditionalGeneration": "atom.plugin.vllm.models.kimi_k25:KimiK25ForConditionalGeneration_",
}


Expand Down
Loading
Loading