-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathpreprocessing.py
More file actions
181 lines (153 loc) · 7.01 KB
/
preprocessing.py
File metadata and controls
181 lines (153 loc) · 7.01 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
import random
import cv2
import numpy as np
import skimage.io as io
from pycocotools.coco import COCO
from tensorflow.keras.preprocessing.image import ImageDataGenerator
def get_class_name(class_id, cats):
for i in range(len(cats)):
if cats[i]['id'] == class_id:
return cats[i]['name']
return None
def filter_dataset(folder, classes=None, mode='train'):
# initialize COCO api for instance annotations
ann_file = '{}/annotations/instances_{}.json'.format(folder, mode)
coco = COCO(ann_file)
images = []
if classes is not None:
# iterate for each individual class in the list
for className in classes:
# get all images containing given categories
cat_ids = coco.getCatIds(catNms=className)
img_ids = coco.getImgIds(catIds=cat_ids)
images += coco.loadImgs(img_ids)
else:
img_ids = coco.getImgIds()
images = coco.loadImgs(img_ids)
# Now, filter out the repeated images
unique_images = []
for i in range(len(images)):
if images[i] not in unique_images:
unique_images.append(images[i])
random.shuffle(unique_images)
dataset_size = len(unique_images)
return unique_images, dataset_size, coco
def segment_to_polygon(segmentation):
polygon = []
for partition in segmentation:
for x, y in zip(partition[::2], partition[1::2]):
polygon.append((x, y))
return polygon
def center_crop(img, polygon, input_image_size):
# get centroid
x = [p[0] for p in polygon]
y = [p[1] for p in polygon]
mid_x, mid_y = int(sum(x) / len(polygon)), int(sum(y) / len(polygon))
# process crop width and height for max available dimension
cleft = max(int(mid_x - input_image_size[0] / 2), 0)
cright = int(mid_x + input_image_size[0] / 2)
ctop = max(int(mid_y - input_image_size[1] / 2), 0)
cbottom = int(mid_y + input_image_size[1] / 2)
crop_img = img[ctop:cbottom, cleft:cright]
return crop_img
def get_image(image_obj, img_folder, input_image_size, polygon):
# Read and normalize an image
train_img = io.imread(img_folder + '/' + image_obj['file_name']) / 255.0
# Crop and resize
try:
cropped_img = center_crop(train_img, polygon, input_image_size)
train_img = cv2.resize(cropped_img, input_image_size)
except (AssertionError, TypeError):
train_img = cv2.resize(train_img, input_image_size)
if len(train_img.shape) == 3 and train_img.shape[2] == 3:
# If it is a RGB 3 channel image
return train_img
else:
# To handle a black and white image, increase dimensions to 3
stacked_img = np.stack((train_img,) * 3, axis=-1)
return stacked_img
def get_binary_masks(image_obj, coco, cat_ids, input_image_size):
masks = []
labels = []
polygons = []
ann_ids = coco.getAnnIds(image_obj['id'], catIds=cat_ids, iscrowd=None)
for ann_id in ann_ids:
anns = coco.loadAnns(ann_id)
mask = coco.annToMask(anns[0])
polygon = segment_to_polygon(anns[0]['segmentation'])
try:
cropped_mask = center_crop(mask, polygon, input_image_size)
train_mask = cv2.resize(cropped_mask, input_image_size)
except (AssertionError, TypeError):
train_mask = cv2.resize(mask, input_image_size)
masks.append(train_mask)
labels.append(coco.loadCats(anns[0]['category_id'])[0]['name'])
polygons.append(polygon)
return masks, labels, polygons
def augment_data(gen, aug_generator_args, seed=None):
# Initialize the image data generator with args provided
image_gen = ImageDataGenerator(**aug_generator_args)
# Remove the brightness argument for the mask. Spatial arguments similar to image.
aug_generator_args_mask = aug_generator_args.copy()
_ = aug_generator_args_mask.pop('brightness_range', None)
np.random.seed(seed if seed is not None else np.random.choice(range(9999)))
for img, mask, label in gen:
seed = np.random.choice(range(9999))
# keep the seeds syncronized otherwise the augmentation of the images
# will end up different from the augmentation of the masks
g_x = image_gen.flow(255 * img,
batch_size=img.shape[0],
seed=seed,
shuffle=False)
g_x_masked = image_gen.flow(255 * mask,
batch_size=img.shape[0],
seed=seed,
shuffle=False)
img_aug = next(g_x) / 255.0
img_masked_aug = next(g_x_masked) / 255.0
yield img_aug, img_masked_aug, label
def data_generator(images, classes, coco, folder, input_image_size=(224, 224, 3),
batch_size=4, mode='train'):
img_folder = '{}/images/{}'.format(folder, mode)
dataset_size = len(images)
cat_ids = coco.getCatIds(catNms=classes)
input_image_size = input_image_size[:2]
c = 0
while True:
img = np.zeros((batch_size, input_image_size[0], input_image_size[1], 3)).astype('float')
img_masked = np.zeros((batch_size, input_image_size[0], input_image_size[1], 3)).astype('float')
label = [None]*batch_size
for i in range(c, c + batch_size): # initially from 0 to batch_size, when c = 0
image_obj = images[i]
# Retrieve and mask image
masks, labels, polygons = get_binary_masks(image_obj, coco, cat_ids, input_image_size)
for train_mask, train_label, polygon in zip(masks, labels, polygons):
train_mask = train_mask[:, :, np.newaxis]
train_img = get_image(image_obj, img_folder, input_image_size, polygon)
train_img_masked = train_img * train_mask
# Add to respective batch sized arrays
img[i - c] = train_img
img_masked[i - c] = train_img_masked
label[i - c] = train_label
c += batch_size
if c + batch_size >= dataset_size:
c = 0
random.shuffle(images)
yield img, img_masked, label
def dataloader(classes, data_dir, input_image_size, batch_size, mode):
images, dataset_size, coco = filter_dataset(data_dir, classes, mode)
data_gen = data_generator(images, classes, coco, data_dir, input_image_size, batch_size, mode)
aug_generator_args = dict(featurewise_center=False,
samplewise_center=False,
rotation_range=5,
width_shift_range=0.01,
height_shift_range=0.01,
brightness_range=(0.9, 1.1),
shear_range=0.01,
zoom_range=[1, 1.25],
horizontal_flip=True,
vertical_flip=False,
fill_mode='reflect',
data_format='channels_last')
dataset = augment_data(data_gen, aug_generator_args)
return dataset