Skip to content

Commit

Permalink
runs, bad result
Browse files Browse the repository at this point in the history
Signed-off-by: ElizaWszola <[email protected]>
  • Loading branch information
ElizaWszola committed Dec 9, 2024
1 parent 1825ef8 commit 49219f9
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 73 deletions.
107 changes: 48 additions & 59 deletions csrc/quantization/cutlass_w8a8/grouped_gemm_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -58,35 +58,14 @@ using ElementAB_Type = cutlass::float_e4m3_t;
// using ElementB = cutlass::float_e4m3_t; // Element type for B matrix operand
using ElementC_Type = cutlass::half_t;

// // A matrix configuration
// using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand
// constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value; // Alignment of A matrix in units of elements (up to 16 bytes)

// // B matrix configuration
// using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand
// constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value; // Alignment of B matrix in units of elements (up to 16 bytes)

// // C/D matrix configuration
// using LayoutC = cutlass::layout::ColumnMajor; // Layout type for C and D matrix operands
// constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value; // Alignment of C matrix in units of elements (up to 16 bytes)

// Core kernel configurations
using ElementAccumulator = float; // Element type for internal accumulation
using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature
using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag
// using StageCountType = cutlass::gemm::collective::StageCountAuto; // Stage count maximized based on the tile size

// Different configs for pingpong/cooperative
// struct CooperativeConfig {
// using KernelSchedule = cutlass::KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum;
// using EpilogueSchedule = cutlass::KernelPtrArrayTmaWarpSpecializedCooperative;
// using TileShape = cute::Shape<cute::_256,cute::_128,cute::_128>;
// using ClusterShape = cute::Shape<cute::_2,cute::_2,cute::_1>;
// };

using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::ColumnMajor;
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::ColumnMajor;

