Skip to content

Commit 99dacd5

Browse files
committed
Several fixes related to rotary position embeddings
First part of resolution of huggingface#35233 - Changes related to `position_embeddings` being a mandatory argument - Remove `position_ids` argument of `apply_rotary_pos_emb` - Replace `torch.stack` by `torch.cat`, former requires equal shapes - `esm`: RoPE depends on `position_ids`, which was ignored. - `gpt_neox`: Selection of attention compute type via class removed - `gptj`, `codegen`: RoPE must be applied per head, and some shape issues. - `nemotron`: `config.partial_rotary_factor` was not implemented.
1 parent 5cabc75 commit 99dacd5

File tree

65 files changed

+660
-521
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

65 files changed

+660
-521
lines changed

src/transformers/models/aria/modeling_aria.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -437,16 +437,14 @@ def rotate_half(x):
437437
return torch.cat((-x2, x1), dim=-1)
438438

439439

440-
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
440+
def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
441441
"""Applies Rotary Position Embedding to the query and key tensors.
442442
443443
Args:
444444
q (`torch.Tensor`): The query tensor.
445445
k (`torch.Tensor`): The key tensor.
446446
cos (`torch.Tensor`): The cosine part of the rotary embedding.
447447
sin (`torch.Tensor`): The sine part of the rotary embedding.
448-
position_ids (`torch.Tensor`, *optional*):
449-
Deprecated and unused.
450448
unsqueeze_dim (`int`, *optional*, defaults to 1):
451449
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
452450
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
@@ -537,6 +535,8 @@ def forward(
537535
cache_position: Optional[torch.LongTensor] = None,
538536
**kwargs: Unpack[FlashAttentionKwargs],
539537
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
538+
if position_embeddings is None:
539+
raise ValueError("position_embeddings = (cos, sin) must be given")
540540
input_shape = hidden_states.shape[:-1]
541541
hidden_shape = (*input_shape, -1, self.head_dim)
542542

@@ -603,13 +603,13 @@ def __init__(self, config: AriaTextConfig, layer_idx: int):
603603
def forward(
604604
self,
605605
hidden_states: torch.Tensor,
606+
position_embeddings: Tuple[torch.Tensor, torch.Tensor],
606607
attention_mask: Optional[torch.Tensor] = None,
607608
position_ids: Optional[torch.LongTensor] = None,
608609
past_key_value: Optional[Cache] = None,
609610
output_attentions: Optional[bool] = False,
610611
use_cache: Optional[bool] = False,
611612
cache_position: Optional[torch.LongTensor] = None,
612-
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
613613
**kwargs: Unpack[FlashAttentionKwargs],
614614
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
615615
residual = hidden_states
@@ -619,13 +619,13 @@ def forward(
619619
# Self Attention
620620
hidden_states, self_attn_weights = self.self_attn(
621621
hidden_states=hidden_states,
622+
position_embeddings=position_embeddings,
622623
attention_mask=attention_mask,
623624
position_ids=position_ids,
624625
past_key_value=past_key_value,
625626
output_attentions=output_attentions,
626627
use_cache=use_cache,
627628
cache_position=cache_position,
628-
position_embeddings=position_embeddings,
629629
**kwargs,
630630
)
631631
hidden_states = residual + hidden_states
@@ -963,24 +963,24 @@ def forward(
963963
layer_outputs = self._gradient_checkpointing_func(
964964
decoder_layer.__call__,
965965
hidden_states,
966+
position_embeddings,
966967
causal_mask,
967968
position_ids,
968969
past_key_values,
969970
output_attentions,
970971
use_cache,
971972
cache_position,
972-
position_embeddings,
973973
)
974974
else:
975975
layer_outputs = decoder_layer(
976976
hidden_states,
977+
position_embeddings=position_embeddings,
977978
attention_mask=causal_mask,
978979
position_ids=position_ids,
979980
past_key_value=past_key_values,
980981
output_attentions=output_attentions,
981982
use_cache=use_cache,
982983
cache_position=cache_position,
983-
position_embeddings=position_embeddings,
984984
**flash_attn_kwargs,
985985
)
986986

src/transformers/models/bamba/modeling_bamba.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,7 @@ def eager_attention_forward(
230230

231231

232232
# Adapted from transformers.models.glm.modular_glm.apply_rotary_pos_emb
233-
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
233+
def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
234234
"""Applies Rotary Position Embedding to the query and key tensors.
235235
236236
Removes the interleaving of cos and sin from GLM
@@ -240,8 +240,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
240240
k (`torch.Tensor`): The key tensor.
241241
cos (`torch.Tensor`): The cosine part of the rotary embedding.
242242
sin (`torch.Tensor`): The sine part of the rotary embedding.
243-
position_ids (`torch.Tensor`, *optional*):
244-
Deprecated and unused.
245243
unsqueeze_dim (`int`, *optional*, defaults to 1):
246244
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
247245
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
@@ -305,6 +303,8 @@ def forward(
305303
cache_position: Optional[torch.LongTensor] = None,
306304
**kwargs: Unpack[FlashAttentionKwargs],
307305
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
306+
if position_embeddings is None:
307+
raise ValueError("position_embeddings = (cos, sin) must be given")
308308
input_shape = hidden_states.shape[:-1]
309309
hidden_shape = (*input_shape, -1, self.head_dim)
310310

src/transformers/models/bamba/modular_bamba.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ class BambaRotaryEmbedding(LlamaRotaryEmbedding):
144144

145145

146146
# Adapted from transformers.models.glm.modular_glm.apply_rotary_pos_emb
147-
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
147+
def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
148148
"""Applies Rotary Position Embedding to the query and key tensors.
149149
150150
Removes the interleaving of cos and sin from GLM
@@ -154,8 +154,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
154154
k (`torch.Tensor`): The key tensor.
155155
cos (`torch.Tensor`): The cosine part of the rotary embedding.
156156
sin (`torch.Tensor`): The sine part of the rotary embedding.
157-
position_ids (`torch.Tensor`, *optional*):
158-
Deprecated and unused.
159157
unsqueeze_dim (`int`, *optional*, defaults to 1):
160158
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
161159
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note

src/transformers/models/chameleon/modeling_chameleon.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -153,16 +153,14 @@ def rotate_half(x):
153153

154154

155155
# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
156-
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
156+
def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
157157
"""Applies Rotary Position Embedding to the query and key tensors.
158158
159159
Args:
160160
q (`torch.Tensor`): The query tensor.
161161
k (`torch.Tensor`): The key tensor.
162162
cos (`torch.Tensor`): The cosine part of the rotary embedding.
163163
sin (`torch.Tensor`): The sine part of the rotary embedding.
164-
position_ids (`torch.Tensor`, *optional*):
165-
Deprecated and unused.
166164
unsqueeze_dim (`int`, *optional*, defaults to 1):
167165
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
168166
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
@@ -532,7 +530,7 @@ def forward(
532530
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
533531

534532
cos, sin = self.rotary_emb(value_states, position_ids)
535-
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, None)
533+
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
536534

537535
if past_key_value is not None:
538536
# sin and cos are specific to RoPE models; position_ids needed for the static cache

src/transformers/models/codegen/modeling_codegen.py

Lines changed: 36 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -41,21 +41,22 @@
4141
def create_sinusoidal_positions(num_pos: int, dim: int) -> torch.Tensor:
4242
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, dtype=torch.int64) / dim))
4343
sinusoid_inp = torch.einsum("i , j -> i j", torch.arange(num_pos, dtype=torch.int64).float(), inv_freq).float()
44-
return torch.cat((torch.sin(sinusoid_inp), torch.cos(sinusoid_inp)), dim=1)
44+
sin, cos = torch.sin(sinusoid_inp), torch.cos(sinusoid_inp)
45+
out = torch.cat((sin, cos), dim=1)
46+
return out
4547

