RTOp Package Browser (Single Doxygen Collection) Version of the Day
RTOpPack_TOpLinearCombination_def.hpp
Go to the documentation of this file.
00001 // @HEADER
00002 // ***********************************************************************
00003 // 
00004 // RTOp: Interfaces and Support Software for Vector Reduction Transformation
00005 //       Operations
00006 //                Copyright (2006) Sandia Corporation
00007 // 
00008 // Under terms of Contract DE-AC04-94AL85000, there is a non-exclusive
00009 // license for use of this work by or on behalf of the U.S. Government.
00010 // 
00011 // This library is free software; you can redistribute it and/or modify
00012 // it under the terms of the GNU Lesser General Public License as
00013 // published by the Free Software Foundation; either version 2.1 of the
00014 // License, or (at your option) any later version.
00015 //  
00016 // This library is distributed in the hope that it will be useful, but
00017 // WITHOUT ANY WARRANTY; without even the implied warranty of
00018 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
00019 // Lesser General Public License for more details.
00020 //  
00021 // You should have received a copy of the GNU Lesser General Public
00022 // License along with this library; if not, write to the Free Software
00023 // Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307
00024 // USA
00025 // Questions? Contact Roscoe A. Bartlett (rabartl@sandia.gov) 
00026 // 
00027 // ***********************************************************************
00028 // @HEADER
00029 
00030 #ifndef RTOPPACK_TOP_LINEAR_COMBINATION_DEF_HPP
00031 #define RTOPPACK_TOP_LINEAR_COMBINATION_DEF_HPP
00032 
00033 
00034 #include "Teuchos_Workspace.hpp"
00035 
00036 
00037 namespace RTOpPack {
00038 
00039 
00040 template<class Scalar>
00041 TOpLinearCombination<Scalar>::TOpLinearCombination(
00042   const ArrayView<const Scalar> &alpha_in,
00043   const Scalar &beta_in
00044   )
00045   :beta_(beta_in)
00046 {
00047   if (alpha_in.size())
00048     this->alpha(alpha_in);
00049   this->setOpNameBase("TOpLinearCombination");
00050 }
00051 
00052 
00053 
00054 template<class Scalar>
00055 void TOpLinearCombination<Scalar>::alpha(
00056   const ArrayView<const Scalar> &alpha_in )
00057 {
00058   TEST_FOR_EXCEPT( alpha_in.size() == 0 );
00059   alpha_ = alpha_in;
00060 }
00061 
00062 
00063 template<class Scalar>
00064 const ArrayView<const Scalar>
00065 TOpLinearCombination<Scalar>::alpha() const
00066 { return alpha_; }
00067 
00068 
00069 template<class Scalar>
00070 void TOpLinearCombination<Scalar>::beta( const Scalar& beta_in ) { beta_ = beta_in; }
00071 
00072 
00073 template<class Scalar>
00074 Scalar TOpLinearCombination<Scalar>::beta() const { return beta_; }
00075 
00076 
00077 template<class Scalar>
00078 int TOpLinearCombination<Scalar>::num_vecs() const { return alpha_.size(); }
00079 
00080 
00081 // Overridden from RTOpT
00082 
00083 
00084 template<class Scalar>
00085 void TOpLinearCombination<Scalar>::apply_op_impl(
00086   const ArrayView<const ConstSubVectorView<Scalar> > &sub_vecs,
00087   const ArrayView<const SubVectorView<Scalar> > &targ_sub_vecs,
00088   const Ptr<ReductTarget> &reduct_obj_inout
00089   ) const
00090 {
00091 
00092   using Teuchos::as;
00093   using Teuchos::Workspace;
00094   typedef Teuchos::ScalarTraits<Scalar> ST;
00095   typedef typename Teuchos::ArrayRCP<Scalar>::iterator iter_t;
00096   typedef typename Teuchos::ArrayRCP<const Scalar>::iterator const_iter_t;
00097   Teuchos::WorkspaceStore* wss = Teuchos::get_default_workspace_store().get();
00098 
00099 #ifdef TEUCHOS_DEBUG
00100   validate_apply_op<Scalar>(*this, as<int>(alpha_.size()), 1, false,
00101     sub_vecs, targ_sub_vecs, reduct_obj_inout.getConst());
00102 #endif
00103 
00104   const int l_num_vecs = alpha_.size();
00105 
00106   // Get iterators to local data
00107   const RTOpPack::index_type subDim = targ_sub_vecs[0].subDim();
00108   iter_t z0_val = targ_sub_vecs[0].values().begin();
00109   const ptrdiff_t z0_s = targ_sub_vecs[0].stride();
00110   Workspace<const_iter_t> v_val(wss,l_num_vecs);
00111   Workspace<ptrdiff_t> v_s(wss,l_num_vecs,false);
00112   for( int k = 0; k < l_num_vecs; ++k ) {
00113 #ifdef TEUCHOS_DEBUG
00114     TEST_FOR_EXCEPT( sub_vecs[k].subDim() != subDim );
00115     TEST_FOR_EXCEPT( sub_vecs[k].globalOffset() != targ_sub_vecs[0].globalOffset() );
00116 #endif          
00117     v_val[k] = sub_vecs[k].values().begin();
00118     v_s[k] = sub_vecs[k].stride();
00119   }
00120 
00121   //
00122   // Perform the operation and specialize the cases for l_num_vecs = 1 and 2
00123   // in order to get good performance.
00124   //
00125   if( l_num_vecs == 1 ) {
00126     //
00127     // z0 = alpha*v0 + beta*z0
00128     //
00129     const Scalar l_alpha = alpha_[0], l_beta = beta_;
00130     const_iter_t v0_val = v_val[0];
00131     const ptrdiff_t v0_s = v_s[0]; 
00132     if( l_beta==ST::zero() ) {
00133       // z0 = alpha*v0
00134       if( z0_s==1 && v0_s==1 ) {
00135         for( int j = 0; j < subDim; ++j )
00136           (*z0_val++) = l_alpha * (*v0_val++);
00137       }
00138       else {
00139         for( int j = 0; j < subDim; ++j, z0_val+=z0_s, v0_val+=v0_s )
00140           (*z0_val) = l_alpha * (*v0_val);
00141       }
00142     }
00143     else if( l_beta==ST::one() ) {
00144       //
00145       // z0 = alpha*v0 + z0
00146       //
00147       if( z0_s==1 && v0_s==1 ) {
00148         for( int j = 0; j < subDim; ++j )
00149           (*z0_val++) += l_alpha * (*v0_val++);
00150       }
00151       else {
00152         for( int j = 0; j < subDim; ++j, z0_val+=z0_s, v0_val+=v0_s )
00153           (*z0_val) += l_alpha * (*v0_val);
00154       }
00155     }
00156     else {
00157       // z0 = alpha*v0 + beta*z0
00158       if( z0_s==1 && v0_s==1 ) {
00159         for( int j = 0; j < subDim; ++j, ++z0_val )
00160           (*z0_val) = l_alpha * (*v0_val++) + l_beta*(*z0_val);
00161       }
00162       else {
00163         for( int j = 0; j < subDim; ++j, z0_val+=z0_s, v0_val+=v0_s )
00164           (*z0_val) = l_alpha * (*v0_val) + l_beta*(*z0_val);
00165       }
00166     }
00167   }
00168   else if( l_num_vecs == 2 ) {
00169     //
00170     // z0 = alpha0*v0 + alpha1*v1 + beta*z0
00171     //
00172     const Scalar alpha0 = alpha_[0], alpha1=alpha_[1], l_beta = beta_;
00173     const_iter_t v0_val = v_val[0];
00174     const ptrdiff_t v0_s = v_s[0]; 
00175     const_iter_t v1_val = v_val[1];
00176     const ptrdiff_t v1_s = v_s[1]; 
00177     if( l_beta==ST::zero() ) {
00178       if( alpha0 == ST::one() ) {
00179         if( alpha1 == ST::one() ) {
00180           // z0 = v0 + v1
00181           if( z0_s==1 && v0_s==1 && v1_s==1 ) {
00182             for( int j = 0; j < subDim; ++j )
00183               (*z0_val++) = (*v0_val++) + (*v1_val++);
00184           }
00185           else {
00186             for( int j = 0; j < subDim; ++j, z0_val+=z0_s, v0_val+=v0_s, v1_val+=v1_s )
00187               (*z0_val) = (*v0_val) + (*v1_val);
00188           }
00189         }
00190         else {
00191           // z0 = v0 + alpha1*v1
00192           if( z0_s==1 && v0_s==1 && v1_s==1 ) {
00193             for( int j = 0; j < subDim; ++j )
00194               (*z0_val++) = (*v0_val++) + alpha1*(*v1_val++);
00195           }
00196           else {
00197             for( int j = 0; j < subDim; ++j, z0_val+=z0_s, v0_val+=v0_s, v1_val+=v1_s )
00198               (*z0_val) = (*v0_val) + alpha1*(*v1_val);
00199           }
00200         }
00201       }
00202       else {
00203         if( alpha1 == ST::one() ) {
00204           // z0 = alpha0*v0 + v1
00205           if( z0_s==1 && v0_s==1 && v1_s==1 ) {
00206             for( int j = 0; j < subDim; ++j )
00207               (*z0_val++) = alpha0*(*v0_val++) + (*v1_val++);
00208           }
00209           else {
00210             for( int j = 0; j < subDim; ++j, z0_val+=z0_s, v0_val+=v0_s, v1_val+=v1_s )
00211               (*z0_val) = alpha0*(*v0_val) + (*v1_val);
00212           }
00213         }
00214         else {
00215           // z0 = alpha0*v0 + alpha1*v1
00216           if( z0_s==1 && v0_s==1 && v1_s==1 ) {
00217             for( int j = 0; j < subDim; ++j )
00218               (*z0_val++) = alpha0*(*v0_val++) + alpha1*(*v1_val++);
00219           }
00220           else {
00221             for( int j = 0; j < subDim; ++j, z0_val+=z0_s, v0_val+=v0_s, v1_val+=v1_s )
00222               (*z0_val) = alpha0*(*v0_val) + alpha1*(*v1_val);
00223           }
00224         }
00225       }
00226     }
00227     else if( l_beta==ST::one() ) {
00228       if( alpha0 == ST::one() ) {
00229         if( alpha1 == ST::one() ) {
00230           // z0 = v0 + v1 + z0
00231           if( z0_s==1 && v0_s==1 && v1_s==1 ) {
00232             for( int j = 0; j < subDim; ++j, ++z0_val )
00233               (*z0_val) += (*v0_val++) + (*v1_val++);
00234           }
00235           else {
00236             for( int j = 0; j < subDim; ++j, z0_val+=z0_s, v0_val+=v0_s, v1_val+=v1_s )
00237               (*z0_val) += (*v0_val) + (*v1_val);
00238           }
00239         }
00240         else {
00241           // z0 = v0 + alpha1*v1 + z0
00242           if( z0_s==1 && v0_s==1 && v1_s==1 ) {
00243             for( int j = 0; j < subDim; ++j, ++z0_val )
00244               (*z0_val) += (*v0_val++) + alpha1*(*v1_val++);
00245           }
00246           else {
00247             for( int j = 0; j < subDim; ++j, z0_val+=z0_s, v0_val+=v0_s, v1_val+=v1_s )
00248               (*z0_val) += (*v0_val) + alpha1*(*v1_val);
00249           }
00250         }
00251       }
00252       else {
00253         if( alpha1 == ST::one() ) {
00254           // z0 = alpha0*v0 + v1 + z0
00255           if( z0_s==1 && v0_s==1 && v1_s==1 ) {
00256             for( int j = 0; j < subDim; ++j, ++z0_val )
00257               (*z0_val) += alpha0*(*v0_val++) + (*v1_val++);
00258           }
00259           else {
00260             for( int j = 0; j < subDim; ++j, z0_val+=z0_s, v0_val+=v0_s, v1_val+=v1_s )
00261               (*z0_val) += alpha0*(*v0_val) + (*v1_val);
00262           }
00263         }
00264         else {
00265           // z0 = alpha0*v0 + alpha1*v1 + z0
00266           if( z0_s==1 && v0_s==1 && v1_s==1 ) {
00267             for( int j = 0; j < subDim; ++j, ++z0_val )
00268               (*z0_val) += alpha0*(*v0_val++) + alpha1*(*v1_val++);
00269           }
00270           else {
00271             for( int j = 0; j < subDim; ++j, z0_val+=z0_s, v0_val+=v0_s, v1_val+=v1_s )
00272               (*z0_val) += alpha0*(*v0_val) + alpha1*(*v1_val);
00273           }
00274         }
00275       }
00276     }
00277     else {
00278       if( alpha0 == ST::one() ) {
00279         if( alpha1 == ST::one() ) {
00280           // z0 = v0 + v1 + beta*z0
00281           if( z0_s==1 && v0_s==1 && v1_s==1 ) {
00282             for( int j = 0; j < subDim; ++j, ++z0_val )
00283               (*z0_val) = (*v0_val++) + (*v1_val++) + l_beta*(*z0_val);
00284           }
00285           else {
00286             for( int j = 0; j < subDim; ++j, z0_val+=z0_s, v0_val+=v0_s, v1_val+=v1_s )
00287               (*z0_val) = (*v0_val) + (*v1_val) + l_beta*(*z0_val);
00288           }
00289         }
00290         else {
00291           // z0 = v0 + alpha1*v1 + beta*z0
00292           if( z0_s==1 && v0_s==1 && v1_s==1 ) {
00293             for( int j = 0; j < subDim; ++j, ++z0_val )
00294               (*z0_val) = (*v0_val++) + alpha1*(*v1_val++) + l_beta*(*z0_val);
00295           }
00296           else {
00297             for( int j = 0; j < subDim; ++j, z0_val+=z0_s, v0_val+=v0_s, v1_val+=v1_s )
00298               (*z0_val) = (*v0_val) + alpha1*(*v1_val) + l_beta*(*z0_val);
00299           }
00300         }
00301       }
00302       else {
00303         if( alpha1 == ST::one() ) {
00304           // z0 = alpha0*v0 + v1 + beta*z0
00305           if( z0_s==1 && v0_s==1 && v1_s==1 ) {
00306             for( int j = 0; j < subDim; ++j, ++z0_val )
00307               (*z0_val) = alpha0*(*v0_val++) + (*v1_val++) + l_beta*(*z0_val);
00308           }
00309           else {
00310             for( int j = 0; j < subDim; ++j, z0_val+=z0_s, v0_val+=v0_s, v1_val+=v1_s )
00311               (*z0_val) = alpha0*(*v0_val) + (*v1_val) + l_beta*(*z0_val);
00312           }
00313         }
00314         else {
00315           // z0 = alpha0*v0 + alpha1*v1 + beta*z0
00316           if( z0_s==1 && v0_s==1 && v1_s==1 ) {
00317             for( int j = 0; j < subDim; ++j, ++z0_val )
00318               (*z0_val) = alpha0*(*v0_val++) + alpha1*(*v1_val++) + l_beta*(*z0_val);
00319           }
00320           else {
00321             for( int j = 0; j < subDim; ++j, z0_val+=z0_s, v0_val+=v0_s, v1_val+=v1_s )
00322               (*z0_val) = alpha0*(*v0_val) + alpha1*(*v1_val) + l_beta*(*z0_val);
00323           }
00324         }
00325       }
00326     }
00327   }
00328   else {
00329     //
00330     // Totally general implementation (but least efficient)
00331     //
00332     // z0 *= beta
00333     if( beta_ == ST::zero() ) {
00334       for( int j = 0; j < subDim; ++j, z0_val += z0_s )
00335         (*z0_val) = ST::zero();
00336     }
00337     else if( beta_ != ST::one() ) {
00338       for( int j = 0; j < subDim; ++j, z0_val += z0_s )
00339         (*z0_val) *= beta_;
00340     }
00341     // z0 += sum( alpha[k]*v[k], k=0...l_num_vecs-1)
00342     z0_val = targ_sub_vecs[0].values().begin();
00343     for( int j = 0; j < subDim; ++j, z0_val += z0_s ) {
00344       for( int k = 0; k < l_num_vecs; ++k ) {
00345         const Scalar
00346           &alpha_k = alpha_[k],
00347           &v_k_val = *v_val[k];
00348         (*z0_val) += alpha_k * v_k_val;
00349         v_val[k] += v_s[k];
00350       }
00351     }
00352   }
00353 }
00354 
00355 
00356 } // namespace RTOpPack
00357 
00358 
00359 #endif // RTOPPACK_TOP_LINEAR_COMBINATION_DEF_HPP
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines