Teuchos Package Browser (Single Doxygen Collection) Version of the Day
Teuchos_MPIComm.cpp
Go to the documentation of this file.
00001 // @HEADER
00002 // ***********************************************************************
00003 //
00004 //                    Teuchos: Common Tools Package
00005 //                 Copyright (2004) Sandia Corporation
00006 //
00007 // Under terms of Contract DE-AC04-94AL85000, there is a non-exclusive
00008 // license for use of this work by or on behalf of the U.S. Government.
00009 //
00010 // This library is free software; you can redistribute it and/or modify
00011 // it under the terms of the GNU Lesser General Public License as
00012 // published by the Free Software Foundation; either version 2.1 of the
00013 // License, or (at your option) any later version.
00014 //
00015 // This library is distributed in the hope that it will be useful, but
00016 // WITHOUT ANY WARRANTY; without even the implied warranty of
00017 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
00018 // Lesser General Public License for more details.
00019 //
00020 // You should have received a copy of the GNU Lesser General Public
00021 // License along with this library; if not, write to the Free Software
00022 // Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307
00023 // USA
00024 // Questions? Contact Michael A. Heroux (maherou@sandia.gov)
00025 //
00026 // ***********************************************************************
00027 // @HEADER
00028 
00029 #include "Teuchos_MPIComm.hpp"
00030 #include "Teuchos_ErrorPolling.hpp"
00031 
00032 
00033 using namespace Teuchos;
00034 
00035 namespace Teuchos
00036 {
00037   const int MPIComm::INT = 1;
00038   const int MPIComm::FLOAT = 2;
00039   const int MPIComm::DOUBLE = 3;
00040   const int MPIComm::CHAR = 4;
00041 
00042   const int MPIComm::SUM = 5;
00043   const int MPIComm::MIN = 6;
00044   const int MPIComm::MAX = 7;
00045   const int MPIComm::PROD = 8;
00046 }
00047 
00048 
00049 MPIComm::MPIComm()
00050   :
00051 #ifdef HAVE_MPI
00052   comm_(MPI_COMM_WORLD),
00053 #endif
00054   nProc_(0), myRank_(0)
00055 {
00056   init();
00057 }
00058 
00059 #ifdef HAVE_MPI
00060 MPIComm::MPIComm(MPI_Comm comm)
00061   : comm_(comm), nProc_(0), myRank_(0)
00062 {
00063   init();
00064 }
00065 #endif
00066 
00067 int MPIComm::mpiIsRunning() const
00068 {
00069   int mpiStarted = 0;
00070 #ifdef HAVE_MPI
00071   MPI_Initialized(&mpiStarted);
00072 #endif
00073   return mpiStarted;
00074 }
00075 
00076 void MPIComm::init()
00077 {
00078 #ifdef HAVE_MPI
00079 
00080   if (mpiIsRunning())
00081     {
00082       errCheck(MPI_Comm_rank(comm_, &myRank_), "Comm_rank");
00083       errCheck(MPI_Comm_size(comm_, &nProc_), "Comm_size");
00084     }
00085   else
00086     {
00087       nProc_ = 1;
00088       myRank_ = 0;
00089     }
00090   
00091 #else
00092   nProc_ = 1;
00093   myRank_ = 0;
00094 #endif
00095 }
00096 
00097 #ifdef USE_MPI_GROUPS /* we're ignoring groups for now */
00098 
00099 MPIComm::MPIComm(const MPIComm& parent, const MPIGroup& group)
00100   :
00101 #ifdef HAVE_MPI
00102   comm_(MPI_COMM_WORLD), 
00103 #endif
00104   nProc_(0), myRank_(0)
00105 {
00106 #ifdef HAVE_MPI
00107   if (group.getNProc()==0)
00108     {
00109       rank_ = -1;
00110       nProc_ = 0;
00111     }
00112   else if (parent.containsMe())
00113     {
00114       MPI_Comm parentComm = parent.comm_;
00115       MPI_Group newGroup = group.group_;
00116       
00117       errCheck(MPI_Comm_create(parentComm, newGroup, &comm_), 
00118                "Comm_create");
00119       
00120       if (group.containsProc(parent.getRank()))
00121         {
00122           errCheck(MPI_Comm_rank(comm_, &rank_), "Comm_rank");
00123           
00124           errCheck(MPI_Comm_size(comm_, &nProc_), "Comm_size");
00125         }
00126       else
00127         {
00128           rank_ = -1;
00129           nProc_ = -1;
00130           return;
00131         }
00132     }
00133   else
00134     {
00135       rank_ = -1;
00136       nProc_ = -1;
00137     }
00138 #endif
00139 }
00140 
00141 #endif /* USE_MPI_GROUPS */
00142 
00143 MPIComm& MPIComm::world()
00144 {
00145   static MPIComm w = MPIComm();
00146   return w;
00147 }
00148 
00149 
00150 MPIComm& MPIComm::self()
00151 {
00152 #ifdef HAVE_MPI
00153   static MPIComm w = MPIComm(MPI_COMM_SELF);
00154 #else
00155   static MPIComm w = MPIComm();
00156 #endif
00157   return w;
00158 }
00159 
00160 
00161 void MPIComm::synchronize() const 
00162 {
00163 #ifdef HAVE_MPI
00164   //mutex_.lock();
00165   {
00166     if (mpiIsRunning())
00167       {
00168         /* test whether errors have been detected on another proc before
00169          * doing the collective operation. */
00170         TEUCHOS_POLL_FOR_FAILURES(*this);
00171         /* if we're to this point, all processors are OK */
00172         
00173         errCheck(::MPI_Barrier(comm_), "Barrier");
00174       }
00175   }
00176   //mutex_.unlock();
00177 #endif
00178 }
00179 
00180 void MPIComm::allToAll(void* sendBuf, int sendCount, int sendType,
00181                        void* recvBuf, int recvCount, int recvType) const
00182 {
00183 #ifdef HAVE_MPI
00184   //mutex_.lock();
00185   {
00186     MPI_Datatype mpiSendType = getDataType(sendType);
00187     MPI_Datatype mpiRecvType = getDataType(recvType);
00188 
00189 
00190     if (mpiIsRunning())
00191       {
00192         /* test whether errors have been detected on another proc before
00193          * doing the collective operation. */
00194         TEUCHOS_POLL_FOR_FAILURES(*this);
00195         /* if we're to this point, all processors are OK */
00196         
00197         errCheck(::MPI_Alltoall(sendBuf, sendCount, mpiSendType,
00198                                 recvBuf, recvCount, mpiRecvType,
00199                                 comm_), "Alltoall");
00200       }
00201   }
00202   //mutex_.unlock();
00203 #else
00204   (void)sendBuf;
00205   (void)sendCount;
00206   (void)sendType;
00207   (void)recvBuf;
00208   (void)recvCount;
00209   (void)recvType;
00210 #endif
00211 }
00212 
00213 void MPIComm::allToAllv(void* sendBuf, int* sendCount, 
00214                         int* sendDisplacements, int sendType,
00215                         void* recvBuf, int* recvCount, 
00216                         int* recvDisplacements, int recvType) const
00217 {
00218 #ifdef HAVE_MPI
00219   //mutex_.lock();
00220   {
00221     MPI_Datatype mpiSendType = getDataType(sendType);
00222     MPI_Datatype mpiRecvType = getDataType(recvType);
00223 
00224     if (mpiIsRunning())
00225       {
00226         /* test whether errors have been detected on another proc before
00227          * doing the collective operation. */
00228         TEUCHOS_POLL_FOR_FAILURES(*this);
00229         /* if we're to this point, all processors are OK */   
00230         
00231         errCheck(::MPI_Alltoallv(sendBuf, sendCount, sendDisplacements, mpiSendType,
00232                                  recvBuf, recvCount, recvDisplacements, mpiRecvType,
00233                                  comm_), "Alltoallv");
00234       }
00235   }
00236   //mutex_.unlock();
00237 #else
00238   (void)sendBuf;
00239   (void)sendCount;
00240   (void)sendDisplacements;
00241   (void)sendType;
00242   (void)recvBuf;
00243   (void)recvCount;
00244   (void)recvDisplacements;
00245   (void)recvType;
00246 #endif
00247 }
00248 
00249 void MPIComm::gather(void* sendBuf, int sendCount, int sendType,
00250                      void* recvBuf, int recvCount, int recvType,
00251                      int root) const
00252 {
00253 #ifdef HAVE_MPI
00254   //mutex_.lock();
00255   {
00256     MPI_Datatype mpiSendType = getDataType(sendType);
00257     MPI_Datatype mpiRecvType = getDataType(recvType);
00258 
00259 
00260     if (mpiIsRunning())
00261       {
00262         /* test whether errors have been detected on another proc before
00263          * doing the collective operation. */
00264         TEUCHOS_POLL_FOR_FAILURES(*this);
00265         /* if we're to this point, all processors are OK */
00266         
00267         errCheck(::MPI_Gather(sendBuf, sendCount, mpiSendType,
00268                               recvBuf, recvCount, mpiRecvType,
00269                               root, comm_), "Gather");
00270       }
00271   }
00272   //mutex_.unlock();
00273 #endif
00274 }
00275 
00276 void MPIComm::gatherv(void* sendBuf, int sendCount, int sendType,
00277                      void* recvBuf, int* recvCount, int* displacements, int recvType,
00278                      int root) const
00279 {
00280 #ifdef HAVE_MPI
00281   //mutex_.lock();
00282   {
00283     MPI_Datatype mpiSendType = getDataType(sendType);
00284     MPI_Datatype mpiRecvType = getDataType(recvType);
00285     
00286     if (mpiIsRunning())
00287       {
00288         /* test whether errors have been detected on another proc before
00289          * doing the collective operation. */
00290         TEUCHOS_POLL_FOR_FAILURES(*this);
00291         /* if we're to this point, all processors are OK */
00292         
00293         errCheck(::MPI_Gatherv(sendBuf, sendCount, mpiSendType,
00294                                recvBuf, recvCount, displacements, mpiRecvType,
00295                                root, comm_), "Gatherv");
00296       }
00297   }
00298   //mutex_.unlock();
00299 #endif
00300 }
00301 
00302 void MPIComm::allGather(void* sendBuf, int sendCount, int sendType,
00303                         void* recvBuf, int recvCount, 
00304                         int recvType) const
00305 {
00306 #ifdef HAVE_MPI
00307   //mutex_.lock();
00308   {
00309     MPI_Datatype mpiSendType = getDataType(sendType);
00310     MPI_Datatype mpiRecvType = getDataType(recvType);
00311     
00312     if (mpiIsRunning())
00313       {
00314         /* test whether errors have been detected on another proc before
00315          * doing the collective operation. */
00316         TEUCHOS_POLL_FOR_FAILURES(*this);
00317         /* if we're to this point, all processors are OK */
00318         
00319         errCheck(::MPI_Allgather(sendBuf, sendCount, mpiSendType,
00320                                  recvBuf, recvCount, 
00321                                  mpiRecvType, comm_), 
00322                  "AllGather");
00323       }
00324   }
00325   //mutex_.unlock();
00326 #endif
00327 }
00328 
00329 
00330 void MPIComm::allGatherv(void* sendBuf, int sendCount, int sendType,
00331                          void* recvBuf, int* recvCount, 
00332                          int* recvDisplacements,
00333                          int recvType) const
00334 {
00335 #ifdef HAVE_MPI
00336   //mutex_.lock();
00337   {
00338     MPI_Datatype mpiSendType = getDataType(sendType);
00339     MPI_Datatype mpiRecvType = getDataType(recvType);
00340     
00341     if (mpiIsRunning())
00342       {
00343         /* test whether errors have been detected on another proc before
00344          * doing the collective operation. */
00345         TEUCHOS_POLL_FOR_FAILURES(*this);
00346         /* if we're to this point, all processors are OK */
00347         
00348         errCheck(::MPI_Allgatherv(sendBuf, sendCount, mpiSendType,
00349                                   recvBuf, recvCount, recvDisplacements,
00350                                   mpiRecvType, 
00351                                   comm_), 
00352                  "AllGatherv");
00353       }
00354   }
00355   //mutex_.unlock();
00356 #endif
00357 }
00358 
00359 
00360 void MPIComm::bcast(void* msg, int length, int type, int src) const
00361 {
00362 #ifdef HAVE_MPI
00363   //mutex_.lock();
00364   {
00365     if (mpiIsRunning())
00366       {
00367         /* test whether errors have been detected on another proc before
00368          * doing the collective operation. */
00369         TEUCHOS_POLL_FOR_FAILURES(*this);
00370         /* if we're to this point, all processors are OK */
00371         
00372         MPI_Datatype mpiType = getDataType(type);
00373         errCheck(::MPI_Bcast(msg, length, mpiType, src, 
00374                              comm_), "Bcast");
00375       }
00376   }
00377   //mutex_.unlock();
00378 #endif
00379 }
00380 
00381 void MPIComm::allReduce(void* input, void* result, int inputCount, 
00382                         int type, int op) const
00383 {
00384 #ifdef HAVE_MPI
00385 
00386   //mutex_.lock();
00387   {
00388     MPI_Op mpiOp = getOp(op);
00389     MPI_Datatype mpiType = getDataType(type);
00390     
00391     if (mpiIsRunning())
00392       {
00393         errCheck(::MPI_Allreduce(input, result, inputCount, mpiType,
00394                                  mpiOp, comm_), 
00395                  "Allreduce");
00396       }
00397   }
00398   //mutex_.unlock();
00399 #endif
00400 }
00401 
00402 
00403 #ifdef HAVE_MPI
00404 
00405 MPI_Datatype MPIComm::getDataType(int type)
00406 {
00407   TEST_FOR_EXCEPTION(
00408     !(type == INT || type==FLOAT 
00409       || type==DOUBLE || type==CHAR),
00410     std::range_error,
00411     "invalid type " << type << " in MPIComm::getDataType");
00412   
00413   if(type == INT) return MPI_INT;
00414   if(type == FLOAT) return MPI_FLOAT;
00415   if(type == DOUBLE) return MPI_DOUBLE;
00416   
00417   return MPI_CHAR;
00418 }
00419 
00420 
00421 void MPIComm::errCheck(int errCode, const std::string& methodName)
00422 {
00423   TEST_FOR_EXCEPTION(errCode != 0, std::runtime_error,
00424                      "MPI function MPI_" << methodName 
00425                      << " returned error code=" << errCode);
00426 }
00427 
00428 MPI_Op MPIComm::getOp(int op)
00429 {
00430 
00431   TEST_FOR_EXCEPTION(
00432     !(op == SUM || op==MAX 
00433       || op==MIN || op==PROD),
00434     std::range_error,
00435     "invalid operator " 
00436     << op << " in MPIComm::getOp");
00437 
00438   if( op == SUM) return MPI_SUM;
00439   else if( op == MAX) return MPI_MAX;
00440   else if( op == MIN) return MPI_MIN;
00441   return MPI_PROD;
00442 }
00443 
00444 #endif
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines