Anasazi Version of the Day
Tsqr_MpiMessenger.hpp
00001 #ifndef __TSQR_MpiMessenger_hpp
00002 #define __TSQR_MpiMessenger_hpp
00003 
00004 #include <mpi.h>
00005 #include <Tsqr_MessengerBase.hpp>
00006 #include <Tsqr_MpiDatatype.hpp>
00007 #include <stdexcept>
00008 
00011 
00012 namespace TSQR {
00013   namespace MPI {
00014 
00015     template< class Datum >
00016     class MpiMessenger : public TSQR::MessengerBase< Datum > {
00017     public:
00018       MpiMessenger (MPI_Comm comm) : comm_ (comm) {}
00019       virtual ~MpiMessenger () {}
00020 
00021       virtual void 
00022       send (const Datum sendData[], 
00023       const int sendCount, 
00024       const int destProc, 
00025       const int tag)
00026       {
00027   const int err = 
00028     MPI_Send (const_cast< Datum* const >(sendData), sendCount, 
00029         mpiType_.get(), destProc, tag, comm_);
00030   if (err != MPI_SUCCESS)
00031     throw std::runtime_error ("MPI_Send failed");
00032       }
00033 
00034       virtual void 
00035       recv (Datum recvData[], 
00036       const int recvCount, 
00037       const int srcProc, 
00038       const int tag)
00039       {
00040   MPI_Status status;
00041   const int err = MPI_Recv (recvData, recvCount, mpiType_.get(), 
00042           srcProc, tag, comm_, &status);
00043   if (err != MPI_SUCCESS)
00044     throw std::runtime_error ("MPI_Recv failed");
00045       }
00046 
00047       virtual void 
00048       swapData (const Datum sendData[], 
00049     Datum recvData[], 
00050     const int sendRecvCount, 
00051     const int destProc, 
00052     const int tag)
00053       {
00054   MPI_Status status;
00055   const int err = 
00056     MPI_Sendrecv (const_cast< Datum* const >(sendData), sendRecvCount, 
00057       mpiType_.get(), destProc, tag,
00058       recvData, sendRecvCount, mpiType_.get(), destProc, tag,
00059       comm_, &status);
00060   if (err != MPI_SUCCESS)
00061     throw std::runtime_error ("MPI_Sendrecv failed");
00062       }
00063 
00064       virtual Datum 
00065       globalSum (const Datum& inDatum)
00066       {
00067   Datum input (inDatum);
00068   Datum output;
00069 
00070   int count = 1;
00071   const int err = MPI_Allreduce (&input, &output, count, 
00072                mpiType_.get(), MPI_SUM, comm_);
00073   if (err != MPI_SUCCESS)
00074     throw std::runtime_error ("MPI_Allreduce (MPI_SUM) failed");
00075   return output;
00076       }
00077 
00078       virtual Datum 
00079       globalMin (const Datum& inDatum)
00080       {
00081   Datum input (inDatum);
00082   Datum output;
00083 
00084   int count = 1;
00085   const int err = MPI_Allreduce (&input, &output, count, 
00086                mpiType_.get(), MPI_MIN, comm_);
00087   if (err != MPI_SUCCESS)
00088     throw std::runtime_error ("MPI_Allreduce (MPI_MIN) failed");
00089   return output;
00090       }
00091 
00092       virtual Datum 
00093       globalMax (const Datum& inDatum)
00094       {
00095   Datum input (inDatum);
00096   Datum output;
00097 
00098   int count = 1;
00099   const int err = MPI_Allreduce (&input, &output, count, 
00100                mpiType_.get(), MPI_MAX, comm_);
00101   if (err != MPI_SUCCESS)
00102     throw std::runtime_error ("MPI_Allreduce (MPI_MAX) failed");
00103   return output;
00104       }
00105 
00106       virtual void
00107       globalVectorSum (const Datum inData[], 
00108            Datum outData[], 
00109            const int count)
00110       {
00111   const int err = 
00112     MPI_Allreduce (const_cast< Datum* const > (inData), outData, 
00113        count, mpiType_.get(), MPI_SUM, comm_);
00114   if (err != MPI_SUCCESS)
00115     throw std::runtime_error ("MPI_Allreduce failed");
00116       }
00117 
00118       virtual void
00119       broadcast (Datum data[], 
00120      const int count,
00121      const int root)
00122       {
00123   const int err = MPI_Bcast (data, count, mpiType_.get(), root, comm_);
00124   if (err != MPI_SUCCESS)
00125     throw std::runtime_error ("MPI_Bcast failed");
00126       }
00127 
00128       virtual int 
00129       size() const
00130       {
00131   int nprocs = 0;
00132   const int err = MPI_Comm_size (comm_, &nprocs);
00133   if (err != MPI_SUCCESS)
00134     throw std::runtime_error ("MPI_Comm_size failed");
00135   else if (nprocs <= 0)
00136     throw std::runtime_error ("MPI_Comm_size returned # processors <= 0");
00137   return nprocs;
00138       }
00139 
00140       virtual int 
00141       rank() const
00142       {
00143   int my_rank = 0;
00144   const int err = MPI_Comm_rank (comm_, &my_rank);
00145   if (err != MPI_SUCCESS)
00146     throw std::runtime_error ("MPI_Comm_rank failed");
00147   return my_rank;
00148       }
00149 
00150       virtual void 
00151       barrier () const 
00152       {
00153   const int err = MPI_Barrier (comm_);
00154   if (err != MPI_SUCCESS)
00155     throw std::runtime_error ("MPI_Barrier failed");
00156       }
00157 
00158     private:
00166       mutable MPI_Comm comm_; 
00167 
00168       MpiDatatype< Datum > mpiType_;
00169     };
00170   } // namespace MPI
00171 } // namespace TSQR
00172 
00173 #endif // __TSQR_MpiMessenger_hpp
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends