Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
72 commits
Select commit Hold shift + click to select a range
9ac3853
feat: add signal gemm api for SBO (Single Batch Overlap).
Sulfur6 Aug 28, 2025
79b5a45
feat: add launch config args for coorperative groups.
Sulfur6 Aug 28, 2025
6cf6ca1
feat: add signal gemm impl&runtime in jit kernels.
Sulfur6 Aug 28, 2025
1e2c135
feat: add signal gemm kernel.
Sulfur6 Aug 28, 2025
a5d23b0
feat: add signal gemm import in deep_gemm package.
Sulfur6 Aug 28, 2025
48c741a
feat: add test for signal gemm.
Sulfur6 Aug 28, 2025
9a2d3e5
feat: add signal threshold as a config arg for signal gemm.
Sulfur6 Aug 28, 2025
03f61d2
add comparation test.
Sulfur6 Aug 30, 2025
e86f1f9
avoid using bench kineto
Sulfur6 Aug 30, 2025
e7ebff1
test: modify generators to fit bs32 down gemm.
Sulfur6 Aug 31, 2025
8037369
more
Sulfur6 Aug 31, 2025
55c7943
more.
Sulfur6 Aug 31, 2025
99cc43b
more.
Sulfur6 Aug 31, 2025
4bfed47
fix
Sulfur6 Aug 31, 2025
4a14a40
more.
Sulfur6 Sep 1, 2025
701751c
more.
Sulfur6 Sep 1, 2025
8035121
more.
Sulfur6 Sep 1, 2025
02edd31
more.
Sulfur6 Sep 1, 2025
0fcd03c
feat: add param max_block_n
Sulfur6 Sep 1, 2025
2ec7414
test
Sulfur6 Sep 1, 2025
a63152f
more.
Sulfur6 Sep 1, 2025
92b29f6
more.
Sulfur6 Sep 1, 2025
be21bf6
more.
Sulfur6 Sep 1, 2025
ffa1140
more.
Sulfur6 Sep 2, 2025
4802c62
fix: ensure memory order and send location.
Sulfur6 Sep 2, 2025
6271138
more.
Sulfur6 Sep 2, 2025
57bb435
exp
Sulfur6 Sep 2, 2025
d6e63a3
exp
Sulfur6 Sep 2, 2025
a21c96f
rollback generator.
Sulfur6 Sep 2, 2025
51740cf
remove threadfence.
Sulfur6 Sep 2, 2025
656fe10
add threadfence.
Sulfur6 Sep 2, 2025
7e385a9
complete test.
Sulfur6 Sep 2, 2025
8a138ad
fix: turn comment from chinese to english.
Sulfur6 Sep 3, 2025
2dfdce5
refactor: merge signal gemm related api & sm90 impl into existing api…
Sulfur6 Sep 4, 2025
21da45d
refactor: add params in common.
Sulfur6 Sep 4, 2025
cd062ad
refactor: merge signal gemm kernel to fp8 gemm kernel.
Sulfur6 Sep 4, 2025
387f068
update tests.
Sulfur6 Sep 4, 2025
aa81d53
fix
Sulfur6 Sep 4, 2025
e252aaa
fix.
Sulfur6 Sep 4, 2025
1d0a364
fix.
Sulfur6 Sep 4, 2025
8c2e6b8
fix.
Sulfur6 Sep 4, 2025
3517989
fix.
Sulfur6 Sep 4, 2025
a086212
fix.
Sulfur6 Sep 4, 2025
e6c6977
fix.
Sulfur6 Sep 4, 2025
7af0f7a
debug
Sulfur6 Sep 4, 2025
2c9fa44
more
Sulfur6 Sep 4, 2025
96cdc90
fix.
Sulfur6 Sep 4, 2025
11eeed6
more.
Sulfur6 Sep 4, 2025
debb196
fix.
Sulfur6 Sep 4, 2025
7ee0480
fix
Sulfur6 Sep 4, 2025
0c0eac8
more.
Sulfur6 Sep 4, 2025
af77de9
remove comments.
Sulfur6 Sep 4, 2025
23be492
remove code dup.
Sulfur6 Sep 4, 2025
84526c5
remove signal gemm api.
Sulfur6 Sep 4, 2025
f858cad
fix.
Sulfur6 Sep 4, 2025
ee1a058
feat: use NamedBarrier instead of coorperative groups.
Sulfur6 Sep 5, 2025
ec8b7c1
refactor: unify test for overlap with test for masked gemm.
Sulfur6 Sep 6, 2025
ad8874a
ref
Sulfur6 Sep 6, 2025
59dc1aa
fix.
Sulfur6 Sep 6, 2025
5edbbc5
fix.
Sulfur6 Sep 6, 2025
bf0d62a
fix.
Sulfur6 Sep 6, 2025
b78061b
fix.
Sulfur6 Sep 6, 2025
848952b
fix.
Sulfur6 Sep 6, 2025
bbc09d6
add test for EP16 situation.
Sulfur6 Sep 6, 2025
4011a8a
remove code dup.
Sulfur6 Sep 6, 2025
b9f37f2
more.
Sulfur6 Sep 8, 2025
07aa61c
more.
Sulfur6 Sep 8, 2025
379a913
try to use atom.add.release.gpu.global.s32 instead of __threadfence w…
Sulfur6 Sep 8, 2025
9f769b4
remove redundant change.
Sulfur6 Sep 9, 2025
eb1091f
fix.
Sulfur6 Sep 9, 2025
ede008b
fix.
Sulfur6 Sep 9, 2025
d232a36
Merge remote-tracking branch 'origin/main' into sbo.v2.public
Sulfur6 Sep 16, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 19 additions & 5 deletions csrc/apis/gemm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -169,14 +169,17 @@ static void m_grouped_fp8_gemm_nn_contiguous(const std::pair<torch::Tensor, torc
d, m_indices, recipe, compiled_dims, disable_ue8m0_cast);
}

static void m_grouped_fp8_gemm_nt_masked(const std::pair<torch::Tensor, torch::Tensor>& a,
static std::optional<std::pair<int, int>> m_grouped_fp8_gemm_nt_masked(const std::pair<torch::Tensor, torch::Tensor>& a,
const std::pair<torch::Tensor, torch::Tensor>& b,
const torch::Tensor& d,
const torch::Tensor& masked_m,
const int& expected_m,
std::optional<std::tuple<int, int, int>> recipe,
const std::string& compiled_dims,
const bool& disable_ue8m0_cast) {
const bool& disable_ue8m0_cast,
const int& max_block_n,
const bool& enable_overlap,
const std::optional<torch::Tensor>& signal) {
// Shape must be `[G, M, K] @ [G, N, K].mT`
const auto& major_a = get_major_type_ab(a.first);
const auto& major_b = get_major_type_ab(b.first);
Expand All @@ -196,6 +199,12 @@ static void m_grouped_fp8_gemm_nt_masked(const std::pair<torch::Tensor, torch::T
DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16);
DG_HOST_ASSERT(masked_m.scalar_type() == torch::kInt);

if (enable_overlap) {
DG_HOST_ASSERT(signal.has_value());
DG_HOST_ASSERT(signal.value().is_contiguous());
DG_HOST_ASSERT(signal.value().scalar_type() == torch::kInt32);
}

// D must be N-major
check_major_type_cd(d);

Expand All @@ -207,9 +216,11 @@ static void m_grouped_fp8_gemm_nt_masked(const std::pair<torch::Tensor, torch::T

// Dispatch implementation
const auto& arch_major = device_runtime->get_arch_major();
std::optional<std::pair<int, int>> result = std::nullopt;
if (arch_major == 9 and sfa.scalar_type() == torch::kFloat) {
sm90_m_grouped_fp8_gemm_masked_1d2d(a.first, sfa, b.first, sfb, d, masked_m,
num_groups, m, n, k, expected_m, major_a, major_b, compiled_dims);
result = sm90_m_grouped_fp8_gemm_masked_1d2d(a.first, sfa, b.first, sfb, d, masked_m,
num_groups, m, n, k, expected_m, major_a, major_b, compiled_dims,
max_block_n, enable_overlap, signal);
} else if (arch_major == 10 and sfa.scalar_type() == torch::kInt) {
sm100_m_grouped_fp8_gemm_masked_1d1d(a.first, sfa, b.first, sfb, d, masked_m,
num_groups, m, n, k, expected_m, major_a, major_b, compiled_dims);
Expand All @@ -219,6 +230,7 @@ static void m_grouped_fp8_gemm_nt_masked(const std::pair<torch::Tensor, torch::T
} else {
DG_HOST_UNREACHABLE("Unsupported architecture or scaling factor types");
}
return result;
}

static void k_grouped_fp8_gemm_tn_contiguous(const std::pair<torch::Tensor, torch::Tensor>& a,
Expand Down Expand Up @@ -436,7 +448,9 @@ static void register_apis(pybind11::module_& m) {
m.def("m_grouped_fp8_gemm_nt_masked", &m_grouped_fp8_gemm_nt_masked,
py::arg("a"), py::arg("b"), py::arg("d"), py::arg("masked_m"),
py::arg("expected_m"), py::arg("recipe") = std::nullopt,
py::arg("compiled_dims") = "nk", py::arg("disable_ue8m0_cast") = false);
py::arg("compiled_dims") = "nk", py::arg("disable_ue8m0_cast") = false,
py::arg("max_block_n") = 256, py::arg("enable_overlap") = false,
py::arg("signal") = std::nullopt);
m.def("k_grouped_fp8_gemm_tn_contiguous", &k_grouped_fp8_gemm_tn_contiguous,
py::arg("a"), py::arg("b"), py::arg("d"), py::arg("ks"),
py::arg("ks_tensor"), py::arg("c") = std::nullopt,
Expand Down
12 changes: 9 additions & 3 deletions csrc/jit_kernels/heuristics/common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ struct GemmConfig {
cute::UMMA::Major major_b;
bool with_accumulation;
int block_m, block_n, block_k;
int signal_threshold;
int num_stages, num_last_stages;

// Templated device configs
Expand All @@ -71,6 +72,8 @@ struct GemmConfig {
MulticastConfig multicast_config;
SharedMemoryConfig smem_config;
ThreadConfig thread_config;

bool enable_overlap;
};

static bool is_multicast_legal(const int& shape_dim, const int& block_dim,
Expand Down Expand Up @@ -146,7 +149,8 @@ static GemmConfig get_best_config(const GemmType& gemm_type, const KernelType& k
const int& m, const int& n, const int& k, const int& num_groups,
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
const at::ScalarType& ab_dtype, const at::ScalarType& cd_dtype,
const bool& with_accumulation, const int& num_sms) {
const bool& with_accumulation, const int& num_sms,
const int& max_block_n = 256, const bool& enable_overlap = false) {
DG_HOST_ASSERT(ab_dtype == torch::kFloat8_e4m3fn or ab_dtype == torch::kBFloat16);
DG_HOST_ASSERT(cd_dtype == torch::kBFloat16 or cd_dtype == torch::kFloat);

Expand All @@ -158,7 +162,7 @@ static GemmConfig get_best_config(const GemmType& gemm_type, const KernelType& k
if (gemm_type == GemmType::MGroupedMasked) // Exclude 256 for performance
block_ms = std::vector{64, 128};
std::vector<int> block_ns;
for (int i = 16; i <= 256; i += 16)
for (int i = 16; i <= max_block_n; i += 16)
block_ns.push_back(i);

// K block size is selected in a fixed manner
Expand Down Expand Up @@ -269,14 +273,16 @@ static GemmConfig get_best_config(const GemmType& gemm_type, const KernelType& k
.block_m = best_block_m,
.block_n = best_block_n,
.block_k = block_k,
.signal_threshold = ceil_div(n, best_block_n),
.num_stages = best_num_stages,
.num_last_stages = ceil_div(k, block_k) % best_num_stages,
.num_sms = num_min_sms,
.tc_util = device_runtime->get_tc_util(),
.multicast_config = best_multicast_config,
// ReSharper disable once CppLocalVariableMightNotBeInitialized
.smem_config = best_smem_config,
.thread_config = ArchSpec::get_thread_config(kernel_type, best_block_m, best_block_n)
.thread_config = ArchSpec::get_thread_config(kernel_type, best_block_m, best_block_n),
.enable_overlap = enable_overlap
};

// Only SM100 BF16 kernels support tensor core control
Expand Down
26 changes: 17 additions & 9 deletions csrc/jit_kernels/impls/sm90_fp8_gemm_1d2d.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
#include <torch/python.h>

#include "../../jit/compiler.hpp"
#include "../../jit/device_runtime.hpp"
#include "../../jit/kernel_runtime.hpp"
#include "../../utils/exception.hpp"
#include "../../utils/format.hpp"
Expand All @@ -21,7 +20,7 @@ class SM90FP8Gemm1D2DRuntime final: public LaunchRuntime<SM90FP8Gemm1D2DRuntime>
GemmConfig gemm_config;
LaunchArgs launch_args;

void *sfb, *grouped_layout;
void *sfb, *grouped_layout, *signal;
CUtensorMap tensor_map_a;
CUtensorMap tensor_map_b;
CUtensorMap tensor_map_d;
Expand All @@ -43,7 +42,7 @@ static void __instantiate_kernel() {{
{}, {},
{}, {},
{}, {},
{}, {}
{}, {}, {}
>);
}};
)",
Expand All @@ -55,13 +54,13 @@ static void __instantiate_kernel() {{
args.gemm_config.num_stages, args.gemm_config.num_last_stages,
args.gemm_config.thread_config.num_tma_threads, args.gemm_config.thread_config.num_math_threads,
args.gemm_config.multicast_config.num_multicast, args.gemm_config.multicast_config.is_multicast_on_a,
args.gemm_config.num_sms, to_string(args.gemm_config.gemm_type));
args.gemm_config.num_sms, to_string(args.gemm_config.gemm_type), args.gemm_config.enable_overlap);
}

static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) {
// TODO: optimize `args` copy
DG_CUDA_UNIFIED_CHECK(launch_kernel(kernel, config,
args.sfb, args.grouped_layout,
args.sfb, args.grouped_layout, args.signal,
args.m, args.n, args.k,
args.tensor_map_a, args.tensor_map_b,
args.tensor_map_d, args.tensor_map_sfa));
Expand Down Expand Up @@ -117,6 +116,7 @@ static void sm90_fp8_gemm_1d2d(const torch::Tensor& a, const torch::Tensor& sfa,
config.multicast_config.num_multicast),
.sfb = sfb.data_ptr(),
.grouped_layout = nullptr,
.signal = nullptr,
.tensor_map_a = tensor_map_a,
.tensor_map_b = tensor_map_b,
.tensor_map_d = tensor_map_d,
Expand All @@ -140,7 +140,7 @@ static void sm90_m_grouped_fp8_gemm_contiguous_1d2d(const torch::Tensor& a, cons
const auto& aligned_k = align(k, 128);
const auto& config = get_best_config<SM90ArchSpec>(
GemmType::MGroupedContiguous, KernelType::Kernel1D2D,
m, n, k, 1, major_a, major_b,
m, n, k, num_groups, major_a, major_b,
torch::kFloat8_e4m3fn, d.scalar_type(), false,
device_runtime->get_num_sms());

Expand Down Expand Up @@ -176,6 +176,7 @@ static void sm90_m_grouped_fp8_gemm_contiguous_1d2d(const torch::Tensor& a, cons
config.multicast_config.num_multicast),
.sfb = sfb.data_ptr(),
.grouped_layout = m_indices.data_ptr(),
.signal = nullptr,
.tensor_map_a = tensor_map_a,
.tensor_map_b = tensor_map_b,
.tensor_map_d = tensor_map_d,
Expand All @@ -186,14 +187,17 @@ static void sm90_m_grouped_fp8_gemm_contiguous_1d2d(const torch::Tensor& a, cons
SM90FP8Gemm1D2DRuntime::launch(runtime, args);
}

static void sm90_m_grouped_fp8_gemm_masked_1d2d(const torch::Tensor& a, const torch::Tensor& sfa,
static std::optional<std::pair<int, int>> sm90_m_grouped_fp8_gemm_masked_1d2d(const torch::Tensor& a, const torch::Tensor& sfa,
const torch::Tensor& b, const torch::Tensor& sfb,
const torch::Tensor& d,
const torch::Tensor& masked_m,
const int& num_groups, const int& m, const int& n, const int& k,
const int& expected_m,
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
const std::string& compiled_dims) {
const std::string& compiled_dims,
const int& max_block_n,
const bool& enable_overlap,
const std::optional<torch::Tensor>& signal) {
const auto& aligned_k = align(k, 128);
DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16);
DG_HOST_ASSERT(major_a == cute::UMMA::Major::K and major_b == cute::UMMA::Major::K);
Expand All @@ -202,7 +206,7 @@ static void sm90_m_grouped_fp8_gemm_masked_1d2d(const torch::Tensor& a, const to
GemmType::MGroupedMasked, KernelType::Kernel1D2D,
expected_m, n, k, num_groups, major_a, major_b,
torch::kFloat8_e4m3fn, d.scalar_type(), false,
device_runtime->get_num_sms());
device_runtime->get_num_sms(), max_block_n, enable_overlap);

// Requires no TMA splits
DG_HOST_ASSERT(config.smem_config.swizzle_a_mode == config.block_k);
Expand Down Expand Up @@ -236,6 +240,7 @@ static void sm90_m_grouped_fp8_gemm_masked_1d2d(const torch::Tensor& a, const to
config.multicast_config.num_multicast),
.sfb = sfb.data_ptr(),
.grouped_layout = masked_m.data_ptr(),
.signal = enable_overlap ? signal.value().data_ptr() : nullptr,
.tensor_map_a = tensor_map_a,
.tensor_map_b = tensor_map_b,
.tensor_map_d = tensor_map_d,
Expand All @@ -244,6 +249,9 @@ static void sm90_m_grouped_fp8_gemm_masked_1d2d(const torch::Tensor& a, const to
const auto& code = SM90FP8Gemm1D2DRuntime::generate(args);
const auto& runtime = compiler->build("sm90_fp8_m_grouped_gemm_masked_1d2d", code);
SM90FP8Gemm1D2DRuntime::launch(runtime, args);
return enable_overlap ?
std::optional(std::make_pair(config.block_m, config.signal_threshold)) :
std::nullopt;
}

} // namespace deep_gemm
10 changes: 10 additions & 0 deletions deep_gemm/include/deep_gemm/common/utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,16 @@ __device__ __forceinline__ void prefetch_l1(void *ptr) {
asm volatile("prefetch.global.L1 [%0];" :: "l"(ptr));
}

__device__ __forceinline__ void store_wait() {
asm volatile("cp.async.bulk.wait_group 0;\n" ::: "memory");
}

__device__ __forceinline__ int atomic_add_release_global(int* addr, int value) {
int ret;
asm volatile ("atom.add.release.gpu.global.s32 %0, [%1], %2;" : "=r"(ret) : "l"(addr), "r"(value));
return ret;
}

template <uint32_t kNumBytes>
struct Vectorized {
static auto zeros() {
Expand Down
16 changes: 14 additions & 2 deletions deep_gemm/include/deep_gemm/impls/sm90_fp8_gemm_1d2d.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,9 @@ template <uint32_t SHAPE_M, uint32_t SHAPE_N, uint32_t SHAPE_K,
uint32_t kNumStages, uint32_t kNumLastStages,
uint32_t kNumTMAThreads, uint32_t kNumMathThreads,
uint32_t kNumTMAMulticast, bool kIsTMAMulticastOnA,
uint32_t kNumSMs, GemmType kGemmType>
uint32_t kNumSMs, GemmType kGemmType, bool kEnableOverlap>
__global__ __launch_bounds__(kNumTMAThreads + kNumMathThreads, 1) void
sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout,
sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout, int* signal,
uint32_t shape_m, uint32_t shape_n, uint32_t shape_k,
const __grid_constant__ cute::TmaDescriptor tensor_map_a,
const __grid_constant__ cute::TmaDescriptor tensor_map_b,
Expand Down Expand Up @@ -428,6 +428,18 @@ sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout,
cute::tma_store_arrive();
}
__syncwarp();

if constexpr (kEnableOverlap) {
if (threadIdx.x < BLOCK_N / TMA_D_BLOCK_N) {
store_wait();
}

cutlass::arch::NamedBarrier(kNumMathThreads).sync();

if (threadIdx.x == 0) {
atomic_add_release_global(signal + scheduler.current_group_idx * ceil_div(shape_m, BLOCK_M) + m_block_idx, 1);
}
}
}
}
#else
Expand Down
14 changes: 14 additions & 0 deletions deep_gemm/testing/numeric.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,20 @@
import torch
from typing import Iterable

def check_signal(num_local_expert, max_m, block_m, threshold, signal, masked_m):
ceil_div = lambda a, b: (a + b - 1) // b

expert_len = max_m // block_m
for expert in range(num_local_expert):
mask = masked_m[expert]
start = expert * expert_len
end = expert * expert_len + expert_len
valid_len = ceil_div(mask, block_m)
for i in range(start, end):
if i < start + valid_len:
assert signal[i] == threshold, f'{i=}, {signal[i]=}, {threshold=}'
else:
assert signal[i] == 0, f'{i=}, {signal[i]=}'

def calc_diff(x: torch.Tensor, y: torch.Tensor):
x, y = x.double(), y.double()
Expand Down
15 changes: 10 additions & 5 deletions tests/generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,10 @@ def enumerate_m_grouped_contiguous(use_bf16: bool = False) -> Generator:
def enumerate_m_grouped_masked() -> Generator:
max_m = 4096
for kernel_type in get_kernel_types():
for num_groups, m in ((1, 1024), (2, 512), (4, 256)):
for n, k in ((4096, 7168), (7168, 2048), ):
yield kernel_type, num_groups, max_m, m, n, k
for enable_overlap in (False, True):
for num_groups, m in ((1, 1024), (2, 512), (4, 256), (16, 64), (16, 32)):
for n, k in ((4096, 7168), (7168, 2048), ):
yield kernel_type, enable_overlap, num_groups, max_m, m, n, k


def enumerate_k_grouped_contiguous():
Expand Down Expand Up @@ -191,7 +192,7 @@ def generate_m_grouped_contiguous(num_groups: int, expected_m_per_group: int, n:


def generate_m_grouped_masked(num_groups: int, max_m: int, expected_m_per_group: int, n: int, k: int,
use_ue8m0: bool = False, use_bf16: bool = False):
use_ue8m0: bool = False, use_bf16: bool = False, enable_overlap: bool = False):
a = torch.randn((num_groups, max_m, k), device='cuda', dtype=torch.bfloat16)
b = torch.randn((num_groups, n, k), device='cuda', dtype=torch.bfloat16)
d = torch.empty((num_groups, max_m, n), device='cuda', dtype=torch.bfloat16)
Expand All @@ -211,7 +212,10 @@ def generate_m_grouped_masked(num_groups: int, max_m: int, expected_m_per_group:
a_fp8[0][i], a_fp8[1][i] = per_token_cast_to_fp8(a[i], use_ue8m0=use_ue8m0)
b_fp8[0][i], b_fp8[1][i] = per_block_cast_to_fp8(b[i], use_ue8m0=use_ue8m0)

return a_fp8, b_fp8, masked_m, d, ref_d
max_signal_size = num_groups * ceil_div(max_m, 64)
signal = torch.zeros(max_signal_size, dtype=torch.int32, device='cuda') if enable_overlap else None

return a_fp8, b_fp8, masked_m, d, ref_d, signal


def generate_k_grouped_contiguous(num_groups: int, m: int, n: int, ks: List[int], use_ue8m0: bool):
Expand All @@ -233,3 +237,4 @@ def generate_k_grouped_contiguous(num_groups: int, m: int, n: int, ks: List[int]
a_fp8 = per_channel_cast_to_fp8(a, use_ue8m0=use_ue8m0)
b_fp8 = per_channel_cast_to_fp8(b, use_ue8m0=use_ue8m0)
return k, a_fp8, b_fp8, c, d, ref_d

Loading