Skip to content

Commit 95cdf66

Browse files
Fix forward-backward equivalence test in Chapter 8
1 parent a03040e commit 95cdf66

1 file changed

Lines changed: 15 additions & 9 deletions

File tree

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,29 @@
11
import numpy as np
22

33
def lambda_return(b, c, lam):
4-
# Episode: r1=0,r2=0,r3=1, gamma=1; V(s0)=a, V(s1)=b, V(s2)=c (a cancels in diff)
5-
G1, G2, G3 = b, c, 1.0
4+
G1 = b
5+
G2 = c
6+
G3 = 1.0
67
return (1 - lam) * (G1 + lam * G2 + lam**2 * G3)
78

89
def test_forward_backward_equivalence():
9-
a, b, c = 0.3, 0.6, 0.2
10+
a, b, c = 0.5, 0.3, 0.2
1011
lam = 0.5
11-
# forward:
12+
13+
# forward λ-return update for V(s0) = a
1214
Glam = lambda_return(b, c, lam)
13-
forward_update = (Glam - a)
15+
forward_update = Glam - a
1416

15-
# backward:
17+
# backward view TD error updates
1618
d0 = b - a
1719
d1 = c - b
1820
d2 = 1.0 - c
19-
# eligibilities for s0 at t=0,1,2 (gamma=1)
20-
e0, e1, e2 = 1.0, lam, lam**2
21-
backward_update = d0*e0 + d1*e1 + d2*e2
21+
22+
# eligibilities for s0 at each step
23+
e0 = 1.0
24+
e1 = lam # after one step
25+
e2 = lam**2 # after two steps (γ=1 here)
26+
27+
backward_update = d0 * e0 + d1 * e1 + d2 * e2
2228

2329
assert np.isclose(forward_update, backward_update, atol=1e-12)

0 commit comments

Comments
 (0)