fix: wire h0_indices into Lightning Attention decode for state-pool indexing#75
fix: wire h0_indices into Lightning Attention decode for state-pool indexing#75Emre-Dinc wants to merge 1 commit into
Conversation
There was a problem hiding this comment.
Code Review
This pull request implements indirect indexing for state pool access in the Lightning Attention decode kernels, ensuring that s_offsets are correctly utilized. The changes include updates to both small and big batch kernels, adjustments to the compilation caching logic, and the addition of a comprehensive test suite covering various offset scenarios. A review comment suggests adding explicit dimensionality assertions for the state and output tensors to align with the implementation and prevent potential runtime errors stemming from misleading docstrings.
| h0_source = s | ||
|
|
||
| # Validate state pool dimensions | ||
| assert s.shape[0] % H == 0, f"s.shape[0] must be divisible by H={H}, got {s.shape[0]}" |
There was a problem hiding this comment.
The implementation assumes that s is a 3D tensor with shape [pool_size * H, V, K] and out is a 3D tensor with shape [B, H, V]. However, the docstrings (lines 585-586 and 607-608) incorrectly describe them as 4D tensors. To prevent runtime indexing errors or incorrect results when users follow the docstrings, it is recommended to explicitly validate the dimensionality of these tensors here.
# Validate state pool dimensions
assert s.ndim == 3, f"s must be a 3D tensor [pool_size * H, V, K], got {s.ndim}D"
assert out.ndim == 3, f"out must be a 3D tensor [B, H, V], got {out.ndim}D"
assert s.shape[0] % H == 0, f"s.shape[0] must be divisible by H={H}, got {s.shape[0]}"9b6a77d to
c922b16
Compare
📌 Description
Lightning Attention decode ignores
s_offsetsand indexes state by flattenedbatch_idxdirectly (the code comments this explicitly: "passed to kernel but not actually used").What changed:
h0_sourceviah0_indices[i_n] * HV + i_hvinstead ofbatch_idxB * Hinstead ofh0_source.shape[0]pool_dim0so different pool sizes recompile correctlys.shape[0] % H == 0validationNot in scope: Pad semantics (
offset == -1) - can be added as a follow-up mirroringkda_decode.py.🔍 Related Issues
Unblocks state-pool-based serving integration (ref: SGLang #22109).
🚀 Pull Request Checklist
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
Added
tests/test_la_decode_pool.pywith 5 tests:big_batchkernel variantFull test suite: 378 passed, 51 skipped, 0 failed on B200 (SM100A). All warnings pre-existing.
Reviewer Notes
The existing tests all use
torch.arange(B)fors_offsets, making the bug invisible. The new tests use non-identity and larger-than-batch pool configurations to expose and verify the fix.