Skip to content

Commit e63f068

Browse files
Add Chapter 8: Eligibility Traces and TD(λ) code, tests, examples, and workflow
1 parent 72bb180 commit e63f068

12 files changed

Lines changed: 561 additions & 0 deletions

.github/workflows/ch8.yml

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
name: "Chapter 8: Eligibility Traces and TD(λ)"
2+
3+
on:
4+
push:
5+
paths:
6+
- 'ch8_td_lambda/**'
7+
- '.github/workflows/ch8.yml'
8+
pull_request:
9+
paths:
10+
- 'ch8_td_lambda/**'
11+
- '.github/workflows/ch8.yml'
12+
13+
jobs:
14+
test:
15+
uses: ./.github/workflows/_chapter-tests.yml
16+
with:
17+
chapter: ch8_td_lambda
18+
19+
examples:
20+
needs: test
21+
runs-on: ubuntu-latest
22+
steps:
23+
- name: Checkout
24+
uses: actions/checkout@v4
25+
26+
- name: Set up Python
27+
uses: actions/setup-python@v5
28+
with:
29+
python-version: '3.11'
30+
31+
- name: Install root requirements
32+
run: |
33+
python -m pip install --upgrade pip
34+
if [ -f requirements.txt ]; then pip install -r requirements.txt; else pip install numpy matplotlib; fi
35+
pip install pytest
36+
37+
- name: Run TD(λ) demo
38+
run: python ch8_td_lambda/examples/TD_Lambda_Demo.py
39+
40+
- name: Run SARSA(λ) demo
41+
run: python ch8_td_lambda/examples/SARSA_Lambda_Demo.py

