Skip to content
49 changes: 48 additions & 1 deletion ggml/src/ggml-metal/ggml-metal-device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -411,6 +411,23 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_conv(ggml_me
return res;
}

ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_conv_batched(ggml_metal_library_t lib, const ggml_tensor * op) {
GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32);
GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32);

GGML_ASSERT(ggml_is_contiguous(op->src[0]));
GGML_ASSERT(ggml_is_contiguous(op->src[1]));

const char * name = "kernel_ssm_conv_f32_f32_b256";

ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
if (!res.pipeline) {
res = ggml_metal_library_compile_pipeline(lib, name, name, nullptr);
}

return res;
}

ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_scan(ggml_metal_library_t lib, const ggml_tensor * op) {
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);

Expand All @@ -427,7 +444,37 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_scan(ggml_me
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
}

res.smem = 32*sizeof(float)*nsg;
// Shared memory layout:
// - sgptg * NW floats for partial sums (nsg * 32)
// - sgptg floats for shared_x_dt (nsg)
// - sgptg floats for shared_dA (nsg)
// Total: nsg * (32 + 2) floats
res.smem = (32 + 2)*sizeof(float)*nsg;

return res;
}

ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_scan_ssd(ggml_metal_library_t lib, const ggml_tensor * op) {
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);

char base[256];
char name[256];

const int nsg = (ne00 + 31)/32;

snprintf(base, 256, "kernel_ssm_scan_ssd_%s", ggml_type_name(op->src[0]->type));
snprintf(name, 256, "%s_nsg=%d", base, nsg);

ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
if (!res.pipeline) {
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
}

// Shared memory layout for SSD kernel:
// - BATCH_SIZE * sgptg floats for partial sums
// BATCH_SIZE = 8, so 8 * nsg floats
constexpr int BATCH_SIZE = 8;
res.smem = BATCH_SIZE * nsg * sizeof(float);

return res;
}
Expand Down
2 changes: 2 additions & 0 deletions ggml/src/ggml-metal/ggml-metal-device.h
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,9 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_cumsum_ad
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tri (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_soft_max (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_conv (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_conv_batched (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_scan (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_scan_ssd (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_rwkv (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv_ext (ggml_metal_library_t lib, enum ggml_type tsrc0, enum ggml_type tsrc1, int nsg, int nxpsg, int r1ptg);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mm (ggml_metal_library_t lib, const struct ggml_tensor * op);
Expand Down
41 changes: 33 additions & 8 deletions ggml/src/ggml-metal/ggml-metal-ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1365,15 +1365,34 @@ int ggml_metal_op_ssm_conv(ggml_metal_op_t ctx, int idx) {
/*.nb2 =*/ nb2,
};

auto pipeline = ggml_metal_library_get_pipeline_ssm_conv(lib, op);
// Use batched kernel for prefill (ne1 > 1) to reduce threadgroup dispatch overhead
constexpr int BATCH_SIZE = 256;
const bool use_batched = (ne1 > 1);

ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0);
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 1);
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[1]), 2);
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 3);
if (use_batched) {
auto pipeline = ggml_metal_library_get_pipeline_ssm_conv_batched(lib, op);

ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0);
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 1);
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[1]), 2);
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 3);

ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne1, ne02, 1, 1, 1);
// Dispatch: ne01 rows, ceil(ne1/BATCH_SIZE) token batches, ne02 sequences
// Each threadgroup has BATCH_SIZE threads, each handling one token
const int n_token_batches = (ne1 + BATCH_SIZE - 1) / BATCH_SIZE;
ggml_metal_encoder_dispatch_threadgroups(enc, ne01, n_token_batches, ne02, BATCH_SIZE, 1, 1);
} else {
auto pipeline = ggml_metal_library_get_pipeline_ssm_conv(lib, op);

ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0);
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 1);
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[1]), 2);
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 3);

ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne1, ne02, 1, 1, 1);
}

return 1;
}
Expand Down Expand Up @@ -1452,7 +1471,13 @@ int ggml_metal_op_ssm_scan(ggml_metal_op_t ctx, int idx) {
/*.nb0 =*/ nb0,
};

auto pipeline = ggml_metal_library_get_pipeline_ssm_scan(lib, op);
auto pipeline = (n_seq_tokens > 1)
? ggml_metal_library_get_pipeline_ssm_scan_ssd(lib, op)
: ggml_metal_library_get_pipeline_ssm_scan(lib, op);

// // Use sequential scan for now - the SSD kernel needs further optimization
// // to be competitive with the efficient sequential implementation
// auto pipeline = ggml_metal_library_get_pipeline_ssm_scan(lib, op);

GGML_ASSERT(d_state <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));

Expand Down
Loading
Loading