77ACTIONS = [(0 , 1 ), (0 , - 1 ), (1 , 0 ), (- 1 , 0 )] # R, L, D, U
88
99def _is_terminal (env : GridWorld4x4 , s ) -> bool :
10- """Robust terminal check even if env.is_terminal is absent."""
1110 if hasattr (env , "is_terminal" ):
1211 return bool (env .is_terminal (s ))
1312 st = s if isinstance (s , tuple ) else env .i2s [int (s )]
1413 return st == env .goal
1514
1615def _step (env : GridWorld4x4 , s , a ):
17- """Use env. step if present; otherwise use P/R (deterministic) ."""
16+ """Robust step that does NOT depend on env.P; uses geometry ."""
1817 if hasattr (env , "step" ):
1918 return env .step (s , a )
20- s_idx = env .s2i [s ] if isinstance (s , tuple ) else int (s )
21- probs = env .P [s_idx , a ]
22- sp_idx = int (np .argmax (probs ))
23- r = float (env .R [s_idx , a , sp_idx ])
24- return env .i2s [sp_idx ], r
19+ st = s if isinstance (s , tuple ) else env .i2s [int (s )]
20+ i , j = st
21+ di , dj = ACTIONS [a ]
22+ # infer grid size
23+ n = getattr (env , "n" , int (round (len (env .S ) ** 0.5 )))
24+ ni , nj = i + di , j + dj
25+ if not (0 <= ni < n and 0 <= nj < n ):
26+ ni , nj = i , j # wall -> stay
27+ sp = (ni , nj )
28+ # reward: step cost unless entering goal, then 0.0 (matches your ch4 tests)
29+ step_reward = float (getattr (env , "step_reward" , - 1.0 ))
30+ r = 0.0 if sp == getattr (env , "goal" , (0 , 3 )) else step_reward
31+ return sp , r
2532
2633def _greedy_action (q_row : np .ndarray ) -> int :
2734 return int (np .argmax (q_row ))
2835
2936def generate_episode_es (env : GridWorld4x4 , Q : np .ndarray , gamma : float , max_steps : int = 10000 ):
30- """
31- Exploring starts: start from a random non-terminal state and random action,
32- then follow greedy policy thereafter. Returns (states, actions, returns).
33- """
3437 rng = np .random .default_rng ()
3538 non_terminal = [s for s in env .S if not _is_terminal (env , s )]
3639 s = non_terminal [rng .integers (len (non_terminal ))]
3740 a = int (rng .integers (len (env .A )))
3841
3942 states = [s ]
4043 actions = [a ]
41- rewards = [0.0 ] # so rewards[t+1] aligns with action at t
44+ rewards = [0.0 ]
4245
4346 steps = 0
4447 while not _is_terminal (env , s ) and steps < max_steps :
@@ -52,7 +55,6 @@ def generate_episode_es(env: GridWorld4x4, Q: np.ndarray, gamma: float, max_step
5255 actions .append (a )
5356 steps += 1
5457
55- # first-visit returns
5658 G = 0.0
5759 returns = np .zeros (len (states ), dtype = float )
5860 for t in range (len (states ) - 1 , - 1 , - 1 ):
@@ -61,30 +63,23 @@ def generate_episode_es(env: GridWorld4x4, Q: np.ndarray, gamma: float, max_step
6163 return states , actions , returns
6264
6365def mc_es_control (env : GridWorld4x4 , episodes : int = 1500 , gamma : float | None = None , seed : int | None = None ):
64- """
65- On-policy Monte Carlo control with Exploring Starts (ES).
66- Returns:
67- Q: (S,A) action-value table
68- pi: (S,A) deterministic greedy policy derived from Q
69- """
7066 if seed is not None :
7167 np .random .seed (seed )
7268 if gamma is None :
73- gamma = float (env . gamma )
69+ gamma = float (getattr ( env , " gamma" , 1.0 ) )
7470
7571 S , A = len (env .S ), len (env .A )
7672 Q = np .zeros ((S , A ), dtype = float )
77- N = np .zeros ((S , A ), dtype = float ) # first-visit counts
73+ N = np .zeros ((S , A ), dtype = float )
7874
7975 for _ in range (episodes ):
8076 states , actions , returns = generate_episode_es (env , Q , gamma )
8177 seen = set ()
8278 for t , (s , a ) in enumerate (zip (states , actions )):
8379 s_idx = env .s2i [s ]
84- key = (s_idx , a )
85- if key in seen :
80+ if (s_idx , a ) in seen :
8681 continue
87- seen .add (key )
82+ seen .add (( s_idx , a ) )
8883 G = returns [t ]
8984 N [s_idx , a ] += 1.0
9085 alpha = 1.0 / N [s_idx , a ]
0 commit comments