Skip to content

Commit

Permalink
rotary_emb_base to rotary_embedding_base
Browse files Browse the repository at this point in the history
  • Loading branch information
lvhan028 committed Sep 8, 2023
1 parent 59441ea commit 02ab71b
Show file tree
Hide file tree
Showing 8 changed files with 21 additions and 21 deletions.
2 changes: 1 addition & 1 deletion src/turbomind/kernels/decoder_masked_multihead_attention.h
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ struct Multihead_attention_params: public Multihead_attention_params_base<T> {
int max_position_embeddings = 0;
bool use_dynamic_ntk = false;
bool use_logn_attn = false;
float rotary_emb_base = 10000.0f;
float rotary_embedding_base = 10000.0f;
};

template<class T>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1378,19 +1378,19 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T>
q = add(q, q_bias);
k = add(k, k_bias);

float rotary_emb_base = params.rotary_emb_base;
float rotary_embedding_base = params.rotary_embedding_base;
if (params.use_dynamic_ntk) {
// +1 because of `length_per_sample == context_length - 1`
rotary_emb_base = rotary_embedding_get_base(params.length_per_sample[bi] + 1,
params.max_position_embeddings,
params.rotary_embedding_dim,
rotary_emb_base);
rotary_embedding_base = rotary_embedding_get_base(params.length_per_sample[bi] + 1,
params.max_position_embeddings,
params.rotary_embedding_dim,
rotary_embedding_base);
}

// Padded len
const int padd_len = (params.total_padding_tokens == nullptr) ? 0 : params.total_padding_tokens[bi];
if (params.rotary_embedding_dim > 0) {
apply_rotary_embedding(q, k, tidx, params.rotary_embedding_dim, rotary_emb_base, params.timestep - padd_len);
apply_rotary_embedding(q, k, tidx, params.rotary_embedding_dim, rotary_embedding_base, params.timestep - padd_len);
}

if (params.use_logn_attn) {
Expand Down
14 changes: 7 additions & 7 deletions src/turbomind/kernels/unfused_attention_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -863,7 +863,7 @@ __global__ void add_fusedQKV_bias_transpose_kernel(T* q_buf,
int kv_head_num,
int size_per_head,
int rotary_embedding_dim,
float rotary_emb_base,
float rotary_embedding_base,
int max_position_embeddings,
bool use_dynamic_ntk,
bool use_logn_attn)
Expand Down Expand Up @@ -933,12 +933,12 @@ __global__ void add_fusedQKV_bias_transpose_kernel(T* q_buf,
const int timestep = history_len + seq_idx;

if (use_dynamic_ntk) {
rotary_emb_base = mmha::rotary_embedding_get_base(
context_len, max_position_embeddings, rotary_embedding_dim, rotary_emb_base);
rotary_embedding_base = mmha::rotary_embedding_get_base(
context_len, max_position_embeddings, rotary_embedding_dim, rotary_embedding_base);
}

// TODO: unused computation on k if GQA is used
mmha::apply_rotary_embedding(q, k, tidx, rotary_embedding_dim, rotary_emb_base, timestep);
mmha::apply_rotary_embedding(q, k, tidx, rotary_embedding_dim, rotary_embedding_base, timestep);

if (use_logn_attn) {
// +1 to convert to context length at the timestep
Expand Down Expand Up @@ -990,7 +990,7 @@ __global__ void add_fusedQKV_bias_transpose_kernel(T* q_buf,
kv_head_num, \
size_per_head, \
rotary_embedding_dim, \
rotary_emb_base, \
rotary_embedding_base, \
max_position_embeddings, \
use_dynamic_ntk, \
use_logn_attn);
Expand All @@ -1011,7 +1011,7 @@ void invokeAddFusedQKVBiasTranspose(T* q_buf,
const int kv_head_num,
const int size_per_head,
const int rotary_embedding_dim,
float rotary_emb_base,
float rotary_embedding_base,
int max_position_embeddings,
bool use_dynamic_ntk,
bool use_logn_attn,
Expand Down Expand Up @@ -1041,7 +1041,7 @@ void invokeAddFusedQKVBiasTranspose(T* q_buf,
const int kv_head_num, \
const int size_per_head, \
const int rotary_embedding_dim, \
float rotary_emb_base, \
float rotary_embedding_base, \
int max_position_embeddings, \
bool use_dynamic_ntk, \
bool use_logn_attn, \
Expand Down
2 changes: 1 addition & 1 deletion src/turbomind/kernels/unfused_attention_kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ void invokeAddFusedQKVBiasTranspose(T* q_buf,
const int kv_head_num,
const int size_per_head,
const int rotary_embedding_dim,
float rotary_emb_base,
float rotary_embedding_base,
int max_position_embeddings,
bool use_dynamic_ntk,
bool use_logn_attn,
Expand Down
2 changes: 1 addition & 1 deletion src/turbomind/models/llama/LlamaContextAttentionLayer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ inline void LlamaContextAttentionLayer<T>::forward(TensorMap*
local_kv_head_num_,
size_per_head_,
params_.rotray_embedding_dim,
params_.rotary_emb_base,
params_.rotary_embedding_base,
params_.max_position_embeddings,
params_.use_dynamic_ntk,
params_.use_logn_attn,
Expand Down
6 changes: 3 additions & 3 deletions src/turbomind/models/llama/LlamaDecoderSelfAttentionLayer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ static inline void fusedQKV_masked_attention_dispatch(const T* qkv_buf,
const int kv_head_num,
const int size_per_head,
const int rotary_embedding_dim,
const float rotary_emb_base,
const float rotary_embedding_base,
const int max_position_embeddings,
const bool use_dynamic_ntk,
const bool use_logn_attn,
Expand Down Expand Up @@ -130,7 +130,7 @@ static inline void fusedQKV_masked_attention_dispatch(const T* qkv_buf,

params.hidden_size_per_head = size_per_head;
params.rotary_embedding_dim = rotary_embedding_dim;
params.rotary_emb_base = rotary_emb_base;
params.rotary_embedding_base = rotary_embedding_base;
params.max_position_embeddings = max_position_embeddings;
params.use_dynamic_ntk = use_dynamic_ntk;
params.use_logn_attn = use_logn_attn;
Expand Down Expand Up @@ -263,7 +263,7 @@ void LlamaDecoderSelfAttentionLayer<T>::forward(TensorMap* o
local_kv_head_num_,
size_per_head_,
params_.rotray_embedding_dim,
params_.rotary_emb_base,
params_.rotary_embedding_base,
params_.max_position_embeddings,
params_.use_dynamic_ntk,
params_.use_logn_attn,
Expand Down
2 changes: 1 addition & 1 deletion src/turbomind/models/llama/llama_params.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ namespace turbomind {

struct LlamaAttentionParams {
int rotray_embedding_dim;
float rotary_emb_base;
float rotary_embedding_base;
int max_position_embeddings;
bool use_dynamic_ntk;
bool use_logn_attn;
Expand Down
2 changes: 1 addition & 1 deletion src/turbomind/triton_backend/llama/LlamaTritonModel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ LlamaTritonModel<T>::LlamaTritonModel(size_t tensor_para_size,
group_size_ = reader.GetInteger("llama", "group_size", 0);

attn_params_.rotray_embedding_dim = reader.GetInteger("llama", "rotary_embedding");
attn_params_.rotary_emb_base = reader.GetFloat("llama", "rope_theta", 10000.0f);
attn_params_.rotary_embedding_base = reader.GetFloat("llama", "rope_theta", 10000.0f);
attn_params_.max_position_embeddings = reader.GetInteger("llama", "max_position_embeddings", 0);
attn_params_.use_dynamic_ntk = reader.GetInteger("llama", "use_dynamic_ntk", 0);
attn_params_.use_logn_attn = reader.GetInteger("llama", "use_logn_attn", 0);
Expand Down

0 comments on commit 02ab71b

Please sign in to comment.