Kokkos Node API and Local Linear Algebra Kernels Version of the Day
Tsqr_MpiMessenger.hpp
00001 /*
00002 //@HEADER
00003 // ************************************************************************
00004 // 
00005 //          Kokkos: Node API and Parallel Node Kernels
00006 //              Copyright (2008) Sandia Corporation
00007 // 
00008 // Under the terms of Contract DE-AC04-94AL85000 with Sandia Corporation,
00009 // the U.S. Government retains certain rights in this software.
00010 // 
00011 // Redistribution and use in source and binary forms, with or without
00012 // modification, are permitted provided that the following conditions are
00013 // met:
00014 //
00015 // 1. Redistributions of source code must retain the above copyright
00016 // notice, this list of conditions and the following disclaimer.
00017 //
00018 // 2. Redistributions in binary form must reproduce the above copyright
00019 // notice, this list of conditions and the following disclaimer in the
00020 // documentation and/or other materials provided with the distribution.
00021 //
00022 // 3. Neither the name of the Corporation nor the names of the
00023 // contributors may be used to endorse or promote products derived from
00024 // this software without specific prior written permission.
00025 //
00026 // THIS SOFTWARE IS PROVIDED BY SANDIA CORPORATION "AS IS" AND ANY
00027 // EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
00028 // IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
00029 // PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL SANDIA CORPORATION OR THE
00030 // CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
00031 // EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
00032 // PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
00033 // PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
00034 // LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
00035 // NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
00036 // SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
00037 //
00038 // Questions? Contact Michael A. Heroux (maherou@sandia.gov) 
00039 // 
00040 // ************************************************************************
00041 //@HEADER
00042 */
00043 
00044 #ifndef __TSQR_MpiMessenger_hpp
00045 #define __TSQR_MpiMessenger_hpp
00046 
00047 #include <Teuchos_ConfigDefs.hpp> // HAVE_MPI
00048 
00049 #ifdef HAVE_MPI
00050 #  include <Tsqr_MessengerBase.hpp>
00051 #  include <Tsqr_MpiDatatype.hpp>
00052 #  include <stdexcept>
00053 
00056 
00057 namespace TSQR {
00058   namespace MPI {
00059 
00075     template< class Datum >
00076     class MpiMessenger : public TSQR::MessengerBase< Datum > {
00077     public:
00078       MpiMessenger (MPI_Comm comm) : comm_ (comm) {}
00079       virtual ~MpiMessenger () {}
00080 
00081       virtual void 
00082       send (const Datum sendData[], 
00083       const int sendCount, 
00084       const int destProc, 
00085       const int tag)
00086       {
00087   const int err = 
00088     MPI_Send (const_cast< Datum* const >(sendData), sendCount, 
00089         mpiType_.get(), destProc, tag, comm_);
00090   if (err != MPI_SUCCESS)
00091     throw std::runtime_error ("MPI_Send failed");
00092       }
00093 
00094       virtual void 
00095       recv (Datum recvData[], 
00096       const int recvCount, 
00097       const int srcProc, 
00098       const int tag)
00099       {
00100   MPI_Status status;
00101   const int err = MPI_Recv (recvData, recvCount, mpiType_.get(), 
00102           srcProc, tag, comm_, &status);
00103   if (err != MPI_SUCCESS)
00104     throw std::runtime_error ("MPI_Recv failed");
00105       }
00106 
00107       virtual void 
00108       swapData (const Datum sendData[], 
00109     Datum recvData[], 
00110     const int sendRecvCount, 
00111     const int destProc, 
00112     const int tag)
00113       {
00114   MPI_Status status;
00115   const int err = 
00116     MPI_Sendrecv (const_cast< Datum* const >(sendData), sendRecvCount, 
00117       mpiType_.get(), destProc, tag,
00118       recvData, sendRecvCount, mpiType_.get(), destProc, tag,
00119       comm_, &status);
00120   if (err != MPI_SUCCESS)
00121     throw std::runtime_error ("MPI_Sendrecv failed");
00122       }
00123 
00124       virtual Datum 
00125       globalSum (const Datum& inDatum)
00126       {
00127   // Preserve const semantics of inDatum, by copying it and
00128   // using the copy in MPI_Allreduce().
00129   Datum input (inDatum);
00130   Datum output;
00131 
00132   int count = 1;
00133   const int err = MPI_Allreduce (&input, &output, count, 
00134                mpiType_.get(), MPI_SUM, comm_);
00135   if (err != MPI_SUCCESS)
00136     throw std::runtime_error ("MPI_Allreduce (MPI_SUM) failed");
00137   return output;
00138       }
00139 
00140       virtual Datum 
00141       globalMin (const Datum& inDatum)
00142       {
00143   // Preserve const semantics of inDatum, by copying it and
00144   // using the copy in MPI_Allreduce().
00145   Datum input (inDatum);
00146   Datum output;
00147 
00148   int count = 1;
00149   const int err = MPI_Allreduce (&input, &output, count, 
00150                mpiType_.get(), MPI_MIN, comm_);
00151   if (err != MPI_SUCCESS)
00152     throw std::runtime_error ("MPI_Allreduce (MPI_MIN) failed");
00153   return output;
00154       }
00155 
00156       virtual Datum 
00157       globalMax (const Datum& inDatum)
00158       {
00159   Datum input (inDatum);
00160   Datum output;
00161 
00162   int count = 1;
00163   const int err = MPI_Allreduce (&input, &output, count, 
00164                mpiType_.get(), MPI_MAX, comm_);
00165   if (err != MPI_SUCCESS)
00166     throw std::runtime_error ("MPI_Allreduce (MPI_MAX) failed");
00167   return output;
00168       }
00169 
00170       virtual void
00171       globalVectorSum (const Datum inData[], 
00172            Datum outData[], 
00173            const int count)
00174       {
00175   const int err = 
00176     MPI_Allreduce (const_cast< Datum* const > (inData), outData, 
00177        count, mpiType_.get(), MPI_SUM, comm_);
00178   if (err != MPI_SUCCESS)
00179     throw std::runtime_error ("MPI_Allreduce failed");
00180       }
00181 
00182       virtual void
00183       broadcast (Datum data[], 
00184      const int count,
00185      const int root)
00186       {
00187   const int err = MPI_Bcast (data, count, mpiType_.get(), root, comm_);
00188   if (err != MPI_SUCCESS)
00189     throw std::runtime_error ("MPI_Bcast failed");
00190       }
00191 
00192       virtual int 
00193       rank() const
00194       {
00195   int my_rank = 0;
00196   const int err = MPI_Comm_rank (comm_, &my_rank);
00197   if (err != MPI_SUCCESS)
00198     throw std::runtime_error ("MPI_Comm_rank failed");
00199   return my_rank;
00200       }
00201 
00202       virtual int 
00203       size() const
00204       {
00205   int nprocs = 0;
00206   const int err = MPI_Comm_size (comm_, &nprocs);
00207   if (err != MPI_SUCCESS)
00208     throw std::runtime_error ("MPI_Comm_size failed");
00209   else if (nprocs <= 0)
00210     // We want to make sure that there is always at least one
00211     // valid rank (at least rank() == 0).  The messenger can't
00212     // do anything useful with MPI_COMM_NULL.
00213     throw std::runtime_error ("MPI_Comm_size returned # processors <= 0");
00214   return nprocs;
00215       }
00216 
00217       virtual void 
00218       barrier () const 
00219       {
00220   const int err = MPI_Barrier (comm_);
00221   if (err != MPI_SUCCESS)
00222     throw std::runtime_error ("MPI_Barrier failed");
00223       }
00224 
00225     private:
00233       mutable MPI_Comm comm_; 
00234 
00238       MpiDatatype< Datum > mpiType_;
00239     };
00240   } // namespace MPI
00241 } // namespace TSQR
00242 
00243 #endif // HAVE_MPI
00244 #endif // __TSQR_MpiMessenger_hpp
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends