Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

dependency error because of orbax-checkpoint version #1273

Open
shota-inoue-lts opened this issue Feb 14, 2025 · 1 comment
Open

dependency error because of orbax-checkpoint version #1273

shota-inoue-lts opened this issue Feb 14, 2025 · 1 comment

Comments

@shota-inoue-lts
Copy link

shota-inoue-lts commented Feb 14, 2025

Situation

I execute following content shell script to train model via TextMax with xpk.

# !/bin/bash
# GCP Settings
PROJECT=XXXXXXX
ZONE=XXXXXXX
CLUSTER=XXXXXXX
TPU_TYPE=v6e-8
NUM_SLICES=1

# Storage path
BASE_OUTPUT_DIR=XXXXXXX
DATASET_PATH=XXXXXXX
DATASET_TYPE=tfds

# HyperParameters
PER_DEVICE_BATCH_SIZE=3
MODEL_NAME=llama3.1-8b
MAX_TARGET_LENGTH=4096
STEPS=35
BLOCK_SIZE=2048
REMAT_POLICY=full
TOKENIZER_PATH=assets/tokenizer_llama3.tiktoken
VMEM_LIMIT=114688
ENABLE_CHECKPOINTING=true
CHECKPOINT_PERIOD=30

# Parallelism
ICI_DATA_PARALLELISM=1
ICI_PIPELINE_PARALLELISM=4
ICI_FSDP_PARALLELISM=1
ICI_FSDP_TRANSPOSE_PARALLELISM=1
ICI_SEQUENCE_PARALLELISM=1
ICI_TENSOR_PARALLELISM=2
ICI_TENSOR_SEQUENCE_PARALLELISM=1
ICI_EXPERT_PARALLELISM=1
ICI_AUTOREGRESSIVE_PARALLELISM=1

# image settings
CLOUD_IMAGE_NAME=${USER}_runner
DOCKER_IMAGE=gcr.io/${PROJECT}/${CLOUD_IMAGE_NAME}:latest

# EXP settings
EXP_NAME=$(echo $MODEL_NAME | tr '.' '-')-bs${PER_DEVICE_BATCH_SIZE}-$(date +'%m-%d-%H-%M-%S') # --workload: Workload name must be less than 40 characters and match the pattern `[a-z]([-a-z0-9]*[a-z0-9])?`

# download dataset
cd ~/maxtext
bash download_dataset.sh ${PROJECT} ${DATASET_PATH}

# create and push image
cd ~/maxtext
bash docker_build_dependency_image.sh DEVICE=tpu MODE=stable_stack BASEIMAGE=us-docker.pkg.dev/cloud-tpu-images/jax-stable-stack/tpu:jax0.4.37-rev1
bash docker_upload_runner.sh CLOUD_IMAGE_NAME=${CLOUD_IMAGE_NAME}

# create workload for model training
cd ~/xpk
python3 xpk.py workload create \
    --cluster ${CLUSTER} \
    --docker-image ${DOCKER_IMAGE} \
    --workload ${EXP_NAME} \
    --tpu-type ${TPU_TYPE} \
    --num-slices ${NUM_SLICES}  \
    --use-vertex-tensorboard \
    --experiment-name ${EXP_NAME} \
    --zone ${ZONE} \
    --on-demand \
    --enable-debug-logs \
    --project ${PROJECT} \
    --command "export LIBTPU_INIT_ARGS='--xla_tpu_use_minor_sharding_for_major_trivial_input=true --xla_tpu_relayout_group_size_threshold_for_reduce_scatter=1 --xla_tpu_scoped_vmem_limit_kib=${VMEM_LIMIT} --xla_tpu_enable_async_collective_fusion=true --xla_tpu_assign_all_reduce_scatter_layout --xla_tpu_enable_async_collective_fusion_fuse_all_gather=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true' && python3 MaxText/train.py MaxText/configs/base.yml model_name=${MODEL_NAME} base_output_directory=${BASE_OUTPUT_DIR} dataset_path=${DATASET_PATH} run_name=${EXP_NAME} tokenizer_path=${TOKENIZER_PATH} max_target_length=${MAX_TARGET_LENGTH} per_device_batch_size=${PER_DEVICE_BATCH_SIZE} remat_policy=${REMAT_POLICY} steps=${STEPS} enable_checkpointing=${ENABLE_CHECKPOINTING} checkpoint_period=${CHECKPOINT_PERIOD} use_iota_embed=true gcs_metrics=true dataset_type=${DATASET_TYPE} reuse_example_batch=1 profiler=xplane attention=flash sa_block_q=${BLOCK_SIZE} sa_block_q_dkv=${BLOCK_SIZE} sa_block_q_dq=${BLOCK_SIZE} ici_data_parallelism=${ICI_DATA_PARALLELISM} ici_pipeline_parallelism=${ICI_PIPELINE_PARALLELISM} ici_fsdp_parallelism=${ICI_FSDP_PARALLELISM} ici_fsdp_transpose_parallelism=${ICI_FSDP_TRANSPOSE_PARALLELISM} ici_sequence_parallelism=${ICI_SEQUENCE_PARALLELISM} ici_tensor_parallelism=${ICI_TENSOR_PARALLELISM} ici_tensor_sequence_parallelism=${ICI_TENSOR_SEQUENCE_PARALLELISM} ici_expert_parallelism=${ICI_EXPERT_PARALLELISM} ici_autoregressive_parallelism=${ICI_AUTOREGRESSIVE_PARALLELISM}"

Error Message

I got the following error during process of MaxText/train.py. Especially, the error occur if I activate a checkpoint setting (ENABLE_CHECKPOINTING=true).

"'Traceback (most recent call last):
File ""/deps/MaxText/train.py"", line 1031, in <module>
app.run(main)
File ""/usr/local/lib/python3.10/site-packages/absl/app.py"", line 308, in run
_run_main(main, args)
File ""/usr/local/lib/python3.10/site-packages/absl/app.py"", line 254, in _run_main
sys.exit(main(argv))
File ""/deps/MaxText/train.py"", line 1027, in main
train_loop(config)
File ""/deps/MaxText/train.py"", line 897, in train_loop
if save_checkpoint(checkpoint_manager, int(step), state_to_save, config.dataset_type, data_iterator, config):
File ""/deps/MaxText/train.py"", line 241, in save_checkpoint
return checkpoint_manager.save(
File ""/usr/local/lib/python3.10/site-packages/orbax/checkpoint/checkpoint_manager.py"", line 1278, in save
self._checkpointer.save(
File ""/usr/local/lib/python3.10/site-packages/orbax/checkpoint/_src/checkpointers/async_checkpointer.py"", line 491, in save
asyncio_utils.run_sync(
File ""/usr/local/lib/python3.10/site-packages/orbax/checkpoint/_src/asyncio_utils.py"", line 50, in run_sync
return asyncio.run(coro)
File ""/usr/local/lib/python3.10/asyncio/runners.py"", line 44, in run
return loop.run_until_complete(main)
File ""/usr/local/lib/python3.10/asyncio/base_events.py"", line 649, in run_until_complete
return future.result()
File ""/usr/local/lib/python3.10/site-packages/orbax/checkpoint/_src/checkpointers/async_checkpointer.py"", line 392, in _save
await self._handler.async_save(tmpdir.get(), args=ckpt_args) or []
File ""/usr/local/lib/python3.10/site-packages/orbax/checkpoint/_src/handlers/composite_checkpoint_handler.py"", line 706, in async_save
jax.tree.flatten(await asyncio.gather(*save_ops))[0] or []
File ""/usr/local/lib/python3.10/site-packages/orbax/checkpoint/_src/handlers/pytree_checkpoint_handler.py"", line 583, in async_save
return await self._handler_impl.async_save(directory, args=args)
File ""/usr/local/lib/python3.10/site-packages/orbax/checkpoint/_src/handlers/base_pytree_checkpoint_handler.py"", line 482, in async_save
commit_futures = await asyncio.gather(*serialize_ops)
File ""/usr/local/lib/python3.10/site-packages/orbax/checkpoint/_src/serialization/type_handlers.py"", line 1127, in serialize
future.CommitFutureAwaitingContractedSignals(
File ""/usr/local/lib/python3.10/site-packages/orbax/checkpoint/_src/futures/future.py"", line 367, in init
receive_signals = get_awaitable_signals_from_contract()
File ""/usr/local/lib/python3.10/site-packages/orbax/checkpoint/_src/futures/future.py"", line 57, in get_awaitable_signals_from_contract
values_str = str(client.key_value_try_get(barrier_key))
AttributeError: 'DistributedRuntimeClient' object has no attribute 'key_value_try_get'. Did you mean: 'key_value_dir_get'?"

Solution

We should install specific package version orbax-checkpoint==0.10.3 (Now orbax-checkpoint==0.11.5 will be installed without version specification) when we create docker image. We solved the problem by rewriting these requirements file (requirements_with_jax_stable_stack.txt, requirements_with_jax_stable_stack.txt).

# maxtext/requirements_with_jax_stable_stack.txt
...
orbax-checkpoint==0.10.3
...
# maxtext/requirements.txt
...
orbax-checkpoint==0.10.3
...

Reference

I referred the following URLs when I create the shell script.

How to run MaxText with XPK?
https://github.com/AI-Hypercomputer/maxtext/blob/main/getting_started/Run_MaxText_via_xpk.md

shota-inoue-lts added a commit to shota-inoue-lts/maxtext that referenced this issue Feb 14, 2025
…Hypercomputer#1273)

- Update requirements.txt and requirements_with_jax_stable_stack.txt to specify orbax-checkpoint==0.10.3.
- Prevent AttributeError in MaxText/train.py related to key_value_try_get.
@shota-inoue-lts
Copy link
Author

shota-inoue-lts commented Feb 14, 2025

I fixed it in PR #1274

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant