forked from MashiMaroLjc/YOLO
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathutils.py
More file actions
122 lines (98 loc) · 3.21 KB
/
utils.py
File metadata and controls
122 lines (98 loc) · 3.21 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
# coding:utf-8
from mxnet import nd
from mxnet import image
import numpy as np
def find_range(location, s):
"""
:param x:
:param y:
:param location:
:return:
"""
x_min, y_min, x_max, y_max = location
if x_min < 0 or y_min < 0 or x_max < 0 or y_max < 0:
return -1
x_center = 0.5 * (x_max + x_min)
y_center = 0.5 * (y_max + y_min)
ceil = 1. / s
row = int(y_center / ceil)
columns = int(x_center / ceil)
index = (row * s + columns)
return index
def translate_locat(x_min, y_min, x_max, y_max):
"""
:param x_min:
:param y_min:
:param x_max:
:param y_max:
:return: x_center, y_center,w,h
"""
center_x = 0.5 * (x_max + x_min)
center_y = 0.5 * (y_max + y_min)
w = (x_max - x_min)
h = (y_max - y_min)
return center_x, center_y, w, h
def translate_y(label, s, b, c):
"""
:param y:
:param s:
:param b:
:param c:
:return:
"""
y_ = label.asnumpy()
labels = y_[:, 0] + 1
location = y_[:, 1:]
batch = len(label)
new_y = np.zeros(shape=[batch, s * s * (b * 5 + c)])
for i, locat in enumerate(location):
labels_ = np.zeros(shape=(s * s, c))
preds_ = np.zeros(shape=(s * s, b))
location_ = np.zeros(shape=(s * s, b, 4))
index = find_range(locat, s)
if index == -1:
labels_[:, 0] = 1
labels_ = labels_.reshape((s * s * c,))
preds_ = preds_.reshape((s * s * b,))
location_ = location_.reshape((s * s * b * 4,))
new_y[i] = (np.concatenate([labels_, preds_, location_], axis=0))
continue
for index_ in range(s * s):
if index_ != index:
labels_[index_][0] = 1
labels_[index][int(labels[i])] = 1
x_min, y_min, x_max, y_max = locat
x, y, w, h = translate_locat(x_min, y_min, x_max, y_max)
w, h = np.sqrt(w), np.sqrt(h)
ceil = 1 / s
x, y = round(x % ceil, 4), round(y % ceil, 4)
for j, b_ in enumerate(preds_[index]):
if b_ != 1:
preds_[index][j] = 1
location_[index][j] = [x, y, w, h]
break
labels_ = labels_.reshape((s * s * c,))
preds_ = preds_.reshape((s * s * b,))
location_ = location_.reshape((s * s * b * 4,))
new_y[i] = (np.concatenate([labels_, preds_, location_], axis=0))
return new_y
def deal_output(y: nd.NDArray, s, b, c):
"""
:param y:
:param s:
:param b:
:param c:
:return:
"""
label = y[:, 0:s * s * c]
preds = y[:, s * s * c: s * s * c + s * s * b]
location = y[:, s * s * c + s * s * b:]
label = nd.reshape(label, shape=(-1, s * s, c))
location = nd.reshape(location, shape=(-1, s * s, b, 4))
return label, preds, location
def process_image(fname, data_shape, rgb_mean, rgb_std):
with open(fname, 'rb') as f:
im = image.imdecode(f.read())
data = image.imresize(im, data_shape, data_shape)
data = (data.astype('float32') - rgb_mean) / rgb_std
return data.transpose((2, 0, 1)).expand_dims(axis=0), im