-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathKNN.py
More file actions
59 lines (49 loc) · 2.21 KB
/
KNN.py
File metadata and controls
59 lines (49 loc) · 2.21 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
from KNN_FOR_MINST_DATABSE.nearest_neighbor import *
from KNN_FOR_MINST_DATABSE.utils import *
class KNearestNeighbor:
def __init__(self, X, y, k):
self.ytr = y
self.Xtr = X
self.k = k
def predict(self, X):
num_test = X.shape[0]
Ypred = np.zeros(num_test, dtype=self.ytr.dtype)
for i in range(num_test):
# print("正在比较第%s个" % (i+1))
distances = np.sum(np.abs(self.Xtr - X[i, :]), axis=1)
k_ls = []
flag = self.k
while flag:
min_index = np.argmin(distances)
k_ls.append(self.ytr[min_index])
distances = np.delete(distances, min_index, axis=0)
flag -= 1
Ypred[i] = np.argmax(np.bincount(np.array(k_ls)))
return Ypred
def test_nearest_neighbor_classifier(slices):
images_matrix, labels_matrix = load_mnist(mnist_path)
text_images_matrix, text_labels_matrix = load_mnist(mnist_path, kind='t10k')
NNAPP = NearestNeighbor(images_matrix, labels_matrix)
ret = NNAPP.predict(text_images_matrix[0:slices])
# print('predict_labels:', ret)
# print('real_labels:', text_labels_matrix[0:slices])
print(np.mean(ret == text_labels_matrix[0:slices]))
def test_k_nearest_neighbor_classifier(k, slices):
images_matrix, labels_matrix = load_mnist(mnist_path)
text_images_matrix, text_labels_matrix = load_mnist(mnist_path, kind='t10k')
NNAPP = KNearestNeighbor(images_matrix, labels_matrix, k)
ret = NNAPP.predict(text_images_matrix[0:slices])
# print(np.mean(text_labels_matrix == ret))
# print('predict_labels:', ret)
# print('real_labels:', text_labels_matrix[0:slices])
print(np.mean(ret == text_labels_matrix[0:slices]))
if __name__ == '__main__':
'''300为训练个数,越大时间越长,
建议第一次跑30,3,5,6,7,9为k的值,将上面的注释去掉可以打印输出预测的标签和真实标签,
函数输出的是准确率'''
N = 300
test_k_nearest_neighbor_classifier(3, N)
test_k_nearest_neighbor_classifier(5, N)
test_k_nearest_neighbor_classifier(7, N)
test_k_nearest_neighbor_classifier(9, N)
test_nearest_neighbor_classifier(N)