4648

4749
# Copied from transformers.models.gptj.modeling_gptj.rotate_every_two
4850
def rotate_every_two(x: torch.Tensor) -> torch.Tensor:
4951
x1 = x[:, :, :, ::2]
5052
x2 = x[:, :, :, 1::2]
51-
x = torch.stack((-x2, x1), dim=-1)
52-
return x.flatten(-2) # in einsum notation: rearrange(x, '... d j -> ... (d j)')
53+
return torch.concat((-x2, x1), dim=-1)
5354

5455

5556
# Copied from transformers.models.gptj.modeling_gptj.apply_rotary_pos_emb
5657
def apply_rotary_pos_emb(tensor: torch.Tensor, sin: torch.Tensor, cos: torch.Tensor) -> torch.Tensor:
57-
sin = torch.repeat_interleave(sin[:, :, None, :], 2, 3)
58-
cos = torch.repeat_interleave(cos[:, :, None, :], 2, 3)
58+
sin = torch.repeat_interleave(sin, 2, dim=-1)
59+
cos = torch.repeat_interleave(cos, 2, dim=-1)
5960
return (tensor * cos) + (rotate_every_two(tensor) * sin)
6061

6162

@@ -87,25 +88,24 @@ def __init__(self, config, layer_idx=None):
8788

8889
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
8990
self.rotary_dim = config.rotary_dim
90-
pos_embd_dim = self.rotary_dim or self.embed_dim
91+
pos_embd_dim = self.rotary_dim or self.head_dim
92+
# `embed_positions` of shape `(max_positions, 2 * pos_embd_dim)`
9193
self.embed_positions = create_sinusoidal_positions(max_positions, pos_embd_dim)
9294

95+
# TODO: Add comment on the role of mp_num. Why this complex reshaping?
9396
def _split_heads(self, x, n_head, dim_head, mp_num):
9497
reshaped = x.reshape(x.shape[:-1] + (n_head // mp_num, dim_head))
9598
reshaped = reshaped.reshape(x.shape[:-2] + (-1,) + reshaped.shape[-1:])
9699
return reshaped
97100

98-
def _merge_heads(self, tensor, num_attention_heads, attn_head_size):
101+
def _merge_heads(self, tensor: torch.Tensor) -> torch.Tensor:
99102
"""
100103
Merges attn_head_size dim and num_attn_heads dim into n_ctx
101104
"""
102-
if len(tensor.shape) == 5:
103-
tensor = tensor.permute(0, 1, 3, 2, 4).contiguous()
104-
elif len(tensor.shape) == 4:
105-
tensor = tensor.permute(0, 2, 1, 3).contiguous()
106-
else:
107-
raise ValueError(f"Input tensor rank should be one of [4, 5], but is: {len(tensor.shape)}")
108-
new_shape = tensor.size()[:-2] + (num_attention_heads * attn_head_size,)
105+
if not (4 <= tensor.dim() <= 5):
106+
raise ValueError(f"Input tensor rank should be one of [4, 5], but is: {tensor.dim()}")
107+
tensor = tensor.transpose(-2, -3).contiguous()
108+
new_shape = tensor.size()[:-2] + (self.num_attention_heads * self.head_dim,)
109109
return tensor.view(new_shape)
110110

111111
def _attn(
@@ -153,33 +153,44 @@ def forward(
153153
Tuple[torch.Tensor, Tuple[torch.Tensor]],
154154
Optional[Tuple[torch.Tensor, Tuple[torch.Tensor], Tuple[torch.Tensor, ...]]],
155155
]:
156-
qkv = self.qkv_proj(hidden_states)
156+
if position_ids is None:
157+
raise ValueError("position_ids must be given")
158+
qkv = self.qkv_proj(hidden_states) # (B, T, 3 * n_head * head_dim)
157159
# TODO(enijkamp): factor out number of logical TPU-v4 cores or make forward pass agnostic
158160
mp_num = 4
159161
qkv_split = qkv.reshape(qkv.shape[:-1] + (mp_num, -1))
160162

161163
local_dim = self.head_dim * self.num_attention_heads // mp_num
162164
query, value, key = torch.split(qkv_split, local_dim, dim=-1)
165+
# Shapes (B, T, mp_num, local_dim), local_dim = n_head * head_dim // mp_num
163166
query = self._split_heads(query, self.num_attention_heads, self.head_dim, mp_num=mp_num)
164167
key = self._split_heads(key, self.num_attention_heads, self.head_dim, mp_num=mp_num)
165168

166169
value = self._split_heads(value, self.num_attention_heads, self.head_dim, mp_num=mp_num)
167-
value = value.permute(0, 2, 1, 3)
170+
# query, key, value: (B, T, n_head, head_dim)
171+
value = value.transpose(1, 2) # (B, n_head, T, head_dim)
168172

169173
embed_positions = self.embed_positions
170174
if embed_positions.device != position_ids.device:
171175
embed_positions = embed_positions.to(position_ids.device)
172176
self.embed_positions = embed_positions
173177

174-
sincos = embed_positions[position_ids]
178+
if position_ids.dim() == 1:
179+
position_ids = position_ids.unsqueeze(0)
180+
embed_positions = embed_positions.unsqueeze(0).repeat(position_ids.shape[0], 1, 1)
181+
repeated_position_ids = position_ids.unsqueeze(-1).repeat(1, 1, embed_positions.shape[-1])
182+
sincos = torch.gather(embed_positions, 1, repeated_position_ids)
175183
sin, cos = torch.split(sincos, sincos.shape[-1] // 2, dim=-1)
184+
sin = sin.unsqueeze(2)
185+
cos = cos.unsqueeze(2)
186+
# cos, sin: (B, T, 1, rotary_dim // 2)
176187

177188
if self.rotary_dim is not None:
178-
k_rot = key[:, :, :, : self.rotary_dim]
179-
k_pass = key[:, :, :, self.rotary_dim :]
189+
k_rot = key[..., : self.rotary_dim]
190+
k_pass = key[..., self.rotary_dim :]
180191

181-
q_rot = query[:, :, :, : self.rotary_dim]
182-
q_pass = query[:, :, :, self.rotary_dim :]
192+
q_rot = query[..., : self.rotary_dim]
193+
q_pass = query[..., self.rotary_dim :]
183194

184195
k_rot = apply_rotary_pos_emb(k_rot, sin, cos)
185196
q_rot = apply_rotary_pos_emb(q_rot, sin, cos)
@@ -190,8 +201,9 @@ def forward(
190201
key = apply_rotary_pos_emb(key, sin, cos)
191202
query = apply_rotary_pos_emb(query, sin, cos)
192203

193-
key = key.permute(0, 2, 1, 3)
194-
query = query.permute(0, 2, 1, 3)
204+
key = key.transpose(1, 2)
205+
query = query.transpose(1, 2)
206+
# query, key, value: (B, n_head, T, head_dim)
195207

196208
# Note that this cast is quite ugly, but is not implemented before ROPE as k_rot in the original codebase is always in fp32.
197209
# Reference: https://github.com/salesforce/CodeGen/blob/f210c3bb1216c975ad858cd4132c0fdeabf4bfc2/codegen1/jaxformer/hf/codegen/modeling_codegen.py#L38
@@ -207,7 +219,7 @@ def forward(
207219
# compute self-attention: V x Softmax(QK^T)
208220
attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
209221

210-
attn_output = self._merge_heads(attn_output, self.num_attention_heads, self.head_dim)
222+
attn_output = self._merge_heads(attn_output)
211223
attn_output = self.out_proj(attn_output)
212224
attn_output = self.resid_dropout(attn_output)
213225

0 commit comments

Comments
 (0)