Skip to content

Commit

Permalink
Merge branch 'yueshen/rotary_scaling_fix_llama3_1' into 'main'
Browse files Browse the repository at this point in the history
rotary_scaling fix for llama3.1 and 3.2

See merge request ADLR/megatron-lm!2180
  • Loading branch information
ericharper committed Oct 4, 2024
2 parents fde8bb1 + 843a22e commit b98ec86
Show file tree
Hide file tree
Showing 9 changed files with 50 additions and 32 deletions.
48 changes: 35 additions & 13 deletions examples/export/ptq_and_trtllm_export/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ cd ../..

Now launch the PTQ + TensorRT-LLM export script,
```sh
bash examples/inference/quantization/ptq_trtllm_minitron_8b ./Minitron-8B-Base None
bash examples/export/ptq_and_trtllm_export/ptq_trtllm_minitron_8b ./Minitron-8B-Base None
```
By default, `cnn_dailymail` is used for calibration. The `GPTModel` will have quantizers for simulating the
quantization effect. The checkpoint will be saved optionally (with quantizers as additional states) and can
Expand Down Expand Up @@ -104,12 +104,12 @@ export trtllm_options=" \
--checkpoint_dir /tmp/trtllm_ckpt \
--output_dir /tmp/trtllm_engine \
--max_input_len 2048 \
--max_output_len 512 \
--max_seq_len 512 \
--max_batch_size 8 "

trtllm-build ${trtllm_options}

python examples/inference/quantization/trtllm_text_generation.py --tokenizer nvidia/Minitron-8B-Base
python examples/export/ptq_and_trtllm_export/trtllm_text_generation.py --tokenizer nvidia/Minitron-8B-Base
```

### mistral-12B FP8 Quantization and TensorRT-LLM Deployment
Expand Down Expand Up @@ -139,7 +139,7 @@ huggingface-cli login
Now launch the PTQ + TensorRT-LLM checkpoint export script,

```sh
bash examples/inference/quantization/ptq_trtllm_mistral_12b.sh ./Mistral-NeMo-12B-Base None
bash examples/export/ptq_and_trtllm_export/ptq_trtllm_mistral_12b.sh ./Mistral-NeMo-12B-Base None
```

Then build TensorRT engine and run text generation example using the newly built TensorRT engine
Expand All @@ -149,12 +149,12 @@ export trtllm_options=" \
--checkpoint_dir /tmp/trtllm_ckpt \
--output_dir /tmp/trtllm_engine \
--max_input_len 2048 \
--max_output_len 512 \
--max_seq_len 512 \
--max_batch_size 8 "

trtllm-build ${trtllm_options}

python examples/inference/quantization/trtllm_text_generation.py --tokenizer mistralai/Mistral-Nemo-Base-2407
python examples/export/ptq_and_trtllm_export/trtllm_text_generation.py --tokenizer mistralai/Mistral-Nemo-Base-2407
```


Expand All @@ -165,7 +165,7 @@ python examples/inference/quantization/trtllm_text_generation.py --tokenizer mis
> that we support.
```sh
bash examples/inference/quantization/ptq_trtllm_llama_7b.sh ${CHECKPOINT_DIR}
bash examples/export/ptq_and_trtllm_export/ptq_trtllm_llama_7b.sh ${CHECKPOINT_DIR}
```

The script expect `${CHECKPOINT_DIR}` to have the following structure:
Expand All @@ -184,8 +184,23 @@ The script expect `${CHECKPOINT_DIR}` to have the following structure:
In short, other than the converted llama megatron checkpoint, also put the Hugging Face checkpoint inside as
the source of the tokenizer.

Then build TensorRT engine and run text generation example using the newly built TensorRT engine

```sh
export trtllm_options=" \
--checkpoint_dir /tmp/trtllm_ckpt \
--output_dir /tmp/trtllm_engine \
--max_input_len 2048 \
--max_seq_len 512 \
--max_batch_size 8 "

trtllm-build ${trtllm_options}

python examples/export/ptq_and_trtllm_export/trtllm_text_generation.py --tokenizer meta-llama/Llama-2-7b
```

### llama3-8b / llama3.1-8b INT8 SmoothQuant and TensorRT-LLM Deployment
> **NOTE:** For llama3.1, the missing rope_scaling parameter will be fixed in modelopt-0.17 and trtllm-0.12.
> **NOTE:** For llama3.1, the missing rope_scaling parameter will be fixed in modelopt-0.19 and trtllm-0.13.
> **NOTE:** There are two ways to acquire the checkpoint. Users can follow
> the instruction in `docs/llama2.md` to convert the checkpoint to megatron legacy `GPTModel` format and
Expand All @@ -199,16 +214,23 @@ If users choose to download the model from NGC, first extract the sharded checkp
tar -xvf 8b_pre_trained_bf16.nemo
```

> **NOTE:** You need a token generated from huggingface.co/settings/tokens and access to meta-llama/Llama-3.1-8B or meta-llama/Llama-3-8B on huggingface
```sh
pip install -U "huggingface_hub[cli]"
huggingface-cli login
```

Now launch the PTQ + TensorRT-LLM checkpoint export script for llama-3,

```sh
bash examples/inference/quantization/ptq_trtllm_llama3_8b.sh ./llama-3-8b-nemo_v1.0 None
bash examples/export/ptq_and_trtllm_export/ptq_trtllm_llama3_8b.sh ./llama-3-8b-nemo_v1.0 None
```

or llama-3.1

```sh
bash examples/inference/quantization/ptq_trtllm_llama3_1_8b.sh ./llama-3_1-8b-nemo_v1.0 None
bash examples/export/ptq_and_trtllm_export/ptq_trtllm_llama3_1_8b.sh ./llama-3_1-8b-nemo_v1.0 None
```

Then build TensorRT engine and run text generation example using the newly built TensorRT engine
Expand All @@ -218,14 +240,14 @@ export trtllm_options=" \
--checkpoint_dir /tmp/trtllm_ckpt \
--output_dir /tmp/trtllm_engine \
--max_input_len 2048 \
--max_output_len 512 \
--max_seq_len 512 \
--max_batch_size 8 "

trtllm-build ${trtllm_options}

python examples/inference/quantization/trtllm_text_generation.py --tokenizer meta-llama/Meta-Llama-3-8B
python examples/export/ptq_and_trtllm_export/trtllm_text_generation.py --tokenizer meta-llama/Meta-Llama-3-8B
# For llama-3

python examples/inference/quantization/trtllm_text_generation.py --tokenizer meta-llama/Meta-Llama-3.1-8B
python examples/export/ptq_and_trtllm_export/trtllm_text_generation.py --tokenizer meta-llama/Meta-Llama-3.1-8B
#For llama-3.1
```
6 changes: 2 additions & 4 deletions examples/export/ptq_and_trtllm_export/ptq_trtllm_llama2_7b.sh
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ options=" \
--tokenizer-model ${TOKENIZER_MODEL} \
--save-interval 1000000 \
--use-dist-ckpt \
--load ${CHECKPOINT_LOAD_DIR}
--load ${CHECKPOINT_LOAD_DIR} \
--fp16"

# Precompile CUDA extentions
Expand All @@ -76,7 +76,5 @@ python -c "import modelopt.torch.quantization.extensions as ext; print(ext.cuda_
launch_config="--nproc_per_node=${TP}"

# Launch multi-process with torchrun
torchrun ${launch_config} examples/inference/quantization/text_generation_ptq.py ${options} ${additional_options}
torchrun ${launch_config} examples/export/ptq_and_trtllm_export/text_generation_ptq.py ${options} ${additional_options}

# This script is using mpi4py which will fork multiple processes.
python examples/inference/quantization/trtllm_text_generation.py ${trtllm_options}
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,10 @@ options=" \
--tokenizer-type HuggingFaceTokenizer \
--tokenizer-model meta-llama/Meta-Llama-3.1-8B \
--save-interval 1000000 \
--use-rope-scaling \
--use-dist-ckpt \
--load ${CHECKPOINT_LOAD_DIR}
--rotary-base 500000
--load ${CHECKPOINT_LOAD_DIR} \
--rotary-base 500000 \
--fp16"

# Precompile CUDA extentions
Expand All @@ -75,4 +76,4 @@ python -c "import modelopt.torch.quantization.extensions as ext; print(ext.cuda_
launch_config="--nproc_per_node=${TP}"

