Sacado Package Browser (Single Doxygen Collection) Version of the Day
fad_lj_grad.cpp
Go to the documentation of this file.
00001 // $Id$ 
00002 // $Source$ 
00003 // @HEADER
00004 // ***********************************************************************
00005 // 
00006 //                           Sacado Package
00007 //                 Copyright (2006) Sandia Corporation
00008 // 
00009 // Under the terms of Contract DE-AC04-94AL85000 with Sandia Corporation,
00010 // the U.S. Government retains certain rights in this software.
00011 // 
00012 // This library is free software; you can redistribute it and/or modify
00013 // it under the terms of the GNU Lesser General Public License as
00014 // published by the Free Software Foundation; either version 2.1 of the
00015 // License, or (at your option) any later version.
00016 //  
00017 // This library is distributed in the hope that it will be useful, but
00018 // WITHOUT ANY WARRANTY; without even the implied warranty of
00019 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
00020 // Lesser General Public License for more details.
00021 //  
00022 // You should have received a copy of the GNU Lesser General Public
00023 // License along with this library; if not, write to the Free Software
00024 // Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307
00025 // USA
00026 // Questions? Contact David M. Gay (dmgay@sandia.gov) or Eric T. Phipps
00027 // (etphipp@sandia.gov).
00028 // 
00029 // ***********************************************************************
00030 // @HEADER
00031 
00032 #include "Sacado_Random.hpp"
00033 #include "Sacado.hpp"
00034 #include "Sacado_CacheFad_DFad.hpp"
00035 
00036 #include "Fad/fad.h"
00037 #include "TinyFadET/tfad.h"
00038 
00039 #include "Teuchos_Time.hpp"
00040 #include "Teuchos_CommandLineProcessor.hpp"
00041 
00042 // A simple performance test that computes the derivative of a simple
00043 // expression using many variants of Fad.
00044 
00045 template <>
00046 Sacado::Fad::MemPool* Sacado::Fad::MemPoolStorage<double>::defaultPool_ = NULL;
00047 
00048 void FAD::error(const char *msg) {
00049   std::cout << msg << std::endl;
00050 }
00051 
00052 namespace {
00053   double xi[3], xj[3], pa[4], f[3], delr[3];
00054 }
00055 
00056 template <typename T>
00057 inline T
00058 vec3_distsq(const T xi[], const double xj[]) {
00059   T delr0 = xi[0]-xj[0];
00060   T delr1 = xi[1]-xj[1];
00061   T delr2 = xi[2]-xj[2];
00062   return delr0*delr0 + delr1*delr1 + delr2*delr2;
00063 }
00064 
00065 template <typename T>
00066 inline T
00067 vec3_distsq(const T xi[], const double xj[], T delr[]) {
00068   delr[0] = xi[0]-xj[0];
00069   delr[1] = xi[1]-xj[1];
00070   delr[2] = xi[2]-xj[2];
00071   return delr[0]*delr[0] + delr[1]*delr[1] + delr[2]*delr[2];
00072 }
00073 
00074 template <typename T>
00075 inline void
00076 lj(const T xi[], const double xj[], T& energy) {
00077   T delr2 = vec3_distsq(xi,xj);
00078   T delr_2 = 1.0/delr2;
00079   T delr_6 = delr_2*delr_2*delr_2;
00080   energy = (pa[1]*delr_6 - pa[2])*delr_6 - pa[3];
00081 }
00082 
00083 inline void
00084 lj_and_grad(const double xi[], const double xj[], double& energy,
00085       double f[]) {
00086   double delr2 = vec3_distsq(xi,xj,delr);
00087   double delr_2 = 1.0/delr2;
00088   double delr_6 = delr_2*delr_2*delr_2;
00089   energy = (pa[1]*delr_6 - pa[2])*delr_6 - pa[3];
00090   double tmp = (-12.0*pa[1]*delr_6 - 6.0*pa[2])*delr_6*delr_2;
00091   f[0] = delr[0]*tmp;
00092   f[1] = delr[1]*tmp;
00093   f[2] = delr[2]*tmp;
00094 }
00095 
00096 template <typename FadType>
00097 double
00098 do_time(int nloop)
00099 {
00100   Teuchos::Time timer("lj", false);
00101   FadType xi_fad[3], energy;
00102 
00103   for (int i=0; i<3; i++) {
00104     xi_fad[i] = FadType(3, i, xi[i]);
00105   }
00106   
00107   timer.start(true);
00108   for (int j=0; j<nloop; j++) {
00109 
00110     lj(xi_fad, xj, energy);
00111 
00112     for (int i=0; i<3; i++)
00113       f[i] += -energy.fastAccessDx(i);
00114   }
00115   timer.stop();
00116 
00117   return timer.totalElapsedTime() / nloop;
00118 }
00119 
00120 double
00121 do_time_analytic(int nloop)
00122 {
00123   Teuchos::Time timer("lj", false);
00124   double energy, ff[3];
00125 
00126   timer.start(true);
00127   for (int j=0; j<nloop; j++) {
00128 
00129     lj_and_grad(xi, xj, energy, ff);
00130 
00131     for (int i=0; i<3; i++)
00132       f[i] += -ff[i];
00133 
00134   }
00135   timer.stop();
00136 
00137   return timer.totalElapsedTime() / nloop;
00138 }
00139 
00140 int main(int argc, char* argv[]) {
00141   int ierr = 0;
00142 
00143   try {
00144     double t, ta;
00145     int p = 2;
00146     int w = p+7;
00147 
00148     // Set up command line options
00149     Teuchos::CommandLineProcessor clp;
00150     clp.setDocString("This program tests the speed of various forward mode AD implementations for a single multiplication operation");
00151     int nloop = 1000000;
00152     clp.setOption("nloop", &nloop, "Number of loops");
00153 
00154     // Parse options
00155     Teuchos::CommandLineProcessor::EParseCommandLineReturn
00156       parseReturn= clp.parse(argc, argv);
00157     if(parseReturn != Teuchos::CommandLineProcessor::PARSE_SUCCESSFUL)
00158       return 1;
00159 
00160     // Memory pool & manager
00161     Sacado::Fad::MemPoolManager<double> poolManager(3);
00162     Sacado::Fad::MemPool* pool = poolManager.getMemoryPool(3);
00163     Sacado::Fad::DMFad<double>::setDefaultPool(pool);
00164 
00165     std::cout.setf(std::ios::scientific);
00166     std::cout.precision(p);
00167     std::cout << "Times (sec) nloop =  " << nloop << ":  " << std::endl;
00168 
00169     Sacado::Random<double> urand(0.0, 1.0);
00170     for (int i=0; i<3; i++) {
00171       xi[i] = urand.number();
00172       xj[i] = urand.number();
00173       pa[i] = urand.number();
00174     }
00175     pa[3] = urand.number();
00176 
00177     ta = do_time_analytic(nloop);
00178     std::cout << "Analytic:  " << std::setw(w) << ta << std::endl;
00179 
00180     t = do_time< FAD::TFad<3,double> >(nloop);
00181     std::cout << "TFad:      " << std::setw(w) << t << "\t" << std::setw(w) << t/ta << std::endl;
00182 
00183     t = do_time< FAD::Fad<double> >(nloop);
00184     std::cout << "Fad:       " << std::setw(w) << t << "\t" << std::setw(w) << t/ta << std::endl;
00185 
00186     t = do_time< Sacado::Fad::SFad<double,3> >(nloop);
00187     std::cout << "SFad:      " << std::setw(w) << t << "\t" << std::setw(w) << t/ta << std::endl;
00188 
00189     t = do_time< Sacado::Fad::SLFad<double,3> >(nloop);
00190     std::cout << "SLFad:     " << std::setw(w) << t << "\t" << std::setw(w) << t/ta << std::endl;
00191     
00192     t = do_time< Sacado::Fad::DFad<double> >(nloop);
00193     std::cout << "DFad:      " << std::setw(w) << t << "\t" << std::setw(w) << t/ta << std::endl;
00194 
00195     t = do_time< Sacado::Fad::DMFad<double> >(nloop);
00196     std::cout << "DMFad:     " << std::setw(w) << t << "\t" << std::setw(w) << t/ta << std::endl; 
00197 
00198     t = do_time< Sacado::ELRFad::SFad<double,3> >(nloop);
00199     std::cout << "ELRSFad:   " << std::setw(w) << t << "\t" << std::setw(w) << t/ta << std::endl;
00200 
00201     t = do_time< Sacado::ELRFad::SLFad<double,3> >(nloop);
00202     std::cout << "ELRSLFad:  " << std::setw(w) << t << "\t" << std::setw(w) << t/ta << std::endl;
00203 
00204     t = do_time< Sacado::ELRFad::DFad<double> >(nloop);
00205     std::cout << "ELRDFad:   " << std::setw(w) << t << "\t" << std::setw(w) << t/ta << std::endl;
00206     
00207     t = do_time< Sacado::CacheFad::DFad<double> >(nloop);
00208     std::cout << "CacheFad:  " << std::setw(w) << t << "\t" << std::setw(w) << t/ta << std::endl;
00209 
00210     t = do_time< Sacado::Fad::DVFad<double> >(nloop);
00211     std::cout << "DVFad:     " << std::setw(w) << t << "\t" << std::setw(w) << t/ta << std::endl;
00212     
00213   }
00214   catch (std::exception& e) {
00215     cout << e.what() << endl;
00216     ierr = 1;
00217   }
00218   catch (const char *s) {
00219     cout << s << endl;
00220     ierr = 1;
00221   }
00222   catch (...) {
00223     cout << "Caught unknown exception!" << endl;
00224     ierr = 1;
00225   }
00226 
00227   return ierr;
00228 }
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines