Anasazi Version of the Day
Tsqr_DistTsqrHelper.hpp
00001 // @HEADER
00002 // ***********************************************************************
00003 //
00004 //                 Anasazi: Block Eigensolvers Package
00005 //                 Copyright (2010) Sandia Corporation
00006 //
00007 // Under terms of Contract DE-AC04-94AL85000, there is a non-exclusive
00008 // license for use of this work by or on behalf of the U.S. Government.
00009 //
00010 // This library is free software; you can redistribute it and/or modify
00011 // it under the terms of the GNU Lesser General Public License as
00012 // published by the Free Software Foundation; either version 2.1 of the
00013 // License, or (at your option) any later version.
00014 //
00015 // This library is distributed in the hope that it will be useful, but
00016 // WITHOUT ANY WARRANTY; without even the implied warranty of
00017 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
00018 // Lesser General Public License for more details.
00019 //
00020 // You should have received a copy of the GNU Lesser General Public
00021 // License along with this library; if not, write to the Free Software
00022 // Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307
00023 // USA
00024 // Questions? Contact Michael A. Heroux (maherou@sandia.gov)
00025 //
00026 // ***********************************************************************
00027 // @HEADER
00028 
00029 #ifndef __TSQR_Tsqr_DistTsqrHelper_hpp
00030 #define __TSQR_Tsqr_DistTsqrHelper_hpp
00031 
00032 #include <Tsqr_MatView.hpp>
00033 #include <Tsqr_MessengerBase.hpp>
00034 #include <Tsqr_Combine.hpp>
00035 #include <Tsqr_Util.hpp>
00036 
00037 #include <algorithm> // std::min, std::max
00038 #include <sstream>
00039 #include <stdexcept>
00040 #include <vector>
00041 
00044 
00045 namespace TSQR {
00046 
00053   template< class LocalOrdinal, class Scalar >
00054   class DistTsqrHelper {
00055   public:
00056     DistTsqrHelper () {}
00057 
00058     void
00059     factor_pair (const LocalOrdinal ncols,
00060      std::vector< Scalar >& R_mine,
00061      const LocalOrdinal P_mine,
00062      const LocalOrdinal P_other,
00063      const LocalOrdinal tag, 
00064      MessengerBase< Scalar >* const messenger,
00065      std::vector< std::vector< Scalar > >& Q_factors,
00066      std::vector< std::vector< Scalar > >& tau_arrays,
00067      std::vector< Scalar >& work) 
00068     {
00069       using std::endl;
00070       using std::ostringstream;
00071       using std::vector;
00072 
00073       if (P_mine == P_other)
00074   return; // nothing to do
00075 
00076       const int P_top = std::min (P_mine, P_other);
00077       const int P_bot = std::max (P_mine, P_other);
00078       const LocalOrdinal nelts = ncols * ncols;
00079       const LocalOrdinal ldr = ncols;
00080       vector< Scalar > R_other (nelts);
00081       vector< Scalar > tau (ncols);
00082 
00083       // Send and receive R factor.
00084       messenger->swapData (&R_mine[0], &R_other[0], nelts, P_other, tag);
00085 
00086       Combine< LocalOrdinal, Scalar > combine;
00087       if (P_mine == P_top)
00088   {
00089     combine.factor_pair (ncols, &R_mine[0], ldr, &R_other[0], ldr, &tau[0], &work[0]);
00090     Q_factors.push_back (R_other);
00091     tau_arrays.push_back (tau);
00092   }
00093       else if (P_mine == P_bot)
00094   {
00095     combine.factor_pair (ncols, &R_other[0], ldr, &R_mine[0], ldr, &tau[0], &work[0]);
00096     Q_factors.push_back (R_mine);
00097     // Make sure that the "bottom" processor gets the current R
00098     // factor, which is returned in R_mine.
00099     copy_matrix (ncols, ncols, &R_mine[0], ldr, &R_other[0], ldr);
00100     tau_arrays.push_back (tau);
00101   }
00102       else
00103   {
00104     // mfh 16 Apr 2010: the troubles with assert statements are as follows:
00105     //
00106     // 1. They go away in a release build.
00107     // 2. They don't often print out useful diagnostic information.
00108     // 3. If you mistype the assert, like "assert(errcode = 1);" instead of 
00109     //    "assert(errcode == 1)", you'll get false positives.
00110     ostringstream os;
00111     os << "Should never get here: P_mine (= " << P_mine 
00112        << ") not one of P_top, P_bot = " << P_top << ", " << P_bot;
00113     throw std::logic_error (os.str());
00114   }
00115     }
00116 
00117     void
00118     factor_helper (const LocalOrdinal ncols,
00119        std::vector< Scalar >& R_mine,
00120        const LocalOrdinal my_rank,
00121        const LocalOrdinal P_first,
00122        const LocalOrdinal P_last,
00123        const LocalOrdinal tag,
00124        MessengerBase< Scalar >* const messenger,
00125        std::vector< std::vector< Scalar > >& Q_factors,
00126        std::vector< std::vector< Scalar > >& tau_arrays,
00127        std::vector< Scalar >& work)
00128     {
00129       using std::endl;
00130       using std::ostringstream;
00131       using std::vector;
00132 
00133       if (P_last <= P_first)
00134   return;
00135       else
00136   {
00137     const int P = P_last - P_first + 1;
00138     // Whether the interval [P_first, P_last] has an even number of
00139     // elements.  Our interval splitting scheme ensures that the
00140     // interval [P_first, P_mid - 1] always has an even number of
00141     // elements.
00142     const bool b_even = (P % 2 == 0);
00143     // We split the interval [P_first, P_last] into 2 intervals:
00144     // [P_first, P_mid-1], and [P_mid, P_last].  We bias the
00145     // splitting procedure so that the lower interval always has an
00146     // even number of processor ranks, and never has fewer processor
00147     // ranks than the higher interval.
00148     const int P_mid = b_even ? (P_first + P/2) : (P_first + P/2 + 1); 
00149 
00150     if (my_rank < P_mid) // Interval [P_first, P_mid-1]
00151       {
00152         factor_helper (ncols, R_mine, my_rank, P_first, P_mid - 1, 
00153            tag + 1, messenger, Q_factors, tau_arrays, work);
00154 
00155         // If there aren't an even number of processors in the
00156         // original interval, then the last processor in the lower
00157         // interval has to skip this round.
00158         if (b_even || my_rank < P_mid - 1)
00159     {
00160       const int my_offset = my_rank - P_first;
00161       const int P_other = P_mid + my_offset;
00162       if (P_other < P_mid || P_other > P_last)
00163         throw std::logic_error ("P_other not in [P_mid,P_last] range");
00164 
00165       factor_pair (ncols, R_mine, my_rank, P_other, tag, 
00166              messenger, Q_factors, tau_arrays, work);
00167     }
00168 
00169         // If I'm skipping this round, get the "current" R factor
00170         // from P_mid.
00171         if (! b_even && my_rank == P_mid - 1)
00172     {
00173       const int theTag = 142; // magic constant
00174       messenger->recv (&R_mine[0], ncols*ncols, P_mid, theTag);
00175     }
00176       }
00177     else // Interval [P_mid, P_last]
00178       {
00179         factor_helper (ncols, R_mine, my_rank, P_mid, P_last, 
00180            tag + 1, messenger, Q_factors, tau_arrays, work);
00181 
00182         const int my_offset = my_rank - P_mid;
00183         const int P_other = P_first + my_offset;
00184 
00185         if (P_other < P_first || P_other >= P_mid)
00186     throw std::logic_error ("P_other not in [P_first,P_mid-1] range");
00187         factor_pair (ncols, R_mine, my_rank, P_other, tag, 
00188          messenger, Q_factors, tau_arrays, work);
00189 
00190         // If Proc P_mid-1 is skipping this round, Proc P_mid will
00191         // send it the "current" R factor.
00192         if (! b_even)
00193     {
00194       const int theTag = 142; // magic constant
00195       messenger->send (&R_mine[0], ncols*ncols, P_mid-1, theTag);
00196     }
00197       }
00198   }
00199     }
00200 
00201     void
00202     apply_pair (const ApplyType& apply_type,
00203     const LocalOrdinal ncols_C,
00204     const LocalOrdinal ncols_Q,
00205     Scalar C_mine[],
00206     const LocalOrdinal ldc_mine,
00207     Scalar C_other[], // contiguous ncols_C x ncols_C scratch
00208     const LocalOrdinal P_mine,
00209     const LocalOrdinal P_other,
00210     const LocalOrdinal tag, 
00211     MessengerBase< Scalar >* const messenger,
00212     const std::vector< Scalar >& Q_cur,
00213     const std::vector< Scalar >& tau_cur,
00214     std::vector< Scalar >& work)
00215     {
00216       using std::endl;
00217       using std::ostringstream;
00218       using std::vector;
00219 
00220       if (P_mine == P_other)
00221   return; // nothing to do
00222     
00223       const int P_top = std::min (P_mine, P_other);
00224       const int P_bot = std::max (P_mine, P_other);
00225     
00226       const LocalOrdinal nelts = ncols_C * ncols_C;
00227       const LocalOrdinal ldq = ncols_Q;
00228       const LocalOrdinal ldc_other = ncols_C;
00229 
00230       // Send and receive C_mine resp. C_other to the other processor of
00231       // the pair.
00232       messenger->swapData (&C_mine[0], &C_other[0], nelts, P_other, tag);
00233 
00234       Combine< LocalOrdinal, Scalar > combine;
00235       if (P_mine == P_top)
00236   combine.apply_pair (apply_type, ncols_C, ncols_Q, &Q_cur[0], ldq, 
00237           &tau_cur[0], C_mine, ldc_mine, C_other, ldc_other, 
00238           &work[0]);
00239       else if (P_mine == P_bot)
00240   combine.apply_pair (apply_type, ncols_C, ncols_Q, &Q_cur[0], ldq, 
00241           &tau_cur[0], C_other, ldc_other, C_mine, ldc_mine, 
00242           &work[0]);
00243       else
00244   {
00245     ostringstream os;
00246     os << "Should never get here: P_mine (= " << P_mine 
00247        << ") not one of P_top, P_bot = " << P_top << ", " << P_bot;
00248     throw std::logic_error (os.str());
00249   }
00250     }
00251 
00252     void
00253     apply_helper (const ApplyType& apply_type,
00254       const LocalOrdinal ncols_C,
00255       const LocalOrdinal ncols_Q,
00256       Scalar C_mine[],
00257       const LocalOrdinal ldc_mine,
00258       Scalar C_other[], // contiguous ncols_C x ncols_C scratch
00259       const LocalOrdinal my_rank,
00260       const LocalOrdinal P_first,
00261       const LocalOrdinal P_last,
00262       const LocalOrdinal tag,
00263       MessengerBase< Scalar >* const messenger,
00264       const std::vector< std::vector< Scalar > >& Q_factors,
00265       const std::vector< std::vector< Scalar > >& tau_arrays,
00266       const LocalOrdinal cur_pos,
00267       std::vector< Scalar >& work)
00268     {
00269       using std::endl;
00270       using std::ostringstream;
00271       using std::vector;
00272 
00273       if (P_last <= P_first)
00274   return;
00275       else
00276   {
00277     const int P = P_last - P_first + 1;
00278     // Whether the interval [P_first, P_last] has an even number of
00279     // elements.  Our interval splitting scheme ensures that the
00280     // interval [P_first, P_mid - 1] always has an even number of
00281     // elements.
00282     const bool b_even = (P % 2 == 0);
00283     // We split the interval [P_first, P_last] into 2 intervals:
00284     // [P_first, P_mid-1], and [P_mid, P_last].  We bias the
00285     // splitting procedure so that the lower interval always has an
00286     // even number of processor ranks, and never has fewer processor
00287     // ranks than the higher interval.
00288     const int P_mid = b_even ? (P_first + P/2) : (P_first + P/2 + 1); 
00289 
00290     if (my_rank < P_mid) // Interval [P_first, P_mid - 1]
00291       {
00292         const bool b_participating = b_even || my_rank < P_mid - 1;
00293 
00294         if (cur_pos < 0)
00295     {
00296       ostringstream os;
00297       os << "On Proc " << my_rank << ": cur_pos (= " << cur_pos 
00298          << ") < 0; lower interval [" << P_first << "," << (P_mid-1)
00299          << "]; original interval [" << P_first << "," << P_last 
00300          << "]" << endl;
00301       throw std::logic_error (os.str());
00302     }
00303 
00304         // If there aren't an even number of processors in the
00305         // original interval, then the last processor in the lower
00306         // interval has to skip this round.  Since we skip this
00307         // round, don't decrement cur_pos (else we'll skip an entry
00308         // and eventually fall off the front of the array.
00309         int new_cur_pos;
00310         if (b_even || my_rank < P_mid - 1)
00311     {
00312       if (! b_participating)
00313         throw std::logic_error("Should never get here");
00314 
00315       const int my_offset = my_rank - P_first;
00316       const int P_other = P_mid + my_offset;
00317       // assert (P_mid <= P_other && P_other <= P_last);
00318       if (P_other < P_mid || P_other > P_last)
00319         throw std::logic_error("Should never get here");
00320 
00321       apply_pair (apply_type, ncols_C, ncols_Q, C_mine, ldc_mine, 
00322             C_other, my_rank, P_other, tag, messenger, 
00323             Q_factors[cur_pos], tau_arrays[cur_pos], work);
00324       new_cur_pos = cur_pos - 1;
00325     }
00326         else
00327     {
00328       if (b_participating)
00329         throw std::logic_error("Should never get here");
00330 
00331       new_cur_pos = cur_pos;
00332     }
00333         apply_helper (apply_type, ncols_C, ncols_Q, C_mine, ldc_mine, 
00334           C_other, my_rank, P_first, P_mid - 1, tag + 1, 
00335           messenger, Q_factors, tau_arrays, new_cur_pos, 
00336           work);
00337       }
00338     else
00339       {
00340         if (cur_pos < 0)
00341     {
00342       ostringstream os;
00343       os << "On Proc " << my_rank << ": cur_pos (= " << cur_pos 
00344          << ") < 0; upper interval [" << P_mid << "," << P_last
00345          << "]; original interval [" << P_first << "," << P_last 
00346          << "]" << endl;
00347       throw std::logic_error (os.str());
00348     }
00349 
00350         const int my_offset = my_rank - P_mid;
00351         const int P_other = P_first + my_offset;
00352         // assert (0 <= P_other && P_other < P_mid);
00353         apply_pair (apply_type, ncols_C, ncols_Q, C_mine, ldc_mine, 
00354         C_other, my_rank, P_other, tag, messenger, 
00355         Q_factors[cur_pos], tau_arrays[cur_pos], work);
00356         apply_helper (apply_type, ncols_C, ncols_Q, C_mine, ldc_mine, 
00357           C_other, my_rank, P_mid, P_last, tag + 1, 
00358           messenger, Q_factors, tau_arrays, cur_pos - 1, 
00359           work);
00360       }
00361   }
00362     }
00363   };
00364 
00365 } // namespace TSQR
00366 
00367 #endif // __TSQR_Tsqr_DistTsqrHelper_hpp
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends