Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 24 additions & 8 deletions atom/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -609,19 +609,35 @@ def __post_init__(self):
self.kv_cache_block_size % 16 == 0 or self.kv_cache_block_size == 1
), f"kv_cache_block_size ({self.kv_cache_block_size}) must be a multiple of 16 or 1"
assert 1 <= self.tensor_parallel_size <= 8
self.hf_config = get_hf_config(self.model)
if is_plugin_mode():
# plugin mode
assert (
self.plugin_config is not None
), "plugin_config is required in plugin mode"
self.hf_config = self.plugin_config.model_config.hf_config
else:
self.hf_config = get_hf_config(self.model)

self.generation_config = get_generation_config(self.model)
if self.generation_config is not None:
if (
eos_ids := getattr(self.generation_config, "eos_token_id", None)
) is not None:
self.stop_token_ids = (
[eos_ids] if isinstance(eos_ids, int) else eos_ids
)
if not hasattr(self.hf_config, "rope_parameters"):
# Compatible with both transformers < 5
rope_params = getattr(self.hf_config, "rope_scaling", {})
rope_params = getattr(self.hf_config, "rope_scaling", {}) or {}
rope_params["rope_theta"] = self.hf_config.rope_theta
rope_params["rope_type"] = getattr(rope_params, "rope_type", "default")
self.hf_config.rope_parameters = rope_params

