Skip to content

Fix flex attention compilation failures in Megatron training #647

Merged
Kovbo merged 1 commit intomainfrom
fix-flex-attention-compilation
Apr 9, 2026
Merged

Fix flex attention compilation failures in Megatron training #647
Kovbo merged 1 commit intomainfrom
fix-flex-attention-compilation

Conversation

@Kovbo
Copy link
Copy Markdown
Collaborator

@Kovbo Kovbo commented Apr 9, 2026

Issue:
Megatron jobs were crashing during compilation of flex_attention, with two different errors depending on the setup:

  1. Shared-memory OOM during Triton autotune:
    No valid triton configs. OutOfMemoryError: out of resource: triton_flex_decoding
    Required: 312320 Hardware limit: 232448
  2. Symbolic-shape assertion inside Inductor's lowering:
    File ".../torch/_inductor/kernel/flex/flex_decoding.py", line 286, in create_flex_decoding_kernel
    V.graph.sizevars.check_leq(...)
    AssertionError

Root cause

Both errors come from the same wrong kernel. Inductor's flex_attention lowering has two Triton templates:

  • triton_flex_attention — the training kernel (Q and KV similar length).
  • triton_flex_decoding — the inference-decode kernel (short Q, long KV cache), with very large BLOCK_M configs and extra decode-specific symbolic-shape assertions.

With packed training sequences + shared-prefix block masks, query lengths are small and symbolic (s29, s64 in the logs), so Inductor's dispatch heuristic decided the shape "looks like decoding" and silently routed the call into create_flex_decoding_kernel. That kernel:

  • autotunes over configs up to BLOCK_M=1024, which blows past SM shared-memory limits on any consumer/H100 GPU, and
  • has invariants it can't discharge against our symbolic packed-sequence shapes, which trip check_leq during lowering.

So we were never actually running the training kernel — Inductor was specializing us into the wrong one.

Fix

Pass kernel_options={"FORCE_USE_FLEX_ATTENTION": True} to the compiled call in FlexAttentionWrapper. This disables the flex_decoding dispatch so all calls go through the regular training kernel.

@Kovbo Kovbo requested a review from FurtherAI April 9, 2026 22:23
@Kovbo Kovbo merged commit 4c03ad2 into main Apr 9, 2026
4 of 5 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants