Kokkos Node API and Local Linear Algebra Kernels Version of the Day
Tsqr_DistTsqrRB.hpp
00001 #ifndef __TSQR_DistTsqrRB_hpp
00002 #define __TSQR_DistTsqrRB_hpp
00003 
00004 #include <Tsqr_ApplyType.hpp>
00005 #include <Tsqr_Combine.hpp>
00006 #include <Tsqr_Matrix.hpp>
00007 #include <Tsqr_StatTimeMonitor.hpp>
00008 
00009 #include <Teuchos_ScalarTraits.hpp>
00010 #include <Teuchos_TimeMonitor.hpp>
00011 
00012 #include <algorithm>
00013 #include <sstream>
00014 #include <stdexcept>
00015 #include <utility>
00016 #include <vector>
00017 
00020 
00021 namespace TSQR {
00022 
00030   template< class LocalOrdinal, class Scalar >
00031   class DistTsqrRB {
00032   public:
00033     typedef LocalOrdinal ordinal_type;
00034     typedef Scalar scalar_type;
00035     typedef typename Teuchos::ScalarTraits< scalar_type >::magnitudeType magnitude_type;
00036     typedef MatView< ordinal_type, scalar_type > matview_type;
00037     typedef Matrix< ordinal_type, scalar_type > matrix_type;
00038     typedef int rank_type;
00039     typedef Combine< ordinal_type, scalar_type > combine_type;
00040 
00045     DistTsqrRB (const Teuchos::RCP< MessengerBase< scalar_type > >& messenger) :
00046       messenger_ (messenger),
00047       totalTime_ (Teuchos::TimeMonitor::getNewTimer ("DistTsqrRB::factorExplicit() total time")),
00048       reduceCommTime_ (Teuchos::TimeMonitor::getNewTimer ("DistTsqrRB::factorReduce() communication time")),
00049       reduceTime_ (Teuchos::TimeMonitor::getNewTimer ("DistTsqrRB::factorReduce() total time")),
00050       bcastCommTime_ (Teuchos::TimeMonitor::getNewTimer ("DistTsqrRB::explicitQBroadcast() communication time")),
00051       bcastTime_ (Teuchos::TimeMonitor::getNewTimer ("DistTsqrRB::explicitQBroadcast() total time"))
00052     {}
00053 
00057     void
00058     getStats (std::vector< TimeStats >& stats) const
00059     {
00060       const int numTimers = 5;
00061       stats.resize (std::max (stats.size(), static_cast<size_t>(numTimers)));
00062 
00063       stats[0] = totalStats_;
00064       stats[1] = reduceCommStats_;
00065       stats[2] = reduceStats_;
00066       stats[3] = bcastCommStats_;
00067       stats[4] = bcastStats_;
00068     }
00069 
00073     void
00074     getStatsLabels (std::vector< std::string >& labels) const
00075     {
00076       const int numTimers = 5;
00077       labels.resize (std::max (labels.size(), static_cast<size_t>(numTimers)));
00078 
00079       labels[0] = totalTime_->name();
00080       labels[1] = reduceCommTime_->name();
00081       labels[2] = reduceTime_->name();
00082       labels[3] = bcastCommTime_->name();
00083       labels[4] = bcastTime_->name();
00084     }
00085 
00088     bool QR_produces_R_factor_with_nonnegative_diagonal () const {
00089       return combine_type::QR_produces_R_factor_with_nonnegative_diagonal();
00090     }
00091 
00108     void
00109     factorExplicit (matview_type R_mine, matview_type Q_mine)
00110     {
00111       StatTimeMonitor totalMonitor (*totalTime_, totalStats_);
00112 
00113       // Dimension sanity checks.  R_mine should have at least as many
00114       // rows as columns (since we will be working on the upper
00115       // triangle).  Q_mine should have the same number of rows as
00116       // R_mine has columns, but Q_mine may have any number of
00117       // columns.  (It depends on how many columns of the explicit Q
00118       // factor we want to compute.)
00119       if (R_mine.nrows() < R_mine.ncols())
00120   {
00121     std::ostringstream os;
00122     os << "R factor input has fewer rows (" << R_mine.nrows() 
00123        << ") than columns (" << R_mine.ncols() << ")";
00124     // This is a logic error because TSQR users should not be
00125     // calling this method directly.
00126     throw std::logic_error (os.str());
00127   }
00128       else if (Q_mine.nrows() != R_mine.ncols())
00129   {
00130     std::ostringstream os;
00131     os << "Q factor input must have the same number of rows as the R "
00132       "factor input has columns.  Q has " << Q_mine.nrows() 
00133        << " rows, but R has " << R_mine.ncols() << " columns.";
00134     // This is a logic error because TSQR users should not be
00135     // calling this method directly.
00136     throw std::logic_error (os.str());
00137   }
00138 
00139       // The factorization is a recursion over processors [P_first, P_last].
00140       const rank_type P_mine = messenger_->rank();
00141       const rank_type P_first = 0;
00142       const rank_type P_last = messenger_->size() - 1;
00143 
00144       // Intermediate Q factors are stored implicitly.  QFactors[k] is
00145       // an upper triangular matrix of Householder reflectors, and
00146       // tauArrays[k] contains its corresponding scaling factors (TAU,
00147       // in LAPACK notation).  These two arrays will be filled in by
00148       // factorReduce().  Different MPI processes will have different
00149       // numbers of elements in these arrays.  In fact, on some
00150       // processes these arrays may be empty on output.  This is a
00151       // feature, not a bug!  
00152       //
00153       // Even though QFactors and tauArrays have the same type has the
00154       // first resp. second elements of DistTsqr::FactorOutput, they
00155       // are not compatible with the output of DistTsqr::factor() and
00156       // cannot be used as the input to DistTsqr::apply() or
00157       // DistTsqr::explicit_Q().  This is because factor() computes a
00158       // general factorization suitable for applying Q (or Q^T or Q^*)
00159       // to any compatible matrix, whereas factorExplicit() computes a
00160       // factorization specifically for the purpose of forming the
00161       // explicit Q factor.  The latter lets us use a broadcast to
00162       // compute Q, rather than a more message-intensive all-to-all
00163       // (butterfly).
00164       std::vector< matrix_type > QFactors;
00165       std::vector< std::vector< scalar_type > > tauArrays;
00166 
00167       {
00168   StatTimeMonitor reduceMonitor (*reduceTime_, reduceStats_);
00169   factorReduce (R_mine, P_mine, P_first, P_last, QFactors, tauArrays);
00170       }
00171 
00172       if (QFactors.size() != tauArrays.size())
00173   {
00174     std::ostringstream os;
00175     os << "QFactors and tauArrays should have the same number of element"
00176       "s after factorReduce() returns, but they do not.  QFactors has " 
00177        << QFactors.size() << " elements, but tauArrays has " 
00178        << tauArrays.size() << " elements.";
00179     throw std::logic_error (os.str());
00180   }
00181 
00182       Q_mine.fill (scalar_type (0));
00183       if (messenger_->rank() == 0)
00184   {
00185     for (ordinal_type j = 0; j < Q_mine.ncols(); ++j)
00186       Q_mine(j, j) = scalar_type (1);
00187   }
00188       // Scratch space for computing results to send to other processors.
00189       matrix_type Q_other (Q_mine.nrows(), Q_mine.ncols(), scalar_type (0));
00190       const rank_type numSteps = QFactors.size() - 1;
00191 
00192       {
00193   StatTimeMonitor bcastMonitor (*bcastTime_, bcastStats_);
00194   explicitQBroadcast (R_mine, Q_mine, Q_other.view(), 
00195           P_mine, P_first, P_last,
00196           numSteps, QFactors, tauArrays);
00197       }
00198     }
00199 
00200   private:
00201 
00202     void
00203     factorReduce (matview_type R_mine,
00204       const rank_type P_mine, 
00205       const rank_type P_first,
00206       const rank_type P_last,
00207       std::vector< matrix_type >& QFactors,
00208       std::vector< std::vector< scalar_type > >& tauArrays)
00209     {
00210       if (P_last < P_first)
00211   {
00212     std::ostringstream os;
00213     os << "Programming error in factorReduce() recursion: interval "
00214       "[P_first, P_last] is invalid: P_first = " << P_first 
00215        << ", P_last = " << P_last << ".";
00216     throw std::logic_error (os.str());
00217   }
00218       else if (P_mine < P_first || P_mine > P_last)
00219   {
00220     std::ostringstream os;
00221     os << "Programming error in factorReduce() recursion: P_mine (= " 
00222        << P_mine << ") is not in current process rank interval " 
00223        << "[P_first = " << P_first << ", P_last = " << P_last << "]";
00224     throw std::logic_error (os.str());
00225   }
00226       else if (P_last == P_first)
00227   return; // skip singleton intervals (see explanation below)
00228       else
00229   {
00230     // Recurse on two intervals: [P_first, P_mid-1] and [P_mid,
00231     // P_last].  For example, if [P_first, P_last] = [0, 9],
00232     // P_mid = floor( (0+9+1)/2 ) = 5 and the intervals are
00233     // [0,4] and [5,9].  
00234     // 
00235     // If [P_first, P_last] = [4,6], P_mid = floor( (4+6+1)/2 )
00236     // = 5 and the intervals are [4,4] (a singleton) and [5,6].
00237     // The latter case shows that singleton intervals may arise.
00238     // We treat them as a base case in the recursion.  Process 4
00239     // won't be skipped completely, though; it will get combined
00240     // with the result from [5,6].
00241 
00242     // Adding 1 and doing integer division works like "ceiling."
00243     const rank_type P_mid = (P_first + P_last + 1) / 2;
00244 
00245     if (P_mine < P_mid) // Interval [P_first, P_mid-1]
00246       factorReduce (R_mine, P_mine, P_first, P_mid - 1,
00247         QFactors, tauArrays);
00248     else // Interval [P_mid, P_last]
00249       factorReduce (R_mine, P_mine, P_mid, P_last,
00250         QFactors, tauArrays);
00251 
00252     // This only does anything if P_mine is either P_first or P_mid.
00253     if (P_mine == P_first)
00254       {
00255         const ordinal_type numCols = R_mine.ncols();
00256         matrix_type R_other (numCols, numCols);
00257         recv_R (R_other, P_mid);
00258 
00259         std::vector< scalar_type > tau (numCols);
00260         // Don't shrink the workspace array; doing so may
00261         // require expensive reallocation every time we send /
00262         // receive data.
00263         resizeWork (numCols);
00264         combine_.factor_pair (numCols, R_mine.get(), R_mine.lda(), 
00265             R_other.get(), R_other.lda(), 
00266             &tau[0], &work_[0]);
00267         QFactors.push_back (R_other);
00268         tauArrays.push_back (tau);
00269       }
00270     else if (P_mine == P_mid)
00271       send_R (R_mine, P_first);
00272   }
00273     }
00274 
00275     void
00276     explicitQBroadcast (matview_type R_mine,
00277       matview_type Q_mine,
00278       matview_type Q_other, // workspace
00279       const rank_type P_mine, 
00280       const rank_type P_first,
00281       const rank_type P_last,
00282       const rank_type curpos,
00283       std::vector< matrix_type >& QFactors,
00284       std::vector< std::vector< scalar_type > >& tauArrays)
00285     {
00286       if (P_last < P_first)
00287   {
00288     std::ostringstream os;
00289     os << "Programming error in explicitQBroadcast() recursion: interval"
00290       " [P_first, P_last] is invalid: P_first = " << P_first 
00291        << ", P_last = " << P_last << ".";
00292     throw std::logic_error (os.str());
00293   }
00294       else if (P_mine < P_first || P_mine > P_last)
00295   {
00296     std::ostringstream os;
00297     os << "Programming error in explicitQBroadcast() recursion: P_mine "
00298       "(= " << P_mine << ") is not in current process rank interval " 
00299        << "[P_first = " << P_first << ", P_last = " << P_last << "]";
00300     throw std::logic_error (os.str());
00301   }
00302       else if (P_last == P_first)
00303   return; // skip singleton intervals
00304       else
00305   {
00306     // Adding 1 and integer division works like "ceiling."
00307     const rank_type P_mid = (P_first + P_last + 1) / 2;
00308     rank_type newpos = curpos;
00309     if (P_mine == P_first)
00310       {
00311         if (curpos < 0)
00312     {
00313       std::ostringstream os;
00314       os << "Programming error: On the current P_first (= " 
00315          << P_first << ") proc: curpos (= " << curpos << ") < 0";
00316       throw std::logic_error (os.str());
00317     }
00318         // Q_impl, tau: implicitly stored local Q factor.
00319         matrix_type& Q_impl = QFactors[curpos];
00320         std::vector< scalar_type >& tau = tauArrays[curpos];
00321         
00322         // Apply implicitly stored local Q factor to 
00323         //   [Q_mine; 
00324         //    Q_other]
00325         // where Q_other = zeros(Q_mine.nrows(), Q_mine.ncols()).
00326         // Overwrite both Q_mine and Q_other with the result.
00327         Q_other.fill (scalar_type (0));
00328         combine_.apply_pair (ApplyType::NoTranspose, 
00329            Q_mine.ncols(), Q_impl.ncols(),
00330            Q_impl.get(), Q_impl.lda(), &tau[0],
00331            Q_mine.get(), Q_mine.lda(),
00332            Q_other.get(), Q_other.lda(), &work_[0]);
00333         // Send the resulting Q_other, and the final R factor, to P_mid.
00334         send_Q_R (Q_other, R_mine, P_mid);
00335         newpos = curpos - 1;
00336       }
00337     else if (P_mine == P_mid)
00338       // P_first computed my explicit Q factor component.
00339       // Receive it, and the final R factor, from P_first.
00340       recv_Q_R (Q_mine, R_mine, P_first);
00341 
00342     if (P_mine < P_mid) // Interval [P_first, P_mid-1]
00343       explicitQBroadcast (R_mine, Q_mine, Q_other, 
00344         P_mine, P_first, P_mid - 1,
00345         newpos, QFactors, tauArrays);
00346     else // Interval [P_mid, P_last]
00347       explicitQBroadcast (R_mine, Q_mine, Q_other, 
00348         P_mine, P_mid, P_last,
00349         newpos, QFactors, tauArrays);
00350   }
00351     }
00352 
00353     template< class ConstMatrixType1, class ConstMatrixType2 >
00354     void
00355     send_Q_R (const ConstMatrixType1& Q,
00356         const ConstMatrixType2& R,
00357         const rank_type destProc) 
00358     {
00359       StatTimeMonitor bcastCommMonitor (*bcastCommTime_, bcastCommStats_);
00360 
00361       const ordinal_type R_numCols = R.ncols();
00362       const ordinal_type Q_size = Q.nrows() * Q.ncols();
00363       const ordinal_type R_size = (R_numCols * (R_numCols + 1)) / 2;
00364       const ordinal_type numElts = Q_size + R_size;
00365 
00366       // Don't shrink the workspace array; doing so would still be
00367       // correct, but may require reallocation of data when it needs
00368       // to grow again.
00369       resizeWork (numElts);
00370 
00371       // Pack the Q data into the workspace array.
00372       matview_type Q_contig (Q.nrows(), Q.ncols(), &work_[0], Q.nrows());
00373       Q_contig.copy (Q);
00374       // Pack the R data into the workspace array.
00375       pack_R (R, &work_[Q_size]);
00376       messenger_->send (&work_[0], numElts, destProc, 0);
00377     }
00378 
00379     template< class MatrixType1, class MatrixType2 >
00380     void
00381     recv_Q_R (MatrixType1& Q, 
00382         MatrixType2& R, 
00383         const rank_type srcProc)
00384     {
00385       StatTimeMonitor bcastCommMonitor (*bcastCommTime_, bcastCommStats_);
00386 
00387       const ordinal_type R_numCols = R.ncols();
00388       const ordinal_type Q_size = Q.nrows() * Q.ncols();
00389       const ordinal_type R_size = (R_numCols * (R_numCols + 1)) / 2;
00390       const ordinal_type numElts = Q_size + R_size;
00391 
00392       // Don't shrink the workspace array; doing so would still be
00393       // correct, but may require reallocation of data when it needs
00394       // to grow again.
00395       resizeWork (numElts);
00396 
00397       messenger_->recv (&work_[0], numElts, srcProc, 0);
00398 
00399       // Unpack the C data from the workspace array.
00400       Q.copy (matview_type (Q.nrows(), Q.ncols(), &work_[0], Q.nrows()));
00401       // Unpack the R data from the workspace array.
00402       unpack_R (R, &work_[Q_size]);
00403     }
00404 
00405     template< class ConstMatrixType >
00406     void
00407     send_R (const ConstMatrixType& R, const rank_type destProc)
00408     {
00409       StatTimeMonitor reduceCommMonitor (*reduceCommTime_, reduceCommStats_);
00410 
00411       const ordinal_type numCols = R.ncols();
00412       const ordinal_type numElts = (numCols * (numCols+1)) / 2;
00413 
00414       // Don't shrink the workspace array; doing so would still be
00415       // correct, but may require reallocation of data when it needs
00416       // to grow again.
00417       resizeWork (numElts);
00418       // Pack the R data into the workspace array.
00419       pack_R (R, &work_[0]);
00420       messenger_->send (&work_[0], numElts, destProc, 0);
00421     }
00422 
00423     template< class MatrixType >
00424     void
00425     recv_R (MatrixType& R, const rank_type srcProc)
00426     {
00427       StatTimeMonitor reduceCommMonitor (*reduceCommTime_, reduceCommStats_);
00428 
00429       const ordinal_type numCols = R.ncols();
00430       const ordinal_type numElts = (numCols * (numCols+1)) / 2;
00431 
00432       // Don't shrink the workspace array; doing so would still be
00433       // correct, but may require reallocation of data when it needs
00434       // to grow again.
00435       resizeWork (numElts);
00436       messenger_->recv (&work_[0], numElts, srcProc, 0);
00437       // Unpack the R data from the workspace array.
00438       unpack_R (R, &work_[0]);
00439     }
00440 
00441     template< class MatrixType >
00442     static void 
00443     unpack_R (MatrixType& R, const scalar_type buf[])
00444     {
00445       ordinal_type curpos = 0;
00446       for (ordinal_type j = 0; j < R.ncols(); ++j)
00447   {
00448     scalar_type* const R_j = &R(0, j);
00449     for (ordinal_type i = 0; i <= j; ++i)
00450       R_j[i] = buf[curpos++];
00451   }
00452     }
00453 
00454     template< class ConstMatrixType >
00455     static void 
00456     pack_R (const ConstMatrixType& R, scalar_type buf[])
00457     {
00458       ordinal_type curpos = 0;
00459       for (ordinal_type j = 0; j < R.ncols(); ++j)
00460   {
00461     const scalar_type* const R_j = &R(0, j);
00462     for (ordinal_type i = 0; i <= j; ++i)
00463       buf[curpos++] = R_j[i];
00464   }
00465     }
00466 
00467     void
00468     resizeWork (const ordinal_type numElts)
00469     {
00470       typedef typename std::vector< scalar_type >::size_type vec_size_type;
00471       work_.resize (std::max (work_.size(), static_cast< vec_size_type >(numElts)));
00472     }
00473 
00474   private:
00475     combine_type combine_;
00476     Teuchos::RCP< MessengerBase< scalar_type > > messenger_;
00477     std::vector< scalar_type > work_;
00478 
00479     // Timers for various phases of the factorization.  Time is
00480     // cumulative over all calls of factorExplicit().
00481     Teuchos::RCP< Teuchos::Time > totalTime_;
00482     Teuchos::RCP< Teuchos::Time > reduceCommTime_;
00483     Teuchos::RCP< Teuchos::Time > reduceTime_;
00484     Teuchos::RCP< Teuchos::Time > bcastCommTime_;
00485     Teuchos::RCP< Teuchos::Time > bcastTime_;
00486 
00487     TimeStats totalStats_, reduceCommStats_, reduceStats_, bcastCommStats_, bcastStats_;
00488   };
00489 
00490 } // namespace TSQR
00491 
00492 #endif // __TSQR_DistTsqrRB_hpp
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends