-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathutils.py
More file actions
133 lines (107 loc) · 4.35 KB
/
utils.py
File metadata and controls
133 lines (107 loc) · 4.35 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import cv2
import numpy as np
import pathlib
import matplotlib.pyplot as plt
def freeze_parameters(net):
for _, params in net.named_parameters():
params.requires_grad = False
def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
slope = (end_e - start_e) / duration
return max(slope * t + start_e, end_e)
def save_video(frames, path, name):
"""
Saves a video containing frames.
"""
# Preprocessing
if np.max(frames) <= 1:
frames = (frames * 255).clip(0, 255).astype('uint8')
if frames.shape[-1] != 3:
frames = frames.transpose(0, 3, 1, 2)
_, H, W, _ = frames.shape
writer = cv2.VideoWriter(
str(pathlib.Path(path)/f'{name}.mp4'),
cv2.VideoWriter_fourcc(*'mp4v'), 25., (W, H), True
)
for frame in frames[..., ::-1]:
writer.write(frame)
writer.release()
def plot_images(data_path: str, labels: list, save_path: str):
# Parameters
n_rows = 5
n_cols = 20
img_w, img_h = 256, 256
pad_vertical = 50
# Margins
left_margin = 120
bottom_margin = 160
right_margin = 50
top_margin = 40
data = np.load(data_path)
sequences = [data[k] for k in data.files]
sequences_new = []
# Re-adjust indices
for s in sequences[1:]:
s = list(s)
s.insert(0, data["real"][0])
s.pop(-1)
sequences_new.append(np.array(s))
sequences_new.insert(0, sequences[0])
sequences = sequences_new
# Compute figure size in pixels
fig_width_px = left_margin + n_cols * img_w + right_margin
fig_height_px = bottom_margin + n_rows * img_h + pad_vertical * (n_rows - 1) + top_margin
# Create figure
fig = plt.figure(figsize=(fig_width_px / 100, fig_height_px / 100))
# Plot images
for row_idx in range(n_rows):
for col_idx in range(n_cols):
x0 = left_margin + col_idx * img_w
y0 = bottom_margin + (n_rows - 1 - row_idx) * (img_h + pad_vertical)
# Convert to relative coordinates
left = x0 / fig_width_px
bottom = y0 / fig_height_px
width = img_w / fig_width_px
height = img_h / fig_height_px
ax = fig.add_axes([left, bottom, width, height])
ax.imshow(sequences[row_idx][col_idx])
ax.axis('off')
for i, label in enumerate(labels):
y_center_px = bottom_margin + (n_rows - 1 - i) * (img_h + pad_vertical) + img_h / 2
fig.text((left_margin - 20) / fig_width_px, y_center_px / fig_height_px,
label, va='center', ha='right', rotation=90, fontsize=42)
# Add x-axis label ("Time")
fig.text(0.5, (bottom_margin - 140) / fig_height_px, "Timestep", ha='center', fontsize=56)
# Add ticks every 5th image with large font
for i in range(n_cols):
if i % 5 == 0:
x_center_px = left_margin + i * img_w + img_w / 2
fig.text(x_center_px / fig_width_px, (bottom_margin - 70) / fig_height_px,
str(i), ha='center', fontsize=60)
for row_idx in [0, 1]: # Between row 0-1 and 1-2
y_px = bottom_margin + (n_rows - row_idx) * (img_h + pad_vertical) - pad_vertical / 2
y_rel = y_px / fig_height_px
line = Line2D(
[left_margin / fig_width_px, (left_margin + n_cols * img_w) / fig_width_px],
[y_rel, y_rel],
color='black',
linewidth=3
)
fig.add_artist(line)
fig.savefig(save_path)
plt.close(fig)
if __name__ == '__main__':
import os
from matplotlib.lines import Line2D
experiment = "trajectories_ablation"
env = "super_tux_kart"
tracks = [x for x in os.listdir(f"experiments/{experiment}/{env}/") if ".png" not in x]
labels = ["Real", "SSM", "Rollout", "L2", "KL"]
for t in tracks:
print(f"Working on {t}")
for l in [5, 6, 7, 8, 9, 10, 15, 20, 25, 30]:
for samples_dir in ["one_step", "long_term"]:
samples_path = f"experiments/{experiment}/{env}/{t}/{l}/samples/{samples_dir}"
os.makedirs(f"{samples_path}/images/{samples_dir}", exist_ok=True)
for i, el in enumerate([x for x in os.listdir(samples_path) if x.endswith(".npz")]):
print(f"[{t}] Working on {el}")
plot_images(f"{samples_path}/{el}", labels, save_path=f"{samples_path}/images/{samples_dir}/sample_{i}.png")