Skip to content

Commit dd230cf

Browse files
Merge branch 'develop' into develop
2 parents 3aa6dfc + 1339e56 commit dd230cf

File tree

7 files changed

+25
-26
lines changed

7 files changed

+25
-26
lines changed

custom_ops/gpu_ops/get_padding_offset.cu

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ __global__ void RemovePadding(int64_t *output_data,
3434
}
3535
}
3636

37-
__global__ void GetPaddingOffsetKernel(int *padding_offset,
37+
__global__ void GetPaddingOffsetKernel(int *batch_id_per_token,
3838
int *cum_offsets_out,
3939
int *cu_seqlens_q,
4040
int *cu_seqlens_k,
@@ -46,7 +46,7 @@ __global__ void GetPaddingOffsetKernel(int *padding_offset,
4646
const int ti = threadIdx.x;
4747
int cum_offset = bi == 0 ? 0 : cum_offsets[bi - 1];
4848
for (int i = ti; i < seq_lens[bi]; i += blockDim.x) {
49-
padding_offset[bi * max_seq_len - cum_offset + i] = bi;
49+
batch_id_per_token[bi * max_seq_len - cum_offset + i] = bi;
5050
}
5151
if (ti == 0) {
5252
cum_offsets_out[bi] = cum_offset;
@@ -75,7 +75,7 @@ std::vector<paddle::Tensor> GetPaddingOffset(const paddle::Tensor &input_ids,
7575
const int token_num_data = cpu_token_num.data<int64_t>()[0];
7676
auto x_remove_padding = paddle::empty(
7777
{token_num_data}, paddle::DataType::INT64, input_ids.place());
78-
auto padding_offset = paddle::empty(
78+
auto batch_id_per_token = paddle::empty(
7979
{token_num_data}, paddle::DataType::INT32, input_ids.place());
8080
auto cu_seqlens_q =
8181
paddle::full({bsz + 1}, 0, paddle::DataType::INT32, input_ids.place());
@@ -87,7 +87,7 @@ std::vector<paddle::Tensor> GetPaddingOffset(const paddle::Tensor &input_ids,
8787
int blockSize = min((token_num_data + WARP_SIZE - 1) / WARP_SIZE * WARP_SIZE, 128);
8888
#endif
8989
GetPaddingOffsetKernel<<<bsz, 128, 0, cu_stream>>>(
90-
padding_offset.data<int>(),
90+
batch_id_per_token.data<int>(),
9191
cum_offsets_out.data<int>(),
9292
cu_seqlens_q.data<int>(),
9393
cu_seqlens_k.data<int>(),
@@ -102,7 +102,7 @@ std::vector<paddle::Tensor> GetPaddingOffset(const paddle::Tensor &input_ids,
102102
seq_length);
103103
return {x_remove_padding,
104104
cum_offsets_out,
105-
padding_offset,
105+
batch_id_per_token,
106106
cu_seqlens_q,
107107
cu_seqlens_k}; // , enc_token_num, dec_token_num};
108108
}
@@ -133,7 +133,7 @@ PD_BUILD_STATIC_OP(get_padding_offset)
133133
.Inputs({"input_ids", "token_num", "cum_offsets", "seq_len"})
134134
.Outputs({"x_remove_padding",
135135
"cum_offsets_out",
136-
"padding_offset",
136+
"batch_id_per_token",
137137
"cu_seqlens_q",
138138
"cu_seqlens_k"})
139139
.SetKernelFn(PD_KERNEL(GetPaddingOffset))

custom_ops/gpu_ops/speculate_decoding/speculate_get_padding_offset.cu

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ __global__ void SpeculateRemovePadding(int64_t* output_data,
4141
}
4242
}
4343

44-
__global__ void SpeculateGetPaddingOffsetKernel(int* padding_offset,
44+
__global__ void SpeculateGetPaddingOffsetKernel(int* batch_id_per_token,
4545
int* cum_offsets_out,
4646
int* cu_seqlens_q,
4747
int* cu_seqlens_k,
@@ -53,7 +53,7 @@ __global__ void SpeculateGetPaddingOffsetKernel(int* padding_offset,
5353
const int ti = threadIdx.x;
5454
int cum_offset = bi == 0 ? 0 : cum_offsets[bi - 1];
5555
for (int i = ti; i < seq_lens[bi]; i += blockDim.x) {
56-
padding_offset[bi * max_seq_len - cum_offset + i] = bi;
56+
batch_id_per_token[bi * max_seq_len - cum_offset + i] = bi;
5757
}
5858
if (ti == 0) {
5959
cum_offsets_out[bi] = cum_offset;
@@ -81,15 +81,15 @@ std::vector<paddle::Tensor> SpeculateGetPaddingOffset(
8181
const int token_num_data = cpu_token_num.data<int64_t>()[0];
8282
auto x_remove_padding = paddle::full(
8383
{token_num_data}, 0, paddle::DataType::INT64, input_ids.place());
84-
auto padding_offset = paddle::full(
84+
auto batch_id_per_token = paddle::full(
8585
{token_num_data}, 0, paddle::DataType::INT32, input_ids.place());
8686
auto cu_seqlens_q =
8787
paddle::full({bsz + 1}, 0, paddle::DataType::INT32, input_ids.place());
8888
auto cu_seqlens_k =
8989
paddle::full({bsz + 1}, 0, paddle::DataType::INT32, input_ids.place());
9090
int blockSize = min((token_num_data + 32 - 1) / 32 * 32, 128);
9191
SpeculateGetPaddingOffsetKernel<<<bsz, 128, 0, cu_stream>>>(
92-
padding_offset.data<int>(),
92+
batch_id_per_token.data<int>(),
9393
cum_offsets_out.data<int>(),
9494
cu_seqlens_q.data<int>(),
9595
cu_seqlens_k.data<int>(),
@@ -107,7 +107,7 @@ std::vector<paddle::Tensor> SpeculateGetPaddingOffset(
107107
max_draft_tokens);
108108
return {x_remove_padding,
109109
cum_offsets_out,
110-
padding_offset,
110+
batch_id_per_token,
111111
cu_seqlens_q,
112112
cu_seqlens_k}; // , enc_token_num, dec_token_num};
113113
}
@@ -147,7 +147,7 @@ PD_BUILD_STATIC_OP(speculate_get_padding_offset)
147147
"seq_lens_encoder"})
148148
.Outputs({"x_remove_padding",
149149
"cum_offsets_out",
150-
"padding_offset",
150+
"batch_id_per_token",
151151
"cu_seqlens_q",
152152
"cu_seqlens_k"})
153153
.SetKernelFn(PD_KERNEL(SpeculateGetPaddingOffset))

custom_ops/xpu_ops/src/ops/get_padding_offset.cc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,15 +34,15 @@ std::vector<paddle::Tensor> GetPaddingOffset(const paddle::Tensor &input_ids,
3434
const int token_num_data = cpu_token_num.data<int64_t>()[0];
3535
auto x_remove_padding = paddle::full(
3636
{token_num_data}, 0, paddle::DataType::INT64, input_ids.place());
37-
auto padding_offset = paddle::full(
37+
auto batch_id_per_token = paddle::full(
3838
{token_num_data}, 0, paddle::DataType::INT32, input_ids.place());
3939
auto cu_seqlens_q =
4040
paddle::full({bsz + 1}, 0, paddle::DataType::INT32, input_ids.place());
4141
auto cu_seqlens_k =
4242
paddle::full({bsz + 1}, 0, paddle::DataType::INT32, input_ids.place());
4343
int r = baidu::xpu::api::plugin::get_padding_offset(
4444
xpu_ctx->x_context(),
45-
padding_offset.data<int>(),
45+
batch_id_per_token.data<int>(),
4646
cum_offsets_out.data<int>(),
4747
cu_seqlens_q.data<int>(),
4848
cu_seqlens_k.data<int>(),
@@ -55,7 +55,7 @@ std::vector<paddle::Tensor> GetPaddingOffset(const paddle::Tensor &input_ids,
5555
PD_CHECK(r == 0, "baidu::xpu::api::plugin::get_padding_offset failed.");
5656
return {x_remove_padding,
5757
cum_offsets_out,
58-
padding_offset,
58+
batch_id_per_token,
5959
cu_seqlens_q,
6060
cu_seqlens_k};
6161
}
@@ -86,7 +86,7 @@ PD_BUILD_OP(get_padding_offset)
8686
.Inputs({"input_ids", "cum_offsets", "token_num", "seq_len"})
8787
.Outputs({"x_remove_padding",
8888
"cum_offsets_out",
89-
"padding_offset",
89+
"batch_id_per_token",
9090
"cu_seqlens_q",
9191
"cu_seqlens_k"})
9292
.SetKernelFn(PD_KERNEL(GetPaddingOffset))

custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/get_padding_offset.xpu

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
namespace xpu3 {
66
namespace plugin {
77

8-
__global__ void get_padding_offset(int *padding_offset,
8+
__global__ void get_padding_offset(int *batch_id_per_token,
99
int *cum_offsets_out,
1010
int *cu_seqlens_q,
1111
int *cu_seqlens_k,
@@ -20,7 +20,7 @@ __global__ void get_padding_offset(int *padding_offset,
2020
int tid = clusterid * ncores + cid;
2121

2222
int buf_len = 32;
23-
__simd__ int padding_offset_lm[buf_len];
23+
__simd__ int batch_id_per_token_lm[buf_len];
2424
__simd__ int cum_offsets_lm[16];
2525
int seq_len_lm;
2626
for (int i = clusterid; i < bs; i += nclusters) {
@@ -32,11 +32,11 @@ __global__ void get_padding_offset(int *padding_offset,
3232
for (int j = cid * buf_len; j < seq_len_lm; j += ncores * buf_len) {
3333
int cur_len = min(seq_len_lm - j, buf_len);
3434
for (int k = 0; k < cur_len; k++) {
35-
padding_offset_lm[k] = cum_offsets_lm[0];
35+
batch_id_per_token_lm[k] = i;
3636
}
3737
mfence_lm();
38-
LM2GM(padding_offset_lm,
39-
padding_offset + i * max_seq_len - cum_offsets_lm[0] + j,
38+
LM2GM(batch_id_per_token_lm,
39+
batch_id_per_token + i * max_seq_len - cum_offsets_lm[0] + j,
4040
cur_len * sizeof(int));
4141
}
4242
if (cid == 0) {

fastdeploy/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def __init__(
7070
self.stop_seqs_max_len = 8
7171

7272
# NOTE(gongshaotain): form _load_model_init_val()
73-
self.top_p = 0.0
73+
self.top_p = 1.0
7474
self.temperature = 1.0
7575
self.rope_theta = 10000.0
7676
self.penalty_score = 1.0

fastdeploy/worker/gpu_model_runner.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -419,7 +419,6 @@ def _dummy_prefill_inputs(self, num_tokens: int, batch_size: int,
419419
self.share_inputs["max_dec_len"][idx:idx + 1] = max_dec_len
420420
self.share_inputs["min_dec_len"][idx:idx + 1] = max_dec_len
421421
self.share_inputs["stop_flags"][idx:idx + 1] = False
422-
self.share_inputs["top_p"][idx:idx + 1] = 0.0
423422
self.share_inputs["temperature"][idx:idx + 1] = 1
424423

425424
self.share_inputs["first_token_ids"][

fastdeploy/worker/xpu_model_runner.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,15 +58,15 @@ def xpu_pre_process(
5858
(
5959
ids_remove_padding,
6060
cum_offsets,
61-
padding_offset,
61+
batch_id_per_token,
6262
cu_seqlens_q,
6363
cu_seqlens_k,
6464
) = get_padding_offset(input_ids, cum_offsets_now, token_num,
6565
seq_lens_this_time)
6666

6767
share_inputs["ids_remove_padding"] = None # set this after adjust batch
6868
share_inputs["cum_offsets"] = cum_offsets
69-
share_inputs["padding_offset"] = padding_offset
69+
share_inputs["batch_id_per_token"] = batch_id_per_token
7070
share_inputs["cu_seqlens_q"] = cu_seqlens_q
7171
share_inputs["cu_seqlens_k"] = cu_seqlens_k
7272

@@ -79,7 +79,7 @@ def xpu_pre_process(
7979
seq_lens_decoder=share_inputs["seq_lens_decoder"],
8080
seq_lens_this_time=share_inputs["seq_lens_this_time"],
8181
cum_offsets=share_inputs["cum_offsets"],
82-
padding_offset=share_inputs["padding_offset"],
82+
batch_id_per_token=share_inputs["batch_id_per_token"],
8383
cu_seqlens_q=share_inputs["cu_seqlens_q"],
8484
cu_seqlens_k=share_inputs["cu_seqlens_k"],
8585
block_tables=share_inputs["block_tables"],

0 commit comments

Comments
 (0)