Kokkos Node API and Local Linear Algebra Kernels Version of the Day
Tsqr_MatView.hpp
00001 //@HEADER
00002 // ************************************************************************
00003 // 
00004 //          Kokkos: Node API and Parallel Node Kernels
00005 //              Copyright (2008) Sandia Corporation
00006 // 
00007 // Under the terms of Contract DE-AC04-94AL85000 with Sandia Corporation,
00008 // the U.S. Government retains certain rights in this software.
00009 // 
00010 // Redistribution and use in source and binary forms, with or without
00011 // modification, are permitted provided that the following conditions are
00012 // met:
00013 //
00014 // 1. Redistributions of source code must retain the above copyright
00015 // notice, this list of conditions and the following disclaimer.
00016 //
00017 // 2. Redistributions in binary form must reproduce the above copyright
00018 // notice, this list of conditions and the following disclaimer in the
00019 // documentation and/or other materials provided with the distribution.
00020 //
00021 // 3. Neither the name of the Corporation nor the names of the
00022 // contributors may be used to endorse or promote products derived from
00023 // this software without specific prior written permission.
00024 //
00025 // THIS SOFTWARE IS PROVIDED BY SANDIA CORPORATION "AS IS" AND ANY
00026 // EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
00027 // IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
00028 // PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL SANDIA CORPORATION OR THE
00029 // CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
00030 // EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
00031 // PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
00032 // PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
00033 // LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
00034 // NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
00035 // SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
00036 //
00037 // Questions? Contact Michael A. Heroux (maherou@sandia.gov) 
00038 // 
00039 // ************************************************************************
00040 //@HEADER
00041 
00042 #ifndef __TSQR_Tsqr_MatView_hpp
00043 #define __TSQR_Tsqr_MatView_hpp
00044 
00045 #include <cstring> // NULL
00046 
00047 // Define for bounds checking and other safety features, undefine for speed.
00048 // #define TSQR_MATVIEW_DEBUG 1
00049 
00050 #ifdef TSQR_MATVIEW_DEBUG
00051 #  include <limits>
00052 #endif // TSQR_MATVIEW_DEBUG
00053 
00054 #include <sstream>
00055 #include <stdexcept>
00056 
00059 
00060 namespace TSQR {
00061 
00062   template< class MatrixViewType1, class MatrixViewType2 >
00063   void
00064   matrixCopy (MatrixViewType1& A, const MatrixViewType2& B)
00065   {
00066     const typename MatrixViewType1::ordinal_type A_nrows = A.nrows();
00067     const typename MatrixViewType1::ordinal_type A_ncols = A.ncols();
00068     if (A_nrows != B.nrows() || A_ncols != B.ncols())
00069       {
00070   using std::endl;
00071   std::ostringstream os;
00072   os << "matrixCopy: dimensions of A (output matrix) "
00073     "and B (input matrix) are not compatible." 
00074      << endl
00075      << "A is " << A.nrows() << " x " << A.ncols() 
00076      << ", and B is " << B.nrows() << " x " << B.ncols() << ".";
00077   throw std::invalid_argument(os.str());
00078       }
00079     for (typename MatrixViewType1::ordinal_type j = 0; j < A_ncols; ++j)
00080       {
00081   typename MatrixViewType1::scalar_type* const A_j = &A(0,j);
00082   const typename MatrixViewType2::scalar_type* const B_j = &B(0,j);
00083   for (typename MatrixViewType1::ordinal_type i = 0; i < A_nrows; ++i)
00084     A_j[i] = B_j[i];
00085       }
00086   }
00087 
00088   template< class FirstMatrixViewType, class SecondMatrixViewType >
00089   bool
00090   matrix_equal (FirstMatrixViewType& A, SecondMatrixViewType& B)
00091   {
00092     if (A.nrows() != B.nrows() || A.ncols() != B.ncols())
00093       return false;
00094     
00095     typedef typename FirstMatrixViewType::ordinal_type first_ordinal_type;
00096     typedef typename SecondMatrixViewType::ordinal_type second_ordinal_type;
00097     typedef typename FirstMatrixViewType::pointer_type first_pointer_type;
00098     typedef typename SecondMatrixViewType::pointer_type second_pointer_type;
00099 
00100     const first_ordinal_type nrows = A.nrows();
00101     const first_ordinal_type A_lda = A.lda();
00102     const first_ordinal_type ncols = A.ncols();
00103     const second_ordinal_type B_lda = B.lda();
00104 
00105     first_pointer_type A_j = A.get();
00106     second_pointer_type B_j = B.get();
00107 
00108     for (first_ordinal_type j = 0; j < ncols; ++j, A_j += A_lda, B_j += B_lda)
00109       for (first_ordinal_type i = 0; i < nrows; ++i)
00110   if (A_j[i] != B_j[i])
00111     return false;
00112 
00113     return true;
00114   }
00115 
00116 #ifdef TSQR_MATVIEW_DEBUG
00117   template< class Ordinal, class Scalar >
00118   class MatViewVerify {
00119   public:
00120     static void 
00121     verify (const Ordinal num_rows, 
00122       const Ordinal num_cols, 
00123       const Scalar* const A, 
00124       const Ordinal leading_dim)
00125     {
00126       using std::endl;
00127 
00128       bool good = true;
00129       std::ostringstream os;
00130       if (! std::numeric_limits<Ordinal>::is_integer)
00131   {
00132     good = false;
00133     os << "Error: Ordinal type must be an integer." << endl;
00134   }
00135       if (std::numeric_limits<Ordinal>::is_signed)
00136   {
00137     if (num_rows < 0)
00138       {
00139         good = false;
00140         os << "Error: num_rows (= " << num_rows << ") < 0." << endl;
00141       }
00142     if (num_cols < 0)
00143       {
00144         good = false;
00145         os << "Error: num_cols (= " << num_cols << ") < 0." << endl;
00146       }
00147     if (leading_dim < 0)
00148       {
00149         good = false;
00150         os << "Error: leading_dim (= " << leading_dim << ") < 0." << endl;
00151       }
00152   }
00153       if (leading_dim < num_rows)
00154   {
00155     good = false;
00156     os << "Error: leading_dim (= " << leading_dim << ") < num_rows (= " << num_rows << ")." << endl;
00157   }
00158       if (! good)
00159   throw std::invalid_argument (os.str());
00160     }
00161   };
00162 #endif // TSQR_MATVIEW_DEBUG
00163 
00164 
00165   // Forward declaration
00166   template< class Ordinal, class Scalar >
00167   class ConstMatView;
00168 
00169   // Forward declaration
00170   template< class Ordinal, class Scalar >
00171   class Matrix;
00172 
00176   template< class Ordinal, class Scalar >
00177   class MatView {
00178   public:
00179     typedef Scalar scalar_type;
00180     typedef Ordinal ordinal_type;
00181     typedef Scalar* pointer_type;
00182 
00185     MatView () : nrows_(0), ncols_(0), lda_(0), A_(NULL) {}
00186 
00187     MatView (const Ordinal num_rows, 
00188        const Ordinal num_cols, 
00189        Scalar* const A, 
00190        const Ordinal leading_dim) :
00191       nrows_(num_rows),
00192       ncols_(num_cols),
00193       lda_(leading_dim),
00194       A_(A)
00195     {
00196 #ifdef TSQR_MATVIEW_DEBUG
00197       MatViewVerify< Ordinal, Scalar >::verify (num_rows, num_cols, A, leading_dim);
00198 #endif // TSQR_MATVIEW_DEBUG
00199     }
00200 
00201     MatView (const MatView& view) :
00202       nrows_(view.nrows()),
00203       ncols_(view.ncols()),
00204       lda_(view.lda()),
00205       A_(view.get())
00206     {}
00207 
00208     MatView& operator= (const MatView& view) {
00209       if (this != &view)
00210   {
00211     nrows_ = view.nrows();
00212     ncols_ = view.ncols();
00213     A_ = view.get();
00214     lda_ = view.lda();
00215   }
00216       return *this;
00217     }
00218 
00224     Scalar& operator() (const Ordinal i, const Ordinal j) const 
00225     {
00226 #ifdef TSQR_MATVIEW_DEBUG
00227       if (std::numeric_limits< Ordinal >::is_signed)
00228   {
00229     if (i < 0 || i >= nrows())
00230       throw std::invalid_argument("Row range invalid");
00231     else if (j < 0 || j >= ncols())
00232       throw std::invalid_argument("Column range invalid");
00233   }
00234       else
00235   {
00236     if (i >= nrows())
00237       throw std::invalid_argument("Row range invalid");
00238     else if (j >= ncols())
00239       throw std::invalid_argument("Column range invalid");
00240   }
00241       if (A_ == NULL)
00242   throw std::logic_error("Attempt to reference NULL data");
00243 #endif // TSQR_MATVIEW_DEBUG
00244       return A_[i + j*lda()];
00245     }
00246 
00247     Ordinal nrows() const { return nrows_; }
00248     Ordinal ncols() const { return ncols_; }
00249     Ordinal lda() const { return lda_; }
00250 
00255     pointer_type get() const { return A_; }
00256     bool empty() const { return nrows() == 0 || ncols() == 0; }
00257 
00260     MatView row_block (const Ordinal firstRow, const Ordinal lastRow) 
00261     {
00262 #ifdef TSQR_MATVIEW_DEBUG
00263       if (std::numeric_limits< Ordinal >::is_signed)
00264   {
00265     if (firstRow < 0 || firstRow > lastRow || lastRow >= nrows())
00266       throw std::invalid_argument ("Row range invalid");
00267   }
00268       else
00269   {
00270     if (firstRow > lastRow || lastRow >= nrows())
00271       throw std::invalid_argument ("Row range invalid");
00272   }
00273 #endif // TSQR_MATVIEW_DEBUG
00274       return MatView (lastRow - firstRow + 1, ncols(), get() + firstRow, lda());
00275     }
00276 
00292     MatView split_top (const Ordinal nrows_top, 
00293            const bool b_contiguous_blocks = false)
00294     {
00295 #ifdef TSQR_MATVIEW_DEBUG
00296       if (std::numeric_limits< Ordinal >::is_signed && nrows_top < 0)
00297   {
00298     std::ostringstream os;
00299     os << "nrows_top (= " << nrows_top << ") < 0";
00300     throw std::invalid_argument (os.str());
00301   }
00302       else if (nrows_top > nrows())
00303   {
00304     std::ostringstream os;
00305     os << "nrows_top (= " << nrows_top << ") > nrows (= " << nrows() << ")";
00306     throw std::invalid_argument (os.str());
00307   }
00308 #endif // TSQR_MATVIEW_DEBUG
00309 
00310       Scalar* const A_top_ptr = get();
00311       Scalar* A_rest_ptr;
00312       const Ordinal nrows_rest = nrows() - nrows_top;
00313       Ordinal lda_top, lda_rest;
00314       if (b_contiguous_blocks)
00315   {
00316     lda_top = nrows_top;
00317     lda_rest = nrows_rest;
00318     A_rest_ptr = A_top_ptr + nrows_top * ncols();
00319   }
00320       else
00321   {
00322     lda_top = lda();
00323     lda_rest = lda();
00324     A_rest_ptr = A_top_ptr + nrows_top;
00325   }
00326       MatView A_top (nrows_top, ncols(), get(), lda_top);
00327       A_ = A_rest_ptr;
00328       nrows_ = nrows_rest;
00329       lda_ = lda_rest;
00330 
00331       return A_top;
00332     }
00333 
00336     MatView split_bottom (const Ordinal nrows_bottom, 
00337         const bool b_contiguous_blocks = false)
00338     {
00339 #ifdef TSQR_MATVIEW_DEBUG
00340       if (std::numeric_limits< Ordinal >::is_signed && nrows_bottom < 0)
00341   throw std::invalid_argument ("nrows_bottom < 0");
00342       if (nrows_bottom > nrows())
00343   throw std::invalid_argument ("nrows_bottom > nrows");
00344 #endif // TSQR_MATVIEW_DEBUG
00345 
00346       Scalar* const A_rest_ptr = get();
00347       Scalar* A_bottom_ptr;
00348       const Ordinal nrows_rest = nrows() - nrows_bottom;
00349       Ordinal lda_bottom, lda_rest;
00350       if (b_contiguous_blocks)
00351   {
00352     lda_bottom = nrows_bottom;
00353     lda_rest = nrows() - nrows_bottom;
00354     A_bottom_ptr = A_rest_ptr + nrows_rest * ncols();
00355   }
00356       else
00357   {
00358     lda_bottom = lda();
00359     lda_rest = lda();
00360     A_bottom_ptr = A_rest_ptr + nrows_rest;
00361   }
00362       MatView A_bottom (nrows_bottom, ncols(), A_bottom_ptr, lda_bottom);
00363       A_ = A_rest_ptr;
00364       nrows_ = nrows_rest;
00365       lda_ = lda_rest;
00366 
00367       return A_bottom;
00368     }
00369 
00370     void
00371     fill (const scalar_type& value) 
00372     {
00373       const ordinal_type num_rows = nrows();
00374       const ordinal_type num_cols = ncols();
00375       const ordinal_type stride = lda();
00376 
00377       scalar_type* A_j = get();
00378       for (ordinal_type j = 0; j < num_cols; ++j, A_j += stride)
00379   for (ordinal_type i = 0; i < num_rows; ++i)
00380     A_j[i] = value;
00381     }
00382 
00386     void
00387     copy (const MatView< ordinal_type, scalar_type >& B) {
00388       matrixCopy (*this, B);
00389     }
00390     void
00391     copy (const ConstMatView< ordinal_type, scalar_type >& B) {
00392       matrixCopy (*this, B);
00393     }
00394     void
00395     copy (const Matrix< ordinal_type, scalar_type >& B) {
00396       matrixCopy (*this, B);
00397     }
00398 
00399     bool operator== (const MatView& rhs) const {
00400       return nrows() == rhs.nrows() && ncols() == rhs.ncols() && 
00401   lda() == rhs.lda() && get() == rhs.get();
00402     }
00403 
00404     bool operator!= (const MatView& rhs) const {
00405       return nrows() != rhs.nrows() || ncols() != rhs.ncols() || 
00406   lda() != rhs.lda() || get() != rhs.get();
00407     }
00408 
00409   private:
00410     ordinal_type nrows_, ncols_, lda_;
00411     scalar_type* A_;
00412   };
00413 
00414 
00422   template< class Ordinal, class Scalar >
00423   class ConstMatView {
00424   public:
00425     typedef Scalar scalar_type;
00426     typedef Ordinal ordinal_type;
00427     typedef const Scalar* pointer_type;
00428 
00429     ConstMatView () : nrows_(0), ncols_(0), lda_(0), A_(NULL) {}
00430 
00433     ConstMatView (const Ordinal num_rows, 
00434       const Ordinal num_cols, 
00435       const Scalar* const A, 
00436       const Ordinal leading_dim) :
00437       nrows_(num_rows),
00438       ncols_(num_cols),
00439       lda_(leading_dim),
00440       A_(A)
00441     {
00442 #ifdef TSQR_MATVIEW_DEBUG
00443       MatViewVerify< Ordinal, Scalar >::verify (num_rows, num_cols, A, leading_dim);
00444 #endif // TSQR_MATVIEW_DEBUG
00445     }
00446 
00447     ConstMatView (const ConstMatView& view) :
00448       nrows_(view.nrows()),
00449       ncols_(view.ncols()),
00450       lda_(view.lda()),
00451       A_(view.get())
00452     {}
00453 
00454     ConstMatView& operator= (const ConstMatView& view) {
00455       if (this != &view)
00456   {
00457     nrows_ = view.nrows();
00458     ncols_ = view.ncols();
00459     lda_ = view.lda();
00460     A_ = view.get();
00461   }
00462       return *this;
00463     }
00464 
00465     const Scalar& operator() (const Ordinal i, const Ordinal j) const 
00466     {
00467 #ifdef TSQR_MATVIEW_DEBUG
00468       if (std::numeric_limits< Ordinal >::is_signed)
00469   {
00470     if (i < 0 || i >= nrows())
00471       throw std::invalid_argument("Row range invalid");
00472     else if (j < 0 || j >= ncols())
00473       throw std::invalid_argument("Column range invalid");
00474   }
00475       else
00476   {
00477     if (i >= nrows())
00478       throw std::invalid_argument("Row range invalid");
00479     else if (j >= ncols())
00480       throw std::invalid_argument("Column range invalid");
00481   }
00482       if (A_ == NULL)
00483   throw std::logic_error("Attempt to reference NULL data");
00484 #endif // TSQR_MATVIEW_DEBUG
00485       return A_[i + j*lda()];
00486     }
00487 
00488     Ordinal nrows() const { return nrows_; }
00489     Ordinal ncols() const { return ncols_; }
00490     Ordinal lda() const { return lda_; }
00491     pointer_type get() const { return A_; }
00492     bool empty() const { return nrows() == 0 || ncols() == 0; }
00493 
00496     ConstMatView rowBlock (const Ordinal firstRow, 
00497          const Ordinal lastRow) const
00498     {
00499 #ifdef TSQR_MATVIEW_DEBUG
00500       if (firstRow < 0 || lastRow >= nrows())
00501   throw std::invalid_argument ("Row range invalid");
00502 #endif // TSQR_MATVIEW_DEBUG
00503       return ConstMatView (lastRow - firstRow + 1, ncols(), get() + firstRow, lda());
00504     }
00505 
00506 
00522     ConstMatView split_top (const Ordinal nrows_top, 
00523           const bool b_contiguous_blocks = false)
00524     {
00525 #ifdef TSQR_MATVIEW_DEBUG
00526       if (std::numeric_limits< Ordinal >::is_signed && nrows_top < 0)
00527   throw std::invalid_argument ("nrows_top < 0");
00528       if (nrows_top > nrows())
00529   throw std::invalid_argument ("nrows_top > nrows");
00530 #endif // TSQR_MATVIEW_DEBUG
00531 
00532       pointer_type const A_top_ptr = get();
00533       pointer_type A_rest_ptr;
00534       const Ordinal nrows_rest = nrows() - nrows_top;
00535       Ordinal lda_top, lda_rest;
00536       if (b_contiguous_blocks)
00537   {
00538     lda_top = nrows_top;
00539     lda_rest = nrows_rest;
00540     A_rest_ptr = A_top_ptr + nrows_top * ncols();
00541   }
00542       else
00543   {
00544     lda_top = lda();
00545     lda_rest = lda();
00546     A_rest_ptr = A_top_ptr + nrows_top;
00547   }
00548       ConstMatView A_top (nrows_top, ncols(), get(), lda_top);
00549       A_ = A_rest_ptr;
00550       nrows_ = nrows_rest;
00551       lda_ = lda_rest;
00552 
00553       return A_top;
00554     }
00555 
00556 
00559     ConstMatView split_bottom (const Ordinal nrows_bottom, 
00560              const bool b_contiguous_blocks = false)
00561     {
00562 #ifdef TSQR_MATVIEW_DEBUG
00563       if (std::numeric_limits< Ordinal >::is_signed && nrows_bottom < 0)
00564   throw std::invalid_argument ("nrows_bottom < 0");
00565       if (nrows_bottom > nrows())
00566   throw std::invalid_argument ("nrows_bottom > nrows");
00567 #endif // TSQR_MATVIEW_DEBUG
00568 
00569       pointer_type const A_rest_ptr = get();
00570       pointer_type A_bottom_ptr;
00571       const ordinal_type nrows_rest = nrows() - nrows_bottom;
00572       ordinal_type lda_bottom, lda_rest;
00573       if (b_contiguous_blocks)
00574   {
00575     lda_bottom = nrows_bottom;
00576     lda_rest = nrows() - nrows_bottom;
00577     A_bottom_ptr = A_rest_ptr + nrows_rest * ncols();
00578   }
00579       else
00580   {
00581     lda_bottom = lda();
00582     lda_rest = lda();
00583     A_bottom_ptr = A_rest_ptr + nrows_rest;
00584   }
00585       ConstMatView A_bottom (nrows_bottom, ncols(), A_bottom_ptr, lda_bottom);
00586       A_ = A_rest_ptr;
00587       nrows_ = nrows_rest;
00588       lda_ = lda_rest;
00589 
00590       return A_bottom;
00591     }
00592 
00593     bool operator== (const ConstMatView& rhs) const {
00594       return nrows() == rhs.nrows() && ncols() == rhs.ncols() && 
00595   lda() == rhs.lda() && get() == rhs.get();
00596     }
00597 
00598     bool operator!= (const ConstMatView& rhs) const {
00599       return nrows() != rhs.nrows() || ncols() != rhs.ncols() || 
00600   lda() != rhs.lda() || get() != rhs.get();
00601     }
00602 
00603 
00604   private:
00605     ordinal_type nrows_, ncols_, lda_;
00606     pointer_type A_;
00607   };
00608 
00609 } // namespace TSQR
00610 
00611 
00612 #endif // __TSQR_Tsqr_MatView_hpp
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends