-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathsetup.py
More file actions
117 lines (99 loc) · 3.71 KB
/
setup.py
File metadata and controls
117 lines (99 loc) · 3.71 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import os
import subprocess
from setuptools import setup
def check_nvcc():
"""Check if CUDA toolkit is available."""
try:
subprocess.check_output(["nvcc", "--version"])
return True
except (FileNotFoundError, subprocess.CalledProcessError):
return False
def detect_cuda_arch_flags():
"""Detect CUDA architecture flags for compilation.
Priority:
1. TIDE_CUDA_ARCH env var (e.g. "8.6")
2. TORCH_CUDA_ARCH_LIST env var (e.g. "8.0;9.0" or "8.0 9.0")
3. Query torch.cuda.get_device_capability() if GPU available
4. Fallback: compile for sm_70 through sm_90
"""
# Check TIDE-specific env var first
tide_arch = os.environ.get("TIDE_CUDA_ARCH")
if tide_arch:
arches = [a.strip() for a in tide_arch.replace(";", " ").split()]
return _arches_to_gencode(arches)
# Check PyTorch standard env var
torch_arch = os.environ.get("TORCH_CUDA_ARCH_LIST")
if torch_arch:
arches = [a.strip() for a in torch_arch.replace(";", " ").split()]
return _arches_to_gencode(arches)
# Try to detect from available GPU
try:
import torch
if torch.cuda.is_available():
cap = torch.cuda.get_device_capability()
arch = f"{cap[0]}.{cap[1]}"
return _arches_to_gencode([arch])
except (ImportError, RuntimeError):
pass
# Fallback: broad range from V100 through Blackwell
return _arches_to_gencode(["7.0", "7.5", "8.0", "8.6", "8.9", "9.0", "10.0", "12.0+PTX"])
def _arches_to_gencode(arches):
"""Convert architecture strings like '8.6' to -gencode flags."""
flags = []
for arch in arches:
# Handle formats: "8.6", "86", "8.6+PTX"
ptx = "+PTX" in arch
arch_clean = arch.replace("+PTX", "").strip()
if "." in arch_clean:
major, minor = arch_clean.split(".")
compute = f"{major}{minor}"
else:
compute = arch_clean
flags.append(f"-gencode=arch=compute_{compute},code=sm_{compute}")
if ptx:
flags.append(f"-gencode=arch=compute_{compute},code=compute_{compute}")
return flags
ext_modules = []
cmdclass = {}
if check_nvcc() and os.environ.get("TIDE_NO_CUDA", "0") != "1":
try:
from torch.utils.cpp_extension import CUDAExtension, BuildExtension
cuda_arch_flags = detect_cuda_arch_flags()
ext_modules = [
CUDAExtension(
name="TIDE._C",
sources=[
"csrc/extensions/torch_bindings.cpp",
"csrc/kernels/fused_layernorm_route.cu",
"csrc/kernels/batch_compact.cu",
"csrc/kernels/exit_scatter.cu",
"csrc/kernels/exit_projection.cu",
],
extra_compile_args={
"cxx": ["-O3"],
"nvcc": [
"-O3",
"--use_fast_math",
*cuda_arch_flags,
"-lineinfo",
"--threads=4",
],
},
),
]
class OptionalBuildExt(BuildExtension):
"""BuildExtension that falls back gracefully if compilation fails."""
def build_extensions(self):
try:
super().build_extensions()
except Exception as e:
print(f"\nWARNING: CUDA extension build failed: {e}")
print("Installing without CUDA kernels (pure Python fallback).\n")
self.extensions = []
cmdclass = {"build_ext": OptionalBuildExt}
except Exception:
pass
setup(
ext_modules=ext_modules,
cmdclass=cmdclass,
)