Kokkos Node API and Local Linear Algebra Kernels Version of the Day
Kokkos_Mkl_RawSparseKernels.cpp
00001 #include "Kokkos_Mkl_RawSparseKernels.hpp"
00002 #include "Teuchos_ConfigDefs.hpp"
00003 #include <mkl.h>
00004 
00006 // Specializations for Scalar=float
00008 
00009 namespace Kokkos {
00010   namespace Mkl {
00011 
00012     template<>
00013     void RawSparseKernels<float, MKL_INT>::
00014     csrmv (const char* const transa,
00015            const MKL_INT m, // Number of rows in A
00016            const MKL_INT k, // Number of columns in A
00017            const float& alpha,
00018            const char* const matdescra,
00019            const float* const val,
00020            const MKL_INT* const ind,
00021            const MKL_INT* const ptrBegin,
00022            const MKL_INT* const ptrEnd, // hint: ptrEnd = &ptrBegin[1]
00023            const float* const x,
00024            const float& beta,
00025            float* const y)
00026     {
00027       mkl_scsrmv ((char*) transa, (MKL_INT*) &m, (MKL_INT*) &k,
00028                   (float*) &alpha, (char*) matdescra,
00029                   (float*) val, (MKL_INT*) ind, (MKL_INT*) ptrBegin, (MKL_INT*) ptrEnd,
00030                   (float*) x, (float*) &beta, y);
00031     }
00032 
00033     template<>
00034     void RawSparseKernels<float, MKL_INT>::
00035     csrmm (const char* const transa,
00036            const MKL_INT m, // number of rows of A
00037            const MKL_INT n, // number of columns of C
00038            const MKL_INT k, // number of columns of A
00039            const float& alpha,
00040            const char* const matdescra,
00041            const float* const val,
00042            const MKL_INT* const ind,
00043            const MKL_INT* const ptrBegin,
00044            const MKL_INT* const ptrEnd, // hint: ptrEnd = &ptrBegin[1]
00045            const float* const B,
00046            const MKL_INT LDB,
00047            const float& beta,
00048            float* const C,
00049            const MKL_INT LDC)
00050     {
00051       mkl_scsrmm ((char*) transa, (MKL_INT*) &m, (MKL_INT*) &n, (MKL_INT*) &k,
00052                   (float*) &alpha, (char*) matdescra,
00053                   (float*) val, (MKL_INT*) ind, (MKL_INT*) ptrBegin, (MKL_INT*) ptrEnd,
00054                   (float*) B, (MKL_INT*) &LDB, (float*) &beta, C, (MKL_INT*) &LDC);
00055     }
00056 
00057     template<>
00058     void RawSparseKernels<float, MKL_INT>::
00059     csrsv (const char* const transa,
00060            const MKL_INT m,
00061            const float& alpha,
00062            const char* const matdescra,
00063            const float* const val,
00064            const MKL_INT* const ind,
00065            const MKL_INT* const ptrBegin,
00066            const MKL_INT* const ptrEnd, // hint: ptrEnd = &ptrBegin[1]
00067            const float* const x,
00068            float* const y)
00069     {
00070       mkl_scsrsv ((char*) transa, (MKL_INT*) &m, (float*) &alpha, (char*) matdescra,
00071                   (float*) val, (MKL_INT*) ind, (MKL_INT*) ptrBegin, (MKL_INT*) ptrEnd,
00072                   (float*) x, y);
00073     }
00074 
00075     template<>
00076     void RawSparseKernels<float, MKL_INT>::
00077     csrsm (const char* const transa,
00078            const MKL_INT m, // Number of columns in A
00079            const MKL_INT n, // Number of columns in C
00080            const float& alpha,
00081            const char* const matdescra,
00082            const float* const val,
00083            const MKL_INT* const ind,
00084            const MKL_INT* const ptrBegin,
00085            const MKL_INT* const ptrEnd, // hint: ptrEnd = &ptrBegin[1]
00086            const float* const B,
00087            const MKL_INT LDB,
00088            float* const C,
00089            const MKL_INT LDC)
00090     {
00091       mkl_scsrsm ((char*) transa, (MKL_INT*) &m, (MKL_INT*) &n,
00092                   (float*) &alpha, (char*) matdescra,
00093                   (float*) val, (MKL_INT*) ind, (MKL_INT*) ptrBegin, (MKL_INT*) ptrEnd,
00094                   (float*) B, (MKL_INT*) &LDB, C, (MKL_INT*) &LDC);
00095     }
00096 
00098     // Specializations for Scalar=double
00100 
00101     template<>
00102     void RawSparseKernels<double, MKL_INT>::
00103     csrmv (const char* const transa,
00104            const MKL_INT m, // Number of rows in A
00105            const MKL_INT k, // Number of columns in A
00106            const double& alpha,
00107            const char* const matdescra,
00108            const double* const val,
00109            const MKL_INT* const ind,
00110            const MKL_INT* const ptrBegin,
00111            const MKL_INT* const ptrEnd, // hint: ptrEnd = &ptrBegin[1]
00112            const double* const x,
00113            const double& beta,
00114            double* const y)
00115     {
00116       mkl_dcsrmv ((char*) transa, (MKL_INT*) &m, (MKL_INT*) &k,
00117                   (double*) &alpha, (char*) matdescra,
00118                   (double*) val, (MKL_INT*) ind, (MKL_INT*) ptrBegin, (MKL_INT*) ptrEnd,
00119                   (double*) x, (double*) &beta, y);
00120     }
00121 
00122     template<>
00123     void RawSparseKernels<double, MKL_INT>::
00124     csrmm (const char* const transa,
00125            const MKL_INT m, // number of rows of A
00126            const MKL_INT n, // number of columns of C
00127            const MKL_INT k, // number of columns of A
00128            const double& alpha,
00129            const char* const matdescra,
00130            const double* const val,
00131            const MKL_INT* const ind,
00132            const MKL_INT* const ptrBegin,
00133            const MKL_INT* const ptrEnd, // hint: ptrEnd = &ptrBegin[1]
00134            const double* const B,
00135            const MKL_INT LDB,
00136            const double& beta,
00137            double* const C,
00138            const MKL_INT LDC)
00139     {
00140       mkl_dcsrmm ((char*) transa, (MKL_INT*) &m, (MKL_INT*) &n, (MKL_INT*) &k,
00141                   (double*) &alpha, (char*) matdescra,
00142                   (double*) val, (MKL_INT*) ind, (MKL_INT*) ptrBegin, (MKL_INT*) ptrEnd,
00143                   (double*) B, (MKL_INT*) &LDB, (double*) &beta, C, (MKL_INT*) &LDC);
00144     }
00145 
00146     template<>
00147     void RawSparseKernels<double, MKL_INT>::
00148     csrsv (const char* const transa,
00149            const MKL_INT m,
00150            const double& alpha,
00151            const char* const matdescra,
00152            const double* const val,
00153            const MKL_INT* const ind,
00154            const MKL_INT* const ptrBegin,
00155            const MKL_INT* const ptrEnd, // hint: ptrEnd = &ptrBegin[1]
00156            const double* const x,
00157            double* const y)
00158     {
00159       mkl_dcsrsv ((char*) transa, (MKL_INT*) &m, (double*) &alpha, (char*) matdescra,
00160                   (double*) val, (MKL_INT*) ind, (MKL_INT*) ptrBegin, (MKL_INT*) ptrEnd,
00161                   (double*) x, y);
00162     }
00163 
00164     template<>
00165     void RawSparseKernels<double, MKL_INT>::
00166     csrsm (const char* const transa,
00167            const MKL_INT m, // Number of columns in A
00168            const MKL_INT n, // Number of columns in C
00169            const double& alpha,
00170            const char* const matdescra,
00171            const double* const val,
00172            const MKL_INT* const ind,
00173            const MKL_INT* const ptrBegin,
00174            const MKL_INT* const ptrEnd, // hint: ptrEnd = &ptrBegin[1]
00175            const double* const B,
00176            const MKL_INT LDB,
00177            double* const C,
00178            const MKL_INT LDC)
00179     {
00180       mkl_dcsrsm ((char*) transa, (MKL_INT*) &m, (MKL_INT*) &n,
00181                   (double*) &alpha, (char*) matdescra,
00182                   (double*) val, (MKL_INT*) ind, (MKL_INT*) ptrBegin, (MKL_INT*) ptrEnd,
00183                   (double*) B, (MKL_INT*) &LDB, C, (MKL_INT*) &LDC);
00184     }
00185 
00186     // MKL always defines the complex-arithmetic routines, but only build
00187     // wrappers for them if Teuchos' complex arithmetic support is enabled.
00188 #ifdef HAVE_TEUCHOS_COMPLEX
00189 
00191     // Specializations for Scalar=std::complex<float>
00193 
00194     template<>
00195     void RawSparseKernels<std::complex<float>, MKL_INT>::
00196     csrmv (const char* const transa,
00197            const MKL_INT m, // Number of rows in A
00198            const MKL_INT k, // Number of columns in A
00199            const std::complex<float>& alpha,
00200            const char* const matdescra,
00201            const std::complex<float>* const val,
00202            const MKL_INT* const ind,
00203            const MKL_INT* const ptrBegin,
00204            const MKL_INT* const ptrEnd, // hint: ptrEnd = &ptrBegin[1]
00205            const std::complex<float>* const x,
00206            const std::complex<float>& beta,
00207            std::complex<float>* const y)
00208     {
00209       mkl_ccsrmv ((char*) transa,
00210                   (MKL_INT*) &m,
00211                   (MKL_INT*) &k,
00212                   reinterpret_cast<MKL_Complex8*> (const_cast<std::complex<float>* > (&alpha)),
00213                   (char*) matdescra,
00214                   reinterpret_cast<MKL_Complex8*> (const_cast<std::complex<float>* > (val)),
00215                   (MKL_INT*) ind,
00216                   (MKL_INT*) ptrBegin,
00217                   (MKL_INT*) ptrEnd,
00218                   reinterpret_cast<MKL_Complex8*> (const_cast<std::complex<float>* > (x)),
00219                   reinterpret_cast<MKL_Complex8*> (const_cast<std::complex<float>* > (&beta)),
00220                   reinterpret_cast<MKL_Complex8*> (y));
00221     }
00222 
00223     template<>
00224     void RawSparseKernels<std::complex<float>, MKL_INT>::
00225     csrmm (const char* const transa,
00226            const MKL_INT m, // number of rows of A
00227            const MKL_INT n, // number of columns of C
00228            const MKL_INT k, // number of columns of A
00229            const std::complex<float>& alpha,
00230            const char* const matdescra,
00231            const std::complex<float>* const val,
00232            const MKL_INT* const ind,
00233            const MKL_INT* const ptrBegin,
00234            const MKL_INT* const ptrEnd, // hint: ptrEnd = &ptrBegin[1]
00235            const std::complex<float>* const B,
00236            const MKL_INT LDB,
00237            const std::complex<float>& beta,
00238            std::complex<float>* const C,
00239            const MKL_INT LDC)
00240     {
00241       mkl_ccsrmm ((char*) transa,
00242                   (MKL_INT*) &m,
00243                   (MKL_INT*) &n,
00244                   (MKL_INT*) &k,
00245                   reinterpret_cast<MKL_Complex8*> (const_cast<std::complex<float>* > (&alpha)),
00246                   (char*) matdescra,
00247                   reinterpret_cast<MKL_Complex8*> (const_cast<std::complex<float>* > (val)),
00248                   (MKL_INT*) ind,
00249                   (MKL_INT*) ptrBegin,
00250                   (MKL_INT*) ptrEnd,
00251                   reinterpret_cast<MKL_Complex8*> (const_cast<std::complex<float>* > (B)), (MKL_INT*) &LDB,
00252                   reinterpret_cast<MKL_Complex8*> (const_cast<std::complex<float>* > (&beta)),
00253                   reinterpret_cast<MKL_Complex8*> (C), (MKL_INT*) &LDC);
00254   }
00255 
00256   template<>
00257   void RawSparseKernels<std::complex<float>, MKL_INT>::
00258   csrsv (const char* const transa,
00259          const MKL_INT m,
00260          const std::complex<float>& alpha,
00261          const char* const matdescra,
00262          const std::complex<float>* const val,
00263          const MKL_INT* const ind,
00264          const MKL_INT* const ptrBegin,
00265          const MKL_INT* const ptrEnd, // hint: ptrEnd = &ptrBegin[1]
00266          const std::complex<float>* const x,
00267          std::complex<float>* const y)
00268   {
00269     mkl_ccsrsv ((char*) transa,
00270                 (MKL_INT*) &m,
00271                 reinterpret_cast<MKL_Complex8*> (const_cast<std::complex<float>* > (&alpha)),
00272                 (char*) matdescra,
00273                 reinterpret_cast<MKL_Complex8*> (const_cast<std::complex<float>* > (val)),
00274                 (MKL_INT*) ind,
00275                 (MKL_INT*) ptrBegin,
00276                 (MKL_INT*) ptrEnd,
00277                 reinterpret_cast<MKL_Complex8*> (const_cast<std::complex<float>* > (x)),
00278                 reinterpret_cast<MKL_Complex8*> (y));
00279   }
00280 
00281   template<>
00282   void RawSparseKernels<std::complex<float>, MKL_INT>::
00283   csrsm (const char* const transa,
00284          const MKL_INT m, // Number of columns in A
00285          const MKL_INT n, // Number of columns in C
00286          const std::complex<float>& alpha,
00287          const char* const matdescra,
00288          const std::complex<float>* const val,
00289          const MKL_INT* const ind,
00290          const MKL_INT* const ptrBegin,
00291          const MKL_INT* const ptrEnd, // hint: ptrEnd = &ptrBegin[1]
00292          const std::complex<float>* const B,
00293          const MKL_INT LDB,
00294          std::complex<float>* const C,
00295          const MKL_INT LDC)
00296   {
00297     mkl_ccsrsm ((char*) transa,
00298                 (MKL_INT*) &m,
00299                 (MKL_INT*) &n,
00300                 reinterpret_cast<MKL_Complex8*> (const_cast<std::complex<float>* > (&alpha)),
00301                 (char*) matdescra,
00302                 reinterpret_cast<MKL_Complex8*> (const_cast<std::complex<float>* > (val)),
00303                 (MKL_INT*) ind,
00304                 (MKL_INT*) ptrBegin,
00305                 (MKL_INT*) ptrEnd,
00306                 reinterpret_cast<MKL_Complex8*> (const_cast<std::complex<float>* > (B)), (MKL_INT*) &LDB,
00307                 reinterpret_cast<MKL_Complex8*> (C), (MKL_INT*) &LDC);
00308   }
00309 
00311   // Specializations for Scalar=std::complex<double>
00313 
00314   template<>
00315   void RawSparseKernels<std::complex<double>, MKL_INT>::
00316   csrmv (const char* const transa,
00317          const MKL_INT m, // Number of rows in A
00318          const MKL_INT k, // Number of columns in A
00319          const std::complex<double>& alpha,
00320          const char* const matdescra,
00321          const std::complex<double>* const val,
00322          const MKL_INT* const ind,
00323          const MKL_INT* const ptrBegin,
00324          const MKL_INT* const ptrEnd, // hint: ptrEnd = &ptrBegin[1]
00325          const std::complex<double>* const x,
00326          const std::complex<double>& beta,
00327          std::complex<double>* const y)
00328   {
00329     mkl_zcsrmv ((char*) transa,
00330                 (MKL_INT*) &m,
00331                 (MKL_INT*) &k,
00332                 reinterpret_cast<MKL_Complex16*> (const_cast<std::complex<double>* > (&alpha)),
00333                 (char*) matdescra,
00334                 reinterpret_cast<MKL_Complex16*> (const_cast<std::complex<double>* > (val)),
00335                 (MKL_INT*) ind,
00336                 (MKL_INT*) ptrBegin,
00337                 (MKL_INT*) ptrEnd,
00338                 reinterpret_cast<MKL_Complex16*> (const_cast<std::complex<double>* > (x)),
00339                 reinterpret_cast<MKL_Complex16*> (const_cast<std::complex<double>* > (&beta)),
00340                 reinterpret_cast<MKL_Complex16*> (y));
00341   }
00342 
00343   template<>
00344   void RawSparseKernels<std::complex<double>, MKL_INT>::
00345   csrmm (const char* const transa,
00346          const MKL_INT m, // number of rows of A
00347          const MKL_INT n, // number of columns of C
00348          const MKL_INT k, // number of columns of A
00349          const std::complex<double>& alpha,
00350          const char* const matdescra,
00351          const std::complex<double>* const val,
00352          const MKL_INT* const ind,
00353          const MKL_INT* const ptrBegin,
00354          const MKL_INT* const ptrEnd, // hint: ptrEnd = &ptrBegin[1]
00355          const std::complex<double>* const B,
00356          const MKL_INT LDB,
00357          const std::complex<double>& beta,
00358          std::complex<double>* const C,
00359          const MKL_INT LDC)
00360   {
00361     mkl_zcsrmm ((char*) transa,
00362                 (MKL_INT*) &m,
00363                 (MKL_INT*) &n,
00364                 (MKL_INT*) &k,
00365                 reinterpret_cast<MKL_Complex16*> (const_cast<std::complex<double>* > (&alpha)),
00366                 (char*) matdescra,
00367                 reinterpret_cast<MKL_Complex16*> (const_cast<std::complex<double>* > (val)),
00368                 (MKL_INT*) ind,
00369                 (MKL_INT*) ptrBegin,
00370                 (MKL_INT*) ptrEnd,
00371                 reinterpret_cast<MKL_Complex16*> (const_cast<std::complex<double>* > (B)), (MKL_INT*) &LDB,
00372                 reinterpret_cast<MKL_Complex16*> (const_cast<std::complex<double>* > (&beta)),
00373                 reinterpret_cast<MKL_Complex16*> (C), (MKL_INT*) &LDC);
00374   }
00375 
00376   template<>
00377   void RawSparseKernels<std::complex<double>, MKL_INT>::
00378   csrsv (const char* const transa,
00379          const MKL_INT m,
00380          const std::complex<double>& alpha,
00381          const char* const matdescra,
00382          const std::complex<double>* const val,
00383          const MKL_INT* const ind,
00384          const MKL_INT* const ptrBegin,
00385          const MKL_INT* const ptrEnd, // hint: ptrEnd = &ptrBegin[1]
00386          const std::complex<double>* const x,
00387          std::complex<double>* const y)
00388   {
00389     mkl_zcsrsv ((char*) transa,
00390                 (MKL_INT*) &m,
00391                 reinterpret_cast<MKL_Complex16*> (const_cast<std::complex<double>* > (&alpha)),
00392                 (char*) matdescra,
00393                 reinterpret_cast<MKL_Complex16*> (const_cast<std::complex<double>* > (val)),
00394                 (MKL_INT*) ind,
00395                 (MKL_INT*) ptrBegin,
00396                 (MKL_INT*) ptrEnd,
00397                 reinterpret_cast<MKL_Complex16*> (const_cast<std::complex<double>* > (x)),
00398                 reinterpret_cast<MKL_Complex16*> (y));
00399   }
00400 
00401   template<>
00402   void RawSparseKernels<std::complex<double>, MKL_INT>::
00403   csrsm (const char* const transa,
00404          const MKL_INT m, // Number of columns in A
00405          const MKL_INT n, // Number of columns in C
00406          const std::complex<double>& alpha,
00407          const char* const matdescra,
00408          const std::complex<double>* const val,
00409          const MKL_INT* const ind,
00410          const MKL_INT* const ptrBegin,
00411          const MKL_INT* const ptrEnd, // hint: ptrEnd = &ptrBegin[1]
00412          const std::complex<double>* const B,
00413          const MKL_INT LDB,
00414          std::complex<double>* const C,
00415          const MKL_INT LDC)
00416   {
00417     mkl_zcsrsm ((char*) transa,
00418                 (MKL_INT*) &m,
00419                 (MKL_INT*) &n,
00420                 reinterpret_cast<MKL_Complex16*> (const_cast<std::complex<double>* > (&alpha)),
00421                 (char*) matdescra,
00422                 reinterpret_cast<MKL_Complex16*> (const_cast<std::complex<double>* > (val)),
00423                 (MKL_INT*) ind,
00424                 (MKL_INT*) ptrBegin,
00425                 (MKL_INT*) ptrEnd,
00426                 reinterpret_cast<MKL_Complex16*> (const_cast<std::complex<double>* > (B)), (MKL_INT*) &LDB,
00427                 reinterpret_cast<MKL_Complex16*> (C), (MKL_INT*) &LDC);
00428   }
00429 
00430 } // namespace Mkl
00431 } // namespace Kokkos
00432 
00433 #endif // HAVE_TEUCHOS_COMPLEX
00434 
00435 
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends