Skip to content

Commit 4b70da6

Browse files
allenwang28facebook-github-bot
authored andcommitted
Add ARM workflow for GB200 support (#2029)
Summary: As title - currently pip install from a GB200 machine won't work as we don't publish an ARM build to PyPI. Differential Revision: D88202637
1 parent fb0cf11 commit 4b70da6

File tree

3 files changed

+181
-18
lines changed

3 files changed

+181
-18
lines changed

.github/workflows/publish_release.yml

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,26 +12,38 @@ concurrency:
1212
cancel-in-progress: true
1313
jobs:
1414
build:
15-
name: cuda12.6-py${{ matrix.python-version }}-${{ matrix.name }}
15+
name: ${{ matrix.name }}-py${{ matrix.python-version }}
1616
strategy:
17-
fail-fast: false # Changed to false to see results from all Python versions
17+
fail-fast: false
1818
matrix:
1919
# TODO add 3.14 once we figure out py03 issue
2020
python-version: ["3.10", "3.11", "3.12", "3.13"]
2121
include:
22-
- name: 4xlarge
23-
runs-on: linux.g5.4xlarge.nvidia.gpu
22+
# x86_64 CUDA builds
23+
- name: cuda12.8-x86_64
24+
runner: linux.g5.4xlarge.nvidia.gpu
2425
torch-spec: '--pre torch --index-url https://download.pytorch.org/whl/nightly/cu128'
2526
gpu-arch-type: "cuda"
2627
gpu-arch-version: "12.8"
28+
docker-image: "pytorch/almalinux-builder" # Uses default, becomes pytorch/almalinux-builder:cuda12.8
29+
platform-tag: "manylinux2014_x86_64"
30+
# aarch64 CUDA builds
31+
- name: cuda12.8-aarch64
32+
runner: linux.arm64.r7g.12xlarge.memory # GPU-enabled ARM runner like PyTorch uses
33+
torch-spec: '--pre torch --index-url https://download.pytorch.org/whl/nightly/cu128'
34+
gpu-arch-type: "cpu" # Use "cpu" to skip nvidia driver install, CUDA libs are in Docker image
35+
gpu-arch-version: ""
36+
docker-image: "pytorch/manylinuxaarch64-builder:cuda12.8" # ARM-specific image with CUDA
37+
platform-tag: "manylinux2014_aarch64"
2738
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
2839
with:
2940
timeout: 60
30-
runner: ${{ matrix.runs-on }}
41+
runner: ${{ matrix.runner }}
3142
gpu-arch-type: ${{ matrix.gpu-arch-type }}
3243
gpu-arch-version: ${{ matrix.gpu-arch-version }}
44+
docker-image: ${{ matrix.docker-image }}
3345
submodules: recursive
34-
upload-artifact: monarch-${{ matrix.python-version }}-${{ matrix.gpu-arch-type }}${{ matrix.gpu-arch-version }}
46+
upload-artifact: monarch-${{ matrix.python-version }}-${{ matrix.name }}
3547
script: |
3648
source scripts/common-setup.sh
3749
setup_build_environment ${{ matrix.python-version }}
@@ -52,9 +64,8 @@ jobs:
5264
export MONARCH_VERSION="${{ github.event.inputs.version }}"
5365
python setup.py bdist_wheel
5466
55-
# hacky until the right distribution wheel can be made...
56-
find dist -name "*linux_x86_64.whl" -type f -exec bash -c 'mv "$1" "${1/linux_x86_64.whl/manylinux2014_x86_64.whl}"' _ {} \;
57-
ls -la dist/
67+
# Properly retag wheel with manylinux platform tag
68+
retag_wheel_platform "${{ matrix.platform-tag }}"
5869
5970
# Run tests
6071
install_python_test_dependencies

.github/workflows/wheels.yml

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,25 +13,37 @@ concurrency:
1313
cancel-in-progress: true
1414
jobs:
1515
build:
16-
name: cuda12.6-py${{ matrix.python-version }}-${{ matrix.name }}
16+
name: ${{ matrix.name }}-py${{ matrix.python-version }}
1717
strategy:
18-
fail-fast: false # Changed to false to see results from all Python versions
18+
fail-fast: false
1919
matrix:
2020
python-version: ["3.10", "3.11", "3.12", "3.13"]
2121
include:
22-
- name: 4xlarge
23-
runs-on: linux.g5.4xlarge.nvidia.gpu
22+
# x86_64 CUDA builds
23+
- name: cuda12.6-x86_64
24+
runner: linux.g5.4xlarge.nvidia.gpu
2425
torch-spec: '--pre torch --index-url https://download.pytorch.org/whl/nightly/cu126'
2526
gpu-arch-type: "cuda"
2627
gpu-arch-version: "12.6"
28+
docker-image: "pytorch/almalinux-builder" # Uses default, becomes pytorch/almalinux-builder:cuda12.6
29+
platform-tag: "manylinux2014_x86_64"
30+
# aarch64 CUDA builds
31+
- name: cuda12.6-aarch64
32+
runner: linux.arm64.r7g.12xlarge.memory # GPU-enabled ARM runner like PyTorch uses
33+
torch-spec: '--pre torch --index-url https://download.pytorch.org/whl/nightly/cu126'
34+
gpu-arch-type: "cpu" # Use "cpu" to skip nvidia driver install, CUDA libs are in Docker image
35+
gpu-arch-version: ""
36+
docker-image: "pytorch/manylinuxaarch64-builder:cuda12.6" # ARM-specific image with CUDA
37+
platform-tag: "manylinux2014_aarch64"
2738
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
2839
with:
2940
timeout: 60
30-
runner: ${{ matrix.runs-on }}
41+
runner: ${{ matrix.runner }}
3142
gpu-arch-type: ${{ matrix.gpu-arch-type }}
3243
gpu-arch-version: ${{ matrix.gpu-arch-version }}
44+
docker-image: ${{ matrix.docker-image }}
3345
submodules: recursive
34-
upload-artifact: monarch-${{ matrix.python-version }}-${{ matrix.gpu-arch-type }}${{ matrix.gpu-arch-version }}
46+
upload-artifact: monarch-${{ matrix.python-version }}-${{ matrix.name }}
3547
script: |
3648
source scripts/common-setup.sh
3749
setup_build_environment ${{ matrix.python-version }}
@@ -53,9 +65,8 @@ jobs:
5365
5466
python setup.py bdist_wheel
5567
56-
# hacky until the right distribution wheel can be made...
57-
find dist -name "*linux_x86_64.whl" -type f -exec bash -c 'mv "$1" "${1/linux_x86_64.whl/manylinux2014_x86_64.whl}"' _ {} \;
58-
ls -la dist/
68+
# Properly retag wheel with manylinux platform tag
69+
retag_wheel_platform "${{ matrix.platform-tag }}"
5970
6071
# Run tests
6172
install_python_test_dependencies

scripts/common-setup.sh

Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,14 @@ set -ex
1212
# Setup conda environment. Defaults to Python 3.10.
1313
setup_conda_environment() {
1414
local python_version=${1:-3.10}
15+
16+
# Check if we're in a manylinux container (no conda)
17+
if [[ -d /opt/python ]] && [[ ! -x "$(command -v conda)" ]]; then
18+
echo "Detected manylinux environment, using system Python instead of conda..."
19+
setup_manylinux_python "${python_version}"
20+
return
21+
fi
22+
1523
echo "Setting up conda environment with Python ${python_version}..."
1624
conda create -n venv python="${python_version}" -y
1725
conda activate venv
@@ -20,6 +28,33 @@ setup_conda_environment() {
2028
python -m pip install --upgrade pip
2129
}
2230

31+
# Setup Python in manylinux container (no conda available)
32+
setup_manylinux_python() {
33+
local python_version=${1:-3.10}
34+
echo "Setting up manylinux Python ${python_version}..."
35+
36+
# Map Python version to manylinux cp version
37+
case "${python_version}" in
38+
3.10) PYTHON_DIR="/opt/python/cp310-cp310" ;;
39+
3.11) PYTHON_DIR="/opt/python/cp311-cp311" ;;
40+
3.12) PYTHON_DIR="/opt/python/cp312-cp312" ;;
41+
3.13) PYTHON_DIR="/opt/python/cp313-cp313" ;;
42+
*) echo "Unsupported Python version: ${python_version}"; return 1 ;;
43+
esac
44+
45+
if [[ ! -d "${PYTHON_DIR}" ]]; then
46+
echo "ERROR: Python directory ${PYTHON_DIR} not found"
47+
return 1
48+
fi
49+
50+
export PATH="${PYTHON_DIR}/bin:${PATH}"
51+
export PYTHON_BIN="${PYTHON_DIR}/bin/python"
52+
53+
echo "Using Python from: ${PYTHON_DIR}"
54+
python --version
55+
python -m pip install --upgrade pip
56+
}
57+
2358
# Install system-level dependencies
2459
install_system_dependencies() {
2560
echo "Installing system dependencies..."
@@ -227,3 +262,109 @@ run_test_groups() {
227262
fi
228263
set -e
229264
}
265+
266+
# Retag wheels with manylinux platform tags
267+
#
268+
# When building wheels with `setup.py bdist_wheel`, the wheel is tagged with the
269+
# generic platform tag (e.g., "linux_x86_64", "linux_aarch64"). However, PyPI
270+
# requires proper manylinux tags (e.g., "manylinux2014_x86_64") to indicate
271+
# compatibility with the manylinux standard. Simply renaming the .whl file is
272+
# insufficient because the platform tag is also stored in the WHEEL metadata file
273+
# inside the wheel archive.
274+
#
275+
# This function properly retags wheels by:
276+
# 1. Unpacking the wheel archive
277+
# 2. Modifying the "Tag:" field in the .dist-info/WHEEL metadata file
278+
# 3. Repacking the wheel with the updated metadata and correct filename
279+
#
280+
# The `wheel pack` command automatically regenerates the RECORD file with updated
281+
# hashes, ensuring wheel integrity. This is similar to how PyTorch does it in their
282+
# manywheel build scripts (see pytorch/.ci/manywheel/build_common.sh).
283+
#
284+
# Usage: retag_wheel_platform <platform_tag> [wheel_dir]
285+
# platform_tag: Target platform (e.g., "manylinux2014_x86_64", "manylinux2014_aarch64")
286+
# wheel_dir: Directory containing wheels (defaults to "dist")
287+
#
288+
# Example:
289+
# retag_wheel_platform "manylinux2014_x86_64"
290+
# retag_wheel_platform "manylinux2014_aarch64" "build/wheels"
291+
retag_wheel_platform() {
292+
local platform_tag="${1}"
293+
local wheel_dir="${2:-dist}"
294+
295+
if [[ -z "$platform_tag" ]]; then
296+
echo "Error: platform_tag is required"
297+
echo "Usage: retag_wheel_platform <platform_tag> [wheel_dir]"
298+
return 1
299+
fi
300+
301+
if [[ ! -d "$wheel_dir" ]]; then
302+
echo "Error: wheel directory '$wheel_dir' does not exist"
303+
return 1
304+
fi
305+
306+
echo "Retagging wheels in '$wheel_dir' with platform tag: $platform_tag"
307+
308+
# Install wheel tool if not present
309+
pip install -q wheel
310+
311+
local wheel_count=0
312+
for whl in "$wheel_dir"/*.whl; do
313+
if [[ ! -f "$whl" ]]; then
314+
continue
315+
fi
316+
317+
wheel_count=$((wheel_count + 1))
318+
echo " Processing: $(basename "$whl")"
319+
320+
# Unpack the wheel
321+
wheel unpack "$whl" -d "$wheel_dir"
322+
local whl_dir=$(find "$wheel_dir" -maxdepth 1 -type d -name "$(basename "$whl" .whl)" -print -quit)
323+
324+
if [[ -n "$whl_dir" && -d "$whl_dir" ]]; then
325+
# Find and modify the WHEEL metadata file
326+
local wheel_file=$(find "$whl_dir" -name "WHEEL" -type f)
327+
328+
if [[ -f "$wheel_file" ]]; then
329+
echo " Updating WHEEL metadata: $wheel_file"
330+
331+
# Replace platform tag based on target
332+
case "$platform_tag" in
333+
manylinux*_x86_64)
334+
sed -i 's/Tag:.*linux_x86_64/Tag: py3-none-'"$platform_tag"'/g' "$wheel_file"
335+
;;
336+
manylinux*_aarch64)
337+
sed -i 's/Tag:.*linux_aarch64/Tag: py3-none-'"$platform_tag"'/g' "$wheel_file"
338+
;;
339+
*)
340+
echo " Warning: Unknown platform tag pattern '$platform_tag', attempting generic replacement"
341+
sed -i 's/Tag: \(.*\)-linux_[^-]*/Tag: \1-'"$platform_tag"'/g' "$wheel_file"
342+
;;
343+
esac
344+
else
345+
echo " Warning: WHEEL file not found in unpacked wheel"
346+
fi
347+
348+
# Repack the wheel with new platform tag
349+
echo " Repacking wheel..."
350+
wheel pack "$whl_dir" -d "$wheel_dir" >/dev/null
351+
352+
# Clean up unpacked directory
353+
rm -rf "$whl_dir"
354+
fi
355+
356+
# Remove original wheel
357+
rm "$whl"
358+
echo " ✓ Retagged: $(basename "$whl")"
359+
done
360+
361+
if [[ $wheel_count -eq 0 ]]; then
362+
echo "Warning: No wheels found in '$wheel_dir'"
363+
return 1
364+
fi
365+
366+
echo "✓ Successfully retagged $wheel_count wheel(s)"
367+
echo "Final wheels:"
368+
ls -lh "$wheel_dir"/*.whl
369+
}
370+

0 commit comments

Comments
 (0)