Skip to content
10 changes: 5 additions & 5 deletions bitsandbytes/backends/cuda/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from bitsandbytes.functional import CUBLAS_Context, _cuda_device_of, _get_tensor_stream, get_ptr

from ..._ops import register_kernel
from ...cextension import HIP_ENVIRONMENT, lib
from ...cextension import ROCM_WARP_SIZE_64, lib


@register_kernel("bitsandbytes::int8_linear_matmul", "cuda")
Expand Down Expand Up @@ -211,7 +211,7 @@ def _get_col_absmax(
def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor, torch.Tensor]:
torch._check_is_size(blocksize)

if HIP_ENVIRONMENT:
if ROCM_WARP_SIZE_64:
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128])
else:
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64])
Expand Down Expand Up @@ -269,7 +269,7 @@ def _(
def _dequantize_blockwise_impl(
A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype, out: torch.Tensor
) -> None:
if HIP_ENVIRONMENT:
if ROCM_WARP_SIZE_64:
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128])
else:
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64])
Expand Down Expand Up @@ -303,7 +303,7 @@ def _dequantize_blockwise_impl(
def _(
A: torch.Tensor, blocksize: int, quant_type: str, quant_storage: torch.dtype
) -> tuple[torch.Tensor, torch.Tensor]:
if HIP_ENVIRONMENT:
if ROCM_WARP_SIZE_64:
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128])
else:
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64])
Expand Down Expand Up @@ -385,7 +385,7 @@ def _dequantize_4bit_impl(
dtype: torch.dtype,
out: torch.Tensor,
) -> None:
if HIP_ENVIRONMENT:
if ROCM_WARP_SIZE_64:
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128])
else:
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64])
Expand Down
9 changes: 8 additions & 1 deletion bitsandbytes/cextension.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,13 @@
import torch

from bitsandbytes.consts import DYNAMIC_LIBRARY_SUFFIX, PACKAGE_DIR
from bitsandbytes.cuda_specs import CUDASpecs, get_cuda_specs, get_cuda_version_tuple, get_rocm_gpu_arch
from bitsandbytes.cuda_specs import (
CUDASpecs,
get_cuda_specs,
get_cuda_version_tuple,
get_rocm_gpu_arch,
get_rocm_warpsize,
)

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -298,6 +304,7 @@ def get_native_library() -> BNBNativeLibrary:


ROCM_GPU_ARCH = get_rocm_gpu_arch()
ROCM_WARP_SIZE_64 = True if get_rocm_warpsize() == 64 else False

HIP_ENVIRONMENT = False
BNB_BACKEND = "CPU"
Expand Down
26 changes: 26 additions & 0 deletions bitsandbytes/cuda_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,3 +100,29 @@ def get_rocm_gpu_arch() -> str:
""",
)
return "unknown"


def get_rocm_warpsize() -> int:
"""Get ROCm warp size."""
logger = logging.getLogger(__name__)
try:
if torch.version.hip:
result = subprocess.run(["rocminfo"], capture_output=True, text=True)
match = re.search(r"Wavefront Size:\s+([0-9]{2})\(0x[0-9]{2}\)", result.stdout)
if match:
return int(match.group(1))
else:
# default to 64 to be safe
return 64
else:
# nvidia cards always use 32 warp size
return 32
except Exception as e:
logger.error(f"Could not detect ROCm warp size: {e}. Defaulting to 64. (some 4-bit functions may not work!)")
if torch.cuda.is_available():
logger.warning(
"""
ROCm warp size detection failed despite ROCm being available.
""",
)
return 64
14 changes: 7 additions & 7 deletions bitsandbytes/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

from bitsandbytes.utils import pack_dict_to_tensor, unpack_tensor_to_dict

from .cextension import HIP_ENVIRONMENT, lib
from .cextension import ROCM_WARP_SIZE_64, lib

name2qmap = {}

Expand Down Expand Up @@ -806,7 +806,7 @@ def quantize_fp4(
quant_storage=torch.uint8,
):
if blocksize is None:
blocksize = 64 if not HIP_ENVIRONMENT else 128
blocksize = 64 if not ROCM_WARP_SIZE_64 else 128
return quantize_4bit(A, absmax, out, blocksize, compress_statistics, "fp4", quant_storage)


Expand All @@ -819,7 +819,7 @@ def quantize_nf4(
quant_storage=torch.uint8,
):
if blocksize is None:
blocksize = 64 if not HIP_ENVIRONMENT else 128
blocksize = 64 if not ROCM_WARP_SIZE_64 else 128
return quantize_4bit(A, absmax, out, blocksize, compress_statistics, "nf4", quant_storage)


Expand Down Expand Up @@ -857,7 +857,7 @@ def quantize_4bit(
"""

if blocksize is None:
blocksize = 64 if not HIP_ENVIRONMENT else 128
blocksize = 64 if not ROCM_WARP_SIZE_64 else 128

input_shape = A.shape

Expand Down Expand Up @@ -912,7 +912,7 @@ def dequantize_fp4(
blocksize: Optional[int] = None,
) -> torch.Tensor:
if blocksize is None:
blocksize = 64 if not HIP_ENVIRONMENT else 128
blocksize = 64 if not ROCM_WARP_SIZE_64 else 128
return dequantize_4bit(A, quant_state, absmax, out, blocksize, "fp4")


Expand All @@ -924,7 +924,7 @@ def dequantize_nf4(
blocksize: Optional[int] = None,
) -> torch.Tensor:
if blocksize is None:
blocksize = 64 if not HIP_ENVIRONMENT else 128
blocksize = 64 if not ROCM_WARP_SIZE_64 else 128
return dequantize_4bit(A, quant_state, absmax, out, blocksize, "nf4")


Expand Down Expand Up @@ -964,7 +964,7 @@ def dequantize_4bit(
"""

if blocksize is None:
blocksize = 64 if not HIP_ENVIRONMENT else 128
blocksize = 64 if not ROCM_WARP_SIZE_64 else 128

if quant_state is None:
assert absmax is not None and out is not None
Expand Down
4 changes: 2 additions & 2 deletions bitsandbytes/nn/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import torch.nn.functional as F

import bitsandbytes as bnb
from bitsandbytes.cextension import HIP_ENVIRONMENT
from bitsandbytes.cextension import ROCM_WARP_SIZE_64
from bitsandbytes.functional import QuantState
from bitsandbytes.optim import GlobalOptimManager
from bitsandbytes.utils import INVERSE_LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING, OutlierTracer
Expand Down Expand Up @@ -221,7 +221,7 @@ def __new__(
data = torch.empty(0)

if blocksize is None:
blocksize = 64 if not HIP_ENVIRONMENT else 128
blocksize = 64 if not ROCM_WARP_SIZE_64 else 128

self = torch.Tensor._make_subclass(cls, data, requires_grad)
self.blocksize = blocksize
Expand Down
8 changes: 6 additions & 2 deletions csrc/common_hip.cuh
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
#pragma once

#define BNB_WARP_SIZE warpSize
#ifdef __GFX9__
#define BNB_WARP_SIZE 64
#else
#define BNB_WARP_SIZE 32
#endif

// These are set based on current BNB support for CDNA 2 & RDNA 3. Update as needed for future archs
#define BNB_MAX_THREADS_PER_SM 2048
#define BNB_MAX_THREADS_PER_CU 2048
#define BNB_BF16_AVAILABLE true
69 changes: 41 additions & 28 deletions csrc/kernels.hip
Original file line number Diff line number Diff line change
Expand Up @@ -1881,7 +1881,7 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char
// rowStats [rows]
// out [rows, cols]
template<typename T, int THREADS, int SPARSE_DECOMP>
__launch_bounds__(1024, BNB_MAX_THREADS_PER_SM / 1024)
__launch_bounds__(1024, BNB_MAX_THREADS_PER_CU / 1024)
__global__ void kInt8VectorQuant(T * __restrict__ A, int8_t* out, float* rowStats, float threshold, int rows, int cols) {

// For sm50/sm52 and CUDA < 12.2 we need to do the reduction in fp32.
Expand Down Expand Up @@ -1945,7 +1945,7 @@ __global__ void kInt8VectorQuant(T * __restrict__ A, int8_t* out, float* rowStat
}

template<typename T, int THREADS, int SPARSE_DECOMP>
__launch_bounds__(1024, BNB_MAX_THREADS_PER_SM / 1024)
__launch_bounds__(1024, BNB_MAX_THREADS_PER_CU / 1024)
__global__ void kgetRowStats(T * __restrict__ A, float *rowStats, float threshold, int rows, int cols) {
using BlockReduceT = hipcub::BlockReduce<float, THREADS>;

Expand Down Expand Up @@ -2057,11 +2057,6 @@ __global__ void kdequant_mm_int32_fp16(
#define DENORM 1.0f/127.0f
#define MAX_SPARSE_COUNT 32
#define SMEM_SIZE 8*256
#if defined(__GFX9__)
#define WARP_SIZE 64
#else
#define WARP_SIZE 32
#endif
template <typename T, int SPMM_ITEMS, int BITS>
__global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB)
{
Expand All @@ -2082,9 +2077,9 @@ __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *o
const int offset = local_max_idx == 0 ? 0 : offset_rowidx[local_max_idx-1];
const int local_row_idx = rowidx[offset];

const int warp_id = threadIdx.x / WARP_SIZE;
const int warp_idx = threadIdx.x % WARP_SIZE;
const int warp_offset = (warp_id*WARP_SIZE)*SPMM_ITEMS;
const int warp_id = threadIdx.x / BNB_WARP_SIZE;
const int warp_idx = threadIdx.x % BNB_WARP_SIZE;
const int warp_offset = (warp_id*BNB_WARP_SIZE)*SPMM_ITEMS;
const int num_items = BITS == 8 ? 8 : 8;
int idx_col_B = warp_offset;
int local_idx_col_B_offset = 0;
Expand All @@ -2104,7 +2099,7 @@ __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *o
}

// each thread processes SPMM_ITEMS=32 per iteration. We have 256 threads. 32*256=x192
// we expect each warp to be SPMM_ITEMS*WARP_SIZE apart
// we expect each warp to be SPMM_ITEMS*BNB_WARP_SIZE apart
// we have a total of 128 bytes for the bank with a bank size of 4 bytes
// added 3 bytes = 6 values between warps should reduce bank conflicts
__shared__ half smem_dequant_stats[SMEM_SIZE];
Expand Down Expand Up @@ -2657,15 +2652,15 @@ template <typename T, int THREADS, int BITS> __global__ void kgemm_4bit_inferenc
{

// per threadblock:
// load step-by-step in chunks of [warp_size,warps]: 1xwarp_size * [warp_size,warps] -> [1,warps]
// load step-by-step in chunks of [BNB_WARP_SIZE,warps]: 1xBNB_WARP_SIZE * [BNB_WARP_SIZE,warps] -> [1,warps]
// 4 warps -> 4 loads per iter
// 1xwarp_size * warp_sizex4 -> 1x4 outputs per thread block
typedef hipcub::WarpReduce<float, WARP_SIZE> WarpReduce;
__shared__ typename WarpReduce::TempStorage temp_storage[THREADS/WARP_SIZE];
// 1xBNB_WARP_SIZE * BNB_WARP_SIZEx4 -> 1x4 outputs per thread block
typedef hipcub::WarpReduce<float, BNB_WARP_SIZE> WarpReduce;
__shared__ typename WarpReduce::TempStorage temp_storage[THREADS/BNB_WARP_SIZE];

const int warp_idx = threadIdx.x / WARP_SIZE;
const int warp_lane = threadIdx.x % WARP_SIZE;
const int row_B = (THREADS/WARP_SIZE)*blockIdx.x + warp_idx;
const int warp_idx = threadIdx.x / BNB_WARP_SIZE;
const int warp_lane = threadIdx.x % BNB_WARP_SIZE;
const int row_B = (THREADS/BNB_WARP_SIZE)*blockIdx.x + warp_idx;
const int offset_B = ldb * row_B;
const int num_values_8bit = num_values_4bit/2;
float local_C = 0.0f;
Expand All @@ -2684,7 +2679,7 @@ template <typename T, int THREADS, int BITS> __global__ void kgemm_4bit_inferenc

// A: [1, K]
// B: [M, K]
for(int inner_idx = warp_lane*num_values_4bit; inner_idx < K; inner_idx += WARP_SIZE*num_values_4bit)
for(int inner_idx = warp_lane*num_values_4bit; inner_idx < K; inner_idx += BNB_WARP_SIZE*num_values_4bit)
{
const int inner_idx_halved = inner_idx/2;

Expand Down Expand Up @@ -2996,23 +2991,29 @@ MAKE_kQuantizeBlockwise(half, 1024, 4, 0, General8bit)
MAKE_kQuantizeBlockwise(half, 512, 2, 0, General8bit)
MAKE_kQuantizeBlockwise(half, 256, 2, 0, General8bit)
MAKE_kQuantizeBlockwise(half, 128, 2, 0, General8bit)
//MAKE_kQuantizeBlockwise(half, 64, 2, 0, General8bit)
#if BNB_WARP_SIZE == 32
MAKE_kQuantizeBlockwise(half, 64, 2, 0, General8bit)
#endif

MAKE_kQuantizeBlockwise(half, 4096, 4, 0, FP4)
MAKE_kQuantizeBlockwise(half, 2048, 4, 0, FP4)
MAKE_kQuantizeBlockwise(half, 1024, 4, 0, FP4)
MAKE_kQuantizeBlockwise(half, 512, 2, 0, FP4)
MAKE_kQuantizeBlockwise(half, 256, 2, 0, FP4)
MAKE_kQuantizeBlockwise(half, 128, 2, 0, FP4)
//MAKE_kQuantizeBlockwise(half, 64, 2, 0, FP4)
#if BNB_WARP_SIZE == 32
MAKE_kQuantizeBlockwise(half, 64, 2, 0, FP4)
#endif

MAKE_kQuantizeBlockwise(half, 4096, 4, 0, NF4)
MAKE_kQuantizeBlockwise(half, 2048, 4, 0, NF4)
MAKE_kQuantizeBlockwise(half, 1024, 4, 0, NF4)
MAKE_kQuantizeBlockwise(half, 512, 2, 0, NF4)
MAKE_kQuantizeBlockwise(half, 256, 2, 0, NF4)
MAKE_kQuantizeBlockwise(half, 128, 2, 0, NF4)
//MAKE_kQuantizeBlockwise(half, 64, 2, 0, NF4)
#if BNB_WARP_SIZE == 32
MAKE_kQuantizeBlockwise(half, 64, 2, 0, NF4)
#endif

MAKE_kQuantizeBlockwise(float, 4096, 4, 0, General8bit)
MAKE_kQuantizeBlockwise(float, 4096, 4, 1, General8bit)
Expand All @@ -3021,23 +3022,29 @@ MAKE_kQuantizeBlockwise(float, 1024, 4, 0, General8bit)
MAKE_kQuantizeBlockwise(float, 512, 2, 0, General8bit)
MAKE_kQuantizeBlockwise(float, 256, 2, 0, General8bit)
MAKE_kQuantizeBlockwise(float, 128, 2, 0, General8bit)
//MAKE_kQuantizeBlockwise(float, 64, 2, 0, General8bit)
#if BNB_WARP_SIZE == 32
MAKE_kQuantizeBlockwise(float, 64, 2, 0, General8bit)
#endif

MAKE_kQuantizeBlockwise(float, 4096, 4, 0, FP4)
MAKE_kQuantizeBlockwise(float, 2048, 4, 0, FP4)
MAKE_kQuantizeBlockwise(float, 1024, 4, 0, FP4)
MAKE_kQuantizeBlockwise(float, 512, 2, 0, FP4)
MAKE_kQuantizeBlockwise(float, 256, 2, 0, FP4)
MAKE_kQuantizeBlockwise(float, 128, 2, 0, FP4)
//MAKE_kQuantizeBlockwise(float, 64, 2, 0, FP4)
#if BNB_WARP_SIZE == 32
MAKE_kQuantizeBlockwise(float, 64, 2, 0, FP4)
#endif

MAKE_kQuantizeBlockwise(float, 4096, 4, 0, NF4)
MAKE_kQuantizeBlockwise(float, 2048, 4, 0, NF4)
MAKE_kQuantizeBlockwise(float, 1024, 4, 0, NF4)
MAKE_kQuantizeBlockwise(float, 512, 2, 0, NF4)
MAKE_kQuantizeBlockwise(float, 256, 2, 0, NF4)
MAKE_kQuantizeBlockwise(float, 128, 2, 0, NF4)
//MAKE_kQuantizeBlockwise(float, 64, 2, 0, NF4)
#if BNB_WARP_SIZE == 32
MAKE_kQuantizeBlockwise(float, 64, 2, 0, NF4)
#endif

MAKE_kQuantizeBlockwise(hip_bfloat16, 4096, 4, 0, General8bit)
MAKE_kQuantizeBlockwise(hip_bfloat16, 4096, 4, 1, General8bit)
Expand All @@ -3046,23 +3053,29 @@ MAKE_kQuantizeBlockwise(hip_bfloat16, 1024, 4, 0, General8bit)
MAKE_kQuantizeBlockwise(hip_bfloat16, 512, 2, 0, General8bit)
MAKE_kQuantizeBlockwise(hip_bfloat16, 256, 2, 0, General8bit)
MAKE_kQuantizeBlockwise(hip_bfloat16, 128, 2, 0, General8bit)
//MAKE_kQuantizeBlockwise(hip_bfloat16, 64, 2, 0, General8bit)
#if BNB_WARP_SIZE == 32
MAKE_kQuantizeBlockwise(hip_bfloat16, 64, 2, 0, General8bit)
#endif

MAKE_kQuantizeBlockwise(hip_bfloat16, 4096, 4, 0, FP4)
MAKE_kQuantizeBlockwise(hip_bfloat16, 2048, 4, 0, FP4)
MAKE_kQuantizeBlockwise(hip_bfloat16, 1024, 4, 0, FP4)
MAKE_kQuantizeBlockwise(hip_bfloat16, 512, 2, 0, FP4)
MAKE_kQuantizeBlockwise(hip_bfloat16, 256, 2, 0, FP4)
MAKE_kQuantizeBlockwise(hip_bfloat16, 128, 2, 0, FP4)
//MAKE_kQuantizeBlockwise(hip_bfloat16, 64, 2, 0, FP4)
#if BNB_WARP_SIZE == 32
MAKE_kQuantizeBlockwise(hip_bfloat16, 64, 2, 0, FP4)
#endif

MAKE_kQuantizeBlockwise(hip_bfloat16, 4096, 4, 0, NF4)
MAKE_kQuantizeBlockwise(hip_bfloat16, 2048, 4, 0, NF4)
MAKE_kQuantizeBlockwise(hip_bfloat16, 1024, 4, 0, NF4)
MAKE_kQuantizeBlockwise(hip_bfloat16, 512, 2, 0, NF4)
MAKE_kQuantizeBlockwise(hip_bfloat16, 256, 2, 0, NF4)
MAKE_kQuantizeBlockwise(hip_bfloat16, 128, 2, 0, NF4)
//MAKE_kQuantizeBlockwise(hip_bfloat16, 64, 2, 0, NF4)
#if BNB_WARP_SIZE == 32
MAKE_kQuantizeBlockwise(hip_bfloat16, 64, 2, 0, NF4)
#endif

template __global__ void kDequantizeBlockwise<half, 512, 64, 8, FP4>(float *code, unsigned char * A, float * absmax, half *out, const int blocksize, const int n);
template __global__ void kDequantizeBlockwise<half, 512, 64, 8, General8bit>(float *code, unsigned char * A, float * absmax, half *out, const int blocksize, const int n);
Expand Down
Loading
Loading