@@ -52,11 +52,36 @@ NDArray_FMatmul(NDArray *a, NDArray *b) {
5252 if (NDArray_DEVICE (a ) == NDARRAY_DEVICE_GPU ) {
5353 // Perform GPU matrix multiplication
5454#ifdef HAVE_CUBLAS
55- NDArray * result_gpu = NDArray_ToGPU (result );
56- NDArray_FREE (result );
57- cuda_matmul_float (NDArray_NUMELEMENTS (a ), NDArray_FDATA (a ), NDArray_FDATA (b ), NDArray_FDATA (result_gpu ),
58- NDArray_SHAPE (a )[1 ], NDArray_SHAPE (a )[0 ], NDArray_SHAPE (b )[1 ]);
59- return result_gpu ;
55+ cublasHandle_t handle ;
56+ cublasCreate (& handle );
57+
58+ float * d_A ;
59+ float * d_B ;
60+ float * d_C ;
61+ size_t size_A = NDArray_NUMELEMENTS (a ) * sizeof (float );
62+ size_t size_B = NDArray_NUMELEMENTS (b ) * sizeof (float );
63+ size_t size_C = NDArray_NUMELEMENTS (result ) * sizeof (float );
64+
65+ cudaMalloc ((void * * )& d_A , size_A );
66+ cudaMalloc ((void * * )& d_B , size_B );
67+ cudaMalloc ((void * * )& d_C , size_C );
68+
69+ cudaMemcpy (d_A , NDArray_FDATA (a ), size_A , cudaMemcpyHostToDevice );
70+ cudaMemcpy (d_B , NDArray_FDATA (b ), size_B , cudaMemcpyHostToDevice );
71+
72+ int m = NDArray_SHAPE (a )[0 ];
73+ int n = NDArray_SHAPE (b )[1 ];
74+ int k = NDArray_SHAPE (a )[1 ];
75+ float alpha = 1.0f ;
76+ float beta = 0.0f ;
77+
78+ cublasSgemm (handle , CUBLAS_OP_N , CUBLAS_OP_N , n , m , k , & alpha , d_B , n , d_A , k , & beta , d_C , n );
79+ cudaMemcpy (NDArray_FDATA (result ), d_C , size_C , cudaMemcpyDeviceToHost );
80+
81+ cudaFree (d_A );
82+ cudaFree (d_B );
83+ cudaFree (d_C );
84+ cublasDestroy (handle );
6085#endif
6186 } else {
6287 // Perform CPU matrix multiplication
@@ -222,6 +247,7 @@ NDArray_Matmul(NDArray *a, NDArray *b) {
222247
223248 if (NDArray_SHAPE (a )[NDArray_NDIM (a ) - 1 ] != NDArray_SHAPE (b )[NDArray_NDIM (b ) - 2 ]) {
224249 zend_throw_error (NULL , "Shape mismatch for matmul. cols(a) != rows(b)" );
250+ return NULL ;
225251 }
226252
227253 if (NDArray_NDIM (a ) > 2 && NDArray_NDIM (b ) > 2 ) {
0 commit comments