ch8_td_lambda/README.md

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
# ch8_td_lambda · Eligibility Traces and TD(λ)
2+
3+
Reference implementations and experiments for **Chapter 8** of
4+
_Reinforcement Learning Fundamentals: From Theory to Practice_.
5+
6+
This chapter covers the forward/backward views, TD(λ) prediction, SARSA(λ) control, and true-online TD(λ) with linear function approximation.
7+
8+
---
9+
10+
## Folder layout
11+
12+
```
13+
ch8_td_lambda/
14+
├─ gridworld_small.py # 4×4 tabular gridworld (start=(3,0), goal=(0,3))
15+
├─ td_lambda.py # TD(λ) prediction (backward view; accumulating/replacing)
16+
├─ sarsa_lambda.py # SARSA(λ) control with ε-greedy
17+
├─ true_online_td_lambda.py # True Online TD(λ) for linear FA
18+
├─ plot_tdlambda_learning.py # Produces learning curves for λ ∈ {0, 0.5, 1}
19+
├─ tests/
20+
│ └─ test_forward_backward_equiv.py # Forward ↔ backward numerical check
21+
```
22+
23+
---
24+
25+
## Quick start
26+
27+
> Assumes Python ≥ 3.9 and `matplotlib`, `numpy`, `pytest`.
28+
29+
### 1) Run the unit test (forward ↔ backward equivalence)
30+
31+
```bash
32+
pytest ch8_td_lambda/tests -q
33+
```
34+
35+
Expected:
36+
```
37+
. [100%]
38+
1 passed in ~0.02s
39+
```
40+
41+
### 2) Generate learning curves (SARSA(λ) on gridworld)
42+
43+
```bash
44+
python ch8_td_lambda/plot_tdlambda_learning.py
45+
```
46+
47+
Artifacts written to the project (figure under `figs/`):
48+
- `ch8_tdlambda_learning.csv`
49+
- `figs/ch8_tdlambda_learning.png`
50+
51+
The plot compares success rates for **λ ∈ {0.0, 0.5, 1.0}**.
52+
(Intermediate λ typically balances speed and stability in this task.)
53+
54+
---
55+
56+
## Minimal examples
57+
58+
### TD(λ) prediction (tabular; backward view)
59+
```python
60+
import numpy as np
61+
from ch8_td_lambda.gridworld_small import GridworldSmall
62+
from ch8_td_lambda.td_lambda import td_lambda_prediction
63+
64+
env = GridworldSmall(seed=0)
65+
66+
def random_policy(s: int):
67+
return np.ones(env.n_actions) / env.n_actions # uniform
68+
69+
V = td_lambda_prediction(env, random_policy, gamma=0.99, alpha=0.1, lam=0.9, episodes=200)
70+
print(V.reshape(env.n_rows, env.n_cols))
71+
```
72+
73+
### SARSA(λ) control (ε-greedy)
74+
```python
75+
from ch8_td_lambda.gridworld_small import GridworldSmall
76+
from ch8_td_lambda.sarsa_lambda import sarsa_lambda_control
77+
import numpy as np
78+
79+
env = GridworldSmall(seed=0)
80+
Q = sarsa_lambda_control(env, gamma=0.99, alpha=0.1, lam=0.8, epsilon=0.1, episodes=1000)
81+
print(Q.argmax(axis=1).reshape(env.n_rows, env.n_cols)) # greedy policy
82+
```
83+
84+
### True Online TD(λ) (linear FA; one-hot features)
85+
```python
86+
import numpy as np
87+
from ch8_td_lambda.gridworld_small import GridworldSmall
88+
from ch8_td_lambda.true_online_td_lambda import true_online_td_lambda_linear
89+
90+
env = GridworldSmall(seed=0)
91+
def phi(s: int):
92+
x = np.zeros(env.n_states, dtype=float)
93+
x[s] = 1.0
94+
return x
95+
96+
w = true_online_td_lambda_linear(env, phi, gamma=0.99, alpha=0.15, lam=0.8, episodes=800, seed=0)
97+
print(w.reshape(env.n_rows, env.n_cols)) # value estimates
98+
```
99+
100+
---
101+
102+
## Expected outputs
103+
104+
- **Learning curves:** `ch8_tdlambda_learning.png` — success rate vs. episodes for λ=0, 0.5, 1.0.
105+
- **CSV:** `ch8_tdlambda_learning.csv` — columns: `episodes, lambda_0.0, lambda_0.5, lambda_1.0`.
106+
107+
---
108+
109+
## LaTeX snippet (embed figure in the book)
110+
111+
After generating the figure, move/commit it under `figs/` and include:
112+
113+
```latex
114+
\begin{figure}[h!]
115+
\centering
116+
\includegraphics[width=0.75\linewidth]{figs/ch8_tdlambda_learning.png}
117+
\caption{Learning curves for TD($\lambda$) on a $4\times4$ gridworld under SARSA($\lambda$). Intermediate $\lambda$ values (e.g., $0.5$) often balance speed and stability.}
118+
\label{fig:tdlambda-learning}
119+
\end{figure}
120+
```
121+
122+
---
123+
124+
## Notes
125+
126+
- `sarsa_lambda_control` supports `trace_type="accumulating"` or `"replacing"` (default is replacing in the learning-curve script for stability when states repeat).
127+
- For reproducibility, seeds are set inside scripts; you can adjust α, ε, and λ from the script/CLI if desired.
128+
129+
---
130+
131+
## References
132+
133+
- Sutton, R. S. (1988). *Learning to Predict by the Methods of Temporal Differences*.
134+
- Sutton, R. S., & Barto, A. G. (2018). *Reinforcement Learning: An Introduction (2nd ed.)*.
135+
- van Seijen, H., & Sutton, R. S. (2014). *True Online TD(λ)*.
136+
- Tesauro, G. (1995). *TD-Gammon*.
137+
- Schulman, J. et al. (2016). *Generalized Advantage Estimation*.

ch8_td_lambda/__init__.py

Whitespace-only changes.
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
from __future__ import annotations
2+
import numpy as np
3+
from ch8_td_lambda.gridworld_small import GridworldSmall
4+
from ch8_td_lambda.sarsa_lambda import sarsa_lambda_control
5+
6+
def main():
7+
env = GridworldSmall(seed=0)
8+
Q = sarsa_lambda_control(env, gamma=0.99, alpha=0.1, lam=0.8, epsilon=0.1, episodes=1500, seed=0)
9+
greedy = Q.argmax(axis=1).reshape(env.n_rows, env.n_cols)
10+
print('Greedy policy (0:up,1:right,2:down,3:left):\n', greedy)
11+
12+
if __name__ == '__main__':
13+
main()
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
from __future__ import annotations
2+
import numpy as np
3+
from ch8_td_lambda.gridworld_small import GridworldSmall
4+
from ch8_td_lambda.td_lambda import td_lambda_prediction
5+
6+
def main():
7+
env = GridworldSmall(seed=0)
8+
9+
def random_policy(s: int):
10+
return np.ones(env.n_actions) / env.n_actions # uniform
11+
12+
V = td_lambda_prediction(env, random_policy, gamma=0.99, alpha=0.1, lam=0.9, episodes=300, seed=0)
13+
print('Value estimates (4x4):\n', V.reshape(env.n_rows, env.n_cols))
14+
15+
if __name__ == '__main__':
16+
main()

