RTOpPack_TOpLinearCombination.hpp

00001 // @HEADER
00002 // ***********************************************************************
00003 // 
00004 //      Thyra: Interfaces and Support Code for the Interoperability of Abstract Numerical Algorithms 
00005 //                 Copyright (2004) Sandia Corporation
00006 // 
00007 // Under terms of Contract DE-AC04-94AL85000, there is a non-exclusive
00008 // license for use of this work by or on behalf of the U.S. Government.
00009 // 
00010 // This library is free software; you can redistribute it and/or modify
00011 // it under the terms of the GNU Lesser General Public License as
00012 // published by the Free Software Foundation; either version 2.1 of the
00013 // License, or (at your option) any later version.
00014 //  
00015 // This library is distributed in the hope that it will be useful, but
00016 // WITHOUT ANY WARRANTY; without even the implied warranty of
00017 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
00018 // Lesser General Public License for more details.
00019 //  
00020 // You should have received a copy of the GNU Lesser General Public
00021 // License along with this library; if not, write to the Free Software
00022 // Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307
00023 // USA
00024 // Questions? Contact Michael A. Heroux (maherou@sandia.gov) 
00025 // 
00026 // ***********************************************************************
00027 // @HEADER
00028 
00029 #ifndef RTOPPACK_TOP_LINEAR_COMBINATION_HPP
00030 #define RTOPPACK_TOP_LINEAR_COMBINATION_HPP
00031 
00032 #include "RTOpPack_RTOpT.hpp"
00033 #include "Teuchos_Workspace.hpp"
00034 
00035 namespace RTOpPack {
00036 
00046 template<class Scalar>
00047 class TOpLinearCombination : public RTOpT<Scalar> {
00048 public:
00050   TOpLinearCombination(
00051     const int       num_vecs  = 0
00052     ,const Scalar   alpha[]   = NULL
00053     ,const Scalar   &beta     = Teuchos::ScalarTraits<Scalar>::zero()
00054     )
00055     :RTOpT<Scalar>("TOpLinearCombination")
00056      ,beta_(beta)
00057     { if(num_vecs) this->alpha(num_vecs,alpha); }
00059   void beta( const Scalar& beta ) { beta_ = beta; }
00061   Scalar beta() const { return beta_; }
00063   void alpha( 
00064     const int       num_vecs
00065     ,const Scalar   alpha[]    
00066     )
00067     {
00068       TEST_FOR_EXCEPT( num_vecs<=0 || alpha==NULL );
00069       alpha_.resize(0);
00070       alpha_.insert(alpha_.begin(),alpha,alpha+num_vecs);
00071     }
00073   int num_vecs() const { return alpha_.size(); }
00075   const Scalar* alpha() const { return &alpha_[0]; }
00079   void apply_op(
00080     const int   num_vecs,       const SubVectorT<Scalar>         sub_vecs[]
00081     ,const int  num_targ_vecs,  const MutableSubVectorT<Scalar>  targ_sub_vecs[]
00082     ,ReductTarget *reduct_obj
00083     ) const
00084     {
00085       typedef Teuchos::ScalarTraits<Scalar> ST;
00086       using Teuchos::Workspace;
00087       Teuchos::WorkspaceStore* wss = Teuchos::get_default_workspace_store().get();
00088       // Validate input
00089 #ifdef _DEBUG
00090       TEST_FOR_EXCEPT( static_cast<int>(alpha_.size()) != num_vecs );
00091       TEST_FOR_EXCEPT( sub_vecs == NULL );
00092       TEST_FOR_EXCEPT( num_targ_vecs != 1 );
00093       TEST_FOR_EXCEPT( targ_sub_vecs == NULL );
00094 #endif
00095       // Get pointers to local data
00096       const RTOpPack::index_type    subDim  = targ_sub_vecs[0].subDim();
00097       Scalar                        *z0_val = targ_sub_vecs[0].values();
00098       const ptrdiff_t               z0_s    = targ_sub_vecs[0].stride();
00099       Workspace<const Scalar*> v_val(wss,num_vecs,false);
00100       Workspace<ptrdiff_t>     v_s(wss,num_vecs,false);
00101       for( int k = 0; k < num_vecs; ++k ) {
00102 #ifdef _DEBUG
00103         TEST_FOR_EXCEPT( sub_vecs[k].subDim() != subDim );
00104         TEST_FOR_EXCEPT( sub_vecs[k].globalOffset() != targ_sub_vecs[0].globalOffset() );
00105 #endif          
00106         v_val[k] = sub_vecs[k].values();
00107         v_s[k]   = sub_vecs[k].stride();
00108       }
00109       //
00110       // Perform the operation and specialize the cases for num_vecs = 1 and 2
00111       // in order to get good performance.
00112       //
00113       if( num_vecs == 1 ) {
00114         //
00115         // z0 = alpha*v0 + beta*z0
00116         //
00117         const Scalar alpha = alpha_[0], beta = beta_;
00118         const Scalar       *v0_val = v_val[0];
00119         const ptrdiff_t    v0_s    = v_s[0]; 
00120         if( beta==ST::zero() ) {
00121           // z0 = alpha*v0
00122           if( z0_s==1 && v0_s==1 ) {
00123             for( int j = 0; j < subDim; ++j )
00124               (*z0_val++) = alpha * (*v0_val++);
00125           }
00126           else {
00127             for( int j = 0; j < subDim; ++j, z0_val+=z0_s, v0_val+=v0_s )
00128               (*z0_val) = alpha * (*v0_val);
00129           }
00130         }
00131         else if( beta==ST::one() ) {
00132           //
00133           // z0 = alpha*v0 + z0
00134           //
00135           if( z0_s==1 && v0_s==1 ) {
00136             for( int j = 0; j < subDim; ++j )
00137               (*z0_val++) += alpha * (*v0_val++);
00138           }
00139           else {
00140             for( int j = 0; j < subDim; ++j, z0_val+=z0_s, v0_val+=v0_s )
00141               (*z0_val) += alpha * (*v0_val);
00142           }
00143         }
00144         else {
00145           // z0 = alpha*v0 + beta*z0
00146           if( z0_s==1 && v0_s==1 ) {
00147             for( int j = 0; j < subDim; ++j, ++z0_val )
00148               (*z0_val) = alpha * (*v0_val++) + beta*(*z0_val);
00149           }
00150           else {
00151             for( int j = 0; j < subDim; ++j, z0_val+=z0_s, v0_val+=v0_s )
00152               (*z0_val) = alpha * (*v0_val) + beta*(*z0_val);
00153           }
00154         }
00155       }
00156       else if( num_vecs == 2 ) {
00157         //
00158         // z0 = alpha0*v0 + alpha1*v1 + beta*z0
00159         //
00160         const Scalar alpha0 = alpha_[0], alpha1=alpha_[1], beta = beta_;
00161         const Scalar     *v0_val = v_val[0];
00162         const ptrdiff_t  v0_s    = v_s[0]; 
00163         const Scalar     *v1_val = v_val[1];
00164         const ptrdiff_t  v1_s    = v_s[1]; 
00165         if( beta==ST::zero() ) {
00166           if( alpha0 == ST::one() ) {
00167             if( alpha1 == ST::one() ) {
00168               // z0 = v0 + v1
00169               if( z0_s==1 && v0_s==1 && v1_s==1 ) {
00170                 for( int j = 0; j < subDim; ++j )
00171                   (*z0_val++) = (*v0_val++) + (*v1_val++);
00172               }
00173               else {
00174                 for( int j = 0; j < subDim; ++j, z0_val+=z0_s, v0_val+=v0_s, v1_val+=v1_s )
00175                   (*z0_val) = (*v0_val) + (*v1_val);
00176               }
00177             }
00178             else {
00179               // z0 = v0 + alpha1*v1
00180               if( z0_s==1 && v0_s==1 && v1_s==1 ) {
00181                 for( int j = 0; j < subDim; ++j )
00182                   (*z0_val++) = (*v0_val++) + alpha1*(*v1_val++);
00183               }
00184               else {
00185                 for( int j = 0; j < subDim; ++j, z0_val+=z0_s, v0_val+=v0_s, v1_val+=v1_s )
00186                   (*z0_val) = (*v0_val) + alpha1*(*v1_val);
00187               }
00188             }
00189           }
00190           else {
00191             if( alpha1 == ST::one() ) {
00192               // z0 = alpha0*v0 + v1
00193               if( z0_s==1 && v0_s==1 && v1_s==1 ) {
00194                 for( int j = 0; j < subDim; ++j )
00195                   (*z0_val++) = alpha0*(*v0_val++) + (*v1_val++);
00196               }
00197               else {
00198                 for( int j = 0; j < subDim; ++j, z0_val+=z0_s, v0_val+=v0_s, v1_val+=v1_s )
00199                   (*z0_val) = alpha0*(*v0_val) + (*v1_val);
00200               }
00201             }
00202             else {
00203               // z0 = alpha0*v0 + alpha1*v1
00204               if( z0_s==1 && v0_s==1 && v1_s==1 ) {
00205                 for( int j = 0; j < subDim; ++j )
00206                   (*z0_val++) = alpha0*(*v0_val++) + alpha1*(*v1_val++);
00207               }
00208               else {
00209                 for( int j = 0; j < subDim; ++j, z0_val+=z0_s, v0_val+=v0_s, v1_val+=v1_s )
00210                   (*z0_val) = alpha0*(*v0_val) + alpha1*(*v1_val);
00211               }
00212             }
00213           }
00214         }
00215         else if( beta==ST::one() ) {
00216           if( alpha0 == ST::one() ) {
00217             if( alpha1 == ST::one() ) {
00218               // z0 = v0 + v1 + z0
00219               if( z0_s==1 && v0_s==1 && v1_s==1 ) {
00220                 for( int j = 0; j < subDim; ++j, ++z0_val )
00221                   (*z0_val) += (*v0_val++) + (*v1_val++);
00222               }
00223               else {
00224                 for( int j = 0; j < subDim; ++j, z0_val+=z0_s, v0_val+=v0_s, v1_val+=v1_s )
00225                   (*z0_val) += (*v0_val) + (*v1_val);
00226               }
00227             }
00228             else {
00229               // z0 = v0 + alpha1*v1 + z0
00230               if( z0_s==1 && v0_s==1 && v1_s==1 ) {
00231                 for( int j = 0; j < subDim; ++j, ++z0_val )
00232                   (*z0_val) += (*v0_val++) + alpha1*(*v1_val++);
00233               }
00234               else {
00235                 for( int j = 0; j < subDim; ++j, z0_val+=z0_s, v0_val+=v0_s, v1_val+=v1_s )
00236                   (*z0_val) += (*v0_val) + alpha1*(*v1_val);
00237               }
00238             }
00239           }
00240           else {
00241             if( alpha1 == ST::one() ) {
00242               // z0 = alpha0*v0 + v1 + z0
00243               if( z0_s==1 && v0_s==1 && v1_s==1 ) {
00244                 for( int j = 0; j < subDim; ++j, ++z0_val )
00245                   (*z0_val) += alpha0*(*v0_val++) + (*v1_val++);
00246               }
00247               else {
00248                 for( int j = 0; j < subDim; ++j, z0_val+=z0_s, v0_val+=v0_s, v1_val+=v1_s )
00249                   (*z0_val) += alpha0*(*v0_val) + (*v1_val);
00250               }
00251             }
00252             else {
00253               // z0 = alpha0*v0 + alpha1*v1 + z0
00254               if( z0_s==1 && v0_s==1 && v1_s==1 ) {
00255                 for( int j = 0; j < subDim; ++j, ++z0_val )
00256                   (*z0_val) += alpha0*(*v0_val++) + alpha1*(*v1_val++);
00257               }
00258               else {
00259                 for( int j = 0; j < subDim; ++j, z0_val+=z0_s, v0_val+=v0_s, v1_val+=v1_s )
00260                   (*z0_val) += alpha0*(*v0_val) + alpha1*(*v1_val);
00261               }
00262             }
00263           }
00264         }
00265         else {
00266           if( alpha0 == ST::one() ) {
00267             if( alpha1 == ST::one() ) {
00268               // z0 = v0 + v1 + beta*z0
00269               if( z0_s==1 && v0_s==1 && v1_s==1 ) {
00270                 for( int j = 0; j < subDim; ++j, ++z0_val )
00271                   (*z0_val) = (*v0_val++) + (*v1_val++) + beta*(*z0_val);
00272               }
00273               else {
00274                 for( int j = 0; j < subDim; ++j, z0_val+=z0_s, v0_val+=v0_s, v1_val+=v1_s )
00275                   (*z0_val) = (*v0_val) + (*v1_val) + beta*(*z0_val);
00276               }
00277             }
00278             else {
00279               // z0 = v0 + alpha1*v1 + beta*z0
00280               if( z0_s==1 && v0_s==1 && v1_s==1 ) {
00281                 for( int j = 0; j < subDim; ++j, ++z0_val )
00282                   (*z0_val) = (*v0_val++) + alpha1*(*v1_val++) + beta*(*z0_val);
00283               }
00284               else {
00285                 for( int j = 0; j < subDim; ++j, z0_val+=z0_s, v0_val+=v0_s, v1_val+=v1_s )
00286                   (*z0_val) = (*v0_val) + alpha1*(*v1_val) + beta*(*z0_val);
00287               }
00288             }
00289           }
00290           else {
00291             if( alpha1 == ST::one() ) {
00292               // z0 = alpha0*v0 + v1 + beta*z0
00293               if( z0_s==1 && v0_s==1 && v1_s==1 ) {
00294                 for( int j = 0; j < subDim; ++j, ++z0_val )
00295                   (*z0_val) = alpha0*(*v0_val++) + (*v1_val++) + beta*(*z0_val);
00296               }
00297               else {
00298                 for( int j = 0; j < subDim; ++j, z0_val+=z0_s, v0_val+=v0_s, v1_val+=v1_s )
00299                   (*z0_val) = alpha0*(*v0_val) + (*v1_val) + beta*(*z0_val);
00300               }
00301             }
00302             else {
00303               // z0 = alpha0*v0 + alpha1*v1 + beta*z0
00304               if( z0_s==1 && v0_s==1 && v1_s==1 ) {
00305                 for( int j = 0; j < subDim; ++j, ++z0_val )
00306                   (*z0_val) = alpha0*(*v0_val++) + alpha1*(*v1_val++) + beta*(*z0_val);
00307               }
00308               else {
00309                 for( int j = 0; j < subDim; ++j, z0_val+=z0_s, v0_val+=v0_s, v1_val+=v1_s )
00310                   (*z0_val) = alpha0*(*v0_val) + alpha1*(*v1_val) + beta*(*z0_val);
00311               }
00312             }
00313           }
00314         }
00315       }
00316       else {
00317         //
00318         // Totally general implementation (but least efficient)
00319         //
00320         // z0 *= beta
00321         if( beta_ == ST::zero() ) {
00322           for( int j = 0; j < subDim; ++j, z0_val += z0_s )
00323             (*z0_val) = ST::zero();
00324         }
00325         else if( beta_ != ST::one() ) {
00326           for( int j = 0; j < subDim; ++j, z0_val += z0_s )
00327             (*z0_val) *= beta_;
00328         }
00329         // z0 += sum( alpha[k]*v[k], k=0...num_vecs-1)
00330         z0_val = targ_sub_vecs[0].values();
00331         for( int j = 0; j < subDim; ++j, z0_val += z0_s ) {
00332           for( int k = 0; k < num_vecs; ++k ) {
00333             const Scalar
00334               &alpha_k = alpha_[k],
00335               &v_k_val = *v_val[k];
00336             (*z0_val) += alpha_k * v_k_val;
00337             v_val[k] += v_s[k];
00338           }
00339         }
00340       }
00341     }
00343 private:
00344   Scalar               beta_;
00345   std::vector<Scalar>  alpha_;
00346 }; // class TOpLinearCombination
00347 
00348 } // namespace RTOpPack
00349 
00350 #endif // RTOPPACK_TOP_LINEAR_COMBINATION_HPP

Generated on Thu Sep 18 12:39:44 2008 for RTOp : Vector Reduction/Transformation Operators by doxygen 1.3.9.1