Skip to content

Commit

Permalink
Start of int8 refactor: remove col32/col_ampere/col_turing transforms…
Browse files Browse the repository at this point in the history
… in new igemmlt implementation
  • Loading branch information
matthewdouglas committed Oct 7, 2024
1 parent 0500c31 commit 0cc5c95
Show file tree
Hide file tree
Showing 10 changed files with 382 additions and 243 deletions.
83 changes: 39 additions & 44 deletions bitsandbytes/autograd/_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,11 +245,11 @@ class MatmulLtState:
_tile_indices: Optional[torch.Tensor] = None
force_no_igemmlt: bool = False
CB = None
CxB = None
CxB = None # TODO: Deprecate/remove
SB = None
SCB = None

CxBt = None
CxBt = None # TODO: Deprecate/remove
SBt = None
CBt = None

Expand All @@ -263,7 +263,7 @@ class MatmulLtState:
has_fp16_weights = True
memory_efficient_backward = False
use_pool = False
formatB = F.get_special_format_str()
formatB = "row" # F.get_special_format_str() TODO: Deprecate/remove

def reset_grads(self):
self.CB = None
Expand All @@ -283,9 +283,6 @@ def tile_indices(self):


class MatMul8bitLt(torch.autograd.Function):
# forward is the same, but we added the fallback for pre-turing GPUs
# backward is mostly the same, but adds one extra clause (see "elif state.CxB is not None")

@staticmethod
def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState):
using_igemmlt = supports_igemmlt(A.device) and not state.force_no_igemmlt
Expand All @@ -306,7 +303,6 @@ def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState):
# 3. Matmul
# 4. Mixed-precision decomposition matmul
# 5. Save state
formatB = state.formatB
input_shape = A.shape
if state.outlier_pool is None:
state.outlier_pool = GlobalOutlierPooler.get_instance()
Expand All @@ -328,14 +324,7 @@ def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState):
subA = A[:, idx]
state.subB = B[:, idx].t().contiguous()
state.idx = idx
else:
if state.CxB is None and using_igemmlt:
# B in in 8-bit row-major, we can transform it back to 16-bit to extract outlier dimensions
# we also need to convert it to the turing/ampere format
state.CxB, state.SB = F.transform(state.CB, to_order=formatB)
else:
if not state.has_fp16_weights and state.CxB is None and using_igemmlt:
state.CxB, state.SB = F.transform(state.CB, to_order=formatB)
subA = None

# 2. Quantize B
Expand All @@ -345,19 +334,17 @@ def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState):
if is_transposed:
B = B.contiguous()

if (state.is_training and not has_grad) or state.CxB is None:
if (state.is_training and not has_grad) or state.CB is None:
state.reset_grads()

# quantize...
(
CB,
state.CB,
state.CBt,
state.SCB,
state.SCBt,
coo_tensorB,
) = F.double_quant(B.to(torch.float16))
if using_igemmlt:
state.CxB, state.SB = F.transform(CB, to_order=formatB)
else:
state.CB = CB
else:
has_grad = False

Expand All @@ -372,17 +359,18 @@ def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState):
# state.idx = state.outlier_pool.get_current_outlier_idx().to(A.device)
# else:
# state.idx = outlier_idx
if state.CxB is not None:
outliers = F.extract_outliers(state.CxB, state.SB, state.idx.int())
else:
outliers = state.CB[:, state.idx.long()].clone()

# if state.CxB is not None:
# outliers = F.extract_outliers(state.CxB, state.SB, state.idx.int())
# else:
outliers = state.CB[:, state.idx.long()].clone()

state.subB = (outliers * state.SCB.view(-1, 1) / 127.0).t().contiguous().to(A.dtype)
CA[:, state.idx.long()] = 0
CAt[:, state.idx.long()] = 0
subA = A[:, state.idx.long()]

shapeB = state.SB[0] if state.SB else B.shape
shapeB = state.CB.shape

if len(input_shape) == 3:
output_shape = (input_shape[0], input_shape[1], shapeB[0])
Expand All @@ -391,13 +379,14 @@ def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState):

# 3. Matmul
if using_igemmlt:
C32A, SA = F.transform(CA, "col32")
out32, Sout32 = F.igemmlt(C32A, state.CxB, SA, state.SB)
out32, Sout32 = F.igemmlt(CA, state.CB)

if bias is None or bias.dtype == torch.float16:
# we apply the fused bias here
output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=bias)
output = output.to(A.dtype)
else: # apply bias separately
# TODO: Fused bias for fp32/bf16?
output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=None)
output = output.to(A.dtype).add_(bias)

Expand All @@ -417,7 +406,6 @@ def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState):
# 5. Save state
ctx.state = state

ctx.formatB = formatB
ctx.grad_shape = input_shape
ctx.dtype_A, ctx.dtype_B, ctx.dtype_bias = A.dtype, B.dtype, None if bias is None else bias.dtype

Expand All @@ -437,10 +425,10 @@ def backward(ctx, grad_output):
if ctx.is_empty:
bias_grad = None if ctx.bias is None else torch.zeros_like(ctx.bias)
return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, bias_grad, None

req_gradA, req_gradB, _, req_gradBias, _ = ctx.needs_input_grad
CAt, subA, A = ctx.tensors
SCAt, idx = ctx.tensor_states
formatB = ctx.formatB
state = ctx.state
grad_A = grad_B = grad_bias = None

Expand All @@ -454,33 +442,39 @@ def backward(ctx, grad_output):

Cgrad, Cgradt, SCgrad, SCgradt, coo_tensor = F.double_quant(grad_output.to(torch.float16))
if req_gradB:
CxAt, SAt = F.transform(CAt, formatB, transpose=True)
C32grad, Sgrad = F.transform(Cgradt, "col32", transpose=True)
gradB32, SgradB32 = F.igemmlt(C32grad, CxAt, Sgrad, SAt)
# CxAt, SAt = F.transform(CAt, formatB, transpose=True)
# C32grad, Sgrad = F.transform(Cgradt, "col32", transpose=True)
# gradB32, SgradB32 = F.igemmlt(C32grad, CxAt, Sgrad, SAt)
# grad_B = F.mm_dequant(gradB32, SgradB32, SCgradt, SCAt)
gradB32, SgradB32 = F.igemmlt(
Cgradt.t(), CAt.t()
) # issue here in test_linear_serialization w/ has fp16 weights
grad_B = F.mm_dequant(gradB32, SgradB32, SCgradt, SCAt)
if state.threshold > 0.0 and subA is not None:
grad_B[:, idx] += torch.matmul(grad_output.t(), subA)

if req_gradA:
if state.CBt is not None:
C32grad, Sgrad = F.transform(Cgrad, "col32")
if state.CxBt is None:
state.CxBt, state.SBt = F.transform(state.CBt, to_order=formatB, transpose=True)
gradA32, SgradA32 = F.igemmlt(C32grad, state.CxBt, Sgrad, state.SBt)
# C32grad, Sgrad = F.transform(Cgrad, "col32")
# if state.CxBt is None:
# state.CxBt, state.SBt = F.transform(state.CBt, to_order=formatB, transpose=True)
# gradA32, SgradA32 = F.igemmlt(C32grad, state.CxBt, Sgrad, state.SBt)
# grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape).to(ctx.dtype_A)
gradA32, SgradA32 = F.igemmlt(Cgradt, state.CBt.t())
grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape).to(ctx.dtype_A)

elif state.CB is not None:
CB = state.CB.to(ctx.dtype_A, copy=True).mul_(state.SCB.unsqueeze(1).mul(1.0 / 127.0))
grad_A = torch.matmul(grad_output, CB).view(ctx.grad_shape).to(ctx.dtype_A)
elif state.CxB is not None:
CB = (
undo_layout(state.CxB, state.tile_indices)
.to(ctx.dtype_A)
.mul_(state.SCB.unsqueeze(1).mul(1.0 / 127.0))
)
grad_A = torch.matmul(grad_output, CB).view(ctx.grad_shape).to(ctx.dtype_A)
# elif state.CxB is not None:
# CB = (
# undo_layout(state.CxB, state.tile_indices)
# .to(ctx.dtype_A)
# .mul_(state.SCB.unsqueeze(1).mul(1.0 / 127.0))
# )
# grad_A = torch.matmul(grad_output, CB).view(ctx.grad_shape).to(ctx.dtype_A)
else:
raise Exception("State must contain either CBt or CB or CxB matrix for backward")
raise Exception("State must contain either CBt or CB matrix for backward")

return grad_A, grad_B, None, grad_bias, None

Expand Down Expand Up @@ -564,6 +558,7 @@ def matmul_4bit(
bias=None,
):
assert quant_state is not None

if A.numel() == A.shape[-1] and A.requires_grad == False:
if A.shape[-1] % quant_state.blocksize != 0:
warn(
Expand Down
Loading

0 comments on commit 0cc5c95

Please sign in to comment.