2727import argparse
2828import logging
2929from pathlib import Path
30+ from typing import Optional , Tuple , Union
3031
32+ import numpy as np
3133from PIL import Image
3234
3335from handmotion .data .feature_extractor import FeatureExtractor
3941logger = logging .getLogger (__name__ )
4042
4143
42- def predict_image (classifier , extractor , image_path ):
44+ def _extract_and_predict (
45+ classifier : HandGestureClassifier ,
46+ image : Union [Image .Image , str , Path ],
47+ image_format : str = "rgb" ,
48+ ) -> Tuple [Optional [str ], Optional [np .ndarray ], Optional [np .ndarray ]]:
4349 """
44- Predict gesture for a single image .
50+ Internal helper to extract features and get predictions .
4551
4652 Args:
4753 classifier: Loaded HandGestureClassifier
48- extractor: FeatureExtractor instance
49- image_path: Path to image file
54+ image: PIL Image, or path to image file
55+ image_format: "rgb" or "bgr" (default: "rgb")
5056
5157 Returns:
52- Tuple of (predicted_label, confidence_dict ) or (None, None) if no hand detected
58+ Tuple of (prediction, probabilities, class_names ) or (None, None, None) if no hand detected
5359 """
54- # Load image
55- image = Image .open (image_path )
56- if image .mode != "RGB" :
57- image = image .convert ("RGB" )
60+ # Load image if path provided
61+ if isinstance (image , (str , Path )):
62+ image = Image .open (image )
63+ if image .mode != "RGB" :
64+ image = image .convert ("RGB" )
65+ image_format = "rgb" # PIL images are always RGB
66+
67+ # Initialize feature extractor
68+ extractor = FeatureExtractor ()
5869
5970 # Extract features
60- features_dict = extractor .extract (image , image_format = "rgb" )
71+ features_dict = extractor .extract (image , image_format = image_format )
6172 if features_dict is None :
62- logger .warning (f"No hand detected in { image_path } " )
63- return None , None
73+ return None , None , None
6474
6575 # Prepare features for classifier
6676 landmarks = features_dict ["landmarks" ].reshape (1 , - 1 ) # (1, 63)
@@ -70,9 +80,71 @@ def predict_image(classifier, extractor, image_path):
7080 # Predict
7181 prediction = classifier .predict (features )[0 ]
7282 probabilities = classifier .predict_proba (features )[0 ]
83+ class_names = classifier .label_encoder .classes_
84+
85+ return prediction , probabilities , class_names
86+
87+
88+ def predict_image (
89+ classifier : HandGestureClassifier ,
90+ image : Union [Image .Image , str , Path ],
91+ threshold : float = 0.5 ,
92+ image_format : str = "rgb" ,
93+ ) -> Tuple [Optional [str ], Optional [float ]]:
94+ """
95+ Predict gesture from an image.
96+
97+ Args:
98+ classifier: Loaded HandGestureClassifier
99+ image: PIL Image, or path to image file
100+ threshold: Minimum confidence to return prediction (0.0-1.0).
101+ If confidence is below threshold, returns (None, None)
102+ image_format: "rgb" or "bgr" (default: "rgb"). Only used if image is PIL Image.
103+
104+ Returns:
105+ Tuple of (label, confidence) or (None, None) if:
106+ - No hand detected
107+ - Confidence below threshold
108+ """
109+ prediction , probabilities , class_names = _extract_and_predict (classifier , image , image_format )
110+ if prediction is None or probabilities is None or class_names is None :
111+ logger .debug ("No hand detected in image" )
112+ return None , None
113+
114+ # Get confidence for predicted class
115+ pred_idx = list (class_names ).index (prediction )
116+ confidence = float (probabilities [pred_idx ])
117+
118+ # Check threshold
119+ if confidence < threshold :
120+ logger .debug (f"Confidence { confidence :.4f} below threshold { threshold } for { prediction } " )
121+ return None , None
122+
123+ return prediction , confidence
124+
125+
126+ def predict_image_with_proba (
127+ classifier : HandGestureClassifier ,
128+ image : Union [Image .Image , str , Path ],
129+ image_format : str = "rgb" ,
130+ ) -> Tuple [Optional [str ], Optional [dict ]]:
131+ """
132+ Predict gesture from an image with full probability distribution.
133+
134+ Args:
135+ classifier: Loaded HandGestureClassifier
136+ image: PIL Image, or path to image file
137+ image_format: "rgb" or "bgr" (default: "rgb")
138+
139+ Returns:
140+ Tuple of (predicted_label, confidence_dict) or (None, None) if no hand detected
141+ """
142+ prediction , probabilities , class_names = _extract_and_predict (classifier , image , image_format )
143+ if prediction is None or probabilities is None or class_names is None :
144+ logger .warning ("No hand detected in image" )
145+ return None , None
73146
74147 # Get class names and create confidence dict
75- class_names = classifier .label_encoder .classes_
76148 confidence_dict = {class_names [i ]: float (prob ) for i , prob in enumerate (probabilities )}
77149
78150 return prediction , confidence_dict
@@ -116,25 +188,26 @@ def main():
116188 classifier = HandGestureClassifier ()
117189 classifier .load (model_path )
118190
119- # Initialize feature extractor
120- extractor = FeatureExtractor ()
121-
122191 # Predict
123192 logger .info (f"Processing image: { image_path } " )
124- prediction , confidence = predict_image (classifier , extractor , image_path )
125-
126- if prediction is None :
127- print ("No hand detected in image." )
128- return
129-
130- # Print results
131- print (f"\n Prediction: { prediction } " )
132- print (f"Confidence: { confidence [prediction ]:.4f} " )
133-
134193 if args .show_proba :
194+ prediction , confidence = predict_image_with_proba (classifier , image_path )
195+ if prediction is None :
196+ print ("No hand detected in image." )
197+ return
198+ assert confidence is not None
199+ print (f"\n Prediction: { prediction } " )
200+ print (f"Confidence: { confidence [prediction ]:.4f} " )
135201 print ("\n All class probabilities:" )
136202 for class_name , prob in sorted (confidence .items (), key = lambda x : x [1 ], reverse = True ):
137203 print (f" { class_name } : { prob :.4f} " )
204+ else :
205+ prediction , confidence = predict_image (classifier , image_path )
206+ if prediction is None :
207+ print ("No hand detected in image." )
208+ return
209+ print (f"\n Prediction: { prediction } " )
210+ print (f"Confidence: { confidence :.4f} " )
138211
139212
140213if __name__ == "__main__" :
0 commit comments