Skip to content

Commit ee051c1

Browse files
authored
hexagon: support for IQ4_NL and MXFP4 (ggml-org#21018)
* ggml-hexagon: add IQ4_NL and MXFP4 HMX matmul support - Add IQ4_NL quantization type support to Hexagon backend (buffer set/get tensor repack, mul_mat, mul_mat_id dispatch) - Implement HVX IQ4_NL vec_dot kernels (1x1, 2x1, 2x2) with LUT-based 4-bit index to int8 kvalue dequantization - Add MXFP4 HMX dequantization path with E8M0 scale conversion, including batch-4 fast path and single-tile fallback - Unify quantized row size / scale offset logic to handle Q4_0, Q8_0, IQ4_NL, and MXFP4 in the DMA fetch path * ggml-hexagon: fix SKIP_QUANTIZE src1 address mismatch in mixed-quant models * Fix the pragma indent
1 parent e6f6770 commit ee051c1

File tree

5 files changed

+619
-23
lines changed

5 files changed

+619
-23
lines changed

ggml/src/ggml-hexagon/ggml-hexagon.cpp

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1406,6 +1406,13 @@ static void ggml_backend_hexagon_buffer_set_tensor(ggml_backend_buffer_t buffer,
14061406
repack_q8_0_q8x4x2(tensor, data, size);
14071407
break;
14081408

1409+
case GGML_TYPE_IQ4_NL:
1410+
GGML_ASSERT(offset == 0);
1411+
GGML_ASSERT(offset + size <= ggml_nbytes(tensor));
1412+
// IQ4_NL has identical block layout to Q4_0 (ggml_half d + uint8_t qs[16])
1413+
repack_q4_0_q4x4x2(tensor, data, size);
1414+
break;
1415+
14091416
case GGML_TYPE_MXFP4:
14101417
GGML_ASSERT(offset == 0);
14111418
GGML_ASSERT(offset + size <= ggml_nbytes(tensor));
@@ -1442,6 +1449,12 @@ static void ggml_backend_hexagon_buffer_get_tensor(ggml_backend_buffer_t buffer,
14421449
repack_q8x4x2_q8_0(data, tensor, size);
14431450
break;
14441451

1452+
case GGML_TYPE_IQ4_NL:
1453+
GGML_ASSERT(offset == 0);
1454+
GGML_ASSERT(offset + size <= ggml_nbytes(tensor));
1455+
repack_q4x4x2_q4_0(data, tensor, size);
1456+
break;
1457+
14451458
case GGML_TYPE_MXFP4:
14461459
GGML_ASSERT(offset == 0);
14471460
GGML_ASSERT(offset + size <= ggml_nbytes(tensor));
@@ -1819,6 +1832,7 @@ static bool ggml_hexagon_supported_mul_mat(const struct ggml_hexagon_session * s
18191832
switch (src0->type) {
18201833
case GGML_TYPE_Q4_0:
18211834
case GGML_TYPE_Q8_0:
1835+
case GGML_TYPE_IQ4_NL:
18221836
case GGML_TYPE_MXFP4:
18231837
if (src0->ne[0] % 32) {
18241838
return false;
@@ -1868,6 +1882,7 @@ static bool ggml_hexagon_supported_mul_mat_id(const struct ggml_hexagon_session
18681882
switch (src0->type) {
18691883
case GGML_TYPE_Q4_0:
18701884
case GGML_TYPE_Q8_0:
1885+
case GGML_TYPE_IQ4_NL:
18711886
case GGML_TYPE_MXFP4:
18721887
if ((src0->ne[0] % 32)) {
18731888
return false;
@@ -2596,8 +2611,26 @@ static void ggml_backend_hexagon_free(ggml_backend_t backend) {
25962611
delete backend;
25972612
}
25982613

2614+
// Map weight type to its activation quantization family.
2615+
// Types in the same family produce identical Q8 formats in VTCM and can
2616+
// safely share quantized activation data via SKIP_QUANTIZE.
2617+
// When adding a new quantized type, assign it the correct family here.
2618+
static inline int act_quant_family(enum ggml_type wtype) {
2619+
switch (wtype) {
2620+
case GGML_TYPE_Q4_0:
2621+
case GGML_TYPE_Q8_0:
2622+
case GGML_TYPE_IQ4_NL:
2623+
case GGML_TYPE_MXFP4:
2624+
return 1; // Q8x4x2
2625+
default:
2626+
return 0; // unknown / not quantized
2627+
}
2628+
}
2629+
25992630
static inline bool op_reuse_src1(const ggml_tensor * op1, const ggml_tensor * op0) {
2600-
return (op0 && op0->src[1] == op1->src[1] && ggml_is_quantized(op0->src[0]->type));
2631+
return (op0 && op0->src[1] == op1->src[1] &&
2632+
act_quant_family(op0->src[0]->type) == act_quant_family(op1->src[0]->type) &&
2633+
act_quant_family(op0->src[0]->type) != 0);
26012634
}
26022635

26032636
static inline bool is_compute_op(ggml_tensor *node)
@@ -3364,6 +3397,8 @@ static void ggml_hexagon_init(ggml_backend_reg * reg) {
33643397
"please update hexagon_type to match ggml_type");
33653398
static_assert((unsigned int) HTP_TYPE_MXFP4 == (unsigned int) GGML_TYPE_MXFP4,
33663399
"please update hexagon_type to match ggml_type");
3400+
static_assert((unsigned int) HTP_TYPE_IQ4_NL == (unsigned int) GGML_TYPE_IQ4_NL,
3401+
"please update hexagon_type to match ggml_type");
33673402

33683403
const char * str_experimental = getenv("GGML_HEXAGON_EXPERIMENTAL");
33693404
const char * str_verbose = getenv("GGML_HEXAGON_VERBOSE");

ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c

Lines changed: 193 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,12 @@ static const __fp16 q4_0_to_fp16_lut[64] __attribute__((aligned(VLEN))) = {
3030
-8, 0, -7, 0, -6, 0, -5, 0, -4, 0, -3, 0, -2, 0, -1, 0, 0, 0, 1, 0, 2, 0, 3, 0, 4, 0, 5, 0, 6, 0, 7, 0,
3131
};
3232

33+
// MXFP4 dequantization LUT: maps 4-bit index to fp16 mantissa value
34+
// kvalues: 0, 0.5, 1, 1.5, 2, 3, 4, 6, 0, -0.5, -1, -1.5, -2, -3, -4, -6
35+
static const __fp16 mxfp4_to_fp16_lut[64] __attribute__((aligned(VLEN))) = {
36+
0, 0, 0.5, 0, 1, 0, 1.5, 0, 2, 0, 3, 0, 4, 0, 6, 0, 0, 0, -0.5, 0, -1, 0, -1.5, 0, -2, 0, -3, 0, -4, 0, -6, 0,
37+
};
38+
3339
static const __fp16 iq4_nl_to_fp16_lut[64] __attribute__((aligned(VLEN))) = {
3440
-127, 0, -104, 0, -83, 0, -65, 0, -49, 0, -35, 0, -22, 0, -10, 0,
3541
1, 0, 13, 0, 25, 0, 38, 0, 53, 0, 69, 0, 89, 0, 113, 0,
@@ -46,7 +52,8 @@ static const int32_t weight_transpose_scatter_offsets[32] __attribute__((aligned
4652

4753
// Scales per x4x2 logical block: 8 × sizeof(__fp16) = 16 bytes
4854
#define HMX_X4X2_SCALES_PER_BLK 8
49-
#define HMX_X4X2_DBLK_SIZE 16 // 8 * 2 bytes
55+
#define HMX_X4X2_DBLK_SIZE 16 // 8 * 2 bytes (fp16 scales for Q4_0/Q8_0/IQ4_NL)
56+
#define HMX_X4X2_MXFP4_EBLK_SIZE 8 // 8 * 1 byte (E8M0 scales for MXFP4)
5057

5158
static inline void swap_ptr(void **p1, void **p2) {
5259
void *t = *p1;
@@ -78,9 +85,11 @@ static inline size_t get_x4x2_row_stride(int weight_type, int k) {
7885
switch (weight_type) {
7986
case HTP_TYPE_Q4_0:
8087
case HTP_TYPE_IQ4_NL:
81-
return (size_t)nb * (QK_Q4_0x4x2 / 2 + HMX_X4X2_DBLK_SIZE); // 144 * nb
88+
return (size_t) nb * (QK_Q4_0x4x2 / 2 + HMX_X4X2_DBLK_SIZE); // 144 * nb
8289
case HTP_TYPE_Q8_0:
83-
return (size_t)nb * (QK_Q8_0x4x2 + HMX_X4X2_DBLK_SIZE); // 272 * nb
90+
return (size_t) nb * (QK_Q8_0x4x2 + HMX_X4X2_DBLK_SIZE); // 272 * nb
91+
case HTP_TYPE_MXFP4:
92+
return (size_t) nb * (QK_MXFP4x4x2 / 2 + HMX_X4X2_MXFP4_EBLK_SIZE); // 136 * nb
8493
default:
8594
return 0;
8695
}
@@ -284,6 +293,87 @@ static inline HVX_Vector dequantize_x4x2_q8_0_group_hvx(
284293
return Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(v_hf, v_scales));
285294
}
286295

296+
// --- MXFP4 E8M0 scale conversion and dequantization ---
297+
//
298+
// HVX batch-convert 8 E8M0 bytes (one x4x2 block's scales) to __fp16[8] on stack.
299+
// Scalar loads from the stack array execute on the scalar pipeline, in parallel
300+
// with HVX vlut16/vmpy/vscatter — freeing HVX slots in the hot loop.
301+
// Arithmetic: fp16_bits = clamp(e - 112, 0, 30) << 10
302+
// e=0..112 -> 0 (underflow), e=113..142 -> valid fp16, e>=143 -> clamped to 2^15.
303+
304+
typedef struct {
305+
__fp16 v[8] __attribute__((aligned(16)));
306+
} mxfp4_scales_t;
307+
308+
static inline mxfp4_scales_t mxfp4_convert_scales(const uint8_t * e8m0_8) {
309+
mxfp4_scales_t s;
310+
HVX_Vector v = hvx_vmemu(e8m0_8);
311+
HVX_Vector vh = Q6_V_lo_W(Q6_Wuh_vunpack_Vub(v));
312+
vh = Q6_Vh_vsub_VhVh(vh, Q6_Vh_vsplat_R(112));
313+
vh = Q6_Vh_vmax_VhVh(vh, Q6_V_vzero());
314+
vh = Q6_Vh_vmin_VhVh(vh, Q6_Vh_vsplat_R(30));
315+
vh = Q6_Vh_vasl_VhR(vh, 10);
316+
hvx_vec_store_u(s.v, 16, vh);
317+
return s;
318+
}
319+
320+
static inline HVX_Vector mxfp4_extract_splat(mxfp4_scales_t scales, int idx) {
321+
return hvx_vec_splat_f16(scales.v[idx]);
322+
}
323+
324+
// Dequantize one x4x2 MXFP4 group (32 elements from 32 packed bytes) -> 32 FP16.
325+
static inline HVX_Vector dequantize_x4x2_mxfp4_group_hvx(const uint8_t * packed_32,
326+
bool upper_nibbles,
327+
int sub_blk,
328+
const HVX_Vector vlut_cvt,
329+
mxfp4_scales_t scales) {
330+
HVX_Vector vq = hvx_vmemu(packed_32);
331+
const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F);
332+
HVX_Vector v_quants = upper_nibbles ? Q6_Vub_vlsr_VubR(vq, 4) : vq;
333+
v_quants = Q6_V_vand_VV(v_quants, mask_h4);
334+
335+
HVX_Vector v_sc = mxfp4_extract_splat(scales, sub_blk);
336+
337+
v_quants = Q6_Vb_vshuff_Vb(v_quants);
338+
HVX_VectorPair vp = Q6_Wh_vlut16_VbVhR(v_quants, vlut_cvt, 0);
339+
HVX_Vector v_hf = Q6_V_lo_W(vp);
340+
341+
return Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(v_hf, v_sc));
342+
}
343+
344+
// Batch-dequantize 4 contiguous x4x2 MXFP4 groups (4x32 = 128 packed bytes).
345+
static inline void dequantize_x4x2_mxfp4_x4groups_hvx(const uint8_t * packed_128,
346+
bool upper_nibbles,
347+
int sub_blk_base,
348+
const HVX_Vector vlut_cvt,
349+
mxfp4_scales_t scales,
350+
HVX_Vector out[4]) {
351+
HVX_Vector vq = hvx_vmemu(packed_128);
352+
const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F);
353+
HVX_Vector v_quants = upper_nibbles ? Q6_Vub_vlsr_VubR(vq, 4) : vq;
354+
v_quants = Q6_V_vand_VV(v_quants, mask_h4);
355+
356+
v_quants = Q6_Vb_vshuff_Vb(v_quants);
357+
358+
HVX_VectorPair vp = Q6_Wh_vlut16_VbVhR(v_quants, vlut_cvt, 0);
359+
HVX_Vector v_lo = Q6_V_lo_W(vp);
360+
HVX_Vector v_hi = Q6_V_hi_W(vp);
361+
362+
HVX_VectorPred q64 = Q6_Q_vsetq_R(64);
363+
HVX_Vector v_sc01 = Q6_V_vmux_QVV(q64, mxfp4_extract_splat(scales, sub_blk_base + 0),
364+
mxfp4_extract_splat(scales, sub_blk_base + 1));
365+
HVX_Vector v_sc23 = Q6_V_vmux_QVV(q64, mxfp4_extract_splat(scales, sub_blk_base + 2),
366+
mxfp4_extract_splat(scales, sub_blk_base + 3));
367+
368+
v_lo = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(v_lo, v_sc01));
369+
v_hi = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(v_hi, v_sc23));
370+
371+
out[0] = v_lo;
372+
out[1] = Q6_V_vror_VR(v_lo, 64);
373+
out[2] = v_hi;
374+
out[3] = Q6_V_vror_VR(v_hi, 64);
375+
}
376+
287377
// Dequantize a tile range from x4x2 weight data (already in VTCM) to tile-major FP16.
288378
// Input: vtcm_src has n_cols rows of x4x2 data, each row_stride bytes.
289379
// Output: vtcm_dst in tile-major FP16 layout.
@@ -295,11 +385,11 @@ static void dequantize_x4x2_weight_to_fp16_tiles_task(
295385
int start_tile, int end_tile) {
296386

297387
const int n_k_tiles = k_block / HMX_FP16_TILE_N_COLS;
298-
const bool is_q4 = (weight_type == HTP_TYPE_Q4_0 || weight_type == HTP_TYPE_IQ4_NL);
299-
const int qrow_size = is_q4 ? (k_block / 2) : k_block;
388+
const int qrow_size = (weight_type == HTP_TYPE_Q8_0) ? k_block : (k_block / 2);
300389

301-
const HVX_Vector vlut_cvt = (weight_type == HTP_TYPE_IQ4_NL)
302-
? hvx_vmem(iq4_nl_to_fp16_lut) : hvx_vmem(q4_0_to_fp16_lut);
390+
const HVX_Vector vlut_cvt = (weight_type == HTP_TYPE_IQ4_NL) ? hvx_vmem(iq4_nl_to_fp16_lut) :
391+
(weight_type == HTP_TYPE_MXFP4) ? hvx_vmem(mxfp4_to_fp16_lut) :
392+
hvx_vmem(q4_0_to_fp16_lut);
303393

304394
// vscatter setup: write dequantized K-values directly to transposed [K][N] tile positions.
305395
// Each int32 element holds a K-row-pair (2 adjacent fp16 values). word[i] at offset i*128
@@ -312,8 +402,9 @@ static void dequantize_x4x2_weight_to_fp16_tiles_task(
312402
int ct = t / n_k_tiles; // column tile index
313403
int kt = t % n_k_tiles; // K tile index
314404

315-
// --- Batch-4 fast path for Q4: process 4 contiguous K-tiles with one vlut16 per row ---
316-
if (is_q4 && (kt % 4 == 0) && (t + 4 <= end_tile) && ((t + 3) / n_k_tiles == ct)) {
405+
// --- Batch-4 fast path for Q4_0/IQ4_NL: process 4 contiguous K-tiles with one vlut16 per row ---
406+
if ((weight_type == HTP_TYPE_Q4_0 || weight_type == HTP_TYPE_IQ4_NL) && (kt % 4 == 0) && (t + 4 <= end_tile) &&
407+
((t + 3) / n_k_tiles == ct)) {
317408
int blk_idx = (kt * 32) / QK_Q4_0x4x2;
318409
int sub_blk_base = ((kt * 32) % QK_Q4_0x4x2) / 32; // 0 or 4
319410
bool upper = (sub_blk_base >= 4);
@@ -351,10 +442,60 @@ static void dequantize_x4x2_weight_to_fp16_tiles_task(
351442
continue;
352443
}
353444

445+
// --- Batch-4 fast path for MXFP4: same nibble layout but E8M0 scales ---
446+
if (weight_type == HTP_TYPE_MXFP4 && (kt % 4 == 0) && (t + 4 <= end_tile) && ((t + 3) / n_k_tiles == ct)) {
447+
int blk_idx = (kt * 32) / QK_MXFP4x4x2;
448+
int sub_blk_base = ((kt * 32) % QK_MXFP4x4x2) / 32; // 0 or 4
449+
bool upper = (sub_blk_base >= 4);
450+
int packed_off = blk_idx * (QK_MXFP4x4x2 / 2); // 128 contiguous packed bytes
451+
int e8m0_blk_off = qrow_size + blk_idx * HMX_X4X2_MXFP4_EBLK_SIZE; // all 8 E8M0 scales
452+
453+
__fp16 * tile_bases[4];
454+
for (int g = 0; g < 4; g++) {
455+
tile_bases[g] = vtcm_dst + (t + g) * HMX_FP16_TILE_N_ELMS;
456+
}
457+
458+
HVX_Vector v_off = v_scat_base;
459+
for (int r = 0; r < HMX_FP16_TILE_N_ROWS; r += 2) {
460+
int row0 = ct * HMX_FP16_TILE_N_COLS + r;
461+
int row1 = row0 + 1;
462+
const uint8_t * r0 = vtcm_src + row0 * row_stride;
463+
const uint8_t * r1 = vtcm_src + row1 * row_stride;
464+
465+
// Batch-convert all 8 E8M0 scales once per row (stays in HVX register)
466+
mxfp4_scales_t r0_e8 = mxfp4_convert_scales(r0 + e8m0_blk_off);
467+
468+
HVX_Vector v0[4], v1[4];
469+
dequantize_x4x2_mxfp4_x4groups_hvx(r0 + packed_off, upper, sub_blk_base, vlut_cvt, r0_e8, v0);
470+
if (row1 < n_cols) {
471+
mxfp4_scales_t r1_e8 = mxfp4_convert_scales(r1 + e8m0_blk_off);
472+
dequantize_x4x2_mxfp4_x4groups_hvx(r1 + packed_off, upper, sub_blk_base, vlut_cvt, r1_e8, v1);
473+
} else {
474+
v1[0] = v1[1] = v1[2] = v1[3] = Q6_V_vzero();
475+
}
476+
477+
for (int g = 0; g < 4; g++) {
478+
Q6_vscatter_QRMVwV(q_mask64, (size_t) tile_bases[g], HMX_FP16_TILE_SIZE - 1, v_off, v0[g]);
479+
}
480+
v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step);
481+
for (int g = 0; g < 4; g++) {
482+
Q6_vscatter_QRMVwV(q_mask64, (size_t) tile_bases[g], HMX_FP16_TILE_SIZE - 1, v_off, v1[g]);
483+
}
484+
v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step);
485+
}
486+
487+
for (int g = 0; g < 4; g++) {
488+
(void) *(volatile HVX_Vector *) (tile_bases[g]);
489+
}
490+
491+
t += 4;
492+
continue;
493+
}
494+
354495
// --- Single-tile fallback ---
355496
__fp16 *tile_base = vtcm_dst + t * HMX_FP16_TILE_N_ELMS;
356497

357-
if (is_q4) {
498+
if (weight_type == HTP_TYPE_Q4_0 || weight_type == HTP_TYPE_IQ4_NL) {
358499
int blk_idx = (kt * 32) / QK_Q4_0x4x2;
359500
int sub_blk = ((kt * 32) % QK_Q4_0x4x2) / 32;
360501
bool upper = (sub_blk >= 4);
@@ -382,6 +523,39 @@ static void dequantize_x4x2_weight_to_fp16_tiles_task(
382523
v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step);
383524
}
384525
(void) *(volatile HVX_Vector *)(tile_base);
526+
} else if (weight_type == HTP_TYPE_MXFP4) {
527+
int blk_idx = (kt * 32) / QK_MXFP4x4x2;
528+
int sub_blk = ((kt * 32) % QK_MXFP4x4x2) / 32;
529+
bool upper = (sub_blk >= 4);
530+
int byte_off = blk_idx * (QK_MXFP4x4x2 / 2) + (upper ? (sub_blk - 4) : sub_blk) * 32;
531+
int e8m0_blk_off = qrow_size + blk_idx * HMX_X4X2_MXFP4_EBLK_SIZE;
532+
533+
HVX_Vector v_off = v_scat_base;
534+
for (int r = 0; r < HMX_FP16_TILE_N_ROWS; r += 2) {
535+
int row0 = ct * HMX_FP16_TILE_N_COLS + r;
536+
int row1 = row0 + 1;
537+
538+
const uint8_t * r0 = vtcm_src + row0 * row_stride;
539+
const uint8_t * r1 = vtcm_src + row1 * row_stride;
540+
541+
// Batch-convert all 8 E8M0 scales once per row (stays in HVX register)
542+
mxfp4_scales_t r0_e8 = mxfp4_convert_scales(r0 + e8m0_blk_off);
543+
544+
HVX_Vector v0 = dequantize_x4x2_mxfp4_group_hvx(r0 + byte_off, upper, sub_blk, vlut_cvt, r0_e8);
545+
HVX_Vector v1;
546+
if (row1 < n_cols) {
547+
mxfp4_scales_t r1_e8 = mxfp4_convert_scales(r1 + e8m0_blk_off);
548+
v1 = dequantize_x4x2_mxfp4_group_hvx(r1 + byte_off, upper, sub_blk, vlut_cvt, r1_e8);
549+
} else {
550+
v1 = Q6_V_vzero();
551+
}
552+
553+
Q6_vscatter_QRMVwV(q_mask64, (size_t) tile_base, HMX_FP16_TILE_SIZE - 1, v_off, v0);
554+
v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step);
555+
Q6_vscatter_QRMVwV(q_mask64, (size_t) tile_base, HMX_FP16_TILE_SIZE - 1, v_off, v1);
556+
v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step);
557+
}
558+
(void) *(volatile HVX_Vector *) (tile_base);
385559
} else {
386560
// Q8_0
387561
int blk_idx = (kt * 32) / QK_Q8_0x4x2;
@@ -1455,21 +1629,24 @@ int mat_mul_qk_0_d16a32_out_stationary(struct htp_context *ctx, float *restrict
14551629
{
14561630
qweight_fetch_task_state_t s;
14571631

1458-
const bool is_q4 = (weight_type == HTP_TYPE_Q4_0 || weight_type == HTP_TYPE_IQ4_NL);
14591632
const int blk_start = kk / QK_Q4_0x4x2;
14601633
const int nb_sub = (k_blk_sz + QK_Q4_0x4x2 - 1) / QK_Q4_0x4x2;
1461-
const int full_qrow = is_q4 ? (k / 2) : k;
1634+
const int full_qrow = (weight_type == HTP_TYPE_Q8_0) ? k : (k / 2);
14621635
const size_t sub_row_stride = get_x4x2_row_stride(weight_type, k_blk_sz);
1636+
const int scale_blk_size =
1637+
(weight_type == HTP_TYPE_MXFP4) ? HMX_X4X2_MXFP4_EBLK_SIZE : HMX_X4X2_DBLK_SIZE;
14631638

14641639
s.dst = vtcm_scratch0;
14651640
s.src = w + nc * row_stride;
14661641
s.n_rows = n_blk_sz;
14671642
s.src_stride = row_stride;
14681643
s.dst_stride = sub_row_stride;
1469-
s.quant_off = is_q4 ? (blk_start * (QK_Q4_0x4x2 / 2)) : (blk_start * QK_Q8_0x4x2);
1470-
s.quant_width = is_q4 ? (nb_sub * (QK_Q4_0x4x2 / 2)) : (nb_sub * QK_Q8_0x4x2);
1471-
s.scale_off = full_qrow + blk_start * HMX_X4X2_DBLK_SIZE;
1472-
s.scale_width = nb_sub * HMX_X4X2_DBLK_SIZE;
1644+
s.quant_off =
1645+
(weight_type == HTP_TYPE_Q8_0) ? (blk_start * QK_Q8_0x4x2) : (blk_start * (QK_Q4_0x4x2 / 2));
1646+
s.quant_width =
1647+
(weight_type == HTP_TYPE_Q8_0) ? (nb_sub * QK_Q8_0x4x2) : (nb_sub * (QK_Q4_0x4x2 / 2));
1648+
s.scale_off = full_qrow + blk_start * scale_blk_size;
1649+
s.scale_width = nb_sub * scale_blk_size;
14731650

14741651
// 2D DMA: quants sub-range
14751652
dma_queue_push(ctx->dma[0], dma_make_ptr(s.dst, s.src + s.quant_off),

ggml/src/ggml-hexagon/htp/htp-ctx.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,12 @@ struct htp_context {
3131

3232
uint32_t opmask;
3333

34+
// Cached src1 spad position from the last quantize pass.
35+
// When SKIP_QUANTIZE is set the Q8 activation data is already in VTCM
36+
// at this address; the matmul must read from here instead of recomputing
37+
// the offset (which depends on the current op's src0 size).
38+
uint8_t * prev_src1_spad;
39+
3440
// HMX acceleration fields (v73+, enabled by compile-time HTP_HAS_HMX)
3541
#ifdef HTP_HAS_HMX
3642
int hmx_enabled; // Runtime flag: HMX initialisation succeeded

ggml/src/ggml-hexagon/htp/main.c

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1114,14 +1114,12 @@ static void proc_hmx_matmul_req(struct htp_context * ctx,
11141114
return;
11151115
}
11161116

1117-
// HMX only supports F16, Q4_0, Q8_0, IQ4_NL weights.
1118-
// Other types (e.g. MXFP4) fall back to HVX.
1117+
// HMX supports F16, Q4_0, Q8_0, IQ4_NL, MXFP4 weights.
1118+
// Other types fall back to HVX.
11191119
{
11201120
uint32_t wtype = req->src0.type;
1121-
if (wtype != HTP_TYPE_F16 &&
1122-
wtype != HTP_TYPE_Q4_0 &&
1123-
wtype != HTP_TYPE_Q8_0 &&
1124-
wtype != HTP_TYPE_IQ4_NL) {
1121+
if (wtype != HTP_TYPE_F16 && wtype != HTP_TYPE_Q4_0 && wtype != HTP_TYPE_Q8_0 && wtype != HTP_TYPE_IQ4_NL &&
1122+
wtype != HTP_TYPE_MXFP4) {
11251123
proc_matmul_req(ctx, req, bufs, n_bufs);
11261124
return;
11271125
}

0 commit comments

Comments
 (0)