Skip to content

Commit b61d889

Browse files
committed
fix: cpu build
1 parent 0e41c3f commit b61d889

File tree

2 files changed

+25
-17
lines changed

2 files changed

+25
-17
lines changed

deps/ReactantExtra/WORKSPACE

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ NSYNC_COMMIT = "82b118aa7ace3132e517e2c467f8732978cf4023"
44

55
NSYNC_SHA256 = ""
66

7-
ENZYMEXLA_COMMIT = "ead4414a40c594814a129adb54934720a0140c86"
7+
ENZYMEXLA_COMMIT = "245a66b57e9a3b7c23a0225d8eabfc2825761029"
88

99
ENZYMEXLA_SHA256 = ""
1010

deps/ReactantExtra/xla_ffi.cpp

Lines changed: 24 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
#include "absl/strings/str_format.h"
22

3-
#include "jaxlib/ffi_helpers.h"
4-
#include "jaxlib/gpu/blas_handle_pool.h"
53
#include "xla/ffi/api/c_api.h"
64
#include "xla/ffi/api/ffi.h"
75
#include "xla/ffi/ffi_api.h"
@@ -18,6 +16,9 @@ namespace cuda {
1816

1917
#if defined(REACTANT_CUDA)
2018

19+
#include "jaxlib/ffi_helpers.h"
20+
#include "jaxlib/gpu/blas_handle_pool.h"
21+
2122
#include "third_party/gpus/cuda/include/cuComplex.h"
2223
#include "third_party/gpus/cuda/include/cublas_v2.h"
2324
#include "third_party/gpus/cuda/include/cuda.h"
@@ -41,26 +42,32 @@ namespace cuda {
4142
}
4243

4344
template <typename T>
44-
T GetHostScalar(CUstream stream, bool use_attribute, double value_real,
45-
double value_imag, ffi::AnyBuffer buffer) {
46-
T host_value;
45+
ffi::Error GetHostScalar(CUstream stream, bool use_attribute, double value_real,
46+
double value_imag, ffi::AnyBuffer buffer,
47+
T *host_value) {
4748
if (use_attribute) {
4849
if constexpr (std::is_same<T, float>::value) {
49-
host_value = static_cast<float>(value_real);
50+
*host_value = static_cast<float>(value_real);
5051
} else if constexpr (std::is_same<T, double>::value) {
51-
host_value = value_real;
52+
*host_value = value_real;
5253
} else if constexpr (std::is_same<T, cuComplex>::value) {
53-
host_value = cuComplex{static_cast<float>(value_real),
54-
static_cast<float>(value_imag)};
54+
*host_value = cuComplex{static_cast<float>(value_real),
55+
static_cast<float>(value_imag)};
5556
} else if constexpr (std::is_same<T, cuDoubleComplex>::value) {
56-
host_value = cuDoubleComplex{value_real, value_imag};
57+
*host_value = cuDoubleComplex{value_real, value_imag};
5758
}
5859
} else {
60+
// Ensure buffer has exactly 1 element
61+
if (buffer.element_count() != 1) {
62+
return ffi::Error::InvalidArgument(
63+
absl::StrFormat("Expected scalar buffer with 1 element, got %d",
64+
buffer.element_count()));
65+
}
5966
// memcpy to host
60-
cudaMemcpyAsync(&host_value, buffer.untyped_data(), sizeof(T),
67+
cudaMemcpyAsync(host_value, buffer.untyped_data(), sizeof(T),
6168
cudaMemcpyDeviceToHost, stream);
6269
}
63-
return host_value;
70+
return ffi::Error::Success();
6471
}
6572

6673
inline ffi::Error CublasStatusToError(cublasStatus_t status,
@@ -185,10 +192,11 @@ ffi::Error SyrkImpl(CUstream stream, bool transpose, bool uplo,
185192
double beta_real, double beta_imag, ffi::AnyBuffer a,
186193
ffi::AnyBuffer c_in, ffi::AnyBuffer alpha_,
187194
ffi::AnyBuffer beta_, ffi::Result<ffi::AnyBuffer> c_out) {
188-
T host_alpha = GetHostScalar<T>(stream, use_alpha_attribute, alpha_real,
189-
alpha_imag, alpha_);
190-
T host_beta =
191-
GetHostScalar<T>(stream, use_beta_attribute, beta_real, beta_imag, beta_);
195+
T host_alpha, host_beta;
196+
FFI_RETURN_IF_ERROR(GetHostScalar<T>(stream, use_alpha_attribute, alpha_real,
197+
alpha_imag, alpha_, &host_alpha));
198+
FFI_RETURN_IF_ERROR(GetHostScalar<T>(stream, use_beta_attribute, beta_real,
199+
beta_imag, beta_, &host_beta));
192200
return SyrkImpl<T>(stream, transpose, uplo, a, c_in, &host_alpha, &host_beta,
193201
c_out);
194202
}

0 commit comments

Comments
 (0)