11name : ~CI, single-arch
2- run-name : CI-${{ inputs.ARCHITECTURE }}
2+ run-name : CI-${{ inputs.ARCHITECTURE }}-${{ inputs.TESTSUBSET }}
33on :
44 workflow_call :
55 inputs :
1616 description : Artifact name in current run w/ manifest/patches. Leaving empty uses manifest/patches in current branch
1717 default : ' '
1818 required : false
19+ TEST_SUBSET :
20+ type : string
21+ description : |
22+ Subset of tests to run. Allowed values are one of:
23+ - base
24+ - jax
25+ - levanter
26+ - equinox
27+ - triton
28+ - upstream-t5x
29+ - rosetta-t5x
30+ - upstream-pax
31+ - rosetta-pax
32+ - maxtext
33+ - grok
34+
35+ Will run all downstream-connected nodes and leaves.
36+ default : ' base'
37+ required : false
1938 outputs :
2039 DOCKER_TAGS :
2140 description : JSON object containing tags of all docker images built
2241 value : ${{ jobs.collect-docker-tags.outputs.TAGS }}
2342
2443permissions :
25- contents : read # to fetch code
26- actions : write # to cancel previous workflows
44+ contents : read # to fetch code
45+ actions : write # to cancel previous workflows
2746 packages : write # to upload container
2847
2948jobs :
49+ pre-flight :
50+ runs-on : ubuntu-22.04
51+ steps :
52+ - name : Validate input `TEST_SUBSET`
53+ shell : bash
54+ run : |
55+ valid_inputs=("base" "core" "levanter" "equinox" "triton" "upstream-t5x" "rosetta-t5x" "upstream-pax" "rosetta-pax" "maxtext" "grok")
56+
57+ if [[ " ${valid_inputs[*]} " != *" ${{ inputs.TEST_SUBSET }} "* ]]; then
58+ echo "Invalid value for \`TEST_SUBSET\` provided. Expected one of: ($valid_inputs), Actual: ${{ inputs.TEST_SUBSET }}"
59+ exit 1
60+ fi
3061
62+ # Always
3163 build-base :
3264 uses : ./.github/workflows/_build_base.yaml
65+ needs : pre-flight
3366 with :
3467 ARCHITECTURE : ${{ inputs.ARCHITECTURE }}
3568 BUILD_DATE : ${{ inputs.BUILD_DATE }}
3669 MANIFEST_ARTIFACT_NAME : ${{ inputs.MANIFEST_ARTIFACT_NAME }}
3770 secrets : inherit
3871
72+ # Always
3973 build-jax :
4074 needs : build-base
4175 uses : ./.github/workflows/_build.yaml
5084 RUNNER_SIZE : large
5185 secrets : inherit
5286
87+ # base, jax, triton
5388 build-triton :
5489 needs : build-jax
55- if : inputs.ARCHITECTURE == 'amd64' # Triton does not seem to support arm64
90+ if : contains(fromJSON('["base", "jax", "triton"]'), inputs.TEST_SUBSET) && inputs.ARCHITECTURE == 'amd64' # Triton does not seem to support arm64
5691 uses : ./.github/workflows/_build.yaml
5792 with :
5893 ARCHITECTURE : ${{ inputs.ARCHITECTURE }}
6499 DOCKERFILE : .github/container/Dockerfile.triton
65100 secrets : inherit
66101
102+ # base, jax, equinox
67103 build-equinox :
68104 needs : build-jax
69105 uses : ./.github/workflows/_build.yaml
106+ if : contains(fromJSON('["base", "jax", "equinox"]'), inputs.TEST_SUBSET)
70107 with :
71108 ARCHITECTURE : ${{ inputs.ARCHITECTURE }}
72109 ARTIFACT_NAME : artifact-equinox-build
@@ -77,9 +114,10 @@ jobs:
77114 DOCKERFILE : .github/container/Dockerfile.equinox
78115 secrets : inherit
79116
117+ # base, jax, maxtext
80118 build-maxtext :
81119 needs : build-jax
82- if : inputs.ARCHITECTURE == 'amd64' # Triton does not seem to support arm64
120+ if : contains(fromJSON('["base", "jax", "maxtext"]'), inputs.TEST_SUBSET) && inputs.ARCHITECTURE == 'amd64' # Triton does not seem to support arm64
83121 uses : ./.github/workflows/_build.yaml
84122 with :
85123 ARCHITECTURE : ${{ inputs.ARCHITECTURE }}
@@ -91,35 +129,41 @@ jobs:
91129 DOCKERFILE : .github/container/Dockerfile.maxtext.amd64
92130 secrets : inherit
93131
132+ # base, jax, levanter
94133 build-levanter :
95134 needs : [build-jax]
96135 uses : ./.github/workflows/_build.yaml
136+ if : contains(fromJSON('["base", "jax", "levanter"]'), inputs.TEST_SUBSET)
97137 with :
98138 ARCHITECTURE : ${{ inputs.ARCHITECTURE }}
99- ARTIFACT_NAME : " artifact-levanter-build"
100- BADGE_FILENAME : " badge-levanter-build"
139+ ARTIFACT_NAME : ' artifact-levanter-build'
140+ BADGE_FILENAME : ' badge-levanter-build'
101141 BUILD_DATE : ${{ inputs.BUILD_DATE }}
102142 BASE_IMAGE : ${{ needs.build-jax.outputs.DOCKER_TAG_MEALKIT }}
103143 CONTAINER_NAME : levanter
104144 DOCKERFILE : .github/container/Dockerfile.levanter
105145 secrets : inherit
106146
147+ # base, jax, upstream-t5x
107148 build-upstream-t5x :
108149 needs : build-jax
109150 uses : ./.github/workflows/_build.yaml
151+ if : contains(fromJSON('["base", "jax", "upstream-t5x", "rosetta-t5x"]'), inputs.TEST_SUBSET)
110152 with :
111153 ARCHITECTURE : ${{ inputs.ARCHITECTURE }}
112- ARTIFACT_NAME : " artifact-t5x-build"
113- BADGE_FILENAME : " badge-t5x-build"
154+ ARTIFACT_NAME : ' artifact-t5x-build'
155+ BADGE_FILENAME : ' badge-t5x-build'
114156 BUILD_DATE : ${{ inputs.BUILD_DATE }}
115157 BASE_IMAGE : ${{ needs.build-jax.outputs.DOCKER_TAG_MEALKIT }}
116158 CONTAINER_NAME : upstream-t5x
117159 DOCKERFILE : .github/container/Dockerfile.t5x.${{ inputs.ARCHITECTURE }}
118160 secrets : inherit
119161
162+ # base, jax, upstream-pax
120163 build-upstream-pax :
121164 needs : build-jax
122165 uses : ./.github/workflows/_build.yaml
166+ if : contains(fromJSON('["base", "jax", "upstream-pax", "rosetta-pax"]'), inputs.TEST_SUBSET)
123167 with :
124168 ARCHITECTURE : ${{ inputs.ARCHITECTURE }}
125169 ARTIFACT_NAME : artifact-pax-build
@@ -130,42 +174,48 @@ jobs:
130174 DOCKERFILE : .github/container/Dockerfile.pax.${{ inputs.ARCHITECTURE }}
131175 secrets : inherit
132176
177+ # base, jax, upstream-t5x, rosetta-t5x
133178 build-rosetta-t5x :
134179 needs : build-upstream-t5x
135180 uses : ./.github/workflows/_build_rosetta.yaml
181+ if : contains(fromJSON('["base", "jax", "upstream-t5x", "rosetta-t5x"]'), inputs.TEST_SUBSET)
136182 with :
137183 ARCHITECTURE : ${{ inputs.ARCHITECTURE }}
138184 BUILD_DATE : ${{ inputs.BUILD_DATE }}
139185 BASE_IMAGE : ${{ needs.build-upstream-t5x.outputs.DOCKER_TAG_MEALKIT }}
140186 BASE_LIBRARY : t5x
141187 secrets : inherit
142188
189+ # base, jax, upstream-pax, rosetta-pax
143190 build-rosetta-pax :
144191 needs : build-upstream-pax
145192 uses : ./.github/workflows/_build_rosetta.yaml
193+ if : contains(fromJSON('["base", "jax", "upstream-pax", "rosetta-pax"]'), inputs.TEST_SUBSET)
146194 with :
147195 ARCHITECTURE : ${{ inputs.ARCHITECTURE }}
148196 BUILD_DATE : ${{ inputs.BUILD_DATE }}
149197 BASE_IMAGE : ${{ needs.build-upstream-pax.outputs.DOCKER_TAG_MEALKIT }}
150198 BASE_LIBRARY : pax
151199 secrets : inherit
152200
201+ # base, jax, grok
153202 build-grok :
154203 needs : [build-jax]
155204 uses : ./.github/workflows/_build.yaml
205+ if : contains(fromJSON('["base", "jax", "grok"]'), inputs.TEST_SUBSET)
156206 with :
157207 ARCHITECTURE : ${{ inputs.ARCHITECTURE }}
158- ARTIFACT_NAME : " artifact-grok-build"
159- BADGE_FILENAME : " badge-grok-build"
208+ ARTIFACT_NAME : ' artifact-grok-build'
209+ BADGE_FILENAME : ' badge-grok-build'
160210 BUILD_DATE : ${{ inputs.BUILD_DATE }}
161211 BASE_IMAGE : ${{ needs.build-jax.outputs.DOCKER_TAG_MEALKIT }}
162212 CONTAINER_NAME : grok
163213 DOCKERFILE : .github/container/Dockerfile.grok
164214 secrets : inherit
165-
215+
166216 collect-docker-tags :
167217 runs-on : ubuntu-22.04
168- if : " !cancelled()"
218+ if : ' !cancelled()'
169219 needs :
170220 - build-base
171221 - build-jax
@@ -236,9 +286,10 @@ jobs:
236286 - name : Run integration test ${{ matrix.TEST_SCRIPT }}
237287 run : bash rosetta/tests/${{ matrix.TEST_SCRIPT }}
238288
289+ # base, jax
239290 test-jax :
240291 needs : build-jax
241- if : inputs.ARCHITECTURE == 'amd64' # arm64 runners n/a
292+ if : contains(fromJSON('["base", "jax"]'), inputs.TEST_SUBSET) && inputs.ARCHITECTURE == 'amd64' # arm64 runners n/a
242293 uses : ./.github/workflows/_test_unit.yaml
243294 with :
244295 TEST_NAME : jax
@@ -291,33 +342,37 @@ jobs:
291342 # test-equinox.log
292343 # secrets: inherit
293344
345+ # base, jax, upstream-pax
294346 test-te-multigpu :
295347 needs : build-upstream-pax
296- if : inputs.ARCHITECTURE == 'amd64' # arm64 runners n/a
348+ if : contains(fromJSON('["base", "jax", "upstream-pax"]'), inputs.TEST_SUBSET) && inputs.ARCHITECTURE == 'amd64' # arm64 runners n/a
297349 uses : ./.github/workflows/_test_te.yaml
298350 with :
299351 TE_IMAGE : ${{ needs.build-upstream-pax.outputs.DOCKER_TAG_FINAL }}
300352 secrets : inherit
301353
354+ # base, jax, upstream-t5x
302355 test-upstream-t5x :
303356 needs : build-upstream-t5x
304- if : inputs.ARCHITECTURE == 'amd64' # arm64 runners n/a
357+ if : contains(fromJSON('["base", "jax", "upstream-t5x"]'), inputs.TEST_SUBSET) && inputs.ARCHITECTURE == 'amd64' # arm64 runners n/a
305358 uses : ./.github/workflows/_test_upstream_t5x.yaml
306359 with :
307360 T5X_IMAGE : ${{ needs.build-upstream-t5x.outputs.DOCKER_TAG_FINAL }}
308361 secrets : inherit
309362
363+ # base, jax, upstream-t5x, rosetta-t5x
310364 test-rosetta-t5x :
311365 needs : build-rosetta-t5x
312- if : inputs.ARCHITECTURE == 'amd64' # no images for arm64
366+ if : contains(fromJSON('["base", "jax", "upstream-t5x", "rosetta-t5x"]'), inputs.TEST_SUBSET) && inputs.ARCHITECTURE == 'amd64' # no images for arm64
313367 uses : ./.github/workflows/_test_t5x_rosetta.yaml
314368 with :
315369 T5X_IMAGE : ${{ needs.build-rosetta-t5x.outputs.DOCKER_TAG_FINAL }}
316370 secrets : inherit
317371
372+ # base, jax
318373 test-pallas :
319374 needs : build-jax
320- if : inputs.ARCHITECTURE == 'amd64' # triton doesn't support arm64(?)
375+ if : contains(fromJSON('["base", "jax"]'), inputs.TEST_SUBSET) && inputs.ARCHITECTURE == 'amd64' # triton doesn't support arm64(?)
321376 uses : ./.github/workflows/_test_unit.yaml
322377 with :
323378 TEST_NAME : pallas
@@ -341,9 +396,10 @@ jobs:
341396 test-pallas.log
342397 secrets : inherit
343398
399+ # base, jax, triton
344400 test-triton :
345401 needs : build-triton
346- if : inputs.ARCHITECTURE == 'amd64' # no images for arm64
402+ if : contains(fromJSON('["base", "jax", "triton"]'), inputs.TEST_SUBSET) && inputs.ARCHITECTURE == 'amd64' # no images for arm64
347403 uses : ./.github/workflows/_test_unit.yaml
348404 with :
349405 TEST_NAME : triton
@@ -367,9 +423,10 @@ jobs:
367423 test-triton.log
368424 secrets : inherit
369425
426+ # base, jax, levanter
370427 test-levanter :
371428 needs : build-levanter
372- if : inputs.ARCHITECTURE == 'amd64' # arm64 runners n/a
429+ if : contains(fromJSON('["base", "jax", "levanter"]'), inputs.TEST_SUBSET) && inputs. ARCHITECTURE == 'amd64' # arm64 runners n/a
373430 uses : ./.github/workflows/_test_unit.yaml
374431 with :
375432 TEST_NAME : levanter
@@ -394,9 +451,10 @@ jobs:
394451 test-levanter.log
395452 secrets : inherit
396453
454+ # base, jax, upstream-pax
397455 test-te :
398456 needs : build-upstream-pax
399- if : inputs.ARCHITECTURE == 'amd64' # arm64 runners n/a
457+ if : contains(fromJSON('["base", "jax", "upstream-pax"]'), inputs.TEST_SUBSET) && inputs. ARCHITECTURE == 'amd64' # arm64 runners n/a
400458 uses : ./.github/workflows/_test_unit.yaml
401459 with :
402460 TEST_NAME : te
@@ -422,25 +480,28 @@ jobs:
422480 pytest-report.jsonl
423481 secrets : inherit
424482
483+ # base, jax, upstream-pax
425484 test-upstream-pax :
426485 needs : build-upstream-pax
427- if : inputs.ARCHITECTURE == 'amd64' # no images for arm64
486+ if : contains(fromJSON('["base", "jax", "upstream-pax"]'), inputs.TEST_SUBSET) && inputs.ARCHITECTURE == 'amd64' # no images for arm64
428487 uses : ./.github/workflows/_test_upstream_pax.yaml
429488 with :
430489 PAX_IMAGE : ${{ needs.build-upstream-pax.outputs.DOCKER_TAG_FINAL }}
431490 secrets : inherit
432491
492+ # base, jax, upstream-pax, rosetta-pax
433493 test-rosetta-pax :
434494 needs : build-rosetta-pax
435- if : inputs.ARCHITECTURE == 'amd64' # no images for arm64
495+ if : contains(fromJSON('["base", "jax", "upstream-pax", "rosetta-pax"]'), inputs.TEST_SUBSET) && inputs.ARCHITECTURE == 'amd64' # no images for arm64
436496 uses : ./.github/workflows/_test_pax_rosetta.yaml
437497 with :
438498 PAX_IMAGE : ${{ needs.build-rosetta-pax.outputs.DOCKER_TAG_FINAL }}
439499 secrets : inherit
440500
501+ # base, jax, maxtext
441502 test-maxtext :
442503 needs : build-maxtext
443- if : inputs.ARCHITECTURE == 'amd64' # no images for arm64
504+ if : contains(fromJSON('["base", "jax", "maxtext"]'), inputs.TEST_SUBSET) && inputs.ARCHITECTURE == 'amd64' # no images for arm64
444505 uses : ./.github/workflows/_test_maxtext.yaml
445506 with :
446507 MAXTEXT_IMAGE : ${{ needs.build-maxtext.outputs.DOCKER_TAG_FINAL }}
0 commit comments