Skip to content
This repository was archived by the owner on Sep 23, 2024. It is now read-only.

Commit 4f3ba52

Browse files
committed
opt for multi query
1 parent 581a3cf commit 4f3ba52

2 files changed

Lines changed: 17 additions & 17 deletions

File tree

src/mha_gpt_amx.cpp

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,8 @@ void mha_gpt_impl_amx::mha_bf16(const tensor& q, const tensor& k, const tensor&
116116
auto head_size = q.m_dims[3];
117117
auto key_seq_len = k.m_dims[2];
118118
bool is_bloom = k.m_strides[3] > k.m_strides[2];
119+
auto h_group_num = k.m_dims[1];
120+
size_t h_each_group_len = head_num / h_group_num;
119121

120122
uint8_t* out = output.data<uint8_t>();
121123

@@ -130,8 +132,8 @@ void mha_gpt_impl_amx::mha_bf16(const tensor& q, const tensor& k, const tensor&
130132
if (use_gemv) {
131133
parallel_for2d(batch, head_num, [&](size_t thread_id, size_t i0, size_t i1) {
132134
auto q_sub = &q.at<uint8_t>({i0, i1});
133-
auto k_sub = &k.at<uint8_t>({i0, i1});
134-
auto v_sub = &v.at<uint8_t>({i0, i1});
135+
auto k_sub = &k.at<uint8_t>({i0, i1 / h_each_group_len});
136+
auto v_sub = &v.at<uint8_t>({i0, i1 / h_each_group_len});
135137

136138
auto mat0_out = reinterpret_cast<uint8_t*>(_buffer_mat0_out + thread_id * _buffer_mat0_out_size);
137139
auto mat1_out = reinterpret_cast<uint8_t*>(_buffer_mat1_out + thread_id * _buffer_mat1_out_size);
@@ -177,8 +179,8 @@ void mha_gpt_impl_amx::mha_bf16(const tensor& q, const tensor& k, const tensor&
177179
// k: [batch, head_num, key_seq_len, head_size]
178180
// v: [batch, head_num, value_seq_len, head_size]
179181
auto q_sub = &q.at<ov::bfloat16>({i0, i1, seq_start});
180-
auto k_sub = &k.at<ov::bfloat16>({i0, i1});
181-
auto v_sub = &v.at<ov::bfloat16>({i0, i1});
182+
auto k_sub = &k.at<ov::bfloat16>({i0, i1 / h_each_group_len});
183+
auto v_sub = &v.at<ov::bfloat16>({i0, i1 / h_each_group_len});
182184

183185
auto mat0_out = reinterpret_cast<float*>(_buffer_mat0_out + thread_id * _buffer_mat0_out_size);
184186
auto mat1_out = reinterpret_cast<float*>(_buffer_mat1_out + thread_id * _buffer_mat1_out_size);
@@ -279,7 +281,7 @@ status_t mha_gpt_impl_amx::exec(const tensor& q, const tensor& k, const tensor&
279281
auto key_seq_len = k.m_dims[2];
280282

281283
if (!(batch == k.m_dims[0] && batch == v.m_dims[0] &&
282-
head_num == k.m_dims[1] && head_num == v.m_dims[1] &&
284+
k.m_dims[1] == v.m_dims[1] &&
283285
key_seq_len == v.m_dims[2] &&
284286
head_size == k.m_dims[3] && head_size == v.m_dims[3])) {
285287
DEBUG_LOG << "dim of q,k,v is error.\n";

src/mm_kernel_common_amx.hpp

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2476,21 +2476,13 @@ template<>
24762476
struct Matmul<ov::bfloat16, uint8_t, float> {
24772477
tensor2D<uint8_t> internalBI8;
24782478

2479-
// wei_buff is ping-pong buffer containing ov::bfloat16 weights decompressed on the fly.
2480-
tensor2D<ov::bfloat16> weiBuff;
2481-
24822479
bool constB;
24832480
bool transposeB;
24842481

24852482
constexpr static int kStep = 32;
24862483

2487-
// 2x2 C tiles buffer
2488-
// most usecase requires post-processing with AVX, thus buffC
2489-
// is used to transfer data to AVX register
2490-
tensor2D<float> buffC;
2491-
24922484
Matmul(bool constB = false, bool transposeB = false) :
2493-
constB(constB), transposeB(transposeB), buffC(32, 32) {}
2485+
constB(constB), transposeB(transposeB) {}
24942486

24952487
float* dequant_scale_B;
24962488
float* zp;
@@ -2500,6 +2492,14 @@ struct Matmul<ov::bfloat16, uint8_t, float> {
25002492
tensor2D<ov::bfloat16> & _matB,
25012493
int n0, int n1,
25022494
PP ppkernel) {
2495+
alignas(64) float buff[32 * 32];
2496+
// wei_buff is ping-pong buffer containing ov::bfloat16 weights decompressed on the fly.
2497+
alignas(64) ov::bfloat16 weiBuff[32 * 2 * 32];
2498+
// 2x2 C tiles buffer
2499+
// most usecase requires post-processing with AVX, thus buffC
2500+
// is used to transfer data to AVX register
2501+
tensor2D<float> buffC(32, 32, buff, 32 * sizeof(float));
2502+
25032503
auto matB = getSubMatB(_matB, n0, n1, transposeB);
25042504
int M = matA.dims[0];
25052505
int K = matA.dims[1];
@@ -2523,9 +2523,7 @@ struct Matmul<ov::bfloat16, uint8_t, float> {
25232523
//constexpr int prefetch_ahead = 64*1024;
25242524
tileconfig_t tfg(1, 0, {M,M,M,16,16}, 64);
25252525
auto * pBint = reinterpret_cast<int8_t*>(&internalBI8[0]);
2526-
auto & B2buff = weiBuff;
2527-
B2buff.resize(32*2, 32);
2528-
auto * const pB = &B2buff[0];
2526+
auto * const pB = weiBuff;
25292527
auto * pBsrc = pB + (32*32) * 0;
25302528
auto * pBdst = pB + (32*32) * 1;
25312529
auto * const pC0 = &buffC[0];

0 commit comments

Comments
 (0)