-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathsync_net.py
More file actions
358 lines (285 loc) · 14 KB
/
sync_net.py
File metadata and controls
358 lines (285 loc) · 14 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
from torch.autograd import Variable
from torch.optim import lr_scheduler
from torchsummary import summary
import numpy as np
from itertools import combinations
from data_loader import get_datasets
# from trainer import fit
from torch.utils.data import DataLoader
from utils import pairwise_distances
cuda = torch.cuda.is_available()
class TripletLoss(nn.Module):
"""
Triplet loss
Takes embeddings of an anchor sample, a positive sample and a negative sample
Taken from https://github.com/adambielski/siamese-triplet/blob/master/losses.py
"""
def __init__(self, margin):
super(TripletLoss, self).__init__()
self.margin = margin
def forward(self, anchor, positive, negative, size_average=True):
distance_positive = (anchor - positive).pow(2).sum(1)
distance_negative = (anchor - negative).pow(2).sum(1)
losses = F.relu(distance_positive - distance_negative + self.margin)
return losses.mean() if size_average else losses.sum()
class CosineSimilarityTripletLoss(nn.Module):
"""
Cosine Similarity Triplet loss
"""
def __init__(self, margin):
super(CosineSimilarityTripletLoss, self).__init__()
self.margin = margin
def forward(self, anchor, positive, negative, size_average=True):
batch_size, embedding_size = anchor.shape
normalized_anchors = anchor / torch.norm(anchor, dim=-1).view(batch_size, 1)
normalized_positives = positive / torch.norm(positive, dim=-1).view(batch_size, 1)
normalized_negatives = negative / torch.norm(negative, dim=-1).view(batch_size, 1)
# print("normalized anchors", normalized_anchors)
# print("normalized positives", normalized_positives)
# print("normalized negatives", normalized_negatives)
positive_similarities = torch.bmm(normalized_anchors.view(batch_size, 1, embedding_size), normalized_positives.view(batch_size, embedding_size, 1)) # It works like a batched dot product
negative_similarities = torch.bmm(normalized_anchors.view(batch_size, 1, embedding_size), normalized_negatives.view(batch_size, embedding_size, 1)) # It works like a batched dot product
positive_distances = 1 - positive_similarities
negative_distances = 1 - negative_similarities
# print("pos dist", positive_distances)
# print("neg dist", negative_distances)
losses = F.relu(positive_distances - negative_distances + self.margin)
return losses.mean() if size_average else losses.sum()
class LosslessTripletLoss(nn.Module):
"""
Class taken (and modified) from
Lossless Triplet loss
A more efficient loss function for Siamese NN
by Marc-Olivier Arsenault
Feb 15, 2018
https://towardsdatascience.com/lossless-triplet-loss-7e932f990b24
"""
"""
N -- The number of dimension
beta -- The scaling factor, N is recommended
epsilon -- The Epsilon value to prevent ln(0)
"""
def __init__(self, N=3, beta=None, epsilon=1e-8):
super(LosslessTripletLoss, self).__init__()
self.N = N
self.beta = N if beta is None else beta
self.epsilon = epsilon
def forward(self, anchor, positive, negative, size_average=True):
# distance between the anchor and the positive
pos_dist = torch.sum(torch.pow(anchor - positive, 2), 1)
# distance between the anchor and the negative
neg_dist = torch.sum(torch.pow(anchor - negative, 2), 1)
# -ln(-x/N+1)
pos_dist = -torch.log(-(pos_dist / self.beta) + 1 + self.epsilon)
neg_dist = -torch.log(-((self.N - neg_dist) / self.beta) + 1 + self.epsilon)
# compute loss
losses = neg_dist + pos_dist
# TODO find why it sometimes return nan
return losses.mean() if size_average else losses.sum()
class OnlineTripletLoss(nn.Module):
"""
Online Triplets loss
Takes a batch of embeddings and corresponding labels.
Triplets are generated using triplet_selector object that take embeddings and targets and return indices of
triplets
Taken from https://github.com/adambielski/siamese-triplet/blob/master/losses.py
"""
def __init__(self, margin, triplet_selector):
super(OnlineTripletLoss, self).__init__()
self.margin = margin
self.triplet_selector = triplet_selector
def forward(self, embeddings, target):
triplets = self.triplet_selector.get_triplets(embeddings, target)
if embeddings.is_cuda:
triplets = triplets.cuda()
ap_distances = (embeddings[triplets[:, 0]] - embeddings[triplets[:, 1]]).pow(2).sum(1) # .pow(.5)
an_distances = (embeddings[triplets[:, 0]] - embeddings[triplets[:, 2]]).pow(2).sum(1) # .pow(.5)
losses = F.relu(ap_distances - an_distances + self.margin)
return losses.mean(), len(triplets)
class MultiSiameseCosineSimilarityLoss(nn.Module):
"""
Multi Siamese Similarity loss
Takes a batch of embeddings with masks for positive pairs and negative pairs.
Useful to get a loss over several pairs with few images.
"""
"""
Parameters
embeddings: matrix of size (batch_size, embedding_size)
positive_matrix: matrix of positive pairs of size (batch_size, batch_size)
negative_matrix: matrix of negative pairs of size (batch_size, batch_size)
Returns
loss: scalar between 0 and 4 where 0 represents perfect similarity between positive pairs and perfect dissimilarity
between negative pairs while 4 is the opposite.
average_positive_similarity: normalized dot product of positive pairs of embeddings
average_negative_similarity: normalized dot product of negative pairs of embeddings
"""
def forward(self, embeddings, positive_matrix, negative_matrix):
batch_size, embedding_size = embeddings.shape
# normalize embeddings
normalized_embeddings = embeddings / torch.norm(embeddings, dim=-1).view(batch_size, 1)
# calculate cosine similarity for every combination
cosine_similarities = torch.bmm(normalized_embeddings.view(1, batch_size, embedding_size),
normalized_embeddings.t().view(1, embedding_size, batch_size)).cpu()
# apply the masks (positive and negative matrices) over the cosine similarity matrix
positive_similarities = cosine_similarities * positive_matrix
negative_similarities = cosine_similarities * negative_matrix
# calculate the average positive and negative similarity
positive_count = positive_matrix.sum()
negative_count = negative_matrix.sum()
average_positive_similarity = positive_similarities.sum() / (positive_count if positive_count > 0 else 1)
average_negative_similarity = negative_similarities.sum() / (negative_count if negative_count > 0 else 1)
positive_value = 1 - average_positive_similarity
negative_value = 1 + average_negative_similarity
loss = positive_value + negative_value
return loss, average_positive_similarity, average_negative_similarity
class SoftMultiSiameseCosineSimilarityLoss(nn.Module):
"""
Soft Multi Siamese Similarity loss
Takes a batch of embeddings to computes the cosine similarity for each pair and compare it with the pair similarity
matrix (ground truth).
Useful to get a loss over several pairs with few images.
"""
"""
Parameters
embeddings: 1 or 2 matrices of size (batch_size, embedding_size)
similarity_matrix: matrix of pair similarity of size (batch_size, batch_size)
masks: matrix of masks of size (batch_size, batch_size) to consider only some pairs
Returns
loss: scalar between 0 and 1 where 0 represents perfect pair similarity while 1 is the opposite.
"""
def forward(self, embeddings, similarity_matrix, masks):
if len(embeddings) != len(similarity_matrix[0]):
embedding_size = embeddings.shape[1]
batch_size_a = similarity_matrix.shape[1]
batch_size_b = similarity_matrix.shape[2]
# normalize embeddings
normalized_embeddings_a = embeddings[:batch_size_a] / torch.norm(embeddings[:batch_size_a], dim=-1).view(batch_size_a, 1)
normalized_embeddings_b = embeddings[batch_size_a:] / torch.norm(embeddings[batch_size_a:], dim=-1).view(batch_size_b, 1)
# calculate cosine similarity for every combination
cosine_similarities = torch.bmm(normalized_embeddings_a.view(1, batch_size_a, embedding_size),
normalized_embeddings_b.t().view(1, embedding_size, batch_size_b)).cpu()
else:
batch_size, embedding_size = embeddings.shape
# normalize embeddings
normalized_embeddings = embeddings / torch.norm(embeddings, dim=-1).view(batch_size, 1)
# calculate cosine similarity for every combination
cosine_similarities = torch.bmm(normalized_embeddings.view(1, batch_size, embedding_size),
normalized_embeddings.t().view(1, embedding_size, batch_size)).cpu()
# we want the similarity to be between 0 (dissimilar) to 1 (similar)
cosine_similarities = (cosine_similarities + 1) / 2
# apply the masks to ignore some pairs
cosine_similarities *= masks
similarity_matrix *= masks
diff = (cosine_similarities - similarity_matrix) ** 2
nonzero = torch.nonzero(diff)
count = nonzero.shape[0]
similarity_loss = diff.sum() / count if count > 0 else torch.tensor(0)
return similarity_loss, similarity_loss, 0
class TripletNet(nn.Module):
"""
https://github.com/adambielski/siamese-triplet/blob/master/networks.py
"""
def __init__(self, embedding_net):
super(TripletNet, self).__init__()
self.embedding_net = embedding_net
def forward(self, x):
# print("TripletNet input", x.shape)
if len(x.shape) == 4:
return self.embedding_net(x)
x1 = x[:, 0]
x2 = x[:, 1]
x3 = x[:, 2]
# print(x.type(), "vs", x1.type())
output1 = self.embedding_net(x1)
output2 = self.embedding_net(x2)
output3 = self.embedding_net(x3)
return output1, output2, output3
def get_embedding(self, x):
return self.embedding_net(x)
class MultiSiameseNet(nn.Module):
def __init__(self, embedding_net):
super(MultiSiameseNet, self).__init__()
self.embedding_net = embedding_net
def forward(self, x):
# (batch_size, channels, width, height)
# print("MultiSiameseNet input", x.shape)
embeddings = self.embedding_net(x)
return embeddings
def reset_first_and_last_layers(model):
# reset the weights of the first convolution layer because our data will have 3 channels, but not RGB
reset_first_layer(model)
# reset the weights of the fully connected layer at the end because we want to learn an embedding that is useful to our sequence and not to classify images
model.fc = nn.Linear(2048, 16)
nn.init.xavier_uniform_(model.fc.weight)
def reset_first_layer(model):
# reset the weights of the first convolution layer because our data will have 3 channels, but not RGB
nn.init.xavier_uniform_(model.conv1.weight)
def replace_last_layer(model, out_features, dropout=0):
# reset the weights of the fully connected layer at the end because we want to learn an embedding that is useful to our sequence and not to classify images
if dropout == 0:
fc = nn.Linear(get_fc_weights(model).shape[-1], out_features)
else:
fc = nn.Sequential(
nn.Dropout(dropout),
nn.Linear(get_fc_weights(model).shape[-1], out_features)
)
if hasattr(model, 'classifier'):
model.classifier = fc
else:
model.fc = fc
nn.init.xavier_uniform_(get_fc_weights(model))
def get_fc_weights(model):
fc = None
if hasattr(model, 'classifier'):
fc = model.classifier
elif hasattr(model, '_fc'):
fc = model._fc
elif hasattr(model, 'fc'):
fc = model.fc
if type(fc) == nn.modules.container.Sequential:
return list(fc.modules())[-1].weight
return fc.weight
def add_sigmoid_activation(model):
return nn.Sequential(model, nn.modules.Sigmoid())
def stop_running_var(layer):
if isinstance(layer, nn.BatchNorm2d):
layer.track_running_stats = False
def freeze_model(model):
for param in model.parameters():
param.requires_grad = False
if __name__ == "__main__":
# torch.cuda.set_device(0)
# embedding_net = models.resnet50(pretrained=True)
# reset_first_and_last_layers(embedding_net)
# model = TripletNet(embedding_net)
# model.cuda(0)
# model = nn.DataParallel(model).cuda()
# lr = 1e-3
# optimizer = torch.optim.Adam(model.parameters(), lr=lr)
# loss_fn = TripletLoss(margin=0.5)
# scheduler = lr_scheduler.StepLR(optimizer, 8, gamma=0.1, last_epoch=-1)
# n_epochs = 20
# log_interval = 100
# dataset = get_datasets()
# train_loader = DataLoader(dataset, batch_size=20, shuffle=True, num_workers=4)
# fit(train_loader, None, model, loss_fn, optimizer, scheduler, n_epochs, cuda, log_interval)
# for child in embedding_net.named_children():
# print(child)
# summary(embedding_net, (3, 224, 224))
# triplets = triplet_selector.get_triplets(None, torch.tensor(np.array([1, 0, 1, 0, 1])))
# print(triplets)
# loss = CosineSimilarityTripletLoss(margin=0.5)
# a = torch.FloatTensor([[1, 1]])
# b = torch.FloatTensor([[-1, 1]])
# c = torch.FloatTensor([[-1, -1]])
# d = torch.FloatTensor([[0, 1]])
# print("loss", loss.forward(a, b, d))
loss_fn = MultiSiameseCosineSimilarityLoss()
embeddings = torch.FloatTensor([[1, 1], [0.5, 1], [-1, -1], [-0.5, -1]])
positive_matrix = torch.FloatTensor([[0, 1, 0, 0], [1, 0, 0, 0], [0, 0, 0, 1], [0, 0, 1, 0]])
negative_matrix = torch.FloatTensor([[0, 0, 1, 1], [0, 0, 1, 1], [1, 1, 0, 0], [1, 1, 0, 0]])
print("loss", loss_fn.forward(embeddings, positive_matrix, negative_matrix))