diff --git a/.github/workflows/package.yml b/.github/workflows/package.yml index e5b235a..389192c 100644 --- a/.github/workflows/package.yml +++ b/.github/workflows/package.yml @@ -53,42 +53,42 @@ jobs: # - "3.13" steps: - uses: actions/checkout@v4 - + - name: Checkout bitsandbytes uses: actions/checkout@v4 with: repository: bitsandbytes-foundation/bitsandbytes path: bitsandbytes - + - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} cache: pip - + - name: Install build tools run: | sudo apt-get update sudo apt-get install -y build-essential cmake - + - name: Compile bitsandbytes with CPU backend run: | cd bitsandbytes cmake -DCOMPUTE_BACKEND=cpu -S . && make cd .. - + - name: Download build artifacts uses: actions/download-artifact@v4 with: name: dist path: dist/ - + - name: Install dependencies and built package run: | python -m pip install --upgrade pip pip install ./bitsandbytes pip install dist/*.whl - + - name: Test import works run: | python -c " @@ -98,7 +98,7 @@ jobs: print('✅ bitsandbytes_intel import successful') print('✅ All imports successful - no XPU operations tested, as for that we would need to configure the XPU runner..') " - + # - name: Test with pytest # run: pytest diff --git a/docker/Dockerfile b/docker/Dockerfile index 9803ff8..ec8e4b0 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -35,14 +35,13 @@ RUN pip install --no-cache-dir \ ruff # Then install PyTorch-dependent packages with constraint to use existing torch -RUN pip install --no-cache-dir \ - --extra-index-url https://download.pytorch.org/whl/xpu \ - -C torch==2.6.0+xpu \ - transformers \ - accelerate \ - bitsandbytes +RUN pip install transformers accelerate bitsandbytes + +# Copy the bitsandbytes-intel repository into /workspace/src/bnb and install it. +COPY .. ${WORKSPACE}/src/bnb +RUN cd ${WORKSPACE}/src/bnb && pip install . COPY --chmod=755 docker/entrypoint.sh /entrypoint.sh ENTRYPOINT ["/entrypoint.sh"] -CMD ["sleep", "infinity"] \ No newline at end of file +CMD ["sleep", "infinity"] diff --git a/docker/entrypoint.sh b/docker/entrypoint.sh index 8979d69..5c33986 100644 --- a/docker/entrypoint.sh +++ b/docker/entrypoint.sh @@ -4,4 +4,4 @@ set -euo pipefail pip install --no-deps -e /workspace/src/bnb pip install --no-deps -e /workspace/src/bnb_intel -exec "$@" \ No newline at end of file +exec "$@" diff --git a/src/bitsandbytes_intel/cpu_xpu_common.py b/src/bitsandbytes_intel/cpu_xpu_common.py deleted file mode 100644 index 13d20ee..0000000 --- a/src/bitsandbytes_intel/cpu_xpu_common.py +++ /dev/null @@ -1,600 +0,0 @@ -import subprocess -from typing import Optional -import warnings - -import torch -import torch.nn.functional as F - -from bitsandbytes.functional import ( - QuantState, - create_dynamic_map, - get_4bit_type, -) - -try: - # to support Intel CPU/GPU (XPU) backend - import intel_extension_for_pytorch as ipex - - ipex_cpu = ipex if ipex._C._has_cpu() else None - ipex_xpu = ipex if ipex._C._has_xpu() else None - ipex_cpu_only = ipex._C._has_cpu() and (not ipex._C._has_xpu()) -except BaseException: - ipex_cpu = None - ipex_xpu = None - ipex_cpu_only = None - - -gxx_available = False -try: - subprocess.run(["g++", "--version"], capture_output=True) # hide terminal output - gxx_available = True -except BaseException: - warnings.warn("g++ not found, torch.compile disabled for CPU/XPU.") - - -Tensor = torch.Tensor - - -def _torch_version_prereq(major, minor): - ver_major = int(torch.__version__.split(".")[0]) - ver_minor = int(torch.__version__.split(".")[1]) - return ver_major * 32 + ver_minor >= major * 32 + minor - - -def _ipex_cpu_version_prereq(major, minor): - if ipex_cpu is not None: - ver_major = ipex_cpu.__version__.split(".")[0] - ver_minor = ipex_cpu.__version__.split(".")[1] - return int(ver_major) * 32 + int(ver_minor) >= major * 32 + minor - return False - - -def _ipex_xpu_version_prereq(major, minor): - if ipex_xpu is not None: - ver_major = ipex_xpu.__version__.split(".")[0] - ver_minor = ipex_xpu.__version__.split(".")[1] - return int(ver_major) * 32 + int(ver_minor) >= major * 32 + minor - return False - - -def _maybe_torch_compile(func): - # torch.compile requires g++ and pytorch >= 2.0 - if gxx_available and _torch_version_prereq(2, 0) and not ipex_xpu: - options = {} - # fx_graph_cache requires pytorch >= 2.2 - if _torch_version_prereq(2, 2): - options.update({"fx_graph_cache": True}) - return torch.compile(func, dynamic=True, options=options) - return func - - -def reverse_4bit_compress_format(weight): - out_1 = torch.empty(weight.size(0), dtype=torch.int32, device=weight.device) - out_2 = torch.empty(weight.size(0), dtype=torch.int32, device=weight.device) - out_1 = (weight & 0xF0) >> 4 - out_2 = (weight & 0xF) << 4 - out = out_1 | out_2 - return out - - -@_maybe_torch_compile -def double_quant_impl(A, col_stats=None, row_stats=None, out_col=None, out_row=None, threshold=0.0): - """ - Find absolute max values of each row/column of a tensor, and symmetrically quantize it to int8. - If threshold > 0.0, only values <= threshold are counted. All outliers are zeroed out in - the original tensor and they are kept in COO format: (rows, cols, values) - If threshold == 0.0, there are no outliers. - Args: - A The tensor to be analyzed and quantized. - col_stats Absolute max values of each column of A. If it is not None, use the values directly. - Otherwise, find the values. - row_stats Absolute max values of each row of A. If it is not None, use the values directly. - Otherwise, find the values. - out_col Output buffer for the result quantized per column if it is not None - out_row Output buffer for the result quantized per row if it is not None - threshold The threshold for finding outliers if it is > 0.0. Otherwise it has no effect. - Return: - A tuple of output quantized per row, output quantized per column, absolute max values of - each row of A, absolute max values of each column of A, outliers in COO format - """ - - cols = A.shape[-1] - if len(A.shape) == 3: - rows = A.shape[0] * A.shape[1] - else: - assert A.dim() == 2, f"double_quant: Input tensor should be 2d or 3d but got {A.dim()}d" - rows = A.shape[0] - A = A.reshape(rows, cols) - - def get_row_col_stats(A): - row_stats = torch.max(torch.abs(A), 1).values # absolute max of each row - col_stats = torch.max(torch.abs(A), 0).values # absolute max of each col - return row_stats, col_stats - - def quant_to_int8(A, stats): - return torch.clamp(torch.round(A * (127.0 / stats)), -128, 127).to(torch.int8) - - if threshold == 0.0: - if row_stats is None or col_stats is None: - row_stats, col_stats = get_row_col_stats(A) - outlier_cols = None - else: - outlier_indices = torch.abs(A) >= threshold # find outliers - outlier_cols = torch.argwhere(outlier_indices.any(dim=0)).view(-1) - outlier_values = A[outlier_indices].clone() - - # outlier_indices = torch.abs(A) >= threshold # find outliers - # outlier_coord = outlier_indices.nonzero() # get outlier coordinates - # outlier_rows = outlier_coord[:, 0] # outlier row for COO sparse tensor - # outlier_cols = outlier_coord[:, 1] # outlier column for COO sparse tensor - # outlier_values = A[outlier_indices] # outlier values for COO sparse tensor - # coo_tensor = COOSparseTensor( - # A.shape[0], A.shape[1], outlier_values.numel(), outlier_rows.int(), outlier_cols.int(), outlier_values - # ) - if row_stats is None or col_stats is None: - A[outlier_indices] = 0 # zero out outliers - row_stats, col_stats = get_row_col_stats(A) - - quant_by_row = quant_to_int8(A, row_stats.unsqueeze(-1)) - quant_by_col = quant_to_int8(A, col_stats.unsqueeze(0)) - - if outlier_cols is not None: - A[outlier_indices] = outlier_values # restore outliers for later use - - if rows > 1: - # zero out outlier columns for all rows - quant_by_row[:, outlier_cols] = 0 - - if out_row is not None: - out_row.copy_(quant_by_row) - else: - out_row = quant_by_row - if out_col is not None: - out_col.copy_(quant_by_col) - else: - out_col = quant_by_col - # Return float stats to align with CUDA impl - return out_row, out_col, row_stats.float(), col_stats.float(), outlier_cols - - -def int8_linear_matmul_impl( - A: torch.Tensor, - B: torch.Tensor, - out: Optional[torch.Tensor] = None, - dtype=torch.int32, -) -> torch.Tensor: - """ - Do GEMMM computation. Data type: int8 * int8 -> int32. - Args: - A Activation of linear, data type is int8 - B Weight of linear, data type is int8 - out Specified output tensor if it is not None - dtype Data type of output - Return: - A tuple of GEMM result in dtype and Sout - """ - - assert A.dtype == torch.int8 - assert B.dtype == torch.int8 - if out is not None: - assert out.dtype == dtype - - dimsA = A.ndim - dimsB = B.ndim - shapeA = A.shape - shapeB = B.shape - assert dimsA in [2, 3], "Only two or three dimensional matrices are supported for argument A" - assert dimsB == 2, "Only two dimensional matrices are supported for argument B" - - if dimsA == 2: - m = shapeA[0] - elif dimsA == 3: - m = shapeA[0] * shapeA[1] - n = shapeB[0] - k = shapeA[-1] - assert shapeA[-1] == shapeB[-1], f"Shapes of A and B do not match, got {shapeA} and {shapeB}" - - # if the tensor is empty, return a transformed empty tensor with the right dimensions - if shapeA[0] == 0 and dimsA == 2: - return torch.empty((0, n), device=A.device, dtype=A.dtype) - elif shapeA[1] == 0 and dimsA == 3: - return torch.empty(tuple(shapeA[:2] + [n]), device=A.device, dtype=A.dtype) - - A_reshaped = A.reshape(m, k) - - # torch._int_mm is available on CPU since torch 2.4, XPU since torch 2.6 - if ( - A.device.type == "cpu" and _torch_version_prereq(2, 4) - # or (A.device.type == "xpu" and _torch_version_prereq(2, 6) - ): - C = torch._int_mm(A_reshaped, B.T).to(dtype) - else: - C = torch.matmul(A_reshaped.float(), B.t().float()).to(dtype) - if C.ndim != dimsA: - assert dimsA == 3 - shapeOut = (shapeA[0], m // shapeA[0], C.shape[-1]) - C = C.reshape(shapeOut) - if out is not None: - out.copy_(C) - else: - out = C - - return out - - -@_maybe_torch_compile -def int8_mm_dequant_impl( - A: torch.Tensor, - row_stats: torch.Tensor, - col_stats: torch.Tensor, - out: Optional[torch.Tensor] = None, - bias: Optional[torch.Tensor] = None, - compute_dtype=torch.float32, - output_dtype=torch.float32, -) -> torch.Tensor: - """ - Dequant and add bias - out = A_int32 * (abs_max_A * abs_max_B) / 127 * 127 + bias - Args: - A The output of int8 gemm, whose dtype is int32 - row_stats Absolute max value of each row of input (A) of gemm - col_stats Absolute max value of each row of weight (B) of gemm - out Output buffer - bias Bias of linear - compute_dtype Data type for computation - output_dtype Data type for output - Return: - The result - """ - assert A.dtype == torch.int32 - out_shape = A.shape - if len(out_shape) == 3: - out_shape = (out_shape[0] * out_shape[1], out_shape[2]) - - if compute_dtype not in [torch.float32, torch.bfloat16]: - warnings.warn( - f"mm_dequant_{A.device}: compute_dtype {compute_dtype} is not supported, will use bfloat16 instead" - ) - compute_dtype = torch.bfloat16 - A_reshaped = A.reshape(out_shape).to(compute_dtype) - row_stats = row_stats.reshape(-1).unsqueeze(-1).to(compute_dtype) - col_stats = col_stats.reshape(-1).unsqueeze(0).to(compute_dtype) - out = A_reshaped * row_stats * col_stats / (127 * 127) - if bias is not None: - out = out + bias.to(compute_dtype) - out = out.to(output_dtype) - return out - - -NF4_QUANT_TABLE = [ - -1.0 - 1e-2, # 0b0000 - -0.8480964004993439, # 0b0001 - -0.6106329262256622, # 0b0010 - -0.4599952697753906, # 0b0011 - -0.33967943489551544, # 0b0100 - -0.23460740596055984, # 0b0101 - -0.13791173323988914, # 0b0110 - -0.045525018125772476, # 0b0111 - 0.03979014977812767, # 0b1000 - 0.1202552504837513, # 0b1001 - 0.2035212516784668, # 0b1010 - 0.2920137718319893, # 0b1011 - 0.3893125355243683, # 0b1100 - 0.5016634166240692, # 0b1101 - 0.6427869200706482, # 0b1110 - 0.8614784181118011, # 0b1111 -] - - -FP4_QUANT_TABLE = { - 0 - 1e-2: 0, # 0b0000 - 0.00260417: 1, # 0b0001 - 0.0859375: 6, # 0b0110 - 0.20833333: 7, # 0b0111 - 0.29166667: 4, # 0b0100 - 0.4166667: 5, # 0b0101 - 0.583333: 2, # 0b0010 - 0.8333333: 3, # 0b0011 -} - -INT8_QUANT_TABLE = create_dynamic_map().tolist() - - -def quantize_4bit_impl( - A: Tensor, - absmax: Tensor = None, - out: Tensor = None, - blocksize=64, - compress_statistics=False, - quant_type="nf4", - quant_storage=torch.uint8, -) -> Tensor: - """ - Quantize tensor A in blocks of 4-bit values. - - Quantizes tensor A by dividing it into blocks which are independently quantized to FP4. - - Parameters - ---------- - A : torch.Tensor - The input tensor. - absmax : torch.Tensor - The absmax values. - out : torch.Tensor - The output tensor (8-bit). - blocksize : int - The blocksize used in quantization. - quant_type : str - The 4-bit quantization data type {fp4, nf4}, only nf4 is supported now - quant_storage: torch.dtype - We can use bytes to convert storage type. - - Returns - ------- - torch.Tensor: - The 8-bit tensor with packed 4-bit values. - tuple(torch.Tensor, torch.Size, torch.dtype, int): - The quantization state to undo the quantization. - """ - if quant_type not in ["nf4", "fp4", "int8"]: - raise NotImplementedError(f"4-bit quantization data type {quant_type} is not implemented for CPU/XPU.") - if quant_type == "fp4": - warnings.warn("fp4 quantization is currently slow on CPU/XPU. Please Use nf4 instead for better performance.") - assert blocksize in [4096, 2048, 1024, 512, 256, 128, 64] - n = A.numel() - input_shape = A.shape - blocks = n // blocksize - blocks += 1 if n % blocksize > 0 else 0 - - if absmax is None: - absmax = torch.zeros((blocks,), device=A.device, dtype=A.dtype) - - if out is None: - out = torch.zeros(((n + 1) // 2), dtype=torch.uint8, device=A.device) - - rem = n % blocksize - has_rem = rem > 0 - - # Scale tensor to [-1, 1] - A_reshaped = A.reshape(n) - A_com = A_reshaped[: n - rem] - A_com_reshaped = A_com.reshape(n // blocksize, blocksize) - absmax[: blocks - has_rem] = torch.abs(A_com_reshaped).max(dim=-1)[0] - scaled_A = torch.clamp(A_com_reshaped * (1 / absmax[: blocks - has_rem].view(-1, 1)), -1, 1) - scaled_A = scaled_A.reshape(-1) - if has_rem: - absmax[-1] = torch.abs(A_reshaped[n - rem :]).max() - scaled_A_rem = torch.clamp(A_reshaped[n - rem :] * (1 / absmax[-1]), -1, 1) - scaled_A = torch.cat([scaled_A, scaled_A_rem], dim=0) - # map [-1, 1] to nf4/fp4 - out_uint8 = torch.empty(scaled_A.shape, dtype=torch.uint8, device=A.device) - if quant_type == "nf4": - for i in range(len(NF4_QUANT_TABLE)): - out_uint8[scaled_A > NF4_QUANT_TABLE[i]] = i - elif quant_type == "fp4": - sign = scaled_A < 0 - abs_scaled_A = torch.abs(scaled_A) - for key, val in FP4_QUANT_TABLE.items(): - out_uint8[abs_scaled_A > key] = val - out_uint8 += sign.to(torch.uint8) * 8 - elif quant_type == "int8": - map = torch.tensor(INT8_QUANT_TABLE, device=scaled_A.device) - diff = torch.abs(scaled_A.unsqueeze(-1) - map) - out_uint8 = torch.argmin(diff, dim=-1).to(torch.uint8).to(scaled_A.device) - - if quant_type == "int8": - out = out_uint8 - code = torch.Tensor(INT8_QUANT_TABLE).to(A.device) - else: - if out_uint8.size(-1) % 2: - out_uint8 = torch.nn.functional.pad(out_uint8, (0, 1), value=0) - out[:] = out_uint8[::2].bitwise_left_shift(4).bitwise_or_(out_uint8[1::2]) - code = get_4bit_type(quant_type, device=A.device) - - if compress_statistics: - offset = absmax.mean() - absmax -= offset - qabsmax, state2 = quantize_4bit_impl(absmax, blocksize=256, quant_type="int8") - del absmax - state = QuantState( - absmax=qabsmax, - shape=input_shape, - dtype=A.dtype, - blocksize=blocksize, - code=code, - quant_type=quant_type, - offset=offset, - state2=state2, - ) - else: - state = QuantState( - absmax=absmax, - shape=input_shape, - dtype=A.dtype, - blocksize=blocksize, - code=code, - quant_type=quant_type, - ) - - if quant_storage != torch.uint8: - bytes_value = out.cpu().numpy().tobytes() - out = torch.frombuffer(bytes_value, dtype=quant_storage).to(A.device) - - return out.reshape(-1, 1), state - - -def dequant_8bit(A, offset, quant_state): - assert A.dtype == torch.uint8 - absmax = quant_state.code[A.reshape(-1).int()] - blocks = absmax.shape[-1] // 256 - res = absmax.shape[-1] % 256 - if res != 0: - absmax = F.pad(absmax, (0, 256 - res), mode="constant", value=0) - absmax = (absmax.view(-1, 256) * quant_state.absmax.view(-1, 1)).to(quant_state.dtype).reshape(-1) - absmax = absmax[: blocks * 256 + res] - absmax = absmax.reshape(A.shape) - absmax += offset - return absmax - - -# Compile will fail in torch.frombuffer -# @_maybe_torch_compile -def dequantize_4bit_impl( - A: Tensor, - quant_state=None, - absmax: Tensor = None, - out: Tensor = None, - blocksize: int = 64, - quant_type="nf4", -) -> Tensor: - """ - Dequantizes 4-bit blockwise quantized values. - Dequantizes the tensor A with maximum absolute values absmax in blocks of size blocksize. - - Parameters - ---------- - A : torch.Tensor - The input 8-bit tensor (packed 4-bit values). - quant_state : QuantState - object with quantisation stats, incl. absmax values, original tensor shape and original dtype. - absmax : torch.Tensor - The absmax values. - out : torch.Tensor - Dequantized output tensor. - blocksize : int - The blocksize used in quantization. - quant_type : str - The 4-bit quantization data type {fp4, nf4} - - Returns - ------- - torch.Tensor: - Dequantized tensor. - """ - transpose = True if A.shape[0] == 1 else False - A = A.reshape(-1) - device = A.device - if A.dtype != torch.uint8: - bytes_value = A.cpu().numpy().tobytes() - A = torch.frombuffer(bytes_value, dtype=torch.uint8).to(device) - - if quant_state is None: - assert absmax is not None and out is not None - - quant_state = QuantState( - absmax=absmax, - shape=out.shape, - dtype=out.dtype, - blocksize=blocksize, - quant_type=quant_type, - ) - - else: - absmax = quant_state.absmax - - if quant_type not in ["nf4", "fp4"]: - raise NotImplementedError( - f"4-bit quantization data type {quant_state.quant_type} is not implemented for CPU/XPU." - ) - - if quant_state.nested: - absmax = dequant_8bit(absmax, quant_state.offset, quant_state.state2) - - if ipex_cpu_only and _ipex_cpu_version_prereq(2, 5) and getattr(quant_state, "ipex", False): - ipex_weight = torch.ops.ipex_prepack.woq_linear_unpack_weight(A, "nf4", quant_state.shape, 2) - A = reverse_4bit_compress_format(ipex_weight) - quant_state.ipex = False - - # Map nf4 to [-1, 1] - out_dq = torch.empty(A.size(0) * 2, dtype=torch.int32, device=A.device) - n = out_dq.numel() - out_dq[1::2] = A & 0xF - out_dq[::2] = A >> 4 - # quant_state.code is fp32, cast to quant_state dtype to avoid the mismatch issue - quant_state.code = quant_state.code.to(quant_state.dtype) - out_dq = quant_state.code[out_dq] - - # Apply scales - if out_dq.numel() != n: - assert out_dq.numel() == n + 1 - out_dq = torch.narrow(out_dq, 0, 0, n) - blocks = n // blocksize - blocks += 1 if n % blocksize > 0 else 0 - rem = n % blocksize - has_rem = rem > 0 - - if has_rem: - if out is None: - out = torch.empty(quant_state.shape, dtype=quant_state.dtype, device=A.device) - out_reshaped = out.reshape(-1) - out_reshaped[: n - rem] = ( - out_dq[: n - rem].view(-1, blocksize) * absmax[: blocks - has_rem].view(-1, 1) - ).reshape(-1) - out_reshaped[n - rem :] = out_dq[n - rem :] * absmax[-1] - else: - out = (out_dq.view(-1, blocksize) * absmax.view(-1, 1)).reshape(quant_state.shape).to(quant_state.dtype) - - # take transpose here because weight is transposed (again) for computation - if transpose: - out = out.t() - - return out - - -# Do not need torch.compile here as we are calling torch/ipex kernel -def gemm_4bit_impl( - A: torch.Tensor, - B: torch.Tensor, - out: Optional[torch.Tensor] = None, - transposed_A=False, - transposed_B=False, - state: QuantState = None, -) -> torch.Tensor: - """ - Matrix-matrix multiplication with 4-bit quantization. - - Parameters - ---------- - A : torch.Tensor - The first input tensor. Usually the activation tensor. - B : torch.Tensor - The second input tensor. Usually the weight tensor. - out : torch.Tensor - The output tensor. - transposed_A : bool - Whether A is transposed - transposed_B : bool - Whether B is transposed - state : QuantState - Contains quantization info, such as blocksize and dtype - - Returns - ------- - torch.Tensor: - GEMM output tensor. - """ - if getattr(state, "ipex", False): - # compute_dtype: 1 indicates fp16, 2 indicates bf16 - compute_dtype = 2 if A.dtype == torch.bfloat16 else 1 - output = torch.ops.torch_ipex.woq_linear( - A, - B, - "nf4", - state.shape, - state.new_scales, - state.new_zeros, - None, - None, - state.blocksize, - compute_dtype, - 1, - state.compensation, - ) - else: - dqB = dequantize_4bit_impl(B, state, blocksize=state.blocksize) - output = torch.matmul(A, dqB.to(A.dtype)) - if out is not None: - out.copy_(output) - else: - out = output - return out diff --git a/src/bitsandbytes_intel/ops.py b/src/bitsandbytes_intel/ops.py index 8ffe13b..8e353fb 100644 --- a/src/bitsandbytes_intel/ops.py +++ b/src/bitsandbytes_intel/ops.py @@ -1,31 +1,67 @@ from collections.abc import Sequence import math +from typing import Optional import torch -from .cpu_xpu_common import int8_linear_matmul_impl +from .xpu import ( + _ipex_xpu_version_prereq, + dequantize_4bit_impl, + dequantize_blockwise_ipex_impl, + dequantize_blockwise_torch_impl, + gemv_4bit_impl, + int8_linear_matmul_impl, + int8_mm_dequant_impl, + ipex_cpu, + ipex_xpu, + optimizer_update_8bit_blockwise, + quantize_4bit_impl, + quantize_blockwise_impl, +) print("Loading ops module") -def register_ops(): +def register_xpu_ops(): print("Registering XPU implementations") - # Check if the operator exists - if not hasattr(torch.ops.bitsandbytes, "int8_linear_matmul"): - raise RuntimeError("bitsandbytes::int8_linear_matmul not found! Make sure bitsandbytes is installed") - - @torch.library.impl("bitsandbytes::int8_linear_matmul", "XPU") + # Register the int8_linear_matmul implementation + @torch.library.impl("bitsandbytes::int8_linear_matmul", "xpu") def int8_linear_matmul_xpu(A: torch.Tensor, B: torch.Tensor): - print("int8_linear_matmul_xpu called with tensors of shape:", A.shape, B.shape) return int8_linear_matmul_impl(A, B) - @torch.library.impl("bitsandbytes::int8_linear_matmul.out", "XPU") + @torch.library.impl("bitsandbytes::int8_linear_matmul.out", "xpu") def int8_linear_matmul_xpu_out(A: torch.Tensor, B: torch.Tensor, out: torch.Tensor): - print("int8_linear_matmul_xpu_out called with tensors of shape:", A.shape, B.shape) - return int8_linear_matmul_impl(A, B, out) + return int8_linear_matmul_impl(A, B) + + # Register the int8_mm_dequant implementation + @torch.library.impl("bitsandbytes::int8_mm_dequant", "xpu") + def int8_mm_dequant_xpu( + A: torch.Tensor, + row_stats: torch.Tensor, + col_stats: torch.Tensor, + dtype: Optional[torch.dtype] = None, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + return int8_mm_dequant_impl(A, row_stats, col_stats, dtype, bias) + + # Register the quantize_4bit implementation + @torch.library.impl("bitsandbytes::quantize_4bit", "xpu") + def quantize_4bit_xpu( + A: torch.Tensor, + blocksize: int, + quant_type: str, + quant_storage: torch.dtype, + ) -> tuple[torch.Tensor, torch.Tensor]: + return quantize_4bit_impl( + A, + blocksize, + quant_type, + quant_storage, + ) - @torch.library.impl("bitsandbytes::dequantize_4bit.out", "XPU") + # Register the dequantize_4bit implementation + @torch.library.impl("bitsandbytes::dequantize_4bit", "xpu") def dequantize_4bit_xpu( A: torch.Tensor, absmax: torch.Tensor, @@ -33,19 +69,94 @@ def dequantize_4bit_xpu( quant_type: str, shape: Sequence[int], dtype: torch.dtype, - out: torch.Tensor, ) -> torch.Tensor: - # TODO - # if quant_type == "nf4" and getattr(quant_state, "ipex", False): - # output = torch.ops.torch_ipex.dequantize_4bit(A, "nf4", shape, absmax, None, blocksize).t() - # else: - # output = dequantize_4bit_impl(A, quant_state, absmax, out, blocksize, quant_type) + return dequantize_4bit_impl(A, absmax, blocksize, quant_type, shape, dtype) - # return output - raise NotImplementedError + # Register the quantize_blockwise implementation + @torch.library.impl("bitsandbytes::quantize_blockwise", "xpu") + def quantize_blockwise_xpu( + A: torch.Tensor, + code: torch.Tensor, + blocksize: int, + ) -> tuple[torch.Tensor, torch.Tensor]: + return quantize_blockwise_impl(A, code, blocksize) + + # Register the dequantize_blockwise implementation + dequantize_blockwise_impl = ( + dequantize_blockwise_ipex_impl if _ipex_xpu_version_prereq(2, 7) else dequantize_blockwise_torch_impl + ) + + @torch.library.impl("bitsandbytes::dequantize_blockwise", "xpu") + def dequantize_blockwise_xpu( + A: torch.Tensor, + absmax: torch.Tensor, + code: torch.Tensor, + blocksize: int, + dtype: torch.dtype, + ) -> torch.Tensor: + return dequantize_blockwise_impl(A, absmax, code, blocksize, dtype) + + # Register the gemv_4bit implementation + @torch.library.impl("bitsandbytes::gemv_4bit", "xpu") + def gemv_4bit_xpu( + A: torch.Tensor, + B: torch.Tensor, + shapeB: Sequence[int], + absmax: torch.Tensor, + code: torch.Tensor, + blocksize: int, + ) -> torch.Tensor: + return gemv_4bit_impl(A, B, shapeB, absmax, code, blocksize) + + # Register the optimizer_update_8bit_blockwise implementation + @torch.library.impl("bitsandbytes::optimizer_update_8bit_blockwise", "xpu") + def optimizer_update_8bit_blockwise_xpu( + optimizer_name: str, + g: torch.Tensor, + p: torch.Tensor, + state1: torch.Tensor, + state2: Optional[torch.Tensor], + beta1: float, + beta2: float, + beta3: float, + alpha: float, + eps: float, + step: int, + lr: float, + qmap1: torch.Tensor, + qmap2: Optional[torch.Tensor], + absmax1: torch.Tensor, + absmax2: Optional[torch.Tensor], + weight_decay: float = 0.0, + gnorm_scale: float = 1.0, + skip_zeros=False, + ) -> None: + optimizer_update_8bit_blockwise( + optimizer_name, + g, + p, + state1, + state2, + beta1, + beta2, + beta3, + alpha, + eps, + step, + lr, + qmap1, + qmap2, + absmax1, + absmax2, + weight_decay, + gnorm_scale, + skip_zeros, + ) print("Successfully registered XPU implementation") + +def register_hpu_ops(): print("Registering HPU implementations") @torch.library.impl("bitsandbytes::dequantize_4bit", "HPU") @@ -77,4 +188,61 @@ def quantize_4bit_hpu( print("Successfully registered HPU implementations") +def register_ipex_ops(): + print("Registering IPEX implementations") + + # Register the dequantize_nf4_ipex implementation + if ipex_cpu: + from bitsandbytes.utils import _reverse_4bit_compress_format + + @torch.library.impl("bitsandbytes::dequantize_nf4_ipex", "cpu") + def dequantize_nf4_ipex_cpu( + A: torch.Tensor, + absmax: torch.Tensor, + blocksize: int, + quant_type: str, + shape: Sequence[int], + dtype: torch.dtype, + ) -> torch.Tensor: + ipex_weight = torch.ops.ipex_prepack.woq_linear_unpack_weight(A, "nf4", shape, 2) + A = _reverse_4bit_compress_format(ipex_weight.reshape(-1)).reshape(1, -1) + return torch.ops.bitsandbytes.dequantize_4bit.default( + A, + absmax, + blocksize, + "nf4", + shape, + dtype, + ) + + if ipex_xpu: + + @torch.library.impl("bitsandbytes::dequantize_nf4_ipex", "xpu") + def dequantize_nf4_ipex_xpu( + A: torch.Tensor, + absmax: torch.Tensor, + blocksize: int, + quant_type: str, + shape: Sequence[int], + dtype: torch.dtype, + ) -> torch.Tensor: + return torch.ops.torch_ipex.dequantize_4bit(A, "nf4", shape, absmax, None, blocksize).t().to(dtype) + + print("Successfully registered IPEX implementation") + + +def register_ops(): + # Check if the operator exists + if not hasattr(torch.ops.bitsandbytes, "int8_linear_matmul"): + raise RuntimeError("bitsandbytes::int8_linear_matmul not found! Make sure bitsandbytes is installed") + + if hasattr(torch, "xpu") and torch.xpu.is_available(): + register_xpu_ops() + # TODO: Need to check HPU + elif hasattr(torch.backends, "hpu") and torch.backends.hpu.is_available(): + register_hpu_ops() + if ipex_cpu or ipex_xpu: + register_ipex_ops() + + print("ops module loaded") diff --git a/src/bitsandbytes_intel/xpu.py b/src/bitsandbytes_intel/xpu.py new file mode 100644 index 0000000..9820fb2 --- /dev/null +++ b/src/bitsandbytes_intel/xpu.py @@ -0,0 +1,472 @@ +from collections.abc import Sequence +import subprocess +from typing import Optional +import warnings + +import torch +import torch.nn.functional as F + +try: + # to support Intel CPU/GPU (XPU) backend + import intel_extension_for_pytorch as ipex + + ipex_cpu = ipex if ipex._C._has_cpu() else None + ipex_xpu = ipex if ipex._C._has_xpu() else None + ipex_cpu_only = ipex._C._has_cpu() and (not ipex._C._has_xpu()) +except BaseException: + ipex_cpu = None + ipex_xpu = None + ipex_cpu_only = None + + +gxx_available = False +try: + subprocess.run(["g++", "--version"], capture_output=True) # hide terminal output + gxx_available = True +except BaseException: + warnings.warn("g++ not found, torch.compile disabled for CPU/XPU.") + + +Tensor = torch.Tensor + + +def _torch_version_prereq(major, minor): + ver_major = int(torch.__version__.split(".")[0]) + ver_minor = int(torch.__version__.split(".")[1]) + return ver_major * 32 + ver_minor >= major * 32 + minor + + +def _ipex_xpu_version_prereq(major, minor): + if ipex_xpu is not None: + ver_major = ipex_xpu.__version__.split(".")[0] + ver_minor = ipex_xpu.__version__.split(".")[1] + return int(ver_major) * 32 + int(ver_minor) >= major * 32 + minor + return False + + +str2optimizer8bit_blockwise = {} +if ipex_xpu is not None and _ipex_xpu_version_prereq(2, 7): + str2optimizer8bit_blockwise = { + "adam": ( + ipex.xpu.bitsandbytes.cadam_8bit_blockwise_grad_fp32, + ipex.xpu.bitsandbytes.cadam_8bit_blockwise_grad_fp16, + ipex.xpu.bitsandbytes.cadam_8bit_blockwise_grad_bf16, + ), + } + + +def _maybe_torch_compile(func): + # torch.compile requires g++ and pytorch >= 2.0 + if gxx_available and _torch_version_prereq(2, 0) and not ipex_xpu: + options = {} + # fx_graph_cache requires pytorch >= 2.2 + if _torch_version_prereq(2, 2): + options.update({"fx_graph_cache": True}) + return torch.compile(func, dynamic=True, options=options) + return func + + +def transform( + A: torch.Tensor, + out: Optional[torch.Tensor] = None, + transpose=False, + state: Optional[tuple[torch.Size, str]] = None, +): + """ + Transform tensor A to to_order. It is originally designed for CUDA. + For CPU/XPU, it returns the original tensor if transpose=False. + Otherwise, it returns the transpose of A + """ + if transpose: + if out is not None: + out.copy_(A.T) + else: + out = A.T + else: + if out is not None: + out.copy_(A) + else: + out = A + return out, state + + +# Applied from cpu int8_linear_matmul op +def int8_linear_matmul_impl(A: torch.Tensor, B: torch.Tensor): + return torch._int_mm( + A.reshape(-1, A.shape[-1]), + B.t(), + ).reshape(*A.shape[:-1], B.shape[0]) + + +@_maybe_torch_compile +def int8_mm_dequant_impl( + A: torch.Tensor, + row_stats: torch.Tensor, + col_stats: torch.Tensor, + dtype: Optional[torch.dtype] = None, + bias: Optional[torch.Tensor] = None, +) -> torch.Tensor: + torch._check(A.dtype == torch.int32, lambda: f"A must be int32, got {A.dtype}") + torch._check(row_stats.dtype == torch.float32, lambda: f"row_stats must be float32, got {row_stats.dtype}") + torch._check(col_stats.dtype == torch.float32, lambda: f"col_stats must be float32, got {col_stats.dtype}") + + A_calc = A.view(-1, A.shape[-1]) + row_stats = row_stats.reshape(-1).unsqueeze(-1) + col_stats = col_stats.reshape(-1).unsqueeze(0) + + out = A_calc * (row_stats * col_stats) * 6.200124e-05 + if bias is not None: + out += bias + + return out.to(dtype or torch.float16) + + +_NF4_QUANT_TABLE = torch.tensor( + [ + -1.0, + -0.6961928009986877, + -0.5250730514526367, + -0.39491748809814453, + -0.28444138169288635, + -0.18477343022823334, + -0.09105003625154495, + 0.0, + 0.07958029955625534, + 0.16093020141124725, + 0.24611230194568634, + 0.33791524171829224, + 0.44070982933044434, + 0.5626170039176941, + 0.7229568362236023, + 1.0, + ], + dtype=torch.float32, + device="xpu", +) +_FP4_QUANT_TABLE = torch.tensor( + [ + 0.0000, + 0.0052, + 0.6667, + 1.0000, + 0.3333, + 0.5000, + 0.1667, + 0.2500, + 0.0000, + -0.0052, + -0.6667, + -1.0000, + -0.3333, + -0.5000, + -0.1667, + -0.2500, + ], + dtype=torch.float32, + device="xpu", +) +CODE = {"nf4": _NF4_QUANT_TABLE, "fp4": _FP4_QUANT_TABLE} + + +def quantize_blockwise_impl( + A: torch.Tensor, + code: torch.Tensor, + blocksize: int, +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Quantize tensor A in blocks of 8-bit values. + + Quantizes tensor A by dividing it into blocks which are independently quantized to int8. + + Parameters + ---------- + A : torch.Tensor + The input tensor. + code : torch.Tensor + The quantization code. + blocksize : int + The blocksize used in quantization. + + Returns + ------- + torch.Tensor: + The 8-bit tensor with packed 4-bit values. + torch.Tensor: + The absmax. + """ + n = A.numel() + blocks = n // blocksize + blocks += 1 if n % blocksize > 0 else 0 + absmax = torch.zeros((blocks,), device=A.device, dtype=A.dtype) + rem = n % blocksize + has_rem = rem > 0 + # Scale tensor to [-1, 1] + A_reshaped = A.reshape(n) + A_com = A_reshaped[: n - rem] + A_com_reshaped = A_com.reshape(n // blocksize, blocksize) + absmax[: blocks - has_rem] = torch.abs(A_com_reshaped).max(dim=-1)[0] + scaled_A = torch.clamp(A_com_reshaped * (1 / absmax[: blocks - has_rem].view(-1, 1)), -1, 1) + scaled_A = scaled_A.reshape(-1) + if has_rem: + absmax[-1] = torch.abs(A_reshaped[n - rem :]).max() + scaled_A_rem = torch.clamp(A_reshaped[n - rem :] * (1 / absmax[-1]), -1, 1) + scaled_A = torch.cat([scaled_A, scaled_A_rem], dim=0) + + diff = torch.abs(scaled_A.unsqueeze(-1) - code.to(scaled_A.device)) + out_uint8 = torch.argmin(diff, dim=-1).to(torch.uint8).to(scaled_A.device) + + return out_uint8, absmax + + +def dequantize_blockwise_torch_impl( + A: torch.Tensor, + absmax: torch.Tensor, + code: torch.Tensor, + blocksize: int, + dtype: torch.dtype, +) -> torch.Tensor: + assert A.dtype == torch.uint8 + out = code[A.reshape(-1).int()] + blocks = out.shape[-1] // blocksize + res = out.shape[-1] % blocksize + if res != 0: + out = F.pad(out, (0, blocksize - res), mode="constant", value=0) + out = (out.view(-1, blocksize) * absmax.view(-1, 1)).to(dtype).reshape(-1) + out = out[: blocks * blocksize + res] + out = out.reshape(A.shape) + + return out + + +# Currently only works for XPU +def dequantize_blockwise_ipex_impl( + A: torch.Tensor, + absmax: torch.Tensor, + code: torch.Tensor, + blocksize: int, + dtype: torch.dtype, +) -> torch.Tensor: + if ipex_xpu is None or not _ipex_xpu_version_prereq(2, 7): + raise RuntimeError("Please install intel_extension_for_ipex >= 2.7 for 8bit optimizer backend on XPU device.") + + out = torch.empty(A.reshape(-1).shape, dtype=dtype, device=A.device) + # void cdequantize_blockwise_fp32( + # float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n, cudaStream_t stream) + if dtype == torch.float16: + ipex.xpu.bitsandbytes.cdequantize_blockwise_fp16(code, A, absmax, out, blocksize, A.numel()) + elif dtype == torch.bfloat16: + ipex.xpu.bitsandbytes.cdequantize_blockwise_bf16(code, A, absmax, out, blocksize, A.numel()) + elif dtype == torch.float32: + ipex.xpu.bitsandbytes.cdequantize_blockwise_fp32(code, A, absmax, out, blocksize, A.numel()) + else: + raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {out.dtype}") + + return out + + +# Copied from cpu quantize_4bit op +def quantize_4bit_impl( + A: torch.Tensor, blocksize: int, quant_type: str, quant_storage: torch.dtype +) -> tuple[torch.Tensor, torch.Tensor]: + torch._check_is_size(blocksize) + torch._check(quant_type in ("nf4", "fp4"), lambda: f"quant_type must be nf4 or fp4 on CPU, got {quant_type}") + torch._check( + A.dtype in [torch.bfloat16, torch.float16, torch.float32], + lambda: f"Blockwise 4bit quantization only supports 16/32-bit floats, but got {A.dtype}", + ) + + n = A.numel() + blocks = n // blocksize + blocks += 1 if n % blocksize > 0 else 0 + rem = n % blocksize + has_rem = rem > 0 + + # Scale tensor to [-1, 1] + absmax = torch.zeros((blocks,), device=A.device, dtype=A.dtype) + A_reshaped = A.reshape(n) + A_com_reshaped = A_reshaped[: n - rem].reshape(n // blocksize, blocksize) + absmax[: blocks - has_rem] = torch.abs(A_com_reshaped).max(dim=-1)[0] + scaled = torch.clamp(A_com_reshaped * (1 / absmax[: blocks - has_rem].view(-1, 1)), -1, 1) + scaled = scaled.reshape(-1) + if has_rem: + absmax[-1] = torch.abs(A_reshaped[n - rem :]).max() + scaled_rem = torch.clamp(A_reshaped[n - rem :] * (1 / absmax[-1]), -1, 1) + scaled = torch.cat([scaled, scaled_rem], dim=0) + # Quantize with the lookup table + quant_table = CODE[quant_type].to(scaled.device) + quantized = torch.argmin(torch.abs(scaled.view(-1, 1) - quant_table), dim=-1, keepdim=True).to(torch.uint8) + + # Pack two quantized values per byte + packed = quantized[::2] << 4 | quantized[1::2] + + if quant_storage != torch.uint8: + packed = packed.squeeze().view(quant_storage).unsqueeze(1) + + return packed, absmax.float() + + +# Copied from cpu dequantize_4bit op +# Compile will fail in torch.frombuffer +# @_maybe_torch_compile +def dequantize_4bit_impl( + A: torch.Tensor, + absmax: torch.Tensor, + blocksize: int, + quant_type: str, + shape: Sequence[int], + dtype: torch.dtype, +) -> Tensor: + torch._check_is_size(blocksize) + torch._check(quant_type in ("nf4", "fp4"), lambda: f"quant_type must be nf4 or fp4 on CPU, got {quant_type}") + torch._check( + dtype in [torch.bfloat16, torch.float16, torch.float32], + lambda: f"Blockwise 4bit dequantization only supports 16/32-bit floats, but got {dtype}", + ) + + # Enable non uint8 dtype + device = A.device + if A.dtype != torch.uint8: + bytes_value = A.cpu().numpy().tobytes() + A = torch.frombuffer(bytes_value, dtype=torch.uint8).to(device) + + A = A.reshape(-1) + # Map nf4 to [-1, 1] + out_dq = torch.empty(A.size(0) * 2, dtype=torch.int32, device=A.device) + n = out_dq.numel() + out_dq[1::2] = A & 0xF + out_dq[::2] = A >> 4 + # code is fp32, cast to dtype to avoid the mismatch issue + code = CODE[quant_type].to(out_dq.device).to(dtype) + out_dq = code[out_dq] + + # Apply scales + if out_dq.numel() != n: + assert out_dq.numel() == n + 1 + out_dq = torch.narrow(out_dq, 0, 0, n) + blocks = n // blocksize + blocks += 1 if n % blocksize > 0 else 0 + rem = n % blocksize + has_rem = rem > 0 + + out = torch.empty(shape, dtype=dtype, device=A.device).reshape(-1) + if has_rem: + out[: n - rem] = (out_dq[: n - rem].view(-1, blocksize) * absmax[: blocks - has_rem].view(-1, 1)).reshape(-1) + out[n - rem :] = out_dq[n - rem :] * absmax[-1] + else: + out = out_dq.view(-1, blocksize) * absmax.view(-1, 1) + + out = out.reshape(-1, *shape[1:]).to(dtype) + + return out + + +# Copied from cpu gemv_4bit op +def gemv_4bit_impl( + A: torch.Tensor, + B: torch.Tensor, + shapeB: Sequence[int], + absmax: torch.Tensor, + code: torch.Tensor, + blocksize: int, +) -> torch.Tensor: + # Applied from dequantize_4bit + B = B.view(-1, 1) + upper = (B >> 4).to(torch.int64) + lower = (B & 0x0F).to(torch.int64) + blocks = torch.cat((upper, lower), dim=1).reshape(-1, blocksize) + B_dq = code[blocks] * absmax[:, None] + B_dq = B_dq.reshape(-1, *shapeB[1:]).to(A.dtype) + + # User called gemv with B.t(), so we need to transpose it back. + # if B.shape[0] == 1: + # B_dq = B_dq.t() + + return torch.nn.functional.linear( + A, + B_dq, + bias=None, + ) + + +def optimizer_update_8bit_blockwise( + optimizer_name: str, + g: torch.Tensor, + p: torch.Tensor, + state1: torch.Tensor, + state2: Optional[torch.Tensor], + beta1: float, + beta2: float, + beta3: float, + alpha: float, + eps: float, + step: int, + lr: float, + qmap1: torch.Tensor, + qmap2: Optional[torch.Tensor], + absmax1: torch.Tensor, + absmax2: Optional[torch.Tensor], + weight_decay: float = 0.0, + gnorm_scale: float = 1.0, + skip_zeros=False, +) -> None: + optim_func = None + if ipex_xpu is None or not _ipex_xpu_version_prereq(2, 7): + raise RuntimeError("Please install intel_extension_for_ipex >= 2.7 for 8bit optimizer backend on XPU device.") + + if g.dtype == torch.float32 and state1.dtype == torch.uint8: + optim_func = str2optimizer8bit_blockwise[optimizer_name][0] + elif g.dtype == torch.float16 and state1.dtype == torch.uint8: + optim_func = str2optimizer8bit_blockwise[optimizer_name][1] + elif ( + g.dtype == torch.bfloat16 + and state1.dtype == torch.uint8 + and len(str2optimizer8bit_blockwise[optimizer_name]) == 3 + ): + optim_func = str2optimizer8bit_blockwise[optimizer_name][2] + else: + raise ValueError( + f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}", + ) + optim_func( + p, + g, + state1, + state2, + beta1, + beta2, + beta3, + alpha, + eps, + step, + lr, + qmap1, + qmap2, + absmax1, + absmax2, + weight_decay, + gnorm_scale, + skip_zeros, + g.numel(), + ) + + +def optimizer_update_32bit( + optimizer_name: str, + g: torch.Tensor, + p: torch.Tensor, + state1: torch.Tensor, + beta1: float, + eps: float, + step: int, + lr: float, + state2: Optional[torch.Tensor] = None, + beta2: float = 0.0, + beta3: float = 0.0, + alpha: float = 0.0, + weight_decay: float = 0.0, + gnorm_scale: float = 1.0, + unorm_vec: Optional[torch.Tensor] = None, + max_unorm: float = 0.0, + skip_zeros=False, +) -> None: + raise NotImplementedError