-
Notifications
You must be signed in to change notification settings - Fork 25
Expand file tree
/
Copy pathtrain_operation.py
More file actions
65 lines (51 loc) · 2.54 KB
/
train_operation.py
File metadata and controls
65 lines (51 loc) · 2.54 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
# encoding: utf-
# 8
import tensorflow as tf
import model
import settings
FLAGS = settings.FLAGS
import numpy as np
RMSPROP_DECAY = 0.9 # Decay term for RMSProp.
RMSPROP_MOMENTUM = 0.9 # Momentum in RMSProp.
RMSPROP_EPSILON = 1.0 # Epsilon term for RMSProp.
def train(total_loss, global_step, summaries, batchnorm_updates):
# Calculate the learning rate schedule.
num_batches_per_epoch = (FLAGS.num_examples_per_epoch_for_train /
FLAGS.batch_size)
decay_steps = int(num_batches_per_epoch * FLAGS.num_epochs_per_decay)
# Decay the learning rate exponentially based on the number of steps.
lr = tf.train.exponential_decay(FLAGS.initial_learning_rate,
global_step,
decay_steps,
FLAGS.learning_rate_decay_factor,
staircase=True)
# Create an optimizer that performs gradient descent.
#opt = tf.train.RMSPropOptimizer(lr, RMSPROP_DECAY,
# momentum=RMSPROP_MOMENTUM,
# epsilon=RMSPROP_EPSILON)
opt = tf.train.AdamOptimizer(lr)
grads = opt.compute_gradients(total_loss)
# Add a summary to track the learning rate.
summaries.append(tf.summary.scalar('learning_rate', lr))
# Add histograms for gradients.
for grad, var in grads:
if grad is not None:
summaries.append(
tf.summary.histogram(var.op.name + '/gradients', grad))
# Apply the gradients to adjust the shared variables.
apply_gradient_op = opt.apply_gradients(grads, global_step=global_step)
# Track the moving averages of all trainable variables.
# Note that we maintain a "double-average" of the BatchNormalization
# global statistics. This is more complicated then need be but we employ
# this for backward-compatibility with our previous models.
variable_averages = tf.train.ExponentialMovingAverage(
FLAGS.moving_average_decay, global_step)
# Another possiblility is to use tf.slim.get_variables().
variables_to_average = (tf.trainable_variables() +
tf.moving_average_variables())
variables_averages_op = variable_averages.apply(variables_to_average)
# Group all updates to into a single train op.
batchnorm_updates_op = tf.group(*batchnorm_updates)
train_op = tf.group(apply_gradient_op, variables_averages_op,
batchnorm_updates_op)
return train_op