Skip to content

Commit

Permalink
Merge branch 'hn-fix-mla-kv-cache' into 'main'
Browse files Browse the repository at this point in the history
Fix signature for multi-latent attention KV cache update

See merge request ADLR/megatron-lm!2294
  • Loading branch information
ericharper committed Nov 7, 2024
2 parents 19515ac + 369fec6 commit bc8c4f3
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions megatron/core/transformer/multi_latent_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ def forward(
packed_seq_params=None,
position_ids=None,
):
"""Forward pass for multi-latent attention"""
assert rotary_pos_emb is None, "Rotary position embeddings should not be passed into MLA."

# hidden_states: [sq, b, h]
Expand All @@ -138,8 +139,8 @@ def forward(
# Adjust key, value for inference
# ===================================================
# rotary_pos_emb = None
key, value, _, attn_mask_type = self._adjust_key_value_for_inference(
inference_params, key, value, rotary_pos_emb=None
query, key, value, _, attn_mask_type = self._adjust_key_value_for_inference(
inference_params, query, key, value, rotary_pos_emb=None
)

# ==================================
Expand Down

0 comments on commit bc8c4f3

Please sign in to comment.