Skip to content

Commit

Permalink
Llama 3.1 RoPE scaling (#205)
Browse files Browse the repository at this point in the history
* feat(llama): import RoPE scaling code

This is imported from the original Llama reference implementation:
https://github.com/meta-llama/llama-models/blob/7890266c5a3ccd29e739d53a71ea968bcf4ca400/models/llama3/reference_impl/model.py#L45

Note that the function does not have any effect on the original model
code as long as the use_scaled parameter is false (the default).

* feat(llama): add RopeScalingArgs

These are aligned with HF ones, so it will be easier to implement rope
scaling as it is done in Llama3.1.

* feat(llama): support rope scaling arguments to improve flexibility

* chore: relax safetensors pattern on download

* feat: untie weights when needed (i.e.: Llama3.2-1B)

* feat: add support for Llama3.1 - 3.2 and 3.3 models
  • Loading branch information
tengomucho authored Dec 16, 2024
1 parent 08e4977 commit bb174b6
Show file tree
Hide file tree
Showing 5 changed files with 126 additions and 5 deletions.
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,12 @@ meta-llama/Meta-Llama-3-8B
meta-llama/Meta-Llama-3-8B-Instruct
meta-llama/Meta-Llama-3-70B
meta-llama/Meta-Llama-3-70B-Instruct
meta-llama/Llama-3.1-8B
meta-llama/Llama-3.1-8B-Instruct
meta-llama/Llama-3.2-1B
meta-llama/Llama-3.2-1B-Instruct
meta-llama/Llama-3.3-70B
meta-llama/Llama-3.3-70B-Instruct
google/gemma-2b
google/gemma-2b-it
google/gemma-7b
Expand Down
3 changes: 3 additions & 0 deletions jetstream_pt/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@
def shard_weights(env, weights, weight_shardings):
"""Shard weights according to weight_shardings"""
sharded = {}
# Some output and embeddings weights might be tied: in this case untie them
if weights["output.weight"].device.type == "meta":
weights["output.weight"] = weights["tok_embeddings.weight"].clone()
for key, val in weights.items():
sharding = env.sharding_by_axis(weight_shardings.get(key, -1))
with jax.default_device(jax.devices("cpu")[0]):
Expand Down
11 changes: 10 additions & 1 deletion jetstream_pt/fetch_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,9 @@ class ModelInfo:
_llama2_70 = ModelInfo(llama_model.Transformer, 80, 8, 128, 8)
_llama3_8 = ModelInfo(llama_model.Transformer, 32, 8, 128, 4)
_llama3_70 = _llama2_70
_llama3_1_8b = _llama3_8
_llama3_2_1b = ModelInfo(llama_model.Transformer, 16, 8, 64, 4)
_llama3_3_70b = _llama2_70

_mixtral_87 = ModelInfo(mixtral_model.Transformer, 32, 8, 128, 4)

Expand All @@ -78,6 +81,12 @@ class ModelInfo:
"meta-llama/Meta-Llama-3-8B-Instruct": _llama3_8,
"meta-llama/Meta-Llama-3-70B": _llama3_70,
"meta-llama/Meta-Llama-3-70B-Instruct": _llama3_70,
"meta-llama/Llama-3.1-8B": _llama3_1_8b,
"meta-llama/Llama-3.1-8B-Instruct": _llama3_1_8b,
"meta-llama/Llama-3.2-1B": _llama3_2_1b,
"meta-llama/Llama-3.2-1B-Instruct": _llama3_2_1b,
"meta-llama/Llama-3.3-70B": _llama3_3_70b,
"meta-llama/Llama-3.3-70B-Instruct": _llama3_3_70b,
"google/gemma-2b": _gemma_2b,
"google/gemma-2b-it": _gemma_2b,
"google/gemma-7b": _gemma_7b,
Expand Down Expand Up @@ -215,7 +224,7 @@ def _hf_download(
local_dir_use_symlinks=False,
token=hf_token,
allow_patterns=[
"model-?????-of-?????.safetensors",
"model*.safetensors",
"*.json",
"*.model",
],
Expand Down
65 changes: 65 additions & 0 deletions jetstream_pt/third_party/llama/model_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,16 @@
from typing import Optional


@dataclasses.dataclass
class RopeScalingArgs:
"""Rope scaling configuration parameters."""

factor: float = 8.0
low_freq_factor: float = 1.0
high_freq_factor: float = 4.0
original_max_position_embeddings: int = 8192


@dataclasses.dataclass
class ModelArgs:
"""Model configuration parameters."""
Expand All @@ -29,6 +39,7 @@ class ModelArgs:
device = "cpu"

rope_theta: float = 10000.0
rope_scaling_args: RopeScalingArgs = None


def get_arg(
Expand Down Expand Up @@ -103,6 +114,60 @@ def get_arg(
"vocab_size": 128256,
"rope_theta": 500000.0,
}
elif model_name == "llama-3.1-8b":
data = {
"dim": 4096,
"vocab_size": 128256,
"multiple_of": 1024,
"ffn_dim_multiplier": 1.3,
"n_layers": 32,
"n_heads": 32,
"n_kv_heads": 8,
"norm_eps": 1e-05,
"rope_theta": 500000.0,
"rope_scaling_args": RopeScalingArgs(
factor=8.0,
low_freq_factor=1.0,
high_freq_factor=4.0,
original_max_position_embeddings=8192,
),
}
elif model_name == "llama-3.2-1b":
data = {
"dim": 2048,
"vocab_size": 128256,
"multiple_of": 1024,
"ffn_dim_multiplier": 1.5,
"n_layers": 16,
"n_heads": 32,
"n_kv_heads": 8,
"norm_eps": 1e-05,
"rope_theta": 500000.0,
"rope_scaling_args": RopeScalingArgs(
factor=32.0,
low_freq_factor=1.0,
high_freq_factor=4.0,
original_max_position_embeddings=8192,
),
}
elif model_name == "llama-3.3-70b":
data = {
"dim": 8192,
"vocab_size": 128256,
"multiple_of": 1024,
"ffn_dim_multiplier": 1.3,
"n_layers": 80,
"n_heads": 64,
"n_kv_heads": 8,
"norm_eps": 1e-05,
"rope_theta": 500000.0,
"rope_scaling_args": RopeScalingArgs(
factor=8.0,
low_freq_factor=1.0,
high_freq_factor=4.0,
original_max_position_embeddings=8192,
),
}

return ModelArgs(
max_seq_len=seqlen,
Expand Down
46 changes: 42 additions & 4 deletions jetstream_pt/third_party/llama/model_exportable.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Any, List, Optional
import copy
import jax
import math
import torch
import torch.nn.functional as F
import functools
Expand Down Expand Up @@ -170,12 +171,42 @@ def forward(
return out


def apply_scaling(freqs: torch.Tensor, config: model_args.RopeScalingArgs):
# Values obtained from grid search
scale_factor = config.factor
low_freq_factor = config.low_freq_factor
high_freq_factor = config.high_freq_factor
old_context_len = config.original_max_position_embeddings

low_freq_wavelen = old_context_len / low_freq_factor
high_freq_wavelen = old_context_len / high_freq_factor
new_freqs = []
for freq in freqs:
wavelen = 2 * math.pi / freq
if wavelen < high_freq_wavelen:
new_freqs.append(freq)
elif wavelen > low_freq_wavelen:
new_freqs.append(freq / scale_factor)
else:
assert low_freq_wavelen != high_freq_wavelen
smooth = (old_context_len / wavelen - low_freq_factor) / (
high_freq_factor - low_freq_factor
)
new_freqs.append((1 - smooth) * freq / scale_factor + smooth * freq)
return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device)


def precompute_freqs_cis(
dim: int, end: int, theta: float = 10000.0
) -> torch.Tensor:
dim: int,
end: int,
theta: float = 10000.0,
rope_scaling_config: model_args.RopeScalingArgs = None,
):
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
t = torch.arange(end, device=freqs.device) # type: ignore
freqs = torch.outer(t, freqs).float() # type: ignore
t = torch.arange(end, device=freqs.device, dtype=torch.float32)
if rope_scaling_config is not None:
freqs = apply_scaling(freqs, rope_scaling_config)
freqs = torch.outer(t, freqs)
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
return freqs_cis

Expand Down Expand Up @@ -223,6 +254,7 @@ def __init__(
self.params.dim // self.params.n_heads,
self.params.max_seq_len * 2,
theta=self.params.rope_theta,
rope_scaling_config=self.params.rope_scaling_args,
)

self.register_buffer("freqs_cis", freqs_cis)
Expand Down Expand Up @@ -306,6 +338,12 @@ def from_hf_model_id(cls, model_id, env, is_tiny=False):
"meta-llama/Meta-Llama-3-8B-Instruct": "llama-3-8b",
"meta-llama/Meta-Llama-3-70B": "llama-3-70b",
"meta-llama/Meta-Llama-3-70B-Instruct": "llama-3-70b",
"meta-llama/Llama-3.1-8B": "llama-3.1-8b",
"meta-llama/Llama-3.1-8B-Instruct": "llama-3.1-8b",
"meta-llama/Llama-3.2-1B": "llama-3.2-1b",
"meta-llama/Llama-3.2-1B-Instruct": "llama-3.2-1b",
"meta-llama/Llama-3.3-70B": "llama-3.3-70b",
"meta-llama/Llama-3.3-70B-Instruct": "llama-3.3-70b",
}.get(model_id)
assert name
args = model_args.get_model_args(
Expand Down

0 comments on commit bb174b6

Please sign in to comment.