forked from darinmoore/grape_expectations
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
executable file
·162 lines (140 loc) · 7.17 KB
/
main.py
File metadata and controls
executable file
·162 lines (140 loc) · 7.17 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
#! /usr/bin/env python
import os
import argparse
import datetime
import gensim
import re
import torch
import torchtext.data as data
import torchtext.datasets as datasets
import model
import train
import winedataset
parser = argparse.ArgumentParser(description='CNN text classificer')
# learning
parser.add_argument('-lr', type=float, default=0.001, help='initial learning rate [default: 0.001]')
parser.add_argument('-epochs', type=int, default=256, help='number of epochs for train [default: 256]')
parser.add_argument('-batch-size', type=int, default=64, help='batch size for training [default: 64]')
parser.add_argument('-log-interval', type=int, default=1, help='how many steps to wait before logging training status [default: 1]')
parser.add_argument('-test-interval', type=int, default=100, help='how many steps to wait before testing [default: 100]')
parser.add_argument('-save-interval', type=int, default=500, help='how many steps to wait before saving [default:500]')
parser.add_argument('-save-dir', type=str, default='snapshot', help='where to save the snapshot')
parser.add_argument('-early-stop', type=int, default=1000, help='iteration numbers to stop without performance increasing')
parser.add_argument('-save-best', type=bool, default=True, help='whether to save when get best performance')
# data
parser.add_argument('-shuffle', action='store_true', default=False, help='shuffle the data every epoch')
# model
parser.add_argument('-dropout', type=float, default=0.5, help='the probability for dropout [default: 0.5]')
parser.add_argument('-max-norm', type=float, default=3.0, help='l2 constraint of parameters [default: 3.0]')
parser.add_argument('-embed-dim', type=int, default=100, help='number of embedding dimension [default: 128]')
parser.add_argument('-kernel-num', type=int, default=100, help='number of each kind of kernel')
parser.add_argument('-kernel-sizes', type=str, default='3,4,5', help='comma-separated kernel size to use for convolution')
parser.add_argument('-static', action='store_true', default=False, help='fix the embedding')
# device
parser.add_argument('-device', type=int, default=-1, help='device to use for iterate data, -1 mean cpu [default: -1]')
parser.add_argument('-no-cuda', action='store_true', default=False, help='disable the gpu')
# option
parser.add_argument('-snapshot', type=str, default=None, help='filename of model snapshot [default: None]')
parser.add_argument('-predict', type=str, default=None, help='predict the sentence given')
parser.add_argument('-test', action='store_true', default=False, help='train or test')
parser.add_argument('-dataset', type=int, default=0, help='Which dataset to use, 0 is wine_color labelling, 1 is wine_type[default: 0]')
args = parser.parse_args()
def clean_str(string):
"""
Tokenization/string cleaning for all datasets except for SST.
Original taken from https://github.com/yoonkim/CNN_sentence/blob/master/process_data.py
"""
string = re.sub("[^a-zA-Z]"," ", string)
#string = re.sub(r"[^A-Za-z0-9(),!?\'\`]", " ", string)
string = re.sub(r"\'s", " \'s", string)
string = re.sub(r"\'ve", " \'ve", string)
string = re.sub(r"n\'t", " n\'t", string)
string = re.sub(r"\'re", " \'re", string)
string = re.sub(r"\'d", " \'d", string)
string = re.sub(r"\'ll", " \'ll", string)
string = re.sub(r",", " , ", string)
string = re.sub(r"!", " ! ", string)
string = re.sub(r"\(", " \( ", string)
string = re.sub(r"\)", " \) ", string)
string = re.sub(r"\?", " \? ", string)
string = re.sub(r"\s{2,}", " ", string)
return string.strip()
# load SST dataset
def sst(text_field, label_field, **kargs):
train_data, dev_data, test_data = datasets.SST.splits(text_field, label_field, fine_grained=True)
text_field.build_vocab(train_data, dev_data, test_data)
label_field.build_vocab(train_data, dev_data, test_data)
train_iter, dev_iter, test_iter = data.BucketIterator.splits(
(train_data, dev_data, test_data),
batch_sizes=(args.batch_size,
len(dev_data),
len(test_data)),
**kargs)
return train_iter, dev_iter, test_iter
def wine(text_field, label_field, **kargs):
train_data, dev_data = winedataset.WINE_TYPE.splits(text_field, label_field)
text_field.build_vocab(train_data, dev_data)
label_field.build_vocab(train_data, dev_data)
train_iter, dev_iter = data.Iterator.splits(
(train_data, dev_data),
batch_sizes=(args.batch_size, len(dev_data)),
**kargs)
return train_iter, dev_iter
def rw(text_field, label_field, **kargs):
train_data, dev_data = winedataset.RED_WHITE.splits(text_field, label_field)
text_field.build_vocab(train_data, dev_data)
label_field.build_vocab(train_data, dev_data)
train_iter, dev_iter = data.Iterator.splits(
(train_data, dev_data),
batch_sizes=(args.batch_size, len(dev_data)),
**kargs)
return train_iter, dev_iter
# load data
print("\nLoading data...")
text_field = data.Field(lower=True)
label_field = data.Field(sequential=False)
#train_iter, dev_iter = mr(text_field, label_field, device=-1, repeat=False)
#train_iter, dev_iter, test_iter = sst(text_field, label_field, device=-1, repeat=False)
#train_iter, dev_iter = wine(text_field, label_field, device=-1, repeat=False)
if args.dataset == 0:
train_iter, dev_iter = rw(text_field, label_field, device=-1, repeat=False)
else:
train_iter, dev_iter = wine(text_field, label_field, device=-1, repeat=False)
#update vocab for word 2 vec embeddings
#w2v = gensim.models.Word2Vec.load('wine2vec.model')
#text_field.vocab = w2v.wv.vocab
#label_field.vocab = w2v.wv.vocab
# update args and print
args.embed_num = len(text_field.vocab)
args.class_num = len(label_field.vocab) - 1
args.cuda = (not args.no_cuda) and torch.cuda.is_available(); del args.no_cuda
args.kernel_sizes = [int(k) for k in args.kernel_sizes.split(',')]
args.save_dir = os.path.join(args.save_dir, datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S'))
args.vocab = text_field.vocab
print("\nParameters:")
for attr, value in sorted(args.__dict__.items()):
print("\t{}={}".format(attr.upper(), value))
# model
cnn = model.CNN_Text(args)
if args.snapshot is not None:
print('\nLoading model from {}...'.format(args.snapshot))
cnn.load_state_dict(torch.load(args.snapshot))
if args.cuda:
torch.cuda.set_device(args.device)
cnn = cnn.cuda()
# train or predict
if args.predict is not None:
label = train.predict(clean_str(args.predict), cnn, text_field, label_field, args.cuda)
print('\n[Text] {}\n[Label] {}\n'.format(args.predict, label))
elif args.test:
try:
train.eval(test_iter, cnn, args)
except Exception as e:
print("\nSorry. The test dataset doesn't exist.\n")
else:
print()
try:
train.train(train_iter, dev_iter, cnn, args)
except KeyboardInterrupt:
print('\n' + '-' * 89)
print('Exiting from training early')