Skip to content

Commit 04f1d48

Browse files
authored
feat: cuda syrk ffi (#1944)
* feat: cuda syrk ffi * fix: cpu build * ci: run local jll test on cuda
1 parent bd43e10 commit 04f1d48

File tree

4 files changed

+276
-10
lines changed

4 files changed

+276
-10
lines changed

.github/workflows/CI-localjll.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,10 @@ jobs:
3131
os:
3232
- linux-x86-n2-32
3333
- macOS-latest
34+
- linux-x86-a2-48-a100-4gpu
3435
exclude:
36+
- os: linux-x86-a2-48-a100-4gpu
37+
version: "1.10"
3538
- os: macOS-latest
3639
version: "1.10"
3740
uses: ./.github/workflows/CommonCI.yml

deps/ReactantExtra/API.cpp

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1005,9 +1005,9 @@ REACTANT_ABI PjRtBuffer *ArrayFromHostBuffer(PjRtClient *client, void *data,
10051005
return bres;
10061006
}
10071007

1008-
10091008
REACTANT_ABI void CopyToBuffer(PjRtClient *client, PjRtBuffer *buffer,
1010-
void *data, size_t offset, size_t size, PjRtBuffer **bufferP) {
1009+
void *data, size_t offset, size_t size,
1010+
PjRtBuffer **bufferP) {
10111011
if (buffer->IsOnCpu()) {
10121012
auto unsafe =
10131013
(char *)MyValueOrThrow(buffer->client()->UnsafeBufferPointer(buffer));
@@ -1016,12 +1016,13 @@ REACTANT_ABI void CopyToBuffer(PjRtClient *client, PjRtBuffer *buffer,
10161016
// data, size);
10171017
return;
10181018
}
1019-
1019+
10201020
auto pid = client->platform_id();
10211021
if (pid == xla::TpuId()) {
10221022
auto dims = buffer->on_device_shape().dimensions();
10231023
// TODO: note this assume that we want to copy the entire buffer size.
1024-
auto buf2 = ArrayFromHostBuffer(client, data, buffer->element_type(), dims.size(), dims.data(), buffer->device());
1024+
auto buf2 = ArrayFromHostBuffer(client, data, buffer->element_type(),
1025+
dims.size(), dims.data(), buffer->device());
10251026
*bufferP = buf2;
10261027
PjRtBufferFree((PjRtBuffer *)buffer);
10271028
return;
@@ -1075,9 +1076,9 @@ REACTANT_ABI void BufferToHost(PjRtBuffer *buffer, void *data) {
10751076
}
10761077
}
10771078

1078-
10791079
REACTANT_ABI void CopyFromBuffer(PjRtClient *client, PjRtBuffer *buffer,
1080-
void *data, size_t offset, size_t size, PjRtBuffer **bufferP) {
1080+
void *data, size_t offset, size_t size,
1081+
PjRtBuffer **bufferP) {
10811082

10821083
auto pid = client->platform_id();
10831084
if (pid == xla::TpuId()) {
@@ -3069,7 +3070,7 @@ struct LinkableRuntime {
30693070
executables;
30703071

30713072
// Set of allocated pointers to size
3072-
std::set<void *, std::greater<void*>> allocations;
3073+
std::set<void *, std::greater<void *>> allocations;
30733074

30743075
LinkableRuntime(const std::string &backend) : registry() {
30753076
InitializeRegistry(wrap(&registry));
@@ -3217,7 +3218,8 @@ REACTANT_ABI void reactantXLAExec(LinkableRuntime **__restrict__ lrtP,
32173218
for (int64_t i = 0; i < argcnt; i++) {
32183219
auto &&[argB, argO, argP] = bufferAndOffset(lrt, args[i]);
32193220
if (argO != 0) {
3220-
llvm::errs() << "only zero-offset execution supported, argument " << i << " had byte offset of " << argO << "\n";
3221+
llvm::errs() << "only zero-offset execution supported, argument " << i
3222+
<< " had byte offset of " << argO << "\n";
32213223
exit(1);
32223224
}
32233225
baseArrays[i] = argB;
@@ -3443,8 +3445,7 @@ class GPUPerformanceModel {
34433445
fusion_analysis_cache_(device_description_),
34443446
gpu_hlo_cost_analysis_(hlo_cost_analysis_options_, device_description_),
34453447
gpu_performance_model_(device_description_, fusion_analysis_cache_,
3446-
gpu_performance_model_cache_,
3447-
mlir_context_) {}
3448+
gpu_performance_model_cache_, mlir_context_) {}
34483449

34493450
void RunAnalysisOnHloModule(std::shared_ptr<xla::HloModule> hlo_module) {
34503451
hlo_module->entry_computation()->Accept(&gpu_hlo_cost_analysis_);

deps/ReactantExtra/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1059,6 +1059,7 @@ cc_library(
10591059
"-Wl,-exported_symbol,_CreateGPUPerformanceModel",
10601060
"-Wl,-exported_symbol,_RunAnalysisOnHloModule",
10611061
"-Wl,-exported_symbol,_EstimateRunTimeForInstruction",
1062+
"-Wl,-exported_symbol,_registerReactantXLAFFI",
10621063
],
10631064
}),
10641065
linkstatic = True,

deps/ReactantExtra/xla_ffi.cpp

Lines changed: 261 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,261 @@
1+
#include "absl/strings/str_format.h"
2+
3+
#include "xla/ffi/api/c_api.h"
4+
#include "xla/ffi/api/ffi.h"
5+
#include "xla/ffi/ffi_api.h"
6+
7+
#include "mlir/CAPI/IR.h"
8+
9+
#if defined(REACTANT_CUDA)
10+
#include "jaxlib/ffi_helpers.h"
11+
#include "jaxlib/gpu/blas_handle_pool.h"
12+
#endif
13+
14+
#define REACTANT_ABI extern "C" MLIR_CAPI_EXPORTED
15+
16+
using namespace xla;
17+
18+
namespace reactant {
19+
namespace cuda {
20+
21+
#if defined(REACTANT_CUDA)
22+
23+
#include "third_party/gpus/cuda/include/cuComplex.h"
24+
#include "third_party/gpus/cuda/include/cublas_v2.h"
25+
#include "third_party/gpus/cuda/include/cuda.h"
26+
#include "third_party/gpus/cuda/include/cuda_fp8.h"
27+
#include "third_party/gpus/cuda/include/cufft.h"
28+
#include "third_party/gpus/cuda/include/cusolverDn.h"
29+
#include "third_party/gpus/cuda/include/cusolver_common.h"
30+
31+
using namespace jax;
32+
33+
#define SOLVER_BLAS_DISPATCH_IMPL(impl, ...) \
34+
switch (dataType) { \
35+
case ffi::F32: \
36+
return impl<float>(__VA_ARGS__); \
37+
case ffi::F64: \
38+
return impl<double>(__VA_ARGS__); \
39+
case ffi::C64: \
40+
return impl<cuComplex>(__VA_ARGS__); \
41+
case ffi::C128: \
42+
return impl<cuDoubleComplex>(__VA_ARGS__); \
43+
default: \
44+
break; \
45+
}
46+
47+
template <typename T>
48+
ffi::Error GetHostScalar(CUstream stream, bool use_attribute, double value_real,
49+
double value_imag, ffi::AnyBuffer buffer,
50+
T *host_value) {
51+
if (use_attribute) {
52+
if constexpr (std::is_same<T, float>::value) {
53+
*host_value = static_cast<float>(value_real);
54+
} else if constexpr (std::is_same<T, double>::value) {
55+
*host_value = value_real;
56+
} else if constexpr (std::is_same<T, cuComplex>::value) {
57+
*host_value = cuComplex{static_cast<float>(value_real),
58+
static_cast<float>(value_imag)};
59+
} else if constexpr (std::is_same<T, cuDoubleComplex>::value) {
60+
*host_value = cuDoubleComplex{value_real, value_imag};
61+
}
62+
} else {
63+
// Ensure buffer has exactly 1 element
64+
if (buffer.element_count() != 1) {
65+
return ffi::Error::InvalidArgument(
66+
absl::StrFormat("Expected scalar buffer with 1 element, got %d",
67+
buffer.element_count()));
68+
}
69+
// memcpy to host
70+
cudaMemcpyAsync(host_value, buffer.untyped_data(), sizeof(T),
71+
cudaMemcpyDeviceToHost, stream);
72+
}
73+
return ffi::Error::Success();
74+
}
75+
76+
inline ffi::Error CublasStatusToError(cublasStatus_t status,
77+
const char *op_name) {
78+
if (status == CUBLAS_STATUS_SUCCESS) {
79+
return ffi::Error::Success();
80+
}
81+
const char *error_name;
82+
switch (status) {
83+
case CUBLAS_STATUS_NOT_INITIALIZED:
84+
error_name = "CUBLAS_STATUS_NOT_INITIALIZED";
85+
break;
86+
case CUBLAS_STATUS_ALLOC_FAILED:
87+
error_name = "CUBLAS_STATUS_ALLOC_FAILED";
88+
break;
89+
case CUBLAS_STATUS_INVALID_VALUE:
90+
error_name = "CUBLAS_STATUS_INVALID_VALUE";
91+
break;
92+
case CUBLAS_STATUS_ARCH_MISMATCH:
93+
error_name = "CUBLAS_STATUS_ARCH_MISMATCH";
94+
break;
95+
case CUBLAS_STATUS_MAPPING_ERROR:
96+
error_name = "CUBLAS_STATUS_MAPPING_ERROR";
97+
break;
98+
case CUBLAS_STATUS_EXECUTION_FAILED:
99+
error_name = "CUBLAS_STATUS_EXECUTION_FAILED";
100+
break;
101+
case CUBLAS_STATUS_INTERNAL_ERROR:
102+
error_name = "CUBLAS_STATUS_INTERNAL_ERROR";
103+
break;
104+
case CUBLAS_STATUS_NOT_SUPPORTED:
105+
error_name = "CUBLAS_STATUS_NOT_SUPPORTED";
106+
break;
107+
default:
108+
error_name = "UNKNOWN";
109+
break;
110+
}
111+
return ffi::Error::InvalidArgument(
112+
absl::StrFormat("%s failed with status %s", op_name, error_name));
113+
}
114+
115+
namespace blas {
116+
117+
template <typename T>
118+
ffi::Error Syrk(cublasHandle_t handle, cublasFillMode_t uplo,
119+
cublasOperation_t trans, int n, int k, const T *alpha,
120+
const T *a, int lda, const T *beta, T *c, int ldc) {
121+
return ffi::Error::InvalidArgument("Unsupported type for syrk");
122+
}
123+
124+
#define SYRK_SPECIALIZATION(T, cublas_func) \
125+
template <> \
126+
ffi::Error Syrk<T>(cublasHandle_t handle, cublasFillMode_t uplo, \
127+
cublasOperation_t trans, int n, int k, const T *alpha, \
128+
const T *a, int lda, const T *beta, T *c, int ldc) { \
129+
cublasStatus_t status = \
130+
cublas_func(handle, uplo, trans, n, k, alpha, a, lda, beta, c, ldc); \
131+
return CublasStatusToError(status, #cublas_func); \
132+
}
133+
134+
SYRK_SPECIALIZATION(float, cublasSsyrk)
135+
SYRK_SPECIALIZATION(double, cublasDsyrk)
136+
SYRK_SPECIALIZATION(cuComplex, cublasCsyrk)
137+
SYRK_SPECIALIZATION(cuDoubleComplex, cublasZsyrk)
138+
139+
#undef SYRK_SPECIALIZATION
140+
141+
} // namespace blas
142+
143+
// Symmetric rank-k update: syrk
144+
145+
template <typename T>
146+
ffi::Error SyrkImpl(CUstream stream, bool transpose, bool uplo_,
147+
ffi::AnyBuffer a, ffi::AnyBuffer c_in, const T *alpha,
148+
const T *beta, ffi::Result<ffi::AnyBuffer> c_out) {
149+
FFI_ASSIGN_OR_RETURN((auto [batch, rows, cols]),
150+
SplitBatch2D(a.dimensions()));
151+
auto size = transpose ? cols : rows;
152+
FFI_RETURN_IF_ERROR(
153+
CheckShape(c_in.dimensions(), {batch, size, size}, "c_in", "syrk"));
154+
FFI_RETURN_IF_ERROR(
155+
CheckShape(c_out->dimensions(), {batch, size, size}, "c_out", "syrk"));
156+
157+
FFI_ASSIGN_OR_RETURN(auto n,
158+
MaybeCastNoOverflow<int>(transpose ? cols : rows));
159+
FFI_ASSIGN_OR_RETURN(auto k,
160+
MaybeCastNoOverflow<int>(transpose ? rows : cols));
161+
cublasFillMode_t uplo =
162+
uplo_ ? CUBLAS_FILL_MODE_LOWER : CUBLAS_FILL_MODE_UPPER;
163+
cublasOperation_t trans = transpose ? CUBLAS_OP_T : CUBLAS_OP_N;
164+
165+
const T *a_data = static_cast<const T *>(a.untyped_data());
166+
T *c_data = static_cast<T *>(c_in.untyped_data());
167+
T *c_out_data = static_cast<T *>(c_out->untyped_data());
168+
169+
if (c_data != c_out_data) {
170+
cudaError_t err = cudaMemcpyAsync(c_out_data, c_data, c_in.size_bytes(),
171+
cudaMemcpyDeviceToDevice, stream);
172+
if (err != cudaSuccess) {
173+
return ffi::Error::InvalidArgument(absl::StrFormat(
174+
"cudaMemcpyAsync failed: %s", cudaGetErrorString(err)));
175+
}
176+
}
177+
FFI_ASSIGN_OR_RETURN(auto handle, BlasHandlePool::Borrow(stream));
178+
// lda is the leading dimension of a, ldc is the leading dimension of c
179+
// For column-major (which cuBLAS uses), lda = number of rows of a, ldc = n
180+
int lda = transpose ? k : n;
181+
int ldc = n;
182+
for (int i = 0; i < batch; ++i) {
183+
FFI_RETURN_IF_ERROR(blas::Syrk<T>(handle.get(), uplo, trans, n, k, alpha,
184+
a_data, lda, beta, c_out_data, ldc));
185+
a_data += k * n;
186+
c_out_data += n * n;
187+
}
188+
return ffi::Error::Success();
189+
}
190+
191+
template <typename T>
192+
ffi::Error SyrkImpl(CUstream stream, bool transpose, bool uplo,
193+
bool use_alpha_attribute, double alpha_real,
194+
double alpha_imag, bool use_beta_attribute,
195+
double beta_real, double beta_imag, ffi::AnyBuffer a,
196+
ffi::AnyBuffer c_in, ffi::AnyBuffer alpha_,
197+
ffi::AnyBuffer beta_, ffi::Result<ffi::AnyBuffer> c_out) {
198+
T host_alpha, host_beta;
199+
FFI_RETURN_IF_ERROR(GetHostScalar<T>(stream, use_alpha_attribute, alpha_real,
200+
alpha_imag, alpha_, &host_alpha));
201+
FFI_RETURN_IF_ERROR(GetHostScalar<T>(stream, use_beta_attribute, beta_real,
202+
beta_imag, beta_, &host_beta));
203+
return SyrkImpl<T>(stream, transpose, uplo, a, c_in, &host_alpha, &host_beta,
204+
c_out);
205+
}
206+
207+
ffi::Error SyrkDispatch(CUstream stream, bool transpose, bool uplo,
208+
bool use_alpha_attribute, double alpha_real,
209+
double alpha_imag, bool use_beta_attribute,
210+
double beta_real, double beta_imag, ffi::AnyBuffer a,
211+
ffi::AnyBuffer c_in, ffi::AnyBuffer alpha_,
212+
ffi::AnyBuffer beta_,
213+
ffi::Result<ffi::AnyBuffer> c_out) {
214+
auto dataType = c_in.element_type();
215+
SOLVER_BLAS_DISPATCH_IMPL(SyrkImpl, stream, transpose, uplo,
216+
use_alpha_attribute, alpha_real, alpha_imag,
217+
use_beta_attribute, beta_real, beta_imag, a, c_in,
218+
alpha_, beta_, c_out);
219+
return ffi::Error::InvalidArgument(absl::StrFormat(
220+
"Unsupported dtype %s in syrk", absl::FormatStreamed(dataType)));
221+
}
222+
223+
XLA_FFI_DEFINE_HANDLER(
224+
SyrkFfi, SyrkDispatch,
225+
xla::ffi::Ffi::Bind()
226+
.Ctx<ffi::PlatformStream<CUstream>>()
227+
.Attr<bool>("transpose") // transpose
228+
.Attr<bool>("uplo") // uplo
229+
.Attr<bool>("use_alpha_attribute") // use_alpha_attribute
230+
.Attr<double>("alpha_real") // alpha_real
231+
.Attr<double>("alpha_imag") // alpha_imag
232+
.Attr<bool>("use_beta_attribute") // use_beta_attribute
233+
.Attr<double>("beta_real") // beta_real
234+
.Attr<double>("beta_imag") // beta_imag
235+
.Arg<ffi::AnyBuffer>() // a
236+
.Arg<ffi::AnyBuffer>() // c_in
237+
.Arg<ffi::AnyBuffer>() // alpha
238+
.Arg<ffi::AnyBuffer>() // beta
239+
.Ret<ffi::AnyBuffer>() // c_out
240+
);
241+
242+
void registerReactantXLACUDAFFI() {
243+
XLA_FFI_REGISTER_HANDLER(xla::ffi::GetXlaFfiApi(), "reactant_cublas_syrk_ffi",
244+
"CUDA", SyrkFfi);
245+
}
246+
247+
#undef SOLVER_BLAS_DISPATCH_IMPL
248+
249+
#else
250+
251+
void registerReactantXLACUDAFFI() {}
252+
253+
#endif
254+
255+
} // namespace cuda
256+
} // namespace reactant
257+
258+
REACTANT_ABI void registerReactantXLAFFI() {
259+
reactant::cuda::registerReactantXLACUDAFFI();
260+
return;
261+
}

0 commit comments

Comments
 (0)