-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathpredict.py
More file actions
156 lines (132 loc) · 6.04 KB
/
predict.py
File metadata and controls
156 lines (132 loc) · 6.04 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
# -*- coding: utf-8 -*-
# @Author : Hao Fan
# @Time : 2024/11/4
import os
import cv2
import torch
import yaml
from PIL import Image
from matplotlib import pyplot as plt
from torchvision import transforms
from nets.TripletNet import TripletNet
from nets.backbone.EfficientNetV2 import efficientnet_v2_s
from nets.similarity_module.dist_block import CosineDistance, EuclideanDistance
from utils.logger import setting_logging
torch.manual_seed(2024)
row = 100
if __name__ == '__main__':
# 配置文件路径
config_yaml_file_path = os.path.join('model', 'EfficientNetV2S_TripletAndCosineLoss', 'left', 'config.yaml')
# 从配置文件加载配置
with open(config_yaml_file_path, 'r') as stream:
config = yaml.safe_load(stream)
logger_name = config['train']['logger_name']
# 设置日志器
logger = setting_logging(logger_name)
# 模型参数保存路径
model_saved_path = config['model']['model_save_path']
model_saved_name = config['model']['model_save_name']
input_shape = config['model']['input_shape']
in_channels = config['model']['in_channels']
backbone = config['model']['backbone']
# 实例化模型backbone
if backbone == 'EfficientNetV2S':
model_backbone, backbone_out_channels = efficientnet_v2_s(in_channels=in_channels,
is_load_pytorch_pretrained=True)
else:
raise ValueError(f'当前不支持模型 {backbone} !')
criterion_name = config['train']['criterion']['name']
if criterion_name == 'TripletMarginAndCosineLoss':
p = config['train']['criterion']['p']
cosine_margin = config['train']['criterion']['cosine_margin']
threshold_margin = cosine_margin
elif criterion_name == 'EuclideanDistanceLoss':
p = config['train']['criterion']['p']
threshold_margin = config['train']['criterion']['threshold_margin']
elif criterion_name == 'TripletMarginLoss':
p = config['train']['criterion']['p']
triplet_margin = config['train']['criterion']['triplet_margin']
threshold_margin = triplet_margin
else:
raise ValueError(f'当前不支持 {criterion_name} !')
dist_block = config['model']['dist_block']
if dist_block == 'CosineDistance':
model_dist_block = CosineDistance()
elif dist_block == 'EuclideanDistance':
model_dist_block = EuclideanDistance(p=p)
else:
raise ValueError(f'当前不支持 {dist_block} !')
include_adapter = config['model']['include_adapter']
se_dis = config['model']['se_dis']
input_shape = config['model']['input_shape']
model = TripletNet(backbone_out_channels=backbone_out_channels,
backbone=model_backbone,
dist_block=model_dist_block,
include_adapter=include_adapter,
se_dis=se_dis)
print(model)
# 加载模型参数
if os.path.exists(os.path.join(model_saved_path, model_saved_name)):
saved_state_dict = torch.load(os.path.join(model_saved_path, model_saved_name))
max_val_acc = saved_state_dict['max_val_acc']
saved_state_dict.pop('max_val_acc')
# 通过自定义加载过程匹配键
new_model_state_dict = {}
for key, value in saved_state_dict.items():
# 修改键的方式以匹配新模型的结构
new_key = key.replace("module.", "")
new_model_state_dict[new_key] = value
model.load_state_dict(new_model_state_dict)
else:
raise OSError(f'{os.path.join(model_saved_path, model_saved_name)} is not exist!')
# 获取可用设备
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
# 将模型转移到可用设备上
model = model.to(device)
# 转换到评估模式
model.eval()
print("欢迎来到东北虎个体相似度预测软件,请分别输入两张图片路径")
result_image_name = 'result_1'
# 读取图片
# image1_path = input('请输入第一张图片路径(注意,需要是彩色图片):')
# image1_path = r'data/stripe_data_100/left body/33/0001.jpg'
image1_path = r'data/zebra_stripe_data_100/left/IBEIS_PZ_0500/000000002274_crop_3221.png'
try:
image1 = cv2.imread(image1_path, cv2.IMREAD_GRAYSCALE)
type_1 = image1_path.split('/')[-2]
code_1 = image1_path.split('/')[-1].split('.')[0]
except:
raise ValueError("image1图片路径输入有误!")
# image2_path = input('请输入第二张图片路径(注意,需要是彩色图片):')
# image2_path = r'data/stripe_data_100/left body/1/0241.jpg'
# image2_path = r'data/stripe_data_100/left body/39/0018.jpg'
image2_path = r'data/zebra_stripe_data_100/left/IBEIS_PZ_0513/000000004086_crop_5747.png'
try:
image2 = cv2.imread(image2_path, cv2.IMREAD_GRAYSCALE)
type_2 = image2_path.split('/')[-2]
code_2 = image2_path.split('/')[-1].split('.')[0]
except:
raise ValueError("image2图片路径输入有误!")
print("图像正在处理中,请稍后...")
image_tf = transforms.Compose([
transforms.Resize([input_shape, input_shape]), # resize到[h * 1.2, w * 1.2]
transforms.ToTensor()
])
print("图像正在识别中,请稍后...")
fig, axs = plt.subplots(nrows=2, ncols=1)
axs[0].imshow(image1, cmap='gray')
axs[0].set_title(f'image1_stripe-type:{type_1}/code:{code_1}')
axs[1].imshow(image2, cmap='gray')
axs[1].set_title(f'image2_stripe-type:{type_2}/code:{code_2}')
with torch.no_grad():
image1_in = image_tf(Image.open(image1_path)).to(device).unsqueeze(0)
image2_in = image_tf(Image.open(image2_path)).to(device).unsqueeze(0)
output = model.predict(image1_in, image2_in)
fig.suptitle('Similarity:%.3f' % output[0].item())
# 调整布局之间的间距
plt.tight_layout()
if not os.path.isdir('./results'):
os.mkdir('./results')
plt.savefig(f'results/{result_image_name}.pdf', format='pdf')
plt.show()
print("识别完成,请查看图像结果输出!")