Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
torch
torchshow
pyzmq
requests
numpy
Expand Down
146 changes: 113 additions & 33 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,54 +4,134 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import importlib.util
import os

import shutil
import subprocess
import sys
import sysconfig

import torch

from setuptools import Command, find_packages, setup
from setuptools.command.build_ext import build_ext
from setuptools.extension import Extension

from setuptools_rust import Binding, RustBin, RustExtension
from torch.utils.cpp_extension import (
BuildExtension,
CppExtension,
CUDA_HOME,
include_paths as torch_include_paths,
TORCH_LIB_PATH,
)


# Helper functions to find torch without importing it
def find_torch_paths():
"""Find torch installation paths without importing torch"""
spec = importlib.util.find_spec("torch")
if not spec or not spec.origin:
raise RuntimeError("torch not found - please install PyTorch first")

base = os.path.dirname(spec.origin)
lib_path = os.path.join(base, "lib")
include_path = os.path.join(base, "include")

# Get all include paths (similar to torch.utils.cpp_extension.include_paths())
include_paths = [include_path]

# Add torch/csrc includes if available
torch_csrc_include = os.path.join(include_path, "torch", "csrc", "api", "include")
if os.path.exists(torch_csrc_include):
include_paths.append(torch_csrc_include)

return {"lib_path": lib_path, "include_paths": include_paths}


def find_cuda_home():
"""Find CUDA installation without importing torch"""
# Check environment variable first
cuda_home = os.environ.get("CUDA_HOME") or os.environ.get("CUDA_PATH")

if cuda_home and os.path.exists(cuda_home):
return cuda_home

# Try to find nvcc
try:
nvcc_path = subprocess.run(
["which", "nvcc"], capture_output=True, text=True, timeout=5
)
if nvcc_path.returncode == 0:
# Get directory containing bin/nvcc
nvcc = nvcc_path.stdout.strip()
cuda_home = os.path.dirname(os.path.dirname(nvcc))
return cuda_home
except (subprocess.TimeoutExpired, FileNotFoundError):
pass

# Check common locations
for path in ["/usr/local/cuda", "/usr/cuda"]:
if os.path.exists(path):
return path

return None


def detect_cxx11_abi():
"""Detect if torch uses C++11 ABI by examining library symbols"""
paths = find_torch_paths()
lib_path = paths["lib_path"]

# Try to find a torch library to check
for lib_name in ["libtorch_cpu.so", "libtorch.so", "libc10.so"]:
lib_file = os.path.join(lib_path, lib_name)
if os.path.exists(lib_file):
try:
result = subprocess.run(
["nm", "-D", lib_file], capture_output=True, text=True, timeout=10
)
if result.returncode == 0:
# Check for __cxx11 namespace which indicates new ABI
if "__cxx11" in result.stdout:
return 1 # New ABI
else:
return 0 # Old ABI
except (subprocess.TimeoutExpired, FileNotFoundError):
pass

# Default to new ABI if we can't determine
return 1


# Get torch paths and settings
torch_paths = find_torch_paths()
TORCH_LIB_PATH = torch_paths["lib_path"]
torch_include_paths = torch_paths["include_paths"]
CUDA_HOME = find_cuda_home()
cxx11_abi = detect_cxx11_abi()

USE_CUDA = CUDA_HOME is not None
USE_TENSOR_ENGINE = os.environ.get("USE_TENSOR_ENGINE", "1") == "1"

monarch_cpp_src = ["python/monarch/common/init.cpp"]

if USE_CUDA:
monarch_cpp_src.append("python/monarch/common/mock_cuda.cpp")
def create_torch_extension(name, sources):
"""Helper to create a C++ extension with torch dependencies"""
return Extension(
name,
sources,
extra_compile_args=["-g", "-O3"],
libraries=["dl", "c10", "torch", "torch_cpu", "torch_python"],
library_dirs=[TORCH_LIB_PATH],
include_dirs=[
os.path.dirname(os.path.abspath(__file__)),
sysconfig.get_config_var("INCLUDEDIR"),
]
+ torch_include_paths,
runtime_library_dirs=[TORCH_LIB_PATH] if sys.platform != "win32" else [],
)

common_C = CppExtension(
"monarch.common._C",
monarch_cpp_src,
extra_compile_args=["-g", "-O3"],
libraries=["dl"],
include_dirs=[
os.path.dirname(os.path.abspath(__file__)),
sysconfig.get_config_var("INCLUDEDIR"),
],
)

monarch_cpp_src = ["python/monarch/common/init.cpp"]
if USE_CUDA:
monarch_cpp_src.append("python/monarch/common/mock_cuda.cpp")

controller_C = CppExtension(
# Create C++ extensions using standard Extension instead of CppExtension
common_C = create_torch_extension("monarch.common._C", monarch_cpp_src)
controller_C = create_torch_extension(
"monarch.gradient._gradient_generator",
["python/monarch/gradient/_gradient_generator.cpp"],
extra_compile_args=["-g", "-O3"],
include_dirs=[
os.path.dirname(os.path.abspath(__file__)),
sysconfig.get_config_var("INCLUDEDIR"),
],
)

ENABLE_MSG_LOGGING = (
Expand All @@ -64,13 +144,13 @@

os.environ.update(
{
"CXXFLAGS": f"-D_GLIBCXX_USE_CXX11_ABI={int(torch._C._GLIBCXX_USE_CXX11_ABI)}",
"CXXFLAGS": f"-D_GLIBCXX_USE_CXX11_ABI={cxx11_abi}",
"RUSTFLAGS": " ".join(
["-Zthreads=16", ENABLE_MSG_LOGGING, ENABLE_TRACING_UNSTABLE]
),
"LIBTORCH_LIB": TORCH_LIB_PATH,
"LIBTORCH_INCLUDE": ":".join(torch_include_paths()),
"_GLIBCXX_USE_CXX11_ABI": str(int(torch._C._GLIBCXX_USE_CXX11_ABI)),
"LIBTORCH_INCLUDE": ":".join(torch_include_paths),
"_GLIBCXX_USE_CXX11_ABI": str(cxx11_abi),
"TORCH_SYS_USE_PYTORCH_APIS": "0",
}
)
Expand Down Expand Up @@ -229,7 +309,7 @@ def run(self):
},
rust_extensions=rust_extensions,
cmdclass={
"build_ext": BuildExtension.with_options(no_python_abi_suffix=True),
"build_ext": build_ext,
"clean": Clean,
},
)