Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
170 changes: 170 additions & 0 deletions tests/operators/test_moe_expert_ffn_wint2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import numpy as np
import paddle

from fastdeploy.model_executor.ops.gpu import moe_expert_ffn_wint2

try:
from fastdeploy.model_executor.ops.gpu import winx_unzip as _winx_unzip_op

_HAS_WINX_UNZIP = True
except ImportError:
_HAS_WINX_UNZIP = False

paddle.seed(2026)
np.random.seed(2026)


def _cutlass_rearrange(w):
"""Apply CUTLASS WINT2 weight layout rearrangement."""
shape = w.shape
E, Kp, N = shape
w = w.reshape([E, Kp // 16, 16, N // 8, 8])
w = paddle.transpose(w, perm=[0, 3, 1, 4, 2])
return w.reshape(shape)


def _build_inputs(num_experts, hidden_size, inter_size, tokens_per_expert, dtype="bfloat16"):
"""Create correctly-shaped tensors for moe_expert_ffn_wint2."""
gated_inter = inter_size * 2
total_tokens = sum(tokens_per_expert)
permute_input = paddle.randn([total_tokens, hidden_size], dtype=dtype)
tokens_expert_prefix_sum = paddle.to_tensor(np.cumsum(tokens_per_expert).astype("int64"))

w_up_raw = paddle.randint(0, 256, [num_experts, hidden_size // 4, gated_inter], dtype="int32").cast("uint8")
w_down_raw = paddle.randint(0, 256, [num_experts, inter_size // 4, hidden_size], dtype="int32").cast("uint8")
w_up = _cutlass_rearrange(w_up_raw)
w_down = _cutlass_rearrange(w_down_raw)

return dict(
permute_input=permute_input,
tokens_expert_prefix_sum=tokens_expert_prefix_sum,
up_gate_proj_weight=w_up,
down_proj_weight=w_down,
up_gate_proj_bias=None,
up_gate_proj_scale=paddle.randn([num_experts, gated_inter], dtype=dtype) * 0.01,
down_proj_scale=paddle.randn([num_experts, hidden_size], dtype=dtype) * 0.01,
up_gate_proj_local_scale=paddle.randint(
0, 256, [num_experts, hidden_size // 128, gated_inter], dtype="int32"
).cast("uint8"),
up_gate_proj_code_scale=paddle.randn([num_experts, gated_inter], dtype="float32") * 0.01,
up_gate_proj_code_zp=paddle.randn([num_experts, gated_inter], dtype="float32") * 0.01,
down_proj_local_scale=paddle.randint(
0, 256, [num_experts, inter_size // 128, hidden_size], dtype="int32"
).cast("uint8"),
down_proj_code_scale=paddle.randn([num_experts, hidden_size], dtype="float32") * 0.01,
down_proj_code_zp=paddle.randn([num_experts, hidden_size], dtype="float32") * 0.01,
_up_weight_raw=w_up_raw,
_down_weight_raw=w_down_raw,
)


def _call_op(inputs):
"""Invoke moe_expert_ffn_wint2."""
return moe_expert_ffn_wint2(
inputs["permute_input"],
inputs["tokens_expert_prefix_sum"],
inputs["up_gate_proj_weight"],
inputs["down_proj_weight"],
inputs["up_gate_proj_bias"],
inputs["up_gate_proj_scale"],
inputs["down_proj_scale"],
inputs["up_gate_proj_local_scale"],
inputs["up_gate_proj_code_scale"],
inputs["up_gate_proj_code_zp"],
inputs["down_proj_local_scale"],
inputs["down_proj_code_scale"],
inputs["down_proj_code_zp"],
False,
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

❓ 疑问 _call_op 末尾第 14 个位置参数 False 含义未注释,可读性差,读者无法从代码判断该布尔值控制什么行为。

建议添加注释说明:

        False,  # TODO: 说明该参数含义,例如 use_bias=False

或在函数 docstring 中补充各参数说明。

)


def _reference_moe_expert_ffn(inputs):
"""Decomposed reference: winx_unzip dequant -> matmul -> SwiGLU -> matmul."""
E = inputs["_up_weight_raw"].shape[0]
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 建议 _up_weight_raw 在此处仅用于取专家数 E,其实际权重数据在引用实现中并未使用(传给 _winx_unzip_op 的是经过重排的 up_gate_proj_weight)。

建议改为:

E = inputs["up_gate_proj_weight"].shape[0]

并移除 _build_inputs 返回字典中的 _up_weight_raw_down_weight_raw 两个冗余字段,减少误导。

prefix = inputs["tokens_expert_prefix_sum"].numpy()
starts = np.concatenate([[0], prefix[:-1]]).astype(int)

dense_up_gate = _winx_unzip_op(
inputs["up_gate_proj_weight"],
inputs["up_gate_proj_local_scale"],
inputs["up_gate_proj_code_scale"],
inputs["up_gate_proj_code_zp"],
inputs["up_gate_proj_scale"],
"weight_only_int2",
).cast("float32")
dense_down = _winx_unzip_op(
inputs["down_proj_weight"],
inputs["down_proj_local_scale"],
inputs["down_proj_code_scale"],
inputs["down_proj_code_zp"],
inputs["down_proj_scale"],
"weight_only_int2",
).cast("float32")

outputs = []
for e in range(E):
start, end = int(starts[e]), int(prefix[e])
if end <= start:
continue
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

❓ 疑问 if end <= start: continue 分支(某专家分配 0 个 token)在现有所有测试用例中均未被触发(TOKENS=[4,6,2,4][1,5,3,7] 均无零分配),该分支的正确性未经验证。

建议补充一个覆盖零 token 专家的测试用例,例如:

def test_correctness_zero_token_expert(self):
    """Expert receiving 0 tokens is silently skipped."""
    self._check_output(tokens=[4, 0, 2, 4])

x = inputs["permute_input"][start:end].cast("float32")
fc1 = paddle.matmul(x, dense_up_gate[e])
inter_size = fc1.shape[-1] // 2
act = paddle.nn.functional.silu(fc1[:, :inter_size]) * fc1[:, inter_size:]
outputs.append(paddle.matmul(act, dense_down[e]))
if outputs:
return paddle.concat(outputs, axis=0).cast(inputs["permute_input"].dtype)
return paddle.zeros_like(inputs["permute_input"][:0])


class TestMoeExpertFFNWint2(unittest.TestCase):
"""Correctness tests for the WINT2 MoE FFN op."""

E, H, INTER = 4, 256, 128
TOKENS = [4, 6, 2, 4]

def setUp(self):
paddle.set_device("gpu")

def _check_output(self, dtype="bfloat16", tokens=None):
"""Run op, compare against decomposed reference via assert_allclose."""
tok = tokens or self.TOKENS
inputs = _build_inputs(self.E, self.H, self.INTER, tok, dtype=dtype)
out = _call_op(inputs).cast("float32").numpy()
ref = _reference_moe_expert_ffn(inputs).cast("float32").numpy()
self.assertEqual(list(out.shape), [sum(tok), self.H])
np.testing.assert_allclose(out, ref, rtol=5e-2, atol=5e-2)

@unittest.skipUnless(_HAS_WINX_UNZIP, "winx_unzip needed for reference")
def test_correctness_bf16(self):
"""Fused output matches decomposed reference (bfloat16)."""
self._check_output(dtype="bfloat16")

@unittest.skipUnless(_HAS_WINX_UNZIP, "winx_unzip needed for reference")
def test_correctness_fp16(self):
"""Fused output matches decomposed reference (float16)."""
self._check_output(dtype="float16")

@unittest.skipUnless(_HAS_WINX_UNZIP, "winx_unzip needed for reference")
def test_correctness_uneven_tokens(self):
"""Correctness with uneven token distribution across experts."""
self._check_output(tokens=[1, 5, 3, 7])


if __name__ == "__main__":
unittest.main()