Kokkos Node API and Local Linear Algebra Kernels Version of the Day
Tsqr_Util.hpp
Go to the documentation of this file.
00001 //@HEADER
00002 // ************************************************************************
00003 // 
00004 //          Kokkos: Node API and Parallel Node Kernels
00005 //              Copyright (2008) Sandia Corporation
00006 // 
00007 // Under the terms of Contract DE-AC04-94AL85000 with Sandia Corporation,
00008 // the U.S. Government retains certain rights in this software.
00009 // 
00010 // Redistribution and use in source and binary forms, with or without
00011 // modification, are permitted provided that the following conditions are
00012 // met:
00013 //
00014 // 1. Redistributions of source code must retain the above copyright
00015 // notice, this list of conditions and the following disclaimer.
00016 //
00017 // 2. Redistributions in binary form must reproduce the above copyright
00018 // notice, this list of conditions and the following disclaimer in the
00019 // documentation and/or other materials provided with the distribution.
00020 //
00021 // 3. Neither the name of the Corporation nor the names of the
00022 // contributors may be used to endorse or promote products derived from
00023 // this software without specific prior written permission.
00024 //
00025 // THIS SOFTWARE IS PROVIDED BY SANDIA CORPORATION "AS IS" AND ANY
00026 // EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
00027 // IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
00028 // PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL SANDIA CORPORATION OR THE
00029 // CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
00030 // EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
00031 // PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
00032 // PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
00033 // LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
00034 // NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
00035 // SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
00036 //
00037 // Questions? Contact Michael A. Heroux (maherou@sandia.gov) 
00038 // 
00039 // ************************************************************************
00040 //@HEADER
00041 
00045 
00046 #ifndef __TSQR_Tsqr_Util_hpp
00047 #define __TSQR_Tsqr_Util_hpp
00048 
00049 #include "Tsqr_ScalarTraits.hpp"
00050 
00051 #include <algorithm>
00052 #include <complex>
00053 #include <ostream>
00054 
00055 
00056 namespace TSQR {
00057 
00072   template<class Scalar, bool isComplex>
00073   class ScalarPrinter {
00074   public:
00077     void operator() (std::ostream& out, const Scalar& elt) const;
00078   };
00079 
00080   // Partial specialization for real Scalar
00081   template< class Scalar >
00082   class ScalarPrinter< Scalar, false > {
00083   public:
00084     void operator() (std::ostream& out, const Scalar& elt) const {
00085       out << elt;
00086     }
00087   };
00088 
00089   // Partial specialization for complex Scalar
00090   template< class Scalar >
00091   class ScalarPrinter< Scalar, true > {
00092   public:
00093     void operator() (std::ostream& out, const Scalar& elt) const {
00094       typedef typename ScalarTraits< Scalar >::magnitude_type magnitude_type;
00095 
00096       const magnitude_type ZERO (0);
00097       const magnitude_type& realPart = std::real (elt);
00098       const magnitude_type& imagPart = std::imag (elt);
00099 
00100       out << realPart;
00101       if (imagPart < ZERO)
00102   out << "-" << ScalarTraits< Scalar >::abs(imagPart) << "*i";
00103       else if (imagPart > ZERO)
00104   out << "+" << imagPart << "*i";
00105     }
00106   };
00107 
00108   template< class LocalOrdinal, class Scalar >
00109   void
00110   print_local_matrix (std::ostream& out,
00111           const LocalOrdinal nrows_local,
00112           const LocalOrdinal ncols,
00113           const Scalar A[],
00114           const LocalOrdinal lda)
00115   {
00116     ScalarPrinter< Scalar, ScalarTraits< Scalar >::is_complex > printer;
00117     for (LocalOrdinal i = 0; i < nrows_local; i++)
00118       {
00119   for (LocalOrdinal j = 0; j < ncols; j++)
00120     {
00121       const Scalar& curElt = A[i + j*lda];
00122       printer (out, curElt);
00123       if (j < ncols - 1)
00124         out << ", ";
00125     }
00126   out << ";" << std::endl;
00127       }
00128   }
00129 
00130   template< class Ordinal, class Scalar >
00131   void
00132   copy_matrix (const Ordinal nrows,
00133          const Ordinal ncols,
00134          Scalar* const A,
00135          const Ordinal lda,
00136          const Scalar* const B,
00137          const Ordinal ldb)
00138   {
00139     for (Ordinal j = 0; j < ncols; j++)
00140       {
00141   Scalar* const A_j = &A[j*lda];
00142   const Scalar* const B_j = &B[j*ldb];
00143   std::copy (B_j, B_j + nrows, A_j);
00144       }
00145   }
00146 
00147   template< class Ordinal, class Scalar >
00148   void
00149   fill_matrix (const Ordinal nrows,
00150          const Ordinal ncols,
00151          Scalar* const A,
00152          const Ordinal lda,
00153          const Scalar& default_val)
00154   {
00155     for (Ordinal j = 0; j < ncols; j++)
00156       {
00157   Scalar* const A_j = &A[j*lda];
00158   std::fill (A_j, A_j + nrows, default_val);
00159       }
00160   }
00161 
00162 
00163   template< class Ordinal, class Scalar, class Generator >
00164   void
00165   generate_matrix (const Ordinal nrows,
00166        const Ordinal ncols,
00167        Scalar* const A,
00168        const Ordinal lda,
00169        Generator gen)
00170   {
00171     for (Ordinal j = 0; j < ncols; j++)
00172       {
00173   Scalar* const A_j = &A[j*lda];
00174   std::generate (A_j, A_j + nrows, gen);
00175       }
00176   }
00177 
00178   template< class Ordinal, class Scalar >
00179   void
00180   copy_upper_triangle (const Ordinal nrows,
00181            const Ordinal ncols,
00182            Scalar* const R_out,
00183            const Ordinal ldr_out,
00184            const Scalar* const R_in,
00185            const Ordinal ldr_in)
00186   {
00187     if (nrows >= ncols)
00188       {
00189   for (Ordinal j = 0; j < ncols; j++)
00190     {
00191       Scalar* const A_j = &R_out[j*ldr_out];
00192       const Scalar* const B_j = &R_in[j*ldr_in];
00193       for (Ordinal i = 0; i <= j; i++)
00194         A_j[i] = B_j[i];
00195     }
00196       }
00197     else
00198       {
00199   copy_upper_triangle (nrows, nrows, R_out, ldr_out, R_in, ldr_in);
00200   for (Ordinal j = nrows; j < ncols; j++)
00201     {
00202       Scalar* const A_j = &R_out[j*ldr_out];
00203       const Scalar* const B_j = &R_in[j*ldr_in];
00204       for (Ordinal i = 0; i < nrows; i++)
00205         A_j[i] = B_j[i];
00206     }
00207       }
00208   }
00209 
00210 
00211   template< class Scalar >
00212   class SumSquare {
00213   public:
00214     Scalar operator() (const Scalar& result, const Scalar& x) const { 
00215       return result + x*x; 
00216     }
00217   };
00218 
00219   // Specialization for complex numbers
00220   template< class Scalar >
00221   class SumSquare< std::complex< Scalar > >  {
00222   public:
00223     Scalar operator() (const std::complex<Scalar>& result, 
00224            const std::complex<Scalar>& x) const { 
00225       const Scalar absval = std::norm (x);
00226       return result + absval * absval; 
00227     }
00228   };
00229 
00230   template< class Ordinal, class Scalar >
00231   void
00232   pack_R_factor (const Ordinal nrows, 
00233      const Ordinal ncols, 
00234      const Scalar R_in[], 
00235      const Ordinal ldr_in,
00236      Scalar buffer[])
00237   {
00238     Ordinal count = 0; // current position in output buffer
00239     if (nrows >= ncols)
00240       for (Ordinal j = 0; j < ncols; j++)
00241   for (Ordinal i = 0; i <= j; i++)
00242     buffer[count++] = R_in[i + j*ldr_in];
00243     else
00244       for (Ordinal j = 0; j < nrows; j++)
00245   for (Ordinal i = 0; i <= j; i++)
00246     buffer[count++] = R_in[i + j*ldr_in];
00247   }
00248 
00249   template< class Ordinal, class Scalar >
00250   void
00251   unpack_R_factor (const Ordinal nrows, 
00252        const Ordinal ncols, 
00253        Scalar R_out[], 
00254        const Ordinal ldr_out,
00255        const Scalar buffer[])
00256   {
00257     Ordinal count = 0; // current position in input buffer
00258     if (nrows >= ncols)
00259       for (Ordinal j = 0; j < ncols; j++)
00260   for (Ordinal i = 0; i <= j; i++)
00261     R_out[i + j*ldr_out] = buffer[count++];
00262     else
00263       for (Ordinal j = 0; j < nrows; j++)
00264   for (Ordinal i = 0; i <= j; i++)
00265     R_out[i + j*ldr_out] = buffer[count++];
00266   }
00267 
00268 } // namespace TSQR
00269 
00270 #endif // __TSQR_Tsqr_Util_hpp
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends