00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029 #ifndef RTOPPACK_TOP_LINEAR_COMBINATION_HPP
00030 #define RTOPPACK_TOP_LINEAR_COMBINATION_HPP
00031
00032 #include "RTOpPack_RTOpT.hpp"
00033 #include "Teuchos_Workspace.hpp"
00034
00035 namespace RTOpPack {
00036
00046 template<class Scalar>
00047 class TOpLinearCombination : public RTOpT<Scalar> {
00048 public:
00050 TOpLinearCombination(
00051 const int num_vecs = 0
00052 ,const Scalar alpha[] = NULL
00053 ,const Scalar &beta = Teuchos::ScalarTraits<Scalar>::zero()
00054 )
00055 :RTOpT<Scalar>("TOpLinearCombination")
00056 ,beta_(beta)
00057 { if(num_vecs) this->alpha(num_vecs,alpha); }
00059 void beta( const Scalar& beta ) { beta_ = beta; }
00061 Scalar beta() const { return beta_; }
00063 void alpha(
00064 const int num_vecs
00065 ,const Scalar alpha[]
00066 )
00067 {
00068 TEST_FOR_EXCEPT( num_vecs<=0 || alpha==NULL );
00069 alpha_.resize(0);
00070 alpha_.insert(alpha_.begin(),alpha,alpha+num_vecs);
00071 }
00073 int num_vecs() const { return alpha_.size(); }
00075 const Scalar* alpha() const { return &alpha_[0]; }
00079 void apply_op(
00080 const int num_vecs, const SubVectorT<Scalar> sub_vecs[]
00081 ,const int num_targ_vecs, const MutableSubVectorT<Scalar> targ_sub_vecs[]
00082 ,ReductTarget *reduct_obj
00083 ) const
00084 {
00085 typedef Teuchos::ScalarTraits<Scalar> ST;
00086 using Teuchos::Workspace;
00087 Teuchos::WorkspaceStore* wss = Teuchos::get_default_workspace_store().get();
00088
00089 #ifdef _DEBUG
00090 TEST_FOR_EXCEPT( static_cast<int>(alpha_.size()) != num_vecs );
00091 TEST_FOR_EXCEPT( sub_vecs == NULL );
00092 TEST_FOR_EXCEPT( num_targ_vecs != 1 );
00093 TEST_FOR_EXCEPT( targ_sub_vecs == NULL );
00094 #endif
00095
00096 const RTOpPack::index_type subDim = targ_sub_vecs[0].subDim();
00097 Scalar *z0_val = targ_sub_vecs[0].values();
00098 const ptrdiff_t z0_s = targ_sub_vecs[0].stride();
00099 Workspace<const Scalar*> v_val(wss,num_vecs,false);
00100 Workspace<ptrdiff_t> v_s(wss,num_vecs,false);
00101 for( int k = 0; k < num_vecs; ++k ) {
00102 #ifdef _DEBUG
00103 TEST_FOR_EXCEPT( sub_vecs[k].subDim() != subDim );
00104 TEST_FOR_EXCEPT( sub_vecs[k].globalOffset() != targ_sub_vecs[0].globalOffset() );
00105 #endif
00106 v_val[k] = sub_vecs[k].values();
00107 v_s[k] = sub_vecs[k].stride();
00108 }
00109
00110
00111
00112
00113 if( num_vecs == 1 ) {
00114
00115
00116
00117 const Scalar alpha = alpha_[0], beta = beta_;
00118 const Scalar *v0_val = v_val[0];
00119 const ptrdiff_t v0_s = v_s[0];
00120 if( beta==ST::zero() ) {
00121
00122 if( z0_s==1 && v0_s==1 ) {
00123 for( int j = 0; j < subDim; ++j )
00124 (*z0_val++) = alpha * (*v0_val++);
00125 }
00126 else {
00127 for( int j = 0; j < subDim; ++j, z0_val+=z0_s, v0_val+=v0_s )
00128 (*z0_val) = alpha * (*v0_val);
00129 }
00130 }
00131 else if( beta==ST::one() ) {
00132
00133
00134
00135 if( z0_s==1 && v0_s==1 ) {
00136 for( int j = 0; j < subDim; ++j )
00137 (*z0_val++) += alpha * (*v0_val++);
00138 }
00139 else {
00140 for( int j = 0; j < subDim; ++j, z0_val+=z0_s, v0_val+=v0_s )
00141 (*z0_val) += alpha * (*v0_val);
00142 }
00143 }
00144 else {
00145
00146 if( z0_s==1 && v0_s==1 ) {
00147 for( int j = 0; j < subDim; ++j, ++z0_val )
00148 (*z0_val) = alpha * (*v0_val++) + beta*(*z0_val);
00149 }
00150 else {
00151 for( int j = 0; j < subDim; ++j, z0_val+=z0_s, v0_val+=v0_s )
00152 (*z0_val) = alpha * (*v0_val) + beta*(*z0_val);
00153 }
00154 }
00155 }
00156 else if( num_vecs == 2 ) {
00157
00158
00159
00160 const Scalar alpha0 = alpha_[0], alpha1=alpha_[1], beta = beta_;
00161 const Scalar *v0_val = v_val[0];
00162 const ptrdiff_t v0_s = v_s[0];
00163 const Scalar *v1_val = v_val[1];
00164 const ptrdiff_t v1_s = v_s[1];
00165 if( beta==ST::zero() ) {
00166 if( alpha0 == ST::one() ) {
00167 if( alpha1 == ST::one() ) {
00168
00169 if( z0_s==1 && v0_s==1 && v1_s==1 ) {
00170 for( int j = 0; j < subDim; ++j )
00171 (*z0_val++) = (*v0_val++) + (*v1_val++);
00172 }
00173 else {
00174 for( int j = 0; j < subDim; ++j, z0_val+=z0_s, v0_val+=v0_s, v1_val+=v1_s )
00175 (*z0_val) = (*v0_val) + (*v1_val);
00176 }
00177 }
00178 else {
00179
00180 if( z0_s==1 && v0_s==1 && v1_s==1 ) {
00181 for( int j = 0; j < subDim; ++j )
00182 (*z0_val++) = (*v0_val++) + alpha1*(*v1_val++);
00183 }
00184 else {
00185 for( int j = 0; j < subDim; ++j, z0_val+=z0_s, v0_val+=v0_s, v1_val+=v1_s )
00186 (*z0_val) = (*v0_val) + alpha1*(*v1_val);
00187 }
00188 }
00189 }
00190 else {
00191 if( alpha1 == ST::one() ) {
00192
00193 if( z0_s==1 && v0_s==1 && v1_s==1 ) {
00194 for( int j = 0; j < subDim; ++j )
00195 (*z0_val++) = alpha0*(*v0_val++) + (*v1_val++);
00196 }
00197 else {
00198 for( int j = 0; j < subDim; ++j, z0_val+=z0_s, v0_val+=v0_s, v1_val+=v1_s )
00199 (*z0_val) = alpha0*(*v0_val) + (*v1_val);
00200 }
00201 }
00202 else {
00203
00204 if( z0_s==1 && v0_s==1 && v1_s==1 ) {
00205 for( int j = 0; j < subDim; ++j )
00206 (*z0_val++) = alpha0*(*v0_val++) + alpha1*(*v1_val++);
00207 }
00208 else {
00209 for( int j = 0; j < subDim; ++j, z0_val+=z0_s, v0_val+=v0_s, v1_val+=v1_s )
00210 (*z0_val) = alpha0*(*v0_val) + alpha1*(*v1_val);
00211 }
00212 }
00213 }
00214 }
00215 else if( beta==ST::one() ) {
00216 if( alpha0 == ST::one() ) {
00217 if( alpha1 == ST::one() ) {
00218
00219 if( z0_s==1 && v0_s==1 && v1_s==1 ) {
00220 for( int j = 0; j < subDim; ++j, ++z0_val )
00221 (*z0_val) += (*v0_val++) + (*v1_val++);
00222 }
00223 else {
00224 for( int j = 0; j < subDim; ++j, z0_val+=z0_s, v0_val+=v0_s, v1_val+=v1_s )
00225 (*z0_val) += (*v0_val) + (*v1_val);
00226 }
00227 }
00228 else {
00229
00230 if( z0_s==1 && v0_s==1 && v1_s==1 ) {
00231 for( int j = 0; j < subDim; ++j, ++z0_val )
00232 (*z0_val) += (*v0_val++) + alpha1*(*v1_val++);
00233 }
00234 else {
00235 for( int j = 0; j < subDim; ++j, z0_val+=z0_s, v0_val+=v0_s, v1_val+=v1_s )
00236 (*z0_val) += (*v0_val) + alpha1*(*v1_val);
00237 }
00238 }
00239 }
00240 else {
00241 if( alpha1 == ST::one() ) {
00242
00243 if( z0_s==1 && v0_s==1 && v1_s==1 ) {
00244 for( int j = 0; j < subDim; ++j, ++z0_val )
00245 (*z0_val) += alpha0*(*v0_val++) + (*v1_val++);
00246 }
00247 else {
00248 for( int j = 0; j < subDim; ++j, z0_val+=z0_s, v0_val+=v0_s, v1_val+=v1_s )
00249 (*z0_val) += alpha0*(*v0_val) + (*v1_val);
00250 }
00251 }
00252 else {
00253
00254 if( z0_s==1 && v0_s==1 && v1_s==1 ) {
00255 for( int j = 0; j < subDim; ++j, ++z0_val )
00256 (*z0_val) += alpha0*(*v0_val++) + alpha1*(*v1_val++);
00257 }
00258 else {
00259 for( int j = 0; j < subDim; ++j, z0_val+=z0_s, v0_val+=v0_s, v1_val+=v1_s )
00260 (*z0_val) += alpha0*(*v0_val) + alpha1*(*v1_val);
00261 }
00262 }
00263 }
00264 }
00265 else {
00266 if( alpha0 == ST::one() ) {
00267 if( alpha1 == ST::one() ) {
00268
00269 if( z0_s==1 && v0_s==1 && v1_s==1 ) {
00270 for( int j = 0; j < subDim; ++j, ++z0_val )
00271 (*z0_val) = (*v0_val++) + (*v1_val++) + beta*(*z0_val);
00272 }
00273 else {
00274 for( int j = 0; j < subDim; ++j, z0_val+=z0_s, v0_val+=v0_s, v1_val+=v1_s )
00275 (*z0_val) = (*v0_val) + (*v1_val) + beta*(*z0_val);
00276 }
00277 }
00278 else {
00279
00280 if( z0_s==1 && v0_s==1 && v1_s==1 ) {
00281 for( int j = 0; j < subDim; ++j, ++z0_val )
00282 (*z0_val) = (*v0_val++) + alpha1*(*v1_val++) + beta*(*z0_val);
00283 }
00284 else {
00285 for( int j = 0; j < subDim; ++j, z0_val+=z0_s, v0_val+=v0_s, v1_val+=v1_s )
00286 (*z0_val) = (*v0_val) + alpha1*(*v1_val) + beta*(*z0_val);
00287 }
00288 }
00289 }
00290 else {
00291 if( alpha1 == ST::one() ) {
00292
00293 if( z0_s==1 && v0_s==1 && v1_s==1 ) {
00294 for( int j = 0; j < subDim; ++j, ++z0_val )
00295 (*z0_val) = alpha0*(*v0_val++) + (*v1_val++) + beta*(*z0_val);
00296 }
00297 else {
00298 for( int j = 0; j < subDim; ++j, z0_val+=z0_s, v0_val+=v0_s, v1_val+=v1_s )
00299 (*z0_val) = alpha0*(*v0_val) + (*v1_val) + beta*(*z0_val);
00300 }
00301 }
00302 else {
00303
00304 if( z0_s==1 && v0_s==1 && v1_s==1 ) {
00305 for( int j = 0; j < subDim; ++j, ++z0_val )
00306 (*z0_val) = alpha0*(*v0_val++) + alpha1*(*v1_val++) + beta*(*z0_val);
00307 }
00308 else {
00309 for( int j = 0; j < subDim; ++j, z0_val+=z0_s, v0_val+=v0_s, v1_val+=v1_s )
00310 (*z0_val) = alpha0*(*v0_val) + alpha1*(*v1_val) + beta*(*z0_val);
00311 }
00312 }
00313 }
00314 }
00315 }
00316 else {
00317
00318
00319
00320
00321 if( beta_ == ST::zero() ) {
00322 for( int j = 0; j < subDim; ++j, z0_val += z0_s )
00323 (*z0_val) = ST::zero();
00324 }
00325 else if( beta_ != ST::one() ) {
00326 for( int j = 0; j < subDim; ++j, z0_val += z0_s )
00327 (*z0_val) *= beta_;
00328 }
00329
00330 z0_val = targ_sub_vecs[0].values();
00331 for( int j = 0; j < subDim; ++j, z0_val += z0_s ) {
00332 for( int k = 0; k < num_vecs; ++k ) {
00333 const Scalar
00334 &alpha_k = alpha_[k],
00335 &v_k_val = *v_val[k];
00336 (*z0_val) += alpha_k * v_k_val;
00337 v_val[k] += v_s[k];
00338 }
00339 }
00340 }
00341 }
00343 private:
00344 Scalar beta_;
00345 std::vector<Scalar> alpha_;
00346 };
00347
00348 }
00349
00350 #endif // RTOPPACK_TOP_LINEAR_COMBINATION_HPP