Kokkos Node API and Local Linear Algebra Kernels Version of the Day
TbbTsqr_TbbMgs.hpp
00001 //@HEADER
00002 // ************************************************************************
00003 // 
00004 //          Kokkos: Node API and Parallel Node Kernels
00005 //              Copyright (2008) Sandia Corporation
00006 // 
00007 // Under the terms of Contract DE-AC04-94AL85000 with Sandia Corporation,
00008 // the U.S. Government retains certain rights in this software.
00009 // 
00010 // Redistribution and use in source and binary forms, with or without
00011 // modification, are permitted provided that the following conditions are
00012 // met:
00013 //
00014 // 1. Redistributions of source code must retain the above copyright
00015 // notice, this list of conditions and the following disclaimer.
00016 //
00017 // 2. Redistributions in binary form must reproduce the above copyright
00018 // notice, this list of conditions and the following disclaimer in the
00019 // documentation and/or other materials provided with the distribution.
00020 //
00021 // 3. Neither the name of the Corporation nor the names of the
00022 // contributors may be used to endorse or promote products derived from
00023 // this software without specific prior written permission.
00024 //
00025 // THIS SOFTWARE IS PROVIDED BY SANDIA CORPORATION "AS IS" AND ANY
00026 // EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
00027 // IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
00028 // PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL SANDIA CORPORATION OR THE
00029 // CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
00030 // EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
00031 // PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
00032 // PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
00033 // LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
00034 // NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
00035 // SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
00036 //
00037 // Questions? Contact Michael A. Heroux (maherou@sandia.gov) 
00038 // 
00039 // ************************************************************************
00040 //@HEADER
00041 
00042 #ifndef __TSQR_TBB_TbbMgs_hpp
00043 #define __TSQR_TBB_TbbMgs_hpp
00044 
00045 #include <algorithm>
00046 #include <cassert>
00047 #include <cmath>
00048 #include <numeric>
00049 #include <utility> // std::pair
00050 
00051 #include <Tsqr_MessengerBase.hpp>
00052 #include <Teuchos_ScalarTraits.hpp>
00053 #include <Tsqr_Util.hpp>
00054 
00055 #include <Teuchos_RCP.hpp>
00056 
00057 #include <tbb/blocked_range.h>
00058 #include <tbb/parallel_for.h>
00059 #include <tbb/parallel_reduce.h>
00060 #include <tbb/partitioner.h>
00061 
00062 // #define TBB_MGS_DEBUG 1
00063 #ifdef TBB_MGS_DEBUG
00064 #  include <iostream>
00065 using std::cerr;
00066 using std::endl;
00067 #endif // TBB_MGS_DEBUG
00068 
00071 
00072 namespace TSQR {
00073   namespace TBB {
00074 
00075     // Forward declaration
00076     template< class LocalOrdinal, class Scalar >
00077     class TbbMgs {
00078     public:
00079       typedef Scalar scalar_type;
00080       typedef LocalOrdinal ordinal_type;
00081       typedef typename Teuchos::ScalarTraits<Scalar>::magnitudeType magnitude_type;
00082       typedef MessengerBase< Scalar > messenger_type;
00083       typedef Teuchos::RCP< messenger_type > messenger_ptr;
00084 
00085       TbbMgs (const messenger_ptr& messenger) :
00086   messenger_ (messenger) {}
00087     
00088       void 
00089       mgs (const LocalOrdinal nrows_local, 
00090      const LocalOrdinal ncols, 
00091      Scalar A_local[], 
00092      const LocalOrdinal lda_local,
00093      Scalar R[],
00094      const LocalOrdinal ldr);
00095 
00096     private:
00097       messenger_ptr messenger_;
00098     };
00099 
00102 
00103     namespace details {
00104 
00107       template< class LocalOrdinal, class Scalar >
00108       class TbbDot {
00109       public:
00110   void 
00111   operator() (const tbb::blocked_range< LocalOrdinal >& r) 
00112   {
00113     typedef Teuchos::ScalarTraits< Scalar > STS;
00114 
00115     // The TBB book likes this copying of pointers into the local routine.
00116     // It probably helps the compiler do optimizations.
00117     const Scalar* const x = x_;
00118     const Scalar* const y = y_;
00119     Scalar local_result = result_;
00120 
00121 #ifdef TBB_MGS_DEBUG
00122     // cerr << "Range: [" << r.begin() << ", " << r.end() << ")" << endl;
00123     // for (LocalOrdinal k = r.begin(); k != r.end(); ++k)
00124     //   cerr << "(x[" << k << "], y[" << k << "]) = (" << x[k] << "," << y[k] << ")" << " ";
00125     // cerr << endl;
00126 #endif // TBB_MGS_DEBUG
00127 
00128     for (LocalOrdinal i = r.begin(); i != r.end(); ++i)
00129       local_result += x[i] * STS::conjugate (y[i]);
00130 
00131 #ifdef TBB_MGS_DEBUG
00132     //    cerr << "-- Final value = " << local_result << endl;
00133 #endif // TBB_MGS_DEBUG
00134 
00135     result_ = local_result;
00136   }
00138   Scalar result() const { return result_; }
00139 
00141   TbbDot (const Scalar* const x, const Scalar* const y) :
00142     result_ (Scalar(0)), x_ (x), y_ (y) {}
00143 
00145   TbbDot (TbbDot& d, tbb::split) : 
00146     result_ (Scalar(0)), x_ (d.x_), y_ (d.y_)
00147   {}
00150   void join (const TbbDot& d) { result_ += d.result(); }
00151 
00152       private:
00153   // Default constructor doesn't make sense.
00154   TbbDot ();
00155 
00156   Scalar result_;
00157   const Scalar* const x_;
00158   const Scalar* const y_;
00159       };
00160 
00161       template< class LocalOrdinal, class Scalar >
00162       class TbbScale {
00163       public:
00164   TbbScale (Scalar* const x, const Scalar& denom) : 
00165     x_ (x), denom_ (denom) {}
00166 
00167   // TBB demands that this be a "const" operator, in order for
00168   // the parallel_for expression to compile.  Strictly speaking,
00169   // it is const, because it does not change the address of the
00170   // pointer x_ (only the values stored there).
00171   void 
00172   operator() (const tbb::blocked_range< LocalOrdinal >& r) const 
00173   {
00174     // TBB likes arrays to have their pointers copied like this in
00175     // the operator() method.  I suspect it has something to do
00176     // with compiler optimizations.  If C++ supported the
00177     // "restrict" keyword, here would be a good place to add it...
00178     Scalar* const x = x_;
00179     const Scalar denom = denom_;
00180     for (LocalOrdinal i = r.begin(); i != r.end(); ++i)
00181       x[i] = x[i] / denom;
00182   }
00183       private:
00184   Scalar* const x_;
00185   const Scalar denom_;
00186       };
00187 
00188       template< class LocalOrdinal, class Scalar >
00189       class TbbAxpy {
00190       public:
00191   TbbAxpy (const Scalar& alpha, const Scalar* const x, Scalar* const y) : 
00192     alpha_ (alpha), x_ (x), y_ (y) 
00193   {}
00194   // TBB demands that this be a "const" operator, in order for
00195   // the parallel_for expression to compile.  Strictly speaking,
00196   // it is const, because it does change the address of the
00197   // pointer y_ (only the values stored there).
00198   void 
00199   operator() (const tbb::blocked_range< LocalOrdinal >& r) const 
00200   {
00201     const Scalar alpha = alpha_;
00202     const Scalar* const x = x_;
00203     Scalar* const y = y_;
00204     for (LocalOrdinal i = r.begin(); i != r.end(); ++i)
00205       y[i] = y[i] + alpha * x[i];
00206   }
00207       private:
00208   const Scalar alpha_;
00209   const Scalar* const x_;
00210   Scalar* const y_;
00211       };
00212 
00213       template< class LocalOrdinal, class Scalar >
00214       class TbbNormSquared {
00215       public:
00216   typedef Teuchos::ScalarTraits< Scalar > STS;
00217   typedef typename STS::magnitudeType magnitude_type;
00218 
00219   void operator() (const tbb::blocked_range< LocalOrdinal >& r) {
00220     // Doing the right thing in the complex case requires taking
00221     // an absolute value.  We want to avoid this additional cost
00222     // in the real case, which is why we check is_complex.
00223     if (STS::isComplex) 
00224       {
00225         // The TBB book favors copying array pointers into the
00226         // local routine.  It probably helps the compiler do
00227         // optimizations.
00228         const Scalar* const x = x_;
00229         for (LocalOrdinal i = r.begin(); i != r.end(); ++i)
00230     {
00231       // One could implement this by computing
00232       //
00233       // result_ += STS::real (x[i] * STS::conjugate(x[i]));
00234       //
00235       // However, in terms of type theory, it's much more
00236       // natural to start with a magnitude_type before
00237       // doing the multiplication.
00238       const magnitude_type xi = STS::magnitude (x[i]);
00239       result_ += xi * xi;
00240     }
00241       }
00242     else
00243       {
00244         const Scalar* const x = x_;
00245         for (LocalOrdinal i = r.begin(); i != r.end(); ++i)
00246     {
00247       const Scalar xi = x[i];
00248       result_ += xi * xi;
00249     }
00250       }
00251   }
00252   magnitude_type result() const { return result_; }
00253 
00254   TbbNormSquared (const Scalar* const x) :
00255     result_ (magnitude_type(0)), x_ (x) {}
00256   TbbNormSquared (TbbNormSquared& d, tbb::split) : 
00257     result_ (magnitude_type(0)), x_ (d.x_) {}
00258   void join (const TbbNormSquared& d) { result_ += d.result(); }
00259 
00260       private:
00261   // Default constructor doesn't make sense
00262   TbbNormSquared ();
00263 
00264   magnitude_type result_;
00265   const Scalar* const x_;
00266       };
00267 
00270   
00271       template< class LocalOrdinal, class Scalar >
00272       class TbbMgsOps {
00273       private:
00274   typedef tbb::blocked_range< LocalOrdinal > range_type;
00275   typedef Teuchos::ScalarTraits< Scalar > STS;
00276 
00277       public:
00278   typedef MessengerBase< Scalar > messenger_type;
00279   typedef Teuchos::RCP< messenger_type > messenger_ptr;
00280   typedef typename Teuchos::ScalarTraits< Scalar >::magnitudeType magnitude_type;
00281 
00282   TbbMgsOps (const messenger_ptr& messenger) :
00283     messenger_ (messenger) {}
00284 
00285   void
00286   axpy (const LocalOrdinal nrows_local,
00287         const Scalar alpha,
00288         const Scalar x_local[],
00289         Scalar y_local[]) const
00290   {
00291     using tbb::auto_partitioner;
00292     using tbb::parallel_for;
00293 
00294     TbbAxpy< LocalOrdinal, Scalar > axpyer (alpha, x_local, y_local);
00295     parallel_for (range_type(0, nrows_local), axpyer, auto_partitioner());
00296   }
00297 
00298   void
00299   scale (const LocalOrdinal nrows_local, 
00300          Scalar x_local[], 
00301          const Scalar denom) const
00302   {
00303     using tbb::auto_partitioner;
00304     using tbb::parallel_for;
00305 
00306     // "scaler" is spelled that way (and not as "scalar") on
00307     // purpose.  Think about it.
00308     TbbScale< LocalOrdinal, Scalar > scaler (x_local, denom);
00309     parallel_for (range_type(0, nrows_local), scaler, auto_partitioner());
00310   }
00311 
00314   Scalar
00315   dot (const LocalOrdinal nrows_local, 
00316        const Scalar x_local[], 
00317        const Scalar y_local[])
00318   {
00319     Scalar localResult (0);
00320     if (true)
00321       {
00322         // FIXME (mfh 26 Aug 2010) I'm not sure why I did this
00323         // (i.e., why I wrote "if (true)" here).  Certainly the
00324         // branch that is currently enabled should produce
00325         // correct behavior.  I suspect the nonenabled branch
00326         // will not.
00327         if (true)
00328     {
00329       TbbDot< LocalOrdinal, Scalar > dotter (x_local, y_local);
00330       dotter(range_type(0, nrows_local));
00331       localResult = dotter.result();
00332     }
00333         else
00334     {
00335       using tbb::auto_partitioner;
00336       using tbb::parallel_reduce;
00337 
00338       TbbDot< LocalOrdinal, Scalar > dotter (x_local, y_local);
00339       parallel_reduce (range_type(0, nrows_local),
00340            dotter, auto_partitioner());
00341       localResult = dotter.result();
00342     }
00343       }
00344     else 
00345       {
00346         for (LocalOrdinal i = 0; i != nrows_local; ++i)
00347     localResult += x_local[i] * STS::conjugate (y_local[i]);
00348       }
00349     
00350     // FIXME (mfh 23 Apr 2010) Does MPI_SUM do the right thing for
00351     // complex or otherwise general MPI data types?  Perhaps an MPI_Op
00352     // should belong in the MessengerBase...
00353     return messenger_->globalSum (localResult);
00354   }
00355 
00356   magnitude_type
00357   norm2 (const LocalOrdinal nrows_local, 
00358          const Scalar x_local[])
00359   {
00360     using tbb::auto_partitioner;
00361     using tbb::parallel_reduce;
00362 
00363     TbbNormSquared< LocalOrdinal, Scalar > normer (x_local);
00364     parallel_reduce (range_type(0, nrows_local), normer, auto_partitioner());
00365     const magnitude_type localResult = normer.result();
00366     // FIXME (mfh 12 Oct 2010) This involves an implicit
00367     // typecast from Scalar to magnitude_type.
00368     const magnitude_type globalResult = messenger_->globalSum (localResult);
00369     // Make sure that sqrt's argument is a magnitude_type.  Of
00370     // course global_result should be nonnegative real, but we
00371     // want the compiler to pick up the correct sqrt function.
00372     return Teuchos::ScalarTraits< magnitude_type >::squareroot (globalResult);
00373   }
00374 
00375   Scalar
00376   project (const LocalOrdinal nrows_local, 
00377      const Scalar q_local[], 
00378      Scalar v_local[])
00379   {
00380     const Scalar coeff = this->dot (nrows_local, v_local, q_local);
00381     this->axpy (nrows_local, -coeff, q_local, v_local);
00382     return coeff;
00383   }
00384 
00385       private:
00386   messenger_ptr messenger_;
00387       };
00388     } // namespace details
00389 
00392 
00393     template< class LocalOrdinal, class Scalar >
00394     void
00395     TbbMgs< LocalOrdinal, Scalar >::mgs (const LocalOrdinal nrows_local, 
00396            const LocalOrdinal ncols, 
00397            Scalar A_local[], 
00398            const LocalOrdinal lda_local,
00399            Scalar R[],
00400            const LocalOrdinal ldr)
00401     {
00402       details::TbbMgsOps< LocalOrdinal, Scalar > ops (messenger_);
00403       
00404       for (LocalOrdinal j = 0; j < ncols; ++j)
00405   {
00406     Scalar* const v = &A_local[j*lda_local];
00407     for (LocalOrdinal i = 0; i < j; ++i)
00408       {
00409         const Scalar* const q = &A_local[i*lda_local];
00410         R[i + j*ldr] = ops.project (nrows_local, q, v);
00411 #ifdef TBB_MGS_DEBUG
00412         if (my_rank == 0)
00413     cerr << "(i,j) = (" << i << "," << j << "): coeff = " 
00414          << R[i + j*ldr] << endl;
00415 #endif // TBB_MGS_DEBUG
00416       }
00417     const magnitude_type denom = ops.norm2 (nrows_local, v);
00418 #ifdef TBB_MGS_DEBUG
00419     if (my_rank == 0)
00420       cerr << "j = " << j << ": denom = " << denom << endl;
00421 #endif // TBB_MGS_DEBUG
00422 
00423     // FIXME (mfh 29 Apr 2010)
00424     //
00425     // NOTE IMPLICIT CAST.  This should work for complex numbers.
00426     // If it doesn't work for your Scalar data type, it means that
00427     // you need a different data type for the diagonal elements of
00428     // the R factor, than you need for the other elements.  This
00429     // is unlikely if we're comparing MGS against a Householder QR
00430     // factorization; I don't really understand how the latter
00431     // would work (not that it couldn't be given a sensible
00432     // interpretation) in the case of Scalars that aren't plain
00433     // old real or complex numbers.
00434     R[j + j*ldr] = Scalar (denom);
00435     ops.scale (nrows_local, v, denom);
00436   }
00437     }
00438   } // namespace TBB
00439 } // namespace TSQR
00440 
00441 #endif // __TSQR_TBB_TbbMgs_hpp
00442 
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends