00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030 #ifndef BELOS_LINEAR_PROBLEM_HPP
00031 #define BELOS_LINEAR_PROBLEM_HPP
00032
00037 #include "BelosMultiVecTraits.hpp"
00038 #include "BelosOperatorTraits.hpp"
00039 #include "Teuchos_ParameterList.hpp"
00040 #include "Teuchos_TimeMonitor.hpp"
00041
00042 using Teuchos::RCP;
00043 using Teuchos::rcp;
00044 using Teuchos::null;
00045 using Teuchos::rcp_const_cast;
00046 using Teuchos::ParameterList;
00047
00053 namespace Belos {
00054
00056
00057
00060 class LinearProblemError : public BelosError
00061 {public: LinearProblemError(const std::string& what_arg) : BelosError(what_arg) {}};
00062
00064
00065
00066 template <class ScalarType, class MV, class OP>
00067 class LinearProblem {
00068
00069 public:
00070
00072
00073
00074
00078 LinearProblem(void);
00079
00081
00086 LinearProblem(const RCP<const OP> &A,
00087 const RCP<MV> &X,
00088 const RCP<const MV> &B
00089 );
00090
00092
00094 LinearProblem(const LinearProblem<ScalarType,MV,OP>& Problem);
00095
00097
00099 virtual ~LinearProblem(void);
00101
00103
00104
00106
00108 void setOperator(const RCP<const OP> &A) { A_ = A; isSet_=false; }
00109
00111
00113 void setLHS(const RCP<MV> &X) { X_ = X; isSet_=false; }
00114
00116
00118 void setRHS(const RCP<const MV> &B) { B_ = B; isSet_=false; }
00119
00121
00123 void setLeftPrec(const RCP<const OP> &LP) { LP_ = LP; }
00124
00126
00128 void setRightPrec(const RCP<const OP> &RP) { RP_ = RP; }
00129
00131
00136 void setCurrLS();
00137
00139
00145 void setLSIndex(std::vector<int>& index);
00146
00148
00152 void setHermitian(){ isHermitian_ = true; }
00153
00155
00158 void setLabel(const std::string& label) { label_ = label; }
00159
00161
00166 RCP<MV> updateSolution( const RCP<MV>& update = null,
00167 bool updateLP = false,
00168 ScalarType scale = Teuchos::ScalarTraits<ScalarType>::one() );
00169
00171 RCP<MV> updateSolution( const RCP<MV>& update = null,
00172 ScalarType scale = Teuchos::ScalarTraits<ScalarType>::one() ) const
00173 { return const_cast<LinearProblem<ScalarType,MV,OP> *>(this)->updateSolution( update, false, scale ); }
00174
00176
00178
00179
00181
00186 bool setProblem( const RCP<MV> &newX = null, const RCP<const MV> &newB = null );
00187
00189
00191
00192
00194 RCP<const OP> getOperator() const { return(A_); }
00195
00197 RCP<MV> getLHS() const { return(X_); }
00198
00200 RCP<const MV> getRHS() const { return(B_); }
00201
00203
00205 RCP<const MV> getInitResVec() const { return(R0_); }
00206
00208
00210 RCP<const MV> getInitPrecResVec() const { return(PR0_); }
00211
00213
00220 RCP<MV> getCurrLHSVec();
00221
00223
00230 RCP<MV> getCurrRHSVec();
00231
00233 RCP<const OP> getLeftPrec() const { return(LP_); };
00234
00236 RCP<const OP> getRightPrec() const { return(RP_); };
00237
00239
00250 const std::vector<int> getLSIndex() const { return(rhsIndex_); }
00251
00253
00254
00255
00256 int getLSNumber() const { return(lsNum_); }
00257
00259
00261
00262
00264
00268 bool isSolutionUpdated() const { return(solutionUpdated_); }
00269
00271 bool isProblemSet() const { return(isSet_); }
00272
00274 bool isHermitian() const { return(isHermitian_); }
00275
00277 bool isLeftPrec() const { return(LP_!=null); }
00278
00280 bool isRightPrec() const { return(RP_!=null); }
00281
00283
00285
00286
00288
00295 void apply( const MV& x, MV& y ) const;
00296
00298
00305 void applyOp( const MV& x, MV& y ) const;
00306
00308
00312 void applyLeftPrec( const MV& x, MV& y ) const;
00313
00315
00319 void applyRightPrec( const MV& x, MV& y ) const;
00320
00322
00326 void computeCurrResVec( MV* R , const MV* X = 0, const MV* B = 0 ) const;
00327
00329
00333 void computeCurrPrecResVec( MV* R, const MV* X = 0, const MV* B = 0 ) const;
00334
00336
00337 private:
00338
00340 RCP<const OP> A_;
00341
00343 RCP<MV> X_;
00344
00346 RCP<MV> curX_;
00347
00349 RCP<const MV> B_;
00350
00352 RCP<MV> curB_;
00353
00355 RCP<MV> R0_;
00356
00358 RCP<MV> PR0_;
00359
00361 RCP<const OP> LP_;
00362
00364 RCP<const OP> RP_;
00365
00367 mutable Teuchos::RCP<Teuchos::Time> timerOp_, timerPrec_;
00368
00370 int blocksize_;
00371
00373 int num2Solve_;
00374
00376 std::vector<int> rhsIndex_;
00377
00379 int lsNum_;
00380
00382 bool Left_Scale_;
00383 bool Right_Scale_;
00384 bool isSet_;
00385 bool isHermitian_;
00386 bool solutionUpdated_;
00387
00389 std::string label_;
00390
00391 typedef MultiVecTraits<ScalarType,MV> MVT;
00392 typedef OperatorTraits<ScalarType,MV,OP> OPT;
00393 };
00394
00395
00396
00397
00398
00399 template <class ScalarType, class MV, class OP>
00400 LinearProblem<ScalarType,MV,OP>::LinearProblem(void) :
00401 blocksize_(0),
00402 num2Solve_(0),
00403 rhsIndex_(0),
00404 lsNum_(0),
00405 Left_Scale_(false),
00406 Right_Scale_(false),
00407 isSet_(false),
00408 isHermitian_(false),
00409 solutionUpdated_(false),
00410 label_("Belos")
00411 {
00412 }
00413
00414 template <class ScalarType, class MV, class OP>
00415 LinearProblem<ScalarType,MV,OP>::LinearProblem(const RCP<const OP> &A,
00416 const RCP<MV> &X,
00417 const RCP<const MV> &B
00418 ) :
00419 A_(A),
00420 X_(X),
00421 B_(B),
00422 blocksize_(0),
00423 num2Solve_(0),
00424 rhsIndex_(0),
00425 lsNum_(0),
00426 Left_Scale_(false),
00427 Right_Scale_(false),
00428 isSet_(false),
00429 isHermitian_(false),
00430 solutionUpdated_(false),
00431 label_("Belos")
00432 {
00433 }
00434
00435 template <class ScalarType, class MV, class OP>
00436 LinearProblem<ScalarType,MV,OP>::LinearProblem(const LinearProblem<ScalarType,MV,OP>& Problem) :
00437 A_(Problem.A_),
00438 X_(Problem.X_),
00439 curX_(Problem.curX_),
00440 B_(Problem.B_),
00441 curB_(Problem.curB_),
00442 R0_(Problem.R0_),
00443 PR0_(Problem.PR0_),
00444 LP_(Problem.LP_),
00445 RP_(Problem.RP_),
00446 timerOp_(Problem.timerOp_),
00447 timerPrec_(Problem.timerPrec_),
00448 blocksize_(Problem.blocksize_),
00449 num2Solve_(Problem.num2Solve_),
00450 rhsIndex_(Problem.rhsIndex_),
00451 lsNum_(Problem.lsNum_),
00452 Left_Scale_(Problem.Left_Scale_),
00453 Right_Scale_(Problem.Right_Scale_),
00454 isSet_(Problem.isSet_),
00455 isHermitian_(Problem.isHermitian_),
00456 solutionUpdated_(Problem.solutionUpdated_),
00457 label_(Problem.label_)
00458 {
00459 }
00460
00461 template <class ScalarType, class MV, class OP>
00462 LinearProblem<ScalarType,MV,OP>::~LinearProblem(void)
00463 {}
00464
00465 template <class ScalarType, class MV, class OP>
00466 void LinearProblem<ScalarType,MV,OP>::setLSIndex(std::vector<int>& index)
00467 {
00468
00469 rhsIndex_ = index;
00470
00471
00472
00473 curB_ = null;
00474 curX_ = null;
00475
00476
00477 int validIdx = 0, ivalidIdx = 0;
00478 blocksize_ = rhsIndex_.size();
00479 std::vector<int> vldIndex( blocksize_ );
00480 std::vector<int> newIndex( blocksize_ );
00481 std::vector<int> iIndex( blocksize_ );
00482 for (int i=0; i<blocksize_; ++i) {
00483 if (rhsIndex_[i] > -1) {
00484 vldIndex[validIdx] = rhsIndex_[i];
00485 newIndex[validIdx] = i;
00486 validIdx++;
00487 }
00488 else {
00489 iIndex[ivalidIdx] = i;
00490 ivalidIdx++;
00491 }
00492 }
00493 vldIndex.resize(validIdx);
00494 newIndex.resize(validIdx);
00495 iIndex.resize(ivalidIdx);
00496 num2Solve_ = validIdx;
00497
00498
00499 if (num2Solve_ != blocksize_) {
00500 newIndex.resize(num2Solve_);
00501 vldIndex.resize(num2Solve_);
00502
00503
00504
00505 curX_ = MVT::Clone( *X_, blocksize_ );
00506 MVT::MvInit(*curX_);
00507 curB_ = MVT::Clone( *B_, blocksize_ );
00508 MVT::MvRandom(*curB_);
00509
00510
00511 RCP<const MV> tptr = MVT::CloneView( *B_, vldIndex );
00512 MVT::SetBlock( *tptr, newIndex, *curB_ );
00513
00514
00515 tptr = MVT::CloneView( *X_, vldIndex );
00516 MVT::SetBlock( *tptr, newIndex, *curX_ );
00517
00518 solutionUpdated_ = false;
00519 }
00520 else {
00521 curX_ = MVT::CloneView( *X_, rhsIndex_ );
00522 curB_ = rcp_const_cast<MV>(MVT::CloneView( *B_, rhsIndex_ ));
00523 }
00524
00525
00526
00527 lsNum_++;
00528 }
00529
00530
00531 template <class ScalarType, class MV, class OP>
00532 void LinearProblem<ScalarType,MV,OP>::setCurrLS()
00533 {
00534
00535
00536
00537
00538 if (num2Solve_ < blocksize_) {
00539
00540
00541
00542 int validIdx = 0;
00543 std::vector<int> newIndex( num2Solve_ );
00544 std::vector<int> vldIndex( num2Solve_ );
00545 for (int i=0; i<blocksize_; ++i) {
00546 if ( rhsIndex_[i] > -1 ) {
00547 vldIndex[validIdx] = rhsIndex_[i];
00548 newIndex[validIdx] = i;
00549 validIdx++;
00550 }
00551 }
00552 RCP<MV> tptr = MVT::CloneView( *curX_, newIndex );
00553 MVT::SetBlock( *tptr, vldIndex, *X_ );
00554 }
00555
00556
00557
00558
00559 curX_ = null;
00560 curB_ = null;
00561 rhsIndex_.resize(0);
00562 }
00563
00564
00565 template <class ScalarType, class MV, class OP>
00566 RCP<MV> LinearProblem<ScalarType,MV,OP>::updateSolution( const RCP<MV>& update,
00567 bool updateLP,
00568 ScalarType scale )
00569 {
00570 RCP<MV> newSoln;
00571 if (update != null) {
00572 if (updateLP == true) {
00573 if (RP_!=null) {
00574
00575
00576 RCP<MV> TrueUpdate = MVT::Clone( *update, MVT::GetNumberVecs( *update ) );
00577 {
00578 Teuchos::TimeMonitor PrecTimer(*timerPrec_);
00579 OPT::Apply( *RP_, *update, *TrueUpdate );
00580 }
00581 MVT::MvAddMv( 1.0, *curX_, scale, *TrueUpdate, *curX_ );
00582 }
00583 else {
00584 MVT::MvAddMv( 1.0, *curX_, scale, *update, *curX_ );
00585 }
00586 solutionUpdated_ = true;
00587 newSoln = curX_;
00588 }
00589 else {
00590 newSoln = MVT::Clone( *update, MVT::GetNumberVecs( *update ) );
00591 if (RP_!=null) {
00592
00593
00594 RCP<MV> trueUpdate = MVT::Clone( *update, MVT::GetNumberVecs( *update ) );
00595 {
00596 Teuchos::TimeMonitor PrecTimer(*timerPrec_);
00597 OPT::Apply( *RP_, *update, *trueUpdate );
00598 }
00599 MVT::MvAddMv( 1.0, *curX_, scale, *trueUpdate, *newSoln );
00600 }
00601 else {
00602 MVT::MvAddMv( 1.0, *curX_, scale, *update, *newSoln );
00603 }
00604 }
00605 }
00606 else {
00607 newSoln = curX_;
00608 }
00609 return newSoln;
00610 }
00611
00612
00613 template <class ScalarType, class MV, class OP>
00614 bool LinearProblem<ScalarType,MV,OP>::setProblem( const RCP<MV> &newX, const RCP<const MV> &newB )
00615 {
00616
00617 if (timerOp_ == Teuchos::null) {
00618 std::string opLabel = label_ + ": Operation Op*x";
00619 timerOp_ = Teuchos::TimeMonitor::getNewTimer( opLabel );
00620 }
00621 if (timerPrec_ == Teuchos::null) {
00622 std::string precLabel = label_ + ": Operation Prec*x";
00623 timerPrec_ = Teuchos::TimeMonitor::getNewTimer( precLabel );
00624 }
00625
00626
00627 if (newX != null)
00628 X_ = newX;
00629 if (newB != null)
00630 B_ = newB;
00631
00632
00633 rhsIndex_.resize(0);
00634 curX_ = null;
00635 curB_ = null;
00636
00637
00638
00639 if (A_ == null || X_ == null || B_ == null) {
00640 isSet_ = false;
00641 return isSet_;
00642 }
00643
00644
00645 solutionUpdated_ = false;
00646
00647
00648 if (R0_==null || MVT::GetNumberVecs( *R0_ )!=MVT::GetNumberVecs( *X_ )) {
00649 R0_ = MVT::Clone( *X_, MVT::GetNumberVecs( *X_ ) );
00650 }
00651 computeCurrResVec( &*R0_, &*X_, &*B_ );
00652
00653 if (LP_!=null) {
00654 if (PR0_==null || MVT::GetNumberVecs( *PR0_ )!=MVT::GetNumberVecs( *X_ )) {
00655 PR0_ = MVT::Clone( *X_, MVT::GetNumberVecs( *X_ ) );
00656 }
00657 {
00658 Teuchos::TimeMonitor PrecTimer(*timerPrec_);
00659 OPT::Apply( *LP_, *R0_, *PR0_ );
00660 }
00661 }
00662 else {
00663 PR0_ = R0_;
00664 }
00665
00666
00667 isSet_ = true;
00668
00669
00670 return isSet_;
00671 }
00672
00673 template <class ScalarType, class MV, class OP>
00674 RCP<MV> LinearProblem<ScalarType,MV,OP>::getCurrLHSVec()
00675 {
00676 if (isSet_) {
00677 return curX_;
00678 }
00679 else {
00680 return Teuchos::null;
00681 }
00682 }
00683
00684 template <class ScalarType, class MV, class OP>
00685 RCP<MV> LinearProblem<ScalarType,MV,OP>::getCurrRHSVec()
00686 {
00687 if (isSet_) {
00688 return curB_;
00689 }
00690 else {
00691 return Teuchos::null;
00692 }
00693 }
00694
00695 template <class ScalarType, class MV, class OP>
00696 void LinearProblem<ScalarType,MV,OP>::apply( const MV& x, MV& y ) const
00697 {
00698 RCP<MV> ytemp = MVT::Clone( y, MVT::GetNumberVecs( y ) );
00699 bool leftPrec = LP_!=null;
00700 bool rightPrec = RP_!=null;
00701
00702
00703
00704 if (!leftPrec && !rightPrec){
00705 Teuchos::TimeMonitor OpTimer(*timerOp_);
00706 OPT::Apply( *A_, x, y );
00707 }
00708
00709
00710
00711 else if( leftPrec && rightPrec )
00712 {
00713 {
00714 Teuchos::TimeMonitor PrecTimer(*timerPrec_);
00715 OPT::Apply( *RP_, x, y );
00716 }
00717 {
00718 Teuchos::TimeMonitor OpTimer(*timerOp_);
00719 OPT::Apply( *A_, y, *ytemp );
00720 }
00721 {
00722 Teuchos::TimeMonitor PrecTimer(*timerPrec_);
00723 OPT::Apply( *LP_, *ytemp, y );
00724 }
00725 }
00726
00727
00728
00729 else if( leftPrec )
00730 {
00731 {
00732 Teuchos::TimeMonitor PrecTimer(*timerPrec_);
00733 OPT::Apply( *A_, x, *ytemp );
00734 }
00735 {
00736 Teuchos::TimeMonitor OpTimer(*timerOp_);
00737
00738 OPT::Apply( *LP_, *ytemp, y );
00739 }
00740 }
00741
00742
00743
00744 else
00745 {
00746 {
00747 Teuchos::TimeMonitor PrecTimer(*timerPrec_);
00748 OPT::Apply( *RP_, x, *ytemp );
00749 }
00750 {
00751 Teuchos::TimeMonitor OpTimer(*timerOp_);
00752 OPT::Apply( *A_, *ytemp, y );
00753 }
00754 }
00755 }
00756
00757 template <class ScalarType, class MV, class OP>
00758 void LinearProblem<ScalarType,MV,OP>::applyOp( const MV& x, MV& y ) const {
00759 if (A_.get()) {
00760 Teuchos::TimeMonitor OpTimer(*timerOp_);
00761 OPT::Apply( *A_,x, y);
00762 }
00763 else {
00764 MVT::MvAddMv( Teuchos::ScalarTraits<ScalarType>::one(), x,
00765 Teuchos::ScalarTraits<ScalarType>::zero(), x, y );
00766 }
00767 }
00768
00769 template <class ScalarType, class MV, class OP>
00770 void LinearProblem<ScalarType,MV,OP>::applyLeftPrec( const MV& x, MV& y ) const {
00771 if (LP_!=null) {
00772 Teuchos::TimeMonitor PrecTimer(*timerPrec_);
00773 return ( OPT::Apply( *LP_,x, y) );
00774 }
00775 else {
00776 MVT::MvAddMv( Teuchos::ScalarTraits<ScalarType>::one(), x,
00777 Teuchos::ScalarTraits<ScalarType>::zero(), x, y );
00778 }
00779 }
00780
00781 template <class ScalarType, class MV, class OP>
00782 void LinearProblem<ScalarType,MV,OP>::applyRightPrec( const MV& x, MV& y ) const {
00783 if (RP_!=null) {
00784 Teuchos::TimeMonitor PrecTimer(*timerPrec_);
00785 return ( OPT::Apply( *RP_,x, y) );
00786 }
00787 else {
00788 MVT::MvAddMv( Teuchos::ScalarTraits<ScalarType>::one(), x,
00789 Teuchos::ScalarTraits<ScalarType>::zero(), x, y );
00790 }
00791 }
00792
00793 template <class ScalarType, class MV, class OP>
00794 void LinearProblem<ScalarType,MV,OP>::computeCurrPrecResVec( MV* R, const MV* X, const MV* B ) const {
00795
00796 if (R) {
00797 if (X && B)
00798 {
00799 if (LP_!=null)
00800 {
00801 RCP<MV> R_temp = MVT::Clone( *X, MVT::GetNumberVecs( *X ) );
00802 {
00803 Teuchos::TimeMonitor OpTimer(*timerOp_);
00804 OPT::Apply( *A_, *X, *R_temp );
00805 }
00806 MVT::MvAddMv( -1.0, *R_temp, 1.0, *B, *R_temp );
00807 {
00808 Teuchos::TimeMonitor PrecTimer(*timerPrec_);
00809 OPT::Apply( *LP_, *R_temp, *R );
00810 }
00811 }
00812 else
00813 {
00814 {
00815 Teuchos::TimeMonitor OpTimer(*timerOp_);
00816 OPT::Apply( *A_, *X, *R );
00817 }
00818 MVT::MvAddMv( -1.0, *R, 1.0, *B, *R );
00819 }
00820 }
00821 else {
00822
00823 RCP<const MV> localB, localX;
00824 if (B)
00825 localB = rcp( B, false );
00826 else
00827 localB = curB_;
00828
00829 if (X)
00830 localX = rcp( X, false );
00831 else
00832 localX = curX_;
00833
00834 if (LP_!=null)
00835 {
00836 RCP<MV> R_temp = MVT::Clone( *localX, MVT::GetNumberVecs( *localX ) );
00837 {
00838 Teuchos::TimeMonitor OpTimer(*timerOp_);
00839 OPT::Apply( *A_, *localX, *R_temp );
00840 }
00841 MVT::MvAddMv( -1.0, *R_temp, 1.0, *localB, *R_temp );
00842 {
00843 Teuchos::TimeMonitor PrecTimer(*timerPrec_);
00844 OPT::Apply( *LP_, *R_temp, *R );
00845 }
00846 }
00847 else
00848 {
00849 {
00850 Teuchos::TimeMonitor OpTimer(*timerOp_);
00851 OPT::Apply( *A_, *localX, *R );
00852 }
00853 MVT::MvAddMv( -1.0, *R, 1.0, *localB, *R );
00854 }
00855 }
00856 }
00857 }
00858
00859
00860 template <class ScalarType, class MV, class OP>
00861 void LinearProblem<ScalarType,MV,OP>::computeCurrResVec( MV* R, const MV* X, const MV* B ) const {
00862
00863 if (R) {
00864 if (X && B)
00865 {
00866 {
00867 Teuchos::TimeMonitor OpTimer(*timerOp_);
00868 OPT::Apply( *A_, *X, *R );
00869 }
00870 MVT::MvAddMv( -1.0, *R, 1.0, *B, *R );
00871 }
00872 else {
00873
00874 RCP<const MV> localB, localX;
00875 if (B)
00876 localB = rcp( B, false );
00877 else
00878 localB = curB_;
00879
00880 if (X)
00881 localX = rcp( X, false );
00882 else
00883 localX = curX_;
00884
00885 {
00886 Teuchos::TimeMonitor OpTimer(*timerOp_);
00887 OPT::Apply( *A_, *localX, *R );
00888 }
00889 MVT::MvAddMv( -1.0, *R, 1.0, *localB, *R );
00890 }
00891 }
00892 }
00893
00894 }
00895
00896 #endif
00897
00898