Kokkos Node API and Local Linear Algebra Kernels Version of the Day
Kokkos_DefaultSparseSolveKernelOps.hpp
00001 //@HEADER
00002 // ************************************************************************
00003 // 
00004 //          Kokkos: Node API and Parallel Node Kernels
00005 //              Copyright (2008) Sandia Corporation
00006 // 
00007 // Under the terms of Contract DE-AC04-94AL85000 with Sandia Corporation,
00008 // the U.S. Government retains certain rights in this software.
00009 // 
00010 // Redistribution and use in source and binary forms, with or without
00011 // modification, are permitted provided that the following conditions are
00012 // met:
00013 //
00014 // 1. Redistributions of source code must retain the above copyright
00015 // notice, this list of conditions and the following disclaimer.
00016 //
00017 // 2. Redistributions in binary form must reproduce the above copyright
00018 // notice, this list of conditions and the following disclaimer in the
00019 // documentation and/or other materials provided with the distribution.
00020 //
00021 // 3. Neither the name of the Corporation nor the names of the
00022 // contributors may be used to endorse or promote products derived from
00023 // this software without specific prior written permission.
00024 //
00025 // THIS SOFTWARE IS PROVIDED BY SANDIA CORPORATION "AS IS" AND ANY
00026 // EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
00027 // IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
00028 // PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL SANDIA CORPORATION OR THE
00029 // CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
00030 // EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
00031 // PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
00032 // PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
00033 // LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
00034 // NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
00035 // SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
00036 //
00037 // Questions? Contact Michael A. Heroux (maherou@sandia.gov) 
00038 // 
00039 // ************************************************************************
00040 //@HEADER
00041 
00042 #ifndef KOKKOS_DEFAULTSPARSESOLVE_KERNELOPS_HPP
00043 #define KOKKOS_DEFAULTSPARSESOLVE_KERNELOPS_HPP
00044 
00045 #ifndef KERNEL_PREFIX
00046 #define KERNEL_PREFIX
00047 #endif
00048 
00049 #ifdef __CUDACC__
00050 #include <Teuchos_ScalarTraitsCUDA.hpp>
00051 #else
00052 #include <Teuchos_ScalarTraits.hpp>
00053 #endif
00054 
00055 
00056 namespace Kokkos {
00057 
00058   // 
00059   // Matrix formatting and mat-vec options
00060   // Applies to all four operations below
00061   // 
00062   // unitDiag indicates whether we neglect the diagonal row entry and scale by it
00063   // or utilize all row entries and implicitly scale by a unit diagonal (i.e., don't need to scale)
00064   // upper (versus lower) will determine the ordering of the solve and the location of the diagonal
00065   // 
00066   // upper -> diagonal is first entry on row
00067   // lower -> diagonal is last entry on row
00068   // 
00069 
00070   template <class Scalar, class Ordinal, class DomainScalar, class RangeScalar>
00071   struct DefaultSparseSolveOp {
00072     // mat data
00073     const size_t  *ptrs;
00074     const Ordinal *inds;
00075     const Scalar  *vals;
00076     size_t numRows;
00077     // matvec params
00078     bool unitDiag, upper;
00079     // mv data
00080     DomainScalar  *x;
00081     const RangeScalar *y;
00082     size_t numRHS, xstride, ystride;
00083 
00084     inline KERNEL_PREFIX void execute() {
00085       // solve for X in A * X = Y
00086       // 
00087       // upper triangular requires backwards substition, solving in reverse order
00088       // must unroll the last iteration, because decrement results in wrap-around
00089       // 
00090       if (upper && unitDiag) {
00091         // upper + unit
00092         for (size_t j=0; j<numRHS; ++j) x[j*xstride+numRows-1] = (DomainScalar)y[j*ystride+numRows-1];
00093         for (size_t r=2; r < numRows+1; ++r) {
00094           const size_t row = numRows - r; // for row=numRows-2 to 0 step -1
00095           const Ordinal *i = inds+ptrs[row],
00096                        *ie = inds+ptrs[row+1];
00097           const Scalar  *v = vals+ptrs[row];
00098           for (size_t j=0; j<numRHS; ++j) x[j*xstride+row] = (DomainScalar)y[j*ystride+row];
00099           while (i != ie) {
00100             const Ordinal ind = *i++;
00101             const Scalar  val = *v++;
00102             for (size_t j=0; j<numRHS; ++j) x[j*xstride+row] -= (DomainScalar)val * x[j*xstride+ind];
00103           }
00104         }
00105       }
00106       else if (upper && !unitDiag) {
00107         // upper + non-unit
00108         DomainScalar diag = vals[ptrs[numRows-1]];
00109         for (size_t j=0; j<numRHS; ++j) x[j*xstride+numRows-1] = (DomainScalar)( y[j*ystride+numRows-1] / diag );
00110         for (size_t r=2; r < numRows+1; ++r) {
00111           const size_t row = numRows - r; // for row=numRows-2 to 0 step -1
00112           const Ordinal *i = inds+ptrs[row],
00113                        *ie = inds+ptrs[row+1];
00114           const Scalar  *v = vals+ptrs[row];
00115           // capture and skip the diag
00116           ++i;
00117           diag = *v++;
00118           //
00119           for (size_t j=0; j<numRHS; ++j) x[j*xstride+row] = (DomainScalar) y[j*ystride+row];
00120           while (i != ie) {
00121             const Ordinal ind = *i++;
00122             const Scalar  val = *v++;
00123             for (size_t j=0; j<numRHS; ++j) x[j*xstride+row] -= (DomainScalar) val * x[j*xstride+ind];
00124           }
00125           for (size_t j=0; j<numRHS; ++j) x[j*xstride+row] /= diag;
00126         }
00127       }
00128       else if (!upper && unitDiag) {
00129         // lower + unit
00130         for (size_t j=0; j<numRHS; ++j) x[j*xstride] = (DomainScalar) y[j*ystride];
00131         for (size_t row=1; row < numRows; ++row) {
00132           const Ordinal *i = inds+ptrs[row],
00133                        *ie = inds+ptrs[row+1];
00134           const Scalar  *v = vals+ptrs[row];
00135           for (size_t j=0; j<numRHS; ++j) x[j*xstride+row] = (DomainScalar) y[j*ystride+row];
00136           while (i != ie) {
00137             const Ordinal ind = *i++;
00138             const Scalar  val = *v++;
00139             for (size_t j=0; j<numRHS; ++j) x[j*xstride+row] -= (DomainScalar) val * x[j*xstride+ind];
00140           }
00141         }
00142       }
00143       else if (!upper && !unitDiag) {
00144         // lower + non-unit
00145         DomainScalar diag = vals[0];
00146         for (size_t j=0; j<numRHS; ++j) x[j*xstride] = (DomainScalar)( y[j*ystride] / (RangeScalar)diag );
00147         for (size_t row=1; row < numRows; ++row) {
00148           const Ordinal *i = inds+ptrs[row],
00149                        *ie = inds+ptrs[row+1];
00150           const Scalar  *v = vals+ptrs[row];
00151           // skip the diag
00152           --ie;
00153           for (size_t j=0; j<numRHS; ++j) x[j*xstride+row] = (DomainScalar) y[j*ystride+row];
00154           while (i != ie) {
00155             const Ordinal ind = *i++;
00156             const Scalar  val = *v++;
00157             for (size_t j=0; j<numRHS; ++j) x[j*xstride+row] -= (DomainScalar) val * x[j*xstride+ind];
00158           }
00159           diag = *v;
00160           for (size_t j=0; j<numRHS; ++j) x[j*xstride+row] /= diag;
00161         }
00162       }
00163     }
00164   };
00165 
00166   template <class Scalar, class Ordinal, class DomainScalar, class RangeScalar>
00167   struct DefaultSparseTransposeSolveOp {
00168     // mat data
00169     const size_t  *ptrs;
00170     const Ordinal *inds;
00171     const Scalar  *vals;
00172     size_t numRows;
00173     // matvec params
00174     bool unitDiag, upper;
00175     // mv data
00176     DomainScalar  *x;
00177     const RangeScalar *y;
00178     size_t numRHS, xstride, ystride;
00179 
00180     inline KERNEL_PREFIX void execute() {
00181       // solve for X in A^H * X = Y
00182       // 
00183       // put y into x and solve system in-situ
00184       // this is copy-safe, in the scenario that x and y point to the same location.
00185       //
00186       for (size_t rhs=0; rhs < numRHS; ++rhs) {
00187         for (size_t row=0; row < numRows; ++row) {
00188           x[rhs*xstride+row] = y[rhs*xstride+row];
00189         }
00190       }
00191       // 
00192       if (upper && unitDiag) {
00193         // upper + unit
00194         for (size_t row=0; row < numRows-1; ++row) {
00195           const Ordinal *i = inds+ptrs[row],
00196                        *ie = inds+ptrs[row+1];
00197           const Scalar  *v = vals+ptrs[row];
00198           while (i != ie) {
00199             const Ordinal ind = *i++;
00200             const Scalar  val = *v++;
00201             for (size_t j=0; j<numRHS; ++j) x[j*xstride+ind] -= (DomainScalar)val * x[j*xstride+row];
00202           }
00203         }
00204       }
00205       else if (upper && !unitDiag) {
00206         // upper + non-unit; diag is first element in row
00207         DomainScalar diag;
00208         for (size_t row=0; row < numRows-1; ++row) {
00209           const Ordinal *i = inds+ptrs[row],
00210                        *ie = inds+ptrs[row+1];
00211           const Scalar  *v = vals+ptrs[row];
00212           // capture and skip the diag
00213           ++i;
00214           diag = *v++;
00215           //
00216           for (size_t j=0; j<numRHS; ++j) x[j*xstride+row] /= diag;
00217           while (i != ie) {
00218             const Ordinal ind = *i++;
00219             const Scalar  val = *v++;
00220             for (size_t j=0; j<numRHS; ++j) x[j*xstride+ind] -= (DomainScalar)val * x[j*xstride+row];
00221           }
00222         }
00223         diag = vals[ptrs[numRows-1]];
00224         for (size_t j=0; j<numRHS; ++j) x[j*xstride+numRows-1] /= diag;
00225       }
00226       else if (!upper && unitDiag) {
00227         // lower + unit
00228         for (size_t row=numRows-1; row > 0; --row) {
00229           const Ordinal *i = inds+ptrs[row],
00230                        *ie = inds+ptrs[row+1];
00231           const Scalar  *v = vals+ptrs[row];
00232           while (i != ie) {
00233             const Ordinal ind = *i++;
00234             const Scalar  val = *v++;
00235             for (size_t j=0; j<numRHS; ++j) x[j*xstride+ind] -= (DomainScalar)val * x[j*xstride+row];
00236           }
00237         }
00238       }
00239       else if (!upper && !unitDiag) {
00240         // lower + non-unit; diag is last element in row
00241         DomainScalar diag;
00242         for (size_t row=numRows-1; row > 0; --row) {
00243           const Ordinal *i = inds+ptrs[row],
00244                        *ie = inds+ptrs[row+1];
00245           const Scalar  *v = vals+ptrs[row];
00246           // capture and skip the diag
00247           diag = v[ie-i-1];
00248           --ie;
00249           //
00250           for (size_t j=0; j<numRHS; ++j) x[j*xstride+row] /= diag;
00251           while (i != ie) {
00252             const Ordinal ind = *i++;
00253             const Scalar  val = *v++;
00254             for (size_t j=0; j<numRHS; ++j) x[j*xstride+ind] -= (DomainScalar)val * x[j*xstride+row];
00255           }
00256         }
00257         // last row
00258         diag = (DomainScalar)vals[0];
00259         for (size_t j=0; j<numRHS; ++j) x[j*xstride] /= diag;
00260       }
00261     }
00262   };
00263 
00264 } // namespace Kokkos
00265 
00266 #endif
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends