-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathcusotm_atari_preprocessing.py
More file actions
148 lines (125 loc) · 6.03 KB
/
cusotm_atari_preprocessing.py
File metadata and controls
148 lines (125 loc) · 6.03 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
import numpy as np
import gym
from gym.spaces import Box
from gym.wrappers import TimeLimit
try:
import cv2
except ImportError:
cv2 = None
class AtariPreprocessing(gym.Wrapper):
r"""Atari 2600 preprocessings.
Specifically:
* NoopReset: obtain initial state by taking random number of no-ops on reset.
* Frame skipping: 4 by default
* Max-pooling: most recent two observations
* Termination signal when a life is lost: turned off by default. Not recommended by Machado et al. (2018).
* Resize to a square image: 84x84 by default
* Grayscale observation: optional
* Scale observation: optional
Args:
env (Env): environment
noop_max (int): max number of no-ops
frame_skip (int): the frequency at which the agent experiences the game.
screen_size (int): resize Atari frame
terminal_on_life_loss (bool): if True, then step() returns done=True whenever a
life is lost.
grayscale_obs (bool): if True, then gray scale observation is returned, otherwise, RGB observation
is returned.
scale_obs (bool): if True, then observation normalized in range [0,1] is returned. It also limits memory
optimization benefits of FrameStack Wrapper.
"""
def __init__(self, env, noop_max=30, frame_skip=4, screen_size=84, terminal_on_life_loss=False, grayscale_obs=True,
scale_obs=False):
super().__init__(env)
assert cv2 is not None, \
"opencv-python package not installed! Try running pip install gym[atari] to get dependencies for atari"
assert frame_skip > 0
assert screen_size > 0
assert noop_max >= 0
assert grayscale_obs, 'This Wrapper does not yet support RGB obersvation retruns'
if frame_skip > 1:
assert env.frameskip == 1, 'disable frame-skipping in the original env. for more than one' \
' frame-skip as it will be done by the wrapper'
self.noop_max = noop_max
assert env.unwrapped.get_action_meanings()[0] == 'NOOP'
self.frame_skip = frame_skip
self.screen_size = screen_size
self.terminal_on_life_loss = terminal_on_life_loss
self.grayscale_obs = grayscale_obs
self.scale_obs = scale_obs
# buffer of most recent two observations for max pooling
if grayscale_obs:
self.obs_buffer = [np.empty(env.observation_space.shape[:2], dtype=np.uint8),
np.empty(env.observation_space.shape[:2], dtype=np.uint8),
np.empty(env.observation_space.shape[:2], dtype=np.uint8)]
else:
self.obs_buffer = [np.empty(env.observation_space.shape, dtype=np.uint8),
np.empty(env.observation_space.shape, dtype=np.uint8),
np.empty(env.observation_space.shape, dtype=np.uint8)]
self.ale = env.unwrapped.ale
self.lives = 0
self.game_over = False
_low, _high, _obs_dtype = (0, 255, np.uint8) if not scale_obs else (0, 1, np.float32)
if grayscale_obs:
self.observation_space = Box(low=_low, high=_high,
shape=(screen_size, screen_size, frame_skip), dtype=_obs_dtype)
else:
self.observation_space = Box(low=_low, high=_high,
shape=(screen_size, screen_size, 3, frame_skip), dtype=_obs_dtype)
if scale_obs:
if grayscale_obs:
self.frame_stack = np.zeros((screen_size, screen_size, frame_skip), dtype='float32')
else:
pass # TODO implement rgb for this
else:
if grayscale_obs:
self.frame_stack = np.zeros((screen_size, screen_size, frame_skip), dtype='uint8')
else:
pass # TODO implement rgb for this
def step(self, action):
R = 0.0
for t in range(self.frame_skip):
_, reward, done, info = self.env.step(action)
R += reward
self.game_over = done
if self.terminal_on_life_loss:
new_lives = self.ale.lives()
done = done or new_lives < self.lives
self.lives = new_lives
if done:
break
if self.grayscale_obs:
self.obs_buffer[1] = self.obs_buffer[0]
self.ale.getScreenGrayscale(self.obs_buffer[0])
self.frame_stack[:, :, -(t+1)] = self._get_obs()
else:
self.obs_buffer[1] = self.obs_buffer[0]
self.ale.getScreenRGB2(self.obs_buffer[0])
self.frame_stack[:, :, -(t+1)] = self._get_obs()
return self.frame_stack, R, done, info
def reset(self, **kwargs):
# NoopReset
self.env.reset(**kwargs)
noops = self.env.unwrapped.np_random.randint(1, self.noop_max + 1) if self.noop_max > 0 else 0
for _ in range(noops):
_, _, done, _ = self.env.step(0)
if done:
self.env.reset(**kwargs)
self.lives = self.ale.lives()
if self.grayscale_obs:
self.ale.getScreenGrayscale(self.obs_buffer[0])
else:
self.ale.getScreenRGB2(self.obs_buffer[0])
self.obs_buffer[1].fill(0)
obs = self._get_obs()
self.frame_stack = np.stack([obs]*self.frame_skip, axis=-1)
return self.frame_stack
def _get_obs(self):
if self.frame_skip > 1: # more efficient in-place pooling
np.maximum(self.obs_buffer[0], self.obs_buffer[1], out=self.obs_buffer[2])
obs = cv2.resize(self.obs_buffer[2], (self.screen_size, self.screen_size), interpolation=cv2.INTER_AREA)
if self.scale_obs:
obs = np.asarray(obs, dtype=np.float32) / 255.0
else:
obs = np.asarray(obs, dtype=np.uint8)
return obs