Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def forward( # type: ignore
x: torch.Tensor,
token_specs: Sequence[TokenSpec],
attention_patterns: Sequence[AttentionPattern],
attention_mask: torch.Tensor | None = None,
key_padding_mask: torch.Tensor | None = None,
freqs: torch.Tensor | None = None,
) -> torch.Tensor:
"""Apply mixed attention with flexible token-name-based patterns.
Expand All @@ -70,10 +70,12 @@ def forward( # type: ignore
token groups (queries) attend to which other token groups (keys/values).
The provided patterns must be exhaustive and non-overlapping. This means every
token group defined in `token_specs` must be a query in exactly one pattern.
attention_mask: Optional attention mask (not currently supported)
key_padding_mask: Optional boolean mask of shape ``(batch_size, n_tokens)``.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if it makes more sense to encode the # of padding tokens as part of the TokenSpec instead of adding a separate input. Have you thought about that?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Generally I think this makes sense.

However, currently the TokenSpec treats every element in the batch the same, i.e. the length of each sub-sequence is assumed the same for each element in the batch. But we need the masking with different paddings per each element in the batch.

I could see 3 approaches:
A) Extend Mixed Attention: Flatten the variable-size batch dimension into the sequence dimension, re-implement the mixed attention using variable-length flash attention, and extend TokenSpec to per-example sequence length specifications.
B) Put mask in TokenSpec nevertheless, mixing mask specification per element in batch and size specification shared among batch.
C) Keep per-batch mask outside of TokenSpec

IMO A) increases complexity unnecessarily at the moment, although it might be worth looking into variable length flash attention in more detail separately of this PR.
I think mixing as in B) is confusing and error prone.
Therefore I tend towards C) keeping as is.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

WDYT?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

true, I didn't think of that. Thanks for explaining!

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

true, I didn't think of that. Thanks for explaining!

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

true, I didn't think of that. Thanks for explaining!

``True`` indicates a real token; ``False`` indicates a padding token that should not be attended to.
The mask is sliced per attention pattern to cover only the key/value tokens of that pattern.
freqs: RoPE frequencies for positional encoding
"""
self._validate_inputs(x, token_specs, attention_patterns, attention_mask, freqs)
self._validate_inputs(x, token_specs, attention_patterns, key_padding_mask, freqs)

# Initial Projection
q, k, v = einops.rearrange(
Expand All @@ -91,7 +93,15 @@ def forward( # type: ignore
}
spec_size_map = {spec.name: spec.size for spec in token_specs}

token_outputs = self._process_pattern_batched(attention_patterns, q, k, v, token_slices, spec_size_map) # type: ignore[arg-type]
token_outputs = self._process_pattern_batched(
attention_patterns=attention_patterns,
q=q,
k=k,
v=v,
token_slices=token_slices, # type: ignore[arg-type]
spec_size_map=spec_size_map, # type: ignore[arg-type]
key_padding_mask=key_padding_mask,
)

# Final assembly and output projection
output_parts = [token_outputs[spec.name] for spec in token_specs]
Expand All @@ -107,6 +117,7 @@ def _process_pattern_batched(
v: torch.Tensor,
token_slices: dict[str, slice],
spec_size_map: dict[str, int],
key_padding_mask: torch.Tensor | None = None,
) -> dict[str, torch.Tensor]:
"""Efficient mixed attention implementation that batches compatible (same shape) attention patterns."""
# Group compatible patterns
Expand All @@ -126,8 +137,26 @@ def _process_pattern_batched(
q_batched = torch.cat(qs, dim=0)
k_batched = torch.cat(ks, dim=0)
v_batched = torch.cat(vs, dim=0)

attn_mask_batched: torch.Tensor | None = None
if key_padding_mask is not None:
# For each pattern, slice the mask to its KV tokens (mirroring how k/v are assembled).
# Shape (batch, 1, 1, kv_len): the 1s let PyTorch broadcast the same key mask
# across all heads and all query positions (every query sees the same valid key set).
per_pattern_masks = []
for patt in group:
kv_bool = torch.cat(
[key_padding_mask[:, token_slices[name]] for name in patt.key_value_tokens], dim=1
)
per_pattern_masks.append(kv_bool[:, None, None, :]) # (batch, 1, 1, kv_len)
attn_mask_batched = torch.cat(per_pattern_masks, dim=0)

output_batched = F.scaled_dot_product_attention(
q_batched, k_batched, v_batched, dropout_p=self.dropout if self.training else 0.0
q_batched,
k_batched,
v_batched,
attn_mask=attn_mask_batched,
dropout_p=self.dropout if self.training else 0.0,
)
# Undo the batch concatenation
output_chunks = torch.chunk(output_batched, chunks=len(group), dim=0)
Expand All @@ -143,15 +172,24 @@ def _validate_inputs(
x: torch.Tensor,
token_specs: Sequence[TokenSpec],
attention_patterns: Sequence[AttentionPattern],
attention_mask: torch.Tensor | None,
key_padding_mask: torch.Tensor | None,
freqs: torch.Tensor | None,
) -> None:
"""Validate input consistency."""
if not self.use_rope == (freqs is not None):
raise ValueError(f"RoPE usage mismatch: self.use_rope = {self.use_rope}, but freqs is {freqs is not None}")

if attention_mask is not None:
raise NotImplementedError("Attention masks are not supported in this implementation.")
if key_padding_mask is not None:
if key_padding_mask.dtype != torch.bool:
raise ValueError(f"key_padding_mask must be a bool tensor, got {key_padding_mask.dtype}.")
if key_padding_mask.ndim != 2:
raise ValueError(
f"key_padding_mask must be 2D with shape (batch_size, n_tokens), got shape {tuple(key_padding_mask.shape)}."
)
if key_padding_mask.shape[1] != x.shape[1]:
raise ValueError(
f"key_padding_mask n_tokens dim ({key_padding_mask.shape[1]}) must match x sequence length ({x.shape[1]})."
)

expected_size = sum(spec.size for spec in token_specs)
if expected_size != x.shape[1]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,13 @@ def forward(
self,
x: torch.Tensor,
token_specs: Sequence[TokenSpec],
key_padding_mask: torch.Tensor | None = None,
freqs: torch.Tensor | None = None,
) -> torch.Tensor:
"""Apply attention using the patterns defined by the subclass."""
self._validate(token_specs)
patterns = self._create_attention_patterns(token_specs)
return self.mixed_attention(x, token_specs, patterns, freqs=freqs) # type: ignore[no-any-return]
return self.mixed_attention(x, token_specs, patterns, key_padding_mask=key_padding_mask, freqs=freqs) # type: ignore[no-any-return]

@abstractmethod
def _create_attention_patterns(self, token_specs: Sequence[TokenSpec]) -> Sequence[AttentionPattern]:
Expand Down
86 changes: 86 additions & 0 deletions tests/unit/noether/modeling/modules/attention/test_mixed.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,3 +258,89 @@ def test_rope_integration(self):

with patch("noether.modeling.modules.attention.anchor_attention.mixed.rope") as mock_rope:
mock_rope.side_effect = lambda t, freqs: t # Identity mock

def test_attn_mask_validation_errors(self, module):
x = torch.randn(1, 10, 64)
token_specs = [TokenSpec(name="surface_anchors", size=10)]
patterns = [AttentionPattern(query_tokens=["surface_anchors"], key_value_tokens=["surface_anchors"])]

with pytest.raises(ValueError, match="bool tensor"):
module(x, token_specs, patterns, key_padding_mask=torch.ones(1, 10)) # float, not bool

with pytest.raises(ValueError, match="2D"):
module(x, token_specs, patterns, key_padding_mask=torch.ones(1, 10, 1, dtype=torch.bool))

with pytest.raises(ValueError, match="n_tokens dim"):
module(x, token_specs, patterns, key_padding_mask=torch.ones(1, 5, dtype=torch.bool))

def test_attn_mask_matches_unpadded_run(self, module):
"""Masked batched run must match individual unpadded runs for each batch element.

Three items with different numbers of real anchors are padded to size 3 and run
together with a mask. Outputs at real positions must equal running each item alone
without any mask or padding.
"""
torch.manual_seed(42)
dim = 64

anchors_0 = torch.randn(1, 3, dim) # 3 real anchors (no padding)
queries_0 = torch.randn(1, 4, dim)
anchors_1 = torch.randn(1, 2, dim) # 2 real anchors, 1 padding
queries_1 = torch.randn(1, 4, dim)
anchors_2 = torch.randn(1, 1, dim) # 1 real anchor, 2 padding
queries_2 = torch.randn(1, 4, dim)

token_specs_padded = [
TokenSpec(name="surface_anchors", size=3),
TokenSpec(name="surface_queries", size=4),
]
patterns = [
AttentionPattern(
query_tokens=["surface_anchors", "surface_queries"],
key_value_tokens=["surface_anchors"],
)
]

pad1 = torch.zeros(1, 1, dim)
pad2 = torch.zeros(1, 2, dim)
x_batched = torch.cat(
[
torch.cat([anchors_0, queries_0], dim=1), # item 0: a0 a1 a2 q0..q3
torch.cat([anchors_1, pad1, queries_1], dim=1), # item 1: a0 a1 PAD q0..q3
torch.cat([anchors_2, pad2, queries_2], dim=1), # item 2: a0 PAD PAD q0..q3
],
dim=0,
)
attn_mask = torch.tensor(
[
[True, True, True, True, True, True, True], # item 0: all real
[True, True, False, True, True, True, True], # item 1: anchor 2 padding
[True, False, False, True, True, True, True], # item 2: anchors 1-2 padding
]
)

out_batched = module(x_batched, token_specs_padded, patterns, key_padding_mask=attn_mask)

# Reference: run each item individually without mask or padding
out_item0 = module(torch.cat([anchors_0, queries_0], dim=1), token_specs_padded, patterns)

token_specs_2anchors = [TokenSpec(name="surface_anchors", size=2), TokenSpec(name="surface_queries", size=4)]
out_item1 = module(torch.cat([anchors_1, queries_1], dim=1), token_specs_2anchors, patterns)

token_specs_1anchor = [TokenSpec(name="surface_anchors", size=1), TokenSpec(name="surface_queries", size=4)]
out_item2 = module(torch.cat([anchors_2, queries_2], dim=1), token_specs_1anchor, patterns)

# Item 0: all positions must match (no padding)
assert torch.allclose(out_batched[0], out_item0[0], atol=1e-5)

# Item 1: compare real positions only
# Batched layout: a0(0) a1(1) PAD(2) q0(3) q1(4) q2(5) q3(6)
# Unpadded layout: a0(0) a1(1) q0(2) q1(3) q2(4) q3(5)
assert torch.allclose(out_batched[1, [0, 1]], out_item1[0, [0, 1]], atol=1e-5)
assert torch.allclose(out_batched[1, [3, 4, 5, 6]], out_item1[0, [2, 3, 4, 5]], atol=1e-5)

# Item 2: compare real positions only
# Batched layout: a0(0) PAD(1) PAD(2) q0(3) q1(4) q2(5) q3(6)
# Unpadded layout: a0(0) q0(1) q1(2) q2(3) q3(4)
assert torch.allclose(out_batched[2, [0]], out_item2[0, [0]], atol=1e-5)
assert torch.allclose(out_batched[2, [3, 4, 5, 6]], out_item2[0, [1, 2, 3, 4]], atol=1e-5)
Loading