RTOp Package Browser (Single Doxygen Collection) Version of the Day
RTOpPack_TOpLinearCombination_def.hpp
Go to the documentation of this file.
00001 // @HEADER
00002 // ***********************************************************************
00003 // 
00004 // RTOp: Interfaces and Support Software for Vector Reduction Transformation
00005 //       Operations
00006 //                Copyright (2006) Sandia Corporation
00007 // 
00008 // Under terms of Contract DE-AC04-94AL85000, there is a non-exclusive
00009 // license for use of this work by or on behalf of the U.S. Government.
00010 // 
00011 // Redistribution and use in source and binary forms, with or without
00012 // modification, are permitted provided that the following conditions are
00013 // met:
00014 //
00015 // 1. Redistributions of source code must retain the above copyright
00016 // notice, this list of conditions and the following disclaimer.
00017 //
00018 // 2. Redistributions in binary form must reproduce the above copyright
00019 // notice, this list of conditions and the following disclaimer in the
00020 // documentation and/or other materials provided with the distribution.
00021 //
00022 // 3. Neither the name of the Corporation nor the names of the
00023 // contributors may be used to endorse or promote products derived from
00024 // this software without specific prior written permission.
00025 //
00026 // THIS SOFTWARE IS PROVIDED BY SANDIA CORPORATION "AS IS" AND ANY
00027 // EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
00028 // IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
00029 // PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL SANDIA CORPORATION OR THE
00030 // CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
00031 // EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
00032 // PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
00033 // PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
00034 // LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
00035 // NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
00036 // SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
00037 //
00038 // Questions? Contact Roscoe A. Bartlett (rabartl@sandia.gov) 
00039 // 
00040 // ***********************************************************************
00041 // @HEADER
00042 
00043 #ifndef RTOPPACK_TOP_LINEAR_COMBINATION_DEF_HPP
00044 #define RTOPPACK_TOP_LINEAR_COMBINATION_DEF_HPP
00045 
00046 
00047 #include "Teuchos_Workspace.hpp"
00048 
00049 
00050 namespace RTOpPack {
00051 
00052 
00053 template<class Scalar>
00054 TOpLinearCombination<Scalar>::TOpLinearCombination(
00055   const ArrayView<const Scalar> &alpha_in,
00056   const Scalar &beta_in
00057   )
00058   :beta_(beta_in)
00059 {
00060   if (alpha_in.size())
00061     this->alpha(alpha_in);
00062   this->setOpNameBase("TOpLinearCombination");
00063 }
00064 
00065 
00066 
00067 template<class Scalar>
00068 void TOpLinearCombination<Scalar>::alpha(
00069   const ArrayView<const Scalar> &alpha_in )
00070 {
00071   TEUCHOS_TEST_FOR_EXCEPT( alpha_in.size() == 0 );
00072   alpha_ = alpha_in;
00073 }
00074 
00075 
00076 template<class Scalar>
00077 const ArrayView<const Scalar>
00078 TOpLinearCombination<Scalar>::alpha() const
00079 { return alpha_; }
00080 
00081 
00082 template<class Scalar>
00083 void TOpLinearCombination<Scalar>::beta( const Scalar& beta_in ) { beta_ = beta_in; }
00084 
00085 
00086 template<class Scalar>
00087 Scalar TOpLinearCombination<Scalar>::beta() const { return beta_; }
00088 
00089 
00090 template<class Scalar>
00091 int TOpLinearCombination<Scalar>::num_vecs() const { return alpha_.size(); }
00092 
00093 
00094 // Overridden from RTOpT
00095 
00096 
00097 template<class Scalar>
00098 void TOpLinearCombination<Scalar>::apply_op_impl(
00099   const ArrayView<const ConstSubVectorView<Scalar> > &sub_vecs,
00100   const ArrayView<const SubVectorView<Scalar> > &targ_sub_vecs,
00101   const Ptr<ReductTarget> &reduct_obj_inout
00102   ) const
00103 {
00104 
00105   using Teuchos::as;
00106   using Teuchos::Workspace;
00107   typedef Teuchos::ScalarTraits<Scalar> ST;
00108   typedef typename Teuchos::ArrayRCP<Scalar>::iterator iter_t;
00109   typedef typename Teuchos::ArrayRCP<const Scalar>::iterator const_iter_t;
00110   Teuchos::WorkspaceStore* wss = Teuchos::get_default_workspace_store().get();
00111 
00112 #ifdef TEUCHOS_DEBUG
00113   validate_apply_op<Scalar>(*this, as<int>(alpha_.size()), 1, false,
00114     sub_vecs, targ_sub_vecs, reduct_obj_inout.getConst());
00115 #endif
00116 
00117   const int l_num_vecs = alpha_.size();
00118 
00119   // Get iterators to local data
00120   const RTOpPack::index_type subDim = targ_sub_vecs[0].subDim();
00121   iter_t z0_val = targ_sub_vecs[0].values().begin();
00122   const ptrdiff_t z0_s = targ_sub_vecs[0].stride();
00123   Workspace<const_iter_t> v_val(wss,l_num_vecs);
00124   Workspace<ptrdiff_t> v_s(wss,l_num_vecs,false);
00125   for( int k = 0; k < l_num_vecs; ++k ) {
00126 #ifdef TEUCHOS_DEBUG
00127     TEUCHOS_TEST_FOR_EXCEPT( sub_vecs[k].subDim() != subDim );
00128     TEUCHOS_TEST_FOR_EXCEPT( sub_vecs[k].globalOffset() != targ_sub_vecs[0].globalOffset() );
00129 #endif          
00130     v_val[k] = sub_vecs[k].values().begin();
00131     v_s[k] = sub_vecs[k].stride();
00132   }
00133 
00134   //
00135   // Perform the operation and specialize the cases for l_num_vecs = 1 and 2
00136   // in order to get good performance.
00137   //
00138   if( l_num_vecs == 1 ) {
00139     //
00140     // z0 = alpha*v0 + beta*z0
00141     //
00142     const Scalar l_alpha = alpha_[0], l_beta = beta_;
00143     const_iter_t v0_val = v_val[0];
00144     const ptrdiff_t v0_s = v_s[0]; 
00145     if( l_beta==ST::zero() ) {
00146       // z0 = alpha*v0
00147       if( z0_s==1 && v0_s==1 ) {
00148         for( int j = 0; j < subDim; ++j )
00149           (*z0_val++) = l_alpha * (*v0_val++);
00150       }
00151       else {
00152         for( int j = 0; j < subDim; ++j, z0_val+=z0_s, v0_val+=v0_s )
00153           (*z0_val) = l_alpha * (*v0_val);
00154       }
00155     }
00156     else if( l_beta==ST::one() ) {
00157       //
00158       // z0 = alpha*v0 + z0
00159       //
00160       if( z0_s==1 && v0_s==1 ) {
00161         for( int j = 0; j < subDim; ++j )
00162           (*z0_val++) += l_alpha * (*v0_val++);
00163       }
00164       else {
00165         for( int j = 0; j < subDim; ++j, z0_val+=z0_s, v0_val+=v0_s )
00166           (*z0_val) += l_alpha * (*v0_val);
00167       }
00168     }
00169     else {
00170       // z0 = alpha*v0 + beta*z0
00171       if( z0_s==1 && v0_s==1 ) {
00172         for( int j = 0; j < subDim; ++j, ++z0_val )
00173           (*z0_val) = l_alpha * (*v0_val++) + l_beta*(*z0_val);
00174       }
00175       else {
00176         for( int j = 0; j < subDim; ++j, z0_val+=z0_s, v0_val+=v0_s )
00177           (*z0_val) = l_alpha * (*v0_val) + l_beta*(*z0_val);
00178       }
00179     }
00180   }
00181   else if( l_num_vecs == 2 ) {
00182     //
00183     // z0 = alpha0*v0 + alpha1*v1 + beta*z0
00184     //
00185     const Scalar alpha0 = alpha_[0], alpha1=alpha_[1], l_beta = beta_;
00186     const_iter_t v0_val = v_val[0];
00187     const ptrdiff_t v0_s = v_s[0]; 
00188     const_iter_t v1_val = v_val[1];
00189     const ptrdiff_t v1_s = v_s[1]; 
00190     if( l_beta==ST::zero() ) {
00191       if( alpha0 == ST::one() ) {
00192         if( alpha1 == ST::one() ) {
00193           // z0 = v0 + v1
00194           if( z0_s==1 && v0_s==1 && v1_s==1 ) {
00195             for( int j = 0; j < subDim; ++j )
00196               (*z0_val++) = (*v0_val++) + (*v1_val++);
00197           }
00198           else {
00199             for( int j = 0; j < subDim; ++j, z0_val+=z0_s, v0_val+=v0_s, v1_val+=v1_s )
00200               (*z0_val) = (*v0_val) + (*v1_val);
00201           }
00202         }
00203         else {
00204           // z0 = v0 + alpha1*v1
00205           if( z0_s==1 && v0_s==1 && v1_s==1 ) {
00206             for( int j = 0; j < subDim; ++j )
00207               (*z0_val++) = (*v0_val++) + alpha1*(*v1_val++);
00208           }
00209           else {
00210             for( int j = 0; j < subDim; ++j, z0_val+=z0_s, v0_val+=v0_s, v1_val+=v1_s )
00211               (*z0_val) = (*v0_val) + alpha1*(*v1_val);
00212           }
00213         }
00214       }
00215       else {
00216         if( alpha1 == ST::one() ) {
00217           // z0 = alpha0*v0 + v1
00218           if( z0_s==1 && v0_s==1 && v1_s==1 ) {
00219             for( int j = 0; j < subDim; ++j )
00220               (*z0_val++) = alpha0*(*v0_val++) + (*v1_val++);
00221           }
00222           else {
00223             for( int j = 0; j < subDim; ++j, z0_val+=z0_s, v0_val+=v0_s, v1_val+=v1_s )
00224               (*z0_val) = alpha0*(*v0_val) + (*v1_val);
00225           }
00226         }
00227         else {
00228           // z0 = alpha0*v0 + alpha1*v1
00229           if( z0_s==1 && v0_s==1 && v1_s==1 ) {
00230             for( int j = 0; j < subDim; ++j )
00231               (*z0_val++) = alpha0*(*v0_val++) + alpha1*(*v1_val++);
00232           }
00233           else {
00234             for( int j = 0; j < subDim; ++j, z0_val+=z0_s, v0_val+=v0_s, v1_val+=v1_s )
00235               (*z0_val) = alpha0*(*v0_val) + alpha1*(*v1_val);
00236           }
00237         }
00238       }
00239     }
00240     else if( l_beta==ST::one() ) {
00241       if( alpha0 == ST::one() ) {
00242         if( alpha1 == ST::one() ) {
00243           // z0 = v0 + v1 + z0
00244           if( z0_s==1 && v0_s==1 && v1_s==1 ) {
00245             for( int j = 0; j < subDim; ++j, ++z0_val )
00246               (*z0_val) += (*v0_val++) + (*v1_val++);
00247           }
00248           else {
00249             for( int j = 0; j < subDim; ++j, z0_val+=z0_s, v0_val+=v0_s, v1_val+=v1_s )
00250               (*z0_val) += (*v0_val) + (*v1_val);
00251           }
00252         }
00253         else {
00254           // z0 = v0 + alpha1*v1 + z0
00255           if( z0_s==1 && v0_s==1 && v1_s==1 ) {
00256             for( int j = 0; j < subDim; ++j, ++z0_val )
00257               (*z0_val) += (*v0_val++) + alpha1*(*v1_val++);
00258           }
00259           else {
00260             for( int j = 0; j < subDim; ++j, z0_val+=z0_s, v0_val+=v0_s, v1_val+=v1_s )
00261               (*z0_val) += (*v0_val) + alpha1*(*v1_val);
00262           }
00263         }
00264       }
00265       else {
00266         if( alpha1 == ST::one() ) {
00267           // z0 = alpha0*v0 + v1 + z0
00268           if( z0_s==1 && v0_s==1 && v1_s==1 ) {
00269             for( int j = 0; j < subDim; ++j, ++z0_val )
00270               (*z0_val) += alpha0*(*v0_val++) + (*v1_val++);
00271           }
00272           else {
00273             for( int j = 0; j < subDim; ++j, z0_val+=z0_s, v0_val+=v0_s, v1_val+=v1_s )
00274               (*z0_val) += alpha0*(*v0_val) + (*v1_val);
00275           }
00276         }
00277         else {
00278           // z0 = alpha0*v0 + alpha1*v1 + z0
00279           if( z0_s==1 && v0_s==1 && v1_s==1 ) {
00280             for( int j = 0; j < subDim; ++j, ++z0_val )
00281               (*z0_val) += alpha0*(*v0_val++) + alpha1*(*v1_val++);
00282           }
00283           else {
00284             for( int j = 0; j < subDim; ++j, z0_val+=z0_s, v0_val+=v0_s, v1_val+=v1_s )
00285               (*z0_val) += alpha0*(*v0_val) + alpha1*(*v1_val);
00286           }
00287         }
00288       }
00289     }
00290     else {
00291       if( alpha0 == ST::one() ) {
00292         if( alpha1 == ST::one() ) {
00293           // z0 = v0 + v1 + beta*z0
00294           if( z0_s==1 && v0_s==1 && v1_s==1 ) {
00295             for( int j = 0; j < subDim; ++j, ++z0_val )
00296               (*z0_val) = (*v0_val++) + (*v1_val++) + l_beta*(*z0_val);
00297           }
00298           else {
00299             for( int j = 0; j < subDim; ++j, z0_val+=z0_s, v0_val+=v0_s, v1_val+=v1_s )
00300               (*z0_val) = (*v0_val) + (*v1_val) + l_beta*(*z0_val);
00301           }
00302         }
00303         else {
00304           // z0 = v0 + alpha1*v1 + beta*z0
00305           if( z0_s==1 && v0_s==1 && v1_s==1 ) {
00306             for( int j = 0; j < subDim; ++j, ++z0_val )
00307               (*z0_val) = (*v0_val++) + alpha1*(*v1_val++) + l_beta*(*z0_val);
00308           }
00309           else {
00310             for( int j = 0; j < subDim; ++j, z0_val+=z0_s, v0_val+=v0_s, v1_val+=v1_s )
00311               (*z0_val) = (*v0_val) + alpha1*(*v1_val) + l_beta*(*z0_val);
00312           }
00313         }
00314       }
00315       else {
00316         if( alpha1 == ST::one() ) {
00317           // z0 = alpha0*v0 + v1 + beta*z0
00318           if( z0_s==1 && v0_s==1 && v1_s==1 ) {
00319             for( int j = 0; j < subDim; ++j, ++z0_val )
00320               (*z0_val) = alpha0*(*v0_val++) + (*v1_val++) + l_beta*(*z0_val);
00321           }
00322           else {
00323             for( int j = 0; j < subDim; ++j, z0_val+=z0_s, v0_val+=v0_s, v1_val+=v1_s )
00324               (*z0_val) = alpha0*(*v0_val) + (*v1_val) + l_beta*(*z0_val);
00325           }
00326         }
00327         else {
00328           // z0 = alpha0*v0 + alpha1*v1 + beta*z0
00329           if( z0_s==1 && v0_s==1 && v1_s==1 ) {
00330             for( int j = 0; j < subDim; ++j, ++z0_val )
00331               (*z0_val) = alpha0*(*v0_val++) + alpha1*(*v1_val++) + l_beta*(*z0_val);
00332           }
00333           else {
00334             for( int j = 0; j < subDim; ++j, z0_val+=z0_s, v0_val+=v0_s, v1_val+=v1_s )
00335               (*z0_val) = alpha0*(*v0_val) + alpha1*(*v1_val) + l_beta*(*z0_val);
00336           }
00337         }
00338       }
00339     }
00340   }
00341   else {
00342     //
00343     // Totally general implementation (but least efficient)
00344     //
00345     // z0 *= beta
00346     if( beta_ == ST::zero() ) {
00347       for( int j = 0; j < subDim; ++j, z0_val += z0_s )
00348         (*z0_val) = ST::zero();
00349     }
00350     else if( beta_ != ST::one() ) {
00351       for( int j = 0; j < subDim; ++j, z0_val += z0_s )
00352         (*z0_val) *= beta_;
00353     }
00354     // z0 += sum( alpha[k]*v[k], k=0...l_num_vecs-1)
00355     z0_val = targ_sub_vecs[0].values().begin();
00356     for( int j = 0; j < subDim; ++j, z0_val += z0_s ) {
00357       for( int k = 0; k < l_num_vecs; ++k ) {
00358         const Scalar
00359           &alpha_k = alpha_[k],
00360           &v_k_val = *v_val[k];
00361         (*z0_val) += alpha_k * v_k_val;
00362         v_val[k] += v_s[k];
00363       }
00364     }
00365   }
00366 }
00367 
00368 
00369 } // namespace RTOpPack
00370 
00371 
00372 #endif // RTOPPACK_TOP_LINEAR_COMBINATION_DEF_HPP
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines