Skip to content

[ICLR2025] Breaking Throughput-Latency Trade-off for Long Sequences with Speculative Decoding

License

Notifications You must be signed in to change notification settings

Infini-AI-Lab/MagicDec

Folders and files

NameName
Last commit message
Last commit date

Latest commit

09cd671 · Nov 28, 2024
Nov 16, 2024
Nov 16, 2024
Sep 5, 2024
Aug 23, 2024
Nov 28, 2024
Sep 28, 2024
Aug 5, 2024
Nov 28, 2024
Nov 16, 2024
Aug 29, 2024
Sep 5, 2024
Aug 9, 2024
Nov 14, 2024
Aug 9, 2024

Repository files navigation

MagicDec: Breaking Throughput-Latency Trade-off for Long Context Generation
with Speculative Decoding

Jian Chen*1, Vashisth Tiwari*1, Ranajoy Sadhukhan*1, Zhuoming Chen1, Jinyuan Shi2, Ian En-Hsu Yen2, Beidi Chen1,3
1Carnegie Mellon University 2Moffett AI 3Meta AI (FAIR)
[Paper] | [Blog]

Update

  • Supports flashinfer and paged attention to further accelerate inference.
  • Supports SnapKV-based drafting for higher speculation quality.
  • Supppots Qwen2.5-[7B,14B,32B], Yi-1.5-[6B,34B], Mistral-7B-v0.1 and Mistral-7B-v0.1.
  • Please make sure PyTorch version greater than 2.5 to use the new features like custom all-reduce.

Installation

Environment Set Up

conda create -n magicdec python=3.11
conda activate magicdec
pip install torch --index-url https://download.pytorch.org/whl/cu121
pip install -r requirements.txt
pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/

Prepare Checkpoints

Currently, we support Llama-2-7b and its long context variant Llama-2-7b-32k, Llama-2-13b, Llama-2-70b, Llama-3-8b, Llama-3-70b, llama-68m, TinyLlama, Llama-3.1-8b, Llama-3.1-70b, Llama-3.2-1b, Qwen2.5-[7B,14B,32B], Yi-1.5-[6B,34B], Mistral-7B-v0.1 and Mistral-7B-v0.1.

We can first download the checkpoints we need through download.py. The --repo_id should be set to the repository ID to download from. The --hf_token should be your HuggingFace API token. The --out_dir should be the directory you want to save the checkpoint.

python download.py --repo_id meta-llama/Meta-Llama-3.1-8B --hf_token "YOUR HUGGINGFACE API TOKEN" --out_dir checkpoints/meta-llama/Meta-Llama-3.1-8B

Then we need to convert the downloaded checkpoint. --checkpoint_dir should be set to the directory we just saved the checkpoint. This script will generate a new model.pth file in the configured directory.

python convert_hf_checkpoint.py --checkpoint_dir checkpoints/meta-llama/Meta-Llama-3.1-8B

Evaluations

We conducted all the experiments in the paper on 8xA100, 8xH100 and 8xL40. We used PG-19 as the dataset for all the experiments.

Baseline

We used the new one-shot and two-shot all-reduce of PyTorch 2.5 by setting ENABLE_INTRA_NODE_COMM=1. --nproc_per_node should be set to the number of GPUs you want to do tensor parallelism. --model should be set to the directory of the model.pth, which is the checkpoint we want to serve. --model_name should be set to the repo id of the checkpoint, which is used to load tokenizer. --rank_group should be set to the list of GPU id in tensor parallelism. --B is the batch size, --prefix_len is the prefill length, --max_len is the max number of tokens we want to generate for each sentence. --printoutput is the flag which decides whether or not to print the output after generation finishes. --compile is the flag to decide whether or not use torch.compile to accelerate the generation.

ENABLE_INTRA_NODE_COMM=1 torchrun --standalone --nproc_per_node=8 tests/baseline_benchmark.py --model checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --model_name meta-llama/Meta-Llama-3.1-8B --rank_group 0 1 2 3 4 5 6 7 --B 1 --prefix_len 3969 --max_len 4096 --printoutput --compile

Standalone Draft

For standalone draft experiment, we use --target and --model to set the target and draft checkpoint. --model_name should be set to the repo id of target model, which will used to load the corresponding tokenizer. --rank_group should be set to the GPU id we want to do tensor parallelism for the target model, and --draft_rank_group should be set to the GPU id we want to do TP for the draft model. --draft_budget should be set to the KV budget for the draft model. Set --draft_budget of StreamingLLM/longspec_benchmark.py to -1 to disable KV compression of draft model (Use full KV, the original speculative decoding).

SnapKV-based Drafting

ENABLE_INTRA_NODE_COMM=1 torchrun --standalone --nproc_per_node=8 tests/SnapKV/longspec_benchmark.py --target checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --model checkpoints/meta-llama/Llama-3.2-1B/model.pth --model_name meta-llama/Meta-Llama-3.1-8B --rank_group 0 1 2 3 4 5 6 7 --draft_rank_group 0 1 2 3 --gamma 3 --B 64 --prefix_len 16032 --max_len 16128 --draft_budget 257 --benchmark --compile

StreamingLLM-based Drafting

ENABLE_INTRA_NODE_COMM=1 torchrun --standalone --nproc_per_node=8 tests/StreamingLLM/longspec_benchmark.py --target checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --model checkpoints/meta-llama/Llama-3.2-1B/model.pth --model_name meta-llama/Meta-Llama-3.1-8B --rank_group 0 1 2 3 4 5 6 7 --draft_rank_group 0 1 2 3 --gamma 3 --B 64 --prefix_len 16032 --max_len 16128 --draft_budget 257 --benchmark --compile

Self-Speculation

Similar to the standalone draft, but here we do not need to configure the draft model as it is the target model itself.

SnapKV-based Drafting

ENABLE_INTRA_NODE_COMM=1 torchrun --standalone --nproc_per_node=8 tests/SnapKV/selfspec_benchmark.py --model checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --model_name meta-llama/Meta-Llama-3.1-8B --rank_group 0 1 2 3 4 5 6 7 --gamma 3 --B 64 --prefix_len 16032 --max_len 16128 --draft_budget 257 --benchmark --compile

StreamingLLM-based Drafting

ENABLE_INTRA_NODE_COMM=1 torchrun --standalone --nproc_per_node=8 tests/StreamingLLM/selfspec_benchmark.py --model checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --model_name meta-llama/Meta-Llama-3.1-8B --rank_group 0 1 2 3 4 5 6 7 --gamma 3 --B 64 --prefix_len 16032 --gen_len 16128 --draft_budget 257 --benchmark --compile

Citation

If you find MagicDec useful or relevant to your project and research, please kindly cite our paper:

@article{chen2024magicdec,
  title={MagicDec: Breaking the Latency-Throughput Tradeoff for Long Context Generation with Speculative Decoding},
  author={Chen, Jian and Tiwari, Vashisth and Sadhukhan, Ranajoy and Chen, Zhuoming and Shi, Jinyuan and Yen, Ian En-Hsu and Chen, Beidi},
  journal={arXiv preprint arXiv:2408.11049},
  year={2024}
}

About

[ICLR2025] Breaking Throughput-Latency Trade-off for Long Sequences with Speculative Decoding

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published