Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
182 changes: 182 additions & 0 deletions tests/operators/test_moe_wna16_marlin_gemm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import numpy as np
import paddle

from fastdeploy.model_executor.ops.gpu import (
MoeWna16MarlinGemmApi,
gptq_marlin_repack,
tritonmoe_preprocess_func,
)

paddle.seed(42)
np.random.seed(42)


def _get_sm_version():
"""Return GPU compute capability as a float, e.g. 8.0 for SM80."""
if not paddle.is_compiled_with_cuda():
return 0.0
try:
prop = paddle.device.cuda.get_device_properties()
return float(f"{prop.major}.{prop.minor}")
except Exception:
return 0.0


def _quantize_to_uint4b8(weight_fp16):
"""Per-channel symmetric quantization to uint4b8 (zero_point=8)."""
K, N = weight_fp16.shape
W = weight_fp16.astype(np.float32)
amax = np.maximum(np.abs(W).max(axis=0), 1e-10)
scales = (amax / 7.0).astype(np.float32)
q = np.clip(np.round(W / scales + 8.0), 0, 15).astype(np.uint8)
return q, scales.reshape(1, N).astype(np.float16)


def _pack_gptq_int32(q_values):
"""Pack [K,N] uint8 values into GPTQ int32 layout [K//8, N]."""
K, N = q_values.shape
packed = np.zeros((K // 8, N), dtype=np.int32)
for offset in range(8):
packed |= q_values[offset::8, :].astype(np.int32) << (4 * offset)
return packed


def _dequantize_uint4b8(q_values, scales):
"""Dequantize uint4b8 values to float32."""
return (q_values.astype(np.float32) - 8.0) * scales.astype(np.float32)


def _build_marlin_weights(weights_list, K, N):
"""Quantize, GPTQ-pack, and Marlin-repack a list of expert weights."""
perm = paddle.empty([0], dtype="int32")
marlin_per_expert, all_q, all_s = [], [], []
for w_fp16 in weights_list:
q, s = _quantize_to_uint4b8(w_fp16)
all_q.append(q)
all_s.append(s)
packed_t = paddle.to_tensor(_pack_gptq_int32(q), dtype="int32", place=paddle.CUDAPlace(0))
marlin_per_expert.append(gptq_marlin_repack(packed_t, perm, size_k=K, size_n=N, num_bits=4))
b_q_weight = paddle.stack(marlin_per_expert, axis=0)
b_scales = paddle.to_tensor(np.stack(all_s, axis=0), dtype="float16", place=paddle.CUDAPlace(0))
return b_q_weight, b_scales, all_q, all_s


Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

❓ 疑问 _quantize_to_uint4b8 返回的 scales shape 为 (1, N)(通过 .reshape(1, N)),存入 all_s 后 stack 得到 b_scales shape 为 [E, 1, N]

但在 NumPy reference 中(第 140 行)inp["scales"][ids[i, j]] 取出的是 (1, N) 的 scales,执行 _dequantize_uint4b8(q_vals, scales) 时 broadcast 可能正确,但建议确认实际 kernel 期望的 b_scales layout(是否为 [E, 1, N][E, N]),避免维度不匹配被 broadcast 静默掩盖。

@unittest.skipUnless(
paddle.is_compiled_with_cuda() and _get_sm_version() >= 8.0,
"Requires CUDA GPU with SM80+ (Ampere or newer)",
)
class TestMoeWna16MarlinGemm(unittest.TestCase):
"""Tests for moe_wna16_marlin_gemm — quantized MoE GEMM correctness."""

E, K, N, BLOCK_M = 8, 256, 256, 16

def setUp(self):
paddle.set_device("gpu")

def _make_inputs(self, M=16, top_k=1, seed=42):
"""Build all tensors needed by MoeWna16MarlinGemmApi."""
np.random.seed(seed)
a_np = (np.random.randn(M, self.K) * 0.1).astype(np.float16)
a = paddle.to_tensor(a_np, dtype="float16", place=paddle.CUDAPlace(0))
ws = [(np.random.randn(self.K, self.N) * 0.05).astype(np.float16) for _ in range(self.E)]
b_q_weight, b_scales, q_vals, scales = _build_marlin_weights(ws, self.K, self.N)
topk_ids_np = np.random.randint(0, self.E, size=(M, top_k)).astype(np.int64)
topk_ids = paddle.to_tensor(topk_ids_np, dtype="int64", place=paddle.CUDAPlace(0))
topk_w_np = np.random.rand(M, top_k).astype(np.float32)
topk_weights = paddle.to_tensor(topk_w_np, dtype="float32", place=paddle.CUDAPlace(0))
sorted_ids, expert_ids, ntokens_pp = tritonmoe_preprocess_func(topk_ids, self.E, self.BLOCK_M)
return dict(
a=a,
a_np=a_np,
b_q_weight=b_q_weight,
b_scales=b_scales,
topk_ids=topk_ids,
topk_ids_np=topk_ids_np,
topk_weights=topk_weights,
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

❓ 疑问 workspace 大小硬编码为 528,魔法数字来源不明。

该值是否来自 kernel 内部固定需求?建议加注释说明,或从被测接口中动态获取,否则若 kernel 实现变化将导致静默错误。

topk_w_np=topk_w_np,
sorted_ids=sorted_ids,
expert_ids=expert_ids,
ntokens_pp=ntokens_pp,
workspace=paddle.empty([528], dtype="int32"),
q_vals=q_vals,
scales=scales,
)

def _check_output(self, M, top_k, mul_topk_weights=False):
"""Run kernel, compute NumPy reference, assert_allclose."""
inp = self._make_inputs(M=M, top_k=top_k)
out = MoeWna16MarlinGemmApi(
inp["a"],
c_or_none=None,
b_q_weight=inp["b_q_weight"],
b_scales=inp["b_scales"],
global_scale_or_none=None,
b_zeros_or_none=None,
g_idx_or_none=None,
perm_or_none=None,
workspace=inp["workspace"],
sorted_token_ids=inp["sorted_ids"],
expert_ids=inp["expert_ids"],
num_tokens_post_padded=inp["ntokens_pp"],
topk_weights=inp["topk_weights"],
moe_block_size=self.BLOCK_M,
top_k=top_k,
mul_topk_weights=mul_topk_weights,
is_ep=False,
b_q_type_str="uint4b8",
size_m=M,
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

❓ 疑问 精度容差 atol=2e-1(0.2)偏大。

FP16 量化 GEMM 的典型误差通常 < 0.05,atol=0.2 几乎无法检出明显的数值错误,会降低测试的有效性。建议收紧为 atol=5e-2(或在注释中说明为何需要 0.2)。

size_n=self.N,
size_k=self.K,
is_k_full=True,
use_atomic_add=True,
use_fp32_reduce=True,
is_zp_float=False,
)[0]
# NumPy reference
a_fp32 = inp["a_np"].astype(np.float32)
ids, w = inp["topk_ids_np"], inp["topk_w_np"]
ref = np.zeros((M * top_k, self.N), dtype=np.float32)
for i in range(M):
for j in range(top_k):
W_deq = _dequantize_uint4b8(inp["q_vals"][ids[i, j]], inp["scales"][ids[i, j]])
row = a_fp32[i] @ W_deq
if mul_topk_weights:
row *= w[i, j]
ref[i * top_k + j] = row
self.assertEqual(list(out.shape), [M * top_k, self.N])
self.assertEqual(out.dtype, paddle.float16)
np.testing.assert_allclose(out.numpy().astype(np.float32), ref, rtol=1e-1, atol=2e-1)

def test_topk1(self):
"""top_k=1, no weight multiplication."""
self._check_output(M=16, top_k=1, mul_topk_weights=False)

def test_topk2_mul_weights(self):
"""top_k=2 with routing weight scaling."""
self._check_output(M=16, top_k=2, mul_topk_weights=True)

def test_various_sizes(self):
"""Multiple M values with top_k=1."""
for M in (1, 8, 32):
with self.subTest(M=M):
self._check_output(M=M, top_k=1)


if __name__ == "__main__":
unittest.main()
Loading