-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathSwinTransformer.py
More file actions
486 lines (413 loc) · 20.7 KB
/
SwinTransformer.py
File metadata and controls
486 lines (413 loc) · 20.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
# Source: https://github.com/rishigami/Swin-Transformer-TF/blob/main/swintransformer/model.py
import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import Dense, Dropout, Conv2D, LayerNormalization, GlobalAveragePooling1D
tf.executing_eagerly()
CFGS = {
'swin_tiny_224':
dict(input_size=(224, 224), window_size=4, embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24]),
'swin_small_224':
dict(input_size=(224, 224), window_size=4, embed_dim=96, depths=[2, 2, 18, 2], num_heads=[3, 6, 12, 24]),
'swin_base_224':
dict(input_size=(224, 224), window_size=4, embed_dim=128, depths=[2, 2, 18, 2], num_heads=[4, 8, 16, 32]),
'swin_base_384':
dict(input_size=(384, 384), window_size=8, embed_dim=128, depths=[2, 2, 18, 2], num_heads=[4, 8, 16, 32]),
'swin_large_224':
dict(input_size=(224, 224), window_size=4, embed_dim=192, depths=[2, 2, 18, 2], num_heads=[6, 12, 24, 48]),
'swin_large_384':
dict(input_size=(384, 384), window_size=8, embed_dim=192, depths=[2, 2, 18, 2], num_heads=[6, 12, 24, 48])
}
class Mlp(tf.keras.layers.Layer):
def __init__(self, in_features, hidden_features=None, out_features=None, drop=0., prefix=''):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = Dense(hidden_features, name=f'{prefix}/mlp/fc1')
self.fc2 = Dense(out_features, name=f'{prefix}/mlp/fc2')
self.drop = Dropout(drop)
def call(self, x):
x = self.fc1(x)
x = tf.keras.activations.gelu(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
def window_partition(x, window_size):
B, H, W, C = x.get_shape().as_list()
x = tf.reshape(x, shape=[-1, H // window_size[0],
window_size[0], W // window_size[1], window_size[1], C])
x = tf.transpose(x, perm=[0, 1, 3, 2, 4, 5])
# This is [B * H * W / Ws ^ 2, Ws, Ws, C]
windows = tf.reshape(x, shape=[-1, window_size[0], window_size[1], C])
return windows
def window_reverse(windows, window_size, H, W, C):
x = tf.reshape(windows, shape=[-1, H // window_size[0],
W // window_size[0], window_size[1], window_size[1], C])
x = tf.transpose(x, perm=[0, 1, 3, 2, 4, 5])
x = tf.reshape(x, shape=[-1, H, W, C])
return x
class WindowAttention(tf.keras.layers.Layer):
def __init__(self, dim, window_size, num_heads, qkv_bias=True,
qk_scale=None, attn_drop=0., proj_drop=0., prefix=''):
super().__init__()
self.dim = dim
self.window_size = window_size
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim ** -0.5
self.prefix = prefix
self.qkv = Dense(dim * 3, use_bias=qkv_bias,
name=f'{self.prefix}/attn/qkv')
self.attn_drop = Dropout(attn_drop)
self.proj = Dense(dim, name=f'{self.prefix}/attn/proj')
self.proj_drop = Dropout(proj_drop)
def build(self, input_shape):
self.relative_position_bias_table = self.add_weight(f'{self.prefix}/attn/relative_position_bias_table',
shape=(
(2 * self.window_size[0] - 1) *
(2 * self.window_size[1] - 1), self.num_heads),
initializer=tf.initializers.Zeros(), trainable=True)
coords_h = np.arange(self.window_size[0])
coords_w = np.arange(self.window_size[1])
coords = np.stack(np.meshgrid(coords_h, coords_w, indexing='ij'))
coords_flatten = coords.reshape(2, -1)
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
relative_coords = relative_coords.transpose([1, 2, 0])
relative_coords[:, :, 0] += self.window_size[0] - 1
relative_coords[:, :, 1] += self.window_size[1] - 1
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
relative_position_index = relative_coords.sum(-1)
self.relative_position_index = tf.Variable(initial_value=tf.convert_to_tensor(
relative_position_index), trainable=False, name=f'{self.prefix}/attn/relative_position_index')
self.built = True
def call(self, x, mask=None):
# Trying to find what the dimensionality should be here. B_ I think is batch, N is seq length, C is channels
print("x_shape in attn_block = {}".format(x.get_shape().as_list()))
B_, N, C = x.get_shape().as_list()
qkv = tf.transpose(tf.reshape(self.qkv(x), shape=[-1, N, 3, self.num_heads, C // self.num_heads]),
perm=[2, 0, 3, 1, 4])
q, k, v = qkv[0], qkv[1], qkv[2]
print("q_shape in attn_block = {}".format(q.get_shape().as_list()))
q = q * self.scale
attn = (q @ tf.transpose(k, perm=[0, 1, 3, 2]))
print("attn_shape = {}".format(attn.get_shape().as_list()))
relative_position_bias = tf.gather(self.relative_position_bias_table, tf.reshape(
self.relative_position_index, shape=[-1]))
relative_position_bias = tf.reshape(relative_position_bias, shape=[
self.window_size[0] * self.window_size[1],
self.window_size[0] * self.window_size[1], -1])
relative_position_bias = tf.transpose(
relative_position_bias, perm=[2, 0, 1])
attn = attn + tf.expand_dims(relative_position_bias, axis=0)
if mask is not None:
nW = mask.get_shape()[0] # tf.shape(mask)[0]
attn = tf.reshape(attn, shape=[-1, nW, self.num_heads, N, N]) + tf.cast(
tf.expand_dims(tf.expand_dims(mask, axis=1), axis=0), tf.float32)
attn = tf.reshape(attn, shape=[-1, self.num_heads, N, N])
attn = tf.nn.softmax(attn, axis=-1)
else:
attn = tf.nn.softmax(attn, axis=-1)
attn = self.attn_drop(attn)
x = tf.transpose((attn @ v), perm=[0, 2, 1, 3])
x = tf.reshape(x, shape=[-1, N, C])
x = self.proj(x)
x = self.proj_drop(x)
return x
def drop_path(inputs, drop_prob, is_training):
if (not is_training) or (drop_prob == 0.):
return inputs
# Compute keep_prob
keep_prob = 1.0 - drop_prob
# Compute drop_connect tensor
random_tensor = keep_prob
shape = (inputs.shape[0],) + (1,) * \
(tf.experimental.numpy.ndim(inputs) - 1)
random_tensor += tf.random.uniform(shape, dtype=inputs.dtype)
binary_tensor = tf.floor(random_tensor)
output = tf.math.divide(inputs, keep_prob) * binary_tensor
return output
class DropPath(tf.keras.layers.Layer):
def __init__(self, drop_prob=None):
super().__init__()
self.drop_prob = drop_prob
def call(self, x, training=None):
return drop_path(x, self.drop_prob, training)
class SwinTransformerBlock(tf.keras.layers.Layer):
def __init__(self, dim, input_resolution, num_heads, window_size=[4, 5], shift_size=0, mlp_ratio=4., qkv_bias=True,
qk_scale=None, drop=0., attn_drop=0., drop_path_prob=0., norm_layer=LayerNormalization, prefix=''):
super().__init__()
self.dim = dim
self.input_resolution = input_resolution
self.num_heads = num_heads
self.window_size = window_size
self.shift_size = shift_size
self.mlp_ratio = mlp_ratio
if min(self.input_resolution) <= self.window_size[1]:
self.shift_size = 0
self.window_size[0] = min(self.input_resolution)
self.window_size[1] = min(self.input_resolution)
assert 0 <= self.shift_size < min(self.window_size), "shift_size must in 0-window_size"
self.prefix = prefix
self.norm1 = norm_layer(epsilon=1e-5, name=f'{self.prefix}/norm1')
self.attn = WindowAttention(dim, window_size=(self.window_size[0], self.window_size[1]),
num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
attn_drop=attn_drop, proj_drop=drop, prefix=self.prefix)
self.drop_path = DropPath(drop_path_prob if drop_path_prob > 0. else 0.)
self.norm2 = norm_layer(epsilon=1e-5, name=f'{self.prefix}/norm2')
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim,
drop=drop, prefix=self.prefix)
def build(self, input_shape):
if self.shift_size > 0:
H, W = self.input_resolution
img_mask = np.zeros([1, H, W, 1])
h_slices = (slice(0, -self.window_size[0]),
slice(-self.window_size[1], -self.shift_size),
slice(-self.shift_size, None))
w_slices = (slice(0, -self.window_size[0]),
slice(-self.window_size[1], -self.shift_size),
slice(-self.shift_size, None))
cnt = 0
for h in h_slices:
for w in w_slices:
img_mask[:, h, w, :] = cnt
cnt += 1
img_mask = tf.convert_to_tensor(img_mask)
mask_windows = window_partition(img_mask, self.window_size)
mask_windows = tf.reshape(
mask_windows, shape=[-1, self.window_size[0] * self.window_size[1]])
attn_mask = tf.expand_dims(
mask_windows, axis=1) - tf.expand_dims(mask_windows, axis=2)
attn_mask = tf.where(attn_mask != 0, -100.0, attn_mask)
attn_mask = tf.where(attn_mask == 0, 0.0, attn_mask)
self.attn_mask = tf.Variable(
initial_value=attn_mask, trainable=False, name=f'{self.prefix}/attn_mask')
else:
self.attn_mask = None
self.built = True
def call(self, x):
H, W = self.input_resolution
B, L, C = x.get_shape().as_list()
assert L == H * W, "input feature has wrong size"
shortcut = x
x = self.norm1(x)
x = tf.reshape(x, shape=[-1, H, W, C])
# cyclic shift
if self.shift_size > 0:
shifted_x = tf.roll(
x, shift=[-self.shift_size, -self.shift_size], axis=[1, 2])
else:
shifted_x = x
# partition windows
x_windows = window_partition(shifted_x, self.window_size)
x_windows = tf.reshape(
x_windows, shape=[-1, self.window_size[0] * self.window_size[1], C])
# W-MSA/SW-MSA
attn_windows = self.attn(x_windows, mask=self.attn_mask)
# merge windows
attn_windows = tf.reshape(
attn_windows, shape=[-1, self.window_size[0], self.window_size[1], C])
shifted_x = window_reverse(attn_windows, self.window_size, H, W, C)
# reverse cyclic shift
if self.shift_size > 0:
x = tf.roll(shifted_x, shift=[
self.shift_size, self.shift_size], axis=[1, 2])
else:
x = shifted_x
x = tf.reshape(x, shape=[-1, H * W, C])
# FFN
x = shortcut + self.drop_path(x)
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
class PatchMerging(tf.keras.layers.Layer):
def __init__(self, input_resolution, dim, norm_layer=LayerNormalization, prefix=''):
super().__init__()
self.input_resolution = input_resolution
self.dim = dim
self.reduction = Dense(2 * dim, use_bias=False,
name=f'{prefix}/downsample/reduction')
self.norm = norm_layer(epsilon=1e-5, name=f'{prefix}/downsample/norm')
def call(self, x):
H, W = self.input_resolution
B, L, C = x.get_shape().as_list()
assert L == H * W, "input feature has wrong size"
assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
x = tf.reshape(x, shape=[-1, H, W, C])
x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
x = tf.concat([x0, x1, x2, x3], axis=-1)
x = tf.reshape(x, shape=[-1, (H // 2) * (W // 2), 4 * C])
x = self.norm(x)
x = self.reduction(x)
return x
class BasicLayer(tf.keras.layers.Layer):
def __init__(self, dim, input_resolution, depth, num_heads, window_size,
mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
drop_path_prob=0., norm_layer=LayerNormalization, downsample=None, use_checkpoint=False, prefix=''):
super().__init__()
self.dim = dim
self.input_resolution = input_resolution
self.depth = depth
self.use_checkpoint = use_checkpoint
self.blocks = []
# build blocks
for i in range(depth):
layer = SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
num_heads=num_heads, window_size=window_size,
shift_size=0 if (
i % 2 == 0) else min(window_size) // 2,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop, attn_drop=attn_drop,
drop_path_prob=drop_path_prob[i] if isinstance(
drop_path_prob, list) else drop_path_prob,
norm_layer=norm_layer,
prefix=f'{prefix}/blocks{i}')
self.blocks.append(layer)
if downsample is not None:
self.downsample = downsample(
input_resolution, dim=dim, norm_layer=norm_layer, prefix=prefix)
else:
self.downsample = None
def call(self, x):
# x = self.blocks(x)
for block in self.blocks:
x = block(x)
# features = tf.reshape(features, [-1, self.input_resolution[0], self.input_resolution[1],
# self.dim // (self.input_resolution[0] * self.input_resolution[1])])
# This may be where we need to return out features. In the middle of this call block.
if self.downsample is not None:
feature = x
x = self.downsample(x)
else:
feature = None
return x, feature
class PatchEmbed(tf.keras.layers.Layer):
def __init__(self, img_size=(224, 224), patch_size=(4, 4), in_chans=3, embed_dim=96, norm_layer=None):
super().__init__(name='patch_embed')
patches_resolution = [img_size[0] //
patch_size[0], img_size[1] // patch_size[1]]
self.img_size = img_size
self.patch_size = patch_size
self.patches_resolution = patches_resolution
self.num_patches = patches_resolution[0] * patches_resolution[1]
self.in_chans = in_chans
self.embed_dim = embed_dim
self.proj = Conv2D(embed_dim, kernel_size=patch_size,
strides=patch_size, name='proj')
if norm_layer is not None:
self.norm = norm_layer(epsilon=1e-5, name='norm')
else:
self.norm = None
def call(self, x):
B, H, W, C = x.get_shape().as_list()
assert H == self.img_size[0] and W == self.img_size[1], \
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
x = self.proj(x)
x = tf.reshape(
x, shape=[-1, (H // self.patch_size[0]) * (W // self.patch_size[0]), self.embed_dim])
if self.norm is not None:
x = self.norm(x)
return x
class SwinTransformerModel(tf.keras.Model):
def __init__(self, model_name='swin_large_384', include_top=False,
img_size=(224, 224), patch_size=(4, 4), in_chans=3, num_classes=1000,
embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24],
window_size=[4, 5], mlp_ratio=4., qkv_bias=True, qk_scale=None,
drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
norm_layer=LayerNormalization, ape=False, patch_norm=True,
use_checkpoint=False, **kwargs):
super().__init__(name=model_name)
self.include_top = include_top
self.num_classes = num_classes
self.num_layers = len(depths)
self.embed_dim = embed_dim
self.ape = ape
self.patch_norm = patch_norm
self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))
self.mlp_ratio = mlp_ratio
# split image into non-overlapping patches
self.patch_embed = PatchEmbed(
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
norm_layer=norm_layer if self.patch_norm else None)
num_patches = self.patch_embed.num_patches
patches_resolution = self.patch_embed.patches_resolution
self.patches_resolution = patches_resolution
self.basic_layers = []
# absolute postion embedding
if self.ape:
self.absolute_pos_embed = self.add_weight('absolute_pos_embed',
shape=(
1, num_patches, embed_dim),
initializer=tf.initializers.Zeros())
self.pos_drop = Dropout(drop_rate)
# stochastic depth
dpr = [x for x in np.linspace(0., drop_path_rate, sum(depths))]
# build layers
for i_layer in range(self.num_layers):
block = BasicLayer(dim=int(embed_dim * 2 ** i_layer),
input_resolution=(patches_resolution[0] // (2 ** i_layer),
patches_resolution[1] // (2 ** i_layer)),
depth=depths[i_layer],
num_heads=num_heads[i_layer],
window_size=window_size,
mlp_ratio=self.mlp_ratio,
qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop_rate, attn_drop=attn_drop_rate,
drop_path_prob=dpr[sum(depths[:i_layer]):sum(
depths[:i_layer + 1])],
norm_layer=norm_layer,
downsample=PatchMerging if (
i_layer < self.num_layers - 1) else None,
use_checkpoint=use_checkpoint,
prefix=f'layers{i_layer}')
self.basic_layers.append(block)
self.norm = norm_layer(epsilon=1e-5, name='norm')
self.avgpool = GlobalAveragePooling1D()
self.features = []
if self.include_top:
self.head = Dense(num_classes, name='head')
else:
self.head = None
def forward_features(self, x):
x = self.patch_embed(x)
if self.ape:
x = x + self.absolute_pos_embed
x = self.pos_drop(x)
for i, layer in enumerate(self.basic_layers):
x, y = layer(x)
if i < self.num_layers - 1:
self.features.append(y)
x = self.norm(x)
x = self.avgpool(x)
return x
def call(self, x):
x = self.forward_features(x)
if self.include_top:
x = self.head(x)
return x, self.features
def SwinTransformer(model_name='swin_large_384', num_classes=1000, include_top=True, pretrained=True, use_tpu=False, cfgs=CFGS):
cfg = cfgs[model_name]
net, features = SwinTransformerModel(
model_name=model_name, include_top=include_top, num_classes=num_classes, img_size=cfg['input_size'],
window_size=cfg['window_size'], embed_dim=cfg['embed_dim'], depths=cfg['depths'], num_heads=cfg['num_heads']
)
net(tf.keras.Input(shape=(cfg['input_size'][0], cfg['input_size'][1], 3)))
if pretrained is True:
url = f'https://github.com/rishigami/Swin-Transformer-TF/releases/download/v0.1-tf-swin-weights/{model_name}.tgz'
pretrained_ckpt = tf.keras.utils.get_file(
model_name, url, untar=True)
else:
pretrained_ckpt = pretrained
if pretrained_ckpt:
if tf.io.gfile.isdir(pretrained_ckpt):
pretrained_ckpt = f'{pretrained_ckpt}/{model_name}.ckpt'
if use_tpu:
load_locally = tf.saved_model.LoadOptions(
experimental_io_device='/job:localhost')
net.load_weights(pretrained_ckpt, options=load_locally)
else:
net.load_weights(pretrained_ckpt)
return net, features