-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathattempt2.py
More file actions
150 lines (97 loc) · 4.27 KB
/
attempt2.py
File metadata and controls
150 lines (97 loc) · 4.27 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
#!/usr/bin/env python2
# -*- coding: utf-8 -*-
"""
Created on Sun Apr 9 16:45:50 2017
@author: astvd3
"""
from __future__ import division,print_function
from utils import *
import os, json
from glob import glob
import numpy as np
np.set_printoptions(precision=4, linewidth=100)
from matplotlib import pyplot as plt
path="GTSRB/"
g = glob('*')
for d in g: os.mkdir(path+'/valid/'+d)
g = glob('*/*/*.ppm')
shuf = np.random.permutation(g)
for i in range(500): os.rename(shuf[i], path+'/valid/' + shuf[i])
from shutil import copyfile
#In case you want to test it first with small dataset
g = glob('*')
for d in g:
os.mkdir(path+'/sample/train/'+d)
os.mkdir(path+'/sample/valid/'+d)
g = glob('*/*.ppm')
shuf = np.random.permutation(g)
for i in range(800): copyfile(shuf[i], path+'/sample/train/' + shuf[i])
g = glob('*/*.ppm')
shuf = np.random.permutation(g)
for i in range(400): copyfile(shuf[i], path+'/sample/valid/' + shuf[i])
#To convert images to jpeg (Comment this block for most cases)
#g=glob('*/*.ppm')
#for i in range(0,len(g)):
# im=Image.open(g[i])
# im.save(g[i][:-4]+'.jpg')
batch_size=64
batches = get_batches(path+'train', batch_size=batch_size)
val_batches = get_batches(path+'valid', batch_size=batch_size*2, shuffle=False)
(val_classes, trn_classes, val_labels, trn_labels,
val_filenames, filenames, test_filenames) = get_classes(path)
from vgg16bn import Vgg16BN
model = vgg_ft_bn(43)
trn = get_data(path+'train')
val = get_data(path+'valid')
save_array(path+'results/trn.dat', trn)
save_array(path+'results/val.dat', val)
trn = load_array(path+'results/trn.dat')
val = load_array(path+'results/val.dat')
gen = image.ImageDataGenerator()
model.compile(optimizer=Adam(1e-3),
loss='categorical_crossentropy', metrics=['accuracy'])
model.fit(trn, trn_labels, batch_size=batch_size, nb_epoch=3, validation_data=(val, val_labels))
model.save_weights(path+'results/ft1.h5')
model.optimizer.lr=1e-3
model.save_weights(path+'results/ft1.h5')
model.load_weights(path+'results/ft1.h5')
conv_layers,fc_layers = split_at(model, Convolution2D)
conv_model = Sequential(conv_layers)
conv_feat = conv_model.predict(trn)
save_array(path+'results/conv_feat.dat', conv_feat)
conv_val_feat = conv_model.predict(val)
save_array(path+'results/conv_val_feat.dat', conv_val_feat)
save_array(path+'results/conv_feat.dat', conv_feat)
conv_feat = load_array(path+'results/conv_feat.dat')
conv_val_feat = load_array(path+'results/conv_val_feat.dat')
conv_feat.shape
def get_bn_layers(p):
return [
MaxPooling2D(input_shape=conv_layers[-1].output_shape[1:]),
BatchNormalization(axis=1),
Dropout(p/4),
Flatten(),
Dense(512, activation='relu'),
BatchNormalization(),
Dropout(p),
Dense(512, activation='relu'),
BatchNormalization(),
Dropout(p/2),
Dense(43, activation='softmax')
]
p=0.6
bn_model = Sequential(get_bn_layers(p))
bn_model.compile(Adam(lr=0.001), loss='categorical_crossentropy', metrics=['accuracy'])
bn_model.fit(conv_feat, trn_labels, batch_size=batch_size, nb_epoch=3,
validation_data=(conv_val_feat, val_labels))
bn_model.optimizer.lr=3e-6
bn_model.fit(conv_feat, trn_labels, batch_size=batch_size, nb_epoch=3,
validation_data=(conv_val_feat, val_labels))
bn_model.save_weights(path+'results/conv_512_6.h5')
bn_model.evaluate(conv_val_feat, val_labels)
bn_model.load_weights(path+'models/conv_512_6.h5')
signs=['20','30','50','60','70','80','below 80','100','120','No passing','No overtaking by heavy vehicles','Right of way at next crossroad','priority road','Give way','stop','No vehicles','No vehicles over 3.5Tons','No entry','General Caution','Dangerous Curve to left','Dangerous Curve to right','Double curves to left','Bumpy road','slippery road','road narrows on the right','Roadworks','Traffic Signal','Pedestrians','Children Crossing','Bicycle Crossing','Road freezes easily and is then slippery','Wild Animals Crossing','No Parking','Turn Right','Turn Left','Ahead only','Straight or Right','Straight or Left','Keep Right','Keep Left','Roundabout','End of No Overtaking','End of No Overtaking by Heavy vehicles' ]
def show_pred(i):
temp=bn_model.predict(conv_feat[i:i+1])
plt.imshow(np.rollaxis(trn[i],0,3))
print(signs[np.argmax(temp)])