Kokkos Node API and Local Linear Algebra Kernels Version of the Day
TbbTsqr_ApplyTask.hpp
00001 //@HEADER
00002 // ************************************************************************
00003 // 
00004 //          Kokkos: Node API and Parallel Node Kernels
00005 //              Copyright (2008) Sandia Corporation
00006 // 
00007 // Under the terms of Contract DE-AC04-94AL85000 with Sandia Corporation,
00008 // the U.S. Government retains certain rights in this software.
00009 // 
00010 // Redistribution and use in source and binary forms, with or without
00011 // modification, are permitted provided that the following conditions are
00012 // met:
00013 //
00014 // 1. Redistributions of source code must retain the above copyright
00015 // notice, this list of conditions and the following disclaimer.
00016 //
00017 // 2. Redistributions in binary form must reproduce the above copyright
00018 // notice, this list of conditions and the following disclaimer in the
00019 // documentation and/or other materials provided with the distribution.
00020 //
00021 // 3. Neither the name of the Corporation nor the names of the
00022 // contributors may be used to endorse or promote products derived from
00023 // this software without specific prior written permission.
00024 //
00025 // THIS SOFTWARE IS PROVIDED BY SANDIA CORPORATION "AS IS" AND ANY
00026 // EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
00027 // IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
00028 // PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL SANDIA CORPORATION OR THE
00029 // CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
00030 // EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
00031 // PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
00032 // PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
00033 // LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
00034 // NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
00035 // SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
00036 //
00037 // Questions? Contact Michael A. Heroux (maherou@sandia.gov) 
00038 // 
00039 // ************************************************************************
00040 //@HEADER
00041 
00042 #ifndef __TSQR_TBB_ApplyTask_hpp
00043 #define __TSQR_TBB_ApplyTask_hpp
00044 
00045 #include <tbb/task.h>
00046 #include <TbbTsqr_Partitioner.hpp>
00047 #include <Tsqr_SequentialTsqr.hpp>
00048 
00051 
00052 namespace TSQR {
00053   namespace TBB {
00054     
00058     template< class LocalOrdinal, class Scalar, class TimerType >
00059     class ApplyTask : public tbb::task {
00060     public:
00061       typedef MatView<LocalOrdinal, Scalar> mat_view;
00062       typedef ConstMatView<LocalOrdinal, Scalar> const_mat_view;
00063       typedef std::pair<mat_view, mat_view> split_t;
00064       typedef std::pair<const_mat_view, const_mat_view> const_split_t;
00065       typedef std::pair<const_mat_view, mat_view> top_blocks_t;
00066       typedef std::vector<top_blocks_t> array_top_blocks_t;
00067 
00070       typedef typename SequentialTsqr<LocalOrdinal, Scalar>::FactorOutput SeqOutput;
00075       typedef std::vector<std::vector<Scalar> > ParOutput;
00079       typedef typename std::pair<std::vector<SeqOutput>, ParOutput> FactorOutput;
00080 
00086       ApplyTask (const size_t P_first__, 
00087      const size_t P_last__,
00088      ConstMatView<LocalOrdinal, Scalar> Q,
00089      MatView<LocalOrdinal, Scalar> C,
00090      array_top_blocks_t& top_blocks, 
00091      const FactorOutput& factor_output,
00092      const SequentialTsqr<LocalOrdinal, Scalar>& seq,
00093      double& my_seq_timing,
00094      double& min_seq_timing,
00095      double& max_seq_timing,
00096      const bool contiguous_cache_blocks) :
00097   P_first_ (P_first__), 
00098   P_last_ (P_last__), 
00099   Q_ (Q), 
00100   C_ (C),
00101   top_blocks_ (top_blocks), 
00102   factor_output_ (factor_output), 
00103   seq_ (seq), 
00104   apply_type_ (ApplyType::NoTranspose), // FIXME: modify to support Q^T and Q^H
00105   my_seq_timing_ (my_seq_timing),
00106   min_seq_timing_ (min_seq_timing),
00107   max_seq_timing_ (max_seq_timing),
00108   contiguous_cache_blocks_ (contiguous_cache_blocks)
00109       {}
00110 
00111       tbb::task* execute () 
00112       {
00113   if (P_first_ > P_last_ || Q_.empty() || C_.empty())
00114     return NULL;
00115   else if (P_first_ == P_last_)
00116     {
00117       execute_base_case ();
00118       return NULL;
00119     }
00120   else
00121     {
00122       // Recurse on two intervals: [P_first, P_mid] and [P_mid+1, P_last]
00123       const size_t P_mid = (P_first_ + P_last_) / 2;
00124       const_split_t Q_split = 
00125         partitioner_.split (Q_, P_first_, P_mid, P_last_,
00126           contiguous_cache_blocks_);
00127       split_t C_split = 
00128         partitioner_.split (C_, P_first_, P_mid, P_last_,
00129           contiguous_cache_blocks_);
00130 
00131       // The partitioner may decide that the current blocks Q_
00132       // and C_ have too few rows to be worth splitting.  In
00133       // that case, Q_split.second and C_split.second (the
00134       // bottom block) will be empty.  We can deal with this by
00135       // treating it as the base case.
00136       if (Q_split.second.empty() || Q_split.second.nrows() == 0)
00137         {
00138     execute_base_case ();
00139     return NULL;
00140         }
00141 
00142       double top_timing;
00143       double top_min_timing = 0.0;
00144       double top_max_timing = 0.0;
00145       double bot_timing;
00146       double bot_min_timing = 0.0;
00147       double bot_max_timing = 0.0;
00148 
00149       apply_pair (P_first_, P_mid+1);
00150       ApplyTask& topTask = *new( allocate_child() )
00151         ApplyTask (P_first_, P_mid, Q_split.first, C_split.first,
00152        top_blocks_, factor_output_, seq_, 
00153        top_timing, top_min_timing, top_max_timing,
00154        contiguous_cache_blocks_);
00155       ApplyTask& botTask = *new( allocate_child() )
00156         ApplyTask (P_mid+1, P_last_, Q_split.second, C_split.second,
00157        top_blocks_, factor_output_, seq_,
00158        bot_timing, bot_min_timing, bot_max_timing,
00159        contiguous_cache_blocks_);
00160 
00161       set_ref_count (3); // 3 children (2 + 1 for the wait)
00162       spawn (topTask);
00163       spawn_and_wait_for_all (botTask);
00164 
00165       top_min_timing = (top_min_timing == 0.0) ? top_timing : top_min_timing;
00166       top_max_timing = (top_max_timing == 0.0) ? top_timing : top_max_timing;
00167 
00168       bot_min_timing = (bot_min_timing == 0.0) ? bot_timing : bot_min_timing;
00169       bot_max_timing = (bot_max_timing == 0.0) ? bot_timing : bot_max_timing;
00170 
00171       min_seq_timing_ = std::min (top_min_timing, bot_min_timing);
00172       max_seq_timing_ = std::min (top_max_timing, bot_max_timing);
00173 
00174       return NULL;
00175     }
00176       }
00177 
00178     private:
00179       size_t P_first_, P_last_;
00180       const_mat_view Q_;
00181       mat_view C_;
00182       array_top_blocks_t& top_blocks_;
00183       const FactorOutput& factor_output_;
00184       SequentialTsqr<LocalOrdinal, Scalar> seq_;
00185       TSQR::ApplyType apply_type_;
00186       TSQR::Combine<LocalOrdinal, Scalar> combine_;
00187       Partitioner<LocalOrdinal, Scalar> partitioner_;
00188       double& my_seq_timing_;
00189       double& min_seq_timing_;
00190       double& max_seq_timing_;
00191       bool contiguous_cache_blocks_;
00192 
00193       void 
00194       execute_base_case ()
00195       {
00196   TimerType timer("");
00197   timer.start();
00198   const std::vector<SeqOutput>& seq_outputs = factor_output_.first;
00199   seq_.apply (apply_type_, Q_.nrows(), Q_.ncols(), 
00200         Q_.get(), Q_.lda(), seq_outputs[P_first_], 
00201         C_.ncols(), C_.get(), C_.lda(), 
00202         contiguous_cache_blocks_);
00203   my_seq_timing_ = timer.stop();
00204       }
00205 
00206       void 
00207       apply_pair (const size_t P_top, 
00208       const size_t P_bot) 
00209       {
00210   if (P_top == P_bot) 
00211     throw std::logic_error("apply_pair: should never get here!");
00212 
00213   const_mat_view& Q_bot = top_blocks_[P_bot].first;
00214   mat_view& C_top = top_blocks_[P_top].second;
00215   mat_view& C_bot = top_blocks_[P_bot].second;
00216 
00217   const ParOutput& par_output = factor_output_.second;
00218   const std::vector<Scalar>& tau = par_output[P_bot];
00219   std::vector<Scalar> work (C_top.ncols());
00220   combine_.apply_pair (apply_type_, C_top.ncols(), Q_bot.ncols(), 
00221            Q_bot.get(), Q_bot.lda(), &tau[0],
00222            C_top.get(), C_top.lda(), 
00223            C_bot.get(), C_bot.lda(), &work[0]);
00224       }
00225 
00226     };
00227 
00228   } // namespace TBB
00229 } // namespace TSQR
00230 
00231 
00232 #endif // __TSQR_TBB_ApplyTask_hpp
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends