Skip to content

Commit

Permalink
optimize phi3 memory usage again (#11848)
Browse files Browse the repository at this point in the history
  • Loading branch information
MeouSker77 authored Aug 19, 2024
1 parent 3cd4e87 commit 9490781
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions python/llm/src/ipex_llm/transformers/models/phi3.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,12 +177,12 @@ def attention_forward(
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)

attn_weights = torch.matmul(query_states,
key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3))

# use inplaced div, add and softmax to avoid double attn_weights memory usage
attn_weights.div_(math.sqrt(self.head_dim))
if attention_mask is not None:
attn_weights = attn_weights + attention_mask

attn_weights.add_(attention_mask)
attn_weights = attention_softmax(attn_weights, self.training)

attn_weights = torch.nn.functional.dropout(attn_weights, p=self.attention_dropout,
Expand Down

0 comments on commit 9490781

Please sign in to comment.