Skip to content

Commit 2616ef1

Browse files
committed
more robust ptax finding
1 parent 5d40a3a commit 2616ef1

File tree

1 file changed

+57
-7
lines changed

1 file changed

+57
-7
lines changed

backends/cuda/cuda_backend.py

Lines changed: 57 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,11 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import os
8+
import shutil
9+
import subprocess
810
import typing
911
from importlib import resources
10-
from typing import Any, Dict, final, List
12+
from typing import Any, Dict, final, List, Optional
1113

1214
import torch
1315
from executorch.backends.aoti.aoti_backend import AotiBackend
@@ -36,6 +38,57 @@ class CudaBackend(AotiBackend, BackendDetails):
3638
def get_device_name(cls) -> str:
3739
return "cuda"
3840

41+
@staticmethod
42+
def _find_ptxas_for_version(cuda_version: str) -> Optional[str]:
43+
"""
44+
Find ptxas binary that matches the expected CUDA version.
45+
Returns the path to ptxas if found and version matches, None otherwise.
46+
"""
47+
expected_version_marker = f"/cuda-{cuda_version}/"
48+
49+
def _validate_ptxas_version(path: str) -> bool:
50+
"""Check if ptxas at given path matches expected CUDA version."""
51+
if not os.path.exists(path):
52+
return False
53+
resolved = os.path.realpath(path)
54+
return expected_version_marker in resolved
55+
56+
# 1. Try PyTorch's CUDA_HOME
57+
try:
58+
from torch.utils.cpp_extension import CUDA_HOME
59+
60+
if CUDA_HOME:
61+
ptxas_path = os.path.join(CUDA_HOME, "bin", "ptxas")
62+
if _validate_ptxas_version(ptxas_path):
63+
return ptxas_path
64+
except ImportError:
65+
pass
66+
67+
# 2. Try CUDA_HOME / CUDA_PATH environment variables
68+
for env_var in ("CUDA_HOME", "CUDA_PATH", "CUDA_ROOT"):
69+
cuda_home = os.environ.get(env_var)
70+
if cuda_home:
71+
ptxas_path = os.path.join(cuda_home, "bin", "ptxas")
72+
if _validate_ptxas_version(ptxas_path):
73+
return ptxas_path
74+
75+
# 3. Try versioned path directly
76+
versioned_path = f"/usr/local/cuda-{cuda_version}/bin/ptxas"
77+
if os.path.exists(versioned_path):
78+
return versioned_path
79+
80+
# 4. Try system PATH via shutil.which
81+
ptxas_in_path = shutil.which("ptxas")
82+
if ptxas_in_path and _validate_ptxas_version(ptxas_in_path):
83+
return ptxas_in_path
84+
85+
# 5. Try default symlink path as last resort
86+
default_path = "/usr/local/cuda/bin/ptxas"
87+
if _validate_ptxas_version(default_path):
88+
return default_path
89+
90+
return None
91+
3992
@staticmethod
4093
def _setup_cuda_environment_for_fatbin() -> bool:
4194
"""
@@ -57,12 +110,9 @@ def _setup_cuda_environment_for_fatbin() -> bool:
57110

58111
# Set TRITON_PTXAS_PATH for CUDA 12.6+
59112
if major == 12 and minor >= 6:
60-
# Try versioned path first, fallback to symlinked path
61-
ptxas_path = f"/usr/local/cuda-{cuda_version}/bin/ptxas"
62-
if not os.path.exists(ptxas_path):
63-
ptxas_path = "/usr/local/cuda/bin/ptxas"
64-
if not os.path.exists(ptxas_path):
65-
return False
113+
ptxas_path = CudaBackend._find_ptxas_for_version(cuda_version)
114+
if ptxas_path is None:
115+
return False
66116
os.environ["TRITON_PTXAS_PATH"] = ptxas_path
67117

68118
# Get compute capability of current CUDA device

0 commit comments

Comments
 (0)