1212import functools
1313import glob
1414import os
15+ import pathlib
1516import re
1617import subprocess
1718
2425get_requires_for_build_sdist = _build_meta .get_requires_for_build_sdist
2526
2627
28+ @functools .cache
29+ def _get_cuda_paths ():
30+ CUDA_PATH = os .environ .get ("CUDA_PATH" , os .environ .get ("CUDA_HOME" , None ))
31+ if not CUDA_PATH :
32+ raise RuntimeError ("Environment variable CUDA_PATH or CUDA_HOME is not set" )
33+ CUDA_PATH = CUDA_PATH .split (os .pathsep )
34+ print ("CUDA paths:" , CUDA_PATH , flush = True )
35+ return CUDA_PATH
36+
37+
38+ @functools .cache
39+ def _get_cuda_version_from_cuda_h (cuda_home = None ):
40+ """
41+ Given CUDA_HOME, try to extract the CUDA_VERSION macro from include/cuda.h.
42+
43+ Example line in cuda.h:
44+ #define CUDA_VERSION 13000
45+
46+ Returns the integer (e.g. 13000) or None if not found / on error.
47+ """
48+ if cuda_home is None :
49+ cuda_home = _get_cuda_paths ()[0 ]
50+
51+ cuda_h = pathlib .Path (cuda_home ) / "include" / "cuda.h"
52+ if not cuda_h .is_file ():
53+ return None
54+
55+ try :
56+ text = cuda_h .read_text (encoding = "utf-8" , errors = "ignore" )
57+ except OSError :
58+ # Permissions issue, unreadable file, etc.
59+ return None
60+
61+ m = re .search (r"^\s*#define\s+CUDA_VERSION\s+(\d+)" , text , re .MULTILINE )
62+ if not m :
63+ return None
64+ print (f"CUDA_VERSION from { cuda_h } :" , m .group (1 ), flush = True )
65+ return m .group (1 )
66+
67+
2768@functools .cache
2869def _get_proper_cuda_bindings_major_version () -> str :
2970 # for local development (with/without build isolation)
@@ -39,6 +80,10 @@ def _get_proper_cuda_bindings_major_version() -> str:
3980 if cuda_major is not None :
4081 return cuda_major
4182
83+ cuda_version = _get_cuda_version_from_cuda_h ()
84+ if cuda_version and len (cuda_version ) > 3 :
85+ return cuda_version [:- 3 ]
86+
4287 # also for local development
4388 try :
4489 out = subprocess .run ("nvidia-smi" , env = os .environ , capture_output = True , check = True ) # noqa: S603, S607
@@ -73,20 +118,11 @@ def strip_prefix_suffix(filename):
73118
74119 module_names = (strip_prefix_suffix (f ) for f in ext_files )
75120
76- @functools .cache
77- def get_cuda_paths ():
78- CUDA_PATH = os .environ .get ("CUDA_PATH" , os .environ .get ("CUDA_HOME" , None ))
79- if not CUDA_PATH :
80- raise RuntimeError ("Environment variable CUDA_PATH or CUDA_HOME is not set" )
81- CUDA_PATH = CUDA_PATH .split (os .pathsep )
82- print ("CUDA paths:" , CUDA_PATH )
83- return CUDA_PATH
84-
85121 ext_modules = tuple (
86122 Extension (
87123 f"cuda.core.experimental.{ mod .replace (os .path .sep , '.' )} " ,
88124 sources = [f"cuda/core/experimental/{ mod } .pyx" ],
89- include_dirs = list (os .path .join (root , "include" ) for root in get_cuda_paths ()),
125+ include_dirs = list (os .path .join (root , "include" ) for root in _get_cuda_paths ()),
90126 language = "c++" ,
91127 )
92128 for mod in module_names
0 commit comments