|
Sacado Package Browser (Single Doxygen Collection) Version of the Day
|
00001 // $Id$ 00002 // $Source$ 00003 // @HEADER 00004 // *********************************************************************** 00005 // 00006 // Sacado Package 00007 // Copyright (2006) Sandia Corporation 00008 // 00009 // Under the terms of Contract DE-AC04-94AL85000 with Sandia Corporation, 00010 // the U.S. Government retains certain rights in this software. 00011 // 00012 // This library is free software; you can redistribute it and/or modify 00013 // it under the terms of the GNU Lesser General Public License as 00014 // published by the Free Software Foundation; either version 2.1 of the 00015 // License, or (at your option) any later version. 00016 // 00017 // This library is distributed in the hope that it will be useful, but 00018 // WITHOUT ANY WARRANTY; without even the implied warranty of 00019 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU 00020 // Lesser General Public License for more details. 00021 // 00022 // You should have received a copy of the GNU Lesser General Public 00023 // License along with this library; if not, write to the Free Software 00024 // Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 00025 // USA 00026 // Questions? Contact David M. Gay (dmgay@sandia.gov) or Eric T. Phipps 00027 // (etphipp@sandia.gov). 00028 // 00029 // *********************************************************************** 00030 // @HEADER 00031 00032 // blas_example 00033 // 00034 // usage: 00035 // blas_example 00036 // 00037 // output: 00038 // prints the results of differentiating a BLAS routine with forward 00039 // mode AD using the Sacado::Fad::DFad class (uses dynamic memory 00040 // allocation for number of derivative components). 00041 00042 #include <iostream> 00043 #include <iomanip> 00044 00045 #include "Sacado.hpp" 00046 #include "Teuchos_BLAS.hpp" 00047 #include "Sacado_Fad_BLAS.hpp" 00048 00049 typedef Sacado::Fad::DFad<double> FadType; 00050 00051 int main(int argc, char **argv) 00052 { 00053 const unsigned int n = 5; 00054 std::vector<double> a(n*n), b(n), c(n); 00055 std::vector<FadType> A(n*n), B(n), C(n); 00056 for (unsigned int i=0; i<n; i++) { 00057 for (unsigned int j=0; j<n; j++) 00058 a[i+j*n] = Teuchos::ScalarTraits<double>::random(); 00059 b[i] = Teuchos::ScalarTraits<double>::random(); 00060 c[i] = 0.0; 00061 00062 for (unsigned int j=0; j<n; j++) 00063 A[i+j*n] = FadType(a[i+j*n]); 00064 B[i] = FadType(n, i, b[i]); 00065 C[i] = FadType(c[i]); 00066 } 00067 00068 Teuchos::BLAS<int,double> blas; 00069 blas.GEMV(Teuchos::NO_TRANS, n, n, 1.0, &a[0], n, &b[0], 1, 0.0, &c[0], 1); 00070 00071 // Teuchos::BLAS<int,FadType> fad_blas; 00072 // fad_blas.GEMV(Teuchos::NO_TRANS, n, n, 1.0, &A[0], n, &B[0], 1, 0.0, &C[0], 1); 00073 00074 Teuchos::BLAS<int,FadType> sacado_fad_blas(false,false,3*n*n+2*n); 00075 sacado_fad_blas.GEMV(Teuchos::NO_TRANS, n, n, 1.0, &A[0], n, &B[0], 1, 0.0, &C[0], 1); 00076 00077 // Print the results 00078 int p = 4; 00079 int w = p+7; 00080 std::cout.setf(std::ios::scientific); 00081 std::cout.precision(p); 00082 00083 std::cout << "BLAS GEMV calculation:" << std::endl; 00084 std::cout << "a = " << std::endl; 00085 for (unsigned int i=0; i<n; i++) { 00086 for (unsigned int j=0; j<n; j++) 00087 std::cout << " " << std::setw(w) << a[i+j*n]; 00088 std::cout << std::endl; 00089 } 00090 std::cout << "b = " << std::endl; 00091 for (unsigned int i=0; i<n; i++) { 00092 std::cout << " " << std::setw(w) << b[i]; 00093 } 00094 std::cout << std::endl; 00095 std::cout << "c = " << std::endl; 00096 for (unsigned int i=0; i<n; i++) { 00097 std::cout << " " << std::setw(w) << c[i]; 00098 } 00099 std::cout << std::endl << std::endl; 00100 00101 std::cout << "FAD BLAS GEMV calculation:" << std::endl; 00102 std::cout << "A.val() (should = a) = " << std::endl; 00103 for (unsigned int i=0; i<n; i++) { 00104 for (unsigned int j=0; j<n; j++) 00105 std::cout << " " << std::setw(w) << A[i+j*n].val(); 00106 std::cout << std::endl; 00107 } 00108 std::cout << "B.val() (should = b) = " << std::endl; 00109 for (unsigned int i=0; i<n; i++) { 00110 std::cout << " " << std::setw(w) << B[i].val(); 00111 } 00112 std::cout << std::endl; 00113 std::cout << "C.val() (should = c) = " << std::endl; 00114 for (unsigned int i=0; i<n; i++) { 00115 std::cout << " " << std::setw(w) << C[i].val(); 00116 } 00117 std::cout << std::endl; 00118 std::cout << "C.dx() ( = dc/db, should = a) = " << std::endl; 00119 for (unsigned int i=0; i<n; i++) { 00120 for (unsigned int j=0; j<n; j++) 00121 std::cout << " " << std::setw(w) << C[i].dx(j); 00122 std::cout << std::endl; 00123 } 00124 00125 double tol = 1.0e-14; 00126 bool failed = false; 00127 for (unsigned int i=0; i<n; i++) { 00128 if (std::fabs(C[i].val() - c[i]) > tol) 00129 failed = true; 00130 for (unsigned int j=0; j<n; j++) { 00131 if (std::fabs(C[i].dx(j) - a[i+j*n]) > tol) 00132 failed = true; 00133 } 00134 } 00135 if (!failed) { 00136 std::cout << "\nExample passed!" << std::endl; 00137 return 0; 00138 } 00139 else { 00140 std::cout <<"\nSomething is wrong, example failed!" << std::endl; 00141 return 1; 00142 } 00143 }
1.7.4