From 4767b04d7db5012c9553109cca4f3bcd3c3dd6c5 Mon Sep 17 00:00:00 2001 From: "q.yao" Date: Mon, 31 Jul 2023 20:48:47 +0800 Subject: [PATCH] Support Runtime tensor parallelism (#158) * works on interlm and vicuna * support GQA * remove comment * update readme, add logger, default tp=1 * remove log --- README.md | 14 +- README_zh-CN.md | 13 +- benchmark/profile_generation.py | 5 +- benchmark/profile_throughput.py | 9 +- lmdeploy/serve/turbomind/chatbot.py | 2 +- lmdeploy/serve/turbomind/utils.py | 80 +---- lmdeploy/turbomind/chat.py | 7 +- lmdeploy/turbomind/turbomind.py | 19 +- lmdeploy/turbomind/utils.py | 79 +++++ .../models/llama/LlamaDecoderLayerWeight.cc | 147 ++++++++- src/turbomind/utils/memory_utils.cu | 291 ++++++++++++++---- src/turbomind/utils/memory_utils.h | 17 +- 12 files changed, 492 insertions(+), 191 deletions(-) create mode 100644 lmdeploy/turbomind/utils.py diff --git a/README.md b/README.md index 9b8a7c696..5fa64e365 100644 --- a/README.md +++ b/README.md @@ -84,10 +84,13 @@ python3 -m lmdeploy.serve.turbomind.deploy internlm-chat-7b /path/to/internlm-ch python -m lmdeploy.turbomind.chat ./workspace ``` -```{note} -When inferring with FP16 precision, the InternLM-7B model requires at least 15.7G of GPU memory overhead on TurboMind. It is recommended to use NVIDIA cards such as 3090, V100, A100, etc. -Disable GPU ECC can free up 10% memory, try `sudo nvidia-smi --ecc-config=0` and reboot system. -``` +> **Note**
+> When inferring with FP16 precision, the InternLM-7B model requires at least 15.7G of GPU memory overhead on TurboMind.
+> It is recommended to use NVIDIA cards such as 3090, V100, A100, etc. +> Disable GPU ECC can free up 10% memory, try `sudo nvidia-smi --ecc-config=0` and reboot system. + +> **Note**
+> Tensor parallel is available to perform inference on multiple GPUs. Add `--tp=` on `chat` to enable runtime TP. #### Serving @@ -163,6 +166,9 @@ Then adjust `workspace/triton_models/weights/config.ini` Here is [quantization test results](./docs/en/quantization.md). +> **Warning**
+> runtime Tesnor Parallel for quantilized model is not available. Please setup `--tp` on `deploy` to enable static TP. + ## Contributing We appreciate all contributions to LMDeploy. Please refer to [CONTRIBUTING.md](.github/CONTRIBUTING.md) for the contributing guideline. diff --git a/README_zh-CN.md b/README_zh-CN.md index 620186acb..8f6b480d5 100644 --- a/README_zh-CN.md +++ b/README_zh-CN.md @@ -83,10 +83,12 @@ python3 -m lmdeploy.serve.turbomind.deploy internlm-chat-7b /path/to/internlm-ch python3 -m lmdeploy.turbomind.chat ./workspace ``` -```{note} -turbomind 在使用 FP16 精度推理 InternLM-7B 模型时,显存开销至少需要 15.7G。建议使用 3090, V100,A100等型号的显卡。 -关闭显卡的 ECC 可以腾出 10% 显存,执行 `sudo nvidia-smi --ecc-config=0` 重启系统生效。 -``` +> **Note**
+> turbomind 在使用 FP16 精度推理 InternLM-7B 模型时,显存开销至少需要 15.7G。建议使用 3090, V100,A100等型号的显卡。
+> 关闭显卡的 ECC 可以腾出 10% 显存,执行 `sudo nvidia-smi --ecc-config=0` 重启系统生效。 + +> **Note**
+> 使用 Tensor 并发可以利用多张 GPU 进行推理。在 `chat` 时添加参数 `--tp=` 可以启动运行时 TP。 #### 部署推理服务 @@ -162,6 +164,9 @@ python3 -m lmdeploy.lite.apis.kv_qparams \ 这里是[量化测试结果](./docs/zh_cn/quantization.md)。 +> **Warning**
+> 量化部署不支持运行时 Tensor 并发。如果希望使用 Tensor 并发,需要在 deploy 时配置 tp 参数。 + ## 贡献指南 我们感谢所有的贡献者为改进和提升 LMDeploy 所作出的努力。请参考[贡献指南](.github/CONTRIBUTING.md)来了解参与项目贡献的相关指引。 diff --git a/benchmark/profile_generation.py b/benchmark/profile_generation.py index 9249595f5..affec1d98 100644 --- a/benchmark/profile_generation.py +++ b/benchmark/profile_generation.py @@ -76,10 +76,11 @@ def main(model_path: str, concurrency: int = 1, input_seqlen: int = 0, output_seqlen: int = 512, - test_round: int = 10): + test_round: int = 10, + tp: int = 1): tokenizer_model_path = osp.join(model_path, 'triton_models', 'tokenizer') tokenizer = Tokenizer(tokenizer_model_path) - tm_model = TurboMind(model_path=model_path) + tm_model = TurboMind(model_path=model_path, tp=tp) warmup(tm_model, concurrency, output_seqlen) diff --git a/benchmark/profile_throughput.py b/benchmark/profile_throughput.py index 3a6500de6..d8100113c 100644 --- a/benchmark/profile_throughput.py +++ b/benchmark/profile_throughput.py @@ -54,11 +54,11 @@ def sample_requests( class Engine: - def __init__(self, model_path: str): + def __init__(self, model_path: str, tp: int = 1): tokenizer_model_path = osp.join(model_path, 'triton_models', 'tokenizer') tokenizer = Tokenizer(tokenizer_model_path) - tm_model = TurboMind(model_path=model_path) + tm_model = TurboMind(model_path=model_path, tp=tp) self.tm_model = tm_model self.tokenizer = tokenizer @@ -117,9 +117,10 @@ def process_request(self, requests, concurrency: int = 1): def main(dataset: str, model_path: str, concurrency: int = 1, - num_prompts: int = 1000): + num_prompts: int = 1000, + tp: int = 1): - engine = Engine(model_path) + engine = Engine(model_path, tp=tp) tokenizer = engine.tokenizer requests = sample_requests(dataset, num_prompts, tokenizer) diff --git a/lmdeploy/serve/turbomind/chatbot.py b/lmdeploy/serve/turbomind/chatbot.py index 969faabdf..045698b5b 100644 --- a/lmdeploy/serve/turbomind/chatbot.py +++ b/lmdeploy/serve/turbomind/chatbot.py @@ -52,7 +52,7 @@ def stream_callback(que, result, error): def get_logger(log_file=None, log_level=logging.INFO): """Return the logger.""" - from .utils import get_logger + from lmdeploy.turbomind.utils import get_logger logger = get_logger('service.ft', log_file=log_file, log_level=log_level) return logger diff --git a/lmdeploy/serve/turbomind/utils.py b/lmdeploy/serve/turbomind/utils.py index ba3fd8921..bd1c3a16c 100644 --- a/lmdeploy/serve/turbomind/utils.py +++ b/lmdeploy/serve/turbomind/utils.py @@ -1,88 +1,10 @@ # Copyright (c) OpenMMLab. All rights reserved. -# Copyright (c) OpenMMLab. All rights reserved. -import logging -from typing import List, Optional, Union +from typing import List, Union import numpy as np import tritonclient.grpc as grpcclient from tritonclient.utils import np_to_triton_dtype -logger_initialized = {} - - -def get_logger(name: str, - log_file: Optional[str] = None, - log_level: int = logging.INFO, - file_mode: str = 'w'): - """Initialize and get a logger by name. - - If the logger has not been initialized, this method will initialize the - logger by adding one or two handlers, otherwise the initialized logger will - be directly returned. During initialization, a StreamHandler will always be - added. If `log_file` is specified, a FileHandler will also be added. - Args: - name (str): Logger name. - log_file (str | None): The log filename. If specified, a FileHandler - will be added to the logger. - log_level (int): The logger level. - file_mode (str): The file mode used in opening log file. - Defaults to 'w'. - Returns: - logging.Logger: The expected logger. - """ - # use logger in mmengine if exists. - try: - from mmengine.logging import MMLogger - if MMLogger.check_instance_created(name): - logger = MMLogger.get_instance(name) - else: - logger = MMLogger.get_instance(name, - logger_name=name, - log_file=log_file, - log_level=log_level, - file_mode=file_mode) - return logger - - except Exception: - pass - - logger = logging.getLogger(name) - if name in logger_initialized: - return logger - # handle hierarchical names - # e.g., logger "a" is initialized, then logger "a.b" will skip the - # initialization since it is a child of "a". - for logger_name in logger_initialized: - if name.startswith(logger_name): - return logger - - # handle duplicate logs to the console - for handler in logger.root.handlers: - if type(handler) is logging.StreamHandler: - handler.setLevel(logging.ERROR) - - stream_handler = logging.StreamHandler() - handlers = [stream_handler] - - if log_file is not None: - # Here, the default behaviour of the official logger is 'a'. Thus, we - # provide an interface to change the file mode to the default - # behaviour. - file_handler = logging.FileHandler(log_file, file_mode) - handlers.append(file_handler) - - formatter = logging.Formatter( - '%(asctime)s - %(name)s - %(levelname)s - %(message)s') - for handler in handlers: - handler.setFormatter(formatter) - handler.setLevel(log_level) - logger.addHandler(handler) - - logger.setLevel(log_level) - logger_initialized[name] = True - - return logger - def prepare_tensor(name, input_tensor): """Create grpcclient's InferInput instance according to a given tensor.""" diff --git a/lmdeploy/turbomind/chat.py b/lmdeploy/turbomind/chat.py index 427735073..f3f2991d4 100644 --- a/lmdeploy/turbomind/chat.py +++ b/lmdeploy/turbomind/chat.py @@ -29,7 +29,10 @@ def valid_str(string, coding='utf-8'): return ret -def main(model_path, session_id: int = 1, repetition_penalty: float = 1.0): +def main(model_path, + session_id: int = 1, + repetition_penalty: float = 1.0, + tp=1): """An example to perform model inference through the command line interface. @@ -39,7 +42,7 @@ def main(model_path, session_id: int = 1, repetition_penalty: float = 1.0): """ tokenizer_model_path = osp.join(model_path, 'triton_models', 'tokenizer') tokenizer = Tokenizer(tokenizer_model_path) - tm_model = tm.TurboMind(model_path, eos_id=tokenizer.eos_token_id) + tm_model = tm.TurboMind(model_path, eos_id=tokenizer.eos_token_id, tp=tp) generator = tm_model.create_instance() nth_round = 1 diff --git a/lmdeploy/turbomind/turbomind.py b/lmdeploy/turbomind/turbomind.py index efe7e65b7..0e6859bef 100644 --- a/lmdeploy/turbomind/turbomind.py +++ b/lmdeploy/turbomind/turbomind.py @@ -13,6 +13,7 @@ import lmdeploy from lmdeploy.model import MODELS +from lmdeploy.turbomind.utils import get_logger # TODO: find another way import _turbomind lmdeploy_dir = osp.split(lmdeploy.__file__)[0] @@ -69,14 +70,11 @@ class TurboMind: Args: model_path (str): the path of turbomind's model - data_type (str): the data type eos_id (int): eos token id + tp (int): tensor parallel """ - def __init__(self, - model_path: str, - data_type: str = 'fp16', - eos_id: int = 2): + def __init__(self, model_path: str, eos_id: int = 2, tp: int = 1): self.eos_id = eos_id # TODO: support mpi @@ -84,8 +82,9 @@ def __init__(self, node_num = 1 # read meta from model path - self.gpu_count = 1 + self.gpu_count = tp self.session_len = 2048 + data_type = 'fp16' ini_path = osp.join(model_path, 'triton_models/weights/config.ini') with open(ini_path, 'r') as f: parser = ConfigParser() @@ -97,10 +96,14 @@ def __init__(self, section_name = 'llama' if len(section_name) > 0: - self.gpu_count = parser.getint(section_name, - 'tensor_para_size') + tp_cfg = parser.getint(section_name, 'tensor_para_size') self.session_len = parser.getint(section_name, 'session_len') + if tp_cfg != 1 and tp_cfg != tp: + get_logger('turbomind').info( + f'found tp={tp_cfg} in config.ini.') + self.gpu_count = tp_cfg self.model_name = parser.get(section_name, 'model_name') + data_type = parser.get(section_name, 'weight_type') model = MODELS.get(self.model_name)() self.stop_words = _stop_words(model.stop_words) diff --git a/lmdeploy/turbomind/utils.py b/lmdeploy/turbomind/utils.py new file mode 100644 index 000000000..7b6d51a01 --- /dev/null +++ b/lmdeploy/turbomind/utils.py @@ -0,0 +1,79 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import logging +from typing import Optional + +logger_initialized = {} + + +def get_logger(name: str, + log_file: Optional[str] = None, + log_level: int = logging.INFO, + file_mode: str = 'w'): + """Initialize and get a logger by name. + + If the logger has not been initialized, this method will initialize the + logger by adding one or two handlers, otherwise the initialized logger will + be directly returned. During initialization, a StreamHandler will always be + added. If `log_file` is specified, a FileHandler will also be added. + Args: + name (str): Logger name. + log_file (str | None): The log filename. If specified, a FileHandler + will be added to the logger. + log_level (int): The logger level. + file_mode (str): The file mode used in opening log file. + Defaults to 'w'. + Returns: + logging.Logger: The expected logger. + """ + # use logger in mmengine if exists. + try: + from mmengine.logging import MMLogger + if MMLogger.check_instance_created(name): + logger = MMLogger.get_instance(name) + else: + logger = MMLogger.get_instance(name, + logger_name=name, + log_file=log_file, + log_level=log_level, + file_mode=file_mode) + return logger + + except Exception: + pass + + logger = logging.getLogger(name) + if name in logger_initialized: + return logger + # handle hierarchical names + # e.g., logger "a" is initialized, then logger "a.b" will skip the + # initialization since it is a child of "a". + for logger_name in logger_initialized: + if name.startswith(logger_name): + return logger + + # handle duplicate logs to the console + for handler in logger.root.handlers: + if type(handler) is logging.StreamHandler: + handler.setLevel(logging.ERROR) + + stream_handler = logging.StreamHandler() + handlers = [stream_handler] + + if log_file is not None: + # Here, the default behaviour of the official logger is 'a'. Thus, we + # provide an interface to change the file mode to the default + # behaviour. + file_handler = logging.FileHandler(log_file, file_mode) + handlers.append(file_handler) + + formatter = logging.Formatter( + '%(asctime)s - %(name)s - %(levelname)s - %(message)s') + for handler in handlers: + handler.setFormatter(formatter) + handler.setLevel(log_level) + logger.addHandler(handler) + + logger.setLevel(log_level) + logger_initialized[name] = True + + return logger diff --git a/src/turbomind/models/llama/LlamaDecoderLayerWeight.cc b/src/turbomind/models/llama/LlamaDecoderLayerWeight.cc index 48b43f8cd..e39cb0bef 100644 --- a/src/turbomind/models/llama/LlamaDecoderLayerWeight.cc +++ b/src/turbomind/models/llama/LlamaDecoderLayerWeight.cc @@ -21,6 +21,7 @@ #include "src/turbomind/models/llama/LlamaDecoderLayerWeight.h" #include "src/turbomind/utils/logger.h" #include "src/turbomind/utils/memory_utils.h" +#include namespace turbomind { @@ -99,25 +100,135 @@ void mallocWeights(LlamaDenseWeight& weights, bool bias) } template -void loadWeights(LlamaDenseWeight& w, std::string prefix, int rank, FtCudaDataType model_file_type) +void loadWeights(LlamaDenseWeight& w, + std::string prefix, + int rank, + FtCudaDataType model_file_type, + size_t tensor_para_size, + int slice_dim = 0, + std::vector slice_shape = {}) { - prefix += "." + std::to_string(rank); - const auto type = model_file_type; + auto max_prefix = prefix + "." + std::to_string(tensor_para_size - 1); + const auto type = model_file_type; + + bool enable_slice = true; + // Disable slice if tensor param rank is 1 + if (tensor_para_size <= 1) { + enable_slice = false; + } + else { + // Disable slice if weight has already been sliced + if (std::filesystem::exists(max_prefix + ".weight") || std::filesystem::exists(max_prefix + ".qweight")) { + TM_LOG_DEBUG("TP weight exists. Disable runtime TP."); + enable_slice = false; + } + } + + size_t dim0 = w.input_dims; + size_t dim1 = w.output_dims; + if (enable_slice) { + // multiple tp size for slice stride + if (slice_dim == 0) { + dim0 = dim0 * tensor_para_size; + if (slice_shape.size() == 0) { + slice_shape = {dim0}; + } + } + else { + dim1 = dim1 * tensor_para_size; + if (slice_shape.size() == 0) { + slice_shape = {dim1}; + } + } + + prefix += "." + std::to_string(0); + } + else { + prefix += "." + std::to_string(rank); + } if (w.bias) { - loadWeightFromBin((T*)w.bias, {w.output_dims}, prefix + ".bias", type); + std::vector bias_slices{}; + if (enable_slice) { + if (slice_dim == 1) { + size_t start = 0; + ConcateSlice slice0{.slices = {{0, 1}}}; + ConcateSlice slice1{.slices = {{}}}; + for (auto len : slice_shape) { + size_t stride = len / tensor_para_size; + slice1.slices.push_back({start + stride * rank, start + stride * (rank + 1)}); + start += len; + } + bias_slices = {slice0, slice1}; + } + } + loadWeightFromBin((T*)w.bias, {1, dim1}, prefix + ".bias", type, bias_slices); } const size_t bit_size = getBitSize(w.type); if (bit_size >= 16) { // fp16, fp32 - loadWeightFromBin((T*)w.kernel, {w.input_dims, w.output_dims}, prefix + ".weight", type); + std::vector weight_slices{}; + if (enable_slice) { + if (slice_dim == 1) { + size_t start = 0; + ConcateSlice slice0{.slices = {{0, dim0}}}; + ConcateSlice slice1{.slices = {{}}}; + for (auto len : slice_shape) { + size_t stride = len / tensor_para_size; + slice1.slices.push_back({start + stride * rank, start + stride * (rank + 1)}); + start += len; + } + weight_slices = {slice0, slice1}; + } + else { + size_t start = 0; + ConcateSlice slice0{.slices = {}}; + ConcateSlice slice1{.slices = {{0, dim1}}}; + for (auto len : slice_shape) { + size_t stride = len / tensor_para_size; + slice0.slices.push_back({start + stride * rank, start + stride * (rank + 1)}); + start += len; + } + weight_slices = {slice0, slice1}; + } + } + loadWeightFromBin((T*)w.kernel, {dim0, dim1}, prefix + ".weight", type, weight_slices); } else { // int8, int4 const int factor = sizeof(float) * 8 / bit_size; - FT_CHECK(w.input_dims % factor == 0); - const auto f32_type = FtCudaDataType::FP32; - loadWeightFromBin((float*)w.kernel, {w.input_dims / factor, w.output_dims}, prefix + ".qweight", f32_type); - loadWeightFromBin((T*)w.scales, {w.output_dims}, prefix + ".scales", type); - loadWeightFromBin((T*)w.zeros, {w.output_dims}, prefix + ".zeros", type); + FT_CHECK(dim0 % factor == 0); + const auto f32_type = FtCudaDataType::FP32; + std::vector weight_slices{}; + std::vector bias_slices{}; + if (enable_slice) { + if (slice_dim == 1) { + size_t start = 0; + ConcateSlice slice0{.slices = {{0, dim0}}}; + ConcateSlice slice1{.slices = {{}}}; + for (auto len : slice_shape) { + size_t stride = len / tensor_para_size; + slice1.slices.push_back({start + stride * rank, start + stride * (rank + 1)}); + start += len; + } + weight_slices = {slice0, slice1}; + + ConcateSlice bias_slice0{.slices = {{0, 1}}}; + bias_slices = {bias_slice0, slice1}; + } + else { + size_t start = 0; + ConcateSlice slice0{.slices = {}}; + ConcateSlice slice1{.slices = {{0, dim1}}}; + for (auto len : slice_shape) { + size_t stride = len / factor / tensor_para_size; + slice0.slices.push_back({start + stride * rank, start + stride * (rank + 1)}); + start += len; + } + weight_slices = {slice0, slice1}; + } + } + loadWeightFromBin((float*)w.kernel, {dim0 / factor, dim1}, prefix + ".qweight", f32_type, weight_slices); + loadWeightFromBin((T*)w.scales, {1, dim1}, prefix + ".scales", type, bias_slices); + loadWeightFromBin((T*)w.zeros, {1, dim1}, prefix + ".zeros", type, bias_slices); } } @@ -158,11 +269,17 @@ void LlamaDecoderLayerWeight::loadModel(std::string dir_path, FtCudaDataType (T*)self_attn_norm_weights, {hidden_units_}, dir_path + ".attention_norm.weight", model_file_type); loadWeightFromBin((T*)ffn_norm_weights, {hidden_units_}, dir_path + ".ffn_norm.weight", model_file_type); - loadWeights(self_attn_weights.qkv, dir_path + ".attention.w_qkv", tensor_para_rank_, type); - loadWeights(self_attn_weights.output, dir_path + ".attention.wo", tensor_para_rank_, type); - loadWeights(ffn_weights.gating, dir_path + ".feed_forward.w1", tensor_para_rank_, type); - loadWeights(ffn_weights.intermediate, dir_path + ".feed_forward.w3", tensor_para_rank_, type); - loadWeights(ffn_weights.output, dir_path + ".feed_forward.w2", tensor_para_rank_, type); + loadWeights(self_attn_weights.qkv, + dir_path + ".attention.w_qkv", + tensor_para_rank_, + type, + tensor_para_size_, + 1, + {head_num_ * size_per_head_, kv_head_num_ * size_per_head_, kv_head_num_ * size_per_head_}); + loadWeights(self_attn_weights.output, dir_path + ".attention.wo", tensor_para_rank_, type, tensor_para_size_, 0); + loadWeights(ffn_weights.gating, dir_path + ".feed_forward.w1", tensor_para_rank_, type, tensor_para_size_, 1); + loadWeights(ffn_weights.intermediate, dir_path + ".feed_forward.w3", tensor_para_rank_, type, tensor_para_size_, 1); + loadWeights(ffn_weights.output, dir_path + ".feed_forward.w2", tensor_para_rank_, type, tensor_para_size_, 0); // load kv_cache quant scale // if file not exist, get empty vector diff --git a/src/turbomind/utils/memory_utils.cu b/src/turbomind/utils/memory_utils.cu index 02be10577..a419e2a9c 100644 --- a/src/turbomind/utils/memory_utils.cu +++ b/src/turbomind/utils/memory_utils.cu @@ -301,58 +301,158 @@ template void cudaRandomUniform(__nv_fp8_e4m3* buffer, const size_t size); // loads data from binary file. If it succeeds, returns a non-empty vector. If loading fails or // the product of the elements in shape is 0, this function will return an empty vector. template -std::vector loadWeightFromBinHelper(std::vector shape, std::string filename) +std::vector +loadWeightFromBinHelper(std::vector shape, std::string filename, std::vector slices = {}) { if (shape.size() > 2) { printf("[ERROR] shape should have less than two dims \n"); return std::vector(); } + size_t dim0 = shape[0], dim1 = 1; if (shape.size() == 2) { dim1 = shape[1]; } - size_t size = dim0 * dim1; - if (size == 0) { - TM_LOG_WARNING("shape is zero, skip loading weight from file %s \n", filename.c_str()); - return std::vector(); - } - std::vector host_array(size); - std::ifstream in(filename, std::ios::in | std::ios::binary); - if (!in.is_open()) { - TM_LOG_WARNING("file %s cannot be opened, loading model fails! \n", filename.c_str()); - return std::vector(); + if (slices.size() == 0) { + size_t size = dim0 * dim1; + if (size == 0) { + TM_LOG_WARNING("shape is zero, skip loading weight from file %s \n", filename.c_str()); + return std::vector(); + } + + std::vector host_array(size); + std::ifstream in(filename, std::ios::in | std::ios::binary); + if (!in.is_open()) { + TM_LOG_WARNING("file %s cannot be opened, loading model fails! \n", filename.c_str()); + return std::vector(); + } + + size_t loaded_data_size = sizeof(T) * size; + in.seekg(0, in.end); + in.seekg(0, in.beg); + + TM_LOG_DEBUG("Read " + std::to_string(loaded_data_size) + " bytes from " + filename); + in.read((char*)host_array.data(), loaded_data_size); + + size_t in_get_size = in.gcount(); + if (in_get_size != loaded_data_size) { + TM_LOG_WARNING("file %s only has %ld, but request %ld, loading model fails! \n", + filename.c_str(), + in_get_size, + loaded_data_size); + return std::vector(); + } + in.close(); + // If we succeed, return an array with values. + return host_array; } + else { + // concate all slices on the same dims + + if (slices.size() != shape.size()) { + printf("[ERROR] slices should have same dims as shape \n"); + return std::vector(); + } - size_t loaded_data_size = sizeof(T) * size; - in.seekg(0, in.end); - in.seekg(0, in.beg); + // get slices + ConcateSlice slice0{.slices = {{0, dim0}}}; + ConcateSlice slice1{.slices = {{0, dim1}}}; + if (slices.size() > 0 && slices[0].slices.size() > 0) { + slice0 = slices[0]; + } + if (shape.size() == 2 && slices[1].slices.size() > 0) { + slice1 = slices[1]; + } - TM_LOG_DEBUG("Read " + std::to_string(loaded_data_size) + " bytes from " + filename); - in.read((char*)host_array.data(), loaded_data_size); + size_t w0 = 0; + for (auto& s : slice0.slices) { + if (s.second > dim0) { + s.second = dim0; + } + if (s.second < s.first) { + printf("[ERROR] slice0: end < start \n"); + return std::vector(); + } + w0 += s.second - s.first; + } - size_t in_get_size = in.gcount(); - if (in_get_size != loaded_data_size) { - TM_LOG_WARNING("file %s only has %ld, but request %ld, loading model fails! \n", - filename.c_str(), - in_get_size, - loaded_data_size); - return std::vector(); + size_t w1 = 0; + for (auto& s : slice1.slices) { + if (s.second > dim1) { + s.second = dim1; + } + if (s.second < s.first) { + printf("[ERROR] slice1: end < start \n"); + return std::vector(); + } + w1 += s.second - s.first; + } + + size_t size = w0 * w1; + size_t loaded_data_size = size * sizeof(T); + + TM_LOG_DEBUG("Read " + std::to_string(loaded_data_size) + " bytes from " + filename + " with slice."); + if (size == 0) { + TM_LOG_WARNING("shape is zero, skip loading weight from file %s \n", filename.c_str()); + return std::vector(); + } + + std::vector host_array(size); + std::ifstream in(filename, std::ios::in | std::ios::binary); + if (!in.is_open()) { + TM_LOG_WARNING("file %s cannot be opened, loading model fails! \n", filename.c_str()); + return std::vector(); + } + + char* host_ptr = (char*)host_array.data(); + if (slice1.slices.size() == 0 + || (slice1.slices.size() == 1 && slice1.slices[0].second - slice1.slices[0].first == dim1)) { + for (auto& s : slice0.slices) { + size_t read_size = (s.second - s.first) * dim1 * sizeof(T); + size_t pos = s.first * dim1; + in.seekg(pos * sizeof(T)); + in.read((char*)host_ptr, read_size); + host_ptr += read_size; + } + in.close(); + return host_array; + } + + { + for (auto& s0 : slice0.slices) { + // loop over outer slice + for (size_t line_id = s0.first; line_id < s0.second; ++line_id) { + // loop over lines + size_t pos0 = line_id * dim1; + for (auto& s1 : slice1.slices) { + // loop over inner slice + size_t pos = pos0 + s1.first; + size_t read_size = (s1.second - s1.first) * sizeof(T); + in.seekg(pos * sizeof(T)); + in.read(host_ptr, read_size); + host_ptr += read_size; + } + } + } + in.close(); + } + return host_array; } - in.close(); - // If we succeed, return an array with values. - return host_array; } -std::vector loadArrayFromBin(std::vector shape, std::string filename) +std::vector loadArrayFromBin(std::vector shape, std::string filename, std::vector slices) { - return loadWeightFromBinHelper(shape, filename); + return loadWeightFromBinHelper(shape, filename, slices); } template -int loadWeightFromBinFunc(T* ptr, std::vector shape, std::string filename) +int loadWeightFromBinFunc(T* ptr, + std::vector shape, + std::string filename, + std::vector slices = std::vector()) { - std::vector host_array = loadWeightFromBinHelper(shape, filename); + std::vector host_array = loadWeightFromBinHelper(shape, filename, slices); if (host_array.empty()) { return 0; @@ -371,49 +471,84 @@ int loadWeightFromBinFunc(T* ptr, std::vector shape, std::string filenam return 0; } -template int loadWeightFromBinFunc(float* ptr, std::vector shape, std::string filename); -template int loadWeightFromBinFunc(half* ptr, std::vector shape, std::string filename); -template int loadWeightFromBinFunc(float* ptr, std::vector shape, std::string filename); -template int loadWeightFromBinFunc(half* ptr, std::vector shape, std::string filename); -template int loadWeightFromBinFunc(int8_t* ptr, std::vector shape, std::string filename); +template int loadWeightFromBinFunc(float* ptr, + std::vector shape, + std::string filename, + std::vector slices); +template int loadWeightFromBinFunc(half* ptr, + std::vector shape, + std::string filename, + std::vector slices); +template int loadWeightFromBinFunc(float* ptr, + std::vector shape, + std::string filename, + std::vector slices); +template int loadWeightFromBinFunc(half* ptr, + std::vector shape, + std::string filename, + std::vector slices); +template int loadWeightFromBinFunc(int8_t* ptr, + std::vector shape, + std::string filename, + std::vector slices); #ifdef ENABLE_BF16 -template int -loadWeightFromBinFunc<__nv_bfloat16, float>(__nv_bfloat16* ptr, std::vector shape, std::string filename); -template int -loadWeightFromBinFunc<__nv_bfloat16, half>(__nv_bfloat16* ptr, std::vector shape, std::string filename); -template int loadWeightFromBinFunc(float* ptr, std::vector shape, std::string filename); -template int loadWeightFromBinFunc(half* ptr, std::vector shape, std::string filename); -template int loadWeightFromBinFunc<__nv_bfloat16, __nv_bfloat16>(__nv_bfloat16* ptr, - std::vector shape, - std::string filename); +template int loadWeightFromBinFunc<__nv_bfloat16, float>(__nv_bfloat16* ptr, + std::vector shape, + std::string filename, + std::vector slices); +template int loadWeightFromBinFunc<__nv_bfloat16, half>(__nv_bfloat16* ptr, + std::vector shape, + std::string filename, + std::vector slices); +template int loadWeightFromBinFunc(float* ptr, + std::vector shape, + std::string filename, + std::vector slices); +template int loadWeightFromBinFunc(half* ptr, + std::vector shape, + std::string filename, + std::vector slices); +template int loadWeightFromBinFunc<__nv_bfloat16, __nv_bfloat16>(__nv_bfloat16* ptr, + std::vector shape, + std::string filename, + std::vector slices); #endif // ENABLE_BF16 -template int loadWeightFromBinFunc(int* ptr, std::vector shape, std::string filename); +template int loadWeightFromBinFunc(int* ptr, + std::vector shape, + std::string filename, + std::vector slices); #ifdef ENABLE_FP8 -template int -loadWeightFromBinFunc<__nv_fp8_e4m3, float>(__nv_fp8_e4m3* ptr, std::vector shape, std::string filename); +template int loadWeightFromBinFunc<__nv_fp8_e4m3, float>(__nv_fp8_e4m3* ptr, + std::vector shape, + std::string filename, + std::vector slices); #endif // ENABLE_FP8 template -int loadWeightFromBin(T* ptr, std::vector shape, std::string filename, FtCudaDataType model_file_type) +int loadWeightFromBin(T* ptr, + std::vector shape, + std::string filename, + FtCudaDataType model_file_type, + std::vector slices) { switch (model_file_type) { case FtCudaDataType::FP32: - loadWeightFromBinFunc(ptr, shape, filename); + loadWeightFromBinFunc(ptr, shape, filename, slices); break; case FtCudaDataType::FP16: - loadWeightFromBinFunc(ptr, shape, filename); + loadWeightFromBinFunc(ptr, shape, filename, slices); break; case FtCudaDataType::INT8: - loadWeightFromBinFunc(ptr, shape, filename); + loadWeightFromBinFunc(ptr, shape, filename, slices); break; #ifdef ENABLE_BF16 case FtCudaDataType::BF16: - loadWeightFromBinFunc(ptr, shape, filename); + loadWeightFromBinFunc(ptr, shape, filename, slices); break; #endif #ifdef ENABLE_FP8 case FtCudaDataType::FP8: - loadWeightFromBinFunc(ptr, shape, filename); + loadWeightFromBinFunc(ptr, shape, filename, slices); break; #endif default: @@ -424,28 +559,50 @@ int loadWeightFromBin(T* ptr, std::vector shape, std::string filename, F } template<> -int loadWeightFromBin(int* ptr, std::vector shape, std::string filename, FtCudaDataType model_file_type) +int loadWeightFromBin(int* ptr, + std::vector shape, + std::string filename, + FtCudaDataType model_file_type, + std::vector slices) { - loadWeightFromBinFunc(ptr, shape, filename); + loadWeightFromBinFunc(ptr, shape, filename, slices); return 0; } -template int -loadWeightFromBin(float* ptr, std::vector shape, std::string filename, FtCudaDataType model_file_type); -template int -loadWeightFromBin(half* ptr, std::vector shape, std::string filename, FtCudaDataType model_file_type); -template int -loadWeightFromBin(int8_t* ptr, std::vector shape, std::string filename, FtCudaDataType model_file_type); +template int loadWeightFromBin(float* ptr, + std::vector shape, + std::string filename, + FtCudaDataType model_file_type, + std::vector slices); +template int loadWeightFromBin(half* ptr, + std::vector shape, + std::string filename, + FtCudaDataType model_file_type, + std::vector slices); +template int loadWeightFromBin(int8_t* ptr, + std::vector shape, + std::string filename, + FtCudaDataType model_file_type, + std::vector slices); #ifdef ENABLE_BF16 -template int -loadWeightFromBin(__nv_bfloat16* ptr, std::vector shape, std::string filename, FtCudaDataType model_file_type); +template int loadWeightFromBin(__nv_bfloat16* ptr, + std::vector shape, + std::string filename, + FtCudaDataType model_file_type, + std::vector slices); #endif #ifdef ENABLE_FP8 -template int -loadWeightFromBin(__nv_fp8_e4m3* ptr, std::vector shape, std::string filename, FtCudaDataType model_file_type); +template int loadWeightFromBin(__nv_fp8_e4m3* ptr, + std::vector shape, + std::string filename, + FtCudaDataType model_file_type, + std::vector slices); #endif -template int -loadWeightFromBin(int* ptr, std::vector shape, std::string filename, FtCudaDataType model_file_type); +template int loadWeightFromBin(int* ptr, + std::vector shape, + std::string filename, + FtCudaDataType model_file_type, + std::vector slices); template __global__ void cudaD2DcpyConvert(T_OUT* dst, const T_IN* src, const size_t size) diff --git a/src/turbomind/utils/memory_utils.h b/src/turbomind/utils/memory_utils.h index 6ce11e771..e51c90390 100644 --- a/src/turbomind/utils/memory_utils.h +++ b/src/turbomind/utils/memory_utils.h @@ -49,13 +49,20 @@ void cudaAutoCpy(T* tgt, const T* src, const size_t size, cudaStream_t stream = template void cudaRandomUniform(T* buffer, const size_t size); +struct ConcateSlice { + std::vector> slices; +}; + template -int loadWeightFromBin(T* ptr, - std::vector shape, - std::string filename, - FtCudaDataType model_file_type = FtCudaDataType::FP32); +int loadWeightFromBin(T* ptr, + std::vector shape, + std::string filename, + FtCudaDataType model_file_type = FtCudaDataType::FP32, + std::vector slices = std::vector()); -std::vector loadArrayFromBin(std::vector shape, std::string filename); +std::vector loadArrayFromBin(std::vector shape, + std::string filename, + std::vector slices = std::vector()); // template // int loadWeightFromBinAndQuantizeForWeightOnly(int8_t* quantized_weight_ptr,