Skip to content

Commit bdd1f03

Browse files
Add corrected Chapter 3 (Multi-Armed Bandits) with UCB1 fix, plots, tests
1 parent c2717b2 commit bdd1f03

14 files changed

Lines changed: 191 additions & 182 deletions
Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1 @@
1-
from .bandit_env import MultiArmedBanditBernoulli
2-
from .algorithms import EpsilonGreedy, UCB1, ThompsonSamplingBeta, simulate
1+

ch3_multi_armed_bandits/algorithms.py

Lines changed: 0 additions & 87 deletions
This file was deleted.

ch3_multi_armed_bandits/bandit_env.py

Lines changed: 0 additions & 24 deletions
This file was deleted.

ch3_multi_armed_bandits/bandits.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
from __future__ import annotations
2+
import numpy as np
3+
from dataclasses import dataclass
4+
5+
@dataclass
6+
class BernoulliBandit:
7+
probs: np.ndarray
8+
9+
def __post_init__(self):
10+
self.probs = np.array(self.probs, dtype=float)
11+
assert self.probs.ndim == 1 and (0 <= self.probs).all() and (self.probs <= 1).all()
12+
self.K = self.probs.shape[0]
13+
self.opt_idx = int(np.argmax(self.probs))
14+
self.opt_mean = float(self.probs[self.opt_idx])
15+
16+
def pull(self, arm: int, rng: np.random.Generator) -> float:
17+
return float(rng.random() < self.probs[arm])
18+
19+
def pseudo_regret(self, arm: int) -> float:
20+
return self.opt_mean - float(self.probs[arm])
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
from __future__ import annotations
2+
import numpy as np
3+
4+
class EpsilonGreedy:
5+
def __init__(self, K: int, epsilon: float = 0.1, rng: np.random.Generator | None = None):
6+
self.K = K
7+
self.epsilon = float(epsilon)
8+
self.rng = rng or np.random.default_rng()
9+
self.counts = np.zeros(K, dtype=int)
10+
self.values = np.zeros(K, dtype=float)
11+
12+
def select_arm(self) -> int:
13+
if self.rng.random() < self.epsilon:
14+
return int(self.rng.integers(self.K))
15+
return int(np.argmax(self.values))
16+
17+
def update(self, arm: int, reward: float):
18+
self.counts[arm] += 1
19+
n = self.counts[arm]
20+
self.values[arm] += (reward - self.values[arm]) / n

ch3_multi_armed_bandits/examples/demo_bandit.py

Lines changed: 0 additions & 19 deletions
This file was deleted.
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
from __future__ import annotations
2+
import os
3+
import numpy as np
4+
import matplotlib.pyplot as plt
5+
from .bandits import BernoulliBandit
6+
from .epsilon_greedy import EpsilonGreedy
7+
from .ucb import UCB1
8+
from .thompson import ThompsonSamplingBernoulli
9+
10+
def run_algorithm(env, algo, T: int, seed: int) -> dict:
11+
rng = np.random.default_rng(seed)
12+
rewards = np.zeros(T, dtype=float)
13+
regret = np.zeros(T, dtype=float)
14+
for t in range(T):
15+
a = algo.select_arm()
16+
r = env.pull(a, rng)
17+
algo.update(a, r)
18+
rewards[t] = r
19+
regret[t] = env.pseudo_regret(a)
20+
return {
21+
"rewards": rewards,
22+
"cum_rewards": np.cumsum(rewards),
23+
"regret": regret,
24+
"cum_regret": np.cumsum(regret),
25+
}
26+
27+
def average_over_runs(env, algo_ctor, T: int, n_runs: int, base_seed: int = 0) -> dict:
28+
cum_regrets = []
29+
for run in range(n_runs):
30+
algo = algo_ctor()
31+
result = run_algorithm(env, algo, T, seed=base_seed + run)
32+
cum_regrets.append(result["cum_regret"])
33+
cum_regrets = np.array(cum_regrets)
34+
mean = cum_regrets.mean(axis=0)
35+
se = cum_regrets.std(axis=0, ddof=1) / np.sqrt(n_runs)
36+
return {"mean": mean, "se": se}
37+
38+
def plot_regret(curves: dict, title: str, fname: str | None):
39+
fig, ax = plt.subplots()
40+
for label, stats in curves.items():
41+
ax.plot(stats["mean"], label=label)
42+
ax.set_xlabel("Time")
43+
ax.set_ylabel("Average cumulative pseudo-regret")
44+
ax.set_title(title)
45+
ax.legend()
46+
if fname:
47+
out_dir = os.path.dirname(fname)
48+
if out_dir and not os.path.exists(out_dir):
49+
os.makedirs(out_dir, exist_ok=True)
50+
fig.savefig(fname, bbox_inches="tight")
51+
else:
52+
plt.show()
53+
54+
def main():
55+
probs = np.array([0.2, 0.25, 0.3, 0.35, 0.5])
56+
env = BernoulliBandit(probs=probs)
57+
T = 2000
58+
n_runs = 200
59+
curves = {}
60+
curves["ε-greedy(0.10)"] = average_over_runs(env, lambda: EpsilonGreedy(env.K, 0.10), T, n_runs, 123)
61+
curves["ε-greedy(0.01)"] = average_over_runs(env, lambda: EpsilonGreedy(env.K, 0.01), T, n_runs, 223)
62+
curves["UCB1(c=0.5)"] = average_over_runs(env, lambda: UCB1(env.K, c=0.5), T, n_runs, 323)
63+
curves["Thompson (Beta-Bernoulli)"] = average_over_runs(env, lambda: ThompsonSamplingBernoulli(env.K), T, n_runs, 423)
64+
here = os.path.dirname(__file__)
65+
out_path = os.path.join(here, "plots", "regret_bernoulli.png")
66+
plot_regret(curves, "Multi-Armed Bandits: Average Cumulative Pseudo-Regret", out_path)
67+
print(f"Saved plot to {out_path}")
68+
69+
if __name__ == "__main__":
70+
main()
39.6 KB
Loading

ch3_multi_armed_bandits/requirements.txt

Lines changed: 0 additions & 3 deletions
This file was deleted.

ch3_multi_armed_bandits/tests/test_bandit_algorithms.py

Lines changed: 0 additions & 47 deletions
This file was deleted.

0 commit comments

Comments
 (0)