Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Reduce BM/BN/BK to 64/32/64 to 48/12/48
Browse files Browse the repository at this point in the history
We initially chose 64/32/64 to make batch processing faster on an NVIDIA
A100 but when the code was run on a $300 AMD Radeon RX 6800 it destroyed
performance, slowing LLaVA image processing down by 10x, possible due to
this card having a small L1 cache or very few registers per thread. This
change is meant as a stopgap. It causes a modest slowdown in performance
for batched operations on more expensive graphics cards in order to gain
the benefit of cheaper graphics cards being possible to use. Until there
exists a better way to determine the optimal behavior at runtime, anyone
who's seriously interested in performance should consider cuBLAS/hipBLAS

This change also fixes the tinyblas header so builds work on all systems
jart committed Jan 4, 2024
1 parent 190f96f commit b7bf60d
Showing 3 changed files with 211 additions and 179 deletions.
42 changes: 26 additions & 16 deletions llama.cpp/ggml-cuda.cu
Original file line number Diff line number Diff line change
@@ -21,27 +21,26 @@
#include <hip/hip_runtime.h>
#include <hipblas/hipblas.h>
#include <hip/hip_fp16.h>
#define CUBLAS_COMPUTE_16F HIPBLAS_R_16F
#define CUBLAS_COMPUTE_32F HIPBLAS_R_32F
#define CUBLAS_COMPUTE_32F_FAST_16F HIPBLAS_R_32F
#define CUBLAS_GEMM_DEFAULT HIPBLAS_GEMM_DEFAULT
#define CUBLAS_GEMM_DEFAULT_TENSOR_OP HIPBLAS_GEMM_DEFAULT
#define CUBLAS_OP_N HIPBLAS_OP_N
#define CUBLAS_OP_T HIPBLAS_OP_T
#define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS
#define CUBLAS_STATUS_NOT_SUPPORTED HIPBLAS_STATUS_NOT_SUPPORTED
#define CUBLAS_COMPUTE_16F TINYBLAS_COMPUTE_16F
#define CUBLAS_COMPUTE_32F TINYBLAS_COMPUTE_32F
#define CUBLAS_GEMM_DEFAULT TINYBLAS_GEMM_DEFAULT_TENSOR_OP
#define CUBLAS_GEMM_DEFAULT_TENSOR_OP TINYBLAS_GEMM_DEFAULT_TENSOR_OP
#define CUBLAS_OP_N TINYBLAS_OP_N
#define CUBLAS_OP_T TINYBLAS_OP_T
#define CUBLAS_STATUS_SUCCESS TINYBLAS_STATUS_SUCCESS
#define CUBLAS_STATUS_NOT_SUPPORTED TINYBLAS_STATUS_NOT_SUPPORTED
#define CUBLAS_TF32_TENSOR_OP_MATH 0
#define CUDA_R_16F HIPBLAS_R_16F
#define CUDA_R_32F HIPBLAS_R_32F
#define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width)
#define cublasGemmAlgo_t hipblasGemmAlgo_t
#define cublasOperation_t hipblasOperation_t
#define cublasComputeType_t hipblasDatatype_t //deprecated, new hipblasComputeType_t not in 5.6
#define cublasGemmAlgo_t tinyblasGemmAlgo_t
#define cublasOperation_t tinyblasOperation_t
#define cublasComputeType_t tinyblasComputeType_t //deprecated, new hipblasComputeType_t not in 5.6
#define cublasCreate hipblasCreate
#define cublasHandle_t hipblasHandle_t
#define cublasHandle_t tinyblasHandle_t
#define cublasSetMathMode(handle, mode) CUBLAS_STATUS_SUCCESS
#define cublasSetStream hipblasSetStream
#define cublasStatus_t hipblasStatus_t
#define cublasStatus_t tinyblasStatus_t
#define cudaDataType_t hipblasDatatype_t //deprecated, new hipblasDatatype not in 5.6
#define cudaDeviceCanAccessPeer hipDeviceCanAccessPeer
#define cudaDeviceDisablePeerAccess hipDeviceDisablePeerAccess
@@ -92,7 +91,19 @@

#elif defined(GGML_USE_TINYBLAS)
#include "tinyblas.cu"
#define cublasHandle_t cudaStream_t
#define CUBLAS_COMPUTE_16F TINYBLAS_COMPUTE_16F
#define CUBLAS_COMPUTE_32F TINYBLAS_COMPUTE_32F
#define CUBLAS_OP_N TINYBLAS_OP_N
#define CUBLAS_OP_T TINYBLAS_OP_T
#define CUBLAS_GEMM_DEFAULT TINYBLAS_GEMM_DEFAULT_TENSOR_OP
#define CUBLAS_GEMM_DEFAULT_TENSOR_OP TINYBLAS_GEMM_DEFAULT_TENSOR_OP
#define CUBLAS_STATUS_SUCCESS TINYBLAS_STATUS_SUCCESS
#define CUBLAS_STATUS_NOT_SUPPORTED TINYBLAS_STATUS_NOT_SUPPORTED
#define cublasGemmAlgo_t tinyblasGemmAlgo_t
#define cublasOperation_t tinyblasOperation_t
#define cublasComputeType_t tinyblasComputeType_t //deprecated, new hipblasComputeType_t not in 5.6
#define cublasHandle_t tinyblasHandle_t
#define cublasStatus_t tinyblasStatus_t
#define cublasSgemm tinyblasSgemm
#define cublasGemmEx tinyblasGemmEx
#define cublasGemmBatchedEx tinyblasGemmBatchedEx
@@ -109,7 +120,6 @@
#endif // __HIP_PLATFORM_AMD__
#define CUBLAS_COMPUTE_16F HIPBLAS_R_16F
#define CUBLAS_COMPUTE_32F HIPBLAS_R_32F
#define CUBLAS_COMPUTE_32F_FAST_16F HIPBLAS_R_32F
#define CUBLAS_GEMM_DEFAULT HIPBLAS_GEMM_DEFAULT
#define CUBLAS_GEMM_DEFAULT_TENSOR_OP HIPBLAS_GEMM_DEFAULT
#define CUBLAS_OP_N HIPBLAS_OP_N
188 changes: 94 additions & 94 deletions llamafile/tinyblas.cu
Original file line number Diff line number Diff line change
@@ -18,12 +18,12 @@
#include "tinyblas.h"

#define READ(A, trans, ld, i, j) \
(((trans) == CUBLAS_OP_N) ? (A)[(i) + (j) * (ld)] : (A)[(j) + (i) * (ld)])
(((trans) == TINYBLAS_OP_N) ? (A)[(i) + (j) * (ld)] : (A)[(j) + (i) * (ld)])
#define READ16(A, trans, ld, i, j) __half2float(READ(A, trans, ld, i, j))

#define BM 64
#define BN 32
#define BK 64
#define BM 48
#define BN 12
#define BK 48
#define CEIL_DIV(M, N) (((M) + (N)-1) / (N))

static __device__ void matmul32_block2d(int m, int n, int k, int x, int y,
@@ -45,13 +45,13 @@ static __device__ void matmul32_block2d(int m, int n, int k, int x, int y,
// we copy into As from A
for (j = 0; j < BM && x + j < m; ++j) {
As[(j * BK) + i] =
READ(A, CUBLAS_OP_T, lda, x + j, blob + i);
READ(A, TINYBLAS_OP_T, lda, x + j, blob + i);
}
for (; j < BM; ++j) As[(j * BK) + i] = 0;
// we copy into Bs from B
for (j = 0; j < BN && y + j < n; ++j) {
Bs[(i * BN) + j] =
READ(B, CUBLAS_OP_N, ldb, blob + i, y + j);
READ(B, TINYBLAS_OP_N, ldb, blob + i, y + j);
}
for (; j < BN; ++j) Bs[(i * BN) + j] = 0;
} else { // UNLIKELY
@@ -110,36 +110,36 @@ static __global__ void tinyblasS_entry(int m, int n, int k,
}
}

static bool check_args(cublasOperation_t transa, cublasOperation_t transb,
static bool check_args(tinyblasOperation_t transa, tinyblasOperation_t transb,
const void *pAlpha, cudaDataType_t Atype,
cudaDataType_t Btype, const void *pBeta,
cudaDataType_t Ctype, cublasComputeType_t computeType) {
return (transa == CUBLAS_OP_T &&
transb == CUBLAS_OP_N &&
cudaDataType_t Ctype, tinyblasComputeType_t computeType) {
return (transa == TINYBLAS_OP_T &&
transb == TINYBLAS_OP_N &&
Atype == CUDA_R_16F &&
Btype == CUDA_R_16F &&
(Ctype == CUDA_R_16F ||
Ctype == CUDA_R_32F) &&
((computeType == CUBLAS_COMPUTE_16F &&
((computeType == TINYBLAS_COMPUTE_16F &&
__half2float(*(half *)pAlpha) == 1.0f &&
__half2float(*(half *)pBeta) == 0.0f) ||
(computeType == CUBLAS_COMPUTE_32F &&
(computeType == TINYBLAS_COMPUTE_32F &&
*(float *)pAlpha == 1.0f &&
*(float *)pBeta == 0.0f)));
}

cublasStatus_t tinyblasSgemm(cudaStream_t stream,
cublasOperation_t transa,
cublasOperation_t transb,
int m, int n, int k,
const float *alpha,
const float *A, int lda,
const float *B, int ldb,
const float *beta,
float *C, int ldc) {
if (transa != CUBLAS_OP_T || transb != CUBLAS_OP_N ||
tinyblasStatus_t tinyblasSgemm(tinyblasHandle_t stream,
tinyblasOperation_t transa,
tinyblasOperation_t transb,
int m, int n, int k,
const float *alpha,
const float *A, int lda,
const float *B, int ldb,
const float *beta,
float *C, int ldc) {
if (transa != TINYBLAS_OP_T || transb != TINYBLAS_OP_N ||
*alpha != 1.0f || *beta != 0.0f) {
return CUBLAS_STATUS_NOT_SUPPORTED;
return TINYBLAS_STATUS_NOT_SUPPORTED;
}

dim3 maxblocks(CEIL_DIV(m, BM), CEIL_DIV(n, BN), 1);
@@ -148,7 +148,7 @@ cublasStatus_t tinyblasSgemm(cudaStream_t stream,
tinyblasS_entry<<<maxblocks, maxthreads,
(sizeof(float) * (BM * BK + BK * BN)), stream>>>(
m, n, k, A, lda, B, ldb, C, ldc);
return CUBLAS_STATUS_SUCCESS;
return TINYBLAS_STATUS_SUCCESS;
}

static __device__ void matmul_block2d(int m, int n, int k, int x, int y,
@@ -171,13 +171,13 @@ static __device__ void matmul_block2d(int m, int n, int k, int x, int y,
// we copy into As from A
for (j = 0; j < BM && x + j < m; ++j) {
As[(j * BK) + i] =
READ16(A, CUBLAS_OP_T, lda, x + j, blob + i);
READ16(A, TINYBLAS_OP_T, lda, x + j, blob + i);
}
for (; j < BM; ++j) As[(j * BK) + i] = 0;
// we copy into Bs from B
for (j = 0; j < BN && y + j < n; ++j) {
Bs[(i * BN) + j] =
READ16(B, CUBLAS_OP_N, ldb, blob + i, y + j);
READ16(B, TINYBLAS_OP_N, ldb, blob + i, y + j);
}
for (; j < BN; ++j) Bs[(i * BN) + j] = 0;
} else { // UNLIKELY
@@ -242,28 +242,28 @@ static __global__ void tinyblasGE_entry(int m, int n, int k, const half *A,
}
}

cublasStatus_t tinyblasGemmEx(cudaStream_t stream,
cublasOperation_t transa,
cublasOperation_t transb,
int m,
int n,
int k,
const void *alpha,
const void *A,
cudaDataType_t Atype,
int lda,
const void *B,
cudaDataType_t Btype,
int ldb,
const void *beta,
void *C,
cudaDataType_t Ctype,
int ldc,
cublasComputeType_t computeType,
cublasGemmAlgo_t algo) {
tinyblasStatus_t tinyblasGemmEx(tinyblasHandle_t stream,
tinyblasOperation_t transa,
tinyblasOperation_t transb,
int m,
int n,
int k,
const void *alpha,
const void *A,
cudaDataType_t Atype,
int lda,
const void *B,
cudaDataType_t Btype,
int ldb,
const void *beta,
void *C,
cudaDataType_t Ctype,
int ldc,
tinyblasComputeType_t computeType,
tinyblasGemmAlgo_t algo) {
if (!check_args(transa, transb, alpha, Atype, Btype, beta, Ctype,
computeType)) {
return CUBLAS_STATUS_NOT_SUPPORTED;
return TINYBLAS_STATUS_NOT_SUPPORTED;
}

dim3 maxblocks(CEIL_DIV(m, BM), CEIL_DIV(n, BN), 1);
@@ -272,7 +272,7 @@ cublasStatus_t tinyblasGemmEx(cudaStream_t stream,
tinyblasGE_entry<<<maxblocks, maxthreads,
(sizeof(float) * (BM * BK + BK * BN)), stream>>>(
m, n, k, (const half *)A, lda, (const half *)B, ldb, C, Ctype, ldc);
return CUBLAS_STATUS_SUCCESS;
return TINYBLAS_STATUS_SUCCESS;
}

// https://docs.nvidia.com/cuda/cublas/index.html#cublasgemmbatchedex
@@ -309,29 +309,29 @@ static __global__ void tinyblasGBE_entry(int m, int n, int k,
}
}

cublasStatus_t tinyblasGemmBatchedEx(cudaStream_t stream,
cublasOperation_t transa,
cublasOperation_t transb,
int m,
int n,
int k,
const void *alpha,
const void *const Aarray[],
cudaDataType_t Atype,
int lda,
const void *const Barray[],
cudaDataType_t Btype,
int ldb,
const void *beta,
void *const Carray[],
cudaDataType_t Ctype,
int ldc,
int batchCount,
cublasComputeType_t computeType,
cublasGemmAlgo_t algo) {
tinyblasStatus_t tinyblasGemmBatchedEx(tinyblasHandle_t stream,
tinyblasOperation_t transa,
tinyblasOperation_t transb,
int m,
int n,
int k,
const void *alpha,
const void *const Aarray[],
cudaDataType_t Atype,
int lda,
const void *const Barray[],
cudaDataType_t Btype,
int ldb,
const void *beta,
void *const Carray[],
cudaDataType_t Ctype,
int ldc,
int batchCount,
tinyblasComputeType_t computeType,
tinyblasGemmAlgo_t algo) {
if (!check_args(transa, transb, alpha, Atype, Btype, beta, Ctype,
computeType)) {
return CUBLAS_STATUS_NOT_SUPPORTED;
return TINYBLAS_STATUS_NOT_SUPPORTED;
}

dim3 maxblocks(CEIL_DIV(m, BM), CEIL_DIV(n, BN), 32);
@@ -341,7 +341,7 @@ cublasStatus_t tinyblasGemmBatchedEx(cudaStream_t stream,
(sizeof(float) * (BM * BK + BK * BN)), stream>>>(
m, n, k, (const half **)Aarray, lda, (const half **)Barray, ldb,
Carray, Ctype, ldc, batchCount);
return CUBLAS_STATUS_SUCCESS;
return TINYBLAS_STATUS_SUCCESS;
}

// https://docs.nvidia.com/cuda/cublas/index.html#cublasgemmstridedbatchedex
@@ -372,13 +372,13 @@ static __device__ void matmul_block2d_sb(int m, int n, int k, int x, int y,
// we copy into As from A
for (j = 0; j < BM && x + j < m; ++j) {
As[(j * BK) + i] =
READ16(A, CUBLAS_OP_T, lda, x + j, blob + i);
READ16(A, TINYBLAS_OP_T, lda, x + j, blob + i);
}
for (; j < BM; ++j) As[(j * BK) + i] = 0;
// we copy into Bs from B
for (j = 0; j < BN && y + j < n; ++j) {
Bs[(i * BN) + j] =
READ16(B, CUBLAS_OP_N, ldb, blob + i, y + j);
READ16(B, TINYBLAS_OP_N, ldb, blob + i, y + j);
}
for (; j < BN; ++j) Bs[(i * BN) + j] = 0;
} else { // UNLIKELY
@@ -457,30 +457,30 @@ static __global__ void tinyblasGSBE_entry(int m, int n, int k,
}
}

cublasStatus_t tinyblasGemmStridedBatchedEx(cudaStream_t stream,
cublasOperation_t transa,
cublasOperation_t transb,
int m, int n, int k,
const void *pAlpha,
const void *A,
cudaDataType_t Atype,
int lda,
long long int strideA,
const void *B,
cudaDataType_t Btype,
int ldb,
long long int strideB,
const void *pBeta,
void *C,
cudaDataType_t Ctype,
int ldc,
long long int strideC,
int batchCount,
cublasComputeType_t computeType,
cublasGemmAlgo_t algo) {
tinyblasStatus_t tinyblasGemmStridedBatchedEx(tinyblasHandle_t stream,
tinyblasOperation_t transa,
tinyblasOperation_t transb,
int m, int n, int k,
const void *pAlpha,
const void *A,
cudaDataType_t Atype,
int lda,
long long int strideA,
const void *B,
cudaDataType_t Btype,
int ldb,
long long int strideB,
const void *pBeta,
void *C,
cudaDataType_t Ctype,
int ldc,
long long int strideC,
int batchCount,
tinyblasComputeType_t computeType,
tinyblasGemmAlgo_t algo) {
if (!check_args(transa, transb, pAlpha, Atype, Btype, pBeta, Ctype,
computeType)) {
return CUBLAS_STATUS_NOT_SUPPORTED;
return TINYBLAS_STATUS_NOT_SUPPORTED;
}

// call the entry function
@@ -492,5 +492,5 @@ cublasStatus_t tinyblasGemmStridedBatchedEx(cudaStream_t stream,
m, n, k, (const half*)A, lda, strideA, (const half*)B, ldb, strideB,
C, Ctype, ldc, strideC, batchCount);

return CUBLAS_STATUS_SUCCESS;
return TINYBLAS_STATUS_SUCCESS;
}
160 changes: 91 additions & 69 deletions llamafile/tinyblas.h
Original file line number Diff line number Diff line change
@@ -8,75 +8,97 @@
#include <cuda_fp16.h>
#endif

cublasStatus_t tinyblasSgemm(cublasHandle_t handle,
cublasOperation_t transa,
cublasOperation_t transb,
int m, int n, int k,
const float *alpha,
const float *A, int lda,
const float *B, int ldb,
const float *beta,
float *C, int ldc);
enum tinyblasOperation_t {
TINYBLAS_OP_N = 111,
TINYBLAS_OP_T = 112,
TINYBLAS_OP_C = 113,
};

cublasStatus_t tinyblasGemmEx(cublasHandle_t handle,
cublasOperation_t transa,
cublasOperation_t transb,
int m,
int n,
int k,
const void *alpha,
const void *A,
cudaDataType_t Atype,
int lda,
const void *B,
cudaDataType_t Btype,
int ldb,
const void *beta,
void *C,
cudaDataType_t Ctype,
int ldc,
cublasComputeType_t computeType,
cublasGemmAlgo_t algo);
enum tinyblasStatus_t {
TINYBLAS_STATUS_SUCCESS = 0,
TINYBLAS_STATUS_NOT_SUPPORTED = 7,
};

cublasStatus_t tinyblasGemmBatchedEx(cublasHandle_t handle,
cublasOperation_t transa,
cublasOperation_t transb,
int m,
int n,
int k,
const void *alpha,
const void *const Aarray[],
cudaDataType_t Atype,
int lda,
const void *const Barray[],
cudaDataType_t Btype,
int ldb,
const void *beta,
void *const Carray[],
cudaDataType_t Ctype,
int ldc,
int batchCount,
cublasComputeType_t computeType,
cublasGemmAlgo_t algo);
enum tinyblasComputeType_t {
TINYBLAS_COMPUTE_16F = 150,
TINYBLAS_COMPUTE_32F = 151,
};

cublasStatus_t tinyblasGemmStridedBatchedEx(cublasHandle_t handle,
cublasOperation_t transa,
cublasOperation_t transb,
int m, int n, int k,
const void *pAlpha,
const void *A,
cudaDataType_t Atype,
int lda,
long long int strideA,
const void *B,
cudaDataType_t Btype,
int ldb,
long long int strideB,
const void *pBeta,
void *C,
cudaDataType_t Ctype,
int ldc,
long long int strideC,
int batchCount,
cublasComputeType_t computeType,
cublasGemmAlgo_t algo);
enum tinyblasGemmAlgo_t {
TINYBLAS_GEMM_DEFAULT_TENSOR_OP = 160,
};

#define tinyblasHandle_t cudaStream_t

tinyblasStatus_t tinyblasSgemm(tinyblasHandle_t handle,
tinyblasOperation_t transa,
tinyblasOperation_t transb,
int m, int n, int k,
const float *alpha,
const float *A, int lda,
const float *B, int ldb,
const float *beta,
float *C, int ldc);

tinyblasStatus_t tinyblasGemmEx(tinyblasHandle_t handle,
tinyblasOperation_t transa,
tinyblasOperation_t transb,
int m,
int n,
int k,
const void *alpha,
const void *A,
cudaDataType_t Atype,
int lda,
const void *B,
cudaDataType_t Btype,
int ldb,
const void *beta,
void *C,
cudaDataType_t Ctype,
int ldc,
tinyblasComputeType_t computeType,
tinyblasGemmAlgo_t algo);

tinyblasStatus_t tinyblasGemmBatchedEx(tinyblasHandle_t handle,
tinyblasOperation_t transa,
tinyblasOperation_t transb,
int m,
int n,
int k,
const void *alpha,
const void *const Aarray[],
cudaDataType_t Atype,
int lda,
const void *const Barray[],
cudaDataType_t Btype,
int ldb,
const void *beta,
void *const Carray[],
cudaDataType_t Ctype,
int ldc,
int batchCount,
tinyblasComputeType_t computeType,
tinyblasGemmAlgo_t algo);

tinyblasStatus_t tinyblasGemmStridedBatchedEx(tinyblasHandle_t handle,
tinyblasOperation_t transa,
tinyblasOperation_t transb,
int m, int n, int k,
const void *pAlpha,
const void *A,
cudaDataType_t Atype,
int lda,
long long int strideA,
const void *B,
cudaDataType_t Btype,
int ldb,
long long int strideB,
const void *pBeta,
void *C,
cudaDataType_t Ctype,
int ldc,
long long int strideC,
int batchCount,
tinyblasComputeType_t computeType,
tinyblasGemmAlgo_t algo);

0 comments on commit b7bf60d

Please sign in to comment.