Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fixing GPTQ #148

Open
wants to merge 2 commits into
base: gh/HDCharles/8/base
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
213 changes: 209 additions & 4 deletions GPTQ.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,12 @@ def __init__(

def add_input(self, args):
if self.inputs is None:
self.inputs = [MultiInput([arg]) for arg in args]
# self.inputs = [MultiInput([arg]) for arg in args]
self.inputs = [GPTQMultiTensor([arg]) for arg in args]
else:
self.inputs = [
multi.add_input(arg) for (multi, arg) in zip(self.inputs, args)
# multi.add_input(arg) for (multi, arg) in zip(self.inputs, args)
multi.add_tensors(arg) for (multi, arg) in zip(self.inputs, args)
]

def get_recorded_inputs(self):
Expand Down Expand Up @@ -129,6 +131,199 @@ def cuda(self):
self.values = [val.cuda() if isinstance(val, torch.Tensor) else val for val in self.values]


class GPTQMultiTensor(torch.Tensor):
"""
"""
# todo need default shape/dtype
@staticmethod
def __new__(cls, input, **kwargs):
if isinstance(input, (list, tuple)):
input = input[0]
kwargs["dtype"]=kwargs.get("dtype", input.dtype)
shape = kwargs.pop("shape", input.shape)
return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs)

def __init__(self, input, **kwargs):
self.values = []
self.add_tensors(input)
self.debug = True

def __repr__(self):
return (
f"{self.__class__.__name__}(data={self.values})"
)

def add_tensors(self, input):
if isinstance(input, (tuple, list)):
for inp in input:
self.add_tensors(inp)
else:
assert isinstance(input, torch.Tensor), f"MultiTensor can only use add_input for Tensors or lists of tensors but got {type(input)}"
self.values.append(input)
return self

def count(self):
return len(self.values)

def cuda(self):
self.values = [val.cuda() for val in self.values]
return self

def cpu(self):
self.values = [val.cpu() for val in self.values]
return self

def configure_quantization_mode(
self,
get_qparams_func,
quantize_func,
dequantize_func,
combine_qparams_list_func,
make_names_and_values_dict_func,
skip_layer_func,
):
self.get_qparams_func = get_qparams_func
self.quantize_func = quantize_func
self.dequantize_func = dequantize_func
self.combine_qparams_list_func = combine_qparams_list_func
self.skip_layer_func = skip_layer_func
self.make_names_and_values_dict_func = make_names_and_values_dict_func
return self

@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None, skip_gptq=False):
# with torch._C.DisableTorchFunctionSubclass():
# is_set_item = str(func)=="<slot wrapper '__setitem__' of 'torch._C.TensorBase' objects>"
# if is_set_item:
# breakpoint()
# try:
# new_arg1=[None if x == slice(None) else x for x in args[1]]
# return torch.ops.aten.index_put(args[0], new_arg1, args[2])
# except Exception as e:
# print(e)
# print("?A?")
# breakpoint()
# print("?")
# if func == torch.ops.aten.index_put_:
# breakpoint()

def tensors_to_cuda(args):
new_args = []
for x in args:
new_args.append(x.cuda() if isinstance(x, torch.Tensor) else x)
return new_args

def flat_to_grouped(flat):
# size of biggest MultiTensor
multi_tensor_size = max(
[x.count() if isinstance(x, GPTQMultiTensor) else 1 for x in flat]
)
# convert [A, MultiTensor(b1,b2,b3), MultiTensor(c1,c2,c3)] => [[A,b1,c1], [A,b2,c2] [A,b3,c3]]
grouped = list(
zip(
*[x.values if isinstance(x, GPTQMultiTensor) else [x] * multi_tensor_size for x in flat]
)
)
return grouped

# convert [[A,b1,c1], [A,b2,c2] [A,b3,c3]] => [A, MultiTensor(b1,b2,b3), MultiTensor(c1,c2,c3)]
# where A is nontensor, b's,c's are tensors
def grouped_to_flat(grouped):
# convert [[A,b1,c1], [A,b2,c2] [A,b3,c3]] => [(A,A,A), (b1,b2,b3), (c1,c2,c3)]
flat_tups = list(zip(*grouped))
# convert [(A,A,A), (b1,b2,b3), (c1,c2,c3)] => [A, MultiTensor(b1,b2,b3), MultiTensor(c1,c2,c3)]
flattened = [
cls(tup).cpu() if isinstance(tup[0], torch.Tensor) else tup[0] for tup in flat_tups
]
# need to check that getting rid of all but one from each nonTensor tuple is OK
non_tensors_equal=min([True]+[
min([True]+[ # handle situation where tuples have size 0
tup[0]==x for x in tup # check all elements match
]) for tup in flat_tups if not isinstance(tup[0], torch.Tensor) # look at tuples of nonTensors
])
return flattened, non_tensors_equal

kwargs = {} if kwargs is None else kwargs
# combine args and kwargs and remove lists and tuples
flat_args, spec = tree_flatten((args, kwargs))
# move single tensors to cuda

