Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 28 additions & 1 deletion atom/model_ops/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -558,7 +558,7 @@ def weight_loader(
self,
param: nn.Parameter,
loaded_weight: torch.Tensor,
loaded_shard_id: int | tuple[int, ...],
loaded_shard_id: int | tuple[int, ...] | None = None,
):
# Support loading multiple consecutive shards in a single tensor.
# This mirrors vLLM's behavior for packed modules like QKV.
Expand Down Expand Up @@ -592,6 +592,33 @@ def weight_loader(
current_offset += shard_size
return

if loaded_shard_id is None:
# Loaded weight is already fused on disk
# Split it and load each shard individually.
param_data = param.data
# Check if this is weight or weight_scale
is_scale_param = param is getattr(
self, "weight_scale", None
) or param is getattr(self, "input_scale", None)

# For fused weight, need to match param shape
if param_data.shape == loaded_weight.shape:
# Shapes match - direct copy
param.weight_loader_process(param_data, loaded_weight)
return

# Otherwise, split the fused weight and load each output shard
Comment thread
wuhuikx marked this conversation as resolved.
current_offset = 0
for shard_id, output_size in enumerate(self.output_sizes):
shard_size = output_size
if is_scale_param and self.quant_type == QuantType.per_1x128:
shard_size //= 128

shard = loaded_weight.narrow(self.tp_dim, current_offset, shard_size)
self.weight_loader(param, shard, shard_id)
current_offset += shard_size
return
Comment thread
wuhuikx marked this conversation as resolved.

param_data = param.data
shard_offset = sum(self.output_sizes[:loaded_shard_id]) // self.tp_size
shard_size = self.output_sizes[loaded_shard_id] // self.tp_size
Expand Down
134 changes: 48 additions & 86 deletions atom/models/qwen3_5.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from torch import nn


from aiter.dist.parallel_state import get_tensor_model_parallel_rank
from atom.config import QuantizationConfig, Config

from atom.model_ops.topK import is_rocm_aiter_fusion_shared_expert_enabled
Expand Down Expand Up @@ -42,105 +41,68 @@
)
from atom.model_ops.split_chunk import fused_split_chunk_zeros

GDN = Qwen3NextGatedDeltaNet
if is_vllm():
from vllm.config import get_current_vllm_config
from vllm.model_executor.layers.mamba.abstract import MambaBase
from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateShapeCalculator,
MambaStateDtypeCalculator,
MambaStateCopyFunc,
MambaStateCopyFuncCalculator,
)
from vllm.distributed import get_tensor_model_parallel_rank

class Qwen3NextGatedDeltaNetVllm(Qwen3NextGatedDeltaNet, MambaBase):

def __init__(
self, config, quant_config=None, speculative_config=None, prefix=""
):
super().__init__(config, quant_config, speculative_config, prefix)
self.model_config = config.plugin_config.vllm_config.model_config
self.cache_config = config.plugin_config.vllm_config.cache_config
self.tp_rank = get_tensor_model_parallel_rank()
compilation_config = get_current_vllm_config().compilation_config
if prefix in compilation_config.static_forward_context:
raise ValueError(f"Duplicate layer name: {prefix}")
compilation_config.static_forward_context[prefix] = self

def create_qkvz_proj(
self,
hidden_size: int,
key_dim: int,
value_dim: int,
quant_config: QuantizationConfig | None,
prefix: str,
) -> MergedColumnParallelLinear:

return MergedColumnParallelLinear(
input_size=hidden_size,
output_sizes=[key_dim, key_dim, value_dim, value_dim],
bias=False,
quant_config=quant_config,
prefix=prefix,
)

def create_ba_proj(
self,
hidden_size: int,
num_v_heads: int,
quant_config: QuantizationConfig | None,
prefix: str,
) -> MergedColumnParallelLinear:
# Qwen3.5 has separate in_proj_b and in_proj_a weights in the
# checkpoint, which are loaded into the fused in_proj_ba parameter
# via stacked_params_mapping with shard_id 0 and 1 respectively.
return MergedColumnParallelLinear(
input_size=hidden_size,
output_sizes=[num_v_heads] * 2,
bias=False,
quant_config=quant_config,
prefix=prefix,
)

def create_qkvzba_proj(self):
self.in_proj_qkvz = self.create_qkvz_proj(
hidden_size=self.hidden_size,
key_dim=self.key_dim,
value_dim=self.value_dim,
quant_config=self.quant_config,
prefix=f"{self.prefix}.in_proj_qkvz",
)

self.in_proj_ba = self.create_ba_proj(
hidden_size=self.hidden_size,
num_v_heads=self.num_v_heads,
quant_config=self.quant_config,
prefix=f"{self.prefix}.in_proj_ba",
)
class Qwen3_5GatedDeltaNet(Qwen3NextGatedDeltaNet):

def get_state_dtype(self) -> tuple[torch.dtype, torch.dtype]:
return MambaStateDtypeCalculator.gated_delta_net_state_dtype(
self.model_config.dtype,
self.cache_config.mamba_cache_dtype,
self.cache_config.mamba_ssm_cache_dtype,
)
def create_qkvz_proj(
self,
hidden_size: int,
key_dim: int,
value_dim: int,
quant_config: QuantizationConfig | None,
prefix: str,
) -> MergedColumnParallelLinear:

return MergedColumnParallelLinear(
input_size=hidden_size,
output_sizes=[key_dim, key_dim, value_dim, value_dim],
bias=False,
quant_config=quant_config,
prefix=prefix,
)

def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]:
return MambaStateShapeCalculator.gated_delta_net_state_shape(
self.tp_size,
self.num_k_heads,
self.num_v_heads,
self.head_k_dim,
self.head_v_dim,
self.conv_kernel_size,
self.num_spec,
)
def create_ba_proj(
self,
hidden_size: int,
num_v_heads: int,
quant_config: QuantizationConfig | None,
prefix: str,
) -> MergedColumnParallelLinear:
# Qwen3.5 has separate in_proj_b and in_proj_a weights in the
# checkpoint, which are loaded into the fused in_proj_ba parameter
# via stacked_params_mapping with shard_id 0 and 1 respectively.
return MergedColumnParallelLinear(
input_size=hidden_size,
output_sizes=[num_v_heads] * 2,
bias=False,
quant_config=quant_config,
prefix=prefix,
)

GDN = Qwen3NextGatedDeltaNetVllm
def create_qkvzba_proj(self, quant_config, prefix):
self.in_proj_qkvz = self.create_qkvz_proj(
hidden_size=self.hidden_size,
key_dim=self.key_dim,
value_dim=self.value_dim,
quant_config=quant_config,
prefix=f"{prefix}.in_proj_qkvz",
)

self.in_proj_ba = self.create_ba_proj(
hidden_size=self.hidden_size,
num_v_heads=self.num_v_heads,
quant_config=quant_config,
prefix=f"{prefix}.in_proj_ba",
)

class Qwen3_5GatedDeltaNet(GDN):
def fix_query_key_value_ordering(
self,
mixed_qkvz: torch.Tensor,
Expand Down
Loading
Loading