BelosLinearProblem.hpp

Go to the documentation of this file.
00001 
00002 // @HEADER
00003 // ***********************************************************************
00004 //
00005 //                 Belos: Block Linear Solvers Package
00006 //                 Copyright (2004) Sandia Corporation
00007 //
00008 // Under terms of Contract DE-AC04-94AL85000, there is a non-exclusive
00009 // license for use of this work by or on behalf of the U.S. Government.
00010 //
00011 // This library is free software; you can redistribute it and/or modify
00012 // it under the terms of the GNU Lesser General Public License as
00013 // published by the Free Software Foundation; either version 2.1 of the
00014 // License, or (at your option) any later version.
00015 //
00016 // This library is distributed in the hope that it will be useful, but
00017 // WITHOUT ANY WARRANTY; without even the implied warranty of
00018 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
00019 // Lesser General Public License for more details.
00020 //
00021 // You should have received a copy of the GNU Lesser General Public
00022 // License along with this library; if not, write to the Free Software
00023 // Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307
00024 // USA
00025 // Questions? Contact Michael A. Heroux (maherou@sandia.gov)
00026 //
00027 // ***********************************************************************
00028 // @HEADER
00029 
00030 #ifndef BELOS_LINEAR_PROBLEM_HPP
00031 #define BELOS_LINEAR_PROBLEM_HPP
00032 
00037 #include "BelosMultiVecTraits.hpp"
00038 #include "BelosOperatorTraits.hpp"
00039 #include "Teuchos_ParameterList.hpp"
00040 
00041 using Teuchos::RefCountPtr;
00042 using Teuchos::rcp;
00043 using Teuchos::null;
00044 using Teuchos::rcp_const_cast;
00045 using Teuchos::ParameterList;
00046 
00054 namespace Belos {
00055   
00056   template <class ScalarType, class MV, class OP>
00057   class LinearProblem {
00058    
00059   public:
00060     
00062 
00063 
00064 
00068     LinearProblem(void);
00069     
00071 
00076     LinearProblem(const RefCountPtr<const OP> &A, 
00077       const RefCountPtr<MV> &X, 
00078       const RefCountPtr<const MV> &B
00079       );
00080     
00082 
00084     LinearProblem(const LinearProblem<ScalarType,MV,OP>& Problem);
00085     
00087 
00089     virtual ~LinearProblem(void);
00091     
00093 
00094     
00096 
00098     void SetOperator(const RefCountPtr<const OP> &A) { A_ = A; };
00099     
00101 
00103     void SetLHS(const RefCountPtr<MV> &X);
00104     
00106 
00108     void SetRHS(const RefCountPtr<const MV> &B) { B_ = B; };
00109     
00111 
00113     void SetLeftPrec(const RefCountPtr<const OP> &LP) {  LP_ = LP; Left_Prec_ = true; };
00114     
00116 
00118     void SetRightPrec(const RefCountPtr<const OP> &RP) { RP_ = RP; Right_Prec_ = true; };
00119     
00121     void SetParameterList(const RefCountPtr<ParameterList> &PL) { PL_ = PL; };
00122 
00124     void SetBlockSize(int blocksize) { default_blocksize_ = blocksize; blocksize_ = blocksize; };
00125     
00127 
00132     void SetCurrLSVec();
00133     
00135 
00139     void AssertSymmetric(){ operatorSymmetric_ = true; };
00140     
00142 
00146     void SolutionUpdated( const MV* SolnUpdate = 0,
00147                           ScalarType scale = Teuchos::ScalarTraits<ScalarType>::one() );
00148     
00150     
00152 
00153     
00155 
00158     void Reset( const RefCountPtr<MV> &newX = null, const RefCountPtr<const MV> &newB = null );
00160     
00162 
00163     
00165     RefCountPtr<const OP> GetOperator() const { return(A_); };
00166     
00168     RefCountPtr<MV> GetLHS() const { return(X_); };
00169     
00171     RefCountPtr<const MV> GetRHS() const { return(B_); };
00172     
00174 
00176     const MV& GetInitResVec();
00177     
00179 
00195     const MV& GetCurrResVec( const MV* CurrSoln = 0 );
00196     
00198 
00205     RefCountPtr<MV> GetCurrLHSVec();
00206     
00208 
00215     RefCountPtr<MV> GetCurrRHSVec();
00216     
00218     RefCountPtr<const OP> GetLeftPrec() const { return(LP_); };
00219     
00221     RefCountPtr<const OP> GetRightPrec() const { return(RP_); };
00222     
00224     RefCountPtr<ParameterList> GetParameterList() const { return(PL_); };
00225 
00227     int GetBlockSize() const { return( default_blocksize_ ); };
00228     
00230 
00234     int GetCurrBlockSize() const { return( blocksize_ ); };
00235 
00237 
00243     int GetNumToSolve() const { return( num_to_solve_ ); };
00244     
00246 
00253     int GetRHSIndex() const { return( rhs_index_ ); };
00254     
00256 
00260     bool IsSolutionUpdated() const { return(solutionUpdated_); };
00261     
00263     bool IsOperatorSymmetric() const { return(operatorSymmetric_); };
00264     
00266     
00268 
00269     
00271 
00278     ReturnType Apply( const MV& x, MV& y );
00279     
00281 
00288     ReturnType ApplyOp( const MV& x, MV& y );
00289     
00291 
00295     ReturnType ApplyLeftPrec( const MV& x, MV& y );
00296     
00298 
00302     ReturnType ApplyRightPrec( const MV& x, MV& y );
00303     
00305 
00309     ReturnType ComputeResVec( MV* R, const MV* X = 0, const MV* B = 0 );
00310     
00312     
00313   private:
00314     
00316     void SetUpBlocks();
00317     
00319     RefCountPtr<const OP> A_;
00320     
00322     RefCountPtr<MV> X_;
00323     
00325     RefCountPtr<MV> CurX_;
00326     
00328     RefCountPtr<const MV> B_;
00329     
00331     RefCountPtr<MV> CurB_;
00332     
00334     RefCountPtr<MV> R_;
00335     
00337     RefCountPtr<MV> R0_;
00338     
00340     RefCountPtr<const OP> LP_;  
00341     
00343     RefCountPtr<const OP> RP_;
00344     
00346     RefCountPtr<ParameterList> PL_;
00347 
00349     int default_blocksize_;
00350     
00352     int blocksize_;
00353 
00355     int num_to_solve_;
00356     
00358     int rhs_index_;
00359     
00361     bool Left_Prec_;
00362     bool Right_Prec_;
00363     bool Left_Scale_;
00364     bool Right_Scale_;
00365     bool operatorSymmetric_;
00366     bool solutionUpdated_;    
00367     bool solutionFinal_;
00368     bool initresidsComputed_;
00369     
00370     typedef MultiVecTraits<ScalarType,MV>  MVT;
00371     typedef OperatorTraits<ScalarType,MV,OP>  OPT;
00372   };
00373   
00374   //--------------------------------------------
00375   //  Constructor Implementations
00376   //--------------------------------------------
00377   
00378   template <class ScalarType, class MV, class OP>
00379   LinearProblem<ScalarType,MV,OP>::LinearProblem(void) : 
00380     default_blocksize_(1),
00381     blocksize_(1),
00382     num_to_solve_(0),
00383     rhs_index_(0),  
00384     Left_Prec_(false),
00385     Right_Prec_(false),
00386     Left_Scale_(false),
00387     Right_Scale_(false),
00388     operatorSymmetric_(false),
00389     solutionUpdated_(false),
00390     solutionFinal_(true),
00391     initresidsComputed_(false)
00392   {
00393   }
00394   
00395   template <class ScalarType, class MV, class OP>
00396   LinearProblem<ScalarType,MV,OP>::LinearProblem(const RefCountPtr<const OP> &A, 
00397              const RefCountPtr<MV> &X, 
00398              const RefCountPtr<const MV> &B
00399              ) :
00400     A_(A),
00401     X_(X),
00402     B_(B),
00403     default_blocksize_(1),
00404     blocksize_(1),
00405     num_to_solve_(1),
00406     rhs_index_(0),
00407     Left_Prec_(false),
00408     Right_Prec_(false),
00409     Left_Scale_(false),
00410     Right_Scale_(false),
00411     operatorSymmetric_(false),
00412     solutionUpdated_(false),
00413     solutionFinal_(true),
00414     initresidsComputed_(false)
00415   {
00416     R0_ = MVT::Clone( *X_, MVT::GetNumberVecs( *X_ ) );
00417   }
00418   
00419   template <class ScalarType, class MV, class OP>
00420   LinearProblem<ScalarType,MV,OP>::LinearProblem(const LinearProblem<ScalarType,MV,OP>& Problem) :
00421     A_(Problem.A_),
00422     X_(Problem.X_),
00423     CurX_(Problem.CurX_),
00424     B_(Problem.B_),
00425     CurB_(Problem.CurB_),
00426     R_(Problem.R_),
00427     R0_(Problem.R0_),
00428     LP_(Problem.LP_),
00429     RP_(Problem.RP_),
00430     PL_(Problem.PL_),
00431     default_blocksize_(Problem.default_blocksize_),
00432     blocksize_(Problem.blocksize_),
00433     num_to_solve_(Problem.num_to_solve_),
00434     rhs_index_(Problem.rhs_index_),
00435     Left_Prec_(Problem.Left_Prec_),
00436     Right_Prec_(Problem.Right_Prec_),
00437     Left_Scale_(Problem.Left_Scale_),
00438     Right_Scale_(Problem.Right_Scale_),
00439     operatorSymmetric_(Problem.operatorSymmetric_),
00440     solutionUpdated_(Problem.solutionUpdated_),
00441     solutionFinal_(Problem.solutionFinal_),
00442     initresidsComputed_(Problem.initresidsComputed_)
00443   {
00444   }
00445   
00446   template <class ScalarType, class MV, class OP>
00447   LinearProblem<ScalarType,MV,OP>::~LinearProblem(void)
00448   {}
00449   
00450   template <class ScalarType, class MV, class OP>
00451   void LinearProblem<ScalarType,MV,OP>::SetUpBlocks()
00452   {
00453     // Compute the new block linear system.
00454     // ( first clean up old linear system )
00455     if (CurB_.get()) CurB_ = null;
00456     if (CurX_.get()) CurX_ = null;
00457     if (R_.get()) R_ = null;
00458     //
00459     // Determine how many linear systems are left to solve for and populate LHS and RHS vector.
00460     // If the number of linear systems left are less than the current blocksize, then
00461     // we create a multivector and copy the left over LHS and RHS vectors into them.
00462     // The rest of the multivector is populated with random vectors (RHS) or zero vectors (LHS).
00463     //
00464     num_to_solve_ = MVT::GetNumberVecs(*X_) - rhs_index_;
00465     //
00466     // Return the NULL pointer if we don't have any more systems to solve for.
00467     if ( num_to_solve_ <= 0 ) { return; }  
00468     //
00469     int i;
00470     std::vector<int> index( num_to_solve_ );
00471     for ( i=0; i<num_to_solve_; i++ ) { index[i] = rhs_index_ + i; }
00472     //
00473 /*    if ( num_to_solve_ < default_blocksize_ )
00474       blocksize_ = num_to_solve_;
00475     else
00476       blocksize_ = default_blocksize_;
00477 */
00478     //
00479     if ( num_to_solve_ < blocksize_ ) 
00480       {
00481   std::vector<int> index2(num_to_solve_);
00482   for (i=0; i<num_to_solve_; i++) {
00483     index2[i] = i;
00484   }
00485   //
00486   // First create multivectors of blocksize and fill the RHS with random vectors LHS with zero vectors.
00487   CurX_ = MVT::Clone( *X_, blocksize_ );
00488   MVT::MvInit(*CurX_);
00489   CurB_ = MVT::Clone( *B_, blocksize_ );
00490   MVT::MvRandom(*CurB_);
00491   R_ = MVT::Clone( *X_, blocksize_);
00492   //
00493   RefCountPtr<const MV> tptr = MVT::CloneView( *B_, index );
00494   MVT::SetBlock( *tptr, index2, *CurB_ );
00495   //
00496   RefCountPtr<MV> tptr2 = MVT::CloneView( *X_, index );
00497   MVT::SetBlock( *tptr2, index2, *CurX_ );
00498       } else { 
00499   //
00500   // If the number of linear systems left are more than or equal to the current blocksize, then
00501   // we create a view into the LHS and RHS.
00502   //
00503   num_to_solve_ = blocksize_;
00504   index.resize( num_to_solve_ );
00505   for ( i=0; i<num_to_solve_; i++ ) { index[i] = rhs_index_ + i; }
00506   CurX_ = MVT::CloneView( *X_, index );
00507   CurB_ = rcp_const_cast<MV>(MVT::CloneView( *B_, index ));
00508   R_ = MVT::Clone( *X_, num_to_solve_ );
00509   //
00510       }
00511     //
00512     // Compute the current residual.
00513     // 
00514     if (R_.get()) {
00515       OPT::Apply( *A_, *CurX_, *R_ );
00516       MVT::MvAddMv( 1.0, *CurB_, -1.0, *R_, *R_ );
00517       solutionUpdated_ = false;
00518     }
00519   }
00520   
00521   template <class ScalarType, class MV, class OP>
00522   void LinearProblem<ScalarType,MV,OP>::SetLHS(const RefCountPtr<MV> &X)
00523   {
00524     X_ = X; 
00525     R0_ = MVT::Clone( *X_, MVT::GetNumberVecs( *X_ ) ); 
00526   }
00527   
00528   template <class ScalarType, class MV, class OP>
00529   void LinearProblem<ScalarType,MV,OP>::SetCurrLSVec() 
00530   { 
00531     int i;
00532     //
00533     // We only need to copy the solutions back if the linear systems of
00534     // interest are less than the block size.
00535     //
00536     if (num_to_solve_ < blocksize_) {
00537       //
00538       std::vector<int> index( num_to_solve_ );
00539       //
00540       RefCountPtr<MV> tptr;
00541       //
00542       // Get a view of the current solutions and correction vector.
00543       //
00544       for (i=0; i<num_to_solve_; i++) { 
00545   index[i] = i; 
00546       }
00547       tptr = MVT::CloneView( *CurX_, index );
00548       //
00549       // Copy the correction vector to the solution vector.
00550       //
00551       for (i=0; i<num_to_solve_; i++) { 
00552   index[i] = rhs_index_ + i; 
00553       }
00554       MVT::SetBlock( *tptr, index, *X_ );
00555     }
00556     //
00557     // Get the linear problem ready to determine the next linear system.
00558     //
00559     solutionFinal_ = true; 
00560     rhs_index_ += num_to_solve_; 
00561   }
00562   
00563   template <class ScalarType, class MV, class OP>
00564   void LinearProblem<ScalarType,MV,OP>::SolutionUpdated( const MV* SolnUpdate, ScalarType scale )
00565   { 
00566     if (SolnUpdate) {
00567       if (Right_Prec_) {
00568   //
00569   // Apply the right preconditioner before computing the current solution.
00570   RefCountPtr<MV> TrueUpdate = MVT::Clone( *SolnUpdate, MVT::GetNumberVecs( *SolnUpdate ) );
00571   OPT::Apply( *RP_, *SolnUpdate, *TrueUpdate ); 
00572   MVT::MvAddMv( 1.0, *CurX_, scale, *TrueUpdate, *CurX_ ); 
00573       } else {
00574   MVT::MvAddMv( 1.0, *CurX_, scale, *SolnUpdate, *CurX_ ); 
00575       }
00576     }
00577     solutionUpdated_ = true; 
00578   }
00579   
00580   template <class ScalarType, class MV, class OP>
00581   void LinearProblem<ScalarType,MV,OP>::Reset( const RefCountPtr<MV> &newX, const RefCountPtr<const MV> &newB )
00582   {
00583     solutionUpdated_ = false;
00584     solutionFinal_ = true;
00585     initresidsComputed_ = false;
00586     rhs_index_ = 0;
00587     
00588     X_ = newX;
00589     B_ = newB;
00590     GetInitResVec();
00591   }
00592   
00593   template <class ScalarType, class MV, class OP>
00594   const MV& LinearProblem<ScalarType,MV,OP>::GetInitResVec() 
00595   {
00596     // Compute the initial residual if it hasn't been computed
00597     // and all the components of the linear system are there.
00598     // The left preconditioner will be applied if it exists, resulting
00599     // in a preconditioned residual.
00600     if (!initresidsComputed_ && A_.get() && X_.get() && B_.get()) 
00601       {
00602   if (R0_.get()) R0_ = null;
00603   R0_ = MVT::Clone( *X_, MVT::GetNumberVecs( *X_ ) );
00604   OPT::Apply( *A_, *X_, *R0_ );
00605   MVT::MvAddMv( 1.0, *B_, -1.0, *R0_, *R0_ );
00606   initresidsComputed_ = true;
00607       }
00608     return (*R0_);
00609   }
00610   
00611   template <class ScalarType, class MV, class OP>
00612   const MV& LinearProblem<ScalarType,MV,OP>::GetCurrResVec( const MV* CurrSoln ) 
00613   {
00614     // Compute the residual of the current linear system.
00615     // This should be used if the solution has been updated.
00616     // Alternatively, if the current solution has been computed by GMRES
00617     // this can be passed in and the current residual will be updated using
00618     // it.
00619     //
00620     if (solutionUpdated_) 
00621       {
00622   OPT::Apply( *A_, *GetCurrLHSVec(), *R_ );
00623   MVT::MvAddMv( 1.0, *GetCurrRHSVec(), -1.0, *R_, *R_ ); 
00624   solutionUpdated_ = false;
00625       }
00626     else if (CurrSoln) 
00627       {
00628   OPT::Apply( *A_, *CurrSoln, *R_ );
00629   MVT::MvAddMv( 1.0, *GetCurrRHSVec(), -1.0, *R_, *R_ ); 
00630       }
00631     return (*R_);
00632   }
00633   
00634   template <class ScalarType, class MV, class OP>
00635   RefCountPtr<MV> LinearProblem<ScalarType,MV,OP>::GetCurrLHSVec()
00636   {
00637     if (solutionFinal_) {
00638       solutionFinal_ = false; // make sure we don't populate the current linear system again.
00639       SetUpBlocks();
00640     }
00641     return CurX_; 
00642   }
00643   
00644   template <class ScalarType, class MV, class OP>
00645   RefCountPtr<MV> LinearProblem<ScalarType,MV,OP>::GetCurrRHSVec()
00646   {
00647     if (solutionFinal_) {
00648       solutionFinal_ = false; // make sure we don't populate the current linear system again.
00649       SetUpBlocks();
00650     }
00651     return CurB_;
00652   }
00653   
00654   template <class ScalarType, class MV, class OP>
00655   ReturnType LinearProblem<ScalarType,MV,OP>::Apply( const MV& x, MV& y )
00656   {
00657     RefCountPtr<MV> ytemp = MVT::Clone( y, MVT::GetNumberVecs( y ) );
00658     //
00659     // No preconditioning.
00660     // 
00661     if (!Left_Prec_ && !Right_Prec_){ OPT::Apply( *A_, x, y );}
00662     //
00663     // Preconditioning is being done on both sides
00664     //
00665     else if( Left_Prec_ && Right_Prec_ ) 
00666       {
00667   OPT::Apply( *RP_, x, y );   
00668   OPT::Apply( *A_, y, *ytemp );
00669   OPT::Apply( *LP_, *ytemp, y );
00670       }
00671     //
00672     // Preconditioning is only being done on the left side
00673     //
00674     else if( Left_Prec_ ) 
00675       {
00676   OPT::Apply( *A_, x, *ytemp );
00677   OPT::Apply( *LP_, *ytemp, y );
00678       }
00679     //
00680     // Preconditioning is only being done on the right side
00681     //
00682     else 
00683       {
00684   OPT::Apply( *RP_, x, *ytemp );
00685   OPT::Apply( *A_, *ytemp, y );
00686       }  
00687     return Ok;
00688   }
00689   
00690   template <class ScalarType, class MV, class OP>
00691   ReturnType LinearProblem<ScalarType,MV,OP>::ApplyOp( const MV& x, MV& y )
00692   {
00693     if (A_.get())
00694       return ( OPT::Apply( *A_,x, y) );   
00695     else
00696       return Undefined;
00697   }
00698   
00699   template <class ScalarType, class MV, class OP>
00700   ReturnType LinearProblem<ScalarType,MV,OP>::ApplyLeftPrec( const MV& x, MV& y )
00701   {
00702     if (Left_Prec_)
00703       return ( OPT::Apply( *LP_,x, y) );
00704     else 
00705       return Undefined;
00706   }
00707   
00708   template <class ScalarType, class MV, class OP>
00709   ReturnType LinearProblem<ScalarType,MV,OP>::ApplyRightPrec( const MV& x, MV& y )
00710   {
00711     if (Right_Prec_)
00712       return ( OPT::Apply( *RP_,x, y) );
00713     else
00714       return Undefined;
00715   }
00716   
00717   template <class ScalarType, class MV, class OP>
00718   ReturnType LinearProblem<ScalarType,MV,OP>::ComputeResVec( MV* R, const MV* X, const MV* B )
00719   {
00720     if (X && B) // The entries are specified, so compute the residual of Op(A)X = B
00721       {
00722   if (Left_Prec_)
00723     {
00724       RefCountPtr<MV> R_temp = MVT::Clone( *X, MVT::GetNumberVecs( *X ) );
00725       OPT::Apply( *A_, *X, *R_temp );
00726       MVT::MvAddMv( -1.0, *R_temp, 1.0, *B, *R_temp );
00727       OPT::Apply( *LP_, *R_temp, *R );
00728     }
00729   else 
00730     {
00731       OPT::Apply( *A_, *X, *R );
00732       MVT::MvAddMv( -1.0, *R, 1.0, *B, *R );
00733     }
00734       }
00735     else { 
00736       // One of the entries is not specified, so just use the linear system information we have.
00737       // Later we may want to check to see which multivec is not specified, and use what is specified.
00738       if (Left_Prec_)
00739   {
00740     RefCountPtr<MV> R_temp = MVT::Clone( *X_, MVT::GetNumberVecs( *X_ ) );
00741     OPT::Apply( *A_, *X_, *R_temp );
00742     MVT::MvAddMv( -1.0, *R_temp, 1.0, *B_, *R_temp );
00743     OPT::Apply( *LP_, *R_temp, *R );
00744   }
00745       else 
00746   {
00747     OPT::Apply( *A_, *X_, *R );
00748     MVT::MvAddMv( -1.0, *R, 1.0, *B_, *R );
00749   }
00750     }    
00751     return Ok;
00752   }
00753   
00754 } // end Belos namespace
00755 
00756 #endif /* BELOS_LINEAR_PROBLEM_HPP */
00757 
00758 

Generated on Thu Sep 18 12:30:12 2008 for Belos by doxygen 1.3.9.1