Skip to content

Commit

Permalink
minor
Browse files Browse the repository at this point in the history
  • Loading branch information
lvhan028 committed Aug 28, 2024
1 parent cbf238d commit 6aea61c
Showing 1 changed file with 45 additions and 40 deletions.
85 changes: 45 additions & 40 deletions src/turbomind/triton_backend/llama/LlamaTritonModel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -213,57 +213,62 @@ LlamaTritonModel<T>::LlamaTritonModel(size_t tensor_para_size,
ft::FT_CHECK(false);
}

model_name_ = reader["model_config"]["model_name"].as<std::string>();
model_param_.head_num = reader["model_config"]["head_num"].as<int>();
model_param_.head_dim = reader["model_config"]["size_per_head"].as<int>();
model_param_.kv_head_num = reader["model_config"]["kv_head_num"].as<int>(0);
model_param_.hidden_units = reader["model_config"]["hidden_units"].as<int>();
model_param_.layer_num = reader["model_config"]["num_layer"].as<int>();
model_param_.inter_size = reader["model_config"]["inter_size"].as<int>();
model_param_.vocab_size = reader["model_config"]["vocab_size"].as<int>();
model_param_.norm_eps = reader["model_config"]["norm_eps"].as<float>();
model_param_.start_id = reader["model_config"]["start_id"].as<int>();
model_param_.end_id = reader["model_config"]["end_id"].as<int>();
attn_param_.cache_block_seq_len = reader["attention_config"]["cache_block_seq_len"].as<int>(0);
model_param_.quant_policy = reader["engine_config"]["quant_policy"].as<int>(0);
auto model_reader = reader["model_config"];
auto attention_reader = reader["attention_config"];
auto lora_reader = reader["lora_config"];
auto engine_reader = reader["engine_config"];

model_name_ = model_reader["model_name"].as<std::string>();
model_param_.head_num = model_reader["head_num"].as<int>();
model_param_.head_dim = model_reader["size_per_head"].as<int>();
model_param_.kv_head_num = model_reader["kv_head_num"].as<int>(0);
model_param_.hidden_units = model_reader["hidden_units"].as<int>();
model_param_.layer_num = model_reader["num_layer"].as<int>();
model_param_.inter_size = model_reader["inter_size"].as<int>();
model_param_.vocab_size = model_reader["vocab_size"].as<int>();
model_param_.norm_eps = model_reader["norm_eps"].as<float>();
model_param_.start_id = model_reader["start_id"].as<int>();
model_param_.end_id = model_reader["end_id"].as<int>();
attn_param_.cache_block_seq_len = attention_reader["cache_block_seq_len"].as<int>(0);
model_param_.quant_policy = engine_reader["quant_policy"].as<int>(0);

// Only weight classes need these
attn_bias_ = reader["model_config"]["attn_bias"].as<int>(0);
group_size_ = reader["model_config"]["group_size"].as<int>(0);
attn_bias_ = model_reader["attn_bias"].as<int>(0);
group_size_ = model_reader["group_size"].as<int>(0);

// rotary embedding parameters
attn_param_.rotary_embedding_dim = reader["attention_config"]["rotary_embedding"].as<int>();
attn_param_.rotary_embedding_base = reader["attention_config"]["rope_theta"].as<float>(10000.0f);
attn_param_.rope_scaling_type = reader["attention_config"]["rope_scaling_type"].as<std::string>("");
attn_param_.rope_scaling_factor = reader["attention_config"]["rope_scaling_factor"].as<float>(0.f);
attn_param_.low_freq_factor = reader["attention_config"]["low_freq_factor"].as<float>(1.0);
attn_param_.high_freq_factor = reader["attention_config"]["high_freq_factor"].as<float>(1.0);
attn_param_.max_position_embeddings = reader["attention_config"]["max_position_embeddings"].as<int>(0);
attn_param_.use_dynamic_ntk = reader["attention_config"]["use_dynamic_ntk"].as<int>(0);
attn_param_.use_logn_attn = reader["attention_config"]["use_logn_attn"].as<int>(0);
attn_param_.rotary_embedding_dim = attention_reader["rotary_embedding"].as<int>();
attn_param_.rotary_embedding_base = attention_reader["rope_theta"].as<float>(10000.0f);
attn_param_.rope_scaling_type = attention_reader["rope_scaling_type"].as<std::string>("");
attn_param_.rope_scaling_factor = attention_reader["rope_scaling_factor"].as<float>(0.f);
attn_param_.low_freq_factor = attention_reader["low_freq_factor"].as<float>(1.0);
attn_param_.high_freq_factor = attention_reader["high_freq_factor"].as<float>(1.0);
attn_param_.max_position_embeddings = attention_reader["max_position_embeddings"].as<int>(0);
attn_param_.use_dynamic_ntk = attention_reader["use_dynamic_ntk"].as<int>(0);
attn_param_.use_logn_attn = attention_reader["use_logn_attn"].as<int>(0);

