Zoltan 2 Version 0.5
Zoltan2_AlltoAll.hpp
Go to the documentation of this file.
00001 // @HEADER
00002 //
00003 // ***********************************************************************
00004 //
00005 //   Zoltan2: A package of combinatorial algorithms for scientific computing
00006 //                  Copyright 2012 Sandia Corporation
00007 //
00008 // Under the terms of Contract DE-AC04-94AL85000 with Sandia Corporation,
00009 // the U.S. Government retains certain rights in this software.
00010 //
00011 // Redistribution and use in source and binary forms, with or without
00012 // modification, are permitted provided that the following conditions are
00013 // met:
00014 //
00015 // 1. Redistributions of source code must retain the above copyright
00016 // notice, this list of conditions and the following disclaimer.
00017 //
00018 // 2. Redistributions in binary form must reproduce the above copyright
00019 // notice, this list of conditions and the following disclaimer in the
00020 // documentation and/or other materials provided with the distribution.
00021 //
00022 // 3. Neither the name of the Corporation nor the names of the
00023 // contributors may be used to endorse or promote products derived from
00024 // this software without specific prior written permission.
00025 //
00026 // THIS SOFTWARE IS PROVIDED BY SANDIA CORPORATION "AS IS" AND ANY
00027 // EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
00028 // IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
00029 // PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL SANDIA CORPORATION OR THE
00030 // CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
00031 // EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
00032 // PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
00033 // PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
00034 // LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
00035 // NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
00036 // SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
00037 //
00038 // Questions? Contact Karen Devine      (kddevin@sandia.gov)
00039 //                    Erik Boman        (egboman@sandia.gov)
00040 //                    Siva Rajamanickam (srajama@sandia.gov)
00041 //
00042 // ***********************************************************************
00043 //
00044 // @HEADER
00045 
00050 #ifndef _ZOLTAN2_ALLTOALL_HPP_
00051 #define _ZOLTAN2_ALLTOALL_HPP_
00052 
00053 #include <Zoltan2_Standards.hpp>
00054 #include <Zoltan2_Environment.hpp>
00055 
00056 #include <vector>
00057 #include <climits>
00058 
00059 namespace Zoltan2
00060 {
00061 
00070 void AlltoAllCount(const Comm<int> &comm, const Environment &env,
00071  const ArrayView<const int> &sendCount, ArrayRCP<int> &recvCount)
00072 {
00073   int nprocs = comm.getSize();
00074   int rank = comm.getRank();
00075 
00076   RCP<const int> *messages = new RCP<const int> [nprocs];
00077   for (int p=0; p < nprocs; p++)
00078     messages[p] = rcp(sendCount.getRawPtr()+p, false);
00079 
00080   ArrayRCP<RCP<const int> > messageArray(messages, 0, nprocs, true);
00081 
00082   int *counts = new int [nprocs];
00083   recvCount = arcp(counts, 0, nprocs, true);
00084 
00085   counts[rank] = sendCount[rank];
00086 
00087 #ifdef HAVE_ZOLTAN2_MPI
00088 
00089   // I was getting hangs in Teuchos::waitAll, so I do
00090   // blocking receives below.
00091 
00092   for (int p=1; p < nprocs; p++){
00093     int recvFrom = (rank + nprocs - p) % nprocs;
00094     int sendTo = (rank + p) % nprocs;
00095 
00096     try{  // non blocking send
00097       Teuchos::isend<int, int>(comm, messageArray[sendTo], sendTo);
00098     }
00099     Z2_THROW_OUTSIDE_ERROR(env);
00100 
00101     try{  // blocking receive for message just sent to me
00102       Teuchos::receive<int, int>(comm, recvFrom, counts + recvFrom);
00103     }
00104     Z2_THROW_OUTSIDE_ERROR(env);
00105   }
00106 
00107   comm.barrier();
00108 #endif
00109 }
00110 
00139 template <typename T>
00140 void AlltoAllv(const Comm<int> &comm,
00141               const Environment &env,  
00142               const ArrayView<const T> &sendBuf,
00143               const ArrayView<const int> &sendCount,
00144               ArrayRCP<T> &recvBuf,      // output, allocated here
00145               ArrayRCP<int> &recvCount,   // output, allocated here
00146               bool countsAreUniform=false)
00147 {
00148   int nprocs = comm.getSize();
00149   int rank = comm.getRank();
00150 
00151   if (countsAreUniform){
00152     int *counts = new int [nprocs];
00153     for (int i=0; i < nprocs; i++)
00154       counts[i] = sendCount[0];
00155     recvCount = arcp(counts, 0, nprocs, true);
00156   }
00157   else{
00158     try{
00159       AlltoAllCount(comm, env, sendCount, recvCount);
00160     }
00161     Z2_FORWARD_EXCEPTIONS;
00162   }
00163 
00164   size_t *offsetIn = new size_t [nprocs+1];
00165   size_t *offsetOut = new size_t [nprocs+1];
00166 
00167   ArrayRCP<size_t> offArray1(offsetIn, 0, nprocs+1, true);
00168   ArrayRCP<size_t> offArray2(offsetOut, 0, nprocs+1, true);
00169   
00170   offsetIn[0] = offsetOut[0] = 0;
00171 
00172   int maxMsg=0;
00173   bool offProc = false;
00174 
00175   for (int i=0; i < nprocs; i++){
00176     offsetIn[i+1] = offsetIn[i] + recvCount[i];
00177     offsetOut[i+1] = offsetOut[i] + sendCount[i];
00178     if (recvCount[i] > maxMsg)
00179       maxMsg = recvCount[i];
00180     if (sendCount[i] > maxMsg)
00181       maxMsg = sendCount[i];
00182 
00183     if (!offProc && (i != rank) && (recvCount[i] > 0 || sendCount[i] > 0))
00184       offProc = true;
00185   }
00186 
00187   env.globalInputAssertion(__FILE__, __LINE__,
00188     "message size exceeds MPI limit (sizes, offsets, counts are ints) ",
00189     maxMsg*sizeof(T) <= INT_MAX, BASIC_ASSERTION, rcp(&comm, false));
00190 
00191   size_t totalIn = offsetIn[nprocs];
00192 
00193   T *rptr = NULL;
00194 
00195   if (totalIn)
00196     rptr = new T [totalIn]; 
00197   
00198   env.globalMemoryAssertion(__FILE__, __LINE__, totalIn, !totalIn||rptr, 
00199     rcp(&comm, false));
00200 
00201   recvBuf = Teuchos::arcp<T>(rptr, 0, totalIn, true);
00202 
00203   const T *sptr = sendBuf.getRawPtr();
00204 
00205   // Copy self messages
00206 
00207   if (recvCount[rank] > 0)
00208     memcpy(rptr + offsetIn[rank], sptr + offsetOut[rank], 
00209       recvCount[rank]*sizeof(T));
00210 
00211   if (nprocs < 2)
00212     return;
00213 
00214 #ifdef HAVE_ZOLTAN2_MPI
00215 
00216   // I was getting hangs in Teuchos::waitAll, so I do
00217   // blocking receives below.
00218 
00219   if (offProc){
00220     Array<ArrayRCP<const T> > sendArray(nprocs);
00221     for (int p=0; p < nprocs; p++){
00222       if (p != rank && sendCount[p] > 0)
00223         sendArray[p] = arcp(sptr + offsetOut[p], 0, sendCount[p], false);
00224     }
00225   
00226     for (int p=1; p < nprocs; p++){
00227       int recvFrom = (rank + nprocs - p) % nprocs;
00228       int sendTo = (rank + p) % nprocs;
00229   
00230       if (sendCount[sendTo] > 0){
00231         try{  // non blocking send
00232           Teuchos::isend<int, T>(comm, sendArray[sendTo], sendTo);
00233         }
00234         Z2_THROW_OUTSIDE_ERROR(env);
00235       }
00236   
00237       if (recvCount[recvFrom] > 0){
00238         try{  // blocking receive for message just sent to me
00239           Teuchos::receive<int, T>(comm, recvFrom, recvCount[recvFrom],
00240              rptr + offsetIn[recvFrom]);
00241         }
00242         Z2_THROW_OUTSIDE_ERROR(env);
00243       }
00244     }
00245   }
00246 
00247   comm.barrier();
00248 
00249 #endif
00250 }
00251 
00252 /* \brief Specialization for std::string.
00253   
00254     For string of char. Number of chars in a string limited to SCHAR_MAX.
00255     Send as chars: 1 char for length of string, then chars in string,
00256      1 char for length of next string, and so on.
00257     \todo error checking
00258  */
00259 template <>
00260 void AlltoAllv(const Comm<int> &comm,
00261               const Environment &env,  
00262               const ArrayView<const string> &sendBuf,
00263               const ArrayView<const int> &sendCount,
00264               ArrayRCP<string> &recvBuf,  // output, allocated here
00265               ArrayRCP<int> &recvCount,  // output, allocated here
00266               bool countsAreUniform)
00267 {
00268   int nprocs = comm.getSize();
00269   int *newCount = new int [nprocs];
00270   memset(newCount, 0, sizeof(int) * nprocs);
00271   ArrayView<const int> newSendCount(newCount, nprocs);
00272 
00273 
00274   size_t numStrings = sendBuf.size();
00275   size_t numChars = 0;
00276   bool fail=false;
00277 
00278   for (int p=0, i=0; !fail && p < nprocs; p++){
00279     for (int c=0; !fail && c < sendCount[p]; c++, i++){
00280       size_t nchars = sendBuf[i].size();
00281       if (nchars > SCHAR_MAX)
00282         fail = true;
00283       else
00284         newCount[p] += nchars;
00285     }
00286     newCount[p] += sendCount[p];
00287     numChars += newCount[p];
00288   }
00289 
00290   if (fail)
00291     throw std::runtime_error("id string length exceeds SCHAR_MAX");
00292 
00293   char *sbuf = NULL;
00294   if (numChars > 0)
00295     sbuf = new char [numChars];
00296   char *sbufptr = sbuf;
00297 
00298   ArrayView<const char> newSendBuf(sbuf, numChars);
00299 
00300   for (size_t i=0; i < numStrings; i++){
00301     size_t nchars = sendBuf[i].size();
00302     *sbufptr++ = static_cast<char>(nchars);
00303     for (size_t j=0; j < nchars; j++)
00304       *sbufptr++ = sendBuf[i][j];
00305   }
00306 
00307   ArrayRCP<char> newRecvBuf;
00308   ArrayRCP<int> newRecvCount;
00309 
00310   AlltoAllv<char>(comm, env, newSendBuf, newSendCount, 
00311     newRecvBuf, newRecvCount, countsAreUniform);
00312 
00313   delete [] sbuf;
00314   delete [] newCount;
00315 
00316   char *inBuf = newRecvBuf.getRawPtr();
00317 
00318   int numNewStrings = 0;
00319   char *buf = inBuf;
00320   char *endChar = inBuf + newRecvBuf.size();
00321   while (buf < endChar){
00322     int slen = static_cast<int>(*buf++);
00323     buf += slen;
00324     numNewStrings++;
00325   }
00326 
00327   // Counts to return
00328   int *numStringsRecv = new int [nprocs];
00329   memset(numStringsRecv, 0, sizeof(int) * nprocs);
00330 
00331   // Data to return
00332   string *newStrings = new string [numNewStrings];
00333 
00334   buf = inBuf;
00335   int next = 0;
00336 
00337   for (int p=0; p < nprocs; p++){
00338     int nchars = newRecvCount[p];
00339     endChar = buf + nchars;
00340     while (buf < endChar){
00341       int slen = *buf++;
00342       string nextString;
00343       for (int i=0; i < slen; i++)
00344         nextString.push_back(*buf++);
00345       newStrings[next++] = nextString;
00346       numStringsRecv[p]++;
00347     }
00348   }
00349 
00350   recvBuf = arcp<string>(newStrings, 0, numNewStrings, true);
00351   recvCount = arcp<int>(numStringsRecv, 0, nprocs, true);
00352 }
00353 
00354 #ifdef HAVE_ZOLTAN2_LONG_LONG
00355 
00356 /* \brief Specialization for unsigned long long 
00357  */
00358 template <>
00359 void AlltoAllv(const Comm<int> &comm,
00360               const Environment &env,  
00361               const ArrayView<const unsigned long long> &sendBuf,
00362               const ArrayView<const int> &sendCount,
00363               ArrayRCP<unsigned long long> &recvBuf,  // output, allocated here
00364               ArrayRCP<int> &recvCount,  // output, allocated here
00365               bool countsAreUniform)
00366 {
00367   const long long *sbuf = 
00368     reinterpret_cast<const long long *>(sendBuf.getRawPtr());
00369   ArrayView<const long long> newSendBuf(sbuf, sendBuf.size());
00370   ArrayRCP<long long> newRecvBuf;
00371 
00372   AlltoAllv<long long>(comm, env, newSendBuf, sendCount, 
00373     newRecvBuf, recvCount, countsAreUniform);
00374 
00375   recvBuf = arcp_reinterpret_cast<unsigned long long>(newRecvBuf);
00376 }
00377 #endif
00378 
00379 /* \brief Specialization for unsigned short 
00380  */
00381 template <>
00382 void AlltoAllv(const Comm<int> &comm,
00383               const Environment &env,  
00384               const ArrayView<const unsigned short> &sendBuf,
00385               const ArrayView<const int> &sendCount,
00386               ArrayRCP<unsigned short> &recvBuf,  // output, allocated here
00387               ArrayRCP<int> &recvCount,  // output, allocated here
00388               bool countsAreUniform)
00389 {
00390   const short *sbuf = reinterpret_cast<const short *>(sendBuf.getRawPtr());
00391   ArrayView<const short> newSendBuf(sbuf, sendBuf.size());
00392   ArrayRCP<short> newRecvBuf;
00393 
00394   AlltoAllv<short>(comm, env, newSendBuf, sendCount, 
00395     newRecvBuf, recvCount, countsAreUniform);
00396 
00397   recvBuf = arcp_reinterpret_cast<unsigned short>(newRecvBuf);
00398 }
00399 
00400 /* \brief For data type unsigned char (no Teuchos::DirectSerializationTraits)
00401  */
00402 template <>
00403 void AlltoAllv(const Comm<int> &comm,
00404               const Environment &env,  
00405               const ArrayView<const unsigned char> &sendBuf,
00406               const ArrayView<const int> &sendCount,
00407               ArrayRCP<unsigned char> &recvBuf,      // output, allocated here
00408               ArrayRCP<int> &recvCount,  // output, allocated here
00409               bool countsAreUniform)
00410 {
00411   const char *sbuf = reinterpret_cast<const char *>(sendBuf.getRawPtr());
00412   ArrayView<const char> newSendBuf(sbuf, sendBuf.size());
00413   ArrayRCP<char> newRecvBuf;
00414 
00415   AlltoAllv<char>(comm, env, newSendBuf, sendCount, 
00416     newRecvBuf, recvCount, countsAreUniform);
00417 
00418   recvBuf = arcp_reinterpret_cast<unsigned char>(newRecvBuf);
00419 }
00420 
00432 template <typename T>
00433 void AlltoAll(const Comm<int> &comm,
00434               const Environment &env,
00435               const ArrayView<const T> &sendBuf,
00436               int count,
00437               ArrayRCP<T> &recvBuf)         // output - allocated here
00438 {
00439   int nprocs = comm.getSize();
00440 
00441   if (count == 0) return;   // count is the same on all procs
00442 
00443   int *counts = new int [nprocs];
00444   for (int i=0; i < nprocs; i++)
00445     counts[i] = count;
00446 
00447   ArrayView<const int> sendCounts(counts, nprocs);
00448   ArrayRCP<int> recvCounts;
00449 
00450   AlltoAllv<T>(comm, env, sendBuf, sendCounts, recvBuf, recvCounts, true);
00451 
00452   delete [] counts;
00453 }
00454 
00455 }                   // namespace Z2
00456 #endif