diff --git a/megatron/model/language_model.py b/megatron/model/language_model.py index 5569f17347..4cbdd2eef5 100644 --- a/megatron/model/language_model.py +++ b/megatron/model/language_model.py @@ -366,14 +366,12 @@ def __init__(self, rotary_dim = args.hidden_size // args.num_attention_heads \ if args.kv_channels is None else args.kv_channels - if args.rotary_percent < 1.0: - rotary_dim = int(rotary_dim * args.rotary_percent) - # partial rotary embeddings, which is better than full rotary # Wang and Komatsuzaki et al # https://github.com/kingoflolz/mesh-transformer-jax/ self.rotary_pos_emb = RotaryEmbedding( rotary_dim, + args.rotary_percent, seq_len_interpolation_factor=args.rotary_seq_len_interpolation_factor )