【Hackathon 9th No.53】add test_multi_head_latent_attention [cf]#7711
Conversation
|
Thanks for your contribution! |
CI报告基于以下代码生成(30分钟更新一次): 1 任务总览所有已执行任务全部通过 ✅,无 required 失败任务,建议通过。
2 任务状态汇总2.1 Required任务 : 0/0 通过
2.2 可选任务 — 2/2 通过
3 失败详情(仅 required)无 |
PaddlePaddle-bot
left a comment
There was a problem hiding this comment.
🤖 Paddle-CI-Agent | pr_review |
2026-05-03 21:45:04
📋 Review 摘要
PR 概述:为 multi_head_latent_attention 算子新增单元测试,验证 BF16/FP16 精度下 decode 阶段的计算正确性
变更范围:tests/operators/
影响面 Tag:[CI] [OP]
📝 PR 规范检查
PR 标题未使用官方 [Tag] 格式(当前为 【Hackathon 9th No.53】...),PR 描述各 section 均为 TODO 占位符,需补全。
标题建议(可直接复制):
[CI] Add unit test for multi_head_latent_attention operator
PR 描述建议(可直接复制,必须复刻 checklist §D2 模板的完整结构):
## Motivation
为 `multi_head_latent_attention`(MLA decode attention)算子补充单元测试,验证 BF16 和 FP16 两种精度下 decode 阶段的计算正确性,覆盖 paged KV cache + GQA 场景。
## Modifications
- 新增 `tests/operators/test_multi_head_latent_attention.py`:
- 实现 NumPy float64 参考实现 `_reference_mla_decode`,支持 paged KV cache + GQA
- 新增 `TestMultiHeadLatentAttention` 单测类,包含 BF16 / FP16 两条 decode 正确性测试用例
- 测试仅在 CUDA 可用且 SM ≥ 90(H100+)时运行,其余环境自动跳过
## Usage or Command
```bash
python -m pytest tests/operators/test_multi_head_latent_attention.py -v
```
## Accuracy Tests
N/A(本 PR 仅新增测试文件,不涉及算子实现变更)
## Checklist
- [x] Add at least a tag in the PR title.
- Tag list: [`[FDConfig]`,`[APIServer]`,`[Engine]`, `[Scheduler]`, `[PD Disaggregation]`, `[Executor]`, `[Graph Optimization]`, `[Speculative Decoding]`, `[RL]`, `[Models]`, `[Quantization]`, `[Loader]`, `[OP]`, `[KVCache]`, `[DataProcessor]`, `[BugFix]`, `[Docs]`, `[CI]`, `[Optimization]`, `[Feature]`, `[Benchmark]`, `[Others]`, `[XPU]`, `[HPU]`, `[GCU]`, `[DCU]`, `[Iluvatar]`, `[Metax]`]
- You can add new tags based on the PR content, but the semantics must be clear.
- [ ] Format your code, run `pre-commit` before commit.
- [x] Add unit tests. Please write the reason in this PR if no unit tests.
- [ ] Provide accuracy results.
- [ ] If the current PR is submitting to the `release` branch, make sure the PR has been submitted to the `develop` branch, then cherry-pick it to the `release` branch with the `[Cherry-Pick]` PR tag.问题
| 级别 | 文件 | 概述 |
|---|---|---|
| 🟡 建议 | tests/operators/test_multi_head_latent_attention.py:172 |
_check_output 仅测 seq_len=5(单 block 内),缺少跨 block 边界场景 |
| ❓ 疑问 | tests/operators/test_multi_head_latent_attention.py:147 |
11 个连续 None 参数缺少逐行注释,独立阅读时语义不明 |
总体评价
测试结构完整,NumPy 参考实现维度推导正确,跳过逻辑合理;建议补充跨 block 边界用例并为 None 参数添加行内注释以提升可维护性。
| max_dec_len_cpu, | ||
| max_len_kv_cpu, | ||
| None, | ||
| None, |
There was a problem hiding this comment.
❓ 疑问 11 个连续 None 参数缺少逐行注释说明。
生产侧(mla_attention_backend.py:518-528)已为每个 None 加注释,测试文件独立阅读时语义不明。建议对照生产代码补充行内注释:
None, # attn_mask
None, # qkv_bias
None, # qkv_out_scales
None, # cache_k_quant_scales
None, # cache_v_quant_scales
None, # cache_k_dequant_scales
None, # cache_v_dequant_scales
None, # cache_k_zp
None, # cache_v_zp
None, # out_shifts
None, # out_smooths| return args, query_ref, kv_ref, block_tables.numpy() | ||
|
|
||
| def _check_output(self, dtype_str, seq_len=5): | ||
| """Run op and compare against NumPy reference.""" |
There was a problem hiding this comment.
🟡 建议 _check_output 默认 seq_len=5 远小于 block_size=64,所有位置均落在第 0 个 block 内,block_tables_np[0, pos // block_size] 始终为 0,无法覆盖 paged KV block 切换的边界逻辑。
建议新增一个跨 block 的测试用例(seq_len=65,跨越第 1 / 第 2 个 block):
def test_decode_correctness_bf16_multi_block(self):
"""BF16 multi-block paged KV cache boundary correctness."""
self._check_output("bfloat16", seq_len=65)
Motivation
Modifications
Usage or Command
Accuracy Tests
Checklist