00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030 #ifndef BELOS_LINEAR_PROBLEM_HPP
00031 #define BELOS_LINEAR_PROBLEM_HPP
00032
00037 #include "BelosMultiVecTraits.hpp"
00038 #include "BelosOperatorTraits.hpp"
00039 #include "Teuchos_ParameterList.hpp"
00040
00041 using Teuchos::RefCountPtr;
00042 using Teuchos::rcp;
00043 using Teuchos::null;
00044 using Teuchos::rcp_const_cast;
00045 using Teuchos::ParameterList;
00046
00054 namespace Belos {
00055
00056 template <class ScalarType, class MV, class OP>
00057 class LinearProblem {
00058
00059 public:
00060
00062
00063
00064
00068 LinearProblem(void);
00069
00071
00076 LinearProblem(const RefCountPtr<const OP> &A,
00077 const RefCountPtr<MV> &X,
00078 const RefCountPtr<const MV> &B
00079 );
00080
00082
00084 LinearProblem(const LinearProblem<ScalarType,MV,OP>& Problem);
00085
00087
00089 virtual ~LinearProblem(void);
00091
00093
00094
00096
00098 void SetOperator(const RefCountPtr<const OP> &A) { A_ = A; };
00099
00101
00103 void SetLHS(const RefCountPtr<MV> &X);
00104
00106
00108 void SetRHS(const RefCountPtr<const MV> &B) { B_ = B; };
00109
00111
00113 void SetLeftPrec(const RefCountPtr<const OP> &LP) { LP_ = LP; Left_Prec_ = true; };
00114
00116
00118 void SetRightPrec(const RefCountPtr<const OP> &RP) { RP_ = RP; Right_Prec_ = true; };
00119
00121 void SetParameterList(const RefCountPtr<ParameterList> &PL) { PL_ = PL; };
00122
00124 void SetBlockSize(int blocksize) { default_blocksize_ = blocksize; blocksize_ = blocksize; };
00125
00127
00132 void SetCurrLSVec();
00133
00135
00139 void AssertSymmetric(){ operatorSymmetric_ = true; };
00140
00142
00146 void SolutionUpdated( const MV* SolnUpdate = 0,
00147 ScalarType scale = Teuchos::ScalarTraits<ScalarType>::one() );
00148
00150
00152
00153
00155
00158 void Reset( const RefCountPtr<MV> &newX = null, const RefCountPtr<const MV> &newB = null );
00160
00162
00163
00165 RefCountPtr<const OP> GetOperator() const { return(A_); };
00166
00168 RefCountPtr<MV> GetLHS() const { return(X_); };
00169
00171 RefCountPtr<const MV> GetRHS() const { return(B_); };
00172
00174
00176 const MV& GetInitResVec();
00177
00179
00195 const MV& GetCurrResVec( const MV* CurrSoln = 0 );
00196
00198
00205 RefCountPtr<MV> GetCurrLHSVec();
00206
00208
00215 RefCountPtr<MV> GetCurrRHSVec();
00216
00218 RefCountPtr<const OP> GetLeftPrec() const { return(LP_); };
00219
00221 RefCountPtr<const OP> GetRightPrec() const { return(RP_); };
00222
00224 RefCountPtr<ParameterList> GetParameterList() const { return(PL_); };
00225
00227 int GetBlockSize() const { return( default_blocksize_ ); };
00228
00230
00234 int GetCurrBlockSize() const { return( blocksize_ ); };
00235
00237
00243 int GetNumToSolve() const { return( num_to_solve_ ); };
00244
00246
00253 int GetRHSIndex() const { return( rhs_index_ ); };
00254
00256
00260 bool IsSolutionUpdated() const { return(solutionUpdated_); };
00261
00263 bool IsOperatorSymmetric() const { return(operatorSymmetric_); };
00264
00266
00268
00269
00271
00278 ReturnType Apply( const MV& x, MV& y );
00279
00281
00288 ReturnType ApplyOp( const MV& x, MV& y );
00289
00291
00295 ReturnType ApplyLeftPrec( const MV& x, MV& y );
00296
00298
00302 ReturnType ApplyRightPrec( const MV& x, MV& y );
00303
00305
00309 ReturnType ComputeResVec( MV* R, const MV* X = 0, const MV* B = 0 );
00310
00312
00313 private:
00314
00316 void SetUpBlocks();
00317
00319 RefCountPtr<const OP> A_;
00320
00322 RefCountPtr<MV> X_;
00323
00325 RefCountPtr<MV> CurX_;
00326
00328 RefCountPtr<const MV> B_;
00329
00331 RefCountPtr<MV> CurB_;
00332
00334 RefCountPtr<MV> R_;
00335
00337 RefCountPtr<MV> R0_;
00338
00340 RefCountPtr<const OP> LP_;
00341
00343 RefCountPtr<const OP> RP_;
00344
00346 RefCountPtr<ParameterList> PL_;
00347
00349 int default_blocksize_;
00350
00352 int blocksize_;
00353
00355 int num_to_solve_;
00356
00358 int rhs_index_;
00359
00361 bool Left_Prec_;
00362 bool Right_Prec_;
00363 bool Left_Scale_;
00364 bool Right_Scale_;
00365 bool operatorSymmetric_;
00366 bool solutionUpdated_;
00367 bool solutionFinal_;
00368 bool initresidsComputed_;
00369
00370 typedef MultiVecTraits<ScalarType,MV> MVT;
00371 typedef OperatorTraits<ScalarType,MV,OP> OPT;
00372 };
00373
00374
00375
00376
00377
00378 template <class ScalarType, class MV, class OP>
00379 LinearProblem<ScalarType,MV,OP>::LinearProblem(void) :
00380 default_blocksize_(1),
00381 blocksize_(1),
00382 num_to_solve_(0),
00383 rhs_index_(0),
00384 Left_Prec_(false),
00385 Right_Prec_(false),
00386 Left_Scale_(false),
00387 Right_Scale_(false),
00388 operatorSymmetric_(false),
00389 solutionUpdated_(false),
00390 solutionFinal_(true),
00391 initresidsComputed_(false)
00392 {
00393 }
00394
00395 template <class ScalarType, class MV, class OP>
00396 LinearProblem<ScalarType,MV,OP>::LinearProblem(const RefCountPtr<const OP> &A,
00397 const RefCountPtr<MV> &X,
00398 const RefCountPtr<const MV> &B
00399 ) :
00400 A_(A),
00401 X_(X),
00402 B_(B),
00403 default_blocksize_(1),
00404 blocksize_(1),
00405 num_to_solve_(1),
00406 rhs_index_(0),
00407 Left_Prec_(false),
00408 Right_Prec_(false),
00409 Left_Scale_(false),
00410 Right_Scale_(false),
00411 operatorSymmetric_(false),
00412 solutionUpdated_(false),
00413 solutionFinal_(true),
00414 initresidsComputed_(false)
00415 {
00416 R0_ = MVT::Clone( *X_, MVT::GetNumberVecs( *X_ ) );
00417 }
00418
00419 template <class ScalarType, class MV, class OP>
00420 LinearProblem<ScalarType,MV,OP>::LinearProblem(const LinearProblem<ScalarType,MV,OP>& Problem) :
00421 A_(Problem.A_),
00422 X_(Problem.X_),
00423 CurX_(Problem.CurX_),
00424 B_(Problem.B_),
00425 CurB_(Problem.CurB_),
00426 R_(Problem.R_),
00427 R0_(Problem.R0_),
00428 LP_(Problem.LP_),
00429 RP_(Problem.RP_),
00430 PL_(Problem.PL_),
00431 default_blocksize_(Problem.default_blocksize_),
00432 blocksize_(Problem.blocksize_),
00433 num_to_solve_(Problem.num_to_solve_),
00434 rhs_index_(Problem.rhs_index_),
00435 Left_Prec_(Problem.Left_Prec_),
00436 Right_Prec_(Problem.Right_Prec_),
00437 Left_Scale_(Problem.Left_Scale_),
00438 Right_Scale_(Problem.Right_Scale_),
00439 operatorSymmetric_(Problem.operatorSymmetric_),
00440 solutionUpdated_(Problem.solutionUpdated_),
00441 solutionFinal_(Problem.solutionFinal_),
00442 initresidsComputed_(Problem.initresidsComputed_)
00443 {
00444 }
00445
00446 template <class ScalarType, class MV, class OP>
00447 LinearProblem<ScalarType,MV,OP>::~LinearProblem(void)
00448 {}
00449
00450 template <class ScalarType, class MV, class OP>
00451 void LinearProblem<ScalarType,MV,OP>::SetUpBlocks()
00452 {
00453
00454
00455 if (CurB_.get()) CurB_ = null;
00456 if (CurX_.get()) CurX_ = null;
00457 if (R_.get()) R_ = null;
00458
00459
00460
00461
00462
00463
00464 num_to_solve_ = MVT::GetNumberVecs(*X_) - rhs_index_;
00465
00466
00467 if ( num_to_solve_ <= 0 ) { return; }
00468
00469 int i;
00470 std::vector<int> index( num_to_solve_ );
00471 for ( i=0; i<num_to_solve_; i++ ) { index[i] = rhs_index_ + i; }
00472
00473
00474
00475
00476
00477
00478
00479 if ( num_to_solve_ < blocksize_ )
00480 {
00481 std::vector<int> index2(num_to_solve_);
00482 for (i=0; i<num_to_solve_; i++) {
00483 index2[i] = i;
00484 }
00485
00486
00487 CurX_ = MVT::Clone( *X_, blocksize_ );
00488 MVT::MvInit(*CurX_);
00489 CurB_ = MVT::Clone( *B_, blocksize_ );
00490 MVT::MvRandom(*CurB_);
00491 R_ = MVT::Clone( *X_, blocksize_);
00492
00493 RefCountPtr<const MV> tptr = MVT::CloneView( *B_, index );
00494 MVT::SetBlock( *tptr, index2, *CurB_ );
00495
00496 RefCountPtr<MV> tptr2 = MVT::CloneView( *X_, index );
00497 MVT::SetBlock( *tptr2, index2, *CurX_ );
00498 } else {
00499
00500
00501
00502
00503 num_to_solve_ = blocksize_;
00504 index.resize( num_to_solve_ );
00505 for ( i=0; i<num_to_solve_; i++ ) { index[i] = rhs_index_ + i; }
00506 CurX_ = MVT::CloneView( *X_, index );
00507 CurB_ = rcp_const_cast<MV>(MVT::CloneView( *B_, index ));
00508 R_ = MVT::Clone( *X_, num_to_solve_ );
00509
00510 }
00511
00512
00513
00514 if (R_.get()) {
00515 OPT::Apply( *A_, *CurX_, *R_ );
00516 MVT::MvAddMv( 1.0, *CurB_, -1.0, *R_, *R_ );
00517 solutionUpdated_ = false;
00518 }
00519 }
00520
00521 template <class ScalarType, class MV, class OP>
00522 void LinearProblem<ScalarType,MV,OP>::SetLHS(const RefCountPtr<MV> &X)
00523 {
00524 X_ = X;
00525 R0_ = MVT::Clone( *X_, MVT::GetNumberVecs( *X_ ) );
00526 }
00527
00528 template <class ScalarType, class MV, class OP>
00529 void LinearProblem<ScalarType,MV,OP>::SetCurrLSVec()
00530 {
00531 int i;
00532
00533
00534
00535
00536 if (num_to_solve_ < blocksize_) {
00537
00538 std::vector<int> index( num_to_solve_ );
00539
00540 RefCountPtr<MV> tptr;
00541
00542
00543
00544 for (i=0; i<num_to_solve_; i++) {
00545 index[i] = i;
00546 }
00547 tptr = MVT::CloneView( *CurX_, index );
00548
00549
00550
00551 for (i=0; i<num_to_solve_; i++) {
00552 index[i] = rhs_index_ + i;
00553 }
00554 MVT::SetBlock( *tptr, index, *X_ );
00555 }
00556
00557
00558
00559 solutionFinal_ = true;
00560 rhs_index_ += num_to_solve_;
00561 }
00562
00563 template <class ScalarType, class MV, class OP>
00564 void LinearProblem<ScalarType,MV,OP>::SolutionUpdated( const MV* SolnUpdate, ScalarType scale )
00565 {
00566 if (SolnUpdate) {
00567 if (Right_Prec_) {
00568
00569
00570 RefCountPtr<MV> TrueUpdate = MVT::Clone( *SolnUpdate, MVT::GetNumberVecs( *SolnUpdate ) );
00571 OPT::Apply( *RP_, *SolnUpdate, *TrueUpdate );
00572 MVT::MvAddMv( 1.0, *CurX_, scale, *TrueUpdate, *CurX_ );
00573 } else {
00574 MVT::MvAddMv( 1.0, *CurX_, scale, *SolnUpdate, *CurX_ );
00575 }
00576 }
00577 solutionUpdated_ = true;
00578 }
00579
00580 template <class ScalarType, class MV, class OP>
00581 void LinearProblem<ScalarType,MV,OP>::Reset( const RefCountPtr<MV> &newX, const RefCountPtr<const MV> &newB )
00582 {
00583 solutionUpdated_ = false;
00584 solutionFinal_ = true;
00585 initresidsComputed_ = false;
00586 rhs_index_ = 0;
00587
00588 X_ = newX;
00589 B_ = newB;
00590 GetInitResVec();
00591 }
00592
00593 template <class ScalarType, class MV, class OP>
00594 const MV& LinearProblem<ScalarType,MV,OP>::GetInitResVec()
00595 {
00596
00597
00598
00599
00600 if (!initresidsComputed_ && A_.get() && X_.get() && B_.get())
00601 {
00602 if (R0_.get()) R0_ = null;
00603 R0_ = MVT::Clone( *X_, MVT::GetNumberVecs( *X_ ) );
00604 OPT::Apply( *A_, *X_, *R0_ );
00605 MVT::MvAddMv( 1.0, *B_, -1.0, *R0_, *R0_ );
00606 initresidsComputed_ = true;
00607 }
00608 return (*R0_);
00609 }
00610
00611 template <class ScalarType, class MV, class OP>
00612 const MV& LinearProblem<ScalarType,MV,OP>::GetCurrResVec( const MV* CurrSoln )
00613 {
00614
00615
00616
00617
00618
00619
00620 if (solutionUpdated_)
00621 {
00622 OPT::Apply( *A_, *GetCurrLHSVec(), *R_ );
00623 MVT::MvAddMv( 1.0, *GetCurrRHSVec(), -1.0, *R_, *R_ );
00624 solutionUpdated_ = false;
00625 }
00626 else if (CurrSoln)
00627 {
00628 OPT::Apply( *A_, *CurrSoln, *R_ );
00629 MVT::MvAddMv( 1.0, *GetCurrRHSVec(), -1.0, *R_, *R_ );
00630 }
00631 return (*R_);
00632 }
00633
00634 template <class ScalarType, class MV, class OP>
00635 RefCountPtr<MV> LinearProblem<ScalarType,MV,OP>::GetCurrLHSVec()
00636 {
00637 if (solutionFinal_) {
00638 solutionFinal_ = false;
00639 SetUpBlocks();
00640 }
00641 return CurX_;
00642 }
00643
00644 template <class ScalarType, class MV, class OP>
00645 RefCountPtr<MV> LinearProblem<ScalarType,MV,OP>::GetCurrRHSVec()
00646 {
00647 if (solutionFinal_) {
00648 solutionFinal_ = false;
00649 SetUpBlocks();
00650 }
00651 return CurB_;
00652 }
00653
00654 template <class ScalarType, class MV, class OP>
00655 ReturnType LinearProblem<ScalarType,MV,OP>::Apply( const MV& x, MV& y )
00656 {
00657 RefCountPtr<MV> ytemp = MVT::Clone( y, MVT::GetNumberVecs( y ) );
00658
00659
00660
00661 if (!Left_Prec_ && !Right_Prec_){ OPT::Apply( *A_, x, y );}
00662
00663
00664
00665 else if( Left_Prec_ && Right_Prec_ )
00666 {
00667 OPT::Apply( *RP_, x, y );
00668 OPT::Apply( *A_, y, *ytemp );
00669 OPT::Apply( *LP_, *ytemp, y );
00670 }
00671
00672
00673
00674 else if( Left_Prec_ )
00675 {
00676 OPT::Apply( *A_, x, *ytemp );
00677 OPT::Apply( *LP_, *ytemp, y );
00678 }
00679
00680
00681
00682 else
00683 {
00684 OPT::Apply( *RP_, x, *ytemp );
00685 OPT::Apply( *A_, *ytemp, y );
00686 }
00687 return Ok;
00688 }
00689
00690 template <class ScalarType, class MV, class OP>
00691 ReturnType LinearProblem<ScalarType,MV,OP>::ApplyOp( const MV& x, MV& y )
00692 {
00693 if (A_.get())
00694 return ( OPT::Apply( *A_,x, y) );
00695 else
00696 return Undefined;
00697 }
00698
00699 template <class ScalarType, class MV, class OP>
00700 ReturnType LinearProblem<ScalarType,MV,OP>::ApplyLeftPrec( const MV& x, MV& y )
00701 {
00702 if (Left_Prec_)
00703 return ( OPT::Apply( *LP_,x, y) );
00704 else
00705 return Undefined;
00706 }
00707
00708 template <class ScalarType, class MV, class OP>
00709 ReturnType LinearProblem<ScalarType,MV,OP>::ApplyRightPrec( const MV& x, MV& y )
00710 {
00711 if (Right_Prec_)
00712 return ( OPT::Apply( *RP_,x, y) );
00713 else
00714 return Undefined;
00715 }
00716
00717 template <class ScalarType, class MV, class OP>
00718 ReturnType LinearProblem<ScalarType,MV,OP>::ComputeResVec( MV* R, const MV* X, const MV* B )
00719 {
00720 if (X && B)
00721 {
00722 if (Left_Prec_)
00723 {
00724 RefCountPtr<MV> R_temp = MVT::Clone( *X, MVT::GetNumberVecs( *X ) );
00725 OPT::Apply( *A_, *X, *R_temp );
00726 MVT::MvAddMv( -1.0, *R_temp, 1.0, *B, *R_temp );
00727 OPT::Apply( *LP_, *R_temp, *R );
00728 }
00729 else
00730 {
00731 OPT::Apply( *A_, *X, *R );
00732 MVT::MvAddMv( -1.0, *R, 1.0, *B, *R );
00733 }
00734 }
00735 else {
00736
00737
00738 if (Left_Prec_)
00739 {
00740 RefCountPtr<MV> R_temp = MVT::Clone( *X_, MVT::GetNumberVecs( *X_ ) );
00741 OPT::Apply( *A_, *X_, *R_temp );
00742 MVT::MvAddMv( -1.0, *R_temp, 1.0, *B_, *R_temp );
00743 OPT::Apply( *LP_, *R_temp, *R );
00744 }
00745 else
00746 {
00747 OPT::Apply( *A_, *X_, *R );
00748 MVT::MvAddMv( -1.0, *R, 1.0, *B_, *R );
00749 }
00750 }
00751 return Ok;
00752 }
00753
00754 }
00755
00756 #endif
00757
00758