Skip to content

[Feature] Intra-Card Context Parallelism to improve B=1 long-sequence performance #8

@yyq0210

Description

@yyq0210

[Feature] Intra-Card Context Parallelism to improve B=1 long-sequence performance

Problem

FlashKDA's kernel2 (recurrence) launches with grid (N, H) — one block per head per sequence. At B=1 (e.g., single long-sequence prefill), SM utilization is critically low:

Config kernel2 blocks H20 SMs Utilization
B=1, H=16 16 78 20.5%
B=1, H=32 32 78 41.0%

This causes FlashKDA to be slower than Triton baselines (fla_chunk_kda) at B=1:

SeqLen H FlashKDA fla_chunk_kda Ratio
64k 16 9.51ms 6.63ms 0.70×
256k 16 37.86ms 26.48ms 0.70×
512k 16 75.67ms 52.97ms 0.70×

Proposed Solution

Adapt FlashQLA's Intra-Card Context Parallelism strategy — split long sequences into sub-segments and process them in parallel:

  1. Sequence splitting: Automatically compute optimal sub-segment length based on SM count and head count
  2. Two-pass forward:
    • Pass 1: Run all sub-segments with h0=0 in parallel, capture each sub-segment's final_state
    • Pass 2: Chain final_states as initial_states for subsequent segments, re-run in parallel to produce correct output
  3. Safety guarantee: Analytically estimate gate decay via A_log; only enable CP when initial state contribution is negligible

Advantages

  • Pure Python-level implementation — no CUDA kernel modifications required
  • Leverages existing cu_seqlens + initial_state interface
  • Automatically decides whether to engage CP; zero overhead for batched scenarios

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions