-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathkNN_classifier.py
More file actions
65 lines (50 loc) · 2.04 KB
/
kNN_classifier.py
File metadata and controls
65 lines (50 loc) · 2.04 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import sqlite3
import math
import libGeneral
'''
Implementation of the nearest Neighbor classifier
'''
DATABASE_NAME = "ex4.db"
connection = sqlite3.connect(DATABASE_NAME)
def classify(classifyVector,n):
cursor = connection.cursor()
sql = "SELECT ID, WORD_VECTOR, CLASS FROM TRAINING WHERE FOR_TESTING = 0;"
resultDictionary = {}
id_class_dictionary = {}
kNNDictionary = {}
cursor.execute(sql)
for row in cursor.fetchall():
id = row[0]
databaseVector = libGeneral.makeDictionaryFromString(row[1])
classname = row[2]
id_class_dictionary[id] = classname
databaseVector = libGeneral.normalizeDictionary(databaseVector)
classifyVector = libGeneral.normalizeDictionary(classifyVector)
sum = 0.0
for key in databaseVector.keys():
if key in classifyVector.keys():
classifyValue = classifyVector[key]
trainingValue = databaseVector[key]
sum += math.fabs(classifyValue - trainingValue)
else:
sum += databaseVector[key]
for key in classifyVector.keys():
if key not in databaseVector.keys():
sum += classifyVector[key]
resultDictionary[id] = sum
lowestValuesDictionary = libGeneral.getLowestValues(resultDictionary,n)
for key in lowestValuesDictionary.keys():
classname = id_class_dictionary[key]
if classname not in kNNDictionary.keys():
kNNDictionary[classname] = 1
else:
kNNDictionary[classname] += 1
'''
aprioriDictionary = libGeneral.calculateAPrioriDictionary(connection)
for key in kNNDictionary.keys():
apriori = aprioriDictionary[key]
kNNDictionary[key] = kNNDictionary[key] * apriori
'''
maxValue = libGeneral.getMaxValueFromDictionary(kNNDictionary)
maxKey = libGeneral.getKeyFromMaxValueFromDictionary(kNNDictionary)
return [maxKey, maxValue]