Kokkos Node API and Local Linear Algebra Kernels Version of the Day
Tsqr_CombineFortran.hpp
Go to the documentation of this file.
00001 //@HEADER
00002 // ************************************************************************
00003 // 
00004 //          Kokkos: Node API and Parallel Node Kernels
00005 //              Copyright (2008) Sandia Corporation
00006 // 
00007 // Under the terms of Contract DE-AC04-94AL85000 with Sandia Corporation,
00008 // the U.S. Government retains certain rights in this software.
00009 // 
00010 // Redistribution and use in source and binary forms, with or without
00011 // modification, are permitted provided that the following conditions are
00012 // met:
00013 //
00014 // 1. Redistributions of source code must retain the above copyright
00015 // notice, this list of conditions and the following disclaimer.
00016 //
00017 // 2. Redistributions in binary form must reproduce the above copyright
00018 // notice, this list of conditions and the following disclaimer in the
00019 // documentation and/or other materials provided with the distribution.
00020 //
00021 // 3. Neither the name of the Corporation nor the names of the
00022 // contributors may be used to endorse or promote products derived from
00023 // this software without specific prior written permission.
00024 //
00025 // THIS SOFTWARE IS PROVIDED BY SANDIA CORPORATION "AS IS" AND ANY
00026 // EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
00027 // IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
00028 // PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL SANDIA CORPORATION OR THE
00029 // CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
00030 // EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
00031 // PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
00032 // PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
00033 // LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
00034 // NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
00035 // SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
00036 //
00037 // Questions? Contact Michael A. Heroux (maherou@sandia.gov) 
00038 // 
00039 // ************************************************************************
00040 //@HEADER
00041 
00045 #ifndef __TSQR_CombineFortran_hpp
00046 #define __TSQR_CombineFortran_hpp
00047 
00048 #include <Tsqr_ApplyType.hpp>
00049 #include <Tsqr_MatView.hpp>
00050 #include <Tsqr_ScalarTraits.hpp>
00051 #include <Tsqr_CombineDefault.hpp>
00052 
00053 
00054 namespace TSQR {
00055 
00065   template<class Scalar, bool is_complex = ScalarTraits<Scalar>::is_complex >
00066   class CombineFortran {
00067   private:
00068     typedef CombineDefault<int, Scalar> combine_default_type;
00069 
00070   public:
00071     typedef Scalar scalar_type;
00072     typedef typename ScalarTraits<Scalar>::magnitude_type magnitude_type;
00073     typedef int ordinal_type;
00074 
00075     CombineFortran () {}
00076 
00085     static bool QR_produces_R_factor_with_nonnegative_diagonal();
00086 
00087     void
00088     factor_first (const ordinal_type nrows,
00089       const ordinal_type ncols,
00090       Scalar A[],
00091       const ordinal_type lda,
00092       Scalar tau[],
00093       Scalar work[]) const;
00094     
00095     void
00096     apply_first (const ApplyType& applyType,
00097      const ordinal_type nrows,
00098      const ordinal_type ncols_C,
00099      const ordinal_type ncols_A,
00100      const Scalar A[],
00101      const ordinal_type lda,
00102      const Scalar tau[],
00103      Scalar C[],
00104      const ordinal_type ldc,
00105      Scalar work[]) const;
00106 
00107     void
00108     apply_inner (const ApplyType& apply_type,
00109      const ordinal_type m,
00110      const ordinal_type ncols_C,
00111      const ordinal_type ncols_Q,
00112      const Scalar A[],
00113      const ordinal_type lda,
00114      const Scalar tau[],
00115      Scalar C_top[],
00116      const ordinal_type ldc_top,
00117      Scalar C_bot[],
00118      const ordinal_type ldc_bot,
00119      Scalar work[]) const;
00120 
00121     void
00122     factor_inner (const ordinal_type m,
00123       const ordinal_type n,
00124       Scalar R[],
00125       const ordinal_type ldr,
00126       Scalar A[],
00127       const ordinal_type lda,
00128       Scalar tau[],
00129       Scalar work[]) const;
00130 
00131     void
00132     factor_pair (const ordinal_type n,
00133      Scalar R_top[],
00134      const ordinal_type ldr_top,
00135      Scalar R_bot[],
00136      const ordinal_type ldr_bot,
00137      Scalar tau[],
00138      Scalar work[]) const;
00139     
00140     void
00141     apply_pair (const ApplyType& apply_type,
00142     const ordinal_type ncols_C, 
00143     const ordinal_type ncols_Q, 
00144     const Scalar R_bot[], 
00145     const ordinal_type ldr_bot,
00146     const Scalar tau[], 
00147     Scalar C_top[], 
00148     const ordinal_type ldc_top, 
00149     Scalar C_bot[], 
00150     const ordinal_type ldc_bot, 
00151     Scalar work[]) const;
00152 
00153   private:
00154     mutable combine_default_type default_;
00155   };
00156 
00157   // "Forward declaration" for the real-arithmetic case.  The Fortran
00158   // back end works well here for Scalar = {float, double}.
00159   template< class Scalar >
00160   class CombineFortran< Scalar, false > {
00161   private:
00162     typedef CombineDefault< int, Scalar > combine_default_type;
00163 
00164   public:
00165     typedef Scalar scalar_type;
00166     typedef typename ScalarTraits< Scalar >::magnitude_type magnitude_type;
00167     typedef int ordinal_type;
00168 
00169     CombineFortran () {}
00170 
00171     static bool QR_produces_R_factor_with_nonnegative_diagonal() {
00172       typedef LAPACK< int, Scalar > lapack_type;
00173 
00174       return lapack_type::QR_produces_R_factor_with_nonnegative_diagonal() &&
00175   combine_default_type::QR_produces_R_factor_with_nonnegative_diagonal();
00176     }
00177 
00178     void
00179     factor_first (const ordinal_type nrows,
00180       const ordinal_type ncols,
00181       Scalar A[],
00182       const ordinal_type lda,
00183       Scalar tau[],
00184       Scalar work[]) const
00185     {
00186       return default_.factor_first (nrows, ncols, A, lda, tau, work);
00187     }
00188     
00189     void
00190     apply_first (const ApplyType& applyType,
00191      const ordinal_type nrows,
00192      const ordinal_type ncols_C,
00193      const ordinal_type ncols_A,
00194      const Scalar A[],
00195      const ordinal_type lda,
00196      const Scalar tau[],
00197      Scalar C[],
00198      const ordinal_type ldc,
00199      Scalar work[]) const
00200     {
00201       return default_.apply_first (applyType, nrows, ncols_C, ncols_A, 
00202            A, lda, tau, 
00203            C, ldc, work);
00204     }
00205 
00206     void
00207     apply_inner (const ApplyType& apply_type,
00208      const ordinal_type m,
00209      const ordinal_type ncols_C,
00210      const ordinal_type ncols_Q,
00211      const Scalar A[],
00212      const ordinal_type lda,
00213      const Scalar tau[],
00214      Scalar C_top[],
00215      const ordinal_type ldc_top,
00216      Scalar C_bot[],
00217      const ordinal_type ldc_bot,
00218      Scalar work[]) const;
00219 
00220     void
00221     factor_inner (const ordinal_type m,
00222       const ordinal_type n,
00223       Scalar R[],
00224       const ordinal_type ldr,
00225       Scalar A[],
00226       const ordinal_type lda,
00227       Scalar tau[],
00228       Scalar work[]) const;
00229 
00230     void
00231     factor_pair (const ordinal_type n,
00232      Scalar R_top[],
00233      const ordinal_type ldr_top,
00234      Scalar R_bot[],
00235      const ordinal_type ldr_bot,
00236      Scalar tau[],
00237      Scalar work[]) const;
00238     
00239     void
00240     apply_pair (const ApplyType& apply_type,
00241     const ordinal_type ncols_C, 
00242     const ordinal_type ncols_Q, 
00243     const Scalar R_bot[], 
00244     const ordinal_type ldr_bot,
00245     const Scalar tau[], 
00246     Scalar C_top[], 
00247     const ordinal_type ldc_top, 
00248     Scalar C_bot[], 
00249     const ordinal_type ldc_bot, 
00250     Scalar work[]) const;
00251 
00252   private:
00253     mutable combine_default_type default_;
00254   };
00255 
00256 
00257   // "Forward declaration" for complex-arithmetic version of
00258   // CombineFortran.  The Fortran code doesn't actually work for this
00259   // case, so we implement everything using CombineDefault.  This
00260   // will likely result in an ~2x slowdown for typical use cases.
00261   template< class Scalar >
00262   class CombineFortran< Scalar, true > {
00263   private:
00264     typedef CombineDefault< int, Scalar > combine_default_type;
00265 
00266   public:
00267     typedef Scalar scalar_type;
00268     typedef typename ScalarTraits< Scalar >::magnitude_type magnitude_type;
00269     typedef int ordinal_type;
00270 
00271     CombineFortran () {}
00272 
00273     static bool QR_produces_R_factor_with_nonnegative_diagonal() {
00274       return combine_default_type::QR_produces_R_factor_with_nonnegative_diagonal();
00275     }
00276 
00277     void
00278     factor_first (const ordinal_type nrows,
00279       const ordinal_type ncols,
00280       Scalar A[],
00281       const ordinal_type lda,
00282       Scalar tau[],
00283       Scalar work[]) const
00284     {
00285       return default_.factor_first (nrows, ncols, A, lda, tau, work);
00286     }
00287     
00288     void
00289     apply_first (const ApplyType& applyType,
00290      const ordinal_type nrows,
00291      const ordinal_type ncols_C,
00292      const ordinal_type ncols_A,
00293      const Scalar A[],
00294      const ordinal_type lda,
00295      const Scalar tau[],
00296      Scalar C[],
00297      const ordinal_type ldc,
00298      Scalar work[]) const
00299     {
00300       return default_.apply_first (applyType, nrows, ncols_C, ncols_A, 
00301            A, lda, tau, 
00302            C, ldc, work);
00303     }
00304 
00305     void
00306     apply_inner (const ApplyType& apply_type,
00307      const ordinal_type m,
00308      const ordinal_type ncols_C,
00309      const ordinal_type ncols_Q,
00310      const Scalar A[],
00311      const ordinal_type lda,
00312      const Scalar tau[],
00313      Scalar C_top[],
00314      const ordinal_type ldc_top,
00315      Scalar C_bot[],
00316      const ordinal_type ldc_bot,
00317      Scalar work[]) const
00318     {
00319       default_.apply_inner (apply_type, m, ncols_C, ncols_Q, 
00320           A, lda, tau, 
00321           C_top, ldc_top, C_bot, ldc_bot, work);
00322     }
00323 
00324     void
00325     factor_inner (const ordinal_type m,
00326       const ordinal_type n,
00327       Scalar R[],
00328       const ordinal_type ldr,
00329       Scalar A[],
00330       const ordinal_type lda,
00331       Scalar tau[],
00332       Scalar work[]) const
00333     {
00334       default_.factor_inner (m, n, R, ldr, A, lda, tau, work);
00335     }
00336 
00337     void
00338     factor_pair (const ordinal_type n,
00339      Scalar R_top[],
00340      const ordinal_type ldr_top,
00341      Scalar R_bot[],
00342      const ordinal_type ldr_bot,
00343      Scalar tau[],
00344      Scalar work[]) const
00345     {
00346       default_.factor_pair (n, R_top, ldr_top, R_bot, ldr_bot, tau, work);
00347     }
00348     
00349     void
00350     apply_pair (const ApplyType& apply_type,
00351     const ordinal_type ncols_C, 
00352     const ordinal_type ncols_Q, 
00353     const Scalar R_bot[], 
00354     const ordinal_type ldr_bot,
00355     const Scalar tau[], 
00356     Scalar C_top[], 
00357     const ordinal_type ldc_top, 
00358     Scalar C_bot[], 
00359     const ordinal_type ldc_bot, 
00360     Scalar work[]) const
00361     {
00362       default_.apply_pair (apply_type, ncols_C, ncols_Q, 
00363          R_bot, ldr_bot, tau, 
00364          C_top, ldc_top, C_bot, ldc_bot, work);
00365     }
00366 
00367   private:
00368     // Default implementation of TSQR::Combine copies data in and out
00369     // of a single matrix, which is given to LAPACK.  It's slow
00370     // because we expect the number of columns to be small, so copying
00371     // overhead is significant.  Experiments have shown a ~2x slowdown
00372     // due to copying overhead.
00373     mutable CombineDefault< ordinal_type, scalar_type > default_;
00374   };
00375 
00376 
00377 } // namespace TSQR
00378 
00379 #endif // __TSQR_CombineFortran_hpp
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends