diff --git a/batched/dense/unit_test/Test_Batched_SerialQR.hpp b/batched/dense/unit_test/Test_Batched_SerialQR.hpp index 643c6e63fc..465a3747cb 100644 --- a/batched/dense/unit_test/Test_Batched_SerialQR.hpp +++ b/batched/dense/unit_test/Test_Batched_SerialQR.hpp @@ -340,7 +340,7 @@ void test_QR_rectangular() { } template -void test_QR_batch() { +void test_QR_batch(const int numMat, const int numRows, const int numCols) { // Generate a batch of matrices // Compute QR factorization // Verify that R is triangular @@ -349,40 +349,8 @@ void test_QR_batch() { using ExecutionSpace = typename Device::execution_space; - { // Square matrix case - constexpr int numMat = 314; - constexpr int numRows = 36; - Kokkos::View tau("tau", numMat, numRows); - Kokkos::View tmp("work buffer", numMat, numRows); - Kokkos::View As("A matrices", numMat, numRows, numRows); - Kokkos::View Bs("B matrices", numMat, numRows, numRows); - Kokkos::View Qs("Q matrices", numMat, numRows, numRows); - Kokkos::View error("global number of error", numMat); - - Kokkos::Random_XorShift64_Pool rand_pool(2718); - constexpr double max_val = 1000; - { - Scalar randStart, randEnd; - Test::getRandomBounds(max_val, randStart, randEnd); - Kokkos::fill_random(ExecutionSpace{}, As, rand_pool, randStart, randEnd); - } - Kokkos::deep_copy(Bs, As); - - qrFunctor myFunc(As, tau, tmp, Qs, Bs, error); - Kokkos::parallel_for("KokkosBatched::test_QR_batch", Kokkos::RangePolicy(0, numMat), myFunc); - Kokkos::fence(); - - typename Kokkos::View::HostMirror error_h = Kokkos::create_mirror_view(error); - Kokkos::deep_copy(error_h, error); - int global_error = 0; - for(int matIdx = 0; matIdx < numMat; ++matIdx) { global_error += error_h(matIdx); } - EXPECT_EQ(global_error, 0); - } - - { // Rectangular matrix case - constexpr int numMat = 25; // 314 - constexpr int numRows = 42; // 42 - constexpr int numCols = 36; // 36 + std::cout << "batched QR, running (" << numMat << ", " << numRows << ", " << numCols << ") case" << std::endl; + { Kokkos::View tau("tau", numMat, numCols); Kokkos::View tmp("work buffer", numMat, numCols); Kokkos::View As("A matrices", numMat, numRows, numCols); @@ -401,14 +369,22 @@ void test_QR_batch() { qrFunctor myFunc(As, tau, tmp, Qs, Bs, error); Kokkos::parallel_for("KokkosBatched::test_QR_batch", Kokkos::RangePolicy(0, numMat), myFunc); + Kokkos::fence(); typename Kokkos::View::HostMirror error_h = Kokkos::create_mirror_view(error); Kokkos::deep_copy(error_h, error); + bool first_error = true; int global_error = 0; for(int matIdx = 0; matIdx < numMat; ++matIdx) { if(0 < error_h(matIdx)) { std::cout << "Errors found for matrix " << matIdx << ": " << error_h(matIdx) << std::endl; + if(first_error == true) { + auto mat_view = Kokkos::subview(As, matIdx, Kokkos::ALL, Kokkos::ALL); + auto mat_host = Kokkos::create_mirror_view(mat_view); + + first_error = false; + } } global_error += error_h(matIdx); } @@ -416,6 +392,12 @@ void test_QR_batch() { } } +template +void test_QR_batch(const int numMat, const int numRows) { + test_QR_batch(numMat, numRows, numRows); +} + + #if defined(KOKKOSKERNELS_INST_FLOAT) TEST_F(TestCategory, serial_qr_square_analytic_float) { typedef KokkosBlas::Algo::QR::Unblocked AlgoTagType; @@ -427,7 +409,12 @@ TEST_F(TestCategory, serial_qr_rectangular_analytic_float) { } TEST_F(TestCategory, serial_qr_batch_float) { typedef KokkosBlas::Algo::QR::Unblocked AlgoTagType; - test_QR_batch(); + test_QR_batch(314, 36); + test_QR_batch(10, 42, 36); + test_QR_batch(100, 42, 36); + test_QR_batch(200, 42, 36); + test_QR_batch(250, 42, 36); + test_QR_batch(300, 42, 36); } #endif @@ -442,7 +429,13 @@ TEST_F(TestCategory, serial_qr_rectangular_analytic_double) { } TEST_F(TestCategory, serial_qr_batch_double) { typedef KokkosBlas::Algo::QR::Unblocked AlgoTagType; - test_QR_batch(); + test_QR_batch(314, 36); + + test_QR_batch(10, 42, 36); + test_QR_batch(100, 42, 36); + test_QR_batch(200, 42, 36); + test_QR_batch(250, 42, 36); + test_QR_batch(300, 42, 36); } #endif