Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 52 additions & 11 deletions fastdeploy/model_executor/layers/attention/dsa_attention_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,8 +335,28 @@ def forward_mixed(
"""
Mixed模式的前向传播
"""
res = DSAAttentionBackend.forward_static(
q, v, compressed_kv, k_pe, forward_meta.caches[2 * layer.layer_id], forward_meta, self.attn_softmax_scale
)
return res

@staticmethod
def forward_static(

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 建议 这个新增的 forward_static() 承接了 DSA 的 cache 写入、prefill、decode 和 mixed merge 逻辑,但本 PR 同时删除了原来覆盖 forward_mixed 分支的 tests/layers/test_dsa_attention_backend.py,没有补充替代测试。

当前保留的 tests/layers/test_dsa_attention_kv_cache.py 只覆盖 cache shape/create_host_kv_cache,无法守住这里的输出 shape、cache 写入和 indexer_topk 传参。建议保留或重写轻量单测,至少覆盖 prefill-only、decode-only、prefill+decode merge 三条分支,并 mock flash_mla/dsk_attn_write_cache 验证 compressed_kv + k_pe 在 backend 内构造后的调用参数。

q: paddle.Tensor,
indexer_topk: paddle.Tensor,
compressed_kv: paddle.Tensor,
k_pe: paddle.Tensor,
latent_cache: paddle.Tensor,
forward_meta: ForwardMeta,
attn_softmax_scale: float,
) -> paddle.Tensor:

latent_cache = forward_meta.caches[2 * layer.layer_id] if hasattr(forward_meta, "caches") else None
assert len(q.shape) == 3
assert len(compressed_kv.shape) == 2
assert len(k_pe.shape) == 3
assert k_pe.shape[1] == 1
assert compressed_kv.shape[0] == k_pe.shape[0]
assert len(latent_cache.shape) == 4

if current_platform.is_cuda():
import flash_mla
Expand All @@ -352,43 +372,64 @@ def forward_mixed(
"fp8_ds_mla",
)

q_num_heads = q.shape[1]
ceil64_num_heads = (q_num_heads + 63) // 64 * 64

fmha_out_prefill = None
if forward_meta.max_len_tensor_cpu[1]: # max_enc_len_this_time
if ceil64_num_heads != q_num_heads:
new_q = paddle.empty([q.shape[0], ceil64_num_heads, q.shape[2]], dtype=q.dtype)
new_q[:, :q_num_heads, :] = q
else:
new_q = q

# concat for involing flash_mla_sparse_fwd!
kv = paddle.concat([compressed_kv.unsqueeze(1), k_pe], axis=-1)
fmha_out_prefill, _, __ = flash_mla.flash_mla_sparse_fwd(
q, # q_input.contiguous(),
k, # kv.unsqueeze(1),
v, # indexer_top_k.unsqueeze(1),
sm_scale=self.attn_softmax_scale,
new_q,
kv,
indexer_topk,
sm_scale=attn_softmax_scale,
)

assert len(fmha_out_prefill.shape) == 3
fmha_out_prefill = fmha_out_prefill[:, :q_num_heads, :].contiguous()

# Decode
# if k is None:
if forward_meta.max_len_tensor_cpu[2]: # max_enc_len_this_time
if forward_meta.max_len_tensor_cpu[2]:

tile_scheduler_metadata, _ = flash_mla.get_mla_metadata()
new_cache_shape = latent_cache.shape
assert new_cache_shape[1] == 1
new_cache_shape[1], new_cache_shape[2] = new_cache_shape[2], new_cache_shape[1]

if ceil64_num_heads != q_num_heads:
new_q = paddle.empty([q.shape[0], ceil64_num_heads, q.shape[2]], dtype=q.dtype)
new_q[:, :q_num_heads, :] = q
else:
new_q = q

fmha_out_decode, _ = flash_mla.flash_mla_with_kvcache(
q.unsqueeze(1).contiguous(),
new_q.unsqueeze(1).contiguous(),
latent_cache.view(new_cache_shape),
None, # forward_meta.block_tables,
None, # cache_seqlens
512, # self.qk_nope_head_dim,
tile_scheduler_metadata,
None, # num_splits,
self.attn_softmax_scale,
attn_softmax_scale,
False, # casual
True, # is_fp8_kvcache
v, # indices,
indexer_topk, # indices,
None, # t.attn_sink,
None, # extra_k_cache,
None, # extra_indices_in_kvcache: Optional[torch.Tensor] = None,
None, # topk_length: Optional[torch.Tensor] = None,
None, # extra_topk_length: Optional[torch.Tensor] = None
)

fmha_out_decode = fmha_out_decode[:, :, :q_num_heads, :].contiguous()

if fmha_out_prefill is not None:

from fastdeploy.model_executor.ops.gpu import (
Expand All @@ -402,7 +443,7 @@ def forward_mixed(
forward_meta.seq_lens_decoder,
forward_meta.seq_lens_this_time,
forward_meta.cu_seqlens_q,
self.num_heads * 4,
q_num_heads * 4,
128,
1,
)
Expand Down
121 changes: 118 additions & 3 deletions fastdeploy/model_executor/models/deepseek_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,78 @@
)


import triton
import triton.language as tl


@enable_compat_on_triton_kernel
@triton.jit
def get_swa_indexer_top_k_kernel(
indexer_top_k,
block_tables,
cu_seqlens_q,
seq_lens_encoder,
seq_lens_decoder,
batch_id_per_token,
max_page_per_seq: tl.constexpr,
window_size: tl.constexpr,
page_size: tl.constexpr,
):
token_id = tl.program_id(0)

indexer_top_k += token_id * window_size

batch_id = tl.load(batch_id_per_token + token_id)
if batch_id < 0:
return

block_tables += batch_id * max_page_per_seq

kv_len = tl.load(seq_lens_decoder + batch_id)
encoder_len = tl.load(seq_lens_encoder + batch_id)
cu_q_len = tl.load(cu_seqlens_q + batch_id)
token_id_in_this_batch = token_id - cu_q_len + kv_len

valid_window_size = min(token_id_in_this_batch + 1, window_size)

for idx in range(token_id_in_this_batch, token_id_in_this_batch - valid_window_size, -1):
if encoder_len > 0:
# encoder case.
tmp = cu_q_len + idx
tl.store(indexer_top_k + token_id_in_this_batch - idx, tmp)
else:
tmp = tl.load(block_tables + idx // page_size)
tmp = tmp * page_size + idx % page_size
tl.store(indexer_top_k + token_id_in_this_batch - idx, tmp)


def get_swa_indexer_top_k(
indexer_top_k,
block_tables,
cu_seqlens_q,
seq_lens_encoder,
seq_lens_decoder,
batch_id_per_token,
):
assert indexer_top_k.ndim == 3
assert indexer_top_k.shape[1] == 1

token_num = indexer_top_k.shape[0]
grid = (token_num,)

get_swa_indexer_top_k_kernel[grid](
indexer_top_k,
block_tables,
cu_seqlens_q,
seq_lens_encoder,
seq_lens_decoder,
batch_id_per_token,
max_page_per_seq=block_tables.shape[1],
window_size=indexer_top_k.shape[2],
page_size=64,
)


class DeepSeekV3MLP(nn.Layer):
"""
DeepSeekV3MLP, for Dense FFN and Shared Experts Layer.
Expand Down Expand Up @@ -534,6 +606,52 @@ def forward(
)
else:
attn_out = fmqa_out

if False:
q_nope_out = self.kv_b_proj_bmm(query_nope.transpose([1, 0, 2]), proj_type="k").transpose([1, 0, 2])

q_input = paddle.concat([q_nope_out, query_pe], axis=-1)
q_input.reshape_(
[
-1,
self.num_attention_heads_tp,
self.kv_lora_rank + self.qk_rope_head_dim,
]
)

self.index_topk = 512
indexer_top_k = paddle.full([q_input.shape[0], 1, self.index_topk], -1, dtype="int32")

get_swa_indexer_top_k(
indexer_top_k,
forward_meta.block_tables,
forward_meta.cu_seqlens_q,
forward_meta.seq_lens_encoder,
forward_meta.seq_lens_decoder,
forward_meta.batch_id_per_token,
)

from fastdeploy.model_executor.layers.attention import DSAAttentionBackend

fmqa_out = DSAAttentionBackend.forward_static(
q=q_input.contiguous(),
indexer_topk=indexer_top_k,
compressed_kv=compressed_kv,
k_pe=key_pe,
latent_cache=forward_meta.caches[self.layer_id],

This comment was marked as outdated.

forward_meta=forward_meta,
attn_softmax_scale=self.attn_softmax_scale,
)

fmqa_out = fmqa_out.reshape_([-1, self.num_attention_heads_tp, self.kv_lora_rank]).transpose([1, 0, 2])

fmqa_out = (
self.kv_b_proj_bmm(fmqa_out, proj_type="v")
.transpose([1, 0, 2])
.reshape_([-1, self.num_attention_heads_tp * self.v_head_dim])
)
attn_out = fmqa_out

if self.use_gated_attn:
gated_attn_act = getattr(self.fd_config.model_config, "gated_attn_act", "sigmoid")
if gated_attn_act == "sigmoid":
Expand All @@ -547,7 +665,6 @@ def forward(


import triton
import triton.language as tl


@enable_compat_on_triton_kernel
Expand Down Expand Up @@ -894,12 +1011,10 @@ def forward(
q_input = paddle.concat([q_nope_out.transpose([1, 0, 2]).contiguous(), query_pe], axis=-1)

compressed_kv = self.kv_a_layernorm(compressed_kv)[0]
kv = paddle.concat([compressed_kv, key_pe.squeeze(1)], axis=-1)

# dsa attention
fmha_out = self.dsa_attn(
q=q_input.contiguous(),
k=kv.unsqueeze(1).contiguous(),
v=indexer_top_k.unsqueeze(1).contiguous(),
qkv=None,
compressed_kv=compressed_kv,
Expand Down
Loading
Loading