-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathautomated_multi_class_annotation.py
More file actions
292 lines (273 loc) · 12.7 KB
/
automated_multi_class_annotation.py
File metadata and controls
292 lines (273 loc) · 12.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
import cv2
import dlib
import argparse
import os
import sys
import numpy as np
from dialogue_box import *
# command line arguments
parser = argparse.ArgumentParser()
parser.add_argument("-i", "--input_video", help="The relative path of the video", required=True)
parser.add_argument("-s", "--save_path", help="The path to save the images and annotations", required=True)
parser.add_argument("-n", "--save_every", type=int, default=10,
help="Only every n-th frame is saved as an annotation", required=False)
parser.add_argument("-o", "--start_number", type=int, default=0,
help="Starting number of the annotations naming sequence", required=False)
parser.add_argument("-c", "--classes", required=True, default="object", help="the class names separataed by ','")
parser.add_argument("-w", "--width", type=int, required=False, default=800, help="the width of the window")
parser.add_argument("--small_object", type=float, default=0.1,
help="threshold for small object, the lower the value, the smaller the object")
parser.add_argument("--frame_delay", type=int, default=1,
help="delay between two consecutive frames in ms")
parser.add_argument("--start_frame", default=0, type=int, required=False, help="starting frame in the video")
parser.add_argument("--skip_frames", default=1, type=int, required=False, help="number of frames to skip")
args = parser.parse_args()
print("Note: Please try to set the number of object trackers as per your system configuration")
# initialize some global variables
frame = None
dragging = False
temp_start_point = []
tracking = False
window_name = "Automated Labelling"
# current rectangles
all_bounding_boxes = {}
all_current_position = {}
# create a dictionary to store the dlib trackers with keys as class names and a list that containes dlib tracker objects
idx_trackers = {}
# save colors as a dictionary to use later in algorithms
color = {"track": (0, 255, 0),
"text": (0, 0, 0),
"line": (150, 255, 180)}
# parse and save the command line arguments
save_counter = args.start_number
save_every = args.save_every
input_vid = args.input_video
save_path = args.save_path
opencv_window_width = args.width
small_thresh = args.small_object
time_delay = args.frame_delay
start_pos = args.start_frame
skip_frames = args.skip_frames
# warning if save_every%skip_frames != 0 then they may never coincide
if save_every % skip_frames != 0 or save_counter % skip_frames != 0:
raise AssertionError(
"\n1. please make sure the skip_frames is a factorial of save_every_frame(n) to avoid errors while saving images\n" +
"for example: n=10 and skip_frames=2 will work but n=10 and skip_frames=3 will not work very well\n"
"2. Please make sure the skip_frames is a factorial of start_number(o)\n")
# create the number of classes from command line arguments
classes = args.classes.split(',')
classes = [w.strip() for w in classes]
set_classes(classes)
# check that at least one class is present
if len(classes) == 0:
raise AssertionError("The classes should not be empty!")
# check if the path already exists
if os.path.exists(save_path):
raise AssertionError("The save path already exists please enter a new path")
# print("okay fix this in prod")
else:
# make sure the path does not include a / at the end
if save_path.endswith("/"):
save_path = save_path[:-1]
# create the required directories
os.mkdir(save_path)
os.mkdir(os.path.join(save_path, "images"))
os.mkdir(os.path.join(save_path, "annotations"))
# this is used to save the yolo format class id
class_idx = {k: i for i, k in enumerate(classes)}
# save the classes in classes.txt
with open(save_path + "/annotations/classes.txt", 'w') as f:
f.write("\n".join(classes))
# this is used to prevent single click being detected as an object
def area_of(points, shape):
global small_thresh
h, w, _ = shape
x1, y1, x2, y2 = points
return (y2 - y1) * (x2 - x1) > (w * h) * small_thresh / 100
def get_points_order(pts):
res = [0] * 4
res[0] = min((pts[0], pts[2]))
res[1] = min((pts[1], pts[3]))
res[2] = max((pts[0], pts[2]))
res[3] = max((pts[1], pts[3]))
return res
# the call back function of cv2 window
def draw_annotation(event, x, y, flags, params):
global dragging, temp_start_point, frame, tracking, save_counter, classes
# this temporary frame is used for drawing rectangles so that the main frame is not effected
temp_frame = frame.copy()
# draw all the bounding boxes and corresponding class labels
for key, val in all_bounding_boxes.items():
for v in val:
x1, y1, x2, y2 = v
temp_box = temp_frame[y1:y2, x1:x2]
# print(temp_frame[y1:y2, x1:x2][:, :, 0].shape)
cv2.rectangle(temp_frame, (x1, y1), (x2, y2), (0, 255, 0))
cv2.putText(temp_frame, key, (x1, y1), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color["text"], thickness=2)
if x1 < x < x2 and y1 < y < y2:
temp_frame[y1:y2, x1:x2, 0] = np.array([[190] * temp_box.shape[1]] * temp_box.shape[0])
if not tracking:
h, w, _ = temp_frame.shape
cv2.line(temp_frame, (x, 0), (x, h), color["line"])
cv2.line(temp_frame, (0, y), (w, y), color["line"])
cv2.imshow(window_name, temp_frame)
# if a rectangle is double clicked, the bounding box is deleted
if event == cv2.EVENT_RBUTTONDBLCLK and not tracking:
delete_these = []
# iterate over bounding boxes to identify the box selected
for key, val in all_bounding_boxes.items():
for i, v in enumerate(val):
x1, y1, x2, y2 = v
# check click condition
if x1 < x < x2 and y1 < y < y2:
delete_these.append([key, i, (y2 - y1) * (x2 - x1)])
# sort so that the smallest area item is selected
delete_these = sorted(delete_these, key=lambda inp: inp[2])
# delete from dictionary separately as dictionary cannot be altered during a loop
for key, i, _ in delete_these:
del all_bounding_boxes[key][i]
break
# the dragging feature is enabled and the top xy coords of the image are added to a local variable
if event == cv2.EVENT_LBUTTONDOWN and not tracking:
print("recog lbutton down")
dragging = True
temp_start_point = [x, y]
# check if drawing the bounding box is done
elif event == cv2.EVENT_LBUTTONUP and not tracking:
# disable drawing
dragging = False
# store the points in a variable
points = [temp_start_point[0], temp_start_point[1], x, y]
points = get_points_order(points)
# check if the small object thresh condition is satisfied
if not area_of(points, temp_frame.shape):
pass
else:
temp_start_point = []
# enable tracking
# tracking = True
# get class from a dialog box dropdown
cls = select_class_name(classes)
# append bounding box to all bounding boxes
if cls not in all_bounding_boxes:
all_bounding_boxes[cls] = []
all_bounding_boxes[cls].append(points)
elif not tracking:
# here we continuously update the frame with the bounding box
if dragging:
if len(temp_start_point) > 0:
cv2.rectangle(temp_frame, tuple(temp_start_point), (x, y), (255, 0, 0))
cv2.imshow(window_name, temp_frame)
# the main function
def main():
# create a window which will display the UI
cv2.namedWindow(window_name)
# add the callback function
cv2.setMouseCallback(window_name, draw_annotation)
# define the global variables
global frame, tracking, idx_trackers, start_pos, save_counter, input_vid, save_every, save_path, opencv_window_width
global skip_frames, time_delay
# initialize video capture
cap = cv2.VideoCapture(input_vid)
cap.set(cv2.CAP_PROP_POS_FRAMES, start_pos)
print(input_vid)
# read the first frame
ret, frame = cap.read()
H, W, _ = frame.shape
# define resize ratio so that the width is frame_width
resize_ratio = opencv_window_width / W
timer = 0
paused = True
# start the main loop
while ret:
current_points = []
# resize the frame so that the width = frame_width px
frame = cv2.resize(frame, None, fx=resize_ratio, fy=resize_ratio)
h, w, _ = frame.shape
# this is used for drawing the rectangle so the original frame is not effected
temp_frame = frame.copy()
assigned = False
# update object positions with tracking
if not paused and tracking:
# track the object initialized in the call back function
tracker_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
# iterate over the trackers to get the updated positions
for key, dlib_trackers in idx_trackers.items():
for ind, dlib_tracker in enumerate(dlib_trackers):
dlib_tracker.update(tracker_rgb)
# get the position of the tracking object in the updated frame
pos = dlib_tracker.get_position()
x1 = int(pos.left())
y1 = int(pos.top())
x2 = int(pos.right())
y2 = int(pos.bottom())
# update all bounding box positions which are used in call back function
all_bounding_boxes[key][ind] = [x1, y1, x2, y2]
# draw the bounding box on the frame
cv2.rectangle(temp_frame, (x1, y1), (x2, y2), color["track"])
# draw the class name
cv2.putText(temp_frame, key, (x1, y1), cv2.FONT_HERSHEY_SIMPLEX, 0.5,
color["text"], thickness=2)
assigned = True
# add the annotations and class name to current_points list
current_points.append([x1, y1, x2, y2, key])
# update the frame with the new tracked object frame
cv2.imshow(window_name, temp_frame)
key_press = cv2.waitKey(timer) & 0xFF
if key_press == ord('q'):
break
# pause the next frame or play from the next frame
elif key_press == ord('p'):
if not paused:
timer = 0
paused = True
print("paused at: ", save_counter)
cv2.imshow(window_name, frame)
assigned = False
tracking = False
else:
timer = time_delay
paused = False
tracking = True
# delete all trackers and will be reinitialized with object positions
delete_trackers = [w for w in idx_trackers.keys()]
for tracker in delete_trackers:
del idx_trackers[tracker]
tracker_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
# initialize trackers for all bounding boxes
for key, val in all_bounding_boxes.items():
idx_trackers[key] = []
for x1, y1, x2, y2 in val:
rect = dlib.rectangle(x1, y1, x2, y2)
idx_trackers[key].append(dlib.correlation_tracker())
idx_trackers[key][-1].start_track(tracker_rgb, rect)
# check if it is not paused and it the object is being tracked
if not paused and tracking:
# check if the object is in fact tracked
if assigned:
# check if the counter satisfies the skip condition
if save_counter % save_every == 0:
if len(current_points) > 0:
lines = []
cv2.imwrite(save_path + "/images/" + str(save_counter) + ".jpg", frame)
# iterate over the frames to save the annotations in txt files
for x1, y1, x2, y2, cls in current_points:
aw = x2 - x1
ah = y2 - y1
cx = (x1 + x2) / 2
cy = (y1 + y2) / 2
ax = cx / w
ay = cy / h
aw /= w
ah /= h
lines.append(str(class_idx[cls]) + " %0.6f %0.6f %0.6f %0.6f" % (ax, ay, aw, ah))
with open(save_path + "/annotations/" + str(save_counter) + ".txt", 'w') as f:
f.write("\n".join(lines))
# increment the save counter
save_counter += 1
# read the next frame
for _ in range(skip_frames):
ret, frame = cap.read()
save_counter += 1
if __name__ == "__main__":
main()