Open
Description
Situation
I execute following content shell script to train model via TextMax with xpk.
# !/bin/bash
# GCP Settings
PROJECT=XXXXXXX
ZONE=XXXXXXX
CLUSTER=XXXXXXX
TPU_TYPE=v6e-8
NUM_SLICES=1
# Storage path
BASE_OUTPUT_DIR=XXXXXXX
DATASET_PATH=XXXXXXX
DATASET_TYPE=tfds
# HyperParameters
PER_DEVICE_BATCH_SIZE=3
MODEL_NAME=llama3.1-8b
MAX_TARGET_LENGTH=4096
STEPS=35
BLOCK_SIZE=2048
REMAT_POLICY=full
TOKENIZER_PATH=assets/tokenizer_llama3.tiktoken
VMEM_LIMIT=114688
ENABLE_CHECKPOINTING=true
CHECKPOINT_PERIOD=30
# Parallelism
ICI_DATA_PARALLELISM=1
ICI_PIPELINE_PARALLELISM=4
ICI_FSDP_PARALLELISM=1
ICI_FSDP_TRANSPOSE_PARALLELISM=1
ICI_SEQUENCE_PARALLELISM=1
ICI_TENSOR_PARALLELISM=2
ICI_TENSOR_SEQUENCE_PARALLELISM=1
ICI_EXPERT_PARALLELISM=1
ICI_AUTOREGRESSIVE_PARALLELISM=1
# image settings
CLOUD_IMAGE_NAME=${USER}_runner
DOCKER_IMAGE=gcr.io/${PROJECT}/${CLOUD_IMAGE_NAME}:latest
# EXP settings
EXP_NAME=$(echo $MODEL_NAME | tr '.' '-')-bs${PER_DEVICE_BATCH_SIZE}-$(date +'%m-%d-%H-%M-%S') # --workload: Workload name must be less than 40 characters and match the pattern `[a-z]([-a-z0-9]*[a-z0-9])?`
# download dataset
cd ~/maxtext
bash download_dataset.sh ${PROJECT} ${DATASET_PATH}
# create and push image
cd ~/maxtext
bash docker_build_dependency_image.sh DEVICE=tpu MODE=stable_stack BASEIMAGE=us-docker.pkg.dev/cloud-tpu-images/jax-stable-stack/tpu:jax0.4.37-rev1
bash docker_upload_runner.sh CLOUD_IMAGE_NAME=${CLOUD_IMAGE_NAME}
# create workload for model training
cd ~/xpk
python3 xpk.py workload create \
--cluster ${CLUSTER} \
--docker-image ${DOCKER_IMAGE} \
--workload ${EXP_NAME} \
--tpu-type ${TPU_TYPE} \
--num-slices ${NUM_SLICES} \
--use-vertex-tensorboard \
--experiment-name ${EXP_NAME} \
--zone ${ZONE} \
--on-demand \
--enable-debug-logs \
--project ${PROJECT} \
--command "export LIBTPU_INIT_ARGS='--xla_tpu_use_minor_sharding_for_major_trivial_input=true --xla_tpu_relayout_group_size_threshold_for_reduce_scatter=1 --xla_tpu_scoped_vmem_limit_kib=${VMEM_LIMIT} --xla_tpu_enable_async_collective_fusion=true --xla_tpu_assign_all_reduce_scatter_layout --xla_tpu_enable_async_collective_fusion_fuse_all_gather=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true' && python3 MaxText/train.py MaxText/configs/base.yml model_name=${MODEL_NAME} base_output_directory=${BASE_OUTPUT_DIR} dataset_path=${DATASET_PATH} run_name=${EXP_NAME} tokenizer_path=${TOKENIZER_PATH} max_target_length=${MAX_TARGET_LENGTH} per_device_batch_size=${PER_DEVICE_BATCH_SIZE} remat_policy=${REMAT_POLICY} steps=${STEPS} enable_checkpointing=${ENABLE_CHECKPOINTING} checkpoint_period=${CHECKPOINT_PERIOD} use_iota_embed=true gcs_metrics=true dataset_type=${DATASET_TYPE} reuse_example_batch=1 profiler=xplane attention=flash sa_block_q=${BLOCK_SIZE} sa_block_q_dkv=${BLOCK_SIZE} sa_block_q_dq=${BLOCK_SIZE} ici_data_parallelism=${ICI_DATA_PARALLELISM} ici_pipeline_parallelism=${ICI_PIPELINE_PARALLELISM} ici_fsdp_parallelism=${ICI_FSDP_PARALLELISM} ici_fsdp_transpose_parallelism=${ICI_FSDP_TRANSPOSE_PARALLELISM} ici_sequence_parallelism=${ICI_SEQUENCE_PARALLELISM} ici_tensor_parallelism=${ICI_TENSOR_PARALLELISM} ici_tensor_sequence_parallelism=${ICI_TENSOR_SEQUENCE_PARALLELISM} ici_expert_parallelism=${ICI_EXPERT_PARALLELISM} ici_autoregressive_parallelism=${ICI_AUTOREGRESSIVE_PARALLELISM}"
Error Message
I got the following error during process of MaxText/train.py
. Especially, the error occur if I activate a checkpoint setting (ENABLE_CHECKPOINTING=true
).
"'Traceback (most recent call last):
File ""/deps/MaxText/train.py"", line 1031, in <module>
app.run(main)
File ""/usr/local/lib/python3.10/site-packages/absl/app.py"", line 308, in run
_run_main(main, args)
File ""/usr/local/lib/python3.10/site-packages/absl/app.py"", line 254, in _run_main
sys.exit(main(argv))
File ""/deps/MaxText/train.py"", line 1027, in main
train_loop(config)
File ""/deps/MaxText/train.py"", line 897, in train_loop
if save_checkpoint(checkpoint_manager, int(step), state_to_save, config.dataset_type, data_iterator, config):
File ""/deps/MaxText/train.py"", line 241, in save_checkpoint
return checkpoint_manager.save(
File ""/usr/local/lib/python3.10/site-packages/orbax/checkpoint/checkpoint_manager.py"", line 1278, in save
self._checkpointer.save(
File ""/usr/local/lib/python3.10/site-packages/orbax/checkpoint/_src/checkpointers/async_checkpointer.py"", line 491, in save
asyncio_utils.run_sync(
File ""/usr/local/lib/python3.10/site-packages/orbax/checkpoint/_src/asyncio_utils.py"", line 50, in run_sync
return asyncio.run(coro)
File ""/usr/local/lib/python3.10/asyncio/runners.py"", line 44, in run
return loop.run_until_complete(main)
File ""/usr/local/lib/python3.10/asyncio/base_events.py"", line 649, in run_until_complete
return future.result()
File ""/usr/local/lib/python3.10/site-packages/orbax/checkpoint/_src/checkpointers/async_checkpointer.py"", line 392, in _save
await self._handler.async_save(tmpdir.get(), args=ckpt_args) or []
File ""/usr/local/lib/python3.10/site-packages/orbax/checkpoint/_src/handlers/composite_checkpoint_handler.py"", line 706, in async_save
jax.tree.flatten(await asyncio.gather(*save_ops))[0] or []
File ""/usr/local/lib/python3.10/site-packages/orbax/checkpoint/_src/handlers/pytree_checkpoint_handler.py"", line 583, in async_save
return await self._handler_impl.async_save(directory, args=args)
File ""/usr/local/lib/python3.10/site-packages/orbax/checkpoint/_src/handlers/base_pytree_checkpoint_handler.py"", line 482, in async_save
commit_futures = await asyncio.gather(*serialize_ops)
File ""/usr/local/lib/python3.10/site-packages/orbax/checkpoint/_src/serialization/type_handlers.py"", line 1127, in serialize
future.CommitFutureAwaitingContractedSignals(
File ""/usr/local/lib/python3.10/site-packages/orbax/checkpoint/_src/futures/future.py"", line 367, in init
receive_signals = get_awaitable_signals_from_contract()
File ""/usr/local/lib/python3.10/site-packages/orbax/checkpoint/_src/futures/future.py"", line 57, in get_awaitable_signals_from_contract
values_str = str(client.key_value_try_get(barrier_key))
AttributeError: 'DistributedRuntimeClient' object has no attribute 'key_value_try_get'. Did you mean: 'key_value_dir_get'?"
Solution
We should install specific package version orbax-checkpoint==0.10.3
(Now orbax-checkpoint==0.11.5
will be installed without version specification) when we create docker image. We solved the problem by rewriting these requirements file (requirements_with_jax_stable_stack.txt
, requirements_with_jax_stable_stack.txt
).
# maxtext/requirements_with_jax_stable_stack.txt
...
orbax-checkpoint==0.10.3
...
# maxtext/requirements.txt
...
orbax-checkpoint==0.10.3
...
Reference
I referred the following URLs when I create the shell script.
How to run MaxText with XPK?
https://github.com/AI-Hypercomputer/maxtext/blob/main/getting_started/Run_MaxText_via_xpk.md
Metadata
Metadata
Assignees
Labels
No labels