diff --git a/README.md b/README.md index c462b72..33d7216 100644 --- a/README.md +++ b/README.md @@ -73,6 +73,8 @@ mistralai/Mistral-7B-v0.1 mistralai/Mistral-7B-Instruct-v0.1 mistralai/Mistral-7B-Instruct-v0.2 meta-llama/Meta-Llama-3-8B +meta-llama/Meta-Llama-3.1-8B +meta-llama/Meta-Llama-3.1-70B meta-llama/Meta-Llama-3.1-405B ``` @@ -93,8 +95,10 @@ Benchmarks run on an 8xA100-80GB, power limited to 330W with a hybrid cube mesh | Llama-2-70B | Base | OOM || | | 8-bit | 19.13 | 1322.58 | | | 4-bit (G=32) | 25.25 | 1097.66 | -| Llama-3-8B | Base | 94.25 | 1411.95 | -| | 8-bit | 139.55 | 1047.23 | +| Llama-3.1-8B | Base | 93.89 | 1410.76 | +| | 8-bit | 137.64 | 1030.89 | +| Llama-3.1-70B | Base | OOM || +| | 8-bit | 18.04 | 1253.78 | ### Speculative Sampling [Verifier: Llama-70B (int4), Draft: Llama-7B (int4)](./scripts/speculate_70B_int4.sh): 48.4 tok/s @@ -110,10 +114,14 @@ Benchmarks run on an 8xA100-80GB, power limited to 330W with a hybrid cube mesh | | 2 | 21.32 | 1481.87 | | | 4 | 38.01 | 1340.76 | | | 8 | 62.50 | 1135.29 | -| Llama-3-8B | 1 | 94.19 | 1411.76 | -| | 2 | 150.48 | 1208.80 | -| | 4 | 219.77 | 991.63 | -| | 8 | 274.65 | 768.55 | +| Llama-3.1-8B | 1 | 93.83 | 1408.37 | +| | 2 | 149.10 | 1197.32 | +| | 4 | 217.21 | 986.32 | +| | 8 | 276.01 | 772.60 | +| Llama-3.1-70B | 1 | OOM | | +| | 2 | 16.03 | 1130.81 | +| | 4 | 37.45 | 1360.53 | +| | 8 | 58.78 | 1129.61 | ### Tensor Parallelism + Quantization | Model | Technique | Tokens/Second | Memory Bandwidth (GB/s) | @@ -121,6 +129,8 @@ Benchmarks run on an 8xA100-80GB, power limited to 330W with a hybrid cube mesh | Llama-2-70B | Base | 62.50 | 1135.29 | | | 8-bit | 80.44 | 752.04 | | | 4-bit (G=32) | 90.77 | 548.10 | +| Llama-3.1-70B | Base | 58.78 | 1129.61 | +| | 8-bit | 75.58 | 726.57 | | Llama-3.1-405B | 8-bit | 15.60 | 815.87 | ### AMD diff --git a/model.py b/model.py index 6799206..c265799 100644 --- a/model.py +++ b/model.py @@ -70,6 +70,12 @@ def from_name(cls, name: str): "llama-3-8b": dict(block_size=8192, n_layer=32, n_head=32, n_local_heads=8, dim=4096, intermediate_size=14336, vocab_size=128256, rope_base=500000), "llama-3-70b": dict(block_size=8192, n_layer=80, n_head=64, n_local_heads=8, dim=8192, intermediate_size=28672, vocab_size=128256, rope_base=500000), + "llama-3.1-8b": dict(block_size=131072, n_layer=32, n_head=32, n_local_heads=8, dim=4096, intermediate_size=14336, vocab_size=128256, rope_base=500000, + rope_scaling=dict(factor=8.0, low_freq_factor=1.0, high_freq_factor=4.0, original_max_position_embeddings=8192), + ), + "llama-3.1-70b": dict(block_size=131072, n_layer=80, n_head=64, n_local_heads=8, dim=8192, intermediate_size=28672, vocab_size=128256, rope_base=500000, + rope_scaling=dict(factor=8.0, low_freq_factor=1.0, high_freq_factor=4.0, original_max_position_embeddings=8192), + ), "llama-3.1-405b": dict(block_size=131072, n_layer=126, n_head=128, n_local_heads=8, dim=16384, intermediate_size=53248, vocab_size=128256, rope_base=500000, rope_scaling=dict(factor=8.0, low_freq_factor=1.0, high_freq_factor=4.0, original_max_position_embeddings=8192), ), diff --git a/scripts/convert_hf_checkpoint.py b/scripts/convert_hf_checkpoint.py index f08eaba..02a8f04 100644 --- a/scripts/convert_hf_checkpoint.py +++ b/scripts/convert_hf_checkpoint.py @@ -116,7 +116,7 @@ def permute(w, n_head): print(f"Saving checkpoint to {checkpoint_dir / 'model.pth'}") torch.save(final_result, checkpoint_dir / "model.pth") if 'llama-3' in model_name.lower(): - if 'llama-3.1' in model_name.lower(): + if 'llama-3.1-405b' in model_name.lower(): original_dir = checkpoint_dir / "original" / "mp16" else: original_dir = checkpoint_dir / "original"