Skip to content

dependency error because of orbax-checkpoint version #1273

Open
@shota-inoue-lts

Description

@shota-inoue-lts

Situation

I execute following content shell script to train model via TextMax with xpk.

# !/bin/bash
# GCP Settings
PROJECT=XXXXXXX
ZONE=XXXXXXX
CLUSTER=XXXXXXX
TPU_TYPE=v6e-8
NUM_SLICES=1

# Storage path
BASE_OUTPUT_DIR=XXXXXXX
DATASET_PATH=XXXXXXX
DATASET_TYPE=tfds

# HyperParameters
PER_DEVICE_BATCH_SIZE=3
MODEL_NAME=llama3.1-8b
MAX_TARGET_LENGTH=4096
STEPS=35
BLOCK_SIZE=2048
REMAT_POLICY=full
TOKENIZER_PATH=assets/tokenizer_llama3.tiktoken
VMEM_LIMIT=114688
ENABLE_CHECKPOINTING=true
CHECKPOINT_PERIOD=30

# Parallelism
ICI_DATA_PARALLELISM=1
ICI_PIPELINE_PARALLELISM=4
ICI_FSDP_PARALLELISM=1
ICI_FSDP_TRANSPOSE_PARALLELISM=1
ICI_SEQUENCE_PARALLELISM=1
ICI_TENSOR_PARALLELISM=2
ICI_TENSOR_SEQUENCE_PARALLELISM=1
ICI_EXPERT_PARALLELISM=1
ICI_AUTOREGRESSIVE_PARALLELISM=1

# image settings
CLOUD_IMAGE_NAME=${USER}_runner
DOCKER_IMAGE=gcr.io/${PROJECT}/${CLOUD_IMAGE_NAME}:latest

# EXP settings
EXP_NAME=$(echo $MODEL_NAME | tr '.' '-')-bs${PER_DEVICE_BATCH_SIZE}-$(date +'%m-%d-%H-%M-%S') # --workload: Workload name must be less than 40 characters and match the pattern `[a-z]([-a-z0-9]*[a-z0-9])?`

# download dataset
cd ~/maxtext
bash download_dataset.sh ${PROJECT} ${DATASET_PATH}

# create and push image
cd ~/maxtext
bash docker_build_dependency_image.sh DEVICE=tpu MODE=stable_stack BASEIMAGE=us-docker.pkg.dev/cloud-tpu-images/jax-stable-stack/tpu:jax0.4.37-rev1
bash docker_upload_runner.sh CLOUD_IMAGE_NAME=${CLOUD_IMAGE_NAME}

# create workload for model training
cd ~/xpk
python3 xpk.py workload create \
    --cluster ${CLUSTER} \
    --docker-image ${DOCKER_IMAGE} \
    --workload ${EXP_NAME} \
    --tpu-type ${TPU_TYPE} \
    --num-slices ${NUM_SLICES}  \
    --use-vertex-tensorboard \
    --experiment-name ${EXP_NAME} \
    --zone ${ZONE} \
    --on-demand \
    --enable-debug-logs \
    --project ${PROJECT} \
    --command "export LIBTPU_INIT_ARGS='--xla_tpu_use_minor_sharding_for_major_trivial_input=true --xla_tpu_relayout_group_size_threshold_for_reduce_scatter=1 --xla_tpu_scoped_vmem_limit_kib=${VMEM_LIMIT} --xla_tpu_enable_async_collective_fusion=true --xla_tpu_assign_all_reduce_scatter_layout --xla_tpu_enable_async_collective_fusion_fuse_all_gather=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true' && python3 MaxText/train.py MaxText/configs/base.yml model_name=${MODEL_NAME} base_output_directory=${BASE_OUTPUT_DIR} dataset_path=${DATASET_PATH} run_name=${EXP_NAME} tokenizer_path=${TOKENIZER_PATH} max_target_length=${MAX_TARGET_LENGTH} per_device_batch_size=${PER_DEVICE_BATCH_SIZE} remat_policy=${REMAT_POLICY} steps=${STEPS} enable_checkpointing=${ENABLE_CHECKPOINTING} checkpoint_period=${CHECKPOINT_PERIOD} use_iota_embed=true gcs_metrics=true dataset_type=${DATASET_TYPE} reuse_example_batch=1 profiler=xplane attention=flash sa_block_q=${BLOCK_SIZE} sa_block_q_dkv=${BLOCK_SIZE} sa_block_q_dq=${BLOCK_SIZE} ici_data_parallelism=${ICI_DATA_PARALLELISM} ici_pipeline_parallelism=${ICI_PIPELINE_PARALLELISM} ici_fsdp_parallelism=${ICI_FSDP_PARALLELISM} ici_fsdp_transpose_parallelism=${ICI_FSDP_TRANSPOSE_PARALLELISM} ici_sequence_parallelism=${ICI_SEQUENCE_PARALLELISM} ici_tensor_parallelism=${ICI_TENSOR_PARALLELISM} ici_tensor_sequence_parallelism=${ICI_TENSOR_SEQUENCE_PARALLELISM} ici_expert_parallelism=${ICI_EXPERT_PARALLELISM} ici_autoregressive_parallelism=${ICI_AUTOREGRESSIVE_PARALLELISM}"

Error Message

I got the following error during process of MaxText/train.py. Especially, the error occur if I activate a checkpoint setting (ENABLE_CHECKPOINTING=true).

"'Traceback (most recent call last):
File ""/deps/MaxText/train.py"", line 1031, in <module>
app.run(main)
File ""/usr/local/lib/python3.10/site-packages/absl/app.py"", line 308, in run
_run_main(main, args)
File ""/usr/local/lib/python3.10/site-packages/absl/app.py"", line 254, in _run_main
sys.exit(main(argv))
File ""/deps/MaxText/train.py"", line 1027, in main
train_loop(config)
File ""/deps/MaxText/train.py"", line 897, in train_loop
if save_checkpoint(checkpoint_manager, int(step), state_to_save, config.dataset_type, data_iterator, config):
File ""/deps/MaxText/train.py"", line 241, in save_checkpoint
return checkpoint_manager.save(
File ""/usr/local/lib/python3.10/site-packages/orbax/checkpoint/checkpoint_manager.py"", line 1278, in save
self._checkpointer.save(
File ""/usr/local/lib/python3.10/site-packages/orbax/checkpoint/_src/checkpointers/async_checkpointer.py"", line 491, in save
asyncio_utils.run_sync(
File ""/usr/local/lib/python3.10/site-packages/orbax/checkpoint/_src/asyncio_utils.py"", line 50, in run_sync
return asyncio.run(coro)
File ""/usr/local/lib/python3.10/asyncio/runners.py"", line 44, in run
return loop.run_until_complete(main)
File ""/usr/local/lib/python3.10/asyncio/base_events.py"", line 649, in run_until_complete
return future.result()
File ""/usr/local/lib/python3.10/site-packages/orbax/checkpoint/_src/checkpointers/async_checkpointer.py"", line 392, in _save
await self._handler.async_save(tmpdir.get(), args=ckpt_args) or []
File ""/usr/local/lib/python3.10/site-packages/orbax/checkpoint/_src/handlers/composite_checkpoint_handler.py"", line 706, in async_save
jax.tree.flatten(await asyncio.gather(*save_ops))[0] or []
File ""/usr/local/lib/python3.10/site-packages/orbax/checkpoint/_src/handlers/pytree_checkpoint_handler.py"", line 583, in async_save
return await self._handler_impl.async_save(directory, args=args)
File ""/usr/local/lib/python3.10/site-packages/orbax/checkpoint/_src/handlers/base_pytree_checkpoint_handler.py"", line 482, in async_save
commit_futures = await asyncio.gather(*serialize_ops)
File ""/usr/local/lib/python3.10/site-packages/orbax/checkpoint/_src/serialization/type_handlers.py"", line 1127, in serialize
future.CommitFutureAwaitingContractedSignals(
File ""/usr/local/lib/python3.10/site-packages/orbax/checkpoint/_src/futures/future.py"", line 367, in init
receive_signals = get_awaitable_signals_from_contract()
File ""/usr/local/lib/python3.10/site-packages/orbax/checkpoint/_src/futures/future.py"", line 57, in get_awaitable_signals_from_contract
values_str = str(client.key_value_try_get(barrier_key))
AttributeError: 'DistributedRuntimeClient' object has no attribute 'key_value_try_get'. Did you mean: 'key_value_dir_get'?"

Solution

We should install specific package version orbax-checkpoint==0.10.3 (Now orbax-checkpoint==0.11.5 will be installed without version specification) when we create docker image. We solved the problem by rewriting these requirements file (requirements_with_jax_stable_stack.txt, requirements_with_jax_stable_stack.txt).

# maxtext/requirements_with_jax_stable_stack.txt
...
orbax-checkpoint==0.10.3
...
# maxtext/requirements.txt
...
orbax-checkpoint==0.10.3
...

Reference

I referred the following URLs when I create the shell script.

How to run MaxText with XPK?
https://github.com/AI-Hypercomputer/maxtext/blob/main/getting_started/Run_MaxText_via_xpk.md

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions