Skip to content

Commit

Permalink
Support Runtime tensor parallelism (InternLM#158)
Browse files Browse the repository at this point in the history
* works on interlm and vicuna

* support GQA

* remove comment

* update readme, add logger, default tp=1

* remove log
  • Loading branch information
grimoire authored Jul 31, 2023
1 parent 981a461 commit 4767b04
Show file tree
Hide file tree
Showing 12 changed files with 492 additions and 191 deletions.
14 changes: 10 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -84,10 +84,13 @@ python3 -m lmdeploy.serve.turbomind.deploy internlm-chat-7b /path/to/internlm-ch
python -m lmdeploy.turbomind.chat ./workspace
```

```{note}
When inferring with FP16 precision, the InternLM-7B model requires at least 15.7G of GPU memory overhead on TurboMind. It is recommended to use NVIDIA cards such as 3090, V100, A100, etc.
Disable GPU ECC can free up 10% memory, try `sudo nvidia-smi --ecc-config=0` and reboot system.
```
> **Note**<br />
> When inferring with FP16 precision, the InternLM-7B model requires at least 15.7G of GPU memory overhead on TurboMind. <br />
> It is recommended to use NVIDIA cards such as 3090, V100, A100, etc.
> Disable GPU ECC can free up 10% memory, try `sudo nvidia-smi --ecc-config=0` and reboot system.
> **Note**<br />
> Tensor parallel is available to perform inference on multiple GPUs. Add `--tp=<num_gpu>` on `chat` to enable runtime TP.
#### Serving

Expand Down Expand Up @@ -163,6 +166,9 @@ Then adjust `workspace/triton_models/weights/config.ini`

Here is [quantization test results](./docs/en/quantization.md).

> **Warning**<br />
> runtime Tesnor Parallel for quantilized model is not available. Please setup `--tp` on `deploy` to enable static TP.
## Contributing

We appreciate all contributions to LMDeploy. Please refer to [CONTRIBUTING.md](.github/CONTRIBUTING.md) for the contributing guideline.
Expand Down
13 changes: 9 additions & 4 deletions README_zh-CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,12 @@ python3 -m lmdeploy.serve.turbomind.deploy internlm-chat-7b /path/to/internlm-ch
python3 -m lmdeploy.turbomind.chat ./workspace
```

```{note}
turbomind 在使用 FP16 精度推理 InternLM-7B 模型时,显存开销至少需要 15.7G。建议使用 3090, V100,A100等型号的显卡。
关闭显卡的 ECC 可以腾出 10% 显存,执行 `sudo nvidia-smi --ecc-config=0` 重启系统生效。
```
> **Note**<br />
> turbomind 在使用 FP16 精度推理 InternLM-7B 模型时,显存开销至少需要 15.7G。建议使用 3090, V100,A100等型号的显卡。<br />
> 关闭显卡的 ECC 可以腾出 10% 显存,执行 `sudo nvidia-smi --ecc-config=0` 重启系统生效。
> **Note**<br />
> 使用 Tensor 并发可以利用多张 GPU 进行推理。在 `chat` 时添加参数 `--tp=<num_gpu>` 可以启动运行时 TP。
#### 部署推理服务

Expand Down Expand Up @@ -162,6 +164,9 @@ python3 -m lmdeploy.lite.apis.kv_qparams \

这里是[量化测试结果](./docs/zh_cn/quantization.md)

> **Warning**<br />
> 量化部署不支持运行时 Tensor 并发。如果希望使用 Tensor 并发,需要在 deploy 时配置 tp 参数。
## 贡献指南

我们感谢所有的贡献者为改进和提升 LMDeploy 所作出的努力。请参考[贡献指南](.github/CONTRIBUTING.md)来了解参与项目贡献的相关指引。
Expand Down
5 changes: 3 additions & 2 deletions benchmark/profile_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,10 +76,11 @@ def main(model_path: str,
concurrency: int = 1,
input_seqlen: int = 0,
output_seqlen: int = 512,
test_round: int = 10):
test_round: int = 10,
tp: int = 1):
tokenizer_model_path = osp.join(model_path, 'triton_models', 'tokenizer')
tokenizer = Tokenizer(tokenizer_model_path)
tm_model = TurboMind(model_path=model_path)
tm_model = TurboMind(model_path=model_path, tp=tp)

warmup(tm_model, concurrency, output_seqlen)

Expand Down
9 changes: 5 additions & 4 deletions benchmark/profile_throughput.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,11 @@ def sample_requests(

class Engine:

def __init__(self, model_path: str):
def __init__(self, model_path: str, tp: int = 1):
tokenizer_model_path = osp.join(model_path, 'triton_models',
'tokenizer')
tokenizer = Tokenizer(tokenizer_model_path)
tm_model = TurboMind(model_path=model_path)
tm_model = TurboMind(model_path=model_path, tp=tp)
self.tm_model = tm_model
self.tokenizer = tokenizer

Expand Down Expand Up @@ -117,9 +117,10 @@ def process_request(self, requests, concurrency: int = 1):
def main(dataset: str,
model_path: str,
concurrency: int = 1,
num_prompts: int = 1000):
num_prompts: int = 1000,
tp: int = 1):

engine = Engine(model_path)
engine = Engine(model_path, tp=tp)
tokenizer = engine.tokenizer

requests = sample_requests(dataset, num_prompts, tokenizer)
Expand Down
2 changes: 1 addition & 1 deletion lmdeploy/serve/turbomind/chatbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def stream_callback(que, result, error):

def get_logger(log_file=None, log_level=logging.INFO):
"""Return the logger."""
from .utils import get_logger
from lmdeploy.turbomind.utils import get_logger
logger = get_logger('service.ft', log_file=log_file, log_level=log_level)
return logger

Expand Down
80 changes: 1 addition & 79 deletions lmdeploy/serve/turbomind/utils.py
Original file line number Diff line number Diff line change
@@ -1,88 +1,10 @@
# Copyright (c) OpenMMLab. All rights reserved.
# Copyright (c) OpenMMLab. All rights reserved.
import logging
from typing import List, Optional, Union
from typing import List, Union

import numpy as np
import tritonclient.grpc as grpcclient
from tritonclient.utils import np_to_triton_dtype

logger_initialized = {}


def get_logger(name: str,
log_file: Optional[str] = None,
log_level: int = logging.INFO,
file_mode: str = 'w'):
"""Initialize and get a logger by name.
If the logger has not been initialized, this method will initialize the
logger by adding one or two handlers, otherwise the initialized logger will
be directly returned. During initialization, a StreamHandler will always be
added. If `log_file` is specified, a FileHandler will also be added.
Args:
name (str): Logger name.
log_file (str | None): The log filename. If specified, a FileHandler
will be added to the logger.
log_level (int): The logger level.
file_mode (str): The file mode used in opening log file.
Defaults to 'w'.
Returns:
logging.Logger: The expected logger.
"""
# use logger in mmengine if exists.
try:
from mmengine.logging import MMLogger
if MMLogger.check_instance_created(name):
logger = MMLogger.get_instance(name)
else:
logger = MMLogger.get_instance(name,
logger_name=name,
log_file=log_file,
log_level=log_level,
file_mode=file_mode)
return logger

except Exception:
pass

logger = logging.getLogger(name)
if name in logger_initialized:
return logger
# handle hierarchical names
# e.g., logger "a" is initialized, then logger "a.b" will skip the
# initialization since it is a child of "a".
for logger_name in logger_initialized:
if name.startswith(logger_name):
return logger

# handle duplicate logs to the console
for handler in logger.root.handlers:
if type(handler) is logging.StreamHandler:
handler.setLevel(logging.ERROR)

stream_handler = logging.StreamHandler()
handlers = [stream_handler]

if log_file is not None:
# Here, the default behaviour of the official logger is 'a'. Thus, we
# provide an interface to change the file mode to the default
# behaviour.
file_handler = logging.FileHandler(log_file, file_mode)
handlers.append(file_handler)

formatter = logging.Formatter(
'%(asctime)s - %(name)s - %(levelname)s - %(message)s')
for handler in handlers:
handler.setFormatter(formatter)
handler.setLevel(log_level)
logger.addHandler(handler)

logger.setLevel(log_level)
logger_initialized[name] = True

return logger


def prepare_tensor(name, input_tensor):
"""Create grpcclient's InferInput instance according to a given tensor."""
Expand Down
7 changes: 5 additions & 2 deletions lmdeploy/turbomind/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,10 @@ def valid_str(string, coding='utf-8'):
return ret


def main(model_path, session_id: int = 1, repetition_penalty: float = 1.0):
def main(model_path,
session_id: int = 1,
repetition_penalty: float = 1.0,
tp=1):
"""An example to perform model inference through the command line
interface.
Expand All @@ -39,7 +42,7 @@ def main(model_path, session_id: int = 1, repetition_penalty: float = 1.0):
"""
tokenizer_model_path = osp.join(model_path, 'triton_models', 'tokenizer')
tokenizer = Tokenizer(tokenizer_model_path)
tm_model = tm.TurboMind(model_path, eos_id=tokenizer.eos_token_id)
tm_model = tm.TurboMind(model_path, eos_id=tokenizer.eos_token_id, tp=tp)
generator = tm_model.create_instance()

