Skip to content

Commit 8161eab

Browse files
hexfa: fixed test-backend-ops failurs due to leftover element handling
1 parent 04db2cf commit 8161eab

1 file changed

Lines changed: 13 additions & 17 deletions

File tree

ggml/src/ggml-hexagon/htp/flash-attn-ops.c

Lines changed: 13 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,7 @@ static inline void hvx_dot_f16_f16_aa(float * restrict r, const void * restrict
2525
uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
2626
uint32_t nloe = n % VLEN_FP16; // leftover elements
2727

28-
const HVX_Vector zero = Q6_V_vsplat_R(0);
29-
HVX_Vector rsum = Q6_V_vsplat_R(0);
28+
HVX_Vector rsum = Q6_V_vsplat_R(0);
3029

3130
uint32_t i = 0;
3231

@@ -41,11 +40,10 @@ static inline void hvx_dot_f16_f16_aa(float * restrict r, const void * restrict
4140
}
4241

4342
if (nloe) {
44-
HVX_Vector y_hf = vy[i];
45-
4643
// Load x (fp16) and zero-out unused elements
4744
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2);
48-
HVX_Vector x_hf = Q6_V_vand_QV(bmask, vx[i]);
45+
HVX_Vector y_hf = Q6_V_vand_QV(bmask, vy[i]);
46+
HVX_Vector x_hf = Q6_V_vand_QV(bmask, vx[i]);
4947

5048
HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf);
5149

@@ -66,12 +64,11 @@ static inline void hvx_dot_f16_f16_aa_rx2(float * restrict r,
6664
const HVX_Vector * restrict vx1 = (const HVX_Vector * restrict) x1; // fp16
6765
const HVX_Vector * restrict vy = (const HVX_Vector * restrict) y; // fp16
6866

69-
uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
70-
uint32_t nloe = n % VLEN_FP16; // leftover elements
67+
uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
68+
uint32_t nloe = n % VLEN_FP16; // leftover elements
7169

72-
const HVX_Vector zero = Q6_V_vsplat_R(0);
73-
HVX_Vector rsum0 = Q6_V_vsplat_R(0);
74-
HVX_Vector rsum1 = Q6_V_vsplat_R(0);
70+
HVX_Vector rsum0 = Q6_V_vsplat_R(0);
71+
HVX_Vector rsum1 = Q6_V_vsplat_R(0);
7572

7673
uint32_t i = 0;
7774

@@ -89,11 +86,11 @@ static inline void hvx_dot_f16_f16_aa_rx2(float * restrict r,
8986
}
9087

9188
if (nloe) {
92-
// Load x (fp16) and zero-out unused y elements
89+
// Load x (fp16) and zero-out unused elements
9390
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2);
94-
HVX_Vector x0_hf = vx0[i];
95-
HVX_Vector x1_hf = vx1[i];
96-
HVX_Vector y_hf = Q6_V_vand_QV(bmask, vy[i]);
91+
HVX_Vector x0_hf = Q6_V_vand_QV(bmask, vx0[i]);
92+
HVX_Vector x1_hf = Q6_V_vand_QV(bmask, vx1[i]);
93+
HVX_Vector y_hf = Q6_V_vand_QV(bmask, vy[i]);
9794

9895
HVX_VectorPair xy0_qf = Q6_Wqf32_vmpy_VhfVhf(x0_hf, y_hf);
9996
HVX_VectorPair xy1_qf = Q6_Wqf32_vmpy_VhfVhf(x1_hf, y_hf);
@@ -180,12 +177,11 @@ static inline void hvx_mad_f32_f16_aa_rx2(float * restrict y,
180177

181178
HVX_Vector xs_p_lo = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xs0_p), Q6_V_lo_W(xs1_p));
182179
HVX_Vector xs = xs_p_lo;
183-
i = 2 * i; // index for ptr_y
180+
i = 2 * i; // index for ptr_y
184181

185182
if (nloe >= 32) {
186183
ptr_y[i] = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(xs, ptr_y[i]));
187-
nloe -= 32;
188-
++i;
184+
nloe -= 32; ++i;
189185
xs = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_hi_W(xs0_p), Q6_V_hi_W(xs1_p));
190186
}
191187

0 commit comments

Comments
 (0)