diff --git a/docs/best_practices/MiniMax-M1.md b/docs/best_practices/MiniMax-M1.md
new file mode 100644
index 00000000000..891efbc5a98
--- /dev/null
+++ b/docs/best_practices/MiniMax-M1.md
@@ -0,0 +1,48 @@
+[简体中文](../zh/best_practices/MiniMax-M1.md)
+
+# MiniMax-M1 Model
+
+## I. Environment Preparation
+
+### 1.1 Support Requirements
+
+MiniMax-M1 support in FastDeploy uses a hybrid decoder stack:
+
+- Standard full-attention layers run through the existing FastDeploy attention backend.
+- Linear-attention layers use the Lightning Attention Triton kernels in `fastdeploy/model_executor/ops/triton_ops/lightning_attn.py`.
+- Current first-pass support targets BF16 inference.
+
+### 1.2 Installing FastDeploy
+
+Installation process reference document [FastDeploy GPU Installation](../get_started/installation/nvidia_gpu.md)
+
+## II. How to Use
+
+### 2.1 Basics: Starting the Service
+
+```shell
+MODEL_PATH=/models/MiniMax-Text-01
+
+python -m fastdeploy.entrypoints.openai.api_server \
+ --model "$MODEL_PATH" \
+ --port 8180 \
+ --metrics-port 8181 \
+ --engine-worker-queue-port 8182 \
+ --max-model-len 32768 \
+ --max-num-seqs 32
+```
+
+### 2.2 Model Notes
+
+- HuggingFace architecture: `MiniMaxText01ForCausalLM`
+- Hybrid layer layout: 70 linear-attention layers and 10 full-attention layers
+- MoE routing: 32 experts, top-2 experts per token
+
+## III. Known Limitations
+
+- This initial integration is focused on model structure and backend wiring.
+- Linear attention KV history uses instance variables, which needs migration to slot-based cache for proper multi-request isolation (TODO already noted in code).
+- Low-bit quantization support still requires follow-up validation against MiniMax-M1 weights.
+- Production validation should include GPU runtime checks for Lightning Attention decode/prefill paths.
+
+
diff --git a/docs/supported_models.md b/docs/supported_models.md
index b0684affc11..1ece03324e5 100644
--- a/docs/supported_models.md
+++ b/docs/supported_models.md
@@ -38,6 +38,7 @@ These models accept text input.
|⭐QWEN2.5|BF16/WINT8/FP8|Qwen/qwen2.5-72B;
Qwen/qwen2.5-32B;
Qwen/qwen2.5-14B;
Qwen/qwen2.5-7B;
Qwen/qwen2.5-3B;
Qwen/qwen2.5-1.5B;
Qwen/qwen2.5-0.5B, etc.|
|⭐QWEN2|BF16/WINT8/FP8|Qwen/Qwen/qwen2-72B;
Qwen/Qwen/qwen2-7B;
Qwen/qwen2-1.5B;
Qwen/qwen2-0.5B;
Qwen/QwQ-32, etc.|
|⭐DEEPSEEK|BF16/WINT4|unsloth/DeepSeek-V3.1-BF16;
unsloth/DeepSeek-V3-0324-BF16;
unsloth/DeepSeek-R1-BF16, etc.|
+|MINIMAX-M1|BF16|[MiniMaxAI/MiniMax-Text-01](./best_practices/MiniMax-M1.md);
MiniMaxAI/MiniMax-Text-01-Large, etc.|
|⭐GPT-OSS|BF16/WINT8|unsloth/gpt-oss-20b-BF16, etc.|
|⭐GLM-4.5/4.6|BF16/wfp8afp8|zai-org/GLM-4.5-Air;
zai-org/GLM-4.6
[最佳实践](./best_practices/GLM-4-MoE-Text.md) etc.|
@@ -52,3 +53,5 @@ These models accept multi-modal inputs (e.g., images and text).
| QWEN-VL |BF16/WINT4/FP8| Qwen/Qwen2.5-VL-72B-Instruct;
Qwen/Qwen2.5-VL-32B-Instruct;
Qwen/Qwen2.5-VL-7B-Instruct;
Qwen/Qwen2.5-VL-3B-Instruct|
More models are being supported. You can submit requests for new model support via [Github Issues](https://github.com/PaddlePaddle/FastDeploy/issues).
+
+
diff --git a/docs/zh/best_practices/MiniMax-M1.md b/docs/zh/best_practices/MiniMax-M1.md
new file mode 100644
index 00000000000..73c2a263143
--- /dev/null
+++ b/docs/zh/best_practices/MiniMax-M1.md
@@ -0,0 +1,48 @@
+[English](../../best_practices/MiniMax-M1.md)
+
+# MiniMax-M1 模型
+
+## 一、环境准备
+
+### 1.1 支持说明
+
+FastDeploy 中的 MiniMax-M1 采用混合解码器结构:
+
+- 全注意力层复用 FastDeploy 现有 Attention 后端。
+- 线性注意力层使用 `fastdeploy/model_executor/ops/triton_ops/lightning_attn.py` 中的 Lightning Attention Triton kernel。
+- 当前首版支持以 BF16 推理为主。
+
+### 1.2 安装 FastDeploy
+
+安装流程可参考 [FastDeploy GPU 安装文档](../get_started/installation/nvidia_gpu.md)
+
+## 二、使用方式
+
+### 2.1 基础启动命令
+
+```shell
+MODEL_PATH=/models/MiniMax-Text-01
+
+python -m fastdeploy.entrypoints.openai.api_server \
+ --model "$MODEL_PATH" \
+ --port 8180 \
+ --metrics-port 8181 \
+ --engine-worker-queue-port 8182 \
+ --max-model-len 32768 \
+ --max-num-seqs 32
+```
+
+### 2.2 模型特性
+
+- HuggingFace 架构名:`MiniMaxText01ForCausalLM`
+- 层类型分布:70 层线性注意力 + 10 层全注意力
+- MoE 路由:32 个专家,每个 token 选择 top-2 专家
+
+## 三、当前限制
+
+- 当前版本优先完成模型组网与后端接线。
+- 线性注意力的 KV history 当前使用实例变量存储,多请求并发场景下需迁移至 slot-based cache(已有 TODO 标注)。
+- 各类低比特量化推理能力还需要结合真实权重进一步验证。
+- Lightning Attention 的 prefill/decode 路径仍需在 GPU 环境完成端到端验证。
+
+
diff --git a/docs/zh/supported_models.md b/docs/zh/supported_models.md
index 1424d2320fb..915342cf7d8 100644
--- a/docs/zh/supported_models.md
+++ b/docs/zh/supported_models.md
@@ -36,6 +36,7 @@ python -m fastdeploy.entrypoints.openai.api_server \
|⭐QWEN2.5|BF16/WINT8/FP8|Qwen/qwen2.5-72B;
Qwen/qwen2.5-32B;
Qwen/qwen2.5-14B;
Qwen/qwen2.5-7B;
Qwen/qwen2.5-3B;
Qwen/qwen2.5-1.5B;
Qwen/qwen2.5-0.5B, etc.|
|⭐QWEN2|BF16/WINT8/FP8|Qwen/Qwen/qwen2-72B;
Qwen/Qwen/qwen2-7B;
Qwen/qwen2-1.5B;
Qwen/qwen2-0.5B;
Qwen/QwQ-32, etc.|
|⭐DEEPSEEK|BF16/WINT4|unsloth/DeepSeek-V3.1-BF16;
unsloth/DeepSeek-V3-0324-BF16;
unsloth/DeepSeek-R1-BF16, etc.|
+|MINIMAX-M1|BF16|[MiniMaxAI/MiniMax-Text-01](./best_practices/MiniMax-M1.md);
MiniMaxAI/MiniMax-Text-01-Large, etc.|
|⭐GPT-OSS|BF16/WINT8|unsloth/gpt-oss-20b-BF16, etc.|
|⭐GLM-4.5/4.6|BF16/wfp8afp8|zai-org/GLM-4.5-Air;
zai-org/GLM-4.6
[最佳实践](./best_practices/GLM-4-MoE-Text.md) etc.|
@@ -50,3 +51,5 @@ python -m fastdeploy.entrypoints.openai.api_server \
| QWEN-VL |BF16/WINT4/FP8| Qwen/Qwen2.5-VL-72B-Instruct;
Qwen/Qwen2.5-VL-32B-Instruct;
Qwen/Qwen2.5-VL-7B-Instruct;
Qwen/Qwen2.5-VL-3B-Instruct|
更多模型同步支持中,你可以通过[Github Issues](https://github.com/PaddlePaddle/FastDeploy/issues)向我们提交新模型的支持需求。
+
+
diff --git a/fastdeploy/model_executor/layers/rotary_embedding.py b/fastdeploy/model_executor/layers/rotary_embedding.py
index 485ffa1c4ad..ca402ab6dd5 100644
--- a/fastdeploy/model_executor/layers/rotary_embedding.py
+++ b/fastdeploy/model_executor/layers/rotary_embedding.py
@@ -1,4 +1,4 @@
-"""
+"""Module for Hackathon 10th Spring No.47.
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -341,7 +341,7 @@ def get_rope_impl(
"""
architecture = model_config.architectures[0]
- if architecture.startswith("Qwen"):
+ if architecture.startswith("Qwen") or architecture.startswith("MiniMaxM1"):
rotary_emb_layer = QwenRotaryEmbedding(rotary_dim, base, partial_rotary_factor)
rotary_emb = rotary_emb_layer(position_ids)
elif architecture.startswith("Glm"):
diff --git a/fastdeploy/model_executor/models/minimax_m1.py b/fastdeploy/model_executor/models/minimax_m1.py
new file mode 100644
index 00000000000..98788306556
--- /dev/null
+++ b/fastdeploy/model_executor/models/minimax_m1.py
@@ -0,0 +1,895 @@
+"""Module for Hackathon 10th Spring No.47.
+# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+MiniMax-M1 Model for FastDeploy
+Hybrid architecture: 70 linear attention layers + 10 full attention layers
+MoE: 32 experts, top-2 routing per token
+"""
+
+from __future__ import annotations
+
+import math
+import re
+from typing import Any, Dict, Union
+
+import numpy as np
+import paddle
+from paddle import nn
+from paddleformers.transformers import PretrainedModel
+from paddleformers.utils.log import logger
+
+from fastdeploy.config import FDConfig
+from fastdeploy.model_executor.forward_meta import ForwardMeta
+from fastdeploy.model_executor.graph_optimization.decorator import (
+ support_graph_optimization,
+)
+from fastdeploy.model_executor.layers.activation import SiluAndMul
+from fastdeploy.model_executor.layers.attention.attention import Attention
+from fastdeploy.model_executor.layers.embeddings import VocabParallelEmbedding
+from fastdeploy.model_executor.layers.linear import (
+ ColumnParallelLinear,
+ MergedColumnParallelLinear,
+ QKVParallelLinear,
+ ReplicatedLinear,
+ RowParallelLinear,
+)
+from fastdeploy.model_executor.layers.lm_head import ParallelLMHead
+from fastdeploy.model_executor.layers.moe.moe import FusedMoE
+from fastdeploy.model_executor.layers.normalization import RMSNorm
+from fastdeploy.model_executor.models.model_base import (
+ ModelCategory,
+ ModelForCasualLM,
+ ModelRegistry,
+)
+from fastdeploy.model_executor.ops.triton_ops.lightning_attn import lightning_attention
+
+
+class MiniMaxM1MLP(nn.Layer):
+ """MiniMax-M1 MLP Layer (Dense FFN)"""
+
+ def __init__(
+ self,
+ fd_config: FDConfig,
+ intermediate_size: int,
+ prefix: str = "",
+ reduce_results: bool = True,
+ ) -> None:
+ super().__init__()
+
+ self.gate_up_proj = MergedColumnParallelLinear(
+ fd_config=fd_config,
+ prefix=f"{prefix}.gate_up_proj",
+ input_size=fd_config.model_config.hidden_size,
+ output_size=intermediate_size * 2,
+ with_bias=False,
+ activation=fd_config.model_config.hidden_act,
+ )
+
+ self.down_proj = RowParallelLinear(
+ fd_config=fd_config,
+ prefix=f"{prefix}.down_proj",
+ input_size=intermediate_size,
+ output_size=fd_config.model_config.hidden_size,
+ with_bias=False,
+ reduce_results=reduce_results,
+ )
+
+ self.act_fn = SiluAndMul(
+ fd_config=fd_config,
+ bias=getattr(self.gate_up_proj, "bias", None),
+ act_method=fd_config.model_config.hidden_act,
+ )
+
+ def load_state_dict(self, state_dict):
+ self.gate_up_proj.load_state_dict(state_dict)
+ self.down_proj.load_state_dict(state_dict)
+
+ def forward(self, x, forward_meta=None):
+ gate_up_out = self.gate_up_proj(x)
+ act_out = self.act_fn(gate_up_out)
+ down_out = self.down_proj(act_out)
+ return down_out
+
+
+class MiniMaxM1MoE(nn.Layer):
+ """MiniMax-M1 MoE Layer with low-bit quantization support."""
+
+ def __init__(self, fd_config: FDConfig, layer_id: int, prefix: str) -> None:
+ super().__init__()
+
+ self.tp_size = fd_config.parallel_config.tensor_parallel_size
+ self.norm_topk_prob = getattr(fd_config.model_config, "norm_topk_prob", False)
+
+ # Build quantization-aware weight key map (mirrors Ernie4_5_MoE pattern)
+ moe_quant_type = ""
+ quant_config = getattr(fd_config, "quant_config", None)
+ if quant_config and hasattr(quant_config, "moe_quant_type"):
+ moe_quant_type = quant_config.moe_quant_type or ""
+
+ is_quantized = getattr(fd_config.model_config, "is_quantized", False)
+ moe_dynamic_quant = getattr(quant_config, "moe_dynamic_quant", False) if quant_config else False
+
+ if moe_quant_type in ("w4a8", "tensor_wise_fp8", "block_wise_fp8") or (
+ moe_quant_type == "w4afp8" and is_quantized and not moe_dynamic_quant
+ ):
+ weight_key_map = {
+ "gate_weight_key": f"{prefix}.gate.weight",
+ "up_gate_proj_expert_weight_key": f"{prefix}.experts.{{}}.up_gate_proj.quant_weight",
+ "down_proj_expert_weight_key": f"{prefix}.experts.{{}}.down_proj.quant_weight",
+ "up_gate_proj_expert_weight_scale_key": f"{prefix}.experts.{{}}.up_gate_proj.weight_scale",
+ "down_proj_expert_weight_scale_key": f"{prefix}.experts.{{}}.down_proj.weight_scale",
+ "up_gate_proj_expert_in_scale_key": f"{prefix}.experts.{{}}.up_gate_proj.activation_scale",
+ "down_proj_expert_in_scale_key": f"{prefix}.experts.{{}}.down_proj.activation_scale",
+ }
+ elif moe_quant_type == "w4afp8" and is_quantized:
+ # Dynamic w4afp8: no activation scales
+ weight_key_map = {
+ "gate_weight_key": f"{prefix}.gate.weight",
+ "up_gate_proj_expert_weight_key": f"{prefix}.experts.{{}}.up_gate_proj.quant_weight",
+ "down_proj_expert_weight_key": f"{prefix}.experts.{{}}.down_proj.quant_weight",
+ "up_gate_proj_expert_weight_scale_key": f"{prefix}.experts.{{}}.up_gate_proj.weight_scale",
+ "down_proj_expert_weight_scale_key": f"{prefix}.experts.{{}}.down_proj.weight_scale",
+ }
+ else:
+ # Default: unquantized
+ weight_key_map = {
+ "gate_weight_key": f"{prefix}.gate.weight",
+ "up_gate_proj_expert_weight_key": f"{prefix}.experts.{{}}.up_gate_proj.weight",
+ "down_proj_expert_weight_key": f"{prefix}.experts.{{}}.down_proj.weight",
+ }
+
+ self.gate = ReplicatedLinear(
+ fd_config=fd_config,
+ prefix=f"{prefix}.gate",
+ input_size=fd_config.model_config.hidden_size,
+ output_size=fd_config.model_config.num_local_experts,
+ with_bias=False,
+ skip_quant=True,
+ weight_dtype="float32",
+ )
+
+ self.experts = FusedMoE(
+ fd_config=fd_config,
+ reduce_results=True,
+ renormalize=self.norm_topk_prob,
+ moe_intermediate_size=fd_config.model_config.intermediate_size,
+ num_experts=fd_config.model_config.num_local_experts,
+ top_k=fd_config.model_config.num_experts_per_tok,
+ layer_idx=layer_id,
+ weight_key_map=weight_key_map,
+ )
+
+ def load_state_dict(self, state_dict):
+ self.gate.load_state_dict(state_dict)
+ self.experts.load_state_dict(state_dict)
+
+ def forward(self, hidden_states: paddle.Tensor, forward_meta: ForwardMeta):
+ """Forward pass with router gating."""
+ # FusedMoE(reduce_results=True) already handles all-reduce internally
+ return self.experts(hidden_states, self.gate, forward_meta)
+
+
+class MiniMaxM1Attention(nn.Layer):
+ """MiniMax-M1 Full Attention (standard GQA)"""
+
+ def __init__(self, fd_config: FDConfig, layer_id: int, prefix: str = "") -> None:
+ super().__init__()
+
+ self.hidden_size = fd_config.model_config.hidden_size
+ self.num_attention_heads = fd_config.model_config.num_attention_heads
+ self.head_dim = fd_config.model_config.head_dim
+ self.num_key_value_heads = fd_config.model_config.num_key_value_heads
+
+ self.qkv_proj = QKVParallelLinear(
+ fd_config=fd_config,
+ prefix=f"{prefix}.qkv_proj",
+ with_bias=False,
+ )
+
+ self.o_proj = RowParallelLinear(
+ fd_config,
+ prefix=f"{prefix}.o_proj",
+ input_size=self.num_attention_heads * self.head_dim,
+ output_size=self.hidden_size,
+ with_bias=False,
+ layer_id=layer_id,
+ )
+
+ self.attn = Attention(
+ fd_config=fd_config,
+ layer_id=layer_id,
+ prefix=prefix,
+ use_neox_rotary_style=True,
+ )
+
+ def load_state_dict(self, state_dict):
+ self.qkv_proj.load_state_dict(state_dict)
+ self.o_proj.load_state_dict(state_dict)
+ self.attn.load_state_dict(state_dict)
+
+ def forward(
+ self,
+ forward_meta: ForwardMeta,
+ hidden_states: paddle.Tensor,
+ ):
+ """Full attention forward."""
+ qkv_out = self.qkv_proj(hidden_states)
+ attn_output = self.attn(qkv=qkv_out, forward_meta=forward_meta)
+ output = self.o_proj(attn_output)
+ return output
+
+
+class MiniMaxM1LinearAttention(nn.Layer):
+ """MiniMax-M1 Linear Attention (Lightning Attention)"""
+
+ def __init__(
+ self,
+ fd_config: FDConfig,
+ layer_id: int,
+ linear_layer_id: int, # Reserved for per-linear-layer indexing in future extensions
+ prefix: str = "",
+ ) -> None:
+ super().__init__()
+
+ self.hidden_size = fd_config.model_config.hidden_size
+ self.head_dim = fd_config.model_config.head_dim
+ tp_size = fd_config.parallel_config.tensor_parallel_size
+ self.num_attention_heads = fd_config.model_config.num_attention_heads // tp_size
+ # Full (unsharded) inner dim for parallel linear layer declarations;
+ # ColumnParallelLinear divides output and RowParallelLinear divides input
+ # by tp_size internally.
+ hidden_inner = fd_config.model_config.num_attention_heads * self.head_dim
+
+ # QKV projection
+ self.qkv_proj = ColumnParallelLinear(
+ fd_config=fd_config,
+ prefix=f"{prefix}.qkv_proj",
+ input_size=self.hidden_size,
+ output_size=hidden_inner * 3,
+ with_bias=False,
+ )
+
+ # Output gate (sigmoid gating on attention output)
+ self.output_gate = ColumnParallelLinear(
+ fd_config=fd_config,
+ prefix=f"{prefix}.output_gate",
+ input_size=self.hidden_size,
+ output_size=hidden_inner,
+ with_bias=False,
+ )
+
+ # Output projection (HF name: out_proj)
+ self.out_proj = RowParallelLinear(
+ fd_config,
+ prefix=f"{prefix}.out_proj",
+ input_size=hidden_inner,
+ output_size=self.hidden_size,
+ with_bias=False,
+ layer_id=layer_id,
+ )
+
+ # RMSNorm on attention output before gating (per-TP-rank dimension)
+ self.norm = RMSNorm(
+ fd_config,
+ hidden_size=self.num_attention_heads * self.head_dim,
+ eps=1e-5,
+ prefix=f"{prefix}.norm",
+ )
+
+ # Build slope tensor for exponential decay; select this TP rank's subset
+ slope_tensor = self._build_slope_tensor(fd_config.model_config.num_attention_heads)
+ tp_rank = fd_config.parallel_config.tensor_parallel_rank
+ slope_tensor = slope_tensor[tp_rank * self.num_attention_heads : (tp_rank + 1) * self.num_attention_heads]
+ if fd_config.model_config.num_hidden_layers <= 1:
+ slope_tensor = slope_tensor * (1 + 1e-5)
+ else:
+ slope_tensor = slope_tensor * (1 - layer_id / (fd_config.model_config.num_hidden_layers - 1) + 1e-5)
+ # Register as buffer (not trainable)
+ self.register_buffer("slope_rate", slope_tensor)
+
+ # KV cache shape: [heads, head_dim, head_dim]
+ self.kv_cache_shape = (self.num_attention_heads, self.head_dim, self.head_dim)
+
+ def load_state_dict(self, state_dict):
+ self.qkv_proj.load_state_dict(state_dict)
+ self.output_gate.load_state_dict(state_dict)
+ self.out_proj.load_state_dict(state_dict)
+ self.norm.load_state_dict(state_dict)
+
+ @staticmethod
+ def _build_slope_tensor(n_heads: int):
+ """Build ALiBi-style slope tensor for exponential decay."""
+
+ def get_slopes_power_of_2(n):
+ start = 2 ** (-(2 ** (-(math.log2(n) - 3))))
+ return [start * (start**i) for i in range(n)]
+
+ if math.log2(n_heads).is_integer():
+ slopes = get_slopes_power_of_2(n_heads)
+ else:
+ closest_power = 2 ** math.floor(math.log2(n_heads))
+ slopes = get_slopes_power_of_2(closest_power)
+ slopes += get_slopes_power_of_2(2 * closest_power)[0::2][: n_heads - closest_power]
+
+ return paddle.to_tensor(slopes, dtype=paddle.float32).reshape([n_heads, 1, 1])
+
+ def forward(
+ self,
+ forward_meta: ForwardMeta,
+ hidden_states: paddle.Tensor,
+ ):
+ """Linear attention forward with output gating."""
+ # Project QKV
+ qkv = self.qkv_proj(hidden_states)
+ hidden_inner = self.num_attention_heads * self.head_dim
+ q, k, v = qkv.split([hidden_inner, hidden_inner, hidden_inner], axis=-1)
+
+ # Apply SiLU activation (matches HF MiniMax convention)
+ q = paddle.nn.functional.silu(q.astype("float32"))
+ k = paddle.nn.functional.silu(k.astype("float32"))
+ v = paddle.nn.functional.silu(v.astype("float32"))
+
+ # Reshape for lightning attention
+ batch_size = q.shape[0]
+ q = q.reshape([batch_size, -1, self.num_attention_heads, self.head_dim])
+ k = k.reshape([batch_size, -1, self.num_attention_heads, self.head_dim])
+ v = v.reshape([batch_size, -1, self.num_attention_heads, self.head_dim])
+
+ # Transpose to [batch, heads, seq_len, dim]
+ q = q.transpose([0, 2, 1, 3])
+ k = k.transpose([0, 2, 1, 3])
+ v = v.transpose([0, 2, 1, 3])
+
+ # Retrieve or initialize KV history for recurrent state persistence.
+ # TODO: Migrate to ForwardMeta.caches / slot-based cache management for
+ # proper multi-request isolation in production serving scenarios.
+ if not hasattr(self, "_kv_history") or self._kv_history is None or self._kv_history.shape[0] != batch_size:
+ self._kv_history = paddle.zeros(
+ [batch_size, self.num_attention_heads, self.head_dim, self.head_dim],
+ dtype=q.dtype,
+ )
+
+ # Apply lightning attention (returns 4D kv_history, not 5D concat)
+ attn_output, new_kv_history = lightning_attention(
+ q, k, v, self.slope_rate.squeeze(-1), block_size=256, kv_history=self._kv_history
+ )
+ # Update persisted KV state for next token generation
+ self._kv_history = new_kv_history
+
+ # Reshape back to [total_tokens, hidden_inner]
+ # FD runtime passes flat [total_tokens, hidden_size] tensors (no batch/seq split).
+ # Each "batch" entry is actually one token, so seq=1 → squeeze to 2D.
+ attn_output = attn_output.transpose([0, 2, 1, 3])
+ attn_output = attn_output.reshape([batch_size, self.num_attention_heads * self.head_dim])
+
+ # Norm → gate → output projection (matches vLLM/HF forward)
+ attn_output = self.norm(attn_output)[0]
+ gate = self.output_gate(hidden_states)
+ attn_output = paddle.nn.functional.sigmoid(gate) * attn_output.astype(hidden_states.dtype)
+ output = self.out_proj(attn_output)
+ return output
+
+
+class MiniMaxM1DecoderLayer(nn.Layer):
+ """MiniMax-M1 Decoder Layer with Hybrid Attention Dispatch"""
+
+ @staticmethod
+ def _build_attn_type_list(num_layers: int):
+ """Build attention type list: 70 linear + 10 full (at indices 7,15,23,...)."""
+ attn_type_list = [0] * num_layers # Default: all linear
+ # Full attention every 8 layers starting at layer 7
+ full_attn_indices = [7, 15, 23, 31, 39, 47, 55, 63, 71, 79]
+ for idx in full_attn_indices:
+ if idx < num_layers:
+ attn_type_list[idx] = 1
+ return attn_type_list
+
+ def __init__(
+ self,
+ fd_config: FDConfig,
+ layer_id: int,
+ prefix: str = "",
+ ) -> None:
+ super().__init__()
+
+ self.hidden_size = fd_config.model_config.hidden_size
+ self.layer_id = layer_id
+ self.postnorm = getattr(fd_config.model_config, "postnorm", False)
+
+ # Determine attention type for this layer
+ # attn_type_list: 70 linear (0) + 10 full (1) at specific indices
+ attn_type_list = getattr(
+ fd_config.model_config,
+ "attn_type_list",
+ self._build_attn_type_list(fd_config.model_config.num_hidden_layers),
+ )
+ self.attention_type = attn_type_list[layer_id] if layer_id < len(attn_type_list) else 1
+
+ # Attention layer (dispatch based on type)
+ if self.attention_type == 0: # Linear attention
+ linear_layer_id = sum(1 for i in range(layer_id) if attn_type_list[i] == 0)
+ self.self_attn = MiniMaxM1LinearAttention(
+ fd_config,
+ layer_id=layer_id,
+ linear_layer_id=linear_layer_id,
+ prefix=f"{prefix}.self_attn",
+ )
+ else: # Full attention
+ self.self_attn = MiniMaxM1Attention(
+ fd_config,
+ layer_id=layer_id,
+ prefix=f"{prefix}.self_attn",
+ )
+
+ # Input layernorm (pre-norm)
+ self.input_layernorm = RMSNorm(
+ fd_config,
+ hidden_size=fd_config.model_config.hidden_size,
+ eps=fd_config.model_config.rms_norm_eps,
+ prefix=f"{prefix}.input_layernorm",
+ )
+
+ # Post-attention layernorm
+ self.post_attention_layernorm = RMSNorm(
+ fd_config,
+ hidden_size=fd_config.model_config.hidden_size,
+ eps=fd_config.model_config.rms_norm_eps,
+ prefix=f"{prefix}.post_attention_layernorm",
+ )
+
+ # DeepNorm alpha/beta scaling — separate coefficients for linear vs full attention
+ if self.attention_type == 0: # Linear attention
+ self.layernorm_attention_alpha = getattr(
+ fd_config.model_config, "layernorm_linear_attention_alpha", 3.5565588200778455
+ )
+ self.layernorm_attention_beta = getattr(fd_config.model_config, "layernorm_linear_attention_beta", 1.0)
+ else: # Full attention
+ self.layernorm_attention_alpha = getattr(
+ fd_config.model_config, "layernorm_full_attention_alpha", 3.5565588200778455
+ )
+ self.layernorm_attention_beta = getattr(fd_config.model_config, "layernorm_full_attention_beta", 1.0)
+ self.layernorm_mlp_alpha = getattr(fd_config.model_config, "layernorm_mlp_alpha", 3.5565588200778455)
+ self.layernorm_mlp_beta = getattr(fd_config.model_config, "layernorm_mlp_beta", 1.0)
+
+ # FFN (MLP or MoE)
+ if fd_config.model_config.num_local_experts > 1:
+ self.block_sparse_moe = MiniMaxM1MoE(
+ fd_config,
+ layer_id=layer_id,
+ prefix=f"{prefix}.block_sparse_moe",
+ )
+ else:
+ self.block_sparse_moe = MiniMaxM1MLP(
+ fd_config,
+ intermediate_size=fd_config.model_config.intermediate_size,
+ prefix=f"{prefix}.mlp",
+ reduce_results=True,
+ )
+
+ def load_state_dict(self, state_dict):
+ self.self_attn.load_state_dict(state_dict)
+ self.block_sparse_moe.load_state_dict(state_dict)
+ self.input_layernorm.load_state_dict(state_dict)
+ self.post_attention_layernorm.load_state_dict(state_dict)
+
+ def forward(
+ self,
+ forward_meta: ForwardMeta,
+ hidden_states: paddle.Tensor,
+ residual: paddle.Tensor = None,
+ ):
+ """Decoder layer forward with DeepNorm.
+
+ When postnorm=True (MiniMax-M1 default), the residual stream carries the
+ *normed* activations rather than the pre-norm sum. This follows the
+ vLLM reference: ``residual = layernorm_output if postnorm else layernorm_input``.
+ """
+ # Input layernorm (fused: x + residual → norm)
+ hidden_states, residual = self.input_layernorm(
+ hidden_states,
+ residual_input=residual,
+ forward_meta=forward_meta,
+ )
+ # hidden_states = norm(input + prev_residual)
+ # residual = input + prev_residual (pre-norm)
+ if self.postnorm:
+ residual = hidden_states # postnorm: residual = normed output
+
+ # Attention (dispatch based on type)
+ attn_output = self.self_attn(forward_meta=forward_meta, hidden_states=hidden_states)
+
+ # DeepNorm alpha/beta scaling
+ residual = residual * self.layernorm_attention_alpha
+ attn_output = attn_output * self.layernorm_attention_beta
+
+ # Post-attention layernorm
+ if self.postnorm:
+ layernorm_input = residual + attn_output
+ hidden_states, residual = self.post_attention_layernorm(
+ layernorm_input,
+ forward_meta=forward_meta,
+ )
+ residual = hidden_states # postnorm: residual = normed output
+ else:
+ hidden_states, residual = self.post_attention_layernorm(
+ attn_output,
+ residual_input=residual,
+ forward_meta=forward_meta,
+ )
+
+ # FFN
+ mlp_output = self.block_sparse_moe(hidden_states, forward_meta)
+
+ # DeepNorm MLP alpha/beta
+ residual = residual * self.layernorm_mlp_alpha
+ mlp_output = mlp_output * self.layernorm_mlp_beta
+
+ hidden_states = residual + mlp_output
+
+ # Return None for residual — DeepNorm scaling already folds the
+ # residual stream into hidden_states (R·α + MLP·β). Passing
+ # ``residual`` separately would cause the next layer's fused
+ # add-norm to double-count it. Matches vLLM reference:
+ # ``return hidden_states, None``
+ return hidden_states, None
+
+
+@support_graph_optimization
+class MiniMaxM1Model(nn.Layer):
+ """MiniMax-M1 Transformer Model"""
+
+ def __init__(self, fd_config: FDConfig = None):
+ super().__init__()
+
+ self.num_layers = fd_config.model_config.num_hidden_layers
+ self.hidden_size = fd_config.model_config.hidden_size
+ fd_config.model_config.pretrained_config.prefix_name = "model"
+
+ # Embedding
+ self.embed_tokens = VocabParallelEmbedding(
+ fd_config,
+ num_embeddings=fd_config.model_config.vocab_size,
+ embedding_dim=fd_config.model_config.hidden_size,
+ prefix=f"{fd_config.model_config.pretrained_config.prefix_name}.embed_tokens",
+ )
+
+ # Decoder layers
+ self.layers = nn.LayerList(
+ [
+ MiniMaxM1DecoderLayer(
+ fd_config,
+ layer_id=i,
+ prefix=f"{fd_config.model_config.pretrained_config.prefix_name}.layers.{i}",
+ )
+ for i in range(self.num_layers)
+ ]
+ )
+
+ # Final layernorm
+ self.norm = RMSNorm(
+ fd_config,
+ hidden_size=fd_config.model_config.hidden_size,
+ eps=fd_config.model_config.rms_norm_eps,
+ prefix=f"{fd_config.model_config.pretrained_config.prefix_name}.norm",
+ )
+
+ def load_state_dict(self, state_dict):
+ """Load model parameters."""
+ self.embed_tokens.load_state_dict(state_dict)
+ self.norm.load_state_dict(state_dict)
+ for i in range(self.num_layers):
+ logger.info(f"Start load layer {i}")
+ self.layers[i].load_state_dict(state_dict)
+
+ def forward(
+ self,
+ ids_remove_padding: paddle.Tensor,
+ forward_meta: ForwardMeta,
+ ):
+ """Model forward pass."""
+ hidden_states = self.embed_tokens(ids_remove_padding=ids_remove_padding, forward_meta=forward_meta)
+
+ residual = None
+
+ # Pass through decoder layers
+ for i in range(self.num_layers):
+ hidden_states, residual = self.layers[i](
+ forward_meta=forward_meta,
+ hidden_states=hidden_states,
+ residual=residual,
+ )
+
+ # Final layernorm
+ hidden_states = self.norm(hidden_states, residual)[0]
+
+ return hidden_states
+
+
+@ModelRegistry.register_model_class(
+ architecture="MiniMaxM1ForCausalLM",
+ module_name="minimax_m1",
+ category=ModelCategory.TEXT_GENERATION,
+ primary_use=ModelCategory.TEXT_GENERATION,
+)
+@ModelRegistry.register_model_class(
+ architecture="MiniMaxText01ForCausalLM",
+ module_name="minimax_m1",
+ category=ModelCategory.TEXT_GENERATION,
+ primary_use=ModelCategory.TEXT_GENERATION,
+)
+class MiniMaxM1ForCausalLM(ModelForCasualLM):
+ """MiniMax-M1 Causal LM Model"""
+
+ # Mapping HF checkpoint names → FD merged parameter names.
+ # For full attention layers: separate q/k/v → merged qkv_proj
+ # For MoE: gate_proj/up_proj → merged gate_up_proj (dense MLP fallback)
+ _STACKED_PARAMS_MAPPING = [
+ # (fd_param_name, hf_weight_name, shard_id)
+ ("qkv_proj", "q_proj", "q"),
+ ("qkv_proj", "k_proj", "k"),
+ ("qkv_proj", "v_proj", "v"),
+ # VocabParallelEmbedding wraps weight inside .embeddings sublayer
+ ("embed_tokens.embeddings", "embed_tokens", None),
+ # ParallelLMHead wraps weight inside .linear sublayer
+ ("lm_head.linear", "lm_head", None),
+ ]
+
+ def __init__(self, fd_config: FDConfig):
+ super().__init__(fd_config)
+
+ self.model = MiniMaxM1Model(fd_config)
+ self.lm_head = ParallelLMHead(
+ fd_config,
+ embedding_dim=fd_config.model_config.hidden_size,
+ num_embeddings=fd_config.model_config.vocab_size,
+ prefix="lm_head",
+ )
+
+ @classmethod
+ def name(cls):
+ """Model name."""
+ return "MiniMaxM1ForCausalLM"
+
+ @paddle.no_grad()
+ def set_state_dict(self, state_dict: Dict[str, Union[np.ndarray, paddle.Tensor]]):
+ """Load model parameters from a given state dictionary.
+
+ Pre-processes HF weight keys to match FD naming conventions, then
+ delegates to sub-layer ``load_state_dict`` calls.
+ """
+ renamed: Dict[str, Union[np.ndarray, paddle.Tensor]] = {}
+ # Collect full-attention q/k/v weights for merging into qkv_proj
+ qkv_buffers: Dict[str, Dict[str, Union[np.ndarray, paddle.Tensor]]] = {}
+
+ for name, weight in list(state_dict.items()):
+ # Expert weights: w1→gate_proj, w3→up_proj, w2→down_proj
+ # Handles both .weight (FP) and .quant_weight / .weight_scale / .activation_scale (quantized)
+ if "block_sparse_moe.experts." in name:
+ name = re.sub(r"\.w1\.", ".gate_proj.", name)
+ name = re.sub(r"\.w3\.", ".up_proj.", name)
+ name = re.sub(r"\.w2\.", ".down_proj.", name)
+ renamed[name] = weight
+ # Full attention: merge separate q/k/v into qkv_proj
+ elif ".self_attn.q_proj." in name or ".self_attn.k_proj." in name or ".self_attn.v_proj." in name:
+ # Extract layer prefix: e.g. "model.layers.7.self_attn"
+ prefix_match = re.match(
+ r"(.*\.self_attn)\.(q|k|v)_proj\.(weight|quant_weight|weight_scale|activation_scale)$", name
+ )
+ if prefix_match:
+ attn_prefix = prefix_match.group(1)
+ proj_type = prefix_match.group(2)
+ suffix = prefix_match.group(3)
+ buf_key = f"{attn_prefix}|{suffix}"
+ if buf_key not in qkv_buffers:
+ qkv_buffers[buf_key] = {}
+ qkv_buffers[buf_key][proj_type] = weight
+ else:
+ renamed[name] = weight
+ else:
+ renamed[name] = weight
+
+ # Merge q/k/v into qkv_proj for full attention layers
+ for buf_key, projections in qkv_buffers.items():
+ if "q" in projections and "k" in projections and "v" in projections:
+ attn_prefix, suffix = buf_key.split("|", 1)
+ q_w = projections["q"]
+ k_w = projections["k"]
+ v_w = projections["v"]
+ if isinstance(q_w, np.ndarray):
+ merged = np.concatenate([q_w, k_w, v_w], axis=0)
+ else:
+ merged = paddle.concat([q_w, k_w, v_w], axis=0)
+ renamed[f"{attn_prefix}.qkv_proj.{suffix}"] = merged
+
+ self.model.load_state_dict(renamed)
+ self.lm_head.load_state_dict(renamed)
+
+ @paddle.no_grad()
+ def load_weights(self, weights_iterator) -> None:
+ """Load model parameters from a weights iterator (v1 loader path).
+
+ Handles HF→FD name mapping for:
+ - Full attention: q_proj/k_proj/v_proj → qkv_proj (stacked via shard_id)
+ - Linear attention: q_proj/k_proj/v_proj → qkv_proj (concatenated, no shard_id)
+ - MoE experts: w1/w3 → up_gate_proj, w2 → down_proj
+ """
+ from fastdeploy.model_executor.utils import (
+ default_weight_loader,
+ process_weights_after_loading,
+ )
+
+ stacked_params_mapping = list(self._STACKED_PARAMS_MAPPING)
+
+ # Expert weight mapping: HF w1/w2/w3 → FD up_gate_proj/down_proj
+ n_experts = getattr(self.fd_config.model_config, "num_local_experts", 1)
+ expert_params_mapping = FusedMoE.make_expert_params_mapping(
+ num_experts=n_experts,
+ ckpt_gate_proj_name="w1",
+ ckpt_down_proj_name="w2",
+ ckpt_up_proj_name="w3",
+ param_gate_up_proj_name="experts.up_gate_proj_",
+ param_down_proj_name="experts.down_proj_",
+ )
+
+ params_dict = dict(self.named_parameters())
+ process_weights_after_loading_fn = process_weights_after_loading(dict(self.named_sublayers()), self.fd_config)
+
+ # Build attention type list to distinguish linear vs full attention layers.
+ # Linear attention layers use ColumnParallelLinear for qkv_proj which does
+ # NOT support shard_id — q/k/v must be concatenated before loading.
+ attn_type_list = getattr(
+ self.fd_config.model_config,
+ "attn_type_list",
+ MiniMaxM1DecoderLayer._build_attn_type_list(self.fd_config.model_config.num_hidden_layers),
+ )
+
+ def _is_linear_attn_layer(weight_name: str) -> bool:
+ """Check if a weight belongs to a linear attention layer."""
+ m = re.search(r"layers\.(\d+)\.", weight_name)
+ if m is None:
+ return False
+ layer_idx = int(m.group(1))
+ return layer_idx < len(attn_type_list) and attn_type_list[layer_idx] == 0
+
+ # Buffer for linear attention q/k/v weights that need concatenation.
+ # Key: (attn_prefix, suffix) → {"q": tensor, "k": tensor, "v": tensor}
+ linear_attn_qkv_buffers: Dict[str, Dict[str, Any]] = {}
+
+ for loaded_weight_name, loaded_weight in weights_iterator:
+ logger.debug(f"Loading weight: {loaded_weight_name}")
+
+ model_param_name = None
+ param = None
+
+ # Linear attention q/k/v: buffer for concatenation (no shard_id)
+ if _is_linear_attn_layer(loaded_weight_name) and any(
+ proj in loaded_weight_name for proj in (".q_proj.", ".k_proj.", ".v_proj.")
+ ):
+ m = re.match(
+ r"(.*\.self_attn)\.(q|k|v)_proj\.(weight|quant_weight|weight_scale|activation_scale)$",
+ loaded_weight_name,
+ )
+ if m:
+ attn_prefix = m.group(1)
+ proj_type = m.group(2)
+ suffix = m.group(3)
+ buf_key = f"{attn_prefix}|{suffix}"
+ if buf_key not in linear_attn_qkv_buffers:
+ linear_attn_qkv_buffers[buf_key] = {}
+ linear_attn_qkv_buffers[buf_key][proj_type] = loaded_weight
+ continue
+
+ # Stacked params (q/k/v → qkv_proj for full attention layers)
+ for param_name, weight_name, shard_id in stacked_params_mapping:
+ if weight_name not in loaded_weight_name:
+ continue
+ # Skip expert weights — handled separately
+ if "block_sparse_moe.experts." in loaded_weight_name:
+ continue
+ model_param_name = loaded_weight_name.replace(weight_name, param_name)
+ if model_param_name not in params_dict:
+ continue
+ param = params_dict[model_param_name]
+ weight_loader = getattr(param, "weight_loader", default_weight_loader(self.fd_config))
+ weight_loader(param, loaded_weight, shard_id)
+ break
+ else:
+ # Expert params (w1/w2/w3 → up_gate_proj/down_proj)
+ for mapping in expert_params_mapping:
+ param_name, weight_name, expert_id, shard_id = mapping
+ if weight_name not in loaded_weight_name:
+ continue
+ model_param_name = loaded_weight_name.replace(weight_name, param_name)
+ if model_param_name not in params_dict:
+ continue
+ param = params_dict[model_param_name]
+ weight_loader = getattr(param, "weight_loader", default_weight_loader(self.fd_config))
+ weight_loader(param, loaded_weight, shard_id=shard_id, expert_id=expert_id)
+ break
+ else:
+ # Direct loading (norm, embed, lm_head, output_gate, out_proj, etc.)
+ model_param_name = loaded_weight_name
+ if model_param_name not in params_dict:
+ continue
+ param = params_dict[model_param_name]
+ weight_loader = getattr(param, "weight_loader", default_weight_loader(self.fd_config))
+ weight_loader(param, loaded_weight)
+
+ if model_param_name is None:
+ logger.warning(f"Weight {loaded_weight_name} not matched to any parameter, skipping")
+ continue
+ model_sublayer_name = re.sub(r"\.(up_gate_proj_weight|down_proj_weight|weight)$", "", model_param_name)
+ process_weights_after_loading_fn(model_sublayer_name, param)
+
+ # Flush buffered linear attention q/k/v → concatenated qkv_proj
+ for buf_key, projections in linear_attn_qkv_buffers.items():
+ if "q" in projections and "k" in projections and "v" in projections:
+ attn_prefix, suffix = buf_key.split("|", 1)
+ q_w, k_w, v_w = projections["q"], projections["k"], projections["v"]
+ if isinstance(q_w, np.ndarray):
+ merged = np.concatenate([q_w, k_w, v_w], axis=0)
+ else:
+ merged = paddle.concat([q_w, k_w, v_w], axis=0)
+ model_param_name = f"{attn_prefix}.qkv_proj.{suffix}"
+ if model_param_name not in params_dict:
+ logger.warning(f"Merged linear attn QKV key {model_param_name} not found, skipping")
+ continue
+ param = params_dict[model_param_name]
+ weight_loader = getattr(param, "weight_loader", default_weight_loader(self.fd_config))
+ weight_loader(param, merged)
+ model_sublayer_name = re.sub(r"\.(up_gate_proj_weight|down_proj_weight|weight)$", "", model_param_name)
+ process_weights_after_loading_fn(model_sublayer_name, param)
+ else:
+ missing = [k for k in ("q", "k", "v") if k not in projections]
+ logger.warning(f"Incomplete linear attn QKV buffer {buf_key}, missing: {missing}")
+
+ # Tie lm_head weight to embed_tokens when tie_word_embeddings is set
+ if self.fd_config.model_config.tie_word_embeddings:
+ self.lm_head.linear.weight.set_value(
+ self.model.embed_tokens.embeddings.weight.transpose([1, 0]).astype(self.lm_head.linear.weight.dtype)
+ )
+
+ def compute_logits(self, hidden_states: paddle.Tensor, forward_meta: ForwardMeta = None):
+ """Compute logits."""
+ logits = self.lm_head(hidden_states)
+ logits = logits.astype(paddle.float32)
+ return logits
+
+ def forward(
+ self,
+ inputs: Dict,
+ forward_meta: ForwardMeta,
+ ):
+ """Forward pass."""
+ ids_remove_padding = inputs["ids_remove_padding"]
+
+ hidden_states = self.model(ids_remove_padding=ids_remove_padding, forward_meta=forward_meta)
+ return hidden_states
+
+
+class MiniMaxM1PretrainedModel(PretrainedModel):
+ """MiniMax-M1 Pretrained Model"""
+
+ config_class = FDConfig
+
+ @classmethod
+ def arch_name(cls):
+ """Architecture name."""
+ return "MiniMaxM1ForCausalLM"
+
+ @classmethod
+ def name(cls):
+ """Model name."""
+ return "MiniMaxM1ForCausalLM"
diff --git a/fastdeploy/model_executor/ops/triton_ops/lightning_attn.py b/fastdeploy/model_executor/ops/triton_ops/lightning_attn.py
new file mode 100644
index 00000000000..3307d4296fd
--- /dev/null
+++ b/fastdeploy/model_executor/ops/triton_ops/lightning_attn.py
@@ -0,0 +1,733 @@
+"""Module for Hackathon 10th Spring No.47.
+# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+
+import paddle
+import triton
+import triton.language as tl
+
+from fastdeploy.model_executor.ops.triton_ops.triton_utils import (
+ enable_compat_on_triton_kernel,
+)
+
+# =============================================================================
+# Triton JIT Kernels — framework-agnostic, operate on raw pointers
+# =============================================================================
+
+
+@enable_compat_on_triton_kernel
+@triton.jit # pragma: no cover
+def _fwd_diag_kernel(
+ Q,
+ K,
+ V,
+ Out,
+ S,
+ b: tl.constexpr,
+ h: tl.constexpr,
+ n,
+ d: tl.constexpr,
+ e: tl.constexpr,
+ BLOCK: tl.constexpr,
+ NUM_BLOCK,
+ CBLOCK: tl.constexpr,
+):
+ # This kernel computes the diagonal blocks of the attention matrix
+ # Each diagonal block represents attention
+ # where queries attend to keys in the same block
+ off = tl.program_id(0)
+ off_bh = off // NUM_BLOCK # batch-head index
+ off_block = off % NUM_BLOCK # block index within the sequence
+ off_cblock = tl.program_id(1) # sub-block index within a block
+
+ off_h = off_bh % h # head index
+
+ # Calculate base offsets for the current batch and head
+ qk_offset = off_bh * n * d
+ v_offset = off_bh * n * e
+ o_offset = off_bh * n * e
+
+ # Calculate offsets for the current block
+ block_offset = off_block * BLOCK
+ qk_block_offset = block_offset * d
+ v_block_offset = block_offset * e
+ o_block_offset = block_offset * e
+
+ # Calculate offsets for the current sub-block
+ cblock_offset = off_cblock * CBLOCK
+ q_cblock_offset = cblock_offset * d
+ o_cblock_offset = cblock_offset * e
+
+ # Calculate pointers to the query, key, value, and output tensors
+ Q_block_ptr = (
+ Q
+ + qk_offset
+ + qk_block_offset
+ + q_cblock_offset
+ + tl.arange(0, CBLOCK)[:, None] * d
+ + tl.arange(0, d)[None, :]
+ )
+ K_trans_block_ptr = K + qk_offset + qk_block_offset + tl.arange(0, CBLOCK)[None, :] * d + tl.arange(0, d)[:, None]
+ V_block_ptr = V + v_offset + v_block_offset + tl.arange(0, CBLOCK)[:, None] * e + tl.arange(0, e)[None, :]
+ O_block_ptr = (
+ Out
+ + o_offset
+ + o_block_offset
+ + o_cblock_offset
+ + tl.arange(0, CBLOCK)[:, None] * e
+ + tl.arange(0, e)[None, :]
+ )
+
+ # Load the decay rate for the current head
+ S_block_ptr = S + off_h
+ s = tl.load(S_block_ptr)
+
+ i = off_cblock
+ q_index = tl.arange(0, CBLOCK) + i * CBLOCK
+
+ # Load query values
+ q = tl.load(Q_block_ptr, mask=block_offset + q_index[:, None] < n, other=0.0).to(tl.float32)
+
+ # Initialize output accumulator
+ qkv = tl.zeros([CBLOCK, e], dtype=tl.float32)
+
+ # Process all sub-blocks up to and
+ # including the current one (causal attention)
+ for j in range(i + 1):
+ kv_index = tl.arange(0, CBLOCK) + j * CBLOCK
+ diff = q_index[:, None] - kv_index[None, :]
+ s_index = s * diff
+ # Apply causal mask: only attend to positions before the current one
+ s_index = tl.where(diff >= 0, -s_index, float("-inf"))
+ decay = tl.exp(s_index)
+
+ # Load key and value
+ k_trans = tl.load(
+ K_trans_block_ptr,
+ mask=block_offset + kv_index[None, :] < n,
+ other=0.0,
+ ).to(tl.float32)
+ v = tl.load(
+ V_block_ptr,
+ mask=block_offset + kv_index[:, None] < n,
+ other=0.0,
+ ).to(tl.float32)
+
+ # Compute attention scores and apply decay
+ qk = tl.dot(q, k_trans) * decay
+
+ # Compute weighted values and accumulate
+ qkv += tl.dot(qk, v)
+
+ # Move to the next sub-block
+ K_trans_block_ptr += CBLOCK * d
+ V_block_ptr += CBLOCK * e
+
+ # Store the result
+ tl.store(
+ O_block_ptr,
+ qkv.to(O_block_ptr.dtype.element_ty),
+ mask=block_offset + q_index[:, None] < n,
+ )
+
+
+@enable_compat_on_triton_kernel
+@triton.jit # pragma: no cover
+def _fwd_kv_parallel(
+ K,
+ V,
+ K_decay,
+ KV,
+ b: tl.constexpr,
+ h: tl.constexpr,
+ n,
+ d: tl.constexpr,
+ e: tl.constexpr,
+ BLOCK: tl.constexpr,
+ NUM_BLOCK,
+ D_FBLOCK: tl.constexpr,
+ E_FBLOCK: tl.constexpr,
+ NUM_FBLOCK: tl.constexpr,
+ CBLOCK: tl.constexpr,
+ NUM_CBLOCK: tl.constexpr,
+):
+ # This kernel computes the key-value outer
+ # products for each block in parallel
+ off_bh = tl.program_id(0) # batch-head index
+ off_block = tl.program_id(1) # block index
+
+ off_h = off_bh % h # head index
+
+ block_offset = off_block * BLOCK
+
+ # Calculate offsets for the current block
+ k_block_offset = block_offset * d
+ v_block_offset = block_offset * e
+ kv_block_offset = off_block * d * e
+
+ # Calculate base offsets for the current batch and head
+ k_offset = off_bh * n * d
+ v_offset = off_bh * n * e
+ kv_offset = off_bh * NUM_BLOCK * d * e
+
+ # Calculate pointers to the key, value, and key-value tensors
+ K_trans_block_ptr = (
+ K + k_offset + k_block_offset + tl.arange(0, CBLOCK)[None, :] * d + tl.arange(0, D_FBLOCK)[:, None]
+ )
+ V_block_ptr = V + v_offset + v_block_offset + tl.arange(0, CBLOCK)[:, None] * e + tl.arange(0, E_FBLOCK)[None, :]
+ KV_block_ptr = (
+ KV + kv_offset + kv_block_offset + tl.arange(0, D_FBLOCK)[:, None] * e + tl.arange(0, E_FBLOCK)[None, :]
+ )
+
+ # Load the decay factors for the current head and block
+ k_decay_ptr = K_decay + off_h * BLOCK + tl.arange(0, CBLOCK)
+
+ kv_index = tl.arange(0, CBLOCK)
+
+ # Initialize the key-value outer product accumulator
+ kv = tl.zeros([D_FBLOCK, E_FBLOCK], dtype=tl.float32)
+
+ # Handle the last block which might be smaller than BLOCK
+ split_n = n - (NUM_BLOCK - 1) * BLOCK if off_block == NUM_BLOCK - 1 else BLOCK
+ left_shift = tl.cdiv(split_n, CBLOCK) * CBLOCK - split_n
+ num_blocks = min(tl.cdiv(split_n, CBLOCK), NUM_CBLOCK)
+ k_decay_ptr += (NUM_CBLOCK - num_blocks) * CBLOCK
+
+ # Process all sub-blocks in the current block
+ for j in range(num_blocks):
+ left_bound = (1 - j) * left_shift
+ # Load key and value, handling boundary conditions
+ k_trans = tl.load(
+ K_trans_block_ptr - left_shift * d,
+ mask=kv_index[None, :] >= left_bound,
+ other=0.0,
+ ).to(tl.float32)
+ v = tl.load(
+ V_block_ptr - left_shift * e,
+ mask=kv_index[:, None] >= left_bound,
+ other=0.0,
+ ).to(tl.float32)
+
+ # Load decay factor and compute weighted key-value outer product
+ k_decay = tl.load(k_decay_ptr)
+
+ # NOTE: Need to add the extra dim here due to AMD MLIR lowering error.
+ # Please don't move it back until issue is resolved.
+ # Issue: https://github.com/ROCm/triton/issues/907
+ k_decay = k_decay[None, :]
+
+ kv += tl.dot(k_trans * k_decay, v)
+
+ # Move to the next sub-block
+ K_trans_block_ptr += CBLOCK * d
+ V_block_ptr += CBLOCK * e
+ k_decay_ptr += CBLOCK
+
+ # Store the result
+ tl.store(KV_block_ptr, kv.to(KV_block_ptr.dtype.element_ty))
+
+
+@enable_compat_on_triton_kernel
+@triton.jit # pragma: no cover
+def _fwd_kv_reduce(
+ S,
+ KV,
+ KV_HISTORY,
+ b: tl.constexpr,
+ h: tl.constexpr,
+ n,
+ d: tl.constexpr,
+ e: tl.constexpr,
+ BLOCK: tl.constexpr,
+ NUM_BLOCK,
+ D_FBLOCK: tl.constexpr,
+ E_FBLOCK: tl.constexpr,
+):
+ # This kernel reduces the key-value outer products
+ # across blocks and updates the KV history
+ off_bh = tl.program_id(0) # batch-head index
+ off_h = off_bh % h # head index
+
+ kv_offset = off_bh * NUM_BLOCK * d * e
+
+ # Calculate pointer to the key-value tensor
+ KV_block_ptr = KV + kv_offset + tl.arange(0, D_FBLOCK)[:, None] * e + tl.arange(0, E_FBLOCK)[None, :]
+
+ # Load the decay rate for the current head
+ s_ptrs = S + off_h
+ s = tl.load(s_ptrs)
+
+ # Calculate pointer to the key-value history tensor
+ kv_history_offset = off_bh * d * e
+ KV_HISTORY_block_ptr = (
+ KV_HISTORY + kv_history_offset + tl.arange(0, D_FBLOCK)[:, None] * e + tl.arange(0, E_FBLOCK)[None, :]
+ )
+
+ # Load the previous key-value history
+ kv_pre = tl.load(KV_HISTORY_block_ptr).to(tl.float32)
+
+ # Process all blocks in forward order to compute the prefix accumulation
+ for i in range(NUM_BLOCK):
+ block_size = min(n - i * BLOCK, BLOCK)
+ # Compute decay factor for the current block
+ block_decay = tl.exp(-s.to(tl.float32) * block_size)
+
+ # Load the current key-value outer product
+ kv_cur = tl.load(KV_block_ptr).to(tl.float32)
+ # Store the previous key-value history to the current block
+ tl.store(KV_block_ptr, kv_pre.to(KV_block_ptr.dtype.element_ty))
+
+ # Update the key-value history with the current block
+ kv_pre = block_decay * kv_pre + kv_cur
+ KV_block_ptr += d * e
+
+ # Store the updated key-value history
+ tl.store(KV_HISTORY_block_ptr, kv_pre)
+
+
+@enable_compat_on_triton_kernel
+@triton.jit # pragma: no cover
+def _fwd_none_diag_kernel(
+ Q,
+ Out,
+ S,
+ KV,
+ b: tl.constexpr,
+ h: tl.constexpr,
+ n,
+ d: tl.constexpr,
+ e: tl.constexpr,
+ BLOCK: tl.constexpr,
+ NUM_BLOCK,
+ E_FBLOCK: tl.constexpr,
+ CBLOCK: tl.constexpr,
+ NUM_CBLOCK: tl.constexpr,
+):
+ # This kernel computes the non-diagonal blocks of the attention matrix
+ # Each non-diagonal block represents attention
+ # where queries attend to keys in different blocks
+ off_bh = tl.program_id(0) # batch-head index
+ off_h = off_bh % h # head index
+
+ off_nc = tl.program_id(1)
+ off_n = off_nc // NUM_CBLOCK # block index
+ off_c = off_nc % NUM_CBLOCK # sub-block index
+ off_e = tl.program_id(2) # output feature block index
+
+ n_offset = off_n * BLOCK
+ c_offset = off_c * CBLOCK
+ e_offset = off_e * E_FBLOCK
+ block_offset = n_offset + c_offset
+
+ # Calculate offsets for the current batch, head, and block
+ q_offset = off_bh * n * d + (n_offset + c_offset) * d
+ o_offset = off_bh * n * e + (n_offset + c_offset) * e + e_offset
+ kv_offset = off_bh * NUM_BLOCK * d * e + off_n * d * e + e_offset
+
+ # Calculate pointers to the query, output, and key-value tensors
+ Q_block_ptr = Q + q_offset + tl.arange(0, CBLOCK)[:, None] * d + tl.arange(0, d)[None, :]
+ O_block_ptr = Out + o_offset + tl.arange(0, CBLOCK)[:, None] * e + tl.arange(0, E_FBLOCK)[None, :]
+ KV_block_ptr = KV + kv_offset + tl.arange(0, d)[:, None] * e + tl.arange(0, E_FBLOCK)[None, :]
+
+ # Load the decay rate for the current head
+ S_block_ptr = S + off_h
+ s = tl.load(S_block_ptr)
+
+ c_array = tl.arange(0, CBLOCK)
+
+ # Load the key-value outer product for the current block
+ kv = tl.load(KV_block_ptr).to(tl.float32)
+ q_index = block_offset + tl.arange(0, CBLOCK)
+
+ # Load query values
+ q = tl.load(Q_block_ptr, mask=q_index[:, None] < n, other=0.0).to(tl.float32)
+
+ # Compute decay factors for the current sub-block
+ q_decay = tl.exp(-s.to(tl.float32) * (off_c * CBLOCK + c_array[:, None]))
+
+ # Compute non-diagonal attention output
+ qkv_none_diag = tl.dot(q, kv) * q_decay
+
+ # Load diagonal attention output (computed by _fwd_diag_kernel)
+ qkv_diag = tl.load(O_block_ptr, mask=q_index[:, None] < n, other=0.0).to(tl.float32)
+
+ # Combine diagonal and non-diagonal attention outputs
+ qkv = qkv_diag + qkv_none_diag
+
+ # Store the result
+ tl.store(O_block_ptr, qkv.to(O_block_ptr.dtype.element_ty), mask=q_index[:, None] < n)
+
+
+@enable_compat_on_triton_kernel
+@triton.jit # pragma: no cover
+def _linear_attn_decode_kernel(
+ q_ptr,
+ k_ptr,
+ v_ptr,
+ kv_cache_ptr,
+ slope_rate,
+ slot_idx,
+ output_ptr,
+ D: tl.constexpr,
+ qkv_b_stride,
+ qkv_h_stride,
+ cache_b_stride,
+ cache_h_stride,
+ cache_d0_stride,
+ cache_d1_stride,
+ BLOCK_SIZE: tl.constexpr,
+):
+ """
+ Kernel for linear attention decoding with KV cache.
+
+ This kernel computes attention for a single token using the KV cache.
+ """
+ pid_b = tl.program_id(0) # batch index
+ pid_h = tl.program_id(1) # head index
+ pid_d = tl.program_id(2) # dimension block index
+
+ # Load slot index for the current batch
+ slot_id = tl.load(slot_idx + pid_b).to(tl.int64)
+
+ # Skip if slot_id is -1 (padding); zero the output so the caller
+ # never reads uninitialised memory from paddle.empty_like.
+ if slot_id == -1:
+ v_d_offsets = tl.arange(0, BLOCK_SIZE) + tl.program_id(2) * BLOCK_SIZE
+ v_mask = v_d_offsets < D
+ out_offset = pid_b * qkv_b_stride + pid_h * qkv_h_stride
+ tl.store(output_ptr + out_offset + v_d_offsets, tl.zeros([BLOCK_SIZE], dtype=tl.float32), mask=v_mask)
+ return
+
+ batch_id = pid_b
+ head_id = pid_h
+
+ # Load decay rate for the current head
+ ratio = tl.load(slope_rate + pid_h)
+
+ # Calculate offsets for dimensions
+ qk_d_offsets = tl.arange(0, D)
+ v_d_offsets = tl.arange(0, BLOCK_SIZE) + pid_d * BLOCK_SIZE
+ cache_d_offsets = qk_d_offsets[:, None] * cache_d0_stride + v_d_offsets[None, :] * cache_d1_stride
+
+ # Calculate offsets for the current batch and head
+ q_offset = batch_id * qkv_b_stride + head_id * qkv_h_stride
+ k_offset = batch_id * qkv_b_stride + head_id * qkv_h_stride
+ v_offset = batch_id * qkv_b_stride + head_id * qkv_h_stride
+
+ cache_offset = slot_id * cache_b_stride + head_id * cache_h_stride
+
+ # Create masks for loading tensors
+ qk_mask = qk_d_offsets < D
+ v_mask = v_d_offsets < D
+
+ # Load query, key, and value tensors
+ q = tl.load(q_ptr + q_offset + qk_d_offsets, mask=qk_mask, other=0.0)
+ k = tl.load(k_ptr + k_offset + qk_d_offsets, mask=qk_mask, other=0.0)
+ v = tl.load(v_ptr + v_offset + v_d_offsets, mask=v_mask, other=0.0)
+
+ # Compute key-value outer product
+ kv_outer = k[:, None] * v[None, :]
+ kv_mask = qk_mask[:, None] & v_mask[None, :]
+
+ # Apply decay to previous KV cache
+ ratio = tl.exp(-ratio)
+ kv_ptr = kv_cache_ptr + cache_offset + cache_d_offsets
+ kv_cache_old = tl.load(kv_ptr, mask=kv_mask, other=0.0)
+ kv_outer = kv_outer + ratio * kv_cache_old
+
+ # Compute attention output
+ output = q[:, None].to(tl.float32) * kv_outer
+ output = tl.sum(output, axis=0)
+
+ # Update KV cache and store output
+ tl.store(kv_ptr, kv_outer, mask=kv_mask)
+ tl.store(output_ptr + q_offset + v_d_offsets, output, mask=v_mask)
+
+
+# =============================================================================
+# Python wrapper functions — Paddle API
+# =============================================================================
+
+
+def lightning_attention_forward(q, k, v, s, kv_history, block_size=256):
+ """
+ Forward pass of the lightning attention algorithm.
+ Converted from vLLM's torch.autograd.Function to a plain function
+ for inference-only use in FastDeploy.
+
+ Args:
+ q: Query tensor [b, h, n, d]
+ k: Key tensor [b, h, n, d]
+ v: Value tensor [b, h, n, e]
+ s: Decay rate tensor [1, h, 1, 1] or [h]
+ kv_history: KV history tensor [b, h, d, e]
+ block_size: Block size for block-sparse attention (default 256)
+
+ Returns:
+ o: Output tensor [b, h, n, e]
+ kv_history: Updated 4D KV history tensor [b, h, d, e]
+ """
+ q = q.contiguous()
+ k = k.contiguous()
+ v = v.contiguous()
+ s = s.contiguous()
+
+ # Get input dimensions
+ b, h, n, d = q.shape
+ e = v.shape[-1]
+
+ # Initialize output tensor
+ o = paddle.empty([b, h, n, e], dtype=q.dtype)
+
+ # Set block sizes
+ BLOCK = block_size
+ NUM_BLOCK = triton.cdiv(n, BLOCK)
+
+ CBLOCK = 32
+ NUM_CBLOCK = BLOCK // CBLOCK
+ assert BLOCK % CBLOCK == 0, "BLOCK must be a multiple of CBLOCK"
+
+ # Compute decay factors for keys
+ array = paddle.arange(0, BLOCK).astype("float32") + 1
+ k_decay = paddle.exp(-s * (BLOCK - array.reshape([1, -1])))
+
+ # Step 1: Compute diagonal blocks of attention
+ grid = (b * h * NUM_BLOCK, NUM_CBLOCK)
+ _fwd_diag_kernel[grid](
+ q,
+ k,
+ v,
+ o,
+ s,
+ b,
+ h,
+ n,
+ d,
+ e,
+ BLOCK=BLOCK,
+ NUM_BLOCK=NUM_BLOCK,
+ CBLOCK=CBLOCK,
+ )
+
+ # Set feature block sizes
+ NUM_FBLOCK = 1
+ D_FBLOCK = d // NUM_FBLOCK
+ assert d % NUM_FBLOCK == 0
+ E_FBLOCK = e // NUM_FBLOCK
+ assert e % NUM_FBLOCK == 0
+
+ CBLOCK = 64
+ NUM_CBLOCK = BLOCK // CBLOCK
+ assert BLOCK % CBLOCK == 0, "BLOCK must be a multiple of CBLOCK"
+
+ # Step 2: Compute key-value outer products for each block in parallel
+ kv = paddle.empty([b, h, NUM_BLOCK, d, e], dtype="float32")
+ grid = (b * h, NUM_BLOCK)
+ _fwd_kv_parallel[grid](
+ k,
+ v,
+ k_decay,
+ kv,
+ b,
+ h,
+ n,
+ d,
+ e,
+ BLOCK=BLOCK,
+ NUM_BLOCK=NUM_BLOCK,
+ D_FBLOCK=D_FBLOCK,
+ E_FBLOCK=E_FBLOCK,
+ NUM_FBLOCK=NUM_FBLOCK,
+ CBLOCK=CBLOCK,
+ NUM_CBLOCK=NUM_CBLOCK,
+ )
+
+ # Step 3: Reduce key-value outer products
+ # across blocks and update KV history
+ grid = (b * h, NUM_FBLOCK)
+ _fwd_kv_reduce[grid](
+ s,
+ kv,
+ kv_history,
+ b,
+ h,
+ n,
+ d,
+ e,
+ BLOCK=BLOCK,
+ NUM_BLOCK=NUM_BLOCK,
+ D_FBLOCK=D_FBLOCK,
+ E_FBLOCK=E_FBLOCK,
+ )
+
+ # Step 4: Compute non-diagonal blocks of attention
+ grid = (b * h, NUM_BLOCK * NUM_CBLOCK)
+ _fwd_none_diag_kernel[grid](
+ q,
+ o,
+ s,
+ kv,
+ b,
+ h,
+ n,
+ d,
+ e,
+ BLOCK=BLOCK,
+ NUM_BLOCK=NUM_BLOCK,
+ E_FBLOCK=E_FBLOCK,
+ CBLOCK=CBLOCK,
+ NUM_CBLOCK=NUM_CBLOCK,
+ )
+
+ # In vLLM the concat [kv, kv_history] is returned for the backward pass.
+ # For inference-only we only need the updated 4D kv_history (already
+ # written in-place by _fwd_kv_reduce).
+ return o, kv_history
+
+
+def lightning_attention(
+ q: paddle.Tensor,
+ k: paddle.Tensor,
+ v: paddle.Tensor,
+ ed: paddle.Tensor,
+ block_size: int = 256,
+ kv_history: paddle.Tensor | None = None,
+) -> tuple[paddle.Tensor, paddle.Tensor]:
+ """
+ Apply lightning attention algorithm to compute attention efficiently.
+
+ Args:
+ q: Query tensor of shape [batch, heads, seq_len, dim]
+ k: Key tensor of shape [batch, heads, seq_len, dim]
+ v: Value tensor of shape [batch, heads, seq_len, dim_v]
+ ed: Decay rate tensor of shape [heads]
+ block_size: Size of blocks for block-sparse attention
+ kv_history: Optional key-value history from previous computations
+
+ Returns:
+ output: Attention output
+ kv: Updated key-value history
+ """
+ d = q.shape[-1]
+ e = v.shape[-1]
+
+ if ed.ndim == 1:
+ ed = ed.reshape([1, -1, 1, 1])
+
+ # Split the computation into chunks for better parallelism.
+ # MiniMax-M1 production uses head_dim=128 (m=128). The fallback must
+ # handle smaller dimensions (e.g. tiny-random test models with d=32).
+ m = 128 if d >= 128 else min(64, d)
+ assert d % m == 0, f"Dimension d ({d}) must be divisible by m ({m})"
+ arr = [m * i for i in range(d // m + 1)]
+ if arr[-1] != d:
+ arr.append(d)
+ n = len(arr)
+ output = 0
+
+ # Initialize key-value history. The Triton kernel updates kv_history
+ # in-place, so we only need a contiguous view — avoid an extra copy.
+ if kv_history is None:
+ kv_history = paddle.zeros([q.shape[0], q.shape[1], d, e], dtype="float32")
+ elif not kv_history.is_contiguous():
+ kv_history = kv_history.contiguous()
+
+ # Process each chunk and accumulate results
+ for i in range(n - 1):
+ s = arr[i]
+ end_idx = arr[i + 1]
+ q1 = q[..., s:end_idx]
+ k1 = k[..., s:end_idx]
+ o, kv_history = lightning_attention_forward(q1, k1, v, ed, kv_history, block_size=block_size)
+ output = output + o
+ return output, kv_history
+
+
+# Reserved for future decode-path integration: will be called from
+# MiniMaxM1LinearAttention.forward when forward_meta.is_decode is True.
+# Kept alongside the prefill kernel for architectural completeness.
+def linear_decode_forward_triton(
+ q: paddle.Tensor,
+ k: paddle.Tensor,
+ v: paddle.Tensor,
+ kv_caches: paddle.Tensor,
+ slope_rate: paddle.Tensor,
+ slot_idx: paddle.Tensor,
+ BLOCK_SIZE: int = 32,
+) -> paddle.Tensor:
+ """
+ Perform linear attention decoding using Triton kernels.
+
+ Args:
+ q: Query tensor of shape [B, H, 1, D]
+ k: Key tensor of shape [B, H, 1, D]
+ v: Value tensor of shape [B, H, 1, D]
+ kv_caches: Key-value cache tensor
+ slope_rate: Decay rate tensor
+ slot_idx: Slot indices for batches
+ BLOCK_SIZE: Size of blocks for processing
+
+ Returns:
+ output: Attention output tensor of shape [B, H*D]
+ """
+ B, H, _, D = q.shape
+ assert k.shape == [B, H, 1, D]
+ assert v.shape == [B, H, 1, D]
+
+ # Initialize output tensor
+ output = paddle.empty_like(q)
+
+ # MiniMax-M1 uses head_dim=128 (128 % 32 == 0). Guard against future
+ # models with non-standard head dimensions until a fallback is added.
+ assert D % BLOCK_SIZE == 0, (
+ f"Head dimension D ({D}) must be divisible by BLOCK_SIZE ({BLOCK_SIZE}); "
+ f"otherwise the kernel grid drops tail dimensions silently."
+ )
+
+ # Set grid dimensions for the kernel
+ grid = (B, H, D // BLOCK_SIZE)
+
+ # Calculate strides for tensors
+ qkv_b_stride = q.strides[0]
+ qkv_h_stride = q.strides[1]
+
+ cache_b_stride = kv_caches.strides[0]
+ cache_h_stride = kv_caches.strides[1]
+ cache_d0_stride = kv_caches.strides[2]
+ cache_d1_stride = kv_caches.strides[3]
+
+ # Launch the kernel
+ _linear_attn_decode_kernel[grid](
+ q,
+ k,
+ v,
+ kv_caches,
+ slope_rate,
+ slot_idx,
+ output,
+ D,
+ qkv_b_stride,
+ qkv_h_stride,
+ cache_b_stride,
+ cache_h_stride,
+ cache_d0_stride,
+ cache_d1_stride,
+ BLOCK_SIZE=BLOCK_SIZE,
+ )
+
+ # Reshape output: "b h n d -> b n (h d)"
+ # output shape: [B, H, 1, D] -> transpose to [B, 1, H, D] -> reshape to [B, 1, H*D]
+ output = output.transpose([0, 2, 1, 3]).reshape([B, 1, -1])
+ return output.squeeze(1).contiguous()
diff --git a/tests/e2e/validate_minimax_m1_e2e.py b/tests/e2e/validate_minimax_m1_e2e.py
new file mode 100644
index 00000000000..46b1968b849
--- /dev/null
+++ b/tests/e2e/validate_minimax_m1_e2e.py
@@ -0,0 +1,263 @@
+#!/usr/bin/env python3
+# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Module for Hackathon 10th Spring No.47.
+Standalone MiniMax-M1 end-to-end validation for AI Studio.
+
+Run on AI Studio A800 via SSH:
+ ssh aistudio 'python3 /home/aistudio/validate_minimax_m1.py 2>&1 | tee /home/aistudio/output/minimax_m1_e2e.log'
+
+This script:
+1. Starts FastDeploy API server with MiniMax-M1 (WINT4/WINT8)
+2. Waits for server readiness
+3. Runs 6 validation checks (health, models, chat, reasoning, Chinese, multi-turn)
+4. Prints structured evidence for PR body
+5. Cleans up server process
+
+Requirements:
+- AI Studio A800 (80GB) or multiple GPUs for full 456B model
+- FastDeploy installed with Triton support
+- Model weights downloaded to MODEL_PATH
+
+Environment variables:
+ MINIMAX_MODEL_PATH Path to MiniMax-M1 weights (default: MiniMax/MiniMax-M1-80k)
+ MINIMAX_PORT Server port (default: 8189)
+ MINIMAX_QUANT Quantization type: wint4, wint8, or none (default: wint4)
+ MINIMAX_TP Tensor parallel degree (default: 1)
+"""
+
+import json
+import os
+import signal
+import subprocess
+import sys
+import time
+import urllib.request
+
+MODEL_PATH = os.environ.get("MINIMAX_MODEL_PATH", "MiniMax/MiniMax-M1-80k")
+PORT = int(os.environ.get("MINIMAX_PORT", "8189"))
+QUANTIZATION = os.environ.get("MINIMAX_QUANT", "wint4")
+TP_DEGREE = int(os.environ.get("MINIMAX_TP", "1"))
+
+
+def log(msg):
+ print(f"[{time.strftime('%H:%M:%S')}] {msg}", flush=True)
+
+
+def wait_for_server(port, timeout=900):
+ """Poll server health until ready."""
+ log(f"Waiting for server on port {port} (timeout={timeout}s)...")
+ start = time.time()
+ while time.time() - start < timeout:
+ try:
+ req = urllib.request.Request(f"http://localhost:{port}/health")
+ with urllib.request.urlopen(req, timeout=5) as resp:
+ if resp.status == 200:
+ elapsed = time.time() - start
+ log(f"Server ready in {elapsed:.1f}s")
+ return True
+ except Exception:
+ pass
+ time.sleep(5)
+ return False
+
+
+def send_chat(prompt, max_tokens=128, temperature=0.0, messages=None):
+ """Send a chat completion request."""
+ if messages is None:
+ messages = [{"role": "user", "content": prompt}]
+ body = json.dumps(
+ {
+ "model": MODEL_PATH,
+ "messages": messages,
+ "max_tokens": max_tokens,
+ "temperature": temperature,
+ }
+ ).encode()
+ req = urllib.request.Request(
+ f"http://localhost:{PORT}/v1/chat/completions",
+ data=body,
+ headers={"Content-Type": "application/json"},
+ )
+ with urllib.request.urlopen(req, timeout=120) as resp:
+ return json.loads(resp.read().decode())
+
+
+def run_validations():
+ """Run all validation checks. Returns (passed, failed, results)."""
+ results = []
+ passed = 0
+ failed = 0
+
+ # Test 1: Health endpoint
+ log("Test 1/6: Health endpoint")
+ try:
+ req = urllib.request.Request(f"http://localhost:{PORT}/health")
+ with urllib.request.urlopen(req, timeout=10) as resp:
+ assert resp.status == 200
+ results.append(("health", "PASS", "HTTP 200"))
+ passed += 1
+ except Exception as e:
+ results.append(("health", "FAIL", str(e)))
+ failed += 1
+
+ # Test 2: Model listing
+ log("Test 2/6: Model listing")
+ try:
+ req = urllib.request.Request(f"http://localhost:{PORT}/v1/models")
+ with urllib.request.urlopen(req, timeout=10) as resp:
+ data = json.loads(resp.read().decode())
+ model_ids = [m["id"] for m in data.get("data", [])]
+ assert len(model_ids) > 0, f"No models listed: {data}"
+ results.append(("models", "PASS", f"Models: {model_ids}"))
+ passed += 1
+ except Exception as e:
+ results.append(("models", "FAIL", str(e)))
+ failed += 1
+
+ # Test 3: Simple chat
+ log("Test 3/6: Simple chat")
+ try:
+ resp = send_chat("Hello, what is your name?")
+ content = resp["choices"][0]["message"]["content"].strip()
+ assert len(content) > 0, "Empty response"
+ results.append(("chat", "PASS", f"Response: {content[:100]}..."))
+ passed += 1
+ except Exception as e:
+ results.append(("chat", "FAIL", str(e)))
+ failed += 1
+
+ # Test 4: Arithmetic reasoning
+ log("Test 4/6: Arithmetic reasoning")
+ try:
+ resp = send_chat("What is 17 * 23? Just give the number.")
+ content = resp["choices"][0]["message"]["content"].strip()
+ assert "391" in content, f"Expected 391, got: {content}"
+ results.append(("arithmetic", "PASS", f"Response: {content[:100]}"))
+ passed += 1
+ except Exception as e:
+ results.append(("arithmetic", "FAIL", str(e)))
+ failed += 1
+
+ # Test 5: Chinese language
+ log("Test 5/6: Chinese language")
+ try:
+ resp = send_chat("用中文解释什么是人工智能,一句话。")
+ content = resp["choices"][0]["message"]["content"].strip()
+ assert len(content) > 5, f"Response too short: {content}"
+ # Verify Chinese characters present
+ has_chinese = any("\u4e00" <= c <= "\u9fff" for c in content)
+ assert has_chinese, f"No Chinese in response: {content}"
+ results.append(("chinese", "PASS", f"Response: {content[:100]}"))
+ passed += 1
+ except Exception as e:
+ results.append(("chinese", "FAIL", str(e)))
+ failed += 1
+
+ # Test 6: Multi-turn conversation
+ log("Test 6/6: Multi-turn conversation")
+ try:
+ messages = [
+ {"role": "user", "content": "My name is Alice."},
+ {"role": "assistant", "content": "Hello Alice! How can I help you?"},
+ {"role": "user", "content": "What is my name?"},
+ ]
+ resp = send_chat("", messages=messages)
+ content = resp["choices"][0]["message"]["content"].strip()
+ assert "alice" in content.lower(), f"Model forgot name: {content}"
+ results.append(("multi_turn", "PASS", f"Response: {content[:100]}"))
+ passed += 1
+ except Exception as e:
+ results.append(("multi_turn", "FAIL", str(e)))
+ failed += 1
+
+ return passed, failed, results
+
+
+def main():
+ log("=" * 60)
+ log("MiniMax-M1 End-to-End Validation")
+ log(f"Model: {MODEL_PATH}")
+ log(f"Quantization: {QUANTIZATION}")
+ log(f"TP Degree: {TP_DEGREE}")
+ log(f"Port: {PORT}")
+ log("=" * 60)
+
+ # Build server command
+ cmd = [
+ sys.executable,
+ "-m",
+ "fastdeploy.entrypoints.openai.api_server",
+ "--model",
+ MODEL_PATH,
+ "--port",
+ str(PORT),
+ "--max-model-len",
+ "4096",
+ ]
+ if QUANTIZATION and QUANTIZATION != "none":
+ cmd.extend(["--quantization", QUANTIZATION])
+ if TP_DEGREE > 1:
+ cmd.extend(["--tensor-parallel-size", str(TP_DEGREE)])
+
+ log(f"Starting server: {' '.join(cmd)}")
+ server = subprocess.Popen(
+ cmd,
+ stdout=subprocess.PIPE,
+ stderr=subprocess.STDOUT,
+ preexec_fn=os.setsid,
+ )
+
+ try:
+ if not wait_for_server(PORT, timeout=900):
+ log("FATAL: Server failed to start within 15 minutes!")
+ # Dump last output
+ if server.stdout:
+ output = server.stdout.read(4096)
+ if output:
+ log(f"Server output:\n{output.decode(errors='replace')}")
+ sys.exit(1)
+
+ passed, failed, results = run_validations()
+
+ # Print structured evidence
+ log("")
+ log("=" * 60)
+ log(f"RESULTS: {passed}/{passed+failed} passed")
+ log("=" * 60)
+ for name, status, detail in results:
+ icon = "✅" if status == "PASS" else "❌"
+ log(f" {icon} {name}: {detail}")
+
+ if failed > 0:
+ log(f"\n❌ {failed} test(s) FAILED")
+ sys.exit(1)
+ else:
+ log("\n✅ All validations passed!")
+
+ finally:
+ log("Shutting down server...")
+ try:
+ os.killpg(os.getpgid(server.pid), signal.SIGTERM)
+ server.wait(timeout=15)
+ except Exception:
+ try:
+ os.killpg(os.getpgid(server.pid), signal.SIGKILL)
+ except Exception:
+ pass
+ log("Done.")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/tests/model_executor/test_minimax_m1.py b/tests/model_executor/test_minimax_m1.py
new file mode 100644
index 00000000000..0b6a6c7f13d
--- /dev/null
+++ b/tests/model_executor/test_minimax_m1.py
@@ -0,0 +1,767 @@
+# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Module for Hackathon 10th Spring No.47.
+Tests for MiniMax-M1 model: architecture dispatch, weight loading, forward paths,
+and Lightning Attention algorithm correctness.
+
+Follows H10 gold standard (test_ernie4_5_mtp.py pattern):
+- Direct import of fastdeploy module
+- Real paddle.nn.Layer stubs (not MagicMock)
+- monkeypatch.setattr for surgical replacement
+- Tests exercise actual FD code paths
+"""
+
+from __future__ import annotations
+
+import math
+from types import SimpleNamespace
+
+import numpy as np
+import paddle
+import pytest
+
+from fastdeploy.model_executor.models import minimax_m1
+
+# ── Lightweight stubs (real nn.Layer subclasses) ────────────────────────────
+
+
+class _StubRMSNorm(paddle.nn.Layer):
+ def __init__(self, *a, **kw):
+ super().__init__()
+ self.load_state_dict_called = False
+
+ def forward(self, x, residual_input=None, forward_meta=None):
+ if residual_input is None:
+ residual_input = paddle.zeros_like(x)
+ return x, residual_input + x
+
+ def load_state_dict(self, _sd):
+ self.load_state_dict_called = True
+
+
+class _StubLinear(paddle.nn.Layer):
+ """Stub for ColumnParallelLinear, RowParallelLinear, MergedColumnParallelLinear, ReplicatedLinear."""
+
+ def __init__(self, *a, **kw):
+ super().__init__()
+ self.load_state_dict_called = False
+ self._out = kw.get("output_size", None)
+
+ def forward(self, x, *a, **kw):
+ if self._out is not None:
+ shape = list(x.shape)
+ shape[-1] = self._out
+ return paddle.zeros(shape, dtype=x.dtype)
+ return x
+
+ def load_state_dict(self, _sd):
+ self.load_state_dict_called = True
+
+
+class _StubAttention(paddle.nn.Layer):
+ def __init__(self, *a, **kw):
+ super().__init__()
+ self.load_state_dict_called = False
+
+ def forward(self, q=None, k=None, v=None, qkv=None, forward_meta=None, **kw):
+ if qkv is not None:
+ return qkv
+ return q
+
+ def load_state_dict(self, _sd):
+ self.load_state_dict_called = True
+
+
+class _StubSiluAndMul(paddle.nn.Layer):
+ def __init__(self, *a, **kw):
+ super().__init__()
+
+ def forward(self, x):
+ return x[..., : x.shape[-1] // 2]
+
+
+class _StubFusedMoE(paddle.nn.Layer):
+ def __init__(self, *a, **kw):
+ super().__init__()
+ self.weight_key_map = kw.get("weight_key_map", {})
+ self.load_state_dict_called = False
+
+ def forward(self, hidden_states, gate, forward_meta=None):
+ return hidden_states
+
+ def load_state_dict(self, _sd):
+ self.load_state_dict_called = True
+
+ @staticmethod
+ def make_expert_params_mapping(**kw):
+ return []
+
+
+class _StubEmbedding(paddle.nn.Layer):
+ def __init__(self, *a, **kw):
+ super().__init__()
+ self.hidden_size = kw.get("embedding_dim", 4)
+ self.load_state_dict_called = False
+
+ def forward(self, ids_remove_padding=None, forward_meta=None):
+ return paddle.zeros([ids_remove_padding.shape[0], self.hidden_size], "float32")
+
+ def load_state_dict(self, _sd):
+ self.load_state_dict_called = True
+
+
+class _StubLMHead(paddle.nn.Layer):
+ def __init__(self, *a, **kw):
+ super().__init__()
+ self.load_state_dict_called = False
+
+ def forward(self, x):
+ return x
+
+ def load_state_dict(self, _sd):
+ self.load_state_dict_called = True
+
+
+def _stub_lightning_attention(q, k, v, slope, block_size=256, kv_history=None):
+ """Stub: return zeros matching shapes."""
+ b, h, seq_len, d = q.shape
+ out = paddle.zeros_like(q)
+ if kv_history is None:
+ kv_history = paddle.zeros([b, h, d, d], dtype=q.dtype)
+ return out, kv_history
+
+
+def _stub_all_reduce(x):
+ return x
+
+
+def _stub_graph_opt(cls):
+ return cls
+
+
+# ── Helpers ─────────────────────────────────────────────────────────────────
+
+
+def _make_fd_config(
+ hidden_size=4,
+ num_layers=2,
+ num_local_experts=4,
+ num_attention_heads=4,
+ num_key_value_heads=2,
+ head_dim=2,
+ postnorm=False,
+):
+ mc = SimpleNamespace(
+ hidden_size=hidden_size,
+ intermediate_size=hidden_size * 2,
+ num_hidden_layers=num_layers,
+ num_attention_heads=num_attention_heads,
+ num_key_value_heads=num_key_value_heads,
+ head_dim=head_dim,
+ vocab_size=8,
+ rms_norm_eps=1e-6,
+ hidden_act="silu",
+ num_local_experts=num_local_experts,
+ num_experts_per_tok=2,
+ norm_topk_prob=False,
+ postnorm=postnorm,
+ attn_type_list=[0, 1][:num_layers],
+ layernorm_full_attention_alpha=3.556,
+ layernorm_full_attention_beta=1.0,
+ layernorm_linear_attention_alpha=3.556,
+ layernorm_linear_attention_beta=1.0,
+ layernorm_mlp_alpha=3.556,
+ layernorm_mlp_beta=1.0,
+ pretrained_config=SimpleNamespace(prefix_name="model"),
+ )
+ pc = SimpleNamespace(tensor_parallel_size=1, tensor_parallel_rank=0, tp_group=None)
+ gc = SimpleNamespace(graph_opt_level=0, use_cudagraph=False)
+ return SimpleNamespace(model_config=mc, parallel_config=pc, graph_opt_config=gc)
+
+
+@pytest.fixture()
+def mm1(monkeypatch):
+ """Patch heavy GPU deps in minimax_m1 module with lightweight stubs."""
+ monkeypatch.setattr(minimax_m1, "RMSNorm", _StubRMSNorm)
+ monkeypatch.setattr(minimax_m1, "ColumnParallelLinear", _StubLinear)
+ monkeypatch.setattr(minimax_m1, "MergedColumnParallelLinear", _StubLinear)
+ monkeypatch.setattr(minimax_m1, "QKVParallelLinear", _StubLinear)
+ monkeypatch.setattr(minimax_m1, "RowParallelLinear", _StubLinear)
+ monkeypatch.setattr(minimax_m1, "ReplicatedLinear", _StubLinear)
+ monkeypatch.setattr(minimax_m1, "Attention", _StubAttention)
+ monkeypatch.setattr(minimax_m1, "SiluAndMul", _StubSiluAndMul)
+ monkeypatch.setattr(minimax_m1, "FusedMoE", _StubFusedMoE)
+ monkeypatch.setattr(minimax_m1, "VocabParallelEmbedding", _StubEmbedding)
+ monkeypatch.setattr(minimax_m1, "ParallelLMHead", _StubLMHead)
+ monkeypatch.setattr(minimax_m1, "lightning_attention", _stub_lightning_attention)
+ monkeypatch.setattr(minimax_m1, "tensor_model_parallel_all_reduce", _stub_all_reduce)
+ monkeypatch.setattr(minimax_m1, "support_graph_optimization", _stub_graph_opt)
+ return minimax_m1
+
+
+# ===================================================================
+# 1. Pure-logic tests (static methods — no stubs needed)
+# ===================================================================
+
+
+class TestBuildAttnTypeList:
+
+ def test_80_layers_has_10_full_attention(self):
+ attn_list = minimax_m1.MiniMaxM1DecoderLayer._build_attn_type_list(80)
+ assert len(attn_list) == 80
+ full_indices = [i for i, t in enumerate(attn_list) if t == 1]
+ assert full_indices == [7, 15, 23, 31, 39, 47, 55, 63, 71, 79]
+
+ def test_short_model_clips_indices(self):
+ attn_list = minimax_m1.MiniMaxM1DecoderLayer._build_attn_type_list(10)
+ assert len(attn_list) == 10
+ assert attn_list[7] == 1
+ assert sum(attn_list) == 1
+
+ def test_single_layer_all_linear(self):
+ assert minimax_m1.MiniMaxM1DecoderLayer._build_attn_type_list(1) == [0]
+
+ def test_all_linear_below_first_full_index(self):
+ assert all(t == 0 for t in minimax_m1.MiniMaxM1DecoderLayer._build_attn_type_list(7))
+
+
+class TestBuildSlopeTensor:
+
+ def test_power_of_two_heads(self):
+ slopes = minimax_m1.MiniMaxM1LinearAttention._build_slope_tensor(8)
+ assert slopes.shape == [8, 1, 1]
+ assert (slopes.flatten().numpy() > 0).all()
+
+ def test_non_power_of_two_heads(self):
+ slopes = minimax_m1.MiniMaxM1LinearAttention._build_slope_tensor(12)
+ assert slopes.shape == [12, 1, 1]
+ assert (slopes.flatten().numpy() > 0).all()
+
+ def test_64_heads_first_slope(self):
+ slopes = minimax_m1.MiniMaxM1LinearAttention._build_slope_tensor(64)
+ assert slopes.shape == [64, 1, 1]
+ expected_start = 2 ** (-(2 ** (-(math.log2(64) - 3))))
+ np.testing.assert_allclose(slopes.flatten().numpy()[0], expected_start, rtol=1e-5)
+
+ @pytest.mark.parametrize("n", [1, 2, 4, 8, 16, 32, 64])
+ def test_slopes_all_positive(self, n):
+ slopes = minimax_m1.MiniMaxM1LinearAttention._build_slope_tensor(n)
+ assert (slopes.flatten().numpy() > 0).all()
+
+
+# ===================================================================
+# 2. Model registration (uses real ModelRegistry)
+# ===================================================================
+
+
+class TestModelRegistration:
+
+ def test_primary_architecture_registered(self):
+ from fastdeploy.model_executor.models.model_base import ModelRegistry
+
+ assert "MiniMaxM1ForCausalLM" in ModelRegistry._arch_to_model_cls
+
+ def test_alias_architecture_registered(self):
+ from fastdeploy.model_executor.models.model_base import ModelRegistry
+
+ assert "MiniMaxText01ForCausalLM" in ModelRegistry._arch_to_model_cls
+
+ def test_registered_class(self):
+ from fastdeploy.model_executor.models.model_base import ModelRegistry
+
+ assert ModelRegistry._arch_to_model_cls["MiniMaxM1ForCausalLM"] is minimax_m1.MiniMaxM1ForCausalLM
+
+ def test_name_method(self):
+ assert minimax_m1.MiniMaxM1ForCausalLM.name() == "MiniMaxM1ForCausalLM"
+
+ def test_pretrained_model_names(self):
+ assert minimax_m1.MiniMaxM1PretrainedModel.arch_name() == "MiniMaxM1ForCausalLM"
+ assert minimax_m1.MiniMaxM1PretrainedModel.name() == "MiniMaxM1ForCausalLM"
+
+
+# ===================================================================
+# 3. Layer construction (exercises real FD code with stubs)
+# ===================================================================
+
+
+class TestDecoderLayerConstruction:
+
+ def test_linear_attention_layer(self, mm1):
+ fd = _make_fd_config()
+ layer = mm1.MiniMaxM1DecoderLayer(fd, layer_id=0, prefix="model.layers.0")
+ assert layer.attention_type == 0
+ assert isinstance(layer.self_attn, mm1.MiniMaxM1LinearAttention)
+ assert hasattr(layer.self_attn, "slope_rate")
+ assert hasattr(layer.self_attn, "output_gate")
+
+ def test_full_attention_layer(self, mm1):
+ fd = _make_fd_config()
+ layer = mm1.MiniMaxM1DecoderLayer(fd, layer_id=1, prefix="model.layers.1")
+ assert layer.attention_type == 1
+ assert isinstance(layer.self_attn, mm1.MiniMaxM1Attention)
+
+ def test_moe_when_experts_gt_1(self, mm1):
+ fd = _make_fd_config(num_local_experts=4)
+ layer = mm1.MiniMaxM1DecoderLayer(fd, layer_id=0, prefix="model.layers.0")
+ assert isinstance(layer.block_sparse_moe, mm1.MiniMaxM1MoE)
+
+ def test_dense_mlp_when_single_expert(self, mm1):
+ fd = _make_fd_config(num_local_experts=1)
+ layer = mm1.MiniMaxM1DecoderLayer(fd, layer_id=0, prefix="model.layers.0")
+ assert isinstance(layer.block_sparse_moe, mm1.MiniMaxM1MLP)
+
+ def test_fallback_attn_type_when_no_config(self, mm1):
+ fd = _make_fd_config(num_layers=80)
+ delattr(fd.model_config, "attn_type_list")
+ layer = mm1.MiniMaxM1DecoderLayer(fd, layer_id=7, prefix="model.layers.7")
+ assert layer.attention_type == 1
+
+
+# ===================================================================
+# 4. Forward pass tests (exercises real FD forward code)
+# ===================================================================
+
+
+def test_decoder_layer_forward_prenorm(mm1):
+ """Pre-norm forward: exercises real DecoderLayer.forward code path."""
+ fd = _make_fd_config(postnorm=False)
+ layer = mm1.MiniMaxM1DecoderLayer(fd, layer_id=0, prefix="model.layers.0")
+ meta = SimpleNamespace()
+ h = paddle.randn([2, 4])
+ out, residual = layer(forward_meta=meta, hidden_states=h)
+ assert out.shape[-1] == 4 and out.shape[0] == 2
+ assert residual.shape[-1] == 4 and residual.shape[0] == 2
+
+
+def test_decoder_layer_forward_postnorm(mm1):
+ """Post-norm forward: exercises the postnorm=True branch."""
+ fd = _make_fd_config(postnorm=True)
+ layer = mm1.MiniMaxM1DecoderLayer(fd, layer_id=0, prefix="model.layers.0")
+ meta = SimpleNamespace()
+ h = paddle.randn([2, 4])
+ out, residual = layer(forward_meta=meta, hidden_states=h)
+ assert out.shape[-1] == 4 and out.shape[0] == 2
+ assert residual.shape[-1] == 4 and residual.shape[0] == 2
+
+
+def test_decoder_layer_forward_full_attn(mm1):
+ """Full attention layer forward."""
+ fd = _make_fd_config()
+ layer = mm1.MiniMaxM1DecoderLayer(fd, layer_id=1, prefix="model.layers.1")
+ meta = SimpleNamespace()
+ h = paddle.randn([2, 4])
+ out, residual = layer(forward_meta=meta, hidden_states=h)
+ assert out.shape[-1] == 4 and out.shape[0] == 2
+
+
+def test_deepnorm_scaling(mm1):
+ """Verify DeepNorm alpha/beta are read from config."""
+ fd = _make_fd_config()
+ fd.model_config.layernorm_linear_attention_alpha = 2.0
+ fd.model_config.layernorm_mlp_alpha = 3.0
+ layer = mm1.MiniMaxM1DecoderLayer(fd, layer_id=0, prefix="model.layers.0")
+ assert layer.layernorm_attention_alpha == 2.0
+ assert layer.layernorm_mlp_alpha == 3.0
+
+
+def test_model_forward(mm1):
+ """MiniMaxM1Model forward: exercises embed -> layers -> norm chain."""
+ fd = _make_fd_config(hidden_size=4, num_layers=2)
+ model = mm1.MiniMaxM1Model(fd_config=fd)
+ ids = paddle.to_tensor([0, 1, 2], dtype="int64")
+ meta = SimpleNamespace()
+ out = model(ids_remove_padding=ids, forward_meta=meta)
+ assert out.shape[-1] == 4 and out.shape[0] == 3
+
+
+def test_model_load_state_dict(mm1):
+ """Verify load_state_dict delegates to all sublayers."""
+ fd = _make_fd_config(hidden_size=4, num_layers=2)
+ model = mm1.MiniMaxM1Model(fd_config=fd)
+ model.load_state_dict({"w": np.zeros([1], dtype=np.float32)})
+ assert model.embed_tokens.load_state_dict_called
+ assert model.norm.load_state_dict_called
+ for layer in model.layers:
+ assert layer.input_layernorm.load_state_dict_called
+
+
+def test_causallm_forward_and_compute_logits(mm1):
+ """CausalLM forward + compute_logits: exercises the top-level model."""
+ fd = _make_fd_config(hidden_size=4, num_layers=1)
+ model = mm1.MiniMaxM1ForCausalLM(fd)
+
+ ids = paddle.to_tensor([0, 1], dtype="int64")
+ meta = SimpleNamespace()
+ hidden = model(inputs={"ids_remove_padding": ids}, forward_meta=meta)
+ assert hidden.shape[-1] == 4 and hidden.shape[0] == 2
+
+ logits = model.compute_logits(hidden.astype("float16"), meta)
+ assert logits.dtype == paddle.float32
+
+
+def test_causallm_name(mm1):
+ """CausalLM.name() returns expected value."""
+ assert mm1.MiniMaxM1ForCausalLM.name() == "MiniMaxM1ForCausalLM"
+
+
+# ===================================================================
+# 5. set_state_dict — HF->FD weight remapping
+# ===================================================================
+
+
+def test_set_state_dict_expert_remap(mm1):
+ """set_state_dict remaps MoE expert weights: w1->gate_proj, w2->down_proj, w3->up_proj."""
+ fd = _make_fd_config(hidden_size=4, num_layers=1)
+ model = mm1.MiniMaxM1ForCausalLM(fd)
+
+ captured = {}
+ model.model.load_state_dict = lambda sd: captured.update(sd)
+ model.lm_head.load_state_dict = lambda sd: None
+
+ sd = {
+ "model.layers.0.block_sparse_moe.experts.0.w1.weight": np.ones([2, 4], dtype=np.float32),
+ "model.layers.0.block_sparse_moe.experts.0.w2.weight": np.ones([4, 2], dtype=np.float32),
+ "model.layers.0.block_sparse_moe.experts.0.w3.weight": np.ones([2, 4], dtype=np.float32),
+ }
+ model.set_state_dict(sd)
+
+ assert "model.layers.0.block_sparse_moe.experts.0.gate_proj.weight" in captured
+ assert "model.layers.0.block_sparse_moe.experts.0.down_proj.weight" in captured
+ assert "model.layers.0.block_sparse_moe.experts.0.up_proj.weight" in captured
+
+
+def test_set_state_dict_qkv_merge(mm1):
+ """set_state_dict merges q/k/v into qkv_proj for full attention layers."""
+ fd = _make_fd_config(hidden_size=4, num_layers=2, num_attention_heads=4, num_key_value_heads=2, head_dim=2)
+ model = mm1.MiniMaxM1ForCausalLM(fd)
+
+ captured = {}
+ model.model.load_state_dict = lambda sd: captured.update(sd)
+ model.lm_head.load_state_dict = lambda sd: None
+
+ # Layer 1 is full attention (attn_type_list=[0,1])
+ q_w = np.arange(16, dtype=np.float32).reshape(4, 4) # [num_heads * head_dim, hidden]
+ k_w = np.arange(8, dtype=np.float32).reshape(2, 4) # [num_kv_heads * head_dim, hidden]
+ v_w = np.arange(8, dtype=np.float32).reshape(2, 4)
+ sd = {
+ "model.layers.1.self_attn.q_proj.weight": q_w,
+ "model.layers.1.self_attn.k_proj.weight": k_w,
+ "model.layers.1.self_attn.v_proj.weight": v_w,
+ }
+ model.set_state_dict(sd)
+
+ merged_key = "model.layers.1.self_attn.qkv_proj.weight"
+ assert merged_key in captured
+ expected = np.concatenate([q_w, k_w, v_w], axis=0)
+ np.testing.assert_array_equal(captured[merged_key], expected)
+
+
+def test_set_state_dict_passthrough(mm1):
+ """Non-expert, non-qkv weights pass through unchanged."""
+ fd = _make_fd_config(hidden_size=4, num_layers=1)
+ model = mm1.MiniMaxM1ForCausalLM(fd)
+
+ captured = {}
+ model.model.load_state_dict = lambda sd: captured.update(sd)
+ model.lm_head.load_state_dict = lambda sd: None
+
+ sd = {"model.norm.weight": np.ones([4], dtype=np.float32)}
+ model.set_state_dict(sd)
+ assert "model.norm.weight" in captured
+
+
+def test_set_state_dict_qkv_paddle_tensors(mm1):
+ """QKV merge works with Paddle tensors (not just numpy)."""
+ fd = _make_fd_config(hidden_size=4, num_layers=2, num_attention_heads=4, num_key_value_heads=2, head_dim=2)
+ model = mm1.MiniMaxM1ForCausalLM(fd)
+
+ captured = {}
+ model.model.load_state_dict = lambda sd: captured.update(sd)
+ model.lm_head.load_state_dict = lambda sd: None
+
+ q_w = paddle.arange(16, dtype="float32").reshape([4, 4])
+ k_w = paddle.arange(8, dtype="float32").reshape([2, 4])
+ v_w = paddle.arange(8, dtype="float32").reshape([2, 4])
+ sd = {
+ "model.layers.1.self_attn.q_proj.weight": q_w,
+ "model.layers.1.self_attn.k_proj.weight": k_w,
+ "model.layers.1.self_attn.v_proj.weight": v_w,
+ }
+ model.set_state_dict(sd)
+
+ merged = captured["model.layers.1.self_attn.qkv_proj.weight"]
+ assert isinstance(merged, paddle.Tensor)
+ assert merged.shape == [8, 4]
+
+
+# ===================================================================
+# 6. MoE weight key map construction
+# ===================================================================
+
+
+def test_moe_default_weight_keys(mm1):
+ """Unquantized MoE: weight_key_map has plain .weight keys."""
+ fd = _make_fd_config(num_local_experts=4)
+ moe = mm1.MiniMaxM1MoE(fd, layer_id=0, prefix="model.layers.0.block_sparse_moe")
+ wkm = moe.experts.weight_key_map
+ assert "gate_weight_key" in wkm
+ assert wkm["up_gate_proj_expert_weight_key"].endswith(".up_gate_proj.weight")
+ assert "weight_scale" not in str(wkm)
+
+
+def test_moe_w4a8_weight_keys(mm1):
+ """w4a8 quant: weight_key_map has .quant_weight + scales."""
+ fd = _make_fd_config(num_local_experts=4)
+ fd.quant_config = SimpleNamespace(moe_quant_type="w4a8")
+ fd.model_config.is_quantized = True
+ moe = mm1.MiniMaxM1MoE(fd, layer_id=0, prefix="model.layers.0.block_sparse_moe")
+ wkm = moe.experts.weight_key_map
+ assert "quant_weight" in wkm["up_gate_proj_expert_weight_key"]
+ assert "weight_scale" in wkm["up_gate_proj_expert_weight_scale_key"]
+ assert "activation_scale" in wkm["up_gate_proj_expert_in_scale_key"]
+
+
+def test_moe_w4afp8_dynamic_weight_keys(mm1):
+ """Dynamic w4afp8: quant_weight + weight_scale but no activation_scale."""
+ fd = _make_fd_config(num_local_experts=4)
+ fd.quant_config = SimpleNamespace(moe_quant_type="w4afp8", moe_dynamic_quant=True)
+ fd.model_config.is_quantized = True
+ moe = mm1.MiniMaxM1MoE(fd, layer_id=0, prefix="model.layers.0.block_sparse_moe")
+ wkm = moe.experts.weight_key_map
+ assert "quant_weight" in wkm["up_gate_proj_expert_weight_key"]
+ assert "weight_scale" in wkm["up_gate_proj_expert_weight_scale_key"]
+ assert "in_scale_key" not in str(wkm)
+
+
+def test_moe_tp_all_reduce(mm1):
+ """MoE with tp_size > 1 sets the attribute."""
+ fd = _make_fd_config(num_local_experts=4)
+ fd.parallel_config.tensor_parallel_size = 2
+ moe = mm1.MiniMaxM1MoE(fd, layer_id=0, prefix="model.layers.0.block_sparse_moe")
+ assert moe.tp_size == 2
+
+
+# ===================================================================
+# 7. Linear attention construction and forward
+# ===================================================================
+
+
+def test_linear_attention_slope_rate_shape(mm1):
+ fd = _make_fd_config(num_layers=2, num_attention_heads=4, head_dim=2)
+ layer = mm1.MiniMaxM1LinearAttention(fd, layer_id=0, linear_layer_id=0, prefix="model.layers.0.self_attn")
+ assert layer.slope_rate.shape == [4, 1, 1]
+ assert (layer.slope_rate.flatten().numpy() > 0).all()
+
+
+def test_linear_attention_kv_cache_shape(mm1):
+ fd = _make_fd_config(num_attention_heads=4, head_dim=2)
+ layer = mm1.MiniMaxM1LinearAttention(fd, layer_id=0, linear_layer_id=0, prefix="model.layers.0.self_attn")
+ assert layer.kv_cache_shape == (4, 2, 2)
+
+
+def test_linear_attention_forward(mm1):
+ fd = _make_fd_config(hidden_size=4, num_attention_heads=4, head_dim=1)
+ layer = mm1.MiniMaxM1LinearAttention(fd, layer_id=0, linear_layer_id=0, prefix="model.layers.0.self_attn")
+ meta = SimpleNamespace()
+ h = paddle.randn([1, 4])
+ out = layer(forward_meta=meta, hidden_states=h)
+ # LinearAttention adds seq=1 dim internally via 4D reshape
+ assert out.shape[-1] == 4 and out.shape[0] == 1
+
+
+def test_linear_attention_load_state_dict(mm1):
+ fd = _make_fd_config(num_attention_heads=4, head_dim=2)
+ layer = mm1.MiniMaxM1LinearAttention(fd, layer_id=0, linear_layer_id=0, prefix="model.layers.0.self_attn")
+ sd = {"w": np.zeros([1], dtype=np.float32)}
+ layer.load_state_dict(sd)
+ assert layer.qkv_proj.load_state_dict_called
+ assert layer.output_gate.load_state_dict_called
+ assert layer.out_proj.load_state_dict_called
+ assert layer.norm.load_state_dict_called
+
+
+# ===================================================================
+# 8. Full attention
+# ===================================================================
+
+
+def test_full_attention_forward(mm1):
+ fd = _make_fd_config(hidden_size=4, num_attention_heads=4, num_key_value_heads=2, head_dim=2)
+ layer = mm1.MiniMaxM1Attention(fd, layer_id=1, prefix="model.layers.1.self_attn")
+ meta = SimpleNamespace()
+ h = paddle.randn([2, 4])
+ out = layer(forward_meta=meta, hidden_states=h)
+ assert out.shape[-1] == 4 and out.shape[0] == 2
+
+
+def test_full_attention_load_state_dict(mm1):
+ fd = _make_fd_config(num_attention_heads=4, num_key_value_heads=2, head_dim=2)
+ layer = mm1.MiniMaxM1Attention(fd, layer_id=1, prefix="model.layers.1.self_attn")
+ layer.load_state_dict({"w": np.zeros([1], dtype=np.float32)})
+ assert layer.qkv_proj.load_state_dict_called
+ assert layer.o_proj.load_state_dict_called
+ assert layer.attn.load_state_dict_called
+
+
+# ===================================================================
+# 9. MLP
+# ===================================================================
+
+
+def test_mlp_forward(mm1):
+ fd = _make_fd_config(num_local_experts=1)
+ mlp = mm1.MiniMaxM1MLP(fd, intermediate_size=8, prefix="model.layers.0.mlp")
+ h = paddle.randn([2, 4])
+ out = mlp.forward(h)
+ assert out.shape == [2, 4]
+
+
+def test_mlp_load_state_dict(mm1):
+ fd = _make_fd_config()
+ mlp = mm1.MiniMaxM1MLP(fd, intermediate_size=8, prefix="model.layers.0.mlp")
+ mlp.load_state_dict({"w": np.zeros([1], dtype=np.float32)})
+ assert mlp.gate_up_proj.load_state_dict_called
+ assert mlp.down_proj.load_state_dict_called
+
+
+# ===================================================================
+# 10. Lightning Attention — Pure-Python reference algorithm
+# ===================================================================
+
+
+def _lightning_attention_numpy_ref(q, k, v, slope, kv_history=None):
+ """
+ Pure NumPy reference implementation of linear attention with exponential decay.
+ """
+ b, h, n, d = q.shape
+ e = v.shape[-1]
+ output = np.zeros((b, h, n, e), dtype=np.float64)
+
+ if kv_history is None:
+ kv_state = np.zeros((b, h, d, e), dtype=np.float64)
+ else:
+ kv_state = kv_history.copy()
+
+ for t in range(n):
+ decay = np.exp(-slope)[np.newaxis, :, np.newaxis, np.newaxis]
+ kv_state = kv_state * decay
+ kt = k[:, :, t, :]
+ vt = v[:, :, t, :]
+ kv_state += kt[:, :, :, np.newaxis] * vt[:, :, np.newaxis, :]
+ qt = q[:, :, t, :]
+ output[:, :, t, :] = np.einsum("bhd,bhde->bhe", qt, kv_state)
+
+ return output, kv_state
+
+
+class TestLightningAttentionPurePython:
+ """Validate Lightning Attention algorithm correctness via NumPy reference."""
+
+ def test_single_token_output_shape(self):
+ b, h, n, d = 1, 4, 1, 16
+ q = np.random.randn(b, h, n, d)
+ k = np.random.randn(b, h, n, d)
+ v = np.random.randn(b, h, n, d)
+ slope = np.abs(np.random.randn(h)) * 0.1
+ output, kv = _lightning_attention_numpy_ref(q, k, v, slope)
+ assert output.shape == (b, h, n, d)
+ assert kv.shape == (b, h, d, d)
+
+ def test_multi_token_causal(self):
+ """With slope approaching 0, approaches causal linear attention."""
+ b, h, n, d = 1, 2, 4, 8
+ np.random.seed(42)
+ q = np.random.randn(b, h, n, d)
+ k = np.random.randn(b, h, n, d)
+ v = np.random.randn(b, h, n, d)
+ slope = np.full(h, 1e-8)
+ output, _ = _lightning_attention_numpy_ref(q, k, v, slope)
+
+ for t in range(n):
+ ref = np.zeros((b, h, d))
+ for j in range(t + 1):
+ kv_outer = k[:, :, j, :, np.newaxis] * v[:, :, j, np.newaxis, :]
+ ref += np.einsum("bhd,bhde->bhe", q[:, :, t, :], kv_outer)
+ np.testing.assert_allclose(output[:, :, t, :], ref, rtol=1e-5, atol=1e-7)
+
+ def test_kv_history_persistence(self):
+ """KV state from one call persists to the next (recurrent property)."""
+ b, h, n, d = 2, 4, 3, 16
+ np.random.seed(123)
+ q1 = np.random.randn(b, h, n, d)
+ k1 = np.random.randn(b, h, n, d)
+ v1 = np.random.randn(b, h, n, d)
+ q2 = np.random.randn(b, h, 1, d)
+ k2 = np.random.randn(b, h, 1, d)
+ v2 = np.random.randn(b, h, 1, d)
+ slope = np.abs(np.random.randn(h)) * 0.05
+ _, kv_after_1 = _lightning_attention_numpy_ref(q1, k1, v1, slope)
+ out2, _ = _lightning_attention_numpy_ref(q2, k2, v2, slope, kv_history=kv_after_1)
+ q_full = np.concatenate([q1, q2], axis=2)
+ k_full = np.concatenate([k1, k2], axis=2)
+ v_full = np.concatenate([v1, v2], axis=2)
+ out_full, _ = _lightning_attention_numpy_ref(q_full, k_full, v_full, slope)
+ np.testing.assert_allclose(out2[:, :, 0, :], out_full[:, :, n, :], rtol=1e-5, atol=1e-7)
+
+ def test_multi_head_independent(self):
+ """Heads are computed independently - zeroing one head Q zeros its output."""
+ b, h, n, d = 1, 8, 4, 16
+ np.random.seed(7)
+ q = np.random.randn(b, h, n, d)
+ k = np.random.randn(b, h, n, d)
+ v = np.random.randn(b, h, n, d)
+ slope = np.abs(np.random.randn(h)) * 0.1
+ q_masked = q.copy()
+ q_masked[:, 3, :, :] = 0.0
+ output, _ = _lightning_attention_numpy_ref(q_masked, k, v, slope)
+ np.testing.assert_allclose(output[:, 3, :, :], 0.0, atol=1e-12)
+
+
+def test_multi_layer_residual_no_blowup(mm1):
+ """Regression: multi-layer forward must not cause residual blowup.
+
+ C1 fix: DeepNorm folds the residual into hidden_states, so the layer
+ returns ``(hidden_states, None)`` — not ``(hidden_states, residual)``.
+ If the old behaviour returns a non-None residual, the next iteration
+ adds it again → exponential growth. This test stacks 4 layers and
+ checks the output norm stays bounded.
+ """
+ fd = _make_fd_config(hidden_size=4, num_layers=4)
+ model = mm1.MiniMaxM1Model(fd_config=fd)
+ ids = paddle.to_tensor([0, 1, 2, 3], dtype="int64")
+ meta = SimpleNamespace()
+ out = model(ids_remove_padding=ids, forward_meta=meta)
+ # With correct residual handling, output magnitude should stay O(1)
+ # relative to the stub operations (identity-ish norms, zero-init attn).
+ # With the old double-counting bug, 4 layers would amplify ~16x.
+ assert paddle.isfinite(out).all(), "Output contains NaN/Inf — residual blowup"
+ assert out.abs().max().item() < 1e4, (
+ f"Output magnitude {out.abs().max().item():.1f} too large — "
+ "possible residual double-counting (C1 regression)"
+ )
+
+
+def test_decoder_layer_returns_none_residual(mm1):
+ """DecoderLayer must return None as residual (DeepNorm convention)."""
+ fd = _make_fd_config()
+ layer = mm1.MiniMaxM1DecoderLayer(fd, layer_id=0, prefix="model.layers.0")
+ meta = SimpleNamespace()
+ h = paddle.randn([2, 4])
+ out, residual = layer(forward_meta=meta, hidden_states=h)
+ assert residual is None, f"Expected None residual (DeepNorm folds it into hidden_states), got {type(residual)}"
+
+
+if __name__ == "__main__":
+ pytest.main([__file__, "-v"])
diff --git a/tests/model_executor/test_minimax_m1_integration.py b/tests/model_executor/test_minimax_m1_integration.py
new file mode 100644
index 00000000000..efd7df2deec
--- /dev/null
+++ b/tests/model_executor/test_minimax_m1_integration.py
@@ -0,0 +1,527 @@
+# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Module for Hackathon 10th Spring No.47.
+Integration tests for MiniMax-M1 model with FastDeploy infrastructure.
+
+Proves that our model code works through FD's real pipelines:
+- Package imports (all public symbols accessible)
+- ModelRegistry resolution (both architecture names)
+- FDConfig construction from config.json
+- Weight key remapping (HF → FD) through load_weights iterator path
+- End-to-end forward pass with real (tiny) weights on GPU
+
+CPU-tier tests run in CI (no GPU). GPU-tier tests run on AI Studio A800.
+"""
+
+from __future__ import annotations
+
+import json
+import os
+import unittest
+from types import SimpleNamespace
+
+import numpy as np
+import paddle
+import pytest
+
+# ---------------------------------------------------------------------------
+# Tiny model config — production-faithful structure, minimal dimensions
+# ---------------------------------------------------------------------------
+
+_TINY_MODEL_CONFIG = {
+ "architectures": ["MiniMaxM1ForCausalLM"],
+ "model_type": "MiniMaxM1",
+ "hidden_size": 128,
+ "intermediate_size": 256,
+ "num_hidden_layers": 4,
+ "num_attention_heads": 4,
+ "num_key_value_heads": 4,
+ "head_dim": 32,
+ "vocab_size": 256,
+ "max_position_embeddings": 512,
+ "rms_norm_eps": 1e-5,
+ "num_local_experts": 2,
+ "num_experts_per_tok": 1,
+ "rope_theta": 10000.0,
+ "torch_dtype": "bfloat16",
+ "full_attention_layer_indices": [1, 3],
+ "attn_type_list": [0, 1, 0, 1], # linear, full, linear, full
+ "use_deep_norm": True,
+ "num_layers_for_deep_norm": 4,
+ "use_post_norm": True,
+ "hidden_act": "silu",
+ "norm_topk_prob": False,
+ "postnorm": False,
+}
+
+
+def _make_fd_config(**model_overrides):
+ """Build a minimal FDConfig-like namespace for CPU tests."""
+ mc_dict = dict(_TINY_MODEL_CONFIG)
+ mc_dict.update(model_overrides)
+ mc_dict["pretrained_config"] = SimpleNamespace(prefix_name="model")
+ mc = SimpleNamespace(**mc_dict)
+ pc = SimpleNamespace(tensor_parallel_size=1, tensor_parallel_rank=0, tp_group=None)
+ gc = SimpleNamespace(graph_opt_level=0, use_cudagraph=False)
+ return SimpleNamespace(
+ model_config=mc,
+ parallel_config=pc,
+ graph_opt_config=gc,
+ )
+
+
+def _write_config_json(tmp_dir, overrides=None):
+ """Write a minimal config.json that mimics real MiniMax-M1 HF layout."""
+ cfg = dict(_TINY_MODEL_CONFIG)
+ if overrides:
+ cfg.update(overrides)
+ config_path = os.path.join(tmp_dir, "config.json")
+ with open(config_path, "w") as f:
+ json.dump(cfg, f)
+ return config_path
+
+
+# ═══════════════════════════════════════════════════════════════════════════
+# Tier 1 — CPU tests (run in CI)
+# ═══════════════════════════════════════════════════════════════════════════
+
+
+class TestPackageImports:
+ """Prove all public MiniMax-M1 symbols are importable from FD."""
+
+ def test_import_model_module(self):
+ from fastdeploy.model_executor.models import minimax_m1
+
+ assert hasattr(minimax_m1, "MiniMaxM1ForCausalLM")
+
+ def test_import_causal_lm(self):
+ from fastdeploy.model_executor.models.minimax_m1 import MiniMaxM1ForCausalLM
+
+ assert MiniMaxM1ForCausalLM is not None
+
+ def test_import_pretrained_model(self):
+ from fastdeploy.model_executor.models.minimax_m1 import MiniMaxM1PretrainedModel
+
+ assert MiniMaxM1PretrainedModel is not None
+
+ def test_import_all_classes(self):
+ from fastdeploy.model_executor.models.minimax_m1 import (
+ MiniMaxM1Attention,
+ MiniMaxM1DecoderLayer,
+ MiniMaxM1ForCausalLM,
+ MiniMaxM1LinearAttention,
+ MiniMaxM1MLP,
+ MiniMaxM1Model,
+ MiniMaxM1MoE,
+ MiniMaxM1PretrainedModel,
+ )
+
+ classes = [
+ MiniMaxM1MLP,
+ MiniMaxM1MoE,
+ MiniMaxM1Attention,
+ MiniMaxM1LinearAttention,
+ MiniMaxM1DecoderLayer,
+ MiniMaxM1Model,
+ MiniMaxM1ForCausalLM,
+ MiniMaxM1PretrainedModel,
+ ]
+ for cls in classes:
+ assert callable(cls), f"{cls.__name__} should be callable"
+
+ def test_lightning_attention_importable(self):
+ from fastdeploy.model_executor.ops.triton_ops import lightning_attn
+
+ assert hasattr(lightning_attn, "lightning_attention")
+
+
+class TestModelRegistryResolution:
+ """Prove ModelRegistry resolves MiniMax-M1 by both architecture names."""
+
+ def test_primary_arch_resolves(self):
+ from fastdeploy.model_executor.models.model_base import ModelRegistry
+
+ cls = ModelRegistry.get_class("MiniMaxM1ForCausalLM")
+ assert cls.__name__ == "MiniMaxM1ForCausalLM"
+
+ def test_alias_arch_resolves(self):
+ from fastdeploy.model_executor.models.model_base import ModelRegistry
+
+ cls = ModelRegistry.get_class("MiniMaxText01ForCausalLM")
+ assert cls.__name__ == "MiniMaxM1ForCausalLM"
+
+ def test_both_resolve_to_same_class(self):
+ from fastdeploy.model_executor.models.model_base import ModelRegistry
+
+ primary = ModelRegistry.get_class("MiniMaxM1ForCausalLM")
+ alias = ModelRegistry.get_class("MiniMaxText01ForCausalLM")
+ assert primary is alias
+
+ def test_in_supported_archs(self):
+ from fastdeploy.model_executor.models.model_base import ModelRegistry
+
+ supported = ModelRegistry.get_supported_archs()
+ assert "MiniMaxM1ForCausalLM" in supported
+ assert "MiniMaxText01ForCausalLM" in supported
+
+
+class TestHFToFDWeightKeyMapping:
+ """Prove the HF→FD weight key remapping pipeline works correctly.
+
+ Tests set_state_dict (v2 path) with real numpy arrays — verifying that
+ HF checkpoint key conventions are correctly transformed to FD conventions.
+ This is the most common source of integration bugs.
+ """
+
+ @pytest.fixture
+ def tiny_model(self, monkeypatch):
+ """Build a MiniMaxM1ForCausalLM with minimal stubs for weight loading."""
+ from fastdeploy.model_executor.models import minimax_m1
+
+ # Lightweight stubs that track load_state_dict calls
+ class _TrackingLayer(paddle.nn.Layer):
+ def __init__(self, *a, **kw):
+ super().__init__()
+ self.loaded_keys = []
+
+ def forward(self, x, *a, **kw):
+ return x
+
+ def load_state_dict(self, sd):
+ self.loaded_keys.extend(sd.keys())
+
+ class _TrackingLinear(_TrackingLayer):
+ def __init__(self, *a, **kw):
+ super().__init__()
+ self._out = kw.get("output_size", 128)
+
+ def forward(self, x, *a, **kw):
+ shape = list(x.shape)
+ shape[-1] = self._out
+ return paddle.zeros(shape, dtype=x.dtype)
+
+ class _TrackingNorm(_TrackingLayer):
+ def forward(self, x, residual_input=None, forward_meta=None):
+ if residual_input is None:
+ residual_input = paddle.zeros_like(x)
+ return x, residual_input + x
+
+ class _TrackingMoE(_TrackingLayer):
+ def __init__(self, *a, **kw):
+ super().__init__()
+ self.loaded_keys = []
+ self.weight_key_map = kw.get("weight_key_map", {})
+
+ def forward(self, hidden_states, gate, forward_meta=None):
+ return hidden_states
+
+ class _TrackingAttn(_TrackingLayer):
+ def forward(self, q, k, v, forward_meta=None):
+ return q
+
+ class _TrackingEmbed(_TrackingLayer):
+ def forward(self, x, *a, **kw):
+ return paddle.zeros([x.shape[0], 128], dtype="float32")
+
+ class _TrackingLMHead(_TrackingLayer):
+ def forward(self, x, *a, **kw):
+ return paddle.zeros([x.shape[0], 256], dtype="float32")
+
+ # Patch constructors
+ monkeypatch.setattr(minimax_m1, "RMSNorm", _TrackingNorm)
+ monkeypatch.setattr(minimax_m1, "ColumnParallelLinear", _TrackingLinear)
+ monkeypatch.setattr(minimax_m1, "RowParallelLinear", _TrackingLinear)
+ monkeypatch.setattr(minimax_m1, "MergedColumnParallelLinear", _TrackingLinear)
+ monkeypatch.setattr(minimax_m1, "QKVParallelLinear", _TrackingLinear)
+ monkeypatch.setattr(minimax_m1, "ReplicatedLinear", _TrackingLinear)
+ monkeypatch.setattr(minimax_m1, "Attention", _TrackingAttn)
+ monkeypatch.setattr(minimax_m1, "FusedMoE", _TrackingMoE)
+ monkeypatch.setattr(minimax_m1, "VocabParallelEmbedding", _TrackingEmbed)
+ monkeypatch.setattr(minimax_m1, "ParallelLMHead", _TrackingLMHead)
+ monkeypatch.setattr(minimax_m1, "SiluAndMul", lambda *a, **kw: (lambda x: x[..., : x.shape[-1] // 2]))
+ monkeypatch.setattr(minimax_m1, "lightning_attention", lambda *a, **kw: (a[0], paddle.zeros([1])))
+ monkeypatch.setattr(minimax_m1, "tensor_model_parallel_all_reduce", lambda x: x)
+ monkeypatch.setattr(minimax_m1, "support_graph_optimization", lambda *a, **kw: (lambda fn: fn))
+
+ cfg = _make_fd_config()
+ model = minimax_m1.MiniMaxM1ForCausalLM(cfg)
+ return model
+
+ def test_expert_w1_w2_w3_renamed(self, tiny_model):
+ """HF w1→gate_proj, w3→up_proj, w2→down_proj in MoE experts."""
+ sd = {}
+ # Layer 0 = linear attention layer (not in full_attention_layer_indices [1,3])
+ # MoE layer
+ sd["model.layers.0.block_sparse_moe.experts.0.w1.weight"] = np.ones((256, 128), dtype=np.float32)
+ sd["model.layers.0.block_sparse_moe.experts.0.w2.weight"] = np.ones((128, 256), dtype=np.float32)
+ sd["model.layers.0.block_sparse_moe.experts.0.w3.weight"] = np.ones((256, 128), dtype=np.float32)
+
+ tiny_model.set_state_dict(sd)
+
+ # Verify renamed keys were passed to MoE sublayer's experts
+ moe = tiny_model.model.layers[0].block_sparse_moe
+ # MiniMaxM1MoE.load_state_dict dispatches to self.gate and self.experts
+ expert_keys = moe.experts.loaded_keys
+ assert any("gate_proj" in k for k in expert_keys), f"Expected gate_proj, got {expert_keys}"
+ assert any("down_proj" in k for k in expert_keys), f"Expected down_proj, got {expert_keys}"
+ assert any("up_proj" in k for k in expert_keys), f"Expected up_proj, got {expert_keys}"
+
+ def test_qkv_merge_for_full_attention_layers(self, tiny_model):
+ """Full attention layers merge separate q/k/v → qkv_proj."""
+ sd = {}
+ # Layer 1 is a full attention layer (index 1 in full_attention_layer_indices)
+ sd["model.layers.1.self_attn.q_proj.weight"] = np.ones((128, 128), dtype=np.float32)
+ sd["model.layers.1.self_attn.k_proj.weight"] = np.ones((128, 128), dtype=np.float32) * 2
+ sd["model.layers.1.self_attn.v_proj.weight"] = np.ones((128, 128), dtype=np.float32) * 3
+
+ tiny_model.set_state_dict(sd)
+
+ attn = tiny_model.model.layers[1].self_attn
+ assert any(
+ "qkv_proj" in k for k in attn.qkv_proj.loaded_keys
+ ), f"Expected qkv_proj merge, got {attn.qkv_proj.loaded_keys}"
+
+ def test_norm_and_embed_passthrough(self, tiny_model):
+ """Non-expert, non-attention keys pass through unchanged."""
+ sd = {}
+ sd["model.embed_tokens.weight"] = np.ones((256, 128), dtype=np.float32)
+ sd["model.norm.weight"] = np.ones(128, dtype=np.float32)
+
+ tiny_model.set_state_dict(sd)
+
+ embed = tiny_model.model.embed_tokens
+ assert len(embed.loaded_keys) > 0, "embed_tokens should receive weights"
+
+ def test_all_layer_types_receive_weights(self, tiny_model):
+ """Build a full HF-style state dict and verify every layer gets called."""
+ sd = {}
+ for i in range(4):
+ # Input norm
+ sd[f"model.layers.{i}.input_layernorm.weight"] = np.ones(128, dtype=np.float32)
+ sd[f"model.layers.{i}.post_attention_layernorm.weight"] = np.ones(128, dtype=np.float32)
+
+ if i in [1, 3]: # full attention
+ sd[f"model.layers.{i}.self_attn.q_proj.weight"] = np.ones((128, 128), dtype=np.float32)
+ sd[f"model.layers.{i}.self_attn.k_proj.weight"] = np.ones((128, 128), dtype=np.float32)
+ sd[f"model.layers.{i}.self_attn.v_proj.weight"] = np.ones((128, 128), dtype=np.float32)
+ sd[f"model.layers.{i}.self_attn.o_proj.weight"] = np.ones((128, 128), dtype=np.float32)
+ else: # linear attention
+ sd[f"model.layers.{i}.self_attn.q_proj.weight"] = np.ones((128, 128), dtype=np.float32)
+ sd[f"model.layers.{i}.self_attn.k_proj.weight"] = np.ones((128, 128), dtype=np.float32)
+ sd[f"model.layers.{i}.self_attn.v_proj.weight"] = np.ones((128, 128), dtype=np.float32)
+ sd[f"model.layers.{i}.self_attn.out_proj.weight"] = np.ones((128, 128), dtype=np.float32)
+ sd[f"model.layers.{i}.self_attn.output_gate.weight"] = np.ones((128, 128), dtype=np.float32)
+
+ # MoE
+ for e in range(2):
+ sd[f"model.layers.{i}.block_sparse_moe.experts.{e}.w1.weight"] = np.ones((256, 128), dtype=np.float32)
+ sd[f"model.layers.{i}.block_sparse_moe.experts.{e}.w2.weight"] = np.ones((128, 256), dtype=np.float32)
+ sd[f"model.layers.{i}.block_sparse_moe.experts.{e}.w3.weight"] = np.ones((256, 128), dtype=np.float32)
+ sd[f"model.layers.{i}.block_sparse_moe.gate.weight"] = np.ones((2, 128), dtype=np.float32)
+
+ sd["model.embed_tokens.weight"] = np.ones((256, 128), dtype=np.float32)
+ sd["model.norm.weight"] = np.ones(128, dtype=np.float32)
+ sd["lm_head.weight"] = np.ones((256, 128), dtype=np.float32)
+
+ tiny_model.set_state_dict(sd)
+
+ # Verify embed, model norm, and lm_head all got weights
+ assert len(tiny_model.model.embed_tokens.loaded_keys) > 0
+ assert len(tiny_model.lm_head.loaded_keys) > 0
+
+
+class TestModelConstruction:
+ """Prove MiniMaxM1ForCausalLM constructs correctly with right layer types."""
+
+ @pytest.fixture
+ def model(self, monkeypatch):
+ """Build model with stubs to verify construction on CPU."""
+ from fastdeploy.model_executor.models import minimax_m1
+
+ class _Stub(paddle.nn.Layer):
+ def __init__(self, *a, **kw):
+ super().__init__()
+
+ def forward(self, *a, **kw):
+ return a[0] if a else paddle.zeros([1])
+
+ def load_state_dict(self, _sd):
+ pass
+
+ class _StubNorm(_Stub):
+ def forward(self, x, residual_input=None, forward_meta=None):
+ r = residual_input if residual_input is not None else paddle.zeros_like(x)
+ return x, r + x
+
+ class _StubAttn(_Stub):
+ def forward(self, q, k, v, forward_meta=None):
+ return q
+
+ class _StubMoE(_Stub):
+ def __init__(self, *a, **kw):
+ super().__init__()
+ self.weight_key_map = kw.get("weight_key_map", {})
+
+ def forward(self, hidden_states, gate, forward_meta=None):
+ return hidden_states
+
+ monkeypatch.setattr(minimax_m1, "RMSNorm", _StubNorm)
+ monkeypatch.setattr(minimax_m1, "ColumnParallelLinear", _Stub)
+ monkeypatch.setattr(minimax_m1, "RowParallelLinear", _Stub)
+ monkeypatch.setattr(minimax_m1, "MergedColumnParallelLinear", _Stub)
+ monkeypatch.setattr(minimax_m1, "QKVParallelLinear", _Stub)
+ monkeypatch.setattr(minimax_m1, "ReplicatedLinear", _Stub)
+ monkeypatch.setattr(minimax_m1, "Attention", _StubAttn)
+ monkeypatch.setattr(minimax_m1, "FusedMoE", _StubMoE)
+ monkeypatch.setattr(minimax_m1, "VocabParallelEmbedding", _Stub)
+ monkeypatch.setattr(minimax_m1, "ParallelLMHead", _Stub)
+ monkeypatch.setattr(minimax_m1, "SiluAndMul", lambda *a, **kw: (lambda x: x[..., : x.shape[-1] // 2]))
+ monkeypatch.setattr(minimax_m1, "lightning_attention", lambda *a, **kw: (a[0], paddle.zeros([1])))
+ monkeypatch.setattr(minimax_m1, "tensor_model_parallel_all_reduce", lambda x: x)
+ monkeypatch.setattr(minimax_m1, "support_graph_optimization", lambda *a, **kw: (lambda fn: fn))
+
+ cfg = _make_fd_config()
+ return minimax_m1.MiniMaxM1ForCausalLM(cfg)
+
+ def test_correct_number_of_layers(self, model):
+ assert len(model.model.layers) == 4
+
+ def test_full_attention_at_configured_indices(self, model):
+ """Full attention layers at indices [1, 3], linear at [0, 2]."""
+ from fastdeploy.model_executor.models.minimax_m1 import MiniMaxM1DecoderLayer
+
+ for i, layer in enumerate(model.model.layers):
+ assert isinstance(layer, MiniMaxM1DecoderLayer)
+ if i in [1, 3]:
+ assert layer.attention_type == 1, f"Layer {i} should be full attention (1), got {layer.attention_type}"
+ else:
+ assert (
+ layer.attention_type == 0
+ ), f"Layer {i} should be linear attention (0), got {layer.attention_type}"
+
+ def test_model_name_method(self, model):
+ assert model.name() == "MiniMaxM1ForCausalLM"
+
+
+# ═══════════════════════════════════════════════════════════════════════════
+# Tier 2 — GPU integration tests (run on AI Studio A800 via SSH)
+# See also: tests/model_executor/test_minimax_m1_smoke.py (kernel-level GPU tests)
+# See also: tests/operators/test_lightning_attn_triton.py (Triton kernel tests)
+# See also: tests/model_executor/validate_minimax_m1_e2e.py (E2E server test)
+# ═══════════════════════════════════════════════════════════════════════════
+
+_GPU_AVAILABLE = paddle.is_compiled_with_cuda() and paddle.device.cuda.device_count() > 0
+_GPU_SKIP_MSG = "No CUDA GPU available — GPU integration tests require A800/V100"
+
+
+@pytest.mark.gpu
+@unittest.skipUnless(_GPU_AVAILABLE, _GPU_SKIP_MSG)
+class TestModelWithRealTritonKernels(unittest.TestCase):
+ """Prove MiniMax-M1 model layers produce correct output via real Triton kernels.
+
+ Unlike test_minimax_m1_smoke.py (which tests kernels in isolation), this
+ tests through the actual MiniMaxM1LinearAttention and MiniMaxM1DecoderLayer
+ code paths — proving the model's forward() method correctly calls Triton ops.
+ """
+
+ def _build_slope(self, n_heads):
+ """Build ALiBi-style slope tensor (same as production code)."""
+ import math
+
+ def get_slopes_power_of_2(n):
+ start = 2 ** (-(2 ** (-(math.log2(n) - 3))))
+ return [start * (start**i) for i in range(n)]
+
+ if math.log2(n_heads).is_integer():
+ slopes = get_slopes_power_of_2(n_heads)
+ else:
+ nearest = 2 ** math.floor(math.log2(n_heads))
+ slopes = get_slopes_power_of_2(nearest) + get_slopes_power_of_2(2 * nearest)[0::2][: n_heads - nearest]
+ return paddle.to_tensor(slopes, dtype="float32").reshape([n_heads, 1, 1])
+
+ def test_linear_attention_layer_forward(self):
+ """lightning_attention() produces valid output via real Triton kernel."""
+ from fastdeploy.model_executor.ops.triton_ops.lightning_attn import (
+ lightning_attention,
+ )
+
+ B, H, S, D = 1, 8, 256, 128 # H=8, S=BLOCK, D>=128 for kernel
+
+ q = paddle.randn([B, H, S, D], dtype="float16")
+ k = paddle.randn([B, H, S, D], dtype="float16")
+ v = paddle.randn([B, H, S, D], dtype="float16")
+ ed = self._build_slope(H).squeeze(-1) # [H, 1] — wrapper reshapes
+
+ out, kv = lightning_attention(q, k, v, ed, block_size=256)
+
+ self.assertEqual(list(out.shape), [B, H, S, D])
+ self.assertFalse(paddle.isnan(out).any().item(), "Output contains NaN")
+ self.assertTrue(paddle.isfinite(out).all().item(), "Output contains Inf")
+ self.assertTrue(kv.abs().sum().item() > 0, "KV state is all zeros")
+
+ def test_decode_kernel_single_token(self):
+ """Decode kernel handles single-token autoregressive step."""
+ from fastdeploy.model_executor.ops.triton_ops.lightning_attn import (
+ linear_decode_forward_triton,
+ )
+
+ B, H, D = 2, 4, 128
+ q = paddle.randn([B, H, 1, D], dtype="float16")
+ k = paddle.randn([B, H, 1, D], dtype="float16")
+ v = paddle.randn([B, H, 1, D], dtype="float16")
+ kv_state = paddle.zeros([B, H, D, D], dtype="float32")
+ slope_rate = self._build_slope(H).squeeze(-1).squeeze(-1) # [H]
+ slot_idx = paddle.arange(B, dtype="int64")
+
+ out = linear_decode_forward_triton(q, k, v, kv_state, slope_rate, slot_idx)
+
+ # Output: [B, H*D] (heads flattened by kernel)
+ self.assertEqual(list(out.shape), [B, H * D])
+ self.assertFalse(paddle.isnan(out).any().item())
+
+ def test_two_step_decode_state_accumulates(self):
+ """Two decode steps via Triton: KV state should differ from fresh state."""
+ from fastdeploy.model_executor.ops.triton_ops.lightning_attn import (
+ linear_decode_forward_triton,
+ )
+
+ B, H, D = 1, 4, 128
+ kv_state = paddle.zeros([B, H, D, D], dtype="float32")
+ slope_rate = self._build_slope(H).squeeze(-1).squeeze(-1) # [H]
+ slot_idx = paddle.arange(B, dtype="int64")
+
+ # Step 1
+ q1 = paddle.randn([B, H, 1, D], dtype="float16")
+ k1 = paddle.randn([B, H, 1, D], dtype="float16")
+ v1 = paddle.randn([B, H, 1, D], dtype="float16")
+ _out1 = linear_decode_forward_triton(q1, k1, v1, kv_state, slope_rate, slot_idx) # noqa: F841
+
+ # KV state should be updated in-place
+ self.assertTrue(kv_state.abs().sum().item() > 0, "KV state not updated after step 1")
+
+ # Step 2 with different input
+ q2 = paddle.randn([B, H, 1, D], dtype="float16")
+ k2 = paddle.randn([B, H, 1, D], dtype="float16")
+ v2 = paddle.randn([B, H, 1, D], dtype="float16")
+ kv_before = kv_state.clone()
+ _out2 = linear_decode_forward_triton(q2, k2, v2, kv_state, slope_rate, slot_idx) # noqa: F841
+
+ # State should change between step 1 and step 2
+ state_changed = (kv_state - kv_before).abs().sum().item() > 0
+ self.assertTrue(state_changed, "KV state unchanged after step 2")
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/model_executor/test_minimax_m1_smoke.py b/tests/model_executor/test_minimax_m1_smoke.py
new file mode 100644
index 00000000000..23878814d69
--- /dev/null
+++ b/tests/model_executor/test_minimax_m1_smoke.py
@@ -0,0 +1,342 @@
+# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Module for Hackathon 10th Spring No.47.
+MiniMax-M1 integration smoke tests — real GPU kernels, no mocks.
+
+These tests exercise the production code paths used by MiniMaxM1LinearAttention:
+ 1. `lightning_attention()` — the chunked prefill wrapper that calls
+ `lightning_attention_forward()` in a loop over head-dim chunks.
+ 2. `linear_decode_forward_triton()` — the single-step decode kernel.
+ 3. `_build_slope_tensor()` — ALiBi-style decay tensor construction.
+ 4. End-to-end prefill → decode transition with KV state carry-over.
+
+All tests run on a single GPU without model weights or TP > 1.
+
+Validated on: AI Studio V100 (SM70), Paddle 3.3.0, Triton 3.x
+CI marker: @pytest.mark.gpu
+"""
+
+import math
+import unittest
+
+import numpy as np
+import paddle
+import pytest
+
+# ---------------------------------------------------------------------------
+# GPU guard
+# ---------------------------------------------------------------------------
+
+_GPU_AVAILABLE = paddle.is_compiled_with_cuda() and paddle.device.cuda.device_count() > 0
+_SKIP_MSG = "No CUDA GPU available — MiniMax-M1 smoke tests require GPU"
+
+
+def _import_ops():
+ """Lazy import to avoid collection failure on CPU-only boxes."""
+ from fastdeploy.model_executor.ops.triton_ops.lightning_attn import (
+ lightning_attention,
+ linear_decode_forward_triton,
+ )
+
+ return lightning_attention, linear_decode_forward_triton
+
+
+# ---------------------------------------------------------------------------
+# NumPy reference
+# ---------------------------------------------------------------------------
+
+
+def _lightning_attention_numpy_ref(q, k, v, slope, kv_history=None):
+ """
+ Pure NumPy reference for lightning attention with exponential decay.
+ Iterates over time steps — slow but correct.
+ """
+ b, h, n, d = q.shape
+ e = v.shape[-1]
+ output = np.zeros((b, h, n, e), dtype=np.float64)
+
+ if kv_history is None:
+ kv_state = np.zeros((b, h, d, e), dtype=np.float64)
+ else:
+ kv_state = kv_history.copy()
+
+ for t in range(n):
+ decay = np.exp(-slope)[np.newaxis, :, np.newaxis, np.newaxis]
+ kv_state = kv_state * decay
+ kt = k[:, :, t, :]
+ vt = v[:, :, t, :]
+ kv_state += kt[:, :, :, np.newaxis] * vt[:, :, np.newaxis, :]
+ qt = q[:, :, t, :]
+ output[:, :, t, :] = np.einsum("bhd,bhde->bhe", qt, kv_state)
+
+ return output, kv_state
+
+
+# ---------------------------------------------------------------------------
+# Slope tensor builder — copied from MiniMaxM1LinearAttention._build_slope_tensor
+# to test independently without FDConfig.
+# ---------------------------------------------------------------------------
+
+
+def _build_slope_tensor(n_heads):
+ """Build ALiBi-style slope tensor (matches production code exactly)."""
+
+ def get_slopes_power_of_2(n):
+ start = 2 ** (-(2 ** (-(math.log2(n) - 3))))
+ return [start * (start**i) for i in range(n)]
+
+ if math.log2(n_heads).is_integer():
+ slopes = get_slopes_power_of_2(n_heads)
+ else:
+ closest_power = 2 ** math.floor(math.log2(n_heads))
+ slopes = get_slopes_power_of_2(closest_power)
+ slopes += get_slopes_power_of_2(2 * closest_power)[0::2][: n_heads - closest_power]
+
+ return paddle.to_tensor(slopes, dtype=paddle.float32).reshape([n_heads, 1, 1])
+
+
+# ---------------------------------------------------------------------------
+# Test suite
+# ---------------------------------------------------------------------------
+
+
+@pytest.mark.gpu
+@unittest.skipUnless(_GPU_AVAILABLE, _SKIP_MSG)
+class TestMiniMaxM1Smoke(unittest.TestCase):
+ """
+ Integration smoke tests for MiniMax-M1 lightning attention pipeline.
+ Exercises the REAL Triton kernels on GPU — no stubs, no mocks.
+ """
+
+ @classmethod
+ def setUpClass(cls):
+ paddle.set_device("gpu:0")
+ # Store as list to avoid Python descriptor binding (self would be
+ # passed as first arg if a bare function is set as class attribute).
+ la, df = _import_ops()
+ cls._ops = [la, df]
+
+ def _call_lightning_attention(self, *args, **kwargs):
+ return self._ops[0](*args, **kwargs)
+
+ def _call_decode_forward(self, *args, **kwargs):
+ return self._ops[1](*args, **kwargs)
+
+ # === 1. Lightning attention (chunked prefill wrapper) ==================
+
+ def test_lightning_attention_basic(self):
+ """
+ lightning_attention() with head_dim=128, the production dimension.
+ Verify output is finite, shape matches, and roughly agrees with reference.
+ """
+ b, h, n, d = 1, 8, 256, 128
+ rng = np.random.default_rng(42)
+
+ q_np = rng.standard_normal((b, h, n, d)).astype(np.float64) * 0.1
+ k_np = rng.standard_normal((b, h, n, d)).astype(np.float64) * 0.1
+ v_np = rng.standard_normal((b, h, n, d)).astype(np.float64) * 0.1
+
+ # Build slope as production does: [n_heads, 1, 1] → squeeze to [n_heads]
+ slope_full = _build_slope_tensor(h) # [h, 1, 1]
+ slope_np = slope_full.squeeze(-1).squeeze(-1).numpy().astype(np.float64)
+
+ # NumPy reference
+ ref_out, _ = _lightning_attention_numpy_ref(q_np, k_np, v_np, slope_np)
+
+ # GPU tensors
+ q = paddle.to_tensor(q_np.astype(np.float32), dtype="float16")
+ k = paddle.to_tensor(k_np.astype(np.float32), dtype="float16")
+ v = paddle.to_tensor(v_np.astype(np.float32), dtype="float16")
+ ed = slope_full.squeeze(-1) # [h, 1] — wrapper reshapes to [1, h, 1, 1]
+
+ out, kv = self._call_lightning_attention(q, k, v, ed, block_size=256)
+
+ self.assertEqual(list(out.shape), [b, h, n, d])
+ self.assertEqual(list(kv.shape), [b, h, d, d])
+ self.assertFalse(paddle.isnan(out).any().item(), "Output has NaN")
+ self.assertTrue(paddle.isfinite(out).all().item(), "Output has Inf")
+
+ # Tolerance: chunked approach + fp16 → generous but must be correlated
+ out_np = out.astype("float32").numpy()
+ cos_sim = np.sum(out_np * ref_out.astype(np.float32)) / (
+ np.linalg.norm(out_np) * np.linalg.norm(ref_out.astype(np.float32)) + 1e-12
+ )
+ self.assertGreater(cos_sim, 0.9, f"Cosine similarity {cos_sim:.4f} too low")
+
+ def test_lightning_attention_multi_batch(self):
+ """lightning_attention() with batch_size=2 and bfloat16."""
+ b, h, n, d = 2, 8, 256, 128
+
+ q = paddle.randn([b, h, n, d], dtype="bfloat16")
+ k = paddle.randn([b, h, n, d], dtype="bfloat16")
+ v = paddle.randn([b, h, n, d], dtype="bfloat16")
+ ed = _build_slope_tensor(h).squeeze(-1) # [h, 1]
+
+ out, kv = self._call_lightning_attention(q, k, v, ed, block_size=256)
+
+ self.assertEqual(list(out.shape), [b, h, n, d])
+ self.assertFalse(paddle.isnan(out).any().item())
+
+ def test_lightning_attention_kv_state_nonzero(self):
+ """After prefill, KV state should be non-zero (kernel populated it)."""
+ b, h, n, d = 1, 4, 256, 64
+
+ q = paddle.randn([b, h, n, d], dtype="float16")
+ k = paddle.randn([b, h, n, d], dtype="float16")
+ v = paddle.randn([b, h, n, d], dtype="float16")
+ ed = _build_slope_tensor(h).squeeze(-1)
+
+ _, kv = self._call_lightning_attention(q, k, v, ed, block_size=256)
+
+ kv_np = kv.numpy()
+ self.assertGreater(np.abs(kv_np).max(), 0.0, "KV state is all zeros after prefill")
+
+ # === 2. Linear decode forward (single-step autoregressive) =============
+
+ def test_decode_forward_basic(self):
+ """
+ linear_decode_forward_triton() — single-step decode path.
+ This is the kernel used during generation after prefill.
+ """
+ b, h, d = 2, 8, 128
+ q = paddle.randn([b, h, 1, d], dtype="float16")
+ k = paddle.randn([b, h, 1, d], dtype="float16")
+ v = paddle.randn([b, h, 1, d], dtype="float16")
+ kv_caches = paddle.zeros([b, h, d, d], dtype="float32")
+ slope_rate = _build_slope_tensor(h).squeeze(-1).squeeze(-1) # [h]
+ slot_idx = paddle.arange(b, dtype="int64")
+
+ out = self._call_decode_forward(q, k, v, kv_caches, slope_rate, slot_idx)
+
+ # Output: [B, H*D] (heads flattened)
+ self.assertEqual(list(out.shape), [b, h * d])
+ self.assertFalse(paddle.isnan(out).any().item(), "Decode output NaN")
+ self.assertTrue(paddle.isfinite(out).all().item(), "Decode output Inf")
+
+ def test_decode_updates_kv_cache(self):
+ """linear_decode_forward_triton should write to kv_caches in-place."""
+ b, h, d = 1, 4, 64
+ q = paddle.randn([b, h, 1, d], dtype="float16")
+ k = paddle.randn([b, h, 1, d], dtype="float16")
+ v = paddle.randn([b, h, 1, d], dtype="float16")
+ kv_caches = paddle.zeros([b, h, d, d], dtype="float32")
+ slope_rate = _build_slope_tensor(h).squeeze(-1).squeeze(-1)
+ slot_idx = paddle.arange(b, dtype="int64")
+
+ kv_before = kv_caches.numpy().copy()
+ self._call_decode_forward(q, k, v, kv_caches, slope_rate, slot_idx)
+ kv_after = kv_caches.numpy()
+
+ self.assertGreater(
+ np.abs(kv_after - kv_before).max(),
+ 0.0,
+ "KV cache was not updated by decode kernel",
+ )
+
+ def test_decode_multiple_steps(self):
+ """Simulate 4 decode steps, verify KV cache accumulates."""
+ b, h, d = 1, 8, 128
+ kv_caches = paddle.zeros([b, h, d, d], dtype="float32")
+ slope_rate = _build_slope_tensor(h).squeeze(-1).squeeze(-1)
+ slot_idx = paddle.arange(b, dtype="int64")
+
+ norms = []
+ for step in range(4):
+ q = paddle.randn([b, h, 1, d], dtype="float16")
+ k = paddle.randn([b, h, 1, d], dtype="float16")
+ v = paddle.randn([b, h, 1, d], dtype="float16")
+ out = self._call_decode_forward(q, k, v, kv_caches, slope_rate, slot_idx)
+ norms.append(float(paddle.norm(out).item()))
+
+ # All steps should produce non-zero output
+ for i, norm_val in enumerate(norms):
+ self.assertGreater(norm_val, 0.0, f"Step {i} output is zero")
+
+ # === 3. Prefill → Decode transition ====================================
+
+ def test_prefill_then_decode(self):
+ """
+ End-to-end: prefill with lightning_attention(), then decode with
+ linear_decode_forward_triton(). This mimics the actual serving path
+ where MiniMaxM1LinearAttention.forward() calls lightning_attention()
+ during prefill and then switches to the decode kernel for generation.
+
+ After prefill the KV state is non-zero; the decode kernel should
+ produce a different output than it would with empty KV state.
+ """
+ b, h, n_prefill, d = 1, 8, 256, 128
+
+ # --- Prefill phase ---
+ q_pf = paddle.randn([b, h, n_prefill, d], dtype="float16")
+ k_pf = paddle.randn([b, h, n_prefill, d], dtype="float16")
+ v_pf = paddle.randn([b, h, n_prefill, d], dtype="float16")
+ ed = _build_slope_tensor(h).squeeze(-1) # [h, 1]
+
+ out_pf, kv_state = self._call_lightning_attention(q_pf, k_pf, v_pf, ed, block_size=256)
+ self.assertFalse(paddle.isnan(out_pf).any().item())
+
+ # --- Decode phase ---
+ q_dec = paddle.randn([b, h, 1, d], dtype="float16")
+ k_dec = paddle.randn([b, h, 1, d], dtype="float16")
+ v_dec = paddle.randn([b, h, 1, d], dtype="float16")
+ slope_rate = _build_slope_tensor(h).squeeze(-1).squeeze(-1) # [h]
+ slot_idx = paddle.arange(b, dtype="int64")
+
+ # Decode WITH warm KV state from prefill
+ kv_warm = kv_state.clone()
+ out_warm = self._call_decode_forward(q_dec, k_dec, v_dec, kv_warm, slope_rate, slot_idx)
+
+ # Decode with COLD (zeros) KV state
+ kv_cold = paddle.zeros_like(kv_state)
+ out_cold = self._call_decode_forward(
+ q_dec.clone(), k_dec.clone(), v_dec.clone(), kv_cold, slope_rate, slot_idx
+ )
+
+ # The warm-state decode should differ from cold-state (prefill context matters)
+ diff = float(paddle.norm(out_warm - out_cold).item())
+ self.assertGreater(
+ diff,
+ 1e-3,
+ "Warm and cold decode outputs are identical — KV state not propagated",
+ )
+
+ # === 4. Slope tensor construction ======================================
+
+ def test_slope_tensor_power_of_2(self):
+ """Slope tensor for n_heads=64 (power of 2) — all values positive, decreasing."""
+ slope = _build_slope_tensor(64)
+ self.assertEqual(list(slope.shape), [64, 1, 1])
+ vals = slope.squeeze(-1).squeeze(-1).numpy()
+ self.assertTrue(np.all(vals > 0), "Non-positive slope values")
+ # First slope should be largest
+ self.assertGreater(vals[0], vals[-1])
+
+ def test_slope_tensor_non_power_of_2(self):
+ """Slope tensor for n_heads=48 (not power of 2) — should still produce valid values."""
+ slope = _build_slope_tensor(48)
+ self.assertEqual(list(slope.shape), [48, 1, 1])
+ vals = slope.squeeze(-1).squeeze(-1).numpy()
+ self.assertTrue(np.all(vals > 0), "Non-positive slope values for n_heads=48")
+
+ def test_slope_tensor_matches_production_heads(self):
+ """Slope tensor for n_heads=64 (MiniMax-M1 production config)."""
+ slope = _build_slope_tensor(64)
+ vals = slope.squeeze(-1).squeeze(-1).numpy()
+ # Expected: 2^{-(2^{-(log2(64)-3)})} = 2^{-(2^{-3})} = 2^{-0.125}
+ expected_start = 2 ** (-0.125)
+ np.testing.assert_allclose(vals[0], expected_start, rtol=1e-5)
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/operators/test_lightning_attn_triton.py b/tests/operators/test_lightning_attn_triton.py
new file mode 100644
index 00000000000..9354524ef1c
--- /dev/null
+++ b/tests/operators/test_lightning_attn_triton.py
@@ -0,0 +1,322 @@
+# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Module for Hackathon 10th Spring No.47.
+Integration tests for the Lightning Attention Triton kernel.
+
+These tests exercise the REAL Triton JIT-compiled GPU kernel
+(lightning_attention_forward) against a pure-NumPy reference
+implementation. They are NOT stub/mock tests — they require
+a CUDA-capable GPU with Triton support.
+
+Validated on: AI Studio V100 (SM70), Paddle 3.3.0, Triton 3.x
+CI marker: @pytest.mark.gpu — skipped automatically when no GPU is present.
+"""
+
+import unittest
+
+import numpy as np
+import paddle
+import pytest
+
+# ---------------------------------------------------------------------------
+# NumPy reference — authoritative, matches the recurrence in the paper.
+# ---------------------------------------------------------------------------
+
+
+def _lightning_attention_numpy_ref(q, k, v, slope, kv_history=None):
+ """
+ Pure NumPy reference implementation of linear attention with exponential
+ decay (Lightning Attention).
+
+ Args:
+ q, k, v: float64 arrays of shape [b, h, n, d] / [b, h, n, e].
+ slope: 1-D array of shape [h] — per-head decay rates.
+ kv_history: optional [b, h, d, e] float64 — KV state carry-in.
+
+ Returns:
+ output: [b, h, n, e] attention output.
+ kv_state: [b, h, d, e] updated KV state after processing all n steps.
+ """
+ b, h, n, d = q.shape
+ e = v.shape[-1]
+ output = np.zeros((b, h, n, e), dtype=np.float64)
+
+ if kv_history is None:
+ kv_state = np.zeros((b, h, d, e), dtype=np.float64)
+ else:
+ kv_state = kv_history.copy()
+
+ for t in range(n):
+ decay = np.exp(-slope)[np.newaxis, :, np.newaxis, np.newaxis]
+ kv_state = kv_state * decay
+ kt = k[:, :, t, :]
+ vt = v[:, :, t, :]
+ kv_state += kt[:, :, :, np.newaxis] * vt[:, :, np.newaxis, :]
+ qt = q[:, :, t, :]
+ output[:, :, t, :] = np.einsum("bhd,bhde->bhe", qt, kv_state)
+
+ return output, kv_state
+
+
+# ---------------------------------------------------------------------------
+# GPU availability guard
+# ---------------------------------------------------------------------------
+
+_GPU_AVAILABLE = paddle.is_compiled_with_cuda() and paddle.device.cuda.device_count() > 0
+
+_SKIP_MSG = "No CUDA GPU available — lightning attention Triton kernel requires GPU"
+
+
+def _import_lightning_attention_forward():
+ """Lazy import so collection doesn't crash on CPU-only boxes."""
+ from fastdeploy.model_executor.ops.triton_ops.lightning_attn import (
+ lightning_attention_forward,
+ )
+
+ return lightning_attention_forward
+
+
+# ---------------------------------------------------------------------------
+# Test suite
+# ---------------------------------------------------------------------------
+
+
+@pytest.mark.gpu
+@unittest.skipUnless(_GPU_AVAILABLE, _SKIP_MSG)
+class TestLightningAttentionTriton(unittest.TestCase):
+ """
+ Integration test: real Triton kernel vs NumPy reference.
+
+ Parametrisation axes:
+ batch : 1, 2
+ heads : 4, 8
+ seq_len : 256 (one block), 512 (two blocks)
+ head_dim: 64, 128
+ dtype : float16, bfloat16
+ """
+
+ # Tolerance table — Triton accumulates in fp32 but the inputs are half
+ # precision, so we need generous tolerances for long sequences.
+ _TOL = {
+ "float16": {"rtol": 5e-2, "atol": 5e-2},
+ "bfloat16": {"rtol": 8e-2, "atol": 8e-2},
+ }
+
+ @classmethod
+ def setUpClass(cls):
+ paddle.set_device("gpu:0")
+ # Store as list to avoid Python descriptor binding (self would be
+ # passed as first arg if a bare function is set as class attribute).
+ cls._forward_fn = [_import_lightning_attention_forward()]
+
+ # --- helpers -----------------------------------------------------------
+
+ def _run_forward(self, b, h, n, d, dtype_str):
+ """Run Triton kernel and compare against NumPy reference."""
+ rng = np.random.default_rng(42)
+
+ # Random inputs in float64 for the reference, then cast to target dtype
+ q_np = rng.standard_normal((b, h, n, d)).astype(np.float64) * 0.1
+ k_np = rng.standard_normal((b, h, n, d)).astype(np.float64) * 0.1
+ v_np = rng.standard_normal((b, h, n, d)).astype(np.float64) * 0.1
+ slope_np = np.abs(rng.standard_normal(h).astype(np.float64)) * 0.5 + 0.1
+
+ # NumPy reference (float64)
+ ref_out, ref_kv = _lightning_attention_numpy_ref(q_np, k_np, v_np, slope_np)
+
+ # Paddle tensors on GPU
+ dtype_paddle = dtype_str
+ q = paddle.to_tensor(q_np.astype(np.float32), dtype=dtype_paddle)
+ k = paddle.to_tensor(k_np.astype(np.float32), dtype=dtype_paddle)
+ v = paddle.to_tensor(v_np.astype(np.float32), dtype=dtype_paddle)
+
+ # Slope: the kernel accepts [1, h, 1, 1] or [h].
+ # The model code passes ed as [1, h, 1, 1] after reshape.
+ slope = paddle.to_tensor(slope_np.astype(np.float32), dtype="float32")
+ slope_4d = slope.reshape([1, h, 1, 1])
+
+ # KV history initialised to zeros
+ kv_history = paddle.zeros([b, h, d, d], dtype="float32")
+
+ # Run kernel
+ out, kv_out = self._forward_fn[0](q, k, v, slope_4d, kv_history, block_size=256)
+
+ # Move to CPU for comparison
+ out_np = out.astype("float32").numpy()
+ kv_out_np = kv_out.numpy()
+
+ tol = self._TOL[dtype_str]
+ np.testing.assert_allclose(
+ out_np,
+ ref_out.astype(np.float32),
+ rtol=tol["rtol"],
+ atol=tol["atol"],
+ err_msg=f"Output mismatch: b={b}, h={h}, n={n}, d={d}, dtype={dtype_str}",
+ )
+
+ return out_np, kv_out_np, ref_out, ref_kv
+
+ # --- core correctness tests -------------------------------------------
+
+ def test_small_single_block_fp16(self):
+ """b=1, h=4, n=256, d=64 — single block, float16."""
+ self._run_forward(b=1, h=4, n=256, d=64, dtype_str="float16")
+
+ def test_small_single_block_bf16(self):
+ """b=1, h=4, n=256, d=64 — single block, bfloat16."""
+ self._run_forward(b=1, h=4, n=256, d=64, dtype_str="bfloat16")
+
+ def test_two_blocks_fp16(self):
+ """b=1, h=8, n=512, d=128 — two blocks, float16."""
+ self._run_forward(b=1, h=8, n=512, d=128, dtype_str="float16")
+
+ def test_two_blocks_bf16(self):
+ """b=2, h=4, n=512, d=64 — two blocks, batched, bfloat16."""
+ self._run_forward(b=2, h=4, n=512, d=64, dtype_str="bfloat16")
+
+ def test_large_dim_fp16(self):
+ """b=1, h=8, n=256, d=128 — large head dim, float16."""
+ self._run_forward(b=1, h=8, n=256, d=128, dtype_str="float16")
+
+ def test_batched_bf16(self):
+ """b=2, h=8, n=256, d=128 — multi-batch, bfloat16."""
+ self._run_forward(b=2, h=8, n=256, d=128, dtype_str="bfloat16")
+
+ # --- KV history persistence (recurrent property) ----------------------
+
+ def test_kv_history_persistence(self):
+ """
+ Verify that processing [seq1, seq2] in two calls with KV carry-over
+ matches processing the full concatenated sequence [seq1 || seq2].
+ """
+ b, h, d = 1, 4, 64
+ n1, n2 = 256, 256
+ rng = np.random.default_rng(123)
+
+ q1_np = rng.standard_normal((b, h, n1, d)).astype(np.float64) * 0.1
+ k1_np = rng.standard_normal((b, h, n1, d)).astype(np.float64) * 0.1
+ v1_np = rng.standard_normal((b, h, n1, d)).astype(np.float64) * 0.1
+ q2_np = rng.standard_normal((b, h, n2, d)).astype(np.float64) * 0.1
+ k2_np = rng.standard_normal((b, h, n2, d)).astype(np.float64) * 0.1
+ v2_np = rng.standard_normal((b, h, n2, d)).astype(np.float64) * 0.1
+ slope_np = np.abs(rng.standard_normal(h).astype(np.float64)) * 0.5 + 0.1
+
+ # Two-call path (with KV carry-over)
+ _, kv_after_1 = _lightning_attention_numpy_ref(q1_np, k1_np, v1_np, slope_np)
+ out2_ref, _ = _lightning_attention_numpy_ref(q2_np, k2_np, v2_np, slope_np, kv_history=kv_after_1)
+
+ # Full-sequence path
+ q_full = np.concatenate([q1_np, q2_np], axis=2)
+ k_full = np.concatenate([k1_np, k2_np], axis=2)
+ v_full = np.concatenate([v1_np, v2_np], axis=2)
+ out_full_ref, _ = _lightning_attention_numpy_ref(q_full, k_full, v_full, slope_np)
+ out_full_second_half = out_full_ref[:, :, n1:, :]
+
+ # Reference consistency check (NumPy vs NumPy)
+ np.testing.assert_allclose(
+ out2_ref.astype(np.float32),
+ out_full_second_half.astype(np.float32),
+ rtol=1e-5,
+ atol=1e-5,
+ err_msg="Reference recurrence does not match full-sequence computation",
+ )
+
+ # Now run the two-call path through the Triton kernel
+ dtype_str = "float16"
+ dtype_paddle = dtype_str
+ slope = paddle.to_tensor(slope_np.astype(np.float32), dtype="float32")
+ slope_4d = slope.reshape([1, h, 1, 1])
+
+ q1 = paddle.to_tensor(q1_np.astype(np.float32), dtype=dtype_paddle)
+ k1 = paddle.to_tensor(k1_np.astype(np.float32), dtype=dtype_paddle)
+ v1 = paddle.to_tensor(v1_np.astype(np.float32), dtype=dtype_paddle)
+ q2 = paddle.to_tensor(q2_np.astype(np.float32), dtype=dtype_paddle)
+ k2 = paddle.to_tensor(k2_np.astype(np.float32), dtype=dtype_paddle)
+ v2 = paddle.to_tensor(v2_np.astype(np.float32), dtype=dtype_paddle)
+
+ kv_init = paddle.zeros([b, h, d, d], dtype="float32")
+
+ # Call 1
+ _, kv_after_1_gpu = self._forward_fn[0](q1, k1, v1, slope_4d, kv_init, block_size=256)
+ # Call 2 — feed KV state from call 1
+ out2_gpu, _ = self._forward_fn[0](q2, k2, v2, slope_4d, kv_after_1_gpu, block_size=256)
+
+ out2_gpu_np = out2_gpu.astype("float32").numpy()
+
+ np.testing.assert_allclose(
+ out2_gpu_np,
+ out2_ref.astype(np.float32),
+ rtol=5e-2,
+ atol=5e-2,
+ err_msg="Triton KV carry-over does not match reference two-call path",
+ )
+
+ # --- output shape and dtype -------------------------------------------
+
+ def test_output_shape(self):
+ """Verify output tensor shape matches [b, h, n, d]."""
+ b, h, n, d = 1, 4, 256, 64
+ q = paddle.randn([b, h, n, d], dtype="float16")
+ k = paddle.randn([b, h, n, d], dtype="float16")
+ v = paddle.randn([b, h, n, d], dtype="float16")
+ slope = paddle.ones([1, h, 1, 1], dtype="float32") * 0.3
+ kv = paddle.zeros([b, h, d, d], dtype="float32")
+
+ out, kv_out = self._forward_fn[0](q, k, v, slope, kv, block_size=256)
+
+ self.assertEqual(list(out.shape), [b, h, n, d])
+ self.assertEqual(list(kv_out.shape), [b, h, d, d])
+
+ def test_output_dtype_preserved(self):
+ """Verify output dtype matches input dtype."""
+ b, h, n, d = 1, 4, 256, 64
+ for dtype_str in ["float16", "bfloat16"]:
+ q = paddle.randn([b, h, n, d], dtype=dtype_str)
+ k = paddle.randn([b, h, n, d], dtype=dtype_str)
+ v = paddle.randn([b, h, n, d], dtype=dtype_str)
+ slope = paddle.ones([1, h, 1, 1], dtype="float32") * 0.3
+ kv = paddle.zeros([b, h, d, d], dtype="float32")
+
+ out, kv_out = self._forward_fn[0](q, k, v, slope, kv, block_size=256)
+ self.assertEqual(str(out.dtype).split(".")[-1], dtype_str)
+
+ # --- decode-path kernel -----------------------------------------------
+
+ def test_linear_decode_forward(self):
+ """
+ Test the linear_decode_forward_triton kernel (single-step decode).
+ This is the kernel used during autoregressive generation.
+ """
+ from fastdeploy.model_executor.ops.triton_ops.lightning_attn import (
+ linear_decode_forward_triton,
+ )
+
+ b, h, d = 2, 8, 128
+ q = paddle.randn([b, h, 1, d], dtype="float16")
+ k = paddle.randn([b, h, 1, d], dtype="float16")
+ v = paddle.randn([b, h, 1, d], dtype="float16")
+ kv_caches = paddle.zeros([b, h, d, d], dtype="float32")
+ slope_rate = paddle.ones([h], dtype="float32") * 0.3
+ slot_idx = paddle.arange(b, dtype="int64")
+
+ out = linear_decode_forward_triton(q, k, v, kv_caches, slope_rate, slot_idx)
+
+ # Output shape: [B, H*D] (flattened heads)
+ self.assertEqual(list(out.shape), [b, h * d])
+ self.assertFalse(paddle.isnan(out).any().item(), "Decode output contains NaN")
+ self.assertTrue(paddle.isfinite(out).all().item(), "Decode output contains Inf")
+
+
+if __name__ == "__main__":
+ unittest.main()