Anasazi Version of the Day
TbbTsqr_TbbMgs.hpp
00001 // @HEADER
00002 // ***********************************************************************
00003 //
00004 //                 Anasazi: Block Eigensolvers Package
00005 //                 Copyright (2010) Sandia Corporation
00006 //
00007 // Under terms of Contract DE-AC04-94AL85000, there is a non-exclusive
00008 // license for use of this work by or on behalf of the U.S. Government.
00009 //
00010 // This library is free software; you can redistribute it and/or modify
00011 // it under the terms of the GNU Lesser General Public License as
00012 // published by the Free Software Foundation; either version 2.1 of the
00013 // License, or (at your option) any later version.
00014 //
00015 // This library is distributed in the hope that it will be useful, but
00016 // WITHOUT ANY WARRANTY; without even the implied warranty of
00017 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
00018 // Lesser General Public License for more details.
00019 //
00020 // You should have received a copy of the GNU Lesser General Public
00021 // License along with this library; if not, write to the Free Software
00022 // Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307
00023 // USA
00024 // Questions? Contact Michael A. Heroux (maherou@sandia.gov)
00025 //
00026 // ***********************************************************************
00027 // @HEADER
00028 
00029 #ifndef __TSQR_TBB_TbbMgs_hpp
00030 #define __TSQR_TBB_TbbMgs_hpp
00031 
00032 #include <algorithm>
00033 #include <cassert>
00034 #include <cmath>
00035 #include <numeric>
00036 #include <utility> // std::pair
00037 
00038 #include <Tsqr_MessengerBase.hpp>
00039 #include <Tsqr_ScalarTraits.hpp>
00040 #include <Tsqr_Util.hpp>
00041 
00042 #include <Teuchos_RCP.hpp>
00043 
00044 #include <tbb/blocked_range.h>
00045 #include <tbb/parallel_for.h>
00046 #include <tbb/parallel_reduce.h>
00047 #include <tbb/partitioner.h>
00048 
00049 // #define TBB_MGS_DEBUG 1
00050 #ifdef TBB_MGS_DEBUG
00051 #  include <iostream>
00052 using std::cerr;
00053 using std::endl;
00054 #endif // TBB_MGS_DEBUG
00055 
00058 
00059 namespace TSQR {
00060   namespace TBB {
00061 
00062     // Forward declaration
00063     template< class LocalOrdinal, class Scalar >
00064     class TbbMgs {
00065     public:
00066       typedef Scalar scalar_type;
00067       typedef LocalOrdinal ordinal_type;
00068       typedef typename ScalarTraits<Scalar>::magnitude_type magnitude_type;
00069       typedef MessengerBase< Scalar > messenger_type;
00070       typedef Teuchos::RCP< messenger_type > messenger_ptr;
00071 
00072       TbbMgs (const messenger_ptr& messenger) :
00073   messenger_ (messenger) {}
00074     
00075       void 
00076       mgs (const LocalOrdinal nrows_local, 
00077      const LocalOrdinal ncols, 
00078      Scalar A_local[], 
00079      const LocalOrdinal lda_local,
00080      Scalar R[],
00081      const LocalOrdinal ldr);
00082 
00083     private:
00084       messenger_ptr messenger_;
00085     };
00086 
00089 
00090     namespace details {
00091 
00094       template< class LocalOrdinal, class Scalar >
00095       class TbbDot {
00096       public:
00097   void operator() (const tbb::blocked_range< LocalOrdinal >& r) {
00098     // The TBB book likes this copying of pointers into the local routine.
00099     // It probably helps the compiler do optimizations.
00100     const Scalar* const x = x_;
00101     const Scalar* const y = y_;
00102     Scalar local_result = result_;
00103 
00104 #ifdef TBB_MGS_DEBUG
00105     // cerr << "Range: [" << r.begin() << ", " << r.end() << ")" << endl;
00106     // for (LocalOrdinal k = r.begin(); k != r.end(); ++k)
00107     //   cerr << "(x[" << k << "], y[" << k << "]) = (" << x[k] << "," << y[k] << ")" << " ";
00108     // cerr << endl;
00109 #endif // TBB_MGS_DEBUG
00110 
00111     for (LocalOrdinal i = r.begin(); i != r.end(); ++i)
00112       local_result += x[i] * ScalarTraits< Scalar >::conj (y[i]);
00113 
00114 #ifdef TBB_MGS_DEBUG
00115     //    cerr << "-- Final value = " << local_result << endl;
00116 #endif // TBB_MGS_DEBUG
00117 
00118     result_ = local_result;
00119   }
00121   Scalar result() const { return result_; }
00122 
00124   TbbDot (const Scalar* const x, const Scalar* const y) :
00125     result_ (Scalar(0)), x_ (x), y_ (y) {}
00126 
00128   TbbDot (TbbDot& d, tbb::split) : 
00129     result_ (Scalar(0)), x_ (d.x_), y_ (d.y_)
00130   {}
00133   void join (const TbbDot& d) { result_ += d.result(); }
00134 
00135       private:
00136   // Default constructor doesn't make sense.
00137   TbbDot ();
00138 
00139   Scalar result_;
00140   const Scalar* const x_;
00141   const Scalar* const y_;
00142       };
00143 
00144       template< class LocalOrdinal, class Scalar >
00145       class TbbScale {
00146       public:
00147   TbbScale (Scalar* const x, const Scalar& denom) : x_ (x), denom_ (denom) {}
00148 
00149   // TBB demands that this be a "const" operator, in order for
00150   // the parallel_for expression to compile.  Strictly speaking,
00151   // it is const, because it does not change the address of the
00152   // pointer x_ (only the values stored there).
00153   void operator() (const tbb::blocked_range< LocalOrdinal >& r) const {
00154     // TBB likes arrays to have their pointers copied like this in
00155     // the operator() method.  I suspect it has something to do
00156     // with compiler optimizations.  If C++ supported the
00157     // "restrict" keyword, here would be a good place to add it...
00158     Scalar* const x = x_;
00159     const Scalar denom = denom_;
00160     for (LocalOrdinal i = r.begin(); i != r.end(); ++i)
00161       x[i] = x[i] / denom;
00162   }
00163       private:
00164   Scalar* const x_;
00165   const Scalar denom_;
00166       };
00167 
00168       template< class LocalOrdinal, class Scalar >
00169       class TbbAxpy {
00170       public:
00171   TbbAxpy (const Scalar& alpha, const Scalar* const x, Scalar* const y) : 
00172     alpha_ (alpha), x_ (x), y_ (y) 
00173   {}
00174   // TBB demands that this be a "const" operator, in order for
00175   // the parallel_for expression to compile.  Strictly speaking,
00176   // it is const, because it does change the address of the
00177   // pointer y_ (only the values stored there).
00178   void operator() (const tbb::blocked_range< LocalOrdinal >& r) const {
00179     const Scalar alpha = alpha_;
00180     const Scalar* const x = x_;
00181     Scalar* const y = y_;
00182     for (LocalOrdinal i = r.begin(); i != r.end(); ++i)
00183       y[i] = y[i] + alpha * x[i];
00184   }
00185       private:
00186   const Scalar alpha_;
00187   const Scalar* const x_;
00188   Scalar* const y_;
00189       };
00190 
00191       template< class LocalOrdinal, class Scalar >
00192       class TbbNormSquared {
00193       public:
00194   typedef typename ScalarTraits< Scalar >::magnitude_type magnitude_type;
00195 
00196   void operator() (const tbb::blocked_range< LocalOrdinal >& r) {
00197     // Doing the right thing in the complex case requires taking
00198     // an absolute value.  We want to avoid this additional cost
00199     // in the real case, which is why we check is_complex.
00200     if (ScalarTraits< Scalar >::is_complex) 
00201       {
00202         // The TBB book favors copying array pointers into the
00203         // local routine.  It probably helps the compiler do
00204         // optimizations.
00205         const Scalar* const x = x_;
00206         for (LocalOrdinal i = r.begin(); i != r.end(); ++i)
00207     {
00208       const magnitude_type xi = ScalarTraits< Scalar >::abs (x[i]);
00209       result_ += xi * xi;
00210     }
00211       }
00212     else
00213       {
00214         const Scalar* const x = x_;
00215         for (LocalOrdinal i = r.begin(); i != r.end(); ++i)
00216     {
00217       const Scalar xi = x[i];
00218       result_ += xi * xi;
00219     }
00220       }
00221   }
00222   magnitude_type result() const { return result_; }
00223 
00224   TbbNormSquared (const Scalar* const x) :
00225     result_ (magnitude_type(0)), x_ (x) {}
00226   TbbNormSquared (TbbNormSquared& d, tbb::split) : 
00227     result_ (magnitude_type(0)), x_ (d.x_) {}
00228   void join (const TbbNormSquared& d) { result_ += d.result(); }
00229 
00230       private:
00231   // Default constructor doesn't make sense
00232   TbbNormSquared ();
00233 
00234   magnitude_type result_;
00235   const Scalar* const x_;
00236       };
00237 
00240   
00241       template< class LocalOrdinal, class Scalar >
00242       class TbbMgsOps {
00243       private:
00244   typedef tbb::blocked_range< LocalOrdinal > range_type;
00245 
00246       public:
00247   typedef MessengerBase< Scalar > messenger_type;
00248   typedef Teuchos::RCP< messenger_type > messenger_ptr;
00249 
00250   TbbMgsOps (const messenger_ptr& messenger) :
00251     messenger_ (messenger) {}
00252 
00253   void
00254   axpy (const LocalOrdinal nrows_local,
00255         const Scalar alpha,
00256         const Scalar x_local[],
00257         Scalar y_local[]) const
00258   {
00259     using tbb::auto_partitioner;
00260     using tbb::parallel_for;
00261 
00262     TbbAxpy< LocalOrdinal, Scalar > axpyer (alpha, x_local, y_local);
00263     parallel_for (range_type(0, nrows_local), axpyer, auto_partitioner());
00264   }
00265 
00266   void
00267   scale (const LocalOrdinal nrows_local, 
00268          Scalar x_local[], 
00269          const Scalar denom) const
00270   {
00271     using tbb::auto_partitioner;
00272     using tbb::parallel_for;
00273 
00274     // "scaler" is spelled that way (and not as "scalar") on
00275     // purpose.  Think about it.
00276     TbbScale< LocalOrdinal, Scalar > scaler (x_local, denom);
00277     parallel_for (range_type(0, nrows_local), scaler, auto_partitioner());
00278   }
00279 
00282   Scalar
00283   dot (const LocalOrdinal nrows_local, 
00284        const Scalar x_local[], 
00285        const Scalar y_local[])
00286   {
00287     Scalar local_result (0);
00288     if (true)
00289       {
00290         // FIXME (mfh 26 Aug 2010) I'm not sure why I did this
00291         // (i.e., why I wrote "if (true)" here).  Certainly the
00292         // branch that is currently enabled should produce
00293         // correct behavior.  I suspect the nonenabled branch
00294         // will not.
00295         if (true)
00296     {
00297       TbbDot< LocalOrdinal, Scalar > dotter (x_local, y_local);
00298       dotter(range_type(0, nrows_local));
00299       local_result = dotter.result();
00300     }
00301         else
00302     {
00303       using tbb::auto_partitioner;
00304       using tbb::parallel_reduce;
00305 
00306       TbbDot< LocalOrdinal, Scalar > dotter (x_local, y_local);
00307       parallel_reduce (range_type(0, nrows_local), dotter, auto_partitioner());
00308       local_result = dotter.result();
00309     }
00310       }
00311     else 
00312       {
00313         for (LocalOrdinal i = 0; i != nrows_local; ++i)
00314     local_result += x_local[i] * ScalarTraits< Scalar >::conj(y_local[i]);
00315       }
00316     
00317     // FIXME (mfh 23 Apr 2010) Does MPI_SUM do the right thing for
00318     // complex or otherwise general MPI data types?  Perhaps an MPI_Op
00319     // should belong in the MessengerBase...
00320     return messenger_->globalSum (local_result);
00321   }
00322 
00323   typename ScalarTraits< Scalar >::magnitude_type 
00324   norm2 (const LocalOrdinal nrows_local, 
00325          const Scalar x_local[])
00326   {
00327     using tbb::auto_partitioner;
00328     using tbb::parallel_reduce;
00329     typedef typename ScalarTraits< Scalar >::magnitude_type magnitude_type;
00330 
00331     TbbNormSquared< LocalOrdinal, Scalar > normer (x_local);
00332     parallel_reduce (range_type(0, nrows_local), normer, auto_partitioner());
00333     const magnitude_type local_result = normer.result();
00334     const magnitude_type global_result = messenger_->globalSum (local_result);
00335     // Make sure that sqrt's argument is a magnitude_type.  Of
00336     // course global_result should be nonnegative real, but we
00337     // want the compiler to pick up the correct sqrt function.
00338     return sqrt (magnitude_type (global_result));
00339   }
00340 
00341   Scalar
00342   project (const LocalOrdinal nrows_local, 
00343      const Scalar q_local[], 
00344      Scalar v_local[])
00345   {
00346     const Scalar coeff = this->dot (nrows_local, v_local, q_local);
00347     this->axpy (nrows_local, -coeff, q_local, v_local);
00348     return coeff;
00349   }
00350 
00351       private:
00352   messenger_ptr messenger_;
00353       };
00354     } // namespace details
00355 
00358 
00359     template< class LocalOrdinal, class Scalar >
00360     void
00361     TbbMgs< LocalOrdinal, Scalar >::mgs (const LocalOrdinal nrows_local, 
00362            const LocalOrdinal ncols, 
00363            Scalar A_local[], 
00364            const LocalOrdinal lda_local,
00365            Scalar R[],
00366            const LocalOrdinal ldr)
00367     {
00368       details::TbbMgsOps< LocalOrdinal, Scalar > ops (messenger_);
00369       
00370       for (LocalOrdinal j = 0; j < ncols; ++j)
00371   {
00372     Scalar* const v = &A_local[j*lda_local];
00373     for (LocalOrdinal i = 0; i < j; ++i)
00374       {
00375         const Scalar* const q = &A_local[i*lda_local];
00376         R[i + j*ldr] = ops.project (nrows_local, q, v);
00377 #ifdef TBB_MGS_DEBUG
00378         if (my_rank == 0)
00379     cerr << "(i,j) = (" << i << "," << j << "): coeff = " 
00380          << R[i + j*ldr] << endl;
00381 #endif // TBB_MGS_DEBUG
00382       }
00383     const magnitude_type denom = ops.norm2 (nrows_local, v);
00384 #ifdef TBB_MGS_DEBUG
00385     if (my_rank == 0)
00386       cerr << "j = " << j << ": denom = " << denom << endl;
00387 #endif // TBB_MGS_DEBUG
00388 
00389     // FIXME (mfh 29 Apr 2010)
00390     //
00391     // NOTE IMPLICIT CAST.  This should work for complex numbers.
00392     // If it doesn't work for your Scalar data type, it means that
00393     // you need a different data type for the diagonal elements of
00394     // the R factor, than you need for the other elements.  This
00395     // is unlikely if we're comparing MGS against a Householder QR
00396     // factorization; I don't really understand how the latter
00397     // would work (not that it couldn't be given a sensible
00398     // interpretation) in the case of Scalars that aren't plain
00399     // old real or complex numbers.
00400     R[j + j*ldr] = Scalar (denom);
00401     ops.scale (nrows_local, v, denom);
00402   }
00403     }
00404   } // namespace TBB
00405 } // namespace TSQR
00406 
00407 #endif // __TSQR_TBB_TbbMgs_hpp
00408 
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends