Skip to content

Commit f0e2151

Browse files
committed
Add Q-Learning algorithm implementation with epsilon-greedy policy and grid world demo code
1 parent 2b09382 commit f0e2151

File tree

1 file changed

+177
-0
lines changed

1 file changed

+177
-0
lines changed

machine_learning/q_learning.py

Lines changed: 177 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,177 @@
1+
"""
2+
Q-Learning is a widely-used model-free algorithm in reinforcement learning that
3+
learns the optimal action-value function Q(s, a), which tells an agent the expected
4+
utility of taking action a in state s and then following the optimal policy after.
5+
It is able to find the best policy for any given finite Markov decision process (MDP)
6+
without requiring a model of the environment.
7+
8+
See: [https://en.wikipedia.org/wiki/Q-learning](https://en.wikipedia.org/wiki/Q-learning)
9+
"""
10+
11+
from collections import defaultdict
12+
import random
13+
14+
# Hyperparameters for Q-Learning
15+
LEARNING_RATE = 0.1
16+
DISCOUNT_FACTOR = 0.97
17+
EPSILON = 0.2
18+
EPSILON_DECAY = 0.995
19+
EPSILON_MIN = 0.01
20+
21+
# Global Q-table to store state-action values
22+
q_table = defaultdict(lambda: defaultdict(float))
23+
24+
# Environment variables for simple grid world
25+
SIZE = 4
26+
GOAL = (SIZE - 1, SIZE - 1)
27+
current_state = (0, 0)
28+
29+
30+
def get_q_value(state, action):
31+
"""
32+
Get Q-value for a given state-action pair.
33+
34+
>>> get_q_value((0, 0), 2)
35+
0.0
36+
"""
37+
return q_table[state][action]
38+
39+
40+
def get_best_action(state, available_actions):
41+
"""
42+
Get the action with maximum Q-value in the given state.
43+
44+
>>> q_table[(0, 0)][1] = 0.7
45+
>>> q_table[(0, 0)][2] = 0.7
46+
>>> q_table[(0, 0)][3] = 0.5
47+
>>> get_best_action((0, 0), [1, 2, 3]) in [1, 2]
48+
True
49+
"""
50+
if not available_actions:
51+
raise ValueError("No available actions provided")
52+
max_q = max(q_table[state][a] for a in available_actions)
53+
best = [a for a in available_actions if q_table[state][a] == max_q]
54+
return random.choice(best)
55+
56+
57+
def choose_action(state, available_actions):
58+
"""
59+
Choose action using epsilon-greedy policy.
60+
61+
>>> EPSILON = 0.0
62+
>>> q_table[(0, 0)][1] = 1.0
63+
>>> q_table[(0, 0)][2] = 0.5
64+
>>> choose_action((0, 0), [1, 2])
65+
1
66+
"""
67+
global EPSILON
68+
if not available_actions:
69+
raise ValueError("No available actions provided")
70+
if random.random() < EPSILON:
71+
return random.choice(available_actions)
72+
return get_best_action(state, available_actions)
73+
74+
75+
def update(state, action, reward, next_state, next_available_actions, done=False):
76+
"""
77+
Perform Q-value update for a transition using the Q-learning rule.
78+
79+
Q(s,a) <- Q(s,a) + alpha * (r + gamma * max_a' Q(s',a') - Q(s,a))
80+
81+
>>> LEARNING_RATE = 0.5
82+
>>> DISCOUNT_FACTOR = 0.9
83+
>>> update((0,0), 1, 1.0, (0,1), [1,2], done=True)
84+
>>> get_q_value((0,0), 1)
85+
0.5
86+
"""
87+
global LEARNING_RATE, DISCOUNT_FACTOR
88+
max_q_next = 0.0 if done or not next_available_actions else max(
89+
get_q_value(next_state, a) for a in next_available_actions
90+
)
91+
old_q = get_q_value(state, action)
92+
new_q = (1 - LEARNING_RATE) * old_q + LEARNING_RATE * (
93+
reward + DISCOUNT_FACTOR * max_q_next
94+
)
95+
q_table[state][action] = new_q
96+
97+
98+
def get_policy():
99+
"""
100+
Extract a deterministic policy from the Q-table.
101+
102+
>>> q_table[(1,2)][1] = 2.0
103+
>>> q_table[(1,2)][2] = 1.0
104+
>>> get_policy()[(1,2)]
105+
1
106+
"""
107+
policy = {}
108+
for s, a_dict in q_table.items():
109+
if a_dict:
110+
policy[s] = max(a_dict, key=a_dict.get)
111+
return policy
112+
113+
114+
def reset_env():
115+
"""
116+
Reset the environment to initial state.
117+
"""
118+
global current_state
119+
current_state = (0, 0)
120+
return current_state
121+
122+
123+
def get_available_actions_env():
124+
"""
125+
Get available actions in the current environment state.
126+
"""
127+
return [0, 1, 2, 3]
128+
129+
130+
def step_env(action):
131+
"""
132+
Take a step in the environment with the given action.
133+
"""
134+
global current_state
135+
x, y = current_state
136+
if action == 0: # up
137+
x = max(0, x - 1)
138+
elif action == 1: # right
139+
y = min(SIZE - 1, y + 1)
140+
elif action == 2: # down
141+
x = min(SIZE - 1, x + 1)
142+
elif action == 3: # left
143+
y = max(0, y - 1)
144+
next_state = (x, y)
145+
reward = 10.0 if next_state == GOAL else -1.0
146+
done = next_state == GOAL
147+
current_state = next_state
148+
return next_state, reward, done
149+
150+
151+
def run_q_learning():
152+
"""
153+
Run Q-Learning on the simple grid world environment.
154+
"""
155+
global EPSILON
156+
episodes = 200
157+
for episode in range(episodes):
158+
state = reset_env()
159+
done = False
160+
while not done:
161+
actions = get_available_actions_env()
162+
action = choose_action(state, actions)
163+
next_state, reward, done = step_env(action)
164+
next_actions = get_available_actions_env()
165+
update(state, action, reward, next_state, next_actions, done)
166+
state = next_state
167+
EPSILON = max(EPSILON * EPSILON_DECAY, EPSILON_MIN)
168+
policy = get_policy()
169+
print("Learned Policy (state: action):")
170+
for s, a in sorted(policy.items()):
171+
print(f"{s}: {a}")
172+
173+
174+
if __name__ == "__main__":
175+
import doctest
176+
doctest.testmod()
177+
run_q_learning()

0 commit comments

Comments
 (0)