Skip to content

Commit

Permalink
Cached cusparseLt
Browse files Browse the repository at this point in the history
  • Loading branch information
ilmarkov committed Nov 13, 2024
1 parent f45a83b commit b1aaea5
Show file tree
Hide file tree
Showing 6 changed files with 107 additions and 57 deletions.
29 changes: 17 additions & 12 deletions benchmarks/cusparseLt_benchmarks/benchmark_24.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from vllm.model_executor.layers.sparsity.utils.cusparse_2_4_utils import (
compress_to_torch_sparse_semi_structured_mat, dense_matmul, get_random_mat,
is_semi_structured_supported, semi_structured_sparse_dense_gemm)
from vllm._custom_ops import (semi_structured_fp8_prepare_mm, semi_structured_fp8_mm_prepared)
from vllm.utils import FlexibleArgumentParser

DEFAULT_MODELS = list(WEIGHT_SHAPES.keys())
Expand Down Expand Up @@ -79,15 +80,15 @@ def bench(m: int, k: int, n: int, label: str, sub_label: str,

a, b = make_rand_tensors(torch.int8, m, n, k)
# cutlass i8
timers.append(
bench_fn(label, sub_label, "cutlass_i8_i8_matmul-w-scales",
dense_matmul, a, b, torch.int8))
# timers.append(
# bench_fn(label, sub_label, "cutlass_i8_i8_matmul-w-scales",
# dense_matmul, a, b, torch.int8))

# cusparseLt i8
timers.append(
bench_fn(label, sub_label, "cusparseLt_i8_i8_2_4",
semi_structured_sparse_dense_gemm,
compress_to_torch_sparse_semi_structured_mat(a), b))
# timers.append(
# bench_fn(label, sub_label, "cusparseLt_i8_i8_2_4",
# semi_structured_sparse_dense_gemm,
# compress_to_torch_sparse_semi_structured_mat(a), b))

if use_fp8:
a, b = make_rand_tensors(torch.float8_e4m3fn, m, n, k)
Expand All @@ -101,6 +102,13 @@ def bench(m: int, k: int, n: int, label: str, sub_label: str,
bench_fn(label, sub_label, "cusparseLt_fp8_fp8_2_4",
semi_structured_sparse_dense_gemm,
compress_to_torch_sparse_semi_structured_mat(a), b))

a_compressed = compress_to_torch_sparse_semi_structured_mat(a)
handle = semi_structured_fp8_prepare_mm(a_compressed.packed, b)
timers.append(
bench_fn(label, sub_label, "cusparseLt_fp8_fp8_2_4_prepared",
semi_structured_fp8_mm_prepared,
torch.tensor([handle], dtype=torch.int64, device='cuda')))

return timers

Expand All @@ -114,9 +122,6 @@ def print_timers(timers: Iterable[TMeasurement]):
def run(MKNs: Iterable[Tuple[int, int, int]],
use_fp8: bool) -> Iterable[TMeasurement]:
results = []
# MKNs = [(2048, 8192, 14336)]
# MKNs = [(32, 11008, 4096)]
MKNs = [(2048, 11008, 14336)]
for m, k, n in MKNs:
timers = bench(m, k, n, "gemm", f"MKN=({m}x{k}x{n})", use_fp8)
print_timers(timers)
Expand Down Expand Up @@ -181,8 +186,8 @@ def model_shapes(model_name: str, tp_size: int) -> List[Tuple[int, int]]:
for d in model_bench_data:
all_data.extend(d)
# pickle all data
with open(f"model_bench-{timestamp}.pkl", "wb") as f:
pkl.dump(all_data, f)
# with open(f"model_bench-{timestamp}.pkl", "wb") as f:
# pkl.dump(all_data, f)


if __name__ == '__main__':
Expand Down
5 changes: 3 additions & 2 deletions csrc/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -232,9 +232,10 @@ torch::Tensor cslt_mm_fp8_semi_structured(
const c10::optional<torch::Tensor>& bias_opt, bool transpose_result);

int64_t cslt_prepare_mm_fp8_semi_structured(const torch::Tensor& compressed_A,
const torch::Tensor& dense_B);
const torch::Tensor& dense_B,
const c10::optional<torch::Tensor>& bias_opt, bool transpose_result);

torch::Tensor cslt_mm_fp8_semi_structured_prepared(int64_t id);
torch::Tensor cslt_mm_fp8_semi_structured_prepared(const torch::Tensor& id);

void cslt_fp8_semi_structured_destroy(int64_t id);

Expand Down
104 changes: 67 additions & 37 deletions csrc/quantization/fp8_semi_structured/cusparseLt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,27 +58,44 @@
" when calling `" #EXPR "`"); \
} while (0)



namespace vllm {
namespace cusparseLt {

cusparseLtHandle_t handle;
bool handle_initialized = false;
using cacheID = int64_t;

struct cusparseLtEntry {
// cusparseLtEntry(): device() {}
int m;
int n;
int k;
// cusparseLtEntry() {}
// void operator=(const cusparseLtEntry& entry) {
// sparse_input_descriptor = entry.sparse_input_descriptor;
// dense_input_descriptor = entry.dense_input_descriptor;
// res_descriptor = entry.res_descriptor;
// C_descriptor = entry.C_descriptor;
// matmul = entry.matmul;
// plan = entry.plan;

// sparse_mat_ptr = entry.sparse_mat_ptr;
// dense_mat_ptr = entry.dense_mat_ptr;

// device = std::move(entry.device);
// allocator = entry.allocator;
// out_dtype = std::move(entry.out_dtype);

// workspace_ptr = std::move(entry.workspace_ptr);

// m = entry.m;
// n = entry.n;
// k = entry.k;
// }

cusparseLtMatmulDescriptor_t matmul;
cusparseLtMatmulPlan_t plan;
cusparseLtMatmulAlgSelection_t alg_sel;
cusparseLtMatDescriptor_t sparse_input_descriptor;
cusparseLtMatDescriptor_t dense_input_descriptor;
cusparseLtMatDescriptor_t res_descriptor;
cusparseLtMatDescriptor_t C_descriptor;

cusparseLtMatmulDescriptor_t matmul;
cusparseLtMatmulPlan_t plan;


void* sparse_mat_ptr;
void* dense_mat_ptr;

Expand All @@ -87,13 +104,23 @@ struct cusparseLtEntry {

c10::cuda::CUDACachingAllocator::CUDAAllocator* allocator;
c10::DataPtr workspace_ptr;

int m;
int n;
int k;
};

std::map<cacheID, cusparseLtEntry> cusparseLt_cache;
cusparseLtHandle_t handle;
bool handle_initialized = false;
using cacheID = int64_t;


std::map<cacheID, cusparseLtEntry> cusparseLt_cache;
} // namespace cusparseLt
} // namespace vllm

vllm::cusparseLt::cusparseLtEntry entry;

torch::Tensor cslt_compress_fp8_semi_structured(const torch::Tensor& input) {
TORCH_CHECK(input.scalar_type() == at::ScalarType::Float8_e4m3fn,
"Only float8 e4m3 is supported in vllm:cslt_compress");
Expand Down Expand Up @@ -128,15 +155,14 @@ torch::Tensor cslt_compress_fp8_semi_structured(const torch::Tensor& input) {
return compressed_tensor;
}

// vllm::cusparseLt::cacheID cslt_prepare_mm_fp8_semi_structured(const
// torch::Tensor& compressed_A, const torch::Tensor& dense_B) {
int64_t cslt_prepare_mm_fp8_semi_structured(const torch::Tensor& compressed_A,
const torch::Tensor& dense_B) {
vllm::cusparseLt::cacheID cslt_prepare_mm_fp8_semi_structured(const torch::Tensor& compressed_A,
const torch::Tensor& dense_B,
const c10::optional<torch::Tensor>& bias_opt, bool transpose_result) {
TORCH_CHECK(compressed_A.scalar_type() == at::ScalarType::Float8_e4m3fn,
"Only float8 e4m3 is supported in vllm:cslt_compress");
namespace vc = vllm::cusparseLt;
if (!vc::handle_initialized) {
TORCH_CUDASPARSE_CHECK(cusparseLtInit(&vllm::cusparseLt::handle));
TORCH_CUDASPARSE_CHECK(cusparseLtInit(&vc::handle));
vc::handle_initialized = true;
}
vc::cacheID id;
Expand All @@ -145,7 +171,9 @@ int64_t cslt_prepare_mm_fp8_semi_structured(const torch::Tensor& compressed_A,
} else {
id = vc::cusparseLt_cache.rbegin()->first + 1;
}
vc::cusparseLtEntry& entry = vc::cusparseLt_cache[id];

// vc::cusparseLtEntry& entry = vc::cusparseLt_cache[id];
// vc::cusparseLtEntry entry;

float alpha = 1.0;
float beta = 0.0;
Expand All @@ -155,7 +183,6 @@ int64_t cslt_prepare_mm_fp8_semi_structured(const torch::Tensor& compressed_A,
cusparseComputeType compute_type = CUSPARSE_COMPUTE_32F;
auto compression_factor = 9;
auto out_dtype = dense_B.scalar_type();

int64_t k = dense_B.size(0);
int64_t n = dense_B.size(1);
int64_t m = (compressed_A.numel() * 16 / compression_factor) / k;
Expand Down Expand Up @@ -183,10 +210,9 @@ int64_t cslt_prepare_mm_fp8_semi_structured(const torch::Tensor& compressed_A,
"float32} for fp8 inputs");
break;
}

// initialize sparse descriptor
TORCH_CUDASPARSE_CHECK(cusparseLtStructuredDescriptorInit(
&vc::handle, &entry.sparse_input_descriptor, m, k, k, 16, input_type,
&vc::handle, &(entry.sparse_input_descriptor), m, k, k, 16, input_type,
CUSPARSE_ORDER_ROW, CUSPARSELT_SPARSITY_50_PERCENT));

// initialize dense descriptor
Expand All @@ -196,27 +222,27 @@ int64_t cslt_prepare_mm_fp8_semi_structured(const torch::Tensor& compressed_A,
(dense_B.is_contiguous()) ? n : k, 16, input_type, CUSPARSE_ORDER_ROW));

// initialize result descriptor
TORCH_CUDASPARSE_CHECK(
cusparseLtDenseDescriptorInit(&vc::handle, &entry.res_descriptor, m, n, m,
16, output_type, CUSPARSE_ORDER_ROW));
TORCH_CUDASPARSE_CHECK(cusparseLtDenseDescriptorInit(
&vc::handle, &entry.res_descriptor, m, n, (transpose_result) ? m : n, 16,
output_type,
(transpose_result) ? CUSPARSE_ORDER_COL : CUSPARSE_ORDER_ROW));
TORCH_CUDASPARSE_CHECK(cusparseLtDenseDescriptorInit(
&vc::handle, &entry.C_descriptor, m, n, (transpose_result) ? m : n, 16, C_type,
(transpose_result) ? CUSPARSE_ORDER_COL : CUSPARSE_ORDER_ROW));

TORCH_CUDASPARSE_CHECK(
cusparseLtDenseDescriptorInit(&vc::handle, &entry.C_descriptor, m, n, n,
16, C_type, CUSPARSE_ORDER_ROW));
cusparseLtMatmulAlgSelection_t alg_sel;

TORCH_CUDASPARSE_CHECK(cusparseLtMatmulDescriptorInit(
&vc::handle, &entry.matmul, CUSPARSE_OPERATION_NON_TRANSPOSE,
(dense_B.is_contiguous()) ? CUSPARSE_OPERATION_NON_TRANSPOSE
: CUSPARSE_OPERATION_TRANSPOSE,
&entry.sparse_input_descriptor, &entry.dense_input_descriptor,
&entry.C_descriptor, &entry.res_descriptor, compute_type));

TORCH_CUDASPARSE_CHECK(cusparseLtMatmulAlgSelectionInit(
&vc::handle, &entry.alg_sel, &entry.matmul,
&vc::handle, &alg_sel, &entry.matmul,
CUSPARSELT_MATMUL_ALG_DEFAULT));
TORCH_CUDASPARSE_CHECK(cusparseLtMatmulPlanInit(
&vc::handle, &entry.plan, &entry.matmul, &entry.alg_sel));

&vc::handle, &entry.plan, &entry.matmul, &alg_sel));
size_t workspace_size;
TORCH_CUDASPARSE_CHECK(
cusparseLtMatmulGetWorkspace(&vc::handle, &entry.plan, &workspace_size));
Expand All @@ -228,18 +254,22 @@ int64_t cslt_prepare_mm_fp8_semi_structured(const torch::Tensor& compressed_A,
entry.m = m;
entry.n = n;
entry.k = k;
entry.sparse_mat_ptr = compressed_A.data_ptr();
entry.dense_mat_ptr = dense_B.data_ptr();
return id;
}

torch::Tensor cslt_mm_fp8_semi_structured_prepared(
vllm::cusparseLt::cacheID id) {
const torch::Tensor& id_tensor) {
namespace vc = vllm::cusparseLt;
TORCH_CHECK(vc::handle_initialized,
"Call of matmul with unintialized matmul");
if (vc::cusparseLt_cache.count(id) == 0) {
TORCH_CHECK(false, "cusparse matmul Id is not found");
}
const auto& entry = vc::cusparseLt_cache[id];
// TORCH_CHECK(id_tensor.numel() == 1, "ID has to be single valued");
// auto id = id_tensor.item<vc::cacheID>();
// if (vc::cusparseLt_cache.count(id) == 0) {
// TORCH_CHECK(false, "cusparse matmul Id is not found");
// }
// const auto& entry = vc::cusparseLt_cache[id];

auto res_tensor_options =
c10::TensorOptions().dtype(entry.out_dtype).device(entry.device);
Expand All @@ -262,7 +292,7 @@ void cslt_fp8_semi_structured_destroy(vllm::cusparseLt::cacheID id) {
if (vllm::cusparseLt::cusparseLt_cache.count(id) == 0) {
TORCH_CHECK(false, "cusparse matmul Id is not found");
}
auto& entry = vllm::cusparseLt::cusparseLt_cache[id];
// auto& entry = vllm::cusparseLt::cusparseLt_cache[id];

TORCH_CUDASPARSE_CHECK(
cusparseLtMatDescriptorDestroy(&entry.sparse_input_descriptor));
Expand Down
4 changes: 2 additions & 2 deletions csrc/torch_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -336,11 +336,11 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {

ops.def(
"cslt_prepare_mm_fp8_semi_structured(Tensor! compressed_A, Tensor! "
"denseB) -> int");
"denseB, Tensor!? bias, bool transpose_result) -> int");
ops.impl("cslt_prepare_mm_fp8_semi_structured", torch::kCUDA,
&cslt_prepare_mm_fp8_semi_structured);

ops.def("cslt_mm_fp8_semi_structured_prepared(int cacheId) -> Tensor");
ops.def("cslt_mm_fp8_semi_structured_prepared(Tensor cacheId) -> Tensor");
ops.impl("cslt_mm_fp8_semi_structured_prepared", torch::kCUDA,
&cslt_mm_fp8_semi_structured_prepared);

Expand Down
12 changes: 12 additions & 0 deletions tests/test_cusparseLt.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
#include <cusparseLt.h>

cusparseLtHandle_t handle;


struct Entry {
cusparseLtMatDescriptor_t sparse_input_descriptor;
};

int main() {

}
10 changes: 6 additions & 4 deletions vllm/_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -723,18 +723,20 @@ def semi_structured_fp8_mm(A_compressed: torch.Tensor,


def semi_structured_fp8_prepare_mm(A_compressed: torch.Tensor,
B_dense: torch.Tensor) -> int:
B_dense: torch.Tensor,
bias: Optional[torch.Tensor] = None,
transpose_result: bool = False) -> int:
assert A_compressed.dtype == torch.float8_e4m3fn
return torch.ops._C.cslt_prepare_mm_fp8_semi_structured(
A_compressed, B_dense)
A_compressed, B_dense, bias, transpose_result)


def semi_structured_fp8_mm_prepared(cacheId: int) -> torch.Tensor:
return torch.ops.cslt_mm_fp8_semi_structured_prepared(cacheId)
return torch.ops._C.cslt_mm_fp8_semi_structured_prepared(cacheId)


def semi_structured_fp8_destroy(cacheId: int):
torch.ops.cslt_fp8_semi_structured_destroy(cacheId)
torch.ops._C.cslt_fp8_semi_structured_destroy(cacheId)


# int8
Expand Down

0 comments on commit b1aaea5

Please sign in to comment.