-
Notifications
You must be signed in to change notification settings - Fork 4
Expand file tree
/
Copy pathmdp.v
More file actions
2800 lines (2404 loc) · 90.4 KB
/
mdp.v
File metadata and controls
2800 lines (2404 loc) · 90.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
(**Require Import MLCert.axioms.
From mathcomp Require Import all_ssreflect.
Require Import List. Import ListNotations.
Require Import ZArith.
Require Import MLCert.float32.
Require Import MLCert.extraction_ocaml.
Require Import Coq.Logic.FunctionalExtensionality.
**)
(**Require Import MLCert.extraction_ocaml.**)
Require Import List. Import ListNotations.
Require Import QArith.
Require Import OUVerT.extrema.
Require Import OUVerT.banach.
Require Import OUVerT.orderedtypes.
(**Require Import FunctionalExtensionality.**)
(**Open Scope f32_scope.**)
Require Import OUVerT.numerics.
Require Import OUVerT.generalized_bigops.
Require Import OUVerT.dyadic.
Require Import ProofIrrelevance.
Require Import OUVerT.compile.
Require Import OUVerT.enumerables.
Require Import Reals.SeqProp.
Require Import Reals.Rseries.
Require Import Rbase.
Require Import Reals.Rfunctions.
Require Import Reals.Rpower.
Import OUVerT.numerics.Numerics.
Import OUVerT.extrema.num_Extrema.
Require Import mathcomp.algebra.matrix.
Require Import Reals.Rcomplete.
Require Import Psatz.
From mathcomp Require Import all_ssreflect.
Require Import mathcomp.ssreflect.ssreflect.
(**Lemma total_order_max: forall (T : Type) (P : T->T->Prop),
(forall (t1 t2 t3 : T), P t1 t2 -> P t2 t3 -> P t1 t3) ->
(forall (t1 t2 : T), P t1 t2 \/ P t2 t1) ->
(forall (t : T), P t t) ->
(forall (t1 t2 : T), {t1 = t2} + {t1 <> t2})->
(forall (l : list T), length l <> O ->
exists t1 : T, (forall t2 : T, In t2 l -> P t1 t2)).
Proof.
intros.
induction l.
{ exfalso. apply H2. auto. }
destruct l.
{
exists a.
intros.
inversion H3.
{ rewrite H4; auto. }
inversion H4.
}
destruct IHl.
{ simpl. unfold not. intros. inversion H3. }
destruct H0 with x a.
{
exists x.
intros.
destruct X with a t2.
{ rewrite <- e; auto. }
inversion H5.
{ rewrite H6 in n. exfalso; apply n; auto. }
apply H3; auto.
}
exists a.
intros.
destruct X with a t2.
{ rewrite e; auto. }
apply H with x; auto.
apply H3.
inversion H5; auto.
exfalso. apply n; auto.
Qed.
Lemma in_iff: forall (T : Type) (l : list T) (a : T), SetoidList.InA (fun x y : T => x = y) a l <-> In a l.
Proof.
intros.
split; intros.
{
induction l.
inversion H.
inversion H.
{
rewrite H1.
constructor.
auto.
}
constructor 2.
apply IHl.
auto.
}
induction l; inversion H.
{
rewrite H0.
constructor.
auto.
}
constructor 2.
apply IHl.
auto.
Qed.
Lemma no_dup_iff: forall (T : Type) (l : list T), SetoidList.NoDupA (fun x y : T => x = y) l <-> NoDup l.
Proof.
intros.
split.
{
intros.
induction l.
constructor.
constructor.
{
rewrite <- in_iff.
inversion H.
auto.
}
apply IHl.
inversion H.
auto.
}
intros.
induction l.
constructor.
constructor.
{
rewrite in_iff.
apply NoDup_cons_iff; auto.
}
apply IHl.
inversion H.
auto.
Qed.**)
Delimit Scope Numeric_scope with Num.
Local Open Scope Num.
Module MDP.
Section mdp_numeric.
Context {Nt:Type} `{Numerics.Numeric Nt}.
Class mdp : Type :=
{
St : Type;
A : Type;
Trans : St->A->St->Nt;
Reward : St->A->St->Nt;
}.
Class mdp_fin (t : mdp) : Type :=
{
SEnum : Enumerable St;
AEnum : Enumerable A;
SEnum_nonempty : O <> length SEnum;
AEnum_nonempty : O <> length AEnum;
}.
Class enum_trans_func {T : Type} {T_enum : Enumerable T} {T_enum_ok : @Enum_ok T T_enum} (f : T->Nt) : Type :=
{
TransSum1 : big_sum T_enum f <= 1;
TransNonneg: forall t : T, 0 <= f t
}.
Class mdp_dec (t : mdp) : Type :=
{
SDec : forall s1 s2 :St, {s1 = s2} + {s1 <> s2};
ADec : forall a1 a2 :A, {a1 = a2} + {a1 <> a2};
}.
Class mdp_fin_ok (t : mdp) {fin : mdp_fin t} :=
{
AEnum_ok : @Enum_ok A AEnum;
SEnum_ok : @Enum_ok St SEnum;
TransEnumFunc : forall (s : St) (a : A), enum_trans_func (T_enum_ok := SEnum_ok) (Trans s a);
}.
End mdp_numeric.
End MDP.
Module MDP_algorithms.
Section mdp_to_R.
Context {Nt : Type} `{Nt_numeric : Numerics.Numeric Nt}.
Definition mdp_to_R (mdp : @MDP.mdp Nt) : (@MDP.mdp R) :=
@MDP.Build_mdp R (@MDP.St Nt mdp) (@MDP.A Nt mdp) (fun st a st' => Numerics.to_R (@MDP.Trans Nt mdp st a st'))
(fun st a st' => Numerics.to_R (MDP.Reward st a st')).
Definition mdp_fin_to_R (mdp : @MDP.mdp Nt) (mdp_fin : @MDP.mdp_fin Nt mdp)
: (@MDP.mdp_fin R (mdp_to_R mdp)).
destruct mdp_fin.
destruct mdp.
simpl in *.
exact (@MDP.Build_mdp_fin R
(mdp_to_R {| MDP.St := St; MDP.A := A; MDP.Trans := Trans; MDP.Reward := Reward |})
SEnum AEnum SEnum_nonempty AEnum_nonempty).
Defined.
Definition mdp_dec_to_R (mdp : @MDP.mdp Nt) (mdp_dec : @MDP.mdp_dec Nt mdp)
: (@MDP.mdp_dec R (mdp_to_R mdp)).
destruct mdp.
destruct mdp_dec.
unfold mdp_to_R.
simpl in *.
exact (@MDP.Build_mdp_dec R ({|
MDP.St := St;
MDP.A := A;
MDP.Trans := fun (st : St) (a : A) (st' : St) => to_R (Trans st a st');
MDP.Reward := fun (st : St) (a : A) (st' : St) => to_R (Reward st a st') |})
SDec ADec).
Defined.
End mdp_to_R.
Section mdp_to_R_ok.
Context {Nt : Type} `{Nt_numeric : Numerics.Numeric Nt} `{Nt_props : Numerics.Numeric_Props Nt}.
Definition mdp_fin_ok_to_R (mdp : @MDP.mdp Nt) {fin : MDP.mdp_fin mdp} (ok : MDP.mdp_fin_ok mdp)
: (@MDP.mdp_fin_ok _ _ (mdp_to_R mdp) (mdp_fin_to_R mdp fin)).
destruct mdp.
destruct fin.
destruct ok.
unfold mdp_to_R.
remember ({|
MDP.St := St;
MDP.A := A;
MDP.Trans := fun (st : St) (a : A) (st' : St) => to_R (Trans st a st');
MDP.Reward := fun (st : St) (a : A) (st' : St) => to_R (Reward st a st') |}) as rmdp.
assert(Rtrans_fun : forall (s : St) (a : A),
MDP.enum_trans_func (@MDP.Trans R
{|
MDP.St := St;
MDP.A := A;
MDP.Trans := fun (st : St) (a : A) (st' : St) => to_R (Trans st a st');
MDP.Reward := fun (st : St) (a : A) (st' : St) => to_R (Reward st a st') |}
s a)).
{
simpl.
intros.
destruct TransEnumFunc with s a.
constructor.
{
rewrite -> to_R_le in TransSum1.
rewrite <- to_R_big_sum in TransSum1.
rewrite to_R_mult_id in TransSum1.
apply TransSum1.
}
intros.
assert(to_R 0 = 0). apply to_R_plus_id.
rewrite <- H.
rewrite <- to_R_le.
apply TransNonneg.
}
exact (@MDP.Build_mdp_fin_ok _ _ _
_ _ _ _
).
Defined.
End mdp_to_R_ok.
Section mdp_fin_dec.
Context {Nt:Type} `{Numerics.Numeric Nt}.
Context (mdp : @MDP.mdp Nt) `{mdp_fin : @MDP.mdp_fin Nt mdp} `{mdp_dec : @MDP.mdp_dec Nt mdp}.
Variable discount : Nt.
Definition St : Type := (@MDP.St Nt mdp).
Definition A : Type := (@MDP.A Nt mdp).
Definition policy := St->A.
Definition value_func := St -> Nt.
Definition s_enum := @enumerable_fun St (@MDP.SEnum Nt mdp _).
Definition a_enum := @enumerable_fun A (@MDP.AEnum Nt mdp _).
Definition trans := (@MDP.Trans Nt mdp).
Definition reward := (@MDP.Reward Nt mdp).
Definition a_enum_nonempty := @MDP.AEnum_nonempty Nt mdp _.
Definition s_enum_nonempty := @MDP.SEnum_nonempty Nt mdp _.
(**Definition eqSt := (@MDP.SDec Nt mdp _) .**)
Definition s_enum_ok := @MDP.SEnum_ok Nt _ mdp _.
Definition a_enum_ok := @MDP.AEnum_ok Nt _ mdp _.
Definition value_table : Type := @Enum_table.table St Nt s_enum.
Definition policy_table : Type := @Enum_table.table St A s_enum.
Definition policy_table_lookup (v : policy_table) (s : St) : A :=
@Enum_table.lookup St A s_enum s_enum_nonempty MDP.SDec v s.
Definition value_func_to_table (v : value_func) : value_table :=
@Enum_table.map_to_table St Nt (@MDP.SEnum Nt mdp _) v.
Definition value_table_lookup (v : value_table) (s : St) : Nt :=
@Enum_table.lookup St Nt s_enum s_enum_nonempty MDP.SDec v s.
Definition value_func_table_eq (vf : value_func) (vt : value_table) := vf =1 value_table_lookup vt.
Definition policy_func_table_eq (vf : policy) (vt : policy_table) := vf =1 policy_table_lookup vt.
Definition discounted_reward (v : value_func) (s : St) (a : A) : Nt :=
big_sum s_enum (fun s' => (trans s a s') * (reward s a s' + discount * (v s'))).
Definition value_func_policy (v: value_func) : policy :=
(fun s => argmax_ne (discounted_reward v s) a_enum_nonempty).
Definition discounted_reward_tb (v : value_table) (s : St) (a : A) : Nt :=
big_sum s_enum (fun s' => (trans s a s') * (reward s a s' + discount * value_table_lookup v s')).
Definition value_table_policy (v : value_table) : policy_table :=
@Enum_table.map_to_table St A s_enum (fun s =>
(argmax_ne (fun a => discounted_reward_tb v s a) a_enum_nonempty )).
Definition value_iteration_step (v : value_func) : value_func :=
(fun (s : St) =>
mapmax_ne (fun a => discounted_reward v s a) a_enum_nonempty
).
Fixpoint value_iteration_rec (v : value_func) (n : nat):=
match n with
| O => v
| S n' => value_iteration_step (value_iteration_rec v n')
end.
Definition evaluate_policy_step (pol : policy) (v : value_func) : value_func :=
(fun s => discounted_reward v s (pol s)).
Fixpoint evaluate_policy_rec (pol : policy) (v : value_func) (n : nat):=
match n with
| O => v
| S n' => evaluate_policy_step pol (evaluate_policy_rec pol v n')
end.
Definition value_iteration_step_tb (v : value_table) : value_table :=
value_func_to_table (fun s => mapmax_ne (fun a : A => discounted_reward_tb v s a) a_enum_nonempty).
Definition evaluate_policy_step_tb (pol : policy_table) (v : value_table) : value_table :=
value_func_to_table (fun s => discounted_reward_tb v s (policy_table_lookup pol s)).
Fixpoint value_iteration_rec_tb (v : value_table) (n : nat):=
match n with
| O => v
| S n' => value_iteration_step_tb (value_iteration_rec_tb v n')
end.
Fixpoint evaluate_policy_rec_tb (pol : policy_table) (v : value_table) (n : nat):=
match n with
| O => v
| S n' => evaluate_policy_step_tb pol (evaluate_policy_rec_tb pol v n')
end.
Section mdp_fin_dec_proofs.
Context `{mdp_fin_ok : @MDP.mdp_fin_ok Nt _ mdp _}.
Hypothesis discount_nonneg : 0 <= discount.
Hypothesis discount_lt_1: discount < 1.
Context `{Nt_props : Numeric_Props Nt (numeric_t := H)}.
Lemma value_table_lookup_inv_to_table: forall (v : value_func) (s : St), value_table_lookup (value_func_to_table v) s = v s.
Proof.
intros.
unfold value_table_lookup.
unfold value_func_to_table.
rewrite (@Enum_table.lookup_map _ _ _ _ _ _ _); auto.
apply s_enum_ok. apply mdp_fin_ok.
Qed.
Lemma discounted_reward_tb_same: forall (vf : value_func) (vt : value_table) (s : St) (a : A),
value_func_table_eq vf vt -> discounted_reward vf s a = discounted_reward_tb vt s a.
Proof.
intros.
unfold discounted_reward.
unfold discounted_reward_tb.
apply big_sum_ext; auto. unfold value_func_table_eq in H0.
unfold eqfun. intros.
rewrite <- H0. auto.
Qed.
Lemma discounted_reward_ext: forall v1 v2 : value_func, v1 =1 v2 -> forall s, discounted_reward v1 s =1 discounted_reward v2 s.
Proof.
unfold discounted_reward. unfold eqfun.
intros v1 v2 ext s a.
apply big_sum_ext; auto.
unfold eqfun.
intros s2.
rewrite ext. auto.
Qed.
Lemma value_iteration_step_ext: forall v1 v2 : value_func, v1 =1 v2 -> value_iteration_step v1 =1 value_iteration_step v2.
Proof.
unfold eqfun.
intros v1 v2 Heq s.
unfold value_iteration_step.
apply mapmax_ne_ext.
apply discounted_reward_ext.
auto.
Qed.
Lemma evaluate_policy_step_ext: forall (p1 p2 : policy) (v1 v2 : value_func), p1 =1 p2 -> v1 =1 v2 -> evaluate_policy_step p1 v1 =1 evaluate_policy_step p2 v2.
Proof.
unfold eqfun.
intros p1 p2 v1 v2 HeqP HeqV s.
unfold evaluate_policy_step.
unfold discounted_reward.
apply big_sum_ext; auto.
unfold eqfun.
intros s'.
f_equal.
f_equal; auto.
f_equal.
f_equal; auto.
f_equal. auto.
Qed.
Lemma evaluate_policy_step_value_ext: forall (p : policy) (v1 v2 : value_func), v1 =1 v2 -> evaluate_policy_step p v1 =1 evaluate_policy_step p v2.
Proof. intros. apply evaluate_policy_step_ext; auto. Qed.
Lemma value_iteration_step_tb_same: forall (vf : value_func) (vt : value_table),
value_func_table_eq vf vt -> value_func_table_eq (value_iteration_step vf) (value_iteration_step_tb vt).
Proof.
unfold value_func_table_eq.
intros vs vt Heq.
unfold value_iteration_step_tb.
unfold eqfun. intros s2.
rewrite value_table_lookup_inv_to_table.
unfold value_iteration_step.
apply mapmax_ne_ext. apply discounted_reward_ext. auto.
Qed.
Definition value_diff (v1 v2 : value_func) : value_func :=
(fun s => v1 s + - v2 s).
Definition value_dist (v1 v2 : value_func) : Nt :=
mapmax_ne (fun s => Numerics.abs ((value_diff v1 v2) s) ) s_enum_nonempty .
Lemma value_dist_ext: forall (v1a v2a v1b v2b : value_func),
v1a =1 v1b -> v2a =1 v2b -> value_dist v1a v2a = value_dist v1b v2b.
Proof.
intros.
unfold value_dist.
unfold eqfun in *.
apply mapmax_ne_ext.
intros.
unfold value_diff.
rewrite H0.
rewrite H1.
auto.
Qed.
Lemma value_iteration_rec_reverse: forall (v : value_func) (n : nat), value_iteration_rec (value_iteration_step v) n = value_iteration_step (value_iteration_rec v n).
Proof.
intros v n.
generalize v.
induction n; intros; simpl; auto; rewrite IHn; auto.
Qed.
Theorem evaluate_policy_contraction: forall (pol : policy) (v1 v2 : value_func),
value_dist (evaluate_policy_step pol v1) (evaluate_policy_step pol v2) <= discount * value_dist v1 v2.
Proof.
intros pol v1 v2.
unfold value_dist.
unfold value_diff.
unfold evaluate_policy_step.
unfold discounted_reward.
apply mapmax_ne_le_const.
intros s HIn.
rewrite -> big_sum_ext with _ _ s_enum _
(fun s' => trans s (pol s) s' * reward s (pol s) s' + discount * (trans s (pol s) s' * v1 s')); auto.
2: {
unfold eqfun.
intros s'.
rewrite mult_plus_distr_l.
repeat rewrite mult_assoc.
rewrite -> mult_comm with _ discount. auto.
}
rewrite -> big_sum_ext with _ _ s_enum
(fun s' : St => trans s (pol s) s' * (reward s (pol s) s' + discount * v2 s'))
(fun s' => trans s (pol s) s' * reward s (pol s) s' + discount * (trans s (pol s) s' * v2 s')); auto.
2:{
unfold eqfun.
intros s'.
rewrite mult_plus_distr_l.
repeat rewrite mult_assoc.
rewrite -> mult_comm with _ discount. auto.
}
repeat rewrite big_sum_plus.
rewrite plus_neg_distr.
rewrite -> plus_comm with (big_sum s_enum (fun c : St => trans s (pol s) c * reward s (pol s) c)) _.
rewrite <- plus_assoc.
rewrite -> plus_assoc with (big_sum s_enum (fun c : St => trans s (pol s) c * reward s (pol s) c)) _ _.
rewrite plus_neg_r.
rewrite plus_id_l.
repeat rewrite <- big_sum_mult_left; auto.
rewrite neg_mult_distr_r.
rewrite <- mult_plus_distr_l.
rewrite abs_mult_pos_l; auto.
apply mult_le_compat_l; auto.
rewrite <- big_sum_nmul.
rewrite <- big_sum_plus.
rewrite -> big_sum_ext with _ _ s_enum _ (fun s' : St => trans s (pol s) s' * (v1 s' + - v2 s')); auto.
2:{
unfold eqfun.
intros s'.
rewrite mult_plus_distr_l.
rewrite <- neg_mult_distr_r.
auto.
}
eapply le_trans.
eapply big_sum_le_abs.
erewrite -> big_sum_ext.
2:{ reflexivity. }
2:{
unfold eqfun.
intros n.
apply abs_mult_pos_l.
eapply MDP.TransNonneg.
}
eapply le_trans.
{
eapply big_sum_func_leq_max_l.
intros.
eapply MDP.TransNonneg.
}
rewrite <- mult_id_l.
apply mult_le_compat_r.
2: { eapply MDP.TransSum1. }
apply mapmax_ne_ge_const.
exists s.
split; auto.
apply abs_ge_0.
Unshelve.
eapply MDP.SEnum.
eapply MDP.SEnum_ok.
eapply MDP.TransEnumFunc.
eapply MDP.SEnum.
eapply MDP.SEnum_ok.
eapply MDP.TransEnumFunc.
eapply MDP.SEnum_ok.
eapply MDP.TransEnumFunc.
Qed.
Theorem value_iteration_contraction: forall (v1 v2 : value_func),
value_dist (value_iteration_step v1) (value_iteration_step v2) <= (discount) * (value_dist v1 v2).
Proof.
intros.
unfold value_dist.
unfold value_iteration_step.
unfold value_diff.
apply mapmax_ne_le_const.
intros s' _.
rewrite mapmax_ne_mult_pos_l; auto.
erewrite mapmax_ne_ext with _ (fun n : St => discount * abs (v1 n + - v2 n)) _ _ _.
2:{
intros s''.
rewrite <- abs_mult_pos_l; auto.
rewrite mult_plus_distr_l.
rewrite <- neg_mult_distr_r.
reflexivity.
}
eapply le_trans.
apply mapmax_ne_abs_dist_le.
unfold discounted_reward.
erewrite mapmax_ne_ext.
2:{
intros a.
rewrite <- big_sum_nmul.
rewrite <- big_sum_plus.
erewrite big_sum_ext.
reflexivity.
reflexivity.
unfold eqfun.
intros s''.
rewrite neg_mult_distr_r.
rewrite plus_neg_distr.
repeat rewrite mult_plus_distr_l.
rewrite plus_assoc.
rewrite -> plus_comm with (trans s' a s'' * reward s' a s'' ) _.
rewrite <- plus_assoc with (trans s' a s'' * (discount * v1 s'')) _ _.
rewrite <- neg_mult_distr_r.
rewrite plus_neg_r.
rewrite plus_id_r.
rewrite <- mult_plus_distr_l.
rewrite neg_mult_distr_r.
rewrite <- mult_plus_distr_l.
rewrite mult_assoc.
rewrite -> mult_comm with _ discount.
by rewrite <- mult_assoc.
}
apply mapmax_ne_le_const.
intros a in_a.
rewrite <- big_sum_mult_left.
rewrite abs_mult_pos_l; auto.
erewrite mapmax_ne_ext.
2:{
intros s''.
rewrite neg_mult_distr_r.
rewrite <- mult_plus_distr_l.
by rewrite abs_mult_pos_l.
}
rewrite <- mapmax_ne_mult_pos_l; auto.
apply mult_le_compat_l; auto.
eapply le_trans.
apply big_sum_le_abs.
simpl.
erewrite big_sum_ext.
2:{ reflexivity. }
2:{
unfold eqfun.
intros s''.
rewrite abs_mult_pos_l.
reflexivity.
destruct mdp_fin_ok.
by destruct TransEnumFunc with s' a.
}
eapply le_trans.
{
apply big_sum_func_leq_max_l.
intros s'' in_s''.
destruct mdp_fin_ok.
by destruct TransEnumFunc with s' a.
}
eapply le_trans.
{
eapply mult_le_compat.
4:{ unfold le. right. reflexivity. }
3:{
destruct mdp_fin_ok.
destruct TransEnumFunc with s' a.
apply TransSum1.
}
{
destruct mdp_fin.
destruct mdp_fin_ok.
apply big_sum_ge0'.
intros.
by destruct TransEnumFunc with s' a.
}
apply mapmax_ne_ge_const.
exists s'.
split.
2:{ apply abs_ge_0. }
destruct mdp_fin_ok.
by destruct SEnum_ok.
}
rewrite mult_id_l.
right. reflexivity.
Qed.
Program Definition value_iteration_banach : banach.contraction_func :=
banach.contraction_mk
Nt _
St
s_enum
(s_enum_ok _)
discount
_
value_iteration_step
_
_
value_iteration_step_ext
value_iteration_contraction
.
Program Definition evaluate_policy_banach (pol : policy) : banach.contraction_func :=
banach.contraction_mk
Nt _
St
s_enum
(s_enum_ok _)
discount
s_enum_nonempty
(evaluate_policy_step pol)
_ _
(evaluate_policy_step_value_ext pol)
(evaluate_policy_contraction pol)
.
Lemma value_func_eval_ub: forall (s : St)(p : policy) (n : nat) (v : value_func),
(evaluate_policy_rec p v n) s <= (value_iteration_rec v n) s.
Proof.
intros s p n v.
generalize dependent s.
induction n; simpl.
by right.
intros s.
unfold evaluate_policy_step. unfold value_iteration_step.
apply le_trans with (discounted_reward (value_iteration_rec v n) s (p s)).
{
unfold discounted_reward.
apply big_sum_le'.
intros s'.
apply mult_le_compat_l.
{
destruct mdp_fin_ok.
by destruct TransEnumFunc with s (p s).
}
apply Numerics.plus_le_compat_l.
apply mult_le_compat_l; auto.
}
apply mapmax_ne_correct.
destruct mdp_fin_ok.
by destruct AEnum_ok.
Qed.
Lemma evaluate_policy_rec_banach_rec: forall (v : St -> Nt) (n : nat) (p : policy),
evaluate_policy_rec p v n = banach.rec_f (evaluate_policy_step p ) v n.
Proof.
intros.
induction n; auto.
simpl.
by rewrite IHn.
Qed.
Lemma value_iteration_rec_banach_rec: forall (v : St -> Nt) (n : nat),
value_iteration_rec v n = banach.rec_f (value_iteration_step ) v n.
Proof.
intros.
induction n; auto.
simpl.
by rewrite IHn.
Qed.
Lemma value_dist_banach_dist: forall (v1 v2 : value_func), value_dist v1 v2 = banach.max_dist _ s_enum_nonempty v1 v2.
Proof. auto. Qed.
Lemma value_dist_triangle: forall (v1 v2 v3: value_func), value_dist v1 v3 <=
value_dist v1 v2 + value_dist v2 v3.
Proof. intros. apply (banach.dist_triangle (value_iteration_banach)). Qed.
Lemma value_dist_triangle2: forall (v1 v2 v3 v4 : value_func), value_dist v1 v4 <=
value_dist v1 v2 + value_dist v2 v3 + value_dist v3 v4.
Proof.
intros.
eapply le_trans.
eapply value_dist_triangle.
rewrite <- plus_assoc.
eapply plus_le_compat_l.
apply value_dist_triangle.
Qed.
Lemma value_dist_ge0: forall (v1 v2 : value_func), 0 <= value_dist v1 v2.
Proof.
intros.
unfold value_dist.
rewrite <- mapmax_ne_const with _ _ 0 (s_enum_nonempty).
apply mapmax_ne_le_ext.
intros.
apply Numerics.abs_ge_0.
Qed.
Lemma value_dist_comm: forall (v1 v2 : value_func), value_dist v1 v2 = value_dist v2 v1.
Proof.
intros.
unfold value_dist.
apply mapmax_ne_ext.
intros.
unfold value_diff.
rewrite <- abs_neg.
rewrite plus_neg_distr. rewrite double_neg. rewrite plus_comm. auto.
Qed.
Lemma value_iteration_step_policy_eval_same: forall (v : value_func) s,
value_iteration_step v s = evaluate_policy_step (value_func_policy v ) v s.
Proof.
intros.
unfold value_iteration_step.
unfold evaluate_policy_step.
unfold value_func_policy.
rewrite argmax_ne_mapmax_ne. auto.
Qed.
Lemma policy_eval_value_iteration_diff: forall (v1 v2 : value_func),
value_dist (value_iteration_step v1)
(evaluate_policy_step (value_func_policy v1) v2) <=
discount * value_dist v1 v2.
Proof.
intros.
unfold value_dist.
apply mapmax_ne_le_const.
intros.
unfold value_diff.
rewrite value_iteration_step_policy_eval_same.
unfold evaluate_policy_step.
unfold discounted_reward.
rewrite <- big_sum_nmul.
rewrite <- big_sum_plus.
eapply le_trans. eapply big_sum_le_abs.
erewrite big_sum_ext.
2: { reflexivity. }
2:{
unfold eqfun.
intros.
rewrite neg_mult_distr_r.
rewrite <- mult_plus_distr_l.
rewrite plus_neg_distr.
rewrite plus_assoc.
rewrite <- plus_assoc with _ (discount * v1 x) _.
rewrite -> plus_comm with (discount * v1 x) _.
rewrite plus_assoc.
rewrite plus_neg_r.
rewrite plus_id_l.
rewrite neg_mult_distr_r.
rewrite <- mult_plus_distr_l.
rewrite abs_mult_pos_l.
{
rewrite abs_mult_pos_l; auto.
rewrite mult_assoc. rewrite -> mult_comm with _ discount.
rewrite <- mult_assoc.
reflexivity.
}
apply (MDP.TransEnumFunc n).
}
rewrite -> big_sum_scalar.
apply mult_le_compat_l; auto.
unfold value_dist.
unfold value_diff.
eapply le_trans.
{ eapply big_sum_func_leq_max_l. intros. apply (MDP.TransEnumFunc n). }
eapply le_trans.
{
apply mult_le_compat_r.
apply mapmax_ne_ge_const.
exists n. split; auto.
apply abs_ge_0.
eapply MDP.TransSum1.
}
rewrite mult_id_l. right. auto.
Unshelve.
apply MDP.SEnum_ok.
apply MDP.TransEnumFunc.
Qed.
(**Lemma policies_nonempty: O <> length enumerate_policies.
Proof.
unfold enumerate_policies.
apply enumerate_table_nonempty.
apply actions_nonempty.
Qed.**)
End mdp_fin_dec_proofs.
End mdp_fin_dec.
Section mdp_fin_dec_toR.
Context {Nt:Type} `{Nt_Numeric : Numerics.Numeric Nt} `{Nt_Props : Numerics.Numeric_Props Nt}.
Context (mdp : @MDP.mdp Nt) `{mdp_fin : @MDP.mdp_fin Nt mdp} `{mdp_dec : @MDP.mdp_dec Nt mdp}.
Variable discount : Nt.
Definition Rdiscount := to_R discount.
Definition Rmdp := mdp_to_R mdp.
Definition Rmdp_fin : MDP.mdp_fin Rmdp := mdp_fin_to_R mdp mdp_fin.
Definition Rmpd_dec : MDP.mdp_dec Rmdp := mdp_dec_to_R mdp mdp_dec.
Definition value_func_to_R (v : value_func mdp) : (value_func Rmdp) :=
(fun s => to_R (v s)).
Lemma s_enum_same: s_enum mdp = s_enum Rmdp (mdp_fin := Rmdp_fin).
Proof.
unfold s_enum.
simpl.
unfold St.
f_equal.
unfold Rmdp_fin.
unfold Rmdp.
unfold mdp_to_R.
unfold mdp_fin_to_R.
simpl.
destruct mdp_fin.
simpl.
destruct mdp.
auto.
Qed.
Lemma a_enum_same: a_enum mdp = a_enum Rmdp (mdp_fin := Rmdp_fin).
Proof.
unfold a_enum.
simpl.
unfold A.
f_equal.
unfold Rmdp_fin.
unfold Rmdp.
unfold mdp_to_R.
unfold mdp_fin_to_R.
simpl.
destruct mdp_fin.
simpl.
destruct mdp.
auto.
Qed.
Lemma AEnum_same: MDP.AEnum (t := mdp) = @MDP.AEnum R Rmdp Rmdp_fin.
Proof.
unfold Rmdp_fin.
unfold Rmdp.
unfold mdp_to_R.
unfold mdp_fin_to_R.
simpl.
destruct mdp_fin.
simpl.
destruct mdp.
auto.
Qed.
Lemma SEnum_same: MDP.SEnum (t := mdp) = @MDP.SEnum R Rmdp Rmdp_fin.
Proof.
unfold Rmdp_fin.
unfold Rmdp.
unfold mdp_to_R.
unfold mdp_fin_to_R.
simpl.
destruct mdp_fin.
simpl.
destruct mdp.
auto.
Qed.
Hint Immediate s_enum_same.
Lemma discounted_reward_to_R: forall (v : value_func mdp) (s : St mdp) (a : A mdp),
to_R (discounted_reward mdp discount v s a) =
discounted_reward (mdp_fin := Rmdp_fin) Rmdp Rdiscount (value_func_to_R v) s a.
Proof.
intros.
unfold discounted_reward.
rewrite <- to_R_big_sum.
apply big_sum_ext; auto.
unfold eqfun. intros.
by to_R_distr.
Qed.
Lemma value_itertion_step_to_R: forall (v : value_func mdp),
value_func_to_R (value_iteration_step mdp discount v) =1
value_iteration_step (mdp_fin := Rmdp_fin) Rmdp Rdiscount (value_func_to_R v).
Proof.
unfold eqfun.
intros.
unfold value_func_to_R.
unfold value_iteration_step.
rewrite to_R_mapmax_ne.
simpl.
apply ssr.ssrfun.Some_inj.
repeat rewrite <- mapmax_ne_ok.
simpl.
rewrite AEnum_same.
apply mapmax_ext.
intros.
by rewrite discounted_reward_to_R.
Qed.
Lemma policy_evaluate_step_to_R: forall (v : value_func mdp) (p : policy mdp),
value_func_to_R (evaluate_policy_step mdp discount p v) =1
evaluate_policy_step (mdp_fin := Rmdp_fin) Rmdp Rdiscount p (value_func_to_R v).
Proof.
unfold eqfun.
unfold evaluate_policy_step.
intros.
rewrite <- discounted_reward_to_R.
unfold value_func_to_R.
auto.
Qed.
End mdp_fin_dec_toR.
Section mdp_fin_dec_R.
Variable discount : R.
Hypothesis discount_nonneg : 0 <= discount.
Hypothesis discount_lt_1: discount < 1.
Context (mdp : @MDP.mdp R) `{mdp_fin : @MDP.mdp_fin R mdp} `{mdp_dec : @MDP.mdp_dec R mdp}.
Context `{mdp_fin_ok : @MDP.mdp_fin_ok R _ mdp _}.
Definition R_vi_banach := value_iteration_banach mdp discount discount_nonneg discount_lt_1.
Definition R_eval_pol_banach := (evaluate_policy_banach mdp discount discount_nonneg discount_lt_1).
Lemma value_iteration_R_cauchy_crit_aux: forall (v : value_func mdp) (n m : nat) (e: R),
0 < e ->
0 < value_dist mdp v (value_iteration_step mdp discount v) ->
pow_nat discount n < e * (1 + - discount) * Rinv (value_dist mdp v (value_iteration_step mdp discount v)) ->
value_dist mdp (value_iteration_rec mdp discount v n) (value_iteration_rec mdp discount v (n + m)) < e.
Proof.