Skip to content

Commit

Permalink
Check-in user guide for w4a16 LLM deployment (InternLM#224)
Browse files Browse the repository at this point in the history
* tmp

* update

* update

* update

* update

* update

* remove

* update

* update
  • Loading branch information
lvhan028 authored Aug 14, 2023
1 parent 6829684 commit 8e8629d
Show file tree
Hide file tree
Showing 6 changed files with 214 additions and 23 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ ______________________________________________________________________

## News 🎉

- \[2023/08\] TurboMind supports 4-bit inference, 2.4x faster than FP16, the fastest open-source implementation🚀.
- \[2023/08\] TurboMind supports 4-bit inference, 2.4x faster than FP16, the fastest open-source implementation🚀. Check [this](./docs/en/w4a16.md) guide for detailed info
- \[2023/08\] LMDeploy has launched on the [HuggingFace Hub](https://huggingface.co/lmdeploy), providing ready-to-use 4-bit models.
- \[2023/08\] LMDeploy supports 4-bit quantization using the [AWQ](https://arxiv.org/abs/2306.00978) algorithm.
- \[2023/07\] TurboMind supports Llama-2 70B with GQA.
Expand Down
2 changes: 1 addition & 1 deletion README_zh-CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ ______________________________________________________________________

## 更新 🎉

- \[2023/08\] TurboMind 支持 4-bit 推理,速度是 FP16 的 2.4 倍,是目前最快的开源实现🚀
- \[2023/08\] TurboMind 支持 4-bit 推理,速度是 FP16 的 2.4 倍,是目前最快的开源实现🚀。部署方式请看[这里](./docs/zh_cn/w4a16.md)
- \[2023/08\] LMDeploy 开通了 [HuggingFace Hub](https://huggingface.co/lmdeploy) ,提供开箱即用的 4-bit 模型
- \[2023/08\] LMDeploy 支持使用 [AWQ](https://arxiv.org/abs/2306.00978) 算法进行 4-bit 量化
- \[2023/07\] TurboMind 支持使用 GQA 的 Llama-2 70B 模型
Expand Down
22 changes: 14 additions & 8 deletions benchmark/profile_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import time
from queue import Queue
from threading import Thread
from typing import List

import fire
import numpy as np
Expand All @@ -29,25 +30,29 @@ def infer(model, session_id: int, input_ids: str, output_seqlen: int,
tokens.append(token)

# TODO: ignore first token
first_token_latency = timestamps[0] - start
first_token_latency = np.round(timestamps[0] - start, 2)
if len(timestamps) == 1:
token_latency = timestamps[0] - start
token_latency = np.round(timestamps[0] - start, 2)
token = tokens[0]
else:
token_latency = timestamps[-1] - timestamps[0]
token_latency = np.round(timestamps[-1] - timestamps[0], 2)
token = tokens[-1] - tokens[0]
stats.append([first_token_latency, token, token_latency])
que.put((session_id, stats))


def warmup(model, concurrency: int, output_seqlen: int, warmup_round: int = 4):
def warmup(model,
concurrency: int,
input_ids: List[int],
output_seqlen: int,
warmup_round: int = 2):
print('start to warmup ...')

def _infer(model, session_id):
chatbot = model.create_instance()
for _ in range(warmup_round):
for _ in chatbot.stream_infer(session_id,
input_ids=[1],
input_ids=input_ids,
request_output_len=output_seqlen,
sequence_start=True,
sequence_end=True,
Expand Down Expand Up @@ -82,11 +87,12 @@ def main(model_path: str,
tokenizer = Tokenizer(tokenizer_model_path)
tm_model = TurboMind(model_path=model_path, tp=tp)

warmup(tm_model, concurrency, output_seqlen)

# make up a prompt that can be tokenized into {input_seqlen} tokens
prompt = '' if input_seqlen == 0 else 'hi' + ' hi' * (input_seqlen - 1)
input_ids = tokenizer.encode(prompt)

warmup(tm_model, concurrency, input_ids, output_seqlen)

que = Queue()
procs = []
_start = time.perf_counter()
Expand Down Expand Up @@ -134,7 +140,7 @@ def main(model_path: str,
f'{first_token_latency_ave:.2f}s\ntoken latency(min, max, ave): '
f'{token_latency_min:.2f}s, {token_latency_max:.2f}s, '
f'{token_latency_ave:.2f}s\n'
f'throughput: {throughput} token/s\n{"-" * 50}')
f'throughput: {throughput:.2f} token/s\n{"-" * 50}')


if __name__ == '__main__':
Expand Down
13 changes: 0 additions & 13 deletions docs/en/serving.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,19 +34,6 @@ bash workspace/service_docker_up.sh

</details>

<details open>
<summary><b>7B with INT4 weight only quantization</b></summary>

```shell
python3 -m lmdeploy.serve.turbomind.deploy llama2 /path/to/llama-2-7b-chat-hf \
--model_format awq \
--group_size 128 \
--quant_path /path/to/awq-quant-weight.pt
bash workspace/service_docker_up.sh
```

</details>

## Serving [LLaMA](https://github.com/facebookresearch/llama)

Weights for the LLaMA models can be obtained from by filling out [this form](https://docs.google.com/forms/d/e/1FAIpQLSfqNECQnMkycAp2jP4Z9TFX0cGR4uf7b_fBxjY_OjhJILlKGA/viewform)
Expand Down
101 changes: 101 additions & 0 deletions docs/en/w4a16.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
# W4A16 LLM Model Deployment

LMDeploy supports LLM model inference of 4-bit weight, with the minimum requirement for NVIDIA graphics cards being sm80, such as A10, A100, Geforce 30/40 series.

Before proceeding with the inference, please ensure that lmdeploy(>=v0.0.4) is installed.

```shell
pip install lmdeploy
```

## 4-bit LLM model Inference

You can download the pre-quantized 4-bit weight models from LMDeploy's [model zoo](https://huggingface.co/lmdeploy) and conduct inference using the following command.

Alternatively, you can quantize 16-bit weights to 4-bit weights following the ["4-bit Weight Quantization"](#4-bit-weight-quantization) section, and then perform inference as per the below instructions.

Take the 4-bit Llama-2-chat-7B model from the model zoo as an example:

```shell
git-lfs install
git clone https://huggingface.co/lmdeploy/llama2-chat-7b-w4
```

As demonstrated in the command below, first convert the model's layout using `turbomind.deploy`, and then you can interact with the AI assistant in the terminal

```shell

## Convert the model's layout and store it in the default path, ./workspace.
python3 -m lmdeploy.serve.turbomind.deploy \
--model-name llama2 \
--model-path ./llama2-chat-7b-w4 \
--model-format awq \
--group-size 128

## inference
python3 -m lmdeploy.turbomind.chat ./workspace
```

## Serve with gradio

If you wish to interact with the model via web ui, please initiate the gradio server as indicated below:

```shell
python3 -m lmdeploy.serve.turbomind ./workspace --server_name {ip_addr} ----server_port {port}
```

Subsequently, you can open the website `http://{ip_addr}:{port}` in your browser and interact with the model

## Inference Performance

We benchmarked the Llama-2-7B-chat and Llama-2-13B-chat models with 4-bit quantization on NVIDIA GeForce RTX 4090 using [profile_generation.py](https://github.com/InternLM/lmdeploy/blob/main/benchmark/profile_generation.py). And we measure the token generation throughput (tokens/s) by setting a single prompt token and generating 512 tokens. All the results are measured for single batch inference.

| model | llm-awq | mlc-llm | turbomind |
| ---------------- | ------- | ------- | --------- |
| Llama-2-7B-chat | 112.9 | 159.4 | 206.4 |
| Llama-2-13B-chat | N/A | 90.7 | 115.8 |

Memory (GB) comparison results between 4-bit and 16-bit model with context size 2048 and 4096 respectively,

| model | 16bit(2048) | 4bit(2048) | 16bit(4096) | 4bit(4096) |
| ---------------- | ----------- | ---------- | ----------- | ---------- |
| Llama-2-7B-chat | 15.1 | 6.3 | 16.2 | 7.5 |
| Llama-2-13B-chat | OOM | 10.3 | OOM | 12.0 |

```shell
python benchmark/profile_generation.py \
./workspace \
--concurrency 1 --input_seqlen 1 --output_seqlen 512
```

## 4-bit Weight Quantization

It includes two steps:

- generate quantization parameter
- quantize model according to the parameter

### Step 1: Generate Quantization Parameter

```shell
python3 -m lmdeploy.lite.apis.calibrate \
--model $HF_MODEL \
--calib_dataset 'c4' \ # Calibration dataset, supports c4, ptb, wikitext2, pileval
--calib_samples 128 \ # Number of samples in the calibration set, if memory is insufficient, you can appropriately reduce this
--calib_seqlen 2048 \ # Length of a single piece of text, if memory is insufficient, you can appropriately reduce this
--work_dir $WORK_DIR \ # Folder storing Pytorch format quantization statistics parameters and post-quantization weight
```

### Step2: Quantize Weights

LMDeploy employs AWQ algorithm for model weight quantization.

```shell
python3 -m lmdeploy.lite.apis.auto_awq \
--model $HF_MODEL \
--w_bits 4 \ # Bit number for weight quantization
--w_group_size 128 \ # Group size for weight quantization statistics
--work_dir $WORK_DIR \ # Directory saving quantization parameters from Step 1
```

After the quantization is complete, the quantized model is saved to `$WORK_DIR`. Then you can proceed with model inference according to the instructions in the ["4-Bit Weight Model Inference"](#4-bit-llm-model-inference) section.
97 changes: 97 additions & 0 deletions docs/zh_cn/w4a16.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
# W4A16 LLM 模型部署

LMDeploy 支持 4bit 权重模型的推理,**对 NVIDIA 显卡的最低要求是 sm80**,比如A10,A100,Gerforce 30/40系列。

在推理之前,请确保安装了 lmdeploy,版本 >= v0.0.4

```shell
pip install lmdeploy
```

## 4bit 权重模型推理

你可以直接从 LMDeploy 的 [model zoo](https://huggingface.co/lmdeploy) 下载已经量化好的 4bit 权重模型,直接使用下面的命令推理。也可以根据["4bit 权重量化"](#4bit-权重量化)章节的内容,把 16bit 权重量化为 4bit 权重,然后再按下述说明推理

以 4bit 的 Llama-2-chat-7B 模型为例,可以从 model zoo 直接下载:

```shell
git-lfs install
git clone https://huggingface.co/lmdeploy/llama2-chat-7b-w4
```

执行以下命令,即可在终端与模型对话:

```shell

## 转换模型的layout,存放在默认路径 ./workspace 下
python3 -m lmdeploy.serve.turbomind.deploy \
--model-name llama2 \
--model-path ./llama2-chat-7b-w4 \
--model-format awq \
--group-size 128

## 推理
python3 -m lmdeploy.turbomind.chat ./workspace
```

## 启动 gradio 服务

如果想通过 webui 与模型对话,请执行以下命令启动 gradio 服务

```shell
python3 -m lmdeploy.serve.turbomind ./workspace --server_name {ip_addr} ----server_port {port}
```

然后,在浏览器中打开 http://{ip_addr}:{port},即可在线对话

## 推理速度

我们在 NVIDIA GeForce RTX 4090 上使用 [profile_generation.py](https://github.com/InternLM/lmdeploy/blob/main/benchmark/profile_generation.py),分别测试了 4-bit Llama-2-7B-chat 和 Llama-2-13B-chat 模型的 token 生成速度。测试配置为 batch size = 1,(prompt_tokens, completion_tokens) = (1, 512)

| model | llm-awq | mlc-llm | turbomind |
| ---------------- | ------- | ------- | --------- |
| Llama-2-7B-chat | 112.9 | 159.4 | 206.4 |
| Llama-2-13B-chat | N/A | 90.7 | 115.8 |

上述两个模型的16bit 和 4bit 权重,分别使用 turbomind 推理时,各自在context size 为 2048 和 4096 配置下,所占的显存对比如下:

| model | 16bit(2048) | 4bit(2048) | 16bit(4096) | 4bit(4096) |
| ---------------- | ----------- | ---------- | ----------- | ---------- |
| Llama-2-7B-chat | 15.1 | 6.3 | 16.2 | 7.5 |
| Llama-2-13B-chat | OOM | 10.3 | OOM | 12.0 |

```shell
python benchmark/profile_generation.py \
./workspace \
--concurrency 1 --input_seqlen 1 --output_seqlen 512
```

## 4bit 权重量化

4bit 权重量化包括 2 步:

- 生成量化参数
- 根据量化参数,量化模型权重

### 第一步:生成量化参数

```shell
python3 -m lmdeploy.lite.apis.calibrate \
--model $HF_MODEL \
--calib_dataset 'c4' \ # 校准数据集,支持 c4, ptb, wikitext2, pileval
--calib_samples 128 \ # 校准集的样本数,如果显存不够,可以适当调小
--calib_seqlen 2048 \ # 单条的文本长度,如果显存不够,可以适当调小
--work_dir $WORK_DIR \ # 保存 Pytorch 格式量化统计参数和量化后权重的文件夹
```

### 第二步:量化权重模型

LMDeploy 使用 AWQ 算法对模型权重进行量化。在执行下面的命令时,需要把步骤1的`$WORK_DIR`传入。量化结束后,权重文件也会存放在这个目录中。然后就可以根据 ["4bit权重模型推理"](#4bit-权重模型推理)章节的说明,进行模型推理。

```shell
python3 -m lmdeploy.lite.apis.auto_awq \
--model $HF_MODEL \
--w_bits 4 \ # 权重量化的 bit 数
--w_group_size 128 \ # 权重量化分组统计尺寸
--work_dir $WORK_DIR \ # 步骤 1 保存量化参数的目录
```

0 comments on commit 8e8629d

Please sign in to comment.