Skip to content

(Performance) Optimized x86 and generic q1_0(_g128) dot#10

Open
pl752 wants to merge 3 commits intoPrismML-Eng:masterfrom
pl752:perf/q1_0_g128_no_nofma
Open

(Performance) Optimized x86 and generic q1_0(_g128) dot#10
pl752 wants to merge 3 commits intoPrismML-Eng:masterfrom
pl752:perf/q1_0_g128_no_nofma

Conversation

@pl752
Copy link
Copy Markdown

@pl752 pl752 commented Apr 3, 2026

Hello
This is yet another PR about the fix of the truncation and optimization of the cpu inference.

In this case I have:

  • Replaced a ton of bit-masking operations, removed redundant float multiplication and unrolled the hot inner loop with constant masks for accumulation with signs in arch-agnostic fallback
  • Introduced paths for filling the gap between default fallback and AVX-512 capable CPUs
  • Performed tests to make sure that optimizations don't have effect on precision/correctness
  • Performed various experinments (most yielded worse performance) including:
  • brancless variant of unroll
  • various register and superscalar pipeline pressure options (AVX 2 uses doubled accumulation flow)
  • AVX-512 VNNI
  • explicitly precomputed masks for SIMD

Note that this PR is built on top of the #3 by @jordankzf, who implemented AVX-512 workflow

Benchmarks were performed with:

  • CPU: AMD Ryzen 5 7640HS (at 65w)
  • WSL vm
  • LPDDR5 @ 6400MT JEDEC
  • Model: Bonsai-1.7B.gguf (Q1_0_g128)
  • Threads: 6
Flow pp 512 t/s tg 128 t/s Speedup Notes
Initial* 1.59 0.85 1.0x / 1.0x Slow
Scalar 9.57 7.06 6.0x / 8.3x Explicit byte-oriented unroll
SSSE3 26.13 19.51 16.5x / 22.9x 128-bit specialization
AVX 34.99 27.31 22.1x / 32.1x Mixed-width specialization
AVX2 + FMA 80.02 51.46 50.4x / 60.5x 256-bit specialization
AVX512BW 97.16 60.88 61.3x / 71.5x Leverages new SIMD extensions**
  • * extrapolated from pp 32 / tg 16: 1.659 t/s pp and 0.862 t/s tg, as I was impatient.
  • ** new SIMD instruction kinds improve performance even on AMD Zen4 implementation of AVX-512, which uses 256 bit pipeline twice instead of implementing full 512 bit one

I would appreciate your feedback

@github-actions github-actions bot added the ggml label Apr 3, 2026
@khosravipasha
Copy link
Copy Markdown
Collaborator

khosravipasha commented Apr 4, 2026

Thanks this looks great, nice write up.
I guess could you also run corretness checks, similar to what we did here: #8
see KL divergence between packed vs unpacked model (should be close to 0).

I am not too familiar with SIMD/AVX stuff, what CPUs does this support:
I know there is some different between AVX512BW, AVX2, SSSE3, AVX, is this for different CPU archetictures?

@pl752
Copy link
Copy Markdown
Author

pl752 commented Apr 4, 2026

@khosravipasha You are welcome :)
These are various generations of x86 simd instructions (mostly backward compatible, aka AVX can SSE, etc.):

  • SSE family is 128 bit simd introduced with Pentium III (fp32 only); then SSE2 followed up with (fp64 and int), then SSE3 with some additional utility instructions;
  • SSSE3 is most interesting in 128 bit case, as it provides instructions for shuffling/expansion of bit mask, sign assign ops and int dot product used in our and other implementations there (this set is available since Core 2 generation and AMD Bobcat, so it covers essentially all realistic x86 targets (except for some shenanigans);
  • AVX is 256 bit SIMD for fp32/fp64 introduced with Sandy bridge (eg core i7-2xxx) or AMD Bulldozer (eg fx-8100), lacks int ops though, still can be used for accum paired with SSSE3 instructions;
  • AVX2 adds most of the missing 256 bit instructions, plus it is paired with FMA(3) instruction introduction (except for VIA C4650), providing fused multiply-add; introduced in Haswell or AMD Zen, in our case used with most modern-ish processors completing the gap bridging between legacy and SoTA cpus;
  • AVX512 extends AVX2 ops to 512 bit SIMD, plus it has numerous extensions including AVX512BW with support for int8 and int16 ops (used in Q8_0 sign expansion to int16)
  • Scalar fallback is mostly for non-x86 cpus and comformity, there I got rid of bit mask calculations, int to float conversion with multiplication and unrolled inner loop with constant masks to minimize computations and increase pipeline pressure

As for perplexity, I have performed run for single 64 token wikitext-2-test chunk with 1.7B model

Flow PPL ln(PPL(Q)/PPL(base)) KL Divergence Δp RMS Same top p
Scalar 17.1988 ± 9.6330 -0.00000 ± -nan -0.00000 ± 0.00000 0.001 ± 0.000 % 100.000 ± 0.000 %
SSSE3 17.2739 ± 9.7094 0.00435 ± 0.00450 0.00024 ± 0.00004 0.218 ± 0.038 % 100.000 ± 0.000 %
AVX 17.2402 ± 9.6760 0.00240 ± 0.00352 0.00025 ± 0.00005 0.362 ± 0.062 % 90.323 ± 5.398 %
AVX2 17.2321 ± 9.6740 0.00193 ± 0.00398 0.00023 ± 0.00004 0.379 ± 0.067 % 96.774 ± 3.226 %
AVX-512 17.2463 ± 9.6895 0.00275 ± 0.00298 0.00023 ± 0.00005 0.279 ± 0.070 % 93.548 ± 4.485 %

I will perform more runs

@pl752
Copy link
Copy Markdown
Author

pl752 commented Apr 4, 2026

I have run 5 chunks of 512 tokens, looks better, I think, will run 100 chunks:

Flow Mean PPL(Q) Mean ln(PPL(Q)/PPL(base)) Mean KLD RMS Δp Same top p
Scalar baseline 20.943033 +/- 2.071658 0.000000 0.000000 0.000 % 100.000 %
SSSE3 21.076136 +/- 2.102022 0.006335 +/- 0.004656 0.000267 +/- 0.000009 0.386 +/- 0.017 % 99.059 +/- 0.271 %
AVX 21.081167 +/- 2.102227 0.006574 +/- 0.004686 0.000285 +/- 0.000011 0.404 +/- 0.019 % 99.451 +/- 0.207 %
AVX2 21.087163 +/- 2.103328 0.006858 +/- 0.004650 0.000282 +/- 0.000012 0.418 +/- 0.027 % 99.529 +/- 0.192 %
AVX-512BW 21.095567 +/- 2.103673 0.007257 +/- 0.004635 0.000279 +/- 0.000010 0.399 +/- 0.019 % 99.294 +/- 0.235 %

@pl752
Copy link
Copy Markdown
Author

pl752 commented Apr 4, 2026

I am somewhat in doubt now, it seems something around the effect of comparing cpu to cuda, or something inbetween fp32->fp16 and fp32->q8_0, maybe it is from using smaller model
Note: I am comparing between my implementations, I think I need to use fp16 as a baseline first

@khosravipasha
Copy link
Copy Markdown
Collaborator

khosravipasha commented Apr 4, 2026

@pl752 Awesome thanks for the explnations.
And for the KL's look pretty good, being close to 0 is good. The rest is numerical noise probably since also llama.cpp side they convert logits to fp16 to save time (I only run few chunks myself too), llama.cpp tool was designed to see how good their quantizations are, for us the weights are equivalent packed and unpacked so having KL close to zero for few chunks is good enough (this is mostly to test the kernels and not the quantization itself)

https://github.com/ggml-org/llama.cpp/tree/master/tools/perplexity

Yeah I used running the model in fp16 as the baeslines using these https://huggingface.co/collections/prism-ml/bonsai-auxiliary

@pl752
Copy link
Copy Markdown
Author

pl752 commented Apr 4, 2026

Okay, don't forget to thank the user from which I've hijacked AVX-512 implementation

@khosravipasha
Copy link
Copy Markdown
Collaborator

khosravipasha commented Apr 4, 2026

@pl752 good idea, which one was it? We can tag them here,
Right now only sending PR to llama.cpp with generic cpu to finalize the naming, formatting, etc.

After that's merged, then can all send a PR together with everyone that contributed tagged in main llama.cpp maybe.

Note that there will be some naming changes (in summary Q1_0_g128 is renamed to Q1_0, and original Q1_0 will be deleted). Should not affect running the current models.

ggml-org#21273

@pl752
Copy link
Copy Markdown
Author

pl752 commented Apr 4, 2026

Note that this PR is built on top of the #3 by @jordankzf, who implemented AVX-512 workflow

@pl752
Copy link
Copy Markdown
Author

pl752 commented Apr 5, 2026

Performed additional 5x512 run against unpacked gguf

Flow Mean PPL(Q) Mean ln(PPL(Q)/PPL(base)) Mean KLD RMS Δp Same top p
Scalar 21.082185 +/- 2.102340 0.005412 +/- 0.004643 0.000213 +/- 0.000008 0.334 +/- 0.017 % 99.451 +/- 0.207 %
SSSE3 21.076136 +/- 2.102022 0.005125 +/- 0.004661 0.000220 +/- 0.000008 0.341 +/- 0.016 % 99.137 +/- 0.259 %
AVX 21.081167 +/- 2.102227 0.005364 +/- 0.004690 0.000235 +/- 0.000010 0.362 +/- 0.017 % 99.373 +/- 0.221 %
AVX2 21.087163 +/- 2.103328 0.005649 +/- 0.004643 0.000216 +/- 0.000009 0.377 +/- 0.023 % 99.608 +/- 0.175 %
AVX-512BW 21.095567 +/- 2.103673 0.006047 +/- 0.004636 0.000222 +/- 0.000009 0.365 +/- 0.020 % 99.059 +/- 0.271 %

@pl752
Copy link
Copy Markdown
Author

pl752 commented Apr 5, 2026

UPD: I have reviewed how I was interleaving instructions when testing various register pressure options and found issues resulting in register spilling, so I just relied on the compiler doing its job properly and simply unrolled inner loop with individual accumulators for SSSE3 (as the compiler already did pretty well for other flows); I have also tried the same thing for AVX-512, but it did result in tiny performance regression. It had almost no effect on perplexity.

Effects on performance, (baseline has drifted due to using -t10 instead of -t6):

flow run baseline updated delta
SSSE3 pp512 33.38 t/s 39.18 t/s +17.36%
SSSE3 tg128 24.61 t/s 29.24 t/s +18.81%

Copy link
Copy Markdown

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR focuses on improving CPU inference throughput by optimizing the q1_0 / q1_0_g128 dot-product kernels against q8_0, reducing bit-twiddling overhead in portable fallbacks and introducing additional optimized x86 SIMD execution paths.

Changes:

  • Reworked generic fallbacks to process packed sign bits in a byte-oriented way (4 × 8-value groups per 32-element sub-block), eliminating per-element bit index arithmetic.
  • Implemented x86-specialized kernels for ggml_vec_dot_q1_0_q8_0 and ggml_vec_dot_q1_0_g128_q8_0 with multiple SIMD paths (SSSE3 / AVX / AVX2 / AVX-512BW) plus scalar byte-oriented fallback.
  • Added small SSSE3 helpers to expand packed sign bits into byte masks and to reduce vector accumulators.

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated no comments.

File Description
ggml/src/ggml-cpu/quants.c Optimizes portable q1_0 and q1_0_g128 generic dot fallbacks by switching to explicit byte-oriented sign decoding and removing per-element bit math.
ggml/src/ggml-cpu/arch/x86/quants.c Replaces x86 dispatch to generic kernels with specialized SIMD implementations across AVX-512BW/AVX2/AVX/SSSE3, keeping a byte-oriented scalar fallback.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

@zcattacz
Copy link
Copy Markdown

zcattacz commented Apr 6, 2026

I tested the AVX2 impl, slightly faster then #7 (see the est full test time) but slower than xor+sub. Maybe the reported 0.00022 KLD is arch related (Tigerlake and Broadwell are both Intel CPU). I have tried several impl on Broadwell, all hit the same KLD after the first few chunks, thus later there's little point to run the full test just to confirm the KLD.

PR10 with mul_sum_i8_pairs_float

system_info: n_threads = 2 (n_threads_batch = 2) / 4 | CPU : SSE3 = 1 | SSSE3 = 1 | AVX = 1 | AVX2 = 1 | F16C = 1 | FMA = 1 | BMI2 = 1 | LLAMAFILE = 1 | OPENMP = 1 | REPACK = 1 | 
kl_divergence: computing over 100 chunks, n_ctx=512, batch_size=2048, n_seq=4
kl_divergence: 150.86 seconds per pass - ETA 1 hours 2.85 minutes

chunk             PPL               ln(PPL(Q)/PPL(base))          KL Divergence              Δp RMS            Same top p
   1      13.9557 ±    3.1807      -0.00019 ±    0.00239       0.00019 ±    0.00002     0.376 ±  0.047 %    99.608 ±  0.392 %
   2      20.1986 ±    3.4363       0.01428 ±    0.01146       0.00020 ±    0.00001     0.346 ±  0.030 %    99.608 ±  0.277 %
   3      20.8582 ±    2.7888       0.00944 ±    0.00766       0.00021 ±    0.00001     0.375 ±  0.025 %    99.216 ±  0.319 %
   4      21.2096 ±    2.3896       0.00684 ±    0.00577       0.00022 ±    0.00001     0.385 ±  0.026 %    99.412 ±  0.240 %
   5      21.0872 ±    2.1033       0.00566 ±    0.00464       0.00022 ±    0.00001     0.376 ±  0.023 %    99.529 ±  0.192 %
   6      21.2932 ±    1.9099       0.00549 ±    0.00390       0.00021 ±    0.00001     0.362 ±  0.020 %    99.477 ±  0.184 %
   7      21.4337 ±    1.7665       0.00508 ±    0.00335       0.00021 ±    0.00001     0.365 ±  0.020 %    99.440 ±  0.177 %
   8      23.1788 ±    1.8031       0.00527 ±    0.00297       0.00021 ±    0.00001     0.364 ±  0.018 %    99.412 ±  0.169 %
   9      24.6955 ±    1.8365       0.00752 ±    0.00337       0.00022 ±    0.00001     0.355 ±  0.017 %    99.390 ±  0.163 %
  10      25.4214 ±    1.7879       0.00672 ±    0.00303       0.00022 ±    0.00001     0.353 ±  0.015 %    99.294 ±  0.166 %
  11      26.0683 ±    1.7516       0.00617 ±    0.00276       0.00022 ±    0.00001     0.354 ±  0.014 %    99.287 ±  0.159 %
  12      26.5272 ±    1.7091       0.00582 ±    0.00254       0.00022 ±    0.00001     0.351 ±  0.013 %    99.346 ±  0.146 %
xor+sub (like PR4)

system_info: n_threads = 2 (n_threads_batch = 2) / 4 | CPU : SSE3 = 1 | SSSE3 = 1 | AVX = 1 | AVX2 = 1 | F16C = 1 | FMA = 1 | BMI2 = 1 | LLAMAFILE = 1 | OPENMP = 1 | REPACK = 1 | 
kl_divergence: computing over 100 chunks, n_ctx=512, batch_size=2048, n_seq=4
kl_divergence: 115.23 seconds per pass - ETA 48.00 minutes

chunk             PPL               ln(PPL(Q)/PPL(base))          KL Divergence              Δp RMS            Same top p
   1      13.9528 ±    3.1791      -0.00040 ±    0.00223       0.00019 ±    0.00002     0.382 ±  0.053 %    99.608 ±  0.392 %
   2      20.1970 ±    3.4355       0.01420 ±    0.01145       0.00019 ±    0.00001     0.343 ±  0.033 %    99.608 ±  0.277 %
   3      20.8596 ±    2.7888       0.00950 ±    0.00765       0.00021 ±    0.00001     0.351 ±  0.026 %    99.346 ±  0.292 %
   4      21.2115 ±    2.3896       0.00693 ±    0.00576       0.00022 ±    0.00001     0.369 ±  0.025 %    99.510 ±  0.219 %
   5      21.0887 ±    2.1034       0.00573 ±    0.00463       0.00022 ±    0.00001     0.363 ±  0.022 %    99.608 ±  0.175 %
   6      21.2944 ±    1.9099       0.00555 ±    0.00389       0.00021 ±    0.00001     0.351 ±  0.019 %    99.542 ±  0.173 %
   7      21.4348 ±    1.7665       0.00513 ±    0.00334       0.00021 ±    0.00001     0.355 ±  0.020 %    99.496 ±  0.168 %
PR7 with _mm256_shuffle_epi8

system_info: n_threads = 2 (n_threads_batch = 2) / 4 | CPU : SSE3 = 1 | SSSE3 = 1 | AVX = 1 | AVX2 = 1 | F16C = 1 | FMA = 1 | BMI2 = 1 | LLAMAFILE = 1 | OPENMP = 1 | REPACK = 1 | 
kl_divergence: computing over 100 chunks, n_ctx=512, batch_size=2048, n_seq=4
kl_divergence: 186.99 seconds per pass - ETA 1 hours 17.90 minutes

chunk             PPL               ln(PPL(Q)/PPL(base))          KL Divergence              Δp RMS            Same top p
   1      13.9733 ±    3.1846       0.00107 ±    0.00236       0.00020 ±    0.00002     0.402 ±  0.048 %    99.608 ±  0.392 %
   2      20.2038 ±    3.4373       0.01454 ±    0.01146       0.00022 ±    0.00002     0.375 ±  0.029 %    99.608 ±  0.277 %
   3      20.8431 ±    2.7865       0.00871 ±    0.00766       0.00023 ±    0.00001     0.387 ±  0.026 %    98.693 ±  0.411 %
   4      21.1827 ±    2.3859       0.00558 ±    0.00577       0.00023 ±    0.00001     0.378 ±  0.022 %    99.020 ±  0.309 %
   5      21.0675 ±    2.1012       0.00473 ±    0.00465       0.00022 ±    0.00001     0.379 ±  0.019 %    99.137 ±  0.259 %
   6      21.2662 ±    1.9072       0.00422 ±    0.00390       0.00022 ±    0.00001     0.381 ±  0.018 %    99.085 ±  0.244 %
   7      21.4126 ±    1.7643       0.00409 ±    0.00335       0.00022 ±    0.00001     0.374 ±  0.016 %    98.992 ±  0.237 %

@pl752
Copy link
Copy Markdown
Author

pl752 commented Apr 6, 2026

@zcattacz Thank you for the hint, it worked at least for at least AVX2, I will revise my current kernels and post updates

@zcattacz
Copy link
Copy Markdown

zcattacz commented Apr 6, 2026

@pl752 , oh. my bad, I misread your KLD. Are they all tested on AMD. it's also around 0.00022. The xor+sub is adapted from PR4. If you are after speed, please give it a try. You can find the code I tested for AVX2 from my comment in #7. Even the shadowed variable gives it a 5%~10% boost. I also tested double accumulator impl, but it didn't give any edge. The compiler seems to be doing some magic here.
I did a similar SSSE3 test (code also in #7) for KLD, I didn't save the result, but since it's so slow on i5, not gonna to do it again. iirc, the KLD is also ~0.00022.

@pl752
Copy link
Copy Markdown
Author

pl752 commented Apr 6, 2026

@zcattacz They all tested on AMD Ryzen 5 7640HS (Zen 4)

@pl752
Copy link
Copy Markdown
Author

pl752 commented Apr 6, 2026

UPD2: Okay, I have applied the advice and it resulted in positive performance changes and no significant perplexity changes. However I removed the AVX512 branch now relying on compiler taking advantage of some of the register and instruction layout changes, as I failed to achieve meaningful performance increase past current AVX2 flow on my Zen4, moreover AVX2 is variant is slightly faster even. Somebody with Zen5 or modern Intel Xeon should take a look and experiment.

Performance changes (t=10)

flow run baseline updated delta
SSSE3 pp512 58.19 t/s 72.79 t/s +25.10%
SSSE3 tg128 37.19 t/s 45.07 t/s +21.19%
AVX pp512 59.72 t/s 75.17 t/s +25.87%
AVX tg128 38.06 t/s 46.55 t/s +22.30%
AVX2 pp512 95.60 t/s 119.14 t/s +24.64%
AVX2 tg128 56.61 t/s 69.32 t/s +22.45%
AVX512 pp512 121.25 t/s 124.16 t/s +2.40%
AVX512 tg128 67.66 t/s 70.73 t/s +4.52%

Also what has happened to SSSE3 performance, its baseline seemingly increased out of nowhere? Turns out rebooting sometimes significantly increases performance

@zcattacz
Copy link
Copy Markdown

zcattacz commented Apr 6, 2026

Hi @pl752, nice improvement. If you want to squeeze a bit more juice, pls try the code in #7 (comment) , it's simpler but the compiler makes the single accumulator impl even faster than the double accumulator (gives another 2~3tps for free on an i5).

loop + double accumulator (current)
bin/cpu/llama-bench -m models/gguf/1.7B/Bonsai-1.7B.gguf -p 512 -n 128
| model                          |       size |     params | backend    | threads |            test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | ------: | --------------: | -------------------: |
| qwen3 1.7B Q1_0                | 231.13 MiB |     1.72 B | CPU        |       2 |           pp512 |         16.98 ± 0.23 |
| qwen3 1.7B Q1_0                | 231.13 MiB |     1.72 B | CPU        |       2 |           tg128 |         13.20 ± 0.60 |


no loop/if + single accumulator (from above comment)
| model                          |       size |     params | backend    | threads |            test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | ------: | --------------: | -------------------: |
| qwen3 1.7B Q1_0                | 231.13 MiB |     1.72 B | CPU        |       2 |           pp512 |         18.40 ± 1.04 |
| qwen3 1.7B Q1_0                | 231.13 MiB |     1.72 B | CPU        |       2 |           tg128 |         13.79 ± 0.74 |

@pl752
Copy link
Copy Markdown
Author

pl752 commented Apr 6, 2026

That's interesting, two accumulators gave better performance before the xor+sub change, now it's other way around; baselines has drifted once more though

flow run baseline updated delta
AVX2 pp512 122.04 t/s 130.11 t/s +6.61%
AVX2 tg128 70.84 t/s 73.29 t/s +3.46%
AVX-512 pp512 129.46 t/s 136.65 t/s +5.55%
AVX-512 tg128 74.25 t/s 77.21 t/s +3.99%

Couldn't confirm KLD changes thhough, @zcattacz , check that I haven't made a mistake there

@khosravipasha
Copy link
Copy Markdown
Collaborator

Good new our first CPU PR just got merged int llama.cpp master branch now, if you are still working on this please rebase with PrismML's master (just pulled the main llama.cpp)

Changes: Q1_0_g128 naming is gone now, the original Q1_0 with group size 32 was deleted and Q1_0_g128 was renamed to Q1_0 now by default has group size 128.

https://github.com/PrismML-Eng/llama.cpp/tree/master

This one only has generic cpu (slow), and ARM NEON path, planning to gather the best x86 kernels from here and to send a PR there (and tag all the contributers).

@zcattacz
Copy link
Copy Markdown

zcattacz commented Apr 7, 2026

@pl752, yeah, I also tested double accumulator with a full unrolled version, but the speed is still a marginal net loss. Looks like FMA is not the bottleneck.

Could you try this and see if it works better for you?
The benchmark shows no overhead from the inline func but I'm not sure if it's worth it. On Broadwell _mm256_sign_epi8 leads to 1tps loss, if later architectures are more efficient, we should expect more than 1tps gain, right?

// AVX2 optimization helper
static inline __m256i avx2_apply_sign_helper(const __m256i * y, const __m256i * sm, const __m256i * ones_8)
{
#if defined(__CLWB__) || defined(__SGX__) || defined(__AVX512F__) 
    // Skylake and Zen3 onwards
    // This is a bottle neck on older arch like Haswell / Broadwell, 1TPS drop.
    return _mm256_sign_epi8(*y, _mm256_or_si256(*sm, *ones_8));
#else
    return _mm256_sub_epi8(_mm256_xor_si256(*y, *sm), *sm);
#endif
}

void ggml_vec_dot_q1_0_g128_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
    const int qk = QK1_0_g128;
    const int nb = n / qk;

    assert(n % qk == 0);
    assert(nrc == 1);
    UNUSED(nrc);
    UNUSED(bx);
    UNUSED(by);
    UNUSED(bs);

    const block_q1_0_g128 * GGML_RESTRICT x = vx;
    const block_q8_0 * GGML_RESTRICT y = vy;

#if defined(__AVX2__)
    const __m256i ones_8 = _mm256_set1_epi8(1);
    const __m256i ones_16 = _mm256_set1_epi16(1);
    const __m256i byte_shuf = _mm256_setr_epi8(0,0,0,0,0,0,0,0,1,1,1,1,1,1,1,1,2,2,2,2,2,2,2,2,3,3,3,3,3,3,3,3);
    const __m256i bit_masks = _mm256_setr_epi8(1,2,4,8,16,32,64,(char)-128,1,2,4,8,16,32,64,(char)-128,1,2,4,8,16,32,64,(char)-128,1,2,4,8,16,32,64,(char)-128);
    const __m256i zero = _mm256_setzero_si256();
    __m256 acc = _mm256_setzero_ps();



    for (int ib = 0; ib < nb; ++ib) {
        const float d0 = GGML_CPU_FP16_TO_FP32(x[ib].d);
        const uint32_t *qs32 = (const uint32_t *)x[ib].qs;
        const block_q8_0 *y_ptr = &y[ib * 4];
        // y is deliberately left shadowed for a measurable performance gain
        __m256 acc_block;
        {
            const __m256i y = _mm256_loadu_si256((const __m256i *)y_ptr[0].qs);
            const __m256i sm = _mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(_mm256_set1_epi32((int)qs32[0]), byte_shuf), bit_masks), zero);
            const __m256i sy = avx2_apply_sign_helper(&y, &sm, &ones_8);
            const __m256i s32 = _mm256_madd_epi16(_mm256_maddubs_epi16(ones_8, sy), ones_16);
            acc_block = _mm256_mul_ps(_mm256_set1_ps(GGML_CPU_FP16_TO_FP32(y_ptr[0].d)), _mm256_cvtepi32_ps(s32));
        }
    #define Q1_AVX2_BLOCK(K) \
        { \
            const __m256i y = _mm256_loadu_si256((const __m256i *)y_ptr[K].qs); \
            const __m256i sm = _mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(_mm256_set1_epi32((int)qs32[K]), byte_shuf), bit_masks), zero); \
            const __m256i sy = avx2_apply_sign_helper(&y, &sm, &ones_8); \
            const __m256i s32 = _mm256_madd_epi16(_mm256_maddubs_epi16(ones_8, sy), ones_16); \
            acc_block = _mm256_fmadd_ps(_mm256_set1_ps(GGML_CPU_FP16_TO_FP32(y_ptr[K].d)), _mm256_cvtepi32_ps(s32), acc_block); \
        }
        Q1_AVX2_BLOCK(1) Q1_AVX2_BLOCK(2) Q1_AVX2_BLOCK(3)
    #undef Q1_AVX2_BLOCK
        acc = _mm256_fmadd_ps(_mm256_set1_ps(d0), acc_block, acc);
    }
    {
        const __m128 h = _mm_add_ps(_mm256_extractf128_ps(acc, 0), _mm256_extractf128_ps(acc, 1));
        const __m128 q = _mm_add_ps(h, _mm_movehl_ps(h, h));
        *s = _mm_cvtss_f32(_mm_add_ss(q, _mm_movehdup_ps(q)));
    }

Since it's slow, I only run the full perplexity test in the initial tuning. Can't find the numbers, but I recall my Max KLD was way higher than what was reported in PR7. It drops down after the initial FMA is replaced with simple MUL.
I'm not quite sure why. AI said AxB+0 might introduce more rounding error. Anyway it just made a difference.

@pl752
Copy link
Copy Markdown
Author

pl752 commented Apr 7, 2026

That (sign alt) was few percent slower unfortunately in my case too

@pl752
Copy link
Copy Markdown
Author

pl752 commented Apr 7, 2026

Brought the code to uniform structure, insignificant changes in ASM for AVX, no measurable changes in perplexity or performance, will prepare for rebase

@pl752 pl752 changed the base branch from prism to master April 7, 2026 06:51
@pl752 pl752 force-pushed the perf/q1_0_g128_no_nofma branch from b793ed1 to 195593b Compare April 7, 2026 06:53
@pl752
Copy link
Copy Markdown
Author

pl752 commented Apr 7, 2026

I think yes, we can write a draft and then send it to main tree. However I am going to sleep currently, so I will help later. I have been testing perplexity all the way through to avoid breaking things, but additional tests won't hurt. Then benchmarks for the scalar and SIMD branches need to be redone and cleaned up summary added. In my current implementation SSSE3 is for most of cpus, AVX helps with fp32 accum part, but difference in performance is questionable, AVX 2 handles modern-ish cpus as the most performant way, and AVX-512 specific branch was discarded (due to me failing to obtain any improvements over AVX2 with AVX-512 flag set on my Zen4), but it is still mentioned in my latest benchmarks, as compiler still produces more optimized code for AVX2 with AVX-512 enabled due to AVX-512 providing 32 SIMD registers instead of 16 (aside from fact that their max length extends to 512 bit, which isn't used there) allowing some additional freedom during applying O3 opts.

