Skip to content

Commit 5d7ab5c

Browse files
committed
[Metax] adapt cutlass moe for ernie-vl
1 parent 1f3ce65 commit 5d7ab5c

File tree

6 files changed

+174
-101
lines changed

6 files changed

+174
-101
lines changed

custom_ops/metax_ops/fused_moe.cu

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,10 @@ std::vector<paddle::Tensor> FusedExpertMoe(
101101
const auto input_type = input.dtype();
102102
auto output = paddle::empty_like(input);
103103

104+
if (output.dims()[0] == 0) {
105+
return {output};
106+
}
107+
104108
switch (input_type) {
105109
case paddle::DataType::BFLOAT16:
106110
FusedMoeKernel<paddle::DataType::BFLOAT16,

custom_ops/metax_ops/moe_dispatch.cu

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,14 @@ std::vector<paddle::Tensor> MoeExpertDispatch(
178178
auto permute_indices_per_token =
179179
GetEmptyTensor({moe_topk, num_rows}, paddle::DataType::INT32, place);
180180

181+
if (token_rows == 0) {
182+
return {permute_input,
183+
tokens_expert_prefix_sum,
184+
permute_indices_per_token,
185+
top_k_weight,
186+
top_k_indices};
187+
}
188+
181189
switch (input_type) {
182190
case paddle::DataType::BFLOAT16:
183191
MoeDispatchKernel<paddle::DataType::BFLOAT16>(input,

custom_ops/metax_ops/moe_ffn.cu

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,10 @@ std::vector<paddle::Tensor> MoeExpertFFN(
114114
const auto input_type = permute_input.dtype();
115115
auto ffn_out = paddle::empty_like(permute_input);
116116

117+
if (permute_input.numel() == 0) {
118+
return {ffn_out};
119+
}
120+
117121
switch (input_type) {
118122
case paddle::DataType::BFLOAT16:
119123
McMoeFFNKernel<paddle::DataType::BFLOAT16,

custom_ops/setup_ops.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -612,6 +612,8 @@ def find_end_files(directory, end_str):
612612
"gpu_ops/text_image_gather_scatter.cu",
613613
"gpu_ops/text_image_index_out.cu",
614614
"gpu_ops/get_position_ids_and_mask_encoder_batch.cu",
615+
"gpu_ops/limit_thinking_content_length_v1.cu",
616+
"gpu_ops/limit_thinking_content_length_v2.cu",
615617
"gpu_ops/append_attn/mla_cache_kernel.cu",
616618
"gpu_ops/append_attn/get_block_shape_and_split_kv_block.cu",
617619
"gpu_ops/moe/tritonmoe_preprocess.cu",

fastdeploy/model_executor/pre_and_post_process.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,12 @@
5050
elif current_platform.is_maca():
5151
from fastdeploy.model_executor.ops.gpu import (
5252
get_padding_offset,
53+
limit_thinking_content_length_v1,
54+
limit_thinking_content_length_v2,
5355
save_output,
5456
set_stop_value_multi_ends,
57+
speculate_limit_thinking_content_length_v1,
58+
speculate_limit_thinking_content_length_v2,
5559
step_paddle,
5660
update_inputs,
5761
update_inputs_v1,
@@ -770,7 +774,9 @@ def rebuild_padding(
770774
seq_lens_decoder,
771775
seq_lens_encoder,
772776
output_padding_offset,
777+
first_token_out,
773778
max_input_length,
779+
enable_logprob,
774780
)
775781
else:
776782
raise RuntimeError("Not supported platform")

0 commit comments

Comments
 (0)