Skip to content

Commit 548b73b

Browse files
ch5: robust MC (no env.is_terminal dependency)
1 parent 104f306 commit 548b73b

File tree

2 files changed

+58
-48
lines changed

2 files changed

+58
-48
lines changed

ch5_monte_carlo/examples/mc_control_es_gridworld.py

Lines changed: 29 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,37 +1,53 @@
1-
# ch5_monte_carlo/examples/mc_control_es_gridworld.py
2-
import numpy as np
1+
import numpy as np
32
from ch4_dynamic_programming.gridworld import GridWorld4x4
43

5-
__all__ = ["mc_es_control", "generate_episode_es"]
4+
__all__ = ["mc_es_control", "generate_episode_es", "ACTIONS"]
65

7-
ACTIONS = [(0, 1), (0, -1), (1, 0), (-1, 0)] # R,L,D,U (must match ch4 env)
6+
# Must match env's action ordering
7+
ACTIONS = [(0, 1), (0, -1), (1, 0), (-1, 0)] # R, L, D, U
8+
9+
def _is_terminal(env: GridWorld4x4, s) -> bool:
10+
"""Robust terminal check even if env.is_terminal is absent."""
11+
if hasattr(env, "is_terminal"):
12+
return bool(env.is_terminal(s))
13+
st = s if isinstance(s, tuple) else env.i2s[int(s)]
14+
return st == env.goal
15+
16+
def _step(env: GridWorld4x4, s, a):
17+
"""Use env.step if present; otherwise use P/R (deterministic)."""
18+
if hasattr(env, "step"):
19+
return env.step(s, a)
20+
s_idx = env.s2i[s] if isinstance(s, tuple) else int(s)
21+
probs = env.P[s_idx, a]
22+
sp_idx = int(np.argmax(probs))
23+
r = float(env.R[s_idx, a, sp_idx])
24+
return env.i2s[sp_idx], r
825

926
def _greedy_action(q_row: np.ndarray) -> int:
1027
return int(np.argmax(q_row))
1128

1229
def generate_episode_es(env: GridWorld4x4, Q: np.ndarray, gamma: float, max_steps: int = 10000):
1330
"""
1431
Exploring starts: start from a random non-terminal state and random action,
15-
then follow greedy policy w.r.t. Q thereafter. Returns (states, actions, returns).
32+
then follow greedy policy thereafter. Returns (states, actions, returns).
1633
"""
1734
rng = np.random.default_rng()
18-
non_terminal = [s for s in env.S if not env.is_terminal(s)]
35+
non_terminal = [s for s in env.S if not _is_terminal(env, s)]
1936
s = non_terminal[rng.integers(len(non_terminal))]
2037
a = int(rng.integers(len(env.A)))
2138

2239
states = [s]
2340
actions = [a]
24-
rewards = [0.0] # so rewards[t+1] aligns with action taken at t
41+
rewards = [0.0] # so rewards[t+1] aligns with action at t
2542

2643
steps = 0
27-
while not env.is_terminal(s) and steps < max_steps:
28-
sp, r = env.step(s, a)
44+
while not _is_terminal(env, s) and steps < max_steps:
45+
sp, r = _step(env, s, a)
2946
rewards.append(float(r))
3047
s = sp
31-
if env.is_terminal(s):
48+
if _is_terminal(env, s):
3249
break
33-
s_idx = env.s2i[s]
34-
a = _greedy_action(Q[s_idx])
50+
a = _greedy_action(Q[env.s2i[s]])
3551
states.append(s)
3652
actions.append(a)
3753
steps += 1
@@ -67,7 +83,7 @@ def mc_es_control(env: GridWorld4x4, episodes: int = 1500, gamma: float | None =
6783
s_idx = env.s2i[s]
6884
key = (s_idx, a)
6985
if key in seen:
70-
continue # first-visit MC
86+
continue
7187
seen.add(key)
7288
G = returns[t]
7389
N[s_idx, a] += 1.0
@@ -77,11 +93,3 @@ def mc_es_control(env: GridWorld4x4, episodes: int = 1500, gamma: float | None =
7793
pi = np.zeros((S, A), dtype=float)
7894
pi[np.arange(S), np.argmax(Q, axis=1)] = 1.0
7995
return Q, pi
80-
81-
# Optional: run this file directly for a quick check
82-
if __name__ == "__main__":
83-
env = GridWorld4x4(step_reward=-1.0, goal=(0, 3), gamma=1.0)
84-
Q, pi = mc_es_control(env, episodes=2000, seed=0)
85-
start = env.s2i[(0, 0)]
86-
print("Q(start):", Q[start])
87-
print("Greedy action at start:", int(np.argmax(pi[start])))

ch5_monte_carlo/examples/mc_control_onpolicy_gridworld.py

Lines changed: 29 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,37 +1,47 @@
1-
# ch5_monte_carlo/examples/mc_control_onpolicy_gridworld.py
2-
import numpy as np
1+
import numpy as np
32
from ch4_dynamic_programming.gridworld import GridWorld4x4
43

54
__all__ = ["mc_control_onpolicy", "ACTIONS", "generate_episode_onpolicy"]
65

7-
# Must match the environment's action ordering
86
ACTIONS = [(0, 1), (0, -1), (1, 0), (-1, 0)] # R, L, D, U
97

10-
def _epsilon_greedy(Q_row: np.ndarray, epsilon: float, rng: np.random.Generator) -> int:
11-
if rng.random() < epsilon:
12-
return int(rng.integers(len(Q_row)))
13-
return int(np.argmax(Q_row))
8+
def _is_terminal(env: GridWorld4x4, s) -> bool:
9+
if hasattr(env, "is_terminal"):
10+
return bool(env.is_terminal(s))
11+
st = s if isinstance(s, tuple) else env.i2s[int(s)]
12+
return st == env.goal
13+
14+
def _step(env: GridWorld4x4, s, a):
15+
if hasattr(env, "step"):
16+
return env.step(s, a)
17+
s_idx = env.s2i[s] if isinstance(s, tuple) else int(s)
18+
probs = env.P[s_idx, a]
19+
sp_idx = int(np.argmax(probs))
20+
r = float(env.R[s_idx, a, sp_idx])
21+
return env.i2s[sp_idx], r
22+
23+
def _epsilon_greedy(q_row: np.ndarray, epsilon: float, rng: np.random.Generator) -> int:
24+
return int(rng.integers(len(q_row))) if rng.random() < epsilon else int(np.argmax(q_row))
1425

1526
def generate_episode_onpolicy(env: GridWorld4x4, Q: np.ndarray, epsilon: float,
16-
rng: np.random.Generator, max_steps: int = 10_000):
17-
"""Start from a random non-terminal state; follow ε-greedy w.r.t. Q throughout."""
18-
non_terminal = [s for s in env.S if not env.is_terminal(s)]
27+
rng: np.random.Generator, max_steps: int = 10000):
28+
"""Start from a random non-terminal state; follow ε-greedy w.r.t. Q."""
29+
non_terminal = [s for s in env.S if not _is_terminal(env, s)]
1930
s = non_terminal[rng.integers(len(non_terminal))]
20-
S, A = len(env.S), len(env.A)
2131

2232
states, actions, rewards = [s], [], [0.0]
2333
steps = 0
24-
while not env.is_terminal(s) and steps < max_steps:
34+
while not _is_terminal(env, s) and steps < max_steps:
2535
a = _epsilon_greedy(Q[env.s2i[s]], epsilon, rng)
2636
actions.append(a)
27-
sp, r = env.step(s, a)
37+
sp, r = _step(env, s, a)
2838
rewards.append(float(r))
2939
s = sp
3040
states.append(s)
3141
steps += 1
3242

3343
# first-visit returns
34-
gamma = env.gamma
44+
gamma = float(getattr(env, "gamma", 1.0))
3545
G = 0.0
3646
returns = np.zeros(len(actions), dtype=float)
3747
for t in range(len(actions) - 1, -1, -1):
@@ -43,15 +53,15 @@ def mc_control_onpolicy(env: GridWorld4x4, episodes: int = 5000,
4353
epsilon: float = 0.1, gamma: float | None = None,
4454
seed: int | None = None):
4555
"""
46-
On-policy Monte Carlo control using ε-greedy behavior/target policy (no exploring starts).
56+
On-policy MC control using ε-greedy behavior/target policy (no ES).
4757
Returns:
48-
Q: (S,A) table
49-
pi: (S,A) deterministic greedy policy derived from Q
58+
Q: (S,A)
59+
pi: (S,A) deterministic greedy policy
5060
"""
5161
rng = np.random.default_rng(seed)
5262
S, A = len(env.S), len(env.A)
5363
if gamma is None:
54-
gamma = float(env.gamma)
64+
gamma = float(getattr(env, "gamma", 1.0))
5565

5666
Q = np.zeros((S, A), dtype=float)
5767
N = np.zeros((S, A), dtype=float)
@@ -63,21 +73,13 @@ def mc_control_onpolicy(env: GridWorld4x4, episodes: int = 5000,
6373
s_idx = env.s2i[s]
6474
key = (s_idx, a)
6575
if key in seen:
66-
continue # first-visit MC
76+
continue
6777
seen.add(key)
6878
G = returns[t]
6979
N[s_idx, a] += 1.0
7080
alpha = 1.0 / N[s_idx, a]
7181
Q[s_idx, a] += alpha * (G - Q[s_idx, a])
7282

73-
# deterministic greedy policy
7483
pi = np.zeros((S, A), dtype=float)
7584
pi[np.arange(S), np.argmax(Q, axis=1)] = 1.0
7685
return Q, pi
77-
78-
if __name__ == "__main__":
79-
env = GridWorld4x4(step_reward=-1.0, goal=(0, 3), gamma=1.0)
80-
Q, pi = mc_control_onpolicy(env, episodes=3000, epsilon=0.1, seed=0)
81-
s0 = env.s2i[(0, 0)]
82-
print("Q(start):", Q[s0])
83-
print("Greedy action at start:", int(np.argmax(pi[s0])))

0 commit comments

Comments
 (0)