Skip to content

Commit

Permalink
cleanup, add import
Browse files Browse the repository at this point in the history
Signed-off-by: ElizaWszola <[email protected]>
ElizaWszola committed Jan 28, 2025
1 parent 64c2a68 commit 1ea7874
Showing 5 changed files with 81 additions and 160 deletions.
15 changes: 7 additions & 8 deletions csrc/cpu/torch_bindings.cpp
Original file line number Diff line number Diff line change
@@ -119,19 +119,18 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" Tensor b_scales, Tensor? bias) -> ()");
ops.impl("cutlass_scaled_mm", torch::kCPU, &int8_scaled_mm);
// CUTLASS w8a8 grouped GEMM // TODO complete this
// ops.def(
// "cutlass_grouped_mm(Tensor! out, Tensor a, Tensor b, Tensor a_scales, "
// " Tensor b_scales, Tensor problem_sizes, "
// " Tensor out_offsets, Tensor a_offsets, "
// " Tensor b_offsets) -> ()");
// ops.impl("cutlass_grouped_mm", torch::kCUDA, &cutlass_grouped_mm);
// ops.def(
// "cutlass_grouped_mm(Tensor! out, Tensor a, Tensor b, Tensor a_scales,
// " " Tensor b_scales, Tensor problem_sizes, " "
// Tensor out_offsets, Tensor a_offsets, " " Tensor
// b_offsets) -> ()");
// ops.impl("cutlass_grouped_mm", torch::kCUDA, &cutlass_grouped_mm);

ops.def(
"compute_expert_offsets(Tensor! trg_a_ptrs,"
" Tensor! a, Tensor topk_ids,"
" Tensor! expert_offsets, SymInt num_experts) -> ()");
ops.impl("compute_expert_offsets", torch::kCUDA,
&compute_expert_offsets);
ops.impl("compute_expert_offsets", torch::kCUDA, &compute_expert_offsets);

// w8a8 GEMM, supporting asymmetric per-tensor or per-row/column
// quantization.
9 changes: 4 additions & 5 deletions csrc/ops.h
Original file line number Diff line number Diff line change
@@ -167,11 +167,10 @@ void cutlass_grouped_mm(torch::Tensor& out_tensors,
torch::Tensor const& expert_offsets,
torch::Tensor const& problem_sizes);

void compute_expert_offsets(torch::Tensor& trg_a_ptrs,
torch::Tensor& a,
const torch::Tensor& topk_ids,
torch::Tensor& expert_offsets,
const int64_t num_experts);
void compute_expert_offsets(torch::Tensor& trg_a_ptrs, torch::Tensor& a,
const torch::Tensor& topk_ids,
torch::Tensor& expert_offsets,
const int64_t num_experts);

void cutlass_scaled_mm_azp(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
178 changes: 55 additions & 123 deletions csrc/quantization/cutlass_w8a8/grouped_mm_c3x.cu
Original file line number Diff line number Diff line change
@@ -80,8 +80,6 @@ struct cutlass_3x_group_gemm {
const int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value;

using EVTCompute = typename Epilogue::EVTCompute;
// the orig hat cutlass::epilogue::fusion::LinearCombination<ElementC,
// ElementAccumulator>

using CollectiveEpilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
@@ -143,88 +141,44 @@ void cutlass_group_gemm_caller(torch::Tensor& out_tensors,
bool per_act_token = a_scales.numel() != groups;
bool per_out_ch = b_scales.numel() != groups;

// TORCH_CHECK((int)b_tensors.size() == groups,
// "Number of B tensors must match number of groups.");
// TORCH_CHECK((int)out_tensors.size() == groups,
// "Number of output tensors must match number of groups.");

// std::vector<const ElementAB*> a_ptrs_host(groups);
// std::vector<const ElementAB*> b_ptrs_host(groups);
// std::vector<const ElementC*> c_ptrs_host(groups);
// std::vector<ElementC*> d_ptrs_host(groups);
// std::vector<const ElementAccumulator*> a_scales_ptrs_host(groups);
// std::vector<const ElementAccumulator*> b_scales_ptrs_host(groups);

// std::vector<typename ProblemShape::UnderlyingProblemShape> problem_sizes_host;
// problem_sizes_host.reserve(groups);

int b_single_size = k_size * n_size;
int b_scale_single_size = per_out_ch ? out_tensors.size(1) : 1;

auto options_int =
torch::TensorOptions().dtype(torch::kInt64).device(a_tensors.device());
torch::Tensor a_ptrs_base = torch::full({groups + 1},
(int64_t)a_tensors.data_ptr(),
options_int);
torch::Tensor out_ptrs_base = torch::full({groups + 1},
(int64_t)out_tensors.data_ptr(),
options_int);
torch::Tensor b_ptrs_base = torch::full({groups + 1},
(int64_t)b_tensors.data_ptr(),
options_int);
torch::Tensor a_scales_base = torch::full({groups + 1},
(int64_t)a_scales.data_ptr(),
options_int);
torch::Tensor b_scales_base = torch::full({groups + 1},
(int64_t)b_scales.data_ptr(),
options_int);
torch::Tensor a_ptrs_base =
torch::full({groups + 1}, reinterpret_cast<int64_t>(a_tensors.data_ptr()),
options_int);
torch::Tensor out_ptrs_base = torch::full(
{groups + 1}, reinterpret_cast<int64_t>(out_tensors.data_ptr()),
options_int);
torch::Tensor b_ptrs_base =
torch::full({groups + 1}, reinterpret_cast<int64_t>(b_tensors.data_ptr()),
options_int);
torch::Tensor a_scales_base =
torch::full({groups + 1}, reinterpret_cast<int64_t>(a_scales.data_ptr()),
options_int);
torch::Tensor b_scales_base =
torch::full({groups + 1}, reinterpret_cast<int64_t>(b_scales.data_ptr()),
options_int);

torch::Tensor b_offsets = torch::arange(0, b_single_size * (groups + 1),
b_single_size, options_int);
b_single_size, options_int);
torch::Tensor a_scales_offsets = torch::arange(0, groups + 1, options_int);
torch::Tensor b_scales_offsets = torch::arange(0, b_scale_single_size *
(groups + 1), b_scale_single_size,
options_int);

// multiply by offset of k 8-bit elements
torch::Tensor a_ptrs = a_ptrs_base.add(expert_offsets, a_tensors.size(1));
// multiply by offset of n 16-bit elements
torch::Tensor out_ptrs = out_ptrs_base.add(expert_offsets, 2 * out_tensors.size(1));
// multiply by offset of n 8-bit elements
torch::Tensor b_ptrs = b_ptrs_base.add(b_offsets);

torch::Tensor a_scales_ptrs = a_scales_base.add(per_act_token ? expert_offsets : a_scales_offsets, 4);
torch::Tensor b_scales_ptrs = b_scales_base.add(b_scales_offsets, 4);

// for (int g = 0; g < groups; ++g) {
// // b_ptrs_host[g] =
// // reinterpret_cast<const ElementAB*>(b_list[g].data_ptr());
// // c_ptrs_host[g] =
// // reinterpret_cast<const ElementC*>(out_tensors[g].data_ptr());
// // d_ptrs_host[g] = reinterpret_cast<ElementC*>(out_tensors[g].data_ptr());
// // a_scales_ptrs_host[g] =
// // reinterpret_cast<const ElementAccumulator*>(a_scales[g].data_ptr());
// // b_scales_ptrs_host[g] =
// // reinterpret_cast<const ElementAccumulator*>(b_scales[g].data_ptr());

// // printf("%p %p %p %p %p %p %p\n", a_ptrs_host[g], b_ptrs_host[g],
// // c_ptrs_host[g], d_ptrs_host[g],)
// // int64_t m = out_tensors[g].size(0);
// // int64_t k = a_tensors.size(1);

// // int64_t k_b = b_tensors[g].size(0);
// // int64_t n = b_tensors[g].size(1);

// // TORCH_CHECK(k == k_b, "Dimension mismatch between A and B: A has k=", k,
// // " while B has k=", k_b);

// // // Optionally, verify output shape matches (m,n)
// // TORCH_CHECK(out_tensors[g].size(0) == m && out_tensors[g].size(1) == n,
// // "Output tensor shape does not match m,n from A,B: ", "Got ",
// // out_tensors[g].sizes(), " expected (", m, ", ", n, ")");

// // problem_sizes_host.push_back({(int)m, (int)n, (int)k});
// }
torch::Tensor b_scales_offsets = torch::arange(
0, b_scale_single_size * (groups + 1), b_scale_single_size, options_int);

torch::Tensor a_ptrs = a_ptrs_base.add(
expert_offsets, sizeof(ElementAB_Type) * a_tensors.size(1));
torch::Tensor out_ptrs = out_ptrs_base.add(
expert_offsets, sizeof(ElementC_Type) * out_tensors.size(1));
torch::Tensor b_ptrs = b_ptrs_base.add(b_offsets, sizeof(ElementAB_Type));

torch::Tensor a_scales_ptrs =
a_scales_base.add(per_act_token ? expert_offsets : a_scales_offsets,
sizeof(ElementAccumulator));
torch::Tensor b_scales_ptrs =
b_scales_base.add(b_scales_offsets, sizeof(ElementAccumulator));

using GemmKernel = typename Gemm::GemmKernel;
using StrideA = Stride<int64_t, Int<1>, Int<0>>;
@@ -239,7 +193,6 @@ void cutlass_group_gemm_caller(torch::Tensor& out_tensors,
int64_t lda = a_tensors.stride(0); // row-major (m x k)
int64_t ldb = a_tensors.stride(0); // column-major (k x n)
int64_t ldc = out_tensors.stride(0); // row-major (m x n)
printf("strides: %ld %ld %ld\n", lda, ldb, ldc);

a_stride_host[g] = StrideA{lda, Int<1>{}, Int<0>{}};
b_stride_host[g] = StrideB{ldb, Int<1>{}, Int<0>{}};
@@ -252,49 +205,33 @@ void cutlass_group_gemm_caller(torch::Tensor& out_tensors,
cutlass::KernelHardwareInfo::query_device_multiprocessor_count(
hw_info.device_id);

// auto problem_sizes_ptr = make_device_ptr(problem_sizes_host);
// ProblemShape prob_shape{groups, problem_sizes_ptr.get(),
// problem_sizes_host.data()};
ProblemShape::UnderlyingProblemShape* problem_sizes_as_shapes =
reinterpret_cast<ProblemShape::UnderlyingProblemShape*>(
problem_sizes.data_ptr());
problem_sizes.data_ptr());
ProblemShape prob_shape{groups, problem_sizes_as_shapes, nullptr};

// auto a_ptrs_ptr = make_device_ptr(a_ptrs_host);
// auto b_ptrs_ptr = make_device_ptr(b_ptrs_host);
// auto c_ptrs_ptr = make_device_ptr(c_ptrs_host);
// auto d_ptrs_ptr = make_device_ptr(d_ptrs_host);

// auto a_scales_ptrs_ptr = make_device_ptr(a_scales_ptrs_host);
// auto b_scales_ptrs_ptr = make_device_ptr(b_scales_ptrs_host);

auto a_stride_ptr = make_device_ptr(a_stride_host);
auto b_stride_ptr = make_device_ptr(b_stride_host);
auto c_stride_ptr = make_device_ptr(c_stride_host);

// auto c_ptrs_ptr = make_device_ptr(c_ptrs_host);
// auto d_ptrs_ptr = make_device_ptr(d_ptrs_host);

typename GemmKernel::MainloopArguments mainloop_args{
(const ElementAB_Type**)a_ptrs.data_ptr(), a_stride_ptr.get(),
(const ElementAB_Type**)b_ptrs.data_ptr(), b_stride_ptr.get()};
reinterpret_cast<const ElementAB_Type**>(a_ptrs.data_ptr()),
a_stride_ptr.get(),
reinterpret_cast<const ElementAB_Type**>(b_ptrs.data_ptr()),
b_stride_ptr.get()};

// Currently, we are only able to do broadcast on either all or none a_scales
// and on either all or none b_scales
typename GemmKernel::EpilogueArguments epilogue_args{
Gemm::Epilogue::prepare_args(
(const ElementAccumulator**)a_scales_ptrs.data_ptr(),
(const ElementAccumulator**)b_scales_ptrs.data_ptr(),
per_act_token, per_out_ch),
(const ElementC_Type**)out_ptrs.data_ptr(), c_stride_ptr.get(),
(ElementC_Type**)out_ptrs.data_ptr(), c_stride_ptr.get()};

// typename GemmKernel::EpilogueArguments epilogue_args{
// Gemm::Epilogue::prepare_args(
// (const ElementAccumulator**)a_scales_ptrs.data_ptr(),
// b_scales_ptrs_ptr.get(),
// per_act_token, per_out_ch),
// (const ElementC_Type**)out_ptrs.data_ptr(), c_stride_ptr.get(),
// (ElementC_Type**)out_ptrs.data_ptr(), c_stride_ptr.get()};
Gemm::Epilogue::prepare_args(reinterpret_cast<const ElementAccumulator**>(
a_scales_ptrs.data_ptr()),
reinterpret_cast<const ElementAccumulator**>(
b_scales_ptrs.data_ptr()),
per_act_token, per_out_ch),
reinterpret_cast<const ElementC_Type**>(out_ptrs.data_ptr()),
c_stride_ptr.get(),
reinterpret_cast<ElementC_Type**>(out_ptrs.data_ptr()),
c_stride_ptr.get()};

typename GemmKernel::Arguments args{
cutlass::gemm::GemmUniversalMode::kGrouped, prob_shape, mainloop_args,
@@ -309,11 +246,9 @@ void cutlass_group_gemm_caller(torch::Tensor& out_tensors,
torch::TensorOptions().dtype(torch::kUInt8).device(a_tensors.device());
auto workspace = torch::empty(workspace_size, workspace_options);

// printf("before: %d\n", out_tensors[0][0]);
auto stream = at::cuda::getCurrentCUDAStream(a_tensors.device().index());
cutlass::Status status = gemm_op.run(args, workspace.data_ptr(), stream);
CUTLASS_CHECK(status);
// printf("after: %d\n", out_tensors[0][0]);
}

template <typename InType, typename OutType,
@@ -350,7 +285,6 @@ struct sm90_fp8_config_M128 {
template <typename InType, typename OutType,
template <typename, typename, typename> typename Epilogue>
struct sm90_fp8_config_M64 {
// M in [1, 64]
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
using KernelSchedule =
cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8FastAccum;
@@ -394,8 +328,7 @@ void cutlass_grouped_mm_sm90(torch::Tensor& out_tensors,
__global__ void get_a_expert_offsets(cutlass::float_e4m3_t** trg_a_ptrs,
cutlass::float_e4m3_t* base_a_ptr,
const int* __restrict__ topk_ids,
int64_t* expert_offsets,
int topk_length) {
int64_t* expert_offsets, int topk_length) {
int expert_id = threadIdx.x;
int num_experts = blockDim.x;

@@ -419,10 +352,12 @@ __global__ void get_a_expert_offsets(cutlass::float_e4m3_t** trg_a_ptrs,

// // For a given "a" of size [M,K] performs a permutation of the M rows based
// // on the given "perm" indices.
// __global__ void permute_fp8_rows_kernel(cutlass::float_e4m3_t const* __restrict__ a_ptr,
// __global__ void permute_fp8_rows_kernel(cutlass::float_e4m3_t const*
// __restrict__ a_ptr,
// int const* __restrict__ perm_int_ptr,
// cutlass::float_e4m3_t* __restrict__ out_ptr,
// int size_m, int size_k, int block_rows) {
// cutlass::float_e4m3_t* __restrict__
// out_ptr, int size_m, int size_k, int
// block_rows) {
// int start_row = block_rows * blockIdx.x;
// int finish_row = start_row + block_rows;
// if (finish_row > size_m) {
@@ -459,17 +394,14 @@ __global__ void get_a_expert_offsets(cutlass::float_e4m3_t** trg_a_ptrs,
// };
// }

void compute_expert_offsets_caller(torch::Tensor& trg_a_ptrs,
torch::Tensor& a,
void compute_expert_offsets_caller(torch::Tensor& trg_a_ptrs, torch::Tensor& a,
const torch::Tensor& topk_ids,
torch::Tensor& expert_offsets,
const int64_t num_experts) {
get_a_expert_offsets<<<1, num_experts>>>(
get_a_expert_offsets<<<1, num_experts>>>(
(cutlass::float_e4m3_t**)trg_a_ptrs.data_ptr(),
(cutlass::float_e4m3_t*)a.data_ptr(),
(const int*)topk_ids.data_ptr(),
(int64_t*)expert_offsets.data_ptr(),
topk_ids.numel());
(cutlass::float_e4m3_t*)a.data_ptr(), (const int*)topk_ids.data_ptr(),
(int64_t*)expert_offsets.data_ptr(), topk_ids.numel());
}

// void permute_fp8_rows(torch::Tensor& a_ptr,
3 changes: 1 addition & 2 deletions csrc/torch_bindings.cpp
Original file line number Diff line number Diff line change
@@ -337,8 +337,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"compute_expert_offsets(Tensor! trg_a_ptrs,"
" Tensor! a, Tensor topk_ids,"
" Tensor! expert_offsets, SymInt num_experts) -> ()");
ops.impl("compute_expert_offsets", torch::kCUDA,
&compute_expert_offsets);
ops.impl("compute_expert_offsets", torch::kCUDA, &compute_expert_offsets);

// Check if cutlass sparse scaled_mm is supported for CUDA devices of the
// given capability
36 changes: 14 additions & 22 deletions tests/kernels/test_cutlass.py
Original file line number Diff line number Diff line change
@@ -2,12 +2,13 @@
Run `pytest tests/kernels/test_cutlass.py`.
"""
import random
from typing import Type

import pytest
import torch

from tests.kernels.utils import opcheck, stack_and_dev
from tests.kernels.utils import opcheck
from vllm import _custom_ops as ops
from vllm.platforms import current_platform

@@ -452,23 +453,20 @@ def test_cutlass_fp8_group_gemm(num_groups: int, per_act_token: bool,
b_tensors = []
a_scales_tensors = []
b_scales_tensors = []
out_tensors = []
baseline_tensors = []

expert_offsets = torch.zeros((num_groups + 1),
device=device,
dtype=torch.int32)

problem_sizes = torch.zeros((num_groups, 3),
device=device,
dtype=torch.int32)
device=device,
dtype=torch.int32)

alignment = 16 # 128 // 8
# For variation, each group has dimensions
# (m_g = m/(g+1), n_g = n/(g+1), k_g = k/(g+1))
n_g = alignment * random.randint(1, 64)
k_g = alignment * random.randint(1, 64)
# one_b = to_fp8(torch.randn((n_g, k_g), device=device))
for g in range(num_groups):
m_g = alignment * random.randint(1, 64)

@@ -485,8 +483,6 @@ def test_cutlass_fp8_group_gemm(num_groups: int, per_act_token: bool,
# Create group-specific A and B (FP8) and output (FP16/FP32)
a_g = to_fp8(torch.randn((m_g, k_g), device=device))
b_g = to_fp8(torch.randn((n_g, k_g), device=device).t())
# b_g = one_b.clone().t()
c_g = torch.zeros((m_g, n_g), device=device, dtype=out_dtype)
# Set up A/B scales
scale_a = torch.randn((m_a_scales, 1),
device=device,
@@ -497,12 +493,9 @@ def test_cutlass_fp8_group_gemm(num_groups: int, per_act_token: bool,

a_tensors.append(a_g)
b_tensors.append(b_g)
out_tensors.append(c_g)
a_scales_tensors.append(scale_a)
b_scales_tensors.append(scale_b)

print(b_g.stride())

# Compute baseline result for this group
baseline_g = baseline_scaled_mm(a_g, b_g, scale_a, scale_b, out_dtype,
None)
@@ -517,7 +510,7 @@ def test_cutlass_fp8_group_gemm(num_groups: int, per_act_token: bool,
for g in range(num_groups):
a_tensors_stacked[expert_offsets[g]:expert_offsets[g +
1]] = a_tensors[g]
b_tensors_stacked[g*n_g:(g+1)*n_g, :] = b_tensors[g].t()
b_tensors_stacked[g * n_g:(g + 1) * n_g, :] = b_tensors[g].t()
b_tensors_stacked = b_tensors_stacked.t()

a_scales_tensors_stacked = torch.empty(
@@ -526,28 +519,27 @@ def test_cutlass_fp8_group_gemm(num_groups: int, per_act_token: bool,
dtype=torch.float32)
if per_act_token:
for g in range(num_groups):
a_scales_tensors_stacked[expert_offsets[g]:expert_offsets[g +
1]] = a_scales_tensors[g]
a_scales_tensors_stacked[
expert_offsets[g]:expert_offsets[g + 1]] = a_scales_tensors[g]
else:
for g in range(num_groups):
a_scales_tensors_stacked[g] = a_scales_tensors[g]

b_scales_tensors_stacked = torch.empty(
(num_groups, n_b_scales),
device=device,
dtype=torch.float32)
b_scales_tensors_stacked = torch.empty((num_groups, n_b_scales),
device=device,
dtype=torch.float32)
for g in range(num_groups):
b_scales_tensors_stacked[g] = b_scales_tensors[g]

out_tensors_stacked = torch.zeros((expert_offsets[num_groups], n_g),
device=device,
dtype=out_dtype)
device=device,
dtype=out_dtype)

torch.ops._C.cutlass_grouped_mm(out_tensors_stacked, a_tensors_stacked,
b_tensors_stacked,
a_scales_tensors_stacked,
b_scales_tensors_stacked,
expert_offsets, problem_sizes)
b_scales_tensors_stacked, expert_offsets,
problem_sizes)

# Validate each group's result against the baseline
for g in range(num_groups):

0 comments on commit 1ea7874

Please sign in to comment.