#ifndef THYRA_MPI_TRIDIAG_LINEAR_OP_HPP
#define THYRA_MPI_TRIDIAG_LINEAR_OP_HPP
#include "Thyra_MPILinearOpBase.hpp"
#include "Teuchos_PrimitiveTypeTraits.hpp"
#include "Teuchos_RawMPITraits.hpp"
template<class Scalar>
class MPITridiagLinearOp : public Thyra::MPILinearOpBase<Scalar> {
private:
MPI_Comm mpiComm_;
int procRank_;
int numProc_;
Thyra::Index localDim_;
std::vector<Scalar> lower_;
std::vector<Scalar> diag_;
std::vector<Scalar> upper_;
void communicate( const bool first, const bool last, const Scalar x[], Scalar *x_km1, Scalar *x_kp1 ) const;
public:
using Thyra::MPILinearOpBase<Scalar>::euclideanApply;
MPITridiagLinearOp() : mpiComm_(MPI_COMM_NULL), procRank_(0), numProc_(0) {}
MPITridiagLinearOp( MPI_Comm mpiComm, const Thyra::Index localDim, const Scalar lower[], const Scalar diag[], const Scalar upper[] )
{ this->initialize(mpiComm,localDim,lower,diag,upper); }
void initialize(
MPI_Comm mpiComm
,const Thyra::Index localDim
,const Scalar lower[]
,const Scalar diag[]
,const Scalar upper[]
)
{
TEST_FOR_EXCEPT( localDim < 2 );
this->setLocalDimensions(mpiComm,localDim,localDim);
mpiComm_ = mpiComm;
localDim_ = localDim;
MPI_Comm_size( mpiComm, &numProc_ );
MPI_Comm_rank( mpiComm, &procRank_ );
const Thyra::Index
lowerDim = ( procRank_ == 0 ? localDim - 1 : localDim ),
upperDim = ( procRank_ == numProc_-1 ? localDim - 1 : localDim );
lower_.resize(lowerDim); for( int k = 0; k < lowerDim; ++k ) lower_[k] = lower[k];
diag_.resize(localDim); for( int k = 0; k < localDim; ++k ) diag_[k] = diag[k];
upper_.resize(upperDim); for( int k = 0; k < upperDim; ++k ) upper_[k] = upper[k];
}
std::string description() const
{
return (std::string("MPITridiagLinearOp<") + Teuchos::ScalarTraits<Scalar>::name() + std::string(">"));
}
protected:
bool opSupported( Thyra::ETransp M_trans ) const
{
typedef Teuchos::ScalarTraits<Scalar> ST;
return (M_trans == Thyra::NOTRANS || (!ST::isComplex && M_trans == Thyra::CONJ) );
}
void euclideanApply(
const Thyra::ETransp M_trans
,const RTOpPack::SubVectorT<Scalar> &local_x_in
,const RTOpPack::MutableSubVectorT<Scalar> *local_y_out
,const Scalar alpha
,const Scalar beta
) const
{
typedef Teuchos::ScalarTraits<Scalar> ST;
TEST_FOR_EXCEPTION( M_trans != Thyra::NOTRANS, std::logic_error, "Error, can not handle transpose!" );
const Scalar zero = ST::zero();
const Scalar *x = local_x_in.values();
Scalar *y = local_y_out->values();
const bool first = ( procRank_ == 0 ), last = ( procRank_ == numProc_-1 );
Scalar x_km1, x_kp1;
communicate( first, last, x, &x_km1, &x_kp1 );
Thyra::Index k = 0, lk = 0;
if( beta == zero ) {
y[k] = alpha * ( (first?zero:lower_[lk]*x_km1) + diag_[k]*x[k] + upper_[k]*x[k+1] ); if(!first) ++lk;
for( k = 1; k < localDim_ - 1; ++lk, ++k )
y[k] = alpha * ( lower_[lk]*x[k-1] + diag_[k]*x[k] + upper_[k]*x[k+1] );
y[k] = alpha * ( lower_[lk]*x[k-1] + diag_[k]*x[k] + (last?zero:upper_[k]*x_kp1) );
}
else {
y[k] = alpha * ( (first?zero:lower_[lk]*x_km1) + diag_[k]*x[k] + upper_[k]*x[k+1] ) + beta*y[k]; if(!first) ++lk;
for( k = 1; k < localDim_ - 1; ++lk, ++k )
y[k] = alpha * ( lower_[lk]*x[k-1] + diag_[k]*x[k] + upper_[k]*x[k+1] ) + beta*y[k];
y[k] = alpha * ( lower_[lk]*x[k-1] + diag_[k]*x[k] + (last?zero:upper_[k]*x_kp1) ) + beta*y[k];
}
}
};
template<class Scalar>
void MPITridiagLinearOp<Scalar>::communicate(
const bool first, const bool last, const Scalar x[], Scalar *x_km1, Scalar *x_kp1
) const
{
if(numProc_ > 1 ) {
typedef Teuchos::PrimitiveTypeTraits<Scalar> PTT;
typedef typename PTT::primitiveType PT;
typedef Teuchos::RawMPITraits<PT> PRMT;
MPI_Datatype primMPIType = PRMT::type();
const int numPrimObjs = PTT::numPrimitiveObjs();
MPI_Status status;
std::vector<PT> buff(numPrimObjs);
if(last) {
MPI_Recv( &buff[0], PRMT::adjustCount(numPrimObjs), primMPIType, procRank_-1, 0, mpiComm_, &status );
PTT::loadPrimitiveObjs( numPrimObjs, &buff[0], x_km1 );
}
else {
PTT::extractPrimitiveObjs( x[localDim_-1], numPrimObjs, &buff[0] );
if(first) {
MPI_Send( &buff[0], numPrimObjs, primMPIType, procRank_+1, 0, mpiComm_ );
}
else {
MPI_Sendrecv_replace( &buff[0], numPrimObjs, primMPIType, procRank_+1, 0, procRank_-1, 0, mpiComm_, &status );
PTT::loadPrimitiveObjs( numPrimObjs, &buff[0], x_km1 );
}
}
if(first) {
MPI_Recv( &buff[0], numPrimObjs, primMPIType, procRank_+1, 0, mpiComm_, &status );
PTT::loadPrimitiveObjs( numPrimObjs, &buff[0], x_kp1 );
}
else {
PTT::extractPrimitiveObjs( x[0], numPrimObjs, &buff[0] );
if(last) {
MPI_Send( &buff[0], numPrimObjs, primMPIType, procRank_-1, 0, mpiComm_ );
}
else {
MPI_Sendrecv_replace( &buff[0], numPrimObjs, primMPIType, procRank_-1, 0, procRank_+1, 0, mpiComm_, &status );
PTT::loadPrimitiveObjs( numPrimObjs, &buff[0], x_kp1 );
}
}
}
}
#endif // THYRA_MPI_TRIDIAG_LINEAR_OP_HPP