-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathpredict.py
More file actions
85 lines (63 loc) · 2.51 KB
/
predict.py
File metadata and controls
85 lines (63 loc) · 2.51 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import argparse
import json
import logging
import tensorflow as tf
import tensorflow_hub as hub
import numpy as np
from PIL import Image
def parse_args():
'''create parser'''
parser = argparse.ArgumentParser(description='Find a flower type')
# Input Command line parameter. Reads input image path.
parser.add_argument('input_image', action='store',
type=str, help='input image path')
parser.add_argument('model', action='store',
default='./ryn_model.h5', type=str, help='model for predict')
parser.add_argument('--top_k', dest="top_k", action='store', default=5,
type=int, help='number of top predictions')
parser.add_argument('--category_names', action='store', default='label_map.json',
type=str, dest="category_names", help='category names for prediction labels')
return parser.parse_args()
def load_model(model_path):
'''load keras model'''
loaded_model = tf.keras.models.load_model(
model_path, custom_objects={'KerasLayer': hub.KerasLayer}, compile=False)
return loaded_model
def load_class_names(label_map):
'''load class names for match with labels'''
with open(label_map, 'r') as f:
return json.load(f)
def process_image(image):
'''normalize image'''
image_size = 224
image = tf.cast(image, tf.float32)
image = tf.image.resize(image, (image_size, image_size))
image /= 255
return image
def predict(image_path, model, class_names, top_k):
''' Predict topK classes with highest probability.'''
# convert type
image = Image.open(image_path)
image = np.asarray(image)
image = process_image(image)
image = np.expand_dims(image, axis=0)
# predict
prob_predicts = model.predict(image)
prob_predicts = prob_predicts[0].tolist()
# find top k
prob_and_class = [(prob, index)
for index, prob in enumerate(prob_predicts)]
prob_and_class.sort(reverse=True)
prob_and_class_names = [(prob, class_names[str(index+1)])
for prob, index in prob_and_class]
return prob_and_class_names[:top_k]
def main():
in_args = parse_args()
img_path = in_args.input_image
model = load_model(in_args.model)
class_names = load_class_names(in_args.category_names)
top_k = in_args.top_k
print('\nThis is top {} probability of this flower tend to be :\n'.format(
top_k), predict(img_path, model, class_names, top_k))
if __name__ == "__main__":
main()