Skip to content

Commit 39366a5

Browse files
authored
[HIPIFY][test][infra] Enable Linux Rock CI PSDB for amd-staging (#2294)
1 parent 562192a commit 39366a5

27 files changed

Lines changed: 4399 additions & 0 deletions
Lines changed: 290 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,290 @@
1+
name: Build Portable Linux JAX Wheels
2+
3+
on:
4+
workflow_call:
5+
inputs:
6+
amdgpu_family:
7+
required: true
8+
type: string
9+
python_version:
10+
required: true
11+
type: string
12+
release_type:
13+
description: The type of release to build ("dev", "nightly", or "prerelease"). All developer-triggered jobs should use "dev"!
14+
required: true
15+
type: string
16+
s3_subdir:
17+
description: S3 subdirectory, not including the GPU-family
18+
required: true
19+
type: string
20+
s3_staging_subdir:
21+
description: S3 staging subdirectory, not including the GPU-family
22+
required: true
23+
type: string
24+
rocm_version:
25+
description: ROCm version to install
26+
type: string
27+
tar_url:
28+
description: URL to TheRock tarball to build against
29+
type: string
30+
cloudfront_url:
31+
description: CloudFront URL pointing to Python index
32+
required: true
33+
type: string
34+
cloudfront_staging_url:
35+
description: CloudFront base URL pointing to staging Python index
36+
required: true
37+
type: string
38+
repository:
39+
description: "Repository to checkout. Defaults to `ROCm/TheRock`."
40+
type: string
41+
default: "ROCm/TheRock"
42+
ref:
43+
description: "Branch, tag or SHA to checkout. Defaults to the reference or SHA that triggered the workflow."
44+
type: string
45+
workflow_dispatch:
46+
inputs:
47+
amdgpu_family:
48+
type: choice
49+
options:
50+
- gfx101X-dgpu
51+
- gfx103X-dgpu
52+
- gfx110X-all
53+
- gfx1150
54+
- gfx1151
55+
- gfx120X-all
56+
- gfx90X-dcgpu
57+
- gfx94X-dcgpu
58+
- gfx950-dcgpu
59+
default: gfx94X-dcgpu
60+
python_version:
61+
required: true
62+
type: string
63+
default: "3.12"
64+
release_type:
65+
type: choice
66+
description: Type of release to create. All developer-triggered jobs should use "dev"!
67+
options:
68+
- dev
69+
- nightly
70+
- prerelease
71+
default: dev
72+
s3_subdir:
73+
description: S3 subdirectory, not including the GPU-family
74+
type: string
75+
default: "v2"
76+
s3_staging_subdir:
77+
description: S3 staging subdirectory, not including the GPU-family
78+
type: string
79+
default: "v2-staging"
80+
rocm_version:
81+
description: ROCm version to install
82+
type: string
83+
tar_url:
84+
description: URL to TheRock tarball to build against
85+
type: string
86+
cloudfront_url:
87+
description: CloudFront base URL pointing to Python index
88+
type: string
89+
default: "https://rocm.devreleases.amd.com/v2"
90+
cloudfront_staging_url:
91+
description: CloudFront base URL pointing to staging Python index
92+
type: string
93+
default: "https://rocm.devreleases.amd.com/v2-staging"
94+
jax_ref:
95+
description: rocm-jax repository ref/branch to check out
96+
type: string
97+
default: rocm-jaxlib-v0.8.0
98+
99+
permissions:
100+
id-token: write
101+
contents: read
102+
103+
run-name: Build Linux JAX Wheels (${{ inputs.amdgpu_family }}, ${{ inputs.python_version }}, ${{ inputs.release_type }})
104+
105+
jobs:
106+
build_jax_wheels:
107+
strategy:
108+
matrix:
109+
jax_ref: [rocm-jaxlib-v0.8.0]
110+
name: Build Linux JAX Wheels | ${{ inputs.amdgpu_family }} | Python ${{ inputs.python_version }}
111+
runs-on: ${{ github.repository_owner == 'ROCm' && 'azure-linux-scale-rocm' || 'ubuntu-24.04' }}
112+
env:
113+
PACKAGE_DIST_DIR: ${{ github.workspace }}/jax/jax_rocm_plugin/wheelhouse
114+
S3_BUCKET_PY: "therock-${{ inputs.release_type }}-python"
115+
outputs:
116+
cp_version: ${{ env.cp_version }}
117+
jax_version: ${{ steps.extract_jax_version.outputs.jax_version }}
118+
steps:
119+
- name: Checkout TheRock
120+
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1
121+
122+
- name: Checkout JAX
123+
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1
124+
with:
125+
path: jax
126+
repository: rocm/rocm-jax
127+
ref: ${{ matrix.jax_ref }}
128+
129+
- name: Configure Git Identity
130+
run: |
131+
git config --global user.name "therockbot"
132+
git config --global user.email "therockbot@amd.com"
133+
134+
- name: "Setting up Python"
135+
uses: actions/setup-python@83679a892e2d95755f2dac6acb0bfd1e9ac5d548 # v6.1.0
136+
with:
137+
python-version: ${{ inputs.python_version }}
138+
139+
- name: Select Python version
140+
run: |
141+
python build_tools/github_actions/python_to_cp_version.py \
142+
--python-version ${{ inputs.python_version }}
143+
144+
- name: Build JAX Wheels
145+
env:
146+
ROCM_VERSION: ${{ inputs.rocm_version }}
147+
run: |
148+
ls -lah
149+
pushd jax
150+
python3 build/ci_build \
151+
--compiler=clang \
152+
--python-versions="${{ inputs.python_version }}" \
153+
--rocm-version="${ROCM_VERSION}" \
154+
--therock-path="${{ inputs.tar_url }}" \
155+
dist_wheels
156+
157+
- name: Extract JAX version
158+
id: extract_jax_version
159+
run: |
160+
# Extract JAX version from requirements.txt (e.g., "jax==0.8.0")
161+
# Remove all whitespace from requirements.txt to simplify parsing
162+
# Search for lines starting with "jax==" or "jaxlib==" followed by version (excluding comments)
163+
# Extract the version number by splitting on '=' and taking the 3rd field
164+
# [^#]+ matches one or more characters that are NOT '#', ensuring we stop before any inline comments
165+
JAX_VERSION=$(tr -d ' ' < jax/build/requirements.txt \
166+
| grep -E '^(jax|jaxlib)==[^#]+' | head -n1 | cut -d'=' -f3)
167+
echo "jax_version=$JAX_VERSION" >> "$GITHUB_OUTPUT"
168+
169+
- name: Install AWS CLI
170+
if: always()
171+
run: bash ./dockerfiles/install_awscli.sh
172+
173+
- name: Configure AWS Credentials
174+
if: always()
175+
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708 # v5.1.1
176+
with:
177+
aws-region: us-east-2
178+
role-to-assume: arn:aws:iam::692859939525:role/therock-${{ inputs.release_type }}-releases
179+
180+
- name: Upload wheels to S3
181+
if: ${{ github.repository_owner == 'ROCm' }}
182+
run: |
183+
aws s3 cp ${{ env.PACKAGE_DIST_DIR }}/ s3://${{ env.S3_BUCKET_PY }}/${{ inputs.s3_staging_subdir }}/${{ inputs.amdgpu_family }}/ \
184+
--recursive --exclude "*" --include "*.whl"
185+
186+
- name: (Re-)Generate Python package release index
187+
if: ${{ github.repository_owner == 'ROCm' }}
188+
run: |
189+
python3 -m venv .venv
190+
source .venv/bin/activate
191+
pip3 install boto3 packaging
192+
python3 ./build_tools/third_party/s3_management/manage.py ${{ inputs.s3_staging_subdir }}/${{ inputs.amdgpu_family }}
193+
194+
generate_target_to_run:
195+
name: Generate target_to_run
196+
runs-on: ubuntu-24.04
197+
outputs:
198+
test_runs_on: ${{ steps.configure.outputs.test-runs-on }}
199+
bypass_tests_for_releases: ${{ steps.configure.outputs.bypass_tests_for_releases }}
200+
steps:
201+
- name: Checking out repository
202+
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1
203+
with:
204+
repository: ${{ inputs.repository || github.repository }}
205+
ref: ${{ inputs.ref || '' }}
206+
207+
- name: Generating target to run
208+
id: configure
209+
env:
210+
TARGET: ${{ inputs.amdgpu_family }}
211+
PLATFORM: "linux"
212+
# Variable comes from ROCm organization variable 'ROCM_THEROCK_TEST_RUNNERS'
213+
ROCM_THEROCK_TEST_RUNNERS: ${{ vars.ROCM_THEROCK_TEST_RUNNERS }}
214+
LOAD_TEST_RUNNERS_FROM_VAR: false
215+
run: python ./build_tools/github_actions/configure_target_run.py
216+
217+
test_jax_wheels:
218+
name: Test JAX wheels | ${{ inputs.amdgpu_family }} | ${{ needs.generate_target_to_run.outputs.test_runs_on }}
219+
needs: [build_jax_wheels, generate_target_to_run]
220+
permissions:
221+
contents: read
222+
packages: read
223+
uses: ./.github/workflows/test_linux_jax_wheels.yml
224+
with:
225+
amdgpu_family: ${{ inputs.amdgpu_family }}
226+
release_type: ${{ inputs.release_type }}
227+
s3_subdir: ${{ inputs.s3_subdir }}
228+
package_index_url: ${{ inputs.cloudfront_staging_url }}
229+
rocm_version: ${{ inputs.rocm_version }}
230+
tar_url: ${{ inputs.tar_url }}
231+
python_version: ${{ inputs.python_version }}
232+
repository: ${{ inputs.repository || github.repository }}
233+
ref: ${{ inputs.ref || '' }}
234+
jax_ref: ${{ inputs.jax_ref }}
235+
test_runs_on: ${{ needs.generate_target_to_run.outputs.test_runs_on }}
236+
237+
upload_jax_wheels:
238+
name: Release JAX Wheels to S3
239+
needs: [build_jax_wheels, generate_target_to_run, test_jax_wheels]
240+
if: ${{ !cancelled() }}
241+
runs-on: ubuntu-24.04
242+
env:
243+
S3_BUCKET_PY: "therock-${{ inputs.release_type }}-python"
244+
JAX_VERSION: "${{ needs.build_jax_wheels.outputs.jax_version }}"
245+
ROCM_VERSION: "${{ inputs.rocm_version }}"
246+
CP_VERSION: "${{ needs.build_jax_wheels.outputs.cp_version }}"
247+
248+
steps:
249+
- name: Checkout
250+
uses: actions/checkout@93cb6efe18208431cddfb8368fd83d5badbf9bfd # v5.0.1
251+
with:
252+
repository: ${{ inputs.repository || github.repository }}
253+
ref: ${{ inputs.ref || '' }}
254+
255+
- name: Configure AWS Credentials
256+
if: always()
257+
uses: aws-actions/configure-aws-credentials@00943011d9042930efac3dcd3a170e4273319bc8 # v5.1.0
258+
with:
259+
aws-region: us-east-2
260+
role-to-assume: arn:aws:iam::692859939525:role/therock-${{ inputs.release_type }}-releases
261+
262+
- name: Determine upload flag
263+
env:
264+
BUILD_RESULT: ${{ needs.build_jax_wheels.result }}
265+
TEST_RESULT: ${{ needs.test_jax_wheels.result }}
266+
TEST_RUNS_ON: ${{ needs.generate_target_to_run.outputs.test_runs_on }}
267+
BYPASS_TESTS_FOR_RELEASES: ${{ needs.generate_target_to_run.outputs.bypass_tests_for_releases }}
268+
run: python ./build_tools/github_actions/promote_wheels_based_on_policy.py
269+
270+
- name: Copy JAX wheels from staging to release S3
271+
if: ${{ env.upload == 'true' }}
272+
run: |
273+
echo "Copying exact tested wheels to release S3 bucket..."
274+
aws s3 cp \
275+
s3://${S3_BUCKET_PY}/${{ inputs.s3_staging_subdir }}/${{ inputs.amdgpu_family }}/ \
276+
s3://${S3_BUCKET_PY}/${{ inputs.s3_subdir }}/${{ inputs.amdgpu_family }}/ \
277+
--recursive \
278+
--exclude "*" \
279+
--include "jaxlib-${JAX_VERSION}+rocm${ROCM_VERSION}-${CP_VERSION}-manylinux_2_27_x86_64.whl" \
280+
--include "jax_rocm7_plugin-${JAX_VERSION}+rocm${ROCM_VERSION}-${CP_VERSION}-manylinux_2_28_x86_64.whl" \
281+
--include "jax_rocm7_pjrt-${JAX_VERSION}+rocm${ROCM_VERSION}-py3-none-manylinux_2_28_x86_64.whl"
282+
283+
- name: (Re-)Generate Python package release index
284+
if: ${{ env.upload == 'true' }}
285+
env:
286+
# Environment variables to be set for `manage.py`
287+
CUSTOM_PREFIX: "${{ inputs.s3_subdir }}/${{ inputs.amdgpu_family }}"
288+
run: |
289+
pip install boto3 packaging
290+
python ./build_tools/third_party/s3_management/manage.py ${{ env.CUSTOM_PREFIX }}

0 commit comments

Comments
 (0)