From dd74ea0b9a40b4dd5c8eacf8306bc0d63c94e54c Mon Sep 17 00:00:00 2001 From: seaofocean Date: Mon, 16 Oct 2023 03:27:17 +0000 Subject: [PATCH] Remove unnecessary repeat_interleave to fix performance drop --- megatron/model/transformer.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/megatron/model/transformer.py b/megatron/model/transformer.py index 71337c818f..fd76edcedd 100644 --- a/megatron/model/transformer.py +++ b/megatron/model/transformer.py @@ -753,14 +753,15 @@ def forward(self, hidden_states, attention_mask, # ================================== # expand the key_layer and value_layer [sk, b, ng, hn] -> [sk, b, np, hn] - key_layer = key_layer.repeat_interleave( - self.num_attention_heads_per_partition // self.num_query_groups_per_partition, - dim = 2 - ) - value_layer = value_layer.repeat_interleave( - self.num_attention_heads_per_partition // self.num_query_groups_per_partition, - dim = 2 - ) + if self.num_attention_heads_per_partition // self.num_query_groups_per_partition > 1: + key_layer = key_layer.repeat_interleave( + self.num_attention_heads_per_partition // self.num_query_groups_per_partition, + dim = 2 + ) + value_layer = value_layer.repeat_interleave( + self.num_attention_heads_per_partition // self.num_query_groups_per_partition, + dim = 2 + ) # apply relative positional encoding (rotary embedding) if rotary_pos_emb is not None: