From 0e13607721b275e35437ab3d7b4d5a423cc00c3c Mon Sep 17 00:00:00 2001 From: Zhenzhong1 <109137058+Zhenzhong1@users.noreply.github.com> Date: Fri, 24 May 2024 16:18:01 +0800 Subject: [PATCH] [vLLM] Support vLLM CPU backend and provide QBits acceleration (#1551) Co-authored-by: VincyZhang Co-authored-by: Wang, Chang --- examples/vllm/README.md | 35 ++++++++ examples/vllm/requirement.txt | 3 + examples/vllm/vllm_acceleration_example.py | 85 ++++++++++++++++++ .../llm/quantization/nn/modules.py | 90 +++++++------------ .../transformers/modeling/modeling_auto.py | 89 ++++++++++++++++++ tests/Nightly/test_vllm.py | 47 ++++++++++ 6 files changed, 292 insertions(+), 57 deletions(-) create mode 100644 examples/vllm/README.md create mode 100644 examples/vllm/requirement.txt create mode 100644 examples/vllm/vllm_acceleration_example.py create mode 100644 tests/Nightly/test_vllm.py diff --git a/examples/vllm/README.md b/examples/vllm/README.md new file mode 100644 index 00000000000..14ea00e6e20 --- /dev/null +++ b/examples/vllm/README.md @@ -0,0 +1,35 @@ +# vLLM Acceleration with ITREX + +Intel extension for transformers(ITREX) integrates the vLLM CPU backend and offers optional [QBits Module](../../docs/qbits.md) to accelerate the vLLM inference on CPUs. + +## Installation Methods + +1. vLLM Installation with CPU: Install vLLM from source code following the instructions provided [here](https://docs.vllm.ai/en/latest/getting_started/cpu-installation.html). + +2. ITREX Installation: Install the ITREX following the [link](../../docs/get_started.md) + +3. Dependencies: Install some additional dependencies that may be used. The dependencies are listed in the current directory. + +Note: torch==2.3.0+cpu is required and vllm==0.4.2+cpu is validated. + +## Usage Example + +ITREX provides a script that demonstrates the vLLM inference acceleration. Run it with the following command: +```bash +numactl -m 0 -C 0-55 python vllm_acceleration_example.py --model_path=/home/model/chatglm2-6b --prompt=你好 +``` + +## Supported and Validated Models +All models listed in the [vLLM Supported Models](https://docs.vllm.ai/en/latest/models/supported_models.html) can be accelerated theoretically. + +We have validated the majority of existing models using vLLM==0.4.2+cpu: +* [THUDM/chatglm2-6b](https://hf-mirror.com/THUDM/chatglm2-6b) +* [meta-llama/Llama-2-7b-chat-hf](https://hf-mirror.com/meta-llama/Llama-2-7b-chat-hf) +* [baichuan-inc/Baichuan2-7B-Chat](https://hf-mirror.com/baichuan-inc/Baichuan2-7B-Chat) +* [tiiuae/falcon-7b](https://huggingface.co/tiiuae/falcon-7b) +* [meta-llama/Meta-Llama-3-8B](https://huggingface.co/meta-llama/Meta-Llama-3-8B) +* [mistralai/Mistral-7B-Instruct-v0.1](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1) +* [microsoft/phi-2](https://huggingface.co/microsoft/phi-2) +* [Qwen/CodeQwen1.5-7B-Chat](https://huggingface.co/Qwen/CodeQwen1.5-7B-Chat) + +If you encounter any problems, please let us know. diff --git a/examples/vllm/requirement.txt b/examples/vllm/requirement.txt new file mode 100644 index 00000000000..b52bcc4a901 --- /dev/null +++ b/examples/vllm/requirement.txt @@ -0,0 +1,3 @@ +accelerate +datasets +peft diff --git a/examples/vllm/vllm_acceleration_example.py b/examples/vllm/vllm_acceleration_example.py new file mode 100644 index 00000000000..b56487c38bd --- /dev/null +++ b/examples/vllm/vllm_acceleration_example.py @@ -0,0 +1,85 @@ +# Copyright (c) 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import time +import os +from vllm import LLM, SamplingParams +from typing import List, Optional +from intel_extension_for_transformers.transformers import AutoModelForCausalLM, RtnConfig +from transformers import AutoTokenizer + + +def main(args_in: Optional[List[str]] = None) -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--model_path", type=str, help="Model name: String", required=True) + parser.add_argument( + "-p", + "--prompt", + type=str, + help="Prompt to start generation with: String (default: empty)", + default="Once upon a time", + ) + parser.add_argument("--benchmark", action="store_true") + parser.add_argument("--use_neural_speed", action="store_true") + args = parser.parse_args(args_in) + print(args) + + if args.benchmark: + if args.use_neural_speed: + os.environ["NEURAL_SPEED_VERBOSE"] = "1" + woq_config = RtnConfig(bits=4, weight_dtype="int4", compute_dtype="int8", scale_dtype="bf16") + model_with_ns = AutoModelForCausalLM.from_pretrained(args.model_path, quantization_config=woq_config) + + tokenizer = AutoTokenizer.from_pretrained(args.model_path, trust_remote_code=True) + inputs = tokenizer(args.prompt, return_tensors="pt").input_ids + + T5 = time.time() + output = model_with_ns.generate(inputs, max_new_tokens=32) + T6 = time.time() + print("neural speed output = ", output) + + llm = LLM(model=args.model_path, trust_remote_code=True) + sampling_params = SamplingParams(max_tokens=32) + T1 = time.time() + original_outputs = llm.generate(args.prompt, sampling_params) # Generate texts from the prompts. + T2 = time.time() + vllm_latency = (T2 - T1) * 1000 + + model = AutoModelForCausalLM.from_pretrained(args.model_path, use_vllm=True) + T3 = time.time() + optimized_output = model.generate(args.prompt, sampling_params) + T4 = time.time() + qbits_latency = (T4 - T3) * 1000 + + print("original outputs = ", original_outputs) + print("input_tokens_length = ", len(original_outputs[0].prompt_token_ids)) + print("output_tokens_length = ", len(original_outputs[0].outputs[0].token_ids)) + + print("optimized outputs = ", optimized_output) + print("input_tokens_length = ", len(optimized_output[0].prompt_token_ids)) + print("output_tokens_length = ", len(optimized_output[0].outputs[0].token_ids)) + + print('The qbits optimized generate:%.2f ms' % qbits_latency) + print('The original vLLM generate:%.2f ms' % vllm_latency) + + return + + model = AutoModelForCausalLM.from_pretrained(args.model_path, use_vllm=True) + output = model.generate(args.prompt) + print(output) + + +if __name__ == "__main__": + main() diff --git a/intel_extension_for_transformers/transformers/llm/quantization/nn/modules.py b/intel_extension_for_transformers/transformers/llm/quantization/nn/modules.py index 664118379c4..0e073b258bc 100644 --- a/intel_extension_for_transformers/transformers/llm/quantization/nn/modules.py +++ b/intel_extension_for_transformers/transformers/llm/quantization/nn/modules.py @@ -15,6 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os import torch from ..utils import DTYPE_BITS_MAPPING from functools import reduce @@ -23,19 +24,19 @@ from peft.tuners.lora import LoraLayer, LoraModel from peft.utils.other import transpose from intel_extension_for_transformers.transformers.llm.quantization.autograd import ( - matmul_kbit, -) + matmul_kbit, ) import intel_extension_for_transformers.qbits as qbits # pylint: disable=E0611, E0401 class DropoutQBits_(torch.autograd.Function): + @staticmethod def forward(ctx, input, probability): mask = qbits.dropout_fwd(input, probability) if any(ctx.needs_input_grad[:1]): - ctx.tensors = (mask,) + ctx.tensors = (mask, ) else: - ctx.tensors = (None,) + ctx.tensors = (None, ) return input @staticmethod @@ -51,6 +52,7 @@ def backward(ctx, grad_output): class DropoutQBits(torch.nn.Module): + def __init__(self, p=0.0): super().__init__() self.p = p @@ -63,6 +65,7 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: class ParamsQBits(torch.nn.Parameter): + def __new__( cls, data=None, @@ -87,6 +90,7 @@ def __new__( class QuantizedLinearQBits(torch.nn.Linear): + def __init__( self, input_features, @@ -156,6 +160,9 @@ def forward(self, x: torch.Tensor): shape[-1] = self.out_features out = out.view(shape) + if os.environ.get("backend", None) == "use_vllm": + return out, None + return out def set_fp_weights_bias(self, weight_data, bias=None): @@ -264,33 +271,24 @@ def quant_weight_w_scale(self, weight, scale, zp, group_size=-1): if zp is not None: zp = zp.to(device) if group_size == -1: - return ( - weight.div_(scale).round_() - if zp is None - else weight.div_(scale).add_(zp).round_() - ) + return (weight.div_(scale).round_() if zp is None else weight.div_(scale).add_(zp).round_()) int_weight = torch.zeros(weight.shape).to(device) leng = weight.shape[1] // group_size tail_flag = False if weight.shape[1] % group_size == 0 else True for i in range(leng): - int_weight_tmp = weight[:, i * group_size : (i + 1) * group_size].div_( - scale[:, i].unsqueeze(1) - ) + int_weight_tmp = weight[:, i * group_size:(i + 1) * group_size].div_(scale[:, i].unsqueeze(1)) if zp is not None: int_weight_tmp.add_(zp[:, i].unsqueeze(1)) - int_weight[:, i * group_size : (i + 1) * group_size].copy_( - int_weight_tmp.round_() - ) + int_weight[:, i * group_size:(i + 1) * group_size].copy_(int_weight_tmp.round_()) if tail_flag: - int_weight_tmp = weight[:, leng * group_size :].div_( - scale[:, -1].unsqueeze(1) - ) + int_weight_tmp = weight[:, leng * group_size:].div_(scale[:, -1].unsqueeze(1)) if zp is not None: int_weight_tmp.add_(zp[:, -1].unsqueeze(1)) - int_weight[:, leng * group_size :].copy_(int_weight_tmp.round_()) + int_weight[:, leng * group_size:].copy_(int_weight_tmp.round_()) return int_weight def recover_qparms(self): + def recover_idx(ret_idx, k, blocksize): g_idx = torch.zeros(k, dtype=int) value_range = (k + blocksize - 1) // blocksize @@ -328,18 +326,12 @@ def recover_int_weight(g_idx, int_weight): else: g_idx = None weight_dtype_ascii = qbits.acquire_packed_weight_info(self.weight, 6) - weight_dtype = "".join( - chr(ascii_code) for ascii_code in weight_dtype_ascii.tolist() - ) + weight_dtype = "".join(chr(ascii_code) for ascii_code in weight_dtype_ascii.tolist()) bits = 4 if weight_dtype in ["nf4", "int4_clip", "fp4", "int4_fullrange"] else 8 compute_dtype_ascii = qbits.acquire_packed_weight_info(self.weight, 7) - compute_dtype = "".join( - chr(ascii_code) for ascii_code in compute_dtype_ascii.tolist() - ) + compute_dtype = "".join(chr(ascii_code) for ascii_code in compute_dtype_ascii.tolist()) scales_dtype_ascii = qbits.acquire_packed_weight_info(self.weight, 8) - scales_dtype = "".join( - chr(ascii_code) for ascii_code in scales_dtype_ascii.tolist() - ) + scales_dtype = "".join(chr(ascii_code) for ascii_code in scales_dtype_ascii.tolist()) if scales_dtype is None: assert False, "scales dtype only support fp32." scales = qbits.acquire_packed_weight_info(self.weight, 9) @@ -356,9 +348,7 @@ def recover_int_weight(g_idx, int_weight): revert_wei = torch.zeros(in_features, out_features, dtype=torch.float) - qbits.dequantize_packed_weight( - self.weight, revert_wei, False, compute_dtype, weight_dtype, scales_dtype - ) + qbits.dequantize_packed_weight(self.weight, revert_wei, False, compute_dtype, weight_dtype, scales_dtype) int_weight = self.quant_weight_w_scale( revert_wei.t(), @@ -426,9 +416,7 @@ def __init__( except: qbits_customop_available = False if lora_dropout > 0 and qbits_customop_available: - self.lora_dropout = torch.nn.ModuleDict( - {adapter_name: DropoutQBits(p=lora_dropout)} - ) + self.lora_dropout = torch.nn.ModuleDict({adapter_name: DropoutQBits(p=lora_dropout)}) def merge(self, safe_merge: bool = False) -> None: """Merge the active adapter weights into the base weights. @@ -440,10 +428,8 @@ def merge(self, safe_merge: bool = False) -> None: NaNs. Defaults to `False`. """ if self.merged: - print( - f"Already following adapters were merged {','.join(self.merged_adapters)}. " - f"You are now additionally merging {','.join(self.active_adapters)}." - ) + print(f"Already following adapters were merged {','.join(self.merged_adapters)}. " + f"You are now additionally merging {','.join(self.active_adapters)}.") w_dequant = torch.zeros( self.out_features, self.in_features, @@ -468,8 +454,7 @@ def merge(self, safe_merge: bool = False) -> None: if not torch.isfinite(orig_weights).all(): raise ValueError( - f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken" - ) + f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken") w_data = orig_weights else: @@ -541,13 +526,10 @@ def unmerge(self) -> None: ) def get_delta_weight(self, adapter) -> torch.Tensor: - return ( - transpose( - self.lora_B[adapter].weight @ self.lora_A[adapter].weight, - False, - ) - * self.scaling[adapter] - ) + return (transpose( + self.lora_B[adapter].weight @ self.lora_A[adapter].weight, + False, + ) * self.scaling[adapter]) def forward(self, x: torch.Tensor) -> torch.Tensor: if self.disable_adapters: @@ -602,10 +584,8 @@ def _create_new_module(self, lora_config, adapter_name, target, **kwargs): bias = kwargs.pop("bias", False) in_features, out_features = target.in_features, target.out_features if kwargs["fan_in_fan_out"]: - print( - "fan_in_fan_out is set to True but the target module is `torch.nn.Linear`. " - "Setting fan_in_fan_out to False." - ) + print("fan_in_fan_out is set to True but the target module is `torch.nn.Linear`. " + "Setting fan_in_fan_out to False.") kwargs["fan_in_fan_out"] = lora_config.fan_in_fan_out = False kwargs["compute_dtype"] = target.compute_dtype kwargs["compress_statistics"] = target.compress_statistics @@ -613,13 +593,9 @@ def _create_new_module(self, lora_config, adapter_name, target, **kwargs): kwargs["scale_dtype"] = target.scale_dtype kwargs["blocksize"] = target.blocksize kwargs["scheme"] = target.scheme - new_module = QuantizedLoraLinearQBits( - adapter_name, in_features, out_features, bias=bias, **kwargs - ) + new_module = QuantizedLoraLinearQBits(adapter_name, in_features, out_features, bias=bias, **kwargs) else: - new_module = QBitsLoraModel._create_new_module_( - lora_config, adapter_name, target, **kwargs - ) + new_module = QBitsLoraModel._create_new_module_(lora_config, adapter_name, target, **kwargs) return new_module diff --git a/intel_extension_for_transformers/transformers/modeling/modeling_auto.py b/intel_extension_for_transformers/transformers/modeling/modeling_auto.py index ee5ac78d18e..e52154e6b0a 100644 --- a/intel_extension_for_transformers/transformers/modeling/modeling_auto.py +++ b/intel_extension_for_transformers/transformers/modeling/modeling_auto.py @@ -86,6 +86,7 @@ has_file, ) +import torch.nn.functional as F from typing import Union if is_ipex_available() and is_intel_gpu_available(): @@ -338,6 +339,94 @@ class _BaseQBitsAutoModelClass: @classmethod def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): + use_vllm = kwargs.pop("use_vllm", None) + if use_vllm is not None: + logger.info("The backend is vLLM.") + from vllm import LLM # pylint: disable=E1101 + from vllm.model_executor.model_loader import get_model_loader # pylint: disable=E0611 + from vllm.model_executor.model_loader.weight_utils import default_weight_loader # pylint: disable=E0401 disable=E0611 + from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, + QKVParallelLinear, + ColumnParallelLinear, + RowParallelLinear) # pylint: disable=E1101 + + os.environ["backend"] = "use_vllm" + llm = LLM(model=pretrained_model_name_or_path, trust_remote_code=True) # Create an vllm instance. + model = llm.llm_engine.model_executor.driver_worker.model_runner.model # pylint: disable=E1101 + print("Original model =", model) + + original_parameter_memo = dict() + original_params_dict = dict(model.named_parameters(remove_duplicate=False)) + for name in original_params_dict.keys(): + params = original_params_dict[name] + if "qkv_proj" in name or "gate_up_proj" in name: + input_dim = getattr(params, "input_dim", None) + output_dim = getattr(params, "output_dim", None) + original_parameter_memo[name] = (input_dim, output_dim, params.weight_loader) + + class linear_adaptor(torch.nn.Linear): + + def __init__(self, in_features: int, out_features: int, bias: bool = True, \ + device=None, dtype=None) -> None: + super().__init__(in_features, out_features, bias, device, dtype) + + def forward(self, input: torch.Tensor) -> tuple[torch.Tensor, None]: + return F.linear(input, self.weight, self.bias), None + + for name, module in model.named_modules(): + bias_flag = False + if isinstance(module, QKVParallelLinear) or isinstance(module, MergedColumnParallelLinear) or \ + isinstance(module, RowParallelLinear) or isinstance(module, ColumnParallelLinear): + out_feature = module.weight.shape[0] + in_feature = module.weight.shape[1] + if getattr(module, "bias", False) != None: + bias_flag = True + weight_dtype = module.weight.dtype + + torch_linear = linear_adaptor(in_features=in_feature, + out_features=out_feature, + bias=bias_flag, + dtype=weight_dtype) + module_traversal = model + all_module_names = name.split('.') + all_module_names_except_last = all_module_names[:-1] + for sub_module_name in all_module_names_except_last: + module_traversal = module_traversal._modules[sub_module_name] + + module_traversal._modules[all_module_names[-1]] = copy.deepcopy(torch_linear) + + print("Optimized model =", model) + loader = get_model_loader(llm.llm_engine.load_config) # pylint: disable=E1101 + + weights_iterator = loader._get_weights_iterator(llm.llm_engine.model_config.model, + llm.llm_engine.model_config.revision, + fall_back_to_pt=True) + + from vllm.model_executor.model_loader.weight_utils import default_weight_loader # pylint: disable=E0401 disable=E0611 + params_dict = dict(model.named_parameters(remove_duplicate=False)) + for name in params_dict.keys(): + params = params_dict[name] + if hasattr(params, "weight_loader") == False: + if "qkv_proj" in name or "gate_up_proj" in name: + original_params = original_parameter_memo[name] + setattr(params, "input_dim", original_params[0]) + setattr(params, "output_dim", original_params[1]) + setattr(params, "weight_loader", original_params[2]) + else: + setattr(params, "weight_loader", default_weight_loader) + + model.load_weights(weights_iterator) + + print("INC quantizing...") + config = RtnConfig(compute_dtype="bf16", + group_size=128, + scale_dtype="bf16", + weight_dtype="int4_clip", + bits=4) + model = convert_to_quantized_model(model, config) + + return llm + # use for neuralspeed gguf gguf_file = kwargs.pop("gguf_file", None) if gguf_file is not None: diff --git a/tests/Nightly/test_vllm.py b/tests/Nightly/test_vllm.py new file mode 100644 index 00000000000..938b12fdfb7 --- /dev/null +++ b/tests/Nightly/test_vllm.py @@ -0,0 +1,47 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright (c) 2022 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import shutil +import torch +import unittest +import neural_compressor.adaptor.pytorch as nc_torch +from transformers import AutoTokenizer, TextStreamer +from intel_extension_for_transformers.transformers import AutoModelForCausalLM +PT_VERSION = nc_torch.get_torch_version() + +class TestVLLM(unittest.TestCase): + + @classmethod + def setUpClass(cls): + pass + + @classmethod + def tearDownClass(cls) -> None: + shutil.rmtree("./runtime_outs", ignore_errors=True) + + @unittest.skipIf(PT_VERSION.release < Version("2.3.0").release, + "Please use PyTroch 2.3.0 or higher version for vllm") + def test_use_vllm_api(self): + model_name = "/tf_dataset2/models/nlp_toolkit/llama-2-7b-chat/Llama-2-7b-chat-hf" + prompt = "Once upon a time" + model = AutoModelForCausalLM.from_pretrained(model_name, use_vllm=True) + output = model.generate(prompt) + print("output = ", output) + + +if __name__ == "__main__": + unittest.main()