Skip to content

Commit 90db243

Browse files
ch5: ES returns over actions; ε-soft dict policy; geometry-based step
1 parent 9e53a01 commit 90db243

1 file changed

Lines changed: 9 additions & 22 deletions

File tree

ch5_monte_carlo/examples/mc_control_onpolicy_gridworld.py

Lines changed: 9 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
# ch5_monte_carlo/examples/mc_control_onpolicy_gridworld.py
2-
# On-policy Monte Carlo control with ε-greedy behavior/target policy (no ES).
3-
# Returns an ε-soft policy as a dict keyed by (state_tuple, action_tuple), as required by tests.
2+
# On-policy MC control with ε-greedy behavior/target; returns an ε-soft dict policy.
43

54
from __future__ import annotations
65
import numpy as np
@@ -9,16 +8,9 @@
98

109
ACTIONS = [(0, 1), (0, -1), (1, 0), (-1, 0)] # Right, Left, Down, Up
1110

12-
# ------------ robust helpers (no dependence on env.P or env.is_terminal) ------------
13-
14-
def _goal(env):
15-
return getattr(env, "goal", (0, 3))
16-
17-
def _n(env):
18-
return getattr(env, "n", int(round(len(env.S) ** 0.5)))
19-
20-
def _step_reward(env):
21-
return float(getattr(env, "step_reward", -1.0))
11+
def _goal(env): return getattr(env, "goal", (0, 3))
12+
def _n(env): return getattr(env, "n", int(round(len(env.S) ** 0.5)))
13+
def _step_reward(env): return float(getattr(env, "step_reward", -1.0))
2214

2315
def _is_terminal(env, s) -> bool:
2416
if hasattr(env, "is_terminal"):
@@ -43,11 +35,8 @@ def _step(env, s, a):
4335
def _epsilon_greedy(q_row: np.ndarray, epsilon: float, rng: np.random.Generator) -> int:
4436
return int(rng.integers(len(q_row))) if rng.random() < epsilon else int(np.argmax(q_row))
4537

46-
# ---------------------------------- core --------------------------------------
47-
4838
def generate_episode_onpolicy(env, Q: np.ndarray, epsilon: float,
4939
rng: np.random.Generator, max_steps: int = 10_000):
50-
"""Start from a random non-terminal state; follow ε-greedy w.r.t. Q."""
5140
non_terminal = [s for s in env.S if not _is_terminal(env, s)]
5241
s = non_terminal[rng.integers(len(non_terminal))]
5342

@@ -62,7 +51,6 @@ def generate_episode_onpolicy(env, Q: np.ndarray, epsilon: float,
6251
states.append(s)
6352
steps += 1
6453

65-
# returns over number of actions
6654
gamma = float(getattr(env, "gamma", 1.0))
6755
T = len(actions)
6856
G = 0.0
@@ -76,10 +64,10 @@ def mc_control_onpolicy(env, episodes: int = 5000,
7664
epsilon: float = 0.1, gamma: float | None = None,
7765
seed: int | None = None):
7866
"""
79-
On-policy MC control using ε-greedy behavior = target policy (ε-soft).
8067
Returns:
81-
Q: (S,A) action-value table
82-
pi_soft: dict mapping (state_tuple, action_tuple) -> probability
68+
Q: (S,A)
69+
pi_soft: dict mapping (state_tuple, action_tuple) -> probability
70+
(ε-soft, so tests can do pi_soft[(s, a_tup)])
8371
"""
8472
rng = np.random.default_rng(seed)
8573
S, A = len(env.S), len(env.A)
@@ -100,10 +88,9 @@ def mc_control_onpolicy(env, episodes: int = 5000,
10088
seen.add(key)
10189
G = returns[t]
10290
N[s_idx, a] += 1.0
103-
alpha = 1.0 / N[s_idx, a]
104-
Q[s_idx, a] += alpha * (G - Q[s_idx, a])
91+
Q[s_idx, a] += (G - Q[s_idx, a]) / N[s_idx, a]
10592

106-
# Build ε-soft policy as a dict keyed by (state_tuple, action_tuple)
93+
# Build ε-soft dict policy keyed by (state_tuple, action_tuple)
10794
pi_soft = {}
10895
for s_idx, s in enumerate(env.S):
10996
a_star = int(np.argmax(Q[s_idx]))

0 commit comments

Comments
 (0)