From 2616ef1cb714f87a2ab7ba9e8742ef52be1daf00 Mon Sep 17 00:00:00 2001 From: Jacob Szwejbka Date: Fri, 12 Dec 2025 12:50:22 -0800 Subject: [PATCH 1/2] more robust ptax finding --- backends/cuda/cuda_backend.py | 64 +++++++++++++++++++++++++++++++---- 1 file changed, 57 insertions(+), 7 deletions(-) diff --git a/backends/cuda/cuda_backend.py b/backends/cuda/cuda_backend.py index 01044d85f5f..a4cfcc207db 100644 --- a/backends/cuda/cuda_backend.py +++ b/backends/cuda/cuda_backend.py @@ -5,9 +5,11 @@ # LICENSE file in the root directory of this source tree. import os +import shutil +import subprocess import typing from importlib import resources -from typing import Any, Dict, final, List +from typing import Any, Dict, final, List, Optional import torch from executorch.backends.aoti.aoti_backend import AotiBackend @@ -36,6 +38,57 @@ class CudaBackend(AotiBackend, BackendDetails): def get_device_name(cls) -> str: return "cuda" + @staticmethod + def _find_ptxas_for_version(cuda_version: str) -> Optional[str]: + """ + Find ptxas binary that matches the expected CUDA version. + Returns the path to ptxas if found and version matches, None otherwise. + """ + expected_version_marker = f"/cuda-{cuda_version}/" + + def _validate_ptxas_version(path: str) -> bool: + """Check if ptxas at given path matches expected CUDA version.""" + if not os.path.exists(path): + return False + resolved = os.path.realpath(path) + return expected_version_marker in resolved + + # 1. Try PyTorch's CUDA_HOME + try: + from torch.utils.cpp_extension import CUDA_HOME + + if CUDA_HOME: + ptxas_path = os.path.join(CUDA_HOME, "bin", "ptxas") + if _validate_ptxas_version(ptxas_path): + return ptxas_path + except ImportError: + pass + + # 2. Try CUDA_HOME / CUDA_PATH environment variables + for env_var in ("CUDA_HOME", "CUDA_PATH", "CUDA_ROOT"): + cuda_home = os.environ.get(env_var) + if cuda_home: + ptxas_path = os.path.join(cuda_home, "bin", "ptxas") + if _validate_ptxas_version(ptxas_path): + return ptxas_path + + # 3. Try versioned path directly + versioned_path = f"/usr/local/cuda-{cuda_version}/bin/ptxas" + if os.path.exists(versioned_path): + return versioned_path + + # 4. Try system PATH via shutil.which + ptxas_in_path = shutil.which("ptxas") + if ptxas_in_path and _validate_ptxas_version(ptxas_in_path): + return ptxas_in_path + + # 5. Try default symlink path as last resort + default_path = "/usr/local/cuda/bin/ptxas" + if _validate_ptxas_version(default_path): + return default_path + + return None + @staticmethod def _setup_cuda_environment_for_fatbin() -> bool: """ @@ -57,12 +110,9 @@ def _setup_cuda_environment_for_fatbin() -> bool: # Set TRITON_PTXAS_PATH for CUDA 12.6+ if major == 12 and minor >= 6: - # Try versioned path first, fallback to symlinked path - ptxas_path = f"/usr/local/cuda-{cuda_version}/bin/ptxas" - if not os.path.exists(ptxas_path): - ptxas_path = "/usr/local/cuda/bin/ptxas" - if not os.path.exists(ptxas_path): - return False + ptxas_path = CudaBackend._find_ptxas_for_version(cuda_version) + if ptxas_path is None: + return False os.environ["TRITON_PTXAS_PATH"] = ptxas_path # Get compute capability of current CUDA device From 37d16578021492ff30ffe2cd2e448e24d98a9f34 Mon Sep 17 00:00:00 2001 From: Jacob Szwejbka Date: Fri, 12 Dec 2025 13:13:47 -0800 Subject: [PATCH 2/2] lint --- backends/cuda/cuda_backend.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/backends/cuda/cuda_backend.py b/backends/cuda/cuda_backend.py index a4cfcc207db..dbbd79f4881 100644 --- a/backends/cuda/cuda_backend.py +++ b/backends/cuda/cuda_backend.py @@ -6,7 +6,6 @@ import os import shutil -import subprocess import typing from importlib import resources from typing import Any, Dict, final, List, Optional @@ -39,7 +38,7 @@ def get_device_name(cls) -> str: return "cuda" @staticmethod - def _find_ptxas_for_version(cuda_version: str) -> Optional[str]: + def _find_ptxas_for_version(cuda_version: str) -> Optional[str]: # noqa: C901 """ Find ptxas binary that matches the expected CUDA version. Returns the path to ptxas if found and version matches, None otherwise.