Skip to content

Commit

Permalink
Batched - QR: making test more generic for all shapes
Browse files Browse the repository at this point in the history
Signed-off-by: Luc Berger-Vergiat <[email protected]>
  • Loading branch information
lucbv committed Jan 21, 2025
1 parent 01029cb commit e2771fd
Showing 1 changed file with 30 additions and 37 deletions.
67 changes: 30 additions & 37 deletions batched/dense/unit_test/Test_Batched_SerialQR.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,7 @@ void test_QR_rectangular() {
}

template <class Device, class Scalar, class AlgoTagType>
void test_QR_batch() {
void test_QR_batch(const int numMat, const int numRows, const int numCols) {
// Generate a batch of matrices
// Compute QR factorization
// Verify that R is triangular
Expand All @@ -349,40 +349,8 @@ void test_QR_batch() {

using ExecutionSpace = typename Device::execution_space;

{ // Square matrix case
constexpr int numMat = 314;
constexpr int numRows = 36;
Kokkos::View<Scalar**, ExecutionSpace> tau("tau", numMat, numRows);
Kokkos::View<Scalar**, ExecutionSpace> tmp("work buffer", numMat, numRows);
Kokkos::View<Scalar***, ExecutionSpace> As("A matrices", numMat, numRows, numRows);
Kokkos::View<Scalar***, ExecutionSpace> Bs("B matrices", numMat, numRows, numRows);
Kokkos::View<Scalar***, ExecutionSpace> Qs("Q matrices", numMat, numRows, numRows);
Kokkos::View<int*, ExecutionSpace> error("global number of error", numMat);

Kokkos::Random_XorShift64_Pool<ExecutionSpace> rand_pool(2718);
constexpr double max_val = 1000;
{
Scalar randStart, randEnd;
Test::getRandomBounds(max_val, randStart, randEnd);
Kokkos::fill_random(ExecutionSpace{}, As, rand_pool, randStart, randEnd);
}
Kokkos::deep_copy(Bs, As);

qrFunctor myFunc(As, tau, tmp, Qs, Bs, error);
Kokkos::parallel_for("KokkosBatched::test_QR_batch", Kokkos::RangePolicy<ExecutionSpace>(0, numMat), myFunc);
Kokkos::fence();

typename Kokkos::View<int*, ExecutionSpace>::HostMirror error_h = Kokkos::create_mirror_view(error);
Kokkos::deep_copy(error_h, error);
int global_error = 0;
for(int matIdx = 0; matIdx < numMat; ++matIdx) { global_error += error_h(matIdx); }
EXPECT_EQ(global_error, 0);
}

{ // Rectangular matrix case
constexpr int numMat = 25; // 314
constexpr int numRows = 42; // 42
constexpr int numCols = 36; // 36
std::cout << "batched QR, running (" << numMat << ", " << numRows << ", " << numCols << ") case" << std::endl;
{
Kokkos::View<Scalar**, ExecutionSpace> tau("tau", numMat, numCols);
Kokkos::View<Scalar**, ExecutionSpace> tmp("work buffer", numMat, numCols);
Kokkos::View<Scalar***, ExecutionSpace> As("A matrices", numMat, numRows, numCols);
Expand All @@ -401,21 +369,35 @@ void test_QR_batch() {

qrFunctor myFunc(As, tau, tmp, Qs, Bs, error);
Kokkos::parallel_for("KokkosBatched::test_QR_batch", Kokkos::RangePolicy<ExecutionSpace>(0, numMat), myFunc);
Kokkos::fence();

typename Kokkos::View<int*, ExecutionSpace>::HostMirror error_h = Kokkos::create_mirror_view(error);
Kokkos::deep_copy(error_h, error);

bool first_error = true;
int global_error = 0;
for(int matIdx = 0; matIdx < numMat; ++matIdx) {
if(0 < error_h(matIdx)) {
std::cout << "Errors found for matrix " << matIdx << ": " << error_h(matIdx) << std::endl;
if(first_error == true) {
auto mat_view = Kokkos::subview(As, matIdx, Kokkos::ALL, Kokkos::ALL);
auto mat_host = Kokkos::create_mirror_view(mat_view);

first_error = false;
}
}
global_error += error_h(matIdx);
}
EXPECT_EQ(global_error, 0);
}
}

template <class Device, class Scalar, class AlgoTagType>
void test_QR_batch(const int numMat, const int numRows) {
test_QR_batch<Device, Scalar, AlgoTagType>(numMat, numRows, numRows);
}


#if defined(KOKKOSKERNELS_INST_FLOAT)
TEST_F(TestCategory, serial_qr_square_analytic_float) {
typedef KokkosBlas::Algo::QR::Unblocked AlgoTagType;
Expand All @@ -427,7 +409,12 @@ TEST_F(TestCategory, serial_qr_rectangular_analytic_float) {
}
TEST_F(TestCategory, serial_qr_batch_float) {
typedef KokkosBlas::Algo::QR::Unblocked AlgoTagType;
test_QR_batch<TestDevice, float, AlgoTagType>();
test_QR_batch<TestDevice, float, AlgoTagType>(314, 36);
test_QR_batch<TestDevice, float, AlgoTagType>(10, 42, 36);
test_QR_batch<TestDevice, float, AlgoTagType>(100, 42, 36);
test_QR_batch<TestDevice, float, AlgoTagType>(200, 42, 36);
test_QR_batch<TestDevice, float, AlgoTagType>(250, 42, 36);
test_QR_batch<TestDevice, float, AlgoTagType>(300, 42, 36);
}
#endif

Expand All @@ -442,7 +429,13 @@ TEST_F(TestCategory, serial_qr_rectangular_analytic_double) {
}
TEST_F(TestCategory, serial_qr_batch_double) {
typedef KokkosBlas::Algo::QR::Unblocked AlgoTagType;
test_QR_batch<TestDevice, double, AlgoTagType>();
test_QR_batch<TestDevice, double, AlgoTagType>(314, 36);

test_QR_batch<TestDevice, float, AlgoTagType>(10, 42, 36);
test_QR_batch<TestDevice, float, AlgoTagType>(100, 42, 36);
test_QR_batch<TestDevice, float, AlgoTagType>(200, 42, 36);
test_QR_batch<TestDevice, float, AlgoTagType>(250, 42, 36);
test_QR_batch<TestDevice, float, AlgoTagType>(300, 42, 36);
}
#endif

Expand Down

0 comments on commit e2771fd

Please sign in to comment.