-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathpredict.py
More file actions
99 lines (89 loc) · 3.22 KB
/
predict.py
File metadata and controls
99 lines (89 loc) · 3.22 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import sys
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import pyscreenshot as ImageGrab
def main():
class_names = check_args()
print(f"--Load Model {sys.argv[2]}--")
# Load the model that should be in sys.argv[2]
model = tf.keras.models.load_model(sys.argv[2])
print(f"--Load Image {sys.argv[3]}--")
img = plt.imread(sys.argv[3])
if np.amax(img.flatten()) > 1:
img = img / 255
img = 1 - img
print(f"--Predict as Class {sys.argv[4]}--")
predict(model, class_names, img, int(sys.argv[4]))
def predict(model, class_names, img, true_label):
img = np.array([img])
# Replace these two lines with code to make a prediction
prediction = model.predict(img)[0]
print("Prediction: ", prediction)
# Determine what the predicted label is
predicted_label = prediction.argmax(axis=-1)
print("Prediction label: ", predicted_label)
plot(class_names, prediction, true_label, predicted_label, img[0])
plt.show()
def check_args():
if (len(sys.argv) == 1):
print("No arguments so using defaults")
sys.argv = ["predict.py", "notMNIST", "notMNIST.h5", "image.png",
input("Predict the class of image.png(index):")]
if (len(sys.argv) != 5):
print("Usage python predict.py <MNIST,notMNIST> <model.h5> <image.png> <prediction class index>")
sys.exit(1)
if sys.argv[1] == "MNIST":
print("--Dataset MNIST--")
class_names = ["A", "B", "C", "D", "E", "F", "G", "H", "I", "J"]
elif sys.argv[1] == "notMNIST":
print("--Dataset notMNIST--")
class_names = ["A", "B", "C", "D", "E", "F", "G", "H", "I", "J"]
else:
print(f"Choose MNIST or notMNIST, not {sys.argv[1]}")
sys.exit(2)
if sys.argv[2][-3:] != ".h5":
print(f"{sys.argv[2]} is not a h5 extension")
sys.exit(3)
if sys.argv[3][-4:] != ".png":
print(f"{sys.argv[3]} is not a png extension")
sys.exit(3)
img = plt.imread(sys.argv[3])
if len(img.shape) != 2:
print("Image is not grey scale!")
sys.exit(4)
if img.shape != (28, 28):
print("Image is not 28 by 28!")
sys.exit(4)
if not sys.argv[4].isdigit():
print(f"{sys.argv[4]} is not an integer (0-9)")
sys.exit(3)
if int(sys.argv[4]) < 0 or int(sys.argv[4]) > 9:
print(f"{sys.argv[4]} is not an integer (0-9)")
sys.exit(3)
return class_names
def plot(class_names, prediction, true_label, predicted_label, img):
plt.figure(figsize=(6, 3))
plt.subplot(1, 2, 1)
plt.grid(False)
plt.xticks([])
plt.yticks([])
plt.imshow(img, cmap=plt.cm.binary)
predicted_label = np.argmax(prediction)
if predicted_label == true_label:
color = 'blue'
else:
color = 'red'
plt.xlabel(
"{} {:2.0f}% ({})".format(class_names[predicted_label], 100 * np.max(prediction), class_names[true_label]),
color=color)
plt.subplot(1, 2, 2)
plt.grid(False)
plt.xticks(range(10))
plt.yticks([])
thisplot = plt.bar(class_names, prediction, color="#777777")
plt.ylim([0, 1])
thisplot[predicted_label].set_color('red')
thisplot[true_label].set_color('blue')
if __name__ == "__main__":
main()