Skip to content

Commit f4cc2ff

Browse files
committed
WIP: mmvq local mem
1 parent 9012eb9 commit f4cc2ff

2 files changed

Lines changed: 132 additions & 3 deletions

File tree

ggml/src/ggml-sycl/ggml-sycl.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3163,20 +3163,25 @@ static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor
31633163
// KQV single-batch
31643164
ggml_sycl_mul_mat_vec_nc(ctx, src0, src1, dst);
31653165
} else if (!split && src0->type == GGML_TYPE_F16 && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) {
3166+
// std::cout << "batched sycl mulmat\n";
31663167
// KQ + KQV multi-batch
31673168
ggml_sycl_mul_mat_batched_sycl(ctx, src0, src1, dst);
31683169
} else if (use_dequantize_mul_mat_vec) {
3170+
// std::cout << "dmmv\n";
31693171
constexpr bool convert_src1_to_q8_1 = false;
31703172
opt_for_reorder(&ctx, src0, src1, dst, mul_mat_algo::DMMV);
31713173
ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_dequantize_mul_mat_vec, convert_src1_to_q8_1);
31723174
} else if (use_mul_mat_vec_q) {
3175+
// std::cout << "mmvq\n";
31733176
constexpr bool convert_src1_to_q8_1 = true;
31743177
opt_for_reorder(&ctx, src0, src1, dst, mul_mat_algo::MMVQ);
31753178
ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_vec_q, convert_src1_to_q8_1);
31763179
} else if (use_mul_mat_q) {
3180+
// std::cout << "mul_mat_q\n";
31773181
constexpr bool convert_src1_to_q8_1 = true;
31783182
ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_q, convert_src1_to_q8_1);
31793183
} else {
3184+
// std::cout << "fallback\n";
31803185
constexpr bool convert_src1_to_q8_1 = false;
31813186
ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_sycl, convert_src1_to_q8_1);
31823187
}

ggml/src/ggml-sycl/mmvq.cpp

Lines changed: 127 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,54 @@ static void mul_mat_vec_q_reorder(const void * __restrict__ vx, const void * __r
5757
}
5858
}
5959

60+
template <typename reorder_vec_dot_q_sycl>
61+
static void mul_mat_vec_q_reorder_local_mem(const void * __restrict__ vx, const void * __restrict__ vy, sycl::local_accessor<block_q8_1, 1> vy_local, float * __restrict__ dst,
62+
const int ncols, const int nrows, const sycl::nd_item<3> & nd_item) {
63+
using block_type = ggml_sycl_reordered::block_q_t<reorder_vec_dot_q_sycl::gtype>;
64+
using block_traits = typename block_type::traits;
65+
66+
const auto sg = nd_item.get_sub_group();
67+
const int sg_range = sg.get_group_linear_range();
68+
const int workgroup_id = nd_item.get_group_linear_id();
69+
const int sg_id = sg.get_group_linear_id();
70+
const int row = workgroup_id * sg_range + sg_id;
71+
72+
if (row >= nrows) return;
73+
74+
const int blocks_per_row = ncols / block_traits::qk;
75+
constexpr int blocks_per_subgroup = ceil_div(block_traits::vdr_mmvq * WARP_SIZE, block_traits::qi);
76+
constexpr int block_elements_per_sg = block_traits::qi / block_traits::vdr_mmvq;
77+
78+
const int total_y_blocks = blocks_per_row * block_type::block_to_q8_1_ratio();
79+
const int nblocks = nrows * blocks_per_row;
80+
81+
const block_q8_1 * y_global = static_cast<const block_q8_1 *>(vy);
82+
for (int i = nd_item.get_local_linear_id(); i < total_y_blocks; i += nd_item.get_local_range().size()) {
83+
vy_local[i] = y_global[i];
84+
}
85+
nd_item.barrier(sycl::access::fence_space::local_space);
86+
87+
float partial_sum = 0.0f;
88+
for (int i = sg.get_local_linear_id() / block_elements_per_sg; i < blocks_per_row; i += blocks_per_subgroup) {
89+
const int ibx = row * blocks_per_row + i;
90+
const int bx_offset = block_type::get_block_offset(ibx);
91+
const int d_offset = block_type::get_d_offset(nrows, ncols, ibx);
92+
93+
const int iby = i * block_type::block_to_q8_1_ratio();
94+
95+
for (int elem = 0; elem < block_elements_per_sg; elem += WARP_SIZE) {
96+
const int iqs = elem + block_traits::vdr_mmvq * (sg.get_local_linear_id() % block_elements_per_sg);
97+
partial_sum += reorder_vec_dot_q_sycl()(vx, bx_offset, d_offset, &vy_local[iby], iqs, nblocks);
98+
}
99+
}
100+
101+
float sum = sycl::reduce_over_group(sg, partial_sum, std::plus<>());
102+
103+
if (sg.leader()) {
104+
dst[row] = sum;
105+
}
106+
}
107+
60108
template <int qk, int qi, typename block_q_t, int vdr, vec_dot_q_sycl_t vec_dot_q_sycl>
61109
static void mul_mat_vec_q(const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
62110
const int ncols, const int nrows, const sycl::nd_item<3> & item_ct1) {
@@ -101,6 +149,58 @@ static void mul_mat_vec_q(const void * __restrict__ vx, const void * __restrict_
101149
}
102150
}
103151

