vector_blas_example.cpp

Go to the documentation of this file.
00001 // $Id$ 
00002 // $Source$ 
00003 // @HEADER
00004 // ***********************************************************************
00005 // 
00006 //                           Sacado Package
00007 //                 Copyright (2006) Sandia Corporation
00008 // 
00009 // Under the terms of Contract DE-AC04-94AL85000 with Sandia Corporation,
00010 // the U.S. Government retains certain rights in this software.
00011 // 
00012 // This library is free software; you can redistribute it and/or modify
00013 // it under the terms of the GNU Lesser General Public License as
00014 // published by the Free Software Foundation; either version 2.1 of the
00015 // License, or (at your option) any later version.
00016 //  
00017 // This library is distributed in the hope that it will be useful, but
00018 // WITHOUT ANY WARRANTY; without even the implied warranty of
00019 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
00020 // Lesser General Public License for more details.
00021 //  
00022 // You should have received a copy of the GNU Lesser General Public
00023 // License along with this library; if not, write to the Free Software
00024 // Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307
00025 // USA
00026 // Questions? Contact David M. Gay (dmgay@sandia.gov) or Eric T. Phipps
00027 // (etphipp@sandia.gov).
00028 // 
00029 // ***********************************************************************
00030 // @HEADER
00031 
00032 // vector_blas_example
00033 //
00034 //  usage: 
00035 //     vector_blas_example
00036 //
00037 //  output:  
00038 //     prints the results of differentiating a BLAS routine with forward
00039 //     mode AD using the Sacado::Fad::DVFad class (uses dynamic memory
00040 //     allocation for number of derivative components stored in a contiguous
00041 //     array).
00042 
00043 #include <iostream>
00044 #include <iomanip>
00045 
00046 #include "Sacado.hpp"
00047 #include "Teuchos_BLAS.hpp"
00048 #include "Sacado_Fad_BLAS.hpp"
00049 
00050 typedef Sacado::Fad::DVFad<double> FadType;
00051 
00052 int main(int argc, char **argv)
00053 {
00054   const unsigned int n = 5;
00055   Sacado::Fad::Vector<unsigned int, FadType> A(n*n,0),B(n,n), C(n,n);
00056   for (unsigned int i=0; i<n; i++) {
00057     for (unsigned int j=0; j<n; j++)
00058       A[i+j*n] = FadType(Teuchos::ScalarTraits<double>::random());
00059     B[i] = FadType(n, Teuchos::ScalarTraits<double>::random());
00060     for (unsigned int j=0; j<n; j++)
00061       B[i].fastAccessDx(j) = Teuchos::ScalarTraits<double>::random();
00062     C[i] = 0.0;
00063   }
00064 
00065   double *a = A.vals();
00066   double *b = B.vals();
00067   double *bdx = B.dx();
00068   std::vector<double> c(n), cdx(n*n);
00069 
00070   Teuchos::BLAS<int,double> blas;
00071   blas.GEMV(Teuchos::NO_TRANS, n, n, 1.0, &a[0], n, &b[0], 1, 0.0, &c[0], 1);
00072   blas.GEMM(Teuchos::NO_TRANS, Teuchos::NO_TRANS, n, n, n, 1.0, &a[0], n, &bdx[0], n, 0.0, &cdx[0], n);
00073 
00074   // Teuchos::BLAS<int,FadType> blas_fad;
00075   // blas_fad.GEMV(Teuchos::NO_TRANS, n, n, 1.0, &A[0], n, &B[0], 1, 0.0, &C[0], 1);
00076 
00077   Teuchos::BLAS<int,FadType> sacado_fad_blas(false);
00078   sacado_fad_blas.GEMV(Teuchos::NO_TRANS, n, n, 1.0, &A[0], n, &B[0], 1, 0.0, &C[0], 1);
00079 
00080   // Print the results
00081   int p = 4;
00082   int w = p+7;
00083   std::cout.setf(std::ios::scientific);
00084   std::cout.precision(p);
00085 
00086   std::cout << "BLAS GEMV calculation:" << std::endl;
00087   std::cout << "a = " << std::endl;
00088   for (unsigned int i=0; i<n; i++) {
00089     for (unsigned int j=0; j<n; j++)
00090       std::cout << " " << std::setw(w) << a[i+j*n];
00091     std::cout << std::endl;
00092   }
00093   std::cout << "b = " << std::endl;
00094   for (unsigned int i=0; i<n; i++) {
00095     std::cout << " " << std::setw(w) << b[i];
00096   }
00097   std::cout << std::endl;
00098   std::cout << "bdot = " << std::endl;
00099   for (unsigned int i=0; i<n; i++) {
00100     for (unsigned int j=0; j<n; j++)
00101       std::cout << " " << std::setw(w) << bdx[i+j*n];
00102     std::cout << std::endl;
00103   }
00104   std::cout << "c = " << std::endl;
00105   for (unsigned int i=0; i<n; i++) {
00106     std::cout << " " << std::setw(w) << c[i];
00107   }
00108   std::cout << std::endl;
00109   std::cout << "cdot = " << std::endl;
00110   for (unsigned int i=0; i<n; i++) {
00111     for (unsigned int j=0; j<n; j++)
00112       std::cout << " " << std::setw(w) << cdx[i+j*n];
00113     std::cout << std::endl;
00114   }
00115   std::cout << std::endl << std::endl;
00116 
00117   std::cout << "FAD BLAS GEMV calculation:" << std::endl;
00118   std::cout << "A.val() (should = a) = " << std::endl;
00119   for (unsigned int i=0; i<n; i++) {
00120     for (unsigned int j=0; j<n; j++)
00121       std::cout << " " << std::setw(w) << A[i+j*n].val();
00122     std::cout << std::endl;
00123   }
00124   std::cout << "B.val() (should = b) = " << std::endl;
00125   for (unsigned int i=0; i<n; i++) {
00126     std::cout << " " << std::setw(w) << B[i].val();
00127   }
00128   std::cout << std::endl;
00129   std::cout << "B.dx() (should = bdot) = " << std::endl;
00130   double *Bdx = B.dx();
00131   for (unsigned int i=0; i<n; i++) {
00132     for (unsigned int j=0; j<n; j++)
00133       std::cout << " " << std::setw(w) << Bdx[i+j*n];
00134     std::cout << std::endl;
00135   }
00136   std::cout << "C.val() (should = c) = " << std::endl;
00137   for (unsigned int i=0; i<n; i++) {
00138     std::cout << " " << std::setw(w) << C[i].val();
00139   }
00140   std::cout << std::endl;
00141   std::cout << "C.dx() (should = cdot) = " << std::endl;
00142   double *Cdx = C.dx();
00143   for (unsigned int i=0; i<n; i++) {
00144     for (unsigned int j=0; j<n; j++)
00145       std::cout << " " << std::setw(w) << Cdx[i+j*n];
00146     std::cout << std::endl;
00147   }
00148 
00149   double tol = 1.0e-14;
00150   bool failed = false;
00151   for (unsigned int i=0; i<n; i++) {
00152     if (std::fabs(C[i].val() - c[i]) > tol)
00153       failed = true;
00154     for (unsigned int j=0; j<n; j++) {
00155       if (std::fabs(C[i].dx(j) - cdx[i+j*n]) > tol) 
00156   failed = true;
00157     }
00158   }
00159   if (!failed) {
00160     std::cout << "\nExample passed!" << std::endl;
00161     return 0;
00162   }
00163   else {
00164     std::cout <<"\nSomething is wrong, example failed!" << std::endl;
00165     return 1;
00166   }
00167 }
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines
Generated on Wed Apr 13 10:19:40 2011 for Sacado Package Browser (Single Doxygen Collection) by  doxygen 1.6.3