diff --git a/megatron/core/transformer/multi_latent_attention.py b/megatron/core/transformer/multi_latent_attention.py index d637e2b448..108e6a5c1b 100644 --- a/megatron/core/transformer/multi_latent_attention.py +++ b/megatron/core/transformer/multi_latent_attention.py @@ -116,6 +116,7 @@ def forward( packed_seq_params=None, position_ids=None, ): + """Forward pass for multi-latent attention""" assert rotary_pos_emb is None, "Rotary position embeddings should not be passed into MLA." # hidden_states: [sq, b, h] @@ -138,8 +139,8 @@ def forward( # Adjust key, value for inference # =================================================== # rotary_pos_emb = None - key, value, _, attn_mask_type = self._adjust_key_value_for_inference( - inference_params, key, value, rotary_pos_emb=None + query, key, value, _, attn_mask_type = self._adjust_key_value_for_inference( + inference_params, query, key, value, rotary_pos_emb=None ) # ==================================