Skip to content

Commit 7bfebb9

Browse files
authored
[Bugfix][Kernel] fixed some kernel blocks calculate errors (xlite-dev#23)
* Update elementwise.cu * Update relu.cu * Update elementwise.cu * Update rms_norm.cu * Update rms_norm.py * Update README.md * Update README.md * Update sigmoid.cu * Update relu.cu * Update softmax.cu * Update softmax.py * Update block_all_reduce.cu * Update dot_product.cu * Update elementwise.cu * Update histogram.cu
1 parent bba5c48 commit 7bfebb9

File tree

8 files changed

+26
-22
lines changed

8 files changed

+26
-22
lines changed

dot-product/dot_product.cu

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
#define WARP_SIZE 32
1414
#define INT4(value) (reinterpret_cast<int4*>(&(value))[0])
1515
#define FLOAT4(value) (reinterpret_cast<float4*>(&(value))[0])
16+
#define HALF2(value) (reinterpret_cast<half2*>(&(value))[0])
17+
#define BFLOAT2(value) (reinterpret_cast<__nv_bfloat162*>(&(value))[0])
1618

1719
// -------------------------------------- FP32 --------------------------------------
1820
// Warp Reduce Sum
@@ -129,8 +131,8 @@ __global__ void dot_prod_f16x2_f32_kernel(half* a, half* b, float* y, int N) {
129131
__shared__ float reduce_smem[NUM_WARPS];
130132

131133
// keep the data in register is enougth for warp operaion.
132-
half2 reg_a = (reinterpret_cast<half2*>(&(a[idx]))[0]);
133-
half2 reg_b = (reinterpret_cast<half2*>(&(b[idx]))[0]);
134+
half2 reg_a = HALF2(a[idx]);
135+
half2 reg_b = HALF2(b[idx]);
134136
half prod_f16 = (idx < N) ? __hadd(__hmul(reg_a.x, reg_b.x),
135137
__hmul(reg_a.y, reg_b.y)) : __float2half(0.0f);
136138
int warp = tid / WARP_SIZE;
@@ -170,7 +172,7 @@ torch::Tensor dot_prod_##packed_type##_##acc_type(torch::Tensor a, torch::Tensor
170172
const int N = a.size(0); \
171173
CHECK_TORCH_TENSOR_SHAPE(b, N) \
172174
static const int NUM_THREADS_PER_BLOCK = 256 / (n_elements); \
173-
const int NUM_BLOCKS = (N + NUM_THREADS_PER_BLOCK - 1) / NUM_THREADS_PER_BLOCK; \
175+
const int NUM_BLOCKS = (N + 256 - 1) / 256; \
174176
dim3 block(NUM_THREADS_PER_BLOCK); \
175177
dim3 grid(NUM_BLOCKS); \
176178
dot_prod_##packed_type##_##acc_type##_kernel< \

elementwise/elementwise.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ torch::Tensor elementwise_add_##packed_type(torch::Tensor a, torch::Tensor b) {
8989
CHECK_TORCH_TENSOR_SHAPE(b, N) \
9090
auto c = torch::zeros({N}, options); \
9191
static const int NUM_THREADS_PER_BLOCK = 256 / (n_elements); \
92-
const int NUM_BLOCKS = (N + NUM_THREADS_PER_BLOCK - 1) / NUM_THREADS_PER_BLOCK;\
92+
const int NUM_BLOCKS = (N + 256 - 1) / 256; \
9393
dim3 block(NUM_THREADS_PER_BLOCK); \
9494
dim3 grid(NUM_BLOCKS); \
9595
elementwise_add_##packed_type##_kernel<<<grid, block>>>( \
@@ -109,7 +109,7 @@ void elementwise_add_##packed_type##_v2(
109109
CHECK_TORCH_TENSOR_SHAPE(b, N) \
110110
CHECK_TORCH_TENSOR_SHAPE(c, N) \
111111
static const int NUM_THREADS_PER_BLOCK = 256 / (n_elements); \
112-
const int NUM_BLOCKS = (N + NUM_THREADS_PER_BLOCK - 1) / NUM_THREADS_PER_BLOCK;\
112+
const int NUM_BLOCKS = (N + 256 - 1) / 256; \
113113
dim3 block(NUM_THREADS_PER_BLOCK); \
114114
dim3 grid(NUM_BLOCKS); \
115115
elementwise_add_##packed_type##_kernel<<<grid, block>>>( \

histogram/histogram.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ torch::Tensor histogram_##packed_type(torch::Tensor a) {
5959
const int M = max_val.item().to<int>(); \
6060
auto y = torch::zeros({M+1}, options); \
6161
static const int NUM_THREADS_PER_BLOCK = 256 / (n_elements); \
62-
const int NUM_BLOCKS = (N + NUM_THREADS_PER_BLOCK - 1) / NUM_THREADS_PER_BLOCK;\
62+
const int NUM_BLOCKS = (N + 256 - 1) / 256; \
6363
dim3 block(NUM_THREADS_PER_BLOCK); \
6464
dim3 grid(NUM_BLOCKS); \
6565
histogram_##packed_type##_kernel<<<grid, block>>>( \
@@ -74,4 +74,4 @@ TORCH_BINDING_HIST(i32x4, torch::kInt32, int, 4)
7474
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
7575
TORCH_BINDING_COMMON_EXTENSION(histogram_i32)
7676
TORCH_BINDING_COMMON_EXTENSION(histogram_i32x4)
77-
}
77+
}

