diff --git a/csrc/apis/gemm.hpp b/csrc/apis/gemm.hpp index 68ac7505..7db0dc89 100644 --- a/csrc/apis/gemm.hpp +++ b/csrc/apis/gemm.hpp @@ -175,14 +175,17 @@ static void m_grouped_fp8_gemm_nn_contiguous(const std::pair& a, +static std::optional> m_grouped_fp8_gemm_nt_masked(const std::pair& a, const std::pair& b, const torch::Tensor& d, const torch::Tensor& masked_m, const int& expected_m, std::optional> recipe, const std::string& compiled_dims, - const bool& disable_ue8m0_cast) { + const bool& disable_ue8m0_cast, + const int& max_block_n, + const bool& enable_overlap, + const c10::optional& signal) { // Shape must be `[G, M, K] @ [G, N, K].mT` const auto& major_a = get_major_type_ab(a.first); const auto& major_b = get_major_type_ab(b.first); @@ -202,6 +205,12 @@ static void m_grouped_fp8_gemm_nt_masked(const std::pairget_arch_major(); + std::optional> result = std::nullopt; if (arch_major == 9 and sfa.scalar_type() == torch::kFloat) { - sm90_m_grouped_fp8_gemm_masked_1d2d(a.first, sfa, b.first, sfb, d, masked_m, - num_groups, m, n, k, expected_m, major_a, major_b, compiled_dims); + result = sm90_m_grouped_fp8_gemm_masked_1d2d(a.first, sfa, b.first, sfb, d, masked_m, + num_groups, m, n, k, expected_m, major_a, major_b, compiled_dims, + max_block_n, enable_overlap, signal); } else if (arch_major == 10 and sfa.scalar_type() == torch::kInt) { sm100_m_grouped_fp8_gemm_masked_1d1d(a.first, sfa, b.first, sfb, d, masked_m, num_groups, m, n, k, expected_m, major_a, major_b, compiled_dims); @@ -225,6 +236,7 @@ static void m_grouped_fp8_gemm_nt_masked(const std::pair& a, diff --git a/csrc/jit_kernels/heuristics/common.hpp b/csrc/jit_kernels/heuristics/common.hpp index 455223bc..5427d138 100644 --- a/csrc/jit_kernels/heuristics/common.hpp +++ b/csrc/jit_kernels/heuristics/common.hpp @@ -63,6 +63,7 @@ struct GemmConfig { cute::UMMA::Major major_b; bool with_accumulation; int block_m, block_n, block_k; + int signal_threshold; int num_stages, num_last_stages; // Templated device configs @@ -73,6 +74,8 @@ struct GemmConfig { MulticastConfig multicast_config; SharedMemoryConfig smem_config; ThreadConfig thread_config; + + bool enable_overlap; }; static bool is_multicast_legal(const int& shape_dim, const int& block_dim, @@ -151,7 +154,8 @@ static GemmConfig get_best_config(const GemmType& gemm_type, const KernelType& k const int& m, const int& n, const int& k, const int& num_groups, const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, const at::ScalarType& ab_dtype, const at::ScalarType& cd_dtype, - const bool& with_accumulation, const int& num_sms) { + const bool& with_accumulation, const int& num_sms, + const int& max_block_n = 256, const bool& enable_overlap = false) { DG_HOST_ASSERT(ab_dtype == torch::kFloat8_e4m3fn or ab_dtype == torch::kBFloat16); DG_HOST_ASSERT(cd_dtype == torch::kBFloat16 or cd_dtype == torch::kFloat); @@ -161,7 +165,7 @@ static GemmConfig get_best_config(const GemmType& gemm_type, const KernelType& k block_ms = std::vector{get_mk_alignment_for_contiguous_layout()}; if (gemm_type == GemmType::MGroupedMasked) // Exclude 256 for performance block_ms = std::vector{64, 128}; - const auto block_ns = ArchSpec::get_block_n_candidates(cd_dtype); + const auto block_ns = ArchSpec::get_block_n_candidates(cd_dtype, max_block_n); // K block size is selected in a fixed manner const auto& block_k = 128 / static_cast(c10::elementSize(ab_dtype)); @@ -271,6 +275,7 @@ static GemmConfig get_best_config(const GemmType& gemm_type, const KernelType& k .block_m = best_block_m, .block_n = best_block_n, .block_k = block_k, + .signal_threshold = ceil_div(n, best_block_n), .num_stages = best_num_stages, .num_last_stages = ceil_div(k, block_k) % best_num_stages, .num_sms = num_min_sms, @@ -278,7 +283,8 @@ static GemmConfig get_best_config(const GemmType& gemm_type, const KernelType& k .multicast_config = best_multicast_config, // ReSharper disable once CppLocalVariableMightNotBeInitialized .smem_config = best_smem_config, - .thread_config = ArchSpec::get_thread_config(kernel_type, best_block_m, best_block_n) + .thread_config = ArchSpec::get_thread_config(kernel_type, best_block_m, best_block_n), + .enable_overlap = enable_overlap }; // Only SM100 BF16 kernels support tensor core control diff --git a/csrc/jit_kernels/heuristics/sm100.hpp b/csrc/jit_kernels/heuristics/sm100.hpp index e62a13cc..d0d16980 100644 --- a/csrc/jit_kernels/heuristics/sm100.hpp +++ b/csrc/jit_kernels/heuristics/sm100.hpp @@ -12,7 +12,7 @@ namespace deep_gemm { struct SM100ArchSpec { static constexpr int smem_capacity = 232448; - static std::vector get_block_n_candidates(const at::ScalarType& cd_dtype) { + static std::vector get_block_n_candidates(const at::ScalarType& cd_dtype, const int& max_block_n) { // 16 is for better SM usage // Stride 32 is due to low-performance swizzle-16/32B std::vector candidates = {16}; diff --git a/csrc/jit_kernels/heuristics/sm90.hpp b/csrc/jit_kernels/heuristics/sm90.hpp index 133e2da0..d411206b 100644 --- a/csrc/jit_kernels/heuristics/sm90.hpp +++ b/csrc/jit_kernels/heuristics/sm90.hpp @@ -11,11 +11,11 @@ namespace deep_gemm { struct SM90ArchSpec { static constexpr int smem_capacity = 232448; - static std::vector get_block_n_candidates(const at::ScalarType& cd_dtype) { + static std::vector get_block_n_candidates(const at::ScalarType& cd_dtype, const int& max_block_n) { // Avoid bank conflicts for FP32 output const auto& start = cd_dtype == torch::kFloat ? 8 : 16; std::vector candidates; - for (int i = start; i <= 256; i += 16) + for (int i = start; i <= max_block_n; i += 16) candidates.push_back(i); return candidates; } diff --git a/csrc/jit_kernels/impls/sm90_fp8_gemm_1d2d.hpp b/csrc/jit_kernels/impls/sm90_fp8_gemm_1d2d.hpp index ced8d17a..f08bce8f 100644 --- a/csrc/jit_kernels/impls/sm90_fp8_gemm_1d2d.hpp +++ b/csrc/jit_kernels/impls/sm90_fp8_gemm_1d2d.hpp @@ -22,7 +22,7 @@ class SM90FP8Gemm1D2DRuntime final: public LaunchRuntime GemmConfig gemm_config; LaunchArgs launch_args; - void *sfb, *grouped_layout; + void *sfb, *grouped_layout, *signal; CUtensorMap tensor_map_a; CUtensorMap tensor_map_b; CUtensorMap tensor_map_d; @@ -44,7 +44,8 @@ static void __instantiate_kernel() {{ {}, {}, {}, {}, {}, {}, - {}, {}, {} + {}, {}, {}, + {} >); }}; )", @@ -57,13 +58,14 @@ static void __instantiate_kernel() {{ args.gemm_config.thread_config.num_tma_threads, args.gemm_config.thread_config.num_math_threads, args.gemm_config.multicast_config.num_multicast, args.gemm_config.multicast_config.is_multicast_on_a, args.gemm_config.num_sms, to_string(args.gemm_config.gemm_type), - get_default_epilogue_type(args.epilogue_type)); + get_default_epilogue_type(args.epilogue_type), + args.gemm_config.enable_overlap); } static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) { // TODO: optimize `args` copy DG_CUDA_UNIFIED_CHECK(launch_kernel(kernel, config, - args.sfb, args.grouped_layout, + args.sfb, args.grouped_layout, args.signal, args.m, args.n, args.k, args.tensor_map_a, args.tensor_map_b, args.tensor_map_d, args.tensor_map_sfa)); @@ -121,6 +123,7 @@ static void sm90_fp8_gemm_1d2d(const torch::Tensor& a, const torch::Tensor& sfa, config.multicast_config.num_multicast), .sfb = sfb.data_ptr(), .grouped_layout = nullptr, + .signal = nullptr, .tensor_map_a = tensor_map_a, .tensor_map_b = tensor_map_b, .tensor_map_d = tensor_map_d, @@ -181,6 +184,7 @@ static void sm90_m_grouped_fp8_gemm_contiguous_1d2d(const torch::Tensor& a, cons config.multicast_config.num_multicast), .sfb = sfb.data_ptr(), .grouped_layout = m_indices.data_ptr(), + .signal = nullptr, .tensor_map_a = tensor_map_a, .tensor_map_b = tensor_map_b, .tensor_map_d = tensor_map_d, @@ -191,14 +195,17 @@ static void sm90_m_grouped_fp8_gemm_contiguous_1d2d(const torch::Tensor& a, cons MAYBE_LAUNCH(SM90FP8Gemm1D2DRuntime::launch(runtime, args)); } -static void sm90_m_grouped_fp8_gemm_masked_1d2d(const torch::Tensor& a, const torch::Tensor& sfa, +static std::optional> sm90_m_grouped_fp8_gemm_masked_1d2d(const torch::Tensor& a, const torch::Tensor& sfa, const torch::Tensor& b, const torch::Tensor& sfb, const torch::Tensor& d, const torch::Tensor& masked_m, const int& num_groups, const int& m, const int& n, const int& k, const int& expected_m, const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, - const std::string& compiled_dims) { + const std::string& compiled_dims, + const int& max_block_n, + const bool& enable_overlap, + const c10::optional& signal) { const auto& aligned_k = align(k, 128); DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16); DG_HOST_ASSERT(major_a == cute::UMMA::Major::K and major_b == cute::UMMA::Major::K); @@ -207,7 +214,7 @@ static void sm90_m_grouped_fp8_gemm_masked_1d2d(const torch::Tensor& a, const to GemmType::MGroupedMasked, KernelType::Kernel1D2D, expected_m, n, k, num_groups, major_a, major_b, torch::kFloat8_e4m3fn, d.scalar_type(), false, - device_runtime->get_num_sms()); + device_runtime->get_num_sms(), max_block_n, enable_overlap); // Requires no TMA splits DG_HOST_ASSERT(config.smem_config.swizzle_a_mode == config.block_k); @@ -242,6 +249,7 @@ static void sm90_m_grouped_fp8_gemm_masked_1d2d(const torch::Tensor& a, const to config.multicast_config.num_multicast), .sfb = sfb.data_ptr(), .grouped_layout = masked_m.data_ptr(), + .signal = enable_overlap ? signal.value().data_ptr() : nullptr, .tensor_map_a = tensor_map_a, .tensor_map_b = tensor_map_b, .tensor_map_d = tensor_map_d, @@ -250,6 +258,9 @@ static void sm90_m_grouped_fp8_gemm_masked_1d2d(const torch::Tensor& a, const to const auto& code = SM90FP8Gemm1D2DRuntime::generate(args); const auto& runtime = compiler->build("sm90_fp8_m_grouped_gemm_masked_1d2d", code); MAYBE_LAUNCH(SM90FP8Gemm1D2DRuntime::launch(runtime, args)); + return enable_overlap ? + std::optional(std::make_pair(config.block_m, config.signal_threshold)) : + std::nullopt; } } // namespace deep_gemm diff --git a/csrc/python_api.cpp b/csrc/python_api.cpp index 2b2b5187..14f3b15d 100644 --- a/csrc/python_api.cpp +++ b/csrc/python_api.cpp @@ -137,8 +137,16 @@ void m_grouped_fp8_gemm_nn_contiguous_wrapper(const torch::Tensor& a_val, const deep_gemm::gemm::m_grouped_fp8_gemm_nn_contiguous({a_val, a_scale}, {b_val, b_scale}, d, m_indices, to_recipe_tuple(recipe), compiled_dims, disable_ue8m0_cast); } -void m_grouped_fp8_gemm_nt_masked_wrapper(const torch::Tensor& a_val, const torch::Tensor& a_scale, const torch::Tensor& b_val, const torch::Tensor& b_scale, const torch::Tensor& d, const torch::Tensor& masked_m, int64_t expected_m, const c10::optional& recipe, const std::string& compiled_dims, bool disable_ue8m0_cast) { - deep_gemm::gemm::m_grouped_fp8_gemm_nt_masked({a_val, a_scale}, {b_val, b_scale}, d, masked_m, expected_m, to_recipe_tuple(recipe), compiled_dims, disable_ue8m0_cast); +std::tuple, c10::optional> m_grouped_fp8_gemm_nt_masked_wrapper(const torch::Tensor& a_val, const torch::Tensor& a_scale, const torch::Tensor& b_val, const torch::Tensor& b_scale, const torch::Tensor& d, const torch::Tensor& masked_m, int64_t expected_m, const c10::optional& recipe, const std::string& compiled_dims, bool disable_ue8m0_cast, int64_t max_block_n, bool enable_overlap, const c10::optional& signal) { + auto result = deep_gemm::gemm::m_grouped_fp8_gemm_nt_masked({a_val, a_scale}, {b_val, b_scale}, d, masked_m, expected_m, to_recipe_tuple(recipe), compiled_dims, disable_ue8m0_cast, max_block_n, enable_overlap, signal); + + if (!result) { + return std::make_tuple(c10::nullopt, c10::nullopt); + } + return std::make_tuple( + c10::optional(result->first), + c10::optional(result->second) + ); } void k_grouped_fp8_gemm_nt_contiguous_wrapper(const torch::Tensor& a_val, const torch::Tensor& a_scale, const torch::Tensor& b_val, const torch::Tensor& b_scale, const torch::Tensor& d, c10::List ks, const torch::Tensor& ks_tensor, const c10::optional& c, c10::IntArrayRef recipe, const std::string& compiled_dims) { @@ -342,17 +350,20 @@ TORCH_LIBRARY(deep_gemm, m) { deep_gemm_wrappers::m_grouped_fp8_gemm_nn_contiguous_wrapper(a_val, a_scale, b_val, b_scale, d, m_indices, recipe, compiled_dims, disable_ue8m0_cast); }); - m.def(R"(m_grouped_fp8_gemm_nt_masked(Any a, Any b, Tensor d, Tensor masked_m, int expected_m, int[]? recipe=None, str compiled_dims="nk", bool disable_ue8m0_cast=False) -> ())"); + m.def(R"(m_grouped_fp8_gemm_nt_masked(Any a, Any b, Tensor d, Tensor masked_m, int expected_m, int[]? recipe=None, str compiled_dims="nk", bool disable_ue8m0_cast=False, int max_block_n=256, bool enable_overlap=False, Tensor? signal=None) -> (int?, int?))"); m.impl("m_grouped_fp8_gemm_nt_masked", torch::kCUDA, [](const c10::IValue& a_input, const c10::IValue& b_input, const torch::Tensor& d, const torch::Tensor& masked_m, int64_t expected_m, const c10::optional& recipe, const std::string& compiled_dims, - bool disable_ue8m0_cast) { + bool disable_ue8m0_cast, + int64_t max_block_n, + bool enable_overlap, + const c10::optional& signal) { auto [a_val, a_scale] = parse_tensor_or_tuple(a_input); auto [b_val, b_scale] = parse_tensor_or_tuple(b_input); - deep_gemm_wrappers::m_grouped_fp8_gemm_nt_masked_wrapper(a_val, a_scale, b_val, b_scale, d, masked_m, expected_m, recipe, compiled_dims, disable_ue8m0_cast); + return deep_gemm_wrappers::m_grouped_fp8_gemm_nt_masked_wrapper(a_val, a_scale, b_val, b_scale, d, masked_m, expected_m, recipe, compiled_dims, disable_ue8m0_cast, max_block_n, enable_overlap, signal); }); m.def(R"(k_grouped_fp8_gemm_nt_contiguous(Any a, Any b, Tensor d, int[] ks, Tensor ks_tensor, Tensor? c=None, int[] recipe=[1, 1, 128], str compiled_dims="mn") -> ())"); diff --git a/deep_gemm/include/deep_gemm/common/utils.cuh b/deep_gemm/include/deep_gemm/common/utils.cuh index 0b7ff116..d590e614 100644 --- a/deep_gemm/include/deep_gemm/common/utils.cuh +++ b/deep_gemm/include/deep_gemm/common/utils.cuh @@ -158,6 +158,16 @@ __device__ __forceinline__ void prefetch_l1(void *ptr) { asm volatile("prefetch.global.L1 [%0];" :: "l"(ptr)); } +__device__ __forceinline__ void store_wait() { + asm volatile("cp.async.bulk.wait_group 0;\n" ::: "memory"); +} + +__device__ __forceinline__ int atomic_add_release_global(int* addr, int value) { + int ret; + asm volatile ("atom.add.release.gpu.global.s32 %0, [%1], %2;" : "=r"(ret) : "l"(addr), "r"(value)); + return ret; +} + template struct Vectorized { static auto zeros() { diff --git a/deep_gemm/include/deep_gemm/impls/sm90_fp8_gemm_1d2d.cuh b/deep_gemm/include/deep_gemm/impls/sm90_fp8_gemm_1d2d.cuh index 5a92d7d4..ea4b5057 100644 --- a/deep_gemm/include/deep_gemm/impls/sm90_fp8_gemm_1d2d.cuh +++ b/deep_gemm/include/deep_gemm/impls/sm90_fp8_gemm_1d2d.cuh @@ -38,9 +38,9 @@ template + typename epilogue_type_t, bool kEnableOverlap> __global__ __launch_bounds__(kNumTMAThreads + kNumMathThreads, 1) void -sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout, +sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout, int *signal, uint32_t shape_m, uint32_t shape_n, uint32_t shape_k, const __grid_constant__ cute::TmaDescriptor tensor_map_a, const __grid_constant__ cute::TmaDescriptor tensor_map_b, @@ -395,6 +395,18 @@ sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout, cute::tma_store_arrive(); } __syncwarp(); + + if constexpr (kEnableOverlap) { + if (threadIdx.x < BLOCK_N / TMA_D_BLOCK_N) { + store_wait(); + } + + cutlass::arch::NamedBarrier(kNumMathThreads).sync(); + + if (threadIdx.x == 0) { + atomic_add_release_global(signal + scheduler.current_group_idx * ceil_div(shape_m, BLOCK_M) + m_block_idx, 1); + } + } } } #else diff --git a/deep_gemm/testing/numeric.py b/deep_gemm/testing/numeric.py index d06a03b9..37a88d43 100644 --- a/deep_gemm/testing/numeric.py +++ b/deep_gemm/testing/numeric.py @@ -17,3 +17,18 @@ def count_bytes(*tensors): elif t is not None: total += t.numel() * t.element_size() return total + +def check_signal(num_local_expert, max_m, block_m, threshold, signal, masked_m): + ceil_div = lambda a, b: (a + b - 1) // b + + expert_len = max_m // block_m + for expert in range(num_local_expert): + mask = masked_m[expert] + start = expert * expert_len + end = expert * expert_len + expert_len + valid_len = ceil_div(mask, block_m) + for i in range(start, end): + if i < start + valid_len: + assert signal[i] == threshold, f'{i=}, {signal[i]=}, {threshold=}' + else: + assert signal[i] == 0, f'{i=}, {signal[i]=}' diff --git a/setup.py b/setup.py index e5b96657..8d9d29f8 100644 --- a/setup.py +++ b/setup.py @@ -93,7 +93,7 @@ def prepare_includes(self): ] }, ext_modules=[ - CUDAExtension(name='deep_gemm_cpp', + CUDAExtension(name='deep_gemm.deep_gemm_cpp', sources=sources, include_dirs=build_include_dirs, libraries=build_libraries, diff --git a/tests/generators.py b/tests/generators.py index d856217e..0d06505a 100644 --- a/tests/generators.py +++ b/tests/generators.py @@ -113,9 +113,10 @@ def enumerate_m_grouped_contiguous(dtype: torch.dtype) -> Generator: def enumerate_m_grouped_masked(dtype: torch.dtype) -> Generator: max_m = 4096 for kernel_type in get_kernel_types(dtype): - for num_groups, m in ((1, 1024), (2, 512), (4, 256)): - for n, k in ((4096, 7168), (7168, 2048), ): - yield kernel_type, num_groups, max_m, m, n, k + for enable_overlap in (False, True): + for num_groups, m in ((1, 1024), (2, 512), (4, 256), (16, 64), (16, 32)): + for n, k in ((4096, 7168), (7168, 2048), ): + yield kernel_type, enable_overlap, num_groups, max_m, m, n, k def enumerate_k_grouped_contiguous(): @@ -218,7 +219,7 @@ def generate_m_grouped_contiguous(num_groups: int, expected_m_per_group: int, n: def generate_m_grouped_masked(num_groups: int, max_m: int, expected_m_per_group: int, n: int, k: int, - use_ue8m0: bool = False, use_bf16: bool = False): + use_ue8m0: bool = False, use_bf16: bool = False, enable_overlap: bool = False): a = torch.randn((num_groups, max_m, k), device='cuda', dtype=torch.bfloat16) b = torch.randn((num_groups, n, k), device='cuda', dtype=torch.bfloat16) d = torch.empty((num_groups, max_m, n), device='cuda', dtype=torch.bfloat16) @@ -238,7 +239,10 @@ def generate_m_grouped_masked(num_groups: int, max_m: int, expected_m_per_group: a_fp8[0][i], a_fp8[1][i] = per_token_cast_to_fp8(a[i], use_ue8m0=use_ue8m0) b_fp8[0][i], b_fp8[1][i] = per_block_cast_to_fp8(b[i], use_ue8m0=use_ue8m0) - return a_fp8, b_fp8, masked_m, d, ref_d + max_signal_size = num_groups * ceil_div(max_m, 64) + signal = torch.zeros(max_signal_size, dtype=torch.int32, device='cuda') if enable_overlap else None + + return a_fp8, b_fp8, masked_m, d, ref_d, signal def generate_k_grouped_contiguous(num_groups: int, m: int, n: int, major_a: MajorTypeAB, major_b: MajorTypeAB, ks: List[int], use_ue8m0: bool): diff --git a/tests/test_fp8.py b/tests/test_fp8.py index 7415e07c..1a8e424d 100644 --- a/tests/test_fp8.py +++ b/tests/test_fp8.py @@ -6,7 +6,8 @@ import deep_gemm from deep_gemm.testing import ( bench, bench_kineto, - calc_diff, count_bytes + calc_diff, count_bytes, + check_signal, ) from generators import ( @@ -90,30 +91,37 @@ def test_m_grouped_gemm_masked() -> None: print('Testing m-grouped masked GEMM:') # TODO: when the actual `m` is greater than `expected_m_per_group`, efficiency may significantly decrease. - for kernel_type, num_groups, max_m, expected_m_per_group, n, k in enumerate_m_grouped_masked(torch.float8_e4m3fn): + for kernel_type, enable_overlap, num_groups, max_m, expected_m_per_group, n, k in enumerate_m_grouped_masked(torch.float8_e4m3fn): kernel_opt = f'1D1D' if kernel_type.is_1d1d() else '1D2D' use_ue8m0 = get_ue8m0_usage(kernel_type) disable_ue8m0_cast = not use_ue8m0 # Test correctness for i in range(10): - a, b, masked_m, d, ref_d = generate_m_grouped_masked(num_groups, max_m, expected_m_per_group, n, k, use_ue8m0=use_ue8m0) - deep_gemm.m_grouped_fp8_gemm_nt_masked(a, b, d, masked_m, expected_m_per_group, disable_ue8m0_cast=disable_ue8m0_cast) + a, b, masked_m, d, ref_d, signal = generate_m_grouped_masked(num_groups, max_m, expected_m_per_group, n, k, use_ue8m0=use_ue8m0, enable_overlap=enable_overlap) + result = deep_gemm.m_grouped_fp8_gemm_nt_masked(a, b, d, masked_m, expected_m_per_group, disable_ue8m0_cast=disable_ue8m0_cast, enable_overlap=enable_overlap, signal=signal) + + if enable_overlap: + block_m, threshold = result + check_signal(num_groups, max_m, block_m, threshold, signal, masked_m) + for j in range(num_groups): diff = calc_diff(d[j, :masked_m[j].item()], ref_d[j, :masked_m[j].item()]) assert diff < 0.001, f'{max_m=}, {n=}, {k=}, {j=}, masked_m={masked_m[j]}, {kernel_opt}, {num_groups=}, {diff:.5f}' # Construct full cases - a, b, masked_m, d, ref_d = generate_m_grouped_masked(num_groups, max_m, expected_m_per_group, n, k, use_ue8m0=use_ue8m0) + a, b, masked_m, d, ref_d, signal = generate_m_grouped_masked(num_groups, max_m, expected_m_per_group, n, k, use_ue8m0=use_ue8m0, enable_overlap=enable_overlap) + # noinspection PyShadowingNames def test_func(): - deep_gemm.m_grouped_fp8_gemm_nt_masked(a, b, d, masked_m, expected_m_per_group, disable_ue8m0_cast=disable_ue8m0_cast) + deep_gemm.m_grouped_fp8_gemm_nt_masked(a, b, d, masked_m, expected_m_per_group, disable_ue8m0_cast=disable_ue8m0_cast, enable_overlap=enable_overlap, signal=signal) + # Test performance with fixed shapes valid_m = masked_m.sum().item() t = bench_kineto(test_func, 'fp8_gemm', suppress_kineto_output=True) - print(f' > Perf ({num_groups=}, expected_m_per_group={expected_m_per_group:4}, n={n:4}, k={k:4}, {kernel_opt}): ' + print(f' > Perf ({num_groups=}, expected_m_per_group={expected_m_per_group:4}, n={n:4}, k={k:4}, {kernel_opt}, enable_overlap={enable_overlap}): ' f'{t * 1e6:4.0f} us | ' f'{2 * valid_m * n * k / t / 1e12:4.0f} TFLOPS | ' f'{(count_bytes(a, d) * valid_m / (max_m * num_groups) + count_bytes(b)) / 1e9 / t:4.0f} GB/s')