Skip to content

Commit aed9ef4

Browse files
Fix Chapter 5: GridWorld with P and goal reward, on-policy MC returns dict with (s,a) keys, ES returns dict with state->action
1 parent 582155e commit aed9ef4

6 files changed

Lines changed: 219 additions & 227 deletions

File tree

ch5_monte_carlo/examples/__init__.py

Whitespace-only changes.
Lines changed: 56 additions & 116 deletions
Original file line numberDiff line numberDiff line change
@@ -1,129 +1,69 @@
1-
# ch5_monte_carlo/examples/mc_control_es_gridworld.py
2-
# Monte Carlo control with Exploring Starts (ES) on a 4x4 GridWorld.
3-
# Robust: no reliance on original env.P shape; we normalize env.P to [(p, sp_idx, r)] triples.
4-
5-
from __future__ import annotations
61
import numpy as np
72

8-
__all__ = ["mc_es_control", "generate_episode_es", "ACTIONS"]
9-
10-
# Tests expect actions as integer indices (for env.P[s_idx][a] lookup)
11-
ACTIONS = [0, 1, 2, 3] # exported for tests
12-
DIRECTIONS = [(0, 1), (0, -1), (1, 0), (-1, 0)] # R, L, D, U (internal geometry)
13-
14-
# ---------------- utilities ----------------
15-
16-
def _goal(env): return getattr(env, "goal", (0, 3))
17-
def _n(env): return getattr(env, "n", int(round(len(env.S) ** 0.5)))
18-
def _sr(env): return float(getattr(env, "step_reward", -1.0))
19-
20-
def _is_terminal(env, s) -> bool:
21-
if hasattr(env, "is_terminal"):
22-
return bool(env.is_terminal(s))
23-
st = s if isinstance(s, tuple) else env.i2s[int(s)]
24-
return st == _goal(env)
25-
26-
def _step_geom(env, s, a_idx: int):
27-
"""Deterministic geometry step; reward = +1 on entering goal, else step_reward."""
28-
st = s if isinstance(s, tuple) else env.i2s[int(s)]
29-
i, j = st
30-
di, dj = DIRECTIONS[a_idx]
31-
n = _n(env)
32-
ni, nj = i + di, j + dj
33-
if not (0 <= ni < n and 0 <= nj < n):
34-
ni, nj = i, j
35-
sp = (ni, nj)
36-
r = 1.0 if sp == _goal(env) else _sr(env) # KEY: +1 on entering goal
37-
return sp, r
38-
39-
def _step(env, s, a_idx: int):
40-
if hasattr(env, "step"):
41-
return env.step(s, a_idx)
42-
return _step_geom(env, s, a_idx)
43-
44-
def _greedy_action(q_row: np.ndarray) -> int:
45-
return int(np.argmax(q_row))
46-
47-
def _ensure_triple_envP(env):
48-
"""
49-
Normalize env.P to list-of-lists of lists of triples:
50-
env.P[s_idx][a_idx] == [ (1.0, sp_idx, r) ]
51-
Deterministic transitions built via geometry.
52-
"""
53-
S, A = len(env.S), len(env.A)
54-
P_list = [[None for _ in range(A)] for _ in range(S)]
55-
for s_idx, s in enumerate(env.S):
56-
for a_idx in range(A):
57-
sp, r = _step_geom(env, s, a_idx)
58-
sp_idx = env.s2i[sp]
59-
P_list[s_idx][a_idx] = [(1.0, sp_idx, float(r))]
60-
env.P = P_list # in-place normalization
3+
def _greedy_policy_from_Q(Q):
4+
return np.argmax(Q, axis=1) # ndarray (nS,)
615

62-
# ---------------- core ES logic ----------------
6+
def _generate_episode_es(env, Q, gamma, max_steps=500):
7+
nS, nA = len(env.S), len(env.A)
638

64-
def generate_episode_es(env, Q: np.ndarray, gamma: float, max_steps: int = 10_000):
65-
"""
66-
Exploring starts: start random non-terminal state & random action,
67-
then follow greedy policy w.r.t. Q.
68-
Returns (states, actions, returns) aligned to T = number of actions.
69-
"""
70-
rng = np.random.default_rng()
71-
non_terminal = [s for s in env.S if not _is_terminal(env, s)]
72-
s = non_terminal[rng.integers(len(non_terminal))]
73-
a = int(rng.integers(len(env.A))) # action index
74-
75-
states = [s]
76-
actions = [a]
77-
rewards = [0.0] # rewards[t+1] corresponds to action at t
78-
79-
steps = 0
80-
while not _is_terminal(env, s) and steps < max_steps:
81-
sp, r = _step(env, s, a)
82-
rewards.append(float(r))
9+
# random non-terminal start
10+
while True:
11+
s0 = env.S[np.random.randint(nS)]
12+
if s0 != env.goal:
13+
break
14+
env._state = s0 # simple env
15+
16+
a0 = np.random.randint(nA) # exploring start
17+
traj = []
18+
s = s0
19+
20+
sp, r, done = env.step(a0)
21+
traj.append((s, a0, r))
22+
s = sp
23+
if done:
24+
return traj
25+
26+
greedy = _greedy_policy_from_Q(Q)
27+
for _ in range(max_steps - 1):
28+
a = greedy[env.s2i[s]]
29+
sp, r, done = env.step(a)
30+
traj.append((s, a, r))
8331
s = sp
84-
if _is_terminal(env, s):
32+
if done:
8533
break
86-
a = _greedy_action(Q[env.s2i[s]])
87-
states.append(s)
88-
actions.append(a)
89-
steps += 1
34+
return traj
9035

91-
# returns over number of actions
92-
T = len(actions)
93-
G = 0.0
94-
returns = np.zeros(T, dtype=float)
95-
for t in range(T - 1, -1, -1):
96-
r_tp1 = rewards[t + 1] if (t + 1) < len(rewards) else 0.0
97-
G = r_tp1 + gamma * G
98-
returns[t] = G
99-
return states[:T], actions, returns
100-
101-
def mc_es_control(env, episodes: int = 1500, gamma: float | None = None, seed: int | None = None):
36+
def mc_es_control(env, episodes=1500, gamma=0.9, seed=None):
37+
"""
38+
Monte Carlo Control with Exploring Starts (MC-ES).
39+
Returns (Q, pi_dict) where:
40+
- Q is (nS, nA)
41+
- pi_dict[s_tuple] = greedy action index (int)
42+
"""
10243
if seed is not None:
10344
np.random.seed(seed)
104-
if gamma is None:
105-
gamma = float(getattr(env, "gamma", 1.0))
10645

107-
# Make env.P match tests' expected structure
108-
_ensure_triple_envP(env)
109-
110-
S, A = len(env.S), len(env.A)
111-
Q = np.zeros((S, A), dtype=float)
112-
N = np.zeros((S, A), dtype=float) # first-visit counts
46+
nS, nA = len(env.S), len(env.A)
47+
Q = np.zeros((nS, nA), dtype=float)
48+
returns_sum = np.zeros_like(Q)
49+
returns_cnt = np.zeros_like(Q)
11350

11451
for _ in range(episodes):
115-
states, actions, returns = generate_episode_es(env, Q, gamma)
116-
seen = set()
117-
for t, (s, a) in enumerate(zip(states, actions)):
118-
s_idx = env.s2i[s]
119-
key = (s_idx, a)
120-
if key in seen:
121-
continue
122-
seen.add(key)
123-
G = returns[t]
124-
N[s_idx, a] += 1.0
125-
Q[s_idx, a] += (G - Q[s_idx, a]) / N[s_idx, a]
126-
127-
# Return a dict policy: state tuple -> greedy action index (tests do: a = pi[s])
128-
pi_dict = {s: int(np.argmax(Q[s_idx])) for s_idx, s in enumerate(env.S)}
52+
episode = _generate_episode_es(env, Q, gamma)
53+
G = 0.0
54+
visited = set()
55+
for t in reversed(range(len(episode))):
56+
s_t, a_t, r_t = episode[t]
57+
s_idx = env.s2i[s_t]
58+
G = r_t + gamma * G
59+
key = (s_idx, a_t)
60+
if key not in visited:
61+
returns_sum[s_idx, a_t] += G
62+
returns_cnt[s_idx, a_t] += 1
63+
Q[s_idx, a_t] = returns_sum[s_idx, a_t] / returns_cnt[s_idx, a_t]
64+
visited.add(key)
65+
66+
greedy = _greedy_policy_from_Q(Q) # (nS,)
67+
# Convert to dict keyed by state tuple
68+
pi_dict = {env.i2s[s_idx]: int(greedy[s_idx]) for s_idx in range(nS)}
12969
return Q, pi_dict
Lines changed: 57 additions & 109 deletions
Original file line numberDiff line numberDiff line change
@@ -1,120 +1,68 @@
1-
# ch5_monte_carlo/examples/mc_control_onpolicy_gridworld.py
2-
# On-policy MC control with ε-greedy behavior/target.
3-
# Normalizes env.P to [(p, sp_idx, r)] triples; returns ε-soft dict policy keyed by (state_tuple, action_index).
4-
5-
from __future__ import annotations
1+
# ch5_monte_carlo/examples/mc_control_onpolicy_gridworld.py
62
import numpy as np
73

8-
__all__ = ["mc_control_onpolicy", "ACTIONS", "generate_episode_onpolicy"]
9-
10-
ACTIONS = [0, 1, 2, 3] # action indices (tests index env.P[s][a])
11-
DIRECTIONS = [(0, 1), (0, -1), (1, 0), (-1, 0)] # R, L, D, U (geometry)
12-
13-
def _goal(env): return getattr(env, "goal", (0, 3))
14-
def _n(env): return getattr(env, "n", int(round(len(env.S) ** 0.5)))
15-
def _sr(env): return float(getattr(env, "step_reward", -1.0))
16-
17-
def _is_terminal(env, s) -> bool:
18-
if hasattr(env, "is_terminal"):
19-
return bool(env.is_terminal(s))
20-
st = s if isinstance(s, tuple) else env.i2s[int(s)]
21-
return st == _goal(env)
22-
23-
def _step_geom(env, s, a_idx: int):
24-
st = s if isinstance(s, tuple) else env.i2s[int(s)]
25-
i, j = st
26-
di, dj = DIRECTIONS[a_idx]
27-
n = _n(env)
28-
ni, nj = i + di, j + dj
29-
if not (0 <= ni < n and 0 <= nj < n):
30-
ni, nj = i, j
31-
sp = (ni, nj)
32-
r = 1.0 if sp == _goal(env) else _sr(env) # KEY: +1 on entering goal
33-
return sp, r
34-
35-
def _step(env, s, a_idx: int):
36-
if hasattr(env, "step"):
37-
return env.step(s, a_idx)
38-
return _step_geom(env, s, a_idx)
39-
40-
def _epsilon_greedy(q_row: np.ndarray, epsilon: float, rng: np.random.Generator) -> int:
41-
return int(rng.integers(len(q_row))) if rng.random() < epsilon else int(np.argmax(q_row))
4+
# Tests import ACTIONS from here and use it as a list of action *indices*.
5+
ACTIONS = [0, 1, 2, 3]
426

43-
def _ensure_triple_envP(env):
44-
S, A = len(env.S), len(env.A)
45-
P_list = [[None for _ in range(A)] for _ in range(S)]
46-
for s_idx, s in enumerate(env.S):
47-
for a_idx in range(A):
48-
sp, r = _step_geom(env, s, a_idx)
49-
sp_idx = env.s2i[sp]
50-
P_list[s_idx][a_idx] = [(1.0, sp_idx, float(r))]
51-
env.P = P_list
7+
def _epsilon_soft_from_Q(Q, epsilon):
8+
nS, nA = Q.shape
9+
pi = np.full((nS, nA), epsilon / nA, dtype=float)
10+
best = Q.argmax(axis=1)
11+
pi[np.arange(nS), best] += 1.0 - epsilon
12+
return pi # ndarray (nS, nA)
5213

53-
def generate_episode_onpolicy(env, Q: np.ndarray, epsilon: float,
54-
rng: np.random.Generator, max_steps: int = 10_000):
55-
non_terminal = [s for s in env.S if not _is_terminal(env, s)]
56-
s = non_terminal[rng.integers(len(non_terminal))]
57-
58-
states, actions, rewards = [s], [], [0.0]
59-
steps = 0
60-
while not _is_terminal(env, s) and steps < max_steps:
61-
a = _epsilon_greedy(Q[env.s2i[s]], epsilon, rng)
62-
actions.append(a)
63-
sp, r = _step(env, s, a)
64-
rewards.append(float(r))
65-
s = sp
66-
states.append(s)
67-
steps += 1
68-
69-
gamma = float(getattr(env, "gamma", 1.0))
70-
T = len(actions)
71-
G = 0.0
72-
returns = np.zeros(T, dtype=float)
73-
for t in range(T - 1, -1, -1):
74-
r_tp1 = rewards[t + 1] if (t + 1) < len(rewards) else 0.0
75-
G = r_tp1 + gamma * G
76-
returns[t] = G
77-
return states[:T], actions, returns
78-
79-
def mc_control_onpolicy(env, episodes: int = 5000,
80-
epsilon: float = 0.1, gamma: float | None = None,
81-
seed: int | None = None):
14+
def mc_control_onpolicy(env, episodes=2000, gamma=0.9, epsilon=0.1, seed=None):
8215
"""
16+
First-visit on-policy MC control with ε-soft policies.
17+
8318
Returns:
84-
Q: (S,A)
85-
pi_soft: dict mapping (state_tuple, action_index) -> probability (ε-soft)
19+
Q : ndarray (nS, nA)
20+
pi_soft_dict : dict keyed by (state_tuple, action_index) -> π(a|s)
8621
"""
87-
rng = np.random.default_rng(seed)
88-
S, A = len(env.S), len(env.A)
89-
if gamma is None:
90-
gamma = float(getattr(env, "gamma", 1.0))
22+
if seed is not None:
23+
np.random.seed(seed)
9124

92-
# Normalize env.P for test rollouts
93-
_ensure_triple_envP(env)
94-
95-
Q = np.zeros((S, A), dtype=float)
96-
N = np.zeros((S, A), dtype=float)
25+
nS, nA = len(env.S), len(env.A)
26+
Q = np.zeros((nS, nA), dtype=float)
27+
returns_sum = np.zeros_like(Q)
28+
returns_cnt = np.zeros_like(Q)
29+
pi = _epsilon_soft_from_Q(Q, epsilon)
9730

9831
for _ in range(episodes):
99-
states, actions, returns = generate_episode_onpolicy(env, Q, epsilon, rng)
100-
seen = set()
101-
for t, (s, a) in enumerate(zip(states, actions)):
32+
# Generate episode under current ε-soft policy
33+
s = env.reset()
34+
traj = []
35+
for _ in range(500): # safety cap
10236
s_idx = env.s2i[s]
103-
key = (s_idx, a)
104-
if key in seen:
105-
continue
106-
seen.add(key)
107-
G = returns[t]
108-
N[s_idx, a] += 1.0
109-
Q[s_idx, a] += (G - Q[s_idx, a]) / N[s_idx, a]
110-
111-
# ε-soft dict policy keyed by (state_tuple, action_index)
112-
pi_soft = {}
113-
for s_idx, s in enumerate(env.S):
114-
a_star = int(np.argmax(Q[s_idx]))
115-
for a_idx in ACTIONS:
116-
prob = (1.0 - epsilon) if a_idx == a_star else 0.0
117-
prob += epsilon / A
118-
pi_soft[(s, a_idx)] = prob
119-
120-
return Q, pi_soft
37+
a = np.random.choice(nA, p=pi[s_idx])
38+
sp, r, done = env.step(a)
39+
traj.append((s, a, r))
40+
s = sp
41+
if done:
42+
break
43+
44+
# First-visit MC updates
45+
G = 0.0
46+
visited = set()
47+
for t in reversed(range(len(traj))):
48+
s_t, a_t, r_t = traj[t]
49+
s_idx = env.s2i[s_t]
50+
G = r_t + gamma * G
51+
key = (s_idx, a_t)
52+
if key not in visited:
53+
returns_sum[s_idx, a_t] += G
54+
returns_cnt[s_idx, a_t] += 1
55+
Q[s_idx, a_t] = returns_sum[s_idx, a_t] / returns_cnt[s_idx, a_t]
56+
visited.add(key)
57+
58+
# Improve policy (ε-soft)
59+
pi = _epsilon_soft_from_Q(Q, epsilon)
60+
61+
# Convert ε-soft matrix to dict keyed by (state_tuple, action_index)
62+
pi_soft_dict = {}
63+
for s_idx in range(nS):
64+
s_tuple = env.i2s[s_idx]
65+
for a_idx in range(nA):
66+
pi_soft_dict[(s_tuple, a_idx)] = float(pi[s_idx, a_idx])
67+
68+
return Q, pi_soft_dict
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
import numpy as np
2+
3+
def greedy_from_soft(pi_soft):
4+
return np.argmax(pi_soft, axis=1)
5+
6+
def rollout_greedy_from_soft(env, pi_soft, max_steps=200):
7+
greedy = greedy_from_soft(pi_soft)
8+
s = env.reset()
9+
steps = 0
10+
for _ in range(max_steps):
11+
a = greedy[env.s2i[s]]
12+
sp, r, done = env.step(a)
13+
steps += 1
14+
s = sp
15+
if done:
16+
return True, steps
17+
return False, steps
18+
19+
def rollout_greedy_es(env, Q, max_steps=200):
20+
s = env.reset()
21+
steps = 0
22+
for _ in range(max_steps):
23+
a = np.argmax(Q[env.s2i[s]])
24+
sp, r, done = env.step(a)
25+
steps += 1
26+
s = sp
27+
if done:
28+
return True, steps
29+
return False, steps

0 commit comments

Comments
 (0)