Skip to content

Commit 088220e

Browse files
authored
Add a test case for Linux aarch64 cu129 build (#7373)
#7364 missed this spot and failed to modify the right variable. Let's add a test case to avoid this. I can confirm that Linux aarch64 cu129 is showing up now https://github.com/pytorch/test-infra/actions/runs/18603342105 I also test this locally with: ``` OS=linux-aarch64 python3 tools/scripts/generate_binary_build_matrix.py | jq { "include": [ { "python_version": "3.10", "gpu_arch_type": "cpu-aarch64", "gpu_arch_version": "", "desired_cuda": "cpu", "container_image": "pytorch/manylinux2_28_aarch64-builder:cpu-aarch64", "package_type": "wheel", "build_name": "wheel-py3_10-cpu-aarch64", "validation_runner": "linux.arm64.2xlarge", "installation": "pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cpu", "channel": "nightly", "upload_to_base_bucket": "no", "stable_version": "2.9.0" }, { "python_version": "3.10", "gpu_arch_type": "cuda-aarch64", "gpu_arch_version": "12.6-aarch64", "desired_cuda": "cu126", "container_image": "pytorch/manylinuxaarch64-builder:cuda12.6", "package_type": "wheel", "build_name": "wheel-py3_10-cuda-aarch6412_6-aarch64", "validation_runner": "linux.arm64.m7g.4xlarge", "installation": "pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu126", "channel": "nightly", "upload_to_base_bucket": "no", "stable_version": "2.9.0" }, { "python_version": "3.10", "gpu_arch_type": "cuda-aarch64", "gpu_arch_version": "12.8-aarch64", "desired_cuda": "cu128", "container_image": "pytorch/manylinuxaarch64-builder:cuda12.8", "package_type": "wheel", "build_name": "wheel-py3_10-cuda-aarch6412_8-aarch64", "validation_runner": "linux.arm64.m7g.4xlarge", "installation": "pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu128", "channel": "nightly", "upload_to_base_bucket": "no", "stable_version": "2.9.0" 1 Update os check }, { "python_version": "3.10", "gpu_arch_type": "cuda-aarch64", "gpu_arch_version": "13.0-aarch64", "desired_cuda": "cu130", "container_image": "pytorch/manylinuxaarch64-builder:cuda13.0", "package_type": "wheel", "build_name": "wheel-py3_10-cuda-aarch6413_0-aarch64", "validation_runner": "linux.arm64.m7g.4xlarge", "installation": "pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu130", "channel": "nightly", "upload_to_base_bucket": "no", "stable_version": "2.9.0" }, { "python_version": "3.10", "gpu_arch_type": "cuda-aarch64", "gpu_arch_version": "12.9-aarch64", "desired_cuda": "cu129", "container_image": "pytorch/manylinuxaarch64-builder:cuda12.9", "package_type": "wheel", "build_name": "wheel-py3_10-cuda-aarch6412_9-aarch64", "validation_runner": "linux.arm64.m7g.4xlarge", "installation": "pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu129", "channel": "nightly", "upload_to_base_bucket": "no", "stable_version": "2.9.0" } ] } ``` --------- Signed-off-by: Huy Do <[email protected]>
1 parent a287d59 commit 088220e

File tree

2 files changed

+57
-2
lines changed

2 files changed

+57
-2
lines changed
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
name: Test Build Linux Aarch64 Wheels with CUDA
2+
3+
on:
4+
pull_request:
5+
paths:
6+
- .github/actions/setup-binary-builds/action.yml
7+
- .github/workflows/test_build_wheels_linux.yml
8+
- .github/workflows/build_wheels_linux.yml
9+
- .github/workflows/generate_binary_build_matrix.yml
10+
- .github/workflows/test_build_wheels_linux_aarch64_with_cuda.yml
11+
- tools/scripts/generate_binary_build_matrix.py
12+
workflow_dispatch:
13+
14+
permissions:
15+
id-token: write
16+
contents: read
17+
18+
jobs:
19+
generate-matrix:
20+
uses: ./.github/workflows/generate_binary_build_matrix.yml
21+
with:
22+
package-type: wheel
23+
os: linux-aarch64
24+
test-infra-repository: ${{ github.repository }}
25+
test-infra-ref: ${{ github.ref }}
26+
with-cuda: enable
27+
test:
28+
needs: generate-matrix
29+
strategy:
30+
fail-fast: false
31+
matrix:
32+
include:
33+
- repository: pytorch/vision
34+
pre-script: packaging/pre_build_script.sh
35+
post-script: packaging/post_build_script.sh
36+
smoke-test-script: test/smoke_test.py
37+
package-name: torchvision
38+
- repository: pytorch/audio
39+
smoke-test-script: test/smoke_test/smoke_test.py
40+
package-name: torchaudio
41+
uses: ./.github/workflows/build_wheels_linux.yml
42+
name: ${{ matrix.repository }}
43+
with:
44+
repository: ${{ matrix.repository }}
45+
ref: nightly
46+
test-infra-repository: ${{ github.repository }}
47+
test-infra-ref: ${{ github.ref }}
48+
build-matrix: ${{ needs.generate-matrix.outputs.matrix }}
49+
pre-script: ${{ matrix.pre-script }}
50+
post-script: ${{ matrix.post-script }}
51+
smoke-test-script: ${{ matrix.smoke-test-script }}
52+
package-name: ${{ matrix.package-name }}
53+
trigger-event: "${{ github.event_name }}"
54+
architecture: aarch64
55+
setup-miniconda: false

tools/scripts/generate_binary_build_matrix.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -149,15 +149,15 @@ def validation_runner(arch_type: str, os: str) -> str:
149149

150150

151151
def initialize_globals(channel: str, os: str, build_python_only: bool) -> None:
152-
global CURRENT_VERSION, CUDA_ARCHES, ROCM_ARCHES, PYTHON_ARCHES
152+
global CURRENT_VERSION, CUDA_ARCHES, CUDA_AARCH64_ARCHES, ROCM_ARCHES, PYTHON_ARCHES
153153
global WHEEL_CONTAINER_IMAGES, LIBTORCH_CONTAINER_IMAGES
154154
if channel == TEST:
155155
CURRENT_VERSION = CURRENT_CANDIDATE_VERSION
156156
else:
157157
CURRENT_VERSION = CURRENT_STABLE_VERSION
158158

159159
CUDA_ARCHES = CUDA_ARCHES_DICT[channel]
160-
if channel != "release" and os == LINUX:
160+
if channel != "release" and (os == LINUX or os == LINUX_AARCH64):
161161
# TODO (huydhn): Only build CUDA 12.9 for Linux. This logic is to be cleaned up
162162
# in 2.10
163163
CUDA_ARCHES.append("12.9")

0 commit comments

Comments
 (0)