-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain.py
More file actions
47 lines (37 loc) · 1.34 KB
/
train.py
File metadata and controls
47 lines (37 loc) · 1.34 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import gymnasium as gym
import time
import numpy as np
env = gym.make("Taxi-v3")
obs, info = env.reset()
# Q TABLE: learned table that scores each state/action pair. d: states x actions
# REWARD TABLE: each state has a predetermined award [(probability, nextstate, reward, done)]
q_table = np.zeros([env.observation_space.n, env.action_space.n])
# HYPERPARAMETERS
epsilon = 0.1
alpha = 0.1
gamma = 0.6
for i in range(100000):
state = env.reset()
state = state[0]
steps, penalties, reward, done = 0, 0, 0, False
while not done:
if np.random.rand() < epsilon:
action = env.action_space.sample()
else:
# choose the best known action from the Q-table given a state
action = np.argmax(q_table[state])
next_state, reward, done, info, mask = env.step(action)
old_q_value = q_table[state, action]
next_max = np.max(q_table[next_state])
new_value = (1 - alpha) * old_q_value + alpha * (reward + gamma * next_max)
q_table[state, action] = new_value
if reward == -10:
penalties += 1
state = next_state
steps += 1
if i % 100 == 0:
print(f"Episode: {i}")
print(f"num penalties: {penalties}")
print(f"num steps: {steps}")
np.save("q_table_taxi.npy", q_table)
print("Training finished.\n")