From 0cc5c95efc213ffc8ed64ad34376fa4e4a3114ac Mon Sep 17 00:00:00 2001 From: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com> Date: Mon, 7 Oct 2024 13:05:20 -0400 Subject: [PATCH] Start of int8 refactor: remove col32/col_ampere/col_turing transforms in new igemmlt implementation --- bitsandbytes/autograd/_functions.py | 83 ++++---- bitsandbytes/functional.py | 165 +++++++-------- bitsandbytes/nn/modules.py | 7 +- bitsandbytes/research/autograd/_functions.py | 33 +-- csrc/ops.cu | 103 +++++++++- csrc/ops.cuh | 2 +- csrc/pythonInterface.cpp | 19 ++ tests/test_autograd.py | 12 +- tests/test_functional.py | 199 +++++++++++-------- tests/test_modules.py | 2 +- 10 files changed, 382 insertions(+), 243 deletions(-) diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index d33dd1bc5..01845a131 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -245,11 +245,11 @@ class MatmulLtState: _tile_indices: Optional[torch.Tensor] = None force_no_igemmlt: bool = False CB = None - CxB = None + CxB = None # TODO: Deprecate/remove SB = None SCB = None - CxBt = None + CxBt = None # TODO: Deprecate/remove SBt = None CBt = None @@ -263,7 +263,7 @@ class MatmulLtState: has_fp16_weights = True memory_efficient_backward = False use_pool = False - formatB = F.get_special_format_str() + formatB = "row" # F.get_special_format_str() TODO: Deprecate/remove def reset_grads(self): self.CB = None @@ -283,9 +283,6 @@ def tile_indices(self): class MatMul8bitLt(torch.autograd.Function): - # forward is the same, but we added the fallback for pre-turing GPUs - # backward is mostly the same, but adds one extra clause (see "elif state.CxB is not None") - @staticmethod def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState): using_igemmlt = supports_igemmlt(A.device) and not state.force_no_igemmlt @@ -306,7 +303,6 @@ def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState): # 3. Matmul # 4. Mixed-precision decomposition matmul # 5. Save state - formatB = state.formatB input_shape = A.shape if state.outlier_pool is None: state.outlier_pool = GlobalOutlierPooler.get_instance() @@ -328,14 +324,7 @@ def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState): subA = A[:, idx] state.subB = B[:, idx].t().contiguous() state.idx = idx - else: - if state.CxB is None and using_igemmlt: - # B in in 8-bit row-major, we can transform it back to 16-bit to extract outlier dimensions - # we also need to convert it to the turing/ampere format - state.CxB, state.SB = F.transform(state.CB, to_order=formatB) else: - if not state.has_fp16_weights and state.CxB is None and using_igemmlt: - state.CxB, state.SB = F.transform(state.CB, to_order=formatB) subA = None # 2. Quantize B @@ -345,19 +334,17 @@ def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState): if is_transposed: B = B.contiguous() - if (state.is_training and not has_grad) or state.CxB is None: + if (state.is_training and not has_grad) or state.CB is None: state.reset_grads() + + # quantize... ( - CB, + state.CB, state.CBt, state.SCB, state.SCBt, coo_tensorB, ) = F.double_quant(B.to(torch.float16)) - if using_igemmlt: - state.CxB, state.SB = F.transform(CB, to_order=formatB) - else: - state.CB = CB else: has_grad = False @@ -372,17 +359,18 @@ def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState): # state.idx = state.outlier_pool.get_current_outlier_idx().to(A.device) # else: # state.idx = outlier_idx - if state.CxB is not None: - outliers = F.extract_outliers(state.CxB, state.SB, state.idx.int()) - else: - outliers = state.CB[:, state.idx.long()].clone() + + # if state.CxB is not None: + # outliers = F.extract_outliers(state.CxB, state.SB, state.idx.int()) + # else: + outliers = state.CB[:, state.idx.long()].clone() state.subB = (outliers * state.SCB.view(-1, 1) / 127.0).t().contiguous().to(A.dtype) CA[:, state.idx.long()] = 0 CAt[:, state.idx.long()] = 0 subA = A[:, state.idx.long()] - shapeB = state.SB[0] if state.SB else B.shape + shapeB = state.CB.shape if len(input_shape) == 3: output_shape = (input_shape[0], input_shape[1], shapeB[0]) @@ -391,13 +379,14 @@ def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState): # 3. Matmul if using_igemmlt: - C32A, SA = F.transform(CA, "col32") - out32, Sout32 = F.igemmlt(C32A, state.CxB, SA, state.SB) + out32, Sout32 = F.igemmlt(CA, state.CB) + if bias is None or bias.dtype == torch.float16: # we apply the fused bias here output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=bias) output = output.to(A.dtype) else: # apply bias separately + # TODO: Fused bias for fp32/bf16? output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=None) output = output.to(A.dtype).add_(bias) @@ -417,7 +406,6 @@ def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState): # 5. Save state ctx.state = state - ctx.formatB = formatB ctx.grad_shape = input_shape ctx.dtype_A, ctx.dtype_B, ctx.dtype_bias = A.dtype, B.dtype, None if bias is None else bias.dtype @@ -437,10 +425,10 @@ def backward(ctx, grad_output): if ctx.is_empty: bias_grad = None if ctx.bias is None else torch.zeros_like(ctx.bias) return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, bias_grad, None + req_gradA, req_gradB, _, req_gradBias, _ = ctx.needs_input_grad CAt, subA, A = ctx.tensors SCAt, idx = ctx.tensor_states - formatB = ctx.formatB state = ctx.state grad_A = grad_B = grad_bias = None @@ -454,33 +442,39 @@ def backward(ctx, grad_output): Cgrad, Cgradt, SCgrad, SCgradt, coo_tensor = F.double_quant(grad_output.to(torch.float16)) if req_gradB: - CxAt, SAt = F.transform(CAt, formatB, transpose=True) - C32grad, Sgrad = F.transform(Cgradt, "col32", transpose=True) - gradB32, SgradB32 = F.igemmlt(C32grad, CxAt, Sgrad, SAt) + # CxAt, SAt = F.transform(CAt, formatB, transpose=True) + # C32grad, Sgrad = F.transform(Cgradt, "col32", transpose=True) + # gradB32, SgradB32 = F.igemmlt(C32grad, CxAt, Sgrad, SAt) + # grad_B = F.mm_dequant(gradB32, SgradB32, SCgradt, SCAt) + gradB32, SgradB32 = F.igemmlt( + Cgradt.t(), CAt.t() + ) # issue here in test_linear_serialization w/ has fp16 weights grad_B = F.mm_dequant(gradB32, SgradB32, SCgradt, SCAt) if state.threshold > 0.0 and subA is not None: grad_B[:, idx] += torch.matmul(grad_output.t(), subA) if req_gradA: if state.CBt is not None: - C32grad, Sgrad = F.transform(Cgrad, "col32") - if state.CxBt is None: - state.CxBt, state.SBt = F.transform(state.CBt, to_order=formatB, transpose=True) - gradA32, SgradA32 = F.igemmlt(C32grad, state.CxBt, Sgrad, state.SBt) + # C32grad, Sgrad = F.transform(Cgrad, "col32") + # if state.CxBt is None: + # state.CxBt, state.SBt = F.transform(state.CBt, to_order=formatB, transpose=True) + # gradA32, SgradA32 = F.igemmlt(C32grad, state.CxBt, Sgrad, state.SBt) + # grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape).to(ctx.dtype_A) + gradA32, SgradA32 = F.igemmlt(Cgradt, state.CBt.t()) grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape).to(ctx.dtype_A) elif state.CB is not None: CB = state.CB.to(ctx.dtype_A, copy=True).mul_(state.SCB.unsqueeze(1).mul(1.0 / 127.0)) grad_A = torch.matmul(grad_output, CB).view(ctx.grad_shape).to(ctx.dtype_A) - elif state.CxB is not None: - CB = ( - undo_layout(state.CxB, state.tile_indices) - .to(ctx.dtype_A) - .mul_(state.SCB.unsqueeze(1).mul(1.0 / 127.0)) - ) - grad_A = torch.matmul(grad_output, CB).view(ctx.grad_shape).to(ctx.dtype_A) + # elif state.CxB is not None: + # CB = ( + # undo_layout(state.CxB, state.tile_indices) + # .to(ctx.dtype_A) + # .mul_(state.SCB.unsqueeze(1).mul(1.0 / 127.0)) + # ) + # grad_A = torch.matmul(grad_output, CB).view(ctx.grad_shape).to(ctx.dtype_A) else: - raise Exception("State must contain either CBt or CB or CxB matrix for backward") + raise Exception("State must contain either CBt or CB matrix for backward") return grad_A, grad_B, None, grad_bias, None @@ -564,6 +558,7 @@ def matmul_4bit( bias=None, ): assert quant_state is not None + if A.numel() == A.shape[-1] and A.requires_grad == False: if A.shape[-1] % quant_state.blocksize != 0: warn( diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 34b3c0293..7f07778ef 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -3,9 +3,8 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import ctypes as ct -from functools import reduce # Required in Python 3 import itertools -import operator +from math import prod from typing import Any, Dict, Optional, Tuple import numpy as np @@ -16,12 +15,6 @@ from .cextension import lib - -# math.prod not compatible with python < 3.8 -def prod(iterable): - return reduce(operator.mul, iterable, 1) - - name2qmap = {} if lib and lib.compiled_with_cuda: @@ -421,15 +414,9 @@ def create_quantile_map(A, total_bits=8): return q +# TODO: Deprecate def get_special_format_str(): - if not torch.cuda.is_available(): - return "col_turing" - major, _minor = torch.cuda.get_device_capability() - if major <= 7: - return "col_turing" - if major == 8: - return "col_ampere" - return "col_turing" + return "row" def is_on_gpu(tensors): @@ -2302,84 +2289,68 @@ def batched_igemm( return out -def igemmlt(A, B, SA, SB, out=None, Sout=None, dtype=torch.int32): - shapeA = SA[0] - shapeB = SB[0] - dimsA = len(shapeA) - dimsB = len(shapeB) - assert dimsB == 2, "Only two dimensional matrices are supported for argument B" - if dimsA == 2: - m = shapeA[0] - elif dimsA == 3: - m = shapeA[0] * shapeA[1] - - rows = n = shapeB[0] - assert prod(list(shapeA)) > 0, f"Input tensor dimensions need to be > 0: {shapeA}" - - # if the tensor is empty, return a transformed empty tensor with the right dimensions - if shapeA[0] == 0 and dimsA == 2: - return torch.empty((0, shapeB[0]), device=A.device, dtype=torch.float16) - elif shapeA[1] == 0 and dimsA == 3: - return torch.empty(tuple(shapeA[:2] + [shapeB[0]]), device=A.device, dtype=torch.float16) - - if dimsA == 2 and out is None: - out, Sout = get_transform_buffer((shapeA[0], shapeB[0]), dtype, A.device, "col32", "row") - elif dimsA == 3 and out is None: - out, Sout = get_transform_buffer((shapeA[0], shapeA[1], shapeB[0]), dtype, A.device, "col32", "row") - - assert dimsB != 3, "len(B.shape)==3 not supported" +def igemmlt(A, B, out=None, Sout=None, dtype=torch.int32): + # + # To use the IMMA tensor core kernels without special Turing/Ampere layouts, + # cublasLt has some rules, namely: A must be transposed, B must not be transposed. + # The C++ API will calculate `C = A.T @ B` in with A, B, C in col-major. + # This will typically be used with row-major tensors to efficiently + # calculate the linear layer with `C = B @ A.T` without any transformations. + # We will swap A and B in the API invocation, so that we get `C = A @ B.T`. + # + # Quick explanation: + # With row-major A and B tensors, `C = A.T.T @ B.T = A @ B.T`. + # To get row-major output, `C.T = (A @ B.T).T = B @ A.T`. + # + A, B = B, A + + shapeA = A.shape + shapeB = B.shape + dimsA = A.ndim + dimsB = B.ndim + assert A.device.type == "cuda" assert B.device.type == "cuda" assert A.dtype == torch.int8 assert B.dtype == torch.int8 + assert dimsA == 2, "Only two dimensional matrices are supported for argument B" + assert dimsB in [2, 3], "Only two or three dimensional matrices are supported for argument A" + assert prod(shapeB) > 0, f"Input tensor dimensions need to be > 0: {shapeB}" + + shapeC = (*shapeB[:-1], shapeA[0]) + Sout = (shapeC, "row") + + if out is None: + out = torch.empty(shapeC, device=A.device, dtype=dtype) + assert out.dtype == dtype - assert SA[1] == "col32" - assert SB[1] in ["col_turing", "col_ampere"] - assert Sout[1] == "col32" - assert ( - shapeA[-1] == shapeB[-1] - ), f"Matmullt only supports A @ B^T. Inner matrix dimensions do not match: A @ B = {shapeA} @ {shapeB}" - formatB = SB[1] + k, m = shapeA + n = prod(shapeB[:-1]) + lda = shapeA[-1] # Weights (outputs, inputs) + ldb = shapeB[-1] # Activations (batch, tokens, inputs) + ldc = shapeC[-1] # Output (batch, tokens, outputs) + + assert lda == ldb, f"igemmlt only supports B^T @ A. Inner dimensions do not match: B @ A = {shapeB} @ B={shapeA}" + prev_device = A.device torch.cuda.set_device(A.device) - ptr = CUBLAS_Context.get_instance().get_context(A.device) + ctx = CUBLAS_Context.get_instance().get_context(A.device) ptrA = get_ptr(A) ptrB = get_ptr(B) ptrC = get_ptr(out) - - k = shapeA[-1] - lda = ct.c_int32(m * 32) - if formatB == "col_turing": - # turing: tiles with rows filled up to multiple of 8 rows by 32 columns - # n = rows - ldb = ct.c_int32(((rows + 7) // 8) * 8 * 32) - else: - # ampere: tiles with rows filled up to multiple of 32 rows by 32 columns - # n = rows - ldb = ct.c_int32(((rows + 31) // 32) * 32 * 32) - - ldc = ct.c_int32(m * 32) - m = ct.c_int32(m) - n = ct.c_int32(n) - k = ct.c_int32(k) - - has_error = 0 ptrRowScale = get_ptr(None) + m, n, k, lda, ldb, ldc = map(ct.c_int32, (m, n, k, lda, ldb, ldc)) + is_on_gpu([A, B, out]) - if formatB == "col_turing": - if dtype == torch.int32: - has_error = lib.cigemmlt_turing_32(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc) - else: - has_error = lib.cigemmlt_turing_8(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc) - elif formatB == "col_ampere": - if dtype == torch.int32: - has_error = lib.cigemmlt_ampere_32(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc) - else: - has_error = lib.cigemmlt_ampere_8(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc) + + if dtype == torch.int32: + has_error = lib.cigemmlt_32(ctx, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc) + else: + has_error = lib.cigemmlt_8(ctx, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc) if has_error == 100: # `ERR_NOT_IMPLEMENTED` is defined as 100 in `ops.cu` - raise NotImplementedError("igemmlt not available (probably built with NO_CUBLASLT)") + raise NotImplementedError("igemmlt not implemented!") if has_error: print(f"A: {shapeA}, B: {shapeB}, C: {Sout[0]}; (lda, ldb, ldc): {(lda, ldb, ldc)}; (m, n, k): {(m, n, k)}") @@ -2392,6 +2363,26 @@ def igemmlt(A, B, SA, SB, out=None, Sout=None, dtype=torch.int32): def mm_dequant(A, quant_state, row_stats, col_stats, out=None, new_row_stats=None, new_col_stats=None, bias=None): assert A.dtype == torch.int32 + + compute_dtype = torch.float32 + + A_calc = A.view(-1, A.shape[-1]).to(compute_dtype) + row_stats = row_stats.reshape(-1).unsqueeze(-1).to(compute_dtype) + col_stats = col_stats.reshape(-1).unsqueeze(0).to(compute_dtype) + + # TODO support out != None + + out = A_calc * (row_stats * col_stats) * 6.200124e-5 # .to(torch.float16) + + if bias is not None: + # assert bias.dtype == torch.float16 + out.add_(bias) + + return out.to(torch.float16) + + +def mm_dequant_old(A, quant_state, row_stats, col_stats, out=None, new_row_stats=None, new_col_stats=None, bias=None): + assert A.dtype == torch.int32 if bias is not None: assert bias.dtype == torch.float16 out_shape = quant_state[0] @@ -2553,6 +2544,21 @@ def coo_zeros(rows, cols, nnz, device, dtype=torch.half): def double_quant(A, col_stats=None, row_stats=None, out_col=None, out_row=None, threshold=0.0): + # TODO: Optimize/write CUDA kernel for this. Currently vectorwise_quant will recalculate row/col stats. + # TODO: Support threshold + + # if out_col is None: + # out_col = torch.zeros(A.shape, device=A.device, dtype=torch.int8) + # if out_row is None: + # out_row = torch.zeros(A.shape, device=A.device, dtype=torch.int8) + + out_col, Scol = vectorwise_quant(A, dim=0) + out_row, Srow = vectorwise_quant(A, dim=1) + + return out_row, out_col, Srow.flatten().float(), Scol.flatten().float(), None # coo_tensor + + +def double_quant_old(A, col_stats=None, row_stats=None, out_col=None, out_row=None, threshold=0.0): device = A.device assert A.dtype == torch.half assert device.type == "cuda" @@ -2949,6 +2955,7 @@ def dequant_min_max(xq, A, B, SA, SB, dtype=torch.half): def extract_outliers(A, SA, idx): + # TODO: Implement for row-major shapeA = SA[0] formatA = SA[1] assert formatA in ["col_turing", "col_ampere"] diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index 6c78494aa..1e5a334ee 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -1009,11 +1009,8 @@ def forward(self, x: torch.Tensor): out = bnb.matmul(x, self.weight, bias=self.bias, state=self.state) if not self.state.has_fp16_weights: - if self.state.CB is not None and self.state.CxB is not None: - # we converted 8-bit row major to turing/ampere format in the first inference pass - # we no longer need the row-major weight - del self.state.CB - self.weight.data = self.state.CxB + if self.state.CB is not None: + self.weight.data = self.state.CB return out diff --git a/bitsandbytes/research/autograd/_functions.py b/bitsandbytes/research/autograd/_functions.py index b194b8777..5f8b2c437 100644 --- a/bitsandbytes/research/autograd/_functions.py +++ b/bitsandbytes/research/autograd/_functions.py @@ -204,7 +204,6 @@ def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState()): # noqa: B00 # 3. Matmul # 4. Mixed-precision decomposition matmul # 5. Save state - formatB = state.formatB input_shape = A.shape if state.outlier_pool is None: state.outlier_pool = GlobalOutlierPooler.get_instance() @@ -227,14 +226,11 @@ def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState()): # noqa: B00 state.subB = B[:, idx].t().contiguous() state.idx = idx else: - if state.CxB is None: - # B in in 8-bit row-major, we can transform it back to 16-bit to extract outlier dimensions - # we also need to convert it to the turing/ampere format - state.CxB, state.SB = F.transform(state.CB, to_order=formatB) + if state.SB is None: + state.SB = (state.CB.shape, "row") else: - # print('A shape', A.shape) - if not state.has_fp16_weights and state.CxB is None: - state.CxB, state.SB = F.transform(state.CB, to_order=formatB) + if not state.has_fp16_weights and state.SB is None: + state.SB = (state.CB.shape, "row") subA = None # 2. Quantize B @@ -245,16 +241,16 @@ def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState()): # noqa: B00 if is_transposed: B = B.contiguous() - if (state.is_training and not has_grad) or state.CxB is None: + if (state.is_training and not has_grad) or state.SB is None: state.reset_grads() ( - CB, + state.CB, state.CBt, state.SCB, state.SCBt, coo_tensorB, ) = F.double_quant(B.to(torch.float16)) - state.CxB, state.SB = F.transform(CB, to_order=formatB) + state.SB = (state.CB.shape, "row") else: has_grad = False @@ -269,7 +265,8 @@ def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState()): # noqa: B00 # state.idx = state.outlier_pool.get_current_outlier_idx().to(A.device) # else: # state.idx = outlier_idx - outliers = F.extract_outliers(state.CxB, state.SB, state.idx.int()) + # outliers = F.extract_outliers(state.CxB, state.SB, state.idx.int()) + outliers = state.CB[:, state.idx.long()].clone() state.subB = (outliers * state.SCB.view(-1, 1) / 127.0).t().contiguous().to(A.dtype) CA[:, state.idx.long()] = 0 CAt[:, state.idx.long()] = 0 @@ -283,8 +280,7 @@ def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState()): # noqa: B00 output_shape = (input_shape[0], shapeB[0]) # 3. Matmul - C32A, SA = F.transform(CA, "col32") - out32, Sout32 = F.igemmlt(C32A, state.CxB, SA, state.SB) + out32, Sout32 = F.igemmlt(CA, state.CB) # we apply the fused bias here if bias is None or bias.dtype == torch.float16: @@ -301,7 +297,6 @@ def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState()): # noqa: B00 # 5. Save state ctx.state = state - ctx.formatB = formatB ctx.grad_shape = input_shape ctx.dtype_A, ctx.dtype_B, ctx.dtype_bias = A.dtype, B.dtype, None if bias is None else bias.dtype @@ -324,7 +319,6 @@ def backward(ctx, grad_output): req_gradA, req_gradB, _, req_gradBias, _ = ctx.needs_input_grad CAt, subA, A = ctx.tensors SCAt, idx = ctx.tensor_states - formatB = ctx.formatB state = ctx.state grad_A = grad_B = grad_bias = None @@ -345,12 +339,7 @@ def backward(ctx, grad_output): if req_gradA: if state.CBt is not None: - C32grad, Sgrad = F.transform(Cgrad, "col32") - if state.CxBt is None: - state.CxBt, state.SBt = F.transform(state.CBt, to_order=formatB, transpose=True) - # print('back B shape', state.CxBt.shape) - # print('back grad shape', C32grad.shape) - gradA32, SgradA32 = F.igemmlt(C32grad, state.CxBt, Sgrad, state.SBt) + gradA32, SgradA32 = F.igemmlt(Cgradt, state.CBt.t()) grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape).to(ctx.dtype_A) elif state.CB is not None: diff --git a/csrc/ops.cu b/csrc/ops.cu index 7ca854baf..1f259d67f 100644 --- a/csrc/ops.cu +++ b/csrc/ops.cu @@ -422,6 +422,101 @@ template void trans #endif } +template int igemmlt( + cublasLtHandle_t ltHandle, + int m, int n, int k, + const int8_t * A, + const int8_t * B, + void * C, + float * row_scale, + int lda, int ldb, int ldc +) { + + // Calculate C = A^T @ B, in col-major layout. + // + // Use the IMMA kernels requires: + // * A must be transposed and B must be non-transposed. + // * All leading dimensions must be multiples of 4. + // * Dimensions m and k must be multiples of 4. + // * All pointers must be 4-byte aligned; 16-byte alignment preferred. + // + + + int has_error = 0; + + // this is the default + cublasLtOrder_t col_major = CUBLASLT_ORDER_COL; + + cublasLtMatmulDesc_t matmulDesc; + cublasLtMatrixLayout_t aDesc, bDesc, cDesc; + cublasOperation_t opT = CUBLAS_OP_T; + + cudaDataType_t outType = DTYPE_OUT == 32 ? CUDA_R_32I : CUDA_R_8I; + cudaDataType_t scaleType = DTYPE_OUT == 32 ? CUDA_R_32I : CUDA_R_32F; + + cublasLtPointerMode_t pointerMode = CUBLASLT_POINTER_MODE_ALPHA_DEVICE_VECTOR_BETA_ZERO; + + has_error |= checkCublasStatus(cublasLtMatrixLayoutCreate(&aDesc, CUDA_R_8I, m, k, lda)); + has_error |= checkCublasStatus(cublasLtMatrixLayoutCreate(&bDesc, CUDA_R_8I, m, n, ldb)); + has_error |= checkCublasStatus(cublasLtMatrixLayoutCreate(&cDesc, outType, k, n, ldc)); + + // Default layout order is col major + + has_error |= checkCublasStatus(cublasLtMatmulDescCreate(&matmulDesc, CUBLAS_COMPUTE_32I, scaleType)); + has_error |= checkCublasStatus(cublasLtMatmulDescSetAttribute(matmulDesc, CUBLASLT_MATMUL_DESC_TRANSA, &opT, sizeof(opT))); + + if (DTYPE_OUT == 32) { + int alpha = 1, beta = 0; + has_error |= checkCublasStatus(cublasLtMatmul( + ltHandle, matmulDesc, + &alpha, A, aDesc, + B, bDesc, &beta, + (int32_t*)C, cDesc, + (int32_t*)C, cDesc, + NULL, NULL, 0, 0 + )); + } else { + if (!SCALE_ROWS) { + float alpha = 1.0f, beta = 0.0f; + has_error |= checkCublasStatus(cublasLtMatmul( + ltHandle, matmulDesc, + &alpha, A, aDesc, + B, bDesc, &beta, + (int8_t*)C, cDesc, + (int8_t*)C, cDesc, + NULL, NULL, 0, 0 + )); + } else { + cublasLtPointerMode_t alphaVec = CUBLASLT_POINTER_MODE_ALPHA_DEVICE_VECTOR_BETA_HOST; + float beta = 0.0f; + has_error |= checkCublasStatus(cublasLtMatmulDescSetAttribute( + matmulDesc, + CUBLASLT_MATMUL_DESC_POINTER_MODE, + &pointerMode, + sizeof(alphaVec) + )); + has_error |= checkCublasStatus(cublasLtMatmul( + ltHandle, matmulDesc, + row_scale, A, aDesc, + B, bDesc, &beta, + (int8_t*)C, cDesc, + (int8_t*)C, cDesc, + NULL, NULL, 0, 0 + )); + } + } + + has_error |= checkCublasStatus(cublasLtMatrixLayoutDestroy(cDesc)); + has_error |= checkCublasStatus(cublasLtMatrixLayoutDestroy(bDesc)); + has_error |= checkCublasStatus(cublasLtMatrixLayoutDestroy(aDesc)); + has_error |= checkCublasStatus(cublasLtMatmulDescDestroy(matmulDesc)); + + if(has_error == 1) + printf("error detected"); + + return has_error; +} + template int igemmlt(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc) { #ifdef NO_CUBLASLT @@ -729,8 +824,8 @@ template void gemm_4bit_inference(int m, int n, int k, T * A, unsi template void gemm_4bit_inference_naive(int m, int n, int k, T * A, unsigned char* B, float *absmax, float *datatype, T * out, int lda, int ldb, int ldc, int blocksize, cudaStream_t stream) { - int num_blocks = (m+3)/4; - kgemm_4bit_inference_naive<<< num_blocks, 128, 0, stream>>>(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize); + int num_blocks = (m+7)/8; + kgemm_4bit_inference_naive<<< num_blocks, 256, 0, stream>>>(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize); CUDA_CHECK_RETURN(cudaPeekAtLastError()); } @@ -772,6 +867,10 @@ template int igemmlt(cublasLtHandle_t ltHandle, int m, int n, template int igemmlt(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc); template int igemmlt(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc); template int igemmlt(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc); +template int igemmlt<32, 0>(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc); +template int igemmlt<8, 0>(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc); +template int igemmlt<8, 1>(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc); + template void transformRowToFormat(char * A, char *out, int rows, int cols); template void transformRowToFormat(char * A, char *out, int rows, int cols); diff --git a/csrc/ops.cuh b/csrc/ops.cuh index b0ecc4622..ab0185242 100644 --- a/csrc/ops.cuh +++ b/csrc/ops.cuh @@ -175,7 +175,7 @@ void gemmex(Context * context, bool transposeA, bool transposeB, int m, int n, i void strided_gemmex(Context *context, bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc, long long int strideA, long long int strideB, long long int strideC, int batchCount); - +template int igemmlt(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc); template int igemmlt(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc); template void transform(cublasLtHandle_t ltHandle, T *A, T *out, int dim1, int dim2); diff --git a/csrc/pythonInterface.cpp b/csrc/pythonInterface.cpp index f0ee84c29..09b9b62a9 100644 --- a/csrc/pythonInterface.cpp +++ b/csrc/pythonInterface.cpp @@ -175,6 +175,15 @@ void transform_row2ampereT(char * A, char *out, int rows, int cols){ transformRo void extractOutliers_turing(char * A, int *idx, char *out, int idx_size, int rows, int cols){ extractOutliers(A, idx, out, idx_size, rows, cols); } void extractOutliers_ampere(char * A, int *idx, char *out, int idx_size, int rows, int cols){ extractOutliers(A, idx, out, idx_size, rows, cols); } + int igemmlt_32(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc) { + return igemmlt<32, 0>(ltHandle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); + } + int igemmlt_8(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc) { + return igemmlt<8, 0>(ltHandle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); + } + int igemmlt_8_rowscale(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc) { + return igemmlt<8, 1>(ltHandle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); + } int igemmlt_turing_32(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc) { return igemmlt(ltHandle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); } @@ -316,6 +325,16 @@ extern "C" Context *get_context(){ return new Context(); } ContextCusparse *get_cusparse(){ return new ContextCusparse(); } + int cigemmlt_32(Context *context, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc) { + return igemmlt_32((cublasLtHandle_t) context->m_handle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); + } + int cigemmlt_8(Context *context, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc) { + return igemmlt_8((cublasLtHandle_t) context->m_handle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); + } + int cigemmlt_8_rowscale(Context *context, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc) { + return igemmlt_8_rowscale((cublasLtHandle_t) context->m_handle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); + } + int cigemmlt_turing_32(Context *context, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc) { return igemmlt_turing_32((cublasLtHandle_t) context->m_handle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); } //{ (cublasLtHandle_t)context->m_handle; return 0; } diff --git a/tests/test_autograd.py b/tests/test_autograd.py index 9da665a2d..89dce644b 100644 --- a/tests/test_autograd.py +++ b/tests/test_autograd.py @@ -198,10 +198,14 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad: Tuple[bool, bool assert (idx == 0).sum().item() < n * 0.02 -@pytest.mark.parametrize("dim1", get_test_dims(16, 64, n=1), ids=id_formatter("dim1")) -@pytest.mark.parametrize("dim2", [*get_test_dims(32, 96, n=1), 0], ids=id_formatter("dim2")) -@pytest.mark.parametrize("dim3", get_test_dims(32, 96, n=1), ids=id_formatter("dim3")) -@pytest.mark.parametrize("dim4", get_test_dims(32, 96, n=1), ids=id_formatter("dim4")) +# @pytest.mark.parametrize("dim1", get_test_dims(16, 64, n=1), ids=id_formatter("dim1")) +# @pytest.mark.parametrize("dim2", [*get_test_dims(32, 96, n=1), 0], ids=id_formatter("dim2")) +# @pytest.mark.parametrize("dim3", get_test_dims(32, 96, n=1), ids=id_formatter("dim3")) +# @pytest.mark.parametrize("dim4", get_test_dims(32, 96, n=1), ids=id_formatter("dim4")) +@pytest.mark.parametrize("dim1", [40], ids=id_formatter("dim1")) +@pytest.mark.parametrize("dim2", [64, 0], ids=id_formatter("dim2")) # [64,0] +@pytest.mark.parametrize("dim3", [32], ids=id_formatter("dim3")) +@pytest.mark.parametrize("dim4", [48], ids=id_formatter("dim4")) @pytest.mark.parametrize("decomp", [0.0, 6.0], ids=id_formatter("decomp")) @pytest.mark.parametrize( "funcs", diff --git a/tests/test_functional.py b/tests/test_functional.py index 1cca04511..522af516c 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -570,10 +570,14 @@ def test_nvidia_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, trans torch.testing.assert_close(A, out2) -@pytest.mark.parametrize("dim1", get_test_dims(1, 256, n=1), ids=id_formatter("dim1")) -@pytest.mark.parametrize("dim2", get_test_dims(32, 512, n=1), ids=id_formatter("dim2")) -@pytest.mark.parametrize("dim3", get_test_dims(32, 1024, n=1), ids=id_formatter("dim3")) -@pytest.mark.parametrize("dim4", get_test_dims(32, 1024, n=1), ids=id_formatter("dim4")) +# @pytest.mark.parametrize("dim1", get_test_dims(1, 256, n=1), ids=id_formatter("dim1")) +# @pytest.mark.parametrize("dim2", get_test_dims(32, 512, n=1), ids=id_formatter("dim2")) +# @pytest.mark.parametrize("dim3", get_test_dims(32, 1024, n=1), ids=id_formatter("dim3")) +# @pytest.mark.parametrize("dim4", get_test_dims(32, 1024, n=1), ids=id_formatter("dim4")) +@pytest.mark.parametrize("dim1", [128], ids=id_formatter("dim1")) +@pytest.mark.parametrize("dim2", [256], ids=id_formatter("dim2")) +@pytest.mark.parametrize("dim3", [512], ids=id_formatter("dim3")) +@pytest.mark.parametrize("dim4", [512], ids=id_formatter("dim4")) @pytest.mark.parametrize("dims", (2, 3), ids=id_formatter("dims")) @pytest.mark.parametrize("ldb", (0,), ids=id_formatter("ldb")) def test_igemmlt_int(dim1, dim2, dim3, dim4, dims, ldb): @@ -585,20 +589,17 @@ def test_igemmlt_int(dim1, dim2, dim3, dim4, dims, ldb): B = torch.randint(-128, 127, size=(dim4, dim3), device="cuda").to(torch.int8) C1 = torch.matmul(A.float(), B.t().float()) - A2, SA = F.transform(A, "col32") - B2, SB = F.transform(B, "col_turing") - C2, SC = F.igemmlt(A2, B2, SA, SB) - C3, S = F.nvidia_transform(C2, "row", state=SC) - torch.testing.assert_close(C1, C3.float()) + C2, SC = F.igemmlt(A, B) + torch.testing.assert_close(C1, C2.float()) # transpose - B = torch.randint(-128, 127, size=(dim3, dim4), device="cuda").to(torch.int8) - C1 = torch.matmul(A.float(), B.float()) + # B = torch.randint(-128, 127, size=(dim3, dim4), device="cuda").to(torch.int8) + # C1 = torch.matmul(A.float(), B.float()) - B2t, SBt = F.transform(B, "col_turing", transpose=True) - C2, SC = F.igemmlt(A2, B2t, SA, SBt) - C3, S = F.nvidia_transform(C2, "row", state=SC) - torch.testing.assert_close(C1, C3.float()) + # B2t, SBt = F.transform(B, "col", transpose=True) + # C2, SC = F.igemmlt(A2, B2t, SA, SBt) #B2t, A2, SBt, SA) + # C3, S = F.nvidia_transform(C2, "row", state=SC) + # torch.testing.assert_close(C1, C2.float()) @pytest.mark.parametrize("dim1", [32], ids=id_formatter("dim1")) @@ -622,17 +623,15 @@ def test_igemmlt_half(dim1, dim2, dim3, dim4, dims): CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A) CB, CBt, statsB, statsBt, coo_tensor = F.double_quant(B) - C32A, SA = F.transform(CA, "col32") - CxB, SB = F.transform(CB, to_order=formatB) - out1_32, Sout1_32 = F.igemmlt(C32A, CxB, SA, SB) - output = F.mm_dequant(out1_32, Sout1_32, statsAt, statsBt) + out1_32, Sout1_32 = F.igemmlt(CA, CB) + output = F.mm_dequant(out1_32, Sout1_32, statsA, statsB) # print('') # print(output.flatten()[:10]) # print(C1.flatten()[:10]) # print(C2.flatten()[:10]) - # torch.testing.assert_close(C1.view(-1, C1.shape[-1]), output, atol=0.025, rtol=0.05) + torch.testing.assert_close(C1.view(-1, C1.shape[-1]), output, atol=0.025, rtol=0.05) # transpose # B = torch.randint(-128, 127, size=(dim3, dim4), device='cuda').to(torch.int8) @@ -801,17 +800,18 @@ def test_bench_8bit_training(batch, seq, model, hidden): # print(t8) -@pytest.mark.parametrize("dim1", get_test_dims(64, 256, n=2), ids=id_formatter("dim1")) -@pytest.mark.parametrize("dim4", get_test_dims(64, 1024, n=2), ids=id_formatter("dim4")) +# @pytest.mark.parametrize("dim1", get_test_dims(64, 256, n=2), ids=id_formatter("dim1")) +# @pytest.mark.parametrize("dim4", get_test_dims(64, 1024, n=2), ids=id_formatter("dim4")) +@pytest.mark.parametrize("dim1", (64, 256), ids=id_formatter("dim1")) +@pytest.mark.parametrize("dim4", (64, 1024), ids=id_formatter("dim4")) @pytest.mark.parametrize("dims", (2,), ids=id_formatter("dims")) -@pytest.mark.parametrize("formatB", ["col_turing", "col_ampere"], ids=id_formatter("formatB")) @pytest.mark.parametrize("has_bias", TRUE_FALSE, ids=id_formatter("has_bias")) -def test_dequant_mm(dim1, dim4, dims, formatB, has_bias): - inner = torch.randint(1, 128, size=(1,)).item() +def test_dequant_mm(dim1, dim4, dims, has_bias): + inner = 128 # torch.randint(1, 128, size=(1,)).item() bias = None if has_bias: bias = torch.randn(dim4, device="cuda", dtype=torch.float16) - formatB = F.get_special_format_str() + for i in range(1): A = torch.randn(dim1, inner, device="cuda") B = torch.randn(dim4, inner, device="cuda") @@ -822,12 +822,9 @@ def test_dequant_mm(dim1, dim4, dims, formatB, has_bias): A1, maxA = F.vectorwise_quant(A, dim=1) B1, maxB = F.vectorwise_quant(B, dim=1) - A2, SA = F.nvidia_transform(A1, "col32") - B2, SB = F.nvidia_transform(B1, formatB) - C2, SC = F.igemmlt(A2, B2, SA, SB) + C2, SC = F.igemmlt(A1, B1) - C3, S = F.nvidia_transform(C2, "row", state=SC) - C4 = F.vectorwise_mm_dequant(C3.float(), maxA, maxB.t()) + C4 = F.vectorwise_mm_dequant(C2.float(), maxA, maxB.t()) if has_bias: C4 += bias @@ -840,8 +837,9 @@ def test_dequant_mm(dim1, dim4, dims, formatB, has_bias): # assert_all_approx_close(C1, C4, atol=0.02, rtol=0.1, count=int(n*0.06)) # assert (count / n < p), f"error in more than {p} of elements: {count}/{n}={count/n}" - C5 = F.mm_dequant(C2, SC, maxA.flatten(), maxB.flatten(), bias=bias) - # torch.testing.assert_close(C5, C4, atol=0.015, rtol=0.1) + C5 = F.mm_dequant(C2, SC, maxA, maxB, bias=bias) + C5 /= std + torch.testing.assert_close(C5, C4, atol=0.015, rtol=0.1) n = C5.numel() assert_all_approx_close(C1, C4, atol=0.015, rtol=0.1, count=int(0.01 * n)) @@ -890,8 +888,10 @@ def test_colrow_absmax(dim1, dim2, dims): assert nnz_block_ptr2 is None -@pytest.mark.parametrize("dim1", get_test_dims(1, 4 * 1024, n=2), ids=id_formatter("dim1")) -@pytest.mark.parametrize("dim2", get_test_dims(1, 4 * 1024, n=2), ids=id_formatter("dim2")) +# @pytest.mark.parametrize("dim1", get_test_dims(1, 4 * 1024, n=2), ids=id_formatter("dim1")) +# @pytest.mark.parametrize("dim2", get_test_dims(1, 4 * 1024, n=2), ids=id_formatter("dim2")) +@pytest.mark.parametrize("dim1", [2048, 4096], ids=id_formatter("dim1")) +@pytest.mark.parametrize("dim2", [512, 1024], ids=id_formatter("dim2")) def test_double_quant(dim1, dim2): for i in range(k): A = torch.randn(dim1, dim2, device="cuda").half() @@ -926,9 +926,12 @@ def test_double_quant(dim1, dim2): ( pytest.param(dim1, dim4, inner, id=f"{dim1=},{dim4=},{inner=}") for (dim1, dim4, inner) in zip( - get_test_dims(1, 4 * 1024, n=4), - get_test_dims(1, 4 * 1024, n=4), - get_test_dims(1, 4 * 1024, n=4), + (1, 8, 2048, 4096), + (2, 128, 2048, 4096), + (4, 256, 512, 4096), + # get_test_dims(1, 4 * 1024, n=4), + # get_test_dims(1, 4 * 1024, n=4), + # get_test_dims(1, 4 * 1024, n=4), ) ), ) @@ -949,17 +952,11 @@ def test_integrated_igemmlt(dim1, dim4, inner): torch.testing.assert_close(C1a, A1, rtol=0, atol=1) torch.testing.assert_close(C2a, B1, rtol=0, atol=1) - A2, SA = F.nvidia_transform(C1a, "col32") - B2, SB = F.nvidia_transform(C2a, "col_turing") - outC32, SC = F.igemmlt(A2, B2, SA, SB) - out2 = F.mm_dequant(outC32, SC, stats1a, stats2a) + out2, SC = F.igemmlt(A1, B1) - A2, SA = F.nvidia_transform(A1, "col32") - B2, SB = F.nvidia_transform(B1, "col_turing") - C2, SC = F.igemmlt(A2, B2, SA, SB) + C2, SC = F.igemmlt(A1, B1) - C3, S = F.nvidia_transform(C2, "row", state=SC) - out3 = F.vectorwise_mm_dequant(C3.float(), maxA, maxB.t()) + out3 = F.vectorwise_mm_dequant(C2.float(), maxA, maxB.t()) err1 = torch.abs(out1 - out2).mean().item() err2 = torch.abs(out1 - out3).mean().item() @@ -999,7 +996,7 @@ def test_igemmlt_row_scale(dim1, dim4, inner): c = 10.0 * inner * scale row_scale = torch.ones_like(maxA) / c - outC32, SC = F.igemmlt(A2, B2, SA, SB, dtype=torch.int8, row_scale=row_scale) + outC32, SC = F.igemmlt(A2, B2, dtype=torch.int8, row_scale=row_scale) C3, S = F.nvidia_transform(outC32, "row", state=SC) maxval = torch.abs(C3).max() if maxval == 127: @@ -1012,7 +1009,7 @@ def test_igemmlt_row_scale(dim1, dim4, inner): C2a, C2b, stats2a, stats2b, coo_tensor = F.double_quant(B) B2, SB = F.nvidia_transform(C2a, formatB) - outC32, SC = F.igemmlt(A2, B2, SA, SB) + outC32, SC = F.igemmlt(A2, B2) out2 = F.mm_dequant(outC32, SC, stats1a, stats2a) CA, SA = F.vectorwise_quant(A, dim=1, quant_type="vector") @@ -1080,7 +1077,7 @@ def test_row_scale_bench(dim1, dim4, inner): torch.cuda.synchronize() t0 = time.time() for i in range(k): - outC32, SC = F.igemmlt(A2, B2, SA, SB, dtype=torch.int8, row_scale=row_scale) + outC32, SC = F.igemmlt(A2, B2, dtype=torch.int8, row_scale=row_scale) torch.cuda.synchronize() print("row-wise", time.time() - t0) @@ -1089,7 +1086,7 @@ def test_row_scale_bench(dim1, dim4, inner): torch.cuda.synchronize() t0 = time.time() for i in range(k): - outC32, SC = F.igemmlt(A2, B2, SA, SB) + outC32, SC = F.igemmlt(A2, B2) torch.cuda.synchronize() print("vector-wise", time.time() - t0) @@ -1132,10 +1129,11 @@ def test_overflow(): a = torch.arange(5, 15).cuda().to(torch.int8).view(-1, 1) b = torch.arange(5, 15).cuda().to(torch.int8).view(-1, 1) - Ca, Sa = F.nvidia_transform(a, "col32") - Cb, Sb = F.nvidia_transform(b, formatB) + # Ca, Sa = F.nvidia_transform(a, "col32") + # Cb, Sb = F.nvidia_transform(b, formatB) - c = F.igemmlt(Ca, Cb, Sa, Sb, dtype=torch.int8) + # c = F.igemmlt(Ca, Cb, Sa, Sb, dtype=torch.int8) + c = F.igemmlt(a, b) c2 = torch.matmul(a.float(), b.float().t()) @@ -1238,25 +1236,21 @@ def test_spmm_bench(): @pytest.mark.parametrize("dim2", get_test_dims(256, 1024, n=2), ids=id_formatter("dim2")) def test_integrated_sparse_decomp(dim1, dim2): threshold = 3.0 - formatB = "col_turing" + # formatB = "col_turing" for i in range(k): A = torch.randn(dim1, dim2).cuda().half() w1 = torch.randn(dim1, dim2).cuda().half() out1 = torch.matmul(A, w1.t()) Cw1, Cw1t, statsw1, statsw1t, coo_tensor = F.double_quant(w1) - CTw1, Sw1 = F.transform(Cw1, formatB) - CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A) - C32A, SA = F.transform(CA, "col32") - out1_32, Sout1_32 = F.igemmlt(C32A, CTw1, SA, Sw1) + out1_32, Sout1_32 = F.igemmlt(CA, Cw1) out2 = F.mm_dequant(out1_32, Sout1_32, statsA, statsw1) CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A, threshold=threshold) - C32A, SA = F.transform(CA, "col32") - out1_32, Sout1_32 = F.igemmlt(C32A, CTw1, SA, Sw1) + out1_32, Sout1_32 = F.igemmlt(CA, Cw1) out3 = F.mm_dequant(out1_32, Sout1_32, statsA, statsw1) assert coo_tensor is not None @@ -1484,7 +1478,12 @@ def test_spmm_coo_dequant(dim1, dim2, dtype): @pytest.mark.parametrize( ("batch", "seq", "model", "hidden"), - [pytest.param(1, 1, 6656, 4 * 6656, id="batch=1, seq=1, model=6656, hidden=26k")], + [ + # pytest.param(1, 128, 6656, 4 * 6656, id="batch=1, seq=128, model=6656, hidden=26k"), + # pytest.param(2, 128, 6656, 4 * 6656, id="batch=2, seq=128, model=6656, hidden=26k"), + # pytest.param(4, 128, 6656, 4 * 6656, id="batch=4, seq=128, model=6656, hidden=26k"), + pytest.param(16, 256, 6656, 4 * 6656, id="batch=16, seq=256, model=6656, hidden=26k") + ], ) @pytest.mark.benchmark def test_bench_matmul(batch, seq, model, hidden): @@ -1557,19 +1556,45 @@ def test_bench_matmul(batch, seq, model, hidden): torch.cuda.synchronize() print(f"bnb nf4+DQ: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") - # torch.cuda.synchronize() - # t0 = time.time() - # for i in range(iters): - # bnb.matmul(A, B) - # torch.cuda.synchronize() - # print(f"CB -> CxB conversion (each iteration): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") + torch.cuda.synchronize() + t0 = time.time() + for i in range(iters): + bnb.matmul(A, B) + torch.cuda.synchronize() + print( + f"B -> CB (each iteration): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" + ) + torch.cuda.synchronize() + t0 = time.time() + for i in range(iters): + bnb.matmul(A, B, threshold=6.0) + torch.cuda.synchronize() + print( + f"B -> CB + threshold: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" + ) + + CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(A, threshold=0.0) + CB, CBt, SCB, SCBt, coo_tensorB = F.double_quant(B) + torch.cuda.synchronize() + t0 = time.time() + for i in range(iters): + # CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(A, threshold=0.0) + out32, Sout32 = F.igemmlt(CA, CB) + torch.cuda.synchronize() + print( + f"no overhead igemmlt [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" + ) + + # C32A, SA = F.transform(CA, "col32") + + # CxB, SB = F.transform(CB, to_order=formatB) # torch.cuda.synchronize() # t0 = time.time() # for i in range(iters): - # bnb.matmul(A, B, threshold=6.0) + # out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB) # torch.cuda.synchronize() - # print(f"CB -> CxB conversion + threshold: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") + # print(f"no overhead matmul-lt: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") # CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(A, threshold=0.0) # C32A, SA = F.transform(CA, "col32") @@ -1610,21 +1635,25 @@ def test_bench_matmul(batch, seq, model, hidden): # torch.cuda.synchronize() # print(f"linear pytorch + nvidia: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") - # linear8bit(A) - # torch.cuda.synchronize() - # t0 = time.time() - # for i in range(iters): - # linear8bit(A) - # torch.cuda.synchronize() - # print( f"bnb linear8bitlt (eval): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") + linear8bit(A) + torch.cuda.synchronize() + t0 = time.time() + for i in range(iters): + linear8bit(A) + torch.cuda.synchronize() + print( + f"bnb linear8bitlt (eval): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" + ) - # linearMixedBit(A) - # torch.cuda.synchronize() - # t0 = time.time() - # for i in range(iters): - # linearMixedBit(A) - # torch.cuda.synchronize() - # print( f"bnb linear8bitlt with threshold (eval): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") + linearMixedBit(A) + torch.cuda.synchronize() + t0 = time.time() + for i in range(iters): + linearMixedBit(A) + torch.cuda.synchronize() + print( + f"bnb linear8bitlt with threshold (eval): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" + ) # linear8bit_train(A) # torch.cuda.synchronize() @@ -2144,7 +2173,7 @@ def test_gemv_4bit(dtype, storage_type, quant_storage, double_quant, kind): c = int(C1.numel() * 0.0014 * (dim / 256)) + 1 - c = assert_all_approx_close(C1, C2, 1e-5, 0.01, count=c, throw=False) + c = assert_all_approx_close(C1, C2, 1e-5, 0.01, count=0, throw=False) err1 = sum(errs1) / len(errs1) / math.sqrt(dim) err2 = sum(errs2) / len(errs2) / math.sqrt(dim) err3 = sum(errs3) / len(errs3) / math.sqrt(dim) diff --git a/tests/test_modules.py b/tests/test_modules.py index 2176f1d48..d5c968395 100644 --- a/tests/test_modules.py +++ b/tests/test_modules.py @@ -310,7 +310,7 @@ def test_linear8bitlt_inference(threshold): b1 = torch.randn(16, 8, 32, device="cuda").half() o1 = l1(b1) if i == 1: - assert l1.state.CxB is not None + assert l1.state.CB is not None def test_linear8bitlt_accumulated_gradient():