Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fp8 tuning upstream #1380

Merged
merged 24 commits into from
Dec 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion tensilelite/Tensile/Utilities/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,7 @@ def mergeLogic(oriData, incData, forceMerge, trimSize=True, addSolutionTags=Fals
origNumSizes = len(oriData[7])
origNumSolutions = len(oriData[5])

incData[7] = incData[7] or []
incNumSizes = len(incData[7])
incNumSolutions = len(incData[5])

Expand All @@ -320,7 +321,7 @@ def mergeLogic(oriData, incData, forceMerge, trimSize=True, addSolutionTags=Fals
incTaggedSizes = addSolutionTagToKeys(incData[7], incData[5])
if addSolutionTags:
oriData[7] = origTaggedSizes
incData[7] = incTaggedSizes
incData[7] = incTaggedSizes
# Print warning if addSolutionTags=False results in removed sizes
else:
origSet = {tuple(size) for size, [_, _] in oriData[7]}
Expand Down
10 changes: 5 additions & 5 deletions tensilelite/Tensile/Utilities/tensile_generator/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ To use the `tensile_config_generator.py` script, follow these steps:
| `--iters ITERS` | Max tuning iterations |
| `--fast BOOL` | If enabled, only tune the matrix instruction with min tile sizes, else, tune full matrix instructions |
| `--gridbase_config GRIDBASE_CONFIG` | Path to gridbase config file |
| `--full_mfma BOOL` | If enabled, will search for all mfma instructions |
| `--full_stage BOOL` | If enabled, will search for all stages instructions |
| `--num_stages STAGES` | How many times to divide matrix |

Equality tuning example:
```
Expand All @@ -48,7 +51,7 @@ To use the `tensile_config_generator.py` script, follow these steps:

3. Install hipBLASLt and Tensile (change the path to the hipBLASLt repo):
```
bash ./install.sh -idc -a $(/opt/rocm/llvm/bin/offload-arch) --keep-build-tmp
bash ./install.sh -idc -a $(/opt/rocm/llvm/bin/offload-arch) --cpu_ref_lib=lapack
```

4. Tune GEMM kernels using the generated YAML files:
Expand All @@ -60,22 +63,19 @@ To use the `tensile_config_generator.py` script, follow these steps:

MI308:

Modify yamls under ```<tune result directory>/3_LibraryLogic/```. ```- gfx942 -> - {Architecture: gfx942, CUCount: {20|80}}```

For cpx, use the gfx942_20cu folder; for spx, use the gfx942_80cu folder.
```
python3 ./tensilelite/Tensile/Utilities/merge.py --no_eff library/src/amd_detail/rocblaslt/src/Tensile/Logic/asm_full/aquavanjaram/{gfx942_20cu|gfx942_80cu}/{Equality|GridBased}/ <tune result directory>/3_LibraryLogic/ library/src/amd_detail/rocblaslt/src/Tensile/Logic/asm_full/aquavanjaram/{gfx942_20cu|gfx942_80cu}/{Equality|GridBased}/
```
MI210:

Modify yamls under ```<tune result directory>/3_LibraryLogic/```. ```- gfx90a -> - {Architecture: gfx90a, CUCount: 104}```
```
python3 ./tensilelite/Tensile/Utilities/merge.py --no_eff library/src/amd_detail/rocblaslt/src/Tensile/Logic/asm_full/aldebaran/104CU/{Equality|GridBased}/ <tune result directory>/3_LibraryLogic/ library/src/amd_detail/rocblaslt/src/Tensile/Logic/asm_full/aldebaran/104CU/{Equality|GridBased}/
```

6. Rebuild hipBLASLt with the merged results:
```
bash ./install.sh -idc -a $(/opt/rocm/llvm/bin/offload-arch) --keep-build-tmp
bash ./install.sh -idc -a $(/opt/rocm/llvm/bin/offload-arch) --cpu_ref_lib=lapack
```

For more detailed information on the script's functionality and advanced usage, please refer to the comments within the `tensile_config_generator.py` file.
Original file line number Diff line number Diff line change
@@ -1,20 +1,14 @@
hipblaslt-bench --api_method c -m 1 -n 200 -k 24 --lda 24 --ldb 24 --ldc 1 --ldd 1 --stride_a 24 --stride_b 4800 --stride_c 200 --stride_d 200 --alpha 1.000000 --beta 0.000000 --transA T --transB N --batch_count 800 --a_type f32_r --b_type f32_r --c_type f32_r --d_type f32_r --scale_type f32_r --bias_type f32_r --compute_type f32_r
hipblaslt-bench --api_method c -m 1 -n 24 -k 200 --lda 1 --ldb 24 --ldc 1 --ldd 1 --stride_a 200 --stride_b 4800 --stride_c 24 --stride_d 24 --alpha 1.000000 --beta 0.000000 --transA N --transB T --batch_count 800 --a_type f16_r --b_type f16_r --c_type f16_r --d_type f16_r --scale_type f32_r --bias_type f32_r --compute_type f32_r
hipblaslt-bench --api_method c -m 1024 -n 200 -k 5244 --lda 1024 --ldb 5244 --ldc 1024 --ldd 1024 --stride_a 0 --stride_b 0 --stride_c 0 --stride_d 0 --alpha 1.000000 --beta 1.000000 --transA N --transB N --batch_count 1 --a_type f16_r --b_type f16_r --c_type f32_r --d_type f32_r --scale_type f32_r --bias_type f32_r --compute_type f32_r
hipblaslt-bench --api_method c -m 128 -n 150 -k 128 --lda 128 --ldb 128 --ldc 128 --ldd 128 --stride_a 16384 --stride_b 19200 --stride_c 19200 --stride_d 19200 --alpha 1.000000 --beta 0.000000 --transA N --transB N --batch_count 2 --a_type f16_r --b_type f16_r --c_type f16_r --d_type f16_r --scale_type f32_r --bias_type f32_r --compute_type f32_r
hipblaslt-bench --api_method c -m 128 -n 200 -k 256 --lda 128 --ldb 256 --ldc 128 --ldd 128 --stride_a 0 --stride_b 0 --stride_c 0 --stride_d 0 --alpha 1.000000 --beta 1.000000 --transA N --transB N --batch_count 1 --a_type f16_r --b_type f16_r --c_type f32_r --d_type f32_r --scale_type f32_r --bias_type f32_r --compute_type f32_r
hipblaslt-bench --api_method c -m 150 -n 200 -k 32 --lda 32 --ldb 32 --ldc 150 --ldd 150 --stride_a 4800 --stride_b 6400 --stride_c 30000 --stride_d 30000 --alpha 1.000000 --beta 0.000000 --transA T --transB N --batch_count 8 --a_type f16_r --b_type f16_r --c_type f32_r --d_type f32_r --scale_type f32_r --bias_type f32_r --compute_type f32_r
hipblaslt-bench --api_method c -m 2 -n 200 -k 64 --lda 2 --ldb 64 --ldc 2 --ldd 2 --stride_a 0 --stride_b 0 --stride_c 0 --stride_d 0 --alpha 1.000000 --beta 1.000000 --transA N --transB N --batch_count 1 --a_type f16_r --b_type f16_r --c_type f32_r --d_type f32_r --scale_type f32_r --bias_type f32_r --compute_type f32_r
hipblaslt-bench --api_method c -m 200 -n 200 -k 24 --lda 24 --ldb 24 --ldc 200 --ldd 200 --stride_a 4800 --stride_b 4800 --stride_c 40000 --stride_d 40000 --alpha 1.000000 --beta 0.000000 --transA T --transB N --batch_count 4 --a_type f16_r --b_type f16_r --c_type f32_r --d_type f32_r --scale_type f32_r --bias_type f32_r --compute_type f32_r
hipblaslt-bench --api_method c -m 24 -n 200 -k 200 --lda 24 --ldb 200 --ldc 24 --ldd 24 --stride_a 4800 --stride_b 40000 --stride_c 4800 --stride_d 4800 --alpha 1.000000 --beta 0.000000 --transA N --transB N --batch_count 4 --a_type f16_r --b_type f16_r --c_type f32_r --d_type f32_r --scale_type f32_r --bias_type f32_r --compute_type f32_r
hipblaslt-bench --api_method c -m 24 -n 200 -k 50 --lda 24 --ldb 50 --ldc 24 --ldd 24 --stride_a 1200 --stride_b 10000 --stride_c 4800 --stride_d 4800 --alpha 1.000000 --beta 0.000000 --transA N --transB N --batch_count 8 --a_type f16_r --b_type f16_r --c_type f16_r --d_type f16_r --scale_type f32_r --bias_type f32_r --compute_type f32_r
hipblaslt-bench --api_method c -m 256 -n 200 -k 512 --lda 256 --ldb 512 --ldc 256 --ldd 256 --stride_a 0 --stride_b 0 --stride_c 0 --stride_d 0 --alpha 1.000000 --beta 1.000000 --transA N --transB N --batch_count 1 --a_type f16_r --b_type f16_r --c_type f32_r --d_type f32_r --scale_type f32_r --bias_type f32_r --compute_type f32_r
hipblaslt-bench --api_method c -m 256 -n 200 -k 96 --lda 256 --ldb 96 --ldc 256 --ldd 256 --stride_a 0 --stride_b 0 --stride_c 0 --stride_d 0 --alpha 1.000000 --beta 0.000000 --transA N --transB N --batch_count 1 --a_type f16_r --b_type f16_r --c_type f16_r --d_type f16_r --scale_type f32_r --bias_type f32_r --compute_type f32_r
hipblaslt-bench --api_method c -m 32 -n 200 -k 150 --lda 32 --ldb 150 --ldc 32 --ldd 32 --stride_a 4800 --stride_b 30000 --stride_c 6400 --stride_d 6400 --alpha 1.000000 --beta 0.000000 --transA N --transB N --batch_count 8 --a_type f16_r --b_type f16_r --c_type f32_r --d_type f32_r --scale_type f32_r --bias_type f32_r --compute_type f32_r
hipblaslt-bench --api_method c -m 384 -n 200 -k 96 --lda 384 --ldb 96 --ldc 384 --ldd 384 --stride_a 0 --stride_b 0 --stride_c 0 --stride_d 0 --alpha 1.000000 --beta 0.000000 --transA N --transB N --batch_count 1 --a_type f16_r --b_type f16_r --c_type f16_r --d_type f16_r --scale_type f32_r --bias_type f32_r --compute_type f32_r
hipblaslt-bench --api_method c -m 50 -n 200 -k 24 --lda 24 --ldb 24 --ldc 50 --ldd 50 --stride_a 1200 --stride_b 4800 --stride_c 10000 --stride_d 10000 --alpha 1.000000 --beta 0.000000 --transA T --transB N --batch_count 8 --a_type f16_r --b_type f16_r --c_type f32_r --d_type f32_r --scale_type f32_r --bias_type f32_r --compute_type f32_r
hipblaslt-bench --api_method c -m 512 -n 200 -k 1024 --lda 512 --ldb 1024 --ldc 512 --ldd 512 --stride_a 0 --stride_b 0 --stride_c 0 --stride_d 0 --alpha 1.000000 --beta 1.000000 --transA N --transB N --batch_count 1 --a_type f16_r --b_type f16_r --c_type f32_r --d_type f32_r --scale_type f32_r --bias_type f32_r --compute_type f32_r
hipblaslt-bench --api_method c -m 64 -n 200 -k 128 --lda 64 --ldb 128 --ldc 64 --ldd 64 --stride_a 0 --stride_b 0 --stride_c 0 --stride_d 0 --alpha 1.000000 --beta 1.000000 --transA N --transB N --batch_count 1 --a_type f16_r --b_type f16_r --c_type f32_r --d_type f32_r --scale_type f32_r --bias_type f32_r --compute_type f32_r
hipblaslt-bench --api_method c -m 96 -n 200 -k 96 --lda 96 --ldb 96 --ldc 96 --ldd 96 --stride_a 0 --stride_b 0 --stride_c 0 --stride_d 0 --alpha 1.000000 --beta 0.000000 --transA N --transB N --batch_count 1 --a_type f16_r --b_type f16_r --c_type f16_r --d_type f16_r --scale_type f32_r --bias_type f32_r --compute_type f32_r
hipblaslt-bench --api_method c -m 96 -n 50 -k 96 --lda 96 --ldb 96 --ldc 96 --ldd 96 --stride_a 9216 --stride_b 4800 --stride_c 4800 --stride_d 4800 --alpha 1.000000 --beta 0.000000 --transA N --transB N --batch_count 2 --a_type f16_r --b_type f16_r --c_type f16_r --d_type f16_r --scale_type f32_r --bias_type f32_r --compute_type f32_r
hipblaslt-bench --api_method c -m 96 -n 6000 -k 96 --lda 96 --ldb 96 --ldc 96 --ldd 96 --stride_a 0 --stride_b 0 --stride_c 0 --stride_d 0 --alpha 1.000000 --beta 0.000000 --transA N --transB N --batch_count 1 --a_type f16_r --b_type f16_r --c_type f16_r --d_type f16_r --scale_type f32_r --bias_type f32_r --compute_type f32_r
hipblaslt-bench --api_method c -m 10 -n 1000 -k 128 --lda 128 --ldb 128 --ldc 10 --ldd 10 --stride_a 0 --stride_b 0 --stride_c 0 --stride_d 0 --alpha 1.000000 --beta 1.000000 --transA T --transB N --batch_count 1 --scaleA 0 --scaleB 0 --a_type f32_r --b_type f32_r --c_type f32_r --d_type f32_r --scale_type f32_r --bias_type f32_r --compute_type f32_r --algo_method index --solution_index 156482 --activation_type none
hipblaslt-bench --api_method c -m 10 -n 32 -k 128 --lda 128 --ldb 128 --ldc 10 --ldd 10 --stride_a 0 --stride_b 0 --stride_c 0 --stride_d 0 --alpha 1.000000 --beta 1.000000 --transA T --transB N --batch_count 1 --scaleA 0 --scaleB 0 --a_type f32_r --b_type f32_r --c_type f32_r --d_type f32_r --scale_type f32_r --bias_type f32_r --compute_type f32_r --algo_method index --solution_index 156117 --activation_type none
hipblaslt-bench --api_method c -m 10 -n 64 -k 128 --lda 128 --ldb 128 --ldc 10 --ldd 10 --stride_a 0 --stride_b 0 --stride_c 0 --stride_d 0 --alpha 1.000000 --beta 1.000000 --transA T --transB N --batch_count 1 --scaleA 0 --scaleB 0 --a_type f32_r --b_type f32_r --c_type f32_r --d_type f32_r --scale_type f32_r --bias_type f32_r --compute_type f32_r --algo_method index --solution_index 156482 --activation_type none
hipblaslt-bench --api_method c -m 128 -n 10 -k 32 --lda 128 --ldb 10 --ldc 128 --ldd 128 --stride_a 0 --stride_b 0 --stride_c 0 --stride_d 0 --alpha 1.000000 --beta 0.000000 --transA N --transB T --batch_count 1 --scaleA 0 --scaleB 0 --a_type f32_r --b_type f32_r --c_type f32_r --d_type f32_r --scale_type f32_r --bias_type f32_r --compute_type f32_r --algo_method index --solution_index 13713 --activation_type none
hipblaslt-bench --api_method c -m 128 -n 10 -k 64 --lda 128 --ldb 10 --ldc 128 --ldd 128 --stride_a 0 --stride_b 0 --stride_c 0 --stride_d 0 --alpha 1.000000 --beta 0.000000 --transA N --transB T --batch_count 1 --scaleA 0 --scaleB 0 --a_type f32_r --b_type f32_r --c_type f32_r --d_type f32_r --scale_type f32_r --bias_type f32_r --compute_type f32_r --algo_method index --solution_index 13713 --activation_type none
hipblaslt-bench --api_method c -m 128 -n 1000 -k 9216 --lda 9216 --ldb 9216 --ldc 128 --ldd 128 --stride_a 0 --stride_b 0 --stride_c 0 --stride_d 0 --alpha 1.000000 --beta 1.000000 --transA T --transB N --batch_count 1 --scaleA 0 --scaleB 0 --a_type f32_r --b_type f32_r --c_type f32_r --d_type f32_r --scale_type f32_r --bias_type f32_r --compute_type f32_r --algo_method index --solution_index 156486 --activation_type none
hipblaslt-bench --api_method c -m 128 -n 32 -k 10 --lda 128 --ldb 10 --ldc 128 --ldd 128 --stride_a 0 --stride_b 0 --stride_c 0 --stride_d 0 --alpha 1.000000 --beta 0.000000 --transA N --transB N --batch_count 1 --scaleA 0 --scaleB 0 --a_type f32_r --b_type f32_r --c_type f32_r --d_type f32_r --scale_type f32_r --bias_type f32_r --compute_type f32_r --algo_method index --solution_index 19256 --activation_type none
hipblaslt-bench --api_method c -m 128 -n 32 -k 9216 --lda 9216 --ldb 9216 --ldc 128 --ldd 128 --stride_a 0 --stride_b 0 --stride_c 0 --stride_d 0 --alpha 1.000000 --beta 1.000000 --transA T --transB N --batch_count 1 --scaleA 0 --scaleB 0 --a_type f32_r --b_type f32_r --c_type f32_r --d_type f32_r --scale_type f32_r --bias_type f32_r --compute_type f32_r --algo_method index --solution_index 156426 --activation_type none
hipblaslt-bench --api_method c -m 128 -n 64 -k 10 --lda 128 --ldb 10 --ldc 128 --ldd 128 --stride_a 0 --stride_b 0 --stride_c 0 --stride_d 0 --alpha 1.000000 --beta 0.000000 --transA N --transB N --batch_count 1 --scaleA 0 --scaleB 0 --a_type f32_r --b_type f32_r --c_type f32_r --d_type f32_r --scale_type f32_r --bias_type f32_r --compute_type f32_r --algo_method index --solution_index 19234 --activation_type none
hipblaslt-bench --api_method c -m 128 -n 64 -k 9216 --lda 9216 --ldb 9216 --ldc 128 --ldd 128 --stride_a 0 --stride_b 0 --stride_c 0 --stride_d 0 --alpha 1.000000 --beta 1.000000 --transA T --transB N --batch_count 1 --scaleA 0 --scaleB 0 --a_type f32_r --b_type f32_r --c_type f32_r --d_type f32_r --scale_type f32_r --bias_type f32_r --compute_type f32_r --algo_method index --solution_index 156490 --activation_type none
hipblaslt-bench --api_method c -m 9216 -n 128 -k 32 --lda 9216 --ldb 128 --ldc 9216 --ldd 9216 --stride_a 0 --stride_b 0 --stride_c 0 --stride_d 0 --alpha 1.000000 --beta 0.000000 --transA N --transB T --batch_count 1 --scaleA 0 --scaleB 0 --a_type f32_r --b_type f32_r --c_type f32_r --d_type f32_r --scale_type f32_r --bias_type f32_r --compute_type f32_r --algo_method index --solution_index 13717 --activation_type none
hipblaslt-bench --api_method c -m 9216 -n 128 -k 64 --lda 9216 --ldb 128 --ldc 9216 --ldd 9216 --stride_a 0 --stride_b 0 --stride_c 0 --stride_d 0 --alpha 1.000000 --beta 0.000000 --transA N --transB T --batch_count 1 --scaleA 0 --scaleB 0 --a_type f32_r --b_type f32_r --c_type f32_r --d_type f32_r --scale_type f32_r --bias_type f32_r --compute_type f32_r --algo_method index --solution_index 13717 --activation_type none
hipblaslt-bench --api_method c -m 9216 -n 32 -k 128 --lda 9216 --ldb 128 --ldc 9216 --ldd 9216 --stride_a 0 --stride_b 0 --stride_c 0 --stride_d 0 --alpha 1.000000 --beta 0.000000 --transA N --transB N --batch_count 1 --scaleA 0 --scaleB 0 --a_type f32_r --b_type f32_r --c_type f32_r --d_type f32_r --scale_type f32_r --bias_type f32_r --compute_type f32_r --algo_method index --solution_index 19272 --activation_type none
hipblaslt-bench --api_method c -m 9216 -n 64 -k 128 --lda 9216 --ldb 128 --ldc 9216 --ldd 9216 --stride_a 0 --stride_b 0 --stride_c 0 --stride_d 0 --alpha 1.000000 --beta 0.000000 --transA N --transB N --batch_count 1 --scaleA 0 --scaleB 0 --a_type f32_r --b_type f32_r --c_type f32_r --d_type f32_r --scale_type f32_r --bias_type f32_r --compute_type f32_r --algo_method index --solution_index 19260 --activation_type none
Loading