Kokkos Node API and Local Linear Algebra Kernels Version of the Day
Kokkos_CuspOps.cuh
00001 //@HEADER
00002 // ************************************************************************
00003 // 
00004 //          Kokkos: Node API and Parallel Node Kernels
00005 //              Copyright (2008) Sandia Corporation
00006 // 
00007 // Under the terms of Contract DE-AC04-94AL85000 with Sandia Corporation,
00008 // the U.S. Government retains certain rights in this software.
00009 // 
00010 // Redistribution and use in source and binary forms, with or without
00011 // modification, are permitted provided that the following conditions are
00012 // met:
00013 //
00014 // 1. Redistributions of source code must retain the above copyright
00015 // notice, this list of conditions and the following disclaimer.
00016 //
00017 // 2. Redistributions in binary form must reproduce the above copyright
00018 // notice, this list of conditions and the following disclaimer in the
00019 // documentation and/or other materials provided with the distribution.
00020 //
00021 // 3. Neither the name of the Corporation nor the names of the
00022 // contributors may be used to endorse or promote products derived from
00023 // this software without specific prior written permission.
00024 //
00025 // THIS SOFTWARE IS PROVIDED BY SANDIA CORPORATION "AS IS" AND ANY
00026 // EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
00027 // IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
00028 // PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL SANDIA CORPORATION OR THE
00029 // CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
00030 // EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
00031 // PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
00032 // PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
00033 // LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
00034 // NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
00035 // SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
00036 //
00037 // Questions? Contact Michael A. Heroux (maherou@sandia.gov) 
00038 // 
00039 // ************************************************************************
00040 //@HEADER
00041 
00042 #ifndef KOKKOS_CUSPOPS_CUH
00043 #define KOKKOS_CUSPOPS_CUH
00044 
00045 #include <cusp/array1d.h>
00046 #include <cusp/csr_matrix.h>
00047 #include <cusp/multiply.h>
00048 #include <cusp/transpose.h>
00049 
00050 #include "Kokkos_CuspWrappers.hpp"
00051 
00052 // forward declaration for cusp adaptors
00053 template <class Offset, class Ordinal, class ScalarA, class ScalarX, class ScalarY> 
00054 void Kokkos::Cuspdetails::cuspCrsMultiply(
00055                           Ordinal numRows, Ordinal numCols, Ordinal nnz, 
00056                           const Offset *rowptrs, const Ordinal *colinds, const ScalarA *values, 
00057                           Ordinal numRHS, const ScalarX *x, Ordinal xstride, ScalarY *y, Ordinal ystride) 
00058 {
00059   typedef  thrust::device_ptr<const Offset>                                   off_ptr;
00060   typedef  thrust::device_ptr<const Ordinal>                                  ind_ptr;
00061   typedef  thrust::device_ptr<const ScalarA>                                  val_ptr;
00062   typedef  cusp::array1d_view< off_ptr >                                 off_arr_view;
00063   typedef  cusp::array1d_view< ind_ptr >                                 ind_arr_view;
00064   typedef  cusp::array1d_view< val_ptr >                                 val_arr_view;
00065   typedef  cusp::csr_matrix_view<off_arr_view, ind_arr_view, val_arr_view>   csr_view;
00066   off_arr_view rptrs = cusp::make_array1d_view(off_ptr(rowptrs), off_ptr(rowptrs+numRows+1));
00067   ind_arr_view cinds = cusp::make_array1d_view(ind_ptr(colinds), ind_ptr(colinds+nnz));
00068   val_arr_view vals  = cusp::make_array1d_view(val_ptr(values),  val_ptr(values+nnz));
00069   csr_view A = cusp::make_csr_matrix_view(numRows, numCols, nnz, rptrs, cinds, vals);
00070   for (int j=0; j < numRHS; ++j) {
00071     cusp::detail::device::spmv_csr_vector(A, x, y);
00072     x += xstride;
00073     y += ystride;
00074   }
00075 }
00076 
00077 template <class Offset, class Ordinal, class Scalar>
00078 void Kokkos::Cuspdetails::cuspCrsTranspose(Ordinal numRows, Ordinal numCols, Ordinal nnz, 
00079                                            const Offset *rowptrs,   const Ordinal *colinds,   const Scalar *values, 
00080                                            Offset *rowptrs_t,       Ordinal *colinds_t,       Scalar *values_t) 
00081 {
00082   typedef  thrust::device_ptr<const Offset>                           off_ptr;
00083   typedef  thrust::device_ptr<const Ordinal>                          ind_ptr;
00084   typedef  thrust::device_ptr<const Scalar>                           val_ptr;
00085   typedef  thrust::device_ptr<Offset>                                 off_ptr_nc;
00086   typedef  thrust::device_ptr<Ordinal>                                ind_ptr_nc;
00087   typedef  thrust::device_ptr<Scalar>                                 val_ptr_nc;
00088   typedef  cusp::array1d_view< off_ptr >                              off_arr;
00089   typedef  cusp::array1d_view< ind_ptr >                              ind_arr;
00090   typedef  cusp::array1d_view< val_ptr >                              val_arr;
00091   typedef  cusp::array1d_view< off_ptr_nc >                           off_arr_nc;
00092   typedef  cusp::array1d_view< ind_ptr_nc >                           ind_arr_nc;
00093   typedef  cusp::array1d_view< val_ptr_nc >                           val_arr_nc;
00094   typedef  cusp::csr_matrix_view<off_arr,    ind_arr,    val_arr>     csr_view;
00095   typedef  cusp::csr_matrix_view<off_arr_nc, ind_arr_nc, val_arr_nc>  csr_view_nc;
00096   off_arr a_rptrs = cusp::make_array1d_view(off_ptr(rowptrs), off_ptr(rowptrs+numRows+1));
00097   ind_arr a_cinds = cusp::make_array1d_view(ind_ptr(colinds), ind_ptr(colinds+nnz));
00098   val_arr a_vals  = cusp::make_array1d_view(val_ptr(values),  val_ptr(values+nnz));
00099   csr_view A = cusp::make_csr_matrix_view(numRows, numCols, nnz, a_rptrs, a_cinds, a_vals);
00100   off_arr_nc at_rptrs = cusp::make_array1d_view(off_ptr_nc(rowptrs_t), off_ptr_nc(rowptrs_t+numCols+1));
00101   ind_arr_nc at_cinds = cusp::make_array1d_view(ind_ptr_nc(colinds_t), ind_ptr_nc(colinds_t+nnz));
00102   val_arr_nc at_vals  = cusp::make_array1d_view(val_ptr_nc(values_t),  val_ptr_nc(values_t+nnz));
00103   csr_view_nc At = cusp::make_csr_matrix_view(numCols, numRows, nnz, at_rptrs, at_cinds, at_vals);
00104   cusp::transpose(A,At);
00105 }
00106 
00107 #endif // KOKKOS_CUSPOPS_CUH
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends