1- # ch5_monte_carlo/examples/mc_control_onpolicy_gridworld.py
2- import numpy as np
1+ import numpy as np
32from ch4_dynamic_programming .gridworld import GridWorld4x4
43
54__all__ = ["mc_control_onpolicy" , "ACTIONS" , "generate_episode_onpolicy" ]
65
7- # Must match the environment's action ordering
86ACTIONS = [(0 , 1 ), (0 , - 1 ), (1 , 0 ), (- 1 , 0 )] # R, L, D, U
97
10- def _epsilon_greedy (Q_row : np .ndarray , epsilon : float , rng : np .random .Generator ) -> int :
11- if rng .random () < epsilon :
12- return int (rng .integers (len (Q_row )))
13- return int (np .argmax (Q_row ))
8+ def _is_terminal (env : GridWorld4x4 , s ) -> bool :
9+ if hasattr (env , "is_terminal" ):
10+ return bool (env .is_terminal (s ))
11+ st = s if isinstance (s , tuple ) else env .i2s [int (s )]
12+ return st == env .goal
13+
14+ def _step (env : GridWorld4x4 , s , a ):
15+ if hasattr (env , "step" ):
16+ return env .step (s , a )
17+ s_idx = env .s2i [s ] if isinstance (s , tuple ) else int (s )
18+ probs = env .P [s_idx , a ]
19+ sp_idx = int (np .argmax (probs ))
20+ r = float (env .R [s_idx , a , sp_idx ])
21+ return env .i2s [sp_idx ], r
22+
23+ def _epsilon_greedy (q_row : np .ndarray , epsilon : float , rng : np .random .Generator ) -> int :
24+ return int (rng .integers (len (q_row ))) if rng .random () < epsilon else int (np .argmax (q_row ))
1425
1526def generate_episode_onpolicy (env : GridWorld4x4 , Q : np .ndarray , epsilon : float ,
16- rng : np .random .Generator , max_steps : int = 10_000 ):
17- """Start from a random non-terminal state; follow ε-greedy w.r.t. Q throughout ."""
18- non_terminal = [s for s in env .S if not env . is_terminal ( s )]
27+ rng : np .random .Generator , max_steps : int = 10000 ):
28+ """Start from a random non-terminal state; follow ε-greedy w.r.t. Q."""
29+ non_terminal = [s for s in env .S if not _is_terminal ( env , s )]
1930 s = non_terminal [rng .integers (len (non_terminal ))]
20- S , A = len (env .S ), len (env .A )
2131
2232 states , actions , rewards = [s ], [], [0.0 ]
2333 steps = 0
24- while not env . is_terminal ( s ) and steps < max_steps :
34+ while not _is_terminal ( env , s ) and steps < max_steps :
2535 a = _epsilon_greedy (Q [env .s2i [s ]], epsilon , rng )
2636 actions .append (a )
27- sp , r = env . step ( s , a )
37+ sp , r = _step ( env , s , a )
2838 rewards .append (float (r ))
2939 s = sp
3040 states .append (s )
3141 steps += 1
3242
3343 # first-visit returns
34- gamma = env . gamma
44+ gamma = float ( getattr ( env , " gamma" , 1.0 ))
3545 G = 0.0
3646 returns = np .zeros (len (actions ), dtype = float )
3747 for t in range (len (actions ) - 1 , - 1 , - 1 ):
@@ -43,15 +53,15 @@ def mc_control_onpolicy(env: GridWorld4x4, episodes: int = 5000,
4353 epsilon : float = 0.1 , gamma : float | None = None ,
4454 seed : int | None = None ):
4555 """
46- On-policy Monte Carlo control using ε-greedy behavior/target policy (no exploring starts ).
56+ On-policy MC control using ε-greedy behavior/target policy (no ES ).
4757 Returns:
48- Q: (S,A) table
49- pi: (S,A) deterministic greedy policy derived from Q
58+ Q: (S,A)
59+ pi: (S,A) deterministic greedy policy
5060 """
5161 rng = np .random .default_rng (seed )
5262 S , A = len (env .S ), len (env .A )
5363 if gamma is None :
54- gamma = float (env . gamma )
64+ gamma = float (getattr ( env , " gamma" , 1.0 ) )
5565
5666 Q = np .zeros ((S , A ), dtype = float )
5767 N = np .zeros ((S , A ), dtype = float )
@@ -63,21 +73,13 @@ def mc_control_onpolicy(env: GridWorld4x4, episodes: int = 5000,
6373 s_idx = env .s2i [s ]
6474 key = (s_idx , a )
6575 if key in seen :
66- continue # first-visit MC
76+ continue
6777 seen .add (key )
6878 G = returns [t ]
6979 N [s_idx , a ] += 1.0
7080 alpha = 1.0 / N [s_idx , a ]
7181 Q [s_idx , a ] += alpha * (G - Q [s_idx , a ])
7282
73- # deterministic greedy policy
7483 pi = np .zeros ((S , A ), dtype = float )
7584 pi [np .arange (S ), np .argmax (Q , axis = 1 )] = 1.0
7685 return Q , pi
77-
78- if __name__ == "__main__" :
79- env = GridWorld4x4 (step_reward = - 1.0 , goal = (0 , 3 ), gamma = 1.0 )
80- Q , pi = mc_control_onpolicy (env , episodes = 3000 , epsilon = 0.1 , seed = 0 )
81- s0 = env .s2i [(0 , 0 )]
82- print ("Q(start):" , Q [s0 ])
83- print ("Greedy action at start:" , int (np .argmax (pi [s0 ])))
0 commit comments