-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmodel.py
More file actions
99 lines (71 loc) · 2.44 KB
/
model.py
File metadata and controls
99 lines (71 loc) · 2.44 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import tensorflow as tf
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Conv2DTranspose
from tensorflow.keras.layers import Activation, BatchNormalization, Concatenate
class ConvBlock(tf.keras.layers.Layer):
def __init__(self, n_filters):
super(ConvBlock, self).__init__()
self.conv1 = Conv2D(n_filters, 3, padding='same')
self.conv2 = Conv2D(n_filters, 3, padding='same')
self.bn1 = BatchNormalization()
self.bn2 = BatchNormalization()
self.activation = Activation('relu')
def call(self, inputs):
x = self.conv1(inputs)
x = self.bn1(x)
x = self.activation(x)
x = self.conv2(x)
x = self.bn2(x)
x = self.activation(x)
return x
class EncoderBlock(tf.keras.layers.Layer):
def __init__(self, n_filters):
super(EncoderBlock, self).__init__()
self.conv_blk = ConvBlock(n_filters)
self.pool = MaxPooling2D((2,2))
def call(self, inputs):
x = self.conv_blk(inputs)
p = self.pool(x)
return x, p
class DecoderBlock(tf.keras.layers.Layer):
def __init__(self, n_filters):
super(DecoderBlock, self).__init__()
self.up = Conv2DTranspose(n_filters, (2,2), strides=2, padding='same')
self.conv_blk = ConvBlock(n_filters)
def call(self, inputs, skip):
x = self.up(inputs)
x = Concatenate()([x, skip])
x = self.conv_blk(x)
return x
class UNET(tf.keras.Model):
def __init__(self, n_classes):
super(UNET, self).__init__()
# Encoder
self.e1 = EncoderBlock(16)
self.e2 = EncoderBlock(32)
self.e3 = EncoderBlock(48)
self.e4 = EncoderBlock(64)
# Bridge
self.b = ConvBlock(128)
# Decoder
self.d1 = DecoderBlock(64)
self.d2 = DecoderBlock(48)
self.d3 = DecoderBlock(32)
self.d4 = DecoderBlock(16)
# Outputs
if n_classes == 1:
activation = 'sigmoid'
else:
activation = 'softmax'
self.outputs = Conv2D(n_classes, 1, padding='same', activation=activation)
def call(self, inputs):
s1, p1 = self.e1(inputs)
s2, p2 = self.e2(p1)
s3, p3 = self.e3(p2)
s4, p4 = self.e4(p3)
b = self.b(p4)
d1 = self.d1(b, s4)
d2 = self.d2(d1, s3)
d3 = self.d3(d2, s2)
d4 = self.d4(d3, s1)
outputs = self.outputs(d4)
return outputs