diff --git a/src/transformers/models/aria/modeling_aria.py b/src/transformers/models/aria/modeling_aria.py index b96697bc0779e6..7f1dbb4c20d5ce 100644 --- a/src/transformers/models/aria/modeling_aria.py +++ b/src/transformers/models/aria/modeling_aria.py @@ -437,7 +437,7 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) -def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): +def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. Args: @@ -445,8 +445,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): k (`torch.Tensor`): The key tensor. cos (`torch.Tensor`): The cosine part of the rotary embedding. sin (`torch.Tensor`): The sine part of the rotary embedding. - position_ids (`torch.Tensor`, *optional*): - Deprecated and unused. unsqueeze_dim (`int`, *optional*, defaults to 1): The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note @@ -537,6 +535,8 @@ def forward( cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if position_embeddings is None: + raise ValueError("position_embeddings = (cos, sin) must be given") input_shape = hidden_states.shape[:-1] hidden_shape = (*input_shape, -1, self.head_dim) @@ -603,13 +603,13 @@ def __init__(self, config: AriaTextConfig, layer_idx: int): def forward( self, hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: residual = hidden_states @@ -619,13 +619,13 @@ def forward( # Self Attention hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, + position_embeddings=position_embeddings, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, - position_embeddings=position_embeddings, **kwargs, ) hidden_states = residual + hidden_states @@ -963,24 +963,24 @@ def forward( layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, + position_embeddings, causal_mask, position_ids, past_key_values, output_attentions, use_cache, cache_position, - position_embeddings, ) else: layer_outputs = decoder_layer( hidden_states, + position_embeddings=position_embeddings, attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, - position_embeddings=position_embeddings, **flash_attn_kwargs, ) diff --git a/src/transformers/models/bamba/modeling_bamba.py b/src/transformers/models/bamba/modeling_bamba.py index c89d8d7853008d..3797d8bd3786ac 100644 --- a/src/transformers/models/bamba/modeling_bamba.py +++ b/src/transformers/models/bamba/modeling_bamba.py @@ -230,7 +230,7 @@ def eager_attention_forward( # Adapted from transformers.models.glm.modular_glm.apply_rotary_pos_emb -def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): +def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. Removes the interleaving of cos and sin from GLM @@ -240,8 +240,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): k (`torch.Tensor`): The key tensor. cos (`torch.Tensor`): The cosine part of the rotary embedding. sin (`torch.Tensor`): The sine part of the rotary embedding. - position_ids (`torch.Tensor`, *optional*): - Deprecated and unused. unsqueeze_dim (`int`, *optional*, defaults to 1): The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note @@ -305,6 +303,8 @@ def forward( cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if position_embeddings is None: + raise ValueError("position_embeddings = (cos, sin) must be given") input_shape = hidden_states.shape[:-1] hidden_shape = (*input_shape, -1, self.head_dim) diff --git a/src/transformers/models/bamba/modular_bamba.py b/src/transformers/models/bamba/modular_bamba.py index 7fb35f48fb3b76..46046042b3edc4 100644 --- a/src/transformers/models/bamba/modular_bamba.py +++ b/src/transformers/models/bamba/modular_bamba.py @@ -144,7 +144,7 @@ class BambaRotaryEmbedding(LlamaRotaryEmbedding): # Adapted from transformers.models.glm.modular_glm.apply_rotary_pos_emb -def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): +def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. Removes the interleaving of cos and sin from GLM @@ -154,8 +154,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): k (`torch.Tensor`): The key tensor. cos (`torch.Tensor`): The cosine part of the rotary embedding. sin (`torch.Tensor`): The sine part of the rotary embedding. - position_ids (`torch.Tensor`, *optional*): - Deprecated and unused. unsqueeze_dim (`int`, *optional*, defaults to 1): The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note diff --git a/src/transformers/models/chameleon/modeling_chameleon.py b/src/transformers/models/chameleon/modeling_chameleon.py index 90a02dd5bb9fee..61f2c4c5d715b1 100644 --- a/src/transformers/models/chameleon/modeling_chameleon.py +++ b/src/transformers/models/chameleon/modeling_chameleon.py @@ -153,7 +153,7 @@ def rotate_half(x): # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb -def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): +def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. Args: @@ -161,8 +161,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): k (`torch.Tensor`): The key tensor. cos (`torch.Tensor`): The cosine part of the rotary embedding. sin (`torch.Tensor`): The sine part of the rotary embedding. - position_ids (`torch.Tensor`, *optional*): - Deprecated and unused. unsqueeze_dim (`int`, *optional*, defaults to 1): The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note @@ -532,7 +530,7 @@ def forward( value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) cos, sin = self.rotary_emb(value_states, position_ids) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, None) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: # sin and cos are specific to RoPE models; position_ids needed for the static cache diff --git a/src/transformers/models/codegen/modeling_codegen.py b/src/transformers/models/codegen/modeling_codegen.py index 5c8f1b3957ab38..2bb8dcdce41733 100644 --- a/src/transformers/models/codegen/modeling_codegen.py +++ b/src/transformers/models/codegen/modeling_codegen.py @@ -41,21 +41,22 @@ def create_sinusoidal_positions(num_pos: int, dim: int) -> torch.Tensor: inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, dtype=torch.int64) / dim)) sinusoid_inp = torch.einsum("i , j -> i j", torch.arange(num_pos, dtype=torch.int64).float(), inv_freq).float() - return torch.cat((torch.sin(sinusoid_inp), torch.cos(sinusoid_inp)), dim=1) + sin, cos = torch.sin(sinusoid_inp), torch.cos(sinusoid_inp) + out = torch.cat((sin, cos), dim=1) + return out # Copied from transformers.models.gptj.modeling_gptj.rotate_every_two def rotate_every_two(x: torch.Tensor) -> torch.Tensor: x1 = x[:, :, :, ::2] x2 = x[:, :, :, 1::2] - x = torch.stack((-x2, x1), dim=-1) - return x.flatten(-2) # in einsum notation: rearrange(x, '... d j -> ... (d j)') + return torch.concat((-x2, x1), dim=-1) # Copied from transformers.models.gptj.modeling_gptj.apply_rotary_pos_emb def apply_rotary_pos_emb(tensor: torch.Tensor, sin: torch.Tensor, cos: torch.Tensor) -> torch.Tensor: - sin = torch.repeat_interleave(sin[:, :, None, :], 2, 3) - cos = torch.repeat_interleave(cos[:, :, None, :], 2, 3) + sin = torch.repeat_interleave(sin, 2, dim=-1) + cos = torch.repeat_interleave(cos, 2, dim=-1) return (tensor * cos) + (rotate_every_two(tensor) * sin) @@ -87,25 +88,24 @@ def __init__(self, config, layer_idx=None): self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False) self.rotary_dim = config.rotary_dim - pos_embd_dim = self.rotary_dim or self.embed_dim + pos_embd_dim = self.rotary_dim or self.head_dim + # `embed_positions` of shape `(max_positions, 2 * pos_embd_dim)` self.embed_positions = create_sinusoidal_positions(max_positions, pos_embd_dim) + # TODO: Add comment on the role of mp_num. Why this complex reshaping? def _split_heads(self, x, n_head, dim_head, mp_num): reshaped = x.reshape(x.shape[:-1] + (n_head // mp_num, dim_head)) reshaped = reshaped.reshape(x.shape[:-2] + (-1,) + reshaped.shape[-1:]) return reshaped - def _merge_heads(self, tensor, num_attention_heads, attn_head_size): + def _merge_heads(self, tensor: torch.Tensor) -> torch.Tensor: """ Merges attn_head_size dim and num_attn_heads dim into n_ctx """ - if len(tensor.shape) == 5: - tensor = tensor.permute(0, 1, 3, 2, 4).contiguous() - elif len(tensor.shape) == 4: - tensor = tensor.permute(0, 2, 1, 3).contiguous() - else: - raise ValueError(f"Input tensor rank should be one of [4, 5], but is: {len(tensor.shape)}") - new_shape = tensor.size()[:-2] + (num_attention_heads * attn_head_size,) + if not (4 <= tensor.dim() <= 5): + raise ValueError(f"Input tensor rank should be one of [4, 5], but is: {tensor.dim()}") + tensor = tensor.transpose(-2, -3).contiguous() + new_shape = tensor.size()[:-2] + (self.num_attention_heads * self.head_dim,) return tensor.view(new_shape) def _attn( @@ -153,33 +153,44 @@ def forward( Tuple[torch.Tensor, Tuple[torch.Tensor]], Optional[Tuple[torch.Tensor, Tuple[torch.Tensor], Tuple[torch.Tensor, ...]]], ]: - qkv = self.qkv_proj(hidden_states) + if position_ids is None: + raise ValueError("position_ids must be given") + qkv = self.qkv_proj(hidden_states) # (B, T, 3 * n_head * head_dim) # TODO(enijkamp): factor out number of logical TPU-v4 cores or make forward pass agnostic mp_num = 4 qkv_split = qkv.reshape(qkv.shape[:-1] + (mp_num, -1)) local_dim = self.head_dim * self.num_attention_heads // mp_num query, value, key = torch.split(qkv_split, local_dim, dim=-1) + # Shapes (B, T, mp_num, local_dim), local_dim = n_head * head_dim // mp_num query = self._split_heads(query, self.num_attention_heads, self.head_dim, mp_num=mp_num) key = self._split_heads(key, self.num_attention_heads, self.head_dim, mp_num=mp_num) value = self._split_heads(value, self.num_attention_heads, self.head_dim, mp_num=mp_num) - value = value.permute(0, 2, 1, 3) + # query, key, value: (B, T, n_head, head_dim) + value = value.transpose(1, 2) # (B, n_head, T, head_dim) embed_positions = self.embed_positions if embed_positions.device != position_ids.device: embed_positions = embed_positions.to(position_ids.device) self.embed_positions = embed_positions - sincos = embed_positions[position_ids] + if position_ids.dim() == 1: + position_ids = position_ids.unsqueeze(0) + embed_positions = embed_positions.unsqueeze(0).repeat(position_ids.shape[0], 1, 1) + repeated_position_ids = position_ids.unsqueeze(-1).repeat(1, 1, embed_positions.shape[-1]) + sincos = torch.gather(embed_positions, 1, repeated_position_ids) sin, cos = torch.split(sincos, sincos.shape[-1] // 2, dim=-1) + sin = sin.unsqueeze(2) + cos = cos.unsqueeze(2) + # cos, sin: (B, T, 1, rotary_dim // 2) if self.rotary_dim is not None: - k_rot = key[:, :, :, : self.rotary_dim] - k_pass = key[:, :, :, self.rotary_dim :] + k_rot = key[..., : self.rotary_dim] + k_pass = key[..., self.rotary_dim :] - q_rot = query[:, :, :, : self.rotary_dim] - q_pass = query[:, :, :, self.rotary_dim :] + q_rot = query[..., : self.rotary_dim] + q_pass = query[..., self.rotary_dim :] k_rot = apply_rotary_pos_emb(k_rot, sin, cos) q_rot = apply_rotary_pos_emb(q_rot, sin, cos) @@ -190,8 +201,9 @@ def forward( key = apply_rotary_pos_emb(key, sin, cos) query = apply_rotary_pos_emb(query, sin, cos) - key = key.permute(0, 2, 1, 3) - query = query.permute(0, 2, 1, 3) + key = key.transpose(1, 2) + query = query.transpose(1, 2) + # query, key, value: (B, n_head, T, head_dim) # Note that this cast is quite ugly, but is not implemented before ROPE as k_rot in the original codebase is always in fp32. # Reference: https://github.com/salesforce/CodeGen/blob/f210c3bb1216c975ad858cd4132c0fdeabf4bfc2/codegen1/jaxformer/hf/codegen/modeling_codegen.py#L38 @@ -207,7 +219,7 @@ def forward( # compute self-attention: V x Softmax(QK^T) attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask) - attn_output = self._merge_heads(attn_output, self.num_attention_heads, self.head_dim) + attn_output = self._merge_heads(attn_output) attn_output = self.out_proj(attn_output) attn_output = self.resid_dropout(attn_output) diff --git a/src/transformers/models/cohere/modeling_cohere.py b/src/transformers/models/cohere/modeling_cohere.py index a65d3ee64a234a..a3d1a0abf9fc03 100644 --- a/src/transformers/models/cohere/modeling_cohere.py +++ b/src/transformers/models/cohere/modeling_cohere.py @@ -176,11 +176,10 @@ def rotate_half(x): # Split and rotate. Note that this function is different from e.g. Llama. x1 = x[..., ::2] x2 = x[..., 1::2] - rot_x = torch.stack([-x2, x1], dim=-1).flatten(-2) - return rot_x + return torch.cat((-x2, x1), dim=-1) -def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): +def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. Args: @@ -188,8 +187,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): k (`torch.Tensor`): The key tensor. cos (`torch.Tensor`): The cosine part of the rotary embedding. sin (`torch.Tensor`): The sine part of the rotary embedding. - position_ids (`torch.Tensor`, *optional*): - Deprecated and unused. unsqueeze_dim (`int`, *optional*, defaults to 1): The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note @@ -286,15 +283,17 @@ def __init__(self, config: CohereConfig, layer_idx: Optional[int] = None): def forward( self, hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if position_embeddings is None: + raise ValueError("position_embeddings = (cos, sin) must be given") bsz, q_len, _ = hidden_states.size() query_states = self.q_proj(hidden_states) @@ -372,15 +371,17 @@ def __init__(self, *args, **kwargs): def forward( self, hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if position_embeddings is None: + raise ValueError("position_embeddings = (cos, sin) must be given") if isinstance(past_key_value, StaticCache): raise ValueError( "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` " @@ -476,14 +477,16 @@ class CohereSdpaAttention(CohereAttention): def forward( self, hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if position_embeddings is None: + raise ValueError("position_embeddings = (cos, sin) must be given") if output_attentions: # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. logger.warning_once( @@ -492,13 +495,13 @@ def forward( ) return super().forward( hidden_states=hidden_states, + position_embeddings=position_embeddings, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, - position_embeddings=position_embeddings, ) bsz, q_len, _ = hidden_states.size() @@ -581,13 +584,13 @@ def __init__(self, config: CohereConfig, layer_idx: int): def forward( self, hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ Args: @@ -615,13 +618,13 @@ def forward( # Self Attention hidden_states_attention, self_attn_weights, present_key_value = self.self_attn( hidden_states=hidden_states, + position_embeddings=position_embeddings, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, - position_embeddings=position_embeddings, ) # Fully Connected @@ -861,24 +864,24 @@ def forward( layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, + position_embeddings, causal_mask, position_ids, past_key_values, output_attentions, use_cache, cache_position, - position_embeddings, ) else: layer_outputs = decoder_layer( hidden_states, + position_embeddings=position_embeddings, attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, - position_embeddings=position_embeddings, **flash_attn_kwargs, ) diff --git a/src/transformers/models/cohere2/modeling_cohere2.py b/src/transformers/models/cohere2/modeling_cohere2.py index 1ffa4bffddc3df..b339e6103d884d 100644 --- a/src/transformers/models/cohere2/modeling_cohere2.py +++ b/src/transformers/models/cohere2/modeling_cohere2.py @@ -161,11 +161,10 @@ def rotate_half(x): # Split and rotate. Note that this function is different from e.g. Llama. x1 = x[..., ::2] x2 = x[..., 1::2] - rot_x = torch.stack([-x2, x1], dim=-1).flatten(-2) - return rot_x + return torch.cat((-x2, x1), dim=-1) -def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): +def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. Args: @@ -173,8 +172,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): k (`torch.Tensor`): The key tensor. cos (`torch.Tensor`): The cosine part of the rotary embedding. sin (`torch.Tensor`): The sine part of the rotary embedding. - position_ids (`torch.Tensor`, *optional*): - Deprecated and unused. unsqueeze_dim (`int`, *optional*, defaults to 1): The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note diff --git a/src/transformers/models/dbrx/modeling_dbrx.py b/src/transformers/models/dbrx/modeling_dbrx.py index 3f2e7c384d7d63..17bcfcc40adeee 100644 --- a/src/transformers/models/dbrx/modeling_dbrx.py +++ b/src/transformers/models/dbrx/modeling_dbrx.py @@ -84,7 +84,7 @@ def rotate_half(x): # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb -def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): +def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. Args: @@ -92,8 +92,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): k (`torch.Tensor`): The key tensor. cos (`torch.Tensor`): The cosine part of the rotary embedding. sin (`torch.Tensor`): The sine part of the rotary embedding. - position_ids (`torch.Tensor`, *optional*): - Deprecated and unused. unsqueeze_dim (`int`, *optional*, defaults to 1): The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note @@ -154,6 +152,7 @@ def load_balancing_loss_func( if gate_logits is None or not isinstance(gate_logits, tuple): return torch.tensor(0.0) + compute_device = None if isinstance(gate_logits, tuple): compute_device = gate_logits[0].device concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0) @@ -482,7 +481,7 @@ def forward( value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) cos, sin = self.rotary_emb(value_states, position_ids, seq_len=None) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, None) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache diff --git a/src/transformers/models/esm/modeling_esm.py b/src/transformers/models/esm/modeling_esm.py index 5df5435bb1229a..3b20b17a46e01a 100755 --- a/src/transformers/models/esm/modeling_esm.py +++ b/src/transformers/models/esm/modeling_esm.py @@ -88,31 +88,60 @@ def __init__(self, dim: int): super().__init__() # Generate and save the inverse frequency buffer (non trainable) inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, dtype=torch.int64).float() / dim)) - inv_freq = inv_freq self.register_buffer("inv_freq", inv_freq) self._seq_len_cached = None self._cos_cached = None self._sin_cached = None + self._positions_ids_cached = None - def _update_cos_sin_tables(self, x, seq_dimension=2): - seq_len = x.shape[seq_dimension] - - # Reset the tables if the sequence length has changed, - # or if we're on a new device (possibly due to tracing for instance) - if seq_len != self._seq_len_cached or self._cos_cached.device != x.device: - self._seq_len_cached = seq_len - t = torch.arange(x.shape[seq_dimension], device=x.device).type_as(self.inv_freq) - freqs = torch.outer(t, self.inv_freq) + def _update_cos_sin_tables( + self, + x: torch.Tensor, + position_ids: Optional[torch.Tensor] = None, + seq_dimension: int = 2, + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Reset the tables if the sequence length has changed, position_ids + # has changed, or if we're on a new device (possibly due to tracing for + # instance) + device_changed = (self._cos_cached is not None) and (self._cos_cached.device != x.device) + t = None + if position_ids is not None: + if ( + device_changed + or self._positions_ids_cached is None + or not torch.equal(position_ids, self._positions_ids_cached) + ): + # RoPE embeddings depends on position_ids + # Caching makes sense: position_ids is the same for every layer + if position_ids.dim() == 1: + position_ids = position_ids.unsqueeze(0) + self._positions_ids_cached = torch.clone(position_ids) + t = position_ids.unsqueeze(-1).type_as(self.inv_freq).to(x.device) + else: + seq_len = x.shape[seq_dimension] + if device_changed or seq_len != self._seq_len_cached: + self._seq_len_cached = seq_len + t = torch.arange(seq_len, device=x.device)[None, :, None].type_as(self.inv_freq) + if t is not None: + inv_freq = self.inv_freq[None, None, :].expand(*t.shape[:2], -1) + t = t.expand(-1, -1, inv_freq.shape[-1]) + freqs = t * inv_freq emb = torch.cat((freqs, freqs), dim=-1).to(x.device) - - self._cos_cached = emb.cos()[None, None, :, :] - self._sin_cached = emb.sin()[None, None, :, :] + self._cos_cached = emb.cos().unsqueeze(1) + self._sin_cached = emb.sin().unsqueeze(1) return self._cos_cached, self._sin_cached - def forward(self, q: torch.Tensor, k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: - self._cos_cached, self._sin_cached = self._update_cos_sin_tables(k, seq_dimension=-2) + def forward( + self, + q: torch.Tensor, + k: torch.Tensor, + position_ids: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + self._cos_cached, self._sin_cached = self._update_cos_sin_tables( + x=k, position_ids=position_ids, seq_dimension=-2 + ) return ( apply_rotary_pos_emb(q, self._cos_cached, self._sin_cached), @@ -276,14 +305,17 @@ def __init__(self, config, position_embedding_type=None): self.is_decoder = config.is_decoder def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: + # x: (B, T, num_heads * head_size) + # Returned tensor: (B, num_heads, T, head_size) new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) x = x.view(new_x_shape) - return x.permute(0, 2, 1, 3) + return x.transpose(1, 2) def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.FloatTensor] = None, + position_ids: Optional[torch.Tensor] = None, head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, @@ -316,6 +348,9 @@ def forward( value_layer = self.transpose_for_scores(self.value(hidden_states)) query_layer = self.transpose_for_scores(mixed_query_layer) + # query_layer: (B, num_heads, T, head_size) + # key_layer, value_layer: (B, num_heads, *, head_size), where + # * = T in the default case # Matt: Our BERT model (which this code was derived from) scales attention logits down by sqrt(head_dim). # ESM scales the query down by the same factor instead. Modulo numerical stability these are equivalent, @@ -334,7 +369,7 @@ def forward( past_key_value = (key_layer, value_layer) if self.position_embedding_type == "rotary": - query_layer, key_layer = self.rotary_embeddings(query_layer, key_layer) + query_layer, key_layer = self.rotary_embeddings(query_layer, key_layer, position_ids) # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) @@ -426,6 +461,7 @@ def forward( self, hidden_states, attention_mask=None, + position_ids: Optional[torch.Tensor] = None, head_mask=None, encoder_hidden_states=None, encoder_attention_mask=None, @@ -436,6 +472,7 @@ def forward( self_outputs = self.self( hidden_states_ln, attention_mask, + position_ids, head_mask, encoder_hidden_states, encoder_attention_mask, @@ -491,6 +528,7 @@ def forward( self, hidden_states, attention_mask=None, + position_ids: Optional[torch.Tensor] = None, head_mask=None, encoder_hidden_states=None, encoder_attention_mask=None, @@ -502,6 +540,7 @@ def forward( self_attention_outputs = self.attention( hidden_states, attention_mask, + position_ids, head_mask, output_attentions=output_attentions, past_key_value=self_attn_past_key_value, @@ -569,6 +608,7 @@ def forward( self, hidden_states, attention_mask=None, + position_ids: Optional[torch.Tensor] = None, head_mask=None, encoder_hidden_states=None, encoder_attention_mask=None, @@ -602,6 +642,7 @@ def forward( layer_module.__call__, hidden_states, attention_mask, + position_ids, layer_head_mask, encoder_hidden_states, encoder_attention_mask, @@ -612,6 +653,7 @@ def forward( layer_outputs = layer_module( hidden_states, attention_mask, + position_ids, layer_head_mask, encoder_hidden_states, encoder_attention_mask, diff --git a/src/transformers/models/esm/modeling_tf_esm.py b/src/transformers/models/esm/modeling_tf_esm.py index 0e5cf3d8f61f8a..9a485057d291d7 100644 --- a/src/transformers/models/esm/modeling_tf_esm.py +++ b/src/transformers/models/esm/modeling_tf_esm.py @@ -107,17 +107,32 @@ def build(self, input_shape): 1.0 / (10000 ** (tf.range(start=0, limit=self.dim, delta=2, dtype=tf.float32) / self.dim)) ) - def _compute_cos_sin(self, x, seq_dimension=2): - seq_len = tf.shape(x)[seq_dimension] - - t = tf.range(seq_len, dtype=self.inv_freq.dtype) - freqs = tf.einsum("i, j -> ij", t, self.inv_freq) # Outer multiplication - emb = tf.concat((freqs, freqs), axis=-1)[None, None, :, :] + def _compute_cos_sin( + self, + x: tf.Tensor, + position_ids: tf.Tensor | None = None, + seq_dimension: int = 2, + ) -> Tuple[tf.Tensor, tf.Tensor]: + if position_ids is not None: + t = tf.cast(position_ids[:, :, None], self.inv_freq.dtype) + else: + seq_len = tf.shape(x)[seq_dimension] + t = tf.range(seq_len, dtype=self.inv_freq.dtype)[None, :, None] + inv_freq = self.inv_freq[None, None, :] + freqs = t * inv_freq + emb = tf.concat((freqs, freqs), axis=-1) return tf.cos(emb), tf.sin(emb) - def call(self, q: tf.Tensor, k: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]: - cos_emb, sin_emb = self._compute_cos_sin(k, seq_dimension=-2) + def call( + self, + q: tf.Tensor, + k: tf.Tensor, + position_ids: tf.Tensor | None = None, + ) -> Tuple[tf.Tensor, tf.Tensor]: + cos_emb, sin_emb = self._compute_cos_sin(x=k, position_ids=position_ids, seq_dimension=-2) + cos_emb = cos_emb[:, None, :, :] + sin_emb = sin_emb[:, None, :, :] return ( apply_rotary_pos_emb(q, cos_emb, sin_emb), @@ -319,6 +334,7 @@ def __init__(self, config, position_embedding_type=None, name=None): self.config = config def transpose_for_scores(self, x: tf.Tensor) -> tf.Tensor: + # Returned tensor has shape (B, num_heads, T, head_size) new_x_shape = shape_list(x)[:-1] + [self.num_attention_heads, self.attention_head_size] x = tf.reshape(x, new_x_shape) return tf.transpose(x, perm=(0, 2, 1, 3)) @@ -327,6 +343,7 @@ def call( self, hidden_states: tf.Tensor, attention_mask: tf.Tensor | None = None, + position_ids: tf.Tensor | None = None, head_mask: tf.Tensor | None = None, encoder_hidden_states: tf.Tensor | None = None, encoder_attention_mask: tf.Tensor | None = None, @@ -378,7 +395,7 @@ def call( past_key_value = (key_layer, value_layer) if self.position_embedding_type == "rotary": - query_layer, key_layer = self.rotary_embeddings(query_layer, key_layer) + query_layer, key_layer = self.rotary_embeddings(query_layer, key_layer, position_ids) # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True) @@ -484,6 +501,7 @@ def call( self, hidden_states, attention_mask=None, + position_ids=None, head_mask=None, encoder_hidden_states=None, encoder_attention_mask=None, @@ -495,6 +513,7 @@ def call( self_outputs = self.self( hidden_states_ln, attention_mask, + position_ids, head_mask, encoder_hidden_states, encoder_attention_mask, @@ -591,6 +610,7 @@ def call( self, hidden_states, attention_mask=None, + position_ids=None, head_mask=None, encoder_hidden_states=None, encoder_attention_mask=None, @@ -603,6 +623,7 @@ def call( self_attention_outputs = self.attention( hidden_states, attention_mask, + position_ids, head_mask, output_attentions=output_attentions, past_key_value=self_attn_past_key_value, @@ -630,6 +651,7 @@ def call( cross_attention_outputs = self.crossattention( attention_output, attention_mask, + position_ids, head_mask, encoder_hidden_states, encoder_attention_mask, @@ -688,6 +710,7 @@ def call( self, hidden_states, attention_mask=None, + position_ids=None, head_mask=None, encoder_hidden_states=None, encoder_attention_mask=None, @@ -713,6 +736,7 @@ def call( layer_outputs = layer_module( hidden_states, attention_mask, + position_ids, layer_head_mask, encoder_hidden_states, encoder_attention_mask, @@ -1049,6 +1073,7 @@ def call( encoder_outputs = self.encoder( hidden_states=embedding_output, attention_mask=extended_attention_mask, + position_ids=position_ids, head_mask=head_mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_extended_attention_mask, diff --git a/src/transformers/models/falcon/modeling_falcon.py b/src/transformers/models/falcon/modeling_falcon.py index e0e4ff424cb47d..6167f80fe4e1a3 100644 --- a/src/transformers/models/falcon/modeling_falcon.py +++ b/src/transformers/models/falcon/modeling_falcon.py @@ -81,7 +81,7 @@ def rotate_half(x): # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb -def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): +def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. Args: @@ -89,8 +89,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): k (`torch.Tensor`): The key tensor. cos (`torch.Tensor`): The cosine part of the rotary embedding. sin (`torch.Tensor`): The sine part of the rotary embedding. - position_ids (`torch.Tensor`, *optional*): - Deprecated and unused. unsqueeze_dim (`int`, *optional*, defaults to 1): The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note @@ -330,6 +328,7 @@ def _merge_heads(self, x: torch.Tensor) -> torch.Tensor: def forward( self, hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], alibi: Optional[torch.Tensor], attention_mask: torch.Tensor, position_ids: Optional[torch.LongTensor] = None, @@ -338,8 +337,9 @@ def forward( use_cache: bool = False, output_attentions: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC ): + if position_embeddings is None: + raise ValueError("position_embeddings = (cos, sin) must be given") fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size] num_kv_heads = self.num_heads if self.new_decoder_architecture else self.num_kv_heads # 3 x [batch_size, seq_length, num_heads, head_dim] @@ -480,6 +480,7 @@ def __init__(self, *args, **kwargs): def forward( self, hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], alibi: Optional[torch.Tensor], attention_mask: torch.Tensor, position_ids: Optional[torch.LongTensor] = None, @@ -488,8 +489,9 @@ def forward( use_cache: bool = False, output_attentions: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC ): + if position_embeddings is None: + raise ValueError("position_embeddings = (cos, sin) must be given") fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size] num_kv_heads = self.num_heads if self.new_decoder_architecture else self.num_kv_heads # 3 x [batch_size, seq_length, num_heads, head_dim] @@ -618,6 +620,7 @@ def __init__(self, config: FalconConfig, layer_idx=None): def forward( self, hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], alibi: Optional[torch.Tensor], attention_mask: torch.Tensor, position_ids: Optional[torch.LongTensor] = None, @@ -626,7 +629,6 @@ def forward( use_cache: bool = False, output_attentions: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC **kwargs, ): residual = hidden_states @@ -640,6 +642,7 @@ def forward( # Self attention. attn_outputs = self.self_attention( attention_layernorm_out, + position_embeddings=position_embeddings, layer_past=layer_past, attention_mask=attention_mask, position_ids=position_ids, @@ -648,7 +651,6 @@ def forward( use_cache=use_cache, output_attentions=output_attentions, cache_position=cache_position, - position_embeddings=position_embeddings, ) attention_output = attn_outputs[0] @@ -961,6 +963,7 @@ def forward( outputs = self._gradient_checkpointing_func( block.__call__, hidden_states, + position_embeddings, alibi, causal_mask, position_ids, @@ -969,11 +972,11 @@ def forward( use_cache, output_attentions, cache_position, - position_embeddings, ) else: outputs = block( hidden_states, + position_embeddings=position_embeddings, layer_past=past_key_values, attention_mask=causal_mask, position_ids=position_ids, @@ -982,7 +985,6 @@ def forward( output_attentions=output_attentions, alibi=alibi, cache_position=cache_position, - position_embeddings=position_embeddings, ) hidden_states = outputs[0] diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index 71cd6b6158ca0b..997e21e9660d23 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -163,7 +163,7 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) -def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): +def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. Args: @@ -171,8 +171,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): k (`torch.Tensor`): The key tensor. cos (`torch.Tensor`): The cosine part of the rotary embedding. sin (`torch.Tensor`): The sine part of the rotary embedding. - position_ids (`torch.Tensor`, *optional*): - Deprecated and unused. unsqueeze_dim (`int`, *optional*, defaults to 1): The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note @@ -263,6 +261,8 @@ def forward( cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if position_embeddings is None: + raise ValueError("position_embeddings = (cos, sin) must be given") input_shape = hidden_states.shape[:-1] hidden_shape = (*input_shape, -1, self.head_dim) @@ -318,13 +318,13 @@ def __init__(self, config: GemmaConfig, layer_idx: int): def forward( self, hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: residual = hidden_states @@ -334,13 +334,13 @@ def forward( # Self Attention hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, + position_embeddings=position_embeddings, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, - position_embeddings=position_embeddings, **kwargs, ) hidden_states = residual + hidden_states @@ -585,24 +585,24 @@ def forward( layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, + position_embeddings, causal_mask, position_ids, past_key_values, output_attentions, use_cache, cache_position, - position_embeddings, ) else: layer_outputs = decoder_layer( hidden_states, + position_embeddings=position_embeddings, attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, - position_embeddings=position_embeddings, ) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/gemma/modular_gemma.py b/src/transformers/models/gemma/modular_gemma.py index 29b6f8a1946173..6189fc6ccbe2cf 100644 --- a/src/transformers/models/gemma/modular_gemma.py +++ b/src/transformers/models/gemma/modular_gemma.py @@ -425,24 +425,24 @@ def forward( layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, + position_embeddings, causal_mask, position_ids, past_key_values, output_attentions, use_cache, cache_position, - position_embeddings, ) else: layer_outputs = decoder_layer( hidden_states, + position_embeddings=position_embeddings, attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, - position_embeddings=position_embeddings, ) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py index 67fc6c86a3bac6..5509ea42484a1c 100644 --- a/src/transformers/models/gemma2/modeling_gemma2.py +++ b/src/transformers/models/gemma2/modeling_gemma2.py @@ -97,7 +97,7 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) -def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): +def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. Args: @@ -105,8 +105,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): k (`torch.Tensor`): The key tensor. cos (`torch.Tensor`): The cosine part of the rotary embedding. sin (`torch.Tensor`): The sine part of the rotary embedding. - position_ids (`torch.Tensor`, *optional*): - Deprecated and unused. unsqueeze_dim (`int`, *optional*, defaults to 1): The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note @@ -208,6 +206,8 @@ def forward( cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if position_embeddings is None: + raise ValueError("position_embeddings = (cos, sin) must be given") input_shape = hidden_states.shape[:-1] hidden_shape = (*input_shape, -1, self.head_dim) diff --git a/src/transformers/models/gemma2/modular_gemma2.py b/src/transformers/models/gemma2/modular_gemma2.py index 48b12411361aff..fd90d55926af1f 100644 --- a/src/transformers/models/gemma2/modular_gemma2.py +++ b/src/transformers/models/gemma2/modular_gemma2.py @@ -244,6 +244,8 @@ def forward( cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if position_embeddings is None: + raise ValueError("position_embeddings = (cos, sin) must be given") input_shape = hidden_states.shape[:-1] hidden_shape = (*input_shape, -1, self.head_dim) diff --git a/src/transformers/models/glm/modeling_glm.py b/src/transformers/models/glm/modeling_glm.py index 706847650b818e..f63666eeed1d99 100644 --- a/src/transformers/models/glm/modeling_glm.py +++ b/src/transformers/models/glm/modeling_glm.py @@ -119,7 +119,7 @@ def rotate_half(x): return torch.stack((-x2, x1), dim=-1).flatten(-2) -def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): +def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. Args: @@ -127,8 +127,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): k (`torch.Tensor`): The key tensor. cos (`torch.Tensor`): The cosine part of the rotary embedding. sin (`torch.Tensor`): The sine part of the rotary embedding. - position_ids (`torch.Tensor`, *optional*): - Deprecated and unused. unsqueeze_dim (`int`, *optional*, defaults to 1): The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note @@ -194,6 +192,8 @@ def forward( cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if position_embeddings is None: + raise ValueError("position_embeddings = (cos, sin) must be given") input_shape = hidden_states.shape[:-1] hidden_shape = (*input_shape, -1, self.head_dim) @@ -334,13 +334,13 @@ def __init__(self, config: GlmConfig, layer_idx: int): def forward( self, hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: residual = hidden_states @@ -350,13 +350,13 @@ def forward( # Self Attention hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, + position_embeddings=position_embeddings, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, - position_embeddings=position_embeddings, **kwargs, ) hidden_states = residual + hidden_states @@ -595,24 +595,24 @@ def forward( layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, + position_embeddings, causal_mask, position_ids, past_key_values, output_attentions, use_cache, cache_position, - position_embeddings, ) else: layer_outputs = decoder_layer( hidden_states, + position_embeddings=position_embeddings, attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, - position_embeddings=position_embeddings, **flash_attn_kwargs, ) diff --git a/src/transformers/models/glm/modular_glm.py b/src/transformers/models/glm/modular_glm.py index ec07be10fb6a55..dfa28a35981a06 100644 --- a/src/transformers/models/glm/modular_glm.py +++ b/src/transformers/models/glm/modular_glm.py @@ -46,7 +46,7 @@ def rotate_half(x): return torch.stack((-x2, x1), dim=-1).flatten(-2) -def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): +def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. Args: @@ -54,8 +54,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): k (`torch.Tensor`): The key tensor. cos (`torch.Tensor`): The cosine part of the rotary embedding. sin (`torch.Tensor`): The sine part of the rotary embedding. - position_ids (`torch.Tensor`, *optional*): - Deprecated and unused. unsqueeze_dim (`int`, *optional*, defaults to 1): The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note diff --git a/src/transformers/models/gpt_neox/configuration_gpt_neox.py b/src/transformers/models/gpt_neox/configuration_gpt_neox.py index 07514a37c6f2fa..c632085ff092e6 100644 --- a/src/transformers/models/gpt_neox/configuration_gpt_neox.py +++ b/src/transformers/models/gpt_neox/configuration_gpt_neox.py @@ -49,7 +49,9 @@ class GPTNeoXConfig(PretrainedConfig): The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, `"relu"`, `"selu"` and `"gelu_new"` are supported. rotary_pct (`float`, *optional*, defaults to 0.25): - percentage of hidden dimensions to allocate to rotary embeddings + Percentage of hidden dimensions to allocate to rotary embeddings. + Note: In most other models, this parameter is called + `partial_rotary_factor`. rotary_emb_base (`int`, *optional*, defaults to 10000) base for computing rotary embeddings frequency attention_dropout (`float`, *optional*, defaults to 0.0): diff --git a/src/transformers/models/gpt_neox/modeling_gpt_neox.py b/src/transformers/models/gpt_neox/modeling_gpt_neox.py index 98418cb02d65ba..44c257aafcc059 100755 --- a/src/transformers/models/gpt_neox/modeling_gpt_neox.py +++ b/src/transformers/models/gpt_neox/modeling_gpt_neox.py @@ -271,12 +271,10 @@ def __init__(self, config, layer_idx=None): "The hidden size is not divisble by the number of attention heads! Make sure to update them" ) self.head_size = self.hidden_size // self.num_attention_heads - self.rotary_ndims = int(self.head_size * config.rotary_pct) - self.rope_theta = config.rotary_emb_base + self.rotary_ndims = int(self.head_size * config.partial_rotary_factor) self._init_bias(config.max_position_embeddings) self.register_buffer("masked_bias", torch.tensor(-1e9), persistent=False) - self.rotary_emb = GPTNeoXRotaryEmbedding(config=self.config) if layer_idx is None: logger.warning_once( @@ -305,6 +303,7 @@ def _init_bias(self, max_positions, device=None): def forward( self, hidden_states: torch.FloatTensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], attention_mask: torch.FloatTensor, position_ids: torch.LongTensor, head_mask: Optional[torch.FloatTensor] = None, @@ -313,18 +312,19 @@ def forward( output_attentions: Optional[bool] = False, padding_mask: Optional[torch.Tensor] = None, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC ): + if position_embeddings is None: + raise ValueError("position_embeddings = (cos, sin) must be given") bsz, seq_len, _ = hidden_states.shape # Apply attention-specific projections and rope query, key, value, present = self._attn_projections_and_rope( hidden_states=hidden_states, + position_embeddings=position_embeddings, position_ids=position_ids, layer_past=layer_past, use_cache=use_cache, cache_position=cache_position, - position_embeddings=position_embeddings, ) # Checking for fallbacks in case an unsupported feature is requested @@ -403,12 +403,14 @@ def _merge_heads(cls, tensor, num_attention_heads, attn_head_size): def _attn_projections_and_rope( self, hidden_states: torch.FloatTensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], position_ids: torch.LongTensor, layer_past: Optional[Tuple[torch.Tensor]] = None, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC ): + if position_embeddings is None: + raise ValueError("position_embeddings = (cos, sin) must be given") # Compute QKV # Attention heads [batch, seq_len, hidden_size] # --> [batch, seq_len, (np * 3 * head_size)] @@ -563,7 +565,7 @@ def rotate_half(x): # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb -def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): +def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. Args: @@ -571,8 +573,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): k (`torch.Tensor`): The key tensor. cos (`torch.Tensor`): The cosine part of the rotary embedding. sin (`torch.Tensor`): The sine part of the rotary embedding. - position_ids (`torch.Tensor`, *optional*): - Deprecated and unused. unsqueeze_dim (`int`, *optional*, defaults to 1): The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note @@ -604,14 +604,6 @@ def forward(self, hidden_states): return hidden_states -GPT_NEOX_ATTENTION_CLASSES = { - "eager": GPTNeoXAttention, - "flash_attention_2": GPTNeoXFlashAttention2, - "sdpa": GPTNeoXSdpaAttention, - "flex_attention": GPTNeoXAttention, -} - - class GPTNeoXLayer(nn.Module): def __init__(self, config, layer_idx): super().__init__() @@ -620,12 +612,13 @@ def __init__(self, config, layer_idx): self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.post_attention_dropout = nn.Dropout(config.hidden_dropout) self.post_mlp_dropout = nn.Dropout(config.hidden_dropout) - self.attention = GPT_NEOX_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx) + self.attention = GPTNeoXAttention(config, layer_idx) self.mlp = GPTNeoXMLP(config) def forward( self, hidden_states: Optional[torch.FloatTensor], + position_embeddings: Tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.FloatTensor] = None, position_ids: Optional[torch.LongTensor] = None, head_mask: Optional[torch.FloatTensor] = None, @@ -633,10 +626,10 @@ def forward( layer_past: Optional[Cache] = None, output_attentions: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC ): attention_layer_outputs = self.attention( self.input_layernorm(hidden_states), + position_embeddings=position_embeddings, attention_mask=attention_mask, position_ids=position_ids, layer_past=layer_past, @@ -644,7 +637,6 @@ def forward( use_cache=use_cache, output_attentions=output_attentions, cache_position=cache_position, - position_embeddings=position_embeddings, ) attn_output = attention_layer_outputs[0] # output_attn: attn_output, present, (attn_weights) attn_output = self.post_attention_dropout(attn_output) @@ -878,6 +870,7 @@ def forward( outputs = self._gradient_checkpointing_func( layer.__call__, hidden_states, + position_embeddings, causal_mask, position_ids, head_mask[i], @@ -885,11 +878,11 @@ def forward( None, output_attentions, cache_position, - position_embeddings, ) else: outputs = layer( hidden_states, + position_embeddings=position_embeddings, attention_mask=causal_mask, position_ids=position_ids, head_mask=head_mask[i], @@ -897,7 +890,6 @@ def forward( use_cache=use_cache, output_attentions=output_attentions, cache_position=cache_position, - position_embeddings=position_embeddings, ) hidden_states = outputs[0] if use_cache is True: diff --git a/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py b/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py index fba67ae03a5979..8356c3654b7650 100755 --- a/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +++ b/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py @@ -83,9 +83,7 @@ def __init__(self, config, use_bias=False, layer_idx=None): ) self.layer_idx = layer_idx - self.rotary_ndims = int(self.head_size * config.rotary_pct) - self.rope_theta = config.rotary_emb_base - self.rotary_emb = GPTNeoXJapaneseRotaryEmbedding(config=config) + self.rotary_ndims = int(self.head_size * config.partial_rotary_factor) self.attention_dropout = nn.Dropout(config.attention_dropout) self.norm_factor = math.sqrt(self.head_size) @@ -98,6 +96,7 @@ def __init__(self, config, use_bias=False, layer_idx=None): def forward( self, hidden_states: torch.FloatTensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], attention_mask: torch.FloatTensor, position_ids: torch.LongTensor, head_mask: Optional[torch.FloatTensor] = None, @@ -105,7 +104,6 @@ def forward( use_cache: Optional[bool] = False, output_attentions: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC ): # Compute QKV # Attention heads [batch, seq_len, hidden_size] @@ -297,7 +295,7 @@ def rotate_half(x): # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb -def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): +def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. Args: @@ -305,8 +303,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): k (`torch.Tensor`): The key tensor. cos (`torch.Tensor`): The cosine part of the rotary embedding. sin (`torch.Tensor`): The sine part of the rotary embedding. - position_ids (`torch.Tensor`, *optional*): - Deprecated and unused. unsqueeze_dim (`int`, *optional*, defaults to 1): The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note @@ -377,6 +373,7 @@ def __init__(self, config, layer_number): def forward( self, hidden_states: Optional[torch.FloatTensor], + position_embeddings: Tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.FloatTensor] = None, position_ids: Optional[torch.LongTensor] = None, head_mask: Optional[torch.FloatTensor] = None, @@ -384,12 +381,12 @@ def forward( layer_past: Optional[Cache] = None, output_attentions: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC ): residual = hidden_states ln_out = self.input_layernorm(hidden_states) attention_layer_outputs, attn_bias = self.attention( ln_out, + position_embeddings=position_embeddings, attention_mask=attention_mask, layer_past=layer_past, head_mask=head_mask, @@ -397,7 +394,6 @@ def forward( output_attentions=output_attentions, position_ids=position_ids, cache_position=cache_position, - position_embeddings=position_embeddings, ) attn_output = attention_layer_outputs[0] # output_attn: a, present, (attentions) outputs = attention_layer_outputs[1:] @@ -623,6 +619,7 @@ def forward( outputs = layer( hidden_states, + position_embeddings=position_embeddings, attention_mask=causal_mask, position_ids=position_ids, head_mask=head_mask[i], @@ -630,7 +627,6 @@ def forward( use_cache=use_cache, output_attentions=output_attentions, cache_position=cache_position, - position_embeddings=position_embeddings, ) hidden_states = outputs[0] if use_cache is True: diff --git a/src/transformers/models/gptj/configuration_gptj.py b/src/transformers/models/gptj/configuration_gptj.py index 1b93f259b05b12..1c8f31e0394ec1 100644 --- a/src/transformers/models/gptj/configuration_gptj.py +++ b/src/transformers/models/gptj/configuration_gptj.py @@ -49,7 +49,9 @@ class GPTJConfig(PretrainedConfig): n_head (`int`, *optional*, defaults to 16): Number of attention heads for each attention layer in the Transformer encoder. rotary_dim (`int`, *optional*, defaults to 64): - Number of dimensions in the embedding that Rotary Position Embedding is applied to. + Number of dimensions in the embedding of each head that Rotary Position + Embedding is applied to. If `rotary_dim=None`, RoPE is applied to the + full size `n_embd // n_head` (this is not the default). n_inner (`int`, *optional*, defaults to None): Dimensionality of the inner feed-forward layers. `None` will set it to 4 times n_embd activation_function (`str`, *optional*, defaults to `"gelu_new"`): diff --git a/src/transformers/models/gptj/modeling_flax_gptj.py b/src/transformers/models/gptj/modeling_flax_gptj.py index 9f0d4d6e860003..b0258a116efa92 100644 --- a/src/transformers/models/gptj/modeling_flax_gptj.py +++ b/src/transformers/models/gptj/modeling_flax_gptj.py @@ -120,16 +120,15 @@ def create_sinusoidal_positions(num_pos, dim): def rotate_every_two(tensor): - rotate_half_tensor = jnp.stack((-tensor[:, :, :, 1::2], tensor[:, :, :, ::2]), axis=-1) - rotate_half_tensor = rotate_half_tensor.reshape(rotate_half_tensor.shape[:-2] + (-1,)) - return rotate_half_tensor + x1 = tensor[:, :, :, ::2] + x2 = tensor[:, :, :, 1::2] + return jnp.concatenate((-x2, x1), axis=-1) -def apply_rotary_pos_emb(tensor, sincos): - sin_pos, cos_pos = sincos - sin_pos = sin_pos[:, :, None, :].repeat(2, 3) - cos_pos = cos_pos[:, :, None, :].repeat(2, 3) - return (tensor * cos_pos) + (rotate_every_two(tensor) * sin_pos) +def apply_rotary_pos_emb(tensor, sin, cos): + sin = sin.repeat(2, axis=-1) + cos = cos.repeat(2, axis=-1) + return (tensor * cos) + (rotate_every_two(tensor) * sin) class FlaxGPTJAttention(nn.Module): @@ -140,10 +139,17 @@ class FlaxGPTJAttention(nn.Module): def setup(self): config = self.config + + self.resid_dropout = nn.Dropout(rate=config.resid_pdrop) + self.embed_dim = config.hidden_size self.num_heads = config.num_attention_heads self.head_dim = self.embed_dim // self.num_heads - + if self.head_dim * self.num_heads != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and" + f" `num_heads`: {self.num_heads})." + ) self.rotary_dim = config.rotary_dim dense = partial( @@ -157,11 +163,9 @@ def setup(self): self.q_proj, self.k_proj, self.v_proj = dense(), dense(), dense() self.out_proj = dense() - self.resid_dropout = nn.Dropout(rate=config.resid_pdrop) - self.causal_mask = make_causal_mask(jnp.ones((1, config.max_position_embeddings), dtype="bool"), dtype="bool") - pos_embd_dim = self.rotary_dim or self.embed_dim + pos_embd_dim = self.rotary_dim or self.head_dim self.embed_positions = create_sinusoidal_positions(config.max_position_embeddings, pos_embd_dim) def _split_heads(self, hidden_states): @@ -212,16 +216,21 @@ def __call__( init_cache: bool = False, output_attentions: bool = False, ): - query = self.q_proj(hidden_states) - key = self.k_proj(hidden_states) - value = self.v_proj(hidden_states) - - query = self._split_heads(query) - key = self._split_heads(key) - value = self._split_heads(value) + if position_ids is None: + raise ValueError("position_ids must be given") + if len(position_ids.shape) == 1: + position_ids = position_ids.reshape(1, -1) + query = self._split_heads(self.q_proj(hidden_states)) + key = self._split_heads(self.k_proj(hidden_states)) + value = self._split_heads(self.v_proj(hidden_states)) + # query, key, value: (B, T, n_head, head_dim) sincos = jnp.take(self.embed_positions, position_ids, axis=0) - sincos = jnp.split(sincos, 2, axis=-1) + sin, cos = jnp.split(sincos, 2, axis=-1) + sin = sin[:, :, None, :] + cos = cos[:, :, None, :] + # cos, sin: (B, T, 1, rotary_dim // 2) or (1, T, 1, rotary_dim // 2) + if self.rotary_dim is not None: k_rot = key[:, :, :, : self.rotary_dim] k_pass = key[:, :, :, self.rotary_dim :] @@ -229,14 +238,14 @@ def __call__( q_rot = query[:, :, :, : self.rotary_dim] q_pass = query[:, :, :, self.rotary_dim :] - k_rot = apply_rotary_pos_emb(k_rot, sincos) - q_rot = apply_rotary_pos_emb(q_rot, sincos) + k_rot = apply_rotary_pos_emb(k_rot, sin, cos) + q_rot = apply_rotary_pos_emb(q_rot, sin, cos) key = jnp.concatenate([k_rot, k_pass], axis=-1) query = jnp.concatenate([q_rot, q_pass], axis=-1) else: - key = apply_rotary_pos_emb(key, sincos) - query = apply_rotary_pos_emb(query, sincos) + key = apply_rotary_pos_emb(key, sin, cos) + query = apply_rotary_pos_emb(query, sin, cos) query_length, key_length = query.shape[1], key.shape[1] diff --git a/src/transformers/models/gptj/modeling_gptj.py b/src/transformers/models/gptj/modeling_gptj.py index 00749b7eb07fbc..9c9c37d2ff3d8b 100644 --- a/src/transformers/models/gptj/modeling_gptj.py +++ b/src/transformers/models/gptj/modeling_gptj.py @@ -61,24 +61,25 @@ def create_sinusoidal_positions(num_pos: int, dim: int) -> torch.Tensor: inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, dtype=torch.int64) / dim)) sinusoid_inp = torch.einsum("i , j -> i j", torch.arange(num_pos, dtype=torch.int64).float(), inv_freq).float() - return torch.cat((torch.sin(sinusoid_inp), torch.cos(sinusoid_inp)), dim=1) + sin, cos = torch.sin(sinusoid_inp), torch.cos(sinusoid_inp) + out = torch.cat((sin, cos), dim=1) + return out @torch.fx.wrap def get_embed_positions(embed_positions, position_ids): - return embed_positions.to(position_ids.device).repeat(position_ids.shape[0], 1, 1) + return embed_positions.to(position_ids.device).unsqueeze(0).expand(position_ids.shape[0], -1, -1) def rotate_every_two(x: torch.Tensor) -> torch.Tensor: x1 = x[:, :, :, ::2] x2 = x[:, :, :, 1::2] - x = torch.stack((-x2, x1), dim=-1) - return x.flatten(-2) # in einsum notation: rearrange(x, '... d j -> ... (d j)') + return torch.concat((-x2, x1), dim=-1) def apply_rotary_pos_emb(tensor: torch.Tensor, sin: torch.Tensor, cos: torch.Tensor) -> torch.Tensor: - sin = torch.repeat_interleave(sin[:, :, None, :], 2, 3) - cos = torch.repeat_interleave(cos[:, :, None, :], 2, 3) + sin = torch.repeat_interleave(sin, 2, dim=-1) + cos = torch.repeat_interleave(cos, 2, dim=-1) return (tensor * cos) + (rotate_every_two(tensor) * sin) @@ -86,7 +87,6 @@ class GPTJAttention(nn.Module): def __init__(self, config, layer_idx=None): super().__init__() self.config = config - max_positions = config.max_position_embeddings self.attn_dropout = nn.Dropout(config.attn_pdrop) self.resid_dropout = nn.Dropout(config.resid_pdrop) @@ -108,52 +108,51 @@ def __init__(self, config, layer_idx=None): f"embed_dim must be divisible by num_attention_heads (got `embed_dim`: {self.embed_dim} and" f" `num_attention_heads`: {self.num_attention_heads})." ) - self.scale_attn = torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32)).to(torch.get_default_dtype()) + self.scale_attn = self.head_dim**0.5 + self.rotary_dim = config.rotary_dim + self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False) self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False) self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False) - self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False) self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False) - self.rotary_dim = config.rotary_dim - pos_embd_dim = self.rotary_dim or self.embed_dim - self.embed_positions = create_sinusoidal_positions(max_positions, pos_embd_dim) - def _split_heads(self, tensor, num_attention_heads, attn_head_size, rotary): + pos_embd_dim = self.rotary_dim or self.head_dim + # `embed_positions` of shape `(max_positions, pos_embd_dim)` + self.embed_positions = create_sinusoidal_positions(config.max_position_embeddings, pos_embd_dim) + + def _split_heads(self, hidden_states: torch.Tensor, transpose: bool = True) -> torch.Tensor: """ Splits hidden dim into attn_head_size and num_attention_heads """ - new_shape = tensor.size()[:-1] + (num_attention_heads, attn_head_size) - tensor = tensor.view(new_shape) - if rotary: - return tensor - if len(tensor.shape) == 5: - return tensor.permute(0, 1, 3, 2, 4) # (batch, blocks, head, block_length, head_features) - elif len(tensor.shape) == 4: - return tensor.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features) + new_shape = hidden_states.size()[:-1] + (self.num_attention_heads, self.head_dim) + hidden_states = hidden_states.view(new_shape) + if not (4 <= hidden_states.dim() <= 5): + raise ValueError(f"Input tensor rank should be one of [4, 5], but is: {hidden_states.dim()}") + if transpose: + # Shape is (batch, blocks, head, block_length, head_features) or + # (batch, head, seq_length, head_features) + return hidden_states.transpose(-2, -3) else: - raise ValueError(f"Input tensor rank should be one of [4, 5], but is: {len(tensor.shape)}") + return hidden_states - def _merge_heads(self, tensor, num_attention_heads, attn_head_size): + def _merge_heads(self, hidden_states: torch.Tensor) -> torch.Tensor: """ Merges attn_head_size dim and num_attn_heads dim into hidden dim """ - if len(tensor.shape) == 5: - tensor = tensor.permute(0, 1, 3, 2, 4).contiguous() - elif len(tensor.shape) == 4: - tensor = tensor.permute(0, 2, 1, 3).contiguous() - else: - raise ValueError(f"Input tensor rank should be one of [4, 5], but is: {len(tensor.shape)}") - new_shape = tensor.size()[:-2] + (num_attention_heads * attn_head_size,) - return tensor.view(new_shape) + if not (4 <= hidden_states.dim() <= 5): + raise ValueError(f"Input tensor rank should be one of [4, 5], but is: {hidden_states.dim()}") + hidden_states = hidden_states.transpose(-2, -3).contiguous() + new_shape = hidden_states.size()[:-2] + (self.num_attention_heads * self.head_dim,) + return hidden_states.view(new_shape) def _attn( self, - query, - key, - value, - attention_mask=None, - head_mask=None, - ): + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: # Keep the attention weights computation in fp32 to avoid overflow issues query = query.to(torch.float32) key = key.to(torch.float32) @@ -178,11 +177,22 @@ def _attn( return attn_output, attn_weights def _get_embed_positions(self, position_ids): + """ + This method does not subselect according to `position_ids`, it only + deals with device and shape. + + Args: + position_ids: Position indices + + Returns: + `embed_positions`, with device and shape according to `position_ids` + + """ embed_positions = self.embed_positions if embed_positions.device != position_ids.device: embed_positions = embed_positions.to(position_ids.device) self.embed_positions = embed_positions - return embed_positions.repeat(position_ids.shape[0], 1, 1) + return embed_positions.unsqueeze(0).expand(position_ids.shape[0], -1, -1) def forward( self, @@ -198,13 +208,14 @@ def forward( Tuple[torch.Tensor, Tuple[torch.Tensor]], Optional[Tuple[torch.Tensor, Tuple[torch.Tensor], Tuple[torch.Tensor, ...]]], ]: - query = self.q_proj(hidden_states) - key = self.k_proj(hidden_states) - value = self.v_proj(hidden_states) - - query = self._split_heads(query, self.num_attention_heads, self.head_dim, True) - key = self._split_heads(key, self.num_attention_heads, self.head_dim, True) - value = self._split_heads(value, self.num_attention_heads, self.head_dim, False) + if position_ids is None: + raise ValueError("position_ids must be given") + if position_ids.dim() == 1: + position_ids = position_ids.unsqueeze(0) + query = self._split_heads(self.q_proj(hidden_states)) + key = self._split_heads(self.k_proj(hidden_states)) + value = self._split_heads(self.v_proj(hidden_states)) + # query, key, value: (B, n_head, T, head_dim) if is_torch_fx_proxy(position_ids) or torch.jit.is_tracing(): # The logic to conditionally copy to GPU could not be traced, so we do this @@ -213,16 +224,19 @@ def forward( else: embed_positions = self._get_embed_positions(position_ids) - repeated_position_ids = position_ids.unsqueeze(-1).repeat(1, 1, embed_positions.shape[-1]) + repeated_position_ids = position_ids.unsqueeze(-1).expand(-1, -1, embed_positions.shape[-1]) sincos = torch.gather(embed_positions, 1, repeated_position_ids) sin, cos = torch.split(sincos, sincos.shape[-1] // 2, dim=-1) + sin = sin.unsqueeze(1) + cos = cos.unsqueeze(1) + # cos, sin: (B, 1, T, rotary_dim // 2) or (1, 1, T, rotary_dim // 2) if self.rotary_dim is not None: - k_rot = key[:, :, :, : self.rotary_dim] - k_pass = key[:, :, :, self.rotary_dim :] + k_rot = key[..., : self.rotary_dim] + k_pass = key[..., self.rotary_dim :] - q_rot = query[:, :, :, : self.rotary_dim] - q_pass = query[:, :, :, self.rotary_dim :] + q_rot = query[..., : self.rotary_dim] + q_pass = query[..., self.rotary_dim :] k_rot = apply_rotary_pos_emb(k_rot, sin, cos) q_rot = apply_rotary_pos_emb(q_rot, sin, cos) @@ -233,9 +247,6 @@ def forward( key = apply_rotary_pos_emb(key, sin, cos) query = apply_rotary_pos_emb(query, sin, cos) - key = key.permute(0, 2, 1, 3) - query = query.permute(0, 2, 1, 3) - if layer_past is not None: cache_kwargs = { "sin": sin, @@ -248,7 +259,7 @@ def forward( # compute self-attention: V x Softmax(QK^T) attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask) - attn_output = self._merge_heads(attn_output, self.num_attention_heads, self.head_dim) + attn_output = self._merge_heads(attn_output) attn_output = self.out_proj(attn_output) attn_output = self.resid_dropout(attn_output) @@ -288,13 +299,14 @@ def forward( Tuple[torch.Tensor, Tuple[torch.Tensor]], Optional[Tuple[torch.Tensor, Tuple[torch.Tensor], Tuple[torch.Tensor, ...]]], ]: - query = self.q_proj(hidden_states) - key = self.k_proj(hidden_states) - value = self.v_proj(hidden_states) - - query = self._split_heads(query, self.num_attention_heads, self.head_dim, True) - key = self._split_heads(key, self.num_attention_heads, self.head_dim, True) - value = self._split_heads(value, self.num_attention_heads, self.head_dim, False) + if position_ids is None: + raise ValueError("position_ids must be given") + if position_ids.dim() == 1: + position_ids = position_ids.unsqueeze(0) + query = self._split_heads(self.q_proj(hidden_states), transpose=False) + key = self._split_heads(self.k_proj(hidden_states), transpose=False) + value = self._split_heads(self.v_proj(hidden_states), transpose=False) + # query, key, value: (B, T, n_head, head_dim) if is_torch_fx_proxy(position_ids) or torch.jit.is_tracing(): # The logic to conditionally copy to GPU could not be traced, so we do this @@ -306,13 +318,16 @@ def forward( repeated_position_ids = position_ids.unsqueeze(-1).repeat(1, 1, embed_positions.shape[-1]) sincos = torch.gather(embed_positions, 1, repeated_position_ids) sin, cos = torch.split(sincos, sincos.shape[-1] // 2, dim=-1) + sin = sin.unsqueeze(2) + cos = cos.unsqueeze(2) + # cos, sin: (B, T, 1, rotary_dim // 2) if self.rotary_dim is not None: - k_rot = key[:, :, :, : self.rotary_dim] - k_pass = key[:, :, :, self.rotary_dim :] + k_rot = key[..., : self.rotary_dim] + k_pass = key[..., self.rotary_dim :] - q_rot = query[:, :, :, : self.rotary_dim] - q_pass = query[:, :, :, self.rotary_dim :] + q_rot = query[..., : self.rotary_dim] + q_pass = query[..., self.rotary_dim :] k_rot = apply_rotary_pos_emb(k_rot, sin, cos) q_rot = apply_rotary_pos_emb(q_rot, sin, cos) @@ -323,13 +338,6 @@ def forward( key = apply_rotary_pos_emb(key, sin, cos) query = apply_rotary_pos_emb(query, sin, cos) - # tanspose to have the desired shape - # before transpose: batch_size x seq_length x num_attention_heads x head_dim - # after transpose: batch_size x num_attention_heads x seq_length x head_dim - key = key.permute(0, 2, 1, 3) - query = query.permute(0, 2, 1, 3) - # value: batch_size x num_attention_heads x seq_length x head_dim - if layer_past is not None: cache_kwargs = { "sin": sin, @@ -337,15 +345,12 @@ def forward( "partial_rotation_size": self.rotary_dim, "cache_position": cache_position, } + # layer_past expects (B, num_heads, T, head_dim) shape + key = key.transpose(1, 2) + value = value.transpose(1, 2) key, value = layer_past.update(key, value, self.layer_idx, cache_kwargs) - - # The Flash attention requires the input to have the shape - # batch_size x seq_length x head_dim x hidden_dim - # therefore we need to keep the original shape for query and key, and reshape value - # to have the correct shape. - key = key.permute(0, 2, 1, 3).contiguous() - query = query.permute(0, 2, 1, 3).contiguous() - value = value.permute(0, 2, 1, 3).contiguous() + key = key.transpose(1, 2).contiguous() + value = value.transpose(1, 2).contiguous() # In PEFT, usually we cast the layer norms in float32 for training stability reasons # therefore the input hidden states gets silently casted in float32. Hence, we need @@ -469,7 +474,6 @@ def forward( outputs = (hidden_states,) + outputs else: outputs = (hidden_states,) + outputs[1:] - return outputs # hidden_states, present, (attentions) diff --git a/src/transformers/models/gptj/modeling_tf_gptj.py b/src/transformers/models/gptj/modeling_tf_gptj.py index a931287adfcd01..44a8e26805636b 100644 --- a/src/transformers/models/gptj/modeling_tf_gptj.py +++ b/src/transformers/models/gptj/modeling_tf_gptj.py @@ -65,22 +65,25 @@ def create_sinusoidal_positions(num_pos: int, dim: int) -> tf.Tensor: def rotate_every_two(x: tf.Tensor) -> tf.Tensor: - rotate_half_tensor = tf.stack((-x[:, :, :, 1::2], x[:, :, :, ::2]), axis=-1) - new_shape = shape_list(rotate_half_tensor)[:-2] + [tf.math.reduce_prod(shape_list(rotate_half_tensor)[-2:])] - rotate_half_tensor = tf.reshape(rotate_half_tensor, new_shape) - return rotate_half_tensor + x1 = x[:, :, :, ::2] + x2 = x[:, :, :, 1::2] + return tf.concat((-x2, x1), axis=-1) -def apply_rotary_pos_emb(tensor: tf.Tensor, sincos: tf.Tensor) -> tf.Tensor: - sin_pos, cos_pos = sincos - sin_pos = tf.repeat(sin_pos[:, :, None, :], 2, 3) - cos_pos = tf.repeat(cos_pos[:, :, None, :], 2, 3) - return (tensor * cos_pos) + (rotate_every_two(tensor) * sin_pos) +def apply_rotary_pos_emb(tensor: tf.Tensor, sin: tf.Tensor, cos: tf.Tensor) -> tf.Tensor: + sin = tf.repeat(sin, 2, axis=-1) + cos = tf.repeat(cos, 2, axis=-1) + return (tensor * cos) + (rotate_every_two(tensor) * sin) class TFGPTJAttention(keras.layers.Layer): def __init__(self, config: GPTJConfig, **kwargs): super().__init__(**kwargs) + self.config = config + max_positions = config.max_position_embeddings + + self.attn_dropout = keras.layers.Dropout(config.attn_pdrop) + self.resid_dropout = keras.layers.Dropout(config.resid_pdrop) self.embed_dim = config.hidden_size self.num_attention_heads = config.num_attention_heads @@ -93,9 +96,6 @@ def __init__(self, config: GPTJConfig, **kwargs): self.scale_attn = self.head_dim**0.5 self.rotary_dim = config.rotary_dim - self.attn_dropout = keras.layers.Dropout(config.attn_pdrop) - self.resid_dropout = keras.layers.Dropout(config.resid_pdrop) - self.q_proj = keras.layers.Dense( self.embed_dim, use_bias=False, @@ -121,13 +121,12 @@ def __init__(self, config: GPTJConfig, **kwargs): name="out_proj", ) - self.max_positions = config.max_position_embeddings self.lower_triangle_mask = tf.reshape( - tf.cast(tf.experimental.numpy.tril(tf.ones((self.max_positions, self.max_positions))), tf.int8), - (1, 1, self.max_positions, self.max_positions), + tf.cast(tf.experimental.numpy.tril(tf.ones((max_positions, max_positions))), tf.int8), + (1, 1, max_positions, max_positions), ) - pos_embd_dim = self.rotary_dim or self.embed_dim - self.embed_positions = create_sinusoidal_positions(self.max_positions, pos_embd_dim) + pos_embd_dim = self.rotary_dim or self.head_dim + self.embed_positions = create_sinusoidal_positions(max_positions, pos_embd_dim) def get_causal_mask(self, key_length, query_length) -> tf.Tensor: return tf.cast(self.lower_triangle_mask[:, :, key_length - query_length : key_length, :key_length], tf.bool) @@ -136,27 +135,28 @@ def get_causal_mask(self, key_length, query_length) -> tf.Tensor: def get_masked_bias(dtype: tf.DType) -> tf.Tensor: return tf.cast(tf.constant(-1e9), dtype) - def _split_heads(self, hidden_states: tf.Tensor, rotary: bool) -> tf.Tensor: + def _split_heads(self, hidden_states: tf.Tensor) -> tf.Tensor: """ Splits hidden dim into attn_head_size and num_attention_heads """ new_shape = shape_list(hidden_states)[:-1] + [self.num_attention_heads, self.head_dim] hidden_states = tf.reshape(hidden_states, new_shape) - if rotary: - return hidden_states - if len(shape_list(hidden_states)) == 4: + num_dims = len(shape_list(hidden_states)) + if num_dims == 4: return tf.transpose(hidden_states, (0, 2, 1, 3)) # (batch, head, seq_length, head_features) - if len(shape_list(hidden_states)) == 5: + elif num_dims == 5: return tf.transpose(hidden_states, (0, 1, 3, 2, 4)) # (batch, blocks, head, block_length, head_features) - raise ValueError(f"Input tensor rank should be one of [4, 5], but is: {len(shape_list(hidden_states))}") + else: + raise ValueError(f"Input tensor rank should be one of [4, 5], but is: {len(shape_list(hidden_states))}") def _merge_heads(self, hidden_states: tf.Tensor) -> tf.Tensor: """ Merges attn_head_size dim and num_attn_heads dim into hidden dim """ - if len(shape_list(hidden_states)) == 4: + num_dims = len(shape_list(hidden_states)) + if num_dims == 4: hidden_states = tf.transpose(hidden_states, (0, 2, 1, 3)) - elif len(shape_list(hidden_states)) == 5: + elif num_dims == 5: hidden_states = tf.transpose(hidden_states, (0, 1, 3, 2, 4)) else: raise ValueError(f"Input tensor rank should be one of [4, 5], but is: {len(shape_list(hidden_states))}") @@ -171,17 +171,16 @@ def _attn( attention_mask: tf.Tensor | None = None, head_mask: tf.Tensor | None = None, ) -> Tuple[tf.Tensor, tf.Tensor]: - # compute causal mask from causal mask buffer - query_length, key_length = shape_list(query)[-2], shape_list(key)[-2] - causal_mask = self.get_causal_mask(key_length, query_length) - # Keep the attention weights computation in fp32 to avoid overflow issues query = tf.cast(query, tf.float32) key = tf.cast(key, tf.float32) + # compute causal mask from causal mask buffer + query_length, key_length = shape_list(query)[-2], shape_list(key)[-2] + causal_mask = self.get_causal_mask(key_length, query_length) + attn_weights = tf.matmul(query, key, transpose_b=True) attn_weights = tf.where(causal_mask, attn_weights, self.get_masked_bias(attn_weights.dtype)) - attn_weights = attn_weights / self.scale_attn if attention_mask is not None: @@ -210,16 +209,21 @@ def call( use_cache: bool = False, output_attentions: bool = False, ): - query = self.q_proj(hidden_states) - key = self.k_proj(hidden_states) - value = self.v_proj(hidden_states) - - query = self._split_heads(query, True) - key = self._split_heads(key, True) - value = self._split_heads(value, False) + if position_ids is None: + raise ValueError("position_ids must be given") + if len(position_ids.shape) == 1: + position_ids = position_ids.reshape(1, -1) + query = self._split_heads(self.q_proj(hidden_states)) + key = self._split_heads(self.k_proj(hidden_states)) + value = self._split_heads(self.v_proj(hidden_states)) + # query, key, value: (B, n_head, T, head_dim) sincos = tf.cast(tf.gather(self.embed_positions, position_ids, axis=0), hidden_states.dtype) - sincos = tf.split(sincos, 2, axis=-1) + sin, cos = tf.split(sincos, 2, axis=-1) + sin = sin[:, None, :, :] + cos = cos[:, None, :, :] + # cos, sin: (B, 1, T, rotary_dim // 2) or (1, 1, T, rotary_dim // 2) + if self.rotary_dim is not None: k_rot = key[:, :, :, : self.rotary_dim] k_pass = key[:, :, :, self.rotary_dim :] @@ -227,17 +231,14 @@ def call( q_rot = query[:, :, :, : self.rotary_dim] q_pass = query[:, :, :, self.rotary_dim :] - k_rot = apply_rotary_pos_emb(k_rot, sincos) - q_rot = apply_rotary_pos_emb(q_rot, sincos) + k_rot = apply_rotary_pos_emb(k_rot, sin, cos) + q_rot = apply_rotary_pos_emb(q_rot, sin, cos) key = tf.concat((k_rot, k_pass), axis=-1) query = tf.concat((q_rot, q_pass), axis=-1) else: - key = apply_rotary_pos_emb(key, sincos) - query = apply_rotary_pos_emb(query, sincos) - - key = tf.transpose(key, (0, 2, 1, 3)) - query = tf.transpose(query, (0, 2, 1, 3)) + key = apply_rotary_pos_emb(key, sin, cos) + query = apply_rotary_pos_emb(query, sin, cos) if layer_past is not None: past_key = layer_past[0] diff --git a/src/transformers/models/granite/modeling_granite.py b/src/transformers/models/granite/modeling_granite.py index 7e758947b6dd8a..7bc3f8794dd29a 100644 --- a/src/transformers/models/granite/modeling_granite.py +++ b/src/transformers/models/granite/modeling_granite.py @@ -54,7 +54,7 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) -def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): +def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. Args: @@ -62,8 +62,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): k (`torch.Tensor`): The key tensor. cos (`torch.Tensor`): The cosine part of the rotary embedding. sin (`torch.Tensor`): The sine part of the rotary embedding. - position_ids (`torch.Tensor`, *optional*): - Deprecated and unused. unsqueeze_dim (`int`, *optional*, defaults to 1): The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note @@ -154,6 +152,8 @@ def forward( cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if position_embeddings is None: + raise ValueError("position_embeddings = (cos, sin) must be given") input_shape = hidden_states.shape[:-1] hidden_shape = (*input_shape, -1, self.head_dim) @@ -245,13 +245,13 @@ def __init__(self, config: GraniteConfig, layer_idx: int): def forward( self, hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ @@ -283,13 +283,13 @@ def forward( # Self Attention hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, + position_embeddings=position_embeddings, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, - position_embeddings=position_embeddings, **kwargs, ) hidden_states = residual + hidden_states * self.residual_multiplier @@ -597,24 +597,24 @@ def forward( layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, + position_embeddings, causal_mask, position_ids, past_key_values, output_attentions, use_cache, cache_position, - position_embeddings, ) else: layer_outputs = decoder_layer( hidden_states, + position_embeddings=position_embeddings, attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, - position_embeddings=position_embeddings, **flash_attn_kwargs, ) diff --git a/src/transformers/models/granite/modular_granite.py b/src/transformers/models/granite/modular_granite.py index 698280085f1852..36ba842689abc1 100644 --- a/src/transformers/models/granite/modular_granite.py +++ b/src/transformers/models/granite/modular_granite.py @@ -48,13 +48,13 @@ def __init__(self, config: GraniteConfig, layer_idx: int): def forward( self, hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ @@ -86,13 +86,13 @@ def forward( # Self Attention hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, + position_embeddings=position_embeddings, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, - position_embeddings=position_embeddings, **kwargs, ) hidden_states = residual + hidden_states * self.residual_multiplier @@ -187,24 +187,24 @@ def forward( layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, + position_embeddings, causal_mask, position_ids, past_key_values, output_attentions, use_cache, cache_position, - position_embeddings, ) else: layer_outputs = decoder_layer( hidden_states, + position_embeddings=position_embeddings, attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, - position_embeddings=position_embeddings, **flash_attn_kwargs, ) diff --git a/src/transformers/models/granitemoe/modeling_granitemoe.py b/src/transformers/models/granitemoe/modeling_granitemoe.py index 1c4c06bbc8d71e..34b94986544381 100644 --- a/src/transformers/models/granitemoe/modeling_granitemoe.py +++ b/src/transformers/models/granitemoe/modeling_granitemoe.py @@ -82,6 +82,7 @@ def load_balancing_loss_func( if gate_logits is None or not isinstance(gate_logits, tuple): return 0 + compute_device = None if isinstance(gate_logits, tuple): compute_device = gate_logits[0].device concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0) @@ -231,7 +232,7 @@ def rotate_half(x): # Copied from transformers.models.granite.modeling_granite.apply_rotary_pos_emb with Granite->GraniteMoe -def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): +def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. Args: @@ -239,8 +240,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): k (`torch.Tensor`): The key tensor. cos (`torch.Tensor`): The cosine part of the rotary embedding. sin (`torch.Tensor`): The sine part of the rotary embedding. - position_ids (`torch.Tensor`, *optional*): - Deprecated and unused. unsqueeze_dim (`int`, *optional*, defaults to 1): The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note @@ -457,15 +456,17 @@ def __init__(self, config: GraniteMoeConfig, layer_idx: Optional[int] = None): def forward( self, hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if position_embeddings is None: + raise ValueError("position_embeddings = (cos, sin) must be given") bsz, q_len, _ = hidden_states.size() query_states = self.q_proj(hidden_states) @@ -535,14 +536,16 @@ def __init__(self, *args, **kwargs): def forward( self, hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if position_embeddings is None: + raise ValueError("position_embeddings = (cos, sin) must be given") output_attentions = False bsz, q_len, _ = hidden_states.size() @@ -636,15 +639,17 @@ class GraniteMoeSdpaAttention(GraniteMoeAttention): def forward( self, hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if position_embeddings is None: + raise ValueError("position_embeddings = (cos, sin) must be given") if output_attentions: # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. logger.warning_once( @@ -653,13 +658,13 @@ def forward( ) return super().forward( hidden_states=hidden_states, + position_embeddings=position_embeddings, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, - position_embeddings=position_embeddings, ) bsz, q_len, _ = hidden_states.size() @@ -739,6 +744,7 @@ def __init__(self, config: GraniteMoeConfig, layer_idx: int): def forward( self, hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, @@ -746,7 +752,6 @@ def forward( use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, output_router_logits: Optional[bool] = False, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ @@ -781,13 +786,13 @@ def forward( # Self Attention hidden_states, self_attn_weights, present_key_value = self.self_attn( hidden_states=hidden_states, + position_embeddings=position_embeddings, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, - position_embeddings=position_embeddings, **kwargs, ) @@ -1058,6 +1063,7 @@ def forward( layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, + position_embeddings, causal_mask, position_ids, past_key_values, @@ -1065,11 +1071,11 @@ def forward( use_cache, cache_position, output_router_logits, - position_embeddings, ) else: layer_outputs = decoder_layer( hidden_states, + position_embeddings=position_embeddings, attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, @@ -1077,7 +1083,6 @@ def forward( use_cache=use_cache, cache_position=cache_position, output_router_logits=output_router_logits, - position_embeddings=position_embeddings, ) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/jamba/modeling_jamba.py b/src/transformers/models/jamba/modeling_jamba.py index ae7470d789b27e..0a7fc071b2e7c5 100755 --- a/src/transformers/models/jamba/modeling_jamba.py +++ b/src/transformers/models/jamba/modeling_jamba.py @@ -112,6 +112,7 @@ def load_balancing_loss_func( if router_logits is None or not isinstance(router_logits, tuple): return 0 + compute_device = None if isinstance(router_logits, tuple): compute_device = router_logits[0].device concatenated_router_logits = torch.cat( diff --git a/src/transformers/models/jetmoe/modeling_jetmoe.py b/src/transformers/models/jetmoe/modeling_jetmoe.py index a2a86fd4c22f4a..461fee182a19dc 100644 --- a/src/transformers/models/jetmoe/modeling_jetmoe.py +++ b/src/transformers/models/jetmoe/modeling_jetmoe.py @@ -87,6 +87,7 @@ def load_balancing_loss_func( if gate_logits is None or not isinstance(gate_logits, tuple): return 0 + compute_device = None if isinstance(gate_logits, tuple): compute_device = gate_logits[0].device concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0) @@ -459,7 +460,7 @@ def rotate_half(x): # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb -def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): +def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. Args: @@ -467,8 +468,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): k (`torch.Tensor`): The key tensor. cos (`torch.Tensor`): The cosine part of the rotary embedding. sin (`torch.Tensor`): The sine part of the rotary embedding. - position_ids (`torch.Tensor`, *optional*): - Deprecated and unused. unsqueeze_dim (`int`, *optional*, defaults to 1): The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note diff --git a/src/transformers/models/llama/modeling_flax_llama.py b/src/transformers/models/llama/modeling_flax_llama.py index 26a2c2bb09a3d2..dee04225d0bee3 100644 --- a/src/transformers/models/llama/modeling_flax_llama.py +++ b/src/transformers/models/llama/modeling_flax_llama.py @@ -133,7 +133,8 @@ def create_sinusoidal_positions(num_pos, dim): freqs = np.einsum("i , j -> i j", np.arange(num_pos), inv_freq).astype("float32") emb = np.concatenate((freqs, freqs), axis=-1) - out = np.concatenate((np.sin(emb)[:, None, :], np.cos(emb)[:, None, :]), axis=-1) + emb = emb[:, None, :] + out = np.concatenate((np.sin(emb), np.cos(emb)), axis=-1) return jnp.array(out[:, :, :num_pos]) diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index df46e15bce0009..9c7198c0f7b5ba 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -151,7 +151,7 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) -def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): +def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. Args: @@ -159,8 +159,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): k (`torch.Tensor`): The key tensor. cos (`torch.Tensor`): The cosine part of the rotary embedding. sin (`torch.Tensor`): The sine part of the rotary embedding. - position_ids (`torch.Tensor`, *optional*): - Deprecated and unused. unsqueeze_dim (`int`, *optional*, defaults to 1): The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note @@ -267,6 +265,8 @@ def forward( cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if position_embeddings is None: + raise ValueError("position_embeddings = (cos, sin) must be given") input_shape = hidden_states.shape[:-1] hidden_shape = (*input_shape, -1, self.head_dim) @@ -322,13 +322,13 @@ def __init__(self, config: LlamaConfig, layer_idx: int): def forward( self, hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: residual = hidden_states @@ -338,13 +338,13 @@ def forward( # Self Attention hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, + position_embeddings=position_embeddings, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, - position_embeddings=position_embeddings, **kwargs, ) hidden_states = residual + hidden_states @@ -583,24 +583,24 @@ def forward( layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, + position_embeddings, causal_mask, position_ids, past_key_values, output_attentions, use_cache, cache_position, - position_embeddings, ) else: layer_outputs = decoder_layer( hidden_states, + position_embeddings=position_embeddings, attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, - position_embeddings=position_embeddings, **flash_attn_kwargs, ) diff --git a/src/transformers/models/mimi/modeling_mimi.py b/src/transformers/models/mimi/modeling_mimi.py index 1440ce1e075c95..8ffc341804e6e8 100644 --- a/src/transformers/models/mimi/modeling_mimi.py +++ b/src/transformers/models/mimi/modeling_mimi.py @@ -438,7 +438,7 @@ def rotate_half(x): # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb -def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): +def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. Args: @@ -446,8 +446,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): k (`torch.Tensor`): The key tensor. cos (`torch.Tensor`): The cosine part of the rotary embedding. sin (`torch.Tensor`): The sine part of the rotary embedding. - position_ids (`torch.Tensor`, *optional*): - Deprecated and unused. unsqueeze_dim (`int`, *optional*, defaults to 1): The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note diff --git a/src/transformers/models/mistral/modeling_flax_mistral.py b/src/transformers/models/mistral/modeling_flax_mistral.py index 3bff2a6281220e..b65d3803a29947 100644 --- a/src/transformers/models/mistral/modeling_flax_mistral.py +++ b/src/transformers/models/mistral/modeling_flax_mistral.py @@ -203,7 +203,8 @@ def create_sinusoidal_positions(num_pos, dim): freqs = np.einsum("i , j -> i j", np.arange(num_pos), inv_freq).astype("float32") emb = np.concatenate((freqs, freqs), axis=-1) - out = np.concatenate((np.sin(emb)[:, None, :], np.cos(emb)[:, None, :]), axis=-1) + emb = emb[:, None, :] + out = np.concatenate((np.sin(emb), np.cos(emb)), axis=-1) return jnp.array(out[:, :, :num_pos]) diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index 90c38895b4280b..e9fc0f5698e424 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -64,7 +64,7 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) -def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): +def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. Args: @@ -72,8 +72,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): k (`torch.Tensor`): The key tensor. cos (`torch.Tensor`): The cosine part of the rotary embedding. sin (`torch.Tensor`): The sine part of the rotary embedding. - position_ids (`torch.Tensor`, *optional*): - Deprecated and unused. unsqueeze_dim (`int`, *optional*, defaults to 1): The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note @@ -155,6 +153,8 @@ def forward( cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if position_embeddings is None: + raise ValueError("position_embeddings = (cos, sin) must be given") input_shape = hidden_states.shape[:-1] hidden_shape = (*input_shape, -1, self.head_dim) @@ -229,13 +229,13 @@ def __init__(self, config: MistralConfig, layer_idx: int): def forward( self, hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: residual = hidden_states @@ -245,13 +245,13 @@ def forward( # Self Attention hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, + position_embeddings=position_embeddings, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, - position_embeddings=position_embeddings, **kwargs, ) hidden_states = residual + hidden_states @@ -555,24 +555,24 @@ def forward( layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, + position_embeddings, causal_mask, position_ids, past_key_values, output_attentions, use_cache, cache_position, - position_embeddings, ) else: layer_outputs = decoder_layer( hidden_states, + position_embeddings=position_embeddings, attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, - position_embeddings=position_embeddings, **flash_attn_kwargs, ) diff --git a/src/transformers/models/mistral/modular_mistral.py b/src/transformers/models/mistral/modular_mistral.py index 362233a21b70f4..c85ad9e730dc04 100644 --- a/src/transformers/models/mistral/modular_mistral.py +++ b/src/transformers/models/mistral/modular_mistral.py @@ -56,6 +56,8 @@ def forward( cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if position_embeddings is None: + raise ValueError("position_embeddings = (cos, sin) must be given") input_shape = hidden_states.shape[:-1] hidden_shape = (*input_shape, -1, self.head_dim) diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py index 84ed327d9be920..3183a3c3f2e857 100644 --- a/src/transformers/models/mixtral/modeling_mixtral.py +++ b/src/transformers/models/mixtral/modeling_mixtral.py @@ -177,7 +177,7 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) -def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): +def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. Args: @@ -185,8 +185,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): k (`torch.Tensor`): The key tensor. cos (`torch.Tensor`): The cosine part of the rotary embedding. sin (`torch.Tensor`): The sine part of the rotary embedding. - position_ids (`torch.Tensor`, *optional*): - Deprecated and unused. unsqueeze_dim (`int`, *optional*, defaults to 1): The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note @@ -268,6 +266,8 @@ def forward( cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if position_embeddings is None: + raise ValueError("position_embeddings = (cos, sin) must be given") input_shape = hidden_states.shape[:-1] hidden_shape = (*input_shape, -1, self.head_dim) @@ -324,6 +324,7 @@ def __init__(self, config: MixtralConfig, layer_idx: int): def forward( self, hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, @@ -331,7 +332,6 @@ def forward( output_router_logits: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ @@ -683,6 +683,7 @@ def forward( layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, + position_embeddings, causal_mask, position_ids, past_key_values, @@ -690,11 +691,11 @@ def forward( output_router_logits, use_cache, cache_position, - position_embeddings, ) else: layer_outputs = decoder_layer( hidden_states, + position_embeddings=position_embeddings, attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, @@ -702,7 +703,6 @@ def forward( output_router_logits=output_router_logits, use_cache=use_cache, cache_position=cache_position, - position_embeddings=position_embeddings, **flash_attn_kwargs, ) @@ -915,6 +915,7 @@ def load_balancing_loss_func( if gate_logits is None or not isinstance(gate_logits, tuple): return 0 + compute_device = None if isinstance(gate_logits, tuple): compute_device = gate_logits[0].device concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0) diff --git a/src/transformers/models/mixtral/modular_mixtral.py b/src/transformers/models/mixtral/modular_mixtral.py index a6069f69b33421..52cf2876852057 100644 --- a/src/transformers/models/mixtral/modular_mixtral.py +++ b/src/transformers/models/mixtral/modular_mixtral.py @@ -88,6 +88,7 @@ def load_balancing_loss_func( if gate_logits is None or not isinstance(gate_logits, tuple): return 0 + compute_device = None if isinstance(gate_logits, tuple): compute_device = gate_logits[0].device concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0) @@ -246,6 +247,7 @@ def __init__(self, config: MixtralConfig, layer_idx: int): def forward( self, hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, @@ -253,7 +255,6 @@ def forward( output_router_logits: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ @@ -392,6 +393,7 @@ def forward( layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, + position_embeddings, causal_mask, position_ids, past_key_values, @@ -399,11 +401,11 @@ def forward( output_router_logits, use_cache, cache_position, - position_embeddings, ) else: layer_outputs = decoder_layer( hidden_states, + position_embeddings=position_embeddings, attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, @@ -411,7 +413,6 @@ def forward( output_router_logits=output_router_logits, use_cache=use_cache, cache_position=cache_position, - position_embeddings=position_embeddings, **flash_attn_kwargs, ) diff --git a/src/transformers/models/mllama/modeling_mllama.py b/src/transformers/models/mllama/modeling_mllama.py index 6523ab6812179c..38c1b0792520f7 100644 --- a/src/transformers/models/mllama/modeling_mllama.py +++ b/src/transformers/models/mllama/modeling_mllama.py @@ -617,7 +617,7 @@ def rotate_half(x): # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb -def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): +def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. Args: @@ -625,8 +625,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): k (`torch.Tensor`): The key tensor. cos (`torch.Tensor`): The cosine part of the rotary embedding. sin (`torch.Tensor`): The sine part of the rotary embedding. - position_ids (`torch.Tensor`, *optional*): - Deprecated and unused. unsqueeze_dim (`int`, *optional*, defaults to 1): The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note diff --git a/src/transformers/models/modernbert/modeling_modernbert.py b/src/transformers/models/modernbert/modeling_modernbert.py index 237fba6f645fa5..f2b34c1c752818 100644 --- a/src/transformers/models/modernbert/modeling_modernbert.py +++ b/src/transformers/models/modernbert/modeling_modernbert.py @@ -269,7 +269,7 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) -def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): +def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. Args: @@ -277,8 +277,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): k (`torch.Tensor`): The key tensor. cos (`torch.Tensor`): The cosine part of the rotary embedding. sin (`torch.Tensor`): The sine part of the rotary embedding. - position_ids (`torch.Tensor`, *optional*): - Deprecated and unused. unsqueeze_dim (`int`, *optional*, defaults to 1): The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note diff --git a/src/transformers/models/moshi/modeling_moshi.py b/src/transformers/models/moshi/modeling_moshi.py index f0281f57cf1c75..424a7b94d100fe 100644 --- a/src/transformers/models/moshi/modeling_moshi.py +++ b/src/transformers/models/moshi/modeling_moshi.py @@ -381,7 +381,7 @@ def rotate_half(x): # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb -def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): +def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. Args: @@ -389,8 +389,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): k (`torch.Tensor`): The key tensor. cos (`torch.Tensor`): The cosine part of the rotary embedding. sin (`torch.Tensor`): The sine part of the rotary embedding. - position_ids (`torch.Tensor`, *optional*): - Deprecated and unused. unsqueeze_dim (`int`, *optional*, defaults to 1): The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note diff --git a/src/transformers/models/nemotron/configuration_nemotron.py b/src/transformers/models/nemotron/configuration_nemotron.py index 7690703127ac92..0720ede2cdb74d 100644 --- a/src/transformers/models/nemotron/configuration_nemotron.py +++ b/src/transformers/models/nemotron/configuration_nemotron.py @@ -76,7 +76,8 @@ class NemotronConfig(PretrainedConfig): Whether to tie weight embeddings rope_theta (`float`, *optional*, defaults to 10000.0): The base period of the RoPE embeddings. - partial_rotary_factor (`float`, *optional*, defaults to 0.5): Percentage of the query and keys which will have rotary embedding. + partial_rotary_factor (`float`, *optional*, defaults to 1.0): + Percentage of the query and keys which will have rotary embedding. attention_bias (`bool`, *optional*, defaults to `False`): Whether to use a bias in the query, key, value and output projection layers during self-attention. attention_dropout (`float`, *optional*, defaults to 0.0): @@ -119,7 +120,7 @@ def __init__( eos_token_id=3, tie_word_embeddings=False, rope_theta=10000.0, - partial_rotary_factor=0.5, + partial_rotary_factor=1.0, attention_bias=False, attention_dropout=0.0, mlp_bias=False, diff --git a/src/transformers/models/nemotron/modeling_nemotron.py b/src/transformers/models/nemotron/modeling_nemotron.py index 75618f1c7e00c7..033a9dd37bb6c4 100644 --- a/src/transformers/models/nemotron/modeling_nemotron.py +++ b/src/transformers/models/nemotron/modeling_nemotron.py @@ -237,7 +237,8 @@ def __init__(self, config: NemotronConfig, layer_idx: Optional[int] = None): self.num_key_value_groups = self.num_heads // self.num_key_value_heads self.max_position_embeddings = config.max_position_embeddings self.rope_theta = config.rope_theta - self.partial_rotary_factor = config.partial_rotary_factor + head_size = config.hidden_size // config.num_attention_heads + self.rotary_ndims = int(head_size * config.partial_rotary_factor) self.is_causal = True self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) @@ -256,6 +257,8 @@ def forward( use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if position_embeddings is None: + raise ValueError("position_embeddings = (cos, sin) is required") bsz, q_len, _ = hidden_states.size() query_states = self.q_proj(hidden_states) @@ -266,9 +269,15 @@ def forward( key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - if position_embeddings is not None: - cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + # Compute rotary embeddings on rotary_ndims + query_rot = query_states[..., : self.rotary_ndims] + query_pass = query_states[..., self.rotary_ndims :] + key_rot = key_states[..., : self.rotary_ndims] + key_pass = key_states[..., self.rotary_ndims :] + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_rot, key_rot, cos, sin) + query_states = torch.cat((query_states, query_pass), dim=-1) + key_states = torch.cat((key_states, key_pass), dim=-1) if past_key_value is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache @@ -330,6 +339,8 @@ def forward( use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if position_embeddings is None: + raise ValueError("position_embeddings = (cos, sin) is required") if isinstance(past_key_value, StaticCache): raise ValueError( "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` " @@ -351,9 +362,15 @@ def forward( key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - if position_embeddings is not None: - cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + # Compute rotary embeddings on rotary_ndims + query_rot = query_states[..., : self.rotary_ndims] + query_pass = query_states[..., self.rotary_ndims :] + key_rot = key_states[..., : self.rotary_ndims] + key_pass = key_states[..., self.rotary_ndims :] + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_rot, key_rot, cos, sin) + query_states = torch.cat((query_states, query_pass), dim=-1) + key_states = torch.cat((key_states, key_pass), dim=-1) if past_key_value is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache @@ -438,6 +455,8 @@ def forward( cache_position: Optional[torch.LongTensor] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if position_embeddings is None: + raise ValueError("position_embeddings = (cos, sin) is required") if output_attentions: # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. logger.warning_once( @@ -446,13 +465,13 @@ def forward( ) return super().forward( hidden_states=hidden_states, + position_embeddings=position_embeddings, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, - position_embeddings=position_embeddings, ) bsz, q_len, _ = hidden_states.size() @@ -465,9 +484,15 @@ def forward( key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - if position_embeddings is not None: - cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + # Compute rotary embeddings on rotary_ndims + query_rot = query_states[..., : self.rotary_ndims] + query_pass = query_states[..., self.rotary_ndims :] + key_rot = key_states[..., : self.rotary_ndims] + key_pass = key_states[..., self.rotary_ndims :] + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_rot, key_rot, cos, sin) + query_states = torch.cat((query_states, query_pass), dim=-1) + key_states = torch.cat((key_states, key_pass), dim=-1) if past_key_value is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache @@ -533,13 +558,13 @@ def __init__(self, config: NemotronConfig, layer_idx: int): def forward( self, hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ @@ -571,13 +596,13 @@ def forward( # Self Attention hidden_states, self_attn_weights, present_key_value = self.self_attn( hidden_states=hidden_states, + position_embeddings=position_embeddings, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, - position_embeddings=position_embeddings, **kwargs, ) hidden_states = residual + hidden_states @@ -823,24 +848,24 @@ def forward( layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, + position_embeddings, causal_mask, position_ids, past_key_values, output_attentions, use_cache, cache_position, - position_embeddings, ) else: layer_outputs = decoder_layer( hidden_states, + position_embeddings=position_embeddings, attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, - position_embeddings=position_embeddings, ) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/olmo/modeling_olmo.py b/src/transformers/models/olmo/modeling_olmo.py index 39bfa726deeedf..da27df015d414b 100644 --- a/src/transformers/models/olmo/modeling_olmo.py +++ b/src/transformers/models/olmo/modeling_olmo.py @@ -70,7 +70,7 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) -def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): +def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. Args: @@ -78,8 +78,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): k (`torch.Tensor`): The key tensor. cos (`torch.Tensor`): The cosine part of the rotary embedding. sin (`torch.Tensor`): The sine part of the rotary embedding. - position_ids (`torch.Tensor`, *optional*): - Deprecated and unused. unsqueeze_dim (`int`, *optional*, defaults to 1): The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note @@ -233,13 +231,13 @@ def __init__(self, config: OlmoConfig, layer_idx: int): def forward( self, hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: residual = hidden_states @@ -249,13 +247,13 @@ def forward( # Self Attention hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, + position_embeddings=position_embeddings, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, - position_embeddings=position_embeddings, **kwargs, ) hidden_states = residual + hidden_states @@ -559,24 +557,24 @@ def forward( layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, + position_embeddings, causal_mask, position_ids, past_key_values, output_attentions, use_cache, cache_position, - position_embeddings, ) else: layer_outputs = decoder_layer( hidden_states, + position_embeddings=position_embeddings, attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, - position_embeddings=position_embeddings, **flash_attn_kwargs, ) diff --git a/src/transformers/models/olmo2/modeling_olmo2.py b/src/transformers/models/olmo2/modeling_olmo2.py index 89b5f4abe1c39c..d3692d7de04b15 100644 --- a/src/transformers/models/olmo2/modeling_olmo2.py +++ b/src/transformers/models/olmo2/modeling_olmo2.py @@ -59,7 +59,7 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) -def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): +def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. Args: @@ -67,8 +67,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): k (`torch.Tensor`): The key tensor. cos (`torch.Tensor`): The cosine part of the rotary embedding. sin (`torch.Tensor`): The sine part of the rotary embedding. - position_ids (`torch.Tensor`, *optional*): - Deprecated and unused. unsqueeze_dim (`int`, *optional*, defaults to 1): The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note @@ -235,13 +233,13 @@ def __init__(self, config: Olmo2Config, layer_idx: int): def forward( self, hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: residual = hidden_states @@ -249,13 +247,13 @@ def forward( # Self Attention hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, + position_embeddings=position_embeddings, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, - position_embeddings=position_embeddings, **kwargs, ) hidden_states = self.post_attention_layernorm(hidden_states) @@ -560,24 +558,24 @@ def forward( layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, + position_embeddings, causal_mask, position_ids, past_key_values, output_attentions, use_cache, cache_position, - position_embeddings, ) else: layer_outputs = decoder_layer( hidden_states, + position_embeddings=position_embeddings, attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, - position_embeddings=position_embeddings, **flash_attn_kwargs, ) diff --git a/src/transformers/models/olmo2/modular_olmo2.py b/src/transformers/models/olmo2/modular_olmo2.py index 5f119170804466..c5b74f41a8544c 100644 --- a/src/transformers/models/olmo2/modular_olmo2.py +++ b/src/transformers/models/olmo2/modular_olmo2.py @@ -234,13 +234,13 @@ def __init__(self, config: Olmo2Config, layer_idx: int): def forward( self, hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: residual = hidden_states @@ -248,13 +248,13 @@ def forward( # Self Attention hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, + position_embeddings=position_embeddings, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, - position_embeddings=position_embeddings, **kwargs, ) hidden_states = self.post_attention_layernorm(hidden_states) diff --git a/src/transformers/models/olmoe/modeling_olmoe.py b/src/transformers/models/olmoe/modeling_olmoe.py index fa3c2f3cd4d11b..44abe4f6d81e48 100644 --- a/src/transformers/models/olmoe/modeling_olmoe.py +++ b/src/transformers/models/olmoe/modeling_olmoe.py @@ -83,6 +83,7 @@ def load_balancing_loss_func( if gate_logits is None or not isinstance(gate_logits, tuple): return 0 + compute_device = None if isinstance(gate_logits, tuple): compute_device = gate_logits[0].device concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0) @@ -231,7 +232,7 @@ def rotate_half(x): # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb -def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): +def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. Args: @@ -239,8 +240,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): k (`torch.Tensor`): The key tensor. cos (`torch.Tensor`): The cosine part of the rotary embedding. sin (`torch.Tensor`): The sine part of the rotary embedding. - position_ids (`torch.Tensor`, *optional*): - Deprecated and unused. unsqueeze_dim (`int`, *optional*, defaults to 1): The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note @@ -330,15 +329,17 @@ def __init__(self, config: OlmoeConfig, layer_idx: Optional[int] = None): def forward( self, hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if position_embeddings is None: + raise ValueError("position_embeddings = (cos, sin) must be given") bsz, q_len, _ = hidden_states.size() query_states = self.q_norm(self.q_proj(hidden_states)) @@ -412,15 +413,17 @@ def __init__(self, *args, **kwargs): def forward( self, hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if position_embeddings is None: + raise ValueError("position_embeddings = (cos, sin) must be given") output_attentions = False bsz, q_len, _ = hidden_states.size() @@ -513,14 +516,16 @@ class OlmoeSdpaAttention(OlmoeAttention): def forward( self, hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if position_embeddings is None: + raise ValueError("position_embeddings = (cos, sin) must be given") if output_attentions: # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. logger.warning_once( @@ -529,13 +534,13 @@ def forward( ) return super().forward( hidden_states=hidden_states, + position_embeddings=position_embeddings, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, - position_embeddings=position_embeddings, ) bsz, q_len, _ = hidden_states.size() @@ -666,6 +671,7 @@ def __init__(self, config: OlmoeConfig, layer_idx: int): def forward( self, hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, @@ -673,7 +679,6 @@ def forward( output_router_logits: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ @@ -708,13 +713,13 @@ def forward( # Self Attention hidden_states, self_attn_weights, present_key_value = self.self_attn( hidden_states=hidden_states, + position_embeddings=position_embeddings, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, - position_embeddings=position_embeddings, **kwargs, ) hidden_states = residual + hidden_states @@ -981,6 +986,7 @@ def forward( layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, + position_embeddings, causal_mask, position_ids, past_key_values, @@ -988,11 +994,11 @@ def forward( output_router_logits, use_cache, cache_position, - position_embeddings, ) else: layer_outputs = decoder_layer( hidden_states, + position_embeddings=position_embeddings, attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, @@ -1000,7 +1006,6 @@ def forward( output_router_logits=output_router_logits, use_cache=use_cache, cache_position=cache_position, - position_embeddings=position_embeddings, ) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/persimmon/modeling_persimmon.py b/src/transformers/models/persimmon/modeling_persimmon.py index 27712741b7c28f..8cd447a85136f4 100644 --- a/src/transformers/models/persimmon/modeling_persimmon.py +++ b/src/transformers/models/persimmon/modeling_persimmon.py @@ -130,7 +130,7 @@ def rotate_half(x): # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb -def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): +def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. Args: @@ -138,8 +138,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): k (`torch.Tensor`): The key tensor. cos (`torch.Tensor`): The cosine part of the rotary embedding. sin (`torch.Tensor`): The sine part of the rotary embedding. - position_ids (`torch.Tensor`, *optional*): - Deprecated and unused. unsqueeze_dim (`int`, *optional*, defaults to 1): The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note @@ -231,14 +229,16 @@ def _split_heads(self, fused_qkv: torch.Tensor) -> Tuple[torch.Tensor, torch.Ten def forward( self, hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if position_embeddings is None: + raise ValueError("position_embeddings = (cos, sin) must be given") bsz, q_len, _ = hidden_states.size() # [batch_size, seq_length, 3 x hidden_size] @@ -326,13 +326,13 @@ def __init__(self, config: PersimmonConfig, layer_idx: int): def forward( self, hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ Args: @@ -365,13 +365,13 @@ def forward( # Self Attention hidden_states, self_attn_weights, present_key_value = self.self_attn( hidden_states=hidden_states, + position_embeddings=position_embeddings, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, - position_embeddings=position_embeddings, ) hidden_states = residual + hidden_states @@ -626,24 +626,24 @@ def forward( layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, + position_embeddings, causal_mask, position_ids, past_key_values, output_attentions, use_cache, cache_position, - position_embeddings, ) else: layer_outputs = decoder_layer( hidden_states, + position_embeddings=position_embeddings, attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, - position_embeddings=position_embeddings, ) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/phi/modeling_phi.py b/src/transformers/models/phi/modeling_phi.py index 5aa038d3ccfaa8..337c0e5e8cf992 100644 --- a/src/transformers/models/phi/modeling_phi.py +++ b/src/transformers/models/phi/modeling_phi.py @@ -47,7 +47,7 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) -def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): +def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. Args: @@ -55,8 +55,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): k (`torch.Tensor`): The key tensor. cos (`torch.Tensor`): The cosine part of the rotary embedding. sin (`torch.Tensor`): The sine part of the rotary embedding. - position_ids (`torch.Tensor`, *optional*): - Deprecated and unused. unsqueeze_dim (`int`, *optional*, defaults to 1): The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note @@ -147,6 +145,8 @@ def forward( cache_position: Optional[torch.LongTensor] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if position_embeddings is None: + raise ValueError("position_embeddings = (cos, sin) must be given") input_shape = hidden_states.shape[:-1] hidden_shape = (*input_shape, -1, self.head_dim) @@ -232,13 +232,13 @@ def __init__(self, config: PhiConfig, layer_idx: int): def forward( self, hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: residual = hidden_states @@ -248,13 +248,13 @@ def forward( # Self Attention attn_outputs, self_attn_weights = self.self_attn( hidden_states=hidden_states, + position_embeddings=position_embeddings, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, - position_embeddings=position_embeddings, **kwargs, ) attn_outputs = self.resid_dropout(attn_outputs) @@ -557,24 +557,24 @@ def forward( layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, + position_embeddings, causal_mask, position_ids, past_key_values, output_attentions, use_cache, cache_position, - position_embeddings, ) else: layer_outputs = decoder_layer( hidden_states, + position_embeddings=position_embeddings, attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, - position_embeddings=position_embeddings, **flash_attn_kwargs, ) diff --git a/src/transformers/models/phi/modular_phi.py b/src/transformers/models/phi/modular_phi.py index 0faa4629f1a768..a4ed37f0596864 100644 --- a/src/transformers/models/phi/modular_phi.py +++ b/src/transformers/models/phi/modular_phi.py @@ -54,6 +54,8 @@ def forward( cache_position: Optional[torch.LongTensor] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if position_embeddings is None: + raise ValueError("position_embeddings = (cos, sin) must be given") input_shape = hidden_states.shape[:-1] hidden_shape = (*input_shape, -1, self.head_dim) @@ -128,13 +130,13 @@ def __init__(self, config: PhiConfig, layer_idx: int): def forward( self, hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: residual = hidden_states @@ -144,13 +146,13 @@ def forward( # Self Attention attn_outputs, self_attn_weights = self.self_attn( hidden_states=hidden_states, + position_embeddings=position_embeddings, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, - position_embeddings=position_embeddings, **kwargs, ) attn_outputs = self.resid_dropout(attn_outputs) @@ -242,24 +244,24 @@ def forward( layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, + position_embeddings, causal_mask, position_ids, past_key_values, output_attentions, use_cache, cache_position, - position_embeddings, ) else: layer_outputs = decoder_layer( hidden_states, + position_embeddings=position_embeddings, attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, - position_embeddings=position_embeddings, **flash_attn_kwargs, ) diff --git a/src/transformers/models/phi3/configuration_phi3.py b/src/transformers/models/phi3/configuration_phi3.py index 4940f43e5bffe3..f87f0b55940a95 100644 --- a/src/transformers/models/phi3/configuration_phi3.py +++ b/src/transformers/models/phi3/configuration_phi3.py @@ -15,6 +15,8 @@ """Phi-3 model configuration""" +import math + from ...configuration_utils import PretrainedConfig from ...utils import logging @@ -204,7 +206,8 @@ def _rope_scaling_validation(self): raise ValueError( f"`rope_scaling`'s short_factor field must be a list of numbers, got {rope_scaling_short_factor}" ) - if not len(rope_scaling_short_factor) == self.hidden_size // self.num_attention_heads // 2: + required_size = int(math.ceil((self.hidden_size // self.num_attention_heads) / 2)) + if not len(rope_scaling_short_factor) == required_size: raise ValueError( f"`rope_scaling`'s short_factor field must have length {self.hidden_size // self.num_attention_heads // 2}, got {len(rope_scaling_short_factor)}" ) @@ -215,7 +218,7 @@ def _rope_scaling_validation(self): raise ValueError( f"`rope_scaling`'s long_factor field must be a list of numbers, got {rope_scaling_long_factor}" ) - if not len(rope_scaling_long_factor) == self.hidden_size // self.num_attention_heads // 2: + if not len(rope_scaling_long_factor) == required_size: raise ValueError( f"`rope_scaling`'s long_factor field must have length {self.hidden_size // self.num_attention_heads // 2}, got {len(rope_scaling_long_factor)}" ) diff --git a/src/transformers/models/phi3/modeling_phi3.py b/src/transformers/models/phi3/modeling_phi3.py index 908fd982b9c73c..13a1048fe94ff0 100644 --- a/src/transformers/models/phi3/modeling_phi3.py +++ b/src/transformers/models/phi3/modeling_phi3.py @@ -80,11 +80,11 @@ class Phi3RotaryEmbedding(nn.Module): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): super().__init__() - self.dim = dim self.max_position_embeddings = max_position_embeddings self.base = base + self.dim = dim - inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim)) + inv_freq = 1.0 / (self.base ** (torch.arange(0, dim, 2, dtype=torch.int64).float() / dim)) self.register_buffer("inv_freq", tensor=inv_freq, persistent=False) @torch.no_grad() @@ -241,7 +241,7 @@ def rotate_half(x): # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb -def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): +def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. Args: @@ -249,8 +249,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): k (`torch.Tensor`): The key tensor. cos (`torch.Tensor`): The cosine part of the rotary embedding. sin (`torch.Tensor`): The sine part of the rotary embedding. - position_ids (`torch.Tensor`, *optional*): - Deprecated and unused. unsqueeze_dim (`int`, *optional*, defaults to 1): The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note @@ -386,7 +384,7 @@ def forward( kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) cos, sin = self.rotary_emb(value_states, position_ids, seq_len=kv_seq_len) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models @@ -486,7 +484,7 @@ def forward( cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len, position_ids=position_ids) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models @@ -602,7 +600,7 @@ def forward( kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) cos, sin = self.rotary_emb(value_states, position_ids, seq_len=kv_seq_len) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models diff --git a/src/transformers/models/phimoe/modeling_phimoe.py b/src/transformers/models/phimoe/modeling_phimoe.py index 8f6b092da6e6ad..9ffeafd4da30c8 100644 --- a/src/transformers/models/phimoe/modeling_phimoe.py +++ b/src/transformers/models/phimoe/modeling_phimoe.py @@ -91,6 +91,7 @@ def load_balancing_loss_func( if gate_logits is None or not isinstance(gate_logits, tuple): return 0 + compute_device = None if isinstance(gate_logits, tuple): compute_device = gate_logits[0].device concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0) @@ -156,8 +157,11 @@ def __init__( else: self.rope_type = "default" self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + self.rotary_ndims = config.hidden_size // config.num_attention_heads - def forward(self, x, seq_len=None): + def forward(self, x, seq_len: int): + if seq_len is None: + raise ValueError("seq_len must be given") mscale = None if self.config.rope_scaling and seq_len: mscale = ( @@ -171,7 +175,9 @@ def forward(self, x, seq_len=None): freqs = torch.outer(t, inv_freq) emb = torch.cat((freqs, freqs), dim=-1) - return (emb.cos() * mscale).to(x.dtype), (emb.sin() * mscale).to(x.dtype) + cos = emb.cos() + sin = emb.sin() + return (cos * mscale).to(x.dtype), (sin * mscale).to(x.dtype) # Copied from transformers.models.llama.modeling_llama.rotate_half @@ -270,14 +276,16 @@ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): def forward( self, hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if position_embeddings is None: + raise ValueError("position_embeddings = (cos, sin) must be given") bsz, q_len, _ = hidden_states.size() query_states = self.q_proj(hidden_states) @@ -338,14 +346,16 @@ class PhimoeFlashAttention2(PhimoeAttention): def forward( self, hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, ): + if position_embeddings is None: + raise ValueError("position_embeddings = (cos, sin) must be given") bsz, q_len, _ = hidden_states.size() query_states = self.q_proj(hidden_states) @@ -432,14 +442,16 @@ class PhimoeSdpaAttention(PhimoeAttention): def forward( self, hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if position_embeddings is None: + raise ValueError("position_embeddings = (cos, sin) must be given") if output_attentions: # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. logger.warning_once( @@ -448,12 +460,12 @@ def forward( ) return super().forward( hidden_states=hidden_states, + position_embeddings=position_embeddings, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, - position_embeddings=position_embeddings, ) bsz, q_len, _ = hidden_states.size() @@ -813,6 +825,7 @@ def __init__(self, config: PhimoeConfig, layer_idx: int): def forward( self, hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, @@ -820,7 +833,6 @@ def forward( output_router_logits: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ @@ -852,13 +864,13 @@ def forward( # Self Attention hidden_states, self_attn_weights, present_key_value = self.self_attn( hidden_states=hidden_states, + position_embeddings=position_embeddings, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, - position_embeddings=position_embeddings, ) hidden_states = residual + hidden_states @@ -1116,6 +1128,7 @@ def forward( layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, + position_embeddings, causal_mask, position_ids, past_key_values, @@ -1123,11 +1136,11 @@ def forward( output_router_logits, use_cache, cache_position, - position_embeddings, ) else: layer_outputs = decoder_layer( hidden_states, + position_embeddings=position_embeddings, attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, @@ -1135,7 +1148,6 @@ def forward( output_router_logits=output_router_logits, use_cache=use_cache, cache_position=cache_position, - position_embeddings=position_embeddings, ) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/pixtral/modeling_pixtral.py b/src/transformers/models/pixtral/modeling_pixtral.py index 03886d4a528478..fc8258c92093ed 100644 --- a/src/transformers/models/pixtral/modeling_pixtral.py +++ b/src/transformers/models/pixtral/modeling_pixtral.py @@ -127,7 +127,7 @@ def rotate_half(x): # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb -def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): +def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. Args: @@ -135,8 +135,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): k (`torch.Tensor`): The key tensor. cos (`torch.Tensor`): The cosine part of the rotary embedding. sin (`torch.Tensor`): The sine part of the rotary embedding. - position_ids (`torch.Tensor`, *optional*): - Deprecated and unused. unsqueeze_dim (`int`, *optional*, defaults to 1): The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note @@ -175,11 +173,13 @@ def __init__(self, config): def forward( self, hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor] = None, - position_embeddings: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: """Input shape: Batch x Time x Channel""" + if position_embeddings is None: + raise ValueError("position_embeddings = (cos, sin) must be given") batch_size, patches, _ = hidden_states.size() @@ -261,8 +261,8 @@ def __init__(self, config): def forward( self, hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], attention_mask: torch.Tensor, - position_embeddings: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = False, ) -> Tuple[torch.FloatTensor]: """ @@ -280,8 +280,8 @@ def forward( hidden_states = self.attention_norm(hidden_states) hidden_states, attn_weights = self.attention( hidden_states=hidden_states, - attention_mask=attention_mask, position_embeddings=position_embeddings, + attention_mask=attention_mask, output_attentions=output_attentions, ) hidden_states = residual + hidden_states @@ -310,8 +310,8 @@ def __init__(self, config): def forward( self, inputs_embeds, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor] = None, - position_embeddings: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, @@ -353,15 +353,15 @@ def forward( layer_outputs = self._gradient_checkpointing_func( encoder_layer.__call__, hidden_states, - attention_mask, position_embeddings, + attention_mask, output_attentions, ) else: layer_outputs = encoder_layer( hidden_states, - attention_mask, position_embeddings=position_embeddings, + attention_mask=attention_mask, output_attentions=output_attentions, ) diff --git a/src/transformers/models/qwen2/modeling_qwen2.py b/src/transformers/models/qwen2/modeling_qwen2.py index 5dba7594e7e9a1..8dbcb952feed62 100644 --- a/src/transformers/models/qwen2/modeling_qwen2.py +++ b/src/transformers/models/qwen2/modeling_qwen2.py @@ -64,7 +64,7 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) -def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): +def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. Args: @@ -72,8 +72,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): k (`torch.Tensor`): The key tensor. cos (`torch.Tensor`): The cosine part of the rotary embedding. sin (`torch.Tensor`): The sine part of the rotary embedding. - position_ids (`torch.Tensor`, *optional*): - Deprecated and unused. unsqueeze_dim (`int`, *optional*, defaults to 1): The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note @@ -155,6 +153,8 @@ def forward( cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if position_embeddings is None: + raise ValueError("position_embeddings = (cos, sin) must be given") input_shape = hidden_states.shape[:-1] hidden_shape = (*input_shape, -1, self.head_dim) @@ -242,13 +242,13 @@ def __init__(self, config: Qwen2Config, layer_idx: int): def forward( self, hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: residual = hidden_states @@ -258,13 +258,13 @@ def forward( # Self Attention hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, + position_embeddings=position_embeddings, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, - position_embeddings=position_embeddings, **kwargs, ) hidden_states = residual + hidden_states @@ -568,24 +568,24 @@ def forward( layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, + position_embeddings, causal_mask, position_ids, past_key_values, output_attentions, use_cache, cache_position, - position_embeddings, ) else: layer_outputs = decoder_layer( hidden_states, + position_embeddings=position_embeddings, attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, - position_embeddings=position_embeddings, **flash_attn_kwargs, ) diff --git a/src/transformers/models/qwen2/modular_qwen2.py b/src/transformers/models/qwen2/modular_qwen2.py index 718abd01090c2b..ccad8d533bfd98 100644 --- a/src/transformers/models/qwen2/modular_qwen2.py +++ b/src/transformers/models/qwen2/modular_qwen2.py @@ -52,6 +52,8 @@ def forward( cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if position_embeddings is None: + raise ValueError("position_embeddings = (cos, sin) must be given") input_shape = hidden_states.shape[:-1] hidden_shape = (*input_shape, -1, self.head_dim) diff --git a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py index 1ce41509a5c0d1..d0689b867e34fe 100644 --- a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py +++ b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py @@ -94,6 +94,7 @@ def load_balancing_loss_func( if gate_logits is None or not isinstance(gate_logits, tuple): return 0 + compute_device = None if isinstance(gate_logits, tuple): compute_device = gate_logits[0].device concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0) @@ -240,7 +241,7 @@ def rotate_half(x): # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb -def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): +def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. Args: @@ -248,8 +249,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): k (`torch.Tensor`): The key tensor. cos (`torch.Tensor`): The cosine part of the rotary embedding. sin (`torch.Tensor`): The sine part of the rotary embedding. - position_ids (`torch.Tensor`, *optional*): - Deprecated and unused. unsqueeze_dim (`int`, *optional*, defaults to 1): The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note @@ -360,7 +359,7 @@ def forward( value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models diff --git a/src/transformers/models/qwen2_vl/configuration_qwen2_vl.py b/src/transformers/models/qwen2_vl/configuration_qwen2_vl.py index ef98ae5e3f508f..fb2102491aebf5 100644 --- a/src/transformers/models/qwen2_vl/configuration_qwen2_vl.py +++ b/src/transformers/models/qwen2_vl/configuration_qwen2_vl.py @@ -54,6 +54,8 @@ def __init__( self.temporal_patch_size = temporal_patch_size +# TODO: Add comment for `rope_scaling["mrope_section"]`. This parameter +# is mandatory, but it is unclear what it should be set to. class Qwen2VLConfig(PretrainedConfig): r""" This is the configuration class to store the configuration of a [`Qwen2VLModel`]. It is used to instantiate a diff --git a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py index 566141d3f75c27..92f33481410e27 100644 --- a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py +++ b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py @@ -217,17 +217,14 @@ def apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, unsqueeze_dim k (`torch.Tensor`): The key tensor. cos (`torch.Tensor`): The cosine part of the rotary embedding. sin (`torch.Tensor`): The sine part of the rotary embedding. - position_ids (`torch.Tensor`): - The position indices of the tokens corresponding to the query and key tensors. For example, this can be - used to pass offsetted position ids when working with a KV-cache. mrope_section(`List(int)`): Multimodal rope section is for channel dimension of temporal, height and width in rope calculation. unsqueeze_dim (`int`, *optional*, defaults to 1): - The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and - sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note - that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos and + sin so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos and sin have the shape [batch_size, seq_len, head_dim]. Then, if q and k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes - cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + cos and sin broadcastable to the shapes of q and k. Similarly, if q and k have the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. Returns: `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. @@ -524,23 +521,19 @@ def __init__(self, config: Qwen2VLConfig, layer_idx: Optional[int] = None): self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) - self.rotary_emb = Qwen2VLRotaryEmbedding( - self.head_dim, - max_position_embeddings=self.max_position_embeddings, - base=self.rope_theta, - ) - def forward( self, hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if position_embeddings is None: + raise ValueError("position_embeddings = (cos, sin) must be given") bsz, q_len, _ = hidden_states.size() query_states = self.q_proj(hidden_states) @@ -617,14 +610,16 @@ def __init__(self, *args, **kwargs): def forward( self, hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC ): + if position_embeddings is None: + raise ValueError("position_embeddings = (cos, sin) must be given") bsz, q_len, _ = hidden_states.size() query_states = self.q_proj(hidden_states) @@ -719,14 +714,16 @@ class Qwen2VLSdpaAttention(Qwen2VLAttention): def forward( self, hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if position_embeddings is None: + raise ValueError("position_embeddings = (cos, sin) must be given") if output_attentions: # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. logger.warning_once( @@ -735,13 +732,13 @@ def forward( ) return super().forward( hidden_states=hidden_states, + position_embeddings=position_embeddings, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, - position_embeddings=position_embeddings, ) bsz, q_len, _ = hidden_states.size() @@ -825,13 +822,13 @@ def __init__(self, config: Qwen2VLConfig, layer_idx: int): def forward( self, hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ @@ -863,13 +860,13 @@ def forward( # Self Attention hidden_states, self_attn_weights, present_key_value = self.self_attn( hidden_states=hidden_states, + position_embeddings=position_embeddings, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, - position_embeddings=position_embeddings, ) hidden_states = residual + hidden_states @@ -1120,24 +1117,24 @@ def forward( layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, + position_embeddings, causal_mask, position_ids, past_key_values, output_attentions, use_cache, cache_position, - position_embeddings, ) else: layer_outputs = decoder_layer( hidden_states, + position_embeddings=position_embeddings, attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, - position_embeddings=position_embeddings, ) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/recurrent_gemma/configuration_recurrent_gemma.py b/src/transformers/models/recurrent_gemma/configuration_recurrent_gemma.py index 7f45a41710cf29..e59f1b16d4fc1a 100644 --- a/src/transformers/models/recurrent_gemma/configuration_recurrent_gemma.py +++ b/src/transformers/models/recurrent_gemma/configuration_recurrent_gemma.py @@ -14,6 +14,8 @@ # limitations under the License. """RecurrentGemma model configuration""" +import math + from ...configuration_utils import PretrainedConfig from ...utils import logging @@ -155,4 +157,6 @@ def __init__( @property def layers_block_type(self): - return (self.block_types * 100)[: self.num_hidden_layers] + len_bt = len(self.block_types) + sz = int(math.ceil(self.num_hidden_layers / len_bt)) + return (self.block_types * sz)[: self.num_hidden_layers] diff --git a/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py b/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py index 74fc2085c36519..6da1fc69c026ad 100644 --- a/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +++ b/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py @@ -71,9 +71,8 @@ def extra_repr(self): class RecurrentGemmaRotaryEmbedding(nn.Module): def __init__(self, dim, base=10000, device=None): super().__init__() - self.dim = dim self.base = base - inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim)) + inv_freq = 1.0 / (self.base ** (torch.arange(0, dim, 2, dtype=torch.int64).float() / dim)) self.register_buffer("inv_freq", tensor=inv_freq, persistent=False) @torch.no_grad() @@ -103,7 +102,7 @@ def rotate_half(x): # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb -def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): +def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. Args: @@ -111,8 +110,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): k (`torch.Tensor`): The key tensor. cos (`torch.Tensor`): The cosine part of the rotary embedding. sin (`torch.Tensor`): The sine part of the rotary embedding. - position_ids (`torch.Tensor`, *optional*): - Deprecated and unused. unsqueeze_dim (`int`, *optional*, defaults to 1): The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note @@ -155,14 +152,14 @@ def __init__(self, config: RecurrentGemmaConfig): self.head_dim = config.head_dim self.num_key_value_heads = config.num_key_value_heads self.num_key_value_groups = self.num_attention_heads // self.num_key_value_heads - self.partial_rotary_factor = config.partial_rotary_factor + self.rotary_ndims = int(self.head_dim * config.partial_rotary_factor) self.q_proj = nn.Linear(self.hidden_size, self.num_attention_heads * self.head_dim, bias=config.attention_bias) self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) self.o_proj = nn.Linear(self.num_attention_heads * self.head_dim, self.hidden_size, bias=True) self.rotary_emb = RecurrentGemmaRotaryEmbedding( - int(self.partial_rotary_factor * self.head_dim), + self.rotary_ndims, base=config.rope_theta, ) @@ -187,9 +184,11 @@ def forward( cos, sin = self.rotary_emb(value_states, position_ids) # Partial rotary embedding - query_rot, query_pass = torch.chunk(query_states, int(1 / self.partial_rotary_factor), dim=-1) - key_rot, key_pass = torch.chunk(key_states, int(1 / self.partial_rotary_factor), dim=-1) - query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin, position_ids) + query_rot = query_states[..., : self.rotary_ndims] + query_pass = query_states[..., self.rotary_ndims :] + key_rot = key_states[..., : self.rotary_ndims] + key_pass = key_states[..., self.rotary_ndims :] + query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin) query_states = torch.cat((query_rot, query_pass), dim=-1) key_states = torch.cat((key_rot, key_pass), dim=-1) diff --git a/src/transformers/models/stablelm/modeling_stablelm.py b/src/transformers/models/stablelm/modeling_stablelm.py index 7214a36e9a3921..406e2f881e5ac9 100755 --- a/src/transformers/models/stablelm/modeling_stablelm.py +++ b/src/transformers/models/stablelm/modeling_stablelm.py @@ -136,7 +136,7 @@ def rotate_half(x): # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb -def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): +def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. Args: @@ -144,8 +144,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): k (`torch.Tensor`): The key tensor. cos (`torch.Tensor`): The cosine part of the rotary embedding. sin (`torch.Tensor`): The sine part of the rotary embedding. - position_ids (`torch.Tensor`, *optional*): - Deprecated and unused. unsqueeze_dim (`int`, *optional*, defaults to 1): The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note @@ -227,7 +225,6 @@ def __init__(self, config: StableLmConfig, layer_idx: Optional[int] = None): self.head_dim = self.hidden_size // self.num_heads self.num_key_value_heads = config.num_key_value_heads self.num_key_value_groups = self.num_heads // self.num_key_value_heads - self.rope_theta = config.rope_theta self.rotary_ndims = int(self.head_dim * config.partial_rotary_factor) self.is_causal = True @@ -249,19 +246,20 @@ def __init__(self, config: StableLmConfig, layer_idx: Optional[int] = None): ) self.attention_dropout = nn.Dropout(config.attention_dropout) - self.rotary_emb = StableLmRotaryEmbedding(config=self.config) def forward( self, hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if position_embeddings is None: + raise ValueError("position_embeddings = (cos, sin) must be given") bsz, q_len, _ = hidden_states.size() query_states = self.q_proj(hidden_states) @@ -341,14 +339,16 @@ class StableLmSdpaAttention(StableLmAttention): def forward( self, hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if position_embeddings is None: + raise ValueError("position_embeddings = (cos, sin) must be given") if output_attentions: # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. logger.warning_once( @@ -357,13 +357,13 @@ def forward( ) return super().forward( hidden_states=hidden_states, + position_embeddings=position_embeddings, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, - position_embeddings=position_embeddings, ) bsz, q_len, _ = hidden_states.size() @@ -463,13 +463,13 @@ def __init__(self, *args, **kwargs): def forward( self, hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: # StableLmFlashAttention2 attention does not support output_attentions @@ -571,13 +571,13 @@ def __init__(self, config: StableLmConfig, layer_idx: int): def forward( self, hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ Args: @@ -611,13 +611,13 @@ def forward( # Self Attention self_attn_output, self_attn_weights, present_key_value = self.self_attn( hidden_states=hidden_states, + position_embeddings=position_embeddings, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, - position_embeddings=position_embeddings, ) # copied from transformers.models.gpt_neox.modeling_gpt_neox.GPTNeoXLayer.forward @@ -881,24 +881,24 @@ def forward( layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, + position_embeddings, causal_mask, position_ids, past_key_values, output_attentions, use_cache, cache_position, - position_embeddings, ) else: layer_outputs = decoder_layer( hidden_states, + position_embeddings=position_embeddings, attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, - position_embeddings=position_embeddings, ) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/starcoder2/modeling_starcoder2.py b/src/transformers/models/starcoder2/modeling_starcoder2.py index 3b4fdbcb81ccc4..814eeabad4fa85 100644 --- a/src/transformers/models/starcoder2/modeling_starcoder2.py +++ b/src/transformers/models/starcoder2/modeling_starcoder2.py @@ -83,7 +83,7 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) -def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): +def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. Args: @@ -91,8 +91,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): k (`torch.Tensor`): The key tensor. cos (`torch.Tensor`): The cosine part of the rotary embedding. sin (`torch.Tensor`): The sine part of the rotary embedding. - position_ids (`torch.Tensor`, *optional*): - Deprecated and unused. unsqueeze_dim (`int`, *optional*, defaults to 1): The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note @@ -233,13 +231,13 @@ def __init__(self, config: Starcoder2Config, layer_idx: int): def forward( self, hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: residual = hidden_states @@ -249,13 +247,13 @@ def forward( # Self Attention hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, + position_embeddings=position_embeddings, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, - position_embeddings=position_embeddings, **kwargs, ) hidden_states = residual + hidden_states @@ -561,13 +559,13 @@ def forward( layer_outputs = decoder_layer( hidden_states, + position_embeddings=position_embeddings, attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, - position_embeddings=position_embeddings, **flash_attn_kwargs, ) diff --git a/src/transformers/models/starcoder2/modular_starcoder2.py b/src/transformers/models/starcoder2/modular_starcoder2.py index 32d64cd167ba50..61fc3271d81f5a 100644 --- a/src/transformers/models/starcoder2/modular_starcoder2.py +++ b/src/transformers/models/starcoder2/modular_starcoder2.py @@ -223,13 +223,13 @@ def forward( layer_outputs = decoder_layer( hidden_states, + position_embeddings=position_embeddings, attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, - position_embeddings=position_embeddings, **flash_attn_kwargs, ) diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 929bbb13a56e80..e74cb043d06154 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -2651,6 +2651,9 @@ def test_pt_tf_model_equivalence(self, allow_missing_keys=False): tf_model_class = getattr(transformers, tf_model_class_name) + if hasattr(config, "debug_output"): + config.debug_output = True + pt_model = model_class(config).eval() tf_model = tf_model_class(config) diff --git a/utils/check_config_attributes.py b/utils/check_config_attributes.py index 116e26e7834f26..c6c21a13f2795c 100644 --- a/utils/check_config_attributes.py +++ b/utils/check_config_attributes.py @@ -199,6 +199,14 @@ "giou_cost", "giou_loss_coefficient", ], + "GPTNeoXConfig": [ + "rotary_emb_base", # Doubles for rope_theta + "rotary_pct", # Doubles for partial_rotary_factor + ], + "GPTNeoXJapaneseConfig": [ + "rotary_emb_base", # Doubles for rope_theta + "rotary_pct", # Doubles for partial_rotary_factor + ], }