Skip to content

Commit

Permalink
Fix destroy function
Browse files Browse the repository at this point in the history
  • Loading branch information
ilmarkov committed Nov 13, 2024
1 parent b1aaea5 commit c36401c
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 9 deletions.
11 changes: 8 additions & 3 deletions benchmarks/cusparseLt_benchmarks/benchmark_24.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from vllm.model_executor.layers.sparsity.utils.cusparse_2_4_utils import (
compress_to_torch_sparse_semi_structured_mat, dense_matmul, get_random_mat,
is_semi_structured_supported, semi_structured_sparse_dense_gemm)
from vllm._custom_ops import (semi_structured_fp8_prepare_mm, semi_structured_fp8_mm_prepared)
from vllm._custom_ops import (semi_structured_fp8_prepare_mm, semi_structured_fp8_mm_prepared, semi_structured_fp8_destroy)
from vllm.utils import FlexibleArgumentParser

DEFAULT_MODELS = list(WEIGHT_SHAPES.keys())
Expand Down Expand Up @@ -105,11 +105,12 @@ def bench(m: int, k: int, n: int, label: str, sub_label: str,

a_compressed = compress_to_torch_sparse_semi_structured_mat(a)
handle = semi_structured_fp8_prepare_mm(a_compressed.packed, b)
id = torch.tensor([handle], dtype=torch.int64, device='cuda')
timers.append(
bench_fn(label, sub_label, "cusparseLt_fp8_fp8_2_4_prepared",
semi_structured_fp8_mm_prepared,
torch.tensor([handle], dtype=torch.int64, device='cuda')))

id))
semi_structured_fp8_destroy(id)
return timers


Expand All @@ -122,6 +123,10 @@ def print_timers(timers: Iterable[TMeasurement]):
def run(MKNs: Iterable[Tuple[int, int, int]],
use_fp8: bool) -> Iterable[TMeasurement]:
results = []
# MKNs = [(2048, 8192, 14336)]
MKNs = [(32, 11008, 4096)]
# MKNs = [(2048, 11008, 14336)]

for m, k, n in MKNs:
timers = bench(m, k, n, "gemm", f"MKN=({m}x{k}x{n})", use_fp8)
print_timers(timers)
Expand Down
2 changes: 1 addition & 1 deletion csrc/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,6 @@ int64_t cslt_prepare_mm_fp8_semi_structured(const torch::Tensor& compressed_A,

torch::Tensor cslt_mm_fp8_semi_structured_prepared(const torch::Tensor& id);

void cslt_fp8_semi_structured_destroy(int64_t id);
void cslt_fp8_semi_structured_destroy(const torch::Tensor& id_tensor);

#endif
10 changes: 6 additions & 4 deletions csrc/quantization/fp8_semi_structured/cusparseLt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -286,12 +286,14 @@ torch::Tensor cslt_mm_fp8_semi_structured_prepared(
return res;
}

void cslt_fp8_semi_structured_destroy(vllm::cusparseLt::cacheID id) {
void cslt_fp8_semi_structured_destroy(const torch::Tensor& id_tensor) {
TORCH_CHECK(vllm::cusparseLt::handle_initialized,
"Call of destroy cusparseId with unintialized cusparseLt");
if (vllm::cusparseLt::cusparseLt_cache.count(id) == 0) {
TORCH_CHECK(false, "cusparse matmul Id is not found");
}
// TORCH_CHECK(id_tensor.numel() == 1, "ID has to be single valued");
// auto id = id_tensor.item<vc::cacheID>();
// if (vllm::cusparseLt::cusparseLt_cache.count(id) == 0) {
// TORCH_CHECK(false, "cusparse matmul Id is not found");
// }
// auto& entry = vllm::cusparseLt::cusparseLt_cache[id];

TORCH_CUDASPARSE_CHECK(
Expand Down
2 changes: 1 addition & 1 deletion csrc/torch_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.impl("cslt_mm_fp8_semi_structured_prepared", torch::kCUDA,
&cslt_mm_fp8_semi_structured_prepared);

ops.def("cslt_fp8_semi_structured_destroy(int cacheId) -> ()");
ops.def("cslt_fp8_semi_structured_destroy(Tensor cacheId) -> ()");
ops.impl("cslt_fp8_semi_structured_destroy", torch::kCUDA,
&cslt_fp8_semi_structured_destroy);
#endif
Expand Down

0 comments on commit c36401c

Please sign in to comment.