00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029 #ifndef RYTHMOS_ADJOINT_MODEL_EVALUATOR_HPP
00030 #define RYTHMOS_ADJOINT_MODEL_EVALUATOR_HPP
00031
00032
00033 #include "Rythmos_IntegratorBase.hpp"
00034 #include "Thyra_ModelEvaluator.hpp"
00035 #include "Thyra_StateFuncModelEvaluatorBase.hpp"
00036 #include "Thyra_ModelEvaluatorDelegatorBase.hpp"
00037 #include "Thyra_DefaultScaledAdjointLinearOp.hpp"
00038 #include "Thyra_DefaultAdjointLinearOpWithSolve.hpp"
00039 #include "Teuchos_implicit_cast.hpp"
00040 #include "Teuchos_Assert.hpp"
00041
00042
00043 namespace Rythmos {
00044
00045
00170 template<class Scalar>
00171 class AdjointModelEvaluator
00172 : virtual public Thyra::StateFuncModelEvaluatorBase<Scalar>
00173 {
00174 public:
00175
00178
00180 AdjointModelEvaluator();
00181
00183 void setFwdStateModel(
00184 const RCP<const Thyra::ModelEvaluator<Scalar> > &fwdStateModel,
00185 const Thyra::ModelEvaluatorBase::InArgs<Scalar> &basePoint );
00186
00190 void setFwdTimeRange( const TimeRange<Scalar> &fwdTimeRange );
00191
00204 void setFwdStateSolutionBuffer(
00205 const RCP<const InterpolationBufferBase<Scalar> > &fwdStateSolutionBuffer );
00206
00208
00211
00213 RCP<const Thyra::VectorSpaceBase<Scalar> > get_x_space() const;
00215 RCP<const Thyra::VectorSpaceBase<Scalar> > get_f_space() const;
00217 Thyra::ModelEvaluatorBase::InArgs<Scalar> getNominalValues() const;
00219 RCP<Thyra::LinearOpWithSolveBase<Scalar> > create_W() const;
00221 RCP<Thyra::LinearOpBase<Scalar> > create_W_op() const;
00223 Thyra::ModelEvaluatorBase::InArgs<Scalar> createInArgs() const;
00224
00226
00227 private:
00228
00231
00233 Thyra::ModelEvaluatorBase::OutArgs<Scalar> createOutArgsImpl() const;
00235 void evalModelImpl(
00236 const Thyra::ModelEvaluatorBase::InArgs<Scalar> &inArgs_bar,
00237 const Thyra::ModelEvaluatorBase::OutArgs<Scalar> &outArgs_bar
00238 ) const;
00239
00241
00242 private:
00243
00244
00245
00246
00247 RCP<const Thyra::ModelEvaluator<Scalar> > fwdStateModel_;
00248 Thyra::ModelEvaluatorBase::InArgs<Scalar> basePoint_;
00249 TimeRange<Scalar> fwdTimeRange_;
00250 RCP<const InterpolationBufferBase<Scalar> > fwdStateSolutionBuffer_;
00251
00252 mutable bool isInitialized_;
00253 mutable Thyra::ModelEvaluatorBase::InArgs<Scalar> prototypeInArgs_bar_;
00254 mutable Thyra::ModelEvaluatorBase::OutArgs<Scalar> prototypeOutArgs_bar_;
00255 mutable Thyra::ModelEvaluatorBase::InArgs<Scalar> adjointNominalValues_;
00256 mutable RCP<Thyra::LinearOpBase<Scalar> > my_W_bar_adj_op_;
00257 mutable RCP<Thyra::LinearOpBase<Scalar> > my_d_f_d_x_dot_op_;
00258
00259
00260
00261
00262
00263 void initialize() const;
00264
00265 };
00266
00267
00272 template<class Scalar>
00273 RCP<AdjointModelEvaluator<Scalar> >
00274 adjointModelEvaluator(
00275 const RCP<const Thyra::ModelEvaluator<Scalar> > &fwdStateModel,
00276 const TimeRange<Scalar> &fwdTimeRange
00277 )
00278 {
00279 RCP<AdjointModelEvaluator<Scalar> >
00280 adjointModel = Teuchos::rcp(new AdjointModelEvaluator<Scalar>);
00281 adjointModel->setFwdStateModel(fwdStateModel, fwdStateModel->getNominalValues());
00282 adjointModel->setFwdTimeRange(fwdTimeRange);
00283 return adjointModel;
00284 }
00285
00286
00287
00288
00289
00290
00291
00292
00293
00294 template<class Scalar>
00295 AdjointModelEvaluator<Scalar>::AdjointModelEvaluator()
00296 :isInitialized_(false)
00297 {}
00298
00299
00300 template<class Scalar>
00301 void AdjointModelEvaluator<Scalar>::setFwdStateModel(
00302 const RCP<const Thyra::ModelEvaluator<Scalar> > &fwdStateModel,
00303 const Thyra::ModelEvaluatorBase::InArgs<Scalar> &basePoint
00304 )
00305 {
00306 TEST_FOR_EXCEPT(is_null(fwdStateModel));
00307 fwdStateModel_ = fwdStateModel;
00308 basePoint_ = basePoint;
00309 isInitialized_ = false;
00310 }
00311
00312
00313 template<class Scalar>
00314 void AdjointModelEvaluator<Scalar>::setFwdTimeRange(
00315 const TimeRange<Scalar> &fwdTimeRange )
00316 {
00317 fwdTimeRange_ = fwdTimeRange;
00318 }
00319
00320
00321 template<class Scalar>
00322 void AdjointModelEvaluator<Scalar>::setFwdStateSolutionBuffer(
00323 const RCP<const InterpolationBufferBase<Scalar> > &fwdStateSolutionBuffer )
00324 {
00325 TEST_FOR_EXCEPT(is_null(fwdStateSolutionBuffer));
00326 fwdStateSolutionBuffer_ = fwdStateSolutionBuffer;
00327 }
00328
00329
00330
00331
00332
00333 template<class Scalar>
00334 RCP<const Thyra::VectorSpaceBase<Scalar> >
00335 AdjointModelEvaluator<Scalar>::get_x_space() const
00336 {
00337 initialize();
00338 return fwdStateModel_->get_f_space();
00339 }
00340
00341
00342 template<class Scalar>
00343 RCP<const Thyra::VectorSpaceBase<Scalar> >
00344 AdjointModelEvaluator<Scalar>::get_f_space() const
00345 {
00346 initialize();
00347 return fwdStateModel_->get_x_space();
00348 }
00349
00350
00351 template<class Scalar>
00352 Thyra::ModelEvaluatorBase::InArgs<Scalar>
00353 AdjointModelEvaluator<Scalar>::getNominalValues() const
00354 {
00355 initialize();
00356 return adjointNominalValues_;
00357 }
00358
00359
00360 template<class Scalar>
00361 RCP<Thyra::LinearOpWithSolveBase<Scalar> >
00362 AdjointModelEvaluator<Scalar>::create_W() const
00363 {
00364 initialize();
00365 return Thyra::nonconstAdjointLows<Scalar>(fwdStateModel_->create_W());
00366 }
00367
00368
00369 template<class Scalar>
00370 RCP<Thyra::LinearOpBase<Scalar> >
00371 AdjointModelEvaluator<Scalar>::create_W_op() const
00372 {
00373 initialize();
00374 return Thyra::nonconstAdjoint<Scalar>(fwdStateModel_->create_W_op());
00375 }
00376
00377
00378 template<class Scalar>
00379 Thyra::ModelEvaluatorBase::InArgs<Scalar>
00380 AdjointModelEvaluator<Scalar>::createInArgs() const
00381 {
00382 initialize();
00383 return prototypeInArgs_bar_;
00384 }
00385
00386
00387
00388
00389
00390 template<class Scalar>
00391 Thyra::ModelEvaluatorBase::OutArgs<Scalar>
00392 AdjointModelEvaluator<Scalar>::createOutArgsImpl() const
00393 {
00394 initialize();
00395 return prototypeOutArgs_bar_;
00396 }
00397
00398
00399 template<class Scalar>
00400 void AdjointModelEvaluator<Scalar>::evalModelImpl(
00401 const Thyra::ModelEvaluatorBase::InArgs<Scalar> &inArgs_bar,
00402 const Thyra::ModelEvaluatorBase::OutArgs<Scalar> &outArgs_bar
00403 ) const
00404 {
00405
00406 using Teuchos::rcp_dynamic_cast;
00407 using Teuchos::describe;
00408 typedef Teuchos::ScalarTraits<Scalar> ST;
00409 typedef Thyra::ModelEvaluatorBase MEB;
00410 typedef Thyra::DefaultScaledAdjointLinearOp<Scalar> DSALO;
00411 typedef Thyra::DefaultAdjointLinearOpWithSolve<Scalar> DALOWS;
00412 typedef Teuchos::VerboseObjectTempState<Thyra::ModelEvaluatorBase> VOTSME;
00413
00414
00415
00416
00417
00418 THYRA_MODEL_EVALUATOR_DECORATOR_EVAL_MODEL_GEN_BEGIN(
00419 "AdjointModelEvaluator", inArgs_bar, outArgs_bar, Teuchos::null );
00420
00421 initialize();
00422
00423 VOTSME fwdStateModel_outputTempState(fwdStateModel_,out,verbLevel);
00424
00425
00426 const bool dumpAll = includesVerbLevel(localVerbLevel, Teuchos::VERB_EXTREME);
00427
00428
00429
00430
00431
00432
00433
00434 const Scalar t_bar = inArgs_bar.get_t();
00435 const RCP<const Thyra::VectorBase<Scalar> >
00436 lambda_rev_dot = inArgs_bar.get_x_dot().assert_not_null(),
00437 lambda = inArgs_bar.get_x().assert_not_null();
00438 const Scalar alpha_bar = inArgs_bar.get_alpha();
00439 const Scalar beta_bar = inArgs_bar.get_beta();
00440
00441 if (dumpAll) {
00442 *out << "\nlambda_rev_dot = " << describe(*lambda_rev_dot, Teuchos::VERB_EXTREME);
00443 *out << "\nlambda = " << describe(*lambda, Teuchos::VERB_EXTREME);
00444 *out << "\nalpha_bar = " << alpha_bar << "\n";
00445 *out << "\nbeta_bar = " << beta_bar << "\n";
00446 }
00447
00448
00449
00450 const RCP<Thyra::VectorBase<Scalar> > f_bar = outArgs_bar.get_f();
00451
00452 RCP<DALOWS> W_bar;
00453 if (outArgs_bar.supports(MEB::OUT_ARG_W))
00454 W_bar = rcp_dynamic_cast<DALOWS>(outArgs_bar.get_W(), true);
00455
00456 RCP<DSALO> W_bar_op;
00457 if (outArgs_bar.supports(MEB::OUT_ARG_W_op))
00458 W_bar_op = rcp_dynamic_cast<DSALO>(outArgs_bar.get_W_op(), true);
00459
00460 if (dumpAll) {
00461 if (!is_null(W_bar)) {
00462 *out << "\nW_bar = " << describe(*W_bar, Teuchos::VERB_EXTREME);
00463 }
00464 if (!is_null(W_bar_op)) {
00465 *out << "\nW_bar_op = " << describe(*W_bar_op, Teuchos::VERB_EXTREME);
00466 }
00467 }
00468
00469
00470
00471
00472
00473 MEB::InArgs<Scalar> fwdInArgs = fwdStateModel_->createInArgs();
00474
00475
00476
00477 fwdInArgs = basePoint_;
00478
00479 if (!is_null(fwdStateSolutionBuffer_)) {
00480 const Scalar t = fwdTimeRange_.length() - t_bar;
00481 RCP<const Thyra::VectorBase<Scalar> > x, x_dot;
00482 get_x_and_x_dot<Scalar>( *fwdStateSolutionBuffer_, t,
00483 outArg(x), outArg(x_dot) );
00484 fwdInArgs.set_x(x);
00485 fwdInArgs.set_x_dot(x);
00486 }
00487 else {
00488
00489
00490
00491
00492
00493
00494
00495
00496
00497 }
00498
00499
00500
00501
00502 RCP<Thyra::LinearOpWithSolveBase<Scalar> > W_bar_adj;
00503 RCP<Thyra::LinearOpBase<Scalar> > W_bar_adj_op;
00504 {
00505
00506 MEB::OutArgs<Scalar> fwdOutArgs = fwdStateModel_->createOutArgs();
00507
00508
00509 if (!is_null(W_bar)) {
00510
00511
00512 W_bar_adj = W_bar->getNonconstOp();
00513 W_bar_adj_op = W_bar_adj;
00514 }
00515 else if (!is_null(W_bar_op)) {
00516
00517
00518 W_bar_adj_op = W_bar_op->getNonconstOp();
00519 }
00520 else if (!is_null(f_bar)) {
00521 TEST_FOR_EXCEPT_MSG(true, "ToDo: Unit test this code!");
00522
00523
00524
00525 if (is_null(my_W_bar_adj_op_)) {
00526 my_W_bar_adj_op_ = fwdStateModel_->create_W_op();
00527 }
00528 W_bar_adj_op = my_W_bar_adj_op_;
00529 }
00530
00531
00532 if (!is_null(W_bar_adj)) {
00533 fwdOutArgs.set_W(W_bar_adj);
00534 }
00535 else if (!is_null(W_bar_adj_op)) {
00536 fwdOutArgs.set_W_op(W_bar_adj_op);
00537 }
00538
00539
00540 if (!is_null(W_bar_adj) || !is_null(W_bar_adj_op)) {
00541 fwdInArgs.set_alpha(alpha_bar);
00542 fwdInArgs.set_beta(beta_bar);
00543 }
00544
00545
00546 if (!is_null(W_bar_adj) || !is_null(W_bar_adj_op)) {
00547 fwdStateModel_->evalModel( fwdInArgs, fwdOutArgs );
00548 }
00549
00550
00551 if (!is_null(W_bar_adj) && dumpAll)
00552 *out << "\nW_bar_adj = " << describe(*W_bar_adj, Teuchos::VERB_EXTREME);
00553 if (!is_null(W_bar_adj_op) && dumpAll)
00554 *out << "\nW_bar_adj_op = " << describe(*W_bar_adj_op, Teuchos::VERB_EXTREME);
00555
00556 }
00557
00558
00559
00560 RCP<Thyra::LinearOpBase<Scalar> > d_f_d_x_dot_op;
00561 if (!is_null(f_bar)) {
00562 if (is_null(my_d_f_d_x_dot_op_)) {
00563 my_d_f_d_x_dot_op_ = fwdStateModel_->create_W_op();
00564 }
00565 d_f_d_x_dot_op = my_d_f_d_x_dot_op_;
00566 MEB::OutArgs<Scalar> fwdOutArgs = fwdStateModel_->createOutArgs();
00567 fwdOutArgs.set_W_op(d_f_d_x_dot_op);
00568 fwdInArgs.set_alpha(ST::one());
00569 fwdInArgs.set_beta(ST::zero());
00570 fwdStateModel_->evalModel( fwdInArgs, fwdOutArgs );
00571 if (dumpAll) {
00572 *out << "\nd_f_d_x_dot_op = " << describe(*d_f_d_x_dot_op, Teuchos::VERB_EXTREME);
00573 }
00574 }
00575
00576
00577
00578
00579
00580
00581
00582
00583 if (!is_null(f_bar)) {
00584
00585
00586 const RCP<Thyra::VectorBase<Scalar> >
00587 lambda_hat = createMember(lambda_rev_dot->space());
00588 Thyra::V_VpStV<Scalar>( outArg(*lambda_hat),
00589 *lambda_rev_dot, -alpha_bar/beta_bar, *lambda );
00590 if (dumpAll)
00591 *out << "\nlambda_hat = " << describe(*lambda_hat, Teuchos::VERB_EXTREME);
00592
00593
00594 Thyra::apply<Scalar>( *d_f_d_x_dot_op, Thyra::CONJTRANS, *lambda_hat,
00595 outArg(*f_bar) );
00596
00597
00598 Thyra::apply<Scalar>( *W_bar_adj_op, Thyra::CONJTRANS, *lambda,
00599 outArg(*f_bar), 1.0/beta_bar, ST::one() );
00600
00601
00602
00603
00604
00605 if (dumpAll)
00606 *out << "\nf_bar = " << describe(*f_bar, Teuchos::VERB_EXTREME);
00607
00608 }
00609
00610 if (dumpAll) {
00611 if (!is_null(W_bar)) {
00612 *out << "\nW_bar = " << describe(*W_bar, Teuchos::VERB_EXTREME);
00613 }
00614 if (!is_null(W_bar_op)) {
00615 *out << "\nW_bar_op = " << describe(*W_bar_op, Teuchos::VERB_EXTREME);
00616 }
00617 }
00618
00619
00620
00621
00622
00623
00624 THYRA_MODEL_EVALUATOR_DECORATOR_EVAL_MODEL_END();
00625
00626 }
00627
00628
00629
00630
00631
00632 template<class Scalar>
00633 void AdjointModelEvaluator<Scalar>::initialize() const
00634 {
00635
00636 typedef Thyra::ModelEvaluatorBase MEB;
00637
00638 if (isInitialized_)
00639 return;
00640
00641
00642
00643
00644
00645 MEB::InArgs<Scalar> fwdStateModelInArgs = fwdStateModel_->createInArgs();
00646 MEB::OutArgs<Scalar> fwdStateModelOutArgs = fwdStateModel_->createOutArgs();
00647
00648 #ifdef RYTHMOS_DEBUG
00649 TEUCHOS_ASSERT( fwdStateModelInArgs.supports(MEB::IN_ARG_x_dot) );
00650 TEUCHOS_ASSERT( fwdStateModelInArgs.supports(MEB::IN_ARG_x) );
00651 TEUCHOS_ASSERT( fwdStateModelInArgs.supports(MEB::IN_ARG_t) );
00652 TEUCHOS_ASSERT( fwdStateModelInArgs.supports(MEB::IN_ARG_alpha) );
00653 TEUCHOS_ASSERT( fwdStateModelInArgs.supports(MEB::IN_ARG_beta) );
00654 TEUCHOS_ASSERT( fwdStateModelOutArgs.supports(MEB::OUT_ARG_f) );
00655 TEUCHOS_ASSERT( fwdStateModelOutArgs.supports(MEB::OUT_ARG_W) );
00656 #endif
00657
00658
00659
00660
00661
00662 {
00663 MEB::InArgsSetup<Scalar> inArgs_bar;
00664 inArgs_bar.setModelEvalDescription(this->description());
00665 inArgs_bar.setSupports( MEB::IN_ARG_x_dot );
00666 inArgs_bar.setSupports( MEB::IN_ARG_x );
00667 inArgs_bar.setSupports( MEB::IN_ARG_t );
00668 inArgs_bar.setSupports( MEB::IN_ARG_alpha );
00669 inArgs_bar.setSupports( MEB::IN_ARG_beta );
00670 prototypeInArgs_bar_ = inArgs_bar;
00671 }
00672
00673 {
00674 MEB::OutArgsSetup<Scalar> outArgs_bar;
00675 outArgs_bar.setModelEvalDescription(this->description());
00676 outArgs_bar.setSupports(MEB::OUT_ARG_f);
00677 if (fwdStateModelOutArgs.supports(MEB::OUT_ARG_W) ) {
00678 outArgs_bar.setSupports(MEB::OUT_ARG_W);
00679 outArgs_bar.set_W_properties(fwdStateModelOutArgs.get_W_properties());
00680 }
00681 if (fwdStateModelOutArgs.supports(MEB::OUT_ARG_W_op) ) {
00682 outArgs_bar.setSupports(MEB::OUT_ARG_W_op);
00683 outArgs_bar.set_W_properties(fwdStateModelOutArgs.get_W_properties());
00684 }
00685 prototypeOutArgs_bar_ = outArgs_bar;
00686 }
00687
00688
00689
00690
00691
00692
00693 adjointNominalValues_ = prototypeInArgs_bar_;
00694
00695 const RCP<Thyra::VectorBase<Scalar> > zero_lambda_vec =
00696 createMember(fwdStateModel_->get_f_space());
00697 V_S( zero_lambda_vec.ptr(), ScalarTraits<Scalar>::zero() );
00698 adjointNominalValues_.set_x_dot(zero_lambda_vec);
00699 adjointNominalValues_.set_x(zero_lambda_vec);
00700
00701
00702
00703
00704
00705 my_W_bar_adj_op_ = Teuchos::null;
00706 my_d_f_d_x_dot_op_ = Teuchos::null;
00707
00708
00709
00710
00711
00712 isInitialized_ = true;
00713
00714 }
00715
00716
00717 }
00718
00719
00720 #endif // RYTHMOS_ADJOINT_MODEL_EVALUATOR_HPP