-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathpreprocessors.py
More file actions
153 lines (120 loc) · 5.23 KB
/
preprocessors.py
File metadata and controls
153 lines (120 loc) · 5.23 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
"""Suggested Preprocessors."""
import numpy as np
from PIL import Image
from core import Preprocessor
class HistoryPreprocessor(Preprocessor):
"""Keeps the last k states.
Useful for domains where you need velocities, but the state
contains only positions.
When the environment starts, this will just fill the initial
sequence values with zeros k times.
Parameters
----------
history_length: int
Number of previous states to prepend to state being processed.
"""
def __init__(self, history_length=1):
self.history_length=history_length
self.queue=[]
def process_state_for_network(self, state,toappend=True):
"""You only want history when you're deciding the current action to take."""
if len(self.queue)<self.history_length: toreturn=[0]*(self.history_length-len(self.queue))+self.queue
else:
toreturn=self.queue
if toappend: self.queue = self.queue[1:]
if toappend: self.queue.append(state)
return toreturn
def reset(self):
"""Reset the history sequence.
Useful when you start a new episode.
"""
self.queue=[]
def get_config(self):
return {'history_length': self.history_length}
class AtariPreprocessor(Preprocessor):
"""Converts images to greyscale and downscales.
Based on the preprocessing step described in:
@article{mnih15_human_level_contr_throug_deep_reinf_learn,
author = {Volodymyr Mnih and Koray Kavukcuoglu and David
Silver and Andrei A. Rusu and Joel Veness and Marc
G. Bellemare and Alex Graves and Martin Riedmiller
and Andreas K. Fidjeland and Georg Ostrovski and
Stig Petersen and Charles Beattie and Amir Sadik and
Ioannis Antonoglou and Helen King and Dharshan
Kumaran and Daan Wierstra and Shane Legg and Demis
Hassabis},
title = {Human-Level Control Through Deep Reinforcement
Learning},
journal = {Nature},
volume = 518,
number = 7540,
pages = {529-533},
year = 2015,
doi = {10.1038/nature14236},
url = {http://dx.doi.org/10.1038/nature14236},
}
You may also want to max over frames to remove flickering. Some
games require this (based on animations and the limited sprite
drawing capabilities of the original Atari).
Parameters
----------
new_size: 2 element tuple
The size that each image in the state should be scaled to. e.g
(84, 84) will make each image in the output have shape (84, 84).
"""
def __init__(self, new_size):
self.new_size=new_size
def process_state_for_memory(self, state): #80
"""Scale, convert to greyscale and store as uint8.
We don't want to save floating point numbers in the replay
memory. We get the same resolution as uint8, but use a quarter
to an eigth of the bytes (depending on float32 or float64)
We recommend using the Python Image Library (PIL) to do the
image conversions.
"""
I=state
I = I[35:195] # crop
I = I[::2,::2,0] # downsample by factor of 2
I[I == 144] = 0 # erase background (background type 1)
I[I == 109] = 0 # erase background (background type 2)
I[I != 0] = 1 # everything else (paddles, ball) just set to 1
return I
def process_state_for_memory2(self,state): #84
frame=state
img = np.reshape(frame, [210, 160, 3]).astype(np.float32)
img = img[:, :, 0] * 0.299 + img[:, :, 1] * 0.587 + img[:, :, 2] * 0.114
img = Image.fromarray(img)
resized_screen = img.resize((84, 110), Image.BILINEAR)
resized_screen = np.array(resized_screen)
x_t = resized_screen[18:102, :]
x_t = np.reshape(x_t, [84, 84, 1])
return x_t.astype(np.uint8)
def process_state_for_network(self, state):
"""Scale, convert to greyscale and store as float32.
Basically same as process state for memory, but this time
outputs float32 images.
"""
I=self.process_state_for_memory2(state)
return I.astype(np.float32)
def process_batch(self, samples):
"""The batches from replay memory will be uint8, convert to float32.
Same as process_state_for_network but works on a batch of
samples from the replay memory. Meaning you need to convert
both state and next state values.
"""
pass
def process_reward(self, reward):
"""Clip reward between -1 and 1."""
return np.sign(reward)
class PreprocessorSequence(Preprocessor):
"""You may find it useful to stack multiple prepcrocesosrs (such as the History and the AtariPreprocessor).
You can easily do this by just having a class that calls each preprocessor in succession.
For example, if you call the process_state_for_network and you
have a sequence of AtariPreproccessor followed by
HistoryPreprocessor. This this class could implement a
process_state_for_network that does something like the following:
state = atari.process_state_for_network(state)
return history.process_state_for_network(state)
"""
def __init__(self, preprocessors):
pass