Sacado Package Browser (Single Doxygen Collection) Version of the Day
Sacado_Fad_ScalarTraitsImp.hpp
Go to the documentation of this file.
00001 // $Id$ 
00002 // $Source$ 
00003 // @HEADER
00004 // ***********************************************************************
00005 // 
00006 //                           Sacado Package
00007 //                 Copyright (2006) Sandia Corporation
00008 // 
00009 // Under the terms of Contract DE-AC04-94AL85000 with Sandia Corporation,
00010 // the U.S. Government retains certain rights in this software.
00011 // 
00012 // This library is free software; you can redistribute it and/or modify
00013 // it under the terms of the GNU Lesser General Public License as
00014 // published by the Free Software Foundation; either version 2.1 of the
00015 // License, or (at your option) any later version.
00016 //  
00017 // This library is distributed in the hope that it will be useful, but
00018 // WITHOUT ANY WARRANTY; without even the implied warranty of
00019 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
00020 // Lesser General Public License for more details.
00021 //  
00022 // You should have received a copy of the GNU Lesser General Public
00023 // License along with this library; if not, write to the Free Software
00024 // Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307
00025 // USA
00026 // Questions? Contact David M. Gay (dmgay@sandia.gov) or Eric T. Phipps
00027 // (etphipp@sandia.gov).
00028 // 
00029 // ***********************************************************************
00030 // @HEADER
00031 
00032 #ifndef SACADO_FAD_SCALARTRAITSIMP_HPP
00033 #define SACADO_FAD_SCALARTRAITSIMP_HPP
00034 
00035 #ifdef HAVE_SACADO_TEUCHOS
00036 
00037 #include "Teuchos_ScalarTraits.hpp"
00038 #include "Teuchos_SerializationTraits.hpp"
00039 #include "Teuchos_SerializationTraitsHelpers.hpp"
00040 #include "Teuchos_Assert.hpp"
00041 #include "Teuchos_RCP.hpp"
00042 #include "Teuchos_Array.hpp"
00043 #include "Sacado_mpl_apply.hpp"
00044 
00045 #include <iterator>
00046 
00047 namespace Sacado {
00048 
00049   namespace Fad {
00050 
00052     template <typename FadType>
00053     struct ScalarTraitsImp {
00054       typedef typename Sacado::ValueType<FadType>::type ValueT;
00055 
00056       typedef typename mpl::apply<FadType,typename Teuchos::ScalarTraits<ValueT>::magnitudeType>::type magnitudeType;
00057       typedef typename mpl::apply<FadType,typename Teuchos::ScalarTraits<ValueT>::halfPrecision>::type halfPrecision;
00058       typedef typename mpl::apply<FadType,typename Teuchos::ScalarTraits<ValueT>::doublePrecision>::type doublePrecision;
00059 
00060       static const bool isComplex = Teuchos::ScalarTraits<ValueT>::isComplex;
00061       static const bool isOrdinal = Teuchos::ScalarTraits<ValueT>::isOrdinal;
00062       static const bool isComparable = 
00063   Teuchos::ScalarTraits<ValueT>::isComparable;
00064       static const bool hasMachineParameters = 
00065   Teuchos::ScalarTraits<ValueT>::hasMachineParameters;
00066       static typename Teuchos::ScalarTraits<ValueT>::magnitudeType eps() {
00067   return Teuchos::ScalarTraits<ValueT>::eps();
00068       }
00069       static typename Teuchos::ScalarTraits<ValueT>::magnitudeType sfmin() {
00070   return Teuchos::ScalarTraits<ValueT>::sfmin();
00071       }
00072       static typename Teuchos::ScalarTraits<ValueT>::magnitudeType base()  {
00073   return Teuchos::ScalarTraits<ValueT>::base();
00074       }
00075       static typename Teuchos::ScalarTraits<ValueT>::magnitudeType prec()  {
00076   return Teuchos::ScalarTraits<ValueT>::prec();
00077       }
00078       static typename Teuchos::ScalarTraits<ValueT>::magnitudeType t()     {
00079   return Teuchos::ScalarTraits<ValueT>::t();
00080       }
00081       static typename Teuchos::ScalarTraits<ValueT>::magnitudeType rnd()   {
00082   return Teuchos::ScalarTraits<ValueT>::rnd();
00083       }
00084       static typename Teuchos::ScalarTraits<ValueT>::magnitudeType emin()  {
00085   return Teuchos::ScalarTraits<ValueT>::emin();
00086       }
00087       static typename Teuchos::ScalarTraits<ValueT>::magnitudeType rmin()  {
00088   return Teuchos::ScalarTraits<ValueT>::rmin();
00089       }
00090       static typename Teuchos::ScalarTraits<ValueT>::magnitudeType emax()  {
00091   return Teuchos::ScalarTraits<ValueT>::emax();
00092       }
00093       static typename Teuchos::ScalarTraits<ValueT>::magnitudeType rmax()  {
00094   return Teuchos::ScalarTraits<ValueT>::rmax();
00095       }
00096       static magnitudeType magnitude(const FadType& a) {
00097 #ifdef TEUCHOS_DEBUG
00098   TEUCHOS_SCALAR_TRAITS_NAN_INF_ERR(
00099     a, "Error, the input value to magnitude(...) a = " << a << 
00100     " can not be NaN!" );
00101   TEUCHOS_TEST_FOR_EXCEPTION(is_fad_real(a) == false, std::runtime_error,
00102          "Complex magnitude is not a differentiable "
00103          "function of complex inputs.");
00104 #endif
00105   //return std::fabs(a); 
00106   magnitudeType b(a.size(), 
00107       Teuchos::ScalarTraits<ValueT>::magnitude(a.val()));
00108   if (Teuchos::ScalarTraits<ValueT>::real(a.val()) >= 0)
00109     for (int i=0; i<a.size(); i++)
00110       b.fastAccessDx(i) = 
00111         Teuchos::ScalarTraits<ValueT>::magnitude(a.fastAccessDx(i));
00112   else
00113     for (int i=0; i<a.size(); i++)
00114       b.fastAccessDx(i) = 
00115         -Teuchos::ScalarTraits<ValueT>::magnitude(a.fastAccessDx(i));
00116   return b;
00117       }
00118       static ValueT zero()  { 
00119   return ValueT(0.0); 
00120       }
00121       static ValueT one()   { 
00122   return ValueT(1.0); 
00123       }
00124       
00125       // Conjugate is only defined for real derivative components
00126       static FadType conjugate(const FadType& x) {
00127 #ifdef TEUCHOS_DEBUG
00128   TEUCHOS_TEST_FOR_EXCEPTION(is_fad_real(x) == false, std::runtime_error,
00129          "Complex conjugate is not a differentiable "
00130          "function of complex inputs.");
00131 #endif
00132   FadType y = x;
00133   y.val() = Teuchos::ScalarTraits<ValueT>::conjugate(x.val());
00134   return y;
00135       }   
00136 
00137       // Real part is only defined for real derivative components
00138       static FadType real(const FadType& x) { 
00139 #ifdef TEUCHOS_DEBUG
00140   TEUCHOS_TEST_FOR_EXCEPTION(is_fad_real(x) == false, std::runtime_error,
00141          "Real component is not a differentiable "
00142          "function of complex inputs.");
00143 #endif
00144   FadType y = x;
00145   y.val() = Teuchos::ScalarTraits<ValueT>::real(x.val());
00146   return y;
00147       }
00148 
00149       // Imaginary part is only defined for real derivative components
00150       static FadType imag(const FadType& x) { 
00151 #ifdef TEUCHOS_DEBUG
00152   TEUCHOS_TEST_FOR_EXCEPTION(is_fad_real(x) == false, std::runtime_error,
00153          "Imaginary component is not a differentiable "
00154          "function of complex inputs.");
00155 #endif
00156   return FadType(Teuchos::ScalarTraits<ValueT>::imag(x.val()));
00157       }
00158 
00159       static ValueT nan() {
00160   return Teuchos::ScalarTraits<ValueT>::nan(); 
00161       }
00162       static bool isnaninf(const FadType& x) { 
00163   if (Teuchos::ScalarTraits<ValueT>::isnaninf(x.val()))
00164     return true;
00165   for (int i=0; i<x.size(); i++)
00166     if (Teuchos::ScalarTraits<ValueT>::isnaninf(x.dx(i)))
00167       return true;
00168   return false;
00169       }
00170       static void seedrandom(unsigned int s) { 
00171   Teuchos::ScalarTraits<ValueT>::seedrandom(s); 
00172       }
00173       static ValueT random() { 
00174   return Teuchos::ScalarTraits<ValueT>::random(); 
00175       }
00176       static std::string name() { 
00177   return Sacado::StringName<FadType>::eval(); 
00178       }
00179       static FadType squareroot(const FadType& x) {
00180 #ifdef TEUCHOS_DEBUG
00181   TEUCHOS_SCALAR_TRAITS_NAN_INF_ERR(
00182     x, "Error, the input value to squareroot(...) a = " << x << 
00183     " can not be NaN!" );
00184 #endif
00185   return std::sqrt(x); 
00186       }
00187       static FadType pow(const FadType& x, const FadType& y) { 
00188   return std::pow(x,y); 
00189       }
00190 
00191       // Helper function to determine whether a complex value is real
00192       static bool is_complex_real(const ValueT& x) {
00193   return 
00194     Teuchos::ScalarTraits<ValueT>::magnitude(x-Teuchos::ScalarTraits<ValueT>::real(x)) == 0;
00195       }
00196 
00197       // Helper function to determine whether a Fad type is real
00198       static bool is_fad_real(const FadType& x) {
00199   if (x.size() == 0)
00200     return true;
00201   if (Teuchos::ScalarTraits<ValueT>::isComplex) {
00202     if (!is_complex_real(x.val()))
00203       return false;
00204     for (int i=0; i<x.size(); i++)
00205       if (!is_complex_real(x.fastAccessDx(i)))
00206         return false;
00207   }
00208   return true;
00209       }
00210 
00211     }; // class ScalarTraitsImp
00212 
00214     template <typename Ordinal, typename FadType, typename Serializer>
00215     struct SerializationImp {
00216 
00217     private:
00218 
00220       typedef Teuchos::SerializationTraits<Ordinal,int> iSerT;
00221 
00223       typedef Teuchos::SerializationTraits<Ordinal,Ordinal> oSerT;
00224 
00226       typedef typename Sacado::ValueType<FadType>::type value_type;
00227 
00228     public:
00229 
00231       static const bool supportsDirectSerialization = false;
00232 
00234 
00235 
00237       static Ordinal fromCountToIndirectBytes(const Serializer& vs,
00238                 const Ordinal count, 
00239                 const FadType buffer[],
00240                 const Ordinal sz = 0) { 
00241   Ordinal bytes = 0;
00242         FadType *x = NULL;
00243         const FadType *cx;
00244   for (Ordinal i=0; i<count; i++) {
00245     int my_sz = buffer[i].size();
00246     int tot_sz = sz;
00247     if (sz == 0) tot_sz = my_sz;
00248     Ordinal b1 = iSerT::fromCountToIndirectBytes(1, &tot_sz);
00249     Ordinal b2 = vs.fromCountToIndirectBytes(1, &(buffer[i].val()));
00250     Ordinal b3 = oSerT::fromCountToIndirectBytes(1, &b2);
00251           Ordinal b4;
00252     if (tot_sz != my_sz) {
00253             if (x == NULL)
00254               x = new FadType(tot_sz, 0.0);
00255       *x = buffer[i];
00256             x->expand(tot_sz);
00257             cx = x;
00258     }
00259           else 
00260             cx = &(buffer[i]);
00261     b4 = vs.fromCountToIndirectBytes(tot_sz, cx->dx());
00262     Ordinal b5 = oSerT::fromCountToIndirectBytes(1, &b4);
00263     bytes += b1+b2+b3+b4+b5;
00264   }
00265         if (x != NULL)
00266           delete x;
00267   return bytes;
00268       }
00269 
00271       static void serialize (const Serializer& vs,
00272            const Ordinal count, 
00273            const FadType buffer[], 
00274            const Ordinal bytes, 
00275            char charBuffer[],
00276            const Ordinal sz = 0) { 
00277         FadType *x = NULL;
00278         const FadType *cx;
00279   for (Ordinal i=0; i<count; i++) {
00280     // First serialize size
00281     int my_sz = buffer[i].size();
00282     int tot_sz = sz;
00283     if (sz == 0) tot_sz = my_sz;
00284     Ordinal b1 = iSerT::fromCountToIndirectBytes(1, &tot_sz);
00285     iSerT::serialize(1, &tot_sz, b1, charBuffer);
00286     charBuffer += b1;
00287   
00288     // Next serialize value
00289     Ordinal b2 = vs.fromCountToIndirectBytes(1, &(buffer[i].val()));
00290     Ordinal b3 = oSerT::fromCountToIndirectBytes(1, &b2);
00291     oSerT::serialize(1, &b2, b3, charBuffer); 
00292     charBuffer += b3;
00293     vs.serialize(1, &(buffer[i].val()), b2, charBuffer);
00294     charBuffer += b2;
00295   
00296     // Next serialize derivative array
00297           Ordinal b4;
00298           if (tot_sz != my_sz) {
00299             if (x == NULL)
00300               x = new FadType(tot_sz, 0.0);
00301       *x = buffer[i];
00302             x->expand(tot_sz);
00303             cx = x;
00304     }
00305           else 
00306             cx = &(buffer[i]);
00307     b4 = vs.fromCountToIndirectBytes(tot_sz, cx->dx());
00308     Ordinal b5 = oSerT::fromCountToIndirectBytes(1, &b4);
00309     oSerT::serialize(1, &b4, b5, charBuffer); 
00310     charBuffer += b5;
00311     vs.serialize(tot_sz, cx->dx(), b4, charBuffer);
00312     charBuffer += b4;
00313   }
00314         if (x != NULL)
00315           delete x;
00316       }
00317 
00319       static Ordinal fromIndirectBytesToCount(const Serializer& vs,
00320                 const Ordinal bytes, 
00321                 const char charBuffer[],
00322                 const Ordinal sz = 0) {
00323   Ordinal count = 0;
00324   Ordinal bytes_used = 0;
00325   while (bytes_used < bytes) {
00326   
00327     // Bytes for size
00328     Ordinal b1 = iSerT::fromCountToDirectBytes(1);
00329     bytes_used += b1;
00330     charBuffer += b1;
00331   
00332     // Bytes for value
00333     Ordinal b3 = oSerT::fromCountToDirectBytes(1);
00334     const Ordinal *b2 = oSerT::convertFromCharPtr(charBuffer);
00335     bytes_used += b3;
00336     charBuffer += b3;
00337     bytes_used += *b2;
00338     charBuffer += *b2;
00339   
00340     // Bytes for derivative array
00341     Ordinal b5 = oSerT::fromCountToDirectBytes(1);
00342     const Ordinal *b4 = oSerT::convertFromCharPtr(charBuffer);
00343     bytes_used += b5;
00344     charBuffer += b5;
00345     bytes_used += *b4;
00346     charBuffer += *b4;
00347   
00348     ++count;
00349   }
00350   return count;
00351       }
00352 
00354       static void deserialize (const Serializer& vs,
00355              const Ordinal bytes, 
00356              const char charBuffer[], 
00357              const Ordinal count, 
00358              FadType buffer[],
00359              const Ordinal sz = 0) { 
00360   for (Ordinal i=0; i<count; i++) {
00361   
00362     // Deserialize size
00363     Ordinal b1 = iSerT::fromCountToDirectBytes(1);
00364     const int *my_sz = iSerT::convertFromCharPtr(charBuffer);
00365     charBuffer += b1;
00366   
00367     // Create empty Fad object of given size
00368     int tot_sz = sz;
00369     if (sz == 0) tot_sz = *my_sz;
00370     buffer[i] = FadType(tot_sz, 0.0);
00371   
00372     // Deserialize value
00373     Ordinal b3 = oSerT::fromCountToDirectBytes(1);
00374     const Ordinal *b2 = oSerT::convertFromCharPtr(charBuffer);
00375     charBuffer += b3;
00376     vs.deserialize(*b2, charBuffer, 1, &(buffer[i].val()));
00377     charBuffer += *b2;
00378   
00379     // Deserialize derivative array
00380     Ordinal b5 = oSerT::fromCountToDirectBytes(1);
00381     const Ordinal *b4 = oSerT::convertFromCharPtr(charBuffer);
00382     charBuffer += b5;
00383     vs.deserialize(*b4, charBuffer, *my_sz, 
00384            &(buffer[i].fastAccessDx(0)));
00385     charBuffer += *b4;
00386   }
00387       
00388       }
00389   
00391       
00392     };
00393 
00395     template <typename Ordinal, typename FadType>
00396     struct SerializationTraitsImp {
00397 
00398     private:
00399 
00401       typedef typename Sacado::ValueType<FadType>::type ValueT;
00402 
00404       typedef Teuchos::DefaultSerializer<Ordinal,ValueT> DS;
00405 
00407       typedef typename DS::DefaultSerializerType ValueSerializer;
00408 
00410       typedef SerializationImp<Ordinal,FadType,ValueSerializer> Imp;
00411 
00412     public:
00413 
00415       static const bool supportsDirectSerialization = 
00416   Imp::supportsDirectSerialization;
00417 
00419 
00420 
00422       static Ordinal fromCountToIndirectBytes(const Ordinal count, 
00423                 const FadType buffer[]) { 
00424   return Imp::fromCountToIndirectBytes(
00425     DS::getDefaultSerializer(), count, buffer);
00426       }
00427 
00429       static void serialize (const Ordinal count, 
00430            const FadType buffer[], 
00431            const Ordinal bytes, 
00432            char charBuffer[]) { 
00433   Imp::serialize(
00434     DS::getDefaultSerializer(), count, buffer, bytes, charBuffer);
00435       }
00436 
00438       static Ordinal fromIndirectBytesToCount(const Ordinal bytes, 
00439                 const char charBuffer[]) {
00440   return Imp::fromIndirectBytesToCount(
00441     DS::getDefaultSerializer(), bytes, charBuffer);
00442       }
00443 
00445       static void deserialize (const Ordinal bytes, 
00446              const char charBuffer[], 
00447              const Ordinal count, 
00448              FadType buffer[]) { 
00449   Imp::deserialize(
00450     DS::getDefaultSerializer(), bytes, charBuffer, count, buffer);
00451       }
00452   
00454       
00455     };
00456 
00458     template <typename Ordinal, typename FadType>
00459     struct StaticSerializationTraitsImp {
00460       typedef typename Sacado::ValueType<FadType>::type ValueT;
00461       typedef Teuchos::SerializationTraits<Ordinal,ValueT> vSerT;
00462       typedef Teuchos::DirectSerializationTraits<Ordinal,FadType> DSerT;
00463       typedef Sacado::Fad::SerializationTraitsImp<Ordinal,FadType> STI;
00464 
00466       static const bool supportsDirectSerialization = 
00467   vSerT::supportsDirectSerialization;
00468 
00470 
00471 
00473       static Ordinal fromCountToDirectBytes(const Ordinal count) { 
00474   return DSerT::fromCountToDirectBytes(count);
00475       }
00476 
00478       static char* convertToCharPtr( FadType* ptr ) { 
00479   return DSerT::convertToCharPtr(ptr);
00480       }
00481       
00483       static const char* convertToCharPtr( const FadType* ptr ) { 
00484   return DSerT::convertToCharPtr(ptr);
00485       }
00486       
00488       static Ordinal fromDirectBytesToCount(const Ordinal bytes) { 
00489   return DSerT::fromDirectBytesToCount(bytes);
00490       }
00491       
00493       static FadType* convertFromCharPtr( char* ptr ) { 
00494   return DSerT::convertFromCharPtr(ptr);
00495       }
00496       
00498       static const FadType* convertFromCharPtr( const char* ptr ) { 
00499   return DSerT::convertFromCharPtr(ptr);
00500       }
00501 
00503 
00505 
00506 
00508       static Ordinal fromCountToIndirectBytes(const Ordinal count, 
00509                 const FadType buffer[]) { 
00510   if (supportsDirectSerialization)
00511     return DSerT::fromCountToIndirectBytes(count, buffer);
00512   else
00513     return STI::fromCountToIndirectBytes(count, buffer);
00514       }
00515 
00517       static void serialize (const Ordinal count, 
00518            const FadType buffer[], 
00519            const Ordinal bytes, 
00520            char charBuffer[]) { 
00521   if (supportsDirectSerialization)
00522     return DSerT::serialize(count, buffer, bytes, charBuffer);
00523   else
00524     return STI::serialize(count, buffer, bytes, charBuffer);
00525       }
00526 
00528       static Ordinal fromIndirectBytesToCount(const Ordinal bytes, 
00529                 const char charBuffer[]) {
00530   if (supportsDirectSerialization)
00531     return DSerT::fromIndirectBytesToCount(bytes, charBuffer);
00532   else
00533     return STI::fromIndirectBytesToCount(bytes, charBuffer);
00534       }
00535 
00537       static void deserialize (const Ordinal bytes, 
00538              const char charBuffer[], 
00539              const Ordinal count, 
00540              FadType buffer[]) { 
00541   if (supportsDirectSerialization)
00542     return DSerT::deserialize(bytes, charBuffer, count, buffer);
00543   else
00544     return STI::deserialize(bytes, charBuffer, count, buffer);
00545       }
00546 
00548       
00549     };
00550 
00552     template <typename Ordinal, typename FadType, typename ValueSerializer>
00553     class SerializerImp {
00554 
00555     private:
00556 
00558       typedef SerializationImp<Ordinal,FadType,ValueSerializer> Imp;
00559 
00561       Teuchos::RCP<const ValueSerializer> vs;
00562 
00564       Ordinal sz;
00565 
00566     public:
00567 
00569       typedef ValueSerializer value_serializer_type;
00570 
00572       static const bool supportsDirectSerialization = 
00573   Imp::supportsDirectSerialization;
00574 
00576       SerializerImp(const Teuchos::RCP<const ValueSerializer>& vs_,
00577         Ordinal sz_ = 0) :
00578   vs(vs_), sz(sz_) {}
00579 
00581       Ordinal getSerializerSize() const { return sz; }
00582       
00584       Teuchos::RCP<const value_serializer_type> getValueSerializer() const { 
00585   return vs; }
00586 
00588 
00589 
00591       Ordinal fromCountToIndirectBytes(const Ordinal count, 
00592                const FadType buffer[]) const { 
00593   return Imp::fromCountToIndirectBytes(*vs, count, buffer, sz);
00594       }
00595 
00597       void serialize (const Ordinal count, 
00598           const FadType buffer[], 
00599           const Ordinal bytes, 
00600           char charBuffer[]) const { 
00601   Imp::serialize(*vs, count, buffer, bytes, charBuffer, sz);
00602       }
00603 
00605       Ordinal fromIndirectBytesToCount(const Ordinal bytes, 
00606                const char charBuffer[]) const {
00607   return Imp::fromIndirectBytesToCount(*vs, bytes, charBuffer, sz);
00608       }
00609 
00611       void deserialize (const Ordinal bytes, 
00612       const char charBuffer[], 
00613       const Ordinal count, 
00614       FadType buffer[]) const { 
00615   return Imp::deserialize(*vs, bytes, charBuffer, count, buffer, sz);
00616       }
00617   
00619       
00620     };
00621 
00622   } // namespace Fad
00623 
00624 } // namespace Sacado
00625 
00626 #endif // HAVE_SACADO_TEUCHOS
00627 
00628 #endif // SACADO_FAD_SCALARTRAITSIMP_HPP
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines