-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathMlp.py
More file actions
53 lines (43 loc) · 1.39 KB
/
Mlp.py
File metadata and controls
53 lines (43 loc) · 1.39 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
from keras import layers
import tensorflow as tf
from tensorflow import keras
import numpy as np
class Mlp:
def __int__(self):
self.dims = None
self.num_classes = None
self.model = None
self.inputs = None
def model_create(self, dims, num_classes):
self.dims = dims
self.num_classes = num_classes
self.inputs = keras.Input(shape=self.dims)
# Image augmentation block
rot = 0.1
flip = "horizontal"
data_augmentation = keras.Sequential(
[
layers.RandomFlip(flip),
layers.RandomRotation(rot),
]
)
x = data_augmentation(self.inputs)
# re-scale
x = layers.Rescaling(1.0 / 255)(x)
# re-shape
x = layers.Reshape([-1, np.prod(self.dims)])(x)
x = keras.backend.squeeze(x=x, axis=1)
for size in [128, 256, 512, 728, 1024]:
x = layers.Dense(size, activation="relu")(x)
x = layers.BatchNormalization()(x)
units = None
if self.num_classes == 2:
activation = "sigmoid"
units = 1
else:
activation = "softmax"
units = self.num_classes
x = layers.Dropout(0.5)(x)
outputs = layers.Dense(units, activation=activation)(x)
self.model = keras.Model(self.inputs, outputs)
self.model.summary()