-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain.py
More file actions
65 lines (48 loc) · 1.85 KB
/
train.py
File metadata and controls
65 lines (48 loc) · 1.85 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import config
from ext import now
from model import make_model, train_on
from model import load_model, save_model
from data import load_data, split_data, batchify_data
##
def main(disp_text=True):
if config.fresh_model:
config.all_losses = []
save_model(make_model())
model = load_model()
if disp_text: print('created model.')
else:
model = load_model()
if not model:
save_model(make_model())
model = load_model()
if disp_text: print('created model.')
else:
if disp_text: print('loaded model.')
data = load_data()
data, data_dev = split_data(data)
data = [d for i,d in enumerate(data) if i in [8,10,13,14]]
seq_lens = [len(d) for d in data]
print(f'seq lens: {seq_lens}')
min_seq_len = min(seq_lens)
print(f'min seq len: {min_seq_len}')
if not config.max_seq_len or config.max_seq_len > min_seq_len:
config.max_seq_len = min_seq_len
data = [d[:config.max_seq_len] for d in data]
# from random import choice
# from torch import randn
# data = [[randn(config.in_size) for _ in range(choice(range(config.max_seq_len//2,config.max_seq_len)))] for _ in range(10)]
# data_dev = []
# for d in data: print(len(d))
if not config.batch_size or config.batch_size >= len(data):
config.batch_size = len(data)
elif config.batch_size < 1:
config.batch_size = int(len(data)*config.batch_size)
if disp_text: print(f'hm data: {len(data)}, hm dev: {len(data_dev)}, bs: {config.batch_size}, lr: {config.learning_rate}, \ntraining started @ {now()}')
for ep in range(config.hm_epochs):
for i, batch in enumerate(batchify_data(data)):
train_on(model, batch)
return model
##
if __name__ == '__main__':
model = main()
save_model(model, config.model_path+'_final')