Skip to content

fix: wire h0_indices into Lightning Attention decode for state-pool indexing#75

Open
Emre-Dinc wants to merge 1 commit into
inclusionAI:mainfrom
Emre-Dinc:fix/la-decode-state-pool-indexing
Open

fix: wire h0_indices into Lightning Attention decode for state-pool indexing#75
Emre-Dinc wants to merge 1 commit into
inclusionAI:mainfrom
Emre-Dinc:fix/la-decode-state-pool-indexing

Conversation

@Emre-Dinc
Copy link
Copy Markdown

📌 Description

Lightning Attention decode ignores s_offsets and indexes state by flattened batch_idx directly (the code comments this explicitly: "passed to kernel but not actually used").

What changed:

  • Both decode kernels: index h0_source via h0_indices[i_n] * HV + i_hv instead of batch_idx
  • Grid launch uses active B * H instead of h0_source.shape[0]
  • Compile cache key includes pool_dim0 so different pool sizes recompile correctly
  • Benchmark caller updated to match new cache signature
  • Added s.shape[0] % H == 0 validation

Not in scope: Pad semantics (offset == -1) - can be added as a follow-up mirroring kda_decode.py.

🔍 Related Issues

Unblocks state-pool-based serving integration (ref: SGLang #22109).

🚀 Pull Request Checklist

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing.

Added tests/test_la_decode_pool.py with 5 tests:

  • Identity offsets (baseline)
  • Non-identity offsets (pool_size=6, batch=4, offsets=[2,0,5,1])
  • Reversed offsets (pool_size=batch, offsets=[3,2,1,0])
  • State writeback verification (active slots updated, inactive slots untouched)
  • Big batch path (B=33, pool_size=40) - exercises big_batch kernel variant

Full test suite: 378 passed, 51 skipped, 0 failed on B200 (SM100A). All warnings pre-existing.

Reviewer Notes

The existing tests all use torch.arange(B) for s_offsets, making the bug invisible. The new tests use non-identity and larger-than-batch pool configurations to expose and verify the fix.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request implements indirect indexing for state pool access in the Lightning Attention decode kernels, ensuring that s_offsets are correctly utilized. The changes include updates to both small and big batch kernels, adjustments to the compilation caching logic, and the addition of a comprehensive test suite covering various offset scenarios. A review comment suggests adding explicit dimensionality assertions for the state and output tensors to align with the implementation and prevent potential runtime errors stemming from misleading docstrings.

Comment thread cula/ops/la_decode.py
h0_source = s

# Validate state pool dimensions
assert s.shape[0] % H == 0, f"s.shape[0] must be divisible by H={H}, got {s.shape[0]}"
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The implementation assumes that s is a 3D tensor with shape [pool_size * H, V, K] and out is a 3D tensor with shape [B, H, V]. However, the docstrings (lines 585-586 and 607-608) incorrectly describe them as 4D tensors. To prevent runtime indexing errors or incorrect results when users follow the docstrings, it is recommended to explicitly validate the dimensionality of these tensors here.

    # Validate state pool dimensions
    assert s.ndim == 3, f"s must be a 3D tensor [pool_size * H, V, K], got {s.ndim}D"
    assert out.ndim == 3, f"out must be a 3D tensor [B, H, V], got {out.ndim}D"
    assert s.shape[0] % H == 0, f"s.shape[0] must be divisible by H={H}, got {s.shape[0]}"

@Emre-Dinc Emre-Dinc force-pushed the fix/la-decode-state-pool-indexing branch from 9b6a77d to c922b16 Compare May 21, 2026 12:40
@Emre-Dinc Emre-Dinc changed the title [KDA] sm90 GVA enhance (#64) fix: wire h0_indices into Lightning Attention decode for state-pool indexing May 21, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants