Skip to content

Commit 0c9166d

Browse files
authored
[Misc][Benchmark] optimize benchmarks (xlite-dev#53)
1 parent 499c39e commit 0c9166d

File tree

20 files changed

+1028
-241
lines changed

20 files changed

+1028
-241
lines changed

README.md

Lines changed: 49 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -42,64 +42,63 @@
4242
| ✔️ [relu_f16x2](./relu/relu.cu)|f16|/|[link](./relu/)|⭐️|
4343
| ✔️ [relu_f16x8](./relu/relu.cu)|f16|/|[link](./relu/)|⭐️|
4444
| ✔️ [relu_f16x8_pack](./relu/relu.cu)|f16|/|[link](./relu/)|⭐️⭐️|
45-
| ✔️ [warp_reduce_f16/bf16/f32/f8/i8](./reduce/block_all_reduce.cu)|all|all|[link](./reduce/)|⭐️⭐️|
46-
| ✔️ [block_reduce_f32](./reduce/block_all_reduce.cu)|f32|f32|[link](./reduce/)|⭐️⭐️|
47-
| ✔️ [block_all_reduce_f32_f32](./reduce/block_all_reduce.cu)|f32|f32|[link](./reduce/)|⭐️⭐️|
48-
| ✔️ [block_all_reduce_f32x4_f32](./reduce/block_all_reduce.cu)|f32|f32|[link](./reduce/)|⭐️⭐️|
49-
| ✔️ [block_all_reduce_f16_f16](./reduce/block_all_reduce.cu)|f16|f16|[link](./reduce/)|⭐️⭐️|
50-
| ✔️ [block_all_reduce_f16_f32](./reduce/block_all_reduce.cu)|f16|f32|[link](./reduce/)|⭐️⭐️|
51-
| ✔️ [block_all_reduce_f16x2_f16](./reduce/block_all_reduce.cu)|f16|f16|[link](./reduce/)|⭐️⭐️|
52-
| ✔️ [block_all_reduce_f16x2_f32](./reduce/block_all_reduce.cu)|f16|f32|[link](./reduce/)|⭐️⭐️|
53-
| ✔️ [block_all_reduce_f16x8_pack_f16](./reduce/block_all_reduce.cu)|f16|f16|[link](./reduce/)|⭐️⭐️|
54-
| ✔️ [block_all_reduce_f16x8_pack_f32](./reduce/block_all_reduce.cu)|f16|f32|[link](./reduce/)|⭐️⭐️|
55-
| ✔️ [block_all_reduce_bf16_bf16](./reduce/block_all_reduce.cu)|bf16|bf16|[link](./reduce/)|⭐️⭐️|
56-
| ✔️ [block_all_reduce_bf16_f32](./reduce/block_all_reduce.cu)|bf16|f32|[link](./reduce/)|⭐️⭐️|
57-
| ✔️ [block_all_reduce_bf16x2_bf16](./reduce/block_all_reduce.cu)|bf16|bf16|[link](./reduce/)|⭐️⭐️|
58-
| ✔️ [block_all_reduce_bf16x2_f32](./reduce/block_all_reduce.cu)|bf16|f32|[link](./reduce/)|⭐️⭐️|
59-
| ✔️ [block_all_reduce_bf16x8_pack_bf16](./reduce/block_all_reduce.cu)|bf16|bf16|[link](./reduce/)|⭐️⭐️|
60-
| ✔️ [block_all_reduce_bf16x8_pack_f32](./reduce/block_all_reduce.cu)|bf16|f32|[link](./reduce/)|⭐️⭐️|
61-
| ✔️ [block_all_reduce_fp8_e4m3_f16](./reduce/block_all_reduce.cu)|fp8_e4m3|f16|[link](./reduce/)|⭐️⭐️|
62-
| ✔️ [block_all_reduce_fp8_e5m2_f16](./reduce/block_all_reduce.cu)|fp8_e5m2|f16|[link](./reduce/)|⭐️⭐️|
63-
| ✔️ [block_all_reduce_fp8_e4m3x16_pack_f16](./reduce/block_all_reduce.cu)|fp8_e4m3|f16|[link](./reduce/)|⭐️⭐️|
64-
| ✔️ [block_all_reduce_fp8_e5m2x16_pack_f16](./reduce/block_all_reduce.cu)|fp8_e5m2|f16|[link](./reduce/)|⭐️⭐️|
65-
| ✔️ [block_all_reduce_i8_i32](./reduce/block_all_reduce.cu)|i8|i32|[link](./reduce/)|⭐️⭐️|
66-
| ✔️ [block_all_reduce_i8x16_pack_i32](./reduce/block_all_reduce.cu)|i8|i32|[link](./reduce/)|⭐️⭐️|
45+
| ✔️ [warp_reduce_[all]](./reduce/reduce.cu)|all|all|[link](./reduce/)|⭐️⭐️|
46+
| ✔️ [reduce_f32_f32](./reduce/reduce.cu)|f32|f32|[link](./reduce/)|⭐️⭐️|
47+
| ✔️ [reduce_f32x4_f32](./reduce/reduce.cu)|f32|f32|[link](./reduce/)|⭐️⭐️|
48+
| ✔️ [reduce_f16_f16](./reduce/reduce.cu)|f16|f16|[link](./reduce/)|⭐️⭐️|
49+
| ✔️ [reduce_f16_f32](./reduce/reduce.cu)|f16|f32|[link](./reduce/)|⭐️⭐️|
50+
| ✔️ [reduce_f16x2_f16](./reduce/reduce.cu)|f16|f16|[link](./reduce/)|⭐️⭐️|
51+
| ✔️ [reduce_f16x2_f32](./reduce/reduce.cu)|f16|f32|[link](./reduce/)|⭐️⭐️|
52+
| ✔️ [reduce_f16x8_pack_f16](./reduce/reduce.cu)|f16|f16|[link](./reduce/)|⭐️⭐️|
53+
| ✔️ [reduce_f16x8_pack_f32](./reduce/reduce.cu)|f16|f32|[link](./reduce/)|⭐️⭐️|
54+
| ✔️ [reduce_bf16_bf16](./reduce/reduce.cu)|bf16|bf16|[link](./reduce/)|⭐️⭐️|
55+
| ✔️ [reduce_bf16_f32](./reduce/reduce.cu)|bf16|f32|[link](./reduce/)|⭐️⭐️|
56+
| ✔️ [reduce_bf16x2_bf16](./reduce/reduce.cu)|bf16|bf16|[link](./reduce/)|⭐️⭐️|
57+
| ✔️ [reduce_bf16x2_f32](./reduce/reduce.cu)|bf16|f32|[link](./reduce/)|⭐️⭐️|
58+
| ✔️ [reduce_bf16x8_pack_bf16](./reduce/reduce.cu)|bf16|bf16|[link](./reduce/)|⭐️⭐️|
59+
| ✔️ [reduce_bf16x8_pack_f32](./reduce/reduce.cu)|bf16|f32|[link](./reduce/)|⭐️⭐️|
60+
| ✔️ [reduce_fp8_e4m3_f16](./reduce/reduce.cu)|fp8_e4m3|f16|[link](./reduce/)|⭐️⭐️|
61+
| ✔️ [reduce_fp8_e5m2_f16](./reduce/reduce.cu)|fp8_e5m2|f16|[link](./reduce/)|⭐️⭐️|
62+
| ✔️ [reduce_fp8_e4m3x16_pack_f16](./reduce/reduce.cu)|fp8_e4m3|f16|[link](./reduce/)|⭐️⭐️|
63+
| ✔️ [reduce_fp8_e5m2x16_pack_f16](./reduce/reduce.cu)|fp8_e5m2|f16|[link](./reduce/)|⭐️⭐️|
64+
| ✔️ [reduce_i8_i32](./reduce/reduce.cu)|i8|i32|[link](./reduce/)|⭐️⭐️|
65+
| ✔️ [reduce_i8x16_pack_i32](./reduce/reduce.cu)|i8|i32|[link](./reduce/)|⭐️⭐️|
6766
| ✔️ [dot_product_f32](./dot-product/dot_product.cu)|f32|f32|[link](./dot-product/)|⭐️⭐️|
6867
| ✔️ [dot_product_f32x4](./dot-product/dot_product.cu)|f32|f32|[link](./dot-product/)|⭐️⭐️|
6968
| ✔️ [dot_product_f16_f32](./dot-product/dot_product.cu)|f16|f32|[link](./dot-product/)|⭐️⭐️|
7069
| ✔️ [dot_product_f16x2_f32](./dot-product/dot_product.cu)|f16|f32|[link](./dot-product/)|⭐️⭐️|
7170
| ✔️ [dot_product_f16x8_pack_f32](./dot-product/dot_product.cu)|f16|f32|[link](./dot-product/)|⭐️⭐️|
72-
| ✔️ [softmax_f32(memory fence)](./softmax/softmax.cu)|f32|f32|[link](./softmax/)|⭐️⭐️|
73-
| ✔️ [softmax_f32x4(memory fence)](./softmax/softmax.cu)|f32|f32|[link](./softmax/)|⭐️⭐️|
74-
| ✔️ [softmax_f32(per token)](./softmax/softmax.cu)|f32|f32|[link](./softmax/)|⭐️⭐️|
75-
| ✔️ [softmax_f32x4(per token)](./softmax/softmax.cu)|f32|f32|[link](./softmax/)|⭐️⭐️|
76-
| ✔️ [safe_softmax_f32(per token)](./softmax/softmax.cu)|f32|f32|[link](./softmax/)|⭐️⭐️|
77-
| ✔️ [safe_softmax_f32x4(per token)](./softmax/softmax.cu)|f32|f32|[link](./softmax/)|⭐️⭐️|
78-
| ✔️ [safe_softmax_f16_f32(per token)](./softmax/softmax.cu)|f16|f32|[link](./softmax/)|⭐️⭐️|
79-
| ✔️ [safe_softmax_f16x2_f32(per token)](./softmax/softmax.cu)|f16|f32|[link](./softmax/)|⭐️⭐️|
80-
| ✔️ [safe_softmax_f16x8_pack_f32(per token)](./softmax/softmax.cu)|f16|f32|[link](./softmax/)|⭐️⭐️|
81-
| ✔️ [layer_norm_f32(per token)](./layer-norm/layer_norm.cu)|f32|f32|[link](./layer-norm/)|⭐️⭐️|
82-
| ✔️ [layer_norm_f32x4(per token)](./layer-norm/layer_norm.cu)|f32|f32|[link](./layer-norm/)|⭐️⭐️|
83-
| ✔️ [layer_norm_f16_f16(per token)](./layer-norm/layer_norm.cu)|f16|f16|[link](./layer-norm/)|⭐️⭐️|
84-
| ✔️ [layer_norm_f16x2_f16(per token)](./layer-norm/layer_norm.cu)|f16|f16|[link](./layer-norm/)|⭐️⭐️|
85-
| ✔️ [layer_norm_f16x8_f16(per token)](./layer-norm/layer_norm.cu)|f16|f16|[link](./layer-norm/)|⭐️⭐️|
86-
| ✔️ [layer_norm_f16x8_pack_f16(per token)](./layer-norm/layer_norm.cu)|f16|f16|[link](./layer-norm/)|⭐️⭐️|
87-
| ✔️ [layer_norm_f16x8_pack_f32(per token)](./layer-norm/layer_norm.cu)|f16|f32|[link](./layer-norm/)|⭐️⭐️|
88-
| ✔️ [layer_norm_f16_f32(per token)](./layer-norm/layer_norm.cu)|f16|f32|[link](./layer-norm/)|⭐️⭐️|
89-
| ✔️ [rms_norm_f32(per token)](./rms-norm/rms_norm.cu)|f32|f32|[link](./rms-norm/)|⭐️⭐️|
90-
| ✔️ [rms_norm_f32x4(per token)](./rms-norm/rms_norm.cu)|f32|f32|[link](./rms-norm/)|⭐️⭐️|
91-
| ✔️ [rms_norm_f16_f16(per token)](./rms-norm/rms_norm.cu)|f16|f16|[link](./rms-norm/)|⭐️⭐️|
92-
| ✔️ [rms_norm_f16x2_f16(per token)](./rms-norm/rms_norm.cu)|f16|f16|[link](./rms-norm/)|⭐️⭐️|
93-
| ✔️ [rms_norm_f16x8_f16(per token)](./rms-norm/rms_norm.cu)|f16|f16|[link](./rms-norm/)|⭐️⭐️|
94-
| ✔️ [rms_norm_f16x8_f32(per token)](./rms-norm/rms_norm.cu)|f16|f32|[link](./rms-norm/)|⭐️⭐️|
95-
| ✔️ [rms_norm_f16x8_pack_f16(per token)](./rms-norm/rms_norm.cu)|f16|f16|[link](./rms-norm/)|⭐️⭐️|
96-
| ✔️ [rms_norm_f16x8_pack_f32(per token)](./rms-norm/rms_norm.cu)|f16|f32|[link](./rms-norm/)|⭐️⭐️|
97-
| ✔️ [rms_norm_f16_f32(per token)](./rms-norm/rms_norm.cu)|f16|f32|[link](./rms-norm/)|⭐️⭐️|
71+
| ✔️ [softmax_f32(fence)](./softmax/softmax.cu)|f32|f32|[link](./softmax/)|⭐️⭐️|
72+
| ✔️ [softmax_f32x4(fence)](./softmax/softmax.cu)|f32|f32|[link](./softmax/)|⭐️⭐️|
73+
| ✔️ [softmax_f32](./softmax/softmax.cu)|f32|f32|[link](./softmax/)|⭐️⭐️|
74+
| ✔️ [softmax_f32x4](./softmax/softmax.cu)|f32|f32|[link](./softmax/)|⭐️⭐️|
75+
| ✔️ [safe_softmax_f32](./softmax/softmax.cu)|f32|f32|[link](./softmax/)|⭐️⭐️|
76+
| ✔️ [safe_softmax_f32x4](./softmax/softmax.cu)|f32|f32|[link](./softmax/)|⭐️⭐️|
77+
| ✔️ [safe_softmax_f16_f32](./softmax/softmax.cu)|f16|f32|[link](./softmax/)|⭐️⭐️|
78+
| ✔️ [safe_softmax_f16x2_f32](./softmax/softmax.cu)|f16|f32|[link](./softmax/)|⭐️⭐️|
79+
| ✔️ [safe_softmax_f16x8_pack_f32](./softmax/softmax.cu)|f16|f32|[link](./softmax/)|⭐️⭐️|
80+
| ✔️ [layer_norm_f32](./layer-norm/layer_norm.cu)|f32|f32|[link](./layer-norm/)|⭐️⭐️|
81+
| ✔️ [layer_norm_f32x4](./layer-norm/layer_norm.cu)|f32|f32|[link](./layer-norm/)|⭐️⭐️|
82+
| ✔️ [layer_norm_f16_f16](./layer-norm/layer_norm.cu)|f16|f16|[link](./layer-norm/)|⭐️⭐️|
83+
| ✔️ [layer_norm_f16x2_f16](./layer-norm/layer_norm.cu)|f16|f16|[link](./layer-norm/)|⭐️⭐️|
84+
| ✔️ [layer_norm_f16x8_f16](./layer-norm/layer_norm.cu)|f16|f16|[link](./layer-norm/)|⭐️⭐️|
85+
| ✔️ [layer_norm_f16x8_pack_f16](./layer-norm/layer_norm.cu)|f16|f16|[link](./layer-norm/)|⭐️⭐️|
86+
| ✔️ [layer_norm_f16x8_pack_f32](./layer-norm/layer_norm.cu)|f16|f32|[link](./layer-norm/)|⭐️⭐️|
87+
| ✔️ [layer_norm_f16_f32](./layer-norm/layer_norm.cu)|f16|f32|[link](./layer-norm/)|⭐️⭐️|
88+
| ✔️ [rms_norm_f32](./rms-norm/rms_norm.cu)|f32|f32|[link](./rms-norm/)|⭐️⭐️|
89+
| ✔️ [rms_norm_f32x4](./rms-norm/rms_norm.cu)|f32|f32|[link](./rms-norm/)|⭐️⭐️|
90+
| ✔️ [rms_norm_f16_f16](./rms-norm/rms_norm.cu)|f16|f16|[link](./rms-norm/)|⭐️⭐️|
91+
| ✔️ [rms_norm_f16x2_f16](./rms-norm/rms_norm.cu)|f16|f16|[link](./rms-norm/)|⭐️⭐️|
92+
| ✔️ [rms_norm_f16x8_f16](./rms-norm/rms_norm.cu)|f16|f16|[link](./rms-norm/)|⭐️⭐️|
93+
| ✔️ [rms_norm_f16x8_f32](./rms-norm/rms_norm.cu)|f16|f32|[link](./rms-norm/)|⭐️⭐️|
94+
| ✔️ [rms_norm_f16x8_pack_f16](./rms-norm/rms_norm.cu)|f16|f16|[link](./rms-norm/)|⭐️⭐️|
95+
| ✔️ [rms_norm_f16x8_pack_f32](./rms-norm/rms_norm.cu)|f16|f32|[link](./rms-norm/)|⭐️⭐️|
96+
| ✔️ [rms_norm_f16_f32](./rms-norm/rms_norm.cu)|f16|f32|[link](./rms-norm/)|⭐️⭐️|
9897
| ✔️ [sgemm_naive_f32](./sgemm/sgemm.cu)|f32|f32|[link](./sgemm/)|⭐️⭐️|
9998
| ✔️ [sgemm_sliced_k_f32](./sgemm/sgemm.cu)|f32|f32|[link](./sgemm/)|⭐️⭐️⭐️|
10099
| ✔️ [sgemm_t_8x8_sliced_k_f32x4](./sgemm/sgemm.cu)|f32|f32|[link](./sgemm/)|⭐️⭐️⭐️|
101-
| ✔️ [sgemm_t_8x8_sliced_k_f32x4_bcf](./sgemm/sgemm.cu)|f32|f32|[link](./sgemm/)|⭐️⭐️⭐️|
102-
| ✔️ [sgemm_t_8x8_sliced_k_f32x4_bcf_dbuf](./sgemm/sgemm.cu)|f32|f32|[link](./sgemm/)|⭐️⭐️⭐️|
100+
| ✔️ [sgemm_t_8x8_sliced_k_..._bcf](./sgemm/sgemm.cu)|f32|f32|[link](./sgemm/)|⭐️⭐️⭐️|
101+
| ✔️ [sgemm_t_8x8_sliced_k_..._dbuf](./sgemm/sgemm.cu)|f32|f32|[link](./sgemm/)|⭐️⭐️⭐️|
103102
| ✔️ [hgemm_sliced_k_f16](./hgemm/hgemm.cu)|f16|f16|[link](./hgemm/)|⭐️⭐️⭐️|
104103
| ✔️ [hgemm_t_8x8_sliced_k_f16x4](./hgemm/hgemm.cu)|f16|f16|[link](./hgemm/)|⭐️⭐️⭐️|
105104
| ✔️ [sgemv_k32_f32](./sgemv/sgemv.cu)|f32|f32|[link](./sgemv/)|⭐️⭐️⭐️|

dot-product/README.md

Lines changed: 97 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
## 测试
1515

1616
```bash
17-
# 只测试Ada架构 不指定默认编译所有架构 耗时较长
17+
# 只测试Ada架构 不指定默认编译所有架构 耗时较长: Volta, Ampere, Ada, Hopper, ...
1818
export TORCH_CUDA_ARCH_LIST=Ada
1919
python3 dot_product.py
2020
```
@@ -23,13 +23,102 @@ python3 dot_product.py
2323

2424
```bash
2525
--------------------------------------------------------------------------------
26-
out_f32f32: -1534.59301758 , time:0.17350578ms
27-
out_f32x4f32: -1534.61364746 , time:0.18058038ms
28-
out_f32f32_th: -1534.61157227 , time:0.18307972ms
26+
S=1024, K=1024
27+
out_f32f32: -670.21264648 , time:0.08947158ms
28+
out_f32x4f32: -670.21435547 , time:0.02821302ms
29+
out_f32f32_th: -670.21374512 , time:0.09709382ms
2930
--------------------------------------------------------------------------------
30-
out_f16f32: -1538.26318359 , time:0.10106802ms
31-
out_f16x2f32: -1537.58288574 , time:0.05217433ms
32-
out_f16x8packf32: -1536.44006348 , time:0.02096844ms
33-
out_f16f16_th: -1536.00000000 , time:0.02491832ms
31+
out_f16f32: -670.32208252 , time:0.04000235ms
32+
out_f16x2f32: -670.15814209 , time:0.05491829ms
33+
out_f16x8packf32: -669.90997314 , time:0.01669478ms
34+
out_f16f16_th: -670.50000000 , time:0.02021313ms
35+
--------------------------------------------------------------------------------
36+
--------------------------------------------------------------------------------
37+
S=1024, K=2048
38+
out_f32f32: 1040.51086426 , time:0.04557490ms
39+
out_f32x4f32: 1040.50720215 , time:0.06275582ms
40+
out_f32f32_th: 1040.50842285 , time:0.04762864ms
41+
--------------------------------------------------------------------------------
42+
out_f16f32: 1041.44299316 , time:0.03214121ms
43+
out_f16x2f32: 1041.79589844 , time:0.03448486ms
44+
out_f16x8packf32: 1042.22717285 , time:0.02689457ms
45+
out_f16f16_th: 1041.00000000 , time:0.02859521ms
46+
--------------------------------------------------------------------------------
47+
--------------------------------------------------------------------------------
48+
S=1024, K=4096
49+
out_f32f32: -1859.81457520 , time:0.08664179ms
50+
out_f32x4f32: -1859.81628418 , time:0.08621526ms
51+
out_f32f32_th: -1859.81933594 , time:0.08647323ms
52+
--------------------------------------------------------------------------------
53+
out_f16f32: -1860.23291016 , time:0.05826116ms
54+
out_f16x2f32: -1860.91186523 , time:0.04677963ms
55+
out_f16x8packf32: -1860.25988770 , time:0.04591107ms
56+
out_f16f16_th: -1861.00000000 , time:0.04904127ms
57+
--------------------------------------------------------------------------------
58+
--------------------------------------------------------------------------------
59+
S=2048, K=1024
60+
out_f32f32: 858.98229980 , time:0.04499865ms
61+
out_f32x4f32: 858.98461914 , time:0.04623890ms
62+
out_f32f32_th: 858.98376465 , time:0.06848693ms
63+
--------------------------------------------------------------------------------
64+
out_f16f32: 858.85339355 , time:0.03274632ms
65+
out_f16x2f32: 858.94274902 , time:0.02831578ms
66+
out_f16x8packf32: 859.46844482 , time:0.02884459ms
67+
out_f16f16_th: 859.00000000 , time:0.03692698ms
68+
--------------------------------------------------------------------------------
69+
--------------------------------------------------------------------------------
70+
S=2048, K=2048
71+
out_f32f32: -1205.77990723 , time:0.08356524ms
72+
out_f32x4f32: -1205.77624512 , time:0.08583307ms
73+
out_f32f32_th: -1205.77807617 , time:0.08613133ms
74+
--------------------------------------------------------------------------------
75+
out_f16f32: -1205.40588379 , time:0.06001544ms
76+
out_f16x2f32: -1205.29028320 , time:0.04738235ms
77+
out_f16x8packf32: -1205.72924805 , time:0.04624581ms
78+
out_f16f16_th: -1205.00000000 , time:0.04907203ms
79+
--------------------------------------------------------------------------------
80+
--------------------------------------------------------------------------------
81+
S=2048, K=4096
82+
out_f32f32: -893.49169922 , time:0.16136765ms
83+
out_f32x4f32: -893.48596191 , time:0.16174912ms
84+
out_f32f32_th: -893.48901367 , time:0.16518927ms
85+
--------------------------------------------------------------------------------
86+
out_f16f32: -894.42169189 , time:0.11468077ms
87+
out_f16x2f32: -894.61779785 , time:0.08950567ms
88+
out_f16x8packf32: -895.26538086 , time:0.08448958ms
89+
out_f16f16_th: -894.00000000 , time:0.09156108ms
90+
--------------------------------------------------------------------------------
91+
--------------------------------------------------------------------------------
92+
S=4096, K=1024
93+
out_f32f32: 141.78890991 , time:0.08385873ms
94+
out_f32x4f32: 141.78639221 , time:0.08500123ms
95+
out_f32f32_th: 141.78683472 , time:0.08647728ms
96+
--------------------------------------------------------------------------------
97+
out_f16f32: 141.80113220 , time:0.05876780ms
98+
out_f16x2f32: 141.62113953 , time:0.04708385ms
99+
out_f16x8packf32: 141.15240479 , time:0.04586506ms
100+
out_f16f16_th: 141.50000000 , time:0.04933500ms
101+
--------------------------------------------------------------------------------
102+
--------------------------------------------------------------------------------
103+
S=4096, K=2048
104+
out_f32f32: -1238.80456543 , time:0.16236329ms
105+
out_f32x4f32: -1238.80737305 , time:0.16246724ms
106+
out_f32f32_th: -1238.80859375 , time:0.16496468ms
107+
--------------------------------------------------------------------------------
108+
out_f16f32: -1238.78466797 , time:0.11416745ms
109+
out_f16x2f32: -1239.28540039 , time:0.08488607ms
110+
out_f16x8packf32: -1238.85302734 , time:0.08867455ms
111+
out_f16f16_th: -1239.00000000 , time:0.09029007ms
112+
--------------------------------------------------------------------------------
113+
--------------------------------------------------------------------------------
114+
S=4096, K=4096
115+
out_f32f32: 556.32690430 , time:0.31692672ms
116+
out_f32x4f32: 556.33087158 , time:0.31752276ms
117+
out_f32f32_th: 556.32879639 , time:0.32040811ms
118+
--------------------------------------------------------------------------------
119+
out_f16f32: 554.45031738 , time:0.23417449ms
120+
out_f16x2f32: 553.61444092 , time:0.16469955ms
121+
out_f16x8packf32: 554.04040527 , time:0.16465998ms
122+
out_f16f16_th: 554.50000000 , time:0.17046404ms
34123
--------------------------------------------------------------------------------
35124
```

dot-product/dot_product.py

Lines changed: 21 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -42,19 +42,24 @@ def run_benchmark(perf_func: callable, a: torch.Tensor, b: torch.Tensor, tag: st
4242
return out, mean_time
4343

4444

45-
print("-" * 80)
46-
S, K = 4096, 4096
47-
a = torch.randn((S*K)).cuda().float()
48-
b = torch.randn((S*K)).cuda().float()
49-
run_benchmark(lib.dot_prod_f32_f32, a, b, "f32f32")
50-
run_benchmark(lib.dot_prod_f32x4_f32, a, b, "f32x4f32")
51-
run_benchmark(torch.dot, a, b, "f32f32_th")
52-
53-
print("-" * 80)
54-
a_f16 = a.half()
55-
b_f16 = b.half()
56-
run_benchmark(lib.dot_prod_f16_f32, a_f16, b_f16, "f16f32")
57-
run_benchmark(lib.dot_prod_f16x2_f32, a_f16, b_f16, "f16x2f32")
58-
run_benchmark(lib.dot_prod_f16x8_pack_f32, a_f16, b_f16, "f16x8packf32")
59-
run_benchmark(torch.dot, a_f16, b_f16, "f16f16_th")
60-
print("-" * 80)
45+
Ss = [1024, 2048, 4096]
46+
Ks = [1024, 2048, 4096]
47+
SKs = [(S, K) for S in Ss for K in Ks]
48+
49+
for (S, K) in SKs:
50+
print("-" * 80)
51+
print(" " * 25 + f"S={S}, K={K}")
52+
a = torch.randn((S*K)).cuda().float()
53+
b = torch.randn((S*K)).cuda().float()
54+
run_benchmark(lib.dot_prod_f32_f32, a, b, "f32f32")
55+
run_benchmark(lib.dot_prod_f32x4_f32, a, b, "f32x4f32")
56+
run_benchmark(torch.dot, a, b, "f32f32_th")
57+
58+
print("-" * 80)
59+
a_f16 = a.half()
60+
b_f16 = b.half()
61+
run_benchmark(lib.dot_prod_f16_f32, a_f16, b_f16, "f16f32")
62+
run_benchmark(lib.dot_prod_f16x2_f32, a_f16, b_f16, "f16x2f32")
63+
run_benchmark(lib.dot_prod_f16x8_pack_f32, a_f16, b_f16, "f16x8packf32")
64+
run_benchmark(torch.dot, a_f16, b_f16, "f16f16_th")
65+
print("-" * 80)

0 commit comments

Comments
 (0)