diff --git a/setup.py b/setup.py index 96745f3..d270768 100644 --- a/setup.py +++ b/setup.py @@ -202,6 +202,8 @@ "-libverbs", ] ) + arch_env = os.environ["PYTORCH_ROCM_ARCH"] + extra_link_args.extend([f"--offload-arch={arch}" for arch in arch_env.split(";")]) if enable_mpi: extra_link_args.extend( [