|
| 1 | +/* |
| 2 | + * cbits/eigsh.c |
| 3 | + * |
| 4 | + * GPU-resident symmetric eigendecomposition via cuSOLVER. |
| 5 | + * |
| 6 | + * Supports f32 (cusolverDnSsyevd) and f64 (cusolverDnDsyevd). |
| 7 | + * cuSOLVER is resolved at runtime through dlopen/dlsym — no link-time |
| 8 | + * dependency on CUDA toolkit. The only link-time requirements are |
| 9 | + * libaf (ArrayFire) and libdl. |
| 10 | + * |
| 11 | + * Returns AF_ERR_RUNTIME when CUDA backend is not active or cuSOLVER |
| 12 | + * cannot be found (graceful degradation on CPU/OpenCL builds). |
| 13 | + * |
| 14 | + * Ordering: cusolverDnDsyevd returns eigenvalues in ascending order, |
| 15 | + * matching hmatrix's eigSH convention. |
| 16 | + */ |
| 17 | + |
| 18 | +#define _GNU_SOURCE |
| 19 | +#define AF_DEFINE_CUDA_TYPES /* gives us cudaStream_t in af/cuda.h */ |
| 20 | +#include "arrayfire.h" |
| 21 | +#include "af/cuda.h" |
| 22 | +#include <dlfcn.h> |
| 23 | +#include <stddef.h> |
| 24 | + |
| 25 | +/* ── minimal cuSOLVER types (avoids needing CUDA toolkit headers) ── */ |
| 26 | +typedef void *cusolverDnHandle_t; |
| 27 | +typedef void *cudaStream_t_t; /* distinct name to avoid redefinition */ |
| 28 | +typedef int cusolverStatus_t; |
| 29 | + |
| 30 | +#define CUSOLVER_STATUS_SUCCESS 0 |
| 31 | +#define CUBLAS_FILL_MODE_LOWER 0 |
| 32 | +#define CUSOLVER_EIG_MODE_VECTOR 1 |
| 33 | + |
| 34 | +/* ── function pointer typedefs ── */ |
| 35 | +typedef cusolverStatus_t (*pfn_Create) (cusolverDnHandle_t *); |
| 36 | +typedef cusolverStatus_t (*pfn_SetStream) (cusolverDnHandle_t, cudaStream_t); |
| 37 | + |
| 38 | +typedef cusolverStatus_t (*pfn_DsyevdBuf)(cusolverDnHandle_t, int, int, |
| 39 | + int, const double *, int, const double *, int *); |
| 40 | +typedef cusolverStatus_t (*pfn_Dsyevd) (cusolverDnHandle_t, int, int, |
| 41 | + int, double *, int, double *, double *, int, int *); |
| 42 | + |
| 43 | +typedef cusolverStatus_t (*pfn_SsyevdBuf)(cusolverDnHandle_t, int, int, |
| 44 | + int, const float *, int, const float *, int *); |
| 45 | +typedef cusolverStatus_t (*pfn_Ssyevd) (cusolverDnHandle_t, int, int, |
| 46 | + int, float *, int, float *, float *, int, int *); |
| 47 | + |
| 48 | +/* ── module-level state ── */ |
| 49 | +static cusolverDnHandle_t g_handle = NULL; |
| 50 | +static pfn_Create fn_Create = NULL; |
| 51 | +static pfn_SetStream fn_SetStr = NULL; |
| 52 | +static pfn_DsyevdBuf fn_DsyBuf = NULL; |
| 53 | +static pfn_Dsyevd fn_Dsyevd = NULL; |
| 54 | +static pfn_SsyevdBuf fn_SsyBuf = NULL; |
| 55 | +static pfn_Ssyevd fn_Ssyevd = NULL; |
| 56 | +static int g_init = 0; /* 0 = uninitialised */ |
| 57 | + |
| 58 | +static af_err load_and_init(void) |
| 59 | +{ |
| 60 | + /* Try the exact versioned name first (already loaded by AF CUDA backend), |
| 61 | + * then fall back to an unversioned symlink if present. */ |
| 62 | + void *lib = dlopen("libcusolver.so.11", RTLD_NOW | RTLD_NOLOAD); |
| 63 | + if (!lib) lib = dlopen("libcusolver.so.11", RTLD_NOW | RTLD_GLOBAL); |
| 64 | + if (!lib) lib = dlopen("libcusolver.so", RTLD_NOW | RTLD_GLOBAL); |
| 65 | + if (!lib) return AF_ERR_RUNTIME; |
| 66 | + |
| 67 | + fn_Create = (pfn_Create) dlsym(lib, "cusolverDnCreate"); |
| 68 | + fn_SetStr = (pfn_SetStream) dlsym(lib, "cusolverDnSetStream"); |
| 69 | + fn_DsyBuf = (pfn_DsyevdBuf) dlsym(lib, "cusolverDnDsyevd_bufferSize"); |
| 70 | + fn_Dsyevd = (pfn_Dsyevd) dlsym(lib, "cusolverDnDsyevd"); |
| 71 | + fn_SsyBuf = (pfn_SsyevdBuf) dlsym(lib, "cusolverDnSsyevd_bufferSize"); |
| 72 | + fn_Ssyevd = (pfn_Ssyevd) dlsym(lib, "cusolverDnSsyevd"); |
| 73 | + |
| 74 | + if (!fn_Create || !fn_SetStr || !fn_DsyBuf || !fn_Dsyevd || |
| 75 | + !fn_SsyBuf || !fn_Ssyevd) |
| 76 | + return AF_ERR_RUNTIME; |
| 77 | + |
| 78 | + if (fn_Create(&g_handle) != CUSOLVER_STATUS_SUCCESS) |
| 79 | + return AF_ERR_INTERNAL; |
| 80 | + |
| 81 | + /* Bind cuSOLVER to ArrayFire's CUDA stream (device 0) so that |
| 82 | + * cuSOLVER kernels are sequenced correctly with AF operations. */ |
| 83 | + cudaStream_t stream = NULL; |
| 84 | + if (afcu_get_stream(&stream, 0) == AF_SUCCESS && stream) |
| 85 | + fn_SetStr(g_handle, stream); |
| 86 | + |
| 87 | + return AF_SUCCESS; |
| 88 | +} |
| 89 | + |
| 90 | +static af_err ensure_init(void) |
| 91 | +{ |
| 92 | + if (g_init) return g_handle ? AF_SUCCESS : AF_ERR_RUNTIME; |
| 93 | + g_init = 1; |
| 94 | + return load_and_init(); |
| 95 | +} |
| 96 | + |
| 97 | +/* ── core eigensolver: writes eigenvectors into d_A, eigenvalues into d_W ── */ |
| 98 | +static af_err run_syevd(int is_double, int n, void *d_A, void *d_W) |
| 99 | +{ |
| 100 | + int lwork; |
| 101 | + cusolverStatus_t st; |
| 102 | + |
| 103 | + if (is_double) { |
| 104 | + st = fn_DsyBuf(g_handle, CUSOLVER_EIG_MODE_VECTOR, CUBLAS_FILL_MODE_LOWER, |
| 105 | + n, (const double *)d_A, n, (const double *)d_W, &lwork); |
| 106 | + } else { |
| 107 | + st = fn_SsyBuf(g_handle, CUSOLVER_EIG_MODE_VECTOR, CUBLAS_FILL_MODE_LOWER, |
| 108 | + n, (const float *)d_A, n, (const float *)d_W, &lwork); |
| 109 | + } |
| 110 | + if (st != CUSOLVER_STATUS_SUCCESS) return AF_ERR_INTERNAL; |
| 111 | + |
| 112 | + dim_t wsz = (dim_t)lwork * (is_double ? sizeof(double) : sizeof(float)); |
| 113 | + |
| 114 | + void *d_work = NULL, *d_info = NULL; |
| 115 | + af_err err; |
| 116 | + if ((err = af_alloc_device_v2(&d_work, wsz)) != AF_SUCCESS) return err; |
| 117 | + if ((err = af_alloc_device_v2(&d_info, sizeof(int))) != AF_SUCCESS) { |
| 118 | + af_free_device_v2(d_work); |
| 119 | + return err; |
| 120 | + } |
| 121 | + |
| 122 | + if (is_double) { |
| 123 | + st = fn_Dsyevd(g_handle, CUSOLVER_EIG_MODE_VECTOR, CUBLAS_FILL_MODE_LOWER, |
| 124 | + n, (double *)d_A, n, (double *)d_W, |
| 125 | + (double *)d_work, lwork, (int *)d_info); |
| 126 | + } else { |
| 127 | + st = fn_Ssyevd(g_handle, CUSOLVER_EIG_MODE_VECTOR, CUBLAS_FILL_MODE_LOWER, |
| 128 | + n, (float *)d_A, n, (float *)d_W, |
| 129 | + (float *)d_work, lwork, (int *)d_info); |
| 130 | + } |
| 131 | + |
| 132 | + af_free_device_v2(d_work); |
| 133 | + af_free_device_v2(d_info); |
| 134 | + return (st == CUSOLVER_STATUS_SUCCESS) ? AF_SUCCESS : AF_ERR_INTERNAL; |
| 135 | +} |
| 136 | + |
| 137 | +/* ── public entry point exposed to Haskell ── */ |
| 138 | +af_err af_eigsh(af_array *evals_out, af_array *evecs_out, const af_array input) |
| 139 | +{ |
| 140 | + af_err err; |
| 141 | + |
| 142 | + if ((err = ensure_init()) != AF_SUCCESS) return err; |
| 143 | + |
| 144 | + af_dtype dtype; |
| 145 | + if ((err = af_get_type(&dtype, input)) != AF_SUCCESS) return err; |
| 146 | + if (dtype != f64 && dtype != f32) return AF_ERR_TYPE; |
| 147 | + |
| 148 | + dim_t d0, d1, d2, d3; |
| 149 | + if ((err = af_get_dims(&d0, &d1, &d2, &d3, input)) != AF_SUCCESS) return err; |
| 150 | + int n = (int)d0; |
| 151 | + |
| 152 | + /* Working copy: cuSOLVER overwrites A in-place with eigenvectors */ |
| 153 | + af_array evecs; |
| 154 | + if ((err = af_copy_array(&evecs, input)) != AF_SUCCESS) return err; |
| 155 | + |
| 156 | + /* Eigenvalue output: n-element array owned and managed by ArrayFire */ |
| 157 | + af_array evals; |
| 158 | + dim_t n_dim = (dim_t)n; |
| 159 | + if ((err = af_constant(&evals, 0.0, 1, &n_dim, dtype)) != AF_SUCCESS) { |
| 160 | + af_release_array(evecs); |
| 161 | + return err; |
| 162 | + } |
| 163 | + |
| 164 | + /* Lock both arrays and obtain raw device pointers for cuSOLVER */ |
| 165 | + void *d_A = NULL, *d_W = NULL; |
| 166 | + if ((err = af_get_device_ptr(&d_A, evecs)) != AF_SUCCESS) { |
| 167 | + af_release_array(evecs); af_release_array(evals); |
| 168 | + return err; |
| 169 | + } |
| 170 | + if ((err = af_get_device_ptr(&d_W, evals)) != AF_SUCCESS) { |
| 171 | + af_unlock_array(evecs); |
| 172 | + af_release_array(evecs); af_release_array(evals); |
| 173 | + return err; |
| 174 | + } |
| 175 | + |
| 176 | + err = run_syevd(dtype == f64, n, d_A, d_W); |
| 177 | + |
| 178 | + /* Unlock: ArrayFire resumes ownership and sees the in-place modifications */ |
| 179 | + af_unlock_array(evecs); |
| 180 | + af_unlock_array(evals); |
| 181 | + |
| 182 | + if (err != AF_SUCCESS) { |
| 183 | + af_release_array(evecs); af_release_array(evals); |
| 184 | + return err; |
| 185 | + } |
| 186 | + |
| 187 | + *evals_out = evals; |
| 188 | + *evecs_out = evecs; |
| 189 | + return AF_SUCCESS; |
| 190 | +} |
0 commit comments