Kokkos Node API and Local Linear Algebra Kernels Version of the Day
Tsqr_DistTsqrHelper.hpp
00001 //@HEADER
00002 // ************************************************************************
00003 // 
00004 //          Kokkos: Node API and Parallel Node Kernels
00005 //              Copyright (2008) Sandia Corporation
00006 // 
00007 // Under the terms of Contract DE-AC04-94AL85000 with Sandia Corporation,
00008 // the U.S. Government retains certain rights in this software.
00009 // 
00010 // Redistribution and use in source and binary forms, with or without
00011 // modification, are permitted provided that the following conditions are
00012 // met:
00013 //
00014 // 1. Redistributions of source code must retain the above copyright
00015 // notice, this list of conditions and the following disclaimer.
00016 //
00017 // 2. Redistributions in binary form must reproduce the above copyright
00018 // notice, this list of conditions and the following disclaimer in the
00019 // documentation and/or other materials provided with the distribution.
00020 //
00021 // 3. Neither the name of the Corporation nor the names of the
00022 // contributors may be used to endorse or promote products derived from
00023 // this software without specific prior written permission.
00024 //
00025 // THIS SOFTWARE IS PROVIDED BY SANDIA CORPORATION "AS IS" AND ANY
00026 // EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
00027 // IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
00028 // PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL SANDIA CORPORATION OR THE
00029 // CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
00030 // EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
00031 // PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
00032 // PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
00033 // LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
00034 // NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
00035 // SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
00036 //
00037 // Questions? Contact Michael A. Heroux (maherou@sandia.gov) 
00038 // 
00039 // ************************************************************************
00040 //@HEADER
00041 
00042 #ifndef __TSQR_Tsqr_DistTsqrHelper_hpp
00043 #define __TSQR_Tsqr_DistTsqrHelper_hpp
00044 
00045 #include <Tsqr_MatView.hpp>
00046 #include <Tsqr_MessengerBase.hpp>
00047 #include <Tsqr_Combine.hpp>
00048 #include <Tsqr_Util.hpp>
00049 
00050 #include <algorithm> // std::min, std::max
00051 #include <sstream>
00052 #include <stdexcept>
00053 #include <vector>
00054 
00055 
00056 namespace TSQR {
00057 
00064   template<class LocalOrdinal, class Scalar>
00065   class DistTsqrHelper {
00066   public:
00067     DistTsqrHelper () {}
00068 
00069     void
00070     factor_pair (const LocalOrdinal ncols,
00071      std::vector< Scalar >& R_mine,
00072      const LocalOrdinal P_mine,
00073      const LocalOrdinal P_other,
00074      const LocalOrdinal tag, 
00075      MessengerBase<Scalar>* const messenger,
00076      std::vector<std::vector<Scalar> >& Q_factors,
00077      std::vector<std::vector<Scalar> >& tau_arrays,
00078      std::vector<Scalar >& work) 
00079     {
00080       using std::endl;
00081       using std::ostringstream;
00082       using std::vector;
00083 
00084       if (P_mine == P_other)
00085   return; // nothing to do
00086 
00087       const int P_top = std::min (P_mine, P_other);
00088       const int P_bot = std::max (P_mine, P_other);
00089       const LocalOrdinal nelts = ncols * ncols;
00090       const LocalOrdinal ldr = ncols;
00091       vector< Scalar > R_other (nelts);
00092       vector< Scalar > tau (ncols);
00093 
00094       // Send and receive R factor.
00095       messenger->swapData (&R_mine[0], &R_other[0], nelts, P_other, tag);
00096 
00097       Combine< LocalOrdinal, Scalar > combine;
00098       if (P_mine == P_top)
00099   {
00100     combine.factor_pair (ncols, &R_mine[0], ldr, &R_other[0], ldr, &tau[0], &work[0]);
00101     Q_factors.push_back (R_other);
00102     tau_arrays.push_back (tau);
00103   }
00104       else if (P_mine == P_bot)
00105   {
00106     combine.factor_pair (ncols, &R_other[0], ldr, &R_mine[0], ldr, &tau[0], &work[0]);
00107     Q_factors.push_back (R_mine);
00108     // Make sure that the "bottom" processor gets the current R
00109     // factor, which is returned in R_mine.
00110     copy_matrix (ncols, ncols, &R_mine[0], ldr, &R_other[0], ldr);
00111     tau_arrays.push_back (tau);
00112   }
00113       else
00114   {
00115     // mfh 16 Apr 2010: the troubles with assert statements are as follows:
00116     //
00117     // 1. They go away in a release build.
00118     // 2. They don't often print out useful diagnostic information.
00119     // 3. If you mistype the assert, like "assert(errcode = 1);" instead of 
00120     //    "assert(errcode == 1)", you'll get false positives.
00121     ostringstream os;
00122     os << "Should never get here: P_mine (= " << P_mine 
00123        << ") not one of P_top, P_bot = " << P_top << ", " << P_bot;
00124     throw std::logic_error (os.str());
00125   }
00126     }
00127 
00128     void
00129     factor_helper (const LocalOrdinal ncols,
00130        std::vector< Scalar >& R_mine,
00131        const LocalOrdinal my_rank,
00132        const LocalOrdinal P_first,
00133        const LocalOrdinal P_last,
00134        const LocalOrdinal tag,
00135        MessengerBase< Scalar >* const messenger,
00136        std::vector< std::vector< Scalar > >& Q_factors,
00137        std::vector< std::vector< Scalar > >& tau_arrays,
00138        std::vector< Scalar >& work)
00139     {
00140       using std::endl;
00141       using std::ostringstream;
00142       using std::vector;
00143 
00144       if (P_last <= P_first)
00145   return;
00146       else
00147   {
00148     const int P = P_last - P_first + 1;
00149     // Whether the interval [P_first, P_last] has an even number of
00150     // elements.  Our interval splitting scheme ensures that the
00151     // interval [P_first, P_mid - 1] always has an even number of
00152     // elements.
00153     const bool b_even = (P % 2 == 0);
00154     // We split the interval [P_first, P_last] into 2 intervals:
00155     // [P_first, P_mid-1], and [P_mid, P_last].  We bias the
00156     // splitting procedure so that the lower interval always has an
00157     // even number of processor ranks, and never has fewer processor
00158     // ranks than the higher interval.
00159     const int P_mid = b_even ? (P_first + P/2) : (P_first + P/2 + 1); 
00160 
00161     if (my_rank < P_mid) // Interval [P_first, P_mid-1]
00162       {
00163         factor_helper (ncols, R_mine, my_rank, P_first, P_mid - 1, 
00164            tag + 1, messenger, Q_factors, tau_arrays, work);
00165 
00166         // If there aren't an even number of processors in the
00167         // original interval, then the last processor in the lower
00168         // interval has to skip this round.
00169         if (b_even || my_rank < P_mid - 1)
00170     {
00171       const int my_offset = my_rank - P_first;
00172       const int P_other = P_mid + my_offset;
00173       if (P_other < P_mid || P_other > P_last)
00174         throw std::logic_error ("P_other not in [P_mid,P_last] range");
00175 
00176       factor_pair (ncols, R_mine, my_rank, P_other, tag, 
00177              messenger, Q_factors, tau_arrays, work);
00178     }
00179 
00180         // If I'm skipping this round, get the "current" R factor
00181         // from P_mid.
00182         if (! b_even && my_rank == P_mid - 1)
00183     {
00184       const int theTag = 142; // magic constant
00185       messenger->recv (&R_mine[0], ncols*ncols, P_mid, theTag);
00186     }
00187       }
00188     else // Interval [P_mid, P_last]
00189       {
00190         factor_helper (ncols, R_mine, my_rank, P_mid, P_last, 
00191            tag + 1, messenger, Q_factors, tau_arrays, work);
00192 
00193         const int my_offset = my_rank - P_mid;
00194         const int P_other = P_first + my_offset;
00195 
00196         if (P_other < P_first || P_other >= P_mid)
00197     throw std::logic_error ("P_other not in [P_first,P_mid-1] range");
00198         factor_pair (ncols, R_mine, my_rank, P_other, tag, 
00199          messenger, Q_factors, tau_arrays, work);
00200 
00201         // If Proc P_mid-1 is skipping this round, Proc P_mid will
00202         // send it the "current" R factor.
00203         if (! b_even)
00204     {
00205       const int theTag = 142; // magic constant
00206       messenger->send (&R_mine[0], ncols*ncols, P_mid-1, theTag);
00207     }
00208       }
00209   }
00210     }
00211 
00212     void
00213     apply_pair (const ApplyType& apply_type,
00214     const LocalOrdinal ncols_C,
00215     const LocalOrdinal ncols_Q,
00216     Scalar C_mine[],
00217     const LocalOrdinal ldc_mine,
00218     Scalar C_other[], // contiguous ncols_C x ncols_C scratch
00219     const LocalOrdinal P_mine,
00220     const LocalOrdinal P_other,
00221     const LocalOrdinal tag, 
00222     MessengerBase< Scalar >* const messenger,
00223     const std::vector< Scalar >& Q_cur,
00224     const std::vector< Scalar >& tau_cur,
00225     std::vector< Scalar >& work)
00226     {
00227       using std::endl;
00228       using std::ostringstream;
00229       using std::vector;
00230 
00231       if (P_mine == P_other)
00232   return; // nothing to do
00233     
00234       const int P_top = std::min (P_mine, P_other);
00235       const int P_bot = std::max (P_mine, P_other);
00236     
00237       const LocalOrdinal nelts = ncols_C * ncols_C;
00238       const LocalOrdinal ldq = ncols_Q;
00239       const LocalOrdinal ldc_other = ncols_C;
00240 
00241       // Send and receive C_mine resp. C_other to the other processor of
00242       // the pair.
00243       messenger->swapData (&C_mine[0], &C_other[0], nelts, P_other, tag);
00244 
00245       Combine< LocalOrdinal, Scalar > combine;
00246       if (P_mine == P_top)
00247   combine.apply_pair (apply_type, ncols_C, ncols_Q, &Q_cur[0], ldq, 
00248           &tau_cur[0], C_mine, ldc_mine, C_other, ldc_other, 
00249           &work[0]);
00250       else if (P_mine == P_bot)
00251   combine.apply_pair (apply_type, ncols_C, ncols_Q, &Q_cur[0], ldq, 
00252           &tau_cur[0], C_other, ldc_other, C_mine, ldc_mine, 
00253           &work[0]);
00254       else
00255   {
00256     ostringstream os;
00257     os << "Should never get here: P_mine (= " << P_mine 
00258        << ") not one of P_top, P_bot = " << P_top << ", " << P_bot;
00259     throw std::logic_error (os.str());
00260   }
00261     }
00262 
00263     void
00264     apply_helper (const ApplyType& apply_type,
00265       const LocalOrdinal ncols_C,
00266       const LocalOrdinal ncols_Q,
00267       Scalar C_mine[],
00268       const LocalOrdinal ldc_mine,
00269       Scalar C_other[], // contiguous ncols_C x ncols_C scratch
00270       const LocalOrdinal my_rank,
00271       const LocalOrdinal P_first,
00272       const LocalOrdinal P_last,
00273       const LocalOrdinal tag,
00274       MessengerBase< Scalar >* const messenger,
00275       const std::vector< std::vector< Scalar > >& Q_factors,
00276       const std::vector< std::vector< Scalar > >& tau_arrays,
00277       const LocalOrdinal cur_pos,
00278       std::vector< Scalar >& work)
00279     {
00280       using std::endl;
00281       using std::ostringstream;
00282       using std::vector;
00283 
00284       if (P_last <= P_first)
00285   return;
00286       else
00287   {
00288     const int P = P_last - P_first + 1;
00289     // Whether the interval [P_first, P_last] has an even number of
00290     // elements.  Our interval splitting scheme ensures that the
00291     // interval [P_first, P_mid - 1] always has an even number of
00292     // elements.
00293     const bool b_even = (P % 2 == 0);
00294     // We split the interval [P_first, P_last] into 2 intervals:
00295     // [P_first, P_mid-1], and [P_mid, P_last].  We bias the
00296     // splitting procedure so that the lower interval always has an
00297     // even number of processor ranks, and never has fewer processor
00298     // ranks than the higher interval.
00299     const int P_mid = b_even ? (P_first + P/2) : (P_first + P/2 + 1); 
00300 
00301     if (my_rank < P_mid) // Interval [P_first, P_mid - 1]
00302       {
00303         const bool b_participating = b_even || my_rank < P_mid - 1;
00304 
00305         if (cur_pos < 0)
00306     {
00307       ostringstream os;
00308       os << "On Proc " << my_rank << ": cur_pos (= " << cur_pos 
00309          << ") < 0; lower interval [" << P_first << "," << (P_mid-1)
00310          << "]; original interval [" << P_first << "," << P_last 
00311          << "]" << endl;
00312       throw std::logic_error (os.str());
00313     }
00314 
00315         // If there aren't an even number of processors in the
00316         // original interval, then the last processor in the lower
00317         // interval has to skip this round.  Since we skip this
00318         // round, don't decrement cur_pos (else we'll skip an entry
00319         // and eventually fall off the front of the array.
00320         int new_cur_pos;
00321         if (b_even || my_rank < P_mid - 1)
00322     {
00323       if (! b_participating)
00324         throw std::logic_error("Should never get here");
00325 
00326       const int my_offset = my_rank - P_first;
00327       const int P_other = P_mid + my_offset;
00328       // assert (P_mid <= P_other && P_other <= P_last);
00329       if (P_other < P_mid || P_other > P_last)
00330         throw std::logic_error("Should never get here");
00331 
00332       apply_pair (apply_type, ncols_C, ncols_Q, C_mine, ldc_mine, 
00333             C_other, my_rank, P_other, tag, messenger, 
00334             Q_factors[cur_pos], tau_arrays[cur_pos], work);
00335       new_cur_pos = cur_pos - 1;
00336     }
00337         else
00338     {
00339       if (b_participating)
00340         throw std::logic_error("Should never get here");
00341 
00342       new_cur_pos = cur_pos;
00343     }
00344         apply_helper (apply_type, ncols_C, ncols_Q, C_mine, ldc_mine, 
00345           C_other, my_rank, P_first, P_mid - 1, tag + 1, 
00346           messenger, Q_factors, tau_arrays, new_cur_pos, 
00347           work);
00348       }
00349     else
00350       {
00351         if (cur_pos < 0)
00352     {
00353       ostringstream os;
00354       os << "On Proc " << my_rank << ": cur_pos (= " << cur_pos 
00355          << ") < 0; upper interval [" << P_mid << "," << P_last
00356          << "]; original interval [" << P_first << "," << P_last 
00357          << "]" << endl;
00358       throw std::logic_error (os.str());
00359     }
00360 
00361         const int my_offset = my_rank - P_mid;
00362         const int P_other = P_first + my_offset;
00363         // assert (0 <= P_other && P_other < P_mid);
00364         apply_pair (apply_type, ncols_C, ncols_Q, C_mine, ldc_mine, 
00365         C_other, my_rank, P_other, tag, messenger, 
00366         Q_factors[cur_pos], tau_arrays[cur_pos], work);
00367         apply_helper (apply_type, ncols_C, ncols_Q, C_mine, ldc_mine, 
00368           C_other, my_rank, P_mid, P_last, tag + 1, 
00369           messenger, Q_factors, tau_arrays, cur_pos - 1, 
00370           work);
00371       }
00372   }
00373     }
00374   };
00375 
00376 } // namespace TSQR
00377 
00378 #endif // __TSQR_Tsqr_DistTsqrHelper_hpp
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends