-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain.py
More file actions
164 lines (150 loc) · 7.19 KB
/
train.py
File metadata and controls
164 lines (150 loc) · 7.19 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
import os
import torch
import yaml
from DataSet.TigerDataSet import TigerTripleImage
from nets.TripletNet import TripletNet
from nets.backbone.EfficientNetV2 import efficientnet_v2_s
from nets.loss_function import TripletMarginAndCosineLoss, EuclideanDistanceLoss, TripletMarginLoss
from nets.similarity_module.dist_block import CosineDistance, EuclideanDistance
from utils.logger import setting_logging
from model.tools import train_model, test_model
os.environ["CUDA_VISIBLE_DEVICES"] = "0, 1, 2, 3, 4, 5, 6, 7"
torch.manual_seed(2024)
if __name__ == '__main__':
# 配置文件路径
config_yaml_file_path = os.path.join('model', 'EfficientNetV2S_TripletAndCosineLoss', 'left', 'config.yaml')
# 从配置文件加载配置
with open(config_yaml_file_path, 'r') as stream:
config = yaml.safe_load(stream)
logger_name = config['train']['logger_name']
# 设置日志器
logger = setting_logging(logger_name)
# 总的迭代次数
Epochs = config['train']['Epochs']
# 批处理大小
batch_size = config['train']['batch_size']
stop_epoch = config['train']['stop_epoch']
# 模型参数保存路径
model_saved_path = config['model']['model_save_path']
model_saved_name = config['model']['model_save_name']
in_channels = config['model']['in_channels']
backbone = config['model']['backbone']
# 实例化模型backbone
if backbone == 'EfficientNetV2S':
model_backbone, backbone_out_channels = efficientnet_v2_s(in_channels=in_channels,
is_load_pytorch_pretrained=True)
else:
raise ValueError(f'当前不支持模型 {backbone} !')
criterion_name = config['train']['criterion']['name']
if criterion_name == 'TripletMarginAndCosineLoss':
p = config['train']['criterion']['p']
cosine_margin = config['train']['criterion']['cosine_margin']
triplet_margin = config['train']['criterion']['triplet_margin']
criterion = TripletMarginAndCosineLoss(p=p, triplet_margin=triplet_margin, cosine_margin=cosine_margin)
threshold_margin = cosine_margin
elif criterion_name == 'EuclideanDistanceLoss':
p = config['train']['criterion']['p']
criterion = EuclideanDistanceLoss(p=p)
threshold_margin = config['train']['criterion']['threshold_margin']
elif criterion_name == 'TripletMarginLoss':
p = config['train']['criterion']['p']
triplet_margin = config['train']['criterion']['triplet_margin']
criterion = TripletMarginLoss(p=p, triplet_margin=triplet_margin)
threshold_margin = triplet_margin
else:
raise ValueError(f'当前不支持 {criterion_name} !')
dist_block = config['model']['dist_block']
if dist_block == 'CosineDistance':
model_dist_block = CosineDistance()
elif dist_block == 'EuclideanDistance':
model_dist_block = EuclideanDistance(p=p)
else:
raise ValueError(f'当前不支持 {dist_block} !')
include_adapter = config['model']['include_adapter']
se_dis = config['model']['se_dis']
input_shape = config['model']['input_shape']
model = TripletNet(backbone_out_channels=backbone_out_channels,
backbone=model_backbone,
dist_block=model_dist_block,
include_adapter=include_adapter,
se_dis=se_dis)
# 加载模型参数
if os.path.exists(os.path.join(model_saved_path, model_saved_name)):
saved_state_dict = torch.load(os.path.join(model_saved_path, model_saved_name))
max_val_acc = saved_state_dict['max_val_acc']
saved_state_dict.pop('max_val_acc')
# 通过自定义加载过程匹配键
new_model_state_dict = {}
for key, value in saved_state_dict.items():
# 修改键的方式以匹配新模型的结构
new_key = key.replace("module.", "")
new_model_state_dict[new_key] = value
model.load_state_dict(new_model_state_dict)
else:
max_val_acc = 0.0
weight_decay = config['train']['weight_decay']
lr = config['train']['lr']
cosine_label_mode = config['train']['cosine_label_mode']
fp_16 = config['train']['fp_16']
dataset_path = config['train']['dataset_path']
train_database = TigerTripleImage(file_path=dataset_path,
resize_x=input_shape,
resize_y=input_shape,
mode='train',
cosine_label_mode=cosine_label_mode)
val_database = TigerTripleImage(file_path=dataset_path,
resize_x=input_shape,
resize_y=input_shape,
mode='val',
cosine_label_mode=cosine_label_mode)
test_database = TigerTripleImage(file_path=dataset_path,
resize_x=input_shape,
resize_y=input_shape,
mode='test',
cosine_label_mode=cosine_label_mode)
save_log_dir = config['train']['save_log_dir']
train_model(model=model,
model_save_path=model_saved_path,
model_save_name=model_saved_name,
criterion=criterion,
threshold_margin=threshold_margin,
train_database=train_database,
val_database=val_database,
batch_size=batch_size,
lr=lr,
weight_decay=weight_decay,
Epochs=Epochs,
max_val_acc=max_val_acc,
input_shape=[1, 3, in_channels, input_shape, input_shape],
save_log_dir=save_log_dir,
logger=logger,
fp_16=fp_16,
stop_epoch=stop_epoch,
cosine_label_mode=cosine_label_mode)
# 加载最优模型参数
if os.path.exists(os.path.join(model_saved_path, model_saved_name)):
saved_state_dict = torch.load(os.path.join(model_saved_path, model_saved_name))
max_val_acc = saved_state_dict['max_val_acc']
saved_state_dict.pop('max_val_acc')
# 通过自定义加载过程匹配键
new_model_state_dict = {}
for key, value in saved_state_dict.items():
# 修改键的方式以匹配新模型的结构
new_key = key.replace("module.", "")
new_model_state_dict[new_key] = value
model.load_state_dict(new_model_state_dict)
else:
max_val_acc = 0.0
acc, pre, recall, f1 = test_model(model=model,
threshold_margin=threshold_margin,
test_database=test_database,
batch_size=batch_size,
logger=logger,
cosine_label_mode=cosine_label_mode)
config['model']['test_acc'] = acc
config['model']['test_pre'] = pre
config['model']['test_recall'] = recall
config['model']['test_f1'] = f1
# 保存指标
with open(config_yaml_file_path, 'w') as stream:
yaml.dump(config, stream)