Anasazi Version of the Day
Tsqr_MatView.hpp
00001 // @HEADER
00002 // ***********************************************************************
00003 //
00004 //                 Anasazi: Block Eigensolvers Package
00005 //                 Copyright (2010) Sandia Corporation
00006 //
00007 // Under terms of Contract DE-AC04-94AL85000, there is a non-exclusive
00008 // license for use of this work by or on behalf of the U.S. Government.
00009 //
00010 // This library is free software; you can redistribute it and/or modify
00011 // it under the terms of the GNU Lesser General Public License as
00012 // published by the Free Software Foundation; either version 2.1 of the
00013 // License, or (at your option) any later version.
00014 //
00015 // This library is distributed in the hope that it will be useful, but
00016 // WITHOUT ANY WARRANTY; without even the implied warranty of
00017 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
00018 // Lesser General Public License for more details.
00019 //
00020 // You should have received a copy of the GNU Lesser General Public
00021 // License along with this library; if not, write to the Free Software
00022 // Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307
00023 // USA
00024 // Questions? Contact Michael A. Heroux (maherou@sandia.gov)
00025 //
00026 // ***********************************************************************
00027 // @HEADER
00028 
00029 #ifndef __TSQR_Tsqr_MatView_hpp
00030 #define __TSQR_Tsqr_MatView_hpp
00031 
00032 #include <cstring> // NULL
00033 
00034 // Define for bounds checking and other safety features, undefine for speed.
00035 // #define TSQR_MATVIEW_DEBUG 1
00036 
00037 #ifdef TSQR_MATVIEW_DEBUG
00038 #  include <limits>
00039 #endif // TSQR_MATVIEW_DEBUG
00040 
00041 #include <sstream>
00042 #include <stdexcept>
00043 
00046 
00047 namespace TSQR {
00048 
00049   template< class MatrixViewType1, class MatrixViewType2 >
00050   void
00051   matrixCopy (MatrixViewType1& A, const MatrixViewType2& B)
00052   {
00053     const typename MatrixViewType1::ordinal_type A_nrows = A.nrows();
00054     const typename MatrixViewType1::ordinal_type A_ncols = A.ncols();
00055     if (A_nrows != B.nrows() || A_ncols != B.ncols())
00056       throw std::invalid_argument("Dimensions of A and B are not compatible");
00057 
00058     for (typename MatrixViewType1::ordinal_type j = 0; j < A_ncols; ++j)
00059       {
00060   typename MatrixViewType1::scalar_type* const A_j = &A(0,j);
00061   const typename MatrixViewType2::scalar_type* const B_j = &B(0,j);
00062   for (typename MatrixViewType1::ordinal_type i = 0; i < A_nrows; ++i)
00063     A_j[i] = B_j[i];
00064       }
00065   }
00066 
00067   template< class FirstMatrixViewType, class SecondMatrixViewType >
00068   bool
00069   matrix_equal (FirstMatrixViewType& A, SecondMatrixViewType& B)
00070   {
00071     if (A.nrows() != B.nrows() || A.ncols() != B.ncols())
00072       return false;
00073     
00074     typedef typename FirstMatrixViewType::ordinal_type first_ordinal_type;
00075     typedef typename SecondMatrixViewType::ordinal_type second_ordinal_type;
00076     typedef typename FirstMatrixViewType::pointer_type first_pointer_type;
00077     typedef typename SecondMatrixViewType::pointer_type second_pointer_type;
00078 
00079     const first_ordinal_type nrows = A.nrows();
00080     const first_ordinal_type A_lda = A.lda();
00081     const first_ordinal_type ncols = A.ncols();
00082     const second_ordinal_type B_lda = B.lda();
00083 
00084     first_pointer_type A_j = A.get();
00085     second_pointer_type B_j = B.get();
00086 
00087     for (first_ordinal_type j = 0; j < ncols; ++j, A_j += A_lda, B_j += B_lda)
00088       for (first_ordinal_type i = 0; i < nrows; ++i)
00089   if (A_j[i] != B_j[i])
00090     return false;
00091 
00092     return true;
00093   }
00094 
00095 #ifdef TSQR_MATVIEW_DEBUG
00096   template< class Ordinal, class Scalar >
00097   class MatViewVerify {
00098   public:
00099     static void 
00100     verify (const Ordinal num_rows, 
00101       const Ordinal num_cols, 
00102       const Scalar* const A, 
00103       const Ordinal leading_dim)
00104     {
00105       using std::endl;
00106 
00107       bool good = true;
00108       std::ostringstream os;
00109       if (! std::numeric_limits<Ordinal>::is_integer)
00110   {
00111     good = false;
00112     os << "Error: Ordinal type must be an integer." << endl;
00113   }
00114       if (std::numeric_limits<Ordinal>::is_signed)
00115   {
00116     if (num_rows < 0)
00117       {
00118         good = false;
00119         os << "Error: num_rows (= " << num_rows << ") < 0." << endl;
00120       }
00121     if (num_cols < 0)
00122       {
00123         good = false;
00124         os << "Error: num_cols (= " << num_cols << ") < 0." << endl;
00125       }
00126     if (leading_dim < 0)
00127       {
00128         good = false;
00129         os << "Error: leading_dim (= " << leading_dim << ") < 0." << endl;
00130       }
00131   }
00132       if (leading_dim < num_rows)
00133   {
00134     good = false;
00135     os << "Error: leading_dim (= " << leading_dim << ") < num_rows (= " << num_rows << ")." << endl;
00136   }
00137       if (! good)
00138   throw std::invalid_argument (os.str());
00139     }
00140   };
00141 #endif // TSQR_MATVIEW_DEBUG
00142 
00143 
00144   // Forward declaration
00145   template< class Ordinal, class Scalar >
00146   class ConstMatView;
00147 
00148   // Forward declaration
00149   template< class Ordinal, class Scalar >
00150   class Matrix;
00151 
00155   template< class Ordinal, class Scalar >
00156   class MatView {
00157   public:
00158     typedef Scalar scalar_type;
00159     typedef Ordinal ordinal_type;
00160     typedef Scalar* pointer_type;
00161 
00164     MatView () : nrows_(0), ncols_(0), lda_(0), A_(NULL) {}
00165 
00166     MatView (const Ordinal num_rows, 
00167        const Ordinal num_cols, 
00168        Scalar* const A, 
00169        const Ordinal leading_dim) :
00170       nrows_(num_rows),
00171       ncols_(num_cols),
00172       lda_(leading_dim),
00173       A_(A)
00174     {
00175 #ifdef TSQR_MATVIEW_DEBUG
00176       MatViewVerify< Ordinal, Scalar >::verify (num_rows, num_cols, A, leading_dim);
00177 #endif // TSQR_MATVIEW_DEBUG
00178     }
00179 
00180     MatView (const MatView& view) :
00181       nrows_(view.nrows()),
00182       ncols_(view.ncols()),
00183       lda_(view.lda()),
00184       A_(view.get())
00185     {}
00186 
00187     MatView& operator= (const MatView& view) {
00188       if (this != &view)
00189   {
00190     nrows_ = view.nrows();
00191     ncols_ = view.ncols();
00192     A_ = view.get();
00193     lda_ = view.lda();
00194   }
00195       return *this;
00196     }
00197 
00203     Scalar& operator() (const Ordinal i, const Ordinal j) const 
00204     {
00205 #ifdef TSQR_MATVIEW_DEBUG
00206       if (std::numeric_limits< Ordinal >::is_signed)
00207   {
00208     if (i < 0 || i >= nrows())
00209       throw std::invalid_argument("Row range invalid");
00210     else if (j < 0 || j >= ncols())
00211       throw std::invalid_argument("Column range invalid");
00212   }
00213       else
00214   {
00215     if (i >= nrows())
00216       throw std::invalid_argument("Row range invalid");
00217     else if (j >= ncols())
00218       throw std::invalid_argument("Column range invalid");
00219   }
00220       if (A_ == NULL)
00221   throw std::logic_error("Attempt to reference NULL data");
00222 #endif // TSQR_MATVIEW_DEBUG
00223       return A_[i + j*lda()];
00224     }
00225 
00226     Ordinal nrows() const { return nrows_; }
00227     Ordinal ncols() const { return ncols_; }
00228     Ordinal lda() const { return lda_; }
00229 
00234     pointer_type get() const { return A_; }
00235     bool empty() const { return nrows() == 0 || ncols() == 0; }
00236 
00239     MatView row_block (const Ordinal firstRow, const Ordinal lastRow) 
00240     {
00241 #ifdef TSQR_MATVIEW_DEBUG
00242       if (std::numeric_limits< Ordinal >::is_signed)
00243   {
00244     if (firstRow < 0 || firstRow > lastRow || lastRow >= nrows())
00245       throw std::invalid_argument ("Row range invalid");
00246   }
00247       else
00248   {
00249     if (firstRow > lastRow || lastRow >= nrows())
00250       throw std::invalid_argument ("Row range invalid");
00251   }
00252 #endif // TSQR_MATVIEW_DEBUG
00253       return MatView (lastRow - firstRow + 1, ncols(), get() + firstRow, lda());
00254     }
00255 
00271     MatView split_top (const Ordinal nrows_top, 
00272            const bool b_contiguous_blocks = false)
00273     {
00274 #ifdef TSQR_MATVIEW_DEBUG
00275       if (std::numeric_limits< Ordinal >::is_signed && nrows_top < 0)
00276   {
00277     std::ostringstream os;
00278     os << "nrows_top (= " << nrows_top << ") < 0";
00279     throw std::invalid_argument (os.str());
00280   }
00281       else if (nrows_top > nrows())
00282   {
00283     std::ostringstream os;
00284     os << "nrows_top (= " << nrows_top << ") > nrows (= " << nrows() << ")";
00285     throw std::invalid_argument (os.str());
00286   }
00287 #endif // TSQR_MATVIEW_DEBUG
00288 
00289       Scalar* const A_top_ptr = get();
00290       Scalar* A_rest_ptr;
00291       const Ordinal nrows_rest = nrows() - nrows_top;
00292       Ordinal lda_top, lda_rest;
00293       if (b_contiguous_blocks)
00294   {
00295     lda_top = nrows_top;
00296     lda_rest = nrows_rest;
00297     A_rest_ptr = A_top_ptr + nrows_top * ncols();
00298   }
00299       else
00300   {
00301     lda_top = lda();
00302     lda_rest = lda();
00303     A_rest_ptr = A_top_ptr + nrows_top;
00304   }
00305       MatView A_top (nrows_top, ncols(), get(), lda_top);
00306       A_ = A_rest_ptr;
00307       nrows_ = nrows_rest;
00308       lda_ = lda_rest;
00309 
00310       return A_top;
00311     }
00312 
00315     MatView split_bottom (const Ordinal nrows_bottom, 
00316         const bool b_contiguous_blocks = false)
00317     {
00318 #ifdef TSQR_MATVIEW_DEBUG
00319       if (std::numeric_limits< Ordinal >::is_signed && nrows_bottom < 0)
00320   throw std::invalid_argument ("nrows_bottom < 0");
00321       if (nrows_bottom > nrows())
00322   throw std::invalid_argument ("nrows_bottom > nrows");
00323 #endif // TSQR_MATVIEW_DEBUG
00324 
00325       Scalar* const A_rest_ptr = get();
00326       Scalar* A_bottom_ptr;
00327       const Ordinal nrows_rest = nrows() - nrows_bottom;
00328       Ordinal lda_bottom, lda_rest;
00329       if (b_contiguous_blocks)
00330   {
00331     lda_bottom = nrows_bottom;
00332     lda_rest = nrows() - nrows_bottom;
00333     A_bottom_ptr = A_rest_ptr + nrows_rest * ncols();
00334   }
00335       else
00336   {
00337     lda_bottom = lda();
00338     lda_rest = lda();
00339     A_bottom_ptr = A_rest_ptr + nrows_rest;
00340   }
00341       MatView A_bottom (nrows_bottom, ncols(), A_bottom_ptr, lda_bottom);
00342       A_ = A_rest_ptr;
00343       nrows_ = nrows_rest;
00344       lda_ = lda_rest;
00345 
00346       return A_bottom;
00347     }
00348 
00349     void
00350     fill (const scalar_type& value) 
00351     {
00352       const ordinal_type num_rows = nrows();
00353       const ordinal_type num_cols = ncols();
00354       const ordinal_type stride = lda();
00355 
00356       scalar_type* A_j = get();
00357       for (ordinal_type j = 0; j < num_cols; ++j, A_j += stride)
00358   for (ordinal_type i = 0; i < num_rows; ++i)
00359     A_j[i] = value;
00360     }
00361 
00365     void
00366     copy (const MatView< ordinal_type, scalar_type >& B) {
00367       matrixCopy (*this, B);
00368     }
00369     void
00370     copy (const ConstMatView< ordinal_type, scalar_type >& B) {
00371       matrixCopy (*this, B);
00372     }
00373     void
00374     copy (const Matrix< ordinal_type, scalar_type >& B) {
00375       matrixCopy (*this, B);
00376     }
00377 
00378     bool operator== (const MatView& rhs) const {
00379       return nrows() == rhs.nrows() && ncols() == rhs.ncols() && 
00380   lda() == rhs.lda() && get() == rhs.get();
00381     }
00382 
00383     bool operator!= (const MatView& rhs) const {
00384       return nrows() != rhs.nrows() || ncols() != rhs.ncols() || 
00385   lda() != rhs.lda() || get() != rhs.get();
00386     }
00387 
00388   private:
00389     ordinal_type nrows_, ncols_, lda_;
00390     scalar_type* A_;
00391   };
00392 
00393 
00401   template< class Ordinal, class Scalar >
00402   class ConstMatView {
00403   public:
00404     typedef Scalar scalar_type;
00405     typedef Ordinal ordinal_type;
00406     typedef const Scalar* pointer_type;
00407 
00408     ConstMatView () : nrows_(0), ncols_(0), lda_(0), A_(NULL) {}
00409 
00412     ConstMatView (const Ordinal num_rows, 
00413       const Ordinal num_cols, 
00414       const Scalar* const A, 
00415       const Ordinal leading_dim) :
00416       nrows_(num_rows),
00417       ncols_(num_cols),
00418       lda_(leading_dim),
00419       A_(A)
00420     {
00421 #ifdef TSQR_MATVIEW_DEBUG
00422       MatViewVerify< Ordinal, Scalar >::verify (num_rows, num_cols, A, leading_dim);
00423 #endif // TSQR_MATVIEW_DEBUG
00424     }
00425 
00426     ConstMatView (const ConstMatView& view) :
00427       nrows_(view.nrows()),
00428       ncols_(view.ncols()),
00429       lda_(view.lda()),
00430       A_(view.get())
00431     {}
00432 
00433     ConstMatView& operator= (const ConstMatView& view) {
00434       if (this != &view)
00435   {
00436     nrows_ = view.nrows();
00437     ncols_ = view.ncols();
00438     lda_ = view.lda();
00439     A_ = view.get();
00440   }
00441       return *this;
00442     }
00443 
00444     const Scalar& operator() (const Ordinal i, const Ordinal j) const 
00445     {
00446 #ifdef TSQR_MATVIEW_DEBUG
00447       if (std::numeric_limits< Ordinal >::is_signed)
00448   {
00449     if (i < 0 || i >= nrows())
00450       throw std::invalid_argument("Row range invalid");
00451     else if (j < 0 || j >= ncols())
00452       throw std::invalid_argument("Column range invalid");
00453   }
00454       else
00455   {
00456     if (i >= nrows())
00457       throw std::invalid_argument("Row range invalid");
00458     else if (j >= ncols())
00459       throw std::invalid_argument("Column range invalid");
00460   }
00461       if (A_ == NULL)
00462   throw std::logic_error("Attempt to reference NULL data");
00463 #endif // TSQR_MATVIEW_DEBUG
00464       return A_[i + j*lda()];
00465     }
00466 
00467     Ordinal nrows() const { return nrows_; }
00468     Ordinal ncols() const { return ncols_; }
00469     Ordinal lda() const { return lda_; }
00470     pointer_type get() const { return A_; }
00471     bool empty() const { return nrows() == 0 || ncols() == 0; }
00472 
00475     ConstMatView rowBlock (const Ordinal firstRow, 
00476          const Ordinal lastRow) const
00477     {
00478 #ifdef TSQR_MATVIEW_DEBUG
00479       if (firstRow < 0 || lastRow >= nrows())
00480   throw std::invalid_argument ("Row range invalid");
00481 #endif // TSQR_MATVIEW_DEBUG
00482       return ConstMatView (lastRow - firstRow + 1, ncols(), get() + firstRow, lda());
00483     }
00484 
00485 
00501     ConstMatView split_top (const Ordinal nrows_top, 
00502           const bool b_contiguous_blocks = false)
00503     {
00504 #ifdef TSQR_MATVIEW_DEBUG
00505       if (std::numeric_limits< Ordinal >::is_signed && nrows_top < 0)
00506   throw std::invalid_argument ("nrows_top < 0");
00507       if (nrows_top > nrows())
00508   throw std::invalid_argument ("nrows_top > nrows");
00509 #endif // TSQR_MATVIEW_DEBUG
00510 
00511       pointer_type const A_top_ptr = get();
00512       pointer_type A_rest_ptr;
00513       const Ordinal nrows_rest = nrows() - nrows_top;
00514       Ordinal lda_top, lda_rest;
00515       if (b_contiguous_blocks)
00516   {
00517     lda_top = nrows_top;
00518     lda_rest = nrows_rest;
00519     A_rest_ptr = A_top_ptr + nrows_top * ncols();
00520   }
00521       else
00522   {
00523     lda_top = lda();
00524     lda_rest = lda();
00525     A_rest_ptr = A_top_ptr + nrows_top;
00526   }
00527       ConstMatView A_top (nrows_top, ncols(), get(), lda_top);
00528       A_ = A_rest_ptr;
00529       nrows_ = nrows_rest;
00530       lda_ = lda_rest;
00531 
00532       return A_top;
00533     }
00534 
00535 
00538     ConstMatView split_bottom (const Ordinal nrows_bottom, 
00539              const bool b_contiguous_blocks = false)
00540     {
00541 #ifdef TSQR_MATVIEW_DEBUG
00542       if (std::numeric_limits< Ordinal >::is_signed && nrows_bottom < 0)
00543   throw std::invalid_argument ("nrows_bottom < 0");
00544       if (nrows_bottom > nrows())
00545   throw std::invalid_argument ("nrows_bottom > nrows");
00546 #endif // TSQR_MATVIEW_DEBUG
00547 
00548       pointer_type const A_rest_ptr = get();
00549       pointer_type A_bottom_ptr;
00550       const ordinal_type nrows_rest = nrows() - nrows_bottom;
00551       ordinal_type lda_bottom, lda_rest;
00552       if (b_contiguous_blocks)
00553   {
00554     lda_bottom = nrows_bottom;
00555     lda_rest = nrows() - nrows_bottom;
00556     A_bottom_ptr = A_rest_ptr + nrows_rest * ncols();
00557   }
00558       else
00559   {
00560     lda_bottom = lda();
00561     lda_rest = lda();
00562     A_bottom_ptr = A_rest_ptr + nrows_rest;
00563   }
00564       ConstMatView A_bottom (nrows_bottom, ncols(), A_bottom_ptr, lda_bottom);
00565       A_ = A_rest_ptr;
00566       nrows_ = nrows_rest;
00567       lda_ = lda_rest;
00568 
00569       return A_bottom;
00570     }
00571 
00572     bool operator== (const ConstMatView& rhs) const {
00573       return nrows() == rhs.nrows() && ncols() == rhs.ncols() && 
00574   lda() == rhs.lda() && get() == rhs.get();
00575     }
00576 
00577     bool operator!= (const ConstMatView& rhs) const {
00578       return nrows() != rhs.nrows() || ncols() != rhs.ncols() || 
00579   lda() != rhs.lda() || get() != rhs.get();
00580     }
00581 
00582 
00583   private:
00584     ordinal_type nrows_, ncols_, lda_;
00585     pointer_type A_;
00586   };
00587 
00588 } // namespace TSQR
00589 
00590 
00591 #endif // __TSQR_Tsqr_MatView_hpp
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends