00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030 #ifndef RTOPPACK_TOP_LINEAR_COMBINATION_HPP
00031 #define RTOPPACK_TOP_LINEAR_COMBINATION_HPP
00032
00033 #include "RTOpPack_RTOpT.hpp"
00034 #include "Teuchos_Workspace.hpp"
00035
00036 namespace RTOpPack {
00037
00048 template<class Scalar>
00049 class TOpLinearCombination : public RTOpT<Scalar> {
00050 public:
00052 TOpLinearCombination(
00053 const int num_vecs = 0
00054 ,const Scalar alpha[] = NULL
00055 ,const Scalar &beta = Teuchos::ScalarTraits<Scalar>::zero()
00056 )
00057 :RTOpT<Scalar>("TOpLinearCombination")
00058 ,beta_(beta)
00059 { if(num_vecs) this->alpha(num_vecs,alpha); }
00061 void beta( const Scalar& beta ) { beta_ = beta; }
00063 Scalar beta() const { return beta_; }
00065 void alpha(
00066 const int num_vecs
00067 ,const Scalar alpha[]
00068 )
00069 {
00070 TEST_FOR_EXCEPT( num_vecs<=0 || alpha==NULL );
00071 alpha_.resize(0);
00072 alpha_.insert(alpha_.begin(),alpha,alpha+num_vecs);
00073 }
00075 int num_vecs() const { return alpha_.size(); }
00077 const Scalar* alpha() const { return &alpha_[0]; }
00081 void apply_op(
00082 const int num_vecs, const ConstSubVectorView<Scalar> sub_vecs[]
00083 ,const int num_targ_vecs, const SubVectorView<Scalar> targ_sub_vecs[]
00084 ,ReductTarget *reduct_obj
00085 ) const
00086 {
00087 typedef Teuchos::ScalarTraits<Scalar> ST;
00088 using Teuchos::Workspace;
00089 Teuchos::WorkspaceStore* wss = Teuchos::get_default_workspace_store().get();
00090
00091 #ifdef TEUCHOS_DEBUG
00092 TEST_FOR_EXCEPT( static_cast<int>(alpha_.size()) != num_vecs );
00093 TEST_FOR_EXCEPT( sub_vecs == NULL );
00094 TEST_FOR_EXCEPT( num_targ_vecs != 1 );
00095 TEST_FOR_EXCEPT( targ_sub_vecs == NULL );
00096 #endif
00097
00098 const RTOpPack::index_type subDim = targ_sub_vecs[0].subDim();
00099 Scalar *z0_val = targ_sub_vecs[0].values();
00100 const ptrdiff_t z0_s = targ_sub_vecs[0].stride();
00101 Workspace<const Scalar*> v_val(wss,num_vecs,false);
00102 Workspace<ptrdiff_t> v_s(wss,num_vecs,false);
00103 for( int k = 0; k < num_vecs; ++k ) {
00104 #ifdef TEUCHOS_DEBUG
00105 TEST_FOR_EXCEPT( sub_vecs[k].subDim() != subDim );
00106 TEST_FOR_EXCEPT( sub_vecs[k].globalOffset() != targ_sub_vecs[0].globalOffset() );
00107 #endif
00108 v_val[k] = sub_vecs[k].values();
00109 v_s[k] = sub_vecs[k].stride();
00110 }
00111
00112
00113
00114
00115 if( num_vecs == 1 ) {
00116
00117
00118
00119 const Scalar alpha = alpha_[0], beta = beta_;
00120 const Scalar *v0_val = v_val[0];
00121 const ptrdiff_t v0_s = v_s[0];
00122 if( beta==ST::zero() ) {
00123
00124 if( z0_s==1 && v0_s==1 ) {
00125 for( int j = 0; j < subDim; ++j )
00126 (*z0_val++) = alpha * (*v0_val++);
00127 }
00128 else {
00129 for( int j = 0; j < subDim; ++j, z0_val+=z0_s, v0_val+=v0_s )
00130 (*z0_val) = alpha * (*v0_val);
00131 }
00132 }
00133 else if( beta==ST::one() ) {
00134
00135
00136
00137 if( z0_s==1 && v0_s==1 ) {
00138 for( int j = 0; j < subDim; ++j )
00139 (*z0_val++) += alpha * (*v0_val++);
00140 }
00141 else {
00142 for( int j = 0; j < subDim; ++j, z0_val+=z0_s, v0_val+=v0_s )
00143 (*z0_val) += alpha * (*v0_val);
00144 }
00145 }
00146 else {
00147
00148 if( z0_s==1 && v0_s==1 ) {
00149 for( int j = 0; j < subDim; ++j, ++z0_val )
00150 (*z0_val) = alpha * (*v0_val++) + beta*(*z0_val);
00151 }
00152 else {
00153 for( int j = 0; j < subDim; ++j, z0_val+=z0_s, v0_val+=v0_s )
00154 (*z0_val) = alpha * (*v0_val) + beta*(*z0_val);
00155 }
00156 }
00157 }
00158 else if( num_vecs == 2 ) {
00159
00160
00161
00162 const Scalar alpha0 = alpha_[0], alpha1=alpha_[1], beta = beta_;
00163 const Scalar *v0_val = v_val[0];
00164 const ptrdiff_t v0_s = v_s[0];
00165 const Scalar *v1_val = v_val[1];
00166 const ptrdiff_t v1_s = v_s[1];
00167 if( beta==ST::zero() ) {
00168 if( alpha0 == ST::one() ) {
00169 if( alpha1 == ST::one() ) {
00170
00171 if( z0_s==1 && v0_s==1 && v1_s==1 ) {
00172 for( int j = 0; j < subDim; ++j )
00173 (*z0_val++) = (*v0_val++) + (*v1_val++);
00174 }
00175 else {
00176 for( int j = 0; j < subDim; ++j, z0_val+=z0_s, v0_val+=v0_s, v1_val+=v1_s )
00177 (*z0_val) = (*v0_val) + (*v1_val);
00178 }
00179 }
00180 else {
00181
00182 if( z0_s==1 && v0_s==1 && v1_s==1 ) {
00183 for( int j = 0; j < subDim; ++j )
00184 (*z0_val++) = (*v0_val++) + alpha1*(*v1_val++);
00185 }
00186 else {
00187 for( int j = 0; j < subDim; ++j, z0_val+=z0_s, v0_val+=v0_s, v1_val+=v1_s )
00188 (*z0_val) = (*v0_val) + alpha1*(*v1_val);
00189 }
00190 }
00191 }
00192 else {
00193 if( alpha1 == ST::one() ) {
00194
00195 if( z0_s==1 && v0_s==1 && v1_s==1 ) {
00196 for( int j = 0; j < subDim; ++j )
00197 (*z0_val++) = alpha0*(*v0_val++) + (*v1_val++);
00198 }
00199 else {
00200 for( int j = 0; j < subDim; ++j, z0_val+=z0_s, v0_val+=v0_s, v1_val+=v1_s )
00201 (*z0_val) = alpha0*(*v0_val) + (*v1_val);
00202 }
00203 }
00204 else {
00205
00206 if( z0_s==1 && v0_s==1 && v1_s==1 ) {
00207 for( int j = 0; j < subDim; ++j )
00208 (*z0_val++) = alpha0*(*v0_val++) + alpha1*(*v1_val++);
00209 }
00210 else {
00211 for( int j = 0; j < subDim; ++j, z0_val+=z0_s, v0_val+=v0_s, v1_val+=v1_s )
00212 (*z0_val) = alpha0*(*v0_val) + alpha1*(*v1_val);
00213 }
00214 }
00215 }
00216 }
00217 else if( beta==ST::one() ) {
00218 if( alpha0 == ST::one() ) {
00219 if( alpha1 == ST::one() ) {
00220
00221 if( z0_s==1 && v0_s==1 && v1_s==1 ) {
00222 for( int j = 0; j < subDim; ++j, ++z0_val )
00223 (*z0_val) += (*v0_val++) + (*v1_val++);
00224 }
00225 else {
00226 for( int j = 0; j < subDim; ++j, z0_val+=z0_s, v0_val+=v0_s, v1_val+=v1_s )
00227 (*z0_val) += (*v0_val) + (*v1_val);
00228 }
00229 }
00230 else {
00231
00232 if( z0_s==1 && v0_s==1 && v1_s==1 ) {
00233 for( int j = 0; j < subDim; ++j, ++z0_val )
00234 (*z0_val) += (*v0_val++) + alpha1*(*v1_val++);
00235 }
00236 else {
00237 for( int j = 0; j < subDim; ++j, z0_val+=z0_s, v0_val+=v0_s, v1_val+=v1_s )
00238 (*z0_val) += (*v0_val) + alpha1*(*v1_val);
00239 }
00240 }
00241 }
00242 else {
00243 if( alpha1 == ST::one() ) {
00244
00245 if( z0_s==1 && v0_s==1 && v1_s==1 ) {
00246 for( int j = 0; j < subDim; ++j, ++z0_val )
00247 (*z0_val) += alpha0*(*v0_val++) + (*v1_val++);
00248 }
00249 else {
00250 for( int j = 0; j < subDim; ++j, z0_val+=z0_s, v0_val+=v0_s, v1_val+=v1_s )
00251 (*z0_val) += alpha0*(*v0_val) + (*v1_val);
00252 }
00253 }
00254 else {
00255
00256 if( z0_s==1 && v0_s==1 && v1_s==1 ) {
00257 for( int j = 0; j < subDim; ++j, ++z0_val )
00258 (*z0_val) += alpha0*(*v0_val++) + alpha1*(*v1_val++);
00259 }
00260 else {
00261 for( int j = 0; j < subDim; ++j, z0_val+=z0_s, v0_val+=v0_s, v1_val+=v1_s )
00262 (*z0_val) += alpha0*(*v0_val) + alpha1*(*v1_val);
00263 }
00264 }
00265 }
00266 }
00267 else {
00268 if( alpha0 == ST::one() ) {
00269 if( alpha1 == ST::one() ) {
00270
00271 if( z0_s==1 && v0_s==1 && v1_s==1 ) {
00272 for( int j = 0; j < subDim; ++j, ++z0_val )
00273 (*z0_val) = (*v0_val++) + (*v1_val++) + beta*(*z0_val);
00274 }
00275 else {
00276 for( int j = 0; j < subDim; ++j, z0_val+=z0_s, v0_val+=v0_s, v1_val+=v1_s )
00277 (*z0_val) = (*v0_val) + (*v1_val) + beta*(*z0_val);
00278 }
00279 }
00280 else {
00281
00282 if( z0_s==1 && v0_s==1 && v1_s==1 ) {
00283 for( int j = 0; j < subDim; ++j, ++z0_val )
00284 (*z0_val) = (*v0_val++) + alpha1*(*v1_val++) + beta*(*z0_val);
00285 }
00286 else {
00287 for( int j = 0; j < subDim; ++j, z0_val+=z0_s, v0_val+=v0_s, v1_val+=v1_s )
00288 (*z0_val) = (*v0_val) + alpha1*(*v1_val) + beta*(*z0_val);
00289 }
00290 }
00291 }
00292 else {
00293 if( alpha1 == ST::one() ) {
00294
00295 if( z0_s==1 && v0_s==1 && v1_s==1 ) {
00296 for( int j = 0; j < subDim; ++j, ++z0_val )
00297 (*z0_val) = alpha0*(*v0_val++) + (*v1_val++) + beta*(*z0_val);
00298 }
00299 else {
00300 for( int j = 0; j < subDim; ++j, z0_val+=z0_s, v0_val+=v0_s, v1_val+=v1_s )
00301 (*z0_val) = alpha0*(*v0_val) + (*v1_val) + beta*(*z0_val);
00302 }
00303 }
00304 else {
00305
00306 if( z0_s==1 && v0_s==1 && v1_s==1 ) {
00307 for( int j = 0; j < subDim; ++j, ++z0_val )
00308 (*z0_val) = alpha0*(*v0_val++) + alpha1*(*v1_val++) + beta*(*z0_val);
00309 }
00310 else {
00311 for( int j = 0; j < subDim; ++j, z0_val+=z0_s, v0_val+=v0_s, v1_val+=v1_s )
00312 (*z0_val) = alpha0*(*v0_val) + alpha1*(*v1_val) + beta*(*z0_val);
00313 }
00314 }
00315 }
00316 }
00317 }
00318 else {
00319
00320
00321
00322
00323 if( beta_ == ST::zero() ) {
00324 for( int j = 0; j < subDim; ++j, z0_val += z0_s )
00325 (*z0_val) = ST::zero();
00326 }
00327 else if( beta_ != ST::one() ) {
00328 for( int j = 0; j < subDim; ++j, z0_val += z0_s )
00329 (*z0_val) *= beta_;
00330 }
00331
00332 z0_val = targ_sub_vecs[0].values();
00333 for( int j = 0; j < subDim; ++j, z0_val += z0_s ) {
00334 for( int k = 0; k < num_vecs; ++k ) {
00335 const Scalar
00336 &alpha_k = alpha_[k],
00337 &v_k_val = *v_val[k];
00338 (*z0_val) += alpha_k * v_k_val;
00339 v_val[k] += v_s[k];
00340 }
00341 }
00342 }
00343 }
00345 private:
00346 Scalar beta_;
00347 std::vector<Scalar> alpha_;
00348 };
00349
00350 }
00351
00352 #endif // RTOPPACK_TOP_LINEAR_COMBINATION_HPP