# flat_args = tensors_to_cuda(flat_args)

# convert [A, MultiTensor(b1,b2,b3), MultiTensor(c1,c2,c3)] => [[A,b1,c1], [A,b2,c2] [A,b3,c3]]
grouped_args = flat_to_grouped(flat_args)

do_gptq_linear = (
func is nn.functional.linear
# and id(args[1]) in self.id_to_name
and not skip_gptq
# and not (self.skip_layer_func)
)

# run function for each of the multitensors and return a multitensor
if not do_gptq_linear:
outputs = []
with torch._C.DisableTorchFunctionSubclass():
for inp in grouped_args:
# inp = tensors_to_cuda(inp)
cur_args, cur_kwargs = tree_unflatten(inp, spec)
try:
out = func(*cur_args, **cur_kwargs)
outputs.append(out.cpu() if isinstance(out, torch.Tensor) else out)
except Exception as e:
print(e)
print("?B?")
breakpoint()
print("?")
try:
# each output
grouped_outputs = [tree_flatten(x)[0] for x in outputs]
out_spec = tree_flatten(outputs[0])[1]
# convert [[A,b1,c1], [A,b2,c2] [A,b3,c3]] => [A, MultiTensor(b1,b2,b3), MultiTensor(c1,c2,c3)]
flat_outputs, non_tensors_equal = grouped_to_flat(grouped_outputs)
assert non_tensors_equal, (
f"ERR: found a function in model: {func} which "
+"caused an error in GPTQMultiInput, the function dispatch only works for functions"
+" with Tensor outputs or that have the same non-Tensor output value for all across all inputs"
)
return tree_unflatten(flat_outputs, out_spec)
except Exception as e:
print(e)
print("?C?")
breakpoint()
print("?")

# do GPTQ if quantize_linear is true
total_batches = 0
H=0
for inp in grouped_args:
# inp = tensors_to_cuda(inp)
cur_args, cur_kwargs = tree_unflatten(inp, spec)
x = cur_args[0].float()
shape = x.shape
n = 1 if len(shape) == 2 else shape[0]
H*= total_batches / (total_batches + n)
total_batches += n
x = (
(2 / total_batches) ** (1 / 2) *
x.reshape(-1, shape[-1]).t().float()

)
H += x.matmul(x.t())
W = args[1].to(H.device)
DQ = W+.01
# Q, DQ, qparams = args[0].faster_quant(H, W.detach())

new_out = cls.__torch_function__(func, types, (args[0], DQ, *args[2:]), kwargs, skip_gptq = True)
# if args[0].debug:
return new_out

@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs):
breakpoint()
pass


