Teuchos_MPIContainerComm.hpp

Go to the documentation of this file.
00001 // @HEADER
00002 // ***********************************************************************
00003 // 
00004 //                    Teuchos: Common Tools Package
00005 //                 Copyright (2004) Sandia Corporation
00006 // 
00007 // Under terms of Contract DE-AC04-94AL85000, there is a non-exclusive
00008 // license for use of this work by or on behalf of the U.S. Government.
00009 // 
00010 // This library is free software; you can redistribute it and/or modify
00011 // it under the terms of the GNU Lesser General Public License as
00012 // published by the Free Software Foundation; either version 2.1 of the
00013 // License, or (at your option) any later version.
00014 //  
00015 // This library is distributed in the hope that it will be useful, but
00016 // WITHOUT ANY WARRANTY; without even the implied warranty of
00017 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
00018 // Lesser General Public License for more details.
00019 //  
00020 // You should have received a copy of the GNU Lesser General Public
00021 // License along with this library; if not, write to the Free Software
00022 // Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307
00023 // USA
00024 // Questions? Contact Michael A. Heroux (maherou@sandia.gov) 
00025 // 
00026 // ***********************************************************************
00027 // @HEADER
00028 
00029 #ifndef TEUCHOS_MPICONTAINERCOMM_H
00030 #define TEUCHOS_MPICONTAINERCOMM_H
00031 
00036 #include "Teuchos_ConfigDefs.hpp"
00037 #include "Teuchos_Array.hpp"
00038 #include "Teuchos_MPIComm.hpp"
00039 #include "Teuchos_MPITraits.hpp"
00040 
00041 namespace Teuchos
00042 {
00049   template <class T> class MPIContainerComm
00050   {
00051   public:
00052 
00054     static void bcast(T& x, int src, const MPIComm& comm);
00055 
00057     static void bcast(Array<T>& x, int src, const MPIComm& comm);
00058 
00060     static void bcast(Array<Array<T> >& x,
00061                       int src, const MPIComm& comm);
00062 
00064     static void allGather(const T& outgoing,
00065                           Array<T>& incoming,
00066                           const MPIComm& comm);
00067 
00069     static void allToAll(const Array<T>& outgoing,
00070                          Array<Array<T> >& incoming,
00071                          const MPIComm& comm);
00072 
00074     static void allToAll(const Array<Array<T> >& outgoing,
00075                          Array<Array<T> >& incoming,
00076                          const MPIComm& comm);
00077 
00079     static void accumulate(const T& localValue, Array<T>& sums, T& total,
00080                            const MPIComm& comm);
00081 
00082   private:
00084     static void getBigArray(const Array<Array<T> >& x,
00085                             Array<T>& bigArray,
00086                             Array<int>& offsets);
00087 
00089     static void getSmallArrays(const Array<T>& bigArray,
00090                                const Array<int>& offsets,
00091                                Array<Array<T> >& x);
00092 
00093 
00094   };
00095 
00096 
00097 #ifndef DOXYGEN_SHOULD_SKIP_THIS
00098 
00101   template <> class MPIContainerComm<string>
00102   {
00103   public:
00104     static void bcast(string& x, int src, const MPIComm& comm);
00105 
00107     static void bcast(Array<string>& x, int src, const MPIComm& comm);
00108 
00110     static void bcast(Array<Array<string> >& x,
00111                       int src, const MPIComm& comm);
00112 
00114     static void allGather(const string& outgoing,
00115                           Array<string>& incoming,
00116                           const MPIComm& comm);
00117 
00118   private:
00120     static void getBigArray(const Array<string>& x,
00121                             Array<char>& bigArray,
00122                             Array<int>& offsets);
00123 
00126     static void getStrings(const Array<char>& bigArray,
00127                            const Array<int>& offsets,
00128                            Array<string>& x);
00129   };
00130 
00131 #endif // DOXYGEN_SHOULD_SKIP_THIS
00132 
00133   /* --------- generic functions for primitives ------------------- */
00134 
00135   template <class T> inline void MPIContainerComm<T>::bcast(T& x, int src,
00136                                                          const MPIComm& comm)
00137   {
00138     comm.bcast((void*)&x, 1, MPITraits<T>::type(), src);
00139   }
00140 
00141 
00142   /* ----------- generic functions for arrays of primitives ----------- */
00143 
00144   template <class T>
00145   inline void MPIContainerComm<T>::bcast(Array<T>& x, int src, const MPIComm& comm)
00146   {
00147      
00148     int len = x.length();
00149     MPIContainerComm<int>::bcast(len, src, comm);
00150 
00151     if (comm.getRank() != src)
00152       {
00153         x.resize(len);
00154       }
00155     if (len==0) return;
00156 
00157     /* then broadcast the contents */
00158     comm.bcast((void*) &(x[0]), (int) len,
00159                MPITraits<T>::type(), src);
00160   }
00161 
00162 
00163 
00164   /* ---------- generic function for arrays of arrays ----------- */
00165 
00166   template <class T>
00167   inline void MPIContainerComm<T>::bcast(Array<Array<T> >& x, int src, const MPIComm& comm)
00168   {
00169     Array<T> bigArray;
00170     Array<int> offsets;
00171 
00172     if (src==comm.getRank())
00173       {
00174         getBigArray(x, bigArray, offsets);
00175       }
00176 
00177     bcast(bigArray, src, comm);
00178     MPIContainerComm<int>::bcast(offsets, src, comm);
00179 
00180     if (src != comm.getRank())
00181       {
00182         getSmallArrays(bigArray, offsets, x);
00183       }
00184   }
00185 
00186   /* ---------- generic gather and scatter ------------------------ */
00187 
00188   template <class T> inline
00189   void MPIContainerComm<T>::allToAll(const Array<T>& outgoing,
00190                                   Array<Array<T> >& incoming,
00191                                   const MPIComm& comm)
00192   {
00193     int numProcs = comm.getNProc();
00194 
00195     // catch degenerate case
00196     if (numProcs==1)
00197       {
00198         incoming.resize(1);
00199         incoming[0] = outgoing;
00200         return;
00201       }
00202 
00203     T* sendBuf = new T[numProcs * outgoing.length()];
00204     TEST_FOR_EXCEPTION(sendBuf==0, 
00205       std::runtime_error, "Comm::allToAll failed to allocate sendBuf");
00206     T* recvBuf = new T[numProcs * outgoing.length()];
00207     TEST_FOR_EXCEPTION(recvBuf==0, 
00208       std::runtime_error, "Comm::allToAll failed to allocate recvBuf");
00209 
00210     int i;
00211     for (i=0; i<numProcs; i++)
00212       {
00213         for (int j=0; j<outgoing.length(); j++)
00214           {
00215             sendBuf[i*outgoing.length() + j] = outgoing[j];
00216           }
00217       }
00218 
00219     comm.allToAll(sendBuf, outgoing.length(), MPITraits<T>::type(),
00220                   recvBuf, outgoing.length(), MPITraits<T>::type());
00221 
00222     incoming.resize(numProcs);
00223 
00224     for (i=0; i<numProcs; i++)
00225       {
00226         incoming[i].resize(outgoing.length());
00227         for (int j=0; j<outgoing.length(); j++)
00228           {
00229             incoming[i][j] = recvBuf[i*outgoing.length() + j];
00230           }
00231       }
00232 
00233     delete [] sendBuf;
00234     delete [] recvBuf;
00235   }
00236 
00237   template <class T> inline
00238   void MPIContainerComm<T>::allToAll(const Array<Array<T> >& outgoing,
00239                                   Array<Array<T> >& incoming, const MPIComm& comm)
00240   {
00241     int numProcs = comm.getNProc();
00242 
00243     // catch degenerate case
00244     if (numProcs==1)
00245       {
00246         incoming = outgoing;
00247         return;
00248       }
00249 
00250     int* sendMesgLength = new int[numProcs];
00251     TEST_FOR_EXCEPTION(sendMesgLength==0, 
00252       std::runtime_error, "failed to allocate sendMesgLength");
00253     int* recvMesgLength = new int[numProcs];
00254     TEST_FOR_EXCEPTION(recvMesgLength==0, 
00255       std::runtime_error, "failed to allocate recvMesgLength");
00256 
00257     int p = 0;
00258     for (p=0; p<numProcs; p++)
00259       {
00260         sendMesgLength[p] = outgoing[p].length();
00261       }
00262 
00263     comm.allToAll(sendMesgLength, 1, MPIComm::INT,
00264                   recvMesgLength, 1, MPIComm::INT);
00265 
00266 
00267     int totalSendLength = 0;
00268     int totalRecvLength = 0;
00269     for (p=0; p<numProcs; p++)
00270       {
00271         totalSendLength += sendMesgLength[p];
00272         totalRecvLength += recvMesgLength[p];
00273       }
00274 
00275     T* sendBuf = new T[totalSendLength];
00276     TEST_FOR_EXCEPTION(sendBuf==0, 
00277       std::runtime_error, "failed to allocate sendBuf");
00278     T* recvBuf = new T[totalRecvLength];
00279     TEST_FOR_EXCEPTION(recvBuf==0, 
00280       std::runtime_error, "failed to allocate recvBuf");
00281 
00282     int* sendDisp = new int[numProcs];
00283     TEST_FOR_EXCEPTION(sendDisp==0, 
00284       std::runtime_error, "failed to allocate sendDisp");
00285     int* recvDisp = new int[numProcs];
00286     TEST_FOR_EXCEPTION(recvDisp==0, 
00287       std::runtime_error, "failed to allocate recvDisp");
00288 
00289     int count = 0;
00290     sendDisp[0] = 0;
00291     recvDisp[0] = 0;
00292 
00293     for (p=0; p<numProcs; p++)
00294       {
00295         for (int i=0; i<outgoing[p].length(); i++)
00296           {
00297             sendBuf[count] = outgoing[p][i];
00298             count++;
00299           }
00300         if (p>0)
00301           {
00302             sendDisp[p] = sendDisp[p-1] + sendMesgLength[p-1];
00303             recvDisp[p] = recvDisp[p-1] + recvMesgLength[p-1];
00304           }
00305       }
00306 
00307     comm.allToAllv(sendBuf, sendMesgLength,
00308                    sendDisp, MPITraits<T>::type(),
00309                    recvBuf, recvMesgLength,
00310                    recvDisp, MPITraits<T>::type());
00311 
00312     incoming.resize(numProcs);
00313     for (p=0; p<numProcs; p++)
00314       {
00315         incoming[p].resize(recvMesgLength[p]);
00316         for (int i=0; i<recvMesgLength[p]; i++)
00317           {
00318             incoming[p][i] = recvBuf[recvDisp[p] + i];
00319           }
00320       }
00321 
00322     delete [] sendBuf;
00323     delete [] sendMesgLength;
00324     delete [] sendDisp;
00325     delete [] recvBuf;
00326     delete [] recvMesgLength;
00327     delete [] recvDisp;
00328   }
00329 
00330   template <class T> inline
00331   void MPIContainerComm<T>::allGather(const T& outgoing, Array<T>& incoming,
00332                                    const MPIComm& comm)
00333   {
00334     int nProc = comm.getNProc();
00335     incoming.resize(nProc);
00336 
00337     if (nProc==1)
00338       {
00339         incoming[0] = outgoing;
00340       }
00341     else
00342       {
00343         comm.allGather((void*) &outgoing, 1, MPITraits<T>::type(),
00344                        (void*) &(incoming[0]), 1, MPITraits<T>::type());
00345       }
00346   }
00347 
00348   template <class T> inline
00349   void MPIContainerComm<T>::accumulate(const T& localValue, Array<T>& sums,
00350                                        T& total,
00351                                        const MPIComm& comm)
00352   {
00353     Array<T> contributions;
00354     allGather(localValue, contributions, comm);
00355     sums.resize(comm.getNProc());
00356     sums[0] = 0;
00357     total = contributions[0];
00358 
00359     for (int i=0; i<comm.getNProc()-1; i++)
00360       {
00361         total += contributions[i+1];
00362         sums[i+1] = sums[i] + contributions[i];
00363       }
00364   }
00365 
00366 
00367 
00368 
00369   template <class T> inline
00370   void MPIContainerComm<T>::getBigArray(const Array<Array<T> >& x, Array<T>& bigArray,
00371                                      Array<int>& offsets)
00372   {
00373     offsets.resize(x.length()+1);
00374     int totalLength = 0;
00375 
00376     for (int i=0; i<x.length(); i++)
00377       {
00378         offsets[i] = totalLength;
00379         totalLength += x[i].length();
00380       }
00381     offsets[x.length()] = totalLength;
00382 
00383     bigArray.resize(totalLength);
00384 
00385     for (int i=0; i<x.length(); i++)
00386       {
00387         for (int j=0; j<x[i].length(); j++)
00388           {
00389             bigArray[offsets[i]+j] = x[i][j];
00390           }
00391       }
00392   }
00393 
00394   template <class T> inline
00395   void MPIContainerComm<T>::getSmallArrays(const Array<T>& bigArray,
00396                                         const Array<int>& offsets,
00397                                         Array<Array<T> >& x)
00398   {
00399     x.resize(offsets.length()-1);
00400     for (int i=0; i<x.length(); i++)
00401       {
00402         x[i].resize(offsets[i+1]-offsets[i]);
00403         for (int j=0; j<x[i].length(); j++)
00404           {
00405             x[i][j] = bigArray[offsets[i] + j];
00406           }
00407       }
00408   }
00409 
00410 
00411 #ifndef DOXYGEN_SHOULD_SKIP_THIS
00412 
00413   /* --------------- string specializations --------------------- */
00414 
00415   inline void MPIContainerComm<string>::bcast(string& x,
00416                                            int src, const MPIComm& comm)
00417   {
00418     int len = x.length();
00419     MPIContainerComm<int>::bcast(len, src, comm);
00420 
00421     x.resize(len);
00422     comm.bcast((void*)&(x[0]), len, MPITraits<char>::type(), src);
00423   }
00424 
00425 
00426   inline void MPIContainerComm<string>::bcast(Array<string>& x, int src,
00427                                            const MPIComm& comm)
00428   {
00429     /* begin by packing all the data into a big char array. This will
00430      * take a little time, but will be cheaper than multiple MPI calls */
00431     Array<char> bigArray;
00432     Array<int> offsets;
00433     if (comm.getRank()==src)
00434       {
00435         getBigArray(x, bigArray, offsets);
00436       }
00437 
00438     /* now broadcast the big array and the offsets */
00439     MPIContainerComm<char>::bcast(bigArray, src, comm);
00440     MPIContainerComm<int>::bcast(offsets, src, comm);
00441 
00442     /* finally, reassemble the array of strings */
00443     if (comm.getRank() != src)
00444       {
00445         getStrings(bigArray, offsets, x);
00446       }
00447   }
00448 
00449   inline void MPIContainerComm<string>::bcast(Array<Array<string> >& x,
00450                                            int src, const MPIComm& comm)
00451   {
00452     int len = x.length();
00453     MPIContainerComm<int>::bcast(len, src, comm);
00454 
00455     x.resize(len);
00456     for (int i=0; i<len; i++)
00457       {
00458         MPIContainerComm<string>::bcast(x[i], src, comm);
00459       }
00460   }
00461 
00462 
00463   inline void MPIContainerComm<string>::allGather(const string& outgoing,
00464                                                Array<string>& incoming,
00465                                                const MPIComm& comm)
00466   {
00467     int nProc = comm.getNProc();
00468 
00469     int sendCount = outgoing.length();
00470 
00471     incoming.resize(nProc);
00472 
00473     int* recvCounts = new int[nProc];
00474     int* recvDisplacements = new int[nProc];
00475 
00476     /* share lengths with all procs */
00477     comm.allGather((void*) &sendCount, 1, MPIComm::INT,
00478                    (void*) recvCounts, 1, MPIComm::INT);
00479 
00480 
00481     int recvSize = 0;
00482     recvDisplacements[0] = 0;
00483     for (int i=0; i<nProc; i++)
00484       {
00485         recvSize += recvCounts[i];
00486         if (i < nProc-1)
00487           {
00488             recvDisplacements[i+1] = recvDisplacements[i]+recvCounts[i];
00489           }
00490       }
00491 
00492     char* recvBuf = new char[recvSize];
00493 
00494     comm.allGatherv((void*) outgoing.c_str(), sendCount, MPIComm::CHAR,
00495                     recvBuf, recvCounts, recvDisplacements, MPIComm::CHAR);
00496 
00497     for (int j=0; j<nProc; j++)
00498       {
00499         char* start = recvBuf + recvDisplacements[j];
00500         char* tmp = new char[recvCounts[j]+1];
00501         memcpy(tmp, start, recvCounts[j]);
00502         tmp[recvCounts[j]] = '\0';
00503         incoming[j] = string(tmp);
00504         delete [] tmp;
00505       }
00506 
00507     delete [] recvCounts;
00508     delete [] recvDisplacements;
00509     delete [] recvBuf;
00510   }
00511 
00512 
00513   inline void MPIContainerComm<string>::getBigArray(const Array<string>& x,
00514                                                  Array<char>& bigArray,
00515                                                  Array<int>& offsets)
00516   {
00517     offsets.resize(x.length()+1);
00518     int totalLength = 0;
00519 
00520     for (int i=0; i<x.length(); i++)
00521       {
00522         offsets[i] = totalLength;
00523         totalLength += x[i].length();
00524       }
00525     offsets[x.length()] = totalLength;
00526 
00527     bigArray.resize(totalLength);
00528 
00529     for (int i=0; i<x.length(); i++)
00530       {
00531         for (unsigned int j=0; j<x[i].length(); j++)
00532           {
00533             bigArray[offsets[i]+j] = x[i][j];
00534           }
00535       }
00536   }
00537 
00538   inline void MPIContainerComm<string>::getStrings(const Array<char>& bigArray,
00539                                                 const Array<int>& offsets,
00540                                                 Array<string>& x)
00541   {
00542     x.resize(offsets.length()-1);
00543     for (int i=0; i<x.length(); i++)
00544       {
00545         x[i].resize(offsets[i+1]-offsets[i]);
00546         for (unsigned int j=0; j<x[i].length(); j++)
00547           {
00548             x[i][j] = bigArray[offsets[i] + j];
00549           }
00550       }
00551   }
00552 #endif // DOXYGEN_SHOULD_SKIP_THIS
00553 
00554 }
00555 
00556 
00557 #endif
00558 
00559 

Generated on Thu Sep 18 12:41:17 2008 for Teuchos - Trilinos Tools Package by doxygen 1.3.9.1