00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029 #include "Teuchos_MPIComm.hpp"
00030 #include "Teuchos_ErrorPolling.hpp"
00031
00032
00033 using namespace Teuchos;
00034
00035 namespace Teuchos
00036 {
00037 const int MPIComm::INT = 1;
00038 const int MPIComm::FLOAT = 2;
00039 const int MPIComm::DOUBLE = 3;
00040 const int MPIComm::CHAR = 4;
00041
00042 const int MPIComm::SUM = 5;
00043 const int MPIComm::MIN = 6;
00044 const int MPIComm::MAX = 7;
00045 const int MPIComm::PROD = 8;
00046 }
00047
00048
00049 MPIComm::MPIComm()
00050 :
00051 #ifdef HAVE_MPI
00052 comm_(MPI_COMM_WORLD),
00053 #endif
00054 nProc_(0), myRank_(0)
00055 {
00056 init();
00057 }
00058
00059 #ifdef HAVE_MPI
00060 MPIComm::MPIComm(MPI_Comm comm)
00061 : comm_(comm), nProc_(0), myRank_(0)
00062 {
00063 init();
00064 }
00065 #endif
00066
00067 int MPIComm::mpiIsRunning() const
00068 {
00069 int mpiStarted = 0;
00070 #ifdef HAVE_MPI
00071 MPI_Initialized(&mpiStarted);
00072 #endif
00073 return mpiStarted;
00074 }
00075
00076 void MPIComm::init()
00077 {
00078 #ifdef HAVE_MPI
00079
00080 if (mpiIsRunning())
00081 {
00082 errCheck(MPI_Comm_rank(comm_, &myRank_), "Comm_rank");
00083 errCheck(MPI_Comm_size(comm_, &nProc_), "Comm_size");
00084 }
00085 else
00086 {
00087 nProc_ = 1;
00088 myRank_ = 0;
00089 }
00090
00091 #else
00092 nProc_ = 1;
00093 myRank_ = 0;
00094 #endif
00095 }
00096
00097 #ifdef USE_MPI_GROUPS
00098
00099 MPIComm::MPIComm(const MPIComm& parent, const MPIGroup& group)
00100 :
00101 #ifdef HAVE_MPI
00102 comm_(MPI_COMM_WORLD),
00103 #endif
00104 nProc_(0), myRank_(0)
00105 {
00106 #ifdef HAVE_MPI
00107 if (group.getNProc()==0)
00108 {
00109 rank_ = -1;
00110 nProc_ = 0;
00111 }
00112 else if (parent.containsMe())
00113 {
00114 MPI_Comm parentComm = parent.comm_;
00115 MPI_Group newGroup = group.group_;
00116
00117 errCheck(MPI_Comm_create(parentComm, newGroup, &comm_),
00118 "Comm_create");
00119
00120 if (group.containsProc(parent.getRank()))
00121 {
00122 errCheck(MPI_Comm_rank(comm_, &rank_), "Comm_rank");
00123
00124 errCheck(MPI_Comm_size(comm_, &nProc_), "Comm_size");
00125 }
00126 else
00127 {
00128 rank_ = -1;
00129 nProc_ = -1;
00130 return;
00131 }
00132 }
00133 else
00134 {
00135 rank_ = -1;
00136 nProc_ = -1;
00137 }
00138 #endif
00139 }
00140
00141 #endif
00142
00143 MPIComm& MPIComm::world()
00144 {
00145 static MPIComm w = MPIComm();
00146 return w;
00147 }
00148
00149
00150 void MPIComm::synchronize() const
00151 {
00152 #ifdef HAVE_MPI
00153
00154 {
00155 if (mpiIsRunning())
00156 {
00157
00158
00159 TEUCHOS_POLL_FOR_FAILURES(*this);
00160
00161
00162 errCheck(::MPI_Barrier(comm_), "Barrier");
00163 }
00164 }
00165
00166 #endif
00167 }
00168
00169 void MPIComm::allToAll(void* sendBuf, int sendCount, int sendType,
00170 void* recvBuf, int recvCount, int recvType) const
00171 {
00172 #ifdef HAVE_MPI
00173
00174 {
00175 MPI_Datatype mpiSendType = getDataType(sendType);
00176 MPI_Datatype mpiRecvType = getDataType(recvType);
00177
00178
00179 if (mpiIsRunning())
00180 {
00181
00182
00183 TEUCHOS_POLL_FOR_FAILURES(*this);
00184
00185
00186 errCheck(::MPI_Alltoall(sendBuf, sendCount, mpiSendType,
00187 recvBuf, recvCount, mpiRecvType,
00188 comm_), "Alltoall");
00189 }
00190 }
00191
00192 #else
00193 (void)sendBuf;
00194 (void)sendCount;
00195 (void)sendType;
00196 (void)recvBuf;
00197 (void)recvCount;
00198 (void)recvType;
00199 #endif
00200 }
00201
00202 void MPIComm::allToAllv(void* sendBuf, int* sendCount,
00203 int* sendDisplacements, int sendType,
00204 void* recvBuf, int* recvCount,
00205 int* recvDisplacements, int recvType) const
00206 {
00207 #ifdef HAVE_MPI
00208
00209 {
00210 MPI_Datatype mpiSendType = getDataType(sendType);
00211 MPI_Datatype mpiRecvType = getDataType(recvType);
00212
00213 if (mpiIsRunning())
00214 {
00215
00216
00217 TEUCHOS_POLL_FOR_FAILURES(*this);
00218
00219
00220 errCheck(::MPI_Alltoallv(sendBuf, sendCount, sendDisplacements, mpiSendType,
00221 recvBuf, recvCount, recvDisplacements, mpiRecvType,
00222 comm_), "Alltoallv");
00223 }
00224 }
00225
00226 #else
00227 (void)sendBuf;
00228 (void)sendCount;
00229 (void)sendDisplacements;
00230 (void)sendType;
00231 (void)recvBuf;
00232 (void)recvCount;
00233 (void)recvDisplacements;
00234 (void)recvType;
00235 #endif
00236 }
00237
00238 void MPIComm::gather(void* sendBuf, int sendCount, int sendType,
00239 void* recvBuf, int recvCount, int recvType,
00240 int root) const
00241 {
00242 #ifdef HAVE_MPI
00243
00244 {
00245 MPI_Datatype mpiSendType = getDataType(sendType);
00246 MPI_Datatype mpiRecvType = getDataType(recvType);
00247
00248
00249 if (mpiIsRunning())
00250 {
00251
00252
00253 TEUCHOS_POLL_FOR_FAILURES(*this);
00254
00255
00256 errCheck(::MPI_Gather(sendBuf, sendCount, mpiSendType,
00257 recvBuf, recvCount, mpiRecvType,
00258 root, comm_), "Gather");
00259 }
00260 }
00261
00262 #endif
00263 }
00264
00265 void MPIComm::gatherv(void* sendBuf, int sendCount, int sendType,
00266 void* recvBuf, int* recvCount, int* displacements, int recvType,
00267 int root) const
00268 {
00269 #ifdef HAVE_MPI
00270
00271 {
00272 MPI_Datatype mpiSendType = getDataType(sendType);
00273 MPI_Datatype mpiRecvType = getDataType(recvType);
00274
00275 if (mpiIsRunning())
00276 {
00277
00278
00279 TEUCHOS_POLL_FOR_FAILURES(*this);
00280
00281
00282 errCheck(::MPI_Gatherv(sendBuf, sendCount, mpiSendType,
00283 recvBuf, recvCount, displacements, mpiRecvType,
00284 root, comm_), "Gatherv");
00285 }
00286 }
00287
00288 #endif
00289 }
00290
00291 void MPIComm::allGather(void* sendBuf, int sendCount, int sendType,
00292 void* recvBuf, int recvCount,
00293 int recvType) const
00294 {
00295 #ifdef HAVE_MPI
00296
00297 {
00298 MPI_Datatype mpiSendType = getDataType(sendType);
00299 MPI_Datatype mpiRecvType = getDataType(recvType);
00300
00301 if (mpiIsRunning())
00302 {
00303
00304
00305 TEUCHOS_POLL_FOR_FAILURES(*this);
00306
00307
00308 errCheck(::MPI_Allgather(sendBuf, sendCount, mpiSendType,
00309 recvBuf, recvCount,
00310 mpiRecvType, comm_),
00311 "AllGather");
00312 }
00313 }
00314
00315 #endif
00316 }
00317
00318
00319 void MPIComm::allGatherv(void* sendBuf, int sendCount, int sendType,
00320 void* recvBuf, int* recvCount,
00321 int* recvDisplacements,
00322 int recvType) const
00323 {
00324 #ifdef HAVE_MPI
00325
00326 {
00327 MPI_Datatype mpiSendType = getDataType(sendType);
00328 MPI_Datatype mpiRecvType = getDataType(recvType);
00329
00330 if (mpiIsRunning())
00331 {
00332
00333
00334 TEUCHOS_POLL_FOR_FAILURES(*this);
00335
00336
00337 errCheck(::MPI_Allgatherv(sendBuf, sendCount, mpiSendType,
00338 recvBuf, recvCount, recvDisplacements,
00339 mpiRecvType,
00340 comm_),
00341 "AllGatherv");
00342 }
00343 }
00344
00345 #endif
00346 }
00347
00348
00349 void MPIComm::bcast(void* msg, int length, int type, int src) const
00350 {
00351 #ifdef HAVE_MPI
00352
00353 {
00354 if (mpiIsRunning())
00355 {
00356
00357
00358 TEUCHOS_POLL_FOR_FAILURES(*this);
00359
00360
00361 MPI_Datatype mpiType = getDataType(type);
00362 errCheck(::MPI_Bcast(msg, length, mpiType, src,
00363 comm_), "Bcast");
00364 }
00365 }
00366
00367 #endif
00368 }
00369
00370 void MPIComm::allReduce(void* input, void* result, int inputCount,
00371 int type, int op) const
00372 {
00373 #ifdef HAVE_MPI
00374
00375
00376 {
00377 MPI_Op mpiOp = getOp(op);
00378 MPI_Datatype mpiType = getDataType(type);
00379
00380 if (mpiIsRunning())
00381 {
00382 errCheck(::MPI_Allreduce(input, result, inputCount, mpiType,
00383 mpiOp, comm_),
00384 "Allreduce");
00385 }
00386 }
00387
00388 #endif
00389 }
00390
00391
00392 #ifdef HAVE_MPI
00393
00394 MPI_Datatype MPIComm::getDataType(int type)
00395 {
00396 TEST_FOR_EXCEPTION(
00397 !(type == INT || type==FLOAT
00398 || type==DOUBLE || type==CHAR),
00399 std::range_error,
00400 "invalid type " << type << " in MPIComm::getDataType");
00401
00402 if(type == INT) return MPI_INT;
00403 if(type == FLOAT) return MPI_FLOAT;
00404 if(type == DOUBLE) return MPI_DOUBLE;
00405
00406 return MPI_CHAR;
00407 }
00408
00409
00410 void MPIComm::errCheck(int errCode, const std::string& methodName)
00411 {
00412 TEST_FOR_EXCEPTION(errCode != 0, std::runtime_error,
00413 "MPI function MPI_" << methodName
00414 << " returned error code=" << errCode);
00415 }
00416
00417 MPI_Op MPIComm::getOp(int op)
00418 {
00419
00420 TEST_FOR_EXCEPTION(
00421 !(op == SUM || op==MAX
00422 || op==MIN || op==PROD),
00423 std::range_error,
00424 "invalid operator "
00425 << op << " in MPIComm::getOp");
00426
00427 if( op == SUM) return MPI_SUM;
00428 else if( op == MAX) return MPI_MAX;
00429 else if( op == MIN) return MPI_MIN;
00430 return MPI_PROD;
00431 }
00432
00433 #endif