-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdbf_car.py
More file actions
executable file
·125 lines (73 loc) · 2.04 KB
/
dbf_car.py
File metadata and controls
executable file
·125 lines (73 loc) · 2.04 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
#!/usr/bin/python
'''
File: dbf_car.py
Author: MOUSTAID TARIK
Description: This a mnist dataset trainer using deep beleif network.
characters are white and background is black.
image sizes are 32x20 (=640).
'''
import sys
import numpy as np
import cv2
from mlxtend.data import loadlocal_mnist
from sklearn.cross_validation import train_test_split
from sklearn.metrics import classification_report
from sklearn import datasets
from nolearn.dbn import DBN
dataset=[]
datalabel=[]
#old dataset from 0_0 to 0_19, now to 0_28
for i in range(0,10):
for j in range(0,29):
datalabel.append(i)
name="data/"+str(i)+"_"+str(j)+".txt"
f=open(name)
#32x20
s=32*20
lal=np.loadtxt(f)
dat=lal.reshape(1,s).tolist()
dataset.append(dat[0])
datalabel=np.asarray([datalabel]).transpose()
dataset=np.asarray(dataset)
print dataset.shape
print datalabel.shape
#
#
#
#
## train the Deep Belief Network with 640 input units, 700 hidden nodes, 10 output units (one for
## each possible output, from one to ten)
#
dbn = DBN(
[640, 700, 10],
learn_rates = 0.3,
learn_rate_decays = 0.9,
epochs = 50,
verbose = 1
)
dbn.fit(dataset, datalabel)
print ("trained ! ready to predict!")
# compute the predictions for the test data and show a classification
# report
#### predicting
while(1):
dst="."
dst=str(raw_input("image to test? \'q\' to quit:\n"))
if dst == "q":
break
else:
try:
img=cv2.imread(dst)
img=cv2.cvtColor(img,cv2.COLOR_BGR2GRAY)
if img.shape != (32,20):
img=cv2.resize(img,(20,32))
_,img=cv2.threshold(img,100,255,cv2.THRESH_BINARY)
img=img/255.0
print img.shape
img=img.reshape(1,s)
img=img.astype(np.float32)
#prediction:
pred=dbn.predict(img)
print pred
except:
print "error reading image.."