【Hackathon 9th No.33】add test_moe_wna16_marlin_gemm [cf]#7708
【Hackathon 9th No.33】add test_moe_wna16_marlin_gemm [cf]#7708ghost wants to merge 1 commit intoPaddlePaddle:developfrom
Conversation
|
Thanks for your contribution! |
CI报告基于以下代码生成(30分钟更新一次): 1 任务总览所有已执行的 CI 任务均已通过 ✅。另有 7 个 Workflow 处于
2 任务状态汇总2.1 Required任务 : 0/0 通过
2.2 可选任务 — 2/2 通过
3 失败详情(仅 required)无 required 失败任务。 |
PaddlePaddle-bot
left a comment
There was a problem hiding this comment.
🤖 Paddle-CI-Agent | pr_review |
2026-05-03 21:50:49
📋 Review 摘要
PR 概述:为 moe_wna16_marlin_gemm 量化 MoE GEMM 算子新增单元测试
变更范围:tests/operators/
影响面 Tag:[OP] [CI]
📝 PR 规范检查
标题格式不符合规范:【Hackathon 9th No.33】 和 [cf] 均非官方 Tag;PR 描述 4 个段落(Motivation / Modifications / Usage or Command / Accuracy Tests)均为空占位符,不合规。
标题建议(可直接复制):
[OP] Add unit test for moe_wna16_marlin_gemm
PR 描述建议(可直接复制,必须复刻 checklist §D2 模板的完整结构):
## Motivation
为 `moe_wna16_marlin_gemm` 算子补充单元测试,验证 INT4(uint4b8)量化 MoE GEMM 的输出正确性,覆盖 top_k=1、top_k=2 及多种 M 维度场景。
## Modifications
新增 `tests/operators/test_moe_wna16_marlin_gemm.py`:
- 实现 `_quantize_to_uint4b8` / `_pack_gptq_int32` / `_dequantize_uint4b8` 等量化辅助函数
- 实现 `_build_marlin_weights` 对多 expert 权重进行 GPTQ-pack + Marlin-repack
- 新增 `TestMoeWna16MarlinGemm` 测试类,包含三个测试用例(`test_topk1`、`test_topk2_mul_weights`、`test_various_sizes`)
## Usage or Command
```bash
python -m pytest tests/operators/test_moe_wna16_marlin_gemm.py -v
```
## Accuracy Tests
N/A(本 PR 为功能正确性单元测试,不涉及精度基线对比)
## Checklist
- [ ] Add at least a tag in the PR title.
- Tag list: [`[FDConfig]`,`[APIServer]`,`[Engine]`, `[Scheduler]`, `[PD Disaggregation]`, `[Executor]`, `[Graph Optimization]`, `[Speculative Decoding]`, `[RL]`, `[Models]`, `[Quantization]`, `[Loader]`, `[OP]`, `[KVCache]`, `[DataProcessor]`, `[BugFix]`, `[Docs]`, `[CI]`, `[Optimization]`, `[Feature]`, `[Benchmark]`, `[Others]`, `[XPU]`, `[HPU]`, `[GCU]`, `[DCU]`, `[Iluvatar]`, `[Metax]`]
- You can add new tags based on the PR content, but the semantics must be clear.
- [ ] Format your code, run `pre-commit` before commit.
- [x] Add unit tests. Please write the reason in this PR if no unit tests.
- [ ] Provide accuracy results.
- [ ] If the current PR is submitting to the `release` branch, make sure the PR has been submitted to the `develop` branch, then cherry-pick it to the `release` branch with the `[Cherry-Pick]` PR tag.问题
| 级别 | 文件 | 概述 |
|---|---|---|
| ❓ 疑问 | tests/operators/test_moe_wna16_marlin_gemm.py:111 |
workspace 大小硬编码为 528,魔法数字来源不明 |
| ❓ 疑问 | tests/operators/test_moe_wna16_marlin_gemm.py:143 |
精度容差 atol=2e-1(0.2)偏大,降低测试有效性 |
| ❓ 疑问 | tests/operators/test_moe_wna16_marlin_gemm.py:79 |
b_scales shape 为 [E, 1, N],需确认 kernel 期望 layout |
总体评价
整体测试结构清晰,覆盖了多种场景(top_k=1/2、不同 M 值)。建议作者确认 workspace 大小来源、适当收紧精度容差,并修复 PR 标题与描述后合入。
| b_scales=b_scales, | ||
| topk_ids=topk_ids, | ||
| topk_ids_np=topk_ids_np, | ||
| topk_weights=topk_weights, |
There was a problem hiding this comment.
❓ 疑问 workspace 大小硬编码为 528,魔法数字来源不明。
该值是否来自 kernel 内部固定需求?建议加注释说明,或从被测接口中动态获取,否则若 kernel 实现变化将导致静默错误。
| mul_topk_weights=mul_topk_weights, | ||
| is_ep=False, | ||
| b_q_type_str="uint4b8", | ||
| size_m=M, |
There was a problem hiding this comment.
❓ 疑问 精度容差 atol=2e-1(0.2)偏大。
FP16 量化 GEMM 的典型误差通常 < 0.05,atol=0.2 几乎无法检出明显的数值错误,会降低测试的有效性。建议收紧为 atol=5e-2(或在注释中说明为何需要 0.2)。
| b_scales = paddle.to_tensor(np.stack(all_s, axis=0), dtype="float16", place=paddle.CUDAPlace(0)) | ||
| return b_q_weight, b_scales, all_q, all_s | ||
|
|
||
|
|
There was a problem hiding this comment.
❓ 疑问 _quantize_to_uint4b8 返回的 scales shape 为 (1, N)(通过 .reshape(1, N)),存入 all_s 后 stack 得到 b_scales shape 为 [E, 1, N]。
但在 NumPy reference 中(第 140 行)inp["scales"][ids[i, j]] 取出的是 (1, N) 的 scales,执行 _dequantize_uint4b8(q_vals, scales) 时 broadcast 可能正确,但建议确认实际 kernel 期望的 b_scales layout(是否为 [E, 1, N] 或 [E, N]),避免维度不匹配被 broadcast 静默掩盖。
Motivation
Modifications
Usage or Command
Accuracy Tests
Checklist