Skip to content

Commit 5e5646e

Browse files
authored
[BUGFIX] llama_4_scaling wrongly passed to DeepseekAttention (#29908)
Signed-off-by: juliendenize <[email protected]>
1 parent 0a9caca commit 5e5646e

File tree

1 file changed

+10
-5
lines changed

1 file changed

+10
-5
lines changed

vllm/model_executor/models/deepseek_v2.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1135,6 +1135,8 @@ def __init__(
11351135
dim == 0 for dim in (qk_nope_head_dim, qk_rope_head_dim)
11361136
)
11371137

1138+
self.use_mha = use_mha
1139+
11381140
if use_mha:
11391141
attn_cls = DeepseekAttention
11401142
elif model_config.use_mla:
@@ -1196,11 +1198,14 @@ def forward(
11961198
hidden_states = self.input_layernorm(hidden_states)
11971199
else:
11981200
hidden_states, residual = self.input_layernorm(hidden_states, residual)
1199-
hidden_states = self.self_attn(
1200-
positions=positions,
1201-
hidden_states=hidden_states,
1202-
llama_4_scaling=llama_4_scaling,
1203-
)
1201+
1202+
attn_kwargs = {
1203+
"positions": positions,
1204+
"hidden_states": hidden_states,
1205+
}
1206+
if not self.use_mha:
1207+
attn_kwargs["llama_4_scaling"] = llama_4_scaling
1208+
hidden_states = self.self_attn(**attn_kwargs)
12041209

12051210
if (
12061211
not isinstance(self.self_attn, DeepseekAttention)

0 commit comments

Comments
 (0)