diff --git a/csrc/quantization/cutlass_w8a8/grouped_gemm_test.cu b/csrc/quantization/cutlass_w8a8/grouped_gemm_test.cu index 8e46b9a33cea3..004599c2b5d26 100644 --- a/csrc/quantization/cutlass_w8a8/grouped_gemm_test.cu +++ b/csrc/quantization/cutlass_w8a8/grouped_gemm_test.cu @@ -58,35 +58,14 @@ using ElementAB_Type = cutlass::float_e4m3_t; // using ElementB = cutlass::float_e4m3_t; // Element type for B matrix operand using ElementC_Type = cutlass::half_t; -// // A matrix configuration -// using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand -// constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; // Alignment of A matrix in units of elements (up to 16 bytes) - -// // B matrix configuration -// using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand -// constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; // Alignment of B matrix in units of elements (up to 16 bytes) - -// // C/D matrix configuration -// using LayoutC = cutlass::layout::ColumnMajor; // Layout type for C and D matrix operands -// constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; // Alignment of C matrix in units of elements (up to 16 bytes) - // Core kernel configurations using ElementAccumulator = float; // Element type for internal accumulation using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag -// using StageCountType = cutlass::gemm::collective::StageCountAuto; // Stage count maximized based on the tile size - -// Different configs for pingpong/cooperative -// struct CooperativeConfig { -// using KernelSchedule = cutlass::KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum; -// using EpilogueSchedule = cutlass::KernelPtrArrayTmaWarpSpecializedCooperative; -// using TileShape = cute::Shape; -// using ClusterShape = cute::Shape; -// }; -using LayoutA = cutlass::layout::RowMajor; -using LayoutB = cutlass::layout::ColumnMajor; -using LayoutC = cutlass::layout::ColumnMajor; +using LayoutA = cutlass::layout::RowMajor; +using LayoutB = cutlass::layout::ColumnMajor; +using LayoutC = cutlass::layout::ColumnMajor; template typename Epilogue_, @@ -107,8 +86,8 @@ struct cutlass_3x_group_gemm { using StrideC = cute::remove_pointer_t, cute::Int<0>>>; - const int AlignmentAB = 128 / cutlass::sizeof_bits::value; - const int AlignmentC = 128 / cutlass::sizeof_bits::value; + const int AlignmentAB = 128 / cutlass::sizeof_bits::value; + const int AlignmentC = 128 / cutlass::sizeof_bits::value; using EVTCompute = typename Epilogue::EVTCompute; // the orig hat cutlass::epilogue::fusion::LinearCombination @@ -172,34 +151,25 @@ void cutlass_group_gemm_caller(torch::Tensor& out, torch::Tensor const& a, std::vector d_ptrs_host(groups); for (int g = 0; g < groups; ++g) { - a_ptrs_host.at(g) = (ElementAB*)a.data_ptr();// + a_offsets[g].item(); - b_ptrs_host.at(g) = (ElementAB*)b.data_ptr();// + b_offsets[g].item(); - c_ptrs_host.at(g) = (ElementC*)out.data_ptr();// + out_offsets[g].item(); - d_ptrs_host.at(g) = (ElementC*)out.data_ptr();// + out_offsets[g].item(); + a_ptrs_host.at(g) = (ElementAB*)a.data_ptr() + a_offsets[g].item(); + b_ptrs_host.at(g) = (ElementAB*)b.data_ptr() + b_offsets[g].item(); + c_ptrs_host.at(g) = (ElementC*)out.data_ptr() + out_offsets[g].item(); + d_ptrs_host.at(g) = (ElementC*)out.data_ptr() + out_offsets[g].item(); } - // int32_t groups = a.size(0); - // int32_t m = a.size(1); - // int32_t n = b.size(2); - // int32_t k = a.size(2); - - // int64_t lda = a.stride(1); - // int64_t ldb = b.stride(2); - // int64_t ldc = out.stride(1); - using StrideA = typename Gemm::GemmKernel::InternalStrideA; using StrideB = typename Gemm::GemmKernel::InternalStrideB; using StrideC = typename Gemm::GemmKernel::InternalStrideC; using StrideD = typename Gemm::GemmKernel::InternalStrideD; - // StrideA stride_A{lda, cute::Int<1>{}, 0}; - // StrideB stride_B{ldb, cute::Int<1>{}, 0}; - // StrideC stride_C{ldc, cute::Int<1>{}, cute::Int<0>{}}; + int64_t lda = a.stride(0); + int64_t ldb = b.stride(1); + int64_t ldc = out.stride(0); - // this should be vector of A ptrs - // auto ptr_A = static_cast(a.data_ptr()); - // auto ptr_B = static_cast(b.data_ptr()); - // auto ptr_C = static_cast(out.data_ptr()); + std::vector a_stride_host(groups, StrideA{lda, cute::Int<1>{}, cute::Int<0>{}}); + std::vector b_stride_host(groups, StrideB{ldb, cute::Int<1>{}, cute::Int<0>{}}); + // TODO fix + std::vector c_stride_host(groups, StrideC{cute::Int<1>{}, ldc, cute::Int<0>{}}); cutlass::platform::unique_ptr stride_A; cutlass::platform::unique_ptr stride_B; @@ -212,7 +182,7 @@ void cutlass_group_gemm_caller(torch::Tensor& out, torch::Tensor const& a, cutlass::platform::unique_ptr ptr_D; using GemmKernel = typename Gemm::GemmKernel; - + cutlass::KernelHardwareInfo hw_info; // Change device_id to another value if you are running on a machine with multiple GPUs and wish // to use a GPU other than that with device ID 0. @@ -241,38 +211,60 @@ void cutlass_group_gemm_caller(torch::Tensor& out, torch::Tensor const& a, const ElementAB** a_ptrs_device; cudaMalloc(&a_ptrs_device, groups * sizeof(ElementAB*)); - cudaMemcpy(a_ptrs_device, a_ptrs_host.data(), groups,cudaMemcpyHostToDevice); + cudaMemcpy(a_ptrs_device, a_ptrs_host.data(), groups, cudaMemcpyHostToDevice); cutlass::platform::unique_ptr> a_ptrs_ptr( a_ptrs_device ); const ElementAB** b_ptrs_device; cudaMalloc(&b_ptrs_device, groups * sizeof(ElementAB*)); - cudaMemcpy(b_ptrs_device, b_ptrs_host.data(), groups,cudaMemcpyHostToDevice); + cudaMemcpy(b_ptrs_device, b_ptrs_host.data(), groups, cudaMemcpyHostToDevice); cutlass::platform::unique_ptr> b_ptrs_ptr( b_ptrs_device ); const ElementC** c_ptrs_device; cudaMalloc(&c_ptrs_device, groups * sizeof(ElementC*)); - cudaMemcpy(c_ptrs_device, c_ptrs_host.data(), groups,cudaMemcpyHostToDevice); + cudaMemcpy(c_ptrs_device, c_ptrs_host.data(), groups, cudaMemcpyHostToDevice); cutlass::platform::unique_ptr> c_ptrs_ptr( c_ptrs_device ); + // TODO if we start with empty values here, no need to copy ElementC** d_ptrs_device; cudaMalloc(&d_ptrs_device, groups * sizeof(ElementC*)); - cudaMemcpy(d_ptrs_device, d_ptrs_host.data(), groups,cudaMemcpyHostToDevice); + cudaMemcpy(d_ptrs_device, d_ptrs_host.data(), groups, cudaMemcpyHostToDevice); cutlass::platform::unique_ptr> d_ptrs_ptr( d_ptrs_device ); + StrideA* a_stride_device; + cudaMalloc(&a_stride_device, groups * sizeof(StrideA*)); + cudaMemcpy(a_stride_device, a_stride_host.data(), groups, cudaMemcpyHostToDevice); + cutlass::platform::unique_ptr> a_stride_ptr( + a_stride_device + ); + + StrideB* b_stride_device; + cudaMalloc(&b_stride_device, groups * sizeof(StrideB*)); + cudaMemcpy(b_stride_device, b_stride_host.data(), groups, cudaMemcpyHostToDevice); + cutlass::platform::unique_ptr> b_stride_ptr( + b_stride_device + ); + + StrideC* c_stride_device; + cudaMalloc(&c_stride_device, groups * sizeof(StrideC*)); + cudaMemcpy(c_stride_device, c_stride_host.data(), groups, cudaMemcpyHostToDevice); + cutlass::platform::unique_ptr> c_stride_ptr( + c_stride_device + ); + typename GemmKernel::MainloopArguments mainloop_args{ - a_ptrs_ptr.get(), stride_A.get(), b_ptrs_ptr.get(), stride_B.get()}; + a_ptrs_ptr.get(), a_stride_ptr.get(), b_ptrs_ptr.get(), b_stride_ptr.get()}; typename GemmKernel::EpilogueArguments epilogue_args{ Gemm::Epilogue::prepare_args( std::forward(epilogue_params)...), - c_ptrs_ptr.get(), stride_C.get(), d_ptrs_ptr.get(), stride_D.get()}; + c_ptrs_ptr.get(), c_stride_ptr.get(), d_ptrs_ptr.get(), c_stride_ptr.get()}; typename GemmKernel::Arguments args{ cutlass::gemm::GemmUniversalMode::kGrouped, @@ -296,11 +288,8 @@ void cutlass_group_gemm_caller(torch::Tensor& out, torch::Tensor const& a, CUTLASS_CHECK(gemm_op.initialize(args, workspace.data_ptr())); - // #if defined(ENABLE_SM90_KERNEL_LEVEL) - // printf("did run through\n"); - cutlass::Status status = gemm_op.run(); - CUTLASS_CHECK(status); - // #endif + cutlass::Status status = gemm_op.run(); + CUTLASS_CHECK(status); } @@ -367,7 +356,7 @@ void cutlass_grouped_mm_sm90(torch::Tensor& out, torch::Tensor const& a, TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn); TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn); - // int32_t m = a.size(1); + // int32_t m = a.size(1); using Cutlass3xGemmDefault = typename sm90_fp8_config_default