-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest.py
More file actions
executable file
·84 lines (62 loc) · 3.71 KB
/
test.py
File metadata and controls
executable file
·84 lines (62 loc) · 3.71 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import sys,os
from tqdm import tqdm
import torch
import torch.nn as nn
from metavision_ml.detection.anchors import Anchors
from Model.feature_extractor import ConditionalInput_EMinGRU_ReLUFuseDownsampleConv_ConditionalConv
from Model.ssd_head import BoxHead
from Model.detection import inference_step
from Model.detection import evaluate
from utils.dataloader import seq_dataloader
import utils.data_augmentation as data_aug
##########################################################################################
dataset_path = '/media/shenqi/data/Gen4_multi_timesurface_FromDat'
dataset_type = 'gen4'
dataloader = seq_dataloader(dataset_path = dataset_path, dataset_type = dataset_type, num_tbins = 12, batch_size = 5, channels = 6)
saved_model_path = './Saved_Model/gen4/CSSL_MGU/'
test_epoch = 49
cout = 256
threshold_group = 1
##########################################################################################
net = ConditionalInput_EMinGRU_ReLUFuseDownsampleConv_ConditionalConv(dataloader.channels, base=int(cout/16), cout=cout, dataset = dataset_type, pruning = False)
box_coder = Anchors(num_levels=net.levels, anchor_list='PSEE_ANCHORS', variances=[0.1, 0.2])
ssd_head = BoxHead(net.cout, box_coder.num_anchors, len(dataloader.wanted_keys)+1, n_layers=0)
net.load_state_dict(torch.load(saved_model_path + str(test_epoch)+ '_model.pth', map_location=torch.device('cuda')))
ssd_head.load_state_dict(torch.load(saved_model_path + str(test_epoch)+ '_pd.pth',map_location=torch.device('cuda')))
net.eval().to('cuda')
ssd_head.eval().to('cuda')
# augment = data_aug.data_augmentation(dataset_type= dataset_type)
output_val_list = []
cnt_val = 0
mean_activity_egru_ave = [0] * net.levels
mean_activity_conv1_ave = [0] * 7
mean_activity_egruRelu_ave = [0] * net.levels
box_hid_mean_ave = [0] * net.levels
cls_hid_mean_ave = [0] * net.levels
with tqdm(total=len(dataloader.seq_dataloader_test), desc=f'Validation',ncols=120) as pbar:
for data in dataloader.seq_dataloader_test:
sys.stdout.flush()
with torch.no_grad():
cnt_val += 1
data['inputs'] = data['inputs'].to(device='cuda')
# if data['frame_is_labeled'].sum().item() != 0:
# data = augment(data, only_vertical_move=True)
output_val,mean_activity, mean_activity_conv1, output_gates_val, mean_activity_egru_relu, box_hid_mean, cls_hid_mean = inference_step(data,net,ssd_head,box_coder)
output_val_list.append(output_val)
for i in range(net.levels):
mean_activity_egru_ave[i] += mean_activity[i].item()
for i in range(7):
mean_activity_conv1_ave[i] += mean_activity_conv1[i].item()
for i in range(net.levels):
mean_activity_egruRelu_ave[i] += mean_activity_egru_relu[i].item()
for i in range(net.levels):
box_hid_mean_ave[i] += box_hid_mean[i].item()
cls_hid_mean_ave[i] += cls_hid_mean[i].item()
pbar.update(1)
mean_activity_egru_ave = [item/cnt_val for item in mean_activity_egru_ave]
mean_activity_conv1_ave = [item/cnt_val for item in mean_activity_conv1_ave]
mean_activity_egruRelu_ave = [item/cnt_val for item in mean_activity_egruRelu_ave]
box_hid_mean_ave = [item/cnt_val for item in box_hid_mean_ave]
cls_hid_mean_ave = [item/cnt_val for item in cls_hid_mean_ave]
evaluate(output_val_list, dataloader)
print('mean_activity_egru_ave:', mean_activity_egru_ave, '\n mean_activity_conv1', mean_activity_conv1_ave, '\n mean_activity_egruRelu', mean_activity_egruRelu_ave, '\n box_hid_mean_ave', box_hid_mean_ave, '\n cls_hid_mean_ave', cls_hid_mean_ave)