00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029 #include "Teuchos_MPIComm.hpp"
00030
00031
00032 using namespace Teuchos;
00033
00034 namespace Teuchos
00035 {
00036 const int MPIComm::INT = 1;
00037 const int MPIComm::FLOAT = 2;
00038 const int MPIComm::DOUBLE = 3;
00039 const int MPIComm::CHAR = 4;
00040
00041 const int MPIComm::SUM = 5;
00042 const int MPIComm::MIN = 6;
00043 const int MPIComm::MAX = 7;
00044 const int MPIComm::PROD = 8;
00045 }
00046
00047
00048 MPIComm::MPIComm()
00049 :
00050 #ifdef HAVE_MPI
00051 comm_(MPI_COMM_WORLD),
00052 #endif
00053 nProc_(0), myRank_(0)
00054 {
00055 init();
00056 }
00057
00058 #ifdef HAVE_MPI
00059 MPIComm::MPIComm(MPI_Comm comm)
00060 : comm_(comm), nProc_(0), myRank_(0)
00061 {
00062 init();
00063 }
00064 #endif
00065
00066 void MPIComm::init()
00067 {
00068 #ifdef HAVE_MPI
00069
00070 errCheck(MPI_Comm_rank(comm_, &myRank_), "Comm_rank");
00071
00072 errCheck(MPI_Comm_size(comm_, &nProc_), "Comm_size");
00073
00074 #else
00075 nProc_ = 1;
00076 myRank_ = 0;
00077 #endif
00078 }
00079
00080 #ifdef USE_MPI_GROUPS
00081
00082 MPIComm::MPIComm(const MPIComm& parent, const MPIGroup& group)
00083 :
00084 #ifdef HAVE_MPI
00085 comm_(MPI_COMM_WORLD),
00086 #endif
00087 nProc_(0), myRank_(0)
00088 {
00089 #ifdef HAVE_MPI
00090 if (group.getNProc()==0)
00091 {
00092 rank_ = -1;
00093 nProc_ = 0;
00094 }
00095 else if (parent.containsMe())
00096 {
00097 MPI_Comm parentComm = parent.comm_;
00098 MPI_Group newGroup = group.group_;
00099
00100 errCheck(MPI_Comm_create(parentComm, newGroup, &comm_),
00101 "Comm_create");
00102
00103 if (group.containsProc(parent.getRank()))
00104 {
00105 errCheck(MPI_Comm_rank(comm_, &rank_), "Comm_rank");
00106
00107 errCheck(MPI_Comm_size(comm_, &nProc_), "Comm_size");
00108 }
00109 else
00110 {
00111 rank_ = -1;
00112 nProc_ = -1;
00113 return;
00114 }
00115 }
00116 else
00117 {
00118 rank_ = -1;
00119 nProc_ = -1;
00120 }
00121 #endif
00122 }
00123
00124 #endif
00125
00126 MPIComm& MPIComm::world()
00127 {
00128 static MPIComm w = MPIComm();
00129 return w;
00130 }
00131
00132
00133 void MPIComm::synchronize() const
00134 {
00135 #ifdef HAVE_MPI
00136
00137 {
00138 errCheck(::MPI_Barrier(comm_), "Barrier");
00139 }
00140
00141 #endif
00142 }
00143
00144 void MPIComm::allToAll(void* sendBuf, int sendCount, int sendType,
00145 void* recvBuf, int recvCount, int recvType) const
00146 {
00147 #ifdef HAVE_MPI
00148
00149 {
00150 MPI_Datatype mpiSendType = getDataType(sendType);
00151 MPI_Datatype mpiRecvType = getDataType(recvType);
00152
00153 errCheck(::MPI_Alltoall(sendBuf, sendCount, mpiSendType,
00154 recvBuf, recvCount, mpiRecvType,
00155 comm_), "Alltoall");
00156 }
00157
00158 #endif
00159 }
00160
00161 void MPIComm::allToAllv(void* sendBuf, int* sendCount,
00162 int* sendDisplacements, int sendType,
00163 void* recvBuf, int* recvCount,
00164 int* recvDisplacements, int recvType) const
00165 {
00166 #ifdef HAVE_MPI
00167
00168 {
00169 MPI_Datatype mpiSendType = getDataType(sendType);
00170 MPI_Datatype mpiRecvType = getDataType(recvType);
00171
00172 errCheck(::MPI_Alltoallv(sendBuf, sendCount, sendDisplacements, mpiSendType,
00173 recvBuf, recvCount, recvDisplacements, mpiRecvType,
00174 comm_), "Alltoallv");
00175 }
00176
00177 #endif
00178 }
00179
00180 void MPIComm::gather(void* sendBuf, int sendCount, int sendType,
00181 void* recvBuf, int recvCount, int recvType,
00182 int root) const
00183 {
00184 #ifdef HAVE_MPI
00185
00186 {
00187 MPI_Datatype mpiSendType = getDataType(sendType);
00188 MPI_Datatype mpiRecvType = getDataType(recvType);
00189
00190 errCheck(::MPI_Gather(sendBuf, sendCount, mpiSendType,
00191 recvBuf, recvCount, mpiRecvType,
00192 root, comm_), "Gather");
00193 }
00194
00195 #endif
00196 }
00197
00198 void MPIComm::allGather(void* sendBuf, int sendCount, int sendType,
00199 void* recvBuf, int recvCount,
00200 int recvType) const
00201 {
00202 #ifdef HAVE_MPI
00203
00204 {
00205 MPI_Datatype mpiSendType = getDataType(sendType);
00206 MPI_Datatype mpiRecvType = getDataType(recvType);
00207
00208 errCheck(::MPI_Allgather(sendBuf, sendCount, mpiSendType,
00209 recvBuf, recvCount,
00210 mpiRecvType, comm_),
00211 "AllGather");
00212 }
00213
00214 #endif
00215 }
00216
00217
00218 void MPIComm::allGatherv(void* sendBuf, int sendCount, int sendType,
00219 void* recvBuf, int* recvCount,
00220 int* recvDisplacements,
00221 int recvType) const
00222 {
00223 #ifdef HAVE_MPI
00224
00225 {
00226 MPI_Datatype mpiSendType = getDataType(sendType);
00227 MPI_Datatype mpiRecvType = getDataType(recvType);
00228
00229 errCheck(::MPI_Allgatherv(sendBuf, sendCount, mpiSendType,
00230 recvBuf, recvCount, recvDisplacements,
00231 mpiRecvType,
00232 comm_),
00233 "AllGatherv");
00234 }
00235
00236 #endif
00237 }
00238
00239
00240 void MPIComm::bcast(void* msg, int length, int type, int src) const
00241 {
00242 #ifdef HAVE_MPI
00243
00244 {
00245 MPI_Datatype mpiType = getDataType(type);
00246 errCheck(::MPI_Bcast(msg, length, mpiType, src,
00247 comm_), "Bcast");
00248 }
00249
00250 #endif
00251 }
00252
00253 void MPIComm::allReduce(void* input, void* result, int inputCount,
00254 int type, int op) const
00255 {
00256 #ifdef HAVE_MPI
00257
00258
00259 {
00260 MPI_Op mpiOp = getOp(op);
00261 MPI_Datatype mpiType = getDataType(type);
00262
00263 errCheck(::MPI_Allreduce(input, result, inputCount, mpiType,
00264 mpiOp, comm_),
00265 "Allreduce");
00266 }
00267
00268 #endif
00269 }
00270
00271
00272 #ifdef HAVE_MPI
00273
00274 MPI_Datatype MPIComm::getDataType(int type)
00275 {
00276 TEST_FOR_EXCEPTION( !(type == INT || type==FLOAT
00277 || type==DOUBLE || type==CHAR),
00278 range_error,
00279 "invalid type " << type << " in MPIComm::getDataType");
00280
00281 if(type == INT) return MPI_INT;
00282 if(type == FLOAT) return MPI_FLOAT;
00283 if(type == DOUBLE) return MPI_DOUBLE;
00284
00285 return MPI_CHAR;
00286 }
00287
00288
00289 void MPIComm::errCheck(int errCode, const string& methodName)
00290 {
00291 TEST_FOR_EXCEPTION(errCode != 0, runtime_error,
00292 "MPI function MPI_" << methodName
00293 << " returned error code=" << errCode);
00294 }
00295
00296 MPI_Op MPIComm::getOp(int op)
00297 {
00298
00299 TEST_FOR_EXCEPTION( !(op == SUM || op==MAX
00300 || op==MIN || op==PROD),
00301 range_error,
00302 "invalid operator "
00303 << op << " in MPIComm::getOp");
00304
00305 if( op == SUM) return MPI_SUM;
00306 else if( op == MAX) return MPI_MAX;
00307 else if( op == MIN) return MPI_MIN;
00308 return MPI_PROD;
00309 }
00310
00311 #endif