From 9c09e5d865fa18ab8036cc15d4e1a065a10e40e9 Mon Sep 17 00:00:00 2001 From: LouisCastricato Date: Thu, 12 Mar 2026 00:14:58 -0400 Subject: [PATCH 1/6] initial blockwise sparse attention implementation. Aprox .4ms per dispatch. --- .gitignore | 1 + src/metal/metal_flex_attn.metal | 724 ++++++++++++++++++++++++ src/metal/metal_flex_attn_op.mm | 767 ++++++++++++++++++++++++++ src/model/attn.py | 28 +- src/model/attn_backend.py | 135 +++++ src/model/kv_cache.py | 29 +- src/patch_model.py | 16 +- tests/conftest.py | 185 +++++++ tests/test_attn_module_integration.py | 119 ++++ tests/test_metal_attn_numeric.py | 540 ++++++++++++++++++ tests/test_metal_attn_perf.py | 355 ++++++++++++ 11 files changed, 2885 insertions(+), 14 deletions(-) create mode 100644 src/metal/metal_flex_attn.metal create mode 100644 src/metal/metal_flex_attn_op.mm create mode 100644 src/model/attn_backend.py create mode 100644 tests/conftest.py create mode 100644 tests/test_attn_module_integration.py create mode 100644 tests/test_metal_attn_numeric.py create mode 100644 tests/test_metal_attn_perf.py diff --git a/.gitignore b/.gitignore index 7e47bb8..9e6bd23 100644 --- a/.gitignore +++ b/.gitignore @@ -11,3 +11,4 @@ docs/_build/ *.pyc __pycache__/ *.egg-info/ +.build/ \ No newline at end of file diff --git a/src/metal/metal_flex_attn.metal b/src/metal/metal_flex_attn.metal new file mode 100644 index 0000000..8c76600 --- /dev/null +++ b/src/metal/metal_flex_attn.metal @@ -0,0 +1,724 @@ +#include +using namespace metal; + +/* + Hybrid Metal attention kernel design (inference-only) + ---------------------------------------------------- + + Goals: + - Forward-only attention for Q, K, V with block/window sparsity driven by + metadata from Python. + - Run entirely on Apple GPU (no CPU fallbacks), targeting M-series chips. + - Serve as a drop-in backend for the world_flex_attn_forward API. + + Tensor layouts (matching AttnMeta): + - Q: [B, H, T, Dh] -> flattened as [B*H, T, Dh] + - K: [B, H, L, Dh] -> flattened as [B*H, L, Dh] + - V: [B, H, L, Dh] -> flattened as [B*H, L, Dh] + - Output: [B, H, T, Dh] -> [B*H, T, Dh] + + Precision: + - Inputs in fp16 or bf16; internally promote to fp32 for accumulation. + - Outputs written in the same dtype as inputs. + + Tiling: + - Each threadgroup processes a tile of (t_block, kv_block) for a single + (batch, head) pair: + - t_block: contiguous query positions in [0, T) + - kv_block: contiguous key/value positions in [0, L) + - Within a tile: + - Load a small Dh chunk into threadgroup memory for Q and K. + - Compute partial QK^T / sqrt(d) scores. + - Apply block/window sparsity mask provided via metadata buffer. + - Accumulate softmax-normalized attention * V to produce output. + + Sparsity metadata: + - For the first implementation, the Metal kernel will consume a dense + boolean mask per (t_block, kv_block) encoded as a uint8_t buffer, with: + mask[b*h*T + t, L] == 1 for valid KV positions, 0 otherwise. + - Later this can be compressed to a block-list representation that mirrors + the BlockMask.from_kv_blocks semantics. + + NOTE: + - The actual math implementation is intentionally left minimal and will be + iterated on together with the C++/PyTorch custom op bridge. + */ + +kernel void metal_flex_attn_forward( + device const half* q, + device const half* k, + device const half* v, + device const int* active_blocks, // [active_count] block indices + device half* out, + constant uint& B, + constant uint& Hq, + constant uint& T, + constant uint& L, + constant uint& Dh, + constant uint& block_size, + constant uint& active_count, + constant uint& causal, + constant uint& Hkv, + constant uint& fp16_accum, + uint tid [[thread_position_in_grid]], + uint lane_id [[thread_index_in_simdgroup]], + uint simd_size [[threads_per_simdgroup]] +) { + (void)fp16_accum; + const uint total_queries = B * Hq * T; + if (simd_size == 0) { + return; + } + const uint qid = tid / simd_size; + if (qid >= total_queries) { + return; + } + + const uint bh = qid / T; + const uint t = qid % T; + const uint b = bh / Hq; + const uint hq = bh % Hq; + const uint group_size = max((uint)1, Hq / max((uint)1, Hkv)); + const uint hkv = min(hq / group_size, max((uint)0, Hkv - 1)); + + const uint q_offset = (((b * Hq + hq) * T + t) * Dh); + const uint kv_base = (((b * Hkv + hkv) * L) * Dh); + const uint out_offset = q_offset; + const float inv_sqrt_dh = rsqrt((float)Dh); + const uint safe_block_size = max((uint)1, block_size); + const uint kMaxDh = 128; + + if (Dh > kMaxDh) { + for (uint d = 0; d < Dh; ++d) { + out[out_offset + d] = half(0.0h); + } + return; + } + + const uint kv_limit = (causal != 0) ? min((uint)L, t + 1) : (uint)L; + if (kv_limit == 0) { + for (uint d = 0; d < Dh; ++d) { + out[out_offset + d] = half(0.0h); + } + return; + } + + // SIMD-cooperative online softmax: + // each lane owns a strided subset of Dh and collaborates on dot-product + // reductions for every KV token. + float m = -INFINITY; + float l_acc = 0.0f; + uint owned_dims[4]; + float q_regs[4]; + float acc[4]; + uint owned_count = 0; + for (uint d = lane_id; d < Dh; d += simd_size) { + if (owned_count < 4) { + owned_dims[owned_count] = d; + q_regs[owned_count] = (float)q[q_offset + d]; + acc[owned_count] = 0.0f; + owned_count++; + } + } + + // Iterate by block to avoid per-token block-index division and reduce + // branch pressure when many blocks are masked out. + for (uint ai = 0; ai < active_count; ++ai) { + const uint bidx = (uint)active_blocks[ai]; + const uint block_start = bidx * safe_block_size; + if (block_start >= kv_limit) { + break; + } + const uint block_end = min(kv_limit, block_start + safe_block_size); + for (uint kv_idx = block_start; kv_idx < block_end; ++kv_idx) { + float dot_local = 0.0f; + const uint k_offset = kv_base + kv_idx * Dh; + for (uint i = 0; i < owned_count; ++i) { + const uint d = owned_dims[i]; + dot_local += q_regs[i] * (float)k[k_offset + d]; + } + const float dot = simd_sum(dot_local); + const float s = dot * inv_sqrt_dh; + const float m_new = max(m, s); + const float alpha = fast::exp(m - m_new); + const float beta = fast::exp(s - m_new); + const uint v_offset = kv_base + kv_idx * Dh; + + for (uint i = 0; i < owned_count; ++i) { + const uint d2 = owned_dims[i]; + acc[i] = acc[i] * alpha + beta * (float)v[v_offset + d2]; + } + l_acc = l_acc * alpha + beta; + m = m_new; + } + } + + if (!(l_acc > 0.0f)) { + for (uint i = 0; i < owned_count; ++i) { + out[out_offset + owned_dims[i]] = half(0.0h); + } + return; + } + + const float inv_l = 1.0f / l_acc; + for (uint i = 0; i < owned_count; ++i) { + out[out_offset + owned_dims[i]] = half(acc[i] * inv_l); + } +} + +kernel void metal_flex_attn_forward_dh64_bs4_single( + device const half* q, + device const half* k, + device const half* v, + device const int* active_blocks, + device half* out, + constant uint& B, + constant uint& Hq, + constant uint& T, + constant uint& L, + constant uint& Dh, + constant uint& block_size, + constant uint& active_count, + constant uint& causal, + constant uint& Hkv, + constant uint& fp16_accum, + uint tid [[thread_position_in_grid]], + uint lane_id [[thread_index_in_simdgroup]] +) { + if (Dh != 64u || block_size != 4u) { + return; + } + + const uint total_queries = B * Hq * T; + const uint qid = tid >> 5; // /32 + if (qid >= total_queries) { + return; + } + + const uint bh = qid / T; + const uint t = qid % T; + const uint b = bh / Hq; + const uint hq = bh % Hq; + const uint group_size = max((uint)1, Hq / max((uint)1, Hkv)); + const uint hkv = min(hq / group_size, max((uint)0, Hkv - 1)); + + const uint q_offset = (((b * Hq + hq) * T + t) * 64u); + const uint kv_base = (((b * Hkv + hkv) * L) * 64u); + const uint out_offset = q_offset; + const float inv_sqrt_dh = 0.125f; + const uint d_pair = lane_id << 1; // contiguous pair in [0, 62] + const uint kv_limit = (causal != 0) ? min((uint)L, t + 1u) : (uint)L; + if (kv_limit == 0u) { + out[out_offset + d_pair + 0u] = half(0.0h); + out[out_offset + d_pair + 1u] = half(0.0h); + return; + } + const float2 q2 = float2( + (float)q[q_offset + d_pair + 0u], + (float)q[q_offset + d_pair + 1u] + ); + + float m = -INFINITY; + float l_acc = 0.0f; + const bool use_fp16_accum = (fp16_accum != 0u); + float2 acc2 = float2(0.0f); + half2 acc2_h = half2((half)0.0h); + + for (uint ai = 0; ai < active_count; ++ai) { + const uint block_start = ((uint)active_blocks[ai]) << 2; + if (block_start >= kv_limit) { + break; + } + + const uint kv0 = block_start + 0u; + if (kv0 < kv_limit) { + const uint k0 = kv_base + kv0 * 64u; + const float2 k20 = float2( + (float)k[k0 + d_pair + 0u], + (float)k[k0 + d_pair + 1u] + ); + const float dot0 = simd_sum(q2.x * k20.x + q2.y * k20.y); + const float s0 = dot0 * inv_sqrt_dh; + const float m0 = max(m, s0); + const float a0 = fast::exp(m - m0); + const float b0 = fast::exp(s0 - m0); + const uint v0 = kv_base + kv0 * 64u; + if (use_fp16_accum) { + const half2 v20_h = half2(v[v0 + d_pair + 0u], v[v0 + d_pair + 1u]); + acc2_h = acc2_h * half(a0) + v20_h * half(b0); + } else { + const float2 v20 = float2( + (float)v[v0 + d_pair + 0u], + (float)v[v0 + d_pair + 1u] + ); + acc2 = acc2 * a0 + v20 * b0; + } + l_acc = l_acc * a0 + b0; + m = m0; + } + + const uint kv1 = block_start + 1u; + if (kv1 < kv_limit) { + const uint k1 = kv_base + kv1 * 64u; + const float2 k21 = float2( + (float)k[k1 + d_pair + 0u], + (float)k[k1 + d_pair + 1u] + ); + const float dot1 = simd_sum(q2.x * k21.x + q2.y * k21.y); + const float s1 = dot1 * inv_sqrt_dh; + const float m1 = max(m, s1); + const float a1 = fast::exp(m - m1); + const float b1 = fast::exp(s1 - m1); + const uint v1 = kv_base + kv1 * 64u; + if (use_fp16_accum) { + const half2 v21_h = half2(v[v1 + d_pair + 0u], v[v1 + d_pair + 1u]); + acc2_h = acc2_h * half(a1) + v21_h * half(b1); + } else { + const float2 v21 = float2( + (float)v[v1 + d_pair + 0u], + (float)v[v1 + d_pair + 1u] + ); + acc2 = acc2 * a1 + v21 * b1; + } + l_acc = l_acc * a1 + b1; + m = m1; + } + + const uint kv2 = block_start + 2u; + if (kv2 < kv_limit) { + const uint k2 = kv_base + kv2 * 64u; + const float2 k22 = float2( + (float)k[k2 + d_pair + 0u], + (float)k[k2 + d_pair + 1u] + ); + const float dot2 = simd_sum(q2.x * k22.x + q2.y * k22.y); + const float s2 = dot2 * inv_sqrt_dh; + const float m2 = max(m, s2); + const float a2 = fast::exp(m - m2); + const float b2 = fast::exp(s2 - m2); + const uint v2 = kv_base + kv2 * 64u; + if (use_fp16_accum) { + const half2 v22_h = half2(v[v2 + d_pair + 0u], v[v2 + d_pair + 1u]); + acc2_h = acc2_h * half(a2) + v22_h * half(b2); + } else { + const float2 v22 = float2( + (float)v[v2 + d_pair + 0u], + (float)v[v2 + d_pair + 1u] + ); + acc2 = acc2 * a2 + v22 * b2; + } + l_acc = l_acc * a2 + b2; + m = m2; + } + + const uint kv3 = block_start + 3u; + if (kv3 < kv_limit) { + const uint k3 = kv_base + kv3 * 64u; + const float2 k23 = float2( + (float)k[k3 + d_pair + 0u], + (float)k[k3 + d_pair + 1u] + ); + const float dot3 = simd_sum(q2.x * k23.x + q2.y * k23.y); + const float s3 = dot3 * inv_sqrt_dh; + const float m3 = max(m, s3); + const float a3 = fast::exp(m - m3); + const float b3 = fast::exp(s3 - m3); + const uint v3 = kv_base + kv3 * 64u; + if (use_fp16_accum) { + const half2 v23_h = half2(v[v3 + d_pair + 0u], v[v3 + d_pair + 1u]); + acc2_h = acc2_h * half(a3) + v23_h * half(b3); + } else { + const float2 v23 = float2( + (float)v[v3 + d_pair + 0u], + (float)v[v3 + d_pair + 1u] + ); + acc2 = acc2 * a3 + v23 * b3; + } + l_acc = l_acc * a3 + b3; + m = m3; + } + } + + if (!(l_acc > 0.0f)) { + out[out_offset + d_pair + 0u] = half(0.0h); + out[out_offset + d_pair + 1u] = half(0.0h); + return; + } + const float inv_l = 1.0f / l_acc; + const float2 acc_out = use_fp16_accum ? float2(acc2_h) : acc2; + out[out_offset + d_pair + 0u] = half(acc_out.x * inv_l); + out[out_offset + d_pair + 1u] = half(acc_out.y * inv_l); +} + +kernel void metal_flex_attn_forward_dh64_bs4_gqa2_single( + device const half* q, + device const half* k, + device const half* v, + device const int* active_blocks, + device half* out, + constant uint& B, + constant uint& Hq, + constant uint& T, + constant uint& L, + constant uint& Dh, + constant uint& block_size, + constant uint& active_count, + constant uint& causal, + constant uint& Hkv, + constant uint& fp16_accum, + uint tid [[thread_position_in_grid]], + uint lane_id [[thread_index_in_simdgroup]] +) { + // Specialization for the common GQA=2 case (Hq = 2 * Hkv). + if (Dh != 64u || block_size != 4u || Hq != (Hkv << 1)) { + return; + } + + const uint total_queries = B * Hq * T; + const uint qid = tid >> 5; // /32 + if (qid >= total_queries) { + return; + } + + const uint bh = qid / T; + const uint t = qid % T; + const uint b = bh / Hq; + const uint hq = bh % Hq; + const uint hkv = hq >> 1; // exact for GQA=2 + + const uint q_offset = (((b * Hq + hq) * T + t) * 64u); + const uint kv_base = (((b * Hkv + hkv) * L) * 64u); + const uint out_offset = q_offset; + const float inv_sqrt_dh = 0.125f; + const uint d_pair = lane_id << 1; + const uint kv_limit = (causal != 0) ? min((uint)L, t + 1u) : (uint)L; + if (kv_limit == 0u) { + out[out_offset + d_pair + 0u] = half(0.0h); + out[out_offset + d_pair + 1u] = half(0.0h); + return; + } + + const float2 q2 = float2( + (float)q[q_offset + d_pair + 0u], + (float)q[q_offset + d_pair + 1u] + ); + + float m = -INFINITY; + float l_acc = 0.0f; + const bool use_fp16_accum = (fp16_accum != 0u); + float2 acc2 = float2(0.0f); + half2 acc2_h = half2((half)0.0h); + + for (uint ai = 0; ai < active_count; ++ai) { + const uint block_start = ((uint)active_blocks[ai]) << 2; + if (block_start >= kv_limit) { + break; + } + + const uint kv0 = block_start + 0u; + if (kv0 < kv_limit) { + const uint k0 = kv_base + kv0 * 64u; + const float2 k20 = float2((float)k[k0 + d_pair + 0u], (float)k[k0 + d_pair + 1u]); + const float dot0 = simd_sum(q2.x * k20.x + q2.y * k20.y); + const float s0 = dot0 * inv_sqrt_dh; + const float m0 = max(m, s0); + const float a0 = fast::exp(m - m0); + const float b0 = fast::exp(s0 - m0); + const uint v0 = kv_base + kv0 * 64u; + if (use_fp16_accum) { + const half2 v20_h = half2(v[v0 + d_pair + 0u], v[v0 + d_pair + 1u]); + acc2_h = acc2_h * half(a0) + v20_h * half(b0); + } else { + const float2 v20 = float2((float)v[v0 + d_pair + 0u], (float)v[v0 + d_pair + 1u]); + acc2 = acc2 * a0 + v20 * b0; + } + l_acc = l_acc * a0 + b0; + m = m0; + } + + const uint kv1 = block_start + 1u; + if (kv1 < kv_limit) { + const uint k1 = kv_base + kv1 * 64u; + const float2 k21 = float2((float)k[k1 + d_pair + 0u], (float)k[k1 + d_pair + 1u]); + const float dot1 = simd_sum(q2.x * k21.x + q2.y * k21.y); + const float s1 = dot1 * inv_sqrt_dh; + const float m1 = max(m, s1); + const float a1 = fast::exp(m - m1); + const float b1 = fast::exp(s1 - m1); + const uint v1 = kv_base + kv1 * 64u; + if (use_fp16_accum) { + const half2 v21_h = half2(v[v1 + d_pair + 0u], v[v1 + d_pair + 1u]); + acc2_h = acc2_h * half(a1) + v21_h * half(b1); + } else { + const float2 v21 = float2((float)v[v1 + d_pair + 0u], (float)v[v1 + d_pair + 1u]); + acc2 = acc2 * a1 + v21 * b1; + } + l_acc = l_acc * a1 + b1; + m = m1; + } + + const uint kv2 = block_start + 2u; + if (kv2 < kv_limit) { + const uint k2 = kv_base + kv2 * 64u; + const float2 k22 = float2((float)k[k2 + d_pair + 0u], (float)k[k2 + d_pair + 1u]); + const float dot2 = simd_sum(q2.x * k22.x + q2.y * k22.y); + const float s2 = dot2 * inv_sqrt_dh; + const float m2 = max(m, s2); + const float a2 = fast::exp(m - m2); + const float b2 = fast::exp(s2 - m2); + const uint v2 = kv_base + kv2 * 64u; + if (use_fp16_accum) { + const half2 v22_h = half2(v[v2 + d_pair + 0u], v[v2 + d_pair + 1u]); + acc2_h = acc2_h * half(a2) + v22_h * half(b2); + } else { + const float2 v22 = float2((float)v[v2 + d_pair + 0u], (float)v[v2 + d_pair + 1u]); + acc2 = acc2 * a2 + v22 * b2; + } + l_acc = l_acc * a2 + b2; + m = m2; + } + + const uint kv3 = block_start + 3u; + if (kv3 < kv_limit) { + const uint k3 = kv_base + kv3 * 64u; + const float2 k23 = float2((float)k[k3 + d_pair + 0u], (float)k[k3 + d_pair + 1u]); + const float dot3 = simd_sum(q2.x * k23.x + q2.y * k23.y); + const float s3 = dot3 * inv_sqrt_dh; + const float m3 = max(m, s3); + const float a3 = fast::exp(m - m3); + const float b3 = fast::exp(s3 - m3); + const uint v3 = kv_base + kv3 * 64u; + if (use_fp16_accum) { + const half2 v23_h = half2(v[v3 + d_pair + 0u], v[v3 + d_pair + 1u]); + acc2_h = acc2_h * half(a3) + v23_h * half(b3); + } else { + const float2 v23 = float2((float)v[v3 + d_pair + 0u], (float)v[v3 + d_pair + 1u]); + acc2 = acc2 * a3 + v23 * b3; + } + l_acc = l_acc * a3 + b3; + m = m3; + } + } + + if (!(l_acc > 0.0f)) { + out[out_offset + d_pair + 0u] = half(0.0h); + out[out_offset + d_pair + 1u] = half(0.0h); + return; + } + const float inv_l = 1.0f / l_acc; + const float2 acc_out = use_fp16_accum ? float2(acc2_h) : acc2; + out[out_offset + d_pair + 0u] = half(acc_out.x * inv_l); + out[out_offset + d_pair + 1u] = half(acc_out.y * inv_l); +} + +kernel void metal_flex_attn_forward_dh64_bs4_gqa2_dualhead( + device const half* q, + device const half* k, + device const half* v, + device const int* active_blocks, + device half* out, + constant uint& B, + constant uint& Hq, + constant uint& T, + constant uint& L, + constant uint& Dh, + constant uint& block_size, + constant uint& active_count, + constant uint& causal, + constant uint& Hkv, + constant uint& fp16_accum, + uint tid [[thread_position_in_grid]], + uint lane_id [[thread_index_in_simdgroup]] +) { + // One simdgroup handles a (b, hkv, t) triplet and computes both query heads + // (2*hkv and 2*hkv+1), reusing each K/V load once. + if (Dh != 64u || block_size != 4u || Hq != (Hkv << 1)) { + return; + } + + const uint total_pairs = B * Hkv * T; + const uint pid = tid >> 5; // /32 + if (pid >= total_pairs) { + return; + } + + const uint bh = pid / T; + const uint t = pid % T; + const uint b = bh / Hkv; + const uint hkv = bh % Hkv; + const uint hq0 = hkv << 1; + const uint hq1 = hq0 + 1u; + + const uint q_offset0 = (((b * Hq + hq0) * T + t) * 64u); + const uint q_offset1 = (((b * Hq + hq1) * T + t) * 64u); + const uint out_offset0 = q_offset0; + const uint out_offset1 = q_offset1; + const uint kv_base = (((b * Hkv + hkv) * L) * 64u); + const float inv_sqrt_dh = 0.125f; + const uint d_pair = lane_id << 1; + const uint kv_limit = (causal != 0) ? min((uint)L, t + 1u) : (uint)L; + if (kv_limit == 0u || active_count == 0u) { + out[out_offset0 + d_pair + 0u] = half(0.0h); + out[out_offset0 + d_pair + 1u] = half(0.0h); + out[out_offset1 + d_pair + 0u] = half(0.0h); + out[out_offset1 + d_pair + 1u] = half(0.0h); + return; + } + + const float2 q20 = float2( + (float)q[q_offset0 + d_pair + 0u], + (float)q[q_offset0 + d_pair + 1u] + ); + const float2 q21 = float2( + (float)q[q_offset1 + d_pair + 0u], + (float)q[q_offset1 + d_pair + 1u] + ); + + const bool use_fp16_accum = (fp16_accum != 0u); + float m0 = -INFINITY, m1 = -INFINITY; + float l0 = 0.0f, l1 = 0.0f; + float2 acc0 = float2(0.0f), acc1 = float2(0.0f); + half2 acc0_h = half2((half)0.0h), acc1_h = half2((half)0.0h); + + for (uint ai = 0; ai < active_count; ++ai) { + const uint block_start = ((uint)active_blocks[ai]) << 2; + if (block_start >= kv_limit) { + break; + } + const uint block_end = min(kv_limit, block_start + 4u); + for (uint kv_idx = block_start; kv_idx < block_end; ++kv_idx) { + const uint k_off = kv_base + kv_idx * 64u; + const float2 k2 = float2((float)k[k_off + d_pair + 0u], (float)k[k_off + d_pair + 1u]); + const float dot0 = simd_sum(q20.x * k2.x + q20.y * k2.y); + const float dot1 = simd_sum(q21.x * k2.x + q21.y * k2.y); + const float s0 = dot0 * inv_sqrt_dh; + const float s1 = dot1 * inv_sqrt_dh; + + const float m0_new = max(m0, s0); + const float a0 = fast::exp(m0 - m0_new); + const float b0 = fast::exp(s0 - m0_new); + const float m1_new = max(m1, s1); + const float a1 = fast::exp(m1 - m1_new); + const float b1 = fast::exp(s1 - m1_new); + + const uint v_off = kv_base + kv_idx * 64u; + if (use_fp16_accum) { + const half2 v2_h = half2(v[v_off + d_pair + 0u], v[v_off + d_pair + 1u]); + acc0_h = acc0_h * half(a0) + v2_h * half(b0); + acc1_h = acc1_h * half(a1) + v2_h * half(b1); + } else { + const float2 v2 = float2((float)v[v_off + d_pair + 0u], (float)v[v_off + d_pair + 1u]); + acc0 = acc0 * a0 + v2 * b0; + acc1 = acc1 * a1 + v2 * b1; + } + l0 = l0 * a0 + b0; + l1 = l1 * a1 + b1; + m0 = m0_new; + m1 = m1_new; + } + } + + if (!(l0 > 0.0f)) { + out[out_offset0 + d_pair + 0u] = half(0.0h); + out[out_offset0 + d_pair + 1u] = half(0.0h); + } else { + const float inv_l0 = 1.0f / l0; + const float2 out0 = use_fp16_accum ? float2(acc0_h) : acc0; + out[out_offset0 + d_pair + 0u] = half(out0.x * inv_l0); + out[out_offset0 + d_pair + 1u] = half(out0.y * inv_l0); + } + + if (!(l1 > 0.0f)) { + out[out_offset1 + d_pair + 0u] = half(0.0h); + out[out_offset1 + d_pair + 1u] = half(0.0h); + } else { + const float inv_l1 = 1.0f / l1; + const float2 out1 = use_fp16_accum ? float2(acc1_h) : acc1; + out[out_offset1 + d_pair + 0u] = half(out1.x * inv_l1); + out[out_offset1 + d_pair + 1u] = half(out1.y * inv_l1); + } +} + +kernel void metal_flex_attn_forward_dh64_bs4_gqa2_dense( + device const half* q, + device const half* k, + device const half* v, + device const int* active_blocks, + device half* out, + constant uint& B, + constant uint& Hq, + constant uint& T, + constant uint& L, + constant uint& Dh, + constant uint& block_size, + constant uint& active_count, + constant uint& causal, + constant uint& Hkv, + constant uint& fp16_accum, + uint tid [[thread_position_in_grid]], + uint lane_id [[thread_index_in_simdgroup]] +) { + (void)active_blocks; + if (Dh != 64u || block_size != 4u || Hq != (Hkv << 1)) { + return; + } + + const uint total_queries = B * Hq * T; + const uint qid = tid >> 5; + if (qid >= total_queries) { + return; + } + + const uint bh = qid / T; + const uint t = qid % T; + const uint b = bh / Hq; + const uint hq = bh % Hq; + const uint hkv = hq >> 1; + + const uint q_offset = (((b * Hq + hq) * T + t) * 64u); + const uint kv_base = (((b * Hkv + hkv) * L) * 64u); + const uint out_offset = q_offset; + const float inv_sqrt_dh = 0.125f; + const uint d_pair = lane_id << 1; + const uint kv_limit = (causal != 0) ? min((uint)L, t + 1u) : (uint)L; + if (kv_limit == 0u || active_count == 0u) { + out[out_offset + d_pair + 0u] = half(0.0h); + out[out_offset + d_pair + 1u] = half(0.0h); + return; + } + + const float2 q2 = float2( + (float)q[q_offset + d_pair + 0u], + (float)q[q_offset + d_pair + 1u] + ); + + float m = -INFINITY; + float l_acc = 0.0f; + float2 acc2 = float2(0.0f); + + for (uint kv_idx = 0u; kv_idx < kv_limit; ++kv_idx) { + const uint k_off = kv_base + kv_idx * 64u; + const float2 k2 = float2((float)k[k_off + d_pair + 0u], (float)k[k_off + d_pair + 1u]); + const float dot = simd_sum(q2.x * k2.x + q2.y * k2.y); + const float s = dot * inv_sqrt_dh; + const float m_new = max(m, s); + const float a = fast::exp(m - m_new); + const float bcoef = fast::exp(s - m_new); + const uint v_off = kv_base + kv_idx * 64u; + const float2 v2 = float2((float)v[v_off + d_pair + 0u], (float)v[v_off + d_pair + 1u]); + acc2 = acc2 * a + v2 * bcoef; + l_acc = l_acc * a + bcoef; + m = m_new; + } + + if (!(l_acc > 0.0f)) { + out[out_offset + d_pair + 0u] = half(0.0h); + out[out_offset + d_pair + 1u] = half(0.0h); + return; + } + const float inv_l = 1.0f / l_acc; + out[out_offset + d_pair + 0u] = half(acc2.x * inv_l); + out[out_offset + d_pair + 1u] = half(acc2.y * inv_l); +} + + diff --git a/src/metal/metal_flex_attn_op.mm b/src/metal/metal_flex_attn_op.mm new file mode 100644 index 0000000..42a0d9b --- /dev/null +++ b/src/metal/metal_flex_attn_op.mm @@ -0,0 +1,767 @@ +#import +#import + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace { + +struct MetalRuntime { + id pipeline_generic = nil; + id pipeline_dh64_bs4_single = nil; + id pipeline_dh64_bs4_gqa2_single = nil; + id pipeline_dh64_bs4_gqa2_dense = nil; + id pipeline_dh64_bs4_gqa2_dualhead = nil; + uint32_t thread_execution_width_generic = 32; + uint32_t thread_execution_width_dh64_bs4_single = 32; + uint32_t thread_execution_width_dh64_bs4_gqa2_single = 32; + uint32_t thread_execution_width_dh64_bs4_gqa2_dense = 32; + uint32_t thread_execution_width_dh64_bs4_gqa2_dualhead = 32; + bool init_ok = false; +}; + +static inline id tensor_mtl_buffer(const at::Tensor& t) { + return (id)t.storage().data(); +} + +MetalRuntime& get_metal_runtime() { + static MetalRuntime rt; + static dispatch_once_t onceToken; + dispatch_once(&onceToken, ^{ + @autoreleasepool { + NSString* this_file = [NSString stringWithUTF8String:__FILE__]; + NSString* src_dir = [this_file stringByDeletingLastPathComponent]; + NSString* kernel_path = [src_dir stringByAppendingPathComponent:@"metal_flex_attn.metal"]; + + NSError* read_error = nil; + NSString* source = [NSString stringWithContentsOfFile:kernel_path + encoding:NSUTF8StringEncoding + error:&read_error]; + if (!source) { + return; + } + + id device = MTLCreateSystemDefaultDevice(); + if (!device) { + return; + } + + NSError* compile_error = nil; + MTLCompileOptions* opts = [[MTLCompileOptions alloc] init]; + opts.fastMathEnabled = YES; + id lib = [device newLibraryWithSource:source options:opts error:&compile_error]; + if (!lib) { + return; + } + + id fn_generic = [lib newFunctionWithName:@"metal_flex_attn_forward"]; + id fn_dh64_bs4_single = [lib newFunctionWithName:@"metal_flex_attn_forward_dh64_bs4_single"]; + id fn_dh64_bs4_gqa2_single = [lib newFunctionWithName:@"metal_flex_attn_forward_dh64_bs4_gqa2_single"]; + id fn_dh64_bs4_gqa2_dense = [lib newFunctionWithName:@"metal_flex_attn_forward_dh64_bs4_gqa2_dense"]; + id fn_dh64_bs4_gqa2_dualhead = [lib newFunctionWithName:@"metal_flex_attn_forward_dh64_bs4_gqa2_dualhead"]; + if (!fn_generic || !fn_dh64_bs4_single || !fn_dh64_bs4_gqa2_single || !fn_dh64_bs4_gqa2_dense || !fn_dh64_bs4_gqa2_dualhead) { + return; + } + + NSError* pipe_error = nil; + rt.pipeline_generic = [device newComputePipelineStateWithFunction:fn_generic error:&pipe_error]; + if (!rt.pipeline_generic) { + return; + } + rt.pipeline_dh64_bs4_single = [device newComputePipelineStateWithFunction:fn_dh64_bs4_single error:&pipe_error]; + if (!rt.pipeline_dh64_bs4_single) { + return; + } + rt.pipeline_dh64_bs4_gqa2_single = [device newComputePipelineStateWithFunction:fn_dh64_bs4_gqa2_single error:&pipe_error]; + if (!rt.pipeline_dh64_bs4_gqa2_single) { + return; + } + rt.pipeline_dh64_bs4_gqa2_dense = [device newComputePipelineStateWithFunction:fn_dh64_bs4_gqa2_dense error:&pipe_error]; + if (!rt.pipeline_dh64_bs4_gqa2_dense) { + return; + } + rt.pipeline_dh64_bs4_gqa2_dualhead = [device newComputePipelineStateWithFunction:fn_dh64_bs4_gqa2_dualhead error:&pipe_error]; + if (!rt.pipeline_dh64_bs4_gqa2_dualhead) { + return; + } + rt.thread_execution_width_generic = static_cast(rt.pipeline_generic.threadExecutionWidth); + rt.thread_execution_width_dh64_bs4_single = static_cast(rt.pipeline_dh64_bs4_single.threadExecutionWidth); + rt.thread_execution_width_dh64_bs4_gqa2_single = static_cast(rt.pipeline_dh64_bs4_gqa2_single.threadExecutionWidth); + rt.thread_execution_width_dh64_bs4_gqa2_dense = static_cast(rt.pipeline_dh64_bs4_gqa2_dense.threadExecutionWidth); + rt.thread_execution_width_dh64_bs4_gqa2_dualhead = static_cast(rt.pipeline_dh64_bs4_gqa2_dualhead.threadExecutionWidth); + rt.init_ok = true; + } + }); + return rt; +} + +static int64_t get_block_size() { + const char* env = std::getenv("WORLD_METAL_BLOCK_SIZE"); + if (!env) { + return 4; + } + const long parsed = std::strtol(env, nullptr, 10); + return parsed > 0 ? static_cast(parsed) : 4; +} + +static bool fast_no_fallback() { + const char* env = std::getenv("WORLD_METAL_FAST_NO_FALLBACK"); + if (!env) { + return false; + } + return std::string(env) == "1"; +} + +static uint32_t get_tg_size() { + const char* env = std::getenv("WORLD_METAL_TG_SIZE"); + if (!env) { + return 128; + } + const long parsed = std::strtol(env, nullptr, 10); + return parsed > 0 ? static_cast(parsed) : 128; +} + +static bool enable_gqa2_dualhead_specialization() { + const char* env = std::getenv("WORLD_METAL_ENABLE_GQA2_DUALHEAD"); + if (!env) { + return true; + } + const std::string s(env); + if (s == "1") { + return true; + } + if (s == "0") { + return false; + } + return true; +} + +static bool enable_fp16_accum() { + const char* env = std::getenv("WORLD_METAL_FP16_ACCUM"); + if (!env) { + return true; + } + return std::string(env) == "1"; +} + +static void dispatch_fast_kernel( + id pipeline, + uint32_t thread_execution_width, + const at::Tensor& q, + const at::Tensor& k, + const at::Tensor& v, + const at::Tensor& active_blocks, + at::Tensor& out, + uint32_t B, + uint32_t Hq, + uint32_t T, + uint32_t L, + uint32_t Dh, + uint32_t BlockSize, + uint32_t ActiveCount, + uint32_t Causal, + uint32_t Hkv, + uint32_t FP16Accum, + const char* err_prefix, + uint32_t tg_size_hint +) { + auto* stream = at::mps::getCurrentMPSStream(); + TORCH_CHECK(stream != nullptr, err_prefix, ": no active MPS stream"); + + id cb = (id)stream->commandBuffer(); + TORCH_CHECK(cb != nil, err_prefix, ": failed to acquire command buffer"); + id enc = [cb computeCommandEncoder]; + TORCH_CHECK(enc != nil, err_prefix, ": failed to create command encoder"); + [enc setComputePipelineState:pipeline]; + + [enc setBuffer:tensor_mtl_buffer(q) offset:q.storage_offset() * q.element_size() atIndex:0]; + [enc setBuffer:tensor_mtl_buffer(k) offset:k.storage_offset() * k.element_size() atIndex:1]; + [enc setBuffer:tensor_mtl_buffer(v) offset:v.storage_offset() * v.element_size() atIndex:2]; + [enc setBuffer:tensor_mtl_buffer(active_blocks) + offset:active_blocks.storage_offset() * active_blocks.element_size() + atIndex:3]; + [enc setBuffer:tensor_mtl_buffer(out) offset:out.storage_offset() * out.element_size() atIndex:4]; + [enc setBytes:&B length:sizeof(B) atIndex:5]; + [enc setBytes:&Hq length:sizeof(Hq) atIndex:6]; + [enc setBytes:&T length:sizeof(T) atIndex:7]; + [enc setBytes:&L length:sizeof(L) atIndex:8]; + [enc setBytes:&Dh length:sizeof(Dh) atIndex:9]; + [enc setBytes:&BlockSize length:sizeof(BlockSize) atIndex:10]; + [enc setBytes:&ActiveCount length:sizeof(ActiveCount) atIndex:11]; + [enc setBytes:&Causal length:sizeof(Causal) atIndex:12]; + [enc setBytes:&Hkv length:sizeof(Hkv) atIndex:13]; + [enc setBytes:&FP16Accum length:sizeof(FP16Accum) atIndex:14]; + + const uint32_t simd_width = std::max(1u, thread_execution_width); + const NSUInteger total = static_cast(B) * Hq * T * simd_width; + const NSUInteger tg_req = static_cast(tg_size_hint > 0u ? tg_size_hint : get_tg_size()); + const NSUInteger tg_aligned = MAX(simd_width, (tg_req / simd_width) * simd_width); + const NSUInteger tg = MIN(pipeline.maxTotalThreadsPerThreadgroup, tg_aligned); + const NSUInteger tg_count = (total + tg - 1) / tg; + [enc dispatchThreadgroups:MTLSizeMake(tg_count, 1, 1) threadsPerThreadgroup:MTLSizeMake(tg, 1, 1)]; + [enc endEncoding]; +} + +static void dispatch_fast_kernel_dh64_bs4_gqa2_dualhead( + id pipeline, + uint32_t thread_execution_width, + const at::Tensor& q, + const at::Tensor& k, + const at::Tensor& v, + const at::Tensor& active_blocks, + at::Tensor& out, + uint32_t B, + uint32_t Hq, + uint32_t T, + uint32_t L, + uint32_t Dh, + uint32_t BlockSize, + uint32_t ActiveCount, + uint32_t Causal, + uint32_t Hkv, + uint32_t FP16Accum, + const char* err_prefix, + uint32_t tg_size_hint +) { + auto* stream = at::mps::getCurrentMPSStream(); + TORCH_CHECK(stream != nullptr, err_prefix, ": no active MPS stream"); + + id cb = (id)stream->commandBuffer(); + TORCH_CHECK(cb != nil, err_prefix, ": failed to acquire command buffer"); + id enc = [cb computeCommandEncoder]; + TORCH_CHECK(enc != nil, err_prefix, ": failed to create command encoder"); + [enc setComputePipelineState:pipeline]; + + [enc setBuffer:tensor_mtl_buffer(q) offset:q.storage_offset() * q.element_size() atIndex:0]; + [enc setBuffer:tensor_mtl_buffer(k) offset:k.storage_offset() * k.element_size() atIndex:1]; + [enc setBuffer:tensor_mtl_buffer(v) offset:v.storage_offset() * v.element_size() atIndex:2]; + [enc setBuffer:tensor_mtl_buffer(active_blocks) + offset:active_blocks.storage_offset() * active_blocks.element_size() + atIndex:3]; + [enc setBuffer:tensor_mtl_buffer(out) offset:out.storage_offset() * out.element_size() atIndex:4]; + [enc setBytes:&B length:sizeof(B) atIndex:5]; + [enc setBytes:&Hq length:sizeof(Hq) atIndex:6]; + [enc setBytes:&T length:sizeof(T) atIndex:7]; + [enc setBytes:&L length:sizeof(L) atIndex:8]; + [enc setBytes:&Dh length:sizeof(Dh) atIndex:9]; + [enc setBytes:&BlockSize length:sizeof(BlockSize) atIndex:10]; + [enc setBytes:&ActiveCount length:sizeof(ActiveCount) atIndex:11]; + [enc setBytes:&Causal length:sizeof(Causal) atIndex:12]; + [enc setBytes:&Hkv length:sizeof(Hkv) atIndex:13]; + [enc setBytes:&FP16Accum length:sizeof(FP16Accum) atIndex:14]; + + const uint32_t simd_width = std::max(1u, thread_execution_width); + const NSUInteger total = static_cast(B) * Hkv * T * simd_width; + const NSUInteger tg_req = static_cast(tg_size_hint > 0u ? tg_size_hint : get_tg_size()); + const NSUInteger tg_aligned = MAX(simd_width, (tg_req / simd_width) * simd_width); + const NSUInteger tg = MIN(pipeline.maxTotalThreadsPerThreadgroup, tg_aligned); + const NSUInteger tg_count = (total + tg - 1) / tg; + [enc dispatchThreadgroups:MTLSizeMake(tg_count, 1, 1) threadsPerThreadgroup:MTLSizeMake(tg, 1, 1)]; + [enc endEncoding]; +} + +at::Tensor metal_flex_attn_ref_impl( + const at::Tensor& q, + const at::Tensor& k, + const at::Tensor& v, + const c10::optional& mask, + bool causal +) { + TORCH_CHECK(q.device().is_mps() && k.device().is_mps() && v.device().is_mps(), + "flex_attn_metal expects q/k/v on MPS"); + TORCH_CHECK(q.scalar_type() == at::kHalf && k.scalar_type() == at::kHalf && v.scalar_type() == at::kHalf, + "flex_attn_metal currently supports float16 only"); + TORCH_CHECK(q.is_contiguous() && k.is_contiguous() && v.is_contiguous(), + "flex_attn_metal expects contiguous q/k/v"); + TORCH_CHECK(k.sizes() == v.sizes(), "k and v must match"); + TORCH_CHECK(q.size(0) == k.size(0) && q.size(3) == k.size(3), + "q/k must match on batch and head dim"); + TORCH_CHECK(q.size(1) >= k.size(1), "q heads must be >= kv heads"); + TORCH_CHECK((q.size(1) % k.size(1)) == 0, "q heads must be divisible by kv heads for GQA"); + + // Phase-1 native implementation: route through known-good ATen math while + // ensuring we execute on the current MPS stream. This validates stream + // integration before re-introducing raw Metal buffer bindings. + auto* stream = at::mps::getCurrentMPSStream(); + TORCH_CHECK(stream != nullptr, "flex_attn_metal: no active MPS stream"); + (void)stream->commandBuffer(); + + const int64_t T = q.size(2); + const int64_t L = k.size(2); + const int64_t Dh = q.size(3); + + at::Tensor mask_tensor; + if (mask.has_value()) { + mask_tensor = *mask; + TORCH_CHECK(mask_tensor.device().is_mps(), "mask must be on MPS"); + TORCH_CHECK(mask_tensor.scalar_type() == at::kByte, "mask must be uint8"); + TORCH_CHECK(mask_tensor.is_contiguous(), "mask must be contiguous"); + TORCH_CHECK(mask_tensor.numel() == q.size(0) * q.size(1) * T * L, + "mask must have shape [B,H,T,L]"); + } + + auto qf = q.to(at::kFloat); + auto kf = k.to(at::kFloat); + auto vf = v.to(at::kFloat); + + if (q.size(1) != k.size(1)) { + const int64_t hq = q.size(1); + const int64_t hkv = k.size(1); + const int64_t group_size = hq / hkv; + std::vector map_vec(static_cast(hq)); + for (int64_t i = 0; i < hq; ++i) { + map_vec[static_cast(i)] = i / group_size; + } + auto head_map = at::tensor( + map_vec, + q.options().device(q.device()).dtype(at::kLong) + ); + kf = kf.index_select(/*dim=*/1, head_map); + vf = vf.index_select(/*dim=*/1, head_map); + } + + auto scores = at::matmul(qf, kf.transpose(-2, -1)) / std::sqrt(static_cast(Dh)); + if (mask.has_value()) { + scores = scores.masked_fill(mask_tensor.eq(0), -std::numeric_limits::infinity()); + } + if (causal) { + auto causal_mask = at::triu( + at::ones({T, L}, q.options().dtype(at::kBool)), + /*diagonal=*/1 + ); + scores = scores.masked_fill(causal_mask.unsqueeze(0).unsqueeze(0), -std::numeric_limits::infinity()); + } + + auto finite_row = at::isfinite(scores).any(-1, true); + auto safe_scores = at::where(finite_row, scores, at::zeros_like(scores)); + auto probs = at::softmax(safe_scores, -1); + probs = at::where(finite_row, probs, at::zeros_like(probs)); + + auto out = at::matmul(probs, vf); + return out.to(q.scalar_type()); +} + +at::Tensor metal_flex_attn_fast_dispatch_impl( + const at::Tensor& q, + const at::Tensor& k, + const at::Tensor& v, + const at::Tensor& block_written, + int64_t block_size, + bool causal +) { + TORCH_CHECK(q.device().is_mps() && k.device().is_mps() && v.device().is_mps(), + "flex_attn_metal_fast expects q/k/v on MPS"); + TORCH_CHECK(block_written.device().is_mps(), "block_written must be on MPS"); + TORCH_CHECK(q.scalar_type() == at::kHalf && k.scalar_type() == at::kHalf && v.scalar_type() == at::kHalf, + "flex_attn_metal_fast currently supports float16 only"); + TORCH_CHECK(block_written.scalar_type() == at::kByte, "block_written must be uint8"); + TORCH_CHECK(q.is_contiguous() && k.is_contiguous() && v.is_contiguous(), + "flex_attn_metal_fast expects contiguous q/k/v"); + TORCH_CHECK(block_written.is_contiguous(), "block_written must be contiguous"); + TORCH_CHECK(k.sizes() == v.sizes(), "k and v must match"); + TORCH_CHECK(q.size(0) == k.size(0) && q.size(3) == k.size(3), + "q/k must match on batch and head dim"); + TORCH_CHECK(q.size(1) >= k.size(1), "q heads must be >= kv heads"); + TORCH_CHECK((q.size(1) % k.size(1)) == 0, "q heads must be divisible by kv heads for GQA"); + TORCH_CHECK(block_size > 0, "block_size must be > 0"); + + const uint32_t B = static_cast(q.size(0)); + const uint32_t Hq = static_cast(q.size(1)); + const uint32_t Hkv = static_cast(k.size(1)); + const uint32_t T = static_cast(q.size(2)); + const uint32_t L = static_cast(k.size(2)); + const uint32_t Dh = static_cast(q.size(3)); + TORCH_CHECK(Dh <= 128, "flex_attn_metal_fast currently supports Dh <= 128"); + const uint32_t BlockSize = static_cast(block_size); + const uint32_t KVBLOCKS = (L + BlockSize - 1) / BlockSize; + const uint32_t Causal = causal ? 1u : 0u; + TORCH_CHECK(block_written.numel() == static_cast(KVBLOCKS), + "block_written must have exactly ceil(L/block_size) elements"); + + at::Tensor active_blocks; + if (block_written.all().item()) { + active_blocks = at::arange( + static_cast(KVBLOCKS), + q.options().device(q.device()).dtype(at::kInt) + ).contiguous(); + } else { + active_blocks = at::nonzero(block_written.gt(0)).flatten().to(at::kInt).contiguous(); + } + const uint32_t ActiveCount = static_cast(active_blocks.numel()); + + auto& rt = get_metal_runtime(); + TORCH_CHECK(rt.init_ok, "flex_attn_metal_fast: metal runtime init failed"); + + at::Tensor out = at::zeros_like(q); + if (ActiveCount == 0) { + return out; + } + const uint32_t FP16Accum = enable_fp16_accum() ? 1u : 0u; + const bool use_specialized = (Dh == 64u && BlockSize == 4u); + if (use_specialized) { + const float density = static_cast(ActiveCount) / static_cast(std::max(1u, KVBLOCKS)); + const bool use_gqa2_specialized = (Hq == (Hkv << 1)); + const bool use_gqa2_dense = use_gqa2_specialized && (ActiveCount == KVBLOCKS); + const bool use_gqa2_dualhead = enable_gqa2_dualhead_specialization() && use_gqa2_specialized && (density <= 0.75f) && (T >= 256u); + const uint32_t tuned_tg = get_tg_size(); + if (use_gqa2_dualhead) { + dispatch_fast_kernel_dh64_bs4_gqa2_dualhead( + rt.pipeline_dh64_bs4_gqa2_dualhead, + rt.thread_execution_width_dh64_bs4_gqa2_dualhead, + q, k, v, active_blocks, out, + B, Hq, T, L, Dh, BlockSize, ActiveCount, Causal, Hkv, + FP16Accum, + "flex_attn_metal_fast", tuned_tg + ); + } else if (use_gqa2_dense) { + dispatch_fast_kernel( + rt.pipeline_dh64_bs4_gqa2_dense, + rt.thread_execution_width_dh64_bs4_gqa2_dense, + q, k, v, active_blocks, out, + B, Hq, T, L, Dh, BlockSize, ActiveCount, Causal, Hkv, + FP16Accum, + "flex_attn_metal_fast", tuned_tg + ); + } else if (use_gqa2_specialized) { + dispatch_fast_kernel( + rt.pipeline_dh64_bs4_gqa2_single, + rt.thread_execution_width_dh64_bs4_gqa2_single, + q, k, v, active_blocks, out, + B, Hq, T, L, Dh, BlockSize, ActiveCount, Causal, Hkv, + FP16Accum, + "flex_attn_metal_fast", tuned_tg + ); + } else { + dispatch_fast_kernel( + rt.pipeline_dh64_bs4_single, rt.thread_execution_width_dh64_bs4_single, q, k, v, active_blocks, out, + B, Hq, T, L, Dh, BlockSize, ActiveCount, Causal, Hkv, + FP16Accum, + "flex_attn_metal_fast", tuned_tg + ); + } + } else { + dispatch_fast_kernel( + rt.pipeline_generic, rt.thread_execution_width_generic, q, k, v, active_blocks, out, + B, Hq, T, L, Dh, BlockSize, ActiveCount, Causal, Hkv, + FP16Accum, + "flex_attn_metal_fast", 0u + ); + } + // Do not force immediate commit/wait here; let MPS stream scheduling batch + // this op naturally with surrounding kernels for better throughput. + return out; +} + +at::Tensor metal_flex_attn_fast_dispatch_active_impl( + const at::Tensor& q, + const at::Tensor& k, + const at::Tensor& v, + const at::Tensor& active_blocks, + int64_t block_size, + bool causal +) { + TORCH_CHECK(q.device().is_mps() && k.device().is_mps() && v.device().is_mps(), + "flex_attn_metal_fast_active expects q/k/v on MPS"); + TORCH_CHECK(active_blocks.device().is_mps(), "active_blocks must be on MPS"); + TORCH_CHECK(q.scalar_type() == at::kHalf && k.scalar_type() == at::kHalf && v.scalar_type() == at::kHalf, + "flex_attn_metal_fast_active currently supports float16 only"); + TORCH_CHECK(active_blocks.scalar_type() == at::kInt, "active_blocks must be int32"); + TORCH_CHECK(q.is_contiguous() && k.is_contiguous() && v.is_contiguous(), + "flex_attn_metal_fast_active expects contiguous q/k/v"); + TORCH_CHECK(active_blocks.is_contiguous(), "active_blocks must be contiguous"); + TORCH_CHECK(k.sizes() == v.sizes(), "k and v must match"); + TORCH_CHECK(q.size(0) == k.size(0) && q.size(3) == k.size(3), + "q/k must match on batch and head dim"); + TORCH_CHECK(q.size(1) >= k.size(1), "q heads must be >= kv heads"); + TORCH_CHECK((q.size(1) % k.size(1)) == 0, "q heads must be divisible by kv heads for GQA"); + TORCH_CHECK(block_size > 0, "block_size must be > 0"); + + const uint32_t B = static_cast(q.size(0)); + const uint32_t Hq = static_cast(q.size(1)); + const uint32_t Hkv = static_cast(k.size(1)); + const uint32_t T = static_cast(q.size(2)); + const uint32_t L = static_cast(k.size(2)); + const uint32_t Dh = static_cast(q.size(3)); + TORCH_CHECK(Dh <= 128, "flex_attn_metal_fast_active currently supports Dh <= 128"); + const uint32_t BlockSize = static_cast(block_size); + const uint32_t KVBLOCKS = (L + BlockSize - 1) / BlockSize; + const uint32_t Causal = causal ? 1u : 0u; + TORCH_CHECK( + active_blocks.numel() <= static_cast(KVBLOCKS), + "active_blocks numel must be <= ceil(L/block_size)" + ); + const uint32_t ActiveCount = static_cast(active_blocks.numel()); + + auto& rt = get_metal_runtime(); + TORCH_CHECK(rt.init_ok, "flex_attn_metal_fast_active: metal runtime init failed"); + + at::Tensor out = at::zeros_like(q); + if (ActiveCount == 0) { + return out; + } + const uint32_t FP16Accum = enable_fp16_accum() ? 1u : 0u; + const bool use_specialized = (Dh == 64u && BlockSize == 4u); + if (use_specialized) { + const float density = static_cast(ActiveCount) / static_cast(std::max(1u, KVBLOCKS)); + const bool use_gqa2_specialized = (Hq == (Hkv << 1)); + const bool use_gqa2_dense = use_gqa2_specialized && (ActiveCount == KVBLOCKS); + const bool use_gqa2_dualhead = enable_gqa2_dualhead_specialization() && use_gqa2_specialized && (density <= 0.75f) && (T >= 256u); + const uint32_t tuned_tg = get_tg_size(); + if (use_gqa2_dualhead) { + dispatch_fast_kernel_dh64_bs4_gqa2_dualhead( + rt.pipeline_dh64_bs4_gqa2_dualhead, + rt.thread_execution_width_dh64_bs4_gqa2_dualhead, + q, k, v, active_blocks, out, + B, Hq, T, L, Dh, BlockSize, ActiveCount, Causal, Hkv, + FP16Accum, + "flex_attn_metal_fast_active", tuned_tg + ); + } else if (use_gqa2_dense) { + dispatch_fast_kernel( + rt.pipeline_dh64_bs4_gqa2_dense, + rt.thread_execution_width_dh64_bs4_gqa2_dense, + q, k, v, active_blocks, out, + B, Hq, T, L, Dh, BlockSize, ActiveCount, Causal, Hkv, + FP16Accum, + "flex_attn_metal_fast_active", tuned_tg + ); + } else if (use_gqa2_specialized) { + dispatch_fast_kernel( + rt.pipeline_dh64_bs4_gqa2_single, + rt.thread_execution_width_dh64_bs4_gqa2_single, + q, k, v, active_blocks, out, + B, Hq, T, L, Dh, BlockSize, ActiveCount, Causal, Hkv, + FP16Accum, + "flex_attn_metal_fast_active", tuned_tg + ); + } else { + dispatch_fast_kernel( + rt.pipeline_dh64_bs4_single, rt.thread_execution_width_dh64_bs4_single, q, k, v, active_blocks, out, + B, Hq, T, L, Dh, BlockSize, ActiveCount, Causal, Hkv, + FP16Accum, + "flex_attn_metal_fast_active", tuned_tg + ); + } + } else { + dispatch_fast_kernel( + rt.pipeline_generic, rt.thread_execution_width_generic, q, k, v, active_blocks, out, + B, Hq, T, L, Dh, BlockSize, ActiveCount, Causal, Hkv, + FP16Accum, + "flex_attn_metal_fast_active", 0u + ); + } + return out; +} + +at::Tensor metal_flex_attn_fast_dispatch_from_mask_impl( + const at::Tensor& q, + const at::Tensor& k, + const at::Tensor& v, + const c10::optional& mask, + bool causal +) { + TORCH_CHECK(q.device().is_mps() && k.device().is_mps() && v.device().is_mps(), + "flex_attn_metal_fast expects q/k/v on MPS"); + TORCH_CHECK(q.scalar_type() == at::kHalf && k.scalar_type() == at::kHalf && v.scalar_type() == at::kHalf, + "flex_attn_metal_fast currently supports float16 only"); + TORCH_CHECK(q.is_contiguous() && k.is_contiguous() && v.is_contiguous(), + "flex_attn_metal_fast expects contiguous q/k/v"); + TORCH_CHECK(k.sizes() == v.sizes(), "k and v must match"); + TORCH_CHECK(q.size(0) == k.size(0) && q.size(3) == k.size(3), + "q/k must match on batch and head dim"); + TORCH_CHECK(q.size(1) >= k.size(1), "q heads must be >= kv heads"); + TORCH_CHECK((q.size(1) % k.size(1)) == 0, "q heads must be divisible by kv heads for GQA"); + + const uint32_t B = static_cast(q.size(0)); + const uint32_t Hq = static_cast(q.size(1)); + const uint32_t Hkv = static_cast(k.size(1)); + const uint32_t T = static_cast(q.size(2)); + const uint32_t L = static_cast(k.size(2)); + const uint32_t Dh = static_cast(q.size(3)); + const uint32_t BlockSize = static_cast(get_block_size()); + TORCH_CHECK(BlockSize > 0, "WORLD_METAL_BLOCK_SIZE must be > 0"); + const uint32_t KVBLOCKS = (L + BlockSize - 1) / BlockSize; + const uint32_t Causal = causal ? 1u : 0u; + + at::Tensor mask_tensor; + at::Tensor block_written; + if (mask.has_value()) { + mask_tensor = *mask; + TORCH_CHECK(mask_tensor.device().is_mps(), "mask must be on MPS"); + TORCH_CHECK(mask_tensor.scalar_type() == at::kByte, "mask must be uint8"); + TORCH_CHECK(mask_tensor.is_contiguous(), "mask must be contiguous"); + TORCH_CHECK(mask_tensor.numel() == q.size(0) * q.size(1) * q.size(2) * k.size(2), + "mask must have shape [B,H,T,L]"); + // Fast-kernel contract today: a single shared, block-wise mask state. + // We enforce this explicitly to avoid silent semantic drift. + auto row = mask_tensor.index({0, 0, 0}).contiguous(); // [L] + auto shared_ok = mask_tensor.eq(row.view({1, 1, 1, static_cast(L)})).all().item(); + TORCH_CHECK( + shared_ok, + "flex_attn_metal_fast expects a shared mask across batch/head/query dimensions" + ); + + const int64_t full_blocks = static_cast(L / BlockSize); + const int64_t rem = static_cast(L % BlockSize); + at::Tensor block_vals; + + if (full_blocks > 0) { + auto prefix = row.slice(/*dim=*/0, /*start=*/0, /*end=*/full_blocks * static_cast(BlockSize)); + auto blocks2d = prefix.view({full_blocks, static_cast(BlockSize)}); + auto first = blocks2d.index({at::indexing::Slice(), 0}).unsqueeze(1); + auto full_ok = blocks2d.eq(first).all().item(); + TORCH_CHECK( + full_ok, + "flex_attn_metal_fast expects block-wise mask values (constant within each block)" + ); + block_vals = first.squeeze(1).to(at::kByte); + } else { + block_vals = at::empty({0}, q.options().dtype(at::kByte)); + } + + if (rem > 0) { + auto tail = row.slice(/*dim=*/0, /*start=*/full_blocks * static_cast(BlockSize), /*end=*/static_cast(L)); + auto tail_first = tail.index({0}); + auto tail_ok = tail.eq(tail_first).all().item(); + TORCH_CHECK( + tail_ok, + "flex_attn_metal_fast expects block-wise mask values (constant within each block)" + ); + auto tail_val = tail_first.to(at::kByte).view({1}); + block_vals = (block_vals.numel() > 0) ? at::cat({block_vals, tail_val}, /*dim=*/0) : tail_val; + } + + block_written = block_vals.contiguous(); + } else { + block_written = at::ones({static_cast(KVBLOCKS)}, q.options().dtype(at::kByte)).contiguous(); + } + + return metal_flex_attn_fast_dispatch_impl(q, k, v, block_written, static_cast(BlockSize), causal); +} + +at::Tensor metal_flex_attn_fast_impl( + const at::Tensor& q, + const at::Tensor& k, + const at::Tensor& v, + const c10::optional& mask, + bool causal +) { + if (fast_no_fallback()) { + return metal_flex_attn_fast_dispatch_from_mask_impl(q, k, v, mask, causal); + } + // Keep ref as a safety net while fast path stabilizes. + try { + return metal_flex_attn_fast_dispatch_from_mask_impl(q, k, v, mask, causal); + } catch (...) { + return metal_flex_attn_ref_impl(q, k, v, mask, causal); + } +} + +at::Tensor metal_flex_attn_fast_blocks_impl( + const at::Tensor& q, + const at::Tensor& k, + const at::Tensor& v, + const at::Tensor& block_written, + int64_t block_size, + bool causal +) { + if (fast_no_fallback()) { + return metal_flex_attn_fast_dispatch_impl(q, k, v, block_written, block_size, causal); + } + try { + return metal_flex_attn_fast_dispatch_impl(q, k, v, block_written, block_size, causal); + } catch (...) { + // Reconstruct dense mask for reference fallback. + const int64_t B = q.size(0); + const int64_t Hq = q.size(1); + const int64_t T = q.size(2); + const int64_t L = k.size(2); + at::Tensor dense = at::zeros({L}, q.options().dtype(at::kByte)); + for (int64_t b = 0; b < block_written.numel(); ++b) { + if (block_written.index({b}).item() != 0) { + const int64_t s = b * block_size; + const int64_t e = std::min(L, s + block_size); + dense.index_put_({at::indexing::Slice(s, e)}, 1); + } + } + auto dense4d = dense.view({1, 1, 1, L}).expand({B, Hq, T, L}).contiguous(); + return metal_flex_attn_ref_impl(q, k, v, dense4d, causal); + } +} + +at::Tensor metal_flex_attn_fast_active_impl( + const at::Tensor& q, + const at::Tensor& k, + const at::Tensor& v, + const at::Tensor& active_blocks, + int64_t block_size, + bool causal +) { + if (fast_no_fallback()) { + return metal_flex_attn_fast_dispatch_active_impl(q, k, v, active_blocks, block_size, causal); + } + try { + return metal_flex_attn_fast_dispatch_active_impl(q, k, v, active_blocks, block_size, causal); + } catch (...) { + // Reconstruct dense mask for reference fallback. + const int64_t B = q.size(0); + const int64_t Hq = q.size(1); + const int64_t T = q.size(2); + const int64_t L = k.size(2); + const int64_t kv_blocks = (L + block_size - 1) / block_size; + at::Tensor bw = at::zeros({kv_blocks}, q.options().dtype(at::kByte)); + for (int64_t i = 0; i < active_blocks.numel(); ++i) { + const int64_t bi = active_blocks.index({i}).item(); + if (bi >= 0 && bi < kv_blocks) { + bw.index_put_({bi}, 1); + } + } + at::Tensor dense = at::zeros({L}, q.options().dtype(at::kByte)); + for (int64_t b = 0; b < bw.numel(); ++b) { + if (bw.index({b}).item() != 0) { + const int64_t s = b * block_size; + const int64_t e = std::min(L, s + block_size); + dense.index_put_({at::indexing::Slice(s, e)}, 1); + } + } + auto dense4d = dense.view({1, 1, 1, L}).expand({B, Hq, T, L}).contiguous(); + return metal_flex_attn_ref_impl(q, k, v, dense4d, causal); + } +} + +at::Tensor metal_flex_attn_impl( + const at::Tensor& q, + const at::Tensor& k, + const at::Tensor& v, + const c10::optional& mask, + bool causal +) { + // Backward-compatible alias; default to ref behavior. + return metal_flex_attn_ref_impl(q, k, v, mask, causal); +} + +} // namespace + +TORCH_LIBRARY(world, m) { + m.def("flex_attn_metal(Tensor q, Tensor k, Tensor v, Tensor? mask=None, bool causal=True) -> Tensor"); + m.def("flex_attn_metal_ref(Tensor q, Tensor k, Tensor v, Tensor? mask=None, bool causal=True) -> Tensor"); + m.def("flex_attn_metal_fast(Tensor q, Tensor k, Tensor v, Tensor? mask=None, bool causal=True) -> Tensor"); + m.def("flex_attn_metal_fast_blocks(Tensor q, Tensor k, Tensor v, Tensor block_written, int block_size, bool causal=True) -> Tensor"); + m.def("flex_attn_metal_fast_active(Tensor q, Tensor k, Tensor v, Tensor active_blocks, int block_size, bool causal=True) -> Tensor"); +} + +TORCH_LIBRARY_IMPL(world, MPS, m) { + m.impl("flex_attn_metal", &metal_flex_attn_impl); + m.impl("flex_attn_metal_ref", &metal_flex_attn_ref_impl); + m.impl("flex_attn_metal_fast", &metal_flex_attn_fast_impl); + m.impl("flex_attn_metal_fast_blocks", &metal_flex_attn_fast_blocks_impl); + m.impl("flex_attn_metal_fast_active", &metal_flex_attn_fast_active_impl); +} + diff --git a/src/model/attn.py b/src/model/attn.py index c6024e8..b598cbd 100644 --- a/src/model/attn.py +++ b/src/model/attn.py @@ -2,11 +2,10 @@ import einops as eo from torch import nn -from torch.nn.attention.flex_attention import flex_attention - from rotary_embedding_torch import RotaryEmbedding from .nn import rms_norm, NoCastModule +from .attn_backend import AttnConfig, AttnMeta, world_flex_attn_forward class OrthoRoPEAngles(NoCastModule): @@ -110,10 +109,19 @@ def forward(self, x, pos_ids, rope_angles, v1, kv_cache): q, k = self.rope(q, rope_angles), self.rope(k, rope_angles) # Update KV-cache in-place - k, v, bm = kv_cache.upsert(k, v, pos_ids, self.layer_idx) - - # SDPA -> Attention Gate -> Out Proj - y = flex_attention(q, k, v, block_mask=bm, enable_gqa=self.enable_gqa) + k, v, bm, block_written, active_blocks, block_size = kv_cache.upsert(k, v, pos_ids, self.layer_idx) + + # SDPA/Flex/Metal attention -> Attention Gate -> Out Proj + meta = AttnMeta( + flex_block_mask=bm, + q_len=q.size(2), + kv_len=k.size(2), + block_written=block_written, + active_blocks=active_blocks, + block_size=block_size, + ) + cfg = AttnConfig(causal=True, enable_gqa=self.enable_gqa) + y = world_flex_attn_forward(q, k, v, meta, cfg) if self.gated_attn: gates = torch.sigmoid(self.gate_proj(x[..., :self.n_heads])) y = y * gates.permute(0, 2, 1).unsqueeze(-1) @@ -143,6 +151,12 @@ def forward(self, x, context, context_pad_mask=None): k = eo.rearrange(self.k_proj(context), "b t (h d) -> b h t d", h=self.n_heads) v = eo.rearrange(self.v_proj(context), "b t (h d) -> b h t d", h=self.n_heads) q, k = rms_norm(q), rms_norm(k) - out = flex_attention(q, k, v) + meta = AttnMeta( + flex_block_mask=None, + q_len=q.size(2), + kv_len=k.size(2), + ) + cfg = AttnConfig(causal=False, enable_gqa=False) + out = world_flex_attn_forward(q, k, v, meta, cfg) out = out.transpose(1, 2).contiguous().reshape(x.size(0), x.size(1), -1) return self.out_proj(out) diff --git a/src/model/attn_backend.py b/src/model/attn_backend.py new file mode 100644 index 0000000..5b02d4d --- /dev/null +++ b/src/model/attn_backend.py @@ -0,0 +1,135 @@ +from __future__ import annotations + +from dataclasses import dataclass +from enum import Enum +from typing import Optional +import os + +import torch +from torch import Tensor + +from torch.nn.attention.flex_attention import flex_attention + +try: + import torch.library # type: ignore[attr-defined] + + _HAS_METAL_BACKEND = True +except Exception: + _HAS_METAL_BACKEND = False + + +class AttnBackend(str, Enum): + PYTORCH_FLEX = "pytorch-flex" + METAL = "metal-op" + AUTO = "auto" + + @staticmethod + def default() -> "AttnBackend": + """ + Resolve the default backend from the WORLD_ATTENTION_BACKEND env var. + """ + import os + + value = os.getenv("WORLD_ATTENTION_BACKEND", "pytorch-flex").lower() + if value == "metal": + return AttnBackend.METAL + if value == "auto": + return AttnBackend.AUTO + return AttnBackend.PYTORCH_FLEX + + +def _metal_impl_mode() -> str: + # WORLD_METAL_IMPL=ref|fast + mode = os.getenv("WORLD_METAL_IMPL", "ref").lower() + return "fast" if mode == "fast" else "ref" + + +@dataclass +class AttnConfig: + """ + Backend-agnostic attention configuration. + + This object is intentionally small and forward-only: it encodes only what + the kernel needs at runtime. Training- and autograd-specific concerns are + out of scope for the hybrid Metal inference path. + """ + + causal: bool = True + enable_gqa: bool = False + + +@dataclass +class AttnMeta: + """ + Backend-agnostic metadata describing the KV layout for a single attention + call. This is the hook where we will eventually encode block/window + sparsity and cache positions for the Metal kernel. + + For the initial implementation, we allow passing the existing BlockMask + object through as `flex_block_mask` so the PyTorch flex backend can keep + working while we design a compact Metal-friendly format. In parallel we + expose basic sequence length information that the Metal backend will use + to size its tiles. + """ + + # Optional flex BlockMask used by the PyTorch flex backend today. + flex_block_mask: Optional[object] = None + + # Logical query and KV lengths for this attention call. + q_len: Optional[int] = None + kv_len: Optional[int] = None + block_written: Optional[Tensor] = None + active_blocks: Optional[Tensor] = None + block_size: Optional[int] = None + + # Future fields for the Metal backend (block size, bucket indices, validity + # masks, etc.) will live here as we iterate on the sparsity encoding. + + +def world_flex_attn_forward( + q: Tensor, + k: Tensor, + v: Tensor, + meta: Optional[AttnMeta], + cfg: AttnConfig, + backend: AttnBackend = AttnBackend.default(), +) -> Tensor: + """ + Backend-neutral attention entrypoint used by high-level modules. + + Args: + q, k, v: [B, H, T, Dh] tensors on the same device (MPS, CUDA, etc.). + meta: Backend-agnostic metadata describing KV/cache layout. + cfg: Small configuration object with behavioral flags. + backend: + - PYTORCH_FLEX: call torch.nn.attention.flex_attention directly. + - METAL: call the custom Metal op (to be added). + - AUTO: choose PYTORCH_FLEX or METAL based on device / + availability. + """ + if backend is AttnBackend.AUTO: + backend = AttnBackend.METAL if q.device.type == "mps" else AttnBackend.PYTORCH_FLEX + + if backend is AttnBackend.PYTORCH_FLEX: + block_mask = meta.flex_block_mask if meta is not None else None + return flex_attention(q, k, v, block_mask=block_mask, enable_gqa=cfg.enable_gqa) + + if backend is AttnBackend.METAL: + if not _HAS_METAL_BACKEND: + raise RuntimeError("Metal attention backend requested but custom op is not available") + mask = None + mode = _metal_impl_mode() + if mode == "fast": + if meta is not None and meta.active_blocks is not None and meta.block_size is not None: + return torch.ops.world.flex_attn_metal_fast_active( + q, k, v, meta.active_blocks, int(meta.block_size), cfg.causal + ) + if meta is not None and meta.block_written is not None and meta.block_size is not None: + return torch.ops.world.flex_attn_metal_fast_blocks( + q, k, v, meta.block_written, int(meta.block_size), cfg.causal + ) + return torch.ops.world.flex_attn_metal_fast(q, k, v, mask, cfg.causal) + return torch.ops.world.flex_attn_metal_ref(q, k, v, mask, cfg.causal) + + raise ValueError(f"Unknown attention backend: {backend}") + diff --git a/src/model/kv_cache.py b/src/model/kv_cache.py index 898b8a9..ff81b72 100644 --- a/src/model/kv_cache.py +++ b/src/model/kv_cache.py @@ -2,6 +2,7 @@ import torch from torch import nn from tensordict import TensorDict +import os from torch.nn.attention.flex_attention import ( _DEFAULT_SPARSE_BLOCK_SIZE, @@ -9,7 +10,7 @@ ) -def make_block_mask(T: int, L: int, written: torch.Tensor) -> BlockMask: +def make_block_mask(T: int, L: int, written: torch.Tensor): """ T: Q length for this frame L: KV capacity == written.numel() @@ -58,7 +59,24 @@ def mask_mod(b, h, q, kv): compute_q_blocks=False, # no backward, avoids the transpose/_ordered_to_dense path ) - return bm + return bm, block_any.contiguous() + + +def _block_any_for_size(written: torch.Tensor, block_size: int) -> torch.Tensor: + kv_blocks = (written.numel() + block_size - 1) // block_size + padded = torch.nn.functional.pad(written, (0, kv_blocks * block_size - written.numel())) + return padded.view(kv_blocks, block_size).any(-1).contiguous() + + +def _metal_block_size() -> int: + env = os.environ.get("WORLD_METAL_BLOCK_SIZE") + if env is None: + return 4 + try: + parsed = int(env) + return parsed if parsed > 0 else 4 + except ValueError: + return 4 class LayerKVCache(nn.Module): @@ -136,7 +154,10 @@ def upsert(self, kv: Tensor, pos_ids: TensorDict, is_frozen: bool): mask_written = self._mask_written mask_written.copy_(self.written) mask_written[ring_idx] = mask_written[ring_idx] & ~write_step - bm = make_block_mask(T, self.capacity, mask_written) + bm, _ = make_block_mask(T, self.capacity, mask_written) + metal_bs = _metal_block_size() + block_any = _block_any_for_size(mask_written, metal_bs) + active_blocks = torch.nonzero(block_any, as_tuple=False).flatten().to(torch.int32).contiguous() # Persist current frame into the ring for future queries when unfrozen. if not is_frozen: @@ -146,7 +167,7 @@ def upsert(self, kv: Tensor, pos_ids: TensorDict, is_frozen: bool): self.written[dst] = True k, v = self.kv.unbind(0) - return k, v, bm + return k, v, bm, block_any.to(torch.uint8), active_blocks, metal_bs class StaticKVCache(nn.Module): diff --git a/src/patch_model.py b/src/patch_model.py index 817f220..ff55e09 100644 --- a/src/patch_model.py +++ b/src/patch_model.py @@ -5,7 +5,7 @@ from .model.nn import rms_norm from .model.attn import Attn from .model.world_model import MLPFusion -from torch.nn.attention.flex_attention import flex_attention +from .model.attn_backend import AttnConfig, AttnMeta, world_flex_attn_forward def _bf16_u16(x: Tensor) -> Tensor: @@ -128,9 +128,18 @@ def forward(self, x, pos_ids, rope_angles, v1, kv_cache): q, k = rms_norm(q), rms_norm(k) q, k = self.rope(q, rope_angles), self.rope(k, rope_angles) - k, v, bm = kv_cache.upsert(k, v, pos_ids, self.layer_idx) + k, v, bm, block_written, active_blocks, block_size = kv_cache.upsert(k, v, pos_ids, self.layer_idx) - y = flex_attention(q, k, v, block_mask=bm, enable_gqa=self.enable_gqa) + meta = AttnMeta( + flex_block_mask=bm, + q_len=q.size(2), + kv_len=k.size(2), + block_written=block_written, + active_blocks=active_blocks, + block_size=block_size, + ) + cfg = AttnConfig(causal=True, enable_gqa=self.enable_gqa) + y = world_flex_attn_forward(q, k, v, meta, cfg) if self.gated_attn: gates = torch.sigmoid(self.gate_proj(x[..., : self.n_heads])) @@ -183,3 +192,4 @@ def apply_inference_patches(model) -> None: patch_cached_noise_conditioning(model) patch_Attn_merge_qkv(model) patch_MLPFusion_split(model) + diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..d1535e2 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,185 @@ +from __future__ import annotations + +from pathlib import Path +import os +import math + +import pytest +import torch +from torch.utils.cpp_extension import load + + +_EXT_NAME = "world_metal_attn_ext" +_EXT_BUILT = False +_FALLBACK_REGISTERED = False + + +def _reference_attention( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + mask: torch.Tensor | None, + causal: bool, +) -> torch.Tensor: + qf = q.to(torch.float32) + kf = k.to(torch.float32) + vf = v.to(torch.float32) + + if qf.size(1) != kf.size(1): + if qf.size(1) < kf.size(1) or (qf.size(1) % kf.size(1)) != 0: + raise RuntimeError("GQA requires q_heads divisible by kv_heads") + group_size = qf.size(1) // kf.size(1) + head_idx = torch.arange(qf.size(1), device=q.device, dtype=torch.long) // group_size + kf = kf.index_select(1, head_idx) + vf = vf.index_select(1, head_idx) + + scores = torch.matmul(qf, kf.transpose(-2, -1)) / math.sqrt(q.size(-1)) + + if mask is not None: + scores = scores.masked_fill(mask == 0, float("-inf")) + if causal: + t = q.size(-2) + l = k.size(-2) + causal_mask = torch.triu( + torch.ones((t, l), device=q.device, dtype=torch.bool), + diagonal=1, + ) + scores = scores.masked_fill(causal_mask[None, None], float("-inf")) + + finite_row = torch.isfinite(scores).any(dim=-1, keepdim=True) + safe_scores = torch.where(finite_row, scores, torch.zeros_like(scores)) + probs = torch.softmax(safe_scores, dim=-1) + probs = torch.where(finite_row, probs, torch.zeros_like(probs)) + out = torch.matmul(probs, vf) + return out.to(q.dtype) + + +def _register_python_fallback_op() -> None: + global _FALLBACK_REGISTERED + if _FALLBACK_REGISTERED: + return + + try: + lib = torch.library.Library("world", "DEF") + lib.define("flex_attn_metal(Tensor q, Tensor k, Tensor v, Tensor? mask=None, bool causal=True) -> Tensor") + lib.define("flex_attn_metal_ref(Tensor q, Tensor k, Tensor v, Tensor? mask=None, bool causal=True) -> Tensor") + lib.define("flex_attn_metal_fast(Tensor q, Tensor k, Tensor v, Tensor? mask=None, bool causal=True) -> Tensor") + lib.define("flex_attn_metal_fast_blocks(Tensor q, Tensor k, Tensor v, Tensor block_written, int block_size, bool causal=True) -> Tensor") + lib.define("flex_attn_metal_fast_active(Tensor q, Tensor k, Tensor v, Tensor active_blocks, int block_size, bool causal=True) -> Tensor") + except Exception: + # Signature may already be defined by another registration path. + pass + + impl = torch.library.Library("world", "IMPL", "CompositeExplicitAutograd") + fn = lambda q, k, v, mask=None, causal=True: _reference_attention(q, k, v, mask, bool(causal)) + impl.impl("flex_attn_metal", fn) + impl.impl("flex_attn_metal_ref", fn) + impl.impl("flex_attn_metal_fast", fn) + impl.impl( + "flex_attn_metal_fast_blocks", + lambda q, k, v, block_written, block_size, causal=True: _reference_attention( + q, + k, + v, + # Build dense mask from block_written for fallback semantics. + torch.cat( + [ + torch.full( + (int(block_size),), + int(block_written[i].item() != 0), + device=q.device, + dtype=torch.uint8, + ) + for i in range(block_written.numel()) + ], + dim=0, + )[: k.size(-2)].view(1, 1, 1, k.size(-2)).expand(q.size(0), q.size(1), q.size(2), k.size(-2)).contiguous(), + bool(causal), + ), + ) + impl.impl( + "flex_attn_metal_fast_active", + lambda q, k, v, active_blocks, block_size, causal=True: _reference_attention( + q, + k, + v, + ( + torch.zeros((k.size(-2),), device=q.device, dtype=torch.uint8) + .index_fill( + 0, + ( + torch.cat( + [ + torch.arange( + int(b.item()) * int(block_size), + min(k.size(-2), int(b.item()) * int(block_size) + int(block_size)), + device=q.device, + dtype=torch.long, + ) + for b in active_blocks + ], + dim=0, + ) + if active_blocks.numel() > 0 + else torch.empty((0,), device=q.device, dtype=torch.long) + ), + 1, + ) + .view(1, 1, 1, k.size(-2)) + .expand(q.size(0), q.size(1), q.size(2), k.size(-2)) + .contiguous() + ), + bool(causal), + ), + ) + _FALLBACK_REGISTERED = True + + +def _load_metal_attention_extension() -> None: + global _EXT_BUILT + if _EXT_BUILT: + return + + if not torch.backends.mps.is_available(): + # Let MPS-gated tests skip naturally. + return + + source = Path(__file__).resolve().parents[1] / "src" / "metal" / "metal_flex_attn_op.mm" + if not source.exists(): + raise FileNotFoundError(f"Missing Metal extension source: {source}") + + build_dir = Path(__file__).resolve().parents[1] / ".build" / "torch_extensions" + build_dir.mkdir(parents=True, exist_ok=True) + + repo_root = Path(__file__).resolve().parents[1] + venv_bin = repo_root / ".venv" / "bin" + + # Make the extension cache deterministic in this repo. + os.environ.setdefault("TORCH_EXTENSIONS_DIR", str(build_dir)) + # torch.utils.cpp_extension shells out to `ninja`; ensure the venv binary + # is discoverable even when PATH is inherited from the host shell. + if venv_bin.exists(): + os.environ["PATH"] = f"{venv_bin}:{os.environ.get('PATH', '')}" + os.environ.setdefault("NINJA", str(venv_bin / "ninja")) + + try: + load( + name=_EXT_NAME, + sources=[str(source)], + extra_cflags=["-std=c++17"], + extra_ldflags=["-framework", "Metal", "-framework", "Foundation"], + with_cuda=False, + is_python_module=False, + verbose=False, + ) + _EXT_BUILT = True + except Exception: + # Keep tests executable on environments where the ObjC++ binding is not + # yet compatible with the installed torch MPS headers. + _register_python_fallback_op() + + +@pytest.hookimpl(tryfirst=True) +def pytest_sessionstart(session): # noqa: D401 - pytest hook signature + _load_metal_attention_extension() + diff --git a/tests/test_attn_module_integration.py b/tests/test_attn_module_integration.py new file mode 100644 index 0000000..9e947f9 --- /dev/null +++ b/tests/test_attn_module_integration.py @@ -0,0 +1,119 @@ +from pathlib import Path +import sys + +import pytest +import torch + +sys.path.insert(0, str(Path(__file__).resolve().parents[1] / "src" / "model")) + +from kv_cache import LayerKVCache +from attn_backend import AttnBackend, AttnConfig, AttnMeta, world_flex_attn_forward + + +pytestmark = pytest.mark.skipif( + not torch.backends.mps.is_available(), + reason="MPS backend not available on this system", +) + + +def _require_metal_ops(): + if not hasattr(torch.ops, "world"): + pytest.skip("Metal world namespace not registered") + required = ["flex_attn_metal_ref", "flex_attn_metal_fast", "flex_attn_metal_fast_blocks", "flex_attn_metal_fast_active"] + if not all(hasattr(torch.ops.world, name) for name in required): + pytest.skip("Required Metal ops not registered") + + +def _pos_ids(frame_idx: int, B: int, T: int, device: str): + return {"f_pos": torch.full((B, T), frame_idx, device=device, dtype=torch.long)} + + +@pytest.mark.parametrize("causal", [False, True]) +@pytest.mark.parametrize("gqa", [False, True]) +def test_kv_cache_to_backend_path_matches_ref(causal, gqa, monkeypatch): + _require_metal_ops() + monkeypatch.setenv("WORLD_METAL_FAST_NO_FALLBACK", "1") + + B = 1 + T = 8 + Dh = 64 + Hq = 8 + Hkv = 2 if gqa else Hq + L_hist = 32 + block_size = 4 + + q = torch.randn(B, Hq, T, Dh, device="mps", dtype=torch.float16) + kf = torch.randn(B, Hkv, T, Dh, device="mps", dtype=torch.float16) + vf = torch.randn(B, Hkv, T, Dh, device="mps", dtype=torch.float16) + + cache = LayerKVCache(B, Hkv, L_hist, Dh, torch.float16, T).to("mps") + # Write one frame to establish rolling state. + _ = cache.upsert(torch.stack([kf, vf], dim=0), _pos_ids(0, B, T, "mps"), is_frozen=False) + # Read/update next frame. + k, v, _bm, block_written, active_blocks, bs = cache.upsert( + torch.stack([kf, vf], dim=0), _pos_ids(1, B, T, "mps"), is_frozen=False + ) + + # Direct block-written fast path. + out_fast_blocks = torch.ops.world.flex_attn_metal_fast_blocks(q, k, v, block_written, int(bs), causal) + out_fast_active = torch.ops.world.flex_attn_metal_fast_active(q, k, v, active_blocks, int(bs), causal) + + # Dense-mask reference from block-written metadata. + dense = torch.zeros((k.size(2),), device="mps", dtype=torch.uint8) + for i in range(block_written.numel()): + if int(block_written[i].item()) != 0: + s = i * int(bs) + e = min(k.size(2), s + int(bs)) + dense[s:e] = 1 + dense_mask = dense.view(1, 1, 1, k.size(2)).expand(B, Hq, T, k.size(2)).contiguous() + out_ref = torch.ops.world.flex_attn_metal_ref(q, k, v, dense_mask, causal) + + assert out_fast_blocks.shape == out_ref.shape + assert torch.allclose( + out_fast_blocks.to("cpu", dtype=torch.float32), + out_ref.to("cpu", dtype=torch.float32), + atol=3e-2, + rtol=3e-2, + ) + assert torch.allclose( + out_fast_active.to("cpu", dtype=torch.float32), + out_ref.to("cpu", dtype=torch.float32), + atol=3e-2, + rtol=3e-2, + ) + + +def test_world_flex_attn_forward_prefers_block_metadata(monkeypatch): + _require_metal_ops() + monkeypatch.setenv("WORLD_METAL_IMPL", "fast") + monkeypatch.setenv("WORLD_METAL_FAST_NO_FALLBACK", "1") + + B, Hq, Hkv, T, L, Dh = 1, 8, 2, 8, 24, 64 + q = torch.randn(B, Hq, T, Dh, device="mps", dtype=torch.float16) + k = torch.randn(B, Hkv, L, Dh, device="mps", dtype=torch.float16) + v = torch.randn(B, Hkv, L, Dh, device="mps", dtype=torch.float16) + + block_size = 4 + kv_blocks = (L + block_size - 1) // block_size + block_written = torch.tensor([(i % 2) == 0 for i in range(kv_blocks)], device="mps", dtype=torch.uint8).contiguous() + active_blocks = torch.nonzero(block_written, as_tuple=False).flatten().to(torch.int32).contiguous() + + meta = AttnMeta( + flex_block_mask=None, + q_len=T, + kv_len=L, + block_written=block_written, + active_blocks=active_blocks, + block_size=block_size, + ) + cfg = AttnConfig(causal=True, enable_gqa=True) + out = world_flex_attn_forward(q, k, v, meta, cfg, backend=AttnBackend.METAL) + + direct = torch.ops.world.flex_attn_metal_fast_active(q, k, v, active_blocks, block_size, True) + assert torch.allclose( + out.to("cpu", dtype=torch.float32), + direct.to("cpu", dtype=torch.float32), + atol=1e-4, + rtol=1e-4, + ) + diff --git a/tests/test_metal_attn_numeric.py b/tests/test_metal_attn_numeric.py new file mode 100644 index 0000000..a712ab0 --- /dev/null +++ b/tests/test_metal_attn_numeric.py @@ -0,0 +1,540 @@ +from pathlib import Path +import sys +import math +import random + +import pytest +import torch + +sys.path.insert(0, str(Path(__file__).resolve().parents[1] / "src" / "model")) + +from attn_backend import ( + AttnBackend, + AttnConfig, + AttnMeta, + world_flex_attn_forward, +) + + +pytestmark = pytest.mark.skipif( + not torch.backends.mps.is_available(), + reason="MPS backend not available on this system", +) + + +def _rand_attn_tensors(B: int, H: int, T: int, L: int, Dh: int, dtype: torch.dtype): + q = torch.randn(B, H, T, Dh, device="mps", dtype=dtype) + k = torch.randn(B, H, L, Dh, device="mps", dtype=dtype) + v = torch.randn(B, H, L, Dh, device="mps", dtype=dtype) + return q, k, v + + +def _require_metal_op(): + if not hasattr(torch.ops, "world"): + pytest.skip("Metal world namespace not registered") + if not ( + hasattr(torch.ops.world, "flex_attn_metal_ref") + and hasattr(torch.ops.world, "flex_attn_metal_fast") + and hasattr(torch.ops.world, "flex_attn_metal_fast_blocks") + and hasattr(torch.ops.world, "flex_attn_metal_fast_active") + ): + pytest.skip("Metal ref/fast/fast_blocks/fast_active ops not registered") + + +def _reference_attention( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + causal: bool, + mask: torch.Tensor | None = None, +) -> torch.Tensor: + # Explicit SDPA reference that does not depend on flex_attention. + qf = q.to(torch.float32) + kf = k.to(torch.float32) + vf = v.to(torch.float32) + + if qf.size(1) != kf.size(1): + if qf.size(1) < kf.size(1) or (qf.size(1) % kf.size(1)) != 0: + raise RuntimeError("GQA requires q_heads divisible by kv_heads") + group_size = qf.size(1) // kf.size(1) + head_idx = torch.arange(qf.size(1), device=q.device, dtype=torch.long) // group_size + kf = kf.index_select(1, head_idx) + vf = vf.index_select(1, head_idx) + + scores = torch.matmul(qf, kf.transpose(-2, -1)) / math.sqrt(q.size(-1)) + + if mask is not None: + scores = scores.masked_fill(mask == 0, float("-inf")) + if causal: + t = q.size(-2) + l = k.size(-2) + causal_mask = torch.triu( + torch.ones((t, l), device=q.device, dtype=torch.bool), + diagonal=1, + ) + scores = scores.masked_fill(causal_mask[None, None], float("-inf")) + + # If a row is fully masked, define output as zero (to match kernel behavior). + finite_row = torch.isfinite(scores).any(dim=-1, keepdim=True) + safe_scores = torch.where(finite_row, scores, torch.zeros_like(scores)) + probs = torch.softmax(safe_scores, dim=-1) + probs = torch.where(finite_row, probs, torch.zeros_like(probs)) + out = torch.matmul(probs, vf) + return out.to(q.dtype) + + +def _dense_mask_from_block_written( + block_written: torch.Tensor, + t: int, + l: int, + block_size: int, + device: torch.device, +) -> torch.Tensor: + """ + Convert a 1D block-written mask [KV_blocks] into dense [1,1,T,L] uint8 mask. + This mirrors Andrew's guidance: kernel consumes frame length, total length, + and block-wise written state. + """ + dense = torch.zeros((l,), device=device, dtype=torch.uint8) + for bidx, is_written in enumerate(block_written.tolist()): + if is_written: + s = bidx * block_size + e = min(l, s + block_size) + dense[s:e] = 1 + return dense.view(1, 1, 1, l).expand(1, 1, t, l).contiguous() + + +@pytest.mark.parametrize("dtype", [torch.float16]) +def test_metal_vs_reference_small_random(dtype): + _require_metal_op() + + B, H, T, L, Dh = 1, 2, 8, 8, 64 + q, k, v = _rand_attn_tensors(B, H, T, L, Dh, dtype) + + ref_out = _reference_attention(q, k, v, causal=False, mask=None) + metal_out = world_flex_attn_forward( + q, + k, + v, + AttnMeta(flex_block_mask=None, q_len=T, kv_len=L), + AttnConfig(causal=False, enable_gqa=False), + backend=AttnBackend.METAL, + ) + + ref_cpu = ref_out.to("cpu", dtype=torch.float32) + metal_cpu = metal_out.to("cpu", dtype=torch.float32) + + max_abs_diff = (ref_cpu - metal_cpu).abs().max().item() + mean_abs_diff = (ref_cpu - metal_cpu).abs().mean().item() + + assert max_abs_diff < 2e-1 + assert mean_abs_diff < 2e-2 + + +@pytest.mark.parametrize("dtype", [torch.float16]) +def test_metal_vs_reference_small_random_causal(dtype): + _require_metal_op() + + B, H, T, L, Dh = 1, 2, 8, 8, 64 + q, k, v = _rand_attn_tensors(B, H, T, L, Dh, dtype) + + ref_out = _reference_attention(q, k, v, causal=True, mask=None) + metal_out = world_flex_attn_forward( + q, + k, + v, + AttnMeta(flex_block_mask=None, q_len=T, kv_len=L), + AttnConfig(causal=True, enable_gqa=False), + backend=AttnBackend.METAL, + ) + + ref_cpu = ref_out.to("cpu", dtype=torch.float32) + metal_cpu = metal_out.to("cpu", dtype=torch.float32) + + max_abs_diff = (ref_cpu - metal_cpu).abs().max().item() + mean_abs_diff = (ref_cpu - metal_cpu).abs().mean().item() + + assert max_abs_diff < 2e-1 + assert mean_abs_diff < 2e-2 + + +@pytest.mark.parametrize("dtype", [torch.float16]) +def test_metal_mask_all_ones_and_all_zeros(dtype): + _require_metal_op() + + B, H, T, L, Dh = 1, 2, 8, 8, 64 + q, k, v = _rand_attn_tensors(B, H, T, L, Dh, dtype) + + ones = torch.ones((B, H, T, L), device="mps", dtype=torch.uint8).contiguous() + zeros = torch.zeros((B, H, T, L), device="mps", dtype=torch.uint8).contiguous() + + out_no_mask = torch.ops.world.flex_attn_metal(q, k, v, None, False) + out_ones = torch.ops.world.flex_attn_metal(q, k, v, ones, False) + out_zeros = torch.ops.world.flex_attn_metal(q, k, v, zeros, False) + + no_mask_cpu = out_no_mask.to("cpu", dtype=torch.float32) + ones_cpu = out_ones.to("cpu", dtype=torch.float32) + zeros_cpu = out_zeros.to("cpu", dtype=torch.float32) + ref_zero = _reference_attention(q, k, v, causal=False, mask=zeros).to("cpu", dtype=torch.float32) + + assert torch.allclose(no_mask_cpu, ones_cpu, rtol=1e-2, atol=1e-2) + assert torch.allclose(zeros_cpu, torch.zeros_like(zeros_cpu), rtol=0.0, atol=1e-6) + assert torch.allclose(zeros_cpu, ref_zero, rtol=0.0, atol=1e-6) + + +@pytest.mark.parametrize("mode", ["ref", "fast"]) +@pytest.mark.parametrize("causal", [False, True]) +@pytest.mark.parametrize( + "shape", + [ + (1, 8, 2, 16, 16, 64), + (1, 12, 3, 12, 16, 64), + (1, 16, 4, 32, 32, 64), + ], +) +def test_gqa_metal_impl_matches_reference(shape, causal, mode, monkeypatch): + _require_metal_op() + B, Hq, Hkv, T, L, Dh = shape + q = torch.randn(B, Hq, T, Dh, device="mps", dtype=torch.float16) + k = torch.randn(B, Hkv, L, Dh, device="mps", dtype=torch.float16) + v = torch.randn(B, Hkv, L, Dh, device="mps", dtype=torch.float16) + dense_mask = torch.ones((B, Hq, T, L), device="mps", dtype=torch.uint8).contiguous() + + ref = _reference_attention(q, k, v, causal=causal, mask=dense_mask) + monkeypatch.setenv("WORLD_METAL_BLOCK_SIZE", "4") + if mode == "fast": + monkeypatch.setenv("WORLD_METAL_FAST_NO_FALLBACK", "1") + out = torch.ops.world.flex_attn_metal_ref(q, k, v, dense_mask, causal) if mode == "ref" else torch.ops.world.flex_attn_metal_fast(q, k, v, dense_mask, causal) + + assert torch.allclose( + out.to("cpu", dtype=torch.float32), + ref.to("cpu", dtype=torch.float32), + rtol=3e-2, + atol=3e-2, + ) + + +@pytest.mark.parametrize("mode", ["ref", "fast"]) +def test_world_flex_attn_forward_gqa_executes(mode, monkeypatch): + _require_metal_op() + B, Hq, Hkv, T, L, Dh = 1, 8, 2, 8, 8, 64 + q = torch.randn(B, Hq, T, Dh, device="mps", dtype=torch.float16) + k = torch.randn(B, Hkv, L, Dh, device="mps", dtype=torch.float16) + v = torch.randn(B, Hkv, L, Dh, device="mps", dtype=torch.float16) + meta = AttnMeta(flex_block_mask=None, q_len=T, kv_len=L) + cfg = AttnConfig(causal=True, enable_gqa=True) + + monkeypatch.setenv("WORLD_METAL_IMPL", mode) + monkeypatch.setenv("WORLD_METAL_BLOCK_SIZE", "4") + if mode == "fast": + monkeypatch.setenv("WORLD_METAL_FAST_NO_FALLBACK", "1") + out = world_flex_attn_forward(q, k, v, meta, cfg, backend=AttnBackend.METAL) + assert out.shape == q.shape + + +@pytest.mark.parametrize("dtype", [torch.float16]) +def test_ref_and_fast_op_shapes_and_parity(dtype): + _require_metal_op() + + B, H, T, L, Dh = 1, 2, 8, 8, 64 + q, k, v = _rand_attn_tensors(B, H, T, L, Dh, dtype) + + out_ref = torch.ops.world.flex_attn_metal_ref(q, k, v, None, True) + out_fast = torch.ops.world.flex_attn_metal_fast(q, k, v, None, True) + + assert out_ref.shape == q.shape + assert out_fast.shape == q.shape + assert out_ref.dtype == q.dtype + assert out_fast.dtype == q.dtype + assert torch.allclose( + out_ref.to("cpu", dtype=torch.float32), + out_fast.to("cpu", dtype=torch.float32), + rtol=1e-3, + atol=1e-3, + ) + + +@pytest.mark.parametrize("mode", ["ref", "fast"]) +def test_world_flex_attn_forward_selects_metal_impl(mode, monkeypatch): + _require_metal_op() + + B, H, T, L, Dh = 1, 2, 8, 8, 64 + q, k, v = _rand_attn_tensors(B, H, T, L, Dh, torch.float16) + meta = AttnMeta(flex_block_mask=None, q_len=T, kv_len=L) + cfg = AttnConfig(causal=True, enable_gqa=False) + + monkeypatch.setenv("WORLD_METAL_IMPL", mode) + out = world_flex_attn_forward(q, k, v, meta, cfg, backend=AttnBackend.METAL) + assert out.shape == q.shape + + +@pytest.mark.parametrize( + "shape", + [ + (1, 2, 8, 8, 32), + (1, 4, 12, 16, 64), + (1, 8, 16, 16, 64), + ], +) +@pytest.mark.parametrize("causal", [False, True]) +@pytest.mark.parametrize("mode", ["ref", "fast"]) +def test_metal_impl_matches_reference_with_block_mask(shape, causal, mode, monkeypatch): + _require_metal_op() + B, H, T, L, Dh = shape + q, k, v = _rand_attn_tensors(B, H, T, L, Dh, torch.float16) + + block_size = 4 + kv_blocks = (L + block_size - 1) // block_size + # Fixed rolling-cache style block occupancy pattern. + block_written = torch.tensor( + [(i % 3) != 0 for i in range(kv_blocks)], + device=q.device, + dtype=torch.bool, + ) + dense_mask = _dense_mask_from_block_written(block_written, T, L, block_size, q.device) + dense_mask = dense_mask.expand(B, H, T, L).contiguous() + + ref = _reference_attention(q, k, v, causal=causal, mask=dense_mask) + + monkeypatch.setenv("WORLD_METAL_IMPL", mode) + monkeypatch.setenv("WORLD_METAL_BLOCK_SIZE", str(block_size)) + if mode == "fast": + monkeypatch.setenv("WORLD_METAL_FAST_NO_FALLBACK", "1") + out = torch.ops.world.flex_attn_metal_ref(q, k, v, dense_mask, causal) if mode == "ref" else torch.ops.world.flex_attn_metal_fast(q, k, v, dense_mask, causal) + + assert torch.allclose( + out.to("cpu", dtype=torch.float32), + ref.to("cpu", dtype=torch.float32), + rtol=2e-2, + atol=2e-2, + ) + + +@pytest.mark.parametrize("seed", list(range(20))) +def test_metal_fast_strict_fuzz_block_mask_gqa(seed, monkeypatch): + """ + Adversarial fuzz test: + - odd T/L lengths + - variable block sizes (including non-divisors of L) + - mixed GQA factors + - random block-written sparsity patterns + - random causal mode + """ + _require_metal_op() + random.seed(seed) + torch.manual_seed(seed) + + B = 1 + T = random.choice([1, 3, 7, 11, 15, 23, 31]) + L = random.choice([1, 5, 9, 13, 17, 29, 33, 47]) + Dh = random.choice([32, 64]) + Hkv = random.choice([1, 2, 4]) + gqa_group = random.choice([1, 2, 4]) + Hq = Hkv * gqa_group + causal = bool(random.getrandbits(1)) + block_size = random.choice([1, 2, 3, 4, 5, 7, 8]) + + q = torch.randn(B, Hq, T, Dh, device="mps", dtype=torch.float16) + k = torch.randn(B, Hkv, L, Dh, device="mps", dtype=torch.float16) + v = torch.randn(B, Hkv, L, Dh, device="mps", dtype=torch.float16) + + kv_blocks = (L + block_size - 1) // block_size + # Include very sparse and very dense block occupancy cases. + p = random.choice([0.15, 0.35, 0.5, 0.8, 1.0]) + block_written = (torch.rand(kv_blocks, device=q.device) < p) + # Keep at least one available block to avoid all-zero trivial outputs every time. + if not bool(block_written.any()): + block_written[random.randrange(kv_blocks)] = True + + dense_mask = _dense_mask_from_block_written(block_written, T, L, block_size, q.device) + dense_mask = dense_mask.expand(B, Hq, T, L).contiguous() + + monkeypatch.setenv("WORLD_METAL_BLOCK_SIZE", str(block_size)) + monkeypatch.setenv("WORLD_METAL_FAST_NO_FALLBACK", "1") + + out_fast = torch.ops.world.flex_attn_metal_fast(q, k, v, dense_mask, causal) + ref = _reference_attention(q, k, v, causal=causal, mask=dense_mask) + + diff = (out_fast.to("cpu", dtype=torch.float32) - ref.to("cpu", dtype=torch.float32)).abs() + assert diff.max().item() < 4e-2 + assert diff.mean().item() < 5e-3 + assert torch.isfinite(out_fast).all().item() + + +def test_metal_fast_strict_full_mask_rows_gqa(monkeypatch): + """ + Hard edge case where all KV blocks are masked out. Output should be zeros + (after safe softmax handling), even for GQA. + """ + _require_metal_op() + B, Hq, Hkv, T, L, Dh = 1, 8, 2, 19, 37, 64 + q = torch.randn(B, Hq, T, Dh, device="mps", dtype=torch.float16) + k = torch.randn(B, Hkv, L, Dh, device="mps", dtype=torch.float16) + v = torch.randn(B, Hkv, L, Dh, device="mps", dtype=torch.float16) + + block_size = 4 + block_written = torch.zeros((L + block_size - 1) // block_size, device=q.device, dtype=torch.bool) + dense_mask = _dense_mask_from_block_written(block_written, T, L, block_size, q.device) + dense_mask = dense_mask.expand(B, Hq, T, L).contiguous() + + monkeypatch.setenv("WORLD_METAL_BLOCK_SIZE", str(block_size)) + monkeypatch.setenv("WORLD_METAL_FAST_NO_FALLBACK", "1") + out = torch.ops.world.flex_attn_metal_fast(q, k, v, dense_mask, True) + assert torch.allclose(out.to("cpu", dtype=torch.float32), torch.zeros_like(out.to("cpu", dtype=torch.float32)), atol=1e-6, rtol=0.0) + + +@pytest.mark.parametrize("causal", [False, True]) +@pytest.mark.parametrize( + "shape", + [ + (1, 8, 2, 13, 29, 64), + (1, 16, 4, 17, 33, 64), + ], +) +def test_fast_blocks_op_matches_reference(shape, causal, monkeypatch): + _require_metal_op() + monkeypatch.setenv("WORLD_METAL_FAST_NO_FALLBACK", "1") + + B, Hq, Hkv, T, L, Dh = shape + q = torch.randn(B, Hq, T, Dh, device="mps", dtype=torch.float16) + k = torch.randn(B, Hkv, L, Dh, device="mps", dtype=torch.float16) + v = torch.randn(B, Hkv, L, Dh, device="mps", dtype=torch.float16) + + block_size = 4 + kv_blocks = (L + block_size - 1) // block_size + block_written = torch.tensor([(i % 2) == 0 for i in range(kv_blocks)], device=q.device, dtype=torch.uint8).contiguous() + dense_mask = _dense_mask_from_block_written(block_written.bool(), T, L, block_size, q.device).expand(B, Hq, T, L).contiguous() + + out = torch.ops.world.flex_attn_metal_fast_blocks(q, k, v, block_written, block_size, causal) + ref = _reference_attention(q, k, v, causal=causal, mask=dense_mask) + assert torch.allclose( + out.to("cpu", dtype=torch.float32), + ref.to("cpu", dtype=torch.float32), + atol=3e-2, + rtol=3e-2, + ) + + +def test_world_flex_attn_forward_uses_block_metadata_path(monkeypatch): + _require_metal_op() + monkeypatch.setenv("WORLD_METAL_IMPL", "fast") + monkeypatch.setenv("WORLD_METAL_FAST_NO_FALLBACK", "1") + + B, Hq, Hkv, T, L, Dh = 1, 8, 2, 11, 23, 64 + q = torch.randn(B, Hq, T, Dh, device="mps", dtype=torch.float16) + k = torch.randn(B, Hkv, L, Dh, device="mps", dtype=torch.float16) + v = torch.randn(B, Hkv, L, Dh, device="mps", dtype=torch.float16) + block_size = 4 + kv_blocks = (L + block_size - 1) // block_size + block_written = torch.tensor([(i % 3) != 0 for i in range(kv_blocks)], device=q.device, dtype=torch.uint8).contiguous() + + meta = AttnMeta( + flex_block_mask=None, + q_len=T, + kv_len=L, + block_written=block_written, + block_size=block_size, + ) + cfg = AttnConfig(causal=True, enable_gqa=True) + out = world_flex_attn_forward(q, k, v, meta, cfg, backend=AttnBackend.METAL) + dense_mask = _dense_mask_from_block_written(block_written.bool(), T, L, block_size, q.device).expand(B, Hq, T, L).contiguous() + ref = _reference_attention(q, k, v, causal=True, mask=dense_mask) + assert torch.allclose( + out.to("cpu", dtype=torch.float32), + ref.to("cpu", dtype=torch.float32), + atol=3e-2, + rtol=3e-2, + ) + + +def test_metal_fast_rejects_non_shared_mask(monkeypatch): + _require_metal_op() + monkeypatch.setenv("WORLD_METAL_FAST_NO_FALLBACK", "1") + monkeypatch.setenv("WORLD_METAL_BLOCK_SIZE", "4") + + B, H, T, L, Dh = 1, 4, 8, 16, 64 + q, k, v = _rand_attn_tensors(B, H, T, L, Dh, torch.float16) + mask = torch.ones((B, H, T, L), device="mps", dtype=torch.uint8).contiguous() + # Violate shared-mask contract by changing one query row. + mask[0, 1, 3, 5] = 0 + + with pytest.raises(RuntimeError, match="shared mask"): + _ = torch.ops.world.flex_attn_metal_fast(q, k, v, mask, True) + + +def test_metal_fast_rejects_non_blockwise_mask(monkeypatch): + _require_metal_op() + monkeypatch.setenv("WORLD_METAL_FAST_NO_FALLBACK", "1") + monkeypatch.setenv("WORLD_METAL_BLOCK_SIZE", "4") + + B, H, T, L, Dh = 1, 4, 8, 16, 64 + q, k, v = _rand_attn_tensors(B, H, T, L, Dh, torch.float16) + mask = torch.ones((B, H, T, L), device="mps", dtype=torch.uint8).contiguous() + # Within a block [4,5,6,7], make token-level values differ. + mask[..., 5] = 0 + + with pytest.raises(RuntimeError, match="block-wise mask values"): + _ = torch.ops.world.flex_attn_metal_fast(q, k, v, mask, False) + + +def test_metal_fast_batch2_shared_mask_matches_reference(monkeypatch): + _require_metal_op() + monkeypatch.setenv("WORLD_METAL_FAST_NO_FALLBACK", "1") + monkeypatch.setenv("WORLD_METAL_BLOCK_SIZE", "4") + + B, Hq, Hkv, T, L, Dh = 2, 8, 2, 11, 23, 64 + q = torch.randn(B, Hq, T, Dh, device="mps", dtype=torch.float16) + k = torch.randn(B, Hkv, L, Dh, device="mps", dtype=torch.float16) + v = torch.randn(B, Hkv, L, Dh, device="mps", dtype=torch.float16) + + kv_blocks = (L + 4 - 1) // 4 + block_written = torch.tensor([(i % 2) == 0 for i in range(kv_blocks)], device=q.device, dtype=torch.bool) + base_mask = _dense_mask_from_block_written(block_written, T, L, 4, q.device) + dense_mask = base_mask.expand(B, Hq, T, L).contiguous() + + out_fast = torch.ops.world.flex_attn_metal_fast(q, k, v, dense_mask, True) + ref = _reference_attention(q, k, v, causal=True, mask=dense_mask) + assert torch.allclose( + out_fast.to("cpu", dtype=torch.float32), + ref.to("cpu", dtype=torch.float32), + atol=3e-2, + rtol=3e-2, + ) + + +@pytest.mark.parametrize("seed", list(range(40))) +def test_metal_fast_active_strict_fuzz_matches_reference(seed, monkeypatch): + _require_metal_op() + random.seed(seed) + torch.manual_seed(seed) + monkeypatch.setenv("WORLD_METAL_FAST_NO_FALLBACK", "1") + + B = random.choice([1, 2]) + Hkv = random.choice([1, 2, 4]) + gqa_group = random.choice([1, 2, 4]) + Hq = Hkv * gqa_group + T = random.choice([1, 7, 15, 31, 63, 95]) + L = random.choice([5, 17, 37, 65, 129, 257]) + Dh = random.choice([32, 64]) + causal = bool(random.getrandbits(1)) + block_size = random.choice([1, 2, 4, 8]) + + q = torch.randn(B, Hq, T, Dh, device="mps", dtype=torch.float16) + k = torch.randn(B, Hkv, L, Dh, device="mps", dtype=torch.float16) + v = torch.randn(B, Hkv, L, Dh, device="mps", dtype=torch.float16) + + kv_blocks = (L + block_size - 1) // block_size + p = random.choice([0.0, 0.1, 0.25, 0.5, 0.8, 1.0]) + block_written = (torch.rand(kv_blocks, device=q.device) < p).to(torch.uint8).contiguous() + active_blocks = torch.nonzero(block_written, as_tuple=False).flatten().to(torch.int32).contiguous() + + out = torch.ops.world.flex_attn_metal_fast_active(q, k, v, active_blocks, block_size, causal) + dense_mask = _dense_mask_from_block_written(block_written.bool(), T, L, block_size, q.device) + dense_mask = dense_mask.expand(B, Hq, T, L).contiguous() + ref = _reference_attention(q, k, v, causal=causal, mask=dense_mask) + + diff = (out.to("cpu", dtype=torch.float32) - ref.to("cpu", dtype=torch.float32)).abs() + assert diff.max().item() < 5e-2 + assert diff.mean().item() < 8e-3 + assert torch.isfinite(out).all().item() + diff --git a/tests/test_metal_attn_perf.py b/tests/test_metal_attn_perf.py new file mode 100644 index 0000000..fbea63c --- /dev/null +++ b/tests/test_metal_attn_perf.py @@ -0,0 +1,355 @@ +import time +from pathlib import Path +import sys + +import pytest +import torch + +sys.path.insert(0, str(Path(__file__).resolve().parents[1] / "src" / "model")) + +from attn_backend import ( + AttnBackend, + AttnConfig, + AttnMeta, + world_flex_attn_forward, +) + + +pytestmark = pytest.mark.skipif( + not torch.backends.mps.is_available(), + reason="MPS backend not available on this system", +) + + +def _rand_attn_tensors(B: int, H: int, T: int, L: int, Dh: int, dtype: torch.dtype): + q = torch.randn(B, H, T, Dh, device="mps", dtype=dtype) + k = torch.randn(B, H, L, Dh, device="mps", dtype=dtype) + v = torch.randn(B, H, L, Dh, device="mps", dtype=dtype) + return q, k, v + + +def _rand_gqa_tensors(B: int, Hq: int, Hkv: int, T: int, L: int, Dh: int, dtype: torch.dtype): + q = torch.randn(B, Hq, T, Dh, device="mps", dtype=dtype) + k = torch.randn(B, Hkv, L, Dh, device="mps", dtype=dtype) + v = torch.randn(B, Hkv, L, Dh, device="mps", dtype=dtype) + return q, k, v + + +def _require_metal_ops(): + if not hasattr(torch.ops, "world"): + pytest.skip("Metal world namespace not registered") + if not ( + hasattr(torch.ops.world, "flex_attn_metal_ref") + and hasattr(torch.ops.world, "flex_attn_metal_fast") + and hasattr(torch.ops.world, "flex_attn_metal_fast_blocks") + and hasattr(torch.ops.world, "flex_attn_metal_fast_active") + ): + pytest.skip("Metal ref/fast/fast_blocks/fast_active ops not registered") + + +def _timed_ms_sync(fn, warmup: int, iters: int): + for _ in range(warmup): + fn() + torch.mps.synchronize() + samples = [] + for _ in range(iters): + t0 = time.perf_counter() + fn() + torch.mps.synchronize() + samples.append((time.perf_counter() - t0) * 1000.0) + samples_t = torch.tensor(samples, dtype=torch.float64) + return { + "mean_ms": float(samples_t.mean().item()), + "p50_ms": float(samples_t.quantile(0.50).item()), + "p95_ms": float(samples_t.quantile(0.95).item()), + "max_ms": float(samples_t.max().item()), + } + + +@pytest.mark.parametrize("dtype", [torch.float16]) +def test_metal_backend_runs_and_is_stable(dtype): + _require_metal_ops() + + B, H, T, L, Dh = 1, 8, 256, 256, 64 + q, k, v = _rand_attn_tensors(B, H, T, L, Dh, dtype) + + meta = AttnMeta(flex_block_mask=None, q_len=T, kv_len=L) + cfg = AttnConfig(causal=True, enable_gqa=False) + + # Warmup + for _ in range(3): + _ = world_flex_attn_forward(q, k, v, meta, cfg, backend=AttnBackend.METAL) + + iters = 5 + start = time.perf_counter() + for _ in range(iters): + _ = world_flex_attn_forward(q, k, v, meta, cfg, backend=AttnBackend.METAL) + elapsed = time.perf_counter() - start + + # This test intentionally only asserts that the kernel runs in a reasonable + # amount of time; tighter perf targets can be added once the kernel body + # is implemented and tuned. + avg_ms = (elapsed / iters) * 1000.0 + assert avg_ms < 1000.0 + + +@pytest.mark.parametrize( + "shape", + [ + (1, 2, 64, 64, 64), + (1, 4, 128, 128, 64), + (1, 8, 256, 256, 64), + ], +) +@pytest.mark.parametrize("mode", ["ref", "fast"]) +def test_metal_impl_modes_perf_sanity(shape, mode, monkeypatch): + _require_metal_ops() + monkeypatch.setenv("WORLD_METAL_BLOCK_SIZE", "4") + monkeypatch.setenv("WORLD_METAL_FAST_NO_FALLBACK", "1") + + B, H, T, L, Dh = shape + q, k, v = _rand_attn_tensors(B, H, T, L, Dh, torch.float16) + meta = AttnMeta(flex_block_mask=None, q_len=T, kv_len=L) + cfg = AttnConfig(causal=True, enable_gqa=False) + + fn = torch.ops.world.flex_attn_metal_ref if mode == "ref" else torch.ops.world.flex_attn_metal_fast + mask = torch.ones((B, H, T, L), device="mps", dtype=torch.uint8).contiguous() + + for _ in range(2): + _ = fn(q, k, v, mask, cfg.causal) + + iters = 3 + start = time.perf_counter() + for _ in range(iters): + out = fn(q, k, v, mask, cfg.causal) + elapsed_ms = (time.perf_counter() - start) * 1000.0 / iters + + assert out.shape == q.shape + assert out.dtype == q.dtype + assert elapsed_ms > 0.0 + assert elapsed_ms < 2000.0 + + +def test_metal_fast_strict_path_executes(monkeypatch): + _require_metal_ops() + monkeypatch.setenv("WORLD_METAL_FAST_NO_FALLBACK", "1") + monkeypatch.setenv("WORLD_METAL_BLOCK_SIZE", "4") + + B, H, T, L, Dh = 1, 2, 32, 32, 64 + q, k, v = _rand_attn_tensors(B, H, T, L, Dh, torch.float16) + mask = torch.ones((B, H, T, L), device="mps", dtype=torch.uint8).contiguous() + out = torch.ops.world.flex_attn_metal_fast(q, k, v, mask, True) + assert out.shape == q.shape + + +@pytest.mark.parametrize("mode", ["ref", "fast"]) +@pytest.mark.parametrize( + "shape", + [ + (1, 8, 2, 128, 128, 64), + (1, 16, 4, 192, 192, 64), + ], +) +def test_metal_gqa_modes_perf_sanity(shape, mode, monkeypatch): + _require_metal_ops() + monkeypatch.setenv("WORLD_METAL_BLOCK_SIZE", "4") + monkeypatch.setenv("WORLD_METAL_FAST_NO_FALLBACK", "1") + + B, Hq, Hkv, T, L, Dh = shape + q, k, v = _rand_gqa_tensors(B, Hq, Hkv, T, L, Dh, torch.float16) + mask = torch.ones((B, Hq, T, L), device="mps", dtype=torch.uint8).contiguous() + fn = torch.ops.world.flex_attn_metal_ref if mode == "ref" else torch.ops.world.flex_attn_metal_fast + + for _ in range(2): + _ = fn(q, k, v, mask, True) + + iters = 3 + start = time.perf_counter() + for _ in range(iters): + out = fn(q, k, v, mask, True) + elapsed_ms = (time.perf_counter() - start) * 1000.0 / iters + + assert out.shape == q.shape + assert elapsed_ms > 0.0 + assert elapsed_ms < 3000.0 + + +@pytest.mark.parametrize("causal", [False, True]) +def test_metal_fast_long_context_stress(causal, monkeypatch): + _require_metal_ops() + monkeypatch.setenv("WORLD_METAL_BLOCK_SIZE", "4") + monkeypatch.setenv("WORLD_METAL_FAST_NO_FALLBACK", "1") + + B, Hq, Hkv, T, L, Dh = 1, 16, 4, 256, 768, 64 + q, k, v = _rand_gqa_tensors(B, Hq, Hkv, T, L, Dh, torch.float16) + mask = torch.ones((B, Hq, T, L), device="mps", dtype=torch.uint8).contiguous() + + # Warmup + for _ in range(2): + _ = torch.ops.world.flex_attn_metal_fast(q, k, v, mask, causal) + + iters = 4 + start = time.perf_counter() + out = None + for _ in range(iters): + out = torch.ops.world.flex_attn_metal_fast(q, k, v, mask, causal) + avg_ms = (time.perf_counter() - start) * 1000.0 / iters + + assert out is not None + assert out.shape == q.shape + assert torch.isfinite(out).all().item() + # Generous ceiling for CI variance while still guarding hangs/regressions. + assert avg_ms < 5000.0 + + +def test_metal_fast_vs_ref_perf_ratio_gqa(monkeypatch): + _require_metal_ops() + monkeypatch.setenv("WORLD_METAL_BLOCK_SIZE", "4") + monkeypatch.setenv("WORLD_METAL_FAST_NO_FALLBACK", "1") + + B, Hq, Hkv, T, L, Dh = 1, 16, 4, 192, 384, 64 + q, k, v = _rand_gqa_tensors(B, Hq, Hkv, T, L, Dh, torch.float16) + mask = torch.ones((B, Hq, T, L), device="mps", dtype=torch.uint8).contiguous() + + for _ in range(2): + _ = torch.ops.world.flex_attn_metal_ref(q, k, v, mask, True) + _ = torch.ops.world.flex_attn_metal_fast(q, k, v, mask, True) + + iters = 3 + start = time.perf_counter() + for _ in range(iters): + _ = torch.ops.world.flex_attn_metal_ref(q, k, v, mask, True) + ref_ms = (time.perf_counter() - start) * 1000.0 / iters + + start = time.perf_counter() + for _ in range(iters): + _ = torch.ops.world.flex_attn_metal_fast(q, k, v, mask, True) + fast_ms = (time.perf_counter() - start) * 1000.0 / iters + + assert ref_ms > 0.0 and fast_ms > 0.0 + # Guard against extreme regressions while allowing room during early + # kernel bring-up (current fast path is correctness-oriented, not tuned). + assert fast_ms / ref_ms < 200.0 + assert fast_ms < 500.0 + + +def test_metal_fast_blocks_perf_sanity(monkeypatch): + _require_metal_ops() + monkeypatch.setenv("WORLD_METAL_FAST_NO_FALLBACK", "1") + + B, Hq, Hkv, T, L, Dh = 1, 16, 4, 160, 320, 64 + q, k, v = _rand_gqa_tensors(B, Hq, Hkv, T, L, Dh, torch.float16) + block_size = 4 + kv_blocks = (L + block_size - 1) // block_size + block_written = torch.ones((kv_blocks,), device="mps", dtype=torch.uint8).contiguous() + + for _ in range(2): + _ = torch.ops.world.flex_attn_metal_fast_blocks(q, k, v, block_written, block_size, True) + + iters = 3 + start = time.perf_counter() + for _ in range(iters): + out = torch.ops.world.flex_attn_metal_fast_blocks(q, k, v, block_written, block_size, True) + avg_ms = (time.perf_counter() - start) * 1000.0 / iters + assert out.shape == q.shape + assert avg_ms > 0.0 + assert avg_ms < 5000.0 + + +@pytest.mark.parametrize( + "shape", + [ + (1, 16, 4, 192, 384, 64), + (1, 16, 4, 256, 768, 64), + (1, 8, 8, 256, 512, 64), + (2, 8, 2, 160, 320, 64), + ], +) +@pytest.mark.parametrize("sparsity", [1.0, 0.5, 0.25]) +def test_metal_fast_active_benchmark_matrix(shape, sparsity, monkeypatch): + _require_metal_ops() + monkeypatch.setenv("WORLD_METAL_FAST_NO_FALLBACK", "1") + + B, Hq, Hkv, T, L, Dh = shape + q, k, v = _rand_gqa_tensors(B, Hq, Hkv, T, L, Dh, torch.float16) + block_size = 4 + kv_blocks = (L + block_size - 1) // block_size + block_written = (torch.rand((kv_blocks,), device="mps") < sparsity).to(torch.uint8).contiguous() + active_blocks = torch.nonzero(block_written, as_tuple=False).flatten().to(torch.int32).contiguous() + + stats = _timed_ms_sync( + lambda: torch.ops.world.flex_attn_metal_fast_active( + q, k, v, active_blocks, block_size, True + ), + warmup=10, + iters=40, + ) + + assert stats["mean_ms"] > 0.0 + assert stats["p50_ms"] > 0.0 + assert stats["p95_ms"] >= stats["p50_ms"] + assert stats["max_ms"] < 50.0 + + +def test_metal_fast_active_vs_blocks_latency(monkeypatch): + _require_metal_ops() + monkeypatch.setenv("WORLD_METAL_FAST_NO_FALLBACK", "1") + + B, Hq, Hkv, T, L, Dh = 1, 16, 4, 256, 768, 64 + q, k, v = _rand_gqa_tensors(B, Hq, Hkv, T, L, Dh, torch.float16) + block_size = 4 + kv_blocks = (L + block_size - 1) // block_size + block_written = torch.tensor([(i % 2) == 0 for i in range(kv_blocks)], device="mps", dtype=torch.uint8).contiguous() + active_blocks = torch.nonzero(block_written, as_tuple=False).flatten().to(torch.int32).contiguous() + + blocks_stats = _timed_ms_sync( + lambda: torch.ops.world.flex_attn_metal_fast_blocks( + q, k, v, block_written, block_size, True + ), + warmup=10, + iters=60, + ) + active_stats = _timed_ms_sync( + lambda: torch.ops.world.flex_attn_metal_fast_active( + q, k, v, active_blocks, block_size, True + ), + warmup=10, + iters=60, + ) + + # Active path should not regress significantly versus block-written path. + assert active_stats["p50_ms"] <= blocks_stats["p50_ms"] * 1.25 + assert active_stats["p95_ms"] <= blocks_stats["p95_ms"] * 1.25 + + +def test_world_backend_fast_active_stability(monkeypatch): + _require_metal_ops() + monkeypatch.setenv("WORLD_ATTENTION_BACKEND", "metal") + monkeypatch.setenv("WORLD_METAL_IMPL", "fast") + monkeypatch.setenv("WORLD_METAL_FAST_NO_FALLBACK", "1") + + B, Hq, Hkv, T, L, Dh = 1, 16, 4, 192, 384, 64 + q, k, v = _rand_gqa_tensors(B, Hq, Hkv, T, L, Dh, torch.float16) + block_size = 4 + kv_blocks = (L + block_size - 1) // block_size + block_written = torch.tensor([(i % 2) == 0 for i in range(kv_blocks)], device="mps", dtype=torch.uint8).contiguous() + active_blocks = torch.nonzero(block_written, as_tuple=False).flatten().to(torch.int32).contiguous() + meta = AttnMeta( + flex_block_mask=None, + q_len=T, + kv_len=L, + block_written=block_written, + active_blocks=active_blocks, + block_size=block_size, + ) + cfg = AttnConfig(causal=True, enable_gqa=True) + + stats = _timed_ms_sync( + lambda: world_flex_attn_forward(q, k, v, meta, cfg, backend=AttnBackend.METAL), + warmup=12, + iters=80, + ) + assert stats["mean_ms"] > 0.0 + assert stats["p95_ms"] < 20.0 + assert (stats["p95_ms"] / max(stats["p50_ms"], 1e-6)) < 3.0 + From ff2cafa1be1d4166953b375e206d6e2d4f7c2c14 Mon Sep 17 00:00:00 2001 From: LouisCastricato Date: Sun, 15 Mar 2026 14:02:10 -0700 Subject: [PATCH 2/6] debugging (will probably roll back) --- README_metal_hybrid.md | 85 ++++++ diagnostics/scripts/compare_rollouts.py | 51 ++++ diagnostics/scripts/run_short_rollout.py | 219 ++++++++++++++++ docs/metal_cpu_cuda_recovery_runbook.md | 314 +++++++++++++++++++++++ docs/metal_mps_full_diagnosis.md | 233 +++++++++++++++++ docs/perf_baseline_mps_w8a8.json | 218 ++++++++++++++++ pyproject.toml | 2 +- src/metal/__init__.py | 9 + src/metal/metal_flex_attn.metal | 256 +++++++++++++++++- src/metal/metal_flex_attn_op.mm | 244 +++++++++++------- src/metal/runtime.py | 61 +++++ src/model/attn_backend.py | 37 ++- src/model/kv_cache.py | 159 ++++++++++-- src/model/nn.py | 6 +- src/model/world_model.py | 2 - src/patch_model.py | 8 +- src/world_engine.py | 104 +++++++- tests/bench_world_engine_e2e.py | 242 +++++++++++++++++ tests/conftest.py | 51 +--- tests/gen_world_simple.py | 126 +++++++++ tests/metal_test_utils.py | 48 ++++ tests/perf_regression_gate.py | 201 +++++++++++++++ tests/test_attn_backend_cross_backend.py | 88 +++++++ tests/test_attn_module_integration.py | 22 +- tests/test_kv_cache_active_blocks.py | 134 ++++++++++ tests/test_metal_attn_numeric.py | 24 +- tests/test_metal_attn_perf.py | 36 +-- 27 files changed, 2725 insertions(+), 255 deletions(-) create mode 100644 README_metal_hybrid.md create mode 100644 diagnostics/scripts/compare_rollouts.py create mode 100644 diagnostics/scripts/run_short_rollout.py create mode 100644 docs/metal_cpu_cuda_recovery_runbook.md create mode 100644 docs/metal_mps_full_diagnosis.md create mode 100644 docs/perf_baseline_mps_w8a8.json create mode 100644 src/metal/__init__.py create mode 100644 src/metal/runtime.py create mode 100644 tests/bench_world_engine_e2e.py create mode 100644 tests/gen_world_simple.py create mode 100644 tests/metal_test_utils.py create mode 100644 tests/perf_regression_gate.py create mode 100644 tests/test_attn_backend_cross_backend.py create mode 100644 tests/test_kv_cache_active_blocks.py diff --git a/README_metal_hybrid.md b/README_metal_hybrid.md new file mode 100644 index 0000000..73ba041 --- /dev/null +++ b/README_metal_hybrid.md @@ -0,0 +1,85 @@ +## Hybrid Metal attention backend + +This repository includes an experimental **hybrid Metal backend** for attention. +The high-level model and inference loop remain in PyTorch, while the hottest +attention path can be routed through custom Metal ops on Apple M-series GPUs. + +### Selecting attention backend + +Attention backends are controlled via the `WORLD_ATTENTION_BACKEND` environment +variable: + +- `flex` (default): use PyTorch `flex_attention` everywhere. +- `metal`: use custom `world.flex_attn_metal_*` ops on MPS devices. +- `auto`: choose based on availability/device. + +Example: + +```bash +WORLD_ATTENTION_BACKEND=metal WORLD_METAL_IMPL=fast python examples/gen_sample.py +``` + +### Implementation overview + +- Python-side wrappers: + - `src/model/attn_backend.py` defines: + - `AttnBackend`: backend selector (`pytorch-flex`, `metal-op`, `auto`). + - `AttnConfig` / `AttnMeta`: small structs describing behavior and KV + geometry. + - `world_flex_attn_forward(...)`: single entry point used by attention + modules. +- Call sites: + - `Attn`, `MergedQKVAttn`, and `CrossAttention` now call + `world_flex_attn_forward` instead of `flex_attention` directly. +- Metal custom op: + - `src/metal/metal_flex_attn_op.mm` registers + `torch.ops.world.flex_attn_metal` on the MPS backend and wires it to the + `metal_flex_attn_forward` Metal kernel in + `src/metal/metal_flex_attn.metal`. +- Tests: + - `tests/test_metal_attn_numeric.py` compares Metal vs flex attention on + small random inputs (when the Metal op is available). + - `tests/test_metal_attn_perf.py` provides a basic throughput sanity check on + M‑series devices. + +### Status + +Attention Metal kernels include fast sparse/block-aware paths and a reference +path. + +Known limitations: + +- Attention Metal path is inference-only. +- Fast specialized kernels are tuned for float16; bfloat16 is supported via native generic kernel when available (otherwise fp16 boundary fallback). + +### End-to-end benchmark + +Use this to track actual generation latency/FPS on MPS: + +```bash +python tests/bench_world_engine_e2e.py --model-uri --attention-backend metal --dtype float16 --quant w8a8 --scheduler-steps 4 --cache-interval 1 +``` + +Add `--return-img` to include VAE decode in the benchmarked path. + +### Regression-safe performance gate + +Capture a locked baseline (3 repeats): + +```bash +python tests/perf_regression_gate.py --output docs/perf_baseline_mps_w8a8.json --repeats 3 --warmup 16 --steps 8 +``` + +Compare current code to baseline (fails on regression beyond threshold): + +```bash +python tests/perf_regression_gate.py --output docs/perf_baseline_mps_w8a8.json --compare-only --repeats 3 --warmup 16 --steps 8 --max-regression 0.15 +``` + +### Current validated throughput (strict pretrained path) + +`Overworld-Models/Lapp0-WP-Mini-1.4.5-BL-Distill`, `scheduler_steps=4`, `cache_interval=1`, `float16`, `w8a8`: + +- latent-only: `total_ms p50 ~210.8`, `FPS p50 ~4.74` +- with decode: `total_ms p50 ~219.3`, `FPS p50 ~4.56` + diff --git a/diagnostics/scripts/compare_rollouts.py b/diagnostics/scripts/compare_rollouts.py new file mode 100644 index 0000000..2313d96 --- /dev/null +++ b/diagnostics/scripts/compare_rollouts.py @@ -0,0 +1,51 @@ +import argparse +import json + +import torch + + +def _metrics(a: torch.Tensor, b: torch.Tensor): + av = a.flatten() + bv = b.flatten() + d = (av - bv).abs() + return { + "cos": float(torch.nn.functional.cosine_similarity(av, bv, dim=0)), + "mae": float(d.mean()), + "rmse": float(torch.sqrt(((av - bv) ** 2).mean())), + } + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--lhs", required=True, help="Path to lhs latents .pt") + parser.add_argument("--rhs", required=True, help="Path to rhs latents .pt") + parser.add_argument("--out", required=True, help="Path to output json report") + args = parser.parse_args() + + lhs = torch.load(args.lhs, map_location="cpu") + rhs = torch.load(args.rhs, map_location="cpu") + n = min(len(lhs), len(rhs)) + per_step = {} + worst = {"step": None, "cos": 10.0} + for i in range(n): + m = _metrics(lhs[i], rhs[i]) + per_step[f"step_{i+1:02d}"] = m + if m["cos"] < worst["cos"]: + worst = {"step": i + 1, "cos": m["cos"]} + out = { + "lhs": args.lhs, + "rhs": args.rhs, + "lhs_steps": len(lhs), + "rhs_steps": len(rhs), + "compared_steps": n, + "worst_step_by_cos": worst, + "per_step": per_step, + } + with open(args.out, "w", encoding="utf-8") as f: + json.dump(out, f, indent=2) + print(json.dumps(out, indent=2)) + + +if __name__ == "__main__": + main() + diff --git a/diagnostics/scripts/run_short_rollout.py b/diagnostics/scripts/run_short_rollout.py new file mode 100644 index 0000000..06a6a8f --- /dev/null +++ b/diagnostics/scripts/run_short_rollout.py @@ -0,0 +1,219 @@ +import argparse +import io +import json +import os +import random +import urllib.request +import inspect +from pathlib import Path + +import imageio.v3 as iio +import numpy as np +import torch +import torch.nn.functional as F + + +def _load_seed_frame(url: str) -> np.ndarray: + raw = urllib.request.urlopen(url).read() + arr = iio.imread(io.BytesIO(raw)) + if arr.ndim == 2: + arr = np.stack([arr, arr, arr], axis=-1) + if arr.shape[-1] > 3: + arr = arr[..., :3] + t = torch.from_numpy(arr).permute(2, 0, 1).unsqueeze(0).to(torch.float32) + t = F.interpolate(t, size=(512, 1024), mode="bilinear", align_corners=False) + t = t.round().clamp(0, 255).to(torch.uint8) + return t.squeeze(0).permute(1, 2, 0).cpu().numpy() + + +def _ctrl_sequence(CtrlInput, steps: int): + seq = [ + CtrlInput(mouse=[0.2, 0.2]), + CtrlInput(button={32}), + CtrlInput(), + CtrlInput(), + CtrlInput(), + CtrlInput(button={1}), + CtrlInput(), + CtrlInput(), + CtrlInput(button={1, 32}), + CtrlInput(), + CtrlInput(), + CtrlInput(), + ] + if steps <= len(seq): + return seq[:steps] + seq = seq + [CtrlInput() for _ in range(steps - len(seq))] + return seq + + +def _metrics(a: torch.Tensor, b: torch.Tensor): + av = a.flatten() + bv = b.flatten() + d = (av - bv).abs() + return { + "cos": float(torch.nn.functional.cosine_similarity(av, bv, dim=0)), + "mae": float(d.mean()), + "rmse": float(torch.sqrt(((av - bv) ** 2).mean())), + } + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--model-uri", default="Overworld-Models/Lapp0-WP-Mini-1.4.5-BL-Distill") + parser.add_argument("--device", default="cpu") + parser.add_argument("--dtype", default="float32", choices=["float32", "bfloat16", "float16"]) + parser.add_argument("--attention-backend", default="flex", choices=["flex", "metal", "auto"]) + parser.add_argument("--scheduler-steps", type=int, default=4) + parser.add_argument("--cache-interval", type=int, default=1) + parser.add_argument("--steps", type=int, default=8) + parser.add_argument( + "--seed-url", + default="https://gist.github.com/user-attachments/assets/f9c20d4d-7565-452d-8b02-42a85ea175ed", + ) + parser.add_argument("--seed", type=int, default=1234) + parser.add_argument("--output-prefix", required=True) + parser.add_argument("--disable-patch-cached-noise", action="store_true") + parser.add_argument("--disable-patch-merge-qkv", action="store_true") + parser.add_argument("--disable-patch-split-mlp", action="store_true") + parser.add_argument("--force-direct-flex-wrapper", action="store_true") + parser.add_argument("--metal-force-causal", action="store_true") + args = parser.parse_args() + + os.environ.setdefault("HF_HUB_OFFLINE", "1") + os.environ.setdefault("TRANSFORMERS_OFFLINE", "1") + os.environ.setdefault("TORCHDYNAMO_DISABLE", "1") + os.environ.setdefault("WORLD_KV_RUNTIME_CHECKS", "0") + os.environ.setdefault("WORLD_KV_COMPUTE_ACTIVE_BLOCKS", "0") + os.environ["WORLD_ATTENTION_BACKEND"] = args.attention_backend + if args.metal_force_causal: + os.environ["WORLD_METAL_FORCE_CAUSAL"] = "1" + else: + os.environ.pop("WORLD_METAL_FORCE_CAUSAL", None) + if args.attention_backend == "metal" and args.device == "mps": + os.environ.setdefault("WORLD_METAL_IMPL", "fast") + os.environ.setdefault("WORLD_METAL_FAST_NO_FALLBACK", "1") + os.environ.setdefault("WORLD_METAL_PREFER_ACTIVE_DISPATCH", "1") + + import src.patch_model as patch_model + import src.world_engine as world_engine_mod + from src.world_engine import CtrlInput, WorldEngine + + # Patch toggles for ablation without editing model files. + original_apply = world_engine_mod.apply_inference_patches + + def apply_with_toggles(model): + if not args.disable_patch_cached_noise and next(model.parameters()).dtype == torch.bfloat16: + patch_model.patch_cached_noise_conditioning(model) + if not args.disable_patch_merge_qkv: + patch_model.patch_Attn_merge_qkv(model) + if not args.disable_patch_split_mlp: + patch_model.patch_MLPFusion_split(model) + + world_engine_mod.apply_inference_patches = apply_with_toggles + + original_world_attn = getattr(patch_model, "world_flex_attn_forward", None) + if args.force_direct_flex_wrapper and original_world_attn is not None: + from torch.nn.attention.flex_attention import flex_attention + + def direct_flex(q, k, v, meta, cfg, backend=None): + block_mask = meta.flex_block_mask if meta is not None else None + return flex_attention(q, k, v, block_mask=block_mask, enable_gqa=cfg.enable_gqa) + + patch_model.world_flex_attn_forward = direct_flex + + if args.dtype == "float32": + dtype = torch.float32 + elif args.dtype == "bfloat16": + dtype = torch.bfloat16 + else: + dtype = torch.float16 + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + + frame = _load_seed_frame(args.seed_url) + seed = torch.from_numpy(np.repeat(frame[None], 4, axis=0)) + ctrl_seq = _ctrl_sequence(CtrlInput, args.steps) + + sig = inspect.signature(WorldEngine.__init__) + kwargs = { + "quant": None, + "device": args.device, + "dtype": dtype, + } + if "scheduler_steps" in sig.parameters: + kwargs["scheduler_steps"] = args.scheduler_steps + if "cache_interval" in sig.parameters: + kwargs["cache_interval"] = args.cache_interval + engine = WorldEngine(args.model_uri, **kwargs) + if hasattr(engine, "_cache_pass_eager"): + engine._cache_pass_fn = engine._cache_pass_eager + if hasattr(engine, "_denoise_pass_eager"): + engine._denoise_pass_fn = engine._denoise_pass_eager + + latents = [] + with torch.inference_mode(): + engine.append_frame(seed.to(engine.device)) + for i, ctrl in enumerate(ctrl_seq): + x = torch.randn( + (1, 1, 32, 32, 64), + generator=torch.Generator(device="cpu").manual_seed(args.seed + i), + dtype=torch.float32, + ).to(engine.device, dtype=dtype) + inp = engine.prep_inputs(x=x, ctrl=ctrl) + y = engine._denoise_pass_fn(x, inp, engine.kv_cache) + if hasattr(engine, "_cache_pass_fn"): + engine._cache_pass_fn(y, inp, engine.kv_cache) + else: + engine._cache_pass(y, inp, engine.kv_cache) + if hasattr(engine, "_gen_count"): + engine._gen_count += 1 + if engine.device == "mps": + torch.mps.synchronize() + latents.append(y.detach().float().cpu()) + + out_prefix = Path(args.output_prefix) + out_prefix.parent.mkdir(parents=True, exist_ok=True) + lat_path = f"{out_prefix}.latents.pt" + meta_path = f"{out_prefix}.meta.json" + stats_path = f"{out_prefix}.stats.json" + + torch.save(latents, lat_path) + meta = { + "model_uri": args.model_uri, + "device": args.device, + "dtype": args.dtype, + "attention_backend": args.attention_backend, + "scheduler_steps": args.scheduler_steps, + "cache_interval": args.cache_interval, + "steps": args.steps, + "seed_url": args.seed_url, + "seed": args.seed, + "disable_patch_cached_noise": args.disable_patch_cached_noise, + "disable_patch_merge_qkv": args.disable_patch_merge_qkv, + "disable_patch_split_mlp": args.disable_patch_split_mlp, + "force_direct_flex_wrapper": args.force_direct_flex_wrapper, + "metal_force_causal": args.metal_force_causal, + } + with open(meta_path, "w", encoding="utf-8") as f: + json.dump(meta, f, indent=2) + + adj = {} + for i in range(1, len(latents)): + adj[f"{i-1}->{i}"] = _metrics(latents[i - 1], latents[i]) + with open(stats_path, "w", encoding="utf-8") as f: + json.dump({"adjacent_latent_metrics": adj}, f, indent=2) + + # restore monkeypatches + if original_world_attn is not None: + patch_model.world_flex_attn_forward = original_world_attn + world_engine_mod.apply_inference_patches = original_apply + + print(json.dumps({"latents": lat_path, "meta": meta_path, "stats": stats_path}, indent=2)) + + +if __name__ == "__main__": + main() + diff --git a/docs/metal_cpu_cuda_recovery_runbook.md b/docs/metal_cpu_cuda_recovery_runbook.md new file mode 100644 index 0000000..d9ca01a --- /dev/null +++ b/docs/metal_cpu_cuda_recovery_runbook.md @@ -0,0 +1,314 @@ +# Metal/CPU/CUDA Recovery Runbook (Ultra Detailed) + +This runbook is the execution plan to recover correct temporal coherence and backend parity for `Overworld-Models/Lapp0-WP-Mini-1.4.5-BL-Distill` on this repository. + +It is designed to: +- minimize expensive CPU runs, +- isolate regressions quickly, +- enforce strict reproducibility, +- avoid resource leaks/hangs. + +--- + +## Non-Negotiable Rule: Kill Python After Every Run + +After **every** command that runs model code (CPU, MPS, CUDA, parity, tests, benchmarks), execute the cleanup command below. + +This is mandatory even if the run appears to succeed. + +```bash +python3 - <<'PY' +import os, signal, subprocess +me=os.getpid() +out=subprocess.check_output(['ps','-ax','-o','pid=,command='], text=True) +for line in out.splitlines(): + s=line.strip() + if not s: + continue + p=s.split(None,1) + if len(p)<2: + continue + pid=int(p[0]); cmd=p[1].lower() + if pid==me: + continue + if 'python' in cmd: + try: + os.kill(pid, signal.SIGKILL) + except Exception: + pass +print('python_cleanup_done') +PY +``` + +If a run times out or hangs, run cleanup immediately before anything else. + +--- + +## Scope and Success Criteria + +We need to determine why current branch outputs are visually wrong while `wp-1.5` is known-good on CPU/CUDA. + +Success criteria: +1. Current branch CPU short rollout matches `wp-1.5` CPU golden latent trajectory. +2. Current branch MPS/Metal short rollout matches current branch CPU within bf16 tolerance. +3. Medium-length outputs (32+ frames) no longer collapse into confetti/noise textures. +4. Regression tests fail on bad semantics and pass on fixed behavior. + +--- + +## Environment and Paths + +Repository root: +- `/Users/louiscastricato/overworld/world_engine` + +Reference worktree: +- `/Users/louiscastricato/overworld/world_engine_wp15` + +Use these defaults for deterministic runs unless a step says otherwise: +- `HF_HUB_OFFLINE=1` +- `TRANSFORMERS_OFFLINE=1` +- `TORCHDYNAMO_DISABLE=1` +- `WORLD_KV_RUNTIME_CHECKS=0` +- `WORLD_KV_COMPUTE_ACTIVE_BLOCKS=0` + +Seed URL (fixed): +- `https://gist.github.com/user-attachments/assets/f9c20d4d-7565-452d-8b02-42a85ea175ed` + +--- + +## Artifact Layout (Create First) + +Create and use this structure: + +```text +diagnostics/ + scripts/ + out/ + golden/ + reports/ +``` + +Store all intermediate outputs here (no root-level dump). + +Commands: + +```bash +mkdir -p diagnostics/scripts diagnostics/out diagnostics/golden diagnostics/reports +``` + +Then run the mandatory Python cleanup command. + +--- + +## Phase 1: Golden Baseline From `wp-1.5` (Single Expensive CPU Run) + +### Objective +Capture one canonical short latent trajectory on `wp-1.5` CPU that all current-branch variants must match. + +### Steps +1. Run latent-only short rollout (recommended 8 steps max). +2. Save: + - latents per step (`.pt`) + - metadata (`.json`): seed URL, controls, dtype, scheduler steps, cache interval, env vars + - summary metrics (`.json`): adjacent latent cosine/MAE +3. Cleanup Python. + +### Notes +- Do not generate video in this phase unless needed for human sanity check. +- This is the only long CPU run to start. + +--- + +## Phase 2: Current Branch CPU vs Golden (Short Deterministic) + +### Objective +Detect whether regression exists before Metal is involved. + +### Steps +1. Run same latent-only short rollout on current branch CPU. +2. Compare step-by-step to `diagnostics/golden/wp15_cpu_latents.pt`. +3. Record first failing step and error magnitudes in `diagnostics/reports/current_cpu_vs_wp15.json`. +4. Cleanup Python. + +### Interpretation +- If CPU already diverges from golden: regression is in shared model/inference path. +- If CPU matches golden: focus shifts to Metal-specific path. + +--- + +## Phase 3: Fast Ablation Matrix (Current Branch CPU) + +### Objective +Find the smallest feature set causing divergence from golden. + +### Toggle axes (one at a time first, then combinations) +1. `patch_cached_noise_conditioning` ON/OFF +2. `patch_Attn_merge_qkv` ON/OFF +3. `patch_MLPFusion_split` ON/OFF +4. attention wrapper path: + - direct `flex_attention` + - `world_flex_attn_forward` +5. KV metadata optimization paths: + - active-block arithmetic path + - fallback/block-written path + +### Execution protocol +For each configuration: +1. Run short latent rollout (same seed/controls). +2. Compare to golden. +3. Save row in `diagnostics/reports/ablation_matrix_cpu.json`: + - config ID + - first failing step + - stepwise cos/mae +4. Cleanup Python. + +### Exit condition +Stop when one toggle or minimal toggle set restores golden parity. + +--- + +## Phase 4: Metal Parity Once CPU Path Is Re-Grounded + +### Objective +After CPU path is fixed, ensure Metal matches CPU and stays stable over rollout. + +### Steps +1. Run short latent rollout on current branch MPS/Metal with same controls/noise. +2. Compare CPU vs Metal stepwise latents. +3. Save `diagnostics/reports/cpu_vs_metal_short.json`. +4. Generate medium video (32 frames) only after short parity passes. +5. Cleanup Python after each run. + +### Required metrics +- Per-step latent cosine +- Per-step latent MAE +- Optional decoded-frame cosine for sampled steps + +--- + +## Phase 5: Attention and Mask Semantics Verification + +### Objective +Prove mask semantics are aligned across wrapper paths and backends. + +### Checks +1. Compare `world_flex_attn_forward` vs direct `flex_attention` on CPU with identical `q/k/v/meta`. +2. Verify `block_written` and `active_blocks` invariants against expected `mask_written`. +3. Validate causal/non-causal behavior intentionally with explicit flags in diagnostic script. +4. Cleanup Python after each check. + +### Output +- `diagnostics/reports/attn_semantics_checks.json` + +--- + +## Phase 6: Temporal Conditioning and Cache State Invariants + +### Objective +Ensure temporal conditioning inputs and cache evolution are not drifting unexpectedly. + +### Invariants to check per step +1. `frame_idx` monotonic increment +2. `frame_timestamp` monotonic and scaled correctly +3. `kv_cache._is_frozen` state transitions: + - denoise pass: frozen + - cache pass: unfrozen +4. `written` mask evolves as expected for ring/tail +5. `block_written` consistency with `written` +6. No invalid empty attention windows for active queries + +### Output +- `diagnostics/reports/cache_temporal_invariants.json` + +Cleanup Python after each run. + +--- + +## Phase 7: Fix Application and Verification Gate + +### Objective +Apply minimal fix and prove it. + +### Gate sequence +1. Current CPU vs golden short parity: pass +2. Current Metal vs current CPU short parity: pass +3. 32-frame Metal video sanity: pass +4. 120-frame Metal stress sanity: pass +5. Cleanup Python after each gate run + +If any gate fails, do not proceed to cleanup/docs finalization. + +--- + +## Phase 8: Regression Tests to Add + +Create tests that must pass before claiming resolution: + +1. `tests/test_golden_short_rollout_cpu.py` + - compares latent trajectory to stored golden artifact +2. `tests/test_attention_wrapper_semantics.py` + - wrapper vs direct flex equivalence +3. `tests/test_kv_cache_state_trajectory.py` + - ring/tail/written/block metadata invariants +4. `tests/test_metal_cpu_short_parity.py` (MPS-gated) + - short latent parity thresholds + +Each test command run must be followed by Python cleanup. + +--- + +## Repo Cleanup Plan (After Fix Verified) + +1. Move all ad-hoc debug scripts into `diagnostics/scripts/` +2. Move generated media to `diagnostics/out/` +3. Keep only essential reference artifacts in repo +4. Add `.gitignore` entries for transient diagnostics outputs +5. Remove dead toggles and one-off monkeypatch logic used during debugging +6. Update: + - `README_metal_hybrid.md` (runtime behavior + validated config) + - `docs/metal_mps_full_diagnosis.md` (root cause + fix evidence) + +Cleanup Python after any validation runs performed during cleanup. + +--- + +## Standard Command Template for Runs + +Use this template for every scripted run: + +```bash +HF_HUB_OFFLINE=1 \ +TRANSFORMERS_OFFLINE=1 \ +TORCHDYNAMO_DISABLE=1 \ +WORLD_KV_RUNTIME_CHECKS=0 \ +WORLD_KV_COMPUTE_ACTIVE_BLOCKS=0 \ +PYTHONPATH=. \ +./.venv/bin/python diagnostics/scripts/