Skip to content

Commit 7c5e8d4

Browse files
zdevitometa-codesync[bot]
authored andcommitted
No torch requirement (#2114)
Summary: Pull Request resolved: #2114 We now only require torch to build, not installed to run monarch. (Though you need it if you want to use tensors in the tensor worker). ghstack-source-id: 328614677 Reviewed By: mariusae Differential Revision: D88912346 fbshipit-source-id: bd33d5ed1df8552f04cb2f2238982c107d17d4b1
1 parent 95cbbe8 commit 7c5e8d4

File tree

2 files changed

+113
-35
lines changed

2 files changed

+113
-35
lines changed

requirements.txt

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
torch
2-
torchshow
31
pyzmq
42
requests
53
numpy

setup.py

Lines changed: 113 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -4,54 +4,134 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import importlib.util
78
import os
8-
99
import shutil
1010
import subprocess
1111
import sys
1212
import sysconfig
1313

14-
import torch
15-
1614
from setuptools import Command, find_packages, setup
15+
from setuptools.command.build_ext import build_ext
16+
from setuptools.extension import Extension
1717

1818
from setuptools_rust import Binding, RustBin, RustExtension
19-
from torch.utils.cpp_extension import (
20-
BuildExtension,
21-
CppExtension,
22-
CUDA_HOME,
23-
include_paths as torch_include_paths,
24-
TORCH_LIB_PATH,
25-
)
19+
20+
21+
# Helper functions to find torch without importing it
22+
def find_torch_paths():
23+
"""Find torch installation paths without importing torch"""
24+
spec = importlib.util.find_spec("torch")
25+
if not spec or not spec.origin:
26+
raise RuntimeError("torch not found - please install PyTorch first")
27+
28+
base = os.path.dirname(spec.origin)
29+
lib_path = os.path.join(base, "lib")
30+
include_path = os.path.join(base, "include")
31+
32+
# Get all include paths (similar to torch.utils.cpp_extension.include_paths())
33+
include_paths = [include_path]
34+
35+
# Add torch/csrc includes if available
36+
torch_csrc_include = os.path.join(include_path, "torch", "csrc", "api", "include")
37+
if os.path.exists(torch_csrc_include):
38+
include_paths.append(torch_csrc_include)
39+
40+
return {"lib_path": lib_path, "include_paths": include_paths}
41+
42+
43+
def find_cuda_home():
44+
"""Find CUDA installation without importing torch"""
45+
# Check environment variable first
46+
cuda_home = os.environ.get("CUDA_HOME") or os.environ.get("CUDA_PATH")
47+
48+
if cuda_home and os.path.exists(cuda_home):
49+
return cuda_home
50+
51+
# Try to find nvcc
52+
try:
53+
nvcc_path = subprocess.run(
54+
["which", "nvcc"], capture_output=True, text=True, timeout=5
55+
)
56+
if nvcc_path.returncode == 0:
57+
# Get directory containing bin/nvcc
58+
nvcc = nvcc_path.stdout.strip()
59+
cuda_home = os.path.dirname(os.path.dirname(nvcc))
60+
return cuda_home
61+
except (subprocess.TimeoutExpired, FileNotFoundError):
62+
pass
63+
64+
# Check common locations
65+
for path in ["/usr/local/cuda", "/usr/cuda"]:
66+
if os.path.exists(path):
67+
return path
68+
69+
return None
70+
71+
72+
def detect_cxx11_abi():
73+
"""Detect if torch uses C++11 ABI by examining library symbols"""
74+
paths = find_torch_paths()
75+
lib_path = paths["lib_path"]
76+
77+
# Try to find a torch library to check
78+
for lib_name in ["libtorch_cpu.so", "libtorch.so", "libc10.so"]:
79+
lib_file = os.path.join(lib_path, lib_name)
80+
if os.path.exists(lib_file):
81+
try:
82+
result = subprocess.run(
83+
["nm", "-D", lib_file], capture_output=True, text=True, timeout=10
84+
)
85+
if result.returncode == 0:
86+
# Check for __cxx11 namespace which indicates new ABI
87+
if "__cxx11" in result.stdout:
88+
return 1 # New ABI
89+
else:
90+
return 0 # Old ABI
91+
except (subprocess.TimeoutExpired, FileNotFoundError):
92+
pass
93+
94+
# Default to new ABI if we can't determine
95+
return 1
96+
97+
98+
# Get torch paths and settings
99+
torch_paths = find_torch_paths()
100+
TORCH_LIB_PATH = torch_paths["lib_path"]
101+
torch_include_paths = torch_paths["include_paths"]
102+
CUDA_HOME = find_cuda_home()
103+
cxx11_abi = detect_cxx11_abi()
26104

27105
USE_CUDA = CUDA_HOME is not None
28106
USE_TENSOR_ENGINE = os.environ.get("USE_TENSOR_ENGINE", "1") == "1"
29107

30-
monarch_cpp_src = ["python/monarch/common/init.cpp"]
31108

32-
if USE_CUDA:
33-
monarch_cpp_src.append("python/monarch/common/mock_cuda.cpp")
109+
def create_torch_extension(name, sources):
110+
"""Helper to create a C++ extension with torch dependencies"""
111+
return Extension(
112+
name,
113+
sources,
114+
extra_compile_args=["-g", "-O3"],
115+
libraries=["dl", "c10", "torch", "torch_cpu", "torch_python"],
116+
library_dirs=[TORCH_LIB_PATH],
117+
include_dirs=[
118+
os.path.dirname(os.path.abspath(__file__)),
119+
sysconfig.get_config_var("INCLUDEDIR"),
120+
]
121+
+ torch_include_paths,
122+
runtime_library_dirs=[TORCH_LIB_PATH] if sys.platform != "win32" else [],
123+
)
34124

35-
common_C = CppExtension(
36-
"monarch.common._C",
37-
monarch_cpp_src,
38-
extra_compile_args=["-g", "-O3"],
39-
libraries=["dl"],
40-
include_dirs=[
41-
os.path.dirname(os.path.abspath(__file__)),
42-
sysconfig.get_config_var("INCLUDEDIR"),
43-
],
44-
)
45125

126+
monarch_cpp_src = ["python/monarch/common/init.cpp"]
127+
if USE_CUDA:
128+
monarch_cpp_src.append("python/monarch/common/mock_cuda.cpp")
46129

47-
controller_C = CppExtension(
130+
# Create C++ extensions using standard Extension instead of CppExtension
131+
common_C = create_torch_extension("monarch.common._C", monarch_cpp_src)
132+
controller_C = create_torch_extension(
48133
"monarch.gradient._gradient_generator",
49134
["python/monarch/gradient/_gradient_generator.cpp"],
50-
extra_compile_args=["-g", "-O3"],
51-
include_dirs=[
52-
os.path.dirname(os.path.abspath(__file__)),
53-
sysconfig.get_config_var("INCLUDEDIR"),
54-
],
55135
)
56136

57137
ENABLE_MSG_LOGGING = (
@@ -64,13 +144,13 @@
64144

65145
os.environ.update(
66146
{
67-
"CXXFLAGS": f"-D_GLIBCXX_USE_CXX11_ABI={int(torch._C._GLIBCXX_USE_CXX11_ABI)}",
147+
"CXXFLAGS": f"-D_GLIBCXX_USE_CXX11_ABI={cxx11_abi}",
68148
"RUSTFLAGS": " ".join(
69149
["-Zthreads=16", ENABLE_MSG_LOGGING, ENABLE_TRACING_UNSTABLE]
70150
),
71151
"LIBTORCH_LIB": TORCH_LIB_PATH,
72-
"LIBTORCH_INCLUDE": ":".join(torch_include_paths()),
73-
"_GLIBCXX_USE_CXX11_ABI": str(int(torch._C._GLIBCXX_USE_CXX11_ABI)),
152+
"LIBTORCH_INCLUDE": ":".join(torch_include_paths),
153+
"_GLIBCXX_USE_CXX11_ABI": str(cxx11_abi),
74154
"TORCH_SYS_USE_PYTORCH_APIS": "0",
75155
}
76156
)
@@ -229,7 +309,7 @@ def run(self):
229309
},
230310
rust_extensions=rust_extensions,
231311
cmdclass={
232-
"build_ext": BuildExtension.with_options(no_python_abi_suffix=True),
312+
"build_ext": build_ext,
233313
"clean": Clean,
234314
},
235315
)

0 commit comments

Comments
 (0)