This repository implements BanditSpec, a speculative decoding framework that adaptively balances exploration and exploitation using bandit algorithms to accelerate autoregressive generation in large language models (LLMs). The framework is compatible with both LLaMA and Qwen2 architectures.
eagle_llama.py: Defines the Eagle (Li Y. et al. 2024 ) draft model based on LLaMA.eagle_qwen.py: Defines the Eagle (Li Y. et al. 2024 ) draft model based on Qwen2.llama.py,qwen.py: Customized versions of LLaMA and Qwen2 architectures.generate_utils.py: Implements core decoding strategies including BanditSpec.inference_length.py: Main script to run throughput benchmarking across different batch sizes and strategies.llama_long.png: Visualization of throughput improvement comparisons.
pip install torch transformers fairscale flash-attn tqdm
⚠️ Make sureflash-attnis compiled for your CUDA and PyTorch version.
Download EAGLE models from their repo (https://github.com/SafeAILab/EAGLE)
project/
├── inference_length.py
├── eagle_llama.py
├── eagle_qwen.py
├── llama.py
├── qwen.py
├── generate_utils.py
├── llama_long.png
├── llama_model/ # contains config.json and pytorch_model.bin for LLaMA
└── eagle_model/ # contains config.json and pytorch_model.bin for Eagle
Modify inference_length.py to set:
target_path = "llama_model"
eagle_path = "eagle_model"python inference_length.pyThis will run decoding experiments across:
- Different batch sizes
- Various
gammavalues - Baselines like
Best Arm,Worst Arm, and fixedgamma
bsz spec_quota gamma throughput
10 256 BanditSpec 1.43
20 256 gamma=1 1.61
...
Li Y, Wei F, Zhang C, et al. Eagle: Speculative sampling requires rethinking feature uncertainty[J]. arXiv preprint arXiv:2401.15077, 2024.