From 369fec6c7cde4a4b7ab33b37ef313f87d399c614 Mon Sep 17 00:00:00 2001 From: Helen Ngo Date: Thu, 7 Nov 2024 15:28:03 -0800 Subject: [PATCH] ADLR/megatron-lm!2294 - Fix signature for multi-latent attention KV cache update --- megatron/core/transformer/multi_latent_attention.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/megatron/core/transformer/multi_latent_attention.py b/megatron/core/transformer/multi_latent_attention.py index d637e2b448..108e6a5c1b 100644 --- a/megatron/core/transformer/multi_latent_attention.py +++ b/megatron/core/transformer/multi_latent_attention.py @@ -116,6 +116,7 @@ def forward( packed_seq_params=None, position_ids=None, ): + """Forward pass for multi-latent attention""" assert rotary_pos_emb is None, "Rotary position embeddings should not be passed into MLA." # hidden_states: [sq, b, h] @@ -138,8 +139,8 @@ def forward( # Adjust key, value for inference # =================================================== # rotary_pos_emb = None - key, value, _, attn_mask_type = self._adjust_key_value_for_inference( - inference_params, key, value, rotary_pos_emb=None + query, key, value, _, attn_mask_type = self._adjust_key_value_for_inference( + inference_params, query, key, value, rotary_pos_emb=None ) # ==================================