|
Sacado Package Browser (Single Doxygen Collection) Version of the Day
|
00001 // $Id$ 00002 // $Source$ 00003 // @HEADER 00004 // *********************************************************************** 00005 // 00006 // Sacado Package 00007 // Copyright (2006) Sandia Corporation 00008 // 00009 // Under the terms of Contract DE-AC04-94AL85000 with Sandia Corporation, 00010 // the U.S. Government retains certain rights in this software. 00011 // 00012 // This library is free software; you can redistribute it and/or modify 00013 // it under the terms of the GNU Lesser General Public License as 00014 // published by the Free Software Foundation; either version 2.1 of the 00015 // License, or (at your option) any later version. 00016 // 00017 // This library is distributed in the hope that it will be useful, but 00018 // WITHOUT ANY WARRANTY; without even the implied warranty of 00019 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU 00020 // Lesser General Public License for more details. 00021 // 00022 // You should have received a copy of the GNU Lesser General Public 00023 // License along with this library; if not, write to the Free Software 00024 // Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 00025 // USA 00026 // Questions? Contact David M. Gay (dmgay@sandia.gov) or Eric T. Phipps 00027 // (etphipp@sandia.gov). 00028 // 00029 // *********************************************************************** 00030 // @HEADER 00031 00032 #include "Sacado_Random.hpp" 00033 #include "Sacado.hpp" 00034 #include "Sacado_CacheFad_DFad.hpp" 00035 00036 #include "Fad/fad.h" 00037 #include "TinyFadET/tfad.h" 00038 00039 #include "Teuchos_Time.hpp" 00040 #include "Teuchos_CommandLineProcessor.hpp" 00041 00042 // A simple performance test that computes the derivative of a simple 00043 // expression using many variants of Fad. 00044 00045 template <> 00046 Sacado::Fad::MemPool* Sacado::Fad::MemPoolStorage<double>::defaultPool_ = NULL; 00047 00048 void FAD::error(const char *msg) { 00049 std::cout << msg << std::endl; 00050 } 00051 00052 namespace { 00053 double xi[3], xj[3], pa[4], f[3], delr[3]; 00054 } 00055 00056 template <typename T> 00057 inline T 00058 vec3_distsq(const T xi[], const double xj[]) { 00059 T delr0 = xi[0]-xj[0]; 00060 T delr1 = xi[1]-xj[1]; 00061 T delr2 = xi[2]-xj[2]; 00062 return delr0*delr0 + delr1*delr1 + delr2*delr2; 00063 } 00064 00065 template <typename T> 00066 inline T 00067 vec3_distsq(const T xi[], const double xj[], T delr[]) { 00068 delr[0] = xi[0]-xj[0]; 00069 delr[1] = xi[1]-xj[1]; 00070 delr[2] = xi[2]-xj[2]; 00071 return delr[0]*delr[0] + delr[1]*delr[1] + delr[2]*delr[2]; 00072 } 00073 00074 template <typename T> 00075 inline void 00076 lj(const T xi[], const double xj[], T& energy) { 00077 T delr2 = vec3_distsq(xi,xj); 00078 T delr_2 = 1.0/delr2; 00079 T delr_6 = delr_2*delr_2*delr_2; 00080 energy = (pa[1]*delr_6 - pa[2])*delr_6 - pa[3]; 00081 } 00082 00083 inline void 00084 lj_and_grad(const double xi[], const double xj[], double& energy, 00085 double f[]) { 00086 double delr2 = vec3_distsq(xi,xj,delr); 00087 double delr_2 = 1.0/delr2; 00088 double delr_6 = delr_2*delr_2*delr_2; 00089 energy = (pa[1]*delr_6 - pa[2])*delr_6 - pa[3]; 00090 double tmp = (-12.0*pa[1]*delr_6 - 6.0*pa[2])*delr_6*delr_2; 00091 f[0] = delr[0]*tmp; 00092 f[1] = delr[1]*tmp; 00093 f[2] = delr[2]*tmp; 00094 } 00095 00096 template <typename FadType> 00097 double 00098 do_time(int nloop) 00099 { 00100 Teuchos::Time timer("lj", false); 00101 FadType xi_fad[3], energy; 00102 00103 for (int i=0; i<3; i++) { 00104 xi_fad[i] = FadType(3, i, xi[i]); 00105 } 00106 00107 timer.start(true); 00108 for (int j=0; j<nloop; j++) { 00109 00110 lj(xi_fad, xj, energy); 00111 00112 for (int i=0; i<3; i++) 00113 f[i] += -energy.fastAccessDx(i); 00114 } 00115 timer.stop(); 00116 00117 return timer.totalElapsedTime() / nloop; 00118 } 00119 00120 double 00121 do_time_analytic(int nloop) 00122 { 00123 Teuchos::Time timer("lj", false); 00124 double energy, ff[3]; 00125 00126 timer.start(true); 00127 for (int j=0; j<nloop; j++) { 00128 00129 lj_and_grad(xi, xj, energy, ff); 00130 00131 for (int i=0; i<3; i++) 00132 f[i] += -ff[i]; 00133 00134 } 00135 timer.stop(); 00136 00137 return timer.totalElapsedTime() / nloop; 00138 } 00139 00140 int main(int argc, char* argv[]) { 00141 int ierr = 0; 00142 00143 try { 00144 double t, ta; 00145 int p = 2; 00146 int w = p+7; 00147 00148 // Set up command line options 00149 Teuchos::CommandLineProcessor clp; 00150 clp.setDocString("This program tests the speed of various forward mode AD implementations for a single multiplication operation"); 00151 int nloop = 1000000; 00152 clp.setOption("nloop", &nloop, "Number of loops"); 00153 00154 // Parse options 00155 Teuchos::CommandLineProcessor::EParseCommandLineReturn 00156 parseReturn= clp.parse(argc, argv); 00157 if(parseReturn != Teuchos::CommandLineProcessor::PARSE_SUCCESSFUL) 00158 return 1; 00159 00160 // Memory pool & manager 00161 Sacado::Fad::MemPoolManager<double> poolManager(3); 00162 Sacado::Fad::MemPool* pool = poolManager.getMemoryPool(3); 00163 Sacado::Fad::DMFad<double>::setDefaultPool(pool); 00164 00165 std::cout.setf(std::ios::scientific); 00166 std::cout.precision(p); 00167 std::cout << "Times (sec) nloop = " << nloop << ": " << std::endl; 00168 00169 Sacado::Random<double> urand(0.0, 1.0); 00170 for (int i=0; i<3; i++) { 00171 xi[i] = urand.number(); 00172 xj[i] = urand.number(); 00173 pa[i] = urand.number(); 00174 } 00175 pa[3] = urand.number(); 00176 00177 ta = do_time_analytic(nloop); 00178 std::cout << "Analytic: " << std::setw(w) << ta << std::endl; 00179 00180 t = do_time< FAD::TFad<3,double> >(nloop); 00181 std::cout << "TFad: " << std::setw(w) << t << "\t" << std::setw(w) << t/ta << std::endl; 00182 00183 t = do_time< FAD::Fad<double> >(nloop); 00184 std::cout << "Fad: " << std::setw(w) << t << "\t" << std::setw(w) << t/ta << std::endl; 00185 00186 t = do_time< Sacado::Fad::SFad<double,3> >(nloop); 00187 std::cout << "SFad: " << std::setw(w) << t << "\t" << std::setw(w) << t/ta << std::endl; 00188 00189 t = do_time< Sacado::Fad::SLFad<double,3> >(nloop); 00190 std::cout << "SLFad: " << std::setw(w) << t << "\t" << std::setw(w) << t/ta << std::endl; 00191 00192 t = do_time< Sacado::Fad::DFad<double> >(nloop); 00193 std::cout << "DFad: " << std::setw(w) << t << "\t" << std::setw(w) << t/ta << std::endl; 00194 00195 t = do_time< Sacado::Fad::DMFad<double> >(nloop); 00196 std::cout << "DMFad: " << std::setw(w) << t << "\t" << std::setw(w) << t/ta << std::endl; 00197 00198 t = do_time< Sacado::ELRFad::SFad<double,3> >(nloop); 00199 std::cout << "ELRSFad: " << std::setw(w) << t << "\t" << std::setw(w) << t/ta << std::endl; 00200 00201 t = do_time< Sacado::ELRFad::SLFad<double,3> >(nloop); 00202 std::cout << "ELRSLFad: " << std::setw(w) << t << "\t" << std::setw(w) << t/ta << std::endl; 00203 00204 t = do_time< Sacado::ELRFad::DFad<double> >(nloop); 00205 std::cout << "ELRDFad: " << std::setw(w) << t << "\t" << std::setw(w) << t/ta << std::endl; 00206 00207 t = do_time< Sacado::CacheFad::DFad<double> >(nloop); 00208 std::cout << "CacheFad: " << std::setw(w) << t << "\t" << std::setw(w) << t/ta << std::endl; 00209 00210 t = do_time< Sacado::Fad::DVFad<double> >(nloop); 00211 std::cout << "DVFad: " << std::setw(w) << t << "\t" << std::setw(w) << t/ta << std::endl; 00212 00213 } 00214 catch (std::exception& e) { 00215 cout << e.what() << endl; 00216 ierr = 1; 00217 } 00218 catch (const char *s) { 00219 cout << s << endl; 00220 ierr = 1; 00221 } 00222 catch (...) { 00223 cout << "Caught unknown exception!" << endl; 00224 ierr = 1; 00225 } 00226 00227 return ierr; 00228 }
1.7.4