diff --git a/csrc/cpu/torch_bindings.cpp b/csrc/cpu/torch_bindings.cpp index d6c32322ff592..80a326cdc5ef4 100644 --- a/csrc/cpu/torch_bindings.cpp +++ b/csrc/cpu/torch_bindings.cpp @@ -118,7 +118,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { " Tensor b, Tensor a_scales," " Tensor b_scales, Tensor? bias) -> ()"); ops.impl("cutlass_scaled_mm", torch::kCPU, &int8_scaled_mm); -// CUTLASS w8a8 grouped GEMM // TODO complete this + // CUTLASS w8a8 grouped GEMM // TODO complete this ops.def( "cutlass_grouped_mm(Tensor! out, Tensor a, Tensor b, Tensor a_scales, " " Tensor b_scales, Tensor problem_sizes, " diff --git a/csrc/quantization/cutlass_w8a8/grouped_gemm_test.cu b/csrc/quantization/cutlass_w8a8/grouped_gemm_test.cu index 004599c2b5d26..db86bd1a4b466 100644 --- a/csrc/quantization/cutlass_w8a8/grouped_gemm_test.cu +++ b/csrc/quantization/cutlass_w8a8/grouped_gemm_test.cu @@ -33,12 +33,12 @@ using namespace cute; #if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 900 -#define ENABLE_SM90_KERNEL_LEVEL 1 + #define ENABLE_SM90_KERNEL_LEVEL 1 #endif namespace { - // A wrapper for the GEMM kernel that is used to guard against compilation on +// A wrapper for the GEMM kernel that is used to guard against compilation on // architectures that will never use the kernel. The purpose of this is to // reduce the size of the compiled binary. // __CUDA_ARCH__ is not defined in host code, so this lets us smuggle the ifdef @@ -47,32 +47,36 @@ template struct enable_sm90_or_later : Kernel { template CUTLASS_DEVICE void operator()(Args&&... args) { - #if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 900 +#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 900 Kernel::operator()(std::forward(args)...); - #endif +#endif } }; -using ProblemShape = cutlass::gemm::GroupProblemShape>; // per group -using ElementAB_Type = cutlass::float_e4m3_t; // Element type for A matrix operand -// using ElementB = cutlass::float_e4m3_t; // Element type for B matrix operand +using ProblemShape = + cutlass::gemm::GroupProblemShape>; // + // per group +using ElementAB_Type = + cutlass::float_e4m3_t; // Element type for A matrix operand +// using ElementB = cutlass::float_e4m3_t; // +// Element type for B matrix operand using ElementC_Type = cutlass::half_t; // Core kernel configurations -using ElementAccumulator = float; // Element type for internal accumulation -using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature -using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag +using ElementAccumulator = float; // Element type for internal accumulation +using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that + // supports the intended feature +using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag -using LayoutA = cutlass::layout::RowMajor; -using LayoutB = cutlass::layout::ColumnMajor; -using LayoutC = cutlass::layout::ColumnMajor; +using LayoutA = cutlass::layout::RowMajor; +using LayoutB = cutlass::layout::ColumnMajor; +using LayoutC = cutlass::layout::RowMajor; template typename Epilogue_, typename TileShape, typename ClusterShape, typename KernelSchedule, typename EpilogueSchedule> struct cutlass_3x_group_gemm { - using ElementAB = ElementAB_; using ElementC = ElementC_; using ElementAccumulator = float; @@ -84,42 +88,36 @@ struct cutlass_3x_group_gemm { using Epilogue = Epilogue_; - using StrideC = cute::remove_pointer_t, cute::Int<0>>>; + using StrideC = + cute::remove_pointer_t, cute::Int<0>>>; - const int AlignmentAB = 128 / cutlass::sizeof_bits::value; - const int AlignmentC = 128 / cutlass::sizeof_bits::value; + const int AlignmentAB = 128 / cutlass::sizeof_bits::value; + const int AlignmentC = 128 / cutlass::sizeof_bits::value; using EVTCompute = typename Epilogue::EVTCompute; - // the orig hat cutlass::epilogue::fusion::LinearCombination - - using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< - ArchTag, OperatorClass, - TileShape, ClusterShape, - cutlass::epilogue::collective::EpilogueTileAuto, - ElementAccumulator, ElementAccumulator, - ElementC, LayoutC*, 4, - ElementC, LayoutC*, 4, - EpilogueSchedule, EVTCompute - >::CollectiveOp; + // the orig hat cutlass::epilogue::fusion::LinearCombination + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, OperatorClass, TileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, ElementAccumulator, + ElementAccumulator, ElementC, LayoutC*, 4, ElementC, LayoutC*, 4, + EpilogueSchedule, EVTCompute>::CollectiveOp; static constexpr size_t CEStorageSize = sizeof(typename CollectiveEpilogue::SharedStorage); using Stages = typename cutlass::gemm::collective::StageCountAutoCarveout< static_cast(CEStorageSize)>; -using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< - ArchTag, OperatorClass, - ElementAB, LayoutA*, 16, - ElementAB, LayoutB*, 16, - ElementAccumulator, - TileShape, ClusterShape, - Stages, KernelSchedule - >::CollectiveOp; + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, ElementAB, LayoutA*, 16, ElementAB, LayoutB*, + 16, ElementAccumulator, TileShape, ClusterShape, Stages, + KernelSchedule>::CollectiveOp; using KernelType = enable_sm90_or_later>; + ProblemShape, CollectiveMainloop, CollectiveEpilogue>>; struct GemmKernel : public KernelType {}; }; @@ -127,20 +125,19 @@ using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder template struct ItemDeleter { void operator()(T* ptr) { - cudaFree(ptr); // noexcept + cudaFree(ptr); // noexcept } }; template void cutlass_group_gemm_caller(torch::Tensor& out, torch::Tensor const& a, - torch::Tensor const& b, - torch::Tensor const& problem_sizes, - torch::Tensor const& out_offsets, - torch::Tensor const& a_offsets, - torch::Tensor const& b_offsets, - EpilogueArgs&&... epilogue_params) { + torch::Tensor const& b, + torch::Tensor const& problem_sizes, + torch::Tensor const& out_offsets, + torch::Tensor const& a_offsets, + torch::Tensor const& b_offsets, + EpilogueArgs&&... epilogue_params) { using ElementAB = typename Gemm::ElementAB; - // using ElementC = typename Gemm::ElementC; using ElementC = typename Gemm::ElementC; using ElementAcc = float; @@ -151,43 +148,48 @@ void cutlass_group_gemm_caller(torch::Tensor& out, torch::Tensor const& a, std::vector d_ptrs_host(groups); for (int g = 0; g < groups; ++g) { - a_ptrs_host.at(g) = (ElementAB*)a.data_ptr() + a_offsets[g].item(); - b_ptrs_host.at(g) = (ElementAB*)b.data_ptr() + b_offsets[g].item(); - c_ptrs_host.at(g) = (ElementC*)out.data_ptr() + out_offsets[g].item(); - d_ptrs_host.at(g) = (ElementC*)out.data_ptr() + out_offsets[g].item(); + a_ptrs_host.at(g) = + static_cast(a.data_ptr()) + a_offsets[g].item(); + b_ptrs_host.at(g) = + static_cast(b.data_ptr()) + b_offsets[g].item(); + c_ptrs_host.at(g) = + static_cast(out.data_ptr()) + out_offsets[g].item(); + d_ptrs_host.at(g) = + static_cast(out.data_ptr()) + out_offsets[g].item(); + printf("%d %d %d\n", a_offsets[g].item(), + b_offsets[g].item(), out_offsets[g].item()); } - using StrideA = typename Gemm::GemmKernel::InternalStrideA; - using StrideB = typename Gemm::GemmKernel::InternalStrideB; - using StrideC = typename Gemm::GemmKernel::InternalStrideC; - using StrideD = typename Gemm::GemmKernel::InternalStrideD; - - int64_t lda = a.stride(0); - int64_t ldb = b.stride(1); - int64_t ldc = out.stride(0); - - std::vector a_stride_host(groups, StrideA{lda, cute::Int<1>{}, cute::Int<0>{}}); - std::vector b_stride_host(groups, StrideB{ldb, cute::Int<1>{}, cute::Int<0>{}}); - // TODO fix - std::vector c_stride_host(groups, StrideC{cute::Int<1>{}, ldc, cute::Int<0>{}}); + using GemmKernel = typename Gemm::GemmKernel; - cutlass::platform::unique_ptr stride_A; - cutlass::platform::unique_ptr stride_B; - cutlass::platform::unique_ptr stride_C; - cutlass::platform::unique_ptr stride_D; + using StrideA = typename GemmKernel::InternalStrideA; + using StrideB = typename GemmKernel::InternalStrideB; + using StrideC = typename GemmKernel::InternalStrideC; + // using StrideD = typename GemmKernel::InternalStrideD; - cutlass::platform::unique_ptr ptr_A; - cutlass::platform::unique_ptr ptr_B; - cutlass::platform::unique_ptr ptr_C; - cutlass::platform::unique_ptr ptr_D; + std::vector a_stride_host(groups); + std::vector b_stride_host(groups); + std::vector c_stride_host(groups); - using GemmKernel = typename Gemm::GemmKernel; + for (int g = 0; g < groups; ++g) { + int32_t m = problem_sizes[g][0].item(); + int32_t n = problem_sizes[g][1].item(); + int32_t k = problem_sizes[g][2].item(); + a_stride_host[g] = StrideA{k, cute::Int<1>{}, cute::Int<0>{}}; // m x k, + // row + b_stride_host[g] = StrideB{k, cute::Int<1>{}, cute::Int<0>{}}; // k x n, + // col + c_stride_host[g] = StrideC{n, cute::Int<1>{}, cute::Int<0>{}}; // m x n, + // row + } cutlass::KernelHardwareInfo hw_info; - // Change device_id to another value if you are running on a machine with multiple GPUs and wish - // to use a GPU other than that with device ID 0. + // Change device_id to another value if you are running on a machine with + // multiple GPUs and wish to use a GPU other than that with device ID 0. hw_info.device_id = 0; - hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + hw_info.sm_count = + cutlass::KernelHardwareInfo::query_device_multiprocessor_count( + hw_info.device_id); using SingleProblemShape = typename ProblemShape::UnderlyingProblemShape; @@ -203,76 +205,83 @@ void cutlass_group_gemm_caller(torch::Tensor& out, torch::Tensor const& a, SingleProblemShape* problem_sizes_device; int32_t problem_sizes_size = groups * sizeof(SingleProblemShape); cudaMalloc(&problem_sizes_device, problem_sizes_size); - cudaMemcpy(problem_sizes_device, problem_sizes_host.data(), groups, - cudaMemcpyHostToDevice); - cutlass::platform::unique_ptr> problem_sizes_ptr( - problem_sizes_device); - ProblemShape prob_shape{groups, problem_sizes_ptr.get(), problem_sizes_host.data()}; + cudaMemcpy(problem_sizes_device, problem_sizes_host.data(), + groups * sizeof(SingleProblemShape), cudaMemcpyHostToDevice); + cutlass::platform::unique_ptr> + problem_sizes_ptr(problem_sizes_device); + ProblemShape prob_shape{groups, problem_sizes_ptr.get(), + problem_sizes_host.data()}; + + // ElementAB* a_host_print; + // int numel = a.numel(); + // cudaMalloc(&a_host_print, groups * sizeof(ElementAB)); + // cudaMemcpy(a_host_print, static_cast(a.data_ptr()), numel* + // sizeof(ElementAB), cudaMemcpyDeviceToHost); + // cudaMemcpy(static_cast(a.data_ptr()), a_host_print, numel* + // sizeof(ElementAB), cudaMemcpyHostToDevice); cudaFree(a_host_print); const ElementAB** a_ptrs_device; cudaMalloc(&a_ptrs_device, groups * sizeof(ElementAB*)); - cudaMemcpy(a_ptrs_device, a_ptrs_host.data(), groups, cudaMemcpyHostToDevice); - cutlass::platform::unique_ptr> a_ptrs_ptr( - a_ptrs_device - ); + cudaMemcpy(a_ptrs_device, a_ptrs_host.data(), groups * sizeof(ElementAB*), + cudaMemcpyHostToDevice); + cutlass::platform::unique_ptr> + a_ptrs_ptr(a_ptrs_device); const ElementAB** b_ptrs_device; cudaMalloc(&b_ptrs_device, groups * sizeof(ElementAB*)); - cudaMemcpy(b_ptrs_device, b_ptrs_host.data(), groups, cudaMemcpyHostToDevice); - cutlass::platform::unique_ptr> b_ptrs_ptr( - b_ptrs_device - ); + cudaMemcpy(b_ptrs_device, b_ptrs_host.data(), groups * sizeof(ElementAB*), + cudaMemcpyHostToDevice); + cutlass::platform::unique_ptr> + b_ptrs_ptr(b_ptrs_device); const ElementC** c_ptrs_device; cudaMalloc(&c_ptrs_device, groups * sizeof(ElementC*)); - cudaMemcpy(c_ptrs_device, c_ptrs_host.data(), groups, cudaMemcpyHostToDevice); - cutlass::platform::unique_ptr> c_ptrs_ptr( - c_ptrs_device - ); + cudaMemcpy(c_ptrs_device, c_ptrs_host.data(), groups * sizeof(ElementC*), + cudaMemcpyHostToDevice); + cutlass::platform::unique_ptr> + c_ptrs_ptr(c_ptrs_device); - // TODO if we start with empty values here, no need to copy ElementC** d_ptrs_device; cudaMalloc(&d_ptrs_device, groups * sizeof(ElementC*)); - cudaMemcpy(d_ptrs_device, d_ptrs_host.data(), groups, cudaMemcpyHostToDevice); + cudaMemcpy(d_ptrs_device, d_ptrs_host.data(), groups * sizeof(ElementC*), + cudaMemcpyHostToDevice); cutlass::platform::unique_ptr> d_ptrs_ptr( - d_ptrs_device - ); + d_ptrs_device); StrideA* a_stride_device; - cudaMalloc(&a_stride_device, groups * sizeof(StrideA*)); - cudaMemcpy(a_stride_device, a_stride_host.data(), groups, cudaMemcpyHostToDevice); + cudaMalloc(&a_stride_device, groups * sizeof(StrideA)); + cudaMemcpy(a_stride_device, a_stride_host.data(), groups * sizeof(StrideA), + cudaMemcpyHostToDevice); cutlass::platform::unique_ptr> a_stride_ptr( - a_stride_device - ); + a_stride_device); StrideB* b_stride_device; - cudaMalloc(&b_stride_device, groups * sizeof(StrideB*)); - cudaMemcpy(b_stride_device, b_stride_host.data(), groups, cudaMemcpyHostToDevice); + cudaMalloc(&b_stride_device, groups * sizeof(StrideB)); + cudaMemcpy(b_stride_device, b_stride_host.data(), groups * sizeof(StrideB), + cudaMemcpyHostToDevice); cutlass::platform::unique_ptr> b_stride_ptr( - b_stride_device - ); + b_stride_device); StrideC* c_stride_device; - cudaMalloc(&c_stride_device, groups * sizeof(StrideC*)); - cudaMemcpy(c_stride_device, c_stride_host.data(), groups, cudaMemcpyHostToDevice); + cudaMalloc(&c_stride_device, groups * sizeof(StrideC)); + cudaMemcpy(c_stride_device, c_stride_host.data(), groups * sizeof(StrideC), + cudaMemcpyHostToDevice); cutlass::platform::unique_ptr> c_stride_ptr( - c_stride_device - ); + c_stride_device); typename GemmKernel::MainloopArguments mainloop_args{ - a_ptrs_ptr.get(), a_stride_ptr.get(), b_ptrs_ptr.get(), b_stride_ptr.get()}; + a_ptrs_ptr.get(), a_stride_ptr.get(), b_ptrs_ptr.get(), + b_stride_ptr.get()}; typename GemmKernel::EpilogueArguments epilogue_args{ Gemm::Epilogue::prepare_args( std::forward(epilogue_params)...), - c_ptrs_ptr.get(), c_stride_ptr.get(), d_ptrs_ptr.get(), c_stride_ptr.get()}; + c_ptrs_ptr.get(), c_stride_ptr.get(), d_ptrs_ptr.get(), + c_stride_ptr.get()}; typename GemmKernel::Arguments args{ - cutlass::gemm::GemmUniversalMode::kGrouped, - prob_shape, - mainloop_args, - epilogue_args, - hw_info - }; + cutlass::gemm::GemmUniversalMode::kGrouped, prob_shape, mainloop_args, + epilogue_args, hw_info}; // Launch the CUTLASS GEMM kernel. using GemmOp = cutlass::gemm::device::GemmUniversalAdapter; @@ -284,18 +293,14 @@ void cutlass_group_gemm_caller(torch::Tensor& out, torch::Tensor const& a, torch::TensorOptions().dtype(torch::kUInt8).device(a.device()); auto workspace = torch::empty(workspace_size, workspace_options); - // // auto stream = at::cuda::getCurrentCUDAStream(a.get_device()); + auto stream = at::cuda::getCurrentCUDAStream(a.get_device()); - CUTLASS_CHECK(gemm_op.initialize(args, workspace.data_ptr())); - - cutlass::Status status = gemm_op.run(); + cutlass::Status status = gemm_op.run(args, workspace.data_ptr(), stream); CUTLASS_CHECK(status); - } // typedef InType = cutlass::float_e4m3_t; // typedef OutType = torch::half; -// typedef Epilogue = ScaledEpilogueBias; template typename Epilogue> @@ -304,12 +309,13 @@ struct sm90_fp8_config_default { static_assert(std::is_same()); using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8FastAccum; - using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong; + using EpilogueSchedule = + cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong; using TileShape = cute::Shape; using ClusterShape = cute::Shape; using Cutlass3xGemm = cutlass_3x_group_gemm; + KernelSchedule, EpilogueSchedule>; }; template ()); using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8FastAccum; - using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong; + using EpilogueSchedule = + cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong; using TileShape = cute::Shape; using ClusterShape = cute::Shape; using Cutlass3xGemm = cutlass_3x_group_gemm; + KernelSchedule, EpilogueSchedule>; }; template ()); using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8FastAccum; - using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong; + using EpilogueSchedule = + cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong; using TileShape = cute::Shape; using ClusterShape = cute::Shape; using Cutlass3xGemm = cutlass_3x_group_gemm; + KernelSchedule, EpilogueSchedule>; }; -} +} // namespace // TODO hardcode types here? -void cutlass_grouped_mm_sm90(torch::Tensor& out, torch::Tensor const& a, - torch::Tensor const& b, torch::Tensor const& a_scales, - torch::Tensor const& b_scales, - torch::Tensor const& problem_sizes, - torch::Tensor const& out_offsets, - torch::Tensor const& a_offsets, - torch::Tensor const& b_offsets) { - +void cutlass_grouped_mm_sm90( + torch::Tensor& out, torch::Tensor const& a, torch::Tensor const& b, + torch::Tensor const& a_scales, torch::Tensor const& b_scales, + torch::Tensor const& problem_sizes, torch::Tensor const& out_offsets, + torch::Tensor const& a_offsets, torch::Tensor const& b_offsets) { TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn); TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn); // int32_t m = a.size(1); - using Cutlass3xGemmDefault = - typename sm90_fp8_config_default::Cutlass3xGemm; + using Cutlass3xGemmDefault = typename sm90_fp8_config_default< + ElementAB_Type, ElementC_Type, vllm::c3x::ScaledEpilogue>::Cutlass3xGemm; // using Cutlass3xGemmM64 = - // typename sm90_fp8_config_M64::Cutlass3xGemm; + // typename sm90_fp8_config_M64::Cutlass3xGemm; // using Cutlass3xGemmM128 = - // typename sm90_fp8_config_M128::Cutlass3xGemm; - + // typename sm90_fp8_config_M128::Cutlass3xGemm; // // uint32_t const m = a.size(0); // uint32_t const mp2 = @@ -373,14 +378,16 @@ void cutlass_grouped_mm_sm90(torch::Tensor& out, torch::Tensor const& a, // if (mp2 <= 64) { // // m in [1, 64] - // cutlass_group_gemm_caller(out, a, b, a_scales, b_scales); + // cutlass_group_gemm_caller(out, a, b, a_scales, + // b_scales); // } else if (mp2 <= 128) { // // m in (64, 128] - // cutlass_group_gemm_caller(out, a, b, a_scales, b_scales); + // cutlass_group_gemm_caller(out, a, b, a_scales, + // b_scales); // } else { // // m in (128, inf) - cutlass_group_gemm_caller(out, a, b, problem_sizes, - out_offsets, a_offsets, b_offsets, a_scales, b_scales); + cutlass_group_gemm_caller( + out, a, b, problem_sizes, out_offsets, a_offsets, b_offsets, a_scales, + b_scales); // } - } diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu index 78225f9b0db0a..961437893dee0 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu @@ -28,13 +28,11 @@ void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a, torch::Tensor const& b_scales, c10::optional const& bias); -void cutlass_grouped_mm_sm90(torch::Tensor& out, torch::Tensor const& a, - torch::Tensor const& b, torch::Tensor const& a_scales, - torch::Tensor const& b_scales, - torch::Tensor const& problem_sizes, - torch::Tensor const& out_offsets, - torch::Tensor const& a_offsets, - torch::Tensor const& b_offsets); +void cutlass_grouped_mm_sm90( + torch::Tensor& out, torch::Tensor const& a, torch::Tensor const& b, + torch::Tensor const& a_scales, torch::Tensor const& b_scales, + torch::Tensor const& problem_sizes, torch::Tensor const& out_offsets, + torch::Tensor const& a_offsets, torch::Tensor const& b_offsets); #endif diff --git a/tests/kernels/test_cutlass.py b/tests/kernels/test_cutlass.py index a97c8f307df32..563a3f433d98b 100644 --- a/tests/kernels/test_cutlass.py +++ b/tests/kernels/test_cutlass.py @@ -2,11 +2,11 @@ Run `pytest tests/kernels/test_cutlass.py`. """ +import random from typing import Optional, Type import pytest import torch -import random from tests.kernels.utils import opcheck from vllm import _custom_ops as ops @@ -455,41 +455,43 @@ def test_cutlass_cuda_graph(per_act_token: bool, per_out_ch: bool): def test_cutlass_support_opcheck(): opcheck(torch.ops._C.cutlass_scaled_mm_supports_fp8, (capability, )) + # TODO fix scales @pytest.mark.parametrize("m,n,k", [(2048, 2048, 2048)]) @pytest.mark.parametrize("num_groups", [10]) -@pytest.mark.parametrize("per_act_token", [False])# [True, False]) -@pytest.mark.parametrize("per_out_ch", [True])# [True, False]) -@pytest.mark.parametrize("use_bias", [False])# [True, False]) +@pytest.mark.parametrize("per_act_token", [False]) # [True, False]) +@pytest.mark.parametrize("per_out_ch", [True]) # [True, False]) +@pytest.mark.parametrize("use_bias", [False]) # [True, False]) @pytest.mark.skipif(not current_platform.has_device_capability(89), reason="FP8 is not supported on this GPU type.") def test_cutlass_fp8_group_gemm(m: int, n: int, k: int, num_groups: int, - per_act_token: bool, - per_out_ch: bool, use_bias: bool): + per_act_token: bool, per_out_ch: bool, + use_bias: bool): # Test for a cutlass kernel with per-token activation quantization # and per-output channel weight quantization. device = "cuda" out_dtype = torch.half - alignment = 16 # 128 // 8 + alignment = 16 # 128 // 8 problem_sizes = torch.empty((num_groups, 3), device="cpu") - offsets_a = torch.empty((num_groups), device="cpu") - offsets_b = torch.empty((num_groups), device="cpu") - offsets_c = torch.empty((num_groups), device="cpu") + offsets_a = torch.empty((num_groups), device="cpu", dtype=torch.int32) + offsets_b = torch.empty((num_groups), device="cpu", dtype=torch.int32) + offsets_c = torch.empty((num_groups), device="cpu", dtype=torch.int32) tot_a = 0 tot_b = 0 tot_c = 0 + m = alignment * random.randint(1, 64) + n = alignment * random.randint(1, 64) + k = alignment * random.randint(1, 64) for g in range(num_groups): - m = alignment * random.randint(1, 64) - n = alignment * random.randint(1, 64) - k = alignment * random.randint(1, 64) tot_a += m tot_b += k tot_c += m - offsets_a[g] = m * k - offsets_b[g] = k * n - offsets_c[g] = m * n + print(m, n, k) + offsets_a[g] = g * m * k + offsets_b[g] = g * k * n + offsets_c[g] = g * m * n problem_sizes[g][0] = m problem_sizes[g][1] = n problem_sizes[g][2] = k @@ -497,32 +499,67 @@ def test_cutlass_fp8_group_gemm(m: int, n: int, k: int, num_groups: int, a = to_fp8(torch.randn((tot_a, k), device=device)) b = to_fp8(torch.randn((tot_b, n), device=device).t()) c = torch.zeros((tot_c, n), device=device).to(out_dtype) + baseline = torch.zeros((tot_c, n), device=device).to(out_dtype) - print(tot_a, tot_b, tot_c) + # print(a) + # print(b) - print(a.stride(), b.stride(), c.stride()) + # print(offsets_a) + # print(offsets_b) + # print(offsets_c) + # print(tot_a, tot_b, tot_c) + + # print(a.stride(), b.stride(), c.stride()) # m_a_scales = m if per_act_token else 1 # n_b_scales = n if per_out_ch else 1 - scale_a = (torch.randn((tot_a if per_act_token else num_groups), - device=device, - dtype=torch.float32)) - scale_b = (torch.randn((tot_b if per_act_token else num_groups), - device=device, - dtype=torch.float32)) + # scale_a = (torch.randn((tot_a if per_act_token else num_groups), + # device=device, + # dtype=torch.float32)) + # scale_b = (torch.randn((tot_b if per_act_token else num_groups), + # device=device, + # dtype=torch.float32)) + + scale_a = (torch.ones((tot_a if per_act_token else num_groups), + device=device, + dtype=torch.float32)) + scale_b = (torch.ones((tot_b if per_act_token else num_groups), + device=device, + dtype=torch.float32)) + # if use_bias: # bias = torch.rand((n, 1), device=device, dtype=out_dtype) * 10 # else: # bias = None + print(a) + # TODO strides we can get later the same way as in scaled_mm_c3x.cu torch.ops._C.cutlass_grouped_mm(c, a, b, scale_a, scale_b, problem_sizes, offsets_c, offsets_a, offsets_b) - # baseline = baseline_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias) + # baseline = baseline_scaled_mm(a, b, scale_a, scale_b, out_dtype, None) + # print(a.dtype) + # print(a) + + # torch.set_printoptions(profile='full') + # # print(c[2*m:3*m]) + # print(torch.max(c, dim=1)) + # print(torch.max(c, dim=0)) print(c) + for g in range(num_groups): + baseline[g * m:(g + 1) * m] = baseline_scaled_mm( + a[g * m:(g + 1) * m], + b.t()[g * k:(g + 1) * k], + scale_a[g * m:(g + 1) * m] if per_act_token else scale_a[g], + scale_b[g * k:(g + 1) * k] if per_act_token else scale_b[g], + out_dtype, None) + print(baseline[g * m:(g + 1) * m]) + print(c[g * m:(g + 1) * m]) + print("*") + # torch.testing.assert_close(c, baseline, rtol=1e-2, atol=5e-2) # opcheck(torch.ops._C.cutlass_scaled_mm,