-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathcheckpoint.py
More file actions
65 lines (56 loc) · 2.51 KB
/
checkpoint.py
File metadata and controls
65 lines (56 loc) · 2.51 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import torch
import os
import utils
# Note:
'''
# save and load entire model
torch.save(model, "model.pkl")
model = torch.load("model.pkl")
# save and load only the model parameters(recommended)
torch.save(model.state_dict(), "params.pkl")
model.load_state_dict(torch.load("params.pkl"))
'''
class CheckPoint(object):
def __init__(self, opt):
self.resume = opt.resume
self.resumeEpoch = opt.resumeEpoch
self.retrain = opt.retrain
self.save_path = opt.save_path+"model/"
self.check_point_params = {'model': None,
'opts': None,
'resume_epoch': None}
def retrainmodel(self):
if os.path.isfile(self.retrain):
print "|===>Retrain model from:", self.retrain
retrain_data = torch.load(self.retrain)
self.check_point_params['model'] = retrain_data['model']
return self.check_point_params
else:
assert False, "file not exits"
def resumemodel(self):
if os.path.isfile(self.resume):
print "|===>Resume check point from:", self.resume
self.check_point_params = torch.load(self.resume)
if self.resumeEpoch != 0:
self.check_point_params['resume_epoch'] = self.resumeEpoch
return self.check_point_params
else:
assert False, "file not exits"
def savemodel(self, epoch=None, model=None, opts=None, best_flag=False):
# Note: if we add hook to the grad by using register_hook(hook), then the hook function can not be saved
# so we need to save state_dict() only. Although save state dictionary is recommended, I still want to save
# the whole model as it can save the structure of network too, thus we do not need to create a new network
# next time.
# model = utils.list2sequential(model).state_dict()
# opts = opts.state_dict()
if not os.path.isdir(self.save_path):
os.mkdir(self.save_path)
# self.check_point_params['model'] = utils.list2sequential(model).state_dict()
self.check_point_params['model'] = model
self.check_point_params['opts'] = opts
self.check_point_params['resume_epoch'] = epoch
torch.save(self.check_point_params, self.save_path+"checkpoint.pkl")
if best_flag:
# best_model = {'model': utils.list2sequential(model).state_dict()}
best_model = {'model': model}
torch.save(best_model, self.save_path+"best_model.pkl")