RTOpPack_TOpLinearCombination.hpp

00001 // @HEADER
00002 // ***********************************************************************
00003 // 
00004 // RTOp: Interfaces and Support Software for Vector Reduction Transformation
00005 //       Operations
00006 //                Copyright (2006) Sandia Corporation
00007 // 
00008 // Under terms of Contract DE-AC04-94AL85000, there is a non-exclusive
00009 // license for use of this work by or on behalf of the U.S. Government.
00010 // 
00011 // This library is free software; you can redistribute it and/or modify
00012 // it under the terms of the GNU Lesser General Public License as
00013 // published by the Free Software Foundation; either version 2.1 of the
00014 // License, or (at your option) any later version.
00015 //  
00016 // This library is distributed in the hope that it will be useful, but
00017 // WITHOUT ANY WARRANTY; without even the implied warranty of
00018 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
00019 // Lesser General Public License for more details.
00020 //  
00021 // You should have received a copy of the GNU Lesser General Public
00022 // License along with this library; if not, write to the Free Software
00023 // Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307
00024 // USA
00025 // Questions? Contact Roscoe A. Bartlett (rabartl@sandia.gov) 
00026 // 
00027 // ***********************************************************************
00028 // @HEADER
00029 
00030 #ifndef RTOPPACK_TOP_LINEAR_COMBINATION_HPP
00031 #define RTOPPACK_TOP_LINEAR_COMBINATION_HPP
00032 
00033 #include "RTOpPack_RTOpT.hpp"
00034 #include "Teuchos_Workspace.hpp"
00035 
00036 namespace RTOpPack {
00037 
00048 template<class Scalar>
00049 class TOpLinearCombination : public RTOpT<Scalar> {
00050 public:
00052   TOpLinearCombination(
00053     const int       num_vecs  = 0
00054     ,const Scalar   alpha[]   = NULL
00055     ,const Scalar   &beta     = Teuchos::ScalarTraits<Scalar>::zero()
00056     )
00057     :RTOpT<Scalar>("TOpLinearCombination")
00058      ,beta_(beta)
00059     { if(num_vecs) this->alpha(num_vecs,alpha); }
00061   void beta( const Scalar& beta ) { beta_ = beta; }
00063   Scalar beta() const { return beta_; }
00065   void alpha( 
00066     const int       num_vecs
00067     ,const Scalar   alpha[]    
00068     )
00069     {
00070       TEST_FOR_EXCEPT( num_vecs<=0 || alpha==NULL );
00071       alpha_.resize(0);
00072       alpha_.insert(alpha_.begin(),alpha,alpha+num_vecs);
00073     }
00075   int num_vecs() const { return alpha_.size(); }
00077   const Scalar* alpha() const { return &alpha_[0]; }
00081   void apply_op(
00082     const int   num_vecs,       const ConstSubVectorView<Scalar>         sub_vecs[]
00083     ,const int  num_targ_vecs,  const SubVectorView<Scalar>  targ_sub_vecs[]
00084     ,ReductTarget *reduct_obj
00085     ) const
00086     {
00087       typedef Teuchos::ScalarTraits<Scalar> ST;
00088       using Teuchos::Workspace;
00089       Teuchos::WorkspaceStore* wss = Teuchos::get_default_workspace_store().get();
00090       // Validate input
00091 #ifdef TEUCHOS_DEBUG
00092       TEST_FOR_EXCEPT( static_cast<int>(alpha_.size()) != num_vecs );
00093       TEST_FOR_EXCEPT( sub_vecs == NULL );
00094       TEST_FOR_EXCEPT( num_targ_vecs != 1 );
00095       TEST_FOR_EXCEPT( targ_sub_vecs == NULL );
00096 #endif
00097       // Get pointers to local data
00098       const RTOpPack::index_type    subDim  = targ_sub_vecs[0].subDim();
00099       Scalar                        *z0_val = targ_sub_vecs[0].values();
00100       const ptrdiff_t               z0_s    = targ_sub_vecs[0].stride();
00101       Workspace<const Scalar*> v_val(wss,num_vecs,false);
00102       Workspace<ptrdiff_t>     v_s(wss,num_vecs,false);
00103       for( int k = 0; k < num_vecs; ++k ) {
00104 #ifdef TEUCHOS_DEBUG
00105         TEST_FOR_EXCEPT( sub_vecs[k].subDim() != subDim );
00106         TEST_FOR_EXCEPT( sub_vecs[k].globalOffset() != targ_sub_vecs[0].globalOffset() );
00107 #endif          
00108         v_val[k] = sub_vecs[k].values();
00109         v_s[k]   = sub_vecs[k].stride();
00110       }
00111       //
00112       // Perform the operation and specialize the cases for num_vecs = 1 and 2
00113       // in order to get good performance.
00114       //
00115       if( num_vecs == 1 ) {
00116         //
00117         // z0 = alpha*v0 + beta*z0
00118         //
00119         const Scalar alpha = alpha_[0], beta = beta_;
00120         const Scalar       *v0_val = v_val[0];
00121         const ptrdiff_t    v0_s    = v_s[0]; 
00122         if( beta==ST::zero() ) {
00123           // z0 = alpha*v0
00124           if( z0_s==1 && v0_s==1 ) {
00125             for( int j = 0; j < subDim; ++j )
00126               (*z0_val++) = alpha * (*v0_val++);
00127           }
00128           else {
00129             for( int j = 0; j < subDim; ++j, z0_val+=z0_s, v0_val+=v0_s )
00130               (*z0_val) = alpha * (*v0_val);
00131           }
00132         }
00133         else if( beta==ST::one() ) {
00134           //
00135           // z0 = alpha*v0 + z0
00136           //
00137           if( z0_s==1 && v0_s==1 ) {
00138             for( int j = 0; j < subDim; ++j )
00139               (*z0_val++) += alpha * (*v0_val++);
00140           }
00141           else {
00142             for( int j = 0; j < subDim; ++j, z0_val+=z0_s, v0_val+=v0_s )
00143               (*z0_val) += alpha * (*v0_val);
00144           }
00145         }
00146         else {
00147           // z0 = alpha*v0 + beta*z0
00148           if( z0_s==1 && v0_s==1 ) {
00149             for( int j = 0; j < subDim; ++j, ++z0_val )
00150               (*z0_val) = alpha * (*v0_val++) + beta*(*z0_val);
00151           }
00152           else {
00153             for( int j = 0; j < subDim; ++j, z0_val+=z0_s, v0_val+=v0_s )
00154               (*z0_val) = alpha * (*v0_val) + beta*(*z0_val);
00155           }
00156         }
00157       }
00158       else if( num_vecs == 2 ) {
00159         //
00160         // z0 = alpha0*v0 + alpha1*v1 + beta*z0
00161         //
00162         const Scalar alpha0 = alpha_[0], alpha1=alpha_[1], beta = beta_;
00163         const Scalar     *v0_val = v_val[0];
00164         const ptrdiff_t  v0_s    = v_s[0]; 
00165         const Scalar     *v1_val = v_val[1];
00166         const ptrdiff_t  v1_s    = v_s[1]; 
00167         if( beta==ST::zero() ) {
00168           if( alpha0 == ST::one() ) {
00169             if( alpha1 == ST::one() ) {
00170               // z0 = v0 + v1
00171               if( z0_s==1 && v0_s==1 && v1_s==1 ) {
00172                 for( int j = 0; j < subDim; ++j )
00173                   (*z0_val++) = (*v0_val++) + (*v1_val++);
00174               }
00175               else {
00176                 for( int j = 0; j < subDim; ++j, z0_val+=z0_s, v0_val+=v0_s, v1_val+=v1_s )
00177                   (*z0_val) = (*v0_val) + (*v1_val);
00178               }
00179             }
00180             else {
00181               // z0 = v0 + alpha1*v1
00182               if( z0_s==1 && v0_s==1 && v1_s==1 ) {
00183                 for( int j = 0; j < subDim; ++j )
00184                   (*z0_val++) = (*v0_val++) + alpha1*(*v1_val++);
00185               }
00186               else {
00187                 for( int j = 0; j < subDim; ++j, z0_val+=z0_s, v0_val+=v0_s, v1_val+=v1_s )
00188                   (*z0_val) = (*v0_val) + alpha1*(*v1_val);
00189               }
00190             }
00191           }
00192           else {
00193             if( alpha1 == ST::one() ) {
00194               // z0 = alpha0*v0 + v1
00195               if( z0_s==1 && v0_s==1 && v1_s==1 ) {
00196                 for( int j = 0; j < subDim; ++j )
00197                   (*z0_val++) = alpha0*(*v0_val++) + (*v1_val++);
00198               }
00199               else {
00200                 for( int j = 0; j < subDim; ++j, z0_val+=z0_s, v0_val+=v0_s, v1_val+=v1_s )
00201                   (*z0_val) = alpha0*(*v0_val) + (*v1_val);
00202               }
00203             }
00204             else {
00205               // z0 = alpha0*v0 + alpha1*v1
00206               if( z0_s==1 && v0_s==1 && v1_s==1 ) {
00207                 for( int j = 0; j < subDim; ++j )
00208                   (*z0_val++) = alpha0*(*v0_val++) + alpha1*(*v1_val++);
00209               }
00210               else {
00211                 for( int j = 0; j < subDim; ++j, z0_val+=z0_s, v0_val+=v0_s, v1_val+=v1_s )
00212                   (*z0_val) = alpha0*(*v0_val) + alpha1*(*v1_val);
00213               }
00214             }
00215           }
00216         }
00217         else if( beta==ST::one() ) {
00218           if( alpha0 == ST::one() ) {
00219             if( alpha1 == ST::one() ) {
00220               // z0 = v0 + v1 + z0
00221               if( z0_s==1 && v0_s==1 && v1_s==1 ) {
00222                 for( int j = 0; j < subDim; ++j, ++z0_val )
00223                   (*z0_val) += (*v0_val++) + (*v1_val++);
00224               }
00225               else {
00226                 for( int j = 0; j < subDim; ++j, z0_val+=z0_s, v0_val+=v0_s, v1_val+=v1_s )
00227                   (*z0_val) += (*v0_val) + (*v1_val);
00228               }
00229             }
00230             else {
00231               // z0 = v0 + alpha1*v1 + z0
00232               if( z0_s==1 && v0_s==1 && v1_s==1 ) {
00233                 for( int j = 0; j < subDim; ++j, ++z0_val )
00234                   (*z0_val) += (*v0_val++) + alpha1*(*v1_val++);
00235               }
00236               else {
00237                 for( int j = 0; j < subDim; ++j, z0_val+=z0_s, v0_val+=v0_s, v1_val+=v1_s )
00238                   (*z0_val) += (*v0_val) + alpha1*(*v1_val);
00239               }
00240             }
00241           }
00242           else {
00243             if( alpha1 == ST::one() ) {
00244               // z0 = alpha0*v0 + v1 + z0
00245               if( z0_s==1 && v0_s==1 && v1_s==1 ) {
00246                 for( int j = 0; j < subDim; ++j, ++z0_val )
00247                   (*z0_val) += alpha0*(*v0_val++) + (*v1_val++);
00248               }
00249               else {
00250                 for( int j = 0; j < subDim; ++j, z0_val+=z0_s, v0_val+=v0_s, v1_val+=v1_s )
00251                   (*z0_val) += alpha0*(*v0_val) + (*v1_val);
00252               }
00253             }
00254             else {
00255               // z0 = alpha0*v0 + alpha1*v1 + z0
00256               if( z0_s==1 && v0_s==1 && v1_s==1 ) {
00257                 for( int j = 0; j < subDim; ++j, ++z0_val )
00258                   (*z0_val) += alpha0*(*v0_val++) + alpha1*(*v1_val++);
00259               }
00260               else {
00261                 for( int j = 0; j < subDim; ++j, z0_val+=z0_s, v0_val+=v0_s, v1_val+=v1_s )
00262                   (*z0_val) += alpha0*(*v0_val) + alpha1*(*v1_val);
00263               }
00264             }
00265           }
00266         }
00267         else {
00268           if( alpha0 == ST::one() ) {
00269             if( alpha1 == ST::one() ) {
00270               // z0 = v0 + v1 + beta*z0
00271               if( z0_s==1 && v0_s==1 && v1_s==1 ) {
00272                 for( int j = 0; j < subDim; ++j, ++z0_val )
00273                   (*z0_val) = (*v0_val++) + (*v1_val++) + beta*(*z0_val);
00274               }
00275               else {
00276                 for( int j = 0; j < subDim; ++j, z0_val+=z0_s, v0_val+=v0_s, v1_val+=v1_s )
00277                   (*z0_val) = (*v0_val) + (*v1_val) + beta*(*z0_val);
00278               }
00279             }
00280             else {
00281               // z0 = v0 + alpha1*v1 + beta*z0
00282               if( z0_s==1 && v0_s==1 && v1_s==1 ) {
00283                 for( int j = 0; j < subDim; ++j, ++z0_val )
00284                   (*z0_val) = (*v0_val++) + alpha1*(*v1_val++) + beta*(*z0_val);
00285               }
00286               else {
00287                 for( int j = 0; j < subDim; ++j, z0_val+=z0_s, v0_val+=v0_s, v1_val+=v1_s )
00288                   (*z0_val) = (*v0_val) + alpha1*(*v1_val) + beta*(*z0_val);
00289               }
00290             }
00291           }
00292           else {
00293             if( alpha1 == ST::one() ) {
00294               // z0 = alpha0*v0 + v1 + beta*z0
00295               if( z0_s==1 && v0_s==1 && v1_s==1 ) {
00296                 for( int j = 0; j < subDim; ++j, ++z0_val )
00297                   (*z0_val) = alpha0*(*v0_val++) + (*v1_val++) + beta*(*z0_val);
00298               }
00299               else {
00300                 for( int j = 0; j < subDim; ++j, z0_val+=z0_s, v0_val+=v0_s, v1_val+=v1_s )
00301                   (*z0_val) = alpha0*(*v0_val) + (*v1_val) + beta*(*z0_val);
00302               }
00303             }
00304             else {
00305               // z0 = alpha0*v0 + alpha1*v1 + beta*z0
00306               if( z0_s==1 && v0_s==1 && v1_s==1 ) {
00307                 for( int j = 0; j < subDim; ++j, ++z0_val )
00308                   (*z0_val) = alpha0*(*v0_val++) + alpha1*(*v1_val++) + beta*(*z0_val);
00309               }
00310               else {
00311                 for( int j = 0; j < subDim; ++j, z0_val+=z0_s, v0_val+=v0_s, v1_val+=v1_s )
00312                   (*z0_val) = alpha0*(*v0_val) + alpha1*(*v1_val) + beta*(*z0_val);
00313               }
00314             }
00315           }
00316         }
00317       }
00318       else {
00319         //
00320         // Totally general implementation (but least efficient)
00321         //
00322         // z0 *= beta
00323         if( beta_ == ST::zero() ) {
00324           for( int j = 0; j < subDim; ++j, z0_val += z0_s )
00325             (*z0_val) = ST::zero();
00326         }
00327         else if( beta_ != ST::one() ) {
00328           for( int j = 0; j < subDim; ++j, z0_val += z0_s )
00329             (*z0_val) *= beta_;
00330         }
00331         // z0 += sum( alpha[k]*v[k], k=0...num_vecs-1)
00332         z0_val = targ_sub_vecs[0].values();
00333         for( int j = 0; j < subDim; ++j, z0_val += z0_s ) {
00334           for( int k = 0; k < num_vecs; ++k ) {
00335             const Scalar
00336               &alpha_k = alpha_[k],
00337               &v_k_val = *v_val[k];
00338             (*z0_val) += alpha_k * v_k_val;
00339             v_val[k] += v_s[k];
00340           }
00341         }
00342       }
00343     }
00345 private:
00346   Scalar                  beta_;
00347   std::vector<Scalar>  alpha_;
00348 }; // class TOpLinearCombination
00349 
00350 } // namespace RTOpPack
00351 
00352 #endif // RTOPPACK_TOP_LINEAR_COMBINATION_HPP

Generated on Tue Oct 20 12:46:13 2009 for Collection of Concrete Vector Reduction/Transformation Operator Implementations by doxygen 1.4.7