Skip to content

Commit eccc69a

Browse files
committed
Add _get_cuda_version_from_cuda_h() into _get_proper_cuda_bindings_major_version()
1 parent b09d7ed commit eccc69a

File tree

1 file changed

+46
-10
lines changed

1 file changed

+46
-10
lines changed

cuda_core/build_hooks.py

Lines changed: 46 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import functools
1313
import glob
1414
import os
15+
import pathlib
1516
import re
1617
import subprocess
1718

@@ -24,6 +25,46 @@
2425
get_requires_for_build_sdist = _build_meta.get_requires_for_build_sdist
2526

2627

28+
@functools.cache
29+
def _get_cuda_paths():
30+
CUDA_PATH = os.environ.get("CUDA_PATH", os.environ.get("CUDA_HOME", None))
31+
if not CUDA_PATH:
32+
raise RuntimeError("Environment variable CUDA_PATH or CUDA_HOME is not set")
33+
CUDA_PATH = CUDA_PATH.split(os.pathsep)
34+
print("CUDA paths:", CUDA_PATH, flush=True)
35+
return CUDA_PATH
36+
37+
38+
@functools.cache
39+
def _get_cuda_version_from_cuda_h(cuda_home=None):
40+
"""
41+
Given CUDA_HOME, try to extract the CUDA_VERSION macro from include/cuda.h.
42+
43+
Example line in cuda.h:
44+
#define CUDA_VERSION 13000
45+
46+
Returns the integer (e.g. 13000) or None if not found / on error.
47+
"""
48+
if cuda_home is None:
49+
cuda_home = _get_cuda_paths()[0]
50+
51+
cuda_h = pathlib.Path(cuda_home) / "include" / "cuda.h"
52+
if not cuda_h.is_file():
53+
return None
54+
55+
try:
56+
text = cuda_h.read_text(encoding="utf-8", errors="ignore")
57+
except OSError:
58+
# Permissions issue, unreadable file, etc.
59+
return None
60+
61+
m = re.search(r"^\s*#define\s+CUDA_VERSION\s+(\d+)", text, re.MULTILINE)
62+
if not m:
63+
return None
64+
print(f"CUDA_VERSION from {cuda_h}:", m.group(1), flush=True)
65+
return m.group(1)
66+
67+
2768
@functools.cache
2869
def _get_proper_cuda_bindings_major_version() -> str:
2970
# for local development (with/without build isolation)
@@ -39,6 +80,10 @@ def _get_proper_cuda_bindings_major_version() -> str:
3980
if cuda_major is not None:
4081
return cuda_major
4182

83+
cuda_version = _get_cuda_version_from_cuda_h()
84+
if cuda_version and len(cuda_version) > 3:
85+
return cuda_version[:-3]
86+
4287
# also for local development
4388
try:
4489
out = subprocess.run("nvidia-smi", env=os.environ, capture_output=True, check=True) # noqa: S603, S607
@@ -73,20 +118,11 @@ def strip_prefix_suffix(filename):
73118

74119
module_names = (strip_prefix_suffix(f) for f in ext_files)
75120

76-
@functools.cache
77-
def get_cuda_paths():
78-
CUDA_PATH = os.environ.get("CUDA_PATH", os.environ.get("CUDA_HOME", None))
79-
if not CUDA_PATH:
80-
raise RuntimeError("Environment variable CUDA_PATH or CUDA_HOME is not set")
81-
CUDA_PATH = CUDA_PATH.split(os.pathsep)
82-
print("CUDA paths:", CUDA_PATH)
83-
return CUDA_PATH
84-
85121
ext_modules = tuple(
86122
Extension(
87123
f"cuda.core.experimental.{mod.replace(os.path.sep, '.')}",
88124
sources=[f"cuda/core/experimental/{mod}.pyx"],
89-
include_dirs=list(os.path.join(root, "include") for root in get_cuda_paths()),
125+
include_dirs=list(os.path.join(root, "include") for root in _get_cuda_paths()),
90126
language="c++",
91127
)
92128
for mod in module_names

0 commit comments

Comments
 (0)