fad_expr.cpp

Go to the documentation of this file.
00001 // $Id$ 
00002 // $Source$ 
00003 // @HEADER
00004 // ***********************************************************************
00005 // 
00006 //                           Sacado Package
00007 //                 Copyright (2006) Sandia Corporation
00008 // 
00009 // Under the terms of Contract DE-AC04-94AL85000 with Sandia Corporation,
00010 // the U.S. Government retains certain rights in this software.
00011 // 
00012 // This library is free software; you can redistribute it and/or modify
00013 // it under the terms of the GNU Lesser General Public License as
00014 // published by the Free Software Foundation; either version 2.1 of the
00015 // License, or (at your option) any later version.
00016 //  
00017 // This library is distributed in the hope that it will be useful, but
00018 // WITHOUT ANY WARRANTY; without even the implied warranty of
00019 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
00020 // Lesser General Public License for more details.
00021 //  
00022 // You should have received a copy of the GNU Lesser General Public
00023 // License along with this library; if not, write to the Free Software
00024 // Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307
00025 // USA
00026 // Questions? Contact David M. Gay (dmgay@sandia.gov) or Eric T. Phipps
00027 // (etphipp@sandia.gov).
00028 // 
00029 // ***********************************************************************
00030 // @HEADER
00031 
00032 #include "Sacado_Random.hpp"
00033 #include "Sacado.hpp"
00034 #include "Sacado_CacheFad_DFad.hpp"
00035 
00036 #include "Fad/fad.h"
00037 #include "TinyFadET/tfad.h"
00038 
00039 #include "Teuchos_Time.hpp"
00040 #include "Teuchos_CommandLineProcessor.hpp"
00041 
00042 // A simple performance test that computes the derivative of a simple
00043 // expression using many variants of Fad.
00044 
00045 template <>
00046 Sacado::Fad::MemPool* Sacado::Fad::MemPoolStorage<double>::defaultPool_ = NULL;
00047 
00048 void FAD::error(const char *msg) {
00049   std::cout << msg << std::endl;
00050 }
00051 
00052 template <typename T>
00053 inline void
00054 func1(const T& x1, const T& x2, T& y) {
00055   y = (x1*x2 + sin(x1)/x2);
00056 }
00057 
00058 inline void
00059 func1_and_deriv(int n, double x1, double x2, double* x1dot, double* x2dot, 
00060     double& y, double* ydot) {
00061   double s = sin(x1);
00062   double c = cos(x1);
00063   double t = s/x2;
00064   double t1 = x2 + c/x2;
00065   double t2 = x1 - t/x2;
00066   y = x1*x2 + t;
00067   for (int i=0; i<10; i++)
00068     ydot[i] = t1*x1dot[i] + t2*x2dot[i];
00069 }
00070 
00071 template <typename FadType>
00072 double
00073 do_time(int nderiv, int nloop)
00074 {
00075   FadType x1, x2, y;
00076   Sacado::Random<double> urand(0.0, 1.0);
00077 
00078   x1 = FadType(nderiv,  urand.number());
00079   x2 = FadType(nderiv,  urand.number());
00080   y = 0.0;
00081   for (int j=0; j<nderiv; j++) {
00082     x1.fastAccessDx(j) = urand.number();
00083     x2.fastAccessDx(j) = urand.number();
00084   }
00085   
00086   Teuchos::Time timer("mult", false);
00087   timer.start(true);
00088   for (int j=0; j<nloop; j++) {
00089     func1(x1, x2, y);
00090   }
00091   timer.stop();
00092 
00093   return timer.totalElapsedTime() / nloop;
00094 }
00095 
00096 double
00097 do_time_analytic(int nderiv, int nloop)
00098 {
00099   double x1, x2, y;
00100   double *x1dot, *x2dot, *ydot;
00101   Sacado::Random<double> urand(0.0, 1.0);
00102 
00103   x1 = urand.number();
00104   x2 = urand.number();
00105   y = 0.0;
00106   x1dot = new double[nderiv];
00107   x2dot = new double[nderiv];
00108   ydot = new double[nderiv];
00109   for (int j=0; j<nderiv; j++) {
00110     x1dot[j] = urand.number();
00111     x2dot[j] = urand.number();
00112   }
00113   
00114   Teuchos::Time timer("mult", false);
00115   timer.start(true);
00116   for (int j=0; j<nloop; j++) {
00117     func1_and_deriv(nderiv, x1, x2, x1dot, x2dot, y, ydot);
00118   }
00119   timer.stop();
00120 
00121   return timer.totalElapsedTime() / nloop;
00122 }
00123 
00124 int main(int argc, char* argv[]) {
00125   int ierr = 0;
00126 
00127   try {
00128     double t, ta;
00129     int p = 2;
00130     int w = p+7;
00131 
00132     // Set up command line options
00133     Teuchos::CommandLineProcessor clp;
00134     clp.setDocString("This program tests the speed of various forward mode AD implementations for a single multiplication operation");
00135     int nderiv = 10;
00136     clp.setOption("nderiv", &nderiv, "Number of derivative components");
00137     int nloop = 1000000;
00138     clp.setOption("nloop", &nloop, "Number of loops");
00139 
00140     // Parse options
00141     Teuchos::CommandLineProcessor::EParseCommandLineReturn
00142       parseReturn= clp.parse(argc, argv);
00143     if(parseReturn != Teuchos::CommandLineProcessor::PARSE_SUCCESSFUL)
00144       return 1;
00145 
00146     // Memory pool & manager
00147     Sacado::Fad::MemPoolManager<double> poolManager(10);
00148     Sacado::Fad::MemPool* pool = poolManager.getMemoryPool(nderiv);
00149     Sacado::Fad::DMFad<double>::setDefaultPool(pool);
00150 
00151     std::cout.setf(std::ios::scientific);
00152     std::cout.precision(p);
00153     std::cout << "Times (sec) for nderiv = " << nderiv 
00154         << " nloop =  " << nloop << ":  " << std::endl;
00155 
00156     ta = do_time_analytic(nderiv, nloop);
00157     std::cout << "Analytic:  " << std::setw(w) << ta << std::endl;
00158 
00159     t = do_time< FAD::TFad<10,double> >(nderiv, nloop);
00160     std::cout << "TFad:      " << std::setw(w) << t << "\t" << std::setw(w) << t/ta << std::endl;
00161 
00162     t = do_time< FAD::Fad<double> >(nderiv, nloop);
00163     std::cout << "Fad:       " << std::setw(w) << t << "\t" << std::setw(w) << t/ta << std::endl;
00164 
00165     t = do_time< Sacado::Fad::SFad<double,10> >(nderiv, nloop);
00166     std::cout << "SFad:      " << std::setw(w) << t << "\t" << std::setw(w) << t/ta << std::endl;
00167 
00168     t = do_time< Sacado::Fad::SLFad<double,10> >(nderiv, nloop);
00169     std::cout << "SLFad:     " << std::setw(w) << t << "\t" << std::setw(w) << t/ta << std::endl;
00170     
00171     t = do_time< Sacado::Fad::DFad<double> >(nderiv, nloop);
00172     std::cout << "DFad:      " << std::setw(w) << t << "\t" << std::setw(w) << t/ta << std::endl;
00173 
00174     t = do_time< Sacado::Fad::DMFad<double> >(nderiv, nloop);
00175     std::cout << "DMFad:     " << std::setw(w) << t << "\t" << std::setw(w) << t/ta << std::endl; 
00176 
00177     t = do_time< Sacado::ELRFad::SFad<double,10> >(nderiv, nloop);
00178     std::cout << "ELRSFad:   " << std::setw(w) << t << "\t" << std::setw(w) << t/ta << std::endl;
00179 
00180     t = do_time< Sacado::ELRFad::SLFad<double,10> >(nderiv, nloop);
00181     std::cout << "ELRSLFad:  " << std::setw(w) << t << "\t" << std::setw(w) << t/ta << std::endl;
00182 
00183     t = do_time< Sacado::ELRFad::DFad<double> >(nderiv, nloop);
00184     std::cout << "ELRDFad:   " << std::setw(w) << t << "\t" << std::setw(w) << t/ta << std::endl;
00185     
00186     t = do_time< Sacado::CacheFad::DFad<double> >(nderiv, nloop);
00187     std::cout << "CacheFad:  " << std::setw(w) << t << "\t" << std::setw(w) << t/ta << std::endl;
00188 
00189     t = do_time< Sacado::Fad::DVFad<double> >(nderiv, nloop);
00190     std::cout << "DVFad:     " << std::setw(w) << t << "\t" << std::setw(w) << t/ta << std::endl;
00191     
00192   }
00193   catch (std::exception& e) {
00194     cout << e.what() << endl;
00195     ierr = 1;
00196   }
00197   catch (const char *s) {
00198     cout << s << endl;
00199     ierr = 1;
00200   }
00201   catch (...) {
00202     cout << "Caught unknown exception!" << endl;
00203     ierr = 1;
00204   }
00205 
00206   return ierr;
00207 }
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines
Generated on Wed Apr 13 10:19:30 2011 for Sacado Package Browser (Single Doxygen Collection) by  doxygen 1.6.3