Skip to content
Open

swa #8054

Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -875,6 +875,7 @@ def forward_mixed(
forward_meta.cu_seqlens_q,
forward_meta.cu_seqlens_k,
causal=self.causal,
window_size=-1,

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔴 Bug 这里把 window_size 固定为 -1,导致新增的 SWA mask 分支不会在唯一调用点生效。

Attention 层已经在 layer_types == "sliding_attention" 时保存了 layer.sliding_windowappend_attn_backend 也会把正的 sliding_window 传给 attention kernel;但 SM10 的 MLA mixed prefill 会走这个 mha_baseline 分支,当前固定 -1 后即使该层是 sliding attention,也会执行 full causal attention,输出语义和 SWA 配置不一致。

建议修复方式:在调用前按现有 attention 约定计算当前层窗口,例如 sliding 层传 layer.sliding_window(或与 append backend 一致地优先使用 backend/model 级 self.sliding_window),非 sliding 层传 -1/0 表示 full attention,并补充 kv_len > q_len 的 mixed prefill + SWA 对齐测试。

**self.flash_attn_kwargs,
)
return fmha_out
Expand Down Expand Up @@ -1155,7 +1156,7 @@ def flashmla_baseline(decoder_q, latent_cache, block_table, cache_seqlens, attn_
return res_baseline

@staticmethod
def mha_baseline(q, k, v, cu_seqlens_q, cu_seqlens_k, causal, softmax_scale):
def mha_baseline(q, k, v, cu_seqlens_q, cu_seqlens_k, causal, window_size, softmax_scale):

assert causal, "Only support causal attention for now"
bsz = cu_seqlens_q.shape[0] - 1
Expand Down Expand Up @@ -1191,7 +1192,12 @@ def mha_baseline(q, k, v, cu_seqlens_q, cu_seqlens_k, causal, softmax_scale):

tmp_zeros = np.zeros((q_len, kv_len)) - 1
for i in range(q_len):
tmp_zeros[i][: i + 1] = 0
if kv_len - q_len + i + 1 > window_size and window_size > 0:
ss = kv_len - q_len + i + 1 - window_size
tmp_zeros[i][ss : kv_len - q_len + i + 1] = 0
else:
# attention all before this `i` th q.
tmp_zeros[i][: kv_len - q_len + i + 1] = 0
mask = tmp_zeros * 1000
mask = paddle.to_tensor(mask, dtype=q.dtype)
p = p + mask[None, :]
Expand Down
Loading