-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathvalidate.py
More file actions
executable file
·83 lines (63 loc) · 3.89 KB
/
validate.py
File metadata and controls
executable file
·83 lines (63 loc) · 3.89 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import sys,os
from tqdm import tqdm
import torch
import torch.nn as nn
from metavision_ml.detection.anchors import Anchors
from Model.feature_extractor import ConditionalInput_EMinGRU_ReLUFuseDownsampleConv_ConditionalConv
from Model.ssd_head import BoxHead
from Model.detection import inference_step
from Model.detection import evaluate
from utils.dataloader import seq_dataloader
##########################################################################################
dataset_path = '/media/shenqi/data/Gen4_multi_timesurface_FromDat'
dataset_type = 'gen4'
dataloader = seq_dataloader(dataset_path = dataset_path, dataset_type = dataset_type, num_tbins = 8, batch_size = 4, channels = 6)
saved_model_path = './Saved_Model/gen4/CSSL_MGU/'
cout = 256
threshold_group = 1
validate_epoch_start = 0
validate_epoch_end = 35
##########################################################################################
net = ConditionalInput_EMinGRU_ReLUFuseDownsampleConv_ConditionalConv(dataloader.channels, base=16, cout=cout, dataset = dataset_type, pruning = False)
box_coder = Anchors(num_levels=net.levels, anchor_list='PSEE_ANCHORS', variances=[0.1, 0.2])
ssd_head = BoxHead(net.cout, box_coder.num_anchors, len(dataloader.wanted_keys)+1, n_layers=0)
net.eval().to('cuda')
ssd_head.eval().to('cuda')
for epoch in range(validate_epoch_start, validate_epoch_end+1):
output_val_list = []
cnt_val = 0
mean_activity_egru_ave = [0] * net.levels
mean_activity_conv1_ave = [0] * 7
mean_activity_egruRelu_ave = [0] * net.levels
box_hid_mean_ave = [0] * net.levels
cls_hid_mean_ave = [0] * net.levels
net.load_state_dict(torch.load(saved_model_path + str(epoch)+ '_model.pth',map_location=torch.device('cuda')))
ssd_head.load_state_dict(torch.load(saved_model_path + str(epoch)+ '_pd.pth',map_location=torch.device('cuda')))
first_batch = next(iter(dataloader.seq_dataloader_val))
net.reset(torch.zeros_like(first_batch['mask_keep_memory']).to(device='cuda'))
with tqdm(total=len(dataloader.seq_dataloader_val), desc=f'Validation',ncols=120) as pbar:
for data in dataloader.seq_dataloader_val:
sys.stdout.flush()
with torch.no_grad():
cnt_val += 1
data['inputs'] = data['inputs'].to(device='cuda')
output_val,mean_activity, mean_activity_conv1, output_gates_val, mean_activity_egru_relu, box_hid_mean, cls_hid_mean = inference_step(data,net,ssd_head,box_coder)
output_val_list.append(output_val)
for i in range(net.levels):
mean_activity_egru_ave[i] += mean_activity[i].item()
for i in range(7):
mean_activity_conv1_ave[i] += mean_activity_conv1[i].item()
for i in range(net.levels):
mean_activity_egruRelu_ave[i] += mean_activity_egru_relu[i].item()
for i in range(net.levels):
box_hid_mean_ave[i] += box_hid_mean[i].item()
cls_hid_mean_ave[i] += cls_hid_mean[i].item()
pbar.update(1)
mean_activity_egru_ave = [item/cnt_val for item in mean_activity_egru_ave]
mean_activity_conv1_ave = [item/cnt_val for item in mean_activity_conv1_ave]
mean_activity_egruRelu_ave = [item/cnt_val for item in mean_activity_egruRelu_ave]
box_hid_mean_ave = [item/cnt_val for item in box_hid_mean_ave]
cls_hid_mean_ave = [item/cnt_val for item in cls_hid_mean_ave]
print('\n epoch is:', epoch)
evaluate(output_val_list, dataloader)
print('mean_activity_egru_ave:', mean_activity_egru_ave, '\n mean_activity_conv1', mean_activity_conv1_ave, '\n mean_activity_egruRelu', mean_activity_egruRelu_ave, '\n box_hid_mean_ave', box_hid_mean_ave, '\n cls_hid_mean_ave', cls_hid_mean_ave)