Kokkos Node API and Local Linear Algebra Kernels Version of the Day
TsqrAdaptor.hpp
Go to the documentation of this file.
00001 //@HEADER
00002 // ************************************************************************
00003 // 
00004 //          Kokkos: Node API and Parallel Node Kernels
00005 //              Copyright (2009) Sandia Corporation
00006 // 
00007 // Under terms of Contract DE-AC04-94AL85000, there is a non-exclusive
00008 // license for use of this work by or on behalf of the U.S. Government.
00009 // 
00010 // This library is free software; you can redistribute it and/or modify
00011 // it under the terms of the GNU Lesser General Public License as
00012 // published by the Free Software Foundation; either version 2.1 of the
00013 // License, or (at your option) any later version.
00014 //  
00015 // This library is distributed in the hope that it will be useful, but
00016 // WITHOUT ANY WARRANTY; without even the implied warranty of
00017 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
00018 // Lesser General Public License for more details.
00019 //  
00020 // You should have received a copy of the GNU Lesser General Public
00021 // License along with this library; if not, write to the Free Software
00022 // Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307
00023 // USA
00024 // Questions? Contact Michael A. Heroux (maherou@sandia.gov) 
00025 // 
00026 // ************************************************************************
00027 //@HEADER
00028 
00029 #ifndef __TSQR_Trilinos_TsqrAdaptor_hpp
00030 #define __TSQR_Trilinos_TsqrAdaptor_hpp
00031 
00035 #include <Tsqr_ConfigDefs.hpp>
00036 #include <Teuchos_SerialDenseMatrix.hpp>
00037 #include <TsqrTypeAdaptor.hpp>
00038 #include <TsqrCommFactory.hpp>
00039 #include <Tsqr_GlobalVerify.hpp>
00040 #include <Teuchos_ScalarTraits.hpp>
00041 
00042 #include <stdexcept>
00043 #include <sstream>
00044 
00045 
00046 namespace TSQR {
00047 
00055   namespace Trilinos {
00056 
00105     template<class S, class LO, class GO, class MV>
00106     class TsqrAdaptor {
00107     public:
00108       typedef S   scalar_type;
00109       typedef LO  local_ordinal_type;
00110       typedef GO  global_ordinal_type;
00111       typedef MV  multivector_type;
00112 
00113       typedef typename Teuchos::ScalarTraits<scalar_type>::magnitudeType magnitude_type;
00114 
00115       typedef TsqrTypeAdaptor<S, LO, GO, MV>        type_adaptor;
00116       typedef typename type_adaptor::factory_type   factory_type;
00117 
00118       typedef typename type_adaptor::node_tsqr_type node_tsqr_type;
00119       typedef typename type_adaptor::node_tsqr_ptr  node_tsqr_ptr;
00120 
00121       typedef typename type_adaptor::comm_type      comm_type;
00122       typedef typename type_adaptor::comm_ptr       comm_ptr;
00123 
00124       typedef typename type_adaptor::dist_tsqr_type dist_tsqr_type;
00125       typedef typename type_adaptor::dist_tsqr_ptr  dist_tsqr_ptr;
00126 
00127       typedef typename type_adaptor::tsqr_type      tsqr_type;
00128       typedef typename type_adaptor::tsqr_ptr       tsqr_ptr;
00129 
00130       typedef typename tsqr_type::FactorOutput      factor_output_type;
00131       typedef Teuchos::SerialDenseMatrix<LO, S>     dense_matrix_type;
00132       typedef Teuchos::RCP< MessengerBase<S> >      scalar_messenger_ptr;
00133       typedef Teuchos::RCP< MessengerBase<LO> >     ordinal_messenger_ptr;
00134 
00136       virtual ~TsqrAdaptor() {}
00137 
00155       void
00156       factorExplicit (multivector_type& A, 
00157           multivector_type& Q, 
00158           dense_matrix_type& R,
00159           const bool contiguousCacheBlocks = false)
00160       {
00161   factor_output_type output = factor (A, R, contiguousCacheBlocks);
00162   explicitQ (A, output, Q, contiguousCacheBlocks);
00163       }
00164 
00201       virtual factor_output_type
00202       factor (multivector_type& A, 
00203         dense_matrix_type& R,
00204         const bool contiguousCacheBlocks = false)
00205       {
00206   local_ordinal_type nrowsLocal, ncols, LDA;
00207   fetchDims (A, nrowsLocal, ncols, LDA);
00208   // This is guaranteed to be _correct_ for any Node type, but
00209   // won't necessary be efficient.  The desired model is that
00210   // A_local requires no copying.
00211   Teuchos::ArrayRCP< scalar_type > A_local = fetchNonConstView (A);
00212 
00213   // Reshape R if necessary.  This operation zeros out all the
00214   // entries of R, which is what we want anyway.
00215   if (R.numRows() != ncols || R.numCols() != ncols)
00216     {
00217       if (0 != R.shape (ncols, ncols))
00218         throw std::runtime_error ("Failed to reshape matrix R");
00219     }
00220   return pTsqr_->factor (nrowsLocal, ncols, A_local.get(), LDA, 
00221              R.values(), R.stride(), contiguousCacheBlocks);
00222       }
00223 
00250       virtual void 
00251       explicitQ (const multivector_type& Q_in, 
00252      const factor_output_type& factorOutput,
00253      multivector_type& Q_out, 
00254      const bool contiguousCacheBlocks = false)
00255       {
00256   using Teuchos::ArrayRCP;
00257 
00258   local_ordinal_type nrowsLocal, ncols_in, LDQ_in;
00259   fetchDims (Q_in, nrowsLocal, ncols_in, LDQ_in);
00260   local_ordinal_type nrowsLocal_out, ncols_out, LDQ_out;
00261   fetchDims (Q_out, nrowsLocal_out, ncols_out, LDQ_out);
00262 
00263   if (nrowsLocal_out != nrowsLocal)
00264     {
00265       std::ostringstream os;
00266       os << "TSQR explicit Q: input Q factor\'s node-local part has a di"
00267         "fferent number of rows (" << nrowsLocal << ") than output Q fac"
00268         "tor\'s node-local part (" << nrowsLocal_out << ").";
00269       throw std::runtime_error (os.str());
00270     }
00271   ArrayRCP< const scalar_type > pQin = fetchConstView (Q_in);
00272   ArrayRCP< scalar_type > pQout = fetchNonConstView (Q_out);
00273   pTsqr_->explicit_Q (nrowsLocal, 
00274           ncols_in, pQin.get(), LDQ_in, 
00275           factorOutput,
00276           ncols_out, pQout.get(), LDQ_out,
00277           contiguousCacheBlocks);
00278       }
00279 
00304       local_ordinal_type
00305       revealRank (multivector_type& Q,
00306       dense_matrix_type& R,
00307       const magnitude_type relativeTolerance,
00308       const bool contiguousCacheBlocks = false) const
00309       {
00310   using Teuchos::ArrayRCP;
00311 
00312   local_ordinal_type nrowsLocal, ncols, ldqLocal;
00313   fetchDims (Q, nrowsLocal, ncols, ldqLocal);
00314 
00315   ArrayRCP< scalar_type > Q_ptr = fetchNonConstView (Q);
00316   return pTsqr_->reveal_rank (nrowsLocal, ncols, 
00317             Q_ptr.get(), ldqLocal,
00318             R.values(), R.stride(), 
00319             relativeTolerance, 
00320             contiguousCacheBlocks);
00321       }
00322 
00333       virtual void 
00334       cacheBlock (const multivector_type& A_in, 
00335       multivector_type& A_out)
00336       {
00337   using Teuchos::ArrayRCP;
00338 
00339   local_ordinal_type nrowsLocal, ncols, LDA_in;
00340   fetchDims (A_in, nrowsLocal, ncols, LDA_in);
00341   local_ordinal_type nrowsLocal_out, ncols_out, LDA_out;
00342   fetchDims (A_out, nrowsLocal_out, ncols_out, LDA_out);
00343 
00344   if (nrowsLocal_out != nrowsLocal)
00345     {
00346       std::ostringstream os;
00347       os << "TSQR cache block: the input matrix\'s node-local part has a"
00348         " different number of rows (" << nrowsLocal << ") than the outpu"
00349         "t matrix\'s node-local part (" << nrowsLocal_out << ").";
00350       throw std::runtime_error (os.str());
00351     }
00352   else if (ncols_out != ncols)
00353     {
00354       std::ostringstream os;
00355       os << "TSQR cache block: the input matrix\'s node-local part has a"
00356         " different number of columns (" << ncols << ") than the output "
00357         "matrix\'s node-local part (" << ncols_out << ").";
00358       throw std::runtime_error (os.str());
00359     }
00360   ArrayRCP< const scalar_type > pA_in = fetchConstView (A_in);
00361   ArrayRCP< scalar_type > pA_out = fetchNonConstView (A_out);
00362   pTsqr_->cache_block (nrowsLocal, ncols, pA_out.get(), 
00363            pA_in.get(), LDA_in);
00364       }
00365 
00371       virtual void 
00372       unCacheBlock (const multivector_type& A_in, 
00373         multivector_type& A_out)
00374       {
00375   using Teuchos::ArrayRCP;
00376 
00377   local_ordinal_type nrowsLocal, ncols, LDA_in;
00378   fetchDims (A_in, nrowsLocal, ncols, LDA_in);
00379   local_ordinal_type nrowsLocal_out, ncols_out, LDA_out;
00380   fetchDims (A_out, nrowsLocal_out, ncols_out, LDA_out);
00381 
00382   if (nrowsLocal_out != nrowsLocal)
00383     {
00384       std::ostringstream os;
00385       os << "TSQR un-cache-block: the input matrix\'s node-local part ha"
00386         "s a different number of rows (" << nrowsLocal << ") than the ou"
00387         "tput matrix\'s node-local part (" << nrowsLocal_out << ").";
00388       throw std::runtime_error (os.str());
00389     }
00390   else if (ncols_out != ncols)
00391     {
00392       std::ostringstream os;
00393       os << "TSQR cache block: the input matrix\'s node-local part has a"
00394         " different number of columns (" << ncols << ") than the output "
00395         "matrix\'s node-local part (" << ncols_out << ").";
00396       throw std::runtime_error (os.str());
00397     }
00398   ArrayRCP< const scalar_type > pA_in = fetchConstView (A_in);
00399   ArrayRCP< scalar_type > pA_out = fetchNonConstView (A_out);
00400   pTsqr_->un_cache_block (nrowsLocal, ncols, pA_out.get(), 
00401         LDA_out, pA_in.get());
00402       }
00403       
00415       virtual std::vector< magnitude_type >
00416       verify (const multivector_type& A,
00417         const multivector_type& Q,
00418         const Teuchos::SerialDenseMatrix< local_ordinal_type, scalar_type >& R)
00419       {
00420   using Teuchos::ArrayRCP;
00421 
00422   local_ordinal_type nrowsLocal_A, ncols_A, LDA;
00423   local_ordinal_type nrowsLocal_Q, ncols_Q, LDQ;
00424   fetchDims (A, nrowsLocal_A, ncols_A, LDA);
00425   fetchDims (Q, nrowsLocal_Q, ncols_Q, LDQ);
00426   if (nrowsLocal_A != nrowsLocal_Q)
00427     throw std::runtime_error ("A and Q must have same number of rows");
00428   else if (ncols_A != ncols_Q)
00429     throw std::runtime_error ("A and Q must have same number of columns");
00430   else if (ncols_A != R.numCols())
00431     throw std::runtime_error ("A and R must have same number of columns");
00432   else if (R.numRows() < R.numCols())
00433     throw std::runtime_error ("R must have no fewer rows than columns");
00434 
00435   // Const views suffice for verification
00436   ArrayRCP< const scalar_type > A_ptr = fetchConstView (A);
00437   ArrayRCP< const scalar_type > Q_ptr = fetchConstView (Q);
00438   return global_verify (nrowsLocal_A, ncols_A, A_ptr.get(), LDA,
00439             Q_ptr.get(), LDQ, R.values(), R.stride(), 
00440             pScalarMessenger_.get());
00441       }
00442 
00443     protected:
00444 
00453       void 
00454       init (const multivector_type& mv,
00455       const Teuchos::ParameterList& plist)
00456       {
00457   // This is done in a multivector type - dependent way.
00458   fetchMessengers (mv, pScalarMessenger_, pOrdinalMessenger_);
00459 
00460   factory_type factory;
00461   // plist and pScalarMessenger_ are inputs.  Construct *pTsqr_.
00462   factory.makeTsqr (plist, pScalarMessenger_, pTsqr_);
00463       }
00464 
00465     private:
00481       virtual void 
00482       fetchDims (const multivector_type& A, 
00483      local_ordinal_type& nrowsLocal, 
00484      local_ordinal_type& ncols, 
00485      local_ordinal_type& LDA) const = 0;
00486 
00494       virtual Teuchos::ArrayRCP<scalar_type> 
00495       fetchNonConstView (multivector_type& A) const = 0;
00496 
00504       virtual Teuchos::ArrayRCP<const scalar_type> 
00505       fetchConstView (const multivector_type& A) const = 0;
00506 
00520       virtual void
00521       fetchMessengers (const multivector_type& mv,
00522            scalar_messenger_ptr& pScalarMessenger,
00523            ordinal_messenger_ptr& pOrdinalMessenger) const = 0;
00524 
00526       scalar_messenger_ptr pScalarMessenger_;
00527 
00530       ordinal_messenger_ptr pOrdinalMessenger_;
00531 
00534       tsqr_ptr pTsqr_;
00535     };
00536 
00537   } // namespace Trilinos
00538 } // namespace TSQR
00539 
00540 #endif // __TSQR_Trilinos_TsqrAdaptor_hpp
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends