Skip to content

Commit

Permalink
use yaml config instead of ini config
Browse files Browse the repository at this point in the history
  • Loading branch information
lvhan028 committed Aug 23, 2024
1 parent 3ffb0c4 commit 8227b67
Show file tree
Hide file tree
Showing 10 changed files with 101 additions and 104 deletions.
2 changes: 2 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ endif()
option(BUILD_PY_FFI "Build python ffi" ON)
option(BUILD_TEST "Build tests" OFF)

find_package(yaml-cpp REQUIRED)

include(FetchContent)
if (BUILD_TEST)
FetchContent_Declare(
Expand Down
2 changes: 1 addition & 1 deletion docker/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ ARG PYTHON_VERSION=3.10
ARG TORCH_VERSION=2.3.0
ARG TORCHVISION_VERSION=0.18.0

RUN apt-get update -y && apt-get install -y software-properties-common wget vim git curl &&\
RUN apt-get update -y && apt-get install -y software-properties-common wget vim git curl libyaml-cpp-dev &&\
curl https://sh.rustup.rs -sSf | sh -s -- -y &&\
add-apt-repository ppa:deadsnakes/ppa -y && apt-get update -y && apt-get install -y --no-install-recommends \
ninja-build rapidjson-dev libgoogle-glog-dev gdb python${PYTHON_VERSION} python${PYTHON_VERSION}-dev python${PYTHON_VERSION}-venv \
Expand Down
16 changes: 8 additions & 8 deletions lmdeploy/archs.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,15 +157,15 @@ def get_model_arch(model_path: str):
"""
if os.path.exists(os.path.join(model_path, 'triton_models', 'weights')):
# the turbomind model
import configparser
import yaml
config_file = os.path.join(model_path, 'triton_models', 'weights',
'config.ini')
config = configparser.ConfigParser()
config.read(config_file)
model_arch = config['llama']['model_arch']
tm_config = TurbomindEngineConfig()
for key in config['llama']:
setattr(tm_config, key, config['llama'][key])
'config.yaml')
with open(config_file, 'r') as f:
config = yaml.safe_load(f)
model_arch = config['model_arch']
from .turbomind.deploy.target_model.base import TurbomindModelConfig
tm_config = TurbomindModelConfig.from_dict(config)

return model_arch, tm_config
else:
# transformers model
Expand Down
11 changes: 5 additions & 6 deletions lmdeploy/serve/async_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,18 +23,17 @@

def get_names_from_model(model_path: str, model_name: str = None):
"""Get model name and chat template name from workspace model."""
from configparser import ConfigParser
triton_model_path = os.path.join(model_path, 'triton_models', 'weights')
if not os.path.exists(triton_model_path):
chat_template_name = best_match_model(model_path)
else:
# `model_path` refers to a turbomind model, reading
# chat_template_name from the config
ini_path = os.path.join(triton_model_path, 'config.ini')
with open(ini_path, 'r') as f:
parser = ConfigParser()
parser.read_file(f)
chat_template_name = parser['llama']['chat_template']
config_path = os.path.join(triton_model_path, 'config.yaml')
with open(config_path, 'r') as f:
import yaml
config = yaml.safe_load(f)
chat_template_name = config['chat_template']
model_name = model_name if model_name else model_path
return model_name, chat_template_name

Expand Down
2 changes: 1 addition & 1 deletion lmdeploy/turbomind/deploy/source_model/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import torch
from safetensors.torch import load_file

from lmdeploy.archs import get_model_arch
from lmdeploy.tokenizer import Tokenizer

from .base import INPUT_MODELS, BaseInputModel, BaseReader
Expand Down Expand Up @@ -141,6 +140,7 @@ def __init__(self, model_path: str, tokenizer_path: str, **kwargs: dict):
ckpt_path = model_path
self.ckpt_path = ckpt_path
self.ckpt_files = self.get_ckpt()
from lmdeploy.archs import get_model_arch
_, self.model_config = get_model_arch(model_path)
self.model_config = self.model_config.to_dict()

Expand Down
18 changes: 10 additions & 8 deletions lmdeploy/turbomind/deploy/target_model/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# Copyright (c) OpenMMLab. All rights reserved.
import configparser
import copy
import inspect
import io
Expand All @@ -10,6 +9,7 @@

import torch
import tqdm
import yaml
from mmengine import Registry
from pydantic.dataclasses import dataclass

Expand Down Expand Up @@ -203,13 +203,15 @@ def get_config(self, cfg: TurbomindModelConfig) -> TurbomindModelConfig:
def export_config(self) -> None:
"""export turbomind config."""
if self.to_file:
config = configparser.ConfigParser()
cfg = dict(llama=self.cfg.__dict__)
for section, key_values in cfg.items():
config[section] = key_values
config_path = osp.join(self.out_dir, 'config.ini')
with open(config_path, 'w') as f:
config.write(f)
# # config = configparser.ConfigParser()
# cfg = dict(llama=self.cfg.__dict__)
# for section, key_values in cfg.items():
# # config[section] = key_values
config_path = osp.join(self.out_dir, 'config.yaml')
with open(config_path, 'w') as file:
yaml.safe_dump(self.cfg.__dict__, file)
# with open(config_path, 'w') as f:
# config.write(f)

def export_weight(self, param: torch.Tensor, name: str) -> None:
"""export turbomind weight."""
Expand Down
2 changes: 1 addition & 1 deletion lmdeploy/turbomind/supported_models.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# Copyright (c) OpenMMLab. All rights reserved.
from lmdeploy.archs import get_model_arch
from lmdeploy.utils import get_logger

logger = get_logger('lmdeploy')
Expand Down Expand Up @@ -74,6 +73,7 @@ def _is_head_dim_128(cfg):
if os.path.exists(triton_model_path):
support_by_turbomind = True
else:
from lmdeploy.archs import get_model_arch
arch, cfg = get_model_arch(model_path)

if arch in SUPPORTED_ARCHS.keys():
Expand Down
25 changes: 12 additions & 13 deletions lmdeploy/turbomind/turbomind.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import os.path as osp
import sys
from concurrent.futures import ThreadPoolExecutor
from configparser import ConfigParser
from itertools import repeat
from queue import LifoQueue, Queue
from typing import Dict, Iterable, List, Union
Expand All @@ -18,7 +17,6 @@
from lmdeploy.tokenizer import Tokenizer
from lmdeploy.utils import get_logger, get_model

from .deploy.converter import SUPPORTED_FORMATS, get_tm_model
from .deploy.target_model.base import TurbomindModelConfig
from .supported_models import is_supported
from .utils import ModelSource, get_model_source
Expand Down Expand Up @@ -172,6 +170,8 @@ def _from_hf(self, model_source: ModelSource, model_path: str,
if engine_config is None:
logger.warning('input engine config is None, using the default')
engine_config = TurbomindEngineConfig()

from .deploy.converter import SUPPORTED_FORMATS
assert engine_config.model_format in SUPPORTED_FORMATS, \
f'The model format should be in {SUPPORTED_FORMATS}'

Expand All @@ -180,6 +180,7 @@ def _from_hf(self, model_source: ModelSource, model_path: str,
'Plz try pytorch engine instead.')

# convert transformers model into turbomind model format
from .deploy.converter import get_tm_model
tm_model = get_tm_model(model_path, self.model_name,
self.chat_template_name, engine_config)

Expand Down Expand Up @@ -212,20 +213,18 @@ def _from_hf(self, model_source: ModelSource, model_path: str,
def _from_workspace(self, model_path: str,
engine_config: TurbomindEngineConfig):
"""Load model which is converted by `lmdeploy convert`"""
ini_path = osp.join(model_path, 'triton_models', 'weights',
'config.ini')
config_path = osp.join(model_path, 'triton_models', 'weights',
'config.yaml')
# load cfg
with open(ini_path, 'r') as f:
parser = ConfigParser()
parser.read_file(f)
section_name = 'llama'
_cfg = parser._sections[section_name]
with open(config_path, 'r') as f:
import yaml
_cfg = yaml.safe_load(f)
cfg = TurbomindModelConfig.from_dict(_cfg)

# check whether input tp is valid
if cfg.tensor_para_size != 1 and \
self.gpu_count != cfg.tensor_para_size:
logger.info(f'found tp={cfg.tensor_para_size} in config.ini.')
logger.info(f'found tp={cfg.tensor_para_size} in config.yaml.')
self.gpu_count = cfg.tensor_para_size

if engine_config is not None:
Expand All @@ -237,13 +236,13 @@ def _from_workspace(self, model_path: str,
cfg.chat_template_name = self.chat_template_name
# update cfg
self.config = cfg

print(yaml.safe_dump(cfg.__dict__))
# create model
logger.warning(f'model_config:\n\n{cfg.toini()}')
logger.warning(f'model_config:\n\n{cfg}')
weight_dir = osp.join(model_path, 'triton_models', 'weights')
model_comm = _tm.AbstractTransformerModel.create_llama_model(
model_dir=weight_dir,
config=cfg.toini(),
config=yaml.safe_dump(cfg.__dict__),
tensor_para_size=self.gpu_count,
data_type=self.config.weight_type)

Expand Down
2 changes: 1 addition & 1 deletion src/turbomind/triton_backend/llama/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -25,5 +25,5 @@ set(llama_triton_backend_files
find_package(CUDAToolkit REQUIRED)
add_library(LlamaTritonBackend STATIC ${llama_triton_backend_files})
set_property(TARGET LlamaTritonBackend PROPERTY POSITION_INDEPENDENT_CODE ON)
target_link_libraries(LlamaTritonBackend PUBLIC TransformerTritonBackend Llama tensor memory_utils CUDA::cublasLt)
target_link_libraries(LlamaTritonBackend PUBLIC TransformerTritonBackend Llama tensor memory_utils CUDA::cublasLt ${YAML_CPP_LIBRARIES})
target_compile_features(LlamaTritonBackend PRIVATE cxx_std_14)
125 changes: 60 additions & 65 deletions src/turbomind/triton_backend/llama/LlamaTritonModel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/triton_backend/multi_gpu_gpt/ParallelGptTritonModel.cc

#include "src/turbomind/triton_backend/llama/LlamaTritonModel.h"
#include "3rdparty/INIReader.h"
#include "src/turbomind/models/llama/LlamaDenseWeight.h"
#include "src/turbomind/models/llama/LlamaInstanceComm.h"
#include "src/turbomind/models/llama/LlamaLinear.h"
Expand All @@ -30,6 +29,7 @@
#include "src/turbomind/utils/cuda_utils.h"
#include <cuda_runtime.h>
#include <mutex>
#include <yaml-cpp/yaml.h>

namespace ft = turbomind;

Expand Down Expand Up @@ -189,81 +189,76 @@ LlamaTritonModel<T>::LlamaTritonModel(size_t tensor_para_size,
weights_(ft::getDeviceCount()),
enable_custom_all_reduce_(enable_custom_all_reduce)
{
INIReader reader;
FT_CHECK_WITH_INFO(!(config.empty() && model_dir.empty()), "invalid init options");

if (!model_dir.empty()) {
model_dir_ = model_dir;
const std::string inifile{model_dir + "/config.ini"};
reader = INIReader(inifile);
if (reader.ParseError() < 0) {
TM_LOG_ERROR("[ERROR] Can't load %s", inifile.c_str());
ft::FT_CHECK(false);
YAML::Node reader;

try {
if (!model_dir.empty()) {
model_dir_ = model_dir;
const std::string config_file{model_dir + "/config.yaml"};
reader = YAML::LoadFile(config_file);
}
}

if (!config.empty()) {
std::FILE* tmpf = std::tmpfile();
std::fputs(config.c_str(), tmpf);
std::rewind(tmpf);
reader = INIReader(tmpf);
if (reader.ParseError() < 0) {
TM_LOG_ERROR("[ERROR] Can't init with config %s", config.c_str());
ft::FT_CHECK(false);
if (!config.empty()) {
reader = YAML::Load(config);
}
} catch (const YAML::Exception& e) {
std::cerr << "Error reading YAML config: " << e.what() << std::endl;
ft::FT_CHECK(false);
}

model_name_ = reader.Get("llama", "model_name");
model_param_.head_num = reader.GetInteger("llama", "head_num");
model_param_.head_dim = reader.GetInteger("llama", "size_per_head");
model_param_.kv_head_num = reader.GetInteger("llama", "kv_head_num", 0);
model_param_.hidden_units = reader.GetInteger("llama", "hidden_units");
model_param_.layer_num = reader.GetInteger("llama", "num_layer");
model_param_.inter_size = reader.GetInteger("llama", "inter_size");
model_param_.vocab_size = reader.GetInteger("llama", "vocab_size");
model_param_.norm_eps = reader.GetFloat("llama", "norm_eps");
model_param_.start_id = reader.GetInteger("llama", "start_id");
model_param_.end_id = reader.GetInteger("llama", "end_id");
attn_param_.cache_block_seq_len = reader.GetInteger("llama", "cache_block_seq_len", 0);
model_param_.quant_policy = reader.GetInteger("llama", "quant_policy", 0);
model_name_ = reader["model_name"].as<std::string>();
model_param_.head_num = reader["head_num"].as<int>();
model_param_.head_dim = reader["size_per_head"].as<int>();
model_param_.kv_head_num = reader["kv_head_num"].as<int>(0);
model_param_.hidden_units = reader["hidden_units"].as<int>();
model_param_.layer_num = reader["num_layer"].as<int>();
model_param_.inter_size = reader["inter_size"].as<int>();
model_param_.vocab_size = reader["vocab_size"].as<int>();
model_param_.norm_eps = reader["norm_eps"].as<float>();
model_param_.start_id = reader["start_id"].as<int>();
model_param_.end_id = reader["end_id"].as<int>();
attn_param_.cache_block_seq_len = reader["cache_block_seq_len"].as<int>(0);
model_param_.quant_policy = reader["quant_policy"].as<int>(0);

// Only weight classes need these
attn_bias_ = reader.GetInteger("llama", "attn_bias", 0);
group_size_ = reader.GetInteger("llama", "group_size", 0);
attn_bias_ = reader["attn_bias"].as<int>(0);
group_size_ = reader["group_size"].as<int>(0);

// rotary embedding parameters
attn_param_.rotary_embedding_dim = reader.GetInteger("llama", "rotary_embedding");
attn_param_.rotary_embedding_base = reader.GetFloat("llama", "rope_theta", 10000.0f);
attn_param_.rope_scaling_type = reader.Get("llama", "rope_scaling_type", "");
attn_param_.rope_scaling_factor = reader.GetFloat("llama", "rope_scaling_factor", 0.f);
attn_param_.low_freq_factor = reader.GetFloat("llama", "low_freq_factor", 1.0);
attn_param_.high_freq_factor = reader.GetFloat("llama", "high_freq_factor", 1.0);
attn_param_.max_position_embeddings = reader.GetInteger("llama", "max_position_embeddings", 0);
attn_param_.use_dynamic_ntk = reader.GetInteger("llama", "use_dynamic_ntk", 0);
attn_param_.use_logn_attn = reader.GetInteger("llama", "use_logn_attn", 0);

attn_param_.original_max_position_embeddings = reader.GetInteger("llama", "original_max_position_embeddings", 0);

engine_param_.max_batch_size = reader.GetInteger("llama", "max_batch_size", 0);
engine_param_.max_prefill_token_num = reader.GetInteger("llama", "max_prefill_token_num", 0);
engine_param_.max_context_token_num = reader.GetInteger("llama", "max_context_token_num", 0);
engine_param_.session_len = reader.GetInteger("llama", "session_len", 0);
engine_param_.step_length = reader.GetInteger("llama", "step_length", 0);

engine_param_.cache_max_block_count = reader.GetFloat("llama", "cache_max_entry_count", 0);
engine_param_.cache_chunk_size = reader.GetInteger("llama", "cache_chunk_size", 0);
engine_param_.enable_prefix_caching = reader.GetBoolean("llama", "enable_prefix_caching", false);

engine_param_.num_tokens_per_iter = reader.GetInteger("llama", "num_tokens_per_iter", 0);
engine_param_.max_prefill_iters = reader.GetInteger("llama", "max_prefill_iters", 1);

lora_param_.policy = ft::getLoraPolicy(reader.Get("llama", "lora_policy", ""));
lora_param_.r = reader.GetInteger("llama", "lora_r", 0);
lora_param_.scale = reader.GetFloat("llama", "lora_scale", 0);
lora_param_.max_wo_r = reader.GetInteger("llama", "lora_max_wo_r", 0);
lora_param_.rank_pattern = getLoraPattern<int>(reader.Get("llama", "lora_rank_pattern", ""),
attn_param_.rotary_embedding_dim = reader["rotary_embedding"].as<int>();
attn_param_.rotary_embedding_base = reader["rope_theta"].as<float>(10000.0f);
attn_param_.rope_scaling_type = reader["rope_scaling_type"].as<std::string>("");
attn_param_.rope_scaling_factor = reader["rope_scaling_factor"].as<float>(0.f);
attn_param_.low_freq_factor = reader["low_freq_factor"].as<float>(1.0);
attn_param_.high_freq_factor = reader["high_freq_factor"].as<float>(1.0);
attn_param_.max_position_embeddings = reader["max_position_embeddings"].as<int>(0);
attn_param_.use_dynamic_ntk = reader["use_dynamic_ntk"].as<int>(0);
attn_param_.use_logn_attn = reader["use_logn_attn"].as<int>(0);

attn_param_.original_max_position_embeddings = reader["original_max_position_embeddings"].as<int>(0);

engine_param_.max_batch_size = reader["max_batch_size"].as<int>(0);
engine_param_.max_prefill_token_num = reader["max_prefill_token_num"].as<int>(0);
engine_param_.max_context_token_num = reader["max_context_token_num"].as<int>(0);
engine_param_.session_len = reader["session_len"].as<int>(0);
engine_param_.step_length = reader["step_length"].as<int>(0);

engine_param_.cache_max_block_count = reader["cache_max_entry_count"].as<float>(0);
engine_param_.cache_chunk_size = reader["cache_chunk_size"].as<int>(0);
engine_param_.enable_prefix_caching = reader["enable_prefix_caching"].as<bool>(false);

engine_param_.num_tokens_per_iter = reader["num_tokens_per_iter"].as<int>(0);
engine_param_.max_prefill_iters = reader["max_prefill_iters"].as<int>(1);

lora_param_.policy = ft::getLoraPolicy(reader["lora_policy"].as<std::string>(""));
lora_param_.r = reader["lora_r"].as<int>(0);
lora_param_.scale = reader["lora_scale"].as<float>(0);
lora_param_.max_wo_r = reader["lora_max_wo_r"].as<int>(0);
lora_param_.rank_pattern = getLoraPattern<int>(reader["lora_rank_pattern"].as<std::string>(""),
[](const std::string& s) { return std::stoi(s); });
lora_param_.scale_pattern = getLoraPattern<float>(reader.Get("llama", "lora_scale_pattern", ""),
lora_param_.scale_pattern = getLoraPattern<float>(reader["lora_scale_pattern"].as<std::string>(""),
[](const std::string& s) { return std::stof(s); });
handleMissingParams();

Expand All @@ -273,7 +268,7 @@ LlamaTritonModel<T>::LlamaTritonModel(size_t tensor_para_size,
const auto device_count = ft::getDeviceCount();
engines_.resize(device_count);

const std::string weight_type_str = reader.Get("llama", "weight_type");
const std::string weight_type_str = reader["weight_type"].as<std::string>();
if (weight_type_str == "fp16") {
weight_type_ = ft::WeightType::kFP16;
}
Expand Down

0 comments on commit 8227b67

Please sign in to comment.