Teuchos_MPIComm.cpp

00001 // @HEADER
00002 // ***********************************************************************
00003 // 
00004 //                    Teuchos: Common Tools Package
00005 //                 Copyright (2004) Sandia Corporation
00006 // 
00007 // Under terms of Contract DE-AC04-94AL85000, there is a non-exclusive
00008 // license for use of this work by or on behalf of the U.S. Government.
00009 // 
00010 // This library is free software; you can redistribute it and/or modify
00011 // it under the terms of the GNU Lesser General Public License as
00012 // published by the Free Software Foundation; either version 2.1 of the
00013 // License, or (at your option) any later version.
00014 //  
00015 // This library is distributed in the hope that it will be useful, but
00016 // WITHOUT ANY WARRANTY; without even the implied warranty of
00017 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
00018 // Lesser General Public License for more details.
00019 //  
00020 // You should have received a copy of the GNU Lesser General Public
00021 // License along with this library; if not, write to the Free Software
00022 // Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307
00023 // USA
00024 // Questions? Contact Michael A. Heroux (maherou@sandia.gov) 
00025 // 
00026 // ***********************************************************************
00027 // @HEADER
00028 
00029 #include "Teuchos_MPIComm.hpp"
00030 #include "Teuchos_ErrorPolling.hpp"
00031 
00032 
00033 using namespace Teuchos;
00034 
00035 namespace Teuchos
00036 {
00037   const int MPIComm::INT = 1;
00038   const int MPIComm::FLOAT = 2;
00039   const int MPIComm::DOUBLE = 3;
00040   const int MPIComm::CHAR = 4;
00041 
00042   const int MPIComm::SUM = 5;
00043   const int MPIComm::MIN = 6;
00044   const int MPIComm::MAX = 7;
00045   const int MPIComm::PROD = 8;
00046 }
00047 
00048 
00049 MPIComm::MPIComm()
00050   :
00051 #ifdef HAVE_MPI
00052   comm_(MPI_COMM_WORLD),
00053 #endif
00054   nProc_(0), myRank_(0)
00055 {
00056   init();
00057 }
00058 
00059 #ifdef HAVE_MPI
00060 MPIComm::MPIComm(MPI_Comm comm)
00061   : comm_(comm), nProc_(0), myRank_(0)
00062 {
00063   init();
00064 }
00065 #endif
00066 
00067 int MPIComm::mpiIsRunning() const
00068 {
00069   int mpiStarted = 0;
00070 #ifdef HAVE_MPI
00071   MPI_Initialized(&mpiStarted);
00072 #endif
00073   return mpiStarted;
00074 }
00075 
00076 void MPIComm::init()
00077 {
00078 #ifdef HAVE_MPI
00079 
00080   if (mpiIsRunning())
00081     {
00082       errCheck(MPI_Comm_rank(comm_, &myRank_), "Comm_rank");
00083       errCheck(MPI_Comm_size(comm_, &nProc_), "Comm_size");
00084     }
00085   else
00086     {
00087       nProc_ = 1;
00088       myRank_ = 0;
00089     }
00090   
00091 #else
00092   nProc_ = 1;
00093   myRank_ = 0;
00094 #endif
00095 }
00096 
00097 #ifdef USE_MPI_GROUPS /* we're ignoring groups for now */
00098 
00099 MPIComm::MPIComm(const MPIComm& parent, const MPIGroup& group)
00100   :
00101 #ifdef HAVE_MPI
00102   comm_(MPI_COMM_WORLD), 
00103 #endif
00104   nProc_(0), myRank_(0)
00105 {
00106 #ifdef HAVE_MPI
00107   if (group.getNProc()==0)
00108     {
00109       rank_ = -1;
00110       nProc_ = 0;
00111     }
00112   else if (parent.containsMe())
00113     {
00114       MPI_Comm parentComm = parent.comm_;
00115       MPI_Group newGroup = group.group_;
00116       
00117       errCheck(MPI_Comm_create(parentComm, newGroup, &comm_), 
00118                "Comm_create");
00119       
00120       if (group.containsProc(parent.getRank()))
00121         {
00122           errCheck(MPI_Comm_rank(comm_, &rank_), "Comm_rank");
00123           
00124           errCheck(MPI_Comm_size(comm_, &nProc_), "Comm_size");
00125         }
00126       else
00127         {
00128           rank_ = -1;
00129           nProc_ = -1;
00130           return;
00131         }
00132     }
00133   else
00134     {
00135       rank_ = -1;
00136       nProc_ = -1;
00137     }
00138 #endif
00139 }
00140 
00141 #endif /* USE_MPI_GROUPS */
00142 
00143 MPIComm& MPIComm::world()
00144 {
00145   static MPIComm w = MPIComm();
00146   return w;
00147 }
00148 
00149 
00150 void MPIComm::synchronize() const 
00151 {
00152 #ifdef HAVE_MPI
00153   //mutex_.lock();
00154   {
00155     if (mpiIsRunning())
00156       {
00157         /* test whether errors have been detected on another proc before
00158          * doing the collective operation. */
00159         TEUCHOS_POLL_FOR_FAILURES(*this);
00160         /* if we're to this point, all processors are OK */
00161         
00162         errCheck(::MPI_Barrier(comm_), "Barrier");
00163       }
00164   }
00165   //mutex_.unlock();
00166 #endif
00167 }
00168 
00169 void MPIComm::allToAll(void* sendBuf, int sendCount, int sendType,
00170                        void* recvBuf, int recvCount, int recvType) const
00171 {
00172 #ifdef HAVE_MPI
00173   //mutex_.lock();
00174   {
00175     MPI_Datatype mpiSendType = getDataType(sendType);
00176     MPI_Datatype mpiRecvType = getDataType(recvType);
00177 
00178 
00179     if (mpiIsRunning())
00180       {
00181         /* test whether errors have been detected on another proc before
00182          * doing the collective operation. */
00183         TEUCHOS_POLL_FOR_FAILURES(*this);
00184         /* if we're to this point, all processors are OK */
00185         
00186         errCheck(::MPI_Alltoall(sendBuf, sendCount, mpiSendType,
00187                                 recvBuf, recvCount, mpiRecvType,
00188                                 comm_), "Alltoall");
00189       }
00190   }
00191   //mutex_.unlock();
00192 #else
00193   (void)sendBuf;
00194   (void)sendCount;
00195   (void)sendType;
00196   (void)recvBuf;
00197   (void)recvCount;
00198   (void)recvType;
00199 #endif
00200 }
00201 
00202 void MPIComm::allToAllv(void* sendBuf, int* sendCount, 
00203                         int* sendDisplacements, int sendType,
00204                         void* recvBuf, int* recvCount, 
00205                         int* recvDisplacements, int recvType) const
00206 {
00207 #ifdef HAVE_MPI
00208   //mutex_.lock();
00209   {
00210     MPI_Datatype mpiSendType = getDataType(sendType);
00211     MPI_Datatype mpiRecvType = getDataType(recvType);
00212 
00213     if (mpiIsRunning())
00214       {
00215         /* test whether errors have been detected on another proc before
00216          * doing the collective operation. */
00217         TEUCHOS_POLL_FOR_FAILURES(*this);
00218         /* if we're to this point, all processors are OK */   
00219         
00220         errCheck(::MPI_Alltoallv(sendBuf, sendCount, sendDisplacements, mpiSendType,
00221                                  recvBuf, recvCount, recvDisplacements, mpiRecvType,
00222                                  comm_), "Alltoallv");
00223       }
00224   }
00225   //mutex_.unlock();
00226 #else
00227   (void)sendBuf;
00228   (void)sendCount;
00229   (void)sendDisplacements;
00230   (void)sendType;
00231   (void)recvBuf;
00232   (void)recvCount;
00233   (void)recvDisplacements;
00234   (void)recvType;
00235 #endif
00236 }
00237 
00238 void MPIComm::gather(void* sendBuf, int sendCount, int sendType,
00239                      void* recvBuf, int recvCount, int recvType,
00240                      int root) const
00241 {
00242 #ifdef HAVE_MPI
00243   //mutex_.lock();
00244   {
00245     MPI_Datatype mpiSendType = getDataType(sendType);
00246     MPI_Datatype mpiRecvType = getDataType(recvType);
00247 
00248 
00249     if (mpiIsRunning())
00250       {
00251         /* test whether errors have been detected on another proc before
00252          * doing the collective operation. */
00253         TEUCHOS_POLL_FOR_FAILURES(*this);
00254         /* if we're to this point, all processors are OK */
00255         
00256         errCheck(::MPI_Gather(sendBuf, sendCount, mpiSendType,
00257                               recvBuf, recvCount, mpiRecvType,
00258                               root, comm_), "Gather");
00259       }
00260   }
00261   //mutex_.unlock();
00262 #endif
00263 }
00264 
00265 void MPIComm::gatherv(void* sendBuf, int sendCount, int sendType,
00266                      void* recvBuf, int* recvCount, int* displacements, int recvType,
00267                      int root) const
00268 {
00269 #ifdef HAVE_MPI
00270   //mutex_.lock();
00271   {
00272     MPI_Datatype mpiSendType = getDataType(sendType);
00273     MPI_Datatype mpiRecvType = getDataType(recvType);
00274     
00275     if (mpiIsRunning())
00276       {
00277         /* test whether errors have been detected on another proc before
00278          * doing the collective operation. */
00279         TEUCHOS_POLL_FOR_FAILURES(*this);
00280         /* if we're to this point, all processors are OK */
00281         
00282         errCheck(::MPI_Gatherv(sendBuf, sendCount, mpiSendType,
00283                                recvBuf, recvCount, displacements, mpiRecvType,
00284                                root, comm_), "Gatherv");
00285       }
00286   }
00287   //mutex_.unlock();
00288 #endif
00289 }
00290 
00291 void MPIComm::allGather(void* sendBuf, int sendCount, int sendType,
00292                         void* recvBuf, int recvCount, 
00293                         int recvType) const
00294 {
00295 #ifdef HAVE_MPI
00296   //mutex_.lock();
00297   {
00298     MPI_Datatype mpiSendType = getDataType(sendType);
00299     MPI_Datatype mpiRecvType = getDataType(recvType);
00300     
00301     if (mpiIsRunning())
00302       {
00303         /* test whether errors have been detected on another proc before
00304          * doing the collective operation. */
00305         TEUCHOS_POLL_FOR_FAILURES(*this);
00306         /* if we're to this point, all processors are OK */
00307         
00308         errCheck(::MPI_Allgather(sendBuf, sendCount, mpiSendType,
00309                                  recvBuf, recvCount, 
00310                                  mpiRecvType, comm_), 
00311                  "AllGather");
00312       }
00313   }
00314   //mutex_.unlock();
00315 #endif
00316 }
00317 
00318 
00319 void MPIComm::allGatherv(void* sendBuf, int sendCount, int sendType,
00320                          void* recvBuf, int* recvCount, 
00321                          int* recvDisplacements,
00322                          int recvType) const
00323 {
00324 #ifdef HAVE_MPI
00325   //mutex_.lock();
00326   {
00327     MPI_Datatype mpiSendType = getDataType(sendType);
00328     MPI_Datatype mpiRecvType = getDataType(recvType);
00329     
00330     if (mpiIsRunning())
00331       {
00332         /* test whether errors have been detected on another proc before
00333          * doing the collective operation. */
00334         TEUCHOS_POLL_FOR_FAILURES(*this);
00335         /* if we're to this point, all processors are OK */
00336         
00337         errCheck(::MPI_Allgatherv(sendBuf, sendCount, mpiSendType,
00338                                   recvBuf, recvCount, recvDisplacements,
00339                                   mpiRecvType, 
00340                                   comm_), 
00341                  "AllGatherv");
00342       }
00343   }
00344   //mutex_.unlock();
00345 #endif
00346 }
00347 
00348 
00349 void MPIComm::bcast(void* msg, int length, int type, int src) const
00350 {
00351 #ifdef HAVE_MPI
00352   //mutex_.lock();
00353   {
00354     if (mpiIsRunning())
00355       {
00356         /* test whether errors have been detected on another proc before
00357          * doing the collective operation. */
00358         TEUCHOS_POLL_FOR_FAILURES(*this);
00359         /* if we're to this point, all processors are OK */
00360         
00361         MPI_Datatype mpiType = getDataType(type);
00362         errCheck(::MPI_Bcast(msg, length, mpiType, src, 
00363                              comm_), "Bcast");
00364       }
00365   }
00366   //mutex_.unlock();
00367 #endif
00368 }
00369 
00370 void MPIComm::allReduce(void* input, void* result, int inputCount, 
00371                         int type, int op) const
00372 {
00373 #ifdef HAVE_MPI
00374 
00375   //mutex_.lock();
00376   {
00377     MPI_Op mpiOp = getOp(op);
00378     MPI_Datatype mpiType = getDataType(type);
00379     
00380     if (mpiIsRunning())
00381       {
00382         errCheck(::MPI_Allreduce(input, result, inputCount, mpiType,
00383                                  mpiOp, comm_), 
00384                  "Allreduce");
00385       }
00386   }
00387   //mutex_.unlock();
00388 #endif
00389 }
00390 
00391 
00392 #ifdef HAVE_MPI
00393 
00394 MPI_Datatype MPIComm::getDataType(int type)
00395 {
00396   TEST_FOR_EXCEPTION(
00397     !(type == INT || type==FLOAT 
00398       || type==DOUBLE || type==CHAR),
00399     std::range_error,
00400     "invalid type " << type << " in MPIComm::getDataType");
00401   
00402   if(type == INT) return MPI_INT;
00403   if(type == FLOAT) return MPI_FLOAT;
00404   if(type == DOUBLE) return MPI_DOUBLE;
00405   
00406   return MPI_CHAR;
00407 }
00408 
00409 
00410 void MPIComm::errCheck(int errCode, const std::string& methodName)
00411 {
00412   TEST_FOR_EXCEPTION(errCode != 0, std::runtime_error,
00413                      "MPI function MPI_" << methodName 
00414                      << " returned error code=" << errCode);
00415 }
00416 
00417 MPI_Op MPIComm::getOp(int op)
00418 {
00419 
00420   TEST_FOR_EXCEPTION(
00421     !(op == SUM || op==MAX 
00422       || op==MIN || op==PROD),
00423     std::range_error,
00424     "invalid operator " 
00425     << op << " in MPIComm::getOp");
00426 
00427   if( op == SUM) return MPI_SUM;
00428   else if( op == MAX) return MPI_MAX;
00429   else if( op == MIN) return MPI_MIN;
00430   return MPI_PROD;
00431 }
00432 
00433 #endif
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines
Generated on Wed Apr 13 09:57:43 2011 for Teuchos - Trilinos Tools Package by  doxygen 1.6.3