From 067af161145fbee64edac9ccff5847eda4a607ff Mon Sep 17 00:00:00 2001 From: adavyas Date: Mon, 4 May 2026 20:12:33 -0700 Subject: [PATCH] metal: fix flash attention nsg8 dv64 path --- ggml/src/ggml-metal/ggml-metal.metal | 21 ++++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index c372eaedeae..ddb800cc285 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -6239,17 +6239,24 @@ void kernel_flash_attn_ext_impl( s8x8_t vs; simdgroup_load(vs, ss + 8*cc, SH, 0, false); - FOR_UNROLL (short ii = 0; ii < NO/2; ++ii) { - v8x8_t mv[2]; + if (NO == 1) { + v8x8_t mv; + + simdgroup_load(mv, pv, NS20, 0, false); + simdgroup_multiply_accumulate(lo[0], vs, mv, lo[0]); + } else { + FOR_UNROLL (short ii = 0; ii < NO/2; ++ii) { + v8x8_t mv[2]; - simdgroup_load(mv[0], pv + 0*NSG + 16*ii*NSG, NS20, 0, false); - simdgroup_load(mv[1], pv + 8*NSG + 16*ii*NSG, NS20, 0, false); + simdgroup_load(mv[0], pv + 0*NSG + 16*ii*NSG, NS20, 0, false); + simdgroup_load(mv[1], pv + 8*NSG + 16*ii*NSG, NS20, 0, false); - simdgroup_multiply_accumulate(lo[2*ii + 0], vs, mv[0], lo[2*ii + 0]); - simdgroup_multiply_accumulate(lo[2*ii + 1], vs, mv[1], lo[2*ii + 1]); + simdgroup_multiply_accumulate(lo[2*ii + 0], vs, mv[0], lo[2*ii + 0]); + simdgroup_multiply_accumulate(lo[2*ii + 1], vs, mv[1], lo[2*ii + 1]); + } } - pv += 8*NS20; + pv += 8*NS20; } } else { constexpr short NC = (C/8)/2;