Kokkos Node API and Local Linear Algebra Kernels Version of the Day
Tsqr_CombineNative.hpp
Go to the documentation of this file.
00001 //@HEADER
00002 // ************************************************************************
00003 // 
00004 //          Kokkos: Node API and Parallel Node Kernels
00005 //              Copyright (2008) Sandia Corporation
00006 // 
00007 // Under the terms of Contract DE-AC04-94AL85000 with Sandia Corporation,
00008 // the U.S. Government retains certain rights in this software.
00009 // 
00010 // Redistribution and use in source and binary forms, with or without
00011 // modification, are permitted provided that the following conditions are
00012 // met:
00013 //
00014 // 1. Redistributions of source code must retain the above copyright
00015 // notice, this list of conditions and the following disclaimer.
00016 //
00017 // 2. Redistributions in binary form must reproduce the above copyright
00018 // notice, this list of conditions and the following disclaimer in the
00019 // documentation and/or other materials provided with the distribution.
00020 //
00021 // 3. Neither the name of the Corporation nor the names of the
00022 // contributors may be used to endorse or promote products derived from
00023 // this software without specific prior written permission.
00024 //
00025 // THIS SOFTWARE IS PROVIDED BY SANDIA CORPORATION "AS IS" AND ANY
00026 // EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
00027 // IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
00028 // PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL SANDIA CORPORATION OR THE
00029 // CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
00030 // EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
00031 // PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
00032 // PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
00033 // LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
00034 // NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
00035 // SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
00036 //
00037 // Questions? Contact Michael A. Heroux (maherou@sandia.gov) 
00038 // 
00039 // ************************************************************************
00040 //@HEADER
00041 
00045 #ifndef __TSQR_CombineNative_hpp
00046 #define __TSQR_CombineNative_hpp
00047 
00048 #include "Teuchos_ScalarTraits.hpp"
00049 
00050 #include "Tsqr_ApplyType.hpp"
00051 #include "Tsqr_Blas.hpp"
00052 #include "Tsqr_Lapack.hpp"
00053 #include "Tsqr_CombineDefault.hpp"
00054 
00055 
00056 namespace TSQR {
00057 
00073   template< class Ordinal, class Scalar, bool isComplex = Teuchos::ScalarTraits< Scalar >::isComplex >
00074   class CombineNative 
00075   {
00076   public:
00077     typedef Scalar scalar_type;
00078     typedef typename Teuchos::ScalarTraits< Scalar >::magnitudeType magnitude_type;
00079     typedef Ordinal ordinal_type;
00080 
00081   private:
00082     typedef BLAS< ordinal_type, scalar_type > blas_type;
00083     typedef LAPACK< ordinal_type, scalar_type > lapack_type;
00084     typedef CombineDefault< ordinal_type, scalar_type > combine_default_type;
00085 
00086   public:
00087 
00088     CombineNative () {}
00089 
00097     static bool QR_produces_R_factor_with_nonnegative_diagonal() { 
00098       return lapack_type::QR_produces_R_factor_with_nonnegative_diagonal() &&
00099   combine_default_type::QR_produces_R_factor_with_nonnegative_diagonal();
00100     }
00101 
00102     void
00103     factor_first (const Ordinal nrows,
00104       const Ordinal ncols,
00105       Scalar A[],
00106       const Ordinal lda,
00107       Scalar tau[],
00108       Scalar work[]) const
00109     {
00110       return default_.factor_first (nrows, ncols, A, lda, tau, work);
00111     }
00112 
00113     void
00114     apply_first (const ApplyType& applyType,
00115      const Ordinal nrows,
00116      const Ordinal ncols_C,
00117      const Ordinal ncols_A,
00118      const Scalar A[],
00119      const Ordinal lda,
00120      const Scalar tau[],
00121      Scalar C[],
00122      const Ordinal ldc,
00123      Scalar work[]) const
00124     {
00125       return default_.apply_first (applyType, nrows, ncols_C, ncols_A, 
00126            A, lda, tau, C, ldc, work);
00127     }
00128 
00129     void
00130     apply_inner (const ApplyType& applyType,
00131      const Ordinal m,
00132      const Ordinal ncols_C,
00133      const Ordinal ncols_Q,
00134      const Scalar A[],
00135      const Ordinal lda,
00136      const Scalar tau[],
00137      Scalar C_top[],
00138      const Ordinal ldc_top,
00139      Scalar C_bot[],
00140      const Ordinal ldc_bot,
00141      Scalar work[]) const;
00142 
00143     void
00144     factor_inner (const Ordinal m,
00145       const Ordinal n,
00146       Scalar R[],
00147       const Ordinal ldr,
00148       Scalar A[],
00149       const Ordinal lda,
00150       Scalar tau[],
00151       Scalar work[]) const;
00152 
00153     void
00154     factor_pair (const Ordinal n,
00155      Scalar R_top[],
00156      const Ordinal ldr_top,
00157      Scalar R_bot[],
00158      const Ordinal ldr_bot,
00159      Scalar tau[],
00160      Scalar work[]) const;
00161 
00162     void
00163     apply_pair (const ApplyType& applyType,
00164     const Ordinal ncols_C,
00165     const Ordinal ncols_Q,
00166     const Scalar R_bot[],
00167     const Ordinal ldr_bot,
00168     const Scalar tau[],
00169     Scalar C_top[],
00170     const Ordinal ldc_top,
00171     Scalar C_bot[],
00172     const Ordinal ldc_bot,
00173     Scalar work[]) const;
00174 
00175   private:
00176     mutable combine_default_type default_;
00177   };
00178      
00179 
00182   template< class Ordinal, class Scalar >
00183   class CombineNative< Ordinal, Scalar, false >
00184   {
00185   public:
00186     typedef Scalar scalar_type;
00187     typedef typename Teuchos::ScalarTraits< Scalar >::magnitudeType magnitude_type;
00188     typedef Ordinal ordinal_type;
00189 
00190   private:
00191     typedef BLAS< ordinal_type, scalar_type > blas_type;
00192     typedef LAPACK< ordinal_type, scalar_type > lapack_type;
00193     typedef CombineDefault< ordinal_type, scalar_type > combine_default_type;
00194 
00195   public:
00196     CombineNative () {}
00197 
00198     static bool QR_produces_R_factor_with_nonnegative_diagonal() { 
00199       return lapack_type::QR_produces_R_factor_with_nonnegative_diagonal() &&
00200   combine_default_type::QR_produces_R_factor_with_nonnegative_diagonal();
00201     }
00202 
00203     void
00204     factor_first (const Ordinal nrows,
00205       const Ordinal ncols,
00206       Scalar A[],
00207       const Ordinal lda,
00208       Scalar tau[],
00209       Scalar work[]) const
00210     {
00211       return default_.factor_first (nrows, ncols, A, lda, tau, work);
00212     }
00213 
00214     void
00215     apply_first (const ApplyType& applyType,
00216      const Ordinal nrows,
00217      const Ordinal ncols_C,
00218      const Ordinal ncols_A,
00219      const Scalar A[],
00220      const Ordinal lda,
00221      const Scalar tau[],
00222      Scalar C[],
00223      const Ordinal ldc,
00224      Scalar work[]) const
00225     {
00226       return default_.apply_first (applyType, nrows, ncols_C, ncols_A, 
00227            A, lda, tau, C, ldc, work);
00228     }
00229 
00230     void
00231     apply_inner (const ApplyType& applyType,
00232      const Ordinal m,
00233      const Ordinal ncols_C,
00234      const Ordinal ncols_Q,
00235      const Scalar A[],
00236      const Ordinal lda,
00237      const Scalar tau[],
00238      Scalar C_top[],
00239      const Ordinal ldc_top,
00240      Scalar C_bot[],
00241      const Ordinal ldc_bot,
00242      Scalar work[]) const;
00243 
00244     void
00245     factor_inner (const Ordinal m,
00246       const Ordinal n,
00247       Scalar R[],
00248       const Ordinal ldr,
00249       Scalar A[],
00250       const Ordinal lda,
00251       Scalar tau[],
00252       Scalar work[]) const;
00253 
00254     void
00255     factor_pair (const Ordinal n,
00256      Scalar R_top[],
00257      const Ordinal ldr_top,
00258      Scalar R_bot[],
00259      const Ordinal ldr_bot,
00260      Scalar tau[],
00261      Scalar work[]) const;
00262 
00263     void
00264     apply_pair (const ApplyType& applyType,
00265     const Ordinal ncols_C,
00266     const Ordinal ncols_Q,
00267     const Scalar R_bot[],
00268     const Ordinal ldr_bot,
00269     const Scalar tau[],
00270     Scalar C_top[],
00271     const Ordinal ldc_top,
00272     Scalar C_bot[],
00273     const Ordinal ldc_bot,
00274     Scalar work[]) const;
00275 
00276   private:
00277     mutable combine_default_type default_;
00278   };
00279 
00280 
00283   template< class Ordinal, class Scalar >
00284   class CombineNative< Ordinal, Scalar, true >
00285   {
00286   public:
00287     typedef Scalar scalar_type;
00288     typedef typename Teuchos::ScalarTraits< Scalar >::magnitudeType magnitude_type;
00289     typedef Ordinal ordinal_type;
00290 
00291   private:
00292     typedef BLAS< ordinal_type, scalar_type > blas_type;
00293     typedef LAPACK< ordinal_type, scalar_type > lapack_type;
00294     typedef CombineDefault< ordinal_type, scalar_type > combine_default_type;
00295 
00296   public:
00297     CombineNative () {}
00298 
00299     static bool QR_produces_R_factor_with_nonnegative_diagonal() { 
00300       return lapack_type::QR_produces_R_factor_with_nonnegative_diagonal() &&
00301   combine_default_type::QR_produces_R_factor_with_nonnegative_diagonal();
00302     }
00303 
00304     void
00305     factor_first (const Ordinal nrows,
00306       const Ordinal ncols,
00307       Scalar A[],
00308       const Ordinal lda,
00309       Scalar tau[],
00310       Scalar work[]) const
00311     {
00312       return default_.factor_first (nrows, ncols, A, lda, tau, work);
00313     }
00314 
00315     void
00316     apply_first (const ApplyType& applyType,
00317      const Ordinal nrows,
00318      const Ordinal ncols_C,
00319      const Ordinal ncols_A,
00320      const Scalar A[],
00321      const Ordinal lda,
00322      const Scalar tau[],
00323      Scalar C[],
00324      const Ordinal ldc,
00325      Scalar work[]) const
00326     {
00327       return default_.apply_first (applyType, nrows, ncols_C, ncols_A, 
00328            A, lda, tau, C, ldc, work);
00329     }
00330 
00331     void
00332     apply_inner (const ApplyType& applyType,
00333      const Ordinal m,
00334      const Ordinal ncols_C,
00335      const Ordinal ncols_Q,
00336      const Scalar A[],
00337      const Ordinal lda,
00338      const Scalar tau[],
00339      Scalar C_top[],
00340      const Ordinal ldc_top,
00341      Scalar C_bot[],
00342      const Ordinal ldc_bot,
00343      Scalar work[]) const
00344     {
00345       return default_.apply_inner (applyType, m, ncols_C, ncols_Q,
00346            A, lda, tau,
00347            C_top, ldc_top, C_bot, ldc_bot,
00348            work);
00349     }
00350 
00351     void
00352     factor_inner (const Ordinal m,
00353       const Ordinal n,
00354       Scalar R[],
00355       const Ordinal ldr,
00356       Scalar A[],
00357       const Ordinal lda,
00358       Scalar tau[],
00359       Scalar work[]) const
00360     {
00361       return default_.factor_inner (m, n, R, ldr, A, lda, tau, work);
00362     }
00363 
00364     void
00365     factor_pair (const Ordinal n,
00366      Scalar R_top[],
00367      const Ordinal ldr_top,
00368      Scalar R_bot[],
00369      const Ordinal ldr_bot,
00370      Scalar tau[],
00371      Scalar work[]) const
00372     {
00373       return default_.factor_pair (n, R_top, ldr_top, R_bot, ldr_bot, tau, work);
00374     }
00375 
00376     void
00377     apply_pair (const ApplyType& applyType,
00378     const Ordinal ncols_C,
00379     const Ordinal ncols_Q,
00380     const Scalar R_bot[],
00381     const Ordinal ldr_bot,
00382     const Scalar tau[],
00383     Scalar C_top[],
00384     const Ordinal ldc_top,
00385     Scalar C_bot[],
00386     const Ordinal ldc_bot,
00387     Scalar work[]) const
00388     {
00389       return default_.apply_pair (applyType, ncols_C, ncols_Q,
00390           R_bot, ldr_bot, tau,
00391           C_top, ldc_top, C_bot, ldc_bot,
00392           work);
00393     }
00394 
00395   private:
00396     mutable combine_default_type default_;
00397   };
00398      
00399 
00400   template< class Ordinal, class Scalar >
00401   void
00402   CombineNative< Ordinal, Scalar, false >::
00403   factor_inner (const Ordinal m,
00404     const Ordinal n,
00405     Scalar R[],
00406     const Ordinal ldr,
00407     Scalar A[],
00408     const Ordinal lda,
00409     Scalar tau[],
00410     Scalar work[]) const
00411   {
00412     const Scalar ZERO(0), ONE(1);
00413     lapack_type lapack;
00414     blas_type blas;
00415 
00416     for (Ordinal k = 0; k < n; ++k)
00417       work[k] = ZERO;
00418 
00419     for (Ordinal k = 0; k < n-1; ++k)
00420       {
00421   Scalar& R_kk = R[ k + k * ldr ];
00422   Scalar* const A_1k = &A[ 0 + k * lda ];
00423   Scalar* const A_1kp1 = &A[ 0 + (k+1) * lda ];
00424 
00425   lapack.LARFP (m + 1, R_kk, A_1k, 1, tau[k]);
00426   blas.GEMV ("T", m, n-k-1, ONE, A_1kp1, lda, A_1k, 1, ZERO, work, 1);
00427 
00428   for (Ordinal j = k+1; j < n; ++j)
00429     {
00430       Scalar& R_kj = R[ k + j*ldr ];
00431 
00432       work[j-k-1] += R_kj;
00433       R_kj -= tau[k] * work[j-k-1];
00434     }
00435   blas.GER (m, n-k-1, -tau[k], A_1k, 1, work, 1, A_1kp1, lda);
00436       }
00437     Scalar& R_nn = R[ (n-1) + (n-1) * ldr ];
00438     Scalar* const A_1n = &A[ 0 + (n-1) * lda ];
00439 
00440     lapack.LARFP (m+1, R_nn, A_1n, 1, tau[n-1]);
00441   }
00442 
00443 
00444   template< class Ordinal, class Scalar >
00445   void
00446   CombineNative< Ordinal, Scalar, false >::
00447   apply_inner (const ApplyType& applyType,
00448          const Ordinal m,
00449          const Ordinal ncols_C,
00450          const Ordinal ncols_Q,
00451          const Scalar A[],
00452          const Ordinal lda,
00453          const Scalar tau[],
00454          Scalar C_top[],
00455          const Ordinal ldc_top,
00456          Scalar C_bot[],
00457          const Ordinal ldc_bot,
00458          Scalar work[]) const
00459   {
00460     const Scalar ZERO(0);
00461     blas_type blas;
00462 
00463     //Scalar* const y = work;
00464     for (Ordinal i = 0; i < ncols_C; ++i)
00465       work[i] = ZERO;
00466     
00467     Ordinal j_start, j_end, j_step;
00468     if (applyType == ApplyType::NoTranspose)
00469       {
00470   j_start = ncols_Q - 1;
00471   j_end = -1; // exclusive
00472   j_step = -1;
00473       }
00474     else
00475       {
00476   j_start = 0; 
00477   j_end = ncols_Q; // exclusive
00478   j_step = +1;
00479       }
00480     for (Ordinal j = j_start; j != j_end; j += j_step)
00481       {
00482   const Scalar* const A_1j = &A[ 0 + j*lda ];
00483 
00484   //blas.GEMV ("T", m, ncols_C, ONE, C_bot, ldc_bot, A_1j, 1, ZERO, &y[0], 1);  
00485   for (Ordinal i = 0; i < ncols_C; ++i)
00486     {
00487       work[i] = ZERO;
00488       for (Ordinal k = 0; k < m; ++k)
00489         work[i] += A_1j[k] * C_bot[ k + i*ldc_bot ];
00490 
00491       work[i] += C_top[ j + i*ldc_top ];
00492     }
00493   for (Ordinal k = 0; k < ncols_C; ++k)
00494     C_top[ j + k*ldc_top ] -= tau[j] * work[k];
00495 
00496   blas.GER (m, ncols_C, -tau[j], A_1j, 1, work, 1, C_bot, ldc_bot);
00497       }
00498   }
00499 
00500 
00501   template< class Ordinal, class Scalar >
00502   void
00503   CombineNative< Ordinal, Scalar, false >::
00504   factor_pair (const Ordinal n,
00505          Scalar R_top[],
00506          const Ordinal ldr_top,
00507          Scalar R_bot[],
00508          const Ordinal ldr_bot,
00509          Scalar tau[],
00510          Scalar work[]) const
00511   {
00512     const Scalar ZERO(0), ONE(1);
00513     lapack_type lapack;
00514     blas_type blas;
00515 
00516     for (Ordinal k = 0; k < n; ++k)
00517       work[k] = ZERO;
00518 
00519     for (Ordinal k = 0; k < n-1; ++k)
00520       {
00521   Scalar& R_top_kk = R_top[ k + k * ldr_top ];
00522   Scalar* const R_bot_1k = &R_bot[ 0 + k * ldr_bot ];
00523   Scalar* const R_bot_1kp1 = &R_bot[ 0 + (k+1) * ldr_bot ];
00524 
00525   // k+2: 1 element in R_top (R_top(k,k)), and k+1 elements in
00526   // R_bot (R_bot(1:k,k), in 1-based indexing notation).
00527   lapack.LARFP (k+2, R_top_kk, R_bot_1k, 1, tau[k]);
00528   // One-based indexing, Matlab version of the GEMV call below:
00529   // work(1:k) := R_bot(1:k,k+1:n)' * R_bot(1:k,k) 
00530   blas.GEMV ("T", k+1, n-k-1, ONE, R_bot_1kp1, ldr_bot, R_bot_1k, 1, ZERO, work, 1);
00531   
00532   for (Ordinal j = k+1; j < n; ++j)
00533     {
00534       Scalar& R_top_kj = R_top[ k + j*ldr_top ];
00535       work[j-k-1] += R_top_kj;
00536       R_top_kj -= tau[k] * work[j-k-1];
00537     }
00538   blas.GER (k+1, n-k-1, -tau[k], R_bot_1k, 1, work, 1, R_bot_1kp1, ldr_bot);
00539       }
00540     Scalar& R_top_nn = R_top[ (n-1) + (n-1)*ldr_top ];
00541     Scalar* const R_bot_1n = &R_bot[ 0 + (n-1)*ldr_bot ];
00542 
00543     // n+1: 1 element in R_top (n,n), and n elements in R_bot (the
00544     // whole last column).
00545     lapack.LARFP (n+1, R_top_nn, R_bot_1n, 1, tau[n-1]);
00546   }
00547 
00548 
00549   template< class Ordinal, class Scalar >
00550   void
00551   CombineNative< Ordinal, Scalar, false >::
00552   apply_pair (const ApplyType& applyType,
00553         const Ordinal ncols_C,
00554         const Ordinal ncols_Q,
00555         const Scalar R_bot[],
00556         const Ordinal ldr_bot,
00557         const Scalar tau[],
00558         Scalar C_top[],
00559         const Ordinal ldc_top,
00560         Scalar C_bot[],
00561         const Ordinal ldc_bot,
00562         Scalar work[]) const
00563   {
00564     const Scalar ZERO(0);
00565     blas_type blas;
00566 
00567     for (Ordinal i = 0; i < ncols_C; ++i)
00568       work[i] = ZERO;
00569     
00570     Ordinal j_start, j_end, j_step;
00571     if (applyType == ApplyType::NoTranspose)
00572       {
00573   j_start = ncols_Q - 1;
00574   j_end = -1; // exclusive
00575   j_step = -1;
00576       }
00577     else
00578       {
00579   j_start = 0; 
00580   j_end = ncols_Q; // exclusive
00581   j_step = +1;
00582       }
00583     for (Ordinal j_Q = j_start; j_Q != j_end; j_Q += j_step)
00584       { // Using Householder reflector stored in column j_Q of R_bot
00585   const Scalar* const R_bot_col = &R_bot[ 0 + j_Q*ldr_bot ];
00586 
00587   // In 1-based indexing notation, with k in 1, 2, ..., ncols_C
00588   // (inclusive): (Output is length ncols_C row vector)
00589   //
00590   // work(1:j) := R_bot(1:j,j)' * C_bot(1:j, 1:ncols_C) - C_top(j, 1:ncols_C)
00591   for (Ordinal j_C = 0; j_C < ncols_C; ++j_C)
00592     { // For each column j_C of [C_top; C_bot], update row j_Q
00593       // of C_top and rows 1:j_Q of C_bot.  (Again, this is in
00594       // 1-based indexing notation.
00595 
00596       Scalar work_j_C = ZERO;
00597       const Scalar* const C_bot_col = &C_bot[ 0 + j_C*ldc_bot ];
00598 
00599       for (Ordinal k = 0; k <= j_Q; ++k)
00600         work_j_C += R_bot_col[k] * C_bot_col[k];
00601 
00602       work_j_C += C_top[ j_Q + j_C*ldc_top ];
00603       work[j_C] = work_j_C;
00604     }
00605   for (Ordinal j_C = 0; j_C < ncols_C; ++j_C)
00606     C_top[ j_Q + j_C*ldc_top ] -= tau[j_Q] * work[j_C];
00607 
00608   blas.GER (j_Q+1, ncols_C, -tau[j_Q], R_bot_col, 1, work, 1, C_bot, ldc_bot);
00609       }
00610   }
00611 } // namespace TSQR
00612 
00613 
00614 
00615 #endif // __TSQR_CombineNative_hpp
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends