Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 16 additions & 4 deletions csrc/apis/gemm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -175,14 +175,17 @@ static void m_grouped_fp8_gemm_nn_contiguous(const std::pair<torch::Tensor, torc
d, m_indices, recipe, compiled_dims, disable_ue8m0_cast);
}

static void m_grouped_fp8_gemm_nt_masked(const std::pair<torch::Tensor, torch::Tensor>& a,
static std::optional<std::pair<int, int>> m_grouped_fp8_gemm_nt_masked(const std::pair<torch::Tensor, torch::Tensor>& a,
const std::pair<torch::Tensor, torch::Tensor>& b,
const torch::Tensor& d,
const torch::Tensor& masked_m,
const int& expected_m,
std::optional<std::tuple<int, int, int>> recipe,
const std::string& compiled_dims,
const bool& disable_ue8m0_cast) {
const bool& disable_ue8m0_cast,
const int& max_block_n,
const bool& enable_overlap,
const c10::optional<torch::Tensor>& signal) {
// Shape must be `[G, M, K] @ [G, N, K].mT`
const auto& major_a = get_major_type_ab(a.first);
const auto& major_b = get_major_type_ab(b.first);
Expand All @@ -202,6 +205,12 @@ static void m_grouped_fp8_gemm_nt_masked(const std::pair<torch::Tensor, torch::T
DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16);
DG_HOST_ASSERT(masked_m.scalar_type() == torch::kInt);

if (enable_overlap) {
DG_HOST_ASSERT(signal.has_value());
DG_HOST_ASSERT(signal.value().is_contiguous());
DG_HOST_ASSERT(signal.value().scalar_type() == torch::kInt32);
}

// D must be N-major
check_major_type_cd(d);

Expand All @@ -213,9 +222,11 @@ static void m_grouped_fp8_gemm_nt_masked(const std::pair<torch::Tensor, torch::T

// Dispatch implementation
const auto& arch_major = device_runtime->get_arch_major();
std::optional<std::pair<int, int>> result = std::nullopt;
if (arch_major == 9 and sfa.scalar_type() == torch::kFloat) {
sm90_m_grouped_fp8_gemm_masked_1d2d(a.first, sfa, b.first, sfb, d, masked_m,
num_groups, m, n, k, expected_m, major_a, major_b, compiled_dims);
result = sm90_m_grouped_fp8_gemm_masked_1d2d(a.first, sfa, b.first, sfb, d, masked_m,
num_groups, m, n, k, expected_m, major_a, major_b, compiled_dims,
max_block_n, enable_overlap, signal);
} else if (arch_major == 10 and sfa.scalar_type() == torch::kInt) {
sm100_m_grouped_fp8_gemm_masked_1d1d(a.first, sfa, b.first, sfb, d, masked_m,
num_groups, m, n, k, expected_m, major_a, major_b, compiled_dims);
Expand All @@ -225,6 +236,7 @@ static void m_grouped_fp8_gemm_nt_masked(const std::pair<torch::Tensor, torch::T
} else {
DG_HOST_UNREACHABLE("Unsupported architecture or scaling factor types");
}
return result;
}

static void k_grouped_fp8_gemm_tn_contiguous(const std::pair<torch::Tensor, torch::Tensor>& a,
Expand Down
12 changes: 9 additions & 3 deletions csrc/jit_kernels/heuristics/common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ struct GemmConfig {
cute::UMMA::Major major_b;
bool with_accumulation;
int block_m, block_n, block_k;
int signal_threshold;
int num_stages, num_last_stages;

// Templated device configs
Expand All @@ -73,6 +74,8 @@ struct GemmConfig {
MulticastConfig multicast_config;
SharedMemoryConfig smem_config;
ThreadConfig thread_config;

bool enable_overlap;
};

static bool is_multicast_legal(const int& shape_dim, const int& block_dim,
Expand Down Expand Up @@ -151,7 +154,8 @@ static GemmConfig get_best_config(const GemmType& gemm_type, const KernelType& k
const int& m, const int& n, const int& k, const int& num_groups,
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
const at::ScalarType& ab_dtype, const at::ScalarType& cd_dtype,
const bool& with_accumulation, const int& num_sms) {
const bool& with_accumulation, const int& num_sms,
const int& max_block_n = 256, const bool& enable_overlap = false) {
DG_HOST_ASSERT(ab_dtype == torch::kFloat8_e4m3fn or ab_dtype == torch::kBFloat16);
DG_HOST_ASSERT(cd_dtype == torch::kBFloat16 or cd_dtype == torch::kFloat);

Expand All @@ -161,7 +165,7 @@ static GemmConfig get_best_config(const GemmType& gemm_type, const KernelType& k
block_ms = std::vector{get_mk_alignment_for_contiguous_layout()};
if (gemm_type == GemmType::MGroupedMasked) // Exclude 256 for performance
block_ms = std::vector{64, 128};
const auto block_ns = ArchSpec::get_block_n_candidates(cd_dtype);
const auto block_ns = ArchSpec::get_block_n_candidates(cd_dtype, max_block_n);

// K block size is selected in a fixed manner
const auto& block_k = 128 / static_cast<int>(c10::elementSize(ab_dtype));
Expand Down Expand Up @@ -271,14 +275,16 @@ static GemmConfig get_best_config(const GemmType& gemm_type, const KernelType& k
.block_m = best_block_m,
.block_n = best_block_n,
.block_k = block_k,
.signal_threshold = ceil_div(n, best_block_n),
.num_stages = best_num_stages,
.num_last_stages = ceil_div(k, block_k) % best_num_stages,
.num_sms = num_min_sms,
.tc_util = device_runtime->get_tc_util(),
.multicast_config = best_multicast_config,
// ReSharper disable once CppLocalVariableMightNotBeInitialized
.smem_config = best_smem_config,
.thread_config = ArchSpec::get_thread_config(kernel_type, best_block_m, best_block_n)
.thread_config = ArchSpec::get_thread_config(kernel_type, best_block_m, best_block_n),
.enable_overlap = enable_overlap
};

// Only SM100 BF16 kernels support tensor core control
Expand Down
2 changes: 1 addition & 1 deletion csrc/jit_kernels/heuristics/sm100.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ namespace deep_gemm {
struct SM100ArchSpec {
static constexpr int smem_capacity = 232448;

static std::vector<int> get_block_n_candidates(const at::ScalarType& cd_dtype) {
static std::vector<int> get_block_n_candidates(const at::ScalarType& cd_dtype, const int& max_block_n) {
// 16 is for better SM usage
// Stride 32 is due to low-performance swizzle-16/32B
std::vector<int> candidates = {16};
Expand Down
4 changes: 2 additions & 2 deletions csrc/jit_kernels/heuristics/sm90.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@ namespace deep_gemm {
struct SM90ArchSpec {
static constexpr int smem_capacity = 232448;

static std::vector<int> get_block_n_candidates(const at::ScalarType& cd_dtype) {
static std::vector<int> get_block_n_candidates(const at::ScalarType& cd_dtype, const int& max_block_n) {
// Avoid bank conflicts for FP32 output
const auto& start = cd_dtype == torch::kFloat ? 8 : 16;
std::vector<int> candidates;
for (int i = start; i <= 256; i += 16)
for (int i = start; i <= max_block_n; i += 16)
candidates.push_back(i);
return candidates;
}
Expand Down
25 changes: 18 additions & 7 deletions csrc/jit_kernels/impls/sm90_fp8_gemm_1d2d.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class SM90FP8Gemm1D2DRuntime final: public LaunchRuntime<SM90FP8Gemm1D2DRuntime>
GemmConfig gemm_config;
LaunchArgs launch_args;

void *sfb, *grouped_layout;
void *sfb, *grouped_layout, *signal;
CUtensorMap tensor_map_a;
CUtensorMap tensor_map_b;
CUtensorMap tensor_map_d;
Expand All @@ -44,7 +44,8 @@ static void __instantiate_kernel() {{
{}, {},
{}, {},
{}, {},
{}, {}, {}
{}, {}, {},
{}
>);
}};
)",
Expand All @@ -57,13 +58,14 @@ static void __instantiate_kernel() {{
args.gemm_config.thread_config.num_tma_threads, args.gemm_config.thread_config.num_math_threads,
args.gemm_config.multicast_config.num_multicast, args.gemm_config.multicast_config.is_multicast_on_a,
args.gemm_config.num_sms, to_string(args.gemm_config.gemm_type),
get_default_epilogue_type(args.epilogue_type));
get_default_epilogue_type(args.epilogue_type),
args.gemm_config.enable_overlap);
}

static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) {
// TODO: optimize `args` copy
DG_CUDA_UNIFIED_CHECK(launch_kernel(kernel, config,
args.sfb, args.grouped_layout,
args.sfb, args.grouped_layout, args.signal,
args.m, args.n, args.k,
args.tensor_map_a, args.tensor_map_b,
args.tensor_map_d, args.tensor_map_sfa));
Expand Down Expand Up @@ -121,6 +123,7 @@ static void sm90_fp8_gemm_1d2d(const torch::Tensor& a, const torch::Tensor& sfa,
config.multicast_config.num_multicast),
.sfb = sfb.data_ptr(),
.grouped_layout = nullptr,
.signal = nullptr,
.tensor_map_a = tensor_map_a,
.tensor_map_b = tensor_map_b,
.tensor_map_d = tensor_map_d,
Expand Down Expand Up @@ -181,6 +184,7 @@ static void sm90_m_grouped_fp8_gemm_contiguous_1d2d(const torch::Tensor& a, cons
config.multicast_config.num_multicast),
.sfb = sfb.data_ptr(),
.grouped_layout = m_indices.data_ptr(),
.signal = nullptr,
.tensor_map_a = tensor_map_a,
.tensor_map_b = tensor_map_b,
.tensor_map_d = tensor_map_d,
Expand All @@ -191,14 +195,17 @@ static void sm90_m_grouped_fp8_gemm_contiguous_1d2d(const torch::Tensor& a, cons
MAYBE_LAUNCH(SM90FP8Gemm1D2DRuntime::launch(runtime, args));
}

static void sm90_m_grouped_fp8_gemm_masked_1d2d(const torch::Tensor& a, const torch::Tensor& sfa,
static std::optional<std::pair<int, int>> sm90_m_grouped_fp8_gemm_masked_1d2d(const torch::Tensor& a, const torch::Tensor& sfa,
const torch::Tensor& b, const torch::Tensor& sfb,
const torch::Tensor& d,
const torch::Tensor& masked_m,
const int& num_groups, const int& m, const int& n, const int& k,
const int& expected_m,
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
const std::string& compiled_dims) {
const std::string& compiled_dims,
const int& max_block_n,
const bool& enable_overlap,
const c10::optional<torch::Tensor>& signal) {
const auto& aligned_k = align(k, 128);
DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16);
DG_HOST_ASSERT(major_a == cute::UMMA::Major::K and major_b == cute::UMMA::Major::K);
Expand All @@ -207,7 +214,7 @@ static void sm90_m_grouped_fp8_gemm_masked_1d2d(const torch::Tensor& a, const to
GemmType::MGroupedMasked, KernelType::Kernel1D2D,
expected_m, n, k, num_groups, major_a, major_b,
torch::kFloat8_e4m3fn, d.scalar_type(), false,
device_runtime->get_num_sms());
device_runtime->get_num_sms(), max_block_n, enable_overlap);

// Requires no TMA splits
DG_HOST_ASSERT(config.smem_config.swizzle_a_mode == config.block_k);
Expand Down Expand Up @@ -242,6 +249,7 @@ static void sm90_m_grouped_fp8_gemm_masked_1d2d(const torch::Tensor& a, const to
config.multicast_config.num_multicast),
.sfb = sfb.data_ptr(),
.grouped_layout = masked_m.data_ptr(),
.signal = enable_overlap ? signal.value().data_ptr() : nullptr,
.tensor_map_a = tensor_map_a,
.tensor_map_b = tensor_map_b,
.tensor_map_d = tensor_map_d,
Expand All @@ -250,6 +258,9 @@ static void sm90_m_grouped_fp8_gemm_masked_1d2d(const torch::Tensor& a, const to
const auto& code = SM90FP8Gemm1D2DRuntime::generate(args);
const auto& runtime = compiler->build("sm90_fp8_m_grouped_gemm_masked_1d2d", code);
MAYBE_LAUNCH(SM90FP8Gemm1D2DRuntime::launch(runtime, args));
return enable_overlap ?
std::optional(std::make_pair(config.block_m, config.signal_threshold)) :
std::nullopt;
}

} // namespace deep_gemm
21 changes: 16 additions & 5 deletions csrc/python_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -137,8 +137,16 @@ void m_grouped_fp8_gemm_nn_contiguous_wrapper(const torch::Tensor& a_val, const
deep_gemm::gemm::m_grouped_fp8_gemm_nn_contiguous({a_val, a_scale}, {b_val, b_scale}, d, m_indices, to_recipe_tuple(recipe), compiled_dims, disable_ue8m0_cast);
}

void m_grouped_fp8_gemm_nt_masked_wrapper(const torch::Tensor& a_val, const torch::Tensor& a_scale, const torch::Tensor& b_val, const torch::Tensor& b_scale, const torch::Tensor& d, const torch::Tensor& masked_m, int64_t expected_m, const c10::optional<c10::IntArrayRef>& recipe, const std::string& compiled_dims, bool disable_ue8m0_cast) {
deep_gemm::gemm::m_grouped_fp8_gemm_nt_masked({a_val, a_scale}, {b_val, b_scale}, d, masked_m, expected_m, to_recipe_tuple(recipe), compiled_dims, disable_ue8m0_cast);
std::tuple<c10::optional<int64_t>, c10::optional<int64_t>> m_grouped_fp8_gemm_nt_masked_wrapper(const torch::Tensor& a_val, const torch::Tensor& a_scale, const torch::Tensor& b_val, const torch::Tensor& b_scale, const torch::Tensor& d, const torch::Tensor& masked_m, int64_t expected_m, const c10::optional<c10::IntArrayRef>& recipe, const std::string& compiled_dims, bool disable_ue8m0_cast, int64_t max_block_n, bool enable_overlap, const c10::optional<torch::Tensor>& signal) {
auto result = deep_gemm::gemm::m_grouped_fp8_gemm_nt_masked({a_val, a_scale}, {b_val, b_scale}, d, masked_m, expected_m, to_recipe_tuple(recipe), compiled_dims, disable_ue8m0_cast, max_block_n, enable_overlap, signal);

if (!result) {
return std::make_tuple(c10::nullopt, c10::nullopt);
}
return std::make_tuple(
c10::optional<int64_t>(result->first),
c10::optional<int64_t>(result->second)
);
}

void k_grouped_fp8_gemm_nt_contiguous_wrapper(const torch::Tensor& a_val, const torch::Tensor& a_scale, const torch::Tensor& b_val, const torch::Tensor& b_scale, const torch::Tensor& d, c10::List<int64_t> ks, const torch::Tensor& ks_tensor, const c10::optional<torch::Tensor>& c, c10::IntArrayRef recipe, const std::string& compiled_dims) {
Expand Down Expand Up @@ -342,17 +350,20 @@ TORCH_LIBRARY(deep_gemm, m) {
deep_gemm_wrappers::m_grouped_fp8_gemm_nn_contiguous_wrapper(a_val, a_scale, b_val, b_scale, d, m_indices, recipe, compiled_dims, disable_ue8m0_cast);
});

m.def(R"(m_grouped_fp8_gemm_nt_masked(Any a, Any b, Tensor d, Tensor masked_m, int expected_m, int[]? recipe=None, str compiled_dims="nk", bool disable_ue8m0_cast=False) -> ())");
m.def(R"(m_grouped_fp8_gemm_nt_masked(Any a, Any b, Tensor d, Tensor masked_m, int expected_m, int[]? recipe=None, str compiled_dims="nk", bool disable_ue8m0_cast=False, int max_block_n=256, bool enable_overlap=False, Tensor? signal=None) -> (int?, int?))");
m.impl("m_grouped_fp8_gemm_nt_masked", torch::kCUDA, [](const c10::IValue& a_input, const c10::IValue& b_input,
const torch::Tensor& d,
const torch::Tensor& masked_m,
int64_t expected_m,
const c10::optional<c10::IntArrayRef>& recipe,
const std::string& compiled_dims,
bool disable_ue8m0_cast) {
bool disable_ue8m0_cast,
int64_t max_block_n,
bool enable_overlap,
const c10::optional<torch::Tensor>& signal) {
auto [a_val, a_scale] = parse_tensor_or_tuple(a_input);
auto [b_val, b_scale] = parse_tensor_or_tuple(b_input);
deep_gemm_wrappers::m_grouped_fp8_gemm_nt_masked_wrapper(a_val, a_scale, b_val, b_scale, d, masked_m, expected_m, recipe, compiled_dims, disable_ue8m0_cast);
return deep_gemm_wrappers::m_grouped_fp8_gemm_nt_masked_wrapper(a_val, a_scale, b_val, b_scale, d, masked_m, expected_m, recipe, compiled_dims, disable_ue8m0_cast, max_block_n, enable_overlap, signal);
});

m.def(R"(k_grouped_fp8_gemm_nt_contiguous(Any a, Any b, Tensor d, int[] ks, Tensor ks_tensor, Tensor? c=None, int[] recipe=[1, 1, 128], str compiled_dims="mn") -> ())");
Expand Down
10 changes: 10 additions & 0 deletions deep_gemm/include/deep_gemm/common/utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,16 @@ __device__ __forceinline__ void prefetch_l1(void *ptr) {
asm volatile("prefetch.global.L1 [%0];" :: "l"(ptr));
}

__device__ __forceinline__ void store_wait() {
asm volatile("cp.async.bulk.wait_group 0;\n" ::: "memory");
}

__device__ __forceinline__ int atomic_add_release_global(int* addr, int value) {
int ret;
asm volatile ("atom.add.release.gpu.global.s32 %0, [%1], %2;" : "=r"(ret) : "l"(addr), "r"(value));
return ret;
}

template <uint32_t kNumBytes>
struct Vectorized {
static auto zeros() {
Expand Down
16 changes: 14 additions & 2 deletions deep_gemm/include/deep_gemm/impls/sm90_fp8_gemm_1d2d.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,9 @@ template <uint32_t SHAPE_M, uint32_t SHAPE_N, uint32_t SHAPE_K,
uint32_t kNumTMAThreads, uint32_t kNumMathThreads,
uint32_t kNumTMAMulticast, bool kIsTMAMulticastOnA,
uint32_t kNumSMs, GemmType kGemmType,
typename epilogue_type_t>
typename epilogue_type_t, bool kEnableOverlap>
__global__ __launch_bounds__(kNumTMAThreads + kNumMathThreads, 1) void
sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout,
sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout, int *signal,
uint32_t shape_m, uint32_t shape_n, uint32_t shape_k,
const __grid_constant__ cute::TmaDescriptor tensor_map_a,
const __grid_constant__ cute::TmaDescriptor tensor_map_b,
Expand Down Expand Up @@ -395,6 +395,18 @@ sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout,
cute::tma_store_arrive();
}
__syncwarp();

if constexpr (kEnableOverlap) {
if (threadIdx.x < BLOCK_N / TMA_D_BLOCK_N) {
store_wait();
}

cutlass::arch::NamedBarrier(kNumMathThreads).sync();

if (threadIdx.x == 0) {
atomic_add_release_global(signal + scheduler.current_group_idx * ceil_div(shape_m, BLOCK_M) + m_block_idx, 1);
}
}
}
}
#else
Expand Down
15 changes: 15 additions & 0 deletions deep_gemm/testing/numeric.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,18 @@ def count_bytes(*tensors):
elif t is not None:
total += t.numel() * t.element_size()
return total

def check_signal(num_local_expert, max_m, block_m, threshold, signal, masked_m):
ceil_div = lambda a, b: (a + b - 1) // b

expert_len = max_m // block_m
for expert in range(num_local_expert):
mask = masked_m[expert]
start = expert * expert_len
end = expert * expert_len + expert_len
valid_len = ceil_div(mask, block_m)
for i in range(start, end):
if i < start + valid_len:
assert signal[i] == threshold, f'{i=}, {signal[i]=}, {threshold=}'
else:
assert signal[i] == 0, f'{i=}, {signal[i]=}'
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def prepare_includes(self):
]
},
ext_modules=[
CUDAExtension(name='deep_gemm_cpp',
CUDAExtension(name='deep_gemm.deep_gemm_cpp',
sources=sources,
include_dirs=build_include_dirs,
libraries=build_libraries,
Expand Down
Loading