Kokkos Node API and Local Linear Algebra Kernels Version of the Day
TbbTsqr_FactorTask.hpp
00001 //@HEADER
00002 // ************************************************************************
00003 // 
00004 //          Kokkos: Node API and Parallel Node Kernels
00005 //              Copyright (2008) Sandia Corporation
00006 // 
00007 // Under the terms of Contract DE-AC04-94AL85000 with Sandia Corporation,
00008 // the U.S. Government retains certain rights in this software.
00009 // 
00010 // Redistribution and use in source and binary forms, with or without
00011 // modification, are permitted provided that the following conditions are
00012 // met:
00013 //
00014 // 1. Redistributions of source code must retain the above copyright
00015 // notice, this list of conditions and the following disclaimer.
00016 //
00017 // 2. Redistributions in binary form must reproduce the above copyright
00018 // notice, this list of conditions and the following disclaimer in the
00019 // documentation and/or other materials provided with the distribution.
00020 //
00021 // 3. Neither the name of the Corporation nor the names of the
00022 // contributors may be used to endorse or promote products derived from
00023 // this software without specific prior written permission.
00024 //
00025 // THIS SOFTWARE IS PROVIDED BY SANDIA CORPORATION "AS IS" AND ANY
00026 // EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
00027 // IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
00028 // PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL SANDIA CORPORATION OR THE
00029 // CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
00030 // EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
00031 // PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
00032 // PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
00033 // LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
00034 // NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
00035 // SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
00036 //
00037 // Questions? Contact Michael A. Heroux (maherou@sandia.gov) 
00038 // 
00039 // ************************************************************************
00040 //@HEADER
00041 
00042 #ifndef __TSQR_TBB_FactorTask_hpp
00043 #define __TSQR_TBB_FactorTask_hpp
00044 
00045 #include <tbb/task.h>
00046 #include <TbbTsqr_Partitioner.hpp>
00047 #include <Tsqr_SequentialTsqr.hpp>
00048 #include <Teuchos_Assert.hpp>
00049 #include <algorithm>
00050 
00053 
00054 namespace TSQR {
00055   namespace TBB {
00056     
00060     template<class LocalOrdinal, class Scalar, class TimerType>
00061     class FactorTask : public tbb::task {
00062     public:
00063       typedef MatView<LocalOrdinal, Scalar> mat_view;
00064       typedef ConstMatView<LocalOrdinal, Scalar> const_mat_view;
00065       typedef std::pair<mat_view, mat_view> split_t;
00066       typedef std::pair<const_mat_view, const_mat_view> const_split_t;
00067 
00070       typedef typename SequentialTsqr<LocalOrdinal, Scalar>::FactorOutput SeqOutput;
00075       typedef std::vector<std::vector<Scalar> > ParOutput;
00079       typedef typename std::pair<std::vector<SeqOutput>, ParOutput> FactorOutput;
00080 
00086       FactorTask (const size_t P_first__, 
00087       const size_t P_last__,
00088       mat_view A,
00089       mat_view* const A_top_ptr,
00090       std::vector<SeqOutput>& seq_outputs,
00091       ParOutput& par_output,
00092       const SequentialTsqr<LocalOrdinal, Scalar>& seq,
00093       double& my_seq_timing,
00094       double& min_seq_timing,
00095       double& max_seq_timing,
00096       const bool contiguous_cache_blocks) :
00097   P_first_ (P_first__),
00098   P_last_ (P_last__),
00099   A_ (A),
00100   A_top_ptr_ (A_top_ptr),
00101   seq_outputs_ (seq_outputs),
00102   par_output_ (par_output),
00103   seq_ (seq),
00104   contiguous_cache_blocks_ (contiguous_cache_blocks),
00105   my_seq_timing_ (my_seq_timing),
00106   min_seq_timing_ (min_seq_timing),
00107   max_seq_timing_ (max_seq_timing)
00108       {}
00109 
00110       tbb::task* execute () 
00111       {
00112   if (P_first_ > P_last_ || A_.empty())
00113     return NULL;
00114   else if (P_first_ == P_last_)
00115     {
00116       execute_base_case ();
00117       return NULL;
00118     }
00119   else
00120     {
00121       // Recurse on two intervals: [P_first, P_mid] and [P_mid+1, P_last]
00122       const size_t P_mid = (P_first_ + P_last_) / 2;
00123       split_t A_split = 
00124         partitioner_.split (A_, P_first_, P_mid, P_last_,
00125           contiguous_cache_blocks_);
00126       // The partitioner may decide that the current block A_
00127       // has too few rows to be worth splitting.  In that case,
00128       // A_split.second (the bottom block) will be empty.  We
00129       // can deal with this by treating it as the base case.
00130       if (A_split.second.empty() || A_split.second.nrows() == 0)
00131         {
00132     execute_base_case ();
00133     return NULL;
00134         }
00135 
00136       double top_timing;
00137       double top_min_timing = 0.0;
00138       double top_max_timing = 0.0;
00139       double bot_timing;
00140       double bot_min_timing = 0.0;
00141       double bot_max_timing = 0.0;
00142 
00143       FactorTask& topTask = *new( allocate_child() )
00144         FactorTask (P_first_, P_mid, A_split.first, A_top_ptr_, 
00145         seq_outputs_, par_output_, seq_,
00146         top_timing, top_min_timing, top_max_timing,
00147         contiguous_cache_blocks_);
00148       // After the task finishes, A_bot will be set to the topmost
00149       // partition of A_split.second.  This will let us combine
00150       // the two subproblems (using factor_pair()) after their
00151       // tasks complete.
00152       mat_view A_bot;
00153       FactorTask& botTask = *new( allocate_child() )
00154         FactorTask (P_mid+1, P_last_, A_split.second, &A_bot, 
00155         seq_outputs_, par_output_, seq_,
00156         bot_timing, bot_min_timing, bot_max_timing,
00157         contiguous_cache_blocks_);
00158       set_ref_count (3); // 3 children (2 + 1 for the wait)
00159       spawn (topTask);
00160       spawn_and_wait_for_all (botTask);
00161       
00162       // Combine the two results
00163       factor_pair (P_first_, P_mid+1, *A_top_ptr_, A_bot);
00164 
00165       top_min_timing = (top_min_timing == 0.0) ? top_timing : top_min_timing;
00166       top_max_timing = (top_max_timing == 0.0) ? top_timing : top_max_timing;
00167 
00168       bot_min_timing = (bot_min_timing == 0.0) ? bot_timing : bot_min_timing;
00169       bot_max_timing = (bot_max_timing == 0.0) ? bot_timing : bot_max_timing;
00170 
00171       min_seq_timing_ = std::min (top_min_timing, bot_min_timing);
00172       max_seq_timing_ = std::min (top_max_timing, bot_max_timing);
00173 
00174       return NULL;
00175     }
00176       }
00177 
00178     private:
00179       const size_t P_first_, P_last_;
00180       mat_view A_;
00181       mat_view* const A_top_ptr_;
00182       std::vector<SeqOutput>& seq_outputs_;
00183       ParOutput& par_output_;
00184       SequentialTsqr<LocalOrdinal, Scalar> seq_;
00185       TSQR::Combine<LocalOrdinal, Scalar> combine_;
00186       Partitioner<LocalOrdinal, Scalar> partitioner_;
00187       const bool contiguous_cache_blocks_;
00188       double& my_seq_timing_;
00189       double& min_seq_timing_;
00190       double& max_seq_timing_;
00191 
00192       void 
00193       factor_pair (const size_t P_top,
00194        const size_t P_bot,
00195        mat_view& A_top, // different than A_top_
00196        mat_view& A_bot)
00197       {
00198   const char thePrefix[] = "TSQR::TBB::Factor::factor_pair: ";
00199   TEUCHOS_TEST_FOR_EXCEPTION(P_top == P_bot, std::logic_error,
00200          thePrefix << "Should never get here! P_top == P_bot (= " 
00201          << P_top << "), that is, the indices of the thread "
00202          "partitions are the same.");
00203   // We only read and write the upper ncols x ncols triangle of
00204   // each block.
00205   TEUCHOS_TEST_FOR_EXCEPTION(A_top.ncols() != A_bot.ncols(), std::logic_error,
00206          thePrefix << "The top cache block A_top is " 
00207          << A_top.nrows() << " x " << A_top.ncols() 
00208          << ", and the bottom cache block A_bot is "
00209          << A_bot.nrows() << " x " << A_bot.ncols() 
00210          << "; this means we can't factor [A_top; A_bot].");
00211   const LocalOrdinal ncols = A_top.ncols();
00212   std::vector<Scalar>& tau = par_output_[P_bot];
00213   std::vector<Scalar> work (ncols);
00214   combine_.factor_pair (ncols, A_top.get(), A_top.lda(),
00215             A_bot.get(), A_bot.lda(), &tau[0], &work[0]);
00216       }
00217 
00218       void
00219       execute_base_case () 
00220       {
00221   TimerType timer("");
00222   timer.start();
00223   seq_outputs_[P_first_] = 
00224     seq_.factor (A_.nrows(), A_.ncols(), A_.get(), 
00225            A_.lda(), contiguous_cache_blocks_);
00226   // Assign the topmost cache block of the current partition to
00227   // *A_top_ptr_.  Every base case invocation does this, so that
00228   // we can combine subproblems.  The root task also does this,
00229   // but for a different reason: so that we can extract the R
00230   // factor, once we're done with the factorization.
00231   *A_top_ptr_ = seq_.top_block (A_, contiguous_cache_blocks_);
00232   my_seq_timing_ = timer.stop();
00233       }
00234     };
00235   } // namespace TBB
00236 } // namespace TSQR
00237 
00238 #endif // __TSQR_TBB_FactorTask_hpp
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends