Skip to content

Commit bba5c48

Browse files
authored
[RMSNorm][Half] support fp16x8 packed RMSNorm (xlite-dev#22)
* Update elementwise.cu * Update relu.cu * Update elementwise.cu * Update rms_norm.cu * Update rms_norm.py * Update README.md * Update README.md
1 parent ff05f78 commit bba5c48

File tree

6 files changed

+372
-15
lines changed

6 files changed

+372
-15
lines changed

README.md

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,10 @@
6666
| ✔️ [layer_norm_f16_f32_kernel(per token)](./layer-norm/layer_norm.cu)|f16|f32|[link](./layer-norm/)|⭐️⭐️|
6767
| ✔️ [rms_norm_f32_kernel(per token)](./rms-norm/rms_norm.cu)|f32|f32|[link](./rms-norm/)|⭐️⭐️|
6868
| ✔️ [rms_norm_f32x4_kernel(per token)](./rms-norm/rms_norm.cu)|f32|f32|[link](./rms-norm/)|⭐️⭐️|
69-
|[rms_norm_f16_kernel(per token)](./rms-norm/rms_norm.cu)|f16|f16||⭐️⭐️|
70-
|[rms_norm_f16x2_kernel(per token)](./rms-norm/rms_norm.cu)|f16|f16||⭐️⭐️|
69+
| ✔️ [rms_norm_f16_f16_kernel(per token)](./rms-norm/rms_norm.cu)|f16|f16|[link](./rms-norm/)|⭐️⭐️|
70+
| ✔️ [rms_norm_f16x2_f16_kernel(per token)](./rms-norm/rms_norm.cu)|f16|f16|[link](./rms-norm/)|⭐️⭐️|
71+
| ✔️ [rms_norm_f16x8_f16_kernel(per token)](./rms-norm/rms_norm.cu)|f16|f16|[link](./rms-norm/)|⭐️⭐️|
72+
| ✔️ [rms_norm_f16_f32_kernel(per token)](./rms-norm/rms_norm.cu)|f16|f32|[link](./rms-norm/)|⭐️⭐️|
7173
| ✔️ [sgemm_sliced_k_f32_kernel](./sgemm/sgemm.cu)|f32|f32|[link](./sgemm/)|⭐️⭐️⭐️|
7274
| ✔️ [sgemm_t_8x8_sliced_k_f32x4_kernel](./sgemm/sgemm.cu)|f32|f32|[link](./sgemm/)|⭐️⭐️⭐️|
7375
|[hgemm_sliced_k_f16_f32_kernel](./hgemm)|f16|f32||⭐️⭐️⭐️|

elementwise/elementwise.cu

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
#define WARP_SIZE 32
1414
#define INT4(value) (reinterpret_cast<int4*>(&(value))[0])
1515
#define FLOAT4(value) (reinterpret_cast<float4*>(&(value))[0])
16+
#define HALF2(value) (reinterpret_cast<half2*>(&(value))[0])
17+
#define BFLOAT2(value) (reinterpret_cast<__nv_bfloat162*>(&(value))[0])
1618

1719
// -------------------------------------- FP32 --------------------------------------
1820
// ElementWise Add
@@ -40,7 +42,6 @@ __global__ void elementwise_add_f32x4_kernel(float* a, float* b, float* c, int N
4042
}
4143
}
4244

43-
4445
// -------------------------------------- FP16 --------------------------------------
4546
// ElementWise Add
4647
// grid(N/256), block(256)
@@ -54,12 +55,12 @@ __global__ void elementwise_add_f16_kernel(half* a, half* b, half* c, int N) {
5455
__global__ void elementwise_add_f16x2_kernel(half* a, half* b, half* c, int N) {
5556
int idx = 2 * (blockIdx.x * blockDim.x + threadIdx.x);
5657
if (idx < N) {
57-
half2 reg_a = (reinterpret_cast<half2*>(&(a[idx]))[0]);
58-
half2 reg_b = (reinterpret_cast<half2*>(&(b[idx]))[0]);
58+
half2 reg_a = HALF2(a[idx]);
59+
half2 reg_b = HALF2(b[idx]);
5960
half2 reg_c;
6061
reg_c.x = __hadd(reg_a.x, reg_b.x);
6162
reg_c.y = __hadd(reg_a.y, reg_b.y);
62-
(reinterpret_cast<half2*>(&(c[idx]))[0]) = reg_c;
63+
HALF2(c[idx]) = reg_c;
6364
}
6465
}
6566

@@ -137,4 +138,4 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
137138
TORCH_BINDING_COMMON_EXTENSION(elementwise_add_f32x4_v2)
138139
TORCH_BINDING_COMMON_EXTENSION(elementwise_add_f16_v2)
139140
TORCH_BINDING_COMMON_EXTENSION(elementwise_add_f16x2_v2)
140-
}
141+
}

relu/relu.cu

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
#define WARP_SIZE 32
1212
#define INT4(value) (reinterpret_cast<int4*>(&(value))[0])
1313
#define FLOAT4(value) (reinterpret_cast<float4*>(&(value))[0])
14+
#define HALF2(value) (reinterpret_cast<half2*>(&(value))[0])
15+
#define BFLOAT2(value) (reinterpret_cast<__nv_bfloat162*>(&(value))[0])
1416

1517
// -------------------------------------- FP32 --------------------------------------
1618
// Relu x: N, y: N y=max(0,x)
@@ -44,11 +46,11 @@ __global__ void relu_f16_kernel(half* x, half* y, int N) {
4446
__global__ void relu_f16x2_kernel(half* x, half* y, int N) {
4547
int idx = 2 * (blockIdx.x * blockDim.x + threadIdx.x);
4648
if (idx < N) {
47-
half2 reg_x = (reinterpret_cast<half2*>(&(x[idx]))[0]);
48-
half2 reg_y = (reinterpret_cast<half2*>(&(y[idx]))[0]);
49+
half2 reg_x = HALF2(x[idx]);
50+
half2 reg_y = HALF2(y[idx]);
4951
reg_y.x = __hmax(__float2half(0.0f), reg_x.x);
5052
reg_y.y = __hmax(__float2half(0.0f), reg_x.y);
51-
(reinterpret_cast<half2*>(&(y[idx]))[0]) = reg_y;
53+
HALF2(y[idx]) = reg_y;
5254
}
5355
}
5456

rms-norm/README.md

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,11 @@
55
包含以下内容:
66

77
- [X] rms_norm_f32_kernel
8-
- [X] rms_norm_f32x4_kernel(float4向量化版本)
8+
- [X] rms_norm_f32x4_kernel
9+
- [X] rms_norm_f16_f16_kernel
10+
- [X] rms_norm_f16x2_f16_kernel
11+
- [X] rms_norm_f16x8_f16_kernel
12+
- [X] rms_norm_f16_f32_kernel
913
- [X] PyTorch bindings
1014

1115
## 测试
@@ -20,8 +24,14 @@ python3 rms_norm.py
2024

2125
```bash
2226
--------------------------------------------------------------------------------
23-
out_f32: [0.58164781, 0.37353456, 0.12243245], time:0.01189494ms
24-
out_f32x4: [0.58164781, 0.37353456, 0.12243245], time:0.00514793ms
25-
out_f32_th: [0.58165121, 0.37353677, 0.12243317], time:0.04241681ms
27+
out_f32: [0.66361254, 0.69628561, 0.51101440], time:0.01188707ms
28+
out_f32x4: [0.66361260, 0.69628561, 0.51101440], time:0.00833464ms
29+
out_f32_th: [0.66361588, 0.69628906, 0.51101691], time:0.04334593ms
2630
--------------------------------------------------------------------------------
27-
```
31+
out_f16f16: [0.66357422, 0.69580078, 0.51074219], time:0.01201081ms
32+
out_f16x2f16: [0.66357422, 0.69580078, 0.51074219], time:0.00668955ms
33+
out_f16x8f16: [0.66650391, 0.69921875, 0.51318359], time:0.00398421ms
34+
out_f16f32: [0.66357422, 0.69628906, 0.51123047], time:0.01176858ms
35+
out_f16_th: [0.66357422, 0.69580078, 0.51074219], time:0.04448509ms
36+
--------------------------------------------------------------------------------
37+
```

0 commit comments

Comments
 (0)