From a67466e1ade2529713498d2fa55793660bcc6bc7 Mon Sep 17 00:00:00 2001 From: Martin Courtois Date: Thu, 12 Oct 2023 20:15:23 +0200 Subject: [PATCH] fix: rotary position embedding missing argument --- megatron/model/language_model.py | 1 + megatron/model/transformer.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/megatron/model/language_model.py b/megatron/model/language_model.py index 5569f17347..0d544b2cd5 100644 --- a/megatron/model/language_model.py +++ b/megatron/model/language_model.py @@ -374,6 +374,7 @@ def __init__(self, # https://github.com/kingoflolz/mesh-transformer-jax/ self.rotary_pos_emb = RotaryEmbedding( rotary_dim, + rotary_percent=args.rotary_percent, seq_len_interpolation_factor=args.rotary_seq_len_interpolation_factor ) diff --git a/megatron/model/transformer.py b/megatron/model/transformer.py index bc15671752..71337c818f 100644 --- a/megatron/model/transformer.py +++ b/megatron/model/transformer.py @@ -15,7 +15,7 @@ from megatron.model.enums import AttnMaskType, LayerType, AttnType from megatron.model.fused_softmax import FusedScaleMaskSoftmax from megatron.model.fused_bias_gelu import bias_gelu_impl -from megatron.core.models.common.embeddings.rotary_pos_embedding import RotaryEmbedding +from megatron.core.models.common.embeddings.rotary_pos_embedding import RotaryEmbedding, apply_rotary_pos_emb from megatron.model.utils import attention_mask_func, openai_gelu, erf_gelu, get_norm from megatron.core.tensor_parallel import gather_from_sequence_parallel_region_to_moe, reduce_scatter_to_sequence_parallel_region_from_moe from megatron.core.parallel_state import get_tensor_model_parallel_group, get_tensor_and_expert_parallel_group