From 298c44337c2a1072c98ff982f13bb1b37011d79b Mon Sep 17 00:00:00 2001 From: HDCharles Date: Thu, 28 Mar 2024 01:42:40 -0700 Subject: [PATCH 1/2] fixing GPTQ Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned] --- GPTQ.py | 213 +++++++++++++++++++++++++++++++++++++++++++++++++++- model.py | 7 +- quantize.py | 14 +++- run.sh | 19 +++++ 4 files changed, 244 insertions(+), 9 deletions(-) create mode 100644 run.sh diff --git a/GPTQ.py b/GPTQ.py index e1279bd..3cf78fd 100644 --- a/GPTQ.py +++ b/GPTQ.py @@ -66,10 +66,12 @@ def __init__( def add_input(self, args): if self.inputs is None: - self.inputs = [MultiInput([arg]) for arg in args] + # self.inputs = [MultiInput([arg]) for arg in args] + self.inputs = [GPTQMultiTensor([arg]) for arg in args] else: self.inputs = [ - multi.add_input(arg) for (multi, arg) in zip(self.inputs, args) + # multi.add_input(arg) for (multi, arg) in zip(self.inputs, args) + multi.add_tensors(arg) for (multi, arg) in zip(self.inputs, args) ] def get_recorded_inputs(self): @@ -129,6 +131,199 @@ def cuda(self): self.values = [val.cuda() if isinstance(val, torch.Tensor) else val for val in self.values] +class GPTQMultiTensor(torch.Tensor): + """ + """ + # todo need default shape/dtype + @staticmethod + def __new__(cls, input, **kwargs): + if isinstance(input, (list, tuple)): + input = input[0] + kwargs["dtype"]=kwargs.get("dtype", input.dtype) + shape = kwargs.pop("shape", input.shape) + return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) + + def __init__(self, input, **kwargs): + self.values = [] + self.add_tensors(input) + self.debug = True + + def __repr__(self): + return ( + f"{self.__class__.__name__}(data={self.values})" + ) + + def add_tensors(self, input): + if isinstance(input, (tuple, list)): + for inp in input: + self.add_tensors(inp) + else: + assert isinstance(input, torch.Tensor), f"MultiTensor can only use add_input for Tensors or lists of tensors but got {type(input)}" + self.values.append(input) + return self + + def count(self): + return len(self.values) + + def cuda(self): + self.values = [val.cuda() for val in self.values] + return self + + def cpu(self): + self.values = [val.cpu() for val in self.values] + return self + + def configure_quantization_mode( + self, + get_qparams_func, + quantize_func, + dequantize_func, + combine_qparams_list_func, + make_names_and_values_dict_func, + skip_layer_func, + ): + self.get_qparams_func = get_qparams_func + self.quantize_func = quantize_func + self.dequantize_func = dequantize_func + self.combine_qparams_list_func = combine_qparams_list_func + self.skip_layer_func = skip_layer_func + self.make_names_and_values_dict_func = make_names_and_values_dict_func + return self + + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None, skip_gptq=False): + # with torch._C.DisableTorchFunctionSubclass(): + # is_set_item = str(func)=="" + # if is_set_item: + # breakpoint() + # try: + # new_arg1=[None if x == slice(None) else x for x in args[1]] + # return torch.ops.aten.index_put(args[0], new_arg1, args[2]) + # except Exception as e: + # print(e) + # print("?A?") + # breakpoint() + # print("?") + # if func == torch.ops.aten.index_put_: + # breakpoint() + + def tensors_to_cuda(args): + new_args = [] + for x in args: + new_args.append(x.cuda() if isinstance(x, torch.Tensor) else x) + return new_args + + def flat_to_grouped(flat): + # size of biggest MultiTensor + multi_tensor_size = max( + [x.count() if isinstance(x, GPTQMultiTensor) else 1 for x in flat] + ) + # convert [A, MultiTensor(b1,b2,b3), MultiTensor(c1,c2,c3)] => [[A,b1,c1], [A,b2,c2] [A,b3,c3]] + grouped = list( + zip( + *[x.values if isinstance(x, GPTQMultiTensor) else [x] * multi_tensor_size for x in flat] + ) + ) + return grouped + + # convert [[A,b1,c1], [A,b2,c2] [A,b3,c3]] => [A, MultiTensor(b1,b2,b3), MultiTensor(c1,c2,c3)] + # where A is nontensor, b's,c's are tensors + def grouped_to_flat(grouped): + # convert [[A,b1,c1], [A,b2,c2] [A,b3,c3]] => [(A,A,A), (b1,b2,b3), (c1,c2,c3)] + flat_tups = list(zip(*grouped)) + # convert [(A,A,A), (b1,b2,b3), (c1,c2,c3)] => [A, MultiTensor(b1,b2,b3), MultiTensor(c1,c2,c3)] + flattened = [ + cls(tup).cpu() if isinstance(tup[0], torch.Tensor) else tup[0] for tup in flat_tups + ] + # need to check that getting rid of all but one from each nonTensor tuple is OK + non_tensors_equal=min([True]+[ + min([True]+[ # handle situation where tuples have size 0 + tup[0]==x for x in tup # check all elements match + ]) for tup in flat_tups if not isinstance(tup[0], torch.Tensor) # look at tuples of nonTensors + ]) + return flattened, non_tensors_equal + + kwargs = {} if kwargs is None else kwargs + # combine args and kwargs and remove lists and tuples + flat_args, spec = tree_flatten((args, kwargs)) + # move single tensors to cuda + + # flat_args = tensors_to_cuda(flat_args) + + # convert [A, MultiTensor(b1,b2,b3), MultiTensor(c1,c2,c3)] => [[A,b1,c1], [A,b2,c2] [A,b3,c3]] + grouped_args = flat_to_grouped(flat_args) + + do_gptq_linear = ( + func is nn.functional.linear + # and id(args[1]) in self.id_to_name + and not skip_gptq + # and not (self.skip_layer_func) + ) + + # run function for each of the multitensors and return a multitensor + if not do_gptq_linear: + outputs = [] + with torch._C.DisableTorchFunctionSubclass(): + for inp in grouped_args: + # inp = tensors_to_cuda(inp) + cur_args, cur_kwargs = tree_unflatten(inp, spec) + try: + out = func(*cur_args, **cur_kwargs) + outputs.append(out.cpu() if isinstance(out, torch.Tensor) else out) + except Exception as e: + print(e) + print("?B?") + breakpoint() + print("?") + try: + # each output + grouped_outputs = [tree_flatten(x)[0] for x in outputs] + out_spec = tree_flatten(outputs[0])[1] + # convert [[A,b1,c1], [A,b2,c2] [A,b3,c3]] => [A, MultiTensor(b1,b2,b3), MultiTensor(c1,c2,c3)] + flat_outputs, non_tensors_equal = grouped_to_flat(grouped_outputs) + assert non_tensors_equal, ( + f"ERR: found a function in model: {func} which " + +"caused an error in GPTQMultiInput, the function dispatch only works for functions" + +" with Tensor outputs or that have the same non-Tensor output value for all across all inputs" + ) + return tree_unflatten(flat_outputs, out_spec) + except Exception as e: + print(e) + print("?C?") + breakpoint() + print("?") + + # do GPTQ if quantize_linear is true + total_batches = 0 + H=0 + for inp in grouped_args: + # inp = tensors_to_cuda(inp) + cur_args, cur_kwargs = tree_unflatten(inp, spec) + x = cur_args[0].float() + shape = x.shape + n = 1 if len(shape) == 2 else shape[0] + H*= total_batches / (total_batches + n) + total_batches += n + x = ( + (2 / total_batches) ** (1 / 2) * + x.reshape(-1, shape[-1]).t().float() + + ) + H += x.matmul(x.t()) + W = args[1].to(H.device) + DQ = W+.01 + # Q, DQ, qparams = args[0].faster_quant(H, W.detach()) + + new_out = cls.__torch_function__(func, types, (args[0], DQ, *args[2:]), kwargs, skip_gptq = True) + # if args[0].debug: + return new_out + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs): + breakpoint() + pass + + class GenericGPTQRunner(fx.Interpreter): """ This is a generic GPTQ runner that takes an existing model and applies GPTQ. @@ -150,7 +345,7 @@ def __init__( } # trace model for one input - one_input = [multi.values[0].cpu() for multi in inputs] + one_input = tuple([multi.values[0].cpu() for multi in inputs]) exported_model = torch._dynamo.export( model.cpu(), aten_graph=True, pre_dispatch=True, tracing_mode="fake" )(*one_input) @@ -161,7 +356,7 @@ def __init__( self.groupsize = groupsize self.inputs = inputs self.gptq_done = False - self.debug = False + self.debug = True def configure_quantization_mode( self, @@ -312,6 +507,16 @@ def SQNR(x, y): print( "SQNR for QDQ (this should be inf)", SQNR(DQ, DQ_after) ) # matches + qparams_after = self.get_qparams_func(DQ) + Q_after = self.quantize_func(DQ, qparams_after) + print( + "abs difference of Q-quant(DQ)", (Q-Q_after).abs().sum() + ) + DQ_after_after = self.dequantize_func(Q_after, qparams_after).to(DQ.dtype) + print( + "SQNR for DQ(Q(DQ)) vs DQ", SQNR(DQ, DQ_after_after) + ) + breakpoint() print( "SQNR for weight (can be low)", SQNR(W, DQ.cuda()) diff --git a/model.py b/model.py index a47d00d..9f40c19 100644 --- a/model.py +++ b/model.py @@ -78,10 +78,16 @@ def update(self, input_pos, k_val, v_val): # input_pos: [S], k_val: [B, H, S, D] assert input_pos.shape[0] == k_val.shape[2] + k_out = self.k_cache v_out = self.v_cache k_out[:, :, input_pos] = k_val + breakpoint() v_out[:, :, input_pos] = v_val + breakpoint() + # k_out = torch.ops.aten.index_put_(self.k_cache, [None, None, input_pos], k_val) + # v_out = torch.ops.aten.index_put_(self.v_cache, [None, None, input_pos], v_val) + return k_out, v_out @@ -174,7 +180,6 @@ def forward(self, x: Tensor, freqs_cis: Tensor, mask: Tensor, input_pos: Optiona kv_size = self.n_local_heads * self.head_dim q, k, v = self.wqkv(x).split([self.dim, kv_size, kv_size], dim=-1) - q = q.view(bsz, seqlen, self.n_head, self.head_dim) k = k.view(bsz, seqlen, self.n_local_heads, self.head_dim) v = v.view(bsz, seqlen, self.n_local_heads, self.head_dim) diff --git a/quantize.py b/quantize.py index a9b3f79..49222ab 100644 --- a/quantize.py +++ b/quantize.py @@ -11,6 +11,8 @@ import torch.nn.functional as F from sentencepiece import SentencePieceProcessor +from GPTQ import GenericGPTQRunner, InputRecorder + try: from GPTQ import GenericGPTQRunner, InputRecorder from eval import get_task_dict, evaluate, lm_eval @@ -286,6 +288,10 @@ def create_quantized_state_dict( pad_calibration_inputs, ) -> "StateDict": inputs = GPTQQuantHandler.get_inputs(self.mod, tokenizer, calibration_tasks, calibration_limit, calibration_seq_length, pad_calibration_inputs) + self.mod=self.mod.to("cpu") + inputs=[x.cpu() if hasattr(x, "cpu") else x for x in inputs] + self.mod(*inputs) + print("Tracing model for GPTQ") GPTQ_runner = GenericGPTQRunner( self.mod, @@ -438,12 +444,12 @@ def convert_for_runtime(self, use_cuda): return self.mod class WeightOnlyInt4GPTQQuantHandler(GPTQQuantHandler): - def __init__(self, mod, groupsize=128, inner_k_tiles=8, padding=True): + def __init__(self, mod, groupsize=128, inner_k_tiles=8, padding_allowed=True): from model import find_multiple self.mod = mod self.groupsize = groupsize self.inner_k_tiles = inner_k_tiles - self.padding = padding + self.padding_allowed = padding_allowed self.get_qparams_func = lambda w: get_group_qparams(w, 4, groupsize) self.quantize_func = lambda w, qparams: \ group_quantize_tensor_from_qparams(w, qparams[0], qparams[1], 4, groupsize) @@ -453,7 +459,7 @@ def __init__(self, mod, groupsize=128, inner_k_tiles=8, padding=True): [torch.cat(x, dim=1) for x in zip(*qparams_list)] # skip unless padding=True or its correctly sized self.skip_layer_func = lambda linear_weight: not ( - _check_linear_int4_k(linear_weight.shape[-1], groupsize, inner_k_tiles) or padding + _check_linear_int4_k(linear_weight.shape[-1], groupsize, inner_k_tiles) or padding_allowed ) # we need to do the padding here, both for q and the qparams if necessary def make_names_and_values_dict_func(q, qparams): @@ -472,7 +478,7 @@ def make_names_and_values_dict_func(q, qparams): def convert_for_runtime(self, use_cuda): - replace_linear_int4(self.mod, self.groupsize, self.inner_k_tiles, self.padding, use_cuda) + replace_linear_int4(self.mod, self.groupsize, self.inner_k_tiles, self.padding_allowed, use_cuda) return self.mod class WeightOnlyInt4Linear(torch.nn.Module): diff --git a/run.sh b/run.sh new file mode 100644 index 0000000..80d6877 --- /dev/null +++ b/run.sh @@ -0,0 +1,19 @@ +export MODEL_REPO=meta-llama/Llama-2-7b-chat-hf + +# python generate.py --checkpoint_path checkpoints/$MODEL_REPO/model.pth --compile # working +# echo "base" +python quantize.py --checkpoint_path checkpoints/$MODEL_REPO/model.pth --mode int4-gptq --calibration_tasks wikitext --calibration_limit 1 +# python eval.py --checkpoint_path checkpoints/$MODEL_REPO/model_int4-gptq.g32.cuda.pth --tasks wikitext --limit 5 + +# python generate.py --checkpoint_path checkpoints/$MODEL_REPO/model_int4.g32.pth --compile +# echo "quant good" + +# python quantize.py --checkpoint_path checkpoints/$MODEL_REPO/model.pth --mode int4 +# python eval.py --checkpoint_path checkpoints/$MODEL_REPO/model_int4.g32.cuda.pth --tasks wikitext --limit 5 + +# export MODEL_REPO=meta-llama/Llama-2-70b-chat-hf +# python quantize.py --checkpoint_path checkpoints/$MODEL_REPO/model.pth --mode int4 +# python eval.py --checkpoint_path checkpoints/$MODEL_REPO/model_int4.g32.cuda.pth --tasks wikitext --limit 5 +# ENABLE_INTRA_NODE_COMM=1 torchrun --standalone --nproc_per_node=8 generate.py --compile --checkpoint_path checkpoints/$MODEL_REPO/model_int4.g32.cuda.pth + +# python quantize.py --checkpoint_path checkpoints/$MODEL_REPO/model.pth --mode int4-gptq --calibration_tasks wikitext --calibration_limit 5 From 7392a31d3ae0e7c727cd16814ae392bc49b958a3 Mon Sep 17 00:00:00 2001 From: HDCharles Date: Thu, 28 Mar 2024 01:50:18 -0700 Subject: [PATCH 2/2] Update on "fixing GPTQ" Summary: trying to fix the issue with kv_cache update by changing tracing into a tensor subclass. However it seems we have less success than the fx tracer. The fx tracer breaks due k_out[:,:, input_pos] = k_val getting traced as new_var = torch.ops.aten.index_put_(k_out, [None, None, input_pos], k_val) with new var never being accessed afterward. new_var becomes hte correct multiInput value, but then is lost. The subclass ont he other hand, tries to use the func "" which seems to not want to mutate k_out and so the attempt to make it a multiTensor fails. Test Plan: sh run.sh Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned]