|
Sacado Package Browser (Single Doxygen Collection) Version of the Day
|
00001 // $Id$ 00002 // $Source$ 00003 // @HEADER 00004 // *********************************************************************** 00005 // 00006 // Sacado Package 00007 // Copyright (2006) Sandia Corporation 00008 // 00009 // Under the terms of Contract DE-AC04-94AL85000 with Sandia Corporation, 00010 // the U.S. Government retains certain rights in this software. 00011 // 00012 // This library is free software; you can redistribute it and/or modify 00013 // it under the terms of the GNU Lesser General Public License as 00014 // published by the Free Software Foundation; either version 2.1 of the 00015 // License, or (at your option) any later version. 00016 // 00017 // This library is distributed in the hope that it will be useful, but 00018 // WITHOUT ANY WARRANTY; without even the implied warranty of 00019 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU 00020 // Lesser General Public License for more details. 00021 // 00022 // You should have received a copy of the GNU Lesser General Public 00023 // License along with this library; if not, write to the Free Software 00024 // Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 00025 // USA 00026 // Questions? Contact David M. Gay (dmgay@sandia.gov) or Eric T. Phipps 00027 // (etphipp@sandia.gov). 00028 // 00029 // *********************************************************************** 00030 // @HEADER 00031 00032 #include "Teuchos_Assert.hpp" 00033 00034 template <typename OrdinalType, typename FadType> 00035 Sacado::Fad::ArrayTraits<OrdinalType,FadType>:: 00036 ArrayTraits(bool use_dynamic_, 00037 OrdinalType workspace_size_) : 00038 use_dynamic(use_dynamic_), 00039 workspace_size(workspace_size_), 00040 workspace(NULL), 00041 workspace_pointer(NULL) 00042 { 00043 if (workspace_size > 0) { 00044 workspace = new ValueType[workspace_size]; 00045 workspace_pointer = workspace; 00046 } 00047 } 00048 00049 template <typename OrdinalType, typename FadType> 00050 Sacado::Fad::ArrayTraits<OrdinalType,FadType>:: 00051 ArrayTraits(const ArrayTraits& a) : 00052 use_dynamic(a.use_dynamic), 00053 workspace_size(a.workspace_size), 00054 workspace(NULL), 00055 workspace_pointer(NULL) 00056 { 00057 if (workspace_size > 0) { 00058 workspace = new ValueType*[workspace_size]; 00059 workspace_pointer = workspace; 00060 } 00061 } 00062 00063 00064 template <typename OrdinalType, typename FadType> 00065 Sacado::Fad::ArrayTraits<OrdinalType,FadType>:: 00066 ~ArrayTraits() 00067 { 00068 // #ifdef SACADO_DEBUG 00069 // TEUCHOS_TEST_FOR_EXCEPTION(workspace_pointer != workspace, 00070 // std::logic_error, 00071 // "ArrayTraits::~ArrayTraits(): " << 00072 // "Destructor called with non-zero used workspace. " << 00073 // "Currently used size is " << workspace_pointer-workspace << 00074 // "."); 00075 00076 // #endif 00077 00078 if (workspace_size > 0) 00079 delete [] workspace; 00080 } 00081 00082 template <typename OrdinalType, typename FadType> 00083 void 00084 Sacado::Fad::ArrayTraits<OrdinalType,FadType>:: 00085 unpack(const FadType& a, OrdinalType& n_dot, ValueType& val, 00086 const ValueType*& dot) const 00087 { 00088 n_dot = a.size(); 00089 val = a.val(); 00090 if (n_dot > 0) 00091 dot = &a.fastAccessDx(0); 00092 else 00093 dot = NULL; 00094 } 00095 00096 template <typename OrdinalType, typename FadType> 00097 void 00098 Sacado::Fad::ArrayTraits<OrdinalType,FadType>:: 00099 unpack(const FadType* a, OrdinalType n, OrdinalType inc, 00100 OrdinalType& n_dot, OrdinalType& inc_val, OrdinalType& inc_dot, 00101 const ValueType*& cval, const ValueType*& cdot) const 00102 { 00103 if (n == 0) { 00104 n_dot = 0; 00105 inc_val = 0; 00106 inc_dot = 0; 00107 cval = NULL; 00108 cdot = NULL; 00109 return; 00110 } 00111 00112 n_dot = a[0].size(); 00113 bool is_contiguous = is_array_contiguous(a, n, n_dot); 00114 if (is_contiguous) { 00115 inc_val = inc; 00116 inc_dot = inc; 00117 cval = &a[0].val(); 00118 if (n_dot > 0) 00119 cdot = &a[0].fastAccessDx(0); 00120 } 00121 else { 00122 inc_val = 1; 00123 inc_dot = 0; 00124 ValueType *val = allocate_array(n); 00125 ValueType *dot = NULL; 00126 if (n_dot > 0) { 00127 inc_dot = 1; 00128 dot = allocate_array(n*n_dot); 00129 } 00130 for (OrdinalType i=0; i<n; i++) { 00131 val[i] = a[i*inc].val(); 00132 for (OrdinalType j=0; j<n_dot; j++) 00133 dot[j*n+i] = a[i*inc].fastAccessDx(j); 00134 } 00135 00136 cval = val; 00137 cdot = dot; 00138 } 00139 } 00140 00141 template <typename OrdinalType, typename FadType> 00142 void 00143 Sacado::Fad::ArrayTraits<OrdinalType,FadType>:: 00144 unpack(const FadType* A, OrdinalType m, OrdinalType n, OrdinalType lda, 00145 OrdinalType& n_dot, OrdinalType& lda_val, OrdinalType& lda_dot, 00146 const ValueType*& cval, const ValueType*& cdot) const 00147 { 00148 if (m*n == 0) { 00149 n_dot = 0; 00150 lda_val = 0; 00151 lda_dot = 0; 00152 cval = NULL; 00153 cdot = NULL; 00154 return; 00155 } 00156 00157 n_dot = A[0].size(); 00158 bool is_contiguous = is_array_contiguous(A, m*n, n_dot); 00159 if (is_contiguous) { 00160 lda_val = lda; 00161 lda_dot = lda; 00162 cval = &A[0].val(); 00163 if (n_dot > 0) 00164 cdot = &A[0].fastAccessDx(0); 00165 } 00166 else { 00167 lda_val = m; 00168 lda_dot = 0; 00169 ValueType *val = allocate_array(m*n); 00170 ValueType *dot = NULL; 00171 if (n_dot > 0) { 00172 lda_dot = m; 00173 dot = allocate_array(m*n*n_dot); 00174 } 00175 for (OrdinalType j=0; j<n; j++) { 00176 for (OrdinalType i=0; i<m; i++) { 00177 val[j*m+i] = A[j*lda+i].val(); 00178 for (OrdinalType k=0; k<n_dot; k++) 00179 dot[(k*n+j)*m+i] = A[j*lda+i].fastAccessDx(k); 00180 } 00181 } 00182 00183 cval = val; 00184 cdot = dot; 00185 } 00186 } 00187 00188 template <typename OrdinalType, typename FadType> 00189 void 00190 Sacado::Fad::ArrayTraits<OrdinalType,FadType>:: 00191 unpack(const ValueType& a, OrdinalType& n_dot, ValueType& val, 00192 const ValueType*& dot) const 00193 { 00194 n_dot = 0; 00195 val = a; 00196 dot = NULL; 00197 } 00198 00199 template <typename OrdinalType, typename FadType> 00200 void 00201 Sacado::Fad::ArrayTraits<OrdinalType,FadType>:: 00202 unpack(const ValueType* a, OrdinalType n, OrdinalType inc, 00203 OrdinalType& n_dot, OrdinalType& inc_val, OrdinalType& inc_dot, 00204 const ValueType*& cval, const ValueType*& cdot) const 00205 { 00206 n_dot = 0; 00207 inc_val = inc; 00208 inc_dot = 0; 00209 cval = a; 00210 cdot = NULL; 00211 } 00212 00213 template <typename OrdinalType, typename FadType> 00214 void 00215 Sacado::Fad::ArrayTraits<OrdinalType,FadType>:: 00216 unpack(const ValueType* A, OrdinalType m, OrdinalType n, OrdinalType lda, 00217 OrdinalType& n_dot, OrdinalType& lda_val, OrdinalType& lda_dot, 00218 const ValueType*& cval, const ValueType*& cdot) const 00219 { 00220 n_dot = 0; 00221 lda_val = lda; 00222 lda_dot = 0; 00223 cval = A; 00224 cdot = NULL; 00225 } 00226 00227 template <typename OrdinalType, typename FadType> 00228 void 00229 Sacado::Fad::ArrayTraits<OrdinalType,FadType>:: 00230 unpack(const ScalarType& a, OrdinalType& n_dot, ScalarType& val, 00231 const ScalarType*& dot) const 00232 { 00233 n_dot = 0; 00234 val = a; 00235 dot = NULL; 00236 } 00237 00238 template <typename OrdinalType, typename FadType> 00239 void 00240 Sacado::Fad::ArrayTraits<OrdinalType,FadType>:: 00241 unpack(const ScalarType* a, OrdinalType n, OrdinalType inc, 00242 OrdinalType& n_dot, OrdinalType& inc_val, OrdinalType& inc_dot, 00243 const ScalarType*& cval, const ScalarType*& cdot) const 00244 { 00245 n_dot = 0; 00246 inc_val = inc; 00247 inc_dot = 0; 00248 cval = a; 00249 cdot = NULL; 00250 } 00251 00252 template <typename OrdinalType, typename FadType> 00253 void 00254 Sacado::Fad::ArrayTraits<OrdinalType,FadType>:: 00255 unpack(const ScalarType* A, OrdinalType m, OrdinalType n, OrdinalType lda, 00256 OrdinalType& n_dot, OrdinalType& lda_val, OrdinalType& lda_dot, 00257 const ScalarType*& cval, const ScalarType*& cdot) const 00258 { 00259 n_dot = 0; 00260 lda_val = lda; 00261 lda_dot = 0; 00262 cval = A; 00263 cdot = NULL; 00264 } 00265 00266 template <typename OrdinalType, typename FadType> 00267 void 00268 Sacado::Fad::ArrayTraits<OrdinalType,FadType>:: 00269 unpack(FadType& a, OrdinalType& n_dot, OrdinalType& final_n_dot, ValueType& val, 00270 ValueType*& dot) const 00271 { 00272 n_dot = a.size(); 00273 val = a.val(); 00274 #ifdef SACADO_DEBUG 00275 TEUCHOS_TEST_FOR_EXCEPTION(n_dot > 0 && final_n_dot > 0 && final_n_dot != n_dot, 00276 std::logic_error, 00277 "ArrayTraits::unpack(): FadType has wrong number of " << 00278 "derivative components. Got " << n_dot << 00279 ", expected " << final_n_dot << "."); 00280 #endif 00281 if (n_dot > final_n_dot) 00282 final_n_dot = n_dot; 00283 00284 OrdinalType n_avail = a.availableSize(); 00285 if (n_avail < final_n_dot) { 00286 dot = alloate(final_n_dot); 00287 for (OrdinalType i=0; i<final_n_dot; i++) 00288 dot[i] = 0.0; 00289 } 00290 else if (n_avail > 0) 00291 dot = &a.fastAccessDx(0); 00292 else 00293 dot = NULL; 00294 } 00295 00296 template <typename OrdinalType, typename FadType> 00297 void 00298 Sacado::Fad::ArrayTraits<OrdinalType,FadType>:: 00299 unpack(FadType* a, OrdinalType n, OrdinalType inc, OrdinalType& n_dot, 00300 OrdinalType& final_n_dot, OrdinalType& inc_val, OrdinalType& inc_dot, 00301 ValueType*& val, ValueType*& dot) const 00302 { 00303 if (n == 0) { 00304 inc_val = 0; 00305 inc_dot = 0; 00306 val = NULL; 00307 dot = NULL; 00308 return; 00309 } 00310 00311 n_dot = a[0].size(); 00312 bool is_contiguous = is_array_contiguous(a, n, n_dot); 00313 #ifdef SACADO_DEBUG 00314 TEUCHOS_TEST_FOR_EXCEPTION(n_dot > 0 && final_n_dot > 0 && final_n_dot != n_dot, 00315 std::logic_error, 00316 "ArrayTraits::unpack(): FadType has wrong number of " << 00317 "derivative components. Got " << n_dot << 00318 ", expected " << final_n_dot << "."); 00319 #endif 00320 if (n_dot > final_n_dot) 00321 final_n_dot = n_dot; 00322 00323 if (is_contiguous) { 00324 inc_val = inc; 00325 val = &a[0].val(); 00326 } 00327 else { 00328 inc_val = 1; 00329 val = allocate_array(n); 00330 for (OrdinalType i=0; i<n; i++) 00331 val[i] = a[i*inc].val(); 00332 } 00333 00334 OrdinalType n_avail = a[0].availableSize(); 00335 if (is_contiguous && n_avail >= final_n_dot && final_n_dot > 0) { 00336 inc_dot = inc; 00337 dot = &a[0].fastAccessDx(0); 00338 } 00339 else if (final_n_dot > 0) { 00340 inc_dot = 1; 00341 dot = allocate_array(n*final_n_dot); 00342 for (OrdinalType i=0; i<n; i++) { 00343 if (n_dot > 0) 00344 for (OrdinalType j=0; j<n_dot; j++) 00345 dot[j*n+i] = a[i*inc].fastAccessDx(j); 00346 else 00347 for (OrdinalType j=0; j<final_n_dot; j++) 00348 dot[j*n+i] = 0.0; 00349 } 00350 } 00351 else { 00352 inc_dot = 0; 00353 dot = NULL; 00354 } 00355 } 00356 00357 template <typename OrdinalType, typename FadType> 00358 void 00359 Sacado::Fad::ArrayTraits<OrdinalType,FadType>:: 00360 unpack(FadType* A, OrdinalType m, OrdinalType n, OrdinalType lda, 00361 OrdinalType& n_dot, OrdinalType& final_n_dot, 00362 OrdinalType& lda_val, OrdinalType& lda_dot, 00363 ValueType*& val, ValueType*& dot) const 00364 { 00365 if (m*n == 0) { 00366 lda_val = 0; 00367 lda_dot = 0; 00368 val = NULL; 00369 dot = NULL; 00370 return; 00371 } 00372 00373 n_dot = A[0].size(); 00374 bool is_contiguous = is_array_contiguous(A, m*n, n_dot); 00375 #ifdef SACADO_DEBUG 00376 TEUCHOS_TEST_FOR_EXCEPTION(n_dot > 0 && final_n_dot > 0 && final_n_dot != n_dot, 00377 std::logic_error, 00378 "ArrayTraits::unpack(): FadType has wrong number of " << 00379 "derivative components. Got " << n_dot << 00380 ", expected " << final_n_dot << "."); 00381 #endif 00382 if (n_dot > final_n_dot) 00383 final_n_dot = n_dot; 00384 00385 if (is_contiguous) { 00386 lda_val = lda; 00387 val = &A[0].val(); 00388 } 00389 else { 00390 lda_val = m; 00391 val = allocate_array(m*n); 00392 for (OrdinalType j=0; j<n; j++) 00393 for (OrdinalType i=0; i<m; i++) 00394 val[j*m+i] = A[j*lda+i].val(); 00395 } 00396 00397 OrdinalType n_avail = A[0].availableSize(); 00398 if (is_contiguous && n_avail >= final_n_dot && final_n_dot > 0) { 00399 lda_dot = lda; 00400 dot = &A[0].fastAccessDx(0); 00401 } 00402 else if (final_n_dot > 0) { 00403 lda_dot = m; 00404 dot = allocate_array(m*n*final_n_dot); 00405 for (OrdinalType j=0; j<n; j++) { 00406 for (OrdinalType i=0; i<m; i++) { 00407 if (n_dot > 0) 00408 for (OrdinalType k=0; k<n_dot; k++) 00409 dot[(k*n+j)*m+i] = A[j*lda+i].fastAccessDx(k); 00410 else 00411 for (OrdinalType k=0; k<final_n_dot; k++) 00412 dot[(k*n+j)*m+i] = 0.0; 00413 } 00414 } 00415 } 00416 else { 00417 lda_dot = 0; 00418 dot = NULL; 00419 } 00420 } 00421 00422 template <typename OrdinalType, typename FadType> 00423 void 00424 Sacado::Fad::ArrayTraits<OrdinalType,FadType>:: 00425 pack(FadType& a, OrdinalType n_dot, const ValueType& val, 00426 const ValueType* dot) const 00427 { 00428 a.val() = val; 00429 00430 if (n_dot == 0) 00431 return; 00432 00433 if (a.size() != n_dot) 00434 a.resize(n_dot); 00435 if (a.dx() != dot) 00436 for (OrdinalType i=0; i<n_dot; i++) 00437 a.fastAccessDx(i) = dot[i]; 00438 } 00439 00440 template <typename OrdinalType, typename FadType> 00441 void 00442 Sacado::Fad::ArrayTraits<OrdinalType,FadType>:: 00443 pack(FadType* a, OrdinalType n, OrdinalType inc, 00444 OrdinalType n_dot, OrdinalType inc_val, OrdinalType inc_dot, 00445 const ValueType* val, const ValueType* dot) const 00446 { 00447 if (n == 0) 00448 return; 00449 00450 // Copy values 00451 if (&a[0].val() != val) 00452 for (OrdinalType i=0; i<n; i++) 00453 a[i*inc].val() = val[i*inc_val]; 00454 00455 if (n_dot == 0) 00456 return; 00457 00458 // Resize derivative arrays 00459 if (a[0].size() != n_dot) 00460 for (OrdinalType i=0; i<n; i++) 00461 a[i*inc].resize(n_dot); 00462 00463 // Copy derivatives 00464 if (a[0].dx() != dot) 00465 for (OrdinalType i=0; i<n; i++) 00466 for (OrdinalType j=0; j<n_dot; j++) 00467 a[i*inc].fastAccessDx(j) = dot[(i+j*n)*inc_dot]; 00468 } 00469 00470 template <typename OrdinalType, typename FadType> 00471 void 00472 Sacado::Fad::ArrayTraits<OrdinalType,FadType>:: 00473 pack(FadType* A, OrdinalType m, OrdinalType n, OrdinalType lda, 00474 OrdinalType n_dot, OrdinalType lda_val, OrdinalType lda_dot, 00475 const ValueType* val, const ValueType* dot) const 00476 { 00477 if (m*n == 0) 00478 return; 00479 00480 // Copy values 00481 if (&A[0].val() != val) 00482 for (OrdinalType j=0; j<n; j++) 00483 for (OrdinalType i=0; i<m; i++) 00484 A[i+j*lda].val() = val[i+j*lda_val]; 00485 00486 if (n_dot == 0) 00487 return; 00488 00489 // Resize derivative arrays 00490 if (A[0].size() != n_dot) 00491 for (OrdinalType j=0; j<n; j++) 00492 for (OrdinalType i=0; i<m; i++) 00493 A[i+j*lda].resize(n_dot); 00494 00495 // Copy derivatives 00496 if (A[0].dx() != dot) 00497 for (OrdinalType j=0; j<n; j++) 00498 for (OrdinalType i=0; i<m; i++) 00499 for (OrdinalType k=0; k<n_dot; k++) 00500 A[i+j*lda].fastAccessDx(k) = dot[i+(j+k*n)*lda_dot]; 00501 } 00502 00503 template <typename OrdinalType, typename FadType> 00504 void 00505 Sacado::Fad::ArrayTraits<OrdinalType,FadType>:: 00506 free(const FadType& a, OrdinalType n_dot, const ValueType* dot) const 00507 { 00508 if (n_dot > 0 && a.dx() != dot) { 00509 free_array(dot, n_dot); 00510 } 00511 } 00512 00513 template <typename OrdinalType, typename FadType> 00514 void 00515 Sacado::Fad::ArrayTraits<OrdinalType,FadType>:: 00516 free(const FadType* a, OrdinalType n, OrdinalType n_dot, 00517 OrdinalType inc_val, OrdinalType inc_dot, 00518 const ValueType* val, const ValueType* dot) const 00519 { 00520 if (n == 0) 00521 return; 00522 00523 if (val != &a[0].val()) 00524 free_array(val, n*inc_val); 00525 00526 if (n_dot > 0 && a[0].dx() != dot) 00527 free_array(dot, n*inc_dot*n_dot); 00528 } 00529 00530 template <typename OrdinalType, typename FadType> 00531 void 00532 Sacado::Fad::ArrayTraits<OrdinalType,FadType>:: 00533 free(const FadType* A, OrdinalType m, OrdinalType n, OrdinalType n_dot, 00534 OrdinalType lda_val, OrdinalType lda_dot, 00535 const ValueType* val, const ValueType* dot) const 00536 { 00537 if (m*n == 0) 00538 return; 00539 00540 if (val != &A[0].val()) 00541 free_array(val, lda_val*n); 00542 00543 if (n_dot > 0 && A[0].dx() != dot) 00544 free_array(dot, lda_dot*n*n_dot); 00545 } 00546 00547 template <typename OrdinalType, typename FadType> 00548 typename Sacado::Fad::ArrayTraits<OrdinalType,FadType>::ValueType* 00549 Sacado::Fad::ArrayTraits<OrdinalType,FadType>:: 00550 allocate_array(OrdinalType size) const 00551 { 00552 if (use_dynamic) 00553 return new ValueType[size]; 00554 00555 #ifdef SACADO_DEBUG 00556 TEUCHOS_TEST_FOR_EXCEPTION(workspace_pointer + size - workspace > workspace_size, 00557 std::logic_error, 00558 "ArrayTraits::allocate_array(): " << 00559 "Requested workspace memory beyond size allocated. " << 00560 "Workspace size is " << workspace_size << 00561 ", currently used is " << workspace_pointer-workspace << 00562 ", requested size is " << size << "."); 00563 00564 #endif 00565 00566 ValueType *v = workspace_pointer; 00567 workspace_pointer += size; 00568 return v; 00569 } 00570 00571 template <typename OrdinalType, typename FadType> 00572 void 00573 Sacado::Fad::ArrayTraits<OrdinalType,FadType>:: 00574 free_array(const ValueType* ptr, OrdinalType size) const 00575 { 00576 if (use_dynamic && ptr != NULL) 00577 delete [] ptr; 00578 else 00579 workspace_pointer -= size; 00580 } 00581 00582 template <typename OrdinalType, typename FadType> 00583 bool 00584 Sacado::Fad::ArrayTraits<OrdinalType,FadType>:: 00585 is_array_contiguous(const FadType* a, OrdinalType n, OrdinalType n_dot) const 00586 { 00587 return (n > 0) && 00588 (&(a[n-1].val())-&(a[0].val()) == n-1) && 00589 (a[n-1].dx()-a[0].dx() == n-1); 00590 } 00591 00592 template <typename OrdinalType, typename FadType> 00593 Sacado::Fad::BLAS<OrdinalType,FadType>:: 00594 BLAS(bool use_default_impl_, 00595 bool use_dynamic_, OrdinalType static_workspace_size_) : 00596 BLASType(), 00597 arrayTraits(use_dynamic_, static_workspace_size_), 00598 blas(), 00599 use_default_impl(use_default_impl_) 00600 { 00601 } 00602 00603 template <typename OrdinalType, typename FadType> 00604 Sacado::Fad::BLAS<OrdinalType,FadType>:: 00605 BLAS(const BLAS& x) : 00606 BLASType(x), 00607 arrayTraits(x.arrayTraits), 00608 blas(x.blas), 00609 use_default_impl(x.use_default_impl) 00610 { 00611 } 00612 00613 template <typename OrdinalType, typename FadType> 00614 Sacado::Fad::BLAS<OrdinalType,FadType>:: 00615 ~BLAS() 00616 { 00617 } 00618 00619 template <typename OrdinalType, typename FadType> 00620 void 00621 Sacado::Fad::BLAS<OrdinalType,FadType>:: 00622 SCAL(const OrdinalType n, const FadType& alpha, FadType* x, 00623 const OrdinalType incx) const 00624 { 00625 if (use_default_impl) { 00626 BLASType::SCAL(n,alpha,x,incx); 00627 return; 00628 } 00629 00630 // Unpack input values & derivatives 00631 ValueType alpha_val; 00632 const ValueType *alpha_dot; 00633 ValueType *x_val, *x_dot; 00634 OrdinalType n_alpha_dot, n_x_dot, n_dot; 00635 OrdinalType incx_val, incx_dot; 00636 arrayTraits.unpack(alpha, n_alpha_dot, alpha_val, alpha_dot); 00637 n_dot = n_alpha_dot; 00638 arrayTraits.unpack(x, n, incx, n_x_dot, n_dot, incx_val, incx_dot, 00639 x_val, x_dot); 00640 00641 #ifdef SACADO_DEBUG 00642 // Check sizes are consistent 00643 TEUCHOS_TEST_FOR_EXCEPTION((n_alpha_dot != n_dot && n_alpha_dot != 0) || 00644 (n_x_dot != n_dot && n_x_dot != 0), 00645 std::logic_error, 00646 "BLAS::SCAL(): All arguments must have " << 00647 "the same number of derivative components, or none"); 00648 #endif 00649 00650 // Call differentiated routine 00651 if (n_x_dot > 0) 00652 blas.SCAL(n*n_x_dot, alpha_val, x_dot, incx_dot); 00653 for (OrdinalType i=0; i<n_alpha_dot; i++) 00654 blas.AXPY(n, alpha_dot[i], x_val, incx_val, x_dot+i*n*incx_dot, incx_dot); 00655 blas.SCAL(n, alpha_val, x_val, incx_val); 00656 00657 // Pack values and derivatives for result 00658 arrayTraits.pack(x, n, incx, n_dot, incx_val, incx_dot, x_val, x_dot); 00659 00660 // Free temporary arrays 00661 arrayTraits.free(alpha, n_alpha_dot, alpha_dot); 00662 arrayTraits.free(x, n, n_dot, incx_val, incx_dot, x_val, x_dot); 00663 } 00664 00665 template <typename OrdinalType, typename FadType> 00666 void 00667 Sacado::Fad::BLAS<OrdinalType,FadType>:: 00668 COPY(const OrdinalType n, const FadType* x, const OrdinalType incx, 00669 FadType* y, const OrdinalType incy) const 00670 { 00671 if (use_default_impl) { 00672 BLASType::COPY(n,x,incx,y,incy); 00673 return; 00674 } 00675 00676 if (n == 0) 00677 return; 00678 00679 OrdinalType n_x_dot = x[0].size(); 00680 OrdinalType n_y_dot = y[0].size(); 00681 if (n_x_dot == 0 || n_y_dot == 0 || n_x_dot != n_y_dot || 00682 !arrayTraits.is_array_contiguous(x, n, n_x_dot) || 00683 !arrayTraits.is_array_contiguous(y, n, n_y_dot)) 00684 BLASType::COPY(n,x,incx,y,incy); 00685 else { 00686 blas.COPY(n, &x[0].val(), incx, &y[0].val(), incy); 00687 blas.COPY(n*n_x_dot, &x[0].fastAccessDx(0), incx, &y[0].fastAccessDx(0), 00688 incy); 00689 } 00690 } 00691 00692 template <typename OrdinalType, typename FadType> 00693 template <typename alpha_type, typename x_type> 00694 void 00695 Sacado::Fad::BLAS<OrdinalType,FadType>:: 00696 AXPY(const OrdinalType n, const alpha_type& alpha, const x_type* x, 00697 const OrdinalType incx, FadType* y, const OrdinalType incy) const 00698 { 00699 if (use_default_impl) { 00700 BLASType::AXPY(n,alpha,x,incx,y,incy); 00701 return; 00702 } 00703 00704 // Unpack input values & derivatives 00705 typename ArrayValueType<alpha_type>::type alpha_val; 00706 const typename ArrayValueType<alpha_type>::type *alpha_dot; 00707 const typename ArrayValueType<x_type>::type *x_val, *x_dot; 00708 ValueType *y_val, *y_dot; 00709 OrdinalType n_alpha_dot, n_x_dot, n_y_dot, n_dot; 00710 OrdinalType incx_val, incy_val, incx_dot, incy_dot; 00711 arrayTraits.unpack(alpha, n_alpha_dot, alpha_val, alpha_dot); 00712 arrayTraits.unpack(x, n, incx, n_x_dot, incx_val, incx_dot, x_val, x_dot); 00713 00714 // Compute size 00715 n_dot = 0; 00716 if (n_alpha_dot > 0) 00717 n_dot = n_alpha_dot; 00718 else if (n_x_dot > 0) 00719 n_dot = n_x_dot; 00720 00721 // Unpack and allocate y 00722 arrayTraits.unpack(y, n, incy, n_y_dot, n_dot, incy_val, incy_dot, y_val, 00723 y_dot); 00724 00725 #ifdef SACADO_DEBUG 00726 // Check sizes are consistent 00727 TEUCHOS_TEST_FOR_EXCEPTION((n_alpha_dot != n_dot && n_alpha_dot != 0) || 00728 (n_x_dot != n_dot && n_x_dot != 0) || 00729 (n_y_dot != n_dot && n_y_dot != 0), 00730 std::logic_error, 00731 "BLAS::AXPY(): All arguments must have " << 00732 "the same number of derivative components, or none"); 00733 #endif 00734 00735 // Call differentiated routine 00736 if (n_x_dot > 0) 00737 blas.AXPY(n*n_x_dot, alpha_val, x_dot, incx_dot, y_dot, incy_dot); 00738 for (OrdinalType i=0; i<n_alpha_dot; i++) 00739 blas.AXPY(n, alpha_dot[i], x_val, incx_val, y_dot+i*n*incy_dot, incy_dot); 00740 blas.AXPY(n, alpha_val, x_val, incx_val, y_val, incy_val); 00741 00742 // Pack values and derivatives for result 00743 arrayTraits.pack(y, n, incy, n_dot, incy_val, incy_dot, y_val, y_dot); 00744 00745 // Free temporary arrays 00746 arrayTraits.free(alpha, n_alpha_dot, alpha_dot); 00747 arrayTraits.free(x, n, n_x_dot, incx_val, incx_dot, x_val, x_dot); 00748 arrayTraits.free(y, n, n_dot, incy_val, incy_dot, y_val, y_dot); 00749 } 00750 00751 template <typename OrdinalType, typename FadType> 00752 template <typename x_type, typename y_type> 00753 FadType 00754 Sacado::Fad::BLAS<OrdinalType,FadType>:: 00755 DOT(const OrdinalType n, const x_type* x, const OrdinalType incx, 00756 const y_type* y, const OrdinalType incy) const 00757 { 00758 if (use_default_impl) 00759 return BLASType::DOT(n,x,incx,y,incy); 00760 00761 // Unpack input values & derivatives 00762 const typename ArrayValueType<x_type>::type *x_val, *x_dot; 00763 const typename ArrayValueType<y_type>::type *y_val, *y_dot; 00764 OrdinalType n_x_dot, n_y_dot; 00765 OrdinalType incx_val, incy_val, incx_dot, incy_dot; 00766 arrayTraits.unpack(x, n, incx, n_x_dot, incx_val, incx_dot, x_val, x_dot); 00767 arrayTraits.unpack(y, n, incy, n_y_dot, incy_val, incy_dot, y_val, y_dot); 00768 00769 // Compute size 00770 OrdinalType n_z_dot = 0; 00771 if (n_x_dot > 0) 00772 n_z_dot = n_x_dot; 00773 else if (n_y_dot > 0) 00774 n_z_dot = n_y_dot; 00775 00776 // Unpack and allocate z 00777 FadType z(n_z_dot, 0.0); 00778 ValueType& z_val = z.val(); 00779 ValueType *z_dot = &z.fastAccessDx(0); 00780 00781 // Call differentiated routine 00782 Fad_DOT(n, x_val, incx_val, n_x_dot, x_dot, incx_dot, 00783 y_val, incy_val, n_y_dot, y_dot, incy_dot, 00784 z_val, n_z_dot, z_dot); 00785 00786 // Free temporary arrays 00787 arrayTraits.free(x, n, n_x_dot, incx_val, incx_dot, x_val, x_dot); 00788 arrayTraits.free(y, n, n_y_dot, incy_val, incy_dot, y_val, y_dot); 00789 00790 return z; 00791 } 00792 00793 template <typename OrdinalType, typename FadType> 00794 typename Sacado::Fad::BLAS<OrdinalType,FadType>::MagnitudeType 00795 Sacado::Fad::BLAS<OrdinalType,FadType>:: 00796 NRM2(const OrdinalType n, const FadType* x, const OrdinalType incx) const 00797 { 00798 if (use_default_impl) 00799 return BLASType::NRM2(n,x,incx); 00800 00801 // Unpack input values & derivatives 00802 const ValueType *x_val, *x_dot; 00803 OrdinalType n_x_dot, incx_val, incx_dot; 00804 arrayTraits.unpack(x, n, incx, n_x_dot, incx_val, incx_dot, x_val, x_dot); 00805 00806 // Unpack and allocate z 00807 MagnitudeType z(n_x_dot, 0.0); 00808 00809 // Call differentiated routine 00810 z.val() = blas.NRM2(n, x_val, incx_val); 00811 // if (!Teuchos::ScalarTraits<FadType>::isComplex && incx_dot == 1) 00812 // blas.GEMV(Teuchos::TRANS, n, n_x_dot, 1.0/z.val(), x_dot, n, x_val, 00813 // incx_val, 1.0, &z.fastAccessDx(0), OrdinalType(1)); 00814 // else 00815 for (OrdinalType i=0; i<n_x_dot; i++) 00816 z.fastAccessDx(i) = 00817 Teuchos::ScalarTraits<ValueType>::magnitude(blas.DOT(n, x_dot+i*n*incx_dot, incx_dot, x_val, incx_val)) / z.val(); 00818 00819 // Free temporary arrays 00820 arrayTraits.free(x, n, n_x_dot, incx_val, incx_dot, x_val, x_dot); 00821 00822 return z; 00823 } 00824 00825 template <typename OrdinalType, typename FadType> 00826 template <typename alpha_type, typename A_type, typename x_type, 00827 typename beta_type> 00828 void 00829 Sacado::Fad::BLAS<OrdinalType,FadType>:: 00830 GEMV(Teuchos::ETransp trans, const OrdinalType m, const OrdinalType n, 00831 const alpha_type& alpha, const A_type* A, 00832 const OrdinalType lda, const x_type* x, 00833 const OrdinalType incx, const beta_type& beta, 00834 FadType* y, const OrdinalType incy) const 00835 { 00836 if (use_default_impl) { 00837 BLASType::GEMV(trans,m,n,alpha,A,lda,x,incx,beta,y,incy); 00838 return; 00839 } 00840 00841 OrdinalType n_x_rows = n; 00842 OrdinalType n_y_rows = m; 00843 if (trans != Teuchos::NO_TRANS) { 00844 n_x_rows = m; 00845 n_y_rows = n; 00846 } 00847 00848 // Unpack input values & derivatives 00849 typename ArrayValueType<alpha_type>::type alpha_val; 00850 const typename ArrayValueType<alpha_type>::type *alpha_dot; 00851 typename ArrayValueType<beta_type>::type beta_val; 00852 const typename ArrayValueType<beta_type>::type *beta_dot; 00853 const typename ArrayValueType<A_type>::type *A_val, *A_dot; 00854 const typename ArrayValueType<x_type>::type *x_val, *x_dot; 00855 ValueType *y_val, *y_dot; 00856 OrdinalType n_alpha_dot, n_A_dot, n_x_dot, n_beta_dot, n_y_dot, n_dot; 00857 OrdinalType lda_val, incx_val, incy_val, lda_dot, incx_dot, incy_dot; 00858 arrayTraits.unpack(alpha, n_alpha_dot, alpha_val, alpha_dot); 00859 arrayTraits.unpack(A, m, n, lda, n_A_dot, lda_val, lda_dot, A_val, A_dot); 00860 arrayTraits.unpack(x, n_x_rows, incx, n_x_dot, incx_val, incx_dot, x_val, 00861 x_dot); 00862 arrayTraits.unpack(beta, n_beta_dot, beta_val, beta_dot); 00863 00864 // Compute size 00865 n_dot = 0; 00866 if (n_alpha_dot > 0) 00867 n_dot = n_alpha_dot; 00868 else if (n_A_dot > 0) 00869 n_dot = n_A_dot; 00870 else if (n_x_dot > 0) 00871 n_dot = n_x_dot; 00872 else if (n_beta_dot > 0) 00873 n_dot = n_beta_dot; 00874 00875 // Unpack and allocate y 00876 arrayTraits.unpack(y, n_y_rows, incy, n_y_dot, n_dot, incy_val, incy_dot, 00877 y_val, y_dot); 00878 00879 // Call differentiated routine 00880 Fad_GEMV(trans, m, n, alpha_val, n_alpha_dot, alpha_dot, A_val, lda_val, 00881 n_A_dot, A_dot, lda_dot, x_val, incx_val, n_x_dot, x_dot, incx_dot, 00882 beta_val, n_beta_dot, beta_dot, y_val, incy_val, n_y_dot, y_dot, 00883 incy_dot, n_dot); 00884 00885 // Pack values and derivatives for result 00886 arrayTraits.pack(y, n_y_rows, incy, n_dot, incy_val, incy_dot, y_val, y_dot); 00887 00888 // Free temporary arrays 00889 arrayTraits.free(alpha, n_alpha_dot, alpha_dot); 00890 arrayTraits.free(A, m, n, n_A_dot, lda_val, lda_dot, A_val, A_dot); 00891 arrayTraits.free(x, n_x_rows, n_x_dot, incx_val, incx_dot, x_val, x_dot); 00892 arrayTraits.free(beta, n_beta_dot, beta_dot); 00893 arrayTraits.free(y, n_y_rows, n_dot, incy_val, incy_dot, y_val, y_dot); 00894 } 00895 00896 template <typename OrdinalType, typename FadType> 00897 template <typename A_type> 00898 void 00899 Sacado::Fad::BLAS<OrdinalType,FadType>:: 00900 TRMV(Teuchos::EUplo uplo, Teuchos::ETransp trans, Teuchos::EDiag diag, 00901 const OrdinalType n, const A_type* A, const OrdinalType lda, 00902 FadType* x, const OrdinalType incx) const 00903 { 00904 if (use_default_impl) { 00905 BLASType::TRMV(uplo,trans,diag,n,A,lda,x,incx); 00906 return; 00907 } 00908 00909 // Unpack input values & derivatives 00910 const typename ArrayValueType<A_type>::type *A_val, *A_dot; 00911 ValueType *x_val, *x_dot; 00912 OrdinalType n_A_dot, n_x_dot, n_dot; 00913 OrdinalType lda_val, incx_val, lda_dot, incx_dot; 00914 arrayTraits.unpack(A, n, n, lda, n_A_dot, lda_val, lda_dot, A_val, A_dot); 00915 n_dot = n_A_dot; 00916 arrayTraits.unpack(x, n, incx, n_x_dot, n_dot, incx_val, incx_dot, x_val, 00917 x_dot); 00918 00919 #ifdef SACADO_DEBUG 00920 // Check sizes are consistent 00921 TEUCHOS_TEST_FOR_EXCEPTION((n_A_dot != n_dot && n_A_dot != 0) || 00922 (n_x_dot != n_dot && n_x_dot != 0), 00923 std::logic_error, 00924 "BLAS::TRMV(): All arguments must have " << 00925 "the same number of derivative components, or none"); 00926 #endif 00927 00928 // Compute [xd_1 .. xd_n] = A*[xd_1 .. xd_n] 00929 if (n_x_dot > 0) { 00930 if (incx_dot == 1) 00931 blas.TRMM(Teuchos::LEFT_SIDE, uplo, trans, diag, n, n_x_dot, 1.0, A_val, 00932 lda_val, x_dot, n); 00933 else 00934 for (OrdinalType i=0; i<n_x_dot; i++) 00935 blas.TRMV(uplo, trans, diag, n, A_val, lda_val, x_dot+i*incx_dot*n, 00936 incx_dot); 00937 } 00938 00939 // Compute [xd_1 .. xd_n] = [Ad_1*x .. Ad_n*x] 00940 if (gemv_Ax.size() != std::size_t(n)) 00941 gemv_Ax.resize(n); 00942 for (OrdinalType i=0; i<n_A_dot; i++) { 00943 blas.COPY(n, x_val, incx_val, &gemv_Ax[0], OrdinalType(1)); 00944 blas.TRMV(uplo, trans, Teuchos::NON_UNIT_DIAG, n, A_dot+i*lda_dot*n, 00945 lda_dot, &gemv_Ax[0], OrdinalType(1)); 00946 blas.AXPY(n, 1.0, &gemv_Ax[0], OrdinalType(1), x_dot+i*incx_dot*n, 00947 incx_dot); 00948 } 00949 00950 // Compute x = A*x 00951 blas.TRMV(uplo, trans, diag, n, A_val, lda_val, x_val, incx_val); 00952 00953 // Pack values and derivatives for result 00954 arrayTraits.pack(x, n, incx, n_dot, incx_val, incx_dot, x_val, x_dot); 00955 00956 // Free temporary arrays 00957 arrayTraits.free(A, n, n, n_A_dot, lda_val, lda_dot, A_val, A_dot); 00958 arrayTraits.free(x, n, n_dot, incx_val, incx_dot, x_val, x_dot); 00959 } 00960 00961 template <typename OrdinalType, typename FadType> 00962 template <typename alpha_type, typename x_type, typename y_type> 00963 void 00964 Sacado::Fad::BLAS<OrdinalType,FadType>:: 00965 GER(const OrdinalType m, const OrdinalType n, const alpha_type& alpha, 00966 const x_type* x, const OrdinalType incx, 00967 const y_type* y, const OrdinalType incy, 00968 FadType* A, const OrdinalType lda) const 00969 { 00970 if (use_default_impl) { 00971 BLASType::GER(m,n,alpha,x,incx,y,incy,A,lda); 00972 return; 00973 } 00974 00975 // Unpack input values & derivatives 00976 typename ArrayValueType<alpha_type>::type alpha_val; 00977 const typename ArrayValueType<alpha_type>::type *alpha_dot; 00978 const typename ArrayValueType<x_type>::type *x_val, *x_dot; 00979 const typename ArrayValueType<y_type>::type *y_val, *y_dot; 00980 ValueType *A_val, *A_dot; 00981 OrdinalType n_alpha_dot, n_x_dot, n_y_dot, n_A_dot, n_dot; 00982 OrdinalType lda_val, incx_val, incy_val, lda_dot, incx_dot, incy_dot; 00983 arrayTraits.unpack(alpha, n_alpha_dot, alpha_val, alpha_dot); 00984 arrayTraits.unpack(x, m, incx, n_x_dot, incx_val, incx_dot, x_val, x_dot); 00985 arrayTraits.unpack(y, n, incy, n_y_dot, incy_val, incy_dot, y_val, y_dot); 00986 00987 // Compute size 00988 n_dot = 0; 00989 if (n_alpha_dot > 0) 00990 n_dot = n_alpha_dot; 00991 else if (n_x_dot > 0) 00992 n_dot = n_x_dot; 00993 else if (n_y_dot > 0) 00994 n_dot = n_y_dot; 00995 00996 // Unpack and allocate A 00997 arrayTraits.unpack(A, m, n, lda, n_A_dot, n_dot, lda_val, lda_dot, A_val, 00998 A_dot); 00999 01000 // Call differentiated routine 01001 Fad_GER(m, n, alpha_val, n_alpha_dot, alpha_dot, x_val, incx_val, 01002 n_x_dot, x_dot, incx_dot, y_val, incy_val, n_y_dot, y_dot, 01003 incy_dot, A_val, lda_val, n_A_dot, A_dot, lda_dot, n_dot); 01004 01005 // Pack values and derivatives for result 01006 arrayTraits.pack(A, m, n, lda, n_dot, lda_val, lda_dot, A_val, A_dot); 01007 01008 // Free temporary arrays 01009 arrayTraits.free(alpha, n_alpha_dot, alpha_dot); 01010 arrayTraits.free(x, m, n_x_dot, incx_val, incx_dot, x_val, x_dot); 01011 arrayTraits.free(y, n, n_y_dot, incy_val, incy_dot, y_val, y_dot); 01012 arrayTraits.free(A, m, n, n_dot, lda_val, lda_dot, A_val, A_dot); 01013 } 01014 01015 template <typename OrdinalType, typename FadType> 01016 template <typename alpha_type, typename A_type, typename B_type, 01017 typename beta_type> 01018 void 01019 Sacado::Fad::BLAS<OrdinalType,FadType>:: 01020 GEMM(Teuchos::ETransp transa, Teuchos::ETransp transb, 01021 const OrdinalType m, const OrdinalType n, const OrdinalType k, 01022 const alpha_type& alpha, const A_type* A, const OrdinalType lda, 01023 const B_type* B, const OrdinalType ldb, const beta_type& beta, 01024 FadType* C, const OrdinalType ldc) const 01025 { 01026 if (use_default_impl) { 01027 BLASType::GEMM(transa,transb,m,n,k,alpha,A,lda,B,ldb,beta,C,ldc); 01028 return; 01029 } 01030 01031 OrdinalType n_A_rows = m; 01032 OrdinalType n_A_cols = k; 01033 if (transa != Teuchos::NO_TRANS) { 01034 n_A_rows = k; 01035 n_A_cols = m; 01036 } 01037 01038 OrdinalType n_B_rows = k; 01039 OrdinalType n_B_cols = n; 01040 if (transb != Teuchos::NO_TRANS) { 01041 n_B_rows = n; 01042 n_B_cols = k; 01043 } 01044 01045 // Unpack input values & derivatives 01046 typename ArrayValueType<alpha_type>::type alpha_val; 01047 const typename ArrayValueType<alpha_type>::type *alpha_dot; 01048 typename ArrayValueType<beta_type>::type beta_val; 01049 const typename ArrayValueType<beta_type>::type *beta_dot; 01050 const typename ArrayValueType<A_type>::type *A_val, *A_dot; 01051 const typename ArrayValueType<B_type>::type *B_val, *B_dot; 01052 ValueType *C_val, *C_dot; 01053 OrdinalType n_alpha_dot, n_A_dot, n_B_dot, n_beta_dot, n_C_dot, n_dot; 01054 OrdinalType lda_val, ldb_val, ldc_val, lda_dot, ldb_dot, ldc_dot; 01055 arrayTraits.unpack(alpha, n_alpha_dot, alpha_val, alpha_dot); 01056 arrayTraits.unpack(A, n_A_rows, n_A_cols, lda, n_A_dot, lda_val, lda_dot, 01057 A_val, A_dot); 01058 arrayTraits.unpack(B, n_B_rows, n_B_cols, ldb, n_B_dot, ldb_val, ldb_dot, 01059 B_val, B_dot); 01060 arrayTraits.unpack(beta, n_beta_dot, beta_val, beta_dot); 01061 01062 // Compute size 01063 n_dot = 0; 01064 if (n_alpha_dot > 0) 01065 n_dot = n_alpha_dot; 01066 else if (n_A_dot > 0) 01067 n_dot = n_A_dot; 01068 else if (n_B_dot > 0) 01069 n_dot = n_B_dot; 01070 else if (n_beta_dot > 0) 01071 n_dot = n_beta_dot; 01072 01073 // Unpack and allocate C 01074 arrayTraits.unpack(C, m, n, ldc, n_C_dot, n_dot, ldc_val, ldc_dot, C_val, 01075 C_dot); 01076 01077 // Call differentiated routine 01078 Fad_GEMM(transa, transb, m, n, k, 01079 alpha_val, n_alpha_dot, alpha_dot, 01080 A_val, lda_val, n_A_dot, A_dot, lda_dot, 01081 B_val, ldb_val, n_B_dot, B_dot, ldb_dot, 01082 beta_val, n_beta_dot, beta_dot, 01083 C_val, ldc_val, n_C_dot, C_dot, ldc_dot, n_dot); 01084 01085 // Pack values and derivatives for result 01086 arrayTraits.pack(C, m, n, ldc, n_dot, ldc_val, ldc_dot, C_val, C_dot); 01087 01088 // Free temporary arrays 01089 arrayTraits.free(alpha, n_alpha_dot, alpha_dot); 01090 arrayTraits.free(A, n_A_rows, n_A_cols, n_A_dot, lda_val, lda_dot, A_val, 01091 A_dot); 01092 arrayTraits.free(B, n_B_rows, n_B_cols, n_B_dot, ldb_val, ldb_dot, B_val, 01093 B_dot); 01094 arrayTraits.free(beta, n_beta_dot, beta_dot); 01095 arrayTraits.free(C, m, n, n_dot, ldc_val, ldc_dot, C_val, C_dot); 01096 } 01097 01098 template <typename OrdinalType, typename FadType> 01099 template <typename alpha_type, typename A_type, typename B_type, 01100 typename beta_type> 01101 void 01102 Sacado::Fad::BLAS<OrdinalType,FadType>:: 01103 SYMM(Teuchos::ESide side, Teuchos::EUplo uplo, 01104 const OrdinalType m, const OrdinalType n, 01105 const alpha_type& alpha, const A_type* A, const OrdinalType lda, 01106 const B_type* B, const OrdinalType ldb, const beta_type& beta, 01107 FadType* C, const OrdinalType ldc) const 01108 { 01109 if (use_default_impl) { 01110 BLASType::SYMM(side,uplo,m,n,alpha,A,lda,B,ldb,beta,C,ldc); 01111 return; 01112 } 01113 01114 OrdinalType n_A_rows = m; 01115 OrdinalType n_A_cols = m; 01116 if (side == Teuchos::RIGHT_SIDE) { 01117 n_A_rows = n; 01118 n_A_cols = n; 01119 } 01120 01121 // Unpack input values & derivatives 01122 typename ArrayValueType<alpha_type>::type alpha_val; 01123 const typename ArrayValueType<alpha_type>::type *alpha_dot; 01124 typename ArrayValueType<beta_type>::type beta_val; 01125 const typename ArrayValueType<beta_type>::type *beta_dot; 01126 const typename ArrayValueType<A_type>::type *A_val, *A_dot; 01127 const typename ArrayValueType<B_type>::type *B_val, *B_dot; 01128 ValueType *C_val, *C_dot; 01129 OrdinalType n_alpha_dot, n_A_dot, n_B_dot, n_beta_dot, n_C_dot, n_dot; 01130 OrdinalType lda_val, ldb_val, ldc_val, lda_dot, ldb_dot, ldc_dot; 01131 arrayTraits.unpack(alpha, n_alpha_dot, alpha_val, alpha_dot); 01132 arrayTraits.unpack(A, n_A_rows, n_A_cols, lda, n_A_dot, lda_val, lda_dot, 01133 A_val, A_dot); 01134 arrayTraits.unpack(B, m, n, ldb, n_B_dot, ldb_val, ldb_dot, B_val, B_dot); 01135 arrayTraits.unpack(beta, n_beta_dot, beta_val, beta_dot); 01136 01137 // Compute size 01138 n_dot = 0; 01139 if (n_alpha_dot > 0) 01140 n_dot = n_alpha_dot; 01141 else if (n_A_dot > 0) 01142 n_dot = n_A_dot; 01143 else if (n_B_dot > 0) 01144 n_dot = n_B_dot; 01145 else if (n_beta_dot > 0) 01146 n_dot = n_beta_dot; 01147 01148 // Unpack and allocate C 01149 arrayTraits.unpack(C, m, n, ldc, n_C_dot, n_dot, ldc_val, ldc_dot, C_val, 01150 C_dot); 01151 01152 // Call differentiated routine 01153 Fad_SYMM(side, uplo, m, n, 01154 alpha_val, n_alpha_dot, alpha_dot, 01155 A_val, lda_val, n_A_dot, A_dot, lda_dot, 01156 B_val, ldb_val, n_B_dot, B_dot, ldb_dot, 01157 beta_val, n_beta_dot, beta_dot, 01158 C_val, ldc_val, n_C_dot, C_dot, ldc_dot, n_dot); 01159 01160 // Pack values and derivatives for result 01161 arrayTraits.pack(C, m, n, ldc, n_dot, ldc_val, ldc_dot, C_val, C_dot); 01162 01163 // Free temporary arrays 01164 arrayTraits.free(alpha, n_alpha_dot, alpha_dot); 01165 arrayTraits.free(A, n_A_rows, n_A_cols, n_A_dot, lda_val, lda_dot, A_val, 01166 A_dot); 01167 arrayTraits.free(B, m, n, n_B_dot, ldb_val, ldb_dot, B_val, B_dot); 01168 arrayTraits.free(beta, n_beta_dot, beta_dot); 01169 arrayTraits.free(C, m, n, n_dot, ldc_val, ldc_dot, C_val, C_dot); 01170 } 01171 01172 template <typename OrdinalType, typename FadType> 01173 template <typename alpha_type, typename A_type> 01174 void 01175 Sacado::Fad::BLAS<OrdinalType,FadType>:: 01176 TRMM(Teuchos::ESide side, Teuchos::EUplo uplo, 01177 Teuchos::ETransp transa, Teuchos::EDiag diag, 01178 const OrdinalType m, const OrdinalType n, 01179 const alpha_type& alpha, const A_type* A, const OrdinalType lda, 01180 FadType* B, const OrdinalType ldb) const 01181 { 01182 if (use_default_impl) { 01183 BLASType::TRMM(side,uplo,transa,diag,m,n,alpha,A,lda,B,ldb); 01184 return; 01185 } 01186 01187 OrdinalType n_A_rows = m; 01188 OrdinalType n_A_cols = m; 01189 if (side == Teuchos::RIGHT_SIDE) { 01190 n_A_rows = n; 01191 n_A_cols = n; 01192 } 01193 01194 // Unpack input values & derivatives 01195 typename ArrayValueType<alpha_type>::type alpha_val; 01196 const typename ArrayValueType<alpha_type>::type *alpha_dot; 01197 const typename ArrayValueType<A_type>::type *A_val, *A_dot; 01198 ValueType *B_val, *B_dot; 01199 OrdinalType n_alpha_dot, n_A_dot, n_B_dot, n_dot; 01200 OrdinalType lda_val, ldb_val, lda_dot, ldb_dot; 01201 arrayTraits.unpack(alpha, n_alpha_dot, alpha_val, alpha_dot); 01202 arrayTraits.unpack(A, n_A_rows, n_A_cols, lda, n_A_dot, lda_val, lda_dot, 01203 A_val, A_dot); 01204 01205 // Compute size 01206 n_dot = 0; 01207 if (n_alpha_dot > 0) 01208 n_dot = n_alpha_dot; 01209 else if (n_A_dot > 0) 01210 n_dot = n_A_dot; 01211 01212 // Unpack and allocate B 01213 arrayTraits.unpack(B, m, n, ldb, n_B_dot, n_dot, ldb_val, ldb_dot, B_val, 01214 B_dot); 01215 01216 // Call differentiated routine 01217 Fad_TRMM(side, uplo, transa, diag, m, n, 01218 alpha_val, n_alpha_dot, alpha_dot, 01219 A_val, lda_val, n_A_dot, A_dot, lda_dot, 01220 B_val, ldb_val, n_B_dot, B_dot, ldb_dot, n_dot); 01221 01222 // Pack values and derivatives for result 01223 arrayTraits.pack(B, m, n, ldb, n_dot, ldb_val, ldb_dot, B_val, B_dot); 01224 01225 // Free temporary arrays 01226 arrayTraits.free(alpha, n_alpha_dot, alpha_dot); 01227 arrayTraits.free(A, n_A_rows, n_A_cols, n_A_dot, lda_val, lda_dot, A_val, 01228 A_dot); 01229 arrayTraits.free(B, m, n, n_dot, ldb_val, ldb_dot, B_val, B_dot); 01230 } 01231 01232 template <typename OrdinalType, typename FadType> 01233 template <typename alpha_type, typename A_type> 01234 void 01235 Sacado::Fad::BLAS<OrdinalType,FadType>:: 01236 TRSM(Teuchos::ESide side, Teuchos::EUplo uplo, 01237 Teuchos::ETransp transa, Teuchos::EDiag diag, 01238 const OrdinalType m, const OrdinalType n, 01239 const alpha_type& alpha, const A_type* A, const OrdinalType lda, 01240 FadType* B, const OrdinalType ldb) const 01241 { 01242 if (use_default_impl) { 01243 BLASType::TRSM(side,uplo,transa,diag,m,n,alpha,A,lda,B,ldb); 01244 return; 01245 } 01246 01247 OrdinalType n_A_rows = m; 01248 OrdinalType n_A_cols = m; 01249 if (side == Teuchos::RIGHT_SIDE) { 01250 n_A_rows = n; 01251 n_A_cols = n; 01252 } 01253 01254 // Unpack input values & derivatives 01255 typename ArrayValueType<alpha_type>::type alpha_val; 01256 const typename ArrayValueType<alpha_type>::type *alpha_dot; 01257 const typename ArrayValueType<A_type>::type *A_val, *A_dot; 01258 ValueType *B_val, *B_dot; 01259 OrdinalType n_alpha_dot, n_A_dot, n_B_dot, n_dot; 01260 OrdinalType lda_val, ldb_val, lda_dot, ldb_dot; 01261 arrayTraits.unpack(alpha, n_alpha_dot, alpha_val, alpha_dot); 01262 arrayTraits.unpack(A, n_A_rows, n_A_cols, lda, n_A_dot, lda_val, lda_dot, 01263 A_val, A_dot); 01264 01265 // Compute size 01266 n_dot = 0; 01267 if (n_alpha_dot > 0) 01268 n_dot = n_alpha_dot; 01269 else if (n_A_dot > 0) 01270 n_dot = n_A_dot; 01271 01272 // Unpack and allocate B 01273 arrayTraits.unpack(B, m, n, ldb, n_B_dot, n_dot, ldb_val, ldb_dot, B_val, 01274 B_dot); 01275 01276 // Call differentiated routine 01277 Fad_TRSM(side, uplo, transa, diag, m, n, 01278 alpha_val, n_alpha_dot, alpha_dot, 01279 A_val, lda_val, n_A_dot, A_dot, lda_dot, 01280 B_val, ldb_val, n_B_dot, B_dot, ldb_dot, n_dot); 01281 01282 // Pack values and derivatives for result 01283 arrayTraits.pack(B, m, n, ldb, n_dot, ldb_val, ldb_dot, B_val, B_dot); 01284 01285 // Free temporary arrays 01286 arrayTraits.free(alpha, n_alpha_dot, alpha_dot); 01287 arrayTraits.free(A, n_A_rows, n_A_cols, n_A_dot, lda_val, lda_dot, A_val, 01288 A_dot); 01289 arrayTraits.free(B, m, n, n_dot, ldb_val, ldb_dot, B_val, B_dot); 01290 } 01291 01292 template <typename OrdinalType, typename FadType> 01293 template <typename x_type, typename y_type> 01294 void 01295 Sacado::Fad::BLAS<OrdinalType,FadType>:: 01296 Fad_DOT(const OrdinalType n, 01297 const x_type* x, 01298 const OrdinalType incx, 01299 const OrdinalType n_x_dot, 01300 const x_type* x_dot, 01301 const OrdinalType incx_dot, 01302 const y_type* y, 01303 const OrdinalType incy, 01304 const OrdinalType n_y_dot, 01305 const y_type* y_dot, 01306 const OrdinalType incy_dot, 01307 ValueType& z, 01308 const OrdinalType n_z_dot, 01309 ValueType* z_dot) const 01310 { 01311 #ifdef SACADO_DEBUG 01312 // Check sizes are consistent 01313 TEUCHOS_TEST_FOR_EXCEPTION((n_x_dot != n_z_dot && n_x_dot != 0) || 01314 (n_y_dot != n_z_dot && n_y_dot != 0), 01315 std::logic_error, 01316 "BLAS::Fad_DOT(): All arguments must have " << 01317 "the same number of derivative components, or none"); 01318 #endif 01319 01320 // Compute [zd_1 .. zd_n] = [xd_1 .. xd_n]^T*y 01321 if (n_x_dot > 0) { 01322 if (incx_dot == OrdinalType(1)) 01323 blas.GEMV(Teuchos::TRANS, n, n_x_dot, 1.0, x_dot, n, y, incy, 0.0, z_dot, 01324 OrdinalType(1)); 01325 else 01326 for (OrdinalType i=0; i<n_z_dot; i++) 01327 z_dot[i] = blas.DOT(n, x_dot+i*incx_dot*n, incx_dot, y, incy); 01328 } 01329 01330 // Compute [zd_1 .. zd_n] += [yd_1 .. yd_n]^T*x 01331 if (n_y_dot > 0) { 01332 if (incy_dot == OrdinalType(1) && 01333 !Teuchos::ScalarTraits<ValueType>::isComplex) 01334 blas.GEMV(Teuchos::TRANS, n, n_y_dot, 1.0, y_dot, n, x, incx, 1.0, z_dot, 01335 OrdinalType(1)); 01336 else 01337 for (OrdinalType i=0; i<n_z_dot; i++) 01338 z_dot[i] += blas.DOT(n, x, incx, y_dot+i*incy_dot*n, incy_dot); 01339 } 01340 01341 // Compute z = x^T*y 01342 z = blas.DOT(n, x, incx, y, incy); 01343 } 01344 01345 template <typename OrdinalType, typename FadType> 01346 template <typename alpha_type, typename A_type, typename x_type, 01347 typename beta_type> 01348 void 01349 Sacado::Fad::BLAS<OrdinalType,FadType>:: 01350 Fad_GEMV(Teuchos::ETransp trans, 01351 const OrdinalType m, 01352 const OrdinalType n, 01353 const alpha_type& alpha, 01354 const OrdinalType n_alpha_dot, 01355 const alpha_type* alpha_dot, 01356 const A_type* A, 01357 const OrdinalType lda, 01358 const OrdinalType n_A_dot, 01359 const A_type* A_dot, 01360 const OrdinalType lda_dot, 01361 const x_type* x, 01362 const OrdinalType incx, 01363 const OrdinalType n_x_dot, 01364 const x_type* x_dot, 01365 const OrdinalType incx_dot, 01366 const beta_type& beta, 01367 const OrdinalType n_beta_dot, 01368 const beta_type* beta_dot, 01369 ValueType* y, 01370 const OrdinalType incy, 01371 const OrdinalType n_y_dot, 01372 ValueType* y_dot, 01373 const OrdinalType incy_dot, 01374 const OrdinalType n_dot) const 01375 { 01376 #ifdef SACADO_DEBUG 01377 // Check sizes are consistent 01378 TEUCHOS_TEST_FOR_EXCEPTION((n_alpha_dot != n_dot && n_alpha_dot != 0) || 01379 (n_A_dot != n_dot && n_A_dot != 0) || 01380 (n_x_dot != n_dot && n_x_dot != 0) || 01381 (n_beta_dot != n_dot && n_beta_dot != 0) || 01382 (n_y_dot != n_dot && n_y_dot != 0), 01383 std::logic_error, 01384 "BLAS::Fad_GEMV(): All arguments must have " << 01385 "the same number of derivative components, or none"); 01386 #endif 01387 OrdinalType n_A_rows = m; 01388 OrdinalType n_A_cols = n; 01389 OrdinalType n_x_rows = n; 01390 OrdinalType n_y_rows = m; 01391 if (trans == Teuchos::TRANS) { 01392 n_A_rows = n; 01393 n_A_cols = m; 01394 n_x_rows = m; 01395 n_y_rows = n; 01396 } 01397 01398 // Compute [yd_1 .. yd_n] = beta*[yd_1 .. yd_n] 01399 if (n_y_dot > 0) 01400 blas.SCAL(n_y_rows*n_y_dot, beta, y_dot, incy_dot); 01401 01402 // Compute [yd_1 .. yd_n] = alpha*A*[xd_1 .. xd_n] 01403 if (n_x_dot > 0) { 01404 if (incx_dot == 1) 01405 blas.GEMM(trans, Teuchos::NO_TRANS, n_A_rows, n_dot, n_A_cols, 01406 alpha, A, lda, x_dot, n_x_rows, 1.0, y_dot, n_y_rows); 01407 else 01408 for (OrdinalType i=0; i<n_x_dot; i++) 01409 blas.GEMV(trans, m, n, alpha, A, lda, x_dot+i*incx_dot*n_x_rows, 01410 incx_dot, 1.0, y_dot+i*incy_dot*n_y_rows, incy_dot); 01411 } 01412 01413 // Compute [yd_1 .. yd_n] += diag([alphad_1 .. alphad_n])*A*x 01414 if (n_alpha_dot > 0) { 01415 if (gemv_Ax.size() != std::size_t(n)) 01416 gemv_Ax.resize(n); 01417 blas.GEMV(trans, m, n, 1.0, A, lda, x, incx, 0.0, &gemv_Ax[0], 01418 OrdinalType(1)); 01419 for (OrdinalType i=0; i<n_alpha_dot; i++) 01420 blas.AXPY(n_y_rows, alpha_dot[i], &gemv_Ax[0], OrdinalType(1), 01421 y_dot+i*incy_dot*n_y_rows, incy_dot); 01422 } 01423 01424 // Compute [yd_1 .. yd_n] += alpha*[Ad_1*x .. Ad_n*x] 01425 for (OrdinalType i=0; i<n_A_dot; i++) 01426 blas.GEMV(trans, m, n, alpha, A_dot+i*lda_dot*n, lda_dot, x, incx, 1.0, 01427 y_dot+i*incy_dot*n_y_rows, incy_dot); 01428 01429 // Compute [yd_1 .. yd_n] += diag([betad_1 .. betad_n])*y 01430 for (OrdinalType i=0; i<n_beta_dot; i++) 01431 blas.AXPY(n_y_rows, beta_dot[i], y, incy, y_dot+i*incy_dot*n_y_rows, 01432 incy_dot); 01433 01434 // Compute y = alpha*A*x + beta*y 01435 if (n_alpha_dot > 0) { 01436 blas.SCAL(n_y_rows, beta, y, incy); 01437 blas.AXPY(n_y_rows, alpha, &gemv_Ax[0], OrdinalType(1), y, incy); 01438 } 01439 else 01440 blas.GEMV(trans, m, n, alpha, A, lda, x, incx, beta, y, incy); 01441 } 01442 01443 template <typename OrdinalType, typename FadType> 01444 template <typename alpha_type, typename x_type, typename y_type> 01445 void 01446 Sacado::Fad::BLAS<OrdinalType,FadType>:: 01447 Fad_GER(const OrdinalType m, 01448 const OrdinalType n, 01449 const alpha_type& alpha, 01450 const OrdinalType n_alpha_dot, 01451 const alpha_type* alpha_dot, 01452 const x_type* x, 01453 const OrdinalType incx, 01454 const OrdinalType n_x_dot, 01455 const x_type* x_dot, 01456 const OrdinalType incx_dot, 01457 const y_type* y, 01458 const OrdinalType incy, 01459 const OrdinalType n_y_dot, 01460 const y_type* y_dot, 01461 const OrdinalType incy_dot, 01462 ValueType* A, 01463 const OrdinalType lda, 01464 const OrdinalType n_A_dot, 01465 ValueType* A_dot, 01466 const OrdinalType lda_dot, 01467 const OrdinalType n_dot) const 01468 { 01469 #ifdef SACADO_DEBUG 01470 // Check sizes are consistent 01471 TEUCHOS_TEST_FOR_EXCEPTION((n_alpha_dot != n_dot && n_alpha_dot != 0) || 01472 (n_A_dot != n_dot && n_A_dot != 0) || 01473 (n_x_dot != n_dot && n_x_dot != 0) || 01474 (n_y_dot != n_dot && n_y_dot != 0), 01475 std::logic_error, 01476 "BLAS::Fad_GER(): All arguments must have " << 01477 "the same number of derivative components, or none"); 01478 #endif 01479 01480 // Compute [Ad_1 .. Ad_n] += [alphad_1*x*y^T .. alphad_n*x*y^T] 01481 for (OrdinalType i=0; i<n_alpha_dot; i++) 01482 blas.GER(m, n, alpha_dot[i], x, incx, y, incy, A_dot+i*lda_dot*n, lda_dot); 01483 01484 // Compute [Ad_1 .. Ad_n] += alpha*[xd_1*y^T .. xd_n*y^T] 01485 for (OrdinalType i=0; i<n_x_dot; i++) 01486 blas.GER(m, n, alpha, x_dot+i*incx_dot*m, incx_dot, y, incy, 01487 A_dot+i*lda_dot*n, lda_dot); 01488 01489 // Compute [Ad_1 .. Ad_n] += alpha*x*[yd_1 .. yd_n] 01490 if (n_y_dot > 0) 01491 blas.GER(m, n*n_y_dot, alpha, x, incx, y_dot, incy_dot, A_dot, lda_dot); 01492 01493 // Compute A = alpha*x*y^T + A 01494 blas.GER(m, n, alpha, x, incx, y, incy, A, lda); 01495 } 01496 01497 template <typename OrdinalType, typename FadType> 01498 template <typename alpha_type, typename A_type, typename B_type, 01499 typename beta_type> 01500 void 01501 Sacado::Fad::BLAS<OrdinalType,FadType>:: 01502 Fad_GEMM(Teuchos::ETransp transa, 01503 Teuchos::ETransp transb, 01504 const OrdinalType m, 01505 const OrdinalType n, 01506 const OrdinalType k, 01507 const alpha_type& alpha, 01508 const OrdinalType n_alpha_dot, 01509 const alpha_type* alpha_dot, 01510 const A_type* A, 01511 const OrdinalType lda, 01512 const OrdinalType n_A_dot, 01513 const A_type* A_dot, 01514 const OrdinalType lda_dot, 01515 const B_type* B, 01516 const OrdinalType ldb, 01517 const OrdinalType n_B_dot, 01518 const B_type* B_dot, 01519 const OrdinalType ldb_dot, 01520 const beta_type& beta, 01521 const OrdinalType n_beta_dot, 01522 const beta_type* beta_dot, 01523 ValueType* C, 01524 const OrdinalType ldc, 01525 const OrdinalType n_C_dot, 01526 ValueType* C_dot, 01527 const OrdinalType ldc_dot, 01528 const OrdinalType n_dot) const 01529 { 01530 #ifdef SACADO_DEBUG 01531 // Check sizes are consistent 01532 TEUCHOS_TEST_FOR_EXCEPTION((n_alpha_dot != n_dot && n_alpha_dot != 0) || 01533 (n_A_dot != n_dot && n_A_dot != 0) || 01534 (n_B_dot != n_dot && n_B_dot != 0) || 01535 (n_beta_dot != n_dot && n_beta_dot != 0) || 01536 (n_C_dot != n_dot && n_C_dot != 0), 01537 std::logic_error, 01538 "BLAS::Fad_GEMM(): All arguments must have " << 01539 "the same number of derivative components, or none"); 01540 #endif 01541 OrdinalType n_A_rows = m; 01542 OrdinalType n_A_cols = k; 01543 if (transa != Teuchos::NO_TRANS) { 01544 n_A_rows = k; 01545 n_A_cols = m; 01546 } 01547 01548 OrdinalType n_B_rows = k; 01549 OrdinalType n_B_cols = n; 01550 if (transb != Teuchos::NO_TRANS) { 01551 n_B_rows = n; 01552 n_B_cols = k; 01553 } 01554 01555 // Compute [Cd_1 .. Cd_n] = beta*[Cd_1 .. Cd_n] 01556 if (n_C_dot > 0) { 01557 if (ldc_dot == m) 01558 blas.SCAL(m*n*n_C_dot, beta, C_dot, OrdinalType(1)); 01559 else 01560 for (OrdinalType i=0; i<n_C_dot; i++) 01561 for (OrdinalType j=0; j<n; j++) 01562 blas.SCAL(m, beta, C_dot+i*ldc_dot*n+j*ldc_dot, OrdinalType(1)); 01563 } 01564 01565 // Compute [Cd_1 .. Cd_n] += alpha*A*[Bd_1 .. Bd_n] 01566 for (OrdinalType i=0; i<n_B_dot; i++) 01567 blas.GEMM(transa, transb, m, n, k, alpha, A, lda, B_dot+i*ldb_dot*n_B_cols, 01568 ldb_dot, 1.0, C_dot+i*ldc_dot*n, ldc_dot); 01569 01570 // Compute [Cd_1 .. Cd_n] += [alphad_1*A*B .. alphad_n*A*B] 01571 if (n_alpha_dot > 0) { 01572 if (gemm_AB.size() != std::size_t(m*n)) 01573 gemm_AB.resize(m*n); 01574 blas.GEMM(transa, transb, m, n, k, 1.0, A, lda, B, ldb, 0.0, &gemm_AB[0], 01575 OrdinalType(m)); 01576 if (ldc_dot == m) 01577 for (OrdinalType i=0; i<n_alpha_dot; i++) 01578 blas.AXPY(m*n, alpha_dot[i], &gemm_AB[0], OrdinalType(1), 01579 C_dot+i*ldc_dot*n, OrdinalType(1)); 01580 else 01581 for (OrdinalType i=0; i<n_alpha_dot; i++) 01582 for (OrdinalType j=0; j<n; j++) 01583 blas.AXPY(m, alpha_dot[i], &gemm_AB[j*m], OrdinalType(1), 01584 C_dot+i*ldc_dot*n+j*ldc_dot, OrdinalType(1)); 01585 } 01586 01587 // Compute [Cd_1 .. Cd_n] += alpha*[Ad_1*B .. Ad_n*B] 01588 for (OrdinalType i=0; i<n_A_dot; i++) 01589 blas.GEMM(transa, transb, m, n, k, alpha, A_dot+i*lda_dot*n_A_cols, 01590 lda_dot, B, ldb, 1.0, C_dot+i*ldc_dot*n, ldc_dot); 01591 01592 // Compute [Cd_1 .. Cd_n] += [betad_1*C .. betad_n*C] 01593 if (ldc == m && ldc_dot == m) 01594 for (OrdinalType i=0; i<n_beta_dot; i++) 01595 blas.AXPY(m*n, beta_dot[i], C, OrdinalType(1), C_dot+i*ldc_dot*n, 01596 OrdinalType(1)); 01597 else 01598 for (OrdinalType i=0; i<n_beta_dot; i++) 01599 for (OrdinalType j=0; j<n; j++) 01600 blas.AXPY(m, beta_dot[i], C+j*ldc, OrdinalType(1), 01601 C_dot+i*ldc_dot*n+j*ldc_dot, OrdinalType(1)); 01602 01603 // Compute C = alpha*A*B + beta*C 01604 if (n_alpha_dot > 0) { 01605 if (ldc == m) { 01606 blas.SCAL(m*n, beta, C, OrdinalType(1)); 01607 blas.AXPY(m*n, alpha, &gemm_AB[0], OrdinalType(1), C, OrdinalType(1)); 01608 } 01609 else 01610 for (OrdinalType j=0; j<n; j++) { 01611 blas.SCAL(m, beta, C+j*ldc, OrdinalType(1)); 01612 blas.AXPY(m, alpha, &gemm_AB[j*m], OrdinalType(1), C+j*ldc, 01613 OrdinalType(1)); 01614 } 01615 } 01616 else 01617 blas.GEMM(transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc); 01618 } 01619 01620 template <typename OrdinalType, typename FadType> 01621 template <typename alpha_type, typename A_type, typename B_type, 01622 typename beta_type> 01623 void 01624 Sacado::Fad::BLAS<OrdinalType,FadType>:: 01625 Fad_SYMM(Teuchos::ESide side, Teuchos::EUplo uplo, 01626 const OrdinalType m, 01627 const OrdinalType n, 01628 const alpha_type& alpha, 01629 const OrdinalType n_alpha_dot, 01630 const alpha_type* alpha_dot, 01631 const A_type* A, 01632 const OrdinalType lda, 01633 const OrdinalType n_A_dot, 01634 const A_type* A_dot, 01635 const OrdinalType lda_dot, 01636 const B_type* B, 01637 const OrdinalType ldb, 01638 const OrdinalType n_B_dot, 01639 const B_type* B_dot, 01640 const OrdinalType ldb_dot, 01641 const beta_type& beta, 01642 const OrdinalType n_beta_dot, 01643 const beta_type* beta_dot, 01644 ValueType* C, 01645 const OrdinalType ldc, 01646 const OrdinalType n_C_dot, 01647 ValueType* C_dot, 01648 const OrdinalType ldc_dot, 01649 const OrdinalType n_dot) const 01650 { 01651 #ifdef SACADO_DEBUG 01652 // Check sizes are consistent 01653 TEUCHOS_TEST_FOR_EXCEPTION((n_alpha_dot != n_dot && n_alpha_dot != 0) || 01654 (n_A_dot != n_dot && n_A_dot != 0) || 01655 (n_B_dot != n_dot && n_B_dot != 0) || 01656 (n_beta_dot != n_dot && n_beta_dot != 0) || 01657 (n_C_dot != n_dot && n_C_dot != 0), 01658 std::logic_error, 01659 "BLAS::Fad_SYMM(): All arguments must have " << 01660 "the same number of derivative components, or none"); 01661 #endif 01662 OrdinalType n_A_rows = m; 01663 OrdinalType n_A_cols = m; 01664 if (side == Teuchos::RIGHT_SIDE) { 01665 n_A_rows = n; 01666 n_A_cols = n; 01667 } 01668 01669 // Compute [Cd_1 .. Cd_n] = beta*[Cd_1 .. Cd_n] 01670 if (n_C_dot > 0) { 01671 if (ldc_dot == m) 01672 blas.SCAL(m*n*n_C_dot, beta, C_dot, OrdinalType(1)); 01673 else 01674 for (OrdinalType i=0; i<n_C_dot; i++) 01675 for (OrdinalType j=0; j<n; j++) 01676 blas.SCAL(m, beta, C_dot+i*ldc_dot*n+j*ldc_dot, OrdinalType(1)); 01677 } 01678 01679 // Compute [Cd_1 .. Cd_n] += alpha*A*[Bd_1 .. Bd_n] 01680 for (OrdinalType i=0; i<n_B_dot; i++) 01681 blas.SYMM(side, uplo, m, n, alpha, A, lda, B_dot+i*ldb_dot*n, 01682 ldb_dot, 1.0, C_dot+i*ldc_dot*n, ldc_dot); 01683 01684 // Compute [Cd_1 .. Cd_n] += [alphad_1*A*B .. alphad_n*A*B] 01685 if (n_alpha_dot > 0) { 01686 if (gemm_AB.size() != std::size_t(m*n)) 01687 gemm_AB.resize(m*n); 01688 blas.SYMM(side, uplo, m, n, 1.0, A, lda, B, ldb, 0.0, &gemm_AB[0], 01689 OrdinalType(m)); 01690 if (ldc_dot == m) 01691 for (OrdinalType i=0; i<n_alpha_dot; i++) 01692 blas.AXPY(m*n, alpha_dot[i], &gemm_AB[0], OrdinalType(1), 01693 C_dot+i*ldc_dot*n, OrdinalType(1)); 01694 else 01695 for (OrdinalType i=0; i<n_alpha_dot; i++) 01696 for (OrdinalType j=0; j<n; j++) 01697 blas.AXPY(m, alpha_dot[i], &gemm_AB[j*m], OrdinalType(1), 01698 C_dot+i*ldc_dot*n+j*ldc_dot, OrdinalType(1)); 01699 } 01700 01701 // Compute [Cd_1 .. Cd_n] += alpha*[Ad_1*B .. Ad_n*B] 01702 for (OrdinalType i=0; i<n_A_dot; i++) 01703 blas.SYMM(side, uplo, m, n, alpha, A_dot+i*lda_dot*n_A_cols, lda_dot, B, 01704 ldb, 1.0, C_dot+i*ldc_dot*n, ldc_dot); 01705 01706 // Compute [Cd_1 .. Cd_n] += [betad_1*C .. betad_n*C] 01707 if (ldc == m && ldc_dot == m) 01708 for (OrdinalType i=0; i<n_beta_dot; i++) 01709 blas.AXPY(m*n, beta_dot[i], C, OrdinalType(1), C_dot+i*ldc_dot*n, 01710 OrdinalType(1)); 01711 else 01712 for (OrdinalType i=0; i<n_beta_dot; i++) 01713 for (OrdinalType j=0; j<n; j++) 01714 blas.AXPY(m, beta_dot[i], C+j*ldc, OrdinalType(1), 01715 C_dot+i*ldc_dot*n+j*ldc_dot, OrdinalType(1)); 01716 01717 // Compute C = alpha*A*B + beta*C 01718 if (n_alpha_dot > 0) { 01719 if (ldc == m) { 01720 blas.SCAL(m*n, beta, C, OrdinalType(1)); 01721 blas.AXPY(m*n, alpha, &gemm_AB[0], OrdinalType(1), C, OrdinalType(1)); 01722 } 01723 else 01724 for (OrdinalType j=0; j<n; j++) { 01725 blas.SCAL(m, beta, C+j*ldc, OrdinalType(1)); 01726 blas.AXPY(m, alpha, &gemm_AB[j*m], OrdinalType(1), C+j*ldc, 01727 OrdinalType(1)); 01728 } 01729 } 01730 else 01731 blas.SYMM(side, uplo, m, n, alpha, A, lda, B, ldb, beta, C, ldc); 01732 } 01733 01734 template <typename OrdinalType, typename FadType> 01735 template <typename alpha_type, typename A_type> 01736 void 01737 Sacado::Fad::BLAS<OrdinalType,FadType>:: 01738 Fad_TRMM(Teuchos::ESide side, 01739 Teuchos::EUplo uplo, 01740 Teuchos::ETransp transa, 01741 Teuchos::EDiag diag, 01742 const OrdinalType m, 01743 const OrdinalType n, 01744 const alpha_type& alpha, 01745 const OrdinalType n_alpha_dot, 01746 const alpha_type* alpha_dot, 01747 const A_type* A, 01748 const OrdinalType lda, 01749 const OrdinalType n_A_dot, 01750 const A_type* A_dot, 01751 const OrdinalType lda_dot, 01752 ValueType* B, 01753 const OrdinalType ldb, 01754 const OrdinalType n_B_dot, 01755 ValueType* B_dot, 01756 const OrdinalType ldb_dot, 01757 const OrdinalType n_dot) const 01758 { 01759 #ifdef SACADO_DEBUG 01760 // Check sizes are consistent 01761 TEUCHOS_TEST_FOR_EXCEPTION((n_alpha_dot != n_dot && n_alpha_dot != 0) || 01762 (n_A_dot != n_dot && n_A_dot != 0) || 01763 (n_B_dot != n_dot && n_B_dot != 0), 01764 std::logic_error, 01765 "BLAS::Fad_TRMM(): All arguments must have " << 01766 "the same number of derivative components, or none"); 01767 #endif 01768 OrdinalType n_A_rows = m; 01769 OrdinalType n_A_cols = m; 01770 if (side == Teuchos::RIGHT_SIDE) { 01771 n_A_rows = n; 01772 n_A_cols = n; 01773 } 01774 01775 // Compute [Bd_1 .. Bd_n] = alpha*A*[Bd_1 .. Bd_n] 01776 for (OrdinalType i=0; i<n_B_dot; i++) 01777 blas.TRMM(side, uplo, transa, diag, m, n, alpha, A, lda, B_dot+i*ldb_dot*n, 01778 ldb_dot); 01779 01780 // Compute [Bd_1 .. Bd_n] += [alphad_1*A*B .. alphad_n*A*B] 01781 if (n_alpha_dot > 0) { 01782 if (gemm_AB.size() != std::size_t(m*n)) 01783 gemm_AB.resize(m*n); 01784 if (ldb == m) 01785 blas.COPY(m*n, B, OrdinalType(1), &gemm_AB[0], OrdinalType(1)); 01786 else 01787 for (OrdinalType j=0; j<n; j++) 01788 blas.COPY(m, B+j*ldb, OrdinalType(1), &gemm_AB[j*m], OrdinalType(1)); 01789 blas.TRMM(side, uplo, transa, diag, m, n, 1.0, A, lda, &gemm_AB[0], 01790 OrdinalType(m)); 01791 if (ldb_dot == m) 01792 for (OrdinalType i=0; i<n_alpha_dot; i++) 01793 blas.AXPY(m*n, alpha_dot[i], &gemm_AB[0], OrdinalType(1), 01794 B_dot+i*ldb_dot*n, OrdinalType(1)); 01795 else 01796 for (OrdinalType i=0; i<n_alpha_dot; i++) 01797 for (OrdinalType j=0; j<n; j++) 01798 blas.AXPY(m, alpha_dot[i], &gemm_AB[j*m], OrdinalType(1), 01799 B_dot+i*ldb_dot*n+j*ldb_dot, OrdinalType(1)); 01800 } 01801 01802 // Compute [Bd_1 .. Bd_n] += alpha*[Ad_1*B .. Ad_n*B] 01803 if (n_A_dot > 0) { 01804 if (gemm_AB.size() != std::size_t(m*n)) 01805 gemm_AB.resize(m*n); 01806 for (OrdinalType i=0; i<n_A_dot; i++) { 01807 if (ldb == m) 01808 blas.COPY(m*n, B, OrdinalType(1), &gemm_AB[0], OrdinalType(1)); 01809 else 01810 for (OrdinalType j=0; j<n; j++) 01811 blas.COPY(m, B+j*ldb, OrdinalType(1), &gemm_AB[j*m], OrdinalType(1)); 01812 blas.TRMM(side, uplo, transa, Teuchos::NON_UNIT_DIAG, m, n, alpha, 01813 A_dot+i*lda_dot*n_A_cols, lda_dot, &gemm_AB[0], 01814 OrdinalType(m)); 01815 if (ldb_dot == m) 01816 blas.AXPY(m*n, 1.0, &gemm_AB[0], OrdinalType(1), 01817 B_dot+i*ldb_dot*n, OrdinalType(1)); 01818 else 01819 for (OrdinalType j=0; j<n; j++) 01820 blas.AXPY(m, 1.0, &gemm_AB[j*m], OrdinalType(1), 01821 B_dot+i*ldb_dot*n+j*ldb_dot, OrdinalType(1)); 01822 } 01823 } 01824 01825 // Compute B = alpha*A*B 01826 if (n_alpha_dot > 0 && n_A_dot == 0) { 01827 if (ldb == m) { 01828 blas.SCAL(m*n, 0.0, B, OrdinalType(1)); 01829 blas.AXPY(m*n, alpha, &gemm_AB[0], OrdinalType(1), B, OrdinalType(1)); 01830 } 01831 else 01832 for (OrdinalType j=0; j<n; j++) { 01833 blas.SCAL(m, 0.0, B+j*ldb, OrdinalType(1)); 01834 blas.AXPY(m, alpha, &gemm_AB[j*m], OrdinalType(1), B+j*ldb, 01835 OrdinalType(1)); 01836 } 01837 } 01838 else 01839 blas.TRMM(side, uplo, transa, diag, m, n, alpha, A, lda, B, ldb); 01840 } 01841 01842 template <typename OrdinalType, typename FadType> 01843 template <typename alpha_type, typename A_type> 01844 void 01845 Sacado::Fad::BLAS<OrdinalType,FadType>:: 01846 Fad_TRSM(Teuchos::ESide side, 01847 Teuchos::EUplo uplo, 01848 Teuchos::ETransp transa, 01849 Teuchos::EDiag diag, 01850 const OrdinalType m, 01851 const OrdinalType n, 01852 const alpha_type& alpha, 01853 const OrdinalType n_alpha_dot, 01854 const alpha_type* alpha_dot, 01855 const A_type* A, 01856 const OrdinalType lda, 01857 const OrdinalType n_A_dot, 01858 const A_type* A_dot, 01859 const OrdinalType lda_dot, 01860 ValueType* B, 01861 const OrdinalType ldb, 01862 const OrdinalType n_B_dot, 01863 ValueType* B_dot, 01864 const OrdinalType ldb_dot, 01865 const OrdinalType n_dot) const 01866 { 01867 #ifdef SACADO_DEBUG 01868 // Check sizes are consistent 01869 TEUCHOS_TEST_FOR_EXCEPTION((n_alpha_dot != n_dot && n_alpha_dot != 0) || 01870 (n_A_dot != n_dot && n_A_dot != 0) || 01871 (n_B_dot != n_dot && n_B_dot != 0), 01872 std::logic_error, 01873 "BLAS::Fad_TRSM(): All arguments must have " << 01874 "the same number of derivative components, or none"); 01875 #endif 01876 OrdinalType n_A_rows = m; 01877 OrdinalType n_A_cols = m; 01878 if (side == Teuchos::RIGHT_SIDE) { 01879 n_A_rows = n; 01880 n_A_cols = n; 01881 } 01882 01883 // Compute [Bd_1 .. Bd_n] = alpha*[Bd_1 .. Bd_n] 01884 if (n_B_dot > 0) { 01885 if (ldb_dot == m) 01886 blas.SCAL(m*n*n_B_dot, alpha, B_dot, OrdinalType(1)); 01887 else 01888 for (OrdinalType i=0; i<n_B_dot; i++) 01889 for (OrdinalType j=0; j<n; j++) 01890 blas.SCAL(m, alpha, B_dot+i*ldb_dot*n+j*ldb_dot, OrdinalType(1)); 01891 } 01892 01893 // Compute [Bd_1 .. Bd_n] += [alphad_1*B .. alphad_n*B] 01894 if (n_alpha_dot > 0) { 01895 if (ldb == m && ldb_dot == m) 01896 for (OrdinalType i=0; i<n_alpha_dot; i++) 01897 blas.AXPY(m*n, alpha_dot[i], B, OrdinalType(1), 01898 B_dot+i*ldb_dot*n, OrdinalType(1)); 01899 else 01900 for (OrdinalType i=0; i<n_alpha_dot; i++) 01901 for (OrdinalType j=0; j<n; j++) 01902 blas.AXPY(m, alpha_dot[i], B+j*ldb, OrdinalType(1), 01903 B_dot+i*ldb_dot*n+j*ldb_dot, OrdinalType(1)); 01904 } 01905 01906 // Solve A*X = alpha*B 01907 blas.TRSM(side, uplo, transa, diag, m, n, alpha, A, lda, B, ldb); 01908 01909 // Compute [Bd_1 .. Bd_n] -= [Ad_1*X .. Ad_n*X] 01910 if (n_A_dot > 0) { 01911 if (gemm_AB.size() != std::size_t(m*n)) 01912 gemm_AB.resize(m*n); 01913 for (OrdinalType i=0; i<n_A_dot; i++) { 01914 if (ldb == m) 01915 blas.COPY(m*n, B, OrdinalType(1), &gemm_AB[0], OrdinalType(1)); 01916 else 01917 for (OrdinalType j=0; j<n; j++) 01918 blas.COPY(m, B+j*ldb, OrdinalType(1), &gemm_AB[j*m], OrdinalType(1)); 01919 blas.TRMM(side, uplo, transa, Teuchos::NON_UNIT_DIAG, m, n, 1.0, 01920 A_dot+i*lda_dot*n_A_cols, lda_dot, &gemm_AB[0], 01921 OrdinalType(m)); 01922 if (ldb_dot == m) 01923 blas.AXPY(m*n, -1.0, &gemm_AB[0], OrdinalType(1), 01924 B_dot+i*ldb_dot*n, OrdinalType(1)); 01925 else 01926 for (OrdinalType j=0; j<n; j++) 01927 blas.AXPY(m, -1.0, &gemm_AB[j*m], OrdinalType(1), 01928 B_dot+i*ldb_dot*n+j*ldb_dot, OrdinalType(1)); 01929 } 01930 } 01931 01932 // Solve A*[Xd_1 .. Xd_n] = [Bd_1 .. Bd_n] 01933 if (side == Teuchos::LEFT_SIDE) 01934 blas.TRSM(side, uplo, transa, diag, m, n*n_dot, 1.0, A, lda, B_dot, 01935 ldb_dot); 01936 else 01937 for (OrdinalType i=0; i<n_dot; i++) 01938 blas.TRSM(side, uplo, transa, diag, m, n, 1.0, A, lda, B_dot+i*ldb_dot*n, 01939 ldb_dot); 01940 }
1.7.4