Thyra_SerialMultiVectorBase.hpp

00001 // @HEADER
00002 // ***********************************************************************
00003 // 
00004 //    Thyra: Interfaces and Support for Abstract Numerical Algorithms
00005 //                 Copyright (2004) Sandia Corporation
00006 // 
00007 // Under terms of Contract DE-AC04-94AL85000, there is a non-exclusive
00008 // license for use of this work by or on behalf of the U.S. Government.
00009 // 
00010 // This library is free software; you can redistribute it and/or modify
00011 // it under the terms of the GNU Lesser General Public License as
00012 // published by the Free Software Foundation; either version 2.1 of the
00013 // License, or (at your option) any later version.
00014 //  
00015 // This library is distributed in the hope that it will be useful, but
00016 // WITHOUT ANY WARRANTY; without even the implied warranty of
00017 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
00018 // Lesser General Public License for more details.
00019 //  
00020 // You should have received a copy of the GNU Lesser General Public
00021 // License along with this library; if not, write to the Free Software
00022 // Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307
00023 // USA
00024 // Questions? Contact Michael A. Heroux (maherou@sandia.gov)
00025 // 
00026 // ***********************************************************************
00027 // @HEADER
00028 
00029 #ifndef THYRA_SERIAL_MULTI_VECTOR_BASE_HPP
00030 #define THYRA_SERIAL_MULTI_VECTOR_BASE_HPP
00031 
00032 #include "Thyra_SerialMultiVectorBaseDecl.hpp"
00033 #include "Thyra_MultiVectorDefaultBase.hpp"
00034 #include "Thyra_SingleScalarEuclideanLinearOpBase.hpp"
00035 #include "Thyra_SerialVectorSpaceBase.hpp"
00036 #include "Thyra_ExplicitMultiVectorView.hpp"
00037 #include "Thyra_apply_op_helper.hpp"
00038 #include "RTOp_parallel_helpers.h"
00039 #include "Teuchos_Workspace.hpp"
00040 #include "Teuchos_dyn_cast.hpp"
00041 #include "Teuchos_Time.hpp"
00042 
00043 // Define to see some timing output!
00044 //#define THYRA_SERIAL_MULTI_VECTOR_BASE_PRINT_TIMES
00045 
00046 namespace Thyra {
00047 
00048 template<class Scalar>
00049 SerialMultiVectorBase<Scalar>::SerialMultiVectorBase()
00050   :in_applyOp_(false)
00051   ,numRows_(0)
00052   ,numCols_(0)
00053 {}
00054 
00055 // Overridden from LinearOpBase
00056 
00057 /*
00058 
00059 template<class Scalar>
00060 void SerialMultiVectorBase<Scalar>::apply(
00061   const ETransp                     M_trans
00062   ,const MultiVectorBase<Scalar>    &X
00063   ,MultiVectorBase<Scalar>          *Y
00064   ,const Scalar                     alpha
00065   ,const Scalar                     beta
00066   ) const
00067 {
00068   this->single_scalar_euclidean_apply_impl(M_trans,X,Y,alpha,beta);
00069 }
00070 
00071 */
00072 
00073 // Overridden from MultiVectorBase
00074 
00075 template<class Scalar>
00076 void SerialMultiVectorBase<Scalar>::applyOp(
00077   const RTOpPack::RTOpT<Scalar>   &pri_op
00078   ,const int                      num_multi_vecs
00079   ,const MultiVectorBase<Scalar>* multi_vecs[]
00080   ,const int                      num_targ_multi_vecs
00081   ,MultiVectorBase<Scalar>*       targ_multi_vecs[]
00082   ,RTOpPack::ReductTarget*        reduct_objs[]
00083   ,const Index                    pri_first_ele_in
00084   ,const Index                    pri_sub_dim_in
00085   ,const Index                    pri_global_offset_in
00086   ,const Index                    sec_first_ele_in
00087   ,const Index                    sec_sub_dim_in
00088   ) const
00089 {
00090 #ifdef _DEBUG
00091   // ToDo: Validate input!
00092   TEST_FOR_EXCEPTION(
00093     in_applyOp_, std::invalid_argument
00094     ,"SerialMultiVectorBase<>::applyOp(...): Error, this method is being entered recursively which is a "
00095     "clear sign that one of the methods getSubMultiVector(...), freeSubMultiVector(...) or commitSubMultiVector(...) "
00096     "was not implemented properly!"
00097     );
00098   apply_op_validate_input(
00099     "SerialMultiVectorBase<Scalar>::applyOp(...)", *this->domain(), *this->range()
00100     ,pri_op,num_multi_vecs,multi_vecs,num_targ_multi_vecs,targ_multi_vecs
00101     ,reduct_objs,pri_first_ele_in,pri_sub_dim_in,pri_global_offset_in
00102     ,sec_first_ele_in,sec_sub_dim_in
00103     );
00104 #endif
00105   in_applyOp_ = true;
00106   apply_op_serial(
00107     *(this->domain()),*(this->range())
00108     ,pri_op,num_multi_vecs,multi_vecs,num_targ_multi_vecs,targ_multi_vecs
00109     ,reduct_objs,pri_first_ele_in,pri_sub_dim_in,pri_global_offset_in
00110     ,sec_first_ele_in,sec_sub_dim_in
00111     );
00112   in_applyOp_ = false;
00113 }
00114 
00115 template<class Scalar>
00116 void SerialMultiVectorBase<Scalar>::getSubMultiVector(
00117   const Range1D                       &rowRng_in
00118   ,const Range1D                      &colRng_in
00119   ,RTOpPack::SubMultiVectorT<Scalar>  *sub_mv
00120   ) const
00121 {
00122   const Range1D rowRng = validateRowRange(rowRng_in);
00123   const Range1D colRng = validateColRange(colRng_in);
00124   const Scalar *localValues = NULL; int leadingDim = 0;
00125   this->getData(&localValues,&leadingDim);
00126   sub_mv->initialize(
00127     rowRng.lbound()-1                             // globalOffset
00128     ,rowRng.size()                                // subDim
00129     ,colRng.lbound()-1                            // colOffset
00130     ,colRng.size()                                // numSubCols
00131     ,localValues
00132     +(rowRng.lbound()-1)
00133     +(colRng.lbound()-1)*leadingDim               // values
00134     ,leadingDim                                   // leadingDim
00135     );
00136 }
00137 
00138 template<class Scalar>
00139 void SerialMultiVectorBase<Scalar>::freeSubMultiVector(
00140   RTOpPack::SubMultiVectorT<Scalar>* sub_mv
00141   ) const
00142 {
00143   freeData( sub_mv->values() );
00144   sub_mv->set_uninitialized();
00145 }
00146 
00147 template<class Scalar>
00148 void SerialMultiVectorBase<Scalar>::getSubMultiVector(
00149   const Range1D                                &rowRng_in
00150   ,const Range1D                               &colRng_in
00151   ,RTOpPack::MutableSubMultiVectorT<Scalar>    *sub_mv
00152   )
00153 {
00154   const Range1D rowRng = validateRowRange(rowRng_in);
00155   const Range1D colRng = validateColRange(colRng_in);
00156   Scalar *localValues = NULL; int leadingDim = 0;
00157   this->getData(&localValues,&leadingDim);
00158   sub_mv->initialize(
00159     rowRng.lbound()-1                             // globalOffset
00160     ,rowRng.size()                                // subDim
00161     ,colRng.lbound()-1                            // colOffset
00162     ,colRng.size()                                // numSubCols
00163     ,localValues
00164     +(rowRng.lbound()-1)
00165     +(colRng.lbound()-1)*leadingDim               // values
00166     ,leadingDim                                   // leadingDim
00167     );
00168 }
00169 
00170 template<class Scalar>
00171 void SerialMultiVectorBase<Scalar>::commitSubMultiVector(
00172   RTOpPack::MutableSubMultiVectorT<Scalar>* sub_mv
00173   )
00174 {
00175   commitData( sub_mv->values() );
00176   sub_mv->set_uninitialized();
00177 }
00178 
00179 // protected
00180 
00181 
00182 // Overridden from SingleScalarEuclideanLinearOpBase
00183 
00184 template<class Scalar>
00185 bool SerialMultiVectorBase<Scalar>::opSupported(ETransp M_trans) const
00186 {
00187   typedef Teuchos::ScalarTraits<Scalar> ST;
00188   return ( ST::isComplex ? ( M_trans!=CONJ ) : true );
00189 }
00190 
00191 template<class Scalar>
00192 void SerialMultiVectorBase<Scalar>::euclideanApply(
00193   const ETransp                     M_trans
00194   ,const MultiVectorBase<Scalar>    &X
00195   ,MultiVectorBase<Scalar>          *Y
00196   ,const Scalar                     alpha
00197   ,const Scalar                     beta
00198   ) const
00199 {
00200 
00201   typedef Teuchos::ScalarTraits<Scalar> ST;
00202 
00203 #ifdef THYRA_SERIAL_MULTI_VECTOR_BASE_PRINT_TIMES
00204   Teuchos::Time timerTotal("dummy",true);
00205   Teuchos::Time timer("dummy");
00206 #endif
00207 
00208   //
00209   // This function performs one of two operations.
00210   //
00211   // The first operation (M_trans == NOTRANS) is:
00212 
00213   //     Y = beta * Y + alpha * M * X
00214   //
00215   // The second operation (M_trans == TRANS) is:
00216   //
00217   //     Y = beta * Y + alpha * M' * X
00218   //
00219 
00220 #ifdef _DEBUG
00221   THYRA_ASSERT_LINEAR_OP_MULTIVEC_APPLY_SPACES("SerialMultiVectorBase<Scalar>::euclideanApply()",*this,M_trans,X,Y);
00222 #endif
00223 
00224   //
00225   // Get explicit views of Y, M and X
00226   //
00227 
00228 #ifdef THYRA_SERIAL_MULTI_VECTOR_BASE_PRINT_TIMES
00229   timer.start();
00230 #endif
00231   ExplicitMutableMultiVectorView<Scalar>  Y_local(*Y);
00232   ExplicitMultiVectorView<Scalar>         M_local(*this);
00233   ExplicitMultiVectorView<Scalar>         X_local(X);
00234 #ifdef THYRA_SERIAL_MULTI_VECTOR_BASE_PRINT_TIMES
00235   timer.stop();
00236   std::cout << "\nSerialMultiVectorBase<Scalar>::apply(...): Time for getting view = " << timer.totalElapsedTime() << " seconds\n";
00237 #endif
00238     
00239   //
00240   // Perform the multiplication:
00241   //
00242   //     Y(local) = localBeta * Y(local) + alpha * op(M(local)) * X(local)
00243   //
00244   // or in BLAS lingo:
00245   //
00246   //     C        = beta      * C        + alpha * op(A)        * op(B)
00247   //
00248 
00249 #ifdef THYRA_SERIAL_MULTI_VECTOR_BASE_PRINT_TIMES
00250   timer.start();
00251 #endif
00252   Teuchos::ETransp t_transp;
00253   if(ST::isComplex) {
00254     switch(M_trans) {
00255       case NOTRANS:   t_transp = Teuchos::NO_TRANS;     break;
00256       case TRANS:     t_transp = Teuchos::TRANS;        break;
00257       case CONJTRANS: t_transp = Teuchos::CONJ_TRANS;   break;
00258       default: TEST_FOR_EXCEPT(true);
00259     }
00260   }
00261   else {
00262     switch(real_trans(M_trans)) {
00263       case NOTRANS:   t_transp = Teuchos::NO_TRANS;     break;
00264       case TRANS:     t_transp = Teuchos::TRANS;        break;
00265       default: TEST_FOR_EXCEPT(true);
00266     }
00267   }
00268   blas_.GEMM(
00269     t_transp                                                                 // TRANSA
00270     ,Teuchos::NO_TRANS                                                       // TRANSB
00271     ,Y_local.subDim()                                                        // M
00272     ,Y_local.numSubCols()                                                    // N
00273     ,real_trans(M_trans)==NOTRANS ? M_local.numSubCols() : M_local.subDim()  // K
00274     ,alpha                                                                   // ALPHA
00275     ,const_cast<Scalar*>(M_local.values())                                   // A
00276     ,M_local.leadingDim()                                                    // LDA
00277     ,const_cast<Scalar*>(X_local.values())                                   // B
00278     ,X_local.leadingDim()                                                    // LDB
00279     ,beta                                                                    // BETA
00280     ,Y_local.values()                                                        // C
00281     ,Y_local.leadingDim()                                                    // LDC
00282     );
00283 #ifdef THYRA_SERIAL_MULTI_VECTOR_BASE_PRINT_TIMES
00284   timer.stop();
00285   std::cout << "\nSerialMultiVectorBase<Scalar>::apply(...): Time for GEMM = " << timer.totalElapsedTime() << " seconds\n";
00286 #endif
00287 
00288 #ifdef THYRA_SERIAL_MULTI_VECTOR_BASE_PRINT_TIMES
00289   timer.stop();
00290   std::cout << "\nSerialMultiVectorBase<Scalar>::apply(...): Total time = " << timerTotal.totalElapsedTime() << " seconds\n";
00291 #endif
00292 
00293 }
00294 
00295 // Miscellaneous functions for subclasses to call
00296 
00297 template<class Scalar>
00298 void SerialMultiVectorBase<Scalar>::updateSpace()
00299 {
00300   if(numRows_ == 0) {
00301     const VectorSpaceBase<Scalar> *range = this->range().get();
00302     if(range) {
00303       numRows_    = range->dim();
00304       numCols_    = this->domain()->dim();
00305     }
00306     else {
00307       numRows_    = 0;
00308       numCols_    = 0;
00309     }
00310   }
00311 }
00312 
00313 template<class Scalar>
00314 Range1D SerialMultiVectorBase<Scalar>::validateRowRange( const Range1D &rowRng_in ) const
00315 {
00316   const Range1D rowRng = RangePack::full_range(rowRng_in,1,numRows_);
00317 #ifdef _DEBUG
00318   TEST_FOR_EXCEPTION(
00319     rowRng.lbound() < 1 || numRows_ < rowRng.ubound(), std::invalid_argument
00320     ,"SerialMultiVectorBase<Scalar>::validateRowRange(rowRng): Error, the range rowRng = ["
00321     <<rowRng.lbound()<<","<<rowRng.ubound()<<"] is not "
00322     "in the range [1,"<<numRows_<<"]!"
00323     );
00324 #endif
00325   return rowRng;
00326 }
00327 
00328 template<class Scalar>
00329 Range1D SerialMultiVectorBase<Scalar>::validateColRange( const Range1D &colRng_in ) const
00330 {
00331   const Range1D colRng = RangePack::full_range(colRng_in,1,numCols_);
00332 #ifdef _DEBUG
00333   TEST_FOR_EXCEPTION(
00334     colRng.lbound() < 1 || numCols_ < colRng.ubound(), std::invalid_argument
00335     ,"SerialMultiVectorBase<Scalar>::validateColRange(colRng): Error, the range colRng = ["
00336     <<colRng.lbound()<<","<<colRng.ubound()<<"] is not "
00337     "in the range [1,"<<numCols_<<"]!"
00338     );
00339 #endif
00340   return colRng;
00341 }
00342 
00343 } // end namespace Thyra
00344 
00345 #endif // THYRA_SERIAL_MULTI_VECTOR_BASE_HPP

Generated on Thu Sep 18 12:39:52 2008 for Thyra ANA Operator/VectorBase Interfaces and Related Software by doxygen 1.3.9.1