Kokkos Node API and Local Linear Algebra Kernels Version of the Day
Tsqr_RMessenger.hpp
00001 //@HEADER
00002 // ************************************************************************
00003 // 
00004 //          Kokkos: Node API and Parallel Node Kernels
00005 //              Copyright (2008) Sandia Corporation
00006 // 
00007 // Under the terms of Contract DE-AC04-94AL85000 with Sandia Corporation,
00008 // the U.S. Government retains certain rights in this software.
00009 // 
00010 // Redistribution and use in source and binary forms, with or without
00011 // modification, are permitted provided that the following conditions are
00012 // met:
00013 //
00014 // 1. Redistributions of source code must retain the above copyright
00015 // notice, this list of conditions and the following disclaimer.
00016 //
00017 // 2. Redistributions in binary form must reproduce the above copyright
00018 // notice, this list of conditions and the following disclaimer in the
00019 // documentation and/or other materials provided with the distribution.
00020 //
00021 // 3. Neither the name of the Corporation nor the names of the
00022 // contributors may be used to endorse or promote products derived from
00023 // this software without specific prior written permission.
00024 //
00025 // THIS SOFTWARE IS PROVIDED BY SANDIA CORPORATION "AS IS" AND ANY
00026 // EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
00027 // IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
00028 // PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL SANDIA CORPORATION OR THE
00029 // CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
00030 // EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
00031 // PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
00032 // PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
00033 // LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
00034 // NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
00035 // SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
00036 //
00037 // Questions? Contact Michael A. Heroux (maherou@sandia.gov) 
00038 // 
00039 // ************************************************************************
00040 //@HEADER
00041 
00042 #ifndef __TSQR_RMessenger_hpp
00043 #define __TSQR_RMessenger_hpp
00044 
00045 #include <Tsqr_MatView.hpp>
00046 #include <Tsqr_MessengerBase.hpp>
00047 #include <Teuchos_RCP.hpp>
00048 
00049 #include <algorithm>
00050 #include <vector>
00051 
00052 
00053 namespace TSQR {
00054 
00061   template<class Ordinal, class Scalar>
00062   class RMessenger {
00063   public:
00064     typedef Scalar scalar_type;
00065     typedef Ordinal ordinal_type;
00066     typedef MessengerBase< Scalar > messenger_type;
00067     typedef Teuchos::RCP< messenger_type > messenger_ptr;
00068 
00072     RMessenger (const messenger_ptr& messenger) :
00073       messenger_ (messenger) {}
00074 
00075     template<class ConstMatrixViewType>
00076     void
00077     send (const ConstMatrixViewType& R, const int destProc)
00078     {
00079       pack (R);
00080       messenger_->send (&buffer_[0], buffer_.size(), destProc, 0);
00081     }
00082 
00083     template<class MatrixViewType>
00084     void
00085     recv (MatrixViewType& R, const int srcProc)
00086     {
00087       const typename MatrixViewType::ordinal_type ncols = R.ncols();
00088       const Ordinal buflen = buffer_length (ncols);
00089       buffer_.resize (buflen);
00090       messenger_->recv (&buffer_[0], buflen, srcProc, 0);
00091       unpack (R);
00092     }
00093 
00094     template<class MatrixViewType>
00095     void
00096     broadcast (MatrixViewType& R, const int rootProc)
00097     {
00098       const int myRank = messenger_->rank();
00099       if (myRank == rootProc)
00100   pack (R);
00101       messenger_->broadcast (&buffer_[0], buffer_length (R.ncols()), rootProc);
00102       if (myRank != rootProc)
00103   unpack (R);
00104     }
00105 
00107     RMessenger (const RMessenger& rhs) :
00108       messenger_ (rhs.messenger_), 
00109       buffer_ (0) // don't need to copy the buffer
00110     {}
00111 
00113     RMessenger& operator= (const RMessenger& rhs) {
00114       if (this != &rhs)
00115   {
00116     this->messenger_ = rhs.messenger_;
00117     // Don't need to do anything to this->buffer_; the various
00118     // operations such as pack() will resize it as necessary.
00119   }
00120       return *this;
00121     }
00122 
00123 
00124   private:
00125     messenger_ptr messenger_;
00126     std::vector< Scalar > buffer_;
00127 
00129     RMessenger ();
00130 
00135     Ordinal buffer_length (const Ordinal ncols) const {
00136       return (ncols * (ncols + Ordinal(1))) / Ordinal(2);
00137     }
00138 
00139     template<class ConstMatrixViewType>
00140     void
00141     pack (const ConstMatrixViewType& R)
00142     {
00143       typedef typename ConstMatrixViewType::scalar_type view_scalar_type;
00144       typedef typename ConstMatrixViewType::ordinal_type view_ordinal_type;
00145       typedef typename std::vector< Scalar >::iterator iter_type;
00146 
00147       const view_ordinal_type ncols = R.ncols();
00148       const Ordinal buf_length = buffer_length (ncols);
00149       buffer_.resize (buf_length);
00150       iter_type iter = buffer_.begin();
00151       for (view_ordinal_type j = 0; j < ncols; ++j)
00152   {
00153     const view_scalar_type* const R_j = &R(0,j);
00154     std::copy (R_j, R_j + (j+1), iter);
00155     iter += (j+1);
00156   }
00157     }
00158 
00159     template<class MatrixViewType>
00160     void
00161     unpack (MatrixViewType& R)
00162     {
00163       typedef typename MatrixViewType::ordinal_type view_ordinal_type;
00164       typedef typename std::vector< Scalar >::const_iterator const_iter_type;
00165 
00166       const view_ordinal_type ncols = R.ncols();
00167       const_iter_type iter = buffer_.begin();
00168       for (view_ordinal_type j = 0; j < ncols; ++j)
00169   {
00170     std::copy (iter, iter + (j+1), &R(0,j));
00171     iter += (j+1);
00172   }
00173     }
00174   };
00175 
00176 
00188   template<class MatrixViewType, class ConstMatrixViewType>
00189   void
00190   scatterStack (const ConstMatrixViewType& R_stack, 
00191     MatrixViewType& R_local,
00192     const Teuchos::RCP<MessengerBase<typename MatrixViewType::scalar_type> >& messenger)
00193   {
00194     typedef typename MatrixViewType::ordinal_type ordinal_type;
00195     typedef typename MatrixViewType::scalar_type scalar_type;
00196     typedef ConstMatView< ordinal_type, scalar_type > const_view_type;
00197 
00198     const int nprocs = messenger->size();
00199     const int my_rank = messenger->rank();
00200 
00201     if (my_rank == 0)
00202       {
00203   const ordinal_type ncols = R_stack.ncols();
00204 
00205   // Copy data from top ncols x ncols block of R_stack into R_local.
00206   const_view_type R_stack_view_first (ncols, ncols, R_stack.get(), R_stack.lda());
00207   R_local.copy (R_stack_view_first);
00208 
00209   // Loop through all other processors, sending each the next
00210   // ncols x ncols block of R_stack.
00211   RMessenger< ordinal_type, scalar_type > sender (messenger);
00212   for (int destProc = 1; destProc < nprocs; ++destProc)
00213     {
00214       const scalar_type* const R_ptr = R_stack.get() + destProc*ncols;
00215       const_view_type R_stack_view_cur (ncols, ncols, R_ptr, R_stack.lda());
00216       sender.send (R_stack_view_cur, destProc);
00217     }
00218       }
00219     else
00220       {
00221   const int srcProc = 0;
00222   R_local.fill (scalar_type(0));
00223   RMessenger< ordinal_type, scalar_type > receiver (messenger);
00224   receiver.recv (R_local, srcProc);
00225       }
00226   }
00227 
00228 
00229 
00230 
00231   template<class MatrixViewType, class ConstMatrixViewType>
00232   void
00233   gatherStack (MatrixViewType& R_stack, 
00234          ConstMatrixViewType& R_local,
00235          const Teuchos::RCP<MessengerBase<typename MatrixViewType::scalar_type> >& messenger)
00236   {
00237     typedef typename MatrixViewType::ordinal_type ordinal_type;
00238     typedef typename MatrixViewType::scalar_type scalar_type;
00239     typedef MatView<ordinal_type, scalar_type> matrix_view_type;
00240 
00241     const int nprocs = messenger->size();
00242     const int my_rank = messenger->rank();
00243 
00244     if (my_rank == 0)
00245       {
00246   const ordinal_type ncols = R_stack.ncols();
00247 
00248   // Copy data from R_local into top ncols x ncols block of R_stack.
00249   matrix_view_type R_stack_view_first (ncols, ncols, R_stack.get(), R_stack.lda());
00250   R_stack_view_first.copy (R_local);
00251 
00252   // Loop through all other processors, fetching their matrix data.
00253   RMessenger< ordinal_type, scalar_type > receiver (messenger);
00254   for (int srcProc = 1; srcProc < nprocs; ++srcProc)
00255     {
00256       const scalar_type* const R_ptr = R_stack.get() + srcProc*ncols;
00257       matrix_view_type R_stack_view_cur (ncols, ncols, R_ptr, R_stack.lda());
00258       // Fill (the lower triangle) with zeros, since
00259       // RMessenger::recv() only writes to the upper triangle.
00260       R_stack_view_cur.fill (scalar_type (0));
00261       receiver.recv (R_stack_view_cur, srcProc);
00262     }
00263       }
00264     else
00265       {
00266   // We only read R_stack on Proc 0, not on this proc.
00267   // Send data from R_local to Proc 0.
00268   const int destProc = 0;
00269   RMessenger< ordinal_type, scalar_type > sender (messenger);
00270   sender.send (R_local, destProc);
00271       }
00272     messenger->barrier();
00273   }
00274 
00275 } // namespace TSQR
00276 
00277 #endif // __TSQR_RMessenger_hpp
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends