Skip to content

Commit

Permalink
Refactor JAX tests to separate setup from test commands (#967)
Browse files Browse the repository at this point in the history
* Refactor JAX tests to separate setup from test commands

* formatting

* undo some small changes

* Remove another duplicate

* Remove some unecessary flags

* Move some lines between setup and run

* fixes

* Make bash_logout common

* fix working dir in compilation cache test
  • Loading branch information
will-cromar authored Aug 9, 2023
1 parent 4858762 commit 00d9941
Show file tree
Hide file tree
Showing 10 changed files with 47 additions and 62 deletions.
31 changes: 14 additions & 17 deletions tests/jax/common.libsonnet
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,20 @@ local tpus = import 'templates/tpus.libsonnet';
tpuVmCreateSleepSeconds: 60,
},

// JAX tests are structured as bash scripts that run directly on the Cloud
// TPU VM instead of using docker images
testScript:: error 'Must define `testScript`',
setup: error 'Must define `setup`',
runTest: error 'Must define `runTest`',

testScript:: |||
set -x
set -u
set -e
# .bash_logout sometimes causes a spurious bad exit code, remove it.
rm .bash_logout
%(setup)s
%(runTest)s
||| % self,
command: [
'bash',
'-c',
Expand Down Expand Up @@ -171,13 +182,6 @@ local tpus = import 'templates/tpus.libsonnet';
huggingFaceTransformer:: {
scriptConfig+: {
installPackages: |||
set -x
set -u
set -e
# .bash_logout sometimes causes a spurious bad exit code, remove it.
rm .bash_logout
pip install --upgrade pip
git clone https://github.com/huggingface/transformers.git
cd transformers && pip install .
Expand All @@ -200,13 +204,6 @@ local tpus = import 'templates/tpus.libsonnet';
huggingFaceDiffuser:: {
scriptConfig+: {
installPackages: |||
set -x
set -u
set -e
# .bash_logout sometimes causes a spurious bad exit code, remove it.
rm .bash_logout
pip install --upgrade pip
git clone https://github.com/huggingface/diffusers.git
cd diffusers && pip install .
Expand Down
15 changes: 5 additions & 10 deletions tests/jax/compilation-cache.libsonnet
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,7 @@ local mixins = import 'templates/mixins.libsonnet';
compilationCacheTest:: common.JaxTest + common.tpuVmBaseImage + mixins.Functional {
modelName: 'compilation-cache-test',

testScript:: |||
set -x
set -u
set -e
# .bash_logout sometimes causes a spurious bad exit code, remove it.
rm .bash_logout
setup: |||
pip install --upgrade pip
%(installLocalJax)s
Expand All @@ -40,6 +33,8 @@ local mixins = import 'templates/mixins.libsonnet';
exit 1
fi
cd ~
mkdir "/tmp/compilation_cache_integration_test"
cat >integration.py <<'END_SCRIPT'
import jax
Expand All @@ -59,12 +54,12 @@ local mixins = import 'templates/mixins.libsonnet';
num_of_files = sum(1 for f in os.listdir("/tmp/compilation_cache_integration_test"))
assert num_of_files == 1, f"The number of files in the cache should be 1 but is {num_of_files}"
END_SCRIPT
||| % self.scriptConfig,
runTest: |||
python3 integration.py
python3 directory_size.py
python3 integration.py
python3 directory_size.py
||| % self.scriptConfig,
},

Expand Down
22 changes: 8 additions & 14 deletions tests/jax/latest/common.libsonnet
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,7 @@ local tpus = import 'templates/tpus.libsonnet';
extraDeps:: [],
extraFlags:: [],

testScript:: |||
set -x
set -u
set -e
# .bash_logout sometimes causes a spurious bad exit code, remove it.
rm .bash_logout
setup: |||
pip install --upgrade pip
pip install --upgrade clu %(extraDeps)s
Expand All @@ -47,16 +40,17 @@ local tpus = import 'templates/tpus.libsonnet';
git clone https://github.com/google/flax
cd flax
pip install --upgrade git+https://github.com/google/flax.git
||| % (self.scriptConfig {
extraDeps: std.join(' ', config.extraDeps),
}),
runTest: |||
cd examples/%(folderName)s
export GCS_BUCKET=$(MODEL_DIR)
export TFDS_DATA_DIR=$(TFDS_DIR)
python3 main.py --workdir=$(MODEL_DIR) --config=configs/%(extraConfig)s %(extraFlags)s
||| % (self.scriptConfig {
folderName: config.folderName,
modelName: config.modelName,
extraDeps: std.join(' ', config.extraDeps),
extraConfig: config.extraConfig,
extraFlags: std.join(' ', config.extraFlags),
}),
Expand All @@ -67,14 +61,14 @@ local tpus = import 'templates/tpus.libsonnet';
frameworkPrefix: 'flax.latest',
modelName:: 'bert-glue',
extraFlags:: [],
testScript:: |||
setup: |||
%(installPackages)s
pip install -r examples/flax/text-classification/requirements.txt
%(verifySetup)s
||| % self.scriptConfig,
runTest: |||
export GCS_BUCKET=$(MODEL_DIR)
export OUTPUT_DIR='./bert-glue'
python3 examples/flax/text-classification/run_flax_glue.py --model_name_or_path bert-base-cased \
--output_dir ${OUTPUT_DIR} \
--logging_dir ${OUTPUT_DIR} \
Expand Down
5 changes: 3 additions & 2 deletions tests/jax/latest/flax-bart-wiki_summary.libsonnet
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,12 @@ local tpus = import 'templates/tpus.libsonnet';
frameworkPrefix: 'flax.latest',
modelName:: 'bart-wiki.summary',
extraFlags:: [],
testScript:: |||
setup: |||
%(installPackages)s
pip install -r examples/flax/summarization/requirements.txt
%(verifySetup)s
||| % (self.scriptConfig { extraFlags: std.join(' ', config.extraFlags) }),
runTest: |||
export GCS_BUCKET=$(MODEL_DIR)
python3 examples/flax/summarization/run_summarization_flax.py \
--output_dir './bart-base-wiki' \
Expand Down
6 changes: 3 additions & 3 deletions tests/jax/latest/flax-gpt2-oscar.libsonnet
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,15 @@ local tpus = import 'templates/tpus.libsonnet';
frameworkPrefix: 'flax.latest',
modelName:: 'gpt2-oscar',
extraFlags:: [],
testScript:: |||
setup: |||
%(installPackages)s
pip install -r examples/flax/language-modeling/requirements.txt
%(verifySetup)s
cd examples/flax/language-modeling
gsutil cp -r gs://cloud-tpu-tpuvm-artifacts/config/xl-ml-test/jax/gpt2 .
||| % (self.scriptConfig { extraFlags: std.join(' ', config.extraFlags) }),
runTest: |||
python3 run_clm_flax.py \
--output_dir=./gpt2 \
--model_type=gpt2 \
Expand All @@ -52,7 +53,6 @@ local tpus = import 'templates/tpus.libsonnet';
--logging_steps=500 \
--eval_steps=2500 \
%(extraFlags)s
||| % (self.scriptConfig { extraFlags: std.join(' ', config.extraFlags) }),
},

Expand Down
6 changes: 3 additions & 3 deletions tests/jax/latest/flax-sd-pokemon.libsonnet
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,12 @@ local tpus = import 'templates/tpus.libsonnet';
frameworkPrefix: 'flax.latest',
modelName:: 'sd-pokemon',
extraFlags:: [],
testScript:: |||
setup: |||
%(installPackages)s
pip install -U -r examples/text_to_image/requirements_flax.txt
%(verifySetup)s
||| % self.scriptConfig,
runTest: |||
export GCS_BUCKET=$(MODEL_DIR)
export MODEL_NAME="duongna/stable-diffusion-v1-4-flax"
export dataset_name="lambdalabs/pokemon-blip-captions"
Expand All @@ -42,7 +43,6 @@ local tpus = import 'templates/tpus.libsonnet';
--output_dir="./sd-pokemon-model" \
--cache_dir /tmp \
%(extraFlags)s
||| % (self.scriptConfig { extraFlags: std.join(' ', config.extraFlags) }),
},

Expand Down
5 changes: 3 additions & 2 deletions tests/jax/latest/flax-vit-imagenette.libsonnet
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,15 @@ local tpus = import 'templates/tpus.libsonnet';
frameworkPrefix: 'flax.latest',
modelName:: 'vit-imagenette',
extraFlags:: [],
testScript:: |||
setup: |||
%(installPackages)s
pip install -r examples/flax/vision/requirements.txt
%(verifySetup)s
wget https://s3.amazonaws.com/fast-ai-imageclas/imagenette2.tgz
tar -xvzf imagenette2.tgz
||| % self.scriptConfig,
runTest: |||
export GCS_BUCKET=$(MODEL_DIR)
python3 examples/flax/vision/run_image_classification.py \
--output_dir './vit-imagenette' \
Expand Down
4 changes: 2 additions & 2 deletions tests/jax/latest/flax-wmt-wmt17_translate.libsonnet
Original file line number Diff line number Diff line change
Expand Up @@ -55,13 +55,13 @@ local tpus = import 'templates/tpus.libsonnet';
local wmt_profiling = self.wmt_profiling,
wmt_profiling:: wmt {
local config = self,
testScript+:: |||
runTest+: |||
gsutil -q stat $(MODEL_DIR)/plugins/profile/*/*.xplane.pb
gsutil cp -r $(MODEL_DIR)/plugins /tmp/
python3 -m pip uninstall tensorboard_plugin_profile
python3 -m pip install tbp-nightly
python3 ~/.local/lib/python3.*/site-packages/tensorboard_plugin_profile/integration_tests/tpu/tensorflow/tpu_tf2_keras_test.* --log_directory=/tmp/
||| % (self.scriptConfig {}),
||| % (self.scriptConfig),
},
configs: [
wmt + functional + v2_8,
Expand Down
9 changes: 3 additions & 6 deletions tests/jax/pod-test.libsonnet
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,12 @@ local tpus = import 'templates/tpus.libsonnet';
podTest:: common.JaxTest + mixins.Functional {
modelName: 'pod-%s-%s' % [self.jaxlibVersion, self.tpuSettings.softwareVersion],

testScript:: |||
set -x
set -u
set -e
setup: |||
%(installLocalJax)s
%(maybeBuildJaxlib)s
%(printDiagnostics)s
||| % self.scriptConfig,
runTest: |||
# Very basic smoke test
python3 -c "import jax; assert jax.device_count() == 32, jax.device_count()"
Expand Down
6 changes: 3 additions & 3 deletions tests/jax/tpu-embedding.libsonnet
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,14 @@ local tpus = import 'templates/tpus.libsonnet';
frameworkPrefix: 'jax-tpu-embedding',
extraFlags:: [],
testCommand:: error 'Must define `testCommand`',
testScript:: |||
setup: |||
pip install --upgrade 'jax[tpu]==0.4.4' -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
pip install flax==0.6.7
gsutil cp gs://cloud-tpu-tpuvm-artifacts/tensorflow/20230214/tf_nightly-2.13.0-cp38-cp38-linux_x86_64.whl .
pip install tf_nightly-2.13.0-cp38-cp38-linux_x86_64.whl
git clone https://github.com/jax-ml/jax-tpu-embedding.git
%(testCommand)s
||| % config.testCommand,
|||,
runTest: config.testCommand,
},
local tpu_embedding_pjit = self.tpu_embedding_pjit,
tpu_embedding_pjit:: tpu_embedding {
Expand Down

0 comments on commit 00d9941

Please sign in to comment.