Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 38 additions & 1 deletion .github/actions/gke-xpk/action.yml
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,7 @@ runs:

if [ $? -ne 0 ]; then
echo "The JobSet ${WORKLOAD_NAME} on ${{ inputs.GKE_CLUSTER }} did not complete as expected "
echo "XPK_EXIT_CODE=1" >> ${GITHUB_ENV}
exit 1
fi

Expand All @@ -262,11 +263,12 @@ runs:
ALL_EXIT_CODES=$(( ALL_EXIT_CODES + POD_EXIT_CODE ))
done

echo "XPK_EXIT_CODE=${ALL_EXIT_CODES}" >> ${GITHUB_ENV}
if [ ${ALL_EXIT_CODES} -gt 0 ]; then
exit 1
fi
exit 0

- name: Clean up JobSet from cluster
shell: bash -x -u {0}
if: ${{ always() }}
Expand All @@ -291,3 +293,38 @@ runs:
if: ${{ always() }}
run: |
sudo rm -rf ${WORKLOAD_NAME}

- name: Generate sitrep
id: sitrep
shell: bash -x -e {0}
if: ${{ always() }}
run: |
source .github/workflows/scripts/to_json.sh
badge_label="${{ matrix.test }}"

summary="${{ inputs.WORKLOAD_NAME_PREFIX }}"
outcome=success
badge_label="${{ inputs.WORKLOAD_NAME_PREFIX }}"
badge_color=brightgreen

if [ "${XPK_EXIT_CODE}" -gt 0 ]; then
badge_color=red
outcome=failed
summary+=": fail"
else
summary+=": pass"
fi

to_json summary \
badge_label \
badge_color \
outcome | \
tee sitrep.json

- name: Upload sitrep to GitHub Actions from runner
if: ${{ always() }}
uses: actions/upload-artifact@v4
with:
name: ${{ inputs.WORKLOAD_NAME_PREFIX }}-sitrep
path: |
sitrep.json
7 changes: 7 additions & 0 deletions .github/container/git-clone.sh
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,13 @@ pushd ${DESTINATION}
git checkout ${GIT_REF}
COMMIT_SHA=$(git rev-parse HEAD)
git submodule update --init --recursive
if [[ "${GIT_REPO}" == *"gitlab"* ]]; then
git remote remove origin
if grep -q -r gitlab-ci-token .git; then
grep -r gitlab-ci-token .git | awk -F: '{print $1}' | xargs rm -f
fi
git branch -D main
fi
popd

## update the manifest file
Expand Down
95 changes: 52 additions & 43 deletions .github/container/pip-finalize.sh
Original file line number Diff line number Diff line change
Expand Up @@ -4,54 +4,60 @@ set -eoux pipefail

pushd /opt/pip-tools.d

# First pip-compile gathers all reqs, but we are care only about VCS installs
# It's possible there are 2nd degree transitive dependencies that are VCS, so
# this is more robust to gather VCS requirements at the cost of pip-compiling
# twice
pip-compile -o requirements.pre $(ls requirements-*.in)
# If requirements-pinned.txt exists, skip compilation
if [[ -f "requirements-pinned.txt" ]]; then
sed -E 's/#sha256=[a-f0-9]+//g' requirements-pinned.txt > requirements.txt
else
# First pip-compile gathers all reqs, but we are care only about VCS installs
# It's possible there are 2nd degree transitive dependencies that are VCS, so
# this is more robust to gather VCS requirements at the cost of pip-compiling
# twice
pip-compile -o requirements.pre $(ls requirements-*.in)

IFS=$'\n'
for line in $(cat requirements.pre | egrep '^[^#].+ @ git\+' || true); do
# VCS installs are of the form "PACKAGE @ git+..."
PACKAGE=$(echo "$line" | awk '{print $1}')
ref=$(yq e ".${PACKAGE}.latest_verified_commit" ${MANIFEST_FILE})
if [[ "$line" == *"#subdirectory="* ]]; then
# This is required b/c git-refs/commits cannot come after
# the subdirectory fragment.
# An example of an install that is of this form is:
# 'orbax-checkpoint @ git+https://github.com/google/orbax/#subdirectory=checkpoint'
echo "${line}" | sed "s/#subdirectory=/@${ref}#subdirectory=/"
else
echo "${line}@${ref}"
fi
done | tee requirements.vcs
unset IFS
IFS=$'\n'
for line in $(cat requirements.pre | egrep '^[^#].+ @ git\+' || true); do
# VCS installs are of the form "PACKAGE @ git+..."
PACKAGE=$(echo "$line" | awk '{print $1}')
ref=$(yq e ".${PACKAGE}.latest_verified_commit" ${MANIFEST_FILE})
if [[ "$line" == *"#subdirectory="* ]]; then
# This is required b/c git-refs/commits cannot come after
# the subdirectory fragment.
# An example of an install that is of this form is:
# 'orbax-checkpoint @ git+https://github.com/google/orbax/#subdirectory=checkpoint'
echo "${line}" | sed "s/#subdirectory=/@${ref}#subdirectory=/"
else
echo "${line}@${ref}"
fi
done | tee requirements.vcs
unset IFS

# Second pip-compile includes one more requirements file that pins all vcs installs
# Uses a special env var to let our custom pip impl know to treat the following as
# equivalent:
#
# fiddle @ git+https://github.com/google/fiddle
# fiddle @ git+https://github.com/google/fiddle@cd4497e4c09bdf95dcccaa1e138c2c125d32d39f
#
# JAX_TOOLBOX_VCS_EQUIVALENCY is an environment variable enabling custom logic in pip
# that treats the above as equivalent and prefers the URI wit the SHA
JAX_TOOLBOX_VCS_EQUIVALENCY=true pip-compile -o requirements.txt requirements.vcs $(ls requirements-*.in)
# Second pip-compile includes one more requirements file that pins all vcs installs
# Uses a special env var to let our custom pip impl know to treat the following as
# equivalent:
#
# fiddle @ git+https://github.com/google/fiddle
# fiddle @ git+https://github.com/google/fiddle@cd4497e4c09bdf95dcccaa1e138c2c125d32d39f
#
# JAX_TOOLBOX_VCS_EQUIVALENCY is an environment variable enabling custom logic in pip
# that treats the above as equivalent and prefers the URI wit the SHA
JAX_TOOLBOX_VCS_EQUIVALENCY=true pip-compile -o requirements.txt requirements.vcs $(ls requirements-*.in)

# If there are unpinned VCS dependencies, error since these should be included in the manifest
unpinned_vcs_dependencies=$(cat requirements.txt | egrep '^[^#].+ @ git\+' | egrep -v '^[^#].+ @ git\+.+@' || true)
if [[ $(echo -n "$unpinned_vcs_dependencies" | wc -l) -gt 0 ]]; then
echo "Unpinned VCS installs found in $(readlink -f requirements.txt):"
echo "$unpinned_vcs_dependencies"
exit 1
fi
# If there are unpinned VCS dependencies, error since these should be included in the manifest
unpinned_vcs_dependencies=$(cat requirements.txt | egrep '^[^#].+ @ git\+' | egrep -v '^[^#].+ @ git\+.+@' || true)
if [[ $(echo -n "$unpinned_vcs_dependencies" | wc -l) -gt 0 ]]; then
echo "Unpinned VCS installs found in $(readlink -f requirements.txt):"
echo "$unpinned_vcs_dependencies"
exit 1
fi

# Replace any tensorflow==X with tensorflow-cpu==X in requirements.txt only on amd64
if [ "$(uname -m)" = "x86_64" ]; then
sed -i 's/^tensorflow==\([0-9.*]\+\)$/tensorflow-cpu==\1/' requirements.txt
else
echo "Skipping TF on $(uname -m)"
# Replace any tensorflow==X with tensorflow-cpu==X in requirements.txt only on amd64
if [[ "$(uname -m)" = "x86_64" ]]; then
sed -i 's/^tensorflow==\([0-9.*]\+\)$/tensorflow-cpu==\1/' requirements.txt
else
echo "Skipping TF on $(uname -m)"
fi
fi

