Skip to content

[SYCL] Q8_0 quantization ~4x slower than Q4_K_M on Intel Arc Pro B70 (Xe2/Battlemage) — kernel efficiency issue #21517

@PMZFX

Description

@PMZFX

Summary

Q8_0 quantized dense models achieve only 21-24% of theoretical memory bandwidth on Intel Arc Pro B70 (Xe2/Battlemage) GPUs, while Q4_K_M achieves 53-64%. This results in Q8_0 token generation being 4-5x slower than Q4_K_M despite only 1.7x more data.

The issue is not VRAM pressure, PCIe bandwidth, or backend-specific — it affects both SYCL and Vulkan backends equally, and persists when splitting across two GPUs with abundant free VRAM.

Hardware

  • 2x Intel Arc Pro B70 (BMG-G31, 32 GB GDDR6 ECC each, 608 GB/s bandwidth per card)
  • PCIe Gen 4 x8 per card
  • AMD Ryzen 5 9600X, 60 GB DDR5
  • Driver: libze-intel-gpu1 26.09.37435.1, IGC 2.30.1 (also tested with 26.05.37020.3 / IGC 2.28.4 — same results)
  • llama.cpp commit 25eec6f, built with Intel oneAPI DPC++ 2025.3.3 (-DGGML_SYCL=ON -DGGML_SYCL_F16=ON)
  • Also tested: Vulkan via llama.cpp b8064 (Ubuntu apt package)

Benchmark Data

Full quantization sweep — Qwen3.5-27B (26.9B params, dense), single GPU

Quant Size (GiB) pp512 t/s tg128 t/s Effective BW % of 608 GB/s Tier
Q4_0 14.63 243 23.67 346 GB/s 57% Fast
Q4_K_S 14.68 309 23.05 339 GB/s 56% Fast
Q4_K_M 15.58 302 20.56 321 GB/s 53% Fast
IQ4_XS 13.94 267 17.52 244 GB/s 40% Medium
Q4_1 15.99 259 16.78 268 GB/s 44% Medium
Q6_K 20.90 304 13.83 289 GB/s 48% Medium
Q5_K_M 18.25 300 13.78 252 GB/s 41% Medium
Q5_K_S 17.58 307 13.50 237 GB/s 39% Medium
IQ4_NL 14.60 238 5.85 85 GB/s 14% Broken
Q8_0 26.62 295 4.88 130 GB/s 21% Broken

Critical: IQ4_NL (14.6 GiB) and Q4_0 (14.6 GiB) are the same size but IQ4_NL is 4x slower. This proves the bottleneck is kernel efficiency, not data volume.

Small Q8_0 model — rules out VRAM pressure

Model Quant Size VRAM Free tg128 t/s BW Util
Qwen 9B Q8_0 8.86 GiB ~22 GiB 16.5 24%
Qwen 27B Q8_0 26.62 GiB ~4 GiB 4.88 21%

Both achieve ~21-24% bandwidth — proportional to model size, confirming kernel-level inefficiency.

Dual GPU doesn't help Q8_0

Model Quant GPUs VRAM Free/card tg128 t/s
Qwen 27B Q8_0 1 ~4 GiB 4.88
Qwen 27B Q8_0 2 (split) ~18 GiB each 4.96

Both SYCL and Vulkan affected

Backend Q4_K_M tg128 Q8_0 tg128
SYCL 20.56 4.97
Vulkan 10.71 5.37

Kernel Dispatch Analysis

We traced the SYCL dispatch and found:

  1. Q4_K_M uses MMVQ+reorder — quantized dot product with data layout optimization
  2. Q8_0 is stuck on DMMV — generic dequantize-mul-mat-vec (not in reorder support list)

We proved this matters by forcing Q4_K_M through DMMV:

Config tg128 t/s
Q4_K_M → MMVQ+reorder (default) 20.56
Q4_K_M → DMMV (forced via GGML_SYCL_PRIORITIZE_DMMV=1) 12.38
Q8_0 → DMMV (default, only option) 4.97
Q8_0 → MMVQ (patched ggml_sycl_supports_dmmv) 4.33

DMMV is 40% slower than MMVQ for Q4_K_M, but switching Q8_0 to MMVQ doesn't help — both paths are slow for Q8_0. The Q8_0 kernel implementations themselves need optimization for Xe2.

Generic vs reorder DMMV

  • Generic DMMV (used by Q8_0): iter_stride = 2 * GGML_SYCL_DMMV_X = 64 → 2 values per thread per iteration
  • Reorder DMMV (used by Q4_0): iter_stride = 8 * 2 * GGML_SYCL_DMMV_X = 512 → 16 values per thread per iteration

The reorder path processes 8x more work per thread iteration and uses a data layout with separated scales for better coalesced memory access.

Environment Variables Tested (no effect on Q8_0 tg)

  • SYCL_PI_LEVEL_ZERO_USE_IMMEDIATE_COMMANDLISTS=0/1
  • SYCL_PI_LEVEL_ZERO_BATCH_SIZE=1
  • ZE_FLAT_DEVICE_HIERARCHY=FLAT
  • IGC_EnableDPEmulation=1
  • GGML_VK_DISABLE_COOPMAT=1 (Vulkan — coopmat has no tg effect)

Driver Update Test

Updated from compute runtime 26.05.37020.3 (IGC 2.28.4) to 26.09.37435.1 (IGC 2.30.1). Clean rebuild of llama.cpp. Q8_0 tg unchanged (4.98 t/s). Q8_0 pp improved ~9%. Confirms the issue is in llama.cpp kernel code, not the Intel driver or compiler.

GGML_SYCL_F16 Finding

Building with -DGGML_SYCL_F16=ON gives 2.4x prompt processing speedup (302→725 t/s for Q4_K_M) but does not affect token generation at all. This is expected — pp uses GEMM (compute-bound, benefits from FP16 XMX), tg uses DMMV/MMVQ (memory-bandwidth-bound).

Cross-Platform Context

Related issues: #19887 (inverse quant anomaly on Vulkan/A770), #19918 (SYCL vs Vulkan perf gap), #18808 (Battlemage user reports)

Suggested Fixes

  1. Add Q8_0 to the MMVQ reorder path — implement dequantize_block_q8_0_reorder, add Q8_0 to ggml_sycl_supports_reorder_mmvq(). The reorder data layout should significantly improve memory access patterns for Q8_0's 34-byte blocks (not power-of-2).
  2. Increase Q8_0 DMMV iter_stride — process more values per thread iteration, matching the 8x factor used in the Q4_0 reorder kernel.
  3. Profile on Intel GPU tools — use ze_tracer or Intel VTune to identify specific kernel bottleneck (memory latency hiding, cache miss rate, EU utilization).

Files Involved

  • ggml/src/ggml-sycl/ggml-sycl.cpp — dispatch logic (line ~3526, ggml_sycl_mul_mat)
  • ggml/src/ggml-sycl/dmmv.cpp — DMMV kernels (line ~975, dequantize_mul_mat_vec_q8_0_sycl)
  • ggml/src/ggml-sycl/mmvq.cpp — MMVQ kernels (line ~682, mul_mat_vec_q8_0_q8_1_sycl)
  • ggml/src/ggml-sycl/dequantize.hpp — Q8_0 dequant function (line ~146)
  • ggml/src/ggml-sycl/vecdotq.hpp — vec_dot_q8_0_q8_1 (line ~844)

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions