MPITridiagLinearOp.hpp

00001 // @HEADER
00002 // ***********************************************************************
00003 // 
00004 //    Thyra: Interfaces and Support for Abstract Numerical Algorithms
00005 //                 Copyright (2004) Sandia Corporation
00006 // 
00007 // Under terms of Contract DE-AC04-94AL85000, there is a non-exclusive
00008 // license for use of this work by or on behalf of the U.S. Government.
00009 // 
00010 // This library is free software; you can redistribute it and/or modify
00011 // it under the terms of the GNU Lesser General Public License as
00012 // published by the Free Software Foundation; either version 2.1 of the
00013 // License, or (at your option) any later version.
00014 //  
00015 // This library is distributed in the hope that it will be useful, but
00016 // WITHOUT ANY WARRANTY; without even the implied warranty of
00017 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
00018 // Lesser General Public License for more details.
00019 //  
00020 // You should have received a copy of the GNU Lesser General Public
00021 // License along with this library; if not, write to the Free Software
00022 // Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307
00023 // USA
00024 // Questions? Contact Michael A. Heroux (maherou@sandia.gov) 
00025 // 
00026 // ***********************************************************************
00027 // @HEADER
00028 
00029 #ifndef THYRA_MPI_TRIDIAG_LINEAR_OP_HPP
00030 #define THYRA_MPI_TRIDIAG_LINEAR_OP_HPP
00031 
00032 #include "Thyra_MPILinearOpBase.hpp"
00033 #include "Teuchos_PrimitiveTypeTraits.hpp"
00034 #include "Teuchos_RawMPITraits.hpp"
00035 
00133 template<class Scalar>
00134 class MPITridiagLinearOp : public Thyra::MPILinearOpBase<Scalar> {
00135 private:
00136 
00137   MPI_Comm             mpiComm_;
00138   int                  procRank_;
00139   int                  numProc_;
00140   Thyra::Index         localDim_;
00141   std::vector<Scalar>  lower_;   // size = ( procRank == 0         ? localDim - 1 : localDim )    
00142   std::vector<Scalar>  diag_;    // size = localDim
00143   std::vector<Scalar>  upper_;   // size = ( procRank == numProc-1 ? localDim - 1 : localDim )
00144 
00145   void communicate( const bool first, const bool last, const Scalar x[], Scalar *x_km1, Scalar *x_kp1 ) const;
00146 
00147 public:
00148 
00150   using Thyra::MPILinearOpBase<Scalar>::euclideanApply;
00151 
00153   MPITridiagLinearOp() : mpiComm_(MPI_COMM_NULL), procRank_(0), numProc_(0) {}
00154 
00156   MPITridiagLinearOp( MPI_Comm mpiComm, const Thyra::Index localDim, const Scalar lower[], const Scalar diag[], const Scalar upper[] )
00157     { this->initialize(mpiComm,localDim,lower,diag,upper);  }
00158   
00176   void initialize(
00177     MPI_Comm                        mpiComm
00178     ,const Thyra::Index             localDim // >= 2
00179     ,const Scalar                   lower[]  // size == ( procRank == 0         ? localDim - 1 : localDim )
00180     ,const Scalar                   diag[]   // size == localDim
00181     ,const Scalar                   upper[]  // size == ( procRank == numProc-1 ? localDim - 1 : localDim )
00182     )
00183     {
00184       TEST_FOR_EXCEPT( localDim < 2 );
00185       this->setLocalDimensions(mpiComm,localDim,localDim); // We must tell the base class our local dimensions to setup range() and domain()
00186       mpiComm_  = mpiComm;
00187       localDim_ = localDim;
00188       MPI_Comm_size( mpiComm, &numProc_ );
00189       MPI_Comm_rank( mpiComm, &procRank_ );
00190       const Thyra::Index
00191         lowerDim = ( procRank_ == 0          ? localDim - 1 : localDim ),
00192         upperDim = ( procRank_ == numProc_-1 ? localDim - 1 : localDim );
00193       lower_.resize(lowerDim);  for( int k = 0; k < lowerDim; ++k ) lower_[k] = lower[k];
00194       diag_.resize(localDim);   for( int k = 0; k < localDim; ++k ) diag_[k]  = diag[k];
00195       upper_.resize(upperDim);  for( int k = 0; k < upperDim; ++k ) upper_[k] = upper[k];
00196     }
00197 
00198   // Overridden form Teuchos::Describable */
00199 
00200   std::string description() const
00201     {
00202       return (std::string("MPITridiagLinearOp<") + Teuchos::ScalarTraits<Scalar>::name() + std::string(">"));
00203     }
00204 
00205 protected:
00206 
00207 
00208   // Overridden from SingleScalarEuclideanLinearOpBase
00209 
00210   bool opSupported( Thyra::ETransp M_trans ) const
00211     {
00212       typedef Teuchos::ScalarTraits<Scalar> ST;
00213       return (M_trans == Thyra::NOTRANS || (!ST::isComplex && M_trans == Thyra::CONJ) );
00214     }
00215 
00216   // Overridden from SerialLinearOpBase
00217 
00218   void euclideanApply(
00219     const Thyra::ETransp                         M_trans
00220     ,const RTOpPack::SubVectorT<Scalar>          &local_x_in
00221     ,const RTOpPack::MutableSubVectorT<Scalar>   *local_y_out
00222     ,const Scalar                                alpha
00223     ,const Scalar                                beta
00224     ) const
00225     {
00226       typedef Teuchos::ScalarTraits<Scalar> ST;
00227       TEST_FOR_EXCEPTION( M_trans != Thyra::NOTRANS, std::logic_error, "Error, can not handle transpose!" );
00228       // Get constants
00229       const Scalar zero = ST::zero();
00230       // Get raw pointers to vector data to make me feel better!
00231       const Scalar *x = local_x_in.values();
00232       Scalar       *y = local_y_out->values();
00233       // Determine what processor we are
00234       const bool first = ( procRank_ == 0 ), last = ( procRank_ == numProc_-1 );
00235       // Communicate ghost elements
00236       Scalar x_km1, x_kp1;
00237       communicate( first, last, x, &x_km1, &x_kp1 );
00238       // Perform operation (if beta==0 then we must be careful since y could be uninitialized on input!)
00239       Thyra::Index k = 0, lk = 0;
00240       if( beta == zero ) {
00241         y[k] = alpha * ( (first?zero:lower_[lk]*x_km1) + diag_[k]*x[k] + upper_[k]*x[k+1] ); if(!first) ++lk;             // First local row
00242         for( k = 1; k < localDim_ - 1; ++lk, ++k )
00243           y[k] = alpha * ( lower_[lk]*x[k-1] + diag_[k]*x[k] + upper_[k]*x[k+1] );                                        // Middle local rows
00244         y[k] = alpha * ( lower_[lk]*x[k-1] + diag_[k]*x[k] + (last?zero:upper_[k]*x_kp1) );                               // Last local row
00245       }
00246       else {
00247         y[k] = alpha * ( (first?zero:lower_[lk]*x_km1) + diag_[k]*x[k] + upper_[k]*x[k+1] ) + beta*y[k]; if(!first) ++lk; // First local row
00248         for( k = 1; k < localDim_ - 1; ++lk, ++k )
00249           y[k] = alpha * ( lower_[lk]*x[k-1] + diag_[k]*x[k] + upper_[k]*x[k+1] ) + beta*y[k];                            // Middle local rows
00250         y[k] = alpha * ( lower_[lk]*x[k-1] + diag_[k]*x[k] + (last?zero:upper_[k]*x_kp1) ) + beta*y[k];                   // Last local row
00251       }
00252       //std::cout << "\ny = ["; for(k=0;k<localDim_;++k) { std::cout << y[k]; if(k<localDim_-1) std::cout << ","; } std::cout << "]\n";
00253     }
00254 
00255 };  // end class MPITridiagLinearOp
00256 
00257 // private
00258 
00259 template<class Scalar>
00260 void MPITridiagLinearOp<Scalar>::communicate(
00261   const bool first, const bool last, const Scalar x[], Scalar *x_km1, Scalar *x_kp1
00262   ) const
00263 {
00264   if(numProc_ > 1 ) {
00265     // Get the types so allow interaction with MPI
00266     typedef Teuchos::PrimitiveTypeTraits<Scalar>  PTT;
00267     typedef typename PTT::primitiveType           PT;
00268     typedef Teuchos::RawMPITraits<PT>             PRMT;
00269     MPI_Datatype primMPIType = PRMT::type();
00270     const int numPrimObjs = PTT::numPrimitiveObjs();
00271     MPI_Status status;
00272     // Setup buffer
00273     std::vector<PT> buff(numPrimObjs);
00274     // Send and receive x[localDim_-1] forward and copy into x_km1
00275     if(last) {
00276       MPI_Recv( &buff[0], PRMT::adjustCount(numPrimObjs), primMPIType, procRank_-1, 0, mpiComm_, &status );
00277       PTT::loadPrimitiveObjs( numPrimObjs, &buff[0], x_km1 );
00278     }
00279     else {
00280       PTT::extractPrimitiveObjs( x[localDim_-1], numPrimObjs, &buff[0] );
00281       if(first) {
00282         MPI_Send( &buff[0], numPrimObjs, primMPIType, procRank_+1, 0, mpiComm_ );
00283       }
00284       else {
00285         MPI_Sendrecv_replace( &buff[0], numPrimObjs, primMPIType, procRank_+1, 0, procRank_-1, 0, mpiComm_, &status );
00286         PTT::loadPrimitiveObjs( numPrimObjs, &buff[0], x_km1 );
00287       }
00288     }
00289     // Send and receive x[0] backward and copy into x_kp1
00290     if(first) {
00291       MPI_Recv( &buff[0], numPrimObjs, primMPIType, procRank_+1, 0, mpiComm_, &status );
00292       PTT::loadPrimitiveObjs( numPrimObjs, &buff[0], x_kp1 );
00293     }
00294     else {
00295       PTT::extractPrimitiveObjs( x[0], numPrimObjs, &buff[0] );
00296       if(last) {
00297         MPI_Send( &buff[0], numPrimObjs, primMPIType, procRank_-1, 0, mpiComm_ );
00298       }
00299       else {
00300         MPI_Sendrecv_replace( &buff[0], numPrimObjs, primMPIType, procRank_-1, 0, procRank_+1, 0, mpiComm_, &status );
00301         PTT::loadPrimitiveObjs( numPrimObjs, &buff[0], x_kp1 );
00302       }
00303     }
00304   }
00305 }
00306 
00307 #endif  // THYRA_MPI_TRIDIAG_LINEAR_OP_HPP

Generated on Thu Sep 18 12:39:52 2008 for Thyra ANA Operator/VectorBase Interfaces and Related Software by doxygen 1.3.9.1