-
Notifications
You must be signed in to change notification settings - Fork 25
[draft][plugin] sgl radix attn backend #296
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -621,22 +621,29 @@ def __post_init__(self): | |
| # assert os.path.isdir(self.model) | ||
|
|
||
| assert 1 <= self.tensor_parallel_size <= 8 | ||
| self.hf_config = get_hf_config(self.model) | ||
| if is_plugin_mode(): | ||
| # plugin mode | ||
| assert ( | ||
| self.plugin_config is not None | ||
| ), "plugin_config is required in plugin mode" | ||
| self.hf_config = self.plugin_config.model_config.hf_config | ||
| else: | ||
| self.hf_config = get_hf_config(self.model) | ||
|
|
||
| self.generation_config = get_generation_config(self.model) | ||
| if self.generation_config is not None: | ||
| if ( | ||
| eos_ids := getattr(self.generation_config, "eos_token_id", None) | ||
| ) is not None: | ||
| self.stop_token_ids = ( | ||
| [eos_ids] if isinstance(eos_ids, int) else eos_ids | ||
| ) | ||
| if not hasattr(self.hf_config, "rope_parameters"): | ||
| # Compatible with both transformers < 5 | ||
| rope_params = getattr(self.hf_config, "rope_scaling", {}) | ||
| if rope_params is None: | ||
| rope_params = {} | ||
| rope_params["rope_theta"] = getattr(self.hf_config, "rope_theta", None) | ||
| rope_params["rope_type"] = getattr(self.hf_config, "rope_type", "default") | ||
| rope_params = getattr(self.hf_config, "rope_scaling", {}) or {} | ||
| rope_params["rope_theta"] = self.hf_config.rope_theta | ||
| rope_params["rope_type"] = getattr(rope_params, "rope_type", "default") | ||
| self.hf_config.rope_parameters = rope_params | ||
|
Comment on lines
642
to
646
|
||
|
|
||
| self.generation_config = get_generation_config(self.model) | ||
| if self.generation_config is not None: | ||
| if ( | ||
| eos_ids := getattr(self.generation_config, "eos_token_id", None) | ||
| ) is not None: | ||
| self.stop_token_ids = [eos_ids] if isinstance(eos_ids, int) else eos_ids | ||
| self.quant_config = get_quant_config(self.hf_config) | ||
| hf_config_max_position_embeddings = getattr( | ||
| self.hf_config, "max_position_embeddings", 8192 | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -9,6 +9,7 @@ | |
| from .base_attention import BaseAttention | ||
| from atom.plugin.prepare import is_plugin_mode, is_sglang | ||
| from atom.models.utils import maybe_prefix | ||
| from atom.utils import envs | ||
|
|
||
|
|
||
| class RadixAttention(BaseAttention): | ||
|
|
@@ -47,7 +48,6 @@ def __init__( | |
| prefix=prefix, | ||
| **kwargs, | ||
| ) | ||
| self.rotary_emb = rotary_emb | ||
|
|
||
| if is_sglang(): | ||
| from sglang.srt.layers.radix_attention import RadixAttention | ||
|
|
@@ -64,6 +64,8 @@ def __init__( | |
| raise NotImplementedError( | ||
| "RadixAttention is only supported for plugin mode for sglang for now" | ||
| ) | ||
| # if True, save cache will be done in rope | ||
| self.use_aiter_rope_fused_qknorm = envs.ATOM_ROPE_FUSED_QKNORM | ||
|
|
||
| def forward_impl_plugin_mode( | ||
| self, | ||
|
|
@@ -82,10 +84,14 @@ def forward_impl_plugin_mode( | |
| # for sglang, forward_batch is required | ||
| forward_batch = kwargs.get("forward_batch", None) | ||
| assert forward_batch is not None, "forward_batch is required for sglang" | ||
| if self.rotary_emb is not None: | ||
| assert positions is not None, "positions is required for ROPE" | ||
| query, key = self.rotary_emb(positions, query, key) | ||
| return self.attn(q=query, k=key, v=value, forward_batch=forward_batch) | ||
| # forward_batch contains the filed attn_backend, which will find the backend registered in ATOM | ||
| return self.attn( | ||
|
Comment on lines
+87
to
+88
|
||
| query, | ||
| key, | ||
| value, | ||
| forward_batch=forward_batch, | ||
| save_kv_cache=not self.use_aiter_rope_fused_qknorm, | ||
| ) | ||
| else: | ||
| raise NotImplementedError( | ||
| "RadixAttention is only supported for plugin mode for sglang for now" | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -5,7 +5,7 @@ | |
| from aiter.dist.parallel_state import get_pp_group, get_tensor_model_parallel_world_size | ||
|
|
||
| # from atom.model_ops.rotary_embedding import get_rope | ||
| from aiter.rotary_embedding import get_rope | ||
| from aiter.rotary_embedding import get_rope, AiterFusedSetKVBufferArg | ||
| from atom.config import Config, QuantizationConfig | ||
| from atom.model_ops.activation import SiluAndMul | ||
|
|
||
|
|
@@ -34,6 +34,7 @@ | |
| from atom.utils import envs | ||
| from torch import nn | ||
| from atom.model_loader.loader import load_model_in_plugin_mode | ||
| from atom.plugin.prepare import is_sglang | ||
|
|
||
| # import torch.distributed as dist | ||
| from transformers import PretrainedConfig | ||
|
|
@@ -42,6 +43,7 @@ | |
| ATOM_ENABLE_QK_NORM_ROPE_CACHE_QUANT_FUSION = ( | ||
| envs.ATOM_ENABLE_QK_NORM_ROPE_CACHE_QUANT_FUSION | ||
| ) | ||
| ENABLE_AITER_ROPE_FUSED_QKNORM_FOR_SGL_PLUGIN_MODE = envs.ATOM_ROPE_FUSED_QKNORM | ||
|
|
||
|
|
||
| class Qwen3MoeMLP(nn.Module): | ||
|
|
@@ -228,6 +230,65 @@ def __init__( | |
|
|
||
| self.kv_cache_dtype = kv_cache_dtype | ||
| self.layer_num = layer_num | ||
| self.k_scale = torch.tensor([1.0], dtype=torch.float32) | ||
| self.v_scale = torch.tensor([1.0], dtype=torch.float32) | ||
|
|
||
|
Comment on lines
+233
to
+235
|
||
| def forward_sgl_plugin_mode( | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hi, @ZhiweiYan-96 @zejunchen-zejun Here should be refined because the execution details should be contained inside the radix attention instead of the model forward |
||
| self, | ||
| positions: torch.Tensor, | ||
| qkv: torch.Tensor, | ||
| **model_kwargs: dict[str, Any] | None, | ||
| ): | ||
| if ENABLE_AITER_ROPE_FUSED_QKNORM_FOR_SGL_PLUGIN_MODE: | ||
| forward_batch = model_kwargs.get("forward_batch", None) | ||
| assert forward_batch is not None, "forward_batch is required for sglang" | ||
| k_buffer, v_buffer = forward_batch.token_to_kv_pool.get_kv_buffer( | ||
| self.layer_num | ||
| ) | ||
| block_size = 1024 # Default fallback | ||
| if hasattr(forward_batch, "attn_backend") and hasattr( | ||
| forward_batch.attn_backend, "page_size" | ||
| ): | ||
| block_size = forward_batch.attn_backend.page_size | ||
| elif hasattr(forward_batch.token_to_kv_pool, "allocator") and hasattr( | ||
| forward_batch.token_to_kv_pool.allocator, "page_size" | ||
| ): | ||
| block_size = forward_batch.token_to_kv_pool.allocator.page_size | ||
| elif hasattr(forward_batch.token_to_kv_pool, "page_size"): | ||
| block_size = forward_batch.token_to_kv_pool.page_size | ||
| x = 16 // k_buffer.element_size() | ||
| aiter_fused_set_kv_buffer_arg = AiterFusedSetKVBufferArg( | ||
| kv_cache=(k_buffer, v_buffer), | ||
| cache_loc=forward_batch.out_cache_loc, | ||
| k_scale=self.k_scale, | ||
| v_scale=self.v_scale, | ||
| return_kv=True, | ||
| use_shuffle_layout=True, | ||
| block_size=block_size, | ||
| x=x, | ||
| ) | ||
| q, k, v = self.rotary_emb( | ||
| qkv, | ||
| self.q_norm.weight, | ||
| self.k_norm.weight, | ||
| positions, | ||
| self.num_heads, | ||
| self.num_kv_heads, | ||
| self.q_norm.eps, | ||
| fused_set_kv_buffer_arg=aiter_fused_set_kv_buffer_arg, | ||
| ) | ||
| else: | ||
| q, k, v = torch.split( | ||
| qkv, [self.q_size, self.kv_size, self.kv_size], dim=-1 | ||
| ) | ||
| # Add qk-norm | ||
| q = self.q_norm(q) | ||
| k = self.k_norm(k) | ||
|
|
||
| q, k = self.rotary_emb(positions, q, k) | ||
|
|
||
| attn_output = self.attn(q, k, v, positions=positions, **model_kwargs) | ||
| return attn_output | ||
|
|
||
| def forward( | ||
| self, | ||
|
|
@@ -245,13 +306,18 @@ def forward( | |
| query=q, key=k, value=v, positions=positions, q_scale=None, qkv=qkv | ||
| ) | ||
| else: | ||
| # Add qk-norm | ||
| q = self.q_norm(q) | ||
| k = self.k_norm(k) | ||
| if is_sglang(): | ||
| attn_output = self.forward_sgl_plugin_mode( | ||
| positions, qkv, **model_kwargs | ||
| ) | ||
| else: | ||
| # Add qk-norm | ||
| q = self.q_norm(q) | ||
| k = self.k_norm(k) | ||
|
|
||
| attn_output = self.attn( | ||
| query=q, key=k, value=v, positions=positions, **model_kwargs | ||
| ) | ||
| attn_output = self.attn( | ||
| query=q, key=k, value=v, positions=positions, **model_kwargs | ||
| ) | ||
| output = self.o_proj(attn_output) | ||
| return output | ||
|
|
||
|
|
@@ -265,7 +331,7 @@ def __init__(self, atom_config=None, layer_num: int = 0, prefix: str = "") -> No | |
| self.hidden_size = config.hidden_size | ||
| rope_params = config.rope_parameters | ||
| rope_theta = rope_params["rope_theta"] | ||
| rope_scaling = rope_params | ||
| rope_scaling = None if rope_params["rope_type"] == "default" else rope_params | ||
| kv_cache_dtype = atom_config.kv_cache_dtype | ||
| max_position_embeddings = getattr(config, "max_position_embeddings", 8192) | ||
| # DecoderLayers are created with `make_layers` which passes the prefix | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here we follow the atom config post init?