Skip to content

Commit

Permalink
Prepare for reproduce
Browse files Browse the repository at this point in the history
  • Loading branch information
ilmarkov committed Nov 13, 2024
1 parent c36401c commit 9f6a469
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 22 deletions.
14 changes: 7 additions & 7 deletions benchmarks/cusparseLt_benchmarks/benchmark_24.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,15 +93,15 @@ def bench(m: int, k: int, n: int, label: str, sub_label: str,
if use_fp8:
a, b = make_rand_tensors(torch.float8_e4m3fn, m, n, k)
# cutlass fp8
timers.append(
bench_fn(label, sub_label, "cutlass_fp8_fp8_matmul-w-scales",
dense_matmul, a, b, torch.float8_e4m3fn))
# timers.append(
# bench_fn(label, sub_label, "cutlass_fp8_fp8_matmul-w-scales",
# dense_matmul, a, b, torch.float8_e4m3fn))

# cusparseLt fp8
timers.append(
bench_fn(label, sub_label, "cusparseLt_fp8_fp8_2_4",
semi_structured_sparse_dense_gemm,
compress_to_torch_sparse_semi_structured_mat(a), b))
# timers.append(
# bench_fn(label, sub_label, "cusparseLt_fp8_fp8_2_4",
# semi_structured_sparse_dense_gemm,
# compress_to_torch_sparse_semi_structured_mat(a), b))

a_compressed = compress_to_torch_sparse_semi_structured_mat(a)
handle = semi_structured_fp8_prepare_mm(a_compressed.packed, b)
Expand Down
31 changes: 16 additions & 15 deletions csrc/quantization/fp8_semi_structured/cusparseLt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ std::map<cacheID, cusparseLtEntry> cusparseLt_cache;
} // namespace cusparseLt
} // namespace vllm

vllm::cusparseLt::cusparseLtEntry entry;
// vllm::cusparseLt::cusparseLtEntry entry;

torch::Tensor cslt_compress_fp8_semi_structured(const torch::Tensor& input) {
TORCH_CHECK(input.scalar_type() == at::ScalarType::Float8_e4m3fn,
Expand Down Expand Up @@ -172,7 +172,7 @@ vllm::cusparseLt::cacheID cslt_prepare_mm_fp8_semi_structured(const torch::Tenso
id = vc::cusparseLt_cache.rbegin()->first + 1;
}

// vc::cusparseLtEntry& entry = vc::cusparseLt_cache[id];
vc::cusparseLtEntry& entry = vc::cusparseLt_cache[id];
// vc::cusparseLtEntry entry;

float alpha = 1.0;
Expand Down Expand Up @@ -264,12 +264,12 @@ torch::Tensor cslt_mm_fp8_semi_structured_prepared(
namespace vc = vllm::cusparseLt;
TORCH_CHECK(vc::handle_initialized,
"Call of matmul with unintialized matmul");
// TORCH_CHECK(id_tensor.numel() == 1, "ID has to be single valued");
// auto id = id_tensor.item<vc::cacheID>();
// if (vc::cusparseLt_cache.count(id) == 0) {
// TORCH_CHECK(false, "cusparse matmul Id is not found");
// }
// const auto& entry = vc::cusparseLt_cache[id];
TORCH_CHECK(id_tensor.numel() == 1, "ID has to be single valued");
auto id = id_tensor.item<vc::cacheID>();
if (vc::cusparseLt_cache.count(id) == 0) {
TORCH_CHECK(false, "cusparse matmul Id is not found");
}
const auto& entry = vc::cusparseLt_cache[id];

auto res_tensor_options =
c10::TensorOptions().dtype(entry.out_dtype).device(entry.device);
Expand All @@ -287,14 +287,15 @@ torch::Tensor cslt_mm_fp8_semi_structured_prepared(
}

void cslt_fp8_semi_structured_destroy(const torch::Tensor& id_tensor) {
TORCH_CHECK(vllm::cusparseLt::handle_initialized,
namespace vc = vllm::cusparseLt;
TORCH_CHECK(vc::handle_initialized,
"Call of destroy cusparseId with unintialized cusparseLt");
// TORCH_CHECK(id_tensor.numel() == 1, "ID has to be single valued");
// auto id = id_tensor.item<vc::cacheID>();
// if (vllm::cusparseLt::cusparseLt_cache.count(id) == 0) {
// TORCH_CHECK(false, "cusparse matmul Id is not found");
// }
// auto& entry = vllm::cusparseLt::cusparseLt_cache[id];
TORCH_CHECK(id_tensor.numel() == 1, "ID has to be single valued");
auto id = id_tensor.item<vc::cacheID>();
if (vc::cusparseLt_cache.count(id) == 0) {
TORCH_CHECK(false, "cusparse matmul Id is not found");
}
auto& entry = vc::cusparseLt_cache[id];

TORCH_CUDASPARSE_CHECK(
cusparseLtMatDescriptorDestroy(&entry.sparse_input_descriptor));
Expand Down

0 comments on commit 9f6a469

Please sign in to comment.