Skip to content

Commit 47f6fc7

Browse files
Add Chapter 11 (Policy Gradient REINFORCE) with Python 3.9-compatible typing and CI
1 parent 00778b3 commit 47f6fc7

9 files changed

Lines changed: 71 additions & 29 deletions

File tree

.github/workflows/ch11.yml

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
name: ch11
1+
name: ch11 — Policy Gradient (REINFORCE)
22
on:
33
push:
44
paths: ['ch11_policy_gradient/**', '.github/workflows/ch11.yml']
@@ -8,7 +8,8 @@ jobs:
88
test:
99
runs-on: ubuntu-latest
1010
strategy:
11-
matrix: { python-version: ['3.9','3.10','3.11'] }
11+
matrix:
12+
python-version: ['3.8','3.9','3.10','3.11']
1213
steps:
1314
- uses: actions/checkout@v4
1415
- uses: actions/setup-python@v5
@@ -18,5 +19,6 @@ jobs:
1819
- run: |
1920
python -m pip install -U pip
2021
pip install -r ch11_policy_gradient/requirements.txt
21-
- env: { PYTHONPATH: . }
22+
- env:
23+
PYTHONPATH: .
2224
run: pytest -q ch11_policy_gradient/tests

ch11_policy_gradient/README_ch11.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
# Chapter 11 — Policy Gradient Fundamentals (REINFORCE)
2+
23
Quickstart:
34
```bash
45
pip install -r ch11_policy_gradient/requirements.txt

ch11_policy_gradient/agents/reinforce.py

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,30 +15,39 @@ class Reinforce:
1515
alpha: float = 0.05
1616
normalize_adv: bool = True
1717
baseline_fn: Optional[Callable[[object], float]] = None
18-
seed: int | None = None
18+
seed: Optional[int] = None
19+
1920
def __post_init__(self):
2021
self.rng = np.random.default_rng(self.seed)
22+
2123
def run_episode_discrete(self, env, policy, feature_fn: Callable[[object], np.ndarray]):
22-
s = env.reset(); S,A,R,L = [],[],[],[]; done=False
24+
s = env.reset()
25+
S, A, R, L = [], [], [], []
26+
done = False
2327
while not done:
24-
x = feature_fn(s); a = policy.sample(x)
25-
logp,_ = policy.logprob_and_grad(x,a)
28+
x = feature_fn(s)
29+
a = policy.sample(x)
30+
logp, _ = policy.logprob_and_grad(x, a)
2631
ns, r, done, _ = env.step(a)
27-
S.append(s); A.append(a); R.append(r); L.append(logp); s = ns
28-
return Trajectory(S,A,R,L)
32+
S.append(s); A.append(a); R.append(r); L.append(logp)
33+
s = ns
34+
return Trajectory(S, A, R, L)
35+
2936
def update_discrete(self, traj: Trajectory, policy, feature_fn: Callable[[object], np.ndarray]):
3037
G = returns_to_go(traj.rewards, self.gamma)
3138
if self.baseline_fn is not None:
32-
b = np.array([self.baseline_fn(s) for s in traj.states], dtype=float); adv = G - b
39+
b = np.array([self.baseline_fn(s) for s in traj.states], dtype=float)
40+
adv = G - b
3341
else:
3442
adv = G.copy()
3543
if self.normalize_adv:
36-
# Only standardize when there is variability; for 1-step episodes std==0 leads to zero updates.
3744
if len(adv) >= 2 and np.std(adv) > 1e-8:
3845
adv = standardize(adv)
46+
3947
total_grad = np.zeros_like(policy.theta)
40-
for s,a,adv_t in zip(traj.states, traj.actions, adv):
41-
x = feature_fn(s); _, grad = policy.logprob_and_grad(x,a)
48+
for s, a, adv_t in zip(traj.states, traj.actions, adv):
49+
x = feature_fn(s)
50+
_, grad = policy.logprob_and_grad(x, a)
4251
total_grad += adv_t * grad
4352
policy.theta += self.alpha * total_grad
4453
return {"G": G, "adv": adv}

ch11_policy_gradient/envs/bandit.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,22 @@
11
import numpy as np
22
from dataclasses import dataclass
3+
from typing import Optional, Tuple
34

45
@dataclass
56
class TwoArmedBandit:
6-
q_star: tuple[float, float] = (1.0, 1.5)
7-
seed: int | None = None
7+
q_star: Tuple[float, float] = (1.0, 1.5)
8+
seed: Optional[int] = None
9+
810
def __post_init__(self):
911
self.rng = np.random.default_rng(self.seed)
12+
1013
@property
11-
def nA(self): return 2
14+
def nA(self):
15+
return 2
16+
1217
def reset(self):
13-
return np.array([1.0], dtype=float) # x(s)=1
18+
return np.array([1.0], dtype=float)
19+
1420
def step(self, a: int):
1521
assert a in (0,1)
1622
r = float(self.rng.normal(self.q_star[a], 1.0))

ch11_policy_gradient/examples/bandit_softmax.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,15 @@ def run(episodes=200, seed=0):
88
x = np.array([1.0], dtype=float)
99
policy = SoftmaxPolicy(nA=2, d=1, seed=seed)
1010
algo = Reinforce(gamma=1.0, alpha=0.05, normalize_adv=True, baseline_fn=None, seed=seed)
11+
1112
probs_hist = []
13+
1214
class EPEnv:
1315
def reset(self): return x
1416
def step(self, a):
1517
_, r, done, _ = env.step(a)
1618
return None, r, True, {}
19+
1720
for _ in range(episodes):
1821
traj = algo.run_episode_discrete(EPEnv(), policy, lambda s: s)
1922
algo.update_discrete(traj, policy, lambda s: s)

ch11_policy_gradient/policies/gaussian.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,23 @@
11
import numpy as np
22
from dataclasses import dataclass
3+
from typing import Optional
34

45
@dataclass
56
class GaussianPolicy1D:
67
mu: float = 0.0
78
log_sigma: float = 0.0
8-
seed: int | None = None
9+
seed: Optional[int] = None
10+
911
def __post_init__(self):
1012
self.rng = np.random.default_rng(self.seed)
13+
1114
@property
1215
def sigma(self) -> float:
1316
return float(np.exp(self.log_sigma))
17+
1418
def sample(self, _x=None) -> float:
1519
return float(self.rng.normal(self.mu, self.sigma))
20+
1621
def logprob_and_grad(self, a: float, _x=None):
1722
sigma2 = self.sigma ** 2
1823
logp = -0.5 * ((a - self.mu) ** 2 / sigma2 + np.log(2*np.pi*sigma2))
Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,35 @@
11
import numpy as np
22
from dataclasses import dataclass
3+
from typing import Optional
34

45
@dataclass
56
class SoftmaxPolicy:
67
nA: int
78
d: int
8-
theta: np.ndarray | None = None
9-
seed: int | None = None
9+
theta: Optional[np.ndarray] = None
10+
seed: Optional[int] = None
11+
1012
def __post_init__(self):
1113
if self.theta is None:
1214
self.theta = np.zeros((self.nA, self.d), dtype=float)
1315
self.rng = np.random.default_rng(self.seed)
16+
1417
def prefs(self, x: np.ndarray) -> np.ndarray:
1518
return self.theta @ x
19+
1620
def probs(self, x: np.ndarray) -> np.ndarray:
17-
h = self.prefs(x); h -= np.max(h)
18-
e = np.exp(h); return e / e.sum()
21+
h = self.prefs(x)
22+
h -= np.max(h)
23+
e = np.exp(h)
24+
return e / e.sum()
25+
1926
def sample(self, x: np.ndarray) -> int:
20-
p = self.probs(x); return int(self.rng.choice(self.nA, p=p))
27+
p = self.probs(x)
28+
return int(self.rng.choice(self.nA, p=p))
29+
2130
def logprob_and_grad(self, x: np.ndarray, a: int):
22-
p = self.probs(x); logp = float(np.log(p[a] + 1e-12))
23-
grad = -np.outer(p, x); grad[a, :] += x
31+
p = self.probs(x)
32+
logp = float(np.log(p[a] + 1e-12))
33+
grad = -np.outer(p, x)
34+
grad[a, :] += x
2435
return logp, grad

ch11_policy_gradient/tests/conftest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Ensure repo root on sys.path so `import ch11_policy_gradient` works from any cwd
1+
# Ensure repo root is on sys.path so `import ch11_policy_gradient` works
22
import os, sys
33
ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..'))
44
if ROOT not in sys.path:
Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,14 @@
11
import numpy as np
2+
23
def returns_to_go(rewards, gamma: float) -> np.ndarray:
3-
G = np.zeros(len(rewards), dtype=float); g = 0.0
4+
G = np.zeros(len(rewards), dtype=float)
5+
g = 0.0
46
for t in reversed(range(len(rewards))):
5-
g = rewards[t] + gamma * g; G[t] = g
7+
g = rewards[t] + gamma * g
8+
G[t] = g
69
return G
10+
711
def standardize(x: np.ndarray, eps: float = 1e-8) -> np.ndarray:
8-
mu, std = x.mean(), x.std()
12+
mu = x.mean()
13+
std = x.std()
914
return (x - mu) / (std + eps)

0 commit comments

Comments
 (0)