Skip to content

Commit aec2644

Browse files
Add Chapter 4 Dynamic Programming code (policy evaluation, policy iteration, value iteration)
1 parent 32bcd92 commit aec2644

File tree

11 files changed

+396
-0
lines changed

11 files changed

+396
-0
lines changed

ch4_dp/.gitignore

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
.venv/
2+
__pycache__/
3+
*.pyc
4+
artifacts/latex/*.tex

ch4_dp/Makefile

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
ART4=artifacts/ch4_4x4
2+
ART6=artifacts/ch4_6x6
3+
4+
.PHONY: ch4-artifacts ch4-tables
5+
ch4-artifacts:
6+
python examples/generate_artifacts.py --env 4x4 --outdir $(ART4)
7+
python examples/generate_artifacts.py --env 6x6 --outdir $(ART6)
8+
9+
ch4-tables: ch4-artifacts
10+
python examples/csv_to_latex.py $(ART4) --outdir artifacts/latex --no-wrap --round 0

ch4_dp/README.md

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
# Chapter 4 – Dynamic Programming (DP) Code
2+
3+
This repo contains clean, runnable reference code to accompany **Chapter 4: Dynamic Programming Approaches**.
4+
5+
## Contents
6+
7+
- `src/rldp/dp.py` – Policy Evaluation, Policy Iteration, Value Iteration
8+
- `src/rldp/gridworld.py` – Simple deterministic GridWorld (4×4, 6×6)
9+
- `src/rldp/latex.py` – CSV → LaTeX table helper (booktabs-ready)
10+
- `examples/generate_artifacts.py` – Reproduces tables/plots for the chapter
11+
- `examples/csv_to_latex.py` – Convert CSV matrices to LaTeX tables
12+
- `Makefile` – Convenience targets
13+
- `requirements.txt` – Python deps
14+
15+
## Quickstart
16+
17+
```bash
18+
python -m venv .venv && source .venv/bin/activate # (Windows: .venv\Scripts\activate)
19+
pip install -r requirements.txt
20+
21+
# Generate artifacts for 4×4 and 6×6 worlds
22+
python examples/generate_artifacts.py --env 4x4 --outdir artifacts/ch4_4x4
23+
python examples/generate_artifacts.py --env 6x6 --outdir artifacts/ch4_6x6
24+
25+
# Convert CSV → LaTeX tabular (booktabs)
26+
python examples/csv_to_latex.py artifacts/ch4_4x4/vi_values_4x4_k2.csv --outdir artifacts/latex --caption "Value iteration estimates (k=2) on the $4\\times4$ gridworld." --label tab:vi-4x4-k2 --float-format ".0f"
27+
```
28+
29+
Then include in LaTeX:
30+
31+
```latex
32+
\usepackage{booktabs} % in preamble
33+
% ...
34+
\input{artifacts/latex/vi_values_4x4_k2.tex}
35+
```
36+
37+
## Make targets
38+
39+
```bash
40+
make ch4-artifacts # build default artifacts
41+
make ch4-tables # CSV → LaTeX for default directory
42+
```
43+
44+
## License
45+
46+
MIT for the code snippets here. Attribution appreciated.

ch4_dp/artifacts/.gitkeep

Whitespace-only changes.

ch4_dp/examples/csv_to_latex.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
from __future__ import annotations
2+
import argparse, os, sys
3+
from pathlib import Path
4+
from rldp.latex import grid_csv_to_tabular
5+
6+
def convert_one(csv_path: str, out_dir: str, caption: str | None, label: str | None,
7+
colfmt: str | None, float_format: str | None, index: bool,
8+
wrap_table: bool, round_digits: int | None,
9+
transpose: bool, suffix: str) -> str:
10+
tex = grid_csv_to_tabular(csv_path, caption, label, colfmt, float_format,
11+
index, True, wrap_table, round_digits, transpose)
12+
name = Path(csv_path).stem + (suffix if suffix else "") + ".tex"
13+
os.makedirs(out_dir, exist_ok=True)
14+
out_path = str(Path(out_dir) / name)
15+
with open(out_path, "w", encoding="utf-8") as f:
16+
f.write(tex)
17+
return out_path
18+
19+
def main():
20+
p = argparse.ArgumentParser(description="Convert CSV grids (values/policies) to LaTeX tables.")
21+
p.add_argument("inputs", nargs="+", help="CSV files or a directory (will convert all *.csv).")
22+
p.add_argument("--outdir", default="artifacts/latex", help="Output directory for .tex tables.")
23+
p.add_argument("--caption", default=None, help="Caption to use (optional).")
24+
p.add_argument("--label", default=None, help="LaTeX label (e.g., tab:vi-iterations).")
25+
p.add_argument("--colfmt", default=None, help="LaTeX column format, e.g., 'cccc'.")
26+
p.add_argument("--float-format", default=None, help="Python format, e.g., '.0f' or '{:.2f}'.")
27+
p.add_argument("--index", action="store_true", help="Include DataFrame index.")
28+
p.add_argument("--no-wrap", action="store_true", help="Emit only tabular (no table environment).")
29+
p.add_argument("--round", type=int, default=None, help="Round all numbers to N decimals.")
30+
p.add_argument("--transpose", action="store_true", help="Transpose before rendering.")
31+
p.add_argument("--suffix", default="", help="Append to output filename stem (e.g., '_nice').")
32+
args = p.parse_args()
33+
34+
# Expand inputs: if a directory is given, take all CSVs in it
35+
files = []
36+
for item in args.inputs:
37+
pth = Path(item)
38+
if pth.is_dir():
39+
files.extend(str(p) for p in pth.glob("*.csv"))
40+
elif pth.suffix.lower() == ".csv":
41+
files.append(str(pth))
42+
else:
43+
print(f"Skipping non-CSV: {item}", file=sys.stderr)
44+
45+
if not files:
46+
print("No CSV files found.", file=sys.stderr)
47+
sys.exit(1)
48+
49+
created = []
50+
for csv in sorted(files):
51+
out = convert_one(
52+
csv_path=csv,
53+
out_dir=args.outdir,
54+
caption=args.caption,
55+
label=args.label,
56+
colfmt=args.colfmt,
57+
float_format=args.float_format,
58+
index=args.index,
59+
wrap_table=not args.no_wrap,
60+
round_digits=args.round,
61+
transpose=args.transpose,
62+
suffix=args.suffix,
63+
)
64+
created.append(out)
65+
print(f"Wrote: {out}")
66+
67+
if __name__ == "__main__":
68+
main()
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
from __future__ import annotations
2+
import argparse, os
3+
import numpy as np
4+
import pandas as pd
5+
import matplotlib.pyplot as plt
6+
7+
from rldp.gridworld import make_gridworld, unravel_index, ACTIONS, arrows_from_policy
8+
from rldp.dp import policy_evaluation, policy_iteration, value_iteration
9+
10+
def save_grid_csv(V, n, out_csv):
11+
M = np.zeros((n, n))
12+
for s in range(n*n):
13+
i, j = unravel_index(s, n)
14+
M[i, j] = V[s]
15+
df = pd.DataFrame(M)
16+
os.makedirs(os.path.dirname(out_csv), exist_ok=True)
17+
df.to_csv(out_csv, index=False)
18+
19+
def save_policy_csv(pi, n, out_csv):
20+
arr = arrows_from_policy(pi).reshape(n, n)
21+
df = pd.DataFrame(arr)
22+
os.makedirs(os.path.dirname(out_csv), exist_ok=True)
23+
df.to_csv(out_csv, index=False)
24+
25+
def plot_values(V, n, out_png, title=None):
26+
M = np.zeros((n, n))
27+
for s in range(n*n):
28+
i, j = unravel_index(s, n)
29+
M[i, j] = V[s]
30+
fig = plt.figure()
31+
plt.imshow(M, interpolation='nearest')
32+
plt.colorbar()
33+
if title:
34+
plt.title(title)
35+
for i in range(n):
36+
for j in range(n):
37+
plt.text(j, i, f"{M[i,j]:.0f}", ha='center', va='center')
38+
os.makedirs(os.path.dirname(out_png), exist_ok=True)
39+
plt.savefig(out_png, bbox_inches='tight', dpi=160)
40+
plt.close(fig)
41+
42+
def main():
43+
ap = argparse.ArgumentParser()
44+
ap.add_argument('--env', default='4x4', choices=['4x4','6x6'])
45+
ap.add_argument('--gamma', type=float, default=1.0)
46+
ap.add_argument('--theta', type=float, default=1e-6)
47+
ap.add_argument('--outdir', default='artifacts/ch4_4x4')
48+
args = ap.parse_args()
49+
50+
n = 4 if args.env == '4x4' else 6
51+
states, actions, P, R, meta = make_gridworld(n=n)
52+
# Policy Iteration
53+
pi_pi, V_pi = policy_iteration(states, actions, P, R, gamma=args.gamma, theta=args.theta)
54+
# Value Iteration
55+
pi_vi, V_vi = value_iteration(states, actions, P, R, gamma=args.gamma, theta=args.theta)
56+
57+
os.makedirs(args.outdir, exist_ok=True)
58+
59+
# Save values (final)
60+
save_grid_csv(V_pi, n, os.path.join(args.outdir, f'pi_values_{args.env}.csv'))
61+
save_grid_csv(V_vi, n, os.path.join(args.outdir, f'vi_values_{args.env}.csv'))
62+
plot_values(V_pi, n, os.path.join(args.outdir, f'pi_values_{args.env}.png'), 'Policy Iteration Values')
63+
plot_values(V_vi, n, os.path.join(args.outdir, f'vi_values_{args.env}.png'), 'Value Iteration Values')
64+
65+
# Save policies
66+
save_policy_csv(pi_pi, n, os.path.join(args.outdir, f'pi_policy_{args.env}.csv'))
67+
save_policy_csv(pi_vi, n, os.path.join(args.outdir, f'vi_policy_{args.env}.csv'))
68+
69+
print('Artifacts written to:', args.outdir)
70+
71+
if __name__ == '__main__':
72+
main()

ch4_dp/requirements.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
numpy
2+
pandas
3+
matplotlib

ch4_dp/src/rldp/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
__all__ = ["dp", "gridworld", "latex"]

ch4_dp/src/rldp/dp.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
from __future__ import annotations
2+
import numpy as np
3+
4+
def policy_evaluation(states, actions, P, R, pi, gamma: float = 1.0, theta: float = 1e-6):
5+
"""Iterative policy evaluation.
6+
states: list-like of states (indices 0..S-1)
7+
actions: list-like of actions (indices 0..A-1)
8+
P: shape [S, A, S] transition probabilities
9+
R: shape [S, A, S] expected rewards
10+
pi: shape [S, A] policy (row-stochastic)
11+
"""
12+
S = len(states)
13+
V = np.zeros(S, dtype=float)
14+
while True:
15+
delta = 0.0
16+
for s in range(S):
17+
v_old = V[s]
18+
V[s] = sum(
19+
pi[s, a] * sum(P[s, a, s2] * (R[s, a, s2] + gamma * V[s2]) for s2 in range(S))
20+
for a in range(len(actions))
21+
)
22+
delta = max(delta, abs(v_old - V[s]))
23+
if delta < theta:
24+
break
25+
return V
26+
27+
def policy_iteration(states, actions, P, R, gamma: float = 1.0, theta: float = 1e-6):
28+
"""Howard's policy iteration."""
29+
S, A = len(states), len(actions)
30+
pi = np.ones((S, A)) / A
31+
V = np.zeros(S, dtype=float)
32+
stable = False
33+
34+
while not stable:
35+
V = policy_evaluation(states, actions, P, R, pi, gamma, theta)
36+
stable = True
37+
for s in range(S):
38+
old_action = np.argmax(pi[s])
39+
q_values = [
40+
sum(P[s, a, s2] * (R[s, a, s2] + gamma * V[s2]) for s2 in range(S))
41+
for a in range(A)
42+
]
43+
best = int(np.argmax(q_values))
44+
pi[s] = np.eye(A)[best]
45+
if best != old_action:
46+
stable = False
47+
return pi, V
48+
49+
def value_iteration(states, actions, P, R, gamma: float = 1.0, theta: float = 1e-6):
50+
"""Bellman optimality updates until convergence."""
51+
S, A = len(states), len(actions)
52+
V = np.zeros(S, dtype=float)
53+
while True:
54+
delta = 0.0
55+
for s in range(S):
56+
v_old = V[s]
57+
q_values = [
58+
sum(P[s, a, s2] * (R[s, a, s2] + gamma * V[s2]) for s2 in range(S))
59+
for a in range(A)
60+
]
61+
V[s] = max(q_values)
62+
delta = max(delta, abs(v_old - V[s]))
63+
if delta < theta:
64+
break
65+
# Derive greedy policy
66+
pi = np.zeros((S, A))
67+
for s in range(S):
68+
q_values = [
69+
sum(P[s, a, s2] * (R[s, a, s2] + gamma * V[s2]) for s2 in range(S))
70+
for a in range(A)
71+
]
72+
best = int(np.argmax(q_values))
73+
pi[s] = np.eye(A)[best]
74+
return pi, V

ch4_dp/src/rldp/gridworld.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
from __future__ import annotations
2+
import numpy as np
3+
4+
ACTIONS = ['U','R','D','L'] # up, right, down, left
5+
A_DELTA = {'U':(-1,0), 'R':(0,1), 'D':(1,0), 'L':(0,-1)}
6+
7+
def make_gridworld(n: int = 4, step_reward: float = -1.0, terminal: tuple[int,int] | None = None):
8+
"""Deterministic gridworld (n×n). Terminal default is (0, n-1)."""
9+
if terminal is None:
10+
terminal = (0, n-1)
11+
S = n*n
12+
A = len(ACTIONS)
13+
P = np.zeros((S, A, S), dtype=float)
14+
R = np.full((S, A, S), 0.0, dtype=float)
15+
16+
def idx(i,j): return i*n + j
17+
term_idx = idx(*terminal)
18+
19+
for i in range(n):
20+
for j in range(n):
21+
s = idx(i,j)
22+
for a_id, a in enumerate(ACTIONS):
23+
if s == term_idx:
24+
P[s, a_id, s] = 1.0
25+
R[s, a_id, s] = 0.0
26+
continue
27+
di, dj = A_DELTA[a]
28+
ni, nj = i+di, j+dj
29+
if ni < 0 or ni >= n or nj < 0 or nj >= n:
30+
ns = s # bump into wall
31+
else:
32+
ns = idx(ni, nj)
33+
P[s, a_id, ns] = 1.0
34+
R[s, a_id, ns] = step_reward if ns != term_idx else 0.0
35+
states = list(range(S))
36+
actions = list(range(A))
37+
return states, actions, P, R, (n, terminal, term_idx)
38+
39+
def unravel_index(s: int, n: int):
40+
return (s // n, s % n)
41+
42+
def arrows_from_policy(pi):
43+
"""Convert one-hot deterministic policy (S×A) to symbol grid of U/R/D/L."""
44+
idx = np.argmax(pi, axis=1)
45+
return np.array([['U','R','D','L'][k] for k in idx])

0 commit comments

Comments
 (0)