-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathnormalization.py
More file actions
95 lines (78 loc) · 3.85 KB
/
normalization.py
File metadata and controls
95 lines (78 loc) · 3.85 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import tensorflow as tf
from models.generative.ops import *
from models.generative.activations import *
def batch_norm(inputs, training, c=None, spectral=False, scope=False):
output = tf.layers.batch_normalization(inputs=inputs, training=training)
return output
def instance_norm(inputs, training, c=None, spectral=False, scope=False):
# Not used: training
output = tf.contrib.layers.instance_norm(inputs=inputs)
return output
def layer_norm(inputs, training, c=None, spectral=False, scope=False):
# Not used: training
output = tf.contrib.layers.layer_norm(inputs=inputs, scope=False)
return output
def group_norm(inputs, training, c=None, spectral=False, scope=False):
# Not used: training
output = tf.contrib.layers.group_norm(inputs=inputs)
return output
def conditional_instance_norm(inputs, training, c, scope, spectral=False):
input_dims = inputs.shape.as_list()
if len(input_dims) == 4:
batch, height, width, channels = input_dims
else:
batch, channels = input_dims
with tf.variable_scope('conditional_instance_norm_%s' % scope):
decay = 0.9
epsilon = 1e-5
# MLP for gamma, and beta.
inter_dim = int((channels+c.shape.as_list()[-1])/2)
net = dense(inputs=c, out_dim=inter_dim, scope=1, spectral=spectral, display=False)
net = ReLU(net)
gamma = dense(inputs=net, out_dim=channels, scope='gamma', spectral=spectral, display=False)
gamma = ReLU(gamma)
beta = dense(inputs=net, out_dim=channels, scope='beta', spectral=spectral, display=False)
if len(input_dims) == 4:
gamma = tf.expand_dims(tf.expand_dims(gamma, 1), 1)
beta = tf.expand_dims(tf.expand_dims(beta, 1), 1)
if len(input_dims) == 4:
batch_mean, batch_variance = tf.nn.moments(inputs, axes=[1,2], keep_dims=True)
else:
batch_mean, batch_variance = tf.nn.moments(inputs, axes=[1], keep_dims=True)
batch_norm_output = tf.nn.batch_normalization(inputs, batch_mean, batch_variance, beta, gamma, epsilon)
return batch_norm_output
def conditional_batch_norm(inputs, training, c, scope, spectral=False):
input_dims = inputs.shape.as_list()
if len(input_dims) == 4:
batch, height, width, channels = input_dims
else:
batch, channels = input_dims
with tf.variable_scope('conditional_batch_norm_%s' % scope) :
decay = 0.9
epsilon = 1e-5
test_mean = tf.get_variable("pop_mean", shape=[channels], dtype=tf.float32, initializer=tf.constant_initializer(0.0), trainable=False)
test_variance = tf.get_variable("pop_var", shape=[channels], dtype=tf.float32, initializer=tf.constant_initializer(1.0), trainable=False)
# MLP for gamma, and beta.
inter_dim = int((channels+c.shape.as_list()[-1])/2)
net = dense(inputs=c, out_dim=inter_dim, scope=1, spectral=spectral, display=False)
net = ReLU(net)
gamma = dense(inputs=net, out_dim=channels, scope='gamma', spectral=spectral, display=False)
gamma = ReLU(gamma)
beta = dense(inputs=net, out_dim=channels, scope='beta', spectral=spectral, display=False)
if len(input_dims) == 4:
gamma = tf.expand_dims(tf.expand_dims(gamma, 1), 1)
beta = tf.expand_dims(tf.expand_dims(beta, 1), 1)
if training:
if len(input_dims) == 4:
batch_mean, batch_variance = tf.nn.moments(inputs, axes=[0, 1, 2])
# batch_mean, batch_variance = tf.nn.moments(inputs, axes=[0, 1, 2], keep_dims=True)
else:
batch_mean, batch_variance = tf.nn.moments(inputs, axes=[0, 1])
# batch_mean, batch_variance = tf.nn.moments(inputs, axes=[0, 1], keep_dims=True)
ema_mean = tf.assign(test_mean, test_mean * decay + batch_mean * (1 - decay))
ema_variance = tf.assign(test_variance, test_variance * decay + batch_variance * (1 - decay))
with tf.control_dependencies([ema_mean, ema_variance]):
batch_norm_output = tf.nn.batch_normalization(inputs, batch_mean, batch_variance, beta, gamma, epsilon)
else:
batch_norm_output = tf.nn.batch_normalization(inputs, test_mean, test_variance, beta, gamma, epsilon)
return batch_norm_output