-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathsvm.py
More file actions
77 lines (71 loc) · 2.23 KB
/
svm.py
File metadata and controls
77 lines (71 loc) · 2.23 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
from sklearn.svm import SVC
from sklearn import svm
from evaluation import *
import os
def svmTraining(dataFolder, trainFile, testFile, trainLabels, testLabels, eventNb, keyframeNb, sampleNb) :
try :
#Training
X = []
Y = []
with open(dataFolder + trainFile, 'r') as fTrain :
for line in fTrain :
line = line.replace('[', '')
line = line.replace(']', '')
line = line.replace('\n', '')
keyframeList = list(line.split(', '))
keyframeMean = []
for x in range(0, eventNb) :
keyframeMean.append(0.0)
for elem in keyframeList :
for x in range(0, eventNb) :
if keyframeList.index(elem) % eventNb == x :
keyframeMean[x] += float(elem)
for x in range(0, eventNb) :
keyframeMean[x] /= keyframeNb
X.append(keyframeMean)
with open(dataFolder + trainLabels, 'r') as labels :
for line in labels :
Y.append(float(line))
svmClass = svm.SVC().fit(X, Y)
#Testing
X = []
Y = []
with open(dataFolder + testFile, 'r') as fTest :
for line in fTest :
line = line.replace('[', '')
line = line.replace(']', '')
line = line.replace('\n', '')
keyframeList = list(line.split(', '))
keyframeMean = []
for x in range(0, eventNb) :
keyframeMean.append(0.0)
for elem in keyframeList :
for x in range(0, eventNb) :
if keyframeList.index(elem) % eventNb == x :
keyframeMean[x] += float(elem)
for x in range(0, eventNb) :
keyframeMean[x] /= keyframeNb
X.append(keyframeMean)
with open(dataFolder + testLabels, 'r') as labels :
for line in labels :
Y.append(float(line))
#Writing the results
prediction = svmClass.predict(X)
svmRes = open(dataFolder + 'testClassification.txt', 'w')
for elem in prediction :
svmRes.write(str(int(elem)))
svmRes.write('\n')
svmRes.close()
#Building the confusion matrix
cMat = getConfusionMatrix(3, dataFolder, testLabels, 'testClassification.txt')
print('Confusion matrix of the SVM method = \n' + str(cMat))
# Process Mean Average Precision
# taking the 3d class into account
MAP3 = getMAP(cMat, 3, {})
print('MAP3 = ' + str(MAP3))
# ignoring the 3d class
MAP2 = getMAP(cMat, 3, {2})
print('MAP2 = ' + str(MAP2))
except IOError:
print('file not found')
return (MAP2, MAP3)