FEI Version of the Day
fei_CommUtils.hpp
00001 
00002 /*--------------------------------------------------------------------*/
00003 /*    Copyright 2007 Sandia Corporation.                              */
00004 /*    Under the terms of Contract DE-AC04-94AL85000, there is a       */
00005 /*    non-exclusive license for use of this work by or on behalf      */
00006 /*    of the U.S. Government.  Export of this program may require     */
00007 /*    a license from the United States Government.                    */
00008 /*--------------------------------------------------------------------*/
00009 
00010 #ifndef _fei_CommUtils_hpp_
00011 #define _fei_CommUtils_hpp_
00012 
00013 #include <fei_macros.hpp>
00014 #include <fei_mpi.h>
00015 #include <fei_mpiTraits.hpp>
00016 #include <fei_chk_mpi.hpp>
00017 #include <fei_iostream.hpp>
00018 #include <fei_CommMap.hpp>
00019 #include <fei_TemplateUtils.hpp>
00020 #include <snl_fei_RaggedTable.hpp>
00021 
00022 #include <vector>
00023 #include <set>
00024 #include <map>
00025 
00026 #include <fei_ErrMacros.hpp>
00027 #undef fei_file
00028 #define fei_file "fei_CommUtils.hpp"
00029 
00030 namespace fei {
00031 
00035 int localProc(MPI_Comm comm);
00036 
00040 int numProcs(MPI_Comm comm);
00041 
00042 void Barrier(MPI_Comm comm);
00043 
00049 int mirrorProcs(MPI_Comm comm, std::vector<int>& toProcs, std::vector<int>& fromProcs);
00050 
00051 typedef snl_fei::RaggedTable<std::map<int,std::set<int>*>,std::set<int> > comm_map;
00052 
00055 int mirrorCommPattern(MPI_Comm comm, comm_map* inPattern, comm_map*& outPattern);
00056 
00061 int Allreduce(MPI_Comm comm, bool localBool, bool& globalBool);
00062 
00078 int exchangeIntData(MPI_Comm comm,
00079                     const std::vector<int>& sendProcs,
00080                     std::vector<int>& sendData,
00081                     const std::vector<int>& recvProcs,
00082                     std::vector<int>& recvData);
00083  
00084 //----------------------------------------------------------------------------
00090 template<class T>
00091 int GlobalMax(MPI_Comm comm, std::vector<T>& local, std::vector<T>& global)
00092 {
00093 #ifdef FEI_SER
00094   global = local;
00095 #else
00096 
00097   MPI_Datatype mpi_dtype = fei::mpiTraits<T>::mpi_type();
00098 
00099   try {
00100     global.resize(local.size());
00101   }
00102   catch(std::runtime_error& exc) {
00103     fei::console_out() << exc.what()<<FEI_ENDL;
00104     return(-1);
00105   }
00106 
00107   CHK_MPI( MPI_Allreduce(&(local[0]), &(global[0]),
00108        local.size(), mpi_dtype, MPI_MAX, comm) );
00109 #endif
00110 
00111   return(0);
00112 }
00113 
00114 //----------------------------------------------------------------------------
00120 template<class T>
00121 int GlobalMax(MPI_Comm comm, T local, T& global)
00122 {
00123 #ifdef FEI_SER
00124   global = local;
00125 #else
00126   MPI_Datatype mpi_dtype = fei::mpiTraits<T>::mpi_type();
00127 
00128   CHK_MPI( MPI_Allreduce(&local, &global, 1, mpi_dtype, MPI_MAX, comm) );
00129 #endif
00130   return(0);
00131 }
00132 
00133 //----------------------------------------------------------------------------
00139 template<class T>
00140 int GlobalMin(MPI_Comm comm, std::vector<T>& local, std::vector<T>& global)
00141 {
00142 #ifdef FEI_SER
00143   global = local;
00144 #else
00145 
00146   MPI_Datatype mpi_dtype = fei::mpiTraits<T>::mpi_type();
00147 
00148   try {
00149     global.resize(local.size());
00150   }
00151   catch(std::runtime_error& exc) {
00152     fei::console_out() << exc.what()<<FEI_ENDL;
00153     return(-1);
00154   }
00155 
00156   CHK_MPI( MPI_Allreduce(&(local[0]), &(global[0]),
00157        local.size(), mpi_dtype, MPI_MIN, comm) );
00158 #endif
00159 
00160   return(0);
00161 }
00162 
00163 //----------------------------------------------------------------------------
00169 template<class T>
00170 int GlobalMin(MPI_Comm comm, T local, T& global)
00171 {
00172 #ifdef FEI_SER
00173   global = local;
00174 #else
00175   MPI_Datatype mpi_dtype = fei::mpiTraits<T>::mpi_type();
00176 
00177   CHK_MPI( MPI_Allreduce(&local, &global, 1, mpi_dtype, MPI_MIN, comm) );
00178 #endif
00179   return(0);
00180 }
00181 
00182 //----------------------------------------------------------------------------
00188 template<class T>
00189 int GlobalSum(MPI_Comm comm, std::vector<T>& local, std::vector<T>& global)
00190 {
00191 #ifdef FEI_SER
00192   global = local;
00193 #else
00194   global.resize(local.size());
00195 
00196   MPI_Datatype mpi_dtype = fei::mpiTraits<T>::mpi_type();
00197 
00198   CHK_MPI( MPI_Allreduce(&(local[0]), &(global[0]),
00199                       local.size(), mpi_dtype, MPI_SUM, comm) );
00200 #endif
00201   return(0);
00202 }
00203 
00204 //----------------------------------------------------------------------------
00206 template<class T>
00207 int GlobalSum(MPI_Comm comm, T local, T& global)
00208 {
00209 #ifdef FEI_SER
00210   global = local;
00211 #else
00212   MPI_Datatype mpi_dtype = fei::mpiTraits<T>::mpi_type();
00213 
00214   CHK_MPI( MPI_Allreduce(&local, &global, 1, mpi_dtype, MPI_SUM, comm) );
00215 #endif
00216   return(0);
00217 }
00218 
00219 
00222 template<class T>
00223 int Allgatherv(MPI_Comm comm,
00224                std::vector<T>& sendbuf,
00225                std::vector<int>& recvLengths,
00226                std::vector<T>& recvbuf)
00227 {
00228 #ifdef FEI_SER
00229   //If we're in serial mode, just copy sendbuf to recvbuf and return.
00230 
00231   recvbuf = sendbuf;
00232   recvLengths.resize(1);
00233   recvLengths[0] = sendbuf.size();
00234 #else
00235   int numProcs = 1;
00236   MPI_Comm_size(comm, &numProcs);
00237 
00238   try {
00239 
00240   MPI_Datatype mpi_dtype = fei::mpiTraits<T>::mpi_type();
00241 
00242   std::vector<int> tmpInt(numProcs, 0);
00243 
00244   int len = sendbuf.size();
00245   int* tmpBuf = &tmpInt[0];
00246 
00247   recvLengths.resize(numProcs);
00248   int* recvLenPtr = &recvLengths[0];
00249 
00250   CHK_MPI( MPI_Allgather(&len, 1, MPI_INT, recvLenPtr, 1, MPI_INT, comm) );
00251 
00252   int displ = 0;
00253   for(int i=0; i<numProcs; i++) {
00254     tmpBuf[i] = displ;
00255     displ += recvLenPtr[i];
00256   }
00257 
00258   if (displ == 0) {
00259     recvbuf.resize(0);
00260     return(0);
00261   }
00262 
00263   recvbuf.resize(displ);
00264 
00265   T* sendbufPtr = sendbuf.size()>0 ? &sendbuf[0] : NULL;
00266   
00267   CHK_MPI( MPI_Allgatherv(sendbufPtr, len, mpi_dtype,
00268       &recvbuf[0], &recvLengths[0], tmpBuf,
00269       mpi_dtype, comm) );
00270 
00271   }
00272   catch(std::runtime_error& exc) {
00273     fei::console_out() << exc.what() << FEI_ENDL;
00274     return(-1);
00275   }
00276 #endif
00277 
00278   return(0);
00279 }
00280 
00281 //------------------------------------------------------------------------
00282 template<class T>
00283 int Bcast(MPI_Comm comm, std::vector<T>& sendbuf, int sourceProc)
00284 {
00285 #ifndef FEI_SER
00286   MPI_Datatype mpi_dtype = fei::mpiTraits<T>::mpi_type();
00287 
00288   CHK_MPI(MPI_Bcast(&sendbuf[0], sendbuf.size(), mpi_dtype,
00289                     sourceProc, comm) );
00290 #endif
00291   return(0);
00292 }
00293 
00294 //------------------------------------------------------------------------
00302 template<typename T>
00303 int exchangeCommMapData(MPI_Comm comm,
00304                         const typename CommMap<T>::Type& sendCommMap,
00305                         typename CommMap<T>::Type& recvCommMap,
00306                         bool recvProcsKnownOnEntry = false,
00307                         bool recvLengthsKnownOnEntry = false)
00308 {
00309   if (!recvProcsKnownOnEntry) {
00310     recvCommMap.clear();
00311   }
00312 
00313 #ifndef FEI_SER
00314   int tag = 11120;
00315   MPI_Datatype mpi_dtype = fei::mpiTraits<T>::mpi_type();
00316 
00317   std::vector<int> sendProcs;
00318   fei::copyKeysToVector(sendCommMap, sendProcs);
00319   std::vector<int> recvProcs;
00320 
00321   if (recvProcsKnownOnEntry) {
00322     fei::copyKeysToVector(recvCommMap, recvProcs);
00323   }
00324   else {
00325     mirrorProcs(comm, sendProcs, recvProcs);
00326     for(size_t i=0; i<recvProcs.size(); ++i) {
00327       addItemsToCommMap<T>(recvProcs[i], 0, NULL, recvCommMap);
00328     }
00329   }
00330 
00331   if (!recvLengthsKnownOnEntry) {
00332     std::vector<int> tmpIntData(sendProcs.size());
00333     std::vector<int> recvLengths(recvProcs.size());
00334     
00335     typename fei::CommMap<T>::Type::const_iterator
00336       s_iter = sendCommMap.begin(), s_end = sendCommMap.end();
00337 
00338     for(size_t i=0; s_iter != s_end; ++s_iter, ++i) {
00339       tmpIntData[i] = s_iter->second.size();
00340     }
00341 
00342     if ( exchangeIntData(comm, sendProcs, tmpIntData, recvProcs, recvLengths) != 0) {
00343       return(-1);
00344     }
00345     for(size_t i=0; i<recvProcs.size(); ++i) {
00346       std::vector<T>& rdata = recvCommMap[recvProcs[i]];
00347       rdata.resize(recvLengths[i]);
00348     }
00349   }
00350 
00351   //launch Irecv's for recv-data:
00352   std::vector<MPI_Request> mpiReqs;
00353   mpiReqs.resize(recvProcs.size());
00354 
00355   typename fei::CommMap<T>::Type::iterator
00356     r_iter = recvCommMap.begin(), r_end = recvCommMap.end();
00357 
00358   size_t req_offset = 0;
00359   for(; r_iter != r_end; ++r_iter) {
00360     int rproc = r_iter->first;
00361     std::vector<T>& recv_vec = r_iter->second;
00362     int len = recv_vec.size();
00363     T* recv_buf = len > 0 ? &recv_vec[0] : NULL;
00364 
00365     CHK_MPI( MPI_Irecv(recv_buf, len, mpi_dtype, rproc,
00366                        tag, comm, &mpiReqs[req_offset++]) );
00367   }
00368 
00369   //send the send-data:
00370 
00371   typename fei::CommMap<T>::Type::const_iterator
00372     s_iter = sendCommMap.begin(), s_end = sendCommMap.end();
00373 
00374   for(; s_iter != s_end; ++s_iter) {
00375     int sproc = s_iter->first;
00376     const std::vector<T>& send_vec = s_iter->second;
00377     int len = send_vec.size();
00378     T* send_buf = len>0 ? const_cast<T*>(&send_vec[0]) : NULL;
00379 
00380     CHK_MPI( MPI_Send(send_buf, len, mpi_dtype, sproc, tag, comm) );
00381   }
00382 
00383   //complete the Irecvs:
00384   for(size_t i=0; i<mpiReqs.size(); ++i) {
00385     int index;
00386     MPI_Status status;
00387     CHK_MPI( MPI_Waitany(mpiReqs.size(), &mpiReqs[0], &index, &status) );
00388   }
00389 
00390 #endif
00391   return(0);
00392 }
00393 
00394 
00395 //------------------------------------------------------------------------
00396 template<class T>
00397 int exchangeData(MPI_Comm comm,
00398                  std::vector<int>& sendProcs,
00399                  std::vector<std::vector<T> >& sendData,
00400                  std::vector<int>& recvProcs,
00401                  bool recvDataLengthsKnownOnEntry,
00402                  std::vector<std::vector<T> >& recvData)
00403 {
00404   if (sendProcs.size() == 0 && recvProcs.size() == 0) return(0);
00405   if (sendProcs.size() != sendData.size()) return(-1);
00406 #ifndef FEI_SER
00407   std::vector<MPI_Request> mpiReqs;
00408   mpiReqs.resize(recvProcs.size());
00409 
00410   int tag = 11119;
00411   MPI_Datatype mpi_dtype = fei::mpiTraits<T>::mpi_type();
00412 
00413   if (!recvDataLengthsKnownOnEntry) {
00414     std::vector<int> tmpIntData(sendData.size());
00415     std::vector<int> recvLengths(recvProcs.size());
00416     for(unsigned i=0; i<sendData.size(); ++i) {
00417       tmpIntData[i] = sendData[i].size();
00418     }
00419 
00420     if ( exchangeIntData(comm, sendProcs, tmpIntData, recvProcs, recvLengths) != 0) {
00421       return(-1);
00422     }
00423     for(unsigned i=0; i<recvProcs.size(); ++i) {
00424       recvData[i].resize(recvLengths[i]);
00425     }
00426   }
00427 
00428   //launch Irecv's for recvData:
00429 
00430   size_t numRecvProcs = recvProcs.size();
00431   int req_offset = 0;
00432   int localProc = fei::localProc(comm);
00433   for(size_t i=0; i<recvProcs.size(); ++i) {
00434     if (recvProcs[i] == localProc) {--numRecvProcs; continue; }
00435 
00436     int len = recvData[i].size();
00437     std::vector<T>& recv_vec = recvData[i];
00438     T* recv_buf = len > 0 ? &recv_vec[0] : NULL;
00439 
00440     CHK_MPI( MPI_Irecv(recv_buf, len, mpi_dtype, recvProcs[i],
00441                        tag, comm, &mpiReqs[req_offset++]) );
00442   }
00443 
00444   //send the sendData:
00445 
00446   for(size_t i=0; i<sendProcs.size(); ++i) {
00447     if (sendProcs[i] == localProc) continue;
00448 
00449     std::vector<T>& send_buf = sendData[i];
00450     CHK_MPI( MPI_Send(&send_buf[0], sendData[i].size(), mpi_dtype,
00451                       sendProcs[i], tag, comm) );
00452   }
00453 
00454   //complete the Irecvs:
00455   for(size_t i=0; i<numRecvProcs; ++i) {
00456     if (recvProcs[i] == localProc) continue;
00457     int index;
00458     MPI_Status status;
00459     CHK_MPI( MPI_Waitany(numRecvProcs, &mpiReqs[0], &index, &status) );
00460   }
00461 
00462 #endif
00463   return(0);
00464 }
00465 
00466 //------------------------------------------------------------------------
00467 template<class T>
00468 int exchangeData(MPI_Comm comm,
00469                  std::vector<int>& sendProcs,
00470                  std::vector<std::vector<T>*>& sendData,
00471                  std::vector<int>& recvProcs,
00472                  bool recvLengthsKnownOnEntry,
00473                  std::vector<std::vector<T>*>& recvData)
00474 {
00475   if (sendProcs.size() == 0 && recvProcs.size() == 0) return(0);
00476   if (sendProcs.size() != sendData.size()) return(-1);
00477 #ifndef FEI_SER
00478   int tag = 11115;
00479   MPI_Datatype mpi_dtype = fei::mpiTraits<T>::mpi_type();
00480   std::vector<MPI_Request> mpiReqs;
00481 
00482   try {
00483   mpiReqs.resize(recvProcs.size());
00484 
00485   if (!recvLengthsKnownOnEntry) {
00486     std::vector<int> tmpIntData;
00487     tmpIntData.resize(sendData.size());
00488     std::vector<int> recvLens(sendData.size());
00489     for(unsigned i=0; i<sendData.size(); ++i) {
00490       tmpIntData[i] = (int)sendData[i]->size();
00491     }
00492 
00493     if (exchangeIntData(comm, sendProcs, tmpIntData, recvProcs, recvLens) != 0) {
00494       return(-1);
00495     }
00496 
00497     for(unsigned i=0; i<recvLens.size(); ++i) {
00498       recvData[i]->resize(recvLens[i]);
00499     }
00500   }
00501   }
00502   catch(std::runtime_error& exc) {
00503     fei::console_out() << exc.what() << FEI_ENDL;
00504     return(-1);
00505   }
00506 
00507   //launch Irecv's for recvData:
00508 
00509   size_t numRecvProcs = recvProcs.size();
00510   int req_offset = 0;
00511   int localProc = fei::localProc(comm);
00512   for(unsigned i=0; i<recvProcs.size(); ++i) {
00513     if (recvProcs[i] == localProc) {--numRecvProcs; continue;}
00514 
00515     size_t len = recvData[i]->size();
00516     std::vector<T>& rbuf = *recvData[i];
00517 
00518     CHK_MPI( MPI_Irecv(&rbuf[0], (int)len, mpi_dtype,
00519                        recvProcs[i], tag, comm, &mpiReqs[req_offset++]) );
00520   }
00521 
00522   //send the sendData:
00523 
00524   for(unsigned i=0; i<sendProcs.size(); ++i) {
00525     if (sendProcs[i] == localProc) continue;
00526 
00527     std::vector<T>& sbuf = *sendData[i];
00528     CHK_MPI( MPI_Send(&sbuf[0], (int)sbuf.size(), mpi_dtype,
00529                       sendProcs[i], tag, comm) );
00530   }
00531 
00532   //complete the Irecv's:
00533   for(unsigned i=0; i<numRecvProcs; ++i) {
00534     if (recvProcs[i] == localProc) continue;
00535     int index;
00536     MPI_Status status;
00537     CHK_MPI( MPI_Waitany((int)numRecvProcs, &mpiReqs[0], &index, &status) );
00538   }
00539 
00540 #endif
00541   return(0);
00542 }
00543 
00544 
00545 //------------------------------------------------------------------------
00563 template<class T>
00564 class MessageHandler {
00565 public:
00566   virtual ~MessageHandler(){}
00567 
00571   virtual std::vector<int>& getSendProcs() = 0;
00572 
00576   virtual std::vector<int>& getRecvProcs() = 0;
00577 
00581   virtual int getSendMessageLength(int destProc, int& messageLength) = 0;
00582 
00586   virtual int getSendMessage(int destProc, std::vector<T>& message) = 0;
00587 
00591   virtual int processRecvMessage(int srcProc, std::vector<T>& message) = 0;
00592 };//class MessageHandler
00593 
00594 
00595 //------------------------------------------------------------------------
00596 template<class T>
00597 int exchange(MPI_Comm comm, MessageHandler<T>* msgHandler)
00598 {
00599 #ifdef FEI_SER
00600   (void)msgHandler;
00601 #else
00602   int numProcs = fei::numProcs(comm);
00603   if (numProcs < 2) return(0);
00604 
00605   std::vector<int>& sendProcs = msgHandler->getSendProcs();
00606   int numSendProcs = sendProcs.size();
00607   std::vector<int>& recvProcs = msgHandler->getRecvProcs();
00608   int numRecvProcs = recvProcs.size();
00609   int i;
00610 
00611   if (numSendProcs < 1 && numRecvProcs < 1) {
00612     return(0);
00613   }
00614 
00615   std::vector<int> sendMsgLengths(numSendProcs);
00616 
00617   for(i=0; i<numSendProcs; ++i) {
00618     CHK_ERR( msgHandler->getSendMessageLength(sendProcs[i], sendMsgLengths[i]) );
00619   }
00620 
00621   std::vector<std::vector<T> > recvMsgs(numRecvProcs);
00622 
00623   std::vector<std::vector<T> > sendMsgs(numSendProcs);
00624   for(i=0; i<numSendProcs; ++i) {
00625     CHK_ERR( msgHandler->getSendMessage(sendProcs[i], sendMsgs[i]) );
00626   }
00627 
00628   CHK_ERR( exchangeData(comm, sendProcs, sendMsgs,
00629                         recvProcs, false, recvMsgs) );
00630 
00631   for(i=0; i<numRecvProcs; ++i) {
00632     std::vector<T>& recvMsg = recvMsgs[i];
00633     CHK_ERR( msgHandler->processRecvMessage(recvProcs[i], recvMsg ) );
00634   }
00635 #endif
00636 
00637   return(0);
00638 }
00639 
00640 
00641 } //namespace fei
00642 
00643 #endif // _fei_CommUtils_hpp_
00644 
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Friends