00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029 #ifndef THYRA_SERIAL_MULTI_VECTOR_BASE_HPP
00030 #define THYRA_SERIAL_MULTI_VECTOR_BASE_HPP
00031
00032 #include "Thyra_SerialMultiVectorBaseDecl.hpp"
00033 #include "Thyra_MultiVectorDefaultBase.hpp"
00034 #include "Thyra_SingleScalarEuclideanLinearOpBase.hpp"
00035 #include "Thyra_SerialVectorSpaceBase.hpp"
00036 #include "Thyra_ExplicitMultiVectorView.hpp"
00037 #include "Thyra_apply_op_helper.hpp"
00038 #include "RTOp_parallel_helpers.h"
00039 #include "Teuchos_Workspace.hpp"
00040 #include "Teuchos_dyn_cast.hpp"
00041 #include "Teuchos_Time.hpp"
00042
00043
00044
00045
00046 namespace Thyra {
00047
00048 template<class Scalar>
00049 SerialMultiVectorBase<Scalar>::SerialMultiVectorBase()
00050 :in_applyOp_(false)
00051 ,numRows_(0)
00052 ,numCols_(0)
00053 {}
00054
00055
00056
00057
00058
00059
00060
00061
00062
00063
00064
00065
00066
00067
00068
00069
00070
00071
00072
00073
00074
00075 template<class Scalar>
00076 void SerialMultiVectorBase<Scalar>::applyOp(
00077 const RTOpPack::RTOpT<Scalar> &pri_op
00078 ,const int num_multi_vecs
00079 ,const MultiVectorBase<Scalar>* multi_vecs[]
00080 ,const int num_targ_multi_vecs
00081 ,MultiVectorBase<Scalar>* targ_multi_vecs[]
00082 ,RTOpPack::ReductTarget* reduct_objs[]
00083 ,const Index pri_first_ele_in
00084 ,const Index pri_sub_dim_in
00085 ,const Index pri_global_offset_in
00086 ,const Index sec_first_ele_in
00087 ,const Index sec_sub_dim_in
00088 ) const
00089 {
00090 #ifdef _DEBUG
00091
00092 TEST_FOR_EXCEPTION(
00093 in_applyOp_, std::invalid_argument
00094 ,"SerialMultiVectorBase<>::applyOp(...): Error, this method is being entered recursively which is a "
00095 "clear sign that one of the methods getSubMultiVector(...), freeSubMultiVector(...) or commitSubMultiVector(...) "
00096 "was not implemented properly!"
00097 );
00098 apply_op_validate_input(
00099 "SerialMultiVectorBase<Scalar>::applyOp(...)", *this->domain(), *this->range()
00100 ,pri_op,num_multi_vecs,multi_vecs,num_targ_multi_vecs,targ_multi_vecs
00101 ,reduct_objs,pri_first_ele_in,pri_sub_dim_in,pri_global_offset_in
00102 ,sec_first_ele_in,sec_sub_dim_in
00103 );
00104 #endif
00105 in_applyOp_ = true;
00106 apply_op_serial(
00107 *(this->domain()),*(this->range())
00108 ,pri_op,num_multi_vecs,multi_vecs,num_targ_multi_vecs,targ_multi_vecs
00109 ,reduct_objs,pri_first_ele_in,pri_sub_dim_in,pri_global_offset_in
00110 ,sec_first_ele_in,sec_sub_dim_in
00111 );
00112 in_applyOp_ = false;
00113 }
00114
00115 template<class Scalar>
00116 void SerialMultiVectorBase<Scalar>::getSubMultiVector(
00117 const Range1D &rowRng_in
00118 ,const Range1D &colRng_in
00119 ,RTOpPack::SubMultiVectorT<Scalar> *sub_mv
00120 ) const
00121 {
00122 const Range1D rowRng = validateRowRange(rowRng_in);
00123 const Range1D colRng = validateColRange(colRng_in);
00124 const Scalar *localValues = NULL; int leadingDim = 0;
00125 this->getData(&localValues,&leadingDim);
00126 sub_mv->initialize(
00127 rowRng.lbound()-1
00128 ,rowRng.size()
00129 ,colRng.lbound()-1
00130 ,colRng.size()
00131 ,localValues
00132 +(rowRng.lbound()-1)
00133 +(colRng.lbound()-1)*leadingDim
00134 ,leadingDim
00135 );
00136 }
00137
00138 template<class Scalar>
00139 void SerialMultiVectorBase<Scalar>::freeSubMultiVector(
00140 RTOpPack::SubMultiVectorT<Scalar>* sub_mv
00141 ) const
00142 {
00143 freeData( sub_mv->values() );
00144 sub_mv->set_uninitialized();
00145 }
00146
00147 template<class Scalar>
00148 void SerialMultiVectorBase<Scalar>::getSubMultiVector(
00149 const Range1D &rowRng_in
00150 ,const Range1D &colRng_in
00151 ,RTOpPack::MutableSubMultiVectorT<Scalar> *sub_mv
00152 )
00153 {
00154 const Range1D rowRng = validateRowRange(rowRng_in);
00155 const Range1D colRng = validateColRange(colRng_in);
00156 Scalar *localValues = NULL; int leadingDim = 0;
00157 this->getData(&localValues,&leadingDim);
00158 sub_mv->initialize(
00159 rowRng.lbound()-1
00160 ,rowRng.size()
00161 ,colRng.lbound()-1
00162 ,colRng.size()
00163 ,localValues
00164 +(rowRng.lbound()-1)
00165 +(colRng.lbound()-1)*leadingDim
00166 ,leadingDim
00167 );
00168 }
00169
00170 template<class Scalar>
00171 void SerialMultiVectorBase<Scalar>::commitSubMultiVector(
00172 RTOpPack::MutableSubMultiVectorT<Scalar>* sub_mv
00173 )
00174 {
00175 commitData( sub_mv->values() );
00176 sub_mv->set_uninitialized();
00177 }
00178
00179
00180
00181
00182
00183
00184 template<class Scalar>
00185 bool SerialMultiVectorBase<Scalar>::opSupported(ETransp M_trans) const
00186 {
00187 typedef Teuchos::ScalarTraits<Scalar> ST;
00188 return ( ST::isComplex ? ( M_trans!=CONJ ) : true );
00189 }
00190
00191 template<class Scalar>
00192 void SerialMultiVectorBase<Scalar>::euclideanApply(
00193 const ETransp M_trans
00194 ,const MultiVectorBase<Scalar> &X
00195 ,MultiVectorBase<Scalar> *Y
00196 ,const Scalar alpha
00197 ,const Scalar beta
00198 ) const
00199 {
00200
00201 typedef Teuchos::ScalarTraits<Scalar> ST;
00202
00203 #ifdef THYRA_SERIAL_MULTI_VECTOR_BASE_PRINT_TIMES
00204 Teuchos::Time timerTotal("dummy",true);
00205 Teuchos::Time timer("dummy");
00206 #endif
00207
00208
00209
00210
00211
00212
00213
00214
00215
00216
00217
00218
00219
00220 #ifdef _DEBUG
00221 THYRA_ASSERT_LINEAR_OP_MULTIVEC_APPLY_SPACES("SerialMultiVectorBase<Scalar>::euclideanApply()",*this,M_trans,X,Y);
00222 #endif
00223
00224
00225
00226
00227
00228 #ifdef THYRA_SERIAL_MULTI_VECTOR_BASE_PRINT_TIMES
00229 timer.start();
00230 #endif
00231 ExplicitMutableMultiVectorView<Scalar> Y_local(*Y);
00232 ExplicitMultiVectorView<Scalar> M_local(*this);
00233 ExplicitMultiVectorView<Scalar> X_local(X);
00234 #ifdef THYRA_SERIAL_MULTI_VECTOR_BASE_PRINT_TIMES
00235 timer.stop();
00236 std::cout << "\nSerialMultiVectorBase<Scalar>::apply(...): Time for getting view = " << timer.totalElapsedTime() << " seconds\n";
00237 #endif
00238
00239
00240
00241
00242
00243
00244
00245
00246
00247
00248
00249 #ifdef THYRA_SERIAL_MULTI_VECTOR_BASE_PRINT_TIMES
00250 timer.start();
00251 #endif
00252 Teuchos::ETransp t_transp;
00253 if(ST::isComplex) {
00254 switch(M_trans) {
00255 case NOTRANS: t_transp = Teuchos::NO_TRANS; break;
00256 case TRANS: t_transp = Teuchos::TRANS; break;
00257 case CONJTRANS: t_transp = Teuchos::CONJ_TRANS; break;
00258 default: TEST_FOR_EXCEPT(true);
00259 }
00260 }
00261 else {
00262 switch(real_trans(M_trans)) {
00263 case NOTRANS: t_transp = Teuchos::NO_TRANS; break;
00264 case TRANS: t_transp = Teuchos::TRANS; break;
00265 default: TEST_FOR_EXCEPT(true);
00266 }
00267 }
00268 blas_.GEMM(
00269 t_transp
00270 ,Teuchos::NO_TRANS
00271 ,Y_local.subDim()
00272 ,Y_local.numSubCols()
00273 ,real_trans(M_trans)==NOTRANS ? M_local.numSubCols() : M_local.subDim()
00274 ,alpha
00275 ,const_cast<Scalar*>(M_local.values())
00276 ,M_local.leadingDim()
00277 ,const_cast<Scalar*>(X_local.values())
00278 ,X_local.leadingDim()
00279 ,beta
00280 ,Y_local.values()
00281 ,Y_local.leadingDim()
00282 );
00283 #ifdef THYRA_SERIAL_MULTI_VECTOR_BASE_PRINT_TIMES
00284 timer.stop();
00285 std::cout << "\nSerialMultiVectorBase<Scalar>::apply(...): Time for GEMM = " << timer.totalElapsedTime() << " seconds\n";
00286 #endif
00287
00288 #ifdef THYRA_SERIAL_MULTI_VECTOR_BASE_PRINT_TIMES
00289 timer.stop();
00290 std::cout << "\nSerialMultiVectorBase<Scalar>::apply(...): Total time = " << timerTotal.totalElapsedTime() << " seconds\n";
00291 #endif
00292
00293 }
00294
00295
00296
00297 template<class Scalar>
00298 void SerialMultiVectorBase<Scalar>::updateSpace()
00299 {
00300 if(numRows_ == 0) {
00301 const VectorSpaceBase<Scalar> *range = this->range().get();
00302 if(range) {
00303 numRows_ = range->dim();
00304 numCols_ = this->domain()->dim();
00305 }
00306 else {
00307 numRows_ = 0;
00308 numCols_ = 0;
00309 }
00310 }
00311 }
00312
00313 template<class Scalar>
00314 Range1D SerialMultiVectorBase<Scalar>::validateRowRange( const Range1D &rowRng_in ) const
00315 {
00316 const Range1D rowRng = RangePack::full_range(rowRng_in,1,numRows_);
00317 #ifdef _DEBUG
00318 TEST_FOR_EXCEPTION(
00319 rowRng.lbound() < 1 || numRows_ < rowRng.ubound(), std::invalid_argument
00320 ,"SerialMultiVectorBase<Scalar>::validateRowRange(rowRng): Error, the range rowRng = ["
00321 <<rowRng.lbound()<<","<<rowRng.ubound()<<"] is not "
00322 "in the range [1,"<<numRows_<<"]!"
00323 );
00324 #endif
00325 return rowRng;
00326 }
00327
00328 template<class Scalar>
00329 Range1D SerialMultiVectorBase<Scalar>::validateColRange( const Range1D &colRng_in ) const
00330 {
00331 const Range1D colRng = RangePack::full_range(colRng_in,1,numCols_);
00332 #ifdef _DEBUG
00333 TEST_FOR_EXCEPTION(
00334 colRng.lbound() < 1 || numCols_ < colRng.ubound(), std::invalid_argument
00335 ,"SerialMultiVectorBase<Scalar>::validateColRange(colRng): Error, the range colRng = ["
00336 <<colRng.lbound()<<","<<colRng.ubound()<<"] is not "
00337 "in the range [1,"<<numCols_<<"]!"
00338 );
00339 #endif
00340 return colRng;
00341 }
00342
00343 }
00344
00345 #endif // THYRA_SERIAL_MULTI_VECTOR_BASE_HPP