diff --git a/rocsolver/clients/gtest/CMakeLists.txt b/rocsolver/clients/gtest/CMakeLists.txt index 0104ce31e..c1ad549c0 100755 --- a/rocsolver/clients/gtest/CMakeLists.txt +++ b/rocsolver/clients/gtest/CMakeLists.txt @@ -46,27 +46,34 @@ set( THREADS_PREFER_PTHREAD_FLAG ON ) find_package( Threads REQUIRED ) set(roclapack_test_source + # vector & matrix manipulations lacgv_gtest.cpp laswp_gtest.cpp + # householder reflections larfg_gtest.cpp larf_gtest.cpp larft_gtest.cpp larfb_gtest.cpp - labrd_gtest.cpp - bdsqr_gtest.cpp + # orthonormal/unitary matrices orgxr_ungxr_gtest.cpp orglx_unglx_gtest.cpp ormxr_unmxr_gtest.cpp ormlx_unmlx_gtest.cpp orgbr_ungbr_gtest.cpp ormbr_unmbr_gtest.cpp - getf2_getrf_gtest.cpp + # bidiagonal matrices and svd + labrd_gtest.cpp + bdsqr_gtest.cpp + # triangular factorizations and linear solvers potf2_potrf_gtest.cpp - getrs_gtest.cpp + getf2_getrf_gtest.cpp getri_gtest.cpp + getrs_gtest.cpp + # orthogonal factorizations geqr2_geqrf_gtest.cpp geql2_geqlf_gtest.cpp gelq2_gelqf_gtest.cpp + # bidiagonalization and svd gebd2_gebrd_gtest.cpp gesvd_gtest.cpp ) diff --git a/rocsolver/clients/include/testing_getrs.hpp b/rocsolver/clients/include/testing_getrs.hpp index fa99aec22..9da596080 100644 --- a/rocsolver/clients/include/testing_getrs.hpp +++ b/rocsolver/clients/include/testing_getrs.hpp @@ -235,10 +235,10 @@ void getrs_getPerfData(const rocblas_handle handle, template void testing_getrs(Arguments argus) { rocblas_local_handle handle; - /* Set handle memory size to a large enough value for all tests to pass. - (TODO: A more definitive solution could be implemented once - the handle memory model APIs are enabled in rocsolver)*/ - rocblas_set_device_memory_size(handle, 20000000); + // /* Set handle memory size to a large enough value for all tests to pass. + // (TODO: A more definitive solution could be implemented once + // the handle memory model APIs are enabled in rocsolver)*/ + // rocblas_set_device_memory_size(handle, 20000000); // get arguments rocblas_int m = argus.M; diff --git a/rocsolver/library/src/CMakeLists.txt b/rocsolver/library/src/CMakeLists.txt index cfc6bffd4..bec353203 100755 --- a/rocsolver/library/src/CMakeLists.txt +++ b/rocsolver/library/src/CMakeLists.txt @@ -27,14 +27,16 @@ set(THREADS_PREFER_PTHREAD_FLAG ON) find_package(Threads REQUIRED) set( rocsolver_auxiliary_source + # vector & matrix manipulations auxiliary/rocauxiliary_aliases.cpp auxiliary/rocauxiliary_lacgv.cpp auxiliary/rocauxiliary_laswp.cpp + # householder reflections auxiliary/rocauxiliary_larfg.cpp auxiliary/rocauxiliary_larf.cpp auxiliary/rocauxiliary_larft.cpp auxiliary/rocauxiliary_larfb.cpp - auxiliary/rocauxiliary_labrd.cpp + # orthonormal/unitary matrices auxiliary/rocauxiliary_org2r_ung2r.cpp auxiliary/rocauxiliary_orgqr_ungqr.cpp auxiliary/rocauxiliary_orgl2_ungl2.cpp @@ -45,10 +47,13 @@ set( rocsolver_auxiliary_source auxiliary/rocauxiliary_orml2_unml2.cpp auxiliary/rocauxiliary_ormlq_unmlq.cpp auxiliary/rocauxiliary_ormbr_unmbr.cpp + # bidiagonal matrices and svd auxiliary/rocauxiliary_bdsqr.cpp + auxiliary/rocauxiliary_labrd.cpp ) set( rocsolver_lapack_source + # triangular factorizations and linear solvers lapack/roclapack_getf2.cpp lapack/roclapack_getf2_batched.cpp lapack/roclapack_getf2_strided_batched.cpp @@ -68,6 +73,7 @@ set( rocsolver_lapack_source lapack/roclapack_potrf.cpp lapack/roclapack_potrf_batched.cpp lapack/roclapack_potrf_strided_batched.cpp + # orthogonal factorizations lapack/roclapack_geqr2.cpp lapack/roclapack_geqr2_batched.cpp lapack/roclapack_geqr2_strided_batched.cpp @@ -87,6 +93,7 @@ set( rocsolver_lapack_source lapack/roclapack_gelqf.cpp lapack/roclapack_gelqf_batched.cpp lapack/roclapack_gelqf_strided_batched.cpp + # bidiagonalization and svd lapack/roclapack_gebd2.cpp lapack/roclapack_gebd2_batched.cpp lapack/roclapack_gebd2_strided_batched.cpp @@ -100,7 +107,6 @@ set( rocsolver_lapack_source set( auxiliaries buildinfo.cpp - rocblas.cpp ) prepend_path( ".." rocsolver_headers_public relative_rocsolver_headers_public ) diff --git a/rocsolver/library/src/auxiliary/rocauxiliary_bdsqr.cpp b/rocsolver/library/src/auxiliary/rocauxiliary_bdsqr.cpp index 4e30da1d8..6c2c52c4e 100644 --- a/rocsolver/library/src/auxiliary/rocauxiliary_bdsqr.cpp +++ b/rocsolver/library/src/auxiliary/rocauxiliary_bdsqr.cpp @@ -61,36 +61,40 @@ rocsolver_bdsqr_impl(rocblas_handle handle, const rocblas_fill uplo, extern "C" { -ROCSOLVER_EXPORT rocblas_status rocsolver_sbdsqr( - rocblas_handle handle, const rocblas_fill uplo, const rocblas_int n, - const rocblas_int nv, const rocblas_int nu, const rocblas_int nc, float *D, - float *E, float *V, const rocblas_int ldv, float *U, const rocblas_int ldu, - float *C, const rocblas_int ldc, rocblas_int *info) { +rocblas_status rocsolver_sbdsqr(rocblas_handle handle, const rocblas_fill uplo, + const rocblas_int n, const rocblas_int nv, + const rocblas_int nu, const rocblas_int nc, + float *D, float *E, float *V, + const rocblas_int ldv, float *U, + const rocblas_int ldu, float *C, + const rocblas_int ldc, rocblas_int *info) { return rocsolver_bdsqr_impl(handle, uplo, n, nv, nu, nc, D, E, V, ldv, U, ldu, C, ldc, info); } -ROCSOLVER_EXPORT rocblas_status rocsolver_dbdsqr( - rocblas_handle handle, const rocblas_fill uplo, const rocblas_int n, - const rocblas_int nv, const rocblas_int nu, const rocblas_int nc, double *D, - double *E, double *V, const rocblas_int ldv, double *U, - const rocblas_int ldu, double *C, const rocblas_int ldc, - rocblas_int *info) { +rocblas_status rocsolver_dbdsqr(rocblas_handle handle, const rocblas_fill uplo, + const rocblas_int n, const rocblas_int nv, + const rocblas_int nu, const rocblas_int nc, + double *D, double *E, double *V, + const rocblas_int ldv, double *U, + const rocblas_int ldu, double *C, + const rocblas_int ldc, rocblas_int *info) { return rocsolver_bdsqr_impl(handle, uplo, n, nv, nu, nc, D, E, V, ldv, U, ldu, C, ldc, info); } -ROCSOLVER_EXPORT rocblas_status rocsolver_cbdsqr( - rocblas_handle handle, const rocblas_fill uplo, const rocblas_int n, - const rocblas_int nv, const rocblas_int nu, const rocblas_int nc, float *D, - float *E, rocblas_float_complex *V, const rocblas_int ldv, - rocblas_float_complex *U, const rocblas_int ldu, rocblas_float_complex *C, - const rocblas_int ldc, rocblas_int *info) { +rocblas_status rocsolver_cbdsqr(rocblas_handle handle, const rocblas_fill uplo, + const rocblas_int n, const rocblas_int nv, + const rocblas_int nu, const rocblas_int nc, + float *D, float *E, rocblas_float_complex *V, + const rocblas_int ldv, rocblas_float_complex *U, + const rocblas_int ldu, rocblas_float_complex *C, + const rocblas_int ldc, rocblas_int *info) { return rocsolver_bdsqr_impl( handle, uplo, n, nv, nu, nc, D, E, V, ldv, U, ldu, C, ldc, info); } -ROCSOLVER_EXPORT rocblas_status rocsolver_zbdsqr( +rocblas_status rocsolver_zbdsqr( rocblas_handle handle, const rocblas_fill uplo, const rocblas_int n, const rocblas_int nv, const rocblas_int nu, const rocblas_int nc, double *D, double *E, rocblas_double_complex *V, const rocblas_int ldv, diff --git a/rocsolver/library/src/auxiliary/rocauxiliary_labrd.cpp b/rocsolver/library/src/auxiliary/rocauxiliary_labrd.cpp index 35a8628d2..1c5d409f2 100644 --- a/rocsolver/library/src/auxiliary/rocauxiliary_labrd.cpp +++ b/rocsolver/library/src/auxiliary/rocauxiliary_labrd.cpp @@ -76,35 +76,39 @@ rocsolver_labrd_impl(rocblas_handle handle, const rocblas_int m, extern "C" { -ROCSOLVER_EXPORT rocblas_status rocsolver_slabrd( - rocblas_handle handle, const rocblas_int m, const rocblas_int n, - const rocblas_int k, float *A, const rocblas_int lda, float *D, float *E, - float *tauq, float *taup, float *X, const rocblas_int ldx, float *Y, - const rocblas_int ldy) { +rocblas_status rocsolver_slabrd(rocblas_handle handle, const rocblas_int m, + const rocblas_int n, const rocblas_int k, + float *A, const rocblas_int lda, float *D, + float *E, float *tauq, float *taup, float *X, + const rocblas_int ldx, float *Y, + const rocblas_int ldy) { return rocsolver_labrd_impl(handle, m, n, k, A, lda, D, E, tauq, taup, X, ldx, Y, ldy); } -ROCSOLVER_EXPORT rocblas_status rocsolver_dlabrd( - rocblas_handle handle, const rocblas_int m, const rocblas_int n, - const rocblas_int k, double *A, const rocblas_int lda, double *D, double *E, - double *tauq, double *taup, double *X, const rocblas_int ldx, double *Y, - const rocblas_int ldy) { +rocblas_status rocsolver_dlabrd(rocblas_handle handle, const rocblas_int m, + const rocblas_int n, const rocblas_int k, + double *A, const rocblas_int lda, double *D, + double *E, double *tauq, double *taup, + double *X, const rocblas_int ldx, double *Y, + const rocblas_int ldy) { return rocsolver_labrd_impl(handle, m, n, k, A, lda, D, E, tauq, taup, X, ldx, Y, ldy); } -ROCSOLVER_EXPORT rocblas_status rocsolver_clabrd( - rocblas_handle handle, const rocblas_int m, const rocblas_int n, - const rocblas_int k, rocblas_float_complex *A, const rocblas_int lda, - float *D, float *E, rocblas_float_complex *tauq, - rocblas_float_complex *taup, rocblas_float_complex *X, - const rocblas_int ldx, rocblas_float_complex *Y, const rocblas_int ldy) { +rocblas_status rocsolver_clabrd(rocblas_handle handle, const rocblas_int m, + const rocblas_int n, const rocblas_int k, + rocblas_float_complex *A, const rocblas_int lda, + float *D, float *E, rocblas_float_complex *tauq, + rocblas_float_complex *taup, + rocblas_float_complex *X, const rocblas_int ldx, + rocblas_float_complex *Y, + const rocblas_int ldy) { return rocsolver_labrd_impl( handle, m, n, k, A, lda, D, E, tauq, taup, X, ldx, Y, ldy); } -ROCSOLVER_EXPORT rocblas_status rocsolver_zlabrd( +rocblas_status rocsolver_zlabrd( rocblas_handle handle, const rocblas_int m, const rocblas_int n, const rocblas_int k, rocblas_double_complex *A, const rocblas_int lda, double *D, double *E, rocblas_double_complex *tauq, diff --git a/rocsolver/library/src/auxiliary/rocauxiliary_lacgv.cpp b/rocsolver/library/src/auxiliary/rocauxiliary_lacgv.cpp index 7a25f85f8..2a95a419a 100644 --- a/rocsolver/library/src/auxiliary/rocauxiliary_lacgv.cpp +++ b/rocsolver/library/src/auxiliary/rocauxiliary_lacgv.cpp @@ -38,17 +38,15 @@ rocblas_status rocsolver_lacgv_impl(rocblas_handle handle, const rocblas_int n, extern "C" { -ROCSOLVER_EXPORT rocblas_status rocsolver_clacgv(rocblas_handle handle, - const rocblas_int n, - rocblas_float_complex *x, - const rocblas_int incx) { +rocblas_status rocsolver_clacgv(rocblas_handle handle, const rocblas_int n, + rocblas_float_complex *x, + const rocblas_int incx) { return rocsolver_lacgv_impl(handle, n, x, incx); } -ROCSOLVER_EXPORT rocblas_status rocsolver_zlacgv(rocblas_handle handle, - const rocblas_int n, - rocblas_double_complex *x, - const rocblas_int incx) { +rocblas_status rocsolver_zlacgv(rocblas_handle handle, const rocblas_int n, + rocblas_double_complex *x, + const rocblas_int incx) { return rocsolver_lacgv_impl(handle, n, x, incx); } diff --git a/rocsolver/library/src/auxiliary/rocauxiliary_larf.cpp b/rocsolver/library/src/auxiliary/rocauxiliary_larf.cpp index e373fb9e9..443610129 100644 --- a/rocsolver/library/src/auxiliary/rocauxiliary_larf.cpp +++ b/rocsolver/library/src/auxiliary/rocauxiliary_larf.cpp @@ -66,35 +66,40 @@ rocblas_status rocsolver_larf_impl(rocblas_handle handle, extern "C" { -ROCSOLVER_EXPORT rocblas_status rocsolver_slarf( - rocblas_handle handle, const rocblas_side side, const rocblas_int m, - const rocblas_int n, float *x, const rocblas_int incx, const float *alpha, - float *A, const rocblas_int lda) { +rocblas_status rocsolver_slarf(rocblas_handle handle, const rocblas_side side, + const rocblas_int m, const rocblas_int n, + float *x, const rocblas_int incx, + const float *alpha, float *A, + const rocblas_int lda) { return rocsolver_larf_impl(handle, side, m, n, x, incx, alpha, A, lda); } -ROCSOLVER_EXPORT rocblas_status rocsolver_dlarf( - rocblas_handle handle, const rocblas_side side, const rocblas_int m, - const rocblas_int n, double *x, const rocblas_int incx, const double *alpha, - double *A, const rocblas_int lda) { +rocblas_status rocsolver_dlarf(rocblas_handle handle, const rocblas_side side, + const rocblas_int m, const rocblas_int n, + double *x, const rocblas_int incx, + const double *alpha, double *A, + const rocblas_int lda) { return rocsolver_larf_impl(handle, side, m, n, x, incx, alpha, A, lda); } -ROCSOLVER_EXPORT rocblas_status rocsolver_clarf( - rocblas_handle handle, const rocblas_side side, const rocblas_int m, - const rocblas_int n, rocblas_float_complex *x, const rocblas_int incx, - const rocblas_float_complex *alpha, rocblas_float_complex *A, - const rocblas_int lda) { +rocblas_status rocsolver_clarf(rocblas_handle handle, const rocblas_side side, + const rocblas_int m, const rocblas_int n, + rocblas_float_complex *x, const rocblas_int incx, + const rocblas_float_complex *alpha, + rocblas_float_complex *A, + const rocblas_int lda) { return rocsolver_larf_impl(handle, side, m, n, x, incx, alpha, A, lda); } -ROCSOLVER_EXPORT rocblas_status rocsolver_zlarf( - rocblas_handle handle, const rocblas_side side, const rocblas_int m, - const rocblas_int n, rocblas_double_complex *x, const rocblas_int incx, - const rocblas_double_complex *alpha, rocblas_double_complex *A, - const rocblas_int lda) { +rocblas_status rocsolver_zlarf(rocblas_handle handle, const rocblas_side side, + const rocblas_int m, const rocblas_int n, + rocblas_double_complex *x, + const rocblas_int incx, + const rocblas_double_complex *alpha, + rocblas_double_complex *A, + const rocblas_int lda) { return rocsolver_larf_impl(handle, side, m, n, x, incx, alpha, A, lda); } diff --git a/rocsolver/library/src/auxiliary/rocauxiliary_larfb.cpp b/rocsolver/library/src/auxiliary/rocauxiliary_larfb.cpp index b082c7aa3..2b1bbe76d 100644 --- a/rocsolver/library/src/auxiliary/rocauxiliary_larfb.cpp +++ b/rocsolver/library/src/auxiliary/rocauxiliary_larfb.cpp @@ -27,29 +27,34 @@ rocsolver_larfb_impl(rocblas_handle handle, const rocblas_side side, rocblas_stride stridea = 0; rocblas_stride stridef = 0; rocblas_int batch_count = 1; + rocblas_int shiftV = 0; + rocblas_int shiftF = 0; + rocblas_int shiftA = 0; // memory managment size_t size_1; // size of workspace size_t size_2; // size of array of pointers to workspace + size_t size_3; // size of worksapce for TRMM calls rocsolver_larfb_getMemorySize(side, m, n, k, batch_count, &size_1, - &size_2); + &size_2, &size_3); // (TODO) MEMORY SIZE QUERIES AND ALLOCATIONS TO BE DONE WITH ROCBLAS HANDLE - void *work, *workArr; + void *work, *workArr, *workTrmm; hipMalloc(&work, size_1); hipMalloc(&workArr, size_2); - if ((size_1 && !work) || (size_2 && !workArr)) + hipMalloc(&workTrmm, size_3); + if ((size_1 && !work) || (size_2 && !workArr) || (size_3 && !workTrmm)) return rocblas_status_memory_error; // execution rocblas_status status = rocsolver_larfb_template( - handle, side, trans, direct, storev, m, n, k, V, 0, // shifted 0 entries - ldv, stridev, F, 0, // shifted 0 entries - ldf, stridef, A, 0, // shifted 0 entries - lda, stridea, batch_count, (T *)work, (T **)workArr); + handle, side, trans, direct, storev, m, n, k, V, shiftV, ldv, stridev, F, + shiftF, ldf, stridef, A, shiftA, lda, stridea, batch_count, (T *)work, + (T **)workArr, (T *)workTrmm); hipFree(work); hipFree(workArr); + hipFree(workTrmm); return status; } @@ -61,44 +66,49 @@ rocsolver_larfb_impl(rocblas_handle handle, const rocblas_side side, extern "C" { -ROCSOLVER_EXPORT rocblas_status rocsolver_slarfb( - rocblas_handle handle, const rocblas_side side, - const rocblas_operation trans, const rocblas_direct direct, - const rocblas_storev storev, const rocblas_int m, const rocblas_int n, - const rocblas_int k, float *V, const rocblas_int ldv, float *T, - const rocblas_int ldt, float *A, const rocblas_int lda) { +rocblas_status +rocsolver_slarfb(rocblas_handle handle, const rocblas_side side, + const rocblas_operation trans, const rocblas_direct direct, + const rocblas_storev storev, const rocblas_int m, + const rocblas_int n, const rocblas_int k, float *V, + const rocblas_int ldv, float *T, const rocblas_int ldt, + float *A, const rocblas_int lda) { return rocsolver_larfb_impl(handle, side, trans, direct, storev, m, n, k, V, ldv, T, ldt, A, lda); } -ROCSOLVER_EXPORT rocblas_status rocsolver_dlarfb( - rocblas_handle handle, const rocblas_side side, - const rocblas_operation trans, const rocblas_direct direct, - const rocblas_storev storev, const rocblas_int m, const rocblas_int n, - const rocblas_int k, double *V, const rocblas_int ldv, double *T, - const rocblas_int ldt, double *A, const rocblas_int lda) { +rocblas_status +rocsolver_dlarfb(rocblas_handle handle, const rocblas_side side, + const rocblas_operation trans, const rocblas_direct direct, + const rocblas_storev storev, const rocblas_int m, + const rocblas_int n, const rocblas_int k, double *V, + const rocblas_int ldv, double *T, const rocblas_int ldt, + double *A, const rocblas_int lda) { return rocsolver_larfb_impl(handle, side, trans, direct, storev, m, n, k, V, ldv, T, ldt, A, lda); } -ROCSOLVER_EXPORT rocblas_status rocsolver_clarfb( - rocblas_handle handle, const rocblas_side side, - const rocblas_operation trans, const rocblas_direct direct, - const rocblas_storev storev, const rocblas_int m, const rocblas_int n, - const rocblas_int k, rocblas_float_complex *V, const rocblas_int ldv, - rocblas_float_complex *T, const rocblas_int ldt, rocblas_float_complex *A, - const rocblas_int lda) { +rocblas_status rocsolver_clarfb(rocblas_handle handle, const rocblas_side side, + const rocblas_operation trans, + const rocblas_direct direct, + const rocblas_storev storev, + const rocblas_int m, const rocblas_int n, + const rocblas_int k, rocblas_float_complex *V, + const rocblas_int ldv, rocblas_float_complex *T, + const rocblas_int ldt, rocblas_float_complex *A, + const rocblas_int lda) { return rocsolver_larfb_impl( handle, side, trans, direct, storev, m, n, k, V, ldv, T, ldt, A, lda); } -ROCSOLVER_EXPORT rocblas_status rocsolver_zlarfb( - rocblas_handle handle, const rocblas_side side, - const rocblas_operation trans, const rocblas_direct direct, - const rocblas_storev storev, const rocblas_int m, const rocblas_int n, - const rocblas_int k, rocblas_double_complex *V, const rocblas_int ldv, - rocblas_double_complex *T, const rocblas_int ldt, rocblas_double_complex *A, - const rocblas_int lda) { +rocblas_status +rocsolver_zlarfb(rocblas_handle handle, const rocblas_side side, + const rocblas_operation trans, const rocblas_direct direct, + const rocblas_storev storev, const rocblas_int m, + const rocblas_int n, const rocblas_int k, + rocblas_double_complex *V, const rocblas_int ldv, + rocblas_double_complex *T, const rocblas_int ldt, + rocblas_double_complex *A, const rocblas_int lda) { return rocsolver_larfb_impl( handle, side, trans, direct, storev, m, n, k, V, ldv, T, ldt, A, lda); } diff --git a/rocsolver/library/src/auxiliary/rocauxiliary_larfb.hpp b/rocsolver/library/src/auxiliary/rocauxiliary_larfb.hpp index 73e85738d..c8f195d4c 100644 --- a/rocsolver/library/src/auxiliary/rocauxiliary_larfb.hpp +++ b/rocsolver/library/src/auxiliary/rocauxiliary_larfb.hpp @@ -57,7 +57,8 @@ template void rocsolver_larfb_getMemorySize(const rocblas_side side, const rocblas_int m, const rocblas_int n, const rocblas_int k, const rocblas_int batch_count, - size_t *size_1, size_t *size_2) { + size_t *size_1, size_t *size_2, + size_t *size_3) { // size of workspace if (side == rocblas_side_left) *size_1 = n; @@ -70,19 +71,9 @@ void rocsolver_larfb_getMemorySize(const rocblas_side side, const rocblas_int m, *size_2 = sizeof(T *) * batch_count; else *size_2 = 0; -} -template -void rocsolver_larfb_getMemorySize(const rocblas_side side, const rocblas_int m, - const rocblas_int n, const rocblas_int k, - const rocblas_int batch_count, - size_t *size) { - // size of workspace - if (side == rocblas_side_left) - *size = n; - else - *size = m; - *size *= sizeof(T) * k * batch_count; + // size of workspace for TRMM calls + *size_3 = 2 * ROCBLAS_TRMM_NB * ROCBLAS_TRMM_NB * sizeof(T) * batch_count; } template @@ -134,7 +125,7 @@ rocblas_status rocsolver_larfb_template( const rocblas_int ldf, const rocblas_stride strideF, U A, const rocblas_int shiftA, const rocblas_int lda, const rocblas_stride strideA, const rocblas_int batch_count, T *work, - T **workArr) { + T **workArr, T *workTrmm) { // quick return if (m == 0 || n == 0 || batch_count == 0) return rocblas_status_success; @@ -152,15 +143,6 @@ rocblas_status rocsolver_larfb_template( T minone = -1; T one = 1; -// **** THIS SYNCHRONIZATION WILL BE REQUIRED UNTIL -// TRMM_BATCH FUNCTIONALITY IS ENABLED. **** -#ifdef batched - T *VV[batch_count]; - hipMemcpy(VV, V, batch_count * sizeof(T *), hipMemcpyDeviceToHost); -#else - T *VV = V; -#endif - // determine the side, size of workspace // and whether V is trapezoidal bool trap; @@ -237,9 +219,6 @@ rocblas_status rocsolver_larfb_template( rocblas_stride strideW = rocblas_stride(ldw) * order; uploT = (forward ? rocblas_fill_upper : rocblas_fill_lower); - // **** TRMM_BATCH IS EXECUTED IN A FOR-LOOP UNTIL - // FUNCITONALITY IS ENABLED **** - // copy A1 to work rocblas_int blocksx = (order - 1) / 32 + 1; rocblas_int blocksy = (ldw - 1) / 32 + 1; @@ -249,11 +228,10 @@ rocblas_status rocsolver_larfb_template( // compute: V1' * A1 // or A1 * V1 - for (int b = 0; b < batch_count; ++b) { - Vp = load_ptr_batch(VV, b, offsetV1, strideV); - rocblas_trmm(handle, side, uploV, transp, rocblas_diagonal_unit, ldw, order, - &one, Vp, ldv, (work + b * strideW), ldw); - } + rocblasCall_trmm( + handle, side, uploV, transp, rocblas_diagonal_unit, ldw, order, &one, V, + offsetV1, ldv, strideV, work, 0, ldw, strideW, batch_count, workTrmm, + workArr); // compute: V1' * A1 + V2' * A2 // or A1 * V1 + A2 * V2 @@ -272,11 +250,10 @@ rocblas_status rocsolver_larfb_template( // compute: trans(T) * (V1' * A1 + V2' * A2) // or (A1 * V1 + A2 * V2) * trans(T) - for (int b = 0; b < batch_count; ++b) { - Fp = load_ptr_batch(F, b, shiftF, strideF); - rocblas_trmm(handle, side, uploT, transt, rocblas_diagonal_non_unit, ldw, - order, &one, Fp, ldf, (work + b * strideW), ldw); - } + rocblasCall_trmm( + handle, side, uploT, transt, rocblas_diagonal_non_unit, ldw, order, &one, + F, shiftF, ldf, strideF, work, 0, ldw, strideW, batch_count, workTrmm, + workArr); // compute: A2 - V2 * trans(T) * (V1' * A1 + V2' * A2) // or A2 - (A1 * V1 + A2 * V2) * trans(T) * V2' @@ -300,11 +277,10 @@ rocblas_status rocsolver_larfb_template( // compute: V1 * trans(T) * (V1' * A1 + V2' * A2) // or (A1 * V1 + A2 * V2) * trans(T) * V1' - for (int b = 0; b < batch_count; ++b) { - Vp = load_ptr_batch(VV, b, offsetV1, strideV); - rocblas_trmm(handle, side, uploV, transp, rocblas_diagonal_unit, ldw, order, - &one, Vp, ldv, (work + b * strideW), ldw); - } + rocblasCall_trmm( + handle, side, uploV, transp, rocblas_diagonal_unit, ldw, order, &one, V, + offsetV1, ldv, strideV, work, 0, ldw, strideW, batch_count, workTrmm, + workArr); // compute: A1 - V1 * trans(T) * (V1' * A1 + V2' * A2) // or A1 - (A1 * V1 + A2 * V2) * trans(T) * V1' diff --git a/rocsolver/library/src/auxiliary/rocauxiliary_larfg.cpp b/rocsolver/library/src/auxiliary/rocauxiliary_larfg.cpp index 8be97d21d..892b7ded0 100644 --- a/rocsolver/library/src/auxiliary/rocauxiliary_larfg.cpp +++ b/rocsolver/library/src/auxiliary/rocauxiliary_larfg.cpp @@ -53,38 +53,32 @@ rocblas_status rocsolver_larfg_impl(rocblas_handle handle, const rocblas_int n, extern "C" { -ROCSOLVER_EXPORT rocblas_status rocsolver_slarfg(rocblas_handle handle, - const rocblas_int n, - float *alpha, float *x, - const rocblas_int incx, - float *tau) { +rocblas_status rocsolver_slarfg(rocblas_handle handle, const rocblas_int n, + float *alpha, float *x, const rocblas_int incx, + float *tau) { return rocsolver_larfg_impl(handle, n, alpha, x, incx, tau); } -ROCSOLVER_EXPORT rocblas_status rocsolver_dlarfg(rocblas_handle handle, - const rocblas_int n, - double *alpha, double *x, - const rocblas_int incx, - double *tau) { +rocblas_status rocsolver_dlarfg(rocblas_handle handle, const rocblas_int n, + double *alpha, double *x, + const rocblas_int incx, double *tau) { return rocsolver_larfg_impl(handle, n, alpha, x, incx, tau); } -ROCSOLVER_EXPORT rocblas_status rocsolver_clarfg(rocblas_handle handle, - const rocblas_int n, - rocblas_float_complex *alpha, - rocblas_float_complex *x, - const rocblas_int incx, - rocblas_float_complex *tau) { +rocblas_status rocsolver_clarfg(rocblas_handle handle, const rocblas_int n, + rocblas_float_complex *alpha, + rocblas_float_complex *x, + const rocblas_int incx, + rocblas_float_complex *tau) { return rocsolver_larfg_impl(handle, n, alpha, x, incx, tau); } -ROCSOLVER_EXPORT rocblas_status rocsolver_zlarfg(rocblas_handle handle, - const rocblas_int n, - rocblas_double_complex *alpha, - rocblas_double_complex *x, - const rocblas_int incx, - rocblas_double_complex *tau) { +rocblas_status rocsolver_zlarfg(rocblas_handle handle, const rocblas_int n, + rocblas_double_complex *alpha, + rocblas_double_complex *x, + const rocblas_int incx, + rocblas_double_complex *tau) { return rocsolver_larfg_impl(handle, n, alpha, x, incx, tau); } diff --git a/rocsolver/library/src/auxiliary/rocauxiliary_larft.cpp b/rocsolver/library/src/auxiliary/rocauxiliary_larft.cpp index d324c4de6..f49f228d6 100644 --- a/rocsolver/library/src/auxiliary/rocauxiliary_larft.cpp +++ b/rocsolver/library/src/auxiliary/rocauxiliary_larft.cpp @@ -66,34 +66,37 @@ rocsolver_larft_impl(rocblas_handle handle, const rocblas_direct direct, extern "C" { -ROCSOLVER_EXPORT rocblas_status -rocsolver_slarft(rocblas_handle handle, const rocblas_direct direct, - const rocblas_storev storev, const rocblas_int n, - const rocblas_int k, float *V, const rocblas_int ldv, - float *tau, float *T, const rocblas_int ldt) { +rocblas_status rocsolver_slarft(rocblas_handle handle, + const rocblas_direct direct, + const rocblas_storev storev, + const rocblas_int n, const rocblas_int k, + float *V, const rocblas_int ldv, float *tau, + float *T, const rocblas_int ldt) { return rocsolver_larft_impl(handle, direct, storev, n, k, V, ldv, tau, T, ldt); } -ROCSOLVER_EXPORT rocblas_status -rocsolver_dlarft(rocblas_handle handle, const rocblas_direct direct, - const rocblas_storev storev, const rocblas_int n, - const rocblas_int k, double *V, const rocblas_int ldv, - double *tau, double *T, const rocblas_int ldt) { +rocblas_status rocsolver_dlarft(rocblas_handle handle, + const rocblas_direct direct, + const rocblas_storev storev, + const rocblas_int n, const rocblas_int k, + double *V, const rocblas_int ldv, double *tau, + double *T, const rocblas_int ldt) { return rocsolver_larft_impl(handle, direct, storev, n, k, V, ldv, tau, T, ldt); } -ROCSOLVER_EXPORT rocblas_status rocsolver_clarft( - rocblas_handle handle, const rocblas_direct direct, - const rocblas_storev storev, const rocblas_int n, const rocblas_int k, - rocblas_float_complex *V, const rocblas_int ldv, rocblas_float_complex *tau, - rocblas_float_complex *T, const rocblas_int ldt) { +rocblas_status +rocsolver_clarft(rocblas_handle handle, const rocblas_direct direct, + const rocblas_storev storev, const rocblas_int n, + const rocblas_int k, rocblas_float_complex *V, + const rocblas_int ldv, rocblas_float_complex *tau, + rocblas_float_complex *T, const rocblas_int ldt) { return rocsolver_larft_impl(handle, direct, storev, n, k, V, ldv, tau, T, ldt); } -ROCSOLVER_EXPORT rocblas_status +rocblas_status rocsolver_zlarft(rocblas_handle handle, const rocblas_direct direct, const rocblas_storev storev, const rocblas_int n, const rocblas_int k, rocblas_double_complex *V, diff --git a/rocsolver/library/src/auxiliary/rocauxiliary_laswp.cpp b/rocsolver/library/src/auxiliary/rocauxiliary_laswp.cpp index 78c6c6818..ee7be4ade 100644 --- a/rocsolver/library/src/auxiliary/rocauxiliary_laswp.cpp +++ b/rocsolver/library/src/auxiliary/rocauxiliary_laswp.cpp @@ -45,32 +45,36 @@ rocblas_status rocsolver_laswp_impl(rocblas_handle handle, const rocblas_int n, extern "C" { -ROCSOLVER_EXPORT rocblas_status rocsolver_slaswp( - rocblas_handle handle, const rocblas_int n, float *A, const rocblas_int lda, - const rocblas_int k1, const rocblas_int k2, const rocblas_int *ipiv, - const rocblas_int incx) { +rocblas_status rocsolver_slaswp(rocblas_handle handle, const rocblas_int n, + float *A, const rocblas_int lda, + const rocblas_int k1, const rocblas_int k2, + const rocblas_int *ipiv, + const rocblas_int incx) { return rocsolver_laswp_impl(handle, n, A, lda, k1, k2, ipiv, incx); } -ROCSOLVER_EXPORT rocblas_status rocsolver_dlaswp( - rocblas_handle handle, const rocblas_int n, double *A, - const rocblas_int lda, const rocblas_int k1, const rocblas_int k2, - const rocblas_int *ipiv, const rocblas_int incx) { +rocblas_status rocsolver_dlaswp(rocblas_handle handle, const rocblas_int n, + double *A, const rocblas_int lda, + const rocblas_int k1, const rocblas_int k2, + const rocblas_int *ipiv, + const rocblas_int incx) { return rocsolver_laswp_impl(handle, n, A, lda, k1, k2, ipiv, incx); } -ROCSOLVER_EXPORT rocblas_status rocsolver_claswp( - rocblas_handle handle, const rocblas_int n, rocblas_float_complex *A, - const rocblas_int lda, const rocblas_int k1, const rocblas_int k2, - const rocblas_int *ipiv, const rocblas_int incx) { +rocblas_status rocsolver_claswp(rocblas_handle handle, const rocblas_int n, + rocblas_float_complex *A, const rocblas_int lda, + const rocblas_int k1, const rocblas_int k2, + const rocblas_int *ipiv, + const rocblas_int incx) { return rocsolver_laswp_impl(handle, n, A, lda, k1, k2, ipiv, incx); } -ROCSOLVER_EXPORT rocblas_status rocsolver_zlaswp( - rocblas_handle handle, const rocblas_int n, rocblas_double_complex *A, - const rocblas_int lda, const rocblas_int k1, const rocblas_int k2, - const rocblas_int *ipiv, const rocblas_int incx) { +rocblas_status rocsolver_zlaswp(rocblas_handle handle, const rocblas_int n, + rocblas_double_complex *A, + const rocblas_int lda, const rocblas_int k1, + const rocblas_int k2, const rocblas_int *ipiv, + const rocblas_int incx) { return rocsolver_laswp_impl(handle, n, A, lda, k1, k2, ipiv, incx); } diff --git a/rocsolver/library/src/auxiliary/rocauxiliary_org2r_ung2r.cpp b/rocsolver/library/src/auxiliary/rocauxiliary_org2r_ung2r.cpp index ba8c21b5c..3abf8b4fc 100644 --- a/rocsolver/library/src/auxiliary/rocauxiliary_org2r_ung2r.cpp +++ b/rocsolver/library/src/auxiliary/rocauxiliary_org2r_ung2r.cpp @@ -63,30 +63,32 @@ rocsolver_org2r_ung2r_impl(rocblas_handle handle, const rocblas_int m, extern "C" { -ROCSOLVER_EXPORT rocblas_status rocsolver_sorg2r( - rocblas_handle handle, const rocblas_int m, const rocblas_int n, - const rocblas_int k, float *A, const rocblas_int lda, float *ipiv) { +rocblas_status rocsolver_sorg2r(rocblas_handle handle, const rocblas_int m, + const rocblas_int n, const rocblas_int k, + float *A, const rocblas_int lda, float *ipiv) { return rocsolver_org2r_ung2r_impl(handle, m, n, k, A, lda, ipiv); } -ROCSOLVER_EXPORT rocblas_status rocsolver_dorg2r( - rocblas_handle handle, const rocblas_int m, const rocblas_int n, - const rocblas_int k, double *A, const rocblas_int lda, double *ipiv) { +rocblas_status rocsolver_dorg2r(rocblas_handle handle, const rocblas_int m, + const rocblas_int n, const rocblas_int k, + double *A, const rocblas_int lda, + double *ipiv) { return rocsolver_org2r_ung2r_impl(handle, m, n, k, A, lda, ipiv); } -ROCSOLVER_EXPORT rocblas_status rocsolver_cung2r( - rocblas_handle handle, const rocblas_int m, const rocblas_int n, - const rocblas_int k, rocblas_float_complex *A, const rocblas_int lda, - rocblas_float_complex *ipiv) { +rocblas_status rocsolver_cung2r(rocblas_handle handle, const rocblas_int m, + const rocblas_int n, const rocblas_int k, + rocblas_float_complex *A, const rocblas_int lda, + rocblas_float_complex *ipiv) { return rocsolver_org2r_ung2r_impl(handle, m, n, k, A, lda, ipiv); } -ROCSOLVER_EXPORT rocblas_status rocsolver_zung2r( - rocblas_handle handle, const rocblas_int m, const rocblas_int n, - const rocblas_int k, rocblas_double_complex *A, const rocblas_int lda, - rocblas_double_complex *ipiv) { +rocblas_status rocsolver_zung2r(rocblas_handle handle, const rocblas_int m, + const rocblas_int n, const rocblas_int k, + rocblas_double_complex *A, + const rocblas_int lda, + rocblas_double_complex *ipiv) { return rocsolver_org2r_ung2r_impl(handle, m, n, k, A, lda, ipiv); } diff --git a/rocsolver/library/src/auxiliary/rocauxiliary_orgbr_ungbr.cpp b/rocsolver/library/src/auxiliary/rocauxiliary_orgbr_ungbr.cpp index f793f5009..e2b5d647d 100644 --- a/rocsolver/library/src/auxiliary/rocauxiliary_orgbr_ungbr.cpp +++ b/rocsolver/library/src/auxiliary/rocauxiliary_orgbr_ungbr.cpp @@ -29,17 +29,20 @@ rocsolver_orgbr_ungbr_impl(rocblas_handle handle, const rocblas_storev storev, size_t size_2; // size of workspace size_t size_3; // size of array of pointers to workspace size_t size_4; // size of temporary array for triangular factor - rocsolver_orgbr_ungbr_getMemorySize( - storev, m, n, k, batch_count, &size_1, &size_2, &size_3, &size_4); + size_t size_5; // worksapce for TRMM calls + rocsolver_orgbr_ungbr_getMemorySize(storev, m, n, k, batch_count, + &size_1, &size_2, &size_3, + &size_4, &size_5); // (TODO) MEMORY SIZE QUERIES AND ALLOCATIONS TO BE DONE WITH ROCBLAS HANDLE - void *scalars, *work, *workArr, *trfact; + void *scalars, *work, *workArr, *trfact, *workTrmm; hipMalloc(&scalars, size_1); hipMalloc(&work, size_2); hipMalloc(&workArr, size_3); hipMalloc(&trfact, size_4); + hipMalloc(&workTrmm, size_5); if (!scalars || (size_2 && !work) || (size_3 && !workArr) || - (size_4 && !trfact)) + (size_4 && !trfact) || (size_5 && !workTrmm)) return rocblas_status_memory_error; // scalar constants for rocblas functions calls @@ -51,12 +54,13 @@ rocsolver_orgbr_ungbr_impl(rocblas_handle handle, const rocblas_storev storev, rocblas_status status = rocsolver_orgbr_ungbr_template( handle, storev, m, n, k, A, 0, // shifted 0 entries lda, strideA, ipiv, strideP, batch_count, (T *)scalars, (T *)work, - (T **)workArr, (T *)trfact); + (T **)workArr, (T *)trfact, (T *)workTrmm); hipFree(scalars); hipFree(work); hipFree(workArr); hipFree(trfact); + hipFree(workTrmm); return status; } @@ -68,34 +72,40 @@ rocsolver_orgbr_ungbr_impl(rocblas_handle handle, const rocblas_storev storev, extern "C" { -ROCSOLVER_EXPORT rocblas_status -rocsolver_sorgbr(rocblas_handle handle, const rocblas_storev storev, - const rocblas_int m, const rocblas_int n, const rocblas_int k, - float *A, const rocblas_int lda, float *ipiv) { +rocblas_status rocsolver_sorgbr(rocblas_handle handle, + const rocblas_storev storev, + const rocblas_int m, const rocblas_int n, + const rocblas_int k, float *A, + const rocblas_int lda, float *ipiv) { return rocsolver_orgbr_ungbr_impl(handle, storev, m, n, k, A, lda, ipiv); } -ROCSOLVER_EXPORT rocblas_status -rocsolver_dorgbr(rocblas_handle handle, const rocblas_storev storev, - const rocblas_int m, const rocblas_int n, const rocblas_int k, - double *A, const rocblas_int lda, double *ipiv) { +rocblas_status rocsolver_dorgbr(rocblas_handle handle, + const rocblas_storev storev, + const rocblas_int m, const rocblas_int n, + const rocblas_int k, double *A, + const rocblas_int lda, double *ipiv) { return rocsolver_orgbr_ungbr_impl(handle, storev, m, n, k, A, lda, ipiv); } -ROCSOLVER_EXPORT rocblas_status rocsolver_cungbr( - rocblas_handle handle, const rocblas_storev storev, const rocblas_int m, - const rocblas_int n, const rocblas_int k, rocblas_float_complex *A, - const rocblas_int lda, rocblas_float_complex *ipiv) { +rocblas_status rocsolver_cungbr(rocblas_handle handle, + const rocblas_storev storev, + const rocblas_int m, const rocblas_int n, + const rocblas_int k, rocblas_float_complex *A, + const rocblas_int lda, + rocblas_float_complex *ipiv) { return rocsolver_orgbr_ungbr_impl(handle, storev, m, n, k, A, lda, ipiv); } -ROCSOLVER_EXPORT rocblas_status rocsolver_zungbr( - rocblas_handle handle, const rocblas_storev storev, const rocblas_int m, - const rocblas_int n, const rocblas_int k, rocblas_double_complex *A, - const rocblas_int lda, rocblas_double_complex *ipiv) { +rocblas_status rocsolver_zungbr(rocblas_handle handle, + const rocblas_storev storev, + const rocblas_int m, const rocblas_int n, + const rocblas_int k, rocblas_double_complex *A, + const rocblas_int lda, + rocblas_double_complex *ipiv) { return rocsolver_orgbr_ungbr_impl(handle, storev, m, n, k, A, lda, ipiv); } diff --git a/rocsolver/library/src/auxiliary/rocauxiliary_orgbr_ungbr.hpp b/rocsolver/library/src/auxiliary/rocauxiliary_orgbr_ungbr.hpp index b942a883d..d07fbb96d 100644 --- a/rocsolver/library/src/auxiliary/rocauxiliary_orgbr_ungbr.hpp +++ b/rocsolver/library/src/auxiliary/rocauxiliary_orgbr_ungbr.hpp @@ -88,27 +88,29 @@ template void rocsolver_orgbr_ungbr_getMemorySize( const rocblas_storev storev, const rocblas_int m, const rocblas_int n, const rocblas_int k, const rocblas_int batch_count, size_t *size_1, - size_t *size_2, size_t *size_3, size_t *size_4) { + size_t *size_2, size_t *size_3, size_t *size_4, size_t *size_5) { if (storev == rocblas_column_wise) { if (m >= k) { rocsolver_orgqr_ungqr_getMemorySize( - m, n, k, batch_count, size_1, size_2, size_3, size_4); + m, n, k, batch_count, size_1, size_2, size_3, size_4, size_5); } else { size_t s1 = sizeof(T) * batch_count * (m - 1) * m / 2; size_t s2; - rocsolver_orgqr_ungqr_getMemorySize( - m - 1, m - 1, m - 1, batch_count, size_1, &s2, size_3, size_4); + rocsolver_orgqr_ungqr_getMemorySize(m - 1, m - 1, m - 1, + batch_count, size_1, &s2, + size_3, size_4, size_5); *size_2 = max(s1, s2); } } else { if (n > k) { rocsolver_orglq_unglq_getMemorySize( - m, n, k, batch_count, size_1, size_2, size_3, size_4); + m, n, k, batch_count, size_1, size_2, size_3, size_4, size_5); } else { size_t s1 = sizeof(T) * batch_count * (n - 1) * n / 2; size_t s2; - rocsolver_orglq_unglq_getMemorySize( - n - 1, n - 1, n - 1, batch_count, size_1, &s2, size_3, size_4); + rocsolver_orglq_unglq_getMemorySize(n - 1, n - 1, n - 1, + batch_count, size_1, &s2, + size_3, size_4, size_5); *size_2 = max(s1, s2); } } @@ -148,7 +150,7 @@ rocblas_status rocsolver_orgbr_ungbr_template( const rocblas_int n, const rocblas_int k, U A, const rocblas_int shiftA, const rocblas_int lda, const rocblas_stride strideA, T *ipiv, const rocblas_stride strideP, const rocblas_int batch_count, T *scalars, - T *work, T **workArr, T *trfact) { + T *work, T **workArr, T *trfact, T *workTrmm) { // quick return if (!n || !m || !batch_count) return rocblas_status_success; @@ -162,7 +164,7 @@ rocblas_status rocsolver_orgbr_ungbr_template( if (m >= k) { rocsolver_orgqr_ungqr_template( handle, m, n, k, A, shiftA, lda, strideA, ipiv, strideP, batch_count, - scalars, work, workArr, trfact); + scalars, work, workArr, trfact, workTrmm); } else { // shift the householder vectors provided by gebrd as they come below the // first subdiagonal @@ -184,7 +186,8 @@ rocblas_status rocsolver_orgbr_ungbr_template( // result rocsolver_orgqr_ungqr_template( handle, m - 1, m - 1, m - 1, A, shiftA + idx2D(1, 1, lda), lda, - strideA, ipiv, strideP, batch_count, scalars, work, workArr, trfact); + strideA, ipiv, strideP, batch_count, scalars, work, workArr, trfact, + workTrmm); } } @@ -194,7 +197,7 @@ rocblas_status rocsolver_orgbr_ungbr_template( if (n > k) { rocsolver_orglq_unglq_template( handle, m, n, k, A, shiftA, lda, strideA, ipiv, strideP, batch_count, - scalars, work, workArr, trfact); + scalars, work, workArr, trfact, workTrmm); } else { // shift the householder vectors provided by gebrd as they come above the // first superdiagonal @@ -216,7 +219,8 @@ rocblas_status rocsolver_orgbr_ungbr_template( // result rocsolver_orglq_unglq_template( handle, n - 1, n - 1, n - 1, A, shiftA + idx2D(1, 1, lda), lda, - strideA, ipiv, strideP, batch_count, scalars, work, workArr, trfact); + strideA, ipiv, strideP, batch_count, scalars, work, workArr, trfact, + workTrmm); } } diff --git a/rocsolver/library/src/auxiliary/rocauxiliary_orgl2_ungl2.cpp b/rocsolver/library/src/auxiliary/rocauxiliary_orgl2_ungl2.cpp index 4b1794098..c68a04c4b 100644 --- a/rocsolver/library/src/auxiliary/rocauxiliary_orgl2_ungl2.cpp +++ b/rocsolver/library/src/auxiliary/rocauxiliary_orgl2_ungl2.cpp @@ -63,30 +63,32 @@ rocsolver_orgl2_ungl2_impl(rocblas_handle handle, const rocblas_int m, extern "C" { -ROCSOLVER_EXPORT rocblas_status rocsolver_sorgl2( - rocblas_handle handle, const rocblas_int m, const rocblas_int n, - const rocblas_int k, float *A, const rocblas_int lda, float *ipiv) { +rocblas_status rocsolver_sorgl2(rocblas_handle handle, const rocblas_int m, + const rocblas_int n, const rocblas_int k, + float *A, const rocblas_int lda, float *ipiv) { return rocsolver_orgl2_ungl2_impl(handle, m, n, k, A, lda, ipiv); } -ROCSOLVER_EXPORT rocblas_status rocsolver_dorgl2( - rocblas_handle handle, const rocblas_int m, const rocblas_int n, - const rocblas_int k, double *A, const rocblas_int lda, double *ipiv) { +rocblas_status rocsolver_dorgl2(rocblas_handle handle, const rocblas_int m, + const rocblas_int n, const rocblas_int k, + double *A, const rocblas_int lda, + double *ipiv) { return rocsolver_orgl2_ungl2_impl(handle, m, n, k, A, lda, ipiv); } -ROCSOLVER_EXPORT rocblas_status rocsolver_cungl2( - rocblas_handle handle, const rocblas_int m, const rocblas_int n, - const rocblas_int k, rocblas_float_complex *A, const rocblas_int lda, - rocblas_float_complex *ipiv) { +rocblas_status rocsolver_cungl2(rocblas_handle handle, const rocblas_int m, + const rocblas_int n, const rocblas_int k, + rocblas_float_complex *A, const rocblas_int lda, + rocblas_float_complex *ipiv) { return rocsolver_orgl2_ungl2_impl(handle, m, n, k, A, lda, ipiv); } -ROCSOLVER_EXPORT rocblas_status rocsolver_zungl2( - rocblas_handle handle, const rocblas_int m, const rocblas_int n, - const rocblas_int k, rocblas_double_complex *A, const rocblas_int lda, - rocblas_double_complex *ipiv) { +rocblas_status rocsolver_zungl2(rocblas_handle handle, const rocblas_int m, + const rocblas_int n, const rocblas_int k, + rocblas_double_complex *A, + const rocblas_int lda, + rocblas_double_complex *ipiv) { return rocsolver_orgl2_ungl2_impl(handle, m, n, k, A, lda, ipiv); } diff --git a/rocsolver/library/src/auxiliary/rocauxiliary_orglq_unglq.cpp b/rocsolver/library/src/auxiliary/rocauxiliary_orglq_unglq.cpp index bc72f16c2..9728e7c9d 100644 --- a/rocsolver/library/src/auxiliary/rocauxiliary_orglq_unglq.cpp +++ b/rocsolver/library/src/auxiliary/rocauxiliary_orglq_unglq.cpp @@ -28,17 +28,19 @@ rocsolver_orglq_unglq_impl(rocblas_handle handle, const rocblas_int m, size_t size_2; // size of workspace size_t size_3; // size of array of pointers to workspace size_t size_4; // size of temporary array for triangular factor - rocsolver_orglq_unglq_getMemorySize(m, n, k, batch_count, &size_1, - &size_2, &size_3, &size_4); + size_t size_5; // size of workspace for TRMM calls + rocsolver_orglq_unglq_getMemorySize( + m, n, k, batch_count, &size_1, &size_2, &size_3, &size_4, &size_5); // (TODO) MEMORY SIZE QUERIES AND ALLOCATIONS TO BE DONE WITH ROCBLAS HANDLE - void *scalars, *work, *workArr, *trfact; + void *scalars, *work, *workArr, *trfact, *workTrmm; hipMalloc(&scalars, size_1); hipMalloc(&work, size_2); hipMalloc(&workArr, size_3); hipMalloc(&trfact, size_4); + hipMalloc(&workTrmm, size_5); if (!scalars || (size_2 && !work) || (size_3 && !workArr) || - (size_4 && !trfact)) + (size_4 && !trfact) || (size_5 && !workTrmm)) return rocblas_status_memory_error; // scalar constants for rocblas functions calls @@ -50,12 +52,13 @@ rocsolver_orglq_unglq_impl(rocblas_handle handle, const rocblas_int m, rocblas_status status = rocsolver_orglq_unglq_template( handle, m, n, k, A, 0, // shifted 0 entries lda, strideA, ipiv, strideP, batch_count, (T *)scalars, (T *)work, - (T **)workArr, (T *)trfact); + (T **)workArr, (T *)trfact, (T *)workTrmm); hipFree(scalars); hipFree(work); hipFree(workArr); hipFree(trfact); + hipFree(workTrmm); return status; } @@ -67,30 +70,32 @@ rocsolver_orglq_unglq_impl(rocblas_handle handle, const rocblas_int m, extern "C" { -ROCSOLVER_EXPORT rocblas_status rocsolver_sorglq( - rocblas_handle handle, const rocblas_int m, const rocblas_int n, - const rocblas_int k, float *A, const rocblas_int lda, float *ipiv) { +rocblas_status rocsolver_sorglq(rocblas_handle handle, const rocblas_int m, + const rocblas_int n, const rocblas_int k, + float *A, const rocblas_int lda, float *ipiv) { return rocsolver_orglq_unglq_impl(handle, m, n, k, A, lda, ipiv); } -ROCSOLVER_EXPORT rocblas_status rocsolver_dorglq( - rocblas_handle handle, const rocblas_int m, const rocblas_int n, - const rocblas_int k, double *A, const rocblas_int lda, double *ipiv) { +rocblas_status rocsolver_dorglq(rocblas_handle handle, const rocblas_int m, + const rocblas_int n, const rocblas_int k, + double *A, const rocblas_int lda, + double *ipiv) { return rocsolver_orglq_unglq_impl(handle, m, n, k, A, lda, ipiv); } -ROCSOLVER_EXPORT rocblas_status rocsolver_cunglq( - rocblas_handle handle, const rocblas_int m, const rocblas_int n, - const rocblas_int k, rocblas_float_complex *A, const rocblas_int lda, - rocblas_float_complex *ipiv) { +rocblas_status rocsolver_cunglq(rocblas_handle handle, const rocblas_int m, + const rocblas_int n, const rocblas_int k, + rocblas_float_complex *A, const rocblas_int lda, + rocblas_float_complex *ipiv) { return rocsolver_orglq_unglq_impl(handle, m, n, k, A, lda, ipiv); } -ROCSOLVER_EXPORT rocblas_status rocsolver_zunglq( - rocblas_handle handle, const rocblas_int m, const rocblas_int n, - const rocblas_int k, rocblas_double_complex *A, const rocblas_int lda, - rocblas_double_complex *ipiv) { +rocblas_status rocsolver_zunglq(rocblas_handle handle, const rocblas_int m, + const rocblas_int n, const rocblas_int k, + rocblas_double_complex *A, + const rocblas_int lda, + rocblas_double_complex *ipiv) { return rocsolver_orglq_unglq_impl(handle, m, n, k, A, lda, ipiv); } diff --git a/rocsolver/library/src/auxiliary/rocauxiliary_orglq_unglq.hpp b/rocsolver/library/src/auxiliary/rocauxiliary_orglq_unglq.hpp index c05fc0f0b..33782a673 100644 --- a/rocsolver/library/src/auxiliary/rocauxiliary_orglq_unglq.hpp +++ b/rocsolver/library/src/auxiliary/rocauxiliary_orglq_unglq.hpp @@ -34,18 +34,17 @@ __global__ void set_zero_row(const rocblas_int m, const rocblas_int kk, U A, } template -void rocsolver_orglq_unglq_getMemorySize(const rocblas_int m, - const rocblas_int n, - const rocblas_int k, - const rocblas_int batch_count, - size_t *size_1, size_t *size_2, - size_t *size_3, size_t *size_4) { - size_t s1, s2, s3; +void rocsolver_orglq_unglq_getMemorySize( + const rocblas_int m, const rocblas_int n, const rocblas_int k, + const rocblas_int batch_count, size_t *size_1, size_t *size_2, + size_t *size_3, size_t *size_4, size_t *size_5) { + size_t s1, s2, s3, unused; rocsolver_orgl2_ungl2_getMemorySize(m, n, batch_count, size_1, size_2, size_3); if (k <= GEQRF_GEQR2_SWITCHSIZE) { *size_4 = 0; + *size_5 = 0; } else { // size of workspace // maximum of what is needed by org2r, larft and larfb @@ -55,8 +54,8 @@ void rocsolver_orglq_unglq_getMemorySize(const rocblas_int m, rocsolver_orgl2_ungl2_getMemorySize(max(m - kk, jb), n, batch_count, &s1); rocsolver_larft_getMemorySize(jb, batch_count, &s2); - rocsolver_larfb_getMemorySize(rocblas_side_left, m - jb, n, jb, - batch_count, &s3); + rocsolver_larfb_getMemorySize( + rocblas_side_left, m - jb, n, jb, batch_count, &s3, &unused, size_5); *size_2 = max(max(s1, s2), s3); @@ -70,8 +69,8 @@ rocblas_status rocsolver_orglq_unglq_template( rocblas_handle handle, const rocblas_int m, const rocblas_int n, const rocblas_int k, U A, const rocblas_int shiftA, const rocblas_int lda, const rocblas_stride strideA, T *ipiv, const rocblas_stride strideP, - const rocblas_int batch_count, T *scalars, T *work, T **workArr, - T *trfact) { + const rocblas_int batch_count, T *scalars, T *work, T **workArr, T *trfact, + T *workTrmm) { // quick return if (!n || !m || !batch_count) return rocblas_status_success; @@ -126,7 +125,7 @@ rocblas_status rocsolver_orglq_unglq_template( rocblas_forward_direction, rocblas_row_wise, m - j - jb, n - j, jb, A, shiftA + idx2D(j, j, lda), lda, strideA, trfact, 0, ldw, strideW, A, shiftA + idx2D(j + jb, j, lda), lda, strideA, batch_count, work, - workArr); + workArr, workTrmm); } // now compute the current block and set to zero diff --git a/rocsolver/library/src/auxiliary/rocauxiliary_orgqr_ungqr.cpp b/rocsolver/library/src/auxiliary/rocauxiliary_orgqr_ungqr.cpp index f7f7407ad..3d4cd23d0 100644 --- a/rocsolver/library/src/auxiliary/rocauxiliary_orgqr_ungqr.cpp +++ b/rocsolver/library/src/auxiliary/rocauxiliary_orgqr_ungqr.cpp @@ -28,17 +28,19 @@ rocsolver_orgqr_ungqr_impl(rocblas_handle handle, const rocblas_int m, size_t size_2; // size of workspace size_t size_3; // size of array of pointers to workspace size_t size_4; // size of temporary array for triangular factor - rocsolver_orgqr_ungqr_getMemorySize(m, n, k, batch_count, &size_1, - &size_2, &size_3, &size_4); + size_t size_5; // size of worksapce for TRMM calls + rocsolver_orgqr_ungqr_getMemorySize( + m, n, k, batch_count, &size_1, &size_2, &size_3, &size_4, &size_5); // (TODO) MEMORY SIZE QUERIES AND ALLOCATIONS TO BE DONE WITH ROCBLAS HANDLE - void *scalars, *work, *workArr, *trfact; + void *scalars, *work, *workArr, *trfact, *workTrmm; hipMalloc(&scalars, size_1); hipMalloc(&work, size_2); hipMalloc(&workArr, size_3); hipMalloc(&trfact, size_4); + hipMalloc(&workTrmm, size_5); if (!scalars || (size_2 && !work) || (size_3 && !workArr) || - (size_4 && !trfact)) + (size_4 && !trfact) || (size_5 && !workTrmm)) return rocblas_status_memory_error; // scalar constants for rocblas functions calls @@ -50,12 +52,13 @@ rocsolver_orgqr_ungqr_impl(rocblas_handle handle, const rocblas_int m, rocblas_status status = rocsolver_orgqr_ungqr_template( handle, m, n, k, A, 0, // shifted 0 entries lda, strideA, ipiv, strideP, batch_count, (T *)scalars, (T *)work, - (T **)workArr, (T *)trfact); + (T **)workArr, (T *)trfact, (T *)workTrmm); hipFree(scalars); hipFree(work); hipFree(workArr); hipFree(trfact); + hipFree(workTrmm); return status; } @@ -67,30 +70,32 @@ rocsolver_orgqr_ungqr_impl(rocblas_handle handle, const rocblas_int m, extern "C" { -ROCSOLVER_EXPORT rocblas_status rocsolver_sorgqr( - rocblas_handle handle, const rocblas_int m, const rocblas_int n, - const rocblas_int k, float *A, const rocblas_int lda, float *ipiv) { +rocblas_status rocsolver_sorgqr(rocblas_handle handle, const rocblas_int m, + const rocblas_int n, const rocblas_int k, + float *A, const rocblas_int lda, float *ipiv) { return rocsolver_orgqr_ungqr_impl(handle, m, n, k, A, lda, ipiv); } -ROCSOLVER_EXPORT rocblas_status rocsolver_dorgqr( - rocblas_handle handle, const rocblas_int m, const rocblas_int n, - const rocblas_int k, double *A, const rocblas_int lda, double *ipiv) { +rocblas_status rocsolver_dorgqr(rocblas_handle handle, const rocblas_int m, + const rocblas_int n, const rocblas_int k, + double *A, const rocblas_int lda, + double *ipiv) { return rocsolver_orgqr_ungqr_impl(handle, m, n, k, A, lda, ipiv); } -ROCSOLVER_EXPORT rocblas_status rocsolver_cungqr( - rocblas_handle handle, const rocblas_int m, const rocblas_int n, - const rocblas_int k, rocblas_float_complex *A, const rocblas_int lda, - rocblas_float_complex *ipiv) { +rocblas_status rocsolver_cungqr(rocblas_handle handle, const rocblas_int m, + const rocblas_int n, const rocblas_int k, + rocblas_float_complex *A, const rocblas_int lda, + rocblas_float_complex *ipiv) { return rocsolver_orgqr_ungqr_impl(handle, m, n, k, A, lda, ipiv); } -ROCSOLVER_EXPORT rocblas_status rocsolver_zungqr( - rocblas_handle handle, const rocblas_int m, const rocblas_int n, - const rocblas_int k, rocblas_double_complex *A, const rocblas_int lda, - rocblas_double_complex *ipiv) { +rocblas_status rocsolver_zungqr(rocblas_handle handle, const rocblas_int m, + const rocblas_int n, const rocblas_int k, + rocblas_double_complex *A, + const rocblas_int lda, + rocblas_double_complex *ipiv) { return rocsolver_orgqr_ungqr_impl(handle, m, n, k, A, lda, ipiv); } diff --git a/rocsolver/library/src/auxiliary/rocauxiliary_orgqr_ungqr.hpp b/rocsolver/library/src/auxiliary/rocauxiliary_orgqr_ungqr.hpp index e4fe88fb5..52f965c80 100644 --- a/rocsolver/library/src/auxiliary/rocauxiliary_orgqr_ungqr.hpp +++ b/rocsolver/library/src/auxiliary/rocauxiliary_orgqr_ungqr.hpp @@ -34,18 +34,17 @@ __global__ void set_zero_col(const rocblas_int n, const rocblas_int kk, U A, } template -void rocsolver_orgqr_ungqr_getMemorySize(const rocblas_int m, - const rocblas_int n, - const rocblas_int k, - const rocblas_int batch_count, - size_t *size_1, size_t *size_2, - size_t *size_3, size_t *size_4) { - size_t s1, s2, s3; +void rocsolver_orgqr_ungqr_getMemorySize( + const rocblas_int m, const rocblas_int n, const rocblas_int k, + const rocblas_int batch_count, size_t *size_1, size_t *size_2, + size_t *size_3, size_t *size_4, size_t *size_5) { + size_t s1, s2, s3, unused; rocsolver_org2r_ung2r_getMemorySize(m, n, batch_count, size_1, size_2, size_3); if (k <= GEQRF_GEQR2_SWITCHSIZE) { *size_4 = 0; + *size_5 = 0; } else { // size of workspace // maximum of what is needed by org2r, larft and larfb @@ -55,8 +54,8 @@ void rocsolver_orgqr_ungqr_getMemorySize(const rocblas_int m, rocsolver_org2r_ung2r_getMemorySize(m, max(n - kk, jb), batch_count, &s1); rocsolver_larft_getMemorySize(jb, batch_count, &s2); - rocsolver_larfb_getMemorySize(rocblas_side_left, m, n - jb, jb, - batch_count, &s3); + rocsolver_larfb_getMemorySize( + rocblas_side_left, m, n - jb, jb, batch_count, &s3, &unused, size_5); *size_2 = max(max(s1, s2), s3); @@ -70,8 +69,8 @@ rocblas_status rocsolver_orgqr_ungqr_template( rocblas_handle handle, const rocblas_int m, const rocblas_int n, const rocblas_int k, U A, const rocblas_int shiftA, const rocblas_int lda, const rocblas_stride strideA, T *ipiv, const rocblas_stride strideP, - const rocblas_int batch_count, T *scalars, T *work, T **workArr, - T *trfact) { + const rocblas_int batch_count, T *scalars, T *work, T **workArr, T *trfact, + T *workTrmm) { // quick return if (!n || !m || !batch_count) return rocblas_status_success; @@ -126,7 +125,7 @@ rocblas_status rocsolver_orgqr_ungqr_template( rocblas_forward_direction, rocblas_column_wise, m - j, n - j - jb, jb, A, shiftA + idx2D(j, j, lda), lda, strideA, trfact, 0, ldw, strideW, A, shiftA + idx2D(j, j + jb, lda), lda, strideA, batch_count, work, - workArr); + workArr, workTrmm); } // now compute the current block and set to zero diff --git a/rocsolver/library/src/auxiliary/rocauxiliary_orm2r_unm2r.cpp b/rocsolver/library/src/auxiliary/rocauxiliary_orm2r_unm2r.cpp index 6996c6b8b..782624fd8 100644 --- a/rocsolver/library/src/auxiliary/rocauxiliary_orm2r_unm2r.cpp +++ b/rocsolver/library/src/auxiliary/rocauxiliary_orm2r_unm2r.cpp @@ -71,40 +71,46 @@ rocsolver_orm2r_unm2r_impl(rocblas_handle handle, const rocblas_side side, extern "C" { -ROCSOLVER_EXPORT rocblas_status rocsolver_sorm2r( - rocblas_handle handle, const rocblas_side side, - const rocblas_operation trans, const rocblas_int m, const rocblas_int n, - const rocblas_int k, float *A, const rocblas_int lda, float *ipiv, float *C, - const rocblas_int ldc) { +rocblas_status rocsolver_sorm2r(rocblas_handle handle, const rocblas_side side, + const rocblas_operation trans, + const rocblas_int m, const rocblas_int n, + const rocblas_int k, float *A, + const rocblas_int lda, float *ipiv, float *C, + const rocblas_int ldc) { return rocsolver_orm2r_unm2r_impl(handle, side, trans, m, n, k, A, lda, ipiv, C, ldc); } -ROCSOLVER_EXPORT rocblas_status rocsolver_dorm2r( - rocblas_handle handle, const rocblas_side side, - const rocblas_operation trans, const rocblas_int m, const rocblas_int n, - const rocblas_int k, double *A, const rocblas_int lda, double *ipiv, - double *C, const rocblas_int ldc) { +rocblas_status rocsolver_dorm2r(rocblas_handle handle, const rocblas_side side, + const rocblas_operation trans, + const rocblas_int m, const rocblas_int n, + const rocblas_int k, double *A, + const rocblas_int lda, double *ipiv, double *C, + const rocblas_int ldc) { return rocsolver_orm2r_unm2r_impl(handle, side, trans, m, n, k, A, lda, ipiv, C, ldc); } -ROCSOLVER_EXPORT rocblas_status rocsolver_cunm2r( - rocblas_handle handle, const rocblas_side side, - const rocblas_operation trans, const rocblas_int m, const rocblas_int n, - const rocblas_int k, rocblas_float_complex *A, const rocblas_int lda, - rocblas_float_complex *ipiv, rocblas_float_complex *C, - const rocblas_int ldc) { +rocblas_status rocsolver_cunm2r(rocblas_handle handle, const rocblas_side side, + const rocblas_operation trans, + const rocblas_int m, const rocblas_int n, + const rocblas_int k, rocblas_float_complex *A, + const rocblas_int lda, + rocblas_float_complex *ipiv, + rocblas_float_complex *C, + const rocblas_int ldc) { return rocsolver_orm2r_unm2r_impl( handle, side, trans, m, n, k, A, lda, ipiv, C, ldc); } -ROCSOLVER_EXPORT rocblas_status rocsolver_zunm2r( - rocblas_handle handle, const rocblas_side side, - const rocblas_operation trans, const rocblas_int m, const rocblas_int n, - const rocblas_int k, rocblas_double_complex *A, const rocblas_int lda, - rocblas_double_complex *ipiv, rocblas_double_complex *C, - const rocblas_int ldc) { +rocblas_status rocsolver_zunm2r(rocblas_handle handle, const rocblas_side side, + const rocblas_operation trans, + const rocblas_int m, const rocblas_int n, + const rocblas_int k, rocblas_double_complex *A, + const rocblas_int lda, + rocblas_double_complex *ipiv, + rocblas_double_complex *C, + const rocblas_int ldc) { return rocsolver_orm2r_unm2r_impl( handle, side, trans, m, n, k, A, lda, ipiv, C, ldc); } diff --git a/rocsolver/library/src/auxiliary/rocauxiliary_ormbr_unmbr.cpp b/rocsolver/library/src/auxiliary/rocauxiliary_ormbr_unmbr.cpp index 9b4964d96..efae7fe47 100644 --- a/rocsolver/library/src/auxiliary/rocauxiliary_ormbr_unmbr.cpp +++ b/rocsolver/library/src/auxiliary/rocauxiliary_ormbr_unmbr.cpp @@ -31,17 +31,20 @@ rocblas_status rocsolver_ormbr_unmbr_impl( size_t size_2; // size of workspace size_t size_3; // size of array of pointers to workspace size_t size_4; // size of temporary array for triangular factor - rocsolver_ormbr_unmbr_getMemorySize( - storev, side, m, n, k, batch_count, &size_1, &size_2, &size_3, &size_4); + size_t size_5; // workspace for TRMM calls + rocsolver_ormbr_unmbr_getMemorySize(storev, side, m, n, k, + batch_count, &size_1, &size_2, + &size_3, &size_4, &size_5); // (TODO) MEMORY SIZE QUERIES AND ALLOCATIONS TO BE DONE WITH ROCBLAS HANDLE - void *scalars, *work, *workArr, *trfact; + void *scalars, *work, *workArr, *trfact, *workTrmm; hipMalloc(&scalars, size_1); hipMalloc(&work, size_2); hipMalloc(&workArr, size_3); hipMalloc(&trfact, size_4); + hipMalloc(&workTrmm, size_5); if (!scalars || (size_2 && !work) || (size_3 && !workArr) || - (size_4 && !trfact)) + (size_4 && !trfact) || (size_5 && !workTrmm)) return rocblas_status_memory_error; // scalar constants for rocblas functions calls @@ -53,12 +56,13 @@ rocblas_status rocsolver_ormbr_unmbr_impl( rocblas_status status = rocsolver_ormbr_unmbr_template( handle, storev, side, trans, m, n, k, A, 0, // shifted 0 entries lda, strideA, ipiv, strideP, C, 0, ldc, strideC, batch_count, - (T *)scalars, (T *)work, (T **)workArr, (T *)trfact); + (T *)scalars, (T *)work, (T **)workArr, (T *)trfact, (T *)workTrmm); hipFree(scalars); hipFree(work); hipFree(workArr); hipFree(trfact); + hipFree(workTrmm); return status; } @@ -70,40 +74,44 @@ rocblas_status rocsolver_ormbr_unmbr_impl( extern "C" { -ROCSOLVER_EXPORT rocblas_status rocsolver_sormbr( - rocblas_handle handle, const rocblas_storev storev, const rocblas_side side, - const rocblas_operation trans, const rocblas_int m, const rocblas_int n, - const rocblas_int k, float *A, const rocblas_int lda, float *ipiv, float *C, - const rocblas_int ldc) { +rocblas_status +rocsolver_sormbr(rocblas_handle handle, const rocblas_storev storev, + const rocblas_side side, const rocblas_operation trans, + const rocblas_int m, const rocblas_int n, const rocblas_int k, + float *A, const rocblas_int lda, float *ipiv, float *C, + const rocblas_int ldc) { return rocsolver_ormbr_unmbr_impl(handle, storev, side, trans, m, n, k, A, lda, ipiv, C, ldc); } -ROCSOLVER_EXPORT rocblas_status rocsolver_dormbr( - rocblas_handle handle, const rocblas_storev storev, const rocblas_side side, - const rocblas_operation trans, const rocblas_int m, const rocblas_int n, - const rocblas_int k, double *A, const rocblas_int lda, double *ipiv, - double *C, const rocblas_int ldc) { +rocblas_status +rocsolver_dormbr(rocblas_handle handle, const rocblas_storev storev, + const rocblas_side side, const rocblas_operation trans, + const rocblas_int m, const rocblas_int n, const rocblas_int k, + double *A, const rocblas_int lda, double *ipiv, double *C, + const rocblas_int ldc) { return rocsolver_ormbr_unmbr_impl(handle, storev, side, trans, m, n, k, A, lda, ipiv, C, ldc); } -ROCSOLVER_EXPORT rocblas_status rocsolver_cunmbr( - rocblas_handle handle, const rocblas_storev storev, const rocblas_side side, - const rocblas_operation trans, const rocblas_int m, const rocblas_int n, - const rocblas_int k, rocblas_float_complex *A, const rocblas_int lda, - rocblas_float_complex *ipiv, rocblas_float_complex *C, - const rocblas_int ldc) { +rocblas_status +rocsolver_cunmbr(rocblas_handle handle, const rocblas_storev storev, + const rocblas_side side, const rocblas_operation trans, + const rocblas_int m, const rocblas_int n, const rocblas_int k, + rocblas_float_complex *A, const rocblas_int lda, + rocblas_float_complex *ipiv, rocblas_float_complex *C, + const rocblas_int ldc) { return rocsolver_ormbr_unmbr_impl( handle, storev, side, trans, m, n, k, A, lda, ipiv, C, ldc); } -ROCSOLVER_EXPORT rocblas_status rocsolver_zunmbr( - rocblas_handle handle, const rocblas_storev storev, const rocblas_side side, - const rocblas_operation trans, const rocblas_int m, const rocblas_int n, - const rocblas_int k, rocblas_double_complex *A, const rocblas_int lda, - rocblas_double_complex *ipiv, rocblas_double_complex *C, - const rocblas_int ldc) { +rocblas_status +rocsolver_zunmbr(rocblas_handle handle, const rocblas_storev storev, + const rocblas_side side, const rocblas_operation trans, + const rocblas_int m, const rocblas_int n, const rocblas_int k, + rocblas_double_complex *A, const rocblas_int lda, + rocblas_double_complex *ipiv, rocblas_double_complex *C, + const rocblas_int ldc) { return rocsolver_ormbr_unmbr_impl( handle, storev, side, trans, m, n, k, A, lda, ipiv, C, ldc); } diff --git a/rocsolver/library/src/auxiliary/rocauxiliary_ormbr_unmbr.hpp b/rocsolver/library/src/auxiliary/rocauxiliary_ormbr_unmbr.hpp index fce30d73c..fc50d1dba 100644 --- a/rocsolver/library/src/auxiliary/rocauxiliary_ormbr_unmbr.hpp +++ b/rocsolver/library/src/auxiliary/rocauxiliary_ormbr_unmbr.hpp @@ -19,14 +19,17 @@ template void rocsolver_ormbr_unmbr_getMemorySize( const rocblas_storev storev, const rocblas_side side, const rocblas_int m, const rocblas_int n, const rocblas_int k, const rocblas_int batch_count, - size_t *size_1, size_t *size_2, size_t *size_3, size_t *size_4) { + size_t *size_1, size_t *size_2, size_t *size_3, size_t *size_4, + size_t *size_5) { rocblas_int nq = side == rocblas_side_left ? m : n; if (storev == rocblas_column_wise) - rocsolver_ormqr_unmqr_getMemorySize( - side, m, n, min(nq, k), batch_count, size_1, size_2, size_3, size_4); + rocsolver_ormqr_unmqr_getMemorySize(side, m, n, min(nq, k), + batch_count, size_1, size_2, + size_3, size_4, size_5); else - rocsolver_ormlq_unmlq_getMemorySize( - side, m, n, min(nq, k), batch_count, size_1, size_2, size_3, size_4); + rocsolver_ormlq_unmlq_getMemorySize(side, m, n, min(nq, k), + batch_count, size_1, size_2, + size_3, size_4, size_5); } template @@ -77,7 +80,7 @@ rocblas_status rocsolver_ormbr_unmbr_template( const rocblas_stride strideA, T *ipiv, const rocblas_stride strideP, U C, const rocblas_int shiftC, const rocblas_int ldc, const rocblas_stride strideC, const rocblas_int batch_count, T *scalars, - T *work, T **workArr, T *trfact) { + T *work, T **workArr, T *trfact, T *workTrmm) { // quick return if (!n || !m || !k || !batch_count) return rocblas_status_success; @@ -105,14 +108,15 @@ rocblas_status rocsolver_ormbr_unmbr_template( if (nq >= k) { rocsolver_ormqr_unmqr_template( handle, side, trans, m, n, k, A, shiftA, lda, strideA, ipiv, strideP, - C, shiftC, ldc, strideC, batch_count, scalars, work, workArr, trfact); + C, shiftC, ldc, strideC, batch_count, scalars, work, workArr, trfact, + workTrmm); } else { // shift the householder vectors provided by gebrd as they come below the // first subdiagonal rocsolver_ormqr_unmqr_template( handle, side, trans, rows, cols, nq - 1, A, shiftA + idx2D(1, 0, lda), lda, strideA, ipiv, strideP, C, shiftC + idx2D(rowC, colC, ldc), ldc, - strideC, batch_count, scalars, work, workArr, trfact); + strideC, batch_count, scalars, work, workArr, trfact, workTrmm); } } @@ -128,7 +132,8 @@ rocblas_status rocsolver_ormbr_unmbr_template( if (nq > k) { rocsolver_ormlq_unmlq_template( handle, side, transP, m, n, k, A, shiftA, lda, strideA, ipiv, strideP, - C, shiftC, ldc, strideC, batch_count, scalars, work, workArr, trfact); + C, shiftC, ldc, strideC, batch_count, scalars, work, workArr, trfact, + workTrmm); } else { // shift the householder vectors provided by gebrd as they come above the // first superdiagonal @@ -136,7 +141,7 @@ rocblas_status rocsolver_ormbr_unmbr_template( handle, side, transP, rows, cols, nq - 1, A, shiftA + idx2D(0, 1, lda), lda, strideA, ipiv, strideP, C, shiftC + idx2D(rowC, colC, ldc), ldc, strideC, batch_count, scalars, - work, workArr, trfact); + work, workArr, trfact, workTrmm); } } diff --git a/rocsolver/library/src/auxiliary/rocauxiliary_orml2_unml2.cpp b/rocsolver/library/src/auxiliary/rocauxiliary_orml2_unml2.cpp index d6c7f3a7d..c885c6836 100644 --- a/rocsolver/library/src/auxiliary/rocauxiliary_orml2_unml2.cpp +++ b/rocsolver/library/src/auxiliary/rocauxiliary_orml2_unml2.cpp @@ -71,40 +71,46 @@ rocsolver_orml2_unml2_impl(rocblas_handle handle, const rocblas_side side, extern "C" { -ROCSOLVER_EXPORT rocblas_status rocsolver_sorml2( - rocblas_handle handle, const rocblas_side side, - const rocblas_operation trans, const rocblas_int m, const rocblas_int n, - const rocblas_int k, float *A, const rocblas_int lda, float *ipiv, float *C, - const rocblas_int ldc) { +rocblas_status rocsolver_sorml2(rocblas_handle handle, const rocblas_side side, + const rocblas_operation trans, + const rocblas_int m, const rocblas_int n, + const rocblas_int k, float *A, + const rocblas_int lda, float *ipiv, float *C, + const rocblas_int ldc) { return rocsolver_orml2_unml2_impl(handle, side, trans, m, n, k, A, lda, ipiv, C, ldc); } -ROCSOLVER_EXPORT rocblas_status rocsolver_dorml2( - rocblas_handle handle, const rocblas_side side, - const rocblas_operation trans, const rocblas_int m, const rocblas_int n, - const rocblas_int k, double *A, const rocblas_int lda, double *ipiv, - double *C, const rocblas_int ldc) { +rocblas_status rocsolver_dorml2(rocblas_handle handle, const rocblas_side side, + const rocblas_operation trans, + const rocblas_int m, const rocblas_int n, + const rocblas_int k, double *A, + const rocblas_int lda, double *ipiv, double *C, + const rocblas_int ldc) { return rocsolver_orml2_unml2_impl(handle, side, trans, m, n, k, A, lda, ipiv, C, ldc); } -ROCSOLVER_EXPORT rocblas_status rocsolver_cunml2( - rocblas_handle handle, const rocblas_side side, - const rocblas_operation trans, const rocblas_int m, const rocblas_int n, - const rocblas_int k, rocblas_float_complex *A, const rocblas_int lda, - rocblas_float_complex *ipiv, rocblas_float_complex *C, - const rocblas_int ldc) { +rocblas_status rocsolver_cunml2(rocblas_handle handle, const rocblas_side side, + const rocblas_operation trans, + const rocblas_int m, const rocblas_int n, + const rocblas_int k, rocblas_float_complex *A, + const rocblas_int lda, + rocblas_float_complex *ipiv, + rocblas_float_complex *C, + const rocblas_int ldc) { return rocsolver_orml2_unml2_impl( handle, side, trans, m, n, k, A, lda, ipiv, C, ldc); } -ROCSOLVER_EXPORT rocblas_status rocsolver_zunml2( - rocblas_handle handle, const rocblas_side side, - const rocblas_operation trans, const rocblas_int m, const rocblas_int n, - const rocblas_int k, rocblas_double_complex *A, const rocblas_int lda, - rocblas_double_complex *ipiv, rocblas_double_complex *C, - const rocblas_int ldc) { +rocblas_status rocsolver_zunml2(rocblas_handle handle, const rocblas_side side, + const rocblas_operation trans, + const rocblas_int m, const rocblas_int n, + const rocblas_int k, rocblas_double_complex *A, + const rocblas_int lda, + rocblas_double_complex *ipiv, + rocblas_double_complex *C, + const rocblas_int ldc) { return rocsolver_orml2_unml2_impl( handle, side, trans, m, n, k, A, lda, ipiv, C, ldc); } diff --git a/rocsolver/library/src/auxiliary/rocauxiliary_ormlq_unmlq.cpp b/rocsolver/library/src/auxiliary/rocauxiliary_ormlq_unmlq.cpp index a2120879f..ebce87803 100644 --- a/rocsolver/library/src/auxiliary/rocauxiliary_ormlq_unmlq.cpp +++ b/rocsolver/library/src/auxiliary/rocauxiliary_ormlq_unmlq.cpp @@ -31,19 +31,20 @@ rocsolver_ormlq_unmlq_impl(rocblas_handle handle, const rocblas_side side, size_t size_1; // size of constants size_t size_2; // size of workspace size_t size_3; // size of array of pointers to workspace - size_t size_4; // size of temporary array for triangular factor or diagonal - // elements + size_t size_4; // size of triangular factor or diagonal elements + size_t size_5; // size of workspace for TRMM calls rocsolver_ormlq_unmlq_getMemorySize( - side, m, n, k, batch_count, &size_1, &size_2, &size_3, &size_4); + side, m, n, k, batch_count, &size_1, &size_2, &size_3, &size_4, &size_5); // (TODO) MEMORY SIZE QUERIES AND ALLOCATIONS TO BE DONE WITH ROCBLAS HANDLE - void *scalars, *work, *workArr, *trfact; + void *scalars, *work, *workArr, *trfact, *workTrmm; hipMalloc(&scalars, size_1); hipMalloc(&work, size_2); hipMalloc(&workArr, size_3); hipMalloc(&trfact, size_4); + hipMalloc(&workTrmm, size_5); if (!scalars || (size_2 && !work) || (size_3 && !workArr) || - (size_4 && !trfact)) + (size_4 && !trfact) || (size_5 && !workTrmm)) return rocblas_status_memory_error; // scalar constants for rocblas functions calls @@ -55,12 +56,13 @@ rocsolver_ormlq_unmlq_impl(rocblas_handle handle, const rocblas_side side, rocblas_status status = rocsolver_ormlq_unmlq_template( handle, side, trans, m, n, k, A, 0, // shifted 0 entries lda, strideA, ipiv, strideP, C, 0, ldc, strideC, batch_count, - (T *)scalars, (T *)work, (T **)workArr, (T *)trfact); + (T *)scalars, (T *)work, (T **)workArr, (T *)trfact, (T *)workTrmm); hipFree(scalars); hipFree(work); hipFree(workArr); hipFree(trfact); + hipFree(workTrmm); return status; } @@ -72,40 +74,46 @@ rocsolver_ormlq_unmlq_impl(rocblas_handle handle, const rocblas_side side, extern "C" { -ROCSOLVER_EXPORT rocblas_status rocsolver_sormlq( - rocblas_handle handle, const rocblas_side side, - const rocblas_operation trans, const rocblas_int m, const rocblas_int n, - const rocblas_int k, float *A, const rocblas_int lda, float *ipiv, float *C, - const rocblas_int ldc) { +rocblas_status rocsolver_sormlq(rocblas_handle handle, const rocblas_side side, + const rocblas_operation trans, + const rocblas_int m, const rocblas_int n, + const rocblas_int k, float *A, + const rocblas_int lda, float *ipiv, float *C, + const rocblas_int ldc) { return rocsolver_ormlq_unmlq_impl(handle, side, trans, m, n, k, A, lda, ipiv, C, ldc); } -ROCSOLVER_EXPORT rocblas_status rocsolver_dormlq( - rocblas_handle handle, const rocblas_side side, - const rocblas_operation trans, const rocblas_int m, const rocblas_int n, - const rocblas_int k, double *A, const rocblas_int lda, double *ipiv, - double *C, const rocblas_int ldc) { +rocblas_status rocsolver_dormlq(rocblas_handle handle, const rocblas_side side, + const rocblas_operation trans, + const rocblas_int m, const rocblas_int n, + const rocblas_int k, double *A, + const rocblas_int lda, double *ipiv, double *C, + const rocblas_int ldc) { return rocsolver_ormlq_unmlq_impl(handle, side, trans, m, n, k, A, lda, ipiv, C, ldc); } -ROCSOLVER_EXPORT rocblas_status rocsolver_cunmlq( - rocblas_handle handle, const rocblas_side side, - const rocblas_operation trans, const rocblas_int m, const rocblas_int n, - const rocblas_int k, rocblas_float_complex *A, const rocblas_int lda, - rocblas_float_complex *ipiv, rocblas_float_complex *C, - const rocblas_int ldc) { +rocblas_status rocsolver_cunmlq(rocblas_handle handle, const rocblas_side side, + const rocblas_operation trans, + const rocblas_int m, const rocblas_int n, + const rocblas_int k, rocblas_float_complex *A, + const rocblas_int lda, + rocblas_float_complex *ipiv, + rocblas_float_complex *C, + const rocblas_int ldc) { return rocsolver_ormlq_unmlq_impl( handle, side, trans, m, n, k, A, lda, ipiv, C, ldc); } -ROCSOLVER_EXPORT rocblas_status rocsolver_zunmlq( - rocblas_handle handle, const rocblas_side side, - const rocblas_operation trans, const rocblas_int m, const rocblas_int n, - const rocblas_int k, rocblas_double_complex *A, const rocblas_int lda, - rocblas_double_complex *ipiv, rocblas_double_complex *C, - const rocblas_int ldc) { +rocblas_status rocsolver_zunmlq(rocblas_handle handle, const rocblas_side side, + const rocblas_operation trans, + const rocblas_int m, const rocblas_int n, + const rocblas_int k, rocblas_double_complex *A, + const rocblas_int lda, + rocblas_double_complex *ipiv, + rocblas_double_complex *C, + const rocblas_int ldc) { return rocsolver_ormlq_unmlq_impl( handle, side, trans, m, n, k, A, lda, ipiv, C, ldc); } diff --git a/rocsolver/library/src/auxiliary/rocauxiliary_ormlq_unmlq.hpp b/rocsolver/library/src/auxiliary/rocauxiliary_ormlq_unmlq.hpp index 84a0805b3..4bccbbc28 100644 --- a/rocsolver/library/src/auxiliary/rocauxiliary_ormlq_unmlq.hpp +++ b/rocsolver/library/src/auxiliary/rocauxiliary_ormlq_unmlq.hpp @@ -20,8 +20,8 @@ template void rocsolver_ormlq_unmlq_getMemorySize( const rocblas_side side, const rocblas_int m, const rocblas_int n, const rocblas_int k, const rocblas_int batch_count, size_t *size_1, - size_t *size_2, size_t *size_3, size_t *size_4) { - size_t s1, s2; + size_t *size_2, size_t *size_3, size_t *size_4, size_t *size_5) { + size_t s1, s2, unused; rocsolver_orml2_unml2_getMemorySize( side, m, n, batch_count, size_1, size_2, size_3, size_4); @@ -30,13 +30,15 @@ void rocsolver_ormlq_unmlq_getMemorySize( // maximum of what is needed by larft and larfb rocblas_int jb = ORMLQ_ORML2_BLOCKSIZE; rocsolver_larft_getMemorySize(min(jb, k), batch_count, &s1); - rocsolver_larfb_getMemorySize(side, m, n, min(jb, k), batch_count, &s2); + rocsolver_larfb_getMemorySize( + side, m, n, min(jb, k), batch_count, &s2, &unused, size_5); *size_2 = max(s1, s2); // size of temporary array for triangular factor *size_4 = sizeof(T) * jb * jb * batch_count; - } + } else + *size_5 = 0; } template ( - side, m, n, k, batch_count, &size_1, &size_2, &size_3, &size_4); + side, m, n, k, batch_count, &size_1, &size_2, &size_3, &size_4, &size_5); // (TODO) MEMORY SIZE QUERIES AND ALLOCATIONS TO BE DONE WITH ROCBLAS HANDLE - void *scalars, *work, *workArr, *trfact; + void *scalars, *work, *workArr, *trfact, *workTrmm; hipMalloc(&scalars, size_1); hipMalloc(&work, size_2); hipMalloc(&workArr, size_3); hipMalloc(&trfact, size_4); + hipMalloc(&workTrmm, size_5); if (!scalars || (size_2 && !work) || (size_3 && !workArr) || - (size_4 && !trfact)) + (size_4 && !trfact) || (size_5 && !workTrmm)) return rocblas_status_memory_error; // scalar constants for rocblas functions calls @@ -55,12 +56,13 @@ rocsolver_ormqr_unmqr_impl(rocblas_handle handle, const rocblas_side side, rocblas_status status = rocsolver_ormqr_unmqr_template( handle, side, trans, m, n, k, A, 0, // shifted 0 entries lda, strideA, ipiv, strideP, C, 0, ldc, strideC, batch_count, - (T *)scalars, (T *)work, (T **)workArr, (T *)trfact); + (T *)scalars, (T *)work, (T **)workArr, (T *)trfact, (T *)workTrmm); hipFree(scalars); hipFree(work); hipFree(workArr); hipFree(trfact); + hipFree(workTrmm); return status; } @@ -72,40 +74,46 @@ rocsolver_ormqr_unmqr_impl(rocblas_handle handle, const rocblas_side side, extern "C" { -ROCSOLVER_EXPORT rocblas_status rocsolver_sormqr( - rocblas_handle handle, const rocblas_side side, - const rocblas_operation trans, const rocblas_int m, const rocblas_int n, - const rocblas_int k, float *A, const rocblas_int lda, float *ipiv, float *C, - const rocblas_int ldc) { +rocblas_status rocsolver_sormqr(rocblas_handle handle, const rocblas_side side, + const rocblas_operation trans, + const rocblas_int m, const rocblas_int n, + const rocblas_int k, float *A, + const rocblas_int lda, float *ipiv, float *C, + const rocblas_int ldc) { return rocsolver_ormqr_unmqr_impl(handle, side, trans, m, n, k, A, lda, ipiv, C, ldc); } -ROCSOLVER_EXPORT rocblas_status rocsolver_dormqr( - rocblas_handle handle, const rocblas_side side, - const rocblas_operation trans, const rocblas_int m, const rocblas_int n, - const rocblas_int k, double *A, const rocblas_int lda, double *ipiv, - double *C, const rocblas_int ldc) { +rocblas_status rocsolver_dormqr(rocblas_handle handle, const rocblas_side side, + const rocblas_operation trans, + const rocblas_int m, const rocblas_int n, + const rocblas_int k, double *A, + const rocblas_int lda, double *ipiv, double *C, + const rocblas_int ldc) { return rocsolver_ormqr_unmqr_impl(handle, side, trans, m, n, k, A, lda, ipiv, C, ldc); } -ROCSOLVER_EXPORT rocblas_status rocsolver_cunmqr( - rocblas_handle handle, const rocblas_side side, - const rocblas_operation trans, const rocblas_int m, const rocblas_int n, - const rocblas_int k, rocblas_float_complex *A, const rocblas_int lda, - rocblas_float_complex *ipiv, rocblas_float_complex *C, - const rocblas_int ldc) { +rocblas_status rocsolver_cunmqr(rocblas_handle handle, const rocblas_side side, + const rocblas_operation trans, + const rocblas_int m, const rocblas_int n, + const rocblas_int k, rocblas_float_complex *A, + const rocblas_int lda, + rocblas_float_complex *ipiv, + rocblas_float_complex *C, + const rocblas_int ldc) { return rocsolver_ormqr_unmqr_impl( handle, side, trans, m, n, k, A, lda, ipiv, C, ldc); } -ROCSOLVER_EXPORT rocblas_status rocsolver_zunmqr( - rocblas_handle handle, const rocblas_side side, - const rocblas_operation trans, const rocblas_int m, const rocblas_int n, - const rocblas_int k, rocblas_double_complex *A, const rocblas_int lda, - rocblas_double_complex *ipiv, rocblas_double_complex *C, - const rocblas_int ldc) { +rocblas_status rocsolver_zunmqr(rocblas_handle handle, const rocblas_side side, + const rocblas_operation trans, + const rocblas_int m, const rocblas_int n, + const rocblas_int k, rocblas_double_complex *A, + const rocblas_int lda, + rocblas_double_complex *ipiv, + rocblas_double_complex *C, + const rocblas_int ldc) { return rocsolver_ormqr_unmqr_impl( handle, side, trans, m, n, k, A, lda, ipiv, C, ldc); } diff --git a/rocsolver/library/src/auxiliary/rocauxiliary_ormqr_unmqr.hpp b/rocsolver/library/src/auxiliary/rocauxiliary_ormqr_unmqr.hpp index 79ba33e5e..64620cf51 100644 --- a/rocsolver/library/src/auxiliary/rocauxiliary_ormqr_unmqr.hpp +++ b/rocsolver/library/src/auxiliary/rocauxiliary_ormqr_unmqr.hpp @@ -20,8 +20,8 @@ template void rocsolver_ormqr_unmqr_getMemorySize( const rocblas_side side, const rocblas_int m, const rocblas_int n, const rocblas_int k, const rocblas_int batch_count, size_t *size_1, - size_t *size_2, size_t *size_3, size_t *size_4) { - size_t s1, s2; + size_t *size_2, size_t *size_3, size_t *size_4, size_t *size_5) { + size_t s1, s2, unused; rocsolver_orm2r_unm2r_getMemorySize( side, m, n, batch_count, size_1, size_2, size_3, size_4); @@ -30,13 +30,15 @@ void rocsolver_ormqr_unmqr_getMemorySize( // maximum of what is needed by larft and larfb rocblas_int jb = ORMQR_ORM2R_BLOCKSIZE; rocsolver_larft_getMemorySize(min(jb, k), batch_count, &s1); - rocsolver_larfb_getMemorySize(side, m, n, min(jb, k), batch_count, &s2); + rocsolver_larfb_getMemorySize( + side, m, n, min(jb, k), batch_count, &s2, &unused, size_5); *size_2 = max(s1, s2); // size of temporary array for triangular factor *size_4 = sizeof(T) * jb * jb * batch_count; - } + } else + *size_5 = 0; } template @@ -47,7 +49,7 @@ rocblas_status rocsolver_ormqr_unmqr_template( const rocblas_stride strideA, T *ipiv, const rocblas_stride strideP, U C, const rocblas_int shiftC, const rocblas_int ldc, const rocblas_stride strideC, const rocblas_int batch_count, T *scalars, - T *work, T **workArr, T *trfact) { + T *work, T **workArr, T *trfact, T *workTrmm) { // quick return if (!n || !m || !k || !batch_count) return rocblas_status_success; @@ -114,7 +116,7 @@ rocblas_status rocsolver_ormqr_unmqr_template( handle, side, trans, rocblas_forward_direction, rocblas_column_wise, nrow, ncol, min(ldw, k - i), A, shiftA + idx2D(i, i, lda), lda, strideA, trfact, 0, ldw, strideW, C, shiftC + idx2D(ic, jc, ldc), ldc, strideC, - batch_count, work, workArr); + batch_count, work, workArr, workTrmm); } return rocblas_status_success; diff --git a/rocsolver/library/src/include/rocblas.hpp b/rocsolver/library/src/include/rocblas.hpp index a8bbc0db3..8edebac75 100644 --- a/rocsolver/library/src/include/rocblas.hpp +++ b/rocsolver/library/src/include/rocblas.hpp @@ -270,17 +270,19 @@ rocblasCall_gemm(rocblas_handle handle, rocblas_operation trans_a, } // trmm -template = 0> +template rocblas_status rocblasCall_trmm(rocblas_handle handle, rocblas_side side, rocblas_fill uplo, rocblas_operation transA, rocblas_diagonal diag, rocblas_int m, - rocblas_int n, U alpha, V A, rocblas_int offsetA, - rocblas_int lda, rocblas_stride strideA, V B, + rocblas_int n, U alpha, T *A, rocblas_int offsetA, + rocblas_int lda, rocblas_stride strideA, T *B, rocblas_int offsetB, rocblas_int ldb, rocblas_stride strideB, rocblas_int batch_count, T *work, T **workArr) { constexpr rocblas_int nb = ROCBLAS_TRMM_NB; constexpr rocblas_stride strideW = 2 * ROCBLAS_TRMM_NB * ROCBLAS_TRMM_NB; + + // adding offsets directly to the arrays A and B until rocblas_trmm + // supports offset arguments return rocblas_trmm_template( handle, side, uplo, transA, diag, m, n, cast2constType(alpha), cast2constType(A + offsetA), lda, strideA, B + offsetB, ldb, strideB, @@ -288,38 +290,84 @@ rocblasCall_trmm(rocblas_handle handle, rocblas_side side, rocblas_fill uplo, } // trmm overload -template = 0> +template rocblas_status rocblasCall_trmm(rocblas_handle handle, rocblas_side side, rocblas_fill uplo, rocblas_operation transA, rocblas_diagonal diag, rocblas_int m, - rocblas_int n, U alpha, V A, rocblas_int offsetA, - rocblas_int lda, rocblas_stride strideA, V B, + rocblas_int n, U alpha, T *const *A, rocblas_int offsetA, + rocblas_int lda, rocblas_stride strideA, T *const *B, rocblas_int offsetB, rocblas_int ldb, rocblas_stride strideB, rocblas_int batch_count, T *work, T **workArr) { constexpr rocblas_int nb = ROCBLAS_TRMM_NB; constexpr rocblas_stride strideW = 2 * ROCBLAS_TRMM_NB * ROCBLAS_TRMM_NB; - // since trmm doesn't have offset arguments, we need to manually offset A and - // B (and store in workArr) - V AA = (V)workArr + batch_count; - V BB = (V)workArr + 2 * batch_count; - hipStream_t stream; rocblas_get_stream(handle, &stream); - rocblas_int blocks = (batch_count - 1) / 256 + 1; hipLaunchKernelGGL(get_array, dim3(blocks), dim3(256), 0, stream, workArr, work, strideW, batch_count); - hipLaunchKernelGGL(shift_array, dim3(blocks), dim3(256), 0, stream, - workArr + batch_count, A, offsetA, batch_count); - hipLaunchKernelGGL(shift_array, dim3(blocks), dim3(256), 0, stream, - workArr + 2 * batch_count, B, offsetB, batch_count); - return rocblas_trmm_template( + // until rocblas_trmm support offset arguments, + // we need to manually offset A and B and store in temporary arrays AA and BB + T **AA, **BB; + hipMalloc(&AA, sizeof(T *) * batch_count); + hipMalloc(&BB, sizeof(T *) * batch_count); + hipLaunchKernelGGL(shift_array, dim3(blocks), dim3(256), 0, stream, AA, A, + offsetA, batch_count); + hipLaunchKernelGGL(shift_array, dim3(blocks), dim3(256), 0, stream, BB, B, + offsetB, batch_count); + + rocblas_status status = rocblas_trmm_template( handle, side, uplo, transA, diag, m, n, cast2constType(alpha), - cast2constType(AA), lda, strideA, BB, ldb, strideB, batch_count, - (V)workArr, strideW); + cast2constType(cast2constPointer(AA)), lda, strideA, + cast2constPointer(BB), ldb, strideB, batch_count, + cast2constPointer(workArr), strideW); + + hipFree(AA); + hipFree(BB); + + return status; +} + +// trmm overload +template +rocblas_status +rocblasCall_trmm(rocblas_handle handle, rocblas_side side, rocblas_fill uplo, + rocblas_operation transA, rocblas_diagonal diag, rocblas_int m, + rocblas_int n, U alpha, T *const *A, rocblas_int offsetA, + rocblas_int lda, rocblas_stride strideA, T *B, + rocblas_int offsetB, rocblas_int ldb, rocblas_stride strideB, + rocblas_int batch_count, T *work, T **workArr) { + constexpr rocblas_int nb = ROCBLAS_TRMM_NB; + constexpr rocblas_stride strideW = 2 * ROCBLAS_TRMM_NB * ROCBLAS_TRMM_NB; + + hipStream_t stream; + rocblas_get_stream(handle, &stream); + rocblas_int blocks = (batch_count - 1) / 256 + 1; + + // adding offsets directly to the array B until rocblas_trmm + // supports offset arguments + hipLaunchKernelGGL(get_array, dim3(blocks), dim3(256), 0, stream, workArr, + B + offsetB, strideB, batch_count); + hipLaunchKernelGGL(get_array, dim3(blocks), dim3(256), 0, stream, + workArr + batch_count, work, strideW, batch_count); + + // until rocblas_trmm support offset arguments, + // we need to manually offset A and store in temporary array AA + T **AA; + hipMalloc(&AA, sizeof(T *) * batch_count); + hipLaunchKernelGGL(shift_array, dim3(blocks), dim3(256), 0, stream, AA, A, + offsetA, batch_count); + + rocblas_status status = rocblas_trmm_template( + handle, side, uplo, transA, diag, m, n, cast2constType(alpha), + cast2constType(cast2constPointer(AA)), lda, strideA, + cast2constPointer(workArr), ldb, strideB, batch_count, + cast2constPointer(workArr + batch_count), strideW); + + hipFree(AA); + + return status; } // syrk @@ -436,130 +484,4 @@ rocblasCall_trsm(rocblas_handle handle, rocblas_side side, rocblas_fill uplo, cast2constType(supplied_invA), 0); } -///////////////////////////////////////////////////////////////////////////////////////// -///////////////////////////////////////////////////////////////////////////////////////// -// THESE SHOULD BE SUBTITUTED BY THEIR CORRESPONDING -// ROCBLAS TEMPLATE FUNCTIONS ONCE THEY ARE EXPORTED -// (ROCBLAS.CPP CAN BE ELIMINATED THEN) - -// nrm2 -template -rocblas_status rocblas_nrm2(rocblas_handle handle, rocblas_int n, const T1 *x, - rocblas_int incx, T2 *result); -/*template <> -rocblas_status rocblas_nrm2(rocblas_handle handle, rocblas_int n, - const float* x, const rocblas_int incx, float* -result) { return rocblas_snrm2(handle, n, x, incx, result); -} -template <> -rocblas_status rocblas_nrm2(rocblas_handle handle, rocblas_int n, - const double* x, const rocblas_int incx, double* -result) { return rocblas_dnrm2(handle, n, x, incx, result); -}*/ - -// iamax -// template -// rocblas_status rocblas_iamax(rocblas_handle handle, rocblas_int n, const T -// *x, -// rocblas_int incx, rocblas_int *result); -/*template <> -rocblas_status rocblas_iamax(rocblas_handle handle, rocblas_int n, - const float *x, rocblas_int incx, - rocblas_int *result) { - return rocblas_isamax(handle, n, x, incx, result); -} -template <> -rocblas_status rocblas_iamax(rocblas_handle handle, rocblas_int n, - const double *x, rocblas_int incx, - rocblas_int *result) { - return rocblas_idamax(handle, n, x, incx, result); -} -template <> -rocblas_status rocblas_iamax(rocblas_handle handle, rocblas_int n, - const rocblas_float_complex *x, rocblas_int incx, - rocblas_int *result) { - return rocblas_icamax(handle, n, x, incx, result); -} -template <> -rocblas_status rocblas_iamax(rocblas_handle handle, rocblas_int n, - const rocblas_double_complex *x, rocblas_int incx, - rocblas_int *result) { - return rocblas_izamax(handle, n, x, incx, result); -}*/ - -// trsm -// (Do not remove yet, some functions still use it) -template -rocblas_status rocblas_trsm(rocblas_handle handle, rocblas_side side, - rocblas_fill uplo, rocblas_operation transA, - rocblas_diagonal diag, rocblas_int m, rocblas_int n, - const T *alpha, T *A, rocblas_int lda, T *B, - rocblas_int ldb); -/*template <> -rocblas_status rocblas_trsm(rocblas_handle handle, rocblas_side side, - rocblas_fill uplo, rocblas_operation transA, - rocblas_diagonal diag, rocblas_int m, rocblas_int n, - const float *alpha, float *A, rocblas_int lda, - float *B, rocblas_int ldb) { - return rocblas_strsm(handle, side, uplo, transA, diag, m, n, alpha, A, lda, -B,ldb); -} -template <> -rocblas_status rocblas_trsm(rocblas_handle handle, rocblas_side side, - rocblas_fill uplo, rocblas_operation transA, - rocblas_diagonal diag, rocblas_int m, rocblas_int n, - const double *alpha, double *A, rocblas_int lda, - double *B, rocblas_int ldb) { - return rocblas_dtrsm(handle, side, uplo, transA, diag, m, n, alpha, A, lda, -B,ldb); -} -template <> -rocblas_status rocblas_trsm(rocblas_handle handle, rocblas_side side, - rocblas_fill uplo, rocblas_operation transA, - rocblas_diagonal diag, rocblas_int m, rocblas_int n, - const rocblas_float_complex *alpha, -rocblas_float_complex *A, rocblas_int lda, rocblas_float_complex *B, rocblas_int -ldb) { return rocblas_ctrsm(handle, side, uplo, transA, diag, m, n, alpha, A, -lda, B, ldb); -} -template <> -rocblas_status rocblas_trsm(rocblas_handle handle, rocblas_side side, - rocblas_fill uplo, rocblas_operation transA, - rocblas_diagonal diag, rocblas_int m, rocblas_int n, - const rocblas_double_complex *alpha, -rocblas_double_complex *A, rocblas_int lda, rocblas_double_complex *B, -rocblas_int ldb) { return rocblas_ztrsm(handle, side, uplo, transA, diag, m, n, -alpha, A, lda, B, ldb); -}*/ - -// trmm -template -rocblas_status rocblas_trmm(rocblas_handle handle, rocblas_side side, - rocblas_fill uplo, rocblas_operation trans, - rocblas_diagonal diag, rocblas_int m, rocblas_int n, - T *alpha, T *A, rocblas_int lda, T *B, - rocblas_int ldb); -/*template <> -rocblas_status rocblas_trmm(rocblas_handle handle, rocblas_side side, -rocblas_fill uplo, rocblas_operation trans, rocblas_diagonal diag, rocblas_int -m, rocblas_int n, float *alpha, float *A, rocblas_int lda, float* B, rocblas_int -ldb) -{ - return rocblas_strmm(handle,side,uplo,trans,diag,m,n,alpha,A,lda,B,ldb); -} -template <> -rocblas_status rocblas_trmm(rocblas_handle handle, rocblas_side side, -rocblas_fill uplo, rocblas_operation trans, rocblas_diagonal diag, rocblas_int -m, rocblas_int n, double *alpha, double *A, rocblas_int lda, double* B, -rocblas_int ldb) -{ - return rocblas_dtrmm(handle,side,uplo,trans,diag,m,n,alpha,A,lda,B,ldb); -}*/ - -// trtri -template -rocblas_status rocblas_trtri(rocblas_handle handle, rocblas_fill uplo, - rocblas_diagonal diag, rocblas_int n, const T *A, - rocblas_int lda, T *invA, rocblas_int ldinvA); - #endif // _ROCBLAS_HPP_ diff --git a/rocsolver/library/src/lapack/roclapack_gebd2.cpp b/rocsolver/library/src/lapack/roclapack_gebd2.cpp index 04539932a..1e411d205 100644 --- a/rocsolver/library/src/lapack/roclapack_gebd2.cpp +++ b/rocsolver/library/src/lapack/roclapack_gebd2.cpp @@ -72,32 +72,36 @@ rocblas_status rocsolver_gebd2_impl(rocblas_handle handle, const rocblas_int m, extern "C" { -ROCSOLVER_EXPORT rocblas_status rocsolver_sgebd2( - rocblas_handle handle, const rocblas_int m, const rocblas_int n, float *A, - const rocblas_int lda, float *D, float *E, float *tauq, float *taup) { +rocblas_status rocsolver_sgebd2(rocblas_handle handle, const rocblas_int m, + const rocblas_int n, float *A, + const rocblas_int lda, float *D, float *E, + float *tauq, float *taup) { return rocsolver_gebd2_impl(handle, m, n, A, lda, D, E, tauq, taup); } -ROCSOLVER_EXPORT rocblas_status rocsolver_dgebd2( - rocblas_handle handle, const rocblas_int m, const rocblas_int n, double *A, - const rocblas_int lda, double *D, double *E, double *tauq, double *taup) { +rocblas_status rocsolver_dgebd2(rocblas_handle handle, const rocblas_int m, + const rocblas_int n, double *A, + const rocblas_int lda, double *D, double *E, + double *tauq, double *taup) { return rocsolver_gebd2_impl(handle, m, n, A, lda, D, E, tauq, taup); } -ROCSOLVER_EXPORT rocblas_status rocsolver_cgebd2( - rocblas_handle handle, const rocblas_int m, const rocblas_int n, - rocblas_float_complex *A, const rocblas_int lda, float *D, float *E, - rocblas_float_complex *tauq, rocblas_float_complex *taup) { +rocblas_status rocsolver_cgebd2(rocblas_handle handle, const rocblas_int m, + const rocblas_int n, rocblas_float_complex *A, + const rocblas_int lda, float *D, float *E, + rocblas_float_complex *tauq, + rocblas_float_complex *taup) { return rocsolver_gebd2_impl( handle, m, n, A, lda, D, E, tauq, taup); } -ROCSOLVER_EXPORT rocblas_status rocsolver_zgebd2( - rocblas_handle handle, const rocblas_int m, const rocblas_int n, - rocblas_double_complex *A, const rocblas_int lda, double *D, double *E, - rocblas_double_complex *tauq, rocblas_double_complex *taup) { +rocblas_status rocsolver_zgebd2(rocblas_handle handle, const rocblas_int m, + const rocblas_int n, rocblas_double_complex *A, + const rocblas_int lda, double *D, double *E, + rocblas_double_complex *tauq, + rocblas_double_complex *taup) { return rocsolver_gebd2_impl( handle, m, n, A, lda, D, E, tauq, taup); } diff --git a/rocsolver/library/src/lapack/roclapack_gebd2_batched.cpp b/rocsolver/library/src/lapack/roclapack_gebd2_batched.cpp index af7d1431a..e3f0ec9e4 100644 --- a/rocsolver/library/src/lapack/roclapack_gebd2_batched.cpp +++ b/rocsolver/library/src/lapack/roclapack_gebd2_batched.cpp @@ -2,7 +2,6 @@ * Copyright (c) 2019-2020 Advanced Micro Devices, Inc. * ************************************************************************ */ -#define batched #include "roclapack_gebd2.hpp" template @@ -69,7 +68,7 @@ rocblas_status rocsolver_gebd2_batched_impl( extern "C" { -ROCSOLVER_EXPORT rocblas_status rocsolver_sgebd2_batched( +rocblas_status rocsolver_sgebd2_batched( rocblas_handle handle, const rocblas_int m, const rocblas_int n, float *const A[], const rocblas_int lda, float *D, const rocblas_stride strideD, float *E, const rocblas_stride strideE, @@ -80,7 +79,7 @@ ROCSOLVER_EXPORT rocblas_status rocsolver_sgebd2_batched( strideP, batch_count); } -ROCSOLVER_EXPORT rocblas_status rocsolver_dgebd2_batched( +rocblas_status rocsolver_dgebd2_batched( rocblas_handle handle, const rocblas_int m, const rocblas_int n, double *const A[], const rocblas_int lda, double *D, const rocblas_stride strideD, double *E, const rocblas_stride strideE, @@ -91,7 +90,7 @@ ROCSOLVER_EXPORT rocblas_status rocsolver_dgebd2_batched( strideP, batch_count); } -ROCSOLVER_EXPORT rocblas_status rocsolver_cgebd2_batched( +rocblas_status rocsolver_cgebd2_batched( rocblas_handle handle, const rocblas_int m, const rocblas_int n, rocblas_float_complex *const A[], const rocblas_int lda, float *D, const rocblas_stride strideD, float *E, const rocblas_stride strideE, @@ -103,7 +102,7 @@ ROCSOLVER_EXPORT rocblas_status rocsolver_cgebd2_batched( strideP, batch_count); } -ROCSOLVER_EXPORT rocblas_status rocsolver_zgebd2_batched( +rocblas_status rocsolver_zgebd2_batched( rocblas_handle handle, const rocblas_int m, const rocblas_int n, rocblas_double_complex *const A[], const rocblas_int lda, double *D, const rocblas_stride strideD, double *E, const rocblas_stride strideE, @@ -116,4 +115,3 @@ ROCSOLVER_EXPORT rocblas_status rocsolver_zgebd2_batched( } } // extern C -#undef batched diff --git a/rocsolver/library/src/lapack/roclapack_gebd2_strided_batched.cpp b/rocsolver/library/src/lapack/roclapack_gebd2_strided_batched.cpp index 6e7358187..3019d4a54 100644 --- a/rocsolver/library/src/lapack/roclapack_gebd2_strided_batched.cpp +++ b/rocsolver/library/src/lapack/roclapack_gebd2_strided_batched.cpp @@ -67,7 +67,7 @@ rocblas_status rocsolver_gebd2_strided_batched_impl( extern "C" { -ROCSOLVER_EXPORT rocblas_status rocsolver_sgebd2_strided_batched( +rocblas_status rocsolver_sgebd2_strided_batched( rocblas_handle handle, const rocblas_int m, const rocblas_int n, float *A, const rocblas_int lda, const rocblas_stride strideA, float *D, const rocblas_stride strideD, float *E, const rocblas_stride strideE, @@ -78,7 +78,7 @@ ROCSOLVER_EXPORT rocblas_status rocsolver_sgebd2_strided_batched( taup, strideP, batch_count); } -ROCSOLVER_EXPORT rocblas_status rocsolver_dgebd2_strided_batched( +rocblas_status rocsolver_dgebd2_strided_batched( rocblas_handle handle, const rocblas_int m, const rocblas_int n, double *A, const rocblas_int lda, const rocblas_stride strideA, double *D, const rocblas_stride strideD, double *E, const rocblas_stride strideE, @@ -89,7 +89,7 @@ ROCSOLVER_EXPORT rocblas_status rocsolver_dgebd2_strided_batched( taup, strideP, batch_count); } -ROCSOLVER_EXPORT rocblas_status rocsolver_cgebd2_strided_batched( +rocblas_status rocsolver_cgebd2_strided_batched( rocblas_handle handle, const rocblas_int m, const rocblas_int n, rocblas_float_complex *A, const rocblas_int lda, const rocblas_stride strideA, float *D, const rocblas_stride strideD, @@ -101,7 +101,7 @@ ROCSOLVER_EXPORT rocblas_status rocsolver_cgebd2_strided_batched( taup, strideP, batch_count); } -ROCSOLVER_EXPORT rocblas_status rocsolver_zgebd2_strided_batched( +rocblas_status rocsolver_zgebd2_strided_batched( rocblas_handle handle, const rocblas_int m, const rocblas_int n, rocblas_double_complex *A, const rocblas_int lda, const rocblas_stride strideA, double *D, const rocblas_stride strideD, diff --git a/rocsolver/library/src/lapack/roclapack_gebrd.cpp b/rocsolver/library/src/lapack/roclapack_gebrd.cpp index 8c48b94f8..85a5d784b 100644 --- a/rocsolver/library/src/lapack/roclapack_gebrd.cpp +++ b/rocsolver/library/src/lapack/roclapack_gebrd.cpp @@ -84,32 +84,36 @@ rocblas_status rocsolver_gebrd_impl(rocblas_handle handle, const rocblas_int m, extern "C" { -ROCSOLVER_EXPORT rocblas_status rocsolver_sgebrd( - rocblas_handle handle, const rocblas_int m, const rocblas_int n, float *A, - const rocblas_int lda, float *D, float *E, float *tauq, float *taup) { +rocblas_status rocsolver_sgebrd(rocblas_handle handle, const rocblas_int m, + const rocblas_int n, float *A, + const rocblas_int lda, float *D, float *E, + float *tauq, float *taup) { return rocsolver_gebrd_impl(handle, m, n, A, lda, D, E, tauq, taup); } -ROCSOLVER_EXPORT rocblas_status rocsolver_dgebrd( - rocblas_handle handle, const rocblas_int m, const rocblas_int n, double *A, - const rocblas_int lda, double *D, double *E, double *tauq, double *taup) { +rocblas_status rocsolver_dgebrd(rocblas_handle handle, const rocblas_int m, + const rocblas_int n, double *A, + const rocblas_int lda, double *D, double *E, + double *tauq, double *taup) { return rocsolver_gebrd_impl(handle, m, n, A, lda, D, E, tauq, taup); } -ROCSOLVER_EXPORT rocblas_status rocsolver_cgebrd( - rocblas_handle handle, const rocblas_int m, const rocblas_int n, - rocblas_float_complex *A, const rocblas_int lda, float *D, float *E, - rocblas_float_complex *tauq, rocblas_float_complex *taup) { +rocblas_status rocsolver_cgebrd(rocblas_handle handle, const rocblas_int m, + const rocblas_int n, rocblas_float_complex *A, + const rocblas_int lda, float *D, float *E, + rocblas_float_complex *tauq, + rocblas_float_complex *taup) { return rocsolver_gebrd_impl( handle, m, n, A, lda, D, E, tauq, taup); } -ROCSOLVER_EXPORT rocblas_status rocsolver_zgebrd( - rocblas_handle handle, const rocblas_int m, const rocblas_int n, - rocblas_double_complex *A, const rocblas_int lda, double *D, double *E, - rocblas_double_complex *tauq, rocblas_double_complex *taup) { +rocblas_status rocsolver_zgebrd(rocblas_handle handle, const rocblas_int m, + const rocblas_int n, rocblas_double_complex *A, + const rocblas_int lda, double *D, double *E, + rocblas_double_complex *tauq, + rocblas_double_complex *taup) { return rocsolver_gebrd_impl( handle, m, n, A, lda, D, E, tauq, taup); } diff --git a/rocsolver/library/src/lapack/roclapack_gebrd_batched.cpp b/rocsolver/library/src/lapack/roclapack_gebrd_batched.cpp index ee0709c50..3223c1403 100644 --- a/rocsolver/library/src/lapack/roclapack_gebrd_batched.cpp +++ b/rocsolver/library/src/lapack/roclapack_gebrd_batched.cpp @@ -2,7 +2,6 @@ * Copyright (c) 2019-2020 Advanced Micro Devices, Inc. * ************************************************************************ */ -#define batched #include "roclapack_gebrd.hpp" template @@ -95,7 +94,7 @@ rocblas_status rocsolver_gebrd_batched_impl( extern "C" { -ROCSOLVER_EXPORT rocblas_status rocsolver_sgebrd_batched( +rocblas_status rocsolver_sgebrd_batched( rocblas_handle handle, const rocblas_int m, const rocblas_int n, float *const A[], const rocblas_int lda, float *D, const rocblas_stride strideD, float *E, const rocblas_stride strideE, @@ -106,7 +105,7 @@ ROCSOLVER_EXPORT rocblas_status rocsolver_sgebrd_batched( strideP, batch_count); } -ROCSOLVER_EXPORT rocblas_status rocsolver_dgebrd_batched( +rocblas_status rocsolver_dgebrd_batched( rocblas_handle handle, const rocblas_int m, const rocblas_int n, double *const A[], const rocblas_int lda, double *D, const rocblas_stride strideD, double *E, const rocblas_stride strideE, @@ -117,7 +116,7 @@ ROCSOLVER_EXPORT rocblas_status rocsolver_dgebrd_batched( strideP, batch_count); } -ROCSOLVER_EXPORT rocblas_status rocsolver_cgebrd_batched( +rocblas_status rocsolver_cgebrd_batched( rocblas_handle handle, const rocblas_int m, const rocblas_int n, rocblas_float_complex *const A[], const rocblas_int lda, float *D, const rocblas_stride strideD, float *E, const rocblas_stride strideE, @@ -129,7 +128,7 @@ ROCSOLVER_EXPORT rocblas_status rocsolver_cgebrd_batched( strideP, batch_count); } -ROCSOLVER_EXPORT rocblas_status rocsolver_zgebrd_batched( +rocblas_status rocsolver_zgebrd_batched( rocblas_handle handle, const rocblas_int m, const rocblas_int n, rocblas_double_complex *const A[], const rocblas_int lda, double *D, const rocblas_stride strideD, double *E, const rocblas_stride strideE, @@ -142,4 +141,3 @@ ROCSOLVER_EXPORT rocblas_status rocsolver_zgebrd_batched( } } // extern C -#undef batched diff --git a/rocsolver/library/src/lapack/roclapack_gebrd_strided_batched.cpp b/rocsolver/library/src/lapack/roclapack_gebrd_strided_batched.cpp index ab5cf2da3..2aaeead7c 100644 --- a/rocsolver/library/src/lapack/roclapack_gebrd_strided_batched.cpp +++ b/rocsolver/library/src/lapack/roclapack_gebrd_strided_batched.cpp @@ -80,7 +80,7 @@ rocblas_status rocsolver_gebrd_strided_batched_impl( extern "C" { -ROCSOLVER_EXPORT rocblas_status rocsolver_sgebrd_strided_batched( +rocblas_status rocsolver_sgebrd_strided_batched( rocblas_handle handle, const rocblas_int m, const rocblas_int n, float *A, const rocblas_int lda, const rocblas_stride strideA, float *D, const rocblas_stride strideD, float *E, const rocblas_stride strideE, @@ -91,7 +91,7 @@ ROCSOLVER_EXPORT rocblas_status rocsolver_sgebrd_strided_batched( taup, strideP, batch_count); } -ROCSOLVER_EXPORT rocblas_status rocsolver_dgebrd_strided_batched( +rocblas_status rocsolver_dgebrd_strided_batched( rocblas_handle handle, const rocblas_int m, const rocblas_int n, double *A, const rocblas_int lda, const rocblas_stride strideA, double *D, const rocblas_stride strideD, double *E, const rocblas_stride strideE, @@ -102,7 +102,7 @@ ROCSOLVER_EXPORT rocblas_status rocsolver_dgebrd_strided_batched( taup, strideP, batch_count); } -ROCSOLVER_EXPORT rocblas_status rocsolver_cgebrd_strided_batched( +rocblas_status rocsolver_cgebrd_strided_batched( rocblas_handle handle, const rocblas_int m, const rocblas_int n, rocblas_float_complex *A, const rocblas_int lda, const rocblas_stride strideA, float *D, const rocblas_stride strideD, @@ -114,7 +114,7 @@ ROCSOLVER_EXPORT rocblas_status rocsolver_cgebrd_strided_batched( taup, strideP, batch_count); } -ROCSOLVER_EXPORT rocblas_status rocsolver_zgebrd_strided_batched( +rocblas_status rocsolver_zgebrd_strided_batched( rocblas_handle handle, const rocblas_int m, const rocblas_int n, rocblas_double_complex *A, const rocblas_int lda, const rocblas_stride strideA, double *D, const rocblas_stride strideD, diff --git a/rocsolver/library/src/lapack/roclapack_gelq2.cpp b/rocsolver/library/src/lapack/roclapack_gelq2.cpp index 470ae108d..ed7a224cb 100644 --- a/rocsolver/library/src/lapack/roclapack_gelq2.cpp +++ b/rocsolver/library/src/lapack/roclapack_gelq2.cpp @@ -67,38 +67,30 @@ rocblas_status rocsolver_gelq2_impl(rocblas_handle handle, const rocblas_int m, extern "C" { -ROCSOLVER_EXPORT rocblas_status rocsolver_sgelq2(rocblas_handle handle, - const rocblas_int m, - const rocblas_int n, float *A, - const rocblas_int lda, - float *ipiv) { +rocblas_status rocsolver_sgelq2(rocblas_handle handle, const rocblas_int m, + const rocblas_int n, float *A, + const rocblas_int lda, float *ipiv) { return rocsolver_gelq2_impl(handle, m, n, A, lda, ipiv); } -ROCSOLVER_EXPORT rocblas_status rocsolver_dgelq2(rocblas_handle handle, - const rocblas_int m, - const rocblas_int n, double *A, - const rocblas_int lda, - double *ipiv) { +rocblas_status rocsolver_dgelq2(rocblas_handle handle, const rocblas_int m, + const rocblas_int n, double *A, + const rocblas_int lda, double *ipiv) { return rocsolver_gelq2_impl(handle, m, n, A, lda, ipiv); } -ROCSOLVER_EXPORT rocblas_status rocsolver_cgelq2(rocblas_handle handle, - const rocblas_int m, - const rocblas_int n, - rocblas_float_complex *A, - const rocblas_int lda, - rocblas_float_complex *ipiv) { +rocblas_status rocsolver_cgelq2(rocblas_handle handle, const rocblas_int m, + const rocblas_int n, rocblas_float_complex *A, + const rocblas_int lda, + rocblas_float_complex *ipiv) { return rocsolver_gelq2_impl(handle, m, n, A, lda, ipiv); } -ROCSOLVER_EXPORT rocblas_status rocsolver_zgelq2(rocblas_handle handle, - const rocblas_int m, - const rocblas_int n, - rocblas_double_complex *A, - const rocblas_int lda, - rocblas_double_complex *ipiv) { +rocblas_status rocsolver_zgelq2(rocblas_handle handle, const rocblas_int m, + const rocblas_int n, rocblas_double_complex *A, + const rocblas_int lda, + rocblas_double_complex *ipiv) { return rocsolver_gelq2_impl(handle, m, n, A, lda, ipiv); } diff --git a/rocsolver/library/src/lapack/roclapack_gelq2_batched.cpp b/rocsolver/library/src/lapack/roclapack_gelq2_batched.cpp index 928686dde..39ff4d8d8 100644 --- a/rocsolver/library/src/lapack/roclapack_gelq2_batched.cpp +++ b/rocsolver/library/src/lapack/roclapack_gelq2_batched.cpp @@ -2,7 +2,6 @@ * Copyright (c) 2019-2020 Advanced Micro Devices, Inc. * ************************************************************************ */ -#define batched #include "roclapack_gelq2.hpp" template @@ -69,39 +68,44 @@ rocsolver_gelq2_batched_impl(rocblas_handle handle, const rocblas_int m, extern "C" { -ROCSOLVER_EXPORT rocblas_status rocsolver_sgelq2_batched( - rocblas_handle handle, const rocblas_int m, const rocblas_int n, - float *const A[], const rocblas_int lda, float *ipiv, - const rocblas_stride stridep, const rocblas_int batch_count) { +rocblas_status rocsolver_sgelq2_batched(rocblas_handle handle, + const rocblas_int m, + const rocblas_int n, float *const A[], + const rocblas_int lda, float *ipiv, + const rocblas_stride stridep, + const rocblas_int batch_count) { return rocsolver_gelq2_batched_impl(handle, m, n, A, lda, ipiv, stridep, batch_count); } -ROCSOLVER_EXPORT rocblas_status rocsolver_dgelq2_batched( - rocblas_handle handle, const rocblas_int m, const rocblas_int n, - double *const A[], const rocblas_int lda, double *ipiv, - const rocblas_stride stridep, const rocblas_int batch_count) { +rocblas_status rocsolver_dgelq2_batched(rocblas_handle handle, + const rocblas_int m, + const rocblas_int n, double *const A[], + const rocblas_int lda, double *ipiv, + const rocblas_stride stridep, + const rocblas_int batch_count) { return rocsolver_gelq2_batched_impl(handle, m, n, A, lda, ipiv, stridep, batch_count); } -ROCSOLVER_EXPORT rocblas_status rocsolver_cgelq2_batched( - rocblas_handle handle, const rocblas_int m, const rocblas_int n, - rocblas_float_complex *const A[], const rocblas_int lda, - rocblas_float_complex *ipiv, const rocblas_stride stridep, - const rocblas_int batch_count) { +rocblas_status +rocsolver_cgelq2_batched(rocblas_handle handle, const rocblas_int m, + const rocblas_int n, rocblas_float_complex *const A[], + const rocblas_int lda, rocblas_float_complex *ipiv, + const rocblas_stride stridep, + const rocblas_int batch_count) { return rocsolver_gelq2_batched_impl( handle, m, n, A, lda, ipiv, stridep, batch_count); } -ROCSOLVER_EXPORT rocblas_status rocsolver_zgelq2_batched( - rocblas_handle handle, const rocblas_int m, const rocblas_int n, - rocblas_double_complex *const A[], const rocblas_int lda, - rocblas_double_complex *ipiv, const rocblas_stride stridep, - const rocblas_int batch_count) { +rocblas_status +rocsolver_zgelq2_batched(rocblas_handle handle, const rocblas_int m, + const rocblas_int n, rocblas_double_complex *const A[], + const rocblas_int lda, rocblas_double_complex *ipiv, + const rocblas_stride stridep, + const rocblas_int batch_count) { return rocsolver_gelq2_batched_impl( handle, m, n, A, lda, ipiv, stridep, batch_count); } } // extern C -#undef batched diff --git a/rocsolver/library/src/lapack/roclapack_gelq2_strided_batched.cpp b/rocsolver/library/src/lapack/roclapack_gelq2_strided_batched.cpp index 8805b0c56..1e793ae2c 100644 --- a/rocsolver/library/src/lapack/roclapack_gelq2_strided_batched.cpp +++ b/rocsolver/library/src/lapack/roclapack_gelq2_strided_batched.cpp @@ -65,7 +65,7 @@ rocblas_status rocsolver_gelq2_strided_batched_impl( extern "C" { -ROCSOLVER_EXPORT rocblas_status rocsolver_sgelq2_strided_batched( +rocblas_status rocsolver_sgelq2_strided_batched( rocblas_handle handle, const rocblas_int m, const rocblas_int n, float *A, const rocblas_int lda, const rocblas_stride strideA, float *ipiv, const rocblas_stride stridep, const rocblas_int batch_count) { @@ -73,7 +73,7 @@ ROCSOLVER_EXPORT rocblas_status rocsolver_sgelq2_strided_batched( handle, m, n, A, lda, strideA, ipiv, stridep, batch_count); } -ROCSOLVER_EXPORT rocblas_status rocsolver_dgelq2_strided_batched( +rocblas_status rocsolver_dgelq2_strided_batched( rocblas_handle handle, const rocblas_int m, const rocblas_int n, double *A, const rocblas_int lda, const rocblas_stride strideA, double *ipiv, const rocblas_stride stridep, const rocblas_int batch_count) { @@ -81,7 +81,7 @@ ROCSOLVER_EXPORT rocblas_status rocsolver_dgelq2_strided_batched( handle, m, n, A, lda, strideA, ipiv, stridep, batch_count); } -ROCSOLVER_EXPORT rocblas_status rocsolver_cgelq2_strided_batched( +rocblas_status rocsolver_cgelq2_strided_batched( rocblas_handle handle, const rocblas_int m, const rocblas_int n, rocblas_float_complex *A, const rocblas_int lda, const rocblas_stride strideA, rocblas_float_complex *ipiv, @@ -90,7 +90,7 @@ ROCSOLVER_EXPORT rocblas_status rocsolver_cgelq2_strided_batched( handle, m, n, A, lda, strideA, ipiv, stridep, batch_count); } -ROCSOLVER_EXPORT rocblas_status rocsolver_zgelq2_strided_batched( +rocblas_status rocsolver_zgelq2_strided_batched( rocblas_handle handle, const rocblas_int m, const rocblas_int n, rocblas_double_complex *A, const rocblas_int lda, const rocblas_stride strideA, rocblas_double_complex *ipiv, diff --git a/rocsolver/library/src/lapack/roclapack_gelqf.cpp b/rocsolver/library/src/lapack/roclapack_gelqf.cpp index 4f569a3d5..d2c988fc5 100644 --- a/rocsolver/library/src/lapack/roclapack_gelqf.cpp +++ b/rocsolver/library/src/lapack/roclapack_gelqf.cpp @@ -26,7 +26,7 @@ rocblas_status rocsolver_gelqf_impl(rocblas_handle handle, const rocblas_int m, size_t size_1; // size of constants size_t size_2; // size of workspace size_t size_3; // size of array of pointers to workspace - size_t size_4; // size of diagonal entry cache + size_t size_4; // size of diagonal entry cache and TRMM calls workspace size_t size_5; // size of triangular factor for block reflector rocsolver_gelqf_getMemorySize(m, n, batch_count, &size_1, &size_2, &size_3, &size_4, &size_5); @@ -70,38 +70,30 @@ rocblas_status rocsolver_gelqf_impl(rocblas_handle handle, const rocblas_int m, extern "C" { -ROCSOLVER_EXPORT rocblas_status rocsolver_sgelqf(rocblas_handle handle, - const rocblas_int m, - const rocblas_int n, float *A, - const rocblas_int lda, - float *ipiv) { +rocblas_status rocsolver_sgelqf(rocblas_handle handle, const rocblas_int m, + const rocblas_int n, float *A, + const rocblas_int lda, float *ipiv) { return rocsolver_gelqf_impl(handle, m, n, A, lda, ipiv); } -ROCSOLVER_EXPORT rocblas_status rocsolver_dgelqf(rocblas_handle handle, - const rocblas_int m, - const rocblas_int n, double *A, - const rocblas_int lda, - double *ipiv) { +rocblas_status rocsolver_dgelqf(rocblas_handle handle, const rocblas_int m, + const rocblas_int n, double *A, + const rocblas_int lda, double *ipiv) { return rocsolver_gelqf_impl(handle, m, n, A, lda, ipiv); } -ROCSOLVER_EXPORT rocblas_status rocsolver_cgelqf(rocblas_handle handle, - const rocblas_int m, - const rocblas_int n, - rocblas_float_complex *A, - const rocblas_int lda, - rocblas_float_complex *ipiv) { +rocblas_status rocsolver_cgelqf(rocblas_handle handle, const rocblas_int m, + const rocblas_int n, rocblas_float_complex *A, + const rocblas_int lda, + rocblas_float_complex *ipiv) { return rocsolver_gelqf_impl(handle, m, n, A, lda, ipiv); } -ROCSOLVER_EXPORT rocblas_status rocsolver_zgelqf(rocblas_handle handle, - const rocblas_int m, - const rocblas_int n, - rocblas_double_complex *A, - const rocblas_int lda, - rocblas_double_complex *ipiv) { +rocblas_status rocsolver_zgelqf(rocblas_handle handle, const rocblas_int m, + const rocblas_int n, rocblas_double_complex *A, + const rocblas_int lda, + rocblas_double_complex *ipiv) { return rocsolver_gelqf_impl(handle, m, n, A, lda, ipiv); } diff --git a/rocsolver/library/src/lapack/roclapack_gelqf.hpp b/rocsolver/library/src/lapack/roclapack_gelqf.hpp index 65d21b202..8dd08f7e2 100644 --- a/rocsolver/library/src/lapack/roclapack_gelqf.hpp +++ b/rocsolver/library/src/lapack/roclapack_gelqf.hpp @@ -22,7 +22,7 @@ void rocsolver_gelqf_getMemorySize(const rocblas_int m, const rocblas_int n, size_t *size_1, size_t *size_2, size_t *size_3, size_t *size_4, size_t *size_5) { - size_t s1, s2, s3; + size_t s1, s2, s3, unused, s4 = 0; rocsolver_gelq2_getMemorySize(m, n, batch_count, size_1, &s1, size_3, size_4); if (m <= GEQRF_GEQR2_SWITCHSIZE || n <= GEQRF_GEQR2_SWITCHSIZE) { @@ -31,11 +31,17 @@ void rocsolver_gelqf_getMemorySize(const rocblas_int m, const rocblas_int n, } else { rocblas_int jb = GEQRF_GEQR2_BLOCKSIZE; rocsolver_larft_getMemorySize(jb, batch_count, &s2); - rocsolver_larfb_getMemorySize(rocblas_side_right, m - jb, n, jb, - batch_count, &s3); + rocsolver_larfb_getMemorySize(rocblas_side_right, m - jb, n, jb, + batch_count, &s3, &unused, &s4); *size_2 = max(s1, max(s2, s3)); *size_5 = sizeof(T) * jb * jb * batch_count; } + *size_4 = max(*size_4, s4); + + // size of workArr is double to accomodate + // the TRMM calls in the batched case + if (BATCHED) + *size_3 *= 2; } template @@ -88,7 +94,7 @@ rocsolver_gelqf_template(rocblas_handle handle, const rocblas_int m, rocblas_forward_direction, rocblas_row_wise, m - j - jb, n - j, jb, A, shiftA + idx2D(j, j, lda), lda, strideA, trfact, 0, ldw, strideW, A, shiftA + idx2D(j + jb, j, lda), lda, strideA, batch_count, work, - workArr); + workArr, diag); } j += GEQRF_GEQR2_BLOCKSIZE; } diff --git a/rocsolver/library/src/lapack/roclapack_gelqf_batched.cpp b/rocsolver/library/src/lapack/roclapack_gelqf_batched.cpp index def9d35e8..35004e361 100644 --- a/rocsolver/library/src/lapack/roclapack_gelqf_batched.cpp +++ b/rocsolver/library/src/lapack/roclapack_gelqf_batched.cpp @@ -2,7 +2,6 @@ * Copyright (c) 2019-2020 Advanced Micro Devices, Inc. * ************************************************************************ */ -#define batched #include "roclapack_gelqf.hpp" template @@ -28,7 +27,7 @@ rocsolver_gelqf_batched_impl(rocblas_handle handle, const rocblas_int m, size_t size_1; // size of constants size_t size_2; // size of workspace size_t size_3; // size of array of pointers to workspace - size_t size_4; // size of diagonal entry cache + size_t size_4; // size of diagonal entry cache and TRMM calls workspace size_t size_5; // size of triangular factor for block reflector rocsolver_gelqf_getMemorySize(m, n, batch_count, &size_1, &size_2, &size_3, &size_4, &size_5); @@ -72,40 +71,44 @@ rocsolver_gelqf_batched_impl(rocblas_handle handle, const rocblas_int m, extern "C" { -ROCSOLVER_EXPORT rocblas_status rocsolver_sgelqf_batched( - rocblas_handle handle, const rocblas_int m, const rocblas_int n, - float *const A[], const rocblas_int lda, float *ipiv, - const rocblas_stride stridep, const rocblas_int batch_count) { +rocblas_status rocsolver_sgelqf_batched(rocblas_handle handle, + const rocblas_int m, + const rocblas_int n, float *const A[], + const rocblas_int lda, float *ipiv, + const rocblas_stride stridep, + const rocblas_int batch_count) { return rocsolver_gelqf_batched_impl(handle, m, n, A, lda, ipiv, stridep, batch_count); } -ROCSOLVER_EXPORT rocblas_status rocsolver_dgelqf_batched( - rocblas_handle handle, const rocblas_int m, const rocblas_int n, - double *const A[], const rocblas_int lda, double *ipiv, - const rocblas_stride stridep, const rocblas_int batch_count) { +rocblas_status rocsolver_dgelqf_batched(rocblas_handle handle, + const rocblas_int m, + const rocblas_int n, double *const A[], + const rocblas_int lda, double *ipiv, + const rocblas_stride stridep, + const rocblas_int batch_count) { return rocsolver_gelqf_batched_impl(handle, m, n, A, lda, ipiv, stridep, batch_count); } -ROCSOLVER_EXPORT rocblas_status rocsolver_cgelqf_batched( - rocblas_handle handle, const rocblas_int m, const rocblas_int n, - rocblas_float_complex *const A[], const rocblas_int lda, - rocblas_float_complex *ipiv, const rocblas_stride stridep, - const rocblas_int batch_count) { +rocblas_status +rocsolver_cgelqf_batched(rocblas_handle handle, const rocblas_int m, + const rocblas_int n, rocblas_float_complex *const A[], + const rocblas_int lda, rocblas_float_complex *ipiv, + const rocblas_stride stridep, + const rocblas_int batch_count) { return rocsolver_gelqf_batched_impl( handle, m, n, A, lda, ipiv, stridep, batch_count); } -ROCSOLVER_EXPORT rocblas_status rocsolver_zgelqf_batched( - rocblas_handle handle, const rocblas_int m, const rocblas_int n, - rocblas_double_complex *const A[], const rocblas_int lda, - rocblas_double_complex *ipiv, const rocblas_stride stridep, - const rocblas_int batch_count) { +rocblas_status +rocsolver_zgelqf_batched(rocblas_handle handle, const rocblas_int m, + const rocblas_int n, rocblas_double_complex *const A[], + const rocblas_int lda, rocblas_double_complex *ipiv, + const rocblas_stride stridep, + const rocblas_int batch_count) { return rocsolver_gelqf_batched_impl( handle, m, n, A, lda, ipiv, stridep, batch_count); } } // extern C - -#undef batched diff --git a/rocsolver/library/src/lapack/roclapack_gelqf_strided_batched.cpp b/rocsolver/library/src/lapack/roclapack_gelqf_strided_batched.cpp index 85a6bbe66..bead2ecfd 100644 --- a/rocsolver/library/src/lapack/roclapack_gelqf_strided_batched.cpp +++ b/rocsolver/library/src/lapack/roclapack_gelqf_strided_batched.cpp @@ -24,7 +24,7 @@ rocblas_status rocsolver_gelqf_strided_batched_impl( size_t size_1; // size of constants size_t size_2; // size of workspace size_t size_3; // size of array of pointers to workspace - size_t size_4; // size of diagonal entry cache + size_t size_4; // size of diagonal entry cache and TRMM calls workspace size_t size_5; // size of triangular factor for block reflector rocsolver_gelqf_getMemorySize(m, n, batch_count, &size_1, &size_2, &size_3, &size_4, &size_5); @@ -68,7 +68,7 @@ rocblas_status rocsolver_gelqf_strided_batched_impl( extern "C" { -ROCSOLVER_EXPORT rocblas_status rocsolver_sgelqf_strided_batched( +rocblas_status rocsolver_sgelqf_strided_batched( rocblas_handle handle, const rocblas_int m, const rocblas_int n, float *A, const rocblas_int lda, const rocblas_stride strideA, float *ipiv, const rocblas_stride stridep, const rocblas_int batch_count) { @@ -76,7 +76,7 @@ ROCSOLVER_EXPORT rocblas_status rocsolver_sgelqf_strided_batched( handle, m, n, A, lda, strideA, ipiv, stridep, batch_count); } -ROCSOLVER_EXPORT rocblas_status rocsolver_dgelqf_strided_batched( +rocblas_status rocsolver_dgelqf_strided_batched( rocblas_handle handle, const rocblas_int m, const rocblas_int n, double *A, const rocblas_int lda, const rocblas_stride strideA, double *ipiv, const rocblas_stride stridep, const rocblas_int batch_count) { @@ -84,7 +84,7 @@ ROCSOLVER_EXPORT rocblas_status rocsolver_dgelqf_strided_batched( handle, m, n, A, lda, strideA, ipiv, stridep, batch_count); } -ROCSOLVER_EXPORT rocblas_status rocsolver_cgelqf_strided_batched( +rocblas_status rocsolver_cgelqf_strided_batched( rocblas_handle handle, const rocblas_int m, const rocblas_int n, rocblas_float_complex *A, const rocblas_int lda, const rocblas_stride strideA, rocblas_float_complex *ipiv, @@ -93,7 +93,7 @@ ROCSOLVER_EXPORT rocblas_status rocsolver_cgelqf_strided_batched( handle, m, n, A, lda, strideA, ipiv, stridep, batch_count); } -ROCSOLVER_EXPORT rocblas_status rocsolver_zgelqf_strided_batched( +rocblas_status rocsolver_zgelqf_strided_batched( rocblas_handle handle, const rocblas_int m, const rocblas_int n, rocblas_double_complex *A, const rocblas_int lda, const rocblas_stride strideA, rocblas_double_complex *ipiv, diff --git a/rocsolver/library/src/lapack/roclapack_geql2.cpp b/rocsolver/library/src/lapack/roclapack_geql2.cpp index b05c785d9..54e3166ea 100644 --- a/rocsolver/library/src/lapack/roclapack_geql2.cpp +++ b/rocsolver/library/src/lapack/roclapack_geql2.cpp @@ -67,38 +67,30 @@ rocblas_status rocsolver_geql2_impl(rocblas_handle handle, const rocblas_int m, extern "C" { -ROCSOLVER_EXPORT rocblas_status rocsolver_sgeql2(rocblas_handle handle, - const rocblas_int m, - const rocblas_int n, float *A, - const rocblas_int lda, - float *ipiv) { +rocblas_status rocsolver_sgeql2(rocblas_handle handle, const rocblas_int m, + const rocblas_int n, float *A, + const rocblas_int lda, float *ipiv) { return rocsolver_geql2_impl(handle, m, n, A, lda, ipiv); } -ROCSOLVER_EXPORT rocblas_status rocsolver_dgeql2(rocblas_handle handle, - const rocblas_int m, - const rocblas_int n, double *A, - const rocblas_int lda, - double *ipiv) { +rocblas_status rocsolver_dgeql2(rocblas_handle handle, const rocblas_int m, + const rocblas_int n, double *A, + const rocblas_int lda, double *ipiv) { return rocsolver_geql2_impl(handle, m, n, A, lda, ipiv); } -ROCSOLVER_EXPORT rocblas_status rocsolver_cgeql2(rocblas_handle handle, - const rocblas_int m, - const rocblas_int n, - rocblas_float_complex *A, - const rocblas_int lda, - rocblas_float_complex *ipiv) { +rocblas_status rocsolver_cgeql2(rocblas_handle handle, const rocblas_int m, + const rocblas_int n, rocblas_float_complex *A, + const rocblas_int lda, + rocblas_float_complex *ipiv) { return rocsolver_geql2_impl(handle, m, n, A, lda, ipiv); } -ROCSOLVER_EXPORT rocblas_status rocsolver_zgeql2(rocblas_handle handle, - const rocblas_int m, - const rocblas_int n, - rocblas_double_complex *A, - const rocblas_int lda, - rocblas_double_complex *ipiv) { +rocblas_status rocsolver_zgeql2(rocblas_handle handle, const rocblas_int m, + const rocblas_int n, rocblas_double_complex *A, + const rocblas_int lda, + rocblas_double_complex *ipiv) { return rocsolver_geql2_impl(handle, m, n, A, lda, ipiv); } diff --git a/rocsolver/library/src/lapack/roclapack_geql2_batched.cpp b/rocsolver/library/src/lapack/roclapack_geql2_batched.cpp index be2f46c5f..c2962c040 100644 --- a/rocsolver/library/src/lapack/roclapack_geql2_batched.cpp +++ b/rocsolver/library/src/lapack/roclapack_geql2_batched.cpp @@ -2,7 +2,6 @@ * Copyright (c) 2019-2020 Advanced Micro Devices, Inc. * ************************************************************************ */ -#define batched #include "roclapack_geql2.hpp" template @@ -69,39 +68,44 @@ rocsolver_geql2_batched_impl(rocblas_handle handle, const rocblas_int m, extern "C" { -ROCSOLVER_EXPORT rocblas_status rocsolver_sgeql2_batched( - rocblas_handle handle, const rocblas_int m, const rocblas_int n, - float *const A[], const rocblas_int lda, float *ipiv, - const rocblas_stride stridep, const rocblas_int batch_count) { +rocblas_status rocsolver_sgeql2_batched(rocblas_handle handle, + const rocblas_int m, + const rocblas_int n, float *const A[], + const rocblas_int lda, float *ipiv, + const rocblas_stride stridep, + const rocblas_int batch_count) { return rocsolver_geql2_batched_impl(handle, m, n, A, lda, ipiv, stridep, batch_count); } -ROCSOLVER_EXPORT rocblas_status rocsolver_dgeql2_batched( - rocblas_handle handle, const rocblas_int m, const rocblas_int n, - double *const A[], const rocblas_int lda, double *ipiv, - const rocblas_stride stridep, const rocblas_int batch_count) { +rocblas_status rocsolver_dgeql2_batched(rocblas_handle handle, + const rocblas_int m, + const rocblas_int n, double *const A[], + const rocblas_int lda, double *ipiv, + const rocblas_stride stridep, + const rocblas_int batch_count) { return rocsolver_geql2_batched_impl(handle, m, n, A, lda, ipiv, stridep, batch_count); } -ROCSOLVER_EXPORT rocblas_status rocsolver_cgeql2_batched( - rocblas_handle handle, const rocblas_int m, const rocblas_int n, - rocblas_float_complex *const A[], const rocblas_int lda, - rocblas_float_complex *ipiv, const rocblas_stride stridep, - const rocblas_int batch_count) { +rocblas_status +rocsolver_cgeql2_batched(rocblas_handle handle, const rocblas_int m, + const rocblas_int n, rocblas_float_complex *const A[], + const rocblas_int lda, rocblas_float_complex *ipiv, + const rocblas_stride stridep, + const rocblas_int batch_count) { return rocsolver_geql2_batched_impl( handle, m, n, A, lda, ipiv, stridep, batch_count); } -ROCSOLVER_EXPORT rocblas_status rocsolver_zgeql2_batched( - rocblas_handle handle, const rocblas_int m, const rocblas_int n, - rocblas_double_complex *const A[], const rocblas_int lda, - rocblas_double_complex *ipiv, const rocblas_stride stridep, - const rocblas_int batch_count) { +rocblas_status +rocsolver_zgeql2_batched(rocblas_handle handle, const rocblas_int m, + const rocblas_int n, rocblas_double_complex *const A[], + const rocblas_int lda, rocblas_double_complex *ipiv, + const rocblas_stride stridep, + const rocblas_int batch_count) { return rocsolver_geql2_batched_impl( handle, m, n, A, lda, ipiv, stridep, batch_count); } } // extern C -#undef batched diff --git a/rocsolver/library/src/lapack/roclapack_geql2_strided_batched.cpp b/rocsolver/library/src/lapack/roclapack_geql2_strided_batched.cpp index 1cd1115e9..2306f9b6e 100644 --- a/rocsolver/library/src/lapack/roclapack_geql2_strided_batched.cpp +++ b/rocsolver/library/src/lapack/roclapack_geql2_strided_batched.cpp @@ -65,7 +65,7 @@ rocblas_status rocsolver_geql2_strided_batched_impl( extern "C" { -ROCSOLVER_EXPORT rocblas_status rocsolver_sgeql2_strided_batched( +rocblas_status rocsolver_sgeql2_strided_batched( rocblas_handle handle, const rocblas_int m, const rocblas_int n, float *A, const rocblas_int lda, const rocblas_stride strideA, float *ipiv, const rocblas_stride stridep, const rocblas_int batch_count) { @@ -73,7 +73,7 @@ ROCSOLVER_EXPORT rocblas_status rocsolver_sgeql2_strided_batched( handle, m, n, A, lda, strideA, ipiv, stridep, batch_count); } -ROCSOLVER_EXPORT rocblas_status rocsolver_dgeql2_strided_batched( +rocblas_status rocsolver_dgeql2_strided_batched( rocblas_handle handle, const rocblas_int m, const rocblas_int n, double *A, const rocblas_int lda, const rocblas_stride strideA, double *ipiv, const rocblas_stride stridep, const rocblas_int batch_count) { @@ -81,7 +81,7 @@ ROCSOLVER_EXPORT rocblas_status rocsolver_dgeql2_strided_batched( handle, m, n, A, lda, strideA, ipiv, stridep, batch_count); } -ROCSOLVER_EXPORT rocblas_status rocsolver_cgeql2_strided_batched( +rocblas_status rocsolver_cgeql2_strided_batched( rocblas_handle handle, const rocblas_int m, const rocblas_int n, rocblas_float_complex *A, const rocblas_int lda, const rocblas_stride strideA, rocblas_float_complex *ipiv, @@ -90,7 +90,7 @@ ROCSOLVER_EXPORT rocblas_status rocsolver_cgeql2_strided_batched( handle, m, n, A, lda, strideA, ipiv, stridep, batch_count); } -ROCSOLVER_EXPORT rocblas_status rocsolver_zgeql2_strided_batched( +rocblas_status rocsolver_zgeql2_strided_batched( rocblas_handle handle, const rocblas_int m, const rocblas_int n, rocblas_double_complex *A, const rocblas_int lda, const rocblas_stride strideA, rocblas_double_complex *ipiv, diff --git a/rocsolver/library/src/lapack/roclapack_geqlf.cpp b/rocsolver/library/src/lapack/roclapack_geqlf.cpp index e35286566..3634f0887 100644 --- a/rocsolver/library/src/lapack/roclapack_geqlf.cpp +++ b/rocsolver/library/src/lapack/roclapack_geqlf.cpp @@ -26,7 +26,7 @@ rocblas_status rocsolver_geqlf_impl(rocblas_handle handle, const rocblas_int m, size_t size_1; // size of constants size_t size_2; // size of workspace size_t size_3; // size of array of pointers to workspace - size_t size_4; // size of diagonal entry cache + size_t size_4; // size of diagonal entry cache and TRMM calls workspace size_t size_5; // size of triangular factor for block reflector rocsolver_geqlf_getMemorySize(m, n, batch_count, &size_1, &size_2, &size_3, &size_4, &size_5); @@ -70,38 +70,30 @@ rocblas_status rocsolver_geqlf_impl(rocblas_handle handle, const rocblas_int m, extern "C" { -ROCSOLVER_EXPORT rocblas_status rocsolver_sgeqlf(rocblas_handle handle, - const rocblas_int m, - const rocblas_int n, float *A, - const rocblas_int lda, - float *ipiv) { +rocblas_status rocsolver_sgeqlf(rocblas_handle handle, const rocblas_int m, + const rocblas_int n, float *A, + const rocblas_int lda, float *ipiv) { return rocsolver_geqlf_impl(handle, m, n, A, lda, ipiv); } -ROCSOLVER_EXPORT rocblas_status rocsolver_dgeqlf(rocblas_handle handle, - const rocblas_int m, - const rocblas_int n, double *A, - const rocblas_int lda, - double *ipiv) { +rocblas_status rocsolver_dgeqlf(rocblas_handle handle, const rocblas_int m, + const rocblas_int n, double *A, + const rocblas_int lda, double *ipiv) { return rocsolver_geqlf_impl(handle, m, n, A, lda, ipiv); } -ROCSOLVER_EXPORT rocblas_status rocsolver_cgeqlf(rocblas_handle handle, - const rocblas_int m, - const rocblas_int n, - rocblas_float_complex *A, - const rocblas_int lda, - rocblas_float_complex *ipiv) { +rocblas_status rocsolver_cgeqlf(rocblas_handle handle, const rocblas_int m, + const rocblas_int n, rocblas_float_complex *A, + const rocblas_int lda, + rocblas_float_complex *ipiv) { return rocsolver_geqlf_impl(handle, m, n, A, lda, ipiv); } -ROCSOLVER_EXPORT rocblas_status rocsolver_zgeqlf(rocblas_handle handle, - const rocblas_int m, - const rocblas_int n, - rocblas_double_complex *A, - const rocblas_int lda, - rocblas_double_complex *ipiv) { +rocblas_status rocsolver_zgeqlf(rocblas_handle handle, const rocblas_int m, + const rocblas_int n, rocblas_double_complex *A, + const rocblas_int lda, + rocblas_double_complex *ipiv) { return rocsolver_geqlf_impl(handle, m, n, A, lda, ipiv); } diff --git a/rocsolver/library/src/lapack/roclapack_geqlf.hpp b/rocsolver/library/src/lapack/roclapack_geqlf.hpp index 97ac749ac..766dedb42 100644 --- a/rocsolver/library/src/lapack/roclapack_geqlf.hpp +++ b/rocsolver/library/src/lapack/roclapack_geqlf.hpp @@ -22,7 +22,7 @@ void rocsolver_geqlf_getMemorySize(const rocblas_int m, const rocblas_int n, size_t *size_1, size_t *size_2, size_t *size_3, size_t *size_4, size_t *size_5) { - size_t s1, s2, s3; + size_t s1, s2, s3, unused, s4 = 0; rocsolver_geql2_getMemorySize(m, n, batch_count, size_1, &s1, size_3, size_4); if (m <= GEQLF_GEQL2_SWITCHSIZE || n <= GEQLF_GEQL2_SWITCHSIZE) { @@ -31,11 +31,17 @@ void rocsolver_geqlf_getMemorySize(const rocblas_int m, const rocblas_int n, } else { rocblas_int jb = GEQLF_GEQL2_BLOCKSIZE; rocsolver_larft_getMemorySize(jb, batch_count, &s2); - rocsolver_larfb_getMemorySize(rocblas_side_left, m, n - jb, jb, - batch_count, &s3); + rocsolver_larfb_getMemorySize(rocblas_side_left, m, n - jb, jb, + batch_count, &s3, &unused, &s4); *size_2 = max(s1, max(s2, s3)); *size_5 = sizeof(T) * jb * jb * batch_count; } + *size_4 = max(*size_4, s4); + + // size of workArr is double to accomodate + // the TRMM calls in the batched case + if (BATCHED) + *size_3 *= 2; } template @@ -94,7 +100,7 @@ rocsolver_geqlf_template(rocblas_handle handle, const rocblas_int m, rocblas_backward_direction, rocblas_column_wise, m - k + j + jb, n - k + j, jb, A, shiftA + idx2D(0, n - k + j, lda), lda, strideA, trfact, 0, ldw, strideW, A, shiftA, lda, strideA, batch_count, work, - workArr); + workArr, diag); } j -= nb; mu = m - k + j + jb; diff --git a/rocsolver/library/src/lapack/roclapack_geqlf_batched.cpp b/rocsolver/library/src/lapack/roclapack_geqlf_batched.cpp index 55876e340..a518b9d5a 100644 --- a/rocsolver/library/src/lapack/roclapack_geqlf_batched.cpp +++ b/rocsolver/library/src/lapack/roclapack_geqlf_batched.cpp @@ -2,7 +2,6 @@ * Copyright (c) 2019-2020 Advanced Micro Devices, Inc. * ************************************************************************ */ -#define batched #include "roclapack_geqlf.hpp" template @@ -28,7 +27,7 @@ rocsolver_geqlf_batched_impl(rocblas_handle handle, const rocblas_int m, size_t size_1; // size of constants size_t size_2; // size of workspace size_t size_3; // size of array of pointers to workspace - size_t size_4; // size of diagonal entry cache + size_t size_4; // size of diagonal entry cache and TRMM calls workspace size_t size_5; // size of triangular factor for block reflector rocsolver_geqlf_getMemorySize(m, n, batch_count, &size_1, &size_2, &size_3, &size_4, &size_5); @@ -72,40 +71,44 @@ rocsolver_geqlf_batched_impl(rocblas_handle handle, const rocblas_int m, extern "C" { -ROCSOLVER_EXPORT rocblas_status rocsolver_sgeqlf_batched( - rocblas_handle handle, const rocblas_int m, const rocblas_int n, - float *const A[], const rocblas_int lda, float *ipiv, - const rocblas_stride stridep, const rocblas_int batch_count) { +rocblas_status rocsolver_sgeqlf_batched(rocblas_handle handle, + const rocblas_int m, + const rocblas_int n, float *const A[], + const rocblas_int lda, float *ipiv, + const rocblas_stride stridep, + const rocblas_int batch_count) { return rocsolver_geqlf_batched_impl(handle, m, n, A, lda, ipiv, stridep, batch_count); } -ROCSOLVER_EXPORT rocblas_status rocsolver_dgeqlf_batched( - rocblas_handle handle, const rocblas_int m, const rocblas_int n, - double *const A[], const rocblas_int lda, double *ipiv, - const rocblas_stride stridep, const rocblas_int batch_count) { +rocblas_status rocsolver_dgeqlf_batched(rocblas_handle handle, + const rocblas_int m, + const rocblas_int n, double *const A[], + const rocblas_int lda, double *ipiv, + const rocblas_stride stridep, + const rocblas_int batch_count) { return rocsolver_geqlf_batched_impl(handle, m, n, A, lda, ipiv, stridep, batch_count); } -ROCSOLVER_EXPORT rocblas_status rocsolver_cgeqlf_batched( - rocblas_handle handle, const rocblas_int m, const rocblas_int n, - rocblas_float_complex *const A[], const rocblas_int lda, - rocblas_float_complex *ipiv, const rocblas_stride stridep, - const rocblas_int batch_count) { +rocblas_status +rocsolver_cgeqlf_batched(rocblas_handle handle, const rocblas_int m, + const rocblas_int n, rocblas_float_complex *const A[], + const rocblas_int lda, rocblas_float_complex *ipiv, + const rocblas_stride stridep, + const rocblas_int batch_count) { return rocsolver_geqlf_batched_impl( handle, m, n, A, lda, ipiv, stridep, batch_count); } -ROCSOLVER_EXPORT rocblas_status rocsolver_zgeqlf_batched( - rocblas_handle handle, const rocblas_int m, const rocblas_int n, - rocblas_double_complex *const A[], const rocblas_int lda, - rocblas_double_complex *ipiv, const rocblas_stride stridep, - const rocblas_int batch_count) { +rocblas_status +rocsolver_zgeqlf_batched(rocblas_handle handle, const rocblas_int m, + const rocblas_int n, rocblas_double_complex *const A[], + const rocblas_int lda, rocblas_double_complex *ipiv, + const rocblas_stride stridep, + const rocblas_int batch_count) { return rocsolver_geqlf_batched_impl( handle, m, n, A, lda, ipiv, stridep, batch_count); } } // extern C - -#undef batched diff --git a/rocsolver/library/src/lapack/roclapack_geqlf_strided_batched.cpp b/rocsolver/library/src/lapack/roclapack_geqlf_strided_batched.cpp index ba8b25aab..5c6a3bcae 100644 --- a/rocsolver/library/src/lapack/roclapack_geqlf_strided_batched.cpp +++ b/rocsolver/library/src/lapack/roclapack_geqlf_strided_batched.cpp @@ -24,7 +24,7 @@ rocblas_status rocsolver_geqlf_strided_batched_impl( size_t size_1; // size of constants size_t size_2; // size of workspace size_t size_3; // size of array of pointers to workspace - size_t size_4; // size of diagonal entry cache + size_t size_4; // size of diagonal entry cache and TRMM calls workspace size_t size_5; // size of triangular factor for block reflector rocsolver_geqlf_getMemorySize(m, n, batch_count, &size_1, &size_2, &size_3, &size_4, &size_5); @@ -68,7 +68,7 @@ rocblas_status rocsolver_geqlf_strided_batched_impl( extern "C" { -ROCSOLVER_EXPORT rocblas_status rocsolver_sgeqlf_strided_batched( +rocblas_status rocsolver_sgeqlf_strided_batched( rocblas_handle handle, const rocblas_int m, const rocblas_int n, float *A, const rocblas_int lda, const rocblas_stride strideA, float *ipiv, const rocblas_stride stridep, const rocblas_int batch_count) { @@ -76,7 +76,7 @@ ROCSOLVER_EXPORT rocblas_status rocsolver_sgeqlf_strided_batched( handle, m, n, A, lda, strideA, ipiv, stridep, batch_count); } -ROCSOLVER_EXPORT rocblas_status rocsolver_dgeqlf_strided_batched( +rocblas_status rocsolver_dgeqlf_strided_batched( rocblas_handle handle, const rocblas_int m, const rocblas_int n, double *A, const rocblas_int lda, const rocblas_stride strideA, double *ipiv, const rocblas_stride stridep, const rocblas_int batch_count) { @@ -84,7 +84,7 @@ ROCSOLVER_EXPORT rocblas_status rocsolver_dgeqlf_strided_batched( handle, m, n, A, lda, strideA, ipiv, stridep, batch_count); } -ROCSOLVER_EXPORT rocblas_status rocsolver_cgeqlf_strided_batched( +rocblas_status rocsolver_cgeqlf_strided_batched( rocblas_handle handle, const rocblas_int m, const rocblas_int n, rocblas_float_complex *A, const rocblas_int lda, const rocblas_stride strideA, rocblas_float_complex *ipiv, @@ -93,7 +93,7 @@ ROCSOLVER_EXPORT rocblas_status rocsolver_cgeqlf_strided_batched( handle, m, n, A, lda, strideA, ipiv, stridep, batch_count); } -ROCSOLVER_EXPORT rocblas_status rocsolver_zgeqlf_strided_batched( +rocblas_status rocsolver_zgeqlf_strided_batched( rocblas_handle handle, const rocblas_int m, const rocblas_int n, rocblas_double_complex *A, const rocblas_int lda, const rocblas_stride strideA, rocblas_double_complex *ipiv, diff --git a/rocsolver/library/src/lapack/roclapack_geqr2.cpp b/rocsolver/library/src/lapack/roclapack_geqr2.cpp index 0acf045a2..87a472b35 100644 --- a/rocsolver/library/src/lapack/roclapack_geqr2.cpp +++ b/rocsolver/library/src/lapack/roclapack_geqr2.cpp @@ -67,38 +67,30 @@ rocblas_status rocsolver_geqr2_impl(rocblas_handle handle, const rocblas_int m, extern "C" { -ROCSOLVER_EXPORT rocblas_status rocsolver_sgeqr2(rocblas_handle handle, - const rocblas_int m, - const rocblas_int n, float *A, - const rocblas_int lda, - float *ipiv) { +rocblas_status rocsolver_sgeqr2(rocblas_handle handle, const rocblas_int m, + const rocblas_int n, float *A, + const rocblas_int lda, float *ipiv) { return rocsolver_geqr2_impl(handle, m, n, A, lda, ipiv); } -ROCSOLVER_EXPORT rocblas_status rocsolver_dgeqr2(rocblas_handle handle, - const rocblas_int m, - const rocblas_int n, double *A, - const rocblas_int lda, - double *ipiv) { +rocblas_status rocsolver_dgeqr2(rocblas_handle handle, const rocblas_int m, + const rocblas_int n, double *A, + const rocblas_int lda, double *ipiv) { return rocsolver_geqr2_impl(handle, m, n, A, lda, ipiv); } -ROCSOLVER_EXPORT rocblas_status rocsolver_cgeqr2(rocblas_handle handle, - const rocblas_int m, - const rocblas_int n, - rocblas_float_complex *A, - const rocblas_int lda, - rocblas_float_complex *ipiv) { +rocblas_status rocsolver_cgeqr2(rocblas_handle handle, const rocblas_int m, + const rocblas_int n, rocblas_float_complex *A, + const rocblas_int lda, + rocblas_float_complex *ipiv) { return rocsolver_geqr2_impl(handle, m, n, A, lda, ipiv); } -ROCSOLVER_EXPORT rocblas_status rocsolver_zgeqr2(rocblas_handle handle, - const rocblas_int m, - const rocblas_int n, - rocblas_double_complex *A, - const rocblas_int lda, - rocblas_double_complex *ipiv) { +rocblas_status rocsolver_zgeqr2(rocblas_handle handle, const rocblas_int m, + const rocblas_int n, rocblas_double_complex *A, + const rocblas_int lda, + rocblas_double_complex *ipiv) { return rocsolver_geqr2_impl(handle, m, n, A, lda, ipiv); } diff --git a/rocsolver/library/src/lapack/roclapack_geqr2_batched.cpp b/rocsolver/library/src/lapack/roclapack_geqr2_batched.cpp index d66ce96df..170447b76 100644 --- a/rocsolver/library/src/lapack/roclapack_geqr2_batched.cpp +++ b/rocsolver/library/src/lapack/roclapack_geqr2_batched.cpp @@ -2,7 +2,6 @@ * Copyright (c) 2019-2020 Advanced Micro Devices, Inc. * ************************************************************************ */ -#define batched #include "roclapack_geqr2.hpp" template @@ -69,39 +68,44 @@ rocsolver_geqr2_batched_impl(rocblas_handle handle, const rocblas_int m, extern "C" { -ROCSOLVER_EXPORT rocblas_status rocsolver_sgeqr2_batched( - rocblas_handle handle, const rocblas_int m, const rocblas_int n, - float *const A[], const rocblas_int lda, float *ipiv, - const rocblas_stride stridep, const rocblas_int batch_count) { +rocblas_status rocsolver_sgeqr2_batched(rocblas_handle handle, + const rocblas_int m, + const rocblas_int n, float *const A[], + const rocblas_int lda, float *ipiv, + const rocblas_stride stridep, + const rocblas_int batch_count) { return rocsolver_geqr2_batched_impl(handle, m, n, A, lda, ipiv, stridep, batch_count); } -ROCSOLVER_EXPORT rocblas_status rocsolver_dgeqr2_batched( - rocblas_handle handle, const rocblas_int m, const rocblas_int n, - double *const A[], const rocblas_int lda, double *ipiv, - const rocblas_stride stridep, const rocblas_int batch_count) { +rocblas_status rocsolver_dgeqr2_batched(rocblas_handle handle, + const rocblas_int m, + const rocblas_int n, double *const A[], + const rocblas_int lda, double *ipiv, + const rocblas_stride stridep, + const rocblas_int batch_count) { return rocsolver_geqr2_batched_impl(handle, m, n, A, lda, ipiv, stridep, batch_count); } -ROCSOLVER_EXPORT rocblas_status rocsolver_cgeqr2_batched( - rocblas_handle handle, const rocblas_int m, const rocblas_int n, - rocblas_float_complex *const A[], const rocblas_int lda, - rocblas_float_complex *ipiv, const rocblas_stride stridep, - const rocblas_int batch_count) { +rocblas_status +rocsolver_cgeqr2_batched(rocblas_handle handle, const rocblas_int m, + const rocblas_int n, rocblas_float_complex *const A[], + const rocblas_int lda, rocblas_float_complex *ipiv, + const rocblas_stride stridep, + const rocblas_int batch_count) { return rocsolver_geqr2_batched_impl( handle, m, n, A, lda, ipiv, stridep, batch_count); } -ROCSOLVER_EXPORT rocblas_status rocsolver_zgeqr2_batched( - rocblas_handle handle, const rocblas_int m, const rocblas_int n, - rocblas_double_complex *const A[], const rocblas_int lda, - rocblas_double_complex *ipiv, const rocblas_stride stridep, - const rocblas_int batch_count) { +rocblas_status +rocsolver_zgeqr2_batched(rocblas_handle handle, const rocblas_int m, + const rocblas_int n, rocblas_double_complex *const A[], + const rocblas_int lda, rocblas_double_complex *ipiv, + const rocblas_stride stridep, + const rocblas_int batch_count) { return rocsolver_geqr2_batched_impl( handle, m, n, A, lda, ipiv, stridep, batch_count); } } // extern C -#undef batched diff --git a/rocsolver/library/src/lapack/roclapack_geqr2_strided_batched.cpp b/rocsolver/library/src/lapack/roclapack_geqr2_strided_batched.cpp index 06897801d..e6a49f56d 100644 --- a/rocsolver/library/src/lapack/roclapack_geqr2_strided_batched.cpp +++ b/rocsolver/library/src/lapack/roclapack_geqr2_strided_batched.cpp @@ -65,7 +65,7 @@ rocblas_status rocsolver_geqr2_strided_batched_impl( extern "C" { -ROCSOLVER_EXPORT rocblas_status rocsolver_sgeqr2_strided_batched( +rocblas_status rocsolver_sgeqr2_strided_batched( rocblas_handle handle, const rocblas_int m, const rocblas_int n, float *A, const rocblas_int lda, const rocblas_stride strideA, float *ipiv, const rocblas_stride stridep, const rocblas_int batch_count) { @@ -73,7 +73,7 @@ ROCSOLVER_EXPORT rocblas_status rocsolver_sgeqr2_strided_batched( handle, m, n, A, lda, strideA, ipiv, stridep, batch_count); } -ROCSOLVER_EXPORT rocblas_status rocsolver_dgeqr2_strided_batched( +rocblas_status rocsolver_dgeqr2_strided_batched( rocblas_handle handle, const rocblas_int m, const rocblas_int n, double *A, const rocblas_int lda, const rocblas_stride strideA, double *ipiv, const rocblas_stride stridep, const rocblas_int batch_count) { @@ -81,7 +81,7 @@ ROCSOLVER_EXPORT rocblas_status rocsolver_dgeqr2_strided_batched( handle, m, n, A, lda, strideA, ipiv, stridep, batch_count); } -ROCSOLVER_EXPORT rocblas_status rocsolver_cgeqr2_strided_batched( +rocblas_status rocsolver_cgeqr2_strided_batched( rocblas_handle handle, const rocblas_int m, const rocblas_int n, rocblas_float_complex *A, const rocblas_int lda, const rocblas_stride strideA, rocblas_float_complex *ipiv, @@ -90,7 +90,7 @@ ROCSOLVER_EXPORT rocblas_status rocsolver_cgeqr2_strided_batched( handle, m, n, A, lda, strideA, ipiv, stridep, batch_count); } -ROCSOLVER_EXPORT rocblas_status rocsolver_zgeqr2_strided_batched( +rocblas_status rocsolver_zgeqr2_strided_batched( rocblas_handle handle, const rocblas_int m, const rocblas_int n, rocblas_double_complex *A, const rocblas_int lda, const rocblas_stride strideA, rocblas_double_complex *ipiv, diff --git a/rocsolver/library/src/lapack/roclapack_geqrf.cpp b/rocsolver/library/src/lapack/roclapack_geqrf.cpp index bf6841c39..34b0dec50 100644 --- a/rocsolver/library/src/lapack/roclapack_geqrf.cpp +++ b/rocsolver/library/src/lapack/roclapack_geqrf.cpp @@ -26,7 +26,7 @@ rocblas_status rocsolver_geqrf_impl(rocblas_handle handle, const rocblas_int m, size_t size_1; // size of constants size_t size_2; // size of workspace size_t size_3; // size of array of pointers to workspace - size_t size_4; // size of diagonal entry cache + size_t size_4; // size of diagonal entry cache and TRMM calls workspace size_t size_5; // size of triangular factor for block reflector rocsolver_geqrf_getMemorySize(m, n, batch_count, &size_1, &size_2, &size_3, &size_4, &size_5); @@ -70,38 +70,30 @@ rocblas_status rocsolver_geqrf_impl(rocblas_handle handle, const rocblas_int m, extern "C" { -ROCSOLVER_EXPORT rocblas_status rocsolver_sgeqrf(rocblas_handle handle, - const rocblas_int m, - const rocblas_int n, float *A, - const rocblas_int lda, - float *ipiv) { +rocblas_status rocsolver_sgeqrf(rocblas_handle handle, const rocblas_int m, + const rocblas_int n, float *A, + const rocblas_int lda, float *ipiv) { return rocsolver_geqrf_impl(handle, m, n, A, lda, ipiv); } -ROCSOLVER_EXPORT rocblas_status rocsolver_dgeqrf(rocblas_handle handle, - const rocblas_int m, - const rocblas_int n, double *A, - const rocblas_int lda, - double *ipiv) { +rocblas_status rocsolver_dgeqrf(rocblas_handle handle, const rocblas_int m, + const rocblas_int n, double *A, + const rocblas_int lda, double *ipiv) { return rocsolver_geqrf_impl(handle, m, n, A, lda, ipiv); } -ROCSOLVER_EXPORT rocblas_status rocsolver_cgeqrf(rocblas_handle handle, - const rocblas_int m, - const rocblas_int n, - rocblas_float_complex *A, - const rocblas_int lda, - rocblas_float_complex *ipiv) { +rocblas_status rocsolver_cgeqrf(rocblas_handle handle, const rocblas_int m, + const rocblas_int n, rocblas_float_complex *A, + const rocblas_int lda, + rocblas_float_complex *ipiv) { return rocsolver_geqrf_impl(handle, m, n, A, lda, ipiv); } -ROCSOLVER_EXPORT rocblas_status rocsolver_zgeqrf(rocblas_handle handle, - const rocblas_int m, - const rocblas_int n, - rocblas_double_complex *A, - const rocblas_int lda, - rocblas_double_complex *ipiv) { +rocblas_status rocsolver_zgeqrf(rocblas_handle handle, const rocblas_int m, + const rocblas_int n, rocblas_double_complex *A, + const rocblas_int lda, + rocblas_double_complex *ipiv) { return rocsolver_geqrf_impl(handle, m, n, A, lda, ipiv); } diff --git a/rocsolver/library/src/lapack/roclapack_geqrf.hpp b/rocsolver/library/src/lapack/roclapack_geqrf.hpp index 71fd8232c..7df69c9f1 100644 --- a/rocsolver/library/src/lapack/roclapack_geqrf.hpp +++ b/rocsolver/library/src/lapack/roclapack_geqrf.hpp @@ -22,7 +22,7 @@ void rocsolver_geqrf_getMemorySize(const rocblas_int m, const rocblas_int n, size_t *size_1, size_t *size_2, size_t *size_3, size_t *size_4, size_t *size_5) { - size_t s1, s2, s3; + size_t s1, s2, s3, unused, s4 = 0; rocsolver_geqr2_getMemorySize(m, n, batch_count, size_1, &s1, size_3, size_4); if (m <= GEQRF_GEQR2_SWITCHSIZE || n <= GEQRF_GEQR2_SWITCHSIZE) { @@ -31,11 +31,17 @@ void rocsolver_geqrf_getMemorySize(const rocblas_int m, const rocblas_int n, } else { rocblas_int jb = GEQRF_GEQR2_BLOCKSIZE; rocsolver_larft_getMemorySize(jb, batch_count, &s2); - rocsolver_larfb_getMemorySize(rocblas_side_left, m, n - jb, jb, - batch_count, &s3); + rocsolver_larfb_getMemorySize(rocblas_side_left, m, n - jb, jb, + batch_count, &s3, &unused, &s4); *size_2 = max(s1, max(s2, s3)); *size_5 = sizeof(T) * jb * jb * batch_count; } + *size_4 = max(*size_4, s4); + + // size of workArr is double to accomodate + // the TRMM calls in the batched case + if (BATCHED) + *size_3 *= 2; } template @@ -88,7 +94,7 @@ rocsolver_geqrf_template(rocblas_handle handle, const rocblas_int m, rocblas_forward_direction, rocblas_column_wise, m - j, n - j - jb, jb, A, shiftA + idx2D(j, j, lda), lda, strideA, trfact, 0, ldw, strideW, A, shiftA + idx2D(j, j + jb, lda), lda, strideA, batch_count, work, - workArr); + workArr, diag); } j += GEQRF_GEQR2_BLOCKSIZE; } diff --git a/rocsolver/library/src/lapack/roclapack_geqrf_batched.cpp b/rocsolver/library/src/lapack/roclapack_geqrf_batched.cpp index a10c84ff7..dfb8e4e81 100644 --- a/rocsolver/library/src/lapack/roclapack_geqrf_batched.cpp +++ b/rocsolver/library/src/lapack/roclapack_geqrf_batched.cpp @@ -2,7 +2,6 @@ * Copyright (c) 2019-2020 Advanced Micro Devices, Inc. * ************************************************************************ */ -#define batched #include "roclapack_geqrf.hpp" template @@ -28,7 +27,7 @@ rocsolver_geqrf_batched_impl(rocblas_handle handle, const rocblas_int m, size_t size_1; // size of constants size_t size_2; // size of workspace size_t size_3; // size of array of pointers to workspace - size_t size_4; // size of diagonal entry cache + size_t size_4; // size of diagonal entry cache and TRMM calls workspace size_t size_5; // size of triangular factor for block reflector rocsolver_geqrf_getMemorySize(m, n, batch_count, &size_1, &size_2, &size_3, &size_4, &size_5); @@ -72,40 +71,44 @@ rocsolver_geqrf_batched_impl(rocblas_handle handle, const rocblas_int m, extern "C" { -ROCSOLVER_EXPORT rocblas_status rocsolver_sgeqrf_batched( - rocblas_handle handle, const rocblas_int m, const rocblas_int n, - float *const A[], const rocblas_int lda, float *ipiv, - const rocblas_stride stridep, const rocblas_int batch_count) { +rocblas_status rocsolver_sgeqrf_batched(rocblas_handle handle, + const rocblas_int m, + const rocblas_int n, float *const A[], + const rocblas_int lda, float *ipiv, + const rocblas_stride stridep, + const rocblas_int batch_count) { return rocsolver_geqrf_batched_impl(handle, m, n, A, lda, ipiv, stridep, batch_count); } -ROCSOLVER_EXPORT rocblas_status rocsolver_dgeqrf_batched( - rocblas_handle handle, const rocblas_int m, const rocblas_int n, - double *const A[], const rocblas_int lda, double *ipiv, - const rocblas_stride stridep, const rocblas_int batch_count) { +rocblas_status rocsolver_dgeqrf_batched(rocblas_handle handle, + const rocblas_int m, + const rocblas_int n, double *const A[], + const rocblas_int lda, double *ipiv, + const rocblas_stride stridep, + const rocblas_int batch_count) { return rocsolver_geqrf_batched_impl(handle, m, n, A, lda, ipiv, stridep, batch_count); } -ROCSOLVER_EXPORT rocblas_status rocsolver_cgeqrf_batched( - rocblas_handle handle, const rocblas_int m, const rocblas_int n, - rocblas_float_complex *const A[], const rocblas_int lda, - rocblas_float_complex *ipiv, const rocblas_stride stridep, - const rocblas_int batch_count) { +rocblas_status +rocsolver_cgeqrf_batched(rocblas_handle handle, const rocblas_int m, + const rocblas_int n, rocblas_float_complex *const A[], + const rocblas_int lda, rocblas_float_complex *ipiv, + const rocblas_stride stridep, + const rocblas_int batch_count) { return rocsolver_geqrf_batched_impl( handle, m, n, A, lda, ipiv, stridep, batch_count); } -ROCSOLVER_EXPORT rocblas_status rocsolver_zgeqrf_batched( - rocblas_handle handle, const rocblas_int m, const rocblas_int n, - rocblas_double_complex *const A[], const rocblas_int lda, - rocblas_double_complex *ipiv, const rocblas_stride stridep, - const rocblas_int batch_count) { +rocblas_status +rocsolver_zgeqrf_batched(rocblas_handle handle, const rocblas_int m, + const rocblas_int n, rocblas_double_complex *const A[], + const rocblas_int lda, rocblas_double_complex *ipiv, + const rocblas_stride stridep, + const rocblas_int batch_count) { return rocsolver_geqrf_batched_impl( handle, m, n, A, lda, ipiv, stridep, batch_count); } } // extern C - -#undef batched diff --git a/rocsolver/library/src/lapack/roclapack_geqrf_ptr_batched.cpp b/rocsolver/library/src/lapack/roclapack_geqrf_ptr_batched.cpp index 4156a6d47..4a315bcef 100644 --- a/rocsolver/library/src/lapack/roclapack_geqrf_ptr_batched.cpp +++ b/rocsolver/library/src/lapack/roclapack_geqrf_ptr_batched.cpp @@ -2,7 +2,6 @@ * Copyright (c) 2019-2020 Advanced Micro Devices, Inc. * ************************************************************************ */ -#define batched #include "roclapack_geqrf.hpp" /* @@ -50,7 +49,7 @@ rocblas_status rocsolver_geqrf_ptr_batched_impl(rocblas_handle handle, size_t size_1; // size of constants size_t size_2; // size of workspace size_t size_3; // size of array of pointers to workspace - size_t size_4; // size of diagonal entry cache + size_t size_4; // size of diagonal entry cache and TRMM calls size_t size_5; // size of triangular factor for block reflector size_t size_6 = sizeof(T) * strideP * batch_count; rocsolver_geqrf_getMemorySize(m, n, batch_count, &size_1, &size_2, @@ -137,5 +136,3 @@ ROCSOLVER_EXPORT rocblas_status rocsolver_zgeqrf_ptr_batched( } } // extern C - -#undef batched diff --git a/rocsolver/library/src/lapack/roclapack_geqrf_strided_batched.cpp b/rocsolver/library/src/lapack/roclapack_geqrf_strided_batched.cpp index b925eec2f..5e8214100 100644 --- a/rocsolver/library/src/lapack/roclapack_geqrf_strided_batched.cpp +++ b/rocsolver/library/src/lapack/roclapack_geqrf_strided_batched.cpp @@ -24,7 +24,7 @@ rocblas_status rocsolver_geqrf_strided_batched_impl( size_t size_1; // size of constants size_t size_2; // size of workspace size_t size_3; // size of array of pointers to workspace - size_t size_4; // size of diagonal entry cache + size_t size_4; // size of diagonal entry cache and TRMM calls workspace size_t size_5; // size of triangular factor for block reflector rocsolver_geqrf_getMemorySize(m, n, batch_count, &size_1, &size_2, &size_3, &size_4, &size_5); @@ -68,7 +68,7 @@ rocblas_status rocsolver_geqrf_strided_batched_impl( extern "C" { -ROCSOLVER_EXPORT rocblas_status rocsolver_sgeqrf_strided_batched( +rocblas_status rocsolver_sgeqrf_strided_batched( rocblas_handle handle, const rocblas_int m, const rocblas_int n, float *A, const rocblas_int lda, const rocblas_stride strideA, float *ipiv, const rocblas_stride stridep, const rocblas_int batch_count) { @@ -76,7 +76,7 @@ ROCSOLVER_EXPORT rocblas_status rocsolver_sgeqrf_strided_batched( handle, m, n, A, lda, strideA, ipiv, stridep, batch_count); } -ROCSOLVER_EXPORT rocblas_status rocsolver_dgeqrf_strided_batched( +rocblas_status rocsolver_dgeqrf_strided_batched( rocblas_handle handle, const rocblas_int m, const rocblas_int n, double *A, const rocblas_int lda, const rocblas_stride strideA, double *ipiv, const rocblas_stride stridep, const rocblas_int batch_count) { @@ -84,7 +84,7 @@ ROCSOLVER_EXPORT rocblas_status rocsolver_dgeqrf_strided_batched( handle, m, n, A, lda, strideA, ipiv, stridep, batch_count); } -ROCSOLVER_EXPORT rocblas_status rocsolver_cgeqrf_strided_batched( +rocblas_status rocsolver_cgeqrf_strided_batched( rocblas_handle handle, const rocblas_int m, const rocblas_int n, rocblas_float_complex *A, const rocblas_int lda, const rocblas_stride strideA, rocblas_float_complex *ipiv, @@ -93,7 +93,7 @@ ROCSOLVER_EXPORT rocblas_status rocsolver_cgeqrf_strided_batched( handle, m, n, A, lda, strideA, ipiv, stridep, batch_count); } -ROCSOLVER_EXPORT rocblas_status rocsolver_zgeqrf_strided_batched( +rocblas_status rocsolver_zgeqrf_strided_batched( rocblas_handle handle, const rocblas_int m, const rocblas_int n, rocblas_double_complex *A, const rocblas_int lda, const rocblas_stride strideA, rocblas_double_complex *ipiv, diff --git a/rocsolver/library/src/lapack/roclapack_gesvd.cpp b/rocsolver/library/src/lapack/roclapack_gesvd.cpp index 61b7ac3cc..d8f6b076f 100644 --- a/rocsolver/library/src/lapack/roclapack_gesvd.cpp +++ b/rocsolver/library/src/lapack/roclapack_gesvd.cpp @@ -42,19 +42,22 @@ rocsolver_gesvd_impl(rocblas_handle handle, const rocblas_svect left_svect, size_t size_4; // size of the array for the householder scalars size_t size_5; - rocsolver_gesvd_getMemorySize(left_svect, right_svect, m, n, - batch_count, &size_1, &size_2, - &size_3, &size_4, &size_5); + // size of workspace for TRMM calls + size_t size_6; + rocsolver_gesvd_getMemorySize( + left_svect, right_svect, m, n, batch_count, &size_1, &size_2, &size_3, + &size_4, &size_5, &size_6); // (TODO) MEMORY SIZE QUERIES AND ALLOCATIONS TO BE DONE WITH ROCBLAS HANDLE - void *scalars, *workgral, *workArr, *workfunc, *tau; + void *scalars, *workgral, *workArr, *workfunc, *tau, *workTrmm; hipMalloc(&scalars, size_1); hipMalloc(&workgral, size_2); hipMalloc(&workArr, size_3); hipMalloc(&workfunc, size_4); hipMalloc(&tau, size_5); + hipMalloc(&workTrmm, size_6); if ((size_1 && !scalars) || (size_2 && !workgral) || (size_3 && !workArr) || - (size_4 && !workfunc) || (size_5 && !tau)) + (size_4 && !workfunc) || (size_5 && !tau) || (size_6 && !workTrmm)) return rocblas_status_memory_error; // scalar constants for rocblas functions calls @@ -66,13 +69,15 @@ rocsolver_gesvd_impl(rocblas_handle handle, const rocblas_svect left_svect, rocblas_status status = rocsolver_gesvd_template( handle, left_svect, right_svect, m, n, A, 0, lda, strideA, S, strideS, U, ldu, strideU, V, ldv, strideV, E, strideE, fast_alg, info, batch_count, - (T *)scalars, workgral, (T **)workArr, (T *)workfunc, (T *)tau); + (T *)scalars, workgral, (T **)workArr, (T *)workfunc, (T *)tau, + (T *)workTrmm); hipFree(scalars); hipFree(workgral); hipFree(workArr); hipFree(workfunc); hipFree(tau); + hipFree(workTrmm); return status; } diff --git a/rocsolver/library/src/lapack/roclapack_gesvd.hpp b/rocsolver/library/src/lapack/roclapack_gesvd.hpp index d643d4530..28605a80e 100644 --- a/rocsolver/library/src/lapack/roclapack_gesvd.hpp +++ b/rocsolver/library/src/lapack/roclapack_gesvd.hpp @@ -159,7 +159,7 @@ void rocsolver_gesvd_getMemorySize(const rocblas_svect left_svect, const rocblas_int batch_count, size_t *size_1, size_t *size_2, size_t *size_3, size_t *size_4, - size_t *size_5) { + size_t *size_5, size_t *size_6) { // if quick return, set workspace to zero if (n == 0 || m == 0 || batch_count == 0) { *size_1 = 0; @@ -167,6 +167,7 @@ void rocsolver_gesvd_getMemorySize(const rocblas_svect left_svect, *size_3 = 0; *size_4 = 0; *size_5 = 0; + *size_6 = 0; return; } @@ -199,8 +200,9 @@ void rocsolver_gesvd_getMemorySize(const rocblas_svect left_svect, k = n; else k = m; - rocsolver_orgbr_ungbr_getMemorySize( - rocblas_column_wise, m, k, n, batch_count, size_1, &s2, size_3, &s4); + rocsolver_orgbr_ungbr_getMemorySize(rocblas_column_wise, m, k, + n, batch_count, size_1, &s2, + size_3, &s4, size_6); *size_2 = (s2 >= *size_2) ? s2 : *size_2; *size_4 = (s4 >= *size_4) ? s4 : *size_4; } @@ -211,8 +213,9 @@ void rocsolver_gesvd_getMemorySize(const rocblas_svect left_svect, k = m; else k = n; - rocsolver_orgbr_ungbr_getMemorySize( - rocblas_row_wise, k, n, m, batch_count, size_1, &s2, size_3, &s4); + rocsolver_orgbr_ungbr_getMemorySize(rocblas_row_wise, k, n, m, + batch_count, size_1, &s2, + size_3, &s4, size_6); *size_2 = (s2 >= *size_2) ? s2 : *size_2; *size_4 = (s4 >= *size_4) ? s4 : *size_4; } @@ -232,7 +235,7 @@ rocblas_status rocsolver_gesvd_template( const rocblas_int ldv, const rocblas_stride strideV, TT *E, const rocblas_stride strideE, const rocblas_workmode fast_alg, rocblas_int *info, const rocblas_int batch_count, T *scalars, - void *workgral, T **workArr, T *workfunc, T *tau) { + void *workgral, T **workArr, T *workfunc, T *tau, T *workTrmm) { constexpr bool COMPLEX = is_complex; // quick return @@ -305,7 +308,7 @@ rocblas_status rocsolver_gesvd_template( A, shiftA, lda, strideA, U, 0, ldu, strideU); rocsolver_orgbr_ungbr_template( handle, rocblas_column_wise, m, mn, n, U, 0, ldu, strideU, tau, k, - batch_count, scalars, (T *)workgral, workArr, workfunc); + batch_count, scalars, (T *)workgral, workArr, workfunc, workTrmm); } if (rightvS || rightvA) { @@ -316,20 +319,20 @@ rocblas_status rocsolver_gesvd_template( rocsolver_orgbr_ungbr_template( handle, rocblas_row_wise, mn, n, m, V, 0, ldv, strideV, (tau + k * batch_count), k, batch_count, scalars, (T *)workgral, - workArr, workfunc); + workArr, workfunc, workTrmm); } if (leftvO) { rocsolver_orgbr_ungbr_template( handle, rocblas_column_wise, m, k, n, A, shiftA, lda, strideA, tau, k, - batch_count, scalars, (T *)workgral, workArr, workfunc); + batch_count, scalars, (T *)workgral, workArr, workfunc, workTrmm); } if (rightvO) { rocsolver_orgbr_ungbr_template( handle, rocblas_row_wise, k, n, m, A, shiftA, lda, strideA, (tau + k * batch_count), k, batch_count, scalars, (T *)workgral, - workArr, workfunc); + workArr, workfunc, workTrmm); } // 3. compute singular values (and vectors if required) using the diff --git a/rocsolver/library/src/lapack/roclapack_gesvd_batched.cpp b/rocsolver/library/src/lapack/roclapack_gesvd_batched.cpp index f13b05beb..643794aee 100644 --- a/rocsolver/library/src/lapack/roclapack_gesvd_batched.cpp +++ b/rocsolver/library/src/lapack/roclapack_gesvd_batched.cpp @@ -2,8 +2,6 @@ * Copyright (c) 2019-2020 Advanced Micro Devices, Inc. * ************************************************************************ */ -#define batched - #include "roclapack_gesvd.hpp" template @@ -41,19 +39,22 @@ rocblas_status rocsolver_gesvd_batched_impl( size_t size_4; // size of the array for the householder scalars size_t size_5; - rocsolver_gesvd_getMemorySize(left_svect, right_svect, m, n, - batch_count, &size_1, &size_2, - &size_3, &size_4, &size_5); + // size of workspace for TRMM calls + size_t size_6; + rocsolver_gesvd_getMemorySize( + left_svect, right_svect, m, n, batch_count, &size_1, &size_2, &size_3, + &size_4, &size_5, &size_6); // (TODO) MEMORY SIZE QUERIES AND ALLOCATIONS TO BE DONE WITH ROCBLAS HANDLE - void *scalars, *workgral, *workArr, *workfunc, *tau; + void *scalars, *workgral, *workArr, *workfunc, *tau, *workTrmm; hipMalloc(&scalars, size_1); hipMalloc(&workgral, size_2); hipMalloc(&workArr, size_3); hipMalloc(&workfunc, size_4); hipMalloc(&tau, size_5); + hipMalloc(&workTrmm, size_6); if ((size_1 && !scalars) || (size_2 && !workgral) || (size_3 && !workArr) || - (size_4 && !workfunc) || (size_5 && !tau)) + (size_4 && !workfunc) || (size_5 && !tau) || (size_6 && !workTrmm)) return rocblas_status_memory_error; // scalar constants for rocblas functions calls @@ -65,13 +66,15 @@ rocblas_status rocsolver_gesvd_batched_impl( rocblas_status status = rocsolver_gesvd_template( handle, left_svect, right_svect, m, n, A, 0, lda, strideA, S, strideS, U, ldu, strideU, V, ldv, strideV, E, strideE, fast_alg, info, batch_count, - (T *)scalars, workgral, (T **)workArr, (T *)workfunc, (T *)tau); + (T *)scalars, workgral, (T **)workArr, (T *)workfunc, (T *)tau, + (T *)workTrmm); hipFree(scalars); hipFree(workgral); hipFree(workArr); hipFree(workfunc); hipFree(tau); + hipFree(workTrmm); return status; } @@ -143,5 +146,3 @@ rocblas_status rocsolver_zgesvd_batched( } } // extern C - -#undef batched diff --git a/rocsolver/library/src/lapack/roclapack_gesvd_strided_batched.cpp b/rocsolver/library/src/lapack/roclapack_gesvd_strided_batched.cpp index b98db264e..e76f1749c 100644 --- a/rocsolver/library/src/lapack/roclapack_gesvd_strided_batched.cpp +++ b/rocsolver/library/src/lapack/roclapack_gesvd_strided_batched.cpp @@ -38,19 +38,22 @@ rocblas_status rocsolver_gesvd_strided_batched_impl( size_t size_4; // size of the array for the householder scalars size_t size_5; - rocsolver_gesvd_getMemorySize(left_svect, right_svect, m, n, - batch_count, &size_1, &size_2, - &size_3, &size_4, &size_5); + // size of workspace for TRMM calls + size_t size_6; + rocsolver_gesvd_getMemorySize( + left_svect, right_svect, m, n, batch_count, &size_1, &size_2, &size_3, + &size_4, &size_5, &size_6); // (TODO) MEMORY SIZE QUERIES AND ALLOCATIONS TO BE DONE WITH ROCBLAS HANDLE - void *scalars, *workgral, *workArr, *workfunc, *tau; + void *scalars, *workgral, *workArr, *workfunc, *tau, *workTrmm; hipMalloc(&scalars, size_1); hipMalloc(&workgral, size_2); hipMalloc(&workArr, size_3); hipMalloc(&workfunc, size_4); hipMalloc(&tau, size_5); + hipMalloc(&workTrmm, size_6); if ((size_1 && !scalars) || (size_2 && !workgral) || (size_3 && !workArr) || - (size_4 && !workfunc) || (size_5 && !tau)) + (size_4 && !workfunc) || (size_5 && !tau) || (size_6 && !workTrmm)) return rocblas_status_memory_error; // scalar constants for rocblas functions calls @@ -62,13 +65,15 @@ rocblas_status rocsolver_gesvd_strided_batched_impl( rocblas_status status = rocsolver_gesvd_template( handle, left_svect, right_svect, m, n, A, 0, lda, strideA, S, strideS, U, ldu, strideU, V, ldv, strideV, E, strideE, fast_alg, info, batch_count, - (T *)scalars, workgral, (T **)workArr, (T *)workfunc, (T *)tau); + (T *)scalars, workgral, (T **)workArr, (T *)workfunc, (T *)tau, + (T *)workTrmm); hipFree(scalars); hipFree(workgral); hipFree(workArr); hipFree(workfunc); hipFree(tau); + hipFree(workTrmm); return status; } diff --git a/rocsolver/library/src/lapack/roclapack_getf2.cpp b/rocsolver/library/src/lapack/roclapack_getf2.cpp index afb030321..1fa86371f 100644 --- a/rocsolver/library/src/lapack/roclapack_getf2.cpp +++ b/rocsolver/library/src/lapack/roclapack_getf2.cpp @@ -71,59 +71,63 @@ rocblas_status rocsolver_getf2_impl(rocblas_handle handle, const rocblas_int m, extern "C" { -ROCSOLVER_EXPORT rocblas_status rocsolver_sgetf2( - rocblas_handle handle, const rocblas_int m, const rocblas_int n, float *A, - const rocblas_int lda, rocblas_int *ipiv, rocblas_int *info) { +rocblas_status rocsolver_sgetf2(rocblas_handle handle, const rocblas_int m, + const rocblas_int n, float *A, + const rocblas_int lda, rocblas_int *ipiv, + rocblas_int *info) { return rocsolver_getf2_impl(handle, m, n, A, lda, ipiv, info, 1); } -ROCSOLVER_EXPORT rocblas_status rocsolver_dgetf2( - rocblas_handle handle, const rocblas_int m, const rocblas_int n, double *A, - const rocblas_int lda, rocblas_int *ipiv, rocblas_int *info) { +rocblas_status rocsolver_dgetf2(rocblas_handle handle, const rocblas_int m, + const rocblas_int n, double *A, + const rocblas_int lda, rocblas_int *ipiv, + rocblas_int *info) { return rocsolver_getf2_impl(handle, m, n, A, lda, ipiv, info, 1); } -ROCSOLVER_EXPORT rocblas_status -rocsolver_cgetf2(rocblas_handle handle, const rocblas_int m, - const rocblas_int n, rocblas_float_complex *A, - const rocblas_int lda, rocblas_int *ipiv, rocblas_int *info) { +rocblas_status rocsolver_cgetf2(rocblas_handle handle, const rocblas_int m, + const rocblas_int n, rocblas_float_complex *A, + const rocblas_int lda, rocblas_int *ipiv, + rocblas_int *info) { return rocsolver_getf2_impl(handle, m, n, A, lda, ipiv, info, 1); } -ROCSOLVER_EXPORT rocblas_status -rocsolver_zgetf2(rocblas_handle handle, const rocblas_int m, - const rocblas_int n, rocblas_double_complex *A, - const rocblas_int lda, rocblas_int *ipiv, rocblas_int *info) { +rocblas_status rocsolver_zgetf2(rocblas_handle handle, const rocblas_int m, + const rocblas_int n, rocblas_double_complex *A, + const rocblas_int lda, rocblas_int *ipiv, + rocblas_int *info) { return rocsolver_getf2_impl(handle, m, n, A, lda, ipiv, info, 1); } -ROCSOLVER_EXPORT rocblas_status rocsolver_sgetf2_npvt( - rocblas_handle handle, const rocblas_int m, const rocblas_int n, float *A, - const rocblas_int lda, rocblas_int *info) { +rocblas_status rocsolver_sgetf2_npvt(rocblas_handle handle, const rocblas_int m, + const rocblas_int n, float *A, + const rocblas_int lda, rocblas_int *info) { rocblas_int *ipiv; return rocsolver_getf2_impl(handle, m, n, A, lda, ipiv, info, 0); } -ROCSOLVER_EXPORT rocblas_status rocsolver_dgetf2_npvt( - rocblas_handle handle, const rocblas_int m, const rocblas_int n, double *A, - const rocblas_int lda, rocblas_int *info) { +rocblas_status rocsolver_dgetf2_npvt(rocblas_handle handle, const rocblas_int m, + const rocblas_int n, double *A, + const rocblas_int lda, rocblas_int *info) { rocblas_int *ipiv; return rocsolver_getf2_impl(handle, m, n, A, lda, ipiv, info, 0); } -ROCSOLVER_EXPORT rocblas_status rocsolver_cgetf2_npvt( - rocblas_handle handle, const rocblas_int m, const rocblas_int n, - rocblas_float_complex *A, const rocblas_int lda, rocblas_int *info) { +rocblas_status rocsolver_cgetf2_npvt(rocblas_handle handle, const rocblas_int m, + const rocblas_int n, + rocblas_float_complex *A, + const rocblas_int lda, rocblas_int *info) { rocblas_int *ipiv; return rocsolver_getf2_impl(handle, m, n, A, lda, ipiv, info, 0); } -ROCSOLVER_EXPORT rocblas_status rocsolver_zgetf2_npvt( - rocblas_handle handle, const rocblas_int m, const rocblas_int n, - rocblas_double_complex *A, const rocblas_int lda, rocblas_int *info) { +rocblas_status rocsolver_zgetf2_npvt(rocblas_handle handle, const rocblas_int m, + const rocblas_int n, + rocblas_double_complex *A, + const rocblas_int lda, rocblas_int *info) { rocblas_int *ipiv; return rocsolver_getf2_impl(handle, m, n, A, lda, ipiv, info, 0); diff --git a/rocsolver/library/src/lapack/roclapack_getf2_batched.cpp b/rocsolver/library/src/lapack/roclapack_getf2_batched.cpp index 2bbf9cd80..4986e7d8e 100644 --- a/rocsolver/library/src/lapack/roclapack_getf2_batched.cpp +++ b/rocsolver/library/src/lapack/roclapack_getf2_batched.cpp @@ -2,7 +2,6 @@ * Copyright (c) 2019-2020 Advanced Micro Devices, Inc. * ************************************************************************ */ -#define batched #include "roclapack_getf2.hpp" template @@ -72,61 +71,67 @@ rocblas_status rocsolver_getf2_batched_impl( extern "C" { -ROCSOLVER_EXPORT rocblas_status rocsolver_sgetf2_batched( - rocblas_handle handle, const rocblas_int m, const rocblas_int n, - float *const A[], const rocblas_int lda, rocblas_int *ipiv, - const rocblas_stride strideP, rocblas_int *info, - const rocblas_int batch_count) { +rocblas_status +rocsolver_sgetf2_batched(rocblas_handle handle, const rocblas_int m, + const rocblas_int n, float *const A[], + const rocblas_int lda, rocblas_int *ipiv, + const rocblas_stride strideP, rocblas_int *info, + const rocblas_int batch_count) { return rocsolver_getf2_batched_impl(handle, m, n, A, lda, ipiv, strideP, info, batch_count, 1); } -ROCSOLVER_EXPORT rocblas_status rocsolver_dgetf2_batched( - rocblas_handle handle, const rocblas_int m, const rocblas_int n, - double *const A[], const rocblas_int lda, rocblas_int *ipiv, - const rocblas_stride strideP, rocblas_int *info, - const rocblas_int batch_count) { +rocblas_status +rocsolver_dgetf2_batched(rocblas_handle handle, const rocblas_int m, + const rocblas_int n, double *const A[], + const rocblas_int lda, rocblas_int *ipiv, + const rocblas_stride strideP, rocblas_int *info, + const rocblas_int batch_count) { return rocsolver_getf2_batched_impl(handle, m, n, A, lda, ipiv, strideP, info, batch_count, 1); } -ROCSOLVER_EXPORT rocblas_status rocsolver_cgetf2_batched( - rocblas_handle handle, const rocblas_int m, const rocblas_int n, - rocblas_float_complex *const A[], const rocblas_int lda, rocblas_int *ipiv, - const rocblas_stride strideP, rocblas_int *info, - const rocblas_int batch_count) { +rocblas_status +rocsolver_cgetf2_batched(rocblas_handle handle, const rocblas_int m, + const rocblas_int n, rocblas_float_complex *const A[], + const rocblas_int lda, rocblas_int *ipiv, + const rocblas_stride strideP, rocblas_int *info, + const rocblas_int batch_count) { return rocsolver_getf2_batched_impl( handle, m, n, A, lda, ipiv, strideP, info, batch_count, 1); } -ROCSOLVER_EXPORT rocblas_status rocsolver_zgetf2_batched( - rocblas_handle handle, const rocblas_int m, const rocblas_int n, - rocblas_double_complex *const A[], const rocblas_int lda, rocblas_int *ipiv, - const rocblas_stride strideP, rocblas_int *info, - const rocblas_int batch_count) { +rocblas_status +rocsolver_zgetf2_batched(rocblas_handle handle, const rocblas_int m, + const rocblas_int n, rocblas_double_complex *const A[], + const rocblas_int lda, rocblas_int *ipiv, + const rocblas_stride strideP, rocblas_int *info, + const rocblas_int batch_count) { return rocsolver_getf2_batched_impl( handle, m, n, A, lda, ipiv, strideP, info, batch_count, 1); } -ROCSOLVER_EXPORT rocblas_status rocsolver_sgetf2_npvt_batched( - rocblas_handle handle, const rocblas_int m, const rocblas_int n, - float *const A[], const rocblas_int lda, rocblas_int *info, - const rocblas_int batch_count) { +rocblas_status +rocsolver_sgetf2_npvt_batched(rocblas_handle handle, const rocblas_int m, + const rocblas_int n, float *const A[], + const rocblas_int lda, rocblas_int *info, + const rocblas_int batch_count) { rocblas_int *ipiv; return rocsolver_getf2_batched_impl(handle, m, n, A, lda, ipiv, 0, info, batch_count, 0); } -ROCSOLVER_EXPORT rocblas_status rocsolver_dgetf2_npvt_batched( - rocblas_handle handle, const rocblas_int m, const rocblas_int n, - double *const A[], const rocblas_int lda, rocblas_int *info, - const rocblas_int batch_count) { +rocblas_status +rocsolver_dgetf2_npvt_batched(rocblas_handle handle, const rocblas_int m, + const rocblas_int n, double *const A[], + const rocblas_int lda, rocblas_int *info, + const rocblas_int batch_count) { rocblas_int *ipiv; return rocsolver_getf2_batched_impl(handle, m, n, A, lda, ipiv, 0, info, batch_count, 0); } -ROCSOLVER_EXPORT rocblas_status rocsolver_cgetf2_npvt_batched( +rocblas_status rocsolver_cgetf2_npvt_batched( rocblas_handle handle, const rocblas_int m, const rocblas_int n, rocblas_float_complex *const A[], const rocblas_int lda, rocblas_int *info, const rocblas_int batch_count) { @@ -135,7 +140,7 @@ ROCSOLVER_EXPORT rocblas_status rocsolver_cgetf2_npvt_batched( handle, m, n, A, lda, ipiv, 0, info, batch_count, 0); } -ROCSOLVER_EXPORT rocblas_status rocsolver_zgetf2_npvt_batched( +rocblas_status rocsolver_zgetf2_npvt_batched( rocblas_handle handle, const rocblas_int m, const rocblas_int n, rocblas_double_complex *const A[], const rocblas_int lda, rocblas_int *info, const rocblas_int batch_count) { @@ -145,5 +150,3 @@ ROCSOLVER_EXPORT rocblas_status rocsolver_zgetf2_npvt_batched( } } // extern C - -#undef batched diff --git a/rocsolver/library/src/lapack/roclapack_getf2_strided_batched.cpp b/rocsolver/library/src/lapack/roclapack_getf2_strided_batched.cpp index 816cc4847..0e856bbe6 100644 --- a/rocsolver/library/src/lapack/roclapack_getf2_strided_batched.cpp +++ b/rocsolver/library/src/lapack/roclapack_getf2_strided_batched.cpp @@ -70,7 +70,7 @@ rocblas_status rocsolver_getf2_strided_batched_impl( extern "C" { -ROCSOLVER_EXPORT rocblas_status rocsolver_sgetf2_strided_batched( +rocblas_status rocsolver_sgetf2_strided_batched( rocblas_handle handle, const rocblas_int m, const rocblas_int n, float *A, const rocblas_int lda, const rocblas_stride strideA, rocblas_int *ipiv, const rocblas_stride strideP, rocblas_int *info, @@ -79,7 +79,7 @@ ROCSOLVER_EXPORT rocblas_status rocsolver_sgetf2_strided_batched( handle, m, n, A, lda, strideA, ipiv, strideP, info, batch_count, 1); } -ROCSOLVER_EXPORT rocblas_status rocsolver_dgetf2_strided_batched( +rocblas_status rocsolver_dgetf2_strided_batched( rocblas_handle handle, const rocblas_int m, const rocblas_int n, double *A, const rocblas_int lda, const rocblas_stride strideA, rocblas_int *ipiv, const rocblas_stride strideP, rocblas_int *info, @@ -88,7 +88,7 @@ ROCSOLVER_EXPORT rocblas_status rocsolver_dgetf2_strided_batched( handle, m, n, A, lda, strideA, ipiv, strideP, info, batch_count, 1); } -ROCSOLVER_EXPORT rocblas_status rocsolver_cgetf2_strided_batched( +rocblas_status rocsolver_cgetf2_strided_batched( rocblas_handle handle, const rocblas_int m, const rocblas_int n, rocblas_float_complex *A, const rocblas_int lda, const rocblas_stride strideA, rocblas_int *ipiv, @@ -98,7 +98,7 @@ ROCSOLVER_EXPORT rocblas_status rocsolver_cgetf2_strided_batched( handle, m, n, A, lda, strideA, ipiv, strideP, info, batch_count, 1); } -ROCSOLVER_EXPORT rocblas_status rocsolver_zgetf2_strided_batched( +rocblas_status rocsolver_zgetf2_strided_batched( rocblas_handle handle, const rocblas_int m, const rocblas_int n, rocblas_double_complex *A, const rocblas_int lda, const rocblas_stride strideA, rocblas_int *ipiv, @@ -108,7 +108,7 @@ ROCSOLVER_EXPORT rocblas_status rocsolver_zgetf2_strided_batched( handle, m, n, A, lda, strideA, ipiv, strideP, info, batch_count, 1); } -ROCSOLVER_EXPORT rocblas_status rocsolver_sgetf2_npvt_strided_batched( +rocblas_status rocsolver_sgetf2_npvt_strided_batched( rocblas_handle handle, const rocblas_int m, const rocblas_int n, float *A, const rocblas_int lda, const rocblas_stride strideA, rocblas_int *info, const rocblas_int batch_count) { @@ -117,7 +117,7 @@ ROCSOLVER_EXPORT rocblas_status rocsolver_sgetf2_npvt_strided_batched( handle, m, n, A, lda, strideA, ipiv, 0, info, batch_count, 0); } -ROCSOLVER_EXPORT rocblas_status rocsolver_dgetf2_npvt_strided_batched( +rocblas_status rocsolver_dgetf2_npvt_strided_batched( rocblas_handle handle, const rocblas_int m, const rocblas_int n, double *A, const rocblas_int lda, const rocblas_stride strideA, rocblas_int *info, const rocblas_int batch_count) { @@ -126,7 +126,7 @@ ROCSOLVER_EXPORT rocblas_status rocsolver_dgetf2_npvt_strided_batched( handle, m, n, A, lda, strideA, ipiv, 0, info, batch_count, 0); } -ROCSOLVER_EXPORT rocblas_status rocsolver_cgetf2_npvt_strided_batched( +rocblas_status rocsolver_cgetf2_npvt_strided_batched( rocblas_handle handle, const rocblas_int m, const rocblas_int n, rocblas_float_complex *A, const rocblas_int lda, const rocblas_stride strideA, rocblas_int *info, @@ -136,7 +136,7 @@ ROCSOLVER_EXPORT rocblas_status rocsolver_cgetf2_npvt_strided_batched( handle, m, n, A, lda, strideA, ipiv, 0, info, batch_count, 0); } -ROCSOLVER_EXPORT rocblas_status rocsolver_zgetf2_npvt_strided_batched( +rocblas_status rocsolver_zgetf2_npvt_strided_batched( rocblas_handle handle, const rocblas_int m, const rocblas_int n, rocblas_double_complex *A, const rocblas_int lda, const rocblas_stride strideA, rocblas_int *info, diff --git a/rocsolver/library/src/lapack/roclapack_getrf.cpp b/rocsolver/library/src/lapack/roclapack_getrf.cpp index 5dd665e90..a30e1d04b 100644 --- a/rocsolver/library/src/lapack/roclapack_getrf.cpp +++ b/rocsolver/library/src/lapack/roclapack_getrf.cpp @@ -41,8 +41,8 @@ rocblas_status rocsolver_getrf_impl(rocblas_handle handle, const rocblas_int m, // (TODO) MEMORY SIZE QUERIES AND ALLOCATIONS TO BE DONE WITH ROCBLAS HANDLE void *scalars, *pivot_val, *pivot_idx, *iinfo, *work, *x_temp, *x_temp_arr, *invA, *invA_arr; - bool optim_mem = - true; // always allocate all required memory for TRSM optimal performance + // always allocate all required memory for TRSM optimal performance + bool optim_mem = true; hipMalloc(&scalars, size_1); hipMalloc(&pivot_val, size_2); @@ -94,59 +94,63 @@ rocblas_status rocsolver_getrf_impl(rocblas_handle handle, const rocblas_int m, extern "C" { -ROCSOLVER_EXPORT rocblas_status rocsolver_sgetrf( - rocblas_handle handle, const rocblas_int m, const rocblas_int n, float *A, - const rocblas_int lda, rocblas_int *ipiv, rocblas_int *info) { +rocblas_status rocsolver_sgetrf(rocblas_handle handle, const rocblas_int m, + const rocblas_int n, float *A, + const rocblas_int lda, rocblas_int *ipiv, + rocblas_int *info) { return rocsolver_getrf_impl(handle, m, n, A, lda, ipiv, info, 1); } -ROCSOLVER_EXPORT rocblas_status rocsolver_dgetrf( - rocblas_handle handle, const rocblas_int m, const rocblas_int n, double *A, - const rocblas_int lda, rocblas_int *ipiv, rocblas_int *info) { +rocblas_status rocsolver_dgetrf(rocblas_handle handle, const rocblas_int m, + const rocblas_int n, double *A, + const rocblas_int lda, rocblas_int *ipiv, + rocblas_int *info) { return rocsolver_getrf_impl(handle, m, n, A, lda, ipiv, info, 1); } -ROCSOLVER_EXPORT rocblas_status -rocsolver_cgetrf(rocblas_handle handle, const rocblas_int m, - const rocblas_int n, rocblas_float_complex *A, - const rocblas_int lda, rocblas_int *ipiv, rocblas_int *info) { +rocblas_status rocsolver_cgetrf(rocblas_handle handle, const rocblas_int m, + const rocblas_int n, rocblas_float_complex *A, + const rocblas_int lda, rocblas_int *ipiv, + rocblas_int *info) { return rocsolver_getrf_impl(handle, m, n, A, lda, ipiv, info, 1); } -ROCSOLVER_EXPORT rocblas_status -rocsolver_zgetrf(rocblas_handle handle, const rocblas_int m, - const rocblas_int n, rocblas_double_complex *A, - const rocblas_int lda, rocblas_int *ipiv, rocblas_int *info) { +rocblas_status rocsolver_zgetrf(rocblas_handle handle, const rocblas_int m, + const rocblas_int n, rocblas_double_complex *A, + const rocblas_int lda, rocblas_int *ipiv, + rocblas_int *info) { return rocsolver_getrf_impl(handle, m, n, A, lda, ipiv, info, 1); } -ROCSOLVER_EXPORT rocblas_status rocsolver_sgetrf_npvt( - rocblas_handle handle, const rocblas_int m, const rocblas_int n, float *A, - const rocblas_int lda, rocblas_int *info) { +rocblas_status rocsolver_sgetrf_npvt(rocblas_handle handle, const rocblas_int m, + const rocblas_int n, float *A, + const rocblas_int lda, rocblas_int *info) { rocblas_int *ipiv; return rocsolver_getrf_impl(handle, m, n, A, lda, ipiv, info, 0); } -ROCSOLVER_EXPORT rocblas_status rocsolver_dgetrf_npvt( - rocblas_handle handle, const rocblas_int m, const rocblas_int n, double *A, - const rocblas_int lda, rocblas_int *info) { +rocblas_status rocsolver_dgetrf_npvt(rocblas_handle handle, const rocblas_int m, + const rocblas_int n, double *A, + const rocblas_int lda, rocblas_int *info) { rocblas_int *ipiv; return rocsolver_getrf_impl(handle, m, n, A, lda, ipiv, info, 0); } -ROCSOLVER_EXPORT rocblas_status rocsolver_cgetrf_npvt( - rocblas_handle handle, const rocblas_int m, const rocblas_int n, - rocblas_float_complex *A, const rocblas_int lda, rocblas_int *info) { +rocblas_status rocsolver_cgetrf_npvt(rocblas_handle handle, const rocblas_int m, + const rocblas_int n, + rocblas_float_complex *A, + const rocblas_int lda, rocblas_int *info) { rocblas_int *ipiv; return rocsolver_getrf_impl(handle, m, n, A, lda, ipiv, info, 0); } -ROCSOLVER_EXPORT rocblas_status rocsolver_zgetrf_npvt( - rocblas_handle handle, const rocblas_int m, const rocblas_int n, - rocblas_double_complex *A, const rocblas_int lda, rocblas_int *info) { +rocblas_status rocsolver_zgetrf_npvt(rocblas_handle handle, const rocblas_int m, + const rocblas_int n, + rocblas_double_complex *A, + const rocblas_int lda, rocblas_int *info) { rocblas_int *ipiv; return rocsolver_getrf_impl(handle, m, n, A, lda, ipiv, info, 0); diff --git a/rocsolver/library/src/lapack/roclapack_getrf.hpp b/rocsolver/library/src/lapack/roclapack_getrf.hpp index 5489a15d3..2b7b60a0b 100644 --- a/rocsolver/library/src/lapack/roclapack_getrf.hpp +++ b/rocsolver/library/src/lapack/roclapack_getrf.hpp @@ -47,12 +47,16 @@ void rocsolver_getrf_getMemorySize(const rocblas_int m, const rocblas_int n, size_5); if (m < GETRF_GETF2_SWITCHSIZE || n < GETRF_GETF2_SWITCHSIZE) { *size_4 = 0; + *size_6 = 0; + *size_7 = 0; + *size_8 = 0; + *size_9 = 0; } else { *size_4 = sizeof(rocblas_int) * batch_count; + rocblas_int jb = GETRF_GETF2_SWITCHSIZE; + rocblasCall_trsm_mem(rocblas_side_left, jb, n - jb, batch_count, + size_6, size_7, size_8, size_9); } - - rocblasCall_trsm_mem(rocblas_side_left, m, n, batch_count, size_6, - size_7, size_8, size_9); } template diff --git a/rocsolver/library/src/lapack/roclapack_getrf_batched.cpp b/rocsolver/library/src/lapack/roclapack_getrf_batched.cpp index eb37f4261..6395c5d97 100644 --- a/rocsolver/library/src/lapack/roclapack_getrf_batched.cpp +++ b/rocsolver/library/src/lapack/roclapack_getrf_batched.cpp @@ -2,7 +2,6 @@ * Copyright (c) 2019-2020 Advanced Micro Devices, Inc. * ************************************************************************ */ -#define batched #include "roclapack_getrf.hpp" template @@ -94,61 +93,67 @@ rocblas_status rocsolver_getrf_batched_impl( extern "C" { -ROCSOLVER_EXPORT rocblas_status rocsolver_sgetrf_batched( - rocblas_handle handle, const rocblas_int m, const rocblas_int n, - float *const A[], const rocblas_int lda, rocblas_int *ipiv, - const rocblas_stride strideP, rocblas_int *info, - const rocblas_int batch_count) { +rocblas_status +rocsolver_sgetrf_batched(rocblas_handle handle, const rocblas_int m, + const rocblas_int n, float *const A[], + const rocblas_int lda, rocblas_int *ipiv, + const rocblas_stride strideP, rocblas_int *info, + const rocblas_int batch_count) { return rocsolver_getrf_batched_impl(handle, m, n, A, lda, ipiv, strideP, info, batch_count, 1); } -ROCSOLVER_EXPORT rocblas_status rocsolver_dgetrf_batched( - rocblas_handle handle, const rocblas_int m, const rocblas_int n, - double *const A[], const rocblas_int lda, rocblas_int *ipiv, - const rocblas_stride strideP, rocblas_int *info, - const rocblas_int batch_count) { +rocblas_status +rocsolver_dgetrf_batched(rocblas_handle handle, const rocblas_int m, + const rocblas_int n, double *const A[], + const rocblas_int lda, rocblas_int *ipiv, + const rocblas_stride strideP, rocblas_int *info, + const rocblas_int batch_count) { return rocsolver_getrf_batched_impl(handle, m, n, A, lda, ipiv, strideP, info, batch_count, 1); } -ROCSOLVER_EXPORT rocblas_status rocsolver_cgetrf_batched( - rocblas_handle handle, const rocblas_int m, const rocblas_int n, - rocblas_float_complex *const A[], const rocblas_int lda, rocblas_int *ipiv, - const rocblas_stride strideP, rocblas_int *info, - const rocblas_int batch_count) { +rocblas_status +rocsolver_cgetrf_batched(rocblas_handle handle, const rocblas_int m, + const rocblas_int n, rocblas_float_complex *const A[], + const rocblas_int lda, rocblas_int *ipiv, + const rocblas_stride strideP, rocblas_int *info, + const rocblas_int batch_count) { return rocsolver_getrf_batched_impl( handle, m, n, A, lda, ipiv, strideP, info, batch_count, 1); } -ROCSOLVER_EXPORT rocblas_status rocsolver_zgetrf_batched( - rocblas_handle handle, const rocblas_int m, const rocblas_int n, - rocblas_double_complex *const A[], const rocblas_int lda, rocblas_int *ipiv, - const rocblas_stride strideP, rocblas_int *info, - const rocblas_int batch_count) { +rocblas_status +rocsolver_zgetrf_batched(rocblas_handle handle, const rocblas_int m, + const rocblas_int n, rocblas_double_complex *const A[], + const rocblas_int lda, rocblas_int *ipiv, + const rocblas_stride strideP, rocblas_int *info, + const rocblas_int batch_count) { return rocsolver_getrf_batched_impl( handle, m, n, A, lda, ipiv, strideP, info, batch_count, 1); } -ROCSOLVER_EXPORT rocblas_status rocsolver_sgetrf_npvt_batched( - rocblas_handle handle, const rocblas_int m, const rocblas_int n, - float *const A[], const rocblas_int lda, rocblas_int *info, - const rocblas_int batch_count) { +rocblas_status +rocsolver_sgetrf_npvt_batched(rocblas_handle handle, const rocblas_int m, + const rocblas_int n, float *const A[], + const rocblas_int lda, rocblas_int *info, + const rocblas_int batch_count) { rocblas_int *ipiv; return rocsolver_getrf_batched_impl(handle, m, n, A, lda, ipiv, 0, info, batch_count, 0); } -ROCSOLVER_EXPORT rocblas_status rocsolver_dgetrf_npvt_batched( - rocblas_handle handle, const rocblas_int m, const rocblas_int n, - double *const A[], const rocblas_int lda, rocblas_int *info, - const rocblas_int batch_count) { +rocblas_status +rocsolver_dgetrf_npvt_batched(rocblas_handle handle, const rocblas_int m, + const rocblas_int n, double *const A[], + const rocblas_int lda, rocblas_int *info, + const rocblas_int batch_count) { rocblas_int *ipiv; return rocsolver_getrf_batched_impl(handle, m, n, A, lda, ipiv, 0, info, batch_count, 0); } -ROCSOLVER_EXPORT rocblas_status rocsolver_cgetrf_npvt_batched( +rocblas_status rocsolver_cgetrf_npvt_batched( rocblas_handle handle, const rocblas_int m, const rocblas_int n, rocblas_float_complex *const A[], const rocblas_int lda, rocblas_int *info, const rocblas_int batch_count) { @@ -157,7 +162,7 @@ ROCSOLVER_EXPORT rocblas_status rocsolver_cgetrf_npvt_batched( handle, m, n, A, lda, ipiv, 0, info, batch_count, 0); } -ROCSOLVER_EXPORT rocblas_status rocsolver_zgetrf_npvt_batched( +rocblas_status rocsolver_zgetrf_npvt_batched( rocblas_handle handle, const rocblas_int m, const rocblas_int n, rocblas_double_complex *const A[], const rocblas_int lda, rocblas_int *info, const rocblas_int batch_count) { @@ -167,5 +172,3 @@ ROCSOLVER_EXPORT rocblas_status rocsolver_zgetrf_npvt_batched( } } // extern C - -#undef batched diff --git a/rocsolver/library/src/lapack/roclapack_getrf_strided_batched.cpp b/rocsolver/library/src/lapack/roclapack_getrf_strided_batched.cpp index 4357c9c23..3c4f0c163 100644 --- a/rocsolver/library/src/lapack/roclapack_getrf_strided_batched.cpp +++ b/rocsolver/library/src/lapack/roclapack_getrf_strided_batched.cpp @@ -92,7 +92,7 @@ rocblas_status rocsolver_getrf_strided_batched_impl( extern "C" { -ROCSOLVER_EXPORT rocblas_status rocsolver_sgetrf_strided_batched( +rocblas_status rocsolver_sgetrf_strided_batched( rocblas_handle handle, const rocblas_int m, const rocblas_int n, float *A, const rocblas_int lda, const rocblas_stride strideA, rocblas_int *ipiv, const rocblas_stride strideP, rocblas_int *info, @@ -101,7 +101,7 @@ ROCSOLVER_EXPORT rocblas_status rocsolver_sgetrf_strided_batched( handle, m, n, A, lda, strideA, ipiv, strideP, info, batch_count, 1); } -ROCSOLVER_EXPORT rocblas_status rocsolver_dgetrf_strided_batched( +rocblas_status rocsolver_dgetrf_strided_batched( rocblas_handle handle, const rocblas_int m, const rocblas_int n, double *A, const rocblas_int lda, const rocblas_stride strideA, rocblas_int *ipiv, const rocblas_stride strideP, rocblas_int *info, @@ -110,7 +110,7 @@ ROCSOLVER_EXPORT rocblas_status rocsolver_dgetrf_strided_batched( handle, m, n, A, lda, strideA, ipiv, strideP, info, batch_count, 1); } -ROCSOLVER_EXPORT rocblas_status rocsolver_cgetrf_strided_batched( +rocblas_status rocsolver_cgetrf_strided_batched( rocblas_handle handle, const rocblas_int m, const rocblas_int n, rocblas_float_complex *A, const rocblas_int lda, const rocblas_stride strideA, rocblas_int *ipiv, @@ -120,7 +120,7 @@ ROCSOLVER_EXPORT rocblas_status rocsolver_cgetrf_strided_batched( handle, m, n, A, lda, strideA, ipiv, strideP, info, batch_count, 1); } -ROCSOLVER_EXPORT rocblas_status rocsolver_zgetrf_strided_batched( +rocblas_status rocsolver_zgetrf_strided_batched( rocblas_handle handle, const rocblas_int m, const rocblas_int n, rocblas_double_complex *A, const rocblas_int lda, const rocblas_stride strideA, rocblas_int *ipiv, @@ -130,7 +130,7 @@ ROCSOLVER_EXPORT rocblas_status rocsolver_zgetrf_strided_batched( handle, m, n, A, lda, strideA, ipiv, strideP, info, batch_count, 1); } -ROCSOLVER_EXPORT rocblas_status rocsolver_sgetrf_npvt_strided_batched( +rocblas_status rocsolver_sgetrf_npvt_strided_batched( rocblas_handle handle, const rocblas_int m, const rocblas_int n, float *A, const rocblas_int lda, const rocblas_stride strideA, rocblas_int *info, const rocblas_int batch_count) { @@ -139,7 +139,7 @@ ROCSOLVER_EXPORT rocblas_status rocsolver_sgetrf_npvt_strided_batched( handle, m, n, A, lda, strideA, ipiv, 0, info, batch_count, 0); } -ROCSOLVER_EXPORT rocblas_status rocsolver_dgetrf_npvt_strided_batched( +rocblas_status rocsolver_dgetrf_npvt_strided_batched( rocblas_handle handle, const rocblas_int m, const rocblas_int n, double *A, const rocblas_int lda, const rocblas_stride strideA, rocblas_int *info, const rocblas_int batch_count) { @@ -148,7 +148,7 @@ ROCSOLVER_EXPORT rocblas_status rocsolver_dgetrf_npvt_strided_batched( handle, m, n, A, lda, strideA, ipiv, 0, info, batch_count, 0); } -ROCSOLVER_EXPORT rocblas_status rocsolver_cgetrf_npvt_strided_batched( +rocblas_status rocsolver_cgetrf_npvt_strided_batched( rocblas_handle handle, const rocblas_int m, const rocblas_int n, rocblas_float_complex *A, const rocblas_int lda, const rocblas_stride strideA, rocblas_int *info, @@ -158,7 +158,7 @@ ROCSOLVER_EXPORT rocblas_status rocsolver_cgetrf_npvt_strided_batched( handle, m, n, A, lda, strideA, ipiv, 0, info, batch_count, 0); } -ROCSOLVER_EXPORT rocblas_status rocsolver_zgetrf_npvt_strided_batched( +rocblas_status rocsolver_zgetrf_npvt_strided_batched( rocblas_handle handle, const rocblas_int m, const rocblas_int n, rocblas_double_complex *A, const rocblas_int lda, const rocblas_stride strideA, rocblas_int *info, diff --git a/rocsolver/library/src/lapack/roclapack_getri.cpp b/rocsolver/library/src/lapack/roclapack_getri.cpp index 8875fe095..8f75d1b20 100644 --- a/rocsolver/library/src/lapack/roclapack_getri.cpp +++ b/rocsolver/library/src/lapack/roclapack_getri.cpp @@ -64,32 +64,29 @@ rocblas_status rocsolver_getri_impl(rocblas_handle handle, const rocblas_int n, extern "C" { -ROCSOLVER_EXPORT rocblas_status rocsolver_sgetri(rocblas_handle handle, - const rocblas_int n, float *A, - const rocblas_int lda, - rocblas_int *ipiv, - rocblas_int *info) { +rocblas_status rocsolver_sgetri(rocblas_handle handle, const rocblas_int n, + float *A, const rocblas_int lda, + rocblas_int *ipiv, rocblas_int *info) { return rocsolver_getri_impl(handle, n, A, lda, ipiv, info); } -ROCSOLVER_EXPORT rocblas_status rocsolver_dgetri(rocblas_handle handle, - const rocblas_int n, double *A, - const rocblas_int lda, - rocblas_int *ipiv, - rocblas_int *info) { +rocblas_status rocsolver_dgetri(rocblas_handle handle, const rocblas_int n, + double *A, const rocblas_int lda, + rocblas_int *ipiv, rocblas_int *info) { return rocsolver_getri_impl(handle, n, A, lda, ipiv, info); } -ROCSOLVER_EXPORT rocblas_status rocsolver_cgetri( - rocblas_handle handle, const rocblas_int n, rocblas_float_complex *A, - const rocblas_int lda, rocblas_int *ipiv, rocblas_int *info) { +rocblas_status rocsolver_cgetri(rocblas_handle handle, const rocblas_int n, + rocblas_float_complex *A, const rocblas_int lda, + rocblas_int *ipiv, rocblas_int *info) { return rocsolver_getri_impl(handle, n, A, lda, ipiv, info); } -ROCSOLVER_EXPORT rocblas_status rocsolver_zgetri( - rocblas_handle handle, const rocblas_int n, rocblas_double_complex *A, - const rocblas_int lda, rocblas_int *ipiv, rocblas_int *info) { +rocblas_status rocsolver_zgetri(rocblas_handle handle, const rocblas_int n, + rocblas_double_complex *A, + const rocblas_int lda, rocblas_int *ipiv, + rocblas_int *info) { return rocsolver_getri_impl(handle, n, A, lda, ipiv, info); } diff --git a/rocsolver/library/src/lapack/roclapack_getri_batched.cpp b/rocsolver/library/src/lapack/roclapack_getri_batched.cpp index 5f9f204e2..9d6112cd4 100644 --- a/rocsolver/library/src/lapack/roclapack_getri_batched.cpp +++ b/rocsolver/library/src/lapack/roclapack_getri_batched.cpp @@ -2,7 +2,6 @@ * Copyright (c) 2019-2020 Advanced Micro Devices, Inc. * ************************************************************************ */ -#define batched #include "roclapack_getri.hpp" template @@ -66,40 +65,42 @@ rocsolver_getri_batched_impl(rocblas_handle handle, const rocblas_int n, U A, extern "C" { -ROCSOLVER_EXPORT rocblas_status rocsolver_sgetri_batched( - rocblas_handle handle, const rocblas_int n, float *const A[], - const rocblas_int lda, rocblas_int *ipiv, const rocblas_stride strideP, - rocblas_int *info, const rocblas_int batch_count) { +rocblas_status +rocsolver_sgetri_batched(rocblas_handle handle, const rocblas_int n, + float *const A[], const rocblas_int lda, + rocblas_int *ipiv, const rocblas_stride strideP, + rocblas_int *info, const rocblas_int batch_count) { return rocsolver_getri_batched_impl(handle, n, A, lda, ipiv, strideP, info, batch_count); } -ROCSOLVER_EXPORT rocblas_status rocsolver_dgetri_batched( - rocblas_handle handle, const rocblas_int n, double *const A[], - const rocblas_int lda, rocblas_int *ipiv, const rocblas_stride strideP, - rocblas_int *info, const rocblas_int batch_count) { +rocblas_status +rocsolver_dgetri_batched(rocblas_handle handle, const rocblas_int n, + double *const A[], const rocblas_int lda, + rocblas_int *ipiv, const rocblas_stride strideP, + rocblas_int *info, const rocblas_int batch_count) { return rocsolver_getri_batched_impl(handle, n, A, lda, ipiv, strideP, info, batch_count); } -ROCSOLVER_EXPORT rocblas_status rocsolver_cgetri_batched( - rocblas_handle handle, const rocblas_int n, - rocblas_float_complex *const A[], const rocblas_int lda, rocblas_int *ipiv, - const rocblas_stride strideP, rocblas_int *info, - const rocblas_int batch_count) { +rocblas_status +rocsolver_cgetri_batched(rocblas_handle handle, const rocblas_int n, + rocblas_float_complex *const A[], + const rocblas_int lda, rocblas_int *ipiv, + const rocblas_stride strideP, rocblas_int *info, + const rocblas_int batch_count) { return rocsolver_getri_batched_impl( handle, n, A, lda, ipiv, strideP, info, batch_count); } -ROCSOLVER_EXPORT rocblas_status rocsolver_zgetri_batched( - rocblas_handle handle, const rocblas_int n, - rocblas_double_complex *const A[], const rocblas_int lda, rocblas_int *ipiv, - const rocblas_stride strideP, rocblas_int *info, - const rocblas_int batch_count) { +rocblas_status +rocsolver_zgetri_batched(rocblas_handle handle, const rocblas_int n, + rocblas_double_complex *const A[], + const rocblas_int lda, rocblas_int *ipiv, + const rocblas_stride strideP, rocblas_int *info, + const rocblas_int batch_count) { return rocsolver_getri_batched_impl( handle, n, A, lda, ipiv, strideP, info, batch_count); } } // extern C - -#undef batched diff --git a/rocsolver/library/src/lapack/roclapack_getri_outofplace_batched.cpp b/rocsolver/library/src/lapack/roclapack_getri_outofplace_batched.cpp index 3d7648a3a..a721d1cde 100644 --- a/rocsolver/library/src/lapack/roclapack_getri_outofplace_batched.cpp +++ b/rocsolver/library/src/lapack/roclapack_getri_outofplace_batched.cpp @@ -2,7 +2,6 @@ * Copyright (c) 2019-2020 Advanced Micro Devices, Inc. * ************************************************************************ */ -#define batched #include "roclapack_getri.hpp" /* @@ -136,5 +135,3 @@ ROCSOLVER_EXPORT rocblas_status rocsolver_zgetri_outofplace_batched( } } // extern C - -#undef batched diff --git a/rocsolver/library/src/lapack/roclapack_getri_strided_batched.cpp b/rocsolver/library/src/lapack/roclapack_getri_strided_batched.cpp index 4325eb062..4e120884d 100644 --- a/rocsolver/library/src/lapack/roclapack_getri_strided_batched.cpp +++ b/rocsolver/library/src/lapack/roclapack_getri_strided_batched.cpp @@ -63,7 +63,7 @@ rocblas_status rocsolver_getri_strided_batched_impl( extern "C" { -ROCSOLVER_EXPORT rocblas_status rocsolver_sgetri_strided_batched( +rocblas_status rocsolver_sgetri_strided_batched( rocblas_handle handle, const rocblas_int n, float *A, const rocblas_int lda, const rocblas_stride strideA, rocblas_int *ipiv, const rocblas_stride strideP, rocblas_int *info, @@ -72,7 +72,7 @@ ROCSOLVER_EXPORT rocblas_status rocsolver_sgetri_strided_batched( handle, n, A, lda, strideA, ipiv, strideP, info, batch_count); } -ROCSOLVER_EXPORT rocblas_status rocsolver_dgetri_strided_batched( +rocblas_status rocsolver_dgetri_strided_batched( rocblas_handle handle, const rocblas_int n, double *A, const rocblas_int lda, const rocblas_stride strideA, rocblas_int *ipiv, const rocblas_stride strideP, rocblas_int *info, @@ -81,7 +81,7 @@ ROCSOLVER_EXPORT rocblas_status rocsolver_dgetri_strided_batched( handle, n, A, lda, strideA, ipiv, strideP, info, batch_count); } -ROCSOLVER_EXPORT rocblas_status rocsolver_cgetri_strided_batched( +rocblas_status rocsolver_cgetri_strided_batched( rocblas_handle handle, const rocblas_int n, rocblas_float_complex *A, const rocblas_int lda, const rocblas_stride strideA, rocblas_int *ipiv, const rocblas_stride strideP, rocblas_int *info, @@ -90,7 +90,7 @@ ROCSOLVER_EXPORT rocblas_status rocsolver_cgetri_strided_batched( handle, n, A, lda, strideA, ipiv, strideP, info, batch_count); } -ROCSOLVER_EXPORT rocblas_status rocsolver_zgetri_strided_batched( +rocblas_status rocsolver_zgetri_strided_batched( rocblas_handle handle, const rocblas_int n, rocblas_double_complex *A, const rocblas_int lda, const rocblas_stride strideA, rocblas_int *ipiv, const rocblas_stride strideP, rocblas_int *info, diff --git a/rocsolver/library/src/lapack/roclapack_getrs.cpp b/rocsolver/library/src/lapack/roclapack_getrs.cpp index 4bd2798a2..c003e8cae 100644 --- a/rocsolver/library/src/lapack/roclapack_getrs.cpp +++ b/rocsolver/library/src/lapack/roclapack_getrs.cpp @@ -27,13 +27,36 @@ rocsolver_getrs_impl(rocblas_handle handle, const rocblas_operation trans, rocblas_int batch_count = 1; // memory managment - // this function does not requiere memory work space + size_t size_1; // for TRSM x_temp + size_t size_2; // for TRSM x_temp_arr + size_t size_3; // for TRSM invA + size_t size_4; // for TRSM invA_arr + rocsolver_getrs_getMemorySize(n, nrhs, batch_count, &size_1, + &size_2, &size_3, &size_4); + // (TODO) MEMORY SIZE QUERIES AND ALLOCATIONS TO BE DONE WITH ROCBLAS HANDLE + void *x_temp, *x_temp_arr, *invA, *invA_arr; + // always allocate all required memory for TRSM optimal performance + bool optim_mem = true; + + hipMalloc(&x_temp, size_1); + hipMalloc(&x_temp_arr, size_2); + hipMalloc(&invA, size_3); + hipMalloc(&invA_arr, size_4); + if ((size_1 && !x_temp) || (size_2 && !x_temp_arr) || (size_3 && !invA) || + (size_4 && !invA_arr)) + return rocblas_status_memory_error; // execution - return rocsolver_getrs_template(handle, trans, n, nrhs, A, 0, lda, strideA, - ipiv, strideP, B, 0, ldb, strideB, - batch_count); + rocblas_status status = rocsolver_getrs_template( + handle, trans, n, nrhs, A, 0, lda, strideA, ipiv, strideP, B, 0, ldb, + strideB, batch_count, x_temp, x_temp_arr, invA, invA_arr, optim_mem); + + hipFree(x_temp); + hipFree(x_temp_arr); + hipFree(invA); + hipFree(invA_arr); + return status; } /* @@ -42,23 +65,25 @@ rocsolver_getrs_impl(rocblas_handle handle, const rocblas_operation trans, * =========================================================================== */ -extern "C" ROCSOLVER_EXPORT rocblas_status rocsolver_sgetrs( - rocblas_handle handle, const rocblas_operation trans, const rocblas_int n, - const rocblas_int nrhs, float *A, const rocblas_int lda, - const rocblas_int *ipiv, float *B, const rocblas_int ldb) { +extern "C" rocblas_status +rocsolver_sgetrs(rocblas_handle handle, const rocblas_operation trans, + const rocblas_int n, const rocblas_int nrhs, float *A, + const rocblas_int lda, const rocblas_int *ipiv, float *B, + const rocblas_int ldb) { return rocsolver_getrs_impl(handle, trans, n, nrhs, A, lda, ipiv, B, ldb); } -extern "C" ROCSOLVER_EXPORT rocblas_status rocsolver_dgetrs( - rocblas_handle handle, const rocblas_operation trans, const rocblas_int n, - const rocblas_int nrhs, double *A, const rocblas_int lda, - const rocblas_int *ipiv, double *B, const rocblas_int ldb) { +extern "C" rocblas_status +rocsolver_dgetrs(rocblas_handle handle, const rocblas_operation trans, + const rocblas_int n, const rocblas_int nrhs, double *A, + const rocblas_int lda, const rocblas_int *ipiv, double *B, + const rocblas_int ldb) { return rocsolver_getrs_impl(handle, trans, n, nrhs, A, lda, ipiv, B, ldb); } -extern "C" ROCSOLVER_EXPORT rocblas_status rocsolver_cgetrs( +extern "C" rocblas_status rocsolver_cgetrs( rocblas_handle handle, const rocblas_operation trans, const rocblas_int n, const rocblas_int nrhs, rocblas_float_complex *A, const rocblas_int lda, const rocblas_int *ipiv, rocblas_float_complex *B, const rocblas_int ldb) { @@ -66,7 +91,7 @@ extern "C" ROCSOLVER_EXPORT rocblas_status rocsolver_cgetrs( lda, ipiv, B, ldb); } -extern "C" ROCSOLVER_EXPORT rocblas_status rocsolver_zgetrs( +extern "C" rocblas_status rocsolver_zgetrs( rocblas_handle handle, const rocblas_operation trans, const rocblas_int n, const rocblas_int nrhs, rocblas_double_complex *A, const rocblas_int lda, const rocblas_int *ipiv, rocblas_double_complex *B, const rocblas_int ldb) { diff --git a/rocsolver/library/src/lapack/roclapack_getrs.hpp b/rocsolver/library/src/lapack/roclapack_getrs.hpp index a2ccee7a2..b16fa6c31 100644 --- a/rocsolver/library/src/lapack/roclapack_getrs.hpp +++ b/rocsolver/library/src/lapack/roclapack_getrs.hpp @@ -37,14 +37,24 @@ rocblas_status rocsolver_getrs_argCheck( return rocblas_status_continue; } -template +template +void rocsolver_getrs_getMemorySize(const rocblas_int n, const rocblas_int nrhs, + const rocblas_int batch_count, + size_t *size_1, size_t *size_2, + size_t *size_3, size_t *size_4) { + rocblasCall_trsm_mem(rocblas_side_left, n, nrhs, batch_count, + size_1, size_2, size_3, size_4); +} + +template rocblas_status rocsolver_getrs_template( rocblas_handle handle, const rocblas_operation trans, const rocblas_int n, const rocblas_int nrhs, U A, const rocblas_int shiftA, const rocblas_int lda, const rocblas_stride strideA, const rocblas_int *ipiv, const rocblas_stride strideP, U B, const rocblas_int shiftB, const rocblas_int ldb, - const rocblas_stride strideB, const rocblas_int batch_count) { + const rocblas_stride strideB, const rocblas_int batch_count, void *x_temp, + void *x_temp_arr, void *invA, void *invA_arr, bool optim_mem) { // quick return if (n == 0 || nrhs == 0 || batch_count == 0) { return rocblas_status_success; @@ -58,61 +68,44 @@ rocblas_status rocsolver_getrs_template( rocblas_get_pointer_mode(handle, &old_mode); rocblas_set_pointer_mode(handle, rocblas_pointer_mode_host); -// **** THIS SYNCHRONIZATION WILL BE REQUIRED UNTIL -// TRSM_BATCH FUNCTIONALITY IS ENABLED. **** -#ifdef batched - T *AA[batch_count]; - T *BB[batch_count]; - hipMemcpy(AA, A, batch_count * sizeof(T *), hipMemcpyDeviceToHost); - hipMemcpy(BB, B, batch_count * sizeof(T *), hipMemcpyDeviceToHost); -#else - T *AA = A; - T *BB = B; -#endif - // constants to use when calling rocablas functions T one = 1; // constant 1 in host - T *Ap, *Bp; - - // **** TRSM_BATCH IS EXECUTED IN A FOR-LOOP UNTIL - // FUNCITONALITY IS ENABLED. **** - if (trans == rocblas_operation_none) { // first apply row interchanges to the right hand sides rocsolver_laswp_template(handle, nrhs, B, shiftB, ldb, strideB, 1, n, ipiv, 0, strideP, 1, batch_count); - for (int b = 0; b < batch_count; ++b) { - Ap = load_ptr_batch(AA, b, shiftA, strideA); - Bp = load_ptr_batch(BB, b, shiftB, strideB); - - // solve L*X = B, overwriting B with X - rocblas_trsm(handle, rocblas_side_left, rocblas_fill_lower, trans, - rocblas_diagonal_unit, n, nrhs, &one, Ap, lda, Bp, ldb); + // solve L*X = B, overwriting B with X + rocblasCall_trsm(handle, rocblas_side_left, rocblas_fill_lower, + trans, rocblas_diagonal_unit, n, nrhs, &one, A, + shiftA, lda, strideA, B, shiftB, ldb, strideB, + batch_count, optim_mem, x_temp, x_temp_arr, + invA, invA_arr); - // solve U*X = B, overwriting B with X - rocblas_trsm(handle, rocblas_side_left, rocblas_fill_upper, trans, - rocblas_diagonal_non_unit, n, nrhs, &one, Ap, lda, Bp, - ldb); - } + // solve U*X = B, overwriting B with X + rocblasCall_trsm(handle, rocblas_side_left, rocblas_fill_upper, + trans, rocblas_diagonal_non_unit, n, nrhs, + &one, A, shiftA, lda, strideA, B, shiftB, ldb, + strideB, batch_count, optim_mem, x_temp, + x_temp_arr, invA, invA_arr); } else { - for (int b = 0; b < batch_count; ++b) { - Ap = load_ptr_batch(AA, b, shiftA, strideA); - Bp = load_ptr_batch(BB, b, shiftB, strideB); - - // solve U**T *X = B or U**H *X = B, overwriting B with X - rocblas_trsm(handle, rocblas_side_left, rocblas_fill_upper, trans, - rocblas_diagonal_non_unit, n, nrhs, &one, Ap, lda, Bp, - ldb); - - // solve L**T *X = B, or L**H *X = B overwriting B with X - rocblas_trsm(handle, rocblas_side_left, rocblas_fill_lower, trans, - rocblas_diagonal_unit, n, nrhs, &one, Ap, lda, Bp, ldb); - } + // solve U**T *X = B or U**H *X = B, overwriting B with X + rocblasCall_trsm(handle, rocblas_side_left, rocblas_fill_upper, + trans, rocblas_diagonal_non_unit, n, nrhs, + &one, A, shiftA, lda, strideA, B, shiftB, ldb, + strideB, batch_count, optim_mem, x_temp, + x_temp_arr, invA, invA_arr); + + // solve L**T *X = B, or L**H *X = B overwriting B with X + rocblasCall_trsm(handle, rocblas_side_left, rocblas_fill_lower, + trans, rocblas_diagonal_unit, n, nrhs, &one, A, + shiftA, lda, strideA, B, shiftB, ldb, strideB, + batch_count, optim_mem, x_temp, x_temp_arr, + invA, invA_arr); // then apply row interchanges to the solution vectors rocsolver_laswp_template(handle, nrhs, B, shiftB, ldb, strideB, 1, n, diff --git a/rocsolver/library/src/lapack/roclapack_getrs_batched.cpp b/rocsolver/library/src/lapack/roclapack_getrs_batched.cpp index 3ec6ab257..013a9d575 100644 --- a/rocsolver/library/src/lapack/roclapack_getrs_batched.cpp +++ b/rocsolver/library/src/lapack/roclapack_getrs_batched.cpp @@ -2,7 +2,6 @@ * Copyright (c) 2019-2020 Advanced Micro Devices, Inc. * ************************************************************************ */ -#define batched #include "roclapack_getrs.hpp" template @@ -26,13 +25,36 @@ rocblas_status rocsolver_getrs_batched_impl( rocblas_stride strideB = 0; // memory managment - // this function does not requiere memory work space + size_t size_1; // for TRSM x_temp + size_t size_2; // for TRSM x_temp_arr + size_t size_3; // for TRSM invA + size_t size_4; // for TRSM invA_arr + rocsolver_getrs_getMemorySize(n, nrhs, batch_count, &size_1, &size_2, + &size_3, &size_4); + // (TODO) MEMORY SIZE QUERIES AND ALLOCATIONS TO BE DONE WITH ROCBLAS HANDLE + void *x_temp, *x_temp_arr, *invA, *invA_arr; + // always allocate all required memory for TRSM optimal performance + bool optim_mem = true; + + hipMalloc(&x_temp, size_1); + hipMalloc(&x_temp_arr, size_2); + hipMalloc(&invA, size_3); + hipMalloc(&invA_arr, size_4); + if ((size_1 && !x_temp) || (size_2 && !x_temp_arr) || (size_3 && !invA) || + (size_4 && !invA_arr)) + return rocblas_status_memory_error; // execution - return rocsolver_getrs_template(handle, trans, n, nrhs, A, 0, lda, strideA, - ipiv, strideP, B, 0, ldb, strideB, - batch_count); + rocblas_status status = rocsolver_getrs_template( + handle, trans, n, nrhs, A, 0, lda, strideA, ipiv, strideP, B, 0, ldb, + strideB, batch_count, x_temp, x_temp_arr, invA, invA_arr, optim_mem); + + hipFree(x_temp); + hipFree(x_temp_arr); + hipFree(invA); + hipFree(invA_arr); + return status; } /* @@ -41,7 +63,7 @@ rocblas_status rocsolver_getrs_batched_impl( * =========================================================================== */ -extern "C" ROCSOLVER_EXPORT rocblas_status rocsolver_sgetrs_batched( +extern "C" rocblas_status rocsolver_sgetrs_batched( rocblas_handle handle, const rocblas_operation trans, const rocblas_int n, const rocblas_int nrhs, float *const A[], const rocblas_int lda, const rocblas_int *ipiv, const rocblas_stride strideP, float *const B[], @@ -50,7 +72,7 @@ extern "C" ROCSOLVER_EXPORT rocblas_status rocsolver_sgetrs_batched( handle, trans, n, nrhs, A, lda, ipiv, strideP, B, ldb, batch_count); } -extern "C" ROCSOLVER_EXPORT rocblas_status rocsolver_dgetrs_batched( +extern "C" rocblas_status rocsolver_dgetrs_batched( rocblas_handle handle, const rocblas_operation trans, const rocblas_int n, const rocblas_int nrhs, double *const A[], const rocblas_int lda, const rocblas_int *ipiv, const rocblas_stride strideP, double *const B[], @@ -59,7 +81,7 @@ extern "C" ROCSOLVER_EXPORT rocblas_status rocsolver_dgetrs_batched( handle, trans, n, nrhs, A, lda, ipiv, strideP, B, ldb, batch_count); } -extern "C" ROCSOLVER_EXPORT rocblas_status rocsolver_cgetrs_batched( +extern "C" rocblas_status rocsolver_cgetrs_batched( rocblas_handle handle, const rocblas_operation trans, const rocblas_int n, const rocblas_int nrhs, rocblas_float_complex *const A[], const rocblas_int lda, const rocblas_int *ipiv, @@ -69,7 +91,7 @@ extern "C" ROCSOLVER_EXPORT rocblas_status rocsolver_cgetrs_batched( handle, trans, n, nrhs, A, lda, ipiv, strideP, B, ldb, batch_count); } -extern "C" ROCSOLVER_EXPORT rocblas_status rocsolver_zgetrs_batched( +extern "C" rocblas_status rocsolver_zgetrs_batched( rocblas_handle handle, const rocblas_operation trans, const rocblas_int n, const rocblas_int nrhs, rocblas_double_complex *const A[], const rocblas_int lda, const rocblas_int *ipiv, @@ -78,5 +100,3 @@ extern "C" ROCSOLVER_EXPORT rocblas_status rocsolver_zgetrs_batched( return rocsolver_getrs_batched_impl( handle, trans, n, nrhs, A, lda, ipiv, strideP, B, ldb, batch_count); } - -#undef batched diff --git a/rocsolver/library/src/lapack/roclapack_getrs_strided_batched.cpp b/rocsolver/library/src/lapack/roclapack_getrs_strided_batched.cpp index 888fb3ca0..83cc2fbab 100644 --- a/rocsolver/library/src/lapack/roclapack_getrs_strided_batched.cpp +++ b/rocsolver/library/src/lapack/roclapack_getrs_strided_batched.cpp @@ -23,13 +23,36 @@ rocblas_status rocsolver_getrs_strided_batched_impl( return st; // memory managment - // this function does not requiere memory work space + size_t size_1; // for TRSM x_temp + size_t size_2; // for TRSM x_temp_arr + size_t size_3; // for TRSM invA + size_t size_4; // for TRSM invA_arr + rocsolver_getrs_getMemorySize(n, nrhs, batch_count, &size_1, + &size_2, &size_3, &size_4); + // (TODO) MEMORY SIZE QUERIES AND ALLOCATIONS TO BE DONE WITH ROCBLAS HANDLE + void *x_temp, *x_temp_arr, *invA, *invA_arr; + // always allocate all required memory for TRSM optimal performance + bool optim_mem = true; + + hipMalloc(&x_temp, size_1); + hipMalloc(&x_temp_arr, size_2); + hipMalloc(&invA, size_3); + hipMalloc(&invA_arr, size_4); + if ((size_1 && !x_temp) || (size_2 && !x_temp_arr) || (size_3 && !invA) || + (size_4 && !invA_arr)) + return rocblas_status_memory_error; // execution - return rocsolver_getrs_template(handle, trans, n, nrhs, A, 0, lda, strideA, - ipiv, strideP, B, 0, ldb, strideB, - batch_count); + rocblas_status status = rocsolver_getrs_template( + handle, trans, n, nrhs, A, 0, lda, strideA, ipiv, strideP, B, 0, ldb, + strideB, batch_count, x_temp, x_temp_arr, invA, invA_arr, optim_mem); + + hipFree(x_temp); + hipFree(x_temp_arr); + hipFree(invA); + hipFree(invA_arr); + return status; } /* @@ -38,7 +61,7 @@ rocblas_status rocsolver_getrs_strided_batched_impl( * =========================================================================== */ -extern "C" ROCSOLVER_EXPORT rocblas_status rocsolver_sgetrs_strided_batched( +extern "C" rocblas_status rocsolver_sgetrs_strided_batched( rocblas_handle handle, const rocblas_operation trans, const rocblas_int n, const rocblas_int nrhs, float *A, const rocblas_int lda, const rocblas_stride strideA, const rocblas_int *ipiv, @@ -49,7 +72,7 @@ extern "C" ROCSOLVER_EXPORT rocblas_status rocsolver_sgetrs_strided_batched( batch_count); } -extern "C" ROCSOLVER_EXPORT rocblas_status rocsolver_dgetrs_strided_batched( +extern "C" rocblas_status rocsolver_dgetrs_strided_batched( rocblas_handle handle, const rocblas_operation trans, const rocblas_int n, const rocblas_int nrhs, double *A, const rocblas_int lda, const rocblas_stride strideA, const rocblas_int *ipiv, @@ -60,7 +83,7 @@ extern "C" ROCSOLVER_EXPORT rocblas_status rocsolver_dgetrs_strided_batched( batch_count); } -extern "C" ROCSOLVER_EXPORT rocblas_status rocsolver_cgetrs_strided_batched( +extern "C" rocblas_status rocsolver_cgetrs_strided_batched( rocblas_handle handle, const rocblas_operation trans, const rocblas_int n, const rocblas_int nrhs, rocblas_float_complex *A, const rocblas_int lda, const rocblas_stride strideA, const rocblas_int *ipiv, @@ -72,7 +95,7 @@ extern "C" ROCSOLVER_EXPORT rocblas_status rocsolver_cgetrs_strided_batched( batch_count); } -extern "C" ROCSOLVER_EXPORT rocblas_status rocsolver_zgetrs_strided_batched( +extern "C" rocblas_status rocsolver_zgetrs_strided_batched( rocblas_handle handle, const rocblas_operation trans, const rocblas_int n, const rocblas_int nrhs, rocblas_double_complex *A, const rocblas_int lda, const rocblas_stride strideA, const rocblas_int *ipiv, diff --git a/rocsolver/library/src/lapack/roclapack_potf2.cpp b/rocsolver/library/src/lapack/roclapack_potf2.cpp index 2da9d78b9..34f459dbf 100644 --- a/rocsolver/library/src/lapack/roclapack_potf2.cpp +++ b/rocsolver/library/src/lapack/roclapack_potf2.cpp @@ -61,32 +61,28 @@ rocblas_status rocsolver_potf2_impl(rocblas_handle handle, extern "C" { -ROCSOLVER_EXPORT rocblas_status rocsolver_spotf2(rocblas_handle handle, - const rocblas_fill uplo, - const rocblas_int n, float *A, - const rocblas_int lda, - rocblas_int *info) { +rocblas_status rocsolver_spotf2(rocblas_handle handle, const rocblas_fill uplo, + const rocblas_int n, float *A, + const rocblas_int lda, rocblas_int *info) { return rocsolver_potf2_impl(handle, uplo, n, A, lda, info); } -ROCSOLVER_EXPORT rocblas_status rocsolver_dpotf2(rocblas_handle handle, - const rocblas_fill uplo, - const rocblas_int n, double *A, - const rocblas_int lda, - rocblas_int *info) { +rocblas_status rocsolver_dpotf2(rocblas_handle handle, const rocblas_fill uplo, + const rocblas_int n, double *A, + const rocblas_int lda, rocblas_int *info) { return rocsolver_potf2_impl(handle, uplo, n, A, lda, info); } -ROCSOLVER_EXPORT rocblas_status rocsolver_cpotf2( - rocblas_handle handle, const rocblas_fill uplo, const rocblas_int n, - rocblas_float_complex *A, const rocblas_int lda, rocblas_int *info) { +rocblas_status rocsolver_cpotf2(rocblas_handle handle, const rocblas_fill uplo, + const rocblas_int n, rocblas_float_complex *A, + const rocblas_int lda, rocblas_int *info) { return rocsolver_potf2_impl(handle, uplo, n, A, lda, info); } -ROCSOLVER_EXPORT rocblas_status rocsolver_zpotf2( - rocblas_handle handle, const rocblas_fill uplo, const rocblas_int n, - rocblas_double_complex *A, const rocblas_int lda, rocblas_int *info) { +rocblas_status rocsolver_zpotf2(rocblas_handle handle, const rocblas_fill uplo, + const rocblas_int n, rocblas_double_complex *A, + const rocblas_int lda, rocblas_int *info) { return rocsolver_potf2_impl(handle, uplo, n, A, lda, info); } diff --git a/rocsolver/library/src/lapack/roclapack_potf2_batched.cpp b/rocsolver/library/src/lapack/roclapack_potf2_batched.cpp index b9f8d4ec4..a11af9335 100644 --- a/rocsolver/library/src/lapack/roclapack_potf2_batched.cpp +++ b/rocsolver/library/src/lapack/roclapack_potf2_batched.cpp @@ -2,7 +2,6 @@ * Copyright (c) 2019-2020 Advanced Micro Devices, Inc. * ************************************************************************ */ -#define batched #include "roclapack_potf2.hpp" template @@ -62,37 +61,41 @@ rocsolver_potf2_batched_impl(rocblas_handle handle, const rocblas_fill uplo, extern "C" { -ROCSOLVER_EXPORT rocblas_status rocsolver_spotf2_batched( - rocblas_handle handle, const rocblas_fill uplo, const rocblas_int n, - float *const A[], const rocblas_int lda, rocblas_int *info, - const rocblas_int batch_count) { +rocblas_status rocsolver_spotf2_batched(rocblas_handle handle, + const rocblas_fill uplo, + const rocblas_int n, float *const A[], + const rocblas_int lda, + rocblas_int *info, + const rocblas_int batch_count) { return rocsolver_potf2_batched_impl(handle, uplo, n, A, lda, info, batch_count); } -ROCSOLVER_EXPORT rocblas_status rocsolver_dpotf2_batched( - rocblas_handle handle, const rocblas_fill uplo, const rocblas_int n, - double *const A[], const rocblas_int lda, rocblas_int *info, - const rocblas_int batch_count) { +rocblas_status rocsolver_dpotf2_batched(rocblas_handle handle, + const rocblas_fill uplo, + const rocblas_int n, double *const A[], + const rocblas_int lda, + rocblas_int *info, + const rocblas_int batch_count) { return rocsolver_potf2_batched_impl(handle, uplo, n, A, lda, info, batch_count); } -ROCSOLVER_EXPORT rocblas_status rocsolver_cpotf2_batched( - rocblas_handle handle, const rocblas_fill uplo, const rocblas_int n, - rocblas_float_complex *const A[], const rocblas_int lda, rocblas_int *info, - const rocblas_int batch_count) { +rocblas_status +rocsolver_cpotf2_batched(rocblas_handle handle, const rocblas_fill uplo, + const rocblas_int n, rocblas_float_complex *const A[], + const rocblas_int lda, rocblas_int *info, + const rocblas_int batch_count) { return rocsolver_potf2_batched_impl( handle, uplo, n, A, lda, info, batch_count); } -ROCSOLVER_EXPORT rocblas_status rocsolver_zpotf2_batched( - rocblas_handle handle, const rocblas_fill uplo, const rocblas_int n, - rocblas_double_complex *const A[], const rocblas_int lda, rocblas_int *info, - const rocblas_int batch_count) { +rocblas_status +rocsolver_zpotf2_batched(rocblas_handle handle, const rocblas_fill uplo, + const rocblas_int n, rocblas_double_complex *const A[], + const rocblas_int lda, rocblas_int *info, + const rocblas_int batch_count) { return rocsolver_potf2_batched_impl( handle, uplo, n, A, lda, info, batch_count); } } - -#undef batched diff --git a/rocsolver/library/src/lapack/roclapack_potf2_strided_batched.cpp b/rocsolver/library/src/lapack/roclapack_potf2_strided_batched.cpp index 43934b52c..ea44c1507 100644 --- a/rocsolver/library/src/lapack/roclapack_potf2_strided_batched.cpp +++ b/rocsolver/library/src/lapack/roclapack_potf2_strided_batched.cpp @@ -59,7 +59,7 @@ rocblas_status rocsolver_potf2_strided_batched_impl( extern "C" { -ROCSOLVER_EXPORT rocblas_status rocsolver_spotf2_strided_batched( +rocblas_status rocsolver_spotf2_strided_batched( rocblas_handle handle, const rocblas_fill uplo, const rocblas_int n, float *A, const rocblas_int lda, const rocblas_stride strideA, rocblas_int *info, const rocblas_int batch_count) { @@ -67,7 +67,7 @@ ROCSOLVER_EXPORT rocblas_status rocsolver_spotf2_strided_batched( handle, uplo, n, A, lda, strideA, info, batch_count); } -ROCSOLVER_EXPORT rocblas_status rocsolver_dpotf2_strided_batched( +rocblas_status rocsolver_dpotf2_strided_batched( rocblas_handle handle, const rocblas_fill uplo, const rocblas_int n, double *A, const rocblas_int lda, const rocblas_stride strideA, rocblas_int *info, const rocblas_int batch_count) { @@ -75,7 +75,7 @@ ROCSOLVER_EXPORT rocblas_status rocsolver_dpotf2_strided_batched( handle, uplo, n, A, lda, strideA, info, batch_count); } -ROCSOLVER_EXPORT rocblas_status rocsolver_cpotf2_strided_batched( +rocblas_status rocsolver_cpotf2_strided_batched( rocblas_handle handle, const rocblas_fill uplo, const rocblas_int n, rocblas_float_complex *A, const rocblas_int lda, const rocblas_stride strideA, rocblas_int *info, @@ -84,7 +84,7 @@ ROCSOLVER_EXPORT rocblas_status rocsolver_cpotf2_strided_batched( handle, uplo, n, A, lda, strideA, info, batch_count); } -ROCSOLVER_EXPORT rocblas_status rocsolver_zpotf2_strided_batched( +rocblas_status rocsolver_zpotf2_strided_batched( rocblas_handle handle, const rocblas_fill uplo, const rocblas_int n, rocblas_double_complex *A, const rocblas_int lda, const rocblas_stride strideA, rocblas_int *info, diff --git a/rocsolver/library/src/lapack/roclapack_potrf.cpp b/rocsolver/library/src/lapack/roclapack_potrf.cpp index 621933290..ced0e5a76 100644 --- a/rocsolver/library/src/lapack/roclapack_potrf.cpp +++ b/rocsolver/library/src/lapack/roclapack_potrf.cpp @@ -27,17 +27,31 @@ rocblas_status rocsolver_potrf_impl(rocblas_handle handle, size_t size_2; // size of workspace size_t size_3; size_t size_4; - rocsolver_potrf_getMemorySize(n, batch_count, &size_1, &size_2, &size_3, - &size_4); + size_t size_5; // for TRSM + size_t size_6; // for TRSM + size_t size_7; // for TRSM + size_t size_8; // for TRSM + rocsolver_potrf_getMemorySize(n, uplo, batch_count, &size_1, + &size_2, &size_3, &size_4, &size_5, + &size_6, &size_7, &size_8); // (TODO) MEMORY SIZE QUERIES AND ALLOCATIONS TO BE DONE WITH ROCBLAS HANDLE - void *scalars, *work, *pivotGPU, *iinfo; + void *scalars, *work, *pivotGPU, *iinfo, *x_temp, *x_temp_arr, *invA, + *invA_arr; + // always allocate all required memory for TRSM optimal performance + bool optim_mem = true; + hipMalloc(&scalars, size_1); hipMalloc(&work, size_2); hipMalloc(&pivotGPU, size_3); hipMalloc(&iinfo, size_4); + hipMalloc(&x_temp, size_5); + hipMalloc(&x_temp_arr, size_6); + hipMalloc(&invA, size_7); + hipMalloc(&invA_arr, size_8); if (!scalars || (size_2 && !work) || (size_3 && !pivotGPU) || - (size_4 && !iinfo)) + (size_4 && !iinfo) || (size_5 && !x_temp) || (size_6 && !x_temp_arr) || + (size_7 && !invA) || (size_8 && !invA_arr)) return rocblas_status_memory_error; // scalar constants for rocblas functions calls @@ -46,16 +60,20 @@ rocblas_status rocsolver_potrf_impl(rocblas_handle handle, RETURN_IF_HIP_ERROR(hipMemcpy(scalars, sca, size_1, hipMemcpyHostToDevice)); // execution - rocblas_status status = rocsolver_potrf_template( + rocblas_status status = rocsolver_potrf_template( handle, uplo, n, A, 0, // the matrix is shifted 0 entries (will work on the entire matrix) lda, strideA, info, batch_count, (T *)scalars, (T *)work, (T *)pivotGPU, - (rocblas_int *)iinfo); + (rocblas_int *)iinfo, x_temp, x_temp_arr, invA, invA_arr, optim_mem); hipFree(scalars); hipFree(work); hipFree(pivotGPU); hipFree(iinfo); + hipFree(x_temp); + hipFree(x_temp_arr); + hipFree(invA); + hipFree(invA_arr); return status; } @@ -67,32 +85,28 @@ rocblas_status rocsolver_potrf_impl(rocblas_handle handle, extern "C" { -ROCSOLVER_EXPORT rocblas_status rocsolver_spotrf(rocblas_handle handle, - const rocblas_fill uplo, - const rocblas_int n, float *A, - const rocblas_int lda, - rocblas_int *info) { +rocblas_status rocsolver_spotrf(rocblas_handle handle, const rocblas_fill uplo, + const rocblas_int n, float *A, + const rocblas_int lda, rocblas_int *info) { return rocsolver_potrf_impl(handle, uplo, n, A, lda, info); } -ROCSOLVER_EXPORT rocblas_status rocsolver_dpotrf(rocblas_handle handle, - const rocblas_fill uplo, - const rocblas_int n, double *A, - const rocblas_int lda, - rocblas_int *info) { +rocblas_status rocsolver_dpotrf(rocblas_handle handle, const rocblas_fill uplo, + const rocblas_int n, double *A, + const rocblas_int lda, rocblas_int *info) { return rocsolver_potrf_impl(handle, uplo, n, A, lda, info); } -ROCSOLVER_EXPORT rocblas_status rocsolver_cpotrf( - rocblas_handle handle, const rocblas_fill uplo, const rocblas_int n, - rocblas_float_complex *A, const rocblas_int lda, rocblas_int *info) { +rocblas_status rocsolver_cpotrf(rocblas_handle handle, const rocblas_fill uplo, + const rocblas_int n, rocblas_float_complex *A, + const rocblas_int lda, rocblas_int *info) { return rocsolver_potrf_impl(handle, uplo, n, A, lda, info); } -ROCSOLVER_EXPORT rocblas_status rocsolver_zpotrf( - rocblas_handle handle, const rocblas_fill uplo, const rocblas_int n, - rocblas_double_complex *A, const rocblas_int lda, rocblas_int *info) { +rocblas_status rocsolver_zpotrf(rocblas_handle handle, const rocblas_fill uplo, + const rocblas_int n, rocblas_double_complex *A, + const rocblas_int lda, rocblas_int *info) { return rocsolver_potrf_impl(handle, uplo, n, A, lda, info); } diff --git a/rocsolver/library/src/lapack/roclapack_potrf.hpp b/rocsolver/library/src/lapack/roclapack_potrf.hpp index 7d992af4d..e813682fd 100644 --- a/rocsolver/library/src/lapack/roclapack_potrf.hpp +++ b/rocsolver/library/src/lapack/roclapack_potrf.hpp @@ -22,28 +22,47 @@ __global__ void chk_positive(rocblas_int *iinfo, rocblas_int *info, int j) { info[id] = iinfo[id] + j; } -template -void rocsolver_potrf_getMemorySize(const rocblas_int n, +template +void rocsolver_potrf_getMemorySize(const rocblas_int n, const rocblas_fill uplo, const rocblas_int batch_count, size_t *size_1, size_t *size_2, - size_t *size_3, size_t *size_4) { + size_t *size_3, size_t *size_4, + size_t *size_5, size_t *size_6, + size_t *size_7, size_t *size_8) { if (n < POTRF_POTF2_SWITCHSIZE) { rocsolver_potf2_getMemorySize(n, batch_count, size_1, size_2, size_3); *size_4 = 0; + *size_5 = 0; + *size_6 = 0; + *size_7 = 0; + *size_8 = 0; } else { + *size_4 = sizeof(rocblas_int) * batch_count; rocsolver_potf2_getMemorySize(POTRF_POTF2_SWITCHSIZE, batch_count, size_1, size_2, size_3); - *size_4 = sizeof(rocblas_int) * batch_count; + + rocblas_int jb = POTRF_POTF2_SWITCHSIZE; + if (uplo == rocblas_fill_upper) + rocblasCall_trsm_mem(rocblas_side_left, jb, n - jb, + batch_count, size_5, size_6, size_7, + size_8); + else + rocblasCall_trsm_mem(rocblas_side_right, n - jb, jb, + batch_count, size_5, size_6, size_7, + size_8); } } -template > +template > rocblas_status rocsolver_potrf_template(rocblas_handle handle, const rocblas_fill uplo, const rocblas_int n, U A, const rocblas_int shiftA, const rocblas_int lda, const rocblas_stride strideA, rocblas_int *info, const rocblas_int batch_count, - T *scalars, T *work, T *pivotGPU, rocblas_int *iinfo) { + T *scalars, T *work, T *pivotGPU, rocblas_int *iinfo, + void *x_temp, void *x_temp_arr, void *invA, + void *invA_arr, bool optim_mem) { // quick return if (n == 0 || batch_count == 0) return rocblas_status_success; @@ -63,15 +82,6 @@ rocsolver_potrf_template(rocblas_handle handle, const rocblas_fill uplo, info, batch_count, scalars, work, pivotGPU); -// **** THIS SYNCHRONIZATION WILL BE REQUIRED UNTIL -// TRSM_BATCH FUNCTIONALITY IS ENABLED. **** -#ifdef batched - T *AA[batch_count]; - hipMemcpy(AA, A, batch_count * sizeof(T *), hipMemcpyDeviceToHost); -#else - T *AA = A; -#endif - // constants for rocblas functions calls T t_one = 1; S s_one = 1; @@ -80,18 +90,14 @@ rocsolver_potrf_template(rocblas_handle handle, const rocblas_fill uplo, rocblas_int blocksReset = (batch_count - 1) / BLOCKSIZE + 1; dim3 gridReset(blocksReset, 1, 1); dim3 threads(BLOCKSIZE, 1, 1); - T *M; rocblas_int jb; // info=0 (starting with a positive definite matrix) hipLaunchKernelGGL(reset_info, gridReset, threads, 0, stream, info, batch_count, 0); - // **** TRSM_BATCH IS EXECUTED IN A FOR-LOOP UNTIL - // FUNCITONALITY IS ENABLED. **** - - if (uplo == - rocblas_fill_upper) { // Compute the Cholesky factorization A = U'*U. + if (uplo == rocblas_fill_upper) { + // Compute the Cholesky factorization A = U'*U. for (rocblas_int j = 0; j < n; j += POTRF_POTF2_SWITCHSIZE) { // Factor diagonal and subdiagonal blocks jb = min(n - j, POTRF_POTF2_SWITCHSIZE); // number of columns in the block @@ -107,14 +113,12 @@ rocsolver_potrf_template(rocblas_handle handle, const rocblas_fill uplo, if (j + jb < n) { // update trailing submatrix - for (int b = 0; b < batch_count; ++b) { - M = load_ptr_batch(AA, b, shiftA, strideA); - rocblas_trsm(handle, rocblas_side_left, uplo, - rocblas_operation_conjugate_transpose, - rocblas_diagonal_non_unit, jb, (n - j - jb), &t_one, - (M + idx2D(j, j, lda)), lda, (M + idx2D(j, j + jb, lda)), - lda); - } + rocblasCall_trsm( + handle, rocblas_side_left, uplo, + rocblas_operation_conjugate_transpose, rocblas_diagonal_non_unit, + jb, (n - j - jb), &t_one, A, shiftA + idx2D(j, j, lda), lda, + strideA, A, shiftA + idx2D(j, j + jb, lda), lda, strideA, + batch_count, optim_mem, x_temp, x_temp_arr, invA, invA_arr); rocblasCall_herk( handle, uplo, rocblas_operation_conjugate_transpose, n - j - jb, jb, @@ -123,7 +127,8 @@ rocsolver_potrf_template(rocblas_handle handle, const rocblas_fill uplo, } } - } else { // Compute the Cholesky factorization A = L'*L. + } else { + // Compute the Cholesky factorization A = L*L'. for (rocblas_int j = 0; j < n; j += POTRF_POTF2_SWITCHSIZE) { // Factor diagonal and subdiagonal blocks jb = min(n - j, POTRF_POTF2_SWITCHSIZE); // number of columns in the block @@ -139,14 +144,12 @@ rocsolver_potrf_template(rocblas_handle handle, const rocblas_fill uplo, if (j + jb < n) { // update trailing submatrix - for (int b = 0; b < batch_count; ++b) { - M = load_ptr_batch(AA, b, shiftA, strideA); - rocblas_trsm(handle, rocblas_side_right, uplo, - rocblas_operation_conjugate_transpose, - rocblas_diagonal_non_unit, (n - j - jb), jb, &t_one, - (M + idx2D(j, j, lda)), lda, (M + idx2D(j + jb, j, lda)), - lda); - } + rocblasCall_trsm( + handle, rocblas_side_right, uplo, + rocblas_operation_conjugate_transpose, rocblas_diagonal_non_unit, + (n - j - jb), jb, &t_one, A, shiftA + idx2D(j, j, lda), lda, + strideA, A, shiftA + idx2D(j + jb, j, lda), lda, strideA, + batch_count, optim_mem, x_temp, x_temp_arr, invA, invA_arr); rocblasCall_herk( handle, uplo, rocblas_operation_none, n - j - jb, jb, &s_minone, A, diff --git a/rocsolver/library/src/lapack/roclapack_potrf_batched.cpp b/rocsolver/library/src/lapack/roclapack_potrf_batched.cpp index 0b34b5c5c..6a513f625 100644 --- a/rocsolver/library/src/lapack/roclapack_potrf_batched.cpp +++ b/rocsolver/library/src/lapack/roclapack_potrf_batched.cpp @@ -2,7 +2,6 @@ * Copyright (c) 2019-2020 Advanced Micro Devices, Inc. * ************************************************************************ */ -#define batched #include "roclapack_potrf.hpp" template @@ -28,17 +27,31 @@ rocsolver_potrf_batched_impl(rocblas_handle handle, const rocblas_fill uplo, size_t size_2; // size of workspace size_t size_3; size_t size_4; - rocsolver_potrf_getMemorySize(n, batch_count, &size_1, &size_2, &size_3, - &size_4); + size_t size_5; // for TRSM + size_t size_6; // for TRSM + size_t size_7; // for TRSM + size_t size_8; // for TRSM + rocsolver_potrf_getMemorySize(n, uplo, batch_count, &size_1, &size_2, + &size_3, &size_4, &size_5, &size_6, + &size_7, &size_8); // (TODO) MEMORY SIZE QUERIES AND ALLOCATIONS TO BE DONE WITH ROCBLAS HANDLE - void *scalars, *work, *pivotGPU, *iinfo; + void *scalars, *work, *pivotGPU, *iinfo, *x_temp, *x_temp_arr, *invA, + *invA_arr; + // always allocate all required memory for TRSM optimal performance + bool optim_mem = true; + hipMalloc(&scalars, size_1); hipMalloc(&work, size_2); hipMalloc(&pivotGPU, size_3); hipMalloc(&iinfo, size_4); + hipMalloc(&x_temp, size_5); + hipMalloc(&x_temp_arr, size_6); + hipMalloc(&invA, size_7); + hipMalloc(&invA_arr, size_8); if (!scalars || (size_2 && !work) || (size_3 && !pivotGPU) || - (size_4 && !iinfo)) + (size_4 && !iinfo) || (size_5 && !x_temp) || (size_6 && !x_temp_arr) || + (size_7 && !invA) || (size_8 && !invA_arr)) return rocblas_status_memory_error; // scalar constants for rocblas functions calls @@ -47,16 +60,20 @@ rocsolver_potrf_batched_impl(rocblas_handle handle, const rocblas_fill uplo, RETURN_IF_HIP_ERROR(hipMemcpy(scalars, sca, size_1, hipMemcpyHostToDevice)); // execution - rocblas_status status = rocsolver_potrf_template( + rocblas_status status = rocsolver_potrf_template( handle, uplo, n, A, 0, // the matrix is shifted 0 entries (will work on the entire matrix) lda, strideA, info, batch_count, (T *)scalars, (T *)work, (T *)pivotGPU, - (rocblas_int *)iinfo); + (rocblas_int *)iinfo, x_temp, x_temp_arr, invA, invA_arr, optim_mem); hipFree(scalars); hipFree(work); hipFree(pivotGPU); hipFree(iinfo); + hipFree(x_temp); + hipFree(x_temp_arr); + hipFree(invA); + hipFree(invA_arr); return status; } @@ -68,37 +85,41 @@ rocsolver_potrf_batched_impl(rocblas_handle handle, const rocblas_fill uplo, extern "C" { -ROCSOLVER_EXPORT rocblas_status rocsolver_spotrf_batched( - rocblas_handle handle, const rocblas_fill uplo, const rocblas_int n, - float *const A[], const rocblas_int lda, rocblas_int *info, - const rocblas_int batch_count) { +rocblas_status rocsolver_spotrf_batched(rocblas_handle handle, + const rocblas_fill uplo, + const rocblas_int n, float *const A[], + const rocblas_int lda, + rocblas_int *info, + const rocblas_int batch_count) { return rocsolver_potrf_batched_impl(handle, uplo, n, A, lda, info, batch_count); } -ROCSOLVER_EXPORT rocblas_status rocsolver_dpotrf_batched( - rocblas_handle handle, const rocblas_fill uplo, const rocblas_int n, - double *const A[], const rocblas_int lda, rocblas_int *info, - const rocblas_int batch_count) { +rocblas_status rocsolver_dpotrf_batched(rocblas_handle handle, + const rocblas_fill uplo, + const rocblas_int n, double *const A[], + const rocblas_int lda, + rocblas_int *info, + const rocblas_int batch_count) { return rocsolver_potrf_batched_impl(handle, uplo, n, A, lda, info, batch_count); } -ROCSOLVER_EXPORT rocblas_status rocsolver_cpotrf_batched( - rocblas_handle handle, const rocblas_fill uplo, const rocblas_int n, - rocblas_float_complex *const A[], const rocblas_int lda, rocblas_int *info, - const rocblas_int batch_count) { +rocblas_status +rocsolver_cpotrf_batched(rocblas_handle handle, const rocblas_fill uplo, + const rocblas_int n, rocblas_float_complex *const A[], + const rocblas_int lda, rocblas_int *info, + const rocblas_int batch_count) { return rocsolver_potrf_batched_impl( handle, uplo, n, A, lda, info, batch_count); } -ROCSOLVER_EXPORT rocblas_status rocsolver_zpotrf_batched( - rocblas_handle handle, const rocblas_fill uplo, const rocblas_int n, - rocblas_double_complex *const A[], const rocblas_int lda, rocblas_int *info, - const rocblas_int batch_count) { +rocblas_status +rocsolver_zpotrf_batched(rocblas_handle handle, const rocblas_fill uplo, + const rocblas_int n, rocblas_double_complex *const A[], + const rocblas_int lda, rocblas_int *info, + const rocblas_int batch_count) { return rocsolver_potrf_batched_impl( handle, uplo, n, A, lda, info, batch_count); } } - -#undef batched diff --git a/rocsolver/library/src/lapack/roclapack_potrf_strided_batched.cpp b/rocsolver/library/src/lapack/roclapack_potrf_strided_batched.cpp index 69aadf577..10ff77cf2 100644 --- a/rocsolver/library/src/lapack/roclapack_potrf_strided_batched.cpp +++ b/rocsolver/library/src/lapack/roclapack_potrf_strided_batched.cpp @@ -25,17 +25,31 @@ rocblas_status rocsolver_potrf_strided_batched_impl( size_t size_2; // size of workspace size_t size_3; size_t size_4; - rocsolver_potrf_getMemorySize(n, batch_count, &size_1, &size_2, &size_3, - &size_4); + size_t size_5; // for TRSM + size_t size_6; // for TRSM + size_t size_7; // for TRSM + size_t size_8; // for TRSM + rocsolver_potrf_getMemorySize(n, uplo, batch_count, &size_1, + &size_2, &size_3, &size_4, &size_5, + &size_6, &size_7, &size_8); // (TODO) MEMORY SIZE QUERIES AND ALLOCATIONS TO BE DONE WITH ROCBLAS HANDLE - void *scalars, *work, *pivotGPU, *iinfo; + void *scalars, *work, *pivotGPU, *iinfo, *x_temp, *x_temp_arr, *invA, + *invA_arr; + // always allocate all required memory for TRSM optimal performance + bool optim_mem = true; + hipMalloc(&scalars, size_1); hipMalloc(&work, size_2); hipMalloc(&pivotGPU, size_3); hipMalloc(&iinfo, size_4); + hipMalloc(&x_temp, size_5); + hipMalloc(&x_temp_arr, size_6); + hipMalloc(&invA, size_7); + hipMalloc(&invA_arr, size_8); if (!scalars || (size_2 && !work) || (size_3 && !pivotGPU) || - (size_4 && !iinfo)) + (size_4 && !iinfo) || (size_5 && !x_temp) || (size_6 && !x_temp_arr) || + (size_7 && !invA) || (size_8 && !invA_arr)) return rocblas_status_memory_error; // scalar constants for rocblas functions calls @@ -44,16 +58,20 @@ rocblas_status rocsolver_potrf_strided_batched_impl( RETURN_IF_HIP_ERROR(hipMemcpy(scalars, sca, size_1, hipMemcpyHostToDevice)); // execution - rocblas_status status = rocsolver_potrf_template( + rocblas_status status = rocsolver_potrf_template( handle, uplo, n, A, 0, // the matrix is shifted 0 entries (will work on the entire matrix) lda, strideA, info, batch_count, (T *)scalars, (T *)work, (T *)pivotGPU, - (rocblas_int *)iinfo); + (rocblas_int *)iinfo, x_temp, x_temp_arr, invA, invA_arr, optim_mem); hipFree(scalars); hipFree(work); hipFree(pivotGPU); hipFree(iinfo); + hipFree(x_temp); + hipFree(x_temp_arr); + hipFree(invA); + hipFree(invA_arr); return status; } @@ -65,7 +83,7 @@ rocblas_status rocsolver_potrf_strided_batched_impl( extern "C" { -ROCSOLVER_EXPORT rocblas_status rocsolver_spotrf_strided_batched( +rocblas_status rocsolver_spotrf_strided_batched( rocblas_handle handle, const rocblas_fill uplo, const rocblas_int n, float *A, const rocblas_int lda, const rocblas_stride strideA, rocblas_int *info, const rocblas_int batch_count) { @@ -73,7 +91,7 @@ ROCSOLVER_EXPORT rocblas_status rocsolver_spotrf_strided_batched( handle, uplo, n, A, lda, strideA, info, batch_count); } -ROCSOLVER_EXPORT rocblas_status rocsolver_dpotrf_strided_batched( +rocblas_status rocsolver_dpotrf_strided_batched( rocblas_handle handle, const rocblas_fill uplo, const rocblas_int n, double *A, const rocblas_int lda, const rocblas_stride strideA, rocblas_int *info, const rocblas_int batch_count) { @@ -81,7 +99,7 @@ ROCSOLVER_EXPORT rocblas_status rocsolver_dpotrf_strided_batched( handle, uplo, n, A, lda, strideA, info, batch_count); } -ROCSOLVER_EXPORT rocblas_status rocsolver_cpotrf_strided_batched( +rocblas_status rocsolver_cpotrf_strided_batched( rocblas_handle handle, const rocblas_fill uplo, const rocblas_int n, rocblas_float_complex *A, const rocblas_int lda, const rocblas_stride strideA, rocblas_int *info, @@ -90,7 +108,7 @@ ROCSOLVER_EXPORT rocblas_status rocsolver_cpotrf_strided_batched( handle, uplo, n, A, lda, strideA, info, batch_count); } -ROCSOLVER_EXPORT rocblas_status rocsolver_zpotrf_strided_batched( +rocblas_status rocsolver_zpotrf_strided_batched( rocblas_handle handle, const rocblas_fill uplo, const rocblas_int n, rocblas_double_complex *A, const rocblas_int lda, const rocblas_stride strideA, rocblas_int *info, diff --git a/rocsolver/library/src/rocblas.cpp b/rocsolver/library/src/rocblas.cpp deleted file mode 100644 index a4983f7ee..000000000 --- a/rocsolver/library/src/rocblas.cpp +++ /dev/null @@ -1,155 +0,0 @@ -/* ************************************************************************ - * Copyright (c) 2019-2020 Advanced Micro Devices, Inc. - * ************************************************************************ */ - -#include "rocblas.hpp" - -template <> -rocblas_status rocblas_nrm2(rocblas_handle handle, rocblas_int n, - const float *x, const rocblas_int incx, - float *result) { - return rocblas_snrm2(handle, n, x, incx, result); -} -template <> -rocblas_status rocblas_nrm2(rocblas_handle handle, rocblas_int n, - const double *x, const rocblas_int incx, - double *result) { - return rocblas_dnrm2(handle, n, x, incx, result); -} - -/*template <> -rocblas_status rocblas_iamax(rocblas_handle handle, rocblas_int n, - const float *x, rocblas_int incx, - rocblas_int *result) { - return rocblas_isamax(handle, n, x, incx, result); -} -template <> -rocblas_status rocblas_iamax(rocblas_handle handle, rocblas_int n, - const double *x, rocblas_int incx, - rocblas_int *result) { - return rocblas_idamax(handle, n, x, incx, result); -} -template <> -rocblas_status rocblas_iamax(rocblas_handle handle, rocblas_int n, - const rocblas_float_complex *x, rocblas_int incx, - rocblas_int *result) { - return rocblas_icamax(handle, n, x, incx, result); -} -template <> -rocblas_status rocblas_iamax(rocblas_handle handle, rocblas_int n, - const rocblas_double_complex *x, rocblas_int incx, - rocblas_int *result) { - return rocblas_izamax(handle, n, x, incx, result); -}*/ - -template <> -rocblas_status rocblas_trsm(rocblas_handle handle, rocblas_side side, - rocblas_fill uplo, rocblas_operation transA, - rocblas_diagonal diag, rocblas_int m, rocblas_int n, - const float *alpha, float *A, rocblas_int lda, - float *B, rocblas_int ldb) { - return rocblas_strsm(handle, side, uplo, transA, diag, m, n, alpha, A, lda, B, - ldb); -} -template <> -rocblas_status rocblas_trsm(rocblas_handle handle, rocblas_side side, - rocblas_fill uplo, rocblas_operation transA, - rocblas_diagonal diag, rocblas_int m, rocblas_int n, - const double *alpha, double *A, rocblas_int lda, - double *B, rocblas_int ldb) { - return rocblas_dtrsm(handle, side, uplo, transA, diag, m, n, alpha, A, lda, B, - ldb); -} -template <> -rocblas_status rocblas_trsm(rocblas_handle handle, rocblas_side side, - rocblas_fill uplo, rocblas_operation transA, - rocblas_diagonal diag, rocblas_int m, rocblas_int n, - const rocblas_float_complex *alpha, - rocblas_float_complex *A, rocblas_int lda, - rocblas_float_complex *B, rocblas_int ldb) { - return rocblas_ctrsm(handle, side, uplo, transA, diag, m, n, alpha, A, lda, B, - ldb); -} -template <> -rocblas_status rocblas_trsm(rocblas_handle handle, rocblas_side side, - rocblas_fill uplo, rocblas_operation transA, - rocblas_diagonal diag, rocblas_int m, rocblas_int n, - const rocblas_double_complex *alpha, - rocblas_double_complex *A, rocblas_int lda, - rocblas_double_complex *B, rocblas_int ldb) { - return rocblas_ztrsm(handle, side, uplo, transA, diag, m, n, alpha, A, lda, B, - ldb); -} - -template <> -rocblas_status rocblas_trmm(rocblas_handle handle, rocblas_side side, - rocblas_fill uplo, rocblas_operation trans, - rocblas_diagonal diag, rocblas_int m, rocblas_int n, - float *alpha, float *A, rocblas_int lda, float *B, - rocblas_int ldb) { - return rocblas_strmm(handle, side, uplo, trans, diag, m, n, alpha, A, lda, B, - ldb); -} -template <> -rocblas_status rocblas_trmm(rocblas_handle handle, rocblas_side side, - rocblas_fill uplo, rocblas_operation trans, - rocblas_diagonal diag, rocblas_int m, rocblas_int n, - double *alpha, double *A, rocblas_int lda, - double *B, rocblas_int ldb) { - return rocblas_dtrmm(handle, side, uplo, trans, diag, m, n, alpha, A, lda, B, - ldb); -} - -template <> -rocblas_status rocblas_trmm(rocblas_handle handle, rocblas_side side, - rocblas_fill uplo, rocblas_operation trans, - rocblas_diagonal diag, rocblas_int m, rocblas_int n, - rocblas_float_complex *alpha, - rocblas_float_complex *A, rocblas_int lda, - rocblas_float_complex *B, rocblas_int ldb) { - return rocblas_ctrmm(handle, side, uplo, trans, diag, m, n, alpha, A, lda, B, - ldb); -} -template <> -rocblas_status rocblas_trmm(rocblas_handle handle, rocblas_side side, - rocblas_fill uplo, rocblas_operation trans, - rocblas_diagonal diag, rocblas_int m, rocblas_int n, - rocblas_double_complex *alpha, - rocblas_double_complex *A, rocblas_int lda, - rocblas_double_complex *B, rocblas_int ldb) { - return rocblas_ztrmm(handle, side, uplo, trans, diag, m, n, alpha, A, lda, B, - ldb); -} - -// trtri -template <> -rocblas_status rocblas_trtri(rocblas_handle handle, rocblas_fill uplo, - rocblas_diagonal diag, rocblas_int n, - const float *A, rocblas_int lda, float *invA, - rocblas_int ldinvA) { - return rocblas_strtri(handle, uplo, diag, n, A, lda, invA, ldinvA); -} - -template <> -rocblas_status rocblas_trtri(rocblas_handle handle, rocblas_fill uplo, - rocblas_diagonal diag, rocblas_int n, - const double *A, rocblas_int lda, double *invA, - rocblas_int ldinvA) { - return rocblas_dtrtri(handle, uplo, diag, n, A, lda, invA, ldinvA); -} - -template <> -rocblas_status rocblas_trtri(rocblas_handle handle, rocblas_fill uplo, - rocblas_diagonal diag, rocblas_int n, - const rocblas_float_complex *A, rocblas_int lda, - rocblas_float_complex *invA, rocblas_int ldinvA) { - return rocblas_ctrtri(handle, uplo, diag, n, A, lda, invA, ldinvA); -} - -template <> -rocblas_status rocblas_trtri(rocblas_handle handle, rocblas_fill uplo, - rocblas_diagonal diag, rocblas_int n, - const rocblas_double_complex *A, rocblas_int lda, - rocblas_double_complex *invA, rocblas_int ldinvA) { - return rocblas_ztrtri(handle, uplo, diag, n, A, lda, invA, ldinvA); -}