Kokkos Node API and Local Linear Algebra Kernels Version of the Day
TbbTsqr_FactorTask.hpp
00001 //@HEADER
00002 // ************************************************************************
00003 // 
00004 //          Kokkos: Node API and Parallel Node Kernels
00005 //              Copyright (2009) Sandia Corporation
00006 // 
00007 // Under terms of Contract DE-AC04-94AL85000, there is a non-exclusive
00008 // license for use of this work by or on behalf of the U.S. Government.
00009 // 
00010 // This library is free software; you can redistribute it and/or modify
00011 // it under the terms of the GNU Lesser General Public License as
00012 // published by the Free Software Foundation; either version 2.1 of the
00013 // License, or (at your option) any later version.
00014 //  
00015 // This library is distributed in the hope that it will be useful, but
00016 // WITHOUT ANY WARRANTY; without even the implied warranty of
00017 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
00018 // Lesser General Public License for more details.
00019 //  
00020 // You should have received a copy of the GNU Lesser General Public
00021 // License along with this library; if not, write to the Free Software
00022 // Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307
00023 // USA
00024 // Questions? Contact Michael A. Heroux (maherou@sandia.gov) 
00025 // 
00026 // ************************************************************************
00027 //@HEADER
00028 
00029 #ifndef __TSQR_TBB_FactorTask_hpp
00030 #define __TSQR_TBB_FactorTask_hpp
00031 
00032 #include <tbb/task.h>
00033 #include <TbbTsqr_Partitioner.hpp>
00034 #include <Tsqr_SequentialTsqr.hpp>
00035 #include <Teuchos_TestForException.hpp>
00036 #include <algorithm>
00037 
00040 
00041 namespace TSQR {
00042   namespace TBB {
00043     
00047     template<class LocalOrdinal, class Scalar, class TimerType>
00048     class FactorTask : public tbb::task {
00049     public:
00050       typedef MatView<LocalOrdinal, Scalar> mat_view;
00051       typedef ConstMatView<LocalOrdinal, Scalar> const_mat_view;
00052       typedef std::pair<mat_view, mat_view> split_t;
00053       typedef std::pair<const_mat_view, const_mat_view> const_split_t;
00054 
00057       typedef typename SequentialTsqr<LocalOrdinal, Scalar>::FactorOutput SeqOutput;
00062       typedef std::vector<std::vector<Scalar> > ParOutput;
00066       typedef typename std::pair<std::vector<SeqOutput>, ParOutput> FactorOutput;
00067 
00073       FactorTask (const size_t P_first__, 
00074       const size_t P_last__,
00075       mat_view A,
00076       mat_view* const A_top_ptr,
00077       std::vector<SeqOutput>& seq_outputs,
00078       ParOutput& par_output,
00079       const SequentialTsqr<LocalOrdinal, Scalar>& seq,
00080       double& my_seq_timing,
00081       double& min_seq_timing,
00082       double& max_seq_timing,
00083       const bool contiguous_cache_blocks) :
00084   P_first_ (P_first__),
00085   P_last_ (P_last__),
00086   A_ (A),
00087   A_top_ptr_ (A_top_ptr),
00088   seq_outputs_ (seq_outputs),
00089   par_output_ (par_output),
00090   seq_ (seq),
00091   contiguous_cache_blocks_ (contiguous_cache_blocks),
00092   my_seq_timing_ (my_seq_timing),
00093   min_seq_timing_ (min_seq_timing),
00094   max_seq_timing_ (max_seq_timing)
00095       {}
00096 
00097       tbb::task* execute () 
00098       {
00099   if (P_first_ > P_last_ || A_.empty())
00100     return NULL;
00101   else if (P_first_ == P_last_)
00102     {
00103       execute_base_case ();
00104       return NULL;
00105     }
00106   else
00107     {
00108       // Recurse on two intervals: [P_first, P_mid] and [P_mid+1, P_last]
00109       const size_t P_mid = (P_first_ + P_last_) / 2;
00110       split_t A_split = 
00111         partitioner_.split (A_, P_first_, P_mid, P_last_,
00112           contiguous_cache_blocks_);
00113       // The partitioner may decide that the current block A_
00114       // has too few rows to be worth splitting.  In that case,
00115       // A_split.second (the bottom block) will be empty.  We
00116       // can deal with this by treating it as the base case.
00117       if (A_split.second.empty() || A_split.second.nrows() == 0)
00118         {
00119     execute_base_case ();
00120     return NULL;
00121         }
00122 
00123       double top_timing;
00124       double top_min_timing = 0.0;
00125       double top_max_timing = 0.0;
00126       double bot_timing;
00127       double bot_min_timing = 0.0;
00128       double bot_max_timing = 0.0;
00129 
00130       FactorTask& topTask = *new( allocate_child() )
00131         FactorTask (P_first_, P_mid, A_split.first, A_top_ptr_, 
00132         seq_outputs_, par_output_, seq_,
00133         top_timing, top_min_timing, top_max_timing,
00134         contiguous_cache_blocks_);
00135       // After the task finishes, A_bot will be set to the topmost
00136       // partition of A_split.second.  This will let us combine
00137       // the two subproblems (using factor_pair()) after their
00138       // tasks complete.
00139       mat_view A_bot;
00140       FactorTask& botTask = *new( allocate_child() )
00141         FactorTask (P_mid+1, P_last_, A_split.second, &A_bot, 
00142         seq_outputs_, par_output_, seq_,
00143         bot_timing, bot_min_timing, bot_max_timing,
00144         contiguous_cache_blocks_);
00145       set_ref_count (3); // 3 children (2 + 1 for the wait)
00146       spawn (topTask);
00147       spawn_and_wait_for_all (botTask);
00148       
00149       // Combine the two results
00150       factor_pair (P_first_, P_mid+1, *A_top_ptr_, A_bot);
00151 
00152       top_min_timing = (top_min_timing == 0.0) ? top_timing : top_min_timing;
00153       top_max_timing = (top_max_timing == 0.0) ? top_timing : top_max_timing;
00154 
00155       bot_min_timing = (bot_min_timing == 0.0) ? bot_timing : bot_min_timing;
00156       bot_max_timing = (bot_max_timing == 0.0) ? bot_timing : bot_max_timing;
00157 
00158       min_seq_timing_ = std::min (top_min_timing, bot_min_timing);
00159       max_seq_timing_ = std::min (top_max_timing, bot_max_timing);
00160 
00161       return NULL;
00162     }
00163       }
00164 
00165     private:
00166       const size_t P_first_, P_last_;
00167       mat_view A_;
00168       mat_view* const A_top_ptr_;
00169       std::vector<SeqOutput>& seq_outputs_;
00170       ParOutput& par_output_;
00171       SequentialTsqr<LocalOrdinal, Scalar> seq_;
00172       TSQR::Combine<LocalOrdinal, Scalar> combine_;
00173       Partitioner<LocalOrdinal, Scalar> partitioner_;
00174       const bool contiguous_cache_blocks_;
00175       double& my_seq_timing_;
00176       double& min_seq_timing_;
00177       double& max_seq_timing_;
00178 
00179       void 
00180       factor_pair (const size_t P_top,
00181        const size_t P_bot,
00182        mat_view& A_top, // different than A_top_
00183        mat_view& A_bot)
00184       {
00185   const char thePrefix[] = "TSQR::TBB::Factor::factor_pair: ";
00186   TEST_FOR_EXCEPTION(P_top == P_bot, std::logic_error,
00187          thePrefix << "Should never get here! P_top == P_bot (= " 
00188          << P_top << "), that is, the indices of the thread "
00189          "partitions are the same.");
00190   // We only read and write the upper ncols x ncols triangle of
00191   // each block.
00192   TEST_FOR_EXCEPTION(A_top.ncols() != A_bot.ncols(), std::logic_error,
00193          thePrefix << "The top cache block A_top is " 
00194          << A_top.nrows() << " x " << A_top.ncols() 
00195          << ", and the bottom cache block A_bot is "
00196          << A_bot.nrows() << " x " << A_bot.ncols() 
00197          << "; this means we can't factor [A_top; A_bot].");
00198   const LocalOrdinal ncols = A_top.ncols();
00199   std::vector<Scalar>& tau = par_output_[P_bot];
00200   std::vector<Scalar> work (ncols);
00201   combine_.factor_pair (ncols, A_top.get(), A_top.lda(),
00202             A_bot.get(), A_bot.lda(), &tau[0], &work[0]);
00203       }
00204 
00205       void
00206       execute_base_case () 
00207       {
00208   TimerType timer("");
00209   timer.start();
00210   seq_outputs_[P_first_] = 
00211     seq_.factor (A_.nrows(), A_.ncols(), A_.get(), 
00212            A_.lda(), contiguous_cache_blocks_);
00213   // Assign the topmost cache block of the current partition to
00214   // *A_top_ptr_.  Every base case invocation does this, so that
00215   // we can combine subproblems.  The root task also does this,
00216   // but for a different reason: so that we can extract the R
00217   // factor, once we're done with the factorization.
00218   *A_top_ptr_ = seq_.top_block (A_, contiguous_cache_blocks_);
00219   my_seq_timing_ = timer.stop();
00220       }
00221     };
00222   } // namespace TBB
00223 } // namespace TSQR
00224 
00225 #endif // __TSQR_TBB_FactorTask_hpp
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends