Thyra_SpmdMultiVectorBase.hpp

00001 // @HEADER
00002 // ***********************************************************************
00003 // 
00004 //    Thyra: Interfaces and Support for Abstract Numerical Algorithms
00005 //                 Copyright (2004) Sandia Corporation
00006 // 
00007 // Under terms of Contract DE-AC04-94AL85000, there is a non-exclusive
00008 // license for use of this work by or on behalf of the U.S. Government.
00009 // 
00010 // This library is free software; you can redistribute it and/or modify
00011 // it under the terms of the GNU Lesser General Public License as
00012 // published by the Free Software Foundation; either version 2.1 of the
00013 // License, or (at your option) any later version.
00014 //  
00015 // This library is distributed in the hope that it will be useful, but
00016 // WITHOUT ANY WARRANTY; without even the implied warranty of
00017 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
00018 // Lesser General Public License for more details.
00019 //  
00020 // You should have received a copy of the GNU Lesser General Public
00021 // License along with this library; if not, write to the Free Software
00022 // Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307
00023 // USA
00024 // Questions? Contact Michael A. Heroux (maherou@sandia.gov)
00025 // 
00026 // ***********************************************************************
00027 // @HEADER
00028 
00029 #ifndef THYRA_SPMD_MULTI_VECTOR_BASE_HPP
00030 #define THYRA_SPMD_MULTI_VECTOR_BASE_HPP
00031 
00032 
00033 #include "Thyra_SpmdMultiVectorBaseDecl.hpp"
00034 #include "Thyra_MultiVectorDefaultBase.hpp"
00035 #include "Thyra_SingleScalarEuclideanLinearOpBase.hpp"
00036 #include "Thyra_SpmdVectorSpaceDefaultBase.hpp"
00037 #include "Thyra_DetachedMultiVectorView.hpp"
00038 #include "RTOpPack_SPMD_apply_op.hpp"
00039 #include "RTOp_parallel_helpers.h"
00040 #include "Teuchos_Workspace.hpp"
00041 #include "Teuchos_dyn_cast.hpp"
00042 #include "Teuchos_Time.hpp"
00043 #include "Teuchos_CommHelpers.hpp"
00044 
00045 
00046 // Define to see some timing output!
00047 //#define THYRA_SPMD_MULTI_VECTOR_BASE_PRINT_TIMES
00048 
00049 
00050 namespace Thyra {
00051 
00052 
00053 template<class Scalar>
00054 SpmdMultiVectorBase<Scalar>::SpmdMultiVectorBase()
00055   :in_applyOp_(false)
00056   ,globalDim_(0)
00057   ,localOffset_(-1)
00058   ,localSubDim_(0)
00059   ,numCols_(0)
00060   ,nonconstLocalValuesViewPtr_(0)
00061   ,localValuesViewPtr_(0)
00062 {}
00063 
00064 
00065 // Overridden from EuclideanLinearOpBase
00066 
00067 
00068 template<class Scalar>
00069 RCP< const ScalarProdVectorSpaceBase<Scalar> >
00070 SpmdMultiVectorBase<Scalar>::rangeScalarProdVecSpc() const
00071 {
00072   return Teuchos::rcp_dynamic_cast<const ScalarProdVectorSpaceBase<Scalar> >(
00073     spmdSpace(),true
00074     );
00075 }
00076 
00077 
00078 // Overridden from LinearOpBase
00079 
00080 
00081 template<class Scalar>
00082 void SpmdMultiVectorBase<Scalar>::apply(
00083   const EOpTransp M_trans
00084   ,const MultiVectorBase<Scalar> &X
00085   ,MultiVectorBase<Scalar> *Y
00086   ,const Scalar alpha
00087   ,const Scalar beta
00088   ) const
00089 {
00090   this->single_scalar_euclidean_apply_impl(M_trans,X,Y,alpha,beta);
00091 }
00092 
00093 
00094 // Overridden from MultiVectorBase
00095 
00096 
00097 template<class Scalar>
00098 void SpmdMultiVectorBase<Scalar>::mvMultiReductApplyOpImpl(
00099   const RTOpPack::RTOpT<Scalar> &pri_op,
00100   const ArrayView<const Ptr<const MultiVectorBase<Scalar> > > &multi_vecs,
00101   const ArrayView<const Ptr<MultiVectorBase<Scalar> > > &targ_multi_vecs,
00102   const ArrayView<const Ptr<RTOpPack::ReductTarget> > &reduct_objs,
00103   const Index pri_first_ele_offset_in,
00104   const Index pri_sub_dim_in,
00105   const Index pri_global_offset_in,
00106   const Index sec_first_ele_offset_in,
00107   const Index sec_sub_dim_in
00108   ) const
00109 {
00110   using Teuchos::dyn_cast;
00111   using Teuchos::Workspace;
00112   Teuchos::WorkspaceStore* wss = Teuchos::get_default_workspace_store().get();
00113   const Index numCols = this->domain()->dim();
00114   const SpmdVectorSpaceBase<Scalar> &spmdSpc = *spmdSpace();
00115 #ifdef TEUCHOS_DEBUG
00116   TEST_FOR_EXCEPTION(
00117     in_applyOp_, std::invalid_argument
00118     ,"SpmdMultiVectorBase<>::mvMultiReductApplyOpImpl(...): Error, this method is being entered recursively which is a "
00119     "clear sign that one of the methods acquireDetachedView(...), releaseDetachedView(...) or commitDetachedView(...) "
00120     "was not implemented properly!"
00121     );
00122   apply_op_validate_input(
00123     "SpmdMultiVectorBase<>::mvMultiReductApplyOpImpl(...)", *this->domain(), *this->range(),
00124     pri_op, multi_vecs, targ_multi_vecs, reduct_objs,
00125     pri_first_ele_offset_in, pri_sub_dim_in, pri_global_offset_in,
00126     sec_first_ele_offset_in, sec_sub_dim_in
00127     );
00128 #endif
00129   // Flag that we are in applyOp()
00130   in_applyOp_ = true;
00131   // First see if this is a locally replicated vector in which case
00132   // we treat this as a local operation only.
00133   const bool locallyReplicated = (localSubDim_ == globalDim_);
00134   // Get the overlap in the current process with the input logical sub-vector
00135   // from (first_ele_offset_in,sub_dim_in,global_offset_in)
00136   Teuchos_Index overlap_first_local_ele_off = 0;
00137   Teuchos_Index overlap_local_sub_dim = 0;
00138   Teuchos_Index overlap_global_offset = 0;
00139   RTOp_parallel_calc_overlap(
00140     globalDim_, localSubDim_, localOffset_, pri_first_ele_offset_in, pri_sub_dim_in, pri_global_offset_in
00141     ,&overlap_first_local_ele_off, &overlap_local_sub_dim, &overlap_global_offset
00142     );
00143   const Range1D
00144     local_rng = (
00145       overlap_first_local_ele_off>=0
00146       ? Range1D( localOffset_+overlap_first_local_ele_off, localOffset_+overlap_first_local_ele_off+overlap_local_sub_dim-1 )
00147       : Range1D::Invalid
00148       ),
00149     col_rng(
00150       sec_first_ele_offset_in
00151       ,sec_sub_dim_in >= 0 ? sec_first_ele_offset_in+sec_sub_dim_in-1 : numCols-1
00152       );
00153   // Create sub-vector views of all of the *participating* local data
00154   Workspace<RTOpPack::ConstSubMultiVectorView<Scalar> > sub_multi_vecs(wss,multi_vecs.size());
00155   Workspace<RTOpPack::SubMultiVectorView<Scalar> > targ_sub_multi_vecs(wss,targ_multi_vecs.size());
00156   if( overlap_first_local_ele_off >= 0 ) {
00157     for(int k = 0; k < multi_vecs.size(); ++k ) {
00158       multi_vecs[k]->acquireDetachedView( local_rng, col_rng, &sub_multi_vecs[k] );
00159       sub_multi_vecs[k].setGlobalOffset( overlap_global_offset );
00160     }
00161     for(int k = 0; k < targ_multi_vecs.size(); ++k ) {
00162       targ_multi_vecs[k]->acquireDetachedView( local_rng, col_rng, &targ_sub_multi_vecs[k] );
00163       targ_sub_multi_vecs[k].setGlobalOffset( overlap_global_offset );
00164     }
00165   }
00166   Workspace<RTOpPack::ReductTarget*> reduct_objs_ptr(wss,reduct_objs.size());
00167   for (int k = 0; k < reduct_objs.size(); ++k) {
00168     reduct_objs_ptr[k] = &*reduct_objs[k];
00169   }
00170   // Apply the RTOp operator object (all processors must participate)
00171   RTOpPack::SPMD_apply_op(
00172     locallyReplicated ? NULL : spmdSpc.getComm().get() // comm
00173     ,pri_op // op
00174     ,col_rng.size() // num_cols
00175     ,multi_vecs.size() // multi_vecs.size()
00176     ,multi_vecs.size() && overlap_first_local_ele_off>=0 ? &sub_multi_vecs[0] : NULL // sub_multi_vecs
00177     ,targ_multi_vecs.size() // targ_multi_vecs.size()
00178     ,targ_multi_vecs.size() && overlap_first_local_ele_off>=0 ? &targ_sub_multi_vecs[0] : NULL// targ_sub_multi_vecs
00179     ,reduct_objs.size() ? &reduct_objs_ptr[0] : 0 // reduct_objs
00180     );
00181   // Free and commit the local data
00182   if( overlap_first_local_ele_off >= 0 ) {
00183     for(int k = 0; k < multi_vecs.size(); ++k ) {
00184       sub_multi_vecs[k].setGlobalOffset(local_rng.lbound());
00185       multi_vecs[k]->releaseDetachedView( &sub_multi_vecs[k] );
00186     }
00187     for(int k = 0; k < targ_multi_vecs.size(); ++k ) {
00188       targ_sub_multi_vecs[k].setGlobalOffset(local_rng.lbound());
00189       targ_multi_vecs[k]->commitDetachedView( &targ_sub_multi_vecs[k] );
00190     }
00191   }
00192   // Flag that we are leaving applyOp()
00193   in_applyOp_ = false;
00194 }
00195 
00196 
00197 template<class Scalar>
00198 void SpmdMultiVectorBase<Scalar>::acquireDetachedMultiVectorViewImpl(
00199   const Range1D &rowRng_in,
00200   const Range1D &colRng_in,
00201   RTOpPack::ConstSubMultiVectorView<Scalar> *sub_mv
00202   ) const
00203 {
00204   const Range1D rowRng = validateRowRange(rowRng_in);
00205   const Range1D colRng = validateColRange(colRng_in);
00206   if( rowRng.lbound() < localOffset_ || localOffset_+localSubDim_-1 < rowRng.ubound() ) {
00207     // rng consists of off-processor elements so use the default implementation!
00208     MultiVectorDefaultBase<Scalar>::acquireDetachedMultiVectorViewImpl(
00209       rowRng_in,colRng_in,sub_mv
00210       );
00211     return;
00212   }
00213   /*
00214     if (localValuesViewPtr_) {
00215     freeLocalData( localValuesViewPtr_ );
00216     localValuesViewPtr_ = 0;
00217     }
00218   */
00219   // 2007/06/08: rabartl: ABove, this logic for this is all wrong when partial
00220   // views are requested. Therefore, I must assume here that these are always
00221   // direct views. The problem is that client code can ask for one column
00222   // view at a time which is perfectly okay. However, the current way this is
00223   // setup does not handle this well. This all needs to be reworked to clean
00224   // this up.
00225   const Scalar *localValues = NULL;
00226   int leadingDim = 0;
00227   this->getLocalData(&localValues,&leadingDim);
00228   localValuesViewPtr_ = localValues;
00229   sub_mv->initialize(
00230     rowRng.lbound() // globalOffset
00231     ,rowRng.size() // subDim
00232     ,colRng.lbound() // colOffset
00233     ,colRng.size() // numSubCols
00234     ,localValues
00235     +(rowRng.lbound()-localOffset_)
00236     +colRng.lbound()*leadingDim // values
00237     ,leadingDim // leadingDim
00238     );
00239 }
00240 
00241 
00242 template<class Scalar>
00243 void SpmdMultiVectorBase<Scalar>::releaseDetachedMultiVectorViewImpl(
00244   RTOpPack::ConstSubMultiVectorView<Scalar>* sub_mv
00245   ) const
00246 {
00247   if(
00248     sub_mv->globalOffset() < localOffset_ 
00249     ||
00250     localOffset_+localSubDim_ < sub_mv->globalOffset()+sub_mv->subDim()
00251     )
00252   {
00253     // Let the default implementation handle it!
00254     MultiVectorDefaultBase<Scalar>::releaseDetachedMultiVectorViewImpl(sub_mv);
00255     return;
00256   }
00257   /*
00258     #ifdef TEUCHOS_DEBUG
00259     TEST_FOR_EXCEPT( localValuesViewPtr_ == 0 );
00260     #endif
00261     freeLocalData( localValuesViewPtr_ );
00262     localValuesViewPtr_ = 0;
00263   */
00264   // 2007/06/08: rabartl: See comment in acquireDetachedView(...) above!
00265   sub_mv->set_uninitialized();
00266 }
00267 
00268 
00269 template<class Scalar>
00270 void SpmdMultiVectorBase<Scalar>::acquireNonconstDetachedMultiVectorViewImpl(
00271   const Range1D &rowRng_in,
00272   const Range1D &colRng_in,
00273   RTOpPack::SubMultiVectorView<Scalar> *sub_mv
00274   )
00275 {
00276   const Range1D rowRng = validateRowRange(rowRng_in);
00277   const Range1D colRng = validateColRange(colRng_in);
00278   if(
00279     rowRng.lbound() < localOffset_
00280     ||
00281     localOffset_+localSubDim_-1 < rowRng.ubound()
00282     )
00283   {
00284     // rng consists of off-processor elements so use the default implementation!
00285     MultiVectorDefaultBase<Scalar>::acquireNonconstDetachedMultiVectorViewImpl(
00286       rowRng_in, colRng_in, sub_mv
00287       );
00288     return;
00289   }
00290   // rng consists of all local data so get it!
00291   /*
00292     if (nonconstLocalValuesViewPtr_) {
00293     commitLocalData( nonconstLocalValuesViewPtr_ );
00294     nonconstLocalValuesViewPtr_ = 0;
00295     }
00296   */
00297   // 2007/06/08: rabartl: See comment in acquireDetachedView(...) above!
00298   Scalar *localValues = NULL;
00299   int leadingDim = 0;
00300   this->getLocalData(&localValues,&leadingDim);
00301   nonconstLocalValuesViewPtr_ = localValues;
00302   sub_mv->initialize(
00303     rowRng.lbound() // globalOffset
00304     ,rowRng.size() // subDim
00305     ,colRng.lbound() // colOffset
00306     ,colRng.size() // numSubCols
00307     ,localValues
00308     +(rowRng.lbound()-localOffset_)
00309     +colRng.lbound()*leadingDim // values
00310     ,leadingDim // leadingDim
00311     );
00312 }
00313 
00314 
00315 template<class Scalar>
00316 void SpmdMultiVectorBase<Scalar>::commitNonconstDetachedMultiVectorViewImpl(
00317   RTOpPack::SubMultiVectorView<Scalar>* sub_mv
00318   )
00319 {
00320   if(
00321     sub_mv->globalOffset() < localOffset_
00322     ||
00323     localOffset_+localSubDim_ < sub_mv->globalOffset()+sub_mv->subDim()
00324     )
00325   {
00326     // Let the default implementation handle it!
00327     MultiVectorDefaultBase<Scalar>::commitNonconstDetachedMultiVectorViewImpl(sub_mv);
00328     return;
00329   }
00330   /*
00331     #ifdef TEUCHOS_DEBUG
00332     TEST_FOR_EXCEPT( nonconstLocalValuesViewPtr_ == 0 );
00333     #endif
00334     commitLocalData( nonconstLocalValuesViewPtr_ );
00335     nonconstLocalValuesViewPtr_ = 0;
00336   */
00337   // 2007/06/08: rabartl: See comment in acquireDetachedView(...) above!
00338   sub_mv->set_uninitialized();
00339 }
00340 
00341 
00342 // protected
00343 
00344 
00345 template<class Scalar>
00346 void SpmdMultiVectorBase<Scalar>::euclideanApply(
00347   const EOpTransp M_trans
00348   ,const MultiVectorBase<Scalar> &X
00349   ,MultiVectorBase<Scalar> *Y
00350   ,const Scalar alpha
00351   ,const Scalar beta
00352   ) const
00353 {
00354   typedef Teuchos::ScalarTraits<Scalar> ST;
00355   using Teuchos::Workspace;
00356   Teuchos::WorkspaceStore* wss = Teuchos::get_default_workspace_store().get();
00357 
00358 #ifdef THYRA_SPMD_MULTI_VECTOR_BASE_PRINT_TIMES
00359   Teuchos::Time timerTotal("dummy",true);
00360   Teuchos::Time timer("dummy");
00361 #endif
00362 
00363   //
00364   // This function performs one of two operations.
00365   //
00366   // The first operation (M_trans == NOTRANS) is:
00367   //
00368   // Y = beta * Y + alpha * M * X
00369   //
00370   // where Y and M have compatible (distributed?) range vector
00371   // spaces and X is a locally replicated serial multi-vector. This
00372   // operation does not require any global communication.
00373   //
00374   // The second operation (M_trans == TRANS) is:
00375   //
00376   // Y = beta * Y + alpha * M' * X
00377   //
00378   // where M and X have compatible (distributed?) range vector spaces
00379   // and Y is a locally replicated serial multi-vector. This operation
00380   // requires a local reduction.
00381   //
00382 
00383   //
00384   // Get spaces and validate compatibility
00385   //
00386 
00387   // Get the SpmdVectorSpace
00388   const SpmdVectorSpaceBase<Scalar> &spmdSpc = *this->spmdSpace();
00389 
00390   // Get the Spmd communicator
00391   const RCP<const Teuchos::Comm<Index> >
00392     comm = spmdSpc.getComm();
00393 #ifdef TEUCHOS_DEBUG
00394   const VectorSpaceBase<Scalar>
00395     &Y_range = *Y->range(),
00396     &X_range = *X.range();
00397 //  std::cout << "SpmdMultiVectorBase<Scalar>::apply(...): comm = " << comm << std::endl;
00398   TEST_FOR_EXCEPTION(
00399     ( globalDim_ > localSubDim_ ) && comm.get()==NULL, std::logic_error
00400     ,"SpmdMultiVectorBase<Scalar>::apply(...MultiVectorBase<Scalar>...): Error!"
00401     );
00402   // ToDo: Write a good general validation function that I can call that will replace
00403   // all of these TEST_FOR_EXCEPTION(...) uses
00404 
00405   TEST_FOR_EXCEPTION(
00406     real_trans(M_trans)==NOTRANS && !spmdSpc.isCompatible(Y_range), Exceptions::IncompatibleVectorSpaces
00407     ,"SpmdMultiVectorBase<Scalar>::apply(...MultiVectorBase<Scalar>...): Error!"
00408     );
00409   TEST_FOR_EXCEPTION(
00410     real_trans(M_trans)==TRANS && !spmdSpc.isCompatible(X_range), Exceptions::IncompatibleVectorSpaces
00411     ,"SpmdMultiVectorBase<Scalar>::apply(...MultiVectorBase<Scalar>...): Error!"
00412     );
00413 #endif
00414 
00415   //
00416   // Get explicit (local) views of Y, M and X
00417   //
00418 
00419 #ifdef THYRA_SPMD_MULTI_VECTOR_BASE_PRINT_TIMES
00420   timer.start();
00421 #endif
00422  
00423   DetachedMultiVectorView<Scalar>
00424     Y_local(
00425       *Y,
00426       real_trans(M_trans)==NOTRANS ? Range1D(localOffset_,localOffset_+localSubDim_-1) : Range1D(),
00427       Range1D()
00428       );
00429   ConstDetachedMultiVectorView<Scalar>
00430     M_local(
00431       *this,
00432       Range1D(localOffset_,localOffset_+localSubDim_-1),
00433       Range1D()
00434       );
00435   ConstDetachedMultiVectorView<Scalar>
00436     X_local(
00437       X
00438       ,real_trans(M_trans)==NOTRANS ? Range1D() : Range1D(localOffset_,localOffset_+localSubDim_-1)
00439       ,Range1D()
00440       );
00441 #ifdef THYRA_SPMD_MULTI_VECTOR_BASE_PRINT_TIMES
00442   timer.stop();
00443   std::cout << "\nSpmdMultiVectorBase<Scalar>::apply(...): Time for getting view = " << timer.totalElapsedTime() << " seconds\n";
00444 #endif
00445 #ifdef TEUCHOS_DEBUG    
00446   TEST_FOR_EXCEPTION(
00447     real_trans(M_trans)==NOTRANS && ( M_local.numSubCols() != X_local.subDim() || X_local.numSubCols() != Y_local.numSubCols() )
00448     , Exceptions::IncompatibleVectorSpaces
00449     ,"SpmdMultiVectorBase<Scalar>::apply(...MultiVectorBase<Scalar>...): Error!"
00450     );
00451   TEST_FOR_EXCEPTION(
00452     real_trans(M_trans)==TRANS && ( M_local.subDim() != X_local.subDim() || X_local.numSubCols() != Y_local.numSubCols() )
00453     , Exceptions::IncompatibleVectorSpaces
00454     ,"SpmdMultiVectorBase<Scalar>::apply(...MultiVectorBase<Scalar>...): Error!"
00455     );
00456 #endif
00457 
00458   //
00459   // If nonlocal (i.e. M_trans==TRANS) then create temporary storage
00460   // for:
00461   //
00462   // Y_local_tmp = alpha * M(local) * X(local) : on nonroot processes
00463   //
00464   // or
00465   //
00466   // Y_local_tmp = beta*Y_local + alpha * M(local) * X(local) : on root process (localOffset_==0)
00467   // 
00468   // and set
00469   //
00470   // localBeta = ( localOffset_ == 0 ? beta : 0.0 )
00471   //
00472   // Above, we choose localBeta such that we will only perform
00473   // Y_local = beta * Y_local + ... on one process (the root
00474   // process where localOffset_==0x). Then, when we add up Y_local
00475   // on all of the processors and we will get the correct result.
00476   //
00477   // If strictly local (i.e. M_trans == NOTRANS) then set:
00478   //
00479   // Y_local_tmp = Y_local
00480   // localBeta = beta
00481   //
00482 
00483 #ifdef THYRA_SPMD_MULTI_VECTOR_BASE_PRINT_TIMES
00484   timer.start();
00485 #endif
00486  
00487   Workspace<Scalar> Y_local_tmp_store(wss, Y_local.subDim()*Y_local.numSubCols(), false);
00488   RTOpPack::SubMultiVectorView<Scalar> Y_local_tmp;
00489   Scalar localBeta;
00490   if( real_trans(M_trans) == TRANS && globalDim_ > localSubDim_ ) {
00491     // Nonlocal
00492     Y_local_tmp.initialize(
00493       0, Y_local.subDim()
00494       ,0, Y_local.numSubCols()
00495       ,&Y_local_tmp_store[0], Y_local.subDim() // leadingDim == subDim (columns are adjacent)
00496       );
00497     if( localOffset_ == 0 ) {
00498       // Root process: Must copy Y_local into Y_local_tmp
00499       for( int j = 0; j < Y_local.numSubCols(); ++j ) {
00500         Scalar *Y_local_j = Y_local.values() + Y_local.leadingDim()*j;
00501         std::copy( Y_local_j, Y_local_j + Y_local.subDim(), Y_local_tmp.values() + Y_local_tmp.leadingDim()*j );
00502       }
00503       localBeta = beta;
00504     }
00505     else {
00506       // Not the root process
00507       localBeta = 0.0;
00508     }
00509   }
00510   else {
00511     // Local
00512     Y_local_tmp = Y_local.smv(); // Shallow copy only!
00513     localBeta = beta;
00514   }
00515 
00516 #ifdef THYRA_SPMD_MULTI_VECTOR_BASE_PRINT_TIMES
00517   timer.stop();
00518   std::cout << "\nSpmdMultiVectorBase<Scalar>::apply(...): Time for setting up Y_local_tmp and localBeta = " << timer.totalElapsedTime() << " seconds\n";
00519 #endif
00520  
00521   //
00522   // Perform the local multiplication:
00523   //
00524   // Y(local) = localBeta * Y(local) + alpha * op(M(local)) * X(local)
00525   //
00526   // or in BLAS lingo:
00527   //
00528   // C = beta * C + alpha * op(A) * op(B)
00529   //
00530 
00531 #ifdef THYRA_SPMD_MULTI_VECTOR_BASE_PRINT_TIMES
00532   timer.start();
00533 #endif
00534   Teuchos::ETransp t_transp;
00535   if(ST::isComplex) {
00536     switch(M_trans) {
00537       case NOTRANS: t_transp = Teuchos::NO_TRANS; break;
00538       case TRANS: t_transp = Teuchos::TRANS; break;
00539       case CONJTRANS: t_transp = Teuchos::CONJ_TRANS; break;
00540       default: TEST_FOR_EXCEPT(true);
00541     }
00542   }
00543   else {
00544     switch(real_trans(M_trans)) {
00545       case NOTRANS: t_transp = Teuchos::NO_TRANS; break;
00546       case TRANS: t_transp = Teuchos::TRANS; break;
00547       default: TEST_FOR_EXCEPT(true);
00548     }
00549   }
00550   if (M_local.numSubCols() > 0) {
00551     blas_.GEMM(
00552       t_transp // TRANSA
00553       ,Teuchos::NO_TRANS // TRANSB
00554       ,Y_local.subDim() // M
00555       ,Y_local.numSubCols() // N
00556       ,real_trans(M_trans)==NOTRANS ? M_local.numSubCols() : M_local.subDim() // K
00557       ,alpha // ALPHA
00558       ,const_cast<Scalar*>(M_local.values()) // A
00559       ,M_local.leadingDim() // LDA
00560       ,const_cast<Scalar*>(X_local.values()) // B
00561       ,X_local.leadingDim() // LDB
00562       ,localBeta // BETA
00563       ,Y_local_tmp.values().get() // C
00564       ,Y_local_tmp.leadingDim() // LDC
00565       );
00566   }
00567   else {
00568     std::fill( Y_local_tmp.values().begin(), Y_local_tmp.values().end(),
00569       ST::zero() );
00570   }
00571 #ifdef THYRA_SPMD_MULTI_VECTOR_BASE_PRINT_TIMES
00572   timer.stop();
00573   std::cout
00574     << "\nSpmdMultiVectorBase<Scalar>::apply(...): Time for GEMM = "
00575     << timer.totalElapsedTime() << " seconds\n";
00576 #endif
00577 
00578   if( comm.get() ) {
00579  
00580     //
00581     // Perform the global reduction of Y_local_tmp back into Y_local
00582     //
00583  
00584     if( real_trans(M_trans)==TRANS && globalDim_ > localSubDim_ ) {
00585       // Contiguous buffer for final reduction
00586       Workspace<Scalar> Y_local_final_buff(wss,Y_local.subDim()*Y_local.numSubCols(),false);
00587       // Perform the reduction
00588       Teuchos::reduceAll<Index,Scalar>(
00589         *comm,Teuchos::REDUCE_SUM,Y_local_final_buff.size(),Y_local_tmp.values().get(),
00590         &Y_local_final_buff[0]
00591         );
00592       // Load Y_local_final_buff back into Y_local
00593       const Scalar *Y_local_final_buff_ptr = &Y_local_final_buff[0];
00594       for( int j = 0; j < Y_local.numSubCols(); ++j ) {
00595         Scalar *Y_local_ptr = Y_local.values() + Y_local.leadingDim()*j;
00596         for( int i = 0; i < Y_local.subDim(); ++i ) {
00597           (*Y_local_ptr++) = (*Y_local_final_buff_ptr++);
00598         }
00599       }
00600     }
00601   }
00602   else {
00603 
00604     // When you get here the view Y_local will be committed back to Y
00605     // in the destructor to Y_local
00606 
00607   }
00608 
00609 #ifdef THYRA_SPMD_MULTI_VECTOR_BASE_PRINT_TIMES
00610   timer.stop();
00611   std::cout 
00612     << "\nSpmdMultiVectorBase<Scalar>::apply(...): Total time = "
00613     << timerTotal.totalElapsedTime() << " seconds\n";
00614 #endif
00615 
00616 }
00617 
00618 
00619 // Overridden from SingleScalarEuclideanLinearOpBase
00620 
00621 
00622 template<class Scalar>
00623 bool SpmdMultiVectorBase<Scalar>::opSupported(EOpTransp M_trans) const
00624 {
00625   typedef Teuchos::ScalarTraits<Scalar> ST;
00626   return ( ST::isComplex ? ( M_trans!=CONJ ) : true );
00627 }
00628 
00629 template<class Scalar>
00630 void SpmdMultiVectorBase<Scalar>::updateSpmdSpace()
00631 {
00632   if(globalDim_ == 0) {
00633     const SpmdVectorSpaceBase<Scalar> *l_spmdSpace = this->spmdSpace().get();
00634     if(l_spmdSpace) {
00635       globalDim_ = l_spmdSpace->dim();
00636       localOffset_ = l_spmdSpace->localOffset();
00637       localSubDim_ = l_spmdSpace->localSubDim();
00638       numCols_ = this->domain()->dim();
00639     }
00640     else {
00641       globalDim_ = 0;
00642       localOffset_ = -1;
00643       localSubDim_ = 0;
00644       numCols_ = 0;
00645     }
00646   }
00647 }
00648 
00649 
00650 template<class Scalar>
00651 Range1D SpmdMultiVectorBase<Scalar>::validateRowRange( const Range1D &rowRng_in ) const
00652 {
00653   const Range1D rowRng = Teuchos::full_range(rowRng_in,0,globalDim_-1);
00654 #ifdef TEUCHOS_DEBUG
00655   TEST_FOR_EXCEPTION(
00656     !( 0 <= rowRng.lbound() && rowRng.ubound() < globalDim_ ), std::invalid_argument
00657     ,"SpmdMultiVectorBase<Scalar>::validateRowRange(rowRng): Error, the range rowRng = ["
00658     <<rowRng.lbound()<<","<<rowRng.ubound()<<"] is not "
00659     "in the range [0,"<<(globalDim_-1)<<"]!"
00660     );
00661 #endif
00662   return rowRng;
00663 }
00664 
00665 
00666 template<class Scalar>
00667 Range1D SpmdMultiVectorBase<Scalar>::validateColRange( const Range1D &colRng_in ) const
00668 {
00669   const Range1D colRng = Teuchos::full_range(colRng_in,0,numCols_-1);
00670 #ifdef TEUCHOS_DEBUG
00671   TEST_FOR_EXCEPTION(
00672     !(0 <= colRng.lbound() && colRng.ubound() < numCols_), std::invalid_argument
00673     ,"SpmdMultiVectorBase<Scalar>::validateColRange(colRng): Error, the range colRng = ["
00674     <<colRng.lbound()<<","<<colRng.ubound()<<"] is not "
00675     "in the range [0,"<<(numCols_-1)<<"]!"
00676     );
00677 #endif
00678   return colRng;
00679 }
00680 
00681 
00682 } // end namespace Thyra
00683 
00684 
00685 #endif // THYRA_SPMD_MULTI_VECTOR_BASE_HPP

Generated on Wed May 12 21:26:54 2010 for Thyra Operator/Vector Support by  doxygen 1.4.7