00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034
00035
00036
00037
00038
00039
00040
00041 #include <iostream>
00042 #include <iomanip>
00043
00044 #include "Sacado.hpp"
00045
00046
00047 template <typename ScalarT>
00048 ScalarT func(const ScalarT& a, const ScalarT& b, const ScalarT& c) {
00049 ScalarT r = c*std::log(b+1.)/std::sin(a);
00050
00051 return r;
00052 }
00053
00054
00055 template <typename ScalarT>
00056 void func_deriv(const ScalarT& a, const ScalarT& b, const ScalarT& c,
00057 ScalarT& drda, ScalarT& drdb)
00058 {
00059 drda = -(c*std::log(b+1.)/std::pow(std::sin(a),2))*std::cos(a);
00060 drdb = c / ((b+1.)*std::sin(a));
00061 }
00062
00063 typedef Sacado::FlopCounterPack::ScalarFlopCounter<double> SFC;
00064 typedef Sacado::Fad::DFad<SFC> FAD_SFC;
00065
00066 int main(int argc, char **argv)
00067 {
00068 double pi = std::atan(1.0)*4.0;
00069
00070
00071 double a = pi/4;
00072 double b = 2.0;
00073 double c = 3.0;
00074
00075
00076 int num_deriv = 2;
00077
00078
00079 SFC as(a);
00080 SFC bs(b);
00081 SFC cs(c);
00082 SFC::resetCounters();
00083 SFC rs = func(as, bs, cs);
00084 SFC::finalizeCounters();
00085
00086 std::cout << "Flop counts for function evaluation:";
00087 SFC::printCounters(std::cout);
00088
00089
00090 SFC drdas, drdbs;
00091 SFC::resetCounters();
00092 func_deriv(as, bs, cs, drdas, drdbs);
00093 SFC::finalizeCounters();
00094
00095 std::cout << "\nFlop counts for analytic derivative evaluation:";
00096 SFC::printCounters(std::cout);
00097
00098
00099 FAD_SFC afad(num_deriv, 0, a);
00100 FAD_SFC bfad(num_deriv, 1, b);
00101 FAD_SFC cfad(c);
00102 SFC::resetCounters();
00103 FAD_SFC rfad = func(afad, bfad, cfad);
00104 SFC::finalizeCounters();
00105
00106 std::cout << "\nFlop counts for AD function and derivative evaluation:";
00107 SFC::printCounters(std::cout);
00108
00109
00110 double r = rs.val();
00111 double drda = drdas.val();
00112 double drdb = drdbs.val();
00113
00114 double r_ad = rfad.val().val();
00115 double drda_ad = rfad.dx(0).val();
00116 double drdb_ad = rfad.dx(1).val();
00117
00118
00119 int p = 4;
00120 int w = p+7;
00121 std::cout.setf(std::ios::scientific);
00122 std::cout.precision(p);
00123 std::cout << "\nValues/derivatives of computation" << std::endl
00124 << " r = " << r << " (original) == " << std::setw(w) << r_ad
00125 << " (AD) Error = " << std::setw(w) << r - r_ad << std::endl
00126 << "dr/da = " << std::setw(w) << drda << " (analytic) == "
00127 << std::setw(w) << drda_ad << " (AD) Error = " << std::setw(w)
00128 << drda - drda_ad << std::endl
00129 << "dr/db = " << std::setw(w) << drdb << " (analytic) == "
00130 << std::setw(w) << drdb_ad << " (AD) Error = " << std::setw(w)
00131 << drdb - drdb_ad << std::endl;
00132
00133 double tol = 1.0e-14;
00134 Sacado::FlopCounterPack::FlopCounts fc = SFC::getCounters();
00135
00136
00137
00138 if (std::fabs(r - r_ad) < tol &&
00139 std::fabs(drda - drda_ad) < tol &&
00140 std::fabs(drdb - drdb_ad) < tol &&
00141 (fc.totalFlopCount == 48 || fc.totalFlopCount == 51)) {
00142 std::cout << "\nExample passed!" << std::endl;
00143 return 0;
00144 }
00145 else {
00146 std::cout <<"\nSomething is wrong, example failed!" << std::endl;
00147 return 1;
00148 }
00149 }