Skip to content

Commit

Permalink
[vLLM] Support vLLM CPU backend and provide QBits acceleration (#1551)
Browse files Browse the repository at this point in the history
Co-authored-by: VincyZhang <[email protected]>
Co-authored-by: Wang, Chang <[email protected]>
  • Loading branch information
3 people committed May 24, 2024
1 parent 93b12e9 commit 0e13607
Show file tree
Hide file tree
Showing 6 changed files with 292 additions and 57 deletions.
35 changes: 35 additions & 0 deletions examples/vllm/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# vLLM Acceleration with ITREX

Intel extension for transformers(ITREX) integrates the vLLM CPU backend and offers optional [QBits Module](../../docs/qbits.md) to accelerate the vLLM inference on CPUs.

## Installation Methods

1. vLLM Installation with CPU: Install vLLM from source code following the instructions provided [here](https://docs.vllm.ai/en/latest/getting_started/cpu-installation.html).

2. ITREX Installation: Install the ITREX following the [link](../../docs/get_started.md)

3. Dependencies: Install some additional dependencies that may be used. The dependencies are listed in the current directory.

Note: torch==2.3.0+cpu is required and vllm==0.4.2+cpu is validated.

## Usage Example

ITREX provides a script that demonstrates the vLLM inference acceleration. Run it with the following command:
```bash
numactl -m 0 -C 0-55 python vllm_acceleration_example.py --model_path=/home/model/chatglm2-6b --prompt=你好
```

## Supported and Validated Models
All models listed in the [vLLM Supported Models](https://docs.vllm.ai/en/latest/models/supported_models.html) can be accelerated theoretically.

We have validated the majority of existing models using vLLM==0.4.2+cpu:
* [THUDM/chatglm2-6b](https://hf-mirror.com/THUDM/chatglm2-6b)
* [meta-llama/Llama-2-7b-chat-hf](https://hf-mirror.com/meta-llama/Llama-2-7b-chat-hf)
* [baichuan-inc/Baichuan2-7B-Chat](https://hf-mirror.com/baichuan-inc/Baichuan2-7B-Chat)
* [tiiuae/falcon-7b](https://huggingface.co/tiiuae/falcon-7b)
* [meta-llama/Meta-Llama-3-8B](https://huggingface.co/meta-llama/Meta-Llama-3-8B)
* [mistralai/Mistral-7B-Instruct-v0.1](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1)
* [microsoft/phi-2](https://huggingface.co/microsoft/phi-2)
* [Qwen/CodeQwen1.5-7B-Chat](https://huggingface.co/Qwen/CodeQwen1.5-7B-Chat)

If you encounter any problems, please let us know.
3 changes: 3 additions & 0 deletions examples/vllm/requirement.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
accelerate
datasets
peft
85 changes: 85 additions & 0 deletions examples/vllm/vllm_acceleration_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# Copyright (c) 2024 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import time
import os
from vllm import LLM, SamplingParams
from typing import List, Optional
from intel_extension_for_transformers.transformers import AutoModelForCausalLM, RtnConfig
from transformers import AutoTokenizer


def main(args_in: Optional[List[str]] = None) -> None:
parser = argparse.ArgumentParser()
parser.add_argument("--model_path", type=str, help="Model name: String", required=True)
parser.add_argument(
"-p",
"--prompt",
type=str,
help="Prompt to start generation with: String (default: empty)",
default="Once upon a time",
)
parser.add_argument("--benchmark", action="store_true")
parser.add_argument("--use_neural_speed", action="store_true")
args = parser.parse_args(args_in)
print(args)

if args.benchmark:
if args.use_neural_speed:
os.environ["NEURAL_SPEED_VERBOSE"] = "1"
woq_config = RtnConfig(bits=4, weight_dtype="int4", compute_dtype="int8", scale_dtype="bf16")
model_with_ns = AutoModelForCausalLM.from_pretrained(args.model_path, quantization_config=woq_config)

tokenizer = AutoTokenizer.from_pretrained(args.model_path, trust_remote_code=True)
inputs = tokenizer(args.prompt, return_tensors="pt").input_ids

T5 = time.time()
output = model_with_ns.generate(inputs, max_new_tokens=32)
T6 = time.time()
print("neural speed output = ", output)

llm = LLM(model=args.model_path, trust_remote_code=True)
sampling_params = SamplingParams(max_tokens=32)
T1 = time.time()
original_outputs = llm.generate(args.prompt, sampling_params) # Generate texts from the prompts.
T2 = time.time()
vllm_latency = (T2 - T1) * 1000

model = AutoModelForCausalLM.from_pretrained(args.model_path, use_vllm=True)
T3 = time.time()
optimized_output = model.generate(args.prompt, sampling_params)
T4 = time.time()
qbits_latency = (T4 - T3) * 1000

print("original outputs = ", original_outputs)
print("input_tokens_length = ", len(original_outputs[0].prompt_token_ids))
print("output_tokens_length = ", len(original_outputs[0].outputs[0].token_ids))

print("optimized outputs = ", optimized_output)
print("input_tokens_length = ", len(optimized_output[0].prompt_token_ids))
print("output_tokens_length = ", len(optimized_output[0].outputs[0].token_ids))

print('The qbits optimized generate:%.2f ms' % qbits_latency)
print('The original vLLM generate:%.2f ms' % vllm_latency)

return

model = AutoModelForCausalLM.from_pretrained(args.model_path, use_vllm=True)
output = model.generate(args.prompt)
print(output)


if __name__ == "__main__":
main()
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import torch
from ..utils import DTYPE_BITS_MAPPING
from functools import reduce
Expand All @@ -23,19 +24,19 @@
from peft.tuners.lora import LoraLayer, LoraModel
from peft.utils.other import transpose
from intel_extension_for_transformers.transformers.llm.quantization.autograd import (
matmul_kbit,
)
matmul_kbit, )
import intel_extension_for_transformers.qbits as qbits # pylint: disable=E0611, E0401


class DropoutQBits_(torch.autograd.Function):

@staticmethod
def forward(ctx, input, probability):
mask = qbits.dropout_fwd(input, probability)
if any(ctx.needs_input_grad[:1]):
ctx.tensors = (mask,)
ctx.tensors = (mask, )
else:
ctx.tensors = (None,)
ctx.tensors = (None, )
return input

@staticmethod
Expand All @@ -51,6 +52,7 @@ def backward(ctx, grad_output):


class DropoutQBits(torch.nn.Module):

def __init__(self, p=0.0):
super().__init__()
self.p = p
Expand All @@ -63,6 +65,7 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:


class ParamsQBits(torch.nn.Parameter):

def __new__(
cls,
data=None,
Expand All @@ -87,6 +90,7 @@ def __new__(


class QuantizedLinearQBits(torch.nn.Linear):

def __init__(
self,
input_features,
Expand Down Expand Up @@ -156,6 +160,9 @@ def forward(self, x: torch.Tensor):
shape[-1] = self.out_features
out = out.view(shape)

if os.environ.get("backend", None) == "use_vllm":
return out, None

return out

def set_fp_weights_bias(self, weight_data, bias=None):
Expand Down Expand Up @@ -264,33 +271,24 @@ def quant_weight_w_scale(self, weight, scale, zp, group_size=-1):
if zp is not None:
zp = zp.to(device)
if group_size == -1:
return (
weight.div_(scale).round_()
if zp is None
else weight.div_(scale).add_(zp).round_()
)
return (weight.div_(scale).round_() if zp is None else weight.div_(scale).add_(zp).round_())
int_weight = torch.zeros(weight.shape).to(device)
leng = weight.shape[1] // group_size
tail_flag = False if weight.shape[1] % group_size == 0 else True
for i in range(leng):
int_weight_tmp = weight[:, i * group_size : (i + 1) * group_size].div_(
scale[:, i].unsqueeze(1)
)
int_weight_tmp = weight[:, i * group_size:(i + 1) * group_size].div_(scale[:, i].unsqueeze(1))
if zp is not None:
int_weight_tmp.add_(zp[:, i].unsqueeze(1))
int_weight[:, i * group_size : (i + 1) * group_size].copy_(
int_weight_tmp.round_()
)
int_weight[:, i * group_size:(i + 1) * group_size].copy_(int_weight_tmp.round_())
if tail_flag:
int_weight_tmp = weight[:, leng * group_size :].div_(
scale[:, -1].unsqueeze(1)
)
int_weight_tmp = weight[:, leng * group_size:].div_(scale[:, -1].unsqueeze(1))
if zp is not None:
int_weight_tmp.add_(zp[:, -1].unsqueeze(1))
int_weight[:, leng * group_size :].copy_(int_weight_tmp.round_())
int_weight[:, leng * group_size:].copy_(int_weight_tmp.round_())
return int_weight

def recover_qparms(self):

def recover_idx(ret_idx, k, blocksize):
g_idx = torch.zeros(k, dtype=int)
value_range = (k + blocksize - 1) // blocksize
Expand Down Expand Up @@ -328,18 +326,12 @@ def recover_int_weight(g_idx, int_weight):
else:
g_idx = None
weight_dtype_ascii = qbits.acquire_packed_weight_info(self.weight, 6)
weight_dtype = "".join(
chr(ascii_code) for ascii_code in weight_dtype_ascii.tolist()
)
weight_dtype = "".join(chr(ascii_code) for ascii_code in weight_dtype_ascii.tolist())
bits = 4 if weight_dtype in ["nf4", "int4_clip", "fp4", "int4_fullrange"] else 8
compute_dtype_ascii = qbits.acquire_packed_weight_info(self.weight, 7)
compute_dtype = "".join(
chr(ascii_code) for ascii_code in compute_dtype_ascii.tolist()
)
compute_dtype = "".join(chr(ascii_code) for ascii_code in compute_dtype_ascii.tolist())
scales_dtype_ascii = qbits.acquire_packed_weight_info(self.weight, 8)
scales_dtype = "".join(
chr(ascii_code) for ascii_code in scales_dtype_ascii.tolist()
)
scales_dtype = "".join(chr(ascii_code) for ascii_code in scales_dtype_ascii.tolist())
if scales_dtype is None:
assert False, "scales dtype only support fp32."
scales = qbits.acquire_packed_weight_info(self.weight, 9)
Expand All @@ -356,9 +348,7 @@ def recover_int_weight(g_idx, int_weight):

revert_wei = torch.zeros(in_features, out_features, dtype=torch.float)

qbits.dequantize_packed_weight(
self.weight, revert_wei, False, compute_dtype, weight_dtype, scales_dtype
)
qbits.dequantize_packed_weight(self.weight, revert_wei, False, compute_dtype, weight_dtype, scales_dtype)

int_weight = self.quant_weight_w_scale(
revert_wei.t(),
Expand Down Expand Up @@ -426,9 +416,7 @@ def __init__(
except:
qbits_customop_available = False
if lora_dropout > 0 and qbits_customop_available:
self.lora_dropout = torch.nn.ModuleDict(
{adapter_name: DropoutQBits(p=lora_dropout)}
)
self.lora_dropout = torch.nn.ModuleDict({adapter_name: DropoutQBits(p=lora_dropout)})

def merge(self, safe_merge: bool = False) -> None:
"""Merge the active adapter weights into the base weights.
Expand All @@ -440,10 +428,8 @@ def merge(self, safe_merge: bool = False) -> None:
NaNs. Defaults to `False`.
"""
if self.merged:
print(
f"Already following adapters were merged {','.join(self.merged_adapters)}. "
f"You are now additionally merging {','.join(self.active_adapters)}."
)
print(f"Already following adapters were merged {','.join(self.merged_adapters)}. "
f"You are now additionally merging {','.join(self.active_adapters)}.")
w_dequant = torch.zeros(
self.out_features,
self.in_features,
Expand All @@ -468,8 +454,7 @@ def merge(self, safe_merge: bool = False) -> None:

if not torch.isfinite(orig_weights).all():
raise ValueError(
f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken"
)
f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken")

w_data = orig_weights
else:
Expand Down Expand Up @@ -541,13 +526,10 @@ def unmerge(self) -> None:
)

def get_delta_weight(self, adapter) -> torch.Tensor:
return (
transpose(
self.lora_B[adapter].weight @ self.lora_A[adapter].weight,
False,
)
* self.scaling[adapter]
)
return (transpose(
self.lora_B[adapter].weight @ self.lora_A[adapter].weight,
False,
) * self.scaling[adapter])

def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.disable_adapters:
Expand Down Expand Up @@ -602,24 +584,18 @@ def _create_new_module(self, lora_config, adapter_name, target, **kwargs):
bias = kwargs.pop("bias", False)
in_features, out_features = target.in_features, target.out_features
if kwargs["fan_in_fan_out"]:
print(
"fan_in_fan_out is set to True but the target module is `torch.nn.Linear`. "
"Setting fan_in_fan_out to False."
)
print("fan_in_fan_out is set to True but the target module is `torch.nn.Linear`. "
"Setting fan_in_fan_out to False.")
kwargs["fan_in_fan_out"] = lora_config.fan_in_fan_out = False
kwargs["compute_dtype"] = target.compute_dtype
kwargs["compress_statistics"] = target.compress_statistics
kwargs["weight_dtype"] = target.weight_dtype
kwargs["scale_dtype"] = target.scale_dtype
kwargs["blocksize"] = target.blocksize
kwargs["scheme"] = target.scheme
new_module = QuantizedLoraLinearQBits(
adapter_name, in_features, out_features, bias=bias, **kwargs
)
new_module = QuantizedLoraLinearQBits(adapter_name, in_features, out_features, bias=bias, **kwargs)
else:
new_module = QBitsLoraModel._create_new_module_(
lora_config, adapter_name, target, **kwargs
)
new_module = QBitsLoraModel._create_new_module_(lora_config, adapter_name, target, **kwargs)
return new_module


Expand Down
Loading

0 comments on commit 0e13607

Please sign in to comment.