Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions gemma/gm/nn/_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,15 @@ def __call__(
# [batch_size, seq_len, cache_size]
attn_mask *= sliding_mask

# Ensure attention mask cache dimension matches logits cache dimension.
# The attn_mask may have been created with a larger cache_length during
# prefill, but the actual cache (and thus logits) may have been sliced.
# Shape: logits is [batch_size, seq_len, num_heads, cache_size]
actual_cache_size = logits.shape[-1]
if attn_mask.shape[-1] != actual_cache_size:
# Slice the attention mask to match the actual cache size being used.
attn_mask = attn_mask[..., :actual_cache_size]

# [batch_size, seq_len, num_heads, cache_size]
padded_logits = jnp.where((jnp.expand_dims(attn_mask, -2)), logits, K_MASK)

Expand Down
101 changes: 101 additions & 0 deletions gemma/gm/nn/_modules_mask_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
# Copyright 2025 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Test for attention mask broadcasting bug (Issue #407)."""

import jax
from jax import numpy as jnp
import pytest


class TestAttentionMaskBroadcasting:
"""Tests for attention mask shape compatibility."""

def test_attention_mask_broadcasting_shape_mismatch(self):
"""Reproduce ValueError from Issue #407: Incompatible shapes for broadcasting.

The bug occurs when attn_mask and logits have mismatched cache dimensions:
- attn_mask shape after expand_dims: (B, L, 1, cache_size_1)
- logits shape: (B, L, num_heads, cache_size_2)

where cache_size_1 != cache_size_2, causing broadcasting failure.
"""
# Simulate the scenario from the error:
# attn_mask originally: (1, 1447, 5234) -> after expand_dims: (1, 1447, 1, 5234)
# logits: (1, 1447, 8, 4096)
batch_size, seq_len, num_heads = 1, 1447, 8
mask_cache_size = 5234
logits_cache_size = 4096

# Original 3D attention mask (before expand_dims)
attn_mask = jnp.ones((batch_size, seq_len, mask_cache_size), dtype=bool)

# Logits from attention computation
logits = jnp.zeros((batch_size, seq_len, num_heads, logits_cache_size))

K_MASK = jnp.finfo(logits.dtype).min

# This is the line from _modules.py:277 that causes the error
with pytest.raises(ValueError, match="Incompatible shapes for broadcasting"):
_ = jnp.where((jnp.expand_dims(attn_mask, -2)), logits, K_MASK)

def test_attention_mask_broadcasting_with_auto_slice_fix(self):
"""Test that automatic slicing fixes the shape mismatch (Issue #407 fix)."""
batch_size, seq_len, num_heads = 1, 1447, 8
mask_cache_size = 5234
logits_cache_size = 4096

# Original 3D attention mask (larger than needed)
attn_mask = jnp.ones((batch_size, seq_len, mask_cache_size), dtype=bool)

# Logits from attention computation
logits = jax.random.normal(
jax.random.PRNGKey(0), (batch_size, seq_len, num_heads, logits_cache_size)
)

K_MASK = jnp.finfo(logits.dtype).min

# Apply the fix: slice attention mask to match logits cache size
actual_cache_size = logits.shape[-1]
if attn_mask.shape[-1] != actual_cache_size:
attn_mask = attn_mask[..., :actual_cache_size]

# Now this should work without errors
padded_logits = jnp.where(
(jnp.expand_dims(attn_mask, -2)), logits, K_MASK
)

assert padded_logits.shape == logits.shape
assert attn_mask.shape[-1] == logits.shape[-1]

def test_attention_mask_broadcasting_correct_shapes(self):
"""Test that broadcasting works when cache sizes match."""
batch_size, seq_len, num_heads, cache_size = 1, 100, 8, 512

# Properly shaped mask and logits
attn_mask = jnp.ones((batch_size, seq_len, cache_size), dtype=bool)
logits = jax.random.normal(
jax.random.PRNGKey(0), (batch_size, seq_len, num_heads, cache_size)
)

K_MASK = jnp.finfo(logits.dtype).min

# This should work without errors
padded_logits = jnp.where(
(jnp.expand_dims(attn_mask, -2)), logits, K_MASK
)

assert padded_logits.shape == logits.shape
# Verify that where mask is True, we get logits; where False, we get K_MASK
assert jnp.all(padded_logits[0, 0, 0, :] == logits[0, 0, 0, :])