nth_round = 1
Expand Down
19 changes: 11 additions & 8 deletions lmdeploy/turbomind/turbomind.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

import lmdeploy
from lmdeploy.model import MODELS
from lmdeploy.turbomind.utils import get_logger

# TODO: find another way import _turbomind
lmdeploy_dir = osp.split(lmdeploy.__file__)[0]
Expand Down Expand Up @@ -69,23 +70,21 @@ class TurboMind:
Args:
model_path (str): the path of turbomind's model
data_type (str): the data type
eos_id (int): eos token id
tp (int): tensor parallel
"""

def __init__(self,
model_path: str,
data_type: str = 'fp16',
eos_id: int = 2):
def __init__(self, model_path: str, eos_id: int = 2, tp: int = 1):
self.eos_id = eos_id

# TODO: support mpi
node_id = 0
node_num = 1

# read meta from model path
self.gpu_count = 1
self.gpu_count = tp
self.session_len = 2048
data_type = 'fp16'
ini_path = osp.join(model_path, 'triton_models/weights/config.ini')
with open(ini_path, 'r') as f:
parser = ConfigParser()
Expand All @@ -97,10 +96,14 @@ def __init__(self,
section_name = 'llama'

if len(section_name) > 0:
self.gpu_count = parser.getint(section_name,
'tensor_para_size')
tp_cfg = parser.getint(section_name, 'tensor_para_size')
self.session_len = parser.getint(section_name, 'session_len')
if tp_cfg != 1 and tp_cfg != tp:
get_logger('turbomind').info(
f'found tp={tp_cfg} in config.ini.')
self.gpu_count = tp_cfg
self.model_name = parser.get(section_name, 'model_name')
data_type = parser.get(section_name, 'weight_type')
model = MODELS.get(self.model_name)()
self.stop_words = _stop_words(model.stop_words)

Expand Down
79 changes: 79 additions & 0 deletions lmdeploy/turbomind/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# Copyright (c) OpenMMLab. All rights reserved.
import logging
from typing import Optional

logger_initialized = {}


def get_logger(name: str,
log_file: Optional[str] = None,
log_level: int = logging.INFO,
file_mode: str = 'w'):
"""Initialize and get a logger by name.
If the logger has not been initialized, this method will initialize the
logger by adding one or two handlers, otherwise the initialized logger will
be directly returned. During initialization, a StreamHandler will always be
added. If `log_file` is specified, a FileHandler will also be added.
Args:
name (str): Logger name.
log_file (str | None): The log filename. If specified, a FileHandler
will be added to the logger.
log_level (int): The logger level.
file_mode (str): The file mode used in opening log file.
Defaults to 'w'.
Returns:
logging.Logger: The expected logger.
"""
# use logger in mmengine if exists.
try:
from mmengine.logging import MMLogger
if MMLogger.check_instance_created(name):
logger = MMLogger.get_instance(name)
else:
logger = MMLogger.get_instance(name,
logger_name=name,
log_file=log_file,
log_level=log_level,
file_mode=file_mode)
return logger

except Exception:
pass

logger = logging.getLogger(name)
if name in logger_initialized:
return logger
# handle hierarchical names
# e.g., logger "a" is initialized, then logger "a.b" will skip the
# initialization since it is a child of "a".
for logger_name in logger_initialized:
if name.startswith(logger_name):
return logger

# handle duplicate logs to the console
for handler in logger.root.handlers:
if type(handler) is logging.StreamHandler:
handler.setLevel(logging.ERROR)

stream_handler = logging.StreamHandler()
handlers = [stream_handler]

if log_file is not None:
# Here, the default behaviour of the official logger is 'a'. Thus, we
# provide an interface to change the file mode to the default
# behaviour.
file_handler = logging.FileHandler(log_file, file_mode)
handlers.append(file_handler)

formatter = logging.Formatter(
'%(asctime)s - %(name)s - %(levelname)s - %(message)s')
for handler in handlers:
handler.setFormatter(formatter)
handler.setLevel(log_level)
logger.addHandler(handler)

logger.setLevel(log_level)
logger_initialized[name] = True

return logger
Loading

0 comments on commit 4767b04

Please sign in to comment.