@@ -116,6 +116,8 @@ void mha_gpt_impl_amx::mha_bf16(const tensor& q, const tensor& k, const tensor&
116116 auto head_size = q.m_dims [3 ];
117117 auto key_seq_len = k.m_dims [2 ];
118118 bool is_bloom = k.m_strides [3 ] > k.m_strides [2 ];
119+ auto h_group_num = k.m_dims [1 ];
120+ size_t h_each_group_len = head_num / h_group_num;
119121
120122 uint8_t * out = output.data <uint8_t >();
121123
@@ -130,8 +132,8 @@ void mha_gpt_impl_amx::mha_bf16(const tensor& q, const tensor& k, const tensor&
130132 if (use_gemv) {
131133 parallel_for2d (batch, head_num, [&](size_t thread_id, size_t i0, size_t i1) {
132134 auto q_sub = &q.at <uint8_t >({i0, i1});
133- auto k_sub = &k.at <uint8_t >({i0, i1});
134- auto v_sub = &v.at <uint8_t >({i0, i1});
135+ auto k_sub = &k.at <uint8_t >({i0, i1 / h_each_group_len });
136+ auto v_sub = &v.at <uint8_t >({i0, i1 / h_each_group_len });
135137
136138 auto mat0_out = reinterpret_cast <uint8_t *>(_buffer_mat0_out + thread_id * _buffer_mat0_out_size);
137139 auto mat1_out = reinterpret_cast <uint8_t *>(_buffer_mat1_out + thread_id * _buffer_mat1_out_size);
@@ -177,8 +179,8 @@ void mha_gpt_impl_amx::mha_bf16(const tensor& q, const tensor& k, const tensor&
177179 // k: [batch, head_num, key_seq_len, head_size]
178180 // v: [batch, head_num, value_seq_len, head_size]
179181 auto q_sub = &q.at <ov::bfloat16>({i0, i1, seq_start});
180- auto k_sub = &k.at <ov::bfloat16>({i0, i1});
181- auto v_sub = &v.at <ov::bfloat16>({i0, i1});
182+ auto k_sub = &k.at <ov::bfloat16>({i0, i1 / h_each_group_len });
183+ auto v_sub = &v.at <ov::bfloat16>({i0, i1 / h_each_group_len });
182184
183185 auto mat0_out = reinterpret_cast <float *>(_buffer_mat0_out + thread_id * _buffer_mat0_out_size);
184186 auto mat1_out = reinterpret_cast <float *>(_buffer_mat1_out + thread_id * _buffer_mat1_out_size);
@@ -279,7 +281,7 @@ status_t mha_gpt_impl_amx::exec(const tensor& q, const tensor& k, const tensor&
279281 auto key_seq_len = k.m_dims [2 ];
280282
281283 if (!(batch == k.m_dims [0 ] && batch == v.m_dims [0 ] &&
282- head_num == k.m_dims [1 ] && head_num == v.m_dims [1 ] &&
284+ k.m_dims [1 ] == v.m_dims [1 ] &&
283285 key_seq_len == v.m_dims [2 ] &&
284286 head_size == k.m_dims [3 ] && head_size == v.m_dims [3 ])) {
285287 DEBUG_LOG << " dim of q,k,v is error.\n " ;
0 commit comments