fei_CommUtils.hpp

00001 
00002 /*--------------------------------------------------------------------*/
00003 /*    Copyright 2007 Sandia Corporation.                              */
00004 /*    Under the terms of Contract DE-AC04-94AL85000, there is a       */
00005 /*    non-exclusive license for use of this work by or on behalf      */
00006 /*    of the U.S. Government.  Export of this program may require     */
00007 /*    a license from the United States Government.                    */
00008 /*--------------------------------------------------------------------*/
00009 
00010 #ifndef _fei_CommUtils_hpp_
00011 #define _fei_CommUtils_hpp_
00012 
00013 #include <fei_macros.hpp>
00014 #include <fei_mpi.h>
00015 #include <fei_mpiTraits.hpp>
00016 #include <fei_chk_mpi.hpp>
00017 #include <fei_iostream.hpp>
00018 #include <snl_fei_RaggedTable.hpp>
00019 
00020 #include <vector>
00021 #include <set>
00022 #include <map>
00023 
00024 #include <fei_ErrMacros.hpp>
00025 #undef fei_file
00026 #define fei_file "fei_CommUtils.hpp"
00027 
00028 namespace fei {
00029 
00033 int localProc(MPI_Comm comm);
00034 
00038 int numProcs(MPI_Comm comm);
00039 
00040 void Barrier(MPI_Comm comm);
00041 
00047 int mirrorProcs(MPI_Comm comm, std::vector<int>& toProcs, std::vector<int>& fromProcs);
00048 
00049 typedef snl_fei::RaggedTable<std::map<int,std::set<int>*>,std::set<int> > comm_map;
00050 
00053 int mirrorCommPattern(MPI_Comm comm, comm_map* inPattern, comm_map*& outPattern);
00054 
00059 int Allreduce(MPI_Comm comm, bool localBool, bool& globalBool);
00060 
00072 int exchangeIntData(MPI_Comm comm,
00073                     std::vector<int>& sendProcs,
00074                     std::vector<int>& sendData,
00075                     std::vector<int>& recvProcs,
00076                     std::vector<int>& recvData);
00077  
00078 //----------------------------------------------------------------------------
00084 template<class T>
00085 int GlobalMax(MPI_Comm comm, std::vector<T>& local, std::vector<T>& global)
00086 {
00087 #ifdef FEI_SER
00088   global = local;
00089 #else
00090 
00091   MPI_Datatype mpi_dtype = fei::mpiTraits<T>::mpi_type();
00092 
00093   try {
00094     global.resize(local.size());
00095   }
00096   catch(std::runtime_error& exc) {
00097     FEI_CERR << exc.what()<<FEI_ENDL;
00098     return(-1);
00099   }
00100 
00101   CHK_MPI( MPI_Allreduce(&(local[0]), &(global[0]),
00102        local.size(), mpi_dtype, MPI_MAX, comm) );
00103 #endif
00104 
00105   return(0);
00106 }
00107 
00108 //----------------------------------------------------------------------------
00114 template<class T>
00115 int GlobalMax(MPI_Comm comm, T local, T& global)
00116 {
00117 #ifdef FEI_SER
00118   global = local;
00119 #else
00120   MPI_Datatype mpi_dtype = fei::mpiTraits<T>::mpi_type();
00121 
00122   CHK_MPI( MPI_Allreduce(&local, &global, 1, mpi_dtype, MPI_MAX, comm) );
00123 #endif
00124   return(0);
00125 }
00126 
00127 //----------------------------------------------------------------------------
00133 template<class T>
00134 int GlobalSum(MPI_Comm comm, std::vector<T>& local, std::vector<T>& global)
00135 {
00136 #ifdef FEI_SER
00137   global = local;
00138 #else
00139   global.resize(local.size());
00140 
00141   MPI_Datatype mpi_dtype = fei::mpiTraits<T>::mpi_type();
00142 
00143   CHK_MPI( MPI_Allreduce(&(local[0]), &(global[0]),
00144                       local.size(), mpi_dtype, MPI_SUM, comm) );
00145 #endif
00146   return(0);
00147 }
00148 
00149 //----------------------------------------------------------------------------
00151 template<class T>
00152 int GlobalSum(MPI_Comm comm, T local, T& global)
00153 {
00154 #ifdef FEI_SER
00155   global = local;
00156 #else
00157   MPI_Datatype mpi_dtype = fei::mpiTraits<T>::mpi_type();
00158 
00159   CHK_MPI( MPI_Allreduce(&local, &global, 1, mpi_dtype, MPI_SUM, comm) );
00160 #endif
00161   return(0);
00162 }
00163 
00164 
00167 template<class T>
00168 int Allgatherv(MPI_Comm comm,
00169                std::vector<T>& sendbuf,
00170                std::vector<int>& recvLengths,
00171                std::vector<T>& recvbuf)
00172 {
00173 #ifdef FEI_SER
00174   //If we're in serial mode, just copy sendbuf to recvbuf and return.
00175 
00176   recvbuf = sendbuf;
00177   recvLengths.resize(1);
00178   recvLengths[0] = sendbuf.size();
00179 #else
00180   int numProcs = 1;
00181   MPI_Comm_size(comm, &numProcs);
00182 
00183   try {
00184 
00185   MPI_Datatype mpi_dtype = fei::mpiTraits<T>::mpi_type();
00186 
00187   std::vector<int> tmpInt(numProcs, 0);
00188 
00189   int len = sendbuf.size();
00190   int* tmpBuf = &tmpInt[0];
00191 
00192   recvLengths.resize(numProcs);
00193   int* recvLenPtr = &recvLengths[0];
00194 
00195   CHK_MPI( MPI_Allgather(&len, 1, MPI_INT, recvLenPtr, 1, MPI_INT, comm) );
00196 
00197   int displ = 0;
00198   for(int i=0; i<numProcs; i++) {
00199     tmpBuf[i] = displ;
00200     displ += recvLenPtr[i];
00201   }
00202 
00203   if (displ == 0) {
00204     recvbuf.resize(0);
00205     return(0);
00206   }
00207 
00208   recvbuf.resize(displ);
00209 
00210   T* sendbufPtr = sendbuf.size()>0 ? &sendbuf[0] : NULL;
00211   
00212   CHK_MPI( MPI_Allgatherv(sendbufPtr, len, mpi_dtype,
00213       &recvbuf[0], &recvLengths[0], tmpBuf,
00214       mpi_dtype, comm) );
00215 
00216   }
00217   catch(std::runtime_error& exc) {
00218     FEI_CERR << exc.what() << FEI_ENDL;
00219     return(-1);
00220   }
00221 #endif
00222 
00223   return(0);
00224 }
00225 
00226 //------------------------------------------------------------------------
00227 template<class T>
00228 int Bcast(MPI_Comm comm, std::vector<T>& sendbuf, int sourceProc)
00229 {
00230 #ifndef FEI_SER
00231   MPI_Datatype mpi_dtype = fei::mpiTraits<T>::mpi_type();
00232 
00233   CHK_MPI(MPI_Bcast(&sendbuf[0], sendbuf.size(), mpi_dtype,
00234                     sourceProc, comm) );
00235 #endif
00236   return(0);
00237 }
00238 
00239 //------------------------------------------------------------------------
00240 template<class T>
00241 int exchangeData(MPI_Comm comm,
00242                  std::vector<int>& sendProcs,
00243                  std::vector<std::vector<T>*>& sendData,
00244                  std::vector<int>& recvProcs,
00245                  std::vector<int>& recvLengths,
00246                  bool recvLengthsKnownOnEntry,
00247                  std::vector<T>& recvData)
00248 {
00249   if (sendProcs.size() == 0 && recvProcs.size() == 0) return(0);
00250   if (sendProcs.size() != sendData.size()) return(-1);
00251 #ifndef FEI_SER
00252   std::vector<MPI_Request> mpiReqs;
00253   mpiReqs.resize(recvProcs.size());
00254   recvLengths.resize(recvProcs.size());
00255 
00256   int tag = 11119;
00257   std::vector<int> tmpIntData;
00258   MPI_Datatype mpi_dtype = fei::mpiTraits<T>::mpi_type();
00259 
00260   if (!recvLengthsKnownOnEntry) {
00261     tmpIntData.resize(sendData.size());
00262     for(unsigned i=0; i<sendData.size(); ++i) {
00263       tmpIntData[i] = sendData[i]->size();
00264     }
00265 
00266     if ( exchangeIntData(comm, sendProcs, tmpIntData, recvProcs, recvLengths) != 0) {
00267       return(-1);
00268     }
00269     int totalRecvLength = 0;
00270     for(unsigned i=0; i<recvLengths.size(); ++i) {
00271       totalRecvLength += recvLengths[i];
00272     }
00273 
00274     recvData.resize(totalRecvLength);
00275   }
00276 
00277   //launch Irecv's for recvData:
00278 
00279   unsigned numRecvProcs = recvProcs.size();
00280   int recv_offset = 0;
00281   int req_offset = 0;
00282   int localProc = fei::localProc(comm);
00283   for(unsigned i=0; i<recvProcs.size(); ++i) {
00284     if (recvProcs[i] == localProc) {--numRecvProcs; continue; }
00285 
00286     int len = recvLengths[i];
00287     T* recvBuf = len>0 ? &(recvData[recv_offset]) : NULL;
00288 
00289     CHK_MPI( MPI_Irecv(recvBuf, len, mpi_dtype, recvProcs[i],
00290                        tag, comm, &mpiReqs[req_offset++]) );
00291 
00292     recv_offset += len;
00293   }
00294 
00295   //send the sendData:
00296 
00297   for(unsigned i=0; i<sendProcs.size(); ++i) {
00298     if (sendProcs[i] == localProc) continue;
00299 
00300     CHK_MPI( MPI_Send(&(*(sendData[i]))[0], sendData[i]->size(), mpi_dtype,
00301                       sendProcs[i], tag, comm) );
00302   }
00303 
00304   //complete the Irecvs:
00305   for(unsigned i=0; i<numRecvProcs; ++i) {
00306     if (recvProcs[i] == localProc) continue;
00307     int index;
00308     MPI_Status status;
00309     CHK_MPI( MPI_Waitany(numRecvProcs, &mpiReqs[0], &index, &status) );
00310   }
00311 
00312 #endif
00313   return(0);
00314 }
00315 
00316 //------------------------------------------------------------------------
00317 template<class T>
00318 int exchangeData(MPI_Comm comm,
00319                  std::vector<int>& sendProcs,
00320                  std::vector<std::vector<T>*>& sendData,
00321                  std::vector<int>& recvProcs,
00322                  bool recvLengthsKnownOnEntry,
00323                  std::vector<std::vector<T>*>& recvData)
00324 {
00325   if (sendProcs.size() == 0 && recvProcs.size() == 0) return(0);
00326   if (sendProcs.size() != sendData.size()) return(-1);
00327 #ifndef FEI_SER
00328   int tag = 11115;
00329   MPI_Datatype mpi_dtype = fei::mpiTraits<T>::mpi_type();
00330   std::vector<MPI_Request> mpiReqs;
00331 
00332   try {
00333   mpiReqs.resize(recvProcs.size());
00334 
00335   if (!recvLengthsKnownOnEntry) {
00336     std::vector<int> tmpIntData;
00337     tmpIntData.resize(sendData.size());
00338     std::vector<int> recvLens(sendData.size());
00339     for(unsigned i=0; i<sendData.size(); ++i) {
00340       tmpIntData[i] = (int)sendData[i]->size();
00341     }
00342 
00343     if (exchangeIntData(comm, sendProcs, tmpIntData, recvProcs, recvLens) != 0) {
00344       return(-1);
00345     }
00346 
00347     for(unsigned i=0; i<recvLens.size(); ++i) {
00348       recvData[i]->resize(recvLens[i]);
00349     }
00350   }
00351   }
00352   catch(std::runtime_error& exc) {
00353     FEI_CERR << exc.what() << FEI_ENDL;
00354     return(-1);
00355   }
00356 
00357   //launch Irecv's for recvData:
00358 
00359   size_t numRecvProcs = recvProcs.size();
00360   int req_offset = 0;
00361   int localProc = fei::localProc(comm);
00362   for(unsigned i=0; i<recvProcs.size(); ++i) {
00363     if (recvProcs[i] == localProc) {--numRecvProcs; continue;}
00364 
00365     size_t len = recvData[i]->size();
00366     std::vector<T>& rbuf = *recvData[i];
00367 
00368     CHK_MPI( MPI_Irecv(&rbuf[0], (int)len, mpi_dtype,
00369                        recvProcs[i], tag, comm, &mpiReqs[req_offset++]) );
00370   }
00371 
00372   //send the sendData:
00373 
00374   for(unsigned i=0; i<sendProcs.size(); ++i) {
00375     if (sendProcs[i] == localProc) continue;
00376 
00377     std::vector<T>& sbuf = *sendData[i];
00378     CHK_MPI( MPI_Send(&sbuf[0], (int)sbuf.size(), mpi_dtype,
00379                       sendProcs[i], tag, comm) );
00380   }
00381 
00382   //complete the Irecv's:
00383   for(unsigned i=0; i<numRecvProcs; ++i) {
00384     if (recvProcs[i] == localProc) continue;
00385     int index;
00386     MPI_Status status;
00387     CHK_MPI( MPI_Waitany((int)numRecvProcs, &mpiReqs[0], &index, &status) );
00388   }
00389 
00390 #endif
00391   return(0);
00392 }
00393 
00394 
00395 //------------------------------------------------------------------------
00413 template<class T>
00414 class MessageHandler {
00415 public:
00416   virtual ~MessageHandler(){}
00417 
00421   virtual std::vector<int>& getSendProcs() = 0;
00422 
00426   virtual std::vector<int>& getRecvProcs() = 0;
00427 
00431   virtual int getSendMessageLength(int destProc, int& messageLength) = 0;
00432 
00436   virtual int getSendMessage(int destProc, std::vector<T>& message) = 0;
00437 
00441   virtual int processRecvMessage(int srcProc, std::vector<T>& message) = 0;
00442 };//class MessageHandler
00443 
00444 
00445 //------------------------------------------------------------------------
00446 template<class T>
00447 int exchange(MPI_Comm comm, MessageHandler<T>* msgHandler)
00448 {
00449 #ifdef FEI_SER
00450   (void)msgHandler;
00451 #else
00452   int numProcs = fei::numProcs(comm);
00453   if (numProcs < 2) return(0);
00454 
00455   std::vector<int>& sendProcs = msgHandler->getSendProcs();
00456   int numSendProcs = sendProcs.size();
00457   std::vector<int>& recvProcs = msgHandler->getRecvProcs();
00458   int numRecvProcs = recvProcs.size();
00459   int i;
00460 
00461   if (numSendProcs < 1 && numRecvProcs < 1) {
00462     return(0);
00463   }
00464 
00465   std::vector<int> sendMsgLengths(numSendProcs), recvMsgLengths(numRecvProcs);
00466 
00467   for(i=0; i<numSendProcs; ++i) {
00468     CHK_ERR( msgHandler->getSendMessageLength(sendProcs[i], sendMsgLengths[i]) );
00469   }
00470 
00471   std::vector<T> recvMsgs;
00472 
00473   std::vector<std::vector<T>* > sendMsgs(numSendProcs);
00474   for(i=0; i<numSendProcs; ++i) {
00475     sendMsgs[i] = new std::vector<T>;
00476     CHK_ERR( msgHandler->getSendMessage(sendProcs[i], *(sendMsgs[i])) );
00477   }
00478 
00479   CHK_ERR( exchangeData(comm, sendProcs, sendMsgs,
00480                         recvProcs, recvMsgLengths, false, recvMsgs) );
00481 
00482   int offset = 0;
00483   for(i=0; i<numRecvProcs; ++i) {
00484     int msgLen = recvMsgLengths[i];
00485     T* mdPtr = &(recvMsgs[offset]);
00486     std::vector<T> recvMsg(mdPtr, mdPtr+msgLen);
00487     CHK_ERR( msgHandler->processRecvMessage(recvProcs[i], recvMsg ) );
00488     offset += msgLen;
00489   }
00490 
00491   for(i=0; i<numSendProcs; ++i) {
00492     delete sendMsgs[i];
00493   }
00494 #endif
00495 
00496   return(0);
00497 }
00498 
00499 
00500 } //namespace fei
00501 
00502 #endif // _fei_CommUtils_hpp_
00503 

Generated on Wed May 12 21:30:40 2010 for FEI by  doxygen 1.4.7