@pl752
Copy link
Copy Markdown
Author

pl752 commented Apr 7, 2026

So I can try creating PR draft myself tommorow and tag everybody from discussion and remained not my own code if any is left, then we will look into the next steps, the code itself seems to be pretty clean.

@khosravipasha
Copy link
Copy Markdown
Collaborator

khosravipasha commented Apr 8, 2026

Thanks @pl752 of course take your time, sleep is more important :D

Closed the other CPU PRs. We can collectively send a PR to main llama.cpp when this branch is ready.

Tagging people that helped in other PRs (let me know if I missed any)

@Marxist-Leninist
Copy link
Copy Markdown

Comet Lake data point from my i7-10510U (4C/8T, laptop, Windows MinGW-w64 UCRT gcc 15.2), same Bonsai-1.7B.gguf, llama-bench -p 512 -n 128 -r 3 on this branch at e29cd48. Fills the gap between @zcattacz's Broadwell / Tigerlake runs and the Zen4 numbers above.

Config pp512 (t/s) tg128 (t/s)
Scalar (branch generic), -t 4 2.89 ± 0.13 2.39 ± 0.12
Scalar (branch generic), -t 8 3.99 ± 0.02 2.62 ± 0.16
AVX2 + FMA + F16C, -t 4 26.65 ± 3.47 13.61 ± 1.72
AVX2 + FMA + F16C, -t 8 33.70 ± 1.87 11.51 ± 0.80

