Tpetra Matrix/Vector Services Version of the Day
MultiPrecCG.hpp
00001 /*
00002 // @HEADER
00003 // ***********************************************************************
00004 // 
00005 //          Tpetra: Templated Linear Algebra Services Package
00006 //                 Copyright (2008) Sandia Corporation
00007 // 
00008 // Under the terms of Contract DE-AC04-94AL85000 with Sandia Corporation,
00009 // the U.S. Government retains certain rights in this software.
00010 // 
00011 // Redistribution and use in source and binary forms, with or without
00012 // modification, are permitted provided that the following conditions are
00013 // met:
00014 //
00015 // 1. Redistributions of source code must retain the above copyright
00016 // notice, this list of conditions and the following disclaimer.
00017 //
00018 // 2. Redistributions in binary form must reproduce the above copyright
00019 // notice, this list of conditions and the following disclaimer in the
00020 // documentation and/or other materials provided with the distribution.
00021 //
00022 // 3. Neither the name of the Corporation nor the names of the
00023 // contributors may be used to endorse or promote products derived from
00024 // this software without specific prior written permission.
00025 //
00026 // THIS SOFTWARE IS PROVIDED BY SANDIA CORPORATION "AS IS" AND ANY
00027 // EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
00028 // IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
00029 // PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL SANDIA CORPORATION OR THE
00030 // CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
00031 // EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
00032 // PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
00033 // PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
00034 // LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
00035 // NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
00036 // SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
00037 //
00038 // Questions? Contact Michael A. Heroux (maherou@sandia.gov) 
00039 // 
00040 // ************************************************************************
00041 // @HEADER
00042 */
00043 
00044 #ifndef MULTIPRECCG_HPP_
00045 #define MULTIPRECCG_HPP_
00046 
00047 #include <Teuchos_TimeMonitor.hpp>
00048 #include <Teuchos_TypeNameTraits.hpp>
00049 #include <Teuchos_ParameterList.hpp>
00050 #include <Teuchos_XMLParameterListHelpers.hpp>
00051 #include <Teuchos_FancyOStream.hpp>
00052 
00053 #include <Tpetra_CrsMatrix.hpp>
00054 #include <Tpetra_Vector.hpp>
00055 #include <Tpetra_RTI.hpp>
00056 #include <Tpetra_MatrixIO.hpp>
00057 
00058 #include <iostream>
00059 #include <functional>
00060 
00061 #ifdef HAVE_TPETRA_QD
00062 # include <qd/qd_real.h>
00063 #endif
00064 
00065 namespace Tpetra {
00066   namespace RTI {
00067     // specialization for pair
00068     template <class T1, class T2>
00069     class ZeroOp<std::pair<T1,T2>> {
00070       public:
00071       static inline std::pair<T1,T2> identity() {
00072         return std::make_pair( Teuchos::ScalarTraits<T1>::zero(), 
00073                                Teuchos::ScalarTraits<T2>::zero() );
00074       }
00075     };
00076   }
00077 }
00078 
00079 namespace TpetraExamples {
00080 
00081   using Teuchos::RCP;
00082   using Teuchos::ParameterList;
00083   using Teuchos::Time;
00084   using Teuchos::null;
00085   using std::binary_function;
00086   using std::pair;
00087   using std::make_pair;
00088   using std::plus;
00089   using std::multiplies;
00090 
00091   struct trivial_fpu_fix {
00092     void fix() {}
00093     void unfix() {}
00094   };
00095 #ifdef HAVE_TPETRA_QD
00096   struct nontrivial_fpu_fix {
00097     unsigned int old_cw;
00098     void fix()   {fpu_fix_start(&old_cw);}
00099     void unfix() {fpu_fix_end(&old_cw);}
00100   };
00101 #endif
00102   // implementations
00103   template <class T> struct fpu_fix : trivial_fpu_fix {};
00104 #ifdef HAVE_TPETRA_QD
00105   template <> struct fpu_fix<qd_real> : nontrivial_fpu_fix {};
00106   template <> struct fpu_fix<dd_real> : nontrivial_fpu_fix {};
00107 #endif
00108 
00110   template <class Tout, class Tin, class LO, class GO, class Node> 
00111   struct convertHelp {
00112     static RCP<const Tpetra::CrsMatrix<Tout,LO,GO,Node>> doit(const RCP<const Tpetra::CrsMatrix<Tin,LO,GO,Node>> &A)
00113     {
00114       return A->template convert<Tout>();
00115     }
00116   };
00117 
00119   template <class T, class LO, class GO, class Node> 
00120   struct convertHelp<T,T,LO,GO,Node> {
00121     static RCP<const Tpetra::CrsMatrix<T,LO,GO,Node>> doit(const RCP<const Tpetra::CrsMatrix<T,LO,GO,Node>> &A)
00122     {
00123       return A;
00124     }
00125   };
00126 
00127 
00131   template <class T1, class T2, class Op>
00132   class pair_op : public binary_function<pair<T1,T2>,pair<T1,T2>,pair<T1,T2>> {
00133   private:
00134     Op op_;
00135   public:
00136     pair_op(Op op) : op_(op) {}
00137     inline pair<T1,T2> operator()(const pair<T1,T2>& a, const pair<T1,T2>& b) const {
00138       return make_pair(op_(a.first,b.first),op_(a.second,b.second));
00139     }
00140   };
00141 
00143   template <class T1, class T2, class Op>
00144   pair_op<T1,T2,Op> make_pair_op(Op op) { return pair_op<T1,T2,Op>(op); }
00145 
00147   template <class S, class LO, class GO, class Node>
00148   class RFPCGInit 
00149   {
00150     private:
00151     typedef Tpetra::Map<LO,GO,Node>               Map;
00152     typedef Tpetra::CrsMatrix<S,LO,GO,Node> CrsMatrix;
00153     RCP<const CrsMatrix> A;
00154 
00155     public:
00156 
00157     RFPCGInit(RCP<Tpetra::CrsMatrix<S,LO,GO,Node>> Atop) : A(Atop) {}
00158 
00159     template <class T>
00160     RCP<ParameterList> initDB(ParameterList &params) 
00161     {
00162       fpu_fix<T> ff;
00163       ff.fix();
00164       typedef Tpetra::Vector<T,LO,GO,Node>    Vector;
00165       typedef Tpetra::Operator<T,LO,GO,Node>      Op;
00166       typedef Tpetra::CrsMatrix<T,LO,GO,Node>    Mat;
00167       RCP<const Map> map = A->getDomainMap();
00168       RCP<ParameterList> db = Teuchos::parameterList();
00169       RCP<const Mat> AT = convertHelp<T,S,LO,GO,Node>::doit(A);
00170       //
00171       db->set<RCP<const Op>>("A", AT                    );
00172       db->set("numIters", params.get<int>("numIters",A->getGlobalNumRows()) );
00173       db->set("tolerance",params.get<double>("tolerance",1e-7));
00174       db->set("verbose",  params.get<int>("verbose",0) );
00175       db->set("bx",       Tpetra::createVector<T>(map)  );
00176       db->set("r",        Tpetra::createVector<T>(map)  );
00177       db->set("z",        Tpetra::createVector<T>(map)  );
00178       db->set("p",        Tpetra::createVector<T>(map)  );
00179       db->set("Ap",       Tpetra::createVector<T>(map)  );
00180       db->set("rold",     Tpetra::createVector<T>(map)  );
00181       if (params.get<bool>("Extract Diagonal",false)) {
00182         RCP<Vector> diag = Tpetra::createVector<T>(map);
00183         AT->getLocalDiagCopy(*diag);
00184         db->set("diag", diag);
00185       }
00186       ff.unfix();
00187       return db;
00188     }
00189   };
00190 
00191   /******************************
00192   *   Somewhat flexible CG
00193   *   Golub and Ye, 1999
00194   *
00195   *   r = b
00196   *   z = M*r
00197   *   p = z
00198   *   do
00199   *     alpha = r'*z / p'*A*p
00200   *     x = x + alpha*p
00201   *     r = r - alpha*A*p
00202   *     if outermost, check r for convergence
00203   *     z = M*r
00204   *     beta = z'*(r_new - r_old) / z'*r
00205   *     p = z + beta*p
00206   *   enddo
00207   ******************************/
00208 
00210   template <class TS, class LO, class GO, class Node>      
00211   void recursiveFPCG(const RCP<Teuchos::FancyOStream> &out, ParameterList &db)
00212   {
00213     using Teuchos::as;
00214     using Tpetra::RTI::ZeroOp;
00215     typedef typename TS::type       T;
00216     typedef typename TS::next::type T2;
00217     typedef Tpetra::Vector<T ,LO,GO,Node> VectorT1;
00218     typedef Tpetra::Vector<T2,LO,GO,Node> VectorT2;
00219     typedef Tpetra::Operator<T,LO,GO,Node>    OpT1;
00220     typedef Teuchos::ScalarTraits<T>            ST;
00221     // get objects from level database
00222     const int numIters = db.get<int>("numIters");
00223     auto x     = db.get<RCP<VectorT1>>("bx");
00224     auto r     = db.get<RCP<VectorT1>>("r");
00225     auto z     = db.get<RCP<VectorT1>>("z");
00226     auto p     = db.get<RCP<VectorT1>>("p");
00227     auto Ap    = db.get<RCP<VectorT1>>("Ap");
00228     auto rold  = db.get<RCP<VectorT1>>("rold");
00229     auto A     = db.get<RCP<const OpT1>>("A");
00230     RCP<const VectorT1> diag;
00231     if (TS::bottom) {
00232       diag = db.get<RCP<VectorT1>>("diag");
00233     }
00234     const T tolerance = db.get<double>("tolerance", 0.0);
00235     const int verbose = db.get<int>("verbose",0);
00236     static RCP<Time> timer, Atimer;
00237     if (timer == null) {
00238       timer = Teuchos::TimeMonitor::getNewTimer(
00239                       "recursiveFPCG<"+Teuchos::TypeNameTraits<T>::name()+">"
00240               );
00241     }
00242     if (Atimer == null) {
00243       Atimer = Teuchos::TimeMonitor::getNewTimer(
00244                       "A<"+Teuchos::TypeNameTraits<T>::name()+">"
00245               );
00246     }
00247 
00248     fpu_fix<T> ff;
00249     ff.fix();
00250 
00251     Teuchos::OSTab tab(out);
00252 
00253     if (verbose) {
00254       *out << "Beginning recursiveFPCG<" << Teuchos::TypeNameTraits<T>::name() << ">" << std::endl;
00255     }
00256 
00257     timer->start(); 
00258     const T r2 = TPETRA_BINARY_PRETRANSFORM_REDUCE(
00259                     r, x,                                         // fused: 
00260                     x,                                            //      : r = x  
00261                     r*r, ZeroOp<T>, plus<T>() );                  //      : sum r'*r
00262     // b comes in, x goes out. now we're done with b, so zero the solution.
00263     TPETRA_UNARY_TRANSFORM( x,  ST::zero() );                     // x = 0
00264     if (TS::bottom) {
00265       TPETRA_TERTIARY_TRANSFORM( z, diag, r,    r/diag );         // z = D\r
00266     }
00267     else {
00268       ff.unfix();
00269       timer->stop(); 
00270       ParameterList &db_T2 = db.sublist("child");
00271       auto bx_T2 = db_T2.get< RCP<VectorT2>>("bx");
00272       TPETRA_BINARY_TRANSFORM( bx_T2, r, as<T2>(r) );             // b_T2 = (T2)r
00273       recursiveFPCG<typename TS::next,LO,GO,Node>(out,db_T2);     // x_T2 = A_T2 \ b_T2 
00274       TPETRA_BINARY_TRANSFORM( z, bx_T2, as<T>(bx_T2) );          // z    = (T)bx_T2
00275       timer->start(); 
00276       ff.fix();
00277     }
00278     T zr = TPETRA_TERTIARY_PRETRANSFORM_REDUCE( 
00279                     p, z, r,                                      // fused: 
00280                     z,                                            //      : p = z
00281                     z*r, ZeroOp<T>, plus<T>() );                  //      : z*r
00283     int k;
00284     for (k=0; k<numIters; ++k) 
00285     {
00286       Atimer->start();
00287       A->apply(*p,*Ap);                                           // Ap = A*p
00288       Atimer->stop();
00289       T pAp = TPETRA_REDUCE2( p, Ap,     
00290                               p*Ap, ZeroOp<T>, plus<T>() );       // p'*Ap
00291       const T alpha = zr / pAp;
00292       TPETRA_BINARY_TRANSFORM( x,    p,  x + alpha*p  );          // x = x + alpha*p
00293       TPETRA_BINARY_TRANSFORM( rold, r,  r            );          // rold = r
00294       T rr = TPETRA_BINARY_PRETRANSFORM_REDUCE(
00295                                r, Ap,                             // fused:
00296                                r - alpha*Ap,                      //      : r - alpha*Ap
00297                                r*r, ZeroOp<T>, plus<T>() );       //      : sum r'*r
00298       if (verbose > 1) *out << "|res|/|res_0|: " << ST::squareroot(rr/r2) 
00299                             << std::endl;
00300       if (rr/r2 < tolerance*tolerance) {
00301         if (verbose) {
00302           *out << "Convergence detected!" << std::endl;
00303         }
00304         break;
00305       }
00306       if (TS::bottom) {
00307         TPETRA_TERTIARY_TRANSFORM( z, diag, r,    r/diag );        // z = D\r
00308       }
00309       else {
00310         ff.unfix();
00311         timer->stop(); 
00312         ParameterList &db_T2 = db.sublist("child");
00313         auto bx_T2 = db_T2.get< RCP<VectorT2>>("bx");
00314         TPETRA_BINARY_TRANSFORM( bx_T2, r,    as<T2>(r) );         // b_T2 = (T2)r
00315         recursiveFPCG<typename TS::next,LO,GO,Node>(out,db_T2);    // x_T2 = A_T2 \ b_T2
00316         TPETRA_BINARY_TRANSFORM( z, bx_T2,    as<T>(bx_T2) );      // z    = (T)bx_T2
00317         timer->start(); 
00318         ff.fix();
00319       }
00320       const T zoro = zr;                                                         
00321       typedef ZeroOp<pair<T,T>> ZeroPTT;
00322       auto plusTT = make_pair_op<T,T>(plus<T>());
00323       pair<T,T> both = TPETRA_REDUCE3( z, r, rold,                 // fused: z'*r and z'*r_old
00324                                        make_pair(z*r, z*rold), ZeroPTT, plusTT );
00325       zr = both.first; // this is used on the next iteration as well
00326       const T znro = both.second;
00327       const T beta = (zr - znro) / zoro;
00328       TPETRA_BINARY_TRANSFORM( p, z,   z + beta*p );               // p = z + beta*p
00329     }
00330     timer->stop(); 
00331     ff.unfix();
00332     if (verbose) {
00333       *out << "Leaving recursiveFPCG<" << Teuchos::TypeNameTraits<T>::name() 
00334            << "> after " << k << " iterations." << std::endl;
00335     }
00336   }
00337 
00339   template <class TS, class LO, class GO, class Node>      
00340   void recursiveFPCGUnfused(const RCP<Teuchos::FancyOStream> &out, ParameterList &db)
00341   {
00342     using Tpetra::RTI::unary_transform;
00343     using Tpetra::RTI::binary_transform;
00344     using Tpetra::RTI::tertiary_transform;
00345     using Tpetra::RTI::reduce;
00346     using Tpetra::RTI::reductionGlob;
00347     using Tpetra::RTI::ZeroOp;
00348     using Teuchos::as;
00349     typedef typename TS::type       T;
00350     typedef typename TS::next::type T2;
00351     typedef Tpetra::Vector<T ,LO,GO,Node> VectorT1;
00352     typedef Tpetra::Vector<T2,LO,GO,Node> VectorT2;
00353     typedef Tpetra::Operator<T,LO,GO,Node>    OpT1;
00354     typedef Teuchos::ScalarTraits<T>            ST;
00355     // get objects from level database
00356     const int numIters = db.get<int>("numIters");
00357     auto x     = db.get<RCP<VectorT1>>("bx");
00358     auto r     = db.get<RCP<VectorT1>>("r");
00359     auto z     = db.get<RCP<VectorT1>>("z");
00360     auto p     = db.get<RCP<VectorT1>>("p");
00361     auto Ap    = db.get<RCP<VectorT1>>("Ap");
00362     auto rold  = db.get<RCP<VectorT1>>("rold");
00363     auto A     = db.get<RCP<const OpT1>>("A");
00364     static RCP<Time> timer;
00365     if (timer == null) {
00366       timer = Teuchos::TimeMonitor::getNewTimer(
00367                       "recursiveFPCGUnfused<"+Teuchos::TypeNameTraits<T>::name()+">"
00368               );
00369     }
00370     RCP<const VectorT1> diag;
00371     if (TS::bottom) {
00372       diag = db.get<RCP<VectorT1>>("diag");
00373     }
00374     const T tolerance = db.get<double>("tolerance", 0.0);
00375     const int verbose = db.get<int>("verbose",0);
00376 
00377     fpu_fix<T> ff;
00378     ff.fix();
00379 
00380     Teuchos::OSTab tab(out);
00381 
00382     if (verbose) {
00383       *out << "Beginning recursiveFPCGUnfused<" << Teuchos::TypeNameTraits<T>::name() << ">" << std::endl;
00384     }
00385 
00386     timer->start(); 
00387     binary_transform( *r, *x,            [](T, T bi)             {return bi;});  // r = b     (b is stored in x)
00388     const T r2 = reduce( *r, *r,      reductionGlob<ZeroOp<T>>(multiplies<T>(),  // r'*r
00389                                                                    plus<T>())); 
00390     unary_transform(  *x,                [](T)           {return ST::zero();});  // set x = 0 (now that we don't need b)
00391     if (TS::bottom) {
00392       tertiary_transform( *z, *diag, *r, [](T, T di, T ri)    {return ri/di;});  // z = D\r
00393     }
00394     else {
00395       ff.unfix();
00396       timer->stop(); 
00397       ParameterList &db_T2 = db.sublist("child");
00398       auto bx_T2 = db_T2.get< RCP<VectorT2>>("bx");
00399       binary_transform( *bx_T2, *r, [](T2, T ri)         {return as<T2>(ri);});  // b_T2 = (T2)r       
00400       recursiveFPCGUnfused<typename TS::next,LO,GO,Node>(out,db_T2);             // x_T2 = A_T2 \ b_T2 
00401       binary_transform( *z, *bx_T2, [](T, T2 xi)          {return as<T>(xi);});  // z    = (T)x_T2     
00402       timer->start(); 
00403       ff.fix();
00404     }
00405     binary_transform( *p, *z, [](T, T zi)                        {return zi;});  // p = z
00406     T zr = reduce( *z, *r,          reductionGlob<ZeroOp<T>>(multiplies<T>(),    // z'*r
00407                                                                    plus<T>())); 
00409     int k;
00410     for (k=0; k<numIters; ++k) 
00411     {
00412       A->apply(*p,*Ap);                                                          // Ap = A*p
00413       T pAp = reduce( *p, *Ap,      reductionGlob<ZeroOp<T>>(multiplies<T>(),    // p'*Ap
00414                                                                    plus<T>())); 
00415       const T alpha = zr / pAp;
00416       binary_transform( *x, *p, [alpha](T xi, T pi)   {return xi + alpha*pi;});  // x = x + alpha*p
00417       binary_transform( *rold, *r, [](T, T ri)                   {return ri;});  // rold = r
00418       binary_transform( *r, *Ap,[alpha](T ri, T Api) {return ri - alpha*Api;});  // r = r - alpha*Ap
00419       T rr = reduce( *r, *r,      reductionGlob<ZeroOp<T>>(multiplies<T>(),      // r'*r
00420                                                                  plus<T>())); 
00421       if (verbose > 1) *out << "|res|/|res_0|: " << ST::squareroot(rr/r2) << std::endl;
00422       if (rr/r2 < tolerance*tolerance) {
00423         if (verbose) {
00424           *out << "Convergence detected!" << std::endl;
00425         }
00426         break;
00427       }
00428       if (TS::bottom) {
00429         tertiary_transform( *z, *diag, *r, [](T, T di, T ri)  {return ri/di;});  // z = D\r
00430       }
00431       else {
00432         ff.unfix();
00433         timer->stop(); 
00434         ParameterList &db_T2 = db.sublist("child");
00435         auto bx_T2 = db_T2.get< RCP<VectorT2>>("bx");
00436         binary_transform( *bx_T2, *r, [](T2, T ri)       {return as<T2>(ri);});  // b_T2 = (T2)r
00437         recursiveFPCGUnfused<typename TS::next,LO,GO,Node>(out,db_T2);           // x_T2 = A_T2 \ b_T2
00438         binary_transform( *z, *bx_T2, [](T, T2 xi)        {return as<T>(xi);});  // z    = (T)x_T2
00439         timer->start(); 
00440         ff.fix();
00441       }
00442       const T zoro = zr;                                                         
00443       zr = reduce( *z, *r,          reductionGlob<ZeroOp<T>>(multiplies<T>(),    // z'*r
00444                                                              plus<T>()));        // this is loop-carried
00445       const T znro = reduce( *z, *rold, reductionGlob<ZeroOp<T>>(multiplies<T>(),// z'*r_old
00446                                                                  plus<T>())); 
00447       const T beta = (zr - znro) / zoro;
00448       binary_transform( *p, *z, [beta](T pi, T zi)     {return zi + beta*pi;});  // p = z + beta*p
00449     }
00450     timer->stop(); 
00451     ff.unfix();
00452     if (verbose) {
00453       *out << "Leaving recursiveFPCGUnfused<" << Teuchos::TypeNameTraits<T>::name() << "> after " << k << " iterations." << std::endl;
00454     }
00455   }
00456 
00457 } // end of namespace TpetraExamples
00458 
00459 #endif // MULTIPRECCG_HPP_
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines