Kokkos Node API and Local Linear Algebra Kernels Version of the Day
Tsqr_Util.hpp
Go to the documentation of this file.
00001 //@HEADER
00002 // ************************************************************************
00003 // 
00004 //          Kokkos: Node API and Parallel Node Kernels
00005 //              Copyright (2009) Sandia Corporation
00006 // 
00007 // Under terms of Contract DE-AC04-94AL85000, there is a non-exclusive
00008 // license for use of this work by or on behalf of the U.S. Government.
00009 // 
00010 // This library is free software; you can redistribute it and/or modify
00011 // it under the terms of the GNU Lesser General Public License as
00012 // published by the Free Software Foundation; either version 2.1 of the
00013 // License, or (at your option) any later version.
00014 //  
00015 // This library is distributed in the hope that it will be useful, but
00016 // WITHOUT ANY WARRANTY; without even the implied warranty of
00017 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
00018 // Lesser General Public License for more details.
00019 //  
00020 // You should have received a copy of the GNU Lesser General Public
00021 // License along with this library; if not, write to the Free Software
00022 // Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307
00023 // USA
00024 // Questions? Contact Michael A. Heroux (maherou@sandia.gov) 
00025 // 
00026 // ************************************************************************
00027 //@HEADER
00028 
00032 
00033 #ifndef __TSQR_Tsqr_Util_hpp
00034 #define __TSQR_Tsqr_Util_hpp
00035 
00036 #include "Tsqr_ScalarTraits.hpp"
00037 
00038 #include <algorithm>
00039 #include <complex>
00040 #include <ostream>
00041 
00042 
00043 namespace TSQR {
00044 
00059   template<class Scalar, bool isComplex>
00060   class ScalarPrinter {
00061   public:
00064     void operator() (std::ostream& out, const Scalar& elt) const;
00065   };
00066 
00067   // Partial specialization for real Scalar
00068   template< class Scalar >
00069   class ScalarPrinter< Scalar, false > {
00070   public:
00071     void operator() (std::ostream& out, const Scalar& elt) const {
00072       out << elt;
00073     }
00074   };
00075 
00076   // Partial specialization for complex Scalar
00077   template< class Scalar >
00078   class ScalarPrinter< Scalar, true > {
00079   public:
00080     void operator() (std::ostream& out, const Scalar& elt) const {
00081       typedef typename ScalarTraits< Scalar >::magnitude_type magnitude_type;
00082 
00083       const magnitude_type ZERO (0);
00084       const magnitude_type& realPart = std::real (elt);
00085       const magnitude_type& imagPart = std::imag (elt);
00086 
00087       out << realPart;
00088       if (imagPart < ZERO)
00089   out << "-" << ScalarTraits< Scalar >::abs(imagPart) << "*i";
00090       else if (imagPart > ZERO)
00091   out << "+" << imagPart << "*i";
00092     }
00093   };
00094 
00095   template< class LocalOrdinal, class Scalar >
00096   void
00097   print_local_matrix (std::ostream& out,
00098           const LocalOrdinal nrows_local,
00099           const LocalOrdinal ncols,
00100           const Scalar A[],
00101           const LocalOrdinal lda)
00102   {
00103     ScalarPrinter< Scalar, ScalarTraits< Scalar >::is_complex > printer;
00104     for (LocalOrdinal i = 0; i < nrows_local; i++)
00105       {
00106   for (LocalOrdinal j = 0; j < ncols; j++)
00107     {
00108       const Scalar& curElt = A[i + j*lda];
00109       printer (out, curElt);
00110       if (j < ncols - 1)
00111         out << ", ";
00112     }
00113   out << ";" << std::endl;
00114       }
00115   }
00116 
00117   template< class Ordinal, class Scalar >
00118   void
00119   copy_matrix (const Ordinal nrows,
00120          const Ordinal ncols,
00121          Scalar* const A,
00122          const Ordinal lda,
00123          const Scalar* const B,
00124          const Ordinal ldb)
00125   {
00126     for (Ordinal j = 0; j < ncols; j++)
00127       {
00128   Scalar* const A_j = &A[j*lda];
00129   const Scalar* const B_j = &B[j*ldb];
00130   std::copy (B_j, B_j + nrows, A_j);
00131       }
00132   }
00133 
00134   template< class Ordinal, class Scalar >
00135   void
00136   fill_matrix (const Ordinal nrows,
00137          const Ordinal ncols,
00138          Scalar* const A,
00139          const Ordinal lda,
00140          const Scalar& default_val)
00141   {
00142     for (Ordinal j = 0; j < ncols; j++)
00143       {
00144   Scalar* const A_j = &A[j*lda];
00145   std::fill (A_j, A_j + nrows, default_val);
00146       }
00147   }
00148 
00149 
00150   template< class Ordinal, class Scalar, class Generator >
00151   void
00152   generate_matrix (const Ordinal nrows,
00153        const Ordinal ncols,
00154        Scalar* const A,
00155        const Ordinal lda,
00156        Generator gen)
00157   {
00158     for (Ordinal j = 0; j < ncols; j++)
00159       {
00160   Scalar* const A_j = &A[j*lda];
00161   std::generate (A_j, A_j + nrows, gen);
00162       }
00163   }
00164 
00165   template< class Ordinal, class Scalar >
00166   void
00167   copy_upper_triangle (const Ordinal nrows,
00168            const Ordinal ncols,
00169            Scalar* const R_out,
00170            const Ordinal ldr_out,
00171            const Scalar* const R_in,
00172            const Ordinal ldr_in)
00173   {
00174     if (nrows >= ncols)
00175       {
00176   for (Ordinal j = 0; j < ncols; j++)
00177     {
00178       Scalar* const A_j = &R_out[j*ldr_out];
00179       const Scalar* const B_j = &R_in[j*ldr_in];
00180       for (Ordinal i = 0; i <= j; i++)
00181         A_j[i] = B_j[i];
00182     }
00183       }
00184     else
00185       {
00186   copy_upper_triangle (nrows, nrows, R_out, ldr_out, R_in, ldr_in);
00187   for (Ordinal j = nrows; j < ncols; j++)
00188     {
00189       Scalar* const A_j = &R_out[j*ldr_out];
00190       const Scalar* const B_j = &R_in[j*ldr_in];
00191       for (Ordinal i = 0; i < nrows; i++)
00192         A_j[i] = B_j[i];
00193     }
00194       }
00195   }
00196 
00197 
00198   template< class Scalar >
00199   class SumSquare {
00200   public:
00201     Scalar operator() (const Scalar& result, const Scalar& x) const { 
00202       return result + x*x; 
00203     }
00204   };
00205 
00206   // Specialization for complex numbers
00207   template< class Scalar >
00208   class SumSquare< std::complex< Scalar > >  {
00209   public:
00210     Scalar operator() (const std::complex<Scalar>& result, 
00211            const std::complex<Scalar>& x) const { 
00212       const Scalar absval = std::norm (x);
00213       return result + absval * absval; 
00214     }
00215   };
00216 
00217   template< class Ordinal, class Scalar >
00218   void
00219   pack_R_factor (const Ordinal nrows, 
00220      const Ordinal ncols, 
00221      const Scalar R_in[], 
00222      const Ordinal ldr_in,
00223      Scalar buffer[])
00224   {
00225     Ordinal count = 0; // current position in output buffer
00226     if (nrows >= ncols)
00227       for (Ordinal j = 0; j < ncols; j++)
00228   for (Ordinal i = 0; i <= j; i++)
00229     buffer[count++] = R_in[i + j*ldr_in];
00230     else
00231       for (Ordinal j = 0; j < nrows; j++)
00232   for (Ordinal i = 0; i <= j; i++)
00233     buffer[count++] = R_in[i + j*ldr_in];
00234   }
00235 
00236   template< class Ordinal, class Scalar >
00237   void
00238   unpack_R_factor (const Ordinal nrows, 
00239        const Ordinal ncols, 
00240        Scalar R_out[], 
00241        const Ordinal ldr_out,
00242        const Scalar buffer[])
00243   {
00244     Ordinal count = 0; // current position in input buffer
00245     if (nrows >= ncols)
00246       for (Ordinal j = 0; j < ncols; j++)
00247   for (Ordinal i = 0; i <= j; i++)
00248     R_out[i + j*ldr_out] = buffer[count++];
00249     else
00250       for (Ordinal j = 0; j < nrows; j++)
00251   for (Ordinal i = 0; i <= j; i++)
00252     R_out[i + j*ldr_out] = buffer[count++];
00253   }
00254 
00255 } // namespace TSQR
00256 
00257 #endif // __TSQR_Tsqr_Util_hpp
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends