Skip to content

Commit c7cabd8

Browse files
committed
SUPPORTED_HEADERS_CTK_LINUX_ONLY etc. (for cufile)
1 parent 0629da2 commit c7cabd8

File tree

3 files changed

+22
-5
lines changed

3 files changed

+22
-5
lines changed

cuda_pathfinder/cuda/pathfinder/_headers/find_nvidia_headers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@
77
import shutil
88
from typing import Optional
99

10-
from cuda.pathfinder._dynamic_libs.supported_nvidia_libs import IS_WINDOWS
1110
from cuda.pathfinder._headers import supported_nvidia_headers
11+
from cuda.pathfinder._headers.supported_nvidia_headers import IS_WINDOWS
1212
from cuda.pathfinder._utils.env_vars import get_cuda_home_or_path
1313
from cuda.pathfinder._utils.find_sub_dirs import find_sub_dirs_all_sitepackages
1414

cuda_pathfinder/cuda/pathfinder/_headers/supported_nvidia_headers.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
11
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
# SPDX-License-Identifier: Apache-2.0
33

4-
SUPPORTED_HEADERS_CTK = {
4+
import sys
5+
6+
IS_WINDOWS = sys.platform == "win32"
7+
8+
SUPPORTED_HEADERS_CTK_COMMON = {
59
"cccl": "cuda/std/version",
610
"cublas": "cublas.h",
711
"cudart": "cuda_runtime.h",
812
"cufft": "cufft.h",
9-
"cufile": "cufile.h",
1013
"curand": "curand.h",
1114
"cusolver": "cusolver_common.h",
1215
"cusparse": "cusparse.h",
@@ -19,6 +22,19 @@
1922
"nvvm": "nvvm.h",
2023
}
2124

25+
SUPPORTED_HEADERS_CTK_LINUX_ONLY = {
26+
"cufile": "cufile.h",
27+
}
28+
SUPPORTED_HEADERS_CTK_LINUX = SUPPORTED_HEADERS_CTK_COMMON | SUPPORTED_HEADERS_CTK_LINUX_ONLY
29+
30+
SUPPORTED_HEADERS_CTK_WINDOWS_ONLY: dict[str, str] = {}
31+
SUPPORTED_HEADERS_CTK_WINDOWS = SUPPORTED_HEADERS_CTK_COMMON | SUPPORTED_HEADERS_CTK_WINDOWS_ONLY
32+
33+
SUPPORTED_HEADERS_CTK_ALL = (
34+
SUPPORTED_HEADERS_CTK_COMMON | SUPPORTED_HEADERS_CTK_LINUX_ONLY | SUPPORTED_HEADERS_CTK_WINDOWS_ONLY
35+
)
36+
SUPPORTED_HEADERS_CTK = SUPPORTED_HEADERS_CTK_WINDOWS if IS_WINDOWS else SUPPORTED_HEADERS_CTK_LINUX
37+
2238
SUPPORTED_SITE_PACKAGE_HEADER_DIRS_CTK = {
2339
"cccl": (
2440
"cuda/cccl/headers/include", # cuda-cccl

cuda_pathfinder/tests/test_find_nvidia_headers.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,10 @@
2121
import pytest
2222

2323
from cuda.pathfinder import _find_nvidia_header_directory as find_nvidia_header_directory
24-
from cuda.pathfinder._dynamic_libs.supported_nvidia_libs import IS_WINDOWS
2524
from cuda.pathfinder._headers.supported_nvidia_headers import (
25+
IS_WINDOWS,
2626
SUPPORTED_HEADERS_CTK,
27+
SUPPORTED_HEADERS_CTK_ALL,
2728
SUPPORTED_SITE_PACKAGE_HEADER_DIRS_CTK,
2829
)
2930

@@ -65,7 +66,7 @@ def test_find_libname_nvshmem(info_summary_append):
6566

6667

6768
def test_supported_headers_site_packages_ctk_consistency():
68-
assert tuple(sorted(SUPPORTED_HEADERS_CTK)) == tuple(sorted(SUPPORTED_SITE_PACKAGE_HEADER_DIRS_CTK.keys()))
69+
assert tuple(sorted(SUPPORTED_HEADERS_CTK_ALL)) == tuple(sorted(SUPPORTED_SITE_PACKAGE_HEADER_DIRS_CTK.keys()))
6970

7071

7172
@pytest.mark.parametrize("libname", SUPPORTED_HEADERS_CTK.keys())

0 commit comments

Comments
 (0)