-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathmodels.py
More file actions
329 lines (255 loc) · 15 KB
/
models.py
File metadata and controls
329 lines (255 loc) · 15 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
"""
Code based on the official MAE Implementation available at:
https://github.com/facebookresearch/mae
"""
import torch
import torch.nn as nn
from timm.models.vision_transformer import VisionTransformer, PatchEmbed, Block
class ViTLayer(nn.Module):
def __init__(self, num_heads, embed_dim, encoder_mlp_hidden, dropout=0.1):
super().__init__()
self.layernorm1 = nn.LayerNorm(embed_dim)
self.msa = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout, batch_first=True)
self.attn_dropout = nn.Dropout(dropout)
self.layernorm2 = nn.LayerNorm(embed_dim)
self.mlp = nn.Sequential(nn.Linear(embed_dim, encoder_mlp_hidden),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(encoder_mlp_hidden, embed_dim),
nn.Dropout(dropout))
def forward(self, x):
# keep track of shapes
norm = self.layernorm1(x)
attn, _ = self.msa(norm, norm, norm)
attn = self.attn_dropout(attn)
x = x + attn
norm = self.layernorm2(x)
x = x + self.mlp(norm)
return x
class ViT(nn.Module):
def __init__(self, patch_dim, image_dim, num_layers, num_heads, embed_dim, encoder_mlp_hidden, num_classes, dropout):
super().__init__()
self.num_layers = num_layers
self.patch_dim = patch_dim
self.image_dim = image_dim
self.input_dim = self.patch_dim * self.patch_dim * 3
self.patch_embedding = nn.Linear(self.input_dim, embed_dim)
self.position_embedding = nn.Parameter(torch.zeros(1, (image_dim // patch_dim) ** 2 + 1, embed_dim))
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
self.embedding_dropout = nn.Dropout(dropout)
self.encoder_layers = nn.ModuleList([])
for i in range(num_layers):
self.encoder_layers.append(ViTLayer(num_heads, embed_dim, encoder_mlp_hidden, dropout))
self.mlp_head = nn.Linear(embed_dim, num_classes)
self.layernorm = nn.LayerNorm(embed_dim)
def forward(self, images):
h = w = self.image_dim // self.patch_dim
N = images.size(0)
images = images.reshape(N, 3, h, self.patch_dim, w, self.patch_dim)
images = torch.einsum("nchpwq -> nhwpqc", images)
patches = images.reshape(N, h * w, self.input_dim)
patch_embeddings = self.patch_embedding(patches)
patch_embeddings = torch.cat([torch.tile(self.cls_token, (N, 1, 1)),
patch_embeddings], dim=1)
out = patch_embeddings + torch.tile(self.position_embedding, (N, 1, 1))
out = self.embedding_dropout(out)
for i in range(self.num_layers):
out = self.encoder_layers[i](out)
cls_head = self.layernorm(torch.squeeze(out[:, 0], dim=1))
logits = self.mlp_head(cls_head)
return logits
class MAE_Timm(nn.Module):
def __init__(self, patch_dim, image_dim,
encoder_num_layers, encoder_num_heads, encoder_embed_dim,
decoder_num_layers, decoder_num_heads, decoder_embed_dim,
mlp_ratio, dropout):
super().__init__()
self.patch_dim = patch_dim
self.image_dim = image_dim
self.num_patches = (image_dim // patch_dim) * (image_dim // patch_dim)
self.input_dim = self.patch_dim * self.patch_dim * 3
self.patch_embedding = PatchEmbed(self.image_dim, self.patch_dim, 3, encoder_embed_dim)
self.cls_token = nn.Parameter(torch.zeros(1, 1, encoder_embed_dim))
self.encoder_position_embedding = nn.Parameter(torch.zeros(1, self.num_patches + 1, encoder_embed_dim))
self.encoder_num_layers = encoder_num_layers
self.encoder_layers = nn.ModuleList([])
for _ in range(self.encoder_num_layers):
self.encoder_layers.append(Block(encoder_embed_dim, encoder_num_heads, mlp_ratio=mlp_ratio, qkv_bias=False, drop=dropout, attn_drop=dropout))
self.encoder_layernorm = nn.LayerNorm(encoder_embed_dim)
self.decoder_embedding = nn.Linear(encoder_embed_dim, decoder_embed_dim, bias=True)
self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))
self.decoder_position_embedding = nn.Parameter(torch.zeros(1, self.num_patches + 1, decoder_embed_dim))
self.decoder_num_layers = decoder_num_layers
self.decoder_layers = nn.ModuleList([])
for _ in range(self.decoder_num_layers):
self.decoder_layers.append(Block(decoder_embed_dim, decoder_num_heads, mlp_ratio=mlp_ratio, qkv_bias=False, drop=dropout, attn_drop=dropout))
self.decoder_layernorm = nn.LayerNorm(decoder_embed_dim)
self.image_projection = nn.Linear(decoder_embed_dim, self.input_dim)
self.init_weights()
def patchify(self, images):
N = images.shape[0]
h = w = self.image_dim // self.patch_dim
images = images.reshape(N, 3, h, self.patch_dim, w, self.patch_dim)
images = torch.einsum("nchpwq -> nhwpqc", images)
patches = images.reshape(N, self.num_patches, self.input_dim)
return patches
def init_weights(self):
w = self.patch_embedding.proj.weight.data
torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
torch.nn.init.normal_(self.encoder_position_embedding, std=.02)
torch.nn.init.normal_(self.decoder_position_embedding, std=.02)
torch.nn.init.normal_(self.cls_token, std=.02)
torch.nn.init.normal_(self.mask_token, std=.02)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
# we use xavier_uniform following official JAX ViT:
torch.nn.init.xavier_uniform_(m.weight)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def encode(self, images, mask_ratio):
patch_embeddings = self.patch_embedding(images) + self.encoder_position_embedding[:, 1:]
N, _, D = patch_embeddings.shape
rand = torch.rand(N, self.num_patches).to(images.device)
idx_shuffle = torch.argsort(rand, dim=1)
idx_unshuffle = torch.argsort(idx_shuffle, dim=1)
keep = int(self.num_patches * (1 - mask_ratio))
mask = torch.ones(N, self.num_patches).to(images.device)
mask[:, keep:] = 0
mask = torch.gather(mask, dim=1, index=idx_unshuffle)
patch_embeddings = torch.gather(patch_embeddings, dim=1, index=idx_shuffle.unsqueeze(-1).tile(1, 1, D))[:, :keep]
class_tokens = torch.tile(self.cls_token + + self.encoder_position_embedding[:, :1], (N, 1, 1))
out = torch.cat([class_tokens, patch_embeddings], dim=1)
for i in range(self.encoder_num_layers):
out = self.encoder_layers[i](out)
out = self.encoder_layernorm(out)
return patch_embeddings, mask, idx_unshuffle
def decode(self, patches, idx_unshuffle):
patch_embeddings = self.decoder_embedding(patches)
embedding_cls = patch_embeddings[:, :1]
embedding_image = patch_embeddings[:, 1:]
N, L, D = embedding_image.shape
embedding_image = torch.cat([embedding_image, torch.tile(self.mask_token, (N, self.num_patches - L, 1))], dim=1)
embedding_image = torch.gather(embedding_image, dim=1, index=idx_unshuffle.unsqueeze(-1).tile(1, 1, D))
patch_embeddings = torch.cat([embedding_cls, embedding_image], dim=1)
out = patch_embeddings + self.decoder_position_embedding
for i in range(self.decoder_num_layers):
out = self.decoder_layers[i](out)
out = self.decoder_layernorm(out)[:, 1:]
image_patches = self.image_projection(out)
return image_patches
def loss(self, images, pred_patches, mask):
# Normalizing image pixels apparently leads to better representations
patches = self.patchify(images)
mean = patches.mean(dim=-1, keepdim=True)
var = patches.var(dim=-1, keepdim=True)
patches = (patches - mean) / (var + 1.e-6)**.5
loss = (patches - pred_patches) ** 2
loss = loss.mean(dim=-1)
mask = 1 - mask
loss = (loss * mask).sum(dim=1) / mask.sum()
return loss
def recover_reconstructed(self, images, mask_ratio):
image_patches = self.patchify(images)
patches, mask, idx_unshuffle = self.encode(images, mask_ratio)
pred_patches = self.decode(patches, idx_unshuffle)
N = images.shape[0]
h = w = self.image_dim // self.patch_dim
mask = mask.unsqueeze(-1)
masked_images = mask * image_patches
masked_images = masked_images.reshape(N, h, w, self.patch_dim, self.patch_dim, 3)
masked_images = torch.einsum("nhwpqc -> nchpwq", masked_images).reshape(N, 3, self.image_dim, self.image_dim)
mean = image_patches.mean(dim=-1, keepdim=True)
var = image_patches.var(dim=-1, keepdim=True)
reconstructed = pred_patches * (var + 1.e-6)**.5 + mean
# reconstructed = mask * image_patches + (1 - mask) * reconstructed
reconstructed = reconstructed.reshape(N, h, w, self.patch_dim, self.patch_dim, 3)
reconstructed = torch.einsum("nhwpqc -> nchpwq", reconstructed).reshape(N, 3, self.image_dim, self.image_dim)
return masked_images, reconstructed
def forward(self, images, mask_ratio=0.0):
patches, mask, idx_unshuffle = self.encode(images, mask_ratio)
pred_patches = self.decode(patches, idx_unshuffle)
loss = self.loss(images, pred_patches, mask)
return loss
def embeddings(self, images):
patch_embeddings = self.patch_embedding(images) + self.encoder_position_embedding[:, 1:]
N, _, _ = patch_embeddings.shape
class_tokens = torch.tile(self.cls_token + self.encoder_position_embedding[:, :1], (N, 1, 1))
out = torch.cat([class_tokens, patch_embeddings], dim=1)
for i in range(self.encoder_num_layers):
out = self.encoder_layers[i](out)
out = self.encoder_layernorm(out)
embeddings = out[:, 0]
return embeddings
class MAE_Classifier(nn.Module):
def __init__(self, mae, embed_dim, num_classes, fine_tune=False):
super().__init__()
self.mae = mae
self.mae.eval()
self.bn = nn.BatchNorm1d(embed_dim)
self.linear = nn.Linear(embed_dim, num_classes)
if not fine_tune:
for parameter in self.mae.parameters():
parameter.requires_grad = False
def forward(self, images):
mae_embeds = self.mae.embeddings(images)
latent = self.bn(mae_embeds)
return latent, self.linear(latent)
def get_vit_tiny(num_classes=10, patch_dim=4, image_dim=32):
return ViT(patch_dim=patch_dim, image_dim=image_dim, num_layers=12, num_heads=3,
embed_dim=192, encoder_mlp_hidden=768, num_classes=num_classes, dropout=0.1)
def get_vit_small(num_classes=10, patch_dim=4, image_dim=32):
return ViT(patch_dim=patch_dim, image_dim=image_dim, num_layers=12, num_heads=6,
embed_dim=384, encoder_mlp_hidden=1536, num_classes=num_classes, dropout=0.1)
def get_vit_base(num_classes=10, patch_dim=4, image_dim=32):
return ViT(patch_dim=patch_dim, image_dim=image_dim, num_layers=12, num_heads=12,
embed_dim=768, encoder_mlp_hidden=3072, num_classes=num_classes, dropout=0.1)
def get_vit_large(num_classes=10, patch_dim=4, image_dim=32):
return ViT(patch_dim=patch_dim, image_dim=image_dim, num_layers=24, num_heads=16,
embed_dim=1024, encoder_mlp_hidden=4096, num_classes=num_classes, dropout=0.1)
def get_vit_huge(num_classes=10, patch_dim=4, image_dim=32):
return ViT(patch_dim=patch_dim, image_dim=image_dim, num_layers=32, num_heads=16,
embed_dim=1280, encoder_mlp_hidden=5120, num_classes=num_classes, dropout=0.1)
def get_vit_tiny_timm(num_classes=10, patch_dim=4, image_dim=32):
return VisionTransformer(img_size=image_dim, patch_size=patch_dim, num_classes=num_classes, embed_dim=192,
depth=12, num_heads=3, mlp_ratio=4, qkv_bias=False, drop_rate=0.1, attn_drop_rate=0.1)
def get_vit_small_timm(num_classes=10, patch_dim=4, image_dim=32):
return VisionTransformer(img_size=image_dim, patch_size=patch_dim, num_classes=num_classes, embed_dim=384,
depth=12, num_heads=6, mlp_ratio=4, qkv_bias=False, drop_rate=0.1, attn_drop_rate=0.1)
def get_vit_base_timm(num_classes=10, patch_dim=4, image_dim=32):
return VisionTransformer(img_size=image_dim, patch_size=patch_dim, num_classes=num_classes, embed_dim=768,
depth=12, num_heads=12, mlp_ratio=4, qkv_bias=False, drop_rate=0.1, attn_drop_rate=0.1)
def get_vit_large_timm(num_classes=10, patch_dim=4, image_dim=32):
return VisionTransformer(img_size=image_dim, patch_size=patch_dim, num_classes=num_classes, embed_dim=1024,
depth=24, num_heads=16, mlp_ratio=4, qkv_bias=False, drop_rate=0.1, attn_drop_rate=0.1)
def get_vit_huge_timm(num_classes=10, patch_dim=4, image_dim=32):
return VisionTransformer(img_size=image_dim, patch_size=patch_dim, num_classes=num_classes, embed_dim=1280,
depth=32, num_heads=16, mlp_ratio=4, qkv_bias=False, drop_rate=0.1, attn_drop_rate=0.1)
def get_mae_tiny(patch_dim=2, image_dim=32):
return MAE_Timm(patch_dim=patch_dim, image_dim=image_dim,
encoder_num_layers=12, encoder_num_heads=3, encoder_embed_dim=192,
decoder_num_layers=4, decoder_num_heads=4, decoder_embed_dim=192,
mlp_ratio=4, dropout=0.1)
def get_mae_small(patch_dim=2, image_dim=32):
return MAE_Timm(patch_dim=patch_dim, image_dim=image_dim,
encoder_num_layers=12, encoder_num_heads=6, encoder_embed_dim=384,
decoder_num_layers=4, decoder_num_heads=8, decoder_embed_dim=256,
mlp_ratio=4, dropout=0.1)
def get_mae_base(patch_dim=2, image_dim=32):
return MAE_Timm(patch_dim=patch_dim, image_dim=image_dim,
encoder_num_layers=12, encoder_num_heads=12, encoder_embed_dim=768,
decoder_num_layers=8, decoder_num_heads=16, decoder_embed_dim=512,
mlp_ratio=4, dropout=0.1)
def get_mae_large(patch_dim=2, image_dim=32):
return MAE_Timm(patch_dim=patch_dim, image_dim=image_dim,
encoder_num_layers=24, encoder_num_heads=16, encoder_embed_dim=1024,
decoder_num_layers=8, decoder_num_heads=16, decoder_embed_dim=512,
mlp_ratio=4, dropout=0.1)
def get_mae_huge(patch_dim=2, image_dim=32):
return MAE_Timm(patch_dim=patch_dim, image_dim=image_dim,
encoder_num_layers=32, encoder_num_heads=16, encoder_embed_dim=1280,
decoder_num_layers=8, decoder_num_heads=16, decoder_embed_dim=512,
mlp_ratio=4, dropout=0.1)