Teuchos_BLAS.hpp

Go to the documentation of this file.
00001 // @HEADER
00002 // ***********************************************************************
00003 // 
00004 //                    Teuchos: Common Tools Package
00005 //                 Copyright (2004) Sandia Corporation
00006 // 
00007 // Under terms of Contract DE-AC04-94AL85000, there is a non-exclusive
00008 // license for use of this work by or on behalf of the U.S. Government.
00009 // 
00010 // This library is free software; you can redistribute it and/or modify
00011 // it under the terms of the GNU Lesser General Public License as
00012 // published by the Free Software Foundation; either version 2.1 of the
00013 // License, or (at your option) any later version.
00014 //  
00015 // This library is distributed in the hope that it will be useful, but
00016 // WITHOUT ANY WARRANTY; without even the implied warranty of
00017 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
00018 // Lesser General Public License for more details.
00019 //  
00020 // You should have received a copy of the GNU Lesser General Public
00021 // License along with this library; if not, write to the Free Software
00022 // Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307
00023 // USA
00024 // Questions? Contact Michael A. Heroux (maherou@sandia.gov) 
00025 // 
00026 // ***********************************************************************
00027 // @HEADER
00028 
00029 // Kris
00030 // 06.16.03 -- Start over from scratch
00031 // 06.16.03 -- Initial templatization (Tpetra_BLAS.cpp is no longer needed)
00032 // 06.18.03 -- Changed xxxxx_() function calls to XXXXX_F77()
00033 //          -- Added warning messages for generic calls
00034 // 07.08.03 -- Move into Teuchos package/namespace
00035 // 07.24.03 -- The first iteration of BLAS generics is nearing completion. Caveats:
00036 //             * TRSM isn't finished yet; it works for one case at the moment (left side, upper tri., no transpose, no unit diag.)
00037 //             * Many of the generic implementations are quite inefficient, ugly, or both. I wrote these to be easy to debug, not for efficiency or legibility. The next iteration will improve both of these aspects as much as possible.
00038 //             * Very little verification of input parameters is done, save for the character-type arguments (TRANS, etc.) which is quite robust.
00039 //             * All of the routines that make use of both an incx and incy parameter (which includes much of the L1 BLAS) are set up to work iff incx == incy && incx > 0. Allowing for differing/negative values of incx/incy should be relatively trivial.
00040 //             * All of the L2/L3 routines assume that the entire matrix is being used (that is, if A is mxn, lda = m); they don't work on submatrices yet. This *should* be a reasonably trivial thing to fix, as well.
00041 //          -- Removed warning messages for generic calls
00042 // 08.08.03 -- TRSM now works for all cases where SIDE == L and DIAG == N. DIAG == U is implemented but does not work correctly; SIDE == R is not yet implemented.
00043 // 08.14.03 -- TRSM now works for all cases and accepts (and uses) leading-dimension information.
00044 // 09.26.03 -- character input replaced with enumerated input to cause compiling errors and not run-time errors ( suggested by RAB ).
00045 
00046 #ifndef _TEUCHOS_BLAS_HPP_
00047 #define _TEUCHOS_BLAS_HPP_
00048 
00056 /* for INTEL_CXML, the second arg may need to be changed to 'one'.  If so
00057 the appropriate declaration of one will need to be added back into
00058 functions that include the macro:
00059 */
00060 #if defined (INTEL_CXML)
00061         unsigned int one=1;
00062 #endif
00063 
00064 #ifdef CHAR_MACRO
00065 #undef CHAR_MACRO
00066 #endif
00067 #if defined (INTEL_CXML)
00068 #define CHAR_MACRO(char_var) &char_var, one
00069 #else
00070 #define CHAR_MACRO(char_var) &char_var
00071 #endif
00072 
00073 #include "Teuchos_ConfigDefs.hpp"
00074 #include "Teuchos_ScalarTraits.hpp"
00075 #include "Teuchos_OrdinalTraits.hpp"
00076 #include "Teuchos_BLAS_types.hpp"
00077 
00111 namespace Teuchos
00112 {
00113   extern const char ESideChar[];
00114   extern const char ETranspChar[];
00115   extern const char EUploChar[];
00116   extern const char EDiagChar[];
00117 
00118   template<typename OrdinalType, typename ScalarType>
00119   class BLAS
00120   {    
00121 
00122     typedef typename Teuchos::ScalarTraits<ScalarType>::magnitudeType MagnitudeType;
00123     
00124   public:
00126 
00127     
00129     inline BLAS(void) {}
00130 
00132     inline BLAS(const BLAS<OrdinalType, ScalarType>& /*BLAS_source*/) {}
00133 
00135     inline virtual ~BLAS(void) {}
00137 
00139 
00140 
00142     void ROTG(ScalarType* da, ScalarType* db, MagnitudeType* c, ScalarType* s) const;
00143 
00145     void ROT(const OrdinalType n, ScalarType* dx, const OrdinalType incx, ScalarType* dy, const OrdinalType incy, MagnitudeType* c, ScalarType* s) const;
00146 
00148     void SCAL(const OrdinalType n, const ScalarType alpha, ScalarType* x, const OrdinalType incx) const;
00149 
00151     void COPY(const OrdinalType n, const ScalarType* x, const OrdinalType incx, ScalarType* y, const OrdinalType incy) const;
00152 
00154     void AXPY(const OrdinalType n, const ScalarType alpha, const ScalarType* x, const OrdinalType incx, ScalarType* y, const OrdinalType incy) const;
00155 
00157     typename ScalarTraits<ScalarType>::magnitudeType ASUM(const OrdinalType n, const ScalarType* x, const OrdinalType incx) const;
00158 
00160     ScalarType DOT(const OrdinalType n, const ScalarType* x, const OrdinalType incx, const ScalarType* y, const OrdinalType incy) const;
00161 
00163     typename ScalarTraits<ScalarType>::magnitudeType NRM2(const OrdinalType n, const ScalarType* x, const OrdinalType incx) const;
00164 
00166     OrdinalType IAMAX(const OrdinalType n, const ScalarType* x, const OrdinalType incx) const;
00167 
00169 
00171 
00172 
00174     void GEMV(ETransp trans, const OrdinalType m, const OrdinalType n, const ScalarType alpha, const ScalarType* A, 
00175         const OrdinalType lda, const ScalarType* x, const OrdinalType incx, const ScalarType beta, ScalarType* y, const OrdinalType incy) const;
00176 
00178     void TRMV(EUplo uplo, ETransp trans, EDiag diag, const OrdinalType n, const ScalarType* A, 
00179         const OrdinalType lda, ScalarType* x, const OrdinalType incx) const;
00180 
00182     void GER(const OrdinalType m, const OrdinalType n, const ScalarType alpha, const ScalarType* x, const OrdinalType incx, 
00183        const ScalarType* y, const OrdinalType incy, ScalarType* A, const OrdinalType lda) const;
00185     
00187 
00188 
00190     void GEMM(ETransp transa, ETransp transb, const OrdinalType m, const OrdinalType n, const OrdinalType k, const ScalarType alpha, const ScalarType* A, const OrdinalType lda, const ScalarType* B, const OrdinalType ldb, const ScalarType beta, ScalarType* C, const OrdinalType ldc) const;
00191 
00193     void SYMM(ESide side, EUplo uplo, const OrdinalType m, const OrdinalType n, const ScalarType alpha, const ScalarType* A, const OrdinalType lda, const ScalarType* B, const OrdinalType ldb, const ScalarType beta, ScalarType* C, const OrdinalType ldc) const;
00194 
00196     void TRMM(ESide side, EUplo uplo, ETransp transa, EDiag diag, const OrdinalType m, const OrdinalType n,
00197                 const ScalarType alpha, const ScalarType* A, const OrdinalType lda, ScalarType* B, const OrdinalType ldb) const;
00198 
00200     void TRSM(ESide side, EUplo uplo, ETransp transa, EDiag diag, const OrdinalType m, const OrdinalType n,
00201                 const ScalarType alpha, const ScalarType* A, const OrdinalType lda, ScalarType* B, const OrdinalType ldb) const;
00203   };
00204 
00205 //------------------------------------------------------------------------------------------
00206 //      LEVEL 1 BLAS ROUTINES  
00207 //------------------------------------------------------------------------------------------
00208     
00209   template<typename OrdinalType, typename ScalarType>
00210   void BLAS<OrdinalType, ScalarType>::ROTG(ScalarType* da, ScalarType* db, MagnitudeType* c, ScalarType* s) const
00211   {
00212     ScalarType scale, r;
00213     ScalarType roe = ScalarTraits<ScalarType>::zero();
00214     ScalarType zero = ScalarTraits<ScalarType>::zero();
00215     ScalarType one = ScalarTraits<ScalarType>::one();
00216 
00217     if ( ScalarTraits<ScalarType>::magnitude( *da ) > ScalarTraits<ScalarType>::magnitude( *db ) ) { roe = *da; }
00218     scale = ScalarTraits<ScalarType>::magnitude( *da ) + ScalarTraits<ScalarType>::magnitude( *db );
00219     if ( scale == zero ) // There is nothing to do.
00220     {
00221       *c = one;
00222       *s = zero;
00223       *da = zero; *db = zero;
00224     } else { // Compute the Givens rotation.
00225       r = scale*ScalarTraits<ScalarType>::squareroot( ( *da/scale)*(*da/scale) + (*db/scale)*(*db/scale) );
00226       if ( roe < zero ) { r *= -one; }
00227       *c = *da / r;
00228       *s = *db / r;
00229       *db = ScalarTraits<ScalarType>::one();
00230       if( ScalarTraits<ScalarType>::magnitude( *da ) > ScalarTraits<ScalarType>::magnitude( *db ) ){ *db = *s; }
00231       if( ScalarTraits<ScalarType>::magnitude( *db ) >= ScalarTraits<ScalarType>::magnitude( *da ) &&
00232      *c != ScalarTraits<ScalarType>::zero() ) { *db = one / *c; }
00233       *da = r;
00234     }
00235   } /* end ROTG */
00236       
00237   template<typename OrdinalType, typename ScalarType>
00238   void BLAS<OrdinalType,ScalarType>::ROT(const OrdinalType n, ScalarType* dx, const OrdinalType incx, ScalarType* dy, const OrdinalType incy, MagnitudeType* c, ScalarType* s) const
00239   {
00240     // ToDo:  Implement this.
00241   }
00242 
00243   template<typename OrdinalType, typename ScalarType>
00244   void BLAS<OrdinalType, ScalarType>::SCAL(const OrdinalType n, const ScalarType alpha, ScalarType* x, const OrdinalType incx) const
00245   {
00246     OrdinalType izero = OrdinalTraits<OrdinalType>::zero();
00247     OrdinalType ione = OrdinalTraits<OrdinalType>::one();
00248     OrdinalType i, ix = izero;
00249     if ( n > izero ) {
00250         // Set the initial index (ix).
00251         if (incx < izero) { ix = (-n+ione)*incx; } 
00252         // Scale the std::vector.
00253         for(i = izero; i < n; i++)
00254         {
00255             x[ix] *= alpha;
00256             ix += incx;
00257         }
00258     }
00259   } /* end SCAL */
00260 
00261   template<typename OrdinalType, typename ScalarType>
00262   void BLAS<OrdinalType, ScalarType>::COPY(const OrdinalType n, const ScalarType* x, const OrdinalType incx, ScalarType* y, const OrdinalType incy) const
00263   {
00264     OrdinalType izero = OrdinalTraits<OrdinalType>::zero();
00265     OrdinalType ione = OrdinalTraits<OrdinalType>::one();
00266     OrdinalType i, ix = izero, iy = izero;
00267     if ( n > izero ) {
00268   // Set the initial indices (ix, iy).
00269       if (incx < izero) { ix = (-n+ione)*incx; }
00270       if (incy < izero) { iy = (-n+ione)*incy; }
00271 
00272         for(i = izero; i < n; i++)
00273           {
00274       y[iy] = x[ix];
00275       ix += incx;
00276       iy += incy;
00277           }
00278       }
00279   } /* end COPY */
00280 
00281   template<typename OrdinalType, typename ScalarType>
00282   void BLAS<OrdinalType, ScalarType>::AXPY(const OrdinalType n, const ScalarType alpha, const ScalarType* x, const OrdinalType incx, ScalarType* y, const OrdinalType incy) const
00283   {
00284     OrdinalType izero = OrdinalTraits<OrdinalType>::zero();
00285     OrdinalType ione = OrdinalTraits<OrdinalType>::one();
00286     OrdinalType i, ix = izero, iy = izero;
00287     if( n > izero && alpha != ScalarTraits<ScalarType>::zero())
00288       {
00289   // Set the initial indices (ix, iy).
00290       if (incx < izero) { ix = (-n+ione)*incx; }
00291       if (incy < izero) { iy = (-n+ione)*incy; }
00292 
00293         for(i = izero; i < n; i++)
00294           {
00295       y[iy] += alpha * x[ix];
00296       ix += incx;
00297       iy += incy;
00298           }
00299       }
00300   } /* end AXPY */
00301 
00302   template<typename OrdinalType, typename ScalarType>
00303   typename ScalarTraits<ScalarType>::magnitudeType BLAS<OrdinalType, ScalarType>::ASUM(const OrdinalType n, const ScalarType* x, const OrdinalType incx) const
00304   {
00305     OrdinalType izero = OrdinalTraits<OrdinalType>::zero();
00306     OrdinalType ione = OrdinalTraits<OrdinalType>::one();
00307     typename ScalarTraits<ScalarType>::magnitudeType result = 
00308       ScalarTraits<typename ScalarTraits<ScalarType>::magnitudeType>::zero();
00309     OrdinalType i, ix = izero;
00310     if( n > izero ) {
00311   // Set the initial indices
00312   if (incx < izero) { ix = (-n+ione)*incx; }
00313 
00314       for(i = izero; i < n; i++)
00315           {
00316       result += ScalarTraits<ScalarType>::magnitude(x[ix]);
00317       ix += incx;
00318           }
00319     } 
00320    return result;
00321   } /* end ASUM */
00322   
00323   template<typename OrdinalType, typename ScalarType>
00324   ScalarType BLAS<OrdinalType, ScalarType>::DOT(const OrdinalType n, const ScalarType* x, const OrdinalType incx, const ScalarType* y, const OrdinalType incy) const
00325   {
00326     OrdinalType izero = OrdinalTraits<OrdinalType>::zero();
00327     OrdinalType ione = OrdinalTraits<OrdinalType>::one();
00328     ScalarType result = ScalarTraits<ScalarType>::zero();
00329     OrdinalType i, ix = izero, iy = izero;
00330     if( n > izero ) 
00331       {
00332   // Set the initial indices (ix,iy).       
00333   if (incx < izero) { ix = (-n+ione)*incx; }
00334   if (incy < izero) { iy = (-n+ione)*incy; }
00335 
00336   for(i = izero; i < n; i++)
00337     {
00338       result += ScalarTraits<ScalarType>::conjugate(x[ix]) * y[iy];
00339       ix += incx;
00340       iy += incy;
00341     }
00342       }
00343     return result;
00344   } /* end DOT */
00345   
00346   template<typename OrdinalType, typename ScalarType>
00347   typename ScalarTraits<ScalarType>::magnitudeType BLAS<OrdinalType, ScalarType>::NRM2(const OrdinalType n, const ScalarType* x, const OrdinalType incx) const
00348   {
00349     OrdinalType izero = OrdinalTraits<OrdinalType>::zero();
00350     OrdinalType ione = OrdinalTraits<OrdinalType>::one();
00351     typename ScalarTraits<ScalarType>::magnitudeType result = 
00352       ScalarTraits<typename ScalarTraits<ScalarType>::magnitudeType>::zero();
00353     OrdinalType i, ix = izero;
00354     if ( n > izero ) 
00355       {
00356   // Set the initial index.
00357   if (incx < izero) { ix = (-n+ione)*incx; }  
00358     
00359   for(i = izero; i < n; i++)
00360           {
00361       result += ScalarTraits<ScalarType>::conjugate(x[ix]) * x[ix];
00362       ix += incx;
00363           }
00364       result = ScalarTraits<ScalarType>::squareroot(result);
00365       } 
00366     return result;
00367   } /* end NRM2 */
00368   
00369   template<typename OrdinalType, typename ScalarType>
00370   OrdinalType BLAS<OrdinalType, ScalarType>::IAMAX(const OrdinalType n, const ScalarType* x, const OrdinalType incx) const
00371   {
00372     OrdinalType izero = OrdinalTraits<OrdinalType>::zero();
00373     OrdinalType ione = OrdinalTraits<OrdinalType>::one();
00374     OrdinalType result = izero, ix = izero, i;
00375     ScalarType maxval;
00376 
00377     if ( n > izero ) 
00378       {
00379   if (incx < izero) { ix = (-n+ione)*incx; }
00380   maxval = ScalarTraits<ScalarType>::magnitude(x[ix]);
00381   ix += incx;
00382       for(i = ione; i < n; i++)
00383           {
00384       if(ScalarTraits<ScalarType>::magnitude(x[ix]) > maxval)
00385         {
00386         result = i;
00387           maxval = ScalarTraits<ScalarType>::magnitude(x[ix]);
00388         }
00389       ix += incx;
00390     }
00391       }
00392     return result + 1; // the BLAS I?AMAX functions return 1-indexed (Fortran-style) values
00393   } /* end IAMAX */
00394 
00395 //------------------------------------------------------------------------------------------
00396 //      LEVEL 2 BLAS ROUTINES
00397 //------------------------------------------------------------------------------------------
00398 
00399   template<typename OrdinalType, typename ScalarType>
00400   void BLAS<OrdinalType, ScalarType>::GEMV(ETransp trans, const OrdinalType m, const OrdinalType n, const ScalarType alpha, const ScalarType* A, const OrdinalType lda, const ScalarType* x, const OrdinalType incx, const ScalarType beta, ScalarType* y, const OrdinalType incy) const
00401   {
00402     OrdinalType izero = OrdinalTraits<OrdinalType>::zero();
00403     OrdinalType ione = OrdinalTraits<OrdinalType>::one();
00404     ScalarType zero = ScalarTraits<ScalarType>::zero();
00405     ScalarType one = ScalarTraits<ScalarType>::one();
00406     bool BadArgument = false;
00407 
00408     // Quick return if there is nothing to do!
00409     if( m == izero || n == izero || ( alpha == zero && beta == one ) ){ return; }
00410     
00411     // Otherwise, we need to check the argument list.
00412     if( m < izero ) { 
00413   std::cout << "BLAS::GEMV Error: M == " << m << std::endl;     
00414   BadArgument = true;
00415     }
00416     if( n < izero ) { 
00417   std::cout << "BLAS::GEMV Error: N == " << n << std::endl;     
00418   BadArgument = true;
00419     }
00420     if( lda < m ) { 
00421   std::cout << "BLAS::GEMV Error: LDA < MAX(1,M)"<< std::endl;      
00422   BadArgument = true;
00423     }
00424     if( incx == izero ) {
00425   std::cout << "BLAS::GEMV Error: INCX == 0"<< std::endl;
00426   BadArgument = true;
00427     }
00428     if( incy == izero ) {
00429   std::cout << "BLAS::GEMV Error: INCY == 0"<< std::endl;
00430   BadArgument = true;
00431     }
00432 
00433     if(!BadArgument) {
00434       OrdinalType i, j, lenx, leny, ix, iy, jx, jy; 
00435       OrdinalType kx = izero, ky = izero;
00436       ScalarType temp;
00437 
00438       // Determine the lengths of the vectors x and y.
00439       if(ETranspChar[trans] == 'N') {
00440   lenx = n;
00441   leny = m;
00442       } else {
00443   lenx = m;
00444   leny = n;
00445       }
00446 
00447       // Set the starting pointers for the vectors x and y if incx/y < 0.
00448       if (incx < izero ) { kx =  (ione - lenx)*incx; }
00449       if (incy < izero ) { ky =  (ione - leny)*incy; }
00450 
00451       // Form y = beta*y
00452       ix = kx; iy = ky;
00453       if(beta != one) {
00454   if (incy == ione) {
00455     if (beta == zero) {
00456       for(i = izero; i < leny; i++) { y[i] = zero; }
00457     } else {
00458       for(i = izero; i < leny; i++) { y[i] *= beta; }
00459     }
00460   } else {
00461     if (beta == zero) {
00462       for(i = izero; i < leny; i++) {
00463         y[iy] = zero;
00464         iy += incy;
00465       }
00466     } else {
00467       for(i = izero; i < leny; i++) {
00468         y[iy] *= beta;
00469         iy += incy;
00470       }
00471     }
00472   }
00473       }
00474   
00475       // Return if we don't have to do anything more.
00476       if(alpha == zero) { return; }
00477 
00478       if( ETranspChar[trans] == 'N' ) {
00479   // Form y = alpha*A*y
00480   jx = kx;
00481   if (incy == ione) {
00482     for(j = izero; j < n; j++) {
00483       if (x[jx] != zero) {
00484         temp = alpha*x[jx];
00485         for(i = izero; i < m; i++) {
00486     y[i] += temp*A[j*lda + i];
00487         }
00488       }
00489       jx += incx;
00490     }
00491   } else {
00492     for(j = izero; j < n; j++) {
00493       if (x[jx] != zero) {
00494         temp = alpha*x[jx];
00495         iy = ky;
00496         for(i = izero; i < m; i++) {
00497     y[iy] += temp*A[j*lda + i];
00498     iy += incy;
00499         }
00500       }
00501       jx += incx;
00502     }
00503   }
00504       } else {
00505   jy = ky;
00506   if (incx == ione) {
00507     for(j = izero; j < n; j++) {
00508       temp = zero;
00509       for(i = izero; i < m; i++) {
00510         temp += A[j*lda + i]*x[i];
00511       }
00512       y[jy] += alpha*temp;
00513       jy += incy;
00514     }
00515   } else {
00516     for(j = izero; j < n; j++) {
00517       temp = zero;
00518       ix = kx;
00519       for (i = izero; i < m; i++) {
00520         temp += A[j*lda + i]*x[ix];
00521         ix += incx;
00522       }
00523       y[jy] += alpha*temp;
00524       jy += incy;
00525     }
00526   }
00527       }
00528     } /* if (!BadArgument) */
00529   } /* end GEMV */
00530 
00531  template<typename OrdinalType, typename ScalarType>
00532  void BLAS<OrdinalType, ScalarType>::TRMV(EUplo uplo, ETransp trans, EDiag diag, const OrdinalType n, const ScalarType* A, const OrdinalType lda, ScalarType* x, const OrdinalType incx) const
00533   {
00534     OrdinalType izero = OrdinalTraits<OrdinalType>::zero();
00535     OrdinalType ione = OrdinalTraits<OrdinalType>::one();
00536     ScalarType zero = ScalarTraits<ScalarType>::zero();
00537     bool BadArgument = false;
00538 
00539     // Quick return if there is nothing to do!
00540     if( n == izero ){ return; }
00541     
00542     // Otherwise, we need to check the argument list.
00543     if( n < izero ) { 
00544       std::cout << "BLAS::TRMV Error: N == " << n << std::endl;     
00545       BadArgument = true;
00546     }
00547     if( lda < n ) { 
00548       std::cout << "BLAS::TRMV Error: LDA < MAX(1,N)"<< std::endl;      
00549       BadArgument = true;
00550     }
00551     if( incx == izero ) {
00552       std::cout << "BLAS::TRMV Error: INCX == 0"<< std::endl;
00553       BadArgument = true;
00554     }
00555 
00556     if(!BadArgument) {
00557       OrdinalType i, j, ix, jx, kx = izero;
00558       ScalarType temp;
00559       bool NoUnit = (EDiagChar[diag] == 'N');
00560 
00561       // Set the starting pointer for the std::vector x if incx < 0.
00562       if (incx < izero) { kx = (-n+ione)*incx; }
00563 
00564       // Start the operations for a nontransposed triangular matrix 
00565       if (ETranspChar[trans] == 'N') {
00566   /* Compute x = A*x */
00567   if (EUploChar[uplo] == 'U') {
00568     /* A is an upper triangular matrix */
00569     if (incx == ione) {
00570       for (j=izero; j<n; j++) {
00571         if (x[j] != zero) {
00572     temp = x[j];
00573     for (i=izero; i<j; i++) {
00574       x[i] += temp*A[j*lda + i];
00575     }
00576     if (NoUnit) 
00577       x[j] *= A[j*lda + j];
00578         }
00579       }
00580     } else {
00581       jx = kx;
00582       for (j=izero; j<n; j++) {
00583         if (x[jx] != zero) {
00584     temp = x[jx];
00585     ix = kx;
00586     for (i=izero; i<j; i++) {
00587       x[ix] += temp*A[j*lda + i];
00588       ix += incx;
00589     }
00590     if (NoUnit)
00591       x[jx] *= A[j*lda + j];
00592         }
00593         jx += incx;
00594       }
00595     } /* if (incx == ione) */
00596   } else { /* A is a lower triangular matrix */
00597     if (incx == ione) {
00598       for (j=n-ione; j>-ione; j--) {
00599         if (x[j] != zero) {
00600     temp = x[j];
00601     for (i=n-ione; i>j; i--) {
00602       x[i] += temp*A[j*lda + i];
00603     }
00604     if (NoUnit)
00605       x[j] *= A[j*lda + j];
00606         }
00607       }
00608     } else {
00609       kx += (n-ione)*incx;
00610       jx = kx;
00611       for (j=n-ione; j>-ione; j--) {
00612         if (x[jx] != zero) {
00613     temp = x[jx];
00614     ix = kx;
00615     for (i=n-ione; i>j; i--) {
00616       x[ix] += temp*A[j*lda + i];
00617       ix -= incx;
00618     }
00619     if (NoUnit) 
00620       x[jx] *= A[j*lda + j];
00621         }
00622         jx -= incx;
00623       }
00624     }
00625   } /* if (EUploChar[uplo]=='U') */
00626       } else { /* A is transposed/conjugated */
00627   /* Compute x = A'*x */
00628   if (EUploChar[uplo]=='U') {
00629     /* A is an upper triangular matrix */
00630     if (incx == ione) {
00631       for (j=n-ione; j>-ione; j--) {
00632         temp = x[j];
00633         if (NoUnit)
00634     temp *= A[j*lda + j];
00635         for (i=j-ione; i>-ione; i--) {
00636     temp += A[j*lda + i]*x[i];
00637         }
00638         x[j] = temp;
00639       }
00640     } else {
00641       jx = kx + (n-ione)*incx;
00642       for (j=n-ione; j>-ione; j--) {
00643         temp = x[jx];
00644         ix = jx;
00645         if (NoUnit)
00646     temp *= A[j*lda + j];
00647         for (i=j-ione; i>-ione; i--) {
00648     ix -= incx;
00649     temp += A[j*lda + i]*x[ix];
00650         }
00651         x[jx] = temp;
00652         jx -= incx;
00653       }
00654     }
00655   } else {
00656     /* A is a lower triangular matrix */
00657     if (incx == ione) {
00658       for (j=izero; j<n; j++) {
00659         temp = x[j];
00660         if (NoUnit)
00661     temp *= A[j*lda + j];
00662         for (i=j+ione; i<n; i++) {
00663     temp += A[j*lda + i]*x[i];
00664         }
00665         x[j] = temp;
00666       }
00667     } else {
00668       jx = kx;
00669       for (j=izero; j<n; j++) {
00670         temp = x[jx];
00671         ix = jx;
00672         if (NoUnit) 
00673     temp *= A[j*lda + j];
00674         for (i=j+ione; i<n; i++) {
00675     ix += incx;
00676     temp += A[j*lda + i]*x[ix];
00677         }
00678         x[jx] = temp;
00679         jx += incx;       
00680       }
00681     }
00682   } /* if (EUploChar[uplo]=='U') */
00683       } /* if (ETranspChar[trans]=='N') */
00684     } /* if (!BadArgument) */
00685   } /* end TRMV */
00686         
00687   template<typename OrdinalType, typename ScalarType>
00688   void BLAS<OrdinalType, ScalarType>::GER(const OrdinalType m, const OrdinalType n, const ScalarType alpha, const ScalarType* x, const OrdinalType incx, const ScalarType* y, const OrdinalType incy, ScalarType* A, const OrdinalType lda) const
00689   {
00690     OrdinalType izero = OrdinalTraits<OrdinalType>::zero();
00691     OrdinalType ione = OrdinalTraits<OrdinalType>::one();
00692     ScalarType zero = ScalarTraits<ScalarType>::zero();
00693     bool BadArgument = false;
00694 
00695     // Quick return if there is nothing to do!
00696     if( m == izero || n == izero || alpha == zero ){ return; }
00697     
00698     // Otherwise, we need to check the argument list.
00699     if( m < izero ) { 
00700   std::cout << "BLAS::GER Error: M == " << m << std::endl;      
00701   BadArgument = true;
00702     }
00703     if( n < izero ) { 
00704   std::cout << "BLAS::GER Error: N == " << n << std::endl;      
00705   BadArgument = true;
00706     }
00707     if( lda < m ) { 
00708   std::cout << "BLAS::GER Error: LDA < MAX(1,M)"<< std::endl;     
00709   BadArgument = true;
00710     }
00711     if( incx == 0 ) {
00712   std::cout << "BLAS::GER Error: INCX == 0"<< std::endl;
00713   BadArgument = true;
00714     }
00715     if( incy == 0 ) {
00716   std::cout << "BLAS::GER Error: INCY == 0"<< std::endl;
00717   BadArgument = true;
00718     }
00719 
00720     if(!BadArgument) {
00721       OrdinalType i, j, ix, jy = izero, kx = izero;
00722       ScalarType temp;
00723 
00724       // Set the starting pointers for the vectors x and y if incx/y < 0.
00725       if (incx < izero) { kx = (-m+ione)*incx; }
00726       if (incy < izero) { jy = (-n+ione)*incy; }
00727 
00728       // Start the operations for incx == 1
00729       if( incx == ione ) {
00730   for( j=izero; j<n; j++ ) {
00731     if ( y[jy] != zero ) {
00732       temp = alpha*y[jy];
00733       for ( i=izero; i<m; i++ ) {
00734         A[j*lda + i] += x[i]*temp;
00735       }
00736     }
00737     jy += incy;
00738   }
00739       } 
00740       else { // Start the operations for incx != 1
00741   for( j=izero; j<n; j++ ) {
00742     if ( y[jy] != zero ) {
00743       temp = alpha*y[jy];
00744       ix = kx;
00745       for( i=izero; i<m; i++ ) {
00746         A[j*lda + i] += x[ix]*temp;
00747         ix += incx;
00748       }
00749     }
00750     jy += incy;
00751   }
00752       }
00753     } /* if(!BadArgument) */
00754   } /* end GER */
00755   
00756 //------------------------------------------------------------------------------------------
00757 //      LEVEL 3 BLAS ROUTINES
00758 //------------------------------------------------------------------------------------------
00759         
00760   template<typename OrdinalType, typename ScalarType>
00761   void BLAS<OrdinalType, ScalarType>::GEMM(ETransp transa, ETransp transb, const OrdinalType m, const OrdinalType n, const OrdinalType k, const ScalarType alpha, const ScalarType* A, const OrdinalType lda, const ScalarType* B, const OrdinalType ldb, const ScalarType beta, ScalarType* C, const OrdinalType ldc) const
00762   {
00763     OrdinalType izero = OrdinalTraits<OrdinalType>::zero();
00764     ScalarType zero = ScalarTraits<ScalarType>::zero();
00765     ScalarType one = ScalarTraits<ScalarType>::one();
00766     OrdinalType i, j, p;
00767     OrdinalType NRowA = m, NRowB = k;
00768     ScalarType temp;
00769     bool BadArgument = false;
00770 
00771     // Change dimensions of matrix if either matrix is transposed
00772     if( !(ETranspChar[transa]=='N') ) {
00773       NRowA = k;
00774     }
00775     if( !(ETranspChar[transb]=='N') ) {
00776       NRowB = n;
00777     }
00778 
00779     // Quick return if there is nothing to do!
00780     if( (m==izero) || (n==izero) || (((alpha==zero)||(k==izero)) && (beta==one)) ){ return; }
00781     if( m < izero ) { 
00782       std::cout << "BLAS::GEMM Error: M == " << m << std::endl;     
00783       BadArgument = true;
00784     }
00785     if( n < izero ) { 
00786       std::cout << "BLAS::GEMM Error: N == " << n << std::endl;     
00787       BadArgument = true;
00788     }
00789     if( k < izero ) { 
00790       std::cout << "BLAS::GEMM Error: K == " << k << std::endl;     
00791       BadArgument = true;
00792     }
00793     if( lda < NRowA ) { 
00794       std::cout << "BLAS::GEMM Error: LDA < MAX(1,M)"<< std::endl;      
00795       BadArgument = true;
00796     }
00797     if( ldb < NRowB ) { 
00798       std::cout << "BLAS::GEMM Error: LDB < MAX(1,K)"<< std::endl;      
00799       BadArgument = true;
00800     }
00801      if( ldc < m ) { 
00802       std::cout << "BLAS::GEMM Error: LDC < MAX(1,M)"<< std::endl;      
00803       BadArgument = true;
00804     }
00805 
00806     if(!BadArgument) {
00807 
00808       // Only need to scale the resulting matrix C.
00809       if( alpha == zero ) {
00810   if( beta == zero ) {
00811     for (j=izero; j<n; j++) {
00812       for (i=izero; i<m; i++) {
00813         C[j*ldc + i] = zero;
00814       }
00815     }
00816   } else {
00817     for (j=izero; j<n; j++) {
00818       for (i=izero; i<m; i++) {
00819         C[j*ldc + i] *= beta;
00820       }
00821     }
00822   }
00823   return;
00824       }
00825       //
00826       // Now start the operations.
00827       //
00828       if ( ETranspChar[transb]=='N' ) {
00829   if ( ETranspChar[transa]=='N' ) {
00830     // Compute C = alpha*A*B + beta*C
00831     for (j=izero; j<n; j++) {
00832       if( beta == zero ) {
00833         for (i=izero; i<m; i++){
00834     C[j*ldc + i] = zero;
00835         }
00836       } else if( beta != one ) {
00837         for (i=izero; i<m; i++){
00838     C[j*ldc + i] *= beta;
00839         }
00840       }
00841       for (p=izero; p<k; p++){
00842         if (B[j*ldb + p] != zero ){
00843     temp = alpha*B[j*ldb + p];
00844     for (i=izero; i<m; i++) {
00845       C[j*ldc + i] += temp*A[p*lda + i];
00846     }
00847         }
00848       }
00849     }
00850   } else {
00851     // Compute C = alpha*A'*B + beta*C
00852     for (j=izero; j<n; j++) {
00853       for (i=izero; i<m; i++) {
00854         temp = zero;
00855         for (p=izero; p<k; p++) {
00856     temp += A[i*lda + p]*B[j*ldb + p];
00857         }
00858         if (beta == zero) {
00859     C[j*ldc + i] = alpha*temp;
00860         } else {
00861     C[j*ldc + i] = alpha*temp + beta*C[j*ldc + i];
00862         }
00863       }
00864     }
00865   }
00866       } else {
00867   if ( ETranspChar[transa]=='N' ) {
00868     // Compute C = alpha*A*B' + beta*C
00869     for (j=izero; j<n; j++) {
00870       if (beta == zero) {
00871         for (i=izero; i<m; i++) {
00872     C[j*ldc + i] = zero;
00873         } 
00874       } else if ( beta != one ) {
00875         for (i=izero; i<m; i++) {
00876     C[j*ldc + i] *= beta;
00877         }
00878       }
00879       for (p=izero; p<k; p++) {
00880         if (B[p*ldb + j] != zero) {
00881     temp = alpha*B[p*ldb + j];
00882     for (i=izero; i<m; i++) {
00883       C[j*ldc + i] += temp*A[p*lda + i];
00884     }
00885         }
00886       }
00887     }
00888   } else {
00889     // Compute C += alpha*A'*B' + beta*C
00890     for (j=izero; j<n; j++) {
00891       for (i=izero; i<m; i++) {
00892         temp = zero;
00893         for (p=izero; p<k; p++) {
00894     temp += A[i*lda + p]*B[p*ldb + j];
00895         }
00896         if (beta == zero) {
00897     C[j*ldc + i] = alpha*temp;
00898         } else {
00899     C[j*ldc + i] = alpha*temp + beta*C[j*ldc + i];
00900         }
00901       }
00902     }
00903   } // end if (ETranspChar[transa]=='N') ...
00904       } // end if (ETranspChar[transb]=='N') ...
00905     } // end if (!BadArgument) ...
00906   } // end of GEMM
00907 
00908 
00909   template<typename OrdinalType, typename ScalarType>
00910   void BLAS<OrdinalType, ScalarType>::SYMM(ESide side, EUplo uplo, const OrdinalType m, const OrdinalType n, const ScalarType alpha, const ScalarType* A, const OrdinalType lda, const ScalarType* B, const OrdinalType ldb, const ScalarType beta, ScalarType* C, const OrdinalType ldc) const
00911   {
00912     OrdinalType izero = OrdinalTraits<OrdinalType>::zero();
00913     OrdinalType ione = OrdinalTraits<OrdinalType>::one();
00914     ScalarType zero = ScalarTraits<ScalarType>::zero();
00915     ScalarType one = ScalarTraits<ScalarType>::one();
00916     OrdinalType i, j, k, NRowA = m;
00917     ScalarType temp1, temp2;
00918     bool BadArgument = false;
00919     bool Upper = (EUploChar[uplo] == 'U');
00920     if (ESideChar[side] == 'R') { NRowA = n; }
00921     
00922     // Quick return.
00923     if ( (m==izero) || (n==izero) || ( (alpha==zero)&&(beta==one) ) ) { return; }
00924     if( m < 0 ) { 
00925       std::cout << "BLAS::SYMM Error: M == "<< m << std::endl;
00926       BadArgument = true; }
00927     if( n < 0 ) {
00928       std::cout << "BLAS::SYMM Error: N == "<< n << std::endl;
00929       BadArgument = true; }
00930     if( lda < NRowA ) {
00931       std::cout << "BLAS::SYMM Error: LDA == "<<lda<<std::endl;
00932       BadArgument = true; }
00933     if( ldb < m ) {
00934       std::cout << "BLAS::SYMM Error: LDB == "<<ldb<<std::endl;
00935       BadArgument = true; }
00936     if( ldc < m ) {
00937       std::cout << "BLAS::SYMM Error: LDC == "<<ldc<<std::endl;
00938       BadArgument = true; }
00939 
00940     if(!BadArgument) {
00941 
00942       // Only need to scale C and return.
00943       if (alpha == zero) {
00944   if (beta == zero ) {
00945     for (j=izero; j<n; j++) {
00946       for (i=izero; i<m; i++) {
00947         C[j*ldc + i] = zero;
00948       }
00949     }
00950   } else {
00951     for (j=izero; j<n; j++) {
00952       for (i=izero; i<m; i++) {
00953         C[j*ldc + i] *= beta;
00954       }
00955     }
00956   }
00957   return;
00958       }
00959 
00960       if ( ESideChar[side] == 'L') {
00961   // Compute C = alpha*A*B + beta*C
00962 
00963   if (Upper) {
00964     // The symmetric part of A is stored in the upper triangular part of the matrix.
00965     for (j=izero; j<n; j++) {
00966       for (i=izero; i<m; i++) {
00967         temp1 = alpha*B[j*ldb + i];
00968         temp2 = zero;
00969         for (k=izero; k<i; k++) {
00970     C[j*ldc + k] += temp1*A[i*lda + k];
00971     temp2 += B[j*ldb + k]*A[i*lda + k];
00972         }
00973         if (beta == zero) {
00974     C[j*ldc + i] = temp1*A[i*lda + i] + alpha*temp2;
00975         } else {
00976     C[j*ldc + i] = beta*C[j*ldc + i] + temp1*A[i*lda + i] + alpha*temp2;
00977         }
00978       }
00979     }
00980   } else {
00981     // The symmetric part of A is stored in the lower triangular part of the matrix.
00982     for (j=izero; j<n; j++) {
00983       for (i=m-ione; i>-ione; i--) {
00984         temp1 = alpha*B[j*ldb + i];
00985         temp2 = zero;
00986         for (k=i+ione; k<m; k++) {
00987     C[j*ldc + k] += temp1*A[i*lda + k];
00988     temp2 += B[j*ldb + k]*A[i*lda + k];
00989         }
00990         if (beta == zero) {
00991     C[j*ldc + i] = temp1*A[i*lda + i] + alpha*temp2;
00992         } else {
00993     C[j*ldc + i] = beta*C[j*ldc + i] + temp1*A[i*lda + i] + alpha*temp2;
00994         }
00995       }
00996     }
00997   }
00998       } else {
00999   // Compute C = alpha*B*A + beta*C.
01000   for (j=izero; j<n; j++) {
01001     temp1 = alpha*A[j*lda + j];
01002     if (beta == zero) {
01003       for (i=izero; i<m; i++) {
01004         C[j*ldc + i] = temp1*B[j*ldb + i];
01005       }
01006     } else {
01007       for (i=izero; i<m; i++) {
01008         C[j*ldc + i] = beta*C[j*ldc + i] + temp1*B[j*ldb + i];
01009       }
01010     }
01011     for (k=izero; k<j; k++) {
01012       if (Upper) {
01013         temp1 = alpha*A[j*lda + k];
01014       } else {
01015         temp1 = alpha*A[k*lda + j];
01016       }
01017       for (i=izero; i<m; i++) {
01018         C[j*ldc + i] += temp1*B[k*ldb + i];
01019       }
01020     }
01021     for (k=j+ione; k<n; k++) {
01022       if (Upper) {
01023         temp1 = alpha*A[k*lda + j];
01024       } else {
01025         temp1 = alpha*A[j*lda + k];
01026       }
01027       for (i=izero; i<m; i++) {
01028         C[j*ldc + i] += temp1*B[k*ldb + i];
01029       }
01030     }
01031   }
01032       } // end if (ESideChar[side]=='L') ...
01033     } // end if(!BadArgument) ...
01034 } // end SYMM
01035   
01036   template<typename OrdinalType, typename ScalarType>
01037   void BLAS<OrdinalType, ScalarType>::TRMM(ESide side, EUplo uplo, ETransp transa, EDiag diag, const OrdinalType m, const OrdinalType n, const ScalarType alpha, const ScalarType* A, const OrdinalType lda, ScalarType* B, const OrdinalType ldb) const
01038   {
01039     OrdinalType izero = OrdinalTraits<OrdinalType>::zero();
01040     OrdinalType ione = OrdinalTraits<OrdinalType>::one();
01041     ScalarType zero = ScalarTraits<ScalarType>::zero();
01042     ScalarType one = ScalarTraits<ScalarType>::one();
01043     OrdinalType i, j, k, NRowA = m;
01044     ScalarType temp;
01045     bool BadArgument = false;
01046     bool LSide = (ESideChar[side] == 'L');
01047     bool NoUnit = (EDiagChar[diag] == 'N');
01048     bool Upper = (EUploChar[uplo] == 'U');
01049 
01050     if(!LSide) { NRowA = n; }
01051 
01052     // Quick return.
01053     if (n==izero || m==izero) { return; }
01054     if( m < 0 ) {
01055       std::cout << "BLAS::TRMM Error: M == "<< m <<std::endl;
01056       BadArgument = true; }
01057     if( n < 0 ) {
01058       std::cout << "BLAS::TRMM Error: N == "<< n <<std::endl;
01059       BadArgument = true; }
01060     if( lda < NRowA ) {
01061       std::cout << "BLAS::TRMM Error: LDA == "<< lda << std::endl;
01062       BadArgument = true; }
01063     if( ldb < m ) {
01064       std::cout << "BLAS::TRMM Error: M == "<< ldb << std::endl;
01065       BadArgument = true; }
01066 
01067     if(!BadArgument) {
01068 
01069       // B only needs to be zeroed out.
01070       if( alpha == zero ) {
01071   for( j=izero; j<n; j++ ) {
01072     for( i=izero; i<m; i++ ) {
01073       B[j*ldb + i] = zero;
01074     }
01075   }
01076   return;
01077       }
01078       
01079       //  Start the computations. 
01080       if ( LSide ) {
01081   // A is on the left side of B.
01082   
01083   if ( ETranspChar[transa]=='N' ) {
01084     // Compute B = alpha*A*B
01085 
01086     if ( Upper ) {
01087       // A is upper triangular
01088       for( j=izero; j<n; j++ ) {
01089         for( k=izero; k<m; k++) {
01090     if ( B[j*ldb + k] != zero ) {
01091       temp = alpha*B[j*ldb + k];
01092       for( i=izero; i<k; i++ ) {
01093         B[j*ldb + i] += temp*A[k*lda + i];
01094       }
01095       if ( NoUnit )
01096         temp *=A[k*lda + k];
01097       B[j*ldb + k] = temp;
01098     }
01099         }
01100       }
01101     } else {
01102       // A is lower triangular
01103       for( j=izero; j<n; j++ ) {
01104         for( k=m-ione; k>-ione; k-- ) {
01105     if( B[j*ldb + k] != zero ) {
01106       temp = alpha*B[j*ldb + k];
01107       B[j*ldb + k] = temp;
01108       if ( NoUnit )
01109         B[j*ldb + k] *= A[k*lda + k];
01110       for( i=k+ione; i<m; i++ ) {
01111         B[j*ldb + i] += temp*A[k*lda + i];
01112       }
01113     }
01114         }
01115       }
01116     }
01117   } else {
01118     // Compute B = alpha*A'*B
01119     if( Upper ) {
01120       for( j=izero; j<n; j++ ) {
01121         for( i=m-ione; i>-ione; i-- ) {
01122     temp = B[j*ldb + i];
01123     if( NoUnit )
01124       temp *= A[i*lda + i];
01125     for( k=izero; k<i; k++ ) {
01126       temp += A[i*lda + k]*B[j*ldb + k];
01127     }
01128     B[j*ldb + i] = alpha*temp;
01129         }
01130       }
01131     } else {
01132       for( j=izero; j<n; j++ ) {
01133         for( i=izero; i<m; i++ ) {
01134     temp = B[j*ldb + i];
01135     if( NoUnit ) 
01136       temp *= A[i*lda + i];
01137     for( k=i+ione; k<m; k++ ) {
01138       temp += A[i*lda + k]*B[j*ldb + k];
01139     }
01140     B[j*ldb + i] = alpha*temp;
01141         }
01142       }
01143     }
01144   }
01145       } else {
01146   // A is on the right hand side of B.
01147   
01148   if( ETranspChar[transa] == 'N' ) {
01149     // Compute B = alpha*B*A
01150 
01151     if( Upper ) {
01152       // A is upper triangular.
01153       for( j=n-ione; j>-ione; j-- ) {
01154         temp = alpha;
01155         if( NoUnit )
01156     temp *= A[j*lda + j];
01157         for( i=izero; i<m; i++ ) {
01158     B[j*ldb + i] *= temp;
01159         }
01160         for( k=izero; k<j; k++ ) {
01161     if( A[j*lda + k] != zero ) {
01162       temp = alpha*A[j*lda + k];
01163       for( i=izero; i<m; i++ ) {
01164         B[j*ldb + i] += temp*B[k*ldb + i];
01165       }
01166     }
01167         }
01168       }
01169     } else {
01170       // A is lower triangular.
01171       for( j=izero; j<n; j++ ) {
01172         temp = alpha;
01173         if( NoUnit )
01174     temp *= A[j*lda + j];
01175         for( i=izero; i<m; i++ ) {
01176     B[j*ldb + i] *= temp;
01177         }
01178         for( k=j+ione; k<n; k++ ) {
01179     if( A[j*lda + k] != zero ) {
01180       temp = alpha*A[j*lda + k];
01181       for( i=izero; i<m; i++ ) {
01182         B[j*ldb + i] += temp*B[k*ldb + i];
01183       }
01184     }
01185         }
01186       }
01187     }
01188   } else {
01189     // Compute B = alpha*B*A'
01190 
01191     if( Upper ) {
01192       for( k=izero; k<n; k++ ) {
01193         for( j=izero; j<k; j++ ) {
01194     if( A[k*lda + j] != zero ) {
01195       temp = alpha*A[k*lda + j];
01196       for( i=izero; i<m; i++ ) {
01197         B[j*ldb + i] += temp*B[k*ldb + i];
01198       }
01199     }
01200         }
01201         temp = alpha;
01202         if( NoUnit ) 
01203     temp *= A[k*lda + k];
01204         if( temp != one ) {
01205     for( i=izero; i<m; i++ ) {
01206       B[k*ldb + i] *= temp;
01207     }
01208         }
01209       }
01210     } else {
01211       for( k=n-ione; k>-ione; k-- ) {
01212         for( j=k+ione; j<n; j++ ) {
01213     if( A[k*lda + j] != zero ) {
01214       temp = alpha*A[k*lda + j];
01215       for( i=izero; i<m; i++ ) {
01216         B[j*ldb + i] += temp*B[k*ldb + i];
01217       }
01218     }
01219         }
01220         temp = alpha;
01221         if( NoUnit )
01222     temp *= A[k*lda + k];
01223         if( temp != one ) {
01224     for( i=izero; i<m; i++ ) {
01225       B[k*ldb + i] *= temp;
01226     }
01227         }
01228       }
01229     }
01230   } // end if( ETranspChar[transa] == 'N' ) ...
01231       } // end if ( LSide ) ...
01232     } // end if (!BadArgument)
01233   } // end TRMM
01234   
01235   template<typename OrdinalType, typename ScalarType>
01236   void BLAS<OrdinalType, ScalarType>::TRSM(ESide side, EUplo uplo, ETransp transa, EDiag diag, const OrdinalType m, const OrdinalType n, const ScalarType alpha, const ScalarType* A, const OrdinalType lda, ScalarType* B, const OrdinalType ldb) const
01237   {
01238     OrdinalType izero = OrdinalTraits<OrdinalType>::zero();
01239     OrdinalType ione = OrdinalTraits<OrdinalType>::one();
01240     ScalarType zero = ScalarTraits<ScalarType>::zero();
01241     ScalarType one = ScalarTraits<ScalarType>::one();
01242     ScalarType temp;
01243     OrdinalType NRowA = m;
01244     bool BadArgument = false;
01245     bool NoUnit = (EDiagChar[diag]=='N');
01246     
01247     if (!(ESideChar[side] == 'L')) { NRowA = n; }
01248 
01249     // Quick return.
01250     if (n == izero || m == izero) { return; }
01251     if( m < izero ) {
01252       std::cout << "BLAS::TRSM Error: M == "<<m<<std::endl;
01253       BadArgument = true; }
01254     if( n < izero ) {
01255       std::cout << "BLAS::TRSM Error: N == "<<n<<std::endl;
01256       BadArgument = true; }
01257     if( lda < NRowA ) {
01258       std::cout << "BLAS::TRSM Error: LDA == "<<lda<<std::endl;
01259       BadArgument = true; }
01260     if( ldb < m ) {
01261       std::cout << "BLAS::TRSM Error: LDB == "<<ldb<<std::endl;
01262       BadArgument = true; }
01263 
01264     if(!BadArgument)
01265       {
01266   int i, j, k;
01267   // Set the solution to the zero std::vector.
01268   if(alpha == zero) {
01269       for(j = izero; j < n; j++) {
01270         for( i = izero; i < m; i++) {
01271         B[j*ldb+i] = zero;
01272           }
01273       }
01274   }
01275   else 
01276   { // Start the operations.
01277       if(ESideChar[side] == 'L') {
01278     //
01279         // Perform computations for OP(A)*X = alpha*B     
01280     //
01281     if(ETranspChar[transa] == 'N') {
01282         //
01283         //  Compute B = alpha*inv( A )*B
01284         //
01285         if(EUploChar[uplo] == 'U') { 
01286       // A is upper triangular.
01287       for(j = izero; j < n; j++) {
01288               // Perform alpha*B if alpha is not 1.
01289               if(alpha != one) {
01290                 for( i = izero; i < m; i++) {
01291                 B[j*ldb+i] *= alpha;
01292             }
01293           }
01294           // Perform a backsolve for column j of B.
01295           for(k = (m - ione); k > -ione; k--) {
01296         // If this entry is zero, we don't have to do anything.
01297         if (B[j*ldb + k] != zero) {
01298             if (NoUnit) {
01299           B[j*ldb + k] /= A[k*lda + k];
01300             }
01301             for(i = izero; i < k; i++) {
01302           B[j*ldb + i] -= B[j*ldb + k] * A[k*lda + i];
01303             }
01304         }
01305           }
01306       }
01307         }
01308         else 
01309         { // A is lower triangular.
01310                         for(j = izero; j < n; j++) {
01311                             // Perform alpha*B if alpha is not 1.
01312                             if(alpha != one) {
01313                                 for( i = izero; i < m; i++) {
01314                                     B[j*ldb+i] *= alpha;
01315                                 }
01316                             }
01317                             // Perform a forward solve for column j of B.
01318                             for(k = izero; k < m; k++) {
01319                                 // If this entry is zero, we don't have to do anything.
01320                                 if (B[j*ldb + k] != zero) {   
01321                                     if (NoUnit) {
01322                                         B[j*ldb + k] /= A[k*lda + k];
01323                                     }
01324                                     for(i = k+ione; i < m; i++) {
01325                                         B[j*ldb + i] -= B[j*ldb + k] * A[k*lda + i];
01326                                     }
01327                                 }
01328                             }
01329                         }
01330         } // end if (uplo == 'U')
01331     }  // if (transa =='N') 
01332         else { 
01333         //
01334         //  Compute B = alpha*inv( A' )*B
01335         //
01336         if(EUploChar[uplo] == 'U') { 
01337       // A is upper triangular.
01338       for(j = izero; j < n; j++) {
01339                   for( i = izero; i < m; i++) {
01340             temp = alpha*B[j*ldb+i];
01341             for(k = izero; k < i; k++) {
01342             temp -= A[i*lda + k] * B[j*ldb + k];
01343         }
01344         if (NoUnit) {
01345             temp /= A[i*lda + i];
01346         }
01347         B[j*ldb + i] = temp;
01348           }
01349       }
01350         }
01351         else
01352         { // A is lower triangular.
01353                         for(j = izero; j < n; j++) {
01354                             for(i = (m - ione) ; i > -ione; i--) {
01355                                 temp = alpha*B[j*ldb+i];
01356                               for(k = i+ione; k < m; k++) {
01357             temp -= A[i*lda + k] * B[j*ldb + k];
01358         }
01359         if (NoUnit) {
01360             temp /= A[i*lda + i];
01361         }
01362         B[j*ldb + i] = temp;
01363                             }
01364                         }
01365         }
01366     }
01367       }  // if (side == 'L')
01368       else { 
01369          // side == 'R'
01370          //
01371          // Perform computations for X*OP(A) = alpha*B      
01372          //
01373         if (ETranspChar[transa] == 'N') {
01374         //
01375         //  Compute B = alpha*B*inv( A )
01376         //
01377         if(EUploChar[uplo] == 'U') { 
01378       // A is upper triangular.
01379           // Perform a backsolve for column j of B.
01380       for(j = izero; j < n; j++) {
01381               // Perform alpha*B if alpha is not 1.
01382               if(alpha != one) {
01383                 for( i = izero; i < m; i++) {
01384                 B[j*ldb+i] *= alpha;
01385             }
01386           }
01387           for(k = izero; k < j; k++) {
01388         // If this entry is zero, we don't have to do anything.
01389         if (A[j*lda + k] != zero) {
01390             for(i = izero; i < m; i++) {
01391           B[j*ldb + i] -= A[j*lda + k] * B[k*ldb + i];
01392             }
01393         }
01394           }
01395           if (NoUnit) {
01396         temp = one/A[j*lda + j];
01397         for(i = izero; i < m; i++) {
01398             B[j*ldb + i] *= temp;
01399         }
01400           }
01401       }
01402         }
01403         else 
01404         { // A is lower triangular.
01405                         for(j = (n - ione); j > -ione; j--) {
01406                             // Perform alpha*B if alpha is not 1.
01407                             if(alpha != one) {
01408                                 for( i = izero; i < m; i++) {
01409                                     B[j*ldb+i] *= alpha;
01410                                 }
01411                             }
01412                             // Perform a forward solve for column j of B.
01413                             for(k = j+ione; k < n; k++) {
01414                                 // If this entry is zero, we don't have to do anything.
01415         if (A[j*lda + k] != zero) {
01416             for(i = izero; i < m; i++) {
01417                                         B[j*ldb + i] -= A[j*lda + k] * B[k*ldb + i]; 
01418                                     }
01419                                 } 
01420                             }
01421           if (NoUnit) {
01422         temp = one/A[j*lda + j];
01423         for(i = izero; i < m; i++) {
01424             B[j*ldb + i] *= temp;
01425         }
01426           }     
01427                         }
01428         } // end if (uplo == 'U')
01429     }  // if (transa =='N') 
01430         else { 
01431         //
01432         //  Compute B = alpha*B*inv( A' )
01433         //
01434         if(EUploChar[uplo] == 'U') { 
01435       // A is upper triangular.
01436       for(k = (n - ione); k > -ione; k--) {
01437           if (NoUnit) {
01438         temp = one/A[k*lda + k];
01439                     for(i = izero; i < m; i++) {
01440                 B[k*ldb + i] *= temp;
01441         }
01442           }
01443           for(j = izero; j < k; j++) {
01444         if (A[k*lda + j] != zero) {
01445             temp = A[k*lda + j];
01446             for(i = izero; i < m; i++) {
01447           B[j*ldb + i] -= temp*B[k*ldb + i];
01448             }
01449         }
01450           }
01451           if (alpha != one) {
01452         for (i = izero; i < m; i++) {
01453             B[k*ldb + i] *= alpha;
01454         }
01455           }
01456       }
01457         }
01458         else
01459         { // A is lower triangular.
01460       for(k = izero; k < n; k++) {
01461           if (NoUnit) {
01462         temp = one/A[k*lda + k];
01463         for (i = izero; i < m; i++) {
01464             B[k*ldb + i] *= temp;
01465         }
01466           }
01467           for(j = k+ione; j < n; j++) {
01468         if(A[k*lda + j] != zero) {
01469             temp = A[k*lda + j];
01470             for(i = izero; i < m; i++) {
01471           B[j*ldb + i] -= temp*B[k*ldb + i];
01472             }
01473         }
01474           }
01475           if (alpha != one) {
01476         for (i = izero; i < m; i++) {
01477             B[k*ldb + i] *= alpha;
01478         }
01479           }
01480                         }
01481         }
01482     }   
01483       }
01484   }
01485     }
01486   }
01487 
01488   // Explicit instantiation for template<int,float>
01489 
01490 #ifdef HAVE_TEUCHOS_BLASFLOAT
01491 
01492   template <>
01493   class BLAS<int, float>
01494   {    
01495   public:
01496     inline BLAS(void) {}
01497     inline BLAS(const BLAS<int, float>& /*BLAS_source*/) {}
01498     inline virtual ~BLAS(void) {}
01499     void ROTG(float* da, float* db, float* c, float* s) const;
01500     void ROT(const int n, float* dx, const int incx, float* dy, const int incy, float* c, float* s) const;
01501     float ASUM(const int n, const float* x, const int incx) const;
01502     void AXPY(const int n, const float alpha, const float* x, const int incx, float* y, const int incy) const;
01503     void COPY(const int n, const float* x, const int incx, float* y, const int incy) const;
01504     float DOT(const int n, const float* x, const int incx, const float* y, const int incy) const;
01505     float NRM2(const int n, const float* x, const int incx) const;
01506     void SCAL(const int n, const float alpha, float* x, const int incx) const;
01507     int IAMAX(const int n, const float* x, const int incx) const;
01508     void GEMV(ETransp trans, const int m, const int n, const float alpha, const float* A, const int lda, const float* x, const int incx, const float beta, float* y, const int incy) const;
01509     void TRMV(EUplo uplo, ETransp trans, EDiag diag, const int n, const float* A, const int lda, float* x, const int incx) const;
01510     void GER(const int m, const int n, const float alpha, const float* x, const int incx, const float* y, const int incy, float* A, const int lda) const;
01511     void GEMM(ETransp transa, ETransp transb, const int m, const int n, const int k, const float alpha, const float* A, const int lda, const float* B, const int ldb, const float beta, float* C, const int ldc) const;
01512     void SYMM(ESide side, EUplo uplo, const int m, const int n, const float alpha, const float* A, const int lda, const float *B, const int ldb, const float beta, float *C, const int ldc) const;
01513     void TRMM(ESide side, EUplo uplo, ETransp transa, EDiag diag, const int m, const int n, const float alpha, const float* A, const int lda, float* B, const int ldb) const;
01514     void TRSM(ESide side, EUplo uplo, ETransp transa, EDiag diag, const int m, const int n, const float alpha, const float* A, const int lda, float* B, const int ldb) const;
01515   };
01516 
01517 #endif // HAVE_TEUCHOS_BLASFLOAT
01518 
01519   // Explicit instantiation for template<int,double>
01520 
01521   template<>
01522   class BLAS<int, double>
01523   {    
01524   public:
01525     inline BLAS(void) {}
01526     inline BLAS(const BLAS<int, double>& /*BLAS_source*/) {}
01527     inline virtual ~BLAS(void) {}
01528     void ROTG(double* da, double* db, double* c, double* s) const;
01529     void ROT(const int n, double* dx, const int incx, double* dy, const int incy, double* c, double* s) const;
01530     double ASUM(const int n, const double* x, const int incx) const;
01531     void AXPY(const int n, const double alpha, const double* x, const int incx, double* y, const int incy) const;
01532     void COPY(const int n, const double* x, const int incx, double* y, const int incy) const;
01533     double DOT(const int n, const double* x, const int incx, const double* y, const int incy) const;
01534     double NRM2(const int n, const double* x, const int incx) const;
01535     void SCAL(const int n, const double alpha, double* x, const int incx) const;
01536     int IAMAX(const int n, const double* x, const int incx) const;
01537     void GEMV(ETransp trans, const int m, const int n, const double alpha, const double* A, const int lda, const double* x, const int incx, const double beta, double* y, const int incy) const;
01538     void TRMV(EUplo uplo, ETransp trans, EDiag diag, const int n, const double* A, const int lda, double* x, const int incx) const;
01539     void GER(const int m, const int n, const double alpha, const double* x, const int incx, const double* y, const int incy, double* A, const int lda) const;
01540     void GEMM(ETransp transa, ETransp transb, const int m, const int n, const int k, const double alpha, const double* A, const int lda, const double* B, const int ldb, const double beta, double* C, const int ldc) const;
01541     void SYMM(ESide side, EUplo uplo, const int m, const int n, const double alpha, const double* A, const int lda, const double *B, const int ldb, const double beta, double *C, const int ldc) const;
01542     void TRMM(ESide side, EUplo uplo, ETransp transa, EDiag diag, const int m, const int n, const double alpha, const double* A, const int lda, double* B, const int ldb) const;
01543     void TRSM(ESide side, EUplo uplo, ETransp transa, EDiag diag, const int m, const int n, const double alpha, const double* A, const int lda, double* B, const int ldb) const;
01544   };
01545 
01546 #ifdef HAVE_TEUCHOS_BLASFLOAT
01547 
01548   // Explicit instantiation for template<int,complex<float> >
01549 
01550   template<>
01551   class BLAS<int, std::complex<float> >
01552   {    
01553   public:
01554     inline BLAS(void) {}
01555     inline BLAS(const BLAS<int, std::complex<float> >& /*BLAS_source*/) {}
01556     inline virtual ~BLAS(void) {}
01557     void ROTG(std::complex<float>* da, std::complex<float>* db, float* c, std::complex<float>* s) const;
01558     void ROT(const int n, std::complex<float>* dx, const int incx, std::complex<float>* dy, const int incy, float* c, std::complex<float>* s) const;
01559     float ASUM(const int n, const std::complex<float>* x, const int incx) const;
01560     void AXPY(const int n, const std::complex<float> alpha, const std::complex<float>* x, const int incx, std::complex<float>* y, const int incy) const;
01561     void COPY(const int n, const std::complex<float>* x, const int incx, std::complex<float>* y, const int incy) const;
01562     std::complex<float> DOT(const int n, const std::complex<float>* x, const int incx, const std::complex<float>* y, const int incy) const;
01563     float NRM2(const int n, const std::complex<float>* x, const int incx) const;
01564     void SCAL(const int n, const std::complex<float> alpha, std::complex<float>* x, const int incx) const;
01565     int IAMAX(const int n, const std::complex<float>* x, const int incx) const;
01566     void GEMV(ETransp trans, const int m, const int n, const std::complex<float> alpha, const std::complex<float>* A, const int lda, const std::complex<float>* x, const int incx, const std::complex<float> beta, std::complex<float>* y, const int incy) const;
01567     void TRMV(EUplo uplo, ETransp trans, EDiag diag, const int n, const std::complex<float>* A, const int lda, std::complex<float>* x, const int incx) const;
01568     void GER(const int m, const int n, const std::complex<float> alpha, const std::complex<float>* x, const int incx, const std::complex<float>* y, const int incy, std::complex<float>* A, const int lda) const;
01569     void GEMM(ETransp transa, ETransp transb, const int m, const int n, const int k, const std::complex<float> alpha, const std::complex<float>* A, const int lda, const std::complex<float>* B, const int ldb, const std::complex<float> beta, std::complex<float>* C, const int ldc) const;
01570     void SYMM(ESide side, EUplo uplo, const int m, const int n, const std::complex<float> alpha, const std::complex<float>* A, const int lda, const std::complex<float> *B, const int ldb, const std::complex<float> beta, std::complex<float> *C, const int ldc) const;
01571     void TRMM(ESide side, EUplo uplo, ETransp transa, EDiag diag, const int m, const int n, const std::complex<float> alpha, const std::complex<float>* A, const int lda, std::complex<float>* B, const int ldb) const;
01572     void TRSM(ESide side, EUplo uplo, ETransp transa, EDiag diag, const int m, const int n, const std::complex<float> alpha, const std::complex<float>* A, const int lda, std::complex<float>* B, const int ldb) const;
01573   };
01574 
01575 #endif // HAVE_TEUCHOS_BLASFLOAT
01576 
01577   // Explicit instantiation for template<int,complex<double> >
01578 
01579   template<>
01580   class BLAS<int, std::complex<double> >
01581   {    
01582   public:
01583     inline BLAS(void) {}
01584     inline BLAS(const BLAS<int, std::complex<double> >& /*BLAS_source*/) {}
01585     inline virtual ~BLAS(void) {}
01586     void ROTG(std::complex<double>* da, std::complex<double>* db, double* c, std::complex<double>* s) const;
01587     void ROT(const int n, std::complex<double>* dx, const int incx, std::complex<double>* dy, const int incy, double* c, std::complex<double>* s) const;
01588     double ASUM(const int n, const std::complex<double>* x, const int incx) const;
01589     void AXPY(const int n, const std::complex<double> alpha, const std::complex<double>* x, const int incx, std::complex<double>* y, const int incy) const;
01590     void COPY(const int n, const std::complex<double>* x, const int incx, std::complex<double>* y, const int incy) const;
01591     std::complex<double> DOT(const int n, const std::complex<double>* x, const int incx, const std::complex<double>* y, const int incy) const;
01592     double NRM2(const int n, const std::complex<double>* x, const int incx) const;
01593     void SCAL(const int n, const std::complex<double> alpha, std::complex<double>* x, const int incx) const;
01594     int IAMAX(const int n, const std::complex<double>* x, const int incx) const;
01595     void GEMV(ETransp trans, const int m, const int n, const std::complex<double> alpha, const std::complex<double>* A, const int lda, const std::complex<double>* x, const int incx, const std::complex<double> beta, std::complex<double>* y, const int incy) const;
01596     void TRMV(EUplo uplo, ETransp trans, EDiag diag, const int n, const std::complex<double>* A, const int lda, std::complex<double>* x, const int incx) const;
01597     void GER(const int m, const int n, const std::complex<double> alpha, const std::complex<double>* x, const int incx, const std::complex<double>* y, const int incy, std::complex<double>* A, const int lda) const;
01598     void GEMM(ETransp transa, ETransp transb, const int m, const int n, const int k, const std::complex<double> alpha, const std::complex<double>* A, const int lda, const std::complex<double>* B, const int ldb, const std::complex<double> beta, std::complex<double>* C, const int ldc) const;
01599     void SYMM(ESide side, EUplo uplo, const int m, const int n, const std::complex<double> alpha, const std::complex<double>* A, const int lda, const std::complex<double> *B, const int ldb, const std::complex<double> beta, std::complex<double> *C, const int ldc) const;
01600     void TRMM(ESide side, EUplo uplo, ETransp transa, EDiag diag, const int m, const int n, const std::complex<double> alpha, const std::complex<double>* A, const int lda, std::complex<double>* B, const int ldb) const;
01601     void TRSM(ESide side, EUplo uplo, ETransp transa, EDiag diag, const int m, const int n, const std::complex<double> alpha, const std::complex<double>* A, const int lda, std::complex<double>* B, const int ldb) const;
01602   };
01603 
01604 } // namespace Teuchos
01605 
01606 #endif // _TEUCHOS_BLAS_HPP_

Generated on Wed May 12 21:40:31 2010 for Teuchos - Trilinos Tools Package by  doxygen 1.4.7