template <typename ElementAB_, typename ElementC_,
template <typename, typename, typename> typename Epilogue_,
Expand All @@ -107,8 +86,8 @@ struct cutlass_3x_group_gemm {

using StrideC = cute::remove_pointer_t<cute::Stride<int64_t, cute::Int<1>, cute::Int<0>>>;

const int AlignmentAB = 128 / cutlass::sizeof_bits<ElementAB>::value;
const int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value;
const int AlignmentAB = 128 / cutlass::sizeof_bits<ElementAB>::value;
const int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value;

using EVTCompute = typename Epilogue::EVTCompute;
// the orig hat cutlass::epilogue::fusion::LinearCombination<ElementC, ElementAccumulator>
Expand Down Expand Up @@ -172,34 +151,25 @@ void cutlass_group_gemm_caller(torch::Tensor& out, torch::Tensor const& a,
std::vector<ElementC*> d_ptrs_host(groups);

for (int g = 0; g < groups; ++g) {
a_ptrs_host.at(g) = (ElementAB*)a.data_ptr();// + a_offsets[g].item<int32_t>();
b_ptrs_host.at(g) = (ElementAB*)b.data_ptr();// + b_offsets[g].item<int32_t>();
c_ptrs_host.at(g) = (ElementC*)out.data_ptr();// + out_offsets[g].item<int32_t>();
d_ptrs_host.at(g) = (ElementC*)out.data_ptr();// + out_offsets[g].item<int32_t>();
a_ptrs_host.at(g) = (ElementAB*)a.data_ptr() + a_offsets[g].item<int32_t>();
b_ptrs_host.at(g) = (ElementAB*)b.data_ptr() + b_offsets[g].item<int32_t>();
c_ptrs_host.at(g) = (ElementC*)out.data_ptr() + out_offsets[g].item<int32_t>();
d_ptrs_host.at(g) = (ElementC*)out.data_ptr() + out_offsets[g].item<int32_t>();
}

// int32_t groups = a.size(0);
// int32_t m = a.size(1);
// int32_t n = b.size(2);
// int32_t k = a.size(2);

// int64_t lda = a.stride(1);
// int64_t ldb = b.stride(2);
// int64_t ldc = out.stride(1);

using StrideA = typename Gemm::GemmKernel::InternalStrideA;
using StrideB = typename Gemm::GemmKernel::InternalStrideB;
using StrideC = typename Gemm::GemmKernel::InternalStrideC;
using StrideD = typename Gemm::GemmKernel::InternalStrideD;

// StrideA stride_A{lda, cute::Int<1>{}, 0};
// StrideB stride_B{ldb, cute::Int<1>{}, 0};
// StrideC stride_C{ldc, cute::Int<1>{}, cute::Int<0>{}};
int64_t lda = a.stride(0);
int64_t ldb = b.stride(1);
int64_t ldc = out.stride(0);

// this should be vector of A ptrs
// auto ptr_A = static_cast<ElementAB*>(a.data_ptr());
// auto ptr_B = static_cast<ElementAB*>(b.data_ptr());
// auto ptr_C = static_cast<ElementC*>(out.data_ptr());
std::vector<StrideA> a_stride_host(groups, StrideA{lda, cute::Int<1>{}, cute::Int<0>{}});
std::vector<StrideB> b_stride_host(groups, StrideB{ldb, cute::Int<1>{}, cute::Int<0>{}});
// TODO fix
std::vector<StrideC> c_stride_host(groups, StrideC{cute::Int<1>{}, ldc, cute::Int<0>{}});

cutlass::platform::unique_ptr<StrideA> stride_A;
cutlass::platform::unique_ptr<StrideB> stride_B;
Expand All @@ -212,7 +182,7 @@ void cutlass_group_gemm_caller(torch::Tensor& out, torch::Tensor const& a,
cutlass::platform::unique_ptr<ElementC*> ptr_D;

using GemmKernel = typename Gemm::GemmKernel;

cutlass::KernelHardwareInfo hw_info;
// Change device_id to another value if you are running on a machine with multiple GPUs and wish
// to use a GPU other than that with device ID 0.
Expand Down Expand Up @@ -241,38 +211,60 @@ void cutlass_group_gemm_caller(torch::Tensor& out, torch::Tensor const& a,

const ElementAB** a_ptrs_device;
cudaMalloc(&a_ptrs_device, groups * sizeof(ElementAB*));
cudaMemcpy(a_ptrs_device, a_ptrs_host.data(), groups,cudaMemcpyHostToDevice);
cudaMemcpy(a_ptrs_device, a_ptrs_host.data(), groups, cudaMemcpyHostToDevice);
cutlass::platform::unique_ptr<const ElementAB*, ItemDeleter<const ElementAB*>> a_ptrs_ptr(
a_ptrs_device
);

const ElementAB** b_ptrs_device;
cudaMalloc(&b_ptrs_device, groups * sizeof(ElementAB*));
cudaMemcpy(b_ptrs_device, b_ptrs_host.data(), groups,cudaMemcpyHostToDevice);
cudaMemcpy(b_ptrs_device, b_ptrs_host.data(), groups, cudaMemcpyHostToDevice);
cutlass::platform::unique_ptr<const ElementAB*, ItemDeleter<const ElementAB*>> b_ptrs_ptr(
b_ptrs_device
);

const ElementC** c_ptrs_device;
cudaMalloc(&c_ptrs_device, groups * sizeof(ElementC*));
cudaMemcpy(c_ptrs_device, c_ptrs_host.data(), groups,cudaMemcpyHostToDevice);
cudaMemcpy(c_ptrs_device, c_ptrs_host.data(), groups, cudaMemcpyHostToDevice);
cutlass::platform::unique_ptr<const ElementC*, ItemDeleter<const ElementC*>> c_ptrs_ptr(
c_ptrs_device
);

// TODO if we start with empty values here, no need to copy
ElementC** d_ptrs_device;
cudaMalloc(&d_ptrs_device, groups * sizeof(ElementC*));
cudaMemcpy(d_ptrs_device, d_ptrs_host.data(), groups,cudaMemcpyHostToDevice);
cudaMemcpy(d_ptrs_device, d_ptrs_host.data(), groups, cudaMemcpyHostToDevice);
cutlass::platform::unique_ptr<ElementC*, ItemDeleter<ElementC*>> d_ptrs_ptr(
d_ptrs_device
);

StrideA* a_stride_device;
cudaMalloc(&a_stride_device, groups * sizeof(StrideA*));
cudaMemcpy(a_stride_device, a_stride_host.data(), groups, cudaMemcpyHostToDevice);
cutlass::platform::unique_ptr<StrideA, ItemDeleter<StrideA>> a_stride_ptr(
a_stride_device
);

StrideB* b_stride_device;
cudaMalloc(&b_stride_device, groups * sizeof(StrideB*));
cudaMemcpy(b_stride_device, b_stride_host.data(), groups, cudaMemcpyHostToDevice);
cutlass::platform::unique_ptr<StrideB, ItemDeleter<StrideB>> b_stride_ptr(
b_stride_device
);

StrideC* c_stride_device;
cudaMalloc(&c_stride_device, groups * sizeof(StrideC*));
cudaMemcpy(c_stride_device, c_stride_host.data(), groups, cudaMemcpyHostToDevice);
cutlass::platform::unique_ptr<StrideC, ItemDeleter<StrideC>> c_stride_ptr(
c_stride_device
);

typename GemmKernel::MainloopArguments mainloop_args{
a_ptrs_ptr.get(), stride_A.get(), b_ptrs_ptr.get(), stride_B.get()};
a_ptrs_ptr.get(), a_stride_ptr.get(), b_ptrs_ptr.get(), b_stride_ptr.get()};
typename GemmKernel::EpilogueArguments epilogue_args{
Gemm::Epilogue::prepare_args(
std::forward<EpilogueArgs>(epilogue_params)...),
c_ptrs_ptr.get(), stride_C.get(), d_ptrs_ptr.get(), stride_D.get()};
c_ptrs_ptr.get(), c_stride_ptr.get(), d_ptrs_ptr.get(), c_stride_ptr.get()};

typename GemmKernel::Arguments args{
cutlass::gemm::GemmUniversalMode::kGrouped,
Expand All @@ -296,11 +288,8 @@ void cutlass_group_gemm_caller(torch::Tensor& out, torch::Tensor const& a,

CUTLASS_CHECK(gemm_op.initialize(args, workspace.data_ptr()));

// #if defined(ENABLE_SM90_KERNEL_LEVEL)
// printf("did run through\n");
cutlass::Status status = gemm_op.run();
CUTLASS_CHECK(status);
// #endif
cutlass::Status status = gemm_op.run();
CUTLASS_CHECK(status);

}

Expand Down Expand Up @@ -367,7 +356,7 @@ void cutlass_grouped_mm_sm90(torch::Tensor& out, torch::Tensor const& a,

TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
// int32_t m = a.size(1);
// int32_t m = a.size(1);

using Cutlass3xGemmDefault =
typename sm90_fp8_config_default<ElementAB_Type, ElementC_Type,
Expand Down
34 changes: 20 additions & 14 deletions tests/kernels/test_cutlass.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,31 +484,37 @@ def test_cutlass_fp8_group_gemm(m: int, n: int, k: int, num_groups: int,
m = alignment * random.randint(1, 64)
n = alignment * random.randint(1, 64)
k = alignment * random.randint(1, 64)
tot_a += m * k
tot_b += k * n
tot_c += m * n
tot_a += m
tot_b += k
tot_c += m
offsets_a[g] = m * k
offsets_b[g] = k * n
offsets_c[g] = m * n
problem_sizes[g][0] = m
problem_sizes[g][1] = n
problem_sizes[g][2] = k

a = to_fp8(torch.randn((tot_a), device=device))
b = to_fp8(torch.randn((tot_b), device=device).t())
c = torch.zeros((tot_c), device=device).to(out_dtype)
a = to_fp8(torch.randn((tot_a, k), device=device))
b = to_fp8(torch.randn((tot_b, n), device=device).t())
c = torch.zeros((tot_c, n), device=device).to(out_dtype)

m_a_scales = m if per_act_token else 1
n_b_scales = n if per_out_ch else 1
print(tot_a, tot_b, tot_c)

scale_a = (torch.randn((m_a_scales, 1), device=device,
print(a.stride(), b.stride(), c.stride())

# m_a_scales = m if per_act_token else 1
# n_b_scales = n if per_out_ch else 1

scale_a = (torch.randn((tot_a if per_act_token else num_groups),
device=device,
dtype=torch.float32))
scale_b = (torch.randn((1, n_b_scales), device=device,
scale_b = (torch.randn((tot_b if per_act_token else num_groups),
device=device,
dtype=torch.float32))
if use_bias:
bias = torch.rand((n, 1), device=device, dtype=out_dtype) * 10
else:
bias = None
# if use_bias:
# bias = torch.rand((n, 1), device=device, dtype=out_dtype) * 10
# else:
# bias = None

# TODO strides we can get later the same way as in scaled_mm_c3x.cu
torch.ops._C.cutlass_grouped_mm(c, a, b, scale_a, scale_b, problem_sizes,
Expand Down

0 comments on commit 49219f9

Please sign in to comment.