-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathTrain_Step2.py
More file actions
115 lines (83 loc) · 3.91 KB
/
Train_Step2.py
File metadata and controls
115 lines (83 loc) · 3.91 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import torch
import numpy as np
import resource
from configuration import Configuration
import yaml
from utils import *
from SuperPoint import SuperPoint
from Database import Database
from torch.utils.data import Dataset, DataLoader
from FanClass import FAN_Model
from eval import test_stage2 as evalModel
def train(config):
stage=2
with open('paths.yml') as file:
paths = yaml.load(file, Loader=yaml.FullLoader)
check_paths(paths)
log_path=paths['log_path']
metadata=paths['metadata']
#This funcion will create the directories /Logs and a /CheckPoints at log_path
initialize_log_dirs(config.experiment_name,log_path)
log_text(f"Experiment Name {config.experiment_name}\n"
f"Database {config.dataset_name}\n"
"Training Parameters: \n"
f"Number of discovered landmarks K: {config.K} \n"
f"Batch size {config.batchSize} \n"
f"Learning rate {config.lr} \n"
f"Weight Decay {config.weight_decay} \n"
, config.experiment_name, log_path)
log_text("Training of First Stage begins", config.experiment_name,log_path)
criterion = JointsMSELoss(True).cuda()
FAN = FAN_Model(criterion,
config.experiment_name,
config.confidence_thres_FAN,
log_path,
1)
FAN.init_firststage( config.lr,
config.weight_decay,
config.M,
config.bootstrapping_iterations,
config.iterations_per_round,
config.K,
config.nms_thres_FAN,
config.lr_step_schedual_stage1)
cluster_dataset = Database( config.dataset_name,
metadata,
function_for_dataloading=Database.get_FAN_inference )
cluster_dataloader = DataLoader(cluster_dataset, batch_size=config.batchSize, shuffle=False,num_workers=config.num_workers, drop_last=False)
path_to_checkpoint=GetPathsResumeFirstStage(config.experiment_name,log_path)
FAN.load_trained_fiststage_model(path_to_checkpoint)
_,keypoints,flipppingCorrespondance=FAN.Update_pseudoLabels(cluster_dataloader)
FAN = FAN_Model(criterion,
config.experiment_name,
config.confidence_thres_FAN,
log_path,
stage)
FAN.init_secondstage(config.lr,
config.weight_decay,
config.K,
config.lr_step_schedual_stage2,
config.save_checkpoint_frequency,
path_to_checkpoint,
flipppingCorrespondance)
train_dataset = Database( config.dataset_name,
metadata,
image_keypoints=keypoints,
function_for_dataloading=Database.get_FAN_secondStage_train,
useflip=config.useflip,
flipppingCorrespondance=flipppingCorrespondance,
number_of_channels=config.K )
train_dataloader = DataLoader(train_dataset, batch_size=config.batchSize, shuffle=True, num_workers=config.num_workers,drop_last=True)
log_text(f'Dataset Number of Images:{len(train_dataset.files)}', config.experiment_name, log_path)
while FAN.iterations < config.total_iterations_stage2:
log_text(f'Training for iteration {FAN.iterations} begins', config.experiment_name, log_path)
FAN.Train_stage2(train_dataloader)
FAN.iterations+=1
if __name__=="__main__":
torch.manual_seed(1993)
torch.cuda.manual_seed_all(1993)
np.random.seed(1993)
rlimit = resource.getrlimit(resource.RLIMIT_NOFILE)
resource.setrlimit(resource.RLIMIT_NOFILE, (4096, rlimit[1]))
config=Configuration().params
train(config)