Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Several fixes related to rotary position embeddings #35376

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions src/transformers/models/aria/modeling_aria.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,16 +437,14 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1)


def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.

Args:
q (`torch.Tensor`): The query tensor.
k (`torch.Tensor`): The key tensor.
cos (`torch.Tensor`): The cosine part of the rotary embedding.
sin (`torch.Tensor`): The sine part of the rotary embedding.
position_ids (`torch.Tensor`, *optional*):
Deprecated and unused.
unsqueeze_dim (`int`, *optional*, defaults to 1):
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
Expand Down Expand Up @@ -537,6 +535,8 @@ def forward(
cache_position: Optional[torch.LongTensor] = None,
**kwargs: Unpack[FlashAttentionKwargs],
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if position_embeddings is None:
raise ValueError("position_embeddings = (cos, sin) must be given")
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.head_dim)

Expand Down Expand Up @@ -603,13 +603,13 @@ def __init__(self, config: AriaTextConfig, layer_idx: int):
def forward(
self,
hidden_states: torch.Tensor,
position_embeddings: Tuple[torch.Tensor, torch.Tensor],
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
**kwargs: Unpack[FlashAttentionKwargs],
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
residual = hidden_states
Expand All @@ -619,13 +619,13 @@ def forward(
# Self Attention
hidden_states, self_attn_weights = self.self_attn(
hidden_states=hidden_states,
position_embeddings=position_embeddings,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
**kwargs,
)
hidden_states = residual + hidden_states
Expand Down Expand Up @@ -963,24 +963,24 @@ def forward(
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
position_embeddings,
causal_mask,
position_ids,
past_key_values,
output_attentions,
use_cache,
cache_position,
position_embeddings,
)
else:
layer_outputs = decoder_layer(
hidden_states,
position_embeddings=position_embeddings,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
**flash_attn_kwargs,
)

Expand Down
6 changes: 3 additions & 3 deletions src/transformers/models/bamba/modeling_bamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ def eager_attention_forward(


# Adapted from transformers.models.glm.modular_glm.apply_rotary_pos_emb
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
Removes the interleaving of cos and sin from GLM
Expand All @@ -240,8 +240,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
k (`torch.Tensor`): The key tensor.
cos (`torch.Tensor`): The cosine part of the rotary embedding.
sin (`torch.Tensor`): The sine part of the rotary embedding.
position_ids (`torch.Tensor`, *optional*):
Deprecated and unused.
unsqueeze_dim (`int`, *optional*, defaults to 1):
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
Expand Down Expand Up @@ -305,6 +303,8 @@ def forward(
cache_position: Optional[torch.LongTensor] = None,
**kwargs: Unpack[FlashAttentionKwargs],
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if position_embeddings is None:
raise ValueError("position_embeddings = (cos, sin) must be given")
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.head_dim)

Expand Down
4 changes: 1 addition & 3 deletions src/transformers/models/bamba/modular_bamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ class BambaRotaryEmbedding(LlamaRotaryEmbedding):


# Adapted from transformers.models.glm.modular_glm.apply_rotary_pos_emb
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
Removes the interleaving of cos and sin from GLM
Expand All @@ -154,8 +154,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
k (`torch.Tensor`): The key tensor.
cos (`torch.Tensor`): The cosine part of the rotary embedding.
sin (`torch.Tensor`): The sine part of the rotary embedding.
position_ids (`torch.Tensor`, *optional*):
Deprecated and unused.
unsqueeze_dim (`int`, *optional*, defaults to 1):
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
Expand Down
6 changes: 2 additions & 4 deletions src/transformers/models/chameleon/modeling_chameleon.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,16 +153,14 @@ def rotate_half(x):


# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.

Args:
q (`torch.Tensor`): The query tensor.
k (`torch.Tensor`): The key tensor.
cos (`torch.Tensor`): The cosine part of the rotary embedding.
sin (`torch.Tensor`): The sine part of the rotary embedding.
position_ids (`torch.Tensor`, *optional*):
Deprecated and unused.
unsqueeze_dim (`int`, *optional*, defaults to 1):
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
Expand Down Expand Up @@ -532,7 +530,7 @@ def forward(
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

cos, sin = self.rotary_emb(value_states, position_ids)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, None)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)

if past_key_value is not None:
# sin and cos are specific to RoPE models; position_ids needed for the static cache
Expand Down
60 changes: 36 additions & 24 deletions src/transformers/models/codegen/modeling_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,21 +41,22 @@
def create_sinusoidal_positions(num_pos: int, dim: int) -> torch.Tensor:
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, dtype=torch.int64) / dim))
sinusoid_inp = torch.einsum("i , j -> i j", torch.arange(num_pos, dtype=torch.int64).float(), inv_freq).float()
return torch.cat((torch.sin(sinusoid_inp), torch.cos(sinusoid_inp)), dim=1)
sin, cos = torch.sin(sinusoid_inp), torch.cos(sinusoid_inp)
out = torch.cat((sin, cos), dim=1)
return out


# Copied from transformers.models.gptj.modeling_gptj.rotate_every_two
def rotate_every_two(x: torch.Tensor) -> torch.Tensor:
x1 = x[:, :, :, ::2]
x2 = x[:, :, :, 1::2]
x = torch.stack((-x2, x1), dim=-1)
return x.flatten(-2) # in einsum notation: rearrange(x, '... d j -> ... (d j)')
return torch.concat((-x2, x1), dim=-1)


# Copied from transformers.models.gptj.modeling_gptj.apply_rotary_pos_emb
def apply_rotary_pos_emb(tensor: torch.Tensor, sin: torch.Tensor, cos: torch.Tensor) -> torch.Tensor:
sin = torch.repeat_interleave(sin[:, :, None, :], 2, 3)
cos = torch.repeat_interleave(cos[:, :, None, :], 2, 3)
sin = torch.repeat_interleave(sin, 2, dim=-1)
cos = torch.repeat_interleave(cos, 2, dim=-1)
return (tensor * cos) + (rotate_every_two(tensor) * sin)


Expand Down Expand Up @@ -87,25 +88,24 @@ def __init__(self, config, layer_idx=None):

self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
self.rotary_dim = config.rotary_dim
pos_embd_dim = self.rotary_dim or self.embed_dim
pos_embd_dim = self.rotary_dim or self.head_dim
# `embed_positions` of shape `(max_positions, 2 * pos_embd_dim)`
self.embed_positions = create_sinusoidal_positions(max_positions, pos_embd_dim)

# TODO: Add comment on the role of mp_num. Why this complex reshaping?
def _split_heads(self, x, n_head, dim_head, mp_num):
reshaped = x.reshape(x.shape[:-1] + (n_head // mp_num, dim_head))
reshaped = reshaped.reshape(x.shape[:-2] + (-1,) + reshaped.shape[-1:])
return reshaped

def _merge_heads(self, tensor, num_attention_heads, attn_head_size):
def _merge_heads(self, tensor: torch.Tensor) -> torch.Tensor:
"""
Merges attn_head_size dim and num_attn_heads dim into n_ctx
"""
if len(tensor.shape) == 5:
tensor = tensor.permute(0, 1, 3, 2, 4).contiguous()
elif len(tensor.shape) == 4:
tensor = tensor.permute(0, 2, 1, 3).contiguous()
else:
raise ValueError(f"Input tensor rank should be one of [4, 5], but is: {len(tensor.shape)}")
new_shape = tensor.size()[:-2] + (num_attention_heads * attn_head_size,)
if not (4 <= tensor.dim() <= 5):
raise ValueError(f"Input tensor rank should be one of [4, 5], but is: {tensor.dim()}")
tensor = tensor.transpose(-2, -3).contiguous()
new_shape = tensor.size()[:-2] + (self.num_attention_heads * self.head_dim,)
return tensor.view(new_shape)

def _attn(
Expand Down Expand Up @@ -153,33 +153,44 @@ def forward(
Tuple[torch.Tensor, Tuple[torch.Tensor]],
Optional[Tuple[torch.Tensor, Tuple[torch.Tensor], Tuple[torch.Tensor, ...]]],
]:
qkv = self.qkv_proj(hidden_states)
if position_ids is None:
raise ValueError("position_ids must be given")
qkv = self.qkv_proj(hidden_states) # (B, T, 3 * n_head * head_dim)
# TODO(enijkamp): factor out number of logical TPU-v4 cores or make forward pass agnostic
mp_num = 4
qkv_split = qkv.reshape(qkv.shape[:-1] + (mp_num, -1))

local_dim = self.head_dim * self.num_attention_heads // mp_num
query, value, key = torch.split(qkv_split, local_dim, dim=-1)
# Shapes (B, T, mp_num, local_dim), local_dim = n_head * head_dim // mp_num
query = self._split_heads(query, self.num_attention_heads, self.head_dim, mp_num=mp_num)
key = self._split_heads(key, self.num_attention_heads, self.head_dim, mp_num=mp_num)

value = self._split_heads(value, self.num_attention_heads, self.head_dim, mp_num=mp_num)
value = value.permute(0, 2, 1, 3)
# query, key, value: (B, T, n_head, head_dim)
value = value.transpose(1, 2) # (B, n_head, T, head_dim)

embed_positions = self.embed_positions
if embed_positions.device != position_ids.device:
embed_positions = embed_positions.to(position_ids.device)
self.embed_positions = embed_positions

sincos = embed_positions[position_ids]
if position_ids.dim() == 1:
position_ids = position_ids.unsqueeze(0)
embed_positions = embed_positions.unsqueeze(0).repeat(position_ids.shape[0], 1, 1)
repeated_position_ids = position_ids.unsqueeze(-1).repeat(1, 1, embed_positions.shape[-1])
sincos = torch.gather(embed_positions, 1, repeated_position_ids)
sin, cos = torch.split(sincos, sincos.shape[-1] // 2, dim=-1)
sin = sin.unsqueeze(2)
cos = cos.unsqueeze(2)
# cos, sin: (B, T, 1, rotary_dim // 2)

if self.rotary_dim is not None:
k_rot = key[:, :, :, : self.rotary_dim]
k_pass = key[:, :, :, self.rotary_dim :]
k_rot = key[..., : self.rotary_dim]
k_pass = key[..., self.rotary_dim :]

q_rot = query[:, :, :, : self.rotary_dim]
q_pass = query[:, :, :, self.rotary_dim :]
q_rot = query[..., : self.rotary_dim]
q_pass = query[..., self.rotary_dim :]

k_rot = apply_rotary_pos_emb(k_rot, sin, cos)
q_rot = apply_rotary_pos_emb(q_rot, sin, cos)
Expand All @@ -190,8 +201,9 @@ def forward(
key = apply_rotary_pos_emb(key, sin, cos)
query = apply_rotary_pos_emb(query, sin, cos)

key = key.permute(0, 2, 1, 3)
query = query.permute(0, 2, 1, 3)
key = key.transpose(1, 2)
query = query.transpose(1, 2)
# query, key, value: (B, n_head, T, head_dim)

# Note that this cast is quite ugly, but is not implemented before ROPE as k_rot in the original codebase is always in fp32.
# Reference: https://github.com/salesforce/CodeGen/blob/f210c3bb1216c975ad858cd4132c0fdeabf4bfc2/codegen1/jaxformer/hf/codegen/modeling_codegen.py#L38
Expand All @@ -207,7 +219,7 @@ def forward(
# compute self-attention: V x Softmax(QK^T)
attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)

attn_output = self._merge_heads(attn_output, self.num_attention_heads, self.head_dim)
attn_output = self._merge_heads(attn_output)
attn_output = self.out_proj(attn_output)
attn_output = self.resid_dropout(attn_output)

Expand Down
Loading