00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030 #ifndef BELOS_LINEAR_PROBLEM_HPP
00031 #define BELOS_LINEAR_PROBLEM_HPP
00032
00037 #include "BelosMultiVecTraits.hpp"
00038 #include "BelosOperatorTraits.hpp"
00039 #include "Teuchos_ParameterList.hpp"
00040 #include "Teuchos_TimeMonitor.hpp"
00041
00042 using Teuchos::RCP;
00043 using Teuchos::rcp;
00044 using Teuchos::null;
00045 using Teuchos::rcp_const_cast;
00046 using Teuchos::ParameterList;
00047
00053 namespace Belos {
00054
00056
00057
00060 class LinearProblemError : public BelosError
00061 {public: LinearProblemError(const std::string& what_arg) : BelosError(what_arg) {}};
00062
00064
00065
00066 template <class ScalarType, class MV, class OP>
00067 class LinearProblem {
00068
00069 public:
00070
00072
00073
00074
00078 LinearProblem(void);
00079
00081
00086 LinearProblem(const RCP<const OP> &A,
00087 const RCP<MV> &X,
00088 const RCP<const MV> &B
00089 );
00090
00092
00094 LinearProblem(const LinearProblem<ScalarType,MV,OP>& Problem);
00095
00097
00099 virtual ~LinearProblem(void);
00101
00103
00104
00106
00108 void setOperator(const RCP<const OP> &A) { A_ = A; isSet_=false; }
00109
00111
00113 void setLHS(const RCP<MV> &X) { X_ = X; isSet_=false; }
00114
00116
00118 void setRHS(const RCP<const MV> &B) { B_ = B; isSet_=false; }
00119
00121
00123 void setLeftPrec(const RCP<const OP> &LP) { LP_ = LP; }
00124
00126
00128 void setRightPrec(const RCP<const OP> &RP) { RP_ = RP; }
00129
00131
00136 void setCurrLS();
00137
00139
00145 void setLSIndex(std::vector<int>& index);
00146
00148
00152 void setHermitian(){ isHermitian_ = true; }
00153
00155
00158 void setLabel(const std::string& label) { label_ = label; }
00159
00161
00166 RCP<MV> updateSolution( const RCP<MV>& update = null,
00167 bool updateLP = false,
00168 ScalarType scale = Teuchos::ScalarTraits<ScalarType>::one() );
00169
00171 RCP<MV> updateSolution( const RCP<MV>& update = null,
00172 ScalarType scale = Teuchos::ScalarTraits<ScalarType>::one() ) const
00173 { return const_cast<LinearProblem<ScalarType,MV,OP> *>(this)->updateSolution( update, false, scale ); }
00174
00176
00178
00179
00181
00186 bool setProblem( const RCP<MV> &newX = null, const RCP<const MV> &newB = null );
00187
00189
00191
00192
00194 RCP<const OP> getOperator() const { return(A_); }
00195
00197 RCP<MV> getLHS() const { return(X_); }
00198
00200 RCP<const MV> getRHS() const { return(B_); }
00201
00203
00205 RCP<const MV> getInitResVec() const { return(R0_); }
00206
00208
00210 RCP<const MV> getInitPrecResVec() const { return(PR0_); }
00211
00213
00220 RCP<MV> getCurrLHSVec();
00221
00223
00230 RCP<MV> getCurrRHSVec();
00231
00233 RCP<const OP> getLeftPrec() const { return(LP_); };
00234
00236 RCP<const OP> getRightPrec() const { return(RP_); };
00237
00239
00250 const std::vector<int> getLSIndex() const { return(rhsIndex_); }
00251
00253
00254
00255
00256 int getLSNumber() const { return(lsNum_); }
00257
00264 Teuchos::Array<Teuchos::RCP<Teuchos::Time> > getTimers() const {
00265 return tuple(timerOp_,timerPrec_);
00266 }
00267
00268
00270
00272
00273
00275
00279 bool isSolutionUpdated() const { return(solutionUpdated_); }
00280
00282 bool isProblemSet() const { return(isSet_); }
00283
00285 bool isHermitian() const { return(isHermitian_); }
00286
00288 bool isLeftPrec() const { return(LP_!=null); }
00289
00291 bool isRightPrec() const { return(RP_!=null); }
00292
00294
00296
00297
00299
00306 void apply( const MV& x, MV& y ) const;
00307
00309
00316 void applyOp( const MV& x, MV& y ) const;
00317
00319
00323 void applyLeftPrec( const MV& x, MV& y ) const;
00324
00326
00330 void applyRightPrec( const MV& x, MV& y ) const;
00331
00333
00337 void computeCurrResVec( MV* R , const MV* X = 0, const MV* B = 0 ) const;
00338
00340
00344 void computeCurrPrecResVec( MV* R, const MV* X = 0, const MV* B = 0 ) const;
00345
00347
00348 private:
00349
00351 RCP<const OP> A_;
00352
00354 RCP<MV> X_;
00355
00357 RCP<MV> curX_;
00358
00360 RCP<const MV> B_;
00361
00363 RCP<MV> curB_;
00364
00366 RCP<MV> R0_;
00367
00369 RCP<MV> PR0_;
00370
00372 RCP<const OP> LP_;
00373
00375 RCP<const OP> RP_;
00376
00378 mutable Teuchos::RCP<Teuchos::Time> timerOp_, timerPrec_;
00379
00381 int blocksize_;
00382
00384 int num2Solve_;
00385
00387 std::vector<int> rhsIndex_;
00388
00390 int lsNum_;
00391
00393 bool Left_Scale_;
00394 bool Right_Scale_;
00395 bool isSet_;
00396 bool isHermitian_;
00397 bool solutionUpdated_;
00398
00400 std::string label_;
00401
00402 typedef MultiVecTraits<ScalarType,MV> MVT;
00403 typedef OperatorTraits<ScalarType,MV,OP> OPT;
00404 };
00405
00406
00407
00408
00409
00410 template <class ScalarType, class MV, class OP>
00411 LinearProblem<ScalarType,MV,OP>::LinearProblem(void) :
00412 blocksize_(0),
00413 num2Solve_(0),
00414 rhsIndex_(0),
00415 lsNum_(0),
00416 Left_Scale_(false),
00417 Right_Scale_(false),
00418 isSet_(false),
00419 isHermitian_(false),
00420 solutionUpdated_(false),
00421 label_("Belos")
00422 {
00423 }
00424
00425 template <class ScalarType, class MV, class OP>
00426 LinearProblem<ScalarType,MV,OP>::LinearProblem(const RCP<const OP> &A,
00427 const RCP<MV> &X,
00428 const RCP<const MV> &B
00429 ) :
00430 A_(A),
00431 X_(X),
00432 B_(B),
00433 blocksize_(0),
00434 num2Solve_(0),
00435 rhsIndex_(0),
00436 lsNum_(0),
00437 Left_Scale_(false),
00438 Right_Scale_(false),
00439 isSet_(false),
00440 isHermitian_(false),
00441 solutionUpdated_(false),
00442 label_("Belos")
00443 {
00444 }
00445
00446 template <class ScalarType, class MV, class OP>
00447 LinearProblem<ScalarType,MV,OP>::LinearProblem(const LinearProblem<ScalarType,MV,OP>& Problem) :
00448 A_(Problem.A_),
00449 X_(Problem.X_),
00450 curX_(Problem.curX_),
00451 B_(Problem.B_),
00452 curB_(Problem.curB_),
00453 R0_(Problem.R0_),
00454 PR0_(Problem.PR0_),
00455 LP_(Problem.LP_),
00456 RP_(Problem.RP_),
00457 timerOp_(Problem.timerOp_),
00458 timerPrec_(Problem.timerPrec_),
00459 blocksize_(Problem.blocksize_),
00460 num2Solve_(Problem.num2Solve_),
00461 rhsIndex_(Problem.rhsIndex_),
00462 lsNum_(Problem.lsNum_),
00463 Left_Scale_(Problem.Left_Scale_),
00464 Right_Scale_(Problem.Right_Scale_),
00465 isSet_(Problem.isSet_),
00466 isHermitian_(Problem.isHermitian_),
00467 solutionUpdated_(Problem.solutionUpdated_),
00468 label_(Problem.label_)
00469 {
00470 }
00471
00472 template <class ScalarType, class MV, class OP>
00473 LinearProblem<ScalarType,MV,OP>::~LinearProblem(void)
00474 {}
00475
00476 template <class ScalarType, class MV, class OP>
00477 void LinearProblem<ScalarType,MV,OP>::setLSIndex(std::vector<int>& index)
00478 {
00479
00480 rhsIndex_ = index;
00481
00482
00483
00484 curB_ = null;
00485 curX_ = null;
00486
00487
00488 int validIdx = 0, ivalidIdx = 0;
00489 blocksize_ = rhsIndex_.size();
00490 std::vector<int> vldIndex( blocksize_ );
00491 std::vector<int> newIndex( blocksize_ );
00492 std::vector<int> iIndex( blocksize_ );
00493 for (int i=0; i<blocksize_; ++i) {
00494 if (rhsIndex_[i] > -1) {
00495 vldIndex[validIdx] = rhsIndex_[i];
00496 newIndex[validIdx] = i;
00497 validIdx++;
00498 }
00499 else {
00500 iIndex[ivalidIdx] = i;
00501 ivalidIdx++;
00502 }
00503 }
00504 vldIndex.resize(validIdx);
00505 newIndex.resize(validIdx);
00506 iIndex.resize(ivalidIdx);
00507 num2Solve_ = validIdx;
00508
00509
00510 if (num2Solve_ != blocksize_) {
00511 newIndex.resize(num2Solve_);
00512 vldIndex.resize(num2Solve_);
00513
00514
00515
00516 curX_ = MVT::Clone( *X_, blocksize_ );
00517 MVT::MvInit(*curX_);
00518 curB_ = MVT::Clone( *B_, blocksize_ );
00519 MVT::MvRandom(*curB_);
00520
00521
00522 RCP<const MV> tptr = MVT::CloneView( *B_, vldIndex );
00523 MVT::SetBlock( *tptr, newIndex, *curB_ );
00524
00525
00526 tptr = MVT::CloneView( *X_, vldIndex );
00527 MVT::SetBlock( *tptr, newIndex, *curX_ );
00528
00529 solutionUpdated_ = false;
00530 }
00531 else {
00532 curX_ = MVT::CloneView( *X_, rhsIndex_ );
00533 curB_ = rcp_const_cast<MV>(MVT::CloneView( *B_, rhsIndex_ ));
00534 }
00535
00536
00537
00538 lsNum_++;
00539 }
00540
00541
00542 template <class ScalarType, class MV, class OP>
00543 void LinearProblem<ScalarType,MV,OP>::setCurrLS()
00544 {
00545
00546
00547
00548
00549 if (num2Solve_ < blocksize_) {
00550
00551
00552
00553 int validIdx = 0;
00554 std::vector<int> newIndex( num2Solve_ );
00555 std::vector<int> vldIndex( num2Solve_ );
00556 for (int i=0; i<blocksize_; ++i) {
00557 if ( rhsIndex_[i] > -1 ) {
00558 vldIndex[validIdx] = rhsIndex_[i];
00559 newIndex[validIdx] = i;
00560 validIdx++;
00561 }
00562 }
00563 RCP<MV> tptr = MVT::CloneView( *curX_, newIndex );
00564 MVT::SetBlock( *tptr, vldIndex, *X_ );
00565 }
00566
00567
00568
00569
00570 curX_ = null;
00571 curB_ = null;
00572 rhsIndex_.resize(0);
00573 }
00574
00575
00576 template <class ScalarType, class MV, class OP>
00577 RCP<MV> LinearProblem<ScalarType,MV,OP>::updateSolution( const RCP<MV>& update,
00578 bool updateLP,
00579 ScalarType scale )
00580 {
00581 RCP<MV> newSoln;
00582 if (update != null) {
00583 if (updateLP == true) {
00584 if (RP_!=null) {
00585
00586
00587 RCP<MV> TrueUpdate = MVT::Clone( *update, MVT::GetNumberVecs( *update ) );
00588 {
00589 Teuchos::TimeMonitor PrecTimer(*timerPrec_);
00590 OPT::Apply( *RP_, *update, *TrueUpdate );
00591 }
00592 MVT::MvAddMv( 1.0, *curX_, scale, *TrueUpdate, *curX_ );
00593 }
00594 else {
00595 MVT::MvAddMv( 1.0, *curX_, scale, *update, *curX_ );
00596 }
00597 solutionUpdated_ = true;
00598 newSoln = curX_;
00599 }
00600 else {
00601 newSoln = MVT::Clone( *update, MVT::GetNumberVecs( *update ) );
00602 if (RP_!=null) {
00603
00604
00605 RCP<MV> trueUpdate = MVT::Clone( *update, MVT::GetNumberVecs( *update ) );
00606 {
00607 Teuchos::TimeMonitor PrecTimer(*timerPrec_);
00608 OPT::Apply( *RP_, *update, *trueUpdate );
00609 }
00610 MVT::MvAddMv( 1.0, *curX_, scale, *trueUpdate, *newSoln );
00611 }
00612 else {
00613 MVT::MvAddMv( 1.0, *curX_, scale, *update, *newSoln );
00614 }
00615 }
00616 }
00617 else {
00618 newSoln = curX_;
00619 }
00620 return newSoln;
00621 }
00622
00623
00624 template <class ScalarType, class MV, class OP>
00625 bool LinearProblem<ScalarType,MV,OP>::setProblem( const RCP<MV> &newX, const RCP<const MV> &newB )
00626 {
00627
00628 if (timerOp_ == Teuchos::null) {
00629 std::string opLabel = label_ + ": Operation Op*x";
00630 timerOp_ = Teuchos::TimeMonitor::getNewTimer( opLabel );
00631 }
00632 if (timerPrec_ == Teuchos::null) {
00633 std::string precLabel = label_ + ": Operation Prec*x";
00634 timerPrec_ = Teuchos::TimeMonitor::getNewTimer( precLabel );
00635 }
00636
00637
00638 if (newX != null)
00639 X_ = newX;
00640 if (newB != null)
00641 B_ = newB;
00642
00643
00644 rhsIndex_.resize(0);
00645 curX_ = null;
00646 curB_ = null;
00647
00648
00649
00650 if (A_ == null || X_ == null || B_ == null) {
00651 isSet_ = false;
00652 return isSet_;
00653 }
00654
00655
00656 solutionUpdated_ = false;
00657
00658
00659 if (R0_==null || MVT::GetNumberVecs( *R0_ )!=MVT::GetNumberVecs( *X_ )) {
00660 R0_ = MVT::Clone( *X_, MVT::GetNumberVecs( *X_ ) );
00661 }
00662 computeCurrResVec( &*R0_, &*X_, &*B_ );
00663
00664 if (LP_!=null) {
00665 if (PR0_==null || MVT::GetNumberVecs( *PR0_ )!=MVT::GetNumberVecs( *X_ )) {
00666 PR0_ = MVT::Clone( *X_, MVT::GetNumberVecs( *X_ ) );
00667 }
00668 {
00669 Teuchos::TimeMonitor PrecTimer(*timerPrec_);
00670 OPT::Apply( *LP_, *R0_, *PR0_ );
00671 }
00672 }
00673 else {
00674 PR0_ = R0_;
00675 }
00676
00677
00678 isSet_ = true;
00679
00680
00681 return isSet_;
00682 }
00683
00684 template <class ScalarType, class MV, class OP>
00685 RCP<MV> LinearProblem<ScalarType,MV,OP>::getCurrLHSVec()
00686 {
00687 if (isSet_) {
00688 return curX_;
00689 }
00690 else {
00691 return Teuchos::null;
00692 }
00693 }
00694
00695 template <class ScalarType, class MV, class OP>
00696 RCP<MV> LinearProblem<ScalarType,MV,OP>::getCurrRHSVec()
00697 {
00698 if (isSet_) {
00699 return curB_;
00700 }
00701 else {
00702 return Teuchos::null;
00703 }
00704 }
00705
00706 template <class ScalarType, class MV, class OP>
00707 void LinearProblem<ScalarType,MV,OP>::apply( const MV& x, MV& y ) const
00708 {
00709 RCP<MV> ytemp = MVT::Clone( y, MVT::GetNumberVecs( y ) );
00710 bool leftPrec = LP_!=null;
00711 bool rightPrec = RP_!=null;
00712
00713
00714
00715 if (!leftPrec && !rightPrec){
00716 Teuchos::TimeMonitor OpTimer(*timerOp_);
00717 OPT::Apply( *A_, x, y );
00718 }
00719
00720
00721
00722 else if( leftPrec && rightPrec )
00723 {
00724 {
00725 Teuchos::TimeMonitor PrecTimer(*timerPrec_);
00726 OPT::Apply( *RP_, x, y );
00727 }
00728 {
00729 Teuchos::TimeMonitor OpTimer(*timerOp_);
00730 OPT::Apply( *A_, y, *ytemp );
00731 }
00732 {
00733 Teuchos::TimeMonitor PrecTimer(*timerPrec_);
00734 OPT::Apply( *LP_, *ytemp, y );
00735 }
00736 }
00737
00738
00739
00740 else if( leftPrec )
00741 {
00742 {
00743 Teuchos::TimeMonitor OpTimer(*timerOp_);
00744 OPT::Apply( *A_, x, *ytemp );
00745 }
00746 {
00747 Teuchos::TimeMonitor PrecTimer(*timerPrec_);
00748 OPT::Apply( *LP_, *ytemp, y );
00749 }
00750 }
00751
00752
00753
00754 else
00755 {
00756 {
00757 Teuchos::TimeMonitor PrecTimer(*timerPrec_);
00758 OPT::Apply( *RP_, x, *ytemp );
00759 }
00760 {
00761 Teuchos::TimeMonitor OpTimer(*timerOp_);
00762 OPT::Apply( *A_, *ytemp, y );
00763 }
00764 }
00765 }
00766
00767 template <class ScalarType, class MV, class OP>
00768 void LinearProblem<ScalarType,MV,OP>::applyOp( const MV& x, MV& y ) const {
00769 if (A_.get()) {
00770 Teuchos::TimeMonitor OpTimer(*timerOp_);
00771 OPT::Apply( *A_,x, y);
00772 }
00773 else {
00774 MVT::MvAddMv( Teuchos::ScalarTraits<ScalarType>::one(), x,
00775 Teuchos::ScalarTraits<ScalarType>::zero(), x, y );
00776 }
00777 }
00778
00779 template <class ScalarType, class MV, class OP>
00780 void LinearProblem<ScalarType,MV,OP>::applyLeftPrec( const MV& x, MV& y ) const {
00781 if (LP_!=null) {
00782 Teuchos::TimeMonitor PrecTimer(*timerPrec_);
00783 return ( OPT::Apply( *LP_,x, y) );
00784 }
00785 else {
00786 MVT::MvAddMv( Teuchos::ScalarTraits<ScalarType>::one(), x,
00787 Teuchos::ScalarTraits<ScalarType>::zero(), x, y );
00788 }
00789 }
00790
00791 template <class ScalarType, class MV, class OP>
00792 void LinearProblem<ScalarType,MV,OP>::applyRightPrec( const MV& x, MV& y ) const {
00793 if (RP_!=null) {
00794 Teuchos::TimeMonitor PrecTimer(*timerPrec_);
00795 return ( OPT::Apply( *RP_,x, y) );
00796 }
00797 else {
00798 MVT::MvAddMv( Teuchos::ScalarTraits<ScalarType>::one(), x,
00799 Teuchos::ScalarTraits<ScalarType>::zero(), x, y );
00800 }
00801 }
00802
00803 template <class ScalarType, class MV, class OP>
00804 void LinearProblem<ScalarType,MV,OP>::computeCurrPrecResVec( MV* R, const MV* X, const MV* B ) const {
00805
00806 if (R) {
00807 if (X && B)
00808 {
00809 if (LP_!=null)
00810 {
00811 RCP<MV> R_temp = MVT::Clone( *X, MVT::GetNumberVecs( *X ) );
00812 {
00813 Teuchos::TimeMonitor OpTimer(*timerOp_);
00814 OPT::Apply( *A_, *X, *R_temp );
00815 }
00816 MVT::MvAddMv( -1.0, *R_temp, 1.0, *B, *R_temp );
00817 {
00818 Teuchos::TimeMonitor PrecTimer(*timerPrec_);
00819 OPT::Apply( *LP_, *R_temp, *R );
00820 }
00821 }
00822 else
00823 {
00824 {
00825 Teuchos::TimeMonitor OpTimer(*timerOp_);
00826 OPT::Apply( *A_, *X, *R );
00827 }
00828 MVT::MvAddMv( -1.0, *R, 1.0, *B, *R );
00829 }
00830 }
00831 else {
00832
00833 RCP<const MV> localB, localX;
00834 if (B)
00835 localB = rcp( B, false );
00836 else
00837 localB = curB_;
00838
00839 if (X)
00840 localX = rcp( X, false );
00841 else
00842 localX = curX_;
00843
00844 if (LP_!=null)
00845 {
00846 RCP<MV> R_temp = MVT::Clone( *localX, MVT::GetNumberVecs( *localX ) );
00847 {
00848 Teuchos::TimeMonitor OpTimer(*timerOp_);
00849 OPT::Apply( *A_, *localX, *R_temp );
00850 }
00851 MVT::MvAddMv( -1.0, *R_temp, 1.0, *localB, *R_temp );
00852 {
00853 Teuchos::TimeMonitor PrecTimer(*timerPrec_);
00854 OPT::Apply( *LP_, *R_temp, *R );
00855 }
00856 }
00857 else
00858 {
00859 {
00860 Teuchos::TimeMonitor OpTimer(*timerOp_);
00861 OPT::Apply( *A_, *localX, *R );
00862 }
00863 MVT::MvAddMv( -1.0, *R, 1.0, *localB, *R );
00864 }
00865 }
00866 }
00867 }
00868
00869
00870 template <class ScalarType, class MV, class OP>
00871 void LinearProblem<ScalarType,MV,OP>::computeCurrResVec( MV* R, const MV* X, const MV* B ) const {
00872
00873 if (R) {
00874 if (X && B)
00875 {
00876 {
00877 Teuchos::TimeMonitor OpTimer(*timerOp_);
00878 OPT::Apply( *A_, *X, *R );
00879 }
00880 MVT::MvAddMv( -1.0, *R, 1.0, *B, *R );
00881 }
00882 else {
00883
00884 RCP<const MV> localB, localX;
00885 if (B)
00886 localB = rcp( B, false );
00887 else
00888 localB = curB_;
00889
00890 if (X)
00891 localX = rcp( X, false );
00892 else
00893 localX = curX_;
00894
00895 {
00896 Teuchos::TimeMonitor OpTimer(*timerOp_);
00897 OPT::Apply( *A_, *localX, *R );
00898 }
00899 MVT::MvAddMv( -1.0, *R, 1.0, *localB, *R );
00900 }
00901 }
00902 }
00903
00904 }
00905
00906 #endif
00907
00908