tradvec_example.cpp

Go to the documentation of this file.
00001 // @HEADER
00002 // ***********************************************************************
00003 //
00004 //                           Sacado Package
00005 //                 Copyright (2006) Sandia Corporation
00006 //
00007 // Under the terms of Contract DE-AC04-94AL85000 with Sandia Corporation,
00008 // the U.S. Government retains certain rights in this software.
00009 //
00010 // This library is free software; you can redistribute it and/or modify
00011 // it under the terms of the GNU Lesser General Public License as
00012 // published by the Free Software Foundation; either version 2.1 of the
00013 // License, or (at your option) any later version.
00014 //
00015 // This library is distributed in the hope that it will be useful, but
00016 // WITHOUT ANY WARRANTY; without even the implied warranty of
00017 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
00018 // Lesser General Public License for more details.
00019 //
00020 // You should have received a copy of the GNU Lesser General Public
00021 // License along with this library; if not, write to the Free Software
00022 // Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307
00023 // USA
00024 // Questions? Contact David M. Gay (dmgay@sandia.gov) or Eric T. Phipps
00025 // (etphipp@sandia.gov).
00026 //
00027 // ***********************************************************************
00028 // @HEADER
00029 
00030 // tradvec_example
00031 //
00032 //  usage:
00033 //     tradvec_example
00034 //
00035 //  output:
00036 //     prints the results of differentiating two simple functions and their
00037 //     sum with reverse mode AD using functions Outvar_Gradcomp and
00038 //     Weighted_GradcompVec in the Sacado::RadVec::ADvar class.
00039 
00040 /* Simple test of Outvar_Gradcomp and Weighted_GradcompVec. */
00041 
00042 #include "Sacado_tradvec.hpp"
00043 #include <stdio.h>
00044 
00045 #ifdef _MSC_VER
00046 # define snprintf _snprintf
00047 #endif
00048 
00049 typedef Sacado::RadVec::ADvar<double> ADVar;
00050 
00051  ADVar
00052 foo(double d, ADVar x, ADVar y)
00053 { return d*x*y; }
00054 
00055  ADVar
00056 goo(double d, ADVar x, ADVar y)
00057 { return d*(x*x + y*y); }
00058 
00059  typedef struct
00060 ExpectedAnswer {
00061   const char *name;
00062   double v; // value of name
00063   double dvdx;  // partial w.r.t. x
00064   double dvdy;  // partial w.r.t. y
00065   };
00066 
00067  static ExpectedAnswer expected[4] = {
00068   { "a", 6., 3., 2. },
00069   { "b", 13., 4., 6. },
00070   { "c", 19., 7., 8. },
00071   { "(a + b + c)", 38., 14., 16. }};
00072 
00073  int
00074 botch(ExpectedAnswer *e, const char *partial, double got, double wanted)
00075 {
00076   char buf[32];
00077   const char *what;
00078 
00079   what = e->name;
00080   if (partial) {
00081     snprintf(buf, sizeof(buf), "d%s/d%s", what, partial);
00082     what = buf;
00083     }
00084   fprintf(stderr, "Expected %s = %g, but got %g\n", what, wanted, got);
00085   return 1;
00086   }
00087 
00088  int
00089 acheck(int k, double d, double v, double dvdx, double dvdy)
00090 {
00091   ExpectedAnswer *e = &expected[k];
00092   int nbad = 0;
00093 
00094   /* There should be no round-off error in this simple example, so we */
00095   /* use exact comparisons, rather than, say, relative differences. */
00096 
00097   if (v != d*e->v)
00098     nbad += botch(e, 0, v, d*e->v);
00099   if (dvdx != d*e->dvdx)
00100     nbad += botch(e, "x", dvdx, d*e->dvdx);
00101   if (dvdy != d*e->dvdy)
00102     nbad += botch(e, "y", dvdy, d*e->dvdy);
00103   return nbad;
00104   }
00105 
00106  int
00107 main(void)
00108 {
00109   double d, z[4];
00110   int i, nbad;
00111 
00112   static ADVar a, b, c, x, y, *v[3] = {&a, &b, &c};
00113   static ADVar **V[4] = {v, v+1, v+2, v};
00114   static size_t np[4] = {1, 1, 1, 3};
00115   static double w[3] = { 1., 1., 1. };
00116   static double *W[4] = {w, w, w, w};
00117 
00118   nbad = 0;
00119   for(d = 1.; d <= 2.; ++d) {
00120     printf("\nd = %g\n", d);
00121     x = 2;
00122     y = 3;
00123     a = foo(d,x,y);
00124     b = goo(d,x,y);
00125     c = a + b;
00126 
00127     ADVar::Outvar_Gradcomp(a);
00128     printf("a = %g\n", a.val());
00129     printf("da/dx = %g\n", x.adj());
00130     printf("da/dy = %g\n", y.adj());
00131     nbad += acheck(0, d, a.val(), x.adj(), y.adj());
00132     z[0] = a.val();
00133 
00134     ADVar::Outvar_Gradcomp(b);
00135     printf("b = %g\n", b.val());
00136     printf("db/dx = %g\n", x.adj());
00137     printf("db/dy = %g\n", y.adj());
00138     nbad += acheck(1, d, b.val(), x.adj(), y.adj());
00139     z[1] = b.val();
00140 
00141     ADVar::Outvar_Gradcomp(c);
00142     printf("c = %g (should be a + b)\n", c.val());
00143     printf("dc/dx = %g\n", x.adj());
00144     printf("dc/dy = %g\n", y.adj());
00145     nbad += acheck(2, d, c.val(), x.adj(), y.adj());
00146     z[2] = c.val();
00147     z[3] = z[0] + z[1] + z[2];
00148 
00149     ADVar::Weighted_GradcompVec(4,np,V,W);
00150     for(i = 0; i < 4; ++i) {
00151       printf("w %d:\td/dx = %g\td/dy = %g\n", i, x.adj(i), y.adj(i));
00152       nbad += acheck(i, d, z[i], x.adj(i), y.adj(i));
00153       }
00154     }
00155   if (nbad == 0)
00156     printf("\nExample passed!\n");
00157   else
00158     printf("\nSomething is wrong, example failed!\n");
00159   return nbad > 0;
00160   }
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines
Generated on Wed Apr 13 10:19:40 2011 for Sacado Package Browser (Single Doxygen Collection) by  doxygen 1.6.3