Anasazi Version of the Day
TsqrAdaptor.hpp
Go to the documentation of this file.
00001 // @HEADER
00002 // ***********************************************************************
00003 //
00004 //                 Anasazi: Block Eigensolvers Package
00005 //                 Copyright (2010) Sandia Corporation
00006 //
00007 // Under terms of Contract DE-AC04-94AL85000, there is a non-exclusive
00008 // license for use of this work by or on behalf of the U.S. Government.
00009 //
00010 // This library is free software; you can redistribute it and/or modify
00011 // it under the terms of the GNU Lesser General Public License as
00012 // published by the Free Software Foundation; either version 2.1 of the
00013 // License, or (at your option) any later version.
00014 //
00015 // This library is distributed in the hope that it will be useful, but
00016 // WITHOUT ANY WARRANTY; without even the implied warranty of
00017 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
00018 // Lesser General Public License for more details.
00019 //
00020 // You should have received a copy of the GNU Lesser General Public
00021 // License along with this library; if not, write to the Free Software
00022 // Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307
00023 // USA
00024 // Questions? Contact Michael A. Heroux (maherou@sandia.gov)
00025 //
00026 // ***********************************************************************
00027 // @HEADER
00028 
00029 #ifndef __TSQR_Trilinos_TsqrAdaptor_hpp
00030 #define __TSQR_Trilinos_TsqrAdaptor_hpp
00031 
00035 
00036 #include "AnasaziConfigDefs.hpp"
00037 #include "Teuchos_SerialDenseMatrix.hpp"
00038 
00039 #include "TsqrTypeAdaptor.hpp"
00040 #include "TsqrCommFactory.hpp"
00041 
00042 #include "Tsqr_GlobalVerify.hpp"
00043 #include "Tsqr_ScalarTraits.hpp"
00044 
00045 #include <stdexcept>
00046 #include <sstream>
00047 
00050 
00051 namespace TSQR {
00052   namespace Trilinos {
00053 
00104     template< class S, class LO, class GO, class MV >
00105     class TsqrAdaptor {
00106     public:
00107       typedef S   scalar_type;
00108       typedef LO  local_ordinal_type;
00109       typedef GO  global_ordinal_type;
00110       typedef MV  multivector_type;
00111 
00112       typedef typename TSQR::ScalarTraits< scalar_type >::magnitude_type magnitude_type;
00113 
00114       typedef TsqrTypeAdaptor< S, LO, GO, MV >      type_adaptor;
00115       typedef typename type_adaptor::factory_type   factory_type;
00116 
00117       typedef typename type_adaptor::node_tsqr_type node_tsqr_type;
00118       typedef typename type_adaptor::node_tsqr_ptr  node_tsqr_ptr;
00119 
00120       typedef typename type_adaptor::comm_type      comm_type;
00121       typedef typename type_adaptor::comm_ptr       comm_ptr;
00122 
00123       typedef typename type_adaptor::dist_tsqr_type dist_tsqr_type;
00124       typedef typename type_adaptor::dist_tsqr_ptr  dist_tsqr_ptr;
00125 
00126       typedef typename type_adaptor::tsqr_type      tsqr_type;
00127       typedef typename type_adaptor::tsqr_ptr       tsqr_ptr;
00128 
00129       typedef typename tsqr_type::FactorOutput      factor_output_type;
00130       typedef Teuchos::SerialDenseMatrix< LO, S >   dense_matrix_type;
00131       typedef Teuchos::RCP< MessengerBase< S > >    scalar_messenger_ptr;
00132       typedef Teuchos::RCP< MessengerBase< LO > >   ordinal_messenger_ptr;
00133 
00134       virtual ~TsqrAdaptor() {}
00135 
00172       virtual factor_output_type
00173       factor (multivector_type& A, 
00174         dense_matrix_type& R,
00175         const bool contiguousCacheBlocks = false)
00176       {
00177   local_ordinal_type nrowsLocal, ncols, LDA;
00178   fetchDims (A, nrowsLocal, ncols, LDA);
00179   // This is guaranteed to be _correct_ for any Node type, but
00180   // won't necessary be efficient.  The desired model is that
00181   // A_local requires no copying.
00182   Teuchos::ArrayRCP< scalar_type > A_local = fetchNonConstView (A);
00183 
00184   // Reshape R if necessary.  This operation zeros out all the
00185   // entries of R, which is what we want anyway.
00186   if (R.numRows() != ncols || R.numCols() != ncols)
00187     {
00188       if (0 != R.shape (ncols, ncols))
00189         throw std::runtime_error ("Failed to reshape matrix R");
00190     }
00191   return pTsqr_->factor (nrowsLocal, ncols, A_local.get(), LDA, 
00192              R.values(), R.stride(), contiguousCacheBlocks);
00193       }
00194 
00221       virtual void 
00222       explicitQ (const multivector_type& Q_in, 
00223      const factor_output_type& factorOutput,
00224      multivector_type& Q_out, 
00225      const bool contiguousCacheBlocks = false)
00226       {
00227   using Teuchos::ArrayRCP;
00228 
00229   local_ordinal_type nrowsLocal, ncols_in, LDQ_in;
00230   fetchDims (Q_in, nrowsLocal, ncols_in, LDQ_in);
00231   local_ordinal_type nrowsLocal_out, ncols_out, LDQ_out;
00232   fetchDims (Q_out, nrowsLocal_out, ncols_out, LDQ_out);
00233 
00234   if (nrowsLocal_out != nrowsLocal)
00235     {
00236       std::ostringstream os;
00237       os << "TSQR explicit Q: input Q factor\'s node-local part has a di"
00238         "fferent number of rows (" << nrowsLocal << ") than output Q fac"
00239         "tor\'s node-local part (" << nrowsLocal_out << ").";
00240       throw std::runtime_error (os.str());
00241     }
00242   ArrayRCP< const scalar_type > pQin = fetchConstView (Q_in);
00243   ArrayRCP< scalar_type > pQout = fetchNonConstView (Q_out);
00244   pTsqr_->explicit_Q (nrowsLocal, 
00245           ncols_in, pQin.get(), LDQ_in, 
00246           factorOutput,
00247           ncols_out, pQout.get(), LDQ_out,
00248           contiguousCacheBlocks);
00249       }
00250 
00275       local_ordinal_type
00276       revealRank (multivector_type& Q,
00277       dense_matrix_type& R,
00278       const magnitude_type relativeTolerance,
00279       const bool contiguousCacheBlocks = false) const
00280       {
00281   using Teuchos::ArrayRCP;
00282 
00283   local_ordinal_type nrowsLocal, ncols, ldqLocal;
00284   fetchDims (Q, nrowsLocal, ncols, ldqLocal);
00285 
00286   ArrayRCP< scalar_type > Q_ptr = fetchNonConstView (Q);
00287   return pTsqr_->reveal_rank (nrowsLocal, ncols, 
00288             Q_ptr.get(), ldqLocal,
00289             R.values(), R.stride(), 
00290             relativeTolerance, 
00291             contiguousCacheBlocks);
00292       }
00293 
00304       virtual void 
00305       cacheBlock (const multivector_type& A_in, 
00306       multivector_type& A_out)
00307       {
00308   using Teuchos::ArrayRCP;
00309 
00310   local_ordinal_type nrowsLocal, ncols, LDA_in;
00311   fetchDims (A_in, nrowsLocal, ncols, LDA_in);
00312   local_ordinal_type nrowsLocal_out, ncols_out, LDA_out;
00313   fetchDims (A_out, nrowsLocal_out, ncols_out, LDA_out);
00314 
00315   if (nrowsLocal_out != nrowsLocal)
00316     {
00317       std::ostringstream os;
00318       os << "TSQR cache block: the input matrix\'s node-local part has a"
00319         " different number of rows (" << nrowsLocal << ") than the outpu"
00320         "t matrix\'s node-local part (" << nrowsLocal_out << ").";
00321       throw std::runtime_error (os.str());
00322     }
00323   else if (ncols_out != ncols)
00324     {
00325       std::ostringstream os;
00326       os << "TSQR cache block: the input matrix\'s node-local part has a"
00327         " different number of columns (" << ncols << ") than the output "
00328         "matrix\'s node-local part (" << ncols_out << ").";
00329       throw std::runtime_error (os.str());
00330     }
00331   ArrayRCP< const scalar_type > pA_in = fetchConstView (A_in);
00332   ArrayRCP< scalar_type > pA_out = fetchNonConstView (A_out);
00333   pTsqr_->cache_block (nrowsLocal, ncols, pA_out.get(), 
00334            pA_in.get(), LDA_in);
00335       }
00336 
00342       virtual void 
00343       unCacheBlock (const multivector_type& A_in, 
00344         multivector_type& A_out)
00345       {
00346   using Teuchos::ArrayRCP;
00347 
00348   local_ordinal_type nrowsLocal, ncols, LDA_in;
00349   fetchDims (A_in, nrowsLocal, ncols, LDA_in);
00350   local_ordinal_type nrowsLocal_out, ncols_out, LDA_out;
00351   fetchDims (A_out, nrowsLocal_out, ncols_out, LDA_out);
00352 
00353   if (nrowsLocal_out != nrowsLocal)
00354     {
00355       std::ostringstream os;
00356       os << "TSQR un-cache-block: the input matrix\'s node-local part ha"
00357         "s a different number of rows (" << nrowsLocal << ") than the ou"
00358         "tput matrix\'s node-local part (" << nrowsLocal_out << ").";
00359       throw std::runtime_error (os.str());
00360     }
00361   else if (ncols_out != ncols)
00362     {
00363       std::ostringstream os;
00364       os << "TSQR cache block: the input matrix\'s node-local part has a"
00365         " different number of columns (" << ncols << ") than the output "
00366         "matrix\'s node-local part (" << ncols_out << ").";
00367       throw std::runtime_error (os.str());
00368     }
00369   ArrayRCP< const scalar_type > pA_in = fetchConstView (A_in);
00370   ArrayRCP< scalar_type > pA_out = fetchNonConstView (A_out);
00371   pTsqr_->un_cache_block (nrowsLocal, ncols, pA_out.get(), 
00372         LDA_out, pA_in.get());
00373       }
00374 
00377       virtual std::pair< magnitude_type, magnitude_type >
00378       verify (const multivector_type& A,
00379         const multivector_type& Q,
00380         const Teuchos::SerialDenseMatrix< local_ordinal_type, scalar_type >& R)
00381       {
00382   using Teuchos::ArrayRCP;
00383 
00384   local_ordinal_type nrowsLocal_A, ncols_A, LDA;
00385   local_ordinal_type nrowsLocal_Q, ncols_Q, LDQ;
00386   fetchDims (A, nrowsLocal_A, ncols_A, LDA);
00387   fetchDims (Q, nrowsLocal_Q, ncols_Q, LDQ);
00388   if (nrowsLocal_A != nrowsLocal_Q)
00389     throw std::runtime_error ("A and Q must have same number of rows");
00390   else if (ncols_A != ncols_Q)
00391     throw std::runtime_error ("A and Q must have same number of columns");
00392   else if (ncols_A != R.numCols())
00393     throw std::runtime_error ("A and R must have same number of columns");
00394   else if (R.numRows() < R.numCols())
00395     throw std::runtime_error ("R must have no fewer rows than columns");
00396 
00397   // Const views suffice for verification
00398   ArrayRCP< const scalar_type > A_ptr = fetchConstView (A);
00399   ArrayRCP< const scalar_type > Q_ptr = fetchConstView (Q);
00400   return global_verify (nrowsLocal_A, ncols_A, A_ptr.get(), LDA,
00401             Q_ptr.get(), LDQ, R.values(), R.stride(), 
00402             pScalarMessenger_.get());
00403       }
00404 
00405     protected:
00408       void 
00409       init (const multivector_type& mv,
00410       const Teuchos::ParameterList& plist)
00411       {
00412   // This is done in a multivector type - dependent way.
00413   fetchMessengers (mv, pScalarMessenger_, pOrdinalMessenger_);
00414 
00415   factory_type factory;
00416   // plist and pScalarMessenger_ are inputs.  Construct *pTsqr_.
00417   factory.makeTsqr (plist, pScalarMessenger_, pTsqr_);
00418       }
00419 
00420     private:
00436       virtual void 
00437       fetchDims (const multivector_type& A, 
00438      local_ordinal_type& nrowsLocal, 
00439      local_ordinal_type& ncols, 
00440      local_ordinal_type& LDA) const = 0;
00441 
00449       virtual Teuchos::ArrayRCP< scalar_type > 
00450       fetchNonConstView (multivector_type& A) const = 0;
00451 
00459       virtual Teuchos::ArrayRCP< const scalar_type > 
00460       fetchConstView (const multivector_type& A) const = 0;
00461 
00464       virtual void
00465       fetchMessengers (const multivector_type& mv,
00466            scalar_messenger_ptr& pScalarMessenger,
00467            ordinal_messenger_ptr& pOrdinalMessenger) const = 0;
00468 
00471       scalar_messenger_ptr pScalarMessenger_;
00472 
00475       ordinal_messenger_ptr pOrdinalMessenger_;
00476 
00479       tsqr_ptr pTsqr_;
00480     };
00481 
00482   } // namespace Trilinos
00483 } // namespace TSQR
00484 
00485 #endif // __TSQR_Trilinos_TsqrAdaptor_hpp
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends