Skip to content
This repository was archived by the owner on Sep 23, 2024. It is now read-only.

Commit 3c7ac52

Browse files
committed
opt for multi query
1 parent 581a3cf commit 3c7ac52

1 file changed

Lines changed: 7 additions & 5 deletions

File tree

src/mha_gpt_amx.cpp

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,8 @@ void mha_gpt_impl_amx::mha_bf16(const tensor& q, const tensor& k, const tensor&
116116
auto head_size = q.m_dims[3];
117117
auto key_seq_len = k.m_dims[2];
118118
bool is_bloom = k.m_strides[3] > k.m_strides[2];
119+
auto h_group_num = k.m_dims[1];
120+
size_t h_each_group_len = head_num / h_group_num;
119121

120122
uint8_t* out = output.data<uint8_t>();
121123

@@ -130,8 +132,8 @@ void mha_gpt_impl_amx::mha_bf16(const tensor& q, const tensor& k, const tensor&
130132
if (use_gemv) {
131133
parallel_for2d(batch, head_num, [&](size_t thread_id, size_t i0, size_t i1) {
132134
auto q_sub = &q.at<uint8_t>({i0, i1});
133-
auto k_sub = &k.at<uint8_t>({i0, i1});
134-
auto v_sub = &v.at<uint8_t>({i0, i1});
135+
auto k_sub = &k.at<uint8_t>({i0, i1 / h_each_group_len});
136+
auto v_sub = &v.at<uint8_t>({i0, i1 / h_each_group_len});
135137

136138
auto mat0_out = reinterpret_cast<uint8_t*>(_buffer_mat0_out + thread_id * _buffer_mat0_out_size);
137139
auto mat1_out = reinterpret_cast<uint8_t*>(_buffer_mat1_out + thread_id * _buffer_mat1_out_size);
@@ -177,8 +179,8 @@ void mha_gpt_impl_amx::mha_bf16(const tensor& q, const tensor& k, const tensor&
177179
// k: [batch, head_num, key_seq_len, head_size]
178180
// v: [batch, head_num, value_seq_len, head_size]
179181
auto q_sub = &q.at<ov::bfloat16>({i0, i1, seq_start});
180-
auto k_sub = &k.at<ov::bfloat16>({i0, i1});
181-
auto v_sub = &v.at<ov::bfloat16>({i0, i1});
182+
auto k_sub = &k.at<ov::bfloat16>({i0, i1 / h_each_group_len});
183+
auto v_sub = &v.at<ov::bfloat16>({i0, i1 / h_each_group_len});
182184

183185
auto mat0_out = reinterpret_cast<float*>(_buffer_mat0_out + thread_id * _buffer_mat0_out_size);
184186
auto mat1_out = reinterpret_cast<float*>(_buffer_mat1_out + thread_id * _buffer_mat1_out_size);
@@ -279,7 +281,7 @@ status_t mha_gpt_impl_amx::exec(const tensor& q, const tensor& k, const tensor&
279281
auto key_seq_len = k.m_dims[2];
280282

281283
if (!(batch == k.m_dims[0] && batch == v.m_dims[0] &&
282-
head_num == k.m_dims[1] && head_num == v.m_dims[1] &&
284+
k.m_dims[1] == v.m_dims[1] &&
283285
key_seq_len == v.m_dims[2] &&
284286
head_size == k.m_dims[3] && head_size == v.m_dims[3])) {
285287
DEBUG_LOG << "dim of q,k,v is error.\n";

0 commit comments

Comments
 (0)