|
1 | | -# ch5_monte_carlo/examples/mc_control_es_gridworld.py |
2 | | -# Monte Carlo control with Exploring Starts (ES) on a 4x4 GridWorld. |
3 | | -# Robust: no reliance on original env.P shape; we normalize env.P to [(p, sp_idx, r)] triples. |
4 | | - |
5 | | -from __future__ import annotations |
6 | 1 | import numpy as np |
7 | 2 |
|
8 | | -__all__ = ["mc_es_control", "generate_episode_es", "ACTIONS"] |
9 | | - |
10 | | -# Tests expect actions as integer indices (for env.P[s_idx][a] lookup) |
11 | | -ACTIONS = [0, 1, 2, 3] # exported for tests |
12 | | -DIRECTIONS = [(0, 1), (0, -1), (1, 0), (-1, 0)] # R, L, D, U (internal geometry) |
13 | | - |
14 | | -# ---------------- utilities ---------------- |
15 | | - |
16 | | -def _goal(env): return getattr(env, "goal", (0, 3)) |
17 | | -def _n(env): return getattr(env, "n", int(round(len(env.S) ** 0.5))) |
18 | | -def _sr(env): return float(getattr(env, "step_reward", -1.0)) |
19 | | - |
20 | | -def _is_terminal(env, s) -> bool: |
21 | | - if hasattr(env, "is_terminal"): |
22 | | - return bool(env.is_terminal(s)) |
23 | | - st = s if isinstance(s, tuple) else env.i2s[int(s)] |
24 | | - return st == _goal(env) |
25 | | - |
26 | | -def _step_geom(env, s, a_idx: int): |
27 | | - """Deterministic geometry step; reward = +1 on entering goal, else step_reward.""" |
28 | | - st = s if isinstance(s, tuple) else env.i2s[int(s)] |
29 | | - i, j = st |
30 | | - di, dj = DIRECTIONS[a_idx] |
31 | | - n = _n(env) |
32 | | - ni, nj = i + di, j + dj |
33 | | - if not (0 <= ni < n and 0 <= nj < n): |
34 | | - ni, nj = i, j |
35 | | - sp = (ni, nj) |
36 | | - r = 1.0 if sp == _goal(env) else _sr(env) # KEY: +1 on entering goal |
37 | | - return sp, r |
38 | | - |
39 | | -def _step(env, s, a_idx: int): |
40 | | - if hasattr(env, "step"): |
41 | | - return env.step(s, a_idx) |
42 | | - return _step_geom(env, s, a_idx) |
43 | | - |
44 | | -def _greedy_action(q_row: np.ndarray) -> int: |
45 | | - return int(np.argmax(q_row)) |
46 | | - |
47 | | -def _ensure_triple_envP(env): |
48 | | - """ |
49 | | - Normalize env.P to list-of-lists of lists of triples: |
50 | | - env.P[s_idx][a_idx] == [ (1.0, sp_idx, r) ] |
51 | | - Deterministic transitions built via geometry. |
52 | | - """ |
53 | | - S, A = len(env.S), len(env.A) |
54 | | - P_list = [[None for _ in range(A)] for _ in range(S)] |
55 | | - for s_idx, s in enumerate(env.S): |
56 | | - for a_idx in range(A): |
57 | | - sp, r = _step_geom(env, s, a_idx) |
58 | | - sp_idx = env.s2i[sp] |
59 | | - P_list[s_idx][a_idx] = [(1.0, sp_idx, float(r))] |
60 | | - env.P = P_list # in-place normalization |
| 3 | +def _greedy_policy_from_Q(Q): |
| 4 | + return np.argmax(Q, axis=1) # ndarray (nS,) |
61 | 5 |
|
62 | | -# ---------------- core ES logic ---------------- |
| 6 | +def _generate_episode_es(env, Q, gamma, max_steps=500): |
| 7 | + nS, nA = len(env.S), len(env.A) |
63 | 8 |
|
64 | | -def generate_episode_es(env, Q: np.ndarray, gamma: float, max_steps: int = 10_000): |
65 | | - """ |
66 | | - Exploring starts: start random non-terminal state & random action, |
67 | | - then follow greedy policy w.r.t. Q. |
68 | | - Returns (states, actions, returns) aligned to T = number of actions. |
69 | | - """ |
70 | | - rng = np.random.default_rng() |
71 | | - non_terminal = [s for s in env.S if not _is_terminal(env, s)] |
72 | | - s = non_terminal[rng.integers(len(non_terminal))] |
73 | | - a = int(rng.integers(len(env.A))) # action index |
74 | | - |
75 | | - states = [s] |
76 | | - actions = [a] |
77 | | - rewards = [0.0] # rewards[t+1] corresponds to action at t |
78 | | - |
79 | | - steps = 0 |
80 | | - while not _is_terminal(env, s) and steps < max_steps: |
81 | | - sp, r = _step(env, s, a) |
82 | | - rewards.append(float(r)) |
| 9 | + # random non-terminal start |
| 10 | + while True: |
| 11 | + s0 = env.S[np.random.randint(nS)] |
| 12 | + if s0 != env.goal: |
| 13 | + break |
| 14 | + env._state = s0 # simple env |
| 15 | + |
| 16 | + a0 = np.random.randint(nA) # exploring start |
| 17 | + traj = [] |
| 18 | + s = s0 |
| 19 | + |
| 20 | + sp, r, done = env.step(a0) |
| 21 | + traj.append((s, a0, r)) |
| 22 | + s = sp |
| 23 | + if done: |
| 24 | + return traj |
| 25 | + |
| 26 | + greedy = _greedy_policy_from_Q(Q) |
| 27 | + for _ in range(max_steps - 1): |
| 28 | + a = greedy[env.s2i[s]] |
| 29 | + sp, r, done = env.step(a) |
| 30 | + traj.append((s, a, r)) |
83 | 31 | s = sp |
84 | | - if _is_terminal(env, s): |
| 32 | + if done: |
85 | 33 | break |
86 | | - a = _greedy_action(Q[env.s2i[s]]) |
87 | | - states.append(s) |
88 | | - actions.append(a) |
89 | | - steps += 1 |
| 34 | + return traj |
90 | 35 |
|
91 | | - # returns over number of actions |
92 | | - T = len(actions) |
93 | | - G = 0.0 |
94 | | - returns = np.zeros(T, dtype=float) |
95 | | - for t in range(T - 1, -1, -1): |
96 | | - r_tp1 = rewards[t + 1] if (t + 1) < len(rewards) else 0.0 |
97 | | - G = r_tp1 + gamma * G |
98 | | - returns[t] = G |
99 | | - return states[:T], actions, returns |
100 | | - |
101 | | -def mc_es_control(env, episodes: int = 1500, gamma: float | None = None, seed: int | None = None): |
| 36 | +def mc_es_control(env, episodes=1500, gamma=0.9, seed=None): |
| 37 | + """ |
| 38 | + Monte Carlo Control with Exploring Starts (MC-ES). |
| 39 | + Returns (Q, pi_dict) where: |
| 40 | + - Q is (nS, nA) |
| 41 | + - pi_dict[s_tuple] = greedy action index (int) |
| 42 | + """ |
102 | 43 | if seed is not None: |
103 | 44 | np.random.seed(seed) |
104 | | - if gamma is None: |
105 | | - gamma = float(getattr(env, "gamma", 1.0)) |
106 | 45 |
|
107 | | - # Make env.P match tests' expected structure |
108 | | - _ensure_triple_envP(env) |
109 | | - |
110 | | - S, A = len(env.S), len(env.A) |
111 | | - Q = np.zeros((S, A), dtype=float) |
112 | | - N = np.zeros((S, A), dtype=float) # first-visit counts |
| 46 | + nS, nA = len(env.S), len(env.A) |
| 47 | + Q = np.zeros((nS, nA), dtype=float) |
| 48 | + returns_sum = np.zeros_like(Q) |
| 49 | + returns_cnt = np.zeros_like(Q) |
113 | 50 |
|
114 | 51 | for _ in range(episodes): |
115 | | - states, actions, returns = generate_episode_es(env, Q, gamma) |
116 | | - seen = set() |
117 | | - for t, (s, a) in enumerate(zip(states, actions)): |
118 | | - s_idx = env.s2i[s] |
119 | | - key = (s_idx, a) |
120 | | - if key in seen: |
121 | | - continue |
122 | | - seen.add(key) |
123 | | - G = returns[t] |
124 | | - N[s_idx, a] += 1.0 |
125 | | - Q[s_idx, a] += (G - Q[s_idx, a]) / N[s_idx, a] |
126 | | - |
127 | | - # Return a dict policy: state tuple -> greedy action index (tests do: a = pi[s]) |
128 | | - pi_dict = {s: int(np.argmax(Q[s_idx])) for s_idx, s in enumerate(env.S)} |
| 52 | + episode = _generate_episode_es(env, Q, gamma) |
| 53 | + G = 0.0 |
| 54 | + visited = set() |
| 55 | + for t in reversed(range(len(episode))): |
| 56 | + s_t, a_t, r_t = episode[t] |
| 57 | + s_idx = env.s2i[s_t] |
| 58 | + G = r_t + gamma * G |
| 59 | + key = (s_idx, a_t) |
| 60 | + if key not in visited: |
| 61 | + returns_sum[s_idx, a_t] += G |
| 62 | + returns_cnt[s_idx, a_t] += 1 |
| 63 | + Q[s_idx, a_t] = returns_sum[s_idx, a_t] / returns_cnt[s_idx, a_t] |
| 64 | + visited.add(key) |
| 65 | + |
| 66 | + greedy = _greedy_policy_from_Q(Q) # (nS,) |
| 67 | + # Convert to dict keyed by state tuple |
| 68 | + pi_dict = {env.i2s[s_idx]: int(greedy[s_idx]) for s_idx in range(nS)} |
129 | 69 | return Q, pi_dict |
0 commit comments