Skip to content

Commit

Permalink
Dev/ipex woq (#1225)
Browse files Browse the repository at this point in the history
Co-authored-by: VincyZhang <[email protected]>
  • Loading branch information
PenghuiCheng and VincyZhang authored Feb 24, 2024
1 parent e879faa commit 8569d7b
Show file tree
Hide file tree
Showing 13 changed files with 1,181 additions and 72 deletions.
3 changes: 2 additions & 1 deletion .github/workflows/script/formatScan/pylint.sh
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ else
fi
# install packages
pip install git+https://github.com/EleutherAI/lm-evaluation-harness.git@83dbfbf6070324f3e5872f63e49d49ff7ef4c9b3
pip install accelerate nlpaug nltk schema optimum-intel==1.11.0 optimum==1.13.3 peft==0.6.2
pip install accelerate nlpaug nltk schema optimum-intel optimum peft
pip install --upgrade --force-reinstall transformers

echo "[DEBUG] list pipdeptree..."
pip install pipdeptree
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ omit =
*/intel_extension_for_transformers/llm/amp/**
*/intel_extension_for_transformers/llm/evaluation/**
*/intel_extension_for_transformers/llm/quantization/**
*/intel_extension_for_transformers/llm/utils/generation/**
*/intel_extension_for_transformers/llm/library/**
*/intel_extension_for_transformers/llm/operator/**
*/intel_extension_for_transformers/llm/runtime/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from transformers import AutoConfig, AutoTokenizer
from transformers.generation import GenerationConfig
import intel_extension_for_pytorch as ipex
from intel_extension_for_transformers.llm.utils.generation import _beam_search, _greedy_search
from intel_extension_for_transformers.transformers import AutoModelForCausalLM, WeightOnlyQuantConfig
from intel_extension_for_transformers.llm.quantization.utils import convert_dtype_str2torch
from transformers.utils import check_min_version
Expand Down Expand Up @@ -36,6 +37,7 @@
# ============Benchmark configs==============
parser.add_argument("--benchmark", action="store_true")
parser.add_argument("--do_profiling", action="store_true")
parser.add_argument("--disable_optimize_transformers", action="store_true")
parser.add_argument("--profile_token_latency", action="store_true")
parser.add_argument("--iters", default=10, type=int, help="num iter")
parser.add_argument("--num_warmup", default=3, type=int, help="num warmup")
Expand All @@ -49,7 +51,7 @@
help="tasks list for accuracy validation")
# ============WeightOnlyQuant configs===============
parser.add_argument("--woq", action="store_true")
parser.add_argument("--woq_algo", default="RTN", choices=['RTN'],
parser.add_argument("--woq_algo", default="RTN", choices=['RTN', 'GPTQ'],
help="Weight-only parameter.")
parser.add_argument("--woq_dtype", type=str, default="int4_fullrange",
choices=["int4_fullrange"])
Expand All @@ -58,6 +60,32 @@
parser.add_argument("--woq_enable_mse_search", action="store_true")
parser.add_argument("--device", default="xpu")
parser.add_argument("--compute_dtype", default="fp16")
parser.add_argument(
"--gptq_percdamp",
type=float,
default=0.01,
help="Percent of the average Hessian diagonal to use for dampening.",
)
parser.add_argument(
"--gptq_block_size",
type=int,
default=128,
help="Block size. sub weight matrix size to run GPTQ.",
)
parser.add_argument(
"--gptq_nsamples", type=int, default=128, help="Number of calibration data samples."
)
parser.add_argument(
"--gptq_use_max_length",
action="store_true",
help="Set all sequence length to be same length of args.gptq_pad_max_length",
)
parser.add_argument(
"--gptq_pad_max_length",
type=int,
default=2048,
help="Calibration dataset sequence max length, this should align with your model config",
)
# ============BitsAndBytes configs==============
parser.add_argument("--bitsandbytes", action="store_true")
parser.add_argument("--load_in_4bit", type=bool, default=False)
Expand All @@ -77,8 +105,7 @@
trust_remote_code=args.trust_remote_code,
revision=args.revision,
)
generation_config = GenerationConfig.from_pretrained(args.model, trust_remote_code=args.trust_remote_code)
generation_config.do_sample = False

user_model = None

# tokenizer
Expand All @@ -90,18 +117,38 @@

quantization_config = None
if args.woq:
quantization_config = WeightOnlyQuantConfig(
compute_dtype=args.compute_dtype, weight_dtype=args.woq_dtype,
group_size=args.woq_group_size, scale_dtype=args.compute_dtype
) #default is A16W4G16
if args.woq_algo == "GPTQ":
algorithm_args = {
"act_order": False,
"percdamp": args.gptq_percdamp,
"block_size": args.gptq_block_size,
"nsamples": args.gptq_nsamples,
"use_max_length": args.gptq_use_max_length,
"pad_max_length": args.gptq_pad_max_length,
}
quantization_config = WeightOnlyQuantConfig(
compute_dtype=args.compute_dtype,
scale_dtype=args.compute_dtype,
weight_dtype=args.woq_dtype,
scheme=args.woq_scheme,
group_size=args.woq_group_size,
algorithm=args.woq_algo,
tokenizer=tokenizer,
algorithm_args=algorithm_args,
)
else:
quantization_config = WeightOnlyQuantConfig(
compute_dtype=args.compute_dtype, weight_dtype=args.woq_dtype,
group_size=args.woq_group_size, scale_dtype=args.compute_dtype
) #default is A16W4G16

# get model
if quantization_config is not None:
user_model = AutoModelForCausalLM.from_pretrained(args.model,
device_map=args.device,
quantization_config=quantization_config,
trust_remote_code=args.trust_remote_code,
fp16=True,
torch_dtype=torch.float16,
use_neural_speed=False
)
elif args.load_in_4bit or args.load_in_8bit:
Expand All @@ -117,16 +164,24 @@
tokenizer.save_pretrained(args.output_dir)

if args.benchmark:
prompt = "它完成了,并提交了。你可以在Android和网络上玩美味生存。在网络上玩是有效的,但你必须模拟多次触摸才能移动桌子."
if config.model_type == "qwen":
prompt = "它完成了,并提交了。你可以在Android和网络上玩美味生存。在网络上玩是有效的,但你必须模拟多次触摸才能移动桌子."
else:
prompt = "Once upon a time, there existed a little girl, who liked to have adventures. She wanted to go to places and meet new people, and have fun."

input_size = tokenizer(prompt, return_tensors="pt").input_ids.size(dim=1)
print("---- Prompt size:", input_size)

user_model = AutoModelForCausalLM.from_pretrained(
args.model, trust_remote_code=args.trust_remote_code, device_map=args.device, torch_dtype=torch_dtype) \
if user_model is None else user_model
user_model = ipex.optimize_transformers(
user_model.eval(), device=args.device, inplace=True, woq=True, dtype=torch_dtype)
user_model = user_model.to(memory_format=torch.channels_last)
if not args.disable_optimize_transformers:
print("Optimize with IPEX...")
user_model = ipex.optimize_transformers(
user_model.eval(), device=args.device, inplace=True, woq=(hasattr(user_model, "quantization_config")), dtype=torch_dtype)
else:
print("Disabled optimization with IPEX...")
# start
num_iter = args.iters
num_warmup = args.num_warmup
Expand All @@ -136,7 +191,10 @@

generate_kwargs = dict(do_sample=False, temperature=0.9, num_beams=args.num_beams)
if args.profile_token_latency:
generate_kwargs["token_latency"] = True
ipex.transformers.optimize.convert_function(user_model, "greedy_search", _greedy_search)
if args.disable_optimize_transformers:
ipex.transformers.optimize.convert_function(user_model, "beam_search", _beam_search)
user_model.config.token_latency = True

total_time = 0.0
total_list = []
Expand Down Expand Up @@ -205,12 +263,16 @@
user_model = AutoModelForCausalLM.from_pretrained(
args.model, trust_remote_code=args.trust_remote_code, device_map=args.device, torch_dtype=torch_dtype) \
if user_model is None else user_model
user_model = ipex.optimize_transformers(
user_model.eval(), device=args.device, inplace=True, woq=True, dtype=torch_dtype)
if not args.disable_optimize_transformers:
print("Optimize with IPEX...")
user_model = ipex.optimize_transformers(
user_model.eval(), device=args.device, inplace=True, woq=(hasattr(user_model, "quantization_config")), dtype=torch_dtype)
else:
print("Disabled optimization with IPEX...")
results = evaluate(
model="hf-causal",
model_args='pretrained='+args.model+',tokenizer=' + args.model + \
',dtype=float32, trust_remote_code=' + str(args.trust_remote_code),
',dtype=float32,trust_remote_code=' + str(args.trust_remote_code),
user_model=user_model,
batch_size=args.batch_size,
tasks=args.tasks,
Expand Down
19 changes: 14 additions & 5 deletions intel_extension_for_transformers/llm/quantization/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@
import os
from accelerate import init_empty_weights
from datasets import load_dataset
from intel_extension_for_transformers.transformers.utils.utility import LazyImport
from neural_compressor import quantization
from neural_compressor.adaptor.torch_utils.model_wrapper import WeightOnlyLinear
from neural_compressor.utils.utility import LazyImport
from neural_compressor.config import PostTrainingQuantConfig
from ...utils.utils import is_ipex_available
from transformers import AutoTokenizer
Expand Down Expand Up @@ -349,6 +349,13 @@ def default_calib_func(model):
if config.algorithm in ["TEQ", "RTN", "GPTQ"]:
calib_func = None

orig_dtype = torch.float32
for param in model.parameters():
orig_dtype = param.dtype
if orig_dtype != torch.float32:
model.to(dtype=torch.float32)
break

inc_model = quantization.fit(model,
conf,
calib_func=calib_func,
Expand All @@ -363,7 +370,6 @@ def default_calib_func(model):
None,
config,
device=device)
return q_model.to("xpu")
else:
if config.algorithm == "GPTQ":
inc_model = inc_model.export_compressed_model(use_optimum_format=True)
Expand All @@ -381,9 +387,12 @@ def default_calib_func(model):
}

setattr(config, "gptq_quantize_config", quantize_config)
return replace_linear(inc_model, None, None, config, device=device)

return replace_linear(inc_model.model, None, None, config, device=device)
q_model = replace_linear(inc_model, None, None, config, device=device)
else:
q_model = replace_linear(inc_model.model, None, None, config, device=device)
if orig_dtype != torch.float32:
q_model.to(dtype=orig_dtype)
return q_model.to(device)

def convert_dtype_str2torch(str_dtype):
if str_dtype == "int8":
Expand Down
19 changes: 19 additions & 0 deletions intel_extension_for_transformers/llm/utils/generation/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright (c) 2023 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .beam_search import _beam_search
from .greedy_search import _greedy_search
Loading

0 comments on commit 8569d7b

Please sign in to comment.