-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathutils.py
More file actions
43 lines (37 loc) · 1.26 KB
/
utils.py
File metadata and controls
43 lines (37 loc) · 1.26 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import json
import os
import time
import torch
def update_monitored_state(memory=None,
read_head=None, write_head=None,
filename='data.json'):
"""
Only the state with respect to the first item of a batch is monitored
"""
if os.path.exists('data.json'):
with open(filename) as f:
try:
j = json.load(f)
except json.decoder.JSONDecodeError:
j = {}
else:
j = {}
NUM_HEAD_HISTORY = 10
for key, val in (('memory', memory), ('read_head', read_head),
('write_head', write_head)):
key_prev = key + '_prev'
j[key_prev] = j[key] if key in j else None
if val is not None:
# read/write heads
if isinstance(val, list):
value = [v.clone().detach() for v in val][-NUM_HEAD_HISTORY:]
value += [torch.zeros(value[0].shape)] * (NUM_HEAD_HISTORY - len(value))
value = torch.cat(value, 0).numpy().tolist()
else:
value = val[0].clone().detach().numpy().tolist()
else:
value = None
j[key] = value
with open(filename, 'w') as f:
json.dump(j, f)
time.sleep(0.3)