Skip to content

Commit 0ae7d53

Browse files
committed
[SYCL] Add Q8_0 reorder optimization for Intel GPUs (~3x tg speedup)
Extend the existing SYCL reorder optimization to support Q8_0 quantization. The reorder separates scale factors from weight data for coalesced memory access — previously only Q4_0, Q4_K, and Q6_K were supported. On Intel Arc Pro B70 (Xe2/Battlemage), Q8_0 token generation improves from 4.88 t/s to 15.24 t/s (3.1x) on Qwen3.5-27B. Memory bandwidth utilization rises from 21% to 66% of theoretical maximum. Q8_0 is now faster than Q6_K (15.24 vs 13.83 t/s) with higher quality. Changes: - quants.hpp: Add block_q_t<GGML_TYPE_Q8_0> reorder traits - dequantize.hpp: Add dequantize_q8_0_reorder() for separated layout - dmmv.cpp: Add Q8_0 DMMV reorder kernel and dispatch - vecdotq.hpp: Add reorder_vec_dot_q_sycl<GGML_TYPE_Q8_0> - mmvq.cpp: Add Q8_0 MMVQ reorder kernel and dispatch - ggml-sycl.cpp: Add reorder_qw_q8_0(), update dispatch and extra allocation gate in ggml_backend_sycl_buffer_init_tensor() Fixes: #21517
1 parent 25eec6f commit 0ae7d53

6 files changed

Lines changed: 247 additions & 3 deletions

File tree

ggml/src/ggml-sycl/dequantize.hpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,22 @@ static __dpct_inline__ void dequantize_q5_1(const void *vx, const int64_t ib,
143143
#endif // GGML_SYCL_F16
144144
}
145145

146+
static __dpct_inline__ void dequantize_q8_0_reorder(const void *d_ptr, const int64_t ib, const void *qs,
147+
const int iqs, dfloat2 &v) {
148+
const dfloat d = (const dfloat)*((const sycl::half*)d_ptr + ib);
149+
150+
v.x() = ((const int8_t *)qs)[iqs + 0];
151+
v.y() = ((const int8_t *)qs)[iqs + 1];
152+
153+
#ifdef GGML_SYCL_F16
154+
v.s0() *= d;
155+
v.s1() *= d;
156+
#else
157+
v.x() *= d;
158+
v.y() *= d;
159+
#endif // GGML_SYCL_F16
160+
}
161+
146162
static __dpct_inline__ void dequantize_q8_0(const void *vx, const int64_t ib,
147163
const int iqs, dfloat2 &v) {
148164
const block_q8_0 * x = (const block_q8_0 *) vx;

ggml/src/ggml-sycl/dmmv.cpp

Lines changed: 103 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -972,6 +972,103 @@ static void dequantize_mul_mat_vec_q5_1_sycl(const void *vx, const dfloat *y,
972972
}
973973
}
974974

