Rythmos_AdjointModelEvaluator.hpp

00001 //@HEADER
00002 // ***********************************************************************
00003 //
00004 //                           Rythmos Package
00005 //                 Copyright (2006) Sandia Corporation
00006 //
00007 // Under terms of Contract DE-AC04-94AL85000, there is a non-exclusive
00008 // license for use of this work by or on behalf of the U.S. Government.
00009 //
00010 // This library is free software; you can redistribute it and/or modify
00011 // it under the terms of the GNU Lesser General Public License as
00012 // published by the Free Software Foundation; either version 2.1 of the
00013 // License, or (at your option) any later version.
00014 //
00015 // This library is distributed in the hope that it will be useful, but
00016 // WITHOUT ANY WARRANTY; without even the implied warranty of
00017 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
00018 // Lesser General Public License for more details.
00019 //
00020 // You should have received a copy of the GNU Lesser General Public
00021 // License along with this library; if not, write to the Free Software
00022 // Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307
00023 // USA
00024 // Questions? Contact Todd S. Coffey (tscoffe@sandia.gov)
00025 //
00026 // ***********************************************************************
00027 //@HEADER
00028 
00029 #ifndef RYTHMOS_ADJOINT_MODEL_EVALUATOR_HPP
00030 #define RYTHMOS_ADJOINT_MODEL_EVALUATOR_HPP
00031 
00032 
00033 #include "Rythmos_IntegratorBase.hpp"
00034 #include "Thyra_ModelEvaluator.hpp" // Interface
00035 #include "Thyra_StateFuncModelEvaluatorBase.hpp" // Implementation
00036 #include "Thyra_DefaultScaledAdjointLinearOp.hpp"
00037 #include "Thyra_DefaultAdjointLinearOpWithSolve.hpp"
00038 #include "Teuchos_implicit_cast.hpp"
00039 #include "Teuchos_Assert.hpp"
00040 
00041 
00042 namespace Rythmos {
00043 
00044 
00167 template<class Scalar>
00168 class AdjointModelEvaluator
00169   : virtual public Thyra::StateFuncModelEvaluatorBase<Scalar>
00170 {
00171 public:
00172 
00175 
00177   AdjointModelEvaluator();
00178 
00180   void setFwdStateModel(
00181     const RCP<const Thyra::ModelEvaluator<Scalar> > &fwdStateModel );
00182 
00186   void setFwdTimeRange( const TimeRange<Scalar> &fwdTimeRange );
00187 
00196   void setFwdStateSolutionBuffer(
00197     const RCP<const InterpolationBufferBase<Scalar> > &fwdStateSolutionBuffer );
00198   
00200 
00203 
00205   RCP<const Thyra::VectorSpaceBase<Scalar> > get_x_space() const;
00207   RCP<const Thyra::VectorSpaceBase<Scalar> > get_f_space() const;
00209   Thyra::ModelEvaluatorBase::InArgs<Scalar> getNominalValues() const;
00211   RCP<Thyra::LinearOpWithSolveBase<Scalar> > create_W() const;
00213   RCP<Thyra::LinearOpBase<Scalar> > create_W_op() const;
00215   Thyra::ModelEvaluatorBase::InArgs<Scalar> createInArgs() const;
00216 
00218 
00219 private:
00220 
00223 
00225   Thyra::ModelEvaluatorBase::OutArgs<Scalar> createOutArgsImpl() const;
00227   void evalModelImpl(
00228     const Thyra::ModelEvaluatorBase::InArgs<Scalar> &inArgs_bar,
00229     const Thyra::ModelEvaluatorBase::OutArgs<Scalar> &outArgs_bar
00230     ) const;
00231 
00233 
00234 private:
00235 
00236   // /////////////////////////
00237   // Private data members
00238 
00239   RCP<const Thyra::ModelEvaluator<Scalar> > fwdStateModel_;
00240   TimeRange<Scalar> fwdTimeRange_;
00241   RCP<const InterpolationBufferBase<Scalar> > fwdStateSolutionBuffer_;
00242 
00243   mutable Thyra::ModelEvaluatorBase::InArgs<Scalar> prototypeInArgs_bar_;
00244   mutable Thyra::ModelEvaluatorBase::OutArgs<Scalar> prototypeOutArgs_bar_;
00245   mutable Thyra::ModelEvaluatorBase::InArgs<Scalar> adjointNominalValues_;
00246   mutable RCP<Thyra::LinearOpBase<Scalar> > my_W_bar_adj_op_;
00247   mutable RCP<Thyra::LinearOpBase<Scalar> > my_d_f_d_x_dot_op_;
00248 
00249   // /////////////////////////
00250   // Private member functions
00251 
00252   // Just-in-time initialization function
00253   void initialize() const;
00254 
00255 };
00256 
00257 
00262 template<class Scalar>
00263 RCP<AdjointModelEvaluator<Scalar> >
00264 adjointModelEvaluator(
00265   const RCP<const Thyra::ModelEvaluator<Scalar> > &fwdStateModel,
00266   const TimeRange<Scalar> &fwdTimeRange
00267   )
00268 {
00269   RCP<AdjointModelEvaluator<Scalar> >
00270     adjointModel = Teuchos::rcp(new AdjointModelEvaluator<Scalar>);
00271   adjointModel->setFwdStateModel(fwdStateModel);
00272   adjointModel->setFwdTimeRange(fwdTimeRange);
00273   return adjointModel;
00274 }
00275 
00276 
00277 // /////////////////////////////////
00278 // Implementations
00279 
00280 
00281 // Constructors/Intializers/Accessors
00282 
00283 
00284 template<class Scalar>
00285 AdjointModelEvaluator<Scalar>::AdjointModelEvaluator()
00286 {}
00287 
00288 
00289 template<class Scalar>
00290 void AdjointModelEvaluator<Scalar>::setFwdStateModel(
00291   const RCP<const Thyra::ModelEvaluator<Scalar> > &fwdStateModel )
00292 {
00293   TEST_FOR_EXCEPT(is_null(fwdStateModel));
00294   fwdStateModel_ = fwdStateModel;
00295 }
00296 
00297 
00298 template<class Scalar>
00299 void AdjointModelEvaluator<Scalar>::setFwdTimeRange(
00300   const TimeRange<Scalar> &fwdTimeRange )
00301 {
00302   fwdTimeRange_ = fwdTimeRange;
00303 }
00304 
00305 
00306 template<class Scalar>
00307 void AdjointModelEvaluator<Scalar>::setFwdStateSolutionBuffer(
00308   const RCP<const InterpolationBufferBase<Scalar> > &fwdStateSolutionBuffer )
00309 {
00310   TEST_FOR_EXCEPT(is_null(fwdStateSolutionBuffer));
00311   fwdStateSolutionBuffer_ = fwdStateSolutionBuffer;
00312 }
00313 
00314 
00315 // Public functions overridden from ModelEvaulator
00316 
00317 
00318 template<class Scalar>
00319 RCP<const Thyra::VectorSpaceBase<Scalar> >
00320 AdjointModelEvaluator<Scalar>::get_x_space() const
00321 {
00322   initialize();
00323   return fwdStateModel_->get_f_space();
00324 }
00325 
00326 
00327 template<class Scalar>
00328 RCP<const Thyra::VectorSpaceBase<Scalar> >
00329 AdjointModelEvaluator<Scalar>::get_f_space() const
00330 {
00331   initialize();
00332   return fwdStateModel_->get_x_space();
00333 }
00334 
00335 
00336 template<class Scalar>
00337 Thyra::ModelEvaluatorBase::InArgs<Scalar>
00338 AdjointModelEvaluator<Scalar>::getNominalValues() const
00339 {
00340   initialize();
00341   return adjointNominalValues_;
00342 }
00343 
00344 
00345 template<class Scalar>
00346 RCP<Thyra::LinearOpWithSolveBase<Scalar> >
00347 AdjointModelEvaluator<Scalar>::create_W() const
00348 {
00349   initialize();
00350   return Thyra::nonconstAdjointLows<Scalar>(fwdStateModel_->create_W());
00351 }
00352 
00353 
00354 template<class Scalar>
00355 RCP<Thyra::LinearOpBase<Scalar> >
00356 AdjointModelEvaluator<Scalar>::create_W_op() const
00357 {
00358   initialize();
00359   return Thyra::nonconstAdjoint<Scalar>(fwdStateModel_->create_W_op());
00360 }
00361 
00362 
00363 template<class Scalar>
00364 Thyra::ModelEvaluatorBase::InArgs<Scalar>
00365 AdjointModelEvaluator<Scalar>::createInArgs() const
00366 {
00367   initialize();
00368   return prototypeInArgs_bar_;
00369 }
00370 
00371 
00372 // Private functions overridden from ModelEvaulatorDefaultBase
00373 
00374 
00375 template<class Scalar>
00376 Thyra::ModelEvaluatorBase::OutArgs<Scalar>
00377 AdjointModelEvaluator<Scalar>::createOutArgsImpl() const
00378 {
00379   initialize();
00380   return prototypeOutArgs_bar_;
00381 }
00382 
00383 
00384 template<class Scalar>
00385 void AdjointModelEvaluator<Scalar>::evalModelImpl(
00386   const Thyra::ModelEvaluatorBase::InArgs<Scalar> &inArgs_bar,
00387   const Thyra::ModelEvaluatorBase::OutArgs<Scalar> &outArgs_bar
00388   ) const
00389 {
00390 
00391   using Teuchos::rcp_dynamic_cast;
00392   using Teuchos::describe;
00393   typedef Teuchos::ScalarTraits<Scalar> ST;
00394   typedef Thyra::ModelEvaluatorBase MEB;
00395   typedef Thyra::DefaultScaledAdjointLinearOp<Scalar> DSALO;
00396   typedef Thyra::DefaultAdjointLinearOpWithSolve<Scalar> DALOWS;
00397   typedef Teuchos::VerboseObjectTempState<Thyra::ModelEvaluatorBase> VOTSME;
00398 
00399   //
00400   // A) Header stuff
00401   //
00402 
00403   THYRA_MODEL_EVALUATOR_DECORATOR_EVAL_MODEL_GEN_BEGIN(
00404     "AdjointModelEvaluator", inArgs_bar, outArgs_bar, Teuchos::null );
00405 
00406   initialize();
00407 
00408   VOTSME fwdStateModel_outputTempState(fwdStateModel_,out,verbLevel);
00409 
00410   //const bool trace = includesVerbLevel(verbLevel, Teuchos::VERB_LOW);
00411   const bool dumpAll = includesVerbLevel(localVerbLevel, Teuchos::VERB_EXTREME);
00412 
00413   //
00414   // B) Unpack the input and output arguments to see what we have to compute
00415   //
00416 
00417   // B.1) InArgs
00418 
00419   const RCP<const Thyra::VectorBase<Scalar> >
00420     lambda_rev_dot = inArgs_bar.get_x_dot().assert_not_null(), // x_bar_dot
00421     lambda = inArgs_bar.get_x().assert_not_null(); // x_bar
00422   const Scalar alpha_bar = inArgs_bar.get_alpha();
00423   const Scalar beta_bar = inArgs_bar.get_beta();
00424   // const Scalar t_bar = inArgs_bar.get_t(); // Don't need yet!
00425 
00426   if (dumpAll) {
00427     *out << "\nlambda_rev_dot = " << describe(*lambda_rev_dot, Teuchos::VERB_EXTREME);
00428     *out << "\nlambda = " << describe(*lambda, Teuchos::VERB_EXTREME);
00429     *out << "\nalpha_bar = " << alpha_bar << "\n";
00430     *out << "\nbeta_bar = " << beta_bar << "\n";
00431   }
00432 
00433   // B.2) OutArgs
00434 
00435   const RCP<Thyra::VectorBase<Scalar> > f_bar = outArgs_bar.get_f();
00436 
00437   RCP<DALOWS> W_bar;
00438   if (outArgs_bar.supports(MEB::OUT_ARG_W))
00439     W_bar = rcp_dynamic_cast<DALOWS>(outArgs_bar.get_W(), true);
00440 
00441   RCP<DSALO> W_bar_op;
00442   if (outArgs_bar.supports(MEB::OUT_ARG_W_op))
00443     W_bar_op = rcp_dynamic_cast<DSALO>(outArgs_bar.get_W_op(), true);
00444   
00445   //
00446   // C) Evaluate the needed quantities from the underlying forward Model
00447   //
00448 
00449   MEB::InArgs<Scalar> fwdInArgs = fwdStateModel_->createInArgs();
00450 
00451   // C.1) Set the required input arguments
00452 
00453   if (!is_null(fwdStateSolutionBuffer_)) {
00454     TEST_FOR_EXCEPT_MSG(true, "ToDo: Implement getting the x and x_dot from IB");
00455   }
00456   else {
00457     // If we don't have an IB object to get the state from, we will assume
00458     // that the problem is linear and, therefore, we can pass in any old value
00459     // of x, x_dot, and t and get the W_bar_adj object that we need
00460     fwdInArgs = fwdStateModel_->getNominalValues();
00461     // 2008/05/14: rabartl: ToDo: Implement real variable dependancy
00462     // communication support to make sure that this is okay!
00463   }
00464 
00465   // C.2) Evaluate W_bar_adj if needed
00466 
00467   RCP<Thyra::LinearOpWithSolveBase<Scalar> > W_bar_adj;
00468   RCP<Thyra::LinearOpBase<Scalar> > W_bar_adj_op;
00469   {
00470 
00471     MEB::OutArgs<Scalar> fwdOutArgs = fwdStateModel_->createOutArgs();
00472     
00473     // Get or create W_bar_adj or W_bar_adj_op if needed
00474     if (!is_null(W_bar)) {
00475       // If we have W_bar, the W_bar_adj was already created in
00476       // this->create_W()
00477       W_bar_adj = W_bar->getNonconstOp();
00478       W_bar_adj_op = W_bar_adj;
00479     }
00480     else if (!is_null(W_bar_op)) {
00481       TEST_FOR_EXCEPT_MSG(true, "ToDo: Unit test this code!");
00482       // If we have W_bar_op, the W_bar_adj_op was already created in
00483       // this->create_W_op()
00484       W_bar_adj_op = W_bar_op->getNonconstOp();
00485     }
00486     else if (!is_null(f_bar)) {
00487       TEST_FOR_EXCEPT_MSG(true, "ToDo: Unit test this code!");
00488       // If the user did not pass in W_bar or W_bar_op, then we need to create
00489       // our own local LOB form W_bar_adj_op of W_bar_adj in order to evaluate
00490       // the residual f_bar
00491       if (is_null(my_W_bar_adj_op_)) {
00492         my_W_bar_adj_op_ = fwdStateModel_->create_W_op();
00493       }
00494       W_bar_adj_op = my_W_bar_adj_op_;
00495     }
00496     
00497     // Set W_bar_adj or W_bar_adj_op on the OutArgs object
00498     if (!is_null(W_bar_adj)) {
00499       fwdOutArgs.set_W(W_bar_adj);
00500     }
00501     else if (!is_null(W_bar_adj_op)) {
00502       fwdOutArgs.set_W_op(W_bar_adj);
00503     }
00504     
00505     // Set alpha and beta on OutArgs object
00506     if (!is_null(W_bar_adj) || !is_null(W_bar_adj_op)) {
00507       fwdInArgs.set_alpha(alpha_bar);
00508       fwdInArgs.set_beta(beta_bar);
00509     }
00510     
00511     // Evaluate the model
00512     if (!is_null(W_bar_adj) || !is_null(W_bar_adj_op)) {
00513       fwdStateModel_->evalModel( fwdInArgs, fwdOutArgs );
00514     }
00515     
00516     // Print the objects if requested
00517     if (!is_null(W_bar_adj) && dumpAll)
00518       *out << "\nW_bar_adj = " << describe(*W_bar_adj, Teuchos::VERB_EXTREME);
00519     if (!is_null(W_bar_adj_op) && dumpAll)
00520       *out << "\nW_bar_adj_op = " << describe(*W_bar_adj_op, Teuchos::VERB_EXTREME);
00521 
00522   }
00523   
00524   // C.3) Evaluate d(f)/d(x_dot) if needed
00525 
00526   RCP<Thyra::LinearOpBase<Scalar> > d_f_d_x_dot_op;
00527   if (!is_null(f_bar)) {
00528     if (is_null(my_d_f_d_x_dot_op_)) {
00529       my_d_f_d_x_dot_op_ = fwdStateModel_->create_W_op();
00530     }
00531     d_f_d_x_dot_op = my_d_f_d_x_dot_op_;
00532     MEB::OutArgs<Scalar> fwdOutArgs = fwdStateModel_->createOutArgs();
00533     fwdOutArgs.set_W_op(d_f_d_x_dot_op);
00534     fwdInArgs.set_alpha(ST::one());
00535     fwdInArgs.set_beta(ST::zero());
00536     fwdStateModel_->evalModel( fwdInArgs, fwdOutArgs );
00537   }
00538 
00539   //
00540   // D) Evaluate the adjoint equation residual:
00541   //
00542   //   f_bar = d(f)/d(x_dot)^T * lambda_hat + 1/beta_bar * W_bar_adj^T * lambda
00543   //           - d(g)/d(x)^T
00544   //
00545 
00546   if (!is_null(f_bar)) {
00547 
00548     // D.1) lambda_hat = lambda_rev_dot - alpha_bar/beta_bar * lambda
00549     const RCP<Thyra::VectorBase<Scalar> >
00550       lambda_hat = createMember(lambda_rev_dot->space());
00551     Thyra::V_VpStV<Scalar>( outArg(*lambda_hat),
00552       *lambda_rev_dot, -alpha_bar/beta_bar, *lambda );
00553     if (dumpAll)
00554       *out << "\nlambda_hat = " << describe(*lambda_hat, Teuchos::VERB_EXTREME);
00555 
00556     // D.2) f_bar = d(f)/d(x_dot)^T * lambda_hat
00557     Thyra::apply<Scalar>( *d_f_d_x_dot_op, Thyra::CONJTRANS, *lambda_hat,
00558       outArg(*f_bar) );
00559 
00560     // D.3) f_bar += 1/beta_bar * W_bar_adj^T * lambda
00561     Thyra::apply<Scalar>( *W_bar_adj_op, Thyra::CONJTRANS, *lambda,
00562       outArg(*f_bar), 1.0/beta_bar, ST::one() );
00563 
00564     // D.4) f_bar += - d(g)/d(x)^T
00565     // 2008/05/15: rabart: ToDo: Implement once we add support for
00566     // distributed response functions
00567 
00568     if (dumpAll)
00569       *out << "\nf_bar = " << describe(*f_bar, Teuchos::VERB_EXTREME);
00570 
00571   }
00572 
00573 
00574   //
00575   // E) Do any remaining post processing
00576   //
00577 
00578   THYRA_MODEL_EVALUATOR_DECORATOR_EVAL_MODEL_END();
00579 
00580 }
00581 
00582 
00583 // private
00584 
00585 
00586 template<class Scalar>
00587 void AdjointModelEvaluator<Scalar>::initialize() const
00588 {
00589 
00590   typedef Thyra::ModelEvaluatorBase MEB;
00591 
00592   //
00593   // A) Validate the that forward Model is of the correct form!
00594   //
00595 
00596   MEB::InArgs<Scalar> fwdStateModelInArgs = fwdStateModel_->createInArgs();
00597   MEB::OutArgs<Scalar> fwdStateModelOutArgs = fwdStateModel_->createOutArgs();
00598 
00599 #ifdef TEUCHOS_DEBUG
00600   TEUCHOS_ASSERT( fwdStateModelInArgs.supports(MEB::IN_ARG_x_dot) );
00601   TEUCHOS_ASSERT( fwdStateModelInArgs.supports(MEB::IN_ARG_x) );
00602   TEUCHOS_ASSERT( fwdStateModelInArgs.supports(MEB::IN_ARG_t) );
00603   TEUCHOS_ASSERT( fwdStateModelInArgs.supports(MEB::IN_ARG_alpha) );
00604   TEUCHOS_ASSERT( fwdStateModelInArgs.supports(MEB::IN_ARG_beta) );
00605   TEUCHOS_ASSERT( fwdStateModelOutArgs.supports(MEB::OUT_ARG_f) );
00606   TEUCHOS_ASSERT( fwdStateModelOutArgs.supports(MEB::OUT_ARG_W) );
00607 #endif
00608 
00609   //
00610   // B) Set up the prototypical InArgs and OutArgs
00611   //
00612 
00613   {
00614     MEB::InArgsSetup<Scalar> inArgs_bar;
00615     inArgs_bar.setModelEvalDescription(this->description());
00616     inArgs_bar.setSupports( MEB::IN_ARG_x_dot );
00617     inArgs_bar.setSupports( MEB::IN_ARG_x );
00618     inArgs_bar.setSupports( MEB::IN_ARG_t );
00619     inArgs_bar.setSupports( MEB::IN_ARG_alpha );
00620     inArgs_bar.setSupports( MEB::IN_ARG_beta );
00621     prototypeInArgs_bar_ = inArgs_bar;
00622   }
00623 
00624   {
00625     MEB::OutArgsSetup<Scalar> outArgs_bar;
00626     outArgs_bar.setModelEvalDescription(this->description());
00627     outArgs_bar.setSupports(MEB::OUT_ARG_f);
00628     if (fwdStateModelOutArgs.supports(MEB::OUT_ARG_W) ) {
00629       outArgs_bar.setSupports(MEB::OUT_ARG_W);
00630       outArgs_bar.set_W_properties(fwdStateModelOutArgs.get_W_properties());
00631     }
00632     if (fwdStateModelOutArgs.supports(MEB::OUT_ARG_W_op) ) {
00633       outArgs_bar.setSupports(MEB::OUT_ARG_W_op);
00634       outArgs_bar.set_W_properties(fwdStateModelOutArgs.get_W_properties());
00635     }
00636     prototypeOutArgs_bar_ = outArgs_bar;
00637   }
00638 
00639   //
00640   // D) Set up the nominal values for the adjoint
00641   //
00642 
00643   // Copy structure
00644   adjointNominalValues_ = prototypeInArgs_bar_;
00645   // Just set a zero initial condition for the adjoint
00646   const RCP<Thyra::VectorBase<Scalar> > zero_lambda_vec =
00647     createMember(fwdStateModel_->get_f_space());
00648   V_S( zero_lambda_vec.ptr(), ScalarTraits<Scalar>::zero() );
00649   adjointNominalValues_.set_x_dot(zero_lambda_vec);
00650   adjointNominalValues_.set_x(zero_lambda_vec);
00651 
00652   //
00653   // E) Wipe out other cached objects
00654   //
00655 
00656   my_W_bar_adj_op_ = Teuchos::null;
00657   my_d_f_d_x_dot_op_ = Teuchos::null;
00658 
00659 }
00660 
00661 
00662 } // namespace Rythmos
00663 
00664 
00665 #endif // RYTHMOS_ADJOINT_MODEL_EVALUATOR_HPP
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends

Generated on Tue Oct 20 10:24:08 2009 for Rythmos - Transient Integration for Differential Equations by  doxygen 1.6.1