Kokkos Node API and Local Linear Algebra Kernels Version of the Day
Tsqr_SequentialCholeskyQR.hpp
00001 //@HEADER
00002 // ************************************************************************
00003 // 
00004 //          Kokkos: Node API and Parallel Node Kernels
00005 //              Copyright (2008) Sandia Corporation
00006 // 
00007 // Under the terms of Contract DE-AC04-94AL85000 with Sandia Corporation,
00008 // the U.S. Government retains certain rights in this software.
00009 // 
00010 // Redistribution and use in source and binary forms, with or without
00011 // modification, are permitted provided that the following conditions are
00012 // met:
00013 //
00014 // 1. Redistributions of source code must retain the above copyright
00015 // notice, this list of conditions and the following disclaimer.
00016 //
00017 // 2. Redistributions in binary form must reproduce the above copyright
00018 // notice, this list of conditions and the following disclaimer in the
00019 // documentation and/or other materials provided with the distribution.
00020 //
00021 // 3. Neither the name of the Corporation nor the names of the
00022 // contributors may be used to endorse or promote products derived from
00023 // this software without specific prior written permission.
00024 //
00025 // THIS SOFTWARE IS PROVIDED BY SANDIA CORPORATION "AS IS" AND ANY
00026 // EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
00027 // IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
00028 // PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL SANDIA CORPORATION OR THE
00029 // CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
00030 // EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
00031 // PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
00032 // PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
00033 // LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
00034 // NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
00035 // SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
00036 //
00037 // Questions? Contact Michael A. Heroux (maherou@sandia.gov) 
00038 // 
00039 // ************************************************************************
00040 //@HEADER
00041 
00042 #ifndef __TSQR_Tsqr_SequentialCholeskyQR_hpp
00043 #define __TSQR_Tsqr_SequentialCholeskyQR_hpp
00044 
00045 #include <Tsqr_MatView.hpp>
00046 #include <Tsqr_Blas.hpp>
00047 #include <Tsqr_Lapack.hpp>
00048 #include <Tsqr_CacheBlockingStrategy.hpp>
00049 #include <Tsqr_CacheBlocker.hpp>
00050 #include <Tsqr_ScalarTraits.hpp>
00051 #include <Tsqr_Util.hpp>
00052 
00053 #include <string>
00054 #include <utility>
00055 #include <vector>
00056 
00059 
00060 namespace TSQR {
00061 
00071   template<class LocalOrdinal, class Scalar>
00072   class SequentialCholeskyQR {
00073   private:
00074     typedef MatView< LocalOrdinal, Scalar > mat_view;
00075     typedef ConstMatView< LocalOrdinal, Scalar > const_mat_view;
00076 
00077   public:
00078     typedef Scalar scalar_type;
00079     typedef LocalOrdinal ordinal_type;
00080 
00087     typedef int FactorOutput;
00088 
00093     size_t TEUCHOS_DEPRECATED cache_block_size () const { 
00094       return strategy_.cache_size_hint(); 
00095     }
00096 
00098     size_t cache_size_hint () const { return strategy_.cache_size_hint(); }
00099 
00105     SequentialCholeskyQR (const size_t theCacheSizeHint = 0) :
00106       strategy_ (theCacheSizeHint)
00107     {}
00108 
00116     bool QR_produces_R_factor_with_nonnegative_diagonal () const {
00117       return true;
00118     }
00119 
00126     FactorOutput
00127     factor (const LocalOrdinal nrows,
00128       const LocalOrdinal ncols,
00129       const Scalar A[],
00130       const LocalOrdinal lda, 
00131       Scalar R[],
00132       const LocalOrdinal ldr,
00133       const bool contiguous_cache_blocks = false)
00134     {
00135       CacheBlocker< LocalOrdinal, Scalar > blocker (nrows, ncols, strategy_);
00136       LAPACK< LocalOrdinal, Scalar > lapack;
00137       BLAS< LocalOrdinal, Scalar > blas;
00138       std::vector< Scalar > work (ncols);
00139       Matrix< LocalOrdinal, Scalar > ATA (ncols, ncols, Scalar(0));
00140       FactorOutput retval (0);
00141 
00142       if (contiguous_cache_blocks)
00143   {
00144     // Compute ATA := A^T * A, by iterating through the cache
00145     // blocks of A from top to bottom.
00146     //
00147     // We say "A_rest" because it points to the remaining part of
00148     // the matrix left to process; at the beginning, the "remaining"
00149     // part is the whole matrix, but that will change as the
00150     // algorithm progresses.
00151     mat_view A_rest (nrows, ncols, A, lda);
00152     // This call modifies A_rest (but not the actual matrix
00153     // entries; just the dimensions and current position).
00154     mat_view A_cur = blocker.split_top_block (A_rest, contiguous_cache_blocks);
00155     // Process the first cache block: ATA := A_cur^T * A_cur
00156     blas.GEMM ("T", "N", ncols, ncols, A_cur.nrows(), 
00157          Scalar(1), A_cur.get(), A_cur.lda(), A_cur.get(), A_cur.lda(),
00158          Scalar(0), ATA.get(), ATA.lda());
00159     // Process the remaining cache blocks in order.
00160     while (! A_rest.empty())
00161       {
00162         A_cur = blocker.split_top_block (A_rest, contiguous_cache_blocks);
00163         // ATA := ATA + A_cur^T * A_cur
00164         blas.GEMM ("T", "N", ncols, ncols, A_cur.nrows(), 
00165        Scalar(1), A_cur.get(), A_cur.lda(), A_cur.get(), A_cur.lda(),
00166        Scalar(1), ATA.get(), ATA.lda());
00167       }
00168   }
00169       else
00170   // Compute ATA := A^T * A, using a single BLAS call.
00171   blas.GEMM ("T", "N", ncols, ncols, nrows, 
00172        Scalar(1), A, lda, A, lda,
00173        Scalar(0), ATA.get(), ATA.lda());
00174 
00175       // Compute the Cholesky factorization of ATA in place, so that
00176       // A^T * A = R^T * R, where R is ncols by ncols upper
00177       // triangular.
00178       int info = 0;
00179       lapack.POTRF ("U", ncols, ATA.get(), ATA.lda(), &info);
00180       // FIXME (mfh 22 June 2010) The right thing to do here would be
00181       // to resort to a rank-revealing factorization, as Stathopoulos
00182       // and Wu (2002) do with their CholeskyQR + symmetric
00183       // eigensolver factorization.
00184       if (info != 0)
00185   throw std::runtime_error("Cholesky factorization failed");
00186 
00187       // Copy out the R factor
00188       fill_matrix (ncols, ncols, R, ldr, Scalar(0));
00189       copy_upper_triangle (ncols, ncols, R, ldr, ATA.get(), ATA.lda());
00190 
00191       // Compute A := A * R^{-1}.  We do this in place in A, using
00192       // BLAS' TRSM with the R factor (form POTRF) stored in the upper
00193       // triangle of ATA.
00194       {
00195   mat_view A_rest (nrows, ncols, A, lda);
00196   // This call modifies A_rest.
00197   mat_view A_cur = blocker.split_top_block (A_rest, contiguous_cache_blocks);
00198 
00199   // Compute A_cur / R (Matlab notation for A_cur * R^{-1}) in place.
00200   blas.TRSM ("R", "U", "N", "N", A_cur.nrows(), ncols, 
00201        Scalar(1), ATA.get(), ATA.lda(), A_cur.get(), A_cur.lda());
00202 
00203   // Process the remaining cache blocks in order.
00204   while (! A_rest.empty())
00205     {
00206       A_cur = blocker.split_top_block (A_rest, contiguous_cache_blocks);
00207       blas.TRSM ("R", "U", "N", "N", A_cur.nrows(), ncols, 
00208            Scalar(1), ATA.get(), ATA.lda(), A_cur.get(), A_cur.lda());
00209     }
00210       }
00211 
00212       return retval;
00213     }
00214 
00217     void
00218     explicit_Q (const LocalOrdinal nrows,
00219     const LocalOrdinal ncols_Q,
00220     const Scalar Q[],
00221     const LocalOrdinal ldq,
00222     const FactorOutput& factor_output,
00223     const LocalOrdinal ncols_C,
00224     Scalar C[],
00225     const LocalOrdinal ldc,
00226     const bool contiguous_cache_blocks = false)
00227     {
00228       if (ncols_Q != ncols_C)
00229   throw std::logic_error("SequentialCholeskyQR::explicit_Q() "
00230              "does not work if ncols_C != ncols_Q");
00231       const LocalOrdinal ncols = ncols_Q;
00232 
00233       if (contiguous_cache_blocks)
00234   {
00235     CacheBlocker< LocalOrdinal, Scalar > blocker (nrows, ncols, strategy_);
00236     mat_view C_rest (nrows, ncols, C, ldc);
00237     const_mat_view Q_rest (nrows, ncols, Q, ldq);
00238 
00239     mat_view C_cur = blocker.split_top_block (C_rest, contiguous_cache_blocks);
00240     const_mat_view Q_cur = blocker.split_top_block (Q_rest, contiguous_cache_blocks);
00241 
00242     while (! C_rest.empty())
00243       Q_cur.copy (C_cur);
00244   }
00245       else
00246   {
00247     mat_view C_view (nrows, ncols, C, ldc);
00248     C_view.copy (const_mat_view (nrows, ncols, Q, ldq));
00249   }
00250     }
00251 
00252 
00254     void
00255     cache_block (const LocalOrdinal nrows,
00256      const LocalOrdinal ncols,
00257      Scalar A_out[],
00258      const Scalar A_in[],
00259      const LocalOrdinal lda_in) const
00260     {
00261       CacheBlocker< LocalOrdinal, Scalar > blocker (nrows, ncols, strategy_);
00262       blocker.cache_block (nrows, ncols, A_out, A_in, lda_in);
00263     }
00264 
00265 
00267     void
00268     un_cache_block (const LocalOrdinal nrows,
00269         const LocalOrdinal ncols,
00270         Scalar A_out[],
00271         const LocalOrdinal lda_out,       
00272         const Scalar A_in[]) const
00273     {
00274       CacheBlocker< LocalOrdinal, Scalar > blocker (nrows, ncols, strategy_);
00275       blocker.un_cache_block (nrows, ncols, A_out, lda_out, A_in);
00276     }
00277 
00279     void
00280     fill_with_zeros (const LocalOrdinal nrows,
00281          const LocalOrdinal ncols,
00282          Scalar A[],
00283          const LocalOrdinal lda, 
00284          const bool contiguous_cache_blocks = false)
00285     {
00286       CacheBlocker< LocalOrdinal, Scalar > blocker (nrows, ncols, strategy_);
00287       blocker.fill_with_zeros (nrows, ncols, A, lda, contiguous_cache_blocks);
00288     }
00289 
00297     template< class MatrixViewType >
00298     MatrixViewType
00299     top_block (const MatrixViewType& C, 
00300          const bool contiguous_cache_blocks = false) const 
00301     {
00302       // The CacheBlocker object knows how to construct a view of the
00303       // top cache block of C.  This is complicated because cache
00304       // blocks (in C) may or may not be stored contiguously.  If they
00305       // are stored contiguously, the CacheBlocker knows the right
00306       // layout, based on the cache blocking strategy.
00307       CacheBlocker< LocalOrdinal, Scalar > blocker (C.nrows(), C.ncols(), strategy_);
00308 
00309       // C_top_block is a view of the topmost cache block of C.
00310       // C_top_block should have >= ncols rows, otherwise either cache
00311       // blocking is broken or the input matrix C itself had fewer
00312       // rows than columns.
00313       MatrixViewType C_top_block = blocker.top_block (C, contiguous_cache_blocks);
00314       if (C_top_block.nrows() < C_top_block.ncols())
00315   throw std::logic_error ("C\'s topmost cache block has fewer rows than "
00316         "columns");
00317       return C_top_block;
00318     }
00319 
00320   private:
00321     CacheBlockingStrategy< LocalOrdinal, Scalar > strategy_;
00322   };
00323   
00324 } // namespace TSQR
00325 
00326 #endif // __TSQR_Tsqr_SequentialCholeskyQR_hpp
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends