FEI Version of the Day
fei_CommUtils.hpp
00001 /*
00002 // @HEADER
00003 // ************************************************************************
00004 //             FEI: Finite Element Interface to Linear Solvers
00005 //                  Copyright (2005) Sandia Corporation.
00006 //
00007 // Under the terms of Contract DE-AC04-94AL85000 with Sandia Corporation, the
00008 // U.S. Government retains certain rights in this software.
00009 //
00010 // Redistribution and use in source and binary forms, with or without
00011 // modification, are permitted provided that the following conditions are
00012 // met:
00013 //
00014 // 1. Redistributions of source code must retain the above copyright
00015 // notice, this list of conditions and the following disclaimer.
00016 //
00017 // 2. Redistributions in binary form must reproduce the above copyright
00018 // notice, this list of conditions and the following disclaimer in the
00019 // documentation and/or other materials provided with the distribution.
00020 //
00021 // 3. Neither the name of the Corporation nor the names of the
00022 // contributors may be used to endorse or promote products derived from
00023 // this software without specific prior written permission.
00024 //
00025 // THIS SOFTWARE IS PROVIDED BY SANDIA CORPORATION "AS IS" AND ANY
00026 // EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
00027 // IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
00028 // PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL SANDIA CORPORATION OR THE
00029 // CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
00030 // EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
00031 // PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
00032 // PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
00033 // LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
00034 // NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
00035 // SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
00036 //
00037 // Questions? Contact Alan Williams (william@sandia.gov) 
00038 //
00039 // ************************************************************************
00040 // @HEADER
00041 */
00042 
00043 #ifndef _fei_CommUtils_hpp_
00044 #define _fei_CommUtils_hpp_
00045 
00046 #include <fei_macros.hpp>
00047 #include <fei_mpi.h>
00048 #include <fei_mpiTraits.hpp>
00049 #include <fei_chk_mpi.hpp>
00050 #include <fei_iostream.hpp>
00051 #include <fei_CommMap.hpp>
00052 #include <fei_TemplateUtils.hpp>
00053 #include <snl_fei_RaggedTable.hpp>
00054 
00055 #include <vector>
00056 #include <set>
00057 #include <map>
00058 
00059 #include <fei_ErrMacros.hpp>
00060 #undef fei_file
00061 #define fei_file "fei_CommUtils.hpp"
00062 
00063 namespace fei {
00064 
00068 int localProc(MPI_Comm comm);
00069 
00073 int numProcs(MPI_Comm comm);
00074 
00075 void Barrier(MPI_Comm comm);
00076 
00082 int mirrorProcs(MPI_Comm comm, std::vector<int>& toProcs, std::vector<int>& fromProcs);
00083 
00084 typedef snl_fei::RaggedTable<std::map<int,std::set<int>*>,std::set<int> > comm_map;
00085 
00088 int mirrorCommPattern(MPI_Comm comm, comm_map* inPattern, comm_map*& outPattern);
00089 
00094 int Allreduce(MPI_Comm comm, bool localBool, bool& globalBool);
00095 
00111 int exchangeIntData(MPI_Comm comm,
00112                     const std::vector<int>& sendProcs,
00113                     std::vector<int>& sendData,
00114                     const std::vector<int>& recvProcs,
00115                     std::vector<int>& recvData);
00116  
00117 //----------------------------------------------------------------------------
00123 template<class T>
00124 int GlobalMax(MPI_Comm comm, std::vector<T>& local, std::vector<T>& global)
00125 {
00126 #ifdef FEI_SER
00127   global = local;
00128 #else
00129 
00130   MPI_Datatype mpi_dtype = fei::mpiTraits<T>::mpi_type();
00131 
00132   try {
00133     global.resize(local.size());
00134   }
00135   catch(std::runtime_error& exc) {
00136     fei::console_out() << exc.what()<<FEI_ENDL;
00137     return(-1);
00138   }
00139 
00140   CHK_MPI( MPI_Allreduce(&(local[0]), &(global[0]),
00141        local.size(), mpi_dtype, MPI_MAX, comm) );
00142 #endif
00143 
00144   return(0);
00145 }
00146 
00147 //----------------------------------------------------------------------------
00153 template<class T>
00154 int GlobalMax(MPI_Comm comm, T local, T& global)
00155 {
00156 #ifdef FEI_SER
00157   global = local;
00158 #else
00159   MPI_Datatype mpi_dtype = fei::mpiTraits<T>::mpi_type();
00160 
00161   CHK_MPI( MPI_Allreduce(&local, &global, 1, mpi_dtype, MPI_MAX, comm) );
00162 #endif
00163   return(0);
00164 }
00165 
00166 //----------------------------------------------------------------------------
00172 template<class T>
00173 int GlobalMin(MPI_Comm comm, T local, T& global)
00174 {
00175 #ifdef FEI_SER
00176   global = local;
00177 #else
00178   MPI_Datatype mpi_dtype = fei::mpiTraits<T>::mpi_type();
00179 
00180   CHK_MPI( MPI_Allreduce(&local, &global, 1, mpi_dtype, MPI_MIN, comm) );
00181 #endif
00182   return(0);
00183 }
00184 
00185 //----------------------------------------------------------------------------
00191 template<class T>
00192 int GlobalSum(MPI_Comm comm, std::vector<T>& local, std::vector<T>& global)
00193 {
00194 #ifdef FEI_SER
00195   global = local;
00196 #else
00197   global.resize(local.size());
00198 
00199   MPI_Datatype mpi_dtype = fei::mpiTraits<T>::mpi_type();
00200 
00201   CHK_MPI( MPI_Allreduce(&(local[0]), &(global[0]),
00202                       local.size(), mpi_dtype, MPI_SUM, comm) );
00203 #endif
00204   return(0);
00205 }
00206 
00207 //----------------------------------------------------------------------------
00209 template<class T>
00210 int GlobalSum(MPI_Comm comm, T local, T& global)
00211 {
00212 #ifdef FEI_SER
00213   global = local;
00214 #else
00215   MPI_Datatype mpi_dtype = fei::mpiTraits<T>::mpi_type();
00216 
00217   CHK_MPI( MPI_Allreduce(&local, &global, 1, mpi_dtype, MPI_SUM, comm) );
00218 #endif
00219   return(0);
00220 }
00221 
00222 
00225 template<class T>
00226 int Allgatherv(MPI_Comm comm,
00227                std::vector<T>& sendbuf,
00228                std::vector<int>& recvLengths,
00229                std::vector<T>& recvbuf)
00230 {
00231 #ifdef FEI_SER
00232   //If we're in serial mode, just copy sendbuf to recvbuf and return.
00233 
00234   recvbuf = sendbuf;
00235   recvLengths.resize(1);
00236   recvLengths[0] = sendbuf.size();
00237 #else
00238   int numProcs = 1;
00239   MPI_Comm_size(comm, &numProcs);
00240 
00241   try {
00242 
00243   MPI_Datatype mpi_dtype = fei::mpiTraits<T>::mpi_type();
00244 
00245   std::vector<int> tmpInt(numProcs, 0);
00246 
00247   int len = sendbuf.size();
00248   int* tmpBuf = &tmpInt[0];
00249 
00250   recvLengths.resize(numProcs);
00251   int* recvLenPtr = &recvLengths[0];
00252 
00253   CHK_MPI( MPI_Allgather(&len, 1, MPI_INT, recvLenPtr, 1, MPI_INT, comm) );
00254 
00255   int displ = 0;
00256   for(int i=0; i<numProcs; i++) {
00257     tmpBuf[i] = displ;
00258     displ += recvLenPtr[i];
00259   }
00260 
00261   if (displ == 0) {
00262     recvbuf.resize(0);
00263     return(0);
00264   }
00265 
00266   recvbuf.resize(displ);
00267 
00268   T* sendbufPtr = sendbuf.size()>0 ? &sendbuf[0] : NULL;
00269   
00270   CHK_MPI( MPI_Allgatherv(sendbufPtr, len, mpi_dtype,
00271       &recvbuf[0], &recvLengths[0], tmpBuf,
00272       mpi_dtype, comm) );
00273 
00274   }
00275   catch(std::runtime_error& exc) {
00276     fei::console_out() << exc.what() << FEI_ENDL;
00277     return(-1);
00278   }
00279 #endif
00280 
00281   return(0);
00282 }
00283 
00284 //------------------------------------------------------------------------
00285 template<class T>
00286 int Bcast(MPI_Comm comm, std::vector<T>& sendbuf, int sourceProc)
00287 {
00288 #ifndef FEI_SER
00289   MPI_Datatype mpi_dtype = fei::mpiTraits<T>::mpi_type();
00290 
00291   CHK_MPI(MPI_Bcast(&sendbuf[0], sendbuf.size(), mpi_dtype,
00292                     sourceProc, comm) );
00293 #endif
00294   return(0);
00295 }
00296 
00297 //------------------------------------------------------------------------
00305 template<typename T>
00306 int exchangeCommMapData(MPI_Comm comm,
00307                         const typename CommMap<T>::Type& sendCommMap,
00308                         typename CommMap<T>::Type& recvCommMap,
00309                         bool recvProcsKnownOnEntry = false,
00310                         bool recvLengthsKnownOnEntry = false)
00311 {
00312   if (!recvProcsKnownOnEntry) {
00313     recvCommMap.clear();
00314   }
00315 
00316 #ifndef FEI_SER
00317   int tag = 11120;
00318   MPI_Datatype mpi_dtype = fei::mpiTraits<T>::mpi_type();
00319 
00320   std::vector<int> sendProcs;
00321   fei::copyKeysToVector(sendCommMap, sendProcs);
00322   std::vector<int> recvProcs;
00323 
00324   if (recvProcsKnownOnEntry) {
00325     fei::copyKeysToVector(recvCommMap, recvProcs);
00326   }
00327   else {
00328     mirrorProcs(comm, sendProcs, recvProcs);
00329     for(size_t i=0; i<recvProcs.size(); ++i) {
00330       addItemsToCommMap<T>(recvProcs[i], 0, NULL, recvCommMap);
00331     }
00332   }
00333 
00334   if (!recvLengthsKnownOnEntry) {
00335     std::vector<int> tmpIntData(sendProcs.size());
00336     std::vector<int> recvLengths(recvProcs.size());
00337     
00338     typename fei::CommMap<T>::Type::const_iterator
00339       s_iter = sendCommMap.begin(), s_end = sendCommMap.end();
00340 
00341     for(size_t i=0; s_iter != s_end; ++s_iter, ++i) {
00342       tmpIntData[i] = s_iter->second.size();
00343     }
00344 
00345     if ( exchangeIntData(comm, sendProcs, tmpIntData, recvProcs, recvLengths) != 0) {
00346       return(-1);
00347     }
00348     for(size_t i=0; i<recvProcs.size(); ++i) {
00349       std::vector<T>& rdata = recvCommMap[recvProcs[i]];
00350       rdata.resize(recvLengths[i]);
00351     }
00352   }
00353 
00354   //launch Irecv's for recv-data:
00355   std::vector<MPI_Request> mpiReqs;
00356   mpiReqs.resize(recvProcs.size());
00357 
00358   typename fei::CommMap<T>::Type::iterator
00359     r_iter = recvCommMap.begin(), r_end = recvCommMap.end();
00360 
00361   size_t req_offset = 0;
00362   for(; r_iter != r_end; ++r_iter) {
00363     int rproc = r_iter->first;
00364     std::vector<T>& recv_vec = r_iter->second;
00365     int len = recv_vec.size();
00366     T* recv_buf = len > 0 ? &recv_vec[0] : NULL;
00367 
00368     CHK_MPI( MPI_Irecv(recv_buf, len, mpi_dtype, rproc,
00369                        tag, comm, &mpiReqs[req_offset++]) );
00370   }
00371 
00372   //send the send-data:
00373 
00374   typename fei::CommMap<T>::Type::const_iterator
00375     s_iter = sendCommMap.begin(), s_end = sendCommMap.end();
00376 
00377   for(; s_iter != s_end; ++s_iter) {
00378     int sproc = s_iter->first;
00379     const std::vector<T>& send_vec = s_iter->second;
00380     int len = send_vec.size();
00381     T* send_buf = len>0 ? const_cast<T*>(&send_vec[0]) : NULL;
00382 
00383     CHK_MPI( MPI_Send(send_buf, len, mpi_dtype, sproc, tag, comm) );
00384   }
00385 
00386   //complete the Irecvs:
00387   for(size_t i=0; i<mpiReqs.size(); ++i) {
00388     int index;
00389     MPI_Status status;
00390     CHK_MPI( MPI_Waitany(mpiReqs.size(), &mpiReqs[0], &index, &status) );
00391   }
00392 
00393 #endif
00394   return(0);
00395 }
00396 
00397 
00398 //------------------------------------------------------------------------
00399 template<class T>
00400 int exchangeData(MPI_Comm comm,
00401                  std::vector<int>& sendProcs,
00402                  std::vector<std::vector<T> >& sendData,
00403                  std::vector<int>& recvProcs,
00404                  bool recvDataLengthsKnownOnEntry,
00405                  std::vector<std::vector<T> >& recvData)
00406 {
00407   if (sendProcs.size() == 0 && recvProcs.size() == 0) return(0);
00408   if (sendProcs.size() != sendData.size()) return(-1);
00409 #ifndef FEI_SER
00410   std::vector<MPI_Request> mpiReqs;
00411   mpiReqs.resize(recvProcs.size());
00412 
00413   int tag = 11119;
00414   MPI_Datatype mpi_dtype = fei::mpiTraits<T>::mpi_type();
00415 
00416   if (!recvDataLengthsKnownOnEntry) {
00417     std::vector<int> tmpIntData(sendData.size());
00418     std::vector<int> recvLengths(recvProcs.size());
00419     for(unsigned i=0; i<sendData.size(); ++i) {
00420       tmpIntData[i] = sendData[i].size();
00421     }
00422 
00423     if ( exchangeIntData(comm, sendProcs, tmpIntData, recvProcs, recvLengths) != 0) {
00424       return(-1);
00425     }
00426     for(unsigned i=0; i<recvProcs.size(); ++i) {
00427       recvData[i].resize(recvLengths[i]);
00428     }
00429   }
00430 
00431   //launch Irecv's for recvData:
00432 
00433   size_t numRecvProcs = recvProcs.size();
00434   int req_offset = 0;
00435   int localProc = fei::localProc(comm);
00436   for(size_t i=0; i<recvProcs.size(); ++i) {
00437     if (recvProcs[i] == localProc) {--numRecvProcs; continue; }
00438 
00439     int len = recvData[i].size();
00440     std::vector<T>& recv_vec = recvData[i];
00441     T* recv_buf = len > 0 ? &recv_vec[0] : NULL;
00442 
00443     CHK_MPI( MPI_Irecv(recv_buf, len, mpi_dtype, recvProcs[i],
00444                        tag, comm, &mpiReqs[req_offset++]) );
00445   }
00446 
00447   //send the sendData:
00448 
00449   for(size_t i=0; i<sendProcs.size(); ++i) {
00450     if (sendProcs[i] == localProc) continue;
00451 
00452     std::vector<T>& send_buf = sendData[i];
00453     CHK_MPI( MPI_Send(&send_buf[0], sendData[i].size(), mpi_dtype,
00454                       sendProcs[i], tag, comm) );
00455   }
00456 
00457   //complete the Irecvs:
00458   for(size_t i=0; i<numRecvProcs; ++i) {
00459     if (recvProcs[i] == localProc) continue;
00460     int index;
00461     MPI_Status status;
00462     CHK_MPI( MPI_Waitany(numRecvProcs, &mpiReqs[0], &index, &status) );
00463   }
00464 
00465 #endif
00466   return(0);
00467 }
00468 
00469 //------------------------------------------------------------------------
00470 template<class T>
00471 int exchangeData(MPI_Comm comm,
00472                  std::vector<int>& sendProcs,
00473                  std::vector<std::vector<T>*>& sendData,
00474                  std::vector<int>& recvProcs,
00475                  bool recvLengthsKnownOnEntry,
00476                  std::vector<std::vector<T>*>& recvData)
00477 {
00478   if (sendProcs.size() == 0 && recvProcs.size() == 0) return(0);
00479   if (sendProcs.size() != sendData.size()) return(-1);
00480 #ifndef FEI_SER
00481   int tag = 11115;
00482   MPI_Datatype mpi_dtype = fei::mpiTraits<T>::mpi_type();
00483   std::vector<MPI_Request> mpiReqs;
00484 
00485   try {
00486   mpiReqs.resize(recvProcs.size());
00487 
00488   if (!recvLengthsKnownOnEntry) {
00489     std::vector<int> tmpIntData;
00490     tmpIntData.resize(sendData.size());
00491     std::vector<int> recvLens(sendData.size());
00492     for(unsigned i=0; i<sendData.size(); ++i) {
00493       tmpIntData[i] = (int)sendData[i]->size();
00494     }
00495 
00496     if (exchangeIntData(comm, sendProcs, tmpIntData, recvProcs, recvLens) != 0) {
00497       return(-1);
00498     }
00499 
00500     for(unsigned i=0; i<recvLens.size(); ++i) {
00501       recvData[i]->resize(recvLens[i]);
00502     }
00503   }
00504   }
00505   catch(std::runtime_error& exc) {
00506     fei::console_out() << exc.what() << FEI_ENDL;
00507     return(-1);
00508   }
00509 
00510   //launch Irecv's for recvData:
00511 
00512   size_t numRecvProcs = recvProcs.size();
00513   int req_offset = 0;
00514   int localProc = fei::localProc(comm);
00515   for(unsigned i=0; i<recvProcs.size(); ++i) {
00516     if (recvProcs[i] == localProc) {--numRecvProcs; continue;}
00517 
00518     size_t len = recvData[i]->size();
00519     std::vector<T>& rbuf = *recvData[i];
00520 
00521     CHK_MPI( MPI_Irecv(&rbuf[0], (int)len, mpi_dtype,
00522                        recvProcs[i], tag, comm, &mpiReqs[req_offset++]) );
00523   }
00524 
00525   //send the sendData:
00526 
00527   for(unsigned i=0; i<sendProcs.size(); ++i) {
00528     if (sendProcs[i] == localProc) continue;
00529 
00530     std::vector<T>& sbuf = *sendData[i];
00531     CHK_MPI( MPI_Send(&sbuf[0], (int)sbuf.size(), mpi_dtype,
00532                       sendProcs[i], tag, comm) );
00533   }
00534 
00535   //complete the Irecv's:
00536   for(unsigned i=0; i<numRecvProcs; ++i) {
00537     if (recvProcs[i] == localProc) continue;
00538     int index;
00539     MPI_Status status;
00540     CHK_MPI( MPI_Waitany((int)numRecvProcs, &mpiReqs[0], &index, &status) );
00541   }
00542 
00543 #endif
00544   return(0);
00545 }
00546 
00547 
00548 //------------------------------------------------------------------------
00566 template<class T>
00567 class MessageHandler {
00568 public:
00569   virtual ~MessageHandler(){}
00570 
00574   virtual std::vector<int>& getSendProcs() = 0;
00575 
00579   virtual std::vector<int>& getRecvProcs() = 0;
00580 
00584   virtual int getSendMessageLength(int destProc, int& messageLength) = 0;
00585 
00589   virtual int getSendMessage(int destProc, std::vector<T>& message) = 0;
00590 
00594   virtual int processRecvMessage(int srcProc, std::vector<T>& message) = 0;
00595 };//class MessageHandler
00596 
00597 
00598 //------------------------------------------------------------------------
00599 template<class T>
00600 int exchange(MPI_Comm comm, MessageHandler<T>* msgHandler)
00601 {
00602 #ifdef FEI_SER
00603   (void)msgHandler;
00604 #else
00605   int numProcs = fei::numProcs(comm);
00606   if (numProcs < 2) return(0);
00607 
00608   std::vector<int>& sendProcs = msgHandler->getSendProcs();
00609   int numSendProcs = sendProcs.size();
00610   std::vector<int>& recvProcs = msgHandler->getRecvProcs();
00611   int numRecvProcs = recvProcs.size();
00612   int i;
00613 
00614   if (numSendProcs < 1 && numRecvProcs < 1) {
00615     return(0);
00616   }
00617 
00618   std::vector<int> sendMsgLengths(numSendProcs);
00619 
00620   for(i=0; i<numSendProcs; ++i) {
00621     CHK_ERR( msgHandler->getSendMessageLength(sendProcs[i], sendMsgLengths[i]) );
00622   }
00623 
00624   std::vector<std::vector<T> > recvMsgs(numRecvProcs);
00625 
00626   std::vector<std::vector<T> > sendMsgs(numSendProcs);
00627   for(i=0; i<numSendProcs; ++i) {
00628     CHK_ERR( msgHandler->getSendMessage(sendProcs[i], sendMsgs[i]) );
00629   }
00630 
00631   CHK_ERR( exchangeData(comm, sendProcs, sendMsgs,
00632                         recvProcs, false, recvMsgs) );
00633 
00634   for(i=0; i<numRecvProcs; ++i) {
00635     std::vector<T>& recvMsg = recvMsgs[i];
00636     CHK_ERR( msgHandler->processRecvMessage(recvProcs[i], recvMsg ) );
00637   }
00638 #endif
00639 
00640   return(0);
00641 }
00642 
00643 
00644 } //namespace fei
00645 
00646 #endif // _fei_CommUtils_hpp_
00647 
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Friends