|
Sacado Package Browser (Single Doxygen Collection) Version of the Day
|
00001 // $Id$ 00002 // $Source$ 00003 // @HEADER 00004 // *********************************************************************** 00005 // 00006 // Sacado Package 00007 // Copyright (2006) Sandia Corporation 00008 // 00009 // Under the terms of Contract DE-AC04-94AL85000 with Sandia Corporation, 00010 // the U.S. Government retains certain rights in this software. 00011 // 00012 // This library is free software; you can redistribute it and/or modify 00013 // it under the terms of the GNU Lesser General Public License as 00014 // published by the Free Software Foundation; either version 2.1 of the 00015 // License, or (at your option) any later version. 00016 // 00017 // This library is distributed in the hope that it will be useful, but 00018 // WITHOUT ANY WARRANTY; without even the implied warranty of 00019 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU 00020 // Lesser General Public License for more details. 00021 // 00022 // You should have received a copy of the GNU Lesser General Public 00023 // License along with this library; if not, write to the Free Software 00024 // Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 00025 // USA 00026 // Questions? Contact David M. Gay (dmgay@sandia.gov) or Eric T. Phipps 00027 // (etphipp@sandia.gov). 00028 // 00029 // *********************************************************************** 00030 // @HEADER 00031 00032 // dfad_sfc_example 00033 // 00034 // usage: 00035 // dfad_sfc_example 00036 // 00037 // output: 00038 // Uses the scalar flop counter to count the flops for a derivative 00039 // of a simple function using Sacado::Rad::ADvar 00040 00041 #include <iostream> 00042 #include <iomanip> 00043 00044 #include "Sacado.hpp" 00045 00046 // The function to differentiate 00047 template <typename ScalarT> 00048 ScalarT func(const ScalarT& a, const ScalarT& b, const ScalarT& c) { 00049 ScalarT r = c*std::log(b+1.)/std::sin(a); 00050 00051 return r; 00052 } 00053 00054 // The analytic derivative of func(a,b,c) with respect to a and b 00055 template <typename ScalarT> 00056 void func_deriv(const ScalarT& a, const ScalarT& b, const ScalarT& c, 00057 ScalarT& drda, ScalarT& drdb) 00058 { 00059 drda = -(c*std::log(b+1.)/std::pow(std::sin(a),2))*std::cos(a); 00060 drdb = c / ((b+1.)*std::sin(a)); 00061 } 00062 00063 typedef Sacado::FlopCounterPack::ScalarFlopCounter<double> SFC; 00064 typedef Sacado::Rad::ADvar<SFC> RAD_SFC; 00065 00066 int main(int argc, char **argv) 00067 { 00068 double pi = std::atan(1.0)*4.0; 00069 00070 // Values of function arguments 00071 double a = pi/4; 00072 double b = 2.0; 00073 double c = 3.0; 00074 00075 // Compute function 00076 SFC as(a); 00077 SFC bs(b); 00078 SFC cs(c); 00079 SFC::resetCounters(); 00080 SFC rs = func(as, bs, cs); 00081 SFC::finalizeCounters(); 00082 00083 std::cout << "Flop counts for function evaluation:"; 00084 SFC::printCounters(std::cout); 00085 00086 // Compute derivative analytically 00087 SFC drdas, drdbs; 00088 SFC::resetCounters(); 00089 func_deriv(as, bs, cs, drdas, drdbs); 00090 SFC::finalizeCounters(); 00091 00092 std::cout << "\nFlop counts for analytic derivative evaluation:"; 00093 SFC::printCounters(std::cout); 00094 00095 // Compute function and derivative with AD 00096 RAD_SFC arad(a); 00097 RAD_SFC brad(b); 00098 RAD_SFC crad(c); 00099 SFC::resetCounters(); 00100 RAD_SFC rrad = func(arad, brad, crad); 00101 RAD_SFC::Gradcomp(); 00102 SFC::finalizeCounters(); 00103 00104 std::cout << "\nFlop counts for AD function and derivative evaluation:"; 00105 SFC::printCounters(std::cout); 00106 00107 // Extract value and derivatives 00108 double r = rs.val(); // r 00109 double drda = drdas.val(); // dr/da 00110 double drdb = drdbs.val(); // dr/db 00111 00112 double r_ad = rrad.val().val(); // r 00113 double drda_ad = arad.adj().val(); // dr/da 00114 double drdb_ad = brad.adj().val(); // dr/db 00115 00116 // Print the results 00117 int p = 4; 00118 int w = p+7; 00119 std::cout.setf(std::ios::scientific); 00120 std::cout.precision(p); 00121 std::cout << "\nValues/derivatives of computation" << std::endl 00122 << " r = " << r << " (original) == " << std::setw(w) << r_ad 00123 << " (AD) Error = " << std::setw(w) << r - r_ad << std::endl 00124 << "dr/da = " << std::setw(w) << drda << " (analytic) == " 00125 << std::setw(w) << drda_ad << " (AD) Error = " << std::setw(w) 00126 << drda - drda_ad << std::endl 00127 << "dr/db = " << std::setw(w) << drdb << " (analytic) == " 00128 << std::setw(w) << drdb_ad << " (AD) Error = " << std::setw(w) 00129 << drdb - drdb_ad << std::endl; 00130 00131 double tol = 1.0e-14; 00132 Sacado::FlopCounterPack::FlopCounts fc = SFC::getCounters(); 00133 // The Solaris and Irix CC compilers get higher counts for operator+= 00134 // and operator* than does g++. 00135 // The test on fc.totalFlopCount allows for this variation. 00136 if (std::fabs(r - r_ad) < tol && 00137 std::fabs(drda - drda_ad) < tol && 00138 std::fabs(drdb - drdb_ad) < tol&& 00139 (fc.totalFlopCount == 27 || fc.totalFlopCount == 29)) { 00140 std::cout << "\nExample passed!" << std::endl; 00141 return 0; 00142 } 00143 else { 00144 std::cout <<"\nSomething is wrong, example failed!" << std::endl; 00145 return 1; 00146 } 00147 }
1.7.4