Anasazi Version of the Day
Tsqr_DistTsqrRB.hpp
00001 #ifndef __TSQR_DistTsqrRB_hpp
00002 #define __TSQR_DistTsqrRB_hpp
00003 
00004 #include <Tsqr_ApplyType.hpp>
00005 #include <Tsqr_Combine.hpp>
00006 #include <Tsqr_Matrix.hpp>
00007 #include <Tsqr_ScalarTraits.hpp>
00008 #include <Tsqr_StatTimeMonitor.hpp>
00009 
00010 #include <algorithm>
00011 #include <sstream>
00012 #include <stdexcept>
00013 #include <utility>
00014 #include <vector>
00015 
00018 
00019 namespace TSQR {
00020 
00028   template< class LocalOrdinal, class Scalar >
00029   class DistTsqrRB {
00030   public:
00031     typedef LocalOrdinal ordinal_type;
00032     typedef Scalar scalar_type;
00033     typedef typename ScalarTraits< scalar_type >::magnitude_type magnitude_type;
00034     typedef MatView< ordinal_type, scalar_type > matview_type;
00035     typedef Matrix< ordinal_type, scalar_type > matrix_type;
00036     typedef int rank_type;
00037     typedef Combine< ordinal_type, scalar_type > combine_type;
00038 
00043     DistTsqrRB (const Teuchos::RCP< MessengerBase< scalar_type > >& messenger) :
00044       messenger_ (messenger),
00045       totalTime_ (Teuchos::TimeMonitor::getNewTimer ("DistTsqrRB::factorExplicit() total time")),
00046       reduceCommTime_ (Teuchos::TimeMonitor::getNewTimer ("DistTsqrRB::factorReduce() communication time")),
00047       reduceTime_ (Teuchos::TimeMonitor::getNewTimer ("DistTsqrRB::factorReduce() total time")),
00048       bcastCommTime_ (Teuchos::TimeMonitor::getNewTimer ("DistTsqrRB::explicitQBroadcast() communication time")),
00049       bcastTime_ (Teuchos::TimeMonitor::getNewTimer ("DistTsqrRB::explicitQBroadcast() total time"))
00050     {}
00051 
00055     void
00056     getStats (std::vector< TimeStats >& stats) const
00057     {
00058       const int numTimers = 5;
00059       stats.resize (std::max (stats.size(), static_cast<size_t>(numTimers)));
00060 
00061       stats[0] = totalStats_;
00062       stats[1] = reduceCommStats_;
00063       stats[2] = reduceStats_;
00064       stats[3] = bcastCommStats_;
00065       stats[4] = bcastStats_;
00066     }
00067 
00071     void
00072     getStatsLabels (std::vector< std::string >& labels) const
00073     {
00074       const int numTimers = 5;
00075       labels.resize (std::max (labels.size(), static_cast<size_t>(numTimers)));
00076 
00077       labels[0] = totalTime_->name();
00078       labels[1] = reduceCommTime_->name();
00079       labels[2] = reduceTime_->name();
00080       labels[3] = bcastCommTime_->name();
00081       labels[4] = bcastTime_->name();
00082     }
00083 
00086     bool QR_produces_R_factor_with_nonnegative_diagonal () const {
00087       return combine_type::QR_produces_R_factor_with_nonnegative_diagonal();
00088     }
00089 
00106     void
00107     factorExplicit (matview_type R_mine, matview_type Q_mine)
00108     {
00109       StatTimeMonitor totalMonitor (*totalTime_, totalStats_);
00110 
00111       // Dimension sanity checks.  R_mine should have at least as many
00112       // rows as columns (since we will be working on the upper
00113       // triangle).  Q_mine should have the same number of rows as
00114       // R_mine has columns, but Q_mine may have any number of
00115       // columns.  (It depends on how many columns of the explicit Q
00116       // factor we want to compute.)
00117       if (R_mine.nrows() < R_mine.ncols())
00118   {
00119     std::ostringstream os;
00120     os << "R factor input has fewer rows (" << R_mine.nrows() 
00121        << ") than columns (" << R_mine.ncols() << ")";
00122     // This is a logic error because TSQR users should not be
00123     // calling this method directly.
00124     throw std::logic_error (os.str());
00125   }
00126       else if (Q_mine.nrows() != R_mine.ncols())
00127   {
00128     std::ostringstream os;
00129     os << "Q factor input must have the same number of rows as the R "
00130       "factor input has columns.  Q has " << Q_mine.nrows() 
00131        << " rows, but R has " << R_mine.ncols() << " columns.";
00132     // This is a logic error because TSQR users should not be
00133     // calling this method directly.
00134     throw std::logic_error (os.str());
00135   }
00136 
00137       // The factorization is a recursion over processors [P_first, P_last].
00138       const rank_type P_mine = messenger_->rank();
00139       const rank_type P_first = 0;
00140       const rank_type P_last = messenger_->size() - 1;
00141 
00142       // Intermediate Q factors are stored implicitly.  QFactors[k] is
00143       // an upper triangular matrix of Householder reflectors, and
00144       // tauArrays[k] contains its corresponding scaling factors (TAU,
00145       // in LAPACK notation).  These two arrays will be filled in by
00146       // factorReduce().  Different MPI processes will have different
00147       // numbers of elements in these arrays.  In fact, on some
00148       // processes these arrays may be empty on output.  This is a
00149       // feature, not a bug!  
00150       //
00151       // Even though QFactors and tauArrays have the same type has the
00152       // first resp. second elements of DistTsqr::FactorOutput, they
00153       // are not compatible with the output of DistTsqr::factor() and
00154       // cannot be used as the input to DistTsqr::apply() or
00155       // DistTsqr::explicit_Q().  This is because factor() computes a
00156       // general factorization suitable for applying Q (or Q^T or Q^*)
00157       // to any compatible matrix, whereas factorExplicit() computes a
00158       // factorization specifically for the purpose of forming the
00159       // explicit Q factor.  The latter lets us use a broadcast to
00160       // compute Q, rather than a more message-intensive all-to-all
00161       // (butterfly).
00162       std::vector< matrix_type > QFactors;
00163       std::vector< std::vector< scalar_type > > tauArrays;
00164 
00165       {
00166   StatTimeMonitor reduceMonitor (*reduceTime_, reduceStats_);
00167   factorReduce (R_mine, P_mine, P_first, P_last, QFactors, tauArrays);
00168       }
00169 
00170       if (QFactors.size() != tauArrays.size())
00171   {
00172     std::ostringstream os;
00173     os << "QFactors and tauArrays should have the same number of element"
00174       "s after factorReduce() returns, but they do not.  QFactors has " 
00175        << QFactors.size() << " elements, but tauArrays has " 
00176        << tauArrays.size() << " elements.";
00177     throw std::logic_error (os.str());
00178   }
00179 
00180       Q_mine.fill (scalar_type (0));
00181       if (messenger_->rank() == 0)
00182   {
00183     for (ordinal_type j = 0; j < Q_mine.ncols(); ++j)
00184       Q_mine(j, j) = scalar_type (1);
00185   }
00186       // Scratch space for computing results to send to other processors.
00187       matrix_type Q_other (Q_mine.nrows(), Q_mine.ncols(), scalar_type (0));
00188       const rank_type numSteps = QFactors.size() - 1;
00189 
00190       {
00191   StatTimeMonitor bcastMonitor (*bcastTime_, bcastStats_);
00192   explicitQBroadcast (R_mine, Q_mine, Q_other.view(), 
00193           P_mine, P_first, P_last,
00194           numSteps, QFactors, tauArrays);
00195       }
00196     }
00197 
00198   private:
00199 
00200     void
00201     factorReduce (matview_type R_mine,
00202       const rank_type P_mine, 
00203       const rank_type P_first,
00204       const rank_type P_last,
00205       std::vector< matrix_type >& QFactors,
00206       std::vector< std::vector< scalar_type > >& tauArrays)
00207     {
00208       if (P_last < P_first)
00209   {
00210     std::ostringstream os;
00211     os << "Programming error in factorReduce() recursion: interval "
00212       "[P_first, P_last] is invalid: P_first = " << P_first 
00213        << ", P_last = " << P_last << ".";
00214     throw std::logic_error (os.str());
00215   }
00216       else if (P_mine < P_first || P_mine > P_last)
00217   {
00218     std::ostringstream os;
00219     os << "Programming error in factorReduce() recursion: P_mine (= " 
00220        << P_mine << ") is not in current process rank interval " 
00221        << "[P_first = " << P_first << ", P_last = " << P_last << "]";
00222     throw std::logic_error (os.str());
00223   }
00224       else if (P_last == P_first)
00225   return; // skip singleton intervals (see explanation below)
00226       else
00227   {
00228     // Recurse on two intervals: [P_first, P_mid-1] and [P_mid,
00229     // P_last].  For example, if [P_first, P_last] = [0, 9],
00230     // P_mid = floor( (0+9+1)/2 ) = 5 and the intervals are
00231     // [0,4] and [5,9].  
00232     // 
00233     // If [P_first, P_last] = [4,6], P_mid = floor( (4+6+1)/2 )
00234     // = 5 and the intervals are [4,4] (a singleton) and [5,6].
00235     // The latter case shows that singleton intervals may arise.
00236     // We treat them as a base case in the recursion.  Process 4
00237     // won't be skipped completely, though; it will get combined
00238     // with the result from [5,6].
00239 
00240     // Adding 1 and doing integer division works like "ceiling."
00241     const rank_type P_mid = (P_first + P_last + 1) / 2;
00242 
00243     if (P_mine < P_mid) // Interval [P_first, P_mid-1]
00244       factorReduce (R_mine, P_mine, P_first, P_mid - 1,
00245         QFactors, tauArrays);
00246     else // Interval [P_mid, P_last]
00247       factorReduce (R_mine, P_mine, P_mid, P_last,
00248         QFactors, tauArrays);
00249 
00250     // This only does anything if P_mine is either P_first or P_mid.
00251     if (P_mine == P_first)
00252       {
00253         const ordinal_type numCols = R_mine.ncols();
00254         matrix_type R_other (numCols, numCols);
00255         recv_R (R_other, P_mid);
00256 
00257         std::vector< scalar_type > tau (numCols);
00258         // Don't shrink the workspace array; doing so may
00259         // require expensive reallocation every time we send /
00260         // receive data.
00261         resizeWork (numCols);
00262         combine_.factor_pair (numCols, R_mine.get(), R_mine.lda(), 
00263             R_other.get(), R_other.lda(), 
00264             &tau[0], &work_[0]);
00265         QFactors.push_back (R_other);
00266         tauArrays.push_back (tau);
00267       }
00268     else if (P_mine == P_mid)
00269       send_R (R_mine, P_first);
00270   }
00271     }
00272 
00273     void
00274     explicitQBroadcast (matview_type R_mine,
00275       matview_type Q_mine,
00276       matview_type Q_other, // workspace
00277       const rank_type P_mine, 
00278       const rank_type P_first,
00279       const rank_type P_last,
00280       const rank_type curpos,
00281       std::vector< matrix_type >& QFactors,
00282       std::vector< std::vector< scalar_type > >& tauArrays)
00283     {
00284       if (P_last < P_first)
00285   {
00286     std::ostringstream os;
00287     os << "Programming error in explicitQBroadcast() recursion: interval"
00288       " [P_first, P_last] is invalid: P_first = " << P_first 
00289        << ", P_last = " << P_last << ".";
00290     throw std::logic_error (os.str());
00291   }
00292       else if (P_mine < P_first || P_mine > P_last)
00293   {
00294     std::ostringstream os;
00295     os << "Programming error in explicitQBroadcast() recursion: P_mine "
00296       "(= " << P_mine << ") is not in current process rank interval " 
00297        << "[P_first = " << P_first << ", P_last = " << P_last << "]";
00298     throw std::logic_error (os.str());
00299   }
00300       else if (P_last == P_first)
00301   return; // skip singleton intervals
00302       else
00303   {
00304     // Adding 1 and integer division works like "ceiling."
00305     const rank_type P_mid = (P_first + P_last + 1) / 2;
00306     rank_type newpos = curpos;
00307     if (P_mine == P_first)
00308       {
00309         if (curpos < 0)
00310     {
00311       std::ostringstream os;
00312       os << "Programming error: On the current P_first (= " 
00313          << P_first << ") proc: curpos (= " << curpos << ") < 0";
00314       throw std::logic_error (os.str());
00315     }
00316         // Q_impl, tau: implicitly stored local Q factor.
00317         matrix_type& Q_impl = QFactors[curpos];
00318         std::vector< scalar_type >& tau = tauArrays[curpos];
00319         
00320         // Apply implicitly stored local Q factor to 
00321         //   [Q_mine; 
00322         //    Q_other]
00323         // where Q_other = zeros(Q_mine.nrows(), Q_mine.ncols()).
00324         // Overwrite both Q_mine and Q_other with the result.
00325         Q_other.fill (scalar_type (0));
00326         combine_.apply_pair (ApplyType::NoTranspose, 
00327            Q_mine.ncols(), Q_impl.ncols(),
00328            Q_impl.get(), Q_impl.lda(), &tau[0],
00329            Q_mine.get(), Q_mine.lda(),
00330            Q_other.get(), Q_other.lda(), &work_[0]);
00331         // Send the resulting Q_other, and the final R factor, to P_mid.
00332         send_Q_R (Q_other, R_mine, P_mid);
00333         newpos = curpos - 1;
00334       }
00335     else if (P_mine == P_mid)
00336       // P_first computed my explicit Q factor component.
00337       // Receive it, and the final R factor, from P_first.
00338       recv_Q_R (Q_mine, R_mine, P_first);
00339 
00340     if (P_mine < P_mid) // Interval [P_first, P_mid-1]
00341       explicitQBroadcast (R_mine, Q_mine, Q_other, 
00342         P_mine, P_first, P_mid - 1,
00343         newpos, QFactors, tauArrays);
00344     else // Interval [P_mid, P_last]
00345       explicitQBroadcast (R_mine, Q_mine, Q_other, 
00346         P_mine, P_mid, P_last,
00347         newpos, QFactors, tauArrays);
00348   }
00349     }
00350 
00351     template< class ConstMatrixType1, class ConstMatrixType2 >
00352     void
00353     send_Q_R (const ConstMatrixType1& Q,
00354         const ConstMatrixType2& R,
00355         const rank_type destProc) 
00356     {
00357       StatTimeMonitor bcastCommMonitor (*bcastCommTime_, bcastCommStats_);
00358 
00359       const ordinal_type R_numCols = R.ncols();
00360       const ordinal_type Q_size = Q.nrows() * Q.ncols();
00361       const ordinal_type R_size = (R_numCols * (R_numCols + 1)) / 2;
00362       const ordinal_type numElts = Q_size + R_size;
00363 
00364       // Don't shrink the workspace array; doing so would still be
00365       // correct, but may require reallocation of data when it needs
00366       // to grow again.
00367       resizeWork (numElts);
00368 
00369       // Pack the Q data into the workspace array.
00370       matview_type Q_contig (Q.nrows(), Q.ncols(), &work_[0], Q.nrows());
00371       Q_contig.copy (Q);
00372       // Pack the R data into the workspace array.
00373       pack_R (R, &work_[Q_size]);
00374       messenger_->send (&work_[0], numElts, destProc, 0);
00375     }
00376 
00377     template< class MatrixType1, class MatrixType2 >
00378     void
00379     recv_Q_R (MatrixType1& Q, 
00380         MatrixType2& R, 
00381         const rank_type srcProc)
00382     {
00383       StatTimeMonitor bcastCommMonitor (*bcastCommTime_, bcastCommStats_);
00384 
00385       const ordinal_type R_numCols = R.ncols();
00386       const ordinal_type Q_size = Q.nrows() * Q.ncols();
00387       const ordinal_type R_size = (R_numCols * (R_numCols + 1)) / 2;
00388       const ordinal_type numElts = Q_size + R_size;
00389 
00390       // Don't shrink the workspace array; doing so would still be
00391       // correct, but may require reallocation of data when it needs
00392       // to grow again.
00393       resizeWork (numElts);
00394 
00395       messenger_->recv (&work_[0], numElts, srcProc, 0);
00396 
00397       // Unpack the C data from the workspace array.
00398       Q.copy (matview_type (Q.nrows(), Q.ncols(), &work_[0], Q.nrows()));
00399       // Unpack the R data from the workspace array.
00400       unpack_R (R, &work_[Q_size]);
00401     }
00402 
00403     template< class ConstMatrixType >
00404     void
00405     send_R (const ConstMatrixType& R, const rank_type destProc)
00406     {
00407       StatTimeMonitor reduceCommMonitor (*reduceCommTime_, reduceCommStats_);
00408 
00409       const ordinal_type numCols = R.ncols();
00410       const ordinal_type numElts = (numCols * (numCols+1)) / 2;
00411 
00412       // Don't shrink the workspace array; doing so would still be
00413       // correct, but may require reallocation of data when it needs
00414       // to grow again.
00415       resizeWork (numElts);
00416       // Pack the R data into the workspace array.
00417       pack_R (R, &work_[0]);
00418       messenger_->send (&work_[0], numElts, destProc, 0);
00419     }
00420 
00421     template< class MatrixType >
00422     void
00423     recv_R (MatrixType& R, const rank_type srcProc)
00424     {
00425       StatTimeMonitor reduceCommMonitor (*reduceCommTime_, reduceCommStats_);
00426 
00427       const ordinal_type numCols = R.ncols();
00428       const ordinal_type numElts = (numCols * (numCols+1)) / 2;
00429 
00430       // Don't shrink the workspace array; doing so would still be
00431       // correct, but may require reallocation of data when it needs
00432       // to grow again.
00433       resizeWork (numElts);
00434       messenger_->recv (&work_[0], numElts, srcProc, 0);
00435       // Unpack the R data from the workspace array.
00436       unpack_R (R, &work_[0]);
00437     }
00438 
00439     template< class MatrixType >
00440     static void 
00441     unpack_R (MatrixType& R, const scalar_type buf[])
00442     {
00443       ordinal_type curpos = 0;
00444       for (ordinal_type j = 0; j < R.ncols(); ++j)
00445   {
00446     scalar_type* const R_j = &R(0, j);
00447     for (ordinal_type i = 0; i <= j; ++i)
00448       R_j[i] = buf[curpos++];
00449   }
00450     }
00451 
00452     template< class ConstMatrixType >
00453     static void 
00454     pack_R (const ConstMatrixType& R, scalar_type buf[])
00455     {
00456       ordinal_type curpos = 0;
00457       for (ordinal_type j = 0; j < R.ncols(); ++j)
00458   {
00459     const scalar_type* const R_j = &R(0, j);
00460     for (ordinal_type i = 0; i <= j; ++i)
00461       buf[curpos++] = R_j[i];
00462   }
00463     }
00464 
00465     void
00466     resizeWork (const ordinal_type numElts)
00467     {
00468       typedef typename std::vector< scalar_type >::size_type vec_size_type;
00469       work_.resize (std::max (work_.size(), static_cast< vec_size_type >(numElts)));
00470     }
00471 
00472   private:
00473     combine_type combine_;
00474     Teuchos::RCP< MessengerBase< scalar_type > > messenger_;
00475     std::vector< scalar_type > work_;
00476 
00477     // Timers for various phases of the factorization.  Time is
00478     // cumulative over all calls of factorExplicit().
00479     Teuchos::RCP< Teuchos::Time > totalTime_;
00480     Teuchos::RCP< Teuchos::Time > reduceCommTime_;
00481     Teuchos::RCP< Teuchos::Time > reduceTime_;
00482     Teuchos::RCP< Teuchos::Time > bcastCommTime_;
00483     Teuchos::RCP< Teuchos::Time > bcastTime_;
00484 
00485     TimeStats totalStats_, reduceCommStats_, reduceStats_, bcastCommStats_, bcastStats_;
00486   };
00487 
00488 } // namespace TSQR
00489 
00490 #endif // __TSQR_DistTsqrRB_hpp
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends