Skip to content

Commit b0ec72a

Browse files
authored
CI: run MaxText tests on AWS with NGC release candidate images (#1237)
1 parent 6af770f commit b0ec72a

File tree

3 files changed

+237
-3
lines changed

3 files changed

+237
-3
lines changed
Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
apiVersion: v1
2+
kind: Service
3+
metadata:
4+
name: PLACEHOLDER
5+
spec:
6+
clusterIP: None # clusterIP must be None to create a headless service
7+
selector:
8+
job-name: PLACEHOLDER # must match Job name
9+
---
10+
apiVersion: batch/v1
11+
kind: Job
12+
metadata:
13+
name: PLACEHOLDER
14+
labels:
15+
kueue.x-k8s.io/queue-name: p5-queue
16+
spec:
17+
completions: 2 # number of nodes
18+
parallelism: 2 # number of nodes
19+
completionMode: Indexed
20+
backoffLimitPerIndex: 0 # max failures per index
21+
maxFailedIndexes: 0 # all indices must succeed
22+
template:
23+
spec:
24+
subdomain: PLACEHOLDER # has to match Service name
25+
restartPolicy: Never
26+
imagePullSecrets:
27+
- name: PLACEHOLDER
28+
containers:
29+
- name: maxtext
30+
image: PLACEHOLDER
31+
ports:
32+
- containerPort: 3389
33+
command:
34+
- bash
35+
- -c
36+
# The logging logic: stream stdout/stderr from the 0th process inside this pod,
37+
# record all of the processes' stdout/stderr + the INFO-level NCCL logs to file
38+
- |
39+
export SERVICE_NAME=$0
40+
export JOB_NAME=$1
41+
cat >each-process.sh <<'EOL'
42+
export JAX_COORDINATOR_IP=${JOB_NAME}-0.${SERVICE_NAME}
43+
export JAX_COORDINATOR_PORT=3389
44+
export NNODES=16 # actually #processes == #GPUs
45+
export NODE_RANK=$((JOB_COMPLETION_INDEX*8 + LOCAL_RANK))
46+
export JAX_LOCAL_DEVICE_IDS=$LOCAL_RANK
47+
export NCCL_DEBUG=INFO
48+
export NCCL_DEBUG_FILE=/opt/output/nccl.$NODE_RANK.log
49+
[[ $LOCAL_RANK == 0 ]] && console="/dev/stdout" || console="/dev/null"
50+
nsys-jax \
51+
--capture-range=cudaProfilerApi \
52+
--capture-range-end=stop \
53+
-o /opt/output/profile.$NODE_RANK.zip \
54+
-- \
55+
test-maxtext.sh \
56+
-n 2 \
57+
-b 2 \
58+
--model-name=llama2-7b \
59+
--attn-type=cudnn_flash_te \
60+
--remat-policy=minimal_flash \
61+
--steps=20 \
62+
--fsdp=16 \
63+
-a "scan_layers=false \
64+
max_target_length=4096 \
65+
use_iota_embed=true \
66+
logits_dot_in_fp32=false \
67+
profiler=nsys \
68+
skip_first_n_steps_for_profiler=3 \
69+
profiler_steps=8" \
70+
|& tee /opt/output/output.$NODE_RANK.log >"${console}"
71+
code=$?
72+
# Should run even on failure
73+
cat /opt/output/nccl.$NODE_RANK.log >"${console}"
74+
exit $code
75+
EOL
76+
# TODO: upgrade parallel-launch to return a failure code as soon as any
77+
# of its children do (it already does this eventually, but it could
78+
# be slow)
79+
parallel-launch LOCAL_RANK 8 bash each-process.sh
80+
code=$?
81+
# Should run even on failure
82+
touch /opt/output/.done
83+
exit $code
84+
- PLACEHOLDER
85+
- PLACEHOLDER
86+
resources:
87+
limits:
88+
nvidia.com/gpu: 8
89+
vpc.amazonaws.com/efa: 32
90+
volumeMounts:
91+
- mountPath: /dev/shm
92+
name: shmem
93+
- mountPath: /opt/output
94+
name: output
95+
- name: upload
96+
image: amazon/aws-cli
97+
command:
98+
- bash
99+
- -c
100+
- |
101+
JOB_NAME="$0"
102+
while [[ ! -f /opt/output/.done ]]; do
103+
sleep 1
104+
done
105+
rm /opt/output/.done
106+
aws s3 cp \
107+
--recursive \
108+
/opt/output \
109+
"s3://jax-toolbox-eks-output/${JOB_NAME}/"
110+
- PLACEHOLDER
111+
volumeMounts:
112+
- mountPath: /opt/output
113+
name: output
114+
volumes:
115+
- name: output
116+
emptyDir: {}
117+
- name: shmem
118+
emptyDir:
119+
medium: Memory
120+
sizeLimit: 16Gi
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
name: ~test MaxText functionality on Kubernetes
2+
3+
on:
4+
workflow_call:
5+
inputs:
6+
MAXTEXT_IMAGE:
7+
type: string
8+
description: MaxText container to test
9+
required: true
10+
11+
permissions:
12+
contents: read # to fetch code
13+
14+
jobs:
15+
maxtext:
16+
runs-on: eks
17+
env:
18+
CONTAINER_IMAGE: "${{ inputs.MAXTEXT_IMAGE }}"
19+
JOB_NAME: "maxtext-${{ github.run_id }}-${{ github.run_attempt }}"
20+
steps:
21+
- name: Check out the repository
22+
uses: actions/checkout@v4
23+
- name: Login to GitHub Container Registry
24+
uses: docker/login-action@v3
25+
with:
26+
registry: ghcr.io
27+
username: ${{ github.repository_owner }}
28+
password: ${{ secrets.GITHUB_TOKEN }}
29+
- name: Login to NVIDIA Container Registry
30+
uses: docker/login-action@v3
31+
with:
32+
registry: nvcr.io
33+
username: $oauthtoken
34+
password: ${{ secrets.NVCR_TOKEN }}
35+
- name: Store GitHub Container Registry token as Kubernetes secret
36+
run: |
37+
# Make this available to later steps
38+
TOKEN_NAME="${JOB_NAME}-token"
39+
echo "TOKEN_NAME=${TOKEN_NAME}" >> "$GITHUB_ENV"
40+
kubectl create secret generic \
41+
${TOKEN_NAME} \
42+
--from-file=.dockerconfigjson=$HOME/.docker/config.json \
43+
--type=kubernetes.io/dockerconfigjson
44+
- name: Configure Kubernetes job
45+
run: |
46+
export SERVICE_NAME="${JOB_NAME}-svc"
47+
yq -i ea 'select(di == 0).metadata.name = strenv(SERVICE_NAME)
48+
| select(di == 0).spec.selector.job-name = strenv(JOB_NAME)
49+
| select(di == 1).metadata.name = strenv(JOB_NAME)
50+
| select(di == 1).spec.template.spec.subdomain = strenv(SERVICE_NAME)
51+
| select(di == 1).spec.template.spec.imagePullSecrets[].name = strenv(TOKEN_NAME)
52+
| select(di == 1).spec.template.spec.containers[0].image = strenv(CONTAINER_IMAGE)
53+
| select(di == 1).spec.template.spec.containers[0].command[3] = strenv(SERVICE_NAME)
54+
| select(di == 1).spec.template.spec.containers[0].command[4] = strenv(JOB_NAME)
55+
| select(di == 1).spec.template.spec.containers[1].command[3] = strenv(JOB_NAME)' \
56+
.github/eks-workflow-files/maxtext-job.yaml
57+
git diff .github/eks-workflow-files/maxtext-job.yaml
58+
- name: Submit Kubernetes job
59+
run: kubectl apply -f .github/eks-workflow-files/maxtext-job.yaml
60+
- name: Wait for Kubernetes job to start
61+
run: |
62+
# Launcher job is created eagerly, but suspended. Kueue un-suspends it when
63+
# resources are available, but that is where there can be a long wait if the
64+
# cluster is busy executing other jobs.
65+
kubectl wait --for=create job/${JOB_NAME}
66+
kubectl wait --for=jsonpath='{.spec.suspend}=false' job/${JOB_NAME} --timeout=3600s
67+
- name: Stream Kubernetes job output
68+
run: |
69+
# Streaming logs will fail if the container/pod is still pending
70+
while [[ -n $(kubectl get pods --selector=batch.kubernetes.io/job-name=${JOB_NAME} --output=jsonpath='{.items[?(@.status.phase == "Pending")].metadata.name}') ]]; do
71+
sleep 1
72+
done
73+
kubectl logs --all-containers=true --all-pods=true --follow job/${JOB_NAME}
74+
- name: Retrieve Kubernetes job status
75+
shell: bash -exo pipefail {0}
76+
run: |
77+
while readarray -d : -t status < <(kubectl get job/${JOB_NAME} -o 'jsonpath={.status.failed}:{.status.succeeded}'); do
78+
failure=${status[0]:-0}
79+
success=${status[1]:-0}
80+
total=$((failure+success))
81+
if [[ ${total} < 2 ]]; then
82+
sleep 1
83+
elif [[ ${total} == 2 ]]; then
84+
break
85+
else
86+
# FIXME
87+
exit 255
88+
fi
89+
done
90+
exit ${failure}
91+
# Provide more debug output in case of failure; note that some kinds of launch
92+
# failure do not produce any log output.
93+
- name: Debug failed Kubernetes job
94+
if: failure()
95+
run: |
96+
# Provide better debug in case of launch failures that will not produce log output
97+
pods=$(kubectl get pods --selector=batch.kubernetes.io/job-name=${JOB_NAME} -o name)
98+
if [[ -n "${pods}" ]]; then
99+
kubectl describe ${pods}
100+
fi
101+
# Clean up in case of errors as well as success
102+
- name: Delete Kubernetes job
103+
if: always()
104+
run: kubectl delete -f .github/eks-workflow-files/maxtext-job.yaml
105+
- name: Delete GitHub Container Registry token
106+
if: always()
107+
run: kubectl delete secret ${TOKEN_NAME}

.github/workflows/ngc-release-testing.yaml

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ jobs:
4545
docker run -i --shm-size=1g --gpus all \
4646
${{ inputs.JAX_IMAGE }} \
4747
bash <<"EOF" |& tee test-backend-independent.log
48-
test-jax.sh -b backend-independent
48+
test-jax.sh -b backend-independent
4949
EOF
5050
docker run -i --shm-size=1g --gpus all \
5151
${{ inputs.JAX_IMAGE }} \
@@ -80,8 +80,15 @@ jobs:
8080
MAXTEXT_IMAGE: ${{ inputs.MAXTEXT_IMAGE }}
8181
secrets: inherit
8282

83+
test-maxtext-eks:
84+
if: inputs.MAXTEXT_IMAGE != ''
85+
uses: ./.github/workflows/_test_maxtext_k8s.yaml
86+
with:
87+
MAXTEXT_IMAGE: ${{ inputs.MAXTEXT_IMAGE }}
88+
secrets: inherit
89+
8390
finalize:
84-
needs: [ test-nccl, test-jax, test-rosetta-pax, test-maxtext ]
91+
needs: [ test-nccl, test-jax, test-rosetta-pax, test-maxtext, test-maxtext-eks ]
8592
if: "!cancelled()"
8693
uses: ./.github/workflows/_finalize.yaml
87-
secrets: inherit
94+
secrets: inherit

0 commit comments

Comments
 (0)