-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathmodel.py
More file actions
156 lines (138 loc) · 5.69 KB
/
model.py
File metadata and controls
156 lines (138 loc) · 5.69 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
from itertools import combinations
import tensorflow as tf
class DEIMOS_Model(tf.keras.Model):
def __init__(self, n_clusters):
super(DEIMOS_Model, self).__init__()
# static parameters
self.lr = 0.5e-3
self.optimizer = tf.keras.optimizers.Adam(self.lr)
self.u_coeff = 0.5
self.l_coeff = 1
self.ul_lr = 0.02
self.n_clusters = n_clusters
self.max_pool = tf.keras.layers.MaxPool2D(pool_size=(2, 2), strides=2)
self.relu = tf.keras.layers.ReLU()
# parameters to optimize
self.lamb = 0
self.conv1 = tf.keras.layers.Conv2D(32, 3, strides=2, activation='relu', padding='valid')
self.batch_norm1 = tf.keras.layers.BatchNormalization(axis=1)
self.conv2 = tf.keras.layers.Conv2D(32, 3, strides=1, activation='relu', padding='same')
self.batch_norm2 = tf.keras.layers.BatchNormalization(axis=1)
self.conv3 = tf.keras.layers.Conv2D(64, 3, strides=1, activation='relu', padding='same')
self.batch_norm3 = tf.keras.layers.BatchNormalization(axis=1)
self.conv4 = tf.keras.layers.Conv2D(64, 3, strides=1, activation='relu', padding='same')
self.batch_norm4 = tf.keras.layers.BatchNormalization(axis=1)
self.fc1 = tf.keras.layers.Dense(self.n_clusters * 8)
self.batch_norm5 = tf.keras.layers.BatchNormalization(axis=1)
self.fc2 = tf.keras.layers.Dense(self.n_clusters, activation='softmax')
'''
Dimensionality Check
Initial: (b, 227, 227, 1)
After conv1: (b, 112, 112, 32)
After pool1: (b, 56, 56, 32)
After conv2: (b, 56, 56, 32)
After pool2: (b, 28, 28, 32)
After conv3: (b, 28, 28, 64)
After pool3: (b, 14, 14, 64)
After conv4: (b, 14, 14, 64)
After pool4: (b, 7, 7, 64)
After flatten: (b, 12544)
After fc1: (b, self.n_clusters)
After fc2: (b, self.n_clusters)
'''
def pretrain_setup(self, n_classes):
self.n_pretrain_classes = n_classes
self.pretrain_fc_out = tf.keras.layers.Dense(n_classes, name='pretrain_output')
self.pretrain_fc_out.trainable = True
self.pretrain_lr = 0.002
self.pretrain_optimizer = tf.keras.optimizers.Adam(self.pretrain_lr)
def call(self, inputs, training=False, mask=None):
outs = self.conv1(inputs)
outs = self.batch_norm1(outs, training=training)
outs = self.max_pool(outs)
outs = self.conv2(outs)
outs = self.max_pool(outs)
outs = self.batch_norm2(outs, training=training)
outs = self.conv3(outs)
outs = self.max_pool(outs)
outs = self.batch_norm3(outs, training=training)
outs = self.conv4(outs)
outs = self.max_pool(outs)
outs = self.batch_norm4(outs, training=training)
outs = tf.keras.layers.Flatten()(outs)
outs = self.fc1(outs)
outs = self.relu(outs)
outs = self.batch_norm5(outs, training=training)
outs = self.fc2(outs)
return outs
def get_clusters(self, inputs):
'''
inputs: Tensor with dimension (_, self.n_clusters)
'''
#inputs += tf.math.reduce_min(inputs)
#inputs, _ = tf.linalg.normalize(inputs, axis=1)
return tf.argmax(inputs, axis=1)
def upper_bound(self):
return 0.95 - self.u_coeff * self.lamb
def lower_bound(self):
return 0.5 + self.l_coeff * self.lamb
def loss_w(self, feats):
#feats += tf.math.reduce_min(feats)
feats, _ = tf.linalg.normalize(feats, axis=1)
loss = 0
ct = 0
for tens_1, tens_2 in combinations(feats, 2):
dot_prod = tf.reduce_sum(tens_1 * tens_2)
if dot_prod < self.lower_bound():
loss -= tf.math.log(1 - dot_prod)
ct += 1
elif dot_prod > self.upper_bound():
loss -= tf.math.log(dot_prod)
ct += 1
if loss == 0:
return None
return loss, ct
def loss_l_update(self):
self.lamb += self.ul_lr * (self.u_coeff + self.l_coeff)
print(f'New upper bound: {self.upper_bound()}')
print(f'New lower bound: {self.lower_bound()}')
# Pretrain call and loss
def call_pretrain(self, inputs):
outs = self.conv1(inputs)
outs = self.batch_norm1(outs)
outs = self.max_pool(outs)
outs = self.conv2(outs)
outs = self.max_pool(outs)
outs = self.batch_norm2(outs)
outs = self.conv3(outs)
outs = self.max_pool(outs)
outs = self.batch_norm3(outs)
outs = self.conv4(outs)
outs = self.max_pool(outs)
outs = self.batch_norm4(outs)
outs = tf.keras.layers.Flatten()(outs)
outs = self.fc1(outs)
outs = self.relu(outs)
outs = self.batch_norm5(outs)
outs = self.pretrain_fc_out(outs)
return outs
def loss_pretrain(self, logits, labels):
labels = tf.one_hot(labels, self.n_pretrain_classes)
loss = tf.keras.losses.categorical_crossentropy(labels, logits, from_logits=True)
return tf.reduce_mean(loss)
def call_feat_output(self, inputs):
outs = self.conv1(inputs)
outs = self.batch_norm1(outs, training=False)
outs = self.max_pool(outs)
outs = self.conv2(outs)
outs = self.max_pool(outs)
outs = self.batch_norm2(outs, training=False)
outs = self.conv3(outs)
outs = self.max_pool(outs)
outs = self.batch_norm3(outs, training=False)
outs = self.conv4(outs)
outs = self.max_pool(outs)
outs = self.batch_norm4(outs, training=False)
outs = tf.keras.layers.Flatten()(outs)
outs = self.fc1(outs)
return outs