Skip to content

Commit 4e76d24

Browse files
authored
ggml : fix AMX and add batched support (ggml-org#19925)
llama-perplexity -hf ggml-org/Qwen3-0.6B-GGUF:Q4_0 -f wikitext-2-raw/wiki.test.raw -c 2048 -b 2048 --chunks 2 before this commit: ``` perplexity: calculating perplexity over 2 chunks, n_ctx=2048, batch_size=2048, n_seq=1 perplexity: 2.31 seconds per pass - ETA 0.07 minutes [1]17.3868,[2]22.2199, Final estimate: PPL = 22.2199 +/- 1.59692 llama_perf_context_print: load time = 878.56 ms llama_perf_context_print: prompt eval time = 2037.82 ms / 4096 tokens ( 0.50 ms per token, 2009.99 tokens per second) llama_perf_context_print: eval time = 0.00 ms / 1 runs ( 0.00 ms per token, inf tokens per second) llama_perf_context_print: total time = 6403.17 ms / 4097 tokens llama_perf_context_print: graphs reused = 0 llama_memory_breakdown_print: | memory breakdown [MiB] | total free self model context compute unaccounted | llama_memory_breakdown_print: | - Host | 845 = 318 + 224 + 302 | llama_memory_breakdown_print: | - CPU_REPACK | 288 = 288 + 0 + 0 | llama_memory_breakdown_print: | - AMX | 31 = 31 + 0 + 0 | ``` after this commit: ``` perplexity: calculating perplexity over 2 chunks, n_ctx=2048, batch_size=2048, n_seq=1 perplexity: 1.98 seconds per pass - ETA 0.05 minutes [1]17.2005,[2]21.8220, Final estimate: PPL = 21.8220 +/- 1.56485 llama_perf_context_print: load time = 719.23 ms llama_perf_context_print: prompt eval time = 1676.23 ms / 4096 tokens ( 0.41 ms per token, 2443.58 tokens per second) llama_perf_context_print: eval time = 0.00 ms / 1 runs ( 0.00 ms per token, inf tokens per second) llama_perf_context_print: total time = 4258.74 ms / 4097 tokens llama_perf_context_print: graphs reused = 0 llama_memory_breakdown_print: | memory breakdown [MiB] | total free self model context compute unaccounted | llama_memory_breakdown_print: | - Host | 845 = 318 + 224 + 302 | llama_memory_breakdown_print: | - AMX | 319 = 319 + 0 + 0 | ``` (no more CPU_REPACK) after this commit, disabling amx: ``` perplexity: calculating perplexity over 2 chunks, n_ctx=2048, batch_size=2048, n_seq=1 perplexity: 2.34 seconds per pass - ETA 0.07 minutes [1]17.2005,[2]21.8220, Final estimate: PPL = 21.8220 +/- 1.56485 llama_perf_context_print: load time = 841.91 ms llama_perf_context_print: prompt eval time = 2057.28 ms / 4096 tokens ( 0.50 ms per token, 1990.98 tokens per second) llama_perf_context_print: eval time = 0.00 ms / 1 runs ( 0.00 ms per token, inf tokens per second) llama_perf_context_print: total time = 6454.51 ms / 4097 tokens llama_perf_context_print: graphs reused = 0 llama_memory_breakdown_print: | memory breakdown [MiB] | total free self model context compute unaccounted | llama_memory_breakdown_print: | - Host | 845 = 318 + 224 + 302 | llama_memory_breakdown_print: | - CPU_REPACK | 319 = 319 + 0 + 0 | ``` => same perplexity. Signed-off-by: Adrien Gallouët <angt@huggingface.co>
1 parent 723c710 commit 4e76d24

File tree

2 files changed

+124
-101
lines changed

2 files changed

+124
-101
lines changed

ggml/src/ggml-cpu/amx/amx.cpp

Lines changed: 42 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -141,27 +141,50 @@ static size_t ggml_backend_amx_buffer_type_get_alignment(ggml_backend_buffer_typ
141141
namespace ggml::cpu::amx {
142142
class extra_buffer_type : ggml::cpu::extra_buffer_type {
143143
bool supports_op(ggml_backend_dev_t, const struct ggml_tensor * op) override {
144-
// handle only 2d gemm for now
145-
auto is_contiguous_2d = [](const struct ggml_tensor * t) {
146-
return ggml_is_contiguous(t) && t->ne[3] == 1 && t->ne[2] == 1;
147-
};
148-
149-
if (op->op == GGML_OP_MUL_MAT && is_contiguous_2d(op->src[0]) && // src0 must be contiguous
150-
is_contiguous_2d(op->src[1]) && // src1 must be contiguous
151-
op->src[0]->buffer && op->src[0]->buffer->buft == ggml_backend_amx_buffer_type() &&
152-
op->src[0]->ne[0] % (TILE_K * 2 * 32) == 0 && // TODO: not sure if correct (https://github.com/ggml-org/llama.cpp/pull/16315)
153-
op->ne[0] % (TILE_N * 2) == 0 && // out_features is 32x
154-
(qtype_has_amx_kernels(op->src[0]->type) || (op->src[0]->type == GGML_TYPE_F16))) {
155-
// src1 must be host buffer
156-
if (op->src[1]->buffer && !ggml_backend_buft_is_host(op->src[1]->buffer->buft)) {
144+
if (op->op != GGML_OP_MUL_MAT) {
145+
return false;
146+
}
147+
auto * src0 = op->src[0];
148+
auto * src1 = op->src[1];
149+
150+
if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(src1)) {
151+
return false;
152+
}
153+
if (!src0->buffer || src0->buffer->buft != ggml_backend_amx_buffer_type()) {
154+
return false;
155+
}
156+
if (src1->buffer && !ggml_backend_buft_is_host(src1->buffer->buft)) {
157+
return false;
158+
}
159+
if (op->ne[0] % (TILE_N * 2)) {
160+
return false;
161+
}
162+
int alignment;
163+
switch (src0->type) {
164+
case GGML_TYPE_Q4_0:
165+
case GGML_TYPE_Q4_1:
166+
case GGML_TYPE_Q8_0:
167+
alignment = TILE_K;
168+
break;
169+
case GGML_TYPE_Q4_K:
170+
case GGML_TYPE_Q5_K:
171+
case GGML_TYPE_Q6_K:
172+
case GGML_TYPE_IQ4_XS:
173+
alignment = 256; // QK_K
174+
break;
175+
case GGML_TYPE_F16:
176+
alignment = 16;
177+
break;
178+
default:
157179
return false;
158-
}
159-
// src1 must be float32
160-
if (op->src[1]->type == GGML_TYPE_F32) {
161-
return true;
162-
}
163180
}
164-
return false;
181+
if (src0->ne[0] % alignment) {
182+
return false;
183+
}
184+
if (src1->type != GGML_TYPE_F32) {
185+
return false;
186+
}
187+
return true;
165188
}
166189

167190
ggml::cpu::tensor_traits * get_tensor_traits(const struct ggml_tensor * op) override {

ggml/src/ggml-cpu/amx/mmq.cpp

Lines changed: 82 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
21
#if defined(__GNUC__)
32
#pragma GCC diagnostic ignored "-Wpedantic"
43
#pragma GCC diagnostic ignored "-Wunused-local-typedefs"
@@ -202,35 +201,27 @@ struct tile_config_t{
202201
// advanced-matrix-extensions-intrinsics-functions.html
203202
//
204203

205-
#define TC_CONFIG_TILE(i, r, cb) tc.rows[i] = r; tc.colsb[i] = cb
206-
void ggml_tile_config_init(void) {
207-
static thread_local bool is_first_time = true;
204+
inline void ggml_tile_config_init(void) {
205+
static thread_local bool done = false;
208206

209-
if (!is_first_time) {
207+
if (done) {
210208
return;
211209
}
212210

213-
static thread_local tile_config_t tc;
214-
tile_config_t current_tc;
215-
_tile_storeconfig(&current_tc);
216-
217-
// load only when config changes
218-
if (tc.palette_id == 0 || (memcmp(&current_tc.colsb, &tc.colsb, sizeof(uint16_t) * 8) != 0 &&
219-
memcmp(&current_tc.rows, &tc.rows, sizeof(uint8_t) * 8) != 0)) {
220-
tc.palette_id = 1;
221-
tc.start_row = 0;
222-
TC_CONFIG_TILE(TMM0, 8, 64);
223-
TC_CONFIG_TILE(TMM1, 8, 64);
224-
TC_CONFIG_TILE(TMM2, 16, 32);
225-
TC_CONFIG_TILE(TMM3, 16, 32);
226-
TC_CONFIG_TILE(TMM4, 16, 64);
227-
TC_CONFIG_TILE(TMM5, 16, 64);
228-
TC_CONFIG_TILE(TMM6, 16, 64);
229-
TC_CONFIG_TILE(TMM7, 16, 64);
230-
_tile_loadconfig(&tc);
231-
}
232-
233-
is_first_time = false;
211+
alignas(64) tile_config_t tc = {};
212+
tc.palette_id = 1;
213+
tc.start_row = 0;
214+
tc.rows[0] = 8; tc.colsb[0] = 64;
215+
tc.rows[1] = 8; tc.colsb[1] = 64;
216+
tc.rows[2] = 16; tc.colsb[2] = 32;
217+
tc.rows[3] = 16; tc.colsb[3] = 32;
218+
tc.rows[4] = 16; tc.colsb[4] = 64;
219+
tc.rows[5] = 16; tc.colsb[5] = 64;
220+
tc.rows[6] = 16; tc.colsb[6] = 64;
221+
tc.rows[7] = 16; tc.colsb[7] = 64;
222+
223+
_tile_loadconfig(&tc);
224+
done = true;
234225
}
235226

236227
// we need an extra 16 * 4B (TILE_N * int32_t) for each NB/KB block for compensation.
@@ -268,33 +259,6 @@ int get_row_size(int K) {
268259
return row_size;
269260
}
270261

271-
// vectorized dtype conversion
272-
inline float FP16_TO_FP32(ggml_half val) {
273-
__m256i v = _mm256_setr_epi16(
274-
val, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0);
275-
__m512 o = _mm512_cvtph_ps(v);
276-
return _mm512_cvtss_f32(o);
277-
}
278-
279-
inline __m512 FP16_TO_FP32_VEC(ggml_half val) {
280-
__m256i v = _mm256_set1_epi16(val);
281-
return _mm512_cvtph_ps(v);
282-
}
283-
284-
// horizontal reduce
285-
inline float _mm512_reduce_max_ps(const __m512 x) {
286-
__m512 v = x;
287-
__m512 v1 = _mm512_shuffle_f32x4(v, v, 0x4E);
288-
v = _mm512_max_ps(v, v1);
289-
v1 = _mm512_shuffle_f32x4(v, v, 0xB1);
290-
v = _mm512_max_ps(v, v1);
291-
v1 = _mm512_shuffle_ps(v, v, 0x4E);
292-
v = _mm512_max_ps(v, v1);
293-
v1 = _mm512_shuffle_ps(v, v, 0xB1);
294-
v = _mm512_max_ps(v, v1);
295-
return _mm512_cvtss_f32(v);
296-
}
297-
298262
// transpose utils
299263
#define SHUFFLE_EPI32(a, b, mask) \
300264
_mm256_castps_si256(_mm256_shuffle_ps(_mm256_castsi256_ps(a), _mm256_castsi256_ps(b), mask))
@@ -1370,9 +1334,9 @@ struct tinygemm_kernel_avx<float, ggml_fp16_t, float, BLOCK_M, BLOCK_N, BLOCK_K>
13701334

13711335
#define LAUNCH_TINYGEMM_KERNEL_AVX(MB_SIZE, NB_SIZE) \
13721336
tinygemm_kernel_avx<float, type, float, MB_SIZE, NB_SIZE, blck_size>::apply( \
1373-
K, (const float *)src1->data + mb_start * K, \
1374-
(const type *)src0->data + nb_start * K, \
1375-
(float *)dst->data + mb_start * ldc + nb_start, ldc);
1337+
K, (const float *)src1->data + src1_offset + mb_start * K, \
1338+
(const type *)src0->data + src0_offset + nb_start * K, \
1339+
(float *)dst->data + dst_offset + mb_start * ldc + nb_start, ldc)
13761340

13771341

13781342
// re-organize in the format {NB, KB, TILE_SIZE}:
@@ -2019,11 +1983,11 @@ struct tinygemm_kernel_vnni<block_q8_K, block_iq4_xs, float, BLOCK_M, BLOCK_N, B
20191983
}
20201984
};
20211985

2022-
#define LAUNCH_TINYGEMM_KERNEL_VNNI(NB_SIZE) \
2023-
tinygemm_kernel_vnni<vec_dot_type, type, float, 1, NB_SIZE, blck_size>::apply( \
2024-
KB, (const char *)wdata + 0 * row_size_A, \
2025-
(const char *)src0->data + PACKED_INDEX(nb * kTilesN, 0, KB, TILE_SIZE), \
2026-
(float *) dst->data + 0 * N + nb_start, ldc)
1986+
#define LAUNCH_TINYGEMM_KERNEL_VNNI(NB_SIZE) \
1987+
tinygemm_kernel_vnni<vec_dot_type, type, float, 1, NB_SIZE, blck_size>::apply( \
1988+
KB, wdata_batch, \
1989+
(const char *)src0->data + src0_offset + PACKED_INDEX(nb * kTilesN, 0, KB, TILE_SIZE), \
1990+
(float *) dst->data + dst_offset + nb_start, ldc)
20271991

20281992
template <typename TA, typename TB, typename TC, int BLOCK_K,
20291993
typename std::enable_if<!is_type_qkk<TB>::value, int>::type = 0>
@@ -2079,7 +2043,7 @@ void tinygemm_kernel_amx(int M, int N, int KB, const void * RESTRICT _A, const v
20792043
_tile_stored(TMM5, Tile5(C_pre), TILE_N * sizeof(int32_t));
20802044

20812045
if (need_unpack) {
2082-
unpack_B<TB>(Tile1, B_blk0);
2046+
unpack_B<TB>(Tile1, B_blk1);
20832047
_tile_loadd(TMM1, Tile1, TILE_N * VNNI_BLK);
20842048
} else {
20852049
_tile_loadd(TMM1, B_blk1, TILE_N * VNNI_BLK);
@@ -2336,6 +2300,13 @@ void ggml_backend_amx_convert_weight(struct ggml_tensor * tensor, const void * d
23362300
});
23372301
}
23382302

2303+
// ne2 is passed explicitly to help compiler optimize repeated calls
2304+
inline int64_t ggml_batch_offset(const ggml_tensor * t, int64_t batch_idx, int64_t ne2) {
2305+
const int64_t i2 = batch_idx % ne2;
2306+
const int64_t i3 = batch_idx / ne2;
2307+
return i3 * t->nb[3] + i2 * t->nb[2];
2308+
}
2309+
23392310
size_t ggml_backend_amx_desired_wsize(const struct ggml_tensor * dst) {
23402311
struct ggml_tensor * src0 = dst->src[0];
23412312

@@ -2348,12 +2319,13 @@ size_t ggml_backend_amx_desired_wsize(const struct ggml_tensor * dst) {
23482319

23492320
const int M = dst->ne[1];
23502321
const int K = src0->ne[0];
2322+
const int64_t n_batch = dst->ne[2] * dst->ne[3];
23512323

23522324
size_t desired_wsize = 0;
23532325

23542326
GGML_DISPATCH_QTYPES(TYPE, [&] {
23552327
const size_t row_size_A = K / blck_size * sizeof(vec_dot_type);
2356-
desired_wsize = M * row_size_A;
2328+
desired_wsize = n_batch * M * row_size_A;
23572329
});
23582330

23592331
return desired_wsize;
@@ -2365,7 +2337,7 @@ size_t ggml_backend_amx_desired_wsize(const struct ggml_tensor * dst) {
23652337
// src1: input in shape of {M, K}, float32
23662338
// dst: output in shape of {M, N}, float32
23672339
//
2368-
// the function performs: dst = src1 @ src0.T
2340+
// the function performs: dst = src1 @ src0.T for each batch
23692341
//
23702342
void ggml_backend_amx_mul_mat(const ggml_compute_params * params, struct ggml_tensor * dst) {
23712343
struct ggml_tensor * src0 = dst->src[0];
@@ -2382,17 +2354,26 @@ void ggml_backend_amx_mul_mat(const ggml_compute_params * params, struct ggml_te
23822354
const int K = src0->ne[0];
23832355
const int ldc = dst->nb[1] / dst->nb[0];
23842356

2357+
const int64_t ne2 = dst->ne[2];
2358+
const int64_t n_batch = ne2 * dst->ne[3];
2359+
23852360
if (is_floating_type) {
23862361
constexpr int BLOCK_M = 4;
23872362
constexpr int BLOCK_N = 6;
23882363
const int MB = div_up(M, BLOCK_M);
23892364
const int NB = div_up(N, BLOCK_N);
23902365

2391-
parallel_for_ggml(params, MB * NB, [&](int begin, int end) {
2366+
parallel_for_ggml(params, n_batch * MB * NB, [&](int begin, int end) {
23922367
GGML_DISPATCH_FLOATING_TYPES(TYPE, [&] {
23932368
for (int i = begin; i < end; ++i) {
2394-
int mb = i / NB;
2395-
int nb = i % NB;
2369+
int batch_idx = i / (MB * NB);
2370+
int remaining = i % (MB * NB);
2371+
int mb = remaining / NB;
2372+
int nb = remaining % NB;
2373+
2374+
int64_t src0_offset = ggml_batch_offset(src0, batch_idx, ne2);
2375+
int64_t src1_offset = ggml_batch_offset(src1, batch_idx, ne2);
2376+
int64_t dst_offset = ggml_batch_offset(dst, batch_idx, ne2);
23962377

23972378
int mb_start = mb * BLOCK_M;
23982379
int mb_size = std::min(BLOCK_M, M - mb_start);
@@ -2424,10 +2405,10 @@ void ggml_backend_amx_mul_mat(const ggml_compute_params * params, struct ggml_te
24242405
void * wdata = params->wdata;
24252406

24262407
//TODO: performance improvement: merge quant A
2427-
if (params->ith == 0) {
2408+
// if (params->ith == 0) {
24282409
GGML_DISPATCH_QTYPES(TYPE, [&] {
24292410
const size_t row_size_A = K / blck_size * sizeof(vec_dot_type);
2430-
const size_t desired_wsize = M * row_size_A;
2411+
const size_t desired_wsize = n_batch * M * row_size_A;
24312412
if (params->wsize < desired_wsize) {
24322413
GGML_ABORT("insufficient work space size");
24332414
}
@@ -2436,12 +2417,19 @@ void ggml_backend_amx_mul_mat(const ggml_compute_params * params, struct ggml_te
24362417
// Q4_K, Q5_K, Q6_K, IQ4_XS handles 8 TILE_K per blck_size
24372418
GGML_ASSERT(TILE_K == blck_size || TILE_K * 8 == blck_size);
24382419

2439-
const float * A_data = static_cast<const float *>(src1->data);
2440-
for (int m = 0; m < M; ++m) {
2441-
from_float<vec_dot_type>(A_data + m * K, (char *)wdata + m * row_size_A, K);
2442-
}
2420+
parallel_for_ggml(params, n_batch, [&](int begin, int end) {
2421+
for (int batch_idx = begin; batch_idx < end; ++batch_idx) {
2422+
int64_t src1_offset = ggml_batch_offset(src1, batch_idx, ne2);
2423+
const float * A_data = (const float *)((const char *)src1->data + src1_offset);
2424+
char * wdata_batch = (char *)wdata + batch_idx * M * row_size_A;
2425+
2426+
for (int m = 0; m < M; ++m) {
2427+
from_float<vec_dot_type>(A_data + m * K, wdata_batch + m * row_size_A, K);
2428+
}
2429+
}
2430+
});
24432431
});
2444-
}
2432+
// }
24452433

24462434
ggml_barrier(params->threadpool);
24472435

@@ -2451,13 +2439,19 @@ void ggml_backend_amx_mul_mat(const ggml_compute_params * params, struct ggml_te
24512439
constexpr int BLOCK_N = TILE_N * kTilesN;
24522440
const int NB = div_up(N, BLOCK_N);
24532441

2454-
parallel_for_ggml(params, NB, [&](int begin, int end) {
2442+
parallel_for_ggml(params, n_batch * NB, [&](int begin, int end) {
24552443
GGML_DISPATCH_QTYPES(TYPE, [&] {
24562444
const int KB = K / blck_size;
24572445
const int TILE_SIZE = get_tile_size<type>();
24582446
const int row_size_A = KB * sizeof(vec_dot_type);
24592447
for (int i = begin; i < end; ++i) {
2460-
int nb = i;
2448+
int batch_idx = i / NB;
2449+
int nb = i % NB;
2450+
2451+
int64_t src0_offset = ggml_batch_offset(src0, batch_idx, ne2);
2452+
int64_t dst_offset = ggml_batch_offset(dst, batch_idx, ne2);
2453+
const char * wdata_batch = (const char *)wdata + batch_idx * row_size_A;
2454+
24612455
int nb_start = nb * BLOCK_N;
24622456
int nb_size = std::min(BLOCK_N, N - nb_start); // 32, 64, 96
24632457

@@ -2481,7 +2475,7 @@ void ggml_backend_amx_mul_mat(const ggml_compute_params * params, struct ggml_te
24812475
const int MB = div_up(M, BLOCK_M);
24822476
const int NB = div_up(N, BLOCK_N);
24832477

2484-
parallel_for_ggml(params, MB * NB, [&](int begin, int end) {
2478+
parallel_for_ggml(params, n_batch * MB * NB, [&](int begin, int end) {
24852479
// init tile config for each thread
24862480
ggml_tile_config_init();
24872481

@@ -2491,8 +2485,14 @@ void ggml_backend_amx_mul_mat(const ggml_compute_params * params, struct ggml_te
24912485
const int row_size_A = KB * sizeof(vec_dot_type);
24922486

24932487
for (int i = begin; i < end; ++i) {
2494-
int mb = i / NB;
2495-
int nb = i % NB;
2488+
int batch_idx = i / (MB * NB);
2489+
int remaining = i % (MB * NB);
2490+
int mb = remaining / NB;
2491+
int nb = remaining % NB;
2492+
2493+
int64_t src0_offset = ggml_batch_offset(src0, batch_idx, ne2);
2494+
int64_t dst_offset = ggml_batch_offset(dst, batch_idx, ne2);
2495+
const char * wdata_batch = (const char *)wdata + batch_idx * M * row_size_A;
24962496

24972497
int mb_start = mb * BLOCK_M;
24982498
int mb_size = std::min(BLOCK_M, M - mb_start);
@@ -2501,9 +2501,9 @@ void ggml_backend_amx_mul_mat(const ggml_compute_params * params, struct ggml_te
25012501

25022502
tinygemm_kernel_amx<vec_dot_type, type, float, blck_size>(
25032503
mb_size, nb_size, KB,
2504-
(const char *)wdata + mb_start * row_size_A,
2505-
(const char *)src0->data + PACKED_INDEX(nb * 2, 0, KB, TILE_SIZE),
2506-
(float *) dst->data + mb_start * N + nb_start, ldc);
2504+
wdata_batch + mb_start * row_size_A,
2505+
(const char *)src0->data + src0_offset + PACKED_INDEX(nb * 2, 0, KB, TILE_SIZE),
2506+
(float *) dst->data + dst_offset + mb_start * N + nb_start, ldc);
25072507
}
25082508
});
25092509
});

0 commit comments

Comments
 (0)