From 4e95d69e1097ae9c0a4ed5d1bb022429496a75a8 Mon Sep 17 00:00:00 2001 From: erfanzar Date: Wed, 26 Nov 2025 18:11:02 +0000 Subject: [PATCH] fix(rpa-v3): add sliding window mask to h64 kernel and attention_sink to h128 Fixes #1169 This PR fixes two issues in the ragged paged attention v3 kernels: 1. **kernel_hd64.py (h64)**: Added missing sliding window mask in the kernel. The original code only skipped fetching KV blocks outside the window but didn't apply token-level masking within partially-covered blocks. 2. **kernel.py (h128)**: Added attention_sink support following the same pattern as the h64 kernel. Attention sinks allow the model to "dump" attention to a virtual token that doesn't contribute to the output. Uses LEFT concatenation semantics where sink logits are prepended before softmax, then removed after. Changes: - kernel_hd64.py: Added `if sliding_window is not None` mask in flash_attention - kernel.py: Added attention_sink parameter to all functions (ref impl, kernel, prepare_inputs, validation, main function) - kernel.py: Initialize m_prev with sink values and l_prev with 1.0 for proper online softmax tracking across blocks when using attention_sink --- .../ragged_paged_attention/v3/kernel.py | 58 +++++++++++++++++-- .../ragged_paged_attention/v3/kernel_hd64.py | 3 + 2 files changed, 55 insertions(+), 6 deletions(-) diff --git a/tpu_inference/kernels/ragged_paged_attention/v3/kernel.py b/tpu_inference/kernels/ragged_paged_attention/v3/kernel.py index f10e7962e..981c2eb26 100644 --- a/tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +++ b/tpu_inference/kernels/ragged_paged_attention/v3/kernel.py @@ -35,6 +35,7 @@ def ref_ragged_paged_attention( page_indices: jax.Array, # i32[max_num_seqs * pages_per_seq] cu_q_lens: jax.Array, # i32[max_num_seqs + 1] distribution: jax.Array, # i32[3] + attention_sink: jax.Array | None = None, # f32[actual_num_q_heads] *, sm_scale: float = 1.0, sliding_window: int | None = None, @@ -56,6 +57,7 @@ def ref_ragged_paged_attention( page_indices, cu_q_lens, distribution, + attention_sink, sm_scale=sm_scale, sliding_window=sliding_window, soft_cap=soft_cap, @@ -143,7 +145,18 @@ def ref_ragged_paged_attention( if soft_cap is not None: attn = soft_cap * jnp.tanh(attn / soft_cap) attn += jnp.where(mask, mask_value, 0.0) - attn = jax.nn.softmax(attn, axis=-1).astype(v.dtype) + + if attention_sink is not None: + reshaped_attention_sink = attention_sink.reshape( + actual_num_q_heads, 1, 1) + reshaped_attention_sink = jnp.repeat(reshaped_attention_sink, + q_len, + axis=1) + attn = jnp.concat([reshaped_attention_sink, attn], axis=2) + attn = jax.nn.softmax(attn, axis=-1).astype(v.dtype) + attn = attn[..., 1:] + else: + attn = jax.nn.softmax(attn, axis=-1).astype(v.dtype) out = jnp.einsum("hqk,khd->qhd", attn, v).astype(queries.dtype) if v_scale is not None: @@ -236,6 +249,7 @@ def _ragged_paged_attention_kernel( q_hbm_ref, # [actual_num_kv_heads, max_num_tokens, num_q_heads_per_kv_head // q_packing, q_packing, head_dim] kv_hbm_ref, # [max_num_tokens, num_kv_heads_x2 // kv_packing, kv_packing, head_dim] kv_cache_hbm_ref, # [total_num_pages, page_size, num_kv_heads_x2 // kv_packing, kv_packing, head_dim] + attention_sink_ref, # [actual_num_kv_heads, num_q_heads_per_kv_head, head_dim] # Output o_hbm_ref, # [actual_num_kv_heads, max_num_tokens, num_q_heads_per_kv_head // q_packing, q_packing, head_dim] updated_kv_cache_hbm_ref, # [total_num_pages, page_size, num_kv_heads_x2 // kv_packing, kv_packing, head_dim] @@ -371,7 +385,15 @@ def load_with_init(ref, init_val): s = soft_cap * jnp.tanh(s / soft_cap) s += jnp.where(mask, mask_value, 0.0) s_rowmax = jnp.max(s, axis=1, keepdims=True) - m_prev = load_with_init(head_m_ref, -jnp.inf) + + if attention_sink_ref is not None: + sinks = attention_sink_ref[kv_head_idx] + actual_bq_sz = q.shape[0] // num_q_heads_per_kv_head + m_prev_init = jnp.concat([sinks] * actual_bq_sz, axis=0) + m_prev = jnp.where(bkv_idx == 0, m_prev_init, head_m_ref[...]) + else: + m_prev = load_with_init(head_m_ref, -jnp.inf) + m_curr = jnp.maximum(m_prev, s_rowmax) head_m_ref[...] = m_curr p = jnp.exp(s - broadcast_minor(m_curr, s.shape)) @@ -382,7 +404,7 @@ def load_with_init(ref, init_val): p_rowsum = jnp.sum(p, axis=1, keepdims=True) exp_m_diff = jnp.exp(m_prev - m_curr) - l_prev = load_with_init(head_l_ref, 0.0) + l_prev = load_with_init(head_l_ref, 1.0 if attention_sink_ref is not None else 0.0) l_curr = exp_m_diff * l_prev + p_rowsum head_l_ref[...] = l_curr o_prev = load_with_init(head_acc_ref, 0.0) @@ -960,6 +982,7 @@ def prepare_inputs( Array, # [max_num_tokens, actual_num_kv_heads, actual_head_dim], v: jax. Array, # [max_num_tokens, actual_num_kv_heads, actual_head_dim], + attention_sink: jax.Array | None = None, # f32[actual_num_q_heads], ): max_num_tokens, actual_num_q_heads, actual_head_dim = q.shape actual_num_kv_heads = k.shape[1] @@ -995,7 +1018,13 @@ def prepare_inputs( .swapaxes(0, 1)) # TODO(kyuyeunk, chengjiyao): Add kv quantization here. kv = merge_kv(k, v) - return q, kv + + if attention_sink is not None: + attention_sink = attention_sink.reshape( + (-1, num_q_heads_per_kv_head, 1)) + attention_sink = jnp.repeat(attention_sink, head_dim, -1) + + return q, kv, attention_sink def prepare_outputs( @@ -1033,6 +1062,7 @@ def dynamic_validate_inputs( page_indices: jax.Array, # i32[max_num_seqs * pages_per_seq] cu_q_lens: jax.Array, # i32[max_num_seqs + 1] distribution: jax.Array, # i32[3] + attention_sink: jax.Array | None = None, # f32[actual_num_q_heads] *, sm_scale: float = 1.0, sliding_window: int | None = None, @@ -1060,6 +1090,7 @@ def dynamic_validate_inputs( page_indices, cu_q_lens, distribution, + attention_sink, sm_scale=sm_scale, sliding_window=sliding_window, soft_cap=soft_cap, @@ -1123,6 +1154,7 @@ def static_validate_inputs( page_indices: jax.Array, # i32[max_num_seqs * pages_per_seq] cu_q_lens: jax.Array, # i32[max_num_seqs + 1] distribution: jax.Array, # i32[3] + attention_sink: jax.Array | None = None, # f32[actual_num_q_heads] *, sm_scale: float = 1.0, sliding_window: int | None = None, @@ -1155,6 +1187,15 @@ def static_validate_inputs( raise ValueError( f"Expected {q.shape[2]=} to be equal to {k.shape[2]=} and {v.shape[2]=}" ) + if attention_sink is not None: + if attention_sink.shape[0] != q.shape[1]: + raise ValueError( + f"Expected {attention_sink.shape[0]=} to be equal to" + f" {q.shape[1]=} (num_q_heads).") + if attention_sink.dtype != jnp.float32: + raise ValueError( + f"Expected {attention_sink.dtype=} to be equal to {jnp.float32=}." + ) actual_head_dim = q.shape[2] actual_num_q_heads = q.shape[1] @@ -1278,6 +1319,7 @@ def ragged_paged_attention( page_indices: jax.Array, # i32[max_num_seqs * pages_per_seq] cu_q_lens: jax.Array, # i32[max_num_seqs + 1] distribution: jax.Array, # i32[3] + attention_sink: jax.Array | None = None, # f32[actual_num_q_heads] *, sm_scale: float = 1.0, sliding_window: int | None = None, @@ -1338,6 +1380,7 @@ def ragged_paged_attention( page_indices, cu_q_lens, distribution, + attention_sink, sm_scale=sm_scale, sliding_window=sliding_window, soft_cap=soft_cap, @@ -1356,7 +1399,7 @@ def ragged_paged_attention( actual_num_kv_heads = k.shape[1] actual_num_q_heads_per_kv_head = actual_num_q_heads // actual_num_kv_heads - q, kv = prepare_inputs(q, k, v) + q, kv, attention_sink = prepare_inputs(q, k, v, attention_sink) ( _, max_num_tokens, @@ -1395,6 +1438,8 @@ def ragged_paged_attention( pl.BlockSpec(memory_space=pltpu.HBM), pl.BlockSpec(memory_space=pltpu.HBM), pl.BlockSpec(memory_space=pltpu.HBM), + None if attention_sink is None else pl.BlockSpec( + memory_space=pltpu.VMEM) ] out_specs = [ @@ -1493,7 +1538,8 @@ def ragged_paged_attention( name=scope_name, )) - output, updated_kv_cache = kernel(*scalar_prefetches, q, kv, kv_cache) + output, updated_kv_cache = kernel(*scalar_prefetches, q, kv, kv_cache, + attention_sink) return ( prepare_outputs(output, actual_num_q_heads_per_kv_head, actual_head_dim), diff --git a/tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py b/tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py index dec422143..3032b677e 100644 --- a/tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +++ b/tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py @@ -391,7 +391,10 @@ def load_with_init(ref, init_val): lax.broadcasted_iota(jnp.int32, s.shape, 0) // num_q_heads_per_kv_head) k_span = bkv_idx * bkv_sz + lax.broadcasted_iota(jnp.int32, s.shape, 1) + mask = q_span < k_span + if sliding_window is not None: + mask = jnp.logical_or(mask, q_span - sliding_window >= k_span) if soft_cap is not None: s = soft_cap * jnp.tanh(s / soft_cap)