Skip to content

Commit

Permalink
feat(llama): support rope scaling arguments to improve flexibility
Browse files Browse the repository at this point in the history
  • Loading branch information
tengomucho committed Dec 13, 2024
1 parent fcffabe commit 3d960ad
Showing 1 changed file with 12 additions and 8 deletions.
20 changes: 12 additions & 8 deletions jetstream_pt/third_party/llama/model_exportable.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,12 +171,12 @@ def forward(
return out


def apply_scaling(freqs: torch.Tensor):
def apply_scaling(freqs: torch.Tensor, config: model_args.RopeScalingArgs):
# Values obtained from grid search
scale_factor = 8
low_freq_factor = 1
high_freq_factor = 4
old_context_len = 8192 # original llama3 length
scale_factor = config.factor
low_freq_factor = config.low_freq_factor
high_freq_factor = config.high_freq_factor
old_context_len = config.original_max_position_embeddings

low_freq_wavelen = old_context_len / low_freq_factor
high_freq_wavelen = old_context_len / high_freq_factor
Expand All @@ -197,12 +197,15 @@ def apply_scaling(freqs: torch.Tensor):


def precompute_freqs_cis(
dim: int, end: int, theta: float = 10000.0, use_scaled: bool = False
dim: int,
end: int,
theta: float = 10000.0,
rope_scaling_config: model_args.RopeScalingArgs = None,
):
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
t = torch.arange(end, device=freqs.device, dtype=torch.float32)
if use_scaled:
freqs = apply_scaling(freqs)
if rope_scaling_config is not None:
freqs = apply_scaling(freqs, rope_scaling_config)
freqs = torch.outer(t, freqs)
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
return freqs_cis
Expand Down Expand Up @@ -251,6 +254,7 @@ def __init__(
self.params.dim // self.params.n_heads,
self.params.max_seq_len * 2,
theta=self.params.rope_theta,
rope_scaling_config=self.params.rope_scaling_args,
)

self.register_buffer("freqs_cis", freqs_cis)
Expand Down

0 comments on commit 3d960ad

Please sign in to comment.