# Launch multi-process with torchrun
torchrun ${launch_config} examples/inference/quantization/text_generation_ptq.py ${options} ${additional_options}
torchrun ${launch_config} examples/export/ptq_and_trtllm_export/text_generation_ptq.py ${options} ${additional_options}
6 changes: 3 additions & 3 deletions examples/export/ptq_and_trtllm_export/ptq_trtllm_llama3_8b.sh
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,8 @@ options=" \
--tokenizer-model meta-llama/Meta-Llama-3-8B \
--save-interval 1000000 \
--use-dist-ckpt \
--load ${CHECKPOINT_LOAD_DIR}
--rotary-base 500000
--load ${CHECKPOINT_LOAD_DIR} \
--rotary-base 500000 \
--fp16"

# Precompile CUDA extentions
Expand All @@ -75,4 +75,4 @@ python -c "import modelopt.torch.quantization.extensions as ext; print(ext.cuda_
launch_config="--nproc_per_node=${TP}"

# Launch multi-process with torchrun
torchrun ${launch_config} examples/inference/quantization/text_generation_ptq.py ${options} ${additional_options}
torchrun ${launch_config} examples/export/ptq_and_trtllm_export/text_generation_ptq.py ${options} ${additional_options}
Original file line number Diff line number Diff line change
Expand Up @@ -71,4 +71,4 @@ python -c "import modelopt.torch.quantization.extensions as ext; print(ext.cuda_
launch_config="--nproc_per_node=${TP}"

# Launch multi-process with torchrun
torchrun ${launch_config} examples/inference/quantization/text_generation_ptq.py ${options} ${additional_options}
torchrun ${launch_config} examples/export/ptq_and_trtllm_export/text_generation_ptq.py ${options} ${additional_options}
Original file line number Diff line number Diff line change
Expand Up @@ -72,4 +72,4 @@ python -c "import modelopt.torch.quantization.extensions as ext; print(ext.cuda_
launch_config="--nproc_per_node=${TP}"

# Launch multi-process with torchrun
torchrun ${launch_config} examples/inference/quantization/text_generation_ptq.py ${options} ${additional_options}
torchrun ${launch_config} examples/export/ptq_and_trtllm_export/text_generation_ptq.py ${options} ${additional_options}
7 changes: 1 addition & 6 deletions examples/export/ptq_and_trtllm_export/text_generation_ptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,11 @@
import sys
from pathlib import Path

sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../../")))
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../../")))

import modelopt.torch.quantization as mtq
import torch
from datasets import load_dataset
from modelopt.torch.utils.distributed import set_data_parallel_group, set_tensor_parallel_group
from tqdm import tqdm

# [ModelOpt]: changing the default model provider to the ModelOpt version
Expand Down Expand Up @@ -179,10 +178,6 @@ def hf_dataset_forword_loop_func(model):
if args.calib_dataset is not None:
ptq_forward_loop_func = hf_dataset_forword_loop_func

# Setting data parallel and tensor parallel group
set_data_parallel_group(mpu.get_data_parallel_group())
set_tensor_parallel_group(mpu.get_tensor_model_parallel_group())

if args.export_quant_cfg in QUANT_CFG_CHOICES:
mtq_config = QUANT_CFG_CHOICES[args.export_quant_cfg]
if "*output_layer*" not in mtq_config["quant_cfg"]:
Expand Down
3 changes: 2 additions & 1 deletion megatron/core/models/gpt/gpt_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,10 +91,11 @@ def __init__(
# TODO: remove this dependency ?
self.model_type = ModelType.encoder_or_decoder

# These 2 attributes are needed for TensorRT-LLM export.
# These 4 attributes are needed for TensorRT-LLM export.
self.max_position_embeddings = max_sequence_length
self.rotary_percent = rotary_percent
self.rotary_base = rotary_base
self.rotary_scaling = rope_scaling

if self.pre_process:
self.embedding = LanguageModelEmbedding(
Expand Down
1 change: 1 addition & 0 deletions megatron/inference/gpt/model_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def model_provider(pre_process=True, post_process=True, parallel_output=True) ->
"position_embedding_type": args.position_embedding_type,
"rotary_percent": args.rotary_percent,
"rotary_base": args.rotary_base,
"rope_scaling": args.use_rope_scaling,
}

model = model_type(**model_kwargs)
Expand Down

0 comments on commit b98ec86

Please sign in to comment.