-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest.py
More file actions
91 lines (79 loc) · 3.66 KB
/
test.py
File metadata and controls
91 lines (79 loc) · 3.66 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import os
import torch
import yaml
from DataSet.TigerDataSet import TigerTripleImage
from nets.TripletNet import TripletNet
from nets.backbone.EfficientNetV2 import efficientnet_v2_s
from nets.loss_function import TripletMarginAndCosineLoss
from nets.similarity_module.dist_block import CosineDistance
from utils.logger import setting_logging
from model.tools import test_model
if __name__ == '__main__':
# 配置文件路径
config_yaml_file_path = os.path.join('model', 'EfficientNetV2S_NoCosineLoss', 'left', 'config.yaml')
# 从配置文件加载配置
with open(config_yaml_file_path, 'r') as stream:
config = yaml.safe_load(stream)
logger_name = config['train']['logger_name']
# 设置日志器
logger = setting_logging(logger_name)
# 批处理大小
batch_size = config['train']['batch_size']
# 模型参数保存路径
model_saved_path = config['model']['model_save_path']
model_saved_name = config['model']['model_save_name']
in_channels = config['model']['in_channels']
backbone = config['model']['backbone']
# 实例化模型backbone
if backbone == 'EfficientNetV2S':
model_backbone, backbone_out_channels = efficientnet_v2_s(in_channels=in_channels,
is_load_pytorch_pretrained=True)
else:
raise ValueError(f'当前不支持模型 {backbone} !')
dist_block = config['model']['dist_block']
if dist_block == 'CosineDistance':
model_dist_block = CosineDistance()
else:
raise ValueError(f'当前不支持 {dist_block} !')
include_adapter = config['model']['include_adapter']
se_dis = config['model']['se_dis']
input_shape = config['model']['input_shape']
model = TripletNet(backbone_out_channels=backbone_out_channels,
backbone=model_backbone,
dist_block=model_dist_block,
include_adapter=include_adapter,
se_dis=se_dis)
# 加载模型参数
if os.path.exists(os.path.join(model_saved_path, model_saved_name)):
saved_state_dict = torch.load(os.path.join(model_saved_path, model_saved_name))
max_val_acc = saved_state_dict['max_val_acc']
saved_state_dict.pop('max_val_acc')
# 通过自定义加载过程匹配键
new_model_state_dict = {}
for key, value in saved_state_dict.items():
# 修改键的方式以匹配新模型的结构
new_key = key.replace("module.", "")
new_model_state_dict[new_key] = value
model.load_state_dict(new_model_state_dict)
else:
max_val_acc = 0.0
cosine_margin = config['train']['criterion']['cosine_margin']
cosine_label_mode = config['train']['cosine_label_mode']
dataset_path = config['train']['dataset_path']
test_database = TigerTripleImage(file_path=dataset_path,
resize_x=input_shape,
resize_y=input_shape,
mode='test',
cosine_label_mode=cosine_label_mode)
acc, pre, recall, f1 = test_model(model=model,
threshold_margin=cosine_margin,
test_database=test_database,
batch_size=batch_size,
logger=logger)
config['model']['test_acc'] = acc
config['model']['test_pre'] = pre
config['model']['test_recall'] = recall
config['model']['test_f1'] = f1
# 保存指标
with open(config_yaml_file_path, 'w') as stream:
yaml.dump(config, stream)