Kokkos Node API and Local Linear Algebra Kernels Version of the Day
TsqrAdaptor.hpp
Go to the documentation of this file.
00001 //@HEADER
00002 // ************************************************************************
00003 // 
00004 //          Kokkos: Node API and Parallel Node Kernels
00005 //              Copyright (2008) Sandia Corporation
00006 // 
00007 // Under the terms of Contract DE-AC04-94AL85000 with Sandia Corporation,
00008 // the U.S. Government retains certain rights in this software.
00009 // 
00010 // Redistribution and use in source and binary forms, with or without
00011 // modification, are permitted provided that the following conditions are
00012 // met:
00013 //
00014 // 1. Redistributions of source code must retain the above copyright
00015 // notice, this list of conditions and the following disclaimer.
00016 //
00017 // 2. Redistributions in binary form must reproduce the above copyright
00018 // notice, this list of conditions and the following disclaimer in the
00019 // documentation and/or other materials provided with the distribution.
00020 //
00021 // 3. Neither the name of the Corporation nor the names of the
00022 // contributors may be used to endorse or promote products derived from
00023 // this software without specific prior written permission.
00024 //
00025 // THIS SOFTWARE IS PROVIDED BY SANDIA CORPORATION "AS IS" AND ANY
00026 // EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
00027 // IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
00028 // PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL SANDIA CORPORATION OR THE
00029 // CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
00030 // EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
00031 // PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
00032 // PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
00033 // LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
00034 // NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
00035 // SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
00036 //
00037 // Questions? Contact Michael A. Heroux (maherou@sandia.gov) 
00038 // 
00039 // ************************************************************************
00040 //@HEADER
00041 
00042 #ifndef __TSQR_Trilinos_TsqrAdaptor_hpp
00043 #define __TSQR_Trilinos_TsqrAdaptor_hpp
00044 
00048 #include <Tsqr_ConfigDefs.hpp>
00049 #include <Teuchos_SerialDenseMatrix.hpp>
00050 #include <TsqrTypeAdaptor.hpp>
00051 #include <TsqrCommFactory.hpp>
00052 #include <Tsqr_GlobalVerify.hpp>
00053 #include <Teuchos_ScalarTraits.hpp>
00054 
00055 #include <stdexcept>
00056 #include <sstream>
00057 
00058 
00059 namespace TSQR {
00060 
00068   namespace Trilinos {
00069 
00118     template<class S, class LO, class GO, class MV>
00119     class TsqrAdaptor {
00120     public:
00121       typedef S   scalar_type;
00122       typedef LO  local_ordinal_type;
00123       typedef GO  global_ordinal_type;
00124       typedef MV  multivector_type;
00125 
00126       typedef typename Teuchos::ScalarTraits<scalar_type>::magnitudeType magnitude_type;
00127 
00128       typedef TsqrTypeAdaptor<S, LO, GO, MV>        type_adaptor;
00129       typedef typename type_adaptor::factory_type   factory_type;
00130 
00131       typedef typename type_adaptor::node_tsqr_type node_tsqr_type;
00132       typedef typename type_adaptor::node_tsqr_ptr  node_tsqr_ptr;
00133 
00134       typedef typename type_adaptor::comm_type      comm_type;
00135       typedef typename type_adaptor::comm_ptr       comm_ptr;
00136 
00137       typedef typename type_adaptor::dist_tsqr_type dist_tsqr_type;
00138       typedef typename type_adaptor::dist_tsqr_ptr  dist_tsqr_ptr;
00139 
00140       typedef typename type_adaptor::tsqr_type      tsqr_type;
00141       typedef typename type_adaptor::tsqr_ptr       tsqr_ptr;
00142 
00143       typedef typename tsqr_type::FactorOutput      factor_output_type;
00144       typedef Teuchos::SerialDenseMatrix<LO, S>     dense_matrix_type;
00145       typedef Teuchos::RCP< MessengerBase<S> >      scalar_messenger_ptr;
00146       typedef Teuchos::RCP< MessengerBase<LO> >     ordinal_messenger_ptr;
00147 
00149       virtual ~TsqrAdaptor() {}
00150 
00168       void
00169       factorExplicit (multivector_type& A, 
00170           multivector_type& Q, 
00171           dense_matrix_type& R,
00172           const bool contiguousCacheBlocks = false)
00173       {
00174   // Lazily init the intranode part of TSQR if necessary.
00175   initNodeTsqr (A);
00176 
00177   factor_output_type output = factor (A, R, contiguousCacheBlocks);
00178   explicitQ (A, output, Q, contiguousCacheBlocks);
00179       }
00180 
00217       virtual factor_output_type
00218       factor (multivector_type& A, 
00219         dense_matrix_type& R,
00220         const bool contiguousCacheBlocks = false)
00221       {
00222   // Lazily init the intranode part of TSQR if necessary.
00223   initNodeTsqr (A);
00224 
00225   local_ordinal_type nrowsLocal, ncols, LDA;
00226   fetchDims (A, nrowsLocal, ncols, LDA);
00227   // This is guaranteed to be _correct_ for any Node type, but
00228   // won't necessary be efficient.  The desired model is that
00229   // A_local requires no copying.
00230   Teuchos::ArrayRCP< scalar_type > A_local = fetchNonConstView (A);
00231 
00232   // Reshape R if necessary.  This operation zeros out all the
00233   // entries of R, which is what we want anyway.
00234   if (R.numRows() != ncols || R.numCols() != ncols)
00235     {
00236       if (0 != R.shape (ncols, ncols))
00237         throw std::runtime_error ("Failed to reshape matrix R");
00238     }
00239   return pTsqr_->factor (nrowsLocal, ncols, A_local.get(), LDA, 
00240              R.values(), R.stride(), contiguousCacheBlocks);
00241       }
00242 
00269       virtual void 
00270       explicitQ (const multivector_type& Q_in, 
00271      const factor_output_type& factorOutput,
00272      multivector_type& Q_out, 
00273      const bool contiguousCacheBlocks = false)
00274       {
00275   using Teuchos::ArrayRCP;
00276 
00277   // Lazily init the intranode part of TSQR if necessary.
00278   initNodeTsqr (Q_in);
00279 
00280   local_ordinal_type nrowsLocal, ncols_in, LDQ_in;
00281   fetchDims (Q_in, nrowsLocal, ncols_in, LDQ_in);
00282   local_ordinal_type nrowsLocal_out, ncols_out, LDQ_out;
00283   fetchDims (Q_out, nrowsLocal_out, ncols_out, LDQ_out);
00284 
00285   if (nrowsLocal_out != nrowsLocal)
00286     {
00287       std::ostringstream os;
00288       os << "TSQR explicit Q: input Q factor\'s node-local part has a di"
00289         "fferent number of rows (" << nrowsLocal << ") than output Q fac"
00290         "tor\'s node-local part (" << nrowsLocal_out << ").";
00291       throw std::runtime_error (os.str());
00292     }
00293   ArrayRCP< const scalar_type > pQin = fetchConstView (Q_in);
00294   ArrayRCP< scalar_type > pQout = fetchNonConstView (Q_out);
00295   pTsqr_->explicit_Q (nrowsLocal, 
00296           ncols_in, pQin.get(), LDQ_in, 
00297           factorOutput,
00298           ncols_out, pQout.get(), LDQ_out,
00299           contiguousCacheBlocks);
00300       }
00301 
00326       local_ordinal_type
00327       revealRank (multivector_type& Q,
00328       dense_matrix_type& R,
00329       const magnitude_type relativeTolerance,
00330       const bool contiguousCacheBlocks = false) const
00331       {
00332   using Teuchos::ArrayRCP;
00333 
00334   // Lazily init the intranode part of TSQR if necessary.
00335   initNodeTsqr (Q);
00336 
00337   local_ordinal_type nrowsLocal, ncols, ldqLocal;
00338   fetchDims (Q, nrowsLocal, ncols, ldqLocal);
00339 
00340   ArrayRCP< scalar_type > Q_ptr = fetchNonConstView (Q);
00341   return pTsqr_->reveal_rank (nrowsLocal, ncols, 
00342             Q_ptr.get(), ldqLocal,
00343             R.values(), R.stride(), 
00344             relativeTolerance, 
00345             contiguousCacheBlocks);
00346       }
00347 
00358       virtual void 
00359       cacheBlock (const multivector_type& A_in, 
00360       multivector_type& A_out)
00361       {
00362   using Teuchos::ArrayRCP;
00363 
00364   // Lazily init the intranode part of TSQR if necessary.
00365   initNodeTsqr (A_in);
00366 
00367   local_ordinal_type nrowsLocal, ncols, LDA_in;
00368   fetchDims (A_in, nrowsLocal, ncols, LDA_in);
00369   local_ordinal_type nrowsLocal_out, ncols_out, LDA_out;
00370   fetchDims (A_out, nrowsLocal_out, ncols_out, LDA_out);
00371 
00372   if (nrowsLocal_out != nrowsLocal)
00373     {
00374       std::ostringstream os;
00375       os << "TSQR cache block: the input matrix\'s node-local part has a"
00376         " different number of rows (" << nrowsLocal << ") than the outpu"
00377         "t matrix\'s node-local part (" << nrowsLocal_out << ").";
00378       throw std::runtime_error (os.str());
00379     }
00380   else if (ncols_out != ncols)
00381     {
00382       std::ostringstream os;
00383       os << "TSQR cache block: the input matrix\'s node-local part has a"
00384         " different number of columns (" << ncols << ") than the output "
00385         "matrix\'s node-local part (" << ncols_out << ").";
00386       throw std::runtime_error (os.str());
00387     }
00388   ArrayRCP< const scalar_type > pA_in = fetchConstView (A_in);
00389   ArrayRCP< scalar_type > pA_out = fetchNonConstView (A_out);
00390   pTsqr_->cache_block (nrowsLocal, ncols, pA_out.get(), 
00391            pA_in.get(), LDA_in);
00392       }
00393 
00399       virtual void 
00400       unCacheBlock (const multivector_type& A_in, 
00401         multivector_type& A_out)
00402       {
00403   using Teuchos::ArrayRCP;
00404 
00405   // Lazily init the intranode part of TSQR if necessary.
00406   initNodeTsqr (A_in);
00407 
00408   local_ordinal_type nrowsLocal, ncols, LDA_in;
00409   fetchDims (A_in, nrowsLocal, ncols, LDA_in);
00410   local_ordinal_type nrowsLocal_out, ncols_out, LDA_out;
00411   fetchDims (A_out, nrowsLocal_out, ncols_out, LDA_out);
00412 
00413   if (nrowsLocal_out != nrowsLocal)
00414     {
00415       std::ostringstream os;
00416       os << "TSQR un-cache-block: the input matrix\'s node-local part ha"
00417         "s a different number of rows (" << nrowsLocal << ") than the ou"
00418         "tput matrix\'s node-local part (" << nrowsLocal_out << ").";
00419       throw std::runtime_error (os.str());
00420     }
00421   else if (ncols_out != ncols)
00422     {
00423       std::ostringstream os;
00424       os << "TSQR cache block: the input matrix\'s node-local part has a"
00425         " different number of columns (" << ncols << ") than the output "
00426         "matrix\'s node-local part (" << ncols_out << ").";
00427       throw std::runtime_error (os.str());
00428     }
00429   ArrayRCP< const scalar_type > pA_in = fetchConstView (A_in);
00430   ArrayRCP< scalar_type > pA_out = fetchNonConstView (A_out);
00431   pTsqr_->un_cache_block (nrowsLocal, ncols, pA_out.get(), 
00432         LDA_out, pA_in.get());
00433       }
00434       
00446       virtual std::vector< magnitude_type >
00447       verify (const multivector_type& A,
00448         const multivector_type& Q,
00449         const Teuchos::SerialDenseMatrix< local_ordinal_type, scalar_type >& R)
00450       {
00451   using Teuchos::ArrayRCP;
00452 
00453   local_ordinal_type nrowsLocal_A, ncols_A, LDA;
00454   local_ordinal_type nrowsLocal_Q, ncols_Q, LDQ;
00455   fetchDims (A, nrowsLocal_A, ncols_A, LDA);
00456   fetchDims (Q, nrowsLocal_Q, ncols_Q, LDQ);
00457   if (nrowsLocal_A != nrowsLocal_Q)
00458     throw std::runtime_error ("A and Q must have same number of rows");
00459   else if (ncols_A != ncols_Q)
00460     throw std::runtime_error ("A and Q must have same number of columns");
00461   else if (ncols_A != R.numCols())
00462     throw std::runtime_error ("A and R must have same number of columns");
00463   else if (R.numRows() < R.numCols())
00464     throw std::runtime_error ("R must have no fewer rows than columns");
00465 
00466   // Const views suffice for verification
00467   ArrayRCP< const scalar_type > A_ptr = fetchConstView (A);
00468   ArrayRCP< const scalar_type > Q_ptr = fetchConstView (Q);
00469   return global_verify (nrowsLocal_A, ncols_A, A_ptr.get(), LDA,
00470             Q_ptr.get(), LDQ, R.values(), R.stride(), 
00471             pScalarMessenger_.get());
00472       }
00473 
00474     protected:
00475 
00484       void 
00485       init (const multivector_type& mv,
00486       const Teuchos::RCP<Teuchos::ParameterList>& plist)
00487       {
00488   // This is done in a multivector type - dependent way.
00489   fetchMessengers (mv, pScalarMessenger_, pOrdinalMessenger_);
00490 
00491   factory_type factory;
00492   // plist and pScalarMessenger_ are inputs.  Construct *pTsqr_.
00493   factory.makeTsqr (plist, pScalarMessenger_, pTsqr_);
00494       }
00495 
00496       // Lazily init the intranode part of TSQR if necessary.
00497       virtual void initNodeTsqr (const multivector_type& A);
00498 
00499     private:
00515       virtual void 
00516       fetchDims (const multivector_type& A, 
00517      local_ordinal_type& nrowsLocal, 
00518      local_ordinal_type& ncols, 
00519      local_ordinal_type& LDA) const = 0;
00520 
00528       virtual Teuchos::ArrayRCP<scalar_type> 
00529       fetchNonConstView (multivector_type& A) const = 0;
00530 
00538       virtual Teuchos::ArrayRCP<const scalar_type> 
00539       fetchConstView (const multivector_type& A) const = 0;
00540 
00554       virtual void
00555       fetchMessengers (const multivector_type& mv,
00556            scalar_messenger_ptr& pScalarMessenger,
00557            ordinal_messenger_ptr& pOrdinalMessenger) const = 0;
00558 
00560       scalar_messenger_ptr pScalarMessenger_;
00561 
00564       ordinal_messenger_ptr pOrdinalMessenger_;
00565 
00568       tsqr_ptr pTsqr_;
00569     };
00570 
00571   } // namespace Trilinos
00572 } // namespace TSQR
00573 
00574 #endif // __TSQR_Trilinos_TsqrAdaptor_hpp
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends