-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain.py
More file actions
188 lines (158 loc) · 6.24 KB
/
train.py
File metadata and controls
188 lines (158 loc) · 6.24 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
import numpy as np
import pickle
from distutils.version import LooseVersion
import warnings
import tensorflow as tf
import preprocess
int_text, vocab_to_int, int_to_vocab, token_dict = preprocess.load_preprocess()
# Check TensorFlow Version
assert LooseVersion(tf.__version__) >= LooseVersion('1.3'), 'Please use TensorFlow version 1.3 or newer'
print('TensorFlow Version: {}'.format(tf.__version__))
# Check for a GPU
if not tf.test.gpu_device_name():
warnings.warn('No GPU found. Please use a GPU to train your neural network.')
else:
print('Default GPU Device: {}'.format(tf.test.gpu_device_name()))
def get_inputs():
"""
Create TF Placeholders for input, targets, and learning rate.
:return: Tuple (input, targets, learning rate)
"""
inputs = tf.placeholder(tf.int32, [None, None], name = 'input')
targets = tf.placeholder(tf.int32, [None, None], name = 'targets')
learning_rate = tf.placeholder(tf.float32, name = 'learning_rate')
return inputs, targets, learning_rate
def get_init_cell(batch_size, rnn_size):
"""
Create an RNN Cell and initialize it.
:param batch_size: Size of batches
:param rnn_size: Size of RNNs
:return: Tuple (cell, initialize state)
"""
lstm = tf.contrib.rnn.BasicLSTMCell(rnn_size)
cell = tf.contrib.rnn.MultiRNNCell([lstm]*2)
# Getting an initial state of all zeros
initial_state = cell.zero_state(batch_size, tf.int32)
initial_state = tf.identity(initial_state, name = 'initial_state')
return cell, initial_state
def get_embed(input_data, vocab_size, embed_dim):
"""
Create embedding for <input_data>.
:param input_data: TF placeholder for text input.
:param vocab_size: Number of words in vocabulary.
:param embed_dim: Number of embedding dimensions
:return: Embedded input.
"""
embedding = tf.Variable(tf.random_uniform([vocab_size, embed_dim]))
embed = tf.nn.embedding_lookup(embedding, input_data)
return embed
def build_rnn(cell, inputs):
"""
Create a RNN using a RNN Cell
:param cell: RNN Cell
:param inputs: Input text data
:return: Tuple (Outputs, Final State)
"""
outputs, final_state = tf.nn.dynamic_rnn(cell, inputs, dtype=tf.float32)
final_state = tf.identity(final_state, name = 'final_state')
return outputs, final_state
def build_nn(cell, rnn_size, input_data, vocab_size, embed_dim):
"""
Build part of the neural network
:param cell: RNN cell
:param rnn_size: Size of rnns
:param input_data: Input data
:param vocab_size: Vocabulary size
:param embed_dim: Number of embedding dimensions
:return: Tuple (Logits, FinalState)
"""
embed = get_embed(input_data, vocab_size, embed_dim)
outputs, final_state = build_rnn(cell, embed)
logits = tf.contrib.layers.fully_connected(outputs, vocab_size, activation_fn = None)
return logits, final_state
def get_batches(int_text, batch_size, seq_length):
"""
Return batches of input and target
:param int_text: Text with the words replaced by their ids
:param batch_size: The size of batch
:param seq_length: The length of sequence
:return: Batches as a Numpy array
"""
n_batches = len(int_text)//(batch_size*seq_length)
int_text = int_text[:n_batches*batch_size*seq_length + 1]
# Initialize result & calculate skip distance
result = np.ndarray(shape = (n_batches, 2, batch_size, seq_length), dtype = int)
skipdistance = n_batches*seq_length
# First loop steps at batch_size * seq_length
for b in range(n_batches):
batch_idx = b*seq_length
for bb in range(batch_size):
idx = batch_idx + (bb*skipdistance) # get starting index for batch
x_idx = idx
y_idx = idx + 1
result[b][0][bb] = int_text[x_idx: x_idx + seq_length]
result[b][1][bb] = int_text[y_idx: y_idx + seq_length]
if n_batches > 1 and batch_size > 1 and seq_length > 1:
result[n_batches - 1][1][batch_size - 1][seq_length - 1] = int_text[0]
return result
# Number of Epochs
num_epochs = 200
# Batch Size
batch_size = 512
# RNN Size
rnn_size = 500
# Embedding Dimension Size
embed_dim = 500
# Sequence Length
seq_length = 30
# Learning Rate
learning_rate = 0.001
# Show stats for every n number of batches
show_every_n_batches = 100
save_dir = './save'
from tensorflow.contrib import seq2seq
train_graph = tf.Graph()
with train_graph.as_default():
vocab_size = len(int_to_vocab)
input_text, targets, lr = get_inputs()
input_data_shape = tf.shape(input_text)
cell, initial_state = get_init_cell(input_data_shape[0], rnn_size)
logits, final_state = build_nn(cell, rnn_size, input_text, vocab_size, embed_dim)
# Probabilities for generating words
probs = tf.nn.softmax(logits, name='probs')
# Loss function
cost = seq2seq.sequence_loss(
logits,
targets,
tf.ones([input_data_shape[0], input_data_shape[1]]))
# Optimizer
optimizer = tf.train.AdamOptimizer(lr)
# Gradient Clipping
gradients = optimizer.compute_gradients(cost)
capped_gradients = [(tf.clip_by_value(grad, -1., 1.), var) for grad, var in gradients if grad is not None]
train_op = optimizer.apply_gradients(capped_gradients)
batches = get_batches(int_text, batch_size, seq_length)
with tf.Session(graph=train_graph) as sess:
sess.run(tf.global_variables_initializer())
for epoch_i in range(num_epochs):
state = sess.run(initial_state, {input_text: batches[0][0]})
for batch_i, (x, y) in enumerate(batches):
# print(x)
feed = {
input_text: x,
targets: y,
initial_state: state,
lr: learning_rate}
train_loss, state, _ = sess.run([cost, final_state, train_op], feed)
# Show every <show_every_n_batches> batches
if (epoch_i * len(batches) + batch_i) % show_every_n_batches == 0:
print('Epoch {:>3} Batch {:>4}/{} train_loss = {:.3f}'.format(
epoch_i,
batch_i,
len(batches),
train_loss))
# Save Model
saver = tf.train.Saver()
saver.save(sess, save_dir)
print('Model Trained and Saved')
preprocess.save_params((seq_length, save_dir))