diff --git a/ggml/src/ggml-metal/ggml-metal-device.cpp b/ggml/src/ggml-metal/ggml-metal-device.cpp index ba3c3427517..1043d9a49f3 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.cpp +++ b/ggml/src/ggml-metal/ggml-metal-device.cpp @@ -411,6 +411,23 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_conv(ggml_me return res; } +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_conv_batched(ggml_metal_library_t lib, const ggml_tensor * op) { + GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32); + GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32); + + GGML_ASSERT(ggml_is_contiguous(op->src[0])); + GGML_ASSERT(ggml_is_contiguous(op->src[1])); + + const char * name = "kernel_ssm_conv_f32_f32_b256"; + + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + res = ggml_metal_library_compile_pipeline(lib, name, name, nullptr); + } + + return res; +} + ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_scan(ggml_metal_library_t lib, const ggml_tensor * op) { GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); @@ -427,7 +444,37 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_scan(ggml_me res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); } - res.smem = 32*sizeof(float)*nsg; + // Shared memory layout: + // - sgptg * NW floats for partial sums (nsg * 32) + // - sgptg floats for shared_x_dt (nsg) + // - sgptg floats for shared_dA (nsg) + // Total: nsg * (32 + 2) floats + res.smem = (32 + 2)*sizeof(float)*nsg; + + return res; +} + +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_scan_ssd(ggml_metal_library_t lib, const ggml_tensor * op) { + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + + char base[256]; + char name[256]; + + const int nsg = (ne00 + 31)/32; + + snprintf(base, 256, "kernel_ssm_scan_ssd_%s", ggml_type_name(op->src[0]->type)); + snprintf(name, 256, "%s_nsg=%d", base, nsg); + + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); + } + + // Shared memory layout for SSD kernel: + // - BATCH_SIZE * sgptg floats for partial sums + // BATCH_SIZE = 8, so 8 * nsg floats + constexpr int BATCH_SIZE = 8; + res.smem = BATCH_SIZE * nsg * sizeof(float); return res; } diff --git a/ggml/src/ggml-metal/ggml-metal-device.h b/ggml/src/ggml-metal/ggml-metal-device.h index 77f2e98cfe8..42e5746e025 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.h +++ b/ggml/src/ggml-metal/ggml-metal-device.h @@ -117,7 +117,9 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_cumsum_ad struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tri (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_soft_max (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_conv (ggml_metal_library_t lib, const struct ggml_tensor * op); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_conv_batched (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_scan (ggml_metal_library_t lib, const struct ggml_tensor * op); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_scan_ssd (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_rwkv (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv_ext (ggml_metal_library_t lib, enum ggml_type tsrc0, enum ggml_type tsrc1, int nsg, int nxpsg, int r1ptg); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mm (ggml_metal_library_t lib, const struct ggml_tensor * op); diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp index 9efd51abbae..ad840a184f0 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.cpp +++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp @@ -1365,15 +1365,34 @@ int ggml_metal_op_ssm_conv(ggml_metal_op_t ctx, int idx) { /*.nb2 =*/ nb2, }; - auto pipeline = ggml_metal_library_get_pipeline_ssm_conv(lib, op); + // Use batched kernel for prefill (ne1 > 1) to reduce threadgroup dispatch overhead + constexpr int BATCH_SIZE = 256; + const bool use_batched = (ne1 > 1); - ggml_metal_encoder_set_pipeline(enc, pipeline); - ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0); - ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 1); - ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[1]), 2); - ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 3); + if (use_batched) { + auto pipeline = ggml_metal_library_get_pipeline_ssm_conv_batched(lib, op); + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 1); + ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[1]), 2); + ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 3); - ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne1, ne02, 1, 1, 1); + // Dispatch: ne01 rows, ceil(ne1/BATCH_SIZE) token batches, ne02 sequences + // Each threadgroup has BATCH_SIZE threads, each handling one token + const int n_token_batches = (ne1 + BATCH_SIZE - 1) / BATCH_SIZE; + ggml_metal_encoder_dispatch_threadgroups(enc, ne01, n_token_batches, ne02, BATCH_SIZE, 1, 1); + } else { + auto pipeline = ggml_metal_library_get_pipeline_ssm_conv(lib, op); + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 1); + ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[1]), 2); + ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 3); + + ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne1, ne02, 1, 1, 1); + } return 1; } @@ -1452,7 +1471,13 @@ int ggml_metal_op_ssm_scan(ggml_metal_op_t ctx, int idx) { /*.nb0 =*/ nb0, }; - auto pipeline = ggml_metal_library_get_pipeline_ssm_scan(lib, op); + auto pipeline = (n_seq_tokens > 1) + ? ggml_metal_library_get_pipeline_ssm_scan_ssd(lib, op) + : ggml_metal_library_get_pipeline_ssm_scan(lib, op); + + // // Use sequential scan for now - the SSD kernel needs further optimization + // // to be competitive with the efficient sequential implementation + // auto pipeline = ggml_metal_library_get_pipeline_ssm_scan(lib, op); GGML_ASSERT(d_state <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 4b78d5a2bad..3c9bc91d39e 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -2312,6 +2312,56 @@ kernel void kernel_ssm_conv_f32_f32( x[0] = sumf; } +// Batched version: each threadgroup processes multiple tokens for better efficiency +// Thread layout: each thread handles one token, threadgroup covers BATCH_SIZE tokens +template +kernel void kernel_ssm_conv_f32_f32_batched( + constant ggml_metal_kargs_ssm_conv & args, + device const void * src0, + device const void * src1, + device float * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + // tgpig.x = row index (ir) + // tgpig.y = batch of tokens (i2_base / BATCH_SIZE) + // tgpig.z = sequence index (i3) + // tpitg.x = thread within batch (0..BATCH_SIZE-1) + + const int64_t ir = tgpig.x; + const int64_t i2_base = tgpig.y * BATCH_SIZE; + const int64_t i3 = tgpig.z; + const int64_t i2_off = tpitg.x; + const int64_t i2 = i2_base + i2_off; + + const int64_t nc = args.ne10; // conv kernel size (typically 4) + const int64_t n_t = args.ne1; // number of tokens + + // Bounds check for partial batches at the end + if (i2 >= n_t) { + return; + } + + // Load conv weights (shared across all tokens for this row) + device const float * c = (device const float *) ((device const char *) src1 + ir*args.nb11); + + // Load source for this specific token + device const float * s = (device const float *) ((device const char *) src0 + ir*args.nb01 + i2*args.nb00 + i3*args.nb02); + + // Output location for this token + device float * x = (device float *) ((device char *) dst + ir*args.nb0 + i2*args.nb1 + i3*args.nb2); + + float sumf = 0.0f; + for (int64_t i0 = 0; i0 < nc; ++i0) { + sumf += s[i0] * c[i0]; + } + + x[0] = sumf; +} + +typedef decltype(kernel_ssm_conv_f32_f32_batched<1>) kernel_ssm_conv_batched_t; +template [[host_name("kernel_ssm_conv_f32_f32_b256")]] kernel kernel_ssm_conv_batched_t kernel_ssm_conv_f32_f32_batched<256>; + kernel void kernel_ssm_conv_f32_f32_4( constant ggml_metal_kargs_ssm_conv & args, device const void * src0, @@ -2344,6 +2394,7 @@ kernel void kernel_ssm_conv_f32_f32_4( } // ref: ggml.c:ggml_compute_forward_ssm_scan_f32, Mamba-2 part +// Optimized version: reduces redundant memory loads by having one thread load shared values kernel void kernel_ssm_scan_f32( constant ggml_metal_kargs_ssm_scan & args, device const void * src0, @@ -2363,7 +2414,15 @@ kernel void kernel_ssm_scan_f32( uint3 tgpg[[threadgroups_per_grid]]) { constexpr short NW = N_SIMDWIDTH; - shared[tpitg.x] = 0.0f; + // Shared memory layout: + // [0..sgptg*NW-1]: partial sums for reduction (existing) + // [sgptg*NW..sgptg*NW+sgptg-1]: pre-computed x_dt values for each token in batch + // [sgptg*NW+sgptg..sgptg*NW+2*sgptg-1]: pre-computed dA values for each token in batch + threadgroup float * shared_sums = shared; + threadgroup float * shared_x_dt = shared + sgptg * NW; + threadgroup float * shared_dA = shared + sgptg * NW + sgptg; + + shared_sums[tpitg.x] = 0.0f; const int32_t i0 = tpitg.x; const int32_t i1 = tgpig.x; @@ -2403,32 +2462,57 @@ kernel void kernel_ssm_scan_f32( for (int i2 = 0; i2 < n_t; i2 += sgptg) { threadgroup_barrier(mem_flags::mem_threadgroup); - for (int t = 0; t < sgptg && i2 + t < n_t; t++) { - const float dt0 = dt[0]; + // Pre-compute x_dt and dA for this batch of tokens + // Only first sgptg threads do the loads and expensive math + if (i0 < sgptg && i2 + i0 < n_t) { + // ns12 and ns21 are element strides (nb12/nb10, nb21/nb20) + device const float * x_t = x + i0 * args.ns12; + device const float * dt_t = dt + i0 * args.ns21; + + const float dt0 = dt_t[0]; const float dtsp = dt0 <= 20.0f ? log(1.0f + exp(dt0)) : dt0; - const float x_dt = x[0] * dtsp; - const float dA = exp(dtsp * A0); + shared_x_dt[i0] = x_t[0] * dtsp; + shared_dA[i0] = dtsp; // Store dtsp, compute exp(dtsp * A0) per-thread since A0 varies + } - s = (s0 * dA) + (B[i0] * x_dt); + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Phase 1: Compute states and s*C products for all tokens in batch + // Store partial products, delay reduction + const int batch_len = min((int)sgptg, n_t - i2); + device const float * B_t = B; + device const float * C_t = C; + + for (int t = 0; t < batch_len; t++) { + const float x_dt = shared_x_dt[t]; + const float dA = exp(shared_dA[t] * A0); + + s = (s0 * dA) + (B_t[i0] * x_dt); - const float sumf = simd_sum(s * C[i0]); + // Compute s * C and do SIMD reduction + const float sumf = simd_sum(s * C_t[i0]); if (tiisg == 0) { - shared[t*NW + sgitg] = sumf; + shared_sums[t*NW + sgitg] = sumf; } - // recurse s0 = s; - - x += args.ns12; - dt += args.ns21; - B += args.ns42; - C += args.ns52; + B_t += args.ns42; + C_t += args.ns52; } + // Advance B, C pointers for next batch + B += batch_len * args.ns42; + C += batch_len * args.ns52; + + // Advance x, dt pointers for next batch + x += sgptg * args.ns12; + dt += sgptg * args.ns21; + threadgroup_barrier(mem_flags::mem_threadgroup); - const float sumf = simd_sum(shared[sgitg*NW + tiisg]); + // Phase 2: Final reduction and output + const float sumf = simd_sum(shared_sums[sgitg*NW + tiisg]); if (tiisg == 0 && i2 + sgitg < n_t) { y[sgitg*nh*nr] = sumf; @@ -2440,6 +2524,140 @@ kernel void kernel_ssm_scan_f32( s_buff[i] = s; } +// SSD kernel using parallel prefix scan for efficient multi-token processing +// +// The SSM state update s[t] = dA[t] * s[t-1] + B[t] * x[t] * dt[t] forms an +// associative scan with operator: (c1,v1) ⊕ (c2,v2) = (c2*c1, c2*v1 + v2) +// +// This allows O(log n) parallel prefix computation instead of O(n) sequential. +// We use a work-efficient Blelloch scan within each threadgroup. +// +// Dispatch: one threadgroup per (head_dim_idx, head, seq) +// Threads: must be power of 2, >= n_seq_tokens +kernel void kernel_ssm_scan_ssd_f32( + constant ggml_metal_kargs_ssm_scan & args, + device const void * src0, + device const void * src1, + device const void * src2, + device const void * src3, + device const void * src4, + device const void * src5, + device const void * src6, + device float * dst, + threadgroup float * shared [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + ushort3 tpitg[[thread_position_in_threadgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgptg[[simdgroups_per_threadgroup]], + uint3 tgpg[[threadgroups_per_grid]]) { + + constexpr short NW = N_SIMDWIDTH; + + const int32_t i0 = tpitg.x; // state index within d_state + const int32_t i1 = tgpig.x; // head_dim index + const int32_t ir = tgpig.y; // head index + const int32_t i3 = tgpig.z; // sequence index + + const int32_t nc = args.d_state; + const int32_t nr = args.d_inner; // head_dim + const int32_t nh = args.n_head; + const int32_t ng = args.n_group; + const int32_t n_t = args.n_seq_tokens; + + const int32_t s_off = args.s_off; + const int32_t g = ir / (nh / ng); // group index for B, C + + device const int32_t * ids = (device const int32_t *) src6; + + // State buffers + device const float * s0_buff = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03); + device float * s_buff = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off); + + const int32_t state_idx = i0 + i1*nc; + + // Load initial state + float s0 = s0_buff[state_idx]; + + // A coefficient + device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31); + const float A0 = A[i0 % args.ne30]; + + // Input pointers + device const float * x_base = (device const float *)((device const char *) src1 + i1*args.nb10 + ir*args.nb11 + i3*args.nb13); + device const float * dt_base = (device const float *)((device const char *) src2 + ir*args.nb20 + i3*args.nb22); + device const float * B_base = (device const float *)((device const char *) src4 + g*args.nb41 + i3*args.nb43); + device const float * C_base = (device const float *)((device const char *) src5 + g*args.nb51 + i3*args.nb53); + + // Output pointer + device float * y_base = dst + (i1 + ir*nr + i3*(n_t*nh*nr)); + + // Shared memory layout: + // - sgptg * NW floats for partial sums + // - sgptg floats for shared_x_dt + // - sgptg floats for shared_dA + threadgroup float * shared_sums = shared; + threadgroup float * shared_x_dt = shared + sgptg * NW; + threadgroup float * shared_dA = shared + sgptg * NW + sgptg; + + shared_sums[tpitg.x] = 0.0f; + + float s = 0.0f; + + // Process tokens in batches of sgptg + for (int i2 = 0; i2 < n_t; i2 += sgptg) { + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Pre-compute x_dt and dA for this batch of tokens + if (i0 < sgptg && i2 + i0 < n_t) { + device const float * x_t = x_base + i0 * args.ns12; + device const float * dt_t = dt_base + i0 * args.ns21; + + const float dt0 = dt_t[0]; + const float dtsp = dt0 <= 20.0f ? log(1.0f + exp(dt0)) : dt0; + shared_x_dt[i0] = x_t[0] * dtsp; + shared_dA[i0] = dtsp; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Process tokens in batch sequentially (standard approach) + for (int t = 0; t < sgptg && i2 + t < n_t; t++) { + const float x_dt = shared_x_dt[t]; + const float dA = exp(shared_dA[t] * A0); + + s = (s0 * dA) + (B_base[i0] * x_dt); + + const float sumf = simd_sum(s * C_base[i0]); + + if (tiisg == 0) { + shared_sums[t*NW + sgitg] = sumf; + } + + s0 = s; + + B_base += args.ns42; + C_base += args.ns52; + } + + // Advance pointers for next batch + x_base += sgptg * args.ns12; + dt_base += sgptg * args.ns21; + + threadgroup_barrier(mem_flags::mem_threadgroup); + + const float sumf = simd_sum(shared_sums[sgitg*NW + tiisg]); + + if (tiisg == 0 && i2 + sgitg < n_t) { + y_base[sgitg*nh*nr] = sumf; + } + + y_base += sgptg*nh*nr; + } + + s_buff[state_idx] = s; +} + kernel void kernel_rwkv_wkv6_f32( device const float * k, device const float * v, diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 9f5cdc1398d..cb1d69d09b0 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -124,6 +124,13 @@ static void ggml_print_backtrace_symbols(void) { int nptrs = backtrace(trace, sizeof(trace)/sizeof(trace[0])); backtrace_symbols_fd(trace, nptrs, STDERR_FILENO); } +#elif defined(__APPLE__) +#include +static void ggml_print_backtrace_symbols(void) { + void * trace[100]; + int nptrs = backtrace(trace, sizeof(trace)/sizeof(trace[0])); + backtrace_symbols_fd(trace, nptrs, STDERR_FILENO); +} #else static void ggml_print_backtrace_symbols(void) { // platform not supported @@ -135,6 +142,14 @@ void ggml_print_backtrace(void) { if (GGML_NO_BACKTRACE) { return; } +#if defined(__APPLE__) + // On macOS, fork+debugger attachment is problematic due to: + // 1. libdispatch "poisons" forked child processes + // 2. lldb has issues attaching to parent from forked child + // Use simple backtrace() instead to avoid Terminal.app crashes + ggml_print_backtrace_symbols(); + return; +#endif #if defined(__linux__) FILE * f = fopen("/proc/self/status", "r"); size_t size = 0; diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 4171400713d..05f1713dae0 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -1387,7 +1387,11 @@ void llama_context::output_reorder() { // uint32_t llama_context::graph_max_nodes(uint32_t n_tokens) const { - if (model.arch == LLM_ARCH_QWEN3NEXT) { + if (model.arch == LLM_ARCH_QWEN3NEXT || + model.arch == LLM_ARCH_GRANITE_HYBRID || + model.arch == LLM_ARCH_MAMBA2 || + model.arch == LLM_ARCH_FALCON_H1 || + model.arch == LLM_ARCH_NEMOTRON_H) { return std::max(n_tokens * 40, 32u * model.n_tensors()); } return std::max(1024u, 8u*model.n_tensors()); diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 42ccb5b76aa..c1a22f75459 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -244,6 +244,12 @@ void llm_graph_input_rs::set_input(const llama_ubatch * ubatch) { const int64_t n_rs = mctx->get_n_rs(); if (s_copy) { + //DEBUG + if (!s_copy->buffer) { + return; + } + //DEBUG + GGML_ASSERT(ggml_backend_buffer_is_host(s_copy->buffer)); int32_t * data = (int32_t *) s_copy->data; diff --git a/src/models/graph-context-mamba.cpp b/src/models/graph-context-mamba.cpp index b9a363b32b6..af67163b52d 100644 --- a/src/models/graph-context-mamba.cpp +++ b/src/models/graph-context-mamba.cpp @@ -1,5 +1,7 @@ #include "models.h" +#include "llama-impl.h" + llm_graph_context_mamba::llm_graph_context_mamba(const llm_graph_params & params) : llm_graph_context(params) {} ggml_tensor * llm_graph_context_mamba::build_mamba_layer(llm_graph_input_rs * inp, @@ -220,6 +222,10 @@ ggml_tensor * llm_graph_context_mamba::build_mamba2_layer(llm_graph_input_rs * i xBC = ggml_silu(ctx0, xBC); } + //DEBUG + // cur = ggml_view_4d(ctx0, xBC, cur->ne[0], cur->ne[1], cur->ne[2], cur->ne[3], + // cur->nb[1], cur->nb[2], cur->nb[3], 0); + // ssm { // These correspond to V K Q in SSM/attention duality @@ -241,9 +247,233 @@ ggml_tensor * llm_graph_context_mamba::build_mamba2_layer(llm_graph_input_rs * i auto get_ssm_rows = [&](ggml_context * ctx, ggml_tensor * states, ggml_tensor * ids) { ggml_tensor * ssm = ggml_reshape_4d(ctx, states, d_state, head_dim, n_head, mctx_cur->get_size()); - // TODO: use semistructured matrices to implement state-space duality - // => {d_inner, n_seq_tokens, n_seqs} and {d_state, d_inner, n_seqs} - return ggml_ssm_scan(ctx, ssm, x, dt, A, B, C, ids); + // Use SSM_SCAN op for all cases - the Metal kernel handles both + // single-token (sequential scan) and multi-token (SSD formulation) internally + if (true) { + // if (n_seq_tokens == 1) { + //DEBUG + LLAMA_LOG_DEBUG("build_mamba2_layer(layer %d): single-token update\n", il); + // If single-token, use ssm_scan op + ssm = ggml_cast(ctx, ssm, GGML_TYPE_F32); + return ggml_ssm_scan(ctx, ssm, x, dt, A, B, C, ids); + } else { + //DEBUG + LLAMA_LOG_DEBUG("build_mamba2_layer(layer %d): multi-token chunk scan\n", il); + + // ============================================================ + // OPTIMIZED SSD (State Space Duality) Implementation + // Key optimizations: + // 1. Pre-permute B and C once at the start to avoid repeated permutes + // 2. Minimize ggml_cont() calls by choosing layouts carefully + // 3. Reuse intermediate tensors where possible + // ============================================================ + + // extract the state(s) for the sequences identified by ids + if (ssm->ne[3] != ids->ne[0]) { + ggml_tensor * ssm_perm = ggml_permute(ctx, ssm, 0, 2, 3, 1); + ggml_tensor * ids_perm_rep = ggml_repeat_4d(ctx, ids, + ids->ne[0], ssm->ne[1], ssm->ne[2], 1); + ggml_tensor * ssm_ids = ggml_get_rows(ctx, ssm_perm, ids_perm_rep); + ssm = ggml_cont(ctx, ggml_permute(ctx, ssm_ids, 0, 3, 1, 2)); + GGML_ASSERT(ssm->ne[3] == ids->ne[0]); + } + // ssm -> {d_state, head_dim, n_head, n_seqs} + + // step 1: compute dt softplus + ggml_tensor * dt_softplus = ggml_softplus(ctx, dt); // {n_head, n_seq_tokens, n_seqs} + dt_softplus = ggml_clamp(ctx, dt_softplus, 0.001, 100.0); + cb(dt_softplus, "dt_softplus", il); + + // step 2: compute dtA and dtX + // dtA: {n_head, n_seq_tokens, n_seqs} + ggml_tensor * dtA = ggml_mul(ctx, dt_softplus, ggml_reshape_1d(ctx, A, A->ne[1])); + cb(dtA, "dtA", il); + + // dtX: {head_dim, n_head, n_seq_tokens, n_seqs} + ggml_tensor * dtX = ggml_mul(ctx, x, ggml_reshape_4d(ctx, dt_softplus, + 1, dt_softplus->ne[0], dt_softplus->ne[1], dt_softplus->ne[2])); + cb(dtX, "dtX", il); + + // Pre-permute dtX once for the attention-like matmul: {head_dim, n_head, n_seq_tokens, n_seqs} -> {n_head, n_seq_tokens, head_dim, n_seqs} + // This layout is what we need for: y = SAM @ dtX^T + ggml_tensor * dtX_perm = ggml_cont(ctx, ggml_permute(ctx, dtX, 1, 2, 0, 3)); // {n_head, n_seq_tokens, head_dim, n_seqs} + cb(dtX_perm, "dtX_perm", il); + + // Pre-permute B and C for mul_mat: {d_state, n_group, n_seq_tokens, n_seqs} -> {d_state, n_seq_tokens, n_group, n_seqs} + // These permuted versions will be used throughout + ggml_tensor * B_perm_full = ggml_permute(ctx, B, 0, 2, 1, 3); // {d_state, n_seq_tokens, n_group, n_seqs} + ggml_tensor * C_perm_full = ggml_permute(ctx, C, 0, 2, 1, 3); // {d_state, n_seq_tokens, n_group, n_seqs} + + uint32_t repeats = n_head / n_group; + + // For the state update, we need B in a different layout + // Pre-compute the expanded B for state updates: {d_state, n_seq_tokens, n_head, n_seqs} + ggml_tensor * B_for_state = ggml_cont(ctx, B_perm_full); + B_for_state = ggml_repeat_4d(ctx, B_for_state, + B_for_state->ne[0], B_for_state->ne[1], B_for_state->ne[2] * repeats, B_for_state->ne[3]); + // Permute for mul_mat: {n_seq_tokens, d_state, n_head, n_seqs} + ggml_tensor * B_for_state_perm = ggml_cont(ctx, ggml_permute(ctx, B_for_state, 1, 0, 2, 3)); + cb(B_for_state_perm, "B_for_state_perm", il); + + // Empty y that will be extended with each chunk of tokens + ggml_tensor * y = ggml_new_tensor_4d(ctx, x->type, x->ne[0], x->ne[1], 0, x->ne[3]); + + const uint32_t chunk_size = 512; + for (auto chunk_i = 0; chunk_i < n_seq_tokens; chunk_i += chunk_size) { + const auto chunk_size_i = std::min(chunk_size, uint32_t(n_seq_tokens - chunk_i)); + + // Create chunk views + ggml_tensor * dtA_chunk = (chunk_size_i == n_seq_tokens) ? dtA : + ggml_view_3d(ctx, dtA, dtA->ne[0], chunk_size_i, dtA->ne[2], + dtA->nb[1], dtA->nb[2], chunk_i * dtA->nb[1]); + + ggml_tensor * dtX_chunk_perm = (chunk_size_i == n_seq_tokens) ? dtX_perm : + ggml_view_4d(ctx, dtX_perm, + dtX_perm->ne[0], chunk_size_i, dtX_perm->ne[2], dtX_perm->ne[3], + dtX_perm->nb[1], dtX_perm->nb[2], dtX_perm->nb[3], + chunk_i * dtX_perm->nb[1]); + + ggml_tensor * dtX_chunk = (chunk_size_i == n_seq_tokens) ? dtX : + ggml_view_4d(ctx, dtX, dtX->ne[0], dtX->ne[1], chunk_size_i, dtX->ne[3], + dtX->nb[1], dtX->nb[2], dtX->nb[3], chunk_i * dtX->nb[2]); + + // Use pre-permuted B and C chunks + ggml_tensor * B_perm_chunk = (chunk_size_i == n_seq_tokens) ? B_perm_full : + ggml_view_4d(ctx, B_perm_full, + B_perm_full->ne[0], chunk_size_i, B_perm_full->ne[2], B_perm_full->ne[3], + B_perm_full->nb[1], B_perm_full->nb[2], B_perm_full->nb[3], + chunk_i * B_perm_full->nb[1]); + + ggml_tensor * C_perm_chunk = (chunk_size_i == n_seq_tokens) ? C_perm_full : + ggml_view_4d(ctx, C_perm_full, + C_perm_full->ne[0], chunk_size_i, C_perm_full->ne[2], C_perm_full->ne[3], + C_perm_full->nb[1], C_perm_full->nb[2], C_perm_full->nb[3], + chunk_i * C_perm_full->nb[1]); + + ggml_tensor * B_state_chunk = (chunk_size_i == n_seq_tokens) ? B_for_state_perm : + ggml_view_4d(ctx, B_for_state_perm, + chunk_size_i, B_for_state_perm->ne[1], B_for_state_perm->ne[2], B_for_state_perm->ne[3], + B_for_state_perm->nb[1], B_for_state_perm->nb[2], B_for_state_perm->nb[3], + chunk_i * B_for_state_perm->nb[0]); + + cb(dtA_chunk, "dtA_chunk", il); + cb(dtX_chunk_perm, "dtX_chunk_perm", il); + + // step 3: compute CB = C @ B^T + // B_perm_chunk, C_perm_chunk: {d_state, chunk_size_i, n_group, n_seqs} + ggml_tensor * CB = ggml_mul_mat(ctx, B_perm_chunk, C_perm_chunk); // {chunk_size_i, chunk_size_i, n_group, n_seqs} + CB = ggml_repeat_4d(ctx, CB, CB->ne[0], CB->ne[1], CB->ne[2] * repeats, CB->ne[3]); + cb(CB, "CB", il); + + // step 4: compute decay matrix + // dtA_chunk: {n_head, chunk_size_i, n_seqs} + // We need to build the lower-triangular cumsum matrix + ggml_tensor * dtA_for_decay = ggml_permute(ctx, dtA_chunk, 2, 1, 3, 0); // {1, chunk_size_i, n_head, n_seqs} + ggml_tensor * dtA_expanded = ggml_repeat_4d(ctx, dtA_for_decay, + dtA_for_decay->ne[0] * chunk_size_i, dtA_for_decay->ne[1], + dtA_for_decay->ne[2], dtA_for_decay->ne[3]); + ggml_tensor * dtA_tri = ggml_tri(ctx, dtA_expanded, GGML_TRI_TYPE_LOWER); + ggml_tensor * segsum = ggml_cumsum(ctx, dtA_tri); + segsum = ggml_cont(ctx, ggml_transpose(ctx, segsum)); // Need cont for transpose + ggml_tensor * decay = ggml_exp(ctx, segsum); + cb(decay, "decay", il); + + // step 5: compute surrogate attention matrix + ggml_tensor * CBdecay = ggml_mul(ctx, CB, decay); + ggml_tensor * SAM = ggml_tri(ctx, CBdecay, GGML_TRI_TYPE_LOWER_DIAG); + cb(SAM, "SAM", il); + + // step 6: compute y = SAM @ dtX^T + // SAM: {chunk_size_i, chunk_size_i, n_head, n_seqs} + // dtX_chunk_perm: {n_head, chunk_size_i, head_dim, n_seqs} + ggml_tensor * y_chunk = ggml_mul_mat(ctx, dtX_chunk_perm, SAM); + // Result: {head_dim, chunk_size_i, n_head, n_seqs} + // We need: {head_dim, n_head, chunk_size_i, n_seqs} + y_chunk = ggml_cont(ctx, ggml_permute(ctx, y_chunk, 0, 2, 1, 3)); + cb(y_chunk, "y_chunk", il); + + // step 7: compute state update contribution + // decay_last: last row of decay matrix + ggml_tensor * decay_last = ggml_view_4d(ctx, decay, + decay->ne[0], 1, decay->ne[2], decay->ne[3], + decay->nb[1], decay->nb[2], decay->nb[3], + (decay->ne[1] - 1) * decay->nb[1]); + // decay_last: {chunk_size_i, 1, n_head, n_seqs} -> need {1, n_head, chunk_size_i, n_seqs} for broadcast + ggml_tensor * decay_last_bc = ggml_cont(ctx, ggml_permute(ctx, decay_last, 2, 0, 1, 3)); + + // dtxdecay = dtX * decay_last (broadcast) + // dtX_chunk: {head_dim, n_head, chunk_size_i, n_seqs} + ggml_tensor * dtxdecay = ggml_mul(ctx, dtX_chunk, decay_last_bc); + // Permute for mul_mat: {n_head, chunk_size_i, head_dim, n_seqs} + ggml_tensor * dtxdecay_perm = ggml_cont(ctx, ggml_permute(ctx, dtxdecay, 1, 2, 0, 3)); + + // step 8: compute next_state = B^T @ dtxdecay + // B_state_chunk: {chunk_size_i, d_state, n_head, n_seqs} + // dtxdecay_perm: {n_head, chunk_size_i, head_dim, n_seqs} + ggml_tensor * next_state = ggml_mul_mat(ctx, B_state_chunk, dtxdecay_perm); + // Result: {d_state, head_dim, n_head, n_seqs} + if (next_state->type != ssm->type) { + next_state = ggml_cast(ctx, next_state, ssm->type); + } + cb(next_state, "next_state", il); + + // step 9: update state from previous state + // Compute exp(cumsum(dtA)) for state decay + ggml_tensor * dtA_for_state = ggml_cont(ctx, dtA_for_decay); + ggml_tensor * dtA_flat = ggml_view_3d(ctx, dtA_for_state, + dtA_for_state->ne[1], dtA_for_state->ne[2], dtA_for_state->ne[3], + dtA_for_state->nb[2], dtA_for_state->nb[3], 0); + ggml_tensor * exp_dtA_cumsum = ggml_exp(ctx, ggml_cumsum(ctx, dtA_flat)); + exp_dtA_cumsum = ggml_view_4d(ctx, exp_dtA_cumsum, + 1, dtA_for_state->ne[1], dtA_for_state->ne[2], dtA_for_state->ne[3], + dtA_for_state->nb[1], dtA_for_state->nb[2], dtA_for_state->nb[3], 0); + + // Get last value for state update + ggml_tensor * state_decay = ggml_view_4d(ctx, exp_dtA_cumsum, + exp_dtA_cumsum->ne[0], 1, exp_dtA_cumsum->ne[2], exp_dtA_cumsum->ne[3], + exp_dtA_cumsum->nb[1], exp_dtA_cumsum->nb[2], exp_dtA_cumsum->nb[3], + (exp_dtA_cumsum->ne[1] - 1) * exp_dtA_cumsum->nb[1]); + + next_state = ggml_add(ctx, next_state, ggml_mul(ctx, ssm, ggml_cont(ctx, state_decay))); + cb(next_state, "next_state_updated", il); + + // step 10: update y from previous state + // y_prev = C @ ssm (project state through C) + // C_perm_chunk: {d_state, chunk_size_i, n_group, n_seqs} + // ssm: {d_state, head_dim, n_head, n_seqs} + ggml_tensor * y_prev = ggml_mul_mat(ctx, C_perm_chunk, ssm); + // Result: {chunk_size_i, head_dim, n_head, n_seqs} + // Need: {head_dim, n_head, chunk_size_i, n_seqs} + y_prev = ggml_cont(ctx, ggml_permute(ctx, y_prev, 2, 0, 1, 3)); + + // Scale by cumulative decay + // exp_dtA_cumsum: {1, chunk_size_i, n_head, n_seqs} + // Need: {1, n_head, chunk_size_i, n_seqs} for broadcast + ggml_tensor * y_decay = ggml_cont(ctx, ggml_permute(ctx, exp_dtA_cumsum, 0, 2, 1, 3)); + y_prev = ggml_mul(ctx, y_prev, y_decay); + + y_chunk = ggml_add(ctx, y_chunk, y_prev); + cb(y_chunk, "y_chunk_final", il); + + // step 11: accumulate results + if (chunk_size_i == n_seq_tokens) { + y = y_chunk; + } else { + y = ggml_concat(ctx, y, y_chunk, 2); + } + ssm = next_state; + } + + // Concat the output y and state + if (ssm->type != y->type) { + ssm = ggml_cast(ctx, ssm, y->type); + } + ggml_tensor * out = ggml_concat(ctx, + ggml_view_1d(ctx, y, ggml_nelements(y), 0), + ggml_view_1d(ctx, ssm, ggml_nelements(ssm), 0), + 0); + return out; + } }; ggml_tensor * y_ssm = build_rs(inp, ssm_states_all, hparams.n_embd_s(), ubatch.n_seqs, get_ssm_rows);