-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain.py
More file actions
76 lines (56 loc) · 2.13 KB
/
train.py
File metadata and controls
76 lines (56 loc) · 2.13 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import torch
import torch.nn as nn
from config_default import DefaultConfig
from dataloader import *
from models.Model import GazeNet
from utility_functions.load_model import *
from utility_functions.save_model import *
config = DefaultConfig()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
EPOCHS = 10
# Define model
model = GazeNet()
# print(model)
model = model.to(device)
# Optimizer
model.optimizer = torch.optim.Adam(
model.parameters(),
lr=config.learning_rate,
weight_decay=config.weight_decay,)
model.criterion = nn.MSELoss(reduction='mean').to(device)
def scale_xy_to_ab(value_xy, x, y, a, b):
value_ab = ((value_xy - x) / (y - x)) * (b - a) + a
return value_ab
def train():
dataset = EVEProcessedDataset(
'./processedFrames/train01/step007_image_MIT-i2277207572')
dataloader = DataLoader(dataset=dataset, batch_size=1, num_workers=1)
for epoch in range(EPOCHS):
print('Epoch ', epoch, 'going on')
total_loss = 0
print('Total length:', len(dataloader))
print('Currently completed:', end=' ')
for i, (face, left_eye, right_eye, face_grid, screen, pog) in enumerate(dataloader):
print(i, end=' ')
model.optimizer.zero_grad()
input_dict = {}
output_dict = {}
input_dict['left_eye_patch'] = left_eye
input_dict['right_eye_patch'] = right_eye
input_dict['screen_frame'] = screen
output_dict['pog'] = pog
model(input_dict, output_dict)
loss = model.criterion(output_dict['PoG_px_final'].float(), pog.float())
loss.backward()
total_loss += loss.float()
del loss
model.optimizer.step()
# print('')
# print(' Prediction: ', output_dict['PoG_px_final'].float())
# print(' Gaze point: ', pog)
#saving the model
save_weights_for_instance(model.eye_net)
save_weights_for_instance(model.refine_net)
print('Total Loss at end of ', epoch, 'is ', total_loss)
if __name__=='__main__':
train()