Skip to content

Commit af601fe

Browse files
authored
Merge pull request #1486 from emankov/HIPIFY
[HIPIFY][#936][BLASLT] `cublasLt` -> `hipblalLt` hipification support - Step 11
2 parents d16e995 + 3d10b35 commit af601fe

File tree

7 files changed

+94
-1
lines changed

7 files changed

+94
-1
lines changed

bin/hipify-perl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3869,6 +3869,8 @@ sub simpleSubstitutions {
38693869
subst("cublasIzamin_v2_64", "hipblasIzamin_v2_64", "library");
38703870
subst("cublasLtCreate", "hipblasLtCreate", "library");
38713871
subst("cublasLtDestroy", "hipblasLtDestroy", "library");
3872+
subst("cublasLtMatmul", "hipblasLtMatmul", "library");
3873+
subst("cublasLtMatrixTransform", "hipblasLtMatrixTransform", "library");
38723874
subst("cublasNrm2Ex", "hipblasNrm2Ex_v2", "library");
38733875
subst("cublasRotEx", "hipblasRotEx_v2", "library");
38743876
subst("cublasSasum", "hipblasSasum", "library");
@@ -6900,6 +6902,8 @@ sub simpleSubstitutions {
69006902
subst("cudaSuccess", "hipSuccess", "numeric_literal");
69016903
subst("cudaUserObjectNoDestructorSync", "hipUserObjectNoDestructorSync", "numeric_literal");
69026904
subst("cusolver_int_t", "int", "numeric_literal");
6905+
subst("CUBLASLT_ORDER_COL", "HIPBLASLT_ORDER_COL", "define");
6906+
subst("CUBLASLT_ORDER_ROW", "HIPBLASLT_ORDER_ROW", "define");
69036907
subst("CUB_MAX", "CUB_MAX", "define");
69046908
subst("CUB_MIN", "CUB_MIN", "define");
69056909
subst("CUB_NAMESPACE_BEGIN", "BEGIN_HIPCUB_NAMESPACE", "define");
@@ -6940,6 +6944,7 @@ sub simpleSubstitutions {
69406944
subst("_CubLog", "_HipcubLog", "define");
69416945
subst("__CUB_ALIGN_BYTES", "__HIPCUB_ALIGN_BYTES", "define");
69426946
subst("__CUDACC__", "__HIPCC__", "define");
6947+
subst("cublasLtOrder_t", "hipblasLtOrder_t", "define");
69436948
subst("cudaArrayCubemap", "hipArrayCubemap", "define");
69446949
subst("cudaArrayDefault", "hipArrayDefault", "define");
69456950
subst("cudaArrayLayered", "hipArrayLayered", "define");
@@ -11509,6 +11514,9 @@ sub warnHipOnlyUnsupportedFunctions {
1150911514
"CUBLASLT_POINTER_MODE_MASK_ALPHA_DEVICE_VECTOR_BETA_HOST",
1151011515
"CUBLASLT_POINTER_MODE_DEVICE_VECTOR",
1151111516
"CUBLASLT_POINTER_MODE_ALPHA_DEVICE_VECTOR_BETA_ZERO",
11517+
"CUBLASLT_ORDER_COL4_4R2_8C",
11518+
"CUBLASLT_ORDER_COL32_2R_4R4",
11519+
"CUBLASLT_ORDER_COL32",
1151211520
"CUBLASLT_NUMERICAL_IMPL_FLAGS_TENSOR_OP_MASK",
1151311521
"CUBLASLT_NUMERICAL_IMPL_FLAGS_OP_TYPE_MASK",
1151411522
"CUBLASLT_NUMERICAL_IMPL_FLAGS_OP_INPUT_TYPE_MASK",
@@ -11842,10 +11850,13 @@ sub warnRocOnlyUnsupportedFunctions {
1184211850
"cublasMigrateComputeType",
1184311851
"cublasLtPointerMode_t",
1184411852
"cublasLtPointerModeMask_t",
11853+
"cublasLtOrder_t",
1184511854
"cublasLtNumericalImplFlags_t",
11855+
"cublasLtMatrixTransform",
1184611856
"cublasLtMatmulTile_t",
1184711857
"cublasLtMatmulStages_t",
1184811858
"cublasLtMatmulInnerShape_t",
11859+
"cublasLtMatmul",
1184911860
"cublasLtHeuristicsCacheSetCapacity",
1185011861
"cublasLtHeuristicsCacheGetCapacity",
1185111862
"cublasLtGetVersion",
@@ -12128,6 +12139,11 @@ sub warnRocOnlyUnsupportedFunctions {
1212812139
"CUBLASLT_POINTER_MODE_DEVICE_VECTOR",
1212912140
"CUBLASLT_POINTER_MODE_DEVICE",
1213012141
"CUBLASLT_POINTER_MODE_ALPHA_DEVICE_VECTOR_BETA_ZERO",
12142+
"CUBLASLT_ORDER_ROW",
12143+
"CUBLASLT_ORDER_COL4_4R2_8C",
12144+
"CUBLASLT_ORDER_COL32_2R_4R4",
12145+
"CUBLASLT_ORDER_COL32",
12146+
"CUBLASLT_ORDER_COL",
1213112147
"CUBLASLT_NUMERICAL_IMPL_FLAGS_TENSOR_OP_MASK",
1213212148
"CUBLASLT_NUMERICAL_IMPL_FLAGS_OP_TYPE_MASK",
1213312149
"CUBLASLT_NUMERICAL_IMPL_FLAGS_OP_INPUT_TYPE_MASK",

docs/tables/CUBLAS_API_supported_by_HIP.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -298,6 +298,11 @@
298298
|`CUBLASLT_NUMERICAL_IMPL_FLAGS_OP_INPUT_TYPE_MASK`|11.0| | | | | | | | | |
299299
|`CUBLASLT_NUMERICAL_IMPL_FLAGS_OP_TYPE_MASK`|11.0| | | | | | | | | |
300300
|`CUBLASLT_NUMERICAL_IMPL_FLAGS_TENSOR_OP_MASK`|11.0| | | | | | | | | |
301+
|`CUBLASLT_ORDER_COL`|10.1| | | |`HIPBLASLT_ORDER_COL`|6.0.0| | | | |
302+
|`CUBLASLT_ORDER_COL32`|10.1| | | | | | | | | |
303+
|`CUBLASLT_ORDER_COL32_2R_4R4`|11.0| | | | | | | | | |
304+
|`CUBLASLT_ORDER_COL4_4R2_8C`|10.1| | | | | | | | | |
305+
|`CUBLASLT_ORDER_ROW`|10.1| | | |`HIPBLASLT_ORDER_ROW`|6.0.0| | | | |
301306
|`CUBLASLT_POINTER_MODE_ALPHA_DEVICE_VECTOR_BETA_HOST`|11.4| | | |`HIPBLASLT_POINTER_MODE_ALPHA_DEVICE_VECTOR_BETA_HOST`|6.0.0| | | | |
302307
|`CUBLASLT_POINTER_MODE_ALPHA_DEVICE_VECTOR_BETA_ZERO`|10.1| | | | | | | | | |
303308
|`CUBLASLT_POINTER_MODE_DEVICE`| | | | |`HIPBLASLT_POINTER_MODE_DEVICE`|6.1.0| | | | |
@@ -325,6 +330,7 @@
325330
|`cublasLtMatrixTransformDescOpaque_t`|11.0| | | |`hipblasLtMatrixTransformDescOpaque_t`|6.0.0| | | | |
326331
|`cublasLtMatrixTransformDesc_t`|10.1| | | |`hipblasLtMatrixTransformDesc_t`|6.0.0| | | | |
327332
|`cublasLtNumericalImplFlags_t`|11.0| | | | | | | | | |
333+
|`cublasLtOrder_t`|10.1| | | |`hipblasLtOrder_t`|6.0.0| | | | |
328334
|`cublasLtPointerModeMask_t`|10.1| | | | | | | | | |
329335
|`cublasLtPointerMode_t`|10.1| | | |`hipblasLtPointerMode_t`|6.0.0| | | | |
330336

@@ -1212,6 +1218,8 @@
12121218
|`cublasLtGetVersion`|10.1| | | | | | | | | |
12131219
|`cublasLtHeuristicsCacheGetCapacity`|11.8| | | | | | | | | |
12141220
|`cublasLtHeuristicsCacheSetCapacity`|11.8| | | | | | | | | |
1221+
|`cublasLtMatmul`|10.1| | | |`hipblasLtMatmul`|5.5.0| | | | |
1222+
|`cublasLtMatrixTransform`|10.1| | | |`hipblasLtMatrixTransform`|6.0.0| | | | |
12151223

12161224

12171225
\*A - Added; D - Deprecated; C - Changed; R - Removed; E - Experimental

docs/tables/CUBLAS_API_supported_by_HIP_and_ROC.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -298,6 +298,11 @@
298298
|`CUBLASLT_NUMERICAL_IMPL_FLAGS_OP_INPUT_TYPE_MASK`|11.0| | | | | | | | | | | | | | | |
299299
|`CUBLASLT_NUMERICAL_IMPL_FLAGS_OP_TYPE_MASK`|11.0| | | | | | | | | | | | | | | |
300300
|`CUBLASLT_NUMERICAL_IMPL_FLAGS_TENSOR_OP_MASK`|11.0| | | | | | | | | | | | | | | |
301+
|`CUBLASLT_ORDER_COL`|10.1| | | |`HIPBLASLT_ORDER_COL`|6.0.0| | | | | | | | | | |
302+
|`CUBLASLT_ORDER_COL32`|10.1| | | | | | | | | | | | | | | |
303+
|`CUBLASLT_ORDER_COL32_2R_4R4`|11.0| | | | | | | | | | | | | | | |
304+
|`CUBLASLT_ORDER_COL4_4R2_8C`|10.1| | | | | | | | | | | | | | | |
305+
|`CUBLASLT_ORDER_ROW`|10.1| | | |`HIPBLASLT_ORDER_ROW`|6.0.0| | | | | | | | | | |
301306
|`CUBLASLT_POINTER_MODE_ALPHA_DEVICE_VECTOR_BETA_HOST`|11.4| | | |`HIPBLASLT_POINTER_MODE_ALPHA_DEVICE_VECTOR_BETA_HOST`|6.0.0| | | | | | | | | | |
302307
|`CUBLASLT_POINTER_MODE_ALPHA_DEVICE_VECTOR_BETA_ZERO`|10.1| | | | | | | | | | | | | | | |
303308
|`CUBLASLT_POINTER_MODE_DEVICE`| | | | |`HIPBLASLT_POINTER_MODE_DEVICE`|6.1.0| | | | | | | | | | |
@@ -325,6 +330,7 @@
325330
|`cublasLtMatrixTransformDescOpaque_t`|11.0| | | |`hipblasLtMatrixTransformDescOpaque_t`|6.0.0| | | | | | | | | | |
326331
|`cublasLtMatrixTransformDesc_t`|10.1| | | |`hipblasLtMatrixTransformDesc_t`|6.0.0| | | | | | | | | | |
327332
|`cublasLtNumericalImplFlags_t`|11.0| | | | | | | | | | | | | | | |
333+
|`cublasLtOrder_t`|10.1| | | |`hipblasLtOrder_t`|6.0.0| | | | | | | | | | |
328334
|`cublasLtPointerModeMask_t`|10.1| | | | | | | | | | | | | | | |
329335
|`cublasLtPointerMode_t`|10.1| | | |`hipblasLtPointerMode_t`|6.0.0| | | | | | | | | | |
330336

@@ -1212,6 +1218,8 @@
12121218
|`cublasLtGetVersion`|10.1| | | | | | | | | | | | | | | |
12131219
|`cublasLtHeuristicsCacheGetCapacity`|11.8| | | | | | | | | | | | | | | |
12141220
|`cublasLtHeuristicsCacheSetCapacity`|11.8| | | | | | | | | | | | | | | |
1221+
|`cublasLtMatmul`|10.1| | | |`hipblasLtMatmul`|5.5.0| | | | | | | | | | |
1222+
|`cublasLtMatrixTransform`|10.1| | | |`hipblasLtMatrixTransform`|6.0.0| | | | | | | | | | |
12151223

12161224

12171225
\*A - Added; D - Deprecated; C - Changed; R - Removed; E - Experimental

docs/tables/CUBLAS_API_supported_by_ROC.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -298,6 +298,11 @@
298298
|`CUBLASLT_NUMERICAL_IMPL_FLAGS_OP_INPUT_TYPE_MASK`|11.0| | | | | | | | | |
299299
|`CUBLASLT_NUMERICAL_IMPL_FLAGS_OP_TYPE_MASK`|11.0| | | | | | | | | |
300300
|`CUBLASLT_NUMERICAL_IMPL_FLAGS_TENSOR_OP_MASK`|11.0| | | | | | | | | |
301+
|`CUBLASLT_ORDER_COL`|10.1| | | | | | | | | |
302+
|`CUBLASLT_ORDER_COL32`|10.1| | | | | | | | | |
303+
|`CUBLASLT_ORDER_COL32_2R_4R4`|11.0| | | | | | | | | |
304+
|`CUBLASLT_ORDER_COL4_4R2_8C`|10.1| | | | | | | | | |
305+
|`CUBLASLT_ORDER_ROW`|10.1| | | | | | | | | |
301306
|`CUBLASLT_POINTER_MODE_ALPHA_DEVICE_VECTOR_BETA_HOST`|11.4| | | | | | | | | |
302307
|`CUBLASLT_POINTER_MODE_ALPHA_DEVICE_VECTOR_BETA_ZERO`|10.1| | | | | | | | | |
303308
|`CUBLASLT_POINTER_MODE_DEVICE`| | | | | | | | | | |
@@ -325,6 +330,7 @@
325330
|`cublasLtMatrixTransformDescOpaque_t`|11.0| | | | | | | | | |
326331
|`cublasLtMatrixTransformDesc_t`|10.1| | | | | | | | | |
327332
|`cublasLtNumericalImplFlags_t`|11.0| | | | | | | | | |
333+
|`cublasLtOrder_t`|10.1| | | | | | | | | |
328334
|`cublasLtPointerModeMask_t`|10.1| | | | | | | | | |
329335
|`cublasLtPointerMode_t`|10.1| | | | | | | | | |
330336

@@ -1212,6 +1218,8 @@
12121218
|`cublasLtGetVersion`|10.1| | | | | | | | | |
12131219
|`cublasLtHeuristicsCacheGetCapacity`|11.8| | | | | | | | | |
12141220
|`cublasLtHeuristicsCacheSetCapacity`|11.8| | | | | | | | | |
1221+
|`cublasLtMatmul`|10.1| | | | | | | | | |
1222+
|`cublasLtMatrixTransform`|10.1| | | | | | | | | |
12151223

12161224

12171225
\*A - Added; D - Deprecated; C - Changed; R - Removed; E - Experimental

src/CUDA2HIP_BLAS_API_functions.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1090,6 +1090,8 @@ const std::map<llvm::StringRef, hipCounter> CUDA_BLAS_FUNCTION_MAP {
10901090
{"cublasLtHeuristicsCacheGetCapacity", {"hipblasLtHeuristicsCacheGetCapacity", "", CONV_LIB_FUNC, API_BLAS, SEC::BLAS_LT, UNSUPPORTED}},
10911091
{"cublasLtHeuristicsCacheSetCapacity", {"hipblasLtHeuristicsCacheSetCapacity", "", CONV_LIB_FUNC, API_BLAS, SEC::BLAS_LT, UNSUPPORTED}},
10921092
{"cublasLtDisableCpuInstructionsSetMask", {"hipblasLtDisableCpuInstructionsSetMask", "", CONV_LIB_FUNC, API_BLAS, SEC::BLAS_LT, UNSUPPORTED}},
1093+
{"cublasLtMatmul", {"hipblasLtMatmul", "", CONV_LIB_FUNC, API_BLAS, SEC::BLAS_LT, ROC_UNSUPPORTED}},
1094+
{"cublasLtMatrixTransform", {"hipblasLtMatrixTransform", "", CONV_LIB_FUNC, API_BLAS, SEC::BLAS_LT, ROC_UNSUPPORTED}},
10931095
};
10941096

10951097
const std::map<llvm::StringRef, cudaAPIversions> CUDA_BLAS_FUNCTION_VER_MAP {
@@ -1550,6 +1552,8 @@ const std::map<llvm::StringRef, cudaAPIversions> CUDA_BLAS_FUNCTION_VER_MAP {
15501552
{"cublasLtHeuristicsCacheGetCapacity", {CUDA_118, CUDA_0, CUDA_0 }},
15511553
{"cublasLtHeuristicsCacheSetCapacity", {CUDA_118, CUDA_0, CUDA_0 }},
15521554
{"cublasLtDisableCpuInstructionsSetMask", {CUDA_121, CUDA_0, CUDA_0 }}, // A: CUDA_VERSION 12011, CUBLAS_VERSION 120103, CUBLAS_VER_MAJOR 12 CUBLAS_VER_MINOR 3
1555+
{"cublasLtMatmul", {CUDA_101, CUDA_0, CUDA_0 }},
1556+
{"cublasLtMatrixTransform", {CUDA_101, CUDA_0, CUDA_0 }},
15531557
};
15541558

15551559
const std::map<llvm::StringRef, hipAPIversions> HIP_BLAS_FUNCTION_VER_MAP {
@@ -1954,6 +1958,8 @@ const std::map<llvm::StringRef, hipAPIversions> HIP_BLAS_FUNCTION_VER_MAP {
19541958
{"hipblasZswap_v2_64", {HIP_6010, HIP_0, HIP_0 }},
19551959
{"hipblasLtCreate", {HIP_5050, HIP_0, HIP_0 }},
19561960
{"hipblasLtDestroy", {HIP_5050, HIP_0, HIP_0 }},
1961+
{"hipblasLtMatmul", {HIP_5050, HIP_0, HIP_0 }},
1962+
{"hipblasLtMatrixTransform", {HIP_6000, HIP_0, HIP_0 }},
19571963

19581964
{"rocblas_status_to_string", {HIP_3050, HIP_0, HIP_0 }},
19591965
{"rocblas_sscal", {HIP_1050, HIP_0, HIP_0 }},

src/CUDA2HIP_BLAS_API_types.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -372,6 +372,12 @@ const std::map<llvm::StringRef, hipCounter> CUDA_BLAS_TYPE_NAME_MAP {
372372
{"CUBLASLT_NUMERICAL_IMPL_FLAGS_OP_INPUT_TYPE_MASK", {"HIPBLASLT_NUMERICAL_IMPL_FLAGS_OP_INPUT_TYPE_MASK", "", CONV_DEFINE, API_BLAS, SEC::BLAS_LT_DATA_TYPES, UNSUPPORTED}},
373373
{"CUBLASLT_NUMERICAL_IMPL_FLAGS_GAUSSIAN", {"HIPBLASLT_NUMERICAL_IMPL_FLAGS_GAUSSIAN", "", CONV_DEFINE, API_BLAS, SEC::BLAS_LT_DATA_TYPES, UNSUPPORTED}},
374374
{"cublasLtNumericalImplFlags_t", {"hipblasLtNumericalImplFlags_t", "", CONV_DEFINE, API_BLAS, SEC::BLAS_LT_DATA_TYPES, UNSUPPORTED}},
375+
{"cublasLtOrder_t", {"hipblasLtOrder_t", "", CONV_DEFINE, API_BLAS, SEC::BLAS_LT_DATA_TYPES, ROC_UNSUPPORTED}},
376+
{"CUBLASLT_ORDER_COL", {"HIPBLASLT_ORDER_COL", "", CONV_DEFINE, API_BLAS, SEC::BLAS_LT_DATA_TYPES, ROC_UNSUPPORTED}},
377+
{"CUBLASLT_ORDER_ROW", {"HIPBLASLT_ORDER_ROW", "", CONV_DEFINE, API_BLAS, SEC::BLAS_LT_DATA_TYPES, ROC_UNSUPPORTED}},
378+
{"CUBLASLT_ORDER_COL32", {"HIPBLASLT_ORDER_COL32", "", CONV_DEFINE, API_BLAS, SEC::BLAS_LT_DATA_TYPES, UNSUPPORTED}},
379+
{"CUBLASLT_ORDER_COL4_4R2_8C", {"HIPBLASLT_ORDER_COL4_4R2_8C", "", CONV_DEFINE, API_BLAS, SEC::BLAS_LT_DATA_TYPES, UNSUPPORTED}},
380+
{"CUBLASLT_ORDER_COL32_2R_4R4", {"HIPBLASLT_ORDER_COL32_2R_4R4", "", CONV_DEFINE, API_BLAS, SEC::BLAS_LT_DATA_TYPES, UNSUPPORTED}},
375381
};
376382

377383
const std::map<llvm::StringRef, cudaAPIversions> CUDA_BLAS_TYPE_NAME_VER_MAP {
@@ -652,6 +658,12 @@ const std::map<llvm::StringRef, cudaAPIversions> CUDA_BLAS_TYPE_NAME_VER_MAP {
652658
{"CUBLASLT_NUMERICAL_IMPL_FLAGS_OP_INPUT_TYPE_MASK", {CUDA_110, CUDA_0, CUDA_0 }}, // A: CUDA_VERSION 11001, CUBLAS_VERSION 11000, CUBLAS_VER_MAJOR 11 CUBLAS_VER_MINOR 0
653659
{"CUBLASLT_NUMERICAL_IMPL_FLAGS_GAUSSIAN", {CUDA_110, CUDA_0, CUDA_0 }}, // A: CUDA_VERSION 11001, CUBLAS_VERSION 11000, CUBLAS_VER_MAJOR 11 CUBLAS_VER_MINOR 0
654660
{"cublasLtNumericalImplFlags_t", {CUDA_110, CUDA_0, CUDA_0 }}, // A: CUDA_VERSION 11001, CUBLAS_VERSION 11000, CUBLAS_VER_MAJOR 11 CUBLAS_VER_MINOR 0
661+
{"cublasLtOrder_t", {CUDA_101, CUDA_0, CUDA_0 }},
662+
{"CUBLASLT_ORDER_COL", {CUDA_101, CUDA_0, CUDA_0 }},
663+
{"CUBLASLT_ORDER_ROW", {CUDA_101, CUDA_0, CUDA_0 }},
664+
{"CUBLASLT_ORDER_COL32", {CUDA_101, CUDA_0, CUDA_0 }},
665+
{"CUBLASLT_ORDER_COL4_4R2_8C", {CUDA_101, CUDA_0, CUDA_0 }},
666+
{"CUBLASLT_ORDER_COL32_2R_4R4", {CUDA_110, CUDA_0, CUDA_0 }}, // A: CUDA_VERSION 11001, CUBLAS_VERSION 11000, CUBLAS_VER_MAJOR 11 CUBLAS_VER_MINOR 0
655667
};
656668

657669
const std::map<llvm::StringRef, hipAPIversions> HIP_BLAS_TYPE_NAME_VER_MAP {
@@ -752,6 +764,9 @@ const std::map<llvm::StringRef, hipAPIversions> HIP_BLAS_TYPE_NAME_VER_MAP {
752764
{"HIPBLASLT_POINTER_MODE_HOST", {HIP_6000, HIP_0, HIP_0 }},
753765
{"HIPBLASLT_POINTER_MODE_DEVICE", {HIP_6010, HIP_0, HIP_0 }},
754766
{"HIPBLASLT_POINTER_MODE_ALPHA_DEVICE_VECTOR_BETA_HOST", {HIP_6000, HIP_0, HIP_0 }},
767+
{"hipblasLtOrder_t", {HIP_6000, HIP_0, HIP_0 }},
768+
{"HIPBLASLT_ORDER_COL", {HIP_6000, HIP_0, HIP_0 }},
769+
{"HIPBLASLT_ORDER_ROW", {HIP_6000, HIP_0, HIP_0 }},
755770

756771
{"rocblas_handle", {HIP_1050, HIP_0, HIP_0 }},
757772
{"_rocblas_handle", {HIP_1050, HIP_0, HIP_0 }},

tests/unit_tests/synthetic/libraries/cublaslt2hipblaslt.cu

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,20 @@ int main() {
1616
// CHECK: hipblasStatus_t status;
1717
cublasStatus_t status;
1818

19+
// CHECK: hipStream_t stream;
20+
cudaStream_t stream;
21+
22+
void *A = nullptr;
23+
void *B = nullptr;
24+
void *C = nullptr;
25+
void *D = nullptr;
26+
void *alpha = nullptr;
27+
void *beta = nullptr;
28+
void *workspace = nullptr;
1929
const char *const_ch = nullptr;
2030

31+
size_t workspaceSizeInBytes = 0;
32+
2133
#if CUDA_VERSION >= 10010
2234
// CHECK: hipblasLtMatmulAlgo_t blasLtMatmulAlgo;
2335
cublasLtMatmulAlgo_t blasLtMatmulAlgo;
@@ -31,6 +43,16 @@ int main() {
3143
// CHECK: hipblasLtMatmulPreference_t blasLtMatmulPreference;
3244
cublasLtMatmulPreference_t blasLtMatmulPreference;
3345

46+
// CHECK: hipblasLtMatrixLayout_t blasLtMatrixLayout, Adesc, Bdesc, Cdesc, Ddesc;
47+
cublasLtMatrixLayout_t blasLtMatrixLayout, Adesc, Bdesc, Cdesc, Ddesc;
48+
49+
// CHECK: hipblasLtOrder_t blasLtOrder;
50+
// CHECK-NEXT: hipblasLtOrder_t BLASLT_ORDER_COL = HIPBLASLT_ORDER_COL;
51+
// CHECK-NEXT: hipblasLtOrder_t BLASLT_ORDER_ROW = HIPBLASLT_ORDER_ROW;
52+
cublasLtOrder_t blasLtOrder;
53+
cublasLtOrder_t BLASLT_ORDER_COL = CUBLASLT_ORDER_COL;
54+
cublasLtOrder_t BLASLT_ORDER_ROW = CUBLASLT_ORDER_ROW;
55+
3456
// CUDA: cublasStatus_t CUBLASWINAPI cublasLtCreate(cublasLtHandle_t* lightHandle);
3557
// HIP: HIPBLASLT_EXPORT hipblasStatus_t hipblasLtCreate(hipblasLtHandle_t* handle);
3658
// CHECK: status = hipblasLtCreate(&blasLtHandle);
@@ -41,6 +63,17 @@ int main() {
4163
// CHECK: status = hipblasLtDestroy(blasLtHandle);
4264
status = cublasLtDestroy(blasLtHandle);
4365

66+
// CUDA: cublasStatus_t CUBLASWINAPI cublasLtMatmul(cublasLtHandle_t lightHandle, cublasLtMatmulDesc_t computeDesc, const void* alpha, const void* A, cublasLtMatrixLayout_t Adesc, const void* B, cublasLtMatrixLayout_t Bdesc, const void* beta, const void* C, cublasLtMatrixLayout_t Cdesc, void* D, cublasLtMatrixLayout_t Ddesc, const cublasLtMatmulAlgo_t* algo, void* workspace, size_t workspaceSizeInBytes, cudaStream_t stream);
67+
// HIP: HIPBLASLT_EXPORT hipblasStatus_t hipblasLtMatmul(hipblasLtHandle_t handle, hipblasLtMatmulDesc_t matmulDesc, const void* alpha, const void* A, hipblasLtMatrixLayout_t Adesc, const void* B, hipblasLtMatrixLayout_t Bdesc, const void* beta, const void* C, hipblasLtMatrixLayout_t Cdesc, void* D, hipblasLtMatrixLayout_t Ddesc, const hipblasLtMatmulAlgo_t* algo, void* workspace, size_t workspaceSizeInBytes, hipStream_t stream);
68+
// CHECK: status = hipblasLtMatmul(blasLtHandle, blasLtMatmulDesc, alpha, A, Adesc, B, Bdesc, beta, C, Cdesc, D, Ddesc, &blasLtMatmulAlgo, workspace, workspaceSizeInBytes, stream);
69+
status = cublasLtMatmul(blasLtHandle, blasLtMatmulDesc, alpha, A, Adesc, B, Bdesc, beta, C, Cdesc, D, Ddesc, &blasLtMatmulAlgo, workspace, workspaceSizeInBytes, stream);
70+
71+
// CUDA: cublasStatus_t CUBLASWINAPI cublasLtMatrixTransform(cublasLtHandle_t lightHandle, cublasLtMatrixTransformDesc_t transformDesc, const void* alpha, const void* A, cublasLtMatrixLayout_t Adesc, const void* beta, const void* B, cublasLtMatrixLayout_t Bdesc, void* C, cublasLtMatrixLayout_t Cdesc, cudaStream_t stream);
72+
// HIP: HIPBLASLT_EXPORT hipblasStatus_t hipblasLtMatrixTransform(hipblasLtHandle_t lightHandle, hipblasLtMatrixTransformDesc_t transformDesc, const void* alpha, const void* A, hipblasLtMatrixLayout_t Adesc, const void* beta, const void* B, hipblasLtMatrixLayout_t Bdesc, void* C, hipblasLtMatrixLayout_t Cdesc, hipStream_t stream);
73+
// CHECK: status = hipblasLtMatrixTransform(blasLtHandle, blasLtMatrixTransformDesc, alpha, A, Adesc, beta, B, Bdesc, C, Cdesc, stream);
74+
status = cublasLtMatrixTransform(blasLtHandle, blasLtMatrixTransformDesc, alpha, A, Adesc, beta, B, Bdesc, C, Cdesc, stream);
75+
#endif
76+
4477
#if CUBLAS_VERSION >= 10200
4578
// CHECK: hipblasLtPointerMode_t blasLtPointerMode;
4679
// CHECK-NEXT: hipblasLtPointerMode_t BLASLT_POINTER_MODE_HOST = HIPBLASLT_POINTER_MODE_HOST;
@@ -49,7 +82,6 @@ int main() {
4982
cublasLtPointerMode_t BLASLT_POINTER_MODE_HOST = CUBLASLT_POINTER_MODE_HOST;
5083
cublasLtPointerMode_t BLASLT_POINTER_MODE_DEVICE = CUBLASLT_POINTER_MODE_DEVICE;
5184
#endif
52-
#endif
5385

5486
#if CUDA_VERSION >= 11000 && CUBLAS_VERSION >= 11000
5587
// CHECK: hipblasLtMatrixLayoutOpaque_t blasLtMatrixLayoutOpaque;

0 commit comments

Comments
 (0)