Kokkos Node API and Local Linear Algebra Kernels Version of the Day
Tsqr_Mgs.hpp
00001 //@HEADER
00002 // ************************************************************************
00003 // 
00004 //          Kokkos: Node API and Parallel Node Kernels
00005 //              Copyright (2008) Sandia Corporation
00006 // 
00007 // Under the terms of Contract DE-AC04-94AL85000 with Sandia Corporation,
00008 // the U.S. Government retains certain rights in this software.
00009 // 
00010 // Redistribution and use in source and binary forms, with or without
00011 // modification, are permitted provided that the following conditions are
00012 // met:
00013 //
00014 // 1. Redistributions of source code must retain the above copyright
00015 // notice, this list of conditions and the following disclaimer.
00016 //
00017 // 2. Redistributions in binary form must reproduce the above copyright
00018 // notice, this list of conditions and the following disclaimer in the
00019 // documentation and/or other materials provided with the distribution.
00020 //
00021 // 3. Neither the name of the Corporation nor the names of the
00022 // contributors may be used to endorse or promote products derived from
00023 // this software without specific prior written permission.
00024 //
00025 // THIS SOFTWARE IS PROVIDED BY SANDIA CORPORATION "AS IS" AND ANY
00026 // EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
00027 // IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
00028 // PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL SANDIA CORPORATION OR THE
00029 // CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
00030 // EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
00031 // PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
00032 // PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
00033 // LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
00034 // NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
00035 // SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
00036 //
00037 // Questions? Contact Michael A. Heroux (maherou@sandia.gov) 
00038 // 
00039 // ************************************************************************
00040 //@HEADER
00041 
00042 #ifndef __TSQR_Tsqr_Mgs_hpp
00043 #define __TSQR_Tsqr_Mgs_hpp
00044 
00045 #include <algorithm>
00046 #include <cassert>
00047 #include <cmath>
00048 #include <utility> // std::pair
00049 
00050 #include <Tsqr_MessengerBase.hpp>
00051 #include <Tsqr_Util.hpp>
00052 
00053 #include <Teuchos_RCP.hpp>
00054 #include <Teuchos_ScalarTraits.hpp>
00055 
00056 // #define MGS_DEBUG 1
00057 #ifdef MGS_DEBUG
00058 #  include <iostream>
00059 using std::cerr;
00060 using std::endl;
00061 #endif // MGS_DEBUG
00062 
00063 
00064 namespace TSQR {
00065 
00068   template<class LocalOrdinal, class Scalar>
00069   class MGS {
00070   public:
00071     typedef Scalar scalar_type;
00072     typedef LocalOrdinal ordinal_type;
00073     typedef Teuchos::ScalarTraits< Scalar > STS;
00074     typedef typename STS::magnitudeType magnitude_type;
00075 
00080     MGS (const Teuchos::RCP< MessengerBase< Scalar > >& messenger) : 
00081       messenger_ (messenger) {}
00082 
00091     bool QR_produces_R_factor_with_nonnegative_diagonal () const {
00092       return true;
00093     }
00094     
00096     void 
00097     mgs (const LocalOrdinal nrows_local, 
00098    const LocalOrdinal ncols, 
00099    Scalar A_local[], 
00100    const LocalOrdinal lda_local,
00101    Scalar R[],
00102    const LocalOrdinal ldr);
00103 
00104   private:
00105     Teuchos::RCP<MessengerBase<Scalar> > messenger_;
00106   };
00107 
00108 
00109   namespace details {
00110 
00111     template<class LocalOrdinal, class Scalar>
00112     class MgsOps {
00113     public:
00114       typedef Teuchos::ScalarTraits< Scalar > STS;
00115       typedef typename STS::magnitudeType magnitude_type;
00116 
00117       MgsOps (const Teuchos::RCP< MessengerBase< Scalar > >& messenger) : 
00118   messenger_ (messenger) {}
00119 
00120       void
00121       axpy (const LocalOrdinal nrows_local,
00122       const Scalar alpha,
00123       const Scalar x_local[],
00124       Scalar y_local[]) const
00125       {
00126   for (LocalOrdinal i = 0; i < nrows_local; ++i)
00127     y_local[i] = y_local[i] + alpha * x_local[i];
00128       }
00129 
00130       void
00131       scale (const LocalOrdinal nrows_local, 
00132        Scalar x_local[], 
00133        const Scalar denom) const
00134       {
00135   for (LocalOrdinal i = 0; i < nrows_local; ++i)
00136     x_local[i] = x_local[i] / denom;
00137       }
00138 
00141       Scalar
00142       dot (const LocalOrdinal nrows_local, 
00143      const Scalar x_local[], 
00144      const Scalar y_local[])
00145       {
00146   Scalar local_result (0);
00147 
00148 #ifdef MGS_DEBUG
00149   // for (LocalOrdinal k = 0; k != nrows_local; ++k)
00150   //   cerr << "(x[" << k << "], y[" << k << "]) = (" << x_local[k] << "," << y_local[k] << ")" << " ";
00151   //   cerr << endl;
00152 #endif // MGS_DEBUG
00153 
00154   for (LocalOrdinal i = 0; i < nrows_local; ++i)
00155     local_result += x_local[i] * STS::conjugate (y_local[i]);
00156 
00157 #ifdef MGS_DEBUG
00158     // cerr << "-- Final value on this proc = " << local_result << endl;
00159 #endif // MGS_DEBUG
00160 
00161   // FIXME (mfh 23 Apr 2010) Does MPI_SUM do the right thing for
00162   // complex or otherwise general MPI data types?  Perhaps an MPI_Op
00163   // should belong in the MessengerBase...
00164   return messenger_->globalSum (local_result);
00165       }
00166 
00167       magnitude_type
00168       norm2 (const LocalOrdinal nrows_local, 
00169        const Scalar x_local[])
00170       {
00171   Scalar localResult (0);
00172 
00173   // Doing the right thing in the complex case requires taking
00174   // an absolute value.  We want to avoid this additional cost
00175   // in the real case, which is why we check is_complex.
00176   if (STS::isComplex)
00177     {
00178       for (LocalOrdinal i = 0; i < nrows_local; ++i)
00179         {
00180     const Scalar xi = STS::magnitude (x_local[i]);
00181     localResult += xi * xi;
00182         }
00183     }
00184   else
00185     {
00186       for (LocalOrdinal i = 0; i < nrows_local; ++i)
00187         {
00188     const Scalar xi = x_local[i];
00189     localResult += xi * xi;
00190         }
00191     }
00192   const Scalar globalResult = messenger_->globalSum (localResult);
00193   // sqrt doesn't make sense if the type of Scalar is complex,
00194   // even if the imaginary part of global_result is zero.
00195   return STS::squareroot (STS::magnitude (globalResult));
00196       }
00197 
00198       Scalar
00199       project (const LocalOrdinal nrows_local, 
00200          const Scalar q_local[], 
00201          Scalar v_local[])
00202       {
00203   const Scalar coeff = this->dot (nrows_local, v_local, q_local);
00204   this->axpy (nrows_local, -coeff, q_local, v_local);
00205   return coeff;
00206       }
00207 
00208     private:
00209       Teuchos::RCP< MessengerBase< Scalar > > messenger_;
00210     };
00211   } // namespace details
00212 
00213 
00214   template<class LocalOrdinal, class Scalar>
00215   void
00216   MGS<LocalOrdinal, Scalar>::mgs (const LocalOrdinal nrows_local, 
00217           const LocalOrdinal ncols, 
00218           Scalar A_local[], 
00219           const LocalOrdinal lda_local,
00220           Scalar R[],
00221           const LocalOrdinal ldr)
00222   {
00223     details::MgsOps<LocalOrdinal, Scalar> ops (messenger_);
00224     
00225     for (LocalOrdinal j = 0; j < ncols; ++j)
00226       {
00227   Scalar* const v = &A_local[j*lda_local];
00228   for (LocalOrdinal i = 0; i < j; ++i)
00229     {
00230       const Scalar* const q = &A_local[i*lda_local];
00231       R[i + j*ldr] = ops.project (nrows_local, q, v);
00232 #ifdef MGS_DEBUG
00233       if (my_rank == 0)
00234         cerr << "(i,j) = (" << i << "," << j << "): coeff = " << R[i + j*ldr] << endl;
00235 #endif // MGS_DEBUG
00236     }
00237   const magnitude_type denom = ops.norm2 (nrows_local, v);
00238 #ifdef MGS_DEBUG
00239     if (my_rank == 0)
00240       cerr << "j = " << j << ": denom = " << denom << endl;
00241 #endif // MGS_DEBUG
00242 
00243   // FIXME (mfh 29 Apr 2010)
00244   //
00245   // NOTE IMPLICIT CAST.  This should work for complex numbers.
00246   // If it doesn't work for your Scalar data type, it means that
00247   // you need a different data type for the diagonal elements of
00248   // the R factor, than you need for the other elements.  This
00249   // is unlikely if we're comparing MGS against a Householder QR
00250   // factorization; I don't really understand how the latter
00251   // would work (not that it couldn't be given a sensible
00252   // interpretation) in the case of Scalars that aren't plain
00253   // old real or complex numbers.
00254   R[j + j*ldr] = Scalar (denom);
00255   ops.scale (nrows_local, v, denom);
00256       }
00257   }
00258 
00259 } // namespace TSQR
00260 
00261 #endif // __TSQR_Tsqr_Mgs_hpp
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends