Teuchos_MPIContainerComm.hpp

Go to the documentation of this file.
00001 // @HEADER
00002 // ***********************************************************************
00003 // 
00004 //                    Teuchos: Common Tools Package
00005 //                 Copyright (2004) Sandia Corporation
00006 // 
00007 // Under terms of Contract DE-AC04-94AL85000, there is a non-exclusive
00008 // license for use of this work by or on behalf of the U.S. Government.
00009 // 
00010 // This library is free software; you can redistribute it and/or modify
00011 // it under the terms of the GNU Lesser General Public License as
00012 // published by the Free Software Foundation; either version 2.1 of the
00013 // License, or (at your option) any later version.
00014 //  
00015 // This library is distributed in the hope that it will be useful, but
00016 // WITHOUT ANY WARRANTY; without even the implied warranty of
00017 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
00018 // Lesser General Public License for more details.
00019 //  
00020 // You should have received a copy of the GNU Lesser General Public
00021 // License along with this library; if not, write to the Free Software
00022 // Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307
00023 // USA
00024 // Questions? Contact Michael A. Heroux (maherou@sandia.gov) 
00025 // 
00026 // ***********************************************************************
00027 // @HEADER
00028 
00029 #ifndef TEUCHOS_MPICONTAINERCOMM_H
00030 #define TEUCHOS_MPICONTAINERCOMM_H
00031 
00036 #include "Teuchos_ConfigDefs.hpp"
00037 #include "Teuchos_Array.hpp"
00038 #include "Teuchos_MPIComm.hpp"
00039 #include "Teuchos_MPITraits.hpp"
00040 
00041 namespace Teuchos
00042 {
00049   template <class T> class MPIContainerComm
00050   {
00051   public:
00052 
00054     static void bcast(T& x, int src, const MPIComm& comm);
00055 
00057     static void bcast(Array<T>& x, int src, const MPIComm& comm);
00058 
00060     static void bcast(Array<Array<T> >& x,
00061                       int src, const MPIComm& comm);
00062 
00064     static void allGather(const T& outgoing,
00065                           Array<T>& incoming,
00066                           const MPIComm& comm);
00067 
00069     static void allToAll(const Array<T>& outgoing,
00070                          Array<Array<T> >& incoming,
00071                          const MPIComm& comm);
00072 
00074     static void allToAll(const Array<Array<T> >& outgoing,
00075                          Array<Array<T> >& incoming,
00076                          const MPIComm& comm);
00077 
00079     static void accumulate(const T& localValue, Array<T>& sums,
00080                            const MPIComm& comm);
00081 
00082   private:
00084     static void getBigArray(const Array<Array<T> >& x,
00085                             Array<T>& bigArray,
00086                             Array<int>& offsets);
00087 
00089     static void getSmallArrays(const Array<T>& bigArray,
00090                                const Array<int>& offsets,
00091                                Array<Array<T> >& x);
00092 
00093 
00094   };
00095 
00096 
00097 #ifndef DOXYGEN_SHOULD_SKIP_THIS
00098 
00101   template <> class MPIContainerComm<string>
00102   {
00103   public:
00104     static void bcast(string& x, int src, const MPIComm& comm);
00105 
00107     static void bcast(Array<string>& x, int src, const MPIComm& comm);
00108 
00110     static void bcast(Array<Array<string> >& x,
00111                       int src, const MPIComm& comm);
00112 
00114     static void allGather(const string& outgoing,
00115                           Array<string>& incoming,
00116                           const MPIComm& comm);
00117 
00118   private:
00120     static void getBigArray(const Array<string>& x,
00121                             Array<char>& bigArray,
00122                             Array<int>& offsets);
00123 
00126     static void getStrings(const Array<char>& bigArray,
00127                            const Array<int>& offsets,
00128                            Array<string>& x);
00129   };
00130 
00131 #endif // DOXYGEN_SHOULD_SKIP_THIS
00132 
00133   /* --------- generic functions for primitives ------------------- */
00134 
00135   template <class T> inline void MPIContainerComm<T>::bcast(T& x, int src,
00136                                                          const MPIComm& comm)
00137   {
00138     comm.bcast((void*)&x, 1, MPITraits<T>::type(), src);
00139   }
00140 
00141 
00142   /* ----------- generic functions for arrays of primitives ----------- */
00143 
00144   template <class T>
00145   inline void MPIContainerComm<T>::bcast(Array<T>& x, int src, const MPIComm& comm)
00146   {
00147      
00148     int len = x.length();
00149     MPIContainerComm<int>::bcast(len, src, comm);
00150 
00151     if (comm.getRank() != src)
00152       {
00153         x.resize(len);
00154       }
00155     if (len==0) return;
00156 
00157     /* then broadcast the contents */
00158     comm.bcast((void*) &(x[0]), (int) len,
00159                MPITraits<T>::type(), src);
00160   }
00161 
00162 
00163 
00164   /* ---------- generic function for arrays of arrays ----------- */
00165 
00166   template <class T>
00167   inline void MPIContainerComm<T>::bcast(Array<Array<T> >& x, int src, const MPIComm& comm)
00168   {
00169     Array<T> bigArray;
00170     Array<int> offsets;
00171 
00172     if (src==comm.getRank())
00173       {
00174         getBigArray(x, bigArray, offsets);
00175       }
00176 
00177     bcast(bigArray, src, comm);
00178     MPIContainerComm<int>::bcast(offsets, src, comm);
00179 
00180     if (src != comm.getRank())
00181       {
00182         getSmallArrays(bigArray, offsets, x);
00183       }
00184   }
00185 
00186   /* ---------- generic gather and scatter ------------------------ */
00187 
00188   template <class T> inline
00189   void MPIContainerComm<T>::allToAll(const Array<T>& outgoing,
00190                                   Array<Array<T> >& incoming,
00191                                   const MPIComm& comm)
00192   {
00193     int numProcs = comm.getNProc();
00194 
00195     // catch degenerate case
00196     if (numProcs==1)
00197       {
00198         incoming.resize(1);
00199         incoming[0] = outgoing;
00200         return;
00201       }
00202 
00203     T* sendBuf = new T[numProcs * outgoing.length()];
00204     TEST_FOR_EXCEPTION(sendBuf==0, 
00205       std::runtime_error, "Comm::allToAll failed to allocate sendBuf");
00206     T* recvBuf = new T[numProcs * outgoing.length()];
00207     TEST_FOR_EXCEPTION(recvBuf==0, 
00208       std::runtime_error, "Comm::allToAll failed to allocate recvBuf");
00209 
00210     int i;
00211     for (i=0; i<numProcs; i++)
00212       {
00213         for (int j=0; j<outgoing.length(); j++)
00214           {
00215             sendBuf[i*outgoing.length() + j] = outgoing[j];
00216           }
00217       }
00218 
00219     comm.allToAll(sendBuf, outgoing.length(), MPITraits<T>::type(),
00220                   recvBuf, outgoing.length(), MPITraits<T>::type());
00221 
00222     incoming.resize(numProcs);
00223 
00224     for (i=0; i<numProcs; i++)
00225       {
00226         incoming[i].resize(outgoing.length());
00227         for (int j=0; j<outgoing.length(); j++)
00228           {
00229             incoming[i][j] = recvBuf[i*outgoing.length() + j];
00230           }
00231       }
00232 
00233     delete [] sendBuf;
00234     delete [] recvBuf;
00235   }
00236 
00237   template <class T> inline
00238   void MPIContainerComm<T>::allToAll(const Array<Array<T> >& outgoing,
00239                                   Array<Array<T> >& incoming, const MPIComm& comm)
00240   {
00241     int numProcs = comm.getNProc();
00242 
00243     // catch degenerate case
00244     if (numProcs==1)
00245       {
00246         incoming = outgoing;
00247         return;
00248       }
00249 
00250     int* sendMesgLength = new int[numProcs];
00251     TEST_FOR_EXCEPTION(sendMesgLength==0, 
00252       std::runtime_error, "failed to allocate sendMesgLength");
00253     int* recvMesgLength = new int[numProcs];
00254     TEST_FOR_EXCEPTION(recvMesgLength==0, 
00255       std::runtime_error, "failed to allocate recvMesgLength");
00256 
00257     int p = 0;
00258     for (p=0; p<numProcs; p++)
00259       {
00260         sendMesgLength[p] = outgoing[p].length();
00261       }
00262 
00263     comm.allToAll(sendMesgLength, 1, MPIComm::INT,
00264                   recvMesgLength, 1, MPIComm::INT);
00265 
00266 
00267     int totalSendLength = 0;
00268     int totalRecvLength = 0;
00269     for (p=0; p<numProcs; p++)
00270       {
00271         totalSendLength += sendMesgLength[p];
00272         totalRecvLength += recvMesgLength[p];
00273       }
00274 
00275     T* sendBuf = new T[totalSendLength];
00276     TEST_FOR_EXCEPTION(sendBuf==0, 
00277       std::runtime_error, "failed to allocate sendBuf");
00278     T* recvBuf = new T[totalRecvLength];
00279     TEST_FOR_EXCEPTION(recvBuf==0, 
00280       std::runtime_error, "failed to allocate recvBuf");
00281 
00282     int* sendDisp = new int[numProcs];
00283     TEST_FOR_EXCEPTION(sendDisp==0, 
00284       std::runtime_error, "failed to allocate sendDisp");
00285     int* recvDisp = new int[numProcs];
00286     TEST_FOR_EXCEPTION(recvDisp==0, 
00287       std::runtime_error, "failed to allocate recvDisp");
00288 
00289     int count = 0;
00290     sendDisp[0] = 0;
00291     recvDisp[0] = 0;
00292 
00293     for (p=0; p<numProcs; p++)
00294       {
00295         for (int i=0; i<outgoing[p].length(); i++)
00296           {
00297             sendBuf[count] = outgoing[p][i];
00298             count++;
00299           }
00300         if (p>0)
00301           {
00302             sendDisp[p] = sendDisp[p-1] + sendMesgLength[p-1];
00303             recvDisp[p] = recvDisp[p-1] + recvMesgLength[p-1];
00304           }
00305       }
00306 
00307     comm.allToAllv(sendBuf, sendMesgLength,
00308                    sendDisp, MPITraits<T>::type(),
00309                    recvBuf, recvMesgLength,
00310                    recvDisp, MPITraits<T>::type());
00311 
00312     incoming.resize(numProcs);
00313     for (p=0; p<numProcs; p++)
00314       {
00315         incoming[p].resize(recvMesgLength[p]);
00316         for (int i=0; i<recvMesgLength[p]; i++)
00317           {
00318             incoming[p][i] = recvBuf[recvDisp[p] + i];
00319           }
00320       }
00321 
00322     delete [] sendBuf;
00323     delete [] sendMesgLength;
00324     delete [] sendDisp;
00325     delete [] recvBuf;
00326     delete [] recvMesgLength;
00327     delete [] recvDisp;
00328   }
00329 
00330   template <class T> inline
00331   void MPIContainerComm<T>::allGather(const T& outgoing, Array<T>& incoming,
00332                                    const MPIComm& comm)
00333   {
00334     int nProc = comm.getNProc();
00335     incoming.resize(nProc);
00336 
00337     if (nProc==1)
00338       {
00339         incoming[0] = outgoing;
00340       }
00341     else
00342       {
00343         comm.allGather((void*) &outgoing, 1, MPITraits<T>::type(),
00344                        (void*) &(incoming[0]), 1, MPITraits<T>::type());
00345       }
00346   }
00347 
00348   template <class T> inline
00349   void MPIContainerComm<T>::accumulate(const T& localValue, Array<T>& sums,
00350                                     const MPIComm& comm)
00351   {
00352     Array<T> contributions;
00353     allGather(localValue, contributions, comm);
00354     sums.resize(comm.getNProc());
00355     sums[0] = 0;
00356 
00357     for (int i=0; i<comm.getNProc()-1; i++)
00358       {
00359         sums[i+1] = sums[i] + contributions[i];
00360       }
00361   }
00362 
00363 
00364 
00365 
00366   template <class T> inline
00367   void MPIContainerComm<T>::getBigArray(const Array<Array<T> >& x, Array<T>& bigArray,
00368                                      Array<int>& offsets)
00369   {
00370     offsets.resize(x.length()+1);
00371     int totalLength = 0;
00372 
00373     for (int i=0; i<x.length(); i++)
00374       {
00375         offsets[i] = totalLength;
00376         totalLength += x[i].length();
00377       }
00378     offsets[x.length()] = totalLength;
00379 
00380     bigArray.resize(totalLength);
00381 
00382     for (int i=0; i<x.length(); i++)
00383       {
00384         for (int j=0; j<x[i].length(); j++)
00385           {
00386             bigArray[offsets[i]+j] = x[i][j];
00387           }
00388       }
00389   }
00390 
00391   template <class T> inline
00392   void MPIContainerComm<T>::getSmallArrays(const Array<T>& bigArray,
00393                                         const Array<int>& offsets,
00394                                         Array<Array<T> >& x)
00395   {
00396     x.resize(offsets.length()-1);
00397     for (int i=0; i<x.length(); i++)
00398       {
00399         x[i].resize(offsets[i+1]-offsets[i]);
00400         for (int j=0; j<x[i].length(); j++)
00401           {
00402             x[i][j] = bigArray[offsets[i] + j];
00403           }
00404       }
00405   }
00406 
00407 
00408 #ifndef DOXYGEN_SHOULD_SKIP_THIS
00409 
00410   /* --------------- string specializations --------------------- */
00411 
00412   inline void MPIContainerComm<string>::bcast(string& x,
00413                                            int src, const MPIComm& comm)
00414   {
00415     int len = x.length();
00416     MPIContainerComm<int>::bcast(len, src, comm);
00417 
00418     x.resize(len);
00419     comm.bcast((void*)&(x[0]), len, MPITraits<char>::type(), src);
00420   }
00421 
00422 
00423   inline void MPIContainerComm<string>::bcast(Array<string>& x, int src,
00424                                            const MPIComm& comm)
00425   {
00426     /* begin by packing all the data into a big char array. This will
00427      * take a little time, but will be cheaper than multiple MPI calls */
00428     Array<char> bigArray;
00429     Array<int> offsets;
00430     if (comm.getRank()==src)
00431       {
00432         getBigArray(x, bigArray, offsets);
00433       }
00434 
00435     /* now broadcast the big array and the offsets */
00436     MPIContainerComm<char>::bcast(bigArray, src, comm);
00437     MPIContainerComm<int>::bcast(offsets, src, comm);
00438 
00439     /* finally, reassemble the array of strings */
00440     if (comm.getRank() != src)
00441       {
00442         getStrings(bigArray, offsets, x);
00443       }
00444   }
00445 
00446   inline void MPIContainerComm<string>::bcast(Array<Array<string> >& x,
00447                                            int src, const MPIComm& comm)
00448   {
00449     int len = x.length();
00450     MPIContainerComm<int>::bcast(len, src, comm);
00451 
00452     x.resize(len);
00453     for (int i=0; i<len; i++)
00454       {
00455         MPIContainerComm<string>::bcast(x[i], src, comm);
00456       }
00457   }
00458 
00459 
00460   inline void MPIContainerComm<string>::allGather(const string& outgoing,
00461                                                Array<string>& incoming,
00462                                                const MPIComm& comm)
00463   {
00464     int nProc = comm.getNProc();
00465 
00466     int sendCount = outgoing.length();
00467 
00468     incoming.resize(nProc);
00469 
00470     int* recvCounts = new int[nProc];
00471     int* recvDisplacements = new int[nProc];
00472 
00473     /* share lengths with all procs */
00474     comm.allGather((void*) &sendCount, 1, MPIComm::INT,
00475                    (void*) recvCounts, 1, MPIComm::INT);
00476 
00477 
00478     int recvSize = 0;
00479     recvDisplacements[0] = 0;
00480     for (int i=0; i<nProc; i++)
00481       {
00482         recvSize += recvCounts[i];
00483         if (i < nProc-1)
00484           {
00485             recvDisplacements[i+1] = recvDisplacements[i]+recvCounts[i];
00486           }
00487       }
00488 
00489     char* recvBuf = new char[recvSize];
00490 
00491     comm.allGatherv((void*) outgoing.c_str(), sendCount, MPIComm::CHAR,
00492                     recvBuf, recvCounts, recvDisplacements, MPIComm::CHAR);
00493 
00494     for (int j=0; j<nProc; j++)
00495       {
00496         char* start = recvBuf + recvDisplacements[j];
00497         char* tmp = new char[recvCounts[j]+1];
00498         memcpy(tmp, start, recvCounts[j]);
00499         tmp[recvCounts[j]] = '\0';
00500         incoming[j] = string(tmp);
00501         delete [] tmp;
00502       }
00503 
00504     delete [] recvCounts;
00505     delete [] recvDisplacements;
00506     delete [] recvBuf;
00507   }
00508 
00509 
00510   inline void MPIContainerComm<string>::getBigArray(const Array<string>& x,
00511                                                  Array<char>& bigArray,
00512                                                  Array<int>& offsets)
00513   {
00514     offsets.resize(x.length()+1);
00515     int totalLength = 0;
00516 
00517     for (int i=0; i<x.length(); i++)
00518       {
00519         offsets[i] = totalLength;
00520         totalLength += x[i].length();
00521       }
00522     offsets[x.length()] = totalLength;
00523 
00524     bigArray.resize(totalLength);
00525 
00526     for (int i=0; i<x.length(); i++)
00527       {
00528         for (unsigned int j=0; j<x[i].length(); j++)
00529           {
00530             bigArray[offsets[i]+j] = x[i][j];
00531           }
00532       }
00533   }
00534 
00535   inline void MPIContainerComm<string>::getStrings(const Array<char>& bigArray,
00536                                                 const Array<int>& offsets,
00537                                                 Array<string>& x)
00538   {
00539     x.resize(offsets.length()-1);
00540     for (int i=0; i<x.length(); i++)
00541       {
00542         x[i].resize(offsets[i+1]-offsets[i]);
00543         for (unsigned int j=0; j<x[i].length(); j++)
00544           {
00545             x[i][j] = bigArray[offsets[i] + j];
00546           }
00547       }
00548   }
00549 #endif // DOXYGEN_SHOULD_SKIP_THIS
00550 
00551 }
00552 
00553 
00554 #endif
00555 
00556 

Generated on Thu Sep 18 12:42:50 2008 for Teuchos - Trilinos Tools Package by doxygen 1.3.9.1