Skip to content

Commit

Permalink
feat(kernels): Add dispatch macros. (microsoft#45)
Browse files Browse the repository at this point in the history
* add dispatch macros.

* move unittests to the correct directory.
  • Loading branch information
lcy-seso authored Jan 22, 2025
1 parent 5bb8f7b commit f154e62
Show file tree
Hide file tree
Showing 10 changed files with 103 additions and 72 deletions.
16 changes: 13 additions & 3 deletions .clang-format
Original file line number Diff line number Diff line change
@@ -1,12 +1,19 @@
# Run manually to reformat a file:
# clang-format -i --style=file <file>
BasedOnStyle: Google
UseTab: Never
ColumnLimit: 80
IndentWidth: 4

AccessModifierOffset: -2

DerivePointerAlignment: false
# If true, empty lines at the start of blocks are kept.
KeepEmptyLinesAtTheStartOfBlocks: false
PointerAlignment: Left

AlignConsecutiveAssignments: false
AlignConsecutiveDeclarations: false
IndentPPDirectives: BeforeHash

SortIncludes: true
IncludeBlocks: Regroup
IncludeCategories:
Expand All @@ -18,7 +25,10 @@ IncludeCategories:
Priority: 2
- Regex: '"([A-Za-z0-9.\Q/-_\E])+"'
Priority: 1


# If true, empty lines at the start of blocks are kept.
KeepEmptyLinesAtTheStartOfBlocks: false

AllowShortLoopsOnASingleLine: true
AllowShortIfStatementsOnASingleLine: true
Cpp11BracedListStyle: true
Expand Down
16 changes: 8 additions & 8 deletions include/config.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,19 @@
#pragma once

#if defined(__CUDA_ARCH__)
#define HOST_DEVICE __forceinline__ __host__ __device__
#define DEVICE __forceinline__ __device__
#define HOST __forceinline__ __host__
#define HOST_DEVICE __forceinline__ __host__ __device__
#define DEVICE __forceinline__ __device__
#define HOST __forceinline__ __host__
#else
#define HOST_DEVICE inline
#define DEVICE inline
#define HOST inline
#define HOST_DEVICE inline
#define DEVICE inline
#define HOST inline
#endif

#if defined(__CUDACC__)
#define WARP_SIZE 32
#define WARP_SIZE 32
#endif

#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800))
#define CP_ASYNC_SM80_ENABLED
#define CP_ASYNC_SM80_ENABLED
#endif
24 changes: 24 additions & 0 deletions include/kernels/dispatch_macros.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once

#define DISPATCH_TYPE_CASE(TYPE, NV_TYPE, ...) \
case TYPE: { \
using scalar_t = NV_TYPE; \
return __VA_ARGS__(); \
}

#define TILEFUSION_DISPATCH_ALL_TYPES(TYPE, ...) \
c10::ScalarType _type = TYPE; \
[&] { \
switch (_type) { \
DISPATCH_TYPE_CASE(c10::ScalarType::Float, float, __VA_ARGS__) \
DISPATCH_TYPE_CASE(c10::ScalarType::Half, __half, __VA_ARGS__) \
DISPATCH_TYPE_CASE(c10::ScalarType::BFloat16, __bfloat16, \
__VA_ARGS__) \
default: \
AT_ERROR("Dispatch is not implemented for type: '", \
toString(_type), "'"); \
} \
}();
7 changes: 1 addition & 6 deletions include/kernels/scatter_nd.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,10 @@
#pragma once

#include "cuda_utils.hpp"
#include "dispatch_macros.hpp"

#include <torch/script.h>

#include <cstdint>

namespace tilefusion::kernels {

// reference:
Expand All @@ -35,10 +34,6 @@ __global__ void scatter_nd_kernel(const T* in, T* out, const int64_t* indices,
unsigned int const* __restrict__ strides,
size_t n, size_t rank, size_t slice_size);

template <typename T>
void scatter_nd(torch::Tensor& data, const torch::Tensor& updates,
const torch::Tensor& indices);

void scatter_op(torch::Tensor& data, const torch::Tensor& updates,
const torch::Tensor& indices);

Expand Down
2 changes: 1 addition & 1 deletion scripts/unittests/run_all_cpp_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ fi

cd $BUILD_DIR

for file in $(find "$TESTS_DIR/cell/" -name "test_*.cu"); do
for file in $(find "$TESTS_DIR/" -name "test_*.cu"); do
test_name=$(basename $file .cu)
echo "Running test: $test_name"
ctest -R $test_name
Expand Down
4 changes: 4 additions & 0 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@ set_target_properties(
CUDA_RESOLVE_DEVICE_SYMBOLS ON
CUDA_SEPARABLE_COMPILATION ON)

# Refer to this issue for more context:
# https://github.com/pytorch/pytorch/issues/13541
target_compile_definitions(${TARGET} PUBLIC _GLIBCXX_USE_CXX11_ABI=0)

target_compile_options(
${TARGET}
PUBLIC $<$<COMPILE_LANGUAGE:CUDA>:
Expand Down
51 changes: 21 additions & 30 deletions src/kernels/scatter_nd.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// Licensed under the MIT License.

#include "kernels/scatter_nd.hpp"
#include "traits/base.hpp"

#include <torch/script.h>

Expand Down Expand Up @@ -52,8 +53,7 @@ __global__ void scatter_nd_kernel(const T* in, T* out, const int64_t* indices,
}
}

template <typename T>
void scatter_nd(torch::Tensor& data, const torch::Tensor& updates,
void scatter_op(torch::Tensor& data, const torch::Tensor& updates,
const torch::Tensor& indices) {
auto data_dims = data.sizes();
auto update_dims = updates.sizes();
Expand Down Expand Up @@ -88,40 +88,31 @@ void scatter_nd(torch::Tensor& data, const torch::Tensor& updates,

size_t data_size = data.numel();

// #ifdef DEBUG
// for (int i = rank - 1; i >= 0; --i) {
// std::cout << "strides[" << i << "]: " << strides[i] << std::endl;
// }
// for (int i = rank - 1; i >= 0; --i) {
// std::cout << "data_dims[" << i << "]: " << data_dims[i] <<
// std::endl;
// }
// std::cout << "k: " << k << ", rank: " << rank << std::endl;
// std::cout << "n: " << n << ", slice_size: " << slice_size <<
// std::endl; std::cout << "data_size: " << data_size << std::endl;
// #endif
#ifdef DEBUG
for (int i = rank - 1; i >= 0; --i) {
std::cout << "strides[" << i << "]: " << strides[i] << std::endl;
}
for (int i = rank - 1; i >= 0; --i) {
std::cout << "data_dims[" << i << "]: " << data_dims[i] << std::endl;
}
std::cout << "k: " << k << ", rank: " << rank << std::endl;
std::cout << "n: " << n << ", slice_size: " << slice_size << std::endl;
std::cout << "data_size: " << data_size << std::endl;
#endif

// TODO: Add some assertion checks.

int64_t block = 256;
int64_t grid = (n + block - 1) / block;

scatter_nd_kernel<<<grid, block>>>(
reinterpret_cast<const T*>(indices.const_data_ptr()),
reinterpret_cast<T*>(data.mutable_data_ptr()),
reinterpret_cast<const int64_t*>(indices.const_data_ptr()),
reinterpret_cast<const unsigned int*>(device_strides), n, k,
slice_size);
}

void scatter_op(torch::Tensor& data, const torch::Tensor& updates,
const torch::Tensor& indices) {
auto dtype = data.dtype();
if (dtype == torch::kFloat32) {
scatter_nd<float>(data, updates, indices);
} else if (dtype == torch::kHalf) {
scatter_nd<__half>(data, updates, indices);
}
TILEFUSION_DISPATCH_ALL_TYPES(data.scalar_type(), [&] {
scatter_nd_kernel<<<grid, block>>>(
reinterpret_cast<const scalar_t*>(indices.const_data_ptr()),
reinterpret_cast<scalar_t*>(data.mutable_data_ptr()),
reinterpret_cast<const int64_t*>(indices.const_data_ptr()),
reinterpret_cast<const unsigned int*>(device_strides), n, k,
slice_size);
});
}

} // namespace tilefusion::kernels
File renamed without changes.
File renamed without changes.
55 changes: 31 additions & 24 deletions tests/python/test_scatter_nd.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from pytilefusion import scatter_nd


class TestGemm(unittest.TestCase):
class TestScatterNd(unittest.TestCase):

def _compute_output_shape(self, index_dims, input_dims):
end_size = index_dims[-1]
Expand All @@ -24,38 +24,45 @@ def setUp(self):
torch.manual_seed(1234)

def test_scatter_nd(self):
data_shape = [7, 8, 9, 10]
data = torch.empty(data_shape, dtype=torch.float32,
device='cuda').fill_(5.0)
scatter_data = data.flatten()

indices_shape = [5, 2]
indices = torch.empty(indices_shape, dtype=torch.int64, device='cuda')
for dtype in [
torch.float32,
torch.float16,
torch.bfloat16,
]:
data_shape = [7, 8, 9, 10]
data = torch.empty(data_shape, dtype=dtype,
device='cuda').fill_(5.0)
scatter_data = data.flatten()

for i in range(indices_shape[0]):
indices[i][0] = random.randint(0, data_shape[0] - 1)
indices[i][1] = random.randint(0, data_shape[1] - 1)
indices_shape = [5, 2]
indices = torch.empty(
indices_shape, dtype=torch.int64, device='cuda'
)

scatter_indices = indices.flatten()
for i in range(indices_shape[0]):
indices[i][0] = random.randint(0, data_shape[0] - 1)
indices[i][1] = random.randint(0, data_shape[1] - 1)

update_shape = self._compute_output_shape(indices_shape, data_shape)
updates = torch.empty(update_shape, dtype=torch.float32,
device='cuda').fill_(10.0)
scatter_updates = updates.flatten()
scatter_indices = indices.flatten()

# import pytilefusion
scatter_nd(scatter_data, scatter_indices, scatter_updates)
update_shape = self._compute_output_shape(indices_shape, data_shape)
updates = torch.empty(update_shape, dtype=dtype,
device='cuda').fill_(10.0)
scatter_updates = updates.flatten()

# Implement `scatter_nd` in Python.
data[indices[:, 0], indices[:, 1]] = updates
scatter_nd(scatter_data, scatter_indices, scatter_updates)

flattened_data = data.flatten()
# Implement `scatter_nd` in Python.
data[indices[:, 0], indices[:, 1]] = updates

# Print data
print(scatter_data)
print(flattened_data)
flattened_data = data.flatten()

assert torch.allclose(scatter_data, flattened_data)
# Print data
print(scatter_data)
print(flattened_data)

assert torch.allclose(scatter_data, flattened_data)


if __name__ == "__main__":
Expand Down

0 comments on commit f154e62

Please sign in to comment.