Speedups vs generic scalar on the same branch: pp512 9.2× at t=4 / 8.4× at t=8; tg128 5.7× at t=4 / 4.4× at t=8.

One pattern worth noting: on this chip tg128 actually peaks at physical-core count (13.61 at t=4) and drops at t=8 (11.51). The AVX2 kernel is tight enough that memory bandwidth becomes the bottleneck and SMT contention hurts generation, while pp512 still benefits from SMT because it's compute-bound. Might be worth mentioning for users picking thread counts on small-L3 mobile Intel parts.

@pl752
Copy link
Copy Markdown
Author

pl752 commented Apr 8, 2026

@Marxist-Leninist Thank you for an additional insight. Using SMT is known to significantly increase memory pressure, this is the reason I have used physical core count on initial benchmarks to avoid memory bottleneck then through benchmarks I found out that for my system 10 threads (logical thr count - 2) was yielding max tg and near max pp, while setting to 12 threads (all threads) didn't significantly increase pp, while tg has slightly reduced, so that's pretty common thing for systems with many threads (also there was/is recommendation to use nproc - 2, or even ncore - 2 in case system has many cores, as usually memory in system designs is scaled worse than compute power (aka 16 core ryzens get same bandwidth (ddr4/5 dual channel) as lower core counts and systems like threadrippers can have even higher core to memory bandwidth ratio.

@Marxist-Leninist
Copy link
Copy Markdown

Thanks @pl752 — that matches exactly. Just ran the same branch on Bonsai-8B.gguf on the same i7-10510U for comparison:

Config pp512 (t/s) tg128 (t/s)
AVX2 + FMA + F16C, -t 4 6.11 ± 0.12 4.55 ± 0.16
AVX2 + FMA + F16C, -t 8 7.06 ± 0.39 3.65 ± 0.23

Same pattern holds at 8B scale: tg128 peaks at physical-core count (4.55 at t=4, drops to 3.65 at t=8), pp512 still gains from SMT (6.11 → 7.06). So the ncore - 1 / ncore - 2 heuristic for tg on memory-bound consumer chips definitely applies here — worth a note in the final PR description.

Continuity note with my closed #4: my original d603bf4 AVX2 kernel on the same chip + same 8B model measured 4.7 pp / 3.1 tg at -t 4. This branch is 6.11 pp / 4.55 tg at -t 4+30% pp, +47% tg over that baseline on the same hardware. So the single-accumulator tuning and the mul-instead-of-fma-on-first-sub-block refinement (from @zcattacz's testing advice) are a measurable improvement even on Comet Lake, not just Zen4. Nice iteration chain.

@Marxist-Leninist
Copy link
Copy Markdown

Two more data points on the same i7-10510U, this branch at e29cd48, chasing faster 8B generation on a memory-pressured laptop.

Bonsai-4B, completing the 1.7B → 4B → 8B series:

Config pp512 tg128
CPU AVX2 + FMA, -t 4 6.09 3.81
Vulkan iGPU (Intel UHD), -ngl 99 4.11 2.06

Consistent with the 1.7B and 8B results I posted earlier — the Intel UHD has int dot: 0 (no integer dot product) so our maddubs/madd kernel advantage disappears and the iGPU falls back to fp16 emulation. ~2× slower than CPU across all three sizes. Probably worth a note for anyone considering integrated-GPU deployment: for Q1_0 specifically, CPU with AVX2 is the right backend on Intel UHD class hardware.

CPU-side micro-optimizations on 8B:

Config tg128 (t/s) Δ vs baseline
Baseline (-fa 0 -ctk f16 -ctv f16) 4.73 (best) / 4.16 (avg)
--mlock -b 2048 4.55 avg over 3 runs +9.3%
-fa 1 -ctk q8_0 -ctv q8_0 1.38 −71% ❌
Speculative (-md Bonsai-1.7B --draft-max 8) 4.39 +5.5% (noise)

Two things worth flagging for the upstream PR:

  1. --mlock is a real win under memory pressure — on my system (32 GB physical / 73 GB committed) it reliably recovered ~10% on generation and ~15% on prompt. Weights were being paged out mid-inference without it. Might be worth mentioning in the release notes / README as a recommended flag for low-RAM machines.

  2. -fa 1 -ctk q8_0 -ctv q8_0 is disastrously bad on CPU for Q1_0 (3.4× slower). The CPU flash attention path + quantized KV dequant overhead absolutely tanks throughput on Q1_0 workloads. Users with large contexts on GPU-quant KV recipes should not blindly copy those flags to CPU. Not a regression in this PR — pre-existing llama.cpp behavior — but worth documenting since the Q1_0 CPU audience is going to be RAM-constrained and tempted to try it.

  3. Speculative decoding (1.7B as draft for 8B) gives only +5.5% which is within noise. Draft-to-target speed ratio is ~3×, well below the 10-50× needed for spec-decode to meaningfully win. A much smaller draft (~0.5B) might help, but probably not worth the complexity for this model family.

The measurable hard ceiling on this chip is ~4.7 t/s on 8B tg128, bandwidth-bound by single-channel DDR4-2667. Notably, 4B and 8B run at identical ~3.8 t/s on this chip because both are already hitting the RAM bandwidth wall — zero throughput penalty for picking 8B over 4B on this hardware class.

@Marxist-Leninist
Copy link
Copy Markdown

Runtime tuning notes from testing this branch (Windows / MinGW, Comet Lake)

Kicked the tires on this branch against 8B Q1_0_g128 on a laptop CPU and a Gen9.5 iGPU — posting the numbers in case they help triage follow-ups or adjacent PRs (#9 Vulkan, #18 CUDA).

Results (8B Q1_0_g128, -t 4, 512-prompt / 128-gen, llama-server /v1/chat)

Config pp t/s gen t/s vs baseline
Baseline (-fa 0 -ctk f16 -ctv f16) 5.71 4.16 —
-fa 1 -ctk q8_0 -ctv q8_0 (short ctx) 3.51 3.06 −26% gen
-fa 1 -ctk q8_0 -ctv q8_0 (4K ctx) 3.09 1.38 −67% gen
Vulkan iGPU (Intel UHD, -ngl 99) 4.52 1.82 −56% gen
Speculative decoding (8B + 1.7B draft, --draft-max 8) 6.17 4.39 +5.5% (≈noise)
--mlock + -b 2048 6.50 4.55 +9.3% gen, +14% pp

Observations worth flagging

  1. -fa 1 -ctk q8_0 -ctv q8_0 is a net loss for Q1_0 on CPU at any context. The KV dequant overhead per read dominates even though the cache "shrinks" from f16 to q8_0. At 4K context the combo is 3.4× slower on gen than baseline. Easy trap — might be worth a one-liner warning in a Q1_0 README.

  2. Vulkan iGPU is a dead-end for Q1_0 on Gen9.x UHD. Device reports int dot: 0 (Comet Lake UHD is Gen9.5, no dp4a). This branch's whole advantage is AVX2 maddubs/madd int8 ops; on this iGPU those fall back to fp16 MADs and run ~2× slower despite unified memory. Might be relevant context for vulkan: add Q1_0_g128 (1-bit ternary) shader support #9 — Q1_0 Vulkan support should probably be advertised as Gen11+ / Xe only (i.e. hardware with dp4a / XMX / DPAS). On pre-Gen11 the iGPU is strictly worse than CPU.

  3. Speculative decoding (1.7B draft → 8B target) is neutral, not a win. Draft-to-target speed ratio is only ~3× (13.6 / 4.16 t/s) while the spec-decode rule-of-thumb needs ≥10× to pay off. A much smaller (~0.5B) same-family draft would probably flip this, but with only 1.7B as the smallest option it's a wash.

  4. Memory bandwidth is the hard ceiling, not compute. 4B tg (3.81 t/s) ≈ 8B tg (3.73 t/s) — both saturate DDR4-2667 single-channel at ~4 t/s. Doubling the model size costs ~zero wall-clock on this CPU. The corollary: on bandwidth-limited boxes, future wins for this kernel are mostly in access patterns (prefetch, packing, cache-line locality) rather than more ALU tricks. The existing COM6-style 64-row block + software prefetch path is therefore doing exactly the right thing.

  5. --mlock gives a real +9.3% gen / +14% pp under memory pressure. Not new, but specifically notable for Q1_0 because the weights are dense-random-access; any page-out during inference tanks throughput. Worth mentioning in user-facing docs for anyone running Q1_0 on a memory-constrained box.

Environment

  • CPU: Intel i7-10510U (Comet Lake, 4c/8t, DDR4-2667 single-channel — note the single-channel caveat)
  • GPU tested: Intel UHD Graphics (Gen9.5, int dot: 0, fp16: 1, bf16: 0, no DP4A/DPAS)
  • OS: Windows 11 Pro, MinGW-w64 (GCC 15.2.0)
  • Build: this branch at b100-e29cd48
  • Models: Bonsai-8B / 4B / 1.7B Q1_0_g128

Happy to rerun with specific flags if any of these numbers look off — the --mlock vs no-mlock gap is the most reproducible of the bunch.

@Marxist-Leninist
Copy link
Copy Markdown

One more round of data on this PR — trying to find speedup on 8B tg and hitting a ceiling that isn't in the kernel.

tl;dr

The kernel in this PR isn't the bottleneck on either of the CPUs I tested. On a power-limited laptop it's PL1 throttling, and on a Skylake-SP Xeon VM the kernel is still fine — a hand-written AVX-512BW variant was within noise on 8B and slower on 1.7B. Signal seems to be that future wins live at the framework/scheduling layer, not in the SIMD inner loop.

i7-10510U (Comet Lake, 4C/8T, 15W, DDR4-2667 single-channel)

Thread sweep (Bonsai-8B Q1_0, tg128, --mlock):

-t 1 2 3 4 5 6 7 8
t/s 2.50 3.67 4.35 4.62 4.70 4.72 4.59 4.58

-t 6 beats -t 4 by ~2% on this 4C/8T chip and is the new optimum for the router. Parallel efficiency at -t 6 is ~31% (Amdahl sequential fraction ≈ 38.8%), so there's a lot of serial work between parallel sections.

Memory bandwidth vs achieved throughput (AVX2 aligned-load microbench on the 1.07 GB Q1_0 file):

-t raw stream GB/s llama 8B tg128 GB/s eq utilization
1 14.80 2.68 18%
2 16.91 3.93 23%
3 16.76 4.65 28%
4 15.49 4.94 32%
6 14.51 5.05 35%

One thread already saturates ~88% of the practical DDR4-2667 single-channel ceiling (~16.9 GB/s). But llama.cpp Q1_0 tops out at ~35% of that. There's ~2.8× theoretical headroom that isn't in the kernel — it's in the per-layer orchestration (OpenMP barriers between the ~32 layers, softmax/layernorm serialization, non-weight traffic, scale loads, KV reads).

CPU frequency under load:

  • Idle: 3.41 GHz
  • Sustained llama-bench: 1.95 GHz (−43%)
  • 1 second after load ends: 3.37 GHz (instant recovery)

Instant recovery rules out thermal — this is the i7-10510U firmware PL1=15W clamp kicking in after the ~28s Tau window. Theoretical unlocked ceiling is 3.4 / 1.95 ≈ 1.74×, which would put 8B at ~8.2 t/s. Needs admin + ThrottleStop/XTU MSR writes; can't be fixed from llama.cpp.

Null-result tweaks (all within noise of 4.72 t/s baseline):

  • OMP_WAIT_POLICY=ACTIVE — 4.83 (spin-wait burns power budget the chip needs for compute)
  • OMP_PROC_BIND=true — 4.78
  • GOMP_CPU_AFFINITY="0 2 4 6" physical cores only — 4.74
  • --numa distribute — 4.83
  • --numa isolate — 4.24 (−10% regression). Worth knowing about.

Skylake-SP Xeon VM (AVX2 + AVX-512F/BW/CD/DQ/VL, no VNNI, 16 cores)

Baseline on -t 4 -n 64 -r 5: 4.96 ± 0.09 t/s on 8B, 18.81 ± 0.43 t/s on 1.7B.

Tried a hand-written AVX-512BW variant of ggml_vec_dot_q1_0_q8_0 that widens the existing maddubs + madd chain from 256-bit to 512-bit by processing 2 Q8_0 blocks (K=0/1 then K=2/3) per iteration, assembled via _mm512_inserti32x8 from 256-bit loads. Compiled with GGML_AVX512=ON, verified ZMM opcodes in objdump (vpmaddubsw zmm, vpmaddwd zmm), verified identical text output vs the AVX2 binary on a fixed-seed prompt.

model AVX2 baseline AVX-512BW Δ
8B Q1_0 tg64 4.96 ± 0.09 4.98 ± 0.05 +0.4% (noise)
1.7B Q1_0 tg128 18.81 ± 0.43 17.66 ± 0.53 −6.1% (significant)

The 1.7B regression is statistically significant. Two suspected causes:

  1. Skylake-SP AVX-512 license downclock — heavy 512-bit ops (vpmaddubsw, vfmadd231ps) step the core from license 0 to license 2, dropping clock 10-20%. At 2× width with 15-20% downclock, the per-cycle win shrinks to ~1.6×.
  2. _mm512_inserti32x8 cross-lane overhead — 3-cycle latency port 5 serialisation, done twice per iteration (qy01, sm01). Eats the throughput gain.

A VNNI (vpdpbusd) path would probably unlock real speedup on Ice Lake+ by collapsing maddubs + madd into one instruction and avoiding the inserts, but Skylake-SP doesn't have it and the Comet Lake laptop doesn't either. I'm posting this as a negative result specifically so nobody else spends time trying the obvious AVX-512 widening on pre-VNNI hardware.

What I think the PR is actually done

Scott's xor+sub kernel (now the hot path here) is extracting everything the clock will give. On both test machines the bottleneck is outside the vec_dot — PL1 on laptop, parallel-sync overhead + per-layer serial work on the Xeon. Worth keeping in mind if anyone else shows up with "let me try AVX-512 / NEON SVE2 / etc" — the 2.8× theoretical ceiling from membw is there, but you have to go looking for it in the layer loop, not in the dot product.

@pl752
Copy link
Copy Markdown
Author

pl752 commented Apr 8, 2026

Okay, @khosravipasha , so what do we need to do next? Do I just create PR for main llama.cpp or something else needs to be done? I think the code is pretty much ready and also I have acquired final benchmark numbers and tested the perplexity and test-quantize-fns once more

@pl752
Copy link
Copy Markdown
Author

pl752 commented Apr 8, 2026

I have opened a draft ggml-org#21636, so it is visible on main repo

@khosravipasha
Copy link
Copy Markdown
Collaborator

Sounds good thanks for putting it all together.

@pl752
Copy link
Copy Markdown
Author

pl752 commented Apr 8, 2026

So, what are the next steps?

@pl752
Copy link
Copy Markdown
Author

pl752 commented Apr 8, 2026

I also think (due to things @Marxist-Leninist is describing) that some alternative geometry options can be explored (nrows > 1, or even more ambitious things like repack and specialized kernels) to try to aleviate suspected memory pressure at least for pp

@pl752
Copy link
Copy Markdown
Author

pl752 commented Apr 8, 2026

I have vibe-coded small experiment (to essentially do what is already done for ARM NEON NVM, Idk where have I seen that, and add path for nrc = 2), as a result it definitely has some potential for developing further

Patch
diff --git a/ggml/src/ggml-cpu/arch/x86/quants.c b/ggml/src/ggml-cpu/arch/x86/quants.c
index d31b454..7a798b2 100644
--- a/ggml/src/ggml-cpu/arch/x86/quants.c
+++ b/ggml/src/ggml-cpu/arch/x86/quants.c
@@ -557,16 +557,12 @@ void ggml_vec_dot_q1_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const voi
     const int nb = n / qk;
 
     assert(n % qk == 0);
-    assert(nrc == 1);
-    UNUSED(nrc);
-    UNUSED(bx);
-    UNUSED(by);
-    UNUSED(bs);
 
     const block_q1_0 * GGML_RESTRICT x = vx;
     const block_q8_0 * GGML_RESTRICT y = vy;
 
 #if defined(__AVX2__)
+    assert((nrc == 2) || (nrc == 1));
     const __m256i ones_8 = _mm256_set1_epi8(1);
     const __m256i ones_16 = _mm256_set1_epi16(1);
     const __m256i byte_shuf = _mm256_setr_epi8(
@@ -576,6 +572,71 @@ void ggml_vec_dot_q1_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const voi
             1, 2, 4, 8, 16, 32, 64, (char) -128, 1, 2, 4, 8, 16, 32, 64, (char) -128,
             1, 2, 4, 8, 16, 32, 64, (char) -128, 1, 2, 4, 8, 16, 32, 64, (char) -128);
     const __m256i zero = _mm256_setzero_si256();
+
+    if (nrc == 2) {
+        const block_q1_0 * GGML_RESTRICT x0 = vx;
+        const block_q1_0 * GGML_RESTRICT x1 = (const block_q1_0 *) ((const uint8_t *) vx + bx);
+        const block_q8_0 * GGML_RESTRICT y0 = vy;
+        const block_q8_0 * GGML_RESTRICT y1 = (const block_q8_0 *) ((const uint8_t *) vy + by);
+
+        __m256 acc_00 = _mm256_setzero_ps();
+        __m256 acc_01 = _mm256_setzero_ps();
+        __m256 acc_10 = _mm256_setzero_ps();
+        __m256 acc_11 = _mm256_setzero_ps();
+
+        for (int ib = 0; ib < nb; ++ib) {
+            const float d00 = GGML_CPU_FP16_TO_FP32(x0[ib].d);
+            const float d10 = GGML_CPU_FP16_TO_FP32(x1[ib].d);
+            const uint32_t * GGML_RESTRICT qs0 = (const uint32_t *) x0[ib].qs;
+            const uint32_t * GGML_RESTRICT qs1 = (const uint32_t *) x1[ib].qs;
+            const block_q8_0 * GGML_RESTRICT y0_ptr = &y0[ib * 4];
+            const block_q8_0 * GGML_RESTRICT y1_ptr = &y1[ib * 4];
+
+            __m256 acc_block_00 = _mm256_setzero_ps();
+            __m256 acc_block_01 = _mm256_setzero_ps();
+            __m256 acc_block_10 = _mm256_setzero_ps();
+            __m256 acc_block_11 = _mm256_setzero_ps();
+
+#define Q1_AVX2_BLOCK_PAIR(K) \
+            { \
+                const __m256i sm0 = _mm256_cmpeq_epi8( \
+                        _mm256_and_si256(_mm256_shuffle_epi8(_mm256_set1_epi32((int) qs0[K]), byte_shuf), bit_masks), zero); \
+                const __m256i sm1 = _mm256_cmpeq_epi8( \
+                        _mm256_and_si256(_mm256_shuffle_epi8(_mm256_set1_epi32((int) qs1[K]), byte_shuf), bit_masks), zero); \
+                const __m256i qy0 = _mm256_loadu_si256((const __m256i *) y0_ptr[K].qs); \
+                const __m256i qy1 = _mm256_loadu_si256((const __m256i *) y1_ptr[K].qs); \
+                const __m256i sy00 = _mm256_sub_epi8(_mm256_xor_si256(qy0, sm0), sm0); \
+                const __m256i sy01 = _mm256_sub_epi8(_mm256_xor_si256(qy1, sm0), sm0); \
+                const __m256i sy10 = _mm256_sub_epi8(_mm256_xor_si256(qy0, sm1), sm1); \
+                const __m256i sy11 = _mm256_sub_epi8(_mm256_xor_si256(qy1, sm1), sm1); \
+                const __m256i s32_00 = _mm256_madd_epi16(_mm256_maddubs_epi16(ones_8, sy00), ones_16); \
+                const __m256i s32_01 = _mm256_madd_epi16(_mm256_maddubs_epi16(ones_8, sy01), ones_16); \
+                const __m256i s32_10 = _mm256_madd_epi16(_mm256_maddubs_epi16(ones_8, sy10), ones_16); \
+                const __m256i s32_11 = _mm256_madd_epi16(_mm256_maddubs_epi16(ones_8, sy11), ones_16); \
+                acc_block_00 = _mm256_fmadd_ps(_mm256_set1_ps(GGML_CPU_FP16_TO_FP32(y0_ptr[K].d)), _mm256_cvtepi32_ps(s32_00), acc_block_00); \
+                acc_block_01 = _mm256_fmadd_ps(_mm256_set1_ps(GGML_CPU_FP16_TO_FP32(y1_ptr[K].d)), _mm256_cvtepi32_ps(s32_01), acc_block_01); \
+                acc_block_10 = _mm256_fmadd_ps(_mm256_set1_ps(GGML_CPU_FP16_TO_FP32(y0_ptr[K].d)), _mm256_cvtepi32_ps(s32_10), acc_block_10); \
+                acc_block_11 = _mm256_fmadd_ps(_mm256_set1_ps(GGML_CPU_FP16_TO_FP32(y1_ptr[K].d)), _mm256_cvtepi32_ps(s32_11), acc_block_11); \
+            }
+            Q1_AVX2_BLOCK_PAIR(0)
+            Q1_AVX2_BLOCK_PAIR(1)
+            Q1_AVX2_BLOCK_PAIR(2)
+            Q1_AVX2_BLOCK_PAIR(3)
+#undef Q1_AVX2_BLOCK_PAIR
+
+            acc_00 = _mm256_fmadd_ps(_mm256_set1_ps(d00), acc_block_00, acc_00);
+            acc_01 = _mm256_fmadd_ps(_mm256_set1_ps(d00), acc_block_01, acc_01);
+            acc_10 = _mm256_fmadd_ps(_mm256_set1_ps(d10), acc_block_10, acc_10);
+            acc_11 = _mm256_fmadd_ps(_mm256_set1_ps(d10), acc_block_11, acc_11);
+        }
+
+        s[0] = hsum_float_8(acc_00);
+        s[1] = hsum_float_8(acc_10);
+        s[bs] = hsum_float_8(acc_01);
+        s[bs + 1] = hsum_float_8(acc_11);
+        return;
+    }
+
     __m256 acc = _mm256_setzero_ps();
 
     for (int ib = 0; ib < nb; ++ib) {
@@ -610,6 +671,10 @@ void ggml_vec_dot_q1_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const voi
 
     *s = hsum_float_8(acc);
 #elif defined(__AVX__)
+    assert(nrc == 1);
+    UNUSED(bx);
+    UNUSED(by);
+    UNUSED(bs);
     const __m128i ones_8 = _mm_set1_epi8(1);
     const __m128i ones_16 = _mm_set1_epi16(1);
     const __m128i zero = _mm_setzero_si128();
@@ -648,6 +713,10 @@ void ggml_vec_dot_q1_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const voi
 
     *s = hsum_float_8(acc);
 #elif defined(__SSSE3__)
+    assert(nrc == 1);
+    UNUSED(bx);
+    UNUSED(by);
+    UNUSED(bs);
     const __m128i ones_8 = _mm_set1_epi8(1);
     const __m128i ones_16 = _mm_set1_epi16(1);
     const __m128i zero = _mm_setzero_si128();
@@ -684,6 +753,10 @@ void ggml_vec_dot_q1_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const voi
 
     *s = hsum_float_4x4(acc_0, acc_1, acc_2, acc_3);
 #else
+    assert(nrc == 1);
+    UNUSED(bx);
+    UNUSED(by);
+    UNUSED(bs);
     float sumf = 0.0f;
 
     for (int ib = 0; ib < nb; ++ib) {
diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c
index 2b3eb5b..b3b0297 100644
--- a/ggml/src/ggml-cpu/ggml-cpu.c
+++ b/ggml/src/ggml-cpu/ggml-cpu.c
@@ -221,7 +221,11 @@ static const struct ggml_type_traits_cpu type_traits_cpu[GGML_TYPE_COUNT] = {
         .from_float               = quantize_row_q1_0,
         .vec_dot                  = ggml_vec_dot_q1_0_q8_0,
         .vec_dot_type             = GGML_TYPE_Q8_0,
+#if defined(__AVX2__)
+        .nrows                    = 2,
+#else
         .nrows                    = 1,
+#endif
     },
     [GGML_TYPE_Q4_0] = {
         .from_float               = quantize_row_q4_0,

Benchmark
flow run nrc=1 nrc<=2 delta
AVX2 pp512 131.03 t/s 177.07 t/s +35.14%
AVX2 tg128 73.85 t/s 73.81 t/s -0.06%
AVX-512 pp512 137.75 t/s 172.88 t/s +25.50%
AVX-512 tg128 76.91 t/s 77.13 t/s +0.28%

Main purpose there was to use activation matrix twice (which is heavier on bandwidth)

@pl752
Copy link
Copy Markdown
Author

pl752 commented Apr 8, 2026

@khosravipasha, I would really like to know what else (if anything) should we do before we push the PR for review (remove draft status), as I feel a little bit awkward for some reason (or maybe I am just hurrying too much)?

@khosravipasha
Copy link
Copy Markdown
Collaborator

This is a good start I feel and much better than falling back to generic. I guess people can do separate PRs if they get massive improvements.

@pl752
Copy link
Copy Markdown
Author

pl752 commented Apr 8, 2026

I think that since it performs significantly better than current options, I will undraft PR and then maybe open second one, as I want to add second round of branches with nrc == 2

@pl752
Copy link
Copy Markdown
Author

pl752 commented Apr 9, 2026

Have opened draft PR #21 (mostly for demonstration and discussion) related to nrows and geometry optimizations.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants