Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Qinwen/add sdxl #10

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 52 additions & 0 deletions training/trillium/DIffusion-XL-MaxDiffusion/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# Instructions for training MaxDiffusion SDXL on TPU trillium

## XPK setup
Please follow this [link](https://github.com/AI-Hypercomputer/tpu-recipes/training/trillium/XPK_README.md) to create your GKE cluster with XPK

## Prep for Maxdiffusion
Please follow this [link](https://github.com/AI-Hypercomputer/tpu-recipes/training/trillium/MAXDIFFUSION_README.md) to install maxdiffusion and build docker image

Download pretrained stable_xl_base from [huggingface](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/tree/main)
##### `gsutil -m cp stable-diffusion-xl-base-1.0 ${OUTPUT_DIR}/checkpoints `

Prepare dataset and store at local folder

`python -m unittest input_pipeline_interface_test.InputPipelineInterface.test_make_pokemon_iterator_sdxl_cache`

Upload prepared dataset to gcs location

`gsutil -m cp /tmp/pokemon-gpt4-captions_xl ${OUTPUT_DIR}/dataset `
## Run Maxdiffusion SDXL workloads on GKE

### Test Env
jaxlib=0.4.35

[maxdiffusion](https://github.com/AI-Hypercomputer/maxdiffusion.git)@269b6216ac65adb9e7044ec454879dc99856d5e9

### Starting workload

From the maxdiffusion root directory, start your SDXL workload on v6e-256

```
python3 ~/xpk/xpk.py workload create --cluster $CLUSTER_NAME --workload $USER-maxdiffusion --command "bash sdxl-v6e-256-pbds-1.sh $OUT_DIR" \
--base-docker-image=maxdiffusion_base_image \
--tpu-type=v6e-256 --num-slices=1 --zone=$ZONE --project=$PROJECT_ID
```

From your workload logs, you should start seeing step time logs like the following:
```
completed step: 254, seconds: 0.164, TFLOP/s/device: 123.764, loss: 0.055
```

start your SDXL workload on multi-slices of v6e-256

```
python3 ~/xpk/xpk.py workload create --cluster $CLUSTER_NAME --workload $USER-maxdiffusion --command "bash sdxl-2xv6e-256-pbds-1.sh $OUT_DIR" \
--base-docker-image=maxdiffusion_base_image \
--tpu-type=v6e-256 --num-slices=2 --zone=$ZONE --project=$PROJECT_ID
```

From your workload logs, you should start seeing step time logs like the following:
```
completed step: 92, seconds: 0.228, TFLOP/s/device: 89.120, loss: 0.057
```
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
export JAX_PLATFORMS="tpu,cpu"

checkpoints=${OUTPUT_DIR}/checkpoints
dataset_path=${OUTPUT_DIR}/dataset

ENABLE_PJRT_COMPATIBILITY=true TPU_SLICE_BUILDER_DUMP_CHIP_FORCE=true TPU_SLICE_BUILDER_DUMP_ICI=true JAX_FORCE_TPU_INIT=true ENABLE_TPUNETD_CLIENT=true python src/maxdiffusion/train_sdxl.py src/maxdiffusion/configs/base_xl.yml pretrained_model_name_or_path=${checkpoints}/models--stabilityai--stable-diffusion-xl-base-1.0 \
revision=refs/pr/95 activations_dtype=bfloat16 weights_dtype=bfloat16 dataset_name=${dataset_path}/pokemon-gpt4-captions_xl resolution=1024 per_device_batch_size=1 jax_cache_dir=${OUT_DIR}/cache_dir/ max_train_steps=100 attention=flash run_name=trillium-sdxl enable_profiler=True output_dir=${OUT_DIR}
10 changes: 10 additions & 0 deletions training/trillium/DIffusion-XL-MaxDiffusion/sdxl-v6e-256-pbds-1.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
export JAX_PLATFORMS="tpu,cpu"

export LIBTPU_INIT_ARGS='--xla_tpu_enable_async_collective_fusion_fuse_all_gather=true --xla_tpu_megacore_fusion_allow_ags=false --xla_enable_async_collective_permute=true --xla_tpu_enable_ag_backward_pipelining=true --xla_tpu_enable_data_parallel_all_reduce_opt=true --xla_tpu_data_parallel_opt_different_sized_ops=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true --xla_tpu_spmd_threshold_for_allgather_cse=1000000 --xla_jf_spmd_threshold_for_windowed_einsum_mib=1000000'
LIBTPU_INIT_ARGS+=' --xla_sc_disable_megacore_partitioning=true --xla_tpu_use_tc_device_shape_on_sc=true --tpu_use_continuations=true --xla_sc_enable_instruction_fusion=false --xla_sc_disjoint_spmem=false --2a886c8_chip_config_name=megachip_tccontrol --xla_jf_crs_combiner_threshold_count=10 --xla_tpu_enable_sparse_core_collective_offload_all_reduce=true --xla_tpu_enable_async_collective_fusion_fuse_all_gather=true'

checkpoints=${OUTPUT_DIR}/checkpoints
dataset_path=${OUTPUT_DIR}/dataset

ENABLE_PJRT_COMPATIBILITY=true TPU_SLICE_BUILDER_DUMP_CHIP_FORCE=true TPU_SLICE_BUILDER_DUMP_ICI=true JAX_FORCE_TPU_INIT=true ENABLE_TPUNETD_CLIENT=true python src/maxdiffusion/train_sdxl.py src/maxdiffusion/configs/base_xl.yml pretrained_model_name_or_path=${checkpoints}/models--stabilityai--stable-diffusion-xl-base-1.0 \
revision=refs/pr/95 activations_dtype=bfloat16 weights_dtype=bfloat16 dataset_name=${dataset_path}/pokemon-gpt4-captions_xl resolution=1024 per_device_batch_size=1 jax_cache_dir=${OUT_DIR}/cache_dir/ max_train_steps=100 attention=flash run_name=trillium-sdxl enable_profiler=True output_dir=${OUT_DIR}
31 changes: 31 additions & 0 deletions training/trillium/MAXDIFFUSION_README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# Prep for Maxdiffusion workloads on GKE
1. Clone [maxdiffusion](https://github.com/AI-Hypercomputer/maxdiffusion.git) repo and move to its directory
```
git clone https://github.com/AI-Hypercomputer/maxdiffusion.git
cd maxdiffusion
git checkout ${maxdiffusion_HASH}
```

2. Run the following commands to build the docker image
```
bash docker_build_dependency_image.sh DEVICE=tpu MODE=stable_stack BASEIMAGE=us-docker.pkg.dev/cloud-tpu-images/jax-stable-stack/tpu:jax0.4.35-rev1
```

3. Upload your docker image to Container Registry
```
bash docker_upload_runner.sh CLOUD_IMAGE_NAME=${USER}_runner
```

4. Create your GCS bucket
```
OUTPUT_DIR=gs://v6e-demo-run #<your_GCS_folder_for_results>
gcloud storage buckets create ${OUTPUT_DIR} --project ${PROJECT}
```

5. Specify your workload configs
```
export PROJECT=#<your_compute_project>
export ZONE=#<your_compute_zone>
export CLUSTER_NAME=v6e-demo #<your_cluster_name>
export OUTPUT_DIR=gs://v6e-demo/ #<your_GCS_folder_for_results>
```