55# LICENSE file in the root directory of this source tree.
66
77import os
8+ import shutil
9+ import subprocess
810import typing
911from importlib import resources
10- from typing import Any , Dict , final , List
12+ from typing import Any , Dict , final , List , Optional
1113
1214import torch
1315from executorch .backends .aoti .aoti_backend import AotiBackend
@@ -36,6 +38,57 @@ class CudaBackend(AotiBackend, BackendDetails):
3638 def get_device_name (cls ) -> str :
3739 return "cuda"
3840
41+ @staticmethod
42+ def _find_ptxas_for_version (cuda_version : str ) -> Optional [str ]:
43+ """
44+ Find ptxas binary that matches the expected CUDA version.
45+ Returns the path to ptxas if found and version matches, None otherwise.
46+ """
47+ expected_version_marker = f"/cuda-{ cuda_version } /"
48+
49+ def _validate_ptxas_version (path : str ) -> bool :
50+ """Check if ptxas at given path matches expected CUDA version."""
51+ if not os .path .exists (path ):
52+ return False
53+ resolved = os .path .realpath (path )
54+ return expected_version_marker in resolved
55+
56+ # 1. Try PyTorch's CUDA_HOME
57+ try :
58+ from torch .utils .cpp_extension import CUDA_HOME
59+
60+ if CUDA_HOME :
61+ ptxas_path = os .path .join (CUDA_HOME , "bin" , "ptxas" )
62+ if _validate_ptxas_version (ptxas_path ):
63+ return ptxas_path
64+ except ImportError :
65+ pass
66+
67+ # 2. Try CUDA_HOME / CUDA_PATH environment variables
68+ for env_var in ("CUDA_HOME" , "CUDA_PATH" , "CUDA_ROOT" ):
69+ cuda_home = os .environ .get (env_var )
70+ if cuda_home :
71+ ptxas_path = os .path .join (cuda_home , "bin" , "ptxas" )
72+ if _validate_ptxas_version (ptxas_path ):
73+ return ptxas_path
74+
75+ # 3. Try versioned path directly
76+ versioned_path = f"/usr/local/cuda-{ cuda_version } /bin/ptxas"
77+ if os .path .exists (versioned_path ):
78+ return versioned_path
79+
80+ # 4. Try system PATH via shutil.which
81+ ptxas_in_path = shutil .which ("ptxas" )
82+ if ptxas_in_path and _validate_ptxas_version (ptxas_in_path ):
83+ return ptxas_in_path
84+
85+ # 5. Try default symlink path as last resort
86+ default_path = "/usr/local/cuda/bin/ptxas"
87+ if _validate_ptxas_version (default_path ):
88+ return default_path
89+
90+ return None
91+
3992 @staticmethod
4093 def _setup_cuda_environment_for_fatbin () -> bool :
4194 """
@@ -57,12 +110,9 @@ def _setup_cuda_environment_for_fatbin() -> bool:
57110
58111 # Set TRITON_PTXAS_PATH for CUDA 12.6+
59112 if major == 12 and minor >= 6 :
60- # Try versioned path first, fallback to symlinked path
61- ptxas_path = f"/usr/local/cuda-{ cuda_version } /bin/ptxas"
62- if not os .path .exists (ptxas_path ):
63- ptxas_path = "/usr/local/cuda/bin/ptxas"
64- if not os .path .exists (ptxas_path ):
65- return False
113+ ptxas_path = CudaBackend ._find_ptxas_for_version (cuda_version )
114+ if ptxas_path is None :
115+ return False
66116 os .environ ["TRITON_PTXAS_PATH" ] = ptxas_path
67117
68118 # Get compute capability of current CUDA device
0 commit comments