00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029 #ifndef THYRA_MPI_TRIDIAG_LINEAR_OP_HPP
00030 #define THYRA_MPI_TRIDIAG_LINEAR_OP_HPP
00031
00032 #include "Thyra_MPILinearOpBase.hpp"
00033 #include "Teuchos_PrimitiveTypeTraits.hpp"
00034 #include "Teuchos_RawMPITraits.hpp"
00035
00133 template<class Scalar>
00134 class MPITridiagLinearOp : public Thyra::MPILinearOpBase<Scalar> {
00135 private:
00136
00137 MPI_Comm mpiComm_;
00138 int procRank_;
00139 int numProc_;
00140 Thyra::Index localDim_;
00141 std::vector<Scalar> lower_;
00142 std::vector<Scalar> diag_;
00143 std::vector<Scalar> upper_;
00144
00145 void communicate( const bool first, const bool last, const Scalar x[], Scalar *x_km1, Scalar *x_kp1 ) const;
00146
00147 public:
00148
00150 using Thyra::MPILinearOpBase<Scalar>::euclideanApply;
00151
00153 MPITridiagLinearOp() : mpiComm_(MPI_COMM_NULL), procRank_(0), numProc_(0) {}
00154
00156 MPITridiagLinearOp( MPI_Comm mpiComm, const Thyra::Index localDim, const Scalar lower[], const Scalar diag[], const Scalar upper[] )
00157 { this->initialize(mpiComm,localDim,lower,diag,upper); }
00158
00176 void initialize(
00177 MPI_Comm mpiComm
00178 ,const Thyra::Index localDim
00179 ,const Scalar lower[]
00180 ,const Scalar diag[]
00181 ,const Scalar upper[]
00182 )
00183 {
00184 TEST_FOR_EXCEPT( localDim < 2 );
00185 this->setLocalDimensions(mpiComm,localDim,localDim);
00186 mpiComm_ = mpiComm;
00187 localDim_ = localDim;
00188 MPI_Comm_size( mpiComm, &numProc_ );
00189 MPI_Comm_rank( mpiComm, &procRank_ );
00190 const Thyra::Index
00191 lowerDim = ( procRank_ == 0 ? localDim - 1 : localDim ),
00192 upperDim = ( procRank_ == numProc_-1 ? localDim - 1 : localDim );
00193 lower_.resize(lowerDim); for( int k = 0; k < lowerDim; ++k ) lower_[k] = lower[k];
00194 diag_.resize(localDim); for( int k = 0; k < localDim; ++k ) diag_[k] = diag[k];
00195 upper_.resize(upperDim); for( int k = 0; k < upperDim; ++k ) upper_[k] = upper[k];
00196 }
00197
00198
00199
00200 std::string description() const
00201 {
00202 return (std::string("MPITridiagLinearOp<") + Teuchos::ScalarTraits<Scalar>::name() + std::string(">"));
00203 }
00204
00205 protected:
00206
00207
00208
00209
00210 bool opSupported( Thyra::ETransp M_trans ) const
00211 {
00212 typedef Teuchos::ScalarTraits<Scalar> ST;
00213 return (M_trans == Thyra::NOTRANS || (!ST::isComplex && M_trans == Thyra::CONJ) );
00214 }
00215
00216
00217
00218 void euclideanApply(
00219 const Thyra::ETransp M_trans
00220 ,const RTOpPack::SubVectorT<Scalar> &local_x_in
00221 ,const RTOpPack::MutableSubVectorT<Scalar> *local_y_out
00222 ,const Scalar alpha
00223 ,const Scalar beta
00224 ) const
00225 {
00226 typedef Teuchos::ScalarTraits<Scalar> ST;
00227 TEST_FOR_EXCEPTION( M_trans != Thyra::NOTRANS, std::logic_error, "Error, can not handle transpose!" );
00228
00229 const Scalar zero = ST::zero();
00230
00231 const Scalar *x = local_x_in.values();
00232 Scalar *y = local_y_out->values();
00233
00234 const bool first = ( procRank_ == 0 ), last = ( procRank_ == numProc_-1 );
00235
00236 Scalar x_km1, x_kp1;
00237 communicate( first, last, x, &x_km1, &x_kp1 );
00238
00239 Thyra::Index k = 0, lk = 0;
00240 if( beta == zero ) {
00241 y[k] = alpha * ( (first?zero:lower_[lk]*x_km1) + diag_[k]*x[k] + upper_[k]*x[k+1] ); if(!first) ++lk;
00242 for( k = 1; k < localDim_ - 1; ++lk, ++k )
00243 y[k] = alpha * ( lower_[lk]*x[k-1] + diag_[k]*x[k] + upper_[k]*x[k+1] );
00244 y[k] = alpha * ( lower_[lk]*x[k-1] + diag_[k]*x[k] + (last?zero:upper_[k]*x_kp1) );
00245 }
00246 else {
00247 y[k] = alpha * ( (first?zero:lower_[lk]*x_km1) + diag_[k]*x[k] + upper_[k]*x[k+1] ) + beta*y[k]; if(!first) ++lk;
00248 for( k = 1; k < localDim_ - 1; ++lk, ++k )
00249 y[k] = alpha * ( lower_[lk]*x[k-1] + diag_[k]*x[k] + upper_[k]*x[k+1] ) + beta*y[k];
00250 y[k] = alpha * ( lower_[lk]*x[k-1] + diag_[k]*x[k] + (last?zero:upper_[k]*x_kp1) ) + beta*y[k];
00251 }
00252
00253 }
00254
00255 };
00256
00257
00258
00259 template<class Scalar>
00260 void MPITridiagLinearOp<Scalar>::communicate(
00261 const bool first, const bool last, const Scalar x[], Scalar *x_km1, Scalar *x_kp1
00262 ) const
00263 {
00264 if(numProc_ > 1 ) {
00265
00266 typedef Teuchos::PrimitiveTypeTraits<Scalar> PTT;
00267 typedef typename PTT::primitiveType PT;
00268 typedef Teuchos::RawMPITraits<PT> PRMT;
00269 MPI_Datatype primMPIType = PRMT::type();
00270 const int numPrimObjs = PTT::numPrimitiveObjs();
00271 MPI_Status status;
00272
00273 std::vector<PT> buff(numPrimObjs);
00274
00275 if(last) {
00276 MPI_Recv( &buff[0], PRMT::adjustCount(numPrimObjs), primMPIType, procRank_-1, 0, mpiComm_, &status );
00277 PTT::loadPrimitiveObjs( numPrimObjs, &buff[0], x_km1 );
00278 }
00279 else {
00280 PTT::extractPrimitiveObjs( x[localDim_-1], numPrimObjs, &buff[0] );
00281 if(first) {
00282 MPI_Send( &buff[0], numPrimObjs, primMPIType, procRank_+1, 0, mpiComm_ );
00283 }
00284 else {
00285 MPI_Sendrecv_replace( &buff[0], numPrimObjs, primMPIType, procRank_+1, 0, procRank_-1, 0, mpiComm_, &status );
00286 PTT::loadPrimitiveObjs( numPrimObjs, &buff[0], x_km1 );
00287 }
00288 }
00289
00290 if(first) {
00291 MPI_Recv( &buff[0], numPrimObjs, primMPIType, procRank_+1, 0, mpiComm_, &status );
00292 PTT::loadPrimitiveObjs( numPrimObjs, &buff[0], x_kp1 );
00293 }
00294 else {
00295 PTT::extractPrimitiveObjs( x[0], numPrimObjs, &buff[0] );
00296 if(last) {
00297 MPI_Send( &buff[0], numPrimObjs, primMPIType, procRank_-1, 0, mpiComm_ );
00298 }
00299 else {
00300 MPI_Sendrecv_replace( &buff[0], numPrimObjs, primMPIType, procRank_-1, 0, procRank_+1, 0, mpiComm_, &status );
00301 PTT::loadPrimitiveObjs( numPrimObjs, &buff[0], x_kp1 );
00302 }
00303 }
00304 }
00305 }
00306
00307 #endif // THYRA_MPI_TRIDIAG_LINEAR_OP_HPP