@@ -269,13 +269,23 @@ void gemmex(Context *context, bool transposeA, bool transposeB, int m, int n, in
269269 const void * beta = &fbeta;
270270 hipblasStatus_t status;
271271
272+ #if hipblasVersionMajor >= 3
273+ status = hipblasGemmEx (context->m_handle ,
274+ transposeA ? HIPBLAS_OP_T : HIPBLAS_OP_N,
275+ transposeB ? HIPBLAS_OP_T : HIPBLAS_OP_N,
276+ m, n, k,
277+ alpha, A, HIP_R_8I, lda, B, HIP_R_8I, ldb, beta,
278+ C, HIP_R_32I, ldc,
279+ HIPBLAS_COMPUTE_32I, HIPBLAS_GEMM_DEFAULT);
280+ #else
272281 status = hipblasGemmEx (context->m_handle ,
273282 transposeA ? HIPBLAS_OP_T : HIPBLAS_OP_N,
274283 transposeB ? HIPBLAS_OP_T : HIPBLAS_OP_N,
275284 m, n, k,
276285 alpha, A, HIPBLAS_R_8I, lda, B, HIPBLAS_R_8I, ldb, beta,
277286 C, HIPBLAS_R_32I, ldc,
278287 HIPBLAS_R_32I, HIPBLAS_GEMM_DEFAULT);
288+ #endif
279289
280290 if (status != HIPBLAS_STATUS_SUCCESS)
281291 {
@@ -299,13 +309,23 @@ void strided_gemmex(Context *context, bool transposeA, bool transposeB, int m, i
299309 // printf("%i %i %i\n", strideA, strideB, strideC);
300310 // printf("%i\n", batchCount);
301311
312+ #if hipblasVersionMajor >= 3
313+ status = hipblasGemmStridedBatchedEx (context->m_handle ,
314+ transposeA ? HIPBLAS_OP_T : HIPBLAS_OP_N,
315+ transposeB ? HIPBLAS_OP_T : HIPBLAS_OP_N,
316+ m, n, k,
317+ alpha, A, HIP_R_8I, lda, (long long int )strideA, B, HIP_R_8I, ldb, (long long int )strideB, beta,
318+ C, HIP_R_32I, ldc, (long long int )strideC, batchCount,
319+ HIPBLAS_COMPUTE_32I, HIPBLAS_GEMM_DEFAULT);
320+ #else
302321 status = hipblasGemmStridedBatchedEx (context->m_handle ,
303322 transposeA ? HIPBLAS_OP_T : HIPBLAS_OP_N,
304323 transposeB ? HIPBLAS_OP_T : HIPBLAS_OP_N,
305324 m, n, k,
306325 alpha, A, HIPBLAS_R_8I, lda, (long long int )strideA, B, HIPBLAS_R_8I, ldb, (long long int )strideB, beta,
307326 C, HIPBLAS_R_32I, ldc, (long long int )strideC, batchCount,
308327 HIPBLAS_R_32I, HIPBLAS_GEMM_DEFAULT);
328+ #endif
309329
310330 if (status != HIPBLAS_STATUS_SUCCESS)
311331 {
0 commit comments