|
Sacado Package Browser (Single Doxygen Collection) Version of the Day
|
00001 // $Id$ 00002 // $Source$ 00003 // @HEADER 00004 // *********************************************************************** 00005 // 00006 // Sacado Package 00007 // Copyright (2006) Sandia Corporation 00008 // 00009 // Under the terms of Contract DE-AC04-94AL85000 with Sandia Corporation, 00010 // the U.S. Government retains certain rights in this software. 00011 // 00012 // This library is free software; you can redistribute it and/or modify 00013 // it under the terms of the GNU Lesser General Public License as 00014 // published by the Free Software Foundation; either version 2.1 of the 00015 // License, or (at your option) any later version. 00016 // 00017 // This library is distributed in the hope that it will be useful, but 00018 // WITHOUT ANY WARRANTY; without even the implied warranty of 00019 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU 00020 // Lesser General Public License for more details. 00021 // 00022 // You should have received a copy of the GNU Lesser General Public 00023 // License along with this library; if not, write to the Free Software 00024 // Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 00025 // USA 00026 // Questions? Contact David M. Gay (dmgay@sandia.gov) or Eric T. Phipps 00027 // (etphipp@sandia.gov). 00028 // 00029 // *********************************************************************** 00030 // @HEADER 00031 00032 // trad_dfad_example 00033 // 00034 // usage: 00035 // trad_dfad_example 00036 // 00037 // output: 00038 // prints the results of computing the second derivative a simple function // with forward nested forward and reverse mode AD using the 00039 // Sacado::Fad::DFad and Sacado::Rad::ADvar classes. 00040 00041 #include <iostream> 00042 #include <iomanip> 00043 00044 #include "Sacado.hpp" 00045 00046 // The function to differentiate 00047 template <typename ScalarT> 00048 ScalarT func(const ScalarT& a, const ScalarT& b, const ScalarT& c) { 00049 ScalarT r = c*std::log(b+1.)/std::sin(a); 00050 return r; 00051 } 00052 00053 // The analytic derivative of func(a,b,c) with respect to a and b 00054 void func_deriv(double a, double b, double c, double& drda, double& drdb) 00055 { 00056 drda = -(c*std::log(b+1.)/std::pow(std::sin(a),2))*std::cos(a); 00057 drdb = c / ((b+1.)*std::sin(a)); 00058 } 00059 00060 // The analytic second derivative of func(a,b,c) with respect to a and b 00061 void func_deriv2(double a, double b, double c, double& d2rda2, double& d2rdb2, 00062 double& d2rdadb) 00063 { 00064 d2rda2 = c*std::log(b+1.)/std::sin(a) + 2.*(c*std::log(b+1.)/std::pow(std::sin(a),3))*std::pow(std::cos(a),2); 00065 d2rdb2 = -c / (std::pow(b+1.,2)*std::sin(a)); 00066 d2rdadb = -c / ((b+1.)*std::pow(std::sin(a),2))*std::cos(a); 00067 } 00068 00069 int main(int argc, char **argv) 00070 { 00071 double pi = std::atan(1.0)*4.0; 00072 00073 // Values of function arguments 00074 double a = pi/4; 00075 double b = 2.0; 00076 double c = 3.0; 00077 00078 // Number of independent variables 00079 int num_deriv = 2; 00080 00081 // Fad objects 00082 Sacado::Rad::ADvar< Sacado::Fad::DFad<double> > arad = 00083 Sacado::Fad::DFad<double>(num_deriv, 0, a); 00084 Sacado::Rad::ADvar< Sacado::Fad::DFad<double> > brad = 00085 Sacado::Fad::DFad<double>(num_deriv, 1, b); 00086 Sacado::Rad::ADvar< Sacado::Fad::DFad<double> > crad = c; 00087 Sacado::Rad::ADvar< Sacado::Fad::DFad<double> > rrad; 00088 00089 // Compute function 00090 double r = func(a, b, c); 00091 00092 // Compute derivative analytically 00093 double drda, drdb; 00094 func_deriv(a, b, c, drda, drdb); 00095 00096 // Compute second derivative analytically 00097 double d2rda2, d2rdb2, d2rdadb; 00098 func_deriv2(a, b, c, d2rda2, d2rdb2, d2rdadb); 00099 00100 // Compute function and derivative with AD 00101 rrad = func(arad, brad, crad); 00102 00103 Sacado::Rad::ADvar< Sacado::Fad::DFad<double> >::Gradcomp(); 00104 00105 // Extract value and derivatives 00106 double r_ad = rrad.val().val(); // r 00107 double drda_ad = arad.adj().val(); // dr/da 00108 double drdb_ad = brad.adj().val(); // dr/db 00109 double d2rda2_ad = arad.adj().dx(0); // d^2r/da^2 00110 double d2rdadb_ad = arad.adj().dx(1); // d^2r/dadb 00111 double d2rdbda_ad = brad.adj().dx(0); // d^2r/dbda 00112 double d2rdb2_ad = brad.adj().dx(1); // d^2/db^2 00113 00114 // Print the results 00115 int p = 4; 00116 int w = p+7; 00117 std::cout.setf(std::ios::scientific); 00118 std::cout.precision(p); 00119 std::cout << " r = " << std::setw(w) << r << " (original) == " 00120 << std::setw(w) << r_ad << " (AD) Error = " << std::setw(w) 00121 << r - r_ad << std::endl 00122 << " dr/da = " << std::setw(w) << drda << " (analytic) == " 00123 << std::setw(w) << drda_ad << " (AD) Error = " << std::setw(w) 00124 << drda - drda_ad << std::endl 00125 << " dr/db = " << std::setw(w) << drdb << " (analytic) == " 00126 << std::setw(w) << drdb_ad << " (AD) Error = " << std::setw(w) 00127 << drdb - drdb_ad << std::endl 00128 << "d^2r/da^2 = " << std::setw(w) << d2rda2 << " (analytic) == " 00129 << std::setw(w) << d2rda2_ad << " (AD) Error = " << std::setw(w) 00130 << d2rda2 - d2rda2_ad << std::endl 00131 << "d^2r/db^2 = " << std::setw(w) << d2rdb2 << " (analytic) == " 00132 << std::setw(w) << d2rdb2_ad << " (AD) Error = " << std::setw(w) 00133 << d2rdb2 - d2rdb2_ad << std::endl 00134 << "d^2r/dadb = " << std::setw(w) << d2rdadb << " (analytic) == " 00135 << std::setw(w) << d2rdadb_ad << " (AD) Error = " << std::setw(w) 00136 << d2rdadb - d2rdadb_ad << std::endl 00137 << "d^2r/dbda = " << std::setw(w) << d2rdadb << " (analytic) == " 00138 << std::setw(w) << d2rdbda_ad << " (AD) Error = " << std::setw(w) 00139 << d2rdadb - d2rdbda_ad << std::endl; 00140 00141 double tol = 1.0e-14; 00142 if (std::fabs(r - r_ad) < tol && 00143 std::fabs(drda - drda_ad) < tol && 00144 std::fabs(drdb - drdb_ad) < tol && 00145 std::fabs(d2rda2 - d2rda2_ad) < tol && 00146 std::fabs(d2rdb2 - d2rdb2_ad) < tol && 00147 std::fabs(d2rdadb - d2rdadb_ad) < tol) { 00148 std::cout << "\nExample passed!" << std::endl; 00149 return 0; 00150 } 00151 else { 00152 std::cout <<"\nSomething is wrong, example failed!" << std::endl; 00153 return 1; 00154 } 00155 }
1.7.4