Teuchos_MPIComm.cpp

00001 // @HEADER
00002 // ***********************************************************************
00003 // 
00004 //                    Teuchos: Common Tools Package
00005 //                 Copyright (2004) Sandia Corporation
00006 // 
00007 // Under terms of Contract DE-AC04-94AL85000, there is a non-exclusive
00008 // license for use of this work by or on behalf of the U.S. Government.
00009 // 
00010 // This library is free software; you can redistribute it and/or modify
00011 // it under the terms of the GNU Lesser General Public License as
00012 // published by the Free Software Foundation; either version 2.1 of the
00013 // License, or (at your option) any later version.
00014 //  
00015 // This library is distributed in the hope that it will be useful, but
00016 // WITHOUT ANY WARRANTY; without even the implied warranty of
00017 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
00018 // Lesser General Public License for more details.
00019 //  
00020 // You should have received a copy of the GNU Lesser General Public
00021 // License along with this library; if not, write to the Free Software
00022 // Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307
00023 // USA
00024 // Questions? Contact Michael A. Heroux (maherou@sandia.gov) 
00025 // 
00026 // ***********************************************************************
00027 // @HEADER
00028 
00029 #include "Teuchos_MPIComm.hpp"
00030 #include "Teuchos_ErrorPolling.hpp"
00031 
00032 
00033 using namespace Teuchos;
00034 
00035 namespace Teuchos
00036 {
00037   const int MPIComm::INT = 1;
00038   const int MPIComm::FLOAT = 2;
00039   const int MPIComm::DOUBLE = 3;
00040   const int MPIComm::CHAR = 4;
00041 
00042   const int MPIComm::SUM = 5;
00043   const int MPIComm::MIN = 6;
00044   const int MPIComm::MAX = 7;
00045   const int MPIComm::PROD = 8;
00046 }
00047 
00048 
00049 MPIComm::MPIComm()
00050   :
00051 #ifdef HAVE_MPI
00052   comm_(MPI_COMM_WORLD),
00053 #endif
00054   nProc_(0), myRank_(0)
00055 {
00056   init();
00057 }
00058 
00059 #ifdef HAVE_MPI
00060 MPIComm::MPIComm(MPI_Comm comm)
00061   : comm_(comm), nProc_(0), myRank_(0)
00062 {
00063   init();
00064 }
00065 #endif
00066 
00067 int MPIComm::mpiIsRunning() const
00068 {
00069   int mpiStarted = 0;
00070 #ifdef HAVE_MPI
00071   MPI_Initialized(&mpiStarted);
00072 #endif
00073   return mpiStarted;
00074 }
00075 
00076 void MPIComm::init()
00077 {
00078 #ifdef HAVE_MPI
00079 
00080   if (mpiIsRunning())
00081     {
00082       errCheck(MPI_Comm_rank(comm_, &myRank_), "Comm_rank");
00083       errCheck(MPI_Comm_size(comm_, &nProc_), "Comm_size");
00084     }
00085   else
00086     {
00087       nProc_ = 1;
00088       myRank_ = 0;
00089     }
00090   
00091 #else
00092   nProc_ = 1;
00093   myRank_ = 0;
00094 #endif
00095 }
00096 
00097 #ifdef USE_MPI_GROUPS /* we're ignoring groups for now */
00098 
00099 MPIComm::MPIComm(const MPIComm& parent, const MPIGroup& group)
00100   :
00101 #ifdef HAVE_MPI
00102   comm_(MPI_COMM_WORLD), 
00103 #endif
00104   nProc_(0), myRank_(0)
00105 {
00106 #ifdef HAVE_MPI
00107   if (group.getNProc()==0)
00108     {
00109       rank_ = -1;
00110       nProc_ = 0;
00111     }
00112   else if (parent.containsMe())
00113     {
00114       MPI_Comm parentComm = parent.comm_;
00115       MPI_Group newGroup = group.group_;
00116       
00117       errCheck(MPI_Comm_create(parentComm, newGroup, &comm_), 
00118                "Comm_create");
00119       
00120       if (group.containsProc(parent.getRank()))
00121         {
00122           errCheck(MPI_Comm_rank(comm_, &rank_), "Comm_rank");
00123           
00124           errCheck(MPI_Comm_size(comm_, &nProc_), "Comm_size");
00125         }
00126       else
00127         {
00128           rank_ = -1;
00129           nProc_ = -1;
00130           return;
00131         }
00132     }
00133   else
00134     {
00135       rank_ = -1;
00136       nProc_ = -1;
00137     }
00138 #endif
00139 }
00140 
00141 #endif /* USE_MPI_GROUPS */
00142 
00143 MPIComm& MPIComm::world()
00144 {
00145   static MPIComm w = MPIComm();
00146   return w;
00147 }
00148 
00149 
00150 void MPIComm::synchronize() const 
00151 {
00152 #ifdef HAVE_MPI
00153   //mutex_.lock();
00154   {
00155     if (mpiIsRunning())
00156       {
00157         /* test whether errors have been detected on another proc before
00158          * doing the collective operation. */
00159         TEUCHOS_POLL_FOR_FAILURES(*this);
00160         /* if we're to this point, all processors are OK */
00161         
00162         errCheck(::MPI_Barrier(comm_), "Barrier");
00163       }
00164   }
00165   //mutex_.unlock();
00166 #endif
00167 }
00168 
00169 void MPIComm::allToAll(void* sendBuf, int sendCount, int sendType,
00170                        void* recvBuf, int recvCount, int recvType) const
00171 {
00172 #ifdef HAVE_MPI
00173   //mutex_.lock();
00174   {
00175     MPI_Datatype mpiSendType = getDataType(sendType);
00176     MPI_Datatype mpiRecvType = getDataType(recvType);
00177 
00178 
00179     if (mpiIsRunning())
00180       {
00181         /* test whether errors have been detected on another proc before
00182          * doing the collective operation. */
00183         TEUCHOS_POLL_FOR_FAILURES(*this);
00184         /* if we're to this point, all processors are OK */
00185         
00186         errCheck(::MPI_Alltoall(sendBuf, sendCount, mpiSendType,
00187                                 recvBuf, recvCount, mpiRecvType,
00188                                 comm_), "Alltoall");
00189       }
00190   }
00191   //mutex_.unlock();
00192 #endif
00193 }
00194 
00195 void MPIComm::allToAllv(void* sendBuf, int* sendCount, 
00196                         int* sendDisplacements, int sendType,
00197                         void* recvBuf, int* recvCount, 
00198                         int* recvDisplacements, int recvType) const
00199 {
00200 #ifdef HAVE_MPI
00201   //mutex_.lock();
00202   {
00203     MPI_Datatype mpiSendType = getDataType(sendType);
00204     MPI_Datatype mpiRecvType = getDataType(recvType);
00205 
00206     if (mpiIsRunning())
00207       {
00208         /* test whether errors have been detected on another proc before
00209          * doing the collective operation. */
00210         TEUCHOS_POLL_FOR_FAILURES(*this);
00211         /* if we're to this point, all processors are OK */   
00212         
00213         errCheck(::MPI_Alltoallv(sendBuf, sendCount, sendDisplacements, mpiSendType,
00214                                  recvBuf, recvCount, recvDisplacements, mpiRecvType,
00215                                  comm_), "Alltoallv");
00216       }
00217   }
00218   //mutex_.unlock();
00219 #endif
00220 }
00221 
00222 void MPIComm::gather(void* sendBuf, int sendCount, int sendType,
00223                      void* recvBuf, int recvCount, int recvType,
00224                      int root) const
00225 {
00226 #ifdef HAVE_MPI
00227   //mutex_.lock();
00228   {
00229     MPI_Datatype mpiSendType = getDataType(sendType);
00230     MPI_Datatype mpiRecvType = getDataType(recvType);
00231 
00232 
00233     if (mpiIsRunning())
00234       {
00235         /* test whether errors have been detected on another proc before
00236          * doing the collective operation. */
00237         TEUCHOS_POLL_FOR_FAILURES(*this);
00238         /* if we're to this point, all processors are OK */
00239         
00240         errCheck(::MPI_Gather(sendBuf, sendCount, mpiSendType,
00241                               recvBuf, recvCount, mpiRecvType,
00242                               root, comm_), "Gather");
00243       }
00244   }
00245   //mutex_.unlock();
00246 #endif
00247 }
00248 
00249 void MPIComm::gatherv(void* sendBuf, int sendCount, int sendType,
00250                      void* recvBuf, int* recvCount, int* displacements, int recvType,
00251                      int root) const
00252 {
00253 #ifdef HAVE_MPI
00254   //mutex_.lock();
00255   {
00256     MPI_Datatype mpiSendType = getDataType(sendType);
00257     MPI_Datatype mpiRecvType = getDataType(recvType);
00258     
00259     if (mpiIsRunning())
00260       {
00261         /* test whether errors have been detected on another proc before
00262          * doing the collective operation. */
00263         TEUCHOS_POLL_FOR_FAILURES(*this);
00264         /* if we're to this point, all processors are OK */
00265         
00266         errCheck(::MPI_Gatherv(sendBuf, sendCount, mpiSendType,
00267                                recvBuf, recvCount, displacements, mpiRecvType,
00268                                root, comm_), "Gatherv");
00269       }
00270   }
00271   //mutex_.unlock();
00272 #endif
00273 }
00274 
00275 void MPIComm::allGather(void* sendBuf, int sendCount, int sendType,
00276                         void* recvBuf, int recvCount, 
00277                         int recvType) const
00278 {
00279 #ifdef HAVE_MPI
00280   //mutex_.lock();
00281   {
00282     MPI_Datatype mpiSendType = getDataType(sendType);
00283     MPI_Datatype mpiRecvType = getDataType(recvType);
00284     
00285     if (mpiIsRunning())
00286       {
00287         /* test whether errors have been detected on another proc before
00288          * doing the collective operation. */
00289         TEUCHOS_POLL_FOR_FAILURES(*this);
00290         /* if we're to this point, all processors are OK */
00291         
00292         errCheck(::MPI_Allgather(sendBuf, sendCount, mpiSendType,
00293                                  recvBuf, recvCount, 
00294                                  mpiRecvType, comm_), 
00295                  "AllGather");
00296       }
00297   }
00298   //mutex_.unlock();
00299 #endif
00300 }
00301 
00302 
00303 void MPIComm::allGatherv(void* sendBuf, int sendCount, int sendType,
00304                          void* recvBuf, int* recvCount, 
00305                          int* recvDisplacements,
00306                          int recvType) const
00307 {
00308 #ifdef HAVE_MPI
00309   //mutex_.lock();
00310   {
00311     MPI_Datatype mpiSendType = getDataType(sendType);
00312     MPI_Datatype mpiRecvType = getDataType(recvType);
00313     
00314     if (mpiIsRunning())
00315       {
00316         /* test whether errors have been detected on another proc before
00317          * doing the collective operation. */
00318         TEUCHOS_POLL_FOR_FAILURES(*this);
00319         /* if we're to this point, all processors are OK */
00320         
00321         errCheck(::MPI_Allgatherv(sendBuf, sendCount, mpiSendType,
00322                                   recvBuf, recvCount, recvDisplacements,
00323                                   mpiRecvType, 
00324                                   comm_), 
00325                  "AllGatherv");
00326       }
00327   }
00328   //mutex_.unlock();
00329 #endif
00330 }
00331 
00332 
00333 void MPIComm::bcast(void* msg, int length, int type, int src) const
00334 {
00335 #ifdef HAVE_MPI
00336   //mutex_.lock();
00337   {
00338     if (mpiIsRunning())
00339       {
00340         /* test whether errors have been detected on another proc before
00341          * doing the collective operation. */
00342         TEUCHOS_POLL_FOR_FAILURES(*this);
00343         /* if we're to this point, all processors are OK */
00344         
00345         MPI_Datatype mpiType = getDataType(type);
00346         errCheck(::MPI_Bcast(msg, length, mpiType, src, 
00347                              comm_), "Bcast");
00348       }
00349   }
00350   //mutex_.unlock();
00351 #endif
00352 }
00353 
00354 void MPIComm::allReduce(void* input, void* result, int inputCount, 
00355                         int type, int op) const
00356 {
00357 #ifdef HAVE_MPI
00358 
00359   //mutex_.lock();
00360   {
00361     MPI_Op mpiOp = getOp(op);
00362     MPI_Datatype mpiType = getDataType(type);
00363     
00364     if (mpiIsRunning())
00365       {
00366         errCheck(::MPI_Allreduce(input, result, inputCount, mpiType,
00367                                  mpiOp, comm_), 
00368                  "Allreduce");
00369       }
00370   }
00371   //mutex_.unlock();
00372 #endif
00373 }
00374 
00375 
00376 #ifdef HAVE_MPI
00377 
00378 MPI_Datatype MPIComm::getDataType(int type)
00379 {
00380   TEST_FOR_EXCEPTION(
00381     !(type == INT || type==FLOAT 
00382       || type==DOUBLE || type==CHAR),
00383     std::range_error,
00384     "invalid type " << type << " in MPIComm::getDataType");
00385   
00386   if(type == INT) return MPI_INT;
00387   if(type == FLOAT) return MPI_FLOAT;
00388   if(type == DOUBLE) return MPI_DOUBLE;
00389   
00390   return MPI_CHAR;
00391 }
00392 
00393 
00394 void MPIComm::errCheck(int errCode, const std::string& methodName)
00395 {
00396   TEST_FOR_EXCEPTION(errCode != 0, std::runtime_error,
00397                      "MPI function MPI_" << methodName 
00398                      << " returned error code=" << errCode);
00399 }
00400 
00401 MPI_Op MPIComm::getOp(int op)
00402 {
00403 
00404   TEST_FOR_EXCEPTION(
00405     !(op == SUM || op==MAX 
00406       || op==MIN || op==PROD),
00407     std::range_error,
00408     "invalid operator " 
00409     << op << " in MPIComm::getOp");
00410 
00411   if( op == SUM) return MPI_SUM;
00412   else if( op == MAX) return MPI_MAX;
00413   else if( op == MIN) return MPI_MIN;
00414   return MPI_PROD;
00415 }
00416 
00417 #endif

Generated on Tue Oct 20 12:45:26 2009 for Teuchos - Trilinos Tools Package by doxygen 1.4.7