diff --git a/.github/README.md b/.github/README.md new file mode 120000 index 0000000000..525f9bef8c --- /dev/null +++ b/.github/README.md @@ -0,0 +1 @@ +../jaxpp.README.md \ No newline at end of file diff --git a/.vscode/launch.json b/.vscode/launch.json index c0d04607f2..6666df3920 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -92,6 +92,6 @@ "quantization=int8", "quantize_kvcache=True" ] - } + }, ] } \ No newline at end of file diff --git a/jaxpp.Dockerfile b/jaxpp.Dockerfile new file mode 100644 index 0000000000..451e51059a --- /dev/null +++ b/jaxpp.Dockerfile @@ -0,0 +1,26 @@ +# Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +ARG BASE_IMAGE +FROM $BASE_IMAGE AS base +ARG JAX_INSTALL_URL + +COPY requirements.txt /tmp/requirements.txt +RUN uv pip install -U pip && uv pip install --no-cache-dir -U -r /tmp/requirements.txt + +COPY --chown=$USER_UID:$USER_GID . maxtext + +RUN uv pip install --no-cache-dir -e '/workdir/jaxpp[dev]' +RUN uv pip install --no-cache-dir -e /workdir/maxtext[cuda_12] --resolution=lowest && \ + if [[ -n "$JAX_INSTALL_URL" ]]; then uv pip install $JAX_INSTALL_URL; fi diff --git a/jaxpp.README.md b/jaxpp.README.md new file mode 100644 index 0000000000..e852cd5614 --- /dev/null +++ b/jaxpp.README.md @@ -0,0 +1,78 @@ +# Overview + +This repository is a fork of [MaxText](https://github.com/AI-Hypercomputer/maxtext) created for training with [JaxPP](https://github.com/NVIDIA/jaxpp). + +# Notable changes + +The changes between this repo and the upstream MaxText is kept minimal in general. +Some of the notable changes are listed below. + +* The `__call__` method of the `Decoder` class in [src/MaxText/layers/decoders.py](src/MaxText/layers/decoders.py) + calls `jaxpp.pipeline_enter_stage` to mark stage boundaries for pipeline parallelism. +* The `maybe_initialize_jax_distributed_system` function in [src/MaxText/max_utils.py](src/MaxText/max_utils.py) + creates `RemoteMpmdMesh` to be used by JaxPP. +* [src/MaxText/train.py](src/MaxText/train.py) contains changes to + * Enable pipeline parallelism for the train step, and + * Mark the pipeline loop in the train step with `jaxpp.treduce`. + +# Docker image + +For ease of use, we provide a docker image with this fork under `/workdir/maxtext`. +The docker image has all the dependencies that are needed to use MaxText with JaxPP installed. + +## Building and Testing Docker Container + +The build process uses the JaxPP base image as a starting point. Follow the instructions at [JaxPP's Building the Base Image](https://github.com/NVIDIA/jaxpp#building-the-base-image) to build the `jaxpp-base` image first. + +### Prerequisites +- Docker installed and configured +- NVIDIA Container Toolkit installed +- JaxPP base image built and available locally + +### Building the Main Image + +After building the base image, you can build the main image: + +```bash +# Check if jaxpp-base image exists +if [ -z "$(docker images -q jaxpp-base)" ]; then + echo "Error: jaxpp-base image not found. Please build it first using the instructions at https://github.com/NVIDIA/jaxpp#building-the-base-image." +else + docker build --force-rm=true \ + -f jaxpp.Dockerfile \ + --build-arg BASE_IMAGE=jaxpp-base \ + -t maxtext-jaxpp . +fi +``` + +### Running Tests + +The container includes several test suites for different models: + +1. **Tiny Llama4 Model Tests**: +```bash +docker run --gpus=all --shm-size=10.24gb --ulimit memlock=-1 --ulimit stack=67108864 \ + -e CUDA_VISIBLE_DEVICES=0 --rm --workdir /workdir/maxtext maxtext-jaxpp \ + "nvidia-smi && CONFIG_FILE=./scripts/llama4_proxy_config.sh bash scripts/test_1gpu_config.sh" +``` + +2. **Tiny Mixtral Model Tests**: +```bash +docker run --gpus=all --shm-size=10.24gb --ulimit memlock=-1 --ulimit stack=67108864 \ + -e CUDA_VISIBLE_DEVICES=0 --rm --workdir /workdir/maxtext maxtext-jaxpp \ + "nvidia-smi && MODEL_CONFIG='model_name=mixtral-8x7b override_model_config=True base_num_decoder_layers=2 base_emb_dim=512 base_mlp_dim=1792' bash scripts/test_1gpu_config.sh" +``` + +3. **Tiny Mistral Model Tests**: +```bash +docker run --gpus=all --shm-size=10.24gb --ulimit memlock=-1 --ulimit stack=67108864 \ + -e CUDA_VISIBLE_DEVICES=0 --rm --workdir /workdir/maxtext maxtext-jaxpp \ + "nvidia-smi && bash MODEL_CONFIG='model_name=mistral-7b override_model_config=True base_num_decoder_layers=2' bash scripts/test_1gpu_config.sh" +``` + +Note: The tests require GPU access and sufficient GPU memory. + +# Profiling + +Profiling is enabled by default in the 6th step, and the first 7 steps are ignored in the performance statistics. +It allows the performance statstics to be collected without the profiling overhead while producing the profiling data while running the benchmarks. \ No newline at end of file diff --git a/scripts/deepseek3_proxy_config.sh b/scripts/deepseek3_proxy_config.sh new file mode 100644 index 0000000000..eda5d6f234 --- /dev/null +++ b/scripts/deepseek3_proxy_config.sh @@ -0,0 +1,8 @@ +export MODEL_CONFIG=" + model_name=deepseek3-671b + override_model_config=True + base_num_decoder_layers=4 + base_emb_dim=896 + base_mlp_dim=2304 + base_moe_mlp_dim=256 +" diff --git a/scripts/llama3.3_proxy_config.sh b/scripts/llama3.3_proxy_config.sh new file mode 100644 index 0000000000..8efc81e6ff --- /dev/null +++ b/scripts/llama3.3_proxy_config.sh @@ -0,0 +1,5 @@ +export MODEL_CONFIG=" + model_name=llama3.3-70b + override_model_config=True + base_num_decoder_layers=2 +" diff --git a/scripts/llama4_proxy_config.sh b/scripts/llama4_proxy_config.sh new file mode 100644 index 0000000000..139746b2be --- /dev/null +++ b/scripts/llama4_proxy_config.sh @@ -0,0 +1,8 @@ +export MODEL_CONFIG=" + model_name=llama4-17b-16e + override_model_config=True + base_num_decoder_layers=2 + base_emb_dim=640 + base_mlp_dim=2048 + base_moe_mlp_dim=2048 +" diff --git a/scripts/local_mc.sh b/scripts/local_mc.sh new file mode 100644 index 0000000000..187968b293 --- /dev/null +++ b/scripts/local_mc.sh @@ -0,0 +1,25 @@ +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +if [ -z "$N_PROCS" ] || [ -z "$N_GPUS" ] || [ -z "$COMMAND" ]; then + echo "N_PROCS, N_GPUS, and COMMAND must be set" + exit 1 +fi + +seq 0 $(($N_PROCS - 1)) | xargs -P $N_PROCS -I {} bash -c ' \ +n_gpus=$2; \ +start=$(({} * n_gpus)); \ +end=$((start + n_gpus - 1)); \ +JAX_COORDINATOR_IP="localhost" JAX_COORDINATOR_PORT=1234 NNODES=$1 NODE_RANK={} \ +CUDA_VISIBLE_DEVICES=$(seq -s, $start $end) $3' _ $N_PROCS $N_GPUS "$COMMAND" diff --git a/scripts/run_local_mc.sh b/scripts/run_local_mc.sh new file mode 100644 index 0000000000..4f0cd40813 --- /dev/null +++ b/scripts/run_local_mc.sh @@ -0,0 +1,18 @@ +export TEST_CONFIG=" + override_model_config=True dataset_type=synthetic steps=10 +" + +export COMMAND="python3 -u -m MaxText.train src/MaxText/configs/base.yml \ + base_output_directory=run_local_mc_outputs \ + run_name=run_$(date +%Y-%m-%d-%H:%M:%S) \ + enable_checkpointing=false \ + async_checkpointing=false \ + dtype=bfloat16 \ + weight_dtype=bfloat16 \ + hardware=gpu \ + $MODEL_CONFIG \ + $TEST_CONFIG \ + $PARALLELISM_CONFIG + $JAXPP_CONFIG" + +bash ./scripts/local_mc.sh diff --git a/scripts/run_tests.sh b/scripts/run_tests.sh new file mode 100644 index 0000000000..2be4249b48 --- /dev/null +++ b/scripts/run_tests.sh @@ -0,0 +1,24 @@ +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +bash scripts/test_8gpu_llama4_proxy.sh + +RAY_ADDRESS=local python -m MaxText.train src/MaxText/configs/base.yml run_name=runner_jaxpp_$(date +%Y-%m-%d-%H-%M) base_output_directory=/tmp/log hardware=gpu dataset_type=synthetic model_name=gpt3-52k steps=20 dtype=bfloat16 max_target_length=2048 per_device_batch_size=4 dcn_data_parallelism=1 ici_data_parallelism=2 ici_tensor_parallelism=2 ici_pipeline_parallelism=2 num_pipeline_repeats=1 num_pipeline_microbatches=8 enable_checkpointing=false use_jaxpp=True schedule=interleaved_1f1b + +tests=$(python3 -m pytest --co -q -W ignore::DeprecationWarning tests/train_compile_jaxpp_test.py | awk '/^[[:space:]]*$/{exit} {print}') + +for t in $tests; do + echo $t + python3 -m pytest --log-cli-level=INFO -s "$t" +done diff --git a/scripts/test_1gpu_config.sh b/scripts/test_1gpu_config.sh new file mode 100644 index 0000000000..083e1199ae --- /dev/null +++ b/scripts/test_1gpu_config.sh @@ -0,0 +1,28 @@ +if [ -n "$MODEL_CONFIG" ] && [ -n "$CONFIG_FILE" ]; then + echo "Error: both MODEL_CONFIG and CONFIG_FILE are set" + exit 1 +fi + +if [ -n "$CONFIG_FILE" ]; then + source $CONFIG_FILE +fi + +export N_PROCS=1 +export N_GPUS=1 + +# Run plain JAX config +bash ./scripts/run_local_mc.sh + +# Run JaxPP config +export PARALLELISM_CONFIG="ici_pipeline_parallelism=1" + +export JAXPP_CONFIG=" + scan_layers=False + use_jaxpp=True + schedule=interleaved_1f1b + num_pipeline_microbatches=4 + num_pipeline_repeats=1 + max_target_length=64 +" + +bash ./scripts/run_local_mc.sh diff --git a/scripts/test_8gpu_deepseek3_proxy.sh b/scripts/test_8gpu_deepseek3_proxy.sh new file mode 100644 index 0000000000..5b42422d49 --- /dev/null +++ b/scripts/test_8gpu_deepseek3_proxy.sh @@ -0,0 +1,22 @@ +source scripts/deepseek3_proxy_config.sh + +export PARALLELISM_CONFIG=" + dcn_pipeline_parallelism=1 ici_pipeline_parallelism=2 + ici_data_parallelism=1 + ici_tensor_parallelism=2 + ici_expert_parallelism=2 + ici_fsdp_parallelism=1 +" + +export JAXPP_CONFIG=" + scan_layers=False + use_jaxpp=True + schedule=interleaved_1f1b + num_pipeline_microbatches=4 + num_pipeline_repeats=1 +" + +export N_PROCS=2 +export N_GPUS=4 + +bash ./scripts/run_local_mc.sh diff --git a/scripts/test_8gpu_llama3.3_proxy.sh b/scripts/test_8gpu_llama3.3_proxy.sh new file mode 100644 index 0000000000..16eb980c05 --- /dev/null +++ b/scripts/test_8gpu_llama3.3_proxy.sh @@ -0,0 +1,30 @@ +export JAX_USE_SHARDY_PARTITIONER=0 +export JAXPP_ENABLE_LICM=1 +export NVTE_FUSED_ATTN=1 +# --xla_dump_hlo_pass_re=.* +export XLA_FLAGS="--xla_dump_hlo_as_html --xla_dump_hlo_as_text --xla_dump_to='./llama3-hlos-pp2' --xla_gpu_enable_latency_hiding_scheduler=true" +source scripts/llama3.3_proxy_config.sh + +export PARALLELISM_CONFIG=" + dcn_pipeline_parallelism=1 ici_pipeline_parallelism=2 + ici_data_parallelism=1 + ici_context_parallelism=2 + ici_tensor_parallelism=2 + ici_fsdp_parallelism=1 + per_device_batch_size=1 + max_target_length=8192 +" + +export JAXPP_CONFIG=" + scan_layers=False + use_jaxpp=True + schedule=interleaved_1f1b + num_pipeline_microbatches=4 + num_pipeline_repeats=1 + profiler=xplane +" + +export N_PROCS=8 +export N_GPUS=1 + +bash ./scripts/run_local_mc.sh diff --git a/scripts/test_8gpu_llama4_proxy.sh b/scripts/test_8gpu_llama4_proxy.sh new file mode 100644 index 0000000000..73d7c8749a --- /dev/null +++ b/scripts/test_8gpu_llama4_proxy.sh @@ -0,0 +1,22 @@ +source scripts/llama4_proxy_config.sh + +export PARALLELISM_CONFIG=" + dcn_pipeline_parallelism=1 ici_pipeline_parallelism=2 + ici_data_parallelism=1 + ici_tensor_parallelism=2 + ici_expert_parallelism=2 + ici_fsdp_parallelism=1 +" + +export JAXPP_CONFIG=" + scan_layers=False + use_jaxpp=True + schedule=interleaved_1f1b + num_pipeline_microbatches=4 + num_pipeline_repeats=1 +" + +export N_PROCS=2 +export N_GPUS=4 + +bash ./scripts/run_local_mc.sh diff --git a/src/MaxText/configs/base.yml b/src/MaxText/configs/base.yml index 3a5d40ac69..89e885ef1e 100644 --- a/src/MaxText/configs/base.yml +++ b/src/MaxText/configs/base.yml @@ -1,4 +1,5 @@ # Copyright 2023–2025 Google LLC +# Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -84,7 +85,7 @@ gcs_metrics: False save_config_to_gcs: False # Gradient dtype -grad_dtype: "float32" +grad_dtype: "bfloat16" # Activation dtypes. dtype: "bfloat16" @@ -165,8 +166,8 @@ mtp_eval_target_module: 0 # mixture of experts (moe) num_experts: 1 num_experts_per_tok: 1 -megablox: True -sparse_matmul: True +megablox: False # Only used when sparse_matmul=True +sparse_matmul: False capacity_factor: -1.0 # a factor to decide expert capacity for token dropping, and no dropping by default load_balance_loss_weight: 0.01 # weight for the load balance loss use_random_routing: False # whether to use random routing for debug/test purpose @@ -216,7 +217,7 @@ inhomogeneous_layer_cycle_interval: 1 # but a smaller size per microbatch which may hurt per-stage performance. Additionally, note when microbatches > num_stages we have the opportunity to # perform the circular transfer (last stage to first) asynchronously. # The bubble fraction is (num_stages - 1) / (num_pipeline_repeats * num_pipeline_microbatches + num_stages - 1) -num_layers_per_pipeline_stage: 1 +num_layers_per_pipeline_stage: 1 # NOTE(jaxpp) unused in JaxPP # The number of repeats will be set to num_decoder_layers / (num_pipeline_stages * num_layers_per_pipeline_stage) num_pipeline_repeats: -1 pipeline_parallel_layers: -1 # Pipeline only this number of layers - for the remaining layers the "stage" mesh axes will act like data parallelism. @@ -259,7 +260,7 @@ set_remat_policy_on_layers_per_stage: False # Choose 'remat_policy' between 'minimal_with_context', 'minimal', 'save_dot_with_context_except_mlp', 'save_dot_except_mlpwi', 'save_dot_except_mlp', # 'save_qkv_proj', 'qkv_proj_offloaded', 'custom', 'minimal_offloaded', 'save_out_proj' and 'full'. # These options offer a trade-off between speed (fastest to slowest) and HBM usage (highest to lowest) -remat_policy: 'full' +remat_policy: 'minimal' # If "custom" remat_policy is chosen, you can select tensors from the following list to offload on host memory, rematerialize or save on device memory. # Pick one of these options for following tensors: ['remat','device','offload'] decoder_layer_input: 'device' # this tensor cannot be rematerialized - it serves as periodic checkpoints that act as the remat start points @@ -276,12 +277,12 @@ out_proj: 'remat' optimizer_memory_host_offload: False parameter_memory_host_offload: False -scan_layers: True # We recommend setting this to false when using pipeline parallelism, instead scanning the PP iterations. -param_scan_axis: 1 +scan_layers: False # We recommend setting this to false when using pipeline parallelism, instead scanning the PP iterations. +param_scan_axis: 0 # NOTE(jaxpp) we set to 0 instead of 1 to avoid flax transposes # The attention parameter dictates the specific algorithm/methodology used to compute the attention scores # The attention_type parameter determines the variants of attention, e.g. global or local_sliding -attention: 'autoselected' # Supported attention: autoselected, dot_product, flash, cudnn_flash_te +attention: 'cudnn_flash_te' # Supported attention: autoselected, dot_product, flash, cudnn_flash_te attention_type: 'global' # Supported attention_type: global, local_sliding, chunk, mla attention_bias: False # If True, adds a learnable bias to the query, key, and value projections attention_sink: False @@ -355,7 +356,7 @@ hardware: 'tpu' # Supported hardware types are 'tpu', 'gpu', 'gpu_multiprocess' # Parallelism mesh_axes: ['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'context', 'context_autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive'] logical_axis_rules: [ - ['activation_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']], + ['activation_batch', ['fsdp', 'fsdp_transpose', 'expert']], ['activation_batch_no_exp', ['data', 'fsdp', 'fsdp_transpose']], ['activation_embed_and_logits_batch', ['data', 'stage', 'fsdp', 'fsdp_transpose', 'expert']], ['activation_embed_and_logits_batch_sequence', ['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'context', 'expert']], @@ -375,8 +376,8 @@ logical_axis_rules: [ ['activation_mlp', ['tensor', 'tensor_transpose', 'tensor_sequence']], ['activation_kv', ['tensor', 'tensor_transpose', 'tensor_sequence']], ['activation_prefill_kv_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']], - ['activation_kv_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']], - ['activation_kv_batch_no_exp', ['data', 'fsdp', 'fsdp_transpose']], + ['activation_kv_batch', ['fsdp', 'fsdp_transpose', 'expert']], + ['activation_kv_batch_no_exp', ['fsdp', 'fsdp_transpose']], ['activation_kv_head_dim', ['tensor', 'tensor_transpose', 'tensor_sequence']], ['activation_vocab', ['tensor', 'tensor_transpose', 'tensor_sequence']], ['activation_vocab', ['tensor', 'tensor_transpose']], @@ -454,7 +455,7 @@ dcn_pipeline_parallelism: 1 dcn_expert_parallelism: 1 dcn_autoregressive_parallelism: 1 # never recommended ici_data_parallelism: 1 -ici_fsdp_parallelism: -1 # recommended ICI axis to be auto-sharded +ici_fsdp_parallelism: 1 # recommended ICI axis to be auto-sharded ici_fsdp_transpose_parallelism: 1 ici_sequence_parallelism: 1 ici_context_parallelism: 1 @@ -581,12 +582,12 @@ autoregressive_decode_assert: "" # e.g. nsys profile -s none --force-overwrite true --capture-range=cudaProfilerApi --capture-range-end=stop {training command} profiler: "" # Supported profiler: '', xplane, nsys # If set to true, upload all profiler results from all hosts. Otherwise, only upload the profiler result from the first host. -upload_all_profiler_results: False +upload_all_profiler_results: True # Skip first n steps for profiling, to omit things like compilation and to give # the iteration time a chance to stabilize. -skip_first_n_steps_for_profiler: 1 +skip_first_n_steps_for_profiler: 5 # Profile for a small number of steps to avoid a large profile file size. -profiler_steps: 5 +profiler_steps: 1 profile_cleanly: True # If set to true, adds a block_until_ready on train state which aligns the profile for each step. profile_periodically_period: -1 # If set to a positive integer, profile every profile_periodically_period steps. # This is useful to debug scenarios where performance is changing. @@ -676,6 +677,12 @@ eval_interval: -1 # the specific number of train step between eval_step eval_steps: -1 # run this number of steps for eval, recommend setting this to prevent error due to running out of evel data target_eval_loss: 0. # early stop once reaching target eval_loss +# NOTE(jaxpp): begin parameters +use_jaxpp: False +schedule: "eager_1f1b" +fuse_steady_state: False +# NOTE(jaxpp): end parameters + # Goodput parameters enable_goodput_recording: False monitor_goodput: False @@ -760,7 +767,7 @@ allow_split_physical_axes: False # Apply transformations to the mesh to optimize for TPU v6e optimize_mesh_for_tpu_v6e: False -shardy: True # Whether to use shardy XLA backend (default in Jax starting 0.7.0), or GSPMD (to be fully deprecated ~2026) +shardy: False # Whether to use shardy XLA backend (default in Jax starting 0.7.0), or GSPMD (to be fully deprecated ~2026) use_ragged_attention: False ragged_block_size: 256 @@ -859,7 +866,7 @@ subslice_shape: "" # NNX enable_nnx: false -shard_optimizer_over_data: False +shard_optimizer_over_data: True ################################## Qwen3-Next Specific Configs ################################## # Kernel size for the 1D convolution in the Gated Delta Net diff --git a/src/MaxText/configs/models/gpt3-175b.yml b/src/MaxText/configs/models/gpt3-175b.yml index 5e24cf4268..af515e66b2 100644 --- a/src/MaxText/configs/models/gpt3-175b.yml +++ b/src/MaxText/configs/models/gpt3-175b.yml @@ -1,4 +1,5 @@ # Copyright 2023–2025 Google LLC +# Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -28,13 +29,13 @@ logits_via_embedding: True normalize_embedding_logits: False logits_dot_in_fp32: False normalization_layer_epsilon: 1.e-05 -use_iota_embed: True +use_iota_embed: False fused_qkv: True opt_type: "adam_pax" decoder_block: "gpt3" -dataset_path: "gs://mlperf-llm-public2" -dataset_name: "c4/en:3.0.4" -eval_dataset_name: "c4/en:3.0.5" +#dataset_path: "gs://mlperf-llm-public2" +#dataset_name: "c4/en:3.0.4" +#eval_dataset_name: "c4/en:3.0.5" gradient_clipping_threshold: 1. adam_b1: 0.9 adam_b2: 0.95 diff --git a/src/MaxText/configs/models/gpt3-52k.yml b/src/MaxText/configs/models/gpt3-52k.yml index 5513663f82..d41f94076c 100644 --- a/src/MaxText/configs/models/gpt3-52k.yml +++ b/src/MaxText/configs/models/gpt3-52k.yml @@ -1,4 +1,5 @@ # Copyright 2023–2025 Google LLC +# Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -18,7 +19,7 @@ base_emb_dim: 16 base_num_query_heads: 2 base_num_kv_heads: 2 base_mlp_dim: 64 -base_num_decoder_layers: 1 +base_num_decoder_layers: 8 head_dim: 8 trainable_position_size: 2048 mlp_activations: ["gelu"] @@ -28,7 +29,7 @@ logits_via_embedding: True normalize_embedding_logits: False logits_dot_in_fp32: False normalization_layer_epsilon: 1.e-05 -use_iota_embed: True +use_iota_embed: False fused_qkv: True opt_type: "adam_pax" decoder_block: "gpt3" diff --git a/src/MaxText/configs/models/llama2-70b.yml b/src/MaxText/configs/models/llama2-70b.yml index 67dd87f68f..ceecd6819d 100644 --- a/src/MaxText/configs/models/llama2-70b.yml +++ b/src/MaxText/configs/models/llama2-70b.yml @@ -1,4 +1,5 @@ # Copyright 2023–2025 Google LLC +# Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -# model config for llama2-7b +# model config for llama2-70b base_emb_dim: 8192 base_num_query_heads: 64 @@ -25,4 +26,3 @@ vocab_size: 32000 logits_via_embedding: False normalization_layer_epsilon: 1.0e-5 decoder_block: "llama2" -logical_axis_rules: [['norm', 'fsdp']] diff --git a/src/MaxText/input_pipeline/_tfds_data_processing_c4_mlperf.py b/src/MaxText/input_pipeline/_tfds_data_processing_c4_mlperf.py index 361cd3ea75..bd1179423f 100644 --- a/src/MaxText/input_pipeline/_tfds_data_processing_c4_mlperf.py +++ b/src/MaxText/input_pipeline/_tfds_data_processing_c4_mlperf.py @@ -1,4 +1,5 @@ # Copyright 2023–2025 Google LLC +# Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -323,7 +324,7 @@ def make_c4_mlperf_train_iterator( """Make train iterator of customized C4 dataset for mlperf gpt3 training.""" train_ds = get_dataset( dataset_name=config.dataset_name, - split="train2", + split="train" if config.dataset_name == "c4/en:3.1.0" else "train2", dataloading_host_index=process_indices.index(jax.process_index()), dataloading_host_count=len(process_indices), enable_data_shuffling=config.enable_data_shuffling, @@ -334,10 +335,14 @@ def make_c4_mlperf_train_iterator( sp_tokenizer = get_tokenizer( config.tokenizer_path, config.tokenizer_type, config.add_bos, config.add_eos, config.hf_access_token ) + # A hack that (global_batch_size_to_load * num_process) when using jaxpp + # because preprocess_train_dataset divides batch sizes with num_process + # natively, and jaxpp uses only one process + global_batch_size_to_load = config.global_batch_size_to_load if not config.use_jaxpp else config.global_batch_size_to_load * jax.process_count() train_ds = preprocess_train_dataset( train_ds, sp_tokenizer=sp_tokenizer, - train_global_batch_size_to_load=config.global_batch_size_to_load, + train_global_batch_size_to_load=global_batch_size_to_load, max_target_length=config.max_target_length, shuffle_buffer_size=128, data_shuffle_seed=config.data_shuffle_seed, diff --git a/src/MaxText/input_pipeline/input_pipeline_interface.py b/src/MaxText/input_pipeline/input_pipeline_interface.py index 34f92946fc..7debaaf28d 100644 --- a/src/MaxText/input_pipeline/input_pipeline_interface.py +++ b/src/MaxText/input_pipeline/input_pipeline_interface.py @@ -1,4 +1,5 @@ # Copyright 2023–2025 Google LLC +# Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -59,7 +60,7 @@ def create_data_iterator(config: pyconfig.HyperParameters, mesh): # Return synthetic dataset if selected if config.dataset_type == "synthetic": - return SyntheticDataIterator(config, mesh), None + return SyntheticDataIterator(config, mesh), SyntheticDataIterator(config, mesh) dataset_type_to_train_eval_iterator = { "tfds": (make_tfds_train_iterator, make_tfds_eval_iterator), "grain": (make_grain_train_iterator, make_grain_eval_iterator), @@ -94,7 +95,7 @@ def create_data_iterator(config: pyconfig.HyperParameters, mesh): # Generate output eval iterator output_eval_iterator = None - if config.eval_interval > 0: + if config.eval_interval > 0 and not config.use_jaxpp: process_indices_eval = get_process_loading_real_data( config.data_sharding, config.global_batch_size_to_load_eval, diff --git a/src/MaxText/layers/decoders.py b/src/MaxText/layers/decoders.py index 92cd3749e5..6d5eba8f1b 100644 --- a/src/MaxText/layers/decoders.py +++ b/src/MaxText/layers/decoders.py @@ -1,4 +1,5 @@ # Copyright 2023–2025 Google LLC +# Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -58,6 +59,9 @@ simple_layer, ) +import jaxpp +from packaging.version import Version + # ------------------------------------------------------------------------------ # The network: Decoder Definitions # ------------------------------------------------------------------------------ @@ -292,6 +296,8 @@ def get_remat_policy(self): elif cfg.remat_policy == "minimal": # save all except context policy = self.minimal_policy() + elif cfg.remat_policy == "save_dot_only": + policy = jax.checkpoint_policies.checkpoint_dots elif cfg.remat_policy == "save_dot_with_context_except_mlp": policy = jax.checkpoint_policies.save_only_these_names( "query_proj", @@ -674,7 +680,7 @@ def __call__( deterministic, model_mode, ) - if cfg.using_pipeline_parallelism: + if cfg.using_pipeline_parallelism and not cfg.use_jaxpp: if cfg.pipeline_fsdp_ag_once: partition_spec = self.pipeline_module.get_weight_sharding( y, decoder_segment_ids, decoder_positions, deterministic, model_mode @@ -727,6 +733,7 @@ def __call__( )(y, *broadcast_args) else: if cfg.scan_layers: + assert not cfg.use_jaxpp, "Layer scanning is not supported with JaxPP" if cfg.decoder_block == DecoderBlockType.DEEPSEEK: assert len(RemattedBlockLayers) == 2, "Scanned layers must have a length of 2 using deepseek." layer_call_kwargs = { @@ -789,30 +796,57 @@ def __call__( **layer_kwargs, )(y, *broadcast_args) else: + num_logical_stages = 1 + layers_per_stage = cfg.num_decoder_layers + cutoffs = [cfg.num_decoder_layers] + if cfg.use_jaxpp: + num_logical_stages = cfg.dcn_pipeline_parallelism * cfg.ici_pipeline_parallelism * cfg.num_pipeline_repeats + layers_per_stage, rem = divmod(cfg.num_decoder_layers, num_logical_stages) + assert layers_per_stage > 0, (cfg.num_decoder_layers, num_logical_stages) + cutoffs = [] + tot = 0 + for stage in range(num_logical_stages): + num_layers_in_pipeline_stage = layers_per_stage + (1 if stage < rem else 0) + tot += num_layers_in_pipeline_stage + cutoffs.append(tot - 1) + + stage_id = 0 + add_last_enter_stage = Version(jaxpp.__version__) > Version("0.6.1") if cfg.decoder_block == DecoderBlockType.DEEPSEEK: assert len(RemattedBlockLayers) == 2, "Unscanned layers must have a length of 2 using deepseek." - dense_layer = RemattedBlockLayers[0] - moe_layer = RemattedBlockLayers[1] - layers = [dense_layer, moe_layer] - layer_prefixes = ["dense_layers", "moe_layers"] - num_moe_layers = cfg.num_decoder_layers - cfg.first_num_dense_layers - num_layers_list = [cfg.first_num_dense_layers, num_moe_layers] - # Iterate over the two layer groups (dense and MoE) and apply layer transformation - for layer, num_layers, layer_prefix in zip(layers, num_layers_list, layer_prefixes): - for index in range(num_layers): - y = layer( - config=cfg, mesh=mesh, name=f"{layer_prefix}_{index}", quant=self.quant, model_mode=self.model_mode - )( - y, - decoder_segment_ids, - decoder_positions, - deterministic, - model_mode, - previous_chunk=previous_chunk, - page_state=page_state, - slot=slot, - ) + for index in range(cfg.first_num_dense_layers): + dense_layer = self.decoder_layer[0] if stage_id != num_logical_stages - 1 else RemattedBlockLayers[0] + y = dense_layer(config=cfg, mesh=mesh, name=f"dense_layers_{index}", quant=self.quant, model_mode=model_mode)( + y, + decoder_segment_ids, + decoder_positions, + deterministic, + model_mode, + previous_chunk=previous_chunk, + page_state=page_state, + slot=slot, + ) + if index != cfg.num_decoder_layers - 1 and cutoffs[stage_id] == index: + y = jaxpp.api.pipeline_enter_stage(y, f"stage_{stage_id}") + stage_id += 1 + + for index in range(cfg.first_num_dense_layers, cfg.num_decoder_layers): + moe_layer = RemattedBlockLayers[1] if stage_id != num_logical_stages - 1 else self.decoder_layer[1] + y = moe_layer(config=cfg, mesh=mesh, name=f"moe_layers_{index - cfg.first_num_dense_layers}", quant=self.quant, model_mode=model_mode)( + y, + decoder_segment_ids, + decoder_positions, + deterministic, + model_mode, + previous_chunk=previous_chunk, + page_state=page_state, + slot=slot, + ) + if index != cfg.num_decoder_layers - 1 and cutoffs[stage_id] == index: + y = jaxpp.api.pipeline_enter_stage(y, f"stage_{stage_id}") + stage_id += 1 + else: for lyr in range(cfg.num_decoder_layers): RemattedBlockLayer = RemattedBlockLayers[0] @@ -831,7 +865,8 @@ def __call__( layer_kwargs = {"layer_idx": lyr} if cfg.decoder_block == DecoderBlockType.GPT_OSS: layer_kwargs = {"attention_type": gpt_oss.get_attention_type(layer_id=lyr)} - layer = RemattedBlockLayer( + layer_ctor = RemattedBlockLayer if stage_id != num_logical_stages - 1 else self.decoder_layer[0] + layer = layer_ctor( config=cfg, mesh=mesh, name=f"layers_{lyr}", quant=self.quant, model_mode=self.model_mode, **layer_kwargs ) y = layer( @@ -845,6 +880,9 @@ def __call__( slot=slot, **layer_call_kwargs, ) + if lyr != cfg.num_decoder_layers - 1 and cutoffs[stage_id] == lyr: + y = jaxpp.api.pipeline_enter_stage(y, f"stage_{stage_id}") + stage_id += 1 assert isinstance(y, jax.Array) @@ -859,6 +897,9 @@ def __call__( else: logits = self._apply_output_head(shared_embedding, hidden_state, deterministic, model_mode) + if add_last_enter_stage: + logits = jaxpp.api.pipeline_enter_stage(logits, f"stage_{stage_id}") + # The API of the Decoder is now a tuple, providing both the main output # and the raw hidden state needed for auxiliary tasks. return logits, hidden_state diff --git a/src/MaxText/max_utils.py b/src/MaxText/max_utils.py index 2fe00341ca..11b34810b3 100644 --- a/src/MaxText/max_utils.py +++ b/src/MaxText/max_utils.py @@ -1,4 +1,5 @@ # Copyright 2023–2025 Google LLC +# Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -40,6 +41,13 @@ from MaxText import max_logging from MaxText.common_types import MODEL_MODE_PREFILL, MODEL_MODE_AUTOREGRESSIVE, MODEL_MODE_TRAIN +import tensorflow as tf +import jaxpp.api as jaxpp +import optax + +# jaxpp related imports +import jaxpp.api as jaxpp + initialize_multi_tier_checkpointing = initialization.initialize_multi_tier_checkpointing HYBRID_RING_64X4 = "hybrid_ring_64x4" HYBRID_RING_32X8 = "hybrid_ring_32x8" @@ -47,6 +55,12 @@ # pylint: disable=too-many-positional-arguments +def maybe_unwrap(a: jaxpp.MpmdArray | jax.Array): + if isinstance(a, jaxpp.MpmdArray): + return v if (v := a.first_mpmd_replica) is not None else 0 + return a + + def with_memory_kind(t, memory_kind): return jax.tree_util.tree_map(lambda x: x.with_memory_kind(kind=memory_kind), t) @@ -66,7 +80,8 @@ def finder(x): def l2norm_pytree(x): """L2 norm of a pytree of arrays.""" - return jnp.sqrt(jax.tree_util.tree_reduce(lambda x, y: x + jnp.sum(jnp.square(y)), x, initializer=0.0)) + per_param_sum = [jnp.sum(jnp.square(x)) for x in jax.tree.leaves(x)] + return jnp.sqrt(jaxpp.cross_mpmd_all_reduce(*(e.astype(jnp.float32) for e in per_param_sum))) def calculate_num_params_from_pytree(params): @@ -596,6 +611,23 @@ def _cross_entropy_with_logits_bwd( cross_entropy_with_logits.defvjp(_cross_entropy_with_logits_fwd, _cross_entropy_with_logits_bwd) +def maybe_update_params_sharding_with_opt(config, state_mesh_shardings): + prev_params_shardings = state_mesh_shardings.params + if config.shard_optimizer_over_data: + if isinstance(state_mesh_shardings.opt_state, optax.ScaleByAdamState): + sharded_fp32_params = state_mesh_shardings.opt_state.mu + elif isinstance(state_mesh_shardings.opt_state, tuple) and isinstance(state_mesh_shardings.opt_state[0], optax.ScaleByAdamState): + sharded_fp32_params = state_mesh_shardings.opt_state[0].mu + else: + raise NotImplementedError(f"Could not find optimizer state shardings from optimizer of type {type(state_mesh_shardings.opt_state)}") + if "params" not in sharded_fp32_params.keys(): + # When quantization=fp8 is enabled the sharded_fp32_params + # are not wrapped in `params`. Here we wrap them back. + sharded_fp32_params = {"params": sharded_fp32_params} + state_mesh_shardings = state_mesh_shardings.replace(params=dict(prev_params_shardings, **sharded_fp32_params)) + return prev_params_shardings, state_mesh_shardings + + def print_pytree_shape(print_str, ptree): print("\n") print(print_str) @@ -656,7 +688,8 @@ def print_mem_stats(label: str): stats = d.memory_stats() used = round(stats["bytes_in_use"] / 2**30, 2) limit = round(stats["bytes_limit"] / 2**30, 2) - max_logging.log(f"\tUsing (GB) {used} / {limit} ({used/limit:%}) on {d}") + peak_size = round(stats["peak_bytes_in_use"] / 2**30, 2) + max_logging.log(f"\tUsing (GB) {used} / {limit} ({used/limit:%}) ({peak_size=} GiB) on {d}") except (RuntimeError, KeyError, TypeError) as ex: max_logging.log(f"\tMemstats unavailable, error: {ex}") diff --git a/src/MaxText/maxtext_utils.py b/src/MaxText/maxtext_utils.py index 3b2eed03b0..3b9fcd30dc 100644 --- a/src/MaxText/maxtext_utils.py +++ b/src/MaxText/maxtext_utils.py @@ -1,4 +1,5 @@ # Copyright 2023–2025 Google LLC +# Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -45,6 +46,10 @@ from MaxText.common_types import DecoderBlockType, MODEL_MODE_PREFILL, MODEL_MODE_AUTOREGRESSIVE from MaxText.inference.page_manager import PageState +import chex +from optax._src import base +import jaxpp.api as jaxpp + OVERWRITE_WITH_GRADIENT = "_overwrite_with_gradient" @@ -832,7 +837,24 @@ def assert_params_sufficiently_sharded(params, mesh, tolerance): ) -def apply_gradient_clipping(raw_grads, state, clipping_threshold): +def clip_by_global_norm(g_norm: jax.Array, max_norm: float) -> optax.GradientTransformation: + # Adaptation of `optax.transforms._clipping.clip_by_global_norm` + # that takes `g_norm` as argument + def update_fn(updates, state, params=None): + del params + trigger = jnp.squeeze(g_norm < max_norm) + chex.assert_shape(trigger, ()) # A scalar. + + def clip_fn(t): + return jax.lax.select(trigger, t, (t / g_norm.astype(t.dtype)) * max_norm) + + updates = jax.tree.map(clip_fn, updates) + return updates, state + + return optax.GradientTransformation(base.init_empty_state, update_fn) + + +def apply_gradient_clipping(raw_grads, state, g_norm, clipping_threshold): """Applies gradient clipping to raw gradients, with special handing for FLAX fp8 stats. Args: @@ -843,7 +865,7 @@ def apply_gradient_clipping(raw_grads, state, clipping_threshold): Returns: A pytree of clipped gradients. """ - gradient_clip_transformation = optax.clip_by_global_norm(clipping_threshold) + gradient_clip_transformation = clip_by_global_norm(g_norm, clipping_threshold) if OVERWRITE_WITH_GRADIENT in raw_grads: # Scales + Amax History for Delayed Tensor Scaling SHOULD NOT be clipped or affect clipping fp8_stats = raw_grads.pop(OVERWRITE_WITH_GRADIENT) @@ -1067,7 +1089,7 @@ def setup_initial_state( tx, config, rng, - mesh, + maybe_mpmd_mesh, checkpoint_manager, is_training=True, ): @@ -1088,12 +1110,19 @@ def setup_initial_state( state_mesh_annotations: the mesh annotations for the train state """ + mesh = maybe_mpmd_mesh + if isinstance(maybe_mpmd_mesh, jaxpp.MpmdMesh): + mesh = maybe_mpmd_mesh.lowering_mesh() + unboxed_abstract_state, state_mesh_annotations, state_mesh_shardings = get_abstract_state( model, tx, config, rng, mesh, is_training ) # Initialization with nn_partitioning.axis_rules(config.logical_axis_rules): + if checkpoint_manager is not None: + assert not config.use_jaxpp + restored, raw_params = checkpointing.load_state_if_possible( checkpoint_manager, data_iterator, @@ -1126,20 +1155,91 @@ def setup_initial_state( else: init_state_partial = functools.partial(init_initial_state, model, tx, config, is_training) init_state_partial.__name__ = "initialize_state" - # pylint: disable=not-callable - state = jax.jit( - init_state_partial, - in_shardings=None, - out_shardings=state_mesh_shardings, - )(rng) - if raw_params: # If we loaded a partial state, we need to merge it. - state = state.replace(params=raw_params) + if config.use_jaxpp: + # First infer placement based on loop usage + # Imported here to avoid circular import errors + from MaxText import maxtext_utils + from MaxText.train import train_step + params_shardings, _state_mesh_shardings = max_utils.maybe_update_params_sharding_with_opt(config, state_mesh_shardings) + data_sharding = maxtext_utils.get_input_data_sharding(config, mesh) + ( + functional_train, + in_shard_train, + out_shard_train, + static_argnums_train, + donate_argnums_train, + ) = maxtext_utils.get_functional_train_with_signature(train_step, data_sharding, _state_mesh_shardings, model, config, params_shardings=params_shardings) + + p_train_step = jaxpp.mpmd_jit_with_loop( + functional_train, + mpmd_mesh=maybe_mpmd_mesh, + donate_argnums=donate_argnums_train, + in_shardings=in_shard_train, + out_shardings=out_shard_train, + ) + with mesh, nn_partitioning.axis_rules(config.logical_axis_rules): + global_mpmd_train_step = p_train_step.trace_and_place( + unboxed_abstract_state, next(data_iterator), rng + ) + unboxed_abstract_state_placements = global_mpmd_train_step.in_shardings[0][0] + def attach_right_mesh(shaped: jax.ShapeDtypeStruct, dist_sharding): + sharding: jax.sharding.NamedSharding = shaped.sharding + jax_mesh = maybe_mpmd_mesh.mpmd_submesh(sorted(dist_sharding.mesh_ids)).jax_mesh + mpmd_sharding = jax.sharding.NamedSharding(jax_mesh, sharding.spec) + return jax.ShapeDtypeStruct(shaped.shape, shaped.dtype, sharding=mpmd_sharding, weak_type=shaped.weak_type) + + unboxed_mpmd_abstract_state = jax.tree.map(attach_right_mesh, unboxed_abstract_state, unboxed_abstract_state_placements) + replicated_sharding = jax.sharding.NamedSharding( + maybe_mpmd_mesh.lowering_mesh(), jax.sharding.PartitionSpec() + ) + state = jaxpp.mpmd_jit_rev( + lambda rng: jax.tree.map(jax._src.numpy.lax_numpy._array_copy, max_utils.unbox_logicallypartioned(init_state_partial(rng))), + out_refs=jax.tree.map(lambda s: s.mesh_ids, unboxed_abstract_state_placements), + mpmd_mesh=maybe_mpmd_mesh, + in_shardings=replicated_sharding, + out_shardings=in_shard_train[0], + )(rng) + else: + # pylint: disable=not-callable + state = jax.jit( + init_state_partial, + in_shardings=None, + out_shardings=state_mesh_shardings, + )(rng) + if raw_params: # If we loaded a partial state, we need to merge it. + state = state.replace(params=raw_params) - state = max_utils.unbox_logicallypartioned(state) + state = max_utils.unbox_logicallypartioned(state) return state, state_mesh_annotations, state_mesh_shardings, data_iterator +def add_data_to_sharding(mesh, path, aval, sharding): + if not isinstance(sharding, jax.sharding.NamedSharding): + raise AssertionError(f"Expected NamedSharding, found {sharding} of {type(sharding)=} at {jax.tree_util.keystr(path)}") + try: + sharded_shape = sharding.shard_shape(aval.shape) + except Exception as e: + raise AssertionError(f"Could not shard value {jax.tree_util.keystr(path)} of shape={aval.shape} with {sharding=}") from e + pspec = sharding.spec + + if 'data' in jax.tree.leaves(pspec): + return sharding + + for idx, (size, partition) in enumerate(zip(sharded_shape, pspec)): + if partition is None: + partition = () + + if isinstance(partition, str): + partition = (partition,) + + if size % mesh.shape['data'] == 0 and (partition is None or 'tensor' not in partition): + added_component = ('data',) + partition + new_pspec = jax.sharding.PartitionSpec(*(pspec[:idx] + (added_component,) + pspec[idx+1:])) + return sharding.update(spec=new_pspec) + return sharding + + def get_abstract_state(model, tx, config, rng, mesh, is_training=True): """Get a shaped abstraction of the state (including optimizer)""" init_state_partial = functools.partial(init_initial_state, model, tx, config, is_training, rng) diff --git a/src/MaxText/metric_logger.py b/src/MaxText/metric_logger.py index 010f769f33..111798c9b1 100644 --- a/src/MaxText/metric_logger.py +++ b/src/MaxText/metric_logger.py @@ -90,6 +90,7 @@ def reset_eval_metrics(self): def write_metrics(self, metrics, step, is_training=True): """Entry point for all metrics writing in Train's Main.""" if metrics: + metrics = jax.tree.map(max_utils.maybe_unwrap, metrics) self.log_metrics(metrics, step, is_training) if self.config.enable_tensorboard: @@ -229,21 +230,21 @@ def record_eval_metrics(self, step, metrics=None, eval_step_count=None): """Records eval metrics and writes the metrics to GCS and/or to TensorBoard.""" if metrics: self.cumulative_eval_metrics["scalar"]["eval/total_loss"] += float( - metrics["scalar"].get("evaluation/total_loss", 0.0) + max_utils.maybe_unwrap(metrics["scalar"].get("evaluation/total_loss", 0.0)) ) self.cumulative_eval_metrics["scalar"]["eval/total_weights"] += float( - metrics["scalar"].get("evaluation/total_weights", 0.0) + max_utils.maybe_unwrap(metrics["scalar"].get("evaluation/total_weights", 0.0)) ) self.cumulative_eval_metrics["scalar"]["eval/moe_lb_loss"] += float( - metrics["scalar"].get("evaluation/moe_lb_loss", 0.0) + max_utils.maybe_unwrap(metrics["scalar"].get("evaluation/moe_lb_loss", 0.0)) ) - self.cumulative_eval_metrics["scalar"]["eval/mtp_loss"] += float(metrics["scalar"].get("evaluation/mtp_loss", 0.0)) + self.cumulative_eval_metrics["scalar"]["eval/mtp_loss"] += float(max_utils.maybe_unwrap(metrics["scalar"].get("evaluation/mtp_loss", 0.0))) self.cumulative_eval_metrics["scalar"]["eval/mtp_acceptance_rate_percent"] += float( - metrics["scalar"].get("evaluation/mtp_acceptance_rate_percent", 0.0) + max_utils.maybe_unwrap(metrics["scalar"].get("evaluation/mtp_acceptance_rate_percent", 0.0)) ) if self.config.use_dpo: self.cumulative_eval_metrics["scalar"]["eval/dpo_reward_accuracy"] += float( - metrics["scalar"].get("evaluation/dpo_reward_accuracy", 0.0) + max_utils.maybe_unwrap(metrics["scalar"].get("evaluation/dpo_reward_accuracy", 0.0)) ) if eval_step_count: diff --git a/src/MaxText/model_creation_utils.py b/src/MaxText/model_creation_utils.py index cf2ae6da42..03baed3664 100644 --- a/src/MaxText/model_creation_utils.py +++ b/src/MaxText/model_creation_utils.py @@ -1,4 +1,5 @@ # Copyright 2023–2025 Google LLC +# Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -31,6 +32,9 @@ from functools import partial from etils import epath +# jaxpp +import jaxpp.api as jaxpp + @overload def from_config( @@ -76,9 +80,16 @@ def from_config( model = from_config(config) """ devices_array = maxtext_utils.create_device_mesh(config, devices) - mesh = Mesh(devices_array, config.mesh_axes) - model = create_model(config, mesh, model_mode=model_mode, rngs=rngs) + if not config.use_jaxpp: + mesh = Mesh(devices_array, config.mesh_axes) + else: + mesh = jaxpp.MpmdMesh(Mesh(devices_array, config.mesh_axes), 'stage') + model = create_model(config, mesh.lowering_mesh() if config.use_jaxpp else mesh, model_mode=model_mode, rngs=rngs) + if config.use_jaxpp: + # At this point, model.mesh has mesh.lowering_mesh() as its value, but we need to set it to the original mesh + # so that the caller can have access to the original mesh. + model.mesh = mesh # Return only the model return model diff --git a/src/MaxText/profiler.py b/src/MaxText/profiler.py index e32e49ff2a..1131a4342c 100644 --- a/src/MaxText/profiler.py +++ b/src/MaxText/profiler.py @@ -40,8 +40,9 @@ def __init__(self, config, offset_step=0): self.finished_initial_profile_step = self._set_last_profiler_step(config.profiler_steps, config.steps) if config.profiler != "" and self.start_initial_profile_step >= config.steps: raise ValueError("Profiling requested but initial profiling step set past training final step") + self.use_jaxpp = config.use_jaxpp - def maybe_activate_profiler(self, step, state): + def maybe_activate_profiler(self, step, state, maybe_mpmd_mesh=None, profiling_process_ids=None): """Conditionally activates the profiler based on the current step. This method checks if the current training step matches the step designated for starting an initial profile, or if it meets the criteria for @@ -49,14 +50,21 @@ def maybe_activate_profiler(self, step, state): """ if self.mode != "" and (step == self.start_initial_profile_step or self.should_activate_periodic_profile(step)): optional_postfix = f"step_{step}" if self.profile_period > 0 else "" - self.activate(blocking_object=state, optional_postfix=optional_postfix) - - def activate(self, blocking_object=None, optional_postfix=""): + if self.use_jaxpp: + assert maybe_mpmd_mesh is not None + assert profiling_process_ids is not None + if maybe_mpmd_mesh.jax_mesh.is_multi_process and jax.process_index() in profiling_process_ids: + optional_postfix = f"mpmd_{maybe_mpmd_mesh.my_mpmd_axis_index:02}_gpu_{profiling_process_ids[jax.process_index()].id:06}_{optional_postfix}" + optional_postfix = f"proc_{jax.process_index():06}_{optional_postfix}" + profile = profiling_process_ids is None or jax.process_index() in profiling_process_ids + self.activate(blocking_object=state, optional_postfix=optional_postfix, profile=profile) + + def activate(self, blocking_object=None, optional_postfix="", profile=True): """Start the profiler. nsys profiler becomes no-op when libcudart.so is not available on the system.""" if self.profile_cleanly and blocking_object is not None: jax.block_until_ready(blocking_object) - if not (self.upload_all_profiler_results or jax.process_index() == 0): + if not (self.upload_all_profiler_results or jax.process_index() == 0) or not profile: return if self.mode != "": self.output_path = os.path.join(self.base_output_dir, optional_postfix) @@ -70,21 +78,22 @@ def activate(self, blocking_object=None, optional_postfix=""): elif self.mode == "xplane": jax.profiler.start_trace(self.output_path) - def maybe_deactivate_profiler(self, step, state): + def maybe_deactivate_profiler(self, step, state, profiling_process_ids=None): """Conditionally deactivates the profiler based on the current step. This method checks if the current training step matches the step designated for finishing the initial profile, or if it meets the criteria for deactivating a periodic profile. """ if self.mode != "" and (step == self.finished_initial_profile_step or self.should_deactivate_periodic_profile(step)): - self.deactivate(blocking_object=state) + profile = profiling_process_ids is None or jax.process_index() in profiling_process_ids + self.deactivate(blocking_object=state, profile=profile) - def deactivate(self, blocking_object=None): + def deactivate(self, blocking_object=None, profile=True): """End the profiler. The result is uploaded to the output bucket.""" if self.profile_cleanly and blocking_object is not None: jax.block_until_ready(blocking_object) - if not (self.upload_all_profiler_results or jax.process_index() == 0): + if not (self.upload_all_profiler_results or jax.process_index() == 0) or not profile: return if self.mode == "nsys": if self.libcudart is not None: diff --git a/src/MaxText/pyconfig.py b/src/MaxText/pyconfig.py index 9c93a902b4..ad23d613e6 100644 --- a/src/MaxText/pyconfig.py +++ b/src/MaxText/pyconfig.py @@ -1,4 +1,5 @@ # Copyright 2023–2025 Google LLC +# Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -902,7 +903,8 @@ def modify_activation_embed_and_logits_batch(logical_axis_rules): # The "stage" needs to be listed first since the microbatch dimension is first before the reshape. logical_axis_rules[idx] = [ "activation_embed_and_logits_batch", - ["stage", "data", "fsdp", "fsdp_transpose", "expert"], + ["stage", "data", "fsdp", "fsdp_transpose", "expert"] if not raw_keys["use_jaxpp"] else + ["stage", "fsdp", "fsdp_transpose", "expert"], ] break # Exit the loop after modifying the list return logical_axis_rules @@ -978,6 +980,12 @@ def pipeline_first_axis(raw_keys): raw_keys["logical_axis_rules"] = modify_activation_embed_and_logits_batch(raw_keys["logical_axis_rules"]) raw_keys = pipeline_first_axis(raw_keys) num_stages = int(raw_keys["ici_pipeline_parallelism"] * raw_keys["dcn_pipeline_parallelism"]) + if raw_keys["use_jaxpp"]: + assert raw_keys["pipeline_delay_activation_forwarding"] is False + assert raw_keys["num_pipeline_repeats"] >= 1 + assert raw_keys["num_pipeline_microbatches"] >= 1 + return raw_keys + if raw_keys["pipeline_parallel_layers"] == -1: if raw_keys["decoder_block"] == "deepseek": moe_layers = raw_keys["num_decoder_layers"] - raw_keys["first_num_dense_layers"] @@ -1019,7 +1027,7 @@ def pipeline_first_axis(raw_keys): else: raw_keys["num_pipeline_microbatches"] = num_stages assert ( - raw_keys["num_pipeline_microbatches"] % num_stages == 0 + raw_keys["num_pipeline_microbatches"] % num_stages == 0 or raw_keys["use_jaxpp"] ), f"The number of microbatches ({raw_keys['num_pipeline_microbatches']}) must be divisible by the number of stages ({num_stages})" assert ( raw_keys["micro_batch_size_to_train_on"] % raw_keys["num_pipeline_microbatches"] == 0 @@ -1067,7 +1075,7 @@ def validate_gpt_oss_moe(raw_keys): def validate_sparse_matmul_parallelism(raw_keys): # TODO: remove once b/434699033 resolved - if raw_keys["sparse_matmul"] and (using_expert_parallelism(raw_keys) and using_pipeline_parallelism(raw_keys)): + if raw_keys["sparse_matmul"] and (using_expert_parallelism(raw_keys) and (not raw_keys["use_jaxpp"] and using_pipeline_parallelism(raw_keys))): raise ValueError("Sparse matmul doesn't support using expert and pipeline parallelism together.") # TODO: remove once b/435539039 resolved @@ -1237,8 +1245,7 @@ def get_context_parallel_size(raw_keys): def using_pipeline_parallelism(raw_keys) -> bool: - return int(raw_keys["ici_pipeline_parallelism"]) > 1 or int(raw_keys["dcn_pipeline_parallelism"]) > 1 - + return raw_keys["use_jaxpp"] or int(raw_keys["ici_pipeline_parallelism"]) > 1 or int(raw_keys["dcn_pipeline_parallelism"]) > 1 def using_tensor_parallelism(raw_keys) -> bool: return ( diff --git a/src/MaxText/train.py b/src/MaxText/train.py index 5f60001a5e..45b33ac258 100644 --- a/src/MaxText/train.py +++ b/src/MaxText/train.py @@ -1,4 +1,5 @@ # Copyright 2023–2025 Google LLC +# Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -42,6 +43,8 @@ from cloud_tpu_diagnostics.configuration import diagnostic_configuration from cloud_tpu_diagnostics.configuration import stack_trace_configuration +from packaging.version import Version + from MaxText import checkpointing from MaxText import exceptions from MaxText import max_logging @@ -69,11 +72,25 @@ from MaxText.dpo_utils import _merge_dpo_state, _split_dpo_state, dpo_loss_fn from MaxText.train_utils import validate_train_config from MaxText.metric_logger import record_activation_metrics + +""" +JaxPP related imports +""" +# system +import subprocess + +from statistics import mean + +# jaxpp +from jaxpp import __version__ as jaxpp_version +from packaging.version import Version +import jaxpp.api as jaxpp + # pylint: disable=too-many-positional-arguments def get_first_step(state): - return int(state.step) + return int(max_utils.maybe_unwrap(state.step)) # ----------------------------------------------------------------------------- @@ -97,7 +114,7 @@ def loss_fn(model, config, data, dropout_rng, params, is_train=True): aux: a dictionary including intermediate_outputs, total_loss, and total_weights """ # decimate proportion of data when per_device_batch_size<1 - if is_train: + if is_train and not config.use_jaxpp: for k, v in data.items(): data[k] = v[: config.micro_batch_size_to_train_on, :] else: @@ -206,6 +223,44 @@ def loss_fn(model, config, data, dropout_rng, params, is_train=True): return loss, aux +def load_schedule(config): + pipeline_parallel_dim = config.dcn_pipeline_parallelism * config.ici_pipeline_parallelism + num_logical_stages = config.num_pipeline_repeats * pipeline_parallel_dim + schedule = None + if config.schedule == "1f1b": + assert num_logical_stages <= pipeline_parallel_dim + schedule = jaxpp.Std1F1B(num_logical_stages) + elif config.schedule == "eager_1f1b": + assert num_logical_stages <= pipeline_parallel_dim + schedule = jaxpp.Eager1F1B(num_logical_stages) + elif config.schedule == "interleaved_1f1b": + if Version(jaxpp_version) > Version("0.6.1"): + schedule = jaxpp.Interleaved1F1B(num_logical_stages, pipeline_parallel_dim, config.fuse_steady_state) + else: + schedule = jaxpp.Interleaved1F1B(num_logical_stages, pipeline_parallel_dim) + elif config.schedule == "zero_bubble": + assert num_logical_stages <= pipeline_parallel_dim + schedule = jaxpp.ZeroBubble(num_logical_stages) + elif config.schedule == "dualpipev": + schedule = jaxpp.DualPipeV(num_logical_stages, pipeline_parallel_dim) + else: + raise NotImplementedError(f"Unknown schedule {config.schedule}") + return schedule + + +def add_leading_axis( + axis_name: str, path: jax.tree_util.KeyPath, s: jax.sharding.NamedSharding +): + assert isinstance(s, jax.sharding.NamedSharding) + used = {n for ns in s.spec for n in (ns if isinstance(ns, tuple) else (ns,))} + if axis_name in used: + raise ValueError( + f"mesh axis name {axis_name} cannot appear in " + f"out_shardings. Found out_shardings{jax.tree_util.keystr(path)}={s.spec}" + ) + return jax.sharding.NamedSharding(s.mesh, jax.sharding.PartitionSpec(axis_name, *s.spec), memory_kind=s.memory_kind) + + def train_step(model, config, state_mesh_shardings, params_shardings, state, data, dropout_rng): """ @@ -248,21 +303,85 @@ def train_step(model, config, state_mesh_shardings, params_shardings, state, dat reference_params, max_utils.with_memory_kind(reference_params_sharding, "device") ) extra_dpo_args = [reference_params] + if config.shard_optimizer_over_data: params = jax.tree.map(jax.lax.with_sharding_constraint, params, params_shardings) - grad_func = jax.value_and_grad(_loss_fn, argnums=4, has_aux=True) - (loss, aux), raw_grads = grad_func(model, config, data, dropout_rng, params, *extra_dpo_args, is_train=True) + def compute_grads(data): + grad_func = jax.value_and_grad(loss_fn, argnums=4, has_aux=True) + (loss, aux), raw_grads = grad_func(model, config, data, dropout_rng, params, *extra_dpo_args, is_train=True) + def cast(p, a): + sp = jax.tree_util.keystr(p) + if 'token_embedder' in sp or 'position_embedder' in sp: + return a + return a.astype(jnp.dtype(config.grad_dtype)) + raw_grads['params'] = jax.tree_util.tree_map_with_path(cast, raw_grads['params']) + return ((loss, aux), raw_grads) + + if not config.use_jaxpp: + (loss, aux), raw_grads = compute_grads(data) + else: + def microbatched(a): + shape = ( + state_mesh_shardings.step.mesh.shape["data"], + config.num_pipeline_microbatches, + -1, + config.max_target_length, + ) + if shape[0] == 1: + shape = shape[1:] + return a.reshape(*shape) + data = jax.tree.map(microbatched, data) + + # Perform data parallelism manually through `vmap` + vmapped_compute_grads = compute_grads + if state_mesh_shardings.step.mesh.shape["data"] > 1: + vmapped_compute_grads = jax.vmap(compute_grads, spmd_axis_name="data") + + loss_aux_sharding = jax.sharding.NamedSharding(state_mesh_shardings.step.mesh, jax.sharding.PartitionSpec()) + param_operation = {'params': jaxpp.Add} + if nn.fp8_ops.OVERWRITE_WITH_GRADIENT in params: + param_operation[nn.fp8_ops.OVERWRITE_WITH_GRADIENT] = jaxpp.Max + assert all(k in param_operation for k in params.keys()) + + axis = 1 if state_mesh_shardings.step.mesh.shape["data"] > 1 else 0 + (loss, aux), raw_grads = jaxpp.treduce( + vmapped_compute_grads, + data, + axis=axis, + schedule=load_schedule(config), + operation=(jaxpp.Concat(axis=axis), param_operation) + ) + if state_mesh_shardings.step.mesh.shape["data"] > 1: + (loss, aux), raw_grads = jax.lax.with_sharding_constraint( + ((loss, aux), raw_grads), + jax.tree.map_with_path( + functools.partial(add_leading_axis, "data"), + (loss_aux_sharding, params_shardings) + ), + ) + # reduce-scatter gradients across "data" + owg = raw_grads.pop(nn.fp8_ops.OVERWRITE_WITH_GRADIENT, None) + raw_grads = jax.tree.map(functools.partial(jax.numpy.sum, axis=0), raw_grads) + if owg is not None: + owg = jax.tree.map(functools.partial(jax.numpy.max, axis=0), owg) + raw_grads[nn.fp8_ops.OVERWRITE_WITH_GRADIENT] = owg + + raw_grads = jax.lax.with_sharding_constraint(raw_grads, state_mesh_shardings.params) raw_grads = jax.tree_util.tree_map(lambda x: x.astype(config.grad_dtype) if x.dtype == jnp.float32 else x, raw_grads) + owg = raw_grads.pop(nn.fp8_ops.OVERWRITE_WITH_GRADIENT, None) + raw_grad_norm = max_utils.l2norm_pytree(raw_grads) intermediate_outputs = aux["intermediate_outputs"] total_weights = aux["total_weights"] moe_lb_loss = aux["moe_lb_loss"] mtp_loss = aux["mtp_loss"] if config.gradient_clipping_threshold > 0: - grads = maxtext_utils.apply_gradient_clipping(raw_grads, state, config.gradient_clipping_threshold) + grads = maxtext_utils.apply_gradient_clipping(raw_grads, state, raw_grad_norm, config.gradient_clipping_threshold) else: grads = raw_grads + if owg is not None: + grads[nn.fp8_ops.OVERWRITE_WITH_GRADIENT] = owg if config.optimizer_memory_host_offload: state = state.replace( opt_state=jax.device_put( @@ -286,16 +405,36 @@ def move(path, value): ) new_state = state.apply_gradients(grads=grads) - scalar_metrics = { - "learning/loss": loss, - "learning/moe_lb_loss": moe_lb_loss, - "learning/mtp_loss": mtp_loss, - "learning/total_weights": total_weights, - } + if config.use_jaxpp: + # TODO: refine logic to match the one in MaxText's gradient accumulation + # or use that altogether (add support for scan instead of + # treduce in JaxPP) + scalar_metrics = { + "learning/loss": loss.sum() / total_weights.sum(), + "learning/moe_lb_loss": moe_lb_loss.sum(), + "learning/mtp_loss": mtp_loss.sum(), + "learning/total_weights": total_weights.sum(), + } + else: + scalar_metrics = { + "learning/loss": loss, + "learning/moe_lb_loss": moe_lb_loss, + "learning/mtp_loss": mtp_loss, + "learning/total_weights": total_weights, + } if not config.optimizer_memory_host_offload: - scalar_metrics["learning/grad_norm"] = max_utils.l2norm_pytree(grads) - scalar_metrics["learning/raw_grad_norm"] = max_utils.l2norm_pytree(raw_grads) - scalar_metrics["learning/param_norm"] = max_utils.l2norm_pytree(new_state.params) + owg = grads.pop(nn.fp8_ops.OVERWRITE_WITH_GRADIENT, None) + scalar_metrics["learning/grad_norm"] = max_utils.l2norm_pytree(grads["params"]) + if owg is not None: + grads[nn.fp8_ops.OVERWRITE_WITH_GRADIENT] = owg + scalar_metrics["learning/raw_grad_norm"] = raw_grad_norm + + new_params = new_state.params + owg = new_params.pop(nn.fp8_ops.OVERWRITE_WITH_GRADIENT, None) + scalar_metrics["learning/param_norm"] = max_utils.l2norm_pytree(new_params) + if owg is not None: + new_params[nn.fp8_ops.OVERWRITE_WITH_GRADIENT] = owg + new_state = new_state.replace(params=new_params) if config.use_dpo: scalar_metrics["learning/dpo_reward_accuracy"] = aux["reward_accuracy"] metrics = { @@ -345,7 +484,7 @@ def eval_step(model, config, state, data, dropout_rng): if config.use_dpo: metrics["scalar"]["evaluation/dpo_reward_accuracy"] = aux["reward_accuracy"] - return metrics + return jax.tree.map(jax._src.numpy.lax_numpy._array_copy, metrics) def train_loop(config, recorder, state=None): @@ -355,7 +494,7 @@ def train_loop(config, recorder, state=None): checkpoint_manager, state_mesh_shardings, model, - mesh, + maybe_mpmd_mesh, learning_rate_schedule, data_iterator, eval_data_iterator, @@ -368,24 +507,23 @@ def train_loop(config, recorder, state=None): state = _merge_dpo_state(state, reference_params) state_mesh_shardings = _merge_dpo_state(state_mesh_shardings, state_mesh_shardings.params["params"]) - params_shardings, state_mesh_shardings = maxtext_utils.maybe_update_params_sharding_with_opt( - config, state_mesh_shardings - ) + mesh = maybe_mpmd_mesh.lowering_mesh() if config.use_jaxpp else maybe_mpmd_mesh params_shardings, state_mesh_shardings = maxtext_utils.maybe_update_params_sharding_with_opt( config, state_mesh_shardings ) p_train_step, p_eval_step = train_utils.jit_train_and_eval_step( - config, model, mesh, state, state_mesh_shardings, train_step, eval_step, eval_data_iterator, params_shardings + config, model, maybe_mpmd_mesh, state, state_mesh_shardings, train_step, eval_step, eval_data_iterator, params_shardings ) - with mesh, nn_partitioning.axis_rules(config.logical_axis_rules): - shaped_batch = maxtext_utils.get_shaped_batch(config) - if config.shard_optimizer_over_data: - state = jax.lax.with_sharding_constraint(state, state_mesh_shardings) - compiled = p_train_step.lower(state, shaped_batch, init_rng).compile() - compiled_stats = compiled.memory_analysis() - max_utils.print_compiled_memory_stats(compiled_stats) + if not config.use_jaxpp: + with mesh, nn_partitioning.axis_rules(config.logical_axis_rules): + shaped_batch = maxtext_utils.get_shaped_batch(config) + if config.shard_optimizer_over_data: + state = jax.lax.with_sharding_constraint(state, state_mesh_shardings) + compiled = p_train_step.lower(state, shaped_batch, init_rng).compile() + compiled_stats = compiled.memory_analysis() + max_utils.print_compiled_memory_stats(compiled_stats) start_step = get_first_step(state) # this is the start_step for training prof = profiler.Profiler(config, offset_step=start_step) @@ -396,9 +534,18 @@ def train_loop(config, recorder, state=None): metric_logger.write_setup_info_to_tensorboard(state.params) try: + step_time = [] + step_tflops = [] + # NOTE: The dict values are unused when use_jaxpp is False. + profiling_process_ids = {pid: "" for pid in jax.process_indices()} + if config.use_jaxpp: + idx = tuple(slice(None) if i == maybe_mpmd_mesh.mpmd_axis else 0 for i in range(len(maybe_mpmd_mesh.jax_mesh.shape))) + first_device_per_mpmd_rank = maybe_mpmd_mesh.jax_mesh.devices[idx] + profiling_process_ids = {d.process_index: d for d in first_device_per_mpmd_rank} + last_step_completion = datetime.datetime.now() for step in np.arange(start_step, config.steps): - prof.maybe_activate_profiler(step, state) + prof.maybe_activate_profiler(step, state, maybe_mpmd_mesh=maybe_mpmd_mesh, profiling_process_ids=profiling_process_ids) with jax.profiler.StepTraceAnnotation("train", step_num=step): example_batch = data_loader.load_next_batch() @@ -406,7 +553,7 @@ def train_loop(config, recorder, state=None): nextrng = jax.jit(jax.random.fold_in)(init_rng, step) with maybe_record_goodput(recorder, GoodputEvent.STEP, step): with mesh, nn_partitioning.axis_rules(config.logical_axis_rules): - if config.shard_optimizer_over_data: + if config.shard_optimizer_over_data and not config.use_jaxpp: state = jax.lax.with_sharding_constraint(state, state_mesh_shardings) state, metrics = p_train_step(state, example_batch, nextrng) @@ -438,6 +585,7 @@ def train_loop(config, recorder, state=None): if config.eval_steps > 0 and eval_step_count >= config.eval_steps: break with mesh, nn_partitioning.axis_rules(config.logical_axis_rules): + eval_batch = jax.tree_util.tree_map(lambda a: a[:2], eval_batch) eval_metrics = p_eval_step(state, eval_batch, nextrng) metric_logger.record_eval_metrics(step, metrics=eval_metrics) max_logging.log(f"Completed eval step {eval_step_count}") @@ -447,12 +595,21 @@ def train_loop(config, recorder, state=None): prof.deactivate() raise exceptions.StopTraining(f"Target loss {config.target_eval_loss=} is achieved.") - prof.maybe_deactivate_profiler(step, state) + prof.maybe_deactivate_profiler(step, state, profiling_process_ids=profiling_process_ids) if step == start_step: max_utils.print_mem_stats("After params initialized") metric_logger.buffer_and_write_train_metrics(metrics, step, step_time_delta) + step_time.append(metrics['scalar']['perf/step_time_seconds']) + step_tflops.append(metrics['scalar']['perf/per_device_tflops_per_sec']) + + if config.use_jaxpp and prof.mode != "": + command = """find . -wholename '*proc_*_mpmd*/*.xplane.pb' | sort | awk '{line=$0; sub(/.*mpmd_/, "", line); sub(/_.*/, "", line); printf "%d:%s:0 ", line, $0}'""" + subprocess.run( + [f"merge_multihost_xplanes $({command})"], + shell=True, cwd=config.tensorboard_dir, check=True + ) if config.save_checkpoint_on_completion: state_to_save = state if not config.use_dpo else _split_dpo_state(state)[0] @@ -462,6 +619,13 @@ def train_loop(config, recorder, state=None): finally: metric_logger.flush_metrics_and_cleanup() + # last_profiling_step + 2 as (1) we count steps from 0, and (2) the execution time for merge_multihost_xplanes is + # counted toward the execution time for the step right after the last profiling step. + num_warmup_steps = (prof.finished_initial_profile_step + 2) if prof.mode != "" else 6 + max_logging.log( + f"excluding the first {num_warmup_steps} steps: avg time per step {mean(step_time[num_warmup_steps:])}, avg tflops per step {mean(step_tflops[num_warmup_steps:])}" + ) + return state diff --git a/src/MaxText/train_compile.py b/src/MaxText/train_compile.py index 7f9ce2d20f..ae2f73e693 100644 --- a/src/MaxText/train_compile.py +++ b/src/MaxText/train_compile.py @@ -1,4 +1,5 @@ # Copyright 2023–2025 Google LLC +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -45,6 +46,9 @@ from MaxText.layers import quantizations from MaxText.utils import gcs_utils +import jaxpp.api as jaxpp +from jaxpp.api import MpmdMesh + # pylint: disable=too-many-positional-arguments Transformer = models.transformer_as_linen @@ -117,18 +121,32 @@ def jit_and_compile( static_argnums, donate_argnums, logical_axis_rules, + mpmd_mesh=None, ): """Jit, lower, and compile func.""" with mesh, logical_axis_rules: - jitted = jax.jit( + if mpmd_mesh is not None: + p_train_step = jaxpp.mpmd_jit_with_loop( func, + mpmd_mesh=mpmd_mesh, + donate_argnums=donate_argnums, in_shardings=in_shardings, out_shardings=out_shardings, - static_argnums=static_argnums, - donate_argnums=donate_argnums, - ) - lowered = jitted.lower(*func_input_args, **func_input_kwargs) - compiled = lowered.compile() + ) + assert len(func_input_kwargs) == 0 + compiled = p_train_step.compile(*func_input_args) + else: + jitted = jax.jit( + func, + in_shardings=in_shardings, + out_shardings=out_shardings, + static_argnums=static_argnums, + donate_argnums=donate_argnums, + ) + lowered = jitted.lower(*func_input_args, **func_input_kwargs) + + if mpmd_mesh is None: + compiled = lowered.compile() return compiled @@ -152,21 +170,28 @@ def main(argv: Sequence[str]) -> None: # Create target mesh topology_mesh = get_topology_mesh(config) + if config.use_jaxpp: + mpmd_mesh = MpmdMesh(topology_mesh, 'stage') + mesh = mpmd_mesh.lowering_mesh() + else: + mpmd_mesh = None + mesh = topology_mesh # Print system information after building the compile topology to avoid # prematurely initializing the backend. max_utils.print_system_information() # Get shaped inputs - shaped_train_args, shaped_train_kwargs, state_mesh_shardings, model = get_shaped_inputs(topology_mesh, config) + shaped_train_args, shaped_train_kwargs, state_mesh_shardings, model = get_shaped_inputs(mesh, config) + params_shardings, state_mesh_shardings = max_utils.maybe_update_params_sharding_with_opt(config, state_mesh_shardings) # Get data sharding - data_sharding = maxtext_utils.get_input_data_sharding(config, topology_mesh) + data_sharding = maxtext_utils.get_input_data_sharding(config, mesh) # Get function to compile and shardings func_to_compile, in_shard, out_shard, static_argnums, donate_argnums = ( maxtext_utils.get_functional_train_with_signature( - train.train_step, data_sharding, state_mesh_shardings, model, config + train.train_step, data_sharding, state_mesh_shardings, model, config, params_shardings=params_shardings ) ) @@ -176,12 +201,13 @@ def main(argv: Sequence[str]) -> None: func_to_compile, shaped_train_args, shaped_train_kwargs, - topology_mesh, + mesh, in_shard, out_shard, static_argnums, donate_argnums, nn_partitioning.axis_rules(config.logical_axis_rules), + mpmd_mesh=mpmd_mesh ) print("Jitting and compilation complete!", flush=True) @@ -190,9 +216,11 @@ def main(argv: Sequence[str]) -> None: print("Saving compiled object...") save_compiled(compiled, config.compiled_trainstep_file) print(f"Successfully saved compiled object as {config.compiled_trainstep_file}") - print("Finished train_compile.py successfully!", flush=True) - print(f"Cost analysis: {compiled.cost_analysis()}") - print(f"Memory analysis: {compiled.memory_analysis()}") + + if not config.use_jaxpp: + print("Finished train_compile.py successfully!", flush=True) + print(f"Cost analysis: {compiled.cost_analysis()}") + print(f"Memory analysis: {compiled.memory_analysis()}") # Dump HLO if requested if config.dump_hlo: diff --git a/src/MaxText/train_utils.py b/src/MaxText/train_utils.py index 651349e5ea..3b15e5283b 100644 --- a/src/MaxText/train_utils.py +++ b/src/MaxText/train_utils.py @@ -28,6 +28,7 @@ from MaxText.utils.goodput_utils import maybe_record_goodput from MaxText import model_creation_utils +import jaxpp.api as jaxpp def create_training_tools(config, model, mesh): """Creates the init_rng, optimizer, learning rate schedule, and checkpoint manager.""" @@ -76,7 +77,7 @@ def create_training_tools(config, model, mesh): return init_rng, checkpoint_manager, learning_rate_schedule, tx -def jit_train_step(config, model, state, state_mesh_shardings, data_sharding, train_step, params_shardings): +def jit_train_step(config, model, state, state_mesh_shardings, data_sharding, train_step, maybe_mpmd_mesh, params_shardings): """Returns a JIT-compiled train step function, which is loaded from a file if specified in the config.""" ( functional_train, @@ -95,18 +96,27 @@ def jit_train_step(config, model, state, state_mesh_shardings, data_sharding, tr p_train_step = maxtext_utils.load_compiled(config, functional_train, state) print("Loaded compiled function!", flush=True) else: - p_train_step = jax.jit( + if not config.use_jaxpp: + p_train_step = jax.jit( functional_train, in_shardings=in_shardings, out_shardings=out_shardings, static_argnums=static_argnums, + donate_argnums=donate_argnums) + else: + max_logging.log("Running with jaxpp") + p_train_step = jaxpp.mpmd_jit_with_loop( + functional_train, + mpmd_mesh=maybe_mpmd_mesh, donate_argnums=donate_argnums, - ) + in_shardings=in_shardings, + out_shardings=out_shardings, + ) return p_train_step -def jit_eval_step(config, model, state_mesh_shardings, data_sharding, eval_step): +def jit_eval_step(config, model, state_mesh_shardings, data_sharding, eval_step, maybe_mpmd_mesh): """Returns a JIT-compiled eval step function.""" ( functional_eval, @@ -118,13 +128,23 @@ def jit_eval_step(config, model, state_mesh_shardings, data_sharding, eval_step) p_eval_step = None if config.compiled_trainstep_file == "": - p_eval_step = jax.jit( + if not config.use_jaxpp: + p_eval_step = jax.jit( + functional_eval, + in_shardings=in_shardings, + out_shardings=out_shardings, + static_argnums=static_argnums, + donate_argnums=donate_argnums, + ) + else: + p_eval_step = jaxpp.mpmd_jit_by_yield( functional_eval, + mpmd_mesh=maybe_mpmd_mesh, in_shardings=in_shardings, out_shardings=out_shardings, static_argnums=static_argnums, donate_argnums=donate_argnums, - ) + ) return p_eval_step @@ -132,7 +152,7 @@ def jit_eval_step(config, model, state_mesh_shardings, data_sharding, eval_step) def jit_train_and_eval_step( config, model, - mesh, + maybe_mpmd_mesh, state, state_mesh_shardings, train_step, @@ -141,11 +161,12 @@ def jit_train_and_eval_step( params_shardings=None, ): """Returns a JIT-compiled train and eval step function.""" + mesh = maybe_mpmd_mesh.lowering_mesh() if config.use_jaxpp else maybe_mpmd_mesh data_sharding = maxtext_utils.get_input_data_sharding(config, mesh) - p_train_step = jit_train_step(config, model, state, state_mesh_shardings, data_sharding, train_step, params_shardings) + p_train_step = jit_train_step(config, model, state, state_mesh_shardings, data_sharding, train_step, maybe_mpmd_mesh, params_shardings) p_eval_step = None if eval_data_iterator: - p_eval_step = jit_eval_step(config, model, state_mesh_shardings, data_sharding, eval_step) + p_eval_step = jit_eval_step(config, model, state_mesh_shardings, data_sharding, eval_step, maybe_mpmd_mesh) return p_train_step, p_eval_step @@ -171,7 +192,13 @@ def setup_train_loop(config, recorder, devices=None): with maybe_record_goodput(recorder, GoodputEvent.TPU_INIT): model = model_creation_utils.from_config(config, devices) - mesh = model.mesh + maybe_mpmd_mesh = model.mesh + if config.use_jaxpp: + assert isinstance(maybe_mpmd_mesh, jaxpp.MpmdMesh) + model.mesh = mesh = maybe_mpmd_mesh.lowering_mesh() + else: + assert isinstance(maybe_mpmd_mesh, jax.sharding.Mesh) + mesh = maybe_mpmd_mesh init_rng, checkpoint_manager, learning_rate_schedule, tx = create_training_tools(config, model, mesh) with maybe_record_goodput(recorder, GoodputEvent.TRAINING_PREPARATION): @@ -198,9 +225,17 @@ def setup_train_loop(config, recorder, devices=None): ) state, _, state_mesh_shardings, data_iterator = maxtext_utils.setup_training_state( - model, data_iterator, tx, config, init_rng, mesh, checkpoint_manager + model, data_iterator, tx, config, init_rng, maybe_mpmd_mesh, checkpoint_manager ) + def make_line(keypath, array_or_array_ref): + sharding = array_or_array_ref.sharding + return (f"{jax.tree_util.keystr(keypath):<120}, {str(array_or_array_ref.dtype):<10}, " + f"{str(array_or_array_ref.shape):<26}, {sharding._to_xla_hlo_sharding(array_or_array_ref.ndim)}") + + max_logging.log("shardings/weights") + max_logging.log("\n".join(make_line(keypath, array_ref) for keypath, array_ref in jax.tree_util.tree_leaves_with_path(state))) + # TODO(aireenmei, hengtaoguo): support sharding in vit for multimodal if not config.using_pipeline_parallelism and not config.use_multimodal: # The vocab tensor(s) of shape [vocab, embed] (and transpose) are not sharded by stage @@ -243,7 +278,7 @@ def setup_train_loop(config, recorder, devices=None): checkpoint_manager, state_mesh_shardings, model, - mesh, + maybe_mpmd_mesh, learning_rate_schedule, data_iterator, eval_data_iterator, diff --git a/tests/train_compile_jaxpp_test.py b/tests/train_compile_jaxpp_test.py new file mode 100644 index 0000000000..9755ce5a3e --- /dev/null +++ b/tests/train_compile_jaxpp_test.py @@ -0,0 +1,127 @@ +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from MaxText.train_compile import main as train_compile_main + + +def get_args( + model_name, + num_nodes, + ici_dp, + ici_tp, + dcn_pp, + vp, + ga, + per_device_batch_size=2, + ici_cp=1, + ici_ep=1, + dcn_ep=1, + quantization=None, + disable_cache: bool = False, +): + res = ( + None, + "MaxText/configs/base.yml", + f"model_name={model_name}", + "attention=dot_product", + "remat_policy=minimal", + "dtype=bfloat16", + "max_target_length=2048", + f"per_device_batch_size={per_device_batch_size}", + "hardware=gpu", + # SPMD Parallelism + f"ici_data_parallelism={ici_dp}", + f"ici_tensor_parallelism={ici_tp}", + f"ici_context_parallelism={ici_cp}", + # Pipeline + f"dcn_pipeline_parallelism={dcn_pp}", + f"num_pipeline_microbatches={ga}", + f"num_pipeline_repeats={vp}", + # JaxPP + "use_jaxpp=True", + "schedule=interleaved_1f1b", + "compile_topology=a3", + f"compile_topology_num_slices={num_nodes}", + ) + if ici_ep > 1: + res = res + (f"ici_expert_parallelism={ici_ep}",) + if dcn_ep > 1: + res = res + (f"dcn_expert_parallelism={dcn_ep}",) + if quantization is not None: + res = res + (f"quantization={quantization}",) + if disable_cache: + res = res + ("jax_cache_dir=",) + return res + + +class TrainCompile(unittest.TestCase): + def test_compile_llama4(self): + train_compile_main( + get_args( + model_name="llama4-17b-16e", + num_nodes=32, + ici_dp=1, + ici_ep=2, + ici_tp=4, + dcn_ep=8, + dcn_pp=4, + vp=4, + ga=64, + per_device_batch_size=4 + ) + ) + + def test_compile_gpt3(self): + train_compile_main( + get_args( + model_name="gpt3-175b", + num_nodes=16, + ici_dp=2, + ici_tp=4, + dcn_pp=8, + vp=6, + ga=32, + ) + ) + + def test_compile_gpt3_fp8(self): + train_compile_main( + get_args( + model_name="gpt3-175b", + num_nodes=16, + ici_dp=2, + ici_tp=4, + dcn_pp=8, + vp=6, + ga=32, + quantization="fp8", + ) + ) + + def test_compile_llama3(self): + train_compile_main( + get_args( + model_name="llama3.3-70b", + num_nodes=8, + ici_dp=1, + ici_cp=2, + ici_tp=4, + dcn_pp=4, + vp=5, + ga=64, + per_device_batch_size=2 + ) + )