class GenericGPTQRunner(fx.Interpreter):
"""
This is a generic GPTQ runner that takes an existing model and applies GPTQ.
Expand All @@ -150,7 +345,7 @@ def __init__(
}

# trace model for one input
one_input = [multi.values[0].cpu() for multi in inputs]
one_input = tuple([multi.values[0].cpu() for multi in inputs])
exported_model = torch._dynamo.export(
model.cpu(), aten_graph=True, pre_dispatch=True, tracing_mode="fake"
)(*one_input)
Expand All @@ -161,7 +356,7 @@ def __init__(
self.groupsize = groupsize
self.inputs = inputs
self.gptq_done = False
self.debug = False
self.debug = True

def configure_quantization_mode(
self,
Expand Down Expand Up @@ -312,6 +507,16 @@ def SQNR(x, y):
print(
"SQNR for QDQ (this should be inf)", SQNR(DQ, DQ_after)
) # matches
qparams_after = self.get_qparams_func(DQ)
Q_after = self.quantize_func(DQ, qparams_after)
print(
"abs difference of Q-quant(DQ)", (Q-Q_after).abs().sum()
)
DQ_after_after = self.dequantize_func(Q_after, qparams_after).to(DQ.dtype)
print(
"SQNR for DQ(Q(DQ)) vs DQ", SQNR(DQ, DQ_after_after)
)
breakpoint()

print(
"SQNR for weight (can be low)", SQNR(W, DQ.cuda())
Expand Down
7 changes: 6 additions & 1 deletion model.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,10 +78,16 @@ def update(self, input_pos, k_val, v_val):
# input_pos: [S], k_val: [B, H, S, D]
assert input_pos.shape[0] == k_val.shape[2]


k_out = self.k_cache
v_out = self.v_cache
k_out[:, :, input_pos] = k_val
breakpoint()
v_out[:, :, input_pos] = v_val
breakpoint()
# k_out = torch.ops.aten.index_put_(self.k_cache, [None, None, input_pos], k_val)
# v_out = torch.ops.aten.index_put_(self.v_cache, [None, None, input_pos], v_val)


return k_out, v_out

Expand Down Expand Up @@ -174,7 +180,6 @@ def forward(self, x: Tensor, freqs_cis: Tensor, mask: Tensor, input_pos: Optiona

kv_size = self.n_local_heads * self.head_dim
q, k, v = self.wqkv(x).split([self.dim, kv_size, kv_size], dim=-1)

q = q.view(bsz, seqlen, self.n_head, self.head_dim)
k = k.view(bsz, seqlen, self.n_local_heads, self.head_dim)
v = v.view(bsz, seqlen, self.n_local_heads, self.head_dim)
Expand Down
14 changes: 10 additions & 4 deletions quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
import torch.nn.functional as F
from sentencepiece import SentencePieceProcessor

from GPTQ import GenericGPTQRunner, InputRecorder

try:
from GPTQ import GenericGPTQRunner, InputRecorder
from eval import get_task_dict, evaluate, lm_eval
Expand Down Expand Up @@ -286,6 +288,10 @@ def create_quantized_state_dict(
pad_calibration_inputs,
) -> "StateDict":
inputs = GPTQQuantHandler.get_inputs(self.mod, tokenizer, calibration_tasks, calibration_limit, calibration_seq_length, pad_calibration_inputs)
self.mod=self.mod.to("cpu")
inputs=[x.cpu() if hasattr(x, "cpu") else x for x in inputs]
self.mod(*inputs)

print("Tracing model for GPTQ")
GPTQ_runner = GenericGPTQRunner(
self.mod,
Expand Down Expand Up @@ -438,12 +444,12 @@ def convert_for_runtime(self, use_cuda):
return self.mod

class WeightOnlyInt4GPTQQuantHandler(GPTQQuantHandler):
def __init__(self, mod, groupsize=128, inner_k_tiles=8, padding=True):
def __init__(self, mod, groupsize=128, inner_k_tiles=8, padding_allowed=True):
from model import find_multiple
self.mod = mod
self.groupsize = groupsize
self.inner_k_tiles = inner_k_tiles
self.padding = padding
self.padding_allowed = padding_allowed
self.get_qparams_func = lambda w: get_group_qparams(w, 4, groupsize)
self.quantize_func = lambda w, qparams: \
group_quantize_tensor_from_qparams(w, qparams[0], qparams[1], 4, groupsize)
Expand All @@ -453,7 +459,7 @@ def __init__(self, mod, groupsize=128, inner_k_tiles=8, padding=True):
[torch.cat(x, dim=1) for x in zip(*qparams_list)]
# skip unless padding=True or its correctly sized
self.skip_layer_func = lambda linear_weight: not (
_check_linear_int4_k(linear_weight.shape[-1], groupsize, inner_k_tiles) or padding
_check_linear_int4_k(linear_weight.shape[-1], groupsize, inner_k_tiles) or padding_allowed
)
# we need to do the padding here, both for q and the qparams if necessary
def make_names_and_values_dict_func(q, qparams):
Expand All @@ -472,7 +478,7 @@ def make_names_and_values_dict_func(q, qparams):


def convert_for_runtime(self, use_cuda):
replace_linear_int4(self.mod, self.groupsize, self.inner_k_tiles, self.padding, use_cuda)
replace_linear_int4(self.mod, self.groupsize, self.inner_k_tiles, self.padding_allowed, use_cuda)
return self.mod

class WeightOnlyInt4Linear(torch.nn.Module):
Expand Down
19 changes: 19 additions & 0 deletions run.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
export MODEL_REPO=meta-llama/Llama-2-7b-chat-hf

# python generate.py --checkpoint_path checkpoints/$MODEL_REPO/model.pth --compile # working
# echo "base"
python quantize.py --checkpoint_path checkpoints/$MODEL_REPO/model.pth --mode int4-gptq --calibration_tasks wikitext --calibration_limit 1
# python eval.py --checkpoint_path checkpoints/$MODEL_REPO/model_int4-gptq.g32.cuda.pth --tasks wikitext --limit 5

# python generate.py --checkpoint_path checkpoints/$MODEL_REPO/model_int4.g32.pth --compile
# echo "quant good"

# python quantize.py --checkpoint_path checkpoints/$MODEL_REPO/model.pth --mode int4
# python eval.py --checkpoint_path checkpoints/$MODEL_REPO/model_int4.g32.cuda.pth --tasks wikitext --limit 5

# export MODEL_REPO=meta-llama/Llama-2-70b-chat-hf
# python quantize.py --checkpoint_path checkpoints/$MODEL_REPO/model.pth --mode int4
# python eval.py --checkpoint_path checkpoints/$MODEL_REPO/model_int4.g32.cuda.pth --tasks wikitext --limit 5
# ENABLE_INTRA_NODE_COMM=1 torchrun --standalone --nproc_per_node=8 generate.py --compile --checkpoint_path checkpoints/$MODEL_REPO/model_int4.g32.cuda.pth

# python quantize.py --checkpoint_path checkpoints/$MODEL_REPO/model.pth --mode int4-gptq --calibration_tasks wikitext --calibration_limit 5