#include "MPITridiagLinearOp.hpp"
#include "sillyCgSolve.hpp"
#include "Thyra_VectorStdOps.hpp"
#include "Thyra_TestingTools.hpp"
#include "Thyra_LinearOpTester.hpp"
#include "Teuchos_CommandLineProcessor.hpp"
#include "Teuchos_Time.hpp"
#include "Teuchos_oblackholestream.hpp"
template<class Scalar>
bool runCgSolveExample(
MPI_Comm mpiComm
,const int procRank
,const int numProc
,const int localDim
,const Scalar diagScale
,const bool verbose
,const bool dumpAll
,const typename Teuchos::ScalarTraits<Scalar>::magnitudeType tolerance
,const int maxNumIters
)
{
using Teuchos::RefCountPtr;
using Teuchos::rcp;
typedef Teuchos::ScalarTraits<Scalar> ST;
typedef typename ST::magnitudeType ScalarMag;
bool success = true;
bool result;
Teuchos::oblackholestream black_hole_out;
std::ostream &out = ( procRank == 0 ? std::cout : black_hole_out );
if(verbose)
out << "\n***\n*** Running silly CG solver using scalar type = \'" << ST::name() << "\' ...\n***\n";
Teuchos::Time timer("");
timer.start(true);
if(verbose) out << "\nConstructing tridiagonal matrix A of local dimension = " << localDim
<< " and diagonal multiplier = " << diagScale << " ...\n";
const Thyra::Index
lowerDim = ( procRank == 0 ? localDim - 1 : localDim ),
upperDim = ( procRank == numProc-1 ? localDim - 1 : localDim );
std::vector<Scalar> lower(lowerDim), diag(localDim), upper(upperDim);
const Scalar one = ST::one(), diagTerm = Scalar(2)*diagScale*ST::one();
int k = 0, kl = 0;
if(procRank > 0) { lower[kl] = -one; ++kl; }; diag[k] = diagTerm; upper[k] = -one;
for( k = 1; k < localDim - 1; ++k, ++kl ) {
lower[kl] = -one; diag[k] = diagTerm; upper[k] = -one;
}
lower[kl] = -one; diag[k] = diagTerm; if(procRank < numProc-1) upper[k] = -one;
RefCountPtr<const Thyra::LinearOpBase<Scalar> >
A = rcp(new MPITridiagLinearOp<Scalar>(mpiComm,localDim,&lower[0],&diag[0],&upper[0]));
if(verbose) out << "\nGlobal dimension of A = " << A->domain()->dim() << std::endl;
if(verbose) out << "\nTesting the constructed linear operator A ...\n";
Thyra::LinearOpTester<Scalar> linearOpTester;
linearOpTester.dump_all(dumpAll);
linearOpTester.set_all_error_tol(tolerance);
linearOpTester.set_all_warning_tol(ScalarMag(ScalarMag(1e-2)*tolerance));
linearOpTester.show_all_tests(true);
linearOpTester.check_adjoint(false);
linearOpTester.check_for_symmetry(true);
result = linearOpTester.check(*A,verbose?&out:0);
if(!result) success = false;
RefCountPtr<Thyra::VectorBase<Scalar> > b = createMember(A->range());
Thyra::seed_randomize<Scalar>(0);
Thyra::randomize( Scalar(-ST::one()), Scalar(+ST::one()), &*b );
RefCountPtr<Thyra::VectorBase<Scalar> > x = createMember(A->domain());
Thyra::assign( &*x, ST::zero() );
result = sillyCgSolve(*A,*b,maxNumIters,tolerance,&*x,verbose?&out:0);
if(!result) success = false;
RefCountPtr<Thyra::VectorBase<Scalar> > r = createMember(A->range());
Thyra::assign(&*r,*b);
Thyra::apply(*A,Thyra::NOTRANS,*x,&*r,Scalar(-ST::one()),ST::one());
const ScalarMag r_nrm = Thyra::norm(*r), b_nrm = Thyra::norm(*b);
const ScalarMag rel_err = r_nrm/b_nrm, relaxTol = ScalarMag(10.0)*tolerance;
result = rel_err <= relaxTol;
if(!result) success = false;
if(verbose)
out
<< "\n||b-A*x||/||b|| = "<<r_nrm<<"/"<<b_nrm<<" = "<<rel_err<<(result?" <= ":" > ")
<<"10.0*tolerance = "<<relaxTol<<": "<<(result?"passed":"failed")<<std::endl;
timer.stop();
if(verbose) out << "\nTotal time = " << timer.totalElapsedTime() << " sec\n";
return success;
}
int main(int argc, char *argv[])
{
using Teuchos::CommandLineProcessor;
bool success = true;
bool verbose = true;
bool result;
MPI_Init(&argc,&argv);
MPI_Comm mpiComm = MPI_COMM_WORLD;
int procRank, numProc;
MPI_Comm_size( mpiComm, &numProc );
MPI_Comm_rank( mpiComm, &procRank );
try {
int localDim = 500;
double diagScale = 1.001;
double tolerance = 1e-4;
int maxNumIters = 300;
bool dumpAll = false;
CommandLineProcessor clp(false);
clp.setOption( "verbose", "quiet", &verbose, "Determines if any output is printed or not." );
clp.setOption( "local-dim", &localDim, "Local dimension of the linear system." );
clp.setOption( "diag-scale", &diagScale, "Scaling of the diagonal to improve conditioning." );
clp.setOption( "tol", &tolerance, "Relative tolerance for linear system solve." );
clp.setOption( "max-num-iters", &maxNumIters, "Maximum of CG iterations." );
clp.setOption( "dump-all", "no-dump-all", &dumpAll, "Determines if vectors are printed or not." );
CommandLineProcessor::EParseCommandLineReturn parse_return = clp.parse(argc,argv);
if( parse_return != CommandLineProcessor::PARSE_SUCCESSFUL ) return parse_return;
TEST_FOR_EXCEPTION( localDim < 2, std::logic_error, "Error, localDim=" << localDim << " < 2 is not allowed!" );
result = runCgSolveExample<float>(mpiComm,procRank,numProc,localDim,diagScale,verbose,dumpAll,tolerance,maxNumIters);
if(!result) success = false;
result = runCgSolveExample<double>(mpiComm,procRank,numProc,localDim,diagScale,verbose,dumpAll,tolerance,maxNumIters);
if(!result) success = false;
#if defined(HAVE_COMPLEX) && defined(HAVE_TEUCHOS_COMPLEX)
result = runCgSolveExample<std::complex<float> >(mpiComm,procRank,numProc,localDim,diagScale,verbose,dumpAll,tolerance,maxNumIters);
if(!result) success = false;
result = runCgSolveExample<std::complex<double> >(mpiComm,procRank,numProc,localDim,diagScale,verbose,dumpAll,tolerance,maxNumIters);
if(!result) success = false;
#endif
}
catch( const std::exception &excpt ) {
std::cerr << "*** p="<<procRank<<": Caught standard exception : " << excpt.what() << std::endl;
success = false;
}
catch( ... ) {
std::cerr << "*** p="<<procRank<<":Caught an unknown exception\n";
success = false;
}
if( verbose && procRank==0 ) {
if(success) std::cout << "\nAll of the tests seem to have run successfully!\n";
else std::cout << "\nOh no! at least one of the tests failed!\n";
}
MPI_Finalize();
return success ? 0 : 1;
}