From 9f6a46930749d4c6696f6e8d92f603e64cb1e7ac Mon Sep 17 00:00:00 2001 From: ilmarkov Date: Fri, 25 Oct 2024 15:27:51 +0000 Subject: [PATCH] Prepare for reproduce --- .../cusparseLt_benchmarks/benchmark_24.py | 14 ++++----- .../fp8_semi_structured/cusparseLt.cpp | 31 ++++++++++--------- 2 files changed, 23 insertions(+), 22 deletions(-) diff --git a/benchmarks/cusparseLt_benchmarks/benchmark_24.py b/benchmarks/cusparseLt_benchmarks/benchmark_24.py index 101a9bc20be6e..cc9a9e1c2603c 100644 --- a/benchmarks/cusparseLt_benchmarks/benchmark_24.py +++ b/benchmarks/cusparseLt_benchmarks/benchmark_24.py @@ -93,15 +93,15 @@ def bench(m: int, k: int, n: int, label: str, sub_label: str, if use_fp8: a, b = make_rand_tensors(torch.float8_e4m3fn, m, n, k) # cutlass fp8 - timers.append( - bench_fn(label, sub_label, "cutlass_fp8_fp8_matmul-w-scales", - dense_matmul, a, b, torch.float8_e4m3fn)) + # timers.append( + # bench_fn(label, sub_label, "cutlass_fp8_fp8_matmul-w-scales", + # dense_matmul, a, b, torch.float8_e4m3fn)) # cusparseLt fp8 - timers.append( - bench_fn(label, sub_label, "cusparseLt_fp8_fp8_2_4", - semi_structured_sparse_dense_gemm, - compress_to_torch_sparse_semi_structured_mat(a), b)) + # timers.append( + # bench_fn(label, sub_label, "cusparseLt_fp8_fp8_2_4", + # semi_structured_sparse_dense_gemm, + # compress_to_torch_sparse_semi_structured_mat(a), b)) a_compressed = compress_to_torch_sparse_semi_structured_mat(a) handle = semi_structured_fp8_prepare_mm(a_compressed.packed, b) diff --git a/csrc/quantization/fp8_semi_structured/cusparseLt.cpp b/csrc/quantization/fp8_semi_structured/cusparseLt.cpp index 80462925fb993..be242283648d1 100644 --- a/csrc/quantization/fp8_semi_structured/cusparseLt.cpp +++ b/csrc/quantization/fp8_semi_structured/cusparseLt.cpp @@ -119,7 +119,7 @@ std::map cusparseLt_cache; } // namespace cusparseLt } // namespace vllm -vllm::cusparseLt::cusparseLtEntry entry; +// vllm::cusparseLt::cusparseLtEntry entry; torch::Tensor cslt_compress_fp8_semi_structured(const torch::Tensor& input) { TORCH_CHECK(input.scalar_type() == at::ScalarType::Float8_e4m3fn, @@ -172,7 +172,7 @@ vllm::cusparseLt::cacheID cslt_prepare_mm_fp8_semi_structured(const torch::Tenso id = vc::cusparseLt_cache.rbegin()->first + 1; } - // vc::cusparseLtEntry& entry = vc::cusparseLt_cache[id]; + vc::cusparseLtEntry& entry = vc::cusparseLt_cache[id]; // vc::cusparseLtEntry entry; float alpha = 1.0; @@ -264,12 +264,12 @@ torch::Tensor cslt_mm_fp8_semi_structured_prepared( namespace vc = vllm::cusparseLt; TORCH_CHECK(vc::handle_initialized, "Call of matmul with unintialized matmul"); - // TORCH_CHECK(id_tensor.numel() == 1, "ID has to be single valued"); - // auto id = id_tensor.item(); - // if (vc::cusparseLt_cache.count(id) == 0) { - // TORCH_CHECK(false, "cusparse matmul Id is not found"); - // } - // const auto& entry = vc::cusparseLt_cache[id]; + TORCH_CHECK(id_tensor.numel() == 1, "ID has to be single valued"); + auto id = id_tensor.item(); + if (vc::cusparseLt_cache.count(id) == 0) { + TORCH_CHECK(false, "cusparse matmul Id is not found"); + } + const auto& entry = vc::cusparseLt_cache[id]; auto res_tensor_options = c10::TensorOptions().dtype(entry.out_dtype).device(entry.device); @@ -287,14 +287,15 @@ torch::Tensor cslt_mm_fp8_semi_structured_prepared( } void cslt_fp8_semi_structured_destroy(const torch::Tensor& id_tensor) { - TORCH_CHECK(vllm::cusparseLt::handle_initialized, + namespace vc = vllm::cusparseLt; + TORCH_CHECK(vc::handle_initialized, "Call of destroy cusparseId with unintialized cusparseLt"); - // TORCH_CHECK(id_tensor.numel() == 1, "ID has to be single valued"); - // auto id = id_tensor.item(); - // if (vllm::cusparseLt::cusparseLt_cache.count(id) == 0) { - // TORCH_CHECK(false, "cusparse matmul Id is not found"); - // } - // auto& entry = vllm::cusparseLt::cusparseLt_cache[id]; + TORCH_CHECK(id_tensor.numel() == 1, "ID has to be single valued"); + auto id = id_tensor.item(); + if (vc::cusparseLt_cache.count(id) == 0) { + TORCH_CHECK(false, "cusparse matmul Id is not found"); + } + auto& entry = vc::cusparseLt_cache[id]; TORCH_CUDASPARSE_CHECK( cusparseLtMatDescriptorDestroy(&entry.sparse_input_descriptor));