|
Sacado Package Browser (Single Doxygen Collection) Version of the Day
|
00001 // $Id$ 00002 // $Source$ 00003 // @HEADER 00004 // *********************************************************************** 00005 // 00006 // Sacado Package 00007 // Copyright (2006) Sandia Corporation 00008 // 00009 // Under the terms of Contract DE-AC04-94AL85000 with Sandia Corporation, 00010 // the U.S. Government retains certain rights in this software. 00011 // 00012 // This library is free software; you can redistribute it and/or modify 00013 // it under the terms of the GNU Lesser General Public License as 00014 // published by the Free Software Foundation; either version 2.1 of the 00015 // License, or (at your option) any later version. 00016 // 00017 // This library is distributed in the hope that it will be useful, but 00018 // WITHOUT ANY WARRANTY; without even the implied warranty of 00019 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU 00020 // Lesser General Public License for more details. 00021 // 00022 // You should have received a copy of the GNU Lesser General Public 00023 // License along with this library; if not, write to the Free Software 00024 // Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 00025 // USA 00026 // Questions? Contact David M. Gay (dmgay@sandia.gov) or Eric T. Phipps 00027 // (etphipp@sandia.gov). 00028 // 00029 // *********************************************************************** 00030 // @HEADER 00031 00032 // dmfad_example 00033 // 00034 // usage: 00035 // dmfad_example 00036 // 00037 // output: 00038 // prints the results of differentiating a simple function with forward 00039 // mode AD using the Sacado::Fad::DMFad class (uses dynamic memory 00040 // allocation for number of derivative components using a custom memory 00041 // manager). 00042 00043 #include <iostream> 00044 #include <iomanip> 00045 00046 #include "Sacado.hpp" 00047 00048 template <> 00049 Sacado::Fad::MemPool* Sacado::Fad::MemPoolStorage<double>::defaultPool_ = NULL; 00050 00051 // The function to differentiate 00052 template <typename ScalarT> 00053 ScalarT func(const ScalarT& a, const ScalarT& b, const ScalarT& c) { 00054 ScalarT r = c*std::log(b+1.)/std::sin(a); 00055 00056 return r; 00057 } 00058 00059 // The analytic derivative of func(a,b,c) with respect to a and b 00060 void func_deriv(double a, double b, double c, double& drda, double& drdb) 00061 { 00062 drda = -(c*std::log(b+1.)/std::pow(std::sin(a),2))*std::cos(a); 00063 drdb = c / ((b+1.)*std::sin(a)); 00064 } 00065 00066 int main(int argc, char **argv) 00067 { 00068 double pi = std::atan(1.0)*4.0; 00069 00070 // Values of function arguments 00071 double a = pi/4; 00072 double b = 2.0; 00073 double c = 3.0; 00074 00075 // Number of independent variables 00076 int num_deriv = 2; 00077 00078 // Memory pool & manager 00079 Sacado::Fad::MemPoolManager<double> poolManager(10); 00080 Sacado::Fad::MemPool* pool = poolManager.getMemoryPool(num_deriv); 00081 Sacado::Fad::DMFad<double>::setDefaultPool(pool); 00082 00083 // Fad objects 00084 Sacado::Fad::DMFad<double> afad(num_deriv, 0, a); // First (0) indep. var 00085 Sacado::Fad::DMFad<double> bfad(num_deriv, 1, b); // Second (1) indep. var 00086 Sacado::Fad::DMFad<double> cfad(c); // Passive variable 00087 Sacado::Fad::DMFad<double> rfad; // Result 00088 00089 // Compute function 00090 double r = func(a, b, c); 00091 00092 // Compute derivative analytically 00093 double drda, drdb; 00094 func_deriv(a, b, c, drda, drdb); 00095 00096 // Compute function and derivative with AD 00097 rfad = func(afad, bfad, cfad); 00098 00099 // Extract value and derivatives 00100 double r_ad = rfad.val(); // r 00101 double drda_ad = rfad.dx(0); // dr/da 00102 double drdb_ad = rfad.dx(1); // dr/db 00103 00104 // Print the results 00105 int p = 4; 00106 int w = p+7; 00107 std::cout.setf(std::ios::scientific); 00108 std::cout.precision(p); 00109 std::cout << " r = " << r << " (original) == " << std::setw(w) << r_ad 00110 << " (AD) Error = " << std::setw(w) << r - r_ad << std::endl 00111 << "dr/da = " << std::setw(w) << drda << " (analytic) == " 00112 << std::setw(w) << drda_ad << " (AD) Error = " << std::setw(w) 00113 << drda - drda_ad << std::endl 00114 << "dr/db = " << std::setw(w) << drdb << " (analytic) == " 00115 << std::setw(w) << drdb_ad << " (AD) Error = " << std::setw(w) 00116 << drdb - drdb_ad << std::endl; 00117 00118 double tol = 1.0e-14; 00119 if (std::fabs(r - r_ad) < tol && 00120 std::fabs(drda - drda_ad) < tol && 00121 std::fabs(drdb - drdb_ad) < tol) { 00122 std::cout << "\nExample passed!" << std::endl; 00123 return 0; 00124 } 00125 else { 00126 std::cout <<"\nSomething is wrong, example failed!" << std::endl; 00127 return 1; 00128 } 00129 }
1.7.4