-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdemoCart.py
More file actions
38 lines (29 loc) · 1.04 KB
/
demoCart.py
File metadata and controls
38 lines (29 loc) · 1.04 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import random
import gym
import numpy as np
from tensorflow.keras.layers import Dense, Flatten
from tensorflow.keras.models import Sequential
from tensorflow.keras.optimizers import Adam
from rl.agents import DQNAgent
from rl.policy import BoltzmannQPolicy
from rl.memory import SequentialMemory
env = gym.make('CartPole-v1')
states = env.observation_space.shape[0]
actions = env.action_space.n
model = Sequential()
model.add(Flatten(input_shape=(1, states)))
model.add(Dense(24, activation='relu')) #relu = rectified linear unit (activation function)
model.add(Dense(24, activation='relu'))
model.add(Dense(actions, activation='linear'))
agent = DQNAgent(
model = model,
memory = SequentialMemory(limit=50000, window_length=1),
policy=BoltzmannQPolicy(),
nb_actions=actions,
nb_steps_warmup=10,
target_model_update=0.01,
)
agent.compile(Adam(learning_rate=1e-3), metrics=['mae'])
agent.fit(env, nb_steps=100000, visualize=True, verbose=1) #Trainign for 100000 iterations
print(np.mean(res.history['episode_reward']))
env.close()