From 6aea61c34b8d4aba6b927f4ebbd71fc8cc5c2b98 Mon Sep 17 00:00:00 2001 From: lvhan028 Date: Wed, 28 Aug 2024 16:58:30 +0800 Subject: [PATCH] minor --- .../triton_backend/llama/LlamaTritonModel.cc | 85 ++++++++++--------- 1 file changed, 45 insertions(+), 40 deletions(-) diff --git a/src/turbomind/triton_backend/llama/LlamaTritonModel.cc b/src/turbomind/triton_backend/llama/LlamaTritonModel.cc index aef22a068..ab7e32232 100644 --- a/src/turbomind/triton_backend/llama/LlamaTritonModel.cc +++ b/src/turbomind/triton_backend/llama/LlamaTritonModel.cc @@ -213,57 +213,62 @@ LlamaTritonModel::LlamaTritonModel(size_t tensor_para_size, ft::FT_CHECK(false); } - model_name_ = reader["model_config"]["model_name"].as(); - model_param_.head_num = reader["model_config"]["head_num"].as(); - model_param_.head_dim = reader["model_config"]["size_per_head"].as(); - model_param_.kv_head_num = reader["model_config"]["kv_head_num"].as(0); - model_param_.hidden_units = reader["model_config"]["hidden_units"].as(); - model_param_.layer_num = reader["model_config"]["num_layer"].as(); - model_param_.inter_size = reader["model_config"]["inter_size"].as(); - model_param_.vocab_size = reader["model_config"]["vocab_size"].as(); - model_param_.norm_eps = reader["model_config"]["norm_eps"].as(); - model_param_.start_id = reader["model_config"]["start_id"].as(); - model_param_.end_id = reader["model_config"]["end_id"].as(); - attn_param_.cache_block_seq_len = reader["attention_config"]["cache_block_seq_len"].as(0); - model_param_.quant_policy = reader["engine_config"]["quant_policy"].as(0); + auto model_reader = reader["model_config"]; + auto attention_reader = reader["attention_config"]; + auto lora_reader = reader["lora_config"]; + auto engine_reader = reader["engine_config"]; + + model_name_ = model_reader["model_name"].as(); + model_param_.head_num = model_reader["head_num"].as(); + model_param_.head_dim = model_reader["size_per_head"].as(); + model_param_.kv_head_num = model_reader["kv_head_num"].as(0); + model_param_.hidden_units = model_reader["hidden_units"].as(); + model_param_.layer_num = model_reader["num_layer"].as(); + model_param_.inter_size = model_reader["inter_size"].as(); + model_param_.vocab_size = model_reader["vocab_size"].as(); + model_param_.norm_eps = model_reader["norm_eps"].as(); + model_param_.start_id = model_reader["start_id"].as(); + model_param_.end_id = model_reader["end_id"].as(); + attn_param_.cache_block_seq_len = attention_reader["cache_block_seq_len"].as(0); + model_param_.quant_policy = engine_reader["quant_policy"].as(0); // Only weight classes need these - attn_bias_ = reader["model_config"]["attn_bias"].as(0); - group_size_ = reader["model_config"]["group_size"].as(0); + attn_bias_ = model_reader["attn_bias"].as(0); + group_size_ = model_reader["group_size"].as(0); // rotary embedding parameters - attn_param_.rotary_embedding_dim = reader["attention_config"]["rotary_embedding"].as(); - attn_param_.rotary_embedding_base = reader["attention_config"]["rope_theta"].as(10000.0f); - attn_param_.rope_scaling_type = reader["attention_config"]["rope_scaling_type"].as(""); - attn_param_.rope_scaling_factor = reader["attention_config"]["rope_scaling_factor"].as(0.f); - attn_param_.low_freq_factor = reader["attention_config"]["low_freq_factor"].as(1.0); - attn_param_.high_freq_factor = reader["attention_config"]["high_freq_factor"].as(1.0); - attn_param_.max_position_embeddings = reader["attention_config"]["max_position_embeddings"].as(0); - attn_param_.use_dynamic_ntk = reader["attention_config"]["use_dynamic_ntk"].as(0); - attn_param_.use_logn_attn = reader["attention_config"]["use_logn_attn"].as(0); + attn_param_.rotary_embedding_dim = attention_reader["rotary_embedding"].as(); + attn_param_.rotary_embedding_base = attention_reader["rope_theta"].as(10000.0f); + attn_param_.rope_scaling_type = attention_reader["rope_scaling_type"].as(""); + attn_param_.rope_scaling_factor = attention_reader["rope_scaling_factor"].as(0.f); + attn_param_.low_freq_factor = attention_reader["low_freq_factor"].as(1.0); + attn_param_.high_freq_factor = attention_reader["high_freq_factor"].as(1.0); + attn_param_.max_position_embeddings = attention_reader["max_position_embeddings"].as(0); + attn_param_.use_dynamic_ntk = attention_reader["use_dynamic_ntk"].as(0); + attn_param_.use_logn_attn = attention_reader["use_logn_attn"].as(0); attn_param_.original_max_position_embeddings = - reader["attention_config"]["original_max_position_embeddings"].as(0); + attention_reader["original_max_position_embeddings"].as(0); - engine_param_.max_batch_size = reader["engine_config"]["max_batch_size"].as(0); - engine_param_.max_prefill_token_num = reader["engine_config"]["max_prefill_token_num"].as(0); - engine_param_.max_context_token_num = reader["engine_config"]["max_context_token_num"].as(0); - engine_param_.session_len = reader["model_config"]["session_len"].as(0); + engine_param_.max_batch_size = engine_reader["max_batch_size"].as(0); + engine_param_.max_prefill_token_num = engine_reader["max_prefill_token_num"].as(0); + engine_param_.max_context_token_num = engine_reader["max_context_token_num"].as(0); + engine_param_.session_len = model_reader["session_len"].as(0); - engine_param_.cache_max_block_count = reader["engine_config"]["cache_max_entry_count"].as(0); - engine_param_.cache_chunk_size = reader["engine_config"]["cache_chunk_size"].as(0); - engine_param_.enable_prefix_caching = reader["engine_config"]["enable_prefix_caching"].as(false); + engine_param_.cache_max_block_count = engine_reader["cache_max_entry_count"].as(0); + engine_param_.cache_chunk_size = engine_reader["cache_chunk_size"].as(0); + engine_param_.enable_prefix_caching = engine_reader["enable_prefix_caching"].as(false); - engine_param_.num_tokens_per_iter = reader["engine_config"]["num_tokens_per_iter"].as(0); - engine_param_.max_prefill_iters = reader["engine_config"]["max_prefill_iters"].as(1); + engine_param_.num_tokens_per_iter = engine_reader["num_tokens_per_iter"].as(0); + engine_param_.max_prefill_iters = engine_reader["max_prefill_iters"].as(1); lora_param_.policy = ft::getLoraPolicy(reader["lora_config"]["lora_policy"].as("")); - lora_param_.r = reader["lora_config"]["lora_r"].as(0); - lora_param_.scale = reader["lora_config"]["lora_scale"].as(0); - lora_param_.max_wo_r = reader["lora_config"]["lora_max_wo_r"].as(0); - lora_param_.rank_pattern = getLoraPattern(reader["lora_config"]["lora_rank_pattern"].as(""), + lora_param_.r = lora_reader["lora_r"].as(0); + lora_param_.scale = lora_reader["lora_scale"].as(0); + lora_param_.max_wo_r = lora_reader["lora_max_wo_r"].as(0); + lora_param_.rank_pattern = getLoraPattern(lora_reader["lora_rank_pattern"].as(""), [](const std::string& s) { return std::stoi(s); }); - lora_param_.scale_pattern = getLoraPattern(reader["lora_config"]["lora_scale_pattern"].as(""), + lora_param_.scale_pattern = getLoraPattern(lora_reader["lora_scale_pattern"].as(""), [](const std::string& s) { return std::stof(s); }); handleMissingParams(); @@ -273,7 +278,7 @@ LlamaTritonModel::LlamaTritonModel(size_t tensor_para_size, const auto device_count = ft::getDeviceCount(); engines_.resize(device_count); - const std::string weight_type_str = reader["model_config"]["weight_type"].as(); + const std::string weight_type_str = model_reader["weight_type"].as(); if (weight_type_str == "fp16") { weight_type_ = ft::WeightType::kFP16; }