reduce/block_all_reduce.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -477,7 +477,7 @@ torch::Tensor block_all_reduce_sum_##packed_type##_##acc_type(torch::Tensor a) {
477477
auto sum = torch::zeros({1}, options); \
478478
const int N = a.size(0); \
479479
static const int NUM_THREADS_PER_BLOCK = 256 / (n_elements); \
480-
const int NUM_BLOCKS = (N + NUM_THREADS_PER_BLOCK - 1) / NUM_THREADS_PER_BLOCK; \
480+
const int NUM_BLOCKS = (N + 256 - 1) / 256; \
481481
dim3 block(NUM_THREADS_PER_BLOCK); \
482482
dim3 grid(NUM_BLOCKS); \
483483
block_all_reduce_sum_##packed_type##_##acc_type##_kernel< \
@@ -494,7 +494,7 @@ torch::Tensor block_all_reduce_sum_##packed_type##_##acc_type(torch::Tensor a) {
494494
auto sum = torch::zeros({1}, options); \
495495
const int N = a.size(0); \
496496
static const int NUM_THREADS_PER_BLOCK = 256 / (n_elements); \
497-
const int NUM_BLOCKS = (N + NUM_THREADS_PER_BLOCK - 1) / NUM_THREADS_PER_BLOCK; \
497+
const int NUM_BLOCKS = (N + 256 - 1) / 256; \
498498
dim3 block(NUM_THREADS_PER_BLOCK); \
499499
dim3 grid(NUM_BLOCKS); \
500500
block_all_reduce_sum_##packed_type##_##acc_type##_kernel< \

relu/relu.cu

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ torch::Tensor relu_##packed_type(torch::Tensor x) {
7676
const int N = x.size(0); \
7777
auto y = torch::zeros({N}, options); \
7878
static const int NUM_THREADS_PER_BLOCK = 256 / (n_elements); \
79-
const int NUM_BLOCKS = (N + NUM_THREADS_PER_BLOCK - 1) / NUM_THREADS_PER_BLOCK;\
79+
const int NUM_BLOCKS = (N + 256 - 1) / 256; \
8080
dim3 block(NUM_THREADS_PER_BLOCK); \
8181
dim3 grid(NUM_BLOCKS); \
8282
relu_##packed_type##_kernel<<<grid, block>>>( \
@@ -92,18 +92,18 @@ void relu_##packed_type##_v2(torch::Tensor x, torch::Tensor y) {
9292
const int N = x.size(0); \
9393
CHECK_TORCH_TENSOR_SHAPE(y, N) \
9494
static const int NUM_THREADS_PER_BLOCK = 256 / (n_elements); \
95-
const int NUM_BLOCKS = (N + NUM_THREADS_PER_BLOCK - 1) / NUM_THREADS_PER_BLOCK;\
95+
const int NUM_BLOCKS = (N + 256 - 1) / 256; \
9696
dim3 block(NUM_THREADS_PER_BLOCK); \
9797
dim3 grid(NUM_BLOCKS); \
9898
relu_##packed_type##_kernel<<<grid, block>>>( \
9999
reinterpret_cast<element_type*>(x.data_ptr()), \
100100
reinterpret_cast<element_type*>(y.data_ptr()), N); \
101101
}
102102

103-
TORCH_BINDING_RELU(f32, torch::kFloat32, float, 1)
104-
TORCH_BINDING_RELU(f32x4, torch::kFloat32, float, 4)
105-
TORCH_BINDING_RELU(f16, torch::kHalf, half, 1)
106-
TORCH_BINDING_RELU(f16x2, torch::kHalf, half, 2)
103+
TORCH_BINDING_RELU(f32, torch::kFloat32, float, 1)
104+
TORCH_BINDING_RELU(f32x4, torch::kFloat32, float, 4)
105+
TORCH_BINDING_RELU(f16, torch::kHalf, half, 1)
106+
TORCH_BINDING_RELU(f16x2, torch::kHalf, half, 2)
107107
TORCH_BINDING_RELU_V2(f32, torch::kFloat32, float, 1)
108108
TORCH_BINDING_RELU_V2(f32x4, torch::kFloat32, float, 4)
109109
TORCH_BINDING_RELU_V2(f16, torch::kHalf, half, 1)

sigmoid/sigmoid.cu

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ torch::Tensor sigmoid_##packed_type(torch::Tensor x) {
5656
const int N = x.size(0); \
5757
auto y = torch::zeros({N}, options); \
5858
static const int NUM_THREADS_PER_BLOCK = 256 / (n_elements); \
59-
const int NUM_BLOCKS = (N + NUM_THREADS_PER_BLOCK - 1) / NUM_THREADS_PER_BLOCK;\
59+
const int NUM_BLOCKS = (N + 256 - 1) / 256; \
6060
dim3 block(NUM_THREADS_PER_BLOCK); \
6161
dim3 grid(NUM_BLOCKS); \
6262
sigmoid_##packed_type##_kernel<<<grid, block>>>( \
@@ -72,16 +72,16 @@ void sigmoid_##packed_type##_v2(torch::Tensor x, torch::Tensor y) {
7272
const int N = x.size(0); \
7373
CHECK_TORCH_TENSOR_SHAPE(y, N) \
7474
static const int NUM_THREADS_PER_BLOCK = 256 / (n_elements); \
75-
const int NUM_BLOCKS = (N + NUM_THREADS_PER_BLOCK - 1) / NUM_THREADS_PER_BLOCK;\
75+
const int NUM_BLOCKS = (N + 256 - 1) / 256; \
7676
dim3 block(NUM_THREADS_PER_BLOCK); \
7777
dim3 grid(NUM_BLOCKS); \
7878
sigmoid_##packed_type##_kernel<<<grid, block>>>( \
7979
reinterpret_cast<element_type*>(x.data_ptr()), \
8080
reinterpret_cast<element_type*>(y.data_ptr()), N); \
8181
}
8282

83-
TORCH_BINDING_SIGMOID(f32, torch::kFloat32, float, 1)
84-
TORCH_BINDING_SIGMOID(f32x4, torch::kFloat32, float, 4)
83+
TORCH_BINDING_SIGMOID(f32, torch::kFloat32, float, 1)
84+
TORCH_BINDING_SIGMOID(f32x4, torch::kFloat32, float, 4)
8585
TORCH_BINDING_SIGMOID_V2(f32, torch::kFloat32, float, 1)
8686
TORCH_BINDING_SIGMOID_V2(f32x4, torch::kFloat32, float, 4)
8787

softmax/softmax.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,7 @@ torch::Tensor softmax_##packed_type(torch::Tensor x) {
230230
auto y = torch::zeros({N}, options); \
231231
auto total = torch::zeros({1}, options); \
232232
static const int NUM_THREADS_PER_BLOCK = 256 / (n_elements); \
233-
const int NUM_BLOCKS = (N + NUM_THREADS_PER_BLOCK - 1) / NUM_THREADS_PER_BLOCK;\
233+
const int NUM_BLOCKS = (N + 256 - 1) / 256; \
234234
dim3 block(NUM_THREADS_PER_BLOCK); \
235235
dim3 grid(NUM_BLOCKS); \
236236
softmax_##packed_type##_kernel<NUM_THREADS_PER_BLOCK><<<grid, block>>>( \
@@ -252,7 +252,7 @@ void softmax_##packed_type##_v2(torch::Tensor x, torch::Tensor y) {
252252
if (y.size(0) != N) {throw std::runtime_error("y size mismatch!"); } \
253253
auto total = torch::zeros({1}, options); \
254254
static const int NUM_THREADS_PER_BLOCK = 256 / (n_elements); \
255-
const int NUM_BLOCKS = (N + NUM_THREADS_PER_BLOCK - 1) / NUM_THREADS_PER_BLOCK;\
255+
const int NUM_BLOCKS = (N + 256 - 1) / 256; \
256256
dim3 block(NUM_THREADS_PER_BLOCK); \
257257
dim3 grid(NUM_BLOCKS); \
258258
softmax_##packed_type##_kernel<NUM_THREADS_PER_BLOCK><<<grid, block>>>( \

softmax/softmax.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,8 @@
2424

2525
def run_benchmark(perf_func: callable, x: torch.Tensor,
2626
tag: str, out: Optional[torch.Tensor] = None,
27-
warmup: int = 10, iters: int = 1000):
27+
warmup: int = 10, iters: int = 1000,
28+
show_all: bool = False):
2829
if out is not None:
2930
out.fill_(0)
3031
if out is not None:
@@ -50,6 +51,7 @@ def run_benchmark(perf_func: callable, x: torch.Tensor,
5051
out_val = out.flatten().detach().cpu().numpy().tolist()[:3]
5152
out_val = [round(v, 8) for v in out_val]
5253
print(f"{out_info:>20}: {out_val}, time:{mean_time:.8f}ms")
54+
if show_all: print(out)
5355
return out, mean_time
5456

5557

0 commit comments

Comments
 (0)