Kokkos Node API and Local Linear Algebra Kernels Version of the Day
Kokkos_PackedSparseMultiply.hpp
00001 //@HEADER
00002 // ************************************************************************
00003 // 
00004 //          Kokkos: Node API and Parallel Node Kernels
00005 //              Copyright (2009) Sandia Corporation
00006 // 
00007 // Under terms of Contract DE-AC04-94AL85000, there is a non-exclusive
00008 // license for use of this work by or on behalf of the U.S. Government.
00009 // 
00010 // This library is free software; you can redistribute it and/or modify
00011 // it under the terms of the GNU Lesser General Public License as
00012 // published by the Free Software Foundation; either version 2.1 of the
00013 // License, or (at your option) any later version.
00014 //  
00015 // This library is distributed in the hope that it will be useful, but
00016 // WITHOUT ANY WARRANTY; without even the implied warranty of
00017 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
00018 // Lesser General Public License for more details.
00019 //  
00020 // You should have received a copy of the GNU Lesser General Public
00021 // License along with this library; if not, write to the Free Software
00022 // Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307
00023 // USA
00024 // Questions? Contact Michael A. Heroux (maherou@sandia.gov) 
00025 // 
00026 // ************************************************************************
00027 //@HEADER
00028 
00029 #ifndef KOKKOS_PACKEDSPARSEMULTIPLY_H
00030 #define KOKKOS_PACKEDSPARSEMULTIPLY_H
00031 
00032 #include "Kokkos_ConfigDefs.hpp"
00033 #include "Kokkos_CisMatrix.hpp" 
00034 #include "Kokkos_SparseOperation.hpp" 
00035 
00036 
00037 namespace Kokkos {
00038 
00040 
00069   template<typename OrdinalType, typename ScalarType>
00070   class PackedSparseMultiply: public virtual SparseOperation<OrdinalType, ScalarType> {
00071   public:
00072 
00074 
00076 
00078     PackedSparseMultiply();
00079   
00081     PackedSparseMultiply(const PackedSparseMultiply& source);
00082   
00084     virtual ~PackedSparseMultiply();
00086 
00087 
00089  
00091 
00099     virtual int initializeStructure(const CisMatrix<OrdinalType, ScalarType>& A, bool willKeepStructure = false);
00100  
00102 
00113     virtual int initializeValues(const CisMatrix<OrdinalType, ScalarType>& A, bool willKeepValues = false,
00114          bool checkStructure = false);
00115  
00117 
00119 
00121   
00123 
00131     virtual int apply(const MultiVector<OrdinalType, ScalarType>& x, MultiVector<OrdinalType, ScalarType>& y, 
00132           bool transA = false, bool conjA = false) const;
00134   
00136 
00138 
00140 
00142     virtual bool getCanUseStructure() const {return(false);};
00143 
00145 
00147     virtual bool getCanUseValues() const {return(false);};
00148 
00150     virtual const CisMatrix<OrdinalType, ScalarType> & getMatrix() const {
00151       if (matrixForValues_==0) return(*matrixForStructure_);
00152       else return(*matrixForValues_);
00153     };
00154     
00156   
00157   protected:
00158 
00159     void copyEntries();
00160     void deleteStructureAndValues();
00161 
00162     struct EntryStruct {
00163       OrdinalType index;
00164       ScalarType value;
00165     }; 
00166     typedef struct EntryStruct Entry;
00167     
00168     CisMatrix<OrdinalType, ScalarType> * matrixForStructure_;
00169     CisMatrix<OrdinalType, ScalarType> * matrixForValues_;
00170 
00171     bool isRowOriented_;
00172     bool haveStructure_;
00173     bool haveValues_;
00174     bool hasUnitDiagonal_;
00175   
00176     OrdinalType numRows_;
00177     OrdinalType numCols_;
00178     OrdinalType numRC_;
00179     OrdinalType numEntries_;
00180 
00181     OrdinalType * profile_;
00182     double costOfMatVec_;
00183     Entry * allEntries_;
00184   };
00185 
00186   //==============================================================================
00187   template<typename OrdinalType, typename ScalarType>
00188   PackedSparseMultiply<OrdinalType, ScalarType>::PackedSparseMultiply() 
00189     : matrixForStructure_(0),
00190       matrixForValues_(0),
00191       isRowOriented_(true),
00192       haveStructure_(false),
00193       haveValues_(false),
00194       hasUnitDiagonal_(false),
00195       numRows_(0),
00196       numCols_(0),
00197       numRC_(0),
00198       numEntries_(0),
00199       profile_(0),
00200       costOfMatVec_(0.0),
00201       allEntries_(0) {
00202   }
00203 
00204   //==============================================================================
00205   template<typename OrdinalType, typename ScalarType>
00206   PackedSparseMultiply<OrdinalType, ScalarType>::PackedSparseMultiply(const PackedSparseMultiply<OrdinalType, ScalarType> &source) 
00207     : matrixForStructure_(source.matrixForStructure_),
00208       matrixForValues_(source.matrixForValues_),
00209       isRowOriented_(source.isRowOriented_),
00210       haveStructure_(source.haveStructure_),
00211       haveValues_(source.haveValues_),
00212       hasUnitDiagonal_(source.hasUnitDiagonal_),
00213       numRows_(source.numRows_),
00214       numCols_(source.numCols_),
00215       numRC_(source.numRC_),
00216       numEntries_(source.numEntries_),
00217       profile_(source.profile_),
00218       costOfMatVec_(source.costOfMatVec_),
00219       allEntries_(source.allEntries_) {
00220 
00221     copyEntries();
00222   }
00223 
00224   //==============================================================================
00225   template<typename OrdinalType, typename ScalarType>
00226   void PackedSparseMultiply<OrdinalType, ScalarType>::copyEntries() {
00227 
00228     OrdinalType i;
00229 
00230     if (allEntries_!=0) {
00231       Entry * tmp_entries = new Entry[numEntries_];
00232       for (i=0; i< numEntries_; i++) {
00233   tmp_entries[i].index = allEntries_[i].index;
00234         tmp_entries[i].value = allEntries_[i].value;
00235       }
00236       allEntries_ = tmp_entries;
00237     }
00238     return;
00239   }
00240   //==============================================================================
00241   template<typename OrdinalType, typename ScalarType>
00242   void PackedSparseMultiply<OrdinalType, ScalarType>::deleteStructureAndValues() {
00243 
00244 
00245     OrdinalType i;
00246 
00247     if (profile_!=0) {
00248       delete [] profile_;
00249       profile_ = 0;
00250     }
00251 
00252     if (allEntries_!=0) {
00253       delete [] allEntries_;
00254       allEntries_ = 0;
00255     }
00256     return;
00257   }
00258   //==============================================================================
00259   template<typename OrdinalType, typename ScalarType>
00260   PackedSparseMultiply<OrdinalType, ScalarType>::~PackedSparseMultiply(){
00261 
00262     deleteStructureAndValues();
00263 
00264   }
00265 
00266   //==============================================================================
00267   template<typename OrdinalType, typename ScalarType>
00268   int PackedSparseMultiply<OrdinalType, ScalarType>::initializeStructure(const CisMatrix<OrdinalType, ScalarType>& A,
00269                        bool willKeepStructure) {
00270 
00271 
00272     if (haveStructure_) return(-1); // Can only call this one time!
00273 
00274     matrixForStructure_ = const_cast<CisMatrix<OrdinalType, ScalarType> *> (&A);
00275     OrdinalType i, j;
00276     isRowOriented_ = A.getIsRowOriented();
00277     hasUnitDiagonal_ = A.getHasImplicitUnitDiagonal();
00278     numRows_ = A.getNumRows();
00279     numCols_ = A.getNumCols();
00280     numEntries_ = A.getNumEntries();
00281     numRC_ = numCols_;
00282     if (isRowOriented_) numRC_ = numRows_;
00283 
00284     profile_ = new OrdinalType[numRC_];
00285 
00286     OrdinalType numRCEntries;
00287     OrdinalType * indicesRC;
00288 
00289       
00290     allEntries_ = new Entry[numEntries_]; // Allocate storage for all entries at once
00291     
00292     OrdinalType offset = 0;
00293     for (i=0; i< numRC_; i++) {
00294       int ierr = A.getIndices(i, numRCEntries, indicesRC);
00295       if (ierr<0) return(ierr);
00296       profile_[i] = numRCEntries;
00297       Entry * curRC = allEntries_+offset;
00298       for (j=0; j<numRCEntries; j++) curRC[j].index = indicesRC[j];
00299       offset += numRCEntries;
00300     }
00301 
00302     costOfMatVec_ = 2.0 * ((double) numEntries_);
00303     if (hasUnitDiagonal_) costOfMatVec_ += 2.0 * ((double) numRC_);
00304     haveStructure_ = true;
00305     return(0);
00306   }
00307 
00308   //==============================================================================
00309   template<typename OrdinalType, typename ScalarType>
00310   int PackedSparseMultiply<OrdinalType, ScalarType>::initializeValues(const CisMatrix<OrdinalType, ScalarType>& A, 
00311                      bool willKeepValues, bool checkStructure) {
00312 
00313     if (!haveStructure_) return(-1); // Must have structure first!
00314 
00315     matrixForValues_ = const_cast<CisMatrix<OrdinalType, ScalarType> *> (&A);
00316     OrdinalType i, j;
00317 
00318     ScalarType * valuesRC;
00319 
00320     OrdinalType offset = 0;
00321     for (i=0; i<numRC_; i++) {
00322       int ierr = A.getValues(i, valuesRC);
00323       if (ierr<0) return(ierr);
00324       Entry * curRC = allEntries_+offset;
00325       OrdinalType numRCEntries = profile_[i];
00326       for (j=0; j<numRCEntries; j++) curRC[j].value = valuesRC[j];
00327       offset += numRCEntries;
00328     }
00329     haveValues_ = true;
00330     return(0);
00331   }
00332 
00333 
00334   //==============================================================================
00335   template<typename OrdinalType, typename ScalarType>
00336   int PackedSparseMultiply<OrdinalType, ScalarType>::apply(const MultiVector<OrdinalType, ScalarType>& x, 
00337                 MultiVector<OrdinalType, ScalarType> & y,
00338                 bool transA, bool conjA) const {
00339     if (!haveValues_) return(-1); // Can't compute without values!
00340     if (conjA) return(-2); // Unsupported at this time
00341     if (x.getNumRows()!=numCols_) return(-3); // Number of cols in A not same as number of rows in x
00342     if (y.getNumRows()!=numRows_) return(-4); // Number of rows in A not same as number of rows in x
00343     OrdinalType numVectors = x.getNumCols();
00344     if (numVectors!=y.getNumCols()) return(-5); // Not the same number of vectors in x and y
00345 
00346     OrdinalType i, j, k, curNumEntries;
00347     Entry * curEntries = allEntries_;
00348 
00349     OrdinalType * profile = profile_;
00350 
00351     ScalarType ** xpp = x.getValues();
00352     ScalarType ** ypp = y.getValues();
00353 
00354     if ((isRowOriented_ && !transA) ||
00355   (!isRowOriented_ && transA)) {
00356       ScalarType sum = 0;
00357       for(i = 0; i < numRC_; i++) {
00358   curNumEntries = *profile++;
00359   for (k=0; k<numVectors; k++) {
00360     ScalarType * xp = xpp[k];
00361     ScalarType * yp = ypp[k];
00362     if (hasUnitDiagonal_)
00363       sum = xp[i];
00364     else
00365       sum = 0.0;
00366     for(j = 0; j < curNumEntries; j++)
00367       sum += curEntries[j].value * xp[curEntries[j].index];
00368     yp[i] = sum;
00369   }
00370   curEntries += curNumEntries;
00371       }
00372     }
00373     else {
00374       
00375       for (k=0; k<numVectors; k++) {
00376   ScalarType * yp = ypp[k];
00377   if (hasUnitDiagonal_) {
00378     ScalarType * xp = xpp[k];
00379     for(i = 0; i < numRC_; i++)
00380       yp[i] = xp[i]; // Initialize y
00381   }
00382   else
00383     for(i = 0; i < numRC_; i++)
00384       yp[i] = 0.0; // Initialize y
00385       }
00386       for(i = 0; i < numRC_; i++) {
00387   curNumEntries = *profile++;
00388   for (k=0; k<numVectors; k++) {
00389     ScalarType * xp = xpp[k];
00390     ScalarType * yp = ypp[k];
00391     for(j = 0; j < curNumEntries; j++)
00392       yp[curEntries[j].index] += curEntries[j].value * xp[i];
00393   }
00394   curEntries += curNumEntries;
00395       }
00396     }
00397     updateFlops(this->costOfMatVec_ * ((double) numVectors));
00398     return(0);
00399   }
00400 
00401 } // namespace Kokkos
00402 #endif /* KOKKOS_PACKEDSPARSEMULTIPLY_H */
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends