|
Teuchos - Trilinos Tools Package Version of the Day
|
00001 // @HEADER 00002 // *********************************************************************** 00003 // 00004 // Teuchos: Common Tools Package 00005 // Copyright (2004) Sandia Corporation 00006 // 00007 // Under terms of Contract DE-AC04-94AL85000, there is a non-exclusive 00008 // license for use of this work by or on behalf of the U.S. Government. 00009 // 00010 // Redistribution and use in source and binary forms, with or without 00011 // modification, are permitted provided that the following conditions are 00012 // met: 00013 // 00014 // 1. Redistributions of source code must retain the above copyright 00015 // notice, this list of conditions and the following disclaimer. 00016 // 00017 // 2. Redistributions in binary form must reproduce the above copyright 00018 // notice, this list of conditions and the following disclaimer in the 00019 // documentation and/or other materials provided with the distribution. 00020 // 00021 // 3. Neither the name of the Corporation nor the names of the 00022 // contributors may be used to endorse or promote products derived from 00023 // this software without specific prior written permission. 00024 // 00025 // THIS SOFTWARE IS PROVIDED BY SANDIA CORPORATION "AS IS" AND ANY 00026 // EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 00027 // IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR 00028 // PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL SANDIA CORPORATION OR THE 00029 // CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 00030 // EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 00031 // PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 00032 // PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF 00033 // LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING 00034 // NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 00035 // SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 00036 // 00037 // Questions? Contact Michael A. Heroux (maherou@sandia.gov) 00038 // 00039 // *********************************************************************** 00040 // @HEADER 00041 00042 #ifndef TEUCHOS_MPICONTAINERCOMM_H 00043 #define TEUCHOS_MPICONTAINERCOMM_H 00044 00049 #include "Teuchos_ConfigDefs.hpp" 00050 #include "Teuchos_Array.hpp" 00051 #include "Teuchos_MPIComm.hpp" 00052 #include "Teuchos_MPITraits.hpp" 00053 00054 namespace Teuchos 00055 { 00062 template <class T> class MPIContainerComm 00063 { 00064 public: 00065 00067 static void bcast(T& x, int src, const MPIComm& comm); 00068 00070 static void bcast(Array<T>& x, int src, const MPIComm& comm); 00071 00073 static void bcast(Array<Array<T> >& x, 00074 int src, const MPIComm& comm); 00075 00077 static void allGather(const T& outgoing, 00078 Array<T>& incoming, 00079 const MPIComm& comm); 00080 00082 static void allToAll(const Array<T>& outgoing, 00083 Array<Array<T> >& incoming, 00084 const MPIComm& comm); 00085 00087 static void allToAll(const Array<Array<T> >& outgoing, 00088 Array<Array<T> >& incoming, 00089 const MPIComm& comm); 00090 00092 static void gatherv(const Array<T>& outgoing, 00093 Array<Array<T> >& incoming, 00094 int rootRank, 00095 const MPIComm& comm); 00096 00098 static void accumulate(const T& localValue, Array<T>& sums, T& total, 00099 const MPIComm& comm); 00100 00101 private: 00103 static void getBigArray(const Array<Array<T> >& x, 00104 Array<T>& bigArray, 00105 Array<int>& offsets); 00106 00108 static void getSmallArrays(const Array<T>& bigArray, 00109 const Array<int>& offsets, 00110 Array<Array<T> >& x); 00111 00112 00113 }; 00114 00115 00116 #ifndef DOXYGEN_SHOULD_SKIP_THIS 00117 00120 template <> class MPIContainerComm<std::string> 00121 { 00122 public: 00123 static void bcast(std::string& x, int src, const MPIComm& comm); 00124 00126 static void bcast(Array<std::string>& x, int src, const MPIComm& comm); 00127 00129 static void bcast(Array<Array<std::string> >& x, 00130 int src, const MPIComm& comm); 00131 00133 static void allGather(const std::string& outgoing, 00134 Array<std::string>& incoming, 00135 const MPIComm& comm); 00136 00138 static void gatherv(const Array<std::string>& outgoing, 00139 Array<Array<std::string> >& incoming, 00140 int rootRank, 00141 const MPIComm& comm); 00142 00150 static void pack(const Array<std::string>& x, 00151 Array<char>& packed); 00152 00155 static void unpack(const Array<char>& packed, 00156 Array<std::string>& x); 00157 private: 00159 static void getBigArray(const Array<std::string>& x, 00160 Array<char>& bigArray, 00161 Array<int>& offsets); 00162 00165 static void getStrings(const Array<char>& bigArray, 00166 const Array<int>& offsets, 00167 Array<std::string>& x); 00168 }; 00169 00170 #endif // DOXYGEN_SHOULD_SKIP_THIS 00171 00172 /* --------- generic functions for primitives ------------------- */ 00173 00174 template <class T> inline void MPIContainerComm<T>::bcast(T& x, int src, 00175 const MPIComm& comm) 00176 { 00177 comm.bcast((void*)&x, 1, MPITraits<T>::type(), src); 00178 } 00179 00180 00181 /* ----------- generic functions for arrays of primitives ----------- */ 00182 00183 template <class T> 00184 inline void MPIContainerComm<T>::bcast(Array<T>& x, int src, const MPIComm& comm) 00185 { 00186 int len = x.length(); 00187 MPIContainerComm<int>::bcast(len, src, comm); 00188 00189 if (comm.getRank() != src) 00190 { 00191 x.resize(len); 00192 } 00193 if (len==0) return; 00194 00195 /* then broadcast the contents */ 00196 comm.bcast((void*) &(x[0]), (int) len, 00197 MPITraits<T>::type(), src); 00198 } 00199 00200 00201 00202 /* ---------- generic function for arrays of arrays ----------- */ 00203 00204 template <class T> 00205 inline void MPIContainerComm<T>::bcast(Array<Array<T> >& x, int src, const MPIComm& comm) 00206 { 00207 Array<T> bigArray; 00208 Array<int> offsets; 00209 00210 if (src==comm.getRank()) 00211 { 00212 getBigArray(x, bigArray, offsets); 00213 } 00214 00215 bcast(bigArray, src, comm); 00216 MPIContainerComm<int>::bcast(offsets, src, comm); 00217 00218 if (src != comm.getRank()) 00219 { 00220 getSmallArrays(bigArray, offsets, x); 00221 } 00222 } 00223 00224 /* ---------- generic gather and scatter ------------------------ */ 00225 00226 template <class T> inline 00227 void MPIContainerComm<T>::allToAll(const Array<T>& outgoing, 00228 Array<Array<T> >& incoming, 00229 const MPIComm& comm) 00230 { 00231 int numProcs = comm.getNProc(); 00232 00233 // catch degenerate case 00234 if (numProcs==1) 00235 { 00236 incoming.resize(1); 00237 incoming[0] = outgoing; 00238 return; 00239 } 00240 00241 Array<T> sb(numProcs * outgoing.length()); 00242 Array<T> rb(numProcs * outgoing.length()); 00243 00244 T* sendBuf = new T[numProcs * outgoing.length()]; 00245 TEST_FOR_EXCEPTION(sendBuf==0, 00246 std::runtime_error, "Comm::allToAll failed to allocate sendBuf"); 00247 00248 T* recvBuf = new T[numProcs * outgoing.length()]; 00249 TEST_FOR_EXCEPTION(recvBuf==0, 00250 std::runtime_error, "Comm::allToAll failed to allocate recvBuf"); 00251 00252 int i; 00253 for (i=0; i<numProcs; i++) 00254 { 00255 for (int j=0; j<outgoing.length(); j++) 00256 { 00257 sendBuf[i*outgoing.length() + j] = outgoing[j]; 00258 } 00259 } 00260 00261 00262 00263 comm.allToAll(sendBuf, outgoing.length(), MPITraits<T>::type(), 00264 recvBuf, outgoing.length(), MPITraits<T>::type()); 00265 00266 incoming.resize(numProcs); 00267 00268 for (i=0; i<numProcs; i++) 00269 { 00270 incoming[i].resize(outgoing.length()); 00271 for (int j=0; j<outgoing.length(); j++) 00272 { 00273 incoming[i][j] = recvBuf[i*outgoing.length() + j]; 00274 } 00275 } 00276 00277 delete [] sendBuf; 00278 delete [] recvBuf; 00279 } 00280 00281 template <class T> inline 00282 void MPIContainerComm<T>::allToAll(const Array<Array<T> >& outgoing, 00283 Array<Array<T> >& incoming, const MPIComm& comm) 00284 { 00285 int numProcs = comm.getNProc(); 00286 00287 // catch degenerate case 00288 if (numProcs==1) 00289 { 00290 incoming = outgoing; 00291 return; 00292 } 00293 00294 int* sendMesgLength = new int[numProcs]; 00295 TEST_FOR_EXCEPTION(sendMesgLength==0, 00296 std::runtime_error, "failed to allocate sendMesgLength"); 00297 int* recvMesgLength = new int[numProcs]; 00298 TEST_FOR_EXCEPTION(recvMesgLength==0, 00299 std::runtime_error, "failed to allocate recvMesgLength"); 00300 00301 int p = 0; 00302 for (p=0; p<numProcs; p++) 00303 { 00304 sendMesgLength[p] = outgoing[p].length(); 00305 } 00306 00307 comm.allToAll(sendMesgLength, 1, MPIComm::INT, 00308 recvMesgLength, 1, MPIComm::INT); 00309 00310 00311 int totalSendLength = 0; 00312 int totalRecvLength = 0; 00313 for (p=0; p<numProcs; p++) 00314 { 00315 totalSendLength += sendMesgLength[p]; 00316 totalRecvLength += recvMesgLength[p]; 00317 } 00318 00319 T* sendBuf = new T[totalSendLength]; 00320 TEST_FOR_EXCEPTION(sendBuf==0, 00321 std::runtime_error, "failed to allocate sendBuf"); 00322 T* recvBuf = new T[totalRecvLength]; 00323 TEST_FOR_EXCEPTION(recvBuf==0, 00324 std::runtime_error, "failed to allocate recvBuf"); 00325 00326 int* sendDisp = new int[numProcs]; 00327 TEST_FOR_EXCEPTION(sendDisp==0, 00328 std::runtime_error, "failed to allocate sendDisp"); 00329 int* recvDisp = new int[numProcs]; 00330 TEST_FOR_EXCEPTION(recvDisp==0, 00331 std::runtime_error, "failed to allocate recvDisp"); 00332 00333 int count = 0; 00334 sendDisp[0] = 0; 00335 recvDisp[0] = 0; 00336 00337 for (p=0; p<numProcs; p++) 00338 { 00339 for (int i=0; i<outgoing[p].length(); i++) 00340 { 00341 sendBuf[count] = outgoing[p][i]; 00342 count++; 00343 } 00344 if (p>0) 00345 { 00346 sendDisp[p] = sendDisp[p-1] + sendMesgLength[p-1]; 00347 recvDisp[p] = recvDisp[p-1] + recvMesgLength[p-1]; 00348 } 00349 } 00350 00351 comm.allToAllv(sendBuf, sendMesgLength, 00352 sendDisp, MPITraits<T>::type(), 00353 recvBuf, recvMesgLength, 00354 recvDisp, MPITraits<T>::type()); 00355 00356 incoming.resize(numProcs); 00357 for (p=0; p<numProcs; p++) 00358 { 00359 incoming[p].resize(recvMesgLength[p]); 00360 for (int i=0; i<recvMesgLength[p]; i++) 00361 { 00362 incoming[p][i] = recvBuf[recvDisp[p] + i]; 00363 } 00364 } 00365 00366 delete [] sendBuf; 00367 delete [] sendMesgLength; 00368 delete [] sendDisp; 00369 delete [] recvBuf; 00370 delete [] recvMesgLength; 00371 delete [] recvDisp; 00372 } 00373 00374 template <class T> inline 00375 void MPIContainerComm<T>::allGather(const T& outgoing, Array<T>& incoming, 00376 const MPIComm& comm) 00377 { 00378 int nProc = comm.getNProc(); 00379 incoming.resize(nProc); 00380 00381 if (nProc==1) 00382 { 00383 incoming[0] = outgoing; 00384 } 00385 else 00386 { 00387 comm.allGather((void*) &outgoing, 1, MPITraits<T>::type(), 00388 (void*) &(incoming[0]), 1, MPITraits<T>::type()); 00389 } 00390 } 00391 00392 template <class T> inline 00393 void MPIContainerComm<T>::accumulate(const T& localValue, Array<T>& sums, 00394 T& total, 00395 const MPIComm& comm) 00396 { 00397 Array<T> contributions; 00398 allGather(localValue, contributions, comm); 00399 sums.resize(comm.getNProc()); 00400 sums[0] = 0; 00401 total = contributions[0]; 00402 00403 for (int i=0; i<comm.getNProc()-1; i++) 00404 { 00405 total += contributions[i+1]; 00406 sums[i+1] = sums[i] + contributions[i]; 00407 } 00408 } 00409 00410 00411 00412 00413 template <class T> inline 00414 void MPIContainerComm<T>::getBigArray(const Array<Array<T> >& x, Array<T>& bigArray, 00415 Array<int>& offsets) 00416 { 00417 offsets.resize(x.length()+1); 00418 int totalLength = 0; 00419 00420 for (int i=0; i<x.length(); i++) 00421 { 00422 offsets[i] = totalLength; 00423 totalLength += x[i].length(); 00424 } 00425 offsets[x.length()] = totalLength; 00426 00427 bigArray.resize(totalLength); 00428 00429 for (int i=0; i<x.length(); i++) 00430 { 00431 for (int j=0; j<x[i].length(); j++) 00432 { 00433 bigArray[offsets[i]+j] = x[i][j]; 00434 } 00435 } 00436 } 00437 00438 template <class T> inline 00439 void MPIContainerComm<T>::getSmallArrays(const Array<T>& bigArray, 00440 const Array<int>& offsets, 00441 Array<Array<T> >& x) 00442 { 00443 x.resize(offsets.length()-1); 00444 for (int i=0; i<x.length(); i++) 00445 { 00446 x[i].resize(offsets[i+1]-offsets[i]); 00447 for (int j=0; j<x[i].length(); j++) 00448 { 00449 x[i][j] = bigArray[offsets[i] + j]; 00450 } 00451 } 00452 } 00453 00454 00455 #ifndef DOXYGEN_SHOULD_SKIP_THIS 00456 00457 /* --------------- std::string specializations --------------------- */ 00458 00459 inline void MPIContainerComm<std::string>::bcast(std::string& x, 00460 int src, const MPIComm& comm) 00461 { 00462 int len = x.length(); 00463 MPIContainerComm<int>::bcast(len, src, comm); 00464 00465 x.resize(len); 00466 comm.bcast((void*)&(x[0]), len, MPITraits<char>::type(), src); 00467 } 00468 00469 00470 inline void MPIContainerComm<std::string>::bcast(Array<std::string>& x, int src, 00471 const MPIComm& comm) 00472 { 00473 /* begin by packing all the data into a big char array. This will 00474 * take a little time, but will be cheaper than multiple MPI calls */ 00475 Array<char> bigArray; 00476 Array<int> offsets; 00477 if (comm.getRank()==src) 00478 { 00479 getBigArray(x, bigArray, offsets); 00480 } 00481 00482 /* now broadcast the big array and the offsets */ 00483 MPIContainerComm<char>::bcast(bigArray, src, comm); 00484 MPIContainerComm<int>::bcast(offsets, src, comm); 00485 00486 /* finally, reassemble the array of strings */ 00487 if (comm.getRank() != src) 00488 { 00489 getStrings(bigArray, offsets, x); 00490 } 00491 } 00492 00493 inline void MPIContainerComm<std::string>::bcast(Array<Array<std::string> >& x, 00494 int src, const MPIComm& comm) 00495 { 00496 int len = x.length(); 00497 MPIContainerComm<int>::bcast(len, src, comm); 00498 00499 x.resize(len); 00500 for (int i=0; i<len; i++) 00501 { 00502 MPIContainerComm<std::string>::bcast(x[i], src, comm); 00503 } 00504 } 00505 00506 00507 inline void MPIContainerComm<std::string>::allGather(const std::string& outgoing, 00508 Array<std::string>& incoming, 00509 const MPIComm& comm) 00510 { 00511 int nProc = comm.getNProc(); 00512 00513 int sendCount = outgoing.length(); 00514 00515 incoming.resize(nProc); 00516 00517 int* recvCounts = new int[nProc]; 00518 int* recvDisplacements = new int[nProc]; 00519 00520 /* share lengths with all procs */ 00521 comm.allGather((void*) &sendCount, 1, MPIComm::INT, 00522 (void*) recvCounts, 1, MPIComm::INT); 00523 00524 00525 int recvSize = 0; 00526 recvDisplacements[0] = 0; 00527 for (int i=0; i<nProc; i++) 00528 { 00529 recvSize += recvCounts[i]; 00530 if (i < nProc-1) 00531 { 00532 recvDisplacements[i+1] = recvDisplacements[i]+recvCounts[i]; 00533 } 00534 } 00535 00536 char* recvBuf = new char[recvSize]; 00537 00538 comm.allGatherv((void*) outgoing.c_str(), sendCount, MPIComm::CHAR, 00539 recvBuf, recvCounts, recvDisplacements, MPIComm::CHAR); 00540 00541 for (int j=0; j<nProc; j++) 00542 { 00543 char* start = recvBuf + recvDisplacements[j]; 00544 char* tmp = new char[recvCounts[j]+1]; 00545 std::memcpy(tmp, start, recvCounts[j]); 00546 tmp[recvCounts[j]] = '\0'; 00547 incoming[j] = std::string(tmp); 00548 delete [] tmp; 00549 } 00550 00551 delete [] recvCounts; 00552 delete [] recvDisplacements; 00553 delete [] recvBuf; 00554 } 00555 00556 inline void MPIContainerComm<std::string>::gatherv(const Array<std::string>& outgoing, 00557 Array<Array<std::string> >& incoming, 00558 int root, 00559 const MPIComm& comm) 00560 { 00561 int nProc = comm.getNProc(); 00562 00563 Array<char> packedLocalArray; 00564 pack(outgoing, packedLocalArray); 00565 00566 int sendCount = packedLocalArray.size(); 00567 00568 /* gather the message sizes from all procs */ 00569 Array<int> recvCounts(nProc); 00570 Array<int> recvDisplacements(nProc); 00571 00572 comm.gather((void*) &sendCount, 1, MPIComm::INT, 00573 (void*) &(recvCounts[0]), 1, MPIComm::INT, root); 00574 00575 /* compute the displacements */ 00576 int recvSize = 0; 00577 if (root == comm.getRank()) 00578 { 00579 recvDisplacements[0] = 0; 00580 for (int i=0; i<nProc; i++) 00581 { 00582 recvSize += recvCounts[i]; 00583 if (i < nProc-1) 00584 { 00585 recvDisplacements[i+1] = recvDisplacements[i]+recvCounts[i]; 00586 } 00587 } 00588 } 00589 00590 /* set the size to 1 on non-root procs */ 00591 Array<char> recvBuf(std::max(1,recvSize)); 00592 00593 00594 void* sendBuf = (void*) &(packedLocalArray[0]); 00595 void* inBuf = (void*) &(recvBuf[0]); 00596 int* inCounts = &(recvCounts[0]); 00597 int* inDisps = &(recvDisplacements[0]); 00598 00599 /* gather the packed data */ 00600 comm.gatherv( sendBuf, sendCount, MPIComm::CHAR, 00601 inBuf, inCounts, inDisps, 00602 MPIComm::CHAR, root); 00603 00604 /* on the root, unpack the data */ 00605 if (comm.getRank()==root) 00606 { 00607 incoming.resize(nProc); 00608 for (int j=0; j<nProc; j++) 00609 { 00610 char* start = &(recvBuf[0]) + recvDisplacements[j]; 00611 Array<char> tmp(recvCounts[j]+1); 00612 std::memcpy(&(tmp[0]), start, recvCounts[j]); 00613 tmp[recvCounts[j]] = '\0'; 00614 unpack(tmp, incoming[j]); 00615 } 00616 } 00617 00618 00619 } 00620 00621 00622 inline void MPIContainerComm<std::string>::getBigArray(const Array<std::string>& x, 00623 Array<char>& bigArray, 00624 Array<int>& offsets) 00625 { 00626 offsets.resize(x.length()+1); 00627 int totalLength = 0; 00628 00629 for (int i=0; i<x.length(); i++) 00630 { 00631 offsets[i] = totalLength; 00632 totalLength += x[i].length(); 00633 } 00634 offsets[x.length()] = totalLength; 00635 00636 bigArray.resize(totalLength); 00637 00638 for (int i=0; i<x.length(); i++) 00639 { 00640 for (unsigned int j=0; j<x[i].length(); j++) 00641 { 00642 bigArray[offsets[i]+j] = x[i][j]; 00643 } 00644 } 00645 } 00646 00647 inline void MPIContainerComm<std::string>::pack(const Array<std::string>& x, 00648 Array<char>& bigArray) 00649 { 00650 Array<int> offsets(x.size()+1); 00651 int headerSize = (x.size()+2) * sizeof(int); 00652 00653 int totalLength = headerSize; 00654 00655 for (int i=0; i<x.length(); i++) 00656 { 00657 offsets[i] = totalLength; 00658 totalLength += x[i].length(); 00659 } 00660 offsets[x.length()] = totalLength; 00661 00662 /* The array will be packed as follows: 00663 * [numStrs, offset1, ... offsetN, characters data] 00664 */ 00665 00666 bigArray.resize(totalLength); 00667 00668 int* header = reinterpret_cast<int*>( &(bigArray[0]) ); 00669 header[0] = x.size(); 00670 for (Array<std::string>::size_type i=0; i<=x.size(); i++) 00671 { 00672 header[i+1] = offsets[i]; 00673 } 00674 00675 for (int i=0; i<x.length(); i++) 00676 { 00677 for (unsigned int j=0; j<x[i].length(); j++) 00678 { 00679 bigArray[offsets[i]+j] = x[i][j]; 00680 } 00681 } 00682 } 00683 00684 inline void MPIContainerComm<std::string>::unpack(const Array<char>& packed, 00685 Array<std::string>& x) 00686 { 00687 const int* header = reinterpret_cast<const int*>( &(packed[0]) ); 00688 00689 x.resize(header[0]); 00690 Array<int> offsets(x.size()+1); 00691 for (Array<std::string>::size_type i=0; i<=x.size(); i++) offsets[i] = header[i+1]; 00692 00693 for (Array<std::string>::size_type i=0; i<x.size(); i++) 00694 { 00695 x[i].resize(offsets[i+1]-offsets[i]); 00696 for (std::string::size_type j=0; j<x[i].length(); j++) 00697 { 00698 x[i][j] = packed[offsets[i] + j]; 00699 } 00700 } 00701 } 00702 00703 inline void MPIContainerComm<std::string>::getStrings(const Array<char>& bigArray, 00704 const Array<int>& offsets, 00705 Array<std::string>& x) 00706 { 00707 x.resize(offsets.length()-1); 00708 for (int i=0; i<x.length(); i++) 00709 { 00710 x[i].resize(offsets[i+1]-offsets[i]); 00711 for (unsigned int j=0; j<x[i].length(); j++) 00712 { 00713 x[i][j] = bigArray[offsets[i] + j]; 00714 } 00715 } 00716 } 00717 #endif // DOXYGEN_SHOULD_SKIP_THIS 00718 00719 } 00720 00721 00722 #endif 00723 00724
1.7.4