Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
lvhan028 committed Jul 24, 2024
1 parent 3de086f commit 57b3f33
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 13 deletions.
16 changes: 8 additions & 8 deletions src/turbomind/kernels/attention/kv_cache_utils_v2.cu
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ __global__ void __launch_bounds__(128) ProcessKV_v2(char** blocks,
const int* cu_block_num,
const float* rope_base,
int rope_dim,
int original_max_position_embeddings,
float rope_ti_scale,
int original_max_position_embeddings,
float rope_scaling_factor,
float low_freq_factor,
float high_freq_factor,
Expand Down Expand Up @@ -205,8 +205,8 @@ void invokeProcessKV_v2(char** blocks,
const int* cu_block_num,
const float* rope_base,
int rope_dim,
int original_max_position_embeddings,
float rope_ti_scale,
int original_max_position_embeddings,
float rope_scaling_factor,
float low_freq_factor,
float high_freq_factor,
Expand Down Expand Up @@ -247,8 +247,8 @@ void invokeProcessKV_v2(char** blocks,
cu_block_num,
rope_base,
rope_dim,
original_max_position_embeddings,
rope_ti_scale,
original_max_position_embeddings,
rope_scaling_factor,
low_freq_factor,
high_freq_factor,
Expand Down Expand Up @@ -282,8 +282,8 @@ void invokeProcessKV_v2(char** blocks,
const int* cu_block_num, \
const float* rope_base, \
int rope_dim, \
int original_max_position_embeddings, \
float rope_ti_scale, \
int original_max_position_embeddings, \
float rope_scaling_factor, \
float low_freq_factor, \
float high_freq_factor, \
Expand Down Expand Up @@ -313,8 +313,8 @@ __global__ void __launch_bounds__(128) flattenKV_v2(T* k,
const int* cu_block_num,
const float* rope_base,
int rope_dim,
int original_max_position_embeddings,
float rope_ti_scale,
int original_max_position_embeddings,
float rope_scaling_factor,
float low_freq_factor,
float high_freq_factor,
Expand Down Expand Up @@ -440,8 +440,8 @@ void invokeFlattenKV_v2(T* k,
const int* cu_block_num,
const float* rope_base,
int rope_dim,
int original_max_position_embeddings,
float rope_ti_scale,
int original_max_position_embeddings,
float rope_scaling_factor,
float low_freq_factor,
float high_freq_factor,
Expand Down Expand Up @@ -479,8 +479,8 @@ void invokeFlattenKV_v2(T* k,
cu_block_num,
rope_base,
rope_dim,
original_max_position_embeddings,
rope_ti_scale,
original_max_position_embeddings,
rope_scaling_factor,
low_freq_factor,
high_freq_factor,
Expand Down Expand Up @@ -511,8 +511,8 @@ void invokeFlattenKV_v2(T* k,
const int* cu_block_num, \
const float* rope_base, \
int rope_dim, \
int original_max_position_embeddings, \
float rope_ti_scale, \
int original_max_position_embeddings, \
float rope_scaling_factor, \
float low_freq_factor, \
float high_freq_factor, \
Expand Down
8 changes: 4 additions & 4 deletions src/turbomind/kernels/attention/kv_cache_utils_v2.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ void invokeProcessKV_v2(char** blocks,
const int* cu_block_num,
const float* rope_base,
int rope_dim,
int original_max_position_embeddings,
float rope_ti_scale,
int original_max_position_embeddings,
float rope_scaling_factor,
float low_freq_factor,
float high_freq_factor,
Expand Down Expand Up @@ -49,8 +49,8 @@ void invokeProcessKV_v2_(const AttentionParams<T>& params)
params.block_iter_params.cu_block_nums,
params.rope_theta,
params.rotary_embedding_dim,
params.original_max_position_embeddings,
params.rope_ti_scale,
params.original_max_position_embeddings,
params.rope_scaling_factor,
params.low_freq_factor,
params.high_freq_factor,
Expand All @@ -76,8 +76,8 @@ void invokeFlattenKV_v2(T* k,
const int* cu_block_num,
const float* rope_base,
int rope_dim,
int original_max_position_embeddings,
float rope_ti_scale,
int original_max_position_embeddings,
float rope_scaling_factor,
float low_freq_factor,
float high_freq_factor,
Expand Down Expand Up @@ -106,8 +106,8 @@ void invokeFlattenKV_v2_(const AttentionParams<T>& params, int sum_k_len)
params.block_iter_params.cu_block_nums,
nullptr, // params.rope_theta,
params.rotary_embedding_dim,
params.original_max_position_embeddings,
params.rope_ti_scale,
params.original_max_position_embeddings,
params.rope_scaling_factor,
params.low_freq_factor,
params.high_freq_factor,
Expand Down
2 changes: 1 addition & 1 deletion src/turbomind/models/llama/llama_params.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ struct LlamaAttentionParams {
int rotary_embedding_dim;
float rotary_embedding_base;
int max_position_embeddings;
int original_max_position_embeddings;
std::string rope_scaling_type;
int original_max_position_embeddings;
float rope_scaling_factor;
float low_freq_factor;
float high_freq_factor;
Expand Down

0 comments on commit 57b3f33

Please sign in to comment.