From 1ea7874a62b726c9a149b9b5ba6e0e8ddc701f27 Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Tue, 28 Jan 2025 04:42:59 +0000 Subject: [PATCH] cleanup, add import Signed-off-by: ElizaWszola --- csrc/cpu/torch_bindings.cpp | 15 +- csrc/ops.h | 9 +- .../cutlass_w8a8/grouped_mm_c3x.cu | 178 ++++++------------ csrc/torch_bindings.cpp | 3 +- tests/kernels/test_cutlass.py | 36 ++-- 5 files changed, 81 insertions(+), 160 deletions(-) diff --git a/csrc/cpu/torch_bindings.cpp b/csrc/cpu/torch_bindings.cpp index 2de5e595c64b2..bbe6d2e8652d3 100644 --- a/csrc/cpu/torch_bindings.cpp +++ b/csrc/cpu/torch_bindings.cpp @@ -119,19 +119,18 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { " Tensor b_scales, Tensor? bias) -> ()"); ops.impl("cutlass_scaled_mm", torch::kCPU, &int8_scaled_mm); // CUTLASS w8a8 grouped GEMM // TODO complete this -// ops.def( -// "cutlass_grouped_mm(Tensor! out, Tensor a, Tensor b, Tensor a_scales, " -// " Tensor b_scales, Tensor problem_sizes, " -// " Tensor out_offsets, Tensor a_offsets, " -// " Tensor b_offsets) -> ()"); -// ops.impl("cutlass_grouped_mm", torch::kCUDA, &cutlass_grouped_mm); + // ops.def( + // "cutlass_grouped_mm(Tensor! out, Tensor a, Tensor b, Tensor a_scales, + // " " Tensor b_scales, Tensor problem_sizes, " " + // Tensor out_offsets, Tensor a_offsets, " " Tensor + // b_offsets) -> ()"); + // ops.impl("cutlass_grouped_mm", torch::kCUDA, &cutlass_grouped_mm); ops.def( "compute_expert_offsets(Tensor! trg_a_ptrs," " Tensor! a, Tensor topk_ids," " Tensor! expert_offsets, SymInt num_experts) -> ()"); - ops.impl("compute_expert_offsets", torch::kCUDA, - &compute_expert_offsets); + ops.impl("compute_expert_offsets", torch::kCUDA, &compute_expert_offsets); // w8a8 GEMM, supporting asymmetric per-tensor or per-row/column // quantization. diff --git a/csrc/ops.h b/csrc/ops.h index 394fb172a476b..a67d7d757f3a7 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -167,11 +167,10 @@ void cutlass_grouped_mm(torch::Tensor& out_tensors, torch::Tensor const& expert_offsets, torch::Tensor const& problem_sizes); -void compute_expert_offsets(torch::Tensor& trg_a_ptrs, - torch::Tensor& a, - const torch::Tensor& topk_ids, - torch::Tensor& expert_offsets, - const int64_t num_experts); +void compute_expert_offsets(torch::Tensor& trg_a_ptrs, torch::Tensor& a, + const torch::Tensor& topk_ids, + torch::Tensor& expert_offsets, + const int64_t num_experts); void cutlass_scaled_mm_azp(torch::Tensor& out, torch::Tensor const& a, torch::Tensor const& b, diff --git a/csrc/quantization/cutlass_w8a8/grouped_mm_c3x.cu b/csrc/quantization/cutlass_w8a8/grouped_mm_c3x.cu index 957ef0cf06779..8ab15fd7be00a 100644 --- a/csrc/quantization/cutlass_w8a8/grouped_mm_c3x.cu +++ b/csrc/quantization/cutlass_w8a8/grouped_mm_c3x.cu @@ -80,8 +80,6 @@ struct cutlass_3x_group_gemm { const int AlignmentC = 128 / cutlass::sizeof_bits::value; using EVTCompute = typename Epilogue::EVTCompute; - // the orig hat cutlass::epilogue::fusion::LinearCombination using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< @@ -143,88 +141,44 @@ void cutlass_group_gemm_caller(torch::Tensor& out_tensors, bool per_act_token = a_scales.numel() != groups; bool per_out_ch = b_scales.numel() != groups; - // TORCH_CHECK((int)b_tensors.size() == groups, - // "Number of B tensors must match number of groups."); - // TORCH_CHECK((int)out_tensors.size() == groups, - // "Number of output tensors must match number of groups."); - - // std::vector a_ptrs_host(groups); - // std::vector b_ptrs_host(groups); - // std::vector c_ptrs_host(groups); - // std::vector d_ptrs_host(groups); - // std::vector a_scales_ptrs_host(groups); - // std::vector b_scales_ptrs_host(groups); - - // std::vector problem_sizes_host; - // problem_sizes_host.reserve(groups); - int b_single_size = k_size * n_size; int b_scale_single_size = per_out_ch ? out_tensors.size(1) : 1; auto options_int = torch::TensorOptions().dtype(torch::kInt64).device(a_tensors.device()); - torch::Tensor a_ptrs_base = torch::full({groups + 1}, - (int64_t)a_tensors.data_ptr(), - options_int); - torch::Tensor out_ptrs_base = torch::full({groups + 1}, - (int64_t)out_tensors.data_ptr(), - options_int); - torch::Tensor b_ptrs_base = torch::full({groups + 1}, - (int64_t)b_tensors.data_ptr(), - options_int); - torch::Tensor a_scales_base = torch::full({groups + 1}, - (int64_t)a_scales.data_ptr(), - options_int); - torch::Tensor b_scales_base = torch::full({groups + 1}, - (int64_t)b_scales.data_ptr(), - options_int); + torch::Tensor a_ptrs_base = + torch::full({groups + 1}, reinterpret_cast(a_tensors.data_ptr()), + options_int); + torch::Tensor out_ptrs_base = torch::full( + {groups + 1}, reinterpret_cast(out_tensors.data_ptr()), + options_int); + torch::Tensor b_ptrs_base = + torch::full({groups + 1}, reinterpret_cast(b_tensors.data_ptr()), + options_int); + torch::Tensor a_scales_base = + torch::full({groups + 1}, reinterpret_cast(a_scales.data_ptr()), + options_int); + torch::Tensor b_scales_base = + torch::full({groups + 1}, reinterpret_cast(b_scales.data_ptr()), + options_int); torch::Tensor b_offsets = torch::arange(0, b_single_size * (groups + 1), - b_single_size, options_int); + b_single_size, options_int); torch::Tensor a_scales_offsets = torch::arange(0, groups + 1, options_int); - torch::Tensor b_scales_offsets = torch::arange(0, b_scale_single_size * - (groups + 1), b_scale_single_size, - options_int); - - // multiply by offset of k 8-bit elements - torch::Tensor a_ptrs = a_ptrs_base.add(expert_offsets, a_tensors.size(1)); - // multiply by offset of n 16-bit elements - torch::Tensor out_ptrs = out_ptrs_base.add(expert_offsets, 2 * out_tensors.size(1)); - // multiply by offset of n 8-bit elements - torch::Tensor b_ptrs = b_ptrs_base.add(b_offsets); - - torch::Tensor a_scales_ptrs = a_scales_base.add(per_act_token ? expert_offsets : a_scales_offsets, 4); - torch::Tensor b_scales_ptrs = b_scales_base.add(b_scales_offsets, 4); - - // for (int g = 0; g < groups; ++g) { - // // b_ptrs_host[g] = - // // reinterpret_cast(b_list[g].data_ptr()); - // // c_ptrs_host[g] = - // // reinterpret_cast(out_tensors[g].data_ptr()); - // // d_ptrs_host[g] = reinterpret_cast(out_tensors[g].data_ptr()); - // // a_scales_ptrs_host[g] = - // // reinterpret_cast(a_scales[g].data_ptr()); - // // b_scales_ptrs_host[g] = - // // reinterpret_cast(b_scales[g].data_ptr()); - - // // printf("%p %p %p %p %p %p %p\n", a_ptrs_host[g], b_ptrs_host[g], - // // c_ptrs_host[g], d_ptrs_host[g],) - // // int64_t m = out_tensors[g].size(0); - // // int64_t k = a_tensors.size(1); - - // // int64_t k_b = b_tensors[g].size(0); - // // int64_t n = b_tensors[g].size(1); - - // // TORCH_CHECK(k == k_b, "Dimension mismatch between A and B: A has k=", k, - // // " while B has k=", k_b); - - // // // Optionally, verify output shape matches (m,n) - // // TORCH_CHECK(out_tensors[g].size(0) == m && out_tensors[g].size(1) == n, - // // "Output tensor shape does not match m,n from A,B: ", "Got ", - // // out_tensors[g].sizes(), " expected (", m, ", ", n, ")"); - - // // problem_sizes_host.push_back({(int)m, (int)n, (int)k}); - // } + torch::Tensor b_scales_offsets = torch::arange( + 0, b_scale_single_size * (groups + 1), b_scale_single_size, options_int); + + torch::Tensor a_ptrs = a_ptrs_base.add( + expert_offsets, sizeof(ElementAB_Type) * a_tensors.size(1)); + torch::Tensor out_ptrs = out_ptrs_base.add( + expert_offsets, sizeof(ElementC_Type) * out_tensors.size(1)); + torch::Tensor b_ptrs = b_ptrs_base.add(b_offsets, sizeof(ElementAB_Type)); + + torch::Tensor a_scales_ptrs = + a_scales_base.add(per_act_token ? expert_offsets : a_scales_offsets, + sizeof(ElementAccumulator)); + torch::Tensor b_scales_ptrs = + b_scales_base.add(b_scales_offsets, sizeof(ElementAccumulator)); using GemmKernel = typename Gemm::GemmKernel; using StrideA = Stride, Int<0>>; @@ -239,7 +193,6 @@ void cutlass_group_gemm_caller(torch::Tensor& out_tensors, int64_t lda = a_tensors.stride(0); // row-major (m x k) int64_t ldb = a_tensors.stride(0); // column-major (k x n) int64_t ldc = out_tensors.stride(0); // row-major (m x n) - printf("strides: %ld %ld %ld\n", lda, ldb, ldc); a_stride_host[g] = StrideA{lda, Int<1>{}, Int<0>{}}; b_stride_host[g] = StrideB{ldb, Int<1>{}, Int<0>{}}; @@ -252,49 +205,33 @@ void cutlass_group_gemm_caller(torch::Tensor& out_tensors, cutlass::KernelHardwareInfo::query_device_multiprocessor_count( hw_info.device_id); - // auto problem_sizes_ptr = make_device_ptr(problem_sizes_host); - // ProblemShape prob_shape{groups, problem_sizes_ptr.get(), - // problem_sizes_host.data()}; ProblemShape::UnderlyingProblemShape* problem_sizes_as_shapes = reinterpret_cast( - problem_sizes.data_ptr()); + problem_sizes.data_ptr()); ProblemShape prob_shape{groups, problem_sizes_as_shapes, nullptr}; - // auto a_ptrs_ptr = make_device_ptr(a_ptrs_host); - // auto b_ptrs_ptr = make_device_ptr(b_ptrs_host); - // auto c_ptrs_ptr = make_device_ptr(c_ptrs_host); - // auto d_ptrs_ptr = make_device_ptr(d_ptrs_host); - - // auto a_scales_ptrs_ptr = make_device_ptr(a_scales_ptrs_host); - // auto b_scales_ptrs_ptr = make_device_ptr(b_scales_ptrs_host); - auto a_stride_ptr = make_device_ptr(a_stride_host); auto b_stride_ptr = make_device_ptr(b_stride_host); auto c_stride_ptr = make_device_ptr(c_stride_host); - // auto c_ptrs_ptr = make_device_ptr(c_ptrs_host); - // auto d_ptrs_ptr = make_device_ptr(d_ptrs_host); - typename GemmKernel::MainloopArguments mainloop_args{ - (const ElementAB_Type**)a_ptrs.data_ptr(), a_stride_ptr.get(), - (const ElementAB_Type**)b_ptrs.data_ptr(), b_stride_ptr.get()}; + reinterpret_cast(a_ptrs.data_ptr()), + a_stride_ptr.get(), + reinterpret_cast(b_ptrs.data_ptr()), + b_stride_ptr.get()}; + // Currently, we are only able to do broadcast on either all or none a_scales // and on either all or none b_scales typename GemmKernel::EpilogueArguments epilogue_args{ - Gemm::Epilogue::prepare_args( - (const ElementAccumulator**)a_scales_ptrs.data_ptr(), - (const ElementAccumulator**)b_scales_ptrs.data_ptr(), - per_act_token, per_out_ch), - (const ElementC_Type**)out_ptrs.data_ptr(), c_stride_ptr.get(), - (ElementC_Type**)out_ptrs.data_ptr(), c_stride_ptr.get()}; - - // typename GemmKernel::EpilogueArguments epilogue_args{ - // Gemm::Epilogue::prepare_args( - // (const ElementAccumulator**)a_scales_ptrs.data_ptr(), - // b_scales_ptrs_ptr.get(), - // per_act_token, per_out_ch), - // (const ElementC_Type**)out_ptrs.data_ptr(), c_stride_ptr.get(), - // (ElementC_Type**)out_ptrs.data_ptr(), c_stride_ptr.get()}; + Gemm::Epilogue::prepare_args(reinterpret_cast( + a_scales_ptrs.data_ptr()), + reinterpret_cast( + b_scales_ptrs.data_ptr()), + per_act_token, per_out_ch), + reinterpret_cast(out_ptrs.data_ptr()), + c_stride_ptr.get(), + reinterpret_cast(out_ptrs.data_ptr()), + c_stride_ptr.get()}; typename GemmKernel::Arguments args{ cutlass::gemm::GemmUniversalMode::kGrouped, prob_shape, mainloop_args, @@ -309,11 +246,9 @@ void cutlass_group_gemm_caller(torch::Tensor& out_tensors, torch::TensorOptions().dtype(torch::kUInt8).device(a_tensors.device()); auto workspace = torch::empty(workspace_size, workspace_options); - // printf("before: %d\n", out_tensors[0][0]); auto stream = at::cuda::getCurrentCUDAStream(a_tensors.device().index()); cutlass::Status status = gemm_op.run(args, workspace.data_ptr(), stream); CUTLASS_CHECK(status); - // printf("after: %d\n", out_tensors[0][0]); } template typename Epilogue> struct sm90_fp8_config_M64 { - // M in [1, 64] static_assert(std::is_same()); using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8FastAccum; @@ -394,8 +328,7 @@ void cutlass_grouped_mm_sm90(torch::Tensor& out_tensors, __global__ void get_a_expert_offsets(cutlass::float_e4m3_t** trg_a_ptrs, cutlass::float_e4m3_t* base_a_ptr, const int* __restrict__ topk_ids, - int64_t* expert_offsets, - int topk_length) { + int64_t* expert_offsets, int topk_length) { int expert_id = threadIdx.x; int num_experts = blockDim.x; @@ -419,10 +352,12 @@ __global__ void get_a_expert_offsets(cutlass::float_e4m3_t** trg_a_ptrs, // // For a given "a" of size [M,K] performs a permutation of the M rows based // // on the given "perm" indices. -// __global__ void permute_fp8_rows_kernel(cutlass::float_e4m3_t const* __restrict__ a_ptr, +// __global__ void permute_fp8_rows_kernel(cutlass::float_e4m3_t const* +// __restrict__ a_ptr, // int const* __restrict__ perm_int_ptr, -// cutlass::float_e4m3_t* __restrict__ out_ptr, -// int size_m, int size_k, int block_rows) { +// cutlass::float_e4m3_t* __restrict__ +// out_ptr, int size_m, int size_k, int +// block_rows) { // int start_row = block_rows * blockIdx.x; // int finish_row = start_row + block_rows; // if (finish_row > size_m) { @@ -459,17 +394,14 @@ __global__ void get_a_expert_offsets(cutlass::float_e4m3_t** trg_a_ptrs, // }; // } -void compute_expert_offsets_caller(torch::Tensor& trg_a_ptrs, - torch::Tensor& a, +void compute_expert_offsets_caller(torch::Tensor& trg_a_ptrs, torch::Tensor& a, const torch::Tensor& topk_ids, torch::Tensor& expert_offsets, const int64_t num_experts) { - get_a_expert_offsets<<<1, num_experts>>>( + get_a_expert_offsets<<<1, num_experts>>>( (cutlass::float_e4m3_t**)trg_a_ptrs.data_ptr(), - (cutlass::float_e4m3_t*)a.data_ptr(), - (const int*)topk_ids.data_ptr(), - (int64_t*)expert_offsets.data_ptr(), - topk_ids.numel()); + (cutlass::float_e4m3_t*)a.data_ptr(), (const int*)topk_ids.data_ptr(), + (int64_t*)expert_offsets.data_ptr(), topk_ids.numel()); } // void permute_fp8_rows(torch::Tensor& a_ptr, diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 2fdb13307424a..81a97c3887bda 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -337,8 +337,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { "compute_expert_offsets(Tensor! trg_a_ptrs," " Tensor! a, Tensor topk_ids," " Tensor! expert_offsets, SymInt num_experts) -> ()"); - ops.impl("compute_expert_offsets", torch::kCUDA, - &compute_expert_offsets); + ops.impl("compute_expert_offsets", torch::kCUDA, &compute_expert_offsets); // Check if cutlass sparse scaled_mm is supported for CUDA devices of the // given capability diff --git a/tests/kernels/test_cutlass.py b/tests/kernels/test_cutlass.py index 0b71c5a1e6a45..a8c8630a04487 100644 --- a/tests/kernels/test_cutlass.py +++ b/tests/kernels/test_cutlass.py @@ -2,12 +2,13 @@ Run `pytest tests/kernels/test_cutlass.py`. """ +import random from typing import Type import pytest import torch -from tests.kernels.utils import opcheck, stack_and_dev +from tests.kernels.utils import opcheck from vllm import _custom_ops as ops from vllm.platforms import current_platform @@ -452,7 +453,6 @@ def test_cutlass_fp8_group_gemm(num_groups: int, per_act_token: bool, b_tensors = [] a_scales_tensors = [] b_scales_tensors = [] - out_tensors = [] baseline_tensors = [] expert_offsets = torch.zeros((num_groups + 1), @@ -460,15 +460,13 @@ def test_cutlass_fp8_group_gemm(num_groups: int, per_act_token: bool, dtype=torch.int32) problem_sizes = torch.zeros((num_groups, 3), - device=device, - dtype=torch.int32) + device=device, + dtype=torch.int32) alignment = 16 # 128 // 8 # For variation, each group has dimensions - # (m_g = m/(g+1), n_g = n/(g+1), k_g = k/(g+1)) n_g = alignment * random.randint(1, 64) k_g = alignment * random.randint(1, 64) - # one_b = to_fp8(torch.randn((n_g, k_g), device=device)) for g in range(num_groups): m_g = alignment * random.randint(1, 64) @@ -485,8 +483,6 @@ def test_cutlass_fp8_group_gemm(num_groups: int, per_act_token: bool, # Create group-specific A and B (FP8) and output (FP16/FP32) a_g = to_fp8(torch.randn((m_g, k_g), device=device)) b_g = to_fp8(torch.randn((n_g, k_g), device=device).t()) - # b_g = one_b.clone().t() - c_g = torch.zeros((m_g, n_g), device=device, dtype=out_dtype) # Set up A/B scales scale_a = torch.randn((m_a_scales, 1), device=device, @@ -497,12 +493,9 @@ def test_cutlass_fp8_group_gemm(num_groups: int, per_act_token: bool, a_tensors.append(a_g) b_tensors.append(b_g) - out_tensors.append(c_g) a_scales_tensors.append(scale_a) b_scales_tensors.append(scale_b) - print(b_g.stride()) - # Compute baseline result for this group baseline_g = baseline_scaled_mm(a_g, b_g, scale_a, scale_b, out_dtype, None) @@ -517,7 +510,7 @@ def test_cutlass_fp8_group_gemm(num_groups: int, per_act_token: bool, for g in range(num_groups): a_tensors_stacked[expert_offsets[g]:expert_offsets[g + 1]] = a_tensors[g] - b_tensors_stacked[g*n_g:(g+1)*n_g, :] = b_tensors[g].t() + b_tensors_stacked[g * n_g:(g + 1) * n_g, :] = b_tensors[g].t() b_tensors_stacked = b_tensors_stacked.t() a_scales_tensors_stacked = torch.empty( @@ -526,28 +519,27 @@ def test_cutlass_fp8_group_gemm(num_groups: int, per_act_token: bool, dtype=torch.float32) if per_act_token: for g in range(num_groups): - a_scales_tensors_stacked[expert_offsets[g]:expert_offsets[g + - 1]] = a_scales_tensors[g] + a_scales_tensors_stacked[ + expert_offsets[g]:expert_offsets[g + 1]] = a_scales_tensors[g] else: for g in range(num_groups): a_scales_tensors_stacked[g] = a_scales_tensors[g] - b_scales_tensors_stacked = torch.empty( - (num_groups, n_b_scales), - device=device, - dtype=torch.float32) + b_scales_tensors_stacked = torch.empty((num_groups, n_b_scales), + device=device, + dtype=torch.float32) for g in range(num_groups): b_scales_tensors_stacked[g] = b_scales_tensors[g] out_tensors_stacked = torch.zeros((expert_offsets[num_groups], n_g), - device=device, - dtype=out_dtype) + device=device, + dtype=out_dtype) torch.ops._C.cutlass_grouped_mm(out_tensors_stacked, a_tensors_stacked, b_tensors_stacked, a_scales_tensors_stacked, - b_scales_tensors_stacked, - expert_offsets, problem_sizes) + b_scales_tensors_stacked, expert_offsets, + problem_sizes) # Validate each group's result against the baseline for g in range(num_groups):