Kokkos Node API and Local Linear Algebra Kernels Version of the Day
Tsqr_CombineFortran.hpp
Go to the documentation of this file.
00001 //@HEADER
00002 // ************************************************************************
00003 // 
00004 //          Kokkos: Node API and Parallel Node Kernels
00005 //              Copyright (2009) Sandia Corporation
00006 // 
00007 // Under terms of Contract DE-AC04-94AL85000, there is a non-exclusive
00008 // license for use of this work by or on behalf of the U.S. Government.
00009 // 
00010 // This library is free software; you can redistribute it and/or modify
00011 // it under the terms of the GNU Lesser General Public License as
00012 // published by the Free Software Foundation; either version 2.1 of the
00013 // License, or (at your option) any later version.
00014 //  
00015 // This library is distributed in the hope that it will be useful, but
00016 // WITHOUT ANY WARRANTY; without even the implied warranty of
00017 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
00018 // Lesser General Public License for more details.
00019 //  
00020 // You should have received a copy of the GNU Lesser General Public
00021 // License along with this library; if not, write to the Free Software
00022 // Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307
00023 // USA
00024 // Questions? Contact Michael A. Heroux (maherou@sandia.gov) 
00025 // 
00026 // ************************************************************************
00027 //@HEADER
00028 
00032 #ifndef __TSQR_CombineFortran_hpp
00033 #define __TSQR_CombineFortran_hpp
00034 
00035 #include <Tsqr_ApplyType.hpp>
00036 #include <Tsqr_MatView.hpp>
00037 #include <Tsqr_ScalarTraits.hpp>
00038 #include <Tsqr_CombineDefault.hpp>
00039 
00040 
00041 namespace TSQR {
00042 
00052   template<class Scalar, bool is_complex = ScalarTraits<Scalar>::is_complex >
00053   class CombineFortran {
00054   private:
00055     typedef CombineDefault<int, Scalar> combine_default_type;
00056 
00057   public:
00058     typedef Scalar scalar_type;
00059     typedef typename ScalarTraits<Scalar>::magnitude_type magnitude_type;
00060     typedef int ordinal_type;
00061 
00062     CombineFortran () {}
00063 
00072     static bool QR_produces_R_factor_with_nonnegative_diagonal();
00073 
00074     void
00075     factor_first (const ordinal_type nrows,
00076       const ordinal_type ncols,
00077       Scalar A[],
00078       const ordinal_type lda,
00079       Scalar tau[],
00080       Scalar work[]) const;
00081     
00082     void
00083     apply_first (const ApplyType& applyType,
00084      const ordinal_type nrows,
00085      const ordinal_type ncols_C,
00086      const ordinal_type ncols_A,
00087      const Scalar A[],
00088      const ordinal_type lda,
00089      const Scalar tau[],
00090      Scalar C[],
00091      const ordinal_type ldc,
00092      Scalar work[]) const;
00093 
00094     void
00095     apply_inner (const ApplyType& apply_type,
00096      const ordinal_type m,
00097      const ordinal_type ncols_C,
00098      const ordinal_type ncols_Q,
00099      const Scalar A[],
00100      const ordinal_type lda,
00101      const Scalar tau[],
00102      Scalar C_top[],
00103      const ordinal_type ldc_top,
00104      Scalar C_bot[],
00105      const ordinal_type ldc_bot,
00106      Scalar work[]) const;
00107 
00108     void
00109     factor_inner (const ordinal_type m,
00110       const ordinal_type n,
00111       Scalar R[],
00112       const ordinal_type ldr,
00113       Scalar A[],
00114       const ordinal_type lda,
00115       Scalar tau[],
00116       Scalar work[]) const;
00117 
00118     void
00119     factor_pair (const ordinal_type n,
00120      Scalar R_top[],
00121      const ordinal_type ldr_top,
00122      Scalar R_bot[],
00123      const ordinal_type ldr_bot,
00124      Scalar tau[],
00125      Scalar work[]) const;
00126     
00127     void
00128     apply_pair (const ApplyType& apply_type,
00129     const ordinal_type ncols_C, 
00130     const ordinal_type ncols_Q, 
00131     const Scalar R_bot[], 
00132     const ordinal_type ldr_bot,
00133     const Scalar tau[], 
00134     Scalar C_top[], 
00135     const ordinal_type ldc_top, 
00136     Scalar C_bot[], 
00137     const ordinal_type ldc_bot, 
00138     Scalar work[]) const;
00139 
00140   private:
00141     mutable combine_default_type default_;
00142   };
00143 
00144   // "Forward declaration" for the real-arithmetic case.  The Fortran
00145   // back end works well here for Scalar = {float, double}.
00146   template< class Scalar >
00147   class CombineFortran< Scalar, false > {
00148   private:
00149     typedef CombineDefault< int, Scalar > combine_default_type;
00150 
00151   public:
00152     typedef Scalar scalar_type;
00153     typedef typename ScalarTraits< Scalar >::magnitude_type magnitude_type;
00154     typedef int ordinal_type;
00155 
00156     CombineFortran () {}
00157 
00158     static bool QR_produces_R_factor_with_nonnegative_diagonal() {
00159       typedef LAPACK< int, Scalar > lapack_type;
00160 
00161       return lapack_type::QR_produces_R_factor_with_nonnegative_diagonal() &&
00162   combine_default_type::QR_produces_R_factor_with_nonnegative_diagonal();
00163     }
00164 
00165     void
00166     factor_first (const ordinal_type nrows,
00167       const ordinal_type ncols,
00168       Scalar A[],
00169       const ordinal_type lda,
00170       Scalar tau[],
00171       Scalar work[]) const
00172     {
00173       return default_.factor_first (nrows, ncols, A, lda, tau, work);
00174     }
00175     
00176     void
00177     apply_first (const ApplyType& applyType,
00178      const ordinal_type nrows,
00179      const ordinal_type ncols_C,
00180      const ordinal_type ncols_A,
00181      const Scalar A[],
00182      const ordinal_type lda,
00183      const Scalar tau[],
00184      Scalar C[],
00185      const ordinal_type ldc,
00186      Scalar work[]) const
00187     {
00188       return default_.apply_first (applyType, nrows, ncols_C, ncols_A, 
00189            A, lda, tau, 
00190            C, ldc, work);
00191     }
00192 
00193     void
00194     apply_inner (const ApplyType& apply_type,
00195      const ordinal_type m,
00196      const ordinal_type ncols_C,
00197      const ordinal_type ncols_Q,
00198      const Scalar A[],
00199      const ordinal_type lda,
00200      const Scalar tau[],
00201      Scalar C_top[],
00202      const ordinal_type ldc_top,
00203      Scalar C_bot[],
00204      const ordinal_type ldc_bot,
00205      Scalar work[]) const;
00206 
00207     void
00208     factor_inner (const ordinal_type m,
00209       const ordinal_type n,
00210       Scalar R[],
00211       const ordinal_type ldr,
00212       Scalar A[],
00213       const ordinal_type lda,
00214       Scalar tau[],
00215       Scalar work[]) const;
00216 
00217     void
00218     factor_pair (const ordinal_type n,
00219      Scalar R_top[],
00220      const ordinal_type ldr_top,
00221      Scalar R_bot[],
00222      const ordinal_type ldr_bot,
00223      Scalar tau[],
00224      Scalar work[]) const;
00225     
00226     void
00227     apply_pair (const ApplyType& apply_type,
00228     const ordinal_type ncols_C, 
00229     const ordinal_type ncols_Q, 
00230     const Scalar R_bot[], 
00231     const ordinal_type ldr_bot,
00232     const Scalar tau[], 
00233     Scalar C_top[], 
00234     const ordinal_type ldc_top, 
00235     Scalar C_bot[], 
00236     const ordinal_type ldc_bot, 
00237     Scalar work[]) const;
00238 
00239   private:
00240     mutable combine_default_type default_;
00241   };
00242 
00243 
00244   // "Forward declaration" for complex-arithmetic version of
00245   // CombineFortran.  The Fortran code doesn't actually work for this
00246   // case, so we implement everything using CombineDefault.  This
00247   // will likely result in an ~2x slowdown for typical use cases.
00248   template< class Scalar >
00249   class CombineFortran< Scalar, true > {
00250   private:
00251     typedef CombineDefault< int, Scalar > combine_default_type;
00252 
00253   public:
00254     typedef Scalar scalar_type;
00255     typedef typename ScalarTraits< Scalar >::magnitude_type magnitude_type;
00256     typedef int ordinal_type;
00257 
00258     CombineFortran () {}
00259 
00260     static bool QR_produces_R_factor_with_nonnegative_diagonal() {
00261       return combine_default_type::QR_produces_R_factor_with_nonnegative_diagonal();
00262     }
00263 
00264     void
00265     factor_first (const ordinal_type nrows,
00266       const ordinal_type ncols,
00267       Scalar A[],
00268       const ordinal_type lda,
00269       Scalar tau[],
00270       Scalar work[]) const
00271     {
00272       return default_.factor_first (nrows, ncols, A, lda, tau, work);
00273     }
00274     
00275     void
00276     apply_first (const ApplyType& applyType,
00277      const ordinal_type nrows,
00278      const ordinal_type ncols_C,
00279      const ordinal_type ncols_A,
00280      const Scalar A[],
00281      const ordinal_type lda,
00282      const Scalar tau[],
00283      Scalar C[],
00284      const ordinal_type ldc,
00285      Scalar work[]) const
00286     {
00287       return default_.apply_first (applyType, nrows, ncols_C, ncols_A, 
00288            A, lda, tau, 
00289            C, ldc, work);
00290     }
00291 
00292     void
00293     apply_inner (const ApplyType& apply_type,
00294      const ordinal_type m,
00295      const ordinal_type ncols_C,
00296      const ordinal_type ncols_Q,
00297      const Scalar A[],
00298      const ordinal_type lda,
00299      const Scalar tau[],
00300      Scalar C_top[],
00301      const ordinal_type ldc_top,
00302      Scalar C_bot[],
00303      const ordinal_type ldc_bot,
00304      Scalar work[]) const
00305     {
00306       default_.apply_inner (apply_type, m, ncols_C, ncols_Q, 
00307           A, lda, tau, 
00308           C_top, ldc_top, C_bot, ldc_bot, work);
00309     }
00310 
00311     void
00312     factor_inner (const ordinal_type m,
00313       const ordinal_type n,
00314       Scalar R[],
00315       const ordinal_type ldr,
00316       Scalar A[],
00317       const ordinal_type lda,
00318       Scalar tau[],
00319       Scalar work[]) const
00320     {
00321       default_.factor_inner (m, n, R, ldr, A, lda, tau, work);
00322     }
00323 
00324     void
00325     factor_pair (const ordinal_type n,
00326      Scalar R_top[],
00327      const ordinal_type ldr_top,
00328      Scalar R_bot[],
00329      const ordinal_type ldr_bot,
00330      Scalar tau[],
00331      Scalar work[]) const
00332     {
00333       default_.factor_pair (n, R_top, ldr_top, R_bot, ldr_bot, tau, work);
00334     }
00335     
00336     void
00337     apply_pair (const ApplyType& apply_type,
00338     const ordinal_type ncols_C, 
00339     const ordinal_type ncols_Q, 
00340     const Scalar R_bot[], 
00341     const ordinal_type ldr_bot,
00342     const Scalar tau[], 
00343     Scalar C_top[], 
00344     const ordinal_type ldc_top, 
00345     Scalar C_bot[], 
00346     const ordinal_type ldc_bot, 
00347     Scalar work[]) const
00348     {
00349       default_.apply_pair (apply_type, ncols_C, ncols_Q, 
00350          R_bot, ldr_bot, tau, 
00351          C_top, ldc_top, C_bot, ldc_bot, work);
00352     }
00353 
00354   private:
00355     // Default implementation of TSQR::Combine copies data in and out
00356     // of a single matrix, which is given to LAPACK.  It's slow
00357     // because we expect the number of columns to be small, so copying
00358     // overhead is significant.  Experiments have shown a ~2x slowdown
00359     // due to copying overhead.
00360     mutable CombineDefault< ordinal_type, scalar_type > default_;
00361   };
00362 
00363 
00364 } // namespace TSQR
00365 
00366 #endif // __TSQR_CombineFortran_hpp
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends