From 3b382ddac999b6345ed7e4ede8b351f6ac47e207 Mon Sep 17 00:00:00 2001 From: Kovbo Date: Thu, 9 Apr 2026 23:59:02 +0000 Subject: [PATCH 1/4] fix drad norm explosion --- src/art/megatron/flex_attention.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/src/art/megatron/flex_attention.py b/src/art/megatron/flex_attention.py index 0e9c5255..be81b40f 100644 --- a/src/art/megatron/flex_attention.py +++ b/src/art/megatron/flex_attention.py @@ -29,9 +29,8 @@ class FlexAttentionWrapper(torch.nn.Module): """Compiled `flex_attention` wrapper with Torchtitan-style inductor options.""" # Torchtitan inductor options for compiling flex attention. - _compile_options = { + _compile_options: dict[str, Any] = { "max_autotune": True, - "coordinate_descent_tuning": True, "triton.cudagraphs": False, } # Skip Inductor's flex_decoding specialization: it has triggered both @@ -72,14 +71,11 @@ def forward( ) -_compiled_create_block_mask = torch.compile(create_block_mask) - - def create_shared_prefix_attention_state( group_ids: Tensor, parent_ids: Tensor, ) -> SharedPrefixAttentionState: - """Build a compiled block mask for ART shared-prefix packing. + """Build a block mask for ART shared-prefix packing. Initialized on the device of the group_ids tensor. @@ -102,7 +98,15 @@ def _shared_prefix_mask( parent_prefix = parent_ids[batch_idx, query_idx] == group_ids[batch_idx, kv_idx] return (query_idx >= kv_idx) & (same_group | parent_prefix) - block_mask = _compiled_create_block_mask( + # NOTE: build the BlockMask eagerly, NOT through torch.compile. + # `_shared_prefix_mask` is a fresh closure on every call that captures + # different `group_ids` / `parent_ids` tensors from the enclosing scope. + # torch.compile's cache can reuse a compiled BlockMask built against stale + # closure captures, producing a mask whose block structure mismatches the + # forward — the compiled flex_attention backward then computes gradients + # over the wrong regions and produces astronomical-but-finite grads on a + # subset of micro-batches. Keep this call eager. + block_mask = create_block_mask( _shared_prefix_mask, group_ids.shape[0], None, From df481791a30401ec0371bea5195b8a52c82ab091 Mon Sep 17 00:00:00 2001 From: Kovbo Date: Fri, 10 Apr 2026 00:18:47 +0000 Subject: [PATCH 2/4] fix types --- src/art/megatron/flex_attention.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/art/megatron/flex_attention.py b/src/art/megatron/flex_attention.py index be81b40f..673e28bd 100644 --- a/src/art/megatron/flex_attention.py +++ b/src/art/megatron/flex_attention.py @@ -13,6 +13,7 @@ from torch import Tensor from torch.nn.attention.flex_attention import ( BlockMask, + FlexKernelOptions, create_block_mask, flex_attention, ) @@ -38,7 +39,7 @@ class FlexAttentionWrapper(torch.nn.Module): # failures (create_flex_decoding_kernel). The regular flex_attention # kernel autotunes against the actual hardware smem budget, so this # stays GPU-agnostic. - _kernel_options = { + _kernel_options: FlexKernelOptions = { "FORCE_USE_FLEX_ATTENTION": True, } _compiled_flex_attention: ClassVar = torch.compile( From 003d4da6a9dcd0347fffb6219c3a287a2c55ec24 Mon Sep 17 00:00:00 2001 From: FurtherAI Date: Mon, 13 Apr 2026 19:27:56 +0000 Subject: [PATCH 3/4] Use aot_eager for shared-prefix block masks --- src/art/megatron/flex_attention.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/src/art/megatron/flex_attention.py b/src/art/megatron/flex_attention.py index 673e28bd..795f65a0 100644 --- a/src/art/megatron/flex_attention.py +++ b/src/art/megatron/flex_attention.py @@ -72,6 +72,11 @@ def forward( ) +# Sequence-length churn can break the Inductor backend here. Keep this +# on aot_eager instead. +_compiled_create_block_mask = torch.compile(create_block_mask, backend="aot_eager") + + def create_shared_prefix_attention_state( group_ids: Tensor, parent_ids: Tensor, @@ -99,15 +104,7 @@ def _shared_prefix_mask( parent_prefix = parent_ids[batch_idx, query_idx] == group_ids[batch_idx, kv_idx] return (query_idx >= kv_idx) & (same_group | parent_prefix) - # NOTE: build the BlockMask eagerly, NOT through torch.compile. - # `_shared_prefix_mask` is a fresh closure on every call that captures - # different `group_ids` / `parent_ids` tensors from the enclosing scope. - # torch.compile's cache can reuse a compiled BlockMask built against stale - # closure captures, producing a mask whose block structure mismatches the - # forward — the compiled flex_attention backward then computes gradients - # over the wrong regions and produces astronomical-but-finite grads on a - # subset of micro-batches. Keep this call eager. - block_mask = create_block_mask( + block_mask = _compiled_create_block_mask( _shared_prefix_mask, group_ids.shape[0], None, From 2e4bb3da25bf2da71f7890af47eba4be4fcdc2a4 Mon Sep 17 00:00:00 2001 From: Kovbo Date: Tue, 14 Apr 2026 20:25:50 +0000 Subject: [PATCH 4/4] add coordinate_descent_tuning back --- src/art/megatron/flex_attention.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/art/megatron/flex_attention.py b/src/art/megatron/flex_attention.py index da44b789..e5f583d3 100644 --- a/src/art/megatron/flex_attention.py +++ b/src/art/megatron/flex_attention.py @@ -36,6 +36,7 @@ class FlexAttentionWrapper(torch.nn.Module): # Torchtitan inductor options for compiling flex attention. _compile_options: ClassVar[CompileOptions] = { "max_autotune": True, + "coordinate_descent_tuning": True, "triton.cudagraphs": False, } # Skip Inductor's flex_decoding specialization: it has triggered both