From be030feb91fff8d6d2b4409153fe549b81237580 Mon Sep 17 00:00:00 2001 From: NaveenElumalaiAMD <69491510+NaveenElumalaiAMD@users.noreply.github.com> Date: Wed, 11 May 2022 16:26:10 -0600 Subject: [PATCH] Add numerical checking helper to Level 3 functions (#1238) * Add numerical checking helper to Level 3 rocBLAS * Added check to see if the input is const * Enclosed the kernel function of TRSM with brackets to invoke the destructor and release the handle memory * Addressed the comments --- library/src/blas3/rocblas_dgmm.cpp | 93 ++++++++-- library/src/blas3/rocblas_dgmm.hpp | 21 +++ library/src/blas3/rocblas_dgmm_batched.cpp | 92 ++++++++-- library/src/blas3/rocblas_dgmm_kernels.cpp | 114 ++++++++++++ .../blas3/rocblas_dgmm_strided_batched.cpp | 92 ++++++++-- library/src/blas3/rocblas_geam.hpp | 1 - library/src/blas3/rocblas_hemm.cpp | 102 ++++++++--- library/src/blas3/rocblas_hemm_batched.cpp | 100 ++++++++--- .../blas3/rocblas_hemm_strided_batched.cpp | 99 ++++++++--- library/src/blas3/rocblas_her2k.cpp | 103 ++++++++--- library/src/blas3/rocblas_her2k_batched.cpp | 102 ++++++++--- .../blas3/rocblas_her2k_strided_batched.cpp | 102 ++++++++--- library/src/blas3/rocblas_herk.cpp | 88 ++++++++-- library/src/blas3/rocblas_herk_batched.cpp | 86 +++++++-- .../blas3/rocblas_herk_strided_batched.cpp | 87 +++++++-- library/src/blas3/rocblas_herkx.cpp | 103 ++++++++--- library/src/blas3/rocblas_herkx_batched.cpp | 102 ++++++++--- .../blas3/rocblas_herkx_strided_batched.cpp | 102 ++++++++--- library/src/blas3/rocblas_symm.cpp | 103 ++++++++--- library/src/blas3/rocblas_symm_batched.cpp | 103 ++++++++--- library/src/blas3/rocblas_symm_hemm.hpp | 21 +++ library/src/blas3/rocblas_symm_kernels.cpp | 129 +++++++++++++- .../blas3/rocblas_symm_strided_batched.cpp | 102 ++++++++--- library/src/blas3/rocblas_syr2k.cpp | 103 ++++++++--- library/src/blas3/rocblas_syr2k_batched.cpp | 102 ++++++++--- library/src/blas3/rocblas_syr2k_her2k.hpp | 21 +++ .../src/blas3/rocblas_syr2k_her2k_kernels.cpp | 124 +++++++++++++ .../blas3/rocblas_syr2k_strided_batched.cpp | 103 ++++++++--- library/src/blas3/rocblas_syrk.cpp | 89 ++++++++-- library/src/blas3/rocblas_syrk_batched.cpp | 88 ++++++++-- library/src/blas3/rocblas_syrk_herk.hpp | 18 ++ .../src/blas3/rocblas_syrk_herk_kernels.cpp | 97 ++++++++++ .../blas3/rocblas_syrk_strided_batched.cpp | 88 ++++++++-- library/src/blas3/rocblas_syrkx.cpp | 102 ++++++++--- library/src/blas3/rocblas_syrkx_batched.cpp | 102 ++++++++--- .../blas3/rocblas_syrkx_strided_batched.cpp | 101 ++++++++--- library/src/blas3/rocblas_trmm.cpp | 102 ++++++++--- library/src/blas3/rocblas_trmm.hpp | 19 ++ library/src/blas3/rocblas_trmm_batched.cpp | 102 ++++++++--- library/src/blas3/rocblas_trmm_kernels.cpp | 98 +++++++++++ .../blas3/rocblas_trmm_strided_batched.cpp | 102 ++++++++--- library/src/blas3/rocblas_trsm.cpp | 163 +++++++++++------ library/src/blas3/rocblas_trsm.hpp | 1 + library/src/blas3/rocblas_trsm_batched.cpp | 157 +++++++++++------ .../blas3/rocblas_trsm_strided_batched.cpp | 165 ++++++++++++------ library/src/blas3/rocblas_trtri.cpp | 86 +++++++-- library/src/blas3/rocblas_trtri.hpp | 56 ++++++ library/src/blas3/rocblas_trtri_batched.cpp | 52 +++++- .../blas3/rocblas_trtri_strided_batched.cpp | 51 +++++- 49 files changed, 3545 insertions(+), 794 deletions(-) diff --git a/library/src/blas3/rocblas_dgmm.cpp b/library/src/blas3/rocblas_dgmm.cpp index f40642594..994c4dcbc 100644 --- a/library/src/blas3/rocblas_dgmm.cpp +++ b/library/src/blas3/rocblas_dgmm.cpp @@ -55,7 +55,8 @@ namespace RETURN_ZERO_DEVICE_MEMORY_SIZE_IF_QUERIED(handle); - auto layer_mode = handle->layer_mode; + auto layer_mode = handle->layer_mode; + auto check_numerics = handle->check_numerics; if(layer_mode & (rocblas_layer_mode_log_trace | rocblas_layer_mode_log_bench @@ -113,25 +114,79 @@ namespace static constexpr rocblas_int batch_count = 1; static constexpr rocblas_stride stride_A = 0, stride_x = 0, stride_C = 0; - return rocblas_dgmm_template(handle, - side, - m, - n, - A, - offset_A, - lda, - stride_A, - x, - offset_x, - incx, - stride_x, - C, - offset_C, - ldc, - stride_C, - batch_count); - } + if(check_numerics) + { + bool is_input = true; + rocblas_status dgmm_check_numerics_status + = rocblas_dgmm_check_numerics(rocblas_dgmm_name, + handle, + side, + m, + n, + A, + lda, + stride_A, + x, + incx, + stride_x, + C, + ldc, + stride_C, + batch_count, + check_numerics, + is_input); + if(dgmm_check_numerics_status != rocblas_status_success) + return dgmm_check_numerics_status; + } + rocblas_status status = rocblas_status_success; + status = rocblas_dgmm_template(handle, + side, + m, + n, + A, + offset_A, + lda, + stride_A, + x, + offset_x, + incx, + stride_x, + C, + offset_C, + ldc, + stride_C, + batch_count); + + if(status != rocblas_status_success) + return status; + + if(check_numerics) + { + bool is_input = false; + rocblas_status dgmm_check_numerics_status + = rocblas_dgmm_check_numerics(rocblas_dgmm_name, + handle, + side, + m, + n, + A, + lda, + stride_A, + x, + incx, + stride_x, + C, + ldc, + stride_C, + batch_count, + check_numerics, + is_input); + if(dgmm_check_numerics_status != rocblas_status_success) + return dgmm_check_numerics_status; + } + return status; + } } // namespace /* diff --git a/library/src/blas3/rocblas_dgmm.hpp b/library/src/blas3/rocblas_dgmm.hpp index 48ac393b8..de28fedd3 100644 --- a/library/src/blas3/rocblas_dgmm.hpp +++ b/library/src/blas3/rocblas_dgmm.hpp @@ -21,6 +21,8 @@ * ************************************************************************ */ #pragma once +#include "check_numerics_matrix.hpp" +#include "check_numerics_vector.hpp" #include "handle.hpp" /** @@ -47,3 +49,22 @@ rocblas_status rocblas_dgmm_template(rocblas_handle handle, rocblas_int ldc, rocblas_stride stride_c, rocblas_int batch_count); + +template +rocblas_status rocblas_dgmm_check_numerics(const char* function_name, + rocblas_handle handle, + rocblas_side side, + rocblas_int m, + rocblas_int n, + TConstPtr A, + rocblas_int lda, + rocblas_stride stride_A, + TConstPtr x, + rocblas_int incx, + rocblas_stride stride_x, + TPtr C, + rocblas_int ldc, + rocblas_stride stride_c, + rocblas_int batch_count, + const int check_numerics, + bool is_input); diff --git a/library/src/blas3/rocblas_dgmm_batched.cpp b/library/src/blas3/rocblas_dgmm_batched.cpp index 147516637..655fd460e 100644 --- a/library/src/blas3/rocblas_dgmm_batched.cpp +++ b/library/src/blas3/rocblas_dgmm_batched.cpp @@ -57,7 +57,8 @@ namespace RETURN_ZERO_DEVICE_MEMORY_SIZE_IF_QUERIED(handle); - auto layer_mode = handle->layer_mode; + auto layer_mode = handle->layer_mode; + auto check_numerics = handle->check_numerics; if(layer_mode & (rocblas_layer_mode_log_trace | rocblas_layer_mode_log_bench @@ -129,23 +130,78 @@ namespace static constexpr rocblas_stride offset_a = 0, offset_x = 0, offset_c = 0; static constexpr rocblas_stride stride_a = 0, stride_x = 0, stride_c = 0; - return rocblas_dgmm_template(handle, - side, - m, - n, - A, - offset_a, - lda, - stride_a, - x, - offset_x, - incx, - stride_x, - C, - offset_c, - ldc, - stride_c, - batch_count); + if(check_numerics) + { + bool is_input = true; + rocblas_status dgmm_check_numerics_status + = rocblas_dgmm_check_numerics(rocblas_dgmm_batched_name, + handle, + side, + m, + n, + A, + lda, + stride_a, + x, + incx, + stride_x, + C, + ldc, + stride_c, + batch_count, + check_numerics, + is_input); + if(dgmm_check_numerics_status != rocblas_status_success) + return dgmm_check_numerics_status; + } + + rocblas_status status = rocblas_status_success; + status = rocblas_dgmm_template(handle, + side, + m, + n, + A, + offset_a, + lda, + stride_a, + x, + offset_x, + incx, + stride_x, + C, + offset_c, + ldc, + stride_c, + batch_count); + + if(status != rocblas_status_success) + return status; + + if(check_numerics) + { + bool is_input = false; + rocblas_status dgmm_check_numerics_status + = rocblas_dgmm_check_numerics(rocblas_dgmm_batched_name, + handle, + side, + m, + n, + A, + lda, + stride_a, + x, + incx, + stride_x, + C, + ldc, + stride_c, + batch_count, + check_numerics, + is_input); + if(dgmm_check_numerics_status != rocblas_status_success) + return dgmm_check_numerics_status; + } + return status; } } // namespace diff --git a/library/src/blas3/rocblas_dgmm_kernels.cpp b/library/src/blas3/rocblas_dgmm_kernels.cpp index 35b158bfc..fc91e8501 100644 --- a/library/src/blas3/rocblas_dgmm_kernels.cpp +++ b/library/src/blas3/rocblas_dgmm_kernels.cpp @@ -158,6 +158,81 @@ rocblas_status rocblas_dgmm_template(rocblas_handle handle, } return rocblas_status_success; } + +template +rocblas_status rocblas_dgmm_check_numerics(const char* function_name, + rocblas_handle handle, + rocblas_side side, + rocblas_int m, + rocblas_int n, + TConstPtr A, + rocblas_int lda, + rocblas_stride stride_a, + TConstPtr x, + rocblas_int incx, + rocblas_stride stride_x, + TPtr C, + rocblas_int ldc, + rocblas_stride stride_c, + rocblas_int batch_count, + const int check_numerics, + bool is_input) +{ + + rocblas_status check_numerics_status = rocblas_status_success; + if(is_input) + { + rocblas_int dim_x = (side == rocblas_side_left) ? m : n; + check_numerics_status + = rocblas_internal_check_numerics_matrix_template(function_name, + handle, + rocblas_operation_none, + rocblas_fill_full, + rocblas_client_general_matrix, + m, + n, + A, + 0, + lda, + stride_a, + batch_count, + check_numerics, + is_input); + if(check_numerics_status != rocblas_status_success) + return check_numerics_status; + + check_numerics_status = rocblas_internal_check_numerics_vector_template(function_name, + handle, + dim_x, + x, + 0, + incx, + stride_x, + batch_count, + check_numerics, + is_input); + if(check_numerics_status != rocblas_status_success) + return check_numerics_status; + } + check_numerics_status + = rocblas_internal_check_numerics_matrix_template(function_name, + handle, + rocblas_operation_none, + rocblas_fill_full, + rocblas_client_general_matrix, + m, + n, + C, + 0, + ldc, + stride_c, + batch_count, + check_numerics, + is_input); + + return check_numerics_status; +} + // Instantiations below will need to be manually updated to match any change in // template parameters in the files dgmm*.cpp @@ -198,4 +273,43 @@ INSTANTIATE_DGMM_TEMPLATE(double const* const*, double* const*) INSTANTIATE_DGMM_TEMPLATE( rocblas_float_complex const* const*, rocblas_float_complex* const*) INSTANTIATE_DGMM_TEMPLATE(rocblas_double_complex const* const*, rocblas_double_complex* const*) #undef INSTANTIATE_DGMM_TEMPLATE + + +#ifdef INSTANTIATE_DGMM_NUMERICS +#error INSTANTIATE_DGMM_NUMERICS already defined +#endif + +#define INSTANTIATE_DGMM_NUMERICS(TConstPtr_, TPtr_) \ +template rocblas_status rocblas_dgmm_check_numerics \ + (const char* function_name, \ + rocblas_handle handle, \ + rocblas_side side, \ + rocblas_int m, \ + rocblas_int n, \ + TConstPtr_ A, \ + rocblas_int lda, \ + rocblas_stride stride_a, \ + TConstPtr_ x, \ + rocblas_int inc, \ + rocblas_stride stride_x, \ + TPtr_ C, \ + rocblas_int ldc, \ + rocblas_stride stride_c, \ + rocblas_int batch_count, \ + const int check_numerics, \ + bool is_input); + +// instantiate for rocblas_Xdgmm and rocblas_Xdgmm_strided_batched +INSTANTIATE_DGMM_NUMERICS(float const*, float*) +INSTANTIATE_DGMM_NUMERICS(double const*, double*) +INSTANTIATE_DGMM_NUMERICS(rocblas_float_complex const*, rocblas_float_complex*) +INSTANTIATE_DGMM_NUMERICS(rocblas_double_complex const*, rocblas_double_complex*) + +// instantiate for rocblas_Xdgmm_batched +INSTANTIATE_DGMM_NUMERICS(float const* const*, float* const*) +INSTANTIATE_DGMM_NUMERICS(double const* const*, double* const*) +INSTANTIATE_DGMM_NUMERICS(rocblas_float_complex const* const*, rocblas_float_complex* const*) +INSTANTIATE_DGMM_NUMERICS(rocblas_double_complex const* const*, rocblas_double_complex* const*) + +#undef INSTANTIATE_DGMM_NUMERICS // clang-format on diff --git a/library/src/blas3/rocblas_dgmm_strided_batched.cpp b/library/src/blas3/rocblas_dgmm_strided_batched.cpp index 9f763d976..0fce6a862 100644 --- a/library/src/blas3/rocblas_dgmm_strided_batched.cpp +++ b/library/src/blas3/rocblas_dgmm_strided_batched.cpp @@ -61,7 +61,8 @@ namespace RETURN_ZERO_DEVICE_MEMORY_SIZE_IF_QUERIED(handle); - auto layer_mode = handle->layer_mode; + auto layer_mode = handle->layer_mode; + auto check_numerics = handle->check_numerics; if(layer_mode & (rocblas_layer_mode_log_trace | rocblas_layer_mode_log_bench @@ -149,25 +150,78 @@ namespace static constexpr rocblas_stride offset_a = 0, offset_x = 0, offset_c = 0; - return rocblas_dgmm_template(handle, - side, - m, - n, - A, - offset_a, - lda, - stride_a, - x, - offset_x, - incx, - stride_x, - C, - offset_c, - ldc, - stride_c, - batch_count); - } + if(check_numerics) + { + bool is_input = true; + rocblas_status dgmm_check_numerics_status + = rocblas_dgmm_check_numerics(rocblas_dgmm_strided_batched_name, + handle, + side, + m, + n, + A, + lda, + stride_a, + x, + incx, + stride_x, + C, + ldc, + stride_c, + batch_count, + check_numerics, + is_input); + if(dgmm_check_numerics_status != rocblas_status_success) + return dgmm_check_numerics_status; + } + + rocblas_status status = rocblas_status_success; + status = rocblas_dgmm_template(handle, + side, + m, + n, + A, + offset_a, + lda, + stride_a, + x, + offset_x, + incx, + stride_x, + C, + offset_c, + ldc, + stride_c, + batch_count); + if(status != rocblas_status_success) + return status; + if(check_numerics) + { + bool is_input = false; + rocblas_status dgmm_check_numerics_status + = rocblas_dgmm_check_numerics(rocblas_dgmm_strided_batched_name, + handle, + side, + m, + n, + A, + lda, + stride_a, + x, + incx, + stride_x, + C, + ldc, + stride_c, + batch_count, + check_numerics, + is_input); + if(dgmm_check_numerics_status != rocblas_status_success) + return dgmm_check_numerics_status; + } + return status; + } } // namespace /* diff --git a/library/src/blas3/rocblas_geam.hpp b/library/src/blas3/rocblas_geam.hpp index 292526a7a..b8e533fed 100644 --- a/library/src/blas3/rocblas_geam.hpp +++ b/library/src/blas3/rocblas_geam.hpp @@ -22,7 +22,6 @@ #pragma once #include "check_numerics_matrix.hpp" -#include "check_numerics_vector.hpp" #include "handle.hpp" /** diff --git a/library/src/blas3/rocblas_hemm.cpp b/library/src/blas3/rocblas_hemm.cpp index 5b45a24f2..8f358e792 100644 --- a/library/src/blas3/rocblas_hemm.cpp +++ b/library/src/blas3/rocblas_hemm.cpp @@ -52,7 +52,8 @@ namespace RETURN_ZERO_DEVICE_MEMORY_SIZE_IF_QUERIED(handle); - auto layer_mode = handle->layer_mode; + auto layer_mode = handle->layer_mode; + auto check_numerics = handle->check_numerics; if(layer_mode & (rocblas_layer_mode_log_trace | rocblas_layer_mode_log_bench | rocblas_layer_mode_log_profile)) @@ -145,26 +146,85 @@ namespace return arg_status; static constexpr bool Hermetian = true; - return rocblas_internal_symm_template(handle, - side, - uplo, - m, - n, - alpha, - A, - offset_A, - lda, - stride_A, - B, - offset_B, - ldb, - stride_B, - beta, - C, - offset_C, - ldc, - stride_C, - batch_count); + + if(check_numerics) + { + bool is_input = true; + rocblas_status hemm_check_numerics_status + = rocblas_hemm_symm_check_numerics(rocblas_hemm_name, + handle, + side, + uplo, + m, + n, + A, + lda, + stride_A, + B, + ldb, + stride_B, + C, + ldc, + stride_C, + batch_count, + check_numerics, + is_input); + + if(hemm_check_numerics_status != rocblas_status_success) + return hemm_check_numerics_status; + } + + rocblas_status status = rocblas_status_success; + status = rocblas_internal_symm_template(handle, + side, + uplo, + m, + n, + alpha, + A, + offset_A, + lda, + stride_A, + B, + offset_B, + ldb, + stride_B, + beta, + C, + offset_C, + ldc, + stride_C, + batch_count); + if(status != rocblas_status_success) + return status; + + if(check_numerics) + { + bool is_input = false; + rocblas_status hemm_check_numerics_status + = rocblas_hemm_symm_check_numerics(rocblas_hemm_name, + handle, + side, + uplo, + m, + n, + A, + lda, + stride_A, + B, + ldb, + stride_B, + C, + ldc, + stride_C, + batch_count, + check_numerics, + is_input); + + if(hemm_check_numerics_status != rocblas_status_success) + return hemm_check_numerics_status; + } + return status; } } /* diff --git a/library/src/blas3/rocblas_hemm_batched.cpp b/library/src/blas3/rocblas_hemm_batched.cpp index c0f3d1b02..c213fd652 100644 --- a/library/src/blas3/rocblas_hemm_batched.cpp +++ b/library/src/blas3/rocblas_hemm_batched.cpp @@ -53,7 +53,8 @@ namespace RETURN_ZERO_DEVICE_MEMORY_SIZE_IF_QUERIED(handle); - auto layer_mode = handle->layer_mode; + auto layer_mode = handle->layer_mode; + auto check_numerics = handle->check_numerics; if(layer_mode & (rocblas_layer_mode_log_trace | rocblas_layer_mode_log_bench | rocblas_layer_mode_log_profile)) @@ -150,26 +151,83 @@ namespace return arg_status; static constexpr bool Hermetian = true; - return rocblas_internal_symm_template(handle, - side, - uplo, - m, - n, - alpha, - A, - offset_A, - lda, - stride_A, - B, - offset_B, - ldb, - stride_B, - beta, - C, - offset_C, - ldc, - stride_C, - batch_count); + if(check_numerics) + { + bool is_input = true; + rocblas_status hemm_check_numerics_status + = rocblas_hemm_symm_check_numerics(rocblas_hemm_name, + handle, + side, + uplo, + m, + n, + A, + lda, + stride_A, + B, + ldb, + stride_B, + C, + ldc, + stride_C, + batch_count, + check_numerics, + is_input); + if(hemm_check_numerics_status != rocblas_status_success) + return hemm_check_numerics_status; + } + + rocblas_status status = rocblas_status_success; + status = rocblas_internal_symm_template(handle, + side, + uplo, + m, + n, + alpha, + A, + offset_A, + lda, + stride_A, + B, + offset_B, + ldb, + stride_B, + beta, + C, + offset_C, + ldc, + stride_C, + batch_count); + + if(status != rocblas_status_success) + return status; + + if(check_numerics) + { + bool is_input = false; + rocblas_status hemm_check_numerics_status + = rocblas_hemm_symm_check_numerics(rocblas_hemm_name, + handle, + side, + uplo, + m, + n, + A, + lda, + stride_A, + B, + ldb, + stride_B, + C, + ldc, + stride_C, + batch_count, + check_numerics, + is_input); + if(hemm_check_numerics_status != rocblas_status_success) + return hemm_check_numerics_status; + } + return status; } } diff --git a/library/src/blas3/rocblas_hemm_strided_batched.cpp b/library/src/blas3/rocblas_hemm_strided_batched.cpp index a5815fea5..ebc124f96 100644 --- a/library/src/blas3/rocblas_hemm_strided_batched.cpp +++ b/library/src/blas3/rocblas_hemm_strided_batched.cpp @@ -56,7 +56,8 @@ namespace RETURN_ZERO_DEVICE_MEMORY_SIZE_IF_QUERIED(handle); - auto layer_mode = handle->layer_mode; + auto layer_mode = handle->layer_mode; + auto check_numerics = handle->check_numerics; if(layer_mode & (rocblas_layer_mode_log_trace | rocblas_layer_mode_log_bench | rocblas_layer_mode_log_profile)) @@ -167,26 +168,82 @@ namespace return arg_status; static constexpr bool Hermetian = true; - return rocblas_internal_symm_template(handle, - side, - uplo, - m, - n, - alpha, - A, - offset_A, - lda, - stride_a, - B, - offset_B, - ldb, - stride_b, - beta, - C, - offset_C, - ldc, - stride_c, - batch_count); + if(check_numerics) + { + bool is_input = true; + rocblas_status hemm_check_numerics_status + = rocblas_hemm_symm_check_numerics(rocblas_hemm_name, + handle, + side, + uplo, + m, + n, + A, + lda, + stride_a, + B, + ldb, + stride_b, + C, + ldc, + stride_c, + batch_count, + check_numerics, + is_input); + if(hemm_check_numerics_status != rocblas_status_success) + return hemm_check_numerics_status; + } + + rocblas_status status = rocblas_status_success; + status = rocblas_internal_symm_template(handle, + side, + uplo, + m, + n, + alpha, + A, + offset_A, + lda, + stride_a, + B, + offset_B, + ldb, + stride_b, + beta, + C, + offset_C, + ldc, + stride_c, + batch_count); + if(status != rocblas_status_success) + return status; + + if(check_numerics) + { + bool is_input = false; + rocblas_status hemm_check_numerics_status + = rocblas_hemm_symm_check_numerics(rocblas_hemm_name, + handle, + side, + uplo, + m, + n, + A, + lda, + stride_a, + B, + ldb, + stride_b, + C, + ldc, + stride_c, + batch_count, + check_numerics, + is_input); + if(hemm_check_numerics_status != rocblas_status_success) + return hemm_check_numerics_status; + } + return status; } } diff --git a/library/src/blas3/rocblas_her2k.cpp b/library/src/blas3/rocblas_her2k.cpp index 3660a1aba..9515d8187 100644 --- a/library/src/blas3/rocblas_her2k.cpp +++ b/library/src/blas3/rocblas_her2k.cpp @@ -52,7 +52,8 @@ namespace RETURN_ZERO_DEVICE_MEMORY_SIZE_IF_QUERIED(handle); - auto layer_mode = handle->layer_mode; + auto layer_mode = handle->layer_mode; + auto check_numerics = handle->check_numerics; if(layer_mode & (rocblas_layer_mode_log_trace | rocblas_layer_mode_log_bench | rocblas_layer_mode_log_profile)) @@ -143,28 +144,88 @@ namespace if(arg_status != rocblas_status_continue) return arg_status; + static constexpr bool Hermetian = true; + if(check_numerics) + { + bool is_input = true; + rocblas_status her2k_check_numerics_status + = rocblas_her2k_syr2k_check_numerics(rocblas_her2k_name, + handle, + uplo, + trans, + n, + k, + A, + lda, + stride_A, + B, + ldb, + stride_B, + C, + ldc, + stride_C, + batch_count, + check_numerics, + is_input); + + if(her2k_check_numerics_status != rocblas_status_success) + return her2k_check_numerics_status; + } + static constexpr bool is2K = true; static constexpr bool BATCHED = false; - return rocblas_internal_her2k_template(handle, - uplo, - trans, - n, - k, - alpha, - A, - offset_A, - lda, - stride_A, - B, - offset_B, - ldb, - stride_B, - beta, - C, - offset_C, - ldc, - stride_C, - batch_count); + rocblas_status status = rocblas_status_success; + status = rocblas_internal_her2k_template(handle, + uplo, + trans, + n, + k, + alpha, + A, + offset_A, + lda, + stride_A, + B, + offset_B, + ldb, + stride_B, + beta, + C, + offset_C, + ldc, + stride_C, + batch_count); + + if(status != rocblas_status_success) + return status; + + if(check_numerics) + { + bool is_input = false; + rocblas_status her2k_check_numerics_status + = rocblas_her2k_syr2k_check_numerics(rocblas_her2k_name, + handle, + uplo, + trans, + n, + k, + A, + lda, + stride_A, + B, + ldb, + stride_B, + C, + ldc, + stride_C, + batch_count, + check_numerics, + is_input); + + if(her2k_check_numerics_status != rocblas_status_success) + return her2k_check_numerics_status; + } + return status; } } diff --git a/library/src/blas3/rocblas_her2k_batched.cpp b/library/src/blas3/rocblas_her2k_batched.cpp index 6bf58800a..1703212df 100644 --- a/library/src/blas3/rocblas_her2k_batched.cpp +++ b/library/src/blas3/rocblas_her2k_batched.cpp @@ -53,7 +53,8 @@ namespace RETURN_ZERO_DEVICE_MEMORY_SIZE_IF_QUERIED(handle); - auto layer_mode = handle->layer_mode; + auto layer_mode = handle->layer_mode; + auto check_numerics = handle->check_numerics; if(layer_mode & (rocblas_layer_mode_log_trace | rocblas_layer_mode_log_bench | rocblas_layer_mode_log_profile)) @@ -148,28 +149,87 @@ namespace if(arg_status != rocblas_status_continue) return arg_status; + static constexpr bool Hermetian = true; + if(check_numerics) + { + bool is_input = true; + rocblas_status her2k_check_numerics_status + = rocblas_her2k_syr2k_check_numerics(rocblas_her2k_name, + handle, + uplo, + trans, + n, + k, + A, + lda, + stride_A, + B, + ldb, + stride_B, + C, + ldc, + stride_C, + batch_count, + check_numerics, + is_input); + + if(her2k_check_numerics_status != rocblas_status_success) + return her2k_check_numerics_status; + } + static constexpr bool is2K = true; static constexpr bool BATCHED = true; - return rocblas_internal_her2k_template(handle, - uplo, - trans, - n, - k, - alpha, - A, - offset_A, - lda, - stride_A, - B, - offset_B, - ldb, - stride_B, - beta, - C, - offset_C, - ldc, - stride_C, - batch_count); + rocblas_status status = rocblas_status_success; + status = rocblas_internal_her2k_template(handle, + uplo, + trans, + n, + k, + alpha, + A, + offset_A, + lda, + stride_A, + B, + offset_B, + ldb, + stride_B, + beta, + C, + offset_C, + ldc, + stride_C, + batch_count); + if(status != rocblas_status_success) + return status; + + if(check_numerics) + { + bool is_input = false; + rocblas_status her2k_check_numerics_status + = rocblas_her2k_syr2k_check_numerics(rocblas_her2k_name, + handle, + uplo, + trans, + n, + k, + A, + lda, + stride_A, + B, + ldb, + stride_B, + C, + ldc, + stride_C, + batch_count, + check_numerics, + is_input); + + if(her2k_check_numerics_status != rocblas_status_success) + return her2k_check_numerics_status; + } + return status; } } diff --git a/library/src/blas3/rocblas_her2k_strided_batched.cpp b/library/src/blas3/rocblas_her2k_strided_batched.cpp index 37bd6c57b..100b573d1 100644 --- a/library/src/blas3/rocblas_her2k_strided_batched.cpp +++ b/library/src/blas3/rocblas_her2k_strided_batched.cpp @@ -56,7 +56,8 @@ namespace RETURN_ZERO_DEVICE_MEMORY_SIZE_IF_QUERIED(handle); - auto layer_mode = handle->layer_mode; + auto layer_mode = handle->layer_mode; + auto check_numerics = handle->check_numerics; if(layer_mode & (rocblas_layer_mode_log_trace | rocblas_layer_mode_log_bench | rocblas_layer_mode_log_profile)) @@ -165,28 +166,87 @@ namespace if(arg_status != rocblas_status_continue) return arg_status; + static constexpr bool Hermetian = true; + if(check_numerics) + { + bool is_input = true; + rocblas_status her2k_check_numerics_status + = rocblas_her2k_syr2k_check_numerics(rocblas_her2k_name, + handle, + uplo, + trans, + n, + k, + A, + lda, + stride_a, + B, + ldb, + stride_b, + C, + ldc, + stride_c, + batch_count, + check_numerics, + is_input); + + if(her2k_check_numerics_status != rocblas_status_success) + return her2k_check_numerics_status; + } + static constexpr bool is2K = true; static constexpr bool BATCHED = false; - return rocblas_internal_her2k_template(handle, - uplo, - trans, - n, - k, - alpha, - A, - offset_A, - lda, - stride_a, - B, - offset_B, - ldb, - stride_b, - beta, - C, - offset_C, - ldc, - stride_c, - batch_count); + rocblas_status status = rocblas_status_success; + status = rocblas_internal_her2k_template(handle, + uplo, + trans, + n, + k, + alpha, + A, + offset_A, + lda, + stride_a, + B, + offset_B, + ldb, + stride_b, + beta, + C, + offset_C, + ldc, + stride_c, + batch_count); + if(status != rocblas_status_success) + return status; + + if(check_numerics) + { + bool is_input = false; + rocblas_status her2k_check_numerics_status + = rocblas_her2k_syr2k_check_numerics(rocblas_her2k_name, + handle, + uplo, + trans, + n, + k, + A, + lda, + stride_a, + B, + ldb, + stride_b, + C, + ldc, + stride_c, + batch_count, + check_numerics, + is_input); + + if(her2k_check_numerics_status != rocblas_status_success) + return her2k_check_numerics_status; + } + return status; } } diff --git a/library/src/blas3/rocblas_herk.cpp b/library/src/blas3/rocblas_herk.cpp index a3b298db9..f4adf3f2b 100644 --- a/library/src/blas3/rocblas_herk.cpp +++ b/library/src/blas3/rocblas_herk.cpp @@ -50,7 +50,8 @@ namespace RETURN_ZERO_DEVICE_MEMORY_SIZE_IF_QUERIED(handle); - auto layer_mode = handle->layer_mode; + auto layer_mode = handle->layer_mode; + auto check_numerics = handle->check_numerics; if(layer_mode & (rocblas_layer_mode_log_trace | rocblas_layer_mode_log_bench | rocblas_layer_mode_log_profile)) @@ -131,22 +132,75 @@ namespace if(arg_status != rocblas_status_continue) return arg_status; - return rocblas_internal_herk_template(handle, - uplo, - transA, - n, - k, - alpha, - A, - offset_A, - lda, - stride_A, - beta, - C, - offset_C, - ldc, - stride_C, - batch_count); + static constexpr bool Hermetian = true; + if(check_numerics) + { + bool is_input = true; + rocblas_status herk_check_numerics_status + = rocblas_herk_syrk_check_numerics(rocblas_herk_name, + handle, + uplo, + transA, + n, + k, + A, + lda, + stride_A, + C, + ldc, + stride_C, + batch_count, + check_numerics, + is_input); + + if(herk_check_numerics_status != rocblas_status_success) + return herk_check_numerics_status; + } + + rocblas_status status = rocblas_status_success; + status = rocblas_internal_herk_template(handle, + uplo, + transA, + n, + k, + alpha, + A, + offset_A, + lda, + stride_A, + beta, + C, + offset_C, + ldc, + stride_C, + batch_count); + if(status != rocblas_status_success) + return status; + + if(check_numerics) + { + bool is_input = false; + rocblas_status herk_check_numerics_status + = rocblas_herk_syrk_check_numerics(rocblas_herk_name, + handle, + uplo, + transA, + n, + k, + A, + lda, + stride_A, + C, + ldc, + stride_C, + batch_count, + check_numerics, + is_input); + + if(herk_check_numerics_status != rocblas_status_success) + return herk_check_numerics_status; + } + return status; } } diff --git a/library/src/blas3/rocblas_herk_batched.cpp b/library/src/blas3/rocblas_herk_batched.cpp index 5a83316ce..ed65e78ca 100644 --- a/library/src/blas3/rocblas_herk_batched.cpp +++ b/library/src/blas3/rocblas_herk_batched.cpp @@ -51,7 +51,8 @@ namespace RETURN_ZERO_DEVICE_MEMORY_SIZE_IF_QUERIED(handle); - auto layer_mode = handle->layer_mode; + auto layer_mode = handle->layer_mode; + auto check_numerics = handle->check_numerics; if(layer_mode & (rocblas_layer_mode_log_trace | rocblas_layer_mode_log_bench | rocblas_layer_mode_log_profile)) @@ -136,22 +137,73 @@ namespace if(arg_status != rocblas_status_continue) return arg_status; - return rocblas_internal_herk_template(handle, - uplo, - transA, - n, - k, - alpha, - A, - offset_A, - lda, - stride_A, - beta, - C, - offset_C, - ldc, - stride_C, - batch_count); + static constexpr bool Hermetian = true; + if(check_numerics) + { + bool is_input = true; + rocblas_status herk_check_numerics_status + = rocblas_herk_syrk_check_numerics(rocblas_herk_name, + handle, + uplo, + transA, + n, + k, + A, + lda, + stride_A, + C, + ldc, + stride_C, + batch_count, + check_numerics, + is_input); + + if(herk_check_numerics_status != rocblas_status_success) + return herk_check_numerics_status; + } + + rocblas_status status = rocblas_status_success; + status = rocblas_internal_herk_template(handle, + uplo, + transA, + n, + k, + alpha, + A, + offset_A, + lda, + stride_A, + beta, + C, + offset_C, + ldc, + stride_C, + batch_count); + + if(check_numerics) + { + bool is_input = false; + rocblas_status herk_check_numerics_status + = rocblas_herk_syrk_check_numerics(rocblas_herk_name, + handle, + uplo, + transA, + n, + k, + A, + lda, + stride_A, + C, + ldc, + stride_C, + batch_count, + check_numerics, + is_input); + + if(herk_check_numerics_status != rocblas_status_success) + return herk_check_numerics_status; + } + return status; } } diff --git a/library/src/blas3/rocblas_herk_strided_batched.cpp b/library/src/blas3/rocblas_herk_strided_batched.cpp index 4fcd522ab..0f31c3b82 100644 --- a/library/src/blas3/rocblas_herk_strided_batched.cpp +++ b/library/src/blas3/rocblas_herk_strided_batched.cpp @@ -53,7 +53,8 @@ namespace RETURN_ZERO_DEVICE_MEMORY_SIZE_IF_QUERIED(handle); - auto layer_mode = handle->layer_mode; + auto layer_mode = handle->layer_mode; + auto check_numerics = handle->check_numerics; if(layer_mode & (rocblas_layer_mode_log_trace | rocblas_layer_mode_log_bench | rocblas_layer_mode_log_profile)) @@ -147,24 +148,74 @@ namespace if(arg_status != rocblas_status_continue) return arg_status; - return rocblas_internal_herk_template(handle, - uplo, - transA, - n, - k, - alpha, - A, - offset_A, - lda, - stride_a, - beta, - C, - offset_C, - ldc, - stride_c, - batch_count); - } + static constexpr bool Hermetian = true; + if(check_numerics) + { + bool is_input = true; + rocblas_status herk_check_numerics_status + = rocblas_herk_syrk_check_numerics(rocblas_herk_name, + handle, + uplo, + transA, + n, + k, + A, + lda, + stride_a, + C, + ldc, + stride_c, + batch_count, + check_numerics, + is_input); + + if(herk_check_numerics_status != rocblas_status_success) + return herk_check_numerics_status; + } + + rocblas_status status = rocblas_status_success; + status = rocblas_internal_herk_template(handle, + uplo, + transA, + n, + k, + alpha, + A, + offset_A, + lda, + stride_a, + beta, + C, + offset_C, + ldc, + stride_c, + batch_count); + if(check_numerics) + { + bool is_input = false; + rocblas_status herk_check_numerics_status + = rocblas_herk_syrk_check_numerics(rocblas_herk_name, + handle, + uplo, + transA, + n, + k, + A, + lda, + stride_a, + C, + ldc, + stride_c, + batch_count, + check_numerics, + is_input); + + if(herk_check_numerics_status != rocblas_status_success) + return herk_check_numerics_status; + } + return status; + } } /* * =========================================================================== diff --git a/library/src/blas3/rocblas_herkx.cpp b/library/src/blas3/rocblas_herkx.cpp index cc54e8fec..fd6390d0e 100644 --- a/library/src/blas3/rocblas_herkx.cpp +++ b/library/src/blas3/rocblas_herkx.cpp @@ -52,7 +52,8 @@ namespace RETURN_ZERO_DEVICE_MEMORY_SIZE_IF_QUERIED(handle); - auto layer_mode = handle->layer_mode; + auto layer_mode = handle->layer_mode; + auto check_numerics = handle->check_numerics; if(layer_mode & (rocblas_layer_mode_log_trace | rocblas_layer_mode_log_bench | rocblas_layer_mode_log_profile)) @@ -143,28 +144,88 @@ namespace if(arg_status != rocblas_status_continue) return arg_status; + static constexpr bool Hermetian = true; + if(check_numerics) + { + bool is_input = true; + rocblas_status herkx_check_numerics_status + = rocblas_her2k_syr2k_check_numerics(rocblas_herkx_name, + handle, + uplo, + trans, + n, + k, + A, + lda, + stride_A, + B, + ldb, + stride_B, + C, + ldc, + stride_C, + batch_count, + check_numerics, + is_input); + + if(herkx_check_numerics_status != rocblas_status_success) + return herkx_check_numerics_status; + } + static constexpr bool is2K = false; // herkx static constexpr bool BATCHED = false; - return rocblas_internal_her2k_template(handle, - uplo, - trans, - n, - k, - alpha, - A, - offset_A, - lda, - stride_A, - B, - offset_B, - ldb, - stride_B, - beta, - C, - offset_C, - ldc, - stride_C, - batch_count); + rocblas_status status = rocblas_status_success; + status = rocblas_internal_her2k_template(handle, + uplo, + trans, + n, + k, + alpha, + A, + offset_A, + lda, + stride_A, + B, + offset_B, + ldb, + stride_B, + beta, + C, + offset_C, + ldc, + stride_C, + batch_count); + + if(status != rocblas_status_success) + return status; + + if(check_numerics) + { + bool is_input = false; + rocblas_status herkx_check_numerics_status + = rocblas_her2k_syr2k_check_numerics(rocblas_herkx_name, + handle, + uplo, + trans, + n, + k, + A, + lda, + stride_A, + B, + ldb, + stride_B, + C, + ldc, + stride_C, + batch_count, + check_numerics, + is_input); + + if(herkx_check_numerics_status != rocblas_status_success) + return herkx_check_numerics_status; + } + return status; } } diff --git a/library/src/blas3/rocblas_herkx_batched.cpp b/library/src/blas3/rocblas_herkx_batched.cpp index 6749daeed..392de4ab9 100644 --- a/library/src/blas3/rocblas_herkx_batched.cpp +++ b/library/src/blas3/rocblas_herkx_batched.cpp @@ -53,7 +53,8 @@ namespace RETURN_ZERO_DEVICE_MEMORY_SIZE_IF_QUERIED(handle); - auto layer_mode = handle->layer_mode; + auto layer_mode = handle->layer_mode; + auto check_numerics = handle->check_numerics; if(layer_mode & (rocblas_layer_mode_log_trace | rocblas_layer_mode_log_bench | rocblas_layer_mode_log_profile)) @@ -149,28 +150,87 @@ namespace if(arg_status != rocblas_status_continue) return arg_status; + static constexpr bool Hermetian = true; + if(check_numerics) + { + bool is_input = true; + rocblas_status herkx_check_numerics_status + = rocblas_her2k_syr2k_check_numerics(rocblas_herkx_name, + handle, + uplo, + trans, + n, + k, + A, + lda, + stride_A, + B, + ldb, + stride_B, + C, + ldc, + stride_C, + batch_count, + check_numerics, + is_input); + + if(herkx_check_numerics_status != rocblas_status_success) + return herkx_check_numerics_status; + } + static constexpr bool is2K = false; // herkx static constexpr bool BATCHED = true; - return rocblas_internal_her2k_template(handle, - uplo, - trans, - n, - k, - alpha, - A, - offset_A, - lda, - stride_A, - B, - offset_B, - ldb, - stride_B, - beta, - C, - offset_C, - ldc, - stride_C, - batch_count); + rocblas_status status = rocblas_status_success; + status = rocblas_internal_her2k_template(handle, + uplo, + trans, + n, + k, + alpha, + A, + offset_A, + lda, + stride_A, + B, + offset_B, + ldb, + stride_B, + beta, + C, + offset_C, + ldc, + stride_C, + batch_count); + if(status != rocblas_status_success) + return status; + + if(check_numerics) + { + bool is_input = false; + rocblas_status herkx_check_numerics_status + = rocblas_her2k_syr2k_check_numerics(rocblas_herkx_name, + handle, + uplo, + trans, + n, + k, + A, + lda, + stride_A, + B, + ldb, + stride_B, + C, + ldc, + stride_C, + batch_count, + check_numerics, + is_input); + + if(herkx_check_numerics_status != rocblas_status_success) + return herkx_check_numerics_status; + } + return status; } } diff --git a/library/src/blas3/rocblas_herkx_strided_batched.cpp b/library/src/blas3/rocblas_herkx_strided_batched.cpp index 31d214e00..a70eadb9c 100644 --- a/library/src/blas3/rocblas_herkx_strided_batched.cpp +++ b/library/src/blas3/rocblas_herkx_strided_batched.cpp @@ -56,7 +56,8 @@ namespace RETURN_ZERO_DEVICE_MEMORY_SIZE_IF_QUERIED(handle); - auto layer_mode = handle->layer_mode; + auto layer_mode = handle->layer_mode; + auto check_numerics = handle->check_numerics; if(layer_mode & (rocblas_layer_mode_log_trace | rocblas_layer_mode_log_bench | rocblas_layer_mode_log_profile)) @@ -166,28 +167,87 @@ namespace if(arg_status != rocblas_status_continue) return arg_status; + static constexpr bool Hermetian = true; + if(check_numerics) + { + bool is_input = true; + rocblas_status herkx_check_numerics_status + = rocblas_her2k_syr2k_check_numerics(rocblas_herkx_name, + handle, + uplo, + trans, + n, + k, + A, + lda, + stride_a, + B, + ldb, + stride_b, + C, + ldc, + stride_c, + batch_count, + check_numerics, + is_input); + + if(herkx_check_numerics_status != rocblas_status_success) + return herkx_check_numerics_status; + } + static constexpr bool is2K = false; // herkx static constexpr bool BATCHED = false; - return rocblas_internal_her2k_template(handle, - uplo, - trans, - n, - k, - alpha, - A, - offset_A, - lda, - stride_a, - B, - offset_B, - ldb, - stride_b, - beta, - C, - offset_C, - ldc, - stride_c, - batch_count); + rocblas_status status = rocblas_status_success; + status = rocblas_internal_her2k_template(handle, + uplo, + trans, + n, + k, + alpha, + A, + offset_A, + lda, + stride_a, + B, + offset_B, + ldb, + stride_b, + beta, + C, + offset_C, + ldc, + stride_c, + batch_count); + if(status != rocblas_status_success) + return status; + + if(check_numerics) + { + bool is_input = false; + rocblas_status herkx_check_numerics_status + = rocblas_her2k_syr2k_check_numerics(rocblas_herkx_name, + handle, + uplo, + trans, + n, + k, + A, + lda, + stride_a, + B, + ldb, + stride_b, + C, + ldc, + stride_c, + batch_count, + check_numerics, + is_input); + + if(herkx_check_numerics_status != rocblas_status_success) + return herkx_check_numerics_status; + } + return status; } } diff --git a/library/src/blas3/rocblas_symm.cpp b/library/src/blas3/rocblas_symm.cpp index a85ecc514..8b9ae0eba 100644 --- a/library/src/blas3/rocblas_symm.cpp +++ b/library/src/blas3/rocblas_symm.cpp @@ -56,7 +56,8 @@ namespace RETURN_ZERO_DEVICE_MEMORY_SIZE_IF_QUERIED(handle); - auto layer_mode = handle->layer_mode; + auto layer_mode = handle->layer_mode; + auto check_numerics = handle->check_numerics; if(layer_mode & (rocblas_layer_mode_log_trace | rocblas_layer_mode_log_bench | rocblas_layer_mode_log_profile)) @@ -147,26 +148,86 @@ namespace if(arg_status != rocblas_status_continue) return arg_status; - return rocblas_internal_symm_template(handle, - side, - uplo, - m, - n, - alpha, - A, - offset_A, - lda, - stride_A, - B, - offset_B, - ldb, - stride_B, - beta, - C, - offset_C, - ldc, - stride_C, - batch_count); + static constexpr bool Hermetian = false; + if(check_numerics) + { + bool is_input = true; + rocblas_status symm_check_numerics_status + = rocblas_hemm_symm_check_numerics(rocblas_symm_name, + handle, + side, + uplo, + m, + n, + A, + lda, + stride_A, + B, + ldb, + stride_B, + C, + ldc, + stride_C, + batch_count, + check_numerics, + is_input); + + if(symm_check_numerics_status != rocblas_status_success) + return symm_check_numerics_status; + } + + rocblas_status status = rocblas_status_success; + status = rocblas_internal_symm_template(handle, + side, + uplo, + m, + n, + alpha, + A, + offset_A, + lda, + stride_A, + B, + offset_B, + ldb, + stride_B, + beta, + C, + offset_C, + ldc, + stride_C, + batch_count); + + if(status != rocblas_status_success) + return status; + + if(check_numerics) + { + bool is_input = false; + rocblas_status symm_check_numerics_status + = rocblas_hemm_symm_check_numerics(rocblas_symm_name, + handle, + side, + uplo, + m, + n, + A, + lda, + stride_A, + B, + ldb, + stride_B, + C, + ldc, + stride_C, + batch_count, + check_numerics, + is_input); + + if(symm_check_numerics_status != rocblas_status_success) + return symm_check_numerics_status; + } + return status; } } diff --git a/library/src/blas3/rocblas_symm_batched.cpp b/library/src/blas3/rocblas_symm_batched.cpp index 9f36a7160..337e500d2 100644 --- a/library/src/blas3/rocblas_symm_batched.cpp +++ b/library/src/blas3/rocblas_symm_batched.cpp @@ -57,7 +57,8 @@ namespace RETURN_ZERO_DEVICE_MEMORY_SIZE_IF_QUERIED(handle); - auto layer_mode = handle->layer_mode; + auto layer_mode = handle->layer_mode; + auto check_numerics = handle->check_numerics; if(layer_mode & (rocblas_layer_mode_log_trace | rocblas_layer_mode_log_bench | rocblas_layer_mode_log_profile)) @@ -152,26 +153,86 @@ namespace if(arg_status != rocblas_status_continue) return arg_status; - return rocblas_internal_symm_template(handle, - side, - uplo, - m, - n, - alpha, - A, - offset_A, - lda, - stride_A, - B, - offset_B, - ldb, - stride_B, - beta, - C, - offset_C, - ldc, - stride_C, - batch_count); + static constexpr bool Hermetian = false; + if(check_numerics) + { + bool is_input = true; + rocblas_status symm_check_numerics_status + = rocblas_hemm_symm_check_numerics(rocblas_symm_name, + handle, + side, + uplo, + m, + n, + A, + lda, + stride_A, + B, + ldb, + stride_B, + C, + ldc, + stride_C, + batch_count, + check_numerics, + is_input); + + if(symm_check_numerics_status != rocblas_status_success) + return symm_check_numerics_status; + } + + rocblas_status status = rocblas_status_success; + status = rocblas_internal_symm_template(handle, + side, + uplo, + m, + n, + alpha, + A, + offset_A, + lda, + stride_A, + B, + offset_B, + ldb, + stride_B, + beta, + C, + offset_C, + ldc, + stride_C, + batch_count); + + if(status != rocblas_status_success) + return status; + + if(check_numerics) + { + bool is_input = false; + rocblas_status symm_check_numerics_status + = rocblas_hemm_symm_check_numerics(rocblas_symm_name, + handle, + side, + uplo, + m, + n, + A, + lda, + stride_A, + B, + ldb, + stride_B, + C, + ldc, + stride_C, + batch_count, + check_numerics, + is_input); + + if(symm_check_numerics_status != rocblas_status_success) + return symm_check_numerics_status; + } + return status; } } diff --git a/library/src/blas3/rocblas_symm_hemm.hpp b/library/src/blas3/rocblas_symm_hemm.hpp index dabf16780..564dccf3c 100644 --- a/library/src/blas3/rocblas_symm_hemm.hpp +++ b/library/src/blas3/rocblas_symm_hemm.hpp @@ -22,6 +22,7 @@ #pragma once +#include "check_numerics_matrix.hpp" #include "handle.hpp" template @@ -88,3 +89,23 @@ ROCBLAS_INTERNAL_EXPORT_NOINLINE rocblas_status rocblas_int ldc, rocblas_stride strideC, rocblas_int batch_count); + +template +rocblas_status rocblas_hemm_symm_check_numerics(const char* function_name, + rocblas_handle handle, + rocblas_side side, + rocblas_fill uplo, + rocblas_int m, + rocblas_int n, + TConstPtr A, + rocblas_int lda, + rocblas_stride strideA, + TConstPtr B, + rocblas_int ldb, + rocblas_stride strideB, + TPtr C, + rocblas_int ldc, + rocblas_stride strideC, + rocblas_int batch_count, + const int check_numerics, + bool is_input); diff --git a/library/src/blas3/rocblas_symm_kernels.cpp b/library/src/blas3/rocblas_symm_kernels.cpp index 4ce51bec5..6b78cdbb6 100644 --- a/library/src/blas3/rocblas_symm_kernels.cpp +++ b/library/src/blas3/rocblas_symm_kernels.cpp @@ -413,6 +413,88 @@ ROCBLAS_INTERNAL_EXPORT_NOINLINE rocblas_status return rocblas_status_success; } +template +rocblas_status rocblas_hemm_symm_check_numerics(const char* function_name, + rocblas_handle handle, + rocblas_side side, + rocblas_fill uplo, + rocblas_int m, + rocblas_int n, + TConstPtr A, + rocblas_int lda, + rocblas_stride stride_a, + TConstPtr B, + rocblas_int ldb, + rocblas_stride stride_b, + TPtr C, + rocblas_int ldc, + rocblas_stride stride_c, + rocblas_int batch_count, + const int check_numerics, + bool is_input) +{ + rocblas_status check_numerics_status = rocblas_status_success; + if(is_input) + { + rocblas_int rows = (side == rocblas_side_left ? m : n); + rocblas_int cols = (side == rocblas_side_left ? m : n); + + check_numerics_status = rocblas_internal_check_numerics_matrix_template( + function_name, + handle, + rocblas_operation_none, + uplo, + HERM ? rocblas_client_hermitian_matrix : rocblas_client_symmetric_matrix, + rows, + cols, + A, + 0, + lda, + stride_a, + batch_count, + check_numerics, + is_input); + if(check_numerics_status != rocblas_status_success) + return check_numerics_status; + + check_numerics_status + = rocblas_internal_check_numerics_matrix_template(function_name, + handle, + rocblas_operation_none, + rocblas_fill_full, + rocblas_client_general_matrix, + m, + n, + B, + 0, + ldb, + stride_b, + batch_count, + check_numerics, + is_input); + if(check_numerics_status != rocblas_status_success) + return check_numerics_status; + } + + check_numerics_status + = rocblas_internal_check_numerics_matrix_template(function_name, + handle, + rocblas_operation_none, + rocblas_fill_full, + rocblas_client_general_matrix, + m, + n, + C, + 0, + ldc, + stride_c, + batch_count, + check_numerics, + is_input); + + return check_numerics_status; +} + // Instantiations below will need to be manually updated to match any change in // template parameters in the files symm*.cpp @@ -461,5 +543,50 @@ INSTANTIATE_SYMM_TEMPLATE( true, rocblas_float_complex const*, rocblas_float_co INSTANTIATE_SYMM_TEMPLATE(false, rocblas_double_complex const*, rocblas_double_complex const* const*, rocblas_double_complex* const*) INSTANTIATE_SYMM_TEMPLATE( true, rocblas_double_complex const*, rocblas_double_complex const* const*, rocblas_double_complex* const*) -#undef INSTANTIATE_SYMM_TEMPLATE + +#undef INSTANTIATE_HEMM_SYMM_NUMERICS + +#ifdef INSTANTIATE_HEMM_SYMM_NUMERICS +#error INSTANTIATE_HEMM_SYMM_NUMERICS already defined +#endif + +#define INSTANTIATE_HEMM_SYMM_NUMERICS(HERM_, TConstPtr_, TPtr_) \ +template rocblas_status rocblas_hemm_symm_check_numerics \ + \ + (const char* function_name, \ + rocblas_handle handle, \ + rocblas_side side, \ + rocblas_fill uplo, \ + rocblas_int m, \ + rocblas_int n, \ + TConstPtr_ A, \ + rocblas_int lda, \ + rocblas_stride strideA, \ + TConstPtr_ B, \ + rocblas_int ldb, \ + rocblas_stride strideB, \ + TPtr_ C, \ + rocblas_int ldc, \ + rocblas_stride strideC, \ + rocblas_int batch_count, \ + const int check_numerics, \ + bool is_input); + +// instantiate for rocblas_Xhemm_Xsymm and rocblas_Xhemm_Xsymm_strided_batched +INSTANTIATE_HEMM_SYMM_NUMERICS(false, float const*, float*) +INSTANTIATE_HEMM_SYMM_NUMERICS(false, double const*, double*) +INSTANTIATE_HEMM_SYMM_NUMERICS(false, rocblas_float_complex const*, rocblas_float_complex*) +INSTANTIATE_HEMM_SYMM_NUMERICS( true, rocblas_float_complex const*, rocblas_float_complex*) +INSTANTIATE_HEMM_SYMM_NUMERICS(false, rocblas_double_complex const*, rocblas_double_complex*) +INSTANTIATE_HEMM_SYMM_NUMERICS( true, rocblas_double_complex const*, rocblas_double_complex*) + +// instantiate for rocblas_Xhemm_Xsymm_batched +INSTANTIATE_HEMM_SYMM_NUMERICS(false, float const* const*, float* const*) +INSTANTIATE_HEMM_SYMM_NUMERICS(false, double const* const*, double* const*) +INSTANTIATE_HEMM_SYMM_NUMERICS(false, rocblas_float_complex const* const*, rocblas_float_complex* const*) +INSTANTIATE_HEMM_SYMM_NUMERICS( true, rocblas_float_complex const* const*, rocblas_float_complex* const*) +INSTANTIATE_HEMM_SYMM_NUMERICS(false, rocblas_double_complex const* const*, rocblas_double_complex* const*) +INSTANTIATE_HEMM_SYMM_NUMERICS( true, rocblas_double_complex const* const*, rocblas_double_complex* const*) + +#undef INSTANTIATE_HEMM_SYMM_NUMERICS // clang-format on diff --git a/library/src/blas3/rocblas_symm_strided_batched.cpp b/library/src/blas3/rocblas_symm_strided_batched.cpp index 3cd4aa8a5..1fe6d72af 100644 --- a/library/src/blas3/rocblas_symm_strided_batched.cpp +++ b/library/src/blas3/rocblas_symm_strided_batched.cpp @@ -60,7 +60,8 @@ namespace RETURN_ZERO_DEVICE_MEMORY_SIZE_IF_QUERIED(handle); - auto layer_mode = handle->layer_mode; + auto layer_mode = handle->layer_mode; + auto check_numerics = handle->check_numerics; if(layer_mode & (rocblas_layer_mode_log_trace | rocblas_layer_mode_log_bench | rocblas_layer_mode_log_profile)) @@ -169,26 +170,85 @@ namespace if(arg_status != rocblas_status_continue) return arg_status; - return rocblas_internal_symm_template(handle, - side, - uplo, - m, - n, - alpha, - A, - offset_A, - lda, - stride_a, - B, - offset_B, - ldb, - stride_b, - beta, - C, - offset_C, - ldc, - stride_c, - batch_count); + static constexpr bool Hermetian = false; + if(check_numerics) + { + bool is_input = true; + rocblas_status symm_check_numerics_status + = rocblas_hemm_symm_check_numerics(rocblas_symm_name, + handle, + side, + uplo, + m, + n, + A, + lda, + stride_a, + B, + ldb, + stride_b, + C, + ldc, + stride_c, + batch_count, + check_numerics, + is_input); + + if(symm_check_numerics_status != rocblas_status_success) + return symm_check_numerics_status; + } + rocblas_status status = rocblas_status_success; + status = rocblas_internal_symm_template(handle, + side, + uplo, + m, + n, + alpha, + A, + offset_A, + lda, + stride_a, + B, + offset_B, + ldb, + stride_b, + beta, + C, + offset_C, + ldc, + stride_c, + batch_count); + + if(status != rocblas_status_success) + return status; + + if(check_numerics) + { + bool is_input = false; + rocblas_status symm_check_numerics_status + = rocblas_hemm_symm_check_numerics(rocblas_symm_name, + handle, + side, + uplo, + m, + n, + A, + lda, + stride_a, + B, + ldb, + stride_b, + C, + ldc, + stride_c, + batch_count, + check_numerics, + is_input); + + if(symm_check_numerics_status != rocblas_status_success) + return symm_check_numerics_status; + } + return status; } } diff --git a/library/src/blas3/rocblas_syr2k.cpp b/library/src/blas3/rocblas_syr2k.cpp index 0626d7ce3..aac297d0c 100644 --- a/library/src/blas3/rocblas_syr2k.cpp +++ b/library/src/blas3/rocblas_syr2k.cpp @@ -56,7 +56,8 @@ namespace RETURN_ZERO_DEVICE_MEMORY_SIZE_IF_QUERIED(handle); - auto layer_mode = handle->layer_mode; + auto layer_mode = handle->layer_mode; + auto check_numerics = handle->check_numerics; if(layer_mode & (rocblas_layer_mode_log_trace | rocblas_layer_mode_log_bench | rocblas_layer_mode_log_profile)) @@ -147,28 +148,88 @@ namespace if(arg_status != rocblas_status_continue) return arg_status; + static constexpr bool Hermetian = false; + if(check_numerics) + { + bool is_input = true; + rocblas_status syr2k_check_numerics_status + = rocblas_her2k_syr2k_check_numerics(rocblas_syr2k_name, + handle, + uplo, + transA, + n, + k, + A, + lda, + stride_A, + B, + ldb, + stride_B, + C, + ldc, + stride_C, + batch_count, + check_numerics, + is_input); + + if(syr2k_check_numerics_status != rocblas_status_success) + return syr2k_check_numerics_status; + } + static constexpr bool is2K = true; static constexpr bool BATCHED = false; - return rocblas_internal_syr2k_template(handle, - uplo, - transA, - n, - k, - alpha, - A, - offset_A, - lda, - stride_A, - B, - offset_B, - ldb, - stride_B, - beta, - C, - offset_C, - ldc, - stride_C, - batch_count); + rocblas_status status = rocblas_status_success; + status = rocblas_internal_syr2k_template(handle, + uplo, + transA, + n, + k, + alpha, + A, + offset_A, + lda, + stride_A, + B, + offset_B, + ldb, + stride_B, + beta, + C, + offset_C, + ldc, + stride_C, + batch_count); + + if(status != rocblas_status_success) + return status; + + if(check_numerics) + { + bool is_input = false; + rocblas_status syr2k_check_numerics_status + = rocblas_her2k_syr2k_check_numerics(rocblas_syr2k_name, + handle, + uplo, + transA, + n, + k, + A, + lda, + stride_A, + B, + ldb, + stride_B, + C, + ldc, + stride_C, + batch_count, + check_numerics, + is_input); + + if(syr2k_check_numerics_status != rocblas_status_success) + return syr2k_check_numerics_status; + } + return status; } } diff --git a/library/src/blas3/rocblas_syr2k_batched.cpp b/library/src/blas3/rocblas_syr2k_batched.cpp index 432868f10..5a683ec19 100644 --- a/library/src/blas3/rocblas_syr2k_batched.cpp +++ b/library/src/blas3/rocblas_syr2k_batched.cpp @@ -57,7 +57,8 @@ namespace RETURN_ZERO_DEVICE_MEMORY_SIZE_IF_QUERIED(handle); - auto layer_mode = handle->layer_mode; + auto layer_mode = handle->layer_mode; + auto check_numerics = handle->check_numerics; if(layer_mode & (rocblas_layer_mode_log_trace | rocblas_layer_mode_log_bench | rocblas_layer_mode_log_profile)) @@ -152,28 +153,87 @@ namespace if(arg_status != rocblas_status_continue) return arg_status; + static constexpr bool Hermetian = false; + if(check_numerics) + { + bool is_input = true; + rocblas_status syr2k_check_numerics_status + = rocblas_her2k_syr2k_check_numerics(rocblas_syr2k_name, + handle, + uplo, + transA, + n, + k, + A, + lda, + stride_A, + B, + ldb, + stride_B, + C, + ldc, + stride_C, + batch_count, + check_numerics, + is_input); + + if(syr2k_check_numerics_status != rocblas_status_success) + return syr2k_check_numerics_status; + } + static constexpr bool is2K = true; static constexpr bool BATCHED = true; - return rocblas_internal_syr2k_template(handle, - uplo, - transA, - n, - k, - alpha, - A, - offset_A, - lda, - stride_A, - B, - offset_B, - ldb, - stride_B, - beta, - C, - offset_C, - ldc, - stride_C, - batch_count); + rocblas_status status = rocblas_status_success; + status = rocblas_internal_syr2k_template(handle, + uplo, + transA, + n, + k, + alpha, + A, + offset_A, + lda, + stride_A, + B, + offset_B, + ldb, + stride_B, + beta, + C, + offset_C, + ldc, + stride_C, + batch_count); + if(status != rocblas_status_success) + return status; + + if(check_numerics) + { + bool is_input = false; + rocblas_status syr2k_check_numerics_status + = rocblas_her2k_syr2k_check_numerics(rocblas_syr2k_name, + handle, + uplo, + transA, + n, + k, + A, + lda, + stride_A, + B, + ldb, + stride_B, + C, + ldc, + stride_C, + batch_count, + check_numerics, + is_input); + + if(syr2k_check_numerics_status != rocblas_status_success) + return syr2k_check_numerics_status; + } + return status; } } diff --git a/library/src/blas3/rocblas_syr2k_her2k.hpp b/library/src/blas3/rocblas_syr2k_her2k.hpp index 8ced22f99..8728d3bd5 100644 --- a/library/src/blas3/rocblas_syr2k_her2k.hpp +++ b/library/src/blas3/rocblas_syr2k_her2k.hpp @@ -22,6 +22,7 @@ #pragma once +#include "check_numerics_matrix.hpp" #include "handle.hpp" #include "herk_scale_device.hpp" @@ -206,3 +207,23 @@ ROCBLAS_INTERNAL_EXPORT_NOINLINE rocblas_status rocblas_int ldc, rocblas_stride strideC, rocblas_int batch_count); + +template +rocblas_status rocblas_her2k_syr2k_check_numerics(const char* function_name, + rocblas_handle handle, + rocblas_fill uplo, + rocblas_operation trans, + rocblas_int n, + rocblas_int k, + TConstPtr A, + rocblas_int lda, + rocblas_stride strideA, + TConstPtr B, + rocblas_int ldb, + rocblas_stride strideB, + TPtr C, + rocblas_int ldc, + rocblas_stride strideC, + rocblas_int batch_count, + const int check_numerics, + bool is_input); diff --git a/library/src/blas3/rocblas_syr2k_her2k_kernels.cpp b/library/src/blas3/rocblas_syr2k_her2k_kernels.cpp index 1fc798fff..e7722d7b3 100644 --- a/library/src/blas3/rocblas_syr2k_her2k_kernels.cpp +++ b/library/src/blas3/rocblas_syr2k_her2k_kernels.cpp @@ -689,6 +689,86 @@ ROCBLAS_INTERNAL_EXPORT_NOINLINE rocblas_status return rocblas_status_success; } +template +rocblas_status rocblas_her2k_syr2k_check_numerics(const char* function_name, + rocblas_handle handle, + rocblas_fill uplo, + rocblas_operation trans, + rocblas_int n, + rocblas_int k, + TConstPtr A, + rocblas_int lda, + rocblas_stride strideA, + TConstPtr B, + rocblas_int ldb, + rocblas_stride strideB, + TPtr C, + rocblas_int ldc, + rocblas_stride strideC, + rocblas_int batch_count, + const int check_numerics, + bool is_input) +{ + rocblas_status check_numerics_status = rocblas_status_success; + + if(is_input) + { + check_numerics_status + = rocblas_internal_check_numerics_matrix_template(function_name, + handle, + trans, + rocblas_fill_full, + rocblas_client_general_matrix, + n, + k, + A, + 0, + lda, + strideA, + batch_count, + check_numerics, + is_input); + if(check_numerics_status != rocblas_status_success) + return check_numerics_status; + + check_numerics_status + = rocblas_internal_check_numerics_matrix_template(function_name, + handle, + trans, + rocblas_fill_full, + rocblas_client_general_matrix, + n, + k, + B, + 0, + ldb, + strideB, + batch_count, + check_numerics, + is_input); + if(check_numerics_status != rocblas_status_success) + return check_numerics_status; + } + + check_numerics_status = rocblas_internal_check_numerics_matrix_template( + function_name, + handle, + rocblas_operation_none, + uplo, + HERM ? rocblas_client_hermitian_matrix : rocblas_client_symmetric_matrix, + n, + n, + C, + 0, + ldc, + strideC, + batch_count, + check_numerics, + is_input); + + return check_numerics_status; +} + // Instantiations below will need to be manually updated to match any change in // template parameters in the files syr2k*.cpp or her2k*.cpp @@ -780,4 +860,48 @@ INSTANTIATE_HER2K_TEMPLATE( true, false, rocblas_double_complex const*, rocblas_ #undef INSTANTIATE_HER2K_TEMPLATE + +#ifdef INSTANTIATE_HER2K_SYR2K_NUMERICS +#error INSTANTIATE_HER2K_SYR2K_NUMERICS already defined +#endif + +#define INSTANTIATE_HER2K_SYR2K_NUMERICS(HERM_, TConstPtr_, TPtr_) \ +template rocblas_status rocblas_her2k_syr2k_check_numerics \ + \ + (const char* function_name, \ + rocblas_handle handle, \ + rocblas_fill uplo, \ + rocblas_operation trans, \ + rocblas_int n, \ + rocblas_int k, \ + TConstPtr_ A, \ + rocblas_int lda, \ + rocblas_stride strideA, \ + TConstPtr_ B, \ + rocblas_int ldb, \ + rocblas_stride strideB, \ + TPtr_ C, \ + rocblas_int ldc, \ + rocblas_stride strideC, \ + rocblas_int batch_count, \ + const int check_numerics, \ + bool is_input); + +// instantiate for rocblas_Xher2k_Xsyr2k and rocblas_Xher2k_Xsyr2k_strided_batched +INSTANTIATE_HER2K_SYR2K_NUMERICS(false, float const*, float*) +INSTANTIATE_HER2K_SYR2K_NUMERICS(false, double const*, double*) +INSTANTIATE_HER2K_SYR2K_NUMERICS(false, rocblas_float_complex const*, rocblas_float_complex*) +INSTANTIATE_HER2K_SYR2K_NUMERICS( true, rocblas_float_complex const*, rocblas_float_complex*) +INSTANTIATE_HER2K_SYR2K_NUMERICS(false, rocblas_double_complex const*, rocblas_double_complex*) +INSTANTIATE_HER2K_SYR2K_NUMERICS( true, rocblas_double_complex const*, rocblas_double_complex*) + +// instantiate for rocblas_Xher2k_Xsyr2k_batched +INSTANTIATE_HER2K_SYR2K_NUMERICS(false, float const* const*, float* const*) +INSTANTIATE_HER2K_SYR2K_NUMERICS(false, double const* const*, double* const*) +INSTANTIATE_HER2K_SYR2K_NUMERICS(false, rocblas_float_complex const* const*, rocblas_float_complex* const*) +INSTANTIATE_HER2K_SYR2K_NUMERICS( true, rocblas_float_complex const* const*, rocblas_float_complex* const*) +INSTANTIATE_HER2K_SYR2K_NUMERICS(false, rocblas_double_complex const* const*, rocblas_double_complex* const*) +INSTANTIATE_HER2K_SYR2K_NUMERICS( true, rocblas_double_complex const* const*, rocblas_double_complex* const*) + +#undef INSTANTIATE_HER2K_SYR2K_NUMERICS // clang-format on diff --git a/library/src/blas3/rocblas_syr2k_strided_batched.cpp b/library/src/blas3/rocblas_syr2k_strided_batched.cpp index b360717c3..c7f27b393 100644 --- a/library/src/blas3/rocblas_syr2k_strided_batched.cpp +++ b/library/src/blas3/rocblas_syr2k_strided_batched.cpp @@ -60,7 +60,8 @@ namespace RETURN_ZERO_DEVICE_MEMORY_SIZE_IF_QUERIED(handle); - auto layer_mode = handle->layer_mode; + auto layer_mode = handle->layer_mode; + auto check_numerics = handle->check_numerics; if(layer_mode & (rocblas_layer_mode_log_trace | rocblas_layer_mode_log_bench | rocblas_layer_mode_log_profile)) @@ -169,28 +170,88 @@ namespace if(arg_status != rocblas_status_continue) return arg_status; + static constexpr bool Hermetian = false; + if(check_numerics) + { + bool is_input = true; + rocblas_status syr2k_check_numerics_status + = rocblas_her2k_syr2k_check_numerics(rocblas_syr2k_name, + handle, + uplo, + transA, + n, + k, + A, + lda, + stride_a, + B, + ldb, + stride_b, + C, + ldc, + stride_c, + batch_count, + check_numerics, + is_input); + + if(syr2k_check_numerics_status != rocblas_status_success) + return syr2k_check_numerics_status; + } + static constexpr bool is2K = true; static constexpr bool BATCHED = false; - return rocblas_internal_syr2k_template(handle, - uplo, - transA, - n, - k, - alpha, - A, - offset_A, - lda, - stride_a, - B, - offset_B, - ldb, - stride_b, - beta, - C, - offset_C, - ldc, - stride_c, - batch_count); + rocblas_status status = rocblas_status_success; + status = rocblas_internal_syr2k_template(handle, + uplo, + transA, + n, + k, + alpha, + A, + offset_A, + lda, + stride_a, + B, + offset_B, + ldb, + stride_b, + beta, + C, + offset_C, + ldc, + stride_c, + batch_count); + + if(status != rocblas_status_success) + return status; + + if(check_numerics) + { + bool is_input = false; + rocblas_status syr2k_check_numerics_status + = rocblas_her2k_syr2k_check_numerics(rocblas_syr2k_name, + handle, + uplo, + transA, + n, + k, + A, + lda, + stride_a, + B, + ldb, + stride_b, + C, + ldc, + stride_c, + batch_count, + check_numerics, + is_input); + + if(syr2k_check_numerics_status != rocblas_status_success) + return syr2k_check_numerics_status; + } + return status; } } diff --git a/library/src/blas3/rocblas_syrk.cpp b/library/src/blas3/rocblas_syrk.cpp index 405c466c0..14ad9cf7c 100644 --- a/library/src/blas3/rocblas_syrk.cpp +++ b/library/src/blas3/rocblas_syrk.cpp @@ -54,7 +54,8 @@ namespace RETURN_ZERO_DEVICE_MEMORY_SIZE_IF_QUERIED(handle); - auto layer_mode = handle->layer_mode; + auto layer_mode = handle->layer_mode; + auto check_numerics = handle->check_numerics; if(layer_mode & (rocblas_layer_mode_log_trace | rocblas_layer_mode_log_bench | rocblas_layer_mode_log_profile)) @@ -135,22 +136,76 @@ namespace if(arg_status != rocblas_status_continue) return arg_status; - return rocblas_internal_syrk_template(handle, - uplo, - transA, - n, - k, - alpha, - A, - offset_A, - lda, - stride_A, - beta, - C, - offset_C, - ldc, - stride_C, - batch_count); + static constexpr bool Hermetian = false; + if(check_numerics) + { + bool is_input = true; + rocblas_status syrk_check_numerics_status + = rocblas_herk_syrk_check_numerics(rocblas_syrk_name, + handle, + uplo, + transA, + n, + k, + A, + lda, + stride_A, + C, + ldc, + stride_C, + batch_count, + check_numerics, + is_input); + + if(syrk_check_numerics_status != rocblas_status_success) + return syrk_check_numerics_status; + } + + rocblas_status status = rocblas_status_success; + status = rocblas_internal_syrk_template(handle, + uplo, + transA, + n, + k, + alpha, + A, + offset_A, + lda, + stride_A, + beta, + C, + offset_C, + ldc, + stride_C, + batch_count); + + if(status != rocblas_status_success) + return status; + + if(check_numerics) + { + bool is_input = false; + rocblas_status syrk_check_numerics_status + = rocblas_herk_syrk_check_numerics(rocblas_syrk_name, + handle, + uplo, + transA, + n, + k, + A, + lda, + stride_A, + C, + ldc, + stride_C, + batch_count, + check_numerics, + is_input); + + if(syrk_check_numerics_status != rocblas_status_success) + return syrk_check_numerics_status; + } + return status; } } diff --git a/library/src/blas3/rocblas_syrk_batched.cpp b/library/src/blas3/rocblas_syrk_batched.cpp index 46e5cffaf..d97e4aa0b 100644 --- a/library/src/blas3/rocblas_syrk_batched.cpp +++ b/library/src/blas3/rocblas_syrk_batched.cpp @@ -55,7 +55,8 @@ namespace RETURN_ZERO_DEVICE_MEMORY_SIZE_IF_QUERIED(handle); - auto layer_mode = handle->layer_mode; + auto layer_mode = handle->layer_mode; + auto check_numerics = handle->check_numerics; if(layer_mode & (rocblas_layer_mode_log_trace | rocblas_layer_mode_log_bench | rocblas_layer_mode_log_profile)) @@ -140,22 +141,75 @@ namespace if(arg_status != rocblas_status_continue) return arg_status; - return rocblas_internal_syrk_template(handle, - uplo, - transA, - n, - k, - alpha, - A, - offset_A, - lda, - stride_A, - beta, - C, - offset_C, - ldc, - stride_C, - batch_count); + static constexpr bool Hermetian = false; + if(check_numerics) + { + bool is_input = true; + rocblas_status syrk_check_numerics_status + = rocblas_herk_syrk_check_numerics(rocblas_syrk_name, + handle, + uplo, + transA, + n, + k, + A, + lda, + stride_A, + C, + ldc, + stride_C, + batch_count, + check_numerics, + is_input); + + if(syrk_check_numerics_status != rocblas_status_success) + return syrk_check_numerics_status; + } + + rocblas_status status = rocblas_status_success; + status = rocblas_internal_syrk_template(handle, + uplo, + transA, + n, + k, + alpha, + A, + offset_A, + lda, + stride_A, + beta, + C, + offset_C, + ldc, + stride_C, + batch_count); + if(status != rocblas_status_success) + return status; + + if(check_numerics) + { + bool is_input = false; + rocblas_status syrk_check_numerics_status + = rocblas_herk_syrk_check_numerics(rocblas_syrk_name, + handle, + uplo, + transA, + n, + k, + A, + lda, + stride_A, + C, + ldc, + stride_C, + batch_count, + check_numerics, + is_input); + + if(syrk_check_numerics_status != rocblas_status_success) + return syrk_check_numerics_status; + } + return status; } } diff --git a/library/src/blas3/rocblas_syrk_herk.hpp b/library/src/blas3/rocblas_syrk_herk.hpp index 6206860e0..e1cbbb476 100644 --- a/library/src/blas3/rocblas_syrk_herk.hpp +++ b/library/src/blas3/rocblas_syrk_herk.hpp @@ -22,6 +22,7 @@ #pragma once +#include "check_numerics_matrix.hpp" #include "handle.hpp" template @@ -150,3 +151,20 @@ ROCBLAS_INTERNAL_EXPORT_NOINLINE rocblas_status rocblas_int ldc, rocblas_stride strideC, rocblas_int batch_count); + +template +rocblas_status rocblas_herk_syrk_check_numerics(const char* function_name, + rocblas_handle handle, + rocblas_fill uplo, + rocblas_operation trans, + rocblas_int n, + rocblas_int k, + TConstPtr A, + rocblas_int lda, + rocblas_stride strideA, + TPtr C, + rocblas_int ldc, + rocblas_stride strideC, + rocblas_int batch_count, + const int check_numerics, + bool is_input); diff --git a/library/src/blas3/rocblas_syrk_herk_kernels.cpp b/library/src/blas3/rocblas_syrk_herk_kernels.cpp index 0c49167eb..1abe27ac2 100644 --- a/library/src/blas3/rocblas_syrk_herk_kernels.cpp +++ b/library/src/blas3/rocblas_syrk_herk_kernels.cpp @@ -550,6 +550,63 @@ ROCBLAS_INTERNAL_EXPORT_NOINLINE rocblas_status return rocblas_status_success; } +template +rocblas_status rocblas_herk_syrk_check_numerics(const char* function_name, + rocblas_handle handle, + rocblas_fill uplo, + rocblas_operation trans, + rocblas_int n, + rocblas_int k, + TConstPtr A, + rocblas_int lda, + rocblas_stride stride_a, + TPtr C, + rocblas_int ldc, + rocblas_stride stride_c, + rocblas_int batch_count, + const int check_numerics, + bool is_input) +{ + rocblas_status check_numerics_status = rocblas_status_success; + if(is_input) + { + check_numerics_status + = rocblas_internal_check_numerics_matrix_template(function_name, + handle, + trans, + rocblas_fill_full, + rocblas_client_general_matrix, + n, + k, + A, + 0, + lda, + stride_a, + batch_count, + check_numerics, + is_input); + if(check_numerics_status != rocblas_status_success) + return check_numerics_status; + } + + check_numerics_status = rocblas_internal_check_numerics_matrix_template( + function_name, + handle, + rocblas_operation_none, + uplo, + HERM ? rocblas_client_hermitian_matrix : rocblas_client_symmetric_matrix, + n, + n, + C, + 0, + ldc, + stride_c, + batch_count, + check_numerics, + is_input); + + return check_numerics_status; +} // Instantiations below will need to be manually updated to match any change in // template parameters in the files syrk*.cpp or herk*.cpp @@ -621,4 +678,44 @@ INSTANTIATE_HERK_TEMPLATE( double const*, rocblas_double_complex const* const*, #undef INSTANTIATE_HERK_TEMPLATE +#ifdef INSTANTIATE_HERK_SYRK_NUMERICS +#error INSTANTIATE_HERK_SYRK_NUMERICS already defined +#endif + +#define INSTANTIATE_HERK_SYRK_NUMERICS(HERM_, TConstPtr_, TPtr_) \ +template rocblas_status rocblas_herk_syrk_check_numerics \ + \ + (const char* function_name, \ + rocblas_handle handle, \ + rocblas_fill uplo, \ + rocblas_operation trans, \ + rocblas_int n, \ + rocblas_int k, \ + TConstPtr_ A, \ + rocblas_int lda, \ + rocblas_stride strideA, \ + TPtr_ C, \ + rocblas_int ldc, \ + rocblas_stride strideC, \ + rocblas_int batch_count, \ + const int check_numerics, \ + bool is_input); + +// instantiate for rocblas_Xherk_Xsyrk and rocblas_Xherk_Xsyrk_strided_batched +INSTANTIATE_HERK_SYRK_NUMERICS(false, float const*, float*) +INSTANTIATE_HERK_SYRK_NUMERICS(false, double const*, double*) +INSTANTIATE_HERK_SYRK_NUMERICS(false, rocblas_float_complex const*, rocblas_float_complex*) +INSTANTIATE_HERK_SYRK_NUMERICS( true, rocblas_float_complex const*, rocblas_float_complex*) +INSTANTIATE_HERK_SYRK_NUMERICS(false, rocblas_double_complex const*, rocblas_double_complex*) +INSTANTIATE_HERK_SYRK_NUMERICS( true, rocblas_double_complex const*, rocblas_double_complex*) + +// instantiate for rocblas_Xherk_Xsyrk_batched +INSTANTIATE_HERK_SYRK_NUMERICS(false, float const* const*, float* const*) +INSTANTIATE_HERK_SYRK_NUMERICS(false, double const* const*, double* const*) +INSTANTIATE_HERK_SYRK_NUMERICS(false, rocblas_float_complex const* const*, rocblas_float_complex* const*) +INSTANTIATE_HERK_SYRK_NUMERICS( true, rocblas_float_complex const* const*, rocblas_float_complex* const*) +INSTANTIATE_HERK_SYRK_NUMERICS(false, rocblas_double_complex const* const*, rocblas_double_complex* const*) +INSTANTIATE_HERK_SYRK_NUMERICS( true, rocblas_double_complex const* const*, rocblas_double_complex* const*) + +#undef INSTANTIATE_HERK_SYRK_NUMERICS // clang-format on diff --git a/library/src/blas3/rocblas_syrk_strided_batched.cpp b/library/src/blas3/rocblas_syrk_strided_batched.cpp index 19bff5300..5b69966da 100644 --- a/library/src/blas3/rocblas_syrk_strided_batched.cpp +++ b/library/src/blas3/rocblas_syrk_strided_batched.cpp @@ -57,7 +57,8 @@ namespace RETURN_ZERO_DEVICE_MEMORY_SIZE_IF_QUERIED(handle); - auto layer_mode = handle->layer_mode; + auto layer_mode = handle->layer_mode; + auto check_numerics = handle->check_numerics; if(layer_mode & (rocblas_layer_mode_log_trace | rocblas_layer_mode_log_bench | rocblas_layer_mode_log_profile)) @@ -151,22 +152,75 @@ namespace if(arg_status != rocblas_status_continue) return arg_status; - return rocblas_internal_syrk_template(handle, - uplo, - transA, - n, - k, - alpha, - A, - offset_A, - lda, - stride_a, - beta, - C, - offset_C, - ldc, - stride_c, - batch_count); + static constexpr bool Hermetian = false; + if(check_numerics) + { + bool is_input = true; + rocblas_status syrk_check_numerics_status + = rocblas_herk_syrk_check_numerics(rocblas_syrk_name, + handle, + uplo, + transA, + n, + k, + A, + lda, + stride_a, + C, + ldc, + stride_c, + batch_count, + check_numerics, + is_input); + + if(syrk_check_numerics_status != rocblas_status_success) + return syrk_check_numerics_status; + } + + rocblas_status status = rocblas_status_success; + status = rocblas_internal_syrk_template(handle, + uplo, + transA, + n, + k, + alpha, + A, + offset_A, + lda, + stride_a, + beta, + C, + offset_C, + ldc, + stride_c, + batch_count); + if(status != rocblas_status_success) + return status; + + if(check_numerics) + { + bool is_input = false; + rocblas_status syrk_check_numerics_status + = rocblas_herk_syrk_check_numerics(rocblas_syrk_name, + handle, + uplo, + transA, + n, + k, + A, + lda, + stride_a, + C, + ldc, + stride_c, + batch_count, + check_numerics, + is_input); + + if(syrk_check_numerics_status != rocblas_status_success) + return syrk_check_numerics_status; + } + return status; } } diff --git a/library/src/blas3/rocblas_syrkx.cpp b/library/src/blas3/rocblas_syrkx.cpp index c60ecfddb..f94074271 100644 --- a/library/src/blas3/rocblas_syrkx.cpp +++ b/library/src/blas3/rocblas_syrkx.cpp @@ -68,7 +68,8 @@ namespace copy_alpha_beta_to_host_if_on_device(handle, alpha, beta, alpha_h, beta_h, k)); auto saved_pointer_mode = handle->push_pointer_mode(rocblas_pointer_mode_host); - auto layer_mode = handle->layer_mode; + auto layer_mode = handle->layer_mode; + auto check_numerics = handle->check_numerics; if(layer_mode & (rocblas_layer_mode_log_trace | rocblas_layer_mode_log_bench | rocblas_layer_mode_log_profile)) @@ -160,28 +161,87 @@ namespace if(arg_status != rocblas_status_continue) return arg_status; + static constexpr bool Hermetian = false; + if(check_numerics) + { + bool is_input = true; + rocblas_status syrkx_check_numerics_status + = rocblas_her2k_syr2k_check_numerics(rocblas_syrkx_name, + handle, + uplo, + trans, + n, + k, + A, + lda, + stride_a, + B, + ldb, + stride_b, + C, + ldc, + stride_c, + batch_count, + check_numerics, + is_input); + + if(syrkx_check_numerics_status != rocblas_status_success) + return syrkx_check_numerics_status; + } static constexpr bool BATCHED = false; - return rocblas_internal_syrkx_template(handle, - uplo, - trans, - n, - k, - alpha, - A, - offset_a, - lda, - stride_a, - B, - offset_b, - ldb, - stride_b, - beta, - C, - offset_c, - ldc, - stride_c, - batch_count); + rocblas_status status = rocblas_status_success; + status = rocblas_internal_syrkx_template(handle, + uplo, + trans, + n, + k, + alpha, + A, + offset_a, + lda, + stride_a, + B, + offset_b, + ldb, + stride_b, + beta, + C, + offset_c, + ldc, + stride_c, + batch_count); + + if(status != rocblas_status_success) + return status; + + if(check_numerics) + { + bool is_input = false; + rocblas_status syrkx_check_numerics_status + = rocblas_her2k_syr2k_check_numerics(rocblas_syrkx_name, + handle, + uplo, + trans, + n, + k, + A, + lda, + stride_a, + B, + ldb, + stride_b, + C, + ldc, + stride_c, + batch_count, + check_numerics, + is_input); + + if(syrkx_check_numerics_status != rocblas_status_success) + return syrkx_check_numerics_status; + } + return status; } } /* diff --git a/library/src/blas3/rocblas_syrkx_batched.cpp b/library/src/blas3/rocblas_syrkx_batched.cpp index 318383388..5895d3653 100644 --- a/library/src/blas3/rocblas_syrkx_batched.cpp +++ b/library/src/blas3/rocblas_syrkx_batched.cpp @@ -69,7 +69,8 @@ namespace copy_alpha_beta_to_host_if_on_device(handle, alpha, beta, alpha_h, beta_h, k)); auto saved_pointer_mode = handle->push_pointer_mode(rocblas_pointer_mode_host); - auto layer_mode = handle->layer_mode; + auto layer_mode = handle->layer_mode; + auto check_numerics = handle->check_numerics; if(layer_mode & (rocblas_layer_mode_log_trace | rocblas_layer_mode_log_bench | rocblas_layer_mode_log_profile)) @@ -165,28 +166,87 @@ namespace if(arg_status != rocblas_status_continue) return arg_status; + static constexpr bool Hermetian = false; + if(check_numerics) + { + bool is_input = true; + rocblas_status syrkx_check_numerics_status + = rocblas_her2k_syr2k_check_numerics(rocblas_syrkx_name, + handle, + uplo, + trans, + n, + k, + A, + lda, + stride_a, + B, + ldb, + stride_b, + C, + ldc, + stride_c, + batch_count, + check_numerics, + is_input); + + if(syrkx_check_numerics_status != rocblas_status_success) + return syrkx_check_numerics_status; + } + static constexpr bool BATCHED = true; - return rocblas_internal_syrkx_template(handle, - uplo, - trans, - n, - k, - alpha, - A, - offset_a, - lda, - stride_a, - B, - offset_b, - ldb, - stride_b, - beta, - C, - offset_c, - ldc, - stride_c, - batch_count); + rocblas_status status = rocblas_status_success; + status = rocblas_internal_syrkx_template(handle, + uplo, + trans, + n, + k, + alpha, + A, + offset_a, + lda, + stride_a, + B, + offset_b, + ldb, + stride_b, + beta, + C, + offset_c, + ldc, + stride_c, + batch_count); + if(status != rocblas_status_success) + return status; + + if(check_numerics) + { + bool is_input = false; + rocblas_status syrkx_check_numerics_status + = rocblas_her2k_syr2k_check_numerics(rocblas_syrkx_name, + handle, + uplo, + trans, + n, + k, + A, + lda, + stride_a, + B, + ldb, + stride_b, + C, + ldc, + stride_c, + batch_count, + check_numerics, + is_input); + + if(syrkx_check_numerics_status != rocblas_status_success) + return syrkx_check_numerics_status; + } + return status; } } /* diff --git a/library/src/blas3/rocblas_syrkx_strided_batched.cpp b/library/src/blas3/rocblas_syrkx_strided_batched.cpp index 553e4b0f5..0d4f9d1f4 100644 --- a/library/src/blas3/rocblas_syrkx_strided_batched.cpp +++ b/library/src/blas3/rocblas_syrkx_strided_batched.cpp @@ -72,7 +72,8 @@ namespace copy_alpha_beta_to_host_if_on_device(handle, alpha, beta, alpha_h, beta_h, k)); auto saved_pointer_mode = handle->push_pointer_mode(rocblas_pointer_mode_host); - auto layer_mode = handle->layer_mode; + auto layer_mode = handle->layer_mode; + auto check_numerics = handle->check_numerics; if(layer_mode & (rocblas_layer_mode_log_trace | rocblas_layer_mode_log_bench | rocblas_layer_mode_log_profile)) @@ -182,28 +183,86 @@ namespace if(arg_status != rocblas_status_continue) return arg_status; + static constexpr bool Hermetian = false; + if(check_numerics) + { + bool is_input = true; + rocblas_status syrkx_check_numerics_status + = rocblas_her2k_syr2k_check_numerics(rocblas_syrkx_name, + handle, + uplo, + trans, + n, + k, + A, + lda, + stride_a, + B, + ldb, + stride_b, + C, + ldc, + stride_c, + batch_count, + check_numerics, + is_input); + + if(syrkx_check_numerics_status != rocblas_status_success) + return syrkx_check_numerics_status; + } + static constexpr bool BATCHED = false; + rocblas_status status = rocblas_status_success; + status = rocblas_internal_syrkx_template(handle, + uplo, + trans, + n, + k, + alpha, + A, + offset_a, + lda, + stride_a, + B, + offset_b, + ldb, + stride_b, + beta, + C, + offset_c, + ldc, + stride_c, + batch_count); + if(status != rocblas_status_success) + return status; - return rocblas_internal_syrkx_template(handle, - uplo, - trans, - n, - k, - alpha, - A, - offset_a, - lda, - stride_a, - B, - offset_b, - ldb, - stride_b, - beta, - C, - offset_c, - ldc, - stride_c, - batch_count); + if(check_numerics) + { + bool is_input = false; + rocblas_status syrkx_check_numerics_status + = rocblas_her2k_syr2k_check_numerics(rocblas_syrkx_name, + handle, + uplo, + trans, + n, + k, + A, + lda, + stride_a, + B, + ldb, + stride_b, + C, + ldc, + stride_c, + batch_count, + check_numerics, + is_input); + + if(syrkx_check_numerics_status != rocblas_status_success) + return syrkx_check_numerics_status; + } + return status; } } diff --git a/library/src/blas3/rocblas_trmm.cpp b/library/src/blas3/rocblas_trmm.cpp index d82e3370d..4f83fab53 100644 --- a/library/src/blas3/rocblas_trmm.cpp +++ b/library/src/blas3/rocblas_trmm.cpp @@ -90,7 +90,9 @@ namespace copy_alpha_beta_to_host_if_on_device(handle, alpha, beta, alpha_h, beta_h, m && n)); auto saved_pointer_mode = handle->push_pointer_mode(rocblas_pointer_mode_host); - auto layer_mode = handle->layer_mode; + auto layer_mode = handle->layer_mode; + auto check_numerics = handle->check_numerics; + if(layer_mode & (rocblas_layer_mode_log_trace | rocblas_layer_mode_log_bench | rocblas_layer_mode_log_profile) @@ -192,30 +194,84 @@ namespace if(rocblas_pointer_mode_host == handle->pointer_mode && !a) return rocblas_status_invalid_pointer; + if(check_numerics) + { + bool is_input = true; + rocblas_status trmm_check_numerics_status + = rocblas_trmm_check_numerics(rocblas_trmm_name, + handle, + side, + uplo, + transa, + m, + n, + a, + lda, + stride_a, + b, + ldb, + stride_b, + batch_count, + check_numerics, + is_input); + if(trmm_check_numerics_status != rocblas_status_success) + return trmm_check_numerics_status; + } + constexpr bool BATCHED = false; - return rocblas_internal_trmm_template(handle, - side, - uplo, - transa, - diag, - m, - n, - alpha, - stride_alpha, - a, - offset_a, - lda, - stride_a, - (const T*)b, - offset_b, - ldb, - stride_b, - b, - offset_b, - ldb, - stride_b, - batch_count); + rocblas_status status = rocblas_status_success; + + status = rocblas_internal_trmm_template(handle, + side, + uplo, + transa, + diag, + m, + n, + alpha, + stride_alpha, + a, + offset_a, + lda, + stride_a, + (const T*)b, + offset_b, + ldb, + stride_b, + b, + offset_b, + ldb, + stride_b, + batch_count); + + if(status != rocblas_status_success) + return status; + + if(check_numerics) + { + bool is_input = false; + rocblas_status trmm_check_numerics_status + = rocblas_trmm_check_numerics(rocblas_trmm_name, + handle, + side, + uplo, + transa, + m, + n, + a, + lda, + stride_a, + b, + ldb, + stride_b, + batch_count, + check_numerics, + is_input); + if(trmm_check_numerics_status != rocblas_status_success) + return trmm_check_numerics_status; + } + return status; } } // namespace diff --git a/library/src/blas3/rocblas_trmm.hpp b/library/src/blas3/rocblas_trmm.hpp index dedcff32f..1401a5523 100644 --- a/library/src/blas3/rocblas_trmm.hpp +++ b/library/src/blas3/rocblas_trmm.hpp @@ -23,6 +23,7 @@ #pragma once #include "Tensile/gemm.hpp" +#include "check_numerics_matrix.hpp" #include "definitions.hpp" template @@ -216,3 +217,21 @@ ROCBLAS_INTERNAL_EXPORT_NOINLINE rocblas_status rocblas_int lddc, rocblas_stride stride_c, rocblas_int batch_count); + +template +rocblas_status rocblas_trmm_check_numerics(const char* function_name, + rocblas_handle handle, + rocblas_side side, + rocblas_fill uplo, + rocblas_operation trans_a, + rocblas_int m, + rocblas_int n, + TConstPtr* A, + rocblas_int lda, + rocblas_stride stride_a, + TPtr* B, + rocblas_int ldb, + rocblas_stride stride_b, + rocblas_int batch_count, + const int check_numerics, + bool is_input); diff --git a/library/src/blas3/rocblas_trmm_batched.cpp b/library/src/blas3/rocblas_trmm_batched.cpp index 084ac9a71..719e6abb2 100644 --- a/library/src/blas3/rocblas_trmm_batched.cpp +++ b/library/src/blas3/rocblas_trmm_batched.cpp @@ -71,7 +71,9 @@ namespace copy_alpha_beta_to_host_if_on_device(handle, alpha, beta, alpha_h, beta_h, m && n)); auto saved_pointer_mode = handle->push_pointer_mode(rocblas_pointer_mode_host); - auto layer_mode = handle->layer_mode; + auto layer_mode = handle->layer_mode; + auto check_numerics = handle->check_numerics; + if(layer_mode & (rocblas_layer_mode_log_trace | rocblas_layer_mode_log_bench | rocblas_layer_mode_log_profile) @@ -175,30 +177,84 @@ namespace if(rocblas_pointer_mode_host == handle->pointer_mode && !a) return rocblas_status_invalid_pointer; + if(check_numerics) + { + bool is_input = true; + rocblas_status trmm_check_numerics_status + = rocblas_trmm_check_numerics(rocblas_trmm_batched_name, + handle, + side, + uplo, + transa, + m, + n, + a, + lda, + stride_a, + b, + ldb, + stride_b, + batch_count, + check_numerics, + is_input); + if(trmm_check_numerics_status != rocblas_status_success) + return trmm_check_numerics_status; + } + constexpr bool BATCHED = true; - return rocblas_internal_trmm_template(handle, - side, - uplo, - transa, - diag, - m, - n, - alpha, - stride_alpha, - a, - offset_a, - lda, - stride_a, - (const T* const*)b, - offset_b, - ldb, - stride_b, - b, - offset_b, - ldb, - stride_b, - batch_count); + rocblas_status status = rocblas_status_success; + + status = rocblas_internal_trmm_template(handle, + side, + uplo, + transa, + diag, + m, + n, + alpha, + stride_alpha, + a, + offset_a, + lda, + stride_a, + (const T* const*)b, + offset_b, + ldb, + stride_b, + b, + offset_b, + ldb, + stride_b, + batch_count); + + if(status != rocblas_status_success) + return status; + + if(check_numerics) + { + bool is_input = false; + rocblas_status trmm_check_numerics_status + = rocblas_trmm_check_numerics(rocblas_trmm_batched_name, + handle, + side, + uplo, + transa, + m, + n, + a, + lda, + stride_a, + b, + ldb, + stride_b, + batch_count, + check_numerics, + is_input); + if(trmm_check_numerics_status != rocblas_status_success) + return trmm_check_numerics_status; + } + return status; } } // namespace diff --git a/library/src/blas3/rocblas_trmm_kernels.cpp b/library/src/blas3/rocblas_trmm_kernels.cpp index fe78390a7..2f901974f 100644 --- a/library/src/blas3/rocblas_trmm_kernels.cpp +++ b/library/src/blas3/rocblas_trmm_kernels.cpp @@ -1556,6 +1556,66 @@ ROCBLAS_INTERNAL_EXPORT_NOINLINE rocblas_status return rocblas_status_success; } +template +rocblas_status rocblas_trmm_check_numerics(const char* function_name, + rocblas_handle handle, + rocblas_side side, + rocblas_fill uplo, + rocblas_operation trans_a, + rocblas_int m, + rocblas_int n, + TConstPtr* A, + rocblas_int lda, + rocblas_stride stride_a, + TPtr* B, + rocblas_int ldb, + rocblas_stride stride_b, + rocblas_int batch_count, + const int check_numerics, + bool is_input) +{ + rocblas_status check_numerics_status = rocblas_status_success; + if(is_input) + { + rocblas_int rows = (side == rocblas_side_left ? m : n); + rocblas_int cols = (side == rocblas_side_left ? m : n); + check_numerics_status + = rocblas_internal_check_numerics_matrix_template(function_name, + handle, + trans_a, + uplo, + rocblas_client_triangular_matrix, + rows, + cols, + A, + 0, + lda, + stride_a, + batch_count, + check_numerics, + is_input); + if(check_numerics_status != rocblas_status_success) + return check_numerics_status; + } + + check_numerics_status + = rocblas_internal_check_numerics_matrix_template(function_name, + handle, + rocblas_operation_none, + rocblas_fill_full, + rocblas_client_general_matrix, + m, + n, + B, + 0, + ldb, + stride_b, + batch_count, + check_numerics, + is_input); + return check_numerics_status; +} + // Instantiations below will need to be manually updated to match any change in // template parameters in the files trmm*.cpp @@ -1639,4 +1699,42 @@ INSTANTIATE_SET_MATRIX_ZERO_TEMPLATE(rocblas_double_complex const*, rocblas_doub #undef INSTANTIATE_SET_MATRIX_ZERO_TEMPLATE + +#ifdef INSTANTIATE_TRMM_NUMERICS +#error INSTANTIATE_TRMM_NUMERICS already defined +#endif + +#define INSTANTIATE_TRMM_NUMERICS(TConstPtr_, TPtr_) \ +template rocblas_status rocblas_trmm_check_numerics \ + \ + (const char* function_name, \ + rocblas_handle handle, \ + rocblas_side side, \ + rocblas_fill uplo, \ + rocblas_operation trans_a, \ + rocblas_int m, \ + rocblas_int n, \ + TConstPtr_* dA, \ + rocblas_int lda, \ + rocblas_stride stride_a, \ + TPtr_* dB, \ + rocblas_int ldb, \ + rocblas_stride stride_b, \ + rocblas_int batch_count, \ + const int check_numerics, \ + bool is_input); + +// instantiate for rocblas_Xtrmm and rocblas_Xtrmm_strided_batched +INSTANTIATE_TRMM_NUMERICS(float const, float) +INSTANTIATE_TRMM_NUMERICS(double const, double) +INSTANTIATE_TRMM_NUMERICS(rocblas_float_complex const, rocblas_float_complex) +INSTANTIATE_TRMM_NUMERICS(rocblas_double_complex const, rocblas_double_complex) + +// instantiate for rocblas_Xtrmm_batched +INSTANTIATE_TRMM_NUMERICS(float const* const, float* const) +INSTANTIATE_TRMM_NUMERICS(double const* const, double* const) +INSTANTIATE_TRMM_NUMERICS(rocblas_float_complex const* const, rocblas_float_complex* const) +INSTANTIATE_TRMM_NUMERICS(rocblas_double_complex const* const, rocblas_double_complex* const) + +#undef INSTANTIATE_TRMM_NUMERICS // clang-format on diff --git a/library/src/blas3/rocblas_trmm_strided_batched.cpp b/library/src/blas3/rocblas_trmm_strided_batched.cpp index c76ad7efa..998312da3 100644 --- a/library/src/blas3/rocblas_trmm_strided_batched.cpp +++ b/library/src/blas3/rocblas_trmm_strided_batched.cpp @@ -75,7 +75,9 @@ namespace copy_alpha_beta_to_host_if_on_device(handle, alpha, beta, alpha_h, beta_h, m && n)); auto saved_pointer_mode = handle->push_pointer_mode(rocblas_pointer_mode_host); - auto layer_mode = handle->layer_mode; + auto layer_mode = handle->layer_mode; + auto check_numerics = handle->check_numerics; + if(layer_mode & (rocblas_layer_mode_log_trace | rocblas_layer_mode_log_bench | rocblas_layer_mode_log_profile) @@ -190,28 +192,82 @@ namespace constexpr bool BATCHED = false; - return rocblas_internal_trmm_template(handle, - side, - uplo, - transa, - diag, - m, - n, - alpha, - stride_alpha, - a, - offset_a, - lda, - stride_a, - (const T*)b, - offset_b, - ldb, - stride_b, - b, - offset_b, - ldb, - stride_b, - batch_count); + if(check_numerics) + { + bool is_input = true; + rocblas_status trmm_check_numerics_status + = rocblas_trmm_check_numerics(rocblas_trmm_strided_batched_name, + handle, + side, + uplo, + transa, + m, + n, + a, + lda, + stride_a, + b, + ldb, + stride_b, + batch_count, + check_numerics, + is_input); + if(trmm_check_numerics_status != rocblas_status_success) + return trmm_check_numerics_status; + } + + rocblas_status status = rocblas_status_success; + + status = rocblas_internal_trmm_template(handle, + side, + uplo, + transa, + diag, + m, + n, + alpha, + stride_alpha, + a, + offset_a, + lda, + stride_a, + (const T*)b, + offset_b, + ldb, + stride_b, + b, + offset_b, + ldb, + stride_b, + batch_count); + + if(status != rocblas_status_success) + return status; + + if(check_numerics) + { + bool is_input = false; + rocblas_status trmm_check_numerics_status + = rocblas_trmm_check_numerics(rocblas_trmm_strided_batched_name, + handle, + side, + uplo, + transa, + m, + n, + a, + lda, + stride_a, + b, + ldb, + stride_b, + batch_count, + check_numerics, + is_input); + if(trmm_check_numerics_status != rocblas_status_success) + return trmm_check_numerics_status; + } + return status; } } // namespace diff --git a/library/src/blas3/rocblas_trsm.cpp b/library/src/blas3/rocblas_trsm.cpp index 49322a93d..980525a83 100644 --- a/library/src/blas3/rocblas_trsm.cpp +++ b/library/src/blas3/rocblas_trsm.cpp @@ -72,6 +72,7 @@ namespace if(!handle) return rocblas_status_invalid_handle; + auto check_numerics = handle->check_numerics; ///////////// // LOGGING // ///////////// @@ -158,66 +159,118 @@ namespace return rocblas_status_success; } + if(check_numerics) + { + bool is_input = true; + rocblas_status trsm_check_numerics_status + = rocblas_trmm_check_numerics(rocblas_trsm_name, + handle, + side, + uplo, + transA, + m, + n, + A, + lda, + 0, + B, + ldb, + 0, + 1, + check_numerics, + is_input); + if(trsm_check_numerics_status != rocblas_status_success) + return trsm_check_numerics_status; + } + ////////////////////// // MEMORY MANAGEMENT// ////////////////////// + rocblas_status status = rocblas_status_success; + //kernel function is enclosed inside the brackets so that the handle device memory used by the kernel is released after the computation. + { + // Proxy object holds the allocation. It must stay alive as long as mem_* pointers below are alive. + auto w_mem = handle->device_malloc(0); + void* w_mem_x_temp; + void* w_mem_x_temp_arr; + void* w_mem_invA; + void* w_mem_invA_arr; + rocblas_status perf_status + = rocblas_internal_trsm_template_mem(handle, + side, + transA, + m, + n, + 1, + w_mem, + w_mem_x_temp, + w_mem_x_temp_arr, + w_mem_invA, + w_mem_invA_arr, + supplied_invA, + supplied_invA_size); + + // If this was a device memory query or an error occurred, return status + if(perf_status != rocblas_status_success && perf_status != rocblas_status_perf_degraded) + return perf_status; + + bool optimal_mem = perf_status == rocblas_status_success; + + status = rocblas_internal_trsm_template(handle, + side, + uplo, + transA, + diag, + m, + n, + alpha, + A, + 0, + lda, + 0, + B, + 0, + ldb, + 0, + 1, + optimal_mem, + w_mem_x_temp, + w_mem_x_temp_arr, + w_mem_invA, + w_mem_invA_arr, + supplied_invA, + supplied_invA_size); + + status = (status != rocblas_status_success) ? status : perf_status; + if(status != rocblas_status_success) + return status; + } - // Proxy object holds the allocation. It must stay alive as long as mem_* pointers below are alive. - auto w_mem = handle->device_malloc(0); - void* w_mem_x_temp; - void* w_mem_x_temp_arr; - void* w_mem_invA; - void* w_mem_invA_arr; - rocblas_status perf_status - = rocblas_internal_trsm_template_mem(handle, - side, - transA, - m, - n, - 1, - w_mem, - w_mem_x_temp, - w_mem_x_temp_arr, - w_mem_invA, - w_mem_invA_arr, - supplied_invA, - supplied_invA_size); - - // If this was a device memory query or an error occurred, return status - if(perf_status != rocblas_status_success && perf_status != rocblas_status_perf_degraded) - return perf_status; - - bool optimal_mem = perf_status == rocblas_status_success; - - rocblas_status status - = rocblas_internal_trsm_template(handle, - side, - uplo, - transA, - diag, - m, - n, - alpha, - A, - 0, - lda, - 0, - B, - 0, - ldb, - 0, - 1, - optimal_mem, - w_mem_x_temp, - w_mem_x_temp_arr, - w_mem_invA, - w_mem_invA_arr, - supplied_invA, - supplied_invA_size); - - return status != rocblas_status_success ? status : perf_status; + if(check_numerics) + { + bool is_input = false; + rocblas_status trsm_check_numerics_status + = rocblas_trmm_check_numerics(rocblas_trsm_name, + handle, + side, + uplo, + transA, + m, + n, + A, + lda, + 0, + B, + ldb, + 0, + 1, + check_numerics, + is_input); + if(trsm_check_numerics_status != rocblas_status_success) + return trsm_check_numerics_status; + } + return status; } - } /* diff --git a/library/src/blas3/rocblas_trsm.hpp b/library/src/blas3/rocblas_trsm.hpp index 2b5089853..8f4887792 100644 --- a/library/src/blas3/rocblas_trsm.hpp +++ b/library/src/blas3/rocblas_trsm.hpp @@ -24,6 +24,7 @@ #include "../blas2/rocblas_trsv.hpp" #include "../blas_ex/rocblas_gemm_ex.hpp" +#include "rocblas_trmm.hpp" #include "trtri_trsm.hpp" template diff --git a/library/src/blas3/rocblas_trsm_batched.cpp b/library/src/blas3/rocblas_trsm_batched.cpp index a93fd137b..15e911a6c 100644 --- a/library/src/blas3/rocblas_trsm_batched.cpp +++ b/library/src/blas3/rocblas_trsm_batched.cpp @@ -69,6 +69,7 @@ namespace if(!handle) return rocblas_status_invalid_handle; + auto check_numerics = handle->check_numerics; ///////////// // LOGGING // ///////////// @@ -164,66 +165,118 @@ namespace return rocblas_status_success; } + if(check_numerics) + { + bool is_input = true; + rocblas_status trsm_check_numerics_status + = rocblas_trmm_check_numerics(rocblas_trsm_name, + handle, + side, + uplo, + transA, + m, + n, + A, + lda, + 0, + B, + ldb, + 0, + batch_count, + check_numerics, + is_input); + if(trsm_check_numerics_status != rocblas_status_success) + return trsm_check_numerics_status; + } + + rocblas_status status = rocblas_status_success; ////////////////////// // MEMORY MANAGEMENT// ////////////////////// + //kernel function is enclosed inside the brackets so that the handle device memory used by the kernel is released after the computation. + { + // Proxy object holds the allocation. It must stay alive as long as mem_* pointers below are alive. + auto w_mem = handle->device_malloc(0); + void* w_mem_x_temp; + void* w_mem_x_temp_arr; + void* w_mem_invA; + void* w_mem_invA_arr; - // Proxy object holds the allocation. It must stay alive as long as mem_* pointers below are alive. - auto w_mem = handle->device_malloc(0); - void* w_mem_x_temp; - void* w_mem_x_temp_arr; - void* w_mem_invA; - void* w_mem_invA_arr; - - rocblas_status perf_status - = rocblas_internal_trsm_template_mem(handle, - side, - transA, - m, - n, - batch_count, - w_mem, - w_mem_x_temp, - w_mem_x_temp_arr, - w_mem_invA, - w_mem_invA_arr, - supplied_invA, - supplied_invA_size); + rocblas_status perf_status + = rocblas_internal_trsm_template_mem(handle, + side, + transA, + m, + n, + batch_count, + w_mem, + w_mem_x_temp, + w_mem_x_temp_arr, + w_mem_invA, + w_mem_invA_arr, + supplied_invA, + supplied_invA_size); - if(perf_status != rocblas_status_success && perf_status != rocblas_status_perf_degraded) - return perf_status; + if(perf_status != rocblas_status_success && perf_status != rocblas_status_perf_degraded) + return perf_status; - bool optimal_mem = perf_status == rocblas_status_success; + bool optimal_mem = perf_status == rocblas_status_success; - rocblas_status status - = rocblas_internal_trsm_template(handle, - side, - uplo, - transA, - diag, - m, - n, - alpha, - A, - 0, - lda, - 0, - B, - 0, - ldb, - 0, - batch_count, - optimal_mem, - w_mem_x_temp, - w_mem_x_temp_arr, - w_mem_invA, - w_mem_invA_arr, - supplied_invA, - supplied_invA_size, - 0, - 0); + status = rocblas_internal_trsm_template(handle, + side, + uplo, + transA, + diag, + m, + n, + alpha, + A, + 0, + lda, + 0, + B, + 0, + ldb, + 0, + batch_count, + optimal_mem, + w_mem_x_temp, + w_mem_x_temp_arr, + w_mem_invA, + w_mem_invA_arr, + supplied_invA, + supplied_invA_size, + 0, + 0); + status = (status != rocblas_status_success) ? status : perf_status; + if(status != rocblas_status_success) + return status; + } - return status != rocblas_status_success ? status : perf_status; + if(check_numerics) + { + bool is_input = false; + rocblas_status trsm_check_numerics_status + = rocblas_trmm_check_numerics(rocblas_trsm_name, + handle, + side, + uplo, + transA, + m, + n, + A, + lda, + 0, + B, + ldb, + 0, + batch_count, + check_numerics, + is_input); + if(trsm_check_numerics_status != rocblas_status_success) + return trsm_check_numerics_status; + } + return status; } } diff --git a/library/src/blas3/rocblas_trsm_strided_batched.cpp b/library/src/blas3/rocblas_trsm_strided_batched.cpp index 9e7b02fa9..8f72de6d1 100644 --- a/library/src/blas3/rocblas_trsm_strided_batched.cpp +++ b/library/src/blas3/rocblas_trsm_strided_batched.cpp @@ -76,6 +76,7 @@ namespace if(!handle) return rocblas_status_invalid_handle; + auto check_numerics = handle->check_numerics; ///////////// // LOGGING // ///////////// @@ -181,66 +182,118 @@ namespace return rocblas_status_success; } + if(check_numerics) + { + bool is_input = true; + rocblas_status trsm_check_numerics_status + = rocblas_trmm_check_numerics(rocblas_trsm_name, + handle, + side, + uplo, + transA, + m, + n, + A, + lda, + stride_A, + B, + ldb, + stride_B, + batch_count, + check_numerics, + is_input); + if(trsm_check_numerics_status != rocblas_status_success) + return trsm_check_numerics_status; + } + + rocblas_status status = rocblas_status_success; ////////////////////// // MEMORY MANAGEMENT// ////////////////////// + //kernel function is enclosed inside the brackets so that the handle device memory used by the kernel is released after the computation. + { + // Proxy object holds the allocation. It must stay alive as long as mem_* pointers below are alive. + auto w_mem = handle->device_malloc(0); + void* w_mem_x_temp; + void* w_mem_x_temp_arr; + void* w_mem_invA; + void* w_mem_invA_arr; + + rocblas_status perf_status + = rocblas_internal_trsm_template_mem(handle, + side, + transA, + m, + n, + batch_count, + w_mem, + w_mem_x_temp, + w_mem_x_temp_arr, + w_mem_invA, + w_mem_invA_arr, + supplied_invA, + supplied_invA_size); + + if(perf_status != rocblas_status_success && perf_status != rocblas_status_perf_degraded) + return perf_status; + + bool optimal_mem = perf_status == rocblas_status_success; + + status = rocblas_internal_trsm_template(handle, + side, + uplo, + transA, + diag, + m, + n, + alpha, + (const T*)A, + 0, + lda, + stride_A, + (T*)B, + 0, + ldb, + stride_B, + batch_count, + optimal_mem, + w_mem_x_temp, + w_mem_x_temp_arr, + w_mem_invA, + w_mem_invA_arr, + (const T*)supplied_invA, + supplied_invA_size, + 0, + stride_invA); + status = (status != rocblas_status_success) ? status : perf_status; + if(status != rocblas_status_success) + return status; + } - // Proxy object holds the allocation. It must stay alive as long as mem_* pointers below are alive. - auto w_mem = handle->device_malloc(0); - void* w_mem_x_temp; - void* w_mem_x_temp_arr; - void* w_mem_invA; - void* w_mem_invA_arr; - - rocblas_status perf_status - = rocblas_internal_trsm_template_mem(handle, - side, - transA, - m, - n, - batch_count, - w_mem, - w_mem_x_temp, - w_mem_x_temp_arr, - w_mem_invA, - w_mem_invA_arr, - supplied_invA, - supplied_invA_size); - - if(perf_status != rocblas_status_success && perf_status != rocblas_status_perf_degraded) - return perf_status; - - bool optimal_mem = perf_status == rocblas_status_success; - - rocblas_status status - = rocblas_internal_trsm_template(handle, - side, - uplo, - transA, - diag, - m, - n, - alpha, - (const T*)A, - 0, - lda, - stride_A, - (T*)B, - 0, - ldb, - stride_B, - batch_count, - optimal_mem, - w_mem_x_temp, - w_mem_x_temp_arr, - w_mem_invA, - w_mem_invA_arr, - (const T*)supplied_invA, - supplied_invA_size, - 0, - stride_invA); - - return status != rocblas_status_success ? status : perf_status; + if(check_numerics) + { + bool is_input = false; + rocblas_status trsm_check_numerics_status + = rocblas_trmm_check_numerics(rocblas_trsm_name, + handle, + side, + uplo, + transA, + m, + n, + A, + lda, + stride_A, + B, + ldb, + stride_B, + batch_count, + check_numerics, + is_input); + if(trsm_check_numerics_status != rocblas_status_success) + return trsm_check_numerics_status; + } + return status; } } diff --git a/library/src/blas3/rocblas_trtri.cpp b/library/src/blas3/rocblas_trtri.cpp index ae5008036..70a8d4d84 100644 --- a/library/src/blas3/rocblas_trtri.cpp +++ b/library/src/blas3/rocblas_trtri.cpp @@ -58,7 +58,9 @@ namespace return handle->set_optimal_device_memory_size(size); } - auto layer_mode = handle->layer_mode; + auto layer_mode = handle->layer_mode; + auto check_numerics = handle->check_numerics; + if(layer_mode & rocblas_layer_mode_log_trace) log_trace(handle, rocblas_trtri_name, uplo, diag, n, A, lda, invA, ldinvA); @@ -85,23 +87,71 @@ namespace if(!w_mem) return rocblas_status_memory_error; - return rocblas_internal_trtri_template(handle, - uplo, - diag, - n, - A, - 0, - lda, - lda * n, - 0, - invA, - 0, - ldinvA, - ldinvA * n, - 0, - 1, - 1, - (T*)w_mem); + if(check_numerics) + { + bool is_input = true; + rocblas_status trtri_check_numerics_status + = rocblas_trtri_check_numerics(rocblas_trtri_name, + handle, + uplo, + n, + A, + lda, + 0, + invA, + ldinvA, + 0, + 1, + check_numerics, + is_input); + if(trtri_check_numerics_status != rocblas_status_success) + return trtri_check_numerics_status; + } + + rocblas_status status = rocblas_status_success; + + status = rocblas_internal_trtri_template(handle, + uplo, + diag, + n, + A, + 0, + lda, + lda * n, + 0, + invA, + 0, + ldinvA, + ldinvA * n, + 0, + 1, + 1, + (T*)w_mem); + + if(status != rocblas_status_success) + return status; + + if(check_numerics) + { + bool is_input = false; + rocblas_status trtri_check_numerics_status + = rocblas_trtri_check_numerics(rocblas_trtri_name, + handle, + uplo, + n, + A, + lda, + 0, + invA, + ldinvA, + 0, + 1, + check_numerics, + is_input); + if(trtri_check_numerics_status != rocblas_status_success) + return trtri_check_numerics_status; + } + return status; } } diff --git a/library/src/blas3/rocblas_trtri.hpp b/library/src/blas3/rocblas_trtri.hpp index fdbee857c..e828fa773 100644 --- a/library/src/blas3/rocblas_trtri.hpp +++ b/library/src/blas3/rocblas_trtri.hpp @@ -22,6 +22,7 @@ #pragma once +#include "check_numerics_matrix.hpp" #include "gemm.hpp" template @@ -1158,3 +1159,58 @@ ROCBLAS_INTERNAL_EXPORT_NOINLINE rocblas_status w_C_tmp); } } + +template +rocblas_status rocblas_trtri_check_numerics(const char* function_name, + rocblas_handle handle, + rocblas_fill uplo, + rocblas_int n, + TConstPtr* A, + rocblas_int lda, + rocblas_stride stride_a, + TPtr* invA, + rocblas_int ldinvA, + rocblas_stride stride_invA, + rocblas_int batch_count, + const int check_numerics, + bool is_input) +{ + rocblas_status check_numerics_status = rocblas_status_success; + if(is_input) + { + check_numerics_status + = rocblas_internal_check_numerics_matrix_template(function_name, + handle, + rocblas_operation_none, + uplo, + rocblas_client_triangular_matrix, + n, + n, + A, + 0, + lda, + stride_a, + batch_count, + check_numerics, + is_input); + if(check_numerics_status != rocblas_status_success) + return check_numerics_status; + } + + check_numerics_status + = rocblas_internal_check_numerics_matrix_template(function_name, + handle, + rocblas_operation_none, + rocblas_fill_full, + rocblas_client_general_matrix, + n, + n, + invA, + 0, + ldinvA, + stride_invA, + batch_count, + check_numerics, + is_input); + return check_numerics_status; +} diff --git a/library/src/blas3/rocblas_trtri_batched.cpp b/library/src/blas3/rocblas_trtri_batched.cpp index f7670055c..00c351e85 100644 --- a/library/src/blas3/rocblas_trtri_batched.cpp +++ b/library/src/blas3/rocblas_trtri_batched.cpp @@ -63,7 +63,9 @@ namespace return handle->set_optimal_device_memory_size(size, sizep); } - auto layer_mode = handle->layer_mode; + auto layer_mode = handle->layer_mode; + auto check_numerics = handle->check_numerics; + if(layer_mode & rocblas_layer_mode_log_trace) log_trace( handle, rocblas_trtri_name, uplo, diag, n, A, lda, invA, ldinvA, batch_count); @@ -89,11 +91,34 @@ namespace if(arg_status != rocblas_status_continue) return arg_status; - rocblas_status status; + if(check_numerics) + { + bool is_input = true; + rocblas_status trtri_check_numerics_status + = rocblas_trtri_check_numerics(rocblas_trtri_name, + handle, + uplo, + n, + A, + lda, + 0, + invA, + ldinvA, + 0, + batch_count, + check_numerics, + is_input); + if(trtri_check_numerics_status != rocblas_status_success) + return trtri_check_numerics_status; + } + + rocblas_status status = rocblas_status_success; if(n <= NB) { status = rocblas_trtri_small( handle, uplo, diag, n, A, 0, lda, 0, 0, invA, 0, ldinvA, 0, 0, batch_count, 1); + if(status != rocblas_status_success) + return status; } else { @@ -135,8 +160,29 @@ namespace batch_count, 1, (T* const*)w_C_tmp_arr); + if(status != rocblas_status_success) + return status; + } + if(check_numerics) + { + bool is_input = false; + rocblas_status trtri_check_numerics_status + = rocblas_trtri_check_numerics(rocblas_trtri_name, + handle, + uplo, + n, + A, + lda, + 0, + invA, + ldinvA, + 0, + batch_count, + check_numerics, + is_input); + if(trtri_check_numerics_status != rocblas_status_success) + return trtri_check_numerics_status; } - return status; } diff --git a/library/src/blas3/rocblas_trtri_strided_batched.cpp b/library/src/blas3/rocblas_trtri_strided_batched.cpp index 67bde3a57..15df21aab 100644 --- a/library/src/blas3/rocblas_trtri_strided_batched.cpp +++ b/library/src/blas3/rocblas_trtri_strided_batched.cpp @@ -63,7 +63,9 @@ namespace return handle->set_optimal_device_memory_size(size); } - auto layer_mode = handle->layer_mode; + auto layer_mode = handle->layer_mode; + auto check_numerics = handle->check_numerics; + if(layer_mode & rocblas_layer_mode_log_trace) log_trace(handle, rocblas_trtri_name, @@ -103,7 +105,28 @@ namespace if(arg_status != rocblas_status_continue) return arg_status; - rocblas_status status; + if(check_numerics) + { + bool is_input = true; + rocblas_status trtri_check_numerics_status + = rocblas_trtri_check_numerics(rocblas_trtri_name, + handle, + uplo, + n, + A, + lda, + bsa, + invA, + ldinvA, + bsinvA, + batch_count, + check_numerics, + is_input); + if(trtri_check_numerics_status != rocblas_status_success) + return trtri_check_numerics_status; + } + + rocblas_status status = rocblas_status_success; if(n <= NB) { status = rocblas_trtri_small(handle, @@ -122,6 +145,8 @@ namespace 0, batch_count, 1); + if(status != rocblas_status_success) + return status; } else { @@ -147,8 +172,30 @@ namespace batch_count, 1, (T*)w_C_tmp); + if(status != rocblas_status_success) + return status; } + if(check_numerics) + { + bool is_input = false; + rocblas_status trtri_check_numerics_status + = rocblas_trtri_check_numerics(rocblas_trtri_name, + handle, + uplo, + n, + A, + lda, + bsa, + invA, + ldinvA, + bsinvA, + batch_count, + check_numerics, + is_input); + if(trtri_check_numerics_status != rocblas_status_success) + return trtri_check_numerics_status; + } return status; }