Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
aec2c90
add SP deny list instead of allow
kashif Mar 5, 2026
49e0310
hub-kernel lazy registration before validation and tests
kashif Mar 8, 2026
2f8e77c
Merge branch 'master' into sp_attn_deny
kashif Mar 8, 2026
ce69dc0
Merge branch 'master' into sp_attn_deny
tohtana Mar 10, 2026
952c3ae
Merge branch 'master' into sp_attn_deny
kashif Mar 13, 2026
6e3f2cb
Merge branch 'master' into sp_attn_deny
kashif Mar 14, 2026
874ec62
position_ids generation and flex_attention BlockMask
kashif Mar 14, 2026
7d0a136
refactor
kashif Mar 14, 2026
5868135
update comments
kashif Mar 14, 2026
b0e05f0
do not check for has_packed_samples
kashif Mar 16, 2026
89058fc
raise error instead of warning
kashif Mar 17, 2026
463cb30
cache BlockMask
kashif Mar 17, 2026
5cf053d
Merge branch 'master' into sp_attn_deny
loadams Mar 23, 2026
e250492
Merge branch 'master' into sp_attn_deny
sfc-gh-truwase Mar 27, 2026
8f53e6c
flex_attention: lazy imports to fix pickle issue and add test
kashif Mar 27, 2026
d4c8953
fix flex_attention BlockMask guard and add position_ids warning in fo…
kashif Mar 27, 2026
12bbb9c
flex_attention test: use local tiny model, fix _compile gradient issue
kashif Mar 27, 2026
4386253
don't compile create_block_mask inside forward()
kashif Mar 27, 2026
3ed1ff5
assert position_ids in forward() instead of warning
kashif Mar 28, 2026
6cb71e9
Merge branch 'master' into sp_attn_deny
sfc-gh-truwase Mar 28, 2026
b82dc59
Merge branch 'master' into sp_attn_deny
stas00 Mar 31, 2026
5470bed
Merge branch 'master' into sp_attn_deny
stas00 Apr 1, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
148 changes: 116 additions & 32 deletions deepspeed/runtime/sequence_parallel/ulysses_sp.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
import deepspeed.comm as dist
import importlib.metadata
import math
import re
import torch
import torch.distributed.nn

