Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,9 @@ project(cudecomp LANGUAGES ${LANGS})

# Set up CUDA compute capabilities by CUDA version. Users can override defaults with CUDECOMP_CUDA_CC_LIST
if (NOT CUDECOMP_CUDA_CC_LIST)
if (${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8)
if (${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
set(CUDECOMP_CUDA_CC_LIST_DEFAULTS "80;90;100")
elseif (${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8)
set(CUDECOMP_CUDA_CC_LIST_DEFAULTS "70;80;90;100")
elseif (${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 11.8)
set(CUDECOMP_CUDA_CC_LIST_DEFAULTS "70;80;90")
Expand Down
30 changes: 17 additions & 13 deletions include/internal/cuda_wrap.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,23 +22,25 @@
#include <cudaTypedefs.h>
#endif

#define DECLARE_CUDA_PFN(symbol, version) PFN_##symbol##_v##version pfn_##symbol = nullptr

namespace cudecomp {

struct cuFunctionTable {
#if CUDART_VERSION >= 11030
PFN_cuDeviceGet pfn_cuDeviceGet = nullptr;
PFN_cuDeviceGetAttribute pfn_cuDeviceGetAttribute = nullptr;
PFN_cuGetErrorString pfn_cuGetErrorString = nullptr;
PFN_cuMemAddressFree pfn_cuMemAddressFree = nullptr;
PFN_cuMemAddressReserve pfn_cuMemAddressReserve = nullptr;
PFN_cuMemCreate pfn_cuMemCreate = nullptr;
PFN_cuMemGetAddressRange pfn_cuMemGetAddressRange = nullptr;
PFN_cuMemGetAllocationGranularity pfn_cuMemGetAllocationGranularity = nullptr;
PFN_cuMemMap pfn_cuMemMap = nullptr;
PFN_cuMemRetainAllocationHandle pfn_cuMemRetainAllocationHandle = nullptr;
PFN_cuMemRelease pfn_cuMemRelease = nullptr;
PFN_cuMemSetAccess pfn_cuMemSetAccess = nullptr;
PFN_cuMemUnmap pfn_cuMemUnmap = nullptr;
DECLARE_CUDA_PFN(cuDeviceGet, 2000);
DECLARE_CUDA_PFN(cuDeviceGetAttribute, 2000);
DECLARE_CUDA_PFN(cuGetErrorString, 6000);
DECLARE_CUDA_PFN(cuMemAddressFree, 10020);
DECLARE_CUDA_PFN(cuMemAddressReserve, 10020);
DECLARE_CUDA_PFN(cuMemCreate, 10020);
DECLARE_CUDA_PFN(cuMemGetAddressRange, 3020);
DECLARE_CUDA_PFN(cuMemGetAllocationGranularity, 10020);
DECLARE_CUDA_PFN(cuMemMap, 10020);
DECLARE_CUDA_PFN(cuMemRetainAllocationHandle, 11000);
DECLARE_CUDA_PFN(cuMemRelease, 10020);
DECLARE_CUDA_PFN(cuMemSetAccess, 10020);
DECLARE_CUDA_PFN(cuMemUnmap, 10020);
#endif
};

Expand All @@ -48,4 +50,6 @@ void initCuFunctionTable();

} // namespace cudecomp

#undef DECLARE_CUDA_PFN

#endif // CUDECOMP_CUDA_WRAP_H
40 changes: 24 additions & 16 deletions src/cuda_wrap.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,23 @@
#include "internal/cuda_wrap.h"
#include "internal/exceptions.h"

#if CUDART_VERSION >= 12000
#define LOAD_SYM(symbol) \
#if CUDART_VERSION >= 13000
#define LOAD_SYM(symbol, version) \
do { \
cudaDriverEntryPointQueryResult driverStatus = cudaDriverEntryPointSymbolNotFound; \
CHECK_CUDA(cudaGetDriverEntryPointByVersion(#symbol, (void**)(&cuFnTable.pfn_##symbol), version, \
cudaEnableDefault, &driverStatus)); \
if (driverStatus != cudaDriverEntryPointSuccess) { THROW_CUDA_ERROR("cudaGetDriverEntryPointByVersion failed."); } \
} while (false)
#elif CUDART_VERSION >= 12000
#define LOAD_SYM(symbol, version) \
do { \
cudaDriverEntryPointQueryResult driverStatus = cudaDriverEntryPointSymbolNotFound; \
CHECK_CUDA(cudaGetDriverEntryPoint(#symbol, (void**)(&cuFnTable.pfn_##symbol), cudaEnableDefault, &driverStatus)); \
if (driverStatus != cudaDriverEntryPointSuccess) { THROW_CUDA_ERROR("cudaGetDriverEntryPoint failed."); } \
} while (false)
#else
#define LOAD_SYM(symbol) \
#define LOAD_SYM(symbol, version) \
do { \
CHECK_CUDA(cudaGetDriverEntryPoint(#symbol, (void**)(&cuFnTable.pfn_##symbol), cudaEnableDefault)); \
} while (false)
Expand All @@ -41,19 +49,19 @@ cuFunctionTable cuFnTable; // global table of required CUDA driver functions

void initCuFunctionTable() {
#if CUDART_VERSION >= 11030
LOAD_SYM(cuDeviceGet);
LOAD_SYM(cuDeviceGetAttribute);
LOAD_SYM(cuGetErrorString);
LOAD_SYM(cuMemAddressFree);
LOAD_SYM(cuMemAddressReserve);
LOAD_SYM(cuMemCreate);
LOAD_SYM(cuMemGetAddressRange);
LOAD_SYM(cuMemGetAllocationGranularity);
LOAD_SYM(cuMemMap);
LOAD_SYM(cuMemRetainAllocationHandle);
LOAD_SYM(cuMemRelease);
LOAD_SYM(cuMemSetAccess);
LOAD_SYM(cuMemUnmap);
LOAD_SYM(cuDeviceGet, 2000);
LOAD_SYM(cuDeviceGetAttribute, 2000);
LOAD_SYM(cuGetErrorString, 6000);
LOAD_SYM(cuMemAddressFree, 10020);
LOAD_SYM(cuMemAddressReserve, 10020);
LOAD_SYM(cuMemCreate, 10020);
LOAD_SYM(cuMemGetAddressRange, 3020);
LOAD_SYM(cuMemGetAllocationGranularity, 10020);
LOAD_SYM(cuMemMap, 10020);
LOAD_SYM(cuMemRetainAllocationHandle, 11000);
LOAD_SYM(cuMemRelease, 10020);
LOAD_SYM(cuMemSetAccess, 10020);
LOAD_SYM(cuMemUnmap, 10020);
#endif
}

Expand Down