Skip to content

Commit 815dcca

Browse files
ch5: geometry-based step; remove dependency on env.P shape
1 parent 548b73b commit 815dcca

2 files changed

Lines changed: 33 additions & 40 deletions

File tree

ch5_monte_carlo/examples/mc_control_es_gridworld.py

Lines changed: 19 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -7,38 +7,41 @@
77
ACTIONS = [(0, 1), (0, -1), (1, 0), (-1, 0)] # R, L, D, U
88

99
def _is_terminal(env: GridWorld4x4, s) -> bool:
10-
"""Robust terminal check even if env.is_terminal is absent."""
1110
if hasattr(env, "is_terminal"):
1211
return bool(env.is_terminal(s))
1312
st = s if isinstance(s, tuple) else env.i2s[int(s)]
1413
return st == env.goal
1514

1615
def _step(env: GridWorld4x4, s, a):
17-
"""Use env.step if present; otherwise use P/R (deterministic)."""
16+
"""Robust step that does NOT depend on env.P; uses geometry."""
1817
if hasattr(env, "step"):
1918
return env.step(s, a)
20-
s_idx = env.s2i[s] if isinstance(s, tuple) else int(s)
21-
probs = env.P[s_idx, a]
22-
sp_idx = int(np.argmax(probs))
23-
r = float(env.R[s_idx, a, sp_idx])
24-
return env.i2s[sp_idx], r
19+
st = s if isinstance(s, tuple) else env.i2s[int(s)]
20+
i, j = st
21+
di, dj = ACTIONS[a]
22+
# infer grid size
23+
n = getattr(env, "n", int(round(len(env.S) ** 0.5)))
24+
ni, nj = i + di, j + dj
25+
if not (0 <= ni < n and 0 <= nj < n):
26+
ni, nj = i, j # wall -> stay
27+
sp = (ni, nj)
28+
# reward: step cost unless entering goal, then 0.0 (matches your ch4 tests)
29+
step_reward = float(getattr(env, "step_reward", -1.0))
30+
r = 0.0 if sp == getattr(env, "goal", (0, 3)) else step_reward
31+
return sp, r
2532

2633
def _greedy_action(q_row: np.ndarray) -> int:
2734
return int(np.argmax(q_row))
2835

2936
def generate_episode_es(env: GridWorld4x4, Q: np.ndarray, gamma: float, max_steps: int = 10000):
30-
"""
31-
Exploring starts: start from a random non-terminal state and random action,
32-
then follow greedy policy thereafter. Returns (states, actions, returns).
33-
"""
3437
rng = np.random.default_rng()
3538
non_terminal = [s for s in env.S if not _is_terminal(env, s)]
3639
s = non_terminal[rng.integers(len(non_terminal))]
3740
a = int(rng.integers(len(env.A)))
3841

3942
states = [s]
4043
actions = [a]
41-
rewards = [0.0] # so rewards[t+1] aligns with action at t
44+
rewards = [0.0]
4245

4346
steps = 0
4447
while not _is_terminal(env, s) and steps < max_steps:
@@ -52,7 +55,6 @@ def generate_episode_es(env: GridWorld4x4, Q: np.ndarray, gamma: float, max_step
5255
actions.append(a)
5356
steps += 1
5457

55-
# first-visit returns
5658
G = 0.0
5759
returns = np.zeros(len(states), dtype=float)
5860
for t in range(len(states) - 1, -1, -1):
@@ -61,30 +63,23 @@ def generate_episode_es(env: GridWorld4x4, Q: np.ndarray, gamma: float, max_step
6163
return states, actions, returns
6264

6365
def mc_es_control(env: GridWorld4x4, episodes: int = 1500, gamma: float | None = None, seed: int | None = None):
64-
"""
65-
On-policy Monte Carlo control with Exploring Starts (ES).
66-
Returns:
67-
Q: (S,A) action-value table
68-
pi: (S,A) deterministic greedy policy derived from Q
69-
"""
7066
if seed is not None:
7167
np.random.seed(seed)
7268
if gamma is None:
73-
gamma = float(env.gamma)
69+
gamma = float(getattr(env, "gamma", 1.0))
7470

7571
S, A = len(env.S), len(env.A)
7672
Q = np.zeros((S, A), dtype=float)
77-
N = np.zeros((S, A), dtype=float) # first-visit counts
73+
N = np.zeros((S, A), dtype=float)
7874

7975
for _ in range(episodes):
8076
states, actions, returns = generate_episode_es(env, Q, gamma)
8177
seen = set()
8278
for t, (s, a) in enumerate(zip(states, actions)):
8379
s_idx = env.s2i[s]
84-
key = (s_idx, a)
85-
if key in seen:
80+
if (s_idx, a) in seen:
8681
continue
87-
seen.add(key)
82+
seen.add((s_idx, a))
8883
G = returns[t]
8984
N[s_idx, a] += 1.0
9085
alpha = 1.0 / N[s_idx, a]

ch5_monte_carlo/examples/mc_control_onpolicy_gridworld.py

Lines changed: 14 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -12,20 +12,26 @@ def _is_terminal(env: GridWorld4x4, s) -> bool:
1212
return st == env.goal
1313

1414
def _step(env: GridWorld4x4, s, a):
15+
"""Robust step that does NOT depend on env.P; uses geometry."""
1516
if hasattr(env, "step"):
1617
return env.step(s, a)
17-
s_idx = env.s2i[s] if isinstance(s, tuple) else int(s)
18-
probs = env.P[s_idx, a]
19-
sp_idx = int(np.argmax(probs))
20-
r = float(env.R[s_idx, a, sp_idx])
21-
return env.i2s[sp_idx], r
18+
st = s if isinstance(s, tuple) else env.i2s[int(s)]
19+
i, j = st
20+
di, dj = ACTIONS[a]
21+
n = getattr(env, "n", int(round(len(env.S) ** 0.5)))
22+
ni, nj = i + di, j + dj
23+
if not (0 <= ni < n and 0 <= nj < n):
24+
ni, nj = i, j
25+
sp = (ni, nj)
26+
step_reward = float(getattr(env, "step_reward", -1.0))
27+
r = 0.0 if sp == getattr(env, "goal", (0, 3)) else step_reward
28+
return sp, r
2229

2330
def _epsilon_greedy(q_row: np.ndarray, epsilon: float, rng: np.random.Generator) -> int:
2431
return int(rng.integers(len(q_row))) if rng.random() < epsilon else int(np.argmax(q_row))
2532

2633
def generate_episode_onpolicy(env: GridWorld4x4, Q: np.ndarray, epsilon: float,
2734
rng: np.random.Generator, max_steps: int = 10000):
28-
"""Start from a random non-terminal state; follow ε-greedy w.r.t. Q."""
2935
non_terminal = [s for s in env.S if not _is_terminal(env, s)]
3036
s = non_terminal[rng.integers(len(non_terminal))]
3137

@@ -40,7 +46,6 @@ def generate_episode_onpolicy(env: GridWorld4x4, Q: np.ndarray, epsilon: float,
4046
states.append(s)
4147
steps += 1
4248

43-
# first-visit returns
4449
gamma = float(getattr(env, "gamma", 1.0))
4550
G = 0.0
4651
returns = np.zeros(len(actions), dtype=float)
@@ -52,12 +57,6 @@ def generate_episode_onpolicy(env: GridWorld4x4, Q: np.ndarray, epsilon: float,
5257
def mc_control_onpolicy(env: GridWorld4x4, episodes: int = 5000,
5358
epsilon: float = 0.1, gamma: float | None = None,
5459
seed: int | None = None):
55-
"""
56-
On-policy MC control using ε-greedy behavior/target policy (no ES).
57-
Returns:
58-
Q: (S,A)
59-
pi: (S,A) deterministic greedy policy
60-
"""
6160
rng = np.random.default_rng(seed)
6261
S, A = len(env.S), len(env.A)
6362
if gamma is None:
@@ -71,10 +70,9 @@ def mc_control_onpolicy(env: GridWorld4x4, episodes: int = 5000,
7170
seen = set()
7271
for t, (s, a) in enumerate(zip(states, actions)):
7372
s_idx = env.s2i[s]
74-
key = (s_idx, a)
75-
if key in seen:
73+
if (s_idx, a) in seen:
7674
continue
77-
seen.add(key)
75+
seen.add((s_idx, a))
7876
G = returns[t]
7977
N[s_idx, a] += 1.0
8078
alpha = 1.0 / N[s_idx, a]

0 commit comments

Comments
 (0)