Skip to content

Commit

Permalink
[HIPIFY][rocSPARSE][fix] Fix hipification of `rocsparse_(s|d|c|z)csri…
Browse files Browse the repository at this point in the history
…lu0_buffer_size`

[Reason] Their hipification from `cusparse(S|C|D|Z)csrilu02_bufferSize` needs a `reinterpret_cast<size_t*>` for the last argument
  • Loading branch information
emankov committed Jan 10, 2024
1 parent a86e930 commit 03747f3
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 9 deletions.
8 changes: 4 additions & 4 deletions src/CUDA2HIP_SPARSE_API_functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1365,10 +1365,10 @@ const std::map<llvm::StringRef, cudaAPIversions> CUDA_SPARSE_FUNCTION_VER_MAP {
{"cusparseCcsrilu02_numericBoost", {CUDA_0, CUDA_122, CUDA_0 }}, // CUSPARSE_VERSION 12120
{"cusparseZcsrilu02_numericBoost", {CUDA_0, CUDA_122, CUDA_0 }}, // CUSPARSE_VERSION 12120
{"cusparseXcsrilu02_zeroPivot", {CUDA_0, CUDA_122, CUDA_0 }}, // CUSPARSE_VERSION 12120
{"cusparseScsrilu02_bufferSize", {CUDA_0, CUDA_122, CUDA_0 }}, // CUSPARSE_VERSION 12120
{"cusparseDcsrilu02_bufferSize", {CUDA_0, CUDA_122, CUDA_0 }}, // CUSPARSE_VERSION 12120
{"cusparseCcsrilu02_bufferSize", {CUDA_0, CUDA_122, CUDA_0 }}, // CUSPARSE_VERSION 12120
{"cusparseZcsrilu02_bufferSize", {CUDA_0, CUDA_122, CUDA_0 }}, // CUSPARSE_VERSION 12120
{"cusparseScsrilu02_bufferSize", {CUDA_0, CUDA_122, CUDA_0 }}, // D: CUSPARSE_VERSION 12102
{"cusparseDcsrilu02_bufferSize", {CUDA_0, CUDA_122, CUDA_0 }}, // D: CUSPARSE_VERSION 12102
{"cusparseCcsrilu02_bufferSize", {CUDA_0, CUDA_122, CUDA_0 }}, // D: CUSPARSE_VERSION 12102
{"cusparseZcsrilu02_bufferSize", {CUDA_0, CUDA_122, CUDA_0 }}, // D: CUSPARSE_VERSION 12102
{"cusparseScsrilu02_bufferSizeExt", {CUDA_0, CUDA_122, CUDA_0 }}, // CUSPARSE_VERSION 12120
{"cusparseDcsrilu02_bufferSizeExt", {CUDA_0, CUDA_122, CUDA_0 }}, // CUSPARSE_VERSION 12120
{"cusparseCcsrilu02_bufferSizeExt", {CUDA_0, CUDA_122, CUDA_0 }}, // CUSPARSE_VERSION 12120
Expand Down
46 changes: 45 additions & 1 deletion src/HipifyAction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,10 @@ const std::string sCusparseZcsrgemm2 = "cusparseZcsrgemm2";
const std::string sCusparseCcsrgemm2 = "cusparseCcsrgemm2";
const std::string sCusparseDcsrgemm2 = "cusparseDcsrgemm2";
const std::string sCusparseScsrgemm2 = "cusparseScsrgemm2";
const std::string sCusparseZcsrilu02_bufferSize = "cusparseZcsrilu02_bufferSize";
const std::string sCusparseCcsrilu02_bufferSize = "cusparseCcsrilu02_bufferSize";
const std::string sCusparseDcsrilu02_bufferSize = "cusparseDcsrilu02_bufferSize";
const std::string sCusparseScsrilu02_bufferSize = "cusparseScsrilu02_bufferSize";

// CUDA_OVERLOADED
const std::string sCudaEventCreate = "cudaEventCreate";
Expand Down Expand Up @@ -1535,6 +1539,42 @@ std::map<std::string, ArgCastStruct> FuncArgCasts {
false
}
},
{sCusparseZcsrilu02_bufferSize,
{
{
{8, {e_reinterpret_cast_size_t, cw_None}}
},
true,
false
}
},
{sCusparseCcsrilu02_bufferSize,
{
{
{8, {e_reinterpret_cast_size_t, cw_None}}
},
true,
false
}
},
{sCusparseDcsrilu02_bufferSize,
{
{
{8, {e_reinterpret_cast_size_t, cw_None}}
},
true,
false
}
},
{sCusparseScsrilu02_bufferSize,
{
{
{8, {e_reinterpret_cast_size_t, cw_None}}
},
true,
false
}
},
};

void HipifyAction::RewriteString(StringRef s, clang::SourceLocation start) {
Expand Down Expand Up @@ -2370,7 +2410,11 @@ std::unique_ptr<clang::ASTConsumer> HipifyAction::CreateASTConsumer(clang::Compi
sCusparseZcsrgemm2,
sCusparseCcsrgemm2,
sCusparseDcsrgemm2,
sCusparseScsrgemm2
sCusparseScsrgemm2,
sCusparseZcsrilu02_bufferSize,
sCusparseCcsrilu02_bufferSize,
sCusparseDcsrilu02_bufferSize,
sCusparseScsrilu02_bufferSize
)
)
)
Expand Down
8 changes: 4 additions & 4 deletions tests/unit_tests/synthetic/libraries/cusparse2rocsparse.cu
Original file line number Diff line number Diff line change
Expand Up @@ -752,22 +752,22 @@ int main() {

// CUDA: CUSPARSE_DEPRECATED cusparseStatus_t CUSPARSEAPI cusparseZcsrilu02_bufferSize(cusparseHandle_t handle, int m, int nnz, const cusparseMatDescr_t descrA, cuDoubleComplex* csrSortedValA, const int* csrSortedRowPtrA, const int* csrSortedColIndA, csrilu02Info_t info, int* pBufferSizeInBytes);
// ROC: ROCSPARSE_EXPORT rocsparse_status rocsparse_zcsrilu0_buffer_size(rocsparse_handle handle, rocsparse_int m, rocsparse_int nnz, const rocsparse_mat_descr descr, const rocsparse_double_complex* csr_val, const rocsparse_int* csr_row_ptr, const rocsparse_int* csr_col_ind, rocsparse_mat_info info, size_t* buffer_size);
// CHECK: status_t = rocsparse_zcsrilu0_buffer_size(handle_t, m, innz, matDescr_A, &dComplexcsrSortedValA, &csrRowPtrA, &csrColIndA, csrilu02_info, &bufferSizeInBytes);
// CHECK: status_t = rocsparse_zcsrilu0_buffer_size(handle_t, m, innz, matDescr_A, &dComplexcsrSortedValA, &csrRowPtrA, &csrColIndA, csrilu02_info, reinterpret_cast<size_t*>(&bufferSizeInBytes));
status_t = cusparseZcsrilu02_bufferSize(handle_t, m, innz, matDescr_A, &dComplexcsrSortedValA, &csrRowPtrA, &csrColIndA, csrilu02_info, &bufferSizeInBytes);

// CUDA: CUSPARSE_DEPRECATED cusparseStatus_t CUSPARSEAPI cusparseCcsrilu02_bufferSize(cusparseHandle_t handle, int m, int nnz, const cusparseMatDescr_t descrA, cuComplex* csrSortedValA, const int* csrSortedRowPtrA, const int* csrSortedColIndA, csrilu02Info_t info, int* pBufferSizeInBytes);
// ROC: ROCSPARSE_EXPORT rocsparse_status rocsparse_ccsrilu0_buffer_size(rocsparse_handle handle, rocsparse_int m, rocsparse_int nnz, const rocsparse_mat_descr descr, const rocsparse_float_complex* csr_val, const rocsparse_int* csr_row_ptr, const rocsparse_int* csr_col_ind, rocsparse_mat_info info, size_t* buffer_size);
// CHECK: status_t = rocsparse_ccsrilu0_buffer_size(handle_t, m, innz, matDescr_A, &complexcsrSortedValA, &csrRowPtrA, &csrColIndA, csrilu02_info, &bufferSizeInBytes);
// CHECK: status_t = rocsparse_ccsrilu0_buffer_size(handle_t, m, innz, matDescr_A, &complexcsrSortedValA, &csrRowPtrA, &csrColIndA, csrilu02_info, reinterpret_cast<size_t*>(&bufferSizeInBytes));
status_t = cusparseCcsrilu02_bufferSize(handle_t, m, innz, matDescr_A, &complexcsrSortedValA, &csrRowPtrA, &csrColIndA, csrilu02_info, &bufferSizeInBytes);

// CUDA: CUSPARSE_DEPRECATED cusparseStatus_t CUSPARSEAPI cusparseDcsrilu02_bufferSize(cusparseHandle_t handle, int m, int nnz, const cusparseMatDescr_t descrA, double* csrSortedValA, const int* csrSortedRowPtrA, const int* csrSortedColIndA, csrilu02Info_t info, int* pBufferSizeInBytes);
// ROC: ROCSPARSE_EXPORT rocsparse_status rocsparse_dcsrilu0_buffer_size(rocsparse_handle handle, rocsparse_int m, rocsparse_int nnz, const rocsparse_mat_descr descr, const double* csr_val, const rocsparse_int* csr_row_ptr, const rocsparse_int* csr_col_ind, rocsparse_mat_info info, size_t* buffer_size);
// CHECK: status_t = rocsparse_dcsrilu0_buffer_size(handle_t, m, innz, matDescr_A, &dcsrSortedValA, &csrRowPtrA, &csrColIndA, csrilu02_info, &bufferSizeInBytes);
// CHECK: status_t = rocsparse_dcsrilu0_buffer_size(handle_t, m, innz, matDescr_A, &dcsrSortedValA, &csrRowPtrA, &csrColIndA, csrilu02_info, reinterpret_cast<size_t*>(&bufferSizeInBytes));
status_t = cusparseDcsrilu02_bufferSize(handle_t, m, innz, matDescr_A, &dcsrSortedValA, &csrRowPtrA, &csrColIndA, csrilu02_info, &bufferSizeInBytes);

// CUDA: CUSPARSE_DEPRECATED cusparseStatus_t CUSPARSEAPI cusparseScsrilu02_bufferSize(cusparseHandle_t handle, int m, int nnz, const cusparseMatDescr_t descrA, float* csrSortedValA, const int* csrSortedRowPtrA, const int* csrSortedColIndA, csrilu02Info_t info, int* pBufferSizeInBytes);
// ROC: ROCSPARSE_EXPORT rocsparse_status rocsparse_scsrilu0_buffer_size(rocsparse_handle handle, rocsparse_int m, rocsparse_int nnz, const rocsparse_mat_descr descr, const float* csr_val, const rocsparse_int* csr_row_ptr, const rocsparse_int* csr_col_ind, rocsparse_mat_info info, size_t* buffer_size);
// CHECK: status_t = rocsparse_scsrilu0_buffer_size(handle_t, m, innz, matDescr_A, &csrSortedValA, &csrRowPtrA, &csrColIndA, csrilu02_info, &bufferSizeInBytes);
// CHECK: status_t = rocsparse_scsrilu0_buffer_size(handle_t, m, innz, matDescr_A, &csrSortedValA, &csrRowPtrA, &csrColIndA, csrilu02_info, reinterpret_cast<size_t*>(&bufferSizeInBytes));
status_t = cusparseScsrilu02_bufferSize(handle_t, m, innz, matDescr_A, &csrSortedValA, &csrRowPtrA, &csrColIndA, csrilu02_info, &bufferSizeInBytes);

// CUDA: CUSPARSE_DEPRECATED cusparseStatus_t CUSPARSEAPI cusparseZcsrilu02_numericBoost(cusparseHandle_t handle, csrilu02Info_t info, int enable_boost, double* tol, cuDoubleComplex* boost_val);
Expand Down

0 comments on commit 03747f3

Please sign in to comment.