-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdqn.py
More file actions
210 lines (173 loc) · 7.48 KB
/
dqn.py
File metadata and controls
210 lines (173 loc) · 7.48 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
"""Main DQN agent."""
from keras import optimizers
class DQNAgent:
"""Class implementing DQN.
This is a basic outline of the functions/parameters you will need
in order to implement the DQNAgnet. This is just to get you
started. You may need to tweak the parameters, add new ones, etc.
Feel free to change the functions and function parameters that the
class provides.
We have provided docstrings to go along with our suggested API.
Parameters
----------
q_network: keras.models.Model
Your Q-network model.
preprocessor: deeprl_hw2.core.Preprocessor
The preprocessor class. See the associated classes for more
details.
memory: deeprl_hw2.core.Memory
Your replay memory.
gamma: float
Discount factor.
target_update_freq: float
Frequency to update the target network. You can either provide a
number representing a soft target update (see utils.py) or a
hard target update (see utils.py and Atari paper.)
num_burn_in: int
Before you begin updating the Q-network your replay memory has
to be filled up with some number of samples. This number says
how many.
train_freq: int
How often you actually update your Q-Network. Sometimes
stability is improved if you collect a couple samples for your
replay memory, for every Q-network update that you run.
batch_size: int
How many samples in each minibatch.
"""
def __init__(self,
q_network,
targetq_network,
preprocessor,
memory,
policy,
gamma,
target_update_freq,
num_burn_in,
train_freq,
batch_size):
self.q_network = q_network
self.preprocessor = preprocessor
self.memory = memory
self.policy = policy
self.gamma = gamma
self.target_update_freq = target_update_freq
self.num_burn_in = num_burn_in
self.train_freq = train_freq
self.batch_size = batch_size
self.Q=q_network
self.targetQ=targetq_network
def compile(self, optimizer, loss_func):
"""Setup all of the TF graph variables/ops.
This is inspired by the compile method on the
keras.models.Model class.
This is a good place to create the target network, setup your
loss function and any placeholders you might need.
You should use the mean_huber_loss function as your
loss_function. You can also experiment with MSE and other
losses.
The optimizer can be whatever class you want. We used the
keras.optimizers.Optimizer class. Specifically the Adam
optimizer.
"""
self.Q.compile(loss=loss_func, optimizer=optimizer)
self.targetQ.compile(loss=loss_func, optimizer=optimizer)
def calc_q_values(self, state):
"""Given a state (or batch of states) calculate the Q-values.
Basically run your network on these states.
Return
------
Q-values for the state(s)
"""
pass
def select_action(self, state, **kwargs):
"""Select the action based on the current state.
You will probably want to vary your behavior here based on
which stage of training your in. For example, if you're still
collecting random samples you might want to use a
UniformRandomPolicy.
If you're testing, you might want to use a GreedyEpsilonPolicy
with a low epsilon.
If you're training, you might want to use the
LinearDecayGreedyEpsilonPolicy.
This would also be a good place to call
process_state_for_network in your preprocessor.
Returns
--------
selected action
"""
pass
def update_policy(self):
"""Update your policy.
Behavior may differ based on what stage of training your
in. If you're in training mode then you should check if you
should update your network parameters based on the current
step and the value you set for train_freq.
Inside, you'll want to sample a minibatch, calculate the
target values, update your network, and then update your
target values.
You might want to return the loss and other metrics as an
output. They can help you monitor how training is going.
"""
pass
def fit(self, env, num_iterations, max_episode_length=None):
"""Fit your model to the provided environment.
Its a good idea to print out things like loss, average reward,
Q-values, etc to see if your agent is actually improving.
You should probably also periodically save your network
weights and any other useful info.
This is where you should sample actions from your network,
collect experience samples and add them to your replay memory,
and update your network parameters.
Parameters
----------
env: gym.Env
This is your Atari environment. You should wrap the
environment using the wrap_atari_env function in the
utils.py
num_iterations: int
How many samples/updates to perform.
max_episode_length: int
How long a single episode should last before the agent
resets. Can help exploration.
"""
for i in range(num_iterations):
start = env.reset()
start_processed = self.preprocessor.process_state_for_memory2(start)
start_hash = self.memory.hashfunc(start_processed)
st=start_hash
for play in range(max_episode_length):
phi_st=self.memory.phi(st)
qvals=self.Q.predict(phi_st)
at=self.policy.select_action(qvals)
if len(self.memory.experience)<self.num_burn_in: at=env.action_space.sample()
next_tuple=env.step(at)
rt = next_tuple[1]
isterminal = next_tuple[2]
st_hash, st1_hash = self.memory.append(st, at, rt, next_tuple)
st=st1_hash
if isterminal: break
if i%self.target_update_freq==0:
obs_batch, action_batch, reward_batch, next_obs_batch, done_batch=self.memory.sample(self.batch_size)
y_true=[]
#y_pred=[]
for j in range(len(obs_batch)):
if done_batch[j]: y_true.append(reward_batch[j])
else:
phi_j1=self.memory.phi(next_obs_batch[j])
qvals=targetQ.predict(phi_j1)
y_true.append(reward_batch[j]+self.gamma*max(qvals))
#phi_j=self.memory.phi(obs_batch[j])
#y_pred.append(Q.predict(phi_j)[action_batch[j]])
self.Q.fit(obs_batch,y_true)
self.targetQ=self.Q
def evaluate(self, env, num_episodes, max_episode_length=None):
"""Test your agent with a provided environment.
You shouldn't update your network parameters here. Also if you
have any layers that vary in behavior between train/test time
(such as dropout or batch norm), you should set them to test.
Basically run your policy on the environment and collect stats
like cumulative reward, average episode length, etc.
You can also call the render function here if you want to
visually inspect your policy.
"""
pass