Skip to content

Commit 230e8ce

Browse files
committed
No torch requirement
We now only require torch to build, not installed to run monarch. (Though you need it if you want to use tensors in the tensor worker). Differential Revision: [D88912346](https://our.internmc.facebook.com/intern/diff/D88912346/) **NOTE FOR REVIEWERS**: This PR has internal Meta-specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D88912346/)! ghstack-source-id: 328596707 Pull Request resolved: #2114
1 parent 29e08e3 commit 230e8ce

File tree

2 files changed

+113
-35
lines changed

2 files changed

+113
-35
lines changed

requirements.txt

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
torch
2-
torchshow
31
pyzmq
42
requests
53
numpy

setup.py

Lines changed: 113 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -4,54 +4,134 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import importlib.util
78
import os
8-
99
import shutil
1010
import subprocess
1111
import sys
1212
import sysconfig
1313

14-
import torch
15-
1614
from setuptools import Command, find_packages, setup
15+
from setuptools.command.build_ext import build_ext
16+
from setuptools.extension import Extension
1717

1818
from setuptools_rust import Binding, RustBin, RustExtension
19-
from torch.utils.cpp_extension import (
20-
BuildExtension,
21-
CppExtension,
22-
CUDA_HOME,
23-
include_paths as torch_include_paths,
24-
TORCH_LIB_PATH,
25-
)
19+
20+
21+
# Helper functions to find torch without importing it
22+
def find_torch_paths():
23+
"""Find torch installation paths without importing torch"""
24+
spec = importlib.util.find_spec("torch")
25+
if not spec or not spec.origin:
26+
raise RuntimeError("torch not found - please install PyTorch first")
27+
28+
base = os.path.dirname(spec.origin)
29+
lib_path = os.path.join(base, "lib")
30+
include_path = os.path.join(base, "include")
31+
32+
# Get all include paths (similar to torch.utils.cpp_extension.include_paths())
33+
include_paths = [include_path]
34+
35+
# Add torch/csrc includes if available
36+
torch_csrc_include = os.path.join(include_path, "torch", "csrc", "api", "include")
37+
if os.path.exists(torch_csrc_include):
38+
include_paths.append(torch_csrc_include)
39+
40+
return {"lib_path": lib_path, "include_paths": include_paths}
41+
42+
43+
def find_cuda_home():
44+
"""Find CUDA installation without importing torch"""
45+
# Check environment variable first
46+
cuda_home = os.environ.get("CUDA_HOME") or os.environ.get("CUDA_PATH")
47+
48+
if cuda_home and os.path.exists(cuda_home):
49+
return cuda_home
50+
51+
# Try to find nvcc
52+
try:
53+
nvcc_path = subprocess.run(
54+
["which", "nvcc"], capture_output=True, text=True, timeout=5
55+
)
56+
if nvcc_path.returncode == 0:
57+
# Get directory containing bin/nvcc
58+
nvcc = nvcc_path.stdout.strip()
59+
cuda_home = os.path.dirname(os.path.dirname(nvcc))
60+
return cuda_home
61+
except (subprocess.TimeoutExpired, FileNotFoundError):
62+
pass
63+
64+
# Check common locations
65+
for path in ["/usr/local/cuda", "/usr/cuda"]:
66+
if os.path.exists(path):
67+
return path
68+
69+
return None
70+
71+
72+
def detect_cxx11_abi():
73+
"""Detect if torch uses C++11 ABI by examining library symbols"""
74+
paths = find_torch_paths()
75+
lib_path = paths["lib_path"]
76+
77+
# Try to find a torch library to check
78+
for lib_name in ["libtorch_cpu.so", "libtorch.so", "libc10.so"]:
79+
lib_file = os.path.join(lib_path, lib_name)
80+
if os.path.exists(lib_file):
81+
try:
82+
result = subprocess.run(
83+
["nm", "-D", lib_file], capture_output=True, text=True, timeout=10
84+
)
85+
if result.returncode == 0:
86+
# Check for __cxx11 namespace which indicates new ABI
87+
if "__cxx11" in result.stdout:
88+
return 1 # New ABI
89+
else:
90+
return 0 # Old ABI
91+
except (subprocess.TimeoutExpired, FileNotFoundError):
92+
pass
93+
94+
# Default to new ABI if we can't determine
95+
return 1
96+
97+
98+
# Get torch paths and settings
99+
torch_paths = find_torch_paths()
100+
TORCH_LIB_PATH = torch_paths["lib_path"]
101+
torch_include_paths = torch_paths["include_paths"]
102+
CUDA_HOME = find_cuda_home()
103+
cxx11_abi = detect_cxx11_abi()
26104

27105
USE_CUDA = CUDA_HOME is not None
28106
USE_TENSOR_ENGINE = os.environ.get("USE_TENSOR_ENGINE", "1") == "1"
29107

30-
monarch_cpp_src = ["python/monarch/common/init.cpp"]
31108

32-
if USE_CUDA:
33-
monarch_cpp_src.append("python/monarch/common/mock_cuda.cpp")
109+
def create_torch_extension(name, sources):
110+
"""Helper to create a C++ extension with torch dependencies"""
111+
return Extension(
112+
name,
113+
sources,
114+
extra_compile_args=["-g", "-O3"],
115+
libraries=["dl", "c10", "torch", "torch_cpu", "torch_python"],
116+
library_dirs=[TORCH_LIB_PATH],
117+
include_dirs=[
118+
os.path.dirname(os.path.abspath(__file__)),
119+
sysconfig.get_config_var("INCLUDEDIR"),
120+
]
121+
+ torch_include_paths,
122+
runtime_library_dirs=[TORCH_LIB_PATH] if sys.platform != "win32" else [],
123+
)
34124

35-
common_C = CppExtension(
36-
"monarch.common._C",
37-
monarch_cpp_src,
38-
extra_compile_args=["-g", "-O3"],
39-
libraries=["dl"],
40-
include_dirs=[
41-
os.path.dirname(os.path.abspath(__file__)),
42-
sysconfig.get_config_var("INCLUDEDIR"),
43-
],
44-
)
45125

126+
monarch_cpp_src = ["python/monarch/common/init.cpp"]
127+
if USE_CUDA:
128+
monarch_cpp_src.append("python/monarch/common/mock_cuda.cpp")
46129

47-
controller_C = CppExtension(
130+
# Create C++ extensions using standard Extension instead of CppExtension
131+
common_C = create_torch_extension("monarch.common._C", monarch_cpp_src)
132+
controller_C = create_torch_extension(
48133
"monarch.gradient._gradient_generator",
49134
["python/monarch/gradient/_gradient_generator.cpp"],
50-
extra_compile_args=["-g", "-O3"],
51-
include_dirs=[
52-
os.path.dirname(os.path.abspath(__file__)),
53-
sysconfig.get_config_var("INCLUDEDIR"),
54-
],
55135
)
56136

57137
ENABLE_MSG_LOGGING = (
@@ -64,13 +144,13 @@
64144

65145
os.environ.update(
66146
{
67-
"CXXFLAGS": f"-D_GLIBCXX_USE_CXX11_ABI={int(torch._C._GLIBCXX_USE_CXX11_ABI)}",
147+
"CXXFLAGS": f"-D_GLIBCXX_USE_CXX11_ABI={cxx11_abi}",
68148
"RUSTFLAGS": " ".join(
69149
["-Zthreads=16", ENABLE_MSG_LOGGING, ENABLE_TRACING_UNSTABLE]
70150
),
71151
"LIBTORCH_LIB": TORCH_LIB_PATH,
72-
"LIBTORCH_INCLUDE": ":".join(torch_include_paths()),
73-
"_GLIBCXX_USE_CXX11_ABI": str(int(torch._C._GLIBCXX_USE_CXX11_ABI)),
152+
"LIBTORCH_INCLUDE": ":".join(torch_include_paths),
153+
"_GLIBCXX_USE_CXX11_ABI": str(cxx11_abi),
74154
"TORCH_SYS_USE_PYTORCH_APIS": "0",
75155
}
76156
)
@@ -229,7 +309,7 @@ def run(self):
229309
},
230310
rust_extensions=rust_extensions,
231311
cmdclass={
232-
"build_ext": BuildExtension.with_options(no_python_abi_suffix=True),
312+
"build_ext": build_ext,
233313
"clean": Clean,
234314
},
235315
)

0 commit comments

Comments
 (0)