@@ -30,6 +30,12 @@ static const __fp16 q4_0_to_fp16_lut[64] __attribute__((aligned(VLEN))) = {
3030 -8 , 0 , -7 , 0 , -6 , 0 , -5 , 0 , -4 , 0 , -3 , 0 , -2 , 0 , -1 , 0 , 0 , 0 , 1 , 0 , 2 , 0 , 3 , 0 , 4 , 0 , 5 , 0 , 6 , 0 , 7 , 0 ,
3131};
3232
33+ // MXFP4 dequantization LUT: maps 4-bit index to fp16 mantissa value
34+ // kvalues: 0, 0.5, 1, 1.5, 2, 3, 4, 6, 0, -0.5, -1, -1.5, -2, -3, -4, -6
35+ static const __fp16 mxfp4_to_fp16_lut [64 ] __attribute__((aligned (VLEN ))) = {
36+ 0 , 0 , 0.5 , 0 , 1 , 0 , 1.5 , 0 , 2 , 0 , 3 , 0 , 4 , 0 , 6 , 0 , 0 , 0 , -0.5 , 0 , -1 , 0 , -1.5 , 0 , -2 , 0 , -3 , 0 , -4 , 0 , -6 , 0 ,
37+ };
38+
3339static const __fp16 iq4_nl_to_fp16_lut [64 ] __attribute__((aligned (VLEN ))) = {
3440 -127 , 0 , -104 , 0 , -83 , 0 , -65 , 0 , -49 , 0 , -35 , 0 , -22 , 0 , -10 , 0 ,
3541 1 , 0 , 13 , 0 , 25 , 0 , 38 , 0 , 53 , 0 , 69 , 0 , 89 , 0 , 113 , 0 ,
@@ -46,7 +52,8 @@ static const int32_t weight_transpose_scatter_offsets[32] __attribute__((aligned
4652
4753// Scales per x4x2 logical block: 8 × sizeof(__fp16) = 16 bytes
4854#define HMX_X4X2_SCALES_PER_BLK 8
49- #define HMX_X4X2_DBLK_SIZE 16 // 8 * 2 bytes
55+ #define HMX_X4X2_DBLK_SIZE 16 // 8 * 2 bytes (fp16 scales for Q4_0/Q8_0/IQ4_NL)
56+ #define HMX_X4X2_MXFP4_EBLK_SIZE 8 // 8 * 1 byte (E8M0 scales for MXFP4)
5057
5158static inline void swap_ptr (void * * p1 , void * * p2 ) {
5259 void * t = * p1 ;
@@ -78,9 +85,11 @@ static inline size_t get_x4x2_row_stride(int weight_type, int k) {
7885 switch (weight_type ) {
7986 case HTP_TYPE_Q4_0 :
8087 case HTP_TYPE_IQ4_NL :
81- return (size_t )nb * (QK_Q4_0x4x2 / 2 + HMX_X4X2_DBLK_SIZE ); // 144 * nb
88+ return (size_t ) nb * (QK_Q4_0x4x2 / 2 + HMX_X4X2_DBLK_SIZE ); // 144 * nb
8289 case HTP_TYPE_Q8_0 :
83- return (size_t )nb * (QK_Q8_0x4x2 + HMX_X4X2_DBLK_SIZE ); // 272 * nb
90+ return (size_t ) nb * (QK_Q8_0x4x2 + HMX_X4X2_DBLK_SIZE ); // 272 * nb
91+ case HTP_TYPE_MXFP4 :
92+ return (size_t ) nb * (QK_MXFP4x4x2 / 2 + HMX_X4X2_MXFP4_EBLK_SIZE ); // 136 * nb
8493 default :
8594 return 0 ;
8695 }
@@ -284,6 +293,87 @@ static inline HVX_Vector dequantize_x4x2_q8_0_group_hvx(
284293 return Q6_Vhf_equals_Vqf16 (Q6_Vqf16_vmpy_VhfVhf (v_hf , v_scales ));
285294}
286295
296+ // --- MXFP4 E8M0 scale conversion and dequantization ---
297+ //
298+ // HVX batch-convert 8 E8M0 bytes (one x4x2 block's scales) to __fp16[8] on stack.
299+ // Scalar loads from the stack array execute on the scalar pipeline, in parallel
300+ // with HVX vlut16/vmpy/vscatter — freeing HVX slots in the hot loop.
301+ // Arithmetic: fp16_bits = clamp(e - 112, 0, 30) << 10
302+ // e=0..112 -> 0 (underflow), e=113..142 -> valid fp16, e>=143 -> clamped to 2^15.
303+
304+ typedef struct {
305+ __fp16 v [8 ] __attribute__((aligned (16 )));
306+ } mxfp4_scales_t ;
307+
308+ static inline mxfp4_scales_t mxfp4_convert_scales (const uint8_t * e8m0_8 ) {
309+ mxfp4_scales_t s ;
310+ HVX_Vector v = hvx_vmemu (e8m0_8 );
311+ HVX_Vector vh = Q6_V_lo_W (Q6_Wuh_vunpack_Vub (v ));
312+ vh = Q6_Vh_vsub_VhVh (vh , Q6_Vh_vsplat_R (112 ));
313+ vh = Q6_Vh_vmax_VhVh (vh , Q6_V_vzero ());
314+ vh = Q6_Vh_vmin_VhVh (vh , Q6_Vh_vsplat_R (30 ));
315+ vh = Q6_Vh_vasl_VhR (vh , 10 );
316+ hvx_vec_store_u (s .v , 16 , vh );
317+ return s ;
318+ }
319+
320+ static inline HVX_Vector mxfp4_extract_splat (mxfp4_scales_t scales , int idx ) {
321+ return hvx_vec_splat_f16 (scales .v [idx ]);
322+ }
323+
324+ // Dequantize one x4x2 MXFP4 group (32 elements from 32 packed bytes) -> 32 FP16.
325+ static inline HVX_Vector dequantize_x4x2_mxfp4_group_hvx (const uint8_t * packed_32 ,
326+ bool upper_nibbles ,
327+ int sub_blk ,
328+ const HVX_Vector vlut_cvt ,
329+ mxfp4_scales_t scales ) {
330+ HVX_Vector vq = hvx_vmemu (packed_32 );
331+ const HVX_Vector mask_h4 = Q6_Vb_vsplat_R (0x0F );
332+ HVX_Vector v_quants = upper_nibbles ? Q6_Vub_vlsr_VubR (vq , 4 ) : vq ;
333+ v_quants = Q6_V_vand_VV (v_quants , mask_h4 );
334+
335+ HVX_Vector v_sc = mxfp4_extract_splat (scales , sub_blk );
336+
337+ v_quants = Q6_Vb_vshuff_Vb (v_quants );
338+ HVX_VectorPair vp = Q6_Wh_vlut16_VbVhR (v_quants , vlut_cvt , 0 );
339+ HVX_Vector v_hf = Q6_V_lo_W (vp );
340+
341+ return Q6_Vhf_equals_Vqf16 (Q6_Vqf16_vmpy_VhfVhf (v_hf , v_sc ));
342+ }
343+
344+ // Batch-dequantize 4 contiguous x4x2 MXFP4 groups (4x32 = 128 packed bytes).
345+ static inline void dequantize_x4x2_mxfp4_x4groups_hvx (const uint8_t * packed_128 ,
346+ bool upper_nibbles ,
347+ int sub_blk_base ,
348+ const HVX_Vector vlut_cvt ,
349+ mxfp4_scales_t scales ,
350+ HVX_Vector out [4 ]) {
351+ HVX_Vector vq = hvx_vmemu (packed_128 );
352+ const HVX_Vector mask_h4 = Q6_Vb_vsplat_R (0x0F );
353+ HVX_Vector v_quants = upper_nibbles ? Q6_Vub_vlsr_VubR (vq , 4 ) : vq ;
354+ v_quants = Q6_V_vand_VV (v_quants , mask_h4 );
355+
356+ v_quants = Q6_Vb_vshuff_Vb (v_quants );
357+
358+ HVX_VectorPair vp = Q6_Wh_vlut16_VbVhR (v_quants , vlut_cvt , 0 );
359+ HVX_Vector v_lo = Q6_V_lo_W (vp );
360+ HVX_Vector v_hi = Q6_V_hi_W (vp );
361+
362+ HVX_VectorPred q64 = Q6_Q_vsetq_R (64 );
363+ HVX_Vector v_sc01 = Q6_V_vmux_QVV (q64 , mxfp4_extract_splat (scales , sub_blk_base + 0 ),
364+ mxfp4_extract_splat (scales , sub_blk_base + 1 ));
365+ HVX_Vector v_sc23 = Q6_V_vmux_QVV (q64 , mxfp4_extract_splat (scales , sub_blk_base + 2 ),
366+ mxfp4_extract_splat (scales , sub_blk_base + 3 ));
367+
368+ v_lo = Q6_Vhf_equals_Vqf16 (Q6_Vqf16_vmpy_VhfVhf (v_lo , v_sc01 ));
369+ v_hi = Q6_Vhf_equals_Vqf16 (Q6_Vqf16_vmpy_VhfVhf (v_hi , v_sc23 ));
370+
371+ out [0 ] = v_lo ;
372+ out [1 ] = Q6_V_vror_VR (v_lo , 64 );
373+ out [2 ] = v_hi ;
374+ out [3 ] = Q6_V_vror_VR (v_hi , 64 );
375+ }
376+
287377// Dequantize a tile range from x4x2 weight data (already in VTCM) to tile-major FP16.
288378// Input: vtcm_src has n_cols rows of x4x2 data, each row_stride bytes.
289379// Output: vtcm_dst in tile-major FP16 layout.
@@ -295,11 +385,11 @@ static void dequantize_x4x2_weight_to_fp16_tiles_task(
295385 int start_tile , int end_tile ) {
296386
297387 const int n_k_tiles = k_block / HMX_FP16_TILE_N_COLS ;
298- const bool is_q4 = (weight_type == HTP_TYPE_Q4_0 || weight_type == HTP_TYPE_IQ4_NL );
299- const int qrow_size = is_q4 ? (k_block / 2 ) : k_block ;
388+ const int qrow_size = (weight_type == HTP_TYPE_Q8_0 ) ? k_block : (k_block / 2 );
300389
301- const HVX_Vector vlut_cvt = (weight_type == HTP_TYPE_IQ4_NL )
302- ? hvx_vmem (iq4_nl_to_fp16_lut ) : hvx_vmem (q4_0_to_fp16_lut );
390+ const HVX_Vector vlut_cvt = (weight_type == HTP_TYPE_IQ4_NL ) ? hvx_vmem (iq4_nl_to_fp16_lut ) :
391+ (weight_type == HTP_TYPE_MXFP4 ) ? hvx_vmem (mxfp4_to_fp16_lut ) :
392+ hvx_vmem (q4_0_to_fp16_lut );
303393
304394 // vscatter setup: write dequantized K-values directly to transposed [K][N] tile positions.
305395 // Each int32 element holds a K-row-pair (2 adjacent fp16 values). word[i] at offset i*128
@@ -312,8 +402,9 @@ static void dequantize_x4x2_weight_to_fp16_tiles_task(
312402 int ct = t / n_k_tiles ; // column tile index
313403 int kt = t % n_k_tiles ; // K tile index
314404
315- // --- Batch-4 fast path for Q4: process 4 contiguous K-tiles with one vlut16 per row ---
316- if (is_q4 && (kt % 4 == 0 ) && (t + 4 <= end_tile ) && ((t + 3 ) / n_k_tiles == ct )) {
405+ // --- Batch-4 fast path for Q4_0/IQ4_NL: process 4 contiguous K-tiles with one vlut16 per row ---
406+ if ((weight_type == HTP_TYPE_Q4_0 || weight_type == HTP_TYPE_IQ4_NL ) && (kt % 4 == 0 ) && (t + 4 <= end_tile ) &&
407+ ((t + 3 ) / n_k_tiles == ct )) {
317408 int blk_idx = (kt * 32 ) / QK_Q4_0x4x2 ;
318409 int sub_blk_base = ((kt * 32 ) % QK_Q4_0x4x2 ) / 32 ; // 0 or 4
319410 bool upper = (sub_blk_base >= 4 );
@@ -351,10 +442,60 @@ static void dequantize_x4x2_weight_to_fp16_tiles_task(
351442 continue ;
352443 }
353444
445+ // --- Batch-4 fast path for MXFP4: same nibble layout but E8M0 scales ---
446+ if (weight_type == HTP_TYPE_MXFP4 && (kt % 4 == 0 ) && (t + 4 <= end_tile ) && ((t + 3 ) / n_k_tiles == ct )) {
447+ int blk_idx = (kt * 32 ) / QK_MXFP4x4x2 ;
448+ int sub_blk_base = ((kt * 32 ) % QK_MXFP4x4x2 ) / 32 ; // 0 or 4
449+ bool upper = (sub_blk_base >= 4 );
450+ int packed_off = blk_idx * (QK_MXFP4x4x2 / 2 ); // 128 contiguous packed bytes
451+ int e8m0_blk_off = qrow_size + blk_idx * HMX_X4X2_MXFP4_EBLK_SIZE ; // all 8 E8M0 scales
452+
453+ __fp16 * tile_bases [4 ];
454+ for (int g = 0 ; g < 4 ; g ++ ) {
455+ tile_bases [g ] = vtcm_dst + (t + g ) * HMX_FP16_TILE_N_ELMS ;
456+ }
457+
458+ HVX_Vector v_off = v_scat_base ;
459+ for (int r = 0 ; r < HMX_FP16_TILE_N_ROWS ; r += 2 ) {
460+ int row0 = ct * HMX_FP16_TILE_N_COLS + r ;
461+ int row1 = row0 + 1 ;
462+ const uint8_t * r0 = vtcm_src + row0 * row_stride ;
463+ const uint8_t * r1 = vtcm_src + row1 * row_stride ;
464+
465+ // Batch-convert all 8 E8M0 scales once per row (stays in HVX register)
466+ mxfp4_scales_t r0_e8 = mxfp4_convert_scales (r0 + e8m0_blk_off );
467+
468+ HVX_Vector v0 [4 ], v1 [4 ];
469+ dequantize_x4x2_mxfp4_x4groups_hvx (r0 + packed_off , upper , sub_blk_base , vlut_cvt , r0_e8 , v0 );
470+ if (row1 < n_cols ) {
471+ mxfp4_scales_t r1_e8 = mxfp4_convert_scales (r1 + e8m0_blk_off );
472+ dequantize_x4x2_mxfp4_x4groups_hvx (r1 + packed_off , upper , sub_blk_base , vlut_cvt , r1_e8 , v1 );
473+ } else {
474+ v1 [0 ] = v1 [1 ] = v1 [2 ] = v1 [3 ] = Q6_V_vzero ();
475+ }
476+
477+ for (int g = 0 ; g < 4 ; g ++ ) {
478+ Q6_vscatter_QRMVwV (q_mask64 , (size_t ) tile_bases [g ], HMX_FP16_TILE_SIZE - 1 , v_off , v0 [g ]);
479+ }
480+ v_off = Q6_Vw_vadd_VwVw (v_off , v_scat_step );
481+ for (int g = 0 ; g < 4 ; g ++ ) {
482+ Q6_vscatter_QRMVwV (q_mask64 , (size_t ) tile_bases [g ], HMX_FP16_TILE_SIZE - 1 , v_off , v1 [g ]);
483+ }
484+ v_off = Q6_Vw_vadd_VwVw (v_off , v_scat_step );
485+ }
486+
487+ for (int g = 0 ; g < 4 ; g ++ ) {
488+ (void ) * (volatile HVX_Vector * ) (tile_bases [g ]);
489+ }
490+
491+ t += 4 ;
492+ continue ;
493+ }
494+
354495 // --- Single-tile fallback ---
355496 __fp16 * tile_base = vtcm_dst + t * HMX_FP16_TILE_N_ELMS ;
356497
357- if (is_q4 ) {
498+ if (weight_type == HTP_TYPE_Q4_0 || weight_type == HTP_TYPE_IQ4_NL ) {
358499 int blk_idx = (kt * 32 ) / QK_Q4_0x4x2 ;
359500 int sub_blk = ((kt * 32 ) % QK_Q4_0x4x2 ) / 32 ;
360501 bool upper = (sub_blk >= 4 );
@@ -382,6 +523,39 @@ static void dequantize_x4x2_weight_to_fp16_tiles_task(
382523 v_off = Q6_Vw_vadd_VwVw (v_off , v_scat_step );
383524 }
384525 (void ) * (volatile HVX_Vector * )(tile_base );
526+ } else if (weight_type == HTP_TYPE_MXFP4 ) {
527+ int blk_idx = (kt * 32 ) / QK_MXFP4x4x2 ;
528+ int sub_blk = ((kt * 32 ) % QK_MXFP4x4x2 ) / 32 ;
529+ bool upper = (sub_blk >= 4 );
530+ int byte_off = blk_idx * (QK_MXFP4x4x2 / 2 ) + (upper ? (sub_blk - 4 ) : sub_blk ) * 32 ;
531+ int e8m0_blk_off = qrow_size + blk_idx * HMX_X4X2_MXFP4_EBLK_SIZE ;
532+
533+ HVX_Vector v_off = v_scat_base ;
534+ for (int r = 0 ; r < HMX_FP16_TILE_N_ROWS ; r += 2 ) {
535+ int row0 = ct * HMX_FP16_TILE_N_COLS + r ;
536+ int row1 = row0 + 1 ;
537+
538+ const uint8_t * r0 = vtcm_src + row0 * row_stride ;
539+ const uint8_t * r1 = vtcm_src + row1 * row_stride ;
540+
541+ // Batch-convert all 8 E8M0 scales once per row (stays in HVX register)
542+ mxfp4_scales_t r0_e8 = mxfp4_convert_scales (r0 + e8m0_blk_off );
543+
544+ HVX_Vector v0 = dequantize_x4x2_mxfp4_group_hvx (r0 + byte_off , upper , sub_blk , vlut_cvt , r0_e8 );
545+ HVX_Vector v1 ;
546+ if (row1 < n_cols ) {
547+ mxfp4_scales_t r1_e8 = mxfp4_convert_scales (r1 + e8m0_blk_off );
548+ v1 = dequantize_x4x2_mxfp4_group_hvx (r1 + byte_off , upper , sub_blk , vlut_cvt , r1_e8 );
549+ } else {
550+ v1 = Q6_V_vzero ();
551+ }
552+
553+ Q6_vscatter_QRMVwV (q_mask64 , (size_t ) tile_base , HMX_FP16_TILE_SIZE - 1 , v_off , v0 );
554+ v_off = Q6_Vw_vadd_VwVw (v_off , v_scat_step );
555+ Q6_vscatter_QRMVwV (q_mask64 , (size_t ) tile_base , HMX_FP16_TILE_SIZE - 1 , v_off , v1 );
556+ v_off = Q6_Vw_vadd_VwVw (v_off , v_scat_step );
557+ }
558+ (void ) * (volatile HVX_Vector * ) (tile_base );
385559 } else {
386560 // Q8_0
387561 int blk_idx = (kt * 32 ) / QK_Q8_0x4x2 ;
@@ -1455,21 +1629,24 @@ int mat_mul_qk_0_d16a32_out_stationary(struct htp_context *ctx, float *restrict
14551629 {
14561630 qweight_fetch_task_state_t s ;
14571631
1458- const bool is_q4 = (weight_type == HTP_TYPE_Q4_0 || weight_type == HTP_TYPE_IQ4_NL );
14591632 const int blk_start = kk / QK_Q4_0x4x2 ;
14601633 const int nb_sub = (k_blk_sz + QK_Q4_0x4x2 - 1 ) / QK_Q4_0x4x2 ;
1461- const int full_qrow = is_q4 ? (k / 2 ) : k ;
1634+ const int full_qrow = ( weight_type == HTP_TYPE_Q8_0 ) ? k : (k / 2 );
14621635 const size_t sub_row_stride = get_x4x2_row_stride (weight_type , k_blk_sz );
1636+ const int scale_blk_size =
1637+ (weight_type == HTP_TYPE_MXFP4 ) ? HMX_X4X2_MXFP4_EBLK_SIZE : HMX_X4X2_DBLK_SIZE ;
14631638
14641639 s .dst = vtcm_scratch0 ;
14651640 s .src = w + nc * row_stride ;
14661641 s .n_rows = n_blk_sz ;
14671642 s .src_stride = row_stride ;
14681643 s .dst_stride = sub_row_stride ;
1469- s .quant_off = is_q4 ? (blk_start * (QK_Q4_0x4x2 / 2 )) : (blk_start * QK_Q8_0x4x2 );
1470- s .quant_width = is_q4 ? (nb_sub * (QK_Q4_0x4x2 / 2 )) : (nb_sub * QK_Q8_0x4x2 );
1471- s .scale_off = full_qrow + blk_start * HMX_X4X2_DBLK_SIZE ;
1472- s .scale_width = nb_sub * HMX_X4X2_DBLK_SIZE ;
1644+ s .quant_off =
1645+ (weight_type == HTP_TYPE_Q8_0 ) ? (blk_start * QK_Q8_0x4x2 ) : (blk_start * (QK_Q4_0x4x2 / 2 ));
1646+ s .quant_width =
1647+ (weight_type == HTP_TYPE_Q8_0 ) ? (nb_sub * QK_Q8_0x4x2 ) : (nb_sub * (QK_Q4_0x4x2 / 2 ));
1648+ s .scale_off = full_qrow + blk_start * scale_blk_size ;
1649+ s .scale_width = nb_sub * scale_blk_size ;
14731650
14741651 // 2D DMA: quants sub-range
14751652 dma_queue_push (ctx -> dma [0 ], dma_make_ptr (s .dst , s .src + s .quant_off ),
0 commit comments