Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(cell): Relax Global-Shared loader/storer to support all possible tile shape. #47

Draft
wants to merge 19 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,9 @@
"thread": "cpp",
"typeindex": "cpp",
"typeinfo": "cpp",
"variant": "cpp"
"variant": "cpp",
"csetjmp": "cpp",
"shared_mutex": "cpp"
},
"gotoSymbolStack.currentStackPosition": 0,
"gotoSymbolStack.maxStackPosition": 0,
Expand Down
19 changes: 19 additions & 0 deletions benchmarks/cpp/g2s_copy/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved. Licensed under the
# MIT License.
# --------------------------------------------------------------------------

cmake_minimum_required(VERSION 3.25 FATAL_ERROR)
project(bench_g2s_copy LANGUAGES C CXX CUDA)

set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH}
"${PROJECT_SOURCE_DIR}/../../../cmake")
set(THIRD_PARTY_DIR "${PROJECT_SOURCE_DIR}/../../../3rd-party")

include(generic)

include_directories("${PROJECT_SOURCE_DIR}/../../../include")
include_directories("${PROJECT_SOURCE_DIR}/../../utils/cpp")
include_directories("${THIRD_PARTY_DIR}/cutlass/include")

add_executable(bench_g2s_copy main.cu)
16 changes: 16 additions & 0 deletions benchmarks/cpp/g2s_copy/Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------


BUILD_DIR := build

.PHONY: build clean

build:
@mkdir -p $(BUILD_DIR)
@cd $(BUILD_DIR) && cmake .. && make -j$(proc)

clean:
@rm -rf $(BUILD_DIR)
28 changes: 28 additions & 0 deletions benchmarks/cpp/g2s_copy/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
## Data Tile Transfer between Global and Shared Memory

### Overview
This preliminary test evaluates the performance of transferring a row-major data tile containing half-precision floating-point values between global memory and shared memory. The transfer process involves loading the data tile into shared memory and subsequently storing it back to global memory. This cycle is repeated 100 times to measure performance.

### Performance Evaluation
Performance is assessed based on the total time required to complete the 100 data tile transfers.

### Implementations
The test includes implementations using TileFusion and cutlass, with no bank conflicts observed in the NVIDIA Compute Utility. The cutlass implementation utilizes a copy plan that allows for maximal global memory coalescing to optimally utilize the global memory.

### Test Environment
- **GPU**: NVIDIA Tesla A100
- **CUDA Version**: 12.6

### Results

| Shape | Warp Layout | TileFusion (ms) | cutlass (ms) | Ratio |
|:--------------------|:-----------:|:---------------:|:------------:|:------:|
| RowMajor (64, 64) | (1, 1) | 0.05044 | 0.05058 | 0.9974 |
| RowMajor (64, 64) | (2, 2) | 0.05309 | 0.05085 | 1.044 |
| RowMajor (64, 64) | (2, 4) | 0.07196 | 0.05199 | 1.384 |
| RowMajor (128, 128) | (1, 1) | 0.1396 | 0.1539 | 0.907 |
| RowMajor (128, 128) | (2, 2) | 0.1353 | 0.1339 | 1.010 |
| RowMajor (128, 128) | (2, 4) | 0.1434 | 0.1381 | 1.038 |
| RowMajor (128, 256) | (1, 1) | 0.2401 | 0.3693 | 0.6501 |
| RowMajor (128, 256) | (2, 2) | 0.2467 | 0.2462 | 1.002 |
| RowMajor (128, 256) | (2, 4) | 0.2528 | 0.2514 | 1.005 |
140 changes: 140 additions & 0 deletions benchmarks/cpp/g2s_copy/cutlass_copy.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once

#include "cutlass/copy.cuh"

#include <cute/swizzle.hpp>
#include <cute/tensor.hpp>

using namespace cute;
using namespace benchmarks;

namespace {
// NOTE: The current implementation of Loader/Storer supports only
// half-precision (FP16) RowMajor data tiles. It is not implemented for other
// data types or memory layouts. Be cautious when using it for other cases.
template <typename Element, //
const int kRows, const int kCols, //
const int kWarpRows, const int kWarpCols>
struct Loader {
DEVICE void operator()(const Element* src_, Element* dst_) {
int tid = threadIdx.x;

auto gtile = make_tensor(make_gmem_ptr(src_), src_layout_);
auto stile = make_tensor(make_smem_ptr(dst_), dst_layout_);

auto loader = tiled_copy_.get_thread_slice(tid);

auto src = loader.partition_S(gtile);
auto dst = loader.partition_D(stile);

#pragma unroll
for (int i = 0; i < int(size<1>(src)); ++i)
#pragma unroll
for (int j = 0; j < int(size<2>(src)); ++j)
cute::copy(tiled_copy_, src(cute::_, i, j), dst(cute::_, i, j));
}

private:
// source
using GlobalLayout =
cute::Layout<Shape<Int<kRows>, Int<kCols>>, Stride<Int<kCols>, _1>>;
GlobalLayout src_layout_;

// destination
using LayoutAtom =
decltype(composition(cute::Swizzle<2, 3, 3>{},
cute::Layout<Shape<_4, _64>, Stride<_64, _1>>{}));
using SharedLayout = decltype(tile_to_shape(
LayoutAtom{}, Shape<Int<kRows>, Int<kCols>>{}, cute::Step<_2, _1>{}));
SharedLayout dst_layout_;

// tiled copy
static constexpr int kThreadCols = kCols * 16 / 128;
static constexpr int kThreadRows = kWarpRows * kWarpCols * 32 / kThreadCols;

using ThreadLayout = cute::Layout<Shape<Int<kThreadRows>, Int<kThreadCols>>,
Stride<Int<kThreadCols>, _1>>;
using ValueLayout = cute::Layout<Shape<_1, _8>>;

using CopyInst =
Copy_Atom<SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>, Element>;
using TiledCopy =
decltype(make_tiled_copy(CopyInst{}, ThreadLayout{}, ValueLayout{}));
TiledCopy tiled_copy_;
};

template <typename Element, //
const int kRows, const int kCols, //
const int kWarpRows, const int kWarpCols>
struct Storer {
DEVICE void operator()(const Element* src_, Element* dst_) {
int tid = threadIdx.x;

auto stile = make_tensor(make_smem_ptr(src_), src_layout_); // shared
auto gtile = make_tensor(make_gmem_ptr(dst_), dst_layout_); // global

auto loader = tiled_copy_.get_thread_slice(tid);

auto src = loader.partition_S(stile);
auto dst = loader.partition_D(gtile);

#pragma unroll
for (int i = 0; i < int(size<1>(src)); ++i)
#pragma unroll
for (int j = 0; j < int(size<2>(src)); ++j)
cute::copy(tiled_copy_, src(cute::_, i, j), dst(cute::_, i, j));
}

private:
// declare the source layout
using LayoutAtom =
decltype(composition(cute::Swizzle<2, 3, 3>{},
cute::Layout<Shape<_4, _64>, Stride<_64, _1>>{}));
using SharedLayout = decltype(tile_to_shape(
LayoutAtom{}, Shape<Int<kRows>, Int<kCols>>{}, cute::Step<_2, _1>{}));
SharedLayout src_layout_;

// declare the destination layout
using GlobalLayout =
cute::Layout<Shape<Int<kRows>, Int<kCols>>, Stride<Int<kCols>, _1>>;
GlobalLayout dst_layout_;

// declare the tiled copy
static constexpr int kThreadCols = kCols * 16 / 128;
static constexpr int kThreadRows = kWarpRows * kWarpCols * 32 / kThreadCols;
using ThreadLayout = cute::Layout<Shape<Int<kThreadRows>, Int<kThreadCols>>,
Stride<Int<kThreadCols>, _1>>;
using ValueLayout = cute::Layout<Shape<_1, _8>>;

using CopyInst = Copy_Atom<DefaultCopy, Element>;
using TiledCopy =
decltype(make_tiled_copy(CopyInst{}, ThreadLayout{}, ValueLayout{}));
TiledCopy tiled_copy_;
};
} // namespace

template <typename Element, const int kRows, const int kCols,
const int kWarpRow, const int kWarpCol, const int kRepeat>
__global__ void cutlass_g2s_data_transfer(const Element* src, Element* dst) {
extern __shared__ __align__(sizeof(double)) unsigned char buf_[];
auto* buf = reinterpret_cast<Element*>(buf_);

using G2S = Loader<Element, kRows, kCols, kWarpRow, kWarpCol>;
G2S loader;

using S2G = Storer<Element, kRows, kCols, kWarpRow, kWarpCol>;
S2G storer;

for (int k = 0; k < kRepeat; ++k) {
loader(src, buf);

cutlass_wrapper::__copy_async();
__syncthreads();

storer(buf, dst);
__syncthreads();
}
}
164 changes: 164 additions & 0 deletions benchmarks/cpp/g2s_copy/main.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "cell/copy/mod.hpp"
#include "cutlass_copy.cuh"
#include "tilefusion_copy.cuh"
#include "types/mod.hpp"
#include "util/cuda_timer.hpp"

#include <thrust/device_vector.h>
#include <thrust/host_vector.h>

using namespace tilefusion;
using namespace tilefusion::cell;
using namespace tilefusion::cell::copy;

int warmup = 20;
int iters = 100;
const int kRepeat = 100;

template <typename Element>
bool check_results(const Element* dst1, const Element* dst2, int64_t numel) {
float epsilon = 1e-3;
for (int i = 0; i < numel; ++i) {
float v1 = abs(static_cast<float>(dst1[i]));
float v2 = abs(static_cast<float>(dst2[i]));
if (v1 - v2 > epsilon) {
std::cerr << "Mismatch at " << i << ": " << v1 << " vs " << v2
<< std::endl;
return false;
}
}
return true;
}

template <typename Element, typename Layout, typename WarpLayout,
const int kRepeat>
float test_tilefusion(const Element* src, Element* dst) {
using Global = GlobalTile<Element, Layout>;
using Shared = SharedTile<Element, Layout, true /*kSwizzled*/>;

using Loader = GlobalToSharedLoader<Shared, WarpLayout>;
Loader loader;

using Storer = SharedToGlobalStorer<Shared, WarpLayout>;
Storer storer;

auto kernel =
&g2s_data_transfer<Element, Global, Shared, Loader, Storer, kRepeat>;

static const int kThreads = WarpLayout::kNumel * 32;
int shm_size = Shared::kNumel * sizeof(Element);

if (shm_size > 48 * 1024) {
cudaFuncSetAttribute(
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, shm_size);
}

dim3 grids(1, 1, 1);
dim3 blocks(kThreads);

for (int i = 0; i < warmup; ++i) // warm up
kernel<<<grids, blocks, shm_size>>>(src, dst, loader, storer);
cudaDeviceSynchronize();

CudaTimer timer;
timer.start();
for (int i = 0; i < iters; ++i)
kernel<<<grids, blocks, shm_size>>>(src, dst, loader, storer);
cudaDeviceSynchronize();
return timer.stop() / iters;
}

template <typename Element, typename Layout, typename WarpLayout,
const int kRepeat>
float test_cutlass(const Element* src, Element* dst) {
auto kernel = &cutlass_g2s_data_transfer<Element, Layout::kRows,
Layout::kCols, WarpLayout::kRows,
WarpLayout::kCols, kRepeat>;

int shm_size = Layout::kNumel * sizeof(Element);
int kThreads = WarpLayout::kNumel * 32;

if (shm_size > 48 * 1024) {
cudaFuncSetAttribute(
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, shm_size);
}

dim3 grids(1, 1, 1);
dim3 blocks(kThreads);

for (int i = 0; i < warmup; ++i) {
kernel<<<grids, blocks, shm_size>>>(src, dst);
}
cudaDeviceSynchronize();

CudaTimer timer;
timer.start();
for (int i = 0; i < iters; ++i) {
kernel<<<grids, blocks, shm_size>>>(src, dst);
}
cudaDeviceSynchronize();
return timer.stop() / iters;
}

template <typename Element, typename Layout, typename WarpLayout>
void run_test_rowmajor() {
int numel = Layout::kNumel;

thrust::host_vector<Element> h_src(numel);
for (int i = 0; i < h_src.size(); ++i)
h_src[i] = static_cast<Element>(i % 2048);

thrust::device_vector<Element> d_src = h_src;
const Element* src = thrust::raw_pointer_cast(d_src.data());

thrust::device_vector<Element> d_dst1(numel);
thrust::fill(d_dst1.begin(), d_dst1.end(), static_cast<Element>(0.));
Element* dst1 = thrust::raw_pointer_cast(d_dst1.data());

thrust::device_vector<Element> d_dst2(numel);
thrust::fill(d_dst2.begin(), d_dst2.end(), static_cast<Element>(0.));
Element* dst2 = thrust::raw_pointer_cast(d_dst2.data());

float t1 = test_tilefusion<Element, Layout, WarpLayout, kRepeat>(src, dst1);
thrust::host_vector<Element> h_dst1 = d_dst1;

float t2 = test_cutlass<Element, Layout, WarpLayout, kRepeat>(src, dst2);
thrust::host_vector<Element> h_dst2 = d_dst2;

bool passed = check_results(thrust::raw_pointer_cast(h_dst1.data()),
thrust::raw_pointer_cast(h_dst2.data()), numel);
if (!passed) {
std::cerr << "Test failed" << std::endl;
return;
}

std::cout << "|RowMajor(" << Layout::kRows << ", " << Layout::kCols << ")|("
<< WarpLayout::kRows << ", " << WarpLayout::kCols << ")|" << t1
<< "|" << t2 << "|" << t1 / t2 << "|" << std::endl;
}

int main() {
std::cout << std::setprecision(4)
<< "|Shape|Warp Layout|tilefusion(ms)|cutlass(ms)|Ratio|"
<< std::endl
<< "|:---|:---:|:---:|:---:|:---:|" << std::endl;

using DType = __half;

run_test_rowmajor<DType, tl::RowMajor<64, 64>, tl::RowMajor<1, 1>>();
run_test_rowmajor<DType, tl::RowMajor<64, 64>, tl::RowMajor<2, 2>>();
run_test_rowmajor<DType, tl::RowMajor<64, 64>, tl::RowMajor<2, 4>>();

run_test_rowmajor<DType, tl::RowMajor<128, 128>, tl::RowMajor<1, 1>>();
run_test_rowmajor<DType, tl::RowMajor<128, 128>, tl::RowMajor<2, 2>>();
run_test_rowmajor<DType, tl::RowMajor<128, 128>, tl::RowMajor<2, 4>>();

run_test_rowmajor<DType, tl::RowMajor<128, 256>, tl::RowMajor<1, 1>>();
run_test_rowmajor<DType, tl::RowMajor<128, 256>, tl::RowMajor<2, 2>>();
run_test_rowmajor<DType, tl::RowMajor<128, 256>, tl::RowMajor<2, 4>>();

return 0;
}
Loading