975+
static void dequantize_mul_mat_vec_q8_0_sycl_reorder(const void *vx, const dfloat *y,
976+
float *dst, const int ncols,
977+
const int nrows,
978+
dpct::queue_ptr stream) {
979+
GGML_ASSERT(ncols % GGML_SYCL_DMMV_X == 0);
980+
const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
981+
const sycl::range<3> block_nums(1, 1, block_num_y);
982+
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
983+
{
984+
dpct::has_capability_or_fail(stream->get_device(),
985+
{sycl::aspect::fp16});
986+
987+
stream->parallel_for(
988+
sycl::nd_range<3>(block_nums * block_dims, block_dims),
989+
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
990+
// Q8_0 reorder layout: [all qs (ncols*nrows bytes)][all d values]
991+
// Cannot reuse dequantize_mul_mat_vec_reorder template because it has
992+
// Q4_0-specific constants hardcoded (d_ptr offset and qs stride).
993+
const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
994+
item_ct1.get_local_id(1);
995+
if (row >= nrows) return;
996+
997+
const int tid = item_ct1.get_local_id(2);
998+
const int iter_stride = 8*2*GGML_SYCL_DMMV_X;
999+
const int vals_per_iter = iter_stride / WARP_SIZE;
1000+
const int ncols_left = ncols % (QK8_0*WARP_SIZE);
1001+
const int ncols_align = ncols - ncols_left;
1002+
1003+
#ifdef GGML_SYCL_F16
1004+
sycl::half2 tmp = {0.0f, 0.0f};
1005+
#else
1006+
float tmp = 0.0f;
1007+
#endif
1008+
const char *d_ptr = (const char*)vx + ncols*nrows; // d after all qs
1009+
1010+
int i = 0;
1011+
for (i = 0; i < ncols_align; i += iter_stride) {
1012+
const int col = i + vals_per_iter*tid;
1013+
const int ib = (row*ncols + col)/QK8_0;
1014+
const int iqs = col % QK8_0;
1015+
1016+
#pragma unroll
1017+
for (int j = 0; j < vals_per_iter; j += 2) {
1018+
dfloat2 v;
1019+
dequantize_q8_0_reorder((const void *)d_ptr, ib, (const void *)vx,
1020+
ib * QK8_0 + iqs + j, v);
1021+
1022+
#ifdef GGML_SYCL_F16
1023+
dfloat2 t1{y[col + j + 0], y[col + j + 1]};
1024+
tmp += v * t1;
1025+
#else
1026+
tmp += v.x() * y[col + j + 0];
1027+
tmp += v.y() * y[col + j + 1];
1028+
#endif
1029+
}
1030+
}
1031+
1032+
// handle remaining columns
1033+
for (; i < ncols; i += iter_stride) {
1034+
if (tid >= ncols_left/QK8_0) continue;
1035+
const int col = i + vals_per_iter*tid;
1036+
const int ib = (row*ncols + col)/QK8_0;
1037+
const int iqs = col % QK8_0;
1038+
1039+
#pragma unroll
1040+
for (int j = 0; j < vals_per_iter; j += 2) {
1041+
dfloat2 v;
1042+
dequantize_q8_0_reorder((const void *)d_ptr, ib, (const void *)vx,
1043+
ib * QK8_0 + iqs + j, v);
1044+
1045+
#ifdef GGML_SYCL_F16
1046+
dfloat2 t1{y[col + j + 0], y[col + j + 1]};
1047+
tmp += v * t1;
1048+
#else
1049+
tmp += v.x() * y[col + j + 0];
1050+
tmp += v.y() * y[col + j + 1];
1051+
#endif
1052+
}
1053+
}
1054+
1055+
// reduce
1056+
const int mask_start = ncols > GGML_SYCL_DMMV_X ? WARP_SIZE >> 1 : WARP_SIZE >> 2;
1057+
for (int mask = mask_start; mask > 0; mask >>= 1) {
1058+
tmp += dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
1059+
}
1060+
1061+
if (tid == 0) {
1062+
#ifdef GGML_SYCL_F16
1063+
dst[row] = tmp.x() + tmp.y();
1064+
#else
1065+
dst[row] = tmp;
1066+
#endif
1067+
}
1068+
});
1069+
}
1070+
}
1071+
9751072
static void dequantize_mul_mat_vec_q8_0_sycl(const void *vx, const dfloat *y,
9761073
float *dst, const int ncols,
9771074
const int nrows,
@@ -1122,7 +1219,12 @@ void ggml_sycl_op_dequantize_mul_mat_vec(
11221219
dequantize_mul_mat_vec_q5_1_sycl(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);
11231220
break;
11241221
case GGML_TYPE_Q8_0:
1125-
dequantize_mul_mat_vec_q8_0_sycl(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);
1222+
if ((ggml_tensor_extra_gpu *) dst->src[0]->extra &&
1223+
((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) {
1224+
dequantize_mul_mat_vec_q8_0_sycl_reorder(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);
1225+
} else {
1226+
dequantize_mul_mat_vec_q8_0_sycl(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);
1227+
}
11261228
break;
11271229
case GGML_TYPE_Q2_K:
11281230
dequantize_mul_mat_vec_q2_K_sycl(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream);

ggml/src/ggml-sycl/ggml-sycl.cpp

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -411,7 +411,7 @@ ggml_backend_sycl_buffer_init_tensor(ggml_backend_buffer_t buffer,
411411
assert(tensor->view_src->buffer->buft == buffer->buft);
412412
return GGML_STATUS_SUCCESS;
413413
}
414-
if ((tensor->type == GGML_TYPE_Q4_0 || tensor->type == GGML_TYPE_Q4_K || tensor->type == GGML_TYPE_Q6_K) &&
414+
if ((tensor->type == GGML_TYPE_Q4_0 || tensor->type == GGML_TYPE_Q8_0 || tensor->type == GGML_TYPE_Q4_K || tensor->type == GGML_TYPE_Q6_K) &&
415415
!g_ggml_sycl_disable_optimize) {
416416
ggml_tensor_extra_gpu * extra = new ggml_tensor_extra_gpu{};
417417
tensor->extra = extra;
@@ -3254,6 +3254,7 @@ inline bool ggml_sycl_supports_mmq(enum ggml_type type) {
32543254
inline bool ggml_sycl_supports_reorder_mul_mat_sycl(enum ggml_type type) {
32553255
switch (type) {
32563256
case GGML_TYPE_Q4_0:
3257+
case GGML_TYPE_Q8_0:
32573258
return true;
32583259
case GGML_TYPE_Q4_K:
32593260
case GGML_TYPE_Q6_K:
@@ -3266,6 +3267,7 @@ inline bool ggml_sycl_supports_reorder_mul_mat_sycl(enum ggml_type type) {
32663267
inline bool ggml_sycl_supports_reorder_dmmv(enum ggml_type type) {
32673268
switch (type) {
32683269
case GGML_TYPE_Q4_0:
3270+
case GGML_TYPE_Q8_0:
32693271
return true;
32703272
default:
32713273
return false;
@@ -3275,6 +3277,7 @@ inline bool ggml_sycl_supports_reorder_dmmv(enum ggml_type type) {
32753277
inline bool ggml_sycl_supports_reorder_mmvq(enum ggml_type type) {
32763278
switch (type) {
32773279
case GGML_TYPE_Q4_0:
3280+
case GGML_TYPE_Q8_0:
32783281
case GGML_TYPE_Q4_K:
32793282
case GGML_TYPE_Q6_K:
32803283
return true;
@@ -3364,6 +3367,40 @@ static void reorder_qw_q4_0(uint8_t * data_device, const int ncols, const int nr
33643367
sycl_ext_free(stream, tmp_buf);
33653368
}
33663369

3370+
static void reorder_qw_q8_0(uint8_t * data_device, const int ncols, const int nrows, size_t size, size_t offset,
3371+
dpct::queue_ptr stream) {
3372+
uint8_t * tmp_buf = static_cast<uint8_t *>(sycl_ext_malloc_device(stream, size));
3373+
3374+
sycl::event copy_event;
3375+
SYCL_CHECK(CHECK_TRY_ERROR(copy_event = stream->memcpy(tmp_buf, data_device, size)));
3376+
if (!g_ggml_sycl_use_async_mem_op) {
3377+
copy_event.wait();
3378+
}
3379+
3380+
GGML_ASSERT((size % sizeof(block_q8_0) == 0));
3381+
GGML_ASSERT((offset % sizeof(block_q8_0) == 0));
3382+
int offset_blks = offset / sizeof(block_q8_0);
3383+
auto qs_ptr = data_device + offset_blks * QK8_0;
3384+
auto d_ptr = (sycl::half*)(qs_ptr + ncols * nrows) + offset_blks;
3385+
3386+
auto reorder_event = stream->parallel_for(
3387+
size / sizeof(block_q8_0),
3388+
[=](auto i) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
3389+
const block_q8_0* x = (const block_q8_0*)tmp_buf;
3390+
const int ib = i;
3391+
3392+
for (int j = 0; j < QK8_0; j++)
3393+
{
3394+
*((int8_t*)qs_ptr + ib * QK8_0 + j) = x[ib].qs[j];
3395+
}
3396+
*(d_ptr + ib) = x[ib].d;
3397+
});
3398+
if (!g_ggml_sycl_use_async_mem_op) {
3399+
reorder_event.wait_and_throw();
3400+
}
3401+
sycl_ext_free(stream, tmp_buf);
3402+
}
3403+
33673404
static void reorder_qw_q4_k(uint8_t * data_device, size_t size, size_t offset, dpct::queue_ptr stream) {
33683405
GGML_ASSERT(size % sizeof(block_q4_K) == 0);
33693406
GGML_ASSERT(offset % sizeof(block_q4_K) == 0);
@@ -3460,6 +3497,9 @@ static void reorder_qw(const ggml_tensor * src0, dpct::queue_ptr stream) {
34603497
case GGML_TYPE_Q4_0:
34613498
reorder_qw_q4_0(data_device, ncols, nrows, size, 0, stream);
34623499
break;
3500+
case GGML_TYPE_Q8_0:
3501+
reorder_qw_q8_0(data_device, ncols, nrows, size, 0, stream);
3502+
break;
34633503
case GGML_TYPE_Q4_K:
34643504
reorder_qw_q4_k(data_device, size, 0, stream);
34653505
break;

ggml/src/ggml-sycl/mmvq.cpp

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -679,6 +679,25 @@ static void mul_mat_vec_q5_1_q8_1_sycl(const void *vx, const void *vy,
679679
}
680680
}
681681

682+
static void reorder_mul_mat_vec_q8_0_q8_1_sycl(const void * vx, const void * vy, float * dst, const int ncols,
683+
const int nrows, dpct::queue_ptr stream) {
684+
GGML_ASSERT(ncols % QK8_0 == 0);
685+
const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y);
686+
constexpr size_t num_subgroups = 16;
687+
GGML_ASSERT(block_num_y % num_subgroups == 0);
688+
689+
const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, (block_num_y * WARP_SIZE));
690+
const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE);
691+
692+
stream->submit([&](sycl::handler & cgh) {
693+
cgh.parallel_for(sycl::nd_range<3>(global_size, workgroup_size),
694+
[=](sycl::nd_item<3> nd_item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
695+
mul_mat_vec_q_reorder<reorder_vec_dot_q_sycl<GGML_TYPE_Q8_0>>(vx, vy, dst, ncols, nrows,
696+
nd_item);
697+
});
698+
});
699+
}
700+
682701
static void mul_mat_vec_q8_0_q8_1_sycl(const void *vx, const void *vy,
683702
float *dst, const int ncols,
684703
const int nrows,
@@ -1101,7 +1120,13 @@ void ggml_sycl_op_mul_mat_vec_q(ggml_backend_sycl_context & ctx, const ggml_tens
11011120
mul_mat_vec_q5_1_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
11021121
break;
11031122
case GGML_TYPE_Q8_0:
1104-
mul_mat_vec_q8_0_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
1123+
if ((ggml_tensor_extra_gpu *) dst->src[0]->extra &&
1124+
((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) {
1125+
GGML_SYCL_DEBUG("Calling reorder_mul_mat_vec_q8_0_q8_1_sycl\n");
1126+
reorder_mul_mat_vec_q8_0_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
1127+
} else {
1128+
mul_mat_vec_q8_0_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
1129+
}
11051130
break;
11061131
case GGML_TYPE_Q2_K:
11071132
mul_mat_vec_q2_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);

ggml/src/ggml-sycl/quants.hpp

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,27 @@ template <> struct block_q_t<GGML_TYPE_Q6_K> {
105105
static constexpr int block_to_q8_1_ratio() { return traits::qk / QK8_1; }
106106
};
107107

108+
template <> struct block_q_t<GGML_TYPE_Q8_0> {
109+
struct traits {
110+
static constexpr uint32_t qk = QK8_0; // 32
111+
static constexpr uint32_t qi = QI8_0; // 8
112+
static constexpr uint32_t qr = QR8_0; // 1
113+
static constexpr uint32_t vdr_mmvq = 4;
114+
};
115+
116+
// Q8_0 reorder layout: [qs0|qs1|...|qsN][d0|d1|...|dN]
117+
// Each block has 32 int8 weights (32 bytes) followed by all scales
118+
static constexpr std::pair<int, int> get_block_offset(const int block_index, const int /* nblocks */) {
119+
return { block_index * QK8_0, 0 };
120+
}
121+
122+
static constexpr std::pair<int, int> get_d_offset(int nrows, int ncols, const int block_index) {
123+
return { (ncols * nrows) + block_index * sizeof(ggml_half), 0 };
124+
}
125+
126+
static constexpr int block_to_q8_1_ratio() { return traits::qk / QK8_1; } // 1
127+
};
128+
108129
} // namespace ggml_sycl_reordered
109130

110131
#endif // GGML_SYCL_QUANTS_HPP

ggml/src/ggml-sycl/vecdotq.hpp

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -351,6 +351,46 @@ template <> struct reorder_vec_dot_q_sycl<GGML_TYPE_Q4_0> {
351351
};
352352
};
353353

354+
template <> struct reorder_vec_dot_q_sycl<GGML_TYPE_Q8_0> {
355+
static constexpr ggml_type gtype = GGML_TYPE_Q8_0;
356+
357+
using q8_0_block = ggml_sycl_reordered::block_q_t<GGML_TYPE_Q8_0>;
358+
using q8_0_traits = typename q8_0_block::traits;
359+
360+
__dpct_inline__ float vec_dot_q8_0_q8_1_impl(const int * v, const int * u, const float & d8_0, const sycl::half2 & ds8) {
361+
int sumi = 0;
362+
363+
#pragma unroll
364+
for (size_t i = 0; i < q8_0_traits::vdr_mmvq; ++i) {
365+
// Q8_0 values are signed int8, no nibble extraction needed
366+
// Direct dp4a: each int packs 4 int8 values
367+
sumi = dpct::dp4a(v[i], u[i], sumi);
368+
}
369+
370+
const sycl::float2 ds8f = ds8.convert<float, sycl::rounding_mode::automatic>();
371+
372+
// Q8_0 has no bias term (values are signed), so just scale
373+
return d8_0 * sumi * ds8f.x();
374+
}
375+
376+
__dpct_inline__ float operator()(const void * __restrict__ vbq, const std::pair<int, int> ibx_offset,
377+
const std::pair<int, int> d_offset, const int8_t * q8_1_quant_ptr,
378+
const sycl::half2 * q8_1_ds, const int & iqs) {
379+
const int8_t * bq8_0 = static_cast<const int8_t *>(vbq) + ibx_offset.first;
380+
const ggml_half d = *(reinterpret_cast<const ggml_half *>(static_cast<const uint8_t *>(vbq) + d_offset.first));
381+
int v[q8_0_traits::vdr_mmvq];
382+
int u[q8_0_traits::vdr_mmvq];
383+
384+
#pragma unroll
385+
for (size_t i = 0; i < q8_0_traits::vdr_mmvq; ++i) {
386+
v[i] = get_int_from_int8(bq8_0, iqs + i);
387+
u[i] = get_int_from_int8_aligned(q8_1_quant_ptr, iqs + i);
388+
}
389+
390+
return vec_dot_q8_0_q8_1_impl(v, u, d, *q8_1_ds);
391+
};
392+
};
393+
354394
static inline float vec_dot_q4_K_q8_1_common(const int * __restrict__ q4, const uint16_t * __restrict__ scales,
355395
const ggml_half2 & dm, const block_q8_1 * __restrict__ bq8_1,
356396
const int & iqs) {

0 commit comments

Comments
 (0)