diff --git a/src/CUDA2HIP_SPARSE_API_functions.cpp b/src/CUDA2HIP_SPARSE_API_functions.cpp index e1fa2bde..b3b32810 100644 --- a/src/CUDA2HIP_SPARSE_API_functions.cpp +++ b/src/CUDA2HIP_SPARSE_API_functions.cpp @@ -1365,10 +1365,10 @@ const std::map CUDA_SPARSE_FUNCTION_VER_MAP { {"cusparseCcsrilu02_numericBoost", {CUDA_0, CUDA_122, CUDA_0 }}, // CUSPARSE_VERSION 12120 {"cusparseZcsrilu02_numericBoost", {CUDA_0, CUDA_122, CUDA_0 }}, // CUSPARSE_VERSION 12120 {"cusparseXcsrilu02_zeroPivot", {CUDA_0, CUDA_122, CUDA_0 }}, // CUSPARSE_VERSION 12120 - {"cusparseScsrilu02_bufferSize", {CUDA_0, CUDA_122, CUDA_0 }}, // CUSPARSE_VERSION 12120 - {"cusparseDcsrilu02_bufferSize", {CUDA_0, CUDA_122, CUDA_0 }}, // CUSPARSE_VERSION 12120 - {"cusparseCcsrilu02_bufferSize", {CUDA_0, CUDA_122, CUDA_0 }}, // CUSPARSE_VERSION 12120 - {"cusparseZcsrilu02_bufferSize", {CUDA_0, CUDA_122, CUDA_0 }}, // CUSPARSE_VERSION 12120 + {"cusparseScsrilu02_bufferSize", {CUDA_0, CUDA_122, CUDA_0 }}, // D: CUSPARSE_VERSION 12102 + {"cusparseDcsrilu02_bufferSize", {CUDA_0, CUDA_122, CUDA_0 }}, // D: CUSPARSE_VERSION 12102 + {"cusparseCcsrilu02_bufferSize", {CUDA_0, CUDA_122, CUDA_0 }}, // D: CUSPARSE_VERSION 12102 + {"cusparseZcsrilu02_bufferSize", {CUDA_0, CUDA_122, CUDA_0 }}, // D: CUSPARSE_VERSION 12102 {"cusparseScsrilu02_bufferSizeExt", {CUDA_0, CUDA_122, CUDA_0 }}, // CUSPARSE_VERSION 12120 {"cusparseDcsrilu02_bufferSizeExt", {CUDA_0, CUDA_122, CUDA_0 }}, // CUSPARSE_VERSION 12120 {"cusparseCcsrilu02_bufferSizeExt", {CUDA_0, CUDA_122, CUDA_0 }}, // CUSPARSE_VERSION 12120 diff --git a/src/HipifyAction.cpp b/src/HipifyAction.cpp index b02f7005..8ff806dd 100644 --- a/src/HipifyAction.cpp +++ b/src/HipifyAction.cpp @@ -195,6 +195,10 @@ const std::string sCusparseZcsrgemm2 = "cusparseZcsrgemm2"; const std::string sCusparseCcsrgemm2 = "cusparseCcsrgemm2"; const std::string sCusparseDcsrgemm2 = "cusparseDcsrgemm2"; const std::string sCusparseScsrgemm2 = "cusparseScsrgemm2"; +const std::string sCusparseZcsrilu02_bufferSize = "cusparseZcsrilu02_bufferSize"; +const std::string sCusparseCcsrilu02_bufferSize = "cusparseCcsrilu02_bufferSize"; +const std::string sCusparseDcsrilu02_bufferSize = "cusparseDcsrilu02_bufferSize"; +const std::string sCusparseScsrilu02_bufferSize = "cusparseScsrilu02_bufferSize"; // CUDA_OVERLOADED const std::string sCudaEventCreate = "cudaEventCreate"; @@ -1535,6 +1539,42 @@ std::map FuncArgCasts { false } }, + {sCusparseZcsrilu02_bufferSize, + { + { + {8, {e_reinterpret_cast_size_t, cw_None}} + }, + true, + false + } + }, + {sCusparseCcsrilu02_bufferSize, + { + { + {8, {e_reinterpret_cast_size_t, cw_None}} + }, + true, + false + } + }, + {sCusparseDcsrilu02_bufferSize, + { + { + {8, {e_reinterpret_cast_size_t, cw_None}} + }, + true, + false + } + }, + {sCusparseScsrilu02_bufferSize, + { + { + {8, {e_reinterpret_cast_size_t, cw_None}} + }, + true, + false + } + }, }; void HipifyAction::RewriteString(StringRef s, clang::SourceLocation start) { @@ -2370,7 +2410,11 @@ std::unique_ptr HipifyAction::CreateASTConsumer(clang::Compi sCusparseZcsrgemm2, sCusparseCcsrgemm2, sCusparseDcsrgemm2, - sCusparseScsrgemm2 + sCusparseScsrgemm2, + sCusparseZcsrilu02_bufferSize, + sCusparseCcsrilu02_bufferSize, + sCusparseDcsrilu02_bufferSize, + sCusparseScsrilu02_bufferSize ) ) ) diff --git a/tests/unit_tests/synthetic/libraries/cusparse2rocsparse.cu b/tests/unit_tests/synthetic/libraries/cusparse2rocsparse.cu index 1f0a6c1b..48b77f76 100644 --- a/tests/unit_tests/synthetic/libraries/cusparse2rocsparse.cu +++ b/tests/unit_tests/synthetic/libraries/cusparse2rocsparse.cu @@ -752,22 +752,22 @@ int main() { // CUDA: CUSPARSE_DEPRECATED cusparseStatus_t CUSPARSEAPI cusparseZcsrilu02_bufferSize(cusparseHandle_t handle, int m, int nnz, const cusparseMatDescr_t descrA, cuDoubleComplex* csrSortedValA, const int* csrSortedRowPtrA, const int* csrSortedColIndA, csrilu02Info_t info, int* pBufferSizeInBytes); // ROC: ROCSPARSE_EXPORT rocsparse_status rocsparse_zcsrilu0_buffer_size(rocsparse_handle handle, rocsparse_int m, rocsparse_int nnz, const rocsparse_mat_descr descr, const rocsparse_double_complex* csr_val, const rocsparse_int* csr_row_ptr, const rocsparse_int* csr_col_ind, rocsparse_mat_info info, size_t* buffer_size); - // CHECK: status_t = rocsparse_zcsrilu0_buffer_size(handle_t, m, innz, matDescr_A, &dComplexcsrSortedValA, &csrRowPtrA, &csrColIndA, csrilu02_info, &bufferSizeInBytes); + // CHECK: status_t = rocsparse_zcsrilu0_buffer_size(handle_t, m, innz, matDescr_A, &dComplexcsrSortedValA, &csrRowPtrA, &csrColIndA, csrilu02_info, reinterpret_cast(&bufferSizeInBytes)); status_t = cusparseZcsrilu02_bufferSize(handle_t, m, innz, matDescr_A, &dComplexcsrSortedValA, &csrRowPtrA, &csrColIndA, csrilu02_info, &bufferSizeInBytes); // CUDA: CUSPARSE_DEPRECATED cusparseStatus_t CUSPARSEAPI cusparseCcsrilu02_bufferSize(cusparseHandle_t handle, int m, int nnz, const cusparseMatDescr_t descrA, cuComplex* csrSortedValA, const int* csrSortedRowPtrA, const int* csrSortedColIndA, csrilu02Info_t info, int* pBufferSizeInBytes); // ROC: ROCSPARSE_EXPORT rocsparse_status rocsparse_ccsrilu0_buffer_size(rocsparse_handle handle, rocsparse_int m, rocsparse_int nnz, const rocsparse_mat_descr descr, const rocsparse_float_complex* csr_val, const rocsparse_int* csr_row_ptr, const rocsparse_int* csr_col_ind, rocsparse_mat_info info, size_t* buffer_size); - // CHECK: status_t = rocsparse_ccsrilu0_buffer_size(handle_t, m, innz, matDescr_A, &complexcsrSortedValA, &csrRowPtrA, &csrColIndA, csrilu02_info, &bufferSizeInBytes); + // CHECK: status_t = rocsparse_ccsrilu0_buffer_size(handle_t, m, innz, matDescr_A, &complexcsrSortedValA, &csrRowPtrA, &csrColIndA, csrilu02_info, reinterpret_cast(&bufferSizeInBytes)); status_t = cusparseCcsrilu02_bufferSize(handle_t, m, innz, matDescr_A, &complexcsrSortedValA, &csrRowPtrA, &csrColIndA, csrilu02_info, &bufferSizeInBytes); // CUDA: CUSPARSE_DEPRECATED cusparseStatus_t CUSPARSEAPI cusparseDcsrilu02_bufferSize(cusparseHandle_t handle, int m, int nnz, const cusparseMatDescr_t descrA, double* csrSortedValA, const int* csrSortedRowPtrA, const int* csrSortedColIndA, csrilu02Info_t info, int* pBufferSizeInBytes); // ROC: ROCSPARSE_EXPORT rocsparse_status rocsparse_dcsrilu0_buffer_size(rocsparse_handle handle, rocsparse_int m, rocsparse_int nnz, const rocsparse_mat_descr descr, const double* csr_val, const rocsparse_int* csr_row_ptr, const rocsparse_int* csr_col_ind, rocsparse_mat_info info, size_t* buffer_size); - // CHECK: status_t = rocsparse_dcsrilu0_buffer_size(handle_t, m, innz, matDescr_A, &dcsrSortedValA, &csrRowPtrA, &csrColIndA, csrilu02_info, &bufferSizeInBytes); + // CHECK: status_t = rocsparse_dcsrilu0_buffer_size(handle_t, m, innz, matDescr_A, &dcsrSortedValA, &csrRowPtrA, &csrColIndA, csrilu02_info, reinterpret_cast(&bufferSizeInBytes)); status_t = cusparseDcsrilu02_bufferSize(handle_t, m, innz, matDescr_A, &dcsrSortedValA, &csrRowPtrA, &csrColIndA, csrilu02_info, &bufferSizeInBytes); // CUDA: CUSPARSE_DEPRECATED cusparseStatus_t CUSPARSEAPI cusparseScsrilu02_bufferSize(cusparseHandle_t handle, int m, int nnz, const cusparseMatDescr_t descrA, float* csrSortedValA, const int* csrSortedRowPtrA, const int* csrSortedColIndA, csrilu02Info_t info, int* pBufferSizeInBytes); // ROC: ROCSPARSE_EXPORT rocsparse_status rocsparse_scsrilu0_buffer_size(rocsparse_handle handle, rocsparse_int m, rocsparse_int nnz, const rocsparse_mat_descr descr, const float* csr_val, const rocsparse_int* csr_row_ptr, const rocsparse_int* csr_col_ind, rocsparse_mat_info info, size_t* buffer_size); - // CHECK: status_t = rocsparse_scsrilu0_buffer_size(handle_t, m, innz, matDescr_A, &csrSortedValA, &csrRowPtrA, &csrColIndA, csrilu02_info, &bufferSizeInBytes); + // CHECK: status_t = rocsparse_scsrilu0_buffer_size(handle_t, m, innz, matDescr_A, &csrSortedValA, &csrRowPtrA, &csrColIndA, csrilu02_info, reinterpret_cast(&bufferSizeInBytes)); status_t = cusparseScsrilu02_bufferSize(handle_t, m, innz, matDescr_A, &csrSortedValA, &csrRowPtrA, &csrColIndA, csrilu02_info, &bufferSizeInBytes); // CUDA: CUSPARSE_DEPRECATED cusparseStatus_t CUSPARSEAPI cusparseZcsrilu02_numericBoost(cusparseHandle_t handle, csrilu02Info_t info, int enable_boost, double* tol, cuDoubleComplex* boost_val);