Rythmos_AdjointModelEvaluator.hpp
00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029 #ifndef RYTHMOS_ADJOINT_MODEL_EVALUATOR_HPP
00030 #define RYTHMOS_ADJOINT_MODEL_EVALUATOR_HPP
00031
00032
00033 #include "Rythmos_IntegratorBase.hpp"
00034 #include "Thyra_ModelEvaluator.hpp"
00035 #include "Thyra_StateFuncModelEvaluatorBase.hpp"
00036 #include "Thyra_DefaultScaledAdjointLinearOp.hpp"
00037 #include "Thyra_DefaultAdjointLinearOpWithSolve.hpp"
00038 #include "Teuchos_implicit_cast.hpp"
00039 #include "Teuchos_Assert.hpp"
00040
00041
00042 namespace Rythmos {
00043
00044
00167 template<class Scalar>
00168 class AdjointModelEvaluator
00169 : virtual public Thyra::StateFuncModelEvaluatorBase<Scalar>
00170 {
00171 public:
00172
00175
00177 AdjointModelEvaluator();
00178
00180 void setFwdStateModel(
00181 const RCP<const Thyra::ModelEvaluator<Scalar> > &fwdStateModel );
00182
00186 void setFwdTimeRange( const TimeRange<Scalar> &fwdTimeRange );
00187
00196 void setFwdStateSolutionBuffer(
00197 const RCP<const InterpolationBufferBase<Scalar> > &fwdStateSolutionBuffer );
00198
00200
00203
00205 RCP<const Thyra::VectorSpaceBase<Scalar> > get_x_space() const;
00207 RCP<const Thyra::VectorSpaceBase<Scalar> > get_f_space() const;
00209 Thyra::ModelEvaluatorBase::InArgs<Scalar> getNominalValues() const;
00211 RCP<Thyra::LinearOpWithSolveBase<Scalar> > create_W() const;
00213 RCP<Thyra::LinearOpBase<Scalar> > create_W_op() const;
00215 Thyra::ModelEvaluatorBase::InArgs<Scalar> createInArgs() const;
00216
00218
00219 private:
00220
00223
00225 Thyra::ModelEvaluatorBase::OutArgs<Scalar> createOutArgsImpl() const;
00227 void evalModelImpl(
00228 const Thyra::ModelEvaluatorBase::InArgs<Scalar> &inArgs_bar,
00229 const Thyra::ModelEvaluatorBase::OutArgs<Scalar> &outArgs_bar
00230 ) const;
00231
00233
00234 private:
00235
00236
00237
00238
00239 RCP<const Thyra::ModelEvaluator<Scalar> > fwdStateModel_;
00240 TimeRange<Scalar> fwdTimeRange_;
00241 RCP<const InterpolationBufferBase<Scalar> > fwdStateSolutionBuffer_;
00242
00243 mutable Thyra::ModelEvaluatorBase::InArgs<Scalar> prototypeInArgs_bar_;
00244 mutable Thyra::ModelEvaluatorBase::OutArgs<Scalar> prototypeOutArgs_bar_;
00245 mutable Thyra::ModelEvaluatorBase::InArgs<Scalar> adjointNominalValues_;
00246 mutable RCP<Thyra::LinearOpBase<Scalar> > my_W_bar_adj_op_;
00247 mutable RCP<Thyra::LinearOpBase<Scalar> > my_d_f_d_x_dot_op_;
00248
00249
00250
00251
00252
00253 void initialize() const;
00254
00255 };
00256
00257
00262 template<class Scalar>
00263 RCP<AdjointModelEvaluator<Scalar> >
00264 adjointModelEvaluator(
00265 const RCP<const Thyra::ModelEvaluator<Scalar> > &fwdStateModel,
00266 const TimeRange<Scalar> &fwdTimeRange
00267 )
00268 {
00269 RCP<AdjointModelEvaluator<Scalar> >
00270 adjointModel = Teuchos::rcp(new AdjointModelEvaluator<Scalar>);
00271 adjointModel->setFwdStateModel(fwdStateModel);
00272 adjointModel->setFwdTimeRange(fwdTimeRange);
00273 return adjointModel;
00274 }
00275
00276
00277
00278
00279
00280
00281
00282
00283
00284 template<class Scalar>
00285 AdjointModelEvaluator<Scalar>::AdjointModelEvaluator()
00286 {}
00287
00288
00289 template<class Scalar>
00290 void AdjointModelEvaluator<Scalar>::setFwdStateModel(
00291 const RCP<const Thyra::ModelEvaluator<Scalar> > &fwdStateModel )
00292 {
00293 TEST_FOR_EXCEPT(is_null(fwdStateModel));
00294 fwdStateModel_ = fwdStateModel;
00295 }
00296
00297
00298 template<class Scalar>
00299 void AdjointModelEvaluator<Scalar>::setFwdTimeRange(
00300 const TimeRange<Scalar> &fwdTimeRange )
00301 {
00302 fwdTimeRange_ = fwdTimeRange;
00303 }
00304
00305
00306 template<class Scalar>
00307 void AdjointModelEvaluator<Scalar>::setFwdStateSolutionBuffer(
00308 const RCP<const InterpolationBufferBase<Scalar> > &fwdStateSolutionBuffer )
00309 {
00310 TEST_FOR_EXCEPT(is_null(fwdStateSolutionBuffer));
00311 fwdStateSolutionBuffer_ = fwdStateSolutionBuffer;
00312 }
00313
00314
00315
00316
00317
00318 template<class Scalar>
00319 RCP<const Thyra::VectorSpaceBase<Scalar> >
00320 AdjointModelEvaluator<Scalar>::get_x_space() const
00321 {
00322 initialize();
00323 return fwdStateModel_->get_f_space();
00324 }
00325
00326
00327 template<class Scalar>
00328 RCP<const Thyra::VectorSpaceBase<Scalar> >
00329 AdjointModelEvaluator<Scalar>::get_f_space() const
00330 {
00331 initialize();
00332 return fwdStateModel_->get_x_space();
00333 }
00334
00335
00336 template<class Scalar>
00337 Thyra::ModelEvaluatorBase::InArgs<Scalar>
00338 AdjointModelEvaluator<Scalar>::getNominalValues() const
00339 {
00340 initialize();
00341 return adjointNominalValues_;
00342 }
00343
00344
00345 template<class Scalar>
00346 RCP<Thyra::LinearOpWithSolveBase<Scalar> >
00347 AdjointModelEvaluator<Scalar>::create_W() const
00348 {
00349 initialize();
00350 return Thyra::nonconstAdjointLows<Scalar>(fwdStateModel_->create_W());
00351 }
00352
00353
00354 template<class Scalar>
00355 RCP<Thyra::LinearOpBase<Scalar> >
00356 AdjointModelEvaluator<Scalar>::create_W_op() const
00357 {
00358 initialize();
00359 return Thyra::nonconstAdjoint<Scalar>(fwdStateModel_->create_W_op());
00360 }
00361
00362
00363 template<class Scalar>
00364 Thyra::ModelEvaluatorBase::InArgs<Scalar>
00365 AdjointModelEvaluator<Scalar>::createInArgs() const
00366 {
00367 initialize();
00368 return prototypeInArgs_bar_;
00369 }
00370
00371
00372
00373
00374
00375 template<class Scalar>
00376 Thyra::ModelEvaluatorBase::OutArgs<Scalar>
00377 AdjointModelEvaluator<Scalar>::createOutArgsImpl() const
00378 {
00379 initialize();
00380 return prototypeOutArgs_bar_;
00381 }
00382
00383
00384 template<class Scalar>
00385 void AdjointModelEvaluator<Scalar>::evalModelImpl(
00386 const Thyra::ModelEvaluatorBase::InArgs<Scalar> &inArgs_bar,
00387 const Thyra::ModelEvaluatorBase::OutArgs<Scalar> &outArgs_bar
00388 ) const
00389 {
00390
00391 using Teuchos::rcp_dynamic_cast;
00392 using Teuchos::describe;
00393 typedef Teuchos::ScalarTraits<Scalar> ST;
00394 typedef Thyra::ModelEvaluatorBase MEB;
00395 typedef Thyra::DefaultScaledAdjointLinearOp<Scalar> DSALO;
00396 typedef Thyra::DefaultAdjointLinearOpWithSolve<Scalar> DALOWS;
00397 typedef Teuchos::VerboseObjectTempState<Thyra::ModelEvaluatorBase> VOTSME;
00398
00399
00400
00401
00402
00403 THYRA_MODEL_EVALUATOR_DECORATOR_EVAL_MODEL_GEN_BEGIN(
00404 "AdjointModelEvaluator", inArgs_bar, outArgs_bar, Teuchos::null );
00405
00406 initialize();
00407
00408 VOTSME fwdStateModel_outputTempState(fwdStateModel_,out,verbLevel);
00409
00410
00411 const bool dumpAll = includesVerbLevel(localVerbLevel, Teuchos::VERB_EXTREME);
00412
00413
00414
00415
00416
00417
00418
00419 const RCP<const Thyra::VectorBase<Scalar> >
00420 lambda_rev_dot = inArgs_bar.get_x_dot().assert_not_null(),
00421 lambda = inArgs_bar.get_x().assert_not_null();
00422 const Scalar alpha_bar = inArgs_bar.get_alpha();
00423 const Scalar beta_bar = inArgs_bar.get_beta();
00424
00425
00426 if (dumpAll) {
00427 *out << "\nlambda_rev_dot = " << describe(*lambda_rev_dot, Teuchos::VERB_EXTREME);
00428 *out << "\nlambda = " << describe(*lambda, Teuchos::VERB_EXTREME);
00429 *out << "\nalpha_bar = " << alpha_bar << "\n";
00430 *out << "\nbeta_bar = " << beta_bar << "\n";
00431 }
00432
00433
00434
00435 const RCP<Thyra::VectorBase<Scalar> > f_bar = outArgs_bar.get_f();
00436
00437 RCP<DALOWS> W_bar;
00438 if (outArgs_bar.supports(MEB::OUT_ARG_W))
00439 W_bar = rcp_dynamic_cast<DALOWS>(outArgs_bar.get_W(), true);
00440
00441 RCP<DSALO> W_bar_op;
00442 if (outArgs_bar.supports(MEB::OUT_ARG_W_op))
00443 W_bar_op = rcp_dynamic_cast<DSALO>(outArgs_bar.get_W_op(), true);
00444
00445
00446
00447
00448
00449 MEB::InArgs<Scalar> fwdInArgs = fwdStateModel_->createInArgs();
00450
00451
00452
00453 if (!is_null(fwdStateSolutionBuffer_)) {
00454 TEST_FOR_EXCEPT_MSG(true, "ToDo: Implement getting the x and x_dot from IB");
00455 }
00456 else {
00457
00458
00459
00460 fwdInArgs = fwdStateModel_->getNominalValues();
00461
00462
00463 }
00464
00465
00466
00467 RCP<Thyra::LinearOpWithSolveBase<Scalar> > W_bar_adj;
00468 RCP<Thyra::LinearOpBase<Scalar> > W_bar_adj_op;
00469 {
00470
00471 MEB::OutArgs<Scalar> fwdOutArgs = fwdStateModel_->createOutArgs();
00472
00473
00474 if (!is_null(W_bar)) {
00475
00476
00477 W_bar_adj = W_bar->getNonconstOp();
00478 W_bar_adj_op = W_bar_adj;
00479 }
00480 else if (!is_null(W_bar_op)) {
00481 TEST_FOR_EXCEPT_MSG(true, "ToDo: Unit test this code!");
00482
00483
00484 W_bar_adj_op = W_bar_op->getNonconstOp();
00485 }
00486 else if (!is_null(f_bar)) {
00487 TEST_FOR_EXCEPT_MSG(true, "ToDo: Unit test this code!");
00488
00489
00490
00491 if (is_null(my_W_bar_adj_op_)) {
00492 my_W_bar_adj_op_ = fwdStateModel_->create_W_op();
00493 }
00494 W_bar_adj_op = my_W_bar_adj_op_;
00495 }
00496
00497
00498 if (!is_null(W_bar_adj)) {
00499 fwdOutArgs.set_W(W_bar_adj);
00500 }
00501 else if (!is_null(W_bar_adj_op)) {
00502 fwdOutArgs.set_W_op(W_bar_adj);
00503 }
00504
00505
00506 if (!is_null(W_bar_adj) || !is_null(W_bar_adj_op)) {
00507 fwdInArgs.set_alpha(alpha_bar);
00508 fwdInArgs.set_beta(beta_bar);
00509 }
00510
00511
00512 if (!is_null(W_bar_adj) || !is_null(W_bar_adj_op)) {
00513 fwdStateModel_->evalModel( fwdInArgs, fwdOutArgs );
00514 }
00515
00516
00517 if (!is_null(W_bar_adj) && dumpAll)
00518 *out << "\nW_bar_adj = " << describe(*W_bar_adj, Teuchos::VERB_EXTREME);
00519 if (!is_null(W_bar_adj_op) && dumpAll)
00520 *out << "\nW_bar_adj_op = " << describe(*W_bar_adj_op, Teuchos::VERB_EXTREME);
00521
00522 }
00523
00524
00525
00526 RCP<Thyra::LinearOpBase<Scalar> > d_f_d_x_dot_op;
00527 if (!is_null(f_bar)) {
00528 if (is_null(my_d_f_d_x_dot_op_)) {
00529 my_d_f_d_x_dot_op_ = fwdStateModel_->create_W_op();
00530 }
00531 d_f_d_x_dot_op = my_d_f_d_x_dot_op_;
00532 MEB::OutArgs<Scalar> fwdOutArgs = fwdStateModel_->createOutArgs();
00533 fwdOutArgs.set_W_op(d_f_d_x_dot_op);
00534 fwdInArgs.set_alpha(ST::one());
00535 fwdInArgs.set_beta(ST::zero());
00536 fwdStateModel_->evalModel( fwdInArgs, fwdOutArgs );
00537 }
00538
00539
00540
00541
00542
00543
00544
00545
00546 if (!is_null(f_bar)) {
00547
00548
00549 const RCP<Thyra::VectorBase<Scalar> >
00550 lambda_hat = createMember(lambda_rev_dot->space());
00551 Thyra::V_VpStV<Scalar>( outArg(*lambda_hat),
00552 *lambda_rev_dot, -alpha_bar/beta_bar, *lambda );
00553 if (dumpAll)
00554 *out << "\nlambda_hat = " << describe(*lambda_hat, Teuchos::VERB_EXTREME);
00555
00556
00557 Thyra::apply<Scalar>( *d_f_d_x_dot_op, Thyra::CONJTRANS, *lambda_hat,
00558 outArg(*f_bar) );
00559
00560
00561 Thyra::apply<Scalar>( *W_bar_adj_op, Thyra::CONJTRANS, *lambda,
00562 outArg(*f_bar), 1.0/beta_bar, ST::one() );
00563
00564
00565
00566
00567
00568 if (dumpAll)
00569 *out << "\nf_bar = " << describe(*f_bar, Teuchos::VERB_EXTREME);
00570
00571 }
00572
00573
00574
00575
00576
00577
00578 THYRA_MODEL_EVALUATOR_DECORATOR_EVAL_MODEL_END();
00579
00580 }
00581
00582
00583
00584
00585
00586 template<class Scalar>
00587 void AdjointModelEvaluator<Scalar>::initialize() const
00588 {
00589
00590 typedef Thyra::ModelEvaluatorBase MEB;
00591
00592
00593
00594
00595
00596 MEB::InArgs<Scalar> fwdStateModelInArgs = fwdStateModel_->createInArgs();
00597 MEB::OutArgs<Scalar> fwdStateModelOutArgs = fwdStateModel_->createOutArgs();
00598
00599 #ifdef TEUCHOS_DEBUG
00600 TEUCHOS_ASSERT( fwdStateModelInArgs.supports(MEB::IN_ARG_x_dot) );
00601 TEUCHOS_ASSERT( fwdStateModelInArgs.supports(MEB::IN_ARG_x) );
00602 TEUCHOS_ASSERT( fwdStateModelInArgs.supports(MEB::IN_ARG_t) );
00603 TEUCHOS_ASSERT( fwdStateModelInArgs.supports(MEB::IN_ARG_alpha) );
00604 TEUCHOS_ASSERT( fwdStateModelInArgs.supports(MEB::IN_ARG_beta) );
00605 TEUCHOS_ASSERT( fwdStateModelOutArgs.supports(MEB::OUT_ARG_f) );
00606 TEUCHOS_ASSERT( fwdStateModelOutArgs.supports(MEB::OUT_ARG_W) );
00607 #endif
00608
00609
00610
00611
00612
00613 {
00614 MEB::InArgsSetup<Scalar> inArgs_bar;
00615 inArgs_bar.setModelEvalDescription(this->description());
00616 inArgs_bar.setSupports( MEB::IN_ARG_x_dot );
00617 inArgs_bar.setSupports( MEB::IN_ARG_x );
00618 inArgs_bar.setSupports( MEB::IN_ARG_t );
00619 inArgs_bar.setSupports( MEB::IN_ARG_alpha );
00620 inArgs_bar.setSupports( MEB::IN_ARG_beta );
00621 prototypeInArgs_bar_ = inArgs_bar;
00622 }
00623
00624 {
00625 MEB::OutArgsSetup<Scalar> outArgs_bar;
00626 outArgs_bar.setModelEvalDescription(this->description());
00627 outArgs_bar.setSupports(MEB::OUT_ARG_f);
00628 if (fwdStateModelOutArgs.supports(MEB::OUT_ARG_W) ) {
00629 outArgs_bar.setSupports(MEB::OUT_ARG_W);
00630 outArgs_bar.set_W_properties(fwdStateModelOutArgs.get_W_properties());
00631 }
00632 if (fwdStateModelOutArgs.supports(MEB::OUT_ARG_W_op) ) {
00633 outArgs_bar.setSupports(MEB::OUT_ARG_W_op);
00634 outArgs_bar.set_W_properties(fwdStateModelOutArgs.get_W_properties());
00635 }
00636 prototypeOutArgs_bar_ = outArgs_bar;
00637 }
00638
00639
00640
00641
00642
00643
00644 adjointNominalValues_ = prototypeInArgs_bar_;
00645
00646 const RCP<Thyra::VectorBase<Scalar> > zero_lambda_vec =
00647 createMember(fwdStateModel_->get_f_space());
00648 V_S( zero_lambda_vec.ptr(), ScalarTraits<Scalar>::zero() );
00649 adjointNominalValues_.set_x_dot(zero_lambda_vec);
00650 adjointNominalValues_.set_x(zero_lambda_vec);
00651
00652
00653
00654
00655
00656 my_W_bar_adj_op_ = Teuchos::null;
00657 my_d_f_d_x_dot_op_ = Teuchos::null;
00658
00659 }
00660
00661
00662 }
00663
00664
00665 #endif // RYTHMOS_ADJOINT_MODEL_EVALUATOR_HPP