Kokkos Node API and Local Linear Algebra Kernels Version of the Day
TbbTsqr_ApplyTask.hpp
00001 //@HEADER
00002 // ************************************************************************
00003 // 
00004 //          Kokkos: Node API and Parallel Node Kernels
00005 //              Copyright (2009) Sandia Corporation
00006 // 
00007 // Under terms of Contract DE-AC04-94AL85000, there is a non-exclusive
00008 // license for use of this work by or on behalf of the U.S. Government.
00009 // 
00010 // This library is free software; you can redistribute it and/or modify
00011 // it under the terms of the GNU Lesser General Public License as
00012 // published by the Free Software Foundation; either version 2.1 of the
00013 // License, or (at your option) any later version.
00014 //  
00015 // This library is distributed in the hope that it will be useful, but
00016 // WITHOUT ANY WARRANTY; without even the implied warranty of
00017 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
00018 // Lesser General Public License for more details.
00019 //  
00020 // You should have received a copy of the GNU Lesser General Public
00021 // License along with this library; if not, write to the Free Software
00022 // Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307
00023 // USA
00024 // Questions? Contact Michael A. Heroux (maherou@sandia.gov) 
00025 // 
00026 // ************************************************************************
00027 //@HEADER
00028 
00029 #ifndef __TSQR_TBB_ApplyTask_hpp
00030 #define __TSQR_TBB_ApplyTask_hpp
00031 
00032 #include <tbb/task.h>
00033 #include <TbbTsqr_Partitioner.hpp>
00034 #include <Tsqr_SequentialTsqr.hpp>
00035 
00038 
00039 namespace TSQR {
00040   namespace TBB {
00041     
00045     template< class LocalOrdinal, class Scalar, class TimerType >
00046     class ApplyTask : public tbb::task {
00047     public:
00048       typedef MatView<LocalOrdinal, Scalar> mat_view;
00049       typedef ConstMatView<LocalOrdinal, Scalar> const_mat_view;
00050       typedef std::pair<mat_view, mat_view> split_t;
00051       typedef std::pair<const_mat_view, const_mat_view> const_split_t;
00052       typedef std::pair<const_mat_view, mat_view> top_blocks_t;
00053       typedef std::vector<top_blocks_t> array_top_blocks_t;
00054 
00057       typedef typename SequentialTsqr<LocalOrdinal, Scalar>::FactorOutput SeqOutput;
00062       typedef std::vector<std::vector<Scalar> > ParOutput;
00066       typedef typename std::pair<std::vector<SeqOutput>, ParOutput> FactorOutput;
00067 
00073       ApplyTask (const size_t P_first__, 
00074      const size_t P_last__,
00075      ConstMatView<LocalOrdinal, Scalar> Q,
00076      MatView<LocalOrdinal, Scalar> C,
00077      array_top_blocks_t& top_blocks, 
00078      const FactorOutput& factor_output,
00079      const SequentialTsqr<LocalOrdinal, Scalar>& seq,
00080      double& my_seq_timing,
00081      double& min_seq_timing,
00082      double& max_seq_timing,
00083      const bool contiguous_cache_blocks) :
00084   P_first_ (P_first__), 
00085   P_last_ (P_last__), 
00086   Q_ (Q), 
00087   C_ (C),
00088   top_blocks_ (top_blocks), 
00089   factor_output_ (factor_output), 
00090   seq_ (seq), 
00091   apply_type_ (ApplyType::NoTranspose), // FIXME: modify to support Q^T and Q^H
00092   my_seq_timing_ (my_seq_timing),
00093   min_seq_timing_ (min_seq_timing),
00094   max_seq_timing_ (max_seq_timing),
00095   contiguous_cache_blocks_ (contiguous_cache_blocks)
00096       {}
00097 
00098       tbb::task* execute () 
00099       {
00100   if (P_first_ > P_last_ || Q_.empty() || C_.empty())
00101     return NULL;
00102   else if (P_first_ == P_last_)
00103     {
00104       execute_base_case ();
00105       return NULL;
00106     }
00107   else
00108     {
00109       // Recurse on two intervals: [P_first, P_mid] and [P_mid+1, P_last]
00110       const size_t P_mid = (P_first_ + P_last_) / 2;
00111       const_split_t Q_split = 
00112         partitioner_.split (Q_, P_first_, P_mid, P_last_,
00113           contiguous_cache_blocks_);
00114       split_t C_split = 
00115         partitioner_.split (C_, P_first_, P_mid, P_last_,
00116           contiguous_cache_blocks_);
00117 
00118       // The partitioner may decide that the current blocks Q_
00119       // and C_ have too few rows to be worth splitting.  In
00120       // that case, Q_split.second and C_split.second (the
00121       // bottom block) will be empty.  We can deal with this by
00122       // treating it as the base case.
00123       if (Q_split.second.empty() || Q_split.second.nrows() == 0)
00124         {
00125     execute_base_case ();
00126     return NULL;
00127         }
00128 
00129       double top_timing;
00130       double top_min_timing = 0.0;
00131       double top_max_timing = 0.0;
00132       double bot_timing;
00133       double bot_min_timing = 0.0;
00134       double bot_max_timing = 0.0;
00135 
00136       apply_pair (P_first_, P_mid+1);
00137       ApplyTask& topTask = *new( allocate_child() )
00138         ApplyTask (P_first_, P_mid, Q_split.first, C_split.first,
00139        top_blocks_, factor_output_, seq_, 
00140        top_timing, top_min_timing, top_max_timing,
00141        contiguous_cache_blocks_);
00142       ApplyTask& botTask = *new( allocate_child() )
00143         ApplyTask (P_mid+1, P_last_, Q_split.second, C_split.second,
00144        top_blocks_, factor_output_, seq_,
00145        bot_timing, bot_min_timing, bot_max_timing,
00146        contiguous_cache_blocks_);
00147 
00148       set_ref_count (3); // 3 children (2 + 1 for the wait)
00149       spawn (topTask);
00150       spawn_and_wait_for_all (botTask);
00151 
00152       top_min_timing = (top_min_timing == 0.0) ? top_timing : top_min_timing;
00153       top_max_timing = (top_max_timing == 0.0) ? top_timing : top_max_timing;
00154 
00155       bot_min_timing = (bot_min_timing == 0.0) ? bot_timing : bot_min_timing;
00156       bot_max_timing = (bot_max_timing == 0.0) ? bot_timing : bot_max_timing;
00157 
00158       min_seq_timing_ = std::min (top_min_timing, bot_min_timing);
00159       max_seq_timing_ = std::min (top_max_timing, bot_max_timing);
00160 
00161       return NULL;
00162     }
00163       }
00164 
00165     private:
00166       size_t P_first_, P_last_;
00167       const_mat_view Q_;
00168       mat_view C_;
00169       array_top_blocks_t& top_blocks_;
00170       const FactorOutput& factor_output_;
00171       SequentialTsqr<LocalOrdinal, Scalar> seq_;
00172       TSQR::ApplyType apply_type_;
00173       TSQR::Combine<LocalOrdinal, Scalar> combine_;
00174       Partitioner<LocalOrdinal, Scalar> partitioner_;
00175       double& my_seq_timing_;
00176       double& min_seq_timing_;
00177       double& max_seq_timing_;
00178       bool contiguous_cache_blocks_;
00179 
00180       void 
00181       execute_base_case ()
00182       {
00183   TimerType timer("");
00184   timer.start();
00185   const std::vector<SeqOutput>& seq_outputs = factor_output_.first;
00186   seq_.apply (apply_type_, Q_.nrows(), Q_.ncols(), 
00187         Q_.get(), Q_.lda(), seq_outputs[P_first_], 
00188         C_.ncols(), C_.get(), C_.lda(), 
00189         contiguous_cache_blocks_);
00190   my_seq_timing_ = timer.stop();
00191       }
00192 
00193       void 
00194       apply_pair (const size_t P_top, 
00195       const size_t P_bot) 
00196       {
00197   if (P_top == P_bot) 
00198     throw std::logic_error("apply_pair: should never get here!");
00199 
00200   const_mat_view& Q_bot = top_blocks_[P_bot].first;
00201   mat_view& C_top = top_blocks_[P_top].second;
00202   mat_view& C_bot = top_blocks_[P_bot].second;
00203 
00204   const ParOutput& par_output = factor_output_.second;
00205   const std::vector<Scalar>& tau = par_output[P_bot];
00206   std::vector<Scalar> work (C_top.ncols());
00207   combine_.apply_pair (apply_type_, C_top.ncols(), Q_bot.ncols(), 
00208            Q_bot.get(), Q_bot.lda(), &tau[0],
00209            C_top.get(), C_top.lda(), 
00210            C_bot.get(), C_bot.lda(), &work[0]);
00211       }
00212 
00213     };
00214 
00215   } // namespace TBB
00216 } // namespace TSQR
00217 
00218 
00219 #endif // __TSQR_TBB_ApplyTask_hpp
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends