Skip to content
33 changes: 20 additions & 13 deletions atom/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -895,22 +895,29 @@ def __post_init__(self):
# assert os.path.isdir(self.model)

assert 1 <= self.tensor_parallel_size <= 8
self.hf_config = get_hf_config(self.model)
if is_plugin_mode():
# plugin mode
assert (
self.plugin_config is not None
), "plugin_config is required in plugin mode"
self.hf_config = self.plugin_config.model_config.hf_config
else:
self.hf_config = get_hf_config(self.model)

self.generation_config = get_generation_config(self.model)
if self.generation_config is not None:
if (
eos_ids := getattr(self.generation_config, "eos_token_id", None)
) is not None:
self.stop_token_ids = (
[eos_ids] if isinstance(eos_ids, int) else eos_ids
)
if not hasattr(self.hf_config, "rope_parameters"):
# Compatible with both transformers < 5
rope_params = getattr(self.hf_config, "rope_scaling", {})
if rope_params is None:
rope_params = {}
rope_params["rope_theta"] = getattr(self.hf_config, "rope_theta", None)
rope_params["rope_type"] = getattr(self.hf_config, "rope_type", "default")
rope_params = getattr(self.hf_config, "rope_scaling", {}) or {}
rope_params["rope_theta"] = self.hf_config.rope_theta
rope_params["rope_type"] = getattr(rope_params, "rope_type", "default")
self.hf_config.rope_parameters = rope_params

self.generation_config = get_generation_config(self.model)
if self.generation_config is not None:
if (
eos_ids := getattr(self.generation_config, "eos_token_id", None)
) is not None:
self.stop_token_ids = [eos_ids] if isinstance(eos_ids, int) else eos_ids
self.quant_config = QuantizationConfig(
self.hf_config,
self.plugin_config.vllm_config if self.plugin_config is not None else None,
Expand Down
1 change: 1 addition & 0 deletions atom/model_ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# it can be assigned to different attention ops.
# By default, PagedAttention is used.
# For sglang, RadixAttention will be assigned to Attention
# see register.py for details.
Attention = PagedAttention

__all__ = [
Expand Down
11 changes: 11 additions & 0 deletions atom/model_ops/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,10 +394,21 @@ def process_weights_after_loading(self):
if self.quant_type == QuantType.per_1x32:
self.weight_scale.data = fp4_utils.e8m0_shuffle(self.weight_scale.data)

_diag_forward_counter = 0

@mark_trace
def forward(
self, x: torch.Tensor, x_scale: Optional[torch.Tensor] = None, otype=dtypes.bf16
) -> torch.Tensor:
if LinearBase._diag_forward_counter < 5:
LinearBase._diag_forward_counter += 1
print(
f"[DIAG][LinearBase.forward] prefix={self.prefix} "
f"quant_type={self.quant_type} "
f"w_dtype={self.weight.dtype} "
f"w_scale_shape={tuple(self.weight_scale.shape) if self.weight_scale is not None else None} "
f"x shape={tuple(x.shape)}"
)
if self.quant_type.value == QuantType.No.value:
y = tgemm.mm(
x,
Expand Down
43 changes: 40 additions & 3 deletions atom/model_ops/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1401,8 +1401,7 @@ def apply(
apply_router_weight_on_input=apply_router_weight_on_input,
)
else:
# Direct kernel call for non-EP/DP cases
return rocm_asm_moe_impl(
return torch.ops.aiter.rocm_aiter_fused_moe(
x,
layer.w13_weight,
layer.w2_weight,
Expand Down Expand Up @@ -1611,7 +1610,20 @@ def process_weights_after_loading(self, layer: nn.Module) -> None:

def _process_block_quant(self, layer: nn.Module) -> None:
assert self.quant_config["is_dynamic"]
print(
f"[DIAG][Fp8MoE._process_block_quant] BEFORE normalize: "
f"w13 dtype={layer.w13_weight.dtype} shape={tuple(layer.w13_weight.shape)} "
f"w13_scale dtype={layer.w13_weight_scale.dtype} shape={tuple(layer.w13_weight_scale.shape)} "
f"need_normalize={self.need_normalize_e4m3fn_to_e4m3fnuz}"
)
self._normalize_weights_and_scales(layer)
print(
f"[DIAG][Fp8MoE._process_block_quant] AFTER normalize: "
f"w13 dtype={layer.w13_weight.dtype} "
f"w13_scale min={layer.w13_weight_scale.data.min().item():.6f} "
f"max={layer.w13_weight_scale.data.max().item():.6f} "
f"mean={layer.w13_weight_scale.data.float().mean().item():.6f}"
)

if not self.need_normalize_e4m3fn_to_e4m3fnuz:
layer.w13_weight = nn.Parameter(layer.w13_weight.data, requires_grad=False)
Expand All @@ -1624,6 +1636,7 @@ def _process_block_quant(self, layer: nn.Module) -> None:
)

shuffle_weights(layer.w13_weight, layer.w2_weight)
print(f"[DIAG][Fp8MoE._process_block_quant] DONE shuffle")

def _process_channel_quant(self, layer: nn.Module) -> None:
"""PTPTC"""
Expand Down Expand Up @@ -1693,13 +1706,26 @@ def get_fused_moe_quant_config(
a2_scale=layer.w2_input_scale,
per_act_token_quant=True,
)
elif self.block_quant:
if self.quant_type == QuantType.per_1x128:
block_shape = [128, 128]
elif self.quant_type == QuantType.per_1x32:
block_shape = [1, 32]
else:
block_shape = None
return fp8_w8a8_moe_quant_config(
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
block_shape=block_shape,
)
else:
return fp8_w8a8_moe_quant_config(
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
block_shape=None,
)

@mark_trace(prefix="fp8_moe", torch_compile=False)
Expand Down Expand Up @@ -1741,6 +1767,17 @@ def apply(
# per_Tensor doesn't support num_local_tokens, so fallback to
# rocm_aiter_fused_moe when using per-tensor or no modular kernel.
if self.quant_type == QuantType.per_Tensor or self.fused_experts is None:
if not hasattr(self, "_diag_apply_printed"):
self._diag_apply_printed = True
print(
f"[DIAG][Fp8MoE.apply] rocm_aiter path: "
f"quant_type={self.quant_type} "
f"w13 dtype={layer.w13_weight.dtype} shape={tuple(layer.w13_weight.shape)} "
f"w13_scale shape={tuple(layer.w13_weight_scale.shape)} "
f"w13_scale min={layer.w13_weight_scale.data.float().min().item():.6f} "
f"max={layer.w13_weight_scale.data.float().max().item():.6f} "
f"x dtype={x.dtype} shape={tuple(x.shape)}"
)
return torch.ops.aiter.rocm_aiter_fused_moe(
x,
layer.w13_weight,
Expand Down
30 changes: 25 additions & 5 deletions atom/model_ops/radix_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from .base_attention import BaseAttention
from atom.plugin.prepare import is_plugin_mode, is_sglang
from atom.models.utils import maybe_prefix
from atom.utils import envs


class RadixAttention(BaseAttention):
Expand Down Expand Up @@ -47,23 +48,37 @@ def __init__(
prefix=prefix,
**kwargs,
)
self.rotary_emb = rotary_emb

if is_sglang():
from sglang.srt.layers.radix_attention import RadixAttention

_v_head_dim = mla_modules.kv_lora_rank if (use_mla and mla_modules is not None) else head_dim

self.attn = RadixAttention(
num_heads=num_heads,
head_dim=head_dim,
scaling=scale,
num_kv_heads=num_kv_heads,
layer_id=layer_num,
v_head_dim=_v_head_dim,
prefix=maybe_prefix(prefix, "attn"),
)
if self.attn.k_scale is None:
self.attn.k_scale = torch.nn.Parameter(
torch.tensor([1.0], dtype=torch.float32, device="cuda"),
requires_grad=False,
)
if self.attn.v_scale is None:
self.attn.v_scale = torch.nn.Parameter(
torch.tensor([1.0], dtype=torch.float32, device="cuda"),
requires_grad=False,
)
else:
raise NotImplementedError(
"RadixAttention is only supported for plugin mode for sglang for now"
)
# if True, save cache will be done in rope
self.use_aiter_rope_fused_qknorm = envs.ATOM_ROPE_FUSED_QKNORM

def forward_impl_plugin_mode(
self,
Expand All @@ -81,11 +96,16 @@ def forward_impl_plugin_mode(
if is_sglang():
# for sglang, forward_batch is required
forward_batch = kwargs.get("forward_batch", None)
save_kv_cache = kwargs.get("save_kv_cache", not self.use_aiter_rope_fused_qknorm)
assert forward_batch is not None, "forward_batch is required for sglang"
if self.rotary_emb is not None:
assert positions is not None, "positions is required for ROPE"
query, key = self.rotary_emb(positions, query, key)
return self.attn(q=query, k=key, v=value, forward_batch=forward_batch)
# forward_batch contains the filed attn_backend, which will find the backend registered in ATOM
return self.attn(
query,
key,
value,
forward_batch=forward_batch,
save_kv_cache=save_kv_cache,
)
else:
raise NotImplementedError(
"RadixAttention is only supported for plugin mode for sglang for now"
Expand Down
Loading