11#include " absl/strings/str_format.h"
22
3- #include " jaxlib/ffi_helpers.h"
4- #include " jaxlib/gpu/blas_handle_pool.h"
53#include " xla/ffi/api/c_api.h"
64#include " xla/ffi/api/ffi.h"
75#include " xla/ffi/ffi_api.h"
@@ -18,6 +16,9 @@ namespace cuda {
1816
1917#if defined(REACTANT_CUDA)
2018
19+ #include " jaxlib/ffi_helpers.h"
20+ #include " jaxlib/gpu/blas_handle_pool.h"
21+
2122#include " third_party/gpus/cuda/include/cuComplex.h"
2223#include " third_party/gpus/cuda/include/cublas_v2.h"
2324#include " third_party/gpus/cuda/include/cuda.h"
@@ -41,26 +42,32 @@ namespace cuda {
4142 }
4243
4344template <typename T>
44- T GetHostScalar (CUstream stream, bool use_attribute, double value_real,
45- double value_imag, ffi::AnyBuffer buffer) {
46- T host_value;
45+ ffi::Error GetHostScalar (CUstream stream, bool use_attribute, double value_real,
46+ double value_imag, ffi::AnyBuffer buffer,
47+ T * host_value) {
4748 if (use_attribute) {
4849 if constexpr (std::is_same<T, float >::value) {
49- host_value = static_cast <float >(value_real);
50+ * host_value = static_cast <float >(value_real);
5051 } else if constexpr (std::is_same<T, double >::value) {
51- host_value = value_real;
52+ * host_value = value_real;
5253 } else if constexpr (std::is_same<T, cuComplex>::value) {
53- host_value = cuComplex{static_cast <float >(value_real),
54- static_cast <float >(value_imag)};
54+ * host_value = cuComplex{static_cast <float >(value_real),
55+ static_cast <float >(value_imag)};
5556 } else if constexpr (std::is_same<T, cuDoubleComplex>::value) {
56- host_value = cuDoubleComplex{value_real, value_imag};
57+ * host_value = cuDoubleComplex{value_real, value_imag};
5758 }
5859 } else {
60+ // Ensure buffer has exactly 1 element
61+ if (buffer.element_count () != 1 ) {
62+ return ffi::Error::InvalidArgument (
63+ absl::StrFormat (" Expected scalar buffer with 1 element, got %d" ,
64+ buffer.element_count ()));
65+ }
5966 // memcpy to host
60- cudaMemcpyAsync (& host_value, buffer.untyped_data (), sizeof (T),
67+ cudaMemcpyAsync (host_value, buffer.untyped_data (), sizeof (T),
6168 cudaMemcpyDeviceToHost, stream);
6269 }
63- return host_value ;
70+ return ffi::Error::Success () ;
6471}
6572
6673inline ffi::Error CublasStatusToError (cublasStatus_t status,
@@ -185,10 +192,11 @@ ffi::Error SyrkImpl(CUstream stream, bool transpose, bool uplo,
185192 double beta_real, double beta_imag, ffi::AnyBuffer a,
186193 ffi::AnyBuffer c_in, ffi::AnyBuffer alpha_,
187194 ffi::AnyBuffer beta_, ffi::Result<ffi::AnyBuffer> c_out) {
188- T host_alpha = GetHostScalar<T>(stream, use_alpha_attribute, alpha_real,
189- alpha_imag, alpha_);
190- T host_beta =
191- GetHostScalar<T>(stream, use_beta_attribute, beta_real, beta_imag, beta_);
195+ T host_alpha, host_beta;
196+ FFI_RETURN_IF_ERROR (GetHostScalar<T>(stream, use_alpha_attribute, alpha_real,
197+ alpha_imag, alpha_, &host_alpha));
198+ FFI_RETURN_IF_ERROR (GetHostScalar<T>(stream, use_beta_attribute, beta_real,
199+ beta_imag, beta_, &host_beta));
192200 return SyrkImpl<T>(stream, transpose, uplo, a, c_in, &host_alpha, &host_beta,
193201 c_out);
194202}
0 commit comments