11# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
22# This file is adapted from language_model.py in Megatron-LM
33
4+ from typing import Literal , Optional
5+
46import torch
57from torch import einsum , nn
68from domino .arguments import get_args
1416from domino .tensor_parallel .partition import _initialize_affine_weight_gpu , set_tensor_model_parallel_attributes
1517from domino .tensor_parallel .partition import ColumnParallelLinear , RowParallelLinearNoComm
1618
19+ from megatron .core .models .common .embeddings .rotary_pos_embedding import RotaryEmbedding
20+ from megatron .model .utils import get_norm
21+
1722from deepspeed .runtime .domino .transformer import DominoTransformer
1823
1924def parallel_lm_logits (input_ , word_embeddings_weight , parallel_output ,
@@ -45,12 +50,18 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output,
4550def get_language_model (config , num_tokentypes ,
4651 encoder_attn_mask_type ,
4752 pre_process = True , post_process = True ):
53+ args = get_args ()
4854 language_model = TransformerLanguageModel (
4955 config ,
5056 encoder_attn_mask_type ,
5157 num_tokentypes = num_tokentypes ,
5258 pre_process = pre_process ,
53- post_process = post_process
59+ post_process = post_process ,
60+ position_embedding_type = args .position_embedding_type ,
61+ rotary_percent = args .rotary_percent ,
62+ rotary_base = args .rotary_base ,
63+ rope_scaling = args .use_rope_scaling ,
64+ seq_len_interpolation_factor = args .rotary_seq_len_interpolation_factor ,
5465 )
5566
5667 return language_model
@@ -85,37 +96,18 @@ def forward(self, input_ids, position_ids):
8596 return combined_embeds
8697
8798
88- class RotaryEmbedding (nn .Module ):
89- def __init__ (self , dim , seq_len_interpolation_factor = None ):
90- super ().__init__ ()
91- self .seq_len_interpolation_factor = seq_len_interpolation_factor
92- inv_freq = 1.0 / (10000 ** (torch .arange (0 , dim , 2 ).float () / dim ))
93- self .register_buffer ('inv_freq' , inv_freq , persistent = False )
94-
95- def forward (self , max_seq_len , offset = 0 ):
96- seq = torch .arange (max_seq_len , device = self .inv_freq .device ) + offset
97- if self .seq_len_interpolation_factor is not None :
98- seq = seq .type_as (self .inv_freq )
99- seq *= 1 / self .seq_len_interpolation_factor
100- freqs = einsum ('i , j -> i j' , seq .type_as (self .inv_freq ), self .inv_freq )
101- # first part even vector components, second part odd vector components,
102- # 2 * dim in dimension size
103- emb = torch .cat ((freqs , freqs ), dim = - 1 )
104- # emb [seq_length, .., dim]
105- return emb [:, None , None , :]
106-
107- # def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
108- # state_dict.pop(f'{prefix}inv_freq', None)
109- # return super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
110-
111-
11299class TransformerLanguageModel (DominoModule ):
113100 def __init__ (self ,
114101 config ,
115102 encoder_attn_mask_type ,
116103 num_tokentypes = 0 ,
117104 pre_process = True ,
118- post_process = True ):
105+ post_process = True ,
106+ position_embedding_type : Literal ['learned_absolute' , 'rope' , 'none' ] = 'learned_absolute' ,
107+ rotary_percent : float = 1.0 ,
108+ rotary_base : int = 10000 ,
109+ rope_scaling : bool = False ,
110+ seq_len_interpolation_factor : Optional [float ] = None ,):
119111
120112 args = get_args ()
121113 super (TransformerLanguageModel , self ).__init__ (share_embeddings_and_output_weights = True )
@@ -127,6 +119,11 @@ def __init__(self,
127119 self .init_method = config .init_method
128120 self .encoder_attn_mask_type = encoder_attn_mask_type
129121 self .encoder_hidden_state = None
122+ self .position_embedding_type = position_embedding_type
123+ self .rotary_percent = rotary_percent
124+ self .rotary_base = rotary_base
125+ self .rotary_scaling = rope_scaling
126+ self .seq_length = config .seq_length
130127
131128 if self .pre_process :
132129 self .embedding = Embedding (self .hidden_size ,
@@ -138,19 +135,18 @@ def __init__(self,
138135 self .use_rotary_position_embeddings = \
139136 args .position_embedding_type == 'rope'
140137 if self .use_rotary_position_embeddings :
141- self .seq_length = args .seq_length
142- rotary_dim = args .hidden_size // args .num_attention_heads \
143- if args .kv_channels is None else args .kv_channels
144- if args .rotary_percent < 1.0 :
145- rotary_dim = int (rotary_dim * args .rotary_percent )
146138 self .rotary_pos_emb = RotaryEmbedding (
147- rotary_dim ,
148- seq_len_interpolation_factor = args .rotary_seq_len_interpolation_factor
139+ kv_channels = config .kv_channels ,
140+ rotary_percent = rotary_percent ,
141+ rotary_interleaved = config .rotary_interleaved ,
142+ seq_len_interpolation_factor = seq_len_interpolation_factor ,
143+ rotary_base = rotary_base ,
144+ rope_scaling = rope_scaling ,
149145 )
150146
151147 self .encoder = DominoTransformer (
152148 config , ModelType .encoder_or_decoder , mpu ,
153- fused_layer_norm , _initialize_affine_weight_gpu ,
149+ get_norm , _initialize_affine_weight_gpu ,
154150 ColumnParallelLinear , RowParallelLinearNoComm , apply_rotary_pos_emb ,
155151 bias_dropout_add_fused_train , bias_dropout_add_fused_inference ,
156152 self_attn_mask_type = self .encoder_attn_mask_type ,
0 commit comments