forked from Kaixhin/EC
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest.py
More file actions
85 lines (68 loc) · 3.09 KB
/
test.py
File metadata and controls
85 lines (68 loc) · 3.09 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import os
import numpy as np
import plotly
from plotly.graph_objs import Scatter
from plotly.graph_objs.scatter import Line
import torch
from envs import AtariEnv
# Globals
Ts, rewards, Qs, best_avg_reward = [], [], [], -1e10
# Plots min, max and mean + standard deviation bars of a population over time
def lineplot(xs, ys_population, title, path='', xaxis='episode'):
max_colour, mean_colour, std_colour, transparent = 'rgb(0, 132, 180)', 'rgb(0, 172, 237)', 'rgba(29, 202, 255, 0.2)', 'rgba(0, 0, 0, 0)'
if isinstance(ys_population[0], list) or isinstance(ys_population[0], tuple):
ys = np.asarray(ys_population, dtype=np.float32)
ys_min, ys_max, ys_mean, ys_std, ys_median = ys.min(1), ys.max(1), ys.mean(1), ys.std(1), np.median(ys, 1)
ys_upper, ys_lower = ys_mean + ys_std, ys_mean - ys_std
trace_max = Scatter(x=xs, y=ys_max, line=Line(color=max_colour, dash='dash'), name='Max')
trace_upper = Scatter(x=xs, y=ys_upper, line=Line(color=transparent), name='+1 Std. Dev.', showlegend=False)
trace_mean = Scatter(x=xs, y=ys_mean, fill='tonexty', fillcolor=std_colour, line=Line(color=mean_colour), name='Mean')
trace_lower = Scatter(x=xs, y=ys_lower, fill='tonexty', fillcolor=std_colour, line=Line(color=transparent), name='-1 Std. Dev.', showlegend=False)
trace_min = Scatter(x=xs, y=ys_min, line=Line(color=max_colour, dash='dash'), name='Min')
trace_median = Scatter(x=xs, y=ys_median, line=Line(color=max_colour), name='Median')
data = [trace_upper, trace_mean, trace_lower, trace_min, trace_max, trace_median]
else:
data = [Scatter(x=xs, y=ys_population, line=Line(color=mean_colour))]
plotly.offline.plot({
'data': data,
'layout': dict(title=title, xaxis={'title': xaxis}, yaxis={'title': title})
}, filename=os.path.join(path, title + '.html'), auto_open=False)
# Test DQN
def test(args, T, agent, val_mem, results_dir, evaluate=False):
global Ts, rewards, Qs, best_avg_reward
env = AtariEnv(args)
env.eval()
Ts.append(T)
T_rewards, T_Qs = [], []
# Test performance over several episodes
done = True
for _ in range(args.evaluation_episodes):
while True:
if done:
state, reward_sum, done = env.reset(), 0, False
action = agent.act(state) # Choose an action ε-greedily (default for eval mode)
state, reward, done = env.step(action) # Step
reward_sum += reward
if args.render:
env.render()
if done:
T_rewards.append(reward_sum)
break
env.close()
# Test Q-values over validation memory
for state in val_mem: # Iterate over valid states
T_Qs.append(agent.evaluate_q(state))
avg_reward, avg_Q = sum(T_rewards) / len(T_rewards), sum(T_Qs) / len(T_Qs)
if not evaluate:
# Append to results
rewards.append(T_rewards)
Qs.append(T_Qs)
# Plot
lineplot(Ts, rewards, 'Reward', path=results_dir, xaxis='Step')
lineplot(Ts, Qs, 'Q', path=results_dir, xaxis='Step')
# Save model parameters if improved
if avg_reward > best_avg_reward:
best_avg_reward = avg_reward
agent.save(results_dir)
# Return rewards and Q-values
return T_rewards, T_Qs