Skip to content

Commit dab9d18

Browse files
Finish eval_3d_bbox_performance. Added mAP scores, plot_pr_curves(), evaluate_detector()
1 parent 89939c1 commit dab9d18

1 file changed

Lines changed: 228 additions & 8 deletions

File tree

GEMstack/onboard/perception/sensorFusion/eval_3d_bbox_performance.py

Lines changed: 228 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import warnings
1010
import traceback
1111

12-
###### Configuration ######
12+
###### Bbox coord mapping ######
1313
IDX_X, IDX_Y, IDX_Z = 0, 1, 2
1414
IDX_L, IDX_W, IDX_H = 3, 4, 5
1515
IDX_YAW = 6
@@ -49,8 +49,8 @@ def parse_box_line(line, is_gt=False):
4949
# Read rotation_y (yaw) from KITTI standard index 14
5050
rot_y = float(parts[14])
5151

52-
# --- Convert to standardized internal format ---
53-
# Internal Standardized format: [cx, cy, cz, l, w, h, yaw, class_label, score]
52+
# Convert to internal format
53+
# Internal format: [cx, cy, cz, l, w, h, yaw, class_label, score]
5454

5555
# Use KITTI loc_x and loc_z directly for internal cx, cz
5656
cx = loc_x
@@ -122,11 +122,9 @@ def calculate_3d_iou(box1, box2):
122122
Returns:
123123
float: The 3D IoU value.
124124
125-
*** PLACEHOLDER IMPLEMENTATION ***
126-
This function needs a proper implementation for ROTATED 3D boxes.
127-
The current version calculates IoU based on axis-aligned bounding boxes
128-
derived from the dimensions and centers, which is inaccurate for rotated boxes.
125+
Doesn't consider yaw
129126
"""
127+
130128
####### Simple Axis-Aligned Bounding Box (AABB) IoU #######
131129
def get_aabb_corners(box):
132130
# Uses internal format where cy is geometric center
@@ -167,7 +165,229 @@ def get_aabb_corners(box):
167165
iou = max(0.0, min(iou, 1.0))
168166
return iou
169167

170-
# TODO: calc mAP
168+
169+
def calculate_ap(precision, recall):
170+
"""Calculates Average Precision using the PASCAL VOC method (area under monotonic PR curve)."""
171+
# Convert to numpy arrays first for safety
172+
if not isinstance(precision, (list, np.ndarray)) or not isinstance(recall, (list, np.ndarray)):
173+
return 0.0
174+
precision = np.array(precision)
175+
recall = np.array(recall)
176+
177+
if precision.size == 0 or recall.size == 0:
178+
return 0.0
179+
180+
# Prepend/Append points for interpolation boundaries
181+
recall = np.concatenate(([0.], recall, [1.0]))
182+
precision = np.concatenate(([0.], precision, [0.])) # Start with 0 precision at recall 0, end with 0 at recall 1
183+
184+
# Make precision monotonically decreasing (handles PR curve 'jiggles')
185+
for i in range(len(precision) - 2, -1, -1):
186+
precision[i] = max(precision[i], precision[i+1])
187+
188+
# Find indices where recall changes (avoids redundant calculations)
189+
indices = np.where(recall[1:] != recall[:-1])[0] + 1
190+
191+
# Compute AP using the area under the curve (sum of rectangle areas)
192+
ap = np.sum((recall[indices] - recall[indices-1]) * precision[indices])
193+
return ap
194+
195+
def evaluate_detector(gt_boxes_all_samples, pred_boxes_all_samples, classes, iou_threshold):
196+
"""Evaluates a single detector's predictions against ground truth."""
197+
results_by_class = {}
198+
sample_ids = list(gt_boxes_all_samples.keys()) # Get fixed order of sample IDs
199+
200+
for cls in classes:
201+
all_pred_boxes_cls = []
202+
num_gt_cls = 0
203+
pred_sample_indices = [] # Store index from sample_ids for each prediction
204+
205+
# Collect all GTs and Preds for this class across samples
206+
for i, sample_id in enumerate(sample_ids):
207+
# Use .get() with default empty dict/list for safety
208+
gt_boxes = gt_boxes_all_samples.get(sample_id, {}).get(cls, [])
209+
pred_boxes = pred_boxes_all_samples.get(sample_id, {}).get(cls, [])
210+
211+
num_gt_cls += len(gt_boxes)
212+
for box in pred_boxes:
213+
all_pred_boxes_cls.append(box)
214+
pred_sample_indices.append(i) # Store the original sample index
215+
216+
if not all_pred_boxes_cls: # Handle case with no predictions for this class
217+
results_by_class[cls] = {
218+
'precision': np.array([]), # Use empty numpy arrays
219+
'recall': np.array([]),
220+
'ap': 0.0,
221+
'num_gt': num_gt_cls,
222+
'num_pred': 0
223+
}
224+
continue # Skip to next class
225+
226+
# Sort detections by confidence score (descending)
227+
# Ensure scores exist and are numeric before sorting
228+
scores = []
229+
valid_indices_for_sorting = []
230+
for idx, box in enumerate(all_pred_boxes_cls):
231+
if len(box) > IDX_SCORE and isinstance(box[IDX_SCORE], (int, float)):
232+
scores.append(-box[IDX_SCORE]) # Use negative score for descending sort with argsort
233+
valid_indices_for_sorting.append(idx)
234+
else:
235+
warnings.warn(f"Class {cls}: Prediction missing score or invalid score type. Excluding from evaluation: {box}")
236+
237+
if not valid_indices_for_sorting: # If filtering removed all boxes
238+
results_by_class[cls] = {'precision': np.array([]),'recall': np.array([]),'ap': 0.0,'num_gt': num_gt_cls,'num_pred': 0}
239+
continue
240+
241+
# Filter lists based on valid scores
242+
all_pred_boxes_cls = [all_pred_boxes_cls[i] for i in valid_indices_for_sorting]
243+
pred_sample_indices = [pred_sample_indices[i] for i in valid_indices_for_sorting]
244+
# Scores list is already built correctly
245+
246+
# Get the sorted order based on scores
247+
sorted_indices = np.argsort(scores) # argsort sorts ascending on negative scores -> descending order of original scores
248+
249+
# Reorder the lists based on sorted scores
250+
all_pred_boxes_cls = [all_pred_boxes_cls[i] for i in sorted_indices]
251+
pred_sample_indices = [pred_sample_indices[i] for i in sorted_indices]
252+
253+
254+
tp = np.zeros(len(all_pred_boxes_cls))
255+
fp = np.zeros(len(all_pred_boxes_cls))
256+
# Track matched GTs per sample: gt_matched[sample_idx][gt_box_idx] = True/False
257+
gt_matched = defaultdict(lambda: defaultdict(bool)) # Indexed by sample_idx, then gt_idx
258+
259+
# Match predictions
260+
for det_idx, pred_box in enumerate(all_pred_boxes_cls):
261+
sample_idx = pred_sample_indices[det_idx] # Get the original sample index (0 to num_samples-1)
262+
sample_id = sample_ids[sample_idx] # Get the sample_id string using the index
263+
gt_boxes = gt_boxes_all_samples.get(sample_id, {}).get(cls, [])
264+
265+
best_iou = -1.0
266+
best_gt_idx = -1 # Index relative to gt_boxes for this sample/class
267+
268+
if not gt_boxes: # No GT for this class in this specific sample
269+
fp[det_idx] = 1
270+
continue
271+
272+
for gt_idx, gt_box in enumerate(gt_boxes):
273+
# Explicitly check class match (belt-and-suspenders)
274+
if pred_box[IDX_CLASS] == gt_box[IDX_CLASS]:
275+
iou = calculate_3d_iou(pred_box, gt_box)
276+
if iou > best_iou:
277+
best_iou = iou
278+
best_gt_idx = gt_idx
279+
# else: # Should not happen if inputs are correctly filtered by class
280+
# pass
281+
282+
283+
if best_iou >= iou_threshold:
284+
# Check if this GT box was already matched *in this sample*
285+
if not gt_matched[sample_idx].get(best_gt_idx, False):
286+
tp[det_idx] = 1
287+
gt_matched[sample_idx][best_gt_idx] = True # Mark as matched for this sample
288+
else:
289+
fp[det_idx] = 1 # Matched a GT box already covered by a higher-scored prediction
290+
else:
291+
fp[det_idx] = 1 # Did not match any available GT box with sufficient IoU
292+
293+
# Calculate precision/recall
294+
fp_cumsum = np.cumsum(fp)
295+
tp_cumsum = np.cumsum(tp)
296+
297+
# Avoid division by zero if num_gt_cls is 0
298+
recall = tp_cumsum / num_gt_cls if num_gt_cls > 0 else np.zeros_like(tp_cumsum, dtype=float)
299+
300+
# Avoid division by zero if no predictions were made or matched (tp + fp = 0)
301+
denominator = tp_cumsum + fp_cumsum
302+
precision = np.divide(tp_cumsum, denominator, out=np.zeros_like(tp_cumsum, dtype=float), where=denominator!=0)
303+
304+
305+
ap = calculate_ap(precision, recall)
306+
307+
results_by_class[cls] = {
308+
'precision': precision, # Store as numpy arrays
309+
'recall': recall,
310+
'ap': ap,
311+
'num_gt': num_gt_cls,
312+
'num_pred': len(all_pred_boxes_cls) # Number of predictions *with valid scores*
313+
}
314+
315+
return results_by_class
316+
317+
318+
def plot_pr_curves(results_all_detectors, classes, output_dir):
319+
"""Plots Precision-Recall curves for each class."""
320+
if not os.path.exists(output_dir):
321+
try:
322+
os.makedirs(output_dir)
323+
except OSError as e:
324+
print(f"[LOG] Error creating output directory {output_dir} for plots: {e}")
325+
return # Cannot save plots
326+
327+
detector_names = list(results_all_detectors.keys())
328+
329+
for cls in classes:
330+
plt.figure(figsize=(10, 7))
331+
any_results_for_class = False # Track if any detector had results for this class
332+
333+
for detector_name, results_by_class in results_all_detectors.items():
334+
if cls in results_by_class and results_by_class[cls]['num_pred'] > 0 : # Check if there were predictions
335+
res = results_by_class[cls]
336+
precision = res['precision']
337+
recall = res['recall']
338+
ap = res['ap']
339+
340+
# Ensure plotting works even if precision/recall are empty arrays
341+
if recall.size > 0 and precision.size > 0:
342+
# Prepend a point for plotting nicely from recall=0
343+
plot_recall = np.concatenate(([0.], recall))
344+
# Use precision[0] if available, else 0.
345+
plot_precision = np.concatenate(([precision[0] if precision.size > 0 else 0.], precision))
346+
plt.plot(plot_recall, plot_precision, marker='.', markersize=4, linestyle='-', label=f'{detector_name} (AP={ap:.3f})')
347+
any_results_for_class = True
348+
else: # Handle case where num_pred > 0 but P/R arrays somehow ended up empty
349+
plt.plot([0], [0], marker='s', markersize=5, linestyle='', label=f'{detector_name} (AP={ap:.3f}, No P/R data?)')
350+
any_results_for_class = True # Still mark as having results
351+
352+
353+
elif cls in results_by_class: # Class exists in evaluation, but no predictions were made for it
354+
num_gt = results_by_class[cls]['num_gt']
355+
if num_gt > 0:
356+
# Plot a marker indicating no predictions were made for this GT class
357+
plt.plot([0], [0], marker='x', markersize=6, linestyle='', label=f'{detector_name} (No Pred, GT={num_gt})')
358+
# else: # No GT and no predictions for this class, don't plot anything specific
359+
# pass
360+
# else: # Class not even in results dict for this detector (e.g., error during eval?)
361+
# Could happen if detector had no files or all files failed parsing for this class
362+
# Might indicate an issue, but avoid cluttering plot unless needed.
363+
pass
364+
365+
366+
if any_results_for_class:
367+
plt.xlabel('Recall')
368+
plt.ylabel('Precision')
369+
plt.title(f'Precision-Recall Curve for Class: {cls}')
370+
plt.legend(loc='lower left')
371+
plt.grid(True)
372+
plt.xlim([-0.05, 1.05])
373+
plt.ylim([-0.05, 1.05])
374+
plot_path = os.path.join(output_dir, f'pr_curve_{cls}.png')
375+
try:
376+
plt.savefig(plot_path)
377+
print(f"[LOG] Generated PR curve: {plot_path}")
378+
except Exception as e:
379+
print(f"[LOG] Error saving PR curve plot for class '{cls}': {e}")
380+
finally:
381+
plt.close() # Close the figure regardless of save success
382+
else:
383+
# Check if there was any GT data for this class across all detectors
384+
# Use .get() chain safely
385+
num_gt_total = sum(results_by_class.get(cls, {}).get('num_gt', 0) for results_by_class in results_all_detectors.values())
386+
if num_gt_total > 0:
387+
print(f" Skipping PR plot for class '{cls}': No predictions found across detectors (GT={num_gt_total}).")
388+
else:
389+
print(f" Skipping PR plot for class '{cls}': No ground truth found.")
390+
plt.close() # Close the empty figure
171391

172392
def main():
173393
parser = argparse.ArgumentParser(description='Evaluate N 3D Object Detectors using KITTI format labels.')

0 commit comments

Comments
 (0)