From 9ac3853d27d3e9a11167a35e2a2dd2e67e547e90 Mon Sep 17 00:00:00 2001 From: Eric Wong Date: Thu, 28 Aug 2025 17:39:48 +0800 Subject: [PATCH 01/71] feat: add signal gemm api for SBO (Single Batch Overlap). Co-authored-by: Zqy11 <841971412@qq.com> --- csrc/apis/gemm.hpp | 53 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 53 insertions(+) diff --git a/csrc/apis/gemm.hpp b/csrc/apis/gemm.hpp index a6bd3447..2b9367d5 100644 --- a/csrc/apis/gemm.hpp +++ b/csrc/apis/gemm.hpp @@ -221,6 +221,55 @@ static void m_grouped_fp8_gemm_nt_masked(const std::pair m_grouped_fp8_gemm_nt_signal(const std::pair& a, + const std::pair& b, + const torch::Tensor& d, + const torch::Tensor& masked_m, + const torch::Tensor& signal, + const int& expected_m, + std::optional> recipe, + const std::string& compiled_dims, + const bool& disable_ue8m0_cast) { + // Shape must be `[G, M, K] @ [G, N, K].mT` + const auto& major_a = get_major_type_ab(a.first); + const auto& major_b = get_major_type_ab(b.first); + DG_HOST_ASSERT(major_a == cute::UMMA::Major::K and major_b == cute::UMMA::Major::K); + DG_HOST_ASSERT(masked_m.is_contiguous()); + DG_HOST_ASSERT(signal.is_contiguous()); + + // Type and shape checks + const auto& [num_groups, m, k] = get_shape<3>(a.first); + const auto& [num_groups_, n, k_] = get_shape<3>(b.first); + const auto& [num_groups__, m_, n_] = get_shape<3>(d); + const auto& num_groups___ = static_cast(masked_m.numel()); + DG_HOST_ASSERT(num_groups == num_groups_ and num_groups == num_groups__ and num_groups == num_groups___); + DG_HOST_ASSERT(m == m_ and n == n_ and k == k_); + DG_HOST_ASSERT(expected_m > 0 and m > 0 and n > 0 and k > 0 and num_groups > 0); + DG_HOST_ASSERT(a.first.scalar_type() == torch::kFloat8_e4m3fn); + DG_HOST_ASSERT(b.first.scalar_type() == torch::kFloat8_e4m3fn); + DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16); + DG_HOST_ASSERT(masked_m.scalar_type() == torch::kInt); + DG_HOST_ASSERT(signal.scalar_type() == torch::kInt32); + + // D must be N-major + check_major_type_cd(d); + + // Transform scaling factors + if (not recipe.has_value()) + recipe = get_default_recipe(a.second.scalar_type(), b.second.scalar_type()); + const auto& sfa = layout::transform_sf_into_required_layout(a.second, m, k, recipe.value(), num_groups, true, disable_ue8m0_cast); + const auto& sfb = layout::transform_sf_into_required_layout(b.second, n, k, recipe.value(), num_groups, false, disable_ue8m0_cast); + + // Dispatch implementation - only SM90 + const auto& arch_major = device_runtime->get_arch_major(); + if (arch_major == 9 and sfa.scalar_type() == torch::kFloat) { + return sm90_m_grouped_fp8_gemm_signal_1d2d(a.first, sfa, b.first, sfb, d, masked_m, signal, + num_groups, m, n, k, expected_m, major_a, major_b, compiled_dims); + } else { + DG_HOST_UNREACHABLE("Unsupported architecture or scaling factor types"); + } +} + static void k_grouped_fp8_gemm_tn_contiguous(const std::pair& a, const std::pair& b, const torch::Tensor& d, @@ -437,6 +486,10 @@ static void register_apis(pybind11::module_& m) { py::arg("a"), py::arg("b"), py::arg("d"), py::arg("masked_m"), py::arg("expected_m"), py::arg("recipe") = std::nullopt, py::arg("compiled_dims") = "nk", py::arg("disable_ue8m0_cast") = false); + m.def("m_grouped_fp8_gemm_nt_signal", &m_grouped_fp8_gemm_nt_signal, + py::arg("a"), py::arg("b"), py::arg("d"), py::arg("masked_m"), py::arg("signal"), + py::arg("expected_m"), py::arg("recipe") = std::nullopt, + py::arg("compiled_dims") = "nk", py::arg("disable_ue8m0_cast") = false); m.def("k_grouped_fp8_gemm_tn_contiguous", &k_grouped_fp8_gemm_tn_contiguous, py::arg("a"), py::arg("b"), py::arg("d"), py::arg("ks"), py::arg("ks_tensor"), py::arg("c") = std::nullopt, From 79b5a453806b98c9a8ba8f52c2a2e279035bed7f Mon Sep 17 00:00:00 2001 From: Eric Wong Date: Thu, 28 Aug 2025 17:42:19 +0800 Subject: [PATCH 02/71] feat: add launch config args for coorperative groups. Co-authored-by: Zqy11 <841971412@qq.com> --- csrc/jit/handle.hpp | 58 +++++++++++++++++++++++++++---------- csrc/jit/kernel_runtime.hpp | 13 ++++++--- 2 files changed, 52 insertions(+), 19 deletions(-) diff --git a/csrc/jit/handle.hpp b/csrc/jit/handle.hpp index e05cf92c..39542048 100644 --- a/csrc/jit/handle.hpp +++ b/csrc/jit/handle.hpp @@ -37,7 +37,8 @@ static void unload_library(const LibraryHandle& library) { static LaunchConfigHandle construct_launch_config(const KernelHandle& kernel, const cudaStream_t& stream, const int& smem_size, - const dim3& grid_dim, const dim3& block_dim, const int& cluster_dim) { + const dim3& grid_dim, const dim3& block_dim, + const int& cluster_dim, const bool& cooperative = false) { if (smem_size > 0) DG_CUDA_RUNTIME_CHECK(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); @@ -49,14 +50,27 @@ static LaunchConfigHandle construct_launch_config(const KernelHandle& kernel, config.numAttrs = 0; config.attrs = nullptr; - // NOTES: must use `static` or the `attr` will be deconstructed - static LaunchAttrHandle attr; + // 支持多个属性 + static LaunchAttrHandle attrs[2]; + int attr_count = 0; + if (cluster_dim > 1) { - attr.id = cudaLaunchAttributeClusterDimension; - attr.val.clusterDim = {static_cast(cluster_dim), 1, 1}; - config.attrs = &attr; - config.numAttrs = 1; + attrs[attr_count].id = cudaLaunchAttributeClusterDimension; + attrs[attr_count].val.clusterDim = {static_cast(cluster_dim), 1, 1}; + attr_count++; + } + + if (cooperative) { + attrs[attr_count].id = cudaLaunchAttributeCooperative; + attrs[attr_count].val.cooperative = 1; + attr_count++; + } + + if (attr_count > 0) { + config.attrs = attrs; + config.numAttrs = attr_count; } + return config; } @@ -95,7 +109,7 @@ static void unload_library(const LibraryHandle& library) { static LaunchConfigHandle construct_launch_config(const KernelHandle& kernel, const cudaStream_t& stream, const int& smem_size, - const dim3& grid_dim, const dim3& block_dim, const int& cluster_dim) { + const dim3& grid_dim, const dim3& block_dim, const int& cluster_dim, const bool& cooperative = false) { if (smem_size > 0) DG_CUDA_DRIVER_CHECK(cuFuncSetAttribute(kernel, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, smem_size)); @@ -111,16 +125,30 @@ static LaunchConfigHandle construct_launch_config(const KernelHandle& kernel, config.numAttrs = 0; config.attrs = nullptr; + // 支持多个属性 // NOTES: must use `static` or the `attr` will be deconstructed - static LaunchAttrHandle attr; + static LaunchAttrHandle attrs[2]; + int attr_count = 0; + if (cluster_dim > 1) { - attr.id = CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION; - attr.value.clusterDim.x = cluster_dim; - attr.value.clusterDim.y = 1; - attr.value.clusterDim.z = 1; - config.attrs = &attr; - config.numAttrs = 1; + attrs[attr_count].id = CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION; + attrs[attr_count].value.clusterDim.x = cluster_dim; + attrs[attr_count].value.clusterDim.y = 1; + attrs[attr_count].value.clusterDim.z = 1; + attr_count++; } + + if (cooperative) { + attrs[attr_count].id = CU_LAUNCH_ATTRIBUTE_COOPERATIVE; + attrs[attr_count].value.cooperative = 1; + attr_count++; + } + + if (attr_count > 0) { + config.attrs = attrs; + config.numAttrs = attr_count; + } + return config; } diff --git a/csrc/jit/kernel_runtime.hpp b/csrc/jit/kernel_runtime.hpp index ba66eeb8..2f7c2d16 100644 --- a/csrc/jit/kernel_runtime.hpp +++ b/csrc/jit/kernel_runtime.hpp @@ -13,12 +13,17 @@ struct LaunchArgs { int num_threads; int smem_size; int cluster_dim; + bool cooperative; - LaunchArgs(const int& grid_dim_x, const int& num_threads, const int& smem_size = 0, const int& cluster_dim = 1): - grid_dim({grid_dim_x, 1}), num_threads(num_threads), smem_size(smem_size), cluster_dim(cluster_dim) {} + LaunchArgs(const int& grid_dim_x, const int& num_threads, const int& smem_size = 0, + const int& cluster_dim = 1, const bool& cooperative = false): + grid_dim({grid_dim_x, 1}), num_threads(num_threads), smem_size(smem_size), + cluster_dim(cluster_dim), cooperative(cooperative) {} - LaunchArgs(const std::pair& grid_dim, const int& num_threads, const int& smem_size = 0, const int& cluster_dim = 1): - grid_dim(grid_dim), num_threads(num_threads), smem_size(smem_size), cluster_dim(cluster_dim) {} + LaunchArgs(const std::pair& grid_dim, const int& num_threads, const int& smem_size = 0, + const int& cluster_dim = 1, const bool& cooperative = false): + grid_dim(grid_dim), num_threads(num_threads), smem_size(smem_size), + cluster_dim(cluster_dim), cooperative(cooperative) {} }; class KernelRuntime final { From 6cf6ca16bf01932385c44da713831866c1560a34 Mon Sep 17 00:00:00 2001 From: Eric Wong Date: Thu, 28 Aug 2025 17:43:58 +0800 Subject: [PATCH 03/71] feat: add signal gemm impl&runtime in jit kernels. Co-authored-by: Zqy11 <841971412@qq.com> --- csrc/jit_kernels/impls/sm90_fp8_gemm_1d2d.hpp | 124 +++++++++++++++++- 1 file changed, 122 insertions(+), 2 deletions(-) diff --git a/csrc/jit_kernels/impls/sm90_fp8_gemm_1d2d.hpp b/csrc/jit_kernels/impls/sm90_fp8_gemm_1d2d.hpp index 3afc2d33..34531f5c 100644 --- a/csrc/jit_kernels/impls/sm90_fp8_gemm_1d2d.hpp +++ b/csrc/jit_kernels/impls/sm90_fp8_gemm_1d2d.hpp @@ -3,7 +3,6 @@ #include #include "../../jit/compiler.hpp" -#include "../../jit/device_runtime.hpp" #include "../../jit/kernel_runtime.hpp" #include "../../utils/exception.hpp" #include "../../utils/format.hpp" @@ -140,7 +139,7 @@ static void sm90_m_grouped_fp8_gemm_contiguous_1d2d(const torch::Tensor& a, cons const auto& aligned_k = align(k, 128); const auto& config = get_best_config( GemmType::MGroupedContiguous, KernelType::Kernel1D2D, - m, n, k, 1, major_a, major_b, + m, n, k, num_groups, major_a, major_b, torch::kFloat8_e4m3fn, d.scalar_type(), false, device_runtime->get_num_sms()); @@ -246,4 +245,125 @@ static void sm90_m_grouped_fp8_gemm_masked_1d2d(const torch::Tensor& a, const to SM90FP8Gemm1D2DRuntime::launch(runtime, args); } +class SM90FP8SignalGemm1D2DRuntime final: public LaunchRuntime { +public: + struct Args { + int m, n, k, num_groups; + const std::string& compiled_dims; + + GemmConfig gemm_config; + LaunchArgs launch_args; + + void *sfb, *grouped_layout, *signal; + CUtensorMap tensor_map_a; + CUtensorMap tensor_map_b; + CUtensorMap tensor_map_d; + CUtensorMap tensor_map_sfa; + }; + + static std::string generate_impl(const Args& args) { + return fmt::format(R"( +#include + +using namespace deep_gemm; + +static void __instantiate_kernel() {{ + auto ptr = reinterpret_cast(&sm90_fp8_signal_gemm_1d2d_impl< + {}, {}, {}, + {}, + {}, {}, {}, + {}, + {}, {}, + {}, {}, + {}, {}, + {}, {} + >); +}}; +)", + get_compiled_dim(args.m, 'm', args.compiled_dims), + get_compiled_dim(args.n, 'n', args.compiled_dims), + get_compiled_dim(args.k, 'k', args.compiled_dims), + args.num_groups, + args.gemm_config.block_m, args.gemm_config.block_n, args.gemm_config.block_k, + args.gemm_config.smem_config.swizzle_cd_mode, + args.gemm_config.num_stages, args.gemm_config.num_last_stages, + args.gemm_config.thread_config.num_tma_threads, args.gemm_config.thread_config.num_math_threads, + args.gemm_config.multicast_config.num_multicast, args.gemm_config.multicast_config.is_multicast_on_a, + args.gemm_config.num_sms, to_string(args.gemm_config.gemm_type)); + } + + static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) { + DG_CUDA_UNIFIED_CHECK(launch_kernel(kernel, config, + args.sfb, args.grouped_layout, args.signal, + args.m, args.n, args.k, + args.tensor_map_a, args.tensor_map_b, + args.tensor_map_d, args.tensor_map_sfa)); + } +}; + +static std::pair sm90_m_grouped_fp8_gemm_signal_1d2d(const torch::Tensor& a, const torch::Tensor& sfa, + const torch::Tensor& b, const torch::Tensor& sfb, + const torch::Tensor& d, + const torch::Tensor& masked_m, + const torch::Tensor& signal, + const int& num_groups, const int& m, const int& n, const int& k, + const int& expected_m, + const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, + const std::string& compiled_dims) { + const auto& aligned_k = align(k, 128); + DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16); + DG_HOST_ASSERT(major_a == cute::UMMA::Major::K and major_b == cute::UMMA::Major::K); + DG_HOST_ASSERT(signal.scalar_type() == torch::kInt32); + + const auto& config = get_best_config( + GemmType::MGroupedMasked, KernelType::Kernel1D2D, + expected_m, n, k, num_groups, major_a, major_b, + torch::kFloat8_e4m3fn, d.scalar_type(), false, + device_runtime->get_num_sms()); + + // Requires no TMA splits + DG_HOST_ASSERT(config.smem_config.swizzle_a_mode == config.block_k); + DG_HOST_ASSERT(config.smem_config.swizzle_b_mode == config.block_k); + const auto& tensor_map_a = make_tma_a_desc(major_a, a, m, k, + SM90ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m), + config.block_k, + static_cast(a.stride(get_non_contiguous_dim(major_a))), num_groups, + config.smem_config.swizzle_a_mode); + const auto& tensor_map_b = make_tma_b_desc(major_b, b, n, k, + SM90ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n), + config.block_k, + static_cast(b.stride(get_non_contiguous_dim(major_b))), num_groups, + config.smem_config.swizzle_b_mode); + const auto& tensor_map_d = make_tma_cd_desc(d, m, n, + SM90ArchSpec::get_cd_store_block_m(config.block_m), + SM90ArchSpec::get_cd_store_block_n(config.block_n), + static_cast(d.stride(-2)), num_groups, + config.smem_config.swizzle_cd_mode); + const auto& tensor_map_sfa = make_tma_sf_desc(cute::UMMA::Major::MN, sfa, m, k, + config.block_m, config.block_k, num_groups, 0); + + // Launch with cooperative groups support + const SM90FP8SignalGemm1D2DRuntime::Args& args = { + .m = m, .n = n, .k = aligned_k, + .num_groups = num_groups, + .compiled_dims = compiled_dims, + .gemm_config = config, + .launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads, + config.smem_config.smem_size, + config.multicast_config.num_multicast, + true), // 启用cooperative groups + .sfb = sfb.data_ptr(), + .grouped_layout = masked_m.data_ptr(), + .signal = signal.data_ptr(), + .tensor_map_a = tensor_map_a, + .tensor_map_b = tensor_map_b, + .tensor_map_d = tensor_map_d, + .tensor_map_sfa = tensor_map_sfa, + }; + const auto& code = SM90FP8SignalGemm1D2DRuntime::generate(args); + const auto& runtime = compiler->build("sm90_fp8_m_grouped_gemm_signal_1d2d", code); + SM90FP8SignalGemm1D2DRuntime::launch(runtime, args); + return std::make_pair(config.block_m, config.signal_threshold); +} + } // namespace deep_gemm From 1e2c135cb15834f8a2090fdcdccb92d2144c2bfc Mon Sep 17 00:00:00 2001 From: Eric Wong Date: Thu, 28 Aug 2025 17:44:56 +0800 Subject: [PATCH 04/71] feat: add signal gemm kernel. Co-authored-by: Zqy11 <841971412@qq.com> --- .../impls/sm90_fp8_signal_gemm_1d2d.cuh | 451 ++++++++++++++++++ 1 file changed, 451 insertions(+) create mode 100644 deep_gemm/include/deep_gemm/impls/sm90_fp8_signal_gemm_1d2d.cuh diff --git a/deep_gemm/include/deep_gemm/impls/sm90_fp8_signal_gemm_1d2d.cuh b/deep_gemm/include/deep_gemm/impls/sm90_fp8_signal_gemm_1d2d.cuh new file mode 100644 index 00000000..1f2540ba --- /dev/null +++ b/deep_gemm/include/deep_gemm/impls/sm90_fp8_signal_gemm_1d2d.cuh @@ -0,0 +1,451 @@ +#pragma once + +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wunknown-attributes" + +#include +#include +#include + +#include +#include +#include + +#include +#include +#include + +namespace cg = cooperative_groups; +namespace deep_gemm { + +using namespace deep_gemm::sm90; + +template +__device__ __host__ void outer_launch_k_iterations(const auto& inner_launch_k_iterations, const auto& func, uint32_t num_former_iters) { + if (num_former_iters == kNumFormerIters) { + inner_launch_k_iterations(func, cute::Int{}); + return; + } + + if constexpr (kNumFormerIters + kGap <= kEnd) + outer_launch_k_iterations(inner_launch_k_iterations, func, num_former_iters); +} + +template +__global__ __launch_bounds__(kNumTMAThreads + kNumMathThreads, 1) void +sm90_fp8_signal_gemm_1d2d_impl(float* sfb, int* grouped_layout, int32_t* signal, + uint32_t shape_m, uint32_t shape_n, uint32_t shape_k, + const __grid_constant__ cute::TmaDescriptor tensor_map_a, + const __grid_constant__ cute::TmaDescriptor tensor_map_b, + const __grid_constant__ cute::TmaDescriptor tensor_map_d, + const __grid_constant__ cute::TmaDescriptor tensor_map_sfa) { +#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 900)) or defined(__CLION_IDE__) + // Scaling checks + DG_STATIC_ASSERT(BLOCK_K == 128, "Only support per-128-channel FP8 scaling"); + DG_STATIC_ASSERT(constexpr_ceil_div(BLOCK_N, BLOCK_K) == 1 or (constexpr_gcd(BLOCK_N, BLOCK_K) == BLOCK_N - BLOCK_K), "Too much B scales in a single block"); + + // Types + using WGMMA = typename FP8MMASelector::type; + using Barrier = cutlass::arch::ClusterTransactionBarrier; + DG_STATIC_ASSERT(BLOCK_M % WGMMA::M == 0, "Invalid block size"); + + // Overwrite shape constants if the compiler gives + shape_m = SHAPE_M != 0 ? SHAPE_M : shape_m; + shape_n = SHAPE_N != 0 ? SHAPE_N : shape_n; + shape_k = SHAPE_K != 0 ? SHAPE_K : shape_k; + + // Shared memory + static constexpr bool kMustUseUniformedScaleB = (BLOCK_K % BLOCK_N == 0); + static constexpr uint32_t SMEM_D_SIZE = BLOCK_M * BLOCK_N * sizeof(__nv_bfloat16); + static constexpr uint32_t SMEM_A_SIZE_PER_STAGE = BLOCK_M * BLOCK_K * sizeof(__nv_fp8_e4m3); + static constexpr uint32_t SMEM_B_SIZE_PER_STAGE = BLOCK_N * BLOCK_K * sizeof(__nv_fp8_e4m3); + static constexpr uint32_t SMEM_SFA_SIZE_PER_STAGE = BLOCK_M * sizeof(float); + const uint32_t& shape_k_scales = ceil_div(shape_k, BLOCK_K); + const uint32_t& smem_sfb_size = align(shape_k_scales * (kMustUseUniformedScaleB ? 1 : 2) * sizeof(float), sizeof(Barrier)); + + // Configs + constexpr uint32_t kFullKOfAllStages = kNumStages * BLOCK_K; + const uint32_t num_iterations = ceil_div(shape_k, kFullKOfAllStages); + const uint32_t warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); + const uint32_t lane_idx = get_lane_idx(); + + // Prefetch TMA descriptors at the very beginning + if (threadIdx.x == kNumMathThreads) { + // NOTES: `reinterpret_cast` must be here, or NVRTC will fail + cute::prefetch_tma_descriptor(&tensor_map_a); + cute::prefetch_tma_descriptor(&tensor_map_b); + cute::prefetch_tma_descriptor(&tensor_map_sfa); + cute::prefetch_tma_descriptor(&tensor_map_d); + } + __syncwarp(); + + // Align to 1024 bytes for swizzle-128B + extern __shared__ __align__(1024) uint8_t smem_buffer[]; + DG_STATIC_ASSERT(SMEM_D_SIZE % 1024 == 0, "Shared memory of A/B must be aligned to 1024 bytes"); + + // Data on shared memory + auto smem_d = reinterpret_cast<__nv_bfloat16*>(smem_buffer); + __nv_fp8_e4m3* smem_a[kNumStages]; + __nv_fp8_e4m3* smem_b[kNumStages]; + float* smem_sfa[kNumStages]; + float* smem_sfb; + + // TMA Barrier for both divisible and non-divisible cases + Barrier* full_barriers[kNumStages]; + Barrier* empty_barriers[kNumStages]; + + // Fill shared memory pointers + #pragma unroll + for (uint32_t i = 0; i < kNumStages; ++ i) { + smem_a[i] = reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_D_SIZE + i * SMEM_A_SIZE_PER_STAGE); + smem_b[i] = reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_D_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE); + smem_sfa[i] = reinterpret_cast(smem_buffer + SMEM_D_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE) + i * SMEM_SFA_SIZE_PER_STAGE); + } + smem_sfb = reinterpret_cast(smem_buffer + SMEM_D_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SFA_SIZE_PER_STAGE)); + + // Fill barriers + auto barrier_start_ptr = reinterpret_cast(reinterpret_cast(smem_sfb) + smem_sfb_size); + #pragma unroll + for (uint32_t i = 0; i < kNumStages; ++ i) { + full_barriers[i] = barrier_start_ptr + i; + empty_barriers[i] = barrier_start_ptr + kNumStages + i; + } + + // Initialize barriers + DG_STATIC_ASSERT(kNumTMAMulticast <= 32, "Too many TMA multicast"); + if (threadIdx.x == kNumMathThreads) { + // NOTES: we always use `lane_idx` to arrive for the `lane_idx`-th CTA in the cluster, + // even with TMA multicast disabled, we want to make the behavior aligned + #pragma unroll + for (uint32_t i = 0; i < kNumStages; ++ i) { + full_barriers[i]->init(1); + empty_barriers[i]->init(kNumTMAMulticast * kNumMathThreads / 32); + } + + // Make initialized barrier visible in async proxy + cutlass::arch::fence_view_async_shared(); + cutlass::arch::fence_barrier_init(); + } + + // Synchronize all threads to make barrier visible in normal memory model + (kNumTMAMulticast > 1) ? cute::cluster_sync() : __syncthreads(); + + // For pipeline unrolling + struct DivisibleK {}; + struct NotDivisibleK {}; + struct SkipComputation {}; + struct NotSkipComputation {}; + auto launch_k_iterations = [=](const auto& func, bool skip_computation, uint32_t num_former_iters) { + constexpr bool kShouldOptimize = BLOCK_K / constexpr_gcd(BLOCK_K, BLOCK_N) <= 4 and not kMustUseUniformedScaleB; + constexpr uint32_t kGap = constexpr_gcd(BLOCK_K, BLOCK_N) / 8; + constexpr uint32_t kEnd = kShouldOptimize ? BLOCK_K / 8 : 0; + + // NOTES: for too-many branches (> 5), we disable this optimization + // Otherwise, the compiler must know the dynamic variable `num_former_iters`'s real value + outer_launch_k_iterations<0, kGap, kEnd>([=](const auto& func, auto num_former_iters_type) { + if (skip_computation) { + for (uint32_t k_iter = 0; k_iter < num_iterations; ++ k_iter) + func(k_iter, DivisibleK{}, SkipComputation{}, num_former_iters_type); + } else if (shape_k % kFullKOfAllStages == 0) { + for (uint32_t k_iter = 0; k_iter < num_iterations; ++ k_iter) + func(k_iter, DivisibleK{}, NotSkipComputation{}, num_former_iters_type); + } else { + for (uint32_t k_iter = 0; k_iter < num_iterations - 1; ++ k_iter) + func(k_iter, DivisibleK{}, NotSkipComputation{}, num_former_iters_type); + func(num_iterations - 1, NotDivisibleK{}, NotSkipComputation{}, num_former_iters_type); + } + }, func, kShouldOptimize ? num_former_iters : 0); + }; + + // Register reconfigurations + constexpr uint32_t kNumTMARegisters = 40; + constexpr uint32_t kNumMathRegisters = 232; + + // Block scheduler + uint32_t m_block_idx, n_block_idx; + auto scheduler = Scheduler(shape_m, shape_n, grouped_layout); + + if (threadIdx.x >= kNumMathThreads) { + // TMA warp-group for loading data + cutlass::arch::warpgroup_reg_dealloc(); + + // NOTES: only one thread (or warp) will be used + if (threadIdx.x == kNumMathThreads) { + // Persistently schedule over blocks + while (scheduler.get_next_block(m_block_idx, n_block_idx)) { + launch_k_iterations([&](uint32_t k_iter, auto divisible_type, auto _, auto __) { + constexpr bool kHasDivisibleStages = cute::is_same_v; + constexpr uint32_t kNumInnerStages = kHasDivisibleStages ? kNumStages : kNumLastStages; + + // Assign TMA multicast number into A and B + // NOTES: there may be additional odd rows/columns or cases where multicast is not possible. + const bool is_tma_multicast_valid = scheduler.is_tma_multicast_valid(m_block_idx); + const uint32_t num_tma_multicast_a = (kIsTMAMulticastOnA and is_tma_multicast_valid) ? kNumTMAMulticast : 1; + const uint32_t num_tma_multicast_b = (not kIsTMAMulticastOnA and is_tma_multicast_valid) ? kNumTMAMulticast : 1; + DG_STATIC_ASSERT(kNumTMAMulticast <= 2, "Scheduler does not support > 2 TMA multicast"); + + // NOTES: unrolling and `kNumInnerStages` are vital for performance, NVCC will try to eliminate all + // shared memory pointers, e.g. `full_barriers` registers, if all the access indices are constant + #pragma unroll + for (uint32_t s = 0; s < kNumInnerStages; ++ s) { + // Wait consumer release + empty_barriers[s]->wait((scheduler.current_iter * num_iterations + k_iter + 1) & 1); + + // Issue TMA A + constexpr bool kWithGroupOffsetA = kGemmType == GemmType::MGroupedMasked; + auto& full_barrier = *full_barriers[s]; + uint32_t k_idx = k_iter * kFullKOfAllStages + s * BLOCK_K; + tma_copy(&tensor_map_a, reinterpret_cast(&full_barrier), + smem_a[s], k_idx, scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx), + num_tma_multicast_a); + tma_copy(&tensor_map_sfa, reinterpret_cast(&full_barrier), + smem_sfa[s], m_block_idx * BLOCK_M, + scheduler.get_global_idx(shape_k_scales, 1, k_idx / BLOCK_K), + num_tma_multicast_a); + + // Issue TMA B + tma_copy(&tensor_map_b, reinterpret_cast(&full_barrier), + smem_b[s], k_idx, scheduler.get_global_idx(shape_n, BLOCK_N, n_block_idx, m_block_idx), + num_tma_multicast_b); + full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SFA_SIZE_PER_STAGE); + } + + // Wait unaligned cases + #pragma unroll + for (uint32_t s = kNumInnerStages; s < kNumStages; ++ s) { + empty_barriers[s]->wait((scheduler.current_iter * num_iterations + k_iter + 1) & 1); + full_barriers[s]->arrive(); + } + }, false, 0); + } + + // To safely deconstruct distributed shared barriers, we need another round of empty waits + if constexpr (kNumTMAMulticast > 1) { + #pragma unroll + for (uint32_t s = 0; s < kNumStages; ++ s) + empty_barriers[s]->wait((scheduler.current_iter * num_iterations + 1) & 1); + } + } + } else { + cg::coalesced_group group = cg::coalesced_threads(); + // Math warp-groups for WGMMA + cutlass::arch::warpgroup_reg_alloc(); + + // NOTES: use `__shfl_sync` to encourage NVCC to use unified registers + const auto math_wg_idx = __shfl_sync(0xffffffff, threadIdx.x / 128, 0); + const auto r_0 = warp_idx * 16 + lane_idx / 4, r_1 = r_0 + 8; + + // Persistently schedule over blocks + while (scheduler.get_next_block(m_block_idx, n_block_idx)) { + // Decide the number of scales B to load + DG_TRAP_ONLY_DEVICE_ASSERT(shape_n % 8 == 0); + uint32_t num_former_iters = BLOCK_N / 8, num_full_iters = num_former_iters; + if constexpr (not kMustUseUniformedScaleB) { + num_former_iters = min(BLOCK_N, BLOCK_K - n_block_idx * BLOCK_N % BLOCK_K) / 8; + num_full_iters = min(shape_n - n_block_idx * BLOCK_N, BLOCK_N) / 8; + } + uint32_t num_sfb = shape_k_scales * (num_former_iters >= num_full_iters ? 1 : 2); + + // Load B scales with math warp-groups + // NOTES: except the first warp, we want to overlap loading B scales with TMA stores between tasks + if (threadIdx.x >= 32) { + auto num_previous_lines = scheduler.get_global_idx(ceil_div(shape_n, BLOCK_K), 0, 0, m_block_idx); + auto local_sfb = sfb + (num_previous_lines + ((n_block_idx * BLOCK_N) / BLOCK_K)) * shape_k_scales; + #pragma unroll + for (uint32_t i = threadIdx.x - 32; i < num_sfb; i += kNumMathThreads - 32) + st_shared(smem_sfb + i, __ldg(local_sfb + i)); + } + cutlass::arch::NamedBarrier(kNumMathThreads).sync(); + + // Accumulation for WGMMA or CUDA promotion + constexpr uint32_t WAVE_BLOCK_M = WGMMA::M * (BLOCK_M <= 64 ? 1 : 2); + DG_STATIC_ASSERT(BLOCK_M % WAVE_BLOCK_M == 0, "Invalid block sizes"); + float accum[WGMMA::kNumAccum], final_accum[WGMMA::kNumAccum * (BLOCK_M / WAVE_BLOCK_M)] = {0}; + + // Empty barrier arrival + auto empty_barrier_arrive = [&](uint32_t s) { + if constexpr (kNumTMAMulticast == 1) { + lane_idx == 0 ? empty_barriers[s]->arrive() : void(); + } else { + auto target_cta = scheduler.is_peer_cta_alive ? lane_idx : cute::block_rank_in_cluster(); + lane_idx < kNumTMAMulticast ? empty_barriers[s]->arrive(target_cta) : void(); + } + }; + + // Launch MMAs + launch_k_iterations([&](uint32_t k_iter, auto divisible_type, auto skip_type, auto _) { + constexpr bool kSkipComputation = cute::is_same_v; + constexpr bool kHasDivisibleStages = cute::is_same_v; + constexpr uint32_t kNumInnerStages = kSkipComputation ? 0 : (kHasDivisibleStages ? kNumStages : kNumLastStages); + + #pragma unroll + for (uint32_t s = 0; s < kNumInnerStages; ++ s) { + // Read B scales + float scale_b_0 = ld_shared(smem_sfb + k_iter * kNumStages + s), scale_b_1; + // NOTES: even some blocks do not need to read the second row, but we still load one to align with other blocks + if constexpr (not kMustUseUniformedScaleB) + scale_b_1 = ld_shared(smem_sfb + k_iter * kNumStages + s + shape_k_scales); + + // Wait TMA arrivals + full_barriers[s]->wait((scheduler.current_iter * num_iterations + k_iter) & 1); + + // TODO: remove some useless computation for unaligned Ms + #pragma unroll + for (uint32_t local_idx = 0; local_idx < BLOCK_M / WAVE_BLOCK_M; ++ local_idx) { + auto m_offset = local_idx * WAVE_BLOCK_M; + + // Read A scales + // NOTES: all shared memory read must be prior to `warpgroup_arrive` to avoid next scheduled block polluting the results + auto scale_a_0 = ld_shared(smem_sfa[s] + r_0 + m_offset); + auto scale_a_1 = ld_shared(smem_sfa[s] + r_1 + m_offset); + + // Commit WGMMA instructions + #pragma unroll + for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i) + warpgroup_fence_operand(accum[i]); + warpgroup_arrive(); + #pragma unroll + for (uint32_t k = 0; k < BLOCK_K / WGMMA::K; ++ k) { + auto desc_a = make_smem_desc(smem_a[s] + (math_wg_idx * WGMMA::M + m_offset) * BLOCK_K + k * WGMMA::K, 1); + auto desc_b = make_smem_desc(smem_b[s] + k * WGMMA::K, 1); + WGMMA::wgmma(desc_a, desc_b, accum, k); + } + warpgroup_commit_batch(); + #pragma unroll + for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i) + warpgroup_fence_operand(accum[i]); + warpgroup_wait<0>(); + + // Notify barrier arrival at the last warpgroup wave + if (local_idx == BLOCK_M / WAVE_BLOCK_M - 1) + empty_barrier_arrive(s); + + // Promote with scales + // NOTES: making it as predicates is very important for performance, comparing to two loops + float scale_0_0 = scale_a_0 * scale_b_0, scale_1_0 = scale_a_1 * scale_b_0; + float scale_0_1, scale_1_1; + if constexpr (not kMustUseUniformedScaleB) + scale_0_1 = scale_a_0 * scale_b_1, scale_1_1 = scale_a_1 * scale_b_1; + + auto shifted_accum = final_accum + WGMMA::kNumAccum * local_idx; + #pragma unroll + for (uint32_t i = 0; i < WGMMA::kNumAccum / 4; ++ i) { + // NOTES: for unrolled `num_former_iters` cases, we expect the compiler to automatically make it a constant + bool predicate = kMustUseUniformedScaleB or i < num_former_iters; + shifted_accum[i * 4 + 0] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 0]; + shifted_accum[i * 4 + 1] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 1]; + shifted_accum[i * 4 + 2] += (predicate ? scale_1_0 : scale_1_1) * accum[i * 4 + 2]; + shifted_accum[i * 4 + 3] += (predicate ? scale_1_0 : scale_1_1) * accum[i * 4 + 3]; + } + } + } + + // Wait unaligned cases + #pragma unroll + for (uint32_t s = kNumInnerStages; s < kNumStages; ++ s) { + full_barriers[s]->wait((scheduler.current_iter * num_iterations + k_iter) & 1); + empty_barrier_arrive(s); + } + }, not scheduler.is_computation_valid(m_block_idx, math_wg_idx * WGMMA::M), num_former_iters); + + // TMA checks + constexpr uint32_t kNumElemBytes = sizeof(nv_bfloat16); + constexpr uint32_t TMA_D_BLOCK_N = kSwizzleDMode == 0 ? BLOCK_N : (kSwizzleDMode / kNumElemBytes); + constexpr uint32_t WGMMA_M_PER_WARP = WGMMA::M / 4; + DG_STATIC_ASSERT(BLOCK_M % 8 == 0, "Invalid swizzling atom"); + DG_STATIC_ASSERT(BLOCK_N % TMA_D_BLOCK_N == 0 and BLOCK_N / TMA_D_BLOCK_N <= 32, + "Unaligned TMA store or too many TMA store instructions"); + DG_STATIC_ASSERT(TMA_D_BLOCK_N % 8 == 0, "Invalid TMA block N"); + + // Wait last TMA store to be finished + if (threadIdx.x < BLOCK_N / TMA_D_BLOCK_N) + cute::tma_store_wait<0>(); + cutlass::arch::NamedBarrier(kNumMathThreads).sync(); + + // Write back to shared memory using STSM and issue TMA stores + DG_STATIC_ASSERT(WGMMA::kNumAccum % 4 == 0, "Invalid STSM x2 vectorization"); + #pragma unroll + for (uint32_t local_idx = 0; local_idx < BLOCK_M / WAVE_BLOCK_M; ++ local_idx) { + auto m_offset = local_idx * WAVE_BLOCK_M; + auto shifted_accum = final_accum + WGMMA::kNumAccum * local_idx; + #pragma unroll + for (auto i = 0; i < WGMMA::kNumAccum / 4; ++ i) { + // Swizzle or padding into the correct address + uint8_t* smem_ptr = nullptr; + if constexpr (kSwizzleDMode > 0) { + // Calculate the swizzling atom offset and in-atom offset + constexpr uint32_t kNumBankGroupBytes = 16; + auto atom_offset = i / (TMA_D_BLOCK_N / 8), in_atom_offset = i % (TMA_D_BLOCK_N / 8); + + // Calculate the index of the bank group to be written in the atom + auto bank_group_index = in_atom_offset + lane_idx * (kSwizzleDMode / kNumBankGroupBytes); + + // Reshape the atom in another view and swizzle + // - original: `(BLOCK_M, kSwizzleDMode / kNumBankGroupBytes)` + // - new: `(BLOCK_M * kSwizzleDMode / kNumBankGroupBytes / 8, 8)` + constexpr bool kHasShortcut = (kSwizzleDMode / kNumBankGroupBytes) == 8; + auto row = kHasShortcut ? (in_atom_offset / 8 + lane_idx) : (bank_group_index / 8); + auto col = kHasShortcut ? (in_atom_offset) : (bank_group_index % 8); + col ^= row % (kSwizzleDMode / 16); + + // Add back into the base pointer + // NOTES: think twice before modifying this, as changes may affect the number of instructions + smem_ptr = reinterpret_cast(smem_d) + // Base pointer + warp_idx * (WGMMA_M_PER_WARP * kSwizzleDMode) + // Warp offset + m_offset * kSwizzleDMode + // Wave offset + atom_offset * BLOCK_M * kSwizzleDMode + // Swizzle atom offset (constants) + row * (kNumBankGroupBytes * 8) + col * kNumBankGroupBytes; // In-atom offset + } else { + // No swizzling, just padding + smem_ptr = reinterpret_cast(smem_d + (m_offset + warp_idx * WGMMA_M_PER_WARP + lane_idx) * BLOCK_N + i * 8); + } + + // NOTES: only 16 lanes' addresses are used + SM90_U32x2_STSM_N::copy( + __float22bfloat162_rn({shifted_accum[i * 4 + 0], shifted_accum[i * 4 + 1]}), + __float22bfloat162_rn({shifted_accum[i * 4 + 2], shifted_accum[i * 4 + 3]}), + smem_ptr + ); + } + } + cute::tma_store_fence(); + cutlass::arch::NamedBarrier(kNumMathThreads).sync(); + + // Use TMA store to write back to global memory + // TODO: compatible with FP32 output + constexpr bool kWithGroupOffsetD = kGemmType == GemmType::MGroupedMasked; + DG_STATIC_ASSERT(kNumMathThreads >= BLOCK_N / TMA_D_BLOCK_N, "Too many TMA blocks"); + if (threadIdx.x < BLOCK_N / TMA_D_BLOCK_N) { + auto in_block_n_offset = threadIdx.x * TMA_D_BLOCK_N; + auto smem_ptr = smem_d + in_block_n_offset * BLOCK_M; + cute::SM90_TMA_STORE_2D::copy(&tensor_map_d, smem_ptr, + n_block_idx * BLOCK_N + in_block_n_offset, + scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx)); + cute::tma_store_arrive(); + } + __syncwarp(); + + group.sync(); + + if (threadIdx.x == 0) { + atomicAdd(signal + scheduler.current_group_idx * ceil_div(shape_m, BLOCK_M) + m_block_idx, 1); + } + + } + } +#else + if (blockIdx.x == 0 and threadIdx.x == 0) + DG_DEVICE_ASSERT(false and "Signal GEMM kernel only supports SM90 architecture"); +#endif +} + +}; // namespace deep_gemm + +#pragma clang diagnostic pop From a5d23b0800795a0a54c9df6785ca3c161cd69ae8 Mon Sep 17 00:00:00 2001 From: Eric Wong Date: Thu, 28 Aug 2025 17:46:09 +0800 Subject: [PATCH 05/71] feat: add signal gemm import in deep_gemm package. Co-authored-by: Zqy11 <841971412@qq.com> --- deep_gemm/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/deep_gemm/__init__.py b/deep_gemm/__init__.py index 169e2e6b..302732e5 100644 --- a/deep_gemm/__init__.py +++ b/deep_gemm/__init__.py @@ -28,6 +28,7 @@ m_grouped_fp8_gemm_nt_contiguous, m_grouped_fp8_gemm_nn_contiguous, m_grouped_fp8_gemm_nt_masked, + m_grouped_fp8_gemm_nt_signal, k_grouped_fp8_gemm_tn_contiguous, # BF16 GEMMs bf16_gemm_nt, bf16_gemm_nn, From 48c741af9d7f2a67a6264dff6d546bf8fd8e9f20 Mon Sep 17 00:00:00 2001 From: Eric Wong Date: Thu, 28 Aug 2025 18:08:38 +0800 Subject: [PATCH 06/71] feat: add test for signal gemm. Co-authored-by: Zqy11 <841971412@qq.com> --- tests/generators.py | 24 +++++++++++ tests/test_fp8.py | 39 +++++++++++++++++ tests/test_signal_gemm.py | 91 +++++++++++++++++++++++++++++++++++++++ 3 files changed, 154 insertions(+) create mode 100644 tests/test_signal_gemm.py diff --git a/tests/generators.py b/tests/generators.py index 82cdbdcc..22c22ffa 100644 --- a/tests/generators.py +++ b/tests/generators.py @@ -233,3 +233,27 @@ def generate_k_grouped_contiguous(num_groups: int, m: int, n: int, ks: List[int] a_fp8 = per_channel_cast_to_fp8(a, use_ue8m0=use_ue8m0) b_fp8 = per_channel_cast_to_fp8(b, use_ue8m0=use_ue8m0) return k, a_fp8, b_fp8, c, d, ref_d + + +def generate_m_grouped_signal(num_groups: int, max_m: int, expected_m_per_group: int, n: int, k: int, + use_ue8m0: bool = False, use_bf16: bool = False): + a = torch.randn((num_groups, max_m, k), device='cuda', dtype=torch.bfloat16) + b = torch.randn((num_groups, n, k), device='cuda', dtype=torch.bfloat16) + d = torch.empty((num_groups, max_m, n), device='cuda', dtype=torch.bfloat16) + ref_d = torch.einsum('gmk,gnk->gmn', a, b) + + masked_m = torch.empty((num_groups, ), device='cuda', dtype=torch.int) + for j in range(num_groups): + masked_m[j] = int(expected_m_per_group * random.uniform(0.7, 1.3)) + assert masked_m.amax().item() <= max_m + + if use_bf16: + return a, b, masked_m, d, ref_d + + a_fp8 = (torch.empty_like(a, dtype=torch.float8_e4m3fn), torch.empty((num_groups, max_m, ceil_div(k, 128)), device='cuda', dtype=torch.float)) + b_fp8 = (torch.empty_like(b, dtype=torch.float8_e4m3fn), torch.empty((num_groups, ceil_div(n, 128), ceil_div(k, 128)), device='cuda', dtype=torch.float)) + for i in range(num_groups): + a_fp8[0][i], a_fp8[1][i] = per_token_cast_to_fp8(a[i], use_ue8m0=use_ue8m0) + b_fp8[0][i], b_fp8[1][i] = per_block_cast_to_fp8(b[i], use_ue8m0=use_ue8m0) + + return a_fp8, b_fp8, masked_m, d, ref_d diff --git a/tests/test_fp8.py b/tests/test_fp8.py index 0c7d3cea..b9ebfc99 100644 --- a/tests/test_fp8.py +++ b/tests/test_fp8.py @@ -159,6 +159,44 @@ def test_func(): print() +def test_m_grouped_gemm_signal() -> None: + print('Testing m-grouped signal GEMM:') + + # TODO: when the actual `m` is greater than `expected_m_per_group`, efficiency may significantly decrease. + for kernel_type, num_groups, max_m, expected_m_per_group, n, k in enumerate_m_grouped_masked(): + kernel_opt = f'1D1D' if kernel_type.is_1d1d() else '1D2D' + use_ue8m0 = get_ue8m0_usage(kernel_type) + disable_ue8m0_cast = not use_ue8m0 + + # Test correctness + for i in range(10): + a, b, masked_m, d, ref_d = generate_m_grouped_masked(num_groups, max_m, expected_m_per_group, n, k, use_ue8m0=use_ue8m0) + # Create signal tensor + signal = torch.zeros((num_groups, max_m), dtype=torch.int32, device='cuda') + result = deep_gemm.m_grouped_fp8_gemm_nt_signal(a, b, d, masked_m, signal, expected_m_per_group, disable_ue8m0_cast=disable_ue8m0_cast) + for j in range(num_groups): + diff = calc_diff(d[j, :masked_m[j].item()], ref_d[j, :masked_m[j].item()]) + assert diff < 0.001, f'{max_m=}, {n=}, {k=}, {j=}, masked_m={masked_m[j]}, {kernel_opt}, {num_groups=}, {diff:.5f}' + + # Construct full cases + a, b, masked_m, d, ref_d = generate_m_grouped_masked(num_groups, max_m, expected_m_per_group, n, k, use_ue8m0=use_ue8m0) + # Create signal tensor + signal = torch.zeros((num_groups, max_m), dtype=torch.int32, device='cuda') + + # noinspection PyShadowingNames + def test_func(): + deep_gemm.m_grouped_fp8_gemm_nt_signal(a, b, d, masked_m, signal, expected_m_per_group, disable_ue8m0_cast=disable_ue8m0_cast) + + # Test performance with fixed shapes + valid_m = masked_m.sum().item() + t = bench_kineto(test_func, 'fp8_gemm', suppress_kineto_output=True) + print(f' > Perf ({num_groups=}, expected_m_per_group={expected_m_per_group:4}, n={n:4}, k={k:4}, {kernel_opt}): ' + f'{t * 1e6:4.0f} us | ' + f'{2 * valid_m * n * k / t / 1e12:4.0f} TFLOPS | ' + f'{(count_bytes(a, d) * valid_m / (max_m * num_groups) + count_bytes(b)) / 1e9 / t:4.0f} GB/s') + print() + + if __name__ == '__main__': torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True @@ -172,3 +210,4 @@ def test_func(): test_m_grouped_gemm_contiguous() test_m_grouped_gemm_masked() test_k_grouped_gemm_contiguous() + test_m_grouped_gemm_signal() diff --git a/tests/test_signal_gemm.py b/tests/test_signal_gemm.py new file mode 100644 index 00000000..1dfd963a --- /dev/null +++ b/tests/test_signal_gemm.py @@ -0,0 +1,91 @@ +import copy +import random +import time +import torch + +import deep_gemm +from deep_gemm.testing import ( + bench, bench_kineto, + calc_diff, count_bytes +) + +from generators import ( + KernelType, get_ue8m0_usage, + enumerate_normal, enumerate_m_grouped_contiguous, enumerate_m_grouped_masked, enumerate_k_grouped_contiguous, + generate_normal, generate_m_grouped_contiguous, generate_m_grouped_masked, generate_k_grouped_contiguous +) + +def ceil_div(a, b): + return (a + b - 1) // b + +def check_signal(num_local_expert, max_m, block_m, threshold, combine_signal, masked_m): + signal = combine_signal.cpu().tolist() + # print(signal) + + expert_len = max_m // block_m + # print(len(signal)) + for expert in range(num_local_expert): + mask = masked_m[expert] + start = expert * expert_len + end = expert * (expert_len + 1) + if mask == 0: + for i in range(start, end): + assert signal[i] == 0, f'{i=}, {signal[i]=}' + else: + valid_len = ceil_div(mask, block_m) + for i in range(start, end): + if i < start + valid_len: + assert signal[i] == threshold, f'{i=}, {signal[i]=}, {threshold=}' + else: + assert signal[i] == 0, f'{i=}, {signal[i]=}' + +def test_m_grouped_gemm_signal() -> None: + print('Testing m-grouped masked GEMM:') + + # TODO: when the actual `m` is greater than `expected_m_per_group`, efficiency may significantly decrease. + for kernel_type, num_groups, max_m, expected_m_per_group, n, k in enumerate_m_grouped_masked(): + kernel_opt = f'1D1D' if kernel_type.is_1d1d() else '1D2D' + use_ue8m0 = get_ue8m0_usage(kernel_type) + disable_ue8m0_cast = not use_ue8m0 + + # Test correctness + for i in range(10): + a, b, masked_m, d, ref_d = generate_m_grouped_masked(num_groups, max_m, expected_m_per_group, n, k, use_ue8m0=use_ue8m0) + max_signal_size = num_groups * ceil_div(max_m, 64) + combine_signal = torch.zeros(max_signal_size, dtype=torch.int32, device='cuda') + block_m, threshold = deep_gemm.m_grouped_fp8_gemm_nt_signal(a, b, d, masked_m, combine_signal, expected_m_per_group, disable_ue8m0_cast=disable_ue8m0_cast) + check_signal(num_groups, max_m, block_m, threshold, combine_signal, masked_m) + for j in range(num_groups): + diff = calc_diff(d[j, :masked_m[j].item()], ref_d[j, :masked_m[j].item()]) + assert diff < 0.001, f'{m=}, {n=}, {k=}, {j=}, masked_m={masked_m[j]}, {kernel_opt}, {num_groups=}, {diff:.5f}' + + print(f' > Correctness ({num_groups=}, expected_m_per_group={expected_m_per_group:4}, n={n:4}, k={k:4}, {kernel_opt}) checked ') + + # Construct full cases + a, b, masked_m, d, ref_d = generate_m_grouped_masked(num_groups, max_m, expected_m_per_group, n, k, use_ue8m0=use_ue8m0) + max_signal_size = num_groups * ceil_div(max_m, 64) + combine_signal = torch.zeros(max_signal_size, dtype=torch.int32, device='cuda') + + # noinspection PyShadowingNames + def test_func(): + deep_gemm.m_grouped_fp8_gemm_nt_signal(a, b, d, masked_m, combine_signal, expected_m_per_group, disable_ue8m0_cast=disable_ue8m0_cast) + + # Test performance with fixed shapes + valid_m = masked_m.sum().item() + t = bench_kineto(test_func, 'fp8_signal_gemm', suppress_kineto_output=True) + print(f' > Perf ({num_groups=}, expected_m_per_group={expected_m_per_group:4}, n={n:4}, k={k:4}, {kernel_opt}): ' + f'{t * 1e6:4.0f} us | ' + f'{2 * valid_m * n * k / t / 1e12:4.0f} TFLOPS | ' + f'{(count_bytes(a, d) * valid_m / (max_m * num_groups) + count_bytes(b)) / 1e9 / t:4.0f} GB/s') + print() + +if __name__ == '__main__': + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + torch.manual_seed(0) + random.seed(0) + + print('Library path:') + print(f' > {deep_gemm.__path__}\n') + + test_m_grouped_gemm_signal() \ No newline at end of file From 9a2d3e55b63c2b329ce6bcb1bf8a1c125959aa6e Mon Sep 17 00:00:00 2001 From: Eric Wong Date: Thu, 28 Aug 2025 19:58:39 +0800 Subject: [PATCH 07/71] feat: add signal threshold as a config arg for signal gemm. --- csrc/jit_kernels/heuristics/common.hpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/csrc/jit_kernels/heuristics/common.hpp b/csrc/jit_kernels/heuristics/common.hpp index 3ed4d2a1..491f3fcb 100644 --- a/csrc/jit_kernels/heuristics/common.hpp +++ b/csrc/jit_kernels/heuristics/common.hpp @@ -61,6 +61,7 @@ struct GemmConfig { cute::UMMA::Major major_b; bool with_accumulation; int block_m, block_n, block_k; + int signal_threshold; int num_stages, num_last_stages; // Templated device configs @@ -266,6 +267,7 @@ static GemmConfig get_best_config(const GemmType& gemm_type, const KernelType& k .block_m = best_block_m, .block_n = best_block_n, .block_k = block_k, + .signal_threshold = ceil_div(n, best_block_n), .num_stages = best_num_stages, .num_last_stages = ceil_div(k, block_k) % best_num_stages, .num_sms = num_min_sms, From 03f61d2b826381d97a1fdc3feaa7ce3ede5fd507 Mon Sep 17 00:00:00 2001 From: Eric Wong Date: Sat, 30 Aug 2025 23:07:48 +0800 Subject: [PATCH 08/71] add comparation test. --- tests/test_signal_gemm.py | 37 ++++++++++++++++++++++++++++++++++++- 1 file changed, 36 insertions(+), 1 deletion(-) diff --git a/tests/test_signal_gemm.py b/tests/test_signal_gemm.py index 1dfd963a..f656dd8c 100644 --- a/tests/test_signal_gemm.py +++ b/tests/test_signal_gemm.py @@ -79,6 +79,40 @@ def test_func(): f'{(count_bytes(a, d) * valid_m / (max_m * num_groups) + count_bytes(b)) / 1e9 / t:4.0f} GB/s') print() + +def test_m_grouped_gemm_masked() -> None: + print('Testing m-grouped masked GEMM:') + + # TODO: when the actual `m` is greater than `expected_m_per_group`, efficiency may significantly decrease. + for kernel_type, num_groups, max_m, expected_m_per_group, n, k in enumerate_m_grouped_masked(): + kernel_opt = f'1D1D' if kernel_type.is_1d1d() else '1D2D' + use_ue8m0 = get_ue8m0_usage(kernel_type) + disable_ue8m0_cast = not use_ue8m0 + + # Test correctness + for i in range(10): + a, b, masked_m, d, ref_d = generate_m_grouped_masked(num_groups, max_m, expected_m_per_group, n, k, use_ue8m0=use_ue8m0) + deep_gemm.m_grouped_fp8_gemm_nt_masked(a, b, d, masked_m, expected_m_per_group, disable_ue8m0_cast=disable_ue8m0_cast) + for j in range(num_groups): + diff = calc_diff(d[j, :masked_m[j].item()], ref_d[j, :masked_m[j].item()]) + assert diff < 0.001, f'{max_m=}, {n=}, {k=}, {j=}, masked_m={masked_m[j]}, {kernel_opt}, {num_groups=}, {diff:.5f}' + + # Construct full cases + a, b, masked_m, d, ref_d = generate_m_grouped_masked(num_groups, max_m, expected_m_per_group, n, k, use_ue8m0=use_ue8m0) + + # noinspection PyShadowingNames + def test_func(): + deep_gemm.m_grouped_fp8_gemm_nt_masked(a, b, d, masked_m, expected_m_per_group, disable_ue8m0_cast=disable_ue8m0_cast) + + # Test performance with fixed shapes + valid_m = masked_m.sum().item() + t = bench_kineto(test_func, 'fp8_gemm', suppress_kineto_output=True) + print(f' > Perf ({num_groups=}, expected_m_per_group={expected_m_per_group:4}, n={n:4}, k={k:4}, {kernel_opt}): ' + f'{t * 1e6:4.0f} us | ' + f'{2 * valid_m * n * k / t / 1e12:4.0f} TFLOPS | ' + f'{(count_bytes(a, d) * valid_m / (max_m * num_groups) + count_bytes(b)) / 1e9 / t:4.0f} GB/s') + print() + if __name__ == '__main__': torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True @@ -88,4 +122,5 @@ def test_func(): print('Library path:') print(f' > {deep_gemm.__path__}\n') - test_m_grouped_gemm_signal() \ No newline at end of file + test_m_grouped_gemm_signal() + test_m_grouped_gemm_masked() \ No newline at end of file From e86f1f91a320c0b8a8618b5001c1dddba43ce5b2 Mon Sep 17 00:00:00 2001 From: Eric Wong Date: Sat, 30 Aug 2025 23:11:58 +0800 Subject: [PATCH 09/71] avoid using bench kineto --- tests/test_signal_gemm.py | 56 +++++++++++++++++++-------------------- 1 file changed, 28 insertions(+), 28 deletions(-) diff --git a/tests/test_signal_gemm.py b/tests/test_signal_gemm.py index f656dd8c..55163f0f 100644 --- a/tests/test_signal_gemm.py +++ b/tests/test_signal_gemm.py @@ -62,21 +62,21 @@ def test_m_grouped_gemm_signal() -> None: print(f' > Correctness ({num_groups=}, expected_m_per_group={expected_m_per_group:4}, n={n:4}, k={k:4}, {kernel_opt}) checked ') # Construct full cases - a, b, masked_m, d, ref_d = generate_m_grouped_masked(num_groups, max_m, expected_m_per_group, n, k, use_ue8m0=use_ue8m0) - max_signal_size = num_groups * ceil_div(max_m, 64) - combine_signal = torch.zeros(max_signal_size, dtype=torch.int32, device='cuda') - - # noinspection PyShadowingNames - def test_func(): - deep_gemm.m_grouped_fp8_gemm_nt_signal(a, b, d, masked_m, combine_signal, expected_m_per_group, disable_ue8m0_cast=disable_ue8m0_cast) - - # Test performance with fixed shapes - valid_m = masked_m.sum().item() - t = bench_kineto(test_func, 'fp8_signal_gemm', suppress_kineto_output=True) - print(f' > Perf ({num_groups=}, expected_m_per_group={expected_m_per_group:4}, n={n:4}, k={k:4}, {kernel_opt}): ' - f'{t * 1e6:4.0f} us | ' - f'{2 * valid_m * n * k / t / 1e12:4.0f} TFLOPS | ' - f'{(count_bytes(a, d) * valid_m / (max_m * num_groups) + count_bytes(b)) / 1e9 / t:4.0f} GB/s') + # a, b, masked_m, d, ref_d = generate_m_grouped_masked(num_groups, max_m, expected_m_per_group, n, k, use_ue8m0=use_ue8m0) + # max_signal_size = num_groups * ceil_div(max_m, 64) + # combine_signal = torch.zeros(max_signal_size, dtype=torch.int32, device='cuda') + + # # noinspection PyShadowingNames + # def test_func(): + # deep_gemm.m_grouped_fp8_gemm_nt_signal(a, b, d, masked_m, combine_signal, expected_m_per_group, disable_ue8m0_cast=disable_ue8m0_cast) + + # # Test performance with fixed shapes + # valid_m = masked_m.sum().item() + # t = bench_kineto(test_func, 'fp8_signal_gemm', suppress_kineto_output=True) + # print(f' > Perf ({num_groups=}, expected_m_per_group={expected_m_per_group:4}, n={n:4}, k={k:4}, {kernel_opt}): ' + # f'{t * 1e6:4.0f} us | ' + # f'{2 * valid_m * n * k / t / 1e12:4.0f} TFLOPS | ' + # f'{(count_bytes(a, d) * valid_m / (max_m * num_groups) + count_bytes(b)) / 1e9 / t:4.0f} GB/s') print() @@ -98,19 +98,19 @@ def test_m_grouped_gemm_masked() -> None: assert diff < 0.001, f'{max_m=}, {n=}, {k=}, {j=}, masked_m={masked_m[j]}, {kernel_opt}, {num_groups=}, {diff:.5f}' # Construct full cases - a, b, masked_m, d, ref_d = generate_m_grouped_masked(num_groups, max_m, expected_m_per_group, n, k, use_ue8m0=use_ue8m0) - - # noinspection PyShadowingNames - def test_func(): - deep_gemm.m_grouped_fp8_gemm_nt_masked(a, b, d, masked_m, expected_m_per_group, disable_ue8m0_cast=disable_ue8m0_cast) - - # Test performance with fixed shapes - valid_m = masked_m.sum().item() - t = bench_kineto(test_func, 'fp8_gemm', suppress_kineto_output=True) - print(f' > Perf ({num_groups=}, expected_m_per_group={expected_m_per_group:4}, n={n:4}, k={k:4}, {kernel_opt}): ' - f'{t * 1e6:4.0f} us | ' - f'{2 * valid_m * n * k / t / 1e12:4.0f} TFLOPS | ' - f'{(count_bytes(a, d) * valid_m / (max_m * num_groups) + count_bytes(b)) / 1e9 / t:4.0f} GB/s') + # a, b, masked_m, d, ref_d = generate_m_grouped_masked(num_groups, max_m, expected_m_per_group, n, k, use_ue8m0=use_ue8m0) + + # # noinspection PyShadowingNames + # def test_func(): + # deep_gemm.m_grouped_fp8_gemm_nt_masked(a, b, d, masked_m, expected_m_per_group, disable_ue8m0_cast=disable_ue8m0_cast) + + # # Test performance with fixed shapes + # valid_m = masked_m.sum().item() + # t = bench_kineto(test_func, 'fp8_gemm', suppress_kineto_output=True) + # print(f' > Perf ({num_groups=}, expected_m_per_group={expected_m_per_group:4}, n={n:4}, k={k:4}, {kernel_opt}): ' + # f'{t * 1e6:4.0f} us | ' + # f'{2 * valid_m * n * k / t / 1e12:4.0f} TFLOPS | ' + # f'{(count_bytes(a, d) * valid_m / (max_m * num_groups) + count_bytes(b)) / 1e9 / t:4.0f} GB/s') print() if __name__ == '__main__': From e7ebff177d60bb693d5ea97a4af5edf79bd4022f Mon Sep 17 00:00:00 2001 From: Eric Wong Date: Sun, 31 Aug 2025 15:15:58 +0800 Subject: [PATCH 10/71] test: modify generators to fit bs32 down gemm. --- tests/generators.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/tests/generators.py b/tests/generators.py index 22c22ffa..e8d608aa 100644 --- a/tests/generators.py +++ b/tests/generators.py @@ -86,10 +86,15 @@ def enumerate_m_grouped_contiguous(use_bf16: bool = False) -> Generator: def enumerate_m_grouped_masked() -> Generator: - max_m = 4096 + max_m = 2048 + # max_m = 4096 + # for kernel_type in get_kernel_types(): + # for num_groups, m in ((1, 1024), (2, 512), (4, 256)): + # for n, k in ((4096, 7168), (7168, 2048), ): + # yield kernel_type, num_groups, max_m, m, n, k for kernel_type in get_kernel_types(): - for num_groups, m in ((1, 1024), (2, 512), (4, 256)): - for n, k in ((4096, 7168), (7168, 2048), ): + for num_groups, m in ((16, 32), (16, 64)): + for n, k in ((7168, 2048), ): yield kernel_type, num_groups, max_m, m, n, k From 803736950645ba02d4aa837ef921dbb10e0e4b33 Mon Sep 17 00:00:00 2001 From: Eric Wong Date: Sun, 31 Aug 2025 15:26:36 +0800 Subject: [PATCH 11/71] more --- tests/test_signal_gemm.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/test_signal_gemm.py b/tests/test_signal_gemm.py index 55163f0f..20296909 100644 --- a/tests/test_signal_gemm.py +++ b/tests/test_signal_gemm.py @@ -49,7 +49,7 @@ def test_m_grouped_gemm_signal() -> None: disable_ue8m0_cast = not use_ue8m0 # Test correctness - for i in range(10): + for i in range(2): a, b, masked_m, d, ref_d = generate_m_grouped_masked(num_groups, max_m, expected_m_per_group, n, k, use_ue8m0=use_ue8m0) max_signal_size = num_groups * ceil_div(max_m, 64) combine_signal = torch.zeros(max_signal_size, dtype=torch.int32, device='cuda') @@ -90,13 +90,15 @@ def test_m_grouped_gemm_masked() -> None: disable_ue8m0_cast = not use_ue8m0 # Test correctness - for i in range(10): + for i in range(2): a, b, masked_m, d, ref_d = generate_m_grouped_masked(num_groups, max_m, expected_m_per_group, n, k, use_ue8m0=use_ue8m0) deep_gemm.m_grouped_fp8_gemm_nt_masked(a, b, d, masked_m, expected_m_per_group, disable_ue8m0_cast=disable_ue8m0_cast) for j in range(num_groups): diff = calc_diff(d[j, :masked_m[j].item()], ref_d[j, :masked_m[j].item()]) assert diff < 0.001, f'{max_m=}, {n=}, {k=}, {j=}, masked_m={masked_m[j]}, {kernel_opt}, {num_groups=}, {diff:.5f}' + print(f' > Correctness ({num_groups=}, expected_m_per_group={expected_m_per_group:4}, n={n:4}, k={k:4}, {kernel_opt}) checked ') + # Construct full cases # a, b, masked_m, d, ref_d = generate_m_grouped_masked(num_groups, max_m, expected_m_per_group, n, k, use_ue8m0=use_ue8m0) From 55c79433841fc0bb50aed2dd0a51a85800058a05 Mon Sep 17 00:00:00 2001 From: Eric Wong Date: Sun, 31 Aug 2025 15:35:18 +0800 Subject: [PATCH 12/71] more. --- tests/test_signal_gemm.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/test_signal_gemm.py b/tests/test_signal_gemm.py index 20296909..f3508949 100644 --- a/tests/test_signal_gemm.py +++ b/tests/test_signal_gemm.py @@ -92,7 +92,10 @@ def test_m_grouped_gemm_masked() -> None: # Test correctness for i in range(2): a, b, masked_m, d, ref_d = generate_m_grouped_masked(num_groups, max_m, expected_m_per_group, n, k, use_ue8m0=use_ue8m0) + origin_sms = deep_gemm.get_num_sms() + deep_gemm.set_num_sms(origin_sms - 3) deep_gemm.m_grouped_fp8_gemm_nt_masked(a, b, d, masked_m, expected_m_per_group, disable_ue8m0_cast=disable_ue8m0_cast) + deep_gemm.set_num_sms(origin_sms) for j in range(num_groups): diff = calc_diff(d[j, :masked_m[j].item()], ref_d[j, :masked_m[j].item()]) assert diff < 0.001, f'{max_m=}, {n=}, {k=}, {j=}, masked_m={masked_m[j]}, {kernel_opt}, {num_groups=}, {diff:.5f}' From 99cc43bbc52d283f4e050d6d27a15cc8a8cef4e6 Mon Sep 17 00:00:00 2001 From: Eric Wong Date: Sun, 31 Aug 2025 15:41:44 +0800 Subject: [PATCH 13/71] more. --- tests/test_signal_gemm.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/test_signal_gemm.py b/tests/test_signal_gemm.py index f3508949..1250ae54 100644 --- a/tests/test_signal_gemm.py +++ b/tests/test_signal_gemm.py @@ -93,8 +93,10 @@ def test_m_grouped_gemm_masked() -> None: for i in range(2): a, b, masked_m, d, ref_d = generate_m_grouped_masked(num_groups, max_m, expected_m_per_group, n, k, use_ue8m0=use_ue8m0) origin_sms = deep_gemm.get_num_sms() + print(f'Origin SMS: {origin_sms}') deep_gemm.set_num_sms(origin_sms - 3) deep_gemm.m_grouped_fp8_gemm_nt_masked(a, b, d, masked_m, expected_m_per_group, disable_ue8m0_cast=disable_ue8m0_cast) + print(f'Current SMS: {deep_gemm.get_num_sms()}') deep_gemm.set_num_sms(origin_sms) for j in range(num_groups): diff = calc_diff(d[j, :masked_m[j].item()], ref_d[j, :masked_m[j].item()]) From 4bfed476be528d0c9042b24cf25abab71447520d Mon Sep 17 00:00:00 2001 From: Eric Wong Date: Sun, 31 Aug 2025 16:44:45 +0800 Subject: [PATCH 14/71] fix --- tests/test_signal_gemm.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/test_signal_gemm.py b/tests/test_signal_gemm.py index 1250ae54..b84790ef 100644 --- a/tests/test_signal_gemm.py +++ b/tests/test_signal_gemm.py @@ -53,7 +53,12 @@ def test_m_grouped_gemm_signal() -> None: a, b, masked_m, d, ref_d = generate_m_grouped_masked(num_groups, max_m, expected_m_per_group, n, k, use_ue8m0=use_ue8m0) max_signal_size = num_groups * ceil_div(max_m, 64) combine_signal = torch.zeros(max_signal_size, dtype=torch.int32, device='cuda') + origin_sms = deep_gemm.get_num_sms() + print(f'Origin SMS: {origin_sms}') + deep_gemm.set_num_sms(origin_sms - 3) block_m, threshold = deep_gemm.m_grouped_fp8_gemm_nt_signal(a, b, d, masked_m, combine_signal, expected_m_per_group, disable_ue8m0_cast=disable_ue8m0_cast) + print(f'Current SMS: {deep_gemm.get_num_sms()}') + deep_gemm.set_num_sms(origin_sms) check_signal(num_groups, max_m, block_m, threshold, combine_signal, masked_m) for j in range(num_groups): diff = calc_diff(d[j, :masked_m[j].item()], ref_d[j, :masked_m[j].item()]) @@ -92,12 +97,7 @@ def test_m_grouped_gemm_masked() -> None: # Test correctness for i in range(2): a, b, masked_m, d, ref_d = generate_m_grouped_masked(num_groups, max_m, expected_m_per_group, n, k, use_ue8m0=use_ue8m0) - origin_sms = deep_gemm.get_num_sms() - print(f'Origin SMS: {origin_sms}') - deep_gemm.set_num_sms(origin_sms - 3) deep_gemm.m_grouped_fp8_gemm_nt_masked(a, b, d, masked_m, expected_m_per_group, disable_ue8m0_cast=disable_ue8m0_cast) - print(f'Current SMS: {deep_gemm.get_num_sms()}') - deep_gemm.set_num_sms(origin_sms) for j in range(num_groups): diff = calc_diff(d[j, :masked_m[j].item()], ref_d[j, :masked_m[j].item()]) assert diff < 0.001, f'{max_m=}, {n=}, {k=}, {j=}, masked_m={masked_m[j]}, {kernel_opt}, {num_groups=}, {diff:.5f}' From 4a14a40fcb0a16c98592d07288aa1211e58d09b4 Mon Sep 17 00:00:00 2001 From: Eric Wong Date: Mon, 1 Sep 2025 14:25:50 +0800 Subject: [PATCH 15/71] more. --- csrc/jit_kernels/heuristics/common.hpp | 1 + 1 file changed, 1 insertion(+) diff --git a/csrc/jit_kernels/heuristics/common.hpp b/csrc/jit_kernels/heuristics/common.hpp index 491f3fcb..e9940d88 100644 --- a/csrc/jit_kernels/heuristics/common.hpp +++ b/csrc/jit_kernels/heuristics/common.hpp @@ -210,6 +210,7 @@ static GemmConfig get_best_config(const GemmType& gemm_type, const KernelType& k } } DG_HOST_ASSERT(best_block_m > 0 and best_block_n > 0); + printf("Best block size: (%d, %d)\n", best_block_m, best_block_n); // Decide the number of TMA multicasts and whether broadcast on A MulticastConfig best_multicast_config = {1, true}; From 701751cd309df55a85d4b75b79d6efec6a600eb4 Mon Sep 17 00:00:00 2001 From: Eric Wong Date: Mon, 1 Sep 2025 14:33:29 +0800 Subject: [PATCH 16/71] more. --- csrc/jit_kernels/heuristics/common.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/jit_kernels/heuristics/common.hpp b/csrc/jit_kernels/heuristics/common.hpp index e9940d88..0ab4058d 100644 --- a/csrc/jit_kernels/heuristics/common.hpp +++ b/csrc/jit_kernels/heuristics/common.hpp @@ -210,7 +210,7 @@ static GemmConfig get_best_config(const GemmType& gemm_type, const KernelType& k } } DG_HOST_ASSERT(best_block_m > 0 and best_block_n > 0); - printf("Best block size: (%d, %d)\n", best_block_m, best_block_n); + printf("Best num waves: %d, Best block size: (%d, %d)\n", best_num_waves, best_block_m, best_block_n); // Decide the number of TMA multicasts and whether broadcast on A MulticastConfig best_multicast_config = {1, true}; From 8035121fb7f72dc71b1605fd7ddf60a9e71c8727 Mon Sep 17 00:00:00 2001 From: Eric Wong Date: Mon, 1 Sep 2025 15:11:46 +0800 Subject: [PATCH 17/71] more. --- csrc/jit_kernels/heuristics/common.hpp | 3 +++ 1 file changed, 3 insertions(+) diff --git a/csrc/jit_kernels/heuristics/common.hpp b/csrc/jit_kernels/heuristics/common.hpp index 0ab4058d..c310145d 100644 --- a/csrc/jit_kernels/heuristics/common.hpp +++ b/csrc/jit_kernels/heuristics/common.hpp @@ -150,6 +150,9 @@ static GemmConfig get_best_config(const GemmType& gemm_type, const KernelType& k const bool& with_accumulation, const int& num_sms) { DG_HOST_ASSERT(ab_dtype == torch::kFloat8_e4m3fn or ab_dtype == torch::kBFloat16); DG_HOST_ASSERT(cd_dtype == torch::kBFloat16 or cd_dtype == torch::kFloat); + printf("m: %d, n: %d, k: %d, num_groups: %d, major_a: %d, major_b: %d, ab_dtype: %d, cd_dtype: %d, with_accumulation: %d, num_sms: %d\n", + m, n, k, num_groups, static_cast(major_a), static_cast(major_b), + static_cast(ab_dtype), static_cast(cd_dtype), with_accumulation, num_sms); // Select M/N block sizes // TODO: support `% 16 == 8` block size on SM90 From 02edd319ef8ac81b3aadd477a1ad1bac00ef353f Mon Sep 17 00:00:00 2001 From: Eric Wong Date: Mon, 1 Sep 2025 15:17:30 +0800 Subject: [PATCH 18/71] more. --- csrc/jit_kernels/heuristics/common.hpp | 8 ++++---- tests/test_signal_gemm.py | 2 -- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/csrc/jit_kernels/heuristics/common.hpp b/csrc/jit_kernels/heuristics/common.hpp index c310145d..d988a5e5 100644 --- a/csrc/jit_kernels/heuristics/common.hpp +++ b/csrc/jit_kernels/heuristics/common.hpp @@ -150,9 +150,6 @@ static GemmConfig get_best_config(const GemmType& gemm_type, const KernelType& k const bool& with_accumulation, const int& num_sms) { DG_HOST_ASSERT(ab_dtype == torch::kFloat8_e4m3fn or ab_dtype == torch::kBFloat16); DG_HOST_ASSERT(cd_dtype == torch::kBFloat16 or cd_dtype == torch::kFloat); - printf("m: %d, n: %d, k: %d, num_groups: %d, major_a: %d, major_b: %d, ab_dtype: %d, cd_dtype: %d, with_accumulation: %d, num_sms: %d\n", - m, n, k, num_groups, static_cast(major_a), static_cast(major_b), - static_cast(ab_dtype), static_cast(cd_dtype), with_accumulation, num_sms); // Select M/N block sizes // TODO: support `% 16 == 8` block size on SM90 @@ -205,6 +202,10 @@ static GemmConfig get_best_config(const GemmType& gemm_type, const KernelType& k } } + if (block_m == 64 and block_n == 176) { + printf("success: %d, num_waves: %d, last_util: %d, best_block: (%d, %d), best_num_waves: %d, best_last_util: %d\n", success, num_waves, last_util, best_block_m, best_block_n, best_num_waves, best_last_util); + } + // Replace with the new config if successful if (success) { best_block_m = block_m, best_block_n = block_n; @@ -213,7 +214,6 @@ static GemmConfig get_best_config(const GemmType& gemm_type, const KernelType& k } } DG_HOST_ASSERT(best_block_m > 0 and best_block_n > 0); - printf("Best num waves: %d, Best block size: (%d, %d)\n", best_num_waves, best_block_m, best_block_n); // Decide the number of TMA multicasts and whether broadcast on A MulticastConfig best_multicast_config = {1, true}; diff --git a/tests/test_signal_gemm.py b/tests/test_signal_gemm.py index b84790ef..e6b601c3 100644 --- a/tests/test_signal_gemm.py +++ b/tests/test_signal_gemm.py @@ -54,10 +54,8 @@ def test_m_grouped_gemm_signal() -> None: max_signal_size = num_groups * ceil_div(max_m, 64) combine_signal = torch.zeros(max_signal_size, dtype=torch.int32, device='cuda') origin_sms = deep_gemm.get_num_sms() - print(f'Origin SMS: {origin_sms}') deep_gemm.set_num_sms(origin_sms - 3) block_m, threshold = deep_gemm.m_grouped_fp8_gemm_nt_signal(a, b, d, masked_m, combine_signal, expected_m_per_group, disable_ue8m0_cast=disable_ue8m0_cast) - print(f'Current SMS: {deep_gemm.get_num_sms()}') deep_gemm.set_num_sms(origin_sms) check_signal(num_groups, max_m, block_m, threshold, combine_signal, masked_m) for j in range(num_groups): From 0fcd03c0040448c853585ed90e70731ba1672ecd Mon Sep 17 00:00:00 2001 From: Eric Wong Date: Mon, 1 Sep 2025 16:23:39 +0800 Subject: [PATCH 19/71] feat: add param max_block_n --- csrc/apis/gemm.hpp | 5 +++-- csrc/jit_kernels/heuristics/common.hpp | 8 ++------ csrc/jit_kernels/impls/sm90_fp8_gemm_1d2d.hpp | 4 ++-- 3 files changed, 7 insertions(+), 10 deletions(-) diff --git a/csrc/apis/gemm.hpp b/csrc/apis/gemm.hpp index 2b9367d5..2796644f 100644 --- a/csrc/apis/gemm.hpp +++ b/csrc/apis/gemm.hpp @@ -229,7 +229,8 @@ static std::pair m_grouped_fp8_gemm_nt_signal(const std::pair> recipe, const std::string& compiled_dims, - const bool& disable_ue8m0_cast) { + const bool& disable_ue8m0_cast, + const int& max_block_n = 256) { // Shape must be `[G, M, K] @ [G, N, K].mT` const auto& major_a = get_major_type_ab(a.first); const auto& major_b = get_major_type_ab(b.first); @@ -489,7 +490,7 @@ static void register_apis(pybind11::module_& m) { m.def("m_grouped_fp8_gemm_nt_signal", &m_grouped_fp8_gemm_nt_signal, py::arg("a"), py::arg("b"), py::arg("d"), py::arg("masked_m"), py::arg("signal"), py::arg("expected_m"), py::arg("recipe") = std::nullopt, - py::arg("compiled_dims") = "nk", py::arg("disable_ue8m0_cast") = false); + py::arg("compiled_dims") = "nk", py::arg("disable_ue8m0_cast") = false, py::arg("max_block_n") = 256); m.def("k_grouped_fp8_gemm_tn_contiguous", &k_grouped_fp8_gemm_tn_contiguous, py::arg("a"), py::arg("b"), py::arg("d"), py::arg("ks"), py::arg("ks_tensor"), py::arg("c") = std::nullopt, diff --git a/csrc/jit_kernels/heuristics/common.hpp b/csrc/jit_kernels/heuristics/common.hpp index d988a5e5..dd460585 100644 --- a/csrc/jit_kernels/heuristics/common.hpp +++ b/csrc/jit_kernels/heuristics/common.hpp @@ -147,7 +147,7 @@ static GemmConfig get_best_config(const GemmType& gemm_type, const KernelType& k const int& m, const int& n, const int& k, const int& num_groups, const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, const at::ScalarType& ab_dtype, const at::ScalarType& cd_dtype, - const bool& with_accumulation, const int& num_sms) { + const bool& with_accumulation, const int& num_sms, const int& max_block_n = 256) { DG_HOST_ASSERT(ab_dtype == torch::kFloat8_e4m3fn or ab_dtype == torch::kBFloat16); DG_HOST_ASSERT(cd_dtype == torch::kBFloat16 or cd_dtype == torch::kFloat); @@ -156,7 +156,7 @@ static GemmConfig get_best_config(const GemmType& gemm_type, const KernelType& k const auto& block_ms = gemm_type == GemmType::MGroupedContiguous ? std::vector{get_mk_alignment_for_contiguous_layout()} : std::vector{64, 128, 256}; std::vector block_ns; - for (int i = 16; i <= 256; i += 16) + for (int i = 16; i <= max_block_n; i += 16) block_ns.push_back(i); // K block size is selected in a fixed manner @@ -202,10 +202,6 @@ static GemmConfig get_best_config(const GemmType& gemm_type, const KernelType& k } } - if (block_m == 64 and block_n == 176) { - printf("success: %d, num_waves: %d, last_util: %d, best_block: (%d, %d), best_num_waves: %d, best_last_util: %d\n", success, num_waves, last_util, best_block_m, best_block_n, best_num_waves, best_last_util); - } - // Replace with the new config if successful if (success) { best_block_m = block_m, best_block_n = block_n; diff --git a/csrc/jit_kernels/impls/sm90_fp8_gemm_1d2d.hpp b/csrc/jit_kernels/impls/sm90_fp8_gemm_1d2d.hpp index 34531f5c..c0d66b57 100644 --- a/csrc/jit_kernels/impls/sm90_fp8_gemm_1d2d.hpp +++ b/csrc/jit_kernels/impls/sm90_fp8_gemm_1d2d.hpp @@ -309,7 +309,7 @@ static std::pair sm90_m_grouped_fp8_gemm_signal_1d2d(const torch::Tens const int& num_groups, const int& m, const int& n, const int& k, const int& expected_m, const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, - const std::string& compiled_dims) { + const std::string& compiled_dims, const int& max_block_n = 256) { const auto& aligned_k = align(k, 128); DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16); DG_HOST_ASSERT(major_a == cute::UMMA::Major::K and major_b == cute::UMMA::Major::K); @@ -319,7 +319,7 @@ static std::pair sm90_m_grouped_fp8_gemm_signal_1d2d(const torch::Tens GemmType::MGroupedMasked, KernelType::Kernel1D2D, expected_m, n, k, num_groups, major_a, major_b, torch::kFloat8_e4m3fn, d.scalar_type(), false, - device_runtime->get_num_sms()); + device_runtime->get_num_sms(), max_block_n); // Requires no TMA splits DG_HOST_ASSERT(config.smem_config.swizzle_a_mode == config.block_k); From 2ec74145d0ff48e921230fca01bdb5af80939a78 Mon Sep 17 00:00:00 2001 From: Eric Wong Date: Mon, 1 Sep 2025 16:26:38 +0800 Subject: [PATCH 20/71] test --- tests/test_signal_gemm.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/tests/test_signal_gemm.py b/tests/test_signal_gemm.py index e6b601c3..e9facfd6 100644 --- a/tests/test_signal_gemm.py +++ b/tests/test_signal_gemm.py @@ -39,7 +39,7 @@ def check_signal(num_local_expert, max_m, block_m, threshold, combine_signal, ma else: assert signal[i] == 0, f'{i=}, {signal[i]=}' -def test_m_grouped_gemm_signal() -> None: +def test_m_grouped_gemm_signal(max_block_n) -> None: print('Testing m-grouped masked GEMM:') # TODO: when the actual `m` is greater than `expected_m_per_group`, efficiency may significantly decrease. @@ -55,7 +55,7 @@ def test_m_grouped_gemm_signal() -> None: combine_signal = torch.zeros(max_signal_size, dtype=torch.int32, device='cuda') origin_sms = deep_gemm.get_num_sms() deep_gemm.set_num_sms(origin_sms - 3) - block_m, threshold = deep_gemm.m_grouped_fp8_gemm_nt_signal(a, b, d, masked_m, combine_signal, expected_m_per_group, disable_ue8m0_cast=disable_ue8m0_cast) + block_m, threshold = deep_gemm.m_grouped_fp8_gemm_nt_signal(a, b, d, masked_m, combine_signal, expected_m_per_group, disable_ue8m0_cast=disable_ue8m0_cast, max_block_n=max_block_n) deep_gemm.set_num_sms(origin_sms) check_signal(num_groups, max_m, block_m, threshold, combine_signal, masked_m) for j in range(num_groups): @@ -127,5 +127,6 @@ def test_m_grouped_gemm_masked() -> None: print('Library path:') print(f' > {deep_gemm.__path__}\n') - test_m_grouped_gemm_signal() - test_m_grouped_gemm_masked() \ No newline at end of file + for max_block_n in (144, 160, 192): + test_m_grouped_gemm_signal(max_block_n) + # test_m_grouped_gemm_masked() \ No newline at end of file From a63152fd9da92100776fab4770cb12c34df21f80 Mon Sep 17 00:00:00 2001 From: Eric Wong Date: Mon, 1 Sep 2025 16:33:43 +0800 Subject: [PATCH 21/71] more. --- csrc/apis/gemm.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/apis/gemm.hpp b/csrc/apis/gemm.hpp index 2796644f..d2f69919 100644 --- a/csrc/apis/gemm.hpp +++ b/csrc/apis/gemm.hpp @@ -265,7 +265,7 @@ static std::pair m_grouped_fp8_gemm_nt_signal(const std::pairget_arch_major(); if (arch_major == 9 and sfa.scalar_type() == torch::kFloat) { return sm90_m_grouped_fp8_gemm_signal_1d2d(a.first, sfa, b.first, sfb, d, masked_m, signal, - num_groups, m, n, k, expected_m, major_a, major_b, compiled_dims); + num_groups, m, n, k, expected_m, major_a, major_b, compiled_dims, max_block_n); } else { DG_HOST_UNREACHABLE("Unsupported architecture or scaling factor types"); } From 92b29f6ed33772dc5c20b01988a89474244ff575 Mon Sep 17 00:00:00 2001 From: Eric Wong Date: Mon, 1 Sep 2025 16:52:51 +0800 Subject: [PATCH 22/71] more. --- csrc/jit_kernels/heuristics/sm90.hpp | 2 +- tests/test_signal_gemm.py | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/csrc/jit_kernels/heuristics/sm90.hpp b/csrc/jit_kernels/heuristics/sm90.hpp index 16ca018c..6cbb367e 100644 --- a/csrc/jit_kernels/heuristics/sm90.hpp +++ b/csrc/jit_kernels/heuristics/sm90.hpp @@ -42,7 +42,7 @@ struct SM90ArchSpec { // Too many scaling factors in a single block: `block_n > block_k and std::gcd(block_n, block_k) != block_n - block_k` // Or too many register spills - if (block_n > 128 and kernel_type == KernelType::Kernel1D2D and (block_n != 144 and block_n != 160 and block_n != 192)) + if (block_n > 128 and kernel_type == KernelType::Kernel1D2D and (block_n != 144 and block_n != 160 and block_n != 176 and block_n != 192)) return false; // Avoid bank conflicts for FP32 output diff --git a/tests/test_signal_gemm.py b/tests/test_signal_gemm.py index e9facfd6..81eb9761 100644 --- a/tests/test_signal_gemm.py +++ b/tests/test_signal_gemm.py @@ -127,6 +127,5 @@ def test_m_grouped_gemm_masked() -> None: print('Library path:') print(f' > {deep_gemm.__path__}\n') - for max_block_n in (144, 160, 192): - test_m_grouped_gemm_signal(max_block_n) + test_m_grouped_gemm_signal() # test_m_grouped_gemm_masked() \ No newline at end of file From be21bf6ca64430e92829b3c85438a2e35b6c19f9 Mon Sep 17 00:00:00 2001 From: Eric Wong Date: Mon, 1 Sep 2025 16:54:52 +0800 Subject: [PATCH 23/71] more. --- tests/test_signal_gemm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_signal_gemm.py b/tests/test_signal_gemm.py index 81eb9761..449d6635 100644 --- a/tests/test_signal_gemm.py +++ b/tests/test_signal_gemm.py @@ -39,7 +39,7 @@ def check_signal(num_local_expert, max_m, block_m, threshold, combine_signal, ma else: assert signal[i] == 0, f'{i=}, {signal[i]=}' -def test_m_grouped_gemm_signal(max_block_n) -> None: +def test_m_grouped_gemm_signal(max_block_n=256) -> None: print('Testing m-grouped masked GEMM:') # TODO: when the actual `m` is greater than `expected_m_per_group`, efficiency may significantly decrease. From ffa1140ae507bf632efca2ffd8eec0466cf0beed Mon Sep 17 00:00:00 2001 From: Eric Wong Date: Tue, 2 Sep 2025 11:04:57 +0800 Subject: [PATCH 24/71] more. --- csrc/jit_kernels/heuristics/sm90.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/jit_kernels/heuristics/sm90.hpp b/csrc/jit_kernels/heuristics/sm90.hpp index 6cbb367e..16ca018c 100644 --- a/csrc/jit_kernels/heuristics/sm90.hpp +++ b/csrc/jit_kernels/heuristics/sm90.hpp @@ -42,7 +42,7 @@ struct SM90ArchSpec { // Too many scaling factors in a single block: `block_n > block_k and std::gcd(block_n, block_k) != block_n - block_k` // Or too many register spills - if (block_n > 128 and kernel_type == KernelType::Kernel1D2D and (block_n != 144 and block_n != 160 and block_n != 176 and block_n != 192)) + if (block_n > 128 and kernel_type == KernelType::Kernel1D2D and (block_n != 144 and block_n != 160 and block_n != 192)) return false; // Avoid bank conflicts for FP32 output From 4802c62d3960c32e072ddaf350e75ceefeceb643 Mon Sep 17 00:00:00 2001 From: Eric Wong Date: Tue, 2 Sep 2025 11:19:59 +0800 Subject: [PATCH 25/71] fix: ensure memory order and send location. Co-authored-by: Zqy11 <841971412@qq.com> Co-authored-by: AniZpZ --- .../include/deep_gemm/impls/sm90_fp8_signal_gemm_1d2d.cuh | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/deep_gemm/include/deep_gemm/impls/sm90_fp8_signal_gemm_1d2d.cuh b/deep_gemm/include/deep_gemm/impls/sm90_fp8_signal_gemm_1d2d.cuh index 1f2540ba..d974f9f8 100644 --- a/deep_gemm/include/deep_gemm/impls/sm90_fp8_signal_gemm_1d2d.cuh +++ b/deep_gemm/include/deep_gemm/impls/sm90_fp8_signal_gemm_1d2d.cuh @@ -432,7 +432,12 @@ sm90_fp8_signal_gemm_1d2d_impl(float* sfb, int* grouped_layout, int32_t* signal, } __syncwarp(); + if (threadIdx.x < BLOCK_N / TMA_D_BLOCK_N) { + cute::tma_store_wait<0>(); + } + group.sync(); + __threadfence(); if (threadIdx.x == 0) { atomicAdd(signal + scheduler.current_group_idx * ceil_div(shape_m, BLOCK_M) + m_block_idx, 1); From 6271138ab88545e323c6bd505686b8c1ef28ccce Mon Sep 17 00:00:00 2001 From: Eric Wong Date: Tue, 2 Sep 2025 12:15:10 +0800 Subject: [PATCH 26/71] more. --- tests/test_signal_gemm.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/test_signal_gemm.py b/tests/test_signal_gemm.py index 449d6635..f2af4d8a 100644 --- a/tests/test_signal_gemm.py +++ b/tests/test_signal_gemm.py @@ -127,5 +127,6 @@ def test_m_grouped_gemm_masked() -> None: print('Library path:') print(f' > {deep_gemm.__path__}\n') - test_m_grouped_gemm_signal() + for max_block_n in (144, 160, 192): + test_m_grouped_gemm_signal(max_block_n) # test_m_grouped_gemm_masked() \ No newline at end of file From 57bb435ec2d744f151c1c14fc0ae62e9aadc3a6a Mon Sep 17 00:00:00 2001 From: Eric Wong Date: Tue, 2 Sep 2025 14:29:54 +0800 Subject: [PATCH 27/71] exp --- .../include/deep_gemm/impls/sm90_fp8_signal_gemm_1d2d.cuh | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/deep_gemm/include/deep_gemm/impls/sm90_fp8_signal_gemm_1d2d.cuh b/deep_gemm/include/deep_gemm/impls/sm90_fp8_signal_gemm_1d2d.cuh index d974f9f8..9e16a0e9 100644 --- a/deep_gemm/include/deep_gemm/impls/sm90_fp8_signal_gemm_1d2d.cuh +++ b/deep_gemm/include/deep_gemm/impls/sm90_fp8_signal_gemm_1d2d.cuh @@ -432,9 +432,9 @@ sm90_fp8_signal_gemm_1d2d_impl(float* sfb, int* grouped_layout, int32_t* signal, } __syncwarp(); - if (threadIdx.x < BLOCK_N / TMA_D_BLOCK_N) { - cute::tma_store_wait<0>(); - } + // if (threadIdx.x < BLOCK_N / TMA_D_BLOCK_N) { + // cute::tma_store_wait<0>(); + // } group.sync(); __threadfence(); From d6e63a3bb1d81ba3df55e10a506ae99ee2fa9e55 Mon Sep 17 00:00:00 2001 From: Eric Wong Date: Tue, 2 Sep 2025 14:36:45 +0800 Subject: [PATCH 28/71] exp --- .../include/deep_gemm/impls/sm90_fp8_signal_gemm_1d2d.cuh | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/deep_gemm/include/deep_gemm/impls/sm90_fp8_signal_gemm_1d2d.cuh b/deep_gemm/include/deep_gemm/impls/sm90_fp8_signal_gemm_1d2d.cuh index 9e16a0e9..80534140 100644 --- a/deep_gemm/include/deep_gemm/impls/sm90_fp8_signal_gemm_1d2d.cuh +++ b/deep_gemm/include/deep_gemm/impls/sm90_fp8_signal_gemm_1d2d.cuh @@ -432,12 +432,12 @@ sm90_fp8_signal_gemm_1d2d_impl(float* sfb, int* grouped_layout, int32_t* signal, } __syncwarp(); - // if (threadIdx.x < BLOCK_N / TMA_D_BLOCK_N) { - // cute::tma_store_wait<0>(); - // } + if (threadIdx.x < BLOCK_N / TMA_D_BLOCK_N) { + cute::tma_store_wait<0>(); + } group.sync(); - __threadfence(); + // __threadfence(); if (threadIdx.x == 0) { atomicAdd(signal + scheduler.current_group_idx * ceil_div(shape_m, BLOCK_M) + m_block_idx, 1); From a21c96f66d7c822b95ad6d246a12a43cff42916a Mon Sep 17 00:00:00 2001 From: Eric Wong Date: Tue, 2 Sep 2025 15:30:13 +0800 Subject: [PATCH 29/71] rollback generator. --- tests/generators.py | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/tests/generators.py b/tests/generators.py index e8d608aa..22c22ffa 100644 --- a/tests/generators.py +++ b/tests/generators.py @@ -86,15 +86,10 @@ def enumerate_m_grouped_contiguous(use_bf16: bool = False) -> Generator: def enumerate_m_grouped_masked() -> Generator: - max_m = 2048 - # max_m = 4096 - # for kernel_type in get_kernel_types(): - # for num_groups, m in ((1, 1024), (2, 512), (4, 256)): - # for n, k in ((4096, 7168), (7168, 2048), ): - # yield kernel_type, num_groups, max_m, m, n, k + max_m = 4096 for kernel_type in get_kernel_types(): - for num_groups, m in ((16, 32), (16, 64)): - for n, k in ((7168, 2048), ): + for num_groups, m in ((1, 1024), (2, 512), (4, 256)): + for n, k in ((4096, 7168), (7168, 2048), ): yield kernel_type, num_groups, max_m, m, n, k From 51740cfa95fb4fe2f1e8419df16f4a3366932c54 Mon Sep 17 00:00:00 2001 From: Eric Wong Date: Tue, 2 Sep 2025 15:30:31 +0800 Subject: [PATCH 30/71] remove threadfence. --- deep_gemm/include/deep_gemm/impls/sm90_fp8_signal_gemm_1d2d.cuh | 1 - 1 file changed, 1 deletion(-) diff --git a/deep_gemm/include/deep_gemm/impls/sm90_fp8_signal_gemm_1d2d.cuh b/deep_gemm/include/deep_gemm/impls/sm90_fp8_signal_gemm_1d2d.cuh index 80534140..0d4ab50b 100644 --- a/deep_gemm/include/deep_gemm/impls/sm90_fp8_signal_gemm_1d2d.cuh +++ b/deep_gemm/include/deep_gemm/impls/sm90_fp8_signal_gemm_1d2d.cuh @@ -437,7 +437,6 @@ sm90_fp8_signal_gemm_1d2d_impl(float* sfb, int* grouped_layout, int32_t* signal, } group.sync(); - // __threadfence(); if (threadIdx.x == 0) { atomicAdd(signal + scheduler.current_group_idx * ceil_div(shape_m, BLOCK_M) + m_block_idx, 1); From 656fe10e5eabdb02b17d2a098c1064676bd558e9 Mon Sep 17 00:00:00 2001 From: Eric Wong Date: Tue, 2 Sep 2025 16:13:33 +0800 Subject: [PATCH 31/71] add threadfence. --- .../include/deep_gemm/impls/sm90_fp8_signal_gemm_1d2d.cuh | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/deep_gemm/include/deep_gemm/impls/sm90_fp8_signal_gemm_1d2d.cuh b/deep_gemm/include/deep_gemm/impls/sm90_fp8_signal_gemm_1d2d.cuh index 0d4ab50b..6b55b7ee 100644 --- a/deep_gemm/include/deep_gemm/impls/sm90_fp8_signal_gemm_1d2d.cuh +++ b/deep_gemm/include/deep_gemm/impls/sm90_fp8_signal_gemm_1d2d.cuh @@ -437,7 +437,9 @@ sm90_fp8_signal_gemm_1d2d_impl(float* sfb, int* grouped_layout, int32_t* signal, } group.sync(); - + + __threadfence(); + if (threadIdx.x == 0) { atomicAdd(signal + scheduler.current_group_idx * ceil_div(shape_m, BLOCK_M) + m_block_idx, 1); } From 7e385a9e9d2ce0804a4a153f9031b6e275858d85 Mon Sep 17 00:00:00 2001 From: Eric Wong Date: Tue, 2 Sep 2025 16:15:24 +0800 Subject: [PATCH 32/71] complete test. --- tests/test_signal_gemm.py | 61 +++++++++++++++++++-------------------- 1 file changed, 30 insertions(+), 31 deletions(-) diff --git a/tests/test_signal_gemm.py b/tests/test_signal_gemm.py index f2af4d8a..84471423 100644 --- a/tests/test_signal_gemm.py +++ b/tests/test_signal_gemm.py @@ -65,21 +65,21 @@ def test_m_grouped_gemm_signal(max_block_n=256) -> None: print(f' > Correctness ({num_groups=}, expected_m_per_group={expected_m_per_group:4}, n={n:4}, k={k:4}, {kernel_opt}) checked ') # Construct full cases - # a, b, masked_m, d, ref_d = generate_m_grouped_masked(num_groups, max_m, expected_m_per_group, n, k, use_ue8m0=use_ue8m0) - # max_signal_size = num_groups * ceil_div(max_m, 64) - # combine_signal = torch.zeros(max_signal_size, dtype=torch.int32, device='cuda') - - # # noinspection PyShadowingNames - # def test_func(): - # deep_gemm.m_grouped_fp8_gemm_nt_signal(a, b, d, masked_m, combine_signal, expected_m_per_group, disable_ue8m0_cast=disable_ue8m0_cast) - - # # Test performance with fixed shapes - # valid_m = masked_m.sum().item() - # t = bench_kineto(test_func, 'fp8_signal_gemm', suppress_kineto_output=True) - # print(f' > Perf ({num_groups=}, expected_m_per_group={expected_m_per_group:4}, n={n:4}, k={k:4}, {kernel_opt}): ' - # f'{t * 1e6:4.0f} us | ' - # f'{2 * valid_m * n * k / t / 1e12:4.0f} TFLOPS | ' - # f'{(count_bytes(a, d) * valid_m / (max_m * num_groups) + count_bytes(b)) / 1e9 / t:4.0f} GB/s') + a, b, masked_m, d, ref_d = generate_m_grouped_masked(num_groups, max_m, expected_m_per_group, n, k, use_ue8m0=use_ue8m0) + max_signal_size = num_groups * ceil_div(max_m, 64) + combine_signal = torch.zeros(max_signal_size, dtype=torch.int32, device='cuda') + + # noinspection PyShadowingNames + def test_func(): + deep_gemm.m_grouped_fp8_gemm_nt_signal(a, b, d, masked_m, combine_signal, expected_m_per_group, disable_ue8m0_cast=disable_ue8m0_cast) + + # Test performance with fixed shapes + valid_m = masked_m.sum().item() + t = bench_kineto(test_func, 'fp8_signal_gemm', suppress_kineto_output=True) + print(f' > Perf ({num_groups=}, expected_m_per_group={expected_m_per_group:4}, n={n:4}, k={k:4}, {kernel_opt}): ' + f'{t * 1e6:4.0f} us | ' + f'{2 * valid_m * n * k / t / 1e12:4.0f} TFLOPS | ' + f'{(count_bytes(a, d) * valid_m / (max_m * num_groups) + count_bytes(b)) / 1e9 / t:4.0f} GB/s') print() @@ -103,19 +103,19 @@ def test_m_grouped_gemm_masked() -> None: print(f' > Correctness ({num_groups=}, expected_m_per_group={expected_m_per_group:4}, n={n:4}, k={k:4}, {kernel_opt}) checked ') # Construct full cases - # a, b, masked_m, d, ref_d = generate_m_grouped_masked(num_groups, max_m, expected_m_per_group, n, k, use_ue8m0=use_ue8m0) - - # # noinspection PyShadowingNames - # def test_func(): - # deep_gemm.m_grouped_fp8_gemm_nt_masked(a, b, d, masked_m, expected_m_per_group, disable_ue8m0_cast=disable_ue8m0_cast) - - # # Test performance with fixed shapes - # valid_m = masked_m.sum().item() - # t = bench_kineto(test_func, 'fp8_gemm', suppress_kineto_output=True) - # print(f' > Perf ({num_groups=}, expected_m_per_group={expected_m_per_group:4}, n={n:4}, k={k:4}, {kernel_opt}): ' - # f'{t * 1e6:4.0f} us | ' - # f'{2 * valid_m * n * k / t / 1e12:4.0f} TFLOPS | ' - # f'{(count_bytes(a, d) * valid_m / (max_m * num_groups) + count_bytes(b)) / 1e9 / t:4.0f} GB/s') + a, b, masked_m, d, ref_d = generate_m_grouped_masked(num_groups, max_m, expected_m_per_group, n, k, use_ue8m0=use_ue8m0) + + # noinspection PyShadowingNames + def test_func(): + deep_gemm.m_grouped_fp8_gemm_nt_masked(a, b, d, masked_m, expected_m_per_group, disable_ue8m0_cast=disable_ue8m0_cast) + + # Test performance with fixed shapes + valid_m = masked_m.sum().item() + t = bench_kineto(test_func, 'fp8_gemm', suppress_kineto_output=True) + print(f' > Perf ({num_groups=}, expected_m_per_group={expected_m_per_group:4}, n={n:4}, k={k:4}, {kernel_opt}): ' + f'{t * 1e6:4.0f} us | ' + f'{2 * valid_m * n * k / t / 1e12:4.0f} TFLOPS | ' + f'{(count_bytes(a, d) * valid_m / (max_m * num_groups) + count_bytes(b)) / 1e9 / t:4.0f} GB/s') print() if __name__ == '__main__': @@ -127,6 +127,5 @@ def test_m_grouped_gemm_masked() -> None: print('Library path:') print(f' > {deep_gemm.__path__}\n') - for max_block_n in (144, 160, 192): - test_m_grouped_gemm_signal(max_block_n) - # test_m_grouped_gemm_masked() \ No newline at end of file + test_m_grouped_gemm_signal() + test_m_grouped_gemm_masked() \ No newline at end of file From 8a138ad4c41bc0935d6bb81c409bc699c6e9f976 Mon Sep 17 00:00:00 2001 From: Eric Wong Date: Wed, 3 Sep 2025 14:09:19 +0800 Subject: [PATCH 33/71] fix: turn comment from chinese to english. --- csrc/jit/handle.hpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/csrc/jit/handle.hpp b/csrc/jit/handle.hpp index 39542048..38c23b47 100644 --- a/csrc/jit/handle.hpp +++ b/csrc/jit/handle.hpp @@ -50,7 +50,7 @@ static LaunchConfigHandle construct_launch_config(const KernelHandle& kernel, config.numAttrs = 0; config.attrs = nullptr; - // 支持多个属性 + // Support for multiple attributes static LaunchAttrHandle attrs[2]; int attr_count = 0; @@ -125,7 +125,7 @@ static LaunchConfigHandle construct_launch_config(const KernelHandle& kernel, config.numAttrs = 0; config.attrs = nullptr; - // 支持多个属性 + // Support for multiple attributes // NOTES: must use `static` or the `attr` will be deconstructed static LaunchAttrHandle attrs[2]; int attr_count = 0; From 2dfdce54b9e8ff8ca753fd0e59631af85bd69259 Mon Sep 17 00:00:00 2001 From: Eric Wong Date: Thu, 4 Sep 2025 17:01:24 +0800 Subject: [PATCH 34/71] refactor: merge signal gemm related api & sm90 impl into existing api & impl. Co-authored-by: Zqy11 <841971412@qq.com> Co-authored-by: AniZpZ --- csrc/apis/gemm.hpp | 21 +++++++++++++---- csrc/jit_kernels/impls/sm90_fp8_gemm_1d2d.hpp | 23 ++++++++++++------- 2 files changed, 32 insertions(+), 12 deletions(-) diff --git a/csrc/apis/gemm.hpp b/csrc/apis/gemm.hpp index d2f69919..d1b36339 100644 --- a/csrc/apis/gemm.hpp +++ b/csrc/apis/gemm.hpp @@ -169,14 +169,17 @@ static void m_grouped_fp8_gemm_nn_contiguous(const std::pair& a, +static std::optional> m_grouped_fp8_gemm_nt_masked(const std::pair& a, const std::pair& b, const torch::Tensor& d, const torch::Tensor& masked_m, const int& expected_m, std::optional> recipe, const std::string& compiled_dims, - const bool& disable_ue8m0_cast) { + const bool& disable_ue8m0_cast, + const int& max_block_n, + const bool& enable_overlap, + const std::optional& signal) { // Shape must be `[G, M, K] @ [G, N, K].mT` const auto& major_a = get_major_type_ab(a.first); const auto& major_b = get_major_type_ab(b.first); @@ -196,6 +199,12 @@ static void m_grouped_fp8_gemm_nt_masked(const std::pairget_arch_major(); + std::optional> result = std::nullopt; if (arch_major == 9 and sfa.scalar_type() == torch::kFloat) { - sm90_m_grouped_fp8_gemm_masked_1d2d(a.first, sfa, b.first, sfb, d, masked_m, + result = sm90_m_grouped_fp8_gemm_masked_1d2d(a.first, sfa, b.first, sfb, d, masked_m, num_groups, m, n, k, expected_m, major_a, major_b, compiled_dims); } else if (arch_major == 10 and sfa.scalar_type() == torch::kInt) { sm100_m_grouped_fp8_gemm_masked_1d1d(a.first, sfa, b.first, sfb, d, masked_m, @@ -219,6 +229,7 @@ static void m_grouped_fp8_gemm_nt_masked(const std::pair m_grouped_fp8_gemm_nt_signal(const std::pair& a, @@ -486,7 +497,9 @@ static void register_apis(pybind11::module_& m) { m.def("m_grouped_fp8_gemm_nt_masked", &m_grouped_fp8_gemm_nt_masked, py::arg("a"), py::arg("b"), py::arg("d"), py::arg("masked_m"), py::arg("expected_m"), py::arg("recipe") = std::nullopt, - py::arg("compiled_dims") = "nk", py::arg("disable_ue8m0_cast") = false); + py::arg("compiled_dims") = "nk", py::arg("disable_ue8m0_cast") = false, + py::arg("max_block_n") = 256, py::arg("enable_overlap") = false + py::arg("signal") = std::nullopt); m.def("m_grouped_fp8_gemm_nt_signal", &m_grouped_fp8_gemm_nt_signal, py::arg("a"), py::arg("b"), py::arg("d"), py::arg("masked_m"), py::arg("signal"), py::arg("expected_m"), py::arg("recipe") = std::nullopt, diff --git a/csrc/jit_kernels/impls/sm90_fp8_gemm_1d2d.hpp b/csrc/jit_kernels/impls/sm90_fp8_gemm_1d2d.hpp index c0d66b57..c0223043 100644 --- a/csrc/jit_kernels/impls/sm90_fp8_gemm_1d2d.hpp +++ b/csrc/jit_kernels/impls/sm90_fp8_gemm_1d2d.hpp @@ -20,7 +20,7 @@ class SM90FP8Gemm1D2DRuntime final: public LaunchRuntime GemmConfig gemm_config; LaunchArgs launch_args; - void *sfb, *grouped_layout; + void *sfb, *grouped_layout, *signal; CUtensorMap tensor_map_a; CUtensorMap tensor_map_b; CUtensorMap tensor_map_d; @@ -42,7 +42,7 @@ static void __instantiate_kernel() {{ {}, {}, {}, {}, {}, {}, - {}, {} + {}, {}, {} >); }}; )", @@ -54,13 +54,13 @@ static void __instantiate_kernel() {{ args.gemm_config.num_stages, args.gemm_config.num_last_stages, args.gemm_config.thread_config.num_tma_threads, args.gemm_config.thread_config.num_math_threads, args.gemm_config.multicast_config.num_multicast, args.gemm_config.multicast_config.is_multicast_on_a, - args.gemm_config.num_sms, to_string(args.gemm_config.gemm_type)); + args.gemm_config.num_sms, to_string(args.gemm_config.gemm_type), args.gemm_config.enable_overlap); } static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) { // TODO: optimize `args` copy DG_CUDA_UNIFIED_CHECK(launch_kernel(kernel, config, - args.sfb, args.grouped_layout, + args.sfb, args.grouped_layout, args.signal args.m, args.n, args.k, args.tensor_map_a, args.tensor_map_b, args.tensor_map_d, args.tensor_map_sfa)); @@ -116,6 +116,7 @@ static void sm90_fp8_gemm_1d2d(const torch::Tensor& a, const torch::Tensor& sfa, config.multicast_config.num_multicast), .sfb = sfb.data_ptr(), .grouped_layout = nullptr, + .signal = nullptr, .tensor_map_a = tensor_map_a, .tensor_map_b = tensor_map_b, .tensor_map_d = tensor_map_d, @@ -175,6 +176,7 @@ static void sm90_m_grouped_fp8_gemm_contiguous_1d2d(const torch::Tensor& a, cons config.multicast_config.num_multicast), .sfb = sfb.data_ptr(), .grouped_layout = m_indices.data_ptr(), + .signal = nullptr, .tensor_map_a = tensor_map_a, .tensor_map_b = tensor_map_b, .tensor_map_d = tensor_map_d, @@ -185,14 +187,17 @@ static void sm90_m_grouped_fp8_gemm_contiguous_1d2d(const torch::Tensor& a, cons SM90FP8Gemm1D2DRuntime::launch(runtime, args); } -static void sm90_m_grouped_fp8_gemm_masked_1d2d(const torch::Tensor& a, const torch::Tensor& sfa, +static std::optional> sm90_m_grouped_fp8_gemm_masked_1d2d(const torch::Tensor& a, const torch::Tensor& sfa, const torch::Tensor& b, const torch::Tensor& sfb, const torch::Tensor& d, const torch::Tensor& masked_m, const int& num_groups, const int& m, const int& n, const int& k, const int& expected_m, const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, - const std::string& compiled_dims) { + const std::string& compiled_dims, + const int& max_block_n, + const bool& enable_overlap, + const std::optional& signal) { const auto& aligned_k = align(k, 128); DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16); DG_HOST_ASSERT(major_a == cute::UMMA::Major::K and major_b == cute::UMMA::Major::K); @@ -201,7 +206,7 @@ static void sm90_m_grouped_fp8_gemm_masked_1d2d(const torch::Tensor& a, const to GemmType::MGroupedMasked, KernelType::Kernel1D2D, expected_m, n, k, num_groups, major_a, major_b, torch::kFloat8_e4m3fn, d.scalar_type(), false, - device_runtime->get_num_sms()); + device_runtime->get_num_sms(), max_block_n, enable_overlap); // Requires no TMA splits DG_HOST_ASSERT(config.smem_config.swizzle_a_mode == config.block_k); @@ -232,9 +237,10 @@ static void sm90_m_grouped_fp8_gemm_masked_1d2d(const torch::Tensor& a, const to .gemm_config = config, .launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads, config.smem_config.smem_size, - config.multicast_config.num_multicast), + config.multicast_config.num_multicast, enable_overlap), .sfb = sfb.data_ptr(), .grouped_layout = masked_m.data_ptr(), + .signal = enable_overlap ? signal.value().data_ptr() : nullptr, .tensor_map_a = tensor_map_a, .tensor_map_b = tensor_map_b, .tensor_map_d = tensor_map_d, @@ -243,6 +249,7 @@ static void sm90_m_grouped_fp8_gemm_masked_1d2d(const torch::Tensor& a, const to const auto& code = SM90FP8Gemm1D2DRuntime::generate(args); const auto& runtime = compiler->build("sm90_fp8_m_grouped_gemm_masked_1d2d", code); SM90FP8Gemm1D2DRuntime::launch(runtime, args); + return enable_overlap ? std::make_pair(config.block_m, config.signal_threshold) : std::nullopt; } class SM90FP8SignalGemm1D2DRuntime final: public LaunchRuntime { From 21da45dc1bfee90e4df49da057d758e37cb054bb Mon Sep 17 00:00:00 2001 From: Eric Wong Date: Thu, 4 Sep 2025 17:01:56 +0800 Subject: [PATCH 35/71] refactor: add params in common. Co-authored-by: Zqy11 <841971412@qq.com> Co-authored-by: AniZpZ --- csrc/jit_kernels/heuristics/common.hpp | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/csrc/jit_kernels/heuristics/common.hpp b/csrc/jit_kernels/heuristics/common.hpp index dd460585..5c5a4f4e 100644 --- a/csrc/jit_kernels/heuristics/common.hpp +++ b/csrc/jit_kernels/heuristics/common.hpp @@ -72,6 +72,8 @@ struct GemmConfig { MulticastConfig multicast_config; SharedMemoryConfig smem_config; ThreadConfig thread_config; + + bool enable_overlap; }; static bool is_multicast_legal(const int& shape_dim, const int& block_dim, @@ -147,7 +149,8 @@ static GemmConfig get_best_config(const GemmType& gemm_type, const KernelType& k const int& m, const int& n, const int& k, const int& num_groups, const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, const at::ScalarType& ab_dtype, const at::ScalarType& cd_dtype, - const bool& with_accumulation, const int& num_sms, const int& max_block_n = 256) { + const bool& with_accumulation, const int& num_sms, + const int& max_block_n = 256, const bool& enable_overlap = false) { DG_HOST_ASSERT(ab_dtype == torch::kFloat8_e4m3fn or ab_dtype == torch::kBFloat16); DG_HOST_ASSERT(cd_dtype == torch::kBFloat16 or cd_dtype == torch::kFloat); From cd062ad1e8a2428a57a119965e270e4e824deb6b Mon Sep 17 00:00:00 2001 From: Eric Wong Date: Thu, 4 Sep 2025 17:02:25 +0800 Subject: [PATCH 36/71] refactor: merge signal gemm kernel to fp8 gemm kernel. Co-authored-by: Zqy11 <841971412@qq.com> Co-authored-by: AniZpZ --- .../deep_gemm/impls/sm90_fp8_gemm_1d2d.cuh | 20 +++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/deep_gemm/include/deep_gemm/impls/sm90_fp8_gemm_1d2d.cuh b/deep_gemm/include/deep_gemm/impls/sm90_fp8_gemm_1d2d.cuh index 5a65d69e..ef5a5197 100644 --- a/deep_gemm/include/deep_gemm/impls/sm90_fp8_gemm_1d2d.cuh +++ b/deep_gemm/include/deep_gemm/impls/sm90_fp8_gemm_1d2d.cuh @@ -3,6 +3,7 @@ #pragma clang diagnostic push #pragma clang diagnostic ignored "-Wunknown-attributes" +#include #include #include @@ -36,9 +37,9 @@ template + uint32_t kNumSMs, GemmType kGemmType, bool kEnableOverlap> __global__ __launch_bounds__(kNumTMAThreads + kNumMathThreads, 1) void -sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout, +sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout, int* signal, uint32_t shape_m, uint32_t shape_n, uint32_t shape_k, const __grid_constant__ cute::TmaDescriptor tensor_map_a, const __grid_constant__ cute::TmaDescriptor tensor_map_b, @@ -232,6 +233,8 @@ sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout, } } } else { + if constexpr (kEnableOverlap) + cg::coalesced_group group = cg::coalesced_threads(); // Math warp-groups for WGMMA cutlass::arch::warpgroup_reg_alloc(); @@ -428,6 +431,19 @@ sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout, cute::tma_store_arrive(); } __syncwarp(); + + if constexpr (kEnableOverlap) { + if (threadIdx.x < BLOCK_N / TMA_D_BLOCK_N) { + cute::tma_store_wait<0>(); + } + + group.sync(); + __threadfence(); + + if (threadIdx.x == 0) { + atomicAdd(signal + scheduler.current_group_idx * ceil_div(shape_m, BLOCK_M) + m_block_idx, 1); + } + } } } #else From 387f068c60bf8d8905e0199e0292e8b0751c95ce Mon Sep 17 00:00:00 2001 From: Eric Wong Date: Thu, 4 Sep 2025 17:14:33 +0800 Subject: [PATCH 37/71] update tests. --- tests/test_signal_gemm.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_signal_gemm.py b/tests/test_signal_gemm.py index 84471423..d94233c7 100644 --- a/tests/test_signal_gemm.py +++ b/tests/test_signal_gemm.py @@ -55,7 +55,7 @@ def test_m_grouped_gemm_signal(max_block_n=256) -> None: combine_signal = torch.zeros(max_signal_size, dtype=torch.int32, device='cuda') origin_sms = deep_gemm.get_num_sms() deep_gemm.set_num_sms(origin_sms - 3) - block_m, threshold = deep_gemm.m_grouped_fp8_gemm_nt_signal(a, b, d, masked_m, combine_signal, expected_m_per_group, disable_ue8m0_cast=disable_ue8m0_cast, max_block_n=max_block_n) + block_m, threshold = deep_gemm.m_grouped_fp8_gemm_nt_masked(a, b, d, masked_m, combine_signal, expected_m_per_group, disable_ue8m0_cast=disable_ue8m0_cast, signal=combine_signal, max_block_n=max_block_n, enable_overlap=True) deep_gemm.set_num_sms(origin_sms) check_signal(num_groups, max_m, block_m, threshold, combine_signal, masked_m) for j in range(num_groups): @@ -71,7 +71,7 @@ def test_m_grouped_gemm_signal(max_block_n=256) -> None: # noinspection PyShadowingNames def test_func(): - deep_gemm.m_grouped_fp8_gemm_nt_signal(a, b, d, masked_m, combine_signal, expected_m_per_group, disable_ue8m0_cast=disable_ue8m0_cast) + deep_gemm.m_grouped_fp8_gemm_nt_masked(a, b, d, masked_m, expected_m_per_group, disable_ue8m0_cast=disable_ue8m0_cast, signal=combine_signal, max_block_n=max_block_n, enable_overlap=True) # Test performance with fixed shapes valid_m = masked_m.sum().item() From aa81d53e00dcff02069f783a13fccd857f1d2dfe Mon Sep 17 00:00:00 2001 From: Eric Wong Date: Thu, 4 Sep 2025 17:31:56 +0800 Subject: [PATCH 38/71] fix --- csrc/apis/gemm.hpp | 5 +++-- csrc/jit_kernels/impls/sm90_fp8_gemm_1d2d.hpp | 6 ++++-- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/csrc/apis/gemm.hpp b/csrc/apis/gemm.hpp index d1b36339..092f7a53 100644 --- a/csrc/apis/gemm.hpp +++ b/csrc/apis/gemm.hpp @@ -219,7 +219,8 @@ static std::optional> m_grouped_fp8_gemm_nt_masked(const std std::optional> result = std::nullopt; if (arch_major == 9 and sfa.scalar_type() == torch::kFloat) { result = sm90_m_grouped_fp8_gemm_masked_1d2d(a.first, sfa, b.first, sfb, d, masked_m, - num_groups, m, n, k, expected_m, major_a, major_b, compiled_dims); + num_groups, m, n, k, expected_m, major_a, major_b, compiled_dims, + max_block_n, enable_overlap, signal); } else if (arch_major == 10 and sfa.scalar_type() == torch::kInt) { sm100_m_grouped_fp8_gemm_masked_1d1d(a.first, sfa, b.first, sfb, d, masked_m, num_groups, m, n, k, expected_m, major_a, major_b, compiled_dims); @@ -498,7 +499,7 @@ static void register_apis(pybind11::module_& m) { py::arg("a"), py::arg("b"), py::arg("d"), py::arg("masked_m"), py::arg("expected_m"), py::arg("recipe") = std::nullopt, py::arg("compiled_dims") = "nk", py::arg("disable_ue8m0_cast") = false, - py::arg("max_block_n") = 256, py::arg("enable_overlap") = false + py::arg("max_block_n") = 256, py::arg("enable_overlap") = false, py::arg("signal") = std::nullopt); m.def("m_grouped_fp8_gemm_nt_signal", &m_grouped_fp8_gemm_nt_signal, py::arg("a"), py::arg("b"), py::arg("d"), py::arg("masked_m"), py::arg("signal"), diff --git a/csrc/jit_kernels/impls/sm90_fp8_gemm_1d2d.hpp b/csrc/jit_kernels/impls/sm90_fp8_gemm_1d2d.hpp index c0223043..6fbc6e67 100644 --- a/csrc/jit_kernels/impls/sm90_fp8_gemm_1d2d.hpp +++ b/csrc/jit_kernels/impls/sm90_fp8_gemm_1d2d.hpp @@ -60,7 +60,7 @@ static void __instantiate_kernel() {{ static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) { // TODO: optimize `args` copy DG_CUDA_UNIFIED_CHECK(launch_kernel(kernel, config, - args.sfb, args.grouped_layout, args.signal + args.sfb, args.grouped_layout, args.signal, args.m, args.n, args.k, args.tensor_map_a, args.tensor_map_b, args.tensor_map_d, args.tensor_map_sfa)); @@ -249,7 +249,9 @@ static std::optional> sm90_m_grouped_fp8_gemm_masked_1d2d(co const auto& code = SM90FP8Gemm1D2DRuntime::generate(args); const auto& runtime = compiler->build("sm90_fp8_m_grouped_gemm_masked_1d2d", code); SM90FP8Gemm1D2DRuntime::launch(runtime, args); - return enable_overlap ? std::make_pair(config.block_m, config.signal_threshold) : std::nullopt; + return enable_overlap ? + std::optional(std::make_pair(config.block_m, config.signal_threshold)) : + std::nullopt; } class SM90FP8SignalGemm1D2DRuntime final: public LaunchRuntime { From e252aaa127651eb88ff3ec39ee2e2dee62cb374f Mon Sep 17 00:00:00 2001 From: Eric Wong Date: Thu, 4 Sep 2025 17:35:09 +0800 Subject: [PATCH 39/71] fix. --- tests/test_signal_gemm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_signal_gemm.py b/tests/test_signal_gemm.py index d94233c7..19dab79e 100644 --- a/tests/test_signal_gemm.py +++ b/tests/test_signal_gemm.py @@ -55,7 +55,7 @@ def test_m_grouped_gemm_signal(max_block_n=256) -> None: combine_signal = torch.zeros(max_signal_size, dtype=torch.int32, device='cuda') origin_sms = deep_gemm.get_num_sms() deep_gemm.set_num_sms(origin_sms - 3) - block_m, threshold = deep_gemm.m_grouped_fp8_gemm_nt_masked(a, b, d, masked_m, combine_signal, expected_m_per_group, disable_ue8m0_cast=disable_ue8m0_cast, signal=combine_signal, max_block_n=max_block_n, enable_overlap=True) + block_m, threshold = deep_gemm.m_grouped_fp8_gemm_nt_masked(a, b, d, masked_m, expected_m_per_group, disable_ue8m0_cast=disable_ue8m0_cast, signal=combine_signal, max_block_n=max_block_n, enable_overlap=True) deep_gemm.set_num_sms(origin_sms) check_signal(num_groups, max_m, block_m, threshold, combine_signal, masked_m) for j in range(num_groups): From 1d0a364d8f43850cbfa04cf2a2d9ec45b5a516d8 Mon Sep 17 00:00:00 2001 From: Eric Wong Date: Thu, 4 Sep 2025 17:37:13 +0800 Subject: [PATCH 40/71] fix. --- deep_gemm/include/deep_gemm/impls/sm90_fp8_gemm_1d2d.cuh | 1 + 1 file changed, 1 insertion(+) diff --git a/deep_gemm/include/deep_gemm/impls/sm90_fp8_gemm_1d2d.cuh b/deep_gemm/include/deep_gemm/impls/sm90_fp8_gemm_1d2d.cuh index ef5a5197..b39a75b1 100644 --- a/deep_gemm/include/deep_gemm/impls/sm90_fp8_gemm_1d2d.cuh +++ b/deep_gemm/include/deep_gemm/impls/sm90_fp8_gemm_1d2d.cuh @@ -15,6 +15,7 @@ #include #include +namespace cg = cooperative_groups; namespace deep_gemm { using namespace deep_gemm::sm90; From 8c2e6b865b209a10f435e1ffa16bb6c4605509be Mon Sep 17 00:00:00 2001 From: Eric Wong Date: Thu, 4 Sep 2025 17:43:49 +0800 Subject: [PATCH 41/71] fix. --- deep_gemm/include/deep_gemm/impls/sm90_fp8_gemm_1d2d.cuh | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/deep_gemm/include/deep_gemm/impls/sm90_fp8_gemm_1d2d.cuh b/deep_gemm/include/deep_gemm/impls/sm90_fp8_gemm_1d2d.cuh index b39a75b1..2aad2108 100644 --- a/deep_gemm/include/deep_gemm/impls/sm90_fp8_gemm_1d2d.cuh +++ b/deep_gemm/include/deep_gemm/impls/sm90_fp8_gemm_1d2d.cuh @@ -234,8 +234,10 @@ sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout, int* signal, } } } else { - if constexpr (kEnableOverlap) - cg::coalesced_group group = cg::coalesced_threads(); + cg::coalesced_group group; + if constexpr (kEnableOverlap) { + group = cg::coalesced_threads(); + } // Math warp-groups for WGMMA cutlass::arch::warpgroup_reg_alloc(); From 3517989a7bd8895ec2cd30233c36afe0bbd707f3 Mon Sep 17 00:00:00 2001 From: Eric Wong Date: Thu, 4 Sep 2025 17:46:21 +0800 Subject: [PATCH 42/71] fix. --- deep_gemm/include/deep_gemm/impls/sm90_fp8_gemm_1d2d.cuh | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/deep_gemm/include/deep_gemm/impls/sm90_fp8_gemm_1d2d.cuh b/deep_gemm/include/deep_gemm/impls/sm90_fp8_gemm_1d2d.cuh index 2aad2108..2d56cd21 100644 --- a/deep_gemm/include/deep_gemm/impls/sm90_fp8_gemm_1d2d.cuh +++ b/deep_gemm/include/deep_gemm/impls/sm90_fp8_gemm_1d2d.cuh @@ -234,9 +234,9 @@ sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout, int* signal, } } } else { - cg::coalesced_group group; + std::optional group; if constexpr (kEnableOverlap) { - group = cg::coalesced_threads(); + group.emplace(cg::coalesced_threads()); } // Math warp-groups for WGMMA cutlass::arch::warpgroup_reg_alloc(); @@ -440,7 +440,7 @@ sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout, int* signal, cute::tma_store_wait<0>(); } - group.sync(); + group.value().sync(); __threadfence(); if (threadIdx.x == 0) { From a086212879c1c9cb30aea0fcf255d5352e01646a Mon Sep 17 00:00:00 2001 From: Eric Wong Date: Thu, 4 Sep 2025 17:47:07 +0800 Subject: [PATCH 43/71] fix. --- tests/test_signal_gemm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_signal_gemm.py b/tests/test_signal_gemm.py index 19dab79e..3bdd7315 100644 --- a/tests/test_signal_gemm.py +++ b/tests/test_signal_gemm.py @@ -75,7 +75,7 @@ def test_func(): # Test performance with fixed shapes valid_m = masked_m.sum().item() - t = bench_kineto(test_func, 'fp8_signal_gemm', suppress_kineto_output=True) + t = bench_kineto(test_func, 'fp8_gemm', suppress_kineto_output=True) print(f' > Perf ({num_groups=}, expected_m_per_group={expected_m_per_group:4}, n={n:4}, k={k:4}, {kernel_opt}): ' f'{t * 1e6:4.0f} us | ' f'{2 * valid_m * n * k / t / 1e12:4.0f} TFLOPS | ' From e6c697732ae2dc124a1d504e125d1e185075e7ab Mon Sep 17 00:00:00 2001 From: Eric Wong Date: Thu, 4 Sep 2025 17:47:30 +0800 Subject: [PATCH 44/71] fix. --- tests/test_signal_gemm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_signal_gemm.py b/tests/test_signal_gemm.py index 3bdd7315..d3047a58 100644 --- a/tests/test_signal_gemm.py +++ b/tests/test_signal_gemm.py @@ -40,7 +40,7 @@ def check_signal(num_local_expert, max_m, block_m, threshold, combine_signal, ma assert signal[i] == 0, f'{i=}, {signal[i]=}' def test_m_grouped_gemm_signal(max_block_n=256) -> None: - print('Testing m-grouped masked GEMM:') + print('Testing m-grouped signal GEMM:') # TODO: when the actual `m` is greater than `expected_m_per_group`, efficiency may significantly decrease. for kernel_type, num_groups, max_m, expected_m_per_group, n, k in enumerate_m_grouped_masked(): From 7af0f7a50d7e9e0ff7fad7dfe196ce709dd19d33 Mon Sep 17 00:00:00 2001 From: Eric Wong Date: Thu, 4 Sep 2025 17:54:40 +0800 Subject: [PATCH 45/71] debug --- tests/test_signal_gemm.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/tests/test_signal_gemm.py b/tests/test_signal_gemm.py index d3047a58..f39fea73 100644 --- a/tests/test_signal_gemm.py +++ b/tests/test_signal_gemm.py @@ -24,20 +24,30 @@ def check_signal(num_local_expert, max_m, block_m, threshold, combine_signal, ma expert_len = max_m // block_m # print(len(signal)) + flag = True for expert in range(num_local_expert): mask = masked_m[expert] start = expert * expert_len end = expert * (expert_len + 1) if mask == 0: for i in range(start, end): - assert signal[i] == 0, f'{i=}, {signal[i]=}' + if signal[i] != 0: + flag = False + # assert signal[i] == 0, f'{i=}, {signal[i]=}' else: valid_len = ceil_div(mask, block_m) + print(f'{start=}, {end=}, {valid_len=}') for i in range(start, end): if i < start + valid_len: - assert signal[i] == threshold, f'{i=}, {signal[i]=}, {threshold=}' + # assert signal[i] == threshold, f'{i=}, {signal[i]=}, {threshold=}' + if signal[i] != threshold: + print(f'{i=}, {signal[i]=}, {threshold=}') + flag = False else: - assert signal[i] == 0, f'{i=}, {signal[i]=}' + # assert signal[i] == 0, f'{i=}, {signal[i]=}' + if signal[i] != 0: + print(f'{i=}, {signal[i]=}') + flag = False def test_m_grouped_gemm_signal(max_block_n=256) -> None: print('Testing m-grouped signal GEMM:') From 2c9fa4452d744c21075d5f7c328df2ca23cec491 Mon Sep 17 00:00:00 2001 From: Eric Wong Date: Thu, 4 Sep 2025 17:57:25 +0800 Subject: [PATCH 46/71] more --- tests/test_signal_gemm.py | 19 +++++-------------- 1 file changed, 5 insertions(+), 14 deletions(-) diff --git a/tests/test_signal_gemm.py b/tests/test_signal_gemm.py index f39fea73..49a852e8 100644 --- a/tests/test_signal_gemm.py +++ b/tests/test_signal_gemm.py @@ -20,34 +20,25 @@ def ceil_div(a, b): def check_signal(num_local_expert, max_m, block_m, threshold, combine_signal, masked_m): signal = combine_signal.cpu().tolist() - # print(signal) + print(f'signal = {signal}') + print(f'masked_m = {masked_m.cpu().tolist()}') expert_len = max_m // block_m # print(len(signal)) - flag = True for expert in range(num_local_expert): mask = masked_m[expert] start = expert * expert_len end = expert * (expert_len + 1) if mask == 0: for i in range(start, end): - if signal[i] != 0: - flag = False - # assert signal[i] == 0, f'{i=}, {signal[i]=}' + assert signal[i] == 0, f'{i=}, {signal[i]=}' else: valid_len = ceil_div(mask, block_m) - print(f'{start=}, {end=}, {valid_len=}') for i in range(start, end): if i < start + valid_len: - # assert signal[i] == threshold, f'{i=}, {signal[i]=}, {threshold=}' - if signal[i] != threshold: - print(f'{i=}, {signal[i]=}, {threshold=}') - flag = False + assert signal[i] == threshold, f'{i=}, {signal[i]=}, {threshold=}' else: - # assert signal[i] == 0, f'{i=}, {signal[i]=}' - if signal[i] != 0: - print(f'{i=}, {signal[i]=}') - flag = False + assert signal[i] == 0, f'{i=}, {signal[i]=}' def test_m_grouped_gemm_signal(max_block_n=256) -> None: print('Testing m-grouped signal GEMM:') From 96cdc9038f7ded1359f47cad4a6e20239af68520 Mon Sep 17 00:00:00 2001 From: Eric Wong Date: Thu, 4 Sep 2025 18:14:28 +0800 Subject: [PATCH 47/71] fix. --- csrc/jit_kernels/heuristics/common.hpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/csrc/jit_kernels/heuristics/common.hpp b/csrc/jit_kernels/heuristics/common.hpp index 5c5a4f4e..737d98cc 100644 --- a/csrc/jit_kernels/heuristics/common.hpp +++ b/csrc/jit_kernels/heuristics/common.hpp @@ -278,7 +278,8 @@ static GemmConfig get_best_config(const GemmType& gemm_type, const KernelType& k .multicast_config = best_multicast_config, // ReSharper disable once CppLocalVariableMightNotBeInitialized .smem_config = best_smem_config, - .thread_config = ArchSpec::get_thread_config(kernel_type, best_block_m, best_block_n) + .thread_config = ArchSpec::get_thread_config(kernel_type, best_block_m, best_block_n), + .enable_overlap = enable_overlap }; // Only SM100 BF16 kernels support tensor core control From 11eeed6621b47be94e6d46316858a7381c1ffa0a Mon Sep 17 00:00:00 2001 From: Eric Wong Date: Thu, 4 Sep 2025 18:22:59 +0800 Subject: [PATCH 48/71] more. --- deep_gemm/include/deep_gemm/impls/sm90_fp8_gemm_1d2d.cuh | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/deep_gemm/include/deep_gemm/impls/sm90_fp8_gemm_1d2d.cuh b/deep_gemm/include/deep_gemm/impls/sm90_fp8_gemm_1d2d.cuh index 2d56cd21..8278f988 100644 --- a/deep_gemm/include/deep_gemm/impls/sm90_fp8_gemm_1d2d.cuh +++ b/deep_gemm/include/deep_gemm/impls/sm90_fp8_gemm_1d2d.cuh @@ -234,10 +234,6 @@ sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout, int* signal, } } } else { - std::optional group; - if constexpr (kEnableOverlap) { - group.emplace(cg::coalesced_threads()); - } // Math warp-groups for WGMMA cutlass::arch::warpgroup_reg_alloc(); @@ -440,7 +436,9 @@ sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout, int* signal, cute::tma_store_wait<0>(); } + cg::coalesced_group group = cg::coalesced_threads(); group.value().sync(); + __threadfence(); if (threadIdx.x == 0) { From debb196dc54b13b1c9f20d47a281f7c5c2ec26f6 Mon Sep 17 00:00:00 2001 From: Eric Wong Date: Thu, 4 Sep 2025 18:24:25 +0800 Subject: [PATCH 49/71] fix. --- deep_gemm/include/deep_gemm/impls/sm90_fp8_gemm_1d2d.cuh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deep_gemm/include/deep_gemm/impls/sm90_fp8_gemm_1d2d.cuh b/deep_gemm/include/deep_gemm/impls/sm90_fp8_gemm_1d2d.cuh index 8278f988..c4c68521 100644 --- a/deep_gemm/include/deep_gemm/impls/sm90_fp8_gemm_1d2d.cuh +++ b/deep_gemm/include/deep_gemm/impls/sm90_fp8_gemm_1d2d.cuh @@ -437,7 +437,7 @@ sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout, int* signal, } cg::coalesced_group group = cg::coalesced_threads(); - group.value().sync(); + group.sync(); __threadfence(); From 7ee0480a0043d8af18a7c5e6c47c81fa9961553f Mon Sep 17 00:00:00 2001 From: Eric Wong Date: Thu, 4 Sep 2025 18:29:09 +0800 Subject: [PATCH 50/71] fix --- tests/test_signal_gemm.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/test_signal_gemm.py b/tests/test_signal_gemm.py index 49a852e8..be2f3ab0 100644 --- a/tests/test_signal_gemm.py +++ b/tests/test_signal_gemm.py @@ -20,15 +20,16 @@ def ceil_div(a, b): def check_signal(num_local_expert, max_m, block_m, threshold, combine_signal, masked_m): signal = combine_signal.cpu().tolist() + maskm = masked_m.cpu().tolist() print(f'signal = {signal}') print(f'masked_m = {masked_m.cpu().tolist()}') expert_len = max_m // block_m # print(len(signal)) for expert in range(num_local_expert): - mask = masked_m[expert] + mask = maskm[expert] start = expert * expert_len - end = expert * (expert_len + 1) + end = expert * expert_len + expert_len if mask == 0: for i in range(start, end): assert signal[i] == 0, f'{i=}, {signal[i]=}' From 0c0eac831fc45a16534c96846efae72b3f0c73fe Mon Sep 17 00:00:00 2001 From: Eric Wong Date: Thu, 4 Sep 2025 18:34:18 +0800 Subject: [PATCH 51/71] more. --- tests/test_signal_gemm.py | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/tests/test_signal_gemm.py b/tests/test_signal_gemm.py index be2f3ab0..f3c3d827 100644 --- a/tests/test_signal_gemm.py +++ b/tests/test_signal_gemm.py @@ -30,16 +30,12 @@ def check_signal(num_local_expert, max_m, block_m, threshold, combine_signal, ma mask = maskm[expert] start = expert * expert_len end = expert * expert_len + expert_len - if mask == 0: - for i in range(start, end): + valid_len = ceil_div(mask, block_m) + for i in range(start, end): + if i < start + valid_len: + assert signal[i] == threshold, f'{i=}, {signal[i]=}, {threshold=}' + else: assert signal[i] == 0, f'{i=}, {signal[i]=}' - else: - valid_len = ceil_div(mask, block_m) - for i in range(start, end): - if i < start + valid_len: - assert signal[i] == threshold, f'{i=}, {signal[i]=}, {threshold=}' - else: - assert signal[i] == 0, f'{i=}, {signal[i]=}' def test_m_grouped_gemm_signal(max_block_n=256) -> None: print('Testing m-grouped signal GEMM:') From af77de9ff745438f209cc0a8329de535fbe43b9b Mon Sep 17 00:00:00 2001 From: Eric Wong Date: Thu, 4 Sep 2025 18:35:28 +0800 Subject: [PATCH 52/71] remove comments. --- tests/test_signal_gemm.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/tests/test_signal_gemm.py b/tests/test_signal_gemm.py index f3c3d827..fd8978db 100644 --- a/tests/test_signal_gemm.py +++ b/tests/test_signal_gemm.py @@ -21,11 +21,8 @@ def ceil_div(a, b): def check_signal(num_local_expert, max_m, block_m, threshold, combine_signal, masked_m): signal = combine_signal.cpu().tolist() maskm = masked_m.cpu().tolist() - print(f'signal = {signal}') - print(f'masked_m = {masked_m.cpu().tolist()}') expert_len = max_m // block_m - # print(len(signal)) for expert in range(num_local_expert): mask = maskm[expert] start = expert * expert_len From 23be492bab7711b06aa61c69a823efda6ca4dca2 Mon Sep 17 00:00:00 2001 From: Eric Wong Date: Thu, 4 Sep 2025 18:39:27 +0800 Subject: [PATCH 53/71] remove code dup. --- csrc/apis/gemm.hpp | 54 --- csrc/jit_kernels/impls/sm90_fp8_gemm_1d2d.hpp | 125 +---- .../impls/sm90_fp8_signal_gemm_1d2d.cuh | 457 ------------------ 3 files changed, 1 insertion(+), 635 deletions(-) delete mode 100644 deep_gemm/include/deep_gemm/impls/sm90_fp8_signal_gemm_1d2d.cuh diff --git a/csrc/apis/gemm.hpp b/csrc/apis/gemm.hpp index 092f7a53..c11969ae 100644 --- a/csrc/apis/gemm.hpp +++ b/csrc/apis/gemm.hpp @@ -233,56 +233,6 @@ static std::optional> m_grouped_fp8_gemm_nt_masked(const std return result; } -static std::pair m_grouped_fp8_gemm_nt_signal(const std::pair& a, - const std::pair& b, - const torch::Tensor& d, - const torch::Tensor& masked_m, - const torch::Tensor& signal, - const int& expected_m, - std::optional> recipe, - const std::string& compiled_dims, - const bool& disable_ue8m0_cast, - const int& max_block_n = 256) { - // Shape must be `[G, M, K] @ [G, N, K].mT` - const auto& major_a = get_major_type_ab(a.first); - const auto& major_b = get_major_type_ab(b.first); - DG_HOST_ASSERT(major_a == cute::UMMA::Major::K and major_b == cute::UMMA::Major::K); - DG_HOST_ASSERT(masked_m.is_contiguous()); - DG_HOST_ASSERT(signal.is_contiguous()); - - // Type and shape checks - const auto& [num_groups, m, k] = get_shape<3>(a.first); - const auto& [num_groups_, n, k_] = get_shape<3>(b.first); - const auto& [num_groups__, m_, n_] = get_shape<3>(d); - const auto& num_groups___ = static_cast(masked_m.numel()); - DG_HOST_ASSERT(num_groups == num_groups_ and num_groups == num_groups__ and num_groups == num_groups___); - DG_HOST_ASSERT(m == m_ and n == n_ and k == k_); - DG_HOST_ASSERT(expected_m > 0 and m > 0 and n > 0 and k > 0 and num_groups > 0); - DG_HOST_ASSERT(a.first.scalar_type() == torch::kFloat8_e4m3fn); - DG_HOST_ASSERT(b.first.scalar_type() == torch::kFloat8_e4m3fn); - DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16); - DG_HOST_ASSERT(masked_m.scalar_type() == torch::kInt); - DG_HOST_ASSERT(signal.scalar_type() == torch::kInt32); - - // D must be N-major - check_major_type_cd(d); - - // Transform scaling factors - if (not recipe.has_value()) - recipe = get_default_recipe(a.second.scalar_type(), b.second.scalar_type()); - const auto& sfa = layout::transform_sf_into_required_layout(a.second, m, k, recipe.value(), num_groups, true, disable_ue8m0_cast); - const auto& sfb = layout::transform_sf_into_required_layout(b.second, n, k, recipe.value(), num_groups, false, disable_ue8m0_cast); - - // Dispatch implementation - only SM90 - const auto& arch_major = device_runtime->get_arch_major(); - if (arch_major == 9 and sfa.scalar_type() == torch::kFloat) { - return sm90_m_grouped_fp8_gemm_signal_1d2d(a.first, sfa, b.first, sfb, d, masked_m, signal, - num_groups, m, n, k, expected_m, major_a, major_b, compiled_dims, max_block_n); - } else { - DG_HOST_UNREACHABLE("Unsupported architecture or scaling factor types"); - } -} - static void k_grouped_fp8_gemm_tn_contiguous(const std::pair& a, const std::pair& b, const torch::Tensor& d, @@ -501,10 +451,6 @@ static void register_apis(pybind11::module_& m) { py::arg("compiled_dims") = "nk", py::arg("disable_ue8m0_cast") = false, py::arg("max_block_n") = 256, py::arg("enable_overlap") = false, py::arg("signal") = std::nullopt); - m.def("m_grouped_fp8_gemm_nt_signal", &m_grouped_fp8_gemm_nt_signal, - py::arg("a"), py::arg("b"), py::arg("d"), py::arg("masked_m"), py::arg("signal"), - py::arg("expected_m"), py::arg("recipe") = std::nullopt, - py::arg("compiled_dims") = "nk", py::arg("disable_ue8m0_cast") = false, py::arg("max_block_n") = 256); m.def("k_grouped_fp8_gemm_tn_contiguous", &k_grouped_fp8_gemm_tn_contiguous, py::arg("a"), py::arg("b"), py::arg("d"), py::arg("ks"), py::arg("ks_tensor"), py::arg("c") = std::nullopt, diff --git a/csrc/jit_kernels/impls/sm90_fp8_gemm_1d2d.hpp b/csrc/jit_kernels/impls/sm90_fp8_gemm_1d2d.hpp index 6fbc6e67..a81fed27 100644 --- a/csrc/jit_kernels/impls/sm90_fp8_gemm_1d2d.hpp +++ b/csrc/jit_kernels/impls/sm90_fp8_gemm_1d2d.hpp @@ -252,127 +252,4 @@ static std::optional> sm90_m_grouped_fp8_gemm_masked_1d2d(co return enable_overlap ? std::optional(std::make_pair(config.block_m, config.signal_threshold)) : std::nullopt; -} - -class SM90FP8SignalGemm1D2DRuntime final: public LaunchRuntime { -public: - struct Args { - int m, n, k, num_groups; - const std::string& compiled_dims; - - GemmConfig gemm_config; - LaunchArgs launch_args; - - void *sfb, *grouped_layout, *signal; - CUtensorMap tensor_map_a; - CUtensorMap tensor_map_b; - CUtensorMap tensor_map_d; - CUtensorMap tensor_map_sfa; - }; - - static std::string generate_impl(const Args& args) { - return fmt::format(R"( -#include - -using namespace deep_gemm; - -static void __instantiate_kernel() {{ - auto ptr = reinterpret_cast(&sm90_fp8_signal_gemm_1d2d_impl< - {}, {}, {}, - {}, - {}, {}, {}, - {}, - {}, {}, - {}, {}, - {}, {}, - {}, {} - >); -}}; -)", - get_compiled_dim(args.m, 'm', args.compiled_dims), - get_compiled_dim(args.n, 'n', args.compiled_dims), - get_compiled_dim(args.k, 'k', args.compiled_dims), - args.num_groups, - args.gemm_config.block_m, args.gemm_config.block_n, args.gemm_config.block_k, - args.gemm_config.smem_config.swizzle_cd_mode, - args.gemm_config.num_stages, args.gemm_config.num_last_stages, - args.gemm_config.thread_config.num_tma_threads, args.gemm_config.thread_config.num_math_threads, - args.gemm_config.multicast_config.num_multicast, args.gemm_config.multicast_config.is_multicast_on_a, - args.gemm_config.num_sms, to_string(args.gemm_config.gemm_type)); - } - - static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) { - DG_CUDA_UNIFIED_CHECK(launch_kernel(kernel, config, - args.sfb, args.grouped_layout, args.signal, - args.m, args.n, args.k, - args.tensor_map_a, args.tensor_map_b, - args.tensor_map_d, args.tensor_map_sfa)); - } -}; - -static std::pair sm90_m_grouped_fp8_gemm_signal_1d2d(const torch::Tensor& a, const torch::Tensor& sfa, - const torch::Tensor& b, const torch::Tensor& sfb, - const torch::Tensor& d, - const torch::Tensor& masked_m, - const torch::Tensor& signal, - const int& num_groups, const int& m, const int& n, const int& k, - const int& expected_m, - const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, - const std::string& compiled_dims, const int& max_block_n = 256) { - const auto& aligned_k = align(k, 128); - DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16); - DG_HOST_ASSERT(major_a == cute::UMMA::Major::K and major_b == cute::UMMA::Major::K); - DG_HOST_ASSERT(signal.scalar_type() == torch::kInt32); - - const auto& config = get_best_config( - GemmType::MGroupedMasked, KernelType::Kernel1D2D, - expected_m, n, k, num_groups, major_a, major_b, - torch::kFloat8_e4m3fn, d.scalar_type(), false, - device_runtime->get_num_sms(), max_block_n); - - // Requires no TMA splits - DG_HOST_ASSERT(config.smem_config.swizzle_a_mode == config.block_k); - DG_HOST_ASSERT(config.smem_config.swizzle_b_mode == config.block_k); - const auto& tensor_map_a = make_tma_a_desc(major_a, a, m, k, - SM90ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m), - config.block_k, - static_cast(a.stride(get_non_contiguous_dim(major_a))), num_groups, - config.smem_config.swizzle_a_mode); - const auto& tensor_map_b = make_tma_b_desc(major_b, b, n, k, - SM90ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n), - config.block_k, - static_cast(b.stride(get_non_contiguous_dim(major_b))), num_groups, - config.smem_config.swizzle_b_mode); - const auto& tensor_map_d = make_tma_cd_desc(d, m, n, - SM90ArchSpec::get_cd_store_block_m(config.block_m), - SM90ArchSpec::get_cd_store_block_n(config.block_n), - static_cast(d.stride(-2)), num_groups, - config.smem_config.swizzle_cd_mode); - const auto& tensor_map_sfa = make_tma_sf_desc(cute::UMMA::Major::MN, sfa, m, k, - config.block_m, config.block_k, num_groups, 0); - - // Launch with cooperative groups support - const SM90FP8SignalGemm1D2DRuntime::Args& args = { - .m = m, .n = n, .k = aligned_k, - .num_groups = num_groups, - .compiled_dims = compiled_dims, - .gemm_config = config, - .launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads, - config.smem_config.smem_size, - config.multicast_config.num_multicast, - true), // 启用cooperative groups - .sfb = sfb.data_ptr(), - .grouped_layout = masked_m.data_ptr(), - .signal = signal.data_ptr(), - .tensor_map_a = tensor_map_a, - .tensor_map_b = tensor_map_b, - .tensor_map_d = tensor_map_d, - .tensor_map_sfa = tensor_map_sfa, - }; - const auto& code = SM90FP8SignalGemm1D2DRuntime::generate(args); - const auto& runtime = compiler->build("sm90_fp8_m_grouped_gemm_signal_1d2d", code); - SM90FP8SignalGemm1D2DRuntime::launch(runtime, args); - return std::make_pair(config.block_m, config.signal_threshold); -} - -} // namespace deep_gemm +} \ No newline at end of file diff --git a/deep_gemm/include/deep_gemm/impls/sm90_fp8_signal_gemm_1d2d.cuh b/deep_gemm/include/deep_gemm/impls/sm90_fp8_signal_gemm_1d2d.cuh deleted file mode 100644 index 6b55b7ee..00000000 --- a/deep_gemm/include/deep_gemm/impls/sm90_fp8_signal_gemm_1d2d.cuh +++ /dev/null @@ -1,457 +0,0 @@ -#pragma once - -#pragma clang diagnostic push -#pragma clang diagnostic ignored "-Wunknown-attributes" - -#include -#include -#include - -#include -#include -#include - -#include -#include -#include - -namespace cg = cooperative_groups; -namespace deep_gemm { - -using namespace deep_gemm::sm90; - -template -__device__ __host__ void outer_launch_k_iterations(const auto& inner_launch_k_iterations, const auto& func, uint32_t num_former_iters) { - if (num_former_iters == kNumFormerIters) { - inner_launch_k_iterations(func, cute::Int{}); - return; - } - - if constexpr (kNumFormerIters + kGap <= kEnd) - outer_launch_k_iterations(inner_launch_k_iterations, func, num_former_iters); -} - -template -__global__ __launch_bounds__(kNumTMAThreads + kNumMathThreads, 1) void -sm90_fp8_signal_gemm_1d2d_impl(float* sfb, int* grouped_layout, int32_t* signal, - uint32_t shape_m, uint32_t shape_n, uint32_t shape_k, - const __grid_constant__ cute::TmaDescriptor tensor_map_a, - const __grid_constant__ cute::TmaDescriptor tensor_map_b, - const __grid_constant__ cute::TmaDescriptor tensor_map_d, - const __grid_constant__ cute::TmaDescriptor tensor_map_sfa) { -#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 900)) or defined(__CLION_IDE__) - // Scaling checks - DG_STATIC_ASSERT(BLOCK_K == 128, "Only support per-128-channel FP8 scaling"); - DG_STATIC_ASSERT(constexpr_ceil_div(BLOCK_N, BLOCK_K) == 1 or (constexpr_gcd(BLOCK_N, BLOCK_K) == BLOCK_N - BLOCK_K), "Too much B scales in a single block"); - - // Types - using WGMMA = typename FP8MMASelector::type; - using Barrier = cutlass::arch::ClusterTransactionBarrier; - DG_STATIC_ASSERT(BLOCK_M % WGMMA::M == 0, "Invalid block size"); - - // Overwrite shape constants if the compiler gives - shape_m = SHAPE_M != 0 ? SHAPE_M : shape_m; - shape_n = SHAPE_N != 0 ? SHAPE_N : shape_n; - shape_k = SHAPE_K != 0 ? SHAPE_K : shape_k; - - // Shared memory - static constexpr bool kMustUseUniformedScaleB = (BLOCK_K % BLOCK_N == 0); - static constexpr uint32_t SMEM_D_SIZE = BLOCK_M * BLOCK_N * sizeof(__nv_bfloat16); - static constexpr uint32_t SMEM_A_SIZE_PER_STAGE = BLOCK_M * BLOCK_K * sizeof(__nv_fp8_e4m3); - static constexpr uint32_t SMEM_B_SIZE_PER_STAGE = BLOCK_N * BLOCK_K * sizeof(__nv_fp8_e4m3); - static constexpr uint32_t SMEM_SFA_SIZE_PER_STAGE = BLOCK_M * sizeof(float); - const uint32_t& shape_k_scales = ceil_div(shape_k, BLOCK_K); - const uint32_t& smem_sfb_size = align(shape_k_scales * (kMustUseUniformedScaleB ? 1 : 2) * sizeof(float), sizeof(Barrier)); - - // Configs - constexpr uint32_t kFullKOfAllStages = kNumStages * BLOCK_K; - const uint32_t num_iterations = ceil_div(shape_k, kFullKOfAllStages); - const uint32_t warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); - const uint32_t lane_idx = get_lane_idx(); - - // Prefetch TMA descriptors at the very beginning - if (threadIdx.x == kNumMathThreads) { - // NOTES: `reinterpret_cast` must be here, or NVRTC will fail - cute::prefetch_tma_descriptor(&tensor_map_a); - cute::prefetch_tma_descriptor(&tensor_map_b); - cute::prefetch_tma_descriptor(&tensor_map_sfa); - cute::prefetch_tma_descriptor(&tensor_map_d); - } - __syncwarp(); - - // Align to 1024 bytes for swizzle-128B - extern __shared__ __align__(1024) uint8_t smem_buffer[]; - DG_STATIC_ASSERT(SMEM_D_SIZE % 1024 == 0, "Shared memory of A/B must be aligned to 1024 bytes"); - - // Data on shared memory - auto smem_d = reinterpret_cast<__nv_bfloat16*>(smem_buffer); - __nv_fp8_e4m3* smem_a[kNumStages]; - __nv_fp8_e4m3* smem_b[kNumStages]; - float* smem_sfa[kNumStages]; - float* smem_sfb; - - // TMA Barrier for both divisible and non-divisible cases - Barrier* full_barriers[kNumStages]; - Barrier* empty_barriers[kNumStages]; - - // Fill shared memory pointers - #pragma unroll - for (uint32_t i = 0; i < kNumStages; ++ i) { - smem_a[i] = reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_D_SIZE + i * SMEM_A_SIZE_PER_STAGE); - smem_b[i] = reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_D_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE); - smem_sfa[i] = reinterpret_cast(smem_buffer + SMEM_D_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE) + i * SMEM_SFA_SIZE_PER_STAGE); - } - smem_sfb = reinterpret_cast(smem_buffer + SMEM_D_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SFA_SIZE_PER_STAGE)); - - // Fill barriers - auto barrier_start_ptr = reinterpret_cast(reinterpret_cast(smem_sfb) + smem_sfb_size); - #pragma unroll - for (uint32_t i = 0; i < kNumStages; ++ i) { - full_barriers[i] = barrier_start_ptr + i; - empty_barriers[i] = barrier_start_ptr + kNumStages + i; - } - - // Initialize barriers - DG_STATIC_ASSERT(kNumTMAMulticast <= 32, "Too many TMA multicast"); - if (threadIdx.x == kNumMathThreads) { - // NOTES: we always use `lane_idx` to arrive for the `lane_idx`-th CTA in the cluster, - // even with TMA multicast disabled, we want to make the behavior aligned - #pragma unroll - for (uint32_t i = 0; i < kNumStages; ++ i) { - full_barriers[i]->init(1); - empty_barriers[i]->init(kNumTMAMulticast * kNumMathThreads / 32); - } - - // Make initialized barrier visible in async proxy - cutlass::arch::fence_view_async_shared(); - cutlass::arch::fence_barrier_init(); - } - - // Synchronize all threads to make barrier visible in normal memory model - (kNumTMAMulticast > 1) ? cute::cluster_sync() : __syncthreads(); - - // For pipeline unrolling - struct DivisibleK {}; - struct NotDivisibleK {}; - struct SkipComputation {}; - struct NotSkipComputation {}; - auto launch_k_iterations = [=](const auto& func, bool skip_computation, uint32_t num_former_iters) { - constexpr bool kShouldOptimize = BLOCK_K / constexpr_gcd(BLOCK_K, BLOCK_N) <= 4 and not kMustUseUniformedScaleB; - constexpr uint32_t kGap = constexpr_gcd(BLOCK_K, BLOCK_N) / 8; - constexpr uint32_t kEnd = kShouldOptimize ? BLOCK_K / 8 : 0; - - // NOTES: for too-many branches (> 5), we disable this optimization - // Otherwise, the compiler must know the dynamic variable `num_former_iters`'s real value - outer_launch_k_iterations<0, kGap, kEnd>([=](const auto& func, auto num_former_iters_type) { - if (skip_computation) { - for (uint32_t k_iter = 0; k_iter < num_iterations; ++ k_iter) - func(k_iter, DivisibleK{}, SkipComputation{}, num_former_iters_type); - } else if (shape_k % kFullKOfAllStages == 0) { - for (uint32_t k_iter = 0; k_iter < num_iterations; ++ k_iter) - func(k_iter, DivisibleK{}, NotSkipComputation{}, num_former_iters_type); - } else { - for (uint32_t k_iter = 0; k_iter < num_iterations - 1; ++ k_iter) - func(k_iter, DivisibleK{}, NotSkipComputation{}, num_former_iters_type); - func(num_iterations - 1, NotDivisibleK{}, NotSkipComputation{}, num_former_iters_type); - } - }, func, kShouldOptimize ? num_former_iters : 0); - }; - - // Register reconfigurations - constexpr uint32_t kNumTMARegisters = 40; - constexpr uint32_t kNumMathRegisters = 232; - - // Block scheduler - uint32_t m_block_idx, n_block_idx; - auto scheduler = Scheduler(shape_m, shape_n, grouped_layout); - - if (threadIdx.x >= kNumMathThreads) { - // TMA warp-group for loading data - cutlass::arch::warpgroup_reg_dealloc(); - - // NOTES: only one thread (or warp) will be used - if (threadIdx.x == kNumMathThreads) { - // Persistently schedule over blocks - while (scheduler.get_next_block(m_block_idx, n_block_idx)) { - launch_k_iterations([&](uint32_t k_iter, auto divisible_type, auto _, auto __) { - constexpr bool kHasDivisibleStages = cute::is_same_v; - constexpr uint32_t kNumInnerStages = kHasDivisibleStages ? kNumStages : kNumLastStages; - - // Assign TMA multicast number into A and B - // NOTES: there may be additional odd rows/columns or cases where multicast is not possible. - const bool is_tma_multicast_valid = scheduler.is_tma_multicast_valid(m_block_idx); - const uint32_t num_tma_multicast_a = (kIsTMAMulticastOnA and is_tma_multicast_valid) ? kNumTMAMulticast : 1; - const uint32_t num_tma_multicast_b = (not kIsTMAMulticastOnA and is_tma_multicast_valid) ? kNumTMAMulticast : 1; - DG_STATIC_ASSERT(kNumTMAMulticast <= 2, "Scheduler does not support > 2 TMA multicast"); - - // NOTES: unrolling and `kNumInnerStages` are vital for performance, NVCC will try to eliminate all - // shared memory pointers, e.g. `full_barriers` registers, if all the access indices are constant - #pragma unroll - for (uint32_t s = 0; s < kNumInnerStages; ++ s) { - // Wait consumer release - empty_barriers[s]->wait((scheduler.current_iter * num_iterations + k_iter + 1) & 1); - - // Issue TMA A - constexpr bool kWithGroupOffsetA = kGemmType == GemmType::MGroupedMasked; - auto& full_barrier = *full_barriers[s]; - uint32_t k_idx = k_iter * kFullKOfAllStages + s * BLOCK_K; - tma_copy(&tensor_map_a, reinterpret_cast(&full_barrier), - smem_a[s], k_idx, scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx), - num_tma_multicast_a); - tma_copy(&tensor_map_sfa, reinterpret_cast(&full_barrier), - smem_sfa[s], m_block_idx * BLOCK_M, - scheduler.get_global_idx(shape_k_scales, 1, k_idx / BLOCK_K), - num_tma_multicast_a); - - // Issue TMA B - tma_copy(&tensor_map_b, reinterpret_cast(&full_barrier), - smem_b[s], k_idx, scheduler.get_global_idx(shape_n, BLOCK_N, n_block_idx, m_block_idx), - num_tma_multicast_b); - full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SFA_SIZE_PER_STAGE); - } - - // Wait unaligned cases - #pragma unroll - for (uint32_t s = kNumInnerStages; s < kNumStages; ++ s) { - empty_barriers[s]->wait((scheduler.current_iter * num_iterations + k_iter + 1) & 1); - full_barriers[s]->arrive(); - } - }, false, 0); - } - - // To safely deconstruct distributed shared barriers, we need another round of empty waits - if constexpr (kNumTMAMulticast > 1) { - #pragma unroll - for (uint32_t s = 0; s < kNumStages; ++ s) - empty_barriers[s]->wait((scheduler.current_iter * num_iterations + 1) & 1); - } - } - } else { - cg::coalesced_group group = cg::coalesced_threads(); - // Math warp-groups for WGMMA - cutlass::arch::warpgroup_reg_alloc(); - - // NOTES: use `__shfl_sync` to encourage NVCC to use unified registers - const auto math_wg_idx = __shfl_sync(0xffffffff, threadIdx.x / 128, 0); - const auto r_0 = warp_idx * 16 + lane_idx / 4, r_1 = r_0 + 8; - - // Persistently schedule over blocks - while (scheduler.get_next_block(m_block_idx, n_block_idx)) { - // Decide the number of scales B to load - DG_TRAP_ONLY_DEVICE_ASSERT(shape_n % 8 == 0); - uint32_t num_former_iters = BLOCK_N / 8, num_full_iters = num_former_iters; - if constexpr (not kMustUseUniformedScaleB) { - num_former_iters = min(BLOCK_N, BLOCK_K - n_block_idx * BLOCK_N % BLOCK_K) / 8; - num_full_iters = min(shape_n - n_block_idx * BLOCK_N, BLOCK_N) / 8; - } - uint32_t num_sfb = shape_k_scales * (num_former_iters >= num_full_iters ? 1 : 2); - - // Load B scales with math warp-groups - // NOTES: except the first warp, we want to overlap loading B scales with TMA stores between tasks - if (threadIdx.x >= 32) { - auto num_previous_lines = scheduler.get_global_idx(ceil_div(shape_n, BLOCK_K), 0, 0, m_block_idx); - auto local_sfb = sfb + (num_previous_lines + ((n_block_idx * BLOCK_N) / BLOCK_K)) * shape_k_scales; - #pragma unroll - for (uint32_t i = threadIdx.x - 32; i < num_sfb; i += kNumMathThreads - 32) - st_shared(smem_sfb + i, __ldg(local_sfb + i)); - } - cutlass::arch::NamedBarrier(kNumMathThreads).sync(); - - // Accumulation for WGMMA or CUDA promotion - constexpr uint32_t WAVE_BLOCK_M = WGMMA::M * (BLOCK_M <= 64 ? 1 : 2); - DG_STATIC_ASSERT(BLOCK_M % WAVE_BLOCK_M == 0, "Invalid block sizes"); - float accum[WGMMA::kNumAccum], final_accum[WGMMA::kNumAccum * (BLOCK_M / WAVE_BLOCK_M)] = {0}; - - // Empty barrier arrival - auto empty_barrier_arrive = [&](uint32_t s) { - if constexpr (kNumTMAMulticast == 1) { - lane_idx == 0 ? empty_barriers[s]->arrive() : void(); - } else { - auto target_cta = scheduler.is_peer_cta_alive ? lane_idx : cute::block_rank_in_cluster(); - lane_idx < kNumTMAMulticast ? empty_barriers[s]->arrive(target_cta) : void(); - } - }; - - // Launch MMAs - launch_k_iterations([&](uint32_t k_iter, auto divisible_type, auto skip_type, auto _) { - constexpr bool kSkipComputation = cute::is_same_v; - constexpr bool kHasDivisibleStages = cute::is_same_v; - constexpr uint32_t kNumInnerStages = kSkipComputation ? 0 : (kHasDivisibleStages ? kNumStages : kNumLastStages); - - #pragma unroll - for (uint32_t s = 0; s < kNumInnerStages; ++ s) { - // Read B scales - float scale_b_0 = ld_shared(smem_sfb + k_iter * kNumStages + s), scale_b_1; - // NOTES: even some blocks do not need to read the second row, but we still load one to align with other blocks - if constexpr (not kMustUseUniformedScaleB) - scale_b_1 = ld_shared(smem_sfb + k_iter * kNumStages + s + shape_k_scales); - - // Wait TMA arrivals - full_barriers[s]->wait((scheduler.current_iter * num_iterations + k_iter) & 1); - - // TODO: remove some useless computation for unaligned Ms - #pragma unroll - for (uint32_t local_idx = 0; local_idx < BLOCK_M / WAVE_BLOCK_M; ++ local_idx) { - auto m_offset = local_idx * WAVE_BLOCK_M; - - // Read A scales - // NOTES: all shared memory read must be prior to `warpgroup_arrive` to avoid next scheduled block polluting the results - auto scale_a_0 = ld_shared(smem_sfa[s] + r_0 + m_offset); - auto scale_a_1 = ld_shared(smem_sfa[s] + r_1 + m_offset); - - // Commit WGMMA instructions - #pragma unroll - for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i) - warpgroup_fence_operand(accum[i]); - warpgroup_arrive(); - #pragma unroll - for (uint32_t k = 0; k < BLOCK_K / WGMMA::K; ++ k) { - auto desc_a = make_smem_desc(smem_a[s] + (math_wg_idx * WGMMA::M + m_offset) * BLOCK_K + k * WGMMA::K, 1); - auto desc_b = make_smem_desc(smem_b[s] + k * WGMMA::K, 1); - WGMMA::wgmma(desc_a, desc_b, accum, k); - } - warpgroup_commit_batch(); - #pragma unroll - for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i) - warpgroup_fence_operand(accum[i]); - warpgroup_wait<0>(); - - // Notify barrier arrival at the last warpgroup wave - if (local_idx == BLOCK_M / WAVE_BLOCK_M - 1) - empty_barrier_arrive(s); - - // Promote with scales - // NOTES: making it as predicates is very important for performance, comparing to two loops - float scale_0_0 = scale_a_0 * scale_b_0, scale_1_0 = scale_a_1 * scale_b_0; - float scale_0_1, scale_1_1; - if constexpr (not kMustUseUniformedScaleB) - scale_0_1 = scale_a_0 * scale_b_1, scale_1_1 = scale_a_1 * scale_b_1; - - auto shifted_accum = final_accum + WGMMA::kNumAccum * local_idx; - #pragma unroll - for (uint32_t i = 0; i < WGMMA::kNumAccum / 4; ++ i) { - // NOTES: for unrolled `num_former_iters` cases, we expect the compiler to automatically make it a constant - bool predicate = kMustUseUniformedScaleB or i < num_former_iters; - shifted_accum[i * 4 + 0] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 0]; - shifted_accum[i * 4 + 1] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 1]; - shifted_accum[i * 4 + 2] += (predicate ? scale_1_0 : scale_1_1) * accum[i * 4 + 2]; - shifted_accum[i * 4 + 3] += (predicate ? scale_1_0 : scale_1_1) * accum[i * 4 + 3]; - } - } - } - - // Wait unaligned cases - #pragma unroll - for (uint32_t s = kNumInnerStages; s < kNumStages; ++ s) { - full_barriers[s]->wait((scheduler.current_iter * num_iterations + k_iter) & 1); - empty_barrier_arrive(s); - } - }, not scheduler.is_computation_valid(m_block_idx, math_wg_idx * WGMMA::M), num_former_iters); - - // TMA checks - constexpr uint32_t kNumElemBytes = sizeof(nv_bfloat16); - constexpr uint32_t TMA_D_BLOCK_N = kSwizzleDMode == 0 ? BLOCK_N : (kSwizzleDMode / kNumElemBytes); - constexpr uint32_t WGMMA_M_PER_WARP = WGMMA::M / 4; - DG_STATIC_ASSERT(BLOCK_M % 8 == 0, "Invalid swizzling atom"); - DG_STATIC_ASSERT(BLOCK_N % TMA_D_BLOCK_N == 0 and BLOCK_N / TMA_D_BLOCK_N <= 32, - "Unaligned TMA store or too many TMA store instructions"); - DG_STATIC_ASSERT(TMA_D_BLOCK_N % 8 == 0, "Invalid TMA block N"); - - // Wait last TMA store to be finished - if (threadIdx.x < BLOCK_N / TMA_D_BLOCK_N) - cute::tma_store_wait<0>(); - cutlass::arch::NamedBarrier(kNumMathThreads).sync(); - - // Write back to shared memory using STSM and issue TMA stores - DG_STATIC_ASSERT(WGMMA::kNumAccum % 4 == 0, "Invalid STSM x2 vectorization"); - #pragma unroll - for (uint32_t local_idx = 0; local_idx < BLOCK_M / WAVE_BLOCK_M; ++ local_idx) { - auto m_offset = local_idx * WAVE_BLOCK_M; - auto shifted_accum = final_accum + WGMMA::kNumAccum * local_idx; - #pragma unroll - for (auto i = 0; i < WGMMA::kNumAccum / 4; ++ i) { - // Swizzle or padding into the correct address - uint8_t* smem_ptr = nullptr; - if constexpr (kSwizzleDMode > 0) { - // Calculate the swizzling atom offset and in-atom offset - constexpr uint32_t kNumBankGroupBytes = 16; - auto atom_offset = i / (TMA_D_BLOCK_N / 8), in_atom_offset = i % (TMA_D_BLOCK_N / 8); - - // Calculate the index of the bank group to be written in the atom - auto bank_group_index = in_atom_offset + lane_idx * (kSwizzleDMode / kNumBankGroupBytes); - - // Reshape the atom in another view and swizzle - // - original: `(BLOCK_M, kSwizzleDMode / kNumBankGroupBytes)` - // - new: `(BLOCK_M * kSwizzleDMode / kNumBankGroupBytes / 8, 8)` - constexpr bool kHasShortcut = (kSwizzleDMode / kNumBankGroupBytes) == 8; - auto row = kHasShortcut ? (in_atom_offset / 8 + lane_idx) : (bank_group_index / 8); - auto col = kHasShortcut ? (in_atom_offset) : (bank_group_index % 8); - col ^= row % (kSwizzleDMode / 16); - - // Add back into the base pointer - // NOTES: think twice before modifying this, as changes may affect the number of instructions - smem_ptr = reinterpret_cast(smem_d) + // Base pointer - warp_idx * (WGMMA_M_PER_WARP * kSwizzleDMode) + // Warp offset - m_offset * kSwizzleDMode + // Wave offset - atom_offset * BLOCK_M * kSwizzleDMode + // Swizzle atom offset (constants) - row * (kNumBankGroupBytes * 8) + col * kNumBankGroupBytes; // In-atom offset - } else { - // No swizzling, just padding - smem_ptr = reinterpret_cast(smem_d + (m_offset + warp_idx * WGMMA_M_PER_WARP + lane_idx) * BLOCK_N + i * 8); - } - - // NOTES: only 16 lanes' addresses are used - SM90_U32x2_STSM_N::copy( - __float22bfloat162_rn({shifted_accum[i * 4 + 0], shifted_accum[i * 4 + 1]}), - __float22bfloat162_rn({shifted_accum[i * 4 + 2], shifted_accum[i * 4 + 3]}), - smem_ptr - ); - } - } - cute::tma_store_fence(); - cutlass::arch::NamedBarrier(kNumMathThreads).sync(); - - // Use TMA store to write back to global memory - // TODO: compatible with FP32 output - constexpr bool kWithGroupOffsetD = kGemmType == GemmType::MGroupedMasked; - DG_STATIC_ASSERT(kNumMathThreads >= BLOCK_N / TMA_D_BLOCK_N, "Too many TMA blocks"); - if (threadIdx.x < BLOCK_N / TMA_D_BLOCK_N) { - auto in_block_n_offset = threadIdx.x * TMA_D_BLOCK_N; - auto smem_ptr = smem_d + in_block_n_offset * BLOCK_M; - cute::SM90_TMA_STORE_2D::copy(&tensor_map_d, smem_ptr, - n_block_idx * BLOCK_N + in_block_n_offset, - scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx)); - cute::tma_store_arrive(); - } - __syncwarp(); - - if (threadIdx.x < BLOCK_N / TMA_D_BLOCK_N) { - cute::tma_store_wait<0>(); - } - - group.sync(); - - __threadfence(); - - if (threadIdx.x == 0) { - atomicAdd(signal + scheduler.current_group_idx * ceil_div(shape_m, BLOCK_M) + m_block_idx, 1); - } - - } - } -#else - if (blockIdx.x == 0 and threadIdx.x == 0) - DG_DEVICE_ASSERT(false and "Signal GEMM kernel only supports SM90 architecture"); -#endif -} - -}; // namespace deep_gemm - -#pragma clang diagnostic pop From 84526c52b5748f376954d0b11c8ce76eb12f1a23 Mon Sep 17 00:00:00 2001 From: Eric Wong Date: Thu, 4 Sep 2025 18:40:22 +0800 Subject: [PATCH 54/71] remove signal gemm api. --- deep_gemm/__init__.py | 1 - 1 file changed, 1 deletion(-) diff --git a/deep_gemm/__init__.py b/deep_gemm/__init__.py index 302732e5..169e2e6b 100644 --- a/deep_gemm/__init__.py +++ b/deep_gemm/__init__.py @@ -28,7 +28,6 @@ m_grouped_fp8_gemm_nt_contiguous, m_grouped_fp8_gemm_nn_contiguous, m_grouped_fp8_gemm_nt_masked, - m_grouped_fp8_gemm_nt_signal, k_grouped_fp8_gemm_tn_contiguous, # BF16 GEMMs bf16_gemm_nt, bf16_gemm_nn, From f858cad7ba21de05f251485d107d90446811c54f Mon Sep 17 00:00:00 2001 From: Eric Wong Date: Thu, 4 Sep 2025 18:41:01 +0800 Subject: [PATCH 55/71] fix. --- csrc/jit_kernels/impls/sm90_fp8_gemm_1d2d.hpp | 1 + 1 file changed, 1 insertion(+) diff --git a/csrc/jit_kernels/impls/sm90_fp8_gemm_1d2d.hpp b/csrc/jit_kernels/impls/sm90_fp8_gemm_1d2d.hpp index a81fed27..a625c118 100644 --- a/csrc/jit_kernels/impls/sm90_fp8_gemm_1d2d.hpp +++ b/csrc/jit_kernels/impls/sm90_fp8_gemm_1d2d.hpp @@ -252,4 +252,5 @@ static std::optional> sm90_m_grouped_fp8_gemm_masked_1d2d(co return enable_overlap ? std::optional(std::make_pair(config.block_m, config.signal_threshold)) : std::nullopt; +} } \ No newline at end of file From ee1a058c15a67d782c8a3a5008e423881918f4a7 Mon Sep 17 00:00:00 2001 From: Eric Wong Date: Fri, 5 Sep 2025 11:19:44 +0800 Subject: [PATCH 56/71] feat: use NamedBarrier instead of coorperative groups. --- csrc/jit/handle.hpp | 60 +++++-------------- csrc/jit/kernel_runtime.hpp | 15 ++--- csrc/jit_kernels/impls/sm90_fp8_gemm_1d2d.hpp | 2 +- .../deep_gemm/impls/sm90_fp8_gemm_1d2d.cuh | 6 +- 4 files changed, 23 insertions(+), 60 deletions(-) diff --git a/csrc/jit/handle.hpp b/csrc/jit/handle.hpp index 38c23b47..e56e04b9 100644 --- a/csrc/jit/handle.hpp +++ b/csrc/jit/handle.hpp @@ -37,8 +37,7 @@ static void unload_library(const LibraryHandle& library) { static LaunchConfigHandle construct_launch_config(const KernelHandle& kernel, const cudaStream_t& stream, const int& smem_size, - const dim3& grid_dim, const dim3& block_dim, - const int& cluster_dim, const bool& cooperative = false) { + const dim3& grid_dim, const dim3& block_dim, const int& cluster_dim) { if (smem_size > 0) DG_CUDA_RUNTIME_CHECK(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); @@ -50,27 +49,14 @@ static LaunchConfigHandle construct_launch_config(const KernelHandle& kernel, config.numAttrs = 0; config.attrs = nullptr; - // Support for multiple attributes - static LaunchAttrHandle attrs[2]; - int attr_count = 0; - + // NOTES: must use `static` or the `attr` will be deconstructed + static LaunchAttrHandle attr; if (cluster_dim > 1) { - attrs[attr_count].id = cudaLaunchAttributeClusterDimension; - attrs[attr_count].val.clusterDim = {static_cast(cluster_dim), 1, 1}; - attr_count++; - } - - if (cooperative) { - attrs[attr_count].id = cudaLaunchAttributeCooperative; - attrs[attr_count].val.cooperative = 1; - attr_count++; - } - - if (attr_count > 0) { - config.attrs = attrs; - config.numAttrs = attr_count; + attr.id = cudaLaunchAttributeClusterDimension; + attr.val.clusterDim = {static_cast(cluster_dim), 1, 1}; + config.attrs = &attr; + config.numAttrs = 1; } - return config; } @@ -109,7 +95,7 @@ static void unload_library(const LibraryHandle& library) { static LaunchConfigHandle construct_launch_config(const KernelHandle& kernel, const cudaStream_t& stream, const int& smem_size, - const dim3& grid_dim, const dim3& block_dim, const int& cluster_dim, const bool& cooperative = false) { + const dim3& grid_dim, const dim3& block_dim, const int& cluster_dim) { if (smem_size > 0) DG_CUDA_DRIVER_CHECK(cuFuncSetAttribute(kernel, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, smem_size)); @@ -125,30 +111,16 @@ static LaunchConfigHandle construct_launch_config(const KernelHandle& kernel, config.numAttrs = 0; config.attrs = nullptr; - // Support for multiple attributes // NOTES: must use `static` or the `attr` will be deconstructed - static LaunchAttrHandle attrs[2]; - int attr_count = 0; - + static LaunchAttrHandle attr; if (cluster_dim > 1) { - attrs[attr_count].id = CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION; - attrs[attr_count].value.clusterDim.x = cluster_dim; - attrs[attr_count].value.clusterDim.y = 1; - attrs[attr_count].value.clusterDim.z = 1; - attr_count++; + attr.id = CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION; + attr.value.clusterDim.x = cluster_dim; + attr.value.clusterDim.y = 1; + attr.value.clusterDim.z = 1; + config.attrs = &attr; + config.numAttrs = 1; } - - if (cooperative) { - attrs[attr_count].id = CU_LAUNCH_ATTRIBUTE_COOPERATIVE; - attrs[attr_count].value.cooperative = 1; - attr_count++; - } - - if (attr_count > 0) { - config.attrs = attrs; - config.numAttrs = attr_count; - } - return config; } @@ -160,4 +132,4 @@ static auto launch_kernel(const KernelHandle& kernel, const LaunchConfigHandle& #endif -} // namespace deep_gemm +} // namespace deep_gemm \ No newline at end of file diff --git a/csrc/jit/kernel_runtime.hpp b/csrc/jit/kernel_runtime.hpp index 2f7c2d16..bf4ae38e 100644 --- a/csrc/jit/kernel_runtime.hpp +++ b/csrc/jit/kernel_runtime.hpp @@ -13,17 +13,12 @@ struct LaunchArgs { int num_threads; int smem_size; int cluster_dim; - bool cooperative; - LaunchArgs(const int& grid_dim_x, const int& num_threads, const int& smem_size = 0, - const int& cluster_dim = 1, const bool& cooperative = false): - grid_dim({grid_dim_x, 1}), num_threads(num_threads), smem_size(smem_size), - cluster_dim(cluster_dim), cooperative(cooperative) {} + LaunchArgs(const int& grid_dim_x, const int& num_threads, const int& smem_size = 0, const int& cluster_dim = 1): + grid_dim({grid_dim_x, 1}), num_threads(num_threads), smem_size(smem_size), cluster_dim(cluster_dim) {} - LaunchArgs(const std::pair& grid_dim, const int& num_threads, const int& smem_size = 0, - const int& cluster_dim = 1, const bool& cooperative = false): - grid_dim(grid_dim), num_threads(num_threads), smem_size(smem_size), - cluster_dim(cluster_dim), cooperative(cooperative) {} + LaunchArgs(const std::pair& grid_dim, const int& num_threads, const int& smem_size = 0, const int& cluster_dim = 1): + grid_dim(grid_dim), num_threads(num_threads), smem_size(smem_size), cluster_dim(cluster_dim) {} }; class KernelRuntime final { @@ -120,4 +115,4 @@ class LaunchRuntime { } }; -} // namespace deep_gemm +} // namespace deep_gemm \ No newline at end of file diff --git a/csrc/jit_kernels/impls/sm90_fp8_gemm_1d2d.hpp b/csrc/jit_kernels/impls/sm90_fp8_gemm_1d2d.hpp index a625c118..17f3edbe 100644 --- a/csrc/jit_kernels/impls/sm90_fp8_gemm_1d2d.hpp +++ b/csrc/jit_kernels/impls/sm90_fp8_gemm_1d2d.hpp @@ -237,7 +237,7 @@ static std::optional> sm90_m_grouped_fp8_gemm_masked_1d2d(co .gemm_config = config, .launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads, config.smem_config.smem_size, - config.multicast_config.num_multicast, enable_overlap), + config.multicast_config.num_multicast), .sfb = sfb.data_ptr(), .grouped_layout = masked_m.data_ptr(), .signal = enable_overlap ? signal.value().data_ptr() : nullptr, diff --git a/deep_gemm/include/deep_gemm/impls/sm90_fp8_gemm_1d2d.cuh b/deep_gemm/include/deep_gemm/impls/sm90_fp8_gemm_1d2d.cuh index c4c68521..468324c2 100644 --- a/deep_gemm/include/deep_gemm/impls/sm90_fp8_gemm_1d2d.cuh +++ b/deep_gemm/include/deep_gemm/impls/sm90_fp8_gemm_1d2d.cuh @@ -3,7 +3,6 @@ #pragma clang diagnostic push #pragma clang diagnostic ignored "-Wunknown-attributes" -#include #include #include @@ -15,7 +14,6 @@ #include #include -namespace cg = cooperative_groups; namespace deep_gemm { using namespace deep_gemm::sm90; @@ -436,9 +434,7 @@ sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout, int* signal, cute::tma_store_wait<0>(); } - cg::coalesced_group group = cg::coalesced_threads(); - group.sync(); - + cutlass::arch::NamedBarrier(kNumMathThreads).sync(); __threadfence(); if (threadIdx.x == 0) { From ec8b7c1af48a3e7cdeb510e62fec0fa8aca9f2d8 Mon Sep 17 00:00:00 2001 From: Eric Wong Date: Sat, 6 Sep 2025 16:31:08 +0800 Subject: [PATCH 57/71] refactor: unify test for overlap with test for masked gemm. --- deep_gemm/testing/numeric.py | 14 +++++++++ tests/generators.py | 38 +++++++----------------- tests/test_fp8.py | 57 ++++++++---------------------------- tests/test_signal_gemm.py | 2 ++ 4 files changed, 38 insertions(+), 73 deletions(-) diff --git a/deep_gemm/testing/numeric.py b/deep_gemm/testing/numeric.py index d06a03b9..59f6acd2 100644 --- a/deep_gemm/testing/numeric.py +++ b/deep_gemm/testing/numeric.py @@ -1,6 +1,20 @@ import torch from typing import Iterable +def check_signal(num_local_expert, max_m, block_m, threshold, signal, masked_m): + ceil_div = lambda a, b: (a + b - 1) // b + + expert_len = max_m // block_m + for expert in range(num_local_expert): + mask = masked_m[expert] + start = expert * expert_len + end = expert * expert_len + expert_len + valid_len = ceil_div(mask, block_m) + for i in range(start, end): + if i < start + valid_len: + assert signal[i] == threshold, f'{i=}, {signal[i]=}, {threshold=}' + else: + assert signal[i] == 0, f'{i=}, {signal[i]=}' def calc_diff(x: torch.Tensor, y: torch.Tensor): x, y = x.double(), y.double() diff --git a/tests/generators.py b/tests/generators.py index 22c22ffa..315e18c8 100644 --- a/tests/generators.py +++ b/tests/generators.py @@ -88,9 +88,10 @@ def enumerate_m_grouped_contiguous(use_bf16: bool = False) -> Generator: def enumerate_m_grouped_masked() -> Generator: max_m = 4096 for kernel_type in get_kernel_types(): - for num_groups, m in ((1, 1024), (2, 512), (4, 256)): - for n, k in ((4096, 7168), (7168, 2048), ): - yield kernel_type, num_groups, max_m, m, n, k + for enable_overlap in (False, True): + for num_groups, m in ((1, 1024), (2, 512), (4, 256)): + for n, k in ((4096, 7168), (7168, 2048), ): + yield kernel_type, enable_overlap, num_groups, max_m, m, n, k def enumerate_k_grouped_contiguous(): @@ -191,7 +192,7 @@ def generate_m_grouped_contiguous(num_groups: int, expected_m_per_group: int, n: def generate_m_grouped_masked(num_groups: int, max_m: int, expected_m_per_group: int, n: int, k: int, - use_ue8m0: bool = False, use_bf16: bool = False): + use_ue8m0: bool = False, use_bf16: bool = False, enable_overlap: bool = False): a = torch.randn((num_groups, max_m, k), device='cuda', dtype=torch.bfloat16) b = torch.randn((num_groups, n, k), device='cuda', dtype=torch.bfloat16) d = torch.empty((num_groups, max_m, n), device='cuda', dtype=torch.bfloat16) @@ -211,7 +212,11 @@ def generate_m_grouped_masked(num_groups: int, max_m: int, expected_m_per_group: a_fp8[0][i], a_fp8[1][i] = per_token_cast_to_fp8(a[i], use_ue8m0=use_ue8m0) b_fp8[0][i], b_fp8[1][i] = per_block_cast_to_fp8(b[i], use_ue8m0=use_ue8m0) - return a_fp8, b_fp8, masked_m, d, ref_d + ceil_div = lambda a, b: (a + b - 1) // b + max_signal_size = num_groups * ceil_div(max_m, 64) + signal = torch.zeros(max_signal_size, dtype=torch.int32, device='cuda') if enable_overlap else None + + return a_fp8, b_fp8, masked_m, d, ref_d, signal def generate_k_grouped_contiguous(num_groups: int, m: int, n: int, ks: List[int], use_ue8m0: bool): @@ -234,26 +239,3 @@ def generate_k_grouped_contiguous(num_groups: int, m: int, n: int, ks: List[int] b_fp8 = per_channel_cast_to_fp8(b, use_ue8m0=use_ue8m0) return k, a_fp8, b_fp8, c, d, ref_d - -def generate_m_grouped_signal(num_groups: int, max_m: int, expected_m_per_group: int, n: int, k: int, - use_ue8m0: bool = False, use_bf16: bool = False): - a = torch.randn((num_groups, max_m, k), device='cuda', dtype=torch.bfloat16) - b = torch.randn((num_groups, n, k), device='cuda', dtype=torch.bfloat16) - d = torch.empty((num_groups, max_m, n), device='cuda', dtype=torch.bfloat16) - ref_d = torch.einsum('gmk,gnk->gmn', a, b) - - masked_m = torch.empty((num_groups, ), device='cuda', dtype=torch.int) - for j in range(num_groups): - masked_m[j] = int(expected_m_per_group * random.uniform(0.7, 1.3)) - assert masked_m.amax().item() <= max_m - - if use_bf16: - return a, b, masked_m, d, ref_d - - a_fp8 = (torch.empty_like(a, dtype=torch.float8_e4m3fn), torch.empty((num_groups, max_m, ceil_div(k, 128)), device='cuda', dtype=torch.float)) - b_fp8 = (torch.empty_like(b, dtype=torch.float8_e4m3fn), torch.empty((num_groups, ceil_div(n, 128), ceil_div(k, 128)), device='cuda', dtype=torch.float)) - for i in range(num_groups): - a_fp8[0][i], a_fp8[1][i] = per_token_cast_to_fp8(a[i], use_ue8m0=use_ue8m0) - b_fp8[0][i], b_fp8[1][i] = per_block_cast_to_fp8(b[i], use_ue8m0=use_ue8m0) - - return a_fp8, b_fp8, masked_m, d, ref_d diff --git a/tests/test_fp8.py b/tests/test_fp8.py index b9ebfc99..443262bb 100644 --- a/tests/test_fp8.py +++ b/tests/test_fp8.py @@ -6,7 +6,8 @@ import deep_gemm from deep_gemm.testing import ( bench, bench_kineto, - calc_diff, count_bytes + calc_diff, count_bytes, + check_signal, ) from generators import ( @@ -97,15 +98,20 @@ def test_m_grouped_gemm_masked() -> None: print('Testing m-grouped masked GEMM:') # TODO: when the actual `m` is greater than `expected_m_per_group`, efficiency may significantly decrease. - for kernel_type, num_groups, max_m, expected_m_per_group, n, k in enumerate_m_grouped_masked(): + for kernel_type, enable_overlap, num_groups, max_m, expected_m_per_group, n, k in enumerate_m_grouped_masked(): kernel_opt = f'1D1D' if kernel_type.is_1d1d() else '1D2D' use_ue8m0 = get_ue8m0_usage(kernel_type) disable_ue8m0_cast = not use_ue8m0 # Test correctness for i in range(10): - a, b, masked_m, d, ref_d = generate_m_grouped_masked(num_groups, max_m, expected_m_per_group, n, k, use_ue8m0=use_ue8m0) - deep_gemm.m_grouped_fp8_gemm_nt_masked(a, b, d, masked_m, expected_m_per_group, disable_ue8m0_cast=disable_ue8m0_cast) + a, b, masked_m, d, ref_d, signal = generate_m_grouped_masked(num_groups, max_m, expected_m_per_group, n, k, use_ue8m0=use_ue8m0) + result = deep_gemm.m_grouped_fp8_gemm_nt_masked(a, b, d, masked_m, expected_m_per_group, disable_ue8m0_cast=disable_ue8m0_cast, enable_overlap=enable_overlap, signal=signal) + + if enable_overlap: + block_m, threshold = result + check_signal(num_groups, max_m, block_m, threshold, signal, masked_m) + for j in range(num_groups): diff = calc_diff(d[j, :masked_m[j].item()], ref_d[j, :masked_m[j].item()]) assert diff < 0.001, f'{max_m=}, {n=}, {k=}, {j=}, masked_m={masked_m[j]}, {kernel_opt}, {num_groups=}, {diff:.5f}' @@ -115,12 +121,12 @@ def test_m_grouped_gemm_masked() -> None: # noinspection PyShadowingNames def test_func(): - deep_gemm.m_grouped_fp8_gemm_nt_masked(a, b, d, masked_m, expected_m_per_group, disable_ue8m0_cast=disable_ue8m0_cast) + deep_gemm.m_grouped_fp8_gemm_nt_masked(a, b, d, masked_m, expected_m_per_group, disable_ue8m0_cast=disable_ue8m0_cast, enable_overlap=enable_overlap, signal=signal) # Test performance with fixed shapes valid_m = masked_m.sum().item() t = bench_kineto(test_func, 'fp8_gemm', suppress_kineto_output=True) - print(f' > Perf ({num_groups=}, expected_m_per_group={expected_m_per_group:4}, n={n:4}, k={k:4}, {kernel_opt}): ' + print(f' > Perf ({num_groups=}, expected_m_per_group={expected_m_per_group:4}, n={n:4}, k={k:4}, {kernel_opt}, enable_overlap={enable_overlap}): ' f'{t * 1e6:4.0f} us | ' f'{2 * valid_m * n * k / t / 1e12:4.0f} TFLOPS | ' f'{(count_bytes(a, d) * valid_m / (max_m * num_groups) + count_bytes(b)) / 1e9 / t:4.0f} GB/s') @@ -159,44 +165,6 @@ def test_func(): print() -def test_m_grouped_gemm_signal() -> None: - print('Testing m-grouped signal GEMM:') - - # TODO: when the actual `m` is greater than `expected_m_per_group`, efficiency may significantly decrease. - for kernel_type, num_groups, max_m, expected_m_per_group, n, k in enumerate_m_grouped_masked(): - kernel_opt = f'1D1D' if kernel_type.is_1d1d() else '1D2D' - use_ue8m0 = get_ue8m0_usage(kernel_type) - disable_ue8m0_cast = not use_ue8m0 - - # Test correctness - for i in range(10): - a, b, masked_m, d, ref_d = generate_m_grouped_masked(num_groups, max_m, expected_m_per_group, n, k, use_ue8m0=use_ue8m0) - # Create signal tensor - signal = torch.zeros((num_groups, max_m), dtype=torch.int32, device='cuda') - result = deep_gemm.m_grouped_fp8_gemm_nt_signal(a, b, d, masked_m, signal, expected_m_per_group, disable_ue8m0_cast=disable_ue8m0_cast) - for j in range(num_groups): - diff = calc_diff(d[j, :masked_m[j].item()], ref_d[j, :masked_m[j].item()]) - assert diff < 0.001, f'{max_m=}, {n=}, {k=}, {j=}, masked_m={masked_m[j]}, {kernel_opt}, {num_groups=}, {diff:.5f}' - - # Construct full cases - a, b, masked_m, d, ref_d = generate_m_grouped_masked(num_groups, max_m, expected_m_per_group, n, k, use_ue8m0=use_ue8m0) - # Create signal tensor - signal = torch.zeros((num_groups, max_m), dtype=torch.int32, device='cuda') - - # noinspection PyShadowingNames - def test_func(): - deep_gemm.m_grouped_fp8_gemm_nt_signal(a, b, d, masked_m, signal, expected_m_per_group, disable_ue8m0_cast=disable_ue8m0_cast) - - # Test performance with fixed shapes - valid_m = masked_m.sum().item() - t = bench_kineto(test_func, 'fp8_gemm', suppress_kineto_output=True) - print(f' > Perf ({num_groups=}, expected_m_per_group={expected_m_per_group:4}, n={n:4}, k={k:4}, {kernel_opt}): ' - f'{t * 1e6:4.0f} us | ' - f'{2 * valid_m * n * k / t / 1e12:4.0f} TFLOPS | ' - f'{(count_bytes(a, d) * valid_m / (max_m * num_groups) + count_bytes(b)) / 1e9 / t:4.0f} GB/s') - print() - - if __name__ == '__main__': torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True @@ -210,4 +178,3 @@ def test_func(): test_m_grouped_gemm_contiguous() test_m_grouped_gemm_masked() test_k_grouped_gemm_contiguous() - test_m_grouped_gemm_signal() diff --git a/tests/test_signal_gemm.py b/tests/test_signal_gemm.py index fd8978db..e4a72d12 100644 --- a/tests/test_signal_gemm.py +++ b/tests/test_signal_gemm.py @@ -19,6 +19,8 @@ def ceil_div(a, b): return (a + b - 1) // b def check_signal(num_local_expert, max_m, block_m, threshold, combine_signal, masked_m): + ceil_div = lambda a, b: (a + b - 1) // b + signal = combine_signal.cpu().tolist() maskm = masked_m.cpu().tolist() From ad8874adb482e62406948c1d2b782b14899e5729 Mon Sep 17 00:00:00 2001 From: Eric Wong Date: Sat, 6 Sep 2025 16:37:27 +0800 Subject: [PATCH 58/71] ref --- tests/test_signal_gemm.py | 63 +++++++-------------------------------- 1 file changed, 11 insertions(+), 52 deletions(-) diff --git a/tests/test_signal_gemm.py b/tests/test_signal_gemm.py index e4a72d12..d4a83518 100644 --- a/tests/test_signal_gemm.py +++ b/tests/test_signal_gemm.py @@ -36,80 +36,40 @@ def check_signal(num_local_expert, max_m, block_m, threshold, combine_signal, ma else: assert signal[i] == 0, f'{i=}, {signal[i]=}' -def test_m_grouped_gemm_signal(max_block_n=256) -> None: - print('Testing m-grouped signal GEMM:') - - # TODO: when the actual `m` is greater than `expected_m_per_group`, efficiency may significantly decrease. - for kernel_type, num_groups, max_m, expected_m_per_group, n, k in enumerate_m_grouped_masked(): - kernel_opt = f'1D1D' if kernel_type.is_1d1d() else '1D2D' - use_ue8m0 = get_ue8m0_usage(kernel_type) - disable_ue8m0_cast = not use_ue8m0 - - # Test correctness - for i in range(2): - a, b, masked_m, d, ref_d = generate_m_grouped_masked(num_groups, max_m, expected_m_per_group, n, k, use_ue8m0=use_ue8m0) - max_signal_size = num_groups * ceil_div(max_m, 64) - combine_signal = torch.zeros(max_signal_size, dtype=torch.int32, device='cuda') - origin_sms = deep_gemm.get_num_sms() - deep_gemm.set_num_sms(origin_sms - 3) - block_m, threshold = deep_gemm.m_grouped_fp8_gemm_nt_masked(a, b, d, masked_m, expected_m_per_group, disable_ue8m0_cast=disable_ue8m0_cast, signal=combine_signal, max_block_n=max_block_n, enable_overlap=True) - deep_gemm.set_num_sms(origin_sms) - check_signal(num_groups, max_m, block_m, threshold, combine_signal, masked_m) - for j in range(num_groups): - diff = calc_diff(d[j, :masked_m[j].item()], ref_d[j, :masked_m[j].item()]) - assert diff < 0.001, f'{m=}, {n=}, {k=}, {j=}, masked_m={masked_m[j]}, {kernel_opt}, {num_groups=}, {diff:.5f}' - - print(f' > Correctness ({num_groups=}, expected_m_per_group={expected_m_per_group:4}, n={n:4}, k={k:4}, {kernel_opt}) checked ') - - # Construct full cases - a, b, masked_m, d, ref_d = generate_m_grouped_masked(num_groups, max_m, expected_m_per_group, n, k, use_ue8m0=use_ue8m0) - max_signal_size = num_groups * ceil_div(max_m, 64) - combine_signal = torch.zeros(max_signal_size, dtype=torch.int32, device='cuda') - - # noinspection PyShadowingNames - def test_func(): - deep_gemm.m_grouped_fp8_gemm_nt_masked(a, b, d, masked_m, expected_m_per_group, disable_ue8m0_cast=disable_ue8m0_cast, signal=combine_signal, max_block_n=max_block_n, enable_overlap=True) - - # Test performance with fixed shapes - valid_m = masked_m.sum().item() - t = bench_kineto(test_func, 'fp8_gemm', suppress_kineto_output=True) - print(f' > Perf ({num_groups=}, expected_m_per_group={expected_m_per_group:4}, n={n:4}, k={k:4}, {kernel_opt}): ' - f'{t * 1e6:4.0f} us | ' - f'{2 * valid_m * n * k / t / 1e12:4.0f} TFLOPS | ' - f'{(count_bytes(a, d) * valid_m / (max_m * num_groups) + count_bytes(b)) / 1e9 / t:4.0f} GB/s') - print() - def test_m_grouped_gemm_masked() -> None: print('Testing m-grouped masked GEMM:') # TODO: when the actual `m` is greater than `expected_m_per_group`, efficiency may significantly decrease. - for kernel_type, num_groups, max_m, expected_m_per_group, n, k in enumerate_m_grouped_masked(): + for kernel_type, enable_overlap, num_groups, max_m, expected_m_per_group, n, k in enumerate_m_grouped_masked(): kernel_opt = f'1D1D' if kernel_type.is_1d1d() else '1D2D' use_ue8m0 = get_ue8m0_usage(kernel_type) disable_ue8m0_cast = not use_ue8m0 # Test correctness - for i in range(2): - a, b, masked_m, d, ref_d = generate_m_grouped_masked(num_groups, max_m, expected_m_per_group, n, k, use_ue8m0=use_ue8m0) - deep_gemm.m_grouped_fp8_gemm_nt_masked(a, b, d, masked_m, expected_m_per_group, disable_ue8m0_cast=disable_ue8m0_cast) + for i in range(10): + a, b, masked_m, d, ref_d, signal = generate_m_grouped_masked(num_groups, max_m, expected_m_per_group, n, k, use_ue8m0=use_ue8m0) + result = deep_gemm.m_grouped_fp8_gemm_nt_masked(a, b, d, masked_m, expected_m_per_group, disable_ue8m0_cast=disable_ue8m0_cast, enable_overlap=enable_overlap, signal=signal) + + if enable_overlap: + block_m, threshold = result + check_signal(num_groups, max_m, block_m, threshold, signal, masked_m) + for j in range(num_groups): diff = calc_diff(d[j, :masked_m[j].item()], ref_d[j, :masked_m[j].item()]) assert diff < 0.001, f'{max_m=}, {n=}, {k=}, {j=}, masked_m={masked_m[j]}, {kernel_opt}, {num_groups=}, {diff:.5f}' - print(f' > Correctness ({num_groups=}, expected_m_per_group={expected_m_per_group:4}, n={n:4}, k={k:4}, {kernel_opt}) checked ') - # Construct full cases a, b, masked_m, d, ref_d = generate_m_grouped_masked(num_groups, max_m, expected_m_per_group, n, k, use_ue8m0=use_ue8m0) # noinspection PyShadowingNames def test_func(): - deep_gemm.m_grouped_fp8_gemm_nt_masked(a, b, d, masked_m, expected_m_per_group, disable_ue8m0_cast=disable_ue8m0_cast) + deep_gemm.m_grouped_fp8_gemm_nt_masked(a, b, d, masked_m, expected_m_per_group, disable_ue8m0_cast=disable_ue8m0_cast, enable_overlap=enable_overlap, signal=signal) # Test performance with fixed shapes valid_m = masked_m.sum().item() t = bench_kineto(test_func, 'fp8_gemm', suppress_kineto_output=True) - print(f' > Perf ({num_groups=}, expected_m_per_group={expected_m_per_group:4}, n={n:4}, k={k:4}, {kernel_opt}): ' + print(f' > Perf ({num_groups=}, expected_m_per_group={expected_m_per_group:4}, n={n:4}, k={k:4}, {kernel_opt}, enable_overlap={enable_overlap}): ' f'{t * 1e6:4.0f} us | ' f'{2 * valid_m * n * k / t / 1e12:4.0f} TFLOPS | ' f'{(count_bytes(a, d) * valid_m / (max_m * num_groups) + count_bytes(b)) / 1e9 / t:4.0f} GB/s') @@ -124,5 +84,4 @@ def test_func(): print('Library path:') print(f' > {deep_gemm.__path__}\n') - test_m_grouped_gemm_signal() test_m_grouped_gemm_masked() \ No newline at end of file From 59dc1aa0ff7de1c19e3fd9b2928b0c0331bb361e Mon Sep 17 00:00:00 2001 From: Eric Wong Date: Sat, 6 Sep 2025 16:38:05 +0800 Subject: [PATCH 59/71] fix. --- tests/test_signal_gemm.py | 25 ++----------------------- 1 file changed, 2 insertions(+), 23 deletions(-) diff --git a/tests/test_signal_gemm.py b/tests/test_signal_gemm.py index d4a83518..2c2682fb 100644 --- a/tests/test_signal_gemm.py +++ b/tests/test_signal_gemm.py @@ -6,7 +6,8 @@ import deep_gemm from deep_gemm.testing import ( bench, bench_kineto, - calc_diff, count_bytes + calc_diff, count_bytes, + check_signal, ) from generators import ( @@ -15,28 +16,6 @@ generate_normal, generate_m_grouped_contiguous, generate_m_grouped_masked, generate_k_grouped_contiguous ) -def ceil_div(a, b): - return (a + b - 1) // b - -def check_signal(num_local_expert, max_m, block_m, threshold, combine_signal, masked_m): - ceil_div = lambda a, b: (a + b - 1) // b - - signal = combine_signal.cpu().tolist() - maskm = masked_m.cpu().tolist() - - expert_len = max_m // block_m - for expert in range(num_local_expert): - mask = maskm[expert] - start = expert * expert_len - end = expert * expert_len + expert_len - valid_len = ceil_div(mask, block_m) - for i in range(start, end): - if i < start + valid_len: - assert signal[i] == threshold, f'{i=}, {signal[i]=}, {threshold=}' - else: - assert signal[i] == 0, f'{i=}, {signal[i]=}' - - def test_m_grouped_gemm_masked() -> None: print('Testing m-grouped masked GEMM:') From 5edbbc501e8f7217e9e506b3e368b913b2181c56 Mon Sep 17 00:00:00 2001 From: Eric Wong Date: Sat, 6 Sep 2025 16:40:06 +0800 Subject: [PATCH 60/71] fix. --- tests/generators.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/generators.py b/tests/generators.py index 315e18c8..b79a04c8 100644 --- a/tests/generators.py +++ b/tests/generators.py @@ -212,7 +212,6 @@ def generate_m_grouped_masked(num_groups: int, max_m: int, expected_m_per_group: a_fp8[0][i], a_fp8[1][i] = per_token_cast_to_fp8(a[i], use_ue8m0=use_ue8m0) b_fp8[0][i], b_fp8[1][i] = per_block_cast_to_fp8(b[i], use_ue8m0=use_ue8m0) - ceil_div = lambda a, b: (a + b - 1) // b max_signal_size = num_groups * ceil_div(max_m, 64) signal = torch.zeros(max_signal_size, dtype=torch.int32, device='cuda') if enable_overlap else None From bf0d62a6c04b5b3d4a50dee5588ec38fab634c71 Mon Sep 17 00:00:00 2001 From: Eric Wong Date: Sat, 6 Sep 2025 16:40:58 +0800 Subject: [PATCH 61/71] fix. --- tests/test_fp8.py | 2 +- tests/test_signal_gemm.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_fp8.py b/tests/test_fp8.py index 443262bb..353455eb 100644 --- a/tests/test_fp8.py +++ b/tests/test_fp8.py @@ -117,7 +117,7 @@ def test_m_grouped_gemm_masked() -> None: assert diff < 0.001, f'{max_m=}, {n=}, {k=}, {j=}, masked_m={masked_m[j]}, {kernel_opt}, {num_groups=}, {diff:.5f}' # Construct full cases - a, b, masked_m, d, ref_d = generate_m_grouped_masked(num_groups, max_m, expected_m_per_group, n, k, use_ue8m0=use_ue8m0) + a, b, masked_m, d, ref_d, signal = generate_m_grouped_masked(num_groups, max_m, expected_m_per_group, n, k, use_ue8m0=use_ue8m0) # noinspection PyShadowingNames def test_func(): diff --git a/tests/test_signal_gemm.py b/tests/test_signal_gemm.py index 2c2682fb..cb1654a3 100644 --- a/tests/test_signal_gemm.py +++ b/tests/test_signal_gemm.py @@ -39,7 +39,7 @@ def test_m_grouped_gemm_masked() -> None: assert diff < 0.001, f'{max_m=}, {n=}, {k=}, {j=}, masked_m={masked_m[j]}, {kernel_opt}, {num_groups=}, {diff:.5f}' # Construct full cases - a, b, masked_m, d, ref_d = generate_m_grouped_masked(num_groups, max_m, expected_m_per_group, n, k, use_ue8m0=use_ue8m0) + a, b, masked_m, d, ref_d, signal = generate_m_grouped_masked(num_groups, max_m, expected_m_per_group, n, k, use_ue8m0=use_ue8m0) # noinspection PyShadowingNames def test_func(): From b78061ba0f43051df6c8a581bf99fedbc0597650 Mon Sep 17 00:00:00 2001 From: Eric Wong Date: Sat, 6 Sep 2025 16:42:50 +0800 Subject: [PATCH 62/71] fix. --- tests/test_signal_gemm.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_signal_gemm.py b/tests/test_signal_gemm.py index cb1654a3..e07232e1 100644 --- a/tests/test_signal_gemm.py +++ b/tests/test_signal_gemm.py @@ -27,7 +27,7 @@ def test_m_grouped_gemm_masked() -> None: # Test correctness for i in range(10): - a, b, masked_m, d, ref_d, signal = generate_m_grouped_masked(num_groups, max_m, expected_m_per_group, n, k, use_ue8m0=use_ue8m0) + a, b, masked_m, d, ref_d, signal = generate_m_grouped_masked(num_groups, max_m, expected_m_per_group, n, k, use_ue8m0=use_ue8m0, enable_overlap=enable_overlap) result = deep_gemm.m_grouped_fp8_gemm_nt_masked(a, b, d, masked_m, expected_m_per_group, disable_ue8m0_cast=disable_ue8m0_cast, enable_overlap=enable_overlap, signal=signal) if enable_overlap: @@ -39,7 +39,7 @@ def test_m_grouped_gemm_masked() -> None: assert diff < 0.001, f'{max_m=}, {n=}, {k=}, {j=}, masked_m={masked_m[j]}, {kernel_opt}, {num_groups=}, {diff:.5f}' # Construct full cases - a, b, masked_m, d, ref_d, signal = generate_m_grouped_masked(num_groups, max_m, expected_m_per_group, n, k, use_ue8m0=use_ue8m0) + a, b, masked_m, d, ref_d, signal = generate_m_grouped_masked(num_groups, max_m, expected_m_per_group, n, k, use_ue8m0=use_ue8m0, enable_overlap=enable_overlap) # noinspection PyShadowingNames def test_func(): From 848952bab29507058889f50da6cdf4023f2487dd Mon Sep 17 00:00:00 2001 From: Eric Wong Date: Sat, 6 Sep 2025 16:45:59 +0800 Subject: [PATCH 63/71] fix. --- tests/test_fp8.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_fp8.py b/tests/test_fp8.py index 353455eb..04500505 100644 --- a/tests/test_fp8.py +++ b/tests/test_fp8.py @@ -105,7 +105,7 @@ def test_m_grouped_gemm_masked() -> None: # Test correctness for i in range(10): - a, b, masked_m, d, ref_d, signal = generate_m_grouped_masked(num_groups, max_m, expected_m_per_group, n, k, use_ue8m0=use_ue8m0) + a, b, masked_m, d, ref_d, signal = generate_m_grouped_masked(num_groups, max_m, expected_m_per_group, n, k, use_ue8m0=use_ue8m0, enable_overlap=enable_overlap) result = deep_gemm.m_grouped_fp8_gemm_nt_masked(a, b, d, masked_m, expected_m_per_group, disable_ue8m0_cast=disable_ue8m0_cast, enable_overlap=enable_overlap, signal=signal) if enable_overlap: @@ -117,7 +117,7 @@ def test_m_grouped_gemm_masked() -> None: assert diff < 0.001, f'{max_m=}, {n=}, {k=}, {j=}, masked_m={masked_m[j]}, {kernel_opt}, {num_groups=}, {diff:.5f}' # Construct full cases - a, b, masked_m, d, ref_d, signal = generate_m_grouped_masked(num_groups, max_m, expected_m_per_group, n, k, use_ue8m0=use_ue8m0) + a, b, masked_m, d, ref_d, signal = generate_m_grouped_masked(num_groups, max_m, expected_m_per_group, n, k, use_ue8m0=use_ue8m0, enable_overlap=enable_overlap) # noinspection PyShadowingNames def test_func(): From bbc09d67f97f7291cb63c40f5ffaff37ffb60510 Mon Sep 17 00:00:00 2001 From: Eric Wong Date: Sat, 6 Sep 2025 16:51:41 +0800 Subject: [PATCH 64/71] add test for EP16 situation. --- tests/generators.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/generators.py b/tests/generators.py index b79a04c8..8f0cce99 100644 --- a/tests/generators.py +++ b/tests/generators.py @@ -89,7 +89,7 @@ def enumerate_m_grouped_masked() -> Generator: max_m = 4096 for kernel_type in get_kernel_types(): for enable_overlap in (False, True): - for num_groups, m in ((1, 1024), (2, 512), (4, 256)): + for num_groups, m in ((1, 1024), (2, 512), (4, 256), (16, 64), (16, 32)): for n, k in ((4096, 7168), (7168, 2048), ): yield kernel_type, enable_overlap, num_groups, max_m, m, n, k From 4011a8a383a91b80825aee586b862ad6fedffc3e Mon Sep 17 00:00:00 2001 From: Eric Wong Date: Sat, 6 Sep 2025 17:22:33 +0800 Subject: [PATCH 65/71] remove code dup. --- tests/test_signal_gemm.py | 66 --------------------------------------- 1 file changed, 66 deletions(-) delete mode 100644 tests/test_signal_gemm.py diff --git a/tests/test_signal_gemm.py b/tests/test_signal_gemm.py deleted file mode 100644 index e07232e1..00000000 --- a/tests/test_signal_gemm.py +++ /dev/null @@ -1,66 +0,0 @@ -import copy -import random -import time -import torch - -import deep_gemm -from deep_gemm.testing import ( - bench, bench_kineto, - calc_diff, count_bytes, - check_signal, -) - -from generators import ( - KernelType, get_ue8m0_usage, - enumerate_normal, enumerate_m_grouped_contiguous, enumerate_m_grouped_masked, enumerate_k_grouped_contiguous, - generate_normal, generate_m_grouped_contiguous, generate_m_grouped_masked, generate_k_grouped_contiguous -) - -def test_m_grouped_gemm_masked() -> None: - print('Testing m-grouped masked GEMM:') - - # TODO: when the actual `m` is greater than `expected_m_per_group`, efficiency may significantly decrease. - for kernel_type, enable_overlap, num_groups, max_m, expected_m_per_group, n, k in enumerate_m_grouped_masked(): - kernel_opt = f'1D1D' if kernel_type.is_1d1d() else '1D2D' - use_ue8m0 = get_ue8m0_usage(kernel_type) - disable_ue8m0_cast = not use_ue8m0 - - # Test correctness - for i in range(10): - a, b, masked_m, d, ref_d, signal = generate_m_grouped_masked(num_groups, max_m, expected_m_per_group, n, k, use_ue8m0=use_ue8m0, enable_overlap=enable_overlap) - result = deep_gemm.m_grouped_fp8_gemm_nt_masked(a, b, d, masked_m, expected_m_per_group, disable_ue8m0_cast=disable_ue8m0_cast, enable_overlap=enable_overlap, signal=signal) - - if enable_overlap: - block_m, threshold = result - check_signal(num_groups, max_m, block_m, threshold, signal, masked_m) - - for j in range(num_groups): - diff = calc_diff(d[j, :masked_m[j].item()], ref_d[j, :masked_m[j].item()]) - assert diff < 0.001, f'{max_m=}, {n=}, {k=}, {j=}, masked_m={masked_m[j]}, {kernel_opt}, {num_groups=}, {diff:.5f}' - - # Construct full cases - a, b, masked_m, d, ref_d, signal = generate_m_grouped_masked(num_groups, max_m, expected_m_per_group, n, k, use_ue8m0=use_ue8m0, enable_overlap=enable_overlap) - - # noinspection PyShadowingNames - def test_func(): - deep_gemm.m_grouped_fp8_gemm_nt_masked(a, b, d, masked_m, expected_m_per_group, disable_ue8m0_cast=disable_ue8m0_cast, enable_overlap=enable_overlap, signal=signal) - - # Test performance with fixed shapes - valid_m = masked_m.sum().item() - t = bench_kineto(test_func, 'fp8_gemm', suppress_kineto_output=True) - print(f' > Perf ({num_groups=}, expected_m_per_group={expected_m_per_group:4}, n={n:4}, k={k:4}, {kernel_opt}, enable_overlap={enable_overlap}): ' - f'{t * 1e6:4.0f} us | ' - f'{2 * valid_m * n * k / t / 1e12:4.0f} TFLOPS | ' - f'{(count_bytes(a, d) * valid_m / (max_m * num_groups) + count_bytes(b)) / 1e9 / t:4.0f} GB/s') - print() - -if __name__ == '__main__': - torch.backends.cuda.matmul.allow_tf32 = True - torch.backends.cudnn.allow_tf32 = True - torch.manual_seed(0) - random.seed(0) - - print('Library path:') - print(f' > {deep_gemm.__path__}\n') - - test_m_grouped_gemm_masked() \ No newline at end of file From b9f37f25eeb473340f33e0f44dfd8e93d86723f3 Mon Sep 17 00:00:00 2001 From: Eric Wong Date: Mon, 8 Sep 2025 15:51:58 +0800 Subject: [PATCH 66/71] more. --- deep_gemm/include/deep_gemm/common/utils.cuh | 4 ++++ deep_gemm/include/deep_gemm/impls/sm90_fp8_gemm_1d2d.cuh | 2 +- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/deep_gemm/include/deep_gemm/common/utils.cuh b/deep_gemm/include/deep_gemm/common/utils.cuh index fc84b696..e6837df9 100644 --- a/deep_gemm/include/deep_gemm/common/utils.cuh +++ b/deep_gemm/include/deep_gemm/common/utils.cuh @@ -144,6 +144,10 @@ __device__ __forceinline__ void prefetch_l1(void *ptr) { asm volatile("prefetch.global.L1 [%0];" :: "l"(ptr)); } +__device__ __forceinline__ void store_wait() { + asm volatile("cp.async.bulk.wait_group 0;\n" ::: "memory"); +} + template struct Vectorized { static auto zeros() { diff --git a/deep_gemm/include/deep_gemm/impls/sm90_fp8_gemm_1d2d.cuh b/deep_gemm/include/deep_gemm/impls/sm90_fp8_gemm_1d2d.cuh index 468324c2..10397483 100644 --- a/deep_gemm/include/deep_gemm/impls/sm90_fp8_gemm_1d2d.cuh +++ b/deep_gemm/include/deep_gemm/impls/sm90_fp8_gemm_1d2d.cuh @@ -431,7 +431,7 @@ sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout, int* signal, if constexpr (kEnableOverlap) { if (threadIdx.x < BLOCK_N / TMA_D_BLOCK_N) { - cute::tma_store_wait<0>(); + store_wait(); } cutlass::arch::NamedBarrier(kNumMathThreads).sync(); From 07aa61cece0d40a7cf82a6f2460ff83da67c7668 Mon Sep 17 00:00:00 2001 From: Eric Wong Date: Mon, 8 Sep 2025 16:06:01 +0800 Subject: [PATCH 67/71] more. --- deep_gemm/include/deep_gemm/common/utils.cuh | 6 ++++++ deep_gemm/include/deep_gemm/impls/sm90_fp8_gemm_1d2d.cuh | 2 +- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/deep_gemm/include/deep_gemm/common/utils.cuh b/deep_gemm/include/deep_gemm/common/utils.cuh index e6837df9..3a9cdc06 100644 --- a/deep_gemm/include/deep_gemm/common/utils.cuh +++ b/deep_gemm/include/deep_gemm/common/utils.cuh @@ -148,6 +148,12 @@ __device__ __forceinline__ void store_wait() { asm volatile("cp.async.bulk.wait_group 0;\n" ::: "memory"); } +__device__ __forceinline__ int atomic_add_release_global(int* addr, int value) { + int ret; + asm volatile ("atom.add.release.gpu.global.s32 %0, [%1], %2;" : "=r"(ret) : "l"(addr), "r"(value)); + return ret; +} + template struct Vectorized { static auto zeros() { diff --git a/deep_gemm/include/deep_gemm/impls/sm90_fp8_gemm_1d2d.cuh b/deep_gemm/include/deep_gemm/impls/sm90_fp8_gemm_1d2d.cuh index 10397483..dffda6af 100644 --- a/deep_gemm/include/deep_gemm/impls/sm90_fp8_gemm_1d2d.cuh +++ b/deep_gemm/include/deep_gemm/impls/sm90_fp8_gemm_1d2d.cuh @@ -438,7 +438,7 @@ sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout, int* signal, __threadfence(); if (threadIdx.x == 0) { - atomicAdd(signal + scheduler.current_group_idx * ceil_div(shape_m, BLOCK_M) + m_block_idx, 1); + atomic_add_release_global(signal + scheduler.current_group_idx * ceil_div(shape_m, BLOCK_M) + m_block_idx, 1); } } } From 379a913986c602bc0e87e0c187ca75cad366f93b Mon Sep 17 00:00:00 2001 From: Eric Wong Date: Mon, 8 Sep 2025 16:09:14 +0800 Subject: [PATCH 68/71] try to use atom.add.release.gpu.global.s32 instead of __threadfence with atomicAdd. --- deep_gemm/include/deep_gemm/impls/sm90_fp8_gemm_1d2d.cuh | 1 - 1 file changed, 1 deletion(-) diff --git a/deep_gemm/include/deep_gemm/impls/sm90_fp8_gemm_1d2d.cuh b/deep_gemm/include/deep_gemm/impls/sm90_fp8_gemm_1d2d.cuh index dffda6af..11f611fb 100644 --- a/deep_gemm/include/deep_gemm/impls/sm90_fp8_gemm_1d2d.cuh +++ b/deep_gemm/include/deep_gemm/impls/sm90_fp8_gemm_1d2d.cuh @@ -435,7 +435,6 @@ sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout, int* signal, } cutlass::arch::NamedBarrier(kNumMathThreads).sync(); - __threadfence(); if (threadIdx.x == 0) { atomic_add_release_global(signal + scheduler.current_group_idx * ceil_div(shape_m, BLOCK_M) + m_block_idx, 1); From 9f769b40daefe260d7218297b42d507ae5b0c83f Mon Sep 17 00:00:00 2001 From: Eric Wong Date: Tue, 9 Sep 2025 17:44:24 +0800 Subject: [PATCH 69/71] remove redundant change. --- csrc/jit/handle.hpp | 2 +- csrc/jit/kernel_runtime.hpp | 2 +- csrc/jit_kernels/impls/sm90_fp8_gemm_1d2d.hpp | 3 ++- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/csrc/jit/handle.hpp b/csrc/jit/handle.hpp index e56e04b9..e05cf92c 100644 --- a/csrc/jit/handle.hpp +++ b/csrc/jit/handle.hpp @@ -132,4 +132,4 @@ static auto launch_kernel(const KernelHandle& kernel, const LaunchConfigHandle& #endif -} // namespace deep_gemm \ No newline at end of file +} // namespace deep_gemm diff --git a/csrc/jit/kernel_runtime.hpp b/csrc/jit/kernel_runtime.hpp index bf4ae38e..ba66eeb8 100644 --- a/csrc/jit/kernel_runtime.hpp +++ b/csrc/jit/kernel_runtime.hpp @@ -115,4 +115,4 @@ class LaunchRuntime { } }; -} // namespace deep_gemm \ No newline at end of file +} // namespace deep_gemm diff --git a/csrc/jit_kernels/impls/sm90_fp8_gemm_1d2d.hpp b/csrc/jit_kernels/impls/sm90_fp8_gemm_1d2d.hpp index 17f3edbe..fc0eff32 100644 --- a/csrc/jit_kernels/impls/sm90_fp8_gemm_1d2d.hpp +++ b/csrc/jit_kernels/impls/sm90_fp8_gemm_1d2d.hpp @@ -253,4 +253,5 @@ static std::optional> sm90_m_grouped_fp8_gemm_masked_1d2d(co std::optional(std::make_pair(config.block_m, config.signal_threshold)) : std::nullopt; } -} \ No newline at end of file + +} } // namespace deep_gemm \ No newline at end of file From eb1091f2a9f93527f95cb8d805bdb7875aa1bbad Mon Sep 17 00:00:00 2001 From: Eric Wong Date: Tue, 9 Sep 2025 17:45:36 +0800 Subject: [PATCH 70/71] fix. --- csrc/jit_kernels/impls/sm90_fp8_gemm_1d2d.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/jit_kernels/impls/sm90_fp8_gemm_1d2d.hpp b/csrc/jit_kernels/impls/sm90_fp8_gemm_1d2d.hpp index fc0eff32..802d46be 100644 --- a/csrc/jit_kernels/impls/sm90_fp8_gemm_1d2d.hpp +++ b/csrc/jit_kernels/impls/sm90_fp8_gemm_1d2d.hpp @@ -254,4 +254,4 @@ static std::optional> sm90_m_grouped_fp8_gemm_masked_1d2d(co std::nullopt; } -} } // namespace deep_gemm \ No newline at end of file +} // namespace deep_gemm \ No newline at end of file From ede008bfb9d4a18e19abb784455b6628dd736c0f Mon Sep 17 00:00:00 2001 From: Eric Wong Date: Tue, 9 Sep 2025 17:46:20 +0800 Subject: [PATCH 71/71] fix. --- csrc/jit_kernels/impls/sm90_fp8_gemm_1d2d.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/jit_kernels/impls/sm90_fp8_gemm_1d2d.hpp b/csrc/jit_kernels/impls/sm90_fp8_gemm_1d2d.hpp index 802d46be..49b4c9ad 100644 --- a/csrc/jit_kernels/impls/sm90_fp8_gemm_1d2d.hpp +++ b/csrc/jit_kernels/impls/sm90_fp8_gemm_1d2d.hpp @@ -254,4 +254,4 @@ static std::optional> sm90_m_grouped_fp8_gemm_masked_1d2d(co std::nullopt; } -} // namespace deep_gemm \ No newline at end of file +} // namespace deep_gemm