Skip to content

Commit f154e62

Browse files
authored
feat(kernels): Add dispatch macros. (microsoft#45)
* add dispatch macros. * move unittests to the correct directory.
1 parent 5bb8f7b commit f154e62

File tree

10 files changed

+103
-72
lines changed

10 files changed

+103
-72
lines changed

.clang-format

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,19 @@
11
# Run manually to reformat a file:
22
# clang-format -i --style=file <file>
33
BasedOnStyle: Google
4+
UseTab: Never
45
ColumnLimit: 80
56
IndentWidth: 4
7+
68
AccessModifierOffset: -2
9+
710
DerivePointerAlignment: false
8-
# If true, empty lines at the start of blocks are kept.
9-
KeepEmptyLinesAtTheStartOfBlocks: false
11+
PointerAlignment: Left
12+
13+
AlignConsecutiveAssignments: false
14+
AlignConsecutiveDeclarations: false
15+
IndentPPDirectives: BeforeHash
16+
1017
SortIncludes: true
1118
IncludeBlocks: Regroup
1219
IncludeCategories:
@@ -18,7 +25,10 @@ IncludeCategories:
1825
Priority: 2
1926
- Regex: '"([A-Za-z0-9.\Q/-_\E])+"'
2027
Priority: 1
21-
28+
29+
# If true, empty lines at the start of blocks are kept.
30+
KeepEmptyLinesAtTheStartOfBlocks: false
31+
2232
AllowShortLoopsOnASingleLine: true
2333
AllowShortIfStatementsOnASingleLine: true
2434
Cpp11BracedListStyle: true

include/config.hpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,19 +4,19 @@
44
#pragma once
55

66
#if defined(__CUDA_ARCH__)
7-
#define HOST_DEVICE __forceinline__ __host__ __device__
8-
#define DEVICE __forceinline__ __device__
9-
#define HOST __forceinline__ __host__
7+
#define HOST_DEVICE __forceinline__ __host__ __device__
8+
#define DEVICE __forceinline__ __device__
9+
#define HOST __forceinline__ __host__
1010
#else
11-
#define HOST_DEVICE inline
12-
#define DEVICE inline
13-
#define HOST inline
11+
#define HOST_DEVICE inline
12+
#define DEVICE inline
13+
#define HOST inline
1414
#endif
1515

1616
#if defined(__CUDACC__)
17-
#define WARP_SIZE 32
17+
#define WARP_SIZE 32
1818
#endif
1919

2020
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800))
21-
#define CP_ASYNC_SM80_ENABLED
21+
#define CP_ASYNC_SM80_ENABLED
2222
#endif

include/kernels/dispatch_macros.hpp

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#pragma once
5+
6+
#define DISPATCH_TYPE_CASE(TYPE, NV_TYPE, ...) \
7+
case TYPE: { \
8+
using scalar_t = NV_TYPE; \
9+
return __VA_ARGS__(); \
10+
}
11+
12+
#define TILEFUSION_DISPATCH_ALL_TYPES(TYPE, ...) \
13+
c10::ScalarType _type = TYPE; \
14+
[&] { \
15+
switch (_type) { \
16+
DISPATCH_TYPE_CASE(c10::ScalarType::Float, float, __VA_ARGS__) \
17+
DISPATCH_TYPE_CASE(c10::ScalarType::Half, __half, __VA_ARGS__) \
18+
DISPATCH_TYPE_CASE(c10::ScalarType::BFloat16, __bfloat16, \
19+
__VA_ARGS__) \
20+
default: \
21+
AT_ERROR("Dispatch is not implemented for type: '", \
22+
toString(_type), "'"); \
23+
} \
24+
}();

include/kernels/scatter_nd.hpp

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,10 @@
44
#pragma once
55

66
#include "cuda_utils.hpp"
7+
#include "dispatch_macros.hpp"
78

89
#include <torch/script.h>
910

10-
#include <cstdint>
11-
1211
namespace tilefusion::kernels {
1312

1413
// reference:
@@ -35,10 +34,6 @@ __global__ void scatter_nd_kernel(const T* in, T* out, const int64_t* indices,
3534
unsigned int const* __restrict__ strides,
3635
size_t n, size_t rank, size_t slice_size);
3736

38-
template <typename T>
39-
void scatter_nd(torch::Tensor& data, const torch::Tensor& updates,
40-
const torch::Tensor& indices);
41-
4237
void scatter_op(torch::Tensor& data, const torch::Tensor& updates,
4338
const torch::Tensor& indices);
4439

scripts/unittests/run_all_cpp_tests.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ fi
1919

2020
cd $BUILD_DIR
2121

22-
for file in $(find "$TESTS_DIR/cell/" -name "test_*.cu"); do
22+
for file in $(find "$TESTS_DIR/" -name "test_*.cu"); do
2323
test_name=$(basename $file .cu)
2424
echo "Running test: $test_name"
2525
ctest -R $test_name

src/CMakeLists.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,10 @@ set_target_properties(
2020
CUDA_RESOLVE_DEVICE_SYMBOLS ON
2121
CUDA_SEPARABLE_COMPILATION ON)
2222

23+
# Refer to this issue for more context:
24+
# https://github.com/pytorch/pytorch/issues/13541
25+
target_compile_definitions(${TARGET} PUBLIC _GLIBCXX_USE_CXX11_ABI=0)
26+
2327
target_compile_options(
2428
${TARGET}
2529
PUBLIC $<$<COMPILE_LANGUAGE:CUDA>:

src/kernels/scatter_nd.cu

Lines changed: 21 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
// Licensed under the MIT License.
33

44
#include "kernels/scatter_nd.hpp"
5+
#include "traits/base.hpp"
56

67
#include <torch/script.h>
78

@@ -52,8 +53,7 @@ __global__ void scatter_nd_kernel(const T* in, T* out, const int64_t* indices,
5253
}
5354
}
5455

55-
template <typename T>
56-
void scatter_nd(torch::Tensor& data, const torch::Tensor& updates,
56+
void scatter_op(torch::Tensor& data, const torch::Tensor& updates,
5757
const torch::Tensor& indices) {
5858
auto data_dims = data.sizes();
5959
auto update_dims = updates.sizes();
@@ -88,40 +88,31 @@ void scatter_nd(torch::Tensor& data, const torch::Tensor& updates,
8888

8989
size_t data_size = data.numel();
9090

91-
// #ifdef DEBUG
92-
// for (int i = rank - 1; i >= 0; --i) {
93-
// std::cout << "strides[" << i << "]: " << strides[i] << std::endl;
94-
// }
95-
// for (int i = rank - 1; i >= 0; --i) {
96-
// std::cout << "data_dims[" << i << "]: " << data_dims[i] <<
97-
// std::endl;
98-
// }
99-
// std::cout << "k: " << k << ", rank: " << rank << std::endl;
100-
// std::cout << "n: " << n << ", slice_size: " << slice_size <<
101-
// std::endl; std::cout << "data_size: " << data_size << std::endl;
102-
// #endif
91+
#ifdef DEBUG
92+
for (int i = rank - 1; i >= 0; --i) {
93+
std::cout << "strides[" << i << "]: " << strides[i] << std::endl;
94+
}
95+
for (int i = rank - 1; i >= 0; --i) {
96+
std::cout << "data_dims[" << i << "]: " << data_dims[i] << std::endl;
97+
}
98+
std::cout << "k: " << k << ", rank: " << rank << std::endl;
99+
std::cout << "n: " << n << ", slice_size: " << slice_size << std::endl;
100+
std::cout << "data_size: " << data_size << std::endl;
101+
#endif
103102

104103
// TODO: Add some assertion checks.
105104

106105
int64_t block = 256;
107106
int64_t grid = (n + block - 1) / block;
108107

109-
scatter_nd_kernel<<<grid, block>>>(
110-
reinterpret_cast<const T*>(indices.const_data_ptr()),
111-
reinterpret_cast<T*>(data.mutable_data_ptr()),
112-
reinterpret_cast<const int64_t*>(indices.const_data_ptr()),
113-
reinterpret_cast<const unsigned int*>(device_strides), n, k,
114-
slice_size);
115-
}
116-
117-
void scatter_op(torch::Tensor& data, const torch::Tensor& updates,
118-
const torch::Tensor& indices) {
119-
auto dtype = data.dtype();
120-
if (dtype == torch::kFloat32) {
121-
scatter_nd<float>(data, updates, indices);
122-
} else if (dtype == torch::kHalf) {
123-
scatter_nd<__half>(data, updates, indices);
124-
}
108+
TILEFUSION_DISPATCH_ALL_TYPES(data.scalar_type(), [&] {
109+
scatter_nd_kernel<<<grid, block>>>(
110+
reinterpret_cast<const scalar_t*>(indices.const_data_ptr()),
111+
reinterpret_cast<scalar_t*>(data.mutable_data_ptr()),
112+
reinterpret_cast<const int64_t*>(indices.const_data_ptr()),
113+
reinterpret_cast<const unsigned int*>(device_strides), n, k,
114+
slice_size);
115+
});
125116
}
126117

127118
} // namespace tilefusion::kernels
File renamed without changes.

tests/python/test_scatter_nd.py

Lines changed: 31 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from pytilefusion import scatter_nd
1212

1313

14-
class TestGemm(unittest.TestCase):
14+
class TestScatterNd(unittest.TestCase):
1515

1616
def _compute_output_shape(self, index_dims, input_dims):
1717
end_size = index_dims[-1]
@@ -24,38 +24,45 @@ def setUp(self):
2424
torch.manual_seed(1234)
2525

2626
def test_scatter_nd(self):
27-
data_shape = [7, 8, 9, 10]
28-
data = torch.empty(data_shape, dtype=torch.float32,
29-
device='cuda').fill_(5.0)
30-
scatter_data = data.flatten()
3127

32-
indices_shape = [5, 2]
33-
indices = torch.empty(indices_shape, dtype=torch.int64, device='cuda')
28+
for dtype in [
29+
torch.float32,
30+
torch.float16,
31+
torch.bfloat16,
32+
]:
33+
data_shape = [7, 8, 9, 10]
34+
data = torch.empty(data_shape, dtype=dtype,
35+
device='cuda').fill_(5.0)
36+
scatter_data = data.flatten()
3437

35-
for i in range(indices_shape[0]):
36-
indices[i][0] = random.randint(0, data_shape[0] - 1)
37-
indices[i][1] = random.randint(0, data_shape[1] - 1)
38+
indices_shape = [5, 2]
39+
indices = torch.empty(
40+
indices_shape, dtype=torch.int64, device='cuda'
41+
)
3842

39-
scatter_indices = indices.flatten()
43+
for i in range(indices_shape[0]):
44+
indices[i][0] = random.randint(0, data_shape[0] - 1)
45+
indices[i][1] = random.randint(0, data_shape[1] - 1)
4046

41-
update_shape = self._compute_output_shape(indices_shape, data_shape)
42-
updates = torch.empty(update_shape, dtype=torch.float32,
43-
device='cuda').fill_(10.0)
44-
scatter_updates = updates.flatten()
47+
scatter_indices = indices.flatten()
4548

46-
# import pytilefusion
47-
scatter_nd(scatter_data, scatter_indices, scatter_updates)
49+
update_shape = self._compute_output_shape(indices_shape, data_shape)
50+
updates = torch.empty(update_shape, dtype=dtype,
51+
device='cuda').fill_(10.0)
52+
scatter_updates = updates.flatten()
4853

49-
# Implement `scatter_nd` in Python.
50-
data[indices[:, 0], indices[:, 1]] = updates
54+
scatter_nd(scatter_data, scatter_indices, scatter_updates)
5155

52-
flattened_data = data.flatten()
56+
# Implement `scatter_nd` in Python.
57+
data[indices[:, 0], indices[:, 1]] = updates
5358

54-
# Print data
55-
print(scatter_data)
56-
print(flattened_data)
59+
flattened_data = data.flatten()
5760

58-
assert torch.allclose(scatter_data, flattened_data)
61+
# Print data
62+
print(scatter_data)
63+
print(flattened_data)
64+
65+
assert torch.allclose(scatter_data, flattened_data)
5966

6067

6168
if __name__ == "__main__":

0 commit comments

Comments
 (0)