|
| 1 | +#include "absl/strings/str_format.h" |
| 2 | + |
| 3 | +#include "jaxlib/ffi_helpers.h" |
| 4 | +#include "jaxlib/gpu/blas_handle_pool.h" |
| 5 | +#include "xla/ffi/api/c_api.h" |
| 6 | +#include "xla/ffi/api/ffi.h" |
| 7 | +#include "xla/ffi/ffi_api.h" |
| 8 | + |
| 9 | +#include "mlir/CAPI/IR.h" |
| 10 | + |
| 11 | +#define REACTANT_ABI extern "C" MLIR_CAPI_EXPORTED |
| 12 | + |
| 13 | +using namespace jax; |
| 14 | +using namespace xla; |
| 15 | + |
| 16 | +namespace reactant { |
| 17 | +namespace cuda { |
| 18 | + |
| 19 | +#if defined(REACTANT_CUDA) |
| 20 | + |
| 21 | +#include "third_party/gpus/cuda/include/cuComplex.h" |
| 22 | +#include "third_party/gpus/cuda/include/cublas_v2.h" |
| 23 | +#include "third_party/gpus/cuda/include/cuda.h" |
| 24 | +#include "third_party/gpus/cuda/include/cuda_fp8.h" |
| 25 | +#include "third_party/gpus/cuda/include/cufft.h" |
| 26 | +#include "third_party/gpus/cuda/include/cusolverDn.h" |
| 27 | +#include "third_party/gpus/cuda/include/cusolver_common.h" |
| 28 | + |
| 29 | +#define SOLVER_BLAS_DISPATCH_IMPL(impl, ...) \ |
| 30 | + switch (dataType) { \ |
| 31 | + case ffi::F32: \ |
| 32 | + return impl<float>(__VA_ARGS__); \ |
| 33 | + case ffi::F64: \ |
| 34 | + return impl<double>(__VA_ARGS__); \ |
| 35 | + case ffi::C64: \ |
| 36 | + return impl<cuComplex>(__VA_ARGS__); \ |
| 37 | + case ffi::C128: \ |
| 38 | + return impl<cuDoubleComplex>(__VA_ARGS__); \ |
| 39 | + default: \ |
| 40 | + break; \ |
| 41 | + } |
| 42 | + |
| 43 | +template <typename T> |
| 44 | +T GetHostScalar(CUstream stream, bool use_attribute, double value_real, |
| 45 | + double value_imag, ffi::AnyBuffer buffer) { |
| 46 | + T host_value; |
| 47 | + if (use_attribute) { |
| 48 | + if constexpr (std::is_same<T, float>::value) { |
| 49 | + host_value = static_cast<float>(value_real); |
| 50 | + } else if constexpr (std::is_same<T, double>::value) { |
| 51 | + host_value = value_real; |
| 52 | + } else if constexpr (std::is_same<T, cuComplex>::value) { |
| 53 | + host_value = cuComplex{static_cast<float>(value_real), |
| 54 | + static_cast<float>(value_imag)}; |
| 55 | + } else if constexpr (std::is_same<T, cuDoubleComplex>::value) { |
| 56 | + host_value = cuDoubleComplex{value_real, value_imag}; |
| 57 | + } |
| 58 | + } else { |
| 59 | + // memcpy to host |
| 60 | + cudaMemcpyAsync(&host_value, buffer.untyped_data(), sizeof(T), |
| 61 | + cudaMemcpyDeviceToHost, stream); |
| 62 | + } |
| 63 | + return host_value; |
| 64 | +} |
| 65 | + |
| 66 | +inline ffi::Error CublasStatusToError(cublasStatus_t status, |
| 67 | + const char *op_name) { |
| 68 | + if (status == CUBLAS_STATUS_SUCCESS) { |
| 69 | + return ffi::Error::Success(); |
| 70 | + } |
| 71 | + const char *error_name; |
| 72 | + switch (status) { |
| 73 | + case CUBLAS_STATUS_NOT_INITIALIZED: |
| 74 | + error_name = "CUBLAS_STATUS_NOT_INITIALIZED"; |
| 75 | + break; |
| 76 | + case CUBLAS_STATUS_ALLOC_FAILED: |
| 77 | + error_name = "CUBLAS_STATUS_ALLOC_FAILED"; |
| 78 | + break; |
| 79 | + case CUBLAS_STATUS_INVALID_VALUE: |
| 80 | + error_name = "CUBLAS_STATUS_INVALID_VALUE"; |
| 81 | + break; |
| 82 | + case CUBLAS_STATUS_ARCH_MISMATCH: |
| 83 | + error_name = "CUBLAS_STATUS_ARCH_MISMATCH"; |
| 84 | + break; |
| 85 | + case CUBLAS_STATUS_MAPPING_ERROR: |
| 86 | + error_name = "CUBLAS_STATUS_MAPPING_ERROR"; |
| 87 | + break; |
| 88 | + case CUBLAS_STATUS_EXECUTION_FAILED: |
| 89 | + error_name = "CUBLAS_STATUS_EXECUTION_FAILED"; |
| 90 | + break; |
| 91 | + case CUBLAS_STATUS_INTERNAL_ERROR: |
| 92 | + error_name = "CUBLAS_STATUS_INTERNAL_ERROR"; |
| 93 | + break; |
| 94 | + case CUBLAS_STATUS_NOT_SUPPORTED: |
| 95 | + error_name = "CUBLAS_STATUS_NOT_SUPPORTED"; |
| 96 | + break; |
| 97 | + default: |
| 98 | + error_name = "UNKNOWN"; |
| 99 | + break; |
| 100 | + } |
| 101 | + return ffi::Error::InvalidArgument( |
| 102 | + absl::StrFormat("%s failed with status %s", op_name, error_name)); |
| 103 | +} |
| 104 | + |
| 105 | +namespace blas { |
| 106 | + |
| 107 | +template <typename T> |
| 108 | +ffi::Error Syrk(cublasHandle_t handle, cublasFillMode_t uplo, |
| 109 | + cublasOperation_t trans, int n, int k, const T *alpha, |
| 110 | + const T *a, int lda, const T *beta, T *c, int ldc) { |
| 111 | + return ffi::Error::InvalidArgument("Unsupported type for syrk"); |
| 112 | +} |
| 113 | + |
| 114 | +#define SYRK_SPECIALIZATION(T, cublas_func) \ |
| 115 | + template <> \ |
| 116 | + ffi::Error Syrk<T>(cublasHandle_t handle, cublasFillMode_t uplo, \ |
| 117 | + cublasOperation_t trans, int n, int k, const T *alpha, \ |
| 118 | + const T *a, int lda, const T *beta, T *c, int ldc) { \ |
| 119 | + cublasStatus_t status = \ |
| 120 | + cublas_func(handle, uplo, trans, n, k, alpha, a, lda, beta, c, ldc); \ |
| 121 | + return CublasStatusToError(status, #cublas_func); \ |
| 122 | + } |
| 123 | + |
| 124 | +SYRK_SPECIALIZATION(float, cublasSsyrk) |
| 125 | +SYRK_SPECIALIZATION(double, cublasDsyrk) |
| 126 | +SYRK_SPECIALIZATION(cuComplex, cublasCsyrk) |
| 127 | +SYRK_SPECIALIZATION(cuDoubleComplex, cublasZsyrk) |
| 128 | + |
| 129 | +#undef SYRK_SPECIALIZATION |
| 130 | + |
| 131 | +} // namespace blas |
| 132 | + |
| 133 | +// Symmetric rank-k update: syrk |
| 134 | + |
| 135 | +template <typename T> |
| 136 | +ffi::Error SyrkImpl(CUstream stream, bool transpose, bool uplo_, |
| 137 | + ffi::AnyBuffer a, ffi::AnyBuffer c_in, const T *alpha, |
| 138 | + const T *beta, ffi::Result<ffi::AnyBuffer> c_out) { |
| 139 | + FFI_ASSIGN_OR_RETURN((auto [batch, rows, cols]), |
| 140 | + SplitBatch2D(a.dimensions())); |
| 141 | + auto size = transpose ? cols : rows; |
| 142 | + FFI_RETURN_IF_ERROR( |
| 143 | + CheckShape(c_in.dimensions(), {batch, size, size}, "c_in", "syrk")); |
| 144 | + FFI_RETURN_IF_ERROR( |
| 145 | + CheckShape(c_out->dimensions(), {batch, size, size}, "c_out", "syrk")); |
| 146 | + |
| 147 | + FFI_ASSIGN_OR_RETURN(auto n, |
| 148 | + MaybeCastNoOverflow<int>(transpose ? cols : rows)); |
| 149 | + FFI_ASSIGN_OR_RETURN(auto k, |
| 150 | + MaybeCastNoOverflow<int>(transpose ? rows : cols)); |
| 151 | + cublasFillMode_t uplo = |
| 152 | + uplo_ ? CUBLAS_FILL_MODE_LOWER : CUBLAS_FILL_MODE_UPPER; |
| 153 | + cublasOperation_t trans = transpose ? CUBLAS_OP_T : CUBLAS_OP_N; |
| 154 | + |
| 155 | + const T *a_data = static_cast<const T *>(a.untyped_data()); |
| 156 | + T *c_data = static_cast<T *>(c_in.untyped_data()); |
| 157 | + T *c_out_data = static_cast<T *>(c_out->untyped_data()); |
| 158 | + |
| 159 | + if (c_data != c_out_data) { |
| 160 | + cudaError_t err = cudaMemcpyAsync(c_out_data, c_data, c_in.size_bytes(), |
| 161 | + cudaMemcpyDeviceToDevice, stream); |
| 162 | + if (err != cudaSuccess) { |
| 163 | + return ffi::Error::InvalidArgument(absl::StrFormat( |
| 164 | + "cudaMemcpyAsync failed: %s", cudaGetErrorString(err))); |
| 165 | + } |
| 166 | + } |
| 167 | + FFI_ASSIGN_OR_RETURN(auto handle, BlasHandlePool::Borrow(stream)); |
| 168 | + // lda is the leading dimension of a, ldc is the leading dimension of c |
| 169 | + // For column-major (which cuBLAS uses), lda = number of rows of a, ldc = n |
| 170 | + int lda = transpose ? k : n; |
| 171 | + int ldc = n; |
| 172 | + for (int i = 0; i < batch; ++i) { |
| 173 | + FFI_RETURN_IF_ERROR(blas::Syrk<T>(handle.get(), uplo, trans, n, k, alpha, |
| 174 | + a_data, lda, beta, c_out_data, ldc)); |
| 175 | + a_data += k * n; |
| 176 | + c_out_data += n * n; |
| 177 | + } |
| 178 | + return ffi::Error::Success(); |
| 179 | +} |
| 180 | + |
| 181 | +template <typename T> |
| 182 | +ffi::Error SyrkImpl(CUstream stream, bool transpose, bool uplo, |
| 183 | + bool use_alpha_attribute, double alpha_real, |
| 184 | + double alpha_imag, bool use_beta_attribute, |
| 185 | + double beta_real, double beta_imag, ffi::AnyBuffer a, |
| 186 | + ffi::AnyBuffer c_in, ffi::AnyBuffer alpha_, |
| 187 | + ffi::AnyBuffer beta_, ffi::Result<ffi::AnyBuffer> c_out) { |
| 188 | + T host_alpha = GetHostScalar<T>(stream, use_alpha_attribute, alpha_real, |
| 189 | + alpha_imag, alpha_); |
| 190 | + T host_beta = |
| 191 | + GetHostScalar<T>(stream, use_beta_attribute, beta_real, beta_imag, beta_); |
| 192 | + return SyrkImpl<T>(stream, transpose, uplo, a, c_in, &host_alpha, &host_beta, |
| 193 | + c_out); |
| 194 | +} |
| 195 | + |
| 196 | +ffi::Error SyrkDispatch(CUstream stream, bool transpose, bool uplo, |
| 197 | + bool use_alpha_attribute, double alpha_real, |
| 198 | + double alpha_imag, bool use_beta_attribute, |
| 199 | + double beta_real, double beta_imag, ffi::AnyBuffer a, |
| 200 | + ffi::AnyBuffer c_in, ffi::AnyBuffer alpha_, |
| 201 | + ffi::AnyBuffer beta_, |
| 202 | + ffi::Result<ffi::AnyBuffer> c_out) { |
| 203 | + auto dataType = c_in.element_type(); |
| 204 | + SOLVER_BLAS_DISPATCH_IMPL(SyrkImpl, stream, transpose, uplo, |
| 205 | + use_alpha_attribute, alpha_real, alpha_imag, |
| 206 | + use_beta_attribute, beta_real, beta_imag, a, c_in, |
| 207 | + alpha_, beta_, c_out); |
| 208 | + return ffi::Error::InvalidArgument(absl::StrFormat( |
| 209 | + "Unsupported dtype %s in syrk", absl::FormatStreamed(dataType))); |
| 210 | +} |
| 211 | + |
| 212 | +XLA_FFI_DEFINE_HANDLER( |
| 213 | + SyrkFfi, SyrkDispatch, |
| 214 | + xla::ffi::Ffi::Bind() |
| 215 | + .Ctx<ffi::PlatformStream<CUstream>>() |
| 216 | + .Attr<bool>("transpose") // transpose |
| 217 | + .Attr<bool>("uplo") // uplo |
| 218 | + .Attr<bool>("use_alpha_attribute") // use_alpha_attribute |
| 219 | + .Attr<double>("alpha_real") // alpha_real |
| 220 | + .Attr<double>("alpha_imag") // alpha_imag |
| 221 | + .Attr<bool>("use_beta_attribute") // use_beta_attribute |
| 222 | + .Attr<double>("beta_real") // beta_real |
| 223 | + .Attr<double>("beta_imag") // beta_imag |
| 224 | + .Arg<ffi::AnyBuffer>() // a |
| 225 | + .Arg<ffi::AnyBuffer>() // c_in |
| 226 | + .Arg<ffi::AnyBuffer>() // alpha |
| 227 | + .Arg<ffi::AnyBuffer>() // beta |
| 228 | + .Ret<ffi::AnyBuffer>() // c_out |
| 229 | +); |
| 230 | + |
| 231 | +void registerReactantXLACUDAFFI() { |
| 232 | + XLA_FFI_REGISTER_HANDLER(xla::ffi::GetXlaFfiApi(), "reactant_cublas_syrk_ffi", |
| 233 | + "CUDA", SyrkFfi); |
| 234 | +} |
| 235 | + |
| 236 | +#undef SOLVER_BLAS_DISPATCH_IMPL |
| 237 | + |
| 238 | +#else |
| 239 | + |
| 240 | +void registerReactantXLACUDAFFI() {} |
| 241 | + |
| 242 | +#endif |
| 243 | + |
| 244 | +} // namespace cuda |
| 245 | +} // namespace reactant |
| 246 | + |
| 247 | +REACTANT_ABI void registerReactantXLAFFI() { |
| 248 | + reactant::cuda::registerReactantXLACUDAFFI(); |
| 249 | + return; |
| 250 | +} |
0 commit comments