ch8_td_lambda/gridworld_small.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
# Minimal 4x4 gridworld with gym-like API (tabular states 0..15).
2+
# Start at (3,0), goal at (0,3). Step reward -0.01, goal reward +1.
3+
from __future__ import annotations
4+
from typing import Tuple, Optional
5+
import numpy as np
6+
7+
class GridworldSmall:
8+
def __init__(self, seed: Optional[int] = None):
9+
self.n_rows = 4
10+
self.n_cols = 4
11+
self.n_states = self.n_rows * self.n_cols
12+
self.n_actions = 4 # 0:up, 1:right, 2:down, 3:left
13+
self.start = (3, 0)
14+
self.goal = (0, 3)
15+
self.step_reward = -0.01
16+
self.goal_reward = 1.0
17+
self._rng = np.random.default_rng(seed)
18+
self.s = self._to_state(self.start)
19+
20+
def _to_state(self, rc: Tuple[int, int]) -> int:
21+
r, c = rc
22+
return r * self.n_cols + c
23+
24+
def _to_rc(self, s: int) -> Tuple[int, int]:
25+
return divmod(s, self.n_cols)
26+
27+
def reset(self) -> int:
28+
self.s = self._to_state(self.start)
29+
return self.s
30+
31+
def step(self, a: int):
32+
r, c = self._to_rc(self.s)
33+
if a == 0: # up
34+
r = max(0, r - 1)
35+
elif a == 1: # right
36+
c = min(self.n_cols - 1, c + 1)
37+
elif a == 2: # down
38+
r = min(self.n_rows - 1, r + 1)
39+
elif a == 3: # left
40+
c = max(0, c - 1)
41+
s_next = self._to_state((r, c))
42+
done = (r, c) == self.goal
43+
reward = self.goal_reward if done else self.step_reward
44+
self.s = s_next
45+
return s_next, reward, done, {}
46+
47+
def render(self) -> None:
48+
r, c = self._to_rc(self.s)
49+
board = np.full((self.n_rows, self.n_cols), '.', dtype=object)
50+
board[self.goal] = 'G'
51+
board[r, c] = 'A'
52+
print('\n'.join(' '.join(row) for row in board))
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
from __future__ import annotations
2+
import numpy as np
3+
import matplotlib.pyplot as plt
4+
from ch8_td_lambda.plot_utils import moving_average, style_axes
5+
from ch8_td_lambda.gridworld_small import GridworldSmall
6+
from ch8_td_lambda.sarsa_lambda import sarsa_lambda_control
7+
8+
def eval_success_rate(env, Q, episodes=200, max_steps=200, seed=123) -> float:
9+
rng = np.random.default_rng(seed)
10+
succ = 0
11+
for _ in range(episodes):
12+
s = env.reset()
13+
for _ in range(max_steps):
14+
a = int(np.argmax(Q[s]))
15+
s, r, done, *_ = env.step(a)
16+
if done:
17+
succ += 1
18+
break
19+
return succ / episodes
20+
21+
def main():
22+
seeds = [0, 1, 2]
23+
lambdas = [0.0, 0.5, 1.0]
24+
episodes_per_seed = 3000
25+
eval_every = 100
26+
alphas = {0.0: 0.15, 0.5: 0.12, 1.0: 0.08} # gentle tuning
27+
28+
curves = {lam: [] for lam in lambdas}
29+
xs = []
30+
31+
for lam in lambdas:
32+
agg = []
33+
for sd in seeds:
34+
env = GridworldSmall(seed=sd)
35+
# train in chunks, evaluate periodically
36+
Q = np.zeros((env.n_states, env.n_actions))
37+
for ep0 in range(0, episodes_per_seed, eval_every):
38+
Q = sarsa_lambda_control(
39+
env=env, gamma=0.99, alpha=alphas[lam], lam=lam,
40+
epsilon=0.1, episodes=eval_every, trace_type='replacing',
41+
seed=sd, n_states=env.n_states, n_actions=env.n_actions
42+
)
43+
sr = eval_success_rate(env, Q, episodes=100, seed=sd)
44+
agg.append(sr)
45+
if lam == lambdas[0]:
46+
xs.append(ep0 + eval_every)
47+
curves[lam] = np.array(agg).reshape(len(seeds), -1).mean(axis=0)
48+
49+
# plot
50+
plt.figure(figsize=(7.2, 4.2))
51+
for lam in lambdas:
52+
plt.plot(xs, moving_average(curves[lam], w=5), label=f'λ={lam}')
53+
style_axes(plt.gca(), xlabel='Episodes', ylabel='Success rate (greedy)', ylim=(0,1.02), legend_loc='lower right')
54+
import os
55+
os.makedirs('figs', exist_ok=True)
56+
plt.savefig('figs/ch8_tdlambda_learning.png', dpi=160)
57+
58+
# also dump CSV
59+
import csv
60+
with open('ch8_tdlambda_learning.csv', 'w', newline='') as f:
61+
w = csv.writer(f)
62+
w.writerow(['episodes'] + [f'lambda_{lam}' for lam in lambdas])
63+
for i, x in enumerate(xs):
64+
row = [x] + [float(curves[lam][i]) for lam in lambdas]
65+
w.writerow(row)
66+
67+
if __name__ == '__main__':
68+
main()

ch8_td_lambda/plot_utils.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
from __future__ import annotations
2+
import numpy as np
3+
import matplotlib.pyplot as plt
4+
5+
def moving_average(x, w=5):
6+
if w <= 1:
7+
return np.asarray(x)
8+
x = np.asarray(x, dtype=float)
9+
if x.size < w:
10+
return x.copy()
11+
c = np.cumsum(np.insert(x, 0, 0.0))
12+
y = (c[w:] - c[:-w]) / float(w)
13+
# pad to original length (left pad with first value)
14+
pad = np.full(w-1, y[0])
15+
return np.concatenate([pad, y])
16+
17+
def style_axes(ax, xlabel=None, ylabel=None, grid=True, ylim=None, legend_loc="best"):
18+
if xlabel: ax.set_xlabel(xlabel)
19+
if ylabel: ax.set_ylabel(ylabel)
20+
if grid: ax.grid(True, alpha=0.35)
21+
if ylim is not None: ax.set_ylim(*ylim)
22+
ax.legend(loc=legend_loc)
23+
plt.tight_layout()
24+
return ax

0 commit comments

Comments
 (0)