Skip to content

Commit

Permalink
Cutlass grouped gemm files
Browse files Browse the repository at this point in the history
Signed-off-by: ElizaWszola <[email protected]>
  • Loading branch information
ElizaWszola committed Dec 6, 2024
1 parent 2298e69 commit 1825ef8
Show file tree
Hide file tree
Showing 7 changed files with 514 additions and 3 deletions.
9 changes: 6 additions & 3 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -209,13 +209,14 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
FetchContent_Declare(
cutlass
GIT_REPOSITORY https://github.com/nvidia/cutlass.git
GIT_TAG v3.5.1
# GIT_TAG v3.5.1
GIT_TAG dbdae514e03f83968f8b7dd4fb064071b9bfbdd1
GIT_PROGRESS TRUE

# Speed up CUTLASS download by retrieving only the specified GIT_TAG instead of the history.
# Important: If GIT_SHALLOW is enabled then GIT_TAG works only with branch names and tags.
# So if the GIT_TAG above is updated to a commit hash, GIT_SHALLOW must be set to FALSE
GIT_SHALLOW TRUE
GIT_SHALLOW FALSE
)
FetchContent_MakeAvailable(cutlass)

Expand Down Expand Up @@ -261,7 +262,9 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
# CUDA 12.0 or later (and only work on Hopper, 9.0/9.0a for now).
cuda_archs_loose_intersection(SCALED_MM_3X_ARCHS "9.0;9.0a" "${CUDA_ARCHS}")
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0 AND SCALED_MM_3X_ARCHS)
set(SRCS "csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu")
set(SRCS
"csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu"
"csrc/quantization/cutlass_w8a8/grouped_gemm_test.cu")
set_gencode_flags_for_srcs(
SRCS "${SRCS}"
CUDA_ARCHS "${SCALED_MM_3X_ARCHS}")
Expand Down
7 changes: 7 additions & 0 deletions csrc/cpu/torch_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,13 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" Tensor b, Tensor a_scales,"
" Tensor b_scales, Tensor? bias) -> ()");
ops.impl("cutlass_scaled_mm", torch::kCPU, &int8_scaled_mm);
// CUTLASS w8a8 grouped GEMM // TODO complete this
ops.def(
"cutlass_grouped_mm(Tensor! out, Tensor a, Tensor b, Tensor a_scales, "
" Tensor b_scales, Tensor problem_sizes, "
" Tensor out_offsets, Tensor a_offsets, "
" Tensor b_offsets) -> ()");
ops.impl("cutlass_grouped_mm", torch::kCUDA, &cutlass_grouped_mm);
// w8a8 GEMM, supporting asymmetric per-tensor or per-row/column
// quantization.
ops.def(
Expand Down
8 changes: 8 additions & 0 deletions csrc/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,14 @@ void cutlass_scaled_mm(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b_scales,
c10::optional<torch::Tensor> const& bias);

void cutlass_grouped_mm(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b, torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
torch::Tensor const& problem_sizes,
torch::Tensor const& out_offsets,
torch::Tensor const& a_offsets,
torch::Tensor const& b_offsets);

void cutlass_scaled_mm_azp(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
Expand Down
Loading

0 comments on commit 1825ef8

Please sign in to comment.