From 00d99415fba229ff693eb30153070318be0e6f27 Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Tue, 8 Aug 2023 17:08:37 -0700 Subject: [PATCH] Refactor JAX tests to separate setup from test commands (#967) * Refactor JAX tests to separate setup from test commands * formatting * undo some small changes * Remove another duplicate * Remove some unecessary flags * Move some lines between setup and run * fixes * Make bash_logout common * fix working dir in compilation cache test --- tests/jax/common.libsonnet | 31 +++++++++---------- tests/jax/compilation-cache.libsonnet | 15 +++------ tests/jax/latest/common.libsonnet | 22 +++++-------- .../latest/flax-bart-wiki_summary.libsonnet | 5 +-- tests/jax/latest/flax-gpt2-oscar.libsonnet | 6 ++-- tests/jax/latest/flax-sd-pokemon.libsonnet | 6 ++-- .../jax/latest/flax-vit-imagenette.libsonnet | 5 +-- .../latest/flax-wmt-wmt17_translate.libsonnet | 4 +-- tests/jax/pod-test.libsonnet | 9 ++---- tests/jax/tpu-embedding.libsonnet | 6 ++-- 10 files changed, 47 insertions(+), 62 deletions(-) diff --git a/tests/jax/common.libsonnet b/tests/jax/common.libsonnet index e1e4ff30f..e718d7e0c 100644 --- a/tests/jax/common.libsonnet +++ b/tests/jax/common.libsonnet @@ -58,9 +58,20 @@ local tpus = import 'templates/tpus.libsonnet'; tpuVmCreateSleepSeconds: 60, }, - // JAX tests are structured as bash scripts that run directly on the Cloud - // TPU VM instead of using docker images - testScript:: error 'Must define `testScript`', + setup: error 'Must define `setup`', + runTest: error 'Must define `runTest`', + + testScript:: ||| + set -x + set -u + set -e + + # .bash_logout sometimes causes a spurious bad exit code, remove it. + rm .bash_logout + + %(setup)s + %(runTest)s + ||| % self, command: [ 'bash', '-c', @@ -171,13 +182,6 @@ local tpus = import 'templates/tpus.libsonnet'; huggingFaceTransformer:: { scriptConfig+: { installPackages: ||| - set -x - set -u - set -e - - # .bash_logout sometimes causes a spurious bad exit code, remove it. - rm .bash_logout - pip install --upgrade pip git clone https://github.com/huggingface/transformers.git cd transformers && pip install . @@ -200,13 +204,6 @@ local tpus = import 'templates/tpus.libsonnet'; huggingFaceDiffuser:: { scriptConfig+: { installPackages: ||| - set -x - set -u - set -e - - # .bash_logout sometimes causes a spurious bad exit code, remove it. - rm .bash_logout - pip install --upgrade pip git clone https://github.com/huggingface/diffusers.git cd diffusers && pip install . diff --git a/tests/jax/compilation-cache.libsonnet b/tests/jax/compilation-cache.libsonnet index 26965cf46..7870b2442 100644 --- a/tests/jax/compilation-cache.libsonnet +++ b/tests/jax/compilation-cache.libsonnet @@ -20,14 +20,7 @@ local mixins = import 'templates/mixins.libsonnet'; compilationCacheTest:: common.JaxTest + common.tpuVmBaseImage + mixins.Functional { modelName: 'compilation-cache-test', - testScript:: ||| - set -x - set -u - set -e - - # .bash_logout sometimes causes a spurious bad exit code, remove it. - rm .bash_logout - + setup: ||| pip install --upgrade pip %(installLocalJax)s @@ -40,6 +33,8 @@ local mixins = import 'templates/mixins.libsonnet'; exit 1 fi + cd ~ + mkdir "/tmp/compilation_cache_integration_test" cat >integration.py <<'END_SCRIPT' import jax @@ -59,12 +54,12 @@ local mixins = import 'templates/mixins.libsonnet'; num_of_files = sum(1 for f in os.listdir("/tmp/compilation_cache_integration_test")) assert num_of_files == 1, f"The number of files in the cache should be 1 but is {num_of_files}" END_SCRIPT - + ||| % self.scriptConfig, + runTest: ||| python3 integration.py python3 directory_size.py python3 integration.py python3 directory_size.py - ||| % self.scriptConfig, }, diff --git a/tests/jax/latest/common.libsonnet b/tests/jax/latest/common.libsonnet index 5b04b2630..f9fc37bd2 100644 --- a/tests/jax/latest/common.libsonnet +++ b/tests/jax/latest/common.libsonnet @@ -23,14 +23,7 @@ local tpus = import 'templates/tpus.libsonnet'; extraDeps:: [], extraFlags:: [], - testScript:: ||| - set -x - set -u - set -e - - # .bash_logout sometimes causes a spurious bad exit code, remove it. - rm .bash_logout - + setup: ||| pip install --upgrade pip pip install --upgrade clu %(extraDeps)s @@ -47,16 +40,17 @@ local tpus = import 'templates/tpus.libsonnet'; git clone https://github.com/google/flax cd flax pip install --upgrade git+https://github.com/google/flax.git + ||| % (self.scriptConfig { + extraDeps: std.join(' ', config.extraDeps), + }), + runTest: ||| cd examples/%(folderName)s - export GCS_BUCKET=$(MODEL_DIR) export TFDS_DATA_DIR=$(TFDS_DIR) python3 main.py --workdir=$(MODEL_DIR) --config=configs/%(extraConfig)s %(extraFlags)s ||| % (self.scriptConfig { folderName: config.folderName, - modelName: config.modelName, - extraDeps: std.join(' ', config.extraDeps), extraConfig: config.extraConfig, extraFlags: std.join(' ', config.extraFlags), }), @@ -67,14 +61,14 @@ local tpus = import 'templates/tpus.libsonnet'; frameworkPrefix: 'flax.latest', modelName:: 'bert-glue', extraFlags:: [], - testScript:: ||| + setup: ||| %(installPackages)s pip install -r examples/flax/text-classification/requirements.txt %(verifySetup)s - + ||| % self.scriptConfig, + runTest: ||| export GCS_BUCKET=$(MODEL_DIR) export OUTPUT_DIR='./bert-glue' - python3 examples/flax/text-classification/run_flax_glue.py --model_name_or_path bert-base-cased \ --output_dir ${OUTPUT_DIR} \ --logging_dir ${OUTPUT_DIR} \ diff --git a/tests/jax/latest/flax-bart-wiki_summary.libsonnet b/tests/jax/latest/flax-bart-wiki_summary.libsonnet index d1aca3ff1..d7e6d7e97 100644 --- a/tests/jax/latest/flax-bart-wiki_summary.libsonnet +++ b/tests/jax/latest/flax-bart-wiki_summary.libsonnet @@ -23,11 +23,12 @@ local tpus = import 'templates/tpus.libsonnet'; frameworkPrefix: 'flax.latest', modelName:: 'bart-wiki.summary', extraFlags:: [], - testScript:: ||| + setup: ||| %(installPackages)s pip install -r examples/flax/summarization/requirements.txt %(verifySetup)s - + ||| % (self.scriptConfig { extraFlags: std.join(' ', config.extraFlags) }), + runTest: ||| export GCS_BUCKET=$(MODEL_DIR) python3 examples/flax/summarization/run_summarization_flax.py \ --output_dir './bart-base-wiki' \ diff --git a/tests/jax/latest/flax-gpt2-oscar.libsonnet b/tests/jax/latest/flax-gpt2-oscar.libsonnet index 6f793f10c..87a8085f4 100644 --- a/tests/jax/latest/flax-gpt2-oscar.libsonnet +++ b/tests/jax/latest/flax-gpt2-oscar.libsonnet @@ -24,14 +24,15 @@ local tpus = import 'templates/tpus.libsonnet'; frameworkPrefix: 'flax.latest', modelName:: 'gpt2-oscar', extraFlags:: [], - testScript:: ||| + setup: ||| %(installPackages)s pip install -r examples/flax/language-modeling/requirements.txt %(verifySetup)s cd examples/flax/language-modeling gsutil cp -r gs://cloud-tpu-tpuvm-artifacts/config/xl-ml-test/jax/gpt2 . - + ||| % (self.scriptConfig { extraFlags: std.join(' ', config.extraFlags) }), + runTest: ||| python3 run_clm_flax.py \ --output_dir=./gpt2 \ --model_type=gpt2 \ @@ -52,7 +53,6 @@ local tpus = import 'templates/tpus.libsonnet'; --logging_steps=500 \ --eval_steps=2500 \ %(extraFlags)s - ||| % (self.scriptConfig { extraFlags: std.join(' ', config.extraFlags) }), }, diff --git a/tests/jax/latest/flax-sd-pokemon.libsonnet b/tests/jax/latest/flax-sd-pokemon.libsonnet index a64d1ddd9..4ac95dfd5 100644 --- a/tests/jax/latest/flax-sd-pokemon.libsonnet +++ b/tests/jax/latest/flax-sd-pokemon.libsonnet @@ -23,11 +23,12 @@ local tpus = import 'templates/tpus.libsonnet'; frameworkPrefix: 'flax.latest', modelName:: 'sd-pokemon', extraFlags:: [], - testScript:: ||| + setup: ||| %(installPackages)s pip install -U -r examples/text_to_image/requirements_flax.txt %(verifySetup)s - + ||| % self.scriptConfig, + runTest: ||| export GCS_BUCKET=$(MODEL_DIR) export MODEL_NAME="duongna/stable-diffusion-v1-4-flax" export dataset_name="lambdalabs/pokemon-blip-captions" @@ -42,7 +43,6 @@ local tpus = import 'templates/tpus.libsonnet'; --output_dir="./sd-pokemon-model" \ --cache_dir /tmp \ %(extraFlags)s - ||| % (self.scriptConfig { extraFlags: std.join(' ', config.extraFlags) }), }, diff --git a/tests/jax/latest/flax-vit-imagenette.libsonnet b/tests/jax/latest/flax-vit-imagenette.libsonnet index 5d211325b..64c13607d 100644 --- a/tests/jax/latest/flax-vit-imagenette.libsonnet +++ b/tests/jax/latest/flax-vit-imagenette.libsonnet @@ -23,14 +23,15 @@ local tpus = import 'templates/tpus.libsonnet'; frameworkPrefix: 'flax.latest', modelName:: 'vit-imagenette', extraFlags:: [], - testScript:: ||| + setup: ||| %(installPackages)s pip install -r examples/flax/vision/requirements.txt %(verifySetup)s wget https://s3.amazonaws.com/fast-ai-imageclas/imagenette2.tgz tar -xvzf imagenette2.tgz - + ||| % self.scriptConfig, + runTest: ||| export GCS_BUCKET=$(MODEL_DIR) python3 examples/flax/vision/run_image_classification.py \ --output_dir './vit-imagenette' \ diff --git a/tests/jax/latest/flax-wmt-wmt17_translate.libsonnet b/tests/jax/latest/flax-wmt-wmt17_translate.libsonnet index 8548af2f6..e819f99ae 100644 --- a/tests/jax/latest/flax-wmt-wmt17_translate.libsonnet +++ b/tests/jax/latest/flax-wmt-wmt17_translate.libsonnet @@ -55,13 +55,13 @@ local tpus = import 'templates/tpus.libsonnet'; local wmt_profiling = self.wmt_profiling, wmt_profiling:: wmt { local config = self, - testScript+:: ||| + runTest+: ||| gsutil -q stat $(MODEL_DIR)/plugins/profile/*/*.xplane.pb gsutil cp -r $(MODEL_DIR)/plugins /tmp/ python3 -m pip uninstall tensorboard_plugin_profile python3 -m pip install tbp-nightly python3 ~/.local/lib/python3.*/site-packages/tensorboard_plugin_profile/integration_tests/tpu/tensorflow/tpu_tf2_keras_test.* --log_directory=/tmp/ - ||| % (self.scriptConfig {}), + ||| % (self.scriptConfig), }, configs: [ wmt + functional + v2_8, diff --git a/tests/jax/pod-test.libsonnet b/tests/jax/pod-test.libsonnet index f31ad4edd..255a27f35 100644 --- a/tests/jax/pod-test.libsonnet +++ b/tests/jax/pod-test.libsonnet @@ -21,15 +21,12 @@ local tpus = import 'templates/tpus.libsonnet'; podTest:: common.JaxTest + mixins.Functional { modelName: 'pod-%s-%s' % [self.jaxlibVersion, self.tpuSettings.softwareVersion], - testScript:: ||| - set -x - set -u - set -e - + setup: ||| %(installLocalJax)s %(maybeBuildJaxlib)s %(printDiagnostics)s - + ||| % self.scriptConfig, + runTest: ||| # Very basic smoke test python3 -c "import jax; assert jax.device_count() == 32, jax.device_count()" diff --git a/tests/jax/tpu-embedding.libsonnet b/tests/jax/tpu-embedding.libsonnet index 9f9021abe..1d5d6caf1 100644 --- a/tests/jax/tpu-embedding.libsonnet +++ b/tests/jax/tpu-embedding.libsonnet @@ -23,14 +23,14 @@ local tpus = import 'templates/tpus.libsonnet'; frameworkPrefix: 'jax-tpu-embedding', extraFlags:: [], testCommand:: error 'Must define `testCommand`', - testScript:: ||| + setup: ||| pip install --upgrade 'jax[tpu]==0.4.4' -f https://storage.googleapis.com/jax-releases/libtpu_releases.html pip install flax==0.6.7 gsutil cp gs://cloud-tpu-tpuvm-artifacts/tensorflow/20230214/tf_nightly-2.13.0-cp38-cp38-linux_x86_64.whl . pip install tf_nightly-2.13.0-cp38-cp38-linux_x86_64.whl git clone https://github.com/jax-ml/jax-tpu-embedding.git - %(testCommand)s - ||| % config.testCommand, + |||, + runTest: config.testCommand, }, local tpu_embedding_pjit = self.tpu_embedding_pjit, tpu_embedding_pjit:: tpu_embedding {