|
1 | 1 | import numpy as np |
2 | 2 |
|
3 | 3 | def lambda_return(b, c, lam): |
4 | | - # Episode: r1=0,r2=0,r3=1, gamma=1; V(s0)=a, V(s1)=b, V(s2)=c (a cancels in diff) |
5 | | - G1, G2, G3 = b, c, 1.0 |
| 4 | + G1 = b |
| 5 | + G2 = c |
| 6 | + G3 = 1.0 |
6 | 7 | return (1 - lam) * (G1 + lam * G2 + lam**2 * G3) |
7 | 8 |
|
8 | 9 | def test_forward_backward_equivalence(): |
9 | | - a, b, c = 0.3, 0.6, 0.2 |
| 10 | + a, b, c = 0.5, 0.3, 0.2 |
10 | 11 | lam = 0.5 |
11 | | - # forward: |
| 12 | + |
| 13 | + # forward λ-return update for V(s0) = a |
12 | 14 | Glam = lambda_return(b, c, lam) |
13 | | - forward_update = (Glam - a) |
| 15 | + forward_update = Glam - a |
14 | 16 |
|
15 | | - # backward: |
| 17 | + # backward view TD error updates |
16 | 18 | d0 = b - a |
17 | 19 | d1 = c - b |
18 | 20 | d2 = 1.0 - c |
19 | | - # eligibilities for s0 at t=0,1,2 (gamma=1) |
20 | | - e0, e1, e2 = 1.0, lam, lam**2 |
21 | | - backward_update = d0*e0 + d1*e1 + d2*e2 |
| 21 | + |
| 22 | + # eligibilities for s0 at each step |
| 23 | + e0 = 1.0 |
| 24 | + e1 = lam # after one step |
| 25 | + e2 = lam**2 # after two steps (γ=1 here) |
| 26 | + |
| 27 | + backward_update = d0 * e0 + d1 * e1 + d2 * e2 |
22 | 28 |
|
23 | 29 | assert np.isclose(forward_update, backward_update, atol=1e-12) |
0 commit comments