152+
template <int qk, int qi, typename block_q_t, int vdr, vec_dot_q_sycl_t vec_dot_q_sycl>
153+
static void mul_mat_vec_q_local_mem(const void * __restrict__ vx, const void * __restrict__ vy, sycl::local_accessor<block_q8_1, 1> y_local, float * __restrict__ dst,
154+
const int ncols, const int nrows, const sycl::nd_item<3> & item_ct1) {
155+
const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) + item_ct1.get_local_id(1);
156+
157+
if (row >= nrows) {
158+
return;
159+
}
160+
161+
const int blocks_per_row = ncols / qk;
162+
constexpr int blocks_per_warp = (vdr * WARP_SIZE + qi - 1) / qi; // Ensuring blocks_per_warp > 0
163+
164+
assert(blocks_per_warp > 0);
165+
166+
// partial sum for each thread
167+
float tmp = 0.0f;
168+
169+
const block_q_t * x = (const block_q_t *) vx;
170+
const block_q8_1 * y = (const block_q8_1 *) vy;
171+
172+
const int blocks_per_row_y = ncols / /* qk_vec */ QK8_1; // TODO:: hardcoded
173+
const int total_y_blocks = blocks_per_row_y;
174+
for (int iby = item_ct1.get_local_id(2); iby < total_y_blocks; iby += item_ct1.get_local_range(2)) {
175+
y_local[iby] = y[iby];
176+
}
177+
178+
item_ct1.barrier(sycl::access::fence_space::local_space);
179+
180+
for (int i = item_ct1.get_local_id(2) / (qi / vdr); i < blocks_per_row; i += blocks_per_warp) {
181+
const int ibx = row * blocks_per_row + i; // x block index
182+
183+
const int iby = i * (qk / QK8_1); // y block index that aligns with ibx
184+
185+
for (size_t elem = 0; elem < qi / vdr; elem += WARP_SIZE) {
186+
const int iqs = elem + vdr * (item_ct1.get_local_id(2) %
187+
(qi / vdr)); // x block quant index when casting the quants to int
188+
189+
tmp += vec_dot_q_sycl(&x[ibx], &y_local[iby], iqs);
190+
}
191+
}
192+
193+
// sum up partial sums and write back result
194+
#pragma unroll
195+
for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
196+
tmp += dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
197+
}
198+
199+
if (item_ct1.get_local_id(2) == 0) {
200+
dst[row] = tmp;
201+
}
202+
}
203+
104204
template <int qk, int qi, typename block_q_t, int vdr>
105205
static void mul_mat_vec_q_iq2_xxs_q8_1(const void *__restrict__ vx,
106206
const void *__restrict__ vy,
@@ -720,41 +820,65 @@ static void mul_mat_vec_q4_K_q8_1_sycl(const void *vx, const void *vy,
720820
float *dst, const int ncols,
721821
const int nrows,
722822
dpct::queue_ptr stream) {
823+
// std::cout << ">>>>>>>>> THIS IS CALLED\n";
723824
GGML_ASSERT(ncols % QK_K == 0);
724825
const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
725826
const sycl::range<3> block_nums(1, 1, block_num_y);
726827
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
828+
829+
using block_type = ggml_sycl_reordered::block_q_t<GGML_TYPE_Q4_K>;
830+
const int blocks_per_row = ncols / block_type::traits::qk;
831+
const int total_y_blocks = blocks_per_row * block_type::block_to_q8_1_ratio();
832+
if(total_y_blocks * sizeof(block_q8_1) > stream->get_device().get_info<sycl::info::device::local_mem_size>()) {
833+
// TODO: add fallback
834+
GGML_ABORT("not enough local mem");
835+
}
836+
727837
{
728838

729839
stream->submit([&](sycl::handler &cgh) {
840+
sycl::local_accessor<block_q8_1, 1> vy_local(sycl::range<1>(total_y_blocks), cgh);
730841

731842
cgh.parallel_for(
732843
sycl::nd_range<3>(block_nums * block_dims, block_dims),
733844
[=](sycl::nd_item<3> item_ct1)
734845
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
735-
mul_mat_vec_q<QK_K, QI4_K, block_q4_K,
846+
mul_mat_vec_q_local_mem<QK_K, QI4_K, block_q4_K,
736847
VDR_Q4_K_Q8_1_MMVQ, vec_dot_q4_K_q8_1>(
737-
vx, vy, dst, ncols, nrows, item_ct1);
848+
vx, vy, vy_local, dst, ncols, nrows, item_ct1);
738849
});
739850
});
740851
}
741852
}
742853

743854
static void reorder_mul_mat_vec_q4_k_q8_1_sycl(const void * vx, const void * vy, float * dst, const int ncols,
744855
const int nrows, dpct::queue_ptr stream) {
856+
// std::cout << ">>>>>>>>> REORDER PATH\n";
857+
745858
GGML_ASSERT(ncols % QK_K == 0);
746859

747860
const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y);
748861
constexpr size_t num_subgroups = 16;
749862
GGML_ASSERT(block_num_y % num_subgroups == 0);
863+
// std::cout << "block_num_y: " << block_num_y << ", num_subgroups: " << num_subgroups << ", nrows: " << nrows << ", ncols:" << ncols << "\n";
750864

751865
const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, block_num_y * WARP_SIZE);
752866
const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE);
753867

868+
using block_type = ggml_sycl_reordered::block_q_t<GGML_TYPE_Q4_K>;
869+
const int blocks_per_row = ncols / block_type::traits::qk;
870+
const int total_y_blocks = blocks_per_row * block_type::block_to_q8_1_ratio();
871+
if(total_y_blocks * sizeof(block_q8_1) > stream->get_device().get_info<sycl::info::device::local_mem_size>()) {
872+
// TODO: add fallback
873+
GGML_ABORT("not enough local mem");
874+
}
875+
754876
stream->submit([&](sycl::handler & cgh) {
877+
sycl::local_accessor<block_q8_1, 1> vy_local(sycl::range<1>(total_y_blocks), cgh);
878+
755879
cgh.parallel_for(sycl::nd_range<3>(global_size, workgroup_size),
756880
[=](sycl::nd_item<3> nd_item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
757-
mul_mat_vec_q_reorder<reorder_vec_dot_q_sycl<GGML_TYPE_Q4_K>>(vx, vy, dst, ncols,
881+
mul_mat_vec_q_reorder_local_mem<reorder_vec_dot_q_sycl<GGML_TYPE_Q4_K>>(vx, vy, vy_local, dst, ncols,
758882
nrows, nd_item);
759883
});
760884
});

0 commit comments

Comments
 (0)