From e66117076c9d8931270176f7d50f365c351092e8 Mon Sep 17 00:00:00 2001 From: Jeffrey Morgan Date: Sat, 27 Jul 2024 05:03:45 -0700 Subject: [PATCH] llama : add support for llama 3.1 rope scaling factors (#8676) * Add llama 3.1 rope scaling factors to llama conversion and inference This commit generates the rope factors on conversion and adds them to the resulting model as a tensor. At inference time, these factors are passed to the `ggml_rope_ext` rope oepration, improving results for context windows above 8192 * Update convert_hf_to_gguf.py Co-authored-by: compilade * address comments * address comments * Update src/llama.cpp Co-authored-by: compilade * Update convert_hf_to_gguf.py Co-authored-by: compilade --------- Co-authored-by: compilade --- convert_hf_to_gguf.py | 28 ++++++++++++++++++++++++++++ src/llama.cpp | 14 ++++++++++++-- 2 files changed, 40 insertions(+), 2 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 4087187c19834..8ba3c5844d22e 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -1570,6 +1570,34 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter return [(self.map_tensor_name(name), data_torch)] def prepare_tensors(self): + if rope_scaling := self.find_hparam(["rope_scaling"], optional=True): + if rope_scaling.get("rope_type", '').lower() == "llama3": + base = self.hparams.get("rope_theta", 10000.0) + dim = self.hparams["hidden_size"] // self.hparams["num_attention_heads"] + freqs = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + + factor = rope_scaling.get("factor", 8.0) + low_freq_factor = rope_scaling.get("low_freq_factor", 1.0) + high_freq_factor = rope_scaling.get("high_freq_factor", 4.0) + old_context_len = self.hparams.get("original_max_position_embeddings", 8192) + + low_freq_wavelen = old_context_len / low_freq_factor + high_freq_wavelen = old_context_len / high_freq_factor + assert low_freq_wavelen != high_freq_wavelen + + rope_factors = [] + for freq in freqs: + wavelen = 2 * math.pi / freq + if wavelen < high_freq_wavelen: + rope_factors.append(1) + elif wavelen > low_freq_wavelen: + rope_factors.append(factor) + else: + smooth = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor) + rope_factors.append(1 / ((1 - smooth) / factor + smooth)) + + self.gguf_writer.add_tensor(self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FREQS), np.array(rope_factors, dtype=np.float32)) + super().prepare_tensors() if self._experts is not None: diff --git a/src/llama.cpp b/src/llama.cpp index e5713b0b04d0a..7bb2dfd4625b3 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -2451,6 +2451,7 @@ struct llama_layer { // long rope factors struct ggml_tensor * rope_long = nullptr; struct ggml_tensor * rope_short = nullptr; + struct ggml_tensor * rope_freqs = nullptr; // bitnet scale struct ggml_tensor * wq_scale; @@ -6060,6 +6061,8 @@ static bool llm_load_tensors( layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); + layer.rope_freqs = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ROPE_FREQS, "weight"), {n_embd/n_head/2}, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0)); + if (n_expert == 0) { layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}); layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}); @@ -8537,6 +8540,10 @@ struct llm_build_context { // choose long/short freq factors based on the context size const auto n_ctx_pre_seq = cparams.n_ctx / cparams.n_seq_max; + if (model.layers[il].rope_freqs != nullptr) { + return model.layers[il].rope_freqs; + } + if (n_ctx_pre_seq > hparams.n_ctx_orig_yarn) { return model.layers[il].rope_long; } @@ -8731,6 +8738,9 @@ struct llm_build_context { // self-attention { + // rope freq factors for llama3; may return nullptr for llama2 and other models + struct ggml_tensor * rope_factors = build_rope_factors(il); + // compute Q and K and RoPE them struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur); cb(Qcur, "Qcur", il); @@ -8754,14 +8764,14 @@ struct llm_build_context { } Qcur = ggml_rope_ext( - ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr, + ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, rope_factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); cb(Qcur, "Qcur", il); Kcur = ggml_rope_ext( - ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr, + ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, rope_factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow );