Anasazi Version of the Day
Tsqr_CombineNative.hpp
00001 // @HEADER
00002 // ***********************************************************************
00003 //
00004 //                 Anasazi: Block Eigensolvers Package
00005 //                 Copyright (2010) Sandia Corporation
00006 //
00007 // Under terms of Contract DE-AC04-94AL85000, there is a non-exclusive
00008 // license for use of this work by or on behalf of the U.S. Government.
00009 //
00010 // This library is free software; you can redistribute it and/or modify
00011 // it under the terms of the GNU Lesser General Public License as
00012 // published by the Free Software Foundation; either version 2.1 of the
00013 // License, or (at your option) any later version.
00014 //
00015 // This library is distributed in the hope that it will be useful, but
00016 // WITHOUT ANY WARRANTY; without even the implied warranty of
00017 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
00018 // Lesser General Public License for more details.
00019 //
00020 // You should have received a copy of the GNU Lesser General Public
00021 // License along with this library; if not, write to the Free Software
00022 // Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307
00023 // USA
00024 // Questions? Contact Michael A. Heroux (maherou@sandia.gov)
00025 //
00026 // ***********************************************************************
00027 // @HEADER
00028 #ifndef __TSQR_CombineNative_hpp
00029 #define __TSQR_CombineNative_hpp
00030 
00031 #include "Tsqr_ApplyType.hpp"
00032 #include "Tsqr_ScalarTraits.hpp"
00033 #include "Tsqr_Blas.hpp"
00034 #include "Tsqr_Lapack.hpp"
00035 #include "Tsqr_CombineDefault.hpp"
00036 
00037 // #define TSQR_COMBINE_NATIVE_DEBUG 1
00038 #ifdef TSQR_COMBINE_NATIVE_DEBUG
00039 #include "Tsqr_Util.hpp"
00040 #include <iostream>
00041 using std::cerr;
00042 using std::endl;
00043 #endif // TSQR_COMBINE_NATIVE_DEBUG
00044 
00047 
00048 namespace TSQR {
00049 
00064   template< class Ordinal, class Scalar, bool isComplex = ScalarTraits< Scalar >::is_complex >
00065   class CombineNative 
00066   {
00067   public:
00068     typedef Scalar scalar_type;
00069     typedef typename ScalarTraits< Scalar >::magnitude_type magnitude_type;
00070     typedef Ordinal ordinal_type;
00071 
00072   private:
00073     typedef BLAS< ordinal_type, scalar_type > blas_type;
00074     typedef LAPACK< ordinal_type, scalar_type > lapack_type;
00075     typedef CombineDefault< ordinal_type, scalar_type > combine_default_type;
00076 
00077   public:
00078 
00079     CombineNative () {}
00080 
00088     static bool QR_produces_R_factor_with_nonnegative_diagonal() { 
00089       return lapack_type::QR_produces_R_factor_with_nonnegative_diagonal() &&
00090   combine_default_type::QR_produces_R_factor_with_nonnegative_diagonal();
00091     }
00092 
00093     void
00094     factor_first (const Ordinal nrows,
00095       const Ordinal ncols,
00096       Scalar A[],
00097       const Ordinal lda,
00098       Scalar tau[],
00099       Scalar work[]) const
00100     {
00101       return default_.factor_first (nrows, ncols, A, lda, tau, work);
00102     }
00103 
00104     void
00105     apply_first (const ApplyType& applyType,
00106      const Ordinal nrows,
00107      const Ordinal ncols_C,
00108      const Ordinal ncols_A,
00109      const Scalar A[],
00110      const Ordinal lda,
00111      const Scalar tau[],
00112      Scalar C[],
00113      const Ordinal ldc,
00114      Scalar work[]) const
00115     {
00116       return default_.apply_first (applyType, nrows, ncols_C, ncols_A, 
00117            A, lda, tau, C, ldc, work);
00118     }
00119 
00120     void
00121     apply_inner (const ApplyType& applyType,
00122      const Ordinal m,
00123      const Ordinal ncols_C,
00124      const Ordinal ncols_Q,
00125      const Scalar A[],
00126      const Ordinal lda,
00127      const Scalar tau[],
00128      Scalar C_top[],
00129      const Ordinal ldc_top,
00130      Scalar C_bot[],
00131      const Ordinal ldc_bot,
00132      Scalar work[]) const;
00133 
00134     void
00135     factor_inner (const Ordinal m,
00136       const Ordinal n,
00137       Scalar R[],
00138       const Ordinal ldr,
00139       Scalar A[],
00140       const Ordinal lda,
00141       Scalar tau[],
00142       Scalar work[]) const;
00143 
00144     void
00145     factor_pair (const Ordinal n,
00146      Scalar R_top[],
00147      const Ordinal ldr_top,
00148      Scalar R_bot[],
00149      const Ordinal ldr_bot,
00150      Scalar tau[],
00151      Scalar work[]) const;
00152 
00153     void
00154     apply_pair (const ApplyType& applyType,
00155     const Ordinal ncols_C,
00156     const Ordinal ncols_Q,
00157     const Scalar R_bot[],
00158     const Ordinal ldr_bot,
00159     const Scalar tau[],
00160     Scalar C_top[],
00161     const Ordinal ldc_top,
00162     Scalar C_bot[],
00163     const Ordinal ldc_bot,
00164     Scalar work[]) const;
00165 
00166   private:
00167     mutable combine_default_type default_;
00168   };
00169      
00170 
00173   template< class Ordinal, class Scalar >
00174   class CombineNative< Ordinal, Scalar, false >
00175   {
00176   public:
00177     typedef Scalar scalar_type;
00178     typedef typename ScalarTraits< Scalar >::magnitude_type magnitude_type;
00179     typedef Ordinal ordinal_type;
00180 
00181   private:
00182     typedef BLAS< ordinal_type, scalar_type > blas_type;
00183     typedef LAPACK< ordinal_type, scalar_type > lapack_type;
00184     typedef CombineDefault< ordinal_type, scalar_type > combine_default_type;
00185 
00186   public:
00187     CombineNative () {}
00188 
00189     static bool QR_produces_R_factor_with_nonnegative_diagonal() { 
00190       return lapack_type::QR_produces_R_factor_with_nonnegative_diagonal() &&
00191   combine_default_type::QR_produces_R_factor_with_nonnegative_diagonal();
00192     }
00193 
00194     void
00195     factor_first (const Ordinal nrows,
00196       const Ordinal ncols,
00197       Scalar A[],
00198       const Ordinal lda,
00199       Scalar tau[],
00200       Scalar work[]) const
00201     {
00202       return default_.factor_first (nrows, ncols, A, lda, tau, work);
00203     }
00204 
00205     void
00206     apply_first (const ApplyType& applyType,
00207      const Ordinal nrows,
00208      const Ordinal ncols_C,
00209      const Ordinal ncols_A,
00210      const Scalar A[],
00211      const Ordinal lda,
00212      const Scalar tau[],
00213      Scalar C[],
00214      const Ordinal ldc,
00215      Scalar work[]) const
00216     {
00217       return default_.apply_first (applyType, nrows, ncols_C, ncols_A, 
00218            A, lda, tau, C, ldc, work);
00219     }
00220 
00221     void
00222     apply_inner (const ApplyType& applyType,
00223      const Ordinal m,
00224      const Ordinal ncols_C,
00225      const Ordinal ncols_Q,
00226      const Scalar A[],
00227      const Ordinal lda,
00228      const Scalar tau[],
00229      Scalar C_top[],
00230      const Ordinal ldc_top,
00231      Scalar C_bot[],
00232      const Ordinal ldc_bot,
00233      Scalar work[]) const;
00234 
00235     void
00236     factor_inner (const Ordinal m,
00237       const Ordinal n,
00238       Scalar R[],
00239       const Ordinal ldr,
00240       Scalar A[],
00241       const Ordinal lda,
00242       Scalar tau[],
00243       Scalar work[]) const;
00244 
00245     void
00246     factor_pair (const Ordinal n,
00247      Scalar R_top[],
00248      const Ordinal ldr_top,
00249      Scalar R_bot[],
00250      const Ordinal ldr_bot,
00251      Scalar tau[],
00252      Scalar work[]) const;
00253 
00254     void
00255     apply_pair (const ApplyType& applyType,
00256     const Ordinal ncols_C,
00257     const Ordinal ncols_Q,
00258     const Scalar R_bot[],
00259     const Ordinal ldr_bot,
00260     const Scalar tau[],
00261     Scalar C_top[],
00262     const Ordinal ldc_top,
00263     Scalar C_bot[],
00264     const Ordinal ldc_bot,
00265     Scalar work[]) const;
00266 
00267   private:
00268     mutable combine_default_type default_;
00269   };
00270 
00271 
00274   template< class Ordinal, class Scalar >
00275   class CombineNative< Ordinal, Scalar, true >
00276   {
00277   public:
00278     typedef Scalar scalar_type;
00279     typedef typename ScalarTraits< Scalar >::magnitude_type magnitude_type;
00280     typedef Ordinal ordinal_type;
00281 
00282   private:
00283     typedef BLAS< ordinal_type, scalar_type > blas_type;
00284     typedef LAPACK< ordinal_type, scalar_type > lapack_type;
00285     typedef CombineDefault< ordinal_type, scalar_type > combine_default_type;
00286 
00287   public:
00288     CombineNative () {}
00289 
00290     static bool QR_produces_R_factor_with_nonnegative_diagonal() { 
00291       return lapack_type::QR_produces_R_factor_with_nonnegative_diagonal() &&
00292   combine_default_type::QR_produces_R_factor_with_nonnegative_diagonal();
00293     }
00294 
00295     void
00296     factor_first (const Ordinal nrows,
00297       const Ordinal ncols,
00298       Scalar A[],
00299       const Ordinal lda,
00300       Scalar tau[],
00301       Scalar work[]) const
00302     {
00303       return default_.factor_first (nrows, ncols, A, lda, tau, work);
00304     }
00305 
00306     void
00307     apply_first (const ApplyType& applyType,
00308      const Ordinal nrows,
00309      const Ordinal ncols_C,
00310      const Ordinal ncols_A,
00311      const Scalar A[],
00312      const Ordinal lda,
00313      const Scalar tau[],
00314      Scalar C[],
00315      const Ordinal ldc,
00316      Scalar work[]) const
00317     {
00318       return default_.apply_first (applyType, nrows, ncols_C, ncols_A, 
00319            A, lda, tau, C, ldc, work);
00320     }
00321 
00322     void
00323     apply_inner (const ApplyType& applyType,
00324      const Ordinal m,
00325      const Ordinal ncols_C,
00326      const Ordinal ncols_Q,
00327      const Scalar A[],
00328      const Ordinal lda,
00329      const Scalar tau[],
00330      Scalar C_top[],
00331      const Ordinal ldc_top,
00332      Scalar C_bot[],
00333      const Ordinal ldc_bot,
00334      Scalar work[]) const
00335     {
00336       return default_.apply_inner (applyType, m, ncols_C, ncols_Q,
00337            A, lda, tau,
00338            C_top, ldc_top, C_bot, ldc_bot,
00339            work);
00340     }
00341 
00342     void
00343     factor_inner (const Ordinal m,
00344       const Ordinal n,
00345       Scalar R[],
00346       const Ordinal ldr,
00347       Scalar A[],
00348       const Ordinal lda,
00349       Scalar tau[],
00350       Scalar work[]) const
00351     {
00352       return default_.factor_inner (m, n, R, ldr, A, lda, tau, work);
00353     }
00354 
00355     void
00356     factor_pair (const Ordinal n,
00357      Scalar R_top[],
00358      const Ordinal ldr_top,
00359      Scalar R_bot[],
00360      const Ordinal ldr_bot,
00361      Scalar tau[],
00362      Scalar work[]) const
00363     {
00364       return default_.factor_pair (n, R_top, ldr_top, R_bot, ldr_bot, tau, work);
00365     }
00366 
00367     void
00368     apply_pair (const ApplyType& applyType,
00369     const Ordinal ncols_C,
00370     const Ordinal ncols_Q,
00371     const Scalar R_bot[],
00372     const Ordinal ldr_bot,
00373     const Scalar tau[],
00374     Scalar C_top[],
00375     const Ordinal ldc_top,
00376     Scalar C_bot[],
00377     const Ordinal ldc_bot,
00378     Scalar work[]) const
00379     {
00380       return default_.apply_pair (applyType, ncols_C, ncols_Q,
00381           R_bot, ldr_bot, tau,
00382           C_top, ldc_top, C_bot, ldc_bot,
00383           work);
00384     }
00385 
00386   private:
00387     mutable combine_default_type default_;
00388   };
00389      
00390 
00391   template< class Ordinal, class Scalar >
00392   void
00393   CombineNative< Ordinal, Scalar, false >::
00394   factor_inner (const Ordinal m,
00395     const Ordinal n,
00396     Scalar R[],
00397     const Ordinal ldr,
00398     Scalar A[],
00399     const Ordinal lda,
00400     Scalar tau[],
00401     Scalar work[]) const
00402   {
00403     const Scalar ZERO(0), ONE(1);
00404     lapack_type lapack;
00405     blas_type blas;
00406 
00407     for (Ordinal k = 0; k < n; ++k)
00408       work[k] = ZERO;
00409 
00410     for (Ordinal k = 0; k < n-1; ++k)
00411       {
00412   Scalar& R_kk = R[ k + k * ldr ];
00413   Scalar* const A_1k = &A[ 0 + k * lda ];
00414   Scalar* const A_1kp1 = &A[ 0 + (k+1) * lda ];
00415 
00416   lapack.LARFP (m + 1, R_kk, A_1k, 1, tau[k]);
00417   blas.GEMV ("T", m, n-k-1, ONE, A_1kp1, lda, A_1k, 1, ZERO, work, 1);
00418 
00419   for (Ordinal j = k+1; j < n; ++j)
00420     {
00421       Scalar& R_kj = R[ k + j*ldr ];
00422 
00423       work[j-k-1] += R_kj;
00424       R_kj -= tau[k] * work[j-k-1];
00425     }
00426   blas.GER (m, n-k-1, -tau[k], A_1k, 1, work, 1, A_1kp1, lda);
00427       }
00428     Scalar& R_nn = R[ (n-1) + (n-1) * ldr ];
00429     Scalar* const A_1n = &A[ 0 + (n-1) * lda ];
00430 
00431     lapack.LARFP (m+1, R_nn, A_1n, 1, tau[n-1]);
00432   }
00433 
00434 
00435   template< class Ordinal, class Scalar >
00436   void
00437   CombineNative< Ordinal, Scalar, false >::
00438   apply_inner (const ApplyType& applyType,
00439          const Ordinal m,
00440          const Ordinal ncols_C,
00441          const Ordinal ncols_Q,
00442          const Scalar A[],
00443          const Ordinal lda,
00444          const Scalar tau[],
00445          Scalar C_top[],
00446          const Ordinal ldc_top,
00447          Scalar C_bot[],
00448          const Ordinal ldc_bot,
00449          Scalar work[]) const
00450   {
00451     const Scalar ZERO(0);
00452     blas_type blas;
00453 
00454     //Scalar* const y = work;
00455     for (Ordinal i = 0; i < ncols_C; ++i)
00456       work[i] = ZERO;
00457     
00458     Ordinal j_start, j_end, j_step;
00459     if (applyType == ApplyType::NoTranspose)
00460       {
00461   j_start = ncols_Q - 1;
00462   j_end = -1; // exclusive
00463   j_step = -1;
00464       }
00465     else
00466       {
00467   j_start = 0; 
00468   j_end = ncols_Q; // exclusive
00469   j_step = +1;
00470       }
00471     for (Ordinal j = j_start; j != j_end; j += j_step)
00472       {
00473   const Scalar* const A_1j = &A[ 0 + j*lda ];
00474 
00475   //blas.GEMV ("T", m, ncols_C, ONE, C_bot, ldc_bot, A_1j, 1, ZERO, &y[0], 1);  
00476   for (Ordinal i = 0; i < ncols_C; ++i)
00477     {
00478       work[i] = ZERO;
00479       for (Ordinal k = 0; k < m; ++k)
00480         work[i] += A_1j[k] * C_bot[ k + i*ldc_bot ];
00481 
00482       work[i] += C_top[ j + i*ldc_top ];
00483     }
00484   for (Ordinal k = 0; k < ncols_C; ++k)
00485     C_top[ j + k*ldc_top ] -= tau[j] * work[k];
00486 
00487   blas.GER (m, ncols_C, -tau[j], A_1j, 1, work, 1, C_bot, ldc_bot);
00488       }
00489   }
00490 
00491 
00492   template< class Ordinal, class Scalar >
00493   void
00494   CombineNative< Ordinal, Scalar, false >::
00495   factor_pair (const Ordinal n,
00496          Scalar R_top[],
00497          const Ordinal ldr_top,
00498          Scalar R_bot[],
00499          const Ordinal ldr_bot,
00500          Scalar tau[],
00501          Scalar work[]) const
00502   {
00503     const Scalar ZERO(0), ONE(1);
00504     lapack_type lapack;
00505     blas_type blas;
00506 
00507     for (Ordinal k = 0; k < n; ++k)
00508       work[k] = ZERO;
00509 
00510     for (Ordinal k = 0; k < n-1; ++k)
00511       {
00512   Scalar& R_top_kk = R_top[ k + k * ldr_top ];
00513   Scalar* const R_bot_1k = &R_bot[ 0 + k * ldr_bot ];
00514   Scalar* const R_bot_1kp1 = &R_bot[ 0 + (k+1) * ldr_bot ];
00515 
00516   // k+2: 1 element in R_top (R_top(k,k)), and k+1 elements in
00517   // R_bot (R_bot(1:k,k), in 1-based indexing notation).
00518   lapack.LARFP (k+2, R_top_kk, R_bot_1k, 1, tau[k]);
00519   // One-based indexing, Matlab version of the GEMV call below:
00520   // work(1:k) := R_bot(1:k,k+1:n)' * R_bot(1:k,k) 
00521   blas.GEMV ("T", k+1, n-k-1, ONE, R_bot_1kp1, ldr_bot, R_bot_1k, 1, ZERO, work, 1);
00522   
00523   for (Ordinal j = k+1; j < n; ++j)
00524     {
00525       Scalar& R_top_kj = R_top[ k + j*ldr_top ];
00526       work[j-k-1] += R_top_kj;
00527       R_top_kj -= tau[k] * work[j-k-1];
00528     }
00529   blas.GER (k+1, n-k-1, -tau[k], R_bot_1k, 1, work, 1, R_bot_1kp1, ldr_bot);
00530       }
00531     Scalar& R_top_nn = R_top[ (n-1) + (n-1)*ldr_top ];
00532     Scalar* const R_bot_1n = &R_bot[ 0 + (n-1)*ldr_bot ];
00533 
00534     // n+1: 1 element in R_top (n,n), and n elements in R_bot (the
00535     // whole last column).
00536     lapack.LARFP (n+1, R_top_nn, R_bot_1n, 1, tau[n-1]);
00537   }
00538 
00539 
00540   template< class Ordinal, class Scalar >
00541   void
00542   CombineNative< Ordinal, Scalar, false >::
00543   apply_pair (const ApplyType& applyType,
00544         const Ordinal ncols_C,
00545         const Ordinal ncols_Q,
00546         const Scalar R_bot[],
00547         const Ordinal ldr_bot,
00548         const Scalar tau[],
00549         Scalar C_top[],
00550         const Ordinal ldc_top,
00551         Scalar C_bot[],
00552         const Ordinal ldc_bot,
00553         Scalar work[]) const
00554   {
00555     const Scalar ZERO(0);
00556     blas_type blas;
00557 
00558     for (Ordinal i = 0; i < ncols_C; ++i)
00559       work[i] = ZERO;
00560     
00561     Ordinal j_start, j_end, j_step;
00562     if (applyType == ApplyType::NoTranspose)
00563       {
00564   j_start = ncols_Q - 1;
00565   j_end = -1; // exclusive
00566   j_step = -1;
00567       }
00568     else
00569       {
00570   j_start = 0; 
00571   j_end = ncols_Q; // exclusive
00572   j_step = +1;
00573       }
00574     for (Ordinal j_Q = j_start; j_Q != j_end; j_Q += j_step)
00575       { // Using Householder reflector stored in column j_Q of R_bot
00576   const Scalar* const R_bot_col = &R_bot[ 0 + j_Q*ldr_bot ];
00577 
00578   // In 1-based indexing notation, with k in 1, 2, ..., ncols_C
00579   // (inclusive): (Output is length ncols_C row vector)
00580   //
00581   // work(1:j) := R_bot(1:j,j)' * C_bot(1:j, 1:ncols_C) - C_top(j, 1:ncols_C)
00582   for (Ordinal j_C = 0; j_C < ncols_C; ++j_C)
00583     { // For each column j_C of [C_top; C_bot], update row j_Q
00584       // of C_top and rows 1:j_Q of C_bot.  (Again, this is in
00585       // 1-based indexing notation.
00586 
00587       Scalar work_j_C = ZERO;
00588       const Scalar* const C_bot_col = &C_bot[ 0 + j_C*ldc_bot ];
00589 
00590       for (Ordinal k = 0; k <= j_Q; ++k)
00591         work_j_C += R_bot_col[k] * C_bot_col[k];
00592 
00593       work_j_C += C_top[ j_Q + j_C*ldc_top ];
00594       work[j_C] = work_j_C;
00595     }
00596   for (Ordinal j_C = 0; j_C < ncols_C; ++j_C)
00597     C_top[ j_Q + j_C*ldc_top ] -= tau[j_Q] * work[j_C];
00598 
00599   blas.GER (j_Q+1, ncols_C, -tau[j_Q], R_bot_col, 1, work, 1, C_bot, ldc_bot);
00600       }
00601   }
00602 } // namespace TSQR
00603 
00604 
00605 
00606 #endif // __TSQR_CombineNative_hpp
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends