Kokkos Node API and Local Linear Algebra Kernels Version of the Day
Kokkos_DefaultSparseSolveKernelOps.hpp
00001 //@HEADER
00002 // ************************************************************************
00003 // 
00004 //          Kokkos: Node API and Parallel Node Kernels
00005 //              Copyright (2009) Sandia Corporation
00006 // 
00007 // Under terms of Contract DE-AC04-94AL85000, there is a non-exclusive
00008 // license for use of this work by or on behalf of the U.S. Government.
00009 // 
00010 // This library is free software; you can redistribute it and/or modify
00011 // it under the terms of the GNU Lesser General Public License as
00012 // published by the Free Software Foundation; either version 2.1 of the
00013 // License, or (at your option) any later version.
00014 //  
00015 // This library is distributed in the hope that it will be useful, but
00016 // WITHOUT ANY WARRANTY; without even the implied warranty of
00017 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
00018 // Lesser General Public License for more details.
00019 //  
00020 // You should have received a copy of the GNU Lesser General Public
00021 // License along with this library; if not, write to the Free Software
00022 // Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307
00023 // USA
00024 // Questions? Contact Michael A. Heroux (maherou@sandia.gov) 
00025 // 
00026 // ************************************************************************
00027 //@HEADER
00028 
00029 #ifndef KOKKOS_DEFAULTSPARSESOLVE_KERNELOPS_HPP
00030 #define KOKKOS_DEFAULTSPARSESOLVE_KERNELOPS_HPP
00031 
00032 #ifndef KERNEL_PREFIX
00033 #define KERNEL_PREFIX
00034 #endif
00035 
00036 #ifdef __CUDACC__
00037 #include <Teuchos_ScalarTraitsCUDA.hpp>
00038 #else
00039 #include <Teuchos_ScalarTraits.hpp>
00040 #endif
00041 
00042 
00043 namespace Kokkos {
00044 
00045   // 
00046   // Matrix formatting and mat-vec options
00047   // Applies to all four operations below
00048   // 
00049   // unitDiag indicates whether we neglect the diagonal row entry and scale by it
00050   // or utilize all row entries and implicitly scale by a unit diagonal (i.e., don't need to scale)
00051   // upper (versus lower) will determine the ordering of the solve and the location of the diagonal
00052   // 
00053   // upper -> diagonal is first entry on row
00054   // lower -> diagonal is last entry on row
00055   // 
00056 
00057   template <class Scalar, class Ordinal, class DomainScalar, class RangeScalar>
00058   struct DefaultSparseSolveOp1 {
00059     // mat data
00060     const size_t  *begs;
00061     const size_t  *ends;
00062     const Ordinal *inds;
00063     const Scalar  *vals;
00064     size_t numRows;
00065     // matvec params
00066     bool unitDiag, upper;
00067     // mv data
00068     DomainScalar  *x;
00069     const RangeScalar *y;
00070     size_t xstride, ystride;
00071 
00072     inline KERNEL_PREFIX void execute(size_t i) {
00073       // solve rhs i for lhs i
00074       const size_t rhs = i;
00075       DomainScalar      *xj = x + rhs * xstride;
00076       const RangeScalar *yj = y + rhs * ystride;
00077       // 
00078       // upper triangular requires backwards substition, solving in reverse order
00079       // must unroll the last iteration, because decrement results in wrap-around
00080       // 
00081       if (upper && unitDiag) {
00082         // upper + unit
00083         xj[numRows-1] = (DomainScalar)yj[numRows-1];
00084         for (size_t r=2; r < numRows+1; ++r) {
00085           const size_t row = numRows - r; // for row=numRows-2 to 0 step -1
00086           const size_t begin = begs[row], end = ends[row];
00087           xj[row] = (DomainScalar)yj[row];
00088           for (size_t c=begin; c != end; ++c) {
00089             xj[row] -= (DomainScalar)vals[c] * xj[inds[c]];
00090           }
00091         }
00092       }
00093       else if (upper && !unitDiag) {
00094         // upper + non-unit
00095         xj[numRows-1] = (DomainScalar)( yj[numRows-1] / (RangeScalar)vals[begs[numRows-1]] );
00096         for (size_t r=2; r < numRows+1; ++r) {
00097           const size_t row = numRows - r; // for row=numRows-2 to 0 step -1
00098           const size_t diag = begs[row], end = ends[row];
00099           const DomainScalar dval = (DomainScalar)vals[diag];
00100           xj[row] = (DomainScalar)yj[row];
00101           for (size_t c=diag+1; c != end; ++c) {
00102             xj[row] -= (DomainScalar)vals[c] * xj[inds[c]];
00103           }
00104           xj[row] /= dval;
00105         }
00106       }
00107       else if (!upper && unitDiag) {
00108         // lower + unit
00109         xj[0] = (DomainScalar)yj[0];
00110         for (size_t row=1; row < numRows; ++row) {
00111           const size_t begin = begs[row], end = ends[row];
00112           xj[row] = (DomainScalar)yj[row];
00113           for (size_t c=begin; c != end; ++c) {
00114             xj[row] -= (DomainScalar)vals[c] * xj[inds[c]];
00115           }
00116         }
00117       }
00118       else if (!upper && !unitDiag) {
00119         // lower + non-unit
00120         xj[0] = (DomainScalar)( yj[0] / (RangeScalar)vals[0] );
00121         for (size_t row=1; row < numRows; ++row) {
00122           const size_t begin = begs[row], diag = ends[row]-1;
00123           const DomainScalar dval = vals[diag];
00124           xj[row] = (DomainScalar)yj[row];
00125           for (size_t c=begin; c != diag; ++c) {
00126             xj[row] -= (DomainScalar)vals[c] * xj[inds[c]];
00127           }
00128           xj[row] /= dval;
00129         }
00130       }
00131     }
00132   };
00133 
00134 
00135   template <class Scalar, class Ordinal, class DomainScalar, class RangeScalar>
00136   struct DefaultSparseSolveOp2 {
00137     // mat data
00138     const Ordinal * const * inds_beg;
00139     const Scalar  * const * vals_beg;
00140     const size_t  *         numEntries;
00141     size_t numRows;
00142     // matvec params
00143     bool unitDiag, upper;
00144     // mv data
00145     DomainScalar      *x;
00146     const RangeScalar *y;
00147     size_t xstride, ystride;
00148 
00149     inline KERNEL_PREFIX void execute(size_t i) {
00150       // solve rhs i for lhs i
00151       const size_t rhs = i;
00152       DomainScalar      *xj = x + rhs * xstride;
00153       const RangeScalar *yj = y + rhs * ystride;
00154       const Scalar  *rowvals;
00155       const Ordinal *rowinds;
00156       DomainScalar dval;
00157       size_t nE;
00158       // 
00159       // upper triangular requires backwards substition, solving in reverse order
00160       // must unroll the last iteration, because decrement results in wrap-around
00161       // 
00162       if (upper && unitDiag) {
00163         // upper + unit
00164         xj[numRows-1] = (DomainScalar)yj[numRows-1];
00165         for (size_t row=numRows-2; row != 0; --row) {
00166           nE = numEntries[row];
00167           rowvals = vals_beg[row];
00168           rowinds = inds_beg[row];
00169           xj[row] = yj[row];
00170           for (size_t j=0; j != nE; ++j) {
00171             xj[row] -= (DomainScalar)rowvals[j] * xj[rowinds[j]];
00172           }
00173         }
00174         nE = numEntries[0];
00175         rowvals = vals_beg[0];
00176         rowinds = inds_beg[0];
00177         xj[0] = (DomainScalar)yj[0];
00178         for (size_t j=0; j != nE; ++j) {
00179           xj[0] -= (DomainScalar)rowvals[j] * xj[rowinds[j]];
00180         }
00181       }
00182       else if (upper && !unitDiag) {
00183         // upper + non-unit: diagonal is first entry
00184         dval = (DomainScalar)vals_beg[numRows-1][0];
00185         xj[numRows-1] = (DomainScalar)yj[numRows-1] / dval;
00186         for (size_t row=numRows-2; row != 0; --row) {
00187           nE = numEntries[row];
00188           rowvals = vals_beg[row];
00189           rowinds = inds_beg[row];
00190           xj[row] = (DomainScalar)yj[row];
00191           Scalar dval_inner = rowvals[0];
00192           for (size_t j=1; j < nE; ++j) {
00193             xj[row] -= (DomainScalar)rowvals[j] * xj[rowinds[j]];
00194           }
00195           xj[row] /= dval_inner;
00196         }
00197         nE = numEntries[0];
00198         rowvals = vals_beg[0];
00199         rowinds = inds_beg[0];
00200         xj[0] = (DomainScalar)yj[0];
00201         DomainScalar dval_inner = (DomainScalar)rowvals[0];
00202         for (size_t j=1; j < nE; ++j) {
00203           xj[0] -= (DomainScalar)rowvals[j] * xj[rowinds[j]];
00204         }
00205         xj[0] /= dval_inner;
00206       }
00207       else if (!upper && unitDiag) {
00208         // lower + unit
00209         xj[0] = (DomainScalar)yj[0];
00210         for (size_t row=1; row < numRows; ++row) {
00211           nE = numEntries[row];
00212           rowvals = vals_beg[row];
00213           rowinds = inds_beg[row];
00214           xj[row] = (DomainScalar)yj[row];
00215           for (size_t j=0; j < nE; ++j) {
00216             xj[row] -= (DomainScalar)rowvals[j] * xj[rowinds[j]];
00217           }
00218         }
00219       }
00220       else if (!upper && !unitDiag) {
00221         // lower + non-unit; diagonal is last entry
00222         nE = numEntries[0];
00223         rowvals = vals_beg[0];
00224         dval = (DomainScalar)rowvals[0];
00225         xj[0] = yj[0];
00226         for (size_t row=1; row < numRows; ++row) {
00227           nE = numEntries[row];
00228           rowvals = vals_beg[row];
00229           rowinds = inds_beg[row];
00230           dval = (DomainScalar)rowvals[nE-1];
00231           xj[row] = (DomainScalar)yj[row];
00232           if (nE > 1) {
00233             for (size_t j=0; j < nE-1; ++j) {
00234               xj[row] -= (DomainScalar)rowvals[j] * xj[rowinds[j]];
00235             }
00236           }
00237           xj[row] /= dval;
00238         }
00239       }
00240     }
00241   };
00242 
00243 
00244   template <class Scalar, class Ordinal, class DomainScalar, class RangeScalar>
00245   struct DefaultSparseTransposeSolveOp1 {
00246     // mat data
00247     const size_t  *begs;
00248     const size_t  *ends;
00249     const Ordinal *inds;
00250     const Scalar  *vals;
00251     size_t numRows;
00252     // matvec params
00253     bool unitDiag, upper;
00254     // mv data
00255     DomainScalar  *x;
00256     const RangeScalar *y;
00257     size_t xstride, ystride;
00258 
00259     inline KERNEL_PREFIX void execute(size_t i) {
00260       // solve rhs i for lhs i
00261       const size_t rhs = i;
00262       DomainScalar      *xj = x + rhs * xstride;
00263       const RangeScalar *yj = y + rhs * ystride;
00264       // 
00265       // put y into x and solve system in-situ
00266       // this is copy-safe, in the scenario that x and y point to the same location.
00267       //
00268       for (size_t row=0; row < numRows; ++row) {
00269         xj[row] = yj[row];
00270       }
00271       // 
00272       if (upper && unitDiag) {
00273         // upper + unit
00274         size_t beg, endplusone;
00275         for (size_t row=0; row < numRows-1; ++row) {
00276           beg = begs[row]; 
00277           endplusone = ends[row];
00278           for (size_t j=beg; j < endplusone; ++j) {
00279             xj[inds[j]] -= (DomainScalar)vals[j] * xj[row];
00280           }
00281         }
00282       }
00283       else if (upper && !unitDiag) {
00284         // upper + non-unit; diag is first element in row
00285         size_t diag, endplusone;
00286         DomainScalar dval;
00287         for (size_t row=0; row < numRows-1; ++row) {
00288           diag = begs[row]; 
00289           endplusone = ends[row];
00290           dval = (DomainScalar)vals[diag];
00291           xj[row] /= dval;
00292           for (size_t j=diag+1; j < endplusone; ++j) {
00293             xj[inds[j]] -= (DomainScalar)vals[j] * xj[row];
00294           }
00295         }
00296         diag = begs[numRows-1];
00297         dval = (DomainScalar)vals[diag];
00298         xj[numRows-1] /= dval;
00299       }
00300       else if (!upper && unitDiag) {
00301         // lower + unit
00302         for (size_t row=numRows-1; row > 0; --row) {
00303           size_t beg = begs[row], endplusone = ends[row];
00304           for (size_t j=beg; j < endplusone; ++j) {
00305             xj[inds[j]] -= (DomainScalar)vals[j] * xj[row];
00306           }
00307         }
00308       }
00309       else if (!upper && !unitDiag) {
00310         // lower + non-unit; diag is last element in row
00311         DomainScalar dval;
00312         for (size_t row=numRows-1; row > 0; --row) {
00313           size_t beg = begs[row], diag = ends[row]-1;
00314           dval = (DomainScalar)vals[diag];
00315           xj[row] /= dval;
00316           for (size_t j=beg; j < diag; ++j) {
00317             xj[inds[j]] -= (DomainScalar)vals[j] * xj[row];
00318           }
00319         }
00320         // last row
00321         dval = (DomainScalar)vals[0];
00322         xj[0] /= dval;
00323       }
00324     }
00325   };
00326 
00327 
00328   template <class Scalar, class Ordinal, class DomainScalar, class RangeScalar>
00329   struct DefaultSparseTransposeSolveOp2 {
00330     // mat data
00331     const Ordinal * const * inds_beg;
00332     const Scalar  * const * vals_beg;
00333     const size_t  *         numEntries;
00334     size_t numRows;
00335     // matvec params
00336     bool unitDiag, upper;
00337     // mv data
00338     DomainScalar      *x;
00339     const RangeScalar *y;
00340     size_t xstride, ystride;
00341 
00342     inline KERNEL_PREFIX void execute(size_t i) {
00343       // solve rhs i for lhs i
00344       const size_t rhs = i;
00345       DomainScalar      *xj = x + rhs * xstride;
00346       const RangeScalar *yj = y + rhs * ystride;
00347       const Scalar  *rowvals;
00348       const Ordinal *rowinds;
00349       DomainScalar dval;
00350       size_t nE;
00351       // 
00352       // put y into x and solve system in-situ
00353       // this is copy-safe, in the scenario that x and y point to the same location.
00354       //
00355       for (size_t row=0; row < numRows; ++row) {
00356         xj[row] = (DomainScalar)yj[row];
00357       }
00358       // 
00359       if (upper && unitDiag) {
00360         // upper + unit
00361         for (size_t row=0; row < numRows-1; ++row) {
00362           nE = numEntries[row];
00363           rowvals = vals_beg[row];
00364           rowinds = inds_beg[row];
00365           for (size_t j=0; j < nE; ++j) {
00366             xj[rowinds[j]] -= (DomainScalar)rowvals[j] * xj[row];
00367           }
00368         }
00369       }
00370       else if (upper && !unitDiag) {
00371         // upper + non-unit; diag is first element in row
00372         for (size_t row=0; row < numRows-1; ++row) {
00373           nE = numEntries[row];
00374           rowvals = vals_beg[row];
00375           rowinds = inds_beg[row];
00376           dval = (DomainScalar)rowvals[0];
00377           xj[row] /= dval;
00378           for (size_t j=1; j < nE; ++j) {
00379             xj[rowinds[j]] -= (DomainScalar)rowvals[j] * xj[row];
00380           }
00381         }
00382         rowvals = vals_beg[numRows-1];
00383         dval = (DomainScalar)rowvals[0];
00384         xj[numRows-1] /= dval;
00385       }
00386       else if (!upper && unitDiag) {
00387         // lower + unit
00388         for (size_t row=numRows-1; row > 0; --row) {
00389           nE = numEntries[row];
00390           rowvals = vals_beg[row];
00391           rowinds = inds_beg[row];
00392           for (size_t j=0; j < nE; ++j) {
00393             xj[rowinds[j]] -= (DomainScalar)rowvals[j] * xj[row];
00394           }
00395         }
00396       }
00397       else if (!upper && !unitDiag) {
00398         // lower + non-unit; diag is last element in row
00399         for (size_t row=numRows-1; row > 0; --row) {
00400           nE = numEntries[row];
00401           rowvals = vals_beg[row];
00402           rowinds = inds_beg[row];
00403           dval = (DomainScalar)rowvals[nE-1];
00404           xj[row] /= dval;
00405           for (size_t j=0; j < nE-1; ++j) {
00406             xj[rowinds[j]] -= (DomainScalar)rowvals[j] * xj[row];
00407           }
00408         }
00409         rowvals = vals_beg[0];
00410         dval = (DomainScalar)rowvals[0];
00411         xj[0] /= dval;
00412       }
00413     }
00414   };
00415 
00416 } // namespace Kokkos
00417 
00418 #endif
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends