Skip to content

Commit 539f01b

Browse files
authored
Merge pull request #76 from ROCm/upstream_fix
update for hipblasVersionMajor >=3
2 parents 47ac97d + e119ff7 commit 539f01b

1 file changed

Lines changed: 20 additions & 0 deletions

File tree

csrc/ops.hip

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -269,13 +269,23 @@ void gemmex(Context *context, bool transposeA, bool transposeB, int m, int n, in
269269
const void * beta = &fbeta;
270270
hipblasStatus_t status;
271271

272+
#if hipblasVersionMajor >= 3
273+
status = hipblasGemmEx(context->m_handle,
274+
transposeA ? HIPBLAS_OP_T : HIPBLAS_OP_N,
275+
transposeB ? HIPBLAS_OP_T : HIPBLAS_OP_N,
276+
m, n, k,
277+
alpha, A, HIP_R_8I, lda, B, HIP_R_8I, ldb, beta,
278+
C, HIP_R_32I, ldc,
279+
HIPBLAS_COMPUTE_32I, HIPBLAS_GEMM_DEFAULT);
280+
#else
272281
status = hipblasGemmEx(context->m_handle,
273282
transposeA ? HIPBLAS_OP_T : HIPBLAS_OP_N,
274283
transposeB ? HIPBLAS_OP_T : HIPBLAS_OP_N,
275284
m, n, k,
276285
alpha, A, HIPBLAS_R_8I, lda, B, HIPBLAS_R_8I, ldb, beta,
277286
C, HIPBLAS_R_32I, ldc,
278287
HIPBLAS_R_32I, HIPBLAS_GEMM_DEFAULT);
288+
#endif
279289

280290
if (status != HIPBLAS_STATUS_SUCCESS)
281291
{
@@ -299,13 +309,23 @@ void strided_gemmex(Context *context, bool transposeA, bool transposeB, int m, i
299309
//printf("%i %i %i\n", strideA, strideB, strideC);
300310
//printf("%i\n", batchCount);
301311

312+
#if hipblasVersionMajor >= 3
313+
status = hipblasGemmStridedBatchedEx(context->m_handle,
314+
transposeA ? HIPBLAS_OP_T : HIPBLAS_OP_N,
315+
transposeB ? HIPBLAS_OP_T : HIPBLAS_OP_N,
316+
m, n, k,
317+
alpha, A, HIP_R_8I, lda, (long long int)strideA, B, HIP_R_8I, ldb, (long long int)strideB, beta,
318+
C, HIP_R_32I, ldc, (long long int)strideC, batchCount,
319+
HIPBLAS_COMPUTE_32I, HIPBLAS_GEMM_DEFAULT);
320+
#else
302321
status = hipblasGemmStridedBatchedEx(context->m_handle,
303322
transposeA ? HIPBLAS_OP_T : HIPBLAS_OP_N,
304323
transposeB ? HIPBLAS_OP_T : HIPBLAS_OP_N,
305324
m, n, k,
306325
alpha, A, HIPBLAS_R_8I, lda, (long long int)strideA, B, HIPBLAS_R_8I, ldb, (long long int)strideB, beta,
307326
C, HIPBLAS_R_32I, ldc, (long long int)strideC, batchCount,
308327
HIPBLAS_R_32I, HIPBLAS_GEMM_DEFAULT);
328+
#endif
309329

310330
if (status != HIPBLAS_STATUS_SUCCESS)
311331
{

0 commit comments

Comments
 (0)