Tpetra Matrix/Vector Services Version of the Day
MultiPrecCG.hpp
00001 #ifndef MULTIPRECCG_HPP_
00002 #define MULTIPRECCG_HPP_
00003 
00004 #include <Teuchos_TimeMonitor.hpp>
00005 #include <Teuchos_TypeNameTraits.hpp>
00006 #include <Teuchos_ParameterList.hpp>
00007 #include <Teuchos_XMLParameterListHelpers.hpp>
00008 #include <Teuchos_FancyOStream.hpp>
00009 
00010 #include <Tpetra_CrsMatrix.hpp>
00011 #include <Tpetra_Vector.hpp>
00012 #include <Tpetra_RTI.hpp>
00013 #include <Tpetra_MatrixIO.hpp>
00014 
00015 #include <iostream>
00016 #include <functional>
00017 
00018 #ifdef HAVE_TPETRA_QD
00019 # include <qd/qd_real.h>
00020 #endif
00021 
00022 namespace Tpetra {
00023   namespace RTI {
00024     // specialization for pair
00025     template <class T1, class T2>
00026     class ZeroOp<std::pair<T1,T2>> {
00027       public:
00028       static inline std::pair<T1,T2> identity() {
00029         return std::make_pair( Teuchos::ScalarTraits<T1>::zero(), 
00030                                Teuchos::ScalarTraits<T2>::zero() );
00031       }
00032     };
00033   }
00034 }
00035 
00036 namespace TpetraExamples {
00037 
00038   using Teuchos::RCP;
00039   using Teuchos::ParameterList;
00040   using Teuchos::Time;
00041   using Teuchos::null;
00042   using std::binary_function;
00043   using std::pair;
00044   using std::make_pair;
00045   using std::plus;
00046   using std::multiplies;
00047 
00048   struct trivial_fpu_fix {
00049     void fix() {}
00050     void unfix() {}
00051   };
00052   struct nontrivial_fpu_fix {
00053     unsigned int old_cw;
00054     void fix()   {fpu_fix_start(&old_cw);}
00055     void unfix() {fpu_fix_end(&old_cw);}
00056   };
00057   // implementations
00058   template <class T> struct fpu_fix : trivial_fpu_fix {};
00059 #ifdef HAVE_TPETRA_QD
00060   template <> struct fpu_fix<qd_real> : nontrivial_fpu_fix {};
00061   template <> struct fpu_fix<dd_real> : nontrivial_fpu_fix {};
00062 #endif
00063 
00065   template <class Tout, class Tin, class LO, class GO, class Node> 
00066   struct convertHelp {
00067     static RCP<const Tpetra::CrsMatrix<Tout,LO,GO,Node>> doit(const RCP<const Tpetra::CrsMatrix<Tin,LO,GO,Node>> &A)
00068     {
00069       return A->template convert<Tout>();
00070     }
00071   };
00072 
00074   template <class T, class LO, class GO, class Node> 
00075   struct convertHelp<T,T,LO,GO,Node> {
00076     static RCP<const Tpetra::CrsMatrix<T,LO,GO,Node>> doit(const RCP<const Tpetra::CrsMatrix<T,LO,GO,Node>> &A)
00077     {
00078       return A;
00079     }
00080   };
00081 
00082 
00086   template <class T1, class T2, class Op>
00087   class pair_op : public binary_function<pair<T1,T2>,pair<T1,T2>,pair<T1,T2>> {
00088   private:
00089     Op op_;
00090   public:
00091     pair_op(Op op) : op_(op) {}
00092     inline pair<T1,T2> operator()(const pair<T1,T2>& a, const pair<T1,T2>& b) const {
00093       return make_pair(op_(a.first,b.first),op_(a.second,b.second));
00094     }
00095   };
00096 
00098   template <class T1, class T2, class Op>
00099   pair_op<T1,T2,Op> make_pair_op(Op op) { return pair_op<T1,T2,Op>(op); }
00100 
00102   template <class S, class LO, class GO, class Node>
00103   class RFPCGInit 
00104   {
00105     private:
00106     typedef Tpetra::Map<LO,GO,Node>               Map;
00107     typedef Tpetra::CrsMatrix<S,LO,GO,Node> CrsMatrix;
00108     RCP<const CrsMatrix> A;
00109 
00110     public:
00111 
00112     RFPCGInit(RCP<Tpetra::CrsMatrix<S,LO,GO,Node>> Atop) : A(Atop) {}
00113 
00114     template <class T>
00115     RCP<ParameterList> initDB(ParameterList &params) 
00116     {
00117       fpu_fix<T> ff;
00118       ff.fix();
00119       typedef Tpetra::Vector<T,LO,GO,Node>    Vector;
00120       typedef Tpetra::Operator<T,LO,GO,Node>      Op;
00121       typedef Tpetra::CrsMatrix<T,LO,GO,Node>    Mat;
00122       RCP<const Map> map = A->getDomainMap();
00123       RCP<ParameterList> db = Teuchos::parameterList();
00124       RCP<const Mat> AT = convertHelp<T,S,LO,GO,Node>::doit(A);
00125       //
00126       db->set<RCP<const Op>>("A", AT                    );
00127       db->set("numIters", params.get<int>("numIters",A->getGlobalNumRows()) );
00128       db->set("tolerance",params.get<double>("tolerance",1e-7));
00129       db->set("verbose",  params.get<int>("verbose",0) );
00130       db->set("bx",       Tpetra::createVector<T>(map)  );
00131       db->set("r",        Tpetra::createVector<T>(map)  );
00132       db->set("z",        Tpetra::createVector<T>(map)  );
00133       db->set("p",        Tpetra::createVector<T>(map)  );
00134       db->set("Ap",       Tpetra::createVector<T>(map)  );
00135       db->set("rold",     Tpetra::createVector<T>(map)  );
00136       if (params.get<bool>("Extract Diagonal",false)) {
00137         RCP<Vector> diag = Tpetra::createVector<T>(map);
00138         AT->getLocalDiagCopy(*diag);
00139         db->set("diag", diag);
00140       }
00141       ff.unfix();
00142       return db;
00143     }
00144   };
00145 
00146   /******************************
00147   *   Somewhat flexible CG
00148   *   Golub and Ye, 1999
00149   *
00150   *   r = b
00151   *   z = M*r
00152   *   p = z
00153   *   do
00154   *     alpha = r'*z / p'*A*p
00155   *     x = x + alpha*p
00156   *     r = r - alpha*A*p
00157   *     if outermost, check r for convergence
00158   *     z = M*r
00159   *     beta = z'*(r_new - r_old) / z'*r
00160   *     p = z + beta*p
00161   *   enddo
00162   ******************************/
00163 
00165   template <class TS, class LO, class GO, class Node>      
00166   void recursiveFPCG(const RCP<Teuchos::FancyOStream> &out, ParameterList &db)
00167   {
00168     using Teuchos::as;
00169     using Tpetra::RTI::ZeroOp;
00170     typedef typename TS::type       T;
00171     typedef typename TS::next::type T2;
00172     typedef Tpetra::Vector<T ,LO,GO,Node> VectorT1;
00173     typedef Tpetra::Vector<T2,LO,GO,Node> VectorT2;
00174     typedef Tpetra::Operator<T,LO,GO,Node>    OpT1;
00175     typedef Teuchos::ScalarTraits<T>            ST;
00176     // get objects from level database
00177     const int numIters = db.get<int>("numIters");
00178     auto x     = db.get<RCP<VectorT1>>("bx");
00179     auto r     = db.get<RCP<VectorT1>>("r");
00180     auto z     = db.get<RCP<VectorT1>>("z");
00181     auto p     = db.get<RCP<VectorT1>>("p");
00182     auto Ap    = db.get<RCP<VectorT1>>("Ap");
00183     auto rold  = db.get<RCP<VectorT1>>("rold");
00184     auto A     = db.get<RCP<const OpT1>>("A");
00185     RCP<const VectorT1> diag;
00186     if (TS::bottom) {
00187       diag = db.get<RCP<VectorT1>>("diag");
00188     }
00189     const T tolerance = db.get<double>("tolerance", 0.0);
00190     const int verbose = db.get<int>("verbose",0);
00191     static RCP<Time> timer, Atimer;
00192     if (timer == null) {
00193       timer = Teuchos::TimeMonitor::getNewTimer(
00194                       "recursiveFPCG<"+Teuchos::TypeNameTraits<T>::name()+">"
00195               );
00196     }
00197     if (Atimer == null) {
00198       Atimer = Teuchos::TimeMonitor::getNewTimer(
00199                       "A<"+Teuchos::TypeNameTraits<T>::name()+">"
00200               );
00201     }
00202 
00203     fpu_fix<T> ff;
00204     ff.fix();
00205 
00206     Teuchos::OSTab tab(out);
00207 
00208     if (verbose) {
00209       *out << "Beginning recursiveFPCG<" << Teuchos::TypeNameTraits<T>::name() << ">" << std::endl;
00210     }
00211 
00212     timer->start(); 
00213     const T r2 = TPETRA_BINARY_PRETRANSFORM_REDUCE(
00214                     r, x,                                         // fused: 
00215                     x,                                            //      : r = x  
00216                     r*r, ZeroOp<T>, plus<T>() );                  //      : sum r'*r
00217     // b comes in, x goes out. now we're done with b, so zero the solution.
00218     TPETRA_UNARY_TRANSFORM( x,  ST::zero() );                     // x = 0
00219     if (TS::bottom) {
00220       TPETRA_TERTIARY_TRANSFORM( z, diag, r,    r/diag );         // z = D\r
00221     }
00222     else {
00223       ff.unfix();
00224       timer->stop(); 
00225       ParameterList &db_T2 = db.sublist("child");
00226       auto bx_T2 = db_T2.get< RCP<VectorT2>>("bx");
00227       TPETRA_BINARY_TRANSFORM( bx_T2, r, as<T2>(r) );             // b_T2 = (T2)r
00228       recursiveFPCG<typename TS::next,LO,GO,Node>(out,db_T2);     // x_T2 = A_T2 \ b_T2 
00229       TPETRA_BINARY_TRANSFORM( z, bx_T2, as<T>(bx_T2) );          // z    = (T)bx_T2
00230       timer->start(); 
00231       ff.fix();
00232     }
00233     T zr = TPETRA_TERTIARY_PRETRANSFORM_REDUCE( 
00234                     p, z, r,                                      // fused: 
00235                     z,                                            //      : p = z
00236                     z*r, ZeroOp<T>, plus<T>() );                  //      : z*r
00238     int k;
00239     for (k=0; k<numIters; ++k) 
00240     {
00241       Atimer->start();
00242       A->apply(*p,*Ap);                                           // Ap = A*p
00243       Atimer->stop();
00244       T pAp = TPETRA_REDUCE2( p, Ap,     
00245                               p*Ap, ZeroOp<T>, plus<T>() );       // p'*Ap
00246       const T alpha = zr / pAp;
00247       TPETRA_BINARY_TRANSFORM( x,    p,  x + alpha*p  );          // x = x + alpha*p
00248       TPETRA_BINARY_TRANSFORM( rold, r,  r            );          // rold = r
00249       T rr = TPETRA_BINARY_PRETRANSFORM_REDUCE(
00250                                r, Ap,                             // fused:
00251                                r - alpha*Ap,                      //      : r - alpha*Ap
00252                                r*r, ZeroOp<T>, plus<T>() );       //      : sum r'*r
00253       if (verbose > 1) *out << "|res|/|res_0|: " << ST::squareroot(rr/r2) 
00254                             << std::endl;
00255       if (rr/r2 < tolerance*tolerance) {
00256         if (verbose) {
00257           *out << "Convergence detected!" << std::endl;
00258         }
00259         break;
00260       }
00261       if (TS::bottom) {
00262         TPETRA_TERTIARY_TRANSFORM( z, diag, r,    r/diag );        // z = D\r
00263       }
00264       else {
00265         ff.unfix();
00266         timer->stop(); 
00267         ParameterList &db_T2 = db.sublist("child");
00268         auto bx_T2 = db_T2.get< RCP<VectorT2>>("bx");
00269         TPETRA_BINARY_TRANSFORM( bx_T2, r,    as<T2>(r) );         // b_T2 = (T2)r
00270         recursiveFPCG<typename TS::next,LO,GO,Node>(out,db_T2);    // x_T2 = A_T2 \ b_T2
00271         TPETRA_BINARY_TRANSFORM( z, bx_T2,    as<T>(bx_T2) );      // z    = (T)bx_T2
00272         timer->start(); 
00273         ff.fix();
00274       }
00275       const T zoro = zr;                                                         
00276       typedef ZeroOp<pair<T,T>> ZeroPTT;
00277       auto plusTT = make_pair_op<T,T>(plus<T>());
00278       pair<T,T> both = TPETRA_REDUCE3( z, r, rold,                 // fused: z'*r and z'*r_old
00279                                        make_pair(z*r, z*rold), ZeroPTT, plusTT );
00280       zr = both.first; // this is used on the next iteration as well
00281       const T znro = both.second;
00282       const T beta = (zr - znro) / zoro;
00283       TPETRA_BINARY_TRANSFORM( p, z,   z + beta*p );               // p = z + beta*p
00284     }
00285     timer->stop(); 
00286     ff.unfix();
00287     if (verbose) {
00288       *out << "Leaving recursiveFPCG<" << Teuchos::TypeNameTraits<T>::name() 
00289            << "> after " << k << " iterations." << std::endl;
00290     }
00291   }
00292 
00294   template <class TS, class LO, class GO, class Node>      
00295   void recursiveFPCGUnfused(const RCP<Teuchos::FancyOStream> &out, ParameterList &db)
00296   {
00297     using Tpetra::RTI::unary_transform;
00298     using Tpetra::RTI::binary_transform;
00299     using Tpetra::RTI::tertiary_transform;
00300     using Tpetra::RTI::reduce;
00301     using Tpetra::RTI::reductionGlob;
00302     using Tpetra::RTI::ZeroOp;
00303     using Teuchos::as;
00304     typedef typename TS::type       T;
00305     typedef typename TS::next::type T2;
00306     typedef Tpetra::Vector<T ,LO,GO,Node> VectorT1;
00307     typedef Tpetra::Vector<T2,LO,GO,Node> VectorT2;
00308     typedef Tpetra::Operator<T,LO,GO,Node>    OpT1;
00309     typedef Teuchos::ScalarTraits<T>            ST;
00310     // get objects from level database
00311     const int numIters = db.get<int>("numIters");
00312     auto x     = db.get<RCP<VectorT1>>("bx");
00313     auto r     = db.get<RCP<VectorT1>>("r");
00314     auto z     = db.get<RCP<VectorT1>>("z");
00315     auto p     = db.get<RCP<VectorT1>>("p");
00316     auto Ap    = db.get<RCP<VectorT1>>("Ap");
00317     auto rold  = db.get<RCP<VectorT1>>("rold");
00318     auto A     = db.get<RCP<const OpT1>>("A");
00319     static RCP<Time> timer;
00320     if (timer == null) {
00321       timer = Teuchos::TimeMonitor::getNewTimer(
00322                       "recursiveFPCGUnfused<"+Teuchos::TypeNameTraits<T>::name()+">"
00323               );
00324     }
00325     RCP<const VectorT1> diag;
00326     if (TS::bottom) {
00327       diag = db.get<RCP<VectorT1>>("diag");
00328     }
00329     const T tolerance = db.get<double>("tolerance", 0.0);
00330     const int verbose = db.get<int>("verbose",0);
00331 
00332     fpu_fix<T> ff;
00333     ff.fix();
00334 
00335     Teuchos::OSTab tab(out);
00336 
00337     if (verbose) {
00338       *out << "Beginning recursiveFPCGUnfused<" << Teuchos::TypeNameTraits<T>::name() << ">" << std::endl;
00339     }
00340 
00341     timer->start(); 
00342     binary_transform( *r, *x,            [](T, T bi)             {return bi;});  // r = b     (b is stored in x)
00343     const T r2 = reduce( *r, *r,      reductionGlob<ZeroOp<T>>(multiplies<T>(),  // r'*r
00344                                                                    plus<T>())); 
00345     unary_transform(  *x,                [](T)           {return ST::zero();});  // set x = 0 (now that we don't need b)
00346     if (TS::bottom) {
00347       tertiary_transform( *z, *diag, *r, [](T, T di, T ri)    {return ri/di;});  // z = D\r
00348     }
00349     else {
00350       ff.unfix();
00351       timer->stop(); 
00352       ParameterList &db_T2 = db.sublist("child");
00353       auto bx_T2 = db_T2.get< RCP<VectorT2>>("bx");
00354       binary_transform( *bx_T2, *r, [](T2, T ri)         {return as<T2>(ri);});  // b_T2 = (T2)r       
00355       recursiveFPCGUnfused<typename TS::next,LO,GO,Node>(out,db_T2);             // x_T2 = A_T2 \ b_T2 
00356       binary_transform( *z, *bx_T2, [](T, T2 xi)          {return as<T>(xi);});  // z    = (T)x_T2     
00357       timer->start(); 
00358       ff.fix();
00359     }
00360     binary_transform( *p, *z, [](T, T zi)                        {return zi;});  // p = z
00361     T zr = reduce( *z, *r,          reductionGlob<ZeroOp<T>>(multiplies<T>(),    // z'*r
00362                                                                    plus<T>())); 
00364     int k;
00365     for (k=0; k<numIters; ++k) 
00366     {
00367       A->apply(*p,*Ap);                                                          // Ap = A*p
00368       T pAp = reduce( *p, *Ap,      reductionGlob<ZeroOp<T>>(multiplies<T>(),    // p'*Ap
00369                                                                    plus<T>())); 
00370       const T alpha = zr / pAp;
00371       binary_transform( *x, *p, [alpha](T xi, T pi)   {return xi + alpha*pi;});  // x = x + alpha*p
00372       binary_transform( *rold, *r, [](T, T ri)                   {return ri;});  // rold = r
00373       binary_transform( *r, *Ap,[alpha](T ri, T Api) {return ri - alpha*Api;});  // r = r - alpha*Ap
00374       T rr = reduce( *r, *r,      reductionGlob<ZeroOp<T>>(multiplies<T>(),      // r'*r
00375                                                                  plus<T>())); 
00376       if (verbose > 1) *out << "|res|/|res_0|: " << ST::squareroot(rr/r2) << std::endl;
00377       if (rr/r2 < tolerance*tolerance) {
00378         if (verbose) {
00379           *out << "Convergence detected!" << std::endl;
00380         }
00381         break;
00382       }
00383       if (TS::bottom) {
00384         tertiary_transform( *z, *diag, *r, [](T, T di, T ri)  {return ri/di;});  // z = D\r
00385       }
00386       else {
00387         ff.unfix();
00388         timer->stop(); 
00389         ParameterList &db_T2 = db.sublist("child");
00390         auto bx_T2 = db_T2.get< RCP<VectorT2>>("bx");
00391         binary_transform( *bx_T2, *r, [](T2, T ri)       {return as<T2>(ri);});  // b_T2 = (T2)r
00392         recursiveFPCGUnfused<typename TS::next,LO,GO,Node>(out,db_T2);           // x_T2 = A_T2 \ b_T2
00393         binary_transform( *z, *bx_T2, [](T, T2 xi)        {return as<T>(xi);});  // z    = (T)x_T2
00394         timer->start(); 
00395         ff.fix();
00396       }
00397       const T zoro = zr;                                                         
00398       zr = reduce( *z, *r,          reductionGlob<ZeroOp<T>>(multiplies<T>(),    // z'*r
00399                                                              plus<T>()));        // this is loop-carried
00400       const T znro = reduce( *z, *rold, reductionGlob<ZeroOp<T>>(multiplies<T>(),// z'*r_old
00401                                                                  plus<T>())); 
00402       const T beta = (zr - znro) / zoro;
00403       binary_transform( *p, *z, [beta](T pi, T zi)     {return zi + beta*pi;});  // p = z + beta*p
00404     }
00405     timer->stop(); 
00406     ff.unfix();
00407     if (verbose) {
00408       *out << "Leaving recursiveFPCGUnfused<" << Teuchos::TypeNameTraits<T>::name() << "> after " << k << " iterations." << std::endl;
00409     }
00410   }
00411 
00412 } // end of namespace TpetraExamples
00413 
00414 #endif // MULTIPRECCG_HPP_
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines