Kokkos Node API and Local Linear Algebra Kernels Version of the Day
Tsqr_Mgs.hpp
00001 //@HEADER
00002 // ************************************************************************
00003 // 
00004 //          Kokkos: Node API and Parallel Node Kernels
00005 //              Copyright (2009) Sandia Corporation
00006 // 
00007 // Under terms of Contract DE-AC04-94AL85000, there is a non-exclusive
00008 // license for use of this work by or on behalf of the U.S. Government.
00009 // 
00010 // This library is free software; you can redistribute it and/or modify
00011 // it under the terms of the GNU Lesser General Public License as
00012 // published by the Free Software Foundation; either version 2.1 of the
00013 // License, or (at your option) any later version.
00014 //  
00015 // This library is distributed in the hope that it will be useful, but
00016 // WITHOUT ANY WARRANTY; without even the implied warranty of
00017 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
00018 // Lesser General Public License for more details.
00019 //  
00020 // You should have received a copy of the GNU Lesser General Public
00021 // License along with this library; if not, write to the Free Software
00022 // Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307
00023 // USA
00024 // Questions? Contact Michael A. Heroux (maherou@sandia.gov) 
00025 // 
00026 // ************************************************************************
00027 //@HEADER
00028 
00029 #ifndef __TSQR_Tsqr_Mgs_hpp
00030 #define __TSQR_Tsqr_Mgs_hpp
00031 
00032 #include <algorithm>
00033 #include <cassert>
00034 #include <cmath>
00035 #include <utility> // std::pair
00036 
00037 #include <Tsqr_MessengerBase.hpp>
00038 #include <Tsqr_Util.hpp>
00039 
00040 #include <Teuchos_RCP.hpp>
00041 #include <Teuchos_ScalarTraits.hpp>
00042 
00043 // #define MGS_DEBUG 1
00044 #ifdef MGS_DEBUG
00045 #  include <iostream>
00046 using std::cerr;
00047 using std::endl;
00048 #endif // MGS_DEBUG
00049 
00050 
00051 namespace TSQR {
00052 
00055   template<class LocalOrdinal, class Scalar>
00056   class MGS {
00057   public:
00058     typedef Scalar scalar_type;
00059     typedef LocalOrdinal ordinal_type;
00060     typedef Teuchos::ScalarTraits< Scalar > STS;
00061     typedef typename STS::magnitudeType magnitude_type;
00062 
00067     MGS (const Teuchos::RCP< MessengerBase< Scalar > >& messenger) : 
00068       messenger_ (messenger) {}
00069 
00078     bool QR_produces_R_factor_with_nonnegative_diagonal () const {
00079       return true;
00080     }
00081     
00083     void 
00084     mgs (const LocalOrdinal nrows_local, 
00085    const LocalOrdinal ncols, 
00086    Scalar A_local[], 
00087    const LocalOrdinal lda_local,
00088    Scalar R[],
00089    const LocalOrdinal ldr);
00090 
00091   private:
00092     Teuchos::RCP<MessengerBase<Scalar> > messenger_;
00093   };
00094 
00095 
00096   namespace details {
00097 
00098     template<class LocalOrdinal, class Scalar>
00099     class MgsOps {
00100     public:
00101       typedef Teuchos::ScalarTraits< Scalar > STS;
00102       typedef typename STS::magnitudeType magnitude_type;
00103 
00104       MgsOps (const Teuchos::RCP< MessengerBase< Scalar > >& messenger) : 
00105   messenger_ (messenger) {}
00106 
00107       void
00108       axpy (const LocalOrdinal nrows_local,
00109       const Scalar alpha,
00110       const Scalar x_local[],
00111       Scalar y_local[]) const
00112       {
00113   for (LocalOrdinal i = 0; i < nrows_local; ++i)
00114     y_local[i] = y_local[i] + alpha * x_local[i];
00115       }
00116 
00117       void
00118       scale (const LocalOrdinal nrows_local, 
00119        Scalar x_local[], 
00120        const Scalar denom) const
00121       {
00122   for (LocalOrdinal i = 0; i < nrows_local; ++i)
00123     x_local[i] = x_local[i] / denom;
00124       }
00125 
00128       Scalar
00129       dot (const LocalOrdinal nrows_local, 
00130      const Scalar x_local[], 
00131      const Scalar y_local[])
00132       {
00133   Scalar local_result (0);
00134 
00135 #ifdef MGS_DEBUG
00136   // for (LocalOrdinal k = 0; k != nrows_local; ++k)
00137   //   cerr << "(x[" << k << "], y[" << k << "]) = (" << x_local[k] << "," << y_local[k] << ")" << " ";
00138   //   cerr << endl;
00139 #endif // MGS_DEBUG
00140 
00141   for (LocalOrdinal i = 0; i < nrows_local; ++i)
00142     local_result += x_local[i] * STS::conjugate (y_local[i]);
00143 
00144 #ifdef MGS_DEBUG
00145     // cerr << "-- Final value on this proc = " << local_result << endl;
00146 #endif // MGS_DEBUG
00147 
00148   // FIXME (mfh 23 Apr 2010) Does MPI_SUM do the right thing for
00149   // complex or otherwise general MPI data types?  Perhaps an MPI_Op
00150   // should belong in the MessengerBase...
00151   return messenger_->globalSum (local_result);
00152       }
00153 
00154       magnitude_type
00155       norm2 (const LocalOrdinal nrows_local, 
00156        const Scalar x_local[])
00157       {
00158   Scalar localResult (0);
00159 
00160   // Doing the right thing in the complex case requires taking
00161   // an absolute value.  We want to avoid this additional cost
00162   // in the real case, which is why we check is_complex.
00163   if (STS::isComplex)
00164     {
00165       for (LocalOrdinal i = 0; i < nrows_local; ++i)
00166         {
00167     const Scalar xi = STS::magnitude (x_local[i]);
00168     localResult += xi * xi;
00169         }
00170     }
00171   else
00172     {
00173       for (LocalOrdinal i = 0; i < nrows_local; ++i)
00174         {
00175     const Scalar xi = x_local[i];
00176     localResult += xi * xi;
00177         }
00178     }
00179   const Scalar globalResult = messenger_->globalSum (localResult);
00180   // sqrt doesn't make sense if the type of Scalar is complex,
00181   // even if the imaginary part of global_result is zero.
00182   return STS::squareroot (STS::magnitude (globalResult));
00183       }
00184 
00185       Scalar
00186       project (const LocalOrdinal nrows_local, 
00187          const Scalar q_local[], 
00188          Scalar v_local[])
00189       {
00190   const Scalar coeff = this->dot (nrows_local, v_local, q_local);
00191   this->axpy (nrows_local, -coeff, q_local, v_local);
00192   return coeff;
00193       }
00194 
00195     private:
00196       Teuchos::RCP< MessengerBase< Scalar > > messenger_;
00197     };
00198   } // namespace details
00199 
00200 
00201   template<class LocalOrdinal, class Scalar>
00202   void
00203   MGS<LocalOrdinal, Scalar>::mgs (const LocalOrdinal nrows_local, 
00204           const LocalOrdinal ncols, 
00205           Scalar A_local[], 
00206           const LocalOrdinal lda_local,
00207           Scalar R[],
00208           const LocalOrdinal ldr)
00209   {
00210     details::MgsOps<LocalOrdinal, Scalar> ops (messenger_);
00211     
00212     for (LocalOrdinal j = 0; j < ncols; ++j)
00213       {
00214   Scalar* const v = &A_local[j*lda_local];
00215   for (LocalOrdinal i = 0; i < j; ++i)
00216     {
00217       const Scalar* const q = &A_local[i*lda_local];
00218       R[i + j*ldr] = ops.project (nrows_local, q, v);
00219 #ifdef MGS_DEBUG
00220       if (my_rank == 0)
00221         cerr << "(i,j) = (" << i << "," << j << "): coeff = " << R[i + j*ldr] << endl;
00222 #endif // MGS_DEBUG
00223     }
00224   const magnitude_type denom = ops.norm2 (nrows_local, v);
00225 #ifdef MGS_DEBUG
00226     if (my_rank == 0)
00227       cerr << "j = " << j << ": denom = " << denom << endl;
00228 #endif // MGS_DEBUG
00229 
00230   // FIXME (mfh 29 Apr 2010)
00231   //
00232   // NOTE IMPLICIT CAST.  This should work for complex numbers.
00233   // If it doesn't work for your Scalar data type, it means that
00234   // you need a different data type for the diagonal elements of
00235   // the R factor, than you need for the other elements.  This
00236   // is unlikely if we're comparing MGS against a Householder QR
00237   // factorization; I don't really understand how the latter
00238   // would work (not that it couldn't be given a sensible
00239   // interpretation) in the case of Scalars that aren't plain
00240   // old real or complex numbers.
00241   R[j + j*ldr] = Scalar (denom);
00242   ops.scale (nrows_local, v, denom);
00243       }
00244   }
00245 
00246 } // namespace TSQR
00247 
00248 #endif // __TSQR_Tsqr_Mgs_hpp
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends