Kokkos Node API and Local Linear Algebra Kernels Version of the Day
Tsqr_CombineNative.hpp
Go to the documentation of this file.
00001 //@HEADER
00002 // ************************************************************************
00003 // 
00004 //          Kokkos: Node API and Parallel Node Kernels
00005 //              Copyright (2009) Sandia Corporation
00006 // 
00007 // Under terms of Contract DE-AC04-94AL85000, there is a non-exclusive
00008 // license for use of this work by or on behalf of the U.S. Government.
00009 // 
00010 // This library is free software; you can redistribute it and/or modify
00011 // it under the terms of the GNU Lesser General Public License as
00012 // published by the Free Software Foundation; either version 2.1 of the
00013 // License, or (at your option) any later version.
00014 //  
00015 // This library is distributed in the hope that it will be useful, but
00016 // WITHOUT ANY WARRANTY; without even the implied warranty of
00017 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
00018 // Lesser General Public License for more details.
00019 //  
00020 // You should have received a copy of the GNU Lesser General Public
00021 // License along with this library; if not, write to the Free Software
00022 // Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307
00023 // USA
00024 // Questions? Contact Michael A. Heroux (maherou@sandia.gov) 
00025 // 
00026 // ************************************************************************
00027 //@HEADER
00028 
00032 #ifndef __TSQR_CombineNative_hpp
00033 #define __TSQR_CombineNative_hpp
00034 
00035 #include "Teuchos_ScalarTraits.hpp"
00036 
00037 #include "Tsqr_ApplyType.hpp"
00038 #include "Tsqr_Blas.hpp"
00039 #include "Tsqr_Lapack.hpp"
00040 #include "Tsqr_CombineDefault.hpp"
00041 
00042 
00043 namespace TSQR {
00044 
00060   template< class Ordinal, class Scalar, bool isComplex = Teuchos::ScalarTraits< Scalar >::isComplex >
00061   class CombineNative 
00062   {
00063   public:
00064     typedef Scalar scalar_type;
00065     typedef typename Teuchos::ScalarTraits< Scalar >::magnitudeType magnitude_type;
00066     typedef Ordinal ordinal_type;
00067 
00068   private:
00069     typedef BLAS< ordinal_type, scalar_type > blas_type;
00070     typedef LAPACK< ordinal_type, scalar_type > lapack_type;
00071     typedef CombineDefault< ordinal_type, scalar_type > combine_default_type;
00072 
00073   public:
00074 
00075     CombineNative () {}
00076 
00084     static bool QR_produces_R_factor_with_nonnegative_diagonal() { 
00085       return lapack_type::QR_produces_R_factor_with_nonnegative_diagonal() &&
00086   combine_default_type::QR_produces_R_factor_with_nonnegative_diagonal();
00087     }
00088 
00089     void
00090     factor_first (const Ordinal nrows,
00091       const Ordinal ncols,
00092       Scalar A[],
00093       const Ordinal lda,
00094       Scalar tau[],
00095       Scalar work[]) const
00096     {
00097       return default_.factor_first (nrows, ncols, A, lda, tau, work);
00098     }
00099 
00100     void
00101     apply_first (const ApplyType& applyType,
00102      const Ordinal nrows,
00103      const Ordinal ncols_C,
00104      const Ordinal ncols_A,
00105      const Scalar A[],
00106      const Ordinal lda,
00107      const Scalar tau[],
00108      Scalar C[],
00109      const Ordinal ldc,
00110      Scalar work[]) const
00111     {
00112       return default_.apply_first (applyType, nrows, ncols_C, ncols_A, 
00113            A, lda, tau, C, ldc, work);
00114     }
00115 
00116     void
00117     apply_inner (const ApplyType& applyType,
00118      const Ordinal m,
00119      const Ordinal ncols_C,
00120      const Ordinal ncols_Q,
00121      const Scalar A[],
00122      const Ordinal lda,
00123      const Scalar tau[],
00124      Scalar C_top[],
00125      const Ordinal ldc_top,
00126      Scalar C_bot[],
00127      const Ordinal ldc_bot,
00128      Scalar work[]) const;
00129 
00130     void
00131     factor_inner (const Ordinal m,
00132       const Ordinal n,
00133       Scalar R[],
00134       const Ordinal ldr,
00135       Scalar A[],
00136       const Ordinal lda,
00137       Scalar tau[],
00138       Scalar work[]) const;
00139 
00140     void
00141     factor_pair (const Ordinal n,
00142      Scalar R_top[],
00143      const Ordinal ldr_top,
00144      Scalar R_bot[],
00145      const Ordinal ldr_bot,
00146      Scalar tau[],
00147      Scalar work[]) const;
00148 
00149     void
00150     apply_pair (const ApplyType& applyType,
00151     const Ordinal ncols_C,
00152     const Ordinal ncols_Q,
00153     const Scalar R_bot[],
00154     const Ordinal ldr_bot,
00155     const Scalar tau[],
00156     Scalar C_top[],
00157     const Ordinal ldc_top,
00158     Scalar C_bot[],
00159     const Ordinal ldc_bot,
00160     Scalar work[]) const;
00161 
00162   private:
00163     mutable combine_default_type default_;
00164   };
00165      
00166 
00169   template< class Ordinal, class Scalar >
00170   class CombineNative< Ordinal, Scalar, false >
00171   {
00172   public:
00173     typedef Scalar scalar_type;
00174     typedef typename Teuchos::ScalarTraits< Scalar >::magnitudeType magnitude_type;
00175     typedef Ordinal ordinal_type;
00176 
00177   private:
00178     typedef BLAS< ordinal_type, scalar_type > blas_type;
00179     typedef LAPACK< ordinal_type, scalar_type > lapack_type;
00180     typedef CombineDefault< ordinal_type, scalar_type > combine_default_type;
00181 
00182   public:
00183     CombineNative () {}
00184 
00185     static bool QR_produces_R_factor_with_nonnegative_diagonal() { 
00186       return lapack_type::QR_produces_R_factor_with_nonnegative_diagonal() &&
00187   combine_default_type::QR_produces_R_factor_with_nonnegative_diagonal();
00188     }
00189 
00190     void
00191     factor_first (const Ordinal nrows,
00192       const Ordinal ncols,
00193       Scalar A[],
00194       const Ordinal lda,
00195       Scalar tau[],
00196       Scalar work[]) const
00197     {
00198       return default_.factor_first (nrows, ncols, A, lda, tau, work);
00199     }
00200 
00201     void
00202     apply_first (const ApplyType& applyType,
00203      const Ordinal nrows,
00204      const Ordinal ncols_C,
00205      const Ordinal ncols_A,
00206      const Scalar A[],
00207      const Ordinal lda,
00208      const Scalar tau[],
00209      Scalar C[],
00210      const Ordinal ldc,
00211      Scalar work[]) const
00212     {
00213       return default_.apply_first (applyType, nrows, ncols_C, ncols_A, 
00214            A, lda, tau, C, ldc, work);
00215     }
00216 
00217     void
00218     apply_inner (const ApplyType& applyType,
00219      const Ordinal m,
00220      const Ordinal ncols_C,
00221      const Ordinal ncols_Q,
00222      const Scalar A[],
00223      const Ordinal lda,
00224      const Scalar tau[],
00225      Scalar C_top[],
00226      const Ordinal ldc_top,
00227      Scalar C_bot[],
00228      const Ordinal ldc_bot,
00229      Scalar work[]) const;
00230 
00231     void
00232     factor_inner (const Ordinal m,
00233       const Ordinal n,
00234       Scalar R[],
00235       const Ordinal ldr,
00236       Scalar A[],
00237       const Ordinal lda,
00238       Scalar tau[],
00239       Scalar work[]) const;
00240 
00241     void
00242     factor_pair (const Ordinal n,
00243      Scalar R_top[],
00244      const Ordinal ldr_top,
00245      Scalar R_bot[],
00246      const Ordinal ldr_bot,
00247      Scalar tau[],
00248      Scalar work[]) const;
00249 
00250     void
00251     apply_pair (const ApplyType& applyType,
00252     const Ordinal ncols_C,
00253     const Ordinal ncols_Q,
00254     const Scalar R_bot[],
00255     const Ordinal ldr_bot,
00256     const Scalar tau[],
00257     Scalar C_top[],
00258     const Ordinal ldc_top,
00259     Scalar C_bot[],
00260     const Ordinal ldc_bot,
00261     Scalar work[]) const;
00262 
00263   private:
00264     mutable combine_default_type default_;
00265   };
00266 
00267 
00270   template< class Ordinal, class Scalar >
00271   class CombineNative< Ordinal, Scalar, true >
00272   {
00273   public:
00274     typedef Scalar scalar_type;
00275     typedef typename Teuchos::ScalarTraits< Scalar >::magnitudeType magnitude_type;
00276     typedef Ordinal ordinal_type;
00277 
00278   private:
00279     typedef BLAS< ordinal_type, scalar_type > blas_type;
00280     typedef LAPACK< ordinal_type, scalar_type > lapack_type;
00281     typedef CombineDefault< ordinal_type, scalar_type > combine_default_type;
00282 
00283   public:
00284     CombineNative () {}
00285 
00286     static bool QR_produces_R_factor_with_nonnegative_diagonal() { 
00287       return lapack_type::QR_produces_R_factor_with_nonnegative_diagonal() &&
00288   combine_default_type::QR_produces_R_factor_with_nonnegative_diagonal();
00289     }
00290 
00291     void
00292     factor_first (const Ordinal nrows,
00293       const Ordinal ncols,
00294       Scalar A[],
00295       const Ordinal lda,
00296       Scalar tau[],
00297       Scalar work[]) const
00298     {
00299       return default_.factor_first (nrows, ncols, A, lda, tau, work);
00300     }
00301 
00302     void
00303     apply_first (const ApplyType& applyType,
00304      const Ordinal nrows,
00305      const Ordinal ncols_C,
00306      const Ordinal ncols_A,
00307      const Scalar A[],
00308      const Ordinal lda,
00309      const Scalar tau[],
00310      Scalar C[],
00311      const Ordinal ldc,
00312      Scalar work[]) const
00313     {
00314       return default_.apply_first (applyType, nrows, ncols_C, ncols_A, 
00315            A, lda, tau, C, ldc, work);
00316     }
00317 
00318     void
00319     apply_inner (const ApplyType& applyType,
00320      const Ordinal m,
00321      const Ordinal ncols_C,
00322      const Ordinal ncols_Q,
00323      const Scalar A[],
00324      const Ordinal lda,
00325      const Scalar tau[],
00326      Scalar C_top[],
00327      const Ordinal ldc_top,
00328      Scalar C_bot[],
00329      const Ordinal ldc_bot,
00330      Scalar work[]) const
00331     {
00332       return default_.apply_inner (applyType, m, ncols_C, ncols_Q,
00333            A, lda, tau,
00334            C_top, ldc_top, C_bot, ldc_bot,
00335            work);
00336     }
00337 
00338     void
00339     factor_inner (const Ordinal m,
00340       const Ordinal n,
00341       Scalar R[],
00342       const Ordinal ldr,
00343       Scalar A[],
00344       const Ordinal lda,
00345       Scalar tau[],
00346       Scalar work[]) const
00347     {
00348       return default_.factor_inner (m, n, R, ldr, A, lda, tau, work);
00349     }
00350 
00351     void
00352     factor_pair (const Ordinal n,
00353      Scalar R_top[],
00354      const Ordinal ldr_top,
00355      Scalar R_bot[],
00356      const Ordinal ldr_bot,
00357      Scalar tau[],
00358      Scalar work[]) const
00359     {
00360       return default_.factor_pair (n, R_top, ldr_top, R_bot, ldr_bot, tau, work);
00361     }
00362 
00363     void
00364     apply_pair (const ApplyType& applyType,
00365     const Ordinal ncols_C,
00366     const Ordinal ncols_Q,
00367     const Scalar R_bot[],
00368     const Ordinal ldr_bot,
00369     const Scalar tau[],
00370     Scalar C_top[],
00371     const Ordinal ldc_top,
00372     Scalar C_bot[],
00373     const Ordinal ldc_bot,
00374     Scalar work[]) const
00375     {
00376       return default_.apply_pair (applyType, ncols_C, ncols_Q,
00377           R_bot, ldr_bot, tau,
00378           C_top, ldc_top, C_bot, ldc_bot,
00379           work);
00380     }
00381 
00382   private:
00383     mutable combine_default_type default_;
00384   };
00385      
00386 
00387   template< class Ordinal, class Scalar >
00388   void
00389   CombineNative< Ordinal, Scalar, false >::
00390   factor_inner (const Ordinal m,
00391     const Ordinal n,
00392     Scalar R[],
00393     const Ordinal ldr,
00394     Scalar A[],
00395     const Ordinal lda,
00396     Scalar tau[],
00397     Scalar work[]) const
00398   {
00399     const Scalar ZERO(0), ONE(1);
00400     lapack_type lapack;
00401     blas_type blas;
00402 
00403     for (Ordinal k = 0; k < n; ++k)
00404       work[k] = ZERO;
00405 
00406     for (Ordinal k = 0; k < n-1; ++k)
00407       {
00408   Scalar& R_kk = R[ k + k * ldr ];
00409   Scalar* const A_1k = &A[ 0 + k * lda ];
00410   Scalar* const A_1kp1 = &A[ 0 + (k+1) * lda ];
00411 
00412   lapack.LARFP (m + 1, R_kk, A_1k, 1, tau[k]);
00413   blas.GEMV ("T", m, n-k-1, ONE, A_1kp1, lda, A_1k, 1, ZERO, work, 1);
00414 
00415   for (Ordinal j = k+1; j < n; ++j)
00416     {
00417       Scalar& R_kj = R[ k + j*ldr ];
00418 
00419       work[j-k-1] += R_kj;
00420       R_kj -= tau[k] * work[j-k-1];
00421     }
00422   blas.GER (m, n-k-1, -tau[k], A_1k, 1, work, 1, A_1kp1, lda);
00423       }
00424     Scalar& R_nn = R[ (n-1) + (n-1) * ldr ];
00425     Scalar* const A_1n = &A[ 0 + (n-1) * lda ];
00426 
00427     lapack.LARFP (m+1, R_nn, A_1n, 1, tau[n-1]);
00428   }
00429 
00430 
00431   template< class Ordinal, class Scalar >
00432   void
00433   CombineNative< Ordinal, Scalar, false >::
00434   apply_inner (const ApplyType& applyType,
00435          const Ordinal m,
00436          const Ordinal ncols_C,
00437          const Ordinal ncols_Q,
00438          const Scalar A[],
00439          const Ordinal lda,
00440          const Scalar tau[],
00441          Scalar C_top[],
00442          const Ordinal ldc_top,
00443          Scalar C_bot[],
00444          const Ordinal ldc_bot,
00445          Scalar work[]) const
00446   {
00447     const Scalar ZERO(0);
00448     blas_type blas;
00449 
00450     //Scalar* const y = work;
00451     for (Ordinal i = 0; i < ncols_C; ++i)
00452       work[i] = ZERO;
00453     
00454     Ordinal j_start, j_end, j_step;
00455     if (applyType == ApplyType::NoTranspose)
00456       {
00457   j_start = ncols_Q - 1;
00458   j_end = -1; // exclusive
00459   j_step = -1;
00460       }
00461     else
00462       {
00463   j_start = 0; 
00464   j_end = ncols_Q; // exclusive
00465   j_step = +1;
00466       }
00467     for (Ordinal j = j_start; j != j_end; j += j_step)
00468       {
00469   const Scalar* const A_1j = &A[ 0 + j*lda ];
00470 
00471   //blas.GEMV ("T", m, ncols_C, ONE, C_bot, ldc_bot, A_1j, 1, ZERO, &y[0], 1);  
00472   for (Ordinal i = 0; i < ncols_C; ++i)
00473     {
00474       work[i] = ZERO;
00475       for (Ordinal k = 0; k < m; ++k)
00476         work[i] += A_1j[k] * C_bot[ k + i*ldc_bot ];
00477 
00478       work[i] += C_top[ j + i*ldc_top ];
00479     }
00480   for (Ordinal k = 0; k < ncols_C; ++k)
00481     C_top[ j + k*ldc_top ] -= tau[j] * work[k];
00482 
00483   blas.GER (m, ncols_C, -tau[j], A_1j, 1, work, 1, C_bot, ldc_bot);
00484       }
00485   }
00486 
00487 
00488   template< class Ordinal, class Scalar >
00489   void
00490   CombineNative< Ordinal, Scalar, false >::
00491   factor_pair (const Ordinal n,
00492          Scalar R_top[],
00493          const Ordinal ldr_top,
00494          Scalar R_bot[],
00495          const Ordinal ldr_bot,
00496          Scalar tau[],
00497          Scalar work[]) const
00498   {
00499     const Scalar ZERO(0), ONE(1);
00500     lapack_type lapack;
00501     blas_type blas;
00502 
00503     for (Ordinal k = 0; k < n; ++k)
00504       work[k] = ZERO;
00505 
00506     for (Ordinal k = 0; k < n-1; ++k)
00507       {
00508   Scalar& R_top_kk = R_top[ k + k * ldr_top ];
00509   Scalar* const R_bot_1k = &R_bot[ 0 + k * ldr_bot ];
00510   Scalar* const R_bot_1kp1 = &R_bot[ 0 + (k+1) * ldr_bot ];
00511 
00512   // k+2: 1 element in R_top (R_top(k,k)), and k+1 elements in
00513   // R_bot (R_bot(1:k,k), in 1-based indexing notation).
00514   lapack.LARFP (k+2, R_top_kk, R_bot_1k, 1, tau[k]);
00515   // One-based indexing, Matlab version of the GEMV call below:
00516   // work(1:k) := R_bot(1:k,k+1:n)' * R_bot(1:k,k) 
00517   blas.GEMV ("T", k+1, n-k-1, ONE, R_bot_1kp1, ldr_bot, R_bot_1k, 1, ZERO, work, 1);
00518   
00519   for (Ordinal j = k+1; j < n; ++j)
00520     {
00521       Scalar& R_top_kj = R_top[ k + j*ldr_top ];
00522       work[j-k-1] += R_top_kj;
00523       R_top_kj -= tau[k] * work[j-k-1];
00524     }
00525   blas.GER (k+1, n-k-1, -tau[k], R_bot_1k, 1, work, 1, R_bot_1kp1, ldr_bot);
00526       }
00527     Scalar& R_top_nn = R_top[ (n-1) + (n-1)*ldr_top ];
00528     Scalar* const R_bot_1n = &R_bot[ 0 + (n-1)*ldr_bot ];
00529 
00530     // n+1: 1 element in R_top (n,n), and n elements in R_bot (the
00531     // whole last column).
00532     lapack.LARFP (n+1, R_top_nn, R_bot_1n, 1, tau[n-1]);
00533   }
00534 
00535 
00536   template< class Ordinal, class Scalar >
00537   void
00538   CombineNative< Ordinal, Scalar, false >::
00539   apply_pair (const ApplyType& applyType,
00540         const Ordinal ncols_C,
00541         const Ordinal ncols_Q,
00542         const Scalar R_bot[],
00543         const Ordinal ldr_bot,
00544         const Scalar tau[],
00545         Scalar C_top[],
00546         const Ordinal ldc_top,
00547         Scalar C_bot[],
00548         const Ordinal ldc_bot,
00549         Scalar work[]) const
00550   {
00551     const Scalar ZERO(0);
00552     blas_type blas;
00553 
00554     for (Ordinal i = 0; i < ncols_C; ++i)
00555       work[i] = ZERO;
00556     
00557     Ordinal j_start, j_end, j_step;
00558     if (applyType == ApplyType::NoTranspose)
00559       {
00560   j_start = ncols_Q - 1;
00561   j_end = -1; // exclusive
00562   j_step = -1;
00563       }
00564     else
00565       {
00566   j_start = 0; 
00567   j_end = ncols_Q; // exclusive
00568   j_step = +1;
00569       }
00570     for (Ordinal j_Q = j_start; j_Q != j_end; j_Q += j_step)
00571       { // Using Householder reflector stored in column j_Q of R_bot
00572   const Scalar* const R_bot_col = &R_bot[ 0 + j_Q*ldr_bot ];
00573 
00574   // In 1-based indexing notation, with k in 1, 2, ..., ncols_C
00575   // (inclusive): (Output is length ncols_C row vector)
00576   //
00577   // work(1:j) := R_bot(1:j,j)' * C_bot(1:j, 1:ncols_C) - C_top(j, 1:ncols_C)
00578   for (Ordinal j_C = 0; j_C < ncols_C; ++j_C)
00579     { // For each column j_C of [C_top; C_bot], update row j_Q
00580       // of C_top and rows 1:j_Q of C_bot.  (Again, this is in
00581       // 1-based indexing notation.
00582 
00583       Scalar work_j_C = ZERO;
00584       const Scalar* const C_bot_col = &C_bot[ 0 + j_C*ldc_bot ];
00585 
00586       for (Ordinal k = 0; k <= j_Q; ++k)
00587         work_j_C += R_bot_col[k] * C_bot_col[k];
00588 
00589       work_j_C += C_top[ j_Q + j_C*ldc_top ];
00590       work[j_C] = work_j_C;
00591     }
00592   for (Ordinal j_C = 0; j_C < ncols_C; ++j_C)
00593     C_top[ j_Q + j_C*ldc_top ] -= tau[j_Q] * work[j_C];
00594 
00595   blas.GER (j_Q+1, ncols_C, -tau[j_Q], R_bot_col, 1, work, 1, C_bot, ldc_bot);
00596       }
00597   }
00598 } // namespace TSQR
00599 
00600 
00601 
00602 #endif // __TSQR_CombineNative_hpp
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends