From b1aaea5c48d8f579fabcd79784829a3f8c00abd8 Mon Sep 17 00:00:00 2001 From: ilmarkov Date: Fri, 25 Oct 2024 14:35:31 +0000 Subject: [PATCH] Cached cusparseLt --- .../cusparseLt_benchmarks/benchmark_24.py | 29 +++-- csrc/ops.h | 5 +- .../fp8_semi_structured/cusparseLt.cpp | 104 +++++++++++------- csrc/torch_bindings.cpp | 4 +- tests/test_cusparseLt.cpp | 12 ++ vllm/_custom_ops.py | 10 +- 6 files changed, 107 insertions(+), 57 deletions(-) create mode 100644 tests/test_cusparseLt.cpp diff --git a/benchmarks/cusparseLt_benchmarks/benchmark_24.py b/benchmarks/cusparseLt_benchmarks/benchmark_24.py index 426fb653598a2..4599964421bc0 100644 --- a/benchmarks/cusparseLt_benchmarks/benchmark_24.py +++ b/benchmarks/cusparseLt_benchmarks/benchmark_24.py @@ -13,6 +13,7 @@ from vllm.model_executor.layers.sparsity.utils.cusparse_2_4_utils import ( compress_to_torch_sparse_semi_structured_mat, dense_matmul, get_random_mat, is_semi_structured_supported, semi_structured_sparse_dense_gemm) +from vllm._custom_ops import (semi_structured_fp8_prepare_mm, semi_structured_fp8_mm_prepared) from vllm.utils import FlexibleArgumentParser DEFAULT_MODELS = list(WEIGHT_SHAPES.keys()) @@ -79,15 +80,15 @@ def bench(m: int, k: int, n: int, label: str, sub_label: str, a, b = make_rand_tensors(torch.int8, m, n, k) # cutlass i8 - timers.append( - bench_fn(label, sub_label, "cutlass_i8_i8_matmul-w-scales", - dense_matmul, a, b, torch.int8)) + # timers.append( + # bench_fn(label, sub_label, "cutlass_i8_i8_matmul-w-scales", + # dense_matmul, a, b, torch.int8)) # cusparseLt i8 - timers.append( - bench_fn(label, sub_label, "cusparseLt_i8_i8_2_4", - semi_structured_sparse_dense_gemm, - compress_to_torch_sparse_semi_structured_mat(a), b)) + # timers.append( + # bench_fn(label, sub_label, "cusparseLt_i8_i8_2_4", + # semi_structured_sparse_dense_gemm, + # compress_to_torch_sparse_semi_structured_mat(a), b)) if use_fp8: a, b = make_rand_tensors(torch.float8_e4m3fn, m, n, k) @@ -101,6 +102,13 @@ def bench(m: int, k: int, n: int, label: str, sub_label: str, bench_fn(label, sub_label, "cusparseLt_fp8_fp8_2_4", semi_structured_sparse_dense_gemm, compress_to_torch_sparse_semi_structured_mat(a), b)) + + a_compressed = compress_to_torch_sparse_semi_structured_mat(a) + handle = semi_structured_fp8_prepare_mm(a_compressed.packed, b) + timers.append( + bench_fn(label, sub_label, "cusparseLt_fp8_fp8_2_4_prepared", + semi_structured_fp8_mm_prepared, + torch.tensor([handle], dtype=torch.int64, device='cuda'))) return timers @@ -114,9 +122,6 @@ def print_timers(timers: Iterable[TMeasurement]): def run(MKNs: Iterable[Tuple[int, int, int]], use_fp8: bool) -> Iterable[TMeasurement]: results = [] - # MKNs = [(2048, 8192, 14336)] - # MKNs = [(32, 11008, 4096)] - MKNs = [(2048, 11008, 14336)] for m, k, n in MKNs: timers = bench(m, k, n, "gemm", f"MKN=({m}x{k}x{n})", use_fp8) print_timers(timers) @@ -181,8 +186,8 @@ def model_shapes(model_name: str, tp_size: int) -> List[Tuple[int, int]]: for d in model_bench_data: all_data.extend(d) # pickle all data - with open(f"model_bench-{timestamp}.pkl", "wb") as f: - pkl.dump(all_data, f) + # with open(f"model_bench-{timestamp}.pkl", "wb") as f: + # pkl.dump(all_data, f) if __name__ == '__main__': diff --git a/csrc/ops.h b/csrc/ops.h index 690ea5b18939a..655cd0d9d555b 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -232,9 +232,10 @@ torch::Tensor cslt_mm_fp8_semi_structured( const c10::optional& bias_opt, bool transpose_result); int64_t cslt_prepare_mm_fp8_semi_structured(const torch::Tensor& compressed_A, - const torch::Tensor& dense_B); + const torch::Tensor& dense_B, + const c10::optional& bias_opt, bool transpose_result); -torch::Tensor cslt_mm_fp8_semi_structured_prepared(int64_t id); +torch::Tensor cslt_mm_fp8_semi_structured_prepared(const torch::Tensor& id); void cslt_fp8_semi_structured_destroy(int64_t id); diff --git a/csrc/quantization/fp8_semi_structured/cusparseLt.cpp b/csrc/quantization/fp8_semi_structured/cusparseLt.cpp index a79076567b3d2..42975d5ff0eea 100644 --- a/csrc/quantization/fp8_semi_structured/cusparseLt.cpp +++ b/csrc/quantization/fp8_semi_structured/cusparseLt.cpp @@ -58,27 +58,44 @@ " when calling `" #EXPR "`"); \ } while (0) + + namespace vllm { namespace cusparseLt { -cusparseLtHandle_t handle; -bool handle_initialized = false; -using cacheID = int64_t; - struct cusparseLtEntry { - // cusparseLtEntry(): device() {} - int m; - int n; - int k; + // cusparseLtEntry() {} + // void operator=(const cusparseLtEntry& entry) { + // sparse_input_descriptor = entry.sparse_input_descriptor; + // dense_input_descriptor = entry.dense_input_descriptor; + // res_descriptor = entry.res_descriptor; + // C_descriptor = entry.C_descriptor; + // matmul = entry.matmul; + // plan = entry.plan; + + // sparse_mat_ptr = entry.sparse_mat_ptr; + // dense_mat_ptr = entry.dense_mat_ptr; + + // device = std::move(entry.device); + // allocator = entry.allocator; + // out_dtype = std::move(entry.out_dtype); + + // workspace_ptr = std::move(entry.workspace_ptr); + + // m = entry.m; + // n = entry.n; + // k = entry.k; + // } - cusparseLtMatmulDescriptor_t matmul; - cusparseLtMatmulPlan_t plan; - cusparseLtMatmulAlgSelection_t alg_sel; cusparseLtMatDescriptor_t sparse_input_descriptor; cusparseLtMatDescriptor_t dense_input_descriptor; cusparseLtMatDescriptor_t res_descriptor; cusparseLtMatDescriptor_t C_descriptor; + cusparseLtMatmulDescriptor_t matmul; + cusparseLtMatmulPlan_t plan; + + void* sparse_mat_ptr; void* dense_mat_ptr; @@ -87,13 +104,23 @@ struct cusparseLtEntry { c10::cuda::CUDACachingAllocator::CUDAAllocator* allocator; c10::DataPtr workspace_ptr; + + int m; + int n; + int k; }; -std::map cusparseLt_cache; +cusparseLtHandle_t handle; +bool handle_initialized = false; +using cacheID = int64_t; + +std::map cusparseLt_cache; } // namespace cusparseLt } // namespace vllm +vllm::cusparseLt::cusparseLtEntry entry; + torch::Tensor cslt_compress_fp8_semi_structured(const torch::Tensor& input) { TORCH_CHECK(input.scalar_type() == at::ScalarType::Float8_e4m3fn, "Only float8 e4m3 is supported in vllm:cslt_compress"); @@ -128,15 +155,14 @@ torch::Tensor cslt_compress_fp8_semi_structured(const torch::Tensor& input) { return compressed_tensor; } -// vllm::cusparseLt::cacheID cslt_prepare_mm_fp8_semi_structured(const -// torch::Tensor& compressed_A, const torch::Tensor& dense_B) { -int64_t cslt_prepare_mm_fp8_semi_structured(const torch::Tensor& compressed_A, - const torch::Tensor& dense_B) { +vllm::cusparseLt::cacheID cslt_prepare_mm_fp8_semi_structured(const torch::Tensor& compressed_A, + const torch::Tensor& dense_B, + const c10::optional& bias_opt, bool transpose_result) { TORCH_CHECK(compressed_A.scalar_type() == at::ScalarType::Float8_e4m3fn, "Only float8 e4m3 is supported in vllm:cslt_compress"); namespace vc = vllm::cusparseLt; if (!vc::handle_initialized) { - TORCH_CUDASPARSE_CHECK(cusparseLtInit(&vllm::cusparseLt::handle)); + TORCH_CUDASPARSE_CHECK(cusparseLtInit(&vc::handle)); vc::handle_initialized = true; } vc::cacheID id; @@ -145,7 +171,9 @@ int64_t cslt_prepare_mm_fp8_semi_structured(const torch::Tensor& compressed_A, } else { id = vc::cusparseLt_cache.rbegin()->first + 1; } - vc::cusparseLtEntry& entry = vc::cusparseLt_cache[id]; + + // vc::cusparseLtEntry& entry = vc::cusparseLt_cache[id]; + // vc::cusparseLtEntry entry; float alpha = 1.0; float beta = 0.0; @@ -155,7 +183,6 @@ int64_t cslt_prepare_mm_fp8_semi_structured(const torch::Tensor& compressed_A, cusparseComputeType compute_type = CUSPARSE_COMPUTE_32F; auto compression_factor = 9; auto out_dtype = dense_B.scalar_type(); - int64_t k = dense_B.size(0); int64_t n = dense_B.size(1); int64_t m = (compressed_A.numel() * 16 / compression_factor) / k; @@ -183,10 +210,9 @@ int64_t cslt_prepare_mm_fp8_semi_structured(const torch::Tensor& compressed_A, "float32} for fp8 inputs"); break; } - // initialize sparse descriptor TORCH_CUDASPARSE_CHECK(cusparseLtStructuredDescriptorInit( - &vc::handle, &entry.sparse_input_descriptor, m, k, k, 16, input_type, + &vc::handle, &(entry.sparse_input_descriptor), m, k, k, 16, input_type, CUSPARSE_ORDER_ROW, CUSPARSELT_SPARSITY_50_PERCENT)); // initialize dense descriptor @@ -196,13 +222,15 @@ int64_t cslt_prepare_mm_fp8_semi_structured(const torch::Tensor& compressed_A, (dense_B.is_contiguous()) ? n : k, 16, input_type, CUSPARSE_ORDER_ROW)); // initialize result descriptor - TORCH_CUDASPARSE_CHECK( - cusparseLtDenseDescriptorInit(&vc::handle, &entry.res_descriptor, m, n, m, - 16, output_type, CUSPARSE_ORDER_ROW)); + TORCH_CUDASPARSE_CHECK(cusparseLtDenseDescriptorInit( + &vc::handle, &entry.res_descriptor, m, n, (transpose_result) ? m : n, 16, + output_type, + (transpose_result) ? CUSPARSE_ORDER_COL : CUSPARSE_ORDER_ROW)); + TORCH_CUDASPARSE_CHECK(cusparseLtDenseDescriptorInit( + &vc::handle, &entry.C_descriptor, m, n, (transpose_result) ? m : n, 16, C_type, + (transpose_result) ? CUSPARSE_ORDER_COL : CUSPARSE_ORDER_ROW)); - TORCH_CUDASPARSE_CHECK( - cusparseLtDenseDescriptorInit(&vc::handle, &entry.C_descriptor, m, n, n, - 16, C_type, CUSPARSE_ORDER_ROW)); + cusparseLtMatmulAlgSelection_t alg_sel; TORCH_CUDASPARSE_CHECK(cusparseLtMatmulDescriptorInit( &vc::handle, &entry.matmul, CUSPARSE_OPERATION_NON_TRANSPOSE, @@ -210,13 +238,11 @@ int64_t cslt_prepare_mm_fp8_semi_structured(const torch::Tensor& compressed_A, : CUSPARSE_OPERATION_TRANSPOSE, &entry.sparse_input_descriptor, &entry.dense_input_descriptor, &entry.C_descriptor, &entry.res_descriptor, compute_type)); - TORCH_CUDASPARSE_CHECK(cusparseLtMatmulAlgSelectionInit( - &vc::handle, &entry.alg_sel, &entry.matmul, + &vc::handle, &alg_sel, &entry.matmul, CUSPARSELT_MATMUL_ALG_DEFAULT)); TORCH_CUDASPARSE_CHECK(cusparseLtMatmulPlanInit( - &vc::handle, &entry.plan, &entry.matmul, &entry.alg_sel)); - + &vc::handle, &entry.plan, &entry.matmul, &alg_sel)); size_t workspace_size; TORCH_CUDASPARSE_CHECK( cusparseLtMatmulGetWorkspace(&vc::handle, &entry.plan, &workspace_size)); @@ -228,18 +254,22 @@ int64_t cslt_prepare_mm_fp8_semi_structured(const torch::Tensor& compressed_A, entry.m = m; entry.n = n; entry.k = k; + entry.sparse_mat_ptr = compressed_A.data_ptr(); + entry.dense_mat_ptr = dense_B.data_ptr(); return id; } torch::Tensor cslt_mm_fp8_semi_structured_prepared( - vllm::cusparseLt::cacheID id) { + const torch::Tensor& id_tensor) { namespace vc = vllm::cusparseLt; TORCH_CHECK(vc::handle_initialized, "Call of matmul with unintialized matmul"); - if (vc::cusparseLt_cache.count(id) == 0) { - TORCH_CHECK(false, "cusparse matmul Id is not found"); - } - const auto& entry = vc::cusparseLt_cache[id]; + // TORCH_CHECK(id_tensor.numel() == 1, "ID has to be single valued"); + // auto id = id_tensor.item(); + // if (vc::cusparseLt_cache.count(id) == 0) { + // TORCH_CHECK(false, "cusparse matmul Id is not found"); + // } + // const auto& entry = vc::cusparseLt_cache[id]; auto res_tensor_options = c10::TensorOptions().dtype(entry.out_dtype).device(entry.device); @@ -262,7 +292,7 @@ void cslt_fp8_semi_structured_destroy(vllm::cusparseLt::cacheID id) { if (vllm::cusparseLt::cusparseLt_cache.count(id) == 0) { TORCH_CHECK(false, "cusparse matmul Id is not found"); } - auto& entry = vllm::cusparseLt::cusparseLt_cache[id]; + // auto& entry = vllm::cusparseLt::cusparseLt_cache[id]; TORCH_CUDASPARSE_CHECK( cusparseLtMatDescriptorDestroy(&entry.sparse_input_descriptor)); diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 33ef571363937..755338fb6f559 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -336,11 +336,11 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ops.def( "cslt_prepare_mm_fp8_semi_structured(Tensor! compressed_A, Tensor! " - "denseB) -> int"); + "denseB, Tensor!? bias, bool transpose_result) -> int"); ops.impl("cslt_prepare_mm_fp8_semi_structured", torch::kCUDA, &cslt_prepare_mm_fp8_semi_structured); - ops.def("cslt_mm_fp8_semi_structured_prepared(int cacheId) -> Tensor"); + ops.def("cslt_mm_fp8_semi_structured_prepared(Tensor cacheId) -> Tensor"); ops.impl("cslt_mm_fp8_semi_structured_prepared", torch::kCUDA, &cslt_mm_fp8_semi_structured_prepared); diff --git a/tests/test_cusparseLt.cpp b/tests/test_cusparseLt.cpp new file mode 100644 index 0000000000000..9c8d3cb813ef1 --- /dev/null +++ b/tests/test_cusparseLt.cpp @@ -0,0 +1,12 @@ + #include + +cusparseLtHandle_t handle; + + +struct Entry { + cusparseLtMatDescriptor_t sparse_input_descriptor; +}; + +int main() { + +} \ No newline at end of file diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 34c9525cb7401..387403ff4d889 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -723,18 +723,20 @@ def semi_structured_fp8_mm(A_compressed: torch.Tensor, def semi_structured_fp8_prepare_mm(A_compressed: torch.Tensor, - B_dense: torch.Tensor) -> int: + B_dense: torch.Tensor, + bias: Optional[torch.Tensor] = None, + transpose_result: bool = False) -> int: assert A_compressed.dtype == torch.float8_e4m3fn return torch.ops._C.cslt_prepare_mm_fp8_semi_structured( - A_compressed, B_dense) + A_compressed, B_dense, bias, transpose_result) def semi_structured_fp8_mm_prepared(cacheId: int) -> torch.Tensor: - return torch.ops.cslt_mm_fp8_semi_structured_prepared(cacheId) + return torch.ops._C.cslt_mm_fp8_semi_structured_prepared(cacheId) def semi_structured_fp8_destroy(cacheId: int): - torch.ops.cslt_fp8_semi_structured_destroy(cacheId) + torch.ops._C.cslt_fp8_semi_structured_destroy(cacheId) # int8