Skip to content

Commit

Permalink
Merge branch 'repeat' into 'main'
Browse files Browse the repository at this point in the history
Remove unnecessary repeat_interleave to fix performance drop

See merge request ADLR/megatron-lm!852
  • Loading branch information
jaredcasper committed Oct 16, 2023
2 parents bc1c67d + dd74ea0 commit 954a65b
Showing 1 changed file with 9 additions and 8 deletions.
17 changes: 9 additions & 8 deletions megatron/model/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -753,14 +753,15 @@ def forward(self, hidden_states, attention_mask,
# ==================================

# expand the key_layer and value_layer [sk, b, ng, hn] -> [sk, b, np, hn]
key_layer = key_layer.repeat_interleave(
self.num_attention_heads_per_partition // self.num_query_groups_per_partition,
dim = 2
)
value_layer = value_layer.repeat_interleave(
self.num_attention_heads_per_partition // self.num_query_groups_per_partition,
dim = 2
)
if self.num_attention_heads_per_partition // self.num_query_groups_per_partition > 1:
key_layer = key_layer.repeat_interleave(
self.num_attention_heads_per_partition // self.num_query_groups_per_partition,
dim = 2
)
value_layer = value_layer.repeat_interleave(
self.num_attention_heads_per_partition // self.num_query_groups_per_partition,
dim = 2
)

# apply relative positional encoding (rotary embedding)
if rotary_pos_emb is not None:
Expand Down

0 comments on commit 954a65b

Please sign in to comment.