11# ch5_monte_carlo/examples/mc_control_es_gridworld.py
22# Monte Carlo control with Exploring Starts (ES) on a 4x4 GridWorld.
3- # Robust to different GridWorld implementations: does not rely on env.P or env.is_terminal.
3+ # Robust: no reliance on env.P shape nor env.is_terminal presence .
44
55from __future__ import annotations
66import numpy as np
77
88__all__ = ["mc_es_control" , "generate_episode_es" , "ACTIONS" ]
99
10- # Must match the environment's action ordering everywhere in the repo
11- ACTIONS = [(0 , 1 ), (0 , - 1 ), (1 , 0 ), (- 1 , 0 )] # Right, Left, Down, Up
10+ # Tests expect ACTIONS to be action *indices* usable as env.P[s_idx][a] keys.
11+ ACTIONS = [0 , 1 , 2 , 3 ] # exported for tests
12+ DIRECTIONS = [(0 , 1 ), (0 , - 1 ), (1 , 0 ), (- 1 , 0 )] # R, L, D, U (internal geometry)
1213
13- # ------------ helpers that do not assume specific env attributes -------------
14-
15- def _goal (env ):
16- return getattr (env , "goal" , (0 , 3 ))
17-
18- def _n (env ):
19- # prefer env.n; otherwise infer from |S|
20- return getattr (env , "n" , int (round (len (env .S ) ** 0.5 )))
21-
22- def _step_reward (env ):
23- return float (getattr (env , "step_reward" , - 1.0 ))
14+ def _goal (env ): return getattr (env , "goal" , (0 , 3 ))
15+ def _n (env ): return getattr (env , "n" , int (round (len (env .S ) ** 0.5 )))
16+ def _step_reward (env ): return float (getattr (env , "step_reward" , - 1.0 ))
2417
2518def _is_terminal (env , s ) -> bool :
2619 if hasattr (env , "is_terminal" ):
2720 return bool (env .is_terminal (s ))
2821 st = s if isinstance (s , tuple ) else env .i2s [int (s )]
2922 return st == _goal (env )
3023
31- def _step (env , s , a ):
32- """Robust step that uses env.step if available ; else uses grid geometry ."""
24+ def _step (env , s , a_idx : int ):
25+ """Use env.step if present ; else geometric fallback using DIRECTIONS ."""
3326 if hasattr (env , "step" ):
34- return env .step (s , a )
27+ return env .step (s , a_idx )
3528 st = s if isinstance (s , tuple ) else env .i2s [int (s )]
3629 i , j = st
37- di , dj = ACTIONS [ a ]
30+ di , dj = DIRECTIONS [ a_idx ]
3831 n = _n (env )
3932 ni , nj = i + di , j + dj
4033 if not (0 <= ni < n and 0 <= nj < n ):
41- ni , nj = i , j # wall -> stay
34+ ni , nj = i , j
4235 sp = (ni , nj )
4336 r = 0.0 if sp == _goal (env ) else _step_reward (env )
4437 return sp , r
4538
4639def _greedy_action (q_row : np .ndarray ) -> int :
4740 return int (np .argmax (q_row ))
4841
49- # ------------------------------- core logic ----------------------------------
50-
5142def generate_episode_es (env , Q : np .ndarray , gamma : float , max_steps : int = 10_000 ):
5243 """
53- Exploring starts:
54- - start from a random NON-terminal state
55- - start with a random action
56- - thereafter follow greedy policy w.r.t. Q
57- Returns:
58- states: list of states (tuples), length T = number of actions
59- actions: list of action indices, length T
60- returns: list/array of returns G_t, length T
44+ Exploring starts: start from random non-terminal state & random action,
45+ then follow greedy policy w.r.t. Q.
46+ Returns aligned (states, actions, returns) of length T = #actions.
6147 """
6248 rng = np .random .default_rng ()
6349 non_terminal = [s for s in env .S if not _is_terminal (env , s )]
6450 s = non_terminal [rng .integers (len (non_terminal ))]
65- a = int (rng .integers (len (env .A )))
51+ a = int (rng .integers (len (env .A ))) # int action index
6652
6753 states = [s ]
6854 actions = [a ]
69- rewards = [0.0 ] # align indexing so rewards[t+1] corresponds to action at t
55+ rewards = [0.0 ] # rewards[t+1] corresponds to action at t
7056
7157 steps = 0
7258 while not _is_terminal (env , s ) and steps < max_steps :
@@ -80,23 +66,17 @@ def generate_episode_es(env, Q: np.ndarray, gamma: float, max_steps: int = 10_00
8066 actions .append (a )
8167 steps += 1
8268
83- # ---- returns over number of actions (T) ----
69+ # Compute returns over T = len(actions); guard rewards indexing just in case.
8470 T = len (actions )
8571 G = 0.0
8672 returns = np .zeros (T , dtype = float )
8773 for t in range (T - 1 , - 1 , - 1 ):
88- G = rewards [t + 1 ] + gamma * G
74+ r_tp1 = rewards [t + 1 ] if (t + 1 ) < len (rewards ) else 0.0
75+ G = r_tp1 + gamma * G
8976 returns [t ] = G
90- # keep states length consistent with actions/returns
9177 return states [:T ], actions , returns
9278
9379def mc_es_control (env , episodes : int = 1500 , gamma : float | None = None , seed : int | None = None ):
94- """
95- On-policy Monte Carlo control with Exploring Starts (ES).
96- Returns:
97- Q: (S,A) action-value table
98- pi: (S,A) deterministic greedy policy derived from Q
99- """
10080 if seed is not None :
10181 np .random .seed (seed )
10282 if gamma is None :
@@ -113,14 +93,13 @@ def mc_es_control(env, episodes: int = 1500, gamma: float | None = None, seed: i
11393 s_idx = env .s2i [s ]
11494 key = (s_idx , a )
11595 if key in seen :
116- continue # first-visit MC
96+ continue
11797 seen .add (key )
11898 G = returns [t ]
11999 N [s_idx , a ] += 1.0
120- alpha = 1.0 / N [s_idx , a ]
121- Q [s_idx , a ] += alpha * (G - Q [s_idx , a ])
100+ Q [s_idx , a ] += (G - Q [s_idx , a ]) / N [s_idx , a ]
122101
123- # greedy deterministic policy from Q
102+ # deterministic greedy policy over action indices
124103 pi = np .zeros ((S , A ), dtype = float )
125104 pi [np .arange (S ), np .argmax (Q , axis = 1 )] = 1.0
126105 return Q , pi
0 commit comments