|
Sacado Package Browser (Single Doxygen Collection) Version of the Day
|
00001 // $Id$ 00002 // $Source$ 00003 // @HEADER 00004 // *********************************************************************** 00005 // 00006 // Sacado Package 00007 // Copyright (2006) Sandia Corporation 00008 // 00009 // Under the terms of Contract DE-AC04-94AL85000 with Sandia Corporation, 00010 // the U.S. Government retains certain rights in this software. 00011 // 00012 // This library is free software; you can redistribute it and/or modify 00013 // it under the terms of the GNU Lesser General Public License as 00014 // published by the Free Software Foundation; either version 2.1 of the 00015 // License, or (at your option) any later version. 00016 // 00017 // This library is distributed in the hope that it will be useful, but 00018 // WITHOUT ANY WARRANTY; without even the implied warranty of 00019 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU 00020 // Lesser General Public License for more details. 00021 // 00022 // You should have received a copy of the GNU Lesser General Public 00023 // License along with this library; if not, write to the Free Software 00024 // Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 00025 // USA 00026 // Questions? Contact David M. Gay (dmgay@sandia.gov) or Eric T. Phipps 00027 // (etphipp@sandia.gov). 00028 // 00029 // *********************************************************************** 00030 // @HEADER 00031 00032 #ifndef SACADO_FAD_BLAS_HPP 00033 #define SACADO_FAD_BLAS_HPP 00034 00035 #include "Teuchos_BLAS.hpp" 00036 #include "Sacado.hpp" 00037 #include "Sacado_CacheFad_DFad.hpp" 00038 #include "Sacado_dummy_arg.hpp" 00039 00040 namespace Sacado { 00041 00042 namespace Fad { 00043 00044 template <typename OrdinalType, typename FadType> 00045 class ArrayTraits { 00046 00047 typedef typename Sacado::ValueType<FadType>::type ValueType; 00048 typedef typename Sacado::ScalarType<FadType>::type scalar_type; 00049 typedef typename Sacado::dummy<ValueType,scalar_type>::type ScalarType; 00050 00051 public: 00052 00053 ArrayTraits(bool use_dynamic = true, 00054 OrdinalType workspace_size = 0); 00055 00056 ArrayTraits(const ArrayTraits& a); 00057 00058 ~ArrayTraits(); 00059 00060 void unpack(const FadType& a, OrdinalType& n_dot, ValueType& val, 00061 const ValueType*& dot) const; 00062 00063 void unpack(const FadType* a, OrdinalType n, OrdinalType inc, 00064 OrdinalType& n_dot, OrdinalType& inc_val, 00065 OrdinalType& inc_dot, 00066 const ValueType*& val, const ValueType*& dot) const; 00067 00068 void unpack(const FadType* A, OrdinalType m, OrdinalType n, 00069 OrdinalType lda, OrdinalType& n_dot, 00070 OrdinalType& lda_val, OrdinalType& lda_dot, 00071 const ValueType*& val, const ValueType*& dot) const; 00072 00073 void unpack(const ValueType& a, OrdinalType& n_dot, ValueType& val, 00074 const ValueType*& dot) const; 00075 00076 void unpack(const ValueType* a, OrdinalType n, OrdinalType inc, 00077 OrdinalType& n_dot, OrdinalType& inc_val, 00078 OrdinalType& inc_dot, 00079 const ValueType*& val, const ValueType*& dot) const; 00080 00081 void unpack(const ValueType* A, OrdinalType m, OrdinalType n, 00082 OrdinalType lda, OrdinalType& n_dot, 00083 OrdinalType& lda_val, OrdinalType& lda_dot, 00084 const ValueType*& val, const ValueType*& dot) const; 00085 00086 void unpack(const ScalarType& a, OrdinalType& n_dot, ScalarType& val, 00087 const ScalarType*& dot) const; 00088 00089 void unpack(const ScalarType* a, OrdinalType n, OrdinalType inc, 00090 OrdinalType& n_dot, OrdinalType& inc_val, 00091 OrdinalType& inc_dot, 00092 const ScalarType*& val, const ScalarType*& dot) const; 00093 00094 void unpack(const ScalarType* A, OrdinalType m, OrdinalType n, 00095 OrdinalType lda, OrdinalType& n_dot, 00096 OrdinalType& lda_val, OrdinalType& lda_dot, 00097 const ScalarType*& val, const ScalarType*& dot) const; 00098 00099 void unpack(FadType& a, OrdinalType& n_dot, OrdinalType& final_n_dot, 00100 ValueType& val, ValueType*& dot) const; 00101 00102 void unpack(FadType* a, OrdinalType n, OrdinalType inc, 00103 OrdinalType& n_dot, OrdinalType& final_n_dot, 00104 OrdinalType& inc_val, OrdinalType& inc_dot, 00105 ValueType*& val, ValueType*& dot) const; 00106 00107 void unpack(FadType* A, OrdinalType m, OrdinalType n, OrdinalType lda, 00108 OrdinalType& n_dot, OrdinalType& final_n_dot, 00109 OrdinalType& lda_val, OrdinalType& lda_dot, 00110 ValueType*& val, ValueType*& dot) const; 00111 00112 void pack(FadType& a, OrdinalType n_dot, const ValueType& val, 00113 const ValueType* dot) const; 00114 00115 void pack(FadType* a, OrdinalType n, OrdinalType inc, 00116 OrdinalType n_dot, OrdinalType inc_val, OrdinalType inc_dot, 00117 const ValueType* val, const ValueType* dot) const; 00118 00119 void pack(FadType* A, OrdinalType m, OrdinalType n, 00120 OrdinalType lda, OrdinalType n_dot, 00121 OrdinalType lda_val, OrdinalType lda_dot, 00122 const ValueType* val, const ValueType* dot) const; 00123 00124 void free(const FadType& a, OrdinalType n_dot, 00125 const ValueType* dot) const; 00126 00127 void free(const FadType* a, OrdinalType n, OrdinalType n_dot, 00128 OrdinalType inc_val, OrdinalType inc_dot, 00129 const ValueType* val, const ValueType* dot) const; 00130 00131 void free(const FadType* A, OrdinalType m, OrdinalType n, 00132 OrdinalType n_dot, OrdinalType lda_val, OrdinalType lda_dot, 00133 const ValueType* val, const ValueType* dot) const; 00134 00135 void free(const ValueType& a, OrdinalType n_dot, 00136 const ValueType* dot) const {} 00137 00138 void free(const ValueType* a, OrdinalType n, OrdinalType n_dot, 00139 OrdinalType inc_val, OrdinalType inc_dot, 00140 const ValueType* val, const ValueType* dot) const {} 00141 00142 void free(const ValueType* A, OrdinalType m, OrdinalType n, 00143 OrdinalType n_dot, OrdinalType lda_val, OrdinalType lda_dot, 00144 const ValueType* val, const ValueType* dot) const {} 00145 00146 void free(const ScalarType& a, OrdinalType n_dot, 00147 const ScalarType* dot) const {} 00148 00149 void free(const ScalarType* a, OrdinalType n, OrdinalType n_dot, 00150 OrdinalType inc_val, OrdinalType inc_dot, 00151 const ScalarType* val, const ScalarType* dot) const {} 00152 00153 void free(const ScalarType* A, OrdinalType m, OrdinalType n, 00154 OrdinalType n_dot, OrdinalType lda_val, OrdinalType lda_dot, 00155 const ScalarType* val, const ScalarType* dot) const {} 00156 00157 ValueType* allocate_array(OrdinalType size) const; 00158 00159 void free_array(const ValueType* ptr, OrdinalType size) const; 00160 00161 bool is_array_contiguous(const FadType* a, OrdinalType n, 00162 OrdinalType n_dot) const; 00163 00164 protected: 00165 00167 bool use_dynamic; 00168 00170 OrdinalType workspace_size; 00171 00173 mutable ValueType *workspace; 00174 00176 mutable ValueType *workspace_pointer; 00177 00178 }; 00179 00180 template <typename T> struct ArrayValueType { typedef T type; }; 00181 00183 template <typename OrdinalType, typename FadType> 00184 class BLAS : public Teuchos::DefaultBLASImpl<OrdinalType,FadType> { 00185 00186 typedef typename Teuchos::ScalarTraits<FadType>::magnitudeType MagnitudeType; 00187 typedef typename Sacado::ValueType<FadType>::type ValueType; 00188 typedef typename Sacado::ScalarType<FadType>::type scalar_type; 00189 typedef typename Sacado::dummy<ValueType,scalar_type>::type ScalarType; 00190 typedef Teuchos::DefaultBLASImpl<OrdinalType,FadType> BLASType; 00191 00192 public: 00194 00195 00197 BLAS(bool use_default_impl = true, 00198 bool use_dynamic = true, OrdinalType static_workspace_size = 0); 00199 00201 00202 BLAS(const BLAS& x); 00203 00205 virtual ~BLAS(); 00206 00208 00210 00211 00213 void ROTG(FadType* da, FadType* db, MagnitudeType* c, FadType* s) const { 00214 BLASType::ROTG(da,db,c,s); 00215 } 00216 00218 void ROT(const OrdinalType n, FadType* dx, const OrdinalType incx, 00219 FadType* dy, const OrdinalType incy, MagnitudeType* c, 00220 FadType* s) const { 00221 BLASType::ROT(n,dx,incx,dy,incy,c,s); 00222 } 00223 00225 void SCAL(const OrdinalType n, const FadType& alpha, FadType* x, 00226 const OrdinalType incx) const; 00227 00229 void COPY(const OrdinalType n, const FadType* x, 00230 const OrdinalType incx, FadType* y, 00231 const OrdinalType incy) const; 00232 00234 template <typename alpha_type, typename x_type> 00235 void AXPY(const OrdinalType n, const alpha_type& alpha, 00236 const x_type* x, const OrdinalType incx, FadType* y, 00237 const OrdinalType incy) const; 00238 00240 typename Teuchos::ScalarTraits<FadType>::magnitudeType 00241 ASUM(const OrdinalType n, const FadType* x, 00242 const OrdinalType incx) const { 00243 return BLASType::ASUM(n,x,incx); 00244 } 00245 00247 template <typename x_type, typename y_type> 00248 FadType DOT(const OrdinalType n, const x_type* x, 00249 const OrdinalType incx, const y_type* y, 00250 const OrdinalType incy) const; 00251 00253 MagnitudeType NRM2(const OrdinalType n, const FadType* x, 00254 const OrdinalType incx) const; 00255 00257 OrdinalType IAMAX(const OrdinalType n, const FadType* x, 00258 const OrdinalType incx) const { 00259 return BLASType::IAMAX(n,x,incx); 00260 } 00261 00263 00265 00266 00272 template <typename alpha_type, typename A_type, typename x_type, 00273 typename beta_type> 00274 void GEMV(Teuchos::ETransp trans, const OrdinalType m, 00275 const OrdinalType n, 00276 const alpha_type& alpha, const A_type* A, 00277 const OrdinalType lda, const x_type* x, 00278 const OrdinalType incx, const beta_type& beta, 00279 FadType* y, const OrdinalType incy) const; 00280 00286 template <typename A_type> 00287 void TRMV(Teuchos::EUplo uplo, Teuchos::ETransp trans, 00288 Teuchos::EDiag diag, const OrdinalType n, 00289 const A_type* A, const OrdinalType lda, FadType* x, 00290 const OrdinalType incx) const; 00291 00293 template <typename alpha_type, typename x_type, typename y_type> 00294 void GER(const OrdinalType m, const OrdinalType n, 00295 const alpha_type& alpha, 00296 const x_type* x, const OrdinalType incx, 00297 const y_type* y, const OrdinalType incy, 00298 FadType* A, const OrdinalType lda) const; 00299 00301 00303 00304 00311 template <typename alpha_type, typename A_type, typename B_type, 00312 typename beta_type> 00313 void GEMM(Teuchos::ETransp transa, Teuchos::ETransp transb, 00314 const OrdinalType m, const OrdinalType n, const OrdinalType k, 00315 const alpha_type& alpha, const A_type* A, const OrdinalType lda, 00316 const B_type* B, const OrdinalType ldb, const beta_type& beta, 00317 FadType* C, const OrdinalType ldc) const; 00318 00325 template <typename alpha_type, typename A_type, typename B_type, 00326 typename beta_type> 00327 void SYMM(Teuchos::ESide side, Teuchos::EUplo uplo, const OrdinalType m, 00328 const OrdinalType n, 00329 const alpha_type& alpha, const A_type* A, 00330 const OrdinalType lda, const B_type* B, 00331 const OrdinalType ldb, 00332 const beta_type& beta, FadType* C, 00333 const OrdinalType ldc) const; 00334 00341 template <typename alpha_type, typename A_type> 00342 void TRMM(Teuchos::ESide side, Teuchos::EUplo uplo, 00343 Teuchos::ETransp transa, Teuchos::EDiag diag, 00344 const OrdinalType m, const OrdinalType n, 00345 const alpha_type& alpha, 00346 const A_type* A, const OrdinalType lda, 00347 FadType* B, const OrdinalType ldb) const; 00348 00356 template <typename alpha_type, typename A_type> 00357 void TRSM(Teuchos::ESide side, Teuchos::EUplo uplo, 00358 Teuchos::ETransp transa, Teuchos::EDiag diag, 00359 const OrdinalType m, const OrdinalType n, 00360 const alpha_type& alpha, 00361 const A_type* A, const OrdinalType lda, 00362 FadType* B, const OrdinalType ldb) const; 00363 00365 00366 protected: 00367 00369 ArrayTraits<OrdinalType,FadType> arrayTraits; 00370 00372 Teuchos::BLAS<OrdinalType, ValueType> blas; 00373 00375 bool use_default_impl; 00376 00378 mutable std::vector<ValueType> gemv_Ax; 00379 00381 mutable std::vector<ValueType> gemm_AB; 00382 00383 protected: 00384 00386 template <typename x_type, typename y_type> 00387 void Fad_DOT(const OrdinalType n, 00388 const x_type* x, 00389 const OrdinalType incx, 00390 const OrdinalType n_x_dot, 00391 const x_type* x_dot, 00392 const OrdinalType incx_dot, 00393 const y_type* y, 00394 const OrdinalType incy, 00395 const OrdinalType n_y_dot, 00396 const y_type* y_dot, 00397 const OrdinalType incy_dot, 00398 ValueType& z, 00399 const OrdinalType n_z_dot, 00400 ValueType* zdot) const; 00401 00403 template <typename alpha_type, typename A_type, typename x_type, 00404 typename beta_type> 00405 void Fad_GEMV(Teuchos::ETransp trans, 00406 const OrdinalType m, 00407 const OrdinalType n, 00408 const alpha_type& alpha, 00409 const OrdinalType n_alpha_dot, 00410 const alpha_type* alpha_dot, 00411 const A_type* A, 00412 const OrdinalType lda, 00413 const OrdinalType n_A_dot, 00414 const A_type* A_dot, 00415 const OrdinalType lda_dot, 00416 const x_type* x, 00417 const OrdinalType incx, 00418 const OrdinalType n_x_dot, 00419 const x_type* x_dot, 00420 const OrdinalType incx_dot, 00421 const beta_type& beta, 00422 const OrdinalType n_beta_dot, 00423 const beta_type* beta_dot, 00424 ValueType* y, 00425 const OrdinalType incy, 00426 const OrdinalType n_y_dot, 00427 ValueType* y_dot, 00428 const OrdinalType incy_dot, 00429 const OrdinalType n_dot) const; 00430 00432 template <typename alpha_type, typename x_type, typename y_type> 00433 void Fad_GER(const OrdinalType m, 00434 const OrdinalType n, 00435 const alpha_type& alpha, 00436 const OrdinalType n_alpha_dot, 00437 const alpha_type* alpha_dot, 00438 const x_type* x, 00439 const OrdinalType incx, 00440 const OrdinalType n_x_dot, 00441 const x_type* x_dot, 00442 const OrdinalType incx_dot, 00443 const y_type* y, 00444 const OrdinalType incy, 00445 const OrdinalType n_y_dot, 00446 const y_type* y_dot, 00447 const OrdinalType incy_dot, 00448 ValueType* A, 00449 const OrdinalType lda, 00450 const OrdinalType n_A_dot, 00451 ValueType* A_dot, 00452 const OrdinalType lda_dot, 00453 const OrdinalType n_dot) const; 00454 00456 template <typename alpha_type, typename A_type, typename B_type, 00457 typename beta_type> 00458 void Fad_GEMM(Teuchos::ETransp transa, 00459 Teuchos::ETransp transb, 00460 const OrdinalType m, 00461 const OrdinalType n, 00462 const OrdinalType k, 00463 const alpha_type& alpha, 00464 const OrdinalType n_alpha_dot, 00465 const alpha_type* alpha_dot, 00466 const A_type* A, 00467 const OrdinalType lda, 00468 const OrdinalType n_A_dot, 00469 const A_type* A_dot, 00470 const OrdinalType lda_dot, 00471 const B_type* B, 00472 const OrdinalType ldb, 00473 const OrdinalType n_B_dot, 00474 const B_type* B_dot, 00475 const OrdinalType ldb_dot, 00476 const beta_type& beta, 00477 const OrdinalType n_beta_dot, 00478 const beta_type* beta_dot, 00479 ValueType* C, 00480 const OrdinalType ldc, 00481 const OrdinalType n_C_dot, 00482 ValueType* C_dot, 00483 const OrdinalType ldc_dot, 00484 const OrdinalType n_dot) const; 00485 00487 template <typename alpha_type, typename A_type, typename B_type, 00488 typename beta_type> 00489 void Fad_SYMM(Teuchos::ESide side, 00490 Teuchos::EUplo uplo, 00491 const OrdinalType m, 00492 const OrdinalType n, 00493 const alpha_type& alpha, 00494 const OrdinalType n_alpha_dot, 00495 const alpha_type* alpha_dot, 00496 const A_type* A, 00497 const OrdinalType lda, 00498 const OrdinalType n_A_dot, 00499 const A_type* A_dot, 00500 const OrdinalType lda_dot, 00501 const B_type* B, 00502 const OrdinalType ldb, 00503 const OrdinalType n_B_dot, 00504 const B_type* B_dot, 00505 const OrdinalType ldb_dot, 00506 const beta_type& beta, 00507 const OrdinalType n_beta_dot, 00508 const beta_type* beta_dot, 00509 ValueType* C, 00510 const OrdinalType ldc, 00511 const OrdinalType n_C_dot, 00512 ValueType* C_dot, 00513 const OrdinalType ldc_dot, 00514 const OrdinalType n_dot) const; 00515 00517 template <typename alpha_type, typename A_type> 00518 void Fad_TRMM(Teuchos::ESide side, 00519 Teuchos::EUplo uplo, 00520 Teuchos::ETransp transa, 00521 Teuchos::EDiag diag, 00522 const OrdinalType m, 00523 const OrdinalType n, 00524 const alpha_type& alpha, 00525 const OrdinalType n_alpha_dot, 00526 const alpha_type* alpha_dot, 00527 const A_type* A, 00528 const OrdinalType lda, 00529 const OrdinalType n_A_dot, 00530 const A_type* A_dot, 00531 const OrdinalType lda_dot, 00532 ValueType* B, 00533 const OrdinalType ldb, 00534 const OrdinalType n_B_dot, 00535 ValueType* B_dot, 00536 const OrdinalType ldb_dot, 00537 const OrdinalType n_dot) const; 00538 00540 template <typename alpha_type, typename A_type> 00541 void Fad_TRSM(Teuchos::ESide side, 00542 Teuchos::EUplo uplo, 00543 Teuchos::ETransp transa, 00544 Teuchos::EDiag diag, 00545 const OrdinalType m, 00546 const OrdinalType n, 00547 const alpha_type& alpha, 00548 const OrdinalType n_alpha_dot, 00549 const alpha_type* alpha_dot, 00550 const A_type* A, 00551 const OrdinalType lda, 00552 const OrdinalType n_A_dot, 00553 const A_type* A_dot, 00554 const OrdinalType lda_dot, 00555 ValueType* B, 00556 const OrdinalType ldb, 00557 const OrdinalType n_B_dot, 00558 ValueType* B_dot, 00559 const OrdinalType ldb_dot, 00560 const OrdinalType n_dot) const; 00561 00562 }; // class FadBLAS 00563 00564 } // namespace Fad 00565 00566 // template <typename FadType> ArrayValueType<FadType> { typedef ValueType type; }; 00567 // template <> ArrayValueType<ValueType> { typedef ValueType type; }; 00568 // template <> ArrayValueType<ScalarType> { typedef ScalarType type; }; 00569 00570 } // namespace Sacado 00571 00572 // Here we provide partial specializations for Teuchos::BLAS for each Fad type 00573 #define TEUCHOS_BLAS_FAD_SPEC(FADTYPE) \ 00574 namespace Teuchos { \ 00575 template <typename OrdinalType, typename ValueT> \ 00576 class BLAS< OrdinalType, FADTYPE<ValueT> > : \ 00577 public Sacado::Fad::BLAS< OrdinalType, FADTYPE<ValueT> > { \ 00578 public: \ 00579 BLAS(bool use_default_impl = true, bool use_dynamic = true, \ 00580 OrdinalType static_workspace_size = 0) : \ 00581 Sacado::Fad::BLAS< OrdinalType, FADTYPE<ValueT> >( \ 00582 use_default_impl, use_dynamic,static_workspace_size) {} \ 00583 BLAS(const BLAS& x) : \ 00584 Sacado::Fad::BLAS< OrdinalType, FADTYPE<ValueT> >(x) {} \ 00585 virtual ~BLAS() {} \ 00586 }; \ 00587 } \ 00588 namespace Sacado { \ 00589 namespace Fad { \ 00590 template <typename ValueT> \ 00591 struct ArrayValueType< FADTYPE<ValueT> > { \ 00592 typedef ValueT type; \ 00593 }; \ 00594 } \ 00595 } 00596 #define TEUCHOS_BLAS_SFAD_SPEC(FADTYPE) \ 00597 namespace Teuchos { \ 00598 template <typename OrdinalType, typename ValueT, int Num> \ 00599 class BLAS< OrdinalType, FADTYPE<ValueT,Num> > : \ 00600 public Sacado::Fad::BLAS< OrdinalType, FADTYPE<ValueT,Num> > { \ 00601 public: \ 00602 BLAS(bool use_default_impl = true, bool use_dynamic = true, \ 00603 OrdinalType static_workspace_size = 0) : \ 00604 Sacado::Fad::BLAS< OrdinalType, FADTYPE<ValueT,Num> >( \ 00605 use_default_impl, use_dynamic, static_workspace_size) {} \ 00606 BLAS(const BLAS& x) : \ 00607 Sacado::Fad::BLAS< OrdinalType, FADTYPE<ValueT,Num> >(x) {} \ 00608 virtual ~BLAS() {} \ 00609 }; \ 00610 } \ 00611 namespace Sacado { \ 00612 namespace Fad { \ 00613 template <typename ValueT, int Num> \ 00614 struct ArrayValueType< FADTYPE<ValueT,Num> > { \ 00615 typedef ValueT type; \ 00616 }; \ 00617 } \ 00618 } 00619 TEUCHOS_BLAS_FAD_SPEC(Sacado::Fad::DFad) 00620 TEUCHOS_BLAS_SFAD_SPEC(Sacado::Fad::SFad) 00621 TEUCHOS_BLAS_SFAD_SPEC(Sacado::Fad::SLFad) 00622 TEUCHOS_BLAS_FAD_SPEC(Sacado::Fad::DMFad) 00623 TEUCHOS_BLAS_FAD_SPEC(Sacado::Fad::DVFad) 00624 TEUCHOS_BLAS_FAD_SPEC(Sacado::ELRFad::DFad) 00625 TEUCHOS_BLAS_SFAD_SPEC(Sacado::ELRFad::SFad) 00626 TEUCHOS_BLAS_SFAD_SPEC(Sacado::ELRFad::SLFad) 00627 TEUCHOS_BLAS_FAD_SPEC(Sacado::CacheFad::DFad) 00628 00629 #undef TEUCHOS_BLAS_FAD_SPEC 00630 #undef TEUCHOS_BLAS_SFAD_SPEC 00631 00632 #include "Sacado_Fad_BLASImp.hpp" 00633 00634 #endif // SACADO_FAD_BLAS_HPP
1.7.4