Kokkos Node API and Local Linear Algebra Kernels Version of the Day
Tsqr_MatView.hpp
00001 //@HEADER
00002 // ************************************************************************
00003 // 
00004 //          Kokkos: Node API and Parallel Node Kernels
00005 //              Copyright (2009) Sandia Corporation
00006 // 
00007 // Under terms of Contract DE-AC04-94AL85000, there is a non-exclusive
00008 // license for use of this work by or on behalf of the U.S. Government.
00009 // 
00010 // This library is free software; you can redistribute it and/or modify
00011 // it under the terms of the GNU Lesser General Public License as
00012 // published by the Free Software Foundation; either version 2.1 of the
00013 // License, or (at your option) any later version.
00014 //  
00015 // This library is distributed in the hope that it will be useful, but
00016 // WITHOUT ANY WARRANTY; without even the implied warranty of
00017 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
00018 // Lesser General Public License for more details.
00019 //  
00020 // You should have received a copy of the GNU Lesser General Public
00021 // License along with this library; if not, write to the Free Software
00022 // Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307
00023 // USA
00024 // Questions? Contact Michael A. Heroux (maherou@sandia.gov) 
00025 // 
00026 // ************************************************************************
00027 //@HEADER
00028 
00029 #ifndef __TSQR_Tsqr_MatView_hpp
00030 #define __TSQR_Tsqr_MatView_hpp
00031 
00032 #include <cstring> // NULL
00033 
00034 // Define for bounds checking and other safety features, undefine for speed.
00035 // #define TSQR_MATVIEW_DEBUG 1
00036 
00037 #ifdef TSQR_MATVIEW_DEBUG
00038 #  include <limits>
00039 #endif // TSQR_MATVIEW_DEBUG
00040 
00041 #include <sstream>
00042 #include <stdexcept>
00043 
00046 
00047 namespace TSQR {
00048 
00049   template< class MatrixViewType1, class MatrixViewType2 >
00050   void
00051   matrixCopy (MatrixViewType1& A, const MatrixViewType2& B)
00052   {
00053     const typename MatrixViewType1::ordinal_type A_nrows = A.nrows();
00054     const typename MatrixViewType1::ordinal_type A_ncols = A.ncols();
00055     if (A_nrows != B.nrows() || A_ncols != B.ncols())
00056       {
00057   using std::endl;
00058   std::ostringstream os;
00059   os << "matrixCopy: dimensions of A (output matrix) "
00060     "and B (input matrix) are not compatible." 
00061      << endl
00062      << "A is " << A.nrows() << " x " << A.ncols() 
00063      << ", and B is " << B.nrows() << " x " << B.ncols() << ".";
00064   throw std::invalid_argument(os.str());
00065       }
00066     for (typename MatrixViewType1::ordinal_type j = 0; j < A_ncols; ++j)
00067       {
00068   typename MatrixViewType1::scalar_type* const A_j = &A(0,j);
00069   const typename MatrixViewType2::scalar_type* const B_j = &B(0,j);
00070   for (typename MatrixViewType1::ordinal_type i = 0; i < A_nrows; ++i)
00071     A_j[i] = B_j[i];
00072       }
00073   }
00074 
00075   template< class FirstMatrixViewType, class SecondMatrixViewType >
00076   bool
00077   matrix_equal (FirstMatrixViewType& A, SecondMatrixViewType& B)
00078   {
00079     if (A.nrows() != B.nrows() || A.ncols() != B.ncols())
00080       return false;
00081     
00082     typedef typename FirstMatrixViewType::ordinal_type first_ordinal_type;
00083     typedef typename SecondMatrixViewType::ordinal_type second_ordinal_type;
00084     typedef typename FirstMatrixViewType::pointer_type first_pointer_type;
00085     typedef typename SecondMatrixViewType::pointer_type second_pointer_type;
00086 
00087     const first_ordinal_type nrows = A.nrows();
00088     const first_ordinal_type A_lda = A.lda();
00089     const first_ordinal_type ncols = A.ncols();
00090     const second_ordinal_type B_lda = B.lda();
00091 
00092     first_pointer_type A_j = A.get();
00093     second_pointer_type B_j = B.get();
00094 
00095     for (first_ordinal_type j = 0; j < ncols; ++j, A_j += A_lda, B_j += B_lda)
00096       for (first_ordinal_type i = 0; i < nrows; ++i)
00097   if (A_j[i] != B_j[i])
00098     return false;
00099 
00100     return true;
00101   }
00102 
00103 #ifdef TSQR_MATVIEW_DEBUG
00104   template< class Ordinal, class Scalar >
00105   class MatViewVerify {
00106   public:
00107     static void 
00108     verify (const Ordinal num_rows, 
00109       const Ordinal num_cols, 
00110       const Scalar* const A, 
00111       const Ordinal leading_dim)
00112     {
00113       using std::endl;
00114 
00115       bool good = true;
00116       std::ostringstream os;
00117       if (! std::numeric_limits<Ordinal>::is_integer)
00118   {
00119     good = false;
00120     os << "Error: Ordinal type must be an integer." << endl;
00121   }
00122       if (std::numeric_limits<Ordinal>::is_signed)
00123   {
00124     if (num_rows < 0)
00125       {
00126         good = false;
00127         os << "Error: num_rows (= " << num_rows << ") < 0." << endl;
00128       }
00129     if (num_cols < 0)
00130       {
00131         good = false;
00132         os << "Error: num_cols (= " << num_cols << ") < 0." << endl;
00133       }
00134     if (leading_dim < 0)
00135       {
00136         good = false;
00137         os << "Error: leading_dim (= " << leading_dim << ") < 0." << endl;
00138       }
00139   }
00140       if (leading_dim < num_rows)
00141   {
00142     good = false;
00143     os << "Error: leading_dim (= " << leading_dim << ") < num_rows (= " << num_rows << ")." << endl;
00144   }
00145       if (! good)
00146   throw std::invalid_argument (os.str());
00147     }
00148   };
00149 #endif // TSQR_MATVIEW_DEBUG
00150 
00151 
00152   // Forward declaration
00153   template< class Ordinal, class Scalar >
00154   class ConstMatView;
00155 
00156   // Forward declaration
00157   template< class Ordinal, class Scalar >
00158   class Matrix;
00159 
00163   template< class Ordinal, class Scalar >
00164   class MatView {
00165   public:
00166     typedef Scalar scalar_type;
00167     typedef Ordinal ordinal_type;
00168     typedef Scalar* pointer_type;
00169 
00172     MatView () : nrows_(0), ncols_(0), lda_(0), A_(NULL) {}
00173 
00174     MatView (const Ordinal num_rows, 
00175        const Ordinal num_cols, 
00176        Scalar* const A, 
00177        const Ordinal leading_dim) :
00178       nrows_(num_rows),
00179       ncols_(num_cols),
00180       lda_(leading_dim),
00181       A_(A)
00182     {
00183 #ifdef TSQR_MATVIEW_DEBUG
00184       MatViewVerify< Ordinal, Scalar >::verify (num_rows, num_cols, A, leading_dim);
00185 #endif // TSQR_MATVIEW_DEBUG
00186     }
00187 
00188     MatView (const MatView& view) :
00189       nrows_(view.nrows()),
00190       ncols_(view.ncols()),
00191       lda_(view.lda()),
00192       A_(view.get())
00193     {}
00194 
00195     MatView& operator= (const MatView& view) {
00196       if (this != &view)
00197   {
00198     nrows_ = view.nrows();
00199     ncols_ = view.ncols();
00200     A_ = view.get();
00201     lda_ = view.lda();
00202   }
00203       return *this;
00204     }
00205 
00211     Scalar& operator() (const Ordinal i, const Ordinal j) const 
00212     {
00213 #ifdef TSQR_MATVIEW_DEBUG
00214       if (std::numeric_limits< Ordinal >::is_signed)
00215   {
00216     if (i < 0 || i >= nrows())
00217       throw std::invalid_argument("Row range invalid");
00218     else if (j < 0 || j >= ncols())
00219       throw std::invalid_argument("Column range invalid");
00220   }
00221       else
00222   {
00223     if (i >= nrows())
00224       throw std::invalid_argument("Row range invalid");
00225     else if (j >= ncols())
00226       throw std::invalid_argument("Column range invalid");
00227   }
00228       if (A_ == NULL)
00229   throw std::logic_error("Attempt to reference NULL data");
00230 #endif // TSQR_MATVIEW_DEBUG
00231       return A_[i + j*lda()];
00232     }
00233 
00234     Ordinal nrows() const { return nrows_; }
00235     Ordinal ncols() const { return ncols_; }
00236     Ordinal lda() const { return lda_; }
00237 
00242     pointer_type get() const { return A_; }
00243     bool empty() const { return nrows() == 0 || ncols() == 0; }
00244 
00247     MatView row_block (const Ordinal firstRow, const Ordinal lastRow) 
00248     {
00249 #ifdef TSQR_MATVIEW_DEBUG
00250       if (std::numeric_limits< Ordinal >::is_signed)
00251   {
00252     if (firstRow < 0 || firstRow > lastRow || lastRow >= nrows())
00253       throw std::invalid_argument ("Row range invalid");
00254   }
00255       else
00256   {
00257     if (firstRow > lastRow || lastRow >= nrows())
00258       throw std::invalid_argument ("Row range invalid");
00259   }
00260 #endif // TSQR_MATVIEW_DEBUG
00261       return MatView (lastRow - firstRow + 1, ncols(), get() + firstRow, lda());
00262     }
00263 
00279     MatView split_top (const Ordinal nrows_top, 
00280            const bool b_contiguous_blocks = false)
00281     {
00282 #ifdef TSQR_MATVIEW_DEBUG
00283       if (std::numeric_limits< Ordinal >::is_signed && nrows_top < 0)
00284   {
00285     std::ostringstream os;
00286     os << "nrows_top (= " << nrows_top << ") < 0";
00287     throw std::invalid_argument (os.str());
00288   }
00289       else if (nrows_top > nrows())
00290   {
00291     std::ostringstream os;
00292     os << "nrows_top (= " << nrows_top << ") > nrows (= " << nrows() << ")";
00293     throw std::invalid_argument (os.str());
00294   }
00295 #endif // TSQR_MATVIEW_DEBUG
00296 
00297       Scalar* const A_top_ptr = get();
00298       Scalar* A_rest_ptr;
00299       const Ordinal nrows_rest = nrows() - nrows_top;
00300       Ordinal lda_top, lda_rest;
00301       if (b_contiguous_blocks)
00302   {
00303     lda_top = nrows_top;
00304     lda_rest = nrows_rest;
00305     A_rest_ptr = A_top_ptr + nrows_top * ncols();
00306   }
00307       else
00308   {
00309     lda_top = lda();
00310     lda_rest = lda();
00311     A_rest_ptr = A_top_ptr + nrows_top;
00312   }
00313       MatView A_top (nrows_top, ncols(), get(), lda_top);
00314       A_ = A_rest_ptr;
00315       nrows_ = nrows_rest;
00316       lda_ = lda_rest;
00317 
00318       return A_top;
00319     }
00320 
00323     MatView split_bottom (const Ordinal nrows_bottom, 
00324         const bool b_contiguous_blocks = false)
00325     {
00326 #ifdef TSQR_MATVIEW_DEBUG
00327       if (std::numeric_limits< Ordinal >::is_signed && nrows_bottom < 0)
00328   throw std::invalid_argument ("nrows_bottom < 0");
00329       if (nrows_bottom > nrows())
00330   throw std::invalid_argument ("nrows_bottom > nrows");
00331 #endif // TSQR_MATVIEW_DEBUG
00332 
00333       Scalar* const A_rest_ptr = get();
00334       Scalar* A_bottom_ptr;
00335       const Ordinal nrows_rest = nrows() - nrows_bottom;
00336       Ordinal lda_bottom, lda_rest;
00337       if (b_contiguous_blocks)
00338   {
00339     lda_bottom = nrows_bottom;
00340     lda_rest = nrows() - nrows_bottom;
00341     A_bottom_ptr = A_rest_ptr + nrows_rest * ncols();
00342   }
00343       else
00344   {
00345     lda_bottom = lda();
00346     lda_rest = lda();
00347     A_bottom_ptr = A_rest_ptr + nrows_rest;
00348   }
00349       MatView A_bottom (nrows_bottom, ncols(), A_bottom_ptr, lda_bottom);
00350       A_ = A_rest_ptr;
00351       nrows_ = nrows_rest;
00352       lda_ = lda_rest;
00353 
00354       return A_bottom;
00355     }
00356 
00357     void
00358     fill (const scalar_type& value) 
00359     {
00360       const ordinal_type num_rows = nrows();
00361       const ordinal_type num_cols = ncols();
00362       const ordinal_type stride = lda();
00363 
00364       scalar_type* A_j = get();
00365       for (ordinal_type j = 0; j < num_cols; ++j, A_j += stride)
00366   for (ordinal_type i = 0; i < num_rows; ++i)
00367     A_j[i] = value;
00368     }
00369 
00373     void
00374     copy (const MatView< ordinal_type, scalar_type >& B) {
00375       matrixCopy (*this, B);
00376     }
00377     void
00378     copy (const ConstMatView< ordinal_type, scalar_type >& B) {
00379       matrixCopy (*this, B);
00380     }
00381     void
00382     copy (const Matrix< ordinal_type, scalar_type >& B) {
00383       matrixCopy (*this, B);
00384     }
00385 
00386     bool operator== (const MatView& rhs) const {
00387       return nrows() == rhs.nrows() && ncols() == rhs.ncols() && 
00388   lda() == rhs.lda() && get() == rhs.get();
00389     }
00390 
00391     bool operator!= (const MatView& rhs) const {
00392       return nrows() != rhs.nrows() || ncols() != rhs.ncols() || 
00393   lda() != rhs.lda() || get() != rhs.get();
00394     }
00395 
00396   private:
00397     ordinal_type nrows_, ncols_, lda_;
00398     scalar_type* A_;
00399   };
00400 
00401 
00409   template< class Ordinal, class Scalar >
00410   class ConstMatView {
00411   public:
00412     typedef Scalar scalar_type;
00413     typedef Ordinal ordinal_type;
00414     typedef const Scalar* pointer_type;
00415 
00416     ConstMatView () : nrows_(0), ncols_(0), lda_(0), A_(NULL) {}
00417 
00420     ConstMatView (const Ordinal num_rows, 
00421       const Ordinal num_cols, 
00422       const Scalar* const A, 
00423       const Ordinal leading_dim) :
00424       nrows_(num_rows),
00425       ncols_(num_cols),
00426       lda_(leading_dim),
00427       A_(A)
00428     {
00429 #ifdef TSQR_MATVIEW_DEBUG
00430       MatViewVerify< Ordinal, Scalar >::verify (num_rows, num_cols, A, leading_dim);
00431 #endif // TSQR_MATVIEW_DEBUG
00432     }
00433 
00434     ConstMatView (const ConstMatView& view) :
00435       nrows_(view.nrows()),
00436       ncols_(view.ncols()),
00437       lda_(view.lda()),
00438       A_(view.get())
00439     {}
00440 
00441     ConstMatView& operator= (const ConstMatView& view) {
00442       if (this != &view)
00443   {
00444     nrows_ = view.nrows();
00445     ncols_ = view.ncols();
00446     lda_ = view.lda();
00447     A_ = view.get();
00448   }
00449       return *this;
00450     }
00451 
00452     const Scalar& operator() (const Ordinal i, const Ordinal j) const 
00453     {
00454 #ifdef TSQR_MATVIEW_DEBUG
00455       if (std::numeric_limits< Ordinal >::is_signed)
00456   {
00457     if (i < 0 || i >= nrows())
00458       throw std::invalid_argument("Row range invalid");
00459     else if (j < 0 || j >= ncols())
00460       throw std::invalid_argument("Column range invalid");
00461   }
00462       else
00463   {
00464     if (i >= nrows())
00465       throw std::invalid_argument("Row range invalid");
00466     else if (j >= ncols())
00467       throw std::invalid_argument("Column range invalid");
00468   }
00469       if (A_ == NULL)
00470   throw std::logic_error("Attempt to reference NULL data");
00471 #endif // TSQR_MATVIEW_DEBUG
00472       return A_[i + j*lda()];
00473     }
00474 
00475     Ordinal nrows() const { return nrows_; }
00476     Ordinal ncols() const { return ncols_; }
00477     Ordinal lda() const { return lda_; }
00478     pointer_type get() const { return A_; }
00479     bool empty() const { return nrows() == 0 || ncols() == 0; }
00480 
00483     ConstMatView rowBlock (const Ordinal firstRow, 
00484          const Ordinal lastRow) const
00485     {
00486 #ifdef TSQR_MATVIEW_DEBUG
00487       if (firstRow < 0 || lastRow >= nrows())
00488   throw std::invalid_argument ("Row range invalid");
00489 #endif // TSQR_MATVIEW_DEBUG
00490       return ConstMatView (lastRow - firstRow + 1, ncols(), get() + firstRow, lda());
00491     }
00492 
00493 
00509     ConstMatView split_top (const Ordinal nrows_top, 
00510           const bool b_contiguous_blocks = false)
00511     {
00512 #ifdef TSQR_MATVIEW_DEBUG
00513       if (std::numeric_limits< Ordinal >::is_signed && nrows_top < 0)
00514   throw std::invalid_argument ("nrows_top < 0");
00515       if (nrows_top > nrows())
00516   throw std::invalid_argument ("nrows_top > nrows");
00517 #endif // TSQR_MATVIEW_DEBUG
00518 
00519       pointer_type const A_top_ptr = get();
00520       pointer_type A_rest_ptr;
00521       const Ordinal nrows_rest = nrows() - nrows_top;
00522       Ordinal lda_top, lda_rest;
00523       if (b_contiguous_blocks)
00524   {
00525     lda_top = nrows_top;
00526     lda_rest = nrows_rest;
00527     A_rest_ptr = A_top_ptr + nrows_top * ncols();
00528   }
00529       else
00530   {
00531     lda_top = lda();
00532     lda_rest = lda();
00533     A_rest_ptr = A_top_ptr + nrows_top;
00534   }
00535       ConstMatView A_top (nrows_top, ncols(), get(), lda_top);
00536       A_ = A_rest_ptr;
00537       nrows_ = nrows_rest;
00538       lda_ = lda_rest;
00539 
00540       return A_top;
00541     }
00542 
00543 
00546     ConstMatView split_bottom (const Ordinal nrows_bottom, 
00547              const bool b_contiguous_blocks = false)
00548     {
00549 #ifdef TSQR_MATVIEW_DEBUG
00550       if (std::numeric_limits< Ordinal >::is_signed && nrows_bottom < 0)
00551   throw std::invalid_argument ("nrows_bottom < 0");
00552       if (nrows_bottom > nrows())
00553   throw std::invalid_argument ("nrows_bottom > nrows");
00554 #endif // TSQR_MATVIEW_DEBUG
00555 
00556       pointer_type const A_rest_ptr = get();
00557       pointer_type A_bottom_ptr;
00558       const ordinal_type nrows_rest = nrows() - nrows_bottom;
00559       ordinal_type lda_bottom, lda_rest;
00560       if (b_contiguous_blocks)
00561   {
00562     lda_bottom = nrows_bottom;
00563     lda_rest = nrows() - nrows_bottom;
00564     A_bottom_ptr = A_rest_ptr + nrows_rest * ncols();
00565   }
00566       else
00567   {
00568     lda_bottom = lda();
00569     lda_rest = lda();
00570     A_bottom_ptr = A_rest_ptr + nrows_rest;
00571   }
00572       ConstMatView A_bottom (nrows_bottom, ncols(), A_bottom_ptr, lda_bottom);
00573       A_ = A_rest_ptr;
00574       nrows_ = nrows_rest;
00575       lda_ = lda_rest;
00576 
00577       return A_bottom;
00578     }
00579 
00580     bool operator== (const ConstMatView& rhs) const {
00581       return nrows() == rhs.nrows() && ncols() == rhs.ncols() && 
00582   lda() == rhs.lda() && get() == rhs.get();
00583     }
00584 
00585     bool operator!= (const ConstMatView& rhs) const {
00586       return nrows() != rhs.nrows() || ncols() != rhs.ncols() || 
00587   lda() != rhs.lda() || get() != rhs.get();
00588     }
00589 
00590 
00591   private:
00592     ordinal_type nrows_, ncols_, lda_;
00593     pointer_type A_;
00594   };
00595 
00596 } // namespace TSQR
00597 
00598 
00599 #endif // __TSQR_Tsqr_MatView_hpp
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends