Kokkos Node API and Local Linear Algebra Kernels Version of the Day
Kokkos_DefaultSparseMultiplyKernelOps.hpp
00001 //@HEADER
00002 // ************************************************************************
00003 //
00004 //          Kokkos: Node API and Parallel Node Kernels
00005 //              Copyright (2008) Sandia Corporation
00006 //
00007 // Under the terms of Contract DE-AC04-94AL85000 with Sandia Corporation,
00008 // the U.S. Government retains certain rights in this software.
00009 //
00010 // Redistribution and use in source and binary forms, with or without
00011 // modification, are permitted provided that the following conditions are
00012 // met:
00013 //
00014 // 1. Redistributions of source code must retain the above copyright
00015 // notice, this list of conditions and the following disclaimer.
00016 //
00017 // 2. Redistributions in binary form must reproduce the above copyright
00018 // notice, this list of conditions and the following disclaimer in the
00019 // documentation and/or other materials provided with the distribution.
00020 //
00021 // 3. Neither the name of the Corporation nor the names of the
00022 // contributors may be used to endorse or promote products derived from
00023 // this software without specific prior written permission.
00024 //
00025 // THIS SOFTWARE IS PROVIDED BY SANDIA CORPORATION "AS IS" AND ANY
00026 // EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
00027 // IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
00028 // PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL SANDIA CORPORATION OR THE
00029 // CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
00030 // EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
00031 // PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
00032 // PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
00033 // LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
00034 // NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
00035 // SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
00036 //
00037 // Questions? Contact Michael A. Heroux (maherou@sandia.gov)
00038 //
00039 // ************************************************************************
00040 //@HEADER
00041 
00042 #ifndef KOKKOS_DEFAULTSPARSEMULTIPLY_KERNELOPS_HPP
00043 #define KOKKOS_DEFAULTSPARSEMULTIPLY_KERNELOPS_HPP
00044 
00045 #ifndef KERNEL_PREFIX
00046 #define KERNEL_PREFIX
00047 #endif
00048 
00049 #ifdef __CUDACC__
00050 #include <Teuchos_ScalarTraitsCUDA.hpp>
00051 #else
00052 #include <Teuchos_ScalarTraits.hpp>
00053 #endif
00054 
00055 namespace Kokkos {
00056 
00057   template <class Scalar, class Ordinal, class DomainScalar, class RangeScalar, int NO_BETA_AND_OVERWRITE>
00058   struct DefaultSparseMultiplyOp {
00059     // mat data
00060     const size_t  *ptrs;
00061     const Ordinal *inds;
00062     const Scalar  *vals;
00063     // matvec params
00064     RangeScalar        alpha, beta;
00065     size_t numRows;
00066     // mv data
00067     const DomainScalar  *x;
00068     RangeScalar         *y;
00069     size_t numRHS, xstride, ystride;
00070 
00071     inline KERNEL_PREFIX void execute(size_t row) {
00072       const Scalar  *v = vals + ptrs[row];
00073       const Ordinal *i = inds + ptrs[row],
00074                    *ie = inds + ptrs[row+1];
00075       if (NO_BETA_AND_OVERWRITE) {
00076         for (size_t j=0; j<numRHS; ++j) y[j*ystride+row] = Teuchos::ScalarTraits<RangeScalar>::zero();
00077       }
00078       else {
00079         for (size_t j=0; j<numRHS; ++j) y[j*ystride+row] *= beta;
00080       }
00081       // save the extra multiplication if possible
00082       if (alpha == Teuchos::ScalarTraits<RangeScalar>::one()) {
00083         while (i != ie)
00084         {
00085           const  Scalar val = *v++;
00086           const Ordinal ind = *i++;
00087           for (size_t j=0; j<numRHS; ++j) y[j*ystride+row] += (RangeScalar)val * (RangeScalar)x[j*xstride+ind];
00088         }
00089       }
00090       else { // alpha != one
00091         while (i != ie)
00092         {
00093           const  Scalar val = *v++;
00094           const Ordinal ind = *i++;
00095           for (size_t j=0; j<numRHS; ++j) y[j*ystride+row] += alpha * (RangeScalar)val * (RangeScalar)x[j*xstride+ind];
00096         }
00097       }
00098     }
00099   };
00100 
00101 
00102   // mfh 15 June 2012: I added the ScalarIsComplex Boolean template parameter to
00103   // avoid build errors due to attempts to cast from Scalar=int to
00104   // RangeScalar=std::complex<double>.  This additional template parameter is a
00105   // standard technique to specialize code for the real or complex arithmetic
00106   // case, when attempts to write a single code for both cases would result in
00107   // syntax errors.  The template parameter refers to RangeScalar, because we
00108   // have to decide whether the RangeScalar constructor takes one or two
00109   // arguments.
00110   //
00111   // If you wish to optimize further, you might like to add another Boolean
00112   // template parameter for whether Scalar is real or complex.  First, this
00113   // would let you avoid calling Teuchos::ScalarTraits<Scalar>::conjugate().
00114   // This should inline to nothing, but not all compilers are good at inlining.
00115   // Second, this would simplify the code for the case when both RangeScalar and
00116   // Scalar are both complex.  However, this second Boolean template parameter
00117   // is not necessary for correctness.
00118   template <class Scalar, // the type of entries in the sparse matrix
00119             class Ordinal, // the type of indices in the sparse matrix
00120             class DomainScalar, // the type of entries in the input (multi)vector
00121             class RangeScalar, // the type of entries in the output (multi)vector
00122             int NO_BETA_AND_OVERWRITE,
00123             bool RangeScalarIsComplex=Teuchos::ScalarTraits<RangeScalar>::isComplex> // whether RangeScalar is complex valued.
00124   struct DefaultSparseTransposeMultiplyOp;
00125 
00126   // Partial specialization for when RangeScalar is real (not complex) valued.
00127   template <class Scalar, class Ordinal, class DomainScalar, class RangeScalar, int NO_BETA_AND_OVERWRITE>
00128   struct DefaultSparseTransposeMultiplyOp<Scalar, Ordinal, DomainScalar, RangeScalar, NO_BETA_AND_OVERWRITE, false> {
00129     // mat data
00130     const size_t  *ptrs;
00131     const Ordinal *inds;
00132     const Scalar  *vals;
00133     // matvec params
00134     RangeScalar        alpha, beta;
00135     size_t numRows, numCols;
00136     // mv data
00137     const DomainScalar  *x;
00138     RangeScalar         *y;
00139     size_t numRHS, xstride, ystride;
00140 
00141     inline void execute() {
00142       using Teuchos::ScalarTraits;
00143 
00144       if (NO_BETA_AND_OVERWRITE) {
00145         for (size_t j=0; j<numRHS; ++j) {
00146           RangeScalar *yp = y+j*ystride;
00147           for (size_t row=0; row<numCols; ++row) {
00148             yp[row] = ScalarTraits<RangeScalar>::zero();
00149           }
00150         }
00151       }
00152       else {
00153         for (size_t j=0; j<numRHS; ++j) {
00154           RangeScalar *yp = y+j*ystride;
00155           for (size_t row=0; row<numCols; ++row) {
00156             yp[row] *= beta;
00157           }
00158         }
00159       }
00160       // save the extra multiplication if possible
00161       if (alpha == ScalarTraits<RangeScalar>::one()) {
00162         for (size_t colAt=0; colAt < numRows; ++colAt) {
00163           const Scalar  *v  = vals + ptrs[colAt];
00164           const Ordinal *i  = inds + ptrs[colAt];
00165           const Ordinal *ie = inds + ptrs[colAt+1];
00166           // sparse outer product: AT[:,colAt] * X[ind]
00167           while (i != ie) {
00168             const  Scalar val = ScalarTraits<Scalar>::conjugate (*v++);
00169             const Ordinal ind = *i++;
00170             for (size_t j = 0; j < numRHS; ++j) {
00171               // mfh 15 June 2012: Casting Scalar to RangeScalar may produce a
00172               // build warning, e.g., if Scalar is int and RangeScalar is
00173               // double.  The way to get it to work is not to rely on casting
00174               // Scalar to RangeScalar.  Instead, rely on C++'s type promotion
00175               // rules.  For now, we just cast the input vector's value to
00176               // RangeScalar, to handle the common iterative refinement case of
00177               // an output vector with higher precision.
00178               y[j*ystride+ind] += val * RangeScalar (x[j*xstride+colAt]);
00179             }
00180           }
00181         }
00182       }
00183       else { // alpha != one
00184         for (size_t colAt=0; colAt < numRows; ++colAt) {
00185           const Scalar  *v  = vals + ptrs[colAt];
00186           const Ordinal *i  = inds + ptrs[colAt];
00187           const Ordinal *ie = inds + ptrs[colAt+1];
00188           // sparse outer product: AT[:,colAt] * X[ind
00189           while (i != ie) {
00190             const  Scalar val = ScalarTraits<Scalar>::conjugate (*v++);
00191             const Ordinal ind = *i++;
00192             for (size_t j=0; j<numRHS; ++j) {
00193               // mfh 15 June 2012: See notes above about build warnings when
00194               // casting val from Scalar to RangeScalar.
00195               y[j*ystride+ind] += alpha * val * RangeScalar (x[j*xstride+colAt]);
00196             }
00197           }
00198         }
00199       }
00200     }
00201   };
00202 
00203   // Partial specialization for when RangeScalar is complex valued.  The most
00204   // common case is that RangeScalar is a specialization of std::complex, but
00205   // this is not necessarily the case.  For example, CUDA has its own complex
00206   // arithmetic type, which is necessary because std::complex's methods are not
00207   // marked as device methods.  This code assumes that RangeScalar, being
00208   // complex valued, has a constructor which takes two arguments, each of which
00209   // can be converted to Teuchos::ScalarTraits<RangeScalar>::magnitudeType.
00210   template <class Scalar, class Ordinal, class DomainScalar, class RangeScalar, int NO_BETA_AND_OVERWRITE>
00211   struct DefaultSparseTransposeMultiplyOp<Scalar, Ordinal, DomainScalar, RangeScalar, NO_BETA_AND_OVERWRITE, true> {
00212     // mat data
00213     const size_t  *ptrs;
00214     const Ordinal *inds;
00215     const Scalar  *vals;
00216     // matvec params
00217     RangeScalar        alpha, beta;
00218     size_t numRows, numCols;
00219     // mv data
00220     const DomainScalar  *x;
00221     RangeScalar         *y;
00222     size_t numRHS, xstride, ystride;
00223 
00224     inline void execute() {
00225       using Teuchos::ScalarTraits;
00226       typedef typename ScalarTraits<RangeScalar>::magnitudeType RSMT;
00227 
00228       if (NO_BETA_AND_OVERWRITE) {
00229         for (size_t j=0; j<numRHS; ++j) {
00230           RangeScalar *yp = y+j*ystride;
00231           for (size_t row=0; row<numCols; ++row) {
00232             yp[row] = ScalarTraits<RangeScalar>::zero();
00233           }
00234         }
00235       }
00236       else {
00237         for (size_t j=0; j<numRHS; ++j) {
00238           RangeScalar *yp = y+j*ystride;
00239           for (size_t row=0; row<numCols; ++row) {
00240             yp[row] *= beta;
00241           }
00242         }
00243       }
00244       // save the extra multiplication if possible
00245       if (alpha == ScalarTraits<RangeScalar>::one()) {
00246         for (size_t colAt=0; colAt < numRows; ++colAt) {
00247           const Scalar  *v  = vals + ptrs[colAt];
00248           const Ordinal *i  = inds + ptrs[colAt];
00249           const Ordinal *ie = inds + ptrs[colAt+1];
00250           // sparse outer product: AT[:,colAt] * X[ind]
00251           while (i != ie) {
00252             const  Scalar val = ScalarTraits<Scalar>::conjugate (*v++);
00253             const Ordinal ind = *i++;
00254             for (size_t j = 0; j < numRHS; ++j) {
00255               // mfh 15 June 2012: Casting Scalar to RangeScalar via a
00256               // static_cast won't work if Scalar is int and RangeScalar is
00257               // std::complex<double>.  (This is valid code.)  This is because
00258               // std::complex<double> doesn't have a constructor that takes one
00259               // int argument.  Furthermore, the C++ standard library doesn't
00260               // define operator*(int, std::complex<double>), so we can't rely
00261               // on C++'s type promotion rules.  However, any reasonable complex
00262               // arithmetic type should have a two-argument constructor that
00263               // takes arguments convertible to RSMT
00264               // (ScalarTraits<RangeScalar>::magnitudeType), and we should also
00265               // be able to cast from Scalar to RSMT.
00266               //
00267               // The mess with taking the real and imaginary components of val
00268               // is because Scalar could be either real or complex, but the
00269               // two-argument constructor of RangeScalar expects two real
00270               // arguments.  You can get rid of this by adding another template
00271               // parameter for whether Scalar is real or complex.
00272               y[j*ystride+ind] += RangeScalar (RSMT (ScalarTraits<Scalar>::real (val)),
00273                                                RSMT (ScalarTraits<Scalar>::imag (val))) *
00274                 RangeScalar (x[j*xstride+colAt]);
00275             }
00276           }
00277         }
00278       }
00279       else { // alpha != one
00280         for (size_t colAt=0; colAt < numRows; ++colAt) {
00281           const Scalar  *v  = vals + ptrs[colAt];
00282           const Ordinal *i  = inds + ptrs[colAt];
00283           const Ordinal *ie = inds + ptrs[colAt+1];
00284           // sparse outer product: AT[:,colAt] * X[ind]
00285           while (i != ie) {
00286             const  Scalar val = ScalarTraits<Scalar>::conjugate (*v++);
00287             const Ordinal ind = *i++;
00288             for (size_t j=0; j<numRHS; ++j) {
00289               // mfh 15 June 2012: See notes above about it sometimes being
00290               // invalid to cast val from Scalar to RangeScalar.
00291               y[j*ystride+ind] += alpha *
00292                 RangeScalar (RSMT (ScalarTraits<Scalar>::real (val)),
00293                              RSMT (ScalarTraits<Scalar>::imag (val))) *
00294                 RangeScalar (x[j*xstride+colAt]);
00295             }
00296           }
00297         }
00298       }
00299     }
00300   };
00301 
00302 } // namespace Kokkos
00303 
00304 #endif
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends