Kokkos Node API and Local Linear Algebra Kernels Version of the Day
Tsqr_MpiMessenger.hpp
00001 #ifndef __TSQR_MpiMessenger_hpp
00002 #define __TSQR_MpiMessenger_hpp
00003 
00004 #include <Teuchos_ConfigDefs.hpp> // HAVE_MPI
00005 
00006 #ifdef HAVE_MPI
00007 #  include <Tsqr_MessengerBase.hpp>
00008 #  include <Tsqr_MpiDatatype.hpp>
00009 #  include <stdexcept>
00010 
00013 
00014 namespace TSQR {
00015   namespace MPI {
00016 
00032     template< class Datum >
00033     class MpiMessenger : public TSQR::MessengerBase< Datum > {
00034     public:
00035       MpiMessenger (MPI_Comm comm) : comm_ (comm) {}
00036       virtual ~MpiMessenger () {}
00037 
00038       virtual void 
00039       send (const Datum sendData[], 
00040       const int sendCount, 
00041       const int destProc, 
00042       const int tag)
00043       {
00044   const int err = 
00045     MPI_Send (const_cast< Datum* const >(sendData), sendCount, 
00046         mpiType_.get(), destProc, tag, comm_);
00047   if (err != MPI_SUCCESS)
00048     throw std::runtime_error ("MPI_Send failed");
00049       }
00050 
00051       virtual void 
00052       recv (Datum recvData[], 
00053       const int recvCount, 
00054       const int srcProc, 
00055       const int tag)
00056       {
00057   MPI_Status status;
00058   const int err = MPI_Recv (recvData, recvCount, mpiType_.get(), 
00059           srcProc, tag, comm_, &status);
00060   if (err != MPI_SUCCESS)
00061     throw std::runtime_error ("MPI_Recv failed");
00062       }
00063 
00064       virtual void 
00065       swapData (const Datum sendData[], 
00066     Datum recvData[], 
00067     const int sendRecvCount, 
00068     const int destProc, 
00069     const int tag)
00070       {
00071   MPI_Status status;
00072   const int err = 
00073     MPI_Sendrecv (const_cast< Datum* const >(sendData), sendRecvCount, 
00074       mpiType_.get(), destProc, tag,
00075       recvData, sendRecvCount, mpiType_.get(), destProc, tag,
00076       comm_, &status);
00077   if (err != MPI_SUCCESS)
00078     throw std::runtime_error ("MPI_Sendrecv failed");
00079       }
00080 
00081       virtual Datum 
00082       globalSum (const Datum& inDatum)
00083       {
00084   // Preserve const semantics of inDatum, by copying it and
00085   // using the copy in MPI_Allreduce().
00086   Datum input (inDatum);
00087   Datum output;
00088 
00089   int count = 1;
00090   const int err = MPI_Allreduce (&input, &output, count, 
00091                mpiType_.get(), MPI_SUM, comm_);
00092   if (err != MPI_SUCCESS)
00093     throw std::runtime_error ("MPI_Allreduce (MPI_SUM) failed");
00094   return output;
00095       }
00096 
00097       virtual Datum 
00098       globalMin (const Datum& inDatum)
00099       {
00100   // Preserve const semantics of inDatum, by copying it and
00101   // using the copy in MPI_Allreduce().
00102   Datum input (inDatum);
00103   Datum output;
00104 
00105   int count = 1;
00106   const int err = MPI_Allreduce (&input, &output, count, 
00107                mpiType_.get(), MPI_MIN, comm_);
00108   if (err != MPI_SUCCESS)
00109     throw std::runtime_error ("MPI_Allreduce (MPI_MIN) failed");
00110   return output;
00111       }
00112 
00113       virtual Datum 
00114       globalMax (const Datum& inDatum)
00115       {
00116   Datum input (inDatum);
00117   Datum output;
00118 
00119   int count = 1;
00120   const int err = MPI_Allreduce (&input, &output, count, 
00121                mpiType_.get(), MPI_MAX, comm_);
00122   if (err != MPI_SUCCESS)
00123     throw std::runtime_error ("MPI_Allreduce (MPI_MAX) failed");
00124   return output;
00125       }
00126 
00127       virtual void
00128       globalVectorSum (const Datum inData[], 
00129            Datum outData[], 
00130            const int count)
00131       {
00132   const int err = 
00133     MPI_Allreduce (const_cast< Datum* const > (inData), outData, 
00134        count, mpiType_.get(), MPI_SUM, comm_);
00135   if (err != MPI_SUCCESS)
00136     throw std::runtime_error ("MPI_Allreduce failed");
00137       }
00138 
00139       virtual void
00140       broadcast (Datum data[], 
00141      const int count,
00142      const int root)
00143       {
00144   const int err = MPI_Bcast (data, count, mpiType_.get(), root, comm_);
00145   if (err != MPI_SUCCESS)
00146     throw std::runtime_error ("MPI_Bcast failed");
00147       }
00148 
00149       virtual int 
00150       rank() const
00151       {
00152   int my_rank = 0;
00153   const int err = MPI_Comm_rank (comm_, &my_rank);
00154   if (err != MPI_SUCCESS)
00155     throw std::runtime_error ("MPI_Comm_rank failed");
00156   return my_rank;
00157       }
00158 
00159       virtual int 
00160       size() const
00161       {
00162   int nprocs = 0;
00163   const int err = MPI_Comm_size (comm_, &nprocs);
00164   if (err != MPI_SUCCESS)
00165     throw std::runtime_error ("MPI_Comm_size failed");
00166   else if (nprocs <= 0)
00167     // We want to make sure that there is always at least one
00168     // valid rank (at least rank() == 0).  The messenger can't
00169     // do anything useful with MPI_COMM_NULL.
00170     throw std::runtime_error ("MPI_Comm_size returned # processors <= 0");
00171   return nprocs;
00172       }
00173 
00174       virtual void 
00175       barrier () const 
00176       {
00177   const int err = MPI_Barrier (comm_);
00178   if (err != MPI_SUCCESS)
00179     throw std::runtime_error ("MPI_Barrier failed");
00180       }
00181 
00182     private:
00190       mutable MPI_Comm comm_; 
00191 
00195       MpiDatatype< Datum > mpiType_;
00196     };
00197   } // namespace MPI
00198 } // namespace TSQR
00199 
00200 #endif // HAVE_MPI
00201 #endif // __TSQR_MpiMessenger_hpp
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends