Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/README.md
2 changes: 1 addition & 1 deletion .vscode/launch.json
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,6 @@
"quantization=int8",
"quantize_kvcache=True"
]
}
},
]
}
26 changes: 26 additions & 0 deletions jaxpp.Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

ARG BASE_IMAGE
FROM $BASE_IMAGE AS base
ARG JAX_INSTALL_URL

COPY requirements.txt /tmp/requirements.txt
RUN uv pip install -U pip && uv pip install --no-cache-dir -U -r /tmp/requirements.txt

COPY --chown=$USER_UID:$USER_GID . maxtext

RUN uv pip install --no-cache-dir -e '/workdir/jaxpp[dev]'
RUN uv pip install --no-cache-dir -e /workdir/maxtext[cuda_12] --resolution=lowest && \
if [[ -n "$JAX_INSTALL_URL" ]]; then uv pip install $JAX_INSTALL_URL; fi

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess we don't need this?

78 changes: 78 additions & 0 deletions jaxpp.README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# Overview

This repository is a fork of [MaxText](https://github.com/AI-Hypercomputer/maxtext) created for training with [JaxPP](https://github.com/NVIDIA/jaxpp).

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's rm this jaxpp.README.md for now


# Notable changes

The changes between this repo and the upstream MaxText is kept minimal in general.
Some of the notable changes are listed below.

* The `__call__` method of the `Decoder` class in [src/MaxText/layers/decoders.py](src/MaxText/layers/decoders.py)
calls `jaxpp.pipeline_enter_stage` to mark stage boundaries for pipeline parallelism.
* The `maybe_initialize_jax_distributed_system` function in [src/MaxText/max_utils.py](src/MaxText/max_utils.py)
creates `RemoteMpmdMesh` to be used by JaxPP.
* [src/MaxText/train.py](src/MaxText/train.py) contains changes to
* Enable pipeline parallelism for the train step, and
* Mark the pipeline loop in the train step with `jaxpp.treduce`.

# Docker image

For ease of use, we provide a docker image with this fork under `/workdir/maxtext`.
The docker image has all the dependencies that are needed to use MaxText with JaxPP installed.

## Building and Testing Docker Container

The build process uses the JaxPP base image as a starting point. Follow the instructions at [JaxPP's Building the Base Image](https://github.com/NVIDIA/jaxpp#building-the-base-image) to build the `jaxpp-base` image first.

### Prerequisites
- Docker installed and configured
- NVIDIA Container Toolkit installed
- JaxPP base image built and available locally

### Building the Main Image

After building the base image, you can build the main image:

```bash
# Check if jaxpp-base image exists
if [ -z "$(docker images -q jaxpp-base)" ]; then
echo "Error: jaxpp-base image not found. Please build it first using the instructions at https://github.com/NVIDIA/jaxpp#building-the-base-image."
else
docker build --force-rm=true \
-f jaxpp.Dockerfile \
--build-arg BASE_IMAGE=jaxpp-base \
-t maxtext-jaxpp .
fi
```

### Running Tests

The container includes several test suites for different models:

1. **Tiny Llama4 Model Tests**:
```bash
docker run --gpus=all --shm-size=10.24gb --ulimit memlock=-1 --ulimit stack=67108864 \
-e CUDA_VISIBLE_DEVICES=0 --rm --workdir /workdir/maxtext maxtext-jaxpp \
"nvidia-smi && CONFIG_FILE=./scripts/llama4_proxy_config.sh bash scripts/test_1gpu_config.sh"
```

2. **Tiny Mixtral Model Tests**:
```bash
docker run --gpus=all --shm-size=10.24gb --ulimit memlock=-1 --ulimit stack=67108864 \
-e CUDA_VISIBLE_DEVICES=0 --rm --workdir /workdir/maxtext maxtext-jaxpp \
"nvidia-smi && MODEL_CONFIG='model_name=mixtral-8x7b override_model_config=True base_num_decoder_layers=2 base_emb_dim=512 base_mlp_dim=1792' bash scripts/test_1gpu_config.sh"
```

3. **Tiny Mistral Model Tests**:
```bash
docker run --gpus=all --shm-size=10.24gb --ulimit memlock=-1 --ulimit stack=67108864 \
-e CUDA_VISIBLE_DEVICES=0 --rm --workdir /workdir/maxtext maxtext-jaxpp \
"nvidia-smi && bash MODEL_CONFIG='model_name=mistral-7b override_model_config=True base_num_decoder_layers=2' bash scripts/test_1gpu_config.sh"
```

Note: The tests require GPU access and sufficient GPU memory.

# Profiling

Profiling is enabled by default in the 6th step, and the first 7 steps are ignored in the performance statistics.
It allows the performance statstics to be collected without the profiling overhead while producing the profiling data while running the benchmarks.
8 changes: 8 additions & 0 deletions scripts/deepseek3_proxy_config.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
export MODEL_CONFIG="
model_name=deepseek3-671b
override_model_config=True
base_num_decoder_layers=4
base_emb_dim=896
base_mlp_dim=2304
base_moe_mlp_dim=256
"
5 changes: 5 additions & 0 deletions scripts/llama3.3_proxy_config.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
export MODEL_CONFIG="
model_name=llama3.3-70b
override_model_config=True
base_num_decoder_layers=2
"
8 changes: 8 additions & 0 deletions scripts/llama4_proxy_config.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
export MODEL_CONFIG="
model_name=llama4-17b-16e
override_model_config=True
base_num_decoder_layers=2
base_emb_dim=640
base_mlp_dim=2048
base_moe_mlp_dim=2048
"
25 changes: 25 additions & 0 deletions scripts/local_mc.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

if [ -z "$N_PROCS" ] || [ -z "$N_GPUS" ] || [ -z "$COMMAND" ]; then
echo "N_PROCS, N_GPUS, and COMMAND must be set"
exit 1
fi

seq 0 $(($N_PROCS - 1)) | xargs -P $N_PROCS -I {} bash -c ' \
n_gpus=$2; \
start=$(({} * n_gpus)); \
end=$((start + n_gpus - 1)); \
JAX_COORDINATOR_IP="localhost" JAX_COORDINATOR_PORT=1234 NNODES=$1 NODE_RANK={} \
CUDA_VISIBLE_DEVICES=$(seq -s, $start $end) $3' _ $N_PROCS $N_GPUS "$COMMAND"
18 changes: 18 additions & 0 deletions scripts/run_local_mc.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
export TEST_CONFIG="
override_model_config=True dataset_type=synthetic steps=10
"

export COMMAND="python3 -u -m MaxText.train src/MaxText/configs/base.yml \
base_output_directory=run_local_mc_outputs \
run_name=run_$(date +%Y-%m-%d-%H:%M:%S) \
enable_checkpointing=false \
async_checkpointing=false \
dtype=bfloat16 \
weight_dtype=bfloat16 \
hardware=gpu \
$MODEL_CONFIG \
$TEST_CONFIG \
$PARALLELISM_CONFIG
$JAXPP_CONFIG"

bash ./scripts/local_mc.sh
24 changes: 24 additions & 0 deletions scripts/run_tests.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

bash scripts/test_8gpu_llama4_proxy.sh

RAY_ADDRESS=local python -m MaxText.train src/MaxText/configs/base.yml run_name=runner_jaxpp_$(date +%Y-%m-%d-%H-%M) base_output_directory=/tmp/log hardware=gpu dataset_type=synthetic model_name=gpt3-52k steps=20 dtype=bfloat16 max_target_length=2048 per_device_batch_size=4 dcn_data_parallelism=1 ici_data_parallelism=2 ici_tensor_parallelism=2 ici_pipeline_parallelism=2 num_pipeline_repeats=1 num_pipeline_microbatches=8 enable_checkpointing=false use_jaxpp=True schedule=interleaved_1f1b

tests=$(python3 -m pytest --co -q -W ignore::DeprecationWarning tests/train_compile_jaxpp_test.py | awk '/^[[:space:]]*$/{exit} {print}')

for t in $tests; do
echo $t
python3 -m pytest --log-cli-level=INFO -s "$t"
done
28 changes: 28 additions & 0 deletions scripts/test_1gpu_config.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
if [ -n "$MODEL_CONFIG" ] && [ -n "$CONFIG_FILE" ]; then
echo "Error: both MODEL_CONFIG and CONFIG_FILE are set"
exit 1
fi

if [ -n "$CONFIG_FILE" ]; then
source $CONFIG_FILE
fi

export N_PROCS=1
export N_GPUS=1

# Run plain JAX config
bash ./scripts/run_local_mc.sh

# Run JaxPP config
export PARALLELISM_CONFIG="ici_pipeline_parallelism=1"

export JAXPP_CONFIG="
scan_layers=False
use_jaxpp=True
schedule=interleaved_1f1b
num_pipeline_microbatches=4
num_pipeline_repeats=1
max_target_length=64
"

bash ./scripts/run_local_mc.sh
22 changes: 22 additions & 0 deletions scripts/test_8gpu_deepseek3_proxy.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
source scripts/deepseek3_proxy_config.sh

export PARALLELISM_CONFIG="
dcn_pipeline_parallelism=1 ici_pipeline_parallelism=2
ici_data_parallelism=1
ici_tensor_parallelism=2
ici_expert_parallelism=2
ici_fsdp_parallelism=1
"

export JAXPP_CONFIG="
scan_layers=False
use_jaxpp=True
schedule=interleaved_1f1b
num_pipeline_microbatches=4
num_pipeline_repeats=1
"

export N_PROCS=2
export N_GPUS=4

bash ./scripts/run_local_mc.sh
30 changes: 30 additions & 0 deletions scripts/test_8gpu_llama3.3_proxy.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
export JAX_USE_SHARDY_PARTITIONER=0
export JAXPP_ENABLE_LICM=1
export NVTE_FUSED_ATTN=1
# --xla_dump_hlo_pass_re=.*
export XLA_FLAGS="--xla_dump_hlo_as_html --xla_dump_hlo_as_text --xla_dump_to='./llama3-hlos-pp2' --xla_gpu_enable_latency_hiding_scheduler=true"
source scripts/llama3.3_proxy_config.sh

export PARALLELISM_CONFIG="
dcn_pipeline_parallelism=1 ici_pipeline_parallelism=2
ici_data_parallelism=1
ici_context_parallelism=2
ici_tensor_parallelism=2
ici_fsdp_parallelism=1
per_device_batch_size=1
max_target_length=8192
"

export JAXPP_CONFIG="
scan_layers=False
use_jaxpp=True
schedule=interleaved_1f1b
num_pipeline_microbatches=4
num_pipeline_repeats=1
profiler=xplane
"

export N_PROCS=8
export N_GPUS=1

bash ./scripts/run_local_mc.sh
22 changes: 22 additions & 0 deletions scripts/test_8gpu_llama4_proxy.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
source scripts/llama4_proxy_config.sh

export PARALLELISM_CONFIG="
dcn_pipeline_parallelism=1 ici_pipeline_parallelism=2
ici_data_parallelism=1
ici_tensor_parallelism=2
ici_expert_parallelism=2
ici_fsdp_parallelism=1
"

export JAXPP_CONFIG="
scan_layers=False
use_jaxpp=True
schedule=interleaved_1f1b
num_pipeline_microbatches=4
num_pipeline_repeats=1
"

export N_PROCS=2
export N_GPUS=4

bash ./scripts/run_local_mc.sh
Loading