diff --git a/src/maxdiffusion/configs/base_wan_14b.yml b/src/maxdiffusion/configs/base_wan_14b.yml index fa630961..c8dd52c0 100644 --- a/src/maxdiffusion/configs/base_wan_14b.yml +++ b/src/maxdiffusion/configs/base_wan_14b.yml @@ -44,6 +44,7 @@ activations_dtype: 'bfloat16' # Replicates vae across devices instead of using the model's sharding annotations for sharding. replicate_vae: False +vae_spatial: -1 # default to total_device * 2 // (dp) # matmul and conv precision from https://jax.readthedocs.io/en/latest/jax.lax.html#jax.lax.Precision # Options are "DEFAULT", "HIGH", "HIGHEST" @@ -60,7 +61,7 @@ jit_initializers: True # Set true to load weights from pytorch from_pt: True split_head_dim: True -attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te, ring +attention: 'tokamax_flash' # Supported attention: dot_product, flash, tokamax_flash, cudnn_flash_te, ring, tokamax_ring flash_min_seq_length: 0 # If mask_padding_tokens is True, we pass in segment ids to splash attention to avoid attending to padding tokens. @@ -81,9 +82,7 @@ flash_block_sizes: { "block_q_dkv" : 512, "block_kv_dkv" : 512, "block_kv_dkv_compute" : 512, - "block_q_dq" : 512, - "block_kv_dq" : 512, - "use_fused_bwd_kernel": False, + "use_fused_bwd_kernel": True } # Use on v6e # flash_block_sizes: { diff --git a/src/maxdiffusion/configs/base_wan_27b.yml b/src/maxdiffusion/configs/base_wan_27b.yml index 6d06218c..1b19a020 100644 --- a/src/maxdiffusion/configs/base_wan_27b.yml +++ b/src/maxdiffusion/configs/base_wan_27b.yml @@ -44,6 +44,7 @@ activations_dtype: 'bfloat16' # Replicates vae across devices instead of using the model's sharding annotations for sharding. replicate_vae: False +vae_spatial: -1 # matmul and conv precision from https://jax.readthedocs.io/en/latest/jax.lax.html#jax.lax.Precision # Options are "DEFAULT", "HIGH", "HIGHEST" diff --git a/src/maxdiffusion/generate_wan.py b/src/maxdiffusion/generate_wan.py index d9d3af7c..83b0cf12 100644 --- a/src/maxdiffusion/generate_wan.py +++ b/src/maxdiffusion/generate_wan.py @@ -85,9 +85,20 @@ def get_git_commit_hash(): jax.config.update("jax_use_shardy_partitioner", True) -def call_pipeline(config, pipeline, prompt, negative_prompt): +def call_pipeline(config, pipeline, prompt, negative_prompt, num_inference_steps=None): + """Call the pipeline with optional num_inference_steps override. + + Args: + config: The configuration object. + pipeline: The pipeline to call. + prompt: The prompt(s) to use. + negative_prompt: The negative prompt(s) to use. + num_inference_steps: Optional override for number of inference steps. + If None, uses config.num_inference_steps. + """ model_key = config.model_name model_type = config.model_type + steps = num_inference_steps if num_inference_steps is not None else config.num_inference_steps if model_type == "I2V": image = load_image(config.image_url) if model_key == WAN2_1: @@ -98,7 +109,7 @@ def call_pipeline(config, pipeline, prompt, negative_prompt): height=config.height, width=config.width, num_frames=config.num_frames, - num_inference_steps=config.num_inference_steps, + num_inference_steps=steps, guidance_scale=config.guidance_scale, ) elif model_key == WAN2_2: @@ -109,7 +120,7 @@ def call_pipeline(config, pipeline, prompt, negative_prompt): height=config.height, width=config.width, num_frames=config.num_frames, - num_inference_steps=config.num_inference_steps, + num_inference_steps=steps, guidance_scale_low=config.guidance_scale_low, guidance_scale_high=config.guidance_scale_high, use_cfg_cache=config.use_cfg_cache, @@ -124,7 +135,7 @@ def call_pipeline(config, pipeline, prompt, negative_prompt): height=config.height, width=config.width, num_frames=config.num_frames, - num_inference_steps=config.num_inference_steps, + num_inference_steps=steps, guidance_scale=config.guidance_scale, use_cfg_cache=config.use_cfg_cache, ) @@ -135,7 +146,7 @@ def call_pipeline(config, pipeline, prompt, negative_prompt): height=config.height, width=config.width, num_frames=config.num_frames, - num_inference_steps=config.num_inference_steps, + num_inference_steps=steps, guidance_scale_low=config.guidance_scale_low, guidance_scale_high=config.guidance_scale_high, use_cfg_cache=config.use_cfg_cache, @@ -248,6 +259,7 @@ def run(config, pipeline=None, filename_prefix="", commit_hash=None): max_logging.log(f"hardware: {jax.devices()[0].platform}") max_logging.log(f"number of devices: {jax.device_count()}") max_logging.log(f"per_device_batch_size: {config.per_device_batch_size}") + max_logging.log(f"vae_spatial: {config.vae_spatial}") max_logging.log("============================================================") compile_time = time.perf_counter() - s0 @@ -276,15 +288,49 @@ def run(config, pipeline=None, filename_prefix="", commit_hash=None): max_logging.log(f"generation time per video: {generation_time_per_video}") else: max_logging.log("Warning: Number of videos is zero, cannot calculate generation_time_per_video.") - s0 = time.perf_counter() + if config.enable_profiler: + skip_steps = getattr(config, 'skip_first_n_steps_for_profiler', 0) + profiler_steps = getattr(config, 'profiler_steps', config.num_inference_steps) + profile_all = profiler_steps == -1 + steps_for_profile = config.num_inference_steps if profile_all else profiler_steps + + if profile_all: + max_logging.log(f"Profiler: profiling all {steps_for_profile} inference steps (profiler_steps=-1)") + else: + max_logging.log(f"Profiler: profiling {steps_for_profile} steps out of {config.num_inference_steps} total") + max_logging.log(f"Profiler: skip_first_n_steps={skip_steps}") + + def block_if_jax(x): + """Block until ready if x is a JAX array, otherwise no-op.""" + if hasattr(x, 'block_until_ready'): + x.block_until_ready() + return x + + for i in range(skip_steps): + max_logging.log(f"Profiler warmup iteration {i + 1}/{skip_steps}") + warmup_videos = call_pipeline(config, pipeline, prompt, negative_prompt, num_inference_steps=steps_for_profile) + # Block until warmup completes + jax.tree_util.tree_map(block_if_jax, warmup_videos) + + # Warm up GCS connection by flushing writer before starting profiler + if writer and jax.process_index() == 0: + max_logging.log("Flushing writer to warm up GCS connection before profiler...") + writer.flush() + + s0 = time.perf_counter() max_utils.activate_profiler(config) - videos = call_pipeline(config, pipeline, prompt, negative_prompt) + max_logging.log(f"Profiler: starting profiled run with {steps_for_profile} steps") + profiled_videos = call_pipeline(config, pipeline, prompt, negative_prompt, num_inference_steps=steps_for_profile) + # Wait for all computation to finish before stopping profiler + jax.tree_util.tree_map(block_if_jax, profiled_videos) max_utils.deactivate_profiler(config) + max_utils.upload_profiler_traces(config) generation_time_with_profiler = time.perf_counter() - s0 max_logging.log(f"generation_time_with_profiler: {generation_time_with_profiler}") if writer and jax.process_index() == 0: writer.add_scalar("inference/generation_time_with_profiler", generation_time_with_profiler, global_step=0) + max_logging.log("Profiler: completed (video not saved)") return saved_video_path diff --git a/src/maxdiffusion/kernels/__init__.py b/src/maxdiffusion/kernels/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/maxdiffusion/kernels/splash_attention/base.py b/src/maxdiffusion/kernels/splash_attention/base.py new file mode 100644 index 00000000..4cd45090 --- /dev/null +++ b/src/maxdiffusion/kernels/splash_attention/base.py @@ -0,0 +1,285 @@ +# Copyright 2025 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Base functionality for Sparse Flash Attention.""" + +import functools +from typing import Final, NamedTuple, TypeAlias +import jax +import jax.numpy as jnp +import numpy as np +from . import splash_attention_mask_info as mask_info_lib + + +MaskInfo = mask_info_lib.MaskInfo + + +DEFAULT_MASK_VALUE: Final[float] = -0.7 * float( + np.finfo(np.dtype("float32")).max +) + + +class SegmentIds(NamedTuple): + """SegmentIds for Q and KV sequences. + + SegmentIds are a mechanism to ensure that there is no cross-attention between + segments (fraction of a sequence) that have been concatenated together into a + sequence. Each array is a list of ids (integers). Only tokens with the same + id are allowed to attend to each other. + + The static mask (e.g. causal) is "and-ed" with the segment id mask to form + the actual attention mask. It is important that the latter does not have any + all-zero rows (along dimension kv). Otherwise it would result in a invalid + softmax (the denominator would be 0). + This condition holds for causal self-attention because in this case segment + ids form a block diagonal matrix so at least one element in each row is set. + It is easy to break this condition with non-self-attention configurations. + Attributes: + q: segment ids along the Q sequence + kv: segment ids along the KV sequence + """ + + q: jax.Array | jax.sharding.PartitionSpec # [q_seq_len] + kv: jax.Array | jax.sharding.PartitionSpec # [kv_seq_len] + + +# Return type of SplashAttention function that implements the custom vjp rule. +SplashCustomReturnType: TypeAlias = ( + jax.Array | tuple[jax.Array, dict[str, jax.Array]] +) + +SplashResidualsType = tuple[ + jax.Array, # q + jax.Array, # k + jax.Array, # v + SegmentIds | None, # segment_ids + jax.Array | None, # sinks + jax.Array, # out + jax.Array, # logsumexp + MaskInfo | None, # dkv_mask_info +] + + +def _attention_reference_impl( + q: jax.Array, + k: jax.Array, + v: jax.Array, + mask: jax.Array, + segment_ids: SegmentIds | None, + sinks: jax.Array | None, + mask_value: float, + save_residuals: bool, + attn_logits_soft_cap: float | None, +) -> SplashCustomReturnType: + logits = jnp.einsum("sd,td->st", q.astype(jnp.float32), k.astype(jnp.float32)) + + if segment_ids is not None: + mask = jnp.logical_and( + mask, segment_ids.q[:, None] == segment_ids.kv[None, :] + ) + + if attn_logits_soft_cap is not None: + logits = jnp.tanh(logits / attn_logits_soft_cap) + logits = logits * attn_logits_soft_cap + + if sinks is not None: + assert sinks.shape == () # should already be vmapped + + logits = jnp.where(mask, logits, mask_value) + m = logits.max(axis=-1) + sinks = None if sinks is None else sinks.astype(logits.dtype) + m = m if sinks is None else jnp.maximum(m, sinks) + s = jnp.exp(logits - m[..., None]) + l = s.sum(axis=-1) + (0 if sinks is None else jnp.exp(sinks - m)) + p = s / l[..., None] + + o = jnp.einsum("st,td->sd", p, v.astype(jnp.float32)) + + if save_residuals: + logsumexp = m + jnp.log(l) + return o, {"logsumexp": logsumexp, "max_logits": m} + return o + + +def _attention_reference_custom_bwd( + do, + q, + k, + v, + mask, + segment_ids, + sinks, + o, + logsumexp, + mask_value: float = DEFAULT_MASK_VALUE, + backward_impl: str = "vanilla", + attn_logits_soft_cap: float | None = None, +) -> tuple[jax.Array, jax.Array, jax.Array, None, None, jax.Array | None]: + uncapped_logits = jnp.einsum( + "qc,kc->qk", q, k, preferred_element_type=jnp.float32 + ) + + if attn_logits_soft_cap is not None: + logits = jnp.tanh(uncapped_logits / attn_logits_soft_cap) + logits = logits * attn_logits_soft_cap + else: + logits = uncapped_logits + + if segment_ids is not None: + mask = jnp.logical_and( + mask, segment_ids.q[:, None] == segment_ids.kv[None, :] + ) + logits = jnp.where(mask, logits, mask_value) + + p = jnp.exp(logits - logsumexp[..., None]) + do = do.astype(jnp.float32) # pytype: disable=attribute-error + dv = jnp.einsum("pt,pd->td", p, do).astype(v.dtype) + dp = jnp.einsum("pd,td->pt", do, v.astype(jnp.float32)) + + # These two ways of computing ds are mathematically equivalent. The first + # involves reducing over the head_dim dimension and the second involves + # reducing over a sequence dimension. They tend to produce slightly different + # numerics. + if backward_impl == "flash": + di = jnp.sum(o.astype(jnp.float32) * do, axis=-1)[..., None] + else: + di = jnp.einsum("st,st->s", dp, p)[:, None] + ds = (dp - di) * p + if attn_logits_soft_cap is not None: + normalized = uncapped_logits / attn_logits_soft_cap + d = jnp.tanh(normalized) + g = ds * (1 - d) + ds = g + g * d + dk = jnp.einsum("sd,st->td", q.astype(jnp.float32), ds).astype(k.dtype) + dq = jnp.einsum("st,td->sd", ds, k.astype(jnp.float32)).astype(q.dtype) + dsinks = None + if sinks is not None: + sinks_exp = -jnp.exp( + sinks[..., None, None].astype(jnp.float32) + - logsumexp[..., None].astype(jnp.float32) + ) + dsinks = jnp.sum(sinks_exp.astype(o.dtype) * o * do, axis=(-1, -2)) + return dq, dk, dv, None, None, dsinks + + +@functools.partial( + jax.jit, + static_argnames=[ + "mask_value", + "save_residuals", + "attn_logits_soft_cap", + "is_mqa", + ], +) +def attention_reference( + q: jax.Array, + k: jax.Array, + v: jax.Array, + mask: jax.Array, + segment_ids: SegmentIds | None = None, + sinks: jax.Array | None = None, + *, + is_mqa: bool, + mask_value: float = DEFAULT_MASK_VALUE, + save_residuals: bool = False, + attn_logits_soft_cap: float | None = None, +): + """A JIT-compiled reference implementation of attention, handles MQA and MHA.""" + attn_impl = functools.partial( + _attention_reference_impl, + mask_value=mask_value, + save_residuals=save_residuals, + attn_logits_soft_cap=attn_logits_soft_cap, + ) + + if is_mqa: + func = jax.vmap(attn_impl, in_axes=(0, None, None, None, None, 0)) + else: + # In grouped attention (1 < num_kv_heads && num_kv_heads < num_q_heads). + # We interleave the KV heads across the Q heads. + # For example: for 8 Q heads and 4 KV heads: + # Q head [0, 1] see KV head 0 + # Q head [2, 3] see KV head 1 + # Q head [4, 5] see KV head 2 + # Q head [6, 7] see KV head 3 + + kv_heads, q_heads = k.shape[0], q.shape[0] + assert q_heads % kv_heads == 0 + + if kv_heads < q_heads: + # Repeat K and V heads to match the number of Q heads. + q_heads_per_kv = q_heads // kv_heads + k = jnp.repeat(k, repeats=q_heads_per_kv, axis=0) + v = jnp.repeat(v, repeats=q_heads_per_kv, axis=0) + + func = jax.vmap(attn_impl, in_axes=(0, 0, 0, None, None, 0)) + + out = func(q, k, v, mask, segment_ids, sinks) + return out + + +@functools.partial( + jax.jit, static_argnames=["is_mqa", "backward_impl", "attn_logits_soft_cap"] +) +def attention_reference_vjp( + do, + q, + k, + v, + mask, + segment_ids, + sinks, + o, + logsumexp, + *, + is_mqa: bool, + backward_impl: str = "vanilla", + attn_logits_soft_cap: float | None = None, +): + """Wrapper for backward reference that handles GQA/MQA broadcasting and reduction.""" + bwd = functools.partial( + _attention_reference_custom_bwd, + backward_impl=backward_impl, + attn_logits_soft_cap=attn_logits_soft_cap, + ) + + num_q_heads = q.shape[0] + num_kv_heads = 1 if is_mqa else k.shape[0] + + is_grouped = not is_mqa and num_kv_heads < num_q_heads + assert num_q_heads % num_kv_heads == 0 + head_multiplier = num_q_heads // num_kv_heads + if is_mqa: + bwd = jax.vmap(bwd, in_axes=(0, 0, None, None, None, None, 0, 0, 0)) + else: + bwd = jax.vmap(bwd, in_axes=(0, 0, 0, 0, None, None, 0, 0, 0)) + # Interleave the KV heads to match the corresponding Q heads. + if is_grouped: + k = jnp.repeat(k, head_multiplier, axis=0) + v = jnp.repeat(v, head_multiplier, axis=0) + + dq, dk, dv, _, _, dsinks = bwd( + do, q, k, v, mask, segment_ids, sinks, o, logsumexp + ) + + if is_mqa: + dk, dv = dk.sum(axis=0), dv.sum(axis=0) + elif is_grouped: + # Perform the sum reduction across the head_multiplier dimension only. + # So that the output still has KV heads. + dk = dk.reshape(num_kv_heads, head_multiplier, *dk.shape[1:]) + dv = dv.reshape(num_kv_heads, head_multiplier, *dv.shape[1:]) + dk, dv = dk.sum(axis=1), dv.sum(axis=1) + + return dq, dk, dv, dsinks diff --git a/src/maxdiffusion/kernels/splash_attention/microbenchmarks.pdf b/src/maxdiffusion/kernels/splash_attention/microbenchmarks.pdf new file mode 100644 index 00000000..46b8036c Binary files /dev/null and b/src/maxdiffusion/kernels/splash_attention/microbenchmarks.pdf differ diff --git a/src/maxdiffusion/kernels/splash_attention/ring_attention_kernel.py b/src/maxdiffusion/kernels/splash_attention/ring_attention_kernel.py new file mode 100644 index 00000000..70cae13f --- /dev/null +++ b/src/maxdiffusion/kernels/splash_attention/ring_attention_kernel.py @@ -0,0 +1,724 @@ +# Copyright 2025 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Implementation of Ring Attention.""" + +import functools +from typing import Any + +import jax +from jax import lax +from jax import tree_util +import jax.numpy as jnp +import numpy as np +from . import base +from . import splash_attention_kernel as splash_kernel +from . import splash_attention_mask as mask_lib +from . import splash_attention_mask_info as mask_info_lib + +P = jax.P +MaskInfo = mask_info_lib.MaskInfo +partial = functools.partial + +SegmentIds = base.SegmentIds +SplashConfig = splash_kernel.SplashConfig +SplashResidualsType = base.SplashResidualsType +SplashCustomReturnType = base.SplashCustomReturnType +MaskFunctionType = splash_kernel.MaskFunctionType +_splash_attention_forward = splash_kernel._splash_attention_forward # pylint: disable=protected-access +_splash_attention_bwd = splash_kernel._splash_attention_bwd # pylint: disable=protected-access + + +def _dynamic_slice_mask_info( + mask_info: MaskInfo, kv_shard_idx: jax.Array, ring_size: int +) -> MaskInfo: + """Slices MaskInfo for the current ring step.""" + + def slice_if_exists(arr: jax.Array | None): + if arr is None: + return None + + shard_len = int(arr.shape[-1]) // ring_size + start_idx = kv_shard_idx * shard_len + return lax.dynamic_slice_in_dim(arr, start_idx, shard_len, axis=-1) + + return MaskInfo( + mask_next=slice_if_exists(mask_info.mask_next), + active_rows=slice_if_exists(mask_info.active_rows), + active_cols=slice_if_exists(mask_info.active_cols), + num_active_blocks=slice_if_exists(mask_info.num_active_blocks), + block_mask=slice_if_exists(mask_info.block_mask), + partial_mask_blocks=mask_info.partial_mask_blocks, # partial mask blocks are global + q_sequence=mask_info.q_sequence, # Q sequence stays stationary + ) + + +def _ring_attention_forward( + fwd_mask_info: MaskInfo, + q: jax.Array, + k: jax.Array, + v: jax.Array, + segment_ids: SegmentIds | None, + mask_value: float, + is_mqa: bool, + config: SplashConfig | None, + mask_function: MaskFunctionType | None, + fwd_mask_sparsity: float, + *, + sinks: jax.Array | None = None, + ring_axis: str, + rotate_segment_ids: bool = True, +) -> tuple[jax.Array, tuple[jax.Array, jax.Array]]: + + if q.shape[-1] != k.shape[-1]: + raise NotImplementedError( + "Queries and keys must have the same head dimension." + ) + + if sinks is not None: + raise NotImplementedError("Sinks aren't supportd yet.") + + ring_axis_size = lax.axis_size(ring_axis) + ring_axis_idx = lax.axis_index(ring_axis) + + shift = partial( + lax.ppermute, + axis_name=ring_axis, + perm=[(i, (i + 1) % ring_axis_size) for i in range(ring_axis_size)], + ) + # for example, if ring size is 4 + # Device 3 => permute_idx 0, offset (3-0) % 4 = 3, + # permute_idx 1, offset (3-1) % 4 = 2, etc. + # Device 2 => permute_idx 0, offset (2-0) % 4 = 2, + # permute_idx 1, offset (2-1) % 4 = 1, etc. + # Device 1 => permute_idx 0, offset (1-0) % 4 = 1, + # permute_idx 1, offset (1-1) % 4 = 0, etc. + # Device 0 => permute_idx 0, offset (0-0) % 4 = 0, + # permute_idx 1, offset (0-1) % 4 = 3, etc. + + splash_fwd_partial = partial( + _splash_attention_forward, + save_residuals=True, + mask_value=mask_value, + is_mqa=is_mqa, + config=config, + mask_function=mask_function, + fwd_mask_sparsity=fwd_mask_sparsity, + max_logit_value=None, + ) + # Initial accumulator values + o_shape = q.shape + o_init = jnp.zeros(o_shape, dtype=jnp.float32) + l_init = jnp.zeros((o_shape[0], o_shape[1]), jnp.float32) + m_init = jnp.full_like(l_init, mask_value, dtype=jnp.float32) + + def body(carry, i: int)-> tuple[tuple[jax.Array, jax.Array, jax.Array, jax.Array, jax.Array, SegmentIds | None], None]: + m_prev, l_prev, o_prev, k_current, v_current, segment_ids_current = carry + + current_kv_shard_idx = (ring_axis_idx - i) % ring_axis_size + local_fwd_mask_info = _dynamic_slice_mask_info( + fwd_mask_info, current_kv_shard_idx, ring_axis_size + ) + k_next = shift(k_current) + v_next = shift(v_current) + + if segment_ids is not None and rotate_segment_ids: + kv_segment_ids_next = shift(segment_ids_current.kv) + segment_ids_next = SegmentIds(segment_ids.q, kv_segment_ids_next) + else: + segment_ids_next = segment_ids_current + + out_curr, stats = splash_fwd_partial( + local_fwd_mask_info, + q, + k_current, + v_current, + segment_ids=segment_ids_current, + sinks=sinks, + ) + lse_curr = stats["logsumexp"] + m_curr = stats["max_logits"] + l_curr = jnp.exp(lse_curr - m_curr) + o_curr = out_curr.astype(jnp.float32) * l_curr[..., None] + m_next = jnp.maximum(m_prev, m_curr) + alpha = jnp.exp(m_prev - m_next) + beta = jnp.exp(m_curr - m_next) + l_next = alpha * l_prev + beta * l_curr + o_next = alpha[..., None] * o_prev + beta[..., None] * o_curr + return (m_next, l_next, o_next, k_next, v_next, segment_ids_next), None + + # Use lax.scan to get the final carry AND the collected sequence of (k,v) + # pairs + initial_carry = (m_init, l_init, o_init, k, v, segment_ids) + (m_final, l_final, o_final, _, _, _), _ = lax.scan( + body, + initial_carry, + xs=jnp.arange(0, ring_axis_size), + length=ring_axis_size, + unroll=True, + ) # type: ignore[arg-type] + # Final normalization + assert l_final.dtype == jnp.float32 + l_inv = jnp.where(l_final == 0.0, 0.0, 1.0 / l_final) + out = (o_final * l_inv[..., None]).astype(q.dtype) + # Final logsumexp for residuals + lse = jnp.log(l_final) + m_final + lse = jnp.where(l_final == 0.0, mask_value, lse) + + return out, (lse, m_final) + + +def _ring_attention_bwd( + mask_value: float, + is_mqa: bool, + config: SplashConfig | None, + mask_function: MaskFunctionType | None, + fwd_mask_sparsity: float, + dkv_mask_sparsity: float, + save_residuals: bool, + ring_axis: str, + rotate_segment_ids: bool, + # Residuals and gradients + res: Any, + do: jax.Array, +): + del save_residuals + (q, k, v, segment_ids, sinks, out, logsumexp, dkv_mask_info) = res + do = do.astype(jnp.float32) + + ring_axis_size = lax.axis_size(ring_axis) + ring_axis_idx = lax.axis_index(ring_axis) + + shift = partial( + lax.ppermute, + axis_name=ring_axis, + perm=[(i, (i + 1) % ring_axis_size) for i in range(ring_axis_size)], + ) + dq_accum = jnp.zeros_like(q, dtype=jnp.float32) + dk_accum = jnp.zeros_like(k, dtype=jnp.float32) + dv_accum = jnp.zeros_like(v, dtype=jnp.float32) + dsinks = sinks + + def body(carry, i: int): + ( + dq_accum, + dk_accum, + dv_accum, + k_current, + v_current, + segment_ids_current, + _, + ) = carry + k_next = shift(k_current) + v_next = shift(v_current) + + current_kv_shard_idx = (ring_axis_idx - i) % ring_axis_size + local_dkv_mask_info = _dynamic_slice_mask_info( + dkv_mask_info, current_kv_shard_idx, ring_axis_size + ) + if segment_ids is not None and rotate_segment_ids: + kv_segment_ids_next = shift(segment_ids_current.kv) + segment_ids_next = SegmentIds(segment_ids.q, kv_segment_ids_next) + else: + segment_ids_next = segment_ids_current + + residuals_for_chunk = ( + q, + k_current, + v_current, + segment_ids_current, + sinks, + out, + logsumexp, + local_dkv_mask_info, + ) + + attn_bwd = functools.partial( + _splash_attention_bwd, + save_residuals=False, + mask_value=mask_value, + is_mqa=is_mqa, + config=config, + mask_function=mask_function, + fwd_mask_sparsity=fwd_mask_sparsity, + dkv_mask_sparsity=dkv_mask_sparsity, + ) + _, _, dq_i, dk_i, dv_i, _, dsinks, _ = attn_bwd( + res=residuals_for_chunk, do=do + ) + dv_next = shift(dv_accum + dv_i.astype(dv_accum.dtype)) + dk_next = shift(dk_accum + dk_i.astype(dk_accum.dtype)) + dq_accum = dq_accum + dq_i.astype(dq_accum.dtype) + + return ( + dq_accum, + dk_next, + dv_next, + k_next, + v_next, + segment_ids_next, + dsinks, + ), None + + initial_carry = (dq_accum, dk_accum, dv_accum, k, v, segment_ids, dsinks) + (dq, dk, dv, _, _, _, dsinks), _ = lax.scan( + body, + initial_carry, + xs=jnp.arange(ring_axis_size), + length=ring_axis_size, + unroll=True, + ) + + if sinks is not None: + dsinks = jax.lax.psum(dsinks, axis_name=ring_axis) + + return ( + None, # fwd_mask_info + None, # dkv_mask_info + dq.astype(q.dtype), + dk.astype(k.dtype), + dv.astype(v.dtype), + dsinks, + None, + ) + + +def _ring_attention_fwd( + fwd_mask_info: MaskInfo, + dkv_mask_info: MaskInfo | None, + q: jax.Array, + k: jax.Array, + v: jax.Array, + segment_ids: SegmentIds | None, + sinks: jax.Array | None, + # nondiff_args + mask_value: float, # 1 + is_mqa: bool, # 2 + config: SplashConfig | None, # 3 + mask_function: MaskFunctionType | None, # 4 + fwd_mask_sparsity: float, # 5 + dkv_mask_sparsity: float, # 6 + save_residuals: bool, # 7 + ring_axis: str, # 8 + rotate_segment_ids: bool, # 9 +) -> tuple[jax.Array, SplashResidualsType]: + """Forward pass for the custom VJP of ring attention. + + This function is used by `jax.custom_vjp` to define the forward pass + of the ring attention computation, also returning residuals needed for + the backward pass. + + Args: + fwd_mask_info: Mask information for the forward pass. + dkv_mask_info: Mask information for the backward pass for dK/dV. + q: Query array. + k: Key array. + v: Value array. + segment_ids: Optional segment IDs for packed sequences. + sinks: Optional sink tokens. + mask_value: The value used for masked-out attention scores. + is_mqa: Whether Multi-Query Attention is used. + config: SplashAttention configuration. + mask_function: Optional function to apply additional masking. + fwd_mask_sparsity: Sparsity level of the forward mask. + save_residuals: Whether to save residuals for the backward pass. + ring_axis: The name of the jax axis used for the ring. + + Returns: + A tuple containing: + - The output of the ring attention computation. + - Residuals needed for the backward pass (`SplashResidualsType`). + """ + del dkv_mask_sparsity + if save_residuals: + raise NotImplementedError("Higher-order AD not supported.") + + out, (logsumexp, max_logits) = _ring_attention_forward( + fwd_mask_info, + q, + k, + v, + segment_ids, + sinks=sinks, + mask_value=mask_value, + is_mqa=is_mqa, + config=config, + mask_function=mask_function, + fwd_mask_sparsity=fwd_mask_sparsity, + ring_axis=ring_axis, + rotate_segment_ids=rotate_segment_ids, + ) + residuals = (q, k, v, segment_ids, sinks, out, logsumexp, dkv_mask_info) + return out, residuals + + +@partial( + jax.custom_vjp, + nondiff_argnames=( + "mask_value", + "is_mqa", + "config", + "mask_function", + "fwd_mask_sparsity", + "dkv_mask_sparsity", + "save_residuals", + "ring_axis", + "rotate_segment_ids", + ), +) +def _ring_attention_custom( + fwd_mask_info: MaskInfo, + dkv_mask_info: MaskInfo | None, + q: jax.Array, + k: jax.Array, + v: jax.Array, + segment_ids: SegmentIds | None, + sinks: jax.Array | None, + mask_value: float, + is_mqa: bool, + config: SplashConfig | None, + mask_function: MaskFunctionType | None, + fwd_mask_sparsity: float, + dkv_mask_sparsity: float, + save_residuals: bool, + ring_axis: str, + rotate_segment_ids: bool , +) -> SplashCustomReturnType: + """Performs ring attention with a custom VJP. + + This function is a wrapper around `_ring_attention_forward` and is used + to define the custom gradient for ring attention. + + Args: + fwd_mask_info: Mask information for the forward pass. + dkv_mask_info: Mask information for the backward pass for dK/dV. + q: Query array. + k: Key array. + v: Value array. + segment_ids: Optional segment IDs for packed sequences. + sinks: Optional sink tokens. + mask_value: The value used for masked-out attention scores. + is_mqa: Whether Multi-Query Attention is used. + config: SplashAttention configuration. + mask_function: Optional function to apply additional masking. + fwd_mask_sparsity: Sparsity level of the forward mask. + save_residuals: Whether to save residuals for the backward pass. + ring_axis: The name of the jax axis used for the ring. + rotate_segment_ids: Whether to rotate segment IDs along with K/V in ring attention. + This only possible when segment id for all KV shards are same, i.e ring attention is called in shard map. + Returns: + The output of the ring attention computation. + """ + del dkv_mask_info, dkv_mask_sparsity + out, _ = _ring_attention_forward( + fwd_mask_info, + q, + k, + v, + segment_ids, + sinks=sinks, + mask_value=mask_value, + is_mqa=is_mqa, + config=config, + mask_function=mask_function, + fwd_mask_sparsity=fwd_mask_sparsity, + ring_axis=ring_axis, + rotate_segment_ids=rotate_segment_ids, + ) + return out + + +_ring_attention_custom.defvjp(_ring_attention_fwd, _ring_attention_bwd) + + +def _has_axis(axis_name: str) -> bool: + try: + # We try to access the size of the axis. + # If it doesn't exist, JAX raises a NameError (or similar) immediately + # during tracing. + lax.axis_size(axis_name) + return True + except (NameError, ValueError): + return False + + +@partial( + jax.jit, + static_argnames=[ + "is_mqa", + "config", + "mask_value", + "mask_function", + "fwd_mask_sparsity", + "dkv_mask_sparsity", + "save_residuals", + "ring_axis", + "rotate_segment_ids", + ], +) +def _ring_attention( + fwd_mask_info: MaskInfo, + dkv_mask_info: MaskInfo | None, + q: jax.Array, + k: jax.Array, + v: jax.Array, + segment_ids: SegmentIds | None = None, + sinks: jax.Array | None = None, + *, + is_mqa: bool, + config: SplashConfig | None, + mask_value: float, + mask_function: MaskFunctionType | None, + fwd_mask_sparsity: float, + dkv_mask_sparsity: float, + save_residuals: bool = False, + ring_axis: str, + rotate_segment_ids: bool = True, +) -> SplashCustomReturnType: + """Performs ring attention using SplashAttention kernels. + + This function orchestrates the ring attention mechanism by iterating through + shards of keys and values across devices along the specified `ring_axis`, + using `_splash_attention_forward` for each chunk. + + Args: + fwd_mask_info: Mask information for the forward pass. + dkv_mask_info: Mask information for the backward pass for dK/dV. + q: Query array. + k: Key array. + v: Value array. + segment_ids: Optional segment IDs for packed sequences. + sinks: Optional sink tokens. + is_mqa: Whether Multi-Query Attention is used. + config: SplashAttention configuration. + mask_value: The value used for masked-out attention scores. + mask_function: Optional function to apply additional masking. + fwd_mask_sparsity: Sparsity level of the forward mask. + save_residuals: Whether to save residuals for the backward pass. + ring_axis: The name of the jax axis used for the ring. + rotate_segment_ids: Whether to rotate segment IDs along with K/V in ring attention + + Returns: + The output of the ring attention computation. + + Raises: + ValueError: If the specified `ring_axis` does not exist. + """ + if not _has_axis(ring_axis): + raise ValueError(f"Ring axis {ring_axis} does not exist") + + return _ring_attention_custom( + fwd_mask_info, + dkv_mask_info, + q, + k, + v, + segment_ids, + sinks, + is_mqa=is_mqa, + config=config, + mask_value=mask_value, + mask_function=mask_function, + fwd_mask_sparsity=fwd_mask_sparsity, + dkv_mask_sparsity=dkv_mask_sparsity, + save_residuals=save_residuals, + ring_axis=ring_axis, + rotate_segment_ids=rotate_segment_ids, + ) + + +@jax.tree_util.register_pytree_node_class +class RingSplashAttentionKernel: + """Implements Ring Attention using SplashAttention for sequence parallelism. + + This kernel computes global attention by keeping Keys and Values distributed + across the `ring_axis`. Instead of gathering full sequences, it rotates K/V + shards between devices and accumulates results incrementally. This allows + processing sequence lengths that exceed single-device memory limits. + + Attributes: + fwd_mask_info: Mask information for the forward pass. + dkv_mask_info: Mask information for the backward pass for dK/dV. + ring_axis: The name of the jax axis used for the ring. + kwargs: Additional keyword arguments passed to the SplashAttentionKernel + constructor. + """ + + def __init__( + self, + fwd_mask_info: MaskInfo, + dkv_mask_info: MaskInfo | None, + ring_axis: str, + rotate_segment_ids: bool , + **kwargs, + ): + self.fwd_mask_info = fwd_mask_info + self.dkv_mask_info = dkv_mask_info + self.ring_axis = ring_axis + self.rotate_segment_ids = rotate_segment_ids + self.kwargs = kwargs + + def __call__(self, *args, **kwargs): + return _ring_attention( + self.fwd_mask_info, + self.dkv_mask_info, + *args, + **kwargs, + **self.kwargs, + ring_axis=self.ring_axis, + rotate_segment_ids=self.rotate_segment_ids, + ) + + def manual_sharding_spec(self): + """Ring attention expects MaskInfo to be sharded by `q_seq_shards`. + + Each q shard will need all the k/v shard's MaskInfo eventually, so we don't + shard it, but instead dynamic_slice the chunk that we need at each + iteration. + """ + + spec = jax.sharding.PartitionSpec(self.ring_axis) + _resolve_spec = lambda x: spec if x is not None else None + + mask_info_specs = MaskInfo( # pytype: disable=wrong-arg-types + mask_next=_resolve_spec(self.fwd_mask_info.mask_next), + active_rows=_resolve_spec(self.fwd_mask_info.active_rows), + active_cols=_resolve_spec(self.fwd_mask_info.active_cols), + num_active_blocks=_resolve_spec(self.fwd_mask_info.num_active_blocks), + block_mask=_resolve_spec(self.fwd_mask_info.block_mask), + partial_mask_blocks=jax.sharding.PartitionSpec(), # replicated + q_sequence=_resolve_spec(self.fwd_mask_info.q_sequence), + ) + return RingSplashAttentionKernel( + mask_info_specs, + mask_info_specs if self.dkv_mask_info is not None else None, + ring_axis=self.ring_axis, + **self.kwargs, + ) + + def tree_flatten(self): + children = (self.fwd_mask_info, self.dkv_mask_info) + aux_data = self.kwargs.copy() + aux_data["ring_axis"] = self.ring_axis + return children, aux_data + + @classmethod + def tree_unflatten(cls, aux_data, children): + fwd_mask_info, dkv_mask_info = children + dkv_mask_info = ( + mask_info_lib.MaskInfo(*dkv_mask_info) + if dkv_mask_info is not None + else None + ) + return cls( + mask_info_lib.MaskInfo(*fwd_mask_info), + dkv_mask_info, + **aux_data, + ) + + +def make_ring_attention( + mask: np.ndarray | mask_lib.Mask, + *, + config: SplashConfig | None = None, + is_mqa: bool, + save_residuals: bool = False, + mask_value: float = base.DEFAULT_MASK_VALUE, + downcast_smem_data: bool = True, + partial_mask_blocks_dtype: jax.typing.DTypeLike = np.int8, + ring_axis: str, + q_seq_shards: int = 1, + kv_seq_shards: int = 1, + rotate_segment_ids: bool = True, +): + """Creates a RingSplashAttentionKernel. + + Args: + mask: The attention mask. + config: SplashAttention configuration. If None, uses the default config. + is_mqa: Whether the model uses Multi-Query Attention. + save_residuals: Whether to save residuals for the backward pass. + mask_value: The value to use for masked-out attention scores. + downcast_smem_data: Whether to downcast data in shared memory. + partial_mask_blocks_dtype: The dtype for partial mask blocks. + ring_axis: The name of the jax scan axis used for the ring. + q_seq_shards: The number of shards for the query sequence dimension. + kv_seq_shards: The number of shards for the key/value sequence dimension. + rotate_segment_ids: Whether to rotate segment IDs along with K/V in ring attention + This only possible when segment id for all KV shards are same, i.e ring attention is called in shard map. + Common scenario being padding applied to each shard independently, so all shards have same segment ids. + Returns: + A RingSplashAttentionKernel instance. + + Raises: + ValueError: If the mask shape is unexpected or ring_axis is not specified + """ + + if len(mask.shape) != 2: + raise ValueError(f"Unexpected mask shape: {mask.shape}") + + if isinstance(mask, np.ndarray): + mask = mask_lib.NumpyMask(mask) + + if not isinstance(mask, (mask_lib.NumpyMask, mask_lib.FullMask)): + raise NotImplementedError( + f"Only NumpyMask and FullMask are supported, but got {type(mask)}." + ) + + if config is None: + config = SplashConfig.get_default() + + process_fn = partial( + mask_info_lib.process_mask, + downcast_smem_data=downcast_smem_data, + partial_mask_blocks_dtype=partial_mask_blocks_dtype, + q_seq_shards=q_seq_shards, + kv_seq_shards=kv_seq_shards, + ) + + fwd_mask_info, mask_function_fwd = process_fn( + mask, + (config.block_q, config.block_kv), + ) + fwd_mask_sparsity = float(np.mean(fwd_mask_info.block_mask != 0)) + fwd_mask_info = tree_util.tree_map(jnp.array, fwd_mask_info) + + dkv_mask_info = None + dkv_mask_sparsity = 0.0 + if config.has_backward_blocks: + bq_dkv, bkv_dkv = config.block_q_dkv, config.block_kv_dkv + dkv_mask_info, mask_function_dkv = process_fn( + mask, + (bq_dkv, bkv_dkv), + is_dkv=True, + return_dynamic_grid=config.dq_reduction_steps == 3, + ) + assert (mask_function_fwd is None) == (mask_function_dkv is None) + dkv_mask_sparsity = float(np.mean(dkv_mask_info.block_mask != 0)) + dkv_mask_info = tree_util.tree_map(jnp.array, dkv_mask_info) + + return RingSplashAttentionKernel( + fwd_mask_info, + dkv_mask_info, + ring_axis=ring_axis, + rotate_segment_ids=rotate_segment_ids, + config=config, + is_mqa=is_mqa, + save_residuals=save_residuals, + mask_value=mask_value, + mask_function=mask_function_fwd, + fwd_mask_sparsity=fwd_mask_sparsity, + dkv_mask_sparsity=dkv_mask_sparsity, + ) diff --git a/src/maxdiffusion/kernels/splash_attention/ring_attention_kernel_test.py b/src/maxdiffusion/kernels/splash_attention/ring_attention_kernel_test.py new file mode 100644 index 00000000..da95a277 --- /dev/null +++ b/src/maxdiffusion/kernels/splash_attention/ring_attention_kernel_test.py @@ -0,0 +1,176 @@ +# Copyright 2025 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for ring attention.""" + +import dataclasses +import functools + +from absl.testing import absltest +from absl.testing import parameterized +import jax +from jax import random +import jax.numpy as jnp +import numpy as np +from . import base +from . import ring_attention_kernel +from . import splash_attention_kernel as splash +from . import splash_attention_mask as mask_lib +from . import splash_attention_test_utils as test_utils + +P = jax.sharding.PartitionSpec +partial = functools.partial + +jax.config.parse_flags_with_absl() + + +class RingAttentionTest(test_utils.SplashAttentionTestCase): + + def setUp(self): + self.skipTest("no sharding on runners") + if jax.default_backend() != "tpu": + self.skipTest("Only supported on TPUs.") + + if len(jax.devices()) < 4: + self.skipTest("This test requires at least 4 devices.") + + super().setUp() + + @parameterized.product( + ring_size=[2], + num_heads=[1], + head_dim=[128, 256], + dtype=[jnp.bfloat16], + is_mqa=[False, True], + is_segmented=[False, True], + mask_type=["FULL", "CAUSAL"], + ) + def test_ring_attention( + self, + ring_size, + num_heads, + head_dim, + dtype, + is_mqa, + is_segmented, + mask_type, + ): + if len(jax.devices()) < ring_size: + self.skipTest( + f"This test requires {ring_size} devices, but has only" + f" {len(jax.devices())} devices available." + ) + + # Mesh Creation and Input Generation + ring_axis = "ring" + devices = np.asarray(jax.devices()[:ring_size]).reshape(1, ring_size) + mesh = jax.sharding.Mesh(devices, ("heads", ring_axis)) + seq_len = 1024 * ring_size + + k1, k2, k3, k4 = random.split(random.key(0), 4) + scale = head_dim**-0.5 + q = random.normal(k1, (num_heads, seq_len, head_dim), dtype=dtype) * scale + if is_mqa: + k = random.normal(k2, (seq_len, head_dim), dtype=dtype) * scale + v = random.normal(k3, (seq_len, head_dim), dtype=dtype) * scale + else: + k = ( + random.normal(k2, (num_heads, seq_len, head_dim), dtype=dtype) + * scale + ) + v = ( + random.normal(k3, (num_heads, seq_len, head_dim), dtype=dtype) + * scale + ) + do = random.normal(k4, q.shape, dtype=dtype) * scale + + if mask_type == "CAUSAL": + mask = mask_lib.make_causal_mask((seq_len, seq_len)) + elif mask_type == "FULL": + mask = mask_lib.FullMask(_shape=(seq_len, seq_len)) + else: + raise ValueError(f"Unsupported mask type: {mask_type}") + + if is_segmented: + segment_ids = test_utils.create_segment_ids(seq_len) + segment_ids_spec = base.SegmentIds(q=P(ring_axis), kv=P(ring_axis)) + else: + segment_ids = segment_ids_spec = None + + # For ring attention, sequence dimension is sharded over 'ring' axis + q_spec = P(None, ring_axis, None) + kv_spec = P(ring_axis, None) if is_mqa else q_spec + + + splash_config = splash.SplashConfig.get_default() + splash_config = dataclasses.replace( + splash_config, + use_base2_exp=False, + fuse_reciprocal=True, + # TODO: Change fuse_reciprocal behavior for ring attention + # so we do the reciprocal after ring + ) + + ring_kernel = ring_attention_kernel.make_ring_attention( + mask, + is_mqa=is_mqa, + ring_axis=ring_axis, + config=splash_config, + save_residuals=False, + q_seq_shards=ring_size, + kv_seq_shards=ring_size, + ) + kernel_spec = ring_kernel.manual_sharding_spec() + + @partial( + jax.shard_map, + mesh=mesh, + in_specs=( + kernel_spec, + q_spec, + kv_spec, + kv_spec, + segment_ids_spec, + ), + out_specs=q_spec, + check_vma=False, + ) + def ring_attn(ring_kernel, q, k, v, segment_ids): + out = ring_kernel(q, k, v, segment_ids) + return out + + ring_attn_ref = partial(base.attention_reference, is_mqa=is_mqa) + + with self.subTest("fwd"): + out = ring_attn(ring_kernel, q, k, v, segment_ids) + out_ref = ring_attn_ref(q, k, v, mask[:, :], segment_ids) + self._assert_allclose(out, out_ref, rtol=5e-3, atol=3e-3) + + with self.subTest("bwd"): + out, out_vjp = jax.vjp(ring_attn, ring_kernel, q, k, v, segment_ids) + out_ref, out_vjp_ref = jax.vjp( + ring_attn_ref, q, k, v, mask[:, :], segment_ids + ) + self._assert_allclose(out, out_ref, rtol=5e-3, atol=3e-3) + + _, dq, dk, dv, _ = out_vjp(do) + dq_ref, dk_ref, dv_ref, _, _ = out_vjp_ref(do.astype(jnp.float32)) + + self._assert_allclose(dq, dq_ref, rtol=1e-2, atol=1e-2) + self._assert_allclose(dk, dk_ref, rtol=1e-2, atol=1e-2) + self._assert_allclose(dv, dv_ref, rtol=1e-2, atol=1e-2) + + +if __name__ == "__main__": + absltest.main() diff --git a/src/maxdiffusion/kernels/splash_attention/splash_attention_kernel.py b/src/maxdiffusion/kernels/splash_attention/splash_attention_kernel.py new file mode 100644 index 00000000..b125f533 --- /dev/null +++ b/src/maxdiffusion/kernels/splash_attention/splash_attention_kernel.py @@ -0,0 +1,2173 @@ +# Copyright 2025 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Implementation of Sparse Flash Attention, a.k.a. "Splash" attention.""" + +from collections.abc import Callable +import dataclasses +import enum +import functools +import json +import math +from typing import Any, NamedTuple + +import jax +from jax import ad_checkpoint +from jax import lax +from jax import tree_util +from jax.experimental import pallas as pl +from jax.experimental.pallas import tpu as pltpu +import jax.numpy as jnp +import numpy as np +from . import base +from . import splash_attention_mask as mask_lib +from . import splash_attention_mask_info as mask_info_lib + + +P = jax.P +MaskInfo = mask_info_lib.MaskInfo +partial = functools.partial +NUM_LANES = 128 +NUM_SUBLANES = 8 +# We predefine some useful dimension numbers for dot_general +NN_DIM_NUMBERS = (((1,), (0,)), ((), ())) # standard matmul +NT_DIM_NUMBERS = (((1,), (1,)), ((), ())) # RHS transposed + +LOG2E = math.log2(math.e) +LOG2E_INV = 1 / LOG2E + +# mypy: ignore-errors + + +def _not(x: jax.Array | bool) -> jax.Array | bool: + if isinstance(x, jax.Array): + return jnp.logical_not(x) + return not x + + +class SegmentIds(NamedTuple): + """SegmentIds for Q and KV sequences. + + SegmentIds are a mechanism to ensure that there is no cross-attention between + segments (fraction of a sequence) that have been concatenated together into a + sequence. Each array is a list of ids (integers). Only tokens with the same + id are allowed to attend to each other. + + The static mask (e.g. causal) is "and-ed" with the segment id mask to form + the actual attention mask. It is important that the latter does not have any + all-zero rows (along dimension kv). Otherwise it would result in a invalid + softmax (the denominator would be 0). + This condition holds for causal self-attention because in this case segment + ids form a block diagonal matrix so at least one element in each row is set. + It is easy to break this condition with non-self-attention configurations. + Attributes: + q: segment ids along the Q sequence + kv: segment ids along the KV sequence + """ + + q: jax.Array # [q_seq_len] + kv: jax.Array # [kv_seq_len] + +MaskFunctionType = Callable[..., jax.Array] + + +def get_kernel_name( + is_mqa: bool, save_residuals: bool, is_segmented: bool, phase: str +) -> str: + """Returns a unique name for all SplashAttention kernel variants.""" + assert phase in ["dq", "dkv", "fwd"] + # Saving residuals is supported only for the fwd phase. + assert not save_residuals or phase == "fwd" + residuals = "_residuals" if save_residuals else "_no_residuals" + attention_type = "mqa" if is_mqa else "mha" + segments = "_segmented" if is_segmented else "" + return f"splash_{attention_type}_{phase}{segments}{residuals}" + + +# Splash attention implementation + + +# We use an IntEnum to make it JSON serializable as regen metadata. +class QKVLayout(enum.IntEnum): + HEAD_DIM_MINOR = enum.auto() # [..., seq_len, head_dim] + SEQ_MINOR = enum.auto() # [..., head_dim, seq_len] + + +def from_head_minor(vals: tuple[Any, ...], layout: QKVLayout): + if layout == QKVLayout.HEAD_DIM_MINOR: + return vals + return (*vals[:-2], vals[-1], vals[-2]) + + +@dataclasses.dataclass(frozen=True, slots=True) +class SplashConfig: + """Tile sizes parameterizing SplashAttention kernels. + + Those parameters have negligible effect on numerics, but affect performance + greatly. + + Note that changing the layouts only influences the physical layout that the + kernel will enforce. The logical interface to splash attention always takes + the head dimension as the minormost one. + """ + + block_q: int + block_kv: int + block_kv_compute: int | None = None + + block_q_dkv: int | None = None + block_kv_dkv: int | None = None + block_kv_dkv_compute: int | None = None + + # TODO: Remove these 3 params, they're only kept for backwards compatibility. + block_q_dq: int | None = None + block_kv_dq: int | None = None + use_fused_bwd_kernel: bool = True + + q_layout: QKVLayout = QKVLayout.HEAD_DIM_MINOR + k_layout: QKVLayout = QKVLayout.HEAD_DIM_MINOR + v_layout: QKVLayout = QKVLayout.HEAD_DIM_MINOR + + fwd_cost_estimate: pl.CostEstimate | None = None + bwd_cost_estimate: pl.CostEstimate | None = None + + residual_checkpoint_name: str | None = None # whether to checkpoint outputs + attn_logits_soft_cap: float | None = None + fuse_reciprocal: bool = True # whether to compute o / lse inside the kernel + use_base2_exp: bool = True + max_logit_const: float | None = None + interpret: bool = False + # The fused bwd kernel accumulates dq at every grid step. To safely avoid + # read/write conflicts we conservatively avoid *any* in-kernel reductions. + # This parameter allows to override this behavior and specifies the number of + # reduction steps. For now, only 3 or all the kv steps are supported. + dq_reduction_steps: int | None = None + # An experimental scheduler that sometimes produces better softmax overlap. + use_experimental_scheduler: bool = False + + def __post_init__(self): + if self.block_kv_compute is None: + object.__setattr__(self, "block_kv_compute", self.block_kv) + if self.block_kv_dkv_compute is None: + object.__setattr__(self, "block_kv_dkv_compute", self.block_kv_dkv) + + if self.dq_reduction_steps is not None and self.dq_reduction_steps != 3: + raise ValueError( + f"Invalid dq_reduction_steps: {self.dq_reduction_steps}, only 3 or" + " None are supported." + ) + if not self.use_fused_bwd_kernel: + raise ValueError("Only the fused bwd kernel is supported.") + + @property + def has_backward_blocks(self) -> bool: + backward_blocks = ( + self.block_q_dkv, + self.block_kv_dkv, + self.block_kv_dkv_compute, + ) + return all(b is not None for b in backward_blocks) + + @classmethod + def get_default(cls): + # TODO: Select better parameters based on a heuristic. + return SplashConfig( + block_q=128, + block_kv=128, + block_kv_compute=128, + block_q_dkv=128, + block_kv_dkv=128, + block_kv_dkv_compute=128, + block_q_dq=128, + block_kv_dq=128, + fuse_reciprocal=True, + ) + + +to_i32 = lambda x: x.astype(jnp.int32) + + +def _apply_mask_and_soft_cap( + qk: jax.Array, + mask_value: float, + mask_ref, + q_sequence_ref, + q_segment_ids_ref, + kv_segment_ids_ref, + *, + attn_logits_soft_cap: float | None, + k_slice: pl.Slice, + k_offset: int | jax.Array, + bq: int, + k_in_lanes=True, + mask_function=None, + has_partial_mask: bool = False, +) -> jax.Array | tuple[jax.Array, jax.Array, jax.Array, jax.Array]: + assert mask_ref is None or q_sequence_ref is None + assert (q_sequence_ref is None) == (mask_function is None) + + masks = [] + if has_partial_mask: + if mask_ref is not None: + mask = mask_ref[:, k_slice] if k_in_lanes else mask_ref[k_slice, :] + masks.append(mask) + elif mask_function is not None: + # Compute the mask using the given q_sequence indices. + # KV indices are computed on the fly. This works because we only support Q + # sequence sharding. If we wanted to compute Q indices too, then we would + # need to keep into account the current shard along Q sequence. + + if k_in_lanes: + assert q_sequence_ref.shape == (bq, NUM_LANES) + + k_sequence = k_offset + jax.lax.broadcasted_iota( + jnp.int32, (bq, k_slice.size), 1 + ) + + repeats, rem = divmod(k_slice.size, NUM_LANES) + assert rem == 0 + q_sequence = jnp.tile( + q_sequence_ref[...], (1, repeats) + ) # [bq, k_slice.size] + else: + assert q_sequence_ref.shape == (NUM_SUBLANES, bq) + + k_sequence = k_offset + jax.lax.broadcasted_iota( + jnp.int32, (k_slice.size, bq), 0 + ) + q_sequence = q_sequence_ref[:1, :] # [1, bq] + q_sequence = jnp.broadcast_to(q_sequence, (k_slice.size, bq)) + + assert q_sequence.shape == k_sequence.shape + computed_mask = mask_function(q_sequence, k_sequence) # pytype: disable=wrong-arg-count + if computed_mask.dtype != jnp.dtype(jnp.bool_): + raise ValueError( + "Mask function must return a boolean-valued array, but got:" + f" {computed_mask.dtype}" + ) + masks.append(computed_mask) + + if q_segment_ids_ref is not None: + if k_in_lanes: + kv_ids = kv_segment_ids_ref[:1, k_slice] # [1, k_slice] + repeats, rem = divmod(kv_ids.shape[1], NUM_LANES) + if rem: + raise NotImplementedError(f"block_kv must be a multiple of {NUM_LANES}") + q_ids = jnp.tile(q_segment_ids_ref[:], (1, repeats)) # [bq, bkv] + else: + assert bq == q_segment_ids_ref.shape[-1] + repeats, rem = divmod(bq, NUM_LANES) + if rem: + raise NotImplementedError(f"block_q must be a multiple of {NUM_LANES}") + kv_ids = jnp.tile( + kv_segment_ids_ref[k_slice, :], (1, repeats) + ) # [k_slice, bq] + q_ids = q_segment_ids_ref[:1, :] # [1, bq] + masks.append(q_ids == kv_ids) + + def cap_logits(logits): + if attn_logits_soft_cap is not None: + logits = jnp.tanh(qk / attn_logits_soft_cap) + return logits * attn_logits_soft_cap + else: + return logits + + if masks: + mask = functools.reduce(jnp.logical_and, masks) + qk = cap_logits(qk) + qk = jnp.where(mask, qk, mask_value) + else: + qk = cap_logits(qk) + return qk + + +def flash_attention_kernel( + # Prefetched inputs + active_rows_ref, + active_cols_ref, + mask_next_ref, + bounds_start_ref, + bounds_end_ref, + block_mask_ref, + # Inputs + q_ref, + k_ref, + v_ref, + q_segment_ids_ref, + kv_segment_ids_ref, + sinks_ref, + mask_ref, + q_sequence_ref, + max_logit_value_ref, + # Outputs + o_ref, + logsumexp_ref, + l_linear_ref, + max_logits_ref, + # Scratch + m_scratch_ref, + l_scratch_ref, + o_scratch_ref, + *, + mask_value: float, + kv_steps: int, + bq: int, + bkv: int, + bkv_compute: int, + head_dim_v: int, + mask_function: MaskFunctionType | None, + fuse_reciprocal: bool, # config.fuse_reciprocal or not save_residuals + config: SplashConfig, +): + del mask_next_ref, active_rows_ref + float32 = jnp.float32 + HEAD_DIM_MINOR = QKVLayout.HEAD_DIM_MINOR + attn_logits_soft_cap = config.attn_logits_soft_cap + if attn_logits_soft_cap is not None and config.use_base2_exp: + attn_logits_soft_cap *= LOG2E + + # If the head_dim_v is not a multiple of the number of lanes, it will be + # padded to that multiple with zeros. + head_dim_v_repeats = pl.cdiv(head_dim_v, NUM_LANES) + + grid_idx = pl.program_id(1) + h = pl.program_id(0) + + if block_mask_ref is not None: + should_not_mask = block_mask_ref[grid_idx].astype(jnp.int32) != 1 + should_initialize = bounds_start_ref[grid_idx].astype(jnp.bool_) + should_write = bounds_end_ref[grid_idx].astype(jnp.bool_) + j = active_cols_ref[grid_idx].astype(jnp.int32) + else: + should_not_mask = False + j = grid_idx % kv_steps + should_initialize = j == 0 + should_write = j == kv_steps - 1 + + max_logit_estimate = config.max_logit_const # potentially None + if max_logit_value_ref is not None: # already ensures max_logit_const is None + max_logit_estimate = max_logit_value_ref[0, h] + + @pl.when(should_initialize) + def init(): + o_scratch_ref[...] = jnp.zeros_like(o_scratch_ref) + + sink = None + if sinks_ref is not None: + sink = sinks_ref[0, h].astype(m_scratch_ref.dtype) + + if sinks_ref is None and max_logit_estimate is None: + m_scratch_ref[...] = jnp.full_like(m_scratch_ref, mask_value) + l_scratch_ref[...] = jnp.zeros_like(l_scratch_ref) + elif sinks_ref is None and max_logit_estimate is not None: + m_scratch_ref[...] = jnp.full_like(m_scratch_ref, max_logit_estimate) + l_scratch_ref[...] = jnp.zeros_like(l_scratch_ref) + elif sinks_ref is not None and max_logit_estimate is None: + m_scratch_ref[...] = jnp.full_like(m_scratch_ref, sink) + l_scratch_ref[...] = jnp.ones_like(l_scratch_ref) + else: # sinks_ref is not None and max_logit_estimate is not None + exp = jnp.exp2 if config.use_base2_exp else jnp.exp + m_scratch_ref[...] = jnp.full_like(m_scratch_ref, max_logit_estimate) + l_scratch_ref[...] = exp( + sink - jnp.full_like(l_scratch_ref, max_logit_estimate) + ) + + def body(kv_compute_index, _, has_partial_mask=False): + slice_k = pl.ds(kv_compute_index * bkv_compute, bkv_compute) + m_prev, l_prev = m_scratch_ref[...], l_scratch_ref[...] + assert m_prev.shape == (bq, NUM_LANES) + assert l_prev.shape == (bq, NUM_LANES) + + q = q_ref[...] if config.q_layout == HEAD_DIM_MINOR else q_ref[...].T + if config.use_base2_exp: + q *= LOG2E + + qk_dims = ( + NT_DIM_NUMBERS if config.k_layout == HEAD_DIM_MINOR else NN_DIM_NUMBERS + ) + if config.k_layout == HEAD_DIM_MINOR: + k = k_ref[slice_k, :] + else: + k = k_ref[:, slice_k] + qk = lax.dot_general(q, k, qk_dims, preferred_element_type=float32) + + assert qk.shape == (bq, bkv_compute) + apply_mask_and_soft_cap = functools.partial( + _apply_mask_and_soft_cap, + qk, + mask_value, + mask_ref, + q_sequence_ref, + q_segment_ids_ref, + kv_segment_ids_ref, + attn_logits_soft_cap=attn_logits_soft_cap, + k_slice=slice_k, + k_offset=j * bkv + kv_compute_index * bkv_compute, + bq=bq, + mask_function=mask_function, + has_partial_mask=has_partial_mask, + ) + + qk = apply_mask_and_soft_cap() + + if max_logit_estimate is None: + m_curr = qk.max(axis=-1)[:, None] # pytype: disable=attribute-error + assert m_curr.shape == (bq, 1) + m_next = jnp.maximum(m_prev, m_curr) + assert m_next.shape == (bq, NUM_LANES) + else: + m_next = None + + bkv_repeats, rem = divmod(bkv_compute, NUM_LANES) + if rem != 0: + raise NotImplementedError( + f"{bkv_compute=} should be a multiple of {NUM_LANES}" + ) + + exp = jnp.exp2 if config.use_base2_exp else jnp.exp + if max_logit_estimate is None: + s_curr = exp(qk - jnp.tile(m_next, (1, bkv_repeats))) + else: + s_curr = exp(qk - max_logit_estimate) + assert s_curr.shape == (bq, bkv_compute) + + l_curr = jax.lax.broadcast_in_dim(s_curr.sum(axis=-1), l_prev.shape, (0,)) + assert l_curr.shape == (bq, NUM_LANES) + + if max_logit_estimate is None: + alpha = exp(m_prev - m_next) + l_next = l_curr + alpha * l_prev + m_scratch_ref[...], l_scratch_ref[...] = m_next, l_next + else: + alpha = None + l_scratch_ref[...] = l_curr + l_prev + + sv_dims = ( + NN_DIM_NUMBERS if config.v_layout == HEAD_DIM_MINOR else NT_DIM_NUMBERS + ) + if config.v_layout == HEAD_DIM_MINOR: + v = v_ref[slice_k, :] + else: + v = v_ref[:, slice_k] + o_curr = lax.dot_general(s_curr, v, sv_dims) + + if max_logit_estimate is None: + alpha_o = jnp.tile(alpha, (1, head_dim_v_repeats)) + alpha_o = alpha_o[..., : o_scratch_ref.shape[-1]] + o_scratch_ref[...] = alpha_o * o_scratch_ref[...] + o_curr + else: + o_scratch_ref[...] = o_scratch_ref[...] + o_curr + + assert bkv % bkv_compute == 0 + num_iters = ( + k_ref.shape[0 if config.k_layout == HEAD_DIM_MINOR else 1] // bkv_compute + ) + + @pl.when(should_not_mask) + def _(): + lax.fori_loop(0, num_iters, body, None, unroll=True) + + @pl.when(jnp.logical_not(should_not_mask)) + def _(): + lax.fori_loop( + 0, num_iters, partial(body, has_partial_mask=True), None, unroll=True + ) + + @pl.when(should_write) + def end(): + l = l_scratch_ref[...] + m = m_scratch_ref[...] + if fuse_reciprocal: # allows fusing reciprocal out of the kernel + l_inv = jnp.tile(1.0 / l, (1, head_dim_v_repeats)) + l_inv = l_inv[..., : o_scratch_ref.shape[-1]] + o_ref[...] = (o_scratch_ref[...] * l_inv).astype(o_ref.dtype) + else: + o_ref[...] = o_scratch_ref[...].astype(o_ref.dtype) + if logsumexp_ref is not None: + assert logsumexp_ref.shape == (bq, NUM_LANES) + log = jnp.log2 if config.use_base2_exp else jnp.log + logsumexp = m + log(l) + logsumexp_ref[...] = logsumexp.astype(logsumexp_ref.dtype) + if l_linear_ref is not None: + assert l_linear_ref.shape == (bq, NUM_LANES) + l_linear_ref[...] = l.astype(l_linear_ref.dtype) + if max_logits_ref is not None: + assert max_logits_ref.shape == (bq, NUM_LANES) + max_logits_ref[...] = m.astype(max_logits_ref.dtype) + + +def _div(dividend: int, divisor: int): + if divisor == 1: + return dividend + + return lax.div(dividend, divisor) + + +def _bytes(x: jax.Array | jax.ShapeDtypeStruct | None) -> int: + if x is None: + return 0 + + if jnp.issubdtype(x.dtype, jnp.floating): + info = jnp.finfo + elif jnp.issubdtype(x.dtype, jnp.integer): + info = jnp.iinfo + else: + raise ValueError(f"Unsupported dtype: {x.dtype}") + return math.ceil(math.prod(x.shape) * info(x.dtype).bits / 8) + + +def _splash_attention_forward( + mask_info: MaskInfo, + q: jax.Array, + k: jax.Array, + v: jax.Array, + segment_ids: base.SegmentIds | None, + sinks: jax.Array | None, + mask_value: float, + is_mqa: bool, + config: SplashConfig, + save_residuals: bool, + mask_function: MaskFunctionType | None, + fwd_mask_sparsity: float, + max_logit_value: jax.Array | None = None, +) -> base.SplashCustomReturnType: + num_q_heads, q_seq_len, head_dim_qk = q.shape + head_dim_v = v.shape[-1] + bq, bkv = config.block_q, config.block_kv + bkv_compute = config.block_kv_compute + fuse_reciprocal = config.fuse_reciprocal or not save_residuals + bounds_start, bounds_end = mask_info_lib.find_bounds(mask_info.active_rows) + + if is_mqa: + expected_kv_rank = 2 + num_kv_heads = 1 + else: + expected_kv_rank = 3 + num_kv_heads = k.shape[0] + + if len(k.shape) != expected_kv_rank: + raise ValueError( + f"Expected {expected_kv_rank}-dim 'key' tensor for MQA. Instead got a" + f" {len(k.shape)}-dim one." + ) + + if k.shape[-1] != head_dim_qk: + raise ValueError( + f"Expected 'key' head dimension to be: {head_dim_qk}. Instead got:" + f" {k.shape[-1]}." + ) + + if not is_mqa and num_q_heads % num_kv_heads != 0: + raise ValueError( + f"In MHA, expected number of 'key' heads ({num_kv_heads}) to be a" + f" multiple of the number of 'query' heads ({num_q_heads})" + ) + + if k.shape[:-1] != v.shape[:-1]: + raise ValueError( + f"Expected 'key' {k.shape} and 'value' {v.shape} to have the same " + "leading dimensions." + ) + + if bkv % bkv_compute: + raise ValueError(f"{bkv=} must be a multiple of {bkv_compute=}.") + if bkv_compute % NUM_LANES: + raise ValueError(f"{bkv_compute=} must be a multiple of {NUM_LANES}.") + + kv_seq_len = k.shape[-2] + kv_steps = kv_seq_len // bkv + q_heads_per_kv_head = num_q_heads // num_kv_heads + dynamic_grid = mask_info.active_rows is not None + + if segment_ids is not None: + assert isinstance(segment_ids.q, jax.Array) # for pytype + assert isinstance(segment_ids.kv, jax.Array) # for pytype + if segment_ids.q.shape != (q_seq_len,): + raise ValueError( + "Invalid shape for q segment_ids: " + f"{segment_ids.q.shape}. Expected: {(q_seq_len,)}" + ) + if segment_ids.kv.shape != (kv_seq_len,): + raise ValueError( + "Invalid shape for kv segment_ids: " + f"{segment_ids.kv.shape}. Expected: {(kv_seq_len,)}" + ) + if config.max_logit_const is not None and max_logit_value is not None: + raise ValueError( + f"Only one of {config.max_logit_const=} and" + f" {max_logit_value=} can be set." + ) + if max_logit_value is not None: + if max_logit_value.shape not in ((), (1,), (num_q_heads,)): + raise ValueError( + "max_logit_value should be a 0,1-dim jax.Array of shape (), (1,) or" + f" ({num_q_heads=},) but got {jax.typeof(max_logit_value)}" + ) + max_logit_value = jnp.broadcast_to( + jnp.atleast_1d(max_logit_value), (num_q_heads,) + ) + + q_layout = config.q_layout + k_layout = config.k_layout + v_layout = config.v_layout + + def unravel(f): + def index_map(h, grid_idx, rows_ref, cols_ref, *_): + if dynamic_grid: + i = to_i32(rows_ref[grid_idx]) + j = to_i32(cols_ref[grid_idx]) + else: + i = grid_idx // kv_steps + j = grid_idx % kv_steps + return f(h, i, j) + + return index_map + + def create_kv_index_map(layout): + def index_map(h, i, j): + del i # Unused. + prefix = () if is_mqa else (_div(h, q_heads_per_kv_head),) + return from_head_minor((*prefix, j, 0), layout) + + return index_map + + q_index_map = unravel(lambda h, i, j: from_head_minor((h, i, 0), q_layout)) + out_index_map = unravel(lambda h, i, j: (h, i, 0)) + k_index_map = unravel(create_kv_index_map(k_layout)) + v_index_map = unravel(create_kv_index_map(v_layout)) + + def mask_index_map(h, grid_idx, rows_ref, cols_ref, mask_next_ref=None, *_): + del h, rows_ref, cols_ref # Unused. + next_m = to_i32(mask_next_ref[grid_idx]) + return next_m, 0, 0 + + q_segment_ids_index_map = unravel(lambda h, i, j: (i, 0)) + kv_segment_ids_index_map = unravel(lambda h, i, j: (0, j)) + + # Convert the logical shape from head-minor to sequence-minor. + in_specs = [ + pl.BlockSpec( + from_head_minor((None, bq, head_dim_qk), q_layout), q_index_map + ), + pl.BlockSpec( + from_head_minor( + (bkv, head_dim_qk) if is_mqa else (None, bkv, head_dim_qk), + k_layout, + ), + k_index_map, + ), + pl.BlockSpec( + from_head_minor( + (bkv, head_dim_v) if is_mqa else (None, bkv, head_dim_v), v_layout + ), + v_index_map, + ), + ] + if segment_ids is not None: + in_specs += [ + pl.BlockSpec((bq, NUM_LANES), q_segment_ids_index_map), + pl.BlockSpec((NUM_SUBLANES, bkv), kv_segment_ids_index_map), + ] + q_segment_ids = jax.lax.broadcast_in_dim( + segment_ids.q, (q_seq_len, NUM_LANES), (0,) + ) + kv_segment_ids = jax.lax.broadcast_in_dim( + segment_ids.kv, (NUM_SUBLANES, kv_seq_len), (1,) + ) + else: + in_specs += [None, None] + q_segment_ids = kv_segment_ids = None + + if sinks is not None: + assert sinks.shape == (num_q_heads,), f"{sinks.shape=} != {num_q_heads=}" + # align sinks to sublanes to allow vmap and shard_map over the kernel + in_specs += [ + pl.BlockSpec( + (NUM_SUBLANES, num_q_heads), + lambda h, i, j, *_: (0, 0), + memory_space=pltpu.SMEM, + ) + ] + sinks = jnp.broadcast_to( + sinks.astype(jnp.float32)[None, :], (NUM_SUBLANES, num_q_heads) + ) + else: + in_specs += [None] + + if mask_info.partial_mask_blocks is not None: + in_specs.append(pl.BlockSpec((None, bq, bkv), mask_index_map)) + else: + in_specs.append(None) + + assert mask_info.partial_mask_blocks is None or mask_info.q_sequence is None + + if mask_info.q_sequence is not None: + q_sequence = jax.lax.broadcast_in_dim( + mask_info.q_sequence, (q_seq_len, NUM_LANES), (0,) + ) + in_specs.append(pl.BlockSpec((bq, NUM_LANES), q_segment_ids_index_map)) + else: + q_sequence = None + in_specs.append(None) + + if max_logit_value is not None: + # reshape to allow sublane selection for vmap-ping and shard_map-ping + max_logit_value = jnp.broadcast_to( + max_logit_value.astype(jnp.float32)[None, :], + (NUM_SUBLANES, num_q_heads), + ) + in_specs += [ + pl.BlockSpec( + (NUM_SUBLANES, num_q_heads), + lambda *_: (0, 0), + memory_space=pltpu.SMEM, + ) + ] + else: + in_specs.append(None) + + out_shapes = [ + jax.ShapeDtypeStruct((num_q_heads, q_seq_len, head_dim_v), q.dtype), + ] + out_specs = [ + pl.BlockSpec((None, bq, head_dim_v), out_index_map), + ] + if save_residuals: + logsumexp_index_map = unravel(lambda h, i, j, *_: (h, i, 0)) + + out_shapes += [ + # logsumexp + jax.ShapeDtypeStruct((num_q_heads, q_seq_len, NUM_LANES), jnp.float32) + if fuse_reciprocal + else None, + # l_linear + jax.ShapeDtypeStruct((num_q_heads, q_seq_len, NUM_LANES), jnp.float32) + if not fuse_reciprocal + else None, + # max_logits + jax.ShapeDtypeStruct((num_q_heads, q_seq_len, NUM_LANES), jnp.float32), + ] + out_specs += [ + pl.BlockSpec((None, bq, NUM_LANES), logsumexp_index_map) + if fuse_reciprocal + else None, + pl.BlockSpec((None, bq, NUM_LANES), logsumexp_index_map) + if not fuse_reciprocal + else None, + pl.BlockSpec((None, bq, NUM_LANES), logsumexp_index_map), + ] + else: + out_shapes += [None, None, None] + out_specs += [None, None, None] + + kernel_name = get_kernel_name( + is_mqa=is_mqa, + save_residuals=save_residuals, + is_segmented=segment_ids is not None, + phase="fwd", + ) + metadata = {"xprof_metadata": json.dumps(dataclasses.asdict(config))} + + def _fwd_cost_estimate( + q: jax.Array, + k: jax.Array, + v: jax.Array, + q_segment_ids: jax.Array | None, + kv_segment_ids: jax.Array | None, + partial_mask_blocks: jax.Array | None, + out_shapes: list[jax.ShapeDtypeStruct], + mask_sparsity: float, + ) -> pl.CostEstimate: + num_q_heads, q_seq_len, head_dim_qk = q.shape + kv_seq_len, head_dim_v = v.shape[-2:] + + matmul_flops = ( + 2 * q_seq_len * kv_seq_len * head_dim_qk + + 2 * q_seq_len * kv_seq_len * head_dim_v + ) + + # This is an upper bound because `mask_sparsity` is actually the mean + # sparsity of the non-fully masked **blocks**. + total_flops = num_q_heads * matmul_flops * mask_sparsity + + # Count expensive exp() calls + transcendentals = num_q_heads * q_seq_len * kv_seq_len * mask_sparsity + + inputs_ = [q, k, v, q_segment_ids, kv_segment_ids, partial_mask_blocks] + input_bytes = sum(map(_bytes, inputs_)) + output_bytes = sum(map(_bytes, out_shapes)) + return pl.CostEstimate( + flops=int(total_flops), + transcendentals=int(transcendentals), + bytes_accessed=int(input_bytes + output_bytes), + ) + + vmem_inputs = [ + q, + k, + v, + q_segment_ids, + kv_segment_ids, + mask_info.partial_mask_blocks, + ] + cost_estimate = config.fwd_cost_estimate or _fwd_cost_estimate( + *vmem_inputs, out_shapes, fwd_mask_sparsity + ) + + if dynamic_grid: + num_active_blocks = mask_info.num_active_blocks[0] + grid = (num_q_heads, num_active_blocks) + is_empty_attention_block = num_active_blocks == 0 + else: + grid = (num_q_heads, kv_steps * (q_seq_len // bq)) + is_empty_attention_block = False + + with jax.named_scope(kernel_name): + all_out = pl.pallas_call( + partial( + flash_attention_kernel, + mask_value=mask_value, + kv_steps=kv_steps, + bq=bq, + bkv=bkv, + bkv_compute=bkv_compute, + head_dim_v=head_dim_v, + # note: fuse_reciprocal can only be False if save_residuals is True + # fuse_reciprocal = (config.fuse_reciprocal or not save_residuals) + fuse_reciprocal=fuse_reciprocal, + config=config, + mask_function=mask_function, + ), + grid_spec=pltpu.PrefetchScalarGridSpec( + num_scalar_prefetch=6, + in_specs=in_specs, + out_specs=out_specs, + grid=grid, + scratch_shapes=[ + pltpu.VMEM((bq, NUM_LANES), jnp.float32), # m_scratch + pltpu.VMEM((bq, NUM_LANES), jnp.float32), # l_scratch + pltpu.VMEM((bq, head_dim_v), jnp.float32), # o_scratch + ], + ), + compiler_params=pltpu.CompilerParams( + dimension_semantics=("parallel", "arbitrary"), + flags={ + "XLA_TPU_FORCE_LP_LLO_SCHEDULER": ( + config.use_experimental_scheduler + ) + }, + ), + out_shape=out_shapes, + name=kernel_name, + cost_estimate=cost_estimate, + interpret=config.interpret, + metadata=metadata, + )( + mask_info.active_rows, + mask_info.active_cols, + mask_info.mask_next, + bounds_start, + bounds_end, + mask_info.block_mask, + q if q_layout == QKVLayout.HEAD_DIM_MINOR else q.mT, + k if k_layout == QKVLayout.HEAD_DIM_MINOR else k.mT, + v if v_layout == QKVLayout.HEAD_DIM_MINOR else v.mT, + q_segment_ids, + kv_segment_ids, + sinks, + mask_info.partial_mask_blocks, + q_sequence, + max_logit_value, + ) + out, logsumexp, l_linear, max_logits = all_out + + # If there is no compute to do within an attention block, then we want to + # initialize the output and residuals to default values. Otherwise, we will + # read uninitialized memory. This is a common case in ring attention. + def init_if_empty(x: jax.Array, value: float) -> jax.Array: + if not dynamic_grid: + return x + + return jnp.where(is_empty_attention_block, value, x) + + out = init_if_empty(out, 0.0) + + if save_residuals: + assert max_logits is not None + max_logits = init_if_empty(max_logits[..., 0], mask_value) + + if fuse_reciprocal: + assert logsumexp is not None + logsumexp = init_if_empty(logsumexp[..., 0], mask_value) + else: + assert l_linear is not None + log = jnp.log2 if config.use_base2_exp else jnp.log + + l = l_linear[..., 0] + logsumexp = max_logits + log(l) + out = (out / l[..., None]).astype(out.dtype) + else: + # If we're not saving residuals, then we can't fuse the reciprocal + # out of the kernel. + assert fuse_reciprocal + + if config.residual_checkpoint_name is not None: + out = ad_checkpoint.checkpoint_name( + out, name=config.residual_checkpoint_name + ) + if logsumexp is not None: + logsumexp = ad_checkpoint.checkpoint_name( + logsumexp, name=config.residual_checkpoint_name + ) + if save_residuals: + stats = {"logsumexp": logsumexp, "max_logits": max_logits} + stats = jax.tree.map(jax.lax.stop_gradient, stats) + return out, stats + return out + + +@partial( + jax.custom_vjp, + nondiff_argnames=( + "save_residuals", + "mask_value", + "is_mqa", + "config", + "mask_function", + "fwd_mask_sparsity", + "dkv_mask_sparsity", + ), +) +def _splash_attention_custom( + fwd_mask_info: MaskInfo, + dkv_mask_info: MaskInfo | None, + q: jax.Array, + k: jax.Array, + v: jax.Array, + segment_ids: base.SegmentIds | None, + sinks: jax.Array | None, + save_residuals: bool, + mask_value: float, + is_mqa: bool, + config: SplashConfig, + mask_function: MaskFunctionType | None, + fwd_mask_sparsity: float, + dkv_mask_sparsity: float, + max_logit_value: jax.Array | None = None, +) -> base.SplashCustomReturnType: + # The forward function does not use the dq and dkv MaskInfos, it just forwards + # them to the backward function as residuals. This is a way to communicate + # arbitrary Arrays to the backward function. Since the three MaskInfos are + # constants there is no overhead in passing them to the backward function as + # residuals. When sharding computation MaskInfos are partitioned so both the + # forward and the backward kernels need to work on the relevant slice. If we + # recomputed the backward MaskInfos in the backward function from the numpy + # mask then we would not work with the MaskInfo slice relevant to the current + # device. + del dkv_mask_info + + ret = _splash_attention_forward( # pytype: disable=wrong-arg-types + fwd_mask_info, + q, + k, + v, + segment_ids, + sinks, + mask_value=mask_value, + is_mqa=is_mqa, + config=config, + save_residuals=save_residuals, + mask_function=mask_function, + fwd_mask_sparsity=fwd_mask_sparsity, + max_logit_value=max_logit_value, + ) + if save_residuals: + out, stats = ret + if config.use_base2_exp: # for user, output values in natural base + stats["logsumexp"] = stats["logsumexp"] / LOG2E + stats["max_logits"] = stats["max_logits"] / LOG2E + return out, stats + else: + return ret + + +def _splash_attention_fwd( + fwd_mask_info: MaskInfo, + dkv_mask_info: MaskInfo | None, + q: jax.Array, + k: jax.Array, + v: jax.Array, + segment_ids: base.SegmentIds | None, + sinks: jax.Array | None, + save_residuals: bool, + mask_value: float, + is_mqa: bool, + config: SplashConfig, + mask_function: MaskFunctionType | None, + fwd_mask_sparsity: float, + dkv_mask_sparsity: float, + max_logit_value: jax.Array | None = None, +) -> tuple[tuple[jax.Array], base.SplashResidualsType]: + + # TODO: add some higher order AD check that isn't save_residuals based. + # if save_residuals: + # raise NotImplementedError("Higher-order AD not supported.") + + out, stats = _splash_attention_forward( # pytype: disable=wrong-arg-types + fwd_mask_info, + q, + k, + v, + segment_ids, + sinks, + mask_value=mask_value, + is_mqa=is_mqa, + config=config, + save_residuals=True, + mask_function=mask_function, + fwd_mask_sparsity=fwd_mask_sparsity, + max_logit_value=max_logit_value, + ) + logsumexp = stats["logsumexp"] # save in the config base for the bwd pass + if config.use_base2_exp: # for user, output values in natural base + stats["logsumexp"] = stats["logsumexp"] / LOG2E + stats["max_logits"] = stats["max_logits"] / LOG2E + residuals = q, k, v, segment_ids, sinks, out, logsumexp, dkv_mask_info + if save_residuals: + return (out, stats), residuals + else: + return out, residuals + + +def _flash_attention_dq_kernel( + # Prefetched inputs + active_rows_ref, + active_cols_ref, + mask_next_ref, + bounds_start_ref, + bounds_end_ref, + block_mask_ref, + # Inputs + q_ref, + k_ref, + v_ref, + q_segment_ids_ref, + kv_segment_ids_ref, + logsumexp_ref, + do_ref, + di_ref, + mask_ref, + q_sequence_ref, + # Outputs + dq_scratch_ref, + dq_ref, + *, + mask_value: float, + kv_steps: int, + bq: int, + bkv: int, + mask_function: MaskFunctionType | None, + config: SplashConfig, +): + del mask_next_ref, active_rows_ref + float32 = jnp.float32 + HEAD_DIM_MINOR = QKVLayout.HEAD_DIM_MINOR + attn_logits_soft_cap = config.attn_logits_soft_cap + if attn_logits_soft_cap is not None and config.use_base2_exp: + attn_logits_soft_cap *= LOG2E + + grid_idx = pl.program_id(1) + if block_mask_ref is not None: + kv_index = active_cols_ref[grid_idx].astype(jnp.int32) + should_not_mask = block_mask_ref[grid_idx].astype(jnp.int32) != 1 + should_initialize = bounds_start_ref[grid_idx].astype(jnp.bool_) + should_write = bounds_end_ref[grid_idx].astype(jnp.bool_) + else: + kv_index = grid_idx % kv_steps + should_not_mask = False + should_initialize = kv_index == 0 + should_write = kv_index == kv_steps - 1 + + @pl.when(should_initialize) + def init(): + dq_scratch_ref[...] = jnp.zeros_like(dq_scratch_ref) + + def body(has_partial_mask: bool = False): + q = q_ref[...] if config.q_layout == HEAD_DIM_MINOR else q_ref[...].T + if config.use_base2_exp: + q *= LOG2E + # We keep k and v possibly transposed, since they are RHS of dots. + k = k_ref[...] + v = v_ref[...] + logsumexp = jnp.expand_dims(logsumexp_ref[0], -1) + do = do_ref[...] + di = jnp.expand_dims(di_ref[0], -1) + + qk_dims = ( + NT_DIM_NUMBERS if config.k_layout == HEAD_DIM_MINOR else NN_DIM_NUMBERS + ) + qk_uncapped = lax.dot_general(q, k, qk_dims, preferred_element_type=float32) + + qk = _apply_mask_and_soft_cap( + qk_uncapped, + mask_value, + mask_ref, + q_sequence_ref, + q_segment_ids_ref, + kv_segment_ids_ref, + attn_logits_soft_cap=attn_logits_soft_cap, + k_slice=pl.ds(0, bkv), + k_offset=kv_index * bkv, + bq=bq, + mask_function=mask_function, + has_partial_mask=has_partial_mask, + ) + exp = jnp.exp2 if config.use_base2_exp else jnp.exp + p = exp(qk - logsumexp) + dp_dims = ( + NT_DIM_NUMBERS if config.v_layout == HEAD_DIM_MINOR else NN_DIM_NUMBERS + ) + dp = lax.dot_general( + do.astype(v.dtype), + v, + dp_dims, + preferred_element_type=jnp.float32, + ) + ds = (dp - di) * p + if attn_logits_soft_cap is not None: + normalized = qk_uncapped / attn_logits_soft_cap + d = jnp.tanh(normalized) + ds = ds * (1 - d * d) + + dq_dims = ( + NN_DIM_NUMBERS if config.k_layout == HEAD_DIM_MINOR else NT_DIM_NUMBERS + ) + dq_scratch_ref[...] += lax.dot_general( + ds.astype(k.dtype), + k, + dq_dims, + preferred_element_type=jnp.float32, + ) + + @pl.when(should_not_mask) + def _(): + body() + + @pl.when(jnp.logical_not(should_not_mask)) + def _(): + body(has_partial_mask=True) + + @pl.when(should_write) + def end(): + dq_ref[...] = dq_scratch_ref[...].astype(dq_ref.dtype) + + +def _flash_attention_dkv_kernel( + # Prefetched inputs + active_rows_ref, + active_cols_ref, + mask_next_ref, + bounds_start_ref, + bounds_end_ref, + block_mask_ref, + # Inputs + q_ref, + k_ref, + v_ref, + q_segment_ids_ref, + kv_segment_ids_ref, + logsumexp_ref, + do_ref, + di_ref, + mask_ref, + q_sequence_ref, + # aliases + dq_alias, + dk_alias, + dv_alias, + # Outputs + dq_ref, + dk_ref, + dv_ref, + # Scratch + dq_scratch_ref, + dk_scratch_ref, + dv_scratch_ref, + *, + mask_value: float, + q_steps: int, + bq: int, + bkv_compute: int, + bkv: int, + mask_function: MaskFunctionType | None, + q_heads_per_kv_head: int, + config: SplashConfig, +): + del mask_next_ref, active_cols_ref + HEAD_DIM_MINOR = QKVLayout.HEAD_DIM_MINOR + attn_logits_soft_cap = config.attn_logits_soft_cap + if attn_logits_soft_cap is not None and config.use_base2_exp: + attn_logits_soft_cap *= LOG2E + + if active_rows_ref is not None: + assert bounds_start_ref is not None + assert bounds_end_ref is not None + grid_idx = pl.program_id(1) + kv_index = active_rows_ref[grid_idx].astype(jnp.int32) + should_initialize = bounds_start_ref[grid_idx].astype(jnp.bool_) + should_write = bounds_end_ref[grid_idx].astype(jnp.bool_) + else: + kv_index, q_head, q_index = ( + pl.program_id(0), + pl.program_id(1), + pl.program_id(2), + ) + grid_idx = (kv_index * q_steps) + q_index + should_initialize = q_index == 0 + should_write = True if q_steps <= 2 else q_index == q_steps - 1 + if q_heads_per_kv_head > 1: + q_head_index_per_kv_head = lax.rem(q_head, q_heads_per_kv_head) + should_initialize = jnp.logical_and( + should_initialize, q_head_index_per_kv_head == 0 + ) + should_write = jnp.logical_and( + should_write, q_head_index_per_kv_head == q_heads_per_kv_head - 1 + ) + + if block_mask_ref is not None: + should_not_mask = block_mask_ref[grid_idx].astype(jnp.int32) != 1 + should_run = block_mask_ref[grid_idx].astype(jnp.int32) != 0 + else: + should_not_mask = False + should_run = True + + # TODO: Update docstring explaining the accumulation logic + + # Consider this situation: + # Q_heads: 0, 1, 2, 3, 4, 5, 6, 7 + # KV_heads: 0, 1, 2, 3 + # The gradient scratch buffers should be initialized for Q_heads 0, 2, 4, 6 + # (first Q_heads to 'see' a new KV_head). + # The gradient output buffers should be written for Q_heads 1, 3, 5, 7 (last + # Q_heads to 'see' the current KV_head). + + @pl.when(should_initialize) + def init(): + dk_scratch_ref[...] = jnp.zeros_like(dk_scratch_ref) + dv_scratch_ref[...] = jnp.zeros_like(dv_scratch_ref) + + def body(i, _, has_partial_mask=False): + + slice_k = pl.ds(i * bkv_compute, bkv_compute) + q = q_ref[...] # We keep q potentially transposed, since it's always RHS + if config.use_base2_exp: + scaled_q = q * LOG2E + else: + scaled_q = q + + def _load_kv(ref, layout): + if layout == HEAD_DIM_MINOR: + return ref[slice_k, :] + return ref[:, slice_k].T + + k = _load_kv(k_ref, config.k_layout) + v = _load_kv(v_ref, config.v_layout) + logsumexp = logsumexp_ref[:1, :] + do = do_ref[...] + di = di_ref[:1, :] + + qk_dims = ( + NT_DIM_NUMBERS if config.q_layout == HEAD_DIM_MINOR else NN_DIM_NUMBERS + ) + qk_uncapped = lax.dot_general( + k, scaled_q, qk_dims, preferred_element_type=jnp.float32 + ) + + qk = _apply_mask_and_soft_cap( + qk_uncapped, + mask_value, + mask_ref, + q_sequence_ref, + q_segment_ids_ref, + kv_segment_ids_ref, + attn_logits_soft_cap=attn_logits_soft_cap, + k_slice=slice_k, + k_offset=kv_index * bkv + i * bkv_compute, + bq=bq, + k_in_lanes=False, + mask_function=mask_function, + has_partial_mask=has_partial_mask, + ) + exp = jnp.exp2 if config.use_base2_exp else jnp.exp + p = exp(qk - logsumexp) + dv = lax.dot(p.astype(do.dtype), do, preferred_element_type=jnp.float32) + dv = dv.astype(dv_scratch_ref.dtype) + dv_scratch_ref[slice_k, :] + dv_scratch_ref[slice_k, :] = dv + + dp = lax.dot_general( + v, + do, + NT_DIM_NUMBERS, + preferred_element_type=jnp.float32, + ) + ds = (dp - di) * p + if attn_logits_soft_cap is not None: + normalized = qk_uncapped / attn_logits_soft_cap + d = jnp.tanh(normalized) + ds = ds * (1 - d * d) + dk_dims = ( + NN_DIM_NUMBERS if config.q_layout == HEAD_DIM_MINOR else NT_DIM_NUMBERS + ) + dk = lax.dot_general( + ds.astype(do.dtype), q, dk_dims, preferred_element_type=jnp.float32 + ) + dk = dk.astype(dk_scratch_ref.dtype) + dk_scratch_ref[slice_k, :] + dk_scratch_ref[slice_k, :] = dk + if dq_scratch_ref is not None or dq_ref is not None: + dq = lax.dot_general( + ds.T.astype(k.dtype), + k, + NN_DIM_NUMBERS, + preferred_element_type=jnp.float32, + ) + if dq_scratch_ref is not None: + # Compute block size != memory block size + dq_scratch_ref[...] += dq + else: + # Compute block size == memory block size + if dq_alias is not None: + dq_ref[...] = dq_alias[...] + dq.astype(dq_ref.dtype) + else: + dq_ref[...] = dq.astype(dq_ref.dtype) + + if dq_scratch_ref is not None: + dq_scratch_ref[...] = jnp.zeros_like(dq_scratch_ref) + elif dq_alias is not None: + dq_ref[...] = dq_alias[...] + else: + dq_ref[...] = jnp.zeros_like(dq_ref) + + num_iters = ( + k_ref.shape[0 if config.k_layout is HEAD_DIM_MINOR else 1] // bkv_compute + ) + + @pl.when(jnp.logical_and(should_not_mask, should_run)) + def _(): + lax.fori_loop(0, num_iters, body, None, unroll=True) + + @pl.when(jnp.logical_and(_not(should_not_mask), should_run)) + def _(): + lax.fori_loop( + 0, num_iters, partial(body, has_partial_mask=True), None, unroll=True + ) + + if dq_scratch_ref is not None: + if dq_alias is not None: + dq_ref[...] = dq_alias[...] + dq_scratch_ref[...].astype(dq_ref.dtype) + else: + dq_ref[...] = dq_scratch_ref[...].astype(dq_ref.dtype) + + if dk_alias is None: + assert dv_alias is None + + @pl.when(should_write) + def _(): + dk_ref[...] = dk_scratch_ref[...].astype(dk_ref.dtype) + dv_ref[...] = dv_scratch_ref[...].astype(dv_ref.dtype) + + else: + q_head = pl.program_id(0) + first_q_head_in_kv_group = lax.rem(q_head, q_heads_per_kv_head) == 0 + + @pl.when(jnp.logical_and(should_write, first_q_head_in_kv_group)) + def _(): + dk_ref[...] = dk_scratch_ref[...].astype(dk_ref.dtype) + dv_ref[...] = dv_scratch_ref[...].astype(dv_ref.dtype) + + @pl.when(jnp.logical_and(should_write, _not(first_q_head_in_kv_group))) + def _(): + dk_ref[...] = dk_alias[...] + dk_scratch_ref[...].astype(dk_ref.dtype) + dv_ref[...] = dv_alias[...] + dv_scratch_ref[...].astype(dv_ref.dtype) + + +def _splash_attention_bwd_dkv( + q, + k, + v, + segment_ids, + logsumexp, + do, + di, + *, + bq: int, + bkv: int, + bkv_compute: int, + is_mqa: bool, + mask_info: MaskInfo, + mask_value: float, + mask_function: MaskFunctionType | None, + config: SplashConfig, + dkv_mask_sparsity: float, +): + num_q_heads, q_seq_len, head_dim_qk = q.shape + kv_seq_len, head_dim_v = v.shape[-2:] + num_kv_heads = 1 if is_mqa else k.shape[0] + dynamic_grid = mask_info.active_rows is not None + + bounds_start, bounds_end = mask_info_lib.find_bounds(mask_info.active_rows) + if bq > q_seq_len: + raise ValueError(f"{bq=} should not be greater than {q_seq_len=}") + if bkv > kv_seq_len: + raise ValueError(f"{bkv=} should not be greater than {kv_seq_len=}") + if bkv_compute > bkv: + raise ValueError(f"{bkv_compute=} should not be greater than {bkv=}") + if bkv % bkv_compute: + raise ValueError(f"{bkv=} should be a multiple of {bkv_compute=}") + + if not is_mqa and num_q_heads % num_kv_heads != 0: + raise ValueError( + f"In MHA, expected number of 'key' heads ({num_kv_heads}) to be a" + f" multiple of the number of 'query' heads ({num_q_heads})" + ) + + if k.shape[:-1] != v.shape[:-1]: + raise ValueError( + f"Expected 'key' {k.shape} and 'value' {v.shape} to have the same " + "leading dimensions." + ) + + kv_steps = kv_seq_len // bkv + q_steps = q_seq_len // bq + q_heads_per_kv_head = num_q_heads // num_kv_heads + + if dynamic_grid: + + def unravel(f): + def index_map(h, grid_idx, rows_ref, cols_ref, *_): + j = to_i32(rows_ref[grid_idx]) + i = to_i32(cols_ref[grid_idx]) + return f(h, i, j) + + return index_map + + grid_size = mask_info.num_active_blocks[0] + grid = (num_q_heads, grid_size) + + def mask_index_map(h, grid_idx, rows_ref, cols_ref, mask_next_ref=None, *_): + del h, rows_ref, cols_ref # Unused. + next_m = to_i32(mask_next_ref[grid_idx]) + return next_m, 0, 0 + + else: + unravel = lambda f: lambda j, h, i, *_: f(h, i, j) + grid = (kv_steps, num_q_heads, q_steps) + + def mask_index_map(j, h, i, rows_ref, cols_ref, mask_next_ref=None, *_): + del h, rows_ref, cols_ref # Unused. + grid_idx = j * q_steps + i + next_m = to_i32(mask_next_ref[grid_idx]) + return next_m, 0, 0 + + q_index_map = unravel( + lambda h, i, j: from_head_minor((h, i, 0), config.q_layout) + ) + o_index_map = unravel(lambda h, i, j: (h, i, 0)) + + def create_kv_index_map(layout): + def index_map(h, i, j, *_): + del i # Unused. + prefix = () if is_mqa else (_div(h, q_heads_per_kv_head),) + return from_head_minor((*prefix, j, 0), layout) + + return index_map + + k_index_map = unravel(create_kv_index_map(config.k_layout)) + v_index_map = unravel(create_kv_index_map(config.v_layout)) + + q_spec = pl.BlockSpec( + from_head_minor((None, bq, head_dim_qk), config.q_layout), q_index_map + ) + + o_spec = pl.BlockSpec((None, bq, head_dim_v), o_index_map) + k_spec = pl.BlockSpec( + from_head_minor( + (bkv, head_dim_qk) if is_mqa else (None, bkv, head_dim_qk), + config.k_layout, + ), + k_index_map, + ) + + v_spec = pl.BlockSpec( + from_head_minor( + (bkv, head_dim_v) if is_mqa else (None, bkv, head_dim_v), + config.v_layout, + ), + v_index_map, + ) + + def create_dkv_index_map(h, i, j, *_): + del i # Unused. + prefix = () if is_mqa else (_div(h, q_heads_per_kv_head),) + return (*prefix, j, 0) + + dkv_index_map = unravel(create_dkv_index_map) + + dk_spec = pl.BlockSpec( + (bkv, head_dim_qk) if is_mqa else (None, bkv, head_dim_qk), + dkv_index_map, + ) + + dv_spec = pl.BlockSpec( + (bkv, head_dim_v) if is_mqa else (None, bkv, head_dim_v), + dkv_index_map, + ) + mask_spec = pl.BlockSpec((None, bkv, bq), mask_index_map) + + q_segment_ids_index_map = unravel(lambda h, i, j: (0, i)) + if segment_ids is not None: + kv_segment_ids_index_map = unravel(lambda h, i, j: (j, 0)) + + q_segment_spec = pl.BlockSpec((NUM_SUBLANES, bq), q_segment_ids_index_map) + kv_segment_spec = pl.BlockSpec((bkv, NUM_LANES), kv_segment_ids_index_map) + q_segment_ids = jax.lax.broadcast_in_dim( + segment_ids.q, (NUM_SUBLANES, q_seq_len), (1,) + ) + kv_segment_ids = jax.lax.broadcast_in_dim( + segment_ids.kv, (kv_seq_len, NUM_LANES), (0,) + ) + else: + q_segment_spec = kv_segment_spec = None + q_segment_ids = kv_segment_ids = None + + do_spec = o_spec + + logsumexp_index_map = unravel(lambda h, i, j: (h, 0, i)) + + assert logsumexp.shape == di.shape == (num_q_heads, q_seq_len) + # TODO: Remove the sublane expansion once Mosaic has all retilings + logsumexp_shape = (num_q_heads, NUM_SUBLANES, q_seq_len) + logsumexp = jnp.broadcast_to(jnp.expand_dims(logsumexp, -2), logsumexp_shape) + logsumexp_spec = pl.BlockSpec((None, NUM_SUBLANES, bq), logsumexp_index_map) + assert logsumexp.ndim == len(logsumexp_spec.block_shape) + + # TODO: Remove the sublane expansion once Mosaic has all retilings + di = jnp.broadcast_to(jnp.expand_dims(di, -2), logsumexp_shape) + di_spec = pl.BlockSpec((None, NUM_SUBLANES, bq), logsumexp_index_map) + assert di.ndim == len(di_spec.block_shape) + + in_specs = [ + q_spec, + k_spec, + v_spec, + q_segment_spec, + kv_segment_spec, + logsumexp_spec, + do_spec, + di_spec, + ] + if mask_info.partial_mask_blocks is not None: + in_specs.append(mask_spec) + else: + in_specs.append(None) + + if mask_info.q_sequence is not None: + in_specs.append(pl.BlockSpec((NUM_SUBLANES, bq), q_segment_ids_index_map)) + q_sequence = jax.lax.broadcast_in_dim( + mask_info.q_sequence, (NUM_SUBLANES, q_seq_len), (1,) + ) + else: + q_sequence = None + in_specs.append(None) + + dq_reduction_steps = config.dq_reduction_steps + if not dynamic_grid and kv_steps <= 3 and dq_reduction_steps == 3: + dq_reduction_steps = None + + dq = dq_alias_spec = None + if dq_reduction_steps == 3: + dq_index_map = unravel(lambda h, i, j: (j % 3, h, i, 0)) + dq_spec = pl.BlockSpec((None, None, bq, head_dim_qk), dq_index_map) + dq_alias_spec = dq_spec + dq_shape = jax.ShapeDtypeStruct((3, *q.shape), q.dtype) + dq = jnp.zeros_like(dq_shape) + else: + dq_index_map = unravel(lambda h, i, j: (j, h, i, 0)) + dq_spec = pl.BlockSpec((None, None, bq, head_dim_qk), dq_index_map) + # Only accumulate in fp32 if there's a small number of reduction steps. + q_dtype = q.dtype if kv_steps <= 4 else jnp.float32 + dq_shape = jax.ShapeDtypeStruct((kv_steps, *q.shape), q_dtype) + + in_specs += [dq_alias_spec] + + if bkv == bkv_compute: + dq_scratch = None + else: + dq_scratch = pltpu.VMEM((bq, head_dim_qk), jnp.float32) + + if dynamic_grid and q_heads_per_kv_head != 1: + # in/out aliasing to accumulate within kv groups. + in_specs += [dk_spec, dv_spec] + dk = lax.empty(k.shape, dtype=jnp.float32) + dv = lax.empty(v.shape, dtype=jnp.float32) + # Keep gradients in fp32 when accumulating over head groups. + dk_type = dv_type = jnp.float32 + else: + in_specs += [None, None] + dk, dv = None, None + dk_type = k.dtype + dv_type = v.dtype + + out_shapes = [ + dq_shape, + jax.ShapeDtypeStruct(k.shape, dk_type), + jax.ShapeDtypeStruct(v.shape, dv_type), + ] + out_specs = [dq_spec, dk_spec, dv_spec] + + kernel = functools.partial( + _flash_attention_dkv_kernel, + mask_value=mask_value, + q_steps=q_steps, + bq=bq, + bkv_compute=bkv_compute, + config=config, + bkv=bkv, + mask_function=mask_function, + q_heads_per_kv_head=q_heads_per_kv_head, + ) + + kernel_name = get_kernel_name( + is_mqa=is_mqa, + save_residuals=False, + is_segmented=segment_ids is not None, + phase="dkv", + ) + metadata = { + "xprof_metadata": json.dumps( + dict( + block_q_dkv=bq, + block_kv_dkv=bkv, + block_kv_dkv_compute=bkv_compute, + q_layout=config.q_layout, + k_layout=config.k_layout, + v_layout=config.v_layout, + use_experimental_scheduler=config.use_experimental_scheduler, + ), + ) + } + args = [ + # scalar prefetch + mask_info.active_rows, + mask_info.active_cols, + mask_info.mask_next, + bounds_start, + bounds_end, + mask_info.block_mask, + # inputs + q if config.q_layout == QKVLayout.HEAD_DIM_MINOR else q.mT, + k if config.k_layout == QKVLayout.HEAD_DIM_MINOR else k.mT, + v if config.v_layout == QKVLayout.HEAD_DIM_MINOR else v.mT, + q_segment_ids, + kv_segment_ids, + logsumexp, + do, + di, + mask_info.partial_mask_blocks, + q_sequence, + ] + num_args = sum(1 for x in args if x is not None) + input_output_aliases = {} + if dq_reduction_steps == 3: + if dynamic_grid and q_heads_per_kv_head != 1: + input_output_aliases = {num_args: 0, num_args + 1: 1, num_args + 2: 2} + else: + input_output_aliases = {num_args: 0} + elif dynamic_grid and q_heads_per_kv_head != 1: + input_output_aliases = {num_args: 1, num_args + 1: 2} + + scratch_shapes = [ + dq_scratch, + pltpu.VMEM((bkv, head_dim_qk), jnp.float32), + pltpu.VMEM((bkv, head_dim_v), jnp.float32), + ] + + def _bwd_cost_estimate( + q: jax.Array, + k: jax.Array, + v: jax.Array, + q_segment_ids: jax.Array | None, + kv_segment_ids: jax.Array | None, + logsumexp: jax.Array, + do: jax.Array, + di: jax.Array, + partial_mask_blocks: jax.Array | None, + q_sequence: jax.Array | None, + out_shapes: list[jax.ShapeDtypeStruct], + mask_sparsity_factor: float, + ) -> pl.CostEstimate: + num_q_heads, q_seq_len, head_dim_qk = q.shape + kv_seq_len, head_dim_v = v.shape[-2:] + + total_matmul_flops_per_head = ( + 2 * q_seq_len * kv_seq_len * head_dim_qk # qk + + 2 * q_seq_len * kv_seq_len * head_dim_v # dv + + 2 * q_seq_len * kv_seq_len * head_dim_v # dp + + 2 * q_seq_len * kv_seq_len * head_dim_qk # dq + + 2 * q_seq_len * kv_seq_len * head_dim_qk # dk + ) + + estimated_flops = int( + total_matmul_flops_per_head * num_q_heads * mask_sparsity_factor + ) + + exp_flops = num_q_heads * q_seq_len * kv_seq_len * mask_sparsity_factor + if config.attn_logits_soft_cap is None: + tanh_flops = 0 + else: + tanh_flops = ( + 2 * num_q_heads * q_seq_len * kv_seq_len * mask_sparsity_factor + ) + estimated_transcendentals = int(exp_flops + tanh_flops) + + inputs_ = [ + q, + k, + v, + q_segment_ids, + kv_segment_ids, + logsumexp, + do, + di, + partial_mask_blocks, + q_sequence, + ] + input_bytes = sum(map(_bytes, inputs_)) + output_bytes = sum(map(_bytes, out_shapes)) + + estimated_bytes = input_bytes + output_bytes + + return pl.CostEstimate( + flops=estimated_flops, + transcendentals=estimated_transcendentals, + bytes_accessed=estimated_bytes, + ) + + cost_estimate = config.bwd_cost_estimate or _bwd_cost_estimate( + q, + k, + v, + q_segment_ids, + kv_segment_ids, + logsumexp, + do, + di, + mask_info.partial_mask_blocks, + q_sequence, + out_shapes, + dkv_mask_sparsity, + ) + + with jax.named_scope(kernel_name): + dq_unreduced, dk, dv = pl.pallas_call( + kernel, + grid_spec=pltpu.PrefetchScalarGridSpec( + num_scalar_prefetch=6, + in_specs=in_specs, + out_specs=out_specs, + grid=grid, + scratch_shapes=scratch_shapes, + ), + out_shape=out_shapes, + input_output_aliases=input_output_aliases, + # We set all dimensions to arbitrary because: + # 1) for heads, we are reducing over heads + # 2) for kv_seq_len, the splash attention prefetch schedule assumes no + # megacore + # 3) for q_seq_len, we are reducing over it to compute dkv + compiler_params=pltpu.CompilerParams( + dimension_semantics=("arbitrary",) * len(grid) + ), + name=kernel_name, + cost_estimate=cost_estimate, + interpret=config.interpret, + metadata=metadata, + )(*args, dq, dk, dv) + dq = dq_unreduced.sum(axis=0) + dq = dq.astype(q.dtype) + dk = dk.astype(k.dtype) + dv = dv.astype(v.dtype) + return dq, dk, dv + + +def _splash_attention_bwd( + save_residuals: bool, + mask_value: float, + is_mqa: bool, + config: SplashConfig, + mask_function: MaskFunctionType | None, + fwd_mask_sparsity: float, + dkv_mask_sparsity: float, + res: base.SplashResidualsType, + do: jax.Array, +) -> tuple[ + MaskInfo | None, # fwd_mask_info + MaskInfo | None, # dvk_mask_info + jax.Array, # q + jax.Array, # k + jax.Array, # v + base.SegmentIds | None, # segment_ids + jax.Array | None, # segment_ids + jax.Array | None, # max_logit_estimate +]: + del save_residuals, fwd_mask_sparsity + if not config.has_backward_blocks: + raise ValueError("Need to specify backward blocks.") + bq_dkv, bkv_dkv_memory, bkv_dkv_compute = ( + config.block_q_dkv, + config.block_kv_dkv, + config.block_kv_dkv_compute, + ) + q, k, v, segment_ids, sinks, o, logsumexp, dkv_mask_info = res + + # di: [num_heads, q_seq_len] + di = jnp.einsum("hsd,hsd->hs", o.astype(jnp.float32), do.astype(jnp.float32)) # pytype: disable=attribute-error + dq, dk, dv = _splash_attention_bwd_dkv( + q, + k, + v, + segment_ids, + logsumexp, + do, + di, + bq=bq_dkv, + bkv=bkv_dkv_memory, + bkv_compute=bkv_dkv_compute, + is_mqa=is_mqa, + mask_info=dkv_mask_info, + mask_value=mask_value, + mask_function=mask_function, + config=config, + dkv_mask_sparsity=dkv_mask_sparsity, + ) + dsinks = None + if sinks is not None: + logsumexp_ = (logsumexp / LOG2E) if config.use_base2_exp else logsumexp + sinks_exp = -jnp.exp( + sinks[..., None, None].astype(jnp.float32) + - logsumexp_[..., None].astype(jnp.float32) + ) + dsinks = jnp.sum(sinks_exp.astype(o.dtype) * o * do, axis=(-1, -2)) + # Match the signature of the fwd function. + assert dq is not None + return ( + None, # fwd_mask_info + None, # dvk_mak_info + dq, # q + dk, # k + dv, # v + None, # segment_ids + dsinks, # sinks + None, # max_logit_estimate + ) + + +_splash_attention_custom.defvjp(_splash_attention_fwd, _splash_attention_bwd) + + +@partial( + jax.jit, + static_argnames=[ + "is_mqa", + "config", + "save_residuals", + "mask_value", + "mask_function", + "fwd_mask_sparsity", + "dkv_mask_sparsity", + ], +) +def _splash_attention( + fwd_mask_info: MaskInfo, + dkv_mask_info: MaskInfo | None, + q: jax.Array, + k: jax.Array, + v: jax.Array, + segment_ids: base.SegmentIds | None = None, + sinks: jax.Array | None = None, + *, + is_mqa: bool, + config: SplashConfig | None, + save_residuals: bool, + mask_value: float, + max_logit_value: jax.Array | None = None, + mask_function: MaskFunctionType | None, + fwd_mask_sparsity: float, + dkv_mask_sparsity: float, +) -> base.SplashCustomReturnType: + return _splash_attention_custom( + fwd_mask_info, + dkv_mask_info, + q, + k, + v, + segment_ids, + sinks, + mask_value=mask_value, + is_mqa=is_mqa, + save_residuals=save_residuals, + config=config, + max_logit_value=max_logit_value, + mask_function=mask_function, + fwd_mask_sparsity=fwd_mask_sparsity, + dkv_mask_sparsity=dkv_mask_sparsity, + ) + + +@jax.tree_util.register_pytree_node_class +class SplashAttentionKernel: + + def __init__( + self, + fwd_mask_info: MaskInfo, + dkv_mask_info: MaskInfo | None, + **kwargs, + ): + self.kwargs = kwargs + self.fwd_mask_info = fwd_mask_info + self.dkv_mask_info = dkv_mask_info + + def __call__(self, *args, **kwargs) -> base.SplashCustomReturnType: + return _splash_attention( + self.fwd_mask_info, + self.dkv_mask_info, + *args, + **dict(self.kwargs, **kwargs), + ) + + def manual_sharding_spec(self, sharding: jax.sharding.NamedSharding): + """Returns a value that can be used as a shard_map partition spec for the kernel.""" + if self.fwd_mask_info.block_mask is not None: + block_mask_shape = self.fwd_mask_info.block_mask.shape + try: + sharding.shard_shape(block_mask_shape) + except ValueError as exc: + raise ValueError( + "The sharding must divide the mask blocks evenly between devices" + ) from exc + + if len(sharding.spec) != 1: + raise ValueError("Only q sequence sharding is supported.") + + _resolve_spec = lambda x: sharding.spec if x is not None else None + mask_info_specs = MaskInfo( # pytype: disable=wrong-arg-types + mask_next=_resolve_spec(self.fwd_mask_info.mask_next), + active_rows=_resolve_spec(self.fwd_mask_info.active_rows), + active_cols=_resolve_spec(self.fwd_mask_info.active_cols), + num_active_blocks=_resolve_spec(self.fwd_mask_info.num_active_blocks), + block_mask=_resolve_spec(self.fwd_mask_info.block_mask), + partial_mask_blocks=jax.sharding.PartitionSpec() # replicated + if self.fwd_mask_info.partial_mask_blocks is not None + else None, + q_sequence=_resolve_spec(self.fwd_mask_info.q_sequence), + ) + return SplashAttentionKernel( + mask_info_specs, + mask_info_specs if self.dkv_mask_info is not None else None, + **self.kwargs, + ) + + def tree_flatten(self): + return ((self.fwd_mask_info, self.dkv_mask_info), self.kwargs) + + @classmethod + def tree_unflatten(cls, kwargs, values): + fwd_mask_info, dkv_mask_info = values + # NamedTuples are not preserved during pytree serialization. + dkv_mask_info = ( + MaskInfo(*dkv_mask_info) if dkv_mask_info is not None else None + ) + return SplashAttentionKernel( + MaskInfo(*fwd_mask_info), dkv_mask_info, **kwargs + ) + + +def _make_splash_attention( + mask: np.ndarray | mask_lib.Mask, + *, + config: SplashConfig | None = None, + is_mqa: bool, + save_residuals: bool = False, + mask_value: float = base.DEFAULT_MASK_VALUE, + downcast_smem_data: bool = True, + partial_mask_blocks_dtype: jax.typing.DTypeLike = np.int8, + q_seq_shards: int, +): + if len(mask.shape) != 2: + raise ValueError(f"Unexpected mask shape: {mask.shape}") + + if isinstance(mask, np.ndarray): + mask = mask_lib.NumpyMask(mask) + + if config is None: + config = SplashConfig.get_default() + + process_fn = partial( + mask_info_lib.process_mask, + downcast_smem_data=downcast_smem_data, + partial_mask_blocks_dtype=partial_mask_blocks_dtype, + q_seq_shards=q_seq_shards, + ) + + fwd_mask_info, mask_function_fwd = process_fn( + mask, + (config.block_q, config.block_kv), + ) + fwd_mask_sparsity = float(np.mean(fwd_mask_info.block_mask != 0)) + fwd_mask_info = tree_util.tree_map(jnp.array, fwd_mask_info) + + dkv_mask_info = None + if config.has_backward_blocks: + bq_dkv, bkv_dkv = config.block_q_dkv, config.block_kv_dkv + dkv_mask_info, mask_function_dkv = process_fn( + mask, + (bq_dkv, bkv_dkv), + is_dkv=True, + return_dynamic_grid=config.dq_reduction_steps == 3, + ) + + assert (mask_function_fwd is None) == (mask_function_dkv is None) + + dkv_mask_sparsity = float(np.mean(dkv_mask_info.block_mask != 0)) + dkv_mask_info = tree_util.tree_map(jnp.array, dkv_mask_info) + else: + dkv_mask_sparsity = 1.0 + + return SplashAttentionKernel( + fwd_mask_info, + dkv_mask_info, + config=config, + is_mqa=is_mqa, + save_residuals=save_residuals, + mask_value=mask_value, + mask_function=mask_function_fwd, + fwd_mask_sparsity=fwd_mask_sparsity, + dkv_mask_sparsity=dkv_mask_sparsity, + ) + + +def _make_dynamic_splash_attention( + mask: jax.Array, + *, + mesh: jax.sharding.Mesh | None = None, + mask_spec: jax.sharding.PartitionSpec | None = None, + config: SplashConfig | None = None, + is_mqa: bool, + save_residuals: bool = False, + mask_value: float = base.DEFAULT_MASK_VALUE, + downcast_smem_data: bool = True, + partial_mask_blocks_dtype: jax.typing.DTypeLike = np.int8, +): + if (mesh is not None) != (mask_spec is not None): + raise ValueError( + "Either both or neither of mesh and mask_spec must be specified." + ) + + if mask_spec is not None and len(mask_spec) != 1: + raise ValueError("Only shard over the query sequence dimension.") + + if len(mask.shape) != 2: + raise ValueError(f"Unexpected mask shape: {mask.shape}") + + if config is None: + config = SplashConfig.get_default() + + # This is the only mode that supports the dynamic grid. + config = dataclasses.replace(config, dq_reduction_steps=3) + + def process_mask_shard(mask): + process_mask_fn = functools.partial( + mask_info_lib._process_dynamic_mask, + downcast_smem_data=downcast_smem_data, + partial_mask_blocks_dtype=partial_mask_blocks_dtype, + ) + + fwd_mask_info = process_mask_fn( + mask, (config.block_q, config.block_kv), is_dkv=False + ) + + dkv_mask_info = None + if config.has_backward_blocks: + dkv_mask_info = process_mask_fn( + mask, (config.block_q_dkv, config.block_kv_dkv), is_dkv=True + ) + + return fwd_mask_info, dkv_mask_info + + kwargs = dict( + config=config, + is_mqa=is_mqa, + save_residuals=save_residuals, + mask_value=mask_value, + mask_function=None, + fwd_mask_sparsity=1.0, + dkv_mask_sparsity=1.0, + ) + + # If the input mask is replicated we don't need to call shard_map. + if mask_spec is None: + fwd_mask_info, dkv_mask_info = process_mask_shard(mask) + kernel = SplashAttentionKernel(fwd_mask_info, dkv_mask_info, **kwargs) + return kernel + + mask_info_specs = MaskInfo( # pytype: disable=wrong-arg-types + mask_next=mask_spec, + active_rows=None, + active_cols=None, + num_active_blocks=None, + block_mask=mask_spec, + partial_mask_blocks=mask_spec, + q_sequence=None, + ) + out_specs = ( + mask_info_specs, + mask_info_specs if config.has_backward_blocks else None, + ) + + @partial( + jax.shard_map, + mesh=mesh, + in_specs=mask_spec, + out_specs=out_specs, + check_vma=False, + ) + def process_all_shards(mask): + return process_mask_shard(mask) + + fwd_mask_info, dkv_mask_info = process_all_shards(mask) + kernel = SplashAttentionKernel(fwd_mask_info, dkv_mask_info, **kwargs) + kernel_spec = SplashAttentionKernel(*out_specs, **kwargs) + + return (kernel, kernel_spec) + + +make_splash_mha = partial(_make_splash_attention, is_mqa=False) +make_splash_mqa = partial(_make_splash_attention, is_mqa=True) + +make_splash_mha_single_device = partial(make_splash_mha, q_seq_shards=1) + +make_splash_mqa_single_device = partial(make_splash_mqa, q_seq_shards=1) + +make_dynamic_splash_mqa = partial(_make_dynamic_splash_attention, is_mqa=True) +make_dynamic_splash_mha = partial(_make_dynamic_splash_attention, is_mqa=False) diff --git a/src/maxdiffusion/kernels/splash_attention/splash_attention_kernel_sharded_test.py b/src/maxdiffusion/kernels/splash_attention/splash_attention_kernel_sharded_test.py new file mode 100644 index 00000000..3bd01fc4 --- /dev/null +++ b/src/maxdiffusion/kernels/splash_attention/splash_attention_kernel_sharded_test.py @@ -0,0 +1,251 @@ +# Copyright 2025 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Tests for partitioning splash_attention.""" + +import functools +import math + +from absl.testing import absltest +from absl.testing import parameterized +import jax +from jax import random +import jax.numpy as jnp +import numpy as np +from . import base +from . import splash_attention_kernel as splash +from . import splash_attention_mask as mask_lib +from . import splash_attention_test_utils as test_utils + + +PartitionSpec = jax.sharding.PartitionSpec +P = jax.P +partial = functools.partial + +jax.config.parse_flags_with_absl() + + +class PallasBaseTest(test_utils.SplashAttentionTestCase): + INTERPRET = False + + def setUp(self): + super().setUp() + if not test_utils.test_device_matches(["tpu"]): + self.skipTest("Test requires TPU.") + + if len(jax.devices()) < 4: + self.skipTest("This test requires at least 4 devices.") + + +class SplashAttentionShardingTest(PallasBaseTest): + + def setUp(self): + self.skipTest("no sharding on runners") + if jax.default_backend() != "tpu": + self.skipTest("Only supported on TPUs.") + super().setUp() + + @parameterized.product( + topology=[(2, 2), (1, 4), (4, 1)], + num_heads=[2, 16], + dtype=[jnp.bfloat16], + is_segmented=[False, True], + is_dynamic_mask=[False, True], + ) + def test_manual_partitioning_mha_fwd( + self, topology, num_heads, dtype, is_segmented, is_dynamic_mask + ): + # TODO: Re-enable once dynamic masks are fixed. + if is_dynamic_mask: + self.skipTest("Dynamic masks not supported.") + + k1, k2, k3 = random.split(random.key(0), 3) + seq_len = 1024 + head_dim = 128 + + head_shards, q_seq_shards = topology + num_devices = math.prod(topology) + + if head_shards > num_heads: + self.skipTest( + f"This test requires {num_heads} heads, but has only" + f" {head_shards} head shards available." + ) + + if len(jax.devices()) < num_devices: + self.skipTest( + f"This test requires {num_devices} devices, but has only" + f" {len(jax.devices())} devices available." + ) + + q = random.uniform(k1, (num_heads, seq_len, head_dim), dtype=dtype) + k = random.uniform(k2, (num_heads, seq_len, head_dim), dtype=dtype) + v = random.uniform(k3, (num_heads, seq_len, head_dim), dtype=dtype) + + mask = mask_lib.make_causal_mask((seq_len, seq_len)) + if is_dynamic_mask: + mask = jnp.array(mask) + + if is_segmented: + segment_ids = test_utils.create_segment_ids(seq_len) + segment_ids_spec = base.SegmentIds( + q=PartitionSpec("q_seq" if q_seq_shards > 1 else None), + kv=PartitionSpec(None), + ) + else: + segment_ids = segment_ids_spec = None + + devices = np.asarray(jax.devices()[:num_devices]).reshape( + head_shards, q_seq_shards + ) + + mesh = jax.sharding.Mesh(devices, ("heads", "q_seq")) + q_spec = PartitionSpec( + "heads" if head_shards > 1 else None, + "q_seq" if q_seq_shards > 1 else None, + ) + mask_spec = PartitionSpec("q_seq" if q_seq_shards > 1 else None) + kv_spec = PartitionSpec("heads" if head_shards > 1 else None, None) + + if is_dynamic_mask: + kernel, kernel_spec = splash.make_dynamic_splash_mha( + mask, mesh=mesh, mask_spec=mask_spec + ) + else: + kernel = splash.make_splash_mha(mask, q_seq_shards=q_seq_shards) + kernel_spec = kernel.manual_sharding_spec( + jax.sharding.NamedSharding(mesh, mask_spec) + ) + + @partial( + jax.shard_map, + mesh=mesh, + in_specs=( + kernel_spec, + q_spec, + kv_spec, + kv_spec, + segment_ids_spec, + ), + out_specs=q_spec, + check_vma=False, + ) + def f(kernel, q, k, v, segment_ids): + return kernel(q, k, v, segment_ids) + + out = f(kernel, q, k, v, segment_ids) + out_ref = base.attention_reference(q, k, v, mask, segment_ids, is_mqa=False) + self._assert_allclose(out, out_ref, rtol=5e-3, atol=3e-3) + + @parameterized.product( + topology=[(2, 2), (1, 4), (4, 1)], + num_heads=[2, 4], + dtype=[jnp.bfloat16], + is_segmented=[False, True], + is_dynamic_mask=[False, True], + ) + def test_manual_partitioning_mha_bwd( + self, topology, num_heads, dtype, is_segmented, is_dynamic_mask + ): + # TODO: Re-enable once dynamic masks are fixed. + if is_dynamic_mask: + self.skipTest("Dynamic masks not supported.") + + assert num_heads % 2 == 0 + k1, k2, k3, k4 = random.split(random.key(0), 4) + seq_len = 1024 + head_dim = 128 + + head_shards, q_seq_shards = topology + num_devices = math.prod(topology) + + if head_shards > num_heads: + self.skipTest( + f"This test requires {num_heads} heads, but has only" + f" {head_shards} head shards available." + ) + + q = random.uniform(k1, (num_heads, seq_len, head_dim), dtype=dtype) + k = random.uniform(k2, (num_heads, seq_len, head_dim), dtype=dtype) + v = random.uniform(k3, (num_heads, seq_len, head_dim), dtype=dtype) + + mask = mask_lib.make_causal_mask((seq_len, seq_len)) + if is_dynamic_mask: + mask = jnp.array(mask) + + if is_segmented: + segment_ids = test_utils.create_segment_ids(seq_len) + segment_ids_spec = base.SegmentIds( + q=PartitionSpec("q_seq" if q_seq_shards > 1 else None), + kv=PartitionSpec(None), + ) + else: + segment_ids = segment_ids_spec = None + + devices = np.asarray(jax.devices()[:num_devices]).reshape( + head_shards, q_seq_shards + ) + + mesh = jax.sharding.Mesh(devices, ("heads", "q_seq")) + q_spec = PartitionSpec( + "heads" if head_shards > 1 else None, + "q_seq" if q_seq_shards > 1 else None, + ) + mask_spec = PartitionSpec("q_seq" if q_seq_shards > 1 else None) + kv_spec = PartitionSpec("heads" if head_shards > 1 else None, None) + + if is_dynamic_mask: + kernel, kernel_spec = splash.make_dynamic_splash_mha( + mask, mesh=mesh, mask_spec=mask_spec + ) + else: + kernel = splash.make_splash_mha(mask, q_seq_shards=q_seq_shards) + kernel_spec = kernel.manual_sharding_spec( + jax.sharding.NamedSharding(mesh, mask_spec) + ) + + @partial( + jax.shard_map, + mesh=mesh, + in_specs=( + kernel_spec, + q_spec, + kv_spec, + kv_spec, + segment_ids_spec, + ), + out_specs=q_spec, + check_vma=False, + ) + def f(kernel, q, k, v, segment_ids): + return kernel(q, k, v, segment_ids) + + f_ref = partial(base.attention_reference, is_mqa=False) + + out, out_vjp = jax.vjp(f, kernel, q, k, v, segment_ids) + out_ref, out_vjp_ref = jax.vjp(f_ref, q, k, v, mask, segment_ids) + self._assert_allclose(out, out_ref, rtol=5e-3, atol=5e-3) + + do = random.uniform(k4, out.shape, dtype=out.dtype) + _, dq, dk, dv, _ = out_vjp(do) + dq_ref, dk_ref, dv_ref, _, _ = out_vjp_ref(do.astype(jnp.float32)) + + self._assert_allclose(dq, dq_ref, atol=8e-2, rtol=1e-2) + self._assert_allclose(dk, dk_ref, atol=8e-2, rtol=2e-2) + self._assert_allclose(dv, dv_ref, atol=8e-2, rtol=1e-2) + + +if __name__ == "__main__": + absltest.main() diff --git a/src/maxdiffusion/kernels/splash_attention/splash_attention_kernel_test.py b/src/maxdiffusion/kernels/splash_attention/splash_attention_kernel_test.py new file mode 100644 index 00000000..ed033a80 --- /dev/null +++ b/src/maxdiffusion/kernels/splash_attention/splash_attention_kernel_test.py @@ -0,0 +1,636 @@ +# Copyright 2025 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from collections.abc import Callable +import dataclasses +import functools +from typing import Any, TypeVar + +from absl.testing import absltest +from absl.testing import parameterized +import hypothesis as hp +import hypothesis.strategies as hps +import jax +from jax import random +import jax.numpy as jnp +import numpy as np +from . import base +from . import splash_attention_kernel as splash +from . import splash_attention_mask as mask_lib +from . import splash_attention_test_utils as test_utils + + +jax.config.parse_flags_with_absl() + + +hp.settings.register_profile( + name="deterministic", + database=None, + derandomize=True, + deadline=None, + max_examples=15, + print_blob=True, + verbosity=hp.Verbosity.verbose, +) +hp.settings.load_profile(name="deterministic") + +partial = functools.partial +Draw = TypeVar("Draw", bound=Callable[[hps.SearchStrategy[Any]], Any]) + + +@dataclasses.dataclass +class ModelConfig: + q_seq_len: int + kv_seq_len: int + num_q_heads: int + num_kv_heads: int + head_dim_qk: int + head_dim_v: int + dtype: np.dtype + + +@hps.composite +def segment_ids_strategy(draw, seq_len: int) -> base.SegmentIds: + boundaries = hps.sets(hps.integers(1, seq_len - 1), min_size=1, max_size=4) + bounds = sorted(draw(boundaries)) + ids_array = np.empty((seq_len,), dtype=np.int32) + for i, (start, end) in enumerate(zip((0, *bounds), (*bounds, seq_len))): + # Not sure why, but short segments can trip things up + if end - start < 2: + end = start + 2 + ids_array[start:end] = i + return base.SegmentIds(ids_array, ids_array) + + +def seed_strategy() -> hps.SearchStrategy[int]: + return hps.integers(min_value=0, max_value=4) + + +class Mask: + + def get_mask(self) -> mask_lib.Mask: + raise NotImplementedError() + + +def full_mask_strategy( + q_seq_len: int, kv_seq_len: int +) -> hps.SearchStrategy[Mask]: + return hps.just(FullMask(q_seq_len, kv_seq_len)) + + +@dataclasses.dataclass +class SplitMask(Mask): + q_seq_len: int + kv_seq_len: int + + def get_mask(self) -> mask_lib.Mask: + mask = np.ones((self.q_seq_len, self.kv_seq_len)).astype(np.bool_) + mask[:, mask.shape[1] // 2 :] = False + return mask_lib.NumpyMask(mask) + + +def split_mask_strategy( + q_seq_len: int, kv_seq_len: int +) -> hps.SearchStrategy[Mask]: + return hps.just(SplitMask(q_seq_len, kv_seq_len)) + + +@dataclasses.dataclass +class FullMask(Mask): + q_seq_len: int + kv_seq_len: int + + def get_mask(self) -> mask_lib.Mask: + return mask_lib.FullMask((self.q_seq_len, self.kv_seq_len)) + + +def causal_mask_strategy( + q_seq_len: int, kv_seq_len: int +) -> hps.SearchStrategy[Mask]: + return hps.just(CausalMask(q_seq_len, kv_seq_len)) + + +@dataclasses.dataclass +class CausalMask(Mask): + q_seq_len: int + kv_seq_len: int + + def get_mask(self) -> mask_lib.Mask: + return mask_lib.CausalMask((self.q_seq_len, self.kv_seq_len)) + + +@dataclasses.dataclass +class LocalAttentionMask(Mask): + seq_len: int + left: int | None + right: int | None + offset: int + + def get_mask(self) -> mask_lib.Mask: + mask = mask_lib.LocalMask( + (self.seq_len, self.seq_len), + (self.left, self.right), + offset=self.offset, + ) + # Make sure that no row is full of zeros as this is leads to undefined + # softmax. + diagonal = mask_lib.NumpyMask(np.identity(self.seq_len, dtype=np.bool_)) + return mask | diagonal + + +@hps.composite +def local_attention_mask_strategy(draw: Draw, seq_len: int) -> Mask: + left_window = draw( + hps.one_of(hps.none(), hps.integers(min_value=0, max_value=seq_len)) + ) + right_window = draw( + hps.one_of(hps.none(), hps.integers(min_value=0, max_value=seq_len)) + ) + offset = draw(hps.integers(min_value=-seq_len, max_value=seq_len - 1)) + return LocalAttentionMask(seq_len, left_window, right_window, offset=offset) + + +@dataclasses.dataclass +class RandomMask(Mask): + q_seq_len: int + kv_seq_len: int + sparsity: float + seed: int + + def get_mask(self) -> mask_lib.Mask: + mask = mask_lib.make_random_mask( + (self.q_seq_len, self.kv_seq_len), self.sparsity, self.seed + ) + # Make sure that no row is full of zeros as this is leads to undefined + # softmax. + mask[:, 0] = True + + return mask_lib.NumpyMask(mask) + + +@hps.composite +def random_mask_strategy(draw: Draw, q_seq_len: int, kv_seq_len: int) -> Mask: + rand = draw(hps.randoms()) + seed = rand.randint(0, 2**32 - 1) + sparsity = rand.uniform(0.01, 0.5) + return RandomMask(q_seq_len, kv_seq_len, sparsity, seed) + + +@dataclasses.dataclass +class ComposeMask(Mask): + left: Mask + right: Mask + op: Callable[[mask_lib.Mask, mask_lib.Mask], mask_lib.Mask] + + def get_mask(self) -> mask_lib.Mask: + return self.op(self.left.get_mask(), self.right.get_mask()) + + +@hps.composite +def compose_mask_strategy(draw: Draw, q_seq_len: int, kv_seq_len: int) -> Mask: + mask1 = draw(mask_strategy(q_seq_len, kv_seq_len)) + mask2 = draw(mask_strategy(q_seq_len, kv_seq_len)) + op = draw( + hps.one_of(hps.just(mask_lib.LogicalOr), hps.just(mask_lib.LogicalAnd)) + ) + return ComposeMask(mask1, mask2, op) + + +@hps.composite +def mask_strategy(draw: Draw, q_seq_len: int, kv_seq_len: int) -> Mask: + oneof = [ + causal_mask_strategy(q_seq_len, kv_seq_len), + full_mask_strategy(q_seq_len, kv_seq_len), + split_mask_strategy(q_seq_len, kv_seq_len), + random_mask_strategy(q_seq_len, kv_seq_len), + # TODO Composing masks creates masks that produce minor numerical + # differences. We should investigate this in the future. + # compose_mask_strategy(q_seq_len, kv_seq_len), + ] + + if q_seq_len == kv_seq_len: + oneof.append(local_attention_mask_strategy(q_seq_len)) + + return draw(hps.one_of(oneof)) + + +@hps.composite +def model_config_strategy(draw: Draw) -> ModelConfig: + q_seq_len = draw(hps.sampled_from([1024, 2048, 4096])) + kv_seq_len = draw(hps.sampled_from([1024, 2048, 4096])) + head_dim_qk, head_dim_v = draw( + hps.sampled_from( + [(64, 128), (64, 64), (128, 128), (256, 256), (192, 128)] + ) + ) + if q_seq_len >= 4096 and kv_seq_len >= 4096: + dtype = np.dtype("float32") + else: + dtype = draw( + hps.sampled_from([np.dtype("float32"), np.dtype(jnp.bfloat16)]) + ) + + num_q_heads, num_kv_heads = draw( + hps.sampled_from([(1, 1), (2, 2), (4, 1), (8, 4), (6, 2)]) + ) + return ModelConfig( + q_seq_len, + kv_seq_len, + num_q_heads, + num_kv_heads, + head_dim_qk, + head_dim_v, + dtype, + ) + + +def check_mask_no_empty_rows( + mask: mask_lib.Mask, segment_ids: splash.SegmentIds | None +): + effective_mask = np.array(mask[:, :]) + + if segment_ids is not None: + segment_mask = segment_ids.q[:, None] == segment_ids.kv[None, :] + effective_mask = effective_mask & segment_mask + + hp.assume(np.all(np.any(effective_mask, axis=1))) + + +@hps.composite +def block_sizes_strategy( + draw: Draw, + q_seq_len: int, + kv_seq_len: int, + include_bwd_blocks: bool = False, +) -> splash.SplashConfig: + all_block_shapes = [128, 256, 512] + q_layout = draw(hps.sampled_from(splash.QKVLayout)) + k_layout = draw(hps.sampled_from(splash.QKVLayout)) + v_layout = draw(hps.sampled_from(splash.QKVLayout)) + layouts = dict(q_layout=q_layout, k_layout=k_layout, v_layout=v_layout) + q_valid_block_shapes = [bs for bs in all_block_shapes if bs <= q_seq_len] + kv_valid_block_shapes = [bs for bs in all_block_shapes if bs <= kv_seq_len] + bq, bkv = ( + draw(hps.sampled_from(q_valid_block_shapes)), + draw(hps.sampled_from(kv_valid_block_shapes)), + ) + bkv_compute = draw( + hps.sampled_from([None, *[b for b in kv_valid_block_shapes if b <= bkv]]) + ) + if not include_bwd_blocks: + return splash.SplashConfig( + block_q=bq, block_kv=bkv, block_kv_compute=bkv_compute, **layouts + ) + all_block_shapes = [128, 256] + q_valid_block_shapes = [bs for bs in all_block_shapes if bs <= q_seq_len] + kv_valid_block_shapes = [bs for bs in all_block_shapes if bs <= kv_seq_len] + bq_dkv, bkv_dkv = ( + draw(hps.sampled_from(q_valid_block_shapes)), + draw(hps.sampled_from(kv_valid_block_shapes)), + ) + block_kv_dkv_compute = draw( + hps.sampled_from( + [None, *[b for b in kv_valid_block_shapes if b <= bkv_dkv]] + ) + ) + return splash.SplashConfig( + block_q=bq, + block_kv=bkv, + block_kv_compute=bkv_compute, + block_q_dkv=bq_dkv, + block_kv_dkv=bkv_dkv, + block_kv_dkv_compute=block_kv_dkv_compute, + **layouts, + ) + + +def _generate_inputs( + data, + config: ModelConfig, + is_mqa: bool, + is_segmented: bool, + use_sinks: bool = False, +) -> tuple[ + jax.Array, + jax.Array, + jax.Array, + jax.Array | None, + splash.SegmentIds | None, + jax.Array, +]: + seed = data.draw(seed_strategy()) + key = random.key(seed) + k1, k2, k3, k_sinks, k_do = random.split(key, 5) + + q_shape = (config.num_q_heads, config.q_seq_len, config.head_dim_qk) + if is_mqa: + k_shape = (config.kv_seq_len, config.head_dim_qk) + v_shape = (config.kv_seq_len, config.head_dim_v) + else: + k_shape = (config.num_kv_heads, config.kv_seq_len, config.head_dim_qk) + v_shape = (config.num_kv_heads, config.kv_seq_len, config.head_dim_v) + + q = random.uniform(k1, q_shape, dtype=config.dtype) + k = random.uniform(k2, k_shape, dtype=config.dtype) + v = random.uniform(k3, v_shape, dtype=config.dtype) + + sinks = None + if use_sinks: + sinks = random.uniform(k_sinks, (config.num_q_heads,), dtype=config.dtype) + + segment_ids = None + if is_segmented: + hp.assume(config.q_seq_len == config.kv_seq_len) + segment_ids = data.draw(segment_ids_strategy(config.q_seq_len)) + + o_shape = (config.num_q_heads, config.q_seq_len, config.head_dim_v) + do = random.uniform(k_do, o_shape, dtype=config.dtype) + return (q, k, v, sinks, segment_ids, do) + + +def attn_logits_soft_cap_strategy() -> hps.SearchStrategy[float | None]: + return hps.one_of(hps.just(None), hps.floats(min_value=1.0, max_value=50.0)) + + +@test_utils.thread_unsafe_test_class() # hypothesis is not thread safe +class SplashAttentionTest(test_utils.SplashAttentionTestCase): + + def setUp(self): + if jax.default_backend() != "tpu": + self.skipTest("Only supported on TPUs.") + super().setUp() + + @parameterized.product( + is_mqa=(False, True), + is_segmented=(False, True), + is_dynamic_mask=(False, True), + ) + @hp.given(hps.data()) + def test_splash_attention(self, is_mqa, is_segmented, is_dynamic_mask, data): + model_config = data.draw(model_config_strategy()) + q_seq_len, kv_seq_len = model_config.q_seq_len, model_config.kv_seq_len + q, k, v, _, segment_ids, _ = _generate_inputs( + data, model_config, is_mqa, is_segmented + ) + attn_logits_soft_cap = data.draw(attn_logits_soft_cap_strategy()) + mask = data.draw(mask_strategy(q_seq_len, kv_seq_len)).get_mask() + check_mask_no_empty_rows(mask, segment_ids) + if is_dynamic_mask: + mask = jnp.array(mask[:, :]) + config = data.draw(block_sizes_strategy(q_seq_len, kv_seq_len)) + config = dataclasses.replace( + config, + attn_logits_soft_cap=attn_logits_soft_cap, + interpret=self.INTERPRET, + ) + + attn_ref = partial(base.attention_reference, is_mqa=is_mqa) + if is_mqa: + if not is_dynamic_mask: + make_mask_fn = splash.make_splash_mqa_single_device + else: + make_mask_fn = splash.make_dynamic_splash_mqa + else: + if not is_dynamic_mask: + make_mask_fn = splash.make_splash_mha_single_device + else: + make_mask_fn = splash.make_dynamic_splash_mha + + attn = make_mask_fn(mask, config=config) + + o = attn(q, k, v, segment_ids) + o_ref = attn_ref( + q.astype(np.float32), + k.astype(np.float32), + v.astype(np.float32), + jnp.array(mask[:, :]), + segment_ids, + attn_logits_soft_cap=attn_logits_soft_cap, + ) + self._assert_allclose(o, o_ref, atol=6e-3, rtol=3e-3) + + @parameterized.product( + is_mqa=(False, True), + is_segmented=(False, True), + is_dynamic_mask=(False, True), + use_base2_exp=(False, True), + use_max_logit_estimate=(None, "const", "value_1d", "value_2d"), + fuse_reciprocal=(True, False), + use_sinks=(False, True), + ) + @hp.given(hps.data()) + def test_splash_attention_fwd(self, is_mqa, is_segmented, is_dynamic_mask, + use_base2_exp, use_max_logit_estimate, + fuse_reciprocal, use_sinks, data): + model_config = data.draw(model_config_strategy()) + q_seq_len, kv_seq_len = model_config.q_seq_len, model_config.kv_seq_len + q, k, v, sinks, segment_ids, _ = _generate_inputs( + data, model_config, is_mqa, is_segmented, use_sinks + ) + attn_logits_soft_cap = data.draw(attn_logits_soft_cap_strategy()) + mask = data.draw(mask_strategy(q_seq_len, kv_seq_len)).get_mask() + check_mask_no_empty_rows(mask, segment_ids) + if is_dynamic_mask: + mask = jnp.array(mask[:, :]) + config = data.draw(block_sizes_strategy(q_seq_len, kv_seq_len)) + if is_mqa: + if not is_dynamic_mask: + make_mask_fn = splash.make_splash_mqa_single_device + else: + make_mask_fn = splash.make_dynamic_splash_mqa + else: + if not is_dynamic_mask: + make_mask_fn = splash.make_splash_mha_single_device + else: + make_mask_fn = splash.make_dynamic_splash_mha + + config = dataclasses.replace( + config, + fuse_reciprocal=fuse_reciprocal, + attn_logits_soft_cap=attn_logits_soft_cap, + use_base2_exp=use_base2_exp, + interpret=self.INTERPRET, + ) + + max_logit_value, max_val = None, 30.0 + if use_max_logit_estimate == "const": + config = dataclasses.replace(config, max_logit_const=max_val) + elif use_max_logit_estimate == "value_1d": + max_logit_value = max_val * jnp.ones((1,), dtype=jnp.bfloat16) + elif use_max_logit_estimate == "value_2d": + max_logit_value = max_val * jnp.ones( + (model_config.num_q_heads,), dtype=jnp.bfloat16 + ) + attn = make_mask_fn(mask, config=config, save_residuals=True) + attn_ref = partial( + base.attention_reference, + is_mqa=is_mqa, + save_residuals=True, + attn_logits_soft_cap=attn_logits_soft_cap, + ) + + o, stats = attn( + q, k, v, segment_ids, sinks, max_logit_value=max_logit_value + ) + + o_ref, stats_ref = attn_ref( + q.astype(jnp.float32), + k.astype(jnp.float32), + v.astype(jnp.float32), + jnp.array(mask[:, :]), + segment_ids, + sinks, + ) + + lse_tol = dict(atol=1e-3, rtol=3e-3) + max_logits_tol = dict(atol=1e-3, rtol=4e-3) + if use_sinks: + o_tol = dict(atol=8e-2, rtol=1e-1) + lse_tol['rtol'] = 6e-2 + elif (use_base2_exp or use_max_logit_estimate is not None + or not fuse_reciprocal): + o_tol = dict(atol=8e-3, rtol=3e-3) + else: + o_tol = dict(atol=4e-3, rtol=3e-3) + + self._assert_allclose(o, o_ref, **o_tol) + self._assert_allclose(stats["logsumexp"], + stats_ref["logsumexp"], **lse_tol) + if use_max_logit_estimate is None: + self._assert_allclose(stats["max_logits"], + stats_ref["max_logits"], **max_logits_tol) + + @parameterized.product( + is_mqa=(False, True), + is_segmented=(False, True), + is_dynamic_mask=(False, True), + # use_max_logit_estimate=(None, "const", "value_1d", "value_2d"), + use_max_logit_estimate=(None,), + use_sinks=(False, True), + dq_reduction_steps=(None, 3), + ) + @hp.given(hps.data()) + def test_splash_attention_bwd( + self, + is_mqa, + is_segmented, + is_dynamic_mask, + use_max_logit_estimate, + dq_reduction_steps, + use_sinks, + data, + ): + downcast_smem_data = data.draw(hp.strategies.booleans()) + fuse_reciprocal = data.draw(hp.strategies.booleans()) + use_base2_exp = data.draw(hp.strategies.booleans()) + + model_config = data.draw(model_config_strategy()) + q_seq_len, kv_seq_len = model_config.q_seq_len, model_config.kv_seq_len + q, k, v, sinks, segment_ids, do = _generate_inputs( + data, model_config, is_mqa, is_segmented, use_sinks=use_sinks + ) + attn_logits_soft_cap = data.draw(attn_logits_soft_cap_strategy()) + mask = data.draw(mask_strategy(q_seq_len, kv_seq_len)).get_mask() + check_mask_no_empty_rows(mask, segment_ids) + if is_dynamic_mask: + mask = jnp.array(mask[:, :]) + config = data.draw( + block_sizes_strategy(q_seq_len, kv_seq_len, include_bwd_blocks=True) + ) + + config = dataclasses.replace( + config, + fuse_reciprocal=fuse_reciprocal, + attn_logits_soft_cap=attn_logits_soft_cap, + interpret=self.INTERPRET, + use_base2_exp=use_base2_exp, + dq_reduction_steps=dq_reduction_steps, + ) + if is_mqa: + if not is_dynamic_mask: + make_mask_fn = splash.make_splash_mqa_single_device + else: + make_mask_fn = splash.make_dynamic_splash_mqa + else: + if not is_dynamic_mask: + make_mask_fn = splash.make_splash_mha_single_device + else: + make_mask_fn = splash.make_dynamic_splash_mha + + max_logit_value, max_val = None, 30.0 + if use_max_logit_estimate == "const": + config = dataclasses.replace(config, max_logit_const=max_val) + elif use_max_logit_estimate == "value_1d": + max_logit_value = max_val * jnp.ones((1,), dtype=jnp.bfloat16) + elif use_max_logit_estimate == "value_2d": + max_logit_value = max_val * jnp.ones( + (model_config.num_q_heads,), dtype=jnp.bfloat16 + ) + + attn = make_mask_fn( + mask, config=config, downcast_smem_data=downcast_smem_data + ) + + o, attn_vjp = jax.vjp(partial(attn, max_logit_value=max_logit_value), + q, k, v, segment_ids, sinks) + q32, k32, v32 = jax.tree.map(lambda x: x.astype(jnp.float32), (q, k, v)) + o_ref, stats_ref = base.attention_reference( + q32, + k32, + v32, + jnp.array(mask[:, :]), + segment_ids, + sinks, + is_mqa=is_mqa, + save_residuals=True, + attn_logits_soft_cap=attn_logits_soft_cap, + ) + if use_sinks: + o_tol = dict(atol=1e-2, rtol=1e-1) + elif (use_base2_exp or use_max_logit_estimate is not None + or not fuse_reciprocal): + o_tol = dict(atol=8e-3, rtol=1e-2) + else: + o_tol = dict(atol=4e-3, rtol=3e-3) + self._assert_allclose(o, o_ref, **o_tol) + + dq, dk, dv, _, dsinks = attn_vjp(do) + dq_ref, dk_ref, dv_ref, dsinks_ref = base.attention_reference_vjp( + do.astype(jnp.float32), + q32, + k32, + v32, + jnp.array(mask[:, :]), + segment_ids, + sinks, + o.astype(jnp.float32), + stats_ref["logsumexp"], + is_mqa=is_mqa, + backward_impl="flash", + attn_logits_soft_cap=attn_logits_soft_cap, + ) + + dq_atol = 8e-2 if use_base2_exp else 2e-2 + dk_atol = 7e-2 if use_base2_exp else 2e-2 + dv_atol = 2e-2 if use_base2_exp else 2e-2 + self._assert_allclose(dq, dq_ref, atol=dq_atol, rtol=3e-2) + self._assert_allclose(dk, dk_ref, atol=dk_atol, rtol=3e-2) + self._assert_allclose(dv, dv_ref, atol=dv_atol, rtol=3e-2) + if use_sinks: + self._assert_allclose(dsinks, dsinks_ref, atol=4e-3, rtol=6e-3) + + +if __name__ == "__main__": + absltest.main() diff --git a/src/maxdiffusion/kernels/splash_attention/splash_attention_mask.py b/src/maxdiffusion/kernels/splash_attention/splash_attention_mask.py new file mode 100644 index 00000000..ce176af7 --- /dev/null +++ b/src/maxdiffusion/kernels/splash_attention/splash_attention_mask.py @@ -0,0 +1,513 @@ +# Copyright 2025 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Mini-mask creation library.""" + +from collections.abc import Callable +import dataclasses +from typing import Any, Self + +import numpy as np + +# mypy: ignore-errors + + +class Mask: + """A base class for splash attention masks.""" + + @property + def shape(self) -> tuple[int, ...]: + raise NotImplementedError + + def __getitem__(self, idx) -> np.ndarray: + raise NotImplementedError + + def __bool__(self) -> bool: + raise NotImplementedError( + 'Conversion to bool is unsupported. Could be caused by using logical' + ' instead of bitwise operations on masks.' + ) + + def __or__(self, other: Self) -> Self: + if self.shape != other.shape: + raise ValueError( + f'Invalid shape for other: {other.shape}, expected: {self.shape}' + ) + return LogicalOr(self, other) + + def __and__(self, other: Self) -> Self: + if self.shape != other.shape: + raise ValueError( + f'Invalid shape for other: {other.shape}, expected: {self.shape}' + ) + return LogicalAnd(self, other) + + +def make_causal_mask(shape: tuple[int, int], offset: int = 0) -> np.ndarray: + """Makes a causal attention mask. + + Args: + shape: Shape of the 2-dim mask: (q_seq_len, kv_seq_len). + offset: Offset of q start wrt kv. A positive offset shifts the bottom + triangle upward, a negative one shifts it downward. A negative offset + makes the first 'offset' rows of the attention matrix all 0s which leads + to undefined softmax. + + Returns: + The causal mask. + """ + q_seq_len, kv_seq_len = shape + q_idx = np.arange(q_seq_len, dtype=np.int32) + kv_idx = np.arange(kv_seq_len, dtype=np.int32) + return (q_idx[:, None] + offset >= kv_idx[None, :]).astype(np.bool_) + + +def make_local_attention_mask( + shape: tuple[int, int], + window_size: tuple[int | None, int | None], + *, + offset: int = 0, +) -> np.ndarray: + """Makes a local attention mask.""" + q_seq_len, kv_seq_len = shape + q_idx = np.arange(q_seq_len, dtype=np.int32) + kv_idx = np.arange(kv_seq_len, dtype=np.int32) + mask = np.ones((q_seq_len, kv_seq_len), dtype=np.bool_) + left, right = window_size + if left is not None: + mask = mask & (q_idx[:, None] - left + offset <= kv_idx[None, :]) + if right is not None: + mask = mask & (q_idx[:, None] + right + offset >= kv_idx[None, :]) + return mask.astype(np.bool_) + + +def make_chunk_attention_mask( + shape: tuple[int, int], chunk_size: int +) -> np.ndarray: + """Makes a chunked causal attention mask. + + Args: + shape: The desired shape of the mask (q_seq_len, kv_seq_len). + chunk_size: The size of the attention chunks. + + Returns: + A boolean mask of shape `mask_shape` where True indicates attention is + allowed according to chunked causal rules, and False otherwise. + + Raises: + ValueError: If chunk_window_size is None or not positive. + """ + if chunk_size <= 0: + raise ValueError('chunk_size must be positive') + + q_seq_len, kv_seq_len = shape + q_idx = np.arange(q_seq_len, dtype=np.int32) + kv_idx = np.arange(kv_seq_len, dtype=np.int32) + + # chunk mask calculation + same_chunk = (q_idx[:, None] // chunk_size) == (kv_idx[None, :] // chunk_size) + mask = same_chunk & (q_idx[:, None] >= kv_idx[None, :]) + return mask + + +def make_random_mask( + shape: tuple[int, int], sparsity: float, seed: int +) -> np.ndarray: + """Makes a random attention mask.""" + np.random.seed(seed) + return np.random.binomial(n=1, p=1.0 - sparsity, size=shape).astype(np.bool_) + + +@dataclasses.dataclass(slots=True) +class LogicalOr(Mask): + left: Mask + right: Mask + + def __init__(self, left: Mask, right: Mask): + if left.shape != right.shape: + raise ValueError('Masks must have the same shape') + self.left = left + self.right = right + + @property + def shape(self) -> tuple[int, ...]: + return self.left.shape + + def __getitem__(self, idx) -> np.ndarray: + return self.left[idx] | self.right[idx] + + def __hash__(self): + return hash((type(self),) + (self.left, self.right)) + + +@dataclasses.dataclass(slots=True) +class LogicalAnd(Mask): + left: Mask + right: Mask + + def __init__(self, left: Mask, right: Mask): + if left.shape != right.shape: + raise ValueError('Masks must have the same shape') + self.left = left + self.right = right + + @property + def shape(self) -> tuple[int, ...]: + return self.left.shape + + def __getitem__(self, idx) -> np.ndarray: + return self.left[idx] & self.right[idx] + + def __hash__(self): + return hash((type(self),) + (self.left, self.right)) + + +class _ComputableMask(Mask): + """Superclass for all masks that can be computed inside the kernel using a callable object. + + This subclass is designed to be used with Splash Attention. + It allows the mask logic to be computed on-the-fly or fused into the attention + kernel, avoiding the memory cost of materializing the full + (sequence_length, sequence_length) boolean mask array, which can be excessive + for long sequences. + + Attributes: + _shape: Shape of the 2-dim mask: (q_seq_len, kv_seq_len). + offset: Offset of q start wrt kv. A positive offset shifts the bottom + triangle upward, a negative one shifts it downward. A negative offset + makes the first 'offset' rows of the attention matrix all 0s which leads + to undefined softmax. + q_sequence: Indices of Q sequence. q_sequence is reused across __getitem__ + calls which is important for compile-time performance. + mask_function: Function used by the SplashAttention kernel to compute the + mask rather than loading it. + """ + + _shape: tuple[int, int] + q_sequence: np.ndarray + mask_function: Callable[..., Any] + + def __init__( + self, + shape: tuple[int, int], + mask_function: Callable[..., Any], + shard_count: int = 1, + ): + self._shape = shape + self.mask_function = mask_function + q_seq_len = self.shape[0] + + if q_seq_len % (shard_count * shard_count) != 0: + raise ValueError( + f'Shard count squared ({shard_count * shard_count}) must' + f' divide Q seq_len ({self.shape[0]}) evenly.' + ) + + self.q_sequence = np.arange(q_seq_len, dtype=np.int32) + + @property + def shape(self) -> tuple[int, ...]: + return self._shape + + def __getitem__(self, idx) -> np.ndarray: + if len(idx) != 2: + raise NotImplementedError(f'Unsupported slice: {idx}') + + q_slice, kv_slice = idx + if not isinstance(q_slice, slice) or not isinstance(kv_slice, slice): + raise NotImplementedError(f'Unsupported slice: {idx}') + + q_slice = _fill_slice(q_slice, self.shape[0]) + kv_slice = _fill_slice(kv_slice, self.shape[1]) + + rows = self.q_sequence[q_slice] + cols = np.arange(kv_slice.start, kv_slice.stop) + + return self.mask_function(rows[:, None], cols[None, :]) + + def __eq__(self, other: object): + raise NotImplementedError() + + def __hash__(self): + raise NotImplementedError() + + +class CausalMask(_ComputableMask): + """Lazy causal mask, prevents the model from attending to future tokens. + + Attributes: + offset: Offset of q start wrt kv. A positive offset shifts the bottom + triangle upward, a negative one shifts it downward. A negative offset + makes the first 'offset' rows of the attention matrix all 0s which leads + to undefined softmax. + """ + + offset: int + + def __init__( + self, + shape: tuple[int, int], + offset: int = 0, + shard_count: int = 1, + ): + self.offset = offset + + def causal_mask_function(q_ids, kv_ids): + # When evaluating the mask in _process_mask we typically work with numpy + # array views. + # Avoid the addition when possible to avoid instantiating an actual array. + if self.offset == 0: + return q_ids >= kv_ids + else: + return q_ids + self.offset >= kv_ids + + mask_function = causal_mask_function + + super().__init__( + shape=shape, + mask_function=mask_function, + shard_count=shard_count, + ) + + def __eq__(self, other: object): + if not isinstance(other, type(self)): + return NotImplemented + + return ( + self.shape == other.shape + and self.offset == other.offset + and np.array_equal(self.q_sequence, other.q_sequence) + ) + + def __hash__(self): + return hash(( + type(self), + self.shape, + self.offset, + self.q_sequence.tobytes() if self.q_sequence is not None else None, + )) + + +class ChunkedCausalMask(_ComputableMask): + """Lazy chunked causal mask. + + Attention is causal within each chunk (0, K), (K, 2K), (2K, 3K), ... tokens + attend to each other but not across chunks. + Llama4 models use interleaved chunk attention along with global attention. + + + Attributes: + chunk_size: The size of each attention chunk. + """ + + chunk_size: int + + def __init__( + self, + shape: tuple[int, int], + chunk_size: int, + shard_count: int = 1, + ): + if chunk_size <= 0: + raise ValueError('chunk_size must be positive') + self.chunk_size = chunk_size + + # Define the mask function for chunk attention + def chunked_causal_mask_function(q_ids, kv_ids): + """Computes the mask logic for the given slice indices.""" + # Condition 1: Same chunk + same_chunk = (q_ids // self.chunk_size) == (kv_ids // self.chunk_size) + + # Condition 2: Causal + causal = q_ids >= kv_ids + + return same_chunk & causal + + super().__init__( + shape=shape, + mask_function=chunked_causal_mask_function, + shard_count=shard_count, + ) + + def __eq__(self, other: object): + if not isinstance(other, type(self)): + return NotImplemented + + return ( + self.shape == other.shape + and self.chunk_size == other.chunk_size + and np.array_equal(self.q_sequence, other.q_sequence) + ) + + def __hash__(self): + return hash(( + type(self), + self.shape, + self.chunk_size, + self.q_sequence.tobytes() if self.q_sequence is not None else None, + )) + + +class LocalMask(_ComputableMask): + """Lazy local mask, prevents model from attending to tokens outside window. + + Attributes: + window_size: Size of the two sides of the local window (None identifies no + limit for the given side). + offset: Offset of q start wrt kv. A positive offset shifts the bottom + triangle upward, a negative one shifts it downward. A negative offset + makes the first 'offset' rows of the attention matrix all 0s which leads + to undefined softmax. + """ + + window_size: tuple[int | None, int | None] + offset: int + + def __init__( + self, + shape: tuple[int, int], + window_size: tuple[int | None, int | None], + offset: int, + shard_count: int = 1, + ): + self.window_size = window_size + self.offset = offset + + def local_mask_function(q_ids, kv_ids): + """Computes the local attention mask for the given slice indices.""" + left_size, right_size = self.window_size + + assert q_ids.ndim == 2 + assert kv_ids.ndim == 2 + + if left_size is None and right_size is None: + return np.ones((q_ids.shape[0], kv_ids.shape[1]), dtype=np.bool_) + + # Avoid the addition when possible to avoid instantiating an actual array. + if offset != 0: + shifted_q_ids = q_ids + self.offset + else: + shifted_q_ids = q_ids + + mask = None + if left_size is not None: + mask = shifted_q_ids - left_size <= kv_ids + if right_size is not None: + if mask is None: + mask = shifted_q_ids + right_size >= kv_ids + else: + mask &= shifted_q_ids + right_size >= kv_ids + return mask + + super().__init__( + shape=shape, + mask_function=local_mask_function, + shard_count=shard_count, + ) + + def __eq__(self, other: object): + if not isinstance(other, type(self)): + return False + + return ( + self.shape == other.shape + and self.window_size == other.window_size + and self.offset == other.offset + and np.array_equal(self.q_sequence, other.q_sequence) + ) + + def __hash__(self): + return hash(( + type(self), + self.shape, + self.window_size, + self.offset, + self.q_sequence.tobytes() if self.q_sequence is not None else None, + )) + + +@dataclasses.dataclass(slots=True) +class NumpyMask(Mask): + """A mask backed by a dense numpy array.""" + + array: np.ndarray + + def __post_init__(self): + if self.array.ndim != 2: + raise ValueError('Expected a 2-dim array') + + if self.array.dtype != np.bool_: + raise ValueError('Mask must be a boolean array') + + @property + def shape(self) -> tuple[int, ...]: + return self.array.shape + + def __getitem__(self, idx) -> np.ndarray: + return self.array[idx] + + def __eq__(self, other: object): + if not isinstance(other, type(self)): + return NotImplemented + + return np.array_equal(self.array, other.array, equal_nan=True) + + def __hash__(self): + return hash((type(self), self.array.tobytes())) + + +def _fill_slice(inp_slice: slice, size: int) -> slice: + assert inp_slice.step is None or inp_slice.step == 1 + start = 0 if inp_slice.start is None else inp_slice.start + stop = size if inp_slice.stop is None else inp_slice.stop + assert start >= 0 + assert stop <= size + return slice(start, stop, None) + + +@dataclasses.dataclass(frozen=True, slots=True) +class FullMask(Mask): + """Lazy full mask, allows all tokens to attend to all other tokens.""" + + # TODO: Transform FullMask into a _ComputableMask. + + _shape: tuple[int, int] + + def __post_init__(self): + if not isinstance(self.shape, tuple): + raise ValueError(f'Unsupported shape type: {type(self.shape)}') + + @property + def shape(self) -> tuple[int, ...]: + return self._shape + + def __getitem__(self, idx) -> np.ndarray: + if len(idx) != 2: + raise NotImplementedError(f'Unsupported slice: {idx}') + i, j = idx + if not isinstance(i, slice) or not isinstance(j, slice): + raise NotImplementedError(f'Unsupported slice: {idx}') + i = _fill_slice(i, self.shape[0]) + j = _fill_slice(j, self.shape[1]) + return np.ones((i.stop - i.start, j.stop - j.start), dtype=np.bool_) + + def __eq__(self, other: object): + if not isinstance(other, type(self)): + return NotImplemented + + return self.shape == other.shape + + def __hash__(self): + return hash((type(self), self.shape)) diff --git a/src/maxdiffusion/kernels/splash_attention/splash_attention_mask_info.py b/src/maxdiffusion/kernels/splash_attention/splash_attention_mask_info.py new file mode 100644 index 00000000..a5d30b58 --- /dev/null +++ b/src/maxdiffusion/kernels/splash_attention/splash_attention_mask_info.py @@ -0,0 +1,577 @@ +# Copyright 2025 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Mini-mask creation library.""" + +import collections +import functools +from typing import Any, NamedTuple + +import jax +import jax.numpy as jnp +import numpy as np +from . import splash_attention_mask as mask_lib + +# mypy: ignore-errors + +lax = jax.lax +MaskCallable = Any + + +def find_bounds( + arr: jax.Array | np.ndarray, +) -> tuple[jax.Array | np.ndarray | None, jax.Array | np.ndarray | None]: + # Find the first and last block of a row to determine when to initialize/store + # the output. + + if arr is None: + return None, None + + bounds_start = (arr != jnp.roll(arr, shift=1, axis=-1)).astype(jnp.int32) + bounds_end = (arr != jnp.roll(arr, shift=-1, axis=-1)).astype(jnp.int32) + bounds_start = bounds_start.at[0].set(1) + bounds_end = bounds_end.at[-1].set(1) + + return bounds_start, bounds_end + + +# Logic for processing NumPy masks for kernels +class MaskInfo(NamedTuple): + """Contains runtime masking information for the Splash attention kernel. + + The arrays, mask_next and block_mask are placed in TPU + scalar-memory. This is a scarse resource so the mask creation logic attempts + to shrink the data-type of these arrays to the smallest possible one. + This can be: np.int32, np.int16 or np.int8. + + Attributes: + mask_next: An integer[num_active_blocks] NumPy array where each entry + contains the next mask block index in `partial_mask_blocks` to prefetch. + active_rows: An integer[num_active_blocks] NumPy array where each entry + contains the row index of the corresponding active block in the original + mask. + active_cols: An integer[num_active_blocks] NumPy array where each entry + contains the column index of the corresponding active block in the + original mask. + block_mask: An integer[num_active_blocks] NumPy array where each entry is + either 1 or 2. 1 means the corresponding block is full and 2 means the + corresponding block is partially masked. + num_active_blocks: An integer[] NumPy array whose entries are the sizes of + the corresponding blocks in the original mask. + partial_mask_blocks: An int8[num_partial_blocks, block_q, block_kv] NumPy + array that contains the blocks of the original mask that contained both + zeros and ones. The entries in `mask_next` point to indices in the first + axis of this array. + q_sequence: A i32[q_sequence_length] NumPy array. When using causal masking, + this contains the list of indices that correspond to q tokens. For plain + causal this is just np.arange(q_sequence_length). + """ + + mask_next: np.ndarray | jax.Array | None + active_rows: np.ndarray | jax.Array | None + active_cols: np.ndarray | jax.Array | None + block_mask: np.ndarray | jax.Array | None + num_active_blocks: np.ndarray | jax.Array | None + partial_mask_blocks: np.ndarray | jax.Array | None + q_sequence: np.ndarray | None + + +def _downcast_to_small_type(array: np.ndarray) -> np.ndarray: + """Downcast numpy array. + + If possible, downcast the data-type of the input array to the smallest numpy + type (among np.int16 and np.int8) that fits the content of the array. + + Args: + array: the array to downcast + + Returns: + The downcasted array. + + Raises: + ValueError: if the input array is not np.int32 or if its elements are not + all positive. + """ + if array.dtype != np.int32: + raise ValueError(f'Expected int32 input, but got {array.dtype}.') + + if not np.all(array >= -1): + # Allow -1 for padding. + raise ValueError('Expected non-negative array.') + + if array.size == 0: + return array + + max_value = np.max(array) + + if max_value <= np.iinfo(np.int8).max: + return array.astype(np.int8) + elif max_value <= np.iinfo(np.int16).max: + return array.astype(np.int16) + else: + return array.astype(np.int32) + + +def _check_mask(mask: mask_lib.Mask) -> None: + """Check that the given mask is valid. + + A row of all zeros along the kv dimension would result in a division by zero + when computing the softmax. This function is meant to protect against that + case. + + Args: + mask: the mask to check. + + Raises: + ValueError: the mask is invalid. + """ + + assert len(mask.shape) == 2 + + exception_message = ( + 'Some rows of the mask (along the kv dimension) are all zeros.\nThis is' + ' would result in a division by zero when computing the attention' + ' softmax.' + ) + + is_row_non_zero = np.zeros(mask.shape[0], dtype=np.bool_) + for col in range(mask.shape[1]): + # Mask only supports slice indices. + is_row_non_zero = np.logical_or( + is_row_non_zero, + mask[(slice(0, mask.shape[0]), slice(col, col + 1))][:, 0], + ) + if not is_row_non_zero.all(): + raise ValueError(exception_message) + + +class _HashableNDArray: + """Helper to make a numpy array hashable: can be added associative containers. + + Attributes: + array: The underlying numpy array. + """ + + __slots__ = ('array', '_hash') + array: np.ndarray + + def __init__(self, array: np.ndarray): + self.array = array + self._hash = hash(array.tobytes()) + + def __hash__(self): + return self._hash + + def __eq__(self, other: object) -> bool: + if not isinstance(other, _HashableNDArray): + return NotImplemented + return np.array_equal(self.array, other.array, equal_nan=True) + + +def _generate_shard_metadata( + block_mask: np.ndarray, + partial_blocks: np.ndarray, + is_dkv: bool, + return_dynamic_grid: bool, +): + if is_dkv: + block_mask = block_mask.mT + partial_blocks = partial_blocks.mT + + if return_dynamic_grid: + active_mask = block_mask > 0 + if is_dkv: + # If an entire row is masked then that kv output tile won't be visited. + # We extend the grid to visit these tiles to initialize them. + active_mask[:, 0] |= ~active_mask.any(axis=1) + active_indices = np.argwhere(active_mask) + active_rows = active_indices[:, 0].astype(np.int32) + active_cols = active_indices[:, 1].astype(np.int32) + block_mask = block_mask[active_mask > 0] + grid_size = active_rows.size + else: + active_indices = np.ndindex(block_mask.shape) + active_rows = active_cols = grid_size = None + + partial_coords = np.argwhere(partial_blocks != -1) + if partial_coords.size > 0: + mask_next = [] + mask_coords_iter = iter([tuple(c) for c in partial_coords]) + first_m = coord_m = next(mask_coords_iter) + + for idx in active_indices: + is_next_mask = tuple(idx) > tuple(coord_m) + if is_next_mask: + try: + coord_m = next(mask_coords_iter) # type: ignore + except StopIteration: + coord_m = first_m + mask_next.append(partial_blocks[coord_m]) + else: + mask_next = np.full(block_mask.size, -1, dtype=np.int32) + + mask_next = np.array(mask_next, dtype=np.int32) + flat_block_mask = block_mask.flatten() + + return active_rows, active_cols, mask_next, flat_block_mask, grid_size + + +def _process_dynamic_mask( + mask: jax.Array, + block_shape: tuple[int, int], + is_dkv: bool, + *, + downcast_smem_data: bool = True, + partial_mask_blocks_dtype: jax.typing.DTypeLike = np.int8, +) -> MaskInfo: + """Process a dynamic mask to compute it's local sparsity data. + + Note that this operates on a single shard of the mask. + + Args: + mask: [q_seq_len, kv_seq_len] jax.Array representing a dense mask to + process. + block_shape: A Tuple[int, int] representing the shape of the Pallas grid + block. + is_dkv: True if we are processing the dKV mask + downcast_smem_data: If True, downcast the scalar-memory data of MaskInfo to + a data type smaller than np.int32 (if possible). + + Returns: + `MaskInfo`, a sparse representation of the dense mask. + + Raises: + ValueError: if the input mask is invalid or the block sizes are not + compatible with the mask sizes. + """ + if len(mask.shape) != 2: + raise ValueError(f'Expected a 2-dim mask, instead got: {mask.shape}.') + + q_seq_len, kv_seq_len = mask.shape + q_block_size, kv_block_size = block_shape + q_blocks_count, q_mod = divmod(q_seq_len, q_block_size) + kv_blocks_count, kv_mod = divmod(kv_seq_len, kv_block_size) + + if q_mod != 0: + raise ValueError(f'{q_block_size=} should divide {q_seq_len=}.') + if kv_mod != 0: + raise ValueError(f'{kv_block_size=} should divide {kv_seq_len=}.') + + # Tile the last 2 dimensions of the mask into 2D tiles of size `block_shape`. + mask_blocks = ( + mask.reshape( + q_blocks_count, + q_block_size, + kv_blocks_count, + kv_block_size, + ) + .swapaxes(-2, -3) + .astype(partial_mask_blocks_dtype) + ) + + any_mask = jnp.any(mask_blocks, axis=(-1, -2)).astype(np.int32) + all_mask = jnp.all(mask_blocks, axis=(-1, -2)).astype(np.int32) + block_mask = any_mask + all_mask + + block_ids = jnp.arange(block_mask.size, dtype=np.int32).reshape( + block_mask.shape + ) + if is_dkv: + block_mask = block_mask.swapaxes(-1, -2) + block_ids = block_ids.swapaxes(-1, -2) + mask_blocks = mask_blocks.swapaxes(-1, -2) + + active_mask = block_mask > 0 + if is_dkv: + # If an entire row is masked then that kv output tile won't be visited. + # We extend the grid to visit these tiles to initialize them. + empty_rows = jnp.all(block_mask == 0, axis=-1) + first_col = jnp.arange(block_mask.shape[1]) == 0 + active_mask |= (empty_rows[:, None] & first_col) + + num_active_blocks = active_mask.flatten().sum(keepdims=True) + active_indices = jnp.argwhere( + active_mask, size=active_mask.size, fill_value=-1 + ) + active_rows = active_indices[:, 0].astype(np.int32) + active_cols = active_indices[:, 1].astype(np.int32) + + block_mask = block_mask[active_rows, active_cols] + mask_next = block_ids.at[active_rows, active_cols].get( + wrap_negative_indices=False + ) + mask_next = jnp.where(block_mask == 1, mask_next, 0) + + # Mask out the blocks that aren't active. + mask = (jnp.arange(block_mask.size) < num_active_blocks).astype(np.int32) + block_mask = block_mask * mask + + # Collapsing because the block ids are linearized. + mask_blocks = lax.collapse(mask_blocks, 0, 2) + + def _downcast(array: jax.Array, max_value: int) -> jax.Array: + if array.size == 0: + return array + + if array.dtype != np.int32: + raise ValueError(f'Expected int32 input, but got {array.dtype}.') + + if max_value <= np.iinfo(np.int8).max: + return array.astype(np.int8) + elif max_value <= np.iinfo(np.int16).max: + return array.astype(np.int16) + else: + return array.astype(np.int32) + + if downcast_smem_data: + block_mask = block_mask.astype(np.int8) # values are in the range [0, 1, 2] + mask_next = _downcast(mask_next, q_blocks_count * kv_blocks_count) + + return MaskInfo( + mask_next=mask_next, + active_rows=active_rows, + active_cols=active_cols, + block_mask=block_mask, + num_active_blocks=num_active_blocks, + partial_mask_blocks=mask_blocks, + q_sequence=None, + ) + + +# When used in a transformer network with multiple layers, the SplashAttention +# kernel is created several times with the same mask. Cache MaskInfo to avoid +# blowing up compile times. Ideally the size of the cache should be determined +# by the client. +@functools.lru_cache(maxsize=12) +def _process_mask( + mask: mask_lib.Mask, # [q_seq_len, kv_seq_len] + block_shape: tuple[int, int], + is_dkv: bool, + *, + downcast_smem_data: bool = True, + partial_mask_blocks_dtype: jax.typing.DTypeLike = np.int8, + q_seq_shards: int = 1, + kv_seq_shards: int = 1, + return_dynamic_grid: bool = True, +) -> tuple[MaskInfo, MaskCallable | None]: + """Transform a dense mask into a sparse representation. + + The number Q sequence shards are needed to create a MaskInfo + object that is partitionable (with shard_map) along that dimension. + Args: + mask: Dense mask to process. + block_shape: Shape of the Pallas grid block. + is_dkv: True if we are processing the dKV mask + downcast_smem_data: If True, downcast the SMEM data of MaskInfo to a data + type smaller if possible. + q_seq_shards: Number of Q sequence shards of the mesh in which the kernel is + launched. + + Returns: + `MaskInfo`, a sparse representation of the dense mask. + `MaskCallable`: a callable that, given Q and KV indices, returns + the value of the mask at those coordinates. + + Raises: + ValueError: if the input mask is invalid or the block sizes are not + compatible with the mask sizes. + """ + + if len(mask.shape) != 2: + raise ValueError(f'Expected a 2-dim mask, instead got: {mask.shape=}') + + q_seq_len, kv_seq_len = mask.shape + q_block_size, kv_block_size = block_shape + q_blocks_count, q_mod = divmod(q_seq_len, q_block_size) + kv_blocks_count, kv_mod = divmod(kv_seq_len, kv_block_size) + + if q_mod != 0: + raise ValueError(f'{q_block_size=} should divide {q_seq_len=}.') + if kv_mod != 0: + raise ValueError(f'{kv_block_size=} should divide {kv_seq_len=}.') + + q_seq_len_per_shard, mod = divmod(q_seq_len, q_seq_shards) + if mod != 0: + raise ValueError(f'{q_seq_shards=} should divide {q_seq_len=}.') + + q_blocks_per_shard, mod = divmod(q_seq_len_per_shard, q_block_size) + if mod != 0: + raise ValueError(f'{q_block_size=} should divide {q_seq_len_per_shard=}.') + + kv_seq_len_per_shard, mod = divmod(kv_seq_len, kv_seq_shards) + if mod != 0: + raise ValueError(f'{kv_seq_shards=} should divide {kv_seq_len=}.') + + kv_blocks_per_shard, mod = divmod(kv_seq_len_per_shard, kv_block_size) + if mod != 0: + raise ValueError(f'{kv_block_size=} should divide {kv_seq_len_per_shard=}.') + + # TODO: checking the validity of the masks is slow for large masks. + # Disable it for now, reevaluate in the future. + + # The mask object either define q_sequence and mask_function or none of + # them. + assert hasattr(mask, 'q_sequence') == hasattr(mask, 'mask_function') + + # If the mask object defines a q_sequence and a mask_function, then make use + # of these in the kernel rather. This is preferable over loading the mask + # from memory. When using a mask_function, then mask_next and + # partial_mask_blocks are left undefined and not used in the kernel. + if hasattr(mask, 'q_sequence') and hasattr(mask, 'mask_function'): + q_sequence = mask.q_sequence + mask_function = mask.mask_function + else: + q_sequence = mask_function = None + + # Identify the partial mask blocks and the value of the block mask for each + # block. + # Partial mask blocks are uniquified. When partitioning, all partial mask + # blocks are replicated across shards. + + blocked_shape = (q_blocks_count, kv_blocks_count) + state_grid = np.zeros(blocked_shape, dtype=np.int32) + partial_id_grid = np.full(blocked_shape, -1, dtype=np.int32) + + partial_blocks_map = collections.defaultdict(lambda: len(partial_blocks_map)) + unique_chunks = [] + + # Partition the dense mask into blocks and categorize them: + # 0 = Empty, 1 = Partial (mixed 0s and 1s), 2 = Full (all 1s). + # Partial blocks are deduplicated and stored in unique_chunks to save memory. + for coords in np.ndindex((q_blocks_count, kv_blocks_count)): + (q_idx, kv_idx) = coords + chunk = mask[( + slice(q_idx * q_block_size, (q_idx + 1) * q_block_size), + slice(kv_idx * kv_block_size, (kv_idx + 1) * kv_block_size), + )] + if chunk.any(): + if chunk.all(): + state_grid[q_idx, kv_idx] = 2 + else: + state_grid[q_idx, kv_idx] = 1 + chunk_id = partial_blocks_map[_HashableNDArray(chunk)] + partial_id_grid[q_idx, kv_idx] = chunk_id + + if chunk_id == len(unique_chunks): + unique_chunks.append(chunk) + + full_mask = (state_grid == 2).all() + if full_mask: + return MaskInfo( + mask_next=None, + active_rows=None, + active_cols=None, + block_mask=None, + num_active_blocks=None, + partial_mask_blocks=None, + q_sequence=q_sequence, + ), None + + if unique_chunks: + partial_mask_blocks = np.stack(unique_chunks).astype( + partial_mask_blocks_dtype + ) + if is_dkv: + partial_mask_blocks = partial_mask_blocks.mT + else: + partial_mask_blocks = None + + # Work on a fraction of the mask at the time to compute the mask. This is + # needed to compute the correct data indices, which are relative to the + # current slice of the mask. + all_shards_metadata = [] + for q_shard_idx in range(q_seq_shards): + for kv_shard_idx in range(kv_seq_shards): + q_slice = slice( + q_shard_idx * q_blocks_per_shard, + (q_shard_idx + 1) * q_blocks_per_shard, + ) + kv_slice = slice( + kv_shard_idx * kv_blocks_per_shard, + (kv_shard_idx + 1) * kv_blocks_per_shard, + ) + metadata = _generate_shard_metadata( + state_grid[q_slice, kv_slice], + partial_id_grid[q_slice, kv_slice], + is_dkv, + return_dynamic_grid, + ) + all_shards_metadata.append(metadata) + + ( + active_rows_slices, + active_cols_slices, + mask_next_slices, + block_mask_slices, + num_active_blocks, + ) = zip(*all_shards_metadata) + + if return_dynamic_grid: + # Pad each slice to the largest number of active blocks in any shard. + max_size = max(num_active_blocks) + pad_slice = lambda arr: np.pad( + arr, (0, max_size - arr.shape[0]), mode='constant', constant_values=-1 + ) + active_rows_slices = list(map(pad_slice, active_rows_slices)) + active_cols_slices = list(map(pad_slice, active_cols_slices)) + mask_next_slices = list(map(pad_slice, mask_next_slices)) + block_mask_slices = list(map(pad_slice, block_mask_slices)) + + # Concatenate the sequence shards. + active_rows = np.concatenate(active_rows_slices, axis=0) + active_cols = np.concatenate(active_cols_slices, axis=0) + num_active_blocks = np.array(num_active_blocks, dtype=np.int32) + + if downcast_smem_data: + active_rows = _downcast_to_small_type(active_rows) + active_cols = _downcast_to_small_type(active_cols) + else: + active_rows = active_cols = num_active_blocks = None + + mask_next = np.concatenate(mask_next_slices, axis=0) + block_mask = np.concatenate(block_mask_slices, axis=0) + + if downcast_smem_data: + mask_next = _downcast_to_small_type(mask_next) + block_mask = _downcast_to_small_type(block_mask) + + if partial_mask_blocks is None: + mask_next = None + + assert (mask_function is not None) == (q_sequence is not None) + # When the mask can be computed inside the kernel with a mask_function, + # there is no need to load it from memory. So mask_next and + # partial_mask_blocks are unused. + return ( + MaskInfo( + mask_next=mask_next if mask_function is None else None, + active_rows=active_rows, + active_cols=active_cols, + block_mask=block_mask, + num_active_blocks=num_active_blocks, + partial_mask_blocks=partial_mask_blocks + if mask_function is None + else None, + q_sequence=q_sequence, + ), + mask_function, + ) + + +process_mask = functools.partial(_process_mask, is_dkv=False) +process_mask_dkv = functools.partial(_process_mask, is_dkv=True) + +process_dynamic_mask = functools.partial(_process_dynamic_mask, is_dkv=False) +process_dynamic_mask_dkv = functools.partial(_process_dynamic_mask, is_dkv=True) diff --git a/src/maxdiffusion/kernels/splash_attention/splash_attention_mask_test.py b/src/maxdiffusion/kernels/splash_attention/splash_attention_mask_test.py new file mode 100644 index 00000000..3fe1da30 --- /dev/null +++ b/src/maxdiffusion/kernels/splash_attention/splash_attention_mask_test.py @@ -0,0 +1,1753 @@ +# Copyright 2025 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import sys + +from absl.testing import absltest +from absl.testing import parameterized +import jax +import numpy as np +from . import splash_attention_mask as mask_lib +from . import splash_attention_mask_info as mask_info_lib +from . import splash_attention_test_utils as test_utils + + +jax.config.parse_flags_with_absl() + +# pylint: disable=line-too-long + + +def _make_lazy_causal_mask(*args, **kwargs): + mask = mask_lib.CausalMask(*args, **kwargs) + return mask[:, :] + + +def _make_causal_mask(*args, **kwargs): + return mask_lib.make_causal_mask(*args, **kwargs) + + +def _make_lazy_local_attention_mask(*args, **kwargs): + mask = mask_lib.LocalMask(*args, **kwargs) + return mask[:, :] + + +def _make_local_attention_mask(*args, **kwargs): + return mask_lib.make_local_attention_mask(*args, **kwargs) + + +def _make_lazy_chunked_causal_mask(shape, chunk_size): + mask = mask_lib.ChunkedCausalMask(shape=shape, chunk_size=chunk_size) + return mask[:, :] + + +def _make_chunked_causal_mask(shape, chunk_size): + return mask_lib.make_chunk_attention_mask(shape=shape, chunk_size=chunk_size) + + +class SplashAttentionMaskTest(test_utils.SplashAttentionTestCase): + + def setUp(self): + if jax.default_backend() != "tpu": + self.skipTest("Only supported on TPUs.") + super().setUp() + + @parameterized.parameters([_make_lazy_causal_mask, _make_causal_mask]) + def test_causal_mask(self, make_causal_mask): + expected = np.array([[1]], dtype=np.bool_) + actual = make_causal_mask((1, 1)) + + with self.subTest("unit"): + self._assert_array_equal(actual, expected) + + expected = np.array( + [ + [1, 0, 0, 0], + [1, 1, 0, 0], + [1, 1, 1, 0], + [1, 1, 1, 1], + ], + dtype=np.bool_, + ) + actual = make_causal_mask((4, 4)) + + with self.subTest("square"): + self._assert_array_equal(actual, expected) + + expected = np.array( + [ + [1, 0, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0], + [1, 1, 1, 0, 0, 0], + [1, 1, 1, 1, 0, 0], + ], + dtype=np.bool_, + ) + actual = make_causal_mask((4, 6)) + + with self.subTest("wide_rectangle"): + self._assert_array_equal(actual, expected) + + actual = make_causal_mask((6, 4)) + expected = np.array( + [ + [1, 0, 0, 0], + [1, 1, 0, 0], + [1, 1, 1, 0], + [1, 1, 1, 1], + [1, 1, 1, 1], + [1, 1, 1, 1], + ], + dtype=np.bool_, + ) + + with self.subTest("tall_rectangle"): + self._assert_array_equal(actual, expected) + + actual = make_causal_mask((4, 4), -1) + expected = np.array( + [ + [0, 0, 0, 0], + [1, 0, 0, 0], + [1, 1, 0, 0], + [1, 1, 1, 0], + ], + dtype=np.bool_, + ) + + with self.subTest("negative_offset"): + self._assert_array_equal(actual, expected) + + actual = make_causal_mask((4, 4), 1) + expected = np.array( + [ + [1, 1, 0, 0], + [1, 1, 1, 0], + [1, 1, 1, 1], + [1, 1, 1, 1], + ], + dtype=np.bool_, + ) + + with self.subTest("positive_offset"): + self._assert_array_equal(actual, expected) + + @parameterized.parameters( + [_make_lazy_local_attention_mask, _make_local_attention_mask] + ) + def test_local_attention_mask(self, make_local_attention_mask): + expected = np.array([[1]], dtype=np.bool_) + actual = make_local_attention_mask((1, 1), (0, None), offset=0) + self._assert_array_equal(actual, expected) + + expected = np.array( + [ + [1, 1, 1, 1], + [1, 1, 1, 1], + [0, 1, 1, 1], + [0, 0, 1, 1], + ], + dtype=np.bool_, + ) + actual = make_local_attention_mask((4, 4), (1, None), offset=0) + with self.subTest("left_1"): + self._assert_array_equal(actual, expected) + self._assert_array_equal(actual, expected) + + expected = np.array( + [ + [1, 1, 1, 0], + [1, 1, 1, 1], + [1, 1, 1, 1], + [1, 1, 1, 1], + ], + dtype=np.bool_, + ) + actual = make_local_attention_mask((4, 4), (None, 2), offset=0) + with self.subTest("right_2"): + self._assert_array_equal(actual, expected) + self._assert_array_equal(actual, expected) + + expected = np.array( + [ + [1, 1, 0, 0], + [1, 1, 1, 0], + [0, 1, 1, 1], + [0, 0, 1, 1], + ], + dtype=np.bool_, + ) + actual = make_local_attention_mask((4, 4), (1, 1), offset=0) + with self.subTest("left_1_right_1"): + self._assert_array_equal(actual, expected) + self._assert_array_equal(actual, expected) + + expected = np.array( + [ + [1, 0, 0, 0], + [1, 1, 0, 0], + [0, 1, 1, 0], + [0, 0, 1, 1], + ], + dtype=np.bool_, + ) + actual = make_local_attention_mask((4, 4), (1, 0), offset=0) + with self.subTest("left_1_right_0"): + self._assert_array_equal(actual, expected) + self._assert_array_equal(actual, expected) + + expected = np.array( + [ + [1, 1, 1, 0], + [0, 1, 1, 1], + [0, 0, 1, 1], + [0, 0, 0, 1], + ], + dtype=np.bool_, + ) + actual = make_local_attention_mask((4, 4), (0, 2), offset=0) + with self.subTest("left_0_right_2"): + self._assert_array_equal(actual, expected) + + @parameterized.parameters( + [_make_lazy_local_attention_mask, _make_local_attention_mask] + ) + def test_local_attention_mask_wide_rectangle(self, make_local_attention_mask): + expected = np.array( + [ + [1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1], + [0, 1, 1, 1, 1, 1], + [0, 0, 1, 1, 1, 1], + ], + dtype=np.bool_, + ) + actual = make_local_attention_mask((4, 6), (1, None), offset=0) + with self.subTest("left_1"): + self._assert_array_equal(actual, expected) + + expected = np.array( + [ + [1, 1, 1, 0, 0, 0], + [1, 1, 1, 1, 0, 0], + [1, 1, 1, 1, 1, 0], + [1, 1, 1, 1, 1, 1], + ], + dtype=np.bool_, + ) + actual = make_local_attention_mask((4, 6), (None, 2), offset=0) + with self.subTest("right_2"): + self._assert_array_equal(actual, expected) + + expected = np.array( + [ + [1, 1, 0, 0, 0, 0], + [1, 1, 1, 0, 0, 0], + [0, 1, 1, 1, 0, 0], + [0, 0, 1, 1, 1, 0], + ], + dtype=np.bool_, + ) + actual = make_local_attention_mask((4, 6), (1, 1), offset=0) + with self.subTest("left_1_right_1"): + self._assert_array_equal(actual, expected) + + expected = np.array( + [ + [1, 0, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0], + [0, 1, 1, 0, 0, 0], + [0, 0, 1, 1, 0, 0], + ], + dtype=np.bool_, + ) + actual = make_local_attention_mask((4, 6), (1, 0), offset=0) + with self.subTest("left_1_right_0"): + self._assert_array_equal(actual, expected) + + expected = np.array( + [ + [1, 1, 1, 0, 0, 0], + [0, 1, 1, 1, 0, 0], + [0, 0, 1, 1, 1, 0], + [0, 0, 0, 1, 1, 1], + ], + dtype=np.bool_, + ) + actual = make_local_attention_mask((4, 6), (0, 2), offset=0) + with self.subTest("left_0_right_2"): + self._assert_array_equal(actual, expected) + + @parameterized.parameters( + [_make_lazy_local_attention_mask, _make_local_attention_mask] + ) + def test_local_attention_mask_tall_rectangle(self, make_local_attention_mask): + expected = np.array( + [ + [1, 1, 1, 1], + [1, 1, 1, 1], + [0, 1, 1, 1], + [0, 0, 1, 1], + [0, 0, 0, 1], + [0, 0, 0, 0], + ], + dtype=np.bool_, + ) + actual = make_local_attention_mask((6, 4), (1, None), offset=0) + with self.subTest("left_1"): + self._assert_array_equal(actual, expected) + + expected = np.array( + [ + [1, 1, 1, 0], + [1, 1, 1, 1], + [1, 1, 1, 1], + [1, 1, 1, 1], + [1, 1, 1, 1], + [1, 1, 1, 1], + ], + dtype=np.bool_, + ) + actual = make_local_attention_mask((6, 4), (None, 2), offset=0) + with self.subTest("right_2"): + self._assert_array_equal(actual, expected) + + expected = np.array( + [ + [1, 1, 0, 0], + [1, 1, 1, 0], + [0, 1, 1, 1], + [0, 0, 1, 1], + [0, 0, 0, 1], + [0, 0, 0, 0], + ], + dtype=np.bool_, + ) + actual = make_local_attention_mask((6, 4), (1, 1), offset=0) + with self.subTest("left_1_right_1"): + self._assert_array_equal(actual, expected) + + expected = np.array( + [ + [1, 0, 0, 0], + [1, 1, 0, 0], + [0, 1, 1, 0], + [0, 0, 1, 1], + [0, 0, 0, 1], + [0, 0, 0, 0], + ], + dtype=np.bool_, + ) + actual = make_local_attention_mask((6, 4), (1, 0), offset=0) + with self.subTest("left_1_right_0"): + self._assert_array_equal(actual, expected) + + expected = np.array( + [ + [1, 1, 1, 0], + [0, 1, 1, 1], + [0, 0, 1, 1], + [0, 0, 0, 1], + [0, 0, 0, 0], + [0, 0, 0, 0], + ], + dtype=np.bool_, + ) + actual = make_local_attention_mask((6, 4), (0, 2), offset=0) + with self.subTest("left_0_right_2"): + self._assert_array_equal(actual, expected) + + @parameterized.product( + block_size=[(256, 256), (256, 128), (128, 256)], + shape=[(1024, 1024), (1024, 2048), (2048, 1024)], + ) + def test_lazy_causal_mask_chunking( + self, block_size: tuple[int, int], shape: tuple[int, int] + ): + dense_mask = mask_lib.make_causal_mask(shape=shape) + self._compare_masks( + dense_mask, + mask_lib.CausalMask(shape), + block_size, + ) + + @parameterized.parameters([ + ((256, 256), (1024, 1024), (128, None), 0), + ((256, 128), (1024, 1024), (128, None), 16), + ((128, 256), (1024, 1024), (128, None), 16), + ((256, 256), (1024, 1024), (128, 256), 0), + ((256, 128), (1024, 1024), (128, 256), 0), + ((128, 256), (1024, 1024), (128, 256), 16), + ((256, 256), (1024, 1024), (None, 256), 0), + ((256, 128), (1024, 1024), (None, 256), 32), + ((128, 256), (1024, 1024), (None, 256), 32), + # + ((256, 256), (1024, 2048), (128, None), 0), + ((256, 128), (1024, 2048), (128, None), 16), + ((128, 256), (1024, 2048), (128, None), 16), + ((256, 256), (1024, 2048), (128, 256), 0), + ((256, 128), (1024, 2048), (128, 256), 0), + ((128, 256), (1024, 2048), (128, 256), 16), + ((256, 256), (1024, 2048), (None, 256), 0), + ((256, 128), (1024, 2048), (None, 256), 32), + ((128, 256), (1024, 2048), (None, 256), 32), + # + ((256, 256), (2048, 1024), (128, None), 0), + ((256, 128), (2048, 1024), (128, None), 16), + ((128, 256), (2048, 1024), (128, None), 16), + ((256, 256), (2048, 1024), (128, 256), 0), + ((256, 128), (2048, 1024), (128, 256), 0), + ((128, 256), (2048, 1024), (128, 256), 16), + ((256, 256), (2048, 1024), (None, 256), 0), + ((256, 128), (2048, 1024), (None, 256), 32), + ((128, 256), (2048, 1024), (None, 256), 32), + ]) + def test_lazy_local_mask_chunking( + self, + block_size: tuple[int, int], + shape: tuple[int, int], + window_size: tuple[int | None, int | None], + offset: int, + ): + dense_mask = mask_lib.make_local_attention_mask( + shape, window_size, offset=offset + ) + self._compare_masks( + dense_mask, + mask_lib.LocalMask(shape, window_size, offset), + block_size, + ) + + @parameterized.parameters( + [_make_lazy_chunked_causal_mask, _make_chunked_causal_mask] + ) + def test_chunked_causal_mask(self, make_chunked_mask): + """Tests the chunked causal mask logic for various shapes and chunk sizes.""" + with self.subTest("unit"): + expected = np.array([[1]], dtype=np.bool_) + actual = make_chunked_mask(shape=(1, 1), chunk_size=1) + self._assert_array_equal(actual, expected) + actual = make_chunked_mask(shape=(1, 1), chunk_size=2) + self._assert_array_equal(actual, expected) + + with self.subTest("square_exact_chunks"): + # Chunk 0: [0, 1], Chunk 1: [2, 3] + expected = np.array( + [ + [1, 0, 0, 0], + [1, 1, 0, 0], + [0, 0, 1, 0], + [0, 0, 1, 1], + ], + dtype=np.bool_, + ) + actual = make_chunked_mask(shape=(4, 4), chunk_size=2) + self._assert_array_equal(actual, expected) + + with self.subTest("square_uneven_chunks"): + expected = np.array( + [ + [1, 0, 0, 0, 0], + [1, 1, 0, 0, 0], + [1, 1, 1, 0, 0], + [0, 0, 0, 1, 0], + [0, 0, 0, 1, 1], + ], + dtype=np.bool_, + ) + actual = make_chunked_mask(shape=(5, 5), chunk_size=3) + self._assert_array_equal(actual, expected) + + with self.subTest("wide_rectangle"): + expected = np.array( + [ + [1, 0, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0], + [1, 1, 1, 0, 0, 0], + [0, 0, 0, 1, 0, 0], + ], + dtype=np.bool_, + ) + actual = make_chunked_mask(shape=(4, 6), chunk_size=3) + self._assert_array_equal(actual, expected) + + with self.subTest("tall_rectangle"): + expected = np.array( + [ + [1, 0, 0, 0], + [1, 1, 0, 0], + [1, 1, 1, 0], + [0, 0, 0, 1], + [0, 0, 0, 1], + [0, 0, 0, 1], + ], + dtype=np.bool_, + ) + actual = make_chunked_mask(shape=(6, 4), chunk_size=3) + self._assert_array_equal(actual, expected) + + with self.subTest("chunk_size_1"): + # Should only allow self-attention q==k and chunk_size == 1 + expected = np.array( + [ + [1, 0, 0, 0], + [0, 1, 0, 0], + [0, 0, 1, 0], + [0, 0, 0, 1], + ], + dtype=np.bool_, + ) + actual = make_chunked_mask(shape=(4, 4), chunk_size=1) + self._assert_array_equal(actual, expected) + + with self.subTest("chunk_size_greater_equal_seqlen"): + # Should behave like a normal causal mask + expected = np.array( + [ + [1, 0, 0, 0], + [1, 1, 0, 0], + [1, 1, 1, 0], + [1, 1, 1, 1], + ], + dtype=np.bool_, + ) + # Test chunk_size == seqlen + actual_eq = make_chunked_mask(shape=(4, 4), chunk_size=4) + self._assert_array_equal(actual_eq, expected) + # Test chunk_size > seqlen + actual_gt = make_chunked_mask(shape=(4, 4), chunk_size=5) + self._assert_array_equal(actual_gt, expected) + + @parameterized.product( + block_size=[(128, 128), (256, 128), (128, 256)], + shape=[(512, 512), (512, 1024), (1024, 512)], + chunk_size=[64, 128, 256, 512, 1024], + ) + def test_lazy_chunked_causal_mask_chunking( + self, + block_size: tuple[int, int], + shape: tuple[int, int], + chunk_size: int, + ): + """Compares lazy chunked mask evaluation against the dense version block-by-block.""" + q_len, kv_len = shape + # Adjust block size if it exceeds shape dimensions + adjusted_block_size = ( + min(block_size[0], q_len), + min(block_size[1], kv_len), + ) + + if ( + q_len % adjusted_block_size[0] != 0 + or kv_len % adjusted_block_size[1] != 0 + ): + self.skipTest( + f"Shape {shape} not divisible by block_size {adjusted_block_size}" + ) + + dense_mask = _make_chunked_causal_mask(shape=shape, chunk_size=chunk_size) + lazy_mask = mask_lib.ChunkedCausalMask(shape=shape, chunk_size=chunk_size) + self._compare_masks( + dense_mask, + lazy_mask, + adjusted_block_size, + ) + + def test_chunked_causal_mask_invalid_chunk_size(self): + """Tests that invalid chunk_size raises ValueError.""" + with self.assertRaises(ValueError): + mask_lib.ChunkedCausalMask(shape=(10, 10), chunk_size=0) + with self.assertRaises(ValueError): + mask_lib.ChunkedCausalMask(shape=(10, 10), chunk_size=-1) + with self.assertRaises(ValueError): + mask_lib.make_chunk_attention_mask(shape=(10, 10), chunk_size=0) + + def test_chunked_causal_mask_minimal_equality_hash(self): + """Tests for __eq__ and __hash__ of ChunkedCausalMask.""" + shape1, chunk_size1 = (128, 256), 16 + shape2, chunk_size2 = (128, 128), 32 # Different shape/chunk_size + + # Create three masks: two identical, one with different shape/chunk_size. + mask1 = mask_lib.ChunkedCausalMask(shape=shape1, chunk_size=chunk_size1) + mask2 = mask_lib.ChunkedCausalMask(shape=shape1, chunk_size=chunk_size1) + mask_diff_shape = mask_lib.ChunkedCausalMask( + shape=shape2, chunk_size=chunk_size1 + ) + mask_diff_chunk = mask_lib.ChunkedCausalMask( + shape=shape1, chunk_size=chunk_size2 + ) + other_obj = object() + + # Test __eq__ + self.assertEqual(mask1, mask2) + self.assertNotEqual(mask1, mask_diff_shape) + self.assertNotEqual(mask1, mask_diff_chunk) + self.assertNotEqual(mask1, other_obj) + + # Test __hash__ of identical masks + self.assertEqual(hash(mask1), hash(mask2)) + + mask_set = {mask1, mask2, mask_diff_chunk} + self.assertLen(mask_set, 2) # mask1 and mask2 are duplicates + self.assertIn(mask1, mask_set) + self.assertIn(mask_diff_chunk, mask_set) + self.assertNotIn(mask_diff_shape, mask_set) + + def test_using_logical_operators_raises_exception(self): + if sys.version_info == (3, 14, 0, "candidate", 1): + # Fails due to Python bug on 3.14.0rc1 + # https://github.com/python/cpython/issues/137288 + self.skipTest("Expected failure.") + mask_1 = mask_lib.NumpyMask( + mask_lib.make_random_mask((256, 256), 0.5, seed=1) + ) + mask_2 = mask_lib.NumpyMask( + mask_lib.make_random_mask((256, 256), 0.5, seed=2) + ) + + with self.subTest("logical_or"): + with self.assertRaises(NotImplementedError): + res = mask_1 or mask_2 + del res + + with self.subTest("logical_and"): + with self.assertRaises(NotImplementedError): + res = mask_1 and mask_2 + del res + + @parameterized.parameters([((256, 256),), ((512, 256),), ((512, 256),)]) + def test_lazy_mask_or(self, shape: tuple[int, int]): + mask_1 = mask_lib.make_random_mask(shape, 0.5, seed=1) + mask_2 = mask_lib.make_random_mask(shape, 0.5, seed=2) + + lazy_or = mask_lib.NumpyMask(mask_1) | mask_lib.NumpyMask(mask_2) + dense = np.logical_or(mask_1, mask_2) + + self._compare_masks(dense, lazy_or, (256, 256)) + + @parameterized.parameters([((256, 256),), ((512, 256),), ((512, 256),)]) + def test_lazy_mask_and(self, shape: tuple[int, int]): + mask_1 = mask_lib.make_random_mask(shape, 0.5, seed=1) + mask_2 = mask_lib.make_random_mask(shape, 0.5, seed=2) + + lazy_and = mask_lib.NumpyMask(mask_1) & mask_lib.NumpyMask(mask_2) + dense = np.logical_and(mask_1, mask_2) + + self._compare_masks(dense, lazy_and, (256, 256)) + + @parameterized.parameters([((256, 256),), ((512, 256),), ((512, 256),)]) + def test_lazy_full_mask(self, shape: tuple[int, int]): + lazy_full = mask_lib.FullMask(shape) + dense = np.ones(shape, dtype=np.bool_) + + self._compare_masks(dense, lazy_full, (256, 256)) + + def _compare_masks( + self, + dense_mask: np.ndarray, + lazy_mask: mask_lib.Mask, + block_size: tuple[int, int], + ): + self.assertEqual(dense_mask.shape, lazy_mask.shape) + + *prefix, width, height = dense_mask.shape + + assert width % block_size[0] == 0 + assert height % block_size[1] == 0 + + full_lazy_mask = lazy_mask[ + (*[slice(p) for p in prefix], slice(None), slice(None)) + ] + self._assert_array_equal(dense_mask, full_lazy_mask) + for i, j in np.ndindex(width // block_size[0], height // block_size[1]): + indexer = ( + *[slice(p) for p in prefix], + slice(i * block_size[0], (i + 1) * block_size[0]), + slice(j * block_size[1], (j + 1) * block_size[1]), + ) + dense_chunk = dense_mask[indexer] + lazy_chunk = lazy_mask[indexer] + self._assert_array_equal(dense_chunk, lazy_chunk) + + +class SplashAttentionMaskInfoTest(test_utils.SplashAttentionTestCase): + """Check the construction of MaskInfo from Mask.""" + + def _assert_mask_info_match( + self, actual: mask_info_lib.MaskInfo, expected: mask_info_lib.MaskInfo + ): + def _check_presence(actual, expected): + return self.assertEqual(actual is not None, expected is not None) + + # TODO: refactor so that all of MaskInfo is possibly None + _check_presence(actual.mask_next, expected.mask_next) + _check_presence(actual.partial_mask_blocks, expected.partial_mask_blocks) + _check_presence(actual.q_sequence, expected.q_sequence) + _check_presence(actual.block_mask, expected.block_mask) + _check_presence(actual.active_rows, expected.active_rows) + _check_presence(actual.active_cols, expected.active_cols) + + self._assert_array_equal( + actual.num_active_blocks, + expected.num_active_blocks, + err_msg="num_active_blocks", + verbose=True, + ) + self._assert_array_equal( + actual.block_mask, + expected.block_mask, + err_msg="block_mask", + verbose=True, + ) + self._assert_array_equal( + actual.active_rows, + expected.active_rows, + err_msg="active_rows", + verbose=True, + ) + self._assert_array_equal( + actual.active_cols, + expected.active_cols, + err_msg="active_cols", + verbose=True, + ) + self._assert_array_equal( + actual.mask_next, + expected.mask_next, + err_msg="mask_next", + verbose=True, + ) + self._assert_array_equal( + actual.partial_mask_blocks, + expected.partial_mask_blocks, + err_msg="partial_mask_blocks", + verbose=True, + ) + self._assert_array_equal( + actual.q_sequence, + expected.q_sequence, + err_msg="q_sequence", + verbose=True, + ) + + def _process_mask(self, *args, **kwargs): + mask_info, mask_function = mask_info_lib.process_mask(*args, **kwargs) + mask_info_dkv, dkv_mask_function = mask_info_lib.process_mask_dkv( + *args, **kwargs + ) + self.assertEqual(mask_function, dkv_mask_function) + return mask_info, mask_info_dkv, mask_function + + @parameterized.parameters((True,), (False,)) + def test_full_mask(self, is_lazy_mask: bool): + sequence_lengths = (64, 64) + block_shape = (16, 16) + + if is_lazy_mask: + full_mask = mask_lib.FullMask(sequence_lengths) + else: + full_mask = mask_lib.NumpyMask(np.ones(sequence_lengths, dtype=np.bool_)) + + mask_info, mask_info_dkv, mask_function = self._process_mask( + full_mask, block_shape + ) + self.assertIsNone(mask_function) + + expected_mask_info = mask_info_lib.MaskInfo( + None, + None, + None, + None, + None, + None, + None, + ) + + self._assert_mask_info_match(mask_info, expected_mask_info) + self._assert_mask_info_match(mask_info_dkv, expected_mask_info) + + def test_no_partial_mask_blocks(self): + sequence_lengths = (64, 64) + block_shape = (16, 16) + + mask = np.ones(sequence_lengths).astype(np.bool_) + mask[:32, 32:] = False + mask = mask_lib.NumpyMask(mask) + + mask_info, mask_info_dkv, mask_function = self._process_mask( + mask, block_shape + ) + self.assertIsNone(mask_function) + + expected_mask_info = mask_info_lib.MaskInfo( + mask_next=None, + active_rows=np.array( + [0, 0, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3], dtype=np.int8 + ), + active_cols=np.array( + [0, 1, 0, 1, 0, 1, 2, 3, 0, 1, 2, 3], dtype=np.int8 + ), + block_mask=np.array( + [2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2], dtype=np.int8 + ), + num_active_blocks=np.array([12], dtype=np.int32), + partial_mask_blocks=None, + q_sequence=None, + ) + + expected_mask_info_dkv = mask_info_lib.MaskInfo( + mask_next=None, + active_rows=np.array( + [0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 3, 3], dtype=np.int8 + ), + active_cols=np.array( + [0, 1, 2, 3, 0, 1, 2, 3, 2, 3, 2, 3], dtype=np.int8 + ), + block_mask=np.array( + [2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2], dtype=np.int8 + ), + num_active_blocks=np.array([12], dtype=np.int32), + partial_mask_blocks=None, + q_sequence=None, + ) + + self._assert_mask_info_match(mask_info, expected_mask_info) + self._assert_mask_info_match(mask_info_dkv, expected_mask_info_dkv) + + @parameterized.product( + is_lazy_mask=[True, False], return_dynamic_grid=[True, False] + ) + def test_rectangular_wide_causal_mask( + self, is_lazy_mask: bool, return_dynamic_grid: bool + ): + sequence_lengths = (64, 128) + block_shape = (16, 16) + + if is_lazy_mask: + causal_mask = mask_lib.CausalMask(sequence_lengths) + else: + causal_mask = mask_lib.NumpyMask( + mask_lib.make_causal_mask(sequence_lengths) + ) + + args = (causal_mask, block_shape) + mask_info, mask_function = mask_info_lib.process_mask(*args) + mask_info_dkv, _ = mask_info_lib.process_mask_dkv( + *args, return_dynamic_grid=return_dynamic_grid + ) + if is_lazy_mask: + self.assertIsNotNone(mask_function) + else: + self.assertIsNone(mask_function) + + expected_causal_mask_next = np.array( + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], dtype=np.int8 + ) + expected_active_rows = np.array( + [0, 1, 1, 2, 2, 2, 3, 3, 3, 3], dtype=np.int8 + ) + expected_active_cols = np.array( + [0, 0, 1, 0, 1, 2, 0, 1, 2, 3], dtype=np.int8 + ) + expected_causal_block_mask = np.array( + [1, 2, 1, 2, 2, 1, 2, 2, 2, 1], dtype=np.int8 + ) + expected_num_active_blocks = np.array([10], dtype=np.int32) + + if not is_lazy_mask: + expected_mask_info = mask_info_lib.MaskInfo( + expected_causal_mask_next, + expected_active_rows, + expected_active_cols, + expected_causal_block_mask, + expected_num_active_blocks, + np.tri(*block_shape, dtype=np.int8)[None, ...], + None, + ) + else: + expected_mask_info = mask_info_lib.MaskInfo( + None, + expected_active_rows, + expected_active_cols, + expected_causal_block_mask, + expected_num_active_blocks, + None, + np.arange(sequence_lengths[0], dtype=np.int32), + ) + + if return_dynamic_grid: + expected_causal_mask_next_dkv = np.array( + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], dtype=np.int8 + ) + # The grid is extended to visit empty rows to initialize dk/dv. + expected_active_rows_dkv = np.array( + [0, 0, 0, 0, 1, 1, 1, 2, 2, 3, 4, 5, 6, 7], dtype=np.int8 + ) + expected_active_cols_dkv = np.array( + [0, 1, 2, 3, 1, 2, 3, 2, 3, 3, 0, 0, 0, 0], dtype=np.int8 + ) + expected_causal_block_mask_dkv = np.array( + [1, 2, 2, 2, 1, 2, 2, 1, 2, 1, 0, 0, 0, 0], dtype=np.int8 + ) + expected_num_active_blocks_dkv = np.array([14], dtype=np.int32) + else: + expected_causal_mask_next_dkv = np.zeros((32,), dtype=np.int8) + expected_active_rows_dkv = None + expected_active_cols_dkv = None + expected_causal_block_mask_dkv = np.array( + [ + [1, 2, 2, 2], + [0, 1, 2, 2], + [0, 0, 1, 2], + [0, 0, 0, 1], + [0, 0, 0, 0], + [0, 0, 0, 0], + [0, 0, 0, 0], + [0, 0, 0, 0], + ], + dtype=np.int8, + ).flatten() + expected_num_active_blocks_dkv = None + + expected_mask_info_dkv = mask_info_lib.MaskInfo( + expected_causal_mask_next_dkv if not is_lazy_mask else None, + expected_active_rows_dkv, + expected_active_cols_dkv, + expected_causal_block_mask_dkv, + expected_num_active_blocks_dkv, + np.tri(*block_shape, dtype=np.int8).T[None, ...] + if not is_lazy_mask + else None, + np.arange(sequence_lengths[0], dtype=np.int32) + if is_lazy_mask + else None, + ) + self._assert_mask_info_match(mask_info, expected_mask_info) + self._assert_mask_info_match(mask_info_dkv, expected_mask_info_dkv) + + @parameterized.parameters((True,), (False,)) + def test_rectangular_tall_causal_mask(self, is_lazy_mask: bool): + sequence_lengths = (128, 64) + block_shape = (16, 16) + + if is_lazy_mask: + causal_mask = mask_lib.CausalMask(sequence_lengths) + else: + causal_mask = mask_lib.NumpyMask( + mask_lib.make_causal_mask(sequence_lengths) + ) + + mask_info, mask_info_dkv, mask_function = self._process_mask( + causal_mask, block_shape + ) + if is_lazy_mask: + self.assertIsNotNone(mask_function) + else: + self.assertIsNone(mask_function) + + expected_causal_mask_next = np.array([0] * 26, dtype=np.int8) + expected_active_rows = np.array( + [ + 0, + 1, + 1, + 2, + 2, + 2, + 3, + 3, + 3, + 3, + 4, + 4, + 4, + 4, + 5, + 5, + 5, + 5, + 6, + 6, + 6, + 6, + 7, + 7, + 7, + 7, + ], + dtype=np.int8, + ) + expected_active_cols = np.array( + [ + 0, + 0, + 1, + 0, + 1, + 2, + 0, + 1, + 2, + 3, + 0, + 1, + 2, + 3, + 0, + 1, + 2, + 3, + 0, + 1, + 2, + 3, + 0, + 1, + 2, + 3, + ], + dtype=np.int8, + ) + expected_causal_block_mask = np.array( + [1, 2, 1, 2, 2, 1, 2, 2, 2, 1] + [2] * 16, dtype=np.int8 + ) + expected_num_active_blocks = np.array([26], dtype=np.int32) + + expected_mask_info = mask_info_lib.MaskInfo( + expected_causal_mask_next if not is_lazy_mask else None, + expected_active_rows, + expected_active_cols, + expected_causal_block_mask, + expected_num_active_blocks, + np.tri(*block_shape, dtype=np.int8)[None, ...] + if not is_lazy_mask + else None, + np.arange(sequence_lengths[0], dtype=np.int32) + if is_lazy_mask + else None, + ) + + expected_causal_mask_next_dkv = np.array([0] * 26, dtype=np.int8) + expected_active_rows_dkv = np.array( + [0] * 8 + [1] * 7 + [2] * 6 + [3] * 5, dtype=np.int8 + ) + expected_active_cols_dkv = np.concatenate( + [np.arange(8), np.arange(1, 8), np.arange(2, 8), np.arange(3, 8)], + dtype=np.int8, + ) + expected_causal_block_mask_dkv = np.array( + [1, 2, 2, 2, 2, 2, 2, 2] + + [1, 2, 2, 2, 2, 2, 2] + + [1, 2, 2, 2, 2, 2] + + [1, 2, 2, 2, 2], + dtype=np.int8, + ) + + expected_mask_info_dkv = mask_info_lib.MaskInfo( + expected_causal_mask_next_dkv if not is_lazy_mask else None, + expected_active_rows_dkv, + expected_active_cols_dkv, + expected_causal_block_mask_dkv, + expected_num_active_blocks, + np.tri(*block_shape, dtype=np.int8).T[None, ...] + if not is_lazy_mask + else None, + np.arange(sequence_lengths[0], dtype=np.int32) + if is_lazy_mask + else None, + ) + + self._assert_mask_info_match(mask_info, expected_mask_info) + self._assert_mask_info_match(mask_info_dkv, expected_mask_info_dkv) + + @parameterized.parameters((True,), (False,)) + def test_local_mask(self, is_lazy_mask: bool): + sequence_lengths = (64, 64) + block_shape = (16, 16) + window_size = 8 + if is_lazy_mask: + local_mask = mask_lib.LocalMask( + sequence_lengths, + window_size=(window_size, window_size), + offset=0, + ) + else: + local_mask = mask_lib.NumpyMask( + mask_lib.make_local_attention_mask( + sequence_lengths, window_size=(window_size, window_size), offset=0 + ) + ) + + mask_info, mask_info_dkv, mask_function = self._process_mask( + local_mask, block_shape + ) + if is_lazy_mask: + self.assertIsNotNone(mask_function) + + expected_partial_mask_blocks = np.stack( + [ + np.triu( + np.tri(*block_shape, window_size, dtype=np.int8), -window_size + ), + np.tri(*block_shape, -window_size, dtype=np.int8), + np.triu(np.ones(block_shape, dtype=np.int8), window_size), + ], + ) + expected_local_mask_next = np.array( + [0, 1, 2, 0, 1, 2, 0, 1, 2, 0], dtype=np.int8 + ) + expected_active_rows = np.array( + [0, 0, 1, 1, 1, 2, 2, 2, 3, 3], dtype=np.int8 + ) + expected_active_cols = np.array( + [0, 1, 0, 1, 2, 1, 2, 3, 2, 3], dtype=np.int8 + ) + expected_local_block_mask = np.array( + [1, 1, 1, 1, 1, 1, 1, 1, 1, 1], dtype=np.int8 + ) + expected_num_active_blocks = np.array([10], dtype=np.int32) + + expected_mask_info = mask_info_lib.MaskInfo( + expected_local_mask_next if not is_lazy_mask else None, + expected_active_rows, + expected_active_cols, + expected_local_block_mask, + expected_num_active_blocks, + expected_partial_mask_blocks if not is_lazy_mask else None, + np.arange(sequence_lengths[0], dtype=np.int32) + if is_lazy_mask + else None, + ) + + expected_local_mask_next_dkv = np.array( + [0, 2, 1, 0, 2, 1, 0, 2, 1, 0], dtype=np.int8 + ) + expected_active_rows_dkv = np.array( + [ + 0, + 0, + 1, + 1, + 1, + 2, + 2, + 2, + 3, + 3, + ], + dtype=np.int8, + ) + expected_active_cols_dkv = np.array( + [0, 1, 0, 1, 2, 1, 2, 3, 2, 3], dtype=np.int8 + ) + expected_local_block_mask_dkv = np.array( + [1, 1, 1, 1, 1, 1, 1, 1, 1, 1], dtype=np.int8 + ) + + expected_mask_info_dkv = mask_info_lib.MaskInfo( + expected_local_mask_next_dkv if not is_lazy_mask else None, + expected_active_rows_dkv, + expected_active_cols_dkv, + expected_local_block_mask_dkv, + expected_num_active_blocks, + expected_partial_mask_blocks.mT if not is_lazy_mask else None, + np.arange(sequence_lengths[0], dtype=np.int32) + if is_lazy_mask + else None, + ) + + self._assert_mask_info_match(mask_info, expected_mask_info) + self._assert_mask_info_match(mask_info_dkv, expected_mask_info_dkv) + + @parameterized.parameters((True,), (False,)) + def test_local_mask_narrow(self, is_lazy_mask: bool): + sequence_lengths = (64, 64) + block_shape = (16, 16) + window_size = 8 + if is_lazy_mask: + local_mask = mask_lib.LocalMask( + sequence_lengths, + window_size=(window_size, 0), + offset=0, + ) + else: + local_mask = mask_lib.NumpyMask( + mask_lib.make_local_attention_mask( + sequence_lengths, window_size=(window_size, 0), offset=0 + ) + ) + + mask_info, mask_info_dkv, mask_function = self._process_mask( + local_mask, block_shape + ) + + if is_lazy_mask: + self.assertIsNotNone(mask_function) + + expected_partial_mask_blocks = np.stack( + [ + np.triu(np.tri(*block_shape, 0, dtype=np.int8), -window_size), + np.triu(np.ones(block_shape, dtype=np.int8), window_size), + ], + ) + + expected_local_mask_next = np.array([0, 1, 0, 1, 0, 1, 0], dtype=np.int8) + expected_active_rows = np.array([0, 1, 1, 2, 2, 3, 3], dtype=np.int8) + expected_active_cols = np.array([0, 0, 1, 1, 2, 2, 3], dtype=np.int8) + expected_local_block_mask = np.array([1, 1, 1, 1, 1, 1, 1], dtype=np.int8) + expected_num_active_blocks = np.array([7], dtype=np.int32) + + expected_mask_info = mask_info_lib.MaskInfo( + expected_local_mask_next if not is_lazy_mask else None, + expected_active_rows, + expected_active_cols, + expected_local_block_mask, + expected_num_active_blocks, + expected_partial_mask_blocks if not is_lazy_mask else None, + np.arange(sequence_lengths[0], dtype=np.int32) + if is_lazy_mask + else None, + ) + expected_active_rows_dkv = np.array([0, 0, 1, 1, 2, 2, 3], dtype=np.int8) + expected_active_cols_dkv = np.array([0, 1, 1, 2, 2, 3, 3], dtype=np.int8) + + expected_mask_info_dkv = mask_info_lib.MaskInfo( + expected_local_mask_next if not is_lazy_mask else None, + expected_active_rows_dkv, + expected_active_cols_dkv, + expected_local_block_mask, + expected_num_active_blocks, + expected_partial_mask_blocks.mT if not is_lazy_mask else None, + np.arange(sequence_lengths[0], dtype=np.int32) + if is_lazy_mask + else None, + ) + + self._assert_mask_info_match(mask_info, expected_mask_info) + self._assert_mask_info_match(mask_info_dkv, expected_mask_info_dkv) + + def test_two_qseq_shards_causal_local_stacked(self): + sequence_lengths = (64, 64) + block_shape = (16, 16) + window_size = 8 + + causal_mask = mask_lib.make_causal_mask(sequence_lengths) + local_mask = mask_lib.make_local_attention_mask( + sequence_lengths, window_size=(window_size, window_size), offset=0 + ) + mask = np.concatenate((causal_mask, local_mask), axis=0) + mask = mask_lib.NumpyMask(mask) + + mask_info, mask_info_dkv, mask_function = self._process_mask( + mask, block_shape, q_seq_shards=2 + ) + self.assertIsNone(mask_function) + + expected_mask_next = np.concatenate( + [ + np.array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0]), # causal mask + np.array([1, 2, 3, 1, 2, 3, 1, 2, 3, 1]), # local mask + ], + axis=0, + dtype=np.int8, + ) + + expected_active_rows = np.concatenate( + [ + np.array([0, 1, 1, 2, 2, 2, 3, 3, 3, 3]), + np.array([0, 0, 1, 1, 1, 2, 2, 2, 3, 3]), + ], + axis=0, + dtype=np.int8, + ) + + expected_active_cols = np.concatenate( + [ + np.array([0, 0, 1, 0, 1, 2, 0, 1, 2, 3]), + np.array([0, 1, 0, 1, 2, 1, 2, 3, 2, 3]), + ], + axis=0, + dtype=np.int8, + ) + + expected_block_mask = np.concatenate( + [ + np.array([1, 2, 1, 2, 2, 1, 2, 2, 2, 1]), + np.array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1]), + ], + axis=0, + dtype=np.int8, + ) + + expected_num_active_blocks = np.array([10, 10], dtype=np.int32) + + expected_partial_mask_blocks = np.stack([ + np.tri(*block_shape, dtype=np.int8), + np.triu( + np.tri(*block_shape, window_size, dtype=np.int8), + -window_size, + ), + np.tri(*block_shape, -window_size, dtype=np.int8), + np.triu(np.ones(block_shape, dtype=np.int8), window_size), + ]) + + expected_mask_info = mask_info_lib.MaskInfo( + expected_mask_next, + expected_active_rows, + expected_active_cols, + expected_block_mask, + expected_num_active_blocks, + expected_partial_mask_blocks, + None, + ) + + expected_mask_next_dkv = np.concatenate( + [ + np.array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0]), # causal mask + np.array([1, 3, 2, 1, 3, 2, 1, 3, 2, 1]), # local mask + ], + axis=0, + dtype=np.int8, + ) + + expected_active_rows_dkv = np.concatenate( + [ + np.array([0, 0, 0, 0, 1, 1, 1, 2, 2, 3]), + np.array([0, 0, 1, 1, 1, 2, 2, 2, 3, 3]), + ], + axis=0, + dtype=np.int8, + ) + + expected_active_cols_dkv = np.concatenate( + [ + np.array([0, 1, 2, 3, 1, 2, 3, 2, 3, 3]), # causal mask + np.array([0, 1, 0, 1, 2, 1, 2, 3, 2, 3]), + ], # local mask + axis=0, + dtype=np.int8, + ) + + expected_block_mask_dkv = np.concatenate( + [ + np.array([1, 2, 2, 2, 1, 2, 2, 1, 2, 1]), # causal mask + np.array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1]), + ], # local mask + axis=0, + dtype=np.int8, + ) + + expected_mask_info_dkv = mask_info_lib.MaskInfo( + expected_mask_next_dkv, + expected_active_rows_dkv, + expected_active_cols_dkv, + expected_block_mask_dkv, + expected_num_active_blocks, + expected_partial_mask_blocks.mT, + None, + ) + + self._assert_mask_info_match(mask_info, expected_mask_info) + self._assert_mask_info_match(mask_info_dkv, expected_mask_info_dkv) + + @parameterized.named_parameters( + dict( + testcase_name="q_seq_shards_2", + q_seq_shards=2, + kv_seq_shards=1, + ), + dict( + testcase_name="kv_seq_shards_2", + q_seq_shards=1, + kv_seq_shards=2, + ), + ) + def test_two_shards_local_wide_local_narrow_stacked( + self, q_seq_shards, kv_seq_shards + ): + sequence_lengths = (64, 64) + block_shape = (16, 16) + window_size = 8 + + local_mask_wide = mask_lib.make_local_attention_mask( + sequence_lengths, window_size=(window_size, window_size), offset=0 + ) + local_mask_narrow = mask_lib.make_local_attention_mask( + sequence_lengths, window_size=(window_size, 0), offset=0 + ) + + concat_axis = 0 if q_seq_shards > 1 else 1 + mask = np.concatenate((local_mask_wide, local_mask_narrow), axis=concat_axis) + + mask = mask_lib.NumpyMask(mask) + + mask_info, mask_info_dkv, mask_function = self._process_mask( + mask, + block_shape, + q_seq_shards=q_seq_shards, + kv_seq_shards=kv_seq_shards, + ) + self.assertIsNone(mask_function) + + expected_block_mask = np.concatenate( + [ + np.array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1]), # local wide block mask + np.array([1, 1, 1, 1, 1, 1, 1, -1, -1, -1]), # local narrow block mask + ], + axis=0, + dtype=np.int8, + ) + + expected_active_rows = np.concatenate( + [ + np.array([0, 0, 1, 1, 1, 2, 2, 2, 3, 3]), + np.array([0, 1, 1, 2, 2, 3, 3, -1, -1, -1]), + ], + axis=0, + dtype=np.int8, + ) + + expected_active_cols = np.concatenate( + [ + np.array([0, 1, 0, 1, 2, 1, 2, 3, 2, 3]), + np.array([0, 0, 1, 1, 2, 2, 3, -1, -1, -1]), + ], + axis=0, + dtype=np.int8, + ) + + expected_num_active_blocks = np.array([10, 7], dtype=np.int32) + + block_wide_1 = np.triu( + np.tri(*block_shape, window_size, dtype=np.int8), -window_size + ) + block_wide_2 = np.tri(*block_shape, -window_size, dtype=np.int8) + block_wide_3 = np.triu(np.ones(block_shape, dtype=np.int8), window_size) + block_narrow = np.triu(np.tri(*block_shape, 0, dtype=np.int8), -window_size) + + if q_seq_shards == 2: + expected_partial_mask_blocks = np.stack( + [block_wide_1, block_wide_2, block_wide_3, block_narrow] + ).astype(np.int8) + + expected_mask_next = np.array( + [0, 1, 2, 0, 1, 2, 0, 1, 2, 0] # local wide mask + + [3, 2, 3, 2, 3, 2, 3, -1, -1, -1], # local narrow mask + dtype=np.int8, + ) + + expected_local_mask_next_dkv = np.array( + [0, 2, 1, 0, 2, 1, 0, 2, 1, 0] + + [3, 2, 3, 2, 3, 2, 3, -1, -1, -1], + dtype=np.int8, + ) + + else: + assert kv_seq_shards == 2 + # The global mask is different so the partial mask blocks are processed + # in a different order. + expected_partial_mask_blocks = np.stack( + [block_wide_1, block_wide_2, block_narrow, block_wide_3], + ).astype(np.int8) + + expected_mask_next = np.array( + [0, 1, 3, 0, 1, 3, 0, 1, 3, 0] # local narrow mask + + [2, 3, 2, 3, 2, 3, 2, -1, -1, -1], # local wide mask + dtype=np.int8, + ) + + expected_local_mask_next_dkv = np.array( + [0, 3, 1, 0, 3, 1, 0, 3, 1, 0] + [2, 3, 2, 3, 2, 3, 2, -1, -1, -1], + dtype=np.int8, + ) + + expected_mask_info = mask_info_lib.MaskInfo( + expected_mask_next, + expected_active_rows, + expected_active_cols, + expected_block_mask, + expected_num_active_blocks, + expected_partial_mask_blocks, + None, + ) + + expected_active_rows_dkv = np.concatenate( + [ + np.array([ + 0, + 0, + 1, + 1, + 1, + 2, + 2, + 2, + 3, + 3, + ]), + np.array([0, 0, 1, 1, 2, 2, 3, -1, -1, -1]), + ], + axis=0, + dtype=np.int8, + ) + + expected_active_cols_dkv = np.concatenate( + [ + np.array([0, 1, 0, 1, 2, 1, 2, 3, 2, 3]), + np.array([0, 1, 1, 2, 2, 3, 3, -1, -1, -1]), + ], + axis=0, + dtype=np.int8, + ) + + expected_block_mask_dkv = np.concatenate( + [ + np.array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1]), + np.array([1, 1, 1, 1, 1, 1, 1, -1, -1, -1]), + ], + axis=0, + dtype=np.int8, + ) + + expected_mask_info_dkv = mask_info_lib.MaskInfo( + expected_local_mask_next_dkv, + expected_active_rows_dkv, + expected_active_cols_dkv, + expected_block_mask_dkv, + expected_num_active_blocks, + expected_partial_mask_blocks.mT, + None, + ) + + self._assert_mask_info_match(mask_info, expected_mask_info) + self._assert_mask_info_match(mask_info_dkv, expected_mask_info_dkv) + + @parameterized.parameters(False, True) + def test_causal_two_q_shards_two_kv_shards(self, return_dynamic_grid): + q_seq_shards = kv_seq_shards = 2 + sequence_lengths = (64, 64) + block_shape = (16, 16) + + mask = mask_lib.make_causal_mask(sequence_lengths, 0) + mask = mask_lib.NumpyMask(mask) + + args = (mask, block_shape) + kwargs = { + "q_seq_shards": q_seq_shards, + "kv_seq_shards": kv_seq_shards, + } + mask_info, _ = mask_info_lib.process_mask(*args, **kwargs) + mask_info_dkv, _ = mask_info_lib.process_mask_dkv( + *args, + **kwargs, + return_dynamic_grid=return_dynamic_grid, + ) + + partial_mask_blocks = np.tri(*(block_shape), dtype=np.int8)[None] + expected_mask_info = mask_info_lib.MaskInfo( + mask_next=np.array( + [0, 0, 0, -1, -1, -1, -1, -1, -1, -1, -1, -1, 0, 0, 0, -1], + dtype=np.int8, + ), + active_rows=np.array( + [0, 1, 1, -1, -1, -1, -1, -1, 0, 0, 1, 1, 0, 1, 1, -1], + dtype=np.int8, + ), + active_cols=np.array( + [0, 0, 1, -1, -1, -1, -1, -1, 0, 1, 0, 1, 0, 0, 1, -1], + dtype=np.int8, + ), + block_mask=np.array( + [1, 2, 1, -1, -1, -1, -1, -1, 2, 2, 2, 2, 1, 2, 1, -1], + dtype=np.int8, + ), + num_active_blocks=np.array([3, 0, 4, 3], dtype=np.int32), + partial_mask_blocks=partial_mask_blocks, + q_sequence=None, + ) + if return_dynamic_grid: + expected_mask_info_dkv = mask_info_lib.MaskInfo( + mask_next=np.array( + [0, 0, 0, -1, -1, -1, -1, -1, -1, -1, -1, -1, 0, 0, 0, -1], + dtype=np.int8, + ), + active_rows=np.array( + [0, 0, 1, -1, 0, 1, -1, -1, 0, 0, 1, 1, 0, 0, 1, -1], dtype=np.int8 + ), + active_cols=np.array( + [0, 1, 1, -1, 0, 0, -1, -1, 0, 1, 0, 1, 0, 1, 1, -1], dtype=np.int8 + ), + block_mask=np.array( + [1, 2, 1, -1, 0, 0, -1, -1, 2, 2, 2, 2, 1, 2, 1, -1], dtype=np.int8 + ), + num_active_blocks=np.array([3, 2, 4, 3], dtype=np.int32), + partial_mask_blocks=partial_mask_blocks.mT, + q_sequence=None, + ) + else: + + expected_mask_info_dkv = mask_info_lib.MaskInfo( + mask_next=np.array( + [0, 0, 0, 0, -1, -1, -1, -1, -1, -1, -1, -1, 0, 0, 0, 0], + dtype=np.int8, + ), + active_rows=None, + active_cols=None, + block_mask=np.array( + [1, 2, 0, 1, 0, 0, 0, 0, 2, 2, 2, 2, 1, 2, 0, 1], dtype=np.int8 + ), + num_active_blocks=None, + partial_mask_blocks=partial_mask_blocks.mT, + q_sequence=None, + ) + + self._assert_mask_info_match(mask_info, expected_mask_info) + self._assert_mask_info_match(mask_info_dkv, expected_mask_info_dkv) + + def test_huge_mask(self): + # Don't go too high with the mask size to avoid timeouts. Prefer covering + # multiple cases rather one very large one. This configuration replicates + # a realistic training shape. In particular, a large number of head shards + # and interleaving contribute to increasing processing time. + sequence_length = (32 * 1024, 32 * 1024) + block_shape = (512, 1024) + + num_shards = 16 + causal_mask = mask_lib.CausalMask( + sequence_length, 0, shard_count=num_shards + ) + + mask_info, mask_function = mask_info_lib.process_mask( + causal_mask, block_shape, q_seq_shards=16 + ) + + self.assertIsNotNone(mask_function) + self.assertIsNotNone(mask_info.block_mask) + self.assertIsNone(mask_info.mask_next) + self.assertIsNone(mask_info.partial_mask_blocks) + self.assertIsNotNone(mask_info.q_sequence) + + def test_huge_mask2(self): + sequence_lengths = (32 * 1024, 32 * 1024) + block_shape = (1024, 1024) + window_size = 8 + + local_mask = mask_lib.LocalMask( + sequence_lengths, + window_size=(window_size, window_size), + offset=0, + ) + + mask_info, mask_function = mask_info_lib.process_mask( + local_mask, block_shape + ) + + self.assertIsNotNone(mask_function) + self.assertIsNotNone(mask_info.block_mask) + self.assertIsNone(mask_info.mask_next) + self.assertIsNone(mask_info.partial_mask_blocks) + self.assertIsNotNone(mask_info.q_sequence) + + def test_process_invalid_mask(self): + """Masks with of an all-0 row causes undefined softmax, reject them.""" + sequence_length = 32 + + invalid_mask = np.ones((sequence_length, sequence_length), dtype=np.bool_) + invalid_mask[14, :] = False + invalid_mask = mask_lib.NumpyMask(invalid_mask) + + with self.assertRaises(ValueError) as ctx: + mask_info_lib._check_mask(invalid_mask) + + self.assertIn("softmax", str(ctx.exception)) + + def test_dynamic_mask(self): + q_seq_len, kv_seq_len = 8, 8 + block_shape = (2, 4) + + mask = _make_causal_mask((q_seq_len, kv_seq_len)) + + process_dynamic_mask_fn = jax.jit( + mask_info_lib.process_dynamic_mask, + static_argnames=["block_shape", "is_dkv"], + ) + + args = (mask, block_shape) + mask_info = process_dynamic_mask_fn(*args) + mask_info_dkv = process_dynamic_mask_fn(*args, is_dkv=True) + + expected_mask_next = np.array([0, 2, 0, 5, 0, 7, 0, 0], dtype=np.int8) + expected_block_mask = np.array([1, 1, 2, 1, 2, 1, 0, 0], dtype=np.int8) + expected_active_rows = np.array([0, 1, 2, 2, 3, 3, -1, -1], dtype=np.int32) + expected_active_cols = np.array([0, 0, 0, 1, 0, 1, -1, -1], dtype=np.int32) + expected_num_active_blocks = np.array([6], dtype=np.int32) + expected_partial_mask_blocks = np.array( + [ + [[1, 0, 0, 0], [1, 1, 0, 0]], + [[0, 0, 0, 0], [0, 0, 0, 0]], + [[1, 1, 1, 0], [1, 1, 1, 1]], + [[0, 0, 0, 0], [0, 0, 0, 0]], + [[1, 1, 1, 1], [1, 1, 1, 1]], + [[1, 0, 0, 0], [1, 1, 0, 0]], + [[1, 1, 1, 1], [1, 1, 1, 1]], + [[1, 1, 1, 0], [1, 1, 1, 1]], + ], + dtype=np.int8, + ) + + expected_mask_info = mask_info_lib.MaskInfo( + expected_mask_next, + expected_active_rows, + expected_active_cols, + expected_block_mask, + expected_num_active_blocks, + expected_partial_mask_blocks, + None, + ) + + expected_mask_next_dkv = np.array([0, 2, 0, 0, 5, 7, 0, 0], dtype=np.int8) + expected_active_rows_dkv = np.array([0, 0, 0, 0, 1, 1, -1, -1], dtype=np.int32) + expected_active_cols_dkv = np.array([0, 1, 2, 3, 2, 3, -1, -1], dtype=np.int32) + expected_block_mask_dkv = np.array([1, 1, 2, 2, 1, 1, 0, 0], dtype=np.int8) + expected_num_active_blocks_dkv = np.array([6], dtype=np.int32) + + expected_mask_info_dkv = mask_info_lib.MaskInfo( + expected_mask_next_dkv, + expected_active_rows_dkv, + expected_active_cols_dkv, + expected_block_mask_dkv, + expected_num_active_blocks_dkv, + expected_partial_mask_blocks.swapaxes(-1, -2), + None, + ) + self._assert_mask_info_match(mask_info, expected_mask_info) + self._assert_mask_info_match(mask_info_dkv, expected_mask_info_dkv) + + def test_find_bounds(self): + test_cases = [ + ("standard", [0, 0, 1, 1, 2], [1, 0, 1, 0, 1], [0, 1, 0, 1, 1], 5), + ("homogeneous", [5, 5, 5, 5], [1, 0, 0, 0], [0, 0, 0, 1], 5), + ("alternating", [0, 1, 0, 1], [1, 1, 1, 1], [1, 1, 1, 1], 4), + ("wrap_around", [1, 0, 0, 1], [1, 1, 0, 1], [1, 0, 1, 1], 4), + ("padding", [0, 0, -1], [1, 0, 0], [0, 1, 0], 2), + ] + + for name, arr, exp_start, exp_end, n in test_cases: + with self.subTest(name): + start, end = mask_info_lib.find_bounds(np.array(arr)) + np.testing.assert_array_equal(start[:n], np.array(exp_start)[:n]) + np.testing.assert_array_equal(end[:n], np.array(exp_end)[:n]) + +if __name__ == "__main__": + absltest.main() diff --git a/src/maxdiffusion/kernels/splash_attention/splash_attention_test_utils.py b/src/maxdiffusion/kernels/splash_attention/splash_attention_test_utils.py new file mode 100644 index 00000000..56eb913f --- /dev/null +++ b/src/maxdiffusion/kernels/splash_attention/splash_attention_test_utils.py @@ -0,0 +1,88 @@ +# Copyright 2025 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import unittest +from absl.testing import parameterized +import jax +import jax.numpy as jnp +import numpy as np + +from . import base + + +def test_device_matches(devices: list[str]) -> bool: + """Returns True if the test device matches any of the given devices.""" + return any(d.lower() in jax.devices()[0].device_kind.lower() for d in devices) + + +def thread_unsafe_test_class(): + """Decorator that marks a TestCase class as thread-hostile.""" + + def f(klass): + assert issubclass(klass, unittest.TestCase), type(klass) + klass.thread_hostile = True + return klass + + return f + + +class SplashAttentionTestCase(parameterized.TestCase): + """Base class for SplashAttention tests.""" + + INTERPRET = False + + def setUp(self): + if self.INTERPRET and not test_device_matches(["cpu"]): + self.skipTest("Interpret mode only supported on CPU") + + super().setUp() + + def _assert_array_equal(self, x, y, **kwargs): + if x is None or y is None: + self.assertIsNone(x) + self.assertIsNone(y) + return + + self.assertTrue(jnp.isfinite(x).all()) + self.assertTrue(jnp.isfinite(y).all()) + + if x.dtype == np.dtype(jnp.bfloat16): + x = x.astype(np.float32) + if y.dtype == np.dtype(jnp.bfloat16): + y = y.astype(np.float32) + + self.assertEqual(x.dtype, y.dtype) + self.assertTupleEqual(x.shape, y.shape) + np.testing.assert_array_equal(x, y, **kwargs) + + def _assert_allclose(self, x, y, **kwargs): + if x.dtype == np.dtype(jnp.bfloat16): + x = x.astype(np.float32) + if y.dtype == np.dtype(jnp.bfloat16): + y = y.astype(np.float32) + self.assertEqual(x.dtype, y.dtype) + self.assertTupleEqual(x.shape, y.shape) + np.testing.assert_allclose(x, y, **kwargs) + + +def create_segment_ids(seq_len: int, num_breaks: int = 2) -> base.SegmentIds: + break_indices = np.random.choice( + range(1, seq_len), num_breaks, replace=False + ) + idxs = np.zeros(seq_len, dtype=np.int32) + idxs[break_indices] = 1 + + idxs = np.cumsum(idxs, dtype=np.int32) + return base.SegmentIds(q=idxs, kv=idxs) diff --git a/src/maxdiffusion/max_utils.py b/src/maxdiffusion/max_utils.py index 04b3869f..c3c11101 100644 --- a/src/maxdiffusion/max_utils.py +++ b/src/maxdiffusion/max_utils.py @@ -78,7 +78,13 @@ def l2norm_pytree(x): def activate_profiler(config): if jax.process_index() == 0 and config.enable_profiler: - jax.profiler.start_trace(config.tensorboard_dir) + # If tensorboard_dir is GCS, write profiler traces locally instead + profiler_path = config.tensorboard_dir + if config.tensorboard_dir.startswith("gs://"): + profiler_path = "/tmp/profiler_traces" + os.makedirs(profiler_path, exist_ok=True) + max_logging.log(f"Profiler: saving traces locally to {profiler_path} (GCS paths not supported)") + jax.profiler.start_trace(profiler_path) def deactivate_profiler(config): @@ -86,6 +92,16 @@ def deactivate_profiler(config): jax.profiler.stop_trace() +def upload_profiler_traces(config): + """No-op for now - profiler traces are saved locally""" + if jax.process_index() == 0 and config.enable_profiler: + if config.tensorboard_dir.startswith("gs://"): + max_logging.log("Profiler traces saved to: /tmp/profiler_traces") + max_logging.log("You can download them manually or use: gsutil -m rsync -r /tmp/profiler_traces/ " + config.tensorboard_dir.rstrip("/") + "/") + else: + max_logging.log(f"Profiler traces saved to: {config.tensorboard_dir}") + + def initialize_summary_writer(config): return writer.SummaryWriter(config.tensorboard_dir) if jax.process_index() == 0 else None @@ -94,7 +110,6 @@ def close_summary_writer(summary_writer): if jax.process_index() == 0: summary_writer.close() - def _prepare_metrics_for_json(metrics, step, run_name): """Converts metric dictionary into json supported types (e.g. float)""" metrics_dict = {} diff --git a/src/maxdiffusion/models/attention_flax.py b/src/maxdiffusion/models/attention_flax.py index 9b738c18..18271838 100644 --- a/src/maxdiffusion/models/attention_flax.py +++ b/src/maxdiffusion/models/attention_flax.py @@ -24,8 +24,9 @@ from jax.experimental import shard_map from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_mask from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_kernel -from tokamax._src.ops.experimental.tpu.splash_attention import splash_attention_mask as tokamax_splash_attention_mask -from tokamax._src.ops.experimental.tpu.splash_attention import splash_attention_kernel as tokamax_splash_attention_kernel +from maxdiffusion.kernels.splash_attention import splash_attention_mask as tokamax_splash_attention_mask +from maxdiffusion.kernels.splash_attention import splash_attention_kernel as tokamax_splash_attention_kernel +from maxdiffusion.kernels.splash_attention import ring_attention_kernel as tokamax_ring_attention_kernel from einops import rearrange from .. import common_types, max_logging @@ -56,10 +57,34 @@ CROSS_ATTN_KV_LENGTH = common_types.CROSS_ATTN_KV_LENGTH +def _coerce_tokamax_block_sizes(block_sizes): + # Tokamax requires fused bwd; convert if needed. + if getattr(block_sizes, "use_fused_bwd_kernel", False): + return block_sizes + + # Fall back if some fields are missing. + bq = block_sizes.block_q + bkv = getattr(block_sizes, "block_kv", bq) + bkv_compute = getattr(block_sizes, "block_kv_compute", bkv) + bq_dkv = getattr(block_sizes, "block_q_dkv", bq) + bkv_dkv = getattr(block_sizes, "block_kv_dkv", bkv) + bkv_dkv_compute = getattr(block_sizes, "block_kv_dkv_compute", bkv_compute) + return splash_attention_kernel.BlockSizes( + block_q=bq, + block_kv=bkv, + block_kv_compute=bkv_compute, + block_q_dkv=bq_dkv, + block_kv_dkv=bkv_dkv, + block_kv_dkv_compute=bkv_dkv_compute, + block_q_dq=None, + block_kv_dq=None, + use_fused_bwd_kernel=True, + ) + + def _maybe_aqt_einsum(quant: Quant): return jnp.einsum if quant is None else quant.einsum() - def _check_attention_inputs(query: Array, key: Array, value: Array) -> None: """Check attention inputs.""" @@ -77,11 +102,8 @@ def _reshape_data_from_cudnn_flash(tensor): def _reshape_data_for_cudnn_flash(tensor, heads): # reshapes from [b, s, h * d] to [b, s, h, d] (input format to flash format) - if len(tensor.shape) == 3: - batch, seq, dim_head = tensor.shape - tensor = tensor.reshape(batch, seq, heads, dim_head // heads) - else: - tensor = jnp.transpose(tensor, (0, 2, 1, 3)) + batch, seq, heads_and_dim_head = tensor.shape + tensor = tensor.reshape(batch, seq, heads, heads_and_dim_head // heads) return tensor @@ -177,20 +199,17 @@ def _pad_data_for_flash(tensor, heads, flash_block_size, num_shards: int = 1): return tensor, kv_size, seq_len - -def convert_to_tokamax_splash_config( - block_sizes: BlockSizes, - q_layout: tokamax_splash_attention_kernel.QKVLayout = tokamax_splash_attention_kernel.QKVLayout.HEAD_DIM_MINOR, - k_layout: tokamax_splash_attention_kernel.QKVLayout = tokamax_splash_attention_kernel.QKVLayout.HEAD_DIM_MINOR, - v_layout: tokamax_splash_attention_kernel.QKVLayout = tokamax_splash_attention_kernel.QKVLayout.HEAD_DIM_MINOR, - residual_checkpoint_name: str | None = None, - attn_logits_soft_cap: float | None = None, - fuse_reciprocal: bool = True, - use_base2_exp: bool = False, - max_logit_const: float | None = None, - interpret: bool = False, - dq_reduction_steps: int | None = None, -) -> tokamax_splash_attention_kernel.SplashConfig: +def convert_to_tokamax_splash_config( block_sizes: BlockSizes, + q_layout: tokamax_splash_attention_kernel.QKVLayout = tokamax_splash_attention_kernel.QKVLayout.HEAD_DIM_MINOR, + k_layout: tokamax_splash_attention_kernel.QKVLayout = tokamax_splash_attention_kernel.QKVLayout.HEAD_DIM_MINOR, + v_layout: tokamax_splash_attention_kernel.QKVLayout = tokamax_splash_attention_kernel.QKVLayout.HEAD_DIM_MINOR, + residual_checkpoint_name: str | None = None, + attn_logits_soft_cap: float | None = None, + fuse_reciprocal: bool = True, + use_base2_exp: bool = False, + max_logit_const: float | None = None, + interpret: bool = False, + dq_reduction_steps: int | None = None) -> tokamax_splash_attention_kernel.SplashConfig: assert block_sizes.use_fused_bwd_kernel, "Tokamax Splash attention only supports fused bwd kernel." return tokamax_splash_attention_kernel.SplashConfig( block_q=block_sizes.block_q, @@ -199,7 +218,7 @@ def convert_to_tokamax_splash_config( block_q_dkv=block_sizes.block_q_dkv, block_kv_dkv=block_sizes.block_kv_dkv, block_kv_dkv_compute=block_sizes.block_kv_dkv_compute, - block_q_dq=None if block_sizes.use_fused_bwd_kernel else block_sizes.block_q_dq, + block_q_dq= None if block_sizes.use_fused_bwd_kernel else block_sizes.block_q_dq, block_kv_dq=None if block_sizes.use_fused_bwd_kernel else block_sizes.block_kv_dq, use_fused_bwd_kernel=block_sizes.use_fused_bwd_kernel, q_layout=q_layout, @@ -228,7 +247,6 @@ def _tpu_flash_attention( attention_kernel: str = "flash", mask_padding_tokens: bool = True, residual_checkpoint_name: str | None = None, - attention_mask: jax.Array = None, ) -> jax.Array: """TPU Flash Attention""" @@ -239,9 +257,13 @@ def _tpu_flash_attention( kv_max_block_size = key.shape[1] else: kv_max_block_size = q_max_block_size + # ensure that for cross attention we override the block sizes. if flash_block_sizes and key.shape[1] == query.shape[1]: block_sizes = flash_block_sizes + use_tokamax = attention_kernel in ["tokamax_flash", "tokamax_ring"] + if use_tokamax: + block_sizes = _coerce_tokamax_block_sizes(flash_block_sizes) else: block_size_q = flash_block_sizes.block_q if flash_block_sizes else q_max_block_size block_sizes = splash_attention_kernel.BlockSizes( @@ -270,6 +292,7 @@ def _tpu_flash_attention( check_rep=False, ) def wrap_flash_attention(query, key, value): + uses_fused_kernel = block_sizes.use_fused_bwd_kernel block_q_sizes = ( block_sizes.block_q, @@ -303,37 +326,27 @@ def wrap_flash_attention(query, key, value): kv_padded_len = key.shape[2] kv_indices = jax.lax.broadcasted_iota(jnp.int32, (kv_padded_len,), 0) kv_segment_ids = (kv_indices < key_seq_len).astype(jnp.int32) - - # If attention_mask is provided, apply it to kv_segment_ids - if attention_mask is not None: - mask_len = min(key_seq_len, attention_mask.shape[1]) - kv_mask_for_batch = attention_mask[0, :mask_len] # (mask_len,) - # If key_seq_len > mask_len, pad the mask with 1s (assume remaining tokens are valid) - if key_seq_len > mask_len: - extra_valid = jnp.ones((key_seq_len - mask_len,), dtype=jnp.int32) - kv_mask_for_batch = jnp.concatenate([kv_mask_for_batch, extra_valid], axis=0) # (key_seq_len,) - # Pad to kv_padded_len - if kv_padded_len > key_seq_len: - padding = jnp.zeros((kv_padded_len - key_seq_len,), dtype=jnp.int32) - kv_mask_padded = jnp.concatenate([kv_mask_for_batch, padding], axis=0) # (kv_padded_len,) - else: - kv_mask_padded = kv_mask_for_batch - # Both are (kv_padded_len,) - element-wise multiplication - kv_segment_ids = (kv_segment_ids * kv_mask_padded).astype(jnp.int32) - segment_ids = splash_attention_kernel.SegmentIds(q=q_segment_ids, kv=kv_segment_ids) # make_splash_mha is wrapped around shardmap and seq and head is already # sharded based on in_specs, therefore setting head_shards=1 and q_seq_shards=1. if attention_kernel == "tokamax_flash": - mask = tokamax_splash_attention_mask.FullMask( - _shape=(query.shape[2], key.shape[2]), - ) + mask = tokamax_splash_attention_mask.FullMask(_shape=(query.shape[2], key.shape[2]),) splash_kernel = tokamax_splash_attention_kernel.make_splash_mha( mask=mask, q_seq_shards=1, # the sizes of the axis is sharding over seq_len config=convert_to_tokamax_splash_config(block_sizes, residual_checkpoint_name=residual_checkpoint_name), - save_residuals=True if attention_kernel == "ring" else False, + save_residuals=False, + ) + elif attention_kernel == "tokamax_ring": + mask = tokamax_splash_attention_mask.FullMask(_shape=(query.shape[2], key.shape[2]),) + splash_kernel = tokamax_ring_attention_kernel.make_ring_attention( + mask=mask, + is_mqa=False, + config=convert_to_tokamax_splash_config(block_sizes, residual_checkpoint_name=residual_checkpoint_name), + save_residuals=False, + ring_axis="context", + rotate_segment_ids=False, # We don't rotate segment ids in tokamax ring attention because our segment ids is for padding each kv shard has same segment ids ) else: splash_kernel = splash_attention_kernel.make_splash_mha( @@ -341,14 +354,16 @@ def wrap_flash_attention(query, key, value): head_shards=1, # the sizes of the axis is sharding over heads q_seq_shards=1, # the sizes of the axis is sharding over seq_len block_sizes=block_sizes, - save_residuals=True if attention_kernel == "ring" else False, - residual_checkpoint_name=residual_checkpoint_name, + save_residuals=True if "ring" in attention_kernel else False, + residual_checkpoint_name=residual_checkpoint_name ) + + vmapped_splash = jax.vmap(splash_kernel, in_axes=(0, 0, 0, None)) if not mask_padding_tokens: segment_ids = None - if attention_kernel in ["flash", "tokamax_flash"]: + if attention_kernel in ["flash", "tokamax_flash", "tokamax_ring"]: attention_output = vmapped_splash(query, key, value, segment_ids) else: if num_context_shards > 1: @@ -384,13 +399,12 @@ def ring_scan_body(carry, _): return (m, l, o, k_next, v_next), None initial_carry = (m, l, o, k1, v1) - (m_final, l_final, o_final, _, _), _ = jax.lax.scan( - ring_scan_body, initial_carry, None, length=num_context_shards - 1 - ) + (m_final, l_final, o_final, _, _), _ = jax.lax.scan(ring_scan_body, initial_carry, None, length=num_context_shards - 1) attention_output = o_final / l_final[..., None] else: raise ValueError("ring attention requires context > 1") + return attention_output[:, :, :query_seq_len, :kv_size].astype(query.dtype) devices_in_data_context = mesh.shape["data"] * mesh.shape["context"] @@ -491,12 +505,24 @@ def _cudnn_flash_attention(query: Array, key: Array, value: Array, heads: int, m key = _reshape_data_for_cudnn_flash(key, heads) value = _reshape_data_for_cudnn_flash(value, heads) - axis_names = nn.logical_to_mesh_axes((BATCH, LENGTH, HEAD, D_KV)) - query = jax.lax.with_sharding_constraint(query, axis_names) - key = jax.lax.with_sharding_constraint(key, axis_names) - value = jax.lax.with_sharding_constraint(value, axis_names) + cudnn_flash_axis_names = (BATCH, LENGTH, HEAD, D_KV) + axis_names = nn.logical_to_mesh_axes(cudnn_flash_axis_names) + + query = nn.with_logical_constraint(query, axis_names) + key = nn.with_logical_constraint(key, axis_names) + value = nn.with_logical_constraint(value, axis_names) - out = dpa_layer(query, key, value, mask=None) + @functools.partial( + shard_map.shard_map, + mesh=mesh, + in_specs=(axis_names, axis_names, axis_names), + out_specs=axis_names, + check_rep=False, + ) + def wrap_flash_attention(query, key, value): + return jax.vmap(dpa_layer)(query, key, value, mask=None) + + out = wrap_flash_attention(query, key, value) return _reshape_data_from_cudnn_flash(out) @@ -520,7 +546,6 @@ def _apply_attention( dpa_layer: Callable, mask_padding_tokens: bool = True, residual_checkpoint_name: str | None = None, - attention_mask: Array = None, ): """Routes to different attention kernels.""" _check_attention_inputs(query, key, value) @@ -553,21 +578,12 @@ def _apply_attention( attention_kernel, mask_padding_tokens=mask_padding_tokens, residual_checkpoint_name=residual_checkpoint_name, - attention_mask=attention_mask, ) - elif attention_kernel == "ring": + elif "ring" in attention_kernel: return _tpu_flash_attention( - query, - key * scale, - value, - heads, - mesh, - axis_names_q, - axis_names_kv, - flash_block_sizes, - dtype, - attention_kernel, + query, key * scale, value, heads, mesh, axis_names_q, axis_names_kv, flash_block_sizes, dtype, attention_kernel, mask_padding_tokens=mask_padding_tokens, + residual_checkpoint_name=residual_checkpoint_name, ) elif attention_kernel == "cudnn_flash_te": return _cudnn_flash_attention(query, key, value, heads, mesh, dpa_layer) @@ -575,6 +591,7 @@ def _apply_attention( raise ValueError(f"Unexpected attention kernel {attention_kernel=}.") + def _query_chunk_attention(query, key, value, precision, key_chunk_size: int = 4096): """Multi-head dot product attention with a limited number of queries.""" num_kv, num_heads, k_features = key.shape[-3:] @@ -679,7 +696,6 @@ def apply_rope(xq: Array, xk: Array, freqs_cis: Array) -> tuple[Array, Array]: return xq_out.reshape(*xq.shape).astype(xq.dtype), xk_out.reshape(*xk.shape).astype(xk.dtype) -# New Class for Wan I2V class NNXSimpleFeedForward(nnx.Module): def __init__( @@ -750,25 +766,7 @@ def __init__( ): self.dpa_layer = None if attention_kernel == "cudnn_flash_te": - from transformer_engine.jax.flax.transformer import DotProductAttention # pytype: disable=import-error - - jax.config.update("jax_use_shardy_partitioner", False) - - dpa_layer = DotProductAttention( - head_dim=dim_head, - num_attention_heads=heads, - num_gqa_groups=heads, - attn_mask_type="no_mask", # 'no_mask', 'padding', 'causal', or 'padding_causal' - attn_bias_type="NO_BIAS", # 'no_bias', 'pre_scale_bias' or 'post_scale_bias' - # attention_dropout=self.dropout_rate, - dropout_rng_name="aqt", - dtype=dtype, - qkv_layout="BSHD_BSHD_BSHD", # 'BS3HD', 'BSHD_BS2HD' or 'BSHD_BSHD_BSHD' - scale_factor=scale, - transpose_batch_sequence=False, - ) - variables = {} - self.dpa_layer = functools.partial(dpa_layer.apply, variables) + raise NotImplementedError(f"{self} has not been tested with {attention_kernel}") self.mesh = mesh self.scale = scale @@ -787,7 +785,7 @@ def __init__( self.mask_padding_tokens = mask_padding_tokens self.residual_checkpoint_name = residual_checkpoint_name - def apply_attention(self, query: Array, key: Array, value: Array, attention_mask: Array = None): + def apply_attention(self, query: Array, key: Array, value: Array): return _apply_attention( query=query, key=key, @@ -808,7 +806,6 @@ def apply_attention(self, query: Array, key: Array, value: Array, attention_mask dpa_layer=self.dpa_layer, mask_padding_tokens=self.mask_padding_tokens, residual_checkpoint_name=self.residual_checkpoint_name, - attention_mask=attention_mask, ) @@ -833,9 +830,7 @@ def setup(self): if self.attention_kernel == "cudnn_flash_te": from transformer_engine.jax.flax.transformer import DotProductAttention # pytype: disable=import-error - jax.config.update("jax_use_shardy_partitioner", False) - - dpa_layer = DotProductAttention( + self.dpa_layer = DotProductAttention( head_dim=self.dim_head, num_attention_heads=self.heads, num_gqa_groups=self.heads, @@ -849,10 +844,8 @@ def setup(self): scale_factor=self.scale, transpose_batch_sequence=False, ) - variables = {} - self.dpa_layer = functools.partial(dpa_layer.apply, variables) - def apply_attention(self, query: Array, key: Array, value: Array, attention_mask: Array = None): + def apply_attention(self, query: Array, key: Array, value: Array): return _apply_attention( query=query, key=key, @@ -871,7 +864,6 @@ def apply_attention(self, query: Array, key: Array, value: Array, attention_mask axis_names_kv=self.axis_names_kv, flash_block_sizes=self.flash_block_sizes, dpa_layer=self.dpa_layer, - attention_mask=attention_mask, ) @@ -906,9 +898,12 @@ def __init__( mask_padding_tokens: bool = True, residual_checkpoint_name: str | None = None, enable_jax_named_scopes: bool = False, - added_kv_proj_dim: Optional[int] = None, # New for I2V - image_seq_len: Optional[int] = None, # New for I2V + added_kv_proj_dim: Optional[int] = None, + image_seq_len: Optional[int] = None, ): + if attention_kernel == "cudnn_flash_te": + raise NotImplementedError(f"Wan 2.1 has not been tested with {attention_kernel}") + if attention_kernel in {"flash", "cudnn_flash_te"} and mesh is None: raise ValueError(f"The flash attention kernel requires a value for mesh, but mesh is {self.mesh}") self.dim_head = dim_head @@ -928,9 +923,8 @@ def __init__( else: axis_names_q = (BATCH, CROSS_ATTN_HEAD, CROSS_ATTN_Q_LENGTH, D_KV) axis_names_kv = (BATCH, CROSS_ATTN_HEAD, CROSS_ATTN_KV_LENGTH, D_KV) - self.added_kv_proj_dim = added_kv_proj_dim # New for I2V - self.image_seq_len = image_seq_len # New for I2V - + if attention_kernel == "tokamax_ring" and not is_self_attention: + attention_kernel = "tokamax_flash" # do not use ring attention for cross attention self.attention_op = NNXAttentionOp( mesh=mesh, attention_kernel=attention_kernel, @@ -1011,7 +1005,7 @@ def __init__( ), ) - self.drop_out = nnx.Dropout(dropout, deterministic=False) + self.drop_out = nnx.Dropout(dropout) self.norm_q = nnx.data(None) self.norm_k = nnx.data(None) @@ -1039,47 +1033,6 @@ def __init__( param_dtype=weights_dtype, ) - # New layers for I2V image conditioning - self.add_k_proj = nnx.data(None) - self.add_v_proj = nnx.data(None) - self.norm_added_k = nnx.data(None) - if self.added_kv_proj_dim is not None: - self.add_k_proj = nnx.Linear( - self.added_kv_proj_dim, - self.inner_dim, - rngs=rngs, - dtype=dtype, - param_dtype=weights_dtype, - precision=precision, - bias_init=nnx.with_partitioning( - nnx.initializers.zeros, - ("embed",), - ), - ) - self.add_v_proj = nnx.Linear( - self.added_kv_proj_dim, - self.inner_dim, - rngs=rngs, - dtype=dtype, - param_dtype=weights_dtype, - precision=precision, - bias_init=nnx.with_partitioning( - nnx.initializers.zeros, - ("embed",), - ), - ) - self.norm_added_k = nnx.RMSNorm( - num_features=self.inner_dim, - rngs=rngs, - epsilon=eps, - dtype=dtype, - param_dtype=weights_dtype, - scale_init=nnx.with_partitioning( - nnx.initializers.ones, - ("norm",), - ), - ) - def _apply_rope(self, xq: jax.Array, xk: jax.Array, freqs_cis: jax.Array) -> Tuple[jax.Array, jax.Array]: # 1. Extract cos and sin, keeping them in native bfloat16 cos = jnp.real(freqs_cis).astype(xq.dtype) @@ -1123,126 +1076,45 @@ def __call__( hidden_states = jax.lax.with_sharding_constraint(hidden_states, axis_names) encoder_hidden_states = jax.lax.with_sharding_constraint(encoder_hidden_states, axis_names) dtype = hidden_states.dtype - is_self_attention = encoder_hidden_states is None if encoder_hidden_states is None: encoder_hidden_states = hidden_states - is_i2v_cross_attention = self.added_kv_proj_dim is not None and not is_self_attention - - if not is_i2v_cross_attention: - with jax.named_scope("query_proj"): - query_proj = self.query(hidden_states) - with jax.named_scope("key_proj"): - key_proj = self.key(encoder_hidden_states) - with jax.named_scope("value_proj"): - value_proj = self.value(encoder_hidden_states) - - if self.qk_norm: - with self.conditional_named_scope("attn_q_norm"): - query_proj = self.norm_q(query_proj) - with self.conditional_named_scope("attn_k_norm"): - key_proj = self.norm_k(key_proj) - - if rotary_emb is not None: - with self.conditional_named_scope("attn_rope"): - query_proj = _unflatten_heads(query_proj, self.heads) - key_proj = _unflatten_heads(key_proj, self.heads) - value_proj = _unflatten_heads(value_proj, self.heads) - # output of _unflatten_heads Batch, heads, seq_len, head_dim - query_proj, key_proj = self._apply_rope(query_proj, key_proj, rotary_emb) - - query_proj = checkpoint_name(query_proj, "query_proj") - key_proj = checkpoint_name(key_proj, "key_proj") - value_proj = checkpoint_name(value_proj, "value_proj") - - with jax.named_scope("apply_attention"): - attn_output = self.attention_op.apply_attention(query_proj, key_proj, value_proj) - - else: - # NEW PATH for I2V CROSS-ATTENTION - with self.conditional_named_scope("proj_query"): - query_proj_raw = self.query(hidden_states) - - # Image embeddings are padded to multiples of 128 for TPU flash attention - # Calculate the padded length to correctly split image and text embeddings - if self.added_kv_proj_dim is not None: - alignment = 128 - if self.image_seq_len is not None: - image_seq_len_actual = self.image_seq_len - else: - image_seq_len_actual = 257 - padded_img_len = ((image_seq_len_actual + alignment - 1) // alignment) * alignment # 257 -> 384 - - if encoder_attention_mask is None: - padded_img_len = image_seq_len_actual - - encoder_hidden_states_img = encoder_hidden_states[:, :padded_img_len, :] - encoder_hidden_states_text = encoder_hidden_states[:, padded_img_len:, :] - - # Use the passed encoder_attention_mask (created in embeddings_flax.py) if using Flash Attention - # It contains the image mask: [1]*257 + [0]*127 for 257 real image tokens padded to 384 - if encoder_attention_mask is not None: - encoder_attention_mask_img = encoder_attention_mask[:, :padded_img_len] - else: - # Fallback: no mask means treat all as valid (for dot product attention) - encoder_attention_mask_img = None - else: - # If no image_seq_len is specified, treat all as text - encoder_hidden_states_img = None - encoder_hidden_states_text = encoder_hidden_states - encoder_attention_mask_img = None - - if self.qk_norm: - with self.conditional_named_scope("attn_q_norm"): - query_proj_text = self.norm_q(query_proj_raw) - else: - query_proj_text = query_proj_raw - - # Text K/V - with self.conditional_named_scope("proj_key"): - key_proj_text = self.key(encoder_hidden_states_text) - if self.qk_norm: - with self.conditional_named_scope("attn_k_norm"): - key_proj_text = self.norm_k(key_proj_text) - with self.conditional_named_scope("proj_value"): - value_proj_text = self.value(encoder_hidden_states_text) - - # Image K/V (only if image embeddings are present) - if encoder_hidden_states_img is not None: - with self.conditional_named_scope("add_proj_k"): - key_proj_img = self.add_k_proj(encoder_hidden_states_img) - with self.conditional_named_scope("norm_add_k"): - key_proj_img = self.norm_added_k(key_proj_img) - with self.conditional_named_scope("add_proj_v"): - value_proj_img = self.add_v_proj(encoder_hidden_states_img) - query_proj_img = query_proj_raw - # Check norm_added_k too - # Checkpointing - query_proj_text = checkpoint_name(query_proj_text, "query_proj") - key_proj_text = checkpoint_name(key_proj_text, "key_proj_text") - value_proj_text = checkpoint_name(value_proj_text, "value_proj_text") - key_proj_img = checkpoint_name(key_proj_img, "key_proj_img") - value_proj_img = checkpoint_name(value_proj_img, "value_proj_img") - query_proj_img = checkpoint_name(query_proj_img, "query_proj_img") - - # Attention - tensors are (B, S, D) - with self.conditional_named_scope("cross_attn_text_apply"): - attn_output_text = self.attention_op.apply_attention(query_proj_text, key_proj_text, value_proj_text) - with self.conditional_named_scope("cross_attn_img_apply"): - # Pass encoder_attention_mask_img for image cross-attention to mask padded tokens - attn_output_img = self.attention_op.apply_attention( - query_proj_img, key_proj_img, value_proj_img, attention_mask=encoder_attention_mask_img - ) - - attn_output = attn_output_text + attn_output_img - else: - # No image embeddings, only text cross-attention - query_proj_text = checkpoint_name(query_proj_text, "query_proj") - key_proj_text = checkpoint_name(key_proj_text, "key_proj_text") - value_proj_text = checkpoint_name(value_proj_text, "value_proj_text") - - with self.conditional_named_scope("cross_attn_text_apply"): - attn_output = self.attention_op.apply_attention(query_proj_text, key_proj_text, value_proj_text) + with jax.named_scope("query_proj"): + query_proj = self.query(hidden_states) + with jax.named_scope("key_proj"): + key_proj = self.key(encoder_hidden_states) + with jax.named_scope("value_proj"): + value_proj = self.value(encoder_hidden_states) + + if self.qk_norm: + with self.conditional_named_scope("attn_q_norm"): + query_proj = self.norm_q(query_proj) + with self.conditional_named_scope("attn_k_norm"): + key_proj = self.norm_k(key_proj) + + if rotary_emb is not None: + with self.conditional_named_scope("attn_rope"): + axis_names_rope = nn.logical_to_mesh_axes((None, None, LENGTH, None)) + rotary_emb = jax.lax.with_sharding_constraint(rotary_emb, axis_names_rope) + query_proj = _unflatten_heads(query_proj, self.heads) + key_proj = _unflatten_heads(key_proj, self.heads) + value_proj = _unflatten_heads(value_proj, self.heads) + + # Enforce sequence parallelism on the new axis 2 (LENGTH) before doing the ROPE math + axis_names_qkv = nn.logical_to_mesh_axes((BATCH, HEAD, LENGTH, D_KV)) + query_proj = jax.lax.with_sharding_constraint(query_proj, axis_names_qkv) + key_proj = jax.lax.with_sharding_constraint(key_proj, axis_names_qkv) + value_proj = jax.lax.with_sharding_constraint(value_proj, axis_names_qkv) + + # output of _unflatten_heads Batch, heads, seq_len, head_dim + query_proj, key_proj = self._apply_rope(query_proj, key_proj, rotary_emb) + + query_proj = checkpoint_name(query_proj, "query_proj") + key_proj = checkpoint_name(key_proj, "key_proj") + value_proj = checkpoint_name(value_proj, "value_proj") + + with jax.named_scope("apply_attention"): + attn_output = self.attention_op.apply_attention(query_proj, key_proj, value_proj) attn_output = attn_output.astype(dtype=dtype) attn_output = checkpoint_name(attn_output, "attn_output") @@ -1363,6 +1235,7 @@ def setup(self): ) def __call__(self, hidden_states, encoder_hidden_states=None, attention_mask=None, image_rotary_emb=None): + qkv_proj = self.qkv(hidden_states) B, L = hidden_states.shape[:2] H, D, K = self.heads, qkv_proj.shape[-1] // (self.heads * 3), 3 @@ -1374,6 +1247,7 @@ def __call__(self, hidden_states, encoder_hidden_states=None, attention_mask=Non key_proj = self.key_norm(key_proj) if encoder_hidden_states is not None: + encoder_qkv_proj = self.encoder_qkv(encoder_hidden_states) B, L = encoder_hidden_states.shape[:2] H, D, K = self.heads, encoder_qkv_proj.shape[-1] // (self.heads * 3), 3 @@ -1467,6 +1341,7 @@ class FlaxAttention(nn.Module): quant: Quant = None def setup(self): + if self.attention_kernel == "flash" and self.mesh is None: raise ValueError(f"The flash attention kernel requires a value for mesh, but mesh is {self.mesh}") inner_dim = self.dim_head * self.heads diff --git a/src/maxdiffusion/models/vae_flax.py b/src/maxdiffusion/models/vae_flax.py index 042ec275..86ec80b7 100644 --- a/src/maxdiffusion/models/vae_flax.py +++ b/src/maxdiffusion/models/vae_flax.py @@ -30,6 +30,7 @@ from .modeling_flax_utils import FlaxModelMixin + @flax.struct.dataclass class FlaxDecoderOutput(BaseOutput): """ @@ -933,6 +934,8 @@ def __call__(self, sample, sample_posterior=False, deterministic: bool = True, r return FlaxDecoderOutput(sample=sample) + + class WanDiagonalGaussianDistribution(FlaxDiagonalGaussianDistribution): pass diff --git a/src/maxdiffusion/models/wan/autoencoder_kl_wan.py b/src/maxdiffusion/models/wan/autoencoder_kl_wan.py index 0328f6ac..9ecbf29c 100644 --- a/src/maxdiffusion/models/wan/autoencoder_kl_wan.py +++ b/src/maxdiffusion/models/wan/autoencoder_kl_wan.py @@ -19,12 +19,15 @@ import flax import jax import jax.numpy as jnp +from jax.sharding import NamedSharding, PartitionSpec as P from jax import tree_util from flax import nnx -from ...configuration_utils import ConfigMixin -from ..modeling_flax_utils import FlaxModelMixin, get_activation -from ... import common_types -from ..vae_flax import ( + +# Absolute imports based on maxdiffusion root structure +from maxdiffusion.configuration_utils import ConfigMixin +from maxdiffusion.models.modeling_flax_utils import FlaxModelMixin, get_activation +from maxdiffusion import common_types +from maxdiffusion.models.vae_flax import ( FlaxAutoencoderKLOutput, FlaxDiagonalGaussianDistribution, FlaxDecoderOutput, @@ -67,85 +70,83 @@ def __eq__(self, other): class WanCausalConv3d(nnx.Module): + def __init__( + self, + rngs: nnx.Rngs, + in_channels: int, + out_channels: int, + kernel_size: Union[int, Tuple[int, int, int]], + stride: Union[int, Tuple[int, int, int]] = 1, + padding: Union[int, Tuple[int, int, int]] = 0, + use_bias: bool = True, + mesh: jax.sharding.Mesh = None, + dtype: jnp.dtype = jnp.float32, + weights_dtype: jnp.dtype = jnp.float32, + precision: jax.lax.Precision = None, + ): + self.kernel_size = _canonicalize_tuple(kernel_size, 3, "kernel_size") + self.stride = _canonicalize_tuple(stride, 3, "stride") + padding_tuple = _canonicalize_tuple(padding, 3, "padding") + + self._causal_padding = ( + (0, 0), + (2 * padding_tuple[0], 0), + (padding_tuple[1], padding_tuple[1]), + (padding_tuple[2], padding_tuple[2]), + (0, 0), + ) + self._depth_padding_before = self._causal_padding[1][0] + self.mesh = mesh + + # Weight sharding (Kernel is sharded along output channels) + num_fsdp_devices = mesh.shape["vae_spatial"] + kernel_sharding = (None, None, None, None, None) + if out_channels % num_fsdp_devices == 0: + kernel_sharding = (None, None, None, None, "vae_spatial") + + self.conv = nnx.Conv( + in_features=in_channels, + out_features=out_channels, + kernel_size=self.kernel_size, + strides=self.stride, + use_bias=use_bias, + padding="VALID", + rngs=rngs, + kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), kernel_sharding), + dtype=dtype, + param_dtype=weights_dtype, + precision=precision, + ) - def __init__( - self, - rngs: nnx.Rngs, # rngs are required for initializing parameters, - in_channels: int, - out_channels: int, - kernel_size: Union[int, Tuple[int, int, int]], - stride: Union[int, Tuple[int, int, int]] = 1, - padding: Union[int, Tuple[int, int, int]] = 0, - use_bias: bool = True, - mesh: jax.sharding.Mesh = None, - dtype: jnp.dtype = jnp.float32, - weights_dtype: jnp.dtype = jnp.float32, - precision: jax.lax.Precision = None, - ): - self.kernel_size = _canonicalize_tuple(kernel_size, 3, "kernel_size") - self.stride = _canonicalize_tuple(stride, 3, "stride") - padding_tuple = _canonicalize_tuple(padding, 3, "padding") # (D, H, W) padding amounts - - self._causal_padding = ( - (0, 0), # Batch dimension - no padding - (2 * padding_tuple[0], 0), # Depth dimension - causal padding (pad only before) - (padding_tuple[1], padding_tuple[1]), # Height dimension - symmetric padding - (padding_tuple[2], padding_tuple[2]), # Width dimension - symmetric padding - (0, 0), # Channel dimension - no padding - ) - - # Store the amount of padding needed *before* the depth dimension for caching logic - self._depth_padding_before = self._causal_padding[1][0] # 2 * padding_tuple[0] - - self.mesh = mesh - # Set sharding dynamically based on out_channels. - num_context_axis_devices = mesh.shape["context"] - kernel_sharding = (None, None, None, None, None) - if out_channels % num_context_axis_devices == 0: - kernel_sharding = (None, None, None, None, "conv_out") - - self.conv = nnx.Conv( - in_features=in_channels, - out_features=out_channels, - kernel_size=self.kernel_size, - strides=self.stride, - use_bias=use_bias, - padding="VALID", # Handle padding manually - rngs=rngs, - kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), kernel_sharding), - dtype=dtype, - param_dtype=weights_dtype, - precision=precision, - ) - - def __call__(self, x: jax.Array, cache_x: Optional[jax.Array] = None, idx=-1) -> jax.Array: - current_padding = list(self._causal_padding) # Mutable copy - padding_needed = self._depth_padding_before - - if cache_x is not None and padding_needed > 0: - # Ensure cache has same spatial/channel dims, potentially different depth - assert cache_x.shape[0] == x.shape[0] and cache_x.shape[2:] == x.shape[2:], "Cache spatial/channel dims mismatch" - cache_len = cache_x.shape[1] - x = jnp.concatenate([cache_x, x], axis=1) # Concat along depth (D) - - padding_needed -= cache_len - if padding_needed < 0: - # Cache longer than needed padding, trim from start - x = x[:, -padding_needed:, ...] - current_padding[1] = (0, 0) # No explicit padding needed now - else: - # Update depth padding needed - current_padding[1] = (padding_needed, 0) - - # Apply padding if any dimension requires it - padding_to_apply = tuple(current_padding) - if any(p > 0 for dim_pads in padding_to_apply for p in dim_pads): - x_padded = jnp.pad(x, padding_to_apply, mode="constant", constant_values=0.0) - else: - x_padded = x + def __call__(self, x: jax.Array, cache_x: Optional[jax.Array] = None, idx=-1) -> jax.Array: + # Sharding Width (index 3) + # Spec: (Batch, Time, Height, Width, Channels) + spatial_sharding = NamedSharding(self.mesh, P(None, None, None, "vae_spatial", None)) + x = jax.lax.with_sharding_constraint(x, spatial_sharding) + + current_padding = list(self._causal_padding) + padding_needed = self._depth_padding_before + + if cache_x is not None and padding_needed > 0: + assert cache_x.shape[0] == x.shape[0] and cache_x.shape[2:] == x.shape[2:] + cache_len = cache_x.shape[1] + x = jnp.concatenate([cache_x, x], axis=1) + + padding_needed -= cache_len + if padding_needed < 0: + x = x[:, -padding_needed:, ...] + current_padding[1] = (0, 0) + else: + current_padding[1] = (padding_needed, 0) + + padding_to_apply = tuple(current_padding) + if any(p > 0 for dim_pads in padding_to_apply for p in dim_pads): + x_padded = jnp.pad(x, padding_to_apply, mode="constant", constant_values=0.0) + else: + x_padded = x - out = self.conv(x_padded) - return out + out = self.conv(x_padded) + return out class WanRMS_norm(nnx.Module): @@ -759,7 +760,6 @@ def __init__( precision=precision, ) - @nnx.jit(static_argnames="feat_idx") def __call__(self, x: jax.Array, feat_cache=None, feat_idx=0): if feat_cache is not None: idx = feat_idx @@ -908,7 +908,6 @@ def __init__( precision=precision, ) - @nnx.jit(static_argnames="feat_idx") def __call__(self, x: jax.Array, feat_cache=None, feat_idx=0): if feat_cache is not None: idx = feat_idx @@ -946,49 +945,33 @@ def __call__(self, x: jax.Array, feat_cache=None, feat_idx=0): class AutoencoderKLWanCache: - - def __init__(self, module): - self.module = module - - def _count_conv3d(module): - count = 0 - node_types = nnx.graph.iter_graph([module]) - for _, value in node_types: - if isinstance(value, WanCausalConv3d): - count += 1 - return count - - self._conv_num = _count_conv3d(self.module.decoder) - self._enc_conv_num = _count_conv3d(self.module.encoder) - self.init_cache() - - def init_cache(self): - """Resets cache dictionaries and indices""" - self._feat_map = (None,) * self._conv_num - # cache encode - self._enc_feat_map = (None,) * self._enc_conv_num - + def __init__(self, module): + self.module = module + def _count_conv3d(m): + count = 0 + for _, value in nnx.graph.iter_graph([m]): + if isinstance(value, WanCausalConv3d): + count += 1 + return count + self._conv_num = _count_conv3d(self.module.decoder) + self._enc_conv_num = _count_conv3d(self.module.encoder) + self.init_cache() + + def init_cache(self): + self._feat_map = (None,) * self._conv_num + self._enc_feat_map = (None,) * self._enc_conv_num def _wan_cache_flatten(cache): - return (cache._feat_map, cache._enc_feat_map), (cache._conv_num, cache._enc_conv_num) - + return (cache._feat_map, cache._enc_feat_map), (cache._conv_num, cache._enc_conv_num) def _wan_cache_unflatten(aux, children): - conv_num, enc_conv_num = aux - feat_map, enc_feat_map = children - # Create a dummy object or one without module reference for JIT internal use - # We can't easily reconstruct 'module' but we don't need it for init_cache anymore - # if we store counts in aux. - # However, __init__ expects module. - # We will bypass __init__ for unflattening. - obj = AutoencoderKLWanCache.__new__(AutoencoderKLWanCache) - obj._conv_num = conv_num - obj._enc_conv_num = enc_conv_num - obj._feat_map = feat_map - obj._enc_feat_map = enc_feat_map - obj.module = None # module is not needed inside the trace for the cache logic now - return obj - + conv_num, enc_conv_num = aux + feat_map, enc_feat_map = children + obj = AutoencoderKLWanCache.__new__(AutoencoderKLWanCache) + obj._conv_num, obj._enc_conv_num = conv_num, enc_conv_num + obj._feat_map, obj._enc_feat_map = feat_map, enc_feat_map + obj.module = None + return obj tree_util.register_pytree_node(AutoencoderKLWanCache, _wan_cache_flatten, _wan_cache_unflatten) @@ -1102,7 +1085,9 @@ def __init__( weights_dtype=weights_dtype, precision=precision, ) + self.mesh = mesh + @nnx.jit def _encode(self, x: jax.Array, feat_cache: AutoencoderKLWanCache): feat_cache.init_cache() if x.shape[-1] != 3: @@ -1114,17 +1099,58 @@ def _encode(self, x: jax.Array, feat_cache: AutoencoderKLWanCache): iter_ = 1 + (t - 1) // 4 enc_feat_map = feat_cache._enc_feat_map - for i in range(iter_): - enc_conv_idx = 0 - if i == 0: - out, enc_feat_map, enc_conv_idx = self.encoder(x[:, :1, :, :, :], feat_cache=enc_feat_map, feat_idx=enc_conv_idx) - else: - out_, enc_feat_map, enc_conv_idx = self.encoder( - x[:, 1 + 4 * (i - 1) : 1 + 4 * i, :, :, :], + spatial_sharding = NamedSharding(self.mesh, P(None, None, None, "vae_spatial", None)) + + # First iteration (i=0): size 1 + chunk_0 = x[:, :1, ...] + out_0, enc_feat_map, _ = self.encoder( + chunk_0, + feat_cache=enc_feat_map, + feat_idx=0 + ) + out_0 = jax.lax.with_sharding_constraint(out_0, spatial_sharding) + + if iter_ > 1: + # We must adjust enc_feat_map from None/'Rep'/'zeros' for scan shapes. + # By running chunk 1 outside the scan, the PyTree shapes will reach their stable state. + chunk_1 = x[:, 1:5, ...] + out_1, enc_feat_map, _ = self.encoder( + chunk_1, feat_cache=enc_feat_map, - feat_idx=enc_conv_idx, + feat_idx=0 ) - out = jnp.concatenate([out, out_], axis=1) + out_1 = jax.lax.with_sharding_constraint(out_1, spatial_sharding) + out_list = [out_0, out_1] + + if iter_ > 2: + # Prepare the remaining chunks (each size 4) to be scanned over + # x_rest shape: (B, (iter_-2)*4, H, W, C) + x_rest = x[:, 5:, ...] + # Reshape to (iter_-2, B, 4, H, W, C) for jax.lax.scan + x_scannable = x_rest.reshape(x_rest.shape[0], iter_ - 2, 4, x_rest.shape[2], x_rest.shape[3], x_rest.shape[4]) + x_scannable = jnp.transpose(x_scannable, (1, 0, 2, 3, 4, 5)) + + def scan_fn(carry, chunk): + current_feat_map = carry + out_chunk, next_feat_map, _ = self.encoder( + chunk, + feat_cache=current_feat_map, + feat_idx=0 + ) + out_chunk = jax.lax.with_sharding_constraint(out_chunk, spatial_sharding) + return next_feat_map, out_chunk + + enc_feat_map, out_rest = jax.lax.scan(scan_fn, enc_feat_map, x_scannable) + # out_rest shape: (iter_-2, B, T', H, W, C) -> transpose back + out_rest = jnp.transpose(out_rest, (1, 0, 2, 3, 4, 5)) + # reshape to (B, (iter_-2)*T', H, W, C) + out_rest = out_rest.reshape(out_rest.shape[0], -1, out_rest.shape[3], out_rest.shape[4], out_rest.shape[5]) + out_list.append(out_rest) + + out = jnp.concatenate(out_list, axis=1) + out = jax.lax.with_sharding_constraint(out, spatial_sharding) + else: + out = out_0 # Update back to the wrapper object if needed, but for result we use local vars feat_cache._enc_feat_map = enc_feat_map @@ -1145,6 +1171,7 @@ def encode( return (posterior,) return FlaxAutoencoderKLOutput(latent_dist=posterior) + @nnx.jit def _decode( self, z: jax.Array, feat_cache: AutoencoderKLWanCache, return_dict: bool = True ) -> Union[FlaxDecoderOutput, jax.Array]: @@ -1153,30 +1180,71 @@ def _decode( x = self.post_quant_conv(z) dec_feat_map = feat_cache._feat_map - - for i in range(iter_): - conv_idx = 0 - if i == 0: - out, dec_feat_map, conv_idx = self.decoder(x[:, i : i + 1, :, :, :], feat_cache=dec_feat_map, feat_idx=conv_idx) - else: - out_, dec_feat_map, conv_idx = self.decoder(x[:, i : i + 1, :, :, :], feat_cache=dec_feat_map, feat_idx=conv_idx) - - # This is to bypass an issue where frame[1] should be frame[2] and vise versa. - # Ideally shouldn't need to do this however, can't find where the frame is going out of sync. - # Most likely due to an incorrect reshaping in the decoder. - fm1, fm2, fm3, fm4 = out_[:, 0, :, :, :], out_[:, 1, :, :, :], out_[:, 2, :, :, :], out_[:, 3, :, :, :] - # When batch_size is 0, expand batch dim for concatenation - # else, expand frame dim for concatenation so that batch dim stays intact. - axis = 0 - if fm1.shape[0] > 1: - axis = 1 - - if len(fm1.shape) == 4: - fm1 = jnp.expand_dims(fm1, axis=axis) - fm2 = jnp.expand_dims(fm2, axis=axis) - fm3 = jnp.expand_dims(fm3, axis=axis) - fm4 = jnp.expand_dims(fm4, axis=axis) - out = jnp.concatenate([out, fm1, fm3, fm2, fm4], axis=1) + # NamedSharding for the Width axis (axis 3) + spatial_sharding = NamedSharding(self.mesh, P(None, None, None, "vae_spatial", None)) + + # First chunk (i=0) + chunk_in_0 = jax.lax.with_sharding_constraint(x[:, 0:1, ...], spatial_sharding) + out_0, dec_feat_map, _ = self.decoder( + chunk_in_0, + feat_cache=dec_feat_map, + feat_idx=0 + ) + out_0 = jax.lax.with_sharding_constraint(out_0, spatial_sharding) + + if iter_ > 1: + # Run chunk 1 outside scan to properly form the cache shape + chunk_in_1 = jax.lax.with_sharding_constraint(x[:, 1:2, ...], spatial_sharding) + out_chunk_1, dec_feat_map, _ = self.decoder( + chunk_in_1, + feat_cache=dec_feat_map, + feat_idx=0 + ) + out_chunk_1 = jax.lax.with_sharding_constraint(out_chunk_1, spatial_sharding) + + # Frame re-sync logic for chunk 1 + fm1, fm2, fm3, fm4 = out_chunk_1[:, 0, ...], out_chunk_1[:, 1, ...], out_chunk_1[:, 2, ...], out_chunk_1[:, 3, ...] + axis = 1 if fm1.shape[0] > 1 else 0 + fm1, fm2, fm3, fm4 = [jnp.expand_dims(f, axis=axis) for f in [fm1, fm2, fm3, fm4]] + out_1 = jnp.concatenate([fm1, fm3, fm2, fm4], axis=1) + + out_list = [out_0, out_1] + + if iter_ > 2: + x_rest = x[:, 2:, ...] + # Reshape for scan: (iter_-2, B, 1, H, W, C) + x_scannable = jnp.transpose(x_rest, (1, 0, 2, 3, 4)) + x_scannable = jnp.expand_dims(x_scannable, axis=2) + + def scan_fn(carry, chunk_in): + current_feat_map = carry + chunk_in = jax.lax.with_sharding_constraint(chunk_in, spatial_sharding) + out_chunk, next_feat_map, _ = self.decoder( + chunk_in, + feat_cache=current_feat_map, + feat_idx=0 + ) + out_chunk = jax.lax.with_sharding_constraint(out_chunk, spatial_sharding) + + # Frame re-sync logic + fm1, fm2, fm3, fm4 = out_chunk[:, 0, ...], out_chunk[:, 1, ...], out_chunk[:, 2, ...], out_chunk[:, 3, ...] + axis = 1 if fm1.shape[0] > 1 else 0 + fm1, fm2, fm3, fm4 = [jnp.expand_dims(f, axis=axis) for f in [fm1, fm2, fm3, fm4]] + new_chunk = jnp.concatenate([fm1, fm3, fm2, fm4], axis=1) + + return next_feat_map, new_chunk + + dec_feat_map, out_rest = jax.lax.scan(scan_fn, dec_feat_map, x_scannable) + + # out_rest is (iter_-2, B, 4, H, W, C) -> transpose back + out_rest = jnp.transpose(out_rest, (1, 0, 2, 3, 4, 5)) + out_rest = out_rest.reshape(out_rest.shape[0], -1, out_rest.shape[3], out_rest.shape[4], out_rest.shape[5]) + out_list.append(out_rest) + + out = jnp.concatenate(out_list, axis=1) + out = jax.lax.with_sharding_constraint(out, spatial_sharding) + else: + out = out_0 feat_cache._feat_map = dec_feat_map diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline.py b/src/maxdiffusion/pipelines/wan/wan_pipeline.py index 86c9f9c2..6a3902d4 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline.py @@ -17,6 +17,7 @@ from functools import partial from maxdiffusion.image_processor import PipelineImageInput import numpy as np +import math import jax import jax.numpy as jnp from jax.sharding import Mesh, NamedSharding, PartitionSpec as P @@ -219,6 +220,7 @@ def __init__( config: HyperParameters, image_processor: Optional[CLIPImageProcessor] = None, image_encoder: Optional[FlaxCLIPVisionModel] = None, + **kwargs, ): self.tokenizer = tokenizer self.text_encoder = text_encoder @@ -233,6 +235,9 @@ def __init__( self.image_processor = image_processor self.image_encoder = image_encoder + self.vae_mesh = kwargs.get("vae_mesh", mesh) + self.vae_logical_axis_rules = kwargs.get("vae_logical_axis_rules", config.logical_axis_rules) + self.vae_scale_factor_temporal = 2 ** sum(self.vae.temperal_downsample) if getattr(self, "vae", None) else 4 self.vae_scale_factor_spatial = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8 self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) @@ -268,7 +273,7 @@ def load_image_encoder(cls, config: HyperParameters): return image_processor, image_encoder @classmethod - def load_vae(cls, devices_array: np.array, mesh: Mesh, rngs: nnx.Rngs, config: HyperParameters): + def load_vae(cls, devices_array: np.array, mesh: Mesh, rngs: nnx.Rngs, config: HyperParameters, vae_logical_axis_rules: tuple = None): def create_model(rngs: nnx.Rngs, config: HyperParameters): wan_vae = AutoencoderKLWan.from_config( config.pretrained_model_name_or_path, @@ -287,7 +292,8 @@ def create_model(rngs: nnx.Rngs, config: HyperParameters): # 2. retrieve the state shardings, mapping logical names to mesh axis names. logical_state_spec = nnx.get_partition_spec(state) - logical_state_sharding = nn.logical_to_mesh_sharding(logical_state_spec, mesh, config.logical_axis_rules) + logical_rules = vae_logical_axis_rules if vae_logical_axis_rules is not None else config.logical_axis_rules + logical_state_sharding = nn.logical_to_mesh_sharding(logical_state_spec, mesh, logical_rules) logical_state_sharding = dict(nnx.to_flat_state(logical_state_sharding)) params = state.to_pure_dict() state = dict(nnx.to_flat_state(state)) @@ -569,7 +575,7 @@ def _denormalize_latents(self, latents: jax.Array) -> jax.Array: def _decode_latents_to_video(self, latents: jax.Array) -> np.ndarray: """Decodes latents to video frames and postprocesses.""" - with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): + with self.vae_mesh, nn_partitioning.axis_rules(self.vae_logical_axis_rules): video = self.vae.decode(latents, self.vae_cache)[0] video = jnp.transpose(video, (0, 4, 1, 2, 3)) @@ -581,22 +587,50 @@ def _decode_latents_to_video(self, latents: jax.Array) -> np.ndarray: def _create_common_components(cls, config, vae_only=False, i2v=False): devices_array = max_utils.create_device_mesh(config) mesh = Mesh(devices_array, config.mesh_axes) + + vae_spatial = getattr(config, "vae_spatial", -1) + total_devices = math.prod(devices_array.shape) + + if vae_spatial <= 0: + dp_size = mesh.shape.get("data", 1) + if dp_size == -1 or dp_size == 0: + dp_size = 1 + vae_spatial = (2 * total_devices) // dp_size + + assert total_devices % vae_spatial == 0, f"total devices ({total_devices}) must be a multiple of vae_spatial ({vae_spatial})" + + flat_devices = devices_array.flatten() + vae_devices_array = flat_devices.reshape(total_devices // vae_spatial, vae_spatial) + + vae_mesh = Mesh(vae_devices_array, ("redundant", "vae_spatial")) + vae_mesh.vae_spatial_axis_name = "vae_spatial" + max_logging.log(f"Created VAE specific mesh with axes ('redundant', 'vae_spatial') to support spatial sharding of {vae_spatial}.") + + # logical axis rules for VAE encoding/decoding + vae_logical_axis_rules = ( + ("activation_batch", "redundant"), + ("activation_length", "vae_spatial"), + ("activation_heads", None), + ("activation_kv_length", None), + ("embed", None), + ("heads", None), + ("norm", None), + ("conv_batch", "redundant"), + ("out_channels", "vae_spatial"), + ("conv_out", "vae_spatial") + ) + rng = jax.random.key(config.seed) rngs = nnx.Rngs(rng) - with mesh: - wan_vae, vae_cache = cls.load_vae(devices_array=devices_array, mesh=mesh, rngs=rngs, config=config) + with vae_mesh: + wan_vae, vae_cache = cls.load_vae(devices_array=devices_array, mesh=vae_mesh, rngs=rngs, config=config, vae_logical_axis_rules=vae_logical_axis_rules) components = { - "vae": wan_vae, - "vae_cache": vae_cache, - "devices_array": devices_array, - "rngs": rngs, - "mesh": mesh, - "tokenizer": None, - "text_encoder": None, - "scheduler": None, - "scheduler_state": None, + "vae": wan_vae, "vae_cache": vae_cache, + "devices_array": devices_array, "rngs": rngs, "mesh": mesh, "vae_mesh": vae_mesh, + "vae_logical_axis_rules": vae_logical_axis_rules, + "tokenizer": None, "text_encoder": None, "scheduler": None, "scheduler_state": None, "image_processor": None, "image_encoder": None, } diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline_2_1.py b/src/maxdiffusion/pipelines/wan/wan_pipeline_2_1.py index 976f0f04..d0aae14e 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline_2_1.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline_2_1.py @@ -47,16 +47,18 @@ def _load_and_init(cls, config, restored_checkpoint=None, vae_only=False, load_t ) pipeline = cls( - tokenizer=common_components["tokenizer"], - text_encoder=common_components["text_encoder"], - transformer=transformer, - vae=common_components["vae"], - vae_cache=common_components["vae_cache"], - scheduler=common_components["scheduler"], - scheduler_state=common_components["scheduler_state"], - devices_array=common_components["devices_array"], - mesh=common_components["mesh"], - config=config, + tokenizer=common_components["tokenizer"], + text_encoder=common_components["text_encoder"], + transformer=transformer, + vae=common_components["vae"], + vae_cache=common_components["vae_cache"], + scheduler=common_components["scheduler"], + scheduler_state=common_components["scheduler_state"], + devices_array=common_components["devices_array"], + mesh=common_components["mesh"], + vae_mesh=common_components["vae_mesh"], + vae_logical_axis_rules=common_components["vae_logical_axis_rules"], + config=config, ) return pipeline, transformer diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline_2_2.py b/src/maxdiffusion/pipelines/wan/wan_pipeline_2_2.py index b8f818e3..2ff7019e 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline_2_2.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline_2_2.py @@ -73,6 +73,8 @@ def _load_and_init(cls, config, restored_checkpoint=None, vae_only=False, load_t scheduler_state=common_components["scheduler_state"], devices_array=common_components["devices_array"], mesh=common_components["mesh"], + vae_mesh=common_components["vae_mesh"], + vae_logical_axis_rules=common_components["vae_logical_axis_rules"], config=config, ) return pipeline, low_noise_transformer, high_noise_transformer diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p1.py b/src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p1.py index 0622ec79..8c89e3fa 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p1.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p1.py @@ -61,6 +61,8 @@ def _load_and_init(cls, config, restored_checkpoint=None, vae_only=False, load_t scheduler_state=common_components["scheduler_state"], devices_array=common_components["devices_array"], mesh=common_components["mesh"], + vae_mesh=common_components["vae_mesh"], + vae_logical_axis_rules=common_components["vae_logical_axis_rules"], config=config, ) return pipeline, transformer diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p2.py b/src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p2.py index 65e78674..4ad8c514 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p2.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p2.py @@ -79,6 +79,8 @@ def _load_and_init(cls, config, restored_checkpoint=None, vae_only=False, load_t scheduler_state=common_components["scheduler_state"], devices_array=common_components["devices_array"], mesh=common_components["mesh"], + vae_mesh=common_components["vae_mesh"], + vae_logical_axis_rules=common_components["vae_logical_axis_rules"], config=config, ) return pipeline, low_noise_transformer, high_noise_transformer diff --git a/src/maxdiffusion/pyconfig.py b/src/maxdiffusion/pyconfig.py index ebcdeaea..6571ca37 100644 --- a/src/maxdiffusion/pyconfig.py +++ b/src/maxdiffusion/pyconfig.py @@ -50,7 +50,6 @@ def _validate_training_model_name(model_name: str | None): f"Invalid config.model_name '{model_name}' for training. Allowed values: {sorted(_ALLOWED_TRAINING_MODEL_NAMES)}" ) - def string_to_bool(s: str) -> bool: if s.lower() == "true": return True @@ -200,9 +199,9 @@ def user_init(raw_keys): raw_keys["logical_axis_rules"] = _lists_to_tuples(raw_keys["logical_axis_rules"]) # Verify qkv is sharded across sequence. - if raw_keys["attention"] == "ring" or raw_keys["attention_sharding_uniform"]: + if "ring" in raw_keys["attention"] or raw_keys["attention_sharding_uniform"]: max_logging.log( - f"Adding sequence sharding to q and kv if not already present because {raw_keys['attention']}=='ring' or {raw_keys['attention_sharding_uniform']} is set." + f"Adding sequence sharding to q and kv if not already present because '{raw_keys['attention']}' contains 'ring' or {raw_keys['attention_sharding_uniform']} is set." ) logical_axis_rules = list(raw_keys["logical_axis_rules"]) max_logging.log(f"Initial logical axis rules: {logical_axis_rules}") @@ -213,12 +212,12 @@ def user_init(raw_keys): logical_axis_rules.append(q_seq_sharding) if kv_seq_sharding not in logical_axis_rules: logical_axis_rules.append(kv_seq_sharding) - if raw_keys["attention"] == "ring": + if "ring" in raw_keys["attention"]: for ring_attention_axis_rule in RING_ATTENTION_AXIS_RULES: if ring_attention_axis_rule not in logical_axis_rules: max_logging.log(f"Adding ring attention axis rule {ring_attention_axis_rule}") new_rules.append(ring_attention_axis_rule) - else: # attention =flash but sequence parallel sharding requested for both self and cross attention + else: # attention contains 'flash' but sequence parallel sharding requested for both self and cross attention for seq_parallel_axis_rule in SEQUENCE_PARALLEL_AXIS_RULES: if seq_parallel_axis_rule not in logical_axis_rules: max_logging.log(f"Adding sequence parallel attention axis rule {seq_parallel_axis_rule}") @@ -256,6 +255,13 @@ def user_init(raw_keys): raw_keys["global_batch_size_to_train_on"], ) = _HyperParameters.calculate_global_batch_sizes(raw_keys["per_device_batch_size"]) + if raw_keys.get("vae_spatial", -1) == -1: + total_device = len(jax.devices()) + dp = raw_keys.get("ici_data_parallelism", 1) * raw_keys.get("dcn_data_parallelism", 1) + if dp == -1 or dp == 0: + dp = 1 + raw_keys["vae_spatial"] = (total_device * 2) // dp + def get_num_slices(raw_keys): if int(raw_keys["compile_topology_num_slices"]) > 0: diff --git a/src/maxdiffusion/tests/wan_transformer_test.py b/src/maxdiffusion/tests/wan_transformer_test.py index 4d54525d..554a5588 100644 --- a/src/maxdiffusion/tests/wan_transformer_test.py +++ b/src/maxdiffusion/tests/wan_transformer_test.py @@ -65,6 +65,7 @@ def setUp(self): devices_array = create_device_mesh(config) self.mesh = Mesh(devices_array, config.mesh_axes) + def test_rotary_pos_embed(self): batch_size = 1 channels = 16 diff --git a/src/maxdiffusion/trainers/flux_trainer.py b/src/maxdiffusion/trainers/flux_trainer.py index 7ef5c536..54aac946 100644 --- a/src/maxdiffusion/trainers/flux_trainer.py +++ b/src/maxdiffusion/trainers/flux_trainer.py @@ -386,6 +386,7 @@ def training_loop(self, p_train_step, pipeline, params, train_states, data_itera if self.config.enable_profiler and step == last_profiling_step: max_utils.deactivate_profiler(self.config) + max_utils.upload_profiler_traces(self.config) train_states[FLUX_STATE_KEY] = flux_state if len(times) > 0: