Kokkos Node API and Local Linear Algebra Kernels Version of the Day
Kokkos_ThrustGPUNode.cuh
00001 #ifndef KOKKOS_THRUSTGPUNODE_CUH_
00002 #define KOKKOS_THRUSTGPUNODE_CUH_
00003 
00004 #include <thrust/for_each.h>
00005 #include <thrust/transform_reduce.h>
00006 #include <thrust/iterator/counting_iterator.h>
00007 
00008 // must define this before including any kernels
00009 #define KERNEL_PREFIX __device__ __host__
00010 
00011 // MUST define this to prevent bringing in implementation of CUDANodeMemoryModel (and therefore, half of Teuchos)
00012 #define KOKKOS_NO_INCLUDE_INSTANTIATIONS
00013 #include <Kokkos_ThrustGPUNode.hpp>
00014 
00015 namespace Kokkos {
00016 
00017   template <class WDPin> 
00018   struct ThrustExecuteWrapper {
00019     mutable WDPin wd;
00020 
00021     inline ThrustExecuteWrapper(WDPin in) : wd(in) {}
00022 
00023     __device__ __host__ inline void operator()(int i) const {
00024       wd.execute(i);
00025     }
00026   };
00027 
00028   template <class WDPin> 
00029   struct ThrustReduceWrapper {
00030     mutable WDPin wd;
00031     inline ThrustReduceWrapper (WDPin in) : wd(in) {}
00032 
00033     __device__ __host__ inline 
00034     typename WDPin::ReductionType 
00035     operator()(typename WDPin::ReductionType x, typename WDPin::ReductionType y) {
00036       return wd.reduce(x,y);
00037     }
00038   };
00039 
00040   template <class WDPin>
00041   struct ThrustGenerateWrapper {
00042     mutable WDPin wd;
00043     inline ThrustGenerateWrapper (WDPin in) : wd(in) {}
00044  
00045     __device__ __host__ inline 
00046     typename WDPin::ReductionType
00047     operator()(int i) {
00048       return wd.generate(i);
00049     }
00050   };
00051 
00052   template <class WDP>
00053   void ThrustGPUNode::parallel_for(int begin, int end, WDP wd) {
00054 #ifdef HAVE_KOKKOS_DEBUG
00055     cudaError_t err = cudaGetLastError();
00056     TEST_FOR_EXCEPTION( cudaSuccess != err, std::runtime_error, 
00057         "Kokkos::ThrustGPUNode::" << __FUNCTION__ << ": " 
00058         << "cudaGetLastError() returned error before function call:\n"
00059         << cudaGetErrorString(err) );
00060 #endif
00061     // wrap in Thrust and hand to thrust::for_each
00062     ThrustExecuteWrapper<WDP> body(wd);  
00063     thrust::counting_iterator<int,thrust::device_space_tag> bit(begin),
00064                                                             eit(end);
00065     thrust::for_each( bit, eit, body );
00066 #ifdef HAVE_KOKKOS_DEBUG
00067     err = cudaThreadSynchronize();
00068     TEST_FOR_EXCEPTION( cudaSuccess != err, std::runtime_error, 
00069         "Kokkos::ThrustGPUNode::" << __FUNCTION__ << ": " 
00070         << "cudaThreadSynchronize() returned error after function call:\n"
00071         << cudaGetErrorString(err) );
00072 #endif
00073   };
00074 
00075   template <class WDP>
00076   typename WDP::ReductionType
00077   ThrustGPUNode::parallel_reduce(int begin, int end, WDP wd) 
00078   {
00079 #ifdef HAVE_KOKKOS_DEBUG
00080     cudaError_t err = cudaGetLastError();
00081     TEST_FOR_EXCEPTION( cudaSuccess != err, std::runtime_error, 
00082         "Kokkos::ThrustGPUNode::" << __FUNCTION__ << ": " 
00083         << "cudaGetLastError() returned error before function call:\n"
00084         << cudaGetErrorString(err) );
00085 #endif
00086     // wrap in Thrust and hand to thrust::transform_reduce
00087     thrust::counting_iterator<int,thrust::device_space_tag> bit(begin),
00088                                                             eit(end);
00089     ThrustReduceWrapper<WDP> ROp(wd);
00090     ThrustGenerateWrapper<WDP> TOp(wd);
00091     typename WDP::ReductionType init = wd.identity(), ret;
00092     ret = thrust::transform_reduce( bit, eit, TOp, init, ROp );
00093 #ifdef HAVE_KOKKOS_DEBUG
00094     err = cudaThreadSynchronize();
00095     TEST_FOR_EXCEPTION( cudaSuccess != err, std::runtime_error, 
00096         "Kokkos::ThrustGPUNode::" << __FUNCTION__ << ": " 
00097         << "cudaThreadSynchronize() returned error after function call:\n"
00098         << cudaGetErrorString(err) );
00099 #endif
00100     return ret;
00101   }
00102 
00103 }
00104 
00105 #endif
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends