trad_dfad_example.cpp

Go to the documentation of this file.
00001 // $Id$ 
00002 // $Source$ 
00003 // @HEADER
00004 // ***********************************************************************
00005 // 
00006 //                           Sacado Package
00007 //                 Copyright (2006) Sandia Corporation
00008 // 
00009 // Under the terms of Contract DE-AC04-94AL85000 with Sandia Corporation,
00010 // the U.S. Government retains certain rights in this software.
00011 // 
00012 // This library is free software; you can redistribute it and/or modify
00013 // it under the terms of the GNU Lesser General Public License as
00014 // published by the Free Software Foundation; either version 2.1 of the
00015 // License, or (at your option) any later version.
00016 //  
00017 // This library is distributed in the hope that it will be useful, but
00018 // WITHOUT ANY WARRANTY; without even the implied warranty of
00019 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
00020 // Lesser General Public License for more details.
00021 //  
00022 // You should have received a copy of the GNU Lesser General Public
00023 // License along with this library; if not, write to the Free Software
00024 // Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307
00025 // USA
00026 // Questions? Contact David M. Gay (dmgay@sandia.gov) or Eric T. Phipps
00027 // (etphipp@sandia.gov).
00028 // 
00029 // ***********************************************************************
00030 // @HEADER
00031 
00032 // trad_dfad_example
00033 //
00034 //  usage: 
00035 //     trad_dfad_example
00036 //
00037 //  output:  
00038 //     prints the results of computing the second derivative a simple function //     with forward nested forward and reverse mode AD using the 
00039 //     Sacado::Fad::DFad and Sacado::Rad::ADvar classes.
00040 
00041 #include <iostream>
00042 #include <iomanip>
00043 
00044 #include "Sacado.hpp"
00045 
00046 // The function to differentiate
00047 template <typename ScalarT>
00048 ScalarT func(const ScalarT& a, const ScalarT& b, const ScalarT& c) {
00049   ScalarT r = c*std::log(b+1.)/std::sin(a);
00050   return r;
00051 }
00052 
00053 // The analytic derivative of func(a,b,c) with respect to a and b
00054 void func_deriv(double a, double b, double c, double& drda, double& drdb)
00055 {
00056   drda = -(c*std::log(b+1.)/std::pow(std::sin(a),2))*std::cos(a);
00057   drdb = c / ((b+1.)*std::sin(a));
00058 }
00059 
00060 // The analytic second derivative of func(a,b,c) with respect to a and b
00061 void func_deriv2(double a, double b, double c, double& d2rda2, double& d2rdb2,
00062      double& d2rdadb)
00063 {
00064   d2rda2 = c*std::log(b+1.)/std::sin(a) + 2.*(c*std::log(b+1.)/std::pow(std::sin(a),3))*std::pow(std::cos(a),2);
00065   d2rdb2 = -c / (std::pow(b+1.,2)*std::sin(a));
00066   d2rdadb = -c / ((b+1.)*std::pow(std::sin(a),2))*std::cos(a);
00067 }
00068 
00069 int main(int argc, char **argv)
00070 {
00071   double pi = std::atan(1.0)*4.0;
00072 
00073   // Values of function arguments
00074   double a = pi/4;
00075   double b = 2.0;
00076   double c = 3.0;
00077 
00078   // Number of independent variables
00079   int num_deriv = 2;
00080 
00081   // Fad objects
00082   Sacado::Rad::ADvar< Sacado::Fad::DFad<double> > arad = 
00083     Sacado::Fad::DFad<double>(num_deriv, 0, a);
00084   Sacado::Rad::ADvar< Sacado::Fad::DFad<double> > brad = 
00085     Sacado::Fad::DFad<double>(num_deriv, 1, b);
00086   Sacado::Rad::ADvar< Sacado::Fad::DFad<double> > crad = c;
00087   Sacado::Rad::ADvar< Sacado::Fad::DFad<double> > rrad;
00088 
00089   // Compute function
00090   double r = func(a, b, c);
00091 
00092   // Compute derivative analytically
00093   double drda, drdb;
00094   func_deriv(a, b, c, drda, drdb);
00095 
00096   // Compute second derivative analytically
00097   double d2rda2, d2rdb2, d2rdadb;
00098   func_deriv2(a, b, c, d2rda2, d2rdb2, d2rdadb);
00099 
00100   // Compute function and derivative with AD
00101   rrad = func(arad, brad, crad);
00102 
00103   Sacado::Rad::ADvar< Sacado::Fad::DFad<double> >::Gradcomp();
00104 
00105   // Extract value and derivatives
00106   double r_ad = rrad.val().val();       // r
00107   double drda_ad = arad.adj().val();    // dr/da
00108   double drdb_ad = brad.adj().val();    // dr/db
00109   double d2rda2_ad = arad.adj().dx(0);  // d^2r/da^2
00110   double d2rdadb_ad = arad.adj().dx(1); // d^2r/dadb
00111   double d2rdbda_ad = brad.adj().dx(0); // d^2r/dbda
00112   double d2rdb2_ad = brad.adj().dx(1);  // d^2/db^2
00113 
00114   // Print the results
00115   int p = 4;
00116   int w = p+7;
00117   std::cout.setf(std::ios::scientific);
00118   std::cout.precision(p);
00119   std::cout << "        r = " << std::setw(w) << r << " (original) == " 
00120       << std::setw(w) << r_ad << " (AD) Error = " << std::setw(w) 
00121       << r - r_ad << std::endl
00122       << "    dr/da = " << std::setw(w) << drda << " (analytic) == " 
00123       << std::setw(w) << drda_ad << " (AD) Error = " << std::setw(w) 
00124       << drda - drda_ad << std::endl
00125       << "    dr/db = " << std::setw(w) << drdb << " (analytic) == " 
00126       << std::setw(w) << drdb_ad << " (AD) Error = " << std::setw(w) 
00127       << drdb - drdb_ad << std::endl
00128       << "d^2r/da^2 = " << std::setw(w) << d2rda2 << " (analytic) == " 
00129       << std::setw(w) << d2rda2_ad << " (AD) Error = " << std::setw(w) 
00130       << d2rda2 - d2rda2_ad << std::endl
00131       << "d^2r/db^2 = " << std::setw(w) << d2rdb2 << " (analytic) == " 
00132       << std::setw(w) << d2rdb2_ad << " (AD) Error = " << std::setw(w) 
00133       << d2rdb2 - d2rdb2_ad << std::endl
00134       << "d^2r/dadb = " << std::setw(w) << d2rdadb << " (analytic) == " 
00135       << std::setw(w) << d2rdadb_ad << " (AD) Error = " << std::setw(w) 
00136       << d2rdadb - d2rdadb_ad << std::endl
00137       << "d^2r/dbda = " << std::setw(w) << d2rdadb << " (analytic) == " 
00138       << std::setw(w) << d2rdbda_ad << " (AD) Error = " << std::setw(w) 
00139       << d2rdadb - d2rdbda_ad << std::endl;
00140 
00141   double tol = 1.0e-14;
00142   if (std::fabs(r - r_ad)             < tol &&
00143       std::fabs(drda - drda_ad)       < tol &&
00144       std::fabs(drdb - drdb_ad)       < tol &&
00145       std::fabs(d2rda2 - d2rda2_ad)   < tol &&
00146       std::fabs(d2rdb2 - d2rdb2_ad)   < tol &&
00147       std::fabs(d2rdadb - d2rdadb_ad) < tol) {
00148     std::cout << "\nExample passed!" << std::endl;
00149     return 0;
00150   }
00151   else {
00152     std::cout <<"\nSomething is wrong, example failed!" << std::endl;
00153     return 1;
00154   }
00155 }

Generated on Wed May 12 21:39:39 2010 for Sacado Package Browser (Single Doxygen Collection) by  doxygen 1.4.7