|
Sacado Package Browser (Single Doxygen Collection) Version of the Day
|
00001 // $Id$ 00002 // $Source$ 00003 // @HEADER 00004 // *********************************************************************** 00005 // 00006 // Sacado Package 00007 // Copyright (2006) Sandia Corporation 00008 // 00009 // Under the terms of Contract DE-AC04-94AL85000 with Sandia Corporation, 00010 // the U.S. Government retains certain rights in this software. 00011 // 00012 // This library is free software; you can redistribute it and/or modify 00013 // it under the terms of the GNU Lesser General Public License as 00014 // published by the Free Software Foundation; either version 2.1 of the 00015 // License, or (at your option) any later version. 00016 // 00017 // This library is distributed in the hope that it will be useful, but 00018 // WITHOUT ANY WARRANTY; without even the implied warranty of 00019 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU 00020 // Lesser General Public License for more details. 00021 // 00022 // You should have received a copy of the GNU Lesser General Public 00023 // License along with this library; if not, write to the Free Software 00024 // Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 00025 // USA 00026 // Questions? Contact David M. Gay (dmgay@sandia.gov) or Eric T. Phipps 00027 // (etphipp@sandia.gov). 00028 // 00029 // *********************************************************************** 00030 // @HEADER 00031 00032 // vector_blas_example 00033 // 00034 // usage: 00035 // vector_blas_example 00036 // 00037 // output: 00038 // prints the results of differentiating a BLAS routine with forward 00039 // mode AD using the Sacado::Fad::DVFad class (uses dynamic memory 00040 // allocation for number of derivative components stored in a contiguous 00041 // array). 00042 00043 #include <iostream> 00044 #include <iomanip> 00045 00046 #include "Sacado.hpp" 00047 #include "Teuchos_BLAS.hpp" 00048 #include "Sacado_Fad_BLAS.hpp" 00049 00050 typedef Sacado::Fad::DVFad<double> FadType; 00051 00052 int main(int argc, char **argv) 00053 { 00054 const unsigned int n = 5; 00055 Sacado::Fad::Vector<unsigned int, FadType> A(n*n,0),B(n,n), C(n,n); 00056 for (unsigned int i=0; i<n; i++) { 00057 for (unsigned int j=0; j<n; j++) 00058 A[i+j*n] = FadType(Teuchos::ScalarTraits<double>::random()); 00059 B[i] = FadType(n, Teuchos::ScalarTraits<double>::random()); 00060 for (unsigned int j=0; j<n; j++) 00061 B[i].fastAccessDx(j) = Teuchos::ScalarTraits<double>::random(); 00062 C[i] = 0.0; 00063 } 00064 00065 double *a = A.vals(); 00066 double *b = B.vals(); 00067 double *bdx = B.dx(); 00068 std::vector<double> c(n), cdx(n*n); 00069 00070 Teuchos::BLAS<int,double> blas; 00071 blas.GEMV(Teuchos::NO_TRANS, n, n, 1.0, &a[0], n, &b[0], 1, 0.0, &c[0], 1); 00072 blas.GEMM(Teuchos::NO_TRANS, Teuchos::NO_TRANS, n, n, n, 1.0, &a[0], n, &bdx[0], n, 0.0, &cdx[0], n); 00073 00074 // Teuchos::BLAS<int,FadType> blas_fad; 00075 // blas_fad.GEMV(Teuchos::NO_TRANS, n, n, 1.0, &A[0], n, &B[0], 1, 0.0, &C[0], 1); 00076 00077 Teuchos::BLAS<int,FadType> sacado_fad_blas(false); 00078 sacado_fad_blas.GEMV(Teuchos::NO_TRANS, n, n, 1.0, &A[0], n, &B[0], 1, 0.0, &C[0], 1); 00079 00080 // Print the results 00081 int p = 4; 00082 int w = p+7; 00083 std::cout.setf(std::ios::scientific); 00084 std::cout.precision(p); 00085 00086 std::cout << "BLAS GEMV calculation:" << std::endl; 00087 std::cout << "a = " << std::endl; 00088 for (unsigned int i=0; i<n; i++) { 00089 for (unsigned int j=0; j<n; j++) 00090 std::cout << " " << std::setw(w) << a[i+j*n]; 00091 std::cout << std::endl; 00092 } 00093 std::cout << "b = " << std::endl; 00094 for (unsigned int i=0; i<n; i++) { 00095 std::cout << " " << std::setw(w) << b[i]; 00096 } 00097 std::cout << std::endl; 00098 std::cout << "bdot = " << std::endl; 00099 for (unsigned int i=0; i<n; i++) { 00100 for (unsigned int j=0; j<n; j++) 00101 std::cout << " " << std::setw(w) << bdx[i+j*n]; 00102 std::cout << std::endl; 00103 } 00104 std::cout << "c = " << std::endl; 00105 for (unsigned int i=0; i<n; i++) { 00106 std::cout << " " << std::setw(w) << c[i]; 00107 } 00108 std::cout << std::endl; 00109 std::cout << "cdot = " << std::endl; 00110 for (unsigned int i=0; i<n; i++) { 00111 for (unsigned int j=0; j<n; j++) 00112 std::cout << " " << std::setw(w) << cdx[i+j*n]; 00113 std::cout << std::endl; 00114 } 00115 std::cout << std::endl << std::endl; 00116 00117 std::cout << "FAD BLAS GEMV calculation:" << std::endl; 00118 std::cout << "A.val() (should = a) = " << std::endl; 00119 for (unsigned int i=0; i<n; i++) { 00120 for (unsigned int j=0; j<n; j++) 00121 std::cout << " " << std::setw(w) << A[i+j*n].val(); 00122 std::cout << std::endl; 00123 } 00124 std::cout << "B.val() (should = b) = " << std::endl; 00125 for (unsigned int i=0; i<n; i++) { 00126 std::cout << " " << std::setw(w) << B[i].val(); 00127 } 00128 std::cout << std::endl; 00129 std::cout << "B.dx() (should = bdot) = " << std::endl; 00130 double *Bdx = B.dx(); 00131 for (unsigned int i=0; i<n; i++) { 00132 for (unsigned int j=0; j<n; j++) 00133 std::cout << " " << std::setw(w) << Bdx[i+j*n]; 00134 std::cout << std::endl; 00135 } 00136 std::cout << "C.val() (should = c) = " << std::endl; 00137 for (unsigned int i=0; i<n; i++) { 00138 std::cout << " " << std::setw(w) << C[i].val(); 00139 } 00140 std::cout << std::endl; 00141 std::cout << "C.dx() (should = cdot) = " << std::endl; 00142 double *Cdx = C.dx(); 00143 for (unsigned int i=0; i<n; i++) { 00144 for (unsigned int j=0; j<n; j++) 00145 std::cout << " " << std::setw(w) << Cdx[i+j*n]; 00146 std::cout << std::endl; 00147 } 00148 00149 double tol = 1.0e-14; 00150 bool failed = false; 00151 for (unsigned int i=0; i<n; i++) { 00152 if (std::fabs(C[i].val() - c[i]) > tol) 00153 failed = true; 00154 for (unsigned int j=0; j<n; j++) { 00155 if (std::fabs(C[i].dx(j) - cdx[i+j*n]) > tol) 00156 failed = true; 00157 } 00158 } 00159 if (!failed) { 00160 std::cout << "\nExample passed!" << std::endl; 00161 return 0; 00162 } 00163 else { 00164 std::cout <<"\nSomething is wrong, example failed!" << std::endl; 00165 return 1; 00166 } 00167 }
1.7.4