Skip to content

Commit 7fd223f

Browse files
committed
feat: cuda syrk ffi
1 parent e055527 commit 7fd223f

File tree

3 files changed

+262
-10
lines changed

3 files changed

+262
-10
lines changed

deps/ReactantExtra/API.cpp

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1005,9 +1005,9 @@ REACTANT_ABI PjRtBuffer *ArrayFromHostBuffer(PjRtClient *client, void *data,
10051005
return bres;
10061006
}
10071007

1008-
10091008
REACTANT_ABI void CopyToBuffer(PjRtClient *client, PjRtBuffer *buffer,
1010-
void *data, size_t offset, size_t size, PjRtBuffer **bufferP) {
1009+
void *data, size_t offset, size_t size,
1010+
PjRtBuffer **bufferP) {
10111011
if (buffer->IsOnCpu()) {
10121012
auto unsafe =
10131013
(char *)MyValueOrThrow(buffer->client()->UnsafeBufferPointer(buffer));
@@ -1016,12 +1016,13 @@ REACTANT_ABI void CopyToBuffer(PjRtClient *client, PjRtBuffer *buffer,
10161016
// data, size);
10171017
return;
10181018
}
1019-
1019+
10201020
auto pid = client->platform_id();
10211021
if (pid == xla::TpuId()) {
10221022
auto dims = buffer->on_device_shape().dimensions();
10231023
// TODO: note this assume that we want to copy the entire buffer size.
1024-
auto buf2 = ArrayFromHostBuffer(client, data, buffer->element_type(), dims.size(), dims.data(), buffer->device());
1024+
auto buf2 = ArrayFromHostBuffer(client, data, buffer->element_type(),
1025+
dims.size(), dims.data(), buffer->device());
10251026
*bufferP = buf2;
10261027
PjRtBufferFree((PjRtBuffer *)buffer);
10271028
return;
@@ -1075,9 +1076,9 @@ REACTANT_ABI void BufferToHost(PjRtBuffer *buffer, void *data) {
10751076
}
10761077
}
10771078

1078-
10791079
REACTANT_ABI void CopyFromBuffer(PjRtClient *client, PjRtBuffer *buffer,
1080-
void *data, size_t offset, size_t size, PjRtBuffer **bufferP) {
1080+
void *data, size_t offset, size_t size,
1081+
PjRtBuffer **bufferP) {
10811082

10821083
auto pid = client->platform_id();
10831084
if (pid == xla::TpuId()) {
@@ -3069,7 +3070,7 @@ struct LinkableRuntime {
30693070
executables;
30703071

30713072
// Set of allocated pointers to size
3072-
std::set<void *, std::greater<void*>> allocations;
3073+
std::set<void *, std::greater<void *>> allocations;
30733074

30743075
LinkableRuntime(const std::string &backend) : registry() {
30753076
InitializeRegistry(wrap(&registry));
@@ -3217,7 +3218,8 @@ REACTANT_ABI void reactantXLAExec(LinkableRuntime **__restrict__ lrtP,
32173218
for (int64_t i = 0; i < argcnt; i++) {
32183219
auto &&[argB, argO, argP] = bufferAndOffset(lrt, args[i]);
32193220
if (argO != 0) {
3220-
llvm::errs() << "only zero-offset execution supported, argument " << i << " had byte offset of " << argO << "\n";
3221+
llvm::errs() << "only zero-offset execution supported, argument " << i
3222+
<< " had byte offset of " << argO << "\n";
32213223
exit(1);
32223224
}
32233225
baseArrays[i] = argB;
@@ -3443,8 +3445,7 @@ class GPUPerformanceModel {
34433445
fusion_analysis_cache_(device_description_),
34443446
gpu_hlo_cost_analysis_(hlo_cost_analysis_options_, device_description_),
34453447
gpu_performance_model_(device_description_, fusion_analysis_cache_,
3446-
gpu_performance_model_cache_,
3447-
mlir_context_) {}
3448+
gpu_performance_model_cache_, mlir_context_) {}
34483449

34493450
void RunAnalysisOnHloModule(std::shared_ptr<xla::HloModule> hlo_module) {
34503451
hlo_module->entry_computation()->Accept(&gpu_hlo_cost_analysis_);

deps/ReactantExtra/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1059,6 +1059,7 @@ cc_library(
10591059
"-Wl,-exported_symbol,_CreateGPUPerformanceModel",
10601060
"-Wl,-exported_symbol,_RunAnalysisOnHloModule",
10611061
"-Wl,-exported_symbol,_EstimateRunTimeForInstruction",
1062+
"-Wl,-exported_symbol,_registerReactantXLAFFI",
10621063
],
10631064
}),
10641065
linkstatic = True,

deps/ReactantExtra/xla_ffi.cpp

Lines changed: 250 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,250 @@
1+
#include "absl/strings/str_format.h"
2+
3+
#include "jaxlib/ffi_helpers.h"
4+
#include "jaxlib/gpu/blas_handle_pool.h"
5+
#include "xla/ffi/api/c_api.h"
6+
#include "xla/ffi/api/ffi.h"
7+
#include "xla/ffi/ffi_api.h"
8+
9+
#include "mlir/CAPI/IR.h"
10+
11+
#define REACTANT_ABI extern "C" MLIR_CAPI_EXPORTED
12+
13+
using namespace jax;
14+
using namespace xla;
15+
16+
namespace reactant {
17+
namespace cuda {
18+
19+
#if defined(REACTANT_CUDA)
20+
21+
#include "third_party/gpus/cuda/include/cuComplex.h"
22+
#include "third_party/gpus/cuda/include/cublas_v2.h"
23+
#include "third_party/gpus/cuda/include/cuda.h"
24+
#include "third_party/gpus/cuda/include/cuda_fp8.h"
25+
#include "third_party/gpus/cuda/include/cufft.h"
26+
#include "third_party/gpus/cuda/include/cusolverDn.h"
27+
#include "third_party/gpus/cuda/include/cusolver_common.h"
28+
29+
#define SOLVER_BLAS_DISPATCH_IMPL(impl, ...) \
30+
switch (dataType) { \
31+
case ffi::F32: \
32+
return impl<float>(__VA_ARGS__); \
33+
case ffi::F64: \
34+
return impl<double>(__VA_ARGS__); \
35+
case ffi::C64: \
36+
return impl<cuComplex>(__VA_ARGS__); \
37+
case ffi::C128: \
38+
return impl<cuDoubleComplex>(__VA_ARGS__); \
39+
default: \
40+
break; \
41+
}
42+
43+
template <typename T>
44+
T GetHostScalar(CUstream stream, bool use_attribute, double value_real,
45+
double value_imag, ffi::AnyBuffer buffer) {
46+
T host_value;
47+
if (use_attribute) {
48+
if constexpr (std::is_same<T, float>::value) {
49+
host_value = static_cast<float>(value_real);
50+
} else if constexpr (std::is_same<T, double>::value) {
51+
host_value = value_real;
52+
} else if constexpr (std::is_same<T, cuComplex>::value) {
53+
host_value = cuComplex{static_cast<float>(value_real),
54+
static_cast<float>(value_imag)};
55+
} else if constexpr (std::is_same<T, cuDoubleComplex>::value) {
56+
host_value = cuDoubleComplex{value_real, value_imag};
57+
}
58+
} else {
59+
// memcpy to host
60+
cudaMemcpyAsync(&host_value, buffer.untyped_data(), sizeof(T),
61+
cudaMemcpyDeviceToHost, stream);
62+
}
63+
return host_value;
64+
}
65+
66+
inline ffi::Error CublasStatusToError(cublasStatus_t status,
67+
const char *op_name) {
68+
if (status == CUBLAS_STATUS_SUCCESS) {
69+
return ffi::Error::Success();
70+
}
71+
const char *error_name;
72+
switch (status) {
73+
case CUBLAS_STATUS_NOT_INITIALIZED:
74+
error_name = "CUBLAS_STATUS_NOT_INITIALIZED";
75+
break;
76+
case CUBLAS_STATUS_ALLOC_FAILED:
77+
error_name = "CUBLAS_STATUS_ALLOC_FAILED";
78+
break;
79+
case CUBLAS_STATUS_INVALID_VALUE:
80+
error_name = "CUBLAS_STATUS_INVALID_VALUE";
81+
break;
82+
case CUBLAS_STATUS_ARCH_MISMATCH:
83+
error_name = "CUBLAS_STATUS_ARCH_MISMATCH";
84+
break;
85+
case CUBLAS_STATUS_MAPPING_ERROR:
86+
error_name = "CUBLAS_STATUS_MAPPING_ERROR";
87+
break;
88+
case CUBLAS_STATUS_EXECUTION_FAILED:
89+
error_name = "CUBLAS_STATUS_EXECUTION_FAILED";
90+
break;
91+
case CUBLAS_STATUS_INTERNAL_ERROR:
92+
error_name = "CUBLAS_STATUS_INTERNAL_ERROR";
93+
break;
94+
case CUBLAS_STATUS_NOT_SUPPORTED:
95+
error_name = "CUBLAS_STATUS_NOT_SUPPORTED";
96+
break;
97+
default:
98+
error_name = "UNKNOWN";
99+
break;
100+
}
101+
return ffi::Error::InvalidArgument(
102+
absl::StrFormat("%s failed with status %s", op_name, error_name));
103+
}
104+
105+
namespace blas {
106+
107+
template <typename T>
108+
ffi::Error Syrk(cublasHandle_t handle, cublasFillMode_t uplo,
109+
cublasOperation_t trans, int n, int k, const T *alpha,
110+
const T *a, int lda, const T *beta, T *c, int ldc) {
111+
return ffi::Error::InvalidArgument("Unsupported type for syrk");
112+
}
113+
114+
#define SYRK_SPECIALIZATION(T, cublas_func) \
115+
template <> \
116+
ffi::Error Syrk<T>(cublasHandle_t handle, cublasFillMode_t uplo, \
117+
cublasOperation_t trans, int n, int k, const T *alpha, \
118+
const T *a, int lda, const T *beta, T *c, int ldc) { \
119+
cublasStatus_t status = \
120+
cublas_func(handle, uplo, trans, n, k, alpha, a, lda, beta, c, ldc); \
121+
return CublasStatusToError(status, #cublas_func); \
122+
}
123+
124+
SYRK_SPECIALIZATION(float, cublasSsyrk)
125+
SYRK_SPECIALIZATION(double, cublasDsyrk)
126+
SYRK_SPECIALIZATION(cuComplex, cublasCsyrk)
127+
SYRK_SPECIALIZATION(cuDoubleComplex, cublasZsyrk)
128+
129+
#undef SYRK_SPECIALIZATION
130+
131+
} // namespace blas
132+
133+
// Symmetric rank-k update: syrk
134+
135+
template <typename T>
136+
ffi::Error SyrkImpl(CUstream stream, bool transpose, bool uplo_,
137+
ffi::AnyBuffer a, ffi::AnyBuffer c_in, const T *alpha,
138+
const T *beta, ffi::Result<ffi::AnyBuffer> c_out) {
139+
FFI_ASSIGN_OR_RETURN((auto [batch, rows, cols]),
140+
SplitBatch2D(a.dimensions()));
141+
auto size = transpose ? cols : rows;
142+
FFI_RETURN_IF_ERROR(
143+
CheckShape(c_in.dimensions(), {batch, size, size}, "c_in", "syrk"));
144+
FFI_RETURN_IF_ERROR(
145+
CheckShape(c_out->dimensions(), {batch, size, size}, "c_out", "syrk"));
146+
147+
FFI_ASSIGN_OR_RETURN(auto n,
148+
MaybeCastNoOverflow<int>(transpose ? cols : rows));
149+
FFI_ASSIGN_OR_RETURN(auto k,
150+
MaybeCastNoOverflow<int>(transpose ? rows : cols));
151+
cublasFillMode_t uplo =
152+
uplo_ ? CUBLAS_FILL_MODE_LOWER : CUBLAS_FILL_MODE_UPPER;
153+
cublasOperation_t trans = transpose ? CUBLAS_OP_T : CUBLAS_OP_N;
154+
155+
const T *a_data = static_cast<const T *>(a.untyped_data());
156+
T *c_data = static_cast<T *>(c_in.untyped_data());
157+
T *c_out_data = static_cast<T *>(c_out->untyped_data());
158+
159+
if (c_data != c_out_data) {
160+
cudaError_t err = cudaMemcpyAsync(c_out_data, c_data, c_in.size_bytes(),
161+
cudaMemcpyDeviceToDevice, stream);
162+
if (err != cudaSuccess) {
163+
return ffi::Error::InvalidArgument(absl::StrFormat(
164+
"cudaMemcpyAsync failed: %s", cudaGetErrorString(err)));
165+
}
166+
}
167+
FFI_ASSIGN_OR_RETURN(auto handle, BlasHandlePool::Borrow(stream));
168+
// lda is the leading dimension of a, ldc is the leading dimension of c
169+
// For column-major (which cuBLAS uses), lda = number of rows of a, ldc = n
170+
int lda = transpose ? k : n;
171+
int ldc = n;
172+
for (int i = 0; i < batch; ++i) {
173+
FFI_RETURN_IF_ERROR(blas::Syrk<T>(handle.get(), uplo, trans, n, k, alpha,
174+
a_data, lda, beta, c_out_data, ldc));
175+
a_data += k * n;
176+
c_out_data += n * n;
177+
}
178+
return ffi::Error::Success();
179+
}
180+
181+
template <typename T>
182+
ffi::Error SyrkImpl(CUstream stream, bool transpose, bool uplo,
183+
bool use_alpha_attribute, double alpha_real,
184+
double alpha_imag, bool use_beta_attribute,
185+
double beta_real, double beta_imag, ffi::AnyBuffer a,
186+
ffi::AnyBuffer c_in, ffi::AnyBuffer alpha_,
187+
ffi::AnyBuffer beta_, ffi::Result<ffi::AnyBuffer> c_out) {
188+
T host_alpha = GetHostScalar<T>(stream, use_alpha_attribute, alpha_real,
189+
alpha_imag, alpha_);
190+
T host_beta =
191+
GetHostScalar<T>(stream, use_beta_attribute, beta_real, beta_imag, beta_);
192+
return SyrkImpl<T>(stream, transpose, uplo, a, c_in, &host_alpha, &host_beta,
193+
c_out);
194+
}
195+
196+
ffi::Error SyrkDispatch(CUstream stream, bool transpose, bool uplo,
197+
bool use_alpha_attribute, double alpha_real,
198+
double alpha_imag, bool use_beta_attribute,
199+
double beta_real, double beta_imag, ffi::AnyBuffer a,
200+
ffi::AnyBuffer c_in, ffi::AnyBuffer alpha_,
201+
ffi::AnyBuffer beta_,
202+
ffi::Result<ffi::AnyBuffer> c_out) {
203+
auto dataType = c_in.element_type();
204+
SOLVER_BLAS_DISPATCH_IMPL(SyrkImpl, stream, transpose, uplo,
205+
use_alpha_attribute, alpha_real, alpha_imag,
206+
use_beta_attribute, beta_real, beta_imag, a, c_in,
207+
alpha_, beta_, c_out);
208+
return ffi::Error::InvalidArgument(absl::StrFormat(
209+
"Unsupported dtype %s in syrk", absl::FormatStreamed(dataType)));
210+
}
211+
212+
XLA_FFI_DEFINE_HANDLER(
213+
SyrkFfi, SyrkDispatch,
214+
xla::ffi::Ffi::Bind()
215+
.Ctx<ffi::PlatformStream<CUstream>>()
216+
.Attr<bool>("transpose") // transpose
217+
.Attr<bool>("uplo") // uplo
218+
.Attr<bool>("use_alpha_attribute") // use_alpha_attribute
219+
.Attr<double>("alpha_real") // alpha_real
220+
.Attr<double>("alpha_imag") // alpha_imag
221+
.Attr<bool>("use_beta_attribute") // use_beta_attribute
222+
.Attr<double>("beta_real") // beta_real
223+
.Attr<double>("beta_imag") // beta_imag
224+
.Arg<ffi::AnyBuffer>() // a
225+
.Arg<ffi::AnyBuffer>() // c_in
226+
.Arg<ffi::AnyBuffer>() // alpha
227+
.Arg<ffi::AnyBuffer>() // beta
228+
.Ret<ffi::AnyBuffer>() // c_out
229+
);
230+
231+
void registerReactantXLACUDAFFI() {
232+
XLA_FFI_REGISTER_HANDLER(xla::ffi::GetXlaFfiApi(), "reactant_cublas_syrk_ffi",
233+
"CUDA", SyrkFfi);
234+
}
235+
236+
#undef SOLVER_BLAS_DISPATCH_IMPL
237+
238+
#else
239+
240+
void registerReactantXLACUDAFFI() {}
241+
242+
#endif
243+
244+
} // namespace cuda
245+
} // namespace reactant
246+
247+
REACTANT_ABI void registerReactantXLAFFI() {
248+
reactant::cuda::registerReactantXLACUDAFFI();
249+
return;
250+
}

0 commit comments

Comments
 (0)