Sacado Package Browser (Single Doxygen Collection) Version of the Day
trad_sfc_example.cpp
Go to the documentation of this file.
00001 // $Id$ 
00002 // $Source$ 
00003 // @HEADER
00004 // ***********************************************************************
00005 // 
00006 //                           Sacado Package
00007 //                 Copyright (2006) Sandia Corporation
00008 // 
00009 // Under the terms of Contract DE-AC04-94AL85000 with Sandia Corporation,
00010 // the U.S. Government retains certain rights in this software.
00011 // 
00012 // This library is free software; you can redistribute it and/or modify
00013 // it under the terms of the GNU Lesser General Public License as
00014 // published by the Free Software Foundation; either version 2.1 of the
00015 // License, or (at your option) any later version.
00016 //  
00017 // This library is distributed in the hope that it will be useful, but
00018 // WITHOUT ANY WARRANTY; without even the implied warranty of
00019 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
00020 // Lesser General Public License for more details.
00021 //  
00022 // You should have received a copy of the GNU Lesser General Public
00023 // License along with this library; if not, write to the Free Software
00024 // Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307
00025 // USA
00026 // Questions? Contact David M. Gay (dmgay@sandia.gov) or Eric T. Phipps
00027 // (etphipp@sandia.gov).
00028 // 
00029 // ***********************************************************************
00030 // @HEADER
00031 
00032 // dfad_sfc_example
00033 //
00034 //  usage: 
00035 //     dfad_sfc_example
00036 //
00037 //  output:  
00038 //     Uses the scalar flop counter to count the flops for a derivative
00039 //     of a simple function using Sacado::Rad::ADvar
00040 
00041 #include <iostream>
00042 #include <iomanip>
00043 
00044 #include "Sacado.hpp"
00045 
00046 // The function to differentiate
00047 template <typename ScalarT>
00048 ScalarT func(const ScalarT& a, const ScalarT& b, const ScalarT& c) {
00049   ScalarT r = c*std::log(b+1.)/std::sin(a);
00050 
00051   return r;
00052 }
00053 
00054 // The analytic derivative of func(a,b,c) with respect to a and b
00055 template <typename ScalarT>
00056 void func_deriv(const ScalarT& a, const ScalarT& b, const ScalarT& c, 
00057     ScalarT& drda, ScalarT& drdb)
00058 {
00059   drda = -(c*std::log(b+1.)/std::pow(std::sin(a),2))*std::cos(a);
00060   drdb = c / ((b+1.)*std::sin(a));
00061 }
00062 
00063 typedef Sacado::FlopCounterPack::ScalarFlopCounter<double> SFC;
00064 typedef Sacado::Rad::ADvar<SFC> RAD_SFC;
00065 
00066 int main(int argc, char **argv)
00067 {
00068   double pi = std::atan(1.0)*4.0;
00069 
00070   // Values of function arguments
00071   double a = pi/4;
00072   double b = 2.0;
00073   double c = 3.0;
00074 
00075   // Compute function
00076   SFC as(a);
00077   SFC bs(b);
00078   SFC cs(c);
00079   SFC::resetCounters();
00080   SFC rs = func(as, bs, cs);
00081   SFC::finalizeCounters();
00082 
00083   std::cout << "Flop counts for function evaluation:";
00084   SFC::printCounters(std::cout);
00085 
00086   // Compute derivative analytically
00087   SFC drdas, drdbs;
00088   SFC::resetCounters();
00089   func_deriv(as, bs, cs, drdas, drdbs);
00090   SFC::finalizeCounters();
00091 
00092   std::cout << "\nFlop counts for analytic derivative evaluation:";
00093   SFC::printCounters(std::cout);
00094 
00095   // Compute function and derivative with AD
00096   RAD_SFC arad(a); 
00097   RAD_SFC brad(b); 
00098   RAD_SFC crad(c);               
00099   SFC::resetCounters();
00100   RAD_SFC rrad = func(arad, brad, crad);
00101   RAD_SFC::Gradcomp();
00102   SFC::finalizeCounters();
00103 
00104   std::cout << "\nFlop counts for AD function and derivative evaluation:";
00105   SFC::printCounters(std::cout);
00106 
00107   // Extract value and derivatives
00108   double r = rs.val();               // r
00109   double drda = drdas.val();         // dr/da
00110   double drdb = drdbs.val();         // dr/db
00111 
00112   double r_ad = rrad.val().val();     // r
00113   double drda_ad = arad.adj().val();  // dr/da
00114   double drdb_ad = brad.adj().val();  // dr/db
00115 
00116   // Print the results
00117   int p = 4;
00118   int w = p+7;
00119   std::cout.setf(std::ios::scientific);
00120   std::cout.precision(p);
00121   std::cout << "\nValues/derivatives of computation" << std::endl
00122       << "    r =  " << r << " (original) == " << std::setw(w) << r_ad
00123       << " (AD) Error = " << std::setw(w) << r - r_ad << std::endl
00124       << "dr/da = " << std::setw(w) << drda << " (analytic) == " 
00125       << std::setw(w) << drda_ad << " (AD) Error = " << std::setw(w) 
00126       << drda - drda_ad << std::endl
00127       << "dr/db = " << std::setw(w) << drdb << " (analytic) == " 
00128       << std::setw(w) << drdb_ad << " (AD) Error = " << std::setw(w) 
00129       << drdb - drdb_ad << std::endl;
00130 
00131   double tol = 1.0e-14;
00132   Sacado::FlopCounterPack::FlopCounts fc = SFC::getCounters();
00133   // The Solaris and Irix CC compilers get higher counts for operator+=
00134   // and operator* than does g++.
00135   // The test on fc.totalFlopCount allows for this variation.
00136   if (std::fabs(r - r_ad)       < tol &&
00137       std::fabs(drda - drda_ad) < tol &&
00138       std::fabs(drdb - drdb_ad) < tol&&
00139       (fc.totalFlopCount == 27 || fc.totalFlopCount == 29)) {
00140     std::cout << "\nExample passed!" << std::endl;
00141     return 0;
00142   }
00143   else {
00144     std::cout <<"\nSomething is wrong, example failed!" << std::endl;
00145     return 1;
00146   }
00147 }
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines