|
Teuchos - Trilinos Tools Package Version of the Day
|
00001 // @HEADER 00002 // *********************************************************************** 00003 // 00004 // Teuchos: Common Tools Package 00005 // Copyright (2004) Sandia Corporation 00006 // 00007 // Under terms of Contract DE-AC04-94AL85000, there is a non-exclusive 00008 // license for use of this work by or on behalf of the U.S. Government. 00009 // 00010 // Redistribution and use in source and binary forms, with or without 00011 // modification, are permitted provided that the following conditions are 00012 // met: 00013 // 00014 // 1. Redistributions of source code must retain the above copyright 00015 // notice, this list of conditions and the following disclaimer. 00016 // 00017 // 2. Redistributions in binary form must reproduce the above copyright 00018 // notice, this list of conditions and the following disclaimer in the 00019 // documentation and/or other materials provided with the distribution. 00020 // 00021 // 3. Neither the name of the Corporation nor the names of the 00022 // contributors may be used to endorse or promote products derived from 00023 // this software without specific prior written permission. 00024 // 00025 // THIS SOFTWARE IS PROVIDED BY SANDIA CORPORATION "AS IS" AND ANY 00026 // EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 00027 // IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR 00028 // PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL SANDIA CORPORATION OR THE 00029 // CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 00030 // EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 00031 // PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 00032 // PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF 00033 // LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING 00034 // NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 00035 // SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 00036 // 00037 // Questions? Contact Michael A. Heroux (maherou@sandia.gov) 00038 // 00039 // *********************************************************************** 00040 // @HEADER 00041 00042 // Kris 00043 // 06.16.03 -- Start over from scratch 00044 // 06.16.03 -- Initial templatization (Tpetra_BLAS.cpp is no longer needed) 00045 // 06.18.03 -- Changed xxxxx_() function calls to XXXXX_F77() 00046 // -- Added warning messages for generic calls 00047 // 07.08.03 -- Move into Teuchos package/namespace 00048 // 07.24.03 -- The first iteration of BLAS generics is nearing completion. Caveats: 00049 // * TRSM isn't finished yet; it works for one case at the moment (left side, upper tri., no transpose, no unit diag.) 00050 // * Many of the generic implementations are quite inefficient, ugly, or both. I wrote these to be easy to debug, not for efficiency or legibility. The next iteration will improve both of these aspects as much as possible. 00051 // * Very little verification of input parameters is done, save for the character-type arguments (TRANS, etc.) which is quite robust. 00052 // * All of the routines that make use of both an incx and incy parameter (which includes much of the L1 BLAS) are set up to work iff incx == incy && incx > 0. Allowing for differing/negative values of incx/incy should be relatively trivial. 00053 // * All of the L2/L3 routines assume that the entire matrix is being used (that is, if A is mxn, lda = m); they don't work on submatrices yet. This *should* be a reasonably trivial thing to fix, as well. 00054 // -- Removed warning messages for generic calls 00055 // 08.08.03 -- TRSM now works for all cases where SIDE == L and DIAG == N. DIAG == U is implemented but does not work correctly; SIDE == R is not yet implemented. 00056 // 08.14.03 -- TRSM now works for all cases and accepts (and uses) leading-dimension information. 00057 // 09.26.03 -- character input replaced with enumerated input to cause compiling errors and not run-time errors ( suggested by RAB ). 00058 00059 #ifndef _TEUCHOS_BLAS_HPP_ 00060 #define _TEUCHOS_BLAS_HPP_ 00061 00069 /* for INTEL_CXML, the second arg may need to be changed to 'one'. If so 00070 the appropriate declaration of one will need to be added back into 00071 functions that include the macro: 00072 */ 00073 #if defined (INTEL_CXML) 00074 unsigned int one=1; 00075 #endif 00076 00077 #ifdef CHAR_MACRO 00078 #undef CHAR_MACRO 00079 #endif 00080 #if defined (INTEL_CXML) 00081 #define CHAR_MACRO(char_var) &char_var, one 00082 #else 00083 #define CHAR_MACRO(char_var) &char_var 00084 #endif 00085 00086 #include "Teuchos_ConfigDefs.hpp" 00087 #include "Teuchos_ScalarTraits.hpp" 00088 #include "Teuchos_OrdinalTraits.hpp" 00089 #include "Teuchos_BLAS_types.hpp" 00090 #include "Teuchos_Assert.hpp" 00091 00125 namespace Teuchos 00126 { 00127 extern TEUCHOS_LIB_DLL_EXPORT const char ESideChar[]; 00128 extern TEUCHOS_LIB_DLL_EXPORT const char ETranspChar[]; 00129 extern TEUCHOS_LIB_DLL_EXPORT const char EUploChar[]; 00130 extern TEUCHOS_LIB_DLL_EXPORT const char EDiagChar[]; 00131 00133 00138 template<typename OrdinalType, typename ScalarType> 00139 class DefaultBLASImpl 00140 { 00141 00142 typedef typename Teuchos::ScalarTraits<ScalarType>::magnitudeType MagnitudeType; 00143 00144 public: 00146 00147 00149 inline DefaultBLASImpl(void) {} 00150 00152 inline DefaultBLASImpl(const DefaultBLASImpl<OrdinalType, ScalarType>& /*BLAS_source*/) {} 00153 00155 inline virtual ~DefaultBLASImpl(void) {} 00157 00159 00160 00162 void ROTG(ScalarType* da, ScalarType* db, MagnitudeType* c, ScalarType* s) const; 00163 00165 void ROT(const OrdinalType n, ScalarType* dx, const OrdinalType incx, ScalarType* dy, const OrdinalType incy, MagnitudeType* c, ScalarType* s) const; 00166 00168 void SCAL(const OrdinalType n, const ScalarType alpha, ScalarType* x, const OrdinalType incx) const; 00169 00171 void COPY(const OrdinalType n, const ScalarType* x, const OrdinalType incx, ScalarType* y, const OrdinalType incy) const; 00172 00174 template <typename alpha_type, typename x_type> 00175 void AXPY(const OrdinalType n, const alpha_type alpha, const x_type* x, const OrdinalType incx, ScalarType* y, const OrdinalType incy) const; 00176 00178 typename ScalarTraits<ScalarType>::magnitudeType ASUM(const OrdinalType n, const ScalarType* x, const OrdinalType incx) const; 00179 00181 template <typename x_type, typename y_type> 00182 ScalarType DOT(const OrdinalType n, const x_type* x, const OrdinalType incx, const y_type* y, const OrdinalType incy) const; 00183 00185 typename ScalarTraits<ScalarType>::magnitudeType NRM2(const OrdinalType n, const ScalarType* x, const OrdinalType incx) const; 00186 00188 OrdinalType IAMAX(const OrdinalType n, const ScalarType* x, const OrdinalType incx) const; 00189 00191 00193 00194 00196 template <typename alpha_type, typename A_type, typename x_type, typename beta_type> 00197 void GEMV(ETransp trans, const OrdinalType m, const OrdinalType n, const alpha_type alpha, const A_type* A, 00198 const OrdinalType lda, const x_type* x, const OrdinalType incx, const beta_type beta, ScalarType* y, const OrdinalType incy) const; 00199 00201 template <typename A_type> 00202 void TRMV(EUplo uplo, ETransp trans, EDiag diag, const OrdinalType n, const A_type* A, 00203 const OrdinalType lda, ScalarType* x, const OrdinalType incx) const; 00204 00206 template <typename alpha_type, typename x_type, typename y_type> 00207 void GER(const OrdinalType m, const OrdinalType n, const alpha_type alpha, const x_type* x, const OrdinalType incx, 00208 const y_type* y, const OrdinalType incy, ScalarType* A, const OrdinalType lda) const; 00210 00212 00213 00215 template <typename alpha_type, typename A_type, typename B_type, typename beta_type> 00216 void GEMM(ETransp transa, ETransp transb, const OrdinalType m, const OrdinalType n, const OrdinalType k, const alpha_type alpha, const A_type* A, const OrdinalType lda, const B_type* B, const OrdinalType ldb, const beta_type beta, ScalarType* C, const OrdinalType ldc) const; 00217 00219 template <typename alpha_type, typename A_type, typename B_type, typename beta_type> 00220 void SYMM(ESide side, EUplo uplo, const OrdinalType m, const OrdinalType n, const alpha_type alpha, const A_type* A, const OrdinalType lda, const B_type* B, const OrdinalType ldb, const beta_type beta, ScalarType* C, const OrdinalType ldc) const; 00221 00223 template <typename alpha_type, typename A_type, typename beta_type> 00224 void SYRK(EUplo uplo, ETransp trans, const OrdinalType n, const OrdinalType k, const alpha_type alpha, const A_type* A, const OrdinalType lda, const beta_type beta, ScalarType* C, const OrdinalType ldc) const; 00225 00227 template <typename alpha_type, typename A_type> 00228 void TRMM(ESide side, EUplo uplo, ETransp transa, EDiag diag, const OrdinalType m, const OrdinalType n, 00229 const alpha_type alpha, const A_type* A, const OrdinalType lda, ScalarType* B, const OrdinalType ldb) const; 00230 00232 template <typename alpha_type, typename A_type> 00233 void TRSM(ESide side, EUplo uplo, ETransp transa, EDiag diag, const OrdinalType m, const OrdinalType n, 00234 const alpha_type alpha, const A_type* A, const OrdinalType lda, ScalarType* B, const OrdinalType ldb) const; 00236 }; 00237 00238 template<typename OrdinalType, typename ScalarType> 00239 class TEUCHOS_LIB_DLL_EXPORT BLAS : public DefaultBLASImpl<OrdinalType,ScalarType> 00240 { 00241 00242 typedef typename Teuchos::ScalarTraits<ScalarType>::magnitudeType MagnitudeType; 00243 00244 public: 00246 00247 00249 inline BLAS(void) {} 00250 00252 inline BLAS(const BLAS<OrdinalType, ScalarType>& /*BLAS_source*/) {} 00253 00255 inline virtual ~BLAS(void) {} 00257 }; 00258 00259 //------------------------------------------------------------------------------------------ 00260 // LEVEL 1 BLAS ROUTINES 00261 //------------------------------------------------------------------------------------------ 00262 00270 namespace details { 00271 00272 template<typename ScalarType, bool isComplex> 00273 class GivensRotator { 00274 public: 00275 void 00276 ROTG (ScalarType* a, 00277 ScalarType* b, 00278 typename ScalarTraits<ScalarType>::magnitudeType* c, 00279 ScalarType* s) const; 00280 }; 00281 00282 // Complex-arithmetic specialization. 00283 template<typename ScalarType> 00284 class GivensRotator<ScalarType, true> { 00285 public: 00286 void 00287 ROTG (ScalarType* ca, 00288 ScalarType* cb, 00289 typename ScalarTraits<ScalarType>::magnitudeType* c, 00290 ScalarType* s) const; 00291 }; 00292 00293 // Real-arithmetic specialization. 00294 template<typename ScalarType> 00295 class GivensRotator<ScalarType, false> { 00296 public: 00297 void 00298 ROTG (ScalarType* da, 00299 ScalarType* db, 00300 ScalarType* c, 00301 ScalarType* s) const; 00302 00303 private: 00316 ScalarType SIGN (ScalarType x, ScalarType y) const { 00317 typedef ScalarTraits<ScalarType> STS; 00318 00319 if (y > STS::zero()) { 00320 return STS::magnitude (x); 00321 } else if (y < STS::zero()) { 00322 return -STS::magnitude (x); 00323 } else { // y == STS::zero() 00324 // Suppose that ScalarType implements signed zero, as IEEE 00325 // 754 - compliant floating-point numbers should. You can't 00326 // use == to test for signed zero, since +0 == -0. However, 00327 // 1/0 = Inf > 0 and 1/-0 = -Inf < 0. Let's hope ScalarType 00328 // supports Inf... we don't need to test for Inf, just see 00329 // if it's greater than or less than zero. 00330 // 00331 // NOTE: This ONLY works if ScalarType is real. Complex 00332 // infinity doesn't have a sign, so we can't compare it with 00333 // zero. That's OK, because finite complex numbers don't 00334 // have a sign either; they have an angle. 00335 ScalarType signedInfinity = STS::one() / y; 00336 if (signedInfinity > STS::zero()) { 00337 return STS::magnitude (x); 00338 } else { 00339 // Even if ScalarType doesn't implement signed zero, 00340 // Fortran's SIGN intrinsic returns -ABS(X) if the second 00341 // argument Y is zero. We imitate this behavior here. 00342 return -STS::magnitude (x); 00343 } 00344 } 00345 } 00346 }; 00347 00348 // Implementation of complex-arithmetic specialization. 00349 template<typename ScalarType> 00350 void 00351 GivensRotator<ScalarType, true>:: 00352 ROTG (ScalarType* ca, 00353 ScalarType* cb, 00354 typename ScalarTraits<ScalarType>::magnitudeType* c, 00355 ScalarType* s) const 00356 { 00357 typedef ScalarTraits<ScalarType> STS; 00358 typedef typename STS::magnitudeType MagnitudeType; 00359 typedef ScalarTraits<MagnitudeType> STM; 00360 00361 // This is a straightforward translation into C++ of the 00362 // reference BLAS' implementation of ZROTG. You can get 00363 // the Fortran 77 source code of ZROTG here: 00364 // 00365 // http://www.netlib.org/blas/zrotg.f 00366 // 00367 // I used the following rules to translate Fortran types and 00368 // intrinsic functions into C++: 00369 // 00370 // DOUBLE PRECISION -> MagnitudeType 00371 // DOUBLE COMPLEX -> ScalarType 00372 // CDABS -> STS::magnitude 00373 // DCMPLX -> ScalarType constructor (assuming that ScalarType 00374 // is std::complex<MagnitudeType>) 00375 // DCONJG -> STS::conjugate 00376 // DSQRT -> STM::squareroot 00377 ScalarType alpha; 00378 MagnitudeType norm, scale; 00379 00380 if (STS::magnitude (*ca) == STM::zero()) { 00381 *c = STM::zero(); 00382 *s = STS::one(); 00383 *ca = *cb; 00384 } else { 00385 scale = STS::magnitude (*ca) + STS::magnitude (*cb); 00386 { // I introduced temporaries into the translated BLAS code in 00387 // order to make the expression easier to read and also save a 00388 // few floating-point operations. 00389 const MagnitudeType ca_scaled = 00390 STS::magnitude (*ca / ScalarType(scale, STM::zero())); 00391 const MagnitudeType cb_scaled = 00392 STS::magnitude (*cb / ScalarType(scale, STM::zero())); 00393 norm = scale * 00394 STM::squareroot (ca_scaled*ca_scaled + cb_scaled*cb_scaled); 00395 } 00396 alpha = *ca / STS::magnitude (*ca); 00397 *c = STS::magnitude (*ca) / norm; 00398 *s = alpha * STS::conjugate (*cb) / norm; 00399 *ca = alpha * norm; 00400 } 00401 } 00402 00403 // Implementation of real-arithmetic specialization. 00404 template<typename ScalarType> 00405 void 00406 GivensRotator<ScalarType, false>:: 00407 ROTG (ScalarType* da, 00408 ScalarType* db, 00409 ScalarType* c, 00410 ScalarType* s) const 00411 { 00412 typedef ScalarTraits<ScalarType> STS; 00413 00414 // This is a straightforward translation into C++ of the 00415 // reference BLAS' implementation of DROTG. You can get 00416 // the Fortran 77 source code of DROTG here: 00417 // 00418 // http://www.netlib.org/blas/drotg.f 00419 // 00420 // I used the following rules to translate Fortran types and 00421 // intrinsic functions into C++: 00422 // 00423 // DOUBLE PRECISION -> ScalarType 00424 // DABS -> STS::magnitude 00425 // DSQRT -> STM::squareroot 00426 // DSIGN -> SIGN (see below) 00427 // 00428 // DSIGN(x,y) (the old DOUBLE PRECISION type-specific form of 00429 // the Fortran type-generic SIGN intrinsic) required special 00430 // translation, which we did in a separate utility function in 00431 // the specializaton of GivensRotator for real arithmetic. 00432 // (ROTG for complex arithmetic doesn't require this function.) 00433 // C99 provides a copysign() math library function, but we are 00434 // not able to rely on the existence of C99 functions here. 00435 ScalarType r, roe, scale, z; 00436 00437 roe = *db; 00438 if (STS::magnitude (*da) > STS::magnitude (*db)) { 00439 roe = *da; 00440 } 00441 scale = STS::magnitude (*da) + STS::magnitude (*db); 00442 if (scale == STS::zero()) { 00443 *c = STS::one(); 00444 *s = STS::zero(); 00445 r = STS::zero(); 00446 z = STS::zero(); 00447 } else { 00448 // I introduced temporaries into the translated BLAS code in 00449 // order to make the expression easier to read and also save 00450 // a few floating-point operations. 00451 const ScalarType da_scaled = *da / scale; 00452 const ScalarType db_scaled = *db / scale; 00453 r = scale * STS::squareroot (da_scaled*da_scaled + db_scaled*db_scaled); 00454 r = SIGN (STS::one(), roe) * r; 00455 *c = *da / r; 00456 *s = *db / r; 00457 z = STS::one(); 00458 if (STS::magnitude (*da) > STS::magnitude (*db)) { 00459 z = *s; 00460 } 00461 if (STS::magnitude (*db) >= STS::magnitude (*da) && *c != STS::zero()) { 00462 z = STS::one() / *c; 00463 } 00464 } 00465 00466 *da = r; 00467 *db = z; 00468 } 00469 } // namespace details 00470 00471 template<typename OrdinalType, typename ScalarType> 00472 void 00473 DefaultBLASImpl<OrdinalType, ScalarType>:: 00474 ROTG (ScalarType* da, 00475 ScalarType* db, 00476 MagnitudeType* c, 00477 ScalarType* s) const 00478 { 00479 typedef ScalarTraits<ScalarType> STS; 00480 details::GivensRotator<ScalarType, STS::isComplex> rotator; 00481 rotator.ROTG (da, db, c, s); 00482 } 00483 00484 template<typename OrdinalType, typename ScalarType> 00485 void DefaultBLASImpl<OrdinalType,ScalarType>::ROT(const OrdinalType n, ScalarType* dx, const OrdinalType incx, ScalarType* dy, const OrdinalType incy, MagnitudeType* c, ScalarType* s) const 00486 { 00487 OrdinalType izero = OrdinalTraits<OrdinalType>::zero(); 00488 ScalarType sconj = Teuchos::ScalarTraits<ScalarType>::conjugate(*s); 00489 if (n <= 0) return; 00490 if (incx==1 && incy==1) { 00491 for(OrdinalType i=0; i<n; ++i) { 00492 ScalarType temp = *c*dx[i] + sconj*dy[i]; 00493 dy[i] = *c*dy[i] - sconj*dx[i]; 00494 dx[i] = temp; 00495 } 00496 } 00497 else { 00498 OrdinalType ix = 0, iy = 0; 00499 if (incx < izero) ix = (-n+1)*incx; 00500 if (incy < izero) iy = (-n+1)*incy; 00501 for(OrdinalType i=0; i<n; ++i) { 00502 ScalarType temp = *c*dx[ix] + sconj*dy[iy]; 00503 dy[iy] = *c*dy[iy] - sconj*dx[ix]; 00504 dx[ix] = temp; 00505 ix += incx; 00506 iy += incy; 00507 } 00508 } 00509 } 00510 00511 template<typename OrdinalType, typename ScalarType> 00512 void DefaultBLASImpl<OrdinalType, ScalarType>::SCAL(const OrdinalType n, const ScalarType alpha, ScalarType* x, const OrdinalType incx) const 00513 { 00514 OrdinalType izero = OrdinalTraits<OrdinalType>::zero(); 00515 OrdinalType ione = OrdinalTraits<OrdinalType>::one(); 00516 OrdinalType i, ix = izero; 00517 if ( n > izero ) { 00518 // Set the initial index (ix). 00519 if (incx < izero) { ix = (-n+ione)*incx; } 00520 // Scale the std::vector. 00521 for(i = izero; i < n; i++) 00522 { 00523 x[ix] *= alpha; 00524 ix += incx; 00525 } 00526 } 00527 } /* end SCAL */ 00528 00529 template<typename OrdinalType, typename ScalarType> 00530 void DefaultBLASImpl<OrdinalType, ScalarType>::COPY(const OrdinalType n, const ScalarType* x, const OrdinalType incx, ScalarType* y, const OrdinalType incy) const 00531 { 00532 OrdinalType izero = OrdinalTraits<OrdinalType>::zero(); 00533 OrdinalType ione = OrdinalTraits<OrdinalType>::one(); 00534 OrdinalType i, ix = izero, iy = izero; 00535 if ( n > izero ) { 00536 // Set the initial indices (ix, iy). 00537 if (incx < izero) { ix = (-n+ione)*incx; } 00538 if (incy < izero) { iy = (-n+ione)*incy; } 00539 00540 for(i = izero; i < n; i++) 00541 { 00542 y[iy] = x[ix]; 00543 ix += incx; 00544 iy += incy; 00545 } 00546 } 00547 } /* end COPY */ 00548 00549 template<typename OrdinalType, typename ScalarType> 00550 template <typename alpha_type, typename x_type> 00551 void DefaultBLASImpl<OrdinalType, ScalarType>::AXPY(const OrdinalType n, const alpha_type alpha, const x_type* x, const OrdinalType incx, ScalarType* y, const OrdinalType incy) const 00552 { 00553 OrdinalType izero = OrdinalTraits<OrdinalType>::zero(); 00554 OrdinalType ione = OrdinalTraits<OrdinalType>::one(); 00555 OrdinalType i, ix = izero, iy = izero; 00556 if( n > izero && alpha != ScalarTraits<alpha_type>::zero()) 00557 { 00558 // Set the initial indices (ix, iy). 00559 if (incx < izero) { ix = (-n+ione)*incx; } 00560 if (incy < izero) { iy = (-n+ione)*incy; } 00561 00562 for(i = izero; i < n; i++) 00563 { 00564 y[iy] += alpha * x[ix]; 00565 ix += incx; 00566 iy += incy; 00567 } 00568 } 00569 } /* end AXPY */ 00570 00571 template<typename OrdinalType, typename ScalarType> 00572 typename ScalarTraits<ScalarType>::magnitudeType DefaultBLASImpl<OrdinalType, ScalarType>::ASUM(const OrdinalType n, const ScalarType* x, const OrdinalType incx) const 00573 { 00574 OrdinalType izero = OrdinalTraits<OrdinalType>::zero(); 00575 OrdinalType ione = OrdinalTraits<OrdinalType>::one(); 00576 typename ScalarTraits<ScalarType>::magnitudeType result = 00577 ScalarTraits<typename ScalarTraits<ScalarType>::magnitudeType>::zero(); 00578 OrdinalType i, ix = izero; 00579 if( n > izero ) { 00580 // Set the initial indices 00581 if (incx < izero) { ix = (-n+ione)*incx; } 00582 00583 for(i = izero; i < n; i++) 00584 { 00585 result += ScalarTraits<ScalarType>::magnitude(x[ix]); 00586 ix += incx; 00587 } 00588 } 00589 return result; 00590 } /* end ASUM */ 00591 00592 template<typename OrdinalType, typename ScalarType> 00593 template <typename x_type, typename y_type> 00594 ScalarType DefaultBLASImpl<OrdinalType, ScalarType>::DOT(const OrdinalType n, const x_type* x, const OrdinalType incx, const y_type* y, const OrdinalType incy) const 00595 { 00596 OrdinalType izero = OrdinalTraits<OrdinalType>::zero(); 00597 OrdinalType ione = OrdinalTraits<OrdinalType>::one(); 00598 ScalarType result = ScalarTraits<ScalarType>::zero(); 00599 OrdinalType i, ix = izero, iy = izero; 00600 if( n > izero ) 00601 { 00602 // Set the initial indices (ix,iy). 00603 if (incx < izero) { ix = (-n+ione)*incx; } 00604 if (incy < izero) { iy = (-n+ione)*incy; } 00605 00606 for(i = izero; i < n; i++) 00607 { 00608 result += ScalarTraits<x_type>::conjugate(x[ix]) * y[iy]; 00609 ix += incx; 00610 iy += incy; 00611 } 00612 } 00613 return result; 00614 } /* end DOT */ 00615 00616 template<typename OrdinalType, typename ScalarType> 00617 typename ScalarTraits<ScalarType>::magnitudeType DefaultBLASImpl<OrdinalType, ScalarType>::NRM2(const OrdinalType n, const ScalarType* x, const OrdinalType incx) const 00618 { 00619 OrdinalType izero = OrdinalTraits<OrdinalType>::zero(); 00620 OrdinalType ione = OrdinalTraits<OrdinalType>::one(); 00621 typename ScalarTraits<ScalarType>::magnitudeType result = 00622 ScalarTraits<typename ScalarTraits<ScalarType>::magnitudeType>::zero(); 00623 OrdinalType i, ix = izero; 00624 if ( n > izero ) 00625 { 00626 // Set the initial index. 00627 if (incx < izero) { ix = (-n+ione)*incx; } 00628 00629 for(i = izero; i < n; i++) 00630 { 00631 result += ScalarTraits<ScalarType>::magnitude(ScalarTraits<ScalarType>::conjugate(x[ix]) * x[ix]); 00632 ix += incx; 00633 } 00634 result = ScalarTraits<typename ScalarTraits<ScalarType>::magnitudeType>::squareroot(result); 00635 } 00636 return result; 00637 } /* end NRM2 */ 00638 00639 template<typename OrdinalType, typename ScalarType> 00640 OrdinalType DefaultBLASImpl<OrdinalType, ScalarType>::IAMAX(const OrdinalType n, const ScalarType* x, const OrdinalType incx) const 00641 { 00642 OrdinalType izero = OrdinalTraits<OrdinalType>::zero(); 00643 OrdinalType ione = OrdinalTraits<OrdinalType>::one(); 00644 OrdinalType result = izero, ix = izero, i; 00645 typename ScalarTraits<ScalarType>::magnitudeType maxval = 00646 ScalarTraits<typename ScalarTraits<ScalarType>::magnitudeType>::zero(); 00647 00648 if ( n > izero ) 00649 { 00650 if (incx < izero) { ix = (-n+ione)*incx; } 00651 maxval = ScalarTraits<ScalarType>::magnitude(x[ix]); 00652 ix += incx; 00653 for(i = ione; i < n; i++) 00654 { 00655 if(ScalarTraits<ScalarType>::magnitude(x[ix]) > maxval) 00656 { 00657 result = i; 00658 maxval = ScalarTraits<ScalarType>::magnitude(x[ix]); 00659 } 00660 ix += incx; 00661 } 00662 } 00663 return result + 1; // the BLAS I?AMAX functions return 1-indexed (Fortran-style) values 00664 } /* end IAMAX */ 00665 00666 //------------------------------------------------------------------------------------------ 00667 // LEVEL 2 BLAS ROUTINES 00668 //------------------------------------------------------------------------------------------ 00669 template<typename OrdinalType, typename ScalarType> 00670 template <typename alpha_type, typename A_type, typename x_type, typename beta_type> 00671 void DefaultBLASImpl<OrdinalType, ScalarType>::GEMV(ETransp trans, const OrdinalType m, const OrdinalType n, const alpha_type alpha, const A_type* A, const OrdinalType lda, const x_type* x, const OrdinalType incx, const beta_type beta, ScalarType* y, const OrdinalType incy) const 00672 { 00673 OrdinalType izero = OrdinalTraits<OrdinalType>::zero(); 00674 OrdinalType ione = OrdinalTraits<OrdinalType>::one(); 00675 alpha_type alpha_zero = ScalarTraits<alpha_type>::zero(); 00676 beta_type beta_zero = ScalarTraits<beta_type>::zero(); 00677 x_type x_zero = ScalarTraits<x_type>::zero(); 00678 ScalarType y_zero = ScalarTraits<ScalarType>::zero(); 00679 beta_type beta_one = ScalarTraits<beta_type>::one(); 00680 bool noConj = true; 00681 bool BadArgument = false; 00682 00683 // Quick return if there is nothing to do! 00684 if( m == izero || n == izero || ( alpha == alpha_zero && beta == beta_one ) ){ return; } 00685 00686 // Otherwise, we need to check the argument list. 00687 if( m < izero ) { 00688 std::cout << "BLAS::GEMV Error: M == " << m << std::endl; 00689 BadArgument = true; 00690 } 00691 if( n < izero ) { 00692 std::cout << "BLAS::GEMV Error: N == " << n << std::endl; 00693 BadArgument = true; 00694 } 00695 if( lda < m ) { 00696 std::cout << "BLAS::GEMV Error: LDA < MAX(1,M)"<< std::endl; 00697 BadArgument = true; 00698 } 00699 if( incx == izero ) { 00700 std::cout << "BLAS::GEMV Error: INCX == 0"<< std::endl; 00701 BadArgument = true; 00702 } 00703 if( incy == izero ) { 00704 std::cout << "BLAS::GEMV Error: INCY == 0"<< std::endl; 00705 BadArgument = true; 00706 } 00707 00708 if(!BadArgument) { 00709 OrdinalType i, j, lenx, leny, ix, iy, jx, jy; 00710 OrdinalType kx = izero, ky = izero; 00711 ScalarType temp; 00712 00713 // Determine the lengths of the vectors x and y. 00714 if(ETranspChar[trans] == 'N') { 00715 lenx = n; 00716 leny = m; 00717 } else { 00718 lenx = m; 00719 leny = n; 00720 } 00721 00722 // Determine if this is a conjugate tranpose 00723 noConj = (ETranspChar[trans] == 'T'); 00724 00725 // Set the starting pointers for the vectors x and y if incx/y < 0. 00726 if (incx < izero ) { kx = (ione - lenx)*incx; } 00727 if (incy < izero ) { ky = (ione - leny)*incy; } 00728 00729 // Form y = beta*y 00730 ix = kx; iy = ky; 00731 if(beta != beta_one) { 00732 if (incy == ione) { 00733 if (beta == beta_zero) { 00734 for(i = izero; i < leny; i++) { y[i] = y_zero; } 00735 } else { 00736 for(i = izero; i < leny; i++) { y[i] *= beta; } 00737 } 00738 } else { 00739 if (beta == beta_zero) { 00740 for(i = izero; i < leny; i++) { 00741 y[iy] = y_zero; 00742 iy += incy; 00743 } 00744 } else { 00745 for(i = izero; i < leny; i++) { 00746 y[iy] *= beta; 00747 iy += incy; 00748 } 00749 } 00750 } 00751 } 00752 00753 // Return if we don't have to do anything more. 00754 if(alpha == alpha_zero) { return; } 00755 00756 if( ETranspChar[trans] == 'N' ) { 00757 // Form y = alpha*A*y 00758 jx = kx; 00759 if (incy == ione) { 00760 for(j = izero; j < n; j++) { 00761 if (x[jx] != x_zero) { 00762 temp = alpha*x[jx]; 00763 for(i = izero; i < m; i++) { 00764 y[i] += temp*A[j*lda + i]; 00765 } 00766 } 00767 jx += incx; 00768 } 00769 } else { 00770 for(j = izero; j < n; j++) { 00771 if (x[jx] != x_zero) { 00772 temp = alpha*x[jx]; 00773 iy = ky; 00774 for(i = izero; i < m; i++) { 00775 y[iy] += temp*A[j*lda + i]; 00776 iy += incy; 00777 } 00778 } 00779 jx += incx; 00780 } 00781 } 00782 } else { 00783 jy = ky; 00784 if (incx == ione) { 00785 for(j = izero; j < n; j++) { 00786 temp = y_zero; 00787 if ( noConj ) { 00788 for(i = izero; i < m; i++) { 00789 temp += A[j*lda + i]*x[i]; 00790 } 00791 } else { 00792 for(i = izero; i < m; i++) { 00793 temp += ScalarTraits<A_type>::conjugate(A[j*lda + i])*x[i]; 00794 } 00795 } 00796 y[jy] += alpha*temp; 00797 jy += incy; 00798 } 00799 } else { 00800 for(j = izero; j < n; j++) { 00801 temp = y_zero; 00802 ix = kx; 00803 if ( noConj ) { 00804 for (i = izero; i < m; i++) { 00805 temp += A[j*lda + i]*x[ix]; 00806 ix += incx; 00807 } 00808 } else { 00809 for (i = izero; i < m; i++) { 00810 temp += ScalarTraits<A_type>::conjugate(A[j*lda + i])*x[ix]; 00811 ix += incx; 00812 } 00813 } 00814 y[jy] += alpha*temp; 00815 jy += incy; 00816 } 00817 } 00818 } 00819 } /* if (!BadArgument) */ 00820 } /* end GEMV */ 00821 00822 template<typename OrdinalType, typename ScalarType> 00823 template <typename A_type> 00824 void DefaultBLASImpl<OrdinalType, ScalarType>::TRMV(EUplo uplo, ETransp trans, EDiag diag, const OrdinalType n, const A_type* A, const OrdinalType lda, ScalarType* x, const OrdinalType incx) const 00825 { 00826 OrdinalType izero = OrdinalTraits<OrdinalType>::zero(); 00827 OrdinalType ione = OrdinalTraits<OrdinalType>::one(); 00828 ScalarType zero = ScalarTraits<ScalarType>::zero(); 00829 bool BadArgument = false; 00830 bool noConj = true; 00831 00832 // Quick return if there is nothing to do! 00833 if( n == izero ){ return; } 00834 00835 // Otherwise, we need to check the argument list. 00836 if( n < izero ) { 00837 std::cout << "BLAS::TRMV Error: N == " << n << std::endl; 00838 BadArgument = true; 00839 } 00840 if( lda < n ) { 00841 std::cout << "BLAS::TRMV Error: LDA < MAX(1,N)"<< std::endl; 00842 BadArgument = true; 00843 } 00844 if( incx == izero ) { 00845 std::cout << "BLAS::TRMV Error: INCX == 0"<< std::endl; 00846 BadArgument = true; 00847 } 00848 00849 if(!BadArgument) { 00850 OrdinalType i, j, ix, jx, kx = izero; 00851 ScalarType temp; 00852 bool noUnit = (EDiagChar[diag] == 'N'); 00853 00854 // Determine if this is a conjugate tranpose 00855 noConj = (ETranspChar[trans] == 'T'); 00856 00857 // Set the starting pointer for the std::vector x if incx < 0. 00858 if (incx < izero) { kx = (-n+ione)*incx; } 00859 00860 // Start the operations for a nontransposed triangular matrix 00861 if (ETranspChar[trans] == 'N') { 00862 /* Compute x = A*x */ 00863 if (EUploChar[uplo] == 'U') { 00864 /* A is an upper triangular matrix */ 00865 if (incx == ione) { 00866 for (j=izero; j<n; j++) { 00867 if (x[j] != zero) { 00868 temp = x[j]; 00869 for (i=izero; i<j; i++) { 00870 x[i] += temp*A[j*lda + i]; 00871 } 00872 if ( noUnit ) 00873 x[j] *= A[j*lda + j]; 00874 } 00875 } 00876 } else { 00877 jx = kx; 00878 for (j=izero; j<n; j++) { 00879 if (x[jx] != zero) { 00880 temp = x[jx]; 00881 ix = kx; 00882 for (i=izero; i<j; i++) { 00883 x[ix] += temp*A[j*lda + i]; 00884 ix += incx; 00885 } 00886 if ( noUnit ) 00887 x[jx] *= A[j*lda + j]; 00888 } 00889 jx += incx; 00890 } 00891 } /* if (incx == ione) */ 00892 } else { /* A is a lower triangular matrix */ 00893 if (incx == ione) { 00894 for (j=n-ione; j>-ione; j--) { 00895 if (x[j] != zero) { 00896 temp = x[j]; 00897 for (i=n-ione; i>j; i--) { 00898 x[i] += temp*A[j*lda + i]; 00899 } 00900 if ( noUnit ) 00901 x[j] *= A[j*lda + j]; 00902 } 00903 } 00904 } else { 00905 kx += (n-ione)*incx; 00906 jx = kx; 00907 for (j=n-ione; j>-ione; j--) { 00908 if (x[jx] != zero) { 00909 temp = x[jx]; 00910 ix = kx; 00911 for (i=n-ione; i>j; i--) { 00912 x[ix] += temp*A[j*lda + i]; 00913 ix -= incx; 00914 } 00915 if ( noUnit ) 00916 x[jx] *= A[j*lda + j]; 00917 } 00918 jx -= incx; 00919 } 00920 } 00921 } /* if (EUploChar[uplo]=='U') */ 00922 } else { /* A is transposed/conjugated */ 00923 /* Compute x = A'*x */ 00924 if (EUploChar[uplo]=='U') { 00925 /* A is an upper triangular matrix */ 00926 if (incx == ione) { 00927 for (j=n-ione; j>-ione; j--) { 00928 temp = x[j]; 00929 if ( noConj ) { 00930 if ( noUnit ) 00931 temp *= A[j*lda + j]; 00932 for (i=j-ione; i>-ione; i--) { 00933 temp += A[j*lda + i]*x[i]; 00934 } 00935 } else { 00936 if ( noUnit ) 00937 temp *= ScalarTraits<A_type>::conjugate(A[j*lda + j]); 00938 for (i=j-ione; i>-ione; i--) { 00939 temp += ScalarTraits<A_type>::conjugate(A[j*lda + i])*x[i]; 00940 } 00941 } 00942 x[j] = temp; 00943 } 00944 } else { 00945 jx = kx + (n-ione)*incx; 00946 for (j=n-ione; j>-ione; j--) { 00947 temp = x[jx]; 00948 ix = jx; 00949 if ( noConj ) { 00950 if ( noUnit ) 00951 temp *= A[j*lda + j]; 00952 for (i=j-ione; i>-ione; i--) { 00953 ix -= incx; 00954 temp += A[j*lda + i]*x[ix]; 00955 } 00956 } else { 00957 if ( noUnit ) 00958 temp *= ScalarTraits<A_type>::conjugate(A[j*lda + j]); 00959 for (i=j-ione; i>-ione; i--) { 00960 ix -= incx; 00961 temp += ScalarTraits<A_type>::conjugate(A[j*lda + i])*x[ix]; 00962 } 00963 } 00964 x[jx] = temp; 00965 jx -= incx; 00966 } 00967 } 00968 } else { 00969 /* A is a lower triangular matrix */ 00970 if (incx == ione) { 00971 for (j=izero; j<n; j++) { 00972 temp = x[j]; 00973 if ( noConj ) { 00974 if ( noUnit ) 00975 temp *= A[j*lda + j]; 00976 for (i=j+ione; i<n; i++) { 00977 temp += A[j*lda + i]*x[i]; 00978 } 00979 } else { 00980 if ( noUnit ) 00981 temp *= ScalarTraits<A_type>::conjugate(A[j*lda + j]); 00982 for (i=j+ione; i<n; i++) { 00983 temp += ScalarTraits<A_type>::conjugate(A[j*lda + i])*x[i]; 00984 } 00985 } 00986 x[j] = temp; 00987 } 00988 } else { 00989 jx = kx; 00990 for (j=izero; j<n; j++) { 00991 temp = x[jx]; 00992 ix = jx; 00993 if ( noConj ) { 00994 if ( noUnit ) 00995 temp *= A[j*lda + j]; 00996 for (i=j+ione; i<n; i++) { 00997 ix += incx; 00998 temp += A[j*lda + i]*x[ix]; 00999 } 01000 } else { 01001 if ( noUnit ) 01002 temp *= ScalarTraits<A_type>::conjugate(A[j*lda + j]); 01003 for (i=j+ione; i<n; i++) { 01004 ix += incx; 01005 temp += ScalarTraits<A_type>::conjugate(A[j*lda + i])*x[ix]; 01006 } 01007 } 01008 x[jx] = temp; 01009 jx += incx; 01010 } 01011 } 01012 } /* if (EUploChar[uplo]=='U') */ 01013 } /* if (ETranspChar[trans]=='N') */ 01014 } /* if (!BadArgument) */ 01015 } /* end TRMV */ 01016 01017 template<typename OrdinalType, typename ScalarType> 01018 template <typename alpha_type, typename x_type, typename y_type> 01019 void DefaultBLASImpl<OrdinalType, ScalarType>::GER(const OrdinalType m, const OrdinalType n, const alpha_type alpha, const x_type* x, const OrdinalType incx, const y_type* y, const OrdinalType incy, ScalarType* A, const OrdinalType lda) const 01020 { 01021 OrdinalType izero = OrdinalTraits<OrdinalType>::zero(); 01022 OrdinalType ione = OrdinalTraits<OrdinalType>::one(); 01023 alpha_type alpha_zero = ScalarTraits<alpha_type>::zero(); 01024 y_type y_zero = ScalarTraits<y_type>::zero(); 01025 bool BadArgument = false; 01026 01027 TEUCHOS_TEST_FOR_EXCEPTION(Teuchos::ScalarTraits<ScalarType>::isComplex, std::logic_error, 01028 "Teuchos::BLAS::GER() does not currently support complex data types."); 01029 01030 // Quick return if there is nothing to do! 01031 if( m == izero || n == izero || alpha == alpha_zero ){ return; } 01032 01033 // Otherwise, we need to check the argument list. 01034 if( m < izero ) { 01035 std::cout << "BLAS::GER Error: M == " << m << std::endl; 01036 BadArgument = true; 01037 } 01038 if( n < izero ) { 01039 std::cout << "BLAS::GER Error: N == " << n << std::endl; 01040 BadArgument = true; 01041 } 01042 if( lda < m ) { 01043 std::cout << "BLAS::GER Error: LDA < MAX(1,M)"<< std::endl; 01044 BadArgument = true; 01045 } 01046 if( incx == 0 ) { 01047 std::cout << "BLAS::GER Error: INCX == 0"<< std::endl; 01048 BadArgument = true; 01049 } 01050 if( incy == 0 ) { 01051 std::cout << "BLAS::GER Error: INCY == 0"<< std::endl; 01052 BadArgument = true; 01053 } 01054 01055 if(!BadArgument) { 01056 OrdinalType i, j, ix, jy = izero, kx = izero; 01057 ScalarType temp; 01058 01059 // Set the starting pointers for the vectors x and y if incx/y < 0. 01060 if (incx < izero) { kx = (-m+ione)*incx; } 01061 if (incy < izero) { jy = (-n+ione)*incy; } 01062 01063 // Start the operations for incx == 1 01064 if( incx == ione ) { 01065 for( j=izero; j<n; j++ ) { 01066 if ( y[jy] != y_zero ) { 01067 temp = alpha*y[jy]; 01068 for ( i=izero; i<m; i++ ) { 01069 A[j*lda + i] += x[i]*temp; 01070 } 01071 } 01072 jy += incy; 01073 } 01074 } 01075 else { // Start the operations for incx != 1 01076 for( j=izero; j<n; j++ ) { 01077 if ( y[jy] != y_zero ) { 01078 temp = alpha*y[jy]; 01079 ix = kx; 01080 for( i=izero; i<m; i++ ) { 01081 A[j*lda + i] += x[ix]*temp; 01082 ix += incx; 01083 } 01084 } 01085 jy += incy; 01086 } 01087 } 01088 } /* if(!BadArgument) */ 01089 } /* end GER */ 01090 01091 //------------------------------------------------------------------------------------------ 01092 // LEVEL 3 BLAS ROUTINES 01093 //------------------------------------------------------------------------------------------ 01094 01095 template<typename OrdinalType, typename ScalarType> 01096 template <typename alpha_type, typename A_type, typename B_type, typename beta_type> 01097 void DefaultBLASImpl<OrdinalType, ScalarType>::GEMM(ETransp transa, ETransp transb, const OrdinalType m, const OrdinalType n, const OrdinalType k, const alpha_type alpha, const A_type* A, const OrdinalType lda, const B_type* B, const OrdinalType ldb, const beta_type beta, ScalarType* C, const OrdinalType ldc) const 01098 { 01099 01100 typedef TypeNameTraits<OrdinalType> OTNT; 01101 typedef TypeNameTraits<ScalarType> STNT; 01102 01103 OrdinalType izero = OrdinalTraits<OrdinalType>::zero(); 01104 alpha_type alpha_zero = ScalarTraits<alpha_type>::zero(); 01105 beta_type beta_zero = ScalarTraits<beta_type>::zero(); 01106 B_type B_zero = ScalarTraits<B_type>::zero(); 01107 ScalarType C_zero = ScalarTraits<ScalarType>::zero(); 01108 beta_type beta_one = ScalarTraits<beta_type>::one(); 01109 OrdinalType i, j, p; 01110 OrdinalType NRowA = m, NRowB = k; 01111 ScalarType temp; 01112 bool BadArgument = false; 01113 bool conjA = false, conjB = false; 01114 01115 // Change dimensions of matrix if either matrix is transposed 01116 if( !(ETranspChar[transa]=='N') ) { 01117 NRowA = k; 01118 } 01119 if( !(ETranspChar[transb]=='N') ) { 01120 NRowB = n; 01121 } 01122 01123 // Quick return if there is nothing to do! 01124 if( (m==izero) || (n==izero) || (((alpha==alpha_zero)||(k==izero)) && (beta==beta_one)) ){ return; } 01125 if( m < izero ) { 01126 std::cout << "BLAS::GEMM Error: M == " << m << std::endl; 01127 BadArgument = true; 01128 } 01129 if( n < izero ) { 01130 std::cout << "BLAS::GEMM Error: N == " << n << std::endl; 01131 BadArgument = true; 01132 } 01133 if( k < izero ) { 01134 std::cout << "BLAS::GEMM Error: K == " << k << std::endl; 01135 BadArgument = true; 01136 } 01137 if( lda < NRowA ) { 01138 std::cout << "BLAS::GEMM Error: LDA < "<<NRowA<<std::endl; 01139 BadArgument = true; 01140 } 01141 if( ldb < NRowB ) { 01142 std::cout << "BLAS::GEMM Error: LDB < "<<NRowB<<std::endl; 01143 BadArgument = true; 01144 } 01145 if( ldc < m ) { 01146 std::cout << "BLAS::GEMM Error: LDC < MAX(1,M)"<< std::endl; 01147 BadArgument = true; 01148 } 01149 01150 if(!BadArgument) { 01151 01152 // Determine if this is a conjugate tranpose 01153 conjA = (ETranspChar[transa] == 'C'); 01154 conjB = (ETranspChar[transb] == 'C'); 01155 01156 // Only need to scale the resulting matrix C. 01157 if( alpha == alpha_zero ) { 01158 if( beta == beta_zero ) { 01159 for (j=izero; j<n; j++) { 01160 for (i=izero; i<m; i++) { 01161 C[j*ldc + i] = C_zero; 01162 } 01163 } 01164 } else { 01165 for (j=izero; j<n; j++) { 01166 for (i=izero; i<m; i++) { 01167 C[j*ldc + i] *= beta; 01168 } 01169 } 01170 } 01171 return; 01172 } 01173 // 01174 // Now start the operations. 01175 // 01176 if ( ETranspChar[transb]=='N' ) { 01177 if ( ETranspChar[transa]=='N' ) { 01178 // Compute C = alpha*A*B + beta*C 01179 for (j=izero; j<n; j++) { 01180 if( beta == beta_zero ) { 01181 for (i=izero; i<m; i++){ 01182 C[j*ldc + i] = C_zero; 01183 } 01184 } else if( beta != beta_one ) { 01185 for (i=izero; i<m; i++){ 01186 C[j*ldc + i] *= beta; 01187 } 01188 } 01189 for (p=izero; p<k; p++){ 01190 if (B[j*ldb + p] != B_zero ){ 01191 temp = alpha*B[j*ldb + p]; 01192 for (i=izero; i<m; i++) { 01193 C[j*ldc + i] += temp*A[p*lda + i]; 01194 } 01195 } 01196 } 01197 } 01198 } else if ( conjA ) { 01199 // Compute C = alpha*conj(A')*B + beta*C 01200 for (j=izero; j<n; j++) { 01201 for (i=izero; i<m; i++) { 01202 temp = C_zero; 01203 for (p=izero; p<k; p++) { 01204 temp += ScalarTraits<A_type>::conjugate(A[i*lda + p])*B[j*ldb + p]; 01205 } 01206 if (beta == beta_zero) { 01207 C[j*ldc + i] = alpha*temp; 01208 } else { 01209 C[j*ldc + i] = alpha*temp + beta*C[j*ldc + i]; 01210 } 01211 } 01212 } 01213 } else { 01214 // Compute C = alpha*A'*B + beta*C 01215 for (j=izero; j<n; j++) { 01216 for (i=izero; i<m; i++) { 01217 temp = C_zero; 01218 for (p=izero; p<k; p++) { 01219 temp += A[i*lda + p]*B[j*ldb + p]; 01220 } 01221 if (beta == beta_zero) { 01222 C[j*ldc + i] = alpha*temp; 01223 } else { 01224 C[j*ldc + i] = alpha*temp + beta*C[j*ldc + i]; 01225 } 01226 } 01227 } 01228 } 01229 } else if ( ETranspChar[transa]=='N' ) { 01230 if ( conjB ) { 01231 // Compute C = alpha*A*conj(B') + beta*C 01232 for (j=izero; j<n; j++) { 01233 if (beta == beta_zero) { 01234 for (i=izero; i<m; i++) { 01235 C[j*ldc + i] = C_zero; 01236 } 01237 } else if ( beta != beta_one ) { 01238 for (i=izero; i<m; i++) { 01239 C[j*ldc + i] *= beta; 01240 } 01241 } 01242 for (p=izero; p<k; p++) { 01243 if (B[p*ldb + j] != B_zero) { 01244 temp = alpha*ScalarTraits<B_type>::conjugate(B[p*ldb + j]); 01245 for (i=izero; i<m; i++) { 01246 C[j*ldc + i] += temp*A[p*lda + i]; 01247 } 01248 } 01249 } 01250 } 01251 } else { 01252 // Compute C = alpha*A*B' + beta*C 01253 for (j=izero; j<n; j++) { 01254 if (beta == beta_zero) { 01255 for (i=izero; i<m; i++) { 01256 C[j*ldc + i] = C_zero; 01257 } 01258 } else if ( beta != beta_one ) { 01259 for (i=izero; i<m; i++) { 01260 C[j*ldc + i] *= beta; 01261 } 01262 } 01263 for (p=izero; p<k; p++) { 01264 if (B[p*ldb + j] != B_zero) { 01265 temp = alpha*B[p*ldb + j]; 01266 for (i=izero; i<m; i++) { 01267 C[j*ldc + i] += temp*A[p*lda + i]; 01268 } 01269 } 01270 } 01271 } 01272 } 01273 } else if ( conjA ) { 01274 if ( conjB ) { 01275 // Compute C = alpha*conj(A')*conj(B') + beta*C 01276 for (j=izero; j<n; j++) { 01277 for (i=izero; i<m; i++) { 01278 temp = C_zero; 01279 for (p=izero; p<k; p++) { 01280 temp += ScalarTraits<A_type>::conjugate(A[i*lda + p]) 01281 * ScalarTraits<B_type>::conjugate(B[p*ldb + j]); 01282 } 01283 if (beta == beta_zero) { 01284 C[j*ldc + i] = alpha*temp; 01285 } else { 01286 C[j*ldc + i] = alpha*temp + beta*C[j*ldc + i]; 01287 } 01288 } 01289 } 01290 } else { 01291 // Compute C = alpha*conj(A')*B' + beta*C 01292 for (j=izero; j<n; j++) { 01293 for (i=izero; i<m; i++) { 01294 temp = C_zero; 01295 for (p=izero; p<k; p++) { 01296 temp += ScalarTraits<A_type>::conjugate(A[i*lda + p]) 01297 * B[p*ldb + j]; 01298 } 01299 if (beta == beta_zero) { 01300 C[j*ldc + i] = alpha*temp; 01301 } else { 01302 C[j*ldc + i] = alpha*temp + beta*C[j*ldc + i]; 01303 } 01304 } 01305 } 01306 } 01307 } else { 01308 if ( conjB ) { 01309 // Compute C = alpha*A'*conj(B') + beta*C 01310 for (j=izero; j<n; j++) { 01311 for (i=izero; i<m; i++) { 01312 temp = C_zero; 01313 for (p=izero; p<k; p++) { 01314 temp += A[i*lda + p] 01315 * ScalarTraits<B_type>::conjugate(B[p*ldb + j]); 01316 } 01317 if (beta == beta_zero) { 01318 C[j*ldc + i] = alpha*temp; 01319 } else { 01320 C[j*ldc + i] = alpha*temp + beta*C[j*ldc + i]; 01321 } 01322 } 01323 } 01324 } else { 01325 // Compute C = alpha*A'*B' + beta*C 01326 for (j=izero; j<n; j++) { 01327 for (i=izero; i<m; i++) { 01328 temp = C_zero; 01329 for (p=izero; p<k; p++) { 01330 temp += A[i*lda + p]*B[p*ldb + j]; 01331 } 01332 if (beta == beta_zero) { 01333 C[j*ldc + i] = alpha*temp; 01334 } else { 01335 C[j*ldc + i] = alpha*temp + beta*C[j*ldc + i]; 01336 } 01337 } 01338 } 01339 } // end if (ETranspChar[transa]=='N') ... 01340 } // end if (ETranspChar[transb]=='N') ... 01341 } // end if (!BadArgument) ... 01342 } // end of GEMM 01343 01344 01345 template<typename OrdinalType, typename ScalarType> 01346 template <typename alpha_type, typename A_type, typename B_type, typename beta_type> 01347 void DefaultBLASImpl<OrdinalType, ScalarType>::SYMM(ESide side, EUplo uplo, const OrdinalType m, const OrdinalType n, const alpha_type alpha, const A_type* A, const OrdinalType lda, const B_type* B, const OrdinalType ldb, const beta_type beta, ScalarType* C, const OrdinalType ldc) const 01348 { 01349 OrdinalType izero = OrdinalTraits<OrdinalType>::zero(); 01350 OrdinalType ione = OrdinalTraits<OrdinalType>::one(); 01351 alpha_type alpha_zero = ScalarTraits<alpha_type>::zero(); 01352 beta_type beta_zero = ScalarTraits<beta_type>::zero(); 01353 ScalarType C_zero = ScalarTraits<ScalarType>::zero(); 01354 beta_type beta_one = ScalarTraits<beta_type>::one(); 01355 OrdinalType i, j, k, NRowA = m; 01356 ScalarType temp1, temp2; 01357 bool BadArgument = false; 01358 bool Upper = (EUploChar[uplo] == 'U'); 01359 if (ESideChar[side] == 'R') { NRowA = n; } 01360 01361 // Quick return. 01362 if ( (m==izero) || (n==izero) || ( (alpha==alpha_zero)&&(beta==beta_one) ) ) { return; } 01363 if( m < izero ) { 01364 std::cout << "BLAS::SYMM Error: M == "<< m << std::endl; 01365 BadArgument = true; } 01366 if( n < izero ) { 01367 std::cout << "BLAS::SYMM Error: N == "<< n << std::endl; 01368 BadArgument = true; } 01369 if( lda < NRowA ) { 01370 std::cout << "BLAS::SYMM Error: LDA < "<<NRowA<<std::endl; 01371 BadArgument = true; } 01372 if( ldb < m ) { 01373 std::cout << "BLAS::SYMM Error: LDB < MAX(1,M)"<<std::endl; 01374 BadArgument = true; } 01375 if( ldc < m ) { 01376 std::cout << "BLAS::SYMM Error: LDC < MAX(1,M)"<<std::endl; 01377 BadArgument = true; } 01378 01379 if(!BadArgument) { 01380 01381 // Only need to scale C and return. 01382 if (alpha == alpha_zero) { 01383 if (beta == beta_zero ) { 01384 for (j=izero; j<n; j++) { 01385 for (i=izero; i<m; i++) { 01386 C[j*ldc + i] = C_zero; 01387 } 01388 } 01389 } else { 01390 for (j=izero; j<n; j++) { 01391 for (i=izero; i<m; i++) { 01392 C[j*ldc + i] *= beta; 01393 } 01394 } 01395 } 01396 return; 01397 } 01398 01399 if ( ESideChar[side] == 'L') { 01400 // Compute C = alpha*A*B + beta*C 01401 01402 if (Upper) { 01403 // The symmetric part of A is stored in the upper triangular part of the matrix. 01404 for (j=izero; j<n; j++) { 01405 for (i=izero; i<m; i++) { 01406 temp1 = alpha*B[j*ldb + i]; 01407 temp2 = C_zero; 01408 for (k=izero; k<i; k++) { 01409 C[j*ldc + k] += temp1*A[i*lda + k]; 01410 temp2 += B[j*ldb + k]*A[i*lda + k]; 01411 } 01412 if (beta == beta_zero) { 01413 C[j*ldc + i] = temp1*A[i*lda + i] + alpha*temp2; 01414 } else { 01415 C[j*ldc + i] = beta*C[j*ldc + i] + temp1*A[i*lda + i] + alpha*temp2; 01416 } 01417 } 01418 } 01419 } else { 01420 // The symmetric part of A is stored in the lower triangular part of the matrix. 01421 for (j=izero; j<n; j++) { 01422 for (i=m-ione; i>-ione; i--) { 01423 temp1 = alpha*B[j*ldb + i]; 01424 temp2 = C_zero; 01425 for (k=i+ione; k<m; k++) { 01426 C[j*ldc + k] += temp1*A[i*lda + k]; 01427 temp2 += B[j*ldb + k]*A[i*lda + k]; 01428 } 01429 if (beta == beta_zero) { 01430 C[j*ldc + i] = temp1*A[i*lda + i] + alpha*temp2; 01431 } else { 01432 C[j*ldc + i] = beta*C[j*ldc + i] + temp1*A[i*lda + i] + alpha*temp2; 01433 } 01434 } 01435 } 01436 } 01437 } else { 01438 // Compute C = alpha*B*A + beta*C. 01439 for (j=izero; j<n; j++) { 01440 temp1 = alpha*A[j*lda + j]; 01441 if (beta == beta_zero) { 01442 for (i=izero; i<m; i++) { 01443 C[j*ldc + i] = temp1*B[j*ldb + i]; 01444 } 01445 } else { 01446 for (i=izero; i<m; i++) { 01447 C[j*ldc + i] = beta*C[j*ldc + i] + temp1*B[j*ldb + i]; 01448 } 01449 } 01450 for (k=izero; k<j; k++) { 01451 if (Upper) { 01452 temp1 = alpha*A[j*lda + k]; 01453 } else { 01454 temp1 = alpha*A[k*lda + j]; 01455 } 01456 for (i=izero; i<m; i++) { 01457 C[j*ldc + i] += temp1*B[k*ldb + i]; 01458 } 01459 } 01460 for (k=j+ione; k<n; k++) { 01461 if (Upper) { 01462 temp1 = alpha*A[k*lda + j]; 01463 } else { 01464 temp1 = alpha*A[j*lda + k]; 01465 } 01466 for (i=izero; i<m; i++) { 01467 C[j*ldc + i] += temp1*B[k*ldb + i]; 01468 } 01469 } 01470 } 01471 } // end if (ESideChar[side]=='L') ... 01472 } // end if(!BadArgument) ... 01473 } // end SYMM 01474 01475 template<typename OrdinalType, typename ScalarType> 01476 template <typename alpha_type, typename A_type, typename beta_type> 01477 void DefaultBLASImpl<OrdinalType, ScalarType>::SYRK(EUplo uplo, ETransp trans, const OrdinalType n, const OrdinalType k, const alpha_type alpha, const A_type* A, const OrdinalType lda, const beta_type beta, ScalarType* C, const OrdinalType ldc) const 01478 { 01479 typedef TypeNameTraits<OrdinalType> OTNT; 01480 typedef TypeNameTraits<ScalarType> STNT; 01481 01482 OrdinalType izero = OrdinalTraits<OrdinalType>::zero(); 01483 alpha_type alpha_zero = ScalarTraits<alpha_type>::zero(); 01484 beta_type beta_zero = ScalarTraits<beta_type>::zero(); 01485 beta_type beta_one = ScalarTraits<beta_type>::one(); 01486 A_type temp, A_zero = ScalarTraits<A_type>::zero(); 01487 ScalarType C_zero = ScalarTraits<ScalarType>::zero(); 01488 OrdinalType i, j, l, NRowA = n; 01489 bool BadArgument = false; 01490 bool Upper = (EUploChar[uplo] == 'U'); 01491 01492 TEUCHOS_TEST_FOR_EXCEPTION( 01493 Teuchos::ScalarTraits<ScalarType>::isComplex 01494 && (trans == CONJ_TRANS), 01495 std::logic_error, 01496 "Teuchos::BLAS<"<<OTNT::name()<<","<<STNT::name()<<">::SYRK()" 01497 " does not support CONJ_TRANS for complex data types." 01498 ); 01499 01500 // Change dimensions of A matrix is transposed 01501 if( !(ETranspChar[trans]=='N') ) { 01502 NRowA = k; 01503 } 01504 01505 // Quick return. 01506 if ( n==izero ) { return; } 01507 if ( ( (alpha==alpha_zero) || (k==izero) ) && (beta==beta_one) ) { return; } 01508 if( n < izero ) { 01509 std::cout << "BLAS::SYRK Error: N == "<< n <<std::endl; 01510 BadArgument = true; } 01511 if( k < izero ) { 01512 std::cout << "BLAS::SYRK Error: K == "<< k <<std::endl; 01513 BadArgument = true; } 01514 if( lda < NRowA ) { 01515 std::cout << "BLAS::SYRK Error: LDA < "<<NRowA<<std::endl; 01516 BadArgument = true; } 01517 if( ldc < n ) { 01518 std::cout << "BLAS::SYRK Error: LDC < MAX(1,N)"<<std::endl; 01519 BadArgument = true; } 01520 01521 if(!BadArgument) { 01522 01523 // Scale C when alpha is zero 01524 if (alpha == alpha_zero) { 01525 if (Upper) { 01526 if (beta==beta_zero) { 01527 for (j=izero; j<n; j++) { 01528 for (i=izero; i<=j; i++) { 01529 C[j*ldc + i] = C_zero; 01530 } 01531 } 01532 } 01533 else { 01534 for (j=izero; j<n; j++) { 01535 for (i=izero; i<=j; i++) { 01536 C[j*ldc + i] *= beta; 01537 } 01538 } 01539 } 01540 } 01541 else { 01542 if (beta==beta_zero) { 01543 for (j=izero; j<n; j++) { 01544 for (i=j; i<n; i++) { 01545 C[j*ldc + i] = C_zero; 01546 } 01547 } 01548 } 01549 else { 01550 for (j=izero; j<n; j++) { 01551 for (i=j; i<n; i++) { 01552 C[j*ldc + i] *= beta; 01553 } 01554 } 01555 } 01556 } 01557 return; 01558 } 01559 01560 // Now we can start the computation 01561 01562 if ( ETranspChar[trans]=='N' ) { 01563 01564 // Form C <- alpha*A*A' + beta*C 01565 if (Upper) { 01566 for (j=izero; j<n; j++) { 01567 if (beta==beta_zero) { 01568 for (i=izero; i<=j; i++) { 01569 C[j*ldc + i] = C_zero; 01570 } 01571 } 01572 else if (beta!=beta_one) { 01573 for (i=izero; i<=j; i++) { 01574 C[j*ldc + i] *= beta; 01575 } 01576 } 01577 for (l=izero; l<k; l++) { 01578 if (A[l*lda + j] != A_zero) { 01579 temp = alpha*A[l*lda + j]; 01580 for (i = izero; i <=j; i++) { 01581 C[j*ldc + i] += temp*A[l*lda + i]; 01582 } 01583 } 01584 } 01585 } 01586 } 01587 else { 01588 for (j=izero; j<n; j++) { 01589 if (beta==beta_zero) { 01590 for (i=j; i<n; i++) { 01591 C[j*ldc + i] = C_zero; 01592 } 01593 } 01594 else if (beta!=beta_one) { 01595 for (i=j; i<n; i++) { 01596 C[j*ldc + i] *= beta; 01597 } 01598 } 01599 for (l=izero; l<k; l++) { 01600 if (A[l*lda + j] != A_zero) { 01601 temp = alpha*A[l*lda + j]; 01602 for (i=j; i<n; i++) { 01603 C[j*ldc + i] += temp*A[l*lda + i]; 01604 } 01605 } 01606 } 01607 } 01608 } 01609 } 01610 else { 01611 01612 // Form C <- alpha*A'*A + beta*C 01613 if (Upper) { 01614 for (j=izero; j<n; j++) { 01615 for (i=izero; i<=j; i++) { 01616 temp = A_zero; 01617 for (l=izero; l<k; l++) { 01618 temp += A[i*lda + l]*A[j*lda + l]; 01619 } 01620 if (beta==beta_zero) { 01621 C[j*ldc + i] = alpha*temp; 01622 } 01623 else { 01624 C[j*ldc + i] = alpha*temp + beta*C[j*ldc + i]; 01625 } 01626 } 01627 } 01628 } 01629 else { 01630 for (j=izero; j<n; j++) { 01631 for (i=j; i<n; i++) { 01632 temp = A_zero; 01633 for (l=izero; l<k; ++l) { 01634 temp += A[i*lda + l]*A[j*lda + l]; 01635 } 01636 if (beta==beta_zero) { 01637 C[j*ldc + i] = alpha*temp; 01638 } 01639 else { 01640 C[j*ldc + i] = alpha*temp + beta*C[j*ldc + i]; 01641 } 01642 } 01643 } 01644 } 01645 } 01646 } /* if (!BadArgument) */ 01647 } /* END SYRK */ 01648 01649 template<typename OrdinalType, typename ScalarType> 01650 template <typename alpha_type, typename A_type> 01651 void DefaultBLASImpl<OrdinalType, ScalarType>::TRMM(ESide side, EUplo uplo, ETransp transa, EDiag diag, const OrdinalType m, const OrdinalType n, const alpha_type alpha, const A_type* A, const OrdinalType lda, ScalarType* B, const OrdinalType ldb) const 01652 { 01653 OrdinalType izero = OrdinalTraits<OrdinalType>::zero(); 01654 OrdinalType ione = OrdinalTraits<OrdinalType>::one(); 01655 alpha_type alpha_zero = ScalarTraits<alpha_type>::zero(); 01656 A_type A_zero = ScalarTraits<A_type>::zero(); 01657 ScalarType B_zero = ScalarTraits<ScalarType>::zero(); 01658 ScalarType one = ScalarTraits<ScalarType>::one(); 01659 OrdinalType i, j, k, NRowA = m; 01660 ScalarType temp; 01661 bool BadArgument = false; 01662 bool LSide = (ESideChar[side] == 'L'); 01663 bool noUnit = (EDiagChar[diag] == 'N'); 01664 bool Upper = (EUploChar[uplo] == 'U'); 01665 bool noConj = (EUploChar[transa] == 'T'); 01666 01667 if(!LSide) { NRowA = n; } 01668 01669 // Quick return. 01670 if (n==izero || m==izero) { return; } 01671 if( m < izero ) { 01672 std::cout << "BLAS::TRMM Error: M == "<< m <<std::endl; 01673 BadArgument = true; } 01674 if( n < izero ) { 01675 std::cout << "BLAS::TRMM Error: N == "<< n <<std::endl; 01676 BadArgument = true; } 01677 if( lda < NRowA ) { 01678 std::cout << "BLAS::TRMM Error: LDA < "<<NRowA<<std::endl; 01679 BadArgument = true; } 01680 if( ldb < m ) { 01681 std::cout << "BLAS::TRMM Error: LDB < MAX(1,M)"<<std::endl; 01682 BadArgument = true; } 01683 01684 if(!BadArgument) { 01685 01686 // B only needs to be zeroed out. 01687 if( alpha == alpha_zero ) { 01688 for( j=izero; j<n; j++ ) { 01689 for( i=izero; i<m; i++ ) { 01690 B[j*ldb + i] = B_zero; 01691 } 01692 } 01693 return; 01694 } 01695 01696 // Start the computations. 01697 if ( LSide ) { 01698 // A is on the left side of B. 01699 01700 if ( ETranspChar[transa]=='N' ) { 01701 // Compute B = alpha*A*B 01702 01703 if ( Upper ) { 01704 // A is upper triangular 01705 for( j=izero; j<n; j++ ) { 01706 for( k=izero; k<m; k++) { 01707 if ( B[j*ldb + k] != B_zero ) { 01708 temp = alpha*B[j*ldb + k]; 01709 for( i=izero; i<k; i++ ) { 01710 B[j*ldb + i] += temp*A[k*lda + i]; 01711 } 01712 if ( noUnit ) 01713 temp *=A[k*lda + k]; 01714 B[j*ldb + k] = temp; 01715 } 01716 } 01717 } 01718 } else { 01719 // A is lower triangular 01720 for( j=izero; j<n; j++ ) { 01721 for( k=m-ione; k>-ione; k-- ) { 01722 if( B[j*ldb + k] != B_zero ) { 01723 temp = alpha*B[j*ldb + k]; 01724 B[j*ldb + k] = temp; 01725 if ( noUnit ) 01726 B[j*ldb + k] *= A[k*lda + k]; 01727 for( i=k+ione; i<m; i++ ) { 01728 B[j*ldb + i] += temp*A[k*lda + i]; 01729 } 01730 } 01731 } 01732 } 01733 } 01734 } else { 01735 // Compute B = alpha*A'*B or B = alpha*conj(A')*B 01736 if( Upper ) { 01737 for( j=izero; j<n; j++ ) { 01738 for( i=m-ione; i>-ione; i-- ) { 01739 temp = B[j*ldb + i]; 01740 if ( noConj ) { 01741 if( noUnit ) 01742 temp *= A[i*lda + i]; 01743 for( k=izero; k<i; k++ ) { 01744 temp += A[i*lda + k]*B[j*ldb + k]; 01745 } 01746 } else { 01747 if( noUnit ) 01748 temp *= ScalarTraits<A_type>::conjugate(A[i*lda + i]); 01749 for( k=izero; k<i; k++ ) { 01750 temp += ScalarTraits<A_type>::conjugate(A[i*lda + k])*B[j*ldb + k]; 01751 } 01752 } 01753 B[j*ldb + i] = alpha*temp; 01754 } 01755 } 01756 } else { 01757 for( j=izero; j<n; j++ ) { 01758 for( i=izero; i<m; i++ ) { 01759 temp = B[j*ldb + i]; 01760 if ( noConj ) { 01761 if( noUnit ) 01762 temp *= A[i*lda + i]; 01763 for( k=i+ione; k<m; k++ ) { 01764 temp += A[i*lda + k]*B[j*ldb + k]; 01765 } 01766 } else { 01767 if( noUnit ) 01768 temp *= ScalarTraits<A_type>::conjugate(A[i*lda + i]); 01769 for( k=i+ione; k<m; k++ ) { 01770 temp += ScalarTraits<A_type>::conjugate(A[i*lda + k])*B[j*ldb + k]; 01771 } 01772 } 01773 B[j*ldb + i] = alpha*temp; 01774 } 01775 } 01776 } 01777 } 01778 } else { 01779 // A is on the right hand side of B. 01780 01781 if( ETranspChar[transa] == 'N' ) { 01782 // Compute B = alpha*B*A 01783 01784 if( Upper ) { 01785 // A is upper triangular. 01786 for( j=n-ione; j>-ione; j-- ) { 01787 temp = alpha; 01788 if( noUnit ) 01789 temp *= A[j*lda + j]; 01790 for( i=izero; i<m; i++ ) { 01791 B[j*ldb + i] *= temp; 01792 } 01793 for( k=izero; k<j; k++ ) { 01794 if( A[j*lda + k] != A_zero ) { 01795 temp = alpha*A[j*lda + k]; 01796 for( i=izero; i<m; i++ ) { 01797 B[j*ldb + i] += temp*B[k*ldb + i]; 01798 } 01799 } 01800 } 01801 } 01802 } else { 01803 // A is lower triangular. 01804 for( j=izero; j<n; j++ ) { 01805 temp = alpha; 01806 if( noUnit ) 01807 temp *= A[j*lda + j]; 01808 for( i=izero; i<m; i++ ) { 01809 B[j*ldb + i] *= temp; 01810 } 01811 for( k=j+ione; k<n; k++ ) { 01812 if( A[j*lda + k] != A_zero ) { 01813 temp = alpha*A[j*lda + k]; 01814 for( i=izero; i<m; i++ ) { 01815 B[j*ldb + i] += temp*B[k*ldb + i]; 01816 } 01817 } 01818 } 01819 } 01820 } 01821 } else { 01822 // Compute B = alpha*B*A' or B = alpha*B*conj(A') 01823 01824 if( Upper ) { 01825 for( k=izero; k<n; k++ ) { 01826 for( j=izero; j<k; j++ ) { 01827 if( A[k*lda + j] != A_zero ) { 01828 if ( noConj ) 01829 temp = alpha*A[k*lda + j]; 01830 else 01831 temp = alpha*ScalarTraits<A_type>::conjugate(A[k*lda + j]); 01832 for( i=izero; i<m; i++ ) { 01833 B[j*ldb + i] += temp*B[k*ldb + i]; 01834 } 01835 } 01836 } 01837 temp = alpha; 01838 if( noUnit ) { 01839 if ( noConj ) 01840 temp *= A[k*lda + k]; 01841 else 01842 temp *= ScalarTraits<A_type>::conjugate(A[k*lda + k]); 01843 } 01844 if( temp != one ) { 01845 for( i=izero; i<m; i++ ) { 01846 B[k*ldb + i] *= temp; 01847 } 01848 } 01849 } 01850 } else { 01851 for( k=n-ione; k>-ione; k-- ) { 01852 for( j=k+ione; j<n; j++ ) { 01853 if( A[k*lda + j] != A_zero ) { 01854 if ( noConj ) 01855 temp = alpha*A[k*lda + j]; 01856 else 01857 temp = alpha*ScalarTraits<A_type>::conjugate(A[k*lda + j]); 01858 for( i=izero; i<m; i++ ) { 01859 B[j*ldb + i] += temp*B[k*ldb + i]; 01860 } 01861 } 01862 } 01863 temp = alpha; 01864 if( noUnit ) { 01865 if ( noConj ) 01866 temp *= A[k*lda + k]; 01867 else 01868 temp *= ScalarTraits<A_type>::conjugate(A[k*lda + k]); 01869 } 01870 if( temp != one ) { 01871 for( i=izero; i<m; i++ ) { 01872 B[k*ldb + i] *= temp; 01873 } 01874 } 01875 } 01876 } 01877 } // end if( ETranspChar[transa] == 'N' ) ... 01878 } // end if ( LSide ) ... 01879 } // end if (!BadArgument) 01880 } // end TRMM 01881 01882 template<typename OrdinalType, typename ScalarType> 01883 template <typename alpha_type, typename A_type> 01884 void DefaultBLASImpl<OrdinalType, ScalarType>::TRSM(ESide side, EUplo uplo, ETransp transa, EDiag diag, const OrdinalType m, const OrdinalType n, const alpha_type alpha, const A_type* A, const OrdinalType lda, ScalarType* B, const OrdinalType ldb) const 01885 { 01886 OrdinalType izero = OrdinalTraits<OrdinalType>::zero(); 01887 OrdinalType ione = OrdinalTraits<OrdinalType>::one(); 01888 alpha_type alpha_zero = ScalarTraits<alpha_type>::zero(); 01889 A_type A_zero = ScalarTraits<A_type>::zero(); 01890 ScalarType B_zero = ScalarTraits<ScalarType>::zero(); 01891 alpha_type alpha_one = ScalarTraits<alpha_type>::one(); 01892 ScalarType B_one = ScalarTraits<ScalarType>::one(); 01893 ScalarType temp; 01894 OrdinalType NRowA = m; 01895 bool BadArgument = false; 01896 bool noUnit = (EDiagChar[diag]=='N'); 01897 bool noConj = (EUploChar[transa] == 'T'); 01898 01899 if (!(ESideChar[side] == 'L')) { NRowA = n; } 01900 01901 // Quick return. 01902 if (n == izero || m == izero) { return; } 01903 if( m < izero ) { 01904 std::cout << "BLAS::TRSM Error: M == "<<m<<std::endl; 01905 BadArgument = true; } 01906 if( n < izero ) { 01907 std::cout << "BLAS::TRSM Error: N == "<<n<<std::endl; 01908 BadArgument = true; } 01909 if( lda < NRowA ) { 01910 std::cout << "BLAS::TRSM Error: LDA < "<<NRowA<<std::endl; 01911 BadArgument = true; } 01912 if( ldb < m ) { 01913 std::cout << "BLAS::TRSM Error: LDB < MAX(1,M)"<<std::endl; 01914 BadArgument = true; } 01915 01916 if(!BadArgument) 01917 { 01918 int i, j, k; 01919 // Set the solution to the zero std::vector. 01920 if(alpha == alpha_zero) { 01921 for(j = izero; j < n; j++) { 01922 for( i = izero; i < m; i++) { 01923 B[j*ldb+i] = B_zero; 01924 } 01925 } 01926 } 01927 else 01928 { // Start the operations. 01929 if(ESideChar[side] == 'L') { 01930 // 01931 // Perform computations for OP(A)*X = alpha*B 01932 // 01933 if(ETranspChar[transa] == 'N') { 01934 // 01935 // Compute B = alpha*inv( A )*B 01936 // 01937 if(EUploChar[uplo] == 'U') { 01938 // A is upper triangular. 01939 for(j = izero; j < n; j++) { 01940 // Perform alpha*B if alpha is not 1. 01941 if(alpha != alpha_one) { 01942 for( i = izero; i < m; i++) { 01943 B[j*ldb+i] *= alpha; 01944 } 01945 } 01946 // Perform a backsolve for column j of B. 01947 for(k = (m - ione); k > -ione; k--) { 01948 // If this entry is zero, we don't have to do anything. 01949 if (B[j*ldb + k] != B_zero) { 01950 if ( noUnit ) { 01951 B[j*ldb + k] /= A[k*lda + k]; 01952 } 01953 for(i = izero; i < k; i++) { 01954 B[j*ldb + i] -= B[j*ldb + k] * A[k*lda + i]; 01955 } 01956 } 01957 } 01958 } 01959 } 01960 else 01961 { // A is lower triangular. 01962 for(j = izero; j < n; j++) { 01963 // Perform alpha*B if alpha is not 1. 01964 if(alpha != alpha_one) { 01965 for( i = izero; i < m; i++) { 01966 B[j*ldb+i] *= alpha; 01967 } 01968 } 01969 // Perform a forward solve for column j of B. 01970 for(k = izero; k < m; k++) { 01971 // If this entry is zero, we don't have to do anything. 01972 if (B[j*ldb + k] != B_zero) { 01973 if ( noUnit ) { 01974 B[j*ldb + k] /= A[k*lda + k]; 01975 } 01976 for(i = k+ione; i < m; i++) { 01977 B[j*ldb + i] -= B[j*ldb + k] * A[k*lda + i]; 01978 } 01979 } 01980 } 01981 } 01982 } // end if (uplo == 'U') 01983 } // if (transa =='N') 01984 else { 01985 // 01986 // Compute B = alpha*inv( A' )*B 01987 // or B = alpha*inv( conj(A') )*B 01988 // 01989 if(EUploChar[uplo] == 'U') { 01990 // A is upper triangular. 01991 for(j = izero; j < n; j++) { 01992 for( i = izero; i < m; i++) { 01993 temp = alpha*B[j*ldb+i]; 01994 if ( noConj ) { 01995 for(k = izero; k < i; k++) { 01996 temp -= A[i*lda + k] * B[j*ldb + k]; 01997 } 01998 if ( noUnit ) { 01999 temp /= A[i*lda + i]; 02000 } 02001 } else { 02002 for(k = izero; k < i; k++) { 02003 temp -= ScalarTraits<A_type>::conjugate(A[i*lda + k]) 02004 * B[j*ldb + k]; 02005 } 02006 if ( noUnit ) { 02007 temp /= ScalarTraits<A_type>::conjugate(A[i*lda + i]); 02008 } 02009 } 02010 B[j*ldb + i] = temp; 02011 } 02012 } 02013 } 02014 else 02015 { // A is lower triangular. 02016 for(j = izero; j < n; j++) { 02017 for(i = (m - ione) ; i > -ione; i--) { 02018 temp = alpha*B[j*ldb+i]; 02019 if ( noConj ) { 02020 for(k = i+ione; k < m; k++) { 02021 temp -= A[i*lda + k] * B[j*ldb + k]; 02022 } 02023 if ( noUnit ) { 02024 temp /= A[i*lda + i]; 02025 } 02026 } else { 02027 for(k = i+ione; k < m; k++) { 02028 temp -= ScalarTraits<A_type>::conjugate(A[i*lda + k]) 02029 * B[j*ldb + k]; 02030 } 02031 if ( noUnit ) { 02032 temp /= ScalarTraits<A_type>::conjugate(A[i*lda + i]); 02033 } 02034 } 02035 B[j*ldb + i] = temp; 02036 } 02037 } 02038 } 02039 } 02040 } // if (side == 'L') 02041 else { 02042 // side == 'R' 02043 // 02044 // Perform computations for X*OP(A) = alpha*B 02045 // 02046 if (ETranspChar[transa] == 'N') { 02047 // 02048 // Compute B = alpha*B*inv( A ) 02049 // 02050 if(EUploChar[uplo] == 'U') { 02051 // A is upper triangular. 02052 // Perform a backsolve for column j of B. 02053 for(j = izero; j < n; j++) { 02054 // Perform alpha*B if alpha is not 1. 02055 if(alpha != alpha_one) { 02056 for( i = izero; i < m; i++) { 02057 B[j*ldb+i] *= alpha; 02058 } 02059 } 02060 for(k = izero; k < j; k++) { 02061 // If this entry is zero, we don't have to do anything. 02062 if (A[j*lda + k] != A_zero) { 02063 for(i = izero; i < m; i++) { 02064 B[j*ldb + i] -= A[j*lda + k] * B[k*ldb + i]; 02065 } 02066 } 02067 } 02068 if ( noUnit ) { 02069 temp = B_one/A[j*lda + j]; 02070 for(i = izero; i < m; i++) { 02071 B[j*ldb + i] *= temp; 02072 } 02073 } 02074 } 02075 } 02076 else 02077 { // A is lower triangular. 02078 for(j = (n - ione); j > -ione; j--) { 02079 // Perform alpha*B if alpha is not 1. 02080 if(alpha != alpha_one) { 02081 for( i = izero; i < m; i++) { 02082 B[j*ldb+i] *= alpha; 02083 } 02084 } 02085 // Perform a forward solve for column j of B. 02086 for(k = j+ione; k < n; k++) { 02087 // If this entry is zero, we don't have to do anything. 02088 if (A[j*lda + k] != A_zero) { 02089 for(i = izero; i < m; i++) { 02090 B[j*ldb + i] -= A[j*lda + k] * B[k*ldb + i]; 02091 } 02092 } 02093 } 02094 if ( noUnit ) { 02095 temp = B_one/A[j*lda + j]; 02096 for(i = izero; i < m; i++) { 02097 B[j*ldb + i] *= temp; 02098 } 02099 } 02100 } 02101 } // end if (uplo == 'U') 02102 } // if (transa =='N') 02103 else { 02104 // 02105 // Compute B = alpha*B*inv( A' ) 02106 // or B = alpha*B*inv( conj(A') ) 02107 // 02108 if(EUploChar[uplo] == 'U') { 02109 // A is upper triangular. 02110 for(k = (n - ione); k > -ione; k--) { 02111 if ( noUnit ) { 02112 if ( noConj ) 02113 temp = B_one/A[k*lda + k]; 02114 else 02115 temp = B_one/ScalarTraits<A_type>::conjugate(A[k*lda + k]); 02116 for(i = izero; i < m; i++) { 02117 B[k*ldb + i] *= temp; 02118 } 02119 } 02120 for(j = izero; j < k; j++) { 02121 if (A[k*lda + j] != A_zero) { 02122 if ( noConj ) 02123 temp = A[k*lda + j]; 02124 else 02125 temp = ScalarTraits<A_type>::conjugate(A[k*lda + j]); 02126 for(i = izero; i < m; i++) { 02127 B[j*ldb + i] -= temp*B[k*ldb + i]; 02128 } 02129 } 02130 } 02131 if (alpha != alpha_one) { 02132 for (i = izero; i < m; i++) { 02133 B[k*ldb + i] *= alpha; 02134 } 02135 } 02136 } 02137 } 02138 else 02139 { // A is lower triangular. 02140 for(k = izero; k < n; k++) { 02141 if ( noUnit ) { 02142 if ( noConj ) 02143 temp = B_one/A[k*lda + k]; 02144 else 02145 temp = B_one/ScalarTraits<A_type>::conjugate(A[k*lda + k]); 02146 for (i = izero; i < m; i++) { 02147 B[k*ldb + i] *= temp; 02148 } 02149 } 02150 for(j = k+ione; j < n; j++) { 02151 if(A[k*lda + j] != A_zero) { 02152 if ( noConj ) 02153 temp = A[k*lda + j]; 02154 else 02155 temp = ScalarTraits<A_type>::conjugate(A[k*lda + j]); 02156 for(i = izero; i < m; i++) { 02157 B[j*ldb + i] -= temp*B[k*ldb + i]; 02158 } 02159 } 02160 } 02161 if (alpha != alpha_one) { 02162 for (i = izero; i < m; i++) { 02163 B[k*ldb + i] *= alpha; 02164 } 02165 } 02166 } 02167 } 02168 } 02169 } 02170 } 02171 } 02172 } 02173 02174 // Explicit instantiation for template<int,float> 02175 02176 template <> 02177 class TEUCHOS_LIB_DLL_EXPORT BLAS<int, float> 02178 { 02179 public: 02180 inline BLAS(void) {} 02181 inline BLAS(const BLAS<int, float>& /*BLAS_source*/) {} 02182 inline virtual ~BLAS(void) {} 02183 void ROTG(float* da, float* db, float* c, float* s) const; 02184 void ROT(const int n, float* dx, const int incx, float* dy, const int incy, float* c, float* s) const; 02185 float ASUM(const int n, const float* x, const int incx) const; 02186 void AXPY(const int n, const float alpha, const float* x, const int incx, float* y, const int incy) const; 02187 void COPY(const int n, const float* x, const int incx, float* y, const int incy) const; 02188 float DOT(const int n, const float* x, const int incx, const float* y, const int incy) const; 02189 float NRM2(const int n, const float* x, const int incx) const; 02190 void SCAL(const int n, const float alpha, float* x, const int incx) const; 02191 int IAMAX(const int n, const float* x, const int incx) const; 02192 void GEMV(ETransp trans, const int m, const int n, const float alpha, const float* A, const int lda, const float* x, const int incx, const float beta, float* y, const int incy) const; 02193 void TRMV(EUplo uplo, ETransp trans, EDiag diag, const int n, const float* A, const int lda, float* x, const int incx) const; 02194 void GER(const int m, const int n, const float alpha, const float* x, const int incx, const float* y, const int incy, float* A, const int lda) const; 02195 void GEMM(ETransp transa, ETransp transb, const int m, const int n, const int k, const float alpha, const float* A, const int lda, const float* B, const int ldb, const float beta, float* C, const int ldc) const; 02196 void SYMM(ESide side, EUplo uplo, const int m, const int n, const float alpha, const float* A, const int lda, const float *B, const int ldb, const float beta, float *C, const int ldc) const; 02197 void SYRK(EUplo uplo, ETransp trans, const int n, const int k, const float alpha, const float* A, const int lda, const float beta, float* C, const int ldc) const; 02198 void TRMM(ESide side, EUplo uplo, ETransp transa, EDiag diag, const int m, const int n, const float alpha, const float* A, const int lda, float* B, const int ldb) const; 02199 void TRSM(ESide side, EUplo uplo, ETransp transa, EDiag diag, const int m, const int n, const float alpha, const float* A, const int lda, float* B, const int ldb) const; 02200 }; 02201 02202 // Explicit instantiation for template<int,double> 02203 02204 template<> 02205 class TEUCHOS_LIB_DLL_EXPORT BLAS<int, double> 02206 { 02207 public: 02208 inline BLAS(void) {} 02209 inline BLAS(const BLAS<int, double>& /*BLAS_source*/) {} 02210 inline virtual ~BLAS(void) {} 02211 void ROTG(double* da, double* db, double* c, double* s) const; 02212 void ROT(const int n, double* dx, const int incx, double* dy, const int incy, double* c, double* s) const; 02213 double ASUM(const int n, const double* x, const int incx) const; 02214 void AXPY(const int n, const double alpha, const double* x, const int incx, double* y, const int incy) const; 02215 void COPY(const int n, const double* x, const int incx, double* y, const int incy) const; 02216 double DOT(const int n, const double* x, const int incx, const double* y, const int incy) const; 02217 double NRM2(const int n, const double* x, const int incx) const; 02218 void SCAL(const int n, const double alpha, double* x, const int incx) const; 02219 int IAMAX(const int n, const double* x, const int incx) const; 02220 void GEMV(ETransp trans, const int m, const int n, const double alpha, const double* A, const int lda, const double* x, const int incx, const double beta, double* y, const int incy) const; 02221 void TRMV(EUplo uplo, ETransp trans, EDiag diag, const int n, const double* A, const int lda, double* x, const int incx) const; 02222 void GER(const int m, const int n, const double alpha, const double* x, const int incx, const double* y, const int incy, double* A, const int lda) const; 02223 void GEMM(ETransp transa, ETransp transb, const int m, const int n, const int k, const double alpha, const double* A, const int lda, const double* B, const int ldb, const double beta, double* C, const int ldc) const; 02224 void SYMM(ESide side, EUplo uplo, const int m, const int n, const double alpha, const double* A, const int lda, const double *B, const int ldb, const double beta, double *C, const int ldc) const; 02225 void SYRK(EUplo uplo, ETransp trans, const int n, const int k, const double alpha, const double* A, const int lda, const double beta, double* C, const int ldc) const; 02226 void TRMM(ESide side, EUplo uplo, ETransp transa, EDiag diag, const int m, const int n, const double alpha, const double* A, const int lda, double* B, const int ldb) const; 02227 void TRSM(ESide side, EUplo uplo, ETransp transa, EDiag diag, const int m, const int n, const double alpha, const double* A, const int lda, double* B, const int ldb) const; 02228 }; 02229 02230 // Explicit instantiation for template<int,complex<float> > 02231 02232 template<> 02233 class TEUCHOS_LIB_DLL_EXPORT BLAS<int, std::complex<float> > 02234 { 02235 public: 02236 inline BLAS(void) {} 02237 inline BLAS(const BLAS<int, std::complex<float> >& /*BLAS_source*/) {} 02238 inline virtual ~BLAS(void) {} 02239 void ROTG(std::complex<float>* da, std::complex<float>* db, float* c, std::complex<float>* s) const; 02240 void ROT(const int n, std::complex<float>* dx, const int incx, std::complex<float>* dy, const int incy, float* c, std::complex<float>* s) const; 02241 float ASUM(const int n, const std::complex<float>* x, const int incx) const; 02242 void AXPY(const int n, const std::complex<float> alpha, const std::complex<float>* x, const int incx, std::complex<float>* y, const int incy) const; 02243 void COPY(const int n, const std::complex<float>* x, const int incx, std::complex<float>* y, const int incy) const; 02244 std::complex<float> DOT(const int n, const std::complex<float>* x, const int incx, const std::complex<float>* y, const int incy) const; 02245 float NRM2(const int n, const std::complex<float>* x, const int incx) const; 02246 void SCAL(const int n, const std::complex<float> alpha, std::complex<float>* x, const int incx) const; 02247 int IAMAX(const int n, const std::complex<float>* x, const int incx) const; 02248 void GEMV(ETransp trans, const int m, const int n, const std::complex<float> alpha, const std::complex<float>* A, const int lda, const std::complex<float>* x, const int incx, const std::complex<float> beta, std::complex<float>* y, const int incy) const; 02249 void TRMV(EUplo uplo, ETransp trans, EDiag diag, const int n, const std::complex<float>* A, const int lda, std::complex<float>* x, const int incx) const; 02250 void GER(const int m, const int n, const std::complex<float> alpha, const std::complex<float>* x, const int incx, const std::complex<float>* y, const int incy, std::complex<float>* A, const int lda) const; 02251 void GEMM(ETransp transa, ETransp transb, const int m, const int n, const int k, const std::complex<float> alpha, const std::complex<float>* A, const int lda, const std::complex<float>* B, const int ldb, const std::complex<float> beta, std::complex<float>* C, const int ldc) const; 02252 void SYMM(ESide side, EUplo uplo, const int m, const int n, const std::complex<float> alpha, const std::complex<float>* A, const int lda, const std::complex<float> *B, const int ldb, const std::complex<float> beta, std::complex<float> *C, const int ldc) const; 02253 void SYRK(EUplo uplo, ETransp trans, const int n, const int k, const std::complex<float> alpha, const std::complex<float>* A, const int lda, const std::complex<float> beta, std::complex<float>* C, const int ldc) const; 02254 void TRMM(ESide side, EUplo uplo, ETransp transa, EDiag diag, const int m, const int n, const std::complex<float> alpha, const std::complex<float>* A, const int lda, std::complex<float>* B, const int ldb) const; 02255 void TRSM(ESide side, EUplo uplo, ETransp transa, EDiag diag, const int m, const int n, const std::complex<float> alpha, const std::complex<float>* A, const int lda, std::complex<float>* B, const int ldb) const; 02256 }; 02257 02258 // Explicit instantiation for template<int,complex<double> > 02259 02260 template<> 02261 class TEUCHOS_LIB_DLL_EXPORT BLAS<int, std::complex<double> > 02262 { 02263 public: 02264 inline BLAS(void) {} 02265 inline BLAS(const BLAS<int, std::complex<double> >& /*BLAS_source*/) {} 02266 inline virtual ~BLAS(void) {} 02267 void ROTG(std::complex<double>* da, std::complex<double>* db, double* c, std::complex<double>* s) const; 02268 void ROT(const int n, std::complex<double>* dx, const int incx, std::complex<double>* dy, const int incy, double* c, std::complex<double>* s) const; 02269 double ASUM(const int n, const std::complex<double>* x, const int incx) const; 02270 void AXPY(const int n, const std::complex<double> alpha, const std::complex<double>* x, const int incx, std::complex<double>* y, const int incy) const; 02271 void COPY(const int n, const std::complex<double>* x, const int incx, std::complex<double>* y, const int incy) const; 02272 std::complex<double> DOT(const int n, const std::complex<double>* x, const int incx, const std::complex<double>* y, const int incy) const; 02273 double NRM2(const int n, const std::complex<double>* x, const int incx) const; 02274 void SCAL(const int n, const std::complex<double> alpha, std::complex<double>* x, const int incx) const; 02275 int IAMAX(const int n, const std::complex<double>* x, const int incx) const; 02276 void GEMV(ETransp trans, const int m, const int n, const std::complex<double> alpha, const std::complex<double>* A, const int lda, const std::complex<double>* x, const int incx, const std::complex<double> beta, std::complex<double>* y, const int incy) const; 02277 void TRMV(EUplo uplo, ETransp trans, EDiag diag, const int n, const std::complex<double>* A, const int lda, std::complex<double>* x, const int incx) const; 02278 void GER(const int m, const int n, const std::complex<double> alpha, const std::complex<double>* x, const int incx, const std::complex<double>* y, const int incy, std::complex<double>* A, const int lda) const; 02279 void GEMM(ETransp transa, ETransp transb, const int m, const int n, const int k, const std::complex<double> alpha, const std::complex<double>* A, const int lda, const std::complex<double>* B, const int ldb, const std::complex<double> beta, std::complex<double>* C, const int ldc) const; 02280 void SYMM(ESide side, EUplo uplo, const int m, const int n, const std::complex<double> alpha, const std::complex<double>* A, const int lda, const std::complex<double> *B, const int ldb, const std::complex<double> beta, std::complex<double> *C, const int ldc) const; 02281 void SYRK(EUplo uplo, ETransp trans, const int n, const int k, const std::complex<double> alpha, const std::complex<double>* A, const int lda, const std::complex<double> beta, std::complex<double>* C, const int ldc) const; 02282 void TRMM(ESide side, EUplo uplo, ETransp transa, EDiag diag, const int m, const int n, const std::complex<double> alpha, const std::complex<double>* A, const int lda, std::complex<double>* B, const int ldb) const; 02283 void TRSM(ESide side, EUplo uplo, ETransp transa, EDiag diag, const int m, const int n, const std::complex<double> alpha, const std::complex<double>* A, const int lda, std::complex<double>* B, const int ldb) const; 02284 }; 02285 02286 } // namespace Teuchos 02287 02288 #endif // _TEUCHOS_BLAS_HPP_
1.7.4