attn_param_.original_max_position_embeddings =
reader["attention_config"]["original_max_position_embeddings"].as<int>(0);
attention_reader["original_max_position_embeddings"].as<int>(0);

engine_param_.max_batch_size = reader["engine_config"]["max_batch_size"].as<int>(0);
engine_param_.max_prefill_token_num = reader["engine_config"]["max_prefill_token_num"].as<int>(0);
engine_param_.max_context_token_num = reader["engine_config"]["max_context_token_num"].as<int>(0);
engine_param_.session_len = reader["model_config"]["session_len"].as<int>(0);
engine_param_.max_batch_size = engine_reader["max_batch_size"].as<int>(0);
engine_param_.max_prefill_token_num = engine_reader["max_prefill_token_num"].as<int>(0);
engine_param_.max_context_token_num = engine_reader["max_context_token_num"].as<int>(0);
engine_param_.session_len = model_reader["session_len"].as<int>(0);

engine_param_.cache_max_block_count = reader["engine_config"]["cache_max_entry_count"].as<float>(0);
engine_param_.cache_chunk_size = reader["engine_config"]["cache_chunk_size"].as<int>(0);
engine_param_.enable_prefix_caching = reader["engine_config"]["enable_prefix_caching"].as<bool>(false);
engine_param_.cache_max_block_count = engine_reader["cache_max_entry_count"].as<float>(0);
engine_param_.cache_chunk_size = engine_reader["cache_chunk_size"].as<int>(0);
engine_param_.enable_prefix_caching = engine_reader["enable_prefix_caching"].as<bool>(false);

engine_param_.num_tokens_per_iter = reader["engine_config"]["num_tokens_per_iter"].as<int>(0);
engine_param_.max_prefill_iters = reader["engine_config"]["max_prefill_iters"].as<int>(1);
engine_param_.num_tokens_per_iter = engine_reader["num_tokens_per_iter"].as<int>(0);
engine_param_.max_prefill_iters = engine_reader["max_prefill_iters"].as<int>(1);

lora_param_.policy = ft::getLoraPolicy(reader["lora_config"]["lora_policy"].as<std::string>(""));
lora_param_.r = reader["lora_config"]["lora_r"].as<int>(0);
lora_param_.scale = reader["lora_config"]["lora_scale"].as<float>(0);
lora_param_.max_wo_r = reader["lora_config"]["lora_max_wo_r"].as<int>(0);
lora_param_.rank_pattern = getLoraPattern<int>(reader["lora_config"]["lora_rank_pattern"].as<std::string>(""),
lora_param_.r = lora_reader["lora_r"].as<int>(0);
lora_param_.scale = lora_reader["lora_scale"].as<float>(0);
lora_param_.max_wo_r = lora_reader["lora_max_wo_r"].as<int>(0);
lora_param_.rank_pattern = getLoraPattern<int>(lora_reader["lora_rank_pattern"].as<std::string>(""),
[](const std::string& s) { return std::stoi(s); });
lora_param_.scale_pattern = getLoraPattern<float>(reader["lora_config"]["lora_scale_pattern"].as<std::string>(""),
lora_param_.scale_pattern = getLoraPattern<float>(lora_reader["lora_scale_pattern"].as<std::string>(""),
[](const std::string& s) { return std::stof(s); });
handleMissingParams();

Expand All @@ -273,7 +278,7 @@ LlamaTritonModel<T>::LlamaTritonModel(size_t tensor_para_size,
const auto device_count = ft::getDeviceCount();
engines_.resize(device_count);

const std::string weight_type_str = reader["model_config"]["weight_type"].as<std::string>();
const std::string weight_type_str = model_reader["weight_type"].as<std::string>();
if (weight_type_str == "fp16") {
weight_type_ = ft::WeightType::kFP16;
}
Expand Down

0 comments on commit 6aea61c

Please sign in to comment.