Skip to content

Commit

Permalink
Merge branch 'cherry-pick-0d7ebc39' into 'main'
Browse files Browse the repository at this point in the history
Merge branch 'hongbinl/perf_fix' into '23.08'

See merge request ADLR/megatron-lm!838
  • Loading branch information
jaredcasper committed Oct 12, 2023
2 parents 37bd99a + 993aa0f commit 96ebacc
Showing 1 changed file with 7 additions and 6 deletions.
13 changes: 7 additions & 6 deletions megatron/core/transformer/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,12 +244,13 @@ def forward(
# This is a noop for normal attention where ng == np. When using group query attention this
# creates a view that has the keys and values virtually repeated along their dimension to
# match the number of queries.
key = key.repeat_interleave(
self.num_attention_heads_per_partition // self.num_query_groups_per_partition, dim=2
)
value = value.repeat_interleave(
self.num_attention_heads_per_partition // self.num_query_groups_per_partition, dim=2
)
if self.num_attention_heads_per_partition // self.num_query_groups_per_partition > 1:
key = key.repeat_interleave(
self.num_attention_heads_per_partition // self.num_query_groups_per_partition, dim=2
)
value = value.repeat_interleave(
self.num_attention_heads_per_partition // self.num_query_groups_per_partition, dim=2
)

if self.checkpoint_dot_product_attention:
core_attn_out = self._checkpointed_attention_forward(query, key, value, attention_mask)
Expand Down

0 comments on commit 96ebacc

Please sign in to comment.