41
41
def create_sinusoidal_positions (num_pos : int , dim : int ) -> torch .Tensor :
42
42
inv_freq = 1.0 / (10000 ** (torch .arange (0 , dim , 2 , dtype = torch .int64 ) / dim ))
43
43
sinusoid_inp = torch .einsum ("i , j -> i j" , torch .arange (num_pos , dtype = torch .int64 ).float (), inv_freq ).float ()
44
- return torch .cat ((torch .sin (sinusoid_inp ), torch .cos (sinusoid_inp )), dim = 1 )
44
+ sin , cos = torch .sin (sinusoid_inp ), torch .cos (sinusoid_inp )
45
+ out = torch .cat ((sin , cos ), dim = 1 )
46
+ return out
45
47
46
48
47
49
# Copied from transformers.models.gptj.modeling_gptj.rotate_every_two
48
50
def rotate_every_two (x : torch .Tensor ) -> torch .Tensor :
49
51
x1 = x [:, :, :, ::2 ]
50
52
x2 = x [:, :, :, 1 ::2 ]
51
- x = torch .stack ((- x2 , x1 ), dim = - 1 )
52
- return x .flatten (- 2 ) # in einsum notation: rearrange(x, '... d j -> ... (d j)')
53
+ return torch .concat ((- x2 , x1 ), dim = - 1 )
53
54
54
55
55
56
# Copied from transformers.models.gptj.modeling_gptj.apply_rotary_pos_emb
56
57
def apply_rotary_pos_emb (tensor : torch .Tensor , sin : torch .Tensor , cos : torch .Tensor ) -> torch .Tensor :
57
- sin = torch .repeat_interleave (sin [:, :, None , :], 2 , 3 )
58
- cos = torch .repeat_interleave (cos [:, :, None , :], 2 , 3 )
58
+ sin = torch .repeat_interleave (sin , 2 , dim = - 1 )
59
+ cos = torch .repeat_interleave (cos , 2 , dim = - 1 )
59
60
return (tensor * cos ) + (rotate_every_two (tensor ) * sin )
60
61
61
62
@@ -87,25 +88,24 @@ def __init__(self, config, layer_idx=None):
87
88
88
89
self .out_proj = nn .Linear (self .embed_dim , self .embed_dim , bias = False )
89
90
self .rotary_dim = config .rotary_dim
90
- pos_embd_dim = self .rotary_dim or self .embed_dim
91
+ pos_embd_dim = self .rotary_dim or self .head_dim
92
+ # `embed_positions` of shape `(max_positions, 2 * pos_embd_dim)`
91
93
self .embed_positions = create_sinusoidal_positions (max_positions , pos_embd_dim )
92
94
95
+ # TODO: Add comment on the role of mp_num. Why this complex reshaping?
93
96
def _split_heads (self , x , n_head , dim_head , mp_num ):
94
97
reshaped = x .reshape (x .shape [:- 1 ] + (n_head // mp_num , dim_head ))
95
98
reshaped = reshaped .reshape (x .shape [:- 2 ] + (- 1 ,) + reshaped .shape [- 1 :])
96
99
return reshaped
97
100
98
- def _merge_heads (self , tensor , num_attention_heads , attn_head_size ) :
101
+ def _merge_heads (self , tensor : torch . Tensor ) -> torch . Tensor :
99
102
"""
100
103
Merges attn_head_size dim and num_attn_heads dim into n_ctx
101
104
"""
102
- if len (tensor .shape ) == 5 :
103
- tensor = tensor .permute (0 , 1 , 3 , 2 , 4 ).contiguous ()
104
- elif len (tensor .shape ) == 4 :
105
- tensor = tensor .permute (0 , 2 , 1 , 3 ).contiguous ()
106
- else :
107
- raise ValueError (f"Input tensor rank should be one of [4, 5], but is: { len (tensor .shape )} " )
108
- new_shape = tensor .size ()[:- 2 ] + (num_attention_heads * attn_head_size ,)
105
+ if not (4 <= tensor .dim () <= 5 ):
106
+ raise ValueError (f"Input tensor rank should be one of [4, 5], but is: { tensor .dim ()} " )
107
+ tensor = tensor .transpose (- 2 , - 3 ).contiguous ()
108
+ new_shape = tensor .size ()[:- 2 ] + (self .num_attention_heads * self .head_dim ,)
109
109
return tensor .view (new_shape )
110
110
111
111
def _attn (
@@ -153,33 +153,44 @@ def forward(
153
153
Tuple [torch .Tensor , Tuple [torch .Tensor ]],
154
154
Optional [Tuple [torch .Tensor , Tuple [torch .Tensor ], Tuple [torch .Tensor , ...]]],
155
155
]:
156
- qkv = self .qkv_proj (hidden_states )
156
+ if position_ids is None :
157
+ raise ValueError ("position_ids must be given" )
158
+ qkv = self .qkv_proj (hidden_states ) # (B, T, 3 * n_head * head_dim)
157
159
# TODO(enijkamp): factor out number of logical TPU-v4 cores or make forward pass agnostic
158
160
mp_num = 4
159
161
qkv_split = qkv .reshape (qkv .shape [:- 1 ] + (mp_num , - 1 ))
160
162
161
163
local_dim = self .head_dim * self .num_attention_heads // mp_num
162
164
query , value , key = torch .split (qkv_split , local_dim , dim = - 1 )
165
+ # Shapes (B, T, mp_num, local_dim), local_dim = n_head * head_dim // mp_num
163
166
query = self ._split_heads (query , self .num_attention_heads , self .head_dim , mp_num = mp_num )
164
167
key = self ._split_heads (key , self .num_attention_heads , self .head_dim , mp_num = mp_num )
165
168
166
169
value = self ._split_heads (value , self .num_attention_heads , self .head_dim , mp_num = mp_num )
167
- value = value .permute (0 , 2 , 1 , 3 )
170
+ # query, key, value: (B, T, n_head, head_dim)
171
+ value = value .transpose (1 , 2 ) # (B, n_head, T, head_dim)
168
172
169
173
embed_positions = self .embed_positions
170
174
if embed_positions .device != position_ids .device :
171
175
embed_positions = embed_positions .to (position_ids .device )
172
176
self .embed_positions = embed_positions
173
177
174
- sincos = embed_positions [position_ids ]
178
+ if position_ids .dim () == 1 :
179
+ position_ids = position_ids .unsqueeze (0 )
180
+ embed_positions = embed_positions .unsqueeze (0 ).repeat (position_ids .shape [0 ], 1 , 1 )
181
+ repeated_position_ids = position_ids .unsqueeze (- 1 ).repeat (1 , 1 , embed_positions .shape [- 1 ])
182
+ sincos = torch .gather (embed_positions , 1 , repeated_position_ids )
175
183
sin , cos = torch .split (sincos , sincos .shape [- 1 ] // 2 , dim = - 1 )
184
+ sin = sin .unsqueeze (2 )
185
+ cos = cos .unsqueeze (2 )
186
+ # cos, sin: (B, T, 1, rotary_dim // 2)
176
187
177
188
if self .rotary_dim is not None :
178
- k_rot = key [:, :, : , : self .rotary_dim ]
179
- k_pass = key [:, :, : , self .rotary_dim :]
189
+ k_rot = key [... , : self .rotary_dim ]
190
+ k_pass = key [... , self .rotary_dim :]
180
191
181
- q_rot = query [:, :, : , : self .rotary_dim ]
182
- q_pass = query [:, :, : , self .rotary_dim :]
192
+ q_rot = query [... , : self .rotary_dim ]
193
+ q_pass = query [... , self .rotary_dim :]
183
194
184
195
k_rot = apply_rotary_pos_emb (k_rot , sin , cos )
185
196
q_rot = apply_rotary_pos_emb (q_rot , sin , cos )
@@ -190,8 +201,9 @@ def forward(
190
201
key = apply_rotary_pos_emb (key , sin , cos )
191
202
query = apply_rotary_pos_emb (query , sin , cos )
192
203
193
- key = key .permute (0 , 2 , 1 , 3 )
194
- query = query .permute (0 , 2 , 1 , 3 )
204
+ key = key .transpose (1 , 2 )
205
+ query = query .transpose (1 , 2 )
206
+ # query, key, value: (B, n_head, T, head_dim)
195
207
196
208
# Note that this cast is quite ugly, but is not implemented before ROPE as k_rot in the original codebase is always in fp32.
197
209
# Reference: https://github.com/salesforce/CodeGen/blob/f210c3bb1216c975ad858cd4132c0fdeabf4bfc2/codegen1/jaxformer/hf/codegen/modeling_codegen.py#L38
@@ -207,7 +219,7 @@ def forward(
207
219
# compute self-attention: V x Softmax(QK^T)
208
220
attn_output , attn_weights = self ._attn (query , key , value , attention_mask , head_mask )
209
221
210
- attn_output = self ._merge_heads (attn_output , self . num_attention_heads , self . head_dim )
222
+ attn_output = self ._merge_heads (attn_output )
211
223
attn_output = self .out_proj (attn_output )
212
224
attn_output = self .resid_dropout (attn_output )
213
225
0 commit comments