FEI Version of the Day
fei_CommUtils.hpp
00001 
00002 /*--------------------------------------------------------------------*/
00003 /*    Copyright 2007 Sandia Corporation.                              */
00004 /*    Under the terms of Contract DE-AC04-94AL85000, there is a       */
00005 /*    non-exclusive license for use of this work by or on behalf      */
00006 /*    of the U.S. Government.  Export of this program may require     */
00007 /*    a license from the United States Government.                    */
00008 /*--------------------------------------------------------------------*/
00009 
00010 #ifndef _fei_CommUtils_hpp_
00011 #define _fei_CommUtils_hpp_
00012 
00013 #include <fei_macros.hpp>
00014 #include <fei_mpi.h>
00015 #include <fei_mpiTraits.hpp>
00016 #include <fei_chk_mpi.hpp>
00017 #include <fei_iostream.hpp>
00018 #include <fei_CommMap.hpp>
00019 #include <fei_TemplateUtils.hpp>
00020 #include <snl_fei_RaggedTable.hpp>
00021 
00022 #include <vector>
00023 #include <set>
00024 #include <map>
00025 
00026 #include <fei_ErrMacros.hpp>
00027 #undef fei_file
00028 #define fei_file "fei_CommUtils.hpp"
00029 
00030 namespace fei {
00031 
00035 int localProc(MPI_Comm comm);
00036 
00040 int numProcs(MPI_Comm comm);
00041 
00042 void Barrier(MPI_Comm comm);
00043 
00049 int mirrorProcs(MPI_Comm comm, std::vector<int>& toProcs, std::vector<int>& fromProcs);
00050 
00051 typedef snl_fei::RaggedTable<std::map<int,std::set<int>*>,std::set<int> > comm_map;
00052 
00055 int mirrorCommPattern(MPI_Comm comm, comm_map* inPattern, comm_map*& outPattern);
00056 
00061 int Allreduce(MPI_Comm comm, bool localBool, bool& globalBool);
00062 
00078 int exchangeIntData(MPI_Comm comm,
00079                     const std::vector<int>& sendProcs,
00080                     std::vector<int>& sendData,
00081                     const std::vector<int>& recvProcs,
00082                     std::vector<int>& recvData);
00083  
00084 //----------------------------------------------------------------------------
00090 template<class T>
00091 int GlobalMax(MPI_Comm comm, std::vector<T>& local, std::vector<T>& global)
00092 {
00093 #ifdef FEI_SER
00094   global = local;
00095 #else
00096 
00097   MPI_Datatype mpi_dtype = fei::mpiTraits<T>::mpi_type();
00098 
00099   try {
00100     global.resize(local.size());
00101   }
00102   catch(std::runtime_error& exc) {
00103     fei::console_out() << exc.what()<<FEI_ENDL;
00104     return(-1);
00105   }
00106 
00107   CHK_MPI( MPI_Allreduce(&(local[0]), &(global[0]),
00108        local.size(), mpi_dtype, MPI_MAX, comm) );
00109 #endif
00110 
00111   return(0);
00112 }
00113 
00114 //----------------------------------------------------------------------------
00120 template<class T>
00121 int GlobalMax(MPI_Comm comm, T local, T& global)
00122 {
00123 #ifdef FEI_SER
00124   global = local;
00125 #else
00126   MPI_Datatype mpi_dtype = fei::mpiTraits<T>::mpi_type();
00127 
00128   CHK_MPI( MPI_Allreduce(&local, &global, 1, mpi_dtype, MPI_MAX, comm) );
00129 #endif
00130   return(0);
00131 }
00132 
00133 //----------------------------------------------------------------------------
00139 template<class T>
00140 int GlobalMin(MPI_Comm comm, T local, T& global)
00141 {
00142 #ifdef FEI_SER
00143   global = local;
00144 #else
00145   MPI_Datatype mpi_dtype = fei::mpiTraits<T>::mpi_type();
00146 
00147   CHK_MPI( MPI_Allreduce(&local, &global, 1, mpi_dtype, MPI_MIN, comm) );
00148 #endif
00149   return(0);
00150 }
00151 
00152 //----------------------------------------------------------------------------
00158 template<class T>
00159 int GlobalSum(MPI_Comm comm, std::vector<T>& local, std::vector<T>& global)
00160 {
00161 #ifdef FEI_SER
00162   global = local;
00163 #else
00164   global.resize(local.size());
00165 
00166   MPI_Datatype mpi_dtype = fei::mpiTraits<T>::mpi_type();
00167 
00168   CHK_MPI( MPI_Allreduce(&(local[0]), &(global[0]),
00169                       local.size(), mpi_dtype, MPI_SUM, comm) );
00170 #endif
00171   return(0);
00172 }
00173 
00174 //----------------------------------------------------------------------------
00176 template<class T>
00177 int GlobalSum(MPI_Comm comm, T local, T& global)
00178 {
00179 #ifdef FEI_SER
00180   global = local;
00181 #else
00182   MPI_Datatype mpi_dtype = fei::mpiTraits<T>::mpi_type();
00183 
00184   CHK_MPI( MPI_Allreduce(&local, &global, 1, mpi_dtype, MPI_SUM, comm) );
00185 #endif
00186   return(0);
00187 }
00188 
00189 
00192 template<class T>
00193 int Allgatherv(MPI_Comm comm,
00194                std::vector<T>& sendbuf,
00195                std::vector<int>& recvLengths,
00196                std::vector<T>& recvbuf)
00197 {
00198 #ifdef FEI_SER
00199   //If we're in serial mode, just copy sendbuf to recvbuf and return.
00200 
00201   recvbuf = sendbuf;
00202   recvLengths.resize(1);
00203   recvLengths[0] = sendbuf.size();
00204 #else
00205   int numProcs = 1;
00206   MPI_Comm_size(comm, &numProcs);
00207 
00208   try {
00209 
00210   MPI_Datatype mpi_dtype = fei::mpiTraits<T>::mpi_type();
00211 
00212   std::vector<int> tmpInt(numProcs, 0);
00213 
00214   int len = sendbuf.size();
00215   int* tmpBuf = &tmpInt[0];
00216 
00217   recvLengths.resize(numProcs);
00218   int* recvLenPtr = &recvLengths[0];
00219 
00220   CHK_MPI( MPI_Allgather(&len, 1, MPI_INT, recvLenPtr, 1, MPI_INT, comm) );
00221 
00222   int displ = 0;
00223   for(int i=0; i<numProcs; i++) {
00224     tmpBuf[i] = displ;
00225     displ += recvLenPtr[i];
00226   }
00227 
00228   if (displ == 0) {
00229     recvbuf.resize(0);
00230     return(0);
00231   }
00232 
00233   recvbuf.resize(displ);
00234 
00235   T* sendbufPtr = sendbuf.size()>0 ? &sendbuf[0] : NULL;
00236   
00237   CHK_MPI( MPI_Allgatherv(sendbufPtr, len, mpi_dtype,
00238       &recvbuf[0], &recvLengths[0], tmpBuf,
00239       mpi_dtype, comm) );
00240 
00241   }
00242   catch(std::runtime_error& exc) {
00243     fei::console_out() << exc.what() << FEI_ENDL;
00244     return(-1);
00245   }
00246 #endif
00247 
00248   return(0);
00249 }
00250 
00251 //------------------------------------------------------------------------
00252 template<class T>
00253 int Bcast(MPI_Comm comm, std::vector<T>& sendbuf, int sourceProc)
00254 {
00255 #ifndef FEI_SER
00256   MPI_Datatype mpi_dtype = fei::mpiTraits<T>::mpi_type();
00257 
00258   CHK_MPI(MPI_Bcast(&sendbuf[0], sendbuf.size(), mpi_dtype,
00259                     sourceProc, comm) );
00260 #endif
00261   return(0);
00262 }
00263 
00264 //------------------------------------------------------------------------
00272 template<typename T>
00273 int exchangeCommMapData(MPI_Comm comm,
00274                         const typename CommMap<T>::Type& sendCommMap,
00275                         typename CommMap<T>::Type& recvCommMap,
00276                         bool recvProcsKnownOnEntry = false,
00277                         bool recvLengthsKnownOnEntry = false)
00278 {
00279   if (!recvProcsKnownOnEntry) {
00280     recvCommMap.clear();
00281   }
00282 
00283 #ifndef FEI_SER
00284   int tag = 11120;
00285   MPI_Datatype mpi_dtype = fei::mpiTraits<T>::mpi_type();
00286 
00287   std::vector<int> sendProcs;
00288   fei::copyKeysToVector(sendCommMap, sendProcs);
00289   std::vector<int> recvProcs;
00290 
00291   if (recvProcsKnownOnEntry) {
00292     fei::copyKeysToVector(recvCommMap, recvProcs);
00293   }
00294   else {
00295     mirrorProcs(comm, sendProcs, recvProcs);
00296     for(size_t i=0; i<recvProcs.size(); ++i) {
00297       addItemsToCommMap<T>(recvProcs[i], 0, NULL, recvCommMap);
00298     }
00299   }
00300 
00301   if (!recvLengthsKnownOnEntry) {
00302     std::vector<int> tmpIntData(sendProcs.size());
00303     std::vector<int> recvLengths(recvProcs.size());
00304     
00305     typename fei::CommMap<T>::Type::const_iterator
00306       s_iter = sendCommMap.begin(), s_end = sendCommMap.end();
00307 
00308     for(size_t i=0; s_iter != s_end; ++s_iter, ++i) {
00309       tmpIntData[i] = s_iter->second.size();
00310     }
00311 
00312     if ( exchangeIntData(comm, sendProcs, tmpIntData, recvProcs, recvLengths) != 0) {
00313       return(-1);
00314     }
00315     for(size_t i=0; i<recvProcs.size(); ++i) {
00316       std::vector<T>& rdata = recvCommMap[recvProcs[i]];
00317       rdata.resize(recvLengths[i]);
00318     }
00319   }
00320 
00321   //launch Irecv's for recv-data:
00322   std::vector<MPI_Request> mpiReqs;
00323   mpiReqs.resize(recvProcs.size());
00324 
00325   typename fei::CommMap<T>::Type::iterator
00326     r_iter = recvCommMap.begin(), r_end = recvCommMap.end();
00327 
00328   size_t req_offset = 0;
00329   for(; r_iter != r_end; ++r_iter) {
00330     int rproc = r_iter->first;
00331     std::vector<T>& recv_vec = r_iter->second;
00332     int len = recv_vec.size();
00333     T* recv_buf = len > 0 ? &recv_vec[0] : NULL;
00334 
00335     CHK_MPI( MPI_Irecv(recv_buf, len, mpi_dtype, rproc,
00336                        tag, comm, &mpiReqs[req_offset++]) );
00337   }
00338 
00339   //send the send-data:
00340 
00341   typename fei::CommMap<T>::Type::const_iterator
00342     s_iter = sendCommMap.begin(), s_end = sendCommMap.end();
00343 
00344   for(; s_iter != s_end; ++s_iter) {
00345     int sproc = s_iter->first;
00346     const std::vector<T>& send_vec = s_iter->second;
00347     int len = send_vec.size();
00348     T* send_buf = len>0 ? const_cast<T*>(&send_vec[0]) : NULL;
00349 
00350     CHK_MPI( MPI_Send(send_buf, len, mpi_dtype, sproc, tag, comm) );
00351   }
00352 
00353   //complete the Irecvs:
00354   for(size_t i=0; i<mpiReqs.size(); ++i) {
00355     int index;
00356     MPI_Status status;
00357     CHK_MPI( MPI_Waitany(mpiReqs.size(), &mpiReqs[0], &index, &status) );
00358   }
00359 
00360 #endif
00361   return(0);
00362 }
00363 
00364 
00365 //------------------------------------------------------------------------
00366 template<class T>
00367 int exchangeData(MPI_Comm comm,
00368                  std::vector<int>& sendProcs,
00369                  std::vector<std::vector<T> >& sendData,
00370                  std::vector<int>& recvProcs,
00371                  bool recvDataLengthsKnownOnEntry,
00372                  std::vector<std::vector<T> >& recvData)
00373 {
00374   if (sendProcs.size() == 0 && recvProcs.size() == 0) return(0);
00375   if (sendProcs.size() != sendData.size()) return(-1);
00376 #ifndef FEI_SER
00377   std::vector<MPI_Request> mpiReqs;
00378   mpiReqs.resize(recvProcs.size());
00379 
00380   int tag = 11119;
00381   MPI_Datatype mpi_dtype = fei::mpiTraits<T>::mpi_type();
00382 
00383   if (!recvDataLengthsKnownOnEntry) {
00384     std::vector<int> tmpIntData(sendData.size());
00385     std::vector<int> recvLengths(recvProcs.size());
00386     for(unsigned i=0; i<sendData.size(); ++i) {
00387       tmpIntData[i] = sendData[i].size();
00388     }
00389 
00390     if ( exchangeIntData(comm, sendProcs, tmpIntData, recvProcs, recvLengths) != 0) {
00391       return(-1);
00392     }
00393     for(unsigned i=0; i<recvProcs.size(); ++i) {
00394       recvData[i].resize(recvLengths[i]);
00395     }
00396   }
00397 
00398   //launch Irecv's for recvData:
00399 
00400   size_t numRecvProcs = recvProcs.size();
00401   int req_offset = 0;
00402   int localProc = fei::localProc(comm);
00403   for(size_t i=0; i<recvProcs.size(); ++i) {
00404     if (recvProcs[i] == localProc) {--numRecvProcs; continue; }
00405 
00406     int len = recvData[i].size();
00407     std::vector<T>& recv_vec = recvData[i];
00408     T* recv_buf = len > 0 ? &recv_vec[0] : NULL;
00409 
00410     CHK_MPI( MPI_Irecv(recv_buf, len, mpi_dtype, recvProcs[i],
00411                        tag, comm, &mpiReqs[req_offset++]) );
00412   }
00413 
00414   //send the sendData:
00415 
00416   for(size_t i=0; i<sendProcs.size(); ++i) {
00417     if (sendProcs[i] == localProc) continue;
00418 
00419     std::vector<T>& send_buf = sendData[i];
00420     CHK_MPI( MPI_Send(&send_buf[0], sendData[i].size(), mpi_dtype,
00421                       sendProcs[i], tag, comm) );
00422   }
00423 
00424   //complete the Irecvs:
00425   for(size_t i=0; i<numRecvProcs; ++i) {
00426     if (recvProcs[i] == localProc) continue;
00427     int index;
00428     MPI_Status status;
00429     CHK_MPI( MPI_Waitany(numRecvProcs, &mpiReqs[0], &index, &status) );
00430   }
00431 
00432 #endif
00433   return(0);
00434 }
00435 
00436 //------------------------------------------------------------------------
00437 template<class T>
00438 int exchangeData(MPI_Comm comm,
00439                  std::vector<int>& sendProcs,
00440                  std::vector<std::vector<T>*>& sendData,
00441                  std::vector<int>& recvProcs,
00442                  bool recvLengthsKnownOnEntry,
00443                  std::vector<std::vector<T>*>& recvData)
00444 {
00445   if (sendProcs.size() == 0 && recvProcs.size() == 0) return(0);
00446   if (sendProcs.size() != sendData.size()) return(-1);
00447 #ifndef FEI_SER
00448   int tag = 11115;
00449   MPI_Datatype mpi_dtype = fei::mpiTraits<T>::mpi_type();
00450   std::vector<MPI_Request> mpiReqs;
00451 
00452   try {
00453   mpiReqs.resize(recvProcs.size());
00454 
00455   if (!recvLengthsKnownOnEntry) {
00456     std::vector<int> tmpIntData;
00457     tmpIntData.resize(sendData.size());
00458     std::vector<int> recvLens(sendData.size());
00459     for(unsigned i=0; i<sendData.size(); ++i) {
00460       tmpIntData[i] = (int)sendData[i]->size();
00461     }
00462 
00463     if (exchangeIntData(comm, sendProcs, tmpIntData, recvProcs, recvLens) != 0) {
00464       return(-1);
00465     }
00466 
00467     for(unsigned i=0; i<recvLens.size(); ++i) {
00468       recvData[i]->resize(recvLens[i]);
00469     }
00470   }
00471   }
00472   catch(std::runtime_error& exc) {
00473     fei::console_out() << exc.what() << FEI_ENDL;
00474     return(-1);
00475   }
00476 
00477   //launch Irecv's for recvData:
00478 
00479   size_t numRecvProcs = recvProcs.size();
00480   int req_offset = 0;
00481   int localProc = fei::localProc(comm);
00482   for(unsigned i=0; i<recvProcs.size(); ++i) {
00483     if (recvProcs[i] == localProc) {--numRecvProcs; continue;}
00484 
00485     size_t len = recvData[i]->size();
00486     std::vector<T>& rbuf = *recvData[i];
00487 
00488     CHK_MPI( MPI_Irecv(&rbuf[0], (int)len, mpi_dtype,
00489                        recvProcs[i], tag, comm, &mpiReqs[req_offset++]) );
00490   }
00491 
00492   //send the sendData:
00493 
00494   for(unsigned i=0; i<sendProcs.size(); ++i) {
00495     if (sendProcs[i] == localProc) continue;
00496 
00497     std::vector<T>& sbuf = *sendData[i];
00498     CHK_MPI( MPI_Send(&sbuf[0], (int)sbuf.size(), mpi_dtype,
00499                       sendProcs[i], tag, comm) );
00500   }
00501 
00502   //complete the Irecv's:
00503   for(unsigned i=0; i<numRecvProcs; ++i) {
00504     if (recvProcs[i] == localProc) continue;
00505     int index;
00506     MPI_Status status;
00507     CHK_MPI( MPI_Waitany((int)numRecvProcs, &mpiReqs[0], &index, &status) );
00508   }
00509 
00510 #endif
00511   return(0);
00512 }
00513 
00514 
00515 //------------------------------------------------------------------------
00533 template<class T>
00534 class MessageHandler {
00535 public:
00536   virtual ~MessageHandler(){}
00537 
00541   virtual std::vector<int>& getSendProcs() = 0;
00542 
00546   virtual std::vector<int>& getRecvProcs() = 0;
00547 
00551   virtual int getSendMessageLength(int destProc, int& messageLength) = 0;
00552 
00556   virtual int getSendMessage(int destProc, std::vector<T>& message) = 0;
00557 
00561   virtual int processRecvMessage(int srcProc, std::vector<T>& message) = 0;
00562 };//class MessageHandler
00563 
00564 
00565 //------------------------------------------------------------------------
00566 template<class T>
00567 int exchange(MPI_Comm comm, MessageHandler<T>* msgHandler)
00568 {
00569 #ifdef FEI_SER
00570   (void)msgHandler;
00571 #else
00572   int numProcs = fei::numProcs(comm);
00573   if (numProcs < 2) return(0);
00574 
00575   std::vector<int>& sendProcs = msgHandler->getSendProcs();
00576   int numSendProcs = sendProcs.size();
00577   std::vector<int>& recvProcs = msgHandler->getRecvProcs();
00578   int numRecvProcs = recvProcs.size();
00579   int i;
00580 
00581   if (numSendProcs < 1 && numRecvProcs < 1) {
00582     return(0);
00583   }
00584 
00585   std::vector<int> sendMsgLengths(numSendProcs);
00586 
00587   for(i=0; i<numSendProcs; ++i) {
00588     CHK_ERR( msgHandler->getSendMessageLength(sendProcs[i], sendMsgLengths[i]) );
00589   }
00590 
00591   std::vector<std::vector<T> > recvMsgs(numRecvProcs);
00592 
00593   std::vector<std::vector<T> > sendMsgs(numSendProcs);
00594   for(i=0; i<numSendProcs; ++i) {
00595     CHK_ERR( msgHandler->getSendMessage(sendProcs[i], sendMsgs[i]) );
00596   }
00597 
00598   CHK_ERR( exchangeData(comm, sendProcs, sendMsgs,
00599                         recvProcs, false, recvMsgs) );
00600 
00601   for(i=0; i<numRecvProcs; ++i) {
00602     std::vector<T>& recvMsg = recvMsgs[i];
00603     CHK_ERR( msgHandler->processRecvMessage(recvProcs[i], recvMsg ) );
00604   }
00605 #endif
00606 
00607   return(0);
00608 }
00609 
00610 
00611 } //namespace fei
00612 
00613 #endif // _fei_CommUtils_hpp_
00614 
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Friends