diff --git a/src/art/megatron/flex_attention.py b/src/art/megatron/flex_attention.py index 2e4ab5a0..e5f583d3 100644 --- a/src/art/megatron/flex_attention.py +++ b/src/art/megatron/flex_attention.py @@ -77,14 +77,16 @@ def forward( ) -_compiled_create_block_mask = torch.compile(create_block_mask) +# Sequence-length churn can break the Inductor backend here. Keep this +# on aot_eager instead. +_compiled_create_block_mask = torch.compile(create_block_mask, backend="aot_eager") def create_shared_prefix_attention_state( group_ids: Tensor, parent_ids: Tensor, ) -> SharedPrefixAttentionState: - """Build a compiled block mask for ART shared-prefix packing. + """Build a block mask for ART shared-prefix packing. Initialized on the device of the group_ids tensor.