Expand Down Expand Up @@ -121,6 +122,10 @@ def __init__(
self.skip_all_but_last_attention_debug_mode = False
self.rotating_layer_counter = 0 # used for dev work

self.core_attn_implementation = None # set by register_with_transformers
self._flex_block_mask_cached = None # cached BlockMask for flex_attention
self._flex_block_mask_cache_key = None # (batch_size, seq_len) for cache invalidation

self.local_q_head_count = attn_head_count // self.world_size

# if we have 4 kv heads and sp 8, we need to replicate kv heads 2x
Expand Down Expand Up @@ -272,23 +277,21 @@ def forward(
key = rearrange(key, "bs hc sl hs -> sl bs hc hs") # .contiguous()
value = rearrange(value, "bs hc sl hs -> sl bs hc hs") # .contiguous()

# core attn like FA2 expects an unsharded `position_ids` - without which packed samples
# will return loss=nan.
#
# XXX: need to figure out if we can do the same for SDPA - as it doesn't require this and
# wants an attention mask, so possibly doing this for FA2 only?
#
# Ideally we would passing the original unsharded position_ids - but we have no way to pass
# it here as HF Transformers drops unexpected keys in `batch` - so either we need to stash
# it somewhere in UlyssesSPDataLoaderAdapter and retrieve it here or we could gather it once
# per batch and stash it inside `module` arg - I already have a machinery to figure out
# which layer number is being called below in the skip_all_but_last_attention_debug_mode
# code where rotating_layer_counter is used - so we could calculate it on the first layer
# and re-use on the remaining layers
if "position_ids" in kwargs:
position_ids_list = [torch.empty_like(kwargs["position_ids"]) for _ in range(self.world_size)]
dist.all_gather(position_ids_list, kwargs["position_ids"], group=self.process_group)
kwargs["position_ids"] = torch.cat(position_ids_list, dim=1)
# All attention backends need unsharded position_ids after the all-to-all.
# FA2 uses them for packed-sequence detection (flash_varlen_fn), sdpa/flex_attention
# need them to be monotonically increasing so causal masking works correctly.
# UlyssesSPDataLoaderAdapter ensures position_ids are in the batch before sharding,
# so after gathering here they reconstruct to the correct global positions.
assert "position_ids" in kwargs, (
"Ulysses SP requires position_ids in every forward() call so that after all_gather "
"causal masking works correctly. Without them each rank generates local [0..chunk_len-1] "
"positions which, after gathering, look like packed sequences and break attention. "
"For non-packed sequences: position_ids = torch.arange(seq_len) per sample. "
"For packed sequences: position_ids must reset at document boundaries. "
"Ensure your data collator or UlyssesSPDataLoaderAdapter includes position_ids.")
position_ids_list = [torch.empty_like(kwargs["position_ids"]) for _ in range(self.world_size)]
dist.all_gather(position_ids_list, kwargs["position_ids"], group=self.process_group)
kwargs["position_ids"] = torch.cat(position_ids_list, dim=1)

# please don't remove the white-space vertical alignment in the error message
assert query.shape == self.required_query_shape, (
Expand All @@ -311,6 +314,41 @@ def forward(
if self.kv_replication_factor > 1:
module.num_key_value_groups = query_layer.size(-3) // key_layer.size(-3)

# For flex_attention: the wrapper preserved the BlockMask from the model, but it
# was built for the local shard's sequence length. Rebuild it for the full gathered
# sequence length after the all-to-all.
# XXX: currently hardcodes a causal mask_mod — models with sliding window or other
# non-standard patterns would need the mask_mod extracted from the original BlockMask.
if self.core_attn_implementation == "flex_attention":
from torch.nn.attention.flex_attention import BlockMask, create_block_mask
Comment thread
kashif marked this conversation as resolved.
if isinstance(attention_mask, BlockMask):
seq_len = query_layer.shape[2]
batch_size = query_layer.shape[0]
cache_key = (batch_size, seq_len)

# Cache the BlockMask — create_block_mask is expensive and the mask is the
# same for all layers within a forward pass. Only rebuild when dimensions change.
if self._flex_block_mask_cache_key != cache_key:

def causal_mask(batch_idx, head_idx, q_idx, kv_idx):
return q_idx >= kv_idx

# Don't compile create_block_mask here — it runs inside the model's
# forward pass where flex_attention already uses torch.compile, and
# nesting compiled contexts causes gradient explosion in the backward
# pass. The BlockMask is cached so creation cost is negligible.
self._flex_block_mask_cached = create_block_mask(
mask_mod=causal_mask,
B=batch_size,
H=None,
Q_LEN=seq_len,
KV_LEN=seq_len,
device=query_layer.device,
)
self._flex_block_mask_cache_key = cache_key

attention_mask = self._flex_block_mask_cached

if not self.skip_all_but_last_attention_debug_mode:
# expects: [bs hc_l sl hs]
context_layer, attn_weights = self.attn(module, query_layer, key_layer, value_layer, attention_mask, *args,
Expand Down Expand Up @@ -411,15 +449,34 @@ def register_with_transformers(
# if we don't have the model yet at this stage
hf_model_config = AutoConfig.from_pretrained(model_name_or_path)

supported_attn_implementation = ["flash_attention_2", "flash_attention_3", "sdpa"]
if core_attn_implementation not in supported_attn_implementation:
# notes on the excluded ones:
# - eager: The problem is that `eager` wants an attention_mask and it creates the wrong attention mask it seems if we don't provide one - it's possible that we could somehow solve this, but it's also unlikely someone will want to use the slow eager attention with sequence parallelism
# - flex_attention: haven't tried

model_attn_implementation = getattr(hf_model_config, "_attn_implementation", None)
if model_attn_implementation is not None and model_attn_implementation != core_attn_implementation:
raise ValueError(
f"core_attn_implementation='{core_attn_implementation}' does not match "
f"model config attn_implementation='{model_attn_implementation}'. "
"Set both to the same value so sequence-parallel wrapper can intercept the active attention path.")

# eager always materializes a 4D attention_mask (O(n²) memory) and cannot fall back
# to is_causal=True like sdpa — so it's incompatible with SP which discards masks.
unsupported_attn_implementation = ["eager", "paged|eager"]
if core_attn_implementation in unsupported_attn_implementation:
raise ValueError(
f"{core_attn_implementation} attn_implementation isn't currently supported by Ulysses sequence"
f" parallelism. Set core_attn_implementation arg to one of {supported_attn_implementation}.")
f" parallelism because it requires a 4D attention_mask (O(n²) memory)."
f" Use any flash attention variant, 'flex_attention', 'sdpa',"
f" or a hub-hosted kernel (e.g. 'kernels-community/flash-attn2').")

# Hub kernels (e.g. kernels-community/flash-attn2) are registered lazily in transformers.
# Ensure registration happens before validating against ALL_ATTENTION_FUNCTIONS.
is_hub_kernel_attn = (isinstance(core_attn_implementation, str) and re.search(
r"^[^/:]+/[^/:]+(?:@[^/:]+)?(?::[^/:]+)?$", core_attn_implementation) is not None)
if is_hub_kernel_attn:
try:
from transformers.modeling_flash_attention_utils import lazy_import_flash_attention
except ImportError as e:
raise ImportError("Hub kernel attention requires a transformers version exposing "
"`transformers.modeling_flash_attention_utils.lazy_import_flash_attention`.") from e
lazy_import_flash_attention(core_attn_implementation)

if core_attn_implementation not in ALL_ATTENTION_FUNCTIONS:
raise ValueError(
Expand Down Expand Up @@ -448,6 +505,7 @@ def register_with_transformers(
global_seq_length=global_seq_length,
disable_in_eval=disable_in_eval,
)
uattn.core_attn_implementation = core_attn_implementation

def uattn_wrapper(
module: torch.nn.Module,
Expand All @@ -459,27 +517,41 @@ def uattn_wrapper(
**kwargs,
) -> Tuple[torch.Tensor, torch.Tensor]:

# We are relaying on position_ids for SP to work so attention_mask has to be None
# the problem is that HF currently doesn't know anything about ALL_ATTENTION_FUNCTIONS["ulysses"] so it doesn't make a special case like for "flash_attention_2" and "sdpa" and it creates an attention mask on the fly and it breaks things.
attention_mask = None
# SP relies on position_ids (not attention_mask) for causal masking.
# HF doesn't know about the SP wrapper, so it creates an attention_mask for
# the local shard's sequence length — which is invalid after the SP all-to-all
# gathers the full sequence. A 4D mask at full sequence length would also be
# O(n²) memory. So we discard 4D tensor masks.
#
# Keep BlockMask (flex_attention) — it's a compressed sparse representation.
# It will be rebuilt for the full gathered sequence in forward().
_is_block_mask = False
if core_attn_implementation == "flex_attention":
from torch.nn.attention.flex_attention import BlockMask
_is_block_mask = isinstance(attention_mask, BlockMask)

if not _is_block_mask:
attention_mask = None

attn_output, attn_weights = uattn(
module,
query,
key,
value,
attention_mask,
# XXX: fixme
*args,
**kwargs,
)
return attn_output, attn_weights

# We don't do: ALL_ATTENTION_FUNCTIONS.register("ulysses", uattn_wrapper)
# The problem with this approach is that we are missing on all the special use cases in HF Transformers that do things like: if self.config._attn_implementation == "flash_attention_2": ...
# So instead we hack `ALL_ATTENTION_FUNCTIONS` to override all existing keys with our implementation, since it only gets used at the point of calling the attention and that's what we want, all other code branches relying on the original core `attn_implementation` will still be executed. This is what we called "Being John Malkovich"
for key in ALL_ATTENTION_FUNCTIONS.keys():
ALL_ATTENTION_FUNCTIONS[key] = uattn_wrapper
# The problem with that approach is that we'd miss all the special-case branches in
# HF Transformers that check `if self.config._attn_implementation == "flash_attention_2": ...`
# So instead we override the requested core implementation key in ALL_ATTENTION_FUNCTIONS
# with our wrapper. All other code paths relying on the original core attn_implementation
# will still be executed — we only intercept at the point of calling attention.
# This is what we called "Being John Malkovich".
ALL_ATTENTION_FUNCTIONS[core_attn_implementation] = uattn_wrapper

return mpu

Expand Down Expand Up @@ -574,6 +646,18 @@ def refill(self):
micro_batches = defaultdict(dict)
# XXX: replace with more efficient all-to-all?

# position_ids must exist before sharding so that after all_gather in
# UlyssesSPAttentionHF.forward() they reconstruct to correct global positions.
# Without them, the Trainer generates local [0,...,chunk_len-1] per rank AFTER
# sharding, which after all_gather looks like packed sequences and breaks
# sdpa/flex_attention causal masking.
if "position_ids" not in batch:
Comment thread
kashif marked this conversation as resolved.
raise ValueError("Ulysses SP requires `position_ids` in every dataloader batch so that "
"each token retains its correct global position after sequence sharding. "
"For non-packed sequences: position_ids = torch.arange(seq_len) per sample. "
"For packed sequences: position_ids must reset at document boundaries. "
"Ensure your data collator includes position_ids in its output.")

# we have batches of variable seqlen so in order to do all_gather on batches - we need to know the exact length of each tensor on each rank
seqlen = torch.tensor(batch["input_ids"].shape[1], dtype=torch.int64, device=self.device)
seqlens = [torch.zeros(1, dtype=torch.int64, device=self.device) for _ in range(self.sp_world_size)]
Expand Down
Loading
Loading