Tpetra Matrix/Vector Services Version of the Day
MultiPrecDriver.hpp
00001 #ifndef MULTIPREC_DRIVER_HPP
00002 #define MULTIPREC_DRIVER_HPP
00003 
00004 #include <Teuchos_TypeNameTraits.hpp>
00005 #include <Teuchos_TimeMonitor.hpp>
00006 #include <Teuchos_ParameterList.hpp>
00007 #include <Teuchos_FancyOStream.hpp>
00008 
00009 #include <Tpetra_Version.hpp>
00010 #include <Tpetra_MatrixIO.hpp>
00011 #include <Tpetra_RTI.hpp>
00012 #include <Tpetra_CrsMatrix.hpp>
00013 #include <Tpetra_Vector.hpp>
00014 
00015 #include "MultiPrecCG.hpp"
00016 
00017 template <class MPStack>
00018 class MultiPrecDriver {
00019   public:
00020   // input
00021   Teuchos::RCP<Teuchos::FancyOStream>     out; 
00022   Teuchos::RCP<Teuchos::ParameterList> params;
00023   std::string                      matrixFile;
00024   bool                             unfusedTest;
00025   // output
00026   bool                             testPassed;
00027 
00028   template <class Node> 
00029   void run(Teuchos::ParameterList &myMachPL, const Teuchos::RCP<const Teuchos::Comm<int> > &comm, const Teuchos::RCP<Node> &node) 
00030   {
00031     using std::pair;
00032     using std::make_pair;
00033     using std::plus;
00034     using std::endl;
00035     using Teuchos::RCP;
00036     using Teuchos::ParameterList;
00037     using TpetraExamples::make_pair_op;
00038     using Tpetra::RTI::reductionGlob;
00039     using Tpetra::RTI::ZeroOp;
00040     using Tpetra::RTI::binary_pre_transform_reduce;
00041     using Tpetra::RTI::binary_transform;
00042 
00043     // Static types
00044     typedef typename MPStack::type   S;
00045     typedef int                     LO;
00046     typedef int                     GO;
00047     typedef Tpetra::CrsMatrix<S,LO,GO,Node> CrsMatrix;
00048     typedef Tpetra::Vector<S,LO,GO,Node>       Vector;
00049 
00050     *out << "Running test with Node==" << Teuchos::typeName(*node) << " on rank " << comm->getRank() << "/" << comm->getSize() << std::endl;
00051 
00052     // read the matrix
00053     RCP<CrsMatrix> A;
00054     Tpetra::Utils::readHBMatrix(matrixFile,comm,node,A);
00055 
00056     // init the solver stack
00057     TpetraExamples::RFPCGInit<S,LO,GO,Node> init(A);
00058     RCP<ParameterList> db = Tpetra::Ext::initStackDB<MPStack>(*params,init);
00059 
00060     testPassed = true;
00061 
00062     // choose a solution, compute a right-hand-side
00063     auto x = Tpetra::createVector<S>(A->getRowMap()),
00064          b = Tpetra::createVector<S>(A->getRowMap());
00065     x->randomize();
00066     A->apply(*x,*b);
00067     {
00068       // init the rhs
00069       auto bx = db->get<RCP<Vector>>("bx");
00070       binary_transform( *bx, *b, [](S, S bi) {return bi;}); // bx = b
00071     }
00072 
00073     // call the solve
00074     TpetraExamples::recursiveFPCG<MPStack,LO,GO,Node>(out,*db);
00075 
00076     // check that residual is as requested
00077     {
00078       auto xhat = db->get<RCP<Vector>>("bx"),
00079            bhat = Tpetra::createVector<S>(A->getRowMap());
00080       A->apply(*xhat,*bhat);
00081       // compute bhat-b, while simultaneously computing |bhat-b|^2 and |b|^2
00082       auto nrms = binary_pre_transform_reduce(*bhat, *b, 
00083                                               reductionGlob<ZeroOp<pair<S,S>>>( 
00084                                                 [](S bhati, S bi){ return bi-bhati;}, // bhati = bi-bhat
00085                                                 [](S bhati, S bi){ return make_pair(bhati*bhati, bi*bi); },
00086                                                 make_pair_op<S,S>(plus<S>())) );
00087       const S enrm = Teuchos::ScalarTraits<S>::squareroot(nrms.first),
00088               bnrm = Teuchos::ScalarTraits<S>::squareroot(nrms.second);
00089       // check that residual is as requested
00090       *out << "|b - A*x|/|b|: " << enrm / bnrm << endl;
00091       const double tolerance = db->get<double>("tolerance");
00092       if (MPStack::bottom) {
00093         // give a little slack
00094         if (enrm / bnrm > 5*tolerance) testPassed = false;
00095       }
00096       else {
00097         if (enrm / bnrm > tolerance) testPassed = false;
00098       }
00099     }
00100 
00101     // 
00102     // solve again, with the unfused version, just for timings purposes
00103     if (unfusedTest) 
00104     {
00105       // init the rhs
00106       auto bx = db->get<RCP<Vector>>("bx");
00107       binary_transform( *bx, *b, [](S, S bi) {return bi;}); // bx = b
00108       // call the solve
00109       TpetraExamples::recursiveFPCGUnfused<MPStack,LO,GO,Node>(out,*db);
00110       //
00111       // test the result
00112       auto xhat = db->get<RCP<Vector>>("bx"),
00113            bhat = Tpetra::createVector<S>(A->getRowMap());
00114       A->apply(*xhat,*bhat);
00115       // compute bhat-b, while simultaneously computing |bhat-b|^2 and |b|^2
00116       auto nrms = binary_pre_transform_reduce(*bhat, *b, 
00117                                               reductionGlob<ZeroOp<pair<S,S>>>( 
00118                                                 [](S bhati, S bi){ return bi-bhati;}, // bhati = bi-bhat
00119                                                 [](S bhati, S bi){ return make_pair(bhati*bhati, bi*bi); },
00120                                                 make_pair_op<S,S>(plus<S>())) );
00121       const S enrm = Teuchos::ScalarTraits<S>::squareroot(nrms.first),
00122               bnrm = Teuchos::ScalarTraits<S>::squareroot(nrms.second);
00123       // check that residual is as requested
00124       *out << "|b - A*x|/|b|: " << enrm / bnrm << endl;
00125       const double tolerance = db->get<double>("tolerance");
00126       if (MPStack::bottom) {
00127         // give a little slack
00128         if (enrm / bnrm > 5*tolerance) testPassed = false;
00129       }
00130       else {
00131         if (enrm / bnrm > tolerance) testPassed = false;
00132       }
00133     }    
00134          
00135          
00136     // print timings
00137     Teuchos::TimeMonitor::summarize( *out );
00138   }
00139 };
00140 
00141 #endif // MULTIPREC_DRIVER_HPP
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines