Skip to content

Commit 7fc313b

Browse files
ch5: align ACTIONS with env.P keys (0..3); robust ES returns; ε-soft dict policy keyed by action index
1 parent 90db243 commit 7fc313b

File tree

2 files changed

+39
-56
lines changed

2 files changed

+39
-56
lines changed
Lines changed: 23 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,72 +1,58 @@
11
# ch5_monte_carlo/examples/mc_control_es_gridworld.py
22
# Monte Carlo control with Exploring Starts (ES) on a 4x4 GridWorld.
3-
# Robust to different GridWorld implementations: does not rely on env.P or env.is_terminal.
3+
# Robust: no reliance on env.P shape nor env.is_terminal presence.
44

55
from __future__ import annotations
66
import numpy as np
77

88
__all__ = ["mc_es_control", "generate_episode_es", "ACTIONS"]
99

10-
# Must match the environment's action ordering everywhere in the repo
11-
ACTIONS = [(0, 1), (0, -1), (1, 0), (-1, 0)] # Right, Left, Down, Up
10+
# Tests expect ACTIONS to be action *indices* usable as env.P[s_idx][a] keys.
11+
ACTIONS = [0, 1, 2, 3] # exported for tests
12+
DIRECTIONS = [(0, 1), (0, -1), (1, 0), (-1, 0)] # R, L, D, U (internal geometry)
1213

13-
# ------------ helpers that do not assume specific env attributes -------------
14-
15-
def _goal(env):
16-
return getattr(env, "goal", (0, 3))
17-
18-
def _n(env):
19-
# prefer env.n; otherwise infer from |S|
20-
return getattr(env, "n", int(round(len(env.S) ** 0.5)))
21-
22-
def _step_reward(env):
23-
return float(getattr(env, "step_reward", -1.0))
14+
def _goal(env): return getattr(env, "goal", (0, 3))
15+
def _n(env): return getattr(env, "n", int(round(len(env.S) ** 0.5)))
16+
def _step_reward(env): return float(getattr(env, "step_reward", -1.0))
2417

2518
def _is_terminal(env, s) -> bool:
2619
if hasattr(env, "is_terminal"):
2720
return bool(env.is_terminal(s))
2821
st = s if isinstance(s, tuple) else env.i2s[int(s)]
2922
return st == _goal(env)
3023

31-
def _step(env, s, a):
32-
"""Robust step that uses env.step if available; else uses grid geometry."""
24+
def _step(env, s, a_idx: int):
25+
"""Use env.step if present; else geometric fallback using DIRECTIONS."""
3326
if hasattr(env, "step"):
34-
return env.step(s, a)
27+
return env.step(s, a_idx)
3528
st = s if isinstance(s, tuple) else env.i2s[int(s)]
3629
i, j = st
37-
di, dj = ACTIONS[a]
30+
di, dj = DIRECTIONS[a_idx]
3831
n = _n(env)
3932
ni, nj = i + di, j + dj
4033
if not (0 <= ni < n and 0 <= nj < n):
41-
ni, nj = i, j # wall -> stay
34+
ni, nj = i, j
4235
sp = (ni, nj)
4336
r = 0.0 if sp == _goal(env) else _step_reward(env)
4437
return sp, r
4538

4639
def _greedy_action(q_row: np.ndarray) -> int:
4740
return int(np.argmax(q_row))
4841

49-
# ------------------------------- core logic ----------------------------------
50-
5142
def generate_episode_es(env, Q: np.ndarray, gamma: float, max_steps: int = 10_000):
5243
"""
53-
Exploring starts:
54-
- start from a random NON-terminal state
55-
- start with a random action
56-
- thereafter follow greedy policy w.r.t. Q
57-
Returns:
58-
states: list of states (tuples), length T = number of actions
59-
actions: list of action indices, length T
60-
returns: list/array of returns G_t, length T
44+
Exploring starts: start from random non-terminal state & random action,
45+
then follow greedy policy w.r.t. Q.
46+
Returns aligned (states, actions, returns) of length T = #actions.
6147
"""
6248
rng = np.random.default_rng()
6349
non_terminal = [s for s in env.S if not _is_terminal(env, s)]
6450
s = non_terminal[rng.integers(len(non_terminal))]
65-
a = int(rng.integers(len(env.A)))
51+
a = int(rng.integers(len(env.A))) # int action index
6652

6753
states = [s]
6854
actions = [a]
69-
rewards = [0.0] # align indexing so rewards[t+1] corresponds to action at t
55+
rewards = [0.0] # rewards[t+1] corresponds to action at t
7056

7157
steps = 0
7258
while not _is_terminal(env, s) and steps < max_steps:
@@ -80,23 +66,17 @@ def generate_episode_es(env, Q: np.ndarray, gamma: float, max_steps: int = 10_00
8066
actions.append(a)
8167
steps += 1
8268

83-
# ---- returns over number of actions (T) ----
69+
# Compute returns over T = len(actions); guard rewards indexing just in case.
8470
T = len(actions)
8571
G = 0.0
8672
returns = np.zeros(T, dtype=float)
8773
for t in range(T - 1, -1, -1):
88-
G = rewards[t + 1] + gamma * G
74+
r_tp1 = rewards[t + 1] if (t + 1) < len(rewards) else 0.0
75+
G = r_tp1 + gamma * G
8976
returns[t] = G
90-
# keep states length consistent with actions/returns
9177
return states[:T], actions, returns
9278

9379
def mc_es_control(env, episodes: int = 1500, gamma: float | None = None, seed: int | None = None):
94-
"""
95-
On-policy Monte Carlo control with Exploring Starts (ES).
96-
Returns:
97-
Q: (S,A) action-value table
98-
pi: (S,A) deterministic greedy policy derived from Q
99-
"""
10080
if seed is not None:
10181
np.random.seed(seed)
10282
if gamma is None:
@@ -113,14 +93,13 @@ def mc_es_control(env, episodes: int = 1500, gamma: float | None = None, seed: i
11393
s_idx = env.s2i[s]
11494
key = (s_idx, a)
11595
if key in seen:
116-
continue # first-visit MC
96+
continue
11797
seen.add(key)
11898
G = returns[t]
11999
N[s_idx, a] += 1.0
120-
alpha = 1.0 / N[s_idx, a]
121-
Q[s_idx, a] += alpha * (G - Q[s_idx, a])
100+
Q[s_idx, a] += (G - Q[s_idx, a]) / N[s_idx, a]
122101

123-
# greedy deterministic policy from Q
102+
# deterministic greedy policy over action indices
124103
pi = np.zeros((S, A), dtype=float)
125104
pi[np.arange(S), np.argmax(Q, axis=1)] = 1.0
126105
return Q, pi

ch5_monte_carlo/examples/mc_control_onpolicy_gridworld.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,16 @@
11
# ch5_monte_carlo/examples/mc_control_onpolicy_gridworld.py
2-
# On-policy MC control with ε-greedy behavior/target; returns an ε-soft dict policy.
2+
# On-policy MC control with ε-greedy behavior/target.
3+
# Returns an ε-soft dict policy keyed by (state_tuple, action_index).
34

45
from __future__ import annotations
56
import numpy as np
67

78
__all__ = ["mc_control_onpolicy", "ACTIONS", "generate_episode_onpolicy"]
89

9-
ACTIONS = [(0, 1), (0, -1), (1, 0), (-1, 0)] # Right, Left, Down, Up
10+
# Tests iterate over ACTIONS and then index env.P[s_idx][a],
11+
# so ACTIONS must be action indices (0..3), not direction vectors.
12+
ACTIONS = [0, 1, 2, 3] # exported
13+
DIRECTIONS = [(0, 1), (0, -1), (1, 0), (-1, 0)] # internal geometry
1014

1115
def _goal(env): return getattr(env, "goal", (0, 3))
1216
def _n(env): return getattr(env, "n", int(round(len(env.S) ** 0.5)))
@@ -18,12 +22,12 @@ def _is_terminal(env, s) -> bool:
1822
st = s if isinstance(s, tuple) else env.i2s[int(s)]
1923
return st == _goal(env)
2024

21-
def _step(env, s, a):
25+
def _step(env, s, a_idx: int):
2226
if hasattr(env, "step"):
23-
return env.step(s, a)
27+
return env.step(s, a_idx)
2428
st = s if isinstance(s, tuple) else env.i2s[int(s)]
2529
i, j = st
26-
di, dj = ACTIONS[a]
30+
di, dj = DIRECTIONS[a_idx]
2731
n = _n(env)
2832
ni, nj = i + di, j + dj
2933
if not (0 <= ni < n and 0 <= nj < n):
@@ -44,7 +48,7 @@ def generate_episode_onpolicy(env, Q: np.ndarray, epsilon: float,
4448
steps = 0
4549
while not _is_terminal(env, s) and steps < max_steps:
4650
a = _epsilon_greedy(Q[env.s2i[s]], epsilon, rng)
47-
actions.append(a)
51+
actions.append(a) # action index
4852
sp, r = _step(env, s, a)
4953
rewards.append(float(r))
5054
s = sp
@@ -56,7 +60,8 @@ def generate_episode_onpolicy(env, Q: np.ndarray, epsilon: float,
5660
G = 0.0
5761
returns = np.zeros(T, dtype=float)
5862
for t in range(T - 1, -1, -1):
59-
G = rewards[t + 1] + gamma * G
63+
r_tp1 = rewards[t + 1] if (t + 1) < len(rewards) else 0.0
64+
G = r_tp1 + gamma * G
6065
returns[t] = G
6166
return states[:T], actions, returns
6267

@@ -66,8 +71,7 @@ def mc_control_onpolicy(env, episodes: int = 5000,
6671
"""
6772
Returns:
6873
Q: (S,A)
69-
pi_soft: dict mapping (state_tuple, action_tuple) -> probability
70-
(ε-soft, so tests can do pi_soft[(s, a_tup)])
74+
pi_soft: dict mapping (state_tuple, action_index) -> probability
7175
"""
7276
rng = np.random.default_rng(seed)
7377
S, A = len(env.S), len(env.A)
@@ -90,13 +94,13 @@ def mc_control_onpolicy(env, episodes: int = 5000,
9094
N[s_idx, a] += 1.0
9195
Q[s_idx, a] += (G - Q[s_idx, a]) / N[s_idx, a]
9296

93-
# Build ε-soft dict policy keyed by (state_tuple, action_tuple)
97+
# Build ε-soft dict policy keyed by (state_tuple, action_index)
9498
pi_soft = {}
9599
for s_idx, s in enumerate(env.S):
96100
a_star = int(np.argmax(Q[s_idx]))
97-
for a_idx, a_tup in enumerate(ACTIONS):
101+
for a_idx in ACTIONS:
98102
prob = (1.0 - epsilon) if a_idx == a_star else 0.0
99103
prob += epsilon / A
100-
pi_soft[(s, a_tup)] = prob
104+
pi_soft[(s, a_idx)] = prob
101105

102106
return Q, pi_soft

0 commit comments

Comments
 (0)