# --no-deps is required since conflicts can still appear during pip-sync
pip-sync --pip-args '--no-deps --src /opt' requirements.txt

Expand All @@ -63,3 +69,6 @@ for post_install in $(ls /opt/pip-tools-post-install.d/*); do
"${post_install}"
fi
done

echo "######## Frozen requirements ########"
pip freeze
120 changes: 120 additions & 0 deletions .github/eks-workflow-files/maxtext-job.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
apiVersion: v1
kind: Service
metadata:
name: PLACEHOLDER
spec:
clusterIP: None # clusterIP must be None to create a headless service
selector:
job-name: PLACEHOLDER # must match Job name
---
apiVersion: batch/v1
kind: Job
metadata:
name: PLACEHOLDER
labels:
kueue.x-k8s.io/queue-name: p5-queue
spec:
completions: 2 # number of nodes
parallelism: 2 # number of nodes
completionMode: Indexed
backoffLimitPerIndex: 0 # max failures per index
maxFailedIndexes: 0 # all indices must succeed
template:
spec:
subdomain: PLACEHOLDER # has to match Service name
restartPolicy: Never
imagePullSecrets:
- name: PLACEHOLDER
containers:
- name: maxtext
image: PLACEHOLDER
ports:
- containerPort: 3389
command:
- bash
- -c
# The logging logic: stream stdout/stderr from the 0th process inside this pod,
# record all of the processes' stdout/stderr + the INFO-level NCCL logs to file
- |
export SERVICE_NAME=$0
export JOB_NAME=$1
cat >each-process.sh <<'EOL'
export JAX_COORDINATOR_IP=${JOB_NAME}-0.${SERVICE_NAME}
export JAX_COORDINATOR_PORT=3389
export NNODES=16 # actually #processes == #GPUs
export NODE_RANK=$((JOB_COMPLETION_INDEX*8 + LOCAL_RANK))
export JAX_LOCAL_DEVICE_IDS=$LOCAL_RANK
export NCCL_DEBUG=INFO
export NCCL_DEBUG_FILE=/opt/output/nccl.$NODE_RANK.log
[[ $LOCAL_RANK == 0 ]] && console="/dev/stdout" || console="/dev/null"
nsys-jax \
--capture-range=cudaProfilerApi \
--capture-range-end=stop \
-o /opt/output/profile.$NODE_RANK.zip \
-- \
test-maxtext.sh \
-n 2 \
-b 2 \
--model-name=llama2-7b \
--attn-type=cudnn_flash_te \
--remat-policy=minimal_flash \
--steps=20 \
--fsdp=16 \
-a "scan_layers=false \
max_target_length=4096 \
use_iota_embed=true \
logits_dot_in_fp32=false \
profiler=nsys \
skip_first_n_steps_for_profiler=3 \
profiler_steps=8" \
|& tee /opt/output/output.$NODE_RANK.log >"${console}"
code=$?
# Should run even on failure
cat /opt/output/nccl.$NODE_RANK.log >"${console}"
exit $code
EOL
# TODO: upgrade parallel-launch to return a failure code as soon as any
# of its children do (it already does this eventually, but it could
# be slow)
parallel-launch LOCAL_RANK 8 bash each-process.sh
code=$?
# Should run even on failure
touch /opt/output/.done
exit $code
- PLACEHOLDER
- PLACEHOLDER
resources:
limits:
nvidia.com/gpu: 8
vpc.amazonaws.com/efa: 32
volumeMounts:
- mountPath: /dev/shm
name: shmem
- mountPath: /opt/output
name: output
- name: upload
image: amazon/aws-cli
command:
- bash
- -c
- |
JOB_NAME="$0"
while [[ ! -f /opt/output/.done ]]; do
sleep 1
done
rm /opt/output/.done
aws s3 cp \
--recursive \
/opt/output \
"s3://jax-toolbox-eks-output/${JOB_NAME}/"
- PLACEHOLDER
volumeMounts:
- mountPath: /opt/output
name: output
volumes:
- name: output
emptyDir: {}
- name: shmem
emptyDir:
medium: Memory
sizeLimit: 16Gi
Loading
Loading