1- # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1+ # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22# SPDX-License-Identifier: Apache-2.0
33
44from __future__ import annotations
55
66import functools
77import glob
88import os
9+ from collections .abc import Callable
910from dataclasses import dataclass
11+ from typing import TYPE_CHECKING
1012
1113from cuda .pathfinder ._dynamic_libs .load_nvidia_dynamic_lib import (
1214 _resolve_system_loaded_abs_path_in_subprocess ,
1315)
1416from cuda .pathfinder ._dynamic_libs .search_steps import derive_ctk_root
15- from cuda .pathfinder ._headers import supported_nvidia_headers
17+ from cuda .pathfinder ._headers .header_descriptor import (
18+ HEADER_DESCRIPTORS ,
19+ platform_include_subdirs ,
20+ resolve_conda_anchor ,
21+ )
1622from cuda .pathfinder ._utils .env_vars import get_cuda_home_or_path
1723from cuda .pathfinder ._utils .find_sub_dirs import find_sub_dirs_all_sitepackages
18- from cuda .pathfinder ._utils .platform_aware import IS_WINDOWS
24+
25+ if TYPE_CHECKING :
26+ from cuda .pathfinder ._headers .header_descriptor import HeaderDescriptor
27+
28+ # ---------------------------------------------------------------------------
29+ # Data types
30+ # ---------------------------------------------------------------------------
1931
2032
2133@dataclass
@@ -27,6 +39,14 @@ def __post_init__(self) -> None:
2739 self .abs_path = _abs_norm (self .abs_path )
2840
2941
42+ #: Type alias for a header find step callable.
43+ HeaderFindStep = Callable [["HeaderDescriptor" ], LocatedHeaderDir | None ]
44+
45+ # ---------------------------------------------------------------------------
46+ # Helpers
47+ # ---------------------------------------------------------------------------
48+
49+
3050def _abs_norm (path : str | None ) -> str | None :
3151 if path :
3252 return os .path .normpath (os .path .abspath (path ))
@@ -37,102 +57,119 @@ def _joined_isfile(dirpath: str, basename: str) -> bool:
3757 return os .path .isfile (os .path .join (dirpath , basename ))
3858
3959
40- def _locate_under_site_packages (sub_dir : str , h_basename : str ) -> LocatedHeaderDir | None :
41- # Installed from a wheel
42- hdr_dir : str # help mypy
43- for hdr_dir in find_sub_dirs_all_sitepackages (tuple (sub_dir .split ("/" ))):
44- if _joined_isfile (hdr_dir , h_basename ):
45- return LocatedHeaderDir (abs_path = hdr_dir , found_via = "site-packages" )
60+ def _locate_in_anchor_layout (desc : HeaderDescriptor , anchor_point : str ) -> str | None :
61+ """Search for a header under *anchor_point* using the descriptor's layout fields."""
62+ h_basename = desc .header_basename
63+ for rel_dir in desc .anchor_include_rel_dirs :
64+ idir = os .path .join (anchor_point , rel_dir )
65+ for subdir in platform_include_subdirs (desc ):
66+ cdir = os .path .join (idir , subdir )
67+ if _joined_isfile (cdir , h_basename ):
68+ return cdir
69+ if _joined_isfile (idir , h_basename ):
70+ return idir
4671 return None
4772
4873
49- def _locate_based_on_ctk_layout (libname : str , h_basename : str , anchor_point : str ) -> str | None :
50- parts = [anchor_point ]
51- if libname == "nvvm" :
52- parts .append (libname )
53- parts .append ("include" )
54- idir = os .path .join (* parts )
55- if libname == "cccl" :
56- if IS_WINDOWS :
57- cdir_ctk12 = os .path .join (idir , "targets" , "x64" ) # conda has this anomaly
58- cdir_ctk13 = os .path .join (cdir_ctk12 , "cccl" )
59- if _joined_isfile (cdir_ctk13 , h_basename ):
60- return cdir_ctk13
61- if _joined_isfile (cdir_ctk12 , h_basename ):
62- return cdir_ctk12
63- cdir = os .path .join (idir , "cccl" ) # CTK 13
64- if _joined_isfile (cdir , h_basename ):
65- return cdir
66- if _joined_isfile (idir , h_basename ):
67- return idir
74+ # ---------------------------------------------------------------------------
75+ # Find steps
76+ # ---------------------------------------------------------------------------
77+
78+
79+ def find_in_site_packages (desc : HeaderDescriptor ) -> LocatedHeaderDir | None :
80+ """Search pip wheel install locations."""
81+ for sub_dir in desc .site_packages_dirs :
82+ hdr_dir : str # help mypy
83+ for hdr_dir in find_sub_dirs_all_sitepackages (tuple (sub_dir .split ("/" ))):
84+ if _joined_isfile (hdr_dir , desc .header_basename ):
85+ return LocatedHeaderDir (abs_path = hdr_dir , found_via = "site-packages" )
6886 return None
6987
7088
71- def _find_based_on_conda_layout (libname : str , h_basename : str , ctk_layout : bool ) -> LocatedHeaderDir | None :
89+ def find_in_conda (desc : HeaderDescriptor ) -> LocatedHeaderDir | None :
90+ """Search ``$CONDA_PREFIX``."""
7291 conda_prefix = os .environ .get ("CONDA_PREFIX" )
7392 if not conda_prefix :
7493 return None
75- if IS_WINDOWS :
76- anchor_point = os .path .join (conda_prefix , "Library" )
77- if not os .path .isdir (anchor_point ):
78- return None
79- else :
80- if ctk_layout :
81- targets_include_path = glob .glob (os .path .join (conda_prefix , "targets" , "*" , "include" ))
82- if not targets_include_path :
83- return None
84- if len (targets_include_path ) != 1 :
85- # Conda does not support multiple architectures.
86- # QUESTION(PR#956): Do we want to issue a warning?
87- return None
88- include_path = targets_include_path [0 ]
89- else :
90- include_path = os .path .join (conda_prefix , "include" )
91- anchor_point = os .path .dirname (include_path )
92- found_header_path = _locate_based_on_ctk_layout (libname , h_basename , anchor_point )
94+ anchor_point = resolve_conda_anchor (desc , conda_prefix )
95+ if anchor_point is None :
96+ return None
97+ found_header_path = _locate_in_anchor_layout (desc , anchor_point )
9398 if found_header_path :
9499 return LocatedHeaderDir (abs_path = found_header_path , found_via = "conda" )
95100 return None
96101
97102
98- def _find_ctk_header_directory_via_canary (libname : str , h_basename : str ) -> str | None :
103+ def find_in_cuda_home (desc : HeaderDescriptor ) -> LocatedHeaderDir | None :
104+ """Search ``$CUDA_HOME`` / ``$CUDA_PATH``."""
105+ cuda_home = get_cuda_home_or_path ()
106+ if cuda_home is None :
107+ return None
108+ result = _locate_in_anchor_layout (desc , cuda_home )
109+ if result is not None :
110+ return LocatedHeaderDir (abs_path = result , found_via = "CUDA_HOME" )
111+ return None
112+
113+
114+ def find_via_ctk_root_canary (desc : HeaderDescriptor ) -> LocatedHeaderDir | None :
99115 """Try CTK header lookup via CTK-root canary probing.
100116
101- Uses the same canary as dynamic-library CTK-root discovery: system-load
102- ``cudart`` in a spawned child process, derive CTK root from the resolved
103- absolute library path, then search the expected CTK include layout under
104- that root.
117+ Skips immediately if the descriptor does not opt in (``use_ctk_root_canary``).
118+ Otherwise, system-loads ``cudart`` in a spawned child process, derives
119+ CTK root from the resolved library path, and searches the expected include
120+ layout under that root.
105121 """
122+ if not desc .use_ctk_root_canary :
123+ return None
106124 canary_abs_path = _resolve_system_loaded_abs_path_in_subprocess ("cudart" )
107125 if canary_abs_path is None :
108126 return None
109127 ctk_root = derive_ctk_root (canary_abs_path )
110128 if ctk_root is None :
111129 return None
112- return _locate_based_on_ctk_layout (libname , h_basename , ctk_root )
130+ result = _locate_in_anchor_layout (desc , ctk_root )
131+ if result is not None :
132+ return LocatedHeaderDir (abs_path = result , found_via = "system-ctk-root" )
133+ return None
113134
114135
115- def _find_ctk_header_directory (libname : str ) -> LocatedHeaderDir | None :
116- h_basename = supported_nvidia_headers .SUPPORTED_HEADERS_CTK [libname ]
117- candidate_dirs = supported_nvidia_headers .SUPPORTED_SITE_PACKAGE_HEADER_DIRS_CTK [libname ]
136+ def find_in_system_install_dirs (desc : HeaderDescriptor ) -> LocatedHeaderDir | None :
137+ """Search system install directories (glob patterns)."""
138+ for pattern in desc .system_install_dirs :
139+ for hdr_dir in sorted (glob .glob (pattern ), reverse = True ):
140+ if _joined_isfile (hdr_dir , desc .header_basename ):
141+ return LocatedHeaderDir (abs_path = hdr_dir , found_via = "supported_install_dir" )
142+ return None
118143
119- for cdir in candidate_dirs :
120- if hdr_dir := _locate_under_site_packages (cdir , h_basename ):
121- return hdr_dir
122144
123- if hdr_dir := _find_based_on_conda_layout (libname , h_basename , True ):
124- return hdr_dir
145+ # ---------------------------------------------------------------------------
146+ # Step sequence and cascade runner
147+ # ---------------------------------------------------------------------------
125148
126- cuda_home = get_cuda_home_or_path ()
127- if cuda_home and (result := _locate_based_on_ctk_layout (libname , h_basename , cuda_home )):
128- return LocatedHeaderDir (abs_path = result , found_via = "CUDA_HOME" )
149+ #: Unified find steps — each step self-gates based on descriptor fields.
150+ FIND_STEPS : tuple [HeaderFindStep , ...] = (
151+ find_in_site_packages ,
152+ find_in_conda ,
153+ find_in_cuda_home ,
154+ find_via_ctk_root_canary ,
155+ find_in_system_install_dirs ,
156+ )
129157
130- if result := _find_ctk_header_directory_via_canary (libname , h_basename ):
131- return LocatedHeaderDir (abs_path = result , found_via = "system-ctk-root" )
132158
159+ def run_find_steps (desc : HeaderDescriptor , steps : tuple [HeaderFindStep , ...]) -> LocatedHeaderDir | None :
160+ """Run find steps in order, returning the first hit."""
161+ for step in steps :
162+ result = step (desc )
163+ if result is not None :
164+ return result
133165 return None
134166
135167
168+ # ---------------------------------------------------------------------------
169+ # Public API
170+ # ---------------------------------------------------------------------------
171+
172+
136173@functools .cache
137174def locate_nvidia_header_directory (libname : str ) -> LocatedHeaderDir | None :
138175 """Locate the header directory for a supported NVIDIA library.
@@ -150,51 +187,17 @@ def locate_nvidia_header_directory(libname: str) -> LocatedHeaderDir | None:
150187 RuntimeError: If ``libname`` is not in the supported set.
151188
152189 Search order:
153- 1. **NVIDIA Python wheels**
154-
155- - Scan installed distributions (``site-packages``) for header layouts
156- shipped in NVIDIA wheels (e.g., ``cuda-toolkit[nvrtc]``).
157-
158- 2. **Conda environments**
159-
160- - Check Conda-style installation prefixes, which use platform-specific
161- include directory layouts.
162-
163- 3. **CUDA Toolkit environment variables**
164-
165- - Use ``CUDA_HOME`` or ``CUDA_PATH`` (in that order).
166-
167- 4. **CTK root canary probe**
168-
169- - Probe a system-loaded ``cudart`` in a spawned child process,
170- derive the CTK root from the resolved library path, then search
171- CTK include layout under that root.
190+ 1. **NVIDIA Python wheels** — site-packages directories from the descriptor.
191+ 2. **Conda environments** — platform-specific conda include layouts.
192+ 3. **CUDA Toolkit environment variables** — ``CUDA_HOME`` / ``CUDA_PATH``.
193+ 4. **CTK root canary probe** — subprocess canary (descriptors with
194+ ``use_ctk_root_canary=True`` only).
195+ 5. **System install directories** — glob patterns from the descriptor.
172196 """
173-
174- if libname in supported_nvidia_headers .SUPPORTED_HEADERS_CTK :
175- return _find_ctk_header_directory (libname )
176-
177- h_basename = supported_nvidia_headers .SUPPORTED_HEADERS_NON_CTK .get (libname )
178- if h_basename is None :
197+ desc = HEADER_DESCRIPTORS .get (libname )
198+ if desc is None :
179199 raise RuntimeError (f"UNKNOWN { libname = } " )
180-
181- candidate_dirs = supported_nvidia_headers .SUPPORTED_SITE_PACKAGE_HEADER_DIRS_NON_CTK .get (libname , [])
182-
183- for cdir in candidate_dirs :
184- if found_hdr := _locate_under_site_packages (cdir , h_basename ):
185- return found_hdr
186-
187- if found_hdr := _find_based_on_conda_layout (libname , h_basename , False ):
188- return found_hdr
189-
190- # Fall back to system install directories
191- candidate_dirs = supported_nvidia_headers .SUPPORTED_INSTALL_DIRS_NON_CTK .get (libname , [])
192- for cdir in candidate_dirs :
193- for hdr_dir in sorted (glob .glob (cdir ), reverse = True ):
194- if _joined_isfile (hdr_dir , h_basename ):
195- # For system installs, we don't have a clear found_via, so use "system"
196- return LocatedHeaderDir (abs_path = hdr_dir , found_via = "supported_install_dir" )
197- return None
200+ return run_find_steps (desc , FIND_STEPS )
198201
199202
200203def find_nvidia_header_directory (libname : str ) -> str | None :
@@ -212,25 +215,12 @@ def find_nvidia_header_directory(libname: str) -> str | None:
212215 RuntimeError: If ``libname`` is not in the supported set.
213216
214217 Search order:
215- 1. **NVIDIA Python wheels**
216-
217- - Scan installed distributions (``site-packages``) for header layouts
218- shipped in NVIDIA wheels (e.g., ``cuda-toolkit[nvrtc]``).
219-
220- 2. **Conda environments**
221-
222- - Check Conda-style installation prefixes, which use platform-specific
223- include directory layouts.
224-
225- 3. **CUDA Toolkit environment variables**
226-
227- - Use ``CUDA_HOME`` or ``CUDA_PATH`` (in that order).
228-
229- 4. **CTK root canary probe**
230-
231- - Probe a system-loaded ``cudart`` in a spawned child process,
232- derive the CTK root from the resolved library path, then search
233- CTK include layout under that root.
218+ 1. **NVIDIA Python wheels** — site-packages directories from the descriptor.
219+ 2. **Conda environments** — platform-specific conda include layouts.
220+ 3. **CUDA Toolkit environment variables** — ``CUDA_HOME`` / ``CUDA_PATH``.
221+ 4. **CTK root canary probe** — subprocess canary (descriptors with
222+ ``use_ctk_root_canary=True`` only).
223+ 5. **System install directories** — glob patterns from the descriptor.
234224 """
235225 found = locate_nvidia_header_directory (libname )
236226 return found .abs_path if found else None
0 commit comments