From 9d715f47c729bcf3aabfd25f954f5692e9c25307 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Wed, 3 Dec 2025 14:09:08 -0700 Subject: [PATCH 1/9] feat: First-pass at porting SSD impl from previous work It builds but doesn't run yet Branch: Mamba2SSD Signed-off-by: Gabe Goodhart --- src/models/graph-context-mamba.cpp | 194 ++++++++++++++++++++++++++++- 1 file changed, 191 insertions(+), 3 deletions(-) diff --git a/src/models/graph-context-mamba.cpp b/src/models/graph-context-mamba.cpp index b9a363b32b..726e068133 100644 --- a/src/models/graph-context-mamba.cpp +++ b/src/models/graph-context-mamba.cpp @@ -1,5 +1,7 @@ #include "models.h" +#include "llama-impl.h" + llm_graph_context_mamba::llm_graph_context_mamba(const llm_graph_params & params) : llm_graph_context(params) {} ggml_tensor * llm_graph_context_mamba::build_mamba_layer(llm_graph_input_rs * inp, @@ -241,9 +243,195 @@ ggml_tensor * llm_graph_context_mamba::build_mamba2_layer(llm_graph_input_rs * i auto get_ssm_rows = [&](ggml_context * ctx, ggml_tensor * states, ggml_tensor * ids) { ggml_tensor * ssm = ggml_reshape_4d(ctx, states, d_state, head_dim, n_head, mctx_cur->get_size()); - // TODO: use semistructured matrices to implement state-space duality - // => {d_inner, n_seq_tokens, n_seqs} and {d_state, d_inner, n_seqs} - return ggml_ssm_scan(ctx, ssm, x, dt, A, B, C, ids); + if (n_seq_tokens == 1) { + // if (true) { + //DEBUG + LLAMA_LOG_DEBUG("build_mamba2_layer(layer %d): single-token update\n", il); + // If single-token, use ssm_scan op + ssm = ggml_cast(ctx, ssm, GGML_TYPE_F32); + return ggml_ssm_scan(ctx, ssm, x, dt, A, B, C, ids); + } else { + //DEBUG + LLAMA_LOG_DEBUG("build_mamba2_layer(layer %d): multi-token chunk scan\n", il); + + // otherwise, use the SSD formulation + + // extract the state(s) for the sequences identified by ids + if (ssm->ne[3] != ids->ne[0]) { + ggml_tensor * ssm_perm = ggml_permute(ctx, ssm, 0, 2, 3, 1); // put the target dim in dim 1 + ggml_tensor * ids_perm_rep = ggml_repeat_4d(ctx, ids, + ids->ne[0], ssm->ne[1], ssm->ne[2], 1); // repeat to match expected shape + ggml_tensor * ssm_ids = ggml_get_rows(ctx, ssm_perm, ids_perm_rep); // extract ids as rows + ssm = ggml_cont(ctx, ggml_permute(ctx, ssm_ids, 0, 3, 1, 2)); // permute back to original shape + GGML_ASSERT(ssm->ne[3] == ids->ne[0]); + } + // ssm -> {d_state, head_dim, n_head, n_seqs} + + // step 1: compute dt softplus + // NOTE: In other implementations, the bias is added after + // the softplus. This shouldn't be a problem, but it's a + // difference. + ggml_tensor * dt_softplus = ggml_softplus(ctx, dt); // {n_head, n_seq_tokens, n_seqs} + dt_softplus = ggml_clamp(ctx, dt_softplus, 0.001, 100.0); + cb(dt_softplus, "dt_softplus", il); + + // step 2: compute dtA and dtX + ggml_tensor * dtA = ggml_mul(ctx, dt_softplus, ggml_reshape_1d(ctx, A, A->ne[1])); // {n_head, n_seq_tokens, n_seqs} + cb(dtA, "dtA", il); + ggml_tensor * dtX = ggml_mul(ctx, x, ggml_reshape_4d(ctx, dt_softplus, 1, dt_softplus->ne[0], dt_softplus->ne[1], dt_softplus->ne[2])); // {head_dim, n_head, n_seq_tokens, n_seqs} + cb(dtX, "dtX", il); + + // loop over all chunks + uint32_t repeats = n_head / n_group; + + // Empty y that will be extended with each chunk of tokens + ggml_tensor * y = ggml_new_tensor_4d(ctx, x->type, x->ne[0], x->ne[1], 0, x->ne[3]); + // TODO: make this configurable + const uint32_t chunk_size = 512; // default ubatch size + for (auto chunk_i = 0; chunk_i < n_seq_tokens; chunk_i += chunk_size) { + ggml_tensor * dtA_chunk; + ggml_tensor * dtX_chunk; + ggml_tensor * B_chunk; + ggml_tensor * C_chunk; + const auto chunk_size_i = std::min(chunk_size, uint32_t(n_seq_tokens - chunk_i)); + if (chunk_size_i == n_seq_tokens) { + dtA_chunk = dtA; + dtX_chunk = dtX; + B_chunk = B; + C_chunk = C; + } else { + // chunk views + // slice dtA on dim 1 + dtA_chunk = ggml_view_3d(ctx, dtA, + dtA->ne[0], chunk_size_i, dtA->ne[2], + dtA->nb[1], dtA->nb[2], + chunk_i * dtA->nb[1]); + // slice dtX on dim 2 + dtX_chunk = ggml_view_4d(ctx, dtX, + dtX->ne[0], dtX->ne[1], chunk_size_i, dtX->ne[3], + dtX->nb[1], dtX->nb[2], dtX->nb[3], + chunk_i * dtX->nb[2]); + // slice B on dim 2 + B_chunk = ggml_view_4d(ctx, B, + B->ne[0], B->ne[1], chunk_size_i, B->ne[3], + B->nb[1], B->nb[2], B->nb[3], + chunk_i * B->nb[2]); + // slice C on dim 2 + C_chunk = ggml_view_4d(ctx, C, + C->ne[0], C->ne[1], chunk_size_i, C->ne[3], + C->nb[1], C->nb[2], C->nb[3], + chunk_i * C->nb[2]); + } + cb(dtA_chunk, "dtA_chunk", il); // {n_head, chunk_size_i, n_seqs} + cb(dtX_chunk, "dtX_chunk", il); // {head_dim, n_head, chunk_size_i, n_seqs} + cb(B_chunk, "B_chunk", il); // {d_state, n_group, chunk_size_i, n_seqs} + cb(C_chunk, "C_chunk", il); // {d_state, n_group, chunk_size_i, n_seqs} + + // step 3: compute CB + ggml_tensor * C_perm = ggml_permute(ctx, C_chunk, 0, 2, 1, 3); // {d_state, chunk_size_i, n_group, n_seqs} + ggml_tensor * B_perm = ggml_permute(ctx, B_chunk, 0, 2, 1, 3); // {d_state, chunk_size_i, n_group, n_seqs} + ggml_tensor * CB = ggml_mul_mat(ctx, B_perm, C_perm); // {chunk_size_i, chunk_size_i, n_group, n_seqs} + CB = ggml_repeat_4d(ctx, CB, CB->ne[0], CB->ne[1], CB->ne[2] * repeats, CB->ne[3]); // {chunk_size_i, chunk_size_i, n_head (repeats * n_group), n_seqs} + cb(CB, "CB", il); + + // step 4: compute decay + dtA_chunk = ggml_permute(ctx, dtA_chunk, 2, 1, 3, 0); // {1, chunk_size_i, n_head, n_seqs} + ggml_tensor * dtA_tmp0 = ggml_repeat_4d(ctx, dtA_chunk, + dtA_chunk->ne[0] * chunk_size_i, dtA_chunk->ne[1], dtA_chunk->ne[2], dtA_chunk->ne[3]); // {chunk_size_i_0, chunk_size_i_1, n_head, n_seqs} + ggml_tensor * dtA_tmp1 = ggml_tri(ctx, dtA_tmp0, GGML_TRI_TYPE_LOWER); // {chunk_size_i_0, chunk_size_i_1, n_head, n_seqs} + ggml_tensor * segsum = ggml_cumsum(ctx, dtA_tmp1); // {chunk_size_i_0, chunk_size_i_1, n_head, n_seqs} + segsum = ggml_cont(ctx, ggml_transpose(ctx, segsum)); // {chunk_size_i_1, chunk_size_i_0, n_head, n_seqs} + cb(segsum, "segsum", il); + ggml_tensor * decay = ggml_exp(ctx, segsum); // {chunk_size_i_1, chunk_size_i_0, n_head, n_seqs} + cb(decay, "decay", il); + + // step 5: compute surrogate_attention_matrix + ggml_tensor * CBdecay = ggml_mul(ctx, CB, decay); + ggml_tensor * surrogate_attention_matrix = ggml_tri(ctx, CBdecay, GGML_TRI_TYPE_LOWER_DIAG); + cb(surrogate_attention_matrix, "surrogate_attention_matrix", il); + + // step 6: compute y + ggml_tensor * dtX_chunk_perm = ggml_cont(ctx, ggml_permute(ctx, dtX_chunk, 1, 2, 0, 3)); + ggml_tensor * y_chunk = ggml_mul_mat(ctx, dtX_chunk_perm, surrogate_attention_matrix); + y_chunk = ggml_cont(ctx, ggml_permute(ctx, y_chunk, 0, 2, 1, 3)); + cb(y_chunk, "y_chunk", il); // {n_head, chunk_size_i, n_seqs} + + // step 7: compute dtxdecay + ggml_tensor * decay_last = ggml_view_4d(ctx, decay, + decay->ne[0], 1, decay->ne[2], decay->ne[3], + decay->nb[1], decay->nb[2], decay->nb[3], + (decay->ne[1] - 1) * decay->nb[1]); + decay_last = ggml_cont(ctx, ggml_permute(ctx, decay_last, 2, 0, 1, 3)); + cb(decay_last, "decay_last", il); + B_perm = ggml_cont(ctx, B_perm); + B_perm = ggml_repeat_4d(ctx, B_perm, + B_perm->ne[0], B_perm->ne[1], B_perm->ne[2] * repeats, B_perm->ne[3]); + ggml_tensor * dtxdecay = ggml_mul(ctx, dtX_chunk, decay_last); + dtxdecay = ggml_cont(ctx, ggml_permute(ctx, dtxdecay, 1, 2, 0, 3)); + cb(dtxdecay, "dtxdecay", il); + + // step 8: compute next_state + ggml_tensor * next_state = ggml_mul_mat(ctx, ggml_cont(ctx, ggml_permute(ctx, B_perm, 1, 0, 2, 3)), dtxdecay); + if (next_state->type != ssm->type) { + next_state = ggml_cast(ctx, next_state, ssm->type); + } + cb(next_state, "next_state", il); + + // TODO: Skip y and state updates if no previous state + + // step 9: update from previous state + dtA_chunk = ggml_cont(ctx, dtA_chunk); + ggml_tensor * dtA_chunk_flat = ggml_view_3d(ctx, + dtA_chunk, dtA_chunk->ne[1], dtA_chunk->ne[2], dtA_chunk->ne[3], + dtA_chunk->nb[2], dtA_chunk->nb[3], 0); // {chunk_size_i, n_head, n_seqs, 1} + ggml_tensor * exp_dtA_cumsum = ggml_view_4d(ctx, + ggml_exp(ctx, ggml_cumsum(ctx, dtA_chunk_flat)), + 1, dtA_chunk->ne[1], dtA_chunk->ne[2], dtA_chunk->ne[3], + dtA_chunk->nb[1], dtA_chunk->nb[2], dtA_chunk->nb[3], 0); // {1, chunk_size_i, n_head, n_seqs} + cb(exp_dtA_cumsum, "exp_dtA_cumsum", il); + ggml_tensor * exp_dtA_cumsum_last = ggml_view_4d(ctx, exp_dtA_cumsum, + exp_dtA_cumsum->ne[0], 1, exp_dtA_cumsum->ne[2], exp_dtA_cumsum->ne[3], + exp_dtA_cumsum->nb[1], exp_dtA_cumsum->nb[2], exp_dtA_cumsum->nb[3], + (exp_dtA_cumsum->ne[1] - 1) * exp_dtA_cumsum->nb[1]); // {1, 1, n_head, n_seqs} + cb(exp_dtA_cumsum_last, "exp_dtA_cumsum_last", il); + // ggml_tensor * exp_dtA_cumsum_perm = ggml_permute(ctx, exp_dtA_cumsum_last, 2, 1, 3, 0); // {1, 1, n_head, n_seqs} + next_state = ggml_add(ctx, next_state, ggml_mul(ctx, ssm, ggml_cont(ctx, exp_dtA_cumsum_last))); + cb(next_state, "next_state_updated", il); + + // step 10: update from previous y + ggml_tensor * y_prev = ggml_mul_mat(ctx, + C_perm, // {d_state, chunk_size_i, n_group, n_seqs} + ssm // {d_state, head_dim, n_head, n_seqs} + ); // {chunk_size_i, head_dim, n_head, n_seqs} + cb(y_prev, "y_prev", il); + y_prev = ggml_mul(ctx, + ggml_cont(ctx, ggml_permute(ctx, y_prev, 2, 0, 1, 3)), // {head_dim, n_head, chunk_size_i, n_seqs} + ggml_cont(ctx, ggml_permute(ctx, exp_dtA_cumsum, 0, 2, 1, 3)) // {1, n_head, chunk_size_i, n_seqs} + ); // {head_dim, chunk_size_i, n_head, n_seqs} + cb(y_prev, "y_prev_mul", il); + y_chunk = ggml_add(ctx, y_chunk, y_prev); + cb(y_chunk, "y_chunk_updated", il); + + // step 11: recurse + if (chunk_size_i == n_seq_tokens) { + y = y_chunk; + } else { + y = ggml_concat(ctx, y, y_chunk, 2); + } + cb(y, "y", il); + ssm = next_state; + } + + // Concat the output y and state + if (ssm->type != y->type) { + ssm = ggml_cast(ctx, ssm, y->type); + } + ggml_tensor * out = ggml_concat(ctx, + ggml_view_1d(ctx, y, ggml_nelements(y), 0), + ggml_view_1d(ctx, ssm, ggml_nelements(ssm), 0), + 0); + return out; + } }; ggml_tensor * y_ssm = build_rs(inp, ssm_states_all, hparams.n_embd_s(), ubatch.n_seqs, get_ssm_rows); From 960bb52b9d89b60e6975150e55b9756cc0653747 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Wed, 3 Dec 2025 14:09:29 -0700 Subject: [PATCH 2/9] fix: Increase max nodes for models known to use mamba2 Branch: Mamba2SSD Signed-off-by: Gabe Goodhart --- src/llama-context.cpp | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 4171400713..05f1713dae 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -1387,7 +1387,11 @@ void llama_context::output_reorder() { // uint32_t llama_context::graph_max_nodes(uint32_t n_tokens) const { - if (model.arch == LLM_ARCH_QWEN3NEXT) { + if (model.arch == LLM_ARCH_QWEN3NEXT || + model.arch == LLM_ARCH_GRANITE_HYBRID || + model.arch == LLM_ARCH_MAMBA2 || + model.arch == LLM_ARCH_FALCON_H1 || + model.arch == LLM_ARCH_NEMOTRON_H) { return std::max(n_tokens * 40, 32u * model.n_tensors()); } return std::max(1024u, 8u*model.n_tensors()); From 0dee5b15000858760adcf825ed31fde5b78bf0bc Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Fri, 5 Dec 2025 20:21:39 -0700 Subject: [PATCH 3/9] TEMP: Disable everything after ssm_conv in mamba2 Branch: Mamba2SSD Signed-off-by: Gabe Goodhart --- src/llama-graph.cpp | 6 + src/models/graph-context-mamba.cpp | 488 +++++++++++++++-------------- 2 files changed, 252 insertions(+), 242 deletions(-) diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 42ccb5b76a..c1a22f7545 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -244,6 +244,12 @@ void llm_graph_input_rs::set_input(const llama_ubatch * ubatch) { const int64_t n_rs = mctx->get_n_rs(); if (s_copy) { + //DEBUG + if (!s_copy->buffer) { + return; + } + //DEBUG + GGML_ASSERT(ggml_backend_buffer_is_host(s_copy->buffer)); int32_t * data = (int32_t *) s_copy->data; diff --git a/src/models/graph-context-mamba.cpp b/src/models/graph-context-mamba.cpp index 726e068133..7c613a0ce7 100644 --- a/src/models/graph-context-mamba.cpp +++ b/src/models/graph-context-mamba.cpp @@ -216,252 +216,256 @@ ggml_tensor * llm_graph_context_mamba::build_mamba2_layer(llm_graph_input_rs * i // For simultaneous sequences, all sequences need to have the same length. xBC = ggml_ssm_conv(ctx0, conv_x, model.layers[il].ssm_conv1d); - // bias - xBC = ggml_add(ctx0, xBC, model.layers[il].ssm_conv1d_b); + // // bias + // xBC = ggml_add(ctx0, xBC, model.layers[il].ssm_conv1d_b); - xBC = ggml_silu(ctx0, xBC); + // xBC = ggml_silu(ctx0, xBC); } - // ssm - { - // These correspond to V K Q in SSM/attention duality - ggml_tensor * x = ggml_view_4d(ctx0, xBC, head_dim, n_head, n_seq_tokens, n_seqs, head_dim * xBC->nb[0], - xBC->nb[1], xBC->nb[2], 0); - ggml_tensor * B = ggml_view_4d(ctx0, xBC, d_state, n_group, n_seq_tokens, n_seqs, d_state * xBC->nb[0], - xBC->nb[1], xBC->nb[2], d_inner * ggml_element_size(xBC)); - ggml_tensor * C = ggml_view_4d(ctx0, xBC, d_state, n_group, n_seq_tokens, n_seqs, d_state * xBC->nb[0], - xBC->nb[1], xBC->nb[2], (d_inner + n_group * d_state) * ggml_element_size(xBC)); - - // {n_head, n_seq_tokens, n_seqs} - dt = ggml_add(ctx0, ggml_cont(ctx0, dt), model.layers[il].ssm_dt_b); - - ggml_tensor * A = model.layers[il].ssm_a; - - // use the states and the indices provided by build_recurrent_state - // (this is necessary in order to properly use the states before they are overwritten, - // while avoiding to make unnecessary copies of the states) - auto get_ssm_rows = [&](ggml_context * ctx, ggml_tensor * states, ggml_tensor * ids) { - ggml_tensor * ssm = ggml_reshape_4d(ctx, states, d_state, head_dim, n_head, mctx_cur->get_size()); - - if (n_seq_tokens == 1) { - // if (true) { - //DEBUG - LLAMA_LOG_DEBUG("build_mamba2_layer(layer %d): single-token update\n", il); - // If single-token, use ssm_scan op - ssm = ggml_cast(ctx, ssm, GGML_TYPE_F32); - return ggml_ssm_scan(ctx, ssm, x, dt, A, B, C, ids); - } else { - //DEBUG - LLAMA_LOG_DEBUG("build_mamba2_layer(layer %d): multi-token chunk scan\n", il); - - // otherwise, use the SSD formulation - - // extract the state(s) for the sequences identified by ids - if (ssm->ne[3] != ids->ne[0]) { - ggml_tensor * ssm_perm = ggml_permute(ctx, ssm, 0, 2, 3, 1); // put the target dim in dim 1 - ggml_tensor * ids_perm_rep = ggml_repeat_4d(ctx, ids, - ids->ne[0], ssm->ne[1], ssm->ne[2], 1); // repeat to match expected shape - ggml_tensor * ssm_ids = ggml_get_rows(ctx, ssm_perm, ids_perm_rep); // extract ids as rows - ssm = ggml_cont(ctx, ggml_permute(ctx, ssm_ids, 0, 3, 1, 2)); // permute back to original shape - GGML_ASSERT(ssm->ne[3] == ids->ne[0]); - } - // ssm -> {d_state, head_dim, n_head, n_seqs} - - // step 1: compute dt softplus - // NOTE: In other implementations, the bias is added after - // the softplus. This shouldn't be a problem, but it's a - // difference. - ggml_tensor * dt_softplus = ggml_softplus(ctx, dt); // {n_head, n_seq_tokens, n_seqs} - dt_softplus = ggml_clamp(ctx, dt_softplus, 0.001, 100.0); - cb(dt_softplus, "dt_softplus", il); - - // step 2: compute dtA and dtX - ggml_tensor * dtA = ggml_mul(ctx, dt_softplus, ggml_reshape_1d(ctx, A, A->ne[1])); // {n_head, n_seq_tokens, n_seqs} - cb(dtA, "dtA", il); - ggml_tensor * dtX = ggml_mul(ctx, x, ggml_reshape_4d(ctx, dt_softplus, 1, dt_softplus->ne[0], dt_softplus->ne[1], dt_softplus->ne[2])); // {head_dim, n_head, n_seq_tokens, n_seqs} - cb(dtX, "dtX", il); - - // loop over all chunks - uint32_t repeats = n_head / n_group; - - // Empty y that will be extended with each chunk of tokens - ggml_tensor * y = ggml_new_tensor_4d(ctx, x->type, x->ne[0], x->ne[1], 0, x->ne[3]); - // TODO: make this configurable - const uint32_t chunk_size = 512; // default ubatch size - for (auto chunk_i = 0; chunk_i < n_seq_tokens; chunk_i += chunk_size) { - ggml_tensor * dtA_chunk; - ggml_tensor * dtX_chunk; - ggml_tensor * B_chunk; - ggml_tensor * C_chunk; - const auto chunk_size_i = std::min(chunk_size, uint32_t(n_seq_tokens - chunk_i)); - if (chunk_size_i == n_seq_tokens) { - dtA_chunk = dtA; - dtX_chunk = dtX; - B_chunk = B; - C_chunk = C; - } else { - // chunk views - // slice dtA on dim 1 - dtA_chunk = ggml_view_3d(ctx, dtA, - dtA->ne[0], chunk_size_i, dtA->ne[2], - dtA->nb[1], dtA->nb[2], - chunk_i * dtA->nb[1]); - // slice dtX on dim 2 - dtX_chunk = ggml_view_4d(ctx, dtX, - dtX->ne[0], dtX->ne[1], chunk_size_i, dtX->ne[3], - dtX->nb[1], dtX->nb[2], dtX->nb[3], - chunk_i * dtX->nb[2]); - // slice B on dim 2 - B_chunk = ggml_view_4d(ctx, B, - B->ne[0], B->ne[1], chunk_size_i, B->ne[3], - B->nb[1], B->nb[2], B->nb[3], - chunk_i * B->nb[2]); - // slice C on dim 2 - C_chunk = ggml_view_4d(ctx, C, - C->ne[0], C->ne[1], chunk_size_i, C->ne[3], - C->nb[1], C->nb[2], C->nb[3], - chunk_i * C->nb[2]); - } - cb(dtA_chunk, "dtA_chunk", il); // {n_head, chunk_size_i, n_seqs} - cb(dtX_chunk, "dtX_chunk", il); // {head_dim, n_head, chunk_size_i, n_seqs} - cb(B_chunk, "B_chunk", il); // {d_state, n_group, chunk_size_i, n_seqs} - cb(C_chunk, "C_chunk", il); // {d_state, n_group, chunk_size_i, n_seqs} - - // step 3: compute CB - ggml_tensor * C_perm = ggml_permute(ctx, C_chunk, 0, 2, 1, 3); // {d_state, chunk_size_i, n_group, n_seqs} - ggml_tensor * B_perm = ggml_permute(ctx, B_chunk, 0, 2, 1, 3); // {d_state, chunk_size_i, n_group, n_seqs} - ggml_tensor * CB = ggml_mul_mat(ctx, B_perm, C_perm); // {chunk_size_i, chunk_size_i, n_group, n_seqs} - CB = ggml_repeat_4d(ctx, CB, CB->ne[0], CB->ne[1], CB->ne[2] * repeats, CB->ne[3]); // {chunk_size_i, chunk_size_i, n_head (repeats * n_group), n_seqs} - cb(CB, "CB", il); - - // step 4: compute decay - dtA_chunk = ggml_permute(ctx, dtA_chunk, 2, 1, 3, 0); // {1, chunk_size_i, n_head, n_seqs} - ggml_tensor * dtA_tmp0 = ggml_repeat_4d(ctx, dtA_chunk, - dtA_chunk->ne[0] * chunk_size_i, dtA_chunk->ne[1], dtA_chunk->ne[2], dtA_chunk->ne[3]); // {chunk_size_i_0, chunk_size_i_1, n_head, n_seqs} - ggml_tensor * dtA_tmp1 = ggml_tri(ctx, dtA_tmp0, GGML_TRI_TYPE_LOWER); // {chunk_size_i_0, chunk_size_i_1, n_head, n_seqs} - ggml_tensor * segsum = ggml_cumsum(ctx, dtA_tmp1); // {chunk_size_i_0, chunk_size_i_1, n_head, n_seqs} - segsum = ggml_cont(ctx, ggml_transpose(ctx, segsum)); // {chunk_size_i_1, chunk_size_i_0, n_head, n_seqs} - cb(segsum, "segsum", il); - ggml_tensor * decay = ggml_exp(ctx, segsum); // {chunk_size_i_1, chunk_size_i_0, n_head, n_seqs} - cb(decay, "decay", il); - - // step 5: compute surrogate_attention_matrix - ggml_tensor * CBdecay = ggml_mul(ctx, CB, decay); - ggml_tensor * surrogate_attention_matrix = ggml_tri(ctx, CBdecay, GGML_TRI_TYPE_LOWER_DIAG); - cb(surrogate_attention_matrix, "surrogate_attention_matrix", il); - - // step 6: compute y - ggml_tensor * dtX_chunk_perm = ggml_cont(ctx, ggml_permute(ctx, dtX_chunk, 1, 2, 0, 3)); - ggml_tensor * y_chunk = ggml_mul_mat(ctx, dtX_chunk_perm, surrogate_attention_matrix); - y_chunk = ggml_cont(ctx, ggml_permute(ctx, y_chunk, 0, 2, 1, 3)); - cb(y_chunk, "y_chunk", il); // {n_head, chunk_size_i, n_seqs} - - // step 7: compute dtxdecay - ggml_tensor * decay_last = ggml_view_4d(ctx, decay, - decay->ne[0], 1, decay->ne[2], decay->ne[3], - decay->nb[1], decay->nb[2], decay->nb[3], - (decay->ne[1] - 1) * decay->nb[1]); - decay_last = ggml_cont(ctx, ggml_permute(ctx, decay_last, 2, 0, 1, 3)); - cb(decay_last, "decay_last", il); - B_perm = ggml_cont(ctx, B_perm); - B_perm = ggml_repeat_4d(ctx, B_perm, - B_perm->ne[0], B_perm->ne[1], B_perm->ne[2] * repeats, B_perm->ne[3]); - ggml_tensor * dtxdecay = ggml_mul(ctx, dtX_chunk, decay_last); - dtxdecay = ggml_cont(ctx, ggml_permute(ctx, dtxdecay, 1, 2, 0, 3)); - cb(dtxdecay, "dtxdecay", il); - - // step 8: compute next_state - ggml_tensor * next_state = ggml_mul_mat(ctx, ggml_cont(ctx, ggml_permute(ctx, B_perm, 1, 0, 2, 3)), dtxdecay); - if (next_state->type != ssm->type) { - next_state = ggml_cast(ctx, next_state, ssm->type); - } - cb(next_state, "next_state", il); - - // TODO: Skip y and state updates if no previous state - - // step 9: update from previous state - dtA_chunk = ggml_cont(ctx, dtA_chunk); - ggml_tensor * dtA_chunk_flat = ggml_view_3d(ctx, - dtA_chunk, dtA_chunk->ne[1], dtA_chunk->ne[2], dtA_chunk->ne[3], - dtA_chunk->nb[2], dtA_chunk->nb[3], 0); // {chunk_size_i, n_head, n_seqs, 1} - ggml_tensor * exp_dtA_cumsum = ggml_view_4d(ctx, - ggml_exp(ctx, ggml_cumsum(ctx, dtA_chunk_flat)), - 1, dtA_chunk->ne[1], dtA_chunk->ne[2], dtA_chunk->ne[3], - dtA_chunk->nb[1], dtA_chunk->nb[2], dtA_chunk->nb[3], 0); // {1, chunk_size_i, n_head, n_seqs} - cb(exp_dtA_cumsum, "exp_dtA_cumsum", il); - ggml_tensor * exp_dtA_cumsum_last = ggml_view_4d(ctx, exp_dtA_cumsum, - exp_dtA_cumsum->ne[0], 1, exp_dtA_cumsum->ne[2], exp_dtA_cumsum->ne[3], - exp_dtA_cumsum->nb[1], exp_dtA_cumsum->nb[2], exp_dtA_cumsum->nb[3], - (exp_dtA_cumsum->ne[1] - 1) * exp_dtA_cumsum->nb[1]); // {1, 1, n_head, n_seqs} - cb(exp_dtA_cumsum_last, "exp_dtA_cumsum_last", il); - // ggml_tensor * exp_dtA_cumsum_perm = ggml_permute(ctx, exp_dtA_cumsum_last, 2, 1, 3, 0); // {1, 1, n_head, n_seqs} - next_state = ggml_add(ctx, next_state, ggml_mul(ctx, ssm, ggml_cont(ctx, exp_dtA_cumsum_last))); - cb(next_state, "next_state_updated", il); - - // step 10: update from previous y - ggml_tensor * y_prev = ggml_mul_mat(ctx, - C_perm, // {d_state, chunk_size_i, n_group, n_seqs} - ssm // {d_state, head_dim, n_head, n_seqs} - ); // {chunk_size_i, head_dim, n_head, n_seqs} - cb(y_prev, "y_prev", il); - y_prev = ggml_mul(ctx, - ggml_cont(ctx, ggml_permute(ctx, y_prev, 2, 0, 1, 3)), // {head_dim, n_head, chunk_size_i, n_seqs} - ggml_cont(ctx, ggml_permute(ctx, exp_dtA_cumsum, 0, 2, 1, 3)) // {1, n_head, chunk_size_i, n_seqs} - ); // {head_dim, chunk_size_i, n_head, n_seqs} - cb(y_prev, "y_prev_mul", il); - y_chunk = ggml_add(ctx, y_chunk, y_prev); - cb(y_chunk, "y_chunk_updated", il); - - // step 11: recurse - if (chunk_size_i == n_seq_tokens) { - y = y_chunk; - } else { - y = ggml_concat(ctx, y, y_chunk, 2); - } - cb(y, "y", il); - ssm = next_state; - } - - // Concat the output y and state - if (ssm->type != y->type) { - ssm = ggml_cast(ctx, ssm, y->type); - } - ggml_tensor * out = ggml_concat(ctx, - ggml_view_1d(ctx, y, ggml_nelements(y), 0), - ggml_view_1d(ctx, ssm, ggml_nelements(ssm), 0), - 0); - return out; - } - }; - - ggml_tensor * y_ssm = build_rs(inp, ssm_states_all, hparams.n_embd_s(), ubatch.n_seqs, get_ssm_rows); + //DEBUG + cur = ggml_view_4d(ctx0, xBC, cur->ne[0], cur->ne[1], cur->ne[2], cur->ne[3], + cur->nb[1], cur->nb[2], cur->nb[3], 0); - // store last states - ggml_build_forward_expand( - gf, ggml_cpy(ctx0, ggml_view_1d(ctx0, y_ssm, d_state * d_inner * n_seqs, ggml_nelements(x) * x->nb[0]), - ggml_view_1d(ctx0, ssm_states_all, d_state * d_inner * n_seqs, - kv_head * d_state * d_inner * ggml_element_size(ssm_states_all)))); - - ggml_tensor * y = ggml_view_4d(ctx0, y_ssm, head_dim, n_head, n_seq_tokens, n_seqs, x->nb[1], n_head * x->nb[1], - n_seq_tokens * n_head * x->nb[1], 0); - - // TODO: skip computing output earlier for unused tokens - - y = ggml_add(ctx0, y, ggml_mul(ctx0, x, model.layers[il].ssm_d)); - cb(y, "mamba2_y_add_d", il); - y = ggml_swiglu_split(ctx0, ggml_cont(ctx0, z), y); - - // grouped RMS norm - if (model.layers[il].ssm_norm) { - y = ggml_reshape_4d(ctx0, y, d_inner / n_group, n_group, n_seq_tokens, n_seqs); - y = build_norm(y, model.layers[il].ssm_norm, NULL, LLM_NORM_RMS, il); - } - - y = ggml_reshape_3d(ctx0, y, d_inner, n_seq_tokens, n_seqs); - - // {d_inner, n_embd} @ {d_inner, n_seq_tokens, n_seqs} => {n_embd, n_seq_tokens, n_seqs} - cur = build_lora_mm(model.layers[il].ssm_out, y); - } + // ssm + // { + // // These correspond to V K Q in SSM/attention duality + // ggml_tensor * x = ggml_view_4d(ctx0, xBC, head_dim, n_head, n_seq_tokens, n_seqs, head_dim * xBC->nb[0], + // xBC->nb[1], xBC->nb[2], 0); + // ggml_tensor * B = ggml_view_4d(ctx0, xBC, d_state, n_group, n_seq_tokens, n_seqs, d_state * xBC->nb[0], + // xBC->nb[1], xBC->nb[2], d_inner * ggml_element_size(xBC)); + // ggml_tensor * C = ggml_view_4d(ctx0, xBC, d_state, n_group, n_seq_tokens, n_seqs, d_state * xBC->nb[0], + // xBC->nb[1], xBC->nb[2], (d_inner + n_group * d_state) * ggml_element_size(xBC)); + + // // {n_head, n_seq_tokens, n_seqs} + // dt = ggml_add(ctx0, ggml_cont(ctx0, dt), model.layers[il].ssm_dt_b); + + // ggml_tensor * A = model.layers[il].ssm_a; + + // // use the states and the indices provided by build_recurrent_state + // // (this is necessary in order to properly use the states before they are overwritten, + // // while avoiding to make unnecessary copies of the states) + // auto get_ssm_rows = [&](ggml_context * ctx, ggml_tensor * states, ggml_tensor * ids) { + // ggml_tensor * ssm = ggml_reshape_4d(ctx, states, d_state, head_dim, n_head, mctx_cur->get_size()); + + // // if (n_seq_tokens == 1) { + // if (true) { + // //DEBUG + // LLAMA_LOG_DEBUG("build_mamba2_layer(layer %d): single-token update\n", il); + // // If single-token, use ssm_scan op + // ssm = ggml_cast(ctx, ssm, GGML_TYPE_F32); + // return ggml_ssm_scan(ctx, ssm, x, dt, A, B, C, ids); + // } else { + // //DEBUG + // LLAMA_LOG_DEBUG("build_mamba2_layer(layer %d): multi-token chunk scan\n", il); + + // // otherwise, use the SSD formulation + + // // extract the state(s) for the sequences identified by ids + // if (ssm->ne[3] != ids->ne[0]) { + // ggml_tensor * ssm_perm = ggml_permute(ctx, ssm, 0, 2, 3, 1); // put the target dim in dim 1 + // ggml_tensor * ids_perm_rep = ggml_repeat_4d(ctx, ids, + // ids->ne[0], ssm->ne[1], ssm->ne[2], 1); // repeat to match expected shape + // ggml_tensor * ssm_ids = ggml_get_rows(ctx, ssm_perm, ids_perm_rep); // extract ids as rows + // ssm = ggml_cont(ctx, ggml_permute(ctx, ssm_ids, 0, 3, 1, 2)); // permute back to original shape + // GGML_ASSERT(ssm->ne[3] == ids->ne[0]); + // } + // // ssm -> {d_state, head_dim, n_head, n_seqs} + + // // step 1: compute dt softplus + // // NOTE: In other implementations, the bias is added after + // // the softplus. This shouldn't be a problem, but it's a + // // difference. + // ggml_tensor * dt_softplus = ggml_softplus(ctx, dt); // {n_head, n_seq_tokens, n_seqs} + // dt_softplus = ggml_clamp(ctx, dt_softplus, 0.001, 100.0); + // cb(dt_softplus, "dt_softplus", il); + + // // step 2: compute dtA and dtX + // ggml_tensor * dtA = ggml_mul(ctx, dt_softplus, ggml_reshape_1d(ctx, A, A->ne[1])); // {n_head, n_seq_tokens, n_seqs} + // cb(dtA, "dtA", il); + // ggml_tensor * dtX = ggml_mul(ctx, x, ggml_reshape_4d(ctx, dt_softplus, 1, dt_softplus->ne[0], dt_softplus->ne[1], dt_softplus->ne[2])); // {head_dim, n_head, n_seq_tokens, n_seqs} + // cb(dtX, "dtX", il); + + // // loop over all chunks + // uint32_t repeats = n_head / n_group; + + // // Empty y that will be extended with each chunk of tokens + // ggml_tensor * y = ggml_new_tensor_4d(ctx, x->type, x->ne[0], x->ne[1], 0, x->ne[3]); + // // TODO: make this configurable + // const uint32_t chunk_size = 512; // default ubatch size + // for (auto chunk_i = 0; chunk_i < n_seq_tokens; chunk_i += chunk_size) { + // ggml_tensor * dtA_chunk; + // ggml_tensor * dtX_chunk; + // ggml_tensor * B_chunk; + // ggml_tensor * C_chunk; + // const auto chunk_size_i = std::min(chunk_size, uint32_t(n_seq_tokens - chunk_i)); + // if (chunk_size_i == n_seq_tokens) { + // dtA_chunk = dtA; + // dtX_chunk = dtX; + // B_chunk = B; + // C_chunk = C; + // } else { + // // chunk views + // // slice dtA on dim 1 + // dtA_chunk = ggml_view_3d(ctx, dtA, + // dtA->ne[0], chunk_size_i, dtA->ne[2], + // dtA->nb[1], dtA->nb[2], + // chunk_i * dtA->nb[1]); + // // slice dtX on dim 2 + // dtX_chunk = ggml_view_4d(ctx, dtX, + // dtX->ne[0], dtX->ne[1], chunk_size_i, dtX->ne[3], + // dtX->nb[1], dtX->nb[2], dtX->nb[3], + // chunk_i * dtX->nb[2]); + // // slice B on dim 2 + // B_chunk = ggml_view_4d(ctx, B, + // B->ne[0], B->ne[1], chunk_size_i, B->ne[3], + // B->nb[1], B->nb[2], B->nb[3], + // chunk_i * B->nb[2]); + // // slice C on dim 2 + // C_chunk = ggml_view_4d(ctx, C, + // C->ne[0], C->ne[1], chunk_size_i, C->ne[3], + // C->nb[1], C->nb[2], C->nb[3], + // chunk_i * C->nb[2]); + // } + // cb(dtA_chunk, "dtA_chunk", il); // {n_head, chunk_size_i, n_seqs} + // cb(dtX_chunk, "dtX_chunk", il); // {head_dim, n_head, chunk_size_i, n_seqs} + // cb(B_chunk, "B_chunk", il); // {d_state, n_group, chunk_size_i, n_seqs} + // cb(C_chunk, "C_chunk", il); // {d_state, n_group, chunk_size_i, n_seqs} + + // // step 3: compute CB + // ggml_tensor * C_perm = ggml_permute(ctx, C_chunk, 0, 2, 1, 3); // {d_state, chunk_size_i, n_group, n_seqs} + // ggml_tensor * B_perm = ggml_permute(ctx, B_chunk, 0, 2, 1, 3); // {d_state, chunk_size_i, n_group, n_seqs} + // ggml_tensor * CB = ggml_mul_mat(ctx, B_perm, C_perm); // {chunk_size_i, chunk_size_i, n_group, n_seqs} + // CB = ggml_repeat_4d(ctx, CB, CB->ne[0], CB->ne[1], CB->ne[2] * repeats, CB->ne[3]); // {chunk_size_i, chunk_size_i, n_head (repeats * n_group), n_seqs} + // cb(CB, "CB", il); + + // // step 4: compute decay + // dtA_chunk = ggml_permute(ctx, dtA_chunk, 2, 1, 3, 0); // {1, chunk_size_i, n_head, n_seqs} + // ggml_tensor * dtA_tmp0 = ggml_repeat_4d(ctx, dtA_chunk, + // dtA_chunk->ne[0] * chunk_size_i, dtA_chunk->ne[1], dtA_chunk->ne[2], dtA_chunk->ne[3]); // {chunk_size_i_0, chunk_size_i_1, n_head, n_seqs} + // ggml_tensor * dtA_tmp1 = ggml_tri(ctx, dtA_tmp0, GGML_TRI_TYPE_LOWER); // {chunk_size_i_0, chunk_size_i_1, n_head, n_seqs} + // ggml_tensor * segsum = ggml_cumsum(ctx, dtA_tmp1); // {chunk_size_i_0, chunk_size_i_1, n_head, n_seqs} + // segsum = ggml_cont(ctx, ggml_transpose(ctx, segsum)); // {chunk_size_i_1, chunk_size_i_0, n_head, n_seqs} + // cb(segsum, "segsum", il); + // ggml_tensor * decay = ggml_exp(ctx, segsum); // {chunk_size_i_1, chunk_size_i_0, n_head, n_seqs} + // cb(decay, "decay", il); + + // // step 5: compute surrogate_attention_matrix + // ggml_tensor * CBdecay = ggml_mul(ctx, CB, decay); + // ggml_tensor * surrogate_attention_matrix = ggml_tri(ctx, CBdecay, GGML_TRI_TYPE_LOWER_DIAG); + // cb(surrogate_attention_matrix, "surrogate_attention_matrix", il); + + // // step 6: compute y + // ggml_tensor * dtX_chunk_perm = ggml_cont(ctx, ggml_permute(ctx, dtX_chunk, 1, 2, 0, 3)); + // ggml_tensor * y_chunk = ggml_mul_mat(ctx, dtX_chunk_perm, surrogate_attention_matrix); + // y_chunk = ggml_cont(ctx, ggml_permute(ctx, y_chunk, 0, 2, 1, 3)); + // cb(y_chunk, "y_chunk", il); // {n_head, chunk_size_i, n_seqs} + + // // step 7: compute dtxdecay + // ggml_tensor * decay_last = ggml_view_4d(ctx, decay, + // decay->ne[0], 1, decay->ne[2], decay->ne[3], + // decay->nb[1], decay->nb[2], decay->nb[3], + // (decay->ne[1] - 1) * decay->nb[1]); + // decay_last = ggml_cont(ctx, ggml_permute(ctx, decay_last, 2, 0, 1, 3)); + // cb(decay_last, "decay_last", il); + // B_perm = ggml_cont(ctx, B_perm); + // B_perm = ggml_repeat_4d(ctx, B_perm, + // B_perm->ne[0], B_perm->ne[1], B_perm->ne[2] * repeats, B_perm->ne[3]); + // ggml_tensor * dtxdecay = ggml_mul(ctx, dtX_chunk, decay_last); + // dtxdecay = ggml_cont(ctx, ggml_permute(ctx, dtxdecay, 1, 2, 0, 3)); + // cb(dtxdecay, "dtxdecay", il); + + // // step 8: compute next_state + // ggml_tensor * next_state = ggml_mul_mat(ctx, ggml_cont(ctx, ggml_permute(ctx, B_perm, 1, 0, 2, 3)), dtxdecay); + // if (next_state->type != ssm->type) { + // next_state = ggml_cast(ctx, next_state, ssm->type); + // } + // cb(next_state, "next_state", il); + + // // TODO: Skip y and state updates if no previous state + + // // step 9: update from previous state + // dtA_chunk = ggml_cont(ctx, dtA_chunk); + // ggml_tensor * dtA_chunk_flat = ggml_view_3d(ctx, + // dtA_chunk, dtA_chunk->ne[1], dtA_chunk->ne[2], dtA_chunk->ne[3], + // dtA_chunk->nb[2], dtA_chunk->nb[3], 0); // {chunk_size_i, n_head, n_seqs, 1} + // ggml_tensor * exp_dtA_cumsum = ggml_view_4d(ctx, + // ggml_exp(ctx, ggml_cumsum(ctx, dtA_chunk_flat)), + // 1, dtA_chunk->ne[1], dtA_chunk->ne[2], dtA_chunk->ne[3], + // dtA_chunk->nb[1], dtA_chunk->nb[2], dtA_chunk->nb[3], 0); // {1, chunk_size_i, n_head, n_seqs} + // cb(exp_dtA_cumsum, "exp_dtA_cumsum", il); + // ggml_tensor * exp_dtA_cumsum_last = ggml_view_4d(ctx, exp_dtA_cumsum, + // exp_dtA_cumsum->ne[0], 1, exp_dtA_cumsum->ne[2], exp_dtA_cumsum->ne[3], + // exp_dtA_cumsum->nb[1], exp_dtA_cumsum->nb[2], exp_dtA_cumsum->nb[3], + // (exp_dtA_cumsum->ne[1] - 1) * exp_dtA_cumsum->nb[1]); // {1, 1, n_head, n_seqs} + // cb(exp_dtA_cumsum_last, "exp_dtA_cumsum_last", il); + // // ggml_tensor * exp_dtA_cumsum_perm = ggml_permute(ctx, exp_dtA_cumsum_last, 2, 1, 3, 0); // {1, 1, n_head, n_seqs} + // next_state = ggml_add(ctx, next_state, ggml_mul(ctx, ssm, ggml_cont(ctx, exp_dtA_cumsum_last))); + // cb(next_state, "next_state_updated", il); + + // // step 10: update from previous y + // ggml_tensor * y_prev = ggml_mul_mat(ctx, + // C_perm, // {d_state, chunk_size_i, n_group, n_seqs} + // ssm // {d_state, head_dim, n_head, n_seqs} + // ); // {chunk_size_i, head_dim, n_head, n_seqs} + // cb(y_prev, "y_prev", il); + // y_prev = ggml_mul(ctx, + // ggml_cont(ctx, ggml_permute(ctx, y_prev, 2, 0, 1, 3)), // {head_dim, n_head, chunk_size_i, n_seqs} + // ggml_cont(ctx, ggml_permute(ctx, exp_dtA_cumsum, 0, 2, 1, 3)) // {1, n_head, chunk_size_i, n_seqs} + // ); // {head_dim, chunk_size_i, n_head, n_seqs} + // cb(y_prev, "y_prev_mul", il); + // y_chunk = ggml_add(ctx, y_chunk, y_prev); + // cb(y_chunk, "y_chunk_updated", il); + + // // step 11: recurse + // if (chunk_size_i == n_seq_tokens) { + // y = y_chunk; + // } else { + // y = ggml_concat(ctx, y, y_chunk, 2); + // } + // cb(y, "y", il); + // ssm = next_state; + // } + + // // Concat the output y and state + // if (ssm->type != y->type) { + // ssm = ggml_cast(ctx, ssm, y->type); + // } + // ggml_tensor * out = ggml_concat(ctx, + // ggml_view_1d(ctx, y, ggml_nelements(y), 0), + // ggml_view_1d(ctx, ssm, ggml_nelements(ssm), 0), + // 0); + // return out; + // } + // }; + + // ggml_tensor * y_ssm = build_rs(inp, ssm_states_all, hparams.n_embd_s(), ubatch.n_seqs, get_ssm_rows); + + // // store last states + // ggml_build_forward_expand( + // gf, ggml_cpy(ctx0, ggml_view_1d(ctx0, y_ssm, d_state * d_inner * n_seqs, ggml_nelements(x) * x->nb[0]), + // ggml_view_1d(ctx0, ssm_states_all, d_state * d_inner * n_seqs, + // kv_head * d_state * d_inner * ggml_element_size(ssm_states_all)))); + + // ggml_tensor * y = ggml_view_4d(ctx0, y_ssm, head_dim, n_head, n_seq_tokens, n_seqs, x->nb[1], n_head * x->nb[1], + // n_seq_tokens * n_head * x->nb[1], 0); + + // // TODO: skip computing output earlier for unused tokens + + // y = ggml_add(ctx0, y, ggml_mul(ctx0, x, model.layers[il].ssm_d)); + // cb(y, "mamba2_y_add_d", il); + // y = ggml_swiglu_split(ctx0, ggml_cont(ctx0, z), y); + + // // grouped RMS norm + // if (model.layers[il].ssm_norm) { + // y = ggml_reshape_4d(ctx0, y, d_inner / n_group, n_group, n_seq_tokens, n_seqs); + // y = build_norm(y, model.layers[il].ssm_norm, NULL, LLM_NORM_RMS, il); + // } + + // y = ggml_reshape_3d(ctx0, y, d_inner, n_seq_tokens, n_seqs); + + // // {d_inner, n_embd} @ {d_inner, n_seq_tokens, n_seqs} => {n_embd, n_seq_tokens, n_seqs} + // cur = build_lora_mm(model.layers[il].ssm_out, y); + // } // {n_embd, n_seq_tokens, n_seqs} => {n_embd, n_tokens} cur = ggml_reshape_2d(ctx0, cur, cur->ne[0], n_seq_tokens * n_seqs); From 4b8f1d59f12b6b843c88727545b2ceacecc875eb Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Fri, 5 Dec 2025 20:45:21 -0700 Subject: [PATCH 4/9] feat: Add a batched version of ssm_conv This was done using Claude Code. It found a number of optimizations around how the threads were organized, resulting in a huge performance boost! Branch: Mamba2SSD Signed-off-by: Gabe Goodhart --- ggml/src/ggml-metal/ggml-metal-device.cpp | 17 ++++++++ ggml/src/ggml-metal/ggml-metal-device.h | 1 + ggml/src/ggml-metal/ggml-metal-ops.cpp | 33 +++++++++++---- ggml/src/ggml-metal/ggml-metal.metal | 50 +++++++++++++++++++++++ 4 files changed, 94 insertions(+), 7 deletions(-) diff --git a/ggml/src/ggml-metal/ggml-metal-device.cpp b/ggml/src/ggml-metal/ggml-metal-device.cpp index ba3c342751..d291a1ac00 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.cpp +++ b/ggml/src/ggml-metal/ggml-metal-device.cpp @@ -411,6 +411,23 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_conv(ggml_me return res; } +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_conv_batched(ggml_metal_library_t lib, const ggml_tensor * op) { + GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32); + GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32); + + GGML_ASSERT(ggml_is_contiguous(op->src[0])); + GGML_ASSERT(ggml_is_contiguous(op->src[1])); + + const char * name = "kernel_ssm_conv_f32_f32_b256"; + + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + res = ggml_metal_library_compile_pipeline(lib, name, name, nullptr); + } + + return res; +} + ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_scan(ggml_metal_library_t lib, const ggml_tensor * op) { GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); diff --git a/ggml/src/ggml-metal/ggml-metal-device.h b/ggml/src/ggml-metal/ggml-metal-device.h index 77f2e98cfe..05202a738b 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.h +++ b/ggml/src/ggml-metal/ggml-metal-device.h @@ -117,6 +117,7 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_cumsum_ad struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tri (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_soft_max (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_conv (ggml_metal_library_t lib, const struct ggml_tensor * op); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_conv_batched (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_scan (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_rwkv (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv_ext (ggml_metal_library_t lib, enum ggml_type tsrc0, enum ggml_type tsrc1, int nsg, int nxpsg, int r1ptg); diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp index 9efd51abba..54972d88c6 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.cpp +++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp @@ -1365,15 +1365,34 @@ int ggml_metal_op_ssm_conv(ggml_metal_op_t ctx, int idx) { /*.nb2 =*/ nb2, }; - auto pipeline = ggml_metal_library_get_pipeline_ssm_conv(lib, op); + // Use batched kernel for prefill (ne1 > 1) to reduce threadgroup dispatch overhead + constexpr int BATCH_SIZE = 256; + const bool use_batched = (ne1 > 1); - ggml_metal_encoder_set_pipeline(enc, pipeline); - ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0); - ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 1); - ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[1]), 2); - ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 3); + if (use_batched) { + auto pipeline = ggml_metal_library_get_pipeline_ssm_conv_batched(lib, op); + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 1); + ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[1]), 2); + ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 3); - ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne1, ne02, 1, 1, 1); + // Dispatch: ne01 rows, ceil(ne1/BATCH_SIZE) token batches, ne02 sequences + // Each threadgroup has BATCH_SIZE threads, each handling one token + const int n_token_batches = (ne1 + BATCH_SIZE - 1) / BATCH_SIZE; + ggml_metal_encoder_dispatch_threadgroups(enc, ne01, n_token_batches, ne02, BATCH_SIZE, 1, 1); + } else { + auto pipeline = ggml_metal_library_get_pipeline_ssm_conv(lib, op); + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 1); + ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[1]), 2); + ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 3); + + ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne1, ne02, 1, 1, 1); + } return 1; } diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 4b78d5a2ba..2a193e3208 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -2312,6 +2312,56 @@ kernel void kernel_ssm_conv_f32_f32( x[0] = sumf; } +// Batched version: each threadgroup processes multiple tokens for better efficiency +// Thread layout: each thread handles one token, threadgroup covers BATCH_SIZE tokens +template +kernel void kernel_ssm_conv_f32_f32_batched( + constant ggml_metal_kargs_ssm_conv & args, + device const void * src0, + device const void * src1, + device float * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + // tgpig.x = row index (ir) + // tgpig.y = batch of tokens (i2_base / BATCH_SIZE) + // tgpig.z = sequence index (i3) + // tpitg.x = thread within batch (0..BATCH_SIZE-1) + + const int64_t ir = tgpig.x; + const int64_t i2_base = tgpig.y * BATCH_SIZE; + const int64_t i3 = tgpig.z; + const int64_t i2_off = tpitg.x; + const int64_t i2 = i2_base + i2_off; + + const int64_t nc = args.ne10; // conv kernel size (typically 4) + const int64_t n_t = args.ne1; // number of tokens + + // Bounds check for partial batches at the end + if (i2 >= n_t) { + return; + } + + // Load conv weights (shared across all tokens for this row) + device const float * c = (device const float *) ((device const char *) src1 + ir*args.nb11); + + // Load source for this specific token + device const float * s = (device const float *) ((device const char *) src0 + ir*args.nb01 + i2*args.nb00 + i3*args.nb02); + + // Output location for this token + device float * x = (device float *) ((device char *) dst + ir*args.nb0 + i2*args.nb1 + i3*args.nb2); + + float sumf = 0.0f; + for (int64_t i0 = 0; i0 < nc; ++i0) { + sumf += s[i0] * c[i0]; + } + + x[0] = sumf; +} + +typedef decltype(kernel_ssm_conv_f32_f32_batched<1>) kernel_ssm_conv_batched_t; +template [[host_name("kernel_ssm_conv_f32_f32_b256")]] kernel kernel_ssm_conv_batched_t kernel_ssm_conv_f32_f32_batched<256>; + kernel void kernel_ssm_conv_f32_f32_4( constant ggml_metal_kargs_ssm_conv & args, device const void * src0, From 953bb624b9c73f49cb5e03b592245e18b11955b0 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Fri, 5 Dec 2025 20:45:36 -0700 Subject: [PATCH 5/9] un-TEMP: Re-enable the SSM portion Branch: Mamba2SSD Signed-off-by: Gabe Goodhart --- src/models/graph-context-mamba.cpp | 488 ++++++++++++++--------------- 1 file changed, 244 insertions(+), 244 deletions(-) diff --git a/src/models/graph-context-mamba.cpp b/src/models/graph-context-mamba.cpp index 7c613a0ce7..6caba05500 100644 --- a/src/models/graph-context-mamba.cpp +++ b/src/models/graph-context-mamba.cpp @@ -216,256 +216,256 @@ ggml_tensor * llm_graph_context_mamba::build_mamba2_layer(llm_graph_input_rs * i // For simultaneous sequences, all sequences need to have the same length. xBC = ggml_ssm_conv(ctx0, conv_x, model.layers[il].ssm_conv1d); - // // bias - // xBC = ggml_add(ctx0, xBC, model.layers[il].ssm_conv1d_b); + // bias + xBC = ggml_add(ctx0, xBC, model.layers[il].ssm_conv1d_b); - // xBC = ggml_silu(ctx0, xBC); + xBC = ggml_silu(ctx0, xBC); } //DEBUG - cur = ggml_view_4d(ctx0, xBC, cur->ne[0], cur->ne[1], cur->ne[2], cur->ne[3], - cur->nb[1], cur->nb[2], cur->nb[3], 0); + // cur = ggml_view_4d(ctx0, xBC, cur->ne[0], cur->ne[1], cur->ne[2], cur->ne[3], + // cur->nb[1], cur->nb[2], cur->nb[3], 0); // ssm - // { - // // These correspond to V K Q in SSM/attention duality - // ggml_tensor * x = ggml_view_4d(ctx0, xBC, head_dim, n_head, n_seq_tokens, n_seqs, head_dim * xBC->nb[0], - // xBC->nb[1], xBC->nb[2], 0); - // ggml_tensor * B = ggml_view_4d(ctx0, xBC, d_state, n_group, n_seq_tokens, n_seqs, d_state * xBC->nb[0], - // xBC->nb[1], xBC->nb[2], d_inner * ggml_element_size(xBC)); - // ggml_tensor * C = ggml_view_4d(ctx0, xBC, d_state, n_group, n_seq_tokens, n_seqs, d_state * xBC->nb[0], - // xBC->nb[1], xBC->nb[2], (d_inner + n_group * d_state) * ggml_element_size(xBC)); - - // // {n_head, n_seq_tokens, n_seqs} - // dt = ggml_add(ctx0, ggml_cont(ctx0, dt), model.layers[il].ssm_dt_b); - - // ggml_tensor * A = model.layers[il].ssm_a; - - // // use the states and the indices provided by build_recurrent_state - // // (this is necessary in order to properly use the states before they are overwritten, - // // while avoiding to make unnecessary copies of the states) - // auto get_ssm_rows = [&](ggml_context * ctx, ggml_tensor * states, ggml_tensor * ids) { - // ggml_tensor * ssm = ggml_reshape_4d(ctx, states, d_state, head_dim, n_head, mctx_cur->get_size()); - - // // if (n_seq_tokens == 1) { - // if (true) { - // //DEBUG - // LLAMA_LOG_DEBUG("build_mamba2_layer(layer %d): single-token update\n", il); - // // If single-token, use ssm_scan op - // ssm = ggml_cast(ctx, ssm, GGML_TYPE_F32); - // return ggml_ssm_scan(ctx, ssm, x, dt, A, B, C, ids); - // } else { - // //DEBUG - // LLAMA_LOG_DEBUG("build_mamba2_layer(layer %d): multi-token chunk scan\n", il); - - // // otherwise, use the SSD formulation - - // // extract the state(s) for the sequences identified by ids - // if (ssm->ne[3] != ids->ne[0]) { - // ggml_tensor * ssm_perm = ggml_permute(ctx, ssm, 0, 2, 3, 1); // put the target dim in dim 1 - // ggml_tensor * ids_perm_rep = ggml_repeat_4d(ctx, ids, - // ids->ne[0], ssm->ne[1], ssm->ne[2], 1); // repeat to match expected shape - // ggml_tensor * ssm_ids = ggml_get_rows(ctx, ssm_perm, ids_perm_rep); // extract ids as rows - // ssm = ggml_cont(ctx, ggml_permute(ctx, ssm_ids, 0, 3, 1, 2)); // permute back to original shape - // GGML_ASSERT(ssm->ne[3] == ids->ne[0]); - // } - // // ssm -> {d_state, head_dim, n_head, n_seqs} - - // // step 1: compute dt softplus - // // NOTE: In other implementations, the bias is added after - // // the softplus. This shouldn't be a problem, but it's a - // // difference. - // ggml_tensor * dt_softplus = ggml_softplus(ctx, dt); // {n_head, n_seq_tokens, n_seqs} - // dt_softplus = ggml_clamp(ctx, dt_softplus, 0.001, 100.0); - // cb(dt_softplus, "dt_softplus", il); - - // // step 2: compute dtA and dtX - // ggml_tensor * dtA = ggml_mul(ctx, dt_softplus, ggml_reshape_1d(ctx, A, A->ne[1])); // {n_head, n_seq_tokens, n_seqs} - // cb(dtA, "dtA", il); - // ggml_tensor * dtX = ggml_mul(ctx, x, ggml_reshape_4d(ctx, dt_softplus, 1, dt_softplus->ne[0], dt_softplus->ne[1], dt_softplus->ne[2])); // {head_dim, n_head, n_seq_tokens, n_seqs} - // cb(dtX, "dtX", il); - - // // loop over all chunks - // uint32_t repeats = n_head / n_group; - - // // Empty y that will be extended with each chunk of tokens - // ggml_tensor * y = ggml_new_tensor_4d(ctx, x->type, x->ne[0], x->ne[1], 0, x->ne[3]); - // // TODO: make this configurable - // const uint32_t chunk_size = 512; // default ubatch size - // for (auto chunk_i = 0; chunk_i < n_seq_tokens; chunk_i += chunk_size) { - // ggml_tensor * dtA_chunk; - // ggml_tensor * dtX_chunk; - // ggml_tensor * B_chunk; - // ggml_tensor * C_chunk; - // const auto chunk_size_i = std::min(chunk_size, uint32_t(n_seq_tokens - chunk_i)); - // if (chunk_size_i == n_seq_tokens) { - // dtA_chunk = dtA; - // dtX_chunk = dtX; - // B_chunk = B; - // C_chunk = C; - // } else { - // // chunk views - // // slice dtA on dim 1 - // dtA_chunk = ggml_view_3d(ctx, dtA, - // dtA->ne[0], chunk_size_i, dtA->ne[2], - // dtA->nb[1], dtA->nb[2], - // chunk_i * dtA->nb[1]); - // // slice dtX on dim 2 - // dtX_chunk = ggml_view_4d(ctx, dtX, - // dtX->ne[0], dtX->ne[1], chunk_size_i, dtX->ne[3], - // dtX->nb[1], dtX->nb[2], dtX->nb[3], - // chunk_i * dtX->nb[2]); - // // slice B on dim 2 - // B_chunk = ggml_view_4d(ctx, B, - // B->ne[0], B->ne[1], chunk_size_i, B->ne[3], - // B->nb[1], B->nb[2], B->nb[3], - // chunk_i * B->nb[2]); - // // slice C on dim 2 - // C_chunk = ggml_view_4d(ctx, C, - // C->ne[0], C->ne[1], chunk_size_i, C->ne[3], - // C->nb[1], C->nb[2], C->nb[3], - // chunk_i * C->nb[2]); - // } - // cb(dtA_chunk, "dtA_chunk", il); // {n_head, chunk_size_i, n_seqs} - // cb(dtX_chunk, "dtX_chunk", il); // {head_dim, n_head, chunk_size_i, n_seqs} - // cb(B_chunk, "B_chunk", il); // {d_state, n_group, chunk_size_i, n_seqs} - // cb(C_chunk, "C_chunk", il); // {d_state, n_group, chunk_size_i, n_seqs} - - // // step 3: compute CB - // ggml_tensor * C_perm = ggml_permute(ctx, C_chunk, 0, 2, 1, 3); // {d_state, chunk_size_i, n_group, n_seqs} - // ggml_tensor * B_perm = ggml_permute(ctx, B_chunk, 0, 2, 1, 3); // {d_state, chunk_size_i, n_group, n_seqs} - // ggml_tensor * CB = ggml_mul_mat(ctx, B_perm, C_perm); // {chunk_size_i, chunk_size_i, n_group, n_seqs} - // CB = ggml_repeat_4d(ctx, CB, CB->ne[0], CB->ne[1], CB->ne[2] * repeats, CB->ne[3]); // {chunk_size_i, chunk_size_i, n_head (repeats * n_group), n_seqs} - // cb(CB, "CB", il); - - // // step 4: compute decay - // dtA_chunk = ggml_permute(ctx, dtA_chunk, 2, 1, 3, 0); // {1, chunk_size_i, n_head, n_seqs} - // ggml_tensor * dtA_tmp0 = ggml_repeat_4d(ctx, dtA_chunk, - // dtA_chunk->ne[0] * chunk_size_i, dtA_chunk->ne[1], dtA_chunk->ne[2], dtA_chunk->ne[3]); // {chunk_size_i_0, chunk_size_i_1, n_head, n_seqs} - // ggml_tensor * dtA_tmp1 = ggml_tri(ctx, dtA_tmp0, GGML_TRI_TYPE_LOWER); // {chunk_size_i_0, chunk_size_i_1, n_head, n_seqs} - // ggml_tensor * segsum = ggml_cumsum(ctx, dtA_tmp1); // {chunk_size_i_0, chunk_size_i_1, n_head, n_seqs} - // segsum = ggml_cont(ctx, ggml_transpose(ctx, segsum)); // {chunk_size_i_1, chunk_size_i_0, n_head, n_seqs} - // cb(segsum, "segsum", il); - // ggml_tensor * decay = ggml_exp(ctx, segsum); // {chunk_size_i_1, chunk_size_i_0, n_head, n_seqs} - // cb(decay, "decay", il); - - // // step 5: compute surrogate_attention_matrix - // ggml_tensor * CBdecay = ggml_mul(ctx, CB, decay); - // ggml_tensor * surrogate_attention_matrix = ggml_tri(ctx, CBdecay, GGML_TRI_TYPE_LOWER_DIAG); - // cb(surrogate_attention_matrix, "surrogate_attention_matrix", il); - - // // step 6: compute y - // ggml_tensor * dtX_chunk_perm = ggml_cont(ctx, ggml_permute(ctx, dtX_chunk, 1, 2, 0, 3)); - // ggml_tensor * y_chunk = ggml_mul_mat(ctx, dtX_chunk_perm, surrogate_attention_matrix); - // y_chunk = ggml_cont(ctx, ggml_permute(ctx, y_chunk, 0, 2, 1, 3)); - // cb(y_chunk, "y_chunk", il); // {n_head, chunk_size_i, n_seqs} - - // // step 7: compute dtxdecay - // ggml_tensor * decay_last = ggml_view_4d(ctx, decay, - // decay->ne[0], 1, decay->ne[2], decay->ne[3], - // decay->nb[1], decay->nb[2], decay->nb[3], - // (decay->ne[1] - 1) * decay->nb[1]); - // decay_last = ggml_cont(ctx, ggml_permute(ctx, decay_last, 2, 0, 1, 3)); - // cb(decay_last, "decay_last", il); - // B_perm = ggml_cont(ctx, B_perm); - // B_perm = ggml_repeat_4d(ctx, B_perm, - // B_perm->ne[0], B_perm->ne[1], B_perm->ne[2] * repeats, B_perm->ne[3]); - // ggml_tensor * dtxdecay = ggml_mul(ctx, dtX_chunk, decay_last); - // dtxdecay = ggml_cont(ctx, ggml_permute(ctx, dtxdecay, 1, 2, 0, 3)); - // cb(dtxdecay, "dtxdecay", il); - - // // step 8: compute next_state - // ggml_tensor * next_state = ggml_mul_mat(ctx, ggml_cont(ctx, ggml_permute(ctx, B_perm, 1, 0, 2, 3)), dtxdecay); - // if (next_state->type != ssm->type) { - // next_state = ggml_cast(ctx, next_state, ssm->type); - // } - // cb(next_state, "next_state", il); - - // // TODO: Skip y and state updates if no previous state - - // // step 9: update from previous state - // dtA_chunk = ggml_cont(ctx, dtA_chunk); - // ggml_tensor * dtA_chunk_flat = ggml_view_3d(ctx, - // dtA_chunk, dtA_chunk->ne[1], dtA_chunk->ne[2], dtA_chunk->ne[3], - // dtA_chunk->nb[2], dtA_chunk->nb[3], 0); // {chunk_size_i, n_head, n_seqs, 1} - // ggml_tensor * exp_dtA_cumsum = ggml_view_4d(ctx, - // ggml_exp(ctx, ggml_cumsum(ctx, dtA_chunk_flat)), - // 1, dtA_chunk->ne[1], dtA_chunk->ne[2], dtA_chunk->ne[3], - // dtA_chunk->nb[1], dtA_chunk->nb[2], dtA_chunk->nb[3], 0); // {1, chunk_size_i, n_head, n_seqs} - // cb(exp_dtA_cumsum, "exp_dtA_cumsum", il); - // ggml_tensor * exp_dtA_cumsum_last = ggml_view_4d(ctx, exp_dtA_cumsum, - // exp_dtA_cumsum->ne[0], 1, exp_dtA_cumsum->ne[2], exp_dtA_cumsum->ne[3], - // exp_dtA_cumsum->nb[1], exp_dtA_cumsum->nb[2], exp_dtA_cumsum->nb[3], - // (exp_dtA_cumsum->ne[1] - 1) * exp_dtA_cumsum->nb[1]); // {1, 1, n_head, n_seqs} - // cb(exp_dtA_cumsum_last, "exp_dtA_cumsum_last", il); - // // ggml_tensor * exp_dtA_cumsum_perm = ggml_permute(ctx, exp_dtA_cumsum_last, 2, 1, 3, 0); // {1, 1, n_head, n_seqs} - // next_state = ggml_add(ctx, next_state, ggml_mul(ctx, ssm, ggml_cont(ctx, exp_dtA_cumsum_last))); - // cb(next_state, "next_state_updated", il); - - // // step 10: update from previous y - // ggml_tensor * y_prev = ggml_mul_mat(ctx, - // C_perm, // {d_state, chunk_size_i, n_group, n_seqs} - // ssm // {d_state, head_dim, n_head, n_seqs} - // ); // {chunk_size_i, head_dim, n_head, n_seqs} - // cb(y_prev, "y_prev", il); - // y_prev = ggml_mul(ctx, - // ggml_cont(ctx, ggml_permute(ctx, y_prev, 2, 0, 1, 3)), // {head_dim, n_head, chunk_size_i, n_seqs} - // ggml_cont(ctx, ggml_permute(ctx, exp_dtA_cumsum, 0, 2, 1, 3)) // {1, n_head, chunk_size_i, n_seqs} - // ); // {head_dim, chunk_size_i, n_head, n_seqs} - // cb(y_prev, "y_prev_mul", il); - // y_chunk = ggml_add(ctx, y_chunk, y_prev); - // cb(y_chunk, "y_chunk_updated", il); - - // // step 11: recurse - // if (chunk_size_i == n_seq_tokens) { - // y = y_chunk; - // } else { - // y = ggml_concat(ctx, y, y_chunk, 2); - // } - // cb(y, "y", il); - // ssm = next_state; - // } - - // // Concat the output y and state - // if (ssm->type != y->type) { - // ssm = ggml_cast(ctx, ssm, y->type); - // } - // ggml_tensor * out = ggml_concat(ctx, - // ggml_view_1d(ctx, y, ggml_nelements(y), 0), - // ggml_view_1d(ctx, ssm, ggml_nelements(ssm), 0), - // 0); - // return out; - // } - // }; - - // ggml_tensor * y_ssm = build_rs(inp, ssm_states_all, hparams.n_embd_s(), ubatch.n_seqs, get_ssm_rows); - - // // store last states - // ggml_build_forward_expand( - // gf, ggml_cpy(ctx0, ggml_view_1d(ctx0, y_ssm, d_state * d_inner * n_seqs, ggml_nelements(x) * x->nb[0]), - // ggml_view_1d(ctx0, ssm_states_all, d_state * d_inner * n_seqs, - // kv_head * d_state * d_inner * ggml_element_size(ssm_states_all)))); - - // ggml_tensor * y = ggml_view_4d(ctx0, y_ssm, head_dim, n_head, n_seq_tokens, n_seqs, x->nb[1], n_head * x->nb[1], - // n_seq_tokens * n_head * x->nb[1], 0); - - // // TODO: skip computing output earlier for unused tokens - - // y = ggml_add(ctx0, y, ggml_mul(ctx0, x, model.layers[il].ssm_d)); - // cb(y, "mamba2_y_add_d", il); - // y = ggml_swiglu_split(ctx0, ggml_cont(ctx0, z), y); - - // // grouped RMS norm - // if (model.layers[il].ssm_norm) { - // y = ggml_reshape_4d(ctx0, y, d_inner / n_group, n_group, n_seq_tokens, n_seqs); - // y = build_norm(y, model.layers[il].ssm_norm, NULL, LLM_NORM_RMS, il); - // } - - // y = ggml_reshape_3d(ctx0, y, d_inner, n_seq_tokens, n_seqs); - - // // {d_inner, n_embd} @ {d_inner, n_seq_tokens, n_seqs} => {n_embd, n_seq_tokens, n_seqs} - // cur = build_lora_mm(model.layers[il].ssm_out, y); - // } + { + // These correspond to V K Q in SSM/attention duality + ggml_tensor * x = ggml_view_4d(ctx0, xBC, head_dim, n_head, n_seq_tokens, n_seqs, head_dim * xBC->nb[0], + xBC->nb[1], xBC->nb[2], 0); + ggml_tensor * B = ggml_view_4d(ctx0, xBC, d_state, n_group, n_seq_tokens, n_seqs, d_state * xBC->nb[0], + xBC->nb[1], xBC->nb[2], d_inner * ggml_element_size(xBC)); + ggml_tensor * C = ggml_view_4d(ctx0, xBC, d_state, n_group, n_seq_tokens, n_seqs, d_state * xBC->nb[0], + xBC->nb[1], xBC->nb[2], (d_inner + n_group * d_state) * ggml_element_size(xBC)); + + // {n_head, n_seq_tokens, n_seqs} + dt = ggml_add(ctx0, ggml_cont(ctx0, dt), model.layers[il].ssm_dt_b); + + ggml_tensor * A = model.layers[il].ssm_a; + + // use the states and the indices provided by build_recurrent_state + // (this is necessary in order to properly use the states before they are overwritten, + // while avoiding to make unnecessary copies of the states) + auto get_ssm_rows = [&](ggml_context * ctx, ggml_tensor * states, ggml_tensor * ids) { + ggml_tensor * ssm = ggml_reshape_4d(ctx, states, d_state, head_dim, n_head, mctx_cur->get_size()); + + // if (n_seq_tokens == 1) { + if (true) { + //DEBUG + LLAMA_LOG_DEBUG("build_mamba2_layer(layer %d): single-token update\n", il); + // If single-token, use ssm_scan op + ssm = ggml_cast(ctx, ssm, GGML_TYPE_F32); + return ggml_ssm_scan(ctx, ssm, x, dt, A, B, C, ids); + } else { + //DEBUG + LLAMA_LOG_DEBUG("build_mamba2_layer(layer %d): multi-token chunk scan\n", il); + + // otherwise, use the SSD formulation + + // extract the state(s) for the sequences identified by ids + if (ssm->ne[3] != ids->ne[0]) { + ggml_tensor * ssm_perm = ggml_permute(ctx, ssm, 0, 2, 3, 1); // put the target dim in dim 1 + ggml_tensor * ids_perm_rep = ggml_repeat_4d(ctx, ids, + ids->ne[0], ssm->ne[1], ssm->ne[2], 1); // repeat to match expected shape + ggml_tensor * ssm_ids = ggml_get_rows(ctx, ssm_perm, ids_perm_rep); // extract ids as rows + ssm = ggml_cont(ctx, ggml_permute(ctx, ssm_ids, 0, 3, 1, 2)); // permute back to original shape + GGML_ASSERT(ssm->ne[3] == ids->ne[0]); + } + // ssm -> {d_state, head_dim, n_head, n_seqs} + + // step 1: compute dt softplus + // NOTE: In other implementations, the bias is added after + // the softplus. This shouldn't be a problem, but it's a + // difference. + ggml_tensor * dt_softplus = ggml_softplus(ctx, dt); // {n_head, n_seq_tokens, n_seqs} + dt_softplus = ggml_clamp(ctx, dt_softplus, 0.001, 100.0); + cb(dt_softplus, "dt_softplus", il); + + // step 2: compute dtA and dtX + ggml_tensor * dtA = ggml_mul(ctx, dt_softplus, ggml_reshape_1d(ctx, A, A->ne[1])); // {n_head, n_seq_tokens, n_seqs} + cb(dtA, "dtA", il); + ggml_tensor * dtX = ggml_mul(ctx, x, ggml_reshape_4d(ctx, dt_softplus, 1, dt_softplus->ne[0], dt_softplus->ne[1], dt_softplus->ne[2])); // {head_dim, n_head, n_seq_tokens, n_seqs} + cb(dtX, "dtX", il); + + // loop over all chunks + uint32_t repeats = n_head / n_group; + + // Empty y that will be extended with each chunk of tokens + ggml_tensor * y = ggml_new_tensor_4d(ctx, x->type, x->ne[0], x->ne[1], 0, x->ne[3]); + // TODO: make this configurable + const uint32_t chunk_size = 512; // default ubatch size + for (auto chunk_i = 0; chunk_i < n_seq_tokens; chunk_i += chunk_size) { + ggml_tensor * dtA_chunk; + ggml_tensor * dtX_chunk; + ggml_tensor * B_chunk; + ggml_tensor * C_chunk; + const auto chunk_size_i = std::min(chunk_size, uint32_t(n_seq_tokens - chunk_i)); + if (chunk_size_i == n_seq_tokens) { + dtA_chunk = dtA; + dtX_chunk = dtX; + B_chunk = B; + C_chunk = C; + } else { + // chunk views + // slice dtA on dim 1 + dtA_chunk = ggml_view_3d(ctx, dtA, + dtA->ne[0], chunk_size_i, dtA->ne[2], + dtA->nb[1], dtA->nb[2], + chunk_i * dtA->nb[1]); + // slice dtX on dim 2 + dtX_chunk = ggml_view_4d(ctx, dtX, + dtX->ne[0], dtX->ne[1], chunk_size_i, dtX->ne[3], + dtX->nb[1], dtX->nb[2], dtX->nb[3], + chunk_i * dtX->nb[2]); + // slice B on dim 2 + B_chunk = ggml_view_4d(ctx, B, + B->ne[0], B->ne[1], chunk_size_i, B->ne[3], + B->nb[1], B->nb[2], B->nb[3], + chunk_i * B->nb[2]); + // slice C on dim 2 + C_chunk = ggml_view_4d(ctx, C, + C->ne[0], C->ne[1], chunk_size_i, C->ne[3], + C->nb[1], C->nb[2], C->nb[3], + chunk_i * C->nb[2]); + } + cb(dtA_chunk, "dtA_chunk", il); // {n_head, chunk_size_i, n_seqs} + cb(dtX_chunk, "dtX_chunk", il); // {head_dim, n_head, chunk_size_i, n_seqs} + cb(B_chunk, "B_chunk", il); // {d_state, n_group, chunk_size_i, n_seqs} + cb(C_chunk, "C_chunk", il); // {d_state, n_group, chunk_size_i, n_seqs} + + // step 3: compute CB + ggml_tensor * C_perm = ggml_permute(ctx, C_chunk, 0, 2, 1, 3); // {d_state, chunk_size_i, n_group, n_seqs} + ggml_tensor * B_perm = ggml_permute(ctx, B_chunk, 0, 2, 1, 3); // {d_state, chunk_size_i, n_group, n_seqs} + ggml_tensor * CB = ggml_mul_mat(ctx, B_perm, C_perm); // {chunk_size_i, chunk_size_i, n_group, n_seqs} + CB = ggml_repeat_4d(ctx, CB, CB->ne[0], CB->ne[1], CB->ne[2] * repeats, CB->ne[3]); // {chunk_size_i, chunk_size_i, n_head (repeats * n_group), n_seqs} + cb(CB, "CB", il); + + // step 4: compute decay + dtA_chunk = ggml_permute(ctx, dtA_chunk, 2, 1, 3, 0); // {1, chunk_size_i, n_head, n_seqs} + ggml_tensor * dtA_tmp0 = ggml_repeat_4d(ctx, dtA_chunk, + dtA_chunk->ne[0] * chunk_size_i, dtA_chunk->ne[1], dtA_chunk->ne[2], dtA_chunk->ne[3]); // {chunk_size_i_0, chunk_size_i_1, n_head, n_seqs} + ggml_tensor * dtA_tmp1 = ggml_tri(ctx, dtA_tmp0, GGML_TRI_TYPE_LOWER); // {chunk_size_i_0, chunk_size_i_1, n_head, n_seqs} + ggml_tensor * segsum = ggml_cumsum(ctx, dtA_tmp1); // {chunk_size_i_0, chunk_size_i_1, n_head, n_seqs} + segsum = ggml_cont(ctx, ggml_transpose(ctx, segsum)); // {chunk_size_i_1, chunk_size_i_0, n_head, n_seqs} + cb(segsum, "segsum", il); + ggml_tensor * decay = ggml_exp(ctx, segsum); // {chunk_size_i_1, chunk_size_i_0, n_head, n_seqs} + cb(decay, "decay", il); + + // step 5: compute surrogate_attention_matrix + ggml_tensor * CBdecay = ggml_mul(ctx, CB, decay); + ggml_tensor * surrogate_attention_matrix = ggml_tri(ctx, CBdecay, GGML_TRI_TYPE_LOWER_DIAG); + cb(surrogate_attention_matrix, "surrogate_attention_matrix", il); + + // step 6: compute y + ggml_tensor * dtX_chunk_perm = ggml_cont(ctx, ggml_permute(ctx, dtX_chunk, 1, 2, 0, 3)); + ggml_tensor * y_chunk = ggml_mul_mat(ctx, dtX_chunk_perm, surrogate_attention_matrix); + y_chunk = ggml_cont(ctx, ggml_permute(ctx, y_chunk, 0, 2, 1, 3)); + cb(y_chunk, "y_chunk", il); // {n_head, chunk_size_i, n_seqs} + + // step 7: compute dtxdecay + ggml_tensor * decay_last = ggml_view_4d(ctx, decay, + decay->ne[0], 1, decay->ne[2], decay->ne[3], + decay->nb[1], decay->nb[2], decay->nb[3], + (decay->ne[1] - 1) * decay->nb[1]); + decay_last = ggml_cont(ctx, ggml_permute(ctx, decay_last, 2, 0, 1, 3)); + cb(decay_last, "decay_last", il); + B_perm = ggml_cont(ctx, B_perm); + B_perm = ggml_repeat_4d(ctx, B_perm, + B_perm->ne[0], B_perm->ne[1], B_perm->ne[2] * repeats, B_perm->ne[3]); + ggml_tensor * dtxdecay = ggml_mul(ctx, dtX_chunk, decay_last); + dtxdecay = ggml_cont(ctx, ggml_permute(ctx, dtxdecay, 1, 2, 0, 3)); + cb(dtxdecay, "dtxdecay", il); + + // step 8: compute next_state + ggml_tensor * next_state = ggml_mul_mat(ctx, ggml_cont(ctx, ggml_permute(ctx, B_perm, 1, 0, 2, 3)), dtxdecay); + if (next_state->type != ssm->type) { + next_state = ggml_cast(ctx, next_state, ssm->type); + } + cb(next_state, "next_state", il); + + // TODO: Skip y and state updates if no previous state + + // step 9: update from previous state + dtA_chunk = ggml_cont(ctx, dtA_chunk); + ggml_tensor * dtA_chunk_flat = ggml_view_3d(ctx, + dtA_chunk, dtA_chunk->ne[1], dtA_chunk->ne[2], dtA_chunk->ne[3], + dtA_chunk->nb[2], dtA_chunk->nb[3], 0); // {chunk_size_i, n_head, n_seqs, 1} + ggml_tensor * exp_dtA_cumsum = ggml_view_4d(ctx, + ggml_exp(ctx, ggml_cumsum(ctx, dtA_chunk_flat)), + 1, dtA_chunk->ne[1], dtA_chunk->ne[2], dtA_chunk->ne[3], + dtA_chunk->nb[1], dtA_chunk->nb[2], dtA_chunk->nb[3], 0); // {1, chunk_size_i, n_head, n_seqs} + cb(exp_dtA_cumsum, "exp_dtA_cumsum", il); + ggml_tensor * exp_dtA_cumsum_last = ggml_view_4d(ctx, exp_dtA_cumsum, + exp_dtA_cumsum->ne[0], 1, exp_dtA_cumsum->ne[2], exp_dtA_cumsum->ne[3], + exp_dtA_cumsum->nb[1], exp_dtA_cumsum->nb[2], exp_dtA_cumsum->nb[3], + (exp_dtA_cumsum->ne[1] - 1) * exp_dtA_cumsum->nb[1]); // {1, 1, n_head, n_seqs} + cb(exp_dtA_cumsum_last, "exp_dtA_cumsum_last", il); + // ggml_tensor * exp_dtA_cumsum_perm = ggml_permute(ctx, exp_dtA_cumsum_last, 2, 1, 3, 0); // {1, 1, n_head, n_seqs} + next_state = ggml_add(ctx, next_state, ggml_mul(ctx, ssm, ggml_cont(ctx, exp_dtA_cumsum_last))); + cb(next_state, "next_state_updated", il); + + // step 10: update from previous y + ggml_tensor * y_prev = ggml_mul_mat(ctx, + C_perm, // {d_state, chunk_size_i, n_group, n_seqs} + ssm // {d_state, head_dim, n_head, n_seqs} + ); // {chunk_size_i, head_dim, n_head, n_seqs} + cb(y_prev, "y_prev", il); + y_prev = ggml_mul(ctx, + ggml_cont(ctx, ggml_permute(ctx, y_prev, 2, 0, 1, 3)), // {head_dim, n_head, chunk_size_i, n_seqs} + ggml_cont(ctx, ggml_permute(ctx, exp_dtA_cumsum, 0, 2, 1, 3)) // {1, n_head, chunk_size_i, n_seqs} + ); // {head_dim, chunk_size_i, n_head, n_seqs} + cb(y_prev, "y_prev_mul", il); + y_chunk = ggml_add(ctx, y_chunk, y_prev); + cb(y_chunk, "y_chunk_updated", il); + + // step 11: recurse + if (chunk_size_i == n_seq_tokens) { + y = y_chunk; + } else { + y = ggml_concat(ctx, y, y_chunk, 2); + } + cb(y, "y", il); + ssm = next_state; + } + + // Concat the output y and state + if (ssm->type != y->type) { + ssm = ggml_cast(ctx, ssm, y->type); + } + ggml_tensor * out = ggml_concat(ctx, + ggml_view_1d(ctx, y, ggml_nelements(y), 0), + ggml_view_1d(ctx, ssm, ggml_nelements(ssm), 0), + 0); + return out; + } + }; + + ggml_tensor * y_ssm = build_rs(inp, ssm_states_all, hparams.n_embd_s(), ubatch.n_seqs, get_ssm_rows); + + // store last states + ggml_build_forward_expand( + gf, ggml_cpy(ctx0, ggml_view_1d(ctx0, y_ssm, d_state * d_inner * n_seqs, ggml_nelements(x) * x->nb[0]), + ggml_view_1d(ctx0, ssm_states_all, d_state * d_inner * n_seqs, + kv_head * d_state * d_inner * ggml_element_size(ssm_states_all)))); + + ggml_tensor * y = ggml_view_4d(ctx0, y_ssm, head_dim, n_head, n_seq_tokens, n_seqs, x->nb[1], n_head * x->nb[1], + n_seq_tokens * n_head * x->nb[1], 0); + + // TODO: skip computing output earlier for unused tokens + + y = ggml_add(ctx0, y, ggml_mul(ctx0, x, model.layers[il].ssm_d)); + cb(y, "mamba2_y_add_d", il); + y = ggml_swiglu_split(ctx0, ggml_cont(ctx0, z), y); + + // grouped RMS norm + if (model.layers[il].ssm_norm) { + y = ggml_reshape_4d(ctx0, y, d_inner / n_group, n_group, n_seq_tokens, n_seqs); + y = build_norm(y, model.layers[il].ssm_norm, NULL, LLM_NORM_RMS, il); + } + + y = ggml_reshape_3d(ctx0, y, d_inner, n_seq_tokens, n_seqs); + + // {d_inner, n_embd} @ {d_inner, n_seq_tokens, n_seqs} => {n_embd, n_seq_tokens, n_seqs} + cur = build_lora_mm(model.layers[il].ssm_out, y); + } // {n_embd, n_seq_tokens, n_seqs} => {n_embd, n_tokens} cur = ggml_reshape_2d(ctx0, cur, cur->ne[0], n_seq_tokens * n_seqs); From 65feb1c66ba15ea8bbb354851ea59271b303c0e5 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Fri, 5 Dec 2025 20:57:31 -0700 Subject: [PATCH 6/9] feat: Optimized SSM_SCAN kernel for metal This used Claude Code and resulted in a modest performance improvement while maintaining correctness. Branch: Mamba2SSD Signed-off-by: Gabe Goodhart --- ggml/src/ggml-metal/ggml-metal-device.cpp | 7 +++- ggml/src/ggml-metal/ggml-metal.metal | 42 ++++++++++++++++++----- 2 files changed, 39 insertions(+), 10 deletions(-) diff --git a/ggml/src/ggml-metal/ggml-metal-device.cpp b/ggml/src/ggml-metal/ggml-metal-device.cpp index d291a1ac00..b6c4dc9559 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.cpp +++ b/ggml/src/ggml-metal/ggml-metal-device.cpp @@ -444,7 +444,12 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_scan(ggml_me res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); } - res.smem = 32*sizeof(float)*nsg; + // Shared memory layout: + // - sgptg * NW floats for partial sums (nsg * 32) + // - sgptg floats for shared_x_dt (nsg) + // - sgptg floats for shared_dA (nsg) + // Total: nsg * (32 + 2) floats + res.smem = (32 + 2)*sizeof(float)*nsg; return res; } diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 2a193e3208..6a8f4ed57f 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -2394,6 +2394,7 @@ kernel void kernel_ssm_conv_f32_f32_4( } // ref: ggml.c:ggml_compute_forward_ssm_scan_f32, Mamba-2 part +// Optimized version: reduces redundant memory loads by having one thread load shared values kernel void kernel_ssm_scan_f32( constant ggml_metal_kargs_ssm_scan & args, device const void * src0, @@ -2413,7 +2414,15 @@ kernel void kernel_ssm_scan_f32( uint3 tgpg[[threadgroups_per_grid]]) { constexpr short NW = N_SIMDWIDTH; - shared[tpitg.x] = 0.0f; + // Shared memory layout: + // [0..sgptg*NW-1]: partial sums for reduction (existing) + // [sgptg*NW..sgptg*NW+sgptg-1]: pre-computed x_dt values for each token in batch + // [sgptg*NW+sgptg..sgptg*NW+2*sgptg-1]: pre-computed dA values for each token in batch + threadgroup float * shared_sums = shared; + threadgroup float * shared_x_dt = shared + sgptg * NW; + threadgroup float * shared_dA = shared + sgptg * NW + sgptg; + + shared_sums[tpitg.x] = 0.0f; const int32_t i0 = tpitg.x; const int32_t i1 = tgpig.x; @@ -2453,32 +2462,47 @@ kernel void kernel_ssm_scan_f32( for (int i2 = 0; i2 < n_t; i2 += sgptg) { threadgroup_barrier(mem_flags::mem_threadgroup); - for (int t = 0; t < sgptg && i2 + t < n_t; t++) { - const float dt0 = dt[0]; + // Pre-compute x_dt and dA for this batch of tokens + // Only first sgptg threads do the loads and expensive math + if (i0 < sgptg && i2 + i0 < n_t) { + // ns12 and ns21 are element strides (nb12/nb10, nb21/nb20) + device const float * x_t = x + i0 * args.ns12; + device const float * dt_t = dt + i0 * args.ns21; + + const float dt0 = dt_t[0]; const float dtsp = dt0 <= 20.0f ? log(1.0f + exp(dt0)) : dt0; - const float x_dt = x[0] * dtsp; - const float dA = exp(dtsp * A0); + shared_x_dt[i0] = x_t[0] * dtsp; + shared_dA[i0] = dtsp; // Store dtsp, compute exp(dtsp * A0) per-thread since A0 varies + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + for (int t = 0; t < sgptg && i2 + t < n_t; t++) { + const float x_dt = shared_x_dt[t]; + const float dA = exp(shared_dA[t] * A0); s = (s0 * dA) + (B[i0] * x_dt); const float sumf = simd_sum(s * C[i0]); if (tiisg == 0) { - shared[t*NW + sgitg] = sumf; + shared_sums[t*NW + sgitg] = sumf; } // recurse s0 = s; - x += args.ns12; - dt += args.ns21; B += args.ns42; C += args.ns52; } + // Advance pointers for next batch + x += sgptg * args.ns12; + dt += sgptg * args.ns21; + threadgroup_barrier(mem_flags::mem_threadgroup); - const float sumf = simd_sum(shared[sgitg*NW + tiisg]); + const float sumf = simd_sum(shared_sums[sgitg*NW + tiisg]); if (tiisg == 0 && i2 + sgitg < n_t) { y[sgitg*nh*nr] = sumf; From 51adb32a0dd917437a9672fd8b2f11851ac5e393 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Fri, 5 Dec 2025 21:23:53 -0700 Subject: [PATCH 7/9] feat: Optimization pass over SSD implementation More Claude optimization, but this is still significantly slower than without! Branch: Mamba2SSD Signed-off-by: Gabe Goodhart --- src/models/graph-context-mamba.cpp | 270 ++++++++++++++++------------- 1 file changed, 153 insertions(+), 117 deletions(-) diff --git a/src/models/graph-context-mamba.cpp b/src/models/graph-context-mamba.cpp index 6caba05500..a2607f93c8 100644 --- a/src/models/graph-context-mamba.cpp +++ b/src/models/graph-context-mamba.cpp @@ -247,8 +247,8 @@ ggml_tensor * llm_graph_context_mamba::build_mamba2_layer(llm_graph_input_rs * i auto get_ssm_rows = [&](ggml_context * ctx, ggml_tensor * states, ggml_tensor * ids) { ggml_tensor * ssm = ggml_reshape_4d(ctx, states, d_state, head_dim, n_head, mctx_cur->get_size()); - // if (n_seq_tokens == 1) { - if (true) { + if (n_seq_tokens == 1) { + // if (true) { //DEBUG LLAMA_LOG_DEBUG("build_mamba2_layer(layer %d): single-token update\n", il); // If single-token, use ssm_scan op @@ -258,171 +258,207 @@ ggml_tensor * llm_graph_context_mamba::build_mamba2_layer(llm_graph_input_rs * i //DEBUG LLAMA_LOG_DEBUG("build_mamba2_layer(layer %d): multi-token chunk scan\n", il); - // otherwise, use the SSD formulation + // ============================================================ + // OPTIMIZED SSD (State Space Duality) Implementation + // Key optimizations: + // 1. Pre-permute B and C once at the start to avoid repeated permutes + // 2. Minimize ggml_cont() calls by choosing layouts carefully + // 3. Reuse intermediate tensors where possible + // ============================================================ // extract the state(s) for the sequences identified by ids if (ssm->ne[3] != ids->ne[0]) { - ggml_tensor * ssm_perm = ggml_permute(ctx, ssm, 0, 2, 3, 1); // put the target dim in dim 1 + ggml_tensor * ssm_perm = ggml_permute(ctx, ssm, 0, 2, 3, 1); ggml_tensor * ids_perm_rep = ggml_repeat_4d(ctx, ids, - ids->ne[0], ssm->ne[1], ssm->ne[2], 1); // repeat to match expected shape - ggml_tensor * ssm_ids = ggml_get_rows(ctx, ssm_perm, ids_perm_rep); // extract ids as rows - ssm = ggml_cont(ctx, ggml_permute(ctx, ssm_ids, 0, 3, 1, 2)); // permute back to original shape + ids->ne[0], ssm->ne[1], ssm->ne[2], 1); + ggml_tensor * ssm_ids = ggml_get_rows(ctx, ssm_perm, ids_perm_rep); + ssm = ggml_cont(ctx, ggml_permute(ctx, ssm_ids, 0, 3, 1, 2)); GGML_ASSERT(ssm->ne[3] == ids->ne[0]); } // ssm -> {d_state, head_dim, n_head, n_seqs} // step 1: compute dt softplus - // NOTE: In other implementations, the bias is added after - // the softplus. This shouldn't be a problem, but it's a - // difference. ggml_tensor * dt_softplus = ggml_softplus(ctx, dt); // {n_head, n_seq_tokens, n_seqs} dt_softplus = ggml_clamp(ctx, dt_softplus, 0.001, 100.0); cb(dt_softplus, "dt_softplus", il); // step 2: compute dtA and dtX - ggml_tensor * dtA = ggml_mul(ctx, dt_softplus, ggml_reshape_1d(ctx, A, A->ne[1])); // {n_head, n_seq_tokens, n_seqs} + // dtA: {n_head, n_seq_tokens, n_seqs} + ggml_tensor * dtA = ggml_mul(ctx, dt_softplus, ggml_reshape_1d(ctx, A, A->ne[1])); cb(dtA, "dtA", il); - ggml_tensor * dtX = ggml_mul(ctx, x, ggml_reshape_4d(ctx, dt_softplus, 1, dt_softplus->ne[0], dt_softplus->ne[1], dt_softplus->ne[2])); // {head_dim, n_head, n_seq_tokens, n_seqs} + + // dtX: {head_dim, n_head, n_seq_tokens, n_seqs} + ggml_tensor * dtX = ggml_mul(ctx, x, ggml_reshape_4d(ctx, dt_softplus, + 1, dt_softplus->ne[0], dt_softplus->ne[1], dt_softplus->ne[2])); cb(dtX, "dtX", il); - // loop over all chunks + // Pre-permute dtX once for the attention-like matmul: {head_dim, n_head, n_seq_tokens, n_seqs} -> {n_head, n_seq_tokens, head_dim, n_seqs} + // This layout is what we need for: y = SAM @ dtX^T + ggml_tensor * dtX_perm = ggml_cont(ctx, ggml_permute(ctx, dtX, 1, 2, 0, 3)); // {n_head, n_seq_tokens, head_dim, n_seqs} + cb(dtX_perm, "dtX_perm", il); + + // Pre-permute B and C for mul_mat: {d_state, n_group, n_seq_tokens, n_seqs} -> {d_state, n_seq_tokens, n_group, n_seqs} + // These permuted versions will be used throughout + ggml_tensor * B_perm_full = ggml_permute(ctx, B, 0, 2, 1, 3); // {d_state, n_seq_tokens, n_group, n_seqs} + ggml_tensor * C_perm_full = ggml_permute(ctx, C, 0, 2, 1, 3); // {d_state, n_seq_tokens, n_group, n_seqs} + uint32_t repeats = n_head / n_group; + // For the state update, we need B in a different layout + // Pre-compute the expanded B for state updates: {d_state, n_seq_tokens, n_head, n_seqs} + ggml_tensor * B_for_state = ggml_cont(ctx, B_perm_full); + B_for_state = ggml_repeat_4d(ctx, B_for_state, + B_for_state->ne[0], B_for_state->ne[1], B_for_state->ne[2] * repeats, B_for_state->ne[3]); + // Permute for mul_mat: {n_seq_tokens, d_state, n_head, n_seqs} + ggml_tensor * B_for_state_perm = ggml_cont(ctx, ggml_permute(ctx, B_for_state, 1, 0, 2, 3)); + cb(B_for_state_perm, "B_for_state_perm", il); + // Empty y that will be extended with each chunk of tokens ggml_tensor * y = ggml_new_tensor_4d(ctx, x->type, x->ne[0], x->ne[1], 0, x->ne[3]); - // TODO: make this configurable - const uint32_t chunk_size = 512; // default ubatch size + + const uint32_t chunk_size = 512; for (auto chunk_i = 0; chunk_i < n_seq_tokens; chunk_i += chunk_size) { - ggml_tensor * dtA_chunk; - ggml_tensor * dtX_chunk; - ggml_tensor * B_chunk; - ggml_tensor * C_chunk; const auto chunk_size_i = std::min(chunk_size, uint32_t(n_seq_tokens - chunk_i)); - if (chunk_size_i == n_seq_tokens) { - dtA_chunk = dtA; - dtX_chunk = dtX; - B_chunk = B; - C_chunk = C; - } else { - // chunk views - // slice dtA on dim 1 - dtA_chunk = ggml_view_3d(ctx, dtA, - dtA->ne[0], chunk_size_i, dtA->ne[2], - dtA->nb[1], dtA->nb[2], - chunk_i * dtA->nb[1]); - // slice dtX on dim 2 - dtX_chunk = ggml_view_4d(ctx, dtX, - dtX->ne[0], dtX->ne[1], chunk_size_i, dtX->ne[3], - dtX->nb[1], dtX->nb[2], dtX->nb[3], - chunk_i * dtX->nb[2]); - // slice B on dim 2 - B_chunk = ggml_view_4d(ctx, B, - B->ne[0], B->ne[1], chunk_size_i, B->ne[3], - B->nb[1], B->nb[2], B->nb[3], - chunk_i * B->nb[2]); - // slice C on dim 2 - C_chunk = ggml_view_4d(ctx, C, - C->ne[0], C->ne[1], chunk_size_i, C->ne[3], - C->nb[1], C->nb[2], C->nb[3], - chunk_i * C->nb[2]); - } - cb(dtA_chunk, "dtA_chunk", il); // {n_head, chunk_size_i, n_seqs} - cb(dtX_chunk, "dtX_chunk", il); // {head_dim, n_head, chunk_size_i, n_seqs} - cb(B_chunk, "B_chunk", il); // {d_state, n_group, chunk_size_i, n_seqs} - cb(C_chunk, "C_chunk", il); // {d_state, n_group, chunk_size_i, n_seqs} - - // step 3: compute CB - ggml_tensor * C_perm = ggml_permute(ctx, C_chunk, 0, 2, 1, 3); // {d_state, chunk_size_i, n_group, n_seqs} - ggml_tensor * B_perm = ggml_permute(ctx, B_chunk, 0, 2, 1, 3); // {d_state, chunk_size_i, n_group, n_seqs} - ggml_tensor * CB = ggml_mul_mat(ctx, B_perm, C_perm); // {chunk_size_i, chunk_size_i, n_group, n_seqs} - CB = ggml_repeat_4d(ctx, CB, CB->ne[0], CB->ne[1], CB->ne[2] * repeats, CB->ne[3]); // {chunk_size_i, chunk_size_i, n_head (repeats * n_group), n_seqs} + + // Create chunk views + ggml_tensor * dtA_chunk = (chunk_size_i == n_seq_tokens) ? dtA : + ggml_view_3d(ctx, dtA, dtA->ne[0], chunk_size_i, dtA->ne[2], + dtA->nb[1], dtA->nb[2], chunk_i * dtA->nb[1]); + + ggml_tensor * dtX_chunk_perm = (chunk_size_i == n_seq_tokens) ? dtX_perm : + ggml_view_4d(ctx, dtX_perm, + dtX_perm->ne[0], chunk_size_i, dtX_perm->ne[2], dtX_perm->ne[3], + dtX_perm->nb[1], dtX_perm->nb[2], dtX_perm->nb[3], + chunk_i * dtX_perm->nb[1]); + + ggml_tensor * dtX_chunk = (chunk_size_i == n_seq_tokens) ? dtX : + ggml_view_4d(ctx, dtX, dtX->ne[0], dtX->ne[1], chunk_size_i, dtX->ne[3], + dtX->nb[1], dtX->nb[2], dtX->nb[3], chunk_i * dtX->nb[2]); + + // Use pre-permuted B and C chunks + ggml_tensor * B_perm_chunk = (chunk_size_i == n_seq_tokens) ? B_perm_full : + ggml_view_4d(ctx, B_perm_full, + B_perm_full->ne[0], chunk_size_i, B_perm_full->ne[2], B_perm_full->ne[3], + B_perm_full->nb[1], B_perm_full->nb[2], B_perm_full->nb[3], + chunk_i * B_perm_full->nb[1]); + + ggml_tensor * C_perm_chunk = (chunk_size_i == n_seq_tokens) ? C_perm_full : + ggml_view_4d(ctx, C_perm_full, + C_perm_full->ne[0], chunk_size_i, C_perm_full->ne[2], C_perm_full->ne[3], + C_perm_full->nb[1], C_perm_full->nb[2], C_perm_full->nb[3], + chunk_i * C_perm_full->nb[1]); + + ggml_tensor * B_state_chunk = (chunk_size_i == n_seq_tokens) ? B_for_state_perm : + ggml_view_4d(ctx, B_for_state_perm, + chunk_size_i, B_for_state_perm->ne[1], B_for_state_perm->ne[2], B_for_state_perm->ne[3], + B_for_state_perm->nb[1], B_for_state_perm->nb[2], B_for_state_perm->nb[3], + chunk_i * B_for_state_perm->nb[0]); + + cb(dtA_chunk, "dtA_chunk", il); + cb(dtX_chunk_perm, "dtX_chunk_perm", il); + + // step 3: compute CB = C @ B^T + // B_perm_chunk, C_perm_chunk: {d_state, chunk_size_i, n_group, n_seqs} + ggml_tensor * CB = ggml_mul_mat(ctx, B_perm_chunk, C_perm_chunk); // {chunk_size_i, chunk_size_i, n_group, n_seqs} + CB = ggml_repeat_4d(ctx, CB, CB->ne[0], CB->ne[1], CB->ne[2] * repeats, CB->ne[3]); cb(CB, "CB", il); - // step 4: compute decay - dtA_chunk = ggml_permute(ctx, dtA_chunk, 2, 1, 3, 0); // {1, chunk_size_i, n_head, n_seqs} - ggml_tensor * dtA_tmp0 = ggml_repeat_4d(ctx, dtA_chunk, - dtA_chunk->ne[0] * chunk_size_i, dtA_chunk->ne[1], dtA_chunk->ne[2], dtA_chunk->ne[3]); // {chunk_size_i_0, chunk_size_i_1, n_head, n_seqs} - ggml_tensor * dtA_tmp1 = ggml_tri(ctx, dtA_tmp0, GGML_TRI_TYPE_LOWER); // {chunk_size_i_0, chunk_size_i_1, n_head, n_seqs} - ggml_tensor * segsum = ggml_cumsum(ctx, dtA_tmp1); // {chunk_size_i_0, chunk_size_i_1, n_head, n_seqs} - segsum = ggml_cont(ctx, ggml_transpose(ctx, segsum)); // {chunk_size_i_1, chunk_size_i_0, n_head, n_seqs} - cb(segsum, "segsum", il); - ggml_tensor * decay = ggml_exp(ctx, segsum); // {chunk_size_i_1, chunk_size_i_0, n_head, n_seqs} + // step 4: compute decay matrix + // dtA_chunk: {n_head, chunk_size_i, n_seqs} + // We need to build the lower-triangular cumsum matrix + ggml_tensor * dtA_for_decay = ggml_permute(ctx, dtA_chunk, 2, 1, 3, 0); // {1, chunk_size_i, n_head, n_seqs} + ggml_tensor * dtA_expanded = ggml_repeat_4d(ctx, dtA_for_decay, + dtA_for_decay->ne[0] * chunk_size_i, dtA_for_decay->ne[1], + dtA_for_decay->ne[2], dtA_for_decay->ne[3]); + ggml_tensor * dtA_tri = ggml_tri(ctx, dtA_expanded, GGML_TRI_TYPE_LOWER); + ggml_tensor * segsum = ggml_cumsum(ctx, dtA_tri); + segsum = ggml_cont(ctx, ggml_transpose(ctx, segsum)); // Need cont for transpose + ggml_tensor * decay = ggml_exp(ctx, segsum); cb(decay, "decay", il); - // step 5: compute surrogate_attention_matrix + // step 5: compute surrogate attention matrix ggml_tensor * CBdecay = ggml_mul(ctx, CB, decay); - ggml_tensor * surrogate_attention_matrix = ggml_tri(ctx, CBdecay, GGML_TRI_TYPE_LOWER_DIAG); - cb(surrogate_attention_matrix, "surrogate_attention_matrix", il); - - // step 6: compute y - ggml_tensor * dtX_chunk_perm = ggml_cont(ctx, ggml_permute(ctx, dtX_chunk, 1, 2, 0, 3)); - ggml_tensor * y_chunk = ggml_mul_mat(ctx, dtX_chunk_perm, surrogate_attention_matrix); + ggml_tensor * SAM = ggml_tri(ctx, CBdecay, GGML_TRI_TYPE_LOWER_DIAG); + cb(SAM, "SAM", il); + + // step 6: compute y = SAM @ dtX^T + // SAM: {chunk_size_i, chunk_size_i, n_head, n_seqs} + // dtX_chunk_perm: {n_head, chunk_size_i, head_dim, n_seqs} + ggml_tensor * y_chunk = ggml_mul_mat(ctx, dtX_chunk_perm, SAM); + // Result: {head_dim, chunk_size_i, n_head, n_seqs} + // We need: {head_dim, n_head, chunk_size_i, n_seqs} y_chunk = ggml_cont(ctx, ggml_permute(ctx, y_chunk, 0, 2, 1, 3)); - cb(y_chunk, "y_chunk", il); // {n_head, chunk_size_i, n_seqs} + cb(y_chunk, "y_chunk", il); - // step 7: compute dtxdecay + // step 7: compute state update contribution + // decay_last: last row of decay matrix ggml_tensor * decay_last = ggml_view_4d(ctx, decay, decay->ne[0], 1, decay->ne[2], decay->ne[3], decay->nb[1], decay->nb[2], decay->nb[3], (decay->ne[1] - 1) * decay->nb[1]); - decay_last = ggml_cont(ctx, ggml_permute(ctx, decay_last, 2, 0, 1, 3)); - cb(decay_last, "decay_last", il); - B_perm = ggml_cont(ctx, B_perm); - B_perm = ggml_repeat_4d(ctx, B_perm, - B_perm->ne[0], B_perm->ne[1], B_perm->ne[2] * repeats, B_perm->ne[3]); - ggml_tensor * dtxdecay = ggml_mul(ctx, dtX_chunk, decay_last); - dtxdecay = ggml_cont(ctx, ggml_permute(ctx, dtxdecay, 1, 2, 0, 3)); - cb(dtxdecay, "dtxdecay", il); - - // step 8: compute next_state - ggml_tensor * next_state = ggml_mul_mat(ctx, ggml_cont(ctx, ggml_permute(ctx, B_perm, 1, 0, 2, 3)), dtxdecay); + // decay_last: {chunk_size_i, 1, n_head, n_seqs} -> need {1, n_head, chunk_size_i, n_seqs} for broadcast + ggml_tensor * decay_last_bc = ggml_cont(ctx, ggml_permute(ctx, decay_last, 2, 0, 1, 3)); + + // dtxdecay = dtX * decay_last (broadcast) + // dtX_chunk: {head_dim, n_head, chunk_size_i, n_seqs} + ggml_tensor * dtxdecay = ggml_mul(ctx, dtX_chunk, decay_last_bc); + // Permute for mul_mat: {n_head, chunk_size_i, head_dim, n_seqs} + ggml_tensor * dtxdecay_perm = ggml_cont(ctx, ggml_permute(ctx, dtxdecay, 1, 2, 0, 3)); + + // step 8: compute next_state = B^T @ dtxdecay + // B_state_chunk: {chunk_size_i, d_state, n_head, n_seqs} + // dtxdecay_perm: {n_head, chunk_size_i, head_dim, n_seqs} + ggml_tensor * next_state = ggml_mul_mat(ctx, B_state_chunk, dtxdecay_perm); + // Result: {d_state, head_dim, n_head, n_seqs} if (next_state->type != ssm->type) { next_state = ggml_cast(ctx, next_state, ssm->type); } cb(next_state, "next_state", il); - // TODO: Skip y and state updates if no previous state - - // step 9: update from previous state - dtA_chunk = ggml_cont(ctx, dtA_chunk); - ggml_tensor * dtA_chunk_flat = ggml_view_3d(ctx, - dtA_chunk, dtA_chunk->ne[1], dtA_chunk->ne[2], dtA_chunk->ne[3], - dtA_chunk->nb[2], dtA_chunk->nb[3], 0); // {chunk_size_i, n_head, n_seqs, 1} - ggml_tensor * exp_dtA_cumsum = ggml_view_4d(ctx, - ggml_exp(ctx, ggml_cumsum(ctx, dtA_chunk_flat)), - 1, dtA_chunk->ne[1], dtA_chunk->ne[2], dtA_chunk->ne[3], - dtA_chunk->nb[1], dtA_chunk->nb[2], dtA_chunk->nb[3], 0); // {1, chunk_size_i, n_head, n_seqs} - cb(exp_dtA_cumsum, "exp_dtA_cumsum", il); - ggml_tensor * exp_dtA_cumsum_last = ggml_view_4d(ctx, exp_dtA_cumsum, + // step 9: update state from previous state + // Compute exp(cumsum(dtA)) for state decay + ggml_tensor * dtA_for_state = ggml_cont(ctx, dtA_for_decay); + ggml_tensor * dtA_flat = ggml_view_3d(ctx, dtA_for_state, + dtA_for_state->ne[1], dtA_for_state->ne[2], dtA_for_state->ne[3], + dtA_for_state->nb[2], dtA_for_state->nb[3], 0); + ggml_tensor * exp_dtA_cumsum = ggml_exp(ctx, ggml_cumsum(ctx, dtA_flat)); + exp_dtA_cumsum = ggml_view_4d(ctx, exp_dtA_cumsum, + 1, dtA_for_state->ne[1], dtA_for_state->ne[2], dtA_for_state->ne[3], + dtA_for_state->nb[1], dtA_for_state->nb[2], dtA_for_state->nb[3], 0); + + // Get last value for state update + ggml_tensor * state_decay = ggml_view_4d(ctx, exp_dtA_cumsum, exp_dtA_cumsum->ne[0], 1, exp_dtA_cumsum->ne[2], exp_dtA_cumsum->ne[3], exp_dtA_cumsum->nb[1], exp_dtA_cumsum->nb[2], exp_dtA_cumsum->nb[3], - (exp_dtA_cumsum->ne[1] - 1) * exp_dtA_cumsum->nb[1]); // {1, 1, n_head, n_seqs} - cb(exp_dtA_cumsum_last, "exp_dtA_cumsum_last", il); - // ggml_tensor * exp_dtA_cumsum_perm = ggml_permute(ctx, exp_dtA_cumsum_last, 2, 1, 3, 0); // {1, 1, n_head, n_seqs} - next_state = ggml_add(ctx, next_state, ggml_mul(ctx, ssm, ggml_cont(ctx, exp_dtA_cumsum_last))); + (exp_dtA_cumsum->ne[1] - 1) * exp_dtA_cumsum->nb[1]); + + next_state = ggml_add(ctx, next_state, ggml_mul(ctx, ssm, ggml_cont(ctx, state_decay))); cb(next_state, "next_state_updated", il); - // step 10: update from previous y - ggml_tensor * y_prev = ggml_mul_mat(ctx, - C_perm, // {d_state, chunk_size_i, n_group, n_seqs} - ssm // {d_state, head_dim, n_head, n_seqs} - ); // {chunk_size_i, head_dim, n_head, n_seqs} - cb(y_prev, "y_prev", il); - y_prev = ggml_mul(ctx, - ggml_cont(ctx, ggml_permute(ctx, y_prev, 2, 0, 1, 3)), // {head_dim, n_head, chunk_size_i, n_seqs} - ggml_cont(ctx, ggml_permute(ctx, exp_dtA_cumsum, 0, 2, 1, 3)) // {1, n_head, chunk_size_i, n_seqs} - ); // {head_dim, chunk_size_i, n_head, n_seqs} - cb(y_prev, "y_prev_mul", il); + // step 10: update y from previous state + // y_prev = C @ ssm (project state through C) + // C_perm_chunk: {d_state, chunk_size_i, n_group, n_seqs} + // ssm: {d_state, head_dim, n_head, n_seqs} + ggml_tensor * y_prev = ggml_mul_mat(ctx, C_perm_chunk, ssm); + // Result: {chunk_size_i, head_dim, n_head, n_seqs} + // Need: {head_dim, n_head, chunk_size_i, n_seqs} + y_prev = ggml_cont(ctx, ggml_permute(ctx, y_prev, 2, 0, 1, 3)); + + // Scale by cumulative decay + // exp_dtA_cumsum: {1, chunk_size_i, n_head, n_seqs} + // Need: {1, n_head, chunk_size_i, n_seqs} for broadcast + ggml_tensor * y_decay = ggml_cont(ctx, ggml_permute(ctx, exp_dtA_cumsum, 0, 2, 1, 3)); + y_prev = ggml_mul(ctx, y_prev, y_decay); + y_chunk = ggml_add(ctx, y_chunk, y_prev); - cb(y_chunk, "y_chunk_updated", il); + cb(y_chunk, "y_chunk_final", il); - // step 11: recurse + // step 11: accumulate results if (chunk_size_i == n_seq_tokens) { y = y_chunk; } else { y = ggml_concat(ctx, y, y_chunk, 2); } - cb(y, "y", il); ssm = next_state; } From 6340ab1bc8a127d68353e6c0ebe360dd948dd4e3 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Mon, 8 Dec 2025 09:12:12 -0700 Subject: [PATCH 8/9] TEMP (wip): Partial work towards a unified kernel implementation of SSD Branch: Mamba2SSD Signed-off-by: Gabe Goodhart --- ggml/src/ggml-metal/ggml-metal-device.cpp | 25 ++++ ggml/src/ggml-metal/ggml-metal-device.h | 1 + ggml/src/ggml-metal/ggml-metal-ops.cpp | 8 +- ggml/src/ggml-metal/ggml-metal.metal | 160 ++++++++++++++++++++-- src/models/graph-context-mamba.cpp | 6 +- 5 files changed, 189 insertions(+), 11 deletions(-) diff --git a/ggml/src/ggml-metal/ggml-metal-device.cpp b/ggml/src/ggml-metal/ggml-metal-device.cpp index b6c4dc9559..1043d9a49f 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.cpp +++ b/ggml/src/ggml-metal/ggml-metal-device.cpp @@ -454,6 +454,31 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_scan(ggml_me return res; } +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_scan_ssd(ggml_metal_library_t lib, const ggml_tensor * op) { + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + + char base[256]; + char name[256]; + + const int nsg = (ne00 + 31)/32; + + snprintf(base, 256, "kernel_ssm_scan_ssd_%s", ggml_type_name(op->src[0]->type)); + snprintf(name, 256, "%s_nsg=%d", base, nsg); + + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); + } + + // Shared memory layout for SSD kernel: + // - BATCH_SIZE * sgptg floats for partial sums + // BATCH_SIZE = 8, so 8 * nsg floats + constexpr int BATCH_SIZE = 8; + res.smem = BATCH_SIZE * nsg * sizeof(float); + + return res; +} + ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_rwkv(ggml_metal_library_t lib, const ggml_tensor * op) { char base[256]; char name[256]; diff --git a/ggml/src/ggml-metal/ggml-metal-device.h b/ggml/src/ggml-metal/ggml-metal-device.h index 05202a738b..42e5746e02 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.h +++ b/ggml/src/ggml-metal/ggml-metal-device.h @@ -119,6 +119,7 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_soft_max struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_conv (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_conv_batched (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_scan (ggml_metal_library_t lib, const struct ggml_tensor * op); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_scan_ssd (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_rwkv (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv_ext (ggml_metal_library_t lib, enum ggml_type tsrc0, enum ggml_type tsrc1, int nsg, int nxpsg, int r1ptg); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mm (ggml_metal_library_t lib, const struct ggml_tensor * op); diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp index 54972d88c6..ad840a184f 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.cpp +++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp @@ -1471,7 +1471,13 @@ int ggml_metal_op_ssm_scan(ggml_metal_op_t ctx, int idx) { /*.nb0 =*/ nb0, }; - auto pipeline = ggml_metal_library_get_pipeline_ssm_scan(lib, op); + auto pipeline = (n_seq_tokens > 1) + ? ggml_metal_library_get_pipeline_ssm_scan_ssd(lib, op) + : ggml_metal_library_get_pipeline_ssm_scan(lib, op); + + // // Use sequential scan for now - the SSD kernel needs further optimization + // // to be competitive with the efficient sequential implementation + // auto pipeline = ggml_metal_library_get_pipeline_ssm_scan(lib, op); GGML_ASSERT(d_state <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 6a8f4ed57f..3c9bc91d39 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -2477,31 +2477,41 @@ kernel void kernel_ssm_scan_f32( threadgroup_barrier(mem_flags::mem_threadgroup); - for (int t = 0; t < sgptg && i2 + t < n_t; t++) { + // Phase 1: Compute states and s*C products for all tokens in batch + // Store partial products, delay reduction + const int batch_len = min((int)sgptg, n_t - i2); + device const float * B_t = B; + device const float * C_t = C; + + for (int t = 0; t < batch_len; t++) { const float x_dt = shared_x_dt[t]; const float dA = exp(shared_dA[t] * A0); - s = (s0 * dA) + (B[i0] * x_dt); + s = (s0 * dA) + (B_t[i0] * x_dt); - const float sumf = simd_sum(s * C[i0]); + // Compute s * C and do SIMD reduction + const float sumf = simd_sum(s * C_t[i0]); if (tiisg == 0) { shared_sums[t*NW + sgitg] = sumf; } - // recurse s0 = s; - - B += args.ns42; - C += args.ns52; + B_t += args.ns42; + C_t += args.ns52; } - // Advance pointers for next batch + // Advance B, C pointers for next batch + B += batch_len * args.ns42; + C += batch_len * args.ns52; + + // Advance x, dt pointers for next batch x += sgptg * args.ns12; dt += sgptg * args.ns21; threadgroup_barrier(mem_flags::mem_threadgroup); + // Phase 2: Final reduction and output const float sumf = simd_sum(shared_sums[sgitg*NW + tiisg]); if (tiisg == 0 && i2 + sgitg < n_t) { @@ -2514,6 +2524,140 @@ kernel void kernel_ssm_scan_f32( s_buff[i] = s; } +// SSD kernel using parallel prefix scan for efficient multi-token processing +// +// The SSM state update s[t] = dA[t] * s[t-1] + B[t] * x[t] * dt[t] forms an +// associative scan with operator: (c1,v1) ⊕ (c2,v2) = (c2*c1, c2*v1 + v2) +// +// This allows O(log n) parallel prefix computation instead of O(n) sequential. +// We use a work-efficient Blelloch scan within each threadgroup. +// +// Dispatch: one threadgroup per (head_dim_idx, head, seq) +// Threads: must be power of 2, >= n_seq_tokens +kernel void kernel_ssm_scan_ssd_f32( + constant ggml_metal_kargs_ssm_scan & args, + device const void * src0, + device const void * src1, + device const void * src2, + device const void * src3, + device const void * src4, + device const void * src5, + device const void * src6, + device float * dst, + threadgroup float * shared [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + ushort3 tpitg[[thread_position_in_threadgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgptg[[simdgroups_per_threadgroup]], + uint3 tgpg[[threadgroups_per_grid]]) { + + constexpr short NW = N_SIMDWIDTH; + + const int32_t i0 = tpitg.x; // state index within d_state + const int32_t i1 = tgpig.x; // head_dim index + const int32_t ir = tgpig.y; // head index + const int32_t i3 = tgpig.z; // sequence index + + const int32_t nc = args.d_state; + const int32_t nr = args.d_inner; // head_dim + const int32_t nh = args.n_head; + const int32_t ng = args.n_group; + const int32_t n_t = args.n_seq_tokens; + + const int32_t s_off = args.s_off; + const int32_t g = ir / (nh / ng); // group index for B, C + + device const int32_t * ids = (device const int32_t *) src6; + + // State buffers + device const float * s0_buff = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03); + device float * s_buff = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off); + + const int32_t state_idx = i0 + i1*nc; + + // Load initial state + float s0 = s0_buff[state_idx]; + + // A coefficient + device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31); + const float A0 = A[i0 % args.ne30]; + + // Input pointers + device const float * x_base = (device const float *)((device const char *) src1 + i1*args.nb10 + ir*args.nb11 + i3*args.nb13); + device const float * dt_base = (device const float *)((device const char *) src2 + ir*args.nb20 + i3*args.nb22); + device const float * B_base = (device const float *)((device const char *) src4 + g*args.nb41 + i3*args.nb43); + device const float * C_base = (device const float *)((device const char *) src5 + g*args.nb51 + i3*args.nb53); + + // Output pointer + device float * y_base = dst + (i1 + ir*nr + i3*(n_t*nh*nr)); + + // Shared memory layout: + // - sgptg * NW floats for partial sums + // - sgptg floats for shared_x_dt + // - sgptg floats for shared_dA + threadgroup float * shared_sums = shared; + threadgroup float * shared_x_dt = shared + sgptg * NW; + threadgroup float * shared_dA = shared + sgptg * NW + sgptg; + + shared_sums[tpitg.x] = 0.0f; + + float s = 0.0f; + + // Process tokens in batches of sgptg + for (int i2 = 0; i2 < n_t; i2 += sgptg) { + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Pre-compute x_dt and dA for this batch of tokens + if (i0 < sgptg && i2 + i0 < n_t) { + device const float * x_t = x_base + i0 * args.ns12; + device const float * dt_t = dt_base + i0 * args.ns21; + + const float dt0 = dt_t[0]; + const float dtsp = dt0 <= 20.0f ? log(1.0f + exp(dt0)) : dt0; + shared_x_dt[i0] = x_t[0] * dtsp; + shared_dA[i0] = dtsp; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Process tokens in batch sequentially (standard approach) + for (int t = 0; t < sgptg && i2 + t < n_t; t++) { + const float x_dt = shared_x_dt[t]; + const float dA = exp(shared_dA[t] * A0); + + s = (s0 * dA) + (B_base[i0] * x_dt); + + const float sumf = simd_sum(s * C_base[i0]); + + if (tiisg == 0) { + shared_sums[t*NW + sgitg] = sumf; + } + + s0 = s; + + B_base += args.ns42; + C_base += args.ns52; + } + + // Advance pointers for next batch + x_base += sgptg * args.ns12; + dt_base += sgptg * args.ns21; + + threadgroup_barrier(mem_flags::mem_threadgroup); + + const float sumf = simd_sum(shared_sums[sgitg*NW + tiisg]); + + if (tiisg == 0 && i2 + sgitg < n_t) { + y_base[sgitg*nh*nr] = sumf; + } + + y_base += sgptg*nh*nr; + } + + s_buff[state_idx] = s; +} + kernel void kernel_rwkv_wkv6_f32( device const float * k, device const float * v, diff --git a/src/models/graph-context-mamba.cpp b/src/models/graph-context-mamba.cpp index a2607f93c8..af67163b52 100644 --- a/src/models/graph-context-mamba.cpp +++ b/src/models/graph-context-mamba.cpp @@ -247,8 +247,10 @@ ggml_tensor * llm_graph_context_mamba::build_mamba2_layer(llm_graph_input_rs * i auto get_ssm_rows = [&](ggml_context * ctx, ggml_tensor * states, ggml_tensor * ids) { ggml_tensor * ssm = ggml_reshape_4d(ctx, states, d_state, head_dim, n_head, mctx_cur->get_size()); - if (n_seq_tokens == 1) { - // if (true) { + // Use SSM_SCAN op for all cases - the Metal kernel handles both + // single-token (sequential scan) and multi-token (SSD formulation) internally + if (true) { + // if (n_seq_tokens == 1) { //DEBUG LLAMA_LOG_DEBUG("build_mamba2_layer(layer %d): single-token update\n", il); // If single-token, use ssm_scan op From 11067a70756b1728f93063e7b2d7c284938c9ff5 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Mon, 8 Dec 2025 12:34:24 -0700 Subject: [PATCH 9/9] fix: Provide macos-specific backtrace printing to avoid terminal death Branch: MacOSSafeBacktrace Signed-off-by: Gabe Goodhart --- ggml/src/ggml.c | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 9f5cdc1398..cb1d69d09b 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -124,6 +124,13 @@ static void ggml_print_backtrace_symbols(void) { int nptrs = backtrace(trace, sizeof(trace)/sizeof(trace[0])); backtrace_symbols_fd(trace, nptrs, STDERR_FILENO); } +#elif defined(__APPLE__) +#include +static void ggml_print_backtrace_symbols(void) { + void * trace[100]; + int nptrs = backtrace(trace, sizeof(trace)/sizeof(trace[0])); + backtrace_symbols_fd(trace, nptrs, STDERR_FILENO); +} #else static void ggml_print_backtrace_symbols(void) { // platform not supported @@ -135,6 +142,14 @@ void ggml_print_backtrace(void) { if (GGML_NO_BACKTRACE) { return; } +#if defined(__APPLE__) + // On macOS, fork+debugger attachment is problematic due to: + // 1. libdispatch "poisons" forked child processes + // 2. lldb has issues attaching to parent from forked child + // Use simple backtrace() instead to avoid Terminal.app crashes + ggml_print_backtrace_symbols(); + return; +#endif #if defined(__linux__) FILE * f = fopen("/proc/self/status", "r"); size_t size = 0;