Tpetra Matrix/Vector Services Version of the Day
MultiPrecCG.hpp
00001 /*
00002 // @HEADER
00003 // ***********************************************************************
00004 // 
00005 //          Tpetra: Templated Linear Algebra Services Package
00006 //                 Copyright (2008) Sandia Corporation
00007 // 
00008 // Under the terms of Contract DE-AC04-94AL85000 with Sandia Corporation,
00009 // the U.S. Government retains certain rights in this software.
00010 // 
00011 // Redistribution and use in source and binary forms, with or without
00012 // modification, are permitted provided that the following conditions are
00013 // met:
00014 //
00015 // 1. Redistributions of source code must retain the above copyright
00016 // notice, this list of conditions and the following disclaimer.
00017 //
00018 // 2. Redistributions in binary form must reproduce the above copyright
00019 // notice, this list of conditions and the following disclaimer in the
00020 // documentation and/or other materials provided with the distribution.
00021 //
00022 // 3. Neither the name of the Corporation nor the names of the
00023 // contributors may be used to endorse or promote products derived from
00024 // this software without specific prior written permission.
00025 //
00026 // THIS SOFTWARE IS PROVIDED BY SANDIA CORPORATION "AS IS" AND ANY
00027 // EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
00028 // IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
00029 // PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL SANDIA CORPORATION OR THE
00030 // CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
00031 // EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
00032 // PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
00033 // PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
00034 // LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
00035 // NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
00036 // SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
00037 //
00038 // Questions? Contact Michael A. Heroux (maherou@sandia.gov) 
00039 // 
00040 // ************************************************************************
00041 // @HEADER
00042 */
00043 
00044 #ifndef MULTIPRECCG_HPP_
00045 #define MULTIPRECCG_HPP_
00046 
00047 #include <Teuchos_TimeMonitor.hpp>
00048 #include <Teuchos_TypeNameTraits.hpp>
00049 #include <Teuchos_ParameterList.hpp>
00050 #include <Teuchos_XMLParameterListHelpers.hpp>
00051 #include <Teuchos_FancyOStream.hpp>
00052 
00053 #include <Tpetra_CrsMatrix.hpp>
00054 #include <Tpetra_Vector.hpp>
00055 #include <Tpetra_RTI.hpp>
00056 #include <Tpetra_MatrixIO.hpp>
00057 
00058 #include <iostream>
00059 #include <functional>
00060 
00061 #ifdef HAVE_TPETRA_QD
00062 # include <qd/qd_real.h>
00063 #endif
00064 
00065 namespace Tpetra {
00066   namespace RTI {
00067     // specialization for pair
00068     template <class T1, class T2>
00069     class ZeroOp<std::pair<T1,T2>> {
00070       public:
00071       static inline std::pair<T1,T2> identity() {
00072         return std::make_pair( Teuchos::ScalarTraits<T1>::zero(), 
00073                                Teuchos::ScalarTraits<T2>::zero() );
00074       }
00075     };
00076   }
00077 }
00078 
00079 namespace TpetraExamples {
00080 
00081   using Teuchos::RCP;
00082   using Teuchos::ParameterList;
00083   using Teuchos::Time;
00084   using Teuchos::null;
00085   using std::binary_function;
00086   using std::pair;
00087   using std::make_pair;
00088   using std::plus;
00089   using std::multiplies;
00090 
00091   struct trivial_fpu_fix {
00092     void fix() {}
00093     void unfix() {}
00094   };
00095   struct nontrivial_fpu_fix {
00096     unsigned int old_cw;
00097     void fix()   {fpu_fix_start(&old_cw);}
00098     void unfix() {fpu_fix_end(&old_cw);}
00099   };
00100   // implementations
00101   template <class T> struct fpu_fix : trivial_fpu_fix {};
00102 #ifdef HAVE_TPETRA_QD
00103   template <> struct fpu_fix<qd_real> : nontrivial_fpu_fix {};
00104   template <> struct fpu_fix<dd_real> : nontrivial_fpu_fix {};
00105 #endif
00106 
00108   template <class Tout, class Tin, class LO, class GO, class Node> 
00109   struct convertHelp {
00110     static RCP<const Tpetra::CrsMatrix<Tout,LO,GO,Node>> doit(const RCP<const Tpetra::CrsMatrix<Tin,LO,GO,Node>> &A)
00111     {
00112       return A->template convert<Tout>();
00113     }
00114   };
00115 
00117   template <class T, class LO, class GO, class Node> 
00118   struct convertHelp<T,T,LO,GO,Node> {
00119     static RCP<const Tpetra::CrsMatrix<T,LO,GO,Node>> doit(const RCP<const Tpetra::CrsMatrix<T,LO,GO,Node>> &A)
00120     {
00121       return A;
00122     }
00123   };
00124 
00125 
00129   template <class T1, class T2, class Op>
00130   class pair_op : public binary_function<pair<T1,T2>,pair<T1,T2>,pair<T1,T2>> {
00131   private:
00132     Op op_;
00133   public:
00134     pair_op(Op op) : op_(op) {}
00135     inline pair<T1,T2> operator()(const pair<T1,T2>& a, const pair<T1,T2>& b) const {
00136       return make_pair(op_(a.first,b.first),op_(a.second,b.second));
00137     }
00138   };
00139 
00141   template <class T1, class T2, class Op>
00142   pair_op<T1,T2,Op> make_pair_op(Op op) { return pair_op<T1,T2,Op>(op); }
00143 
00145   template <class S, class LO, class GO, class Node>
00146   class RFPCGInit 
00147   {
00148     private:
00149     typedef Tpetra::Map<LO,GO,Node>               Map;
00150     typedef Tpetra::CrsMatrix<S,LO,GO,Node> CrsMatrix;
00151     RCP<const CrsMatrix> A;
00152 
00153     public:
00154 
00155     RFPCGInit(RCP<Tpetra::CrsMatrix<S,LO,GO,Node>> Atop) : A(Atop) {}
00156 
00157     template <class T>
00158     RCP<ParameterList> initDB(ParameterList &params) 
00159     {
00160       fpu_fix<T> ff;
00161       ff.fix();
00162       typedef Tpetra::Vector<T,LO,GO,Node>    Vector;
00163       typedef Tpetra::Operator<T,LO,GO,Node>      Op;
00164       typedef Tpetra::CrsMatrix<T,LO,GO,Node>    Mat;
00165       RCP<const Map> map = A->getDomainMap();
00166       RCP<ParameterList> db = Teuchos::parameterList();
00167       RCP<const Mat> AT = convertHelp<T,S,LO,GO,Node>::doit(A);
00168       //
00169       db->set<RCP<const Op>>("A", AT                    );
00170       db->set("numIters", params.get<int>("numIters",A->getGlobalNumRows()) );
00171       db->set("tolerance",params.get<double>("tolerance",1e-7));
00172       db->set("verbose",  params.get<int>("verbose",0) );
00173       db->set("bx",       Tpetra::createVector<T>(map)  );
00174       db->set("r",        Tpetra::createVector<T>(map)  );
00175       db->set("z",        Tpetra::createVector<T>(map)  );
00176       db->set("p",        Tpetra::createVector<T>(map)  );
00177       db->set("Ap",       Tpetra::createVector<T>(map)  );
00178       db->set("rold",     Tpetra::createVector<T>(map)  );
00179       if (params.get<bool>("Extract Diagonal",false)) {
00180         RCP<Vector> diag = Tpetra::createVector<T>(map);
00181         AT->getLocalDiagCopy(*diag);
00182         db->set("diag", diag);
00183       }
00184       ff.unfix();
00185       return db;
00186     }
00187   };
00188 
00189   /******************************
00190   *   Somewhat flexible CG
00191   *   Golub and Ye, 1999
00192   *
00193   *   r = b
00194   *   z = M*r
00195   *   p = z
00196   *   do
00197   *     alpha = r'*z / p'*A*p
00198   *     x = x + alpha*p
00199   *     r = r - alpha*A*p
00200   *     if outermost, check r for convergence
00201   *     z = M*r
00202   *     beta = z'*(r_new - r_old) / z'*r
00203   *     p = z + beta*p
00204   *   enddo
00205   ******************************/
00206 
00208   template <class TS, class LO, class GO, class Node>      
00209   void recursiveFPCG(const RCP<Teuchos::FancyOStream> &out, ParameterList &db)
00210   {
00211     using Teuchos::as;
00212     using Tpetra::RTI::ZeroOp;
00213     typedef typename TS::type       T;
00214     typedef typename TS::next::type T2;
00215     typedef Tpetra::Vector<T ,LO,GO,Node> VectorT1;
00216     typedef Tpetra::Vector<T2,LO,GO,Node> VectorT2;
00217     typedef Tpetra::Operator<T,LO,GO,Node>    OpT1;
00218     typedef Teuchos::ScalarTraits<T>            ST;
00219     // get objects from level database
00220     const int numIters = db.get<int>("numIters");
00221     auto x     = db.get<RCP<VectorT1>>("bx");
00222     auto r     = db.get<RCP<VectorT1>>("r");
00223     auto z     = db.get<RCP<VectorT1>>("z");
00224     auto p     = db.get<RCP<VectorT1>>("p");
00225     auto Ap    = db.get<RCP<VectorT1>>("Ap");
00226     auto rold  = db.get<RCP<VectorT1>>("rold");
00227     auto A     = db.get<RCP<const OpT1>>("A");
00228     RCP<const VectorT1> diag;
00229     if (TS::bottom) {
00230       diag = db.get<RCP<VectorT1>>("diag");
00231     }
00232     const T tolerance = db.get<double>("tolerance", 0.0);
00233     const int verbose = db.get<int>("verbose",0);
00234     static RCP<Time> timer, Atimer;
00235     if (timer == null) {
00236       timer = Teuchos::TimeMonitor::getNewTimer(
00237                       "recursiveFPCG<"+Teuchos::TypeNameTraits<T>::name()+">"
00238               );
00239     }
00240     if (Atimer == null) {
00241       Atimer = Teuchos::TimeMonitor::getNewTimer(
00242                       "A<"+Teuchos::TypeNameTraits<T>::name()+">"
00243               );
00244     }
00245 
00246     fpu_fix<T> ff;
00247     ff.fix();
00248 
00249     Teuchos::OSTab tab(out);
00250 
00251     if (verbose) {
00252       *out << "Beginning recursiveFPCG<" << Teuchos::TypeNameTraits<T>::name() << ">" << std::endl;
00253     }
00254 
00255     timer->start(); 
00256     const T r2 = TPETRA_BINARY_PRETRANSFORM_REDUCE(
00257                     r, x,                                         // fused: 
00258                     x,                                            //      : r = x  
00259                     r*r, ZeroOp<T>, plus<T>() );                  //      : sum r'*r
00260     // b comes in, x goes out. now we're done with b, so zero the solution.
00261     TPETRA_UNARY_TRANSFORM( x,  ST::zero() );                     // x = 0
00262     if (TS::bottom) {
00263       TPETRA_TERTIARY_TRANSFORM( z, diag, r,    r/diag );         // z = D\r
00264     }
00265     else {
00266       ff.unfix();
00267       timer->stop(); 
00268       ParameterList &db_T2 = db.sublist("child");
00269       auto bx_T2 = db_T2.get< RCP<VectorT2>>("bx");
00270       TPETRA_BINARY_TRANSFORM( bx_T2, r, as<T2>(r) );             // b_T2 = (T2)r
00271       recursiveFPCG<typename TS::next,LO,GO,Node>(out,db_T2);     // x_T2 = A_T2 \ b_T2 
00272       TPETRA_BINARY_TRANSFORM( z, bx_T2, as<T>(bx_T2) );          // z    = (T)bx_T2
00273       timer->start(); 
00274       ff.fix();
00275     }
00276     T zr = TPETRA_TERTIARY_PRETRANSFORM_REDUCE( 
00277                     p, z, r,                                      // fused: 
00278                     z,                                            //      : p = z
00279                     z*r, ZeroOp<T>, plus<T>() );                  //      : z*r
00281     int k;
00282     for (k=0; k<numIters; ++k) 
00283     {
00284       Atimer->start();
00285       A->apply(*p,*Ap);                                           // Ap = A*p
00286       Atimer->stop();
00287       T pAp = TPETRA_REDUCE2( p, Ap,     
00288                               p*Ap, ZeroOp<T>, plus<T>() );       // p'*Ap
00289       const T alpha = zr / pAp;
00290       TPETRA_BINARY_TRANSFORM( x,    p,  x + alpha*p  );          // x = x + alpha*p
00291       TPETRA_BINARY_TRANSFORM( rold, r,  r            );          // rold = r
00292       T rr = TPETRA_BINARY_PRETRANSFORM_REDUCE(
00293                                r, Ap,                             // fused:
00294                                r - alpha*Ap,                      //      : r - alpha*Ap
00295                                r*r, ZeroOp<T>, plus<T>() );       //      : sum r'*r
00296       if (verbose > 1) *out << "|res|/|res_0|: " << ST::squareroot(rr/r2) 
00297                             << std::endl;
00298       if (rr/r2 < tolerance*tolerance) {
00299         if (verbose) {
00300           *out << "Convergence detected!" << std::endl;
00301         }
00302         break;
00303       }
00304       if (TS::bottom) {
00305         TPETRA_TERTIARY_TRANSFORM( z, diag, r,    r/diag );        // z = D\r
00306       }
00307       else {
00308         ff.unfix();
00309         timer->stop(); 
00310         ParameterList &db_T2 = db.sublist("child");
00311         auto bx_T2 = db_T2.get< RCP<VectorT2>>("bx");
00312         TPETRA_BINARY_TRANSFORM( bx_T2, r,    as<T2>(r) );         // b_T2 = (T2)r
00313         recursiveFPCG<typename TS::next,LO,GO,Node>(out,db_T2);    // x_T2 = A_T2 \ b_T2
00314         TPETRA_BINARY_TRANSFORM( z, bx_T2,    as<T>(bx_T2) );      // z    = (T)bx_T2
00315         timer->start(); 
00316         ff.fix();
00317       }
00318       const T zoro = zr;                                                         
00319       typedef ZeroOp<pair<T,T>> ZeroPTT;
00320       auto plusTT = make_pair_op<T,T>(plus<T>());
00321       pair<T,T> both = TPETRA_REDUCE3( z, r, rold,                 // fused: z'*r and z'*r_old
00322                                        make_pair(z*r, z*rold), ZeroPTT, plusTT );
00323       zr = both.first; // this is used on the next iteration as well
00324       const T znro = both.second;
00325       const T beta = (zr - znro) / zoro;
00326       TPETRA_BINARY_TRANSFORM( p, z,   z + beta*p );               // p = z + beta*p
00327     }
00328     timer->stop(); 
00329     ff.unfix();
00330     if (verbose) {
00331       *out << "Leaving recursiveFPCG<" << Teuchos::TypeNameTraits<T>::name() 
00332            << "> after " << k << " iterations." << std::endl;
00333     }
00334   }
00335 
00337   template <class TS, class LO, class GO, class Node>      
00338   void recursiveFPCGUnfused(const RCP<Teuchos::FancyOStream> &out, ParameterList &db)
00339   {
00340     using Tpetra::RTI::unary_transform;
00341     using Tpetra::RTI::binary_transform;
00342     using Tpetra::RTI::tertiary_transform;
00343     using Tpetra::RTI::reduce;
00344     using Tpetra::RTI::reductionGlob;
00345     using Tpetra::RTI::ZeroOp;
00346     using Teuchos::as;
00347     typedef typename TS::type       T;
00348     typedef typename TS::next::type T2;
00349     typedef Tpetra::Vector<T ,LO,GO,Node> VectorT1;
00350     typedef Tpetra::Vector<T2,LO,GO,Node> VectorT2;
00351     typedef Tpetra::Operator<T,LO,GO,Node>    OpT1;
00352     typedef Teuchos::ScalarTraits<T>            ST;
00353     // get objects from level database
00354     const int numIters = db.get<int>("numIters");
00355     auto x     = db.get<RCP<VectorT1>>("bx");
00356     auto r     = db.get<RCP<VectorT1>>("r");
00357     auto z     = db.get<RCP<VectorT1>>("z");
00358     auto p     = db.get<RCP<VectorT1>>("p");
00359     auto Ap    = db.get<RCP<VectorT1>>("Ap");
00360     auto rold  = db.get<RCP<VectorT1>>("rold");
00361     auto A     = db.get<RCP<const OpT1>>("A");
00362     static RCP<Time> timer;
00363     if (timer == null) {
00364       timer = Teuchos::TimeMonitor::getNewTimer(
00365                       "recursiveFPCGUnfused<"+Teuchos::TypeNameTraits<T>::name()+">"
00366               );
00367     }
00368     RCP<const VectorT1> diag;
00369     if (TS::bottom) {
00370       diag = db.get<RCP<VectorT1>>("diag");
00371     }
00372     const T tolerance = db.get<double>("tolerance", 0.0);
00373     const int verbose = db.get<int>("verbose",0);
00374 
00375     fpu_fix<T> ff;
00376     ff.fix();
00377 
00378     Teuchos::OSTab tab(out);
00379 
00380     if (verbose) {
00381       *out << "Beginning recursiveFPCGUnfused<" << Teuchos::TypeNameTraits<T>::name() << ">" << std::endl;
00382     }
00383 
00384     timer->start(); 
00385     binary_transform( *r, *x,            [](T, T bi)             {return bi;});  // r = b     (b is stored in x)
00386     const T r2 = reduce( *r, *r,      reductionGlob<ZeroOp<T>>(multiplies<T>(),  // r'*r
00387                                                                    plus<T>())); 
00388     unary_transform(  *x,                [](T)           {return ST::zero();});  // set x = 0 (now that we don't need b)
00389     if (TS::bottom) {
00390       tertiary_transform( *z, *diag, *r, [](T, T di, T ri)    {return ri/di;});  // z = D\r
00391     }
00392     else {
00393       ff.unfix();
00394       timer->stop(); 
00395       ParameterList &db_T2 = db.sublist("child");
00396       auto bx_T2 = db_T2.get< RCP<VectorT2>>("bx");
00397       binary_transform( *bx_T2, *r, [](T2, T ri)         {return as<T2>(ri);});  // b_T2 = (T2)r       
00398       recursiveFPCGUnfused<typename TS::next,LO,GO,Node>(out,db_T2);             // x_T2 = A_T2 \ b_T2 
00399       binary_transform( *z, *bx_T2, [](T, T2 xi)          {return as<T>(xi);});  // z    = (T)x_T2     
00400       timer->start(); 
00401       ff.fix();
00402     }
00403     binary_transform( *p, *z, [](T, T zi)                        {return zi;});  // p = z
00404     T zr = reduce( *z, *r,          reductionGlob<ZeroOp<T>>(multiplies<T>(),    // z'*r
00405                                                                    plus<T>())); 
00407     int k;
00408     for (k=0; k<numIters; ++k) 
00409     {
00410       A->apply(*p,*Ap);                                                          // Ap = A*p
00411       T pAp = reduce( *p, *Ap,      reductionGlob<ZeroOp<T>>(multiplies<T>(),    // p'*Ap
00412                                                                    plus<T>())); 
00413       const T alpha = zr / pAp;
00414       binary_transform( *x, *p, [alpha](T xi, T pi)   {return xi + alpha*pi;});  // x = x + alpha*p
00415       binary_transform( *rold, *r, [](T, T ri)                   {return ri;});  // rold = r
00416       binary_transform( *r, *Ap,[alpha](T ri, T Api) {return ri - alpha*Api;});  // r = r - alpha*Ap
00417       T rr = reduce( *r, *r,      reductionGlob<ZeroOp<T>>(multiplies<T>(),      // r'*r
00418                                                                  plus<T>())); 
00419       if (verbose > 1) *out << "|res|/|res_0|: " << ST::squareroot(rr/r2) << std::endl;
00420       if (rr/r2 < tolerance*tolerance) {
00421         if (verbose) {
00422           *out << "Convergence detected!" << std::endl;
00423         }
00424         break;
00425       }
00426       if (TS::bottom) {
00427         tertiary_transform( *z, *diag, *r, [](T, T di, T ri)  {return ri/di;});  // z = D\r
00428       }
00429       else {
00430         ff.unfix();
00431         timer->stop(); 
00432         ParameterList &db_T2 = db.sublist("child");
00433         auto bx_T2 = db_T2.get< RCP<VectorT2>>("bx");
00434         binary_transform( *bx_T2, *r, [](T2, T ri)       {return as<T2>(ri);});  // b_T2 = (T2)r
00435         recursiveFPCGUnfused<typename TS::next,LO,GO,Node>(out,db_T2);           // x_T2 = A_T2 \ b_T2
00436         binary_transform( *z, *bx_T2, [](T, T2 xi)        {return as<T>(xi);});  // z    = (T)x_T2
00437         timer->start(); 
00438         ff.fix();
00439       }
00440       const T zoro = zr;                                                         
00441       zr = reduce( *z, *r,          reductionGlob<ZeroOp<T>>(multiplies<T>(),    // z'*r
00442                                                              plus<T>()));        // this is loop-carried
00443       const T znro = reduce( *z, *rold, reductionGlob<ZeroOp<T>>(multiplies<T>(),// z'*r_old
00444                                                                  plus<T>())); 
00445       const T beta = (zr - znro) / zoro;
00446       binary_transform( *p, *z, [beta](T pi, T zi)     {return zi + beta*pi;});  // p = z + beta*p
00447     }
00448     timer->stop(); 
00449     ff.unfix();
00450     if (verbose) {
00451       *out << "Leaving recursiveFPCGUnfused<" << Teuchos::TypeNameTraits<T>::name() << "> after " << k << " iterations." << std::endl;
00452     }
00453   }
00454 
00455 } // end of namespace TpetraExamples
00456 
00457 #endif // MULTIPRECCG_HPP_
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines