MaxText supports Llama2 pretraining, finetuning and decoding for its 7B and 70B flavors. To get started on decoding and finetuning of Llama2, you will first need to download weights along with its tokenizer from Meta.
The file test_llama2_7b.sh provides details on how to convert the PyTorch weights in orbax checkpoint format, and thereafter use it for running decoding and finetuning. test_llama2_7b.sh also shows how to run pretraining and also how to run decoding on the finetuned model checkpoint.
Model Flop utilization for training on v5e and v5p and v4 TPUs with MaxText.
Model | v4-128 (bf16) | v5p-128 (bf16) | v5e-256 (bf16) |
---|---|---|---|
Llama2-70b | 57% | 65% | 57% |