blas_example.cpp

Go to the documentation of this file.
00001 // $Id$ 
00002 // $Source$ 
00003 // @HEADER
00004 // ***********************************************************************
00005 // 
00006 //                           Sacado Package
00007 //                 Copyright (2006) Sandia Corporation
00008 // 
00009 // Under the terms of Contract DE-AC04-94AL85000 with Sandia Corporation,
00010 // the U.S. Government retains certain rights in this software.
00011 // 
00012 // This library is free software; you can redistribute it and/or modify
00013 // it under the terms of the GNU Lesser General Public License as
00014 // published by the Free Software Foundation; either version 2.1 of the
00015 // License, or (at your option) any later version.
00016 //  
00017 // This library is distributed in the hope that it will be useful, but
00018 // WITHOUT ANY WARRANTY; without even the implied warranty of
00019 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
00020 // Lesser General Public License for more details.
00021 //  
00022 // You should have received a copy of the GNU Lesser General Public
00023 // License along with this library; if not, write to the Free Software
00024 // Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307
00025 // USA
00026 // Questions? Contact David M. Gay (dmgay@sandia.gov) or Eric T. Phipps
00027 // (etphipp@sandia.gov).
00028 // 
00029 // ***********************************************************************
00030 // @HEADER
00031 
00032 // blas_example
00033 //
00034 //  usage: 
00035 //     blas_example
00036 //
00037 //  output:  
00038 //     prints the results of differentiating a BLAS routine with forward
00039 //     mode AD using the Sacado::Fad::DFad class (uses dynamic memory
00040 //     allocation for number of derivative components).
00041 
00042 #include <iostream>
00043 #include <iomanip>
00044 
00045 #include "Sacado.hpp"
00046 #include "Teuchos_BLAS.hpp"
00047 #include "Sacado_Fad_BLAS.hpp"
00048 
00049 typedef Sacado::Fad::DFad<double> FadType;
00050 
00051 int main(int argc, char **argv)
00052 {
00053   const unsigned int n = 5;
00054   std::vector<double> a(n*n), b(n), c(n);
00055   std::vector<FadType> A(n*n), B(n), C(n);
00056   for (unsigned int i=0; i<n; i++) {
00057     for (unsigned int j=0; j<n; j++)
00058       a[i+j*n] = Teuchos::ScalarTraits<double>::random();
00059     b[i] = Teuchos::ScalarTraits<double>::random();
00060     c[i] = 0.0;
00061 
00062     for (unsigned int j=0; j<n; j++)
00063       A[i+j*n] = FadType(a[i+j*n]);
00064     B[i] = FadType(n, i, b[i]);
00065     C[i] = FadType(c[i]);
00066   }
00067 
00068   Teuchos::BLAS<int,double> blas;
00069   blas.GEMV(Teuchos::NO_TRANS, n, n, 1.0, &a[0], n, &b[0], 1, 0.0, &c[0], 1);
00070 
00071   // Teuchos::BLAS<int,FadType> fad_blas;
00072   // fad_blas.GEMV(Teuchos::NO_TRANS, n, n, 1.0, &A[0], n, &B[0], 1, 0.0, &C[0], 1);
00073 
00074   Teuchos::BLAS<int,FadType> sacado_fad_blas(false,false,3*n*n+2*n);
00075   sacado_fad_blas.GEMV(Teuchos::NO_TRANS, n, n, 1.0, &A[0], n, &B[0], 1, 0.0, &C[0], 1);
00076 
00077   // Print the results
00078   int p = 4;
00079   int w = p+7;
00080   std::cout.setf(std::ios::scientific);
00081   std::cout.precision(p);
00082 
00083   std::cout << "BLAS GEMV calculation:" << std::endl;
00084   std::cout << "a = " << std::endl;
00085   for (unsigned int i=0; i<n; i++) {
00086     for (unsigned int j=0; j<n; j++)
00087       std::cout << " " << std::setw(w) << a[i+j*n];
00088     std::cout << std::endl;
00089   }
00090   std::cout << "b = " << std::endl;
00091   for (unsigned int i=0; i<n; i++) {
00092     std::cout << " " << std::setw(w) << b[i];
00093   }
00094   std::cout << std::endl;
00095   std::cout << "c = " << std::endl;
00096   for (unsigned int i=0; i<n; i++) {
00097     std::cout << " " << std::setw(w) << c[i];
00098   }
00099   std::cout << std::endl << std::endl;
00100 
00101   std::cout << "FAD BLAS GEMV calculation:" << std::endl;
00102   std::cout << "A.val() (should = a) = " << std::endl;
00103   for (unsigned int i=0; i<n; i++) {
00104     for (unsigned int j=0; j<n; j++)
00105       std::cout << " " << std::setw(w) << A[i+j*n].val();
00106     std::cout << std::endl;
00107   }
00108   std::cout << "B.val() (should = b) = " << std::endl;
00109   for (unsigned int i=0; i<n; i++) {
00110     std::cout << " " << std::setw(w) << B[i].val();
00111   }
00112   std::cout << std::endl;
00113   std::cout << "C.val() (should = c) = " << std::endl;
00114   for (unsigned int i=0; i<n; i++) {
00115     std::cout << " " << std::setw(w) << C[i].val();
00116   }
00117   std::cout << std::endl;
00118   std::cout << "C.dx() ( = dc/db, should = a) = " << std::endl;
00119   for (unsigned int i=0; i<n; i++) {
00120     for (unsigned int j=0; j<n; j++)
00121       std::cout << " " << std::setw(w) << C[i].dx(j);
00122     std::cout << std::endl;
00123   }
00124 
00125   double tol = 1.0e-14;
00126   bool failed = false;
00127   for (unsigned int i=0; i<n; i++) {
00128     if (std::fabs(C[i].val() - c[i]) > tol)
00129       failed = true;
00130     for (unsigned int j=0; j<n; j++) {
00131       if (std::fabs(C[i].dx(j) - a[i+j*n]) > tol) 
00132   failed = true;
00133     }
00134   }
00135   if (!failed) {
00136     std::cout << "\nExample passed!" << std::endl;
00137     return 0;
00138   }
00139   else {
00140     std::cout <<"\nSomething is wrong, example failed!" << std::endl;
00141     return 1;
00142   }
00143 }

Generated on Wed May 12 21:39:32 2010 for Sacado Package Browser (Single Doxygen Collection) by  doxygen 1.4.7