Skip to content

Commit

Permalink
Extending LUMod test driver to support low-rank case
Browse files Browse the repository at this point in the history
  • Loading branch information
poulson committed Dec 15, 2016
1 parent 51fe6de commit 8a1a42c
Showing 1 changed file with 84 additions and 80 deletions.
164 changes: 84 additions & 80 deletions tests/lapack_like/LUMod.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,87 +9,88 @@
#include <El.hpp>
using namespace El;

template<typename F>
template<typename Field>
void TestCorrectness
( bool print,
const Matrix<F>& A,
const Matrix<Field>& A,
const Permutation& P,
const Matrix<F>& AOrig,
const Matrix<Field>& AOrig,
Int numRHS=100 )
{
typedef Base<F> Real;
typedef Base<Field> Real;
const Int n = AOrig.Width();
const Real eps = limits::Epsilon<Real>();
const Real oneNormA = OneNorm( AOrig );

Output("Testing error...");

// Generate random right-hand sides
Matrix<F> X;
Matrix<Field> X;
Uniform( X, n, numRHS );
auto Y( X );
const Real oneNormY = OneNorm( Y );
P.PermuteRows( Y );
lu::SolveAfter( NORMAL, A, Y );

// Now investigate the residual, ||AOrig Y - X||_oo
Gemm( NORMAL, NORMAL, F(-1), AOrig, Y, F(1), X );
Gemm( NORMAL, NORMAL, Field(-1), AOrig, Y, Field(1), X );
const Real infError = InfinityNorm( X );
const Real relError = infError / (eps*n*Max(oneNormA,oneNormY));

// TODO: Use a rigorous failure condition
// TODO(poulson): Use a rigorous failure condition
Output("||A X - Y||_oo / (eps n Max(||A||_1,||Y||_1)) = ",relError);
if( relError > Real(1000) )
LogicError("Unacceptably large relative error");
}

template<typename F>
template<typename Field>
void TestCorrectness
( bool print,
const DistMatrix<F>& A,
const DistMatrix<Field>& A,
const DistPermutation& P,
const DistMatrix<F>& AOrig,
const DistMatrix<Field>& AOrig,
Int numRHS=100 )
{
typedef Base<F> Real;
const Grid& g = A.Grid();
typedef Base<Field> Real;
const Grid& grid = A.Grid();
const Int n = AOrig.Width();
const Real eps = limits::Epsilon<Real>();
const Real oneNormA = OneNorm( AOrig );

OutputFromRoot(g.Comm(),"Testing error...");
OutputFromRoot(grid.Comm(),"Testing error...");

// Generate random right-hand sides
DistMatrix<F> X(g);
DistMatrix<Field> X(grid);
Uniform( X, n, numRHS );
auto Y( X );
const Real oneNormY = OneNorm( Y );
P.PermuteRows( Y );
lu::SolveAfter( NORMAL, A, Y );

// Now investigate the residual, ||AOrig Y - X||_oo
Gemm( NORMAL, NORMAL, F(-1), AOrig, Y, F(1), X );
Gemm( NORMAL, NORMAL, Field(-1), AOrig, Y, Field(1), X );
const Real infError = InfinityNorm( X );
const Real relError = infError / (eps*n*Max(oneNormA,oneNormY));

// TODO: Use a rigorous failure condition
// TODO(poulson): Use a rigorous failure condition
OutputFromRoot
(g.Comm(),"||A X - Y||_oo / (eps n Max(||A||_1,||Y||_1)) = ",relError);
(grid.Comm(),"||A X - Y||_oo / (eps n Max(||A||_1,||Y||_1)) = ",relError);
if( relError > Real(1000) )
LogicError("Unacceptably large relative error");
}

template<typename F>
template<typename Field>
void TestLUMod
( Int m,
Int updateRank,
bool conjugate,
Base<F> tau,
Base<Field> tau,
bool correctness,
bool print )
{
Output("Testing with ",TypeName<F>());
Output("Testing with ",TypeName<Field>());
PushIndent();
Matrix<F> A, AOrig;
Matrix<Field> A, AOrig;
Permutation P;

Uniform( A, m, m );
Expand All @@ -102,58 +103,60 @@ void TestLUMod
Output("Starting LU factorization...");
Timer timer;
timer.Start();
P.ReserveSwaps( m+2*m-1 );
//P.ReserveSwaps( m+2*m-1 );
LU( A, P );
const double runTime = timer.Stop();
const double realGFlops = 2./3.*Pow(double(m),3.)/(1.e9*runTime);
const double gFlops =
( IsComplex<F>::value ? 4*realGFlops : realGFlops );
( IsComplex<Field>::value ? 4*realGFlops : realGFlops );
Output(runTime," seconds (",gFlops," GFlop/s)");
}

// TODO: Print permutation
// TODO(poulson): Print permutation

// Generate random vectors u and v
Matrix<F> u, v;
Uniform( u, m, 1 );
Uniform( v, m, 1 );
Matrix<Field> U, V;
Uniform( U, m, updateRank );
Uniform( V, m, updateRank );
if( correctness )
{
if( conjugate )
Ger( F(1), u, v, AOrig );
Gemm( NORMAL, ADJOINT, Field(1), U, V, Field(1), AOrig );
else
Geru( F(1), u, v, AOrig );
Gemm( NORMAL, TRANSPOSE, Field(1), U, V, Field(1), AOrig );
}

{
Output("Starting rank-one LU modification...");
Output("Starting low-rank LU modification...");
Timer timer;
timer.Start();
LUMod( A, P, u, v, conjugate, tau );
LUMod( A, P, U, V, conjugate, tau );
const double runTime = timer.Stop();
Output(runTime," seconds");
}

// TODO: Print permutation
// TODO(poulson): Print permutation

if( correctness )
TestCorrectness( print, A, P, AOrig );

PopIndent();
}

template<typename F>
template<typename Field>
void TestLUMod
( const Grid& g,
( const Grid& grid,
Int m,
Int updateRank,
bool conjugate,
Base<F> tau,
Base<Field> tau,
bool correctness,
bool print )
{
OutputFromRoot(g.Comm(),"Testing with ",TypeName<F>());
OutputFromRoot(grid.Comm(),"Testing with ",TypeName<Field>());
PushIndent();
DistMatrix<F> A(g), AOrig(g);
DistPermutation P(g);
DistMatrix<Field> A(grid), AOrig(grid);
DistPermutation P(grid);

Uniform( A, m, m );
if( correctness )
Expand All @@ -162,46 +165,46 @@ void TestLUMod
Print( A, "A" );

{
OutputFromRoot(g.Comm(),"Starting LU factorization...");
mpi::Barrier( g.Comm() );
OutputFromRoot(grid.Comm(),"Starting LU factorization...");
mpi::Barrier( grid.Comm() );
Timer timer;
timer.Start();
P.ReserveSwaps( m+2*m-1 );
LU( A, P );
mpi::Barrier( g.Comm() );
mpi::Barrier( grid.Comm() );
const double runTime = timer.Stop();
const double realGFlops = 2./3.*Pow(double(m),3.)/(1.e9*runTime);
const double gFlops =
( IsComplex<F>::value ? 4*realGFlops : realGFlops );
OutputFromRoot(g.Comm(),runTime," seconds (",gFlops," GFlop/s)");
( IsComplex<Field>::value ? 4*realGFlops : realGFlops );
OutputFromRoot(grid.Comm(),runTime," seconds (",gFlops," GFlop/s)");
}

// TODO: Print permutation
// TODO(poulson): Print permutation

// Generate random vectors u and v
DistMatrix<F> u(g), v(g);
Uniform( u, m, 1 );
Uniform( v, m, 1 );
DistMatrix<Field> U(grid), V(grid);
Uniform( U, m, updateRank );
Uniform( V, m, updateRank );
if( correctness )
{
if( conjugate )
Ger( F(1), u, v, AOrig );
Gemm( NORMAL, ADJOINT, Field(1), U, V, Field(1), AOrig );
else
Geru( F(1), u, v, AOrig );
Gemm( NORMAL, TRANSPOSE, Field(1), U, V, Field(1), AOrig );
}

{
OutputFromRoot(g.Comm(),"Starting rank-one LU modification...");
mpi::Barrier( g.Comm() );
OutputFromRoot(grid.Comm(),"Starting low-rank LU modification...");
mpi::Barrier( grid.Comm() );
Timer timer;
timer.Start();
LUMod( A, P, u, v, conjugate, tau );
mpi::Barrier( g.Comm() );
LUMod( A, P, U, V, conjugate, tau );
mpi::Barrier( grid.Comm() );
const double runTime = timer.Stop();
OutputFromRoot(g.Comm(),runTime," seconds");
OutputFromRoot(grid.Comm(),runTime," seconds");
}

// TODO: Print permutation
// TODO(poulson): Print permutation

if( correctness )
TestCorrectness( print, A, P, AOrig );
Expand All @@ -219,6 +222,7 @@ main( int argc, char* argv[] )
int gridHeight = Input("--gridHeight","height of process grid",0);
const bool colMajor = Input("--colMajor","column-major ordering?",true);
const Int m = Input("--height","height of matrix",100);
const Int updateRank = Input("--updateRank","rank of LU update",10);
const Int nb = Input("--nb","algorithmic blocksize",96);
const double tau = Input("--tau","pivot threshold",0.1);
const bool conjugate = Input("--conjugate","conjugate v?",true);
Expand All @@ -239,83 +243,83 @@ main( int argc, char* argv[] )
if( gridHeight == 0 )
gridHeight = Grid::FindFactor( mpi::Size(comm) );
const GridOrder order = ( colMajor ? COLUMN_MAJOR : ROW_MAJOR );
const Grid g( comm, gridHeight, order );
const Grid grid( comm, gridHeight, order );
SetBlocksize( nb );
ComplainIfDebug();

if( sequential && mpi::Rank() == 0 )
{
TestLUMod<float>
( m, conjugate, tau, correctness, print );
( m, updateRank, conjugate, tau, correctness, print );
TestLUMod<Complex<float>>
( m, conjugate, tau, correctness, print );
( m, updateRank, conjugate, tau, correctness, print );

TestLUMod<double>
( m, conjugate, tau, correctness, print );
( m, updateRank, conjugate, tau, correctness, print );
TestLUMod<Complex<double>>
( m, conjugate, tau, correctness, print );
( m, updateRank, conjugate, tau, correctness, print );

#ifdef EL_HAVE_QD
TestLUMod<DoubleDouble>
( m, conjugate, tau, correctness, print );
( m, updateRank, conjugate, tau, correctness, print );
TestLUMod<QuadDouble>
( m, conjugate, tau, correctness, print );
( m, updateRank, conjugate, tau, correctness, print );

TestLUMod<Complex<DoubleDouble>>
( m, conjugate, tau, correctness, print );
( m, updateRank, conjugate, tau, correctness, print );
TestLUMod<Complex<QuadDouble>>
( m, conjugate, tau, correctness, print );
( m, updateRank, conjugate, tau, correctness, print );
#endif

#ifdef EL_HAVE_QUAD
TestLUMod<Quad>
( m, conjugate, tau, correctness, print );
( m, updateRank, conjugate, tau, correctness, print );
TestLUMod<Complex<Quad>>
( m, conjugate, tau, correctness, print );
( m, updateRank, conjugate, tau, correctness, print );
#endif

#ifdef EL_HAVE_MPC
TestLUMod<BigFloat>
( m, conjugate, tau, correctness, print );
( m, updateRank, conjugate, tau, correctness, print );
TestLUMod<Complex<BigFloat>>
( m, conjugate, tau, correctness, print );
( m, updateRank, conjugate, tau, correctness, print );
#endif
}

TestLUMod<float>
( g, m, conjugate, tau, correctness, print );
( grid, m, updateRank, conjugate, tau, correctness, print );
TestLUMod<Complex<float>>
( g, m, conjugate, tau, correctness, print );
( grid, m, updateRank, conjugate, tau, correctness, print );

TestLUMod<double>
( g, m, conjugate, tau, correctness, print );
( grid, m, updateRank, conjugate, tau, correctness, print );
TestLUMod<Complex<double>>
( g, m, conjugate, tau, correctness, print );
( grid, m, updateRank, conjugate, tau, correctness, print );

#ifdef EL_HAVE_QD
TestLUMod<DoubleDouble>
( g, m, conjugate, tau, correctness, print );
( grid, m, updateRank, conjugate, tau, correctness, print );
TestLUMod<QuadDouble>
( g, m, conjugate, tau, correctness, print );
( grid, m, updateRank, conjugate, tau, correctness, print );

TestLUMod<Complex<DoubleDouble>>
( g, m, conjugate, tau, correctness, print );
( grid, m, updateRank, conjugate, tau, correctness, print );
TestLUMod<Complex<QuadDouble>>
( g, m, conjugate, tau, correctness, print );
( grid, m, updateRank, conjugate, tau, correctness, print );
#endif

#ifdef EL_HAVE_QUAD
TestLUMod<Quad>
( g, m, conjugate, tau, correctness, print );
( grid, m, updateRank, conjugate, tau, correctness, print );
TestLUMod<Complex<Quad>>
( g, m, conjugate, tau, correctness, print );
( grid, m, updateRank, conjugate, tau, correctness, print );
#endif

#ifdef EL_HAVE_MPC
TestLUMod<BigFloat>
( g, m, conjugate, tau, correctness, print );
( grid, m, updateRank, conjugate, tau, correctness, print );
TestLUMod<Complex<BigFloat>>
( g, m, conjugate, tau, correctness, print );
( grid, m, updateRank, conjugate, tau, correctness, print );
#endif
}
catch( exception& e ) { ReportException(e); }
Expand Down

0 comments on commit 8a1a42c

Please sign in to comment.