-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathutils.py
More file actions
112 lines (92 loc) · 3.42 KB
/
utils.py
File metadata and controls
112 lines (92 loc) · 3.42 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import os
import sys
from datetime import datetime
import copy
import torch
import shutil
import importlib
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from collections import defaultdict
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def load_config(config_filename):
config_path = "configs.{}".format(config_filename.split('.')[0])
module = importlib.import_module(config_path)
return module.Config()
def splitprint():
print("#" * 100)
def runid_checker(opts, if_syn=False):
rootpath = opts.train_collection
if if_syn:
rootpath = opts.syn_collection
valset_name = os.path.split(opts.val_collection)[-1]
config_filename = opts.model_configs + "_" + opts.model_id
run_id = opts.run_id
target_path = os.path.join(rootpath, "models", valset_name, config_filename, "run_" + str(run_id))
if os.path.exists(target_path):
if opts.overwrite:
shutil.rmtree(target_path)
else:
print("'{}' exists!".format(target_path))
return False
os.makedirs(target_path)
print("checkpoints are saved in '{}'".format(target_path))
return True
def predict_dataloader(model, loader, device, net_name="mm-model", if_test=False):
model.eval()
predicts, predicts_fine = [], []
scores = []
scores_fine = []
expects, expects_fine = [], []
eye_level_predict = defaultdict(list)
eye_level_expect = defaultdict(list)
for i, (inputs, labels_onehot, imagenames) in enumerate(loader):
with torch.no_grad():
outputs, _ = model(inputs.to(device))
outputs = torch.nn.Sigmoid()(outputs)
eye_id = '-'.join(imagenames[0].split('-')[0:2])
eye_level_predict[eye_id].extend(outputs.cpu().numpy().tolist())
if eye_id not in eye_level_expect:
eye_level_expect[eye_id] = labels_onehot.cpu().numpy()
scores_fine.extend(outputs.cpu().numpy().astype(np.int64).tolist())
predict_fine = torch.round(outputs).cpu().numpy().astype(np.int64).tolist()
predicts_fine.extend(predict_fine)
expects_fine.extend(labels_onehot.cpu().numpy().tolist())
for eye_id in eye_level_predict:
predict = np.array([np.max(np.array(eye_level_predict[eye_id])[:, i]) for i in range(11)])
scores.append(predict)
predict = np.int64(torch.from_numpy(predict).squeeze(0).cpu().numpy() >= 0.5).tolist()
predicts.append(predict)
expects.extend(eye_level_expect[eye_id])
return predicts, scores, expects, predicts_fine, scores_fine
class Logger:
def __init__(self, filename):
self.terminal = sys.stdout
self.log = open(filename, 'w')
def write(self, message):
self.terminal.write(message)
self.log.write(message)
self.log.flush()
def flush(self):
self.terminal.flush()
self.log.flush()
def setup_logging():
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
log_filename = f'training_log_{timestamp}.txt'
sys.stdout = Logger(log_filename)
return log_filename