Skip to content

Commit df554de

Browse files
committed
Add OpenMP compile/link flags to setup.py for source builds
Source builds of torchvision do not pass -fopenmp (compile) or -lomp/-lgomp (link) flags when building the _C extension. Since at::parallel_for is a header-only template whose #pragma omp directives are compiled into the calling translation unit (_C.so), the missing flags cause it to silently fall back to sequential execution. This has had no observable effect so far because no existing torchvision C++ kernel directly uses at::parallel_for or #pragma omp. However, upcoming changes (e.g. pytorch#9442) introduce at::parallel_for, and without these flags source builds get 0% speedup from parallelization. - macOS: -Xpreprocessor -fopenmp (compile) + -lomp from PyTorch's bundled libomp (link) - Linux: -fopenmp (compile) + -lgomp (link) - Windows: unchanged (uses /openmp via MSVC, already handled separately) Fixes pytorch#2783 Signed-off-by: Yonghye Kwon <developer.0hye@gmail.com>
1 parent 8a5946e commit df554de

1 file changed

Lines changed: 16 additions & 0 deletions

File tree

setup.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,12 @@ def get_macros_and_flags():
131131
if sysconfig.get_config_var("Py_GIL_DISABLED"):
132132
extra_compile_args["cxx"].append("-DPy_GIL_DISABLED")
133133

134+
if sys.platform == "darwin":
135+
extra_compile_args["cxx"].append("-Xpreprocessor")
136+
extra_compile_args["cxx"].append("-fopenmp")
137+
elif sys.platform != "win32":
138+
extra_compile_args["cxx"].append("-fopenmp")
139+
134140
if DEBUG:
135141
extra_compile_args["cxx"].append("-g")
136142
extra_compile_args["cxx"].append("-O0")
@@ -182,12 +188,22 @@ def make_C_extension():
182188
sources += mps_sources
183189

184190
define_macros, extra_compile_args = get_macros_and_flags()
191+
192+
extra_link_args = []
193+
if sys.platform == "darwin":
194+
# Link against libomp shipped with PyTorch for at::parallel_for support
195+
torch_lib_dir = os.path.join(os.path.dirname(torch.__file__), "lib")
196+
extra_link_args = [f"-L{torch_lib_dir}", "-lomp"]
197+
elif sys.platform != "win32":
198+
extra_link_args = ["-lgomp"]
199+
185200
return Extension(
186201
name="torchvision._C",
187202
sources=sorted(str(s) for s in sources),
188203
include_dirs=[CSRS_DIR],
189204
define_macros=define_macros,
190205
extra_compile_args=extra_compile_args,
206+
extra_link_args=extra_link_args,
191207
)
192208

193209

0 commit comments

Comments
 (0)