[Feature] Intra-Card Context Parallelism to improve B=1 long-sequence performance
Problem
FlashKDA's kernel2 (recurrence) launches with grid (N, H) — one block per head per sequence. At B=1 (e.g., single long-sequence prefill), SM utilization is critically low:
| Config |
kernel2 blocks |
H20 SMs |
Utilization |
| B=1, H=16 |
16 |
78 |
20.5% |
| B=1, H=32 |
32 |
78 |
41.0% |
This causes FlashKDA to be slower than Triton baselines (fla_chunk_kda) at B=1:
| SeqLen |
H |
FlashKDA |
fla_chunk_kda |
Ratio |
| 64k |
16 |
9.51ms |
6.63ms |
0.70× |
| 256k |
16 |
37.86ms |
26.48ms |
0.70× |
| 512k |
16 |
75.67ms |
52.97ms |
0.70× |
Proposed Solution
Adapt FlashQLA's Intra-Card Context Parallelism strategy — split long sequences into sub-segments and process them in parallel:
- Sequence splitting: Automatically compute optimal sub-segment length based on SM count and head count
- Two-pass forward:
- Pass 1: Run all sub-segments with h0=0 in parallel, capture each sub-segment's final_state
- Pass 2: Chain final_states as initial_states for subsequent segments, re-run in parallel to produce correct output
- Safety guarantee: Analytically estimate gate decay via A_log; only enable CP when initial state contribution is negligible
Advantages
- Pure Python-level implementation — no CUDA kernel modifications required
- Leverages existing
cu_seqlens + initial_state interface
- Automatically decides whether to engage CP; zero overhead for batched scenarios
[Feature] Intra-Card Context Parallelism to improve B=1 long-sequence performance
Problem
FlashKDA's kernel2 (recurrence) launches with grid
(N, H)— one block per head per sequence. At B=1 (e.g., single long-sequence prefill), SM utilization is critically low:This causes FlashKDA to be slower than Triton baselines (fla_chunk_kda) at B=1:
Proposed Solution
Adapt FlashQLA's Intra-Card Context Parallelism strategy — split long sequences into sub-segments and process them in parallel:
Advantages
cu_seqlens+initial_stateinterface