Skip to content

Commit

Permalink
Used cache position instead to get the seq len
Browse files Browse the repository at this point in the history
  • Loading branch information
garg-amit committed Oct 4, 2024
1 parent 18830f5 commit 36891bf
Showing 1 changed file with 1 addition and 4 deletions.
5 changes: 1 addition & 4 deletions src/transformers/models/phimoe/modeling_phimoe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1130,10 +1130,7 @@ def forward(

hidden_states = inputs_embeds

kv_seq_len = hidden_states.shape[-2]
if past_key_values is not None:
kv_seq_len += past_key_values.get_usable_length(kv_seq_len)
position_embeddings = self.rotary_emb(hidden_states, seq_len=kv_seq_len)
position_embeddings = self.rotary_emb(hidden_states, seq_len=cache_position[-1] + 1)

# decoder layers
all_hidden_states = () if output_hidden_states else None
Expand Down

0 comments on commit 36891bf

Please sign in to comment.