-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy patheval.py
More file actions
175 lines (145 loc) · 5.74 KB
/
eval.py
File metadata and controls
175 lines (145 loc) · 5.74 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
''' Script to evaluate 1D ResNet models on Malnet Signal dataset
'''
import torch as T
from numpy import ndarray
from sklearn.metrics import f1_score, accuracy_score, precision_score, recall_score
import argparse
import os
from pprint import pprint
from model.model_factory import model_factory
from training_fns.malnet_loader import make_malnet_1d
def parse_option():
parser = argparse.ArgumentParser('argument for training')
# model config
parser.add_argument('--model', type=str, default='resnet1dv2_152d_se', help='name of cnn model')
parser.add_argument('--activation', type=str, default='gelu', help='name of activation function')
# data config
parser.add_argument('--task', type=str, default='type', help='classification granularity. can be binary, family, or type')
parser.add_argument('--data_path', type=str, default='dataset/signals', help='path to dataset dir')
parser.add_argument('--train_split_path', type=str, default='dataset/split_info/<task>/1.0/train.txt', help='path to text file containing training split')
parser.add_argument('--val_split_path', type=str, default='dataset/split_info/<task>/1.0/val.txt', help='path to text file containing validation split')
parser.add_argument('--test_split_path', type=str, default='dataset/split_info/<task>/1.0/test.txt', help='path to text file containing test split')
parser.add_argument('--exclude_path', type=str, default='dataset/split_info/exclude.txt', help='path to text file containing sample to exclude from dataset')
parser.add_argument('--n_workers', type=int, default=8, help='number of dataloader workers')
# misc
parser.add_argument('--device', type=str, default='cuda', help='device to run evaluation on')
parser.add_argument('--checkpoint_path', type=str, default=None, help='path to checkpoint')
parser.add_argument('--chunk_size', type=int, default=64, help='inference batch size')
opt = parser.parse_args()
# expand paths
opt.data_path = os.path.expanduser(opt.data_path)
opt.train_split_path = os.path.expanduser(opt.train_split_path)
opt.val_split_path = os.path.expanduser(opt.val_split_path)
opt.test_split_path = os.path.expanduser(opt.test_split_path)
opt.exclude_path = os.path.expanduser(opt.exclude_path)
opt.train_split_path = opt.train_split_path.replace("<task>", opt.task)
opt.val_split_path = opt.val_split_path.replace("<task>", opt.task)
opt.test_split_path = opt.test_split_path.replace("<task>", opt.task)
return opt
def load_data(opt):
# get signal statistics
if opt.task == 'binary':
transform = dict(
mean = 0.21301524341106415,
std = 0.12883393466472626,
)
elif opt.task == 'type':
transform = dict(
mean = 0.21299883723258972,
std = 0.12885181605815887,
)
# better, use if retraining models
'''
transform = dict(
mean = 0.2130117416381836,
std = 0.12885624170303345,
)
'''
else: # task is family
transform = dict(
mean = 0.2130156308412552,
std = 0.12882858514785767,
)
# make dataset
return make_malnet_1d(dict(
data_path = opt.data_path,
train_split_path = opt.train_split_path,
val_split_path = opt.val_split_path,
test_split_path = opt.test_split_path,
exclude_path = opt.exclude_path,
n_workers = opt.n_workers,
task = opt.task,
batch_size = opt.chunk_size,
train_transform = transform,
val_transform = transform,
train_drop_last = True,
shuffle_train = True,
collate_fn = None,
pin_mem = False,
))
def load_model(opt):
if opt.task == 'binary':
n_classes = 2
elif opt.task == 'type':
n_classes = 47
else: # family level classification
n_classes = 696
model = model_factory(
config = dict(
name = opt.model,
n_classes = n_classes,
in_channels = 1,
act_layer = opt.activation,
load_weights = opt.checkpoint_path,
),
verbose = True,
)
model = model.to(opt.device)
model.eval()
return model
@T.no_grad()
def get_preds(
model,
dl,
opt,
):
y_true = []
y_pred = []
for i, batch in enumerate(dl):
print(f'batch {i} / {len(dl)}', end = '\r')
x,y = batch
x = x.to(opt.device)
y = y.to(opt.device)
y_true.append(y)
y_pred_batch = T.argmax(model(x), dim = -1)
y_pred.append(y_pred_batch)
y_true = T.cat(y_true, dim = 0)
y_pred = T.cat(y_pred, dim = 0)
return y_true.cpu().detach().numpy(), y_pred.cpu().detach().numpy()
def get_metrics(y_true: ndarray, y_pred: ndarray) -> dict:
metrics = dict(
Accuracy = accuracy_score(y_true, y_pred),
Macro_Precision = precision_score(y_true, y_pred, average="macro", zero_division=0.0),
Macro_Recall = recall_score(y_true, y_pred, average="macro", zero_division=0.0),
Macro_F1 = f1_score(y_true, y_pred, average="macro", zero_division = 0.0),
Micro_F1 = f1_score(y_true, y_pred, average="micro", zero_division = 0.0),
Weighted_F1 = f1_score(y_true, y_pred, average="weighted", zero_division = 0.0)
)
return metrics
def main():
opt = parse_option()
# get data
_, _, test_dl = load_data(opt)
# get model
model = load_model(opt)
# get predictions
y_true, y_pred = get_preds(
model = model,
dl = test_dl,
opt = opt,
)
# get metrics
metrics = get_metrics(y_true, y_pred)
pprint(metrics)
if __name__ == '__main__':
main()