Kokkos Node API and Local Linear Algebra Kernels Version of the Day
Kokkos_PackedSparseMultiply.hpp
00001 //@HEADER
00002 // ************************************************************************
00003 // 
00004 //          Kokkos: Node API and Parallel Node Kernels
00005 //              Copyright (2008) Sandia Corporation
00006 // 
00007 // Under the terms of Contract DE-AC04-94AL85000 with Sandia Corporation,
00008 // the U.S. Government retains certain rights in this software.
00009 // 
00010 // Redistribution and use in source and binary forms, with or without
00011 // modification, are permitted provided that the following conditions are
00012 // met:
00013 //
00014 // 1. Redistributions of source code must retain the above copyright
00015 // notice, this list of conditions and the following disclaimer.
00016 //
00017 // 2. Redistributions in binary form must reproduce the above copyright
00018 // notice, this list of conditions and the following disclaimer in the
00019 // documentation and/or other materials provided with the distribution.
00020 //
00021 // 3. Neither the name of the Corporation nor the names of the
00022 // contributors may be used to endorse or promote products derived from
00023 // this software without specific prior written permission.
00024 //
00025 // THIS SOFTWARE IS PROVIDED BY SANDIA CORPORATION "AS IS" AND ANY
00026 // EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
00027 // IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
00028 // PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL SANDIA CORPORATION OR THE
00029 // CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
00030 // EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
00031 // PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
00032 // PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
00033 // LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
00034 // NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
00035 // SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
00036 //
00037 // Questions? Contact Michael A. Heroux (maherou@sandia.gov) 
00038 // 
00039 // ************************************************************************
00040 //@HEADER
00041 
00042 #ifndef KOKKOS_PACKEDSPARSEMULTIPLY_H
00043 #define KOKKOS_PACKEDSPARSEMULTIPLY_H
00044 
00045 #include "Kokkos_ConfigDefs.hpp"
00046 #include "Kokkos_CisMatrix.hpp" 
00047 #include "Kokkos_SparseOperation.hpp" 
00048 
00049 
00050 namespace Kokkos {
00051 
00053 
00082   template<typename OrdinalType, typename ScalarType>
00083   class PackedSparseMultiply: public virtual SparseOperation<OrdinalType, ScalarType> {
00084   public:
00085 
00087 
00089 
00091     PackedSparseMultiply();
00092   
00094     PackedSparseMultiply(const PackedSparseMultiply& source);
00095   
00097     virtual ~PackedSparseMultiply();
00099 
00100 
00102  
00104 
00112     virtual int initializeStructure(const CisMatrix<OrdinalType, ScalarType>& A, bool willKeepStructure = false);
00113  
00115 
00126     virtual int initializeValues(const CisMatrix<OrdinalType, ScalarType>& A, bool willKeepValues = false,
00127          bool checkStructure = false);
00128  
00130 
00132 
00134   
00136 
00144     virtual int apply(const MultiVector<OrdinalType, ScalarType>& x, MultiVector<OrdinalType, ScalarType>& y, 
00145           bool transA = false, bool conjA = false) const;
00147   
00149 
00151 
00153 
00155     virtual bool getCanUseStructure() const {return(false);};
00156 
00158 
00160     virtual bool getCanUseValues() const {return(false);};
00161 
00163     virtual const CisMatrix<OrdinalType, ScalarType> & getMatrix() const {
00164       if (matrixForValues_==0) return(*matrixForStructure_);
00165       else return(*matrixForValues_);
00166     };
00167     
00169   
00170   protected:
00171 
00172     void copyEntries();
00173     void deleteStructureAndValues();
00174 
00175     struct EntryStruct {
00176       OrdinalType index;
00177       ScalarType value;
00178     }; 
00179     typedef struct EntryStruct Entry;
00180     
00181     CisMatrix<OrdinalType, ScalarType> * matrixForStructure_;
00182     CisMatrix<OrdinalType, ScalarType> * matrixForValues_;
00183 
00184     bool isRowOriented_;
00185     bool haveStructure_;
00186     bool haveValues_;
00187     bool hasUnitDiagonal_;
00188   
00189     OrdinalType numRows_;
00190     OrdinalType numCols_;
00191     OrdinalType numRC_;
00192     OrdinalType numEntries_;
00193 
00194     OrdinalType * profile_;
00195     double costOfMatVec_;
00196     Entry * allEntries_;
00197   };
00198 
00199   //==============================================================================
00200   template<typename OrdinalType, typename ScalarType>
00201   PackedSparseMultiply<OrdinalType, ScalarType>::PackedSparseMultiply() 
00202     : matrixForStructure_(0),
00203       matrixForValues_(0),
00204       isRowOriented_(true),
00205       haveStructure_(false),
00206       haveValues_(false),
00207       hasUnitDiagonal_(false),
00208       numRows_(0),
00209       numCols_(0),
00210       numRC_(0),
00211       numEntries_(0),
00212       profile_(0),
00213       costOfMatVec_(0.0),
00214       allEntries_(0) {
00215   }
00216 
00217   //==============================================================================
00218   template<typename OrdinalType, typename ScalarType>
00219   PackedSparseMultiply<OrdinalType, ScalarType>::PackedSparseMultiply(const PackedSparseMultiply<OrdinalType, ScalarType> &source) 
00220     : matrixForStructure_(source.matrixForStructure_),
00221       matrixForValues_(source.matrixForValues_),
00222       isRowOriented_(source.isRowOriented_),
00223       haveStructure_(source.haveStructure_),
00224       haveValues_(source.haveValues_),
00225       hasUnitDiagonal_(source.hasUnitDiagonal_),
00226       numRows_(source.numRows_),
00227       numCols_(source.numCols_),
00228       numRC_(source.numRC_),
00229       numEntries_(source.numEntries_),
00230       profile_(source.profile_),
00231       costOfMatVec_(source.costOfMatVec_),
00232       allEntries_(source.allEntries_) {
00233 
00234     copyEntries();
00235   }
00236 
00237   //==============================================================================
00238   template<typename OrdinalType, typename ScalarType>
00239   void PackedSparseMultiply<OrdinalType, ScalarType>::copyEntries() {
00240 
00241     OrdinalType i;
00242 
00243     if (allEntries_!=0) {
00244       Entry * tmp_entries = new Entry[numEntries_];
00245       for (i=0; i< numEntries_; i++) {
00246   tmp_entries[i].index = allEntries_[i].index;
00247         tmp_entries[i].value = allEntries_[i].value;
00248       }
00249       allEntries_ = tmp_entries;
00250     }
00251     return;
00252   }
00253   //==============================================================================
00254   template<typename OrdinalType, typename ScalarType>
00255   void PackedSparseMultiply<OrdinalType, ScalarType>::deleteStructureAndValues() {
00256 
00257 
00258     OrdinalType i;
00259 
00260     if (profile_!=0) {
00261       delete [] profile_;
00262       profile_ = 0;
00263     }
00264 
00265     if (allEntries_!=0) {
00266       delete [] allEntries_;
00267       allEntries_ = 0;
00268     }
00269     return;
00270   }
00271   //==============================================================================
00272   template<typename OrdinalType, typename ScalarType>
00273   PackedSparseMultiply<OrdinalType, ScalarType>::~PackedSparseMultiply(){
00274 
00275     deleteStructureAndValues();
00276 
00277   }
00278 
00279   //==============================================================================
00280   template<typename OrdinalType, typename ScalarType>
00281   int PackedSparseMultiply<OrdinalType, ScalarType>::initializeStructure(const CisMatrix<OrdinalType, ScalarType>& A,
00282                        bool willKeepStructure) {
00283 
00284 
00285     if (haveStructure_) return(-1); // Can only call this one time!
00286 
00287     matrixForStructure_ = const_cast<CisMatrix<OrdinalType, ScalarType> *> (&A);
00288     OrdinalType i, j;
00289     isRowOriented_ = A.getIsRowOriented();
00290     hasUnitDiagonal_ = A.getHasImplicitUnitDiagonal();
00291     numRows_ = A.getNumRows();
00292     numCols_ = A.getNumCols();
00293     numEntries_ = A.getNumEntries();
00294     numRC_ = numCols_;
00295     if (isRowOriented_) numRC_ = numRows_;
00296 
00297     profile_ = new OrdinalType[numRC_];
00298 
00299     OrdinalType numRCEntries;
00300     OrdinalType * indicesRC;
00301 
00302       
00303     allEntries_ = new Entry[numEntries_]; // Allocate storage for all entries at once
00304     
00305     OrdinalType offset = 0;
00306     for (i=0; i< numRC_; i++) {
00307       int ierr = A.getIndices(i, numRCEntries, indicesRC);
00308       if (ierr<0) return(ierr);
00309       profile_[i] = numRCEntries;
00310       Entry * curRC = allEntries_+offset;
00311       for (j=0; j<numRCEntries; j++) curRC[j].index = indicesRC[j];
00312       offset += numRCEntries;
00313     }
00314 
00315     costOfMatVec_ = 2.0 * ((double) numEntries_);
00316     if (hasUnitDiagonal_) costOfMatVec_ += 2.0 * ((double) numRC_);
00317     haveStructure_ = true;
00318     return(0);
00319   }
00320 
00321   //==============================================================================
00322   template<typename OrdinalType, typename ScalarType>
00323   int PackedSparseMultiply<OrdinalType, ScalarType>::initializeValues(const CisMatrix<OrdinalType, ScalarType>& A, 
00324                      bool willKeepValues, bool checkStructure) {
00325 
00326     if (!haveStructure_) return(-1); // Must have structure first!
00327 
00328     matrixForValues_ = const_cast<CisMatrix<OrdinalType, ScalarType> *> (&A);
00329     OrdinalType i, j;
00330 
00331     ScalarType * valuesRC;
00332 
00333     OrdinalType offset = 0;
00334     for (i=0; i<numRC_; i++) {
00335       int ierr = A.getValues(i, valuesRC);
00336       if (ierr<0) return(ierr);
00337       Entry * curRC = allEntries_+offset;
00338       OrdinalType numRCEntries = profile_[i];
00339       for (j=0; j<numRCEntries; j++) curRC[j].value = valuesRC[j];
00340       offset += numRCEntries;
00341     }
00342     haveValues_ = true;
00343     return(0);
00344   }
00345 
00346 
00347   //==============================================================================
00348   template<typename OrdinalType, typename ScalarType>
00349   int PackedSparseMultiply<OrdinalType, ScalarType>::apply(const MultiVector<OrdinalType, ScalarType>& x, 
00350                 MultiVector<OrdinalType, ScalarType> & y,
00351                 bool transA, bool conjA) const {
00352     if (!haveValues_) return(-1); // Can't compute without values!
00353     if (conjA) return(-2); // Unsupported at this time
00354     if (x.getNumRows()!=numCols_) return(-3); // Number of cols in A not same as number of rows in x
00355     if (y.getNumRows()!=numRows_) return(-4); // Number of rows in A not same as number of rows in x
00356     OrdinalType numVectors = x.getNumCols();
00357     if (numVectors!=y.getNumCols()) return(-5); // Not the same number of vectors in x and y
00358 
00359     OrdinalType i, j, k, curNumEntries;
00360     Entry * curEntries = allEntries_;
00361 
00362     OrdinalType * profile = profile_;
00363 
00364     ScalarType ** xpp = x.getValues();
00365     ScalarType ** ypp = y.getValues();
00366 
00367     if ((isRowOriented_ && !transA) ||
00368   (!isRowOriented_ && transA)) {
00369       ScalarType sum = 0;
00370       for(i = 0; i < numRC_; i++) {
00371   curNumEntries = *profile++;
00372   for (k=0; k<numVectors; k++) {
00373     ScalarType * xp = xpp[k];
00374     ScalarType * yp = ypp[k];
00375     if (hasUnitDiagonal_)
00376       sum = xp[i];
00377     else
00378       sum = 0.0;
00379     for(j = 0; j < curNumEntries; j++)
00380       sum += curEntries[j].value * xp[curEntries[j].index];
00381     yp[i] = sum;
00382   }
00383   curEntries += curNumEntries;
00384       }
00385     }
00386     else {
00387       
00388       for (k=0; k<numVectors; k++) {
00389   ScalarType * yp = ypp[k];
00390   if (hasUnitDiagonal_) {
00391     ScalarType * xp = xpp[k];
00392     for(i = 0; i < numRC_; i++)
00393       yp[i] = xp[i]; // Initialize y
00394   }
00395   else
00396     for(i = 0; i < numRC_; i++)
00397       yp[i] = 0.0; // Initialize y
00398       }
00399       for(i = 0; i < numRC_; i++) {
00400   curNumEntries = *profile++;
00401   for (k=0; k<numVectors; k++) {
00402     ScalarType * xp = xpp[k];
00403     ScalarType * yp = ypp[k];
00404     for(j = 0; j < curNumEntries; j++)
00405       yp[curEntries[j].index] += curEntries[j].value * xp[i];
00406   }
00407   curEntries += curNumEntries;
00408       }
00409     }
00410     updateFlops(this->costOfMatVec_ * ((double) numVectors));
00411     return(0);
00412   }
00413 
00414 } // namespace Kokkos
00415 #endif /* KOKKOS_PACKEDSPARSEMULTIPLY_H */
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends