-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathevaluate.py
More file actions
98 lines (72 loc) · 3.18 KB
/
evaluate.py
File metadata and controls
98 lines (72 loc) · 3.18 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
""" ensLoss: test.py"""
# Authors: Ben Dai
# License: MIT License
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
import pandas as pd
import losses
import numpy as np
from torcheval.metrics import BinaryAccuracy, BinaryAUROC
import os
import time
class Tester(object):
def __init__(self, model, device, test_loader):
self.model = model
self.device = device
self.test_loader = test_loader
def test(self):
print('\n###### TEST ######')
epoch_acc_val = 0
epoch_auc_val = 0
self.model.eval()
with torch.no_grad():
acc_metric = BinaryAccuracy()
auc_metric = BinaryAUROC()
tbar = tqdm(self.test_loader, ncols=120)
for batch_idx, (X_batch, y_batch) in enumerate(tbar):
X_batch, y_batch = X_batch.to(self.device), y_batch.to(self.device)
y_pred = self.model(X_batch)
acc_metric.update(1.0*(y_pred > 0).flatten(), y_batch)
auc_metric.update(y_pred.flatten(), y_batch)
epoch_acc_val = acc_metric.compute().item()
epoch_auc_val = auc_metric.compute().item()
tbar.set_description('TEST | Acc: {:.3f}; AUC: {:.3f}'.format(
epoch_acc_val, epoch_auc_val))
return epoch_acc_val, epoch_auc_val
class Tester_bag(object):
def __init__(self, model_bag, device, test_loader, strategy='average'):
self.model_bag = model_bag
self.device = device
self.test_loader = test_loader
self.num_bag = len(model_bag)
self.strategy = strategy
def test(self):
print('\n###### TEST ######')
epoch_acc_val = 0
epoch_auc_val = 0
for i in range(self.num_bag):
self.model_bag[i].eval()
with torch.no_grad():
acc_metric = BinaryAccuracy()
auc_metric = BinaryAUROC()
tbar = tqdm(self.test_loader, ncols=120)
for batch_idx, (X_batch, y_batch) in enumerate(tbar):
X_batch, y_batch = X_batch.to(self.device), y_batch.to(self.device)
y_pred_bag = torch.stack([model(X_batch) for model in self.model_bag], dim=0)
if self.strategy == 'average':
y_pred = torch.mean(y_pred_bag, dim=0)
elif self.strategy == 'voting':
y_pred_bag = torch.sign(y_pred_bag)
y_pred, _ = torch.mode(y_pred_bag, dim=0)
else:
raise ValueError('Unknown strategy: {}'.format(self.strategy))
acc_metric.update(1.0*(y_pred > 0).flatten(), y_batch)
auc_metric.update(y_pred.flatten(), y_batch)
epoch_acc_val = acc_metric.compute().item()
epoch_auc_val = auc_metric.compute().item()
tbar.set_description('TEST | Acc: {:.3f}; AUC: {:.3f}'.format(
epoch_acc_val, epoch_auc_val))
return epoch_acc_val, epoch_auc_val