self.generation_config = get_generation_config(self.model)
if self.generation_config is not None:
if (
eos_ids := getattr(self.generation_config, "eos_token_id", None)
) is not None:
self.stop_token_ids = [eos_ids] if isinstance(eos_ids, int) else eos_ids
# if self.generation_config is not None:
# if (
# eos_ids := getattr(self.generation_config, "eos_token_id", None)
# ) is not None:
# self.stop_token_ids = [eos_ids] if isinstance(eos_ids, int) else eos_ids
self.quant_config = get_quant_config(self.hf_config)
hf_config_max_position_embeddings = getattr(
self.hf_config, "max_position_embeddings", 8192
Expand Down
86 changes: 69 additions & 17 deletions atom/model_ops/radix_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from atom.models.utils import maybe_prefix
from atom.utils import envs

from aiter.rotary_embedding import AiterFusedSetKVBufferArg


class RadixAttention(BaseAttention):
"""
Expand Down Expand Up @@ -50,22 +52,30 @@ def __init__(
)

if is_sglang():
from sglang.srt.layers.radix_attention import RadixAttention
self.rotary_emb = rotary_emb
self.layer_num = layer_num

self.k_scale = torch.tensor([1.0], dtype=torch.float32)
self.v_scale = torch.tensor([1.0], dtype=torch.float32)

# if True, save cache will be done in rope
self.use_rope_fused_qknorm = envs.ATOM_ENABLE_QK_NORM_ROPE_CACHE_QUANT_FUSION

from sglang.srt.layers.radix_attention import RadixAttention
self.attn = RadixAttention(
num_heads=num_heads,
head_dim=head_dim,
scaling=scale,
num_kv_heads=num_kv_heads,
layer_id=layer_num,
layer_id=self.layer_num,
prefix=maybe_prefix(prefix, "attn"),
)
else:
raise NotImplementedError(
"RadixAttention is only supported for plugin mode for sglang for now"
)
# if True, save cache will be done in rope
self.use_rope_fused_qknorm = envs.ATOM_ROPE_FUSED_QKNORM_FOR_SGL_PLUGIN_MODE
# self.use_rope_fused_qknorm = envs.ATOM_ROPE_FUSED_QKNORM_FOR_SGL_PLUGIN_MODE

def forward_impl_plugin_mode(
self,
Expand All @@ -78,24 +88,64 @@ def forward_impl_plugin_mode(
output_block_scale: torch.Tensor | None = None,
positions: torch.Tensor = None,
q_scale: torch.Tensor = None,
qkv: torch.Tensor = None,
**kwargs,
):
if is_sglang():
# for sglang, forward_batch is required
forward_batch = kwargs.get("forward_batch", None)
assert forward_batch is not None, "forward_batch is required for sglang"
save_kv_cache = not self.use_rope_fused_qknorm
return self.attn(
query,
key,
value,
forward_batch=forward_batch,
save_kv_cache=save_kv_cache,
# for sglang, forward_batch is required
forward_batch = kwargs.get("forward_batch", None)
assert forward_batch is not None, "forward_batch is required for sglang"

if self.use_rope_fused_qknorm:
k_buffer, v_buffer = forward_batch.token_to_kv_pool.get_kv_buffer(
self.layer_num
)
else:
raise NotImplementedError(
"RadixAttention is only supported for plugin mode for sglang for now"
block_size = 1024 # Default fallback
if hasattr(forward_batch, "attn_backend") and hasattr(
forward_batch.attn_backend, "page_size"
):
block_size = forward_batch.attn_backend.page_size
elif hasattr(forward_batch.token_to_kv_pool, "allocator") and hasattr(
forward_batch.token_to_kv_pool.allocator, "page_size"
):
block_size = forward_batch.token_to_kv_pool.allocator.page_size
elif hasattr(forward_batch.token_to_kv_pool, "page_size"):
block_size = forward_batch.token_to_kv_pool.page_size
x = 16 // k_buffer.element_size()
aiter_fused_set_kv_buffer_arg = AiterFusedSetKVBufferArg(
kv_cache=(k_buffer, v_buffer),
cache_loc=forward_batch.out_cache_loc,
k_scale=self.k_scale,
v_scale=self.v_scale,
return_kv=True,
use_shuffle_layout=True,
block_size=block_size,
x=x,
)
q, k, v = self.rotary_emb(
qkv,
self.q_norm.weight,
self.k_norm.weight,
positions,
self.num_heads,
self.num_kv_heads,
self.q_norm.eps,
fused_set_kv_buffer_arg=aiter_fused_set_kv_buffer_arg,
)
else:
# calculate the q and k with rotary embedding
assert self.rotary_emb is not None, "rotary_emb is required"
q, k = self.rotary_emb(positions, query, key)
v = value

save_kv_cache = not self.use_rope_fused_qknorm
return self.attn(
q,
k,
v,
forward_batch=forward_batch,
save_kv_cache=save_kv_cache,
)


def forward(
self,
Expand All @@ -104,6 +154,7 @@ def forward(
value: torch.Tensor,
positions: torch.Tensor = None,
q_scale: Optional[torch.Tensor] = None,
qkv: torch.Tensor = None,
**kwargs,
):
if is_plugin_mode():
Expand All @@ -113,6 +164,7 @@ def forward(
value=value,
positions=positions,
q_scale=q_scale,
qkv=qkv,
**kwargs,
)
else:
Expand Down
95 changes: 16 additions & 79 deletions atom/models/qwen3_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from aiter.dist.parallel_state import get_pp_group, get_tensor_model_parallel_world_size

# from atom.model_ops.rotary_embedding import get_rope
from aiter.rotary_embedding import get_rope, AiterFusedSetKVBufferArg
from aiter.rotary_embedding import get_rope
from atom.config import Config, QuantizationConfig
from atom.model_ops.activation import SiluAndMul

Expand All @@ -30,7 +30,6 @@
from atom.utils.decorators import support_torch_compile
from torch import nn
from atom.model_loader.loader import load_model_in_plugin_mode
from atom.plugin.prepare import is_sglang

# import torch.distributed as dist
from transformers import PretrainedConfig
Expand All @@ -39,7 +38,6 @@
ATOM_ENABLE_QK_NORM_ROPE_CACHE_QUANT_FUSION = (
envs.ATOM_ENABLE_QK_NORM_ROPE_CACHE_QUANT_FUSION
)
ENABLE_AITER_ROPE_FUSED_QKNORM_FOR_SGL_PLUGIN_MODE = envs.ATOM_ROPE_FUSED_QKNORM_FOR_SGL_PLUGIN_MODE


class Qwen3MoeMLP(nn.Module):
Expand Down Expand Up @@ -226,65 +224,6 @@ def __init__(

self.kv_cache_dtype = kv_cache_dtype
self.layer_num = layer_num
self.k_scale = torch.tensor([1.0], dtype=torch.float32)
self.v_scale = torch.tensor([1.0], dtype=torch.float32)

def forward_sgl_plugin_mode(
self,
positions: torch.Tensor,
qkv: torch.Tensor,
**model_kwargs: dict[str, Any] | None,
):
if ENABLE_AITER_ROPE_FUSED_QKNORM_FOR_SGL_PLUGIN_MODE:
forward_batch = model_kwargs.get("forward_batch", None)
assert forward_batch is not None, "forward_batch is required for sglang"
k_buffer, v_buffer = forward_batch.token_to_kv_pool.get_kv_buffer(
self.layer_num
)
block_size = 1024 # Default fallback
if hasattr(forward_batch, "attn_backend") and hasattr(
forward_batch.attn_backend, "page_size"
):
block_size = forward_batch.attn_backend.page_size
elif hasattr(forward_batch.token_to_kv_pool, "allocator") and hasattr(
forward_batch.token_to_kv_pool.allocator, "page_size"
):
block_size = forward_batch.token_to_kv_pool.allocator.page_size
elif hasattr(forward_batch.token_to_kv_pool, "page_size"):
block_size = forward_batch.token_to_kv_pool.page_size
x = 16 // k_buffer.element_size()
aiter_fused_set_kv_buffer_arg = AiterFusedSetKVBufferArg(
kv_cache=(k_buffer, v_buffer),
cache_loc=forward_batch.out_cache_loc,
k_scale=self.k_scale,
v_scale=self.v_scale,
return_kv=True,
use_shuffle_layout=True,
block_size=block_size,
x=x,
)
q, k, v = self.rotary_emb(
qkv,
self.q_norm.weight,
self.k_norm.weight,
positions,
self.num_heads,
self.num_kv_heads,
self.q_norm.eps,
fused_set_kv_buffer_arg=aiter_fused_set_kv_buffer_arg,
)
else:
q, k, v = torch.split(
qkv, [self.q_size, self.kv_size, self.kv_size], dim=-1
)
# Add qk-norm
q = self.q_norm(q)
k = self.k_norm(k)

q, k = self.rotary_emb(positions, q, k)

attn_output = self.attn(q, k, v, positions=positions, **model_kwargs)
return attn_output

def forward(
self,
Expand All @@ -295,25 +234,23 @@ def forward(
qkv = self.qkv_proj(hidden_states)
q, k, v = torch.split(qkv, [self.q_size, self.kv_size, self.kv_size], dim=-1)
if ATOM_ENABLE_QK_NORM_ROPE_CACHE_QUANT_FUSION:
q, k, v = torch.split(
qkv, [self.q_size, self.kv_size, self.kv_size], dim=-1
)
attn_output = self.attn(
query=q, key=k, value=v, positions=positions, q_scale=None, qkv=qkv
)
attn_output = self.attn(query=q,
key=k,
value=v,
positions=positions,
q_scale=None,
qkv=qkv,
**model_kwargs)
else:
if is_sglang():
attn_output = self.forward_sgl_plugin_mode(
positions, qkv, **model_kwargs
)
else:
# Add qk-norm
q = self.q_norm(q)
k = self.k_norm(k)
# Add qk-norm
q = self.q_norm(q)
k = self.k_norm(k)

attn_output = self.attn(
query=q, key=k, value=v, positions=positions, **model_kwargs
)
attn_output = self.attn(query=q,
key=k,
value=v,
positions=positions,
**model_kwargs)
output = self.o_proj(attn_output)
return output

Expand Down
1 change: 0 additions & 1 deletion atom/utils/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@
"ATOM_LLAMA_ENABLE_AITER_TRITON_FUSED_SILU_MUL_QUANT", "1"
)
== "1",
"ATOM_ROPE_FUSED_QKNORM_FOR_SGL_PLUGIN_MODE": lambda: os.getenv("ATOM_ROPE_FUSED_QKNORM_FOR_SGL_PLUGIN_MODE", "0") == "1",
}


Expand Down
53 changes: 53 additions & 0 deletions bench_qwen.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
#!bin/bash

MODEL=/mnt/raid0/pretrained_model/Qwen/Qwen3-235B-A22B-Instruct-2507-FP8

RANGE_RATIO=0.8

# 1K/1K
ISL=1024
OSL=1024
CON=128
#CON=128
NUM=$(( CON * 4 ))

# 4K/1K
#ISL=4000
#OSL=1000
#CON=128
#CON=64
#NUM=$(( CON * 4 ))

# 10K/1K
#ISL=10000
#OSL=1000
#CON=64
#CON=32
#NUM=$(( CON * 4 ))

echo "ATOM Model=${MODEL}"
echo "ATOM ISL=${ISL}, OSL=${OSL}, NUM=${NUM}, CON=${CON} RANGE_RATIO=${RANGE_RATIO}"

sleep 2

# git clone https://github.com/kimbochen/bench_serving.git
python bench_serving/benchmark_serving.py \
--model=$MODEL \
--backend=vllm \
--base-url=http://localhost:8000 \
--dataset-name=random \
--random-input-len=$ISL \
--random-output-len=$OSL \
--random-range-ratio ${RANGE_RATIO} \
--num-prompts=${NUM} \
--max-concurrency=${CON} \
--request-rate=inf \
--ignore-eos \
--save-result \
--percentile-metrics="ttft,tpot,itl,e2el" \
--result-dir=./ \
2>&1 | tee log.bench.log

echo "ATOM Model=${MODEL}"
echo "ATOM ISL=${ISL}, OSL=${OSL}, NUM=${NUM}, CON=${CON}"
rm -rf ./*.json
Loading