[WIP] [KDA] support GVA for delta_h and fwd_o#73
Conversation
There was a problem hiding this comment.
Code Review
This pull request updates the flash-linear-attention baseline to v0.5.0 and introduces support for Grouped Value Attention (GVA) across the chunk_delta_h and fwd_o kernels. The implementation includes updated indexing logic to map value heads to QK heads and extends benchmark scripts to support configurable head counts. Documentation and benchmark results have been refreshed to reflect performance improvements on Blackwell and Hopper architectures. Feedback was provided to include explicit assertions validating that the number of value heads is a multiple of the QK heads and that head dimensions are restricted to 128, as required by the current kernel tiling logic.
| HV = u.shape[2] | ||
| V_dim = u.shape[3] | ||
| BT = chunk_size | ||
| is_varlen = cu_seqlens is not None |
There was a problem hiding this comment.
For Grouped Value Attention (GVA) to work correctly with the current head-mapping logic (i_h = hidx // (HV // H)), the number of value heads (HV) must be a multiple of the number of QK heads (H). Additionally, since the kernel tiling is hardcoded for specific dimensions, we should also validate that V_dim matches the expected 128.
| is_varlen = cu_seqlens is not None | |
| is_varlen = cu_seqlens is not None | |
| assert HV >= H and HV % H == 0, f"HV ({HV}) must be >= H ({H}) and divisible by H" | |
| assert K_dim == 128 and V_dim == 128, f"current kernel only supports head_dim=128, got K={K_dim}, V={V_dim}" |
📌 Description
🔍 Related Issues
#55
🚀 Pull Request Checklist
Thank you for contributing to cuLA! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
⚡ Performance
Reviewer Notes