@@ -224,45 +224,60 @@ static void ggml_cuda_flash_attn_ext_vec(ggml_backend_cuda_context & ctx, ggml_t
224224 FATTN_VEC_CASES_ALL_D (GGML_TYPE_Q5_0, GGML_TYPE_F16)
225225 FATTN_VEC_CASES_ALL_D (GGML_TYPE_Q5_1, GGML_TYPE_F16)
226226 FATTN_VEC_CASES_ALL_D (GGML_TYPE_Q8_0, GGML_TYPE_F16)
227+ FATTN_VEC_CASES_ALL_D (GGML_TYPE_BF16, GGML_TYPE_F16)
227228
228229 FATTN_VEC_CASES_ALL_D (GGML_TYPE_F16, GGML_TYPE_Q4_0)
229230 FATTN_VEC_CASES_ALL_D (GGML_TYPE_Q4_0, GGML_TYPE_Q4_0)
230231 FATTN_VEC_CASES_ALL_D (GGML_TYPE_Q4_1, GGML_TYPE_Q4_0)
231232 FATTN_VEC_CASES_ALL_D (GGML_TYPE_Q5_0, GGML_TYPE_Q4_0)
232233 FATTN_VEC_CASES_ALL_D (GGML_TYPE_Q5_1, GGML_TYPE_Q4_0)
233234 FATTN_VEC_CASES_ALL_D (GGML_TYPE_Q8_0, GGML_TYPE_Q4_0)
235+ FATTN_VEC_CASES_ALL_D (GGML_TYPE_BF16, GGML_TYPE_Q4_0)
234236
235237 FATTN_VEC_CASES_ALL_D (GGML_TYPE_F16, GGML_TYPE_Q4_1)
236238 FATTN_VEC_CASES_ALL_D (GGML_TYPE_Q4_0, GGML_TYPE_Q4_1)
237239 FATTN_VEC_CASES_ALL_D (GGML_TYPE_Q4_1, GGML_TYPE_Q4_1)
238240 FATTN_VEC_CASES_ALL_D (GGML_TYPE_Q5_0, GGML_TYPE_Q4_1)
239241 FATTN_VEC_CASES_ALL_D (GGML_TYPE_Q5_1, GGML_TYPE_Q4_1)
240242 FATTN_VEC_CASES_ALL_D (GGML_TYPE_Q8_0, GGML_TYPE_Q4_1)
243+ FATTN_VEC_CASES_ALL_D (GGML_TYPE_BF16, GGML_TYPE_Q4_1)
241244
242245 FATTN_VEC_CASES_ALL_D (GGML_TYPE_F16, GGML_TYPE_Q5_0)
243246 FATTN_VEC_CASES_ALL_D (GGML_TYPE_Q4_0, GGML_TYPE_Q5_0)
244247 FATTN_VEC_CASES_ALL_D (GGML_TYPE_Q4_1, GGML_TYPE_Q5_0)
245248 FATTN_VEC_CASES_ALL_D (GGML_TYPE_Q5_0, GGML_TYPE_Q5_0)
246249 FATTN_VEC_CASES_ALL_D (GGML_TYPE_Q5_1, GGML_TYPE_Q5_0)
247250 FATTN_VEC_CASES_ALL_D (GGML_TYPE_Q8_0, GGML_TYPE_Q5_0)
251+ FATTN_VEC_CASES_ALL_D (GGML_TYPE_BF16, GGML_TYPE_Q5_0)
248252
249253 FATTN_VEC_CASES_ALL_D (GGML_TYPE_F16, GGML_TYPE_Q5_1)
250254 FATTN_VEC_CASES_ALL_D (GGML_TYPE_Q4_0, GGML_TYPE_Q5_1)
251255 FATTN_VEC_CASES_ALL_D (GGML_TYPE_Q4_1, GGML_TYPE_Q5_1)
252256 FATTN_VEC_CASES_ALL_D (GGML_TYPE_Q5_0, GGML_TYPE_Q5_1)
253257 FATTN_VEC_CASES_ALL_D (GGML_TYPE_Q5_1, GGML_TYPE_Q5_1)
254258 FATTN_VEC_CASES_ALL_D (GGML_TYPE_Q8_0, GGML_TYPE_Q5_1)
259+ FATTN_VEC_CASES_ALL_D (GGML_TYPE_BF16, GGML_TYPE_Q5_1)
255260
256261 FATTN_VEC_CASES_ALL_D (GGML_TYPE_F16, GGML_TYPE_Q8_0)
257262 FATTN_VEC_CASES_ALL_D (GGML_TYPE_Q4_0, GGML_TYPE_Q8_0)
258263 FATTN_VEC_CASES_ALL_D (GGML_TYPE_Q4_1, GGML_TYPE_Q8_0)
259264 FATTN_VEC_CASES_ALL_D (GGML_TYPE_Q5_0, GGML_TYPE_Q8_0)
260265 FATTN_VEC_CASES_ALL_D (GGML_TYPE_Q5_1, GGML_TYPE_Q8_0)
261266 FATTN_VEC_CASES_ALL_D (GGML_TYPE_Q8_0, GGML_TYPE_Q8_0)
267+ FATTN_VEC_CASES_ALL_D (GGML_TYPE_BF16, GGML_TYPE_Q8_0)
268+
269+ FATTN_VEC_CASES_ALL_D (GGML_TYPE_F16, GGML_TYPE_BF16)
270+ FATTN_VEC_CASES_ALL_D (GGML_TYPE_Q4_0, GGML_TYPE_BF16)
271+ FATTN_VEC_CASES_ALL_D (GGML_TYPE_Q4_1, GGML_TYPE_BF16)
272+ FATTN_VEC_CASES_ALL_D (GGML_TYPE_Q5_0, GGML_TYPE_BF16)
273+ FATTN_VEC_CASES_ALL_D (GGML_TYPE_Q5_1, GGML_TYPE_BF16)
274+ FATTN_VEC_CASES_ALL_D (GGML_TYPE_Q8_0, GGML_TYPE_BF16)
275+ FATTN_VEC_CASES_ALL_D (GGML_TYPE_BF16, GGML_TYPE_BF16)
262276#else
263277 FATTN_VEC_CASES_ALL_D (GGML_TYPE_F16, GGML_TYPE_F16)
264278 FATTN_VEC_CASES_ALL_D (GGML_TYPE_Q4_0, GGML_TYPE_Q4_0)
265279 FATTN_VEC_CASES_ALL_D (GGML_TYPE_Q8_0, GGML_TYPE_Q8_0)
280+ FATTN_VEC_CASES_ALL_D (GGML_TYPE_BF16, GGML_TYPE_BF16)
266281#endif // GGML_CUDA_FA_ALL_QUANTS
267282
268283 GGML_ABORT (" fatal error" );
@@ -355,6 +370,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
355370#endif // GGML_CUDA_FA_ALL_QUANTS
356371 case GGML_TYPE_Q4_0:
357372 case GGML_TYPE_Q8_0:
373+ case GGML_TYPE_BF16:
358374 break ;
359375 default :
360376 return BEST_FATTN_KERNEL_NONE;
0 commit comments