Skip to content

Commit b41721c

Browse files
committed
feat: First-pass at porting SSD impl from previous work
It builds but doesn't run yet Branch: Mamba2SSD Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
1 parent bde188d commit b41721c

File tree

1 file changed

+191
-3
lines changed

1 file changed

+191
-3
lines changed

src/models/graph-context-mamba.cpp

Lines changed: 191 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
#include "models.h"
22

3+
#include "llama-impl.h"
4+
35
llm_graph_context_mamba::llm_graph_context_mamba(const llm_graph_params & params) : llm_graph_context(params) {}
46

57
ggml_tensor * llm_graph_context_mamba::build_mamba_layer(llm_graph_input_rs * inp,
@@ -241,9 +243,195 @@ ggml_tensor * llm_graph_context_mamba::build_mamba2_layer(llm_graph_input_rs * i
241243
auto get_ssm_rows = [&](ggml_context * ctx, ggml_tensor * states, ggml_tensor * ids) {
242244
ggml_tensor * ssm = ggml_reshape_4d(ctx, states, d_state, head_dim, n_head, mctx_cur->get_size());
243245

244-
// TODO: use semistructured matrices to implement state-space duality
245-
// => {d_inner, n_seq_tokens, n_seqs} and {d_state, d_inner, n_seqs}
246-
return ggml_ssm_scan(ctx, ssm, x, dt, A, B, C, ids);
246+
if (n_seq_tokens == 1) {
247+
// if (true) {
248+
//DEBUG
249+
LLAMA_LOG_DEBUG("build_mamba2_layer(layer %d): single-token update\n", il);
250+
// If single-token, use ssm_scan op
251+
ssm = ggml_cast(ctx, ssm, GGML_TYPE_F32);
252+
return ggml_ssm_scan(ctx, ssm, x, dt, A, B, C, ids);
253+
} else {
254+
//DEBUG
255+
LLAMA_LOG_DEBUG("build_mamba2_layer(layer %d): multi-token chunk scan\n", il);
256+
257+
// otherwise, use the SSD formulation
258+
259+
// extract the state(s) for the sequences identified by ids
260+
if (ssm->ne[3] != ids->ne[0]) {
261+
ggml_tensor * ssm_perm = ggml_permute(ctx, ssm, 0, 2, 3, 1); // put the target dim in dim 1
262+
ggml_tensor * ids_perm_rep = ggml_repeat_4d(ctx, ids,
263+
ids->ne[0], ssm->ne[1], ssm->ne[2], 1); // repeat to match expected shape
264+
ggml_tensor * ssm_ids = ggml_get_rows(ctx, ssm_perm, ids_perm_rep); // extract ids as rows
265+
ssm = ggml_cont(ctx, ggml_permute(ctx, ssm_ids, 0, 3, 1, 2)); // permute back to original shape
266+
GGML_ASSERT(ssm->ne[3] == ids->ne[0]);
267+
}
268+
// ssm -> {d_state, head_dim, n_head, n_seqs}
269+
270+
// step 1: compute dt softplus
271+
// NOTE: In other implementations, the bias is added after
272+
// the softplus. This shouldn't be a problem, but it's a
273+
// difference.
274+
ggml_tensor * dt_softplus = ggml_softplus(ctx, dt); // {n_head, n_seq_tokens, n_seqs}
275+
dt_softplus = ggml_clamp(ctx, dt_softplus, 0.001, 100.0);
276+
cb(dt_softplus, "dt_softplus", il);
277+
278+
// step 2: compute dtA and dtX
279+
ggml_tensor * dtA = ggml_mul(ctx, dt_softplus, ggml_reshape_1d(ctx, A, A->ne[1])); // {n_head, n_seq_tokens, n_seqs}
280+
cb(dtA, "dtA", il);
281+
ggml_tensor * dtX = ggml_mul(ctx, x, ggml_reshape_4d(ctx, dt_softplus, 1, dt_softplus->ne[0], dt_softplus->ne[1], dt_softplus->ne[2])); // {head_dim, n_head, n_seq_tokens, n_seqs}
282+
cb(dtX, "dtX", il);
283+
284+
// loop over all chunks
285+
uint32_t repeats = n_head / n_group;
286+
287+
// Empty y that will be extended with each chunk of tokens
288+
ggml_tensor * y = ggml_new_tensor_4d(ctx, x->type, x->ne[0], x->ne[1], 0, x->ne[3]);
289+
// TODO: make this configurable
290+
const uint32_t chunk_size = 512; // default ubatch size
291+
for (auto chunk_i = 0; chunk_i < n_seq_tokens; chunk_i += chunk_size) {
292+
ggml_tensor * dtA_chunk;
293+
ggml_tensor * dtX_chunk;
294+
ggml_tensor * B_chunk;
295+
ggml_tensor * C_chunk;
296+
const auto chunk_size_i = std::min(chunk_size, uint32_t(n_seq_tokens - chunk_i));
297+
if (chunk_size_i == n_seq_tokens) {
298+
dtA_chunk = dtA;
299+
dtX_chunk = dtX;
300+
B_chunk = B;
301+
C_chunk = C;
302+
} else {
303+
// chunk views
304+
// slice dtA on dim 1
305+
dtA_chunk = ggml_view_3d(ctx, dtA,
306+
dtA->ne[0], chunk_size_i, dtA->ne[2],
307+
dtA->nb[1], dtA->nb[2],
308+
chunk_i * dtA->nb[1]);
309+
// slice dtX on dim 2
310+
dtX_chunk = ggml_view_4d(ctx, dtX,
311+
dtX->ne[0], dtX->ne[1], chunk_size_i, dtX->ne[3],
312+
dtX->nb[1], dtX->nb[2], dtX->nb[3],
313+
chunk_i * dtX->nb[2]);
314+
// slice B on dim 2
315+
B_chunk = ggml_view_4d(ctx, B,
316+
B->ne[0], B->ne[1], chunk_size_i, B->ne[3],
317+
B->nb[1], B->nb[2], B->nb[3],
318+
chunk_i * B->nb[2]);
319+
// slice C on dim 2
320+
C_chunk = ggml_view_4d(ctx, C,
321+
C->ne[0], C->ne[1], chunk_size_i, C->ne[3],
322+
C->nb[1], C->nb[2], C->nb[3],
323+
chunk_i * C->nb[2]);
324+
}
325+
cb(dtA_chunk, "dtA_chunk", il); // {n_head, chunk_size_i, n_seqs}
326+
cb(dtX_chunk, "dtX_chunk", il); // {head_dim, n_head, chunk_size_i, n_seqs}
327+
cb(B_chunk, "B_chunk", il); // {d_state, n_group, chunk_size_i, n_seqs}
328+
cb(C_chunk, "C_chunk", il); // {d_state, n_group, chunk_size_i, n_seqs}
329+
330+
// step 3: compute CB
331+
ggml_tensor * C_perm = ggml_permute(ctx, C_chunk, 0, 2, 1, 3); // {d_state, chunk_size_i, n_group, n_seqs}
332+
ggml_tensor * B_perm = ggml_permute(ctx, B_chunk, 0, 2, 1, 3); // {d_state, chunk_size_i, n_group, n_seqs}
333+
ggml_tensor * CB = ggml_mul_mat(ctx, B_perm, C_perm); // {chunk_size_i, chunk_size_i, n_group, n_seqs}
334+
CB = ggml_repeat_4d(ctx, CB, CB->ne[0], CB->ne[1], CB->ne[2] * repeats, CB->ne[3]); // {chunk_size_i, chunk_size_i, n_head (repeats * n_group), n_seqs}
335+
cb(CB, "CB", il);
336+
337+
// step 4: compute decay
338+
dtA_chunk = ggml_permute(ctx, dtA_chunk, 2, 1, 3, 0); // {1, chunk_size_i, n_head, n_seqs}
339+
ggml_tensor * dtA_tmp0 = ggml_repeat_4d(ctx, dtA_chunk,
340+
dtA_chunk->ne[0] * chunk_size_i, dtA_chunk->ne[1], dtA_chunk->ne[2], dtA_chunk->ne[3]); // {chunk_size_i_0, chunk_size_i_1, n_head, n_seqs}
341+
ggml_tensor * dtA_tmp1 = ggml_tri(ctx, dtA_tmp0, GGML_TRI_TYPE_LOWER); // {chunk_size_i_0, chunk_size_i_1, n_head, n_seqs}
342+
ggml_tensor * segsum = ggml_cumsum(ctx, dtA_tmp1); // {chunk_size_i_0, chunk_size_i_1, n_head, n_seqs}
343+
segsum = ggml_cont(ctx, ggml_transpose(ctx, segsum)); // {chunk_size_i_1, chunk_size_i_0, n_head, n_seqs}
344+
cb(segsum, "segsum", il);
345+
ggml_tensor * decay = ggml_exp(ctx, segsum); // {chunk_size_i_1, chunk_size_i_0, n_head, n_seqs}
346+
cb(decay, "decay", il);
347+
348+
// step 5: compute surrogate_attention_matrix
349+
ggml_tensor * CBdecay = ggml_mul(ctx, CB, decay);
350+
ggml_tensor * surrogate_attention_matrix = ggml_tri(ctx, CBdecay, GGML_TRI_TYPE_LOWER_DIAG);
351+
cb(surrogate_attention_matrix, "surrogate_attention_matrix", il);
352+
353+
// step 6: compute y
354+
ggml_tensor * dtX_chunk_perm = ggml_cont(ctx, ggml_permute(ctx, dtX_chunk, 1, 2, 0, 3));
355+
ggml_tensor * y_chunk = ggml_mul_mat(ctx, dtX_chunk_perm, surrogate_attention_matrix);
356+
y_chunk = ggml_cont(ctx, ggml_permute(ctx, y_chunk, 0, 2, 1, 3));
357+
cb(y_chunk, "y_chunk", il); // {n_head, chunk_size_i, n_seqs}
358+
359+
// step 7: compute dtxdecay
360+
ggml_tensor * decay_last = ggml_view_4d(ctx, decay,
361+
decay->ne[0], 1, decay->ne[2], decay->ne[3],
362+
decay->nb[1], decay->nb[2], decay->nb[3],
363+
(decay->ne[1] - 1) * decay->nb[1]);
364+
decay_last = ggml_cont(ctx, ggml_permute(ctx, decay_last, 2, 0, 1, 3));
365+
cb(decay_last, "decay_last", il);
366+
B_perm = ggml_cont(ctx, B_perm);
367+
B_perm = ggml_repeat_4d(ctx, B_perm,
368+
B_perm->ne[0], B_perm->ne[1], B_perm->ne[2] * repeats, B_perm->ne[3]);
369+
ggml_tensor * dtxdecay = ggml_mul(ctx, dtX_chunk, decay_last);
370+
dtxdecay = ggml_cont(ctx, ggml_permute(ctx, dtxdecay, 1, 2, 0, 3));
371+
cb(dtxdecay, "dtxdecay", il);
372+
373+
// step 8: compute next_state
374+
ggml_tensor * next_state = ggml_mul_mat(ctx, ggml_cont(ctx, ggml_permute(ctx, B_perm, 1, 0, 2, 3)), dtxdecay);
375+
if (next_state->type != ssm->type) {
376+
next_state = ggml_cast(ctx, next_state, ssm->type);
377+
}
378+
cb(next_state, "next_state", il);
379+
380+
// TODO: Skip y and state updates if no previous state
381+
382+
// step 9: update from previous state
383+
dtA_chunk = ggml_cont(ctx, dtA_chunk);
384+
ggml_tensor * dtA_chunk_flat = ggml_view_3d(ctx,
385+
dtA_chunk, dtA_chunk->ne[1], dtA_chunk->ne[2], dtA_chunk->ne[3],
386+
dtA_chunk->nb[2], dtA_chunk->nb[3], 0); // {chunk_size_i, n_head, n_seqs, 1}
387+
ggml_tensor * exp_dtA_cumsum = ggml_view_4d(ctx,
388+
ggml_exp(ctx, ggml_cumsum(ctx, dtA_chunk_flat)),
389+
1, dtA_chunk->ne[1], dtA_chunk->ne[2], dtA_chunk->ne[3],
390+
dtA_chunk->nb[1], dtA_chunk->nb[2], dtA_chunk->nb[3], 0); // {1, chunk_size_i, n_head, n_seqs}
391+
cb(exp_dtA_cumsum, "exp_dtA_cumsum", il);
392+
ggml_tensor * exp_dtA_cumsum_last = ggml_view_4d(ctx, exp_dtA_cumsum,
393+
exp_dtA_cumsum->ne[0], 1, exp_dtA_cumsum->ne[2], exp_dtA_cumsum->ne[3],
394+
exp_dtA_cumsum->nb[1], exp_dtA_cumsum->nb[2], exp_dtA_cumsum->nb[3],
395+
(exp_dtA_cumsum->ne[1] - 1) * exp_dtA_cumsum->nb[1]); // {1, 1, n_head, n_seqs}
396+
cb(exp_dtA_cumsum_last, "exp_dtA_cumsum_last", il);
397+
// ggml_tensor * exp_dtA_cumsum_perm = ggml_permute(ctx, exp_dtA_cumsum_last, 2, 1, 3, 0); // {1, 1, n_head, n_seqs}
398+
next_state = ggml_add(ctx, next_state, ggml_mul(ctx, ssm, ggml_cont(ctx, exp_dtA_cumsum_last)));
399+
cb(next_state, "next_state_updated", il);
400+
401+
// step 10: update from previous y
402+
ggml_tensor * y_prev = ggml_mul_mat(ctx,
403+
C_perm, // {d_state, chunk_size_i, n_group, n_seqs}
404+
ssm // {d_state, head_dim, n_head, n_seqs}
405+
); // {chunk_size_i, head_dim, n_head, n_seqs}
406+
cb(y_prev, "y_prev", il);
407+
y_prev = ggml_mul(ctx,
408+
ggml_cont(ctx, ggml_permute(ctx, y_prev, 2, 0, 1, 3)), // {head_dim, n_head, chunk_size_i, n_seqs}
409+
ggml_cont(ctx, ggml_permute(ctx, exp_dtA_cumsum, 0, 2, 1, 3)) // {1, n_head, chunk_size_i, n_seqs}
410+
); // {head_dim, chunk_size_i, n_head, n_seqs}
411+
cb(y_prev, "y_prev_mul", il);
412+
y_chunk = ggml_add(ctx, y_chunk, y_prev);
413+
cb(y_chunk, "y_chunk_updated", il);
414+
415+
// step 11: recurse
416+
if (chunk_size_i == n_seq_tokens) {
417+
y = y_chunk;
418+
} else {
419+
y = ggml_concat(ctx, y, y_chunk, 2);
420+
}
421+
cb(y, "y", il);
422+
ssm = next_state;
423+
}
424+
425+
// Concat the output y and state
426+
if (ssm->type != y->type) {
427+
ssm = ggml_cast(ctx, ssm, y->type);
428+
}
429+
ggml_tensor * out = ggml_concat(ctx,
430+
ggml_view_1d(ctx, y, ggml_nelements(y), 0),
431+
ggml_view_1d(ctx, ssm, ggml_nelements(ssm), 0),
432+
0);
433+
return out;
434+
}
247435
};
248436

249437
ggml_tensor * y_ssm = build_rs(inp, ssm_states_all, hparams.n_embd_s(), ubatch.n_seqs, get_ssm_rows);

0 commit comments

Comments
 (0)