|
9 | 9 | import warnings |
10 | 10 | import traceback |
11 | 11 |
|
12 | | -###### Configuration ###### |
| 12 | +###### Bbox coord mapping ###### |
13 | 13 | IDX_X, IDX_Y, IDX_Z = 0, 1, 2 |
14 | 14 | IDX_L, IDX_W, IDX_H = 3, 4, 5 |
15 | 15 | IDX_YAW = 6 |
@@ -49,8 +49,8 @@ def parse_box_line(line, is_gt=False): |
49 | 49 | # Read rotation_y (yaw) from KITTI standard index 14 |
50 | 50 | rot_y = float(parts[14]) |
51 | 51 |
|
52 | | - # --- Convert to standardized internal format --- |
53 | | - # Internal Standardized format: [cx, cy, cz, l, w, h, yaw, class_label, score] |
| 52 | + # Convert to internal format |
| 53 | + # Internal format: [cx, cy, cz, l, w, h, yaw, class_label, score] |
54 | 54 |
|
55 | 55 | # Use KITTI loc_x and loc_z directly for internal cx, cz |
56 | 56 | cx = loc_x |
@@ -122,11 +122,9 @@ def calculate_3d_iou(box1, box2): |
122 | 122 | Returns: |
123 | 123 | float: The 3D IoU value. |
124 | 124 |
|
125 | | - *** PLACEHOLDER IMPLEMENTATION *** |
126 | | - This function needs a proper implementation for ROTATED 3D boxes. |
127 | | - The current version calculates IoU based on axis-aligned bounding boxes |
128 | | - derived from the dimensions and centers, which is inaccurate for rotated boxes. |
| 125 | + Doesn't consider yaw |
129 | 126 | """ |
| 127 | + |
130 | 128 | ####### Simple Axis-Aligned Bounding Box (AABB) IoU ####### |
131 | 129 | def get_aabb_corners(box): |
132 | 130 | # Uses internal format where cy is geometric center |
@@ -167,7 +165,229 @@ def get_aabb_corners(box): |
167 | 165 | iou = max(0.0, min(iou, 1.0)) |
168 | 166 | return iou |
169 | 167 |
|
170 | | -# TODO: calc mAP |
| 168 | + |
| 169 | +def calculate_ap(precision, recall): |
| 170 | + """Calculates Average Precision using the PASCAL VOC method (area under monotonic PR curve).""" |
| 171 | + # Convert to numpy arrays first for safety |
| 172 | + if not isinstance(precision, (list, np.ndarray)) or not isinstance(recall, (list, np.ndarray)): |
| 173 | + return 0.0 |
| 174 | + precision = np.array(precision) |
| 175 | + recall = np.array(recall) |
| 176 | + |
| 177 | + if precision.size == 0 or recall.size == 0: |
| 178 | + return 0.0 |
| 179 | + |
| 180 | + # Prepend/Append points for interpolation boundaries |
| 181 | + recall = np.concatenate(([0.], recall, [1.0])) |
| 182 | + precision = np.concatenate(([0.], precision, [0.])) # Start with 0 precision at recall 0, end with 0 at recall 1 |
| 183 | + |
| 184 | + # Make precision monotonically decreasing (handles PR curve 'jiggles') |
| 185 | + for i in range(len(precision) - 2, -1, -1): |
| 186 | + precision[i] = max(precision[i], precision[i+1]) |
| 187 | + |
| 188 | + # Find indices where recall changes (avoids redundant calculations) |
| 189 | + indices = np.where(recall[1:] != recall[:-1])[0] + 1 |
| 190 | + |
| 191 | + # Compute AP using the area under the curve (sum of rectangle areas) |
| 192 | + ap = np.sum((recall[indices] - recall[indices-1]) * precision[indices]) |
| 193 | + return ap |
| 194 | + |
| 195 | +def evaluate_detector(gt_boxes_all_samples, pred_boxes_all_samples, classes, iou_threshold): |
| 196 | + """Evaluates a single detector's predictions against ground truth.""" |
| 197 | + results_by_class = {} |
| 198 | + sample_ids = list(gt_boxes_all_samples.keys()) # Get fixed order of sample IDs |
| 199 | + |
| 200 | + for cls in classes: |
| 201 | + all_pred_boxes_cls = [] |
| 202 | + num_gt_cls = 0 |
| 203 | + pred_sample_indices = [] # Store index from sample_ids for each prediction |
| 204 | + |
| 205 | + # Collect all GTs and Preds for this class across samples |
| 206 | + for i, sample_id in enumerate(sample_ids): |
| 207 | + # Use .get() with default empty dict/list for safety |
| 208 | + gt_boxes = gt_boxes_all_samples.get(sample_id, {}).get(cls, []) |
| 209 | + pred_boxes = pred_boxes_all_samples.get(sample_id, {}).get(cls, []) |
| 210 | + |
| 211 | + num_gt_cls += len(gt_boxes) |
| 212 | + for box in pred_boxes: |
| 213 | + all_pred_boxes_cls.append(box) |
| 214 | + pred_sample_indices.append(i) # Store the original sample index |
| 215 | + |
| 216 | + if not all_pred_boxes_cls: # Handle case with no predictions for this class |
| 217 | + results_by_class[cls] = { |
| 218 | + 'precision': np.array([]), # Use empty numpy arrays |
| 219 | + 'recall': np.array([]), |
| 220 | + 'ap': 0.0, |
| 221 | + 'num_gt': num_gt_cls, |
| 222 | + 'num_pred': 0 |
| 223 | + } |
| 224 | + continue # Skip to next class |
| 225 | + |
| 226 | + # Sort detections by confidence score (descending) |
| 227 | + # Ensure scores exist and are numeric before sorting |
| 228 | + scores = [] |
| 229 | + valid_indices_for_sorting = [] |
| 230 | + for idx, box in enumerate(all_pred_boxes_cls): |
| 231 | + if len(box) > IDX_SCORE and isinstance(box[IDX_SCORE], (int, float)): |
| 232 | + scores.append(-box[IDX_SCORE]) # Use negative score for descending sort with argsort |
| 233 | + valid_indices_for_sorting.append(idx) |
| 234 | + else: |
| 235 | + warnings.warn(f"Class {cls}: Prediction missing score or invalid score type. Excluding from evaluation: {box}") |
| 236 | + |
| 237 | + if not valid_indices_for_sorting: # If filtering removed all boxes |
| 238 | + results_by_class[cls] = {'precision': np.array([]),'recall': np.array([]),'ap': 0.0,'num_gt': num_gt_cls,'num_pred': 0} |
| 239 | + continue |
| 240 | + |
| 241 | + # Filter lists based on valid scores |
| 242 | + all_pred_boxes_cls = [all_pred_boxes_cls[i] for i in valid_indices_for_sorting] |
| 243 | + pred_sample_indices = [pred_sample_indices[i] for i in valid_indices_for_sorting] |
| 244 | + # Scores list is already built correctly |
| 245 | + |
| 246 | + # Get the sorted order based on scores |
| 247 | + sorted_indices = np.argsort(scores) # argsort sorts ascending on negative scores -> descending order of original scores |
| 248 | + |
| 249 | + # Reorder the lists based on sorted scores |
| 250 | + all_pred_boxes_cls = [all_pred_boxes_cls[i] for i in sorted_indices] |
| 251 | + pred_sample_indices = [pred_sample_indices[i] for i in sorted_indices] |
| 252 | + |
| 253 | + |
| 254 | + tp = np.zeros(len(all_pred_boxes_cls)) |
| 255 | + fp = np.zeros(len(all_pred_boxes_cls)) |
| 256 | + # Track matched GTs per sample: gt_matched[sample_idx][gt_box_idx] = True/False |
| 257 | + gt_matched = defaultdict(lambda: defaultdict(bool)) # Indexed by sample_idx, then gt_idx |
| 258 | + |
| 259 | + # Match predictions |
| 260 | + for det_idx, pred_box in enumerate(all_pred_boxes_cls): |
| 261 | + sample_idx = pred_sample_indices[det_idx] # Get the original sample index (0 to num_samples-1) |
| 262 | + sample_id = sample_ids[sample_idx] # Get the sample_id string using the index |
| 263 | + gt_boxes = gt_boxes_all_samples.get(sample_id, {}).get(cls, []) |
| 264 | + |
| 265 | + best_iou = -1.0 |
| 266 | + best_gt_idx = -1 # Index relative to gt_boxes for this sample/class |
| 267 | + |
| 268 | + if not gt_boxes: # No GT for this class in this specific sample |
| 269 | + fp[det_idx] = 1 |
| 270 | + continue |
| 271 | + |
| 272 | + for gt_idx, gt_box in enumerate(gt_boxes): |
| 273 | + # Explicitly check class match (belt-and-suspenders) |
| 274 | + if pred_box[IDX_CLASS] == gt_box[IDX_CLASS]: |
| 275 | + iou = calculate_3d_iou(pred_box, gt_box) |
| 276 | + if iou > best_iou: |
| 277 | + best_iou = iou |
| 278 | + best_gt_idx = gt_idx |
| 279 | + # else: # Should not happen if inputs are correctly filtered by class |
| 280 | + # pass |
| 281 | + |
| 282 | + |
| 283 | + if best_iou >= iou_threshold: |
| 284 | + # Check if this GT box was already matched *in this sample* |
| 285 | + if not gt_matched[sample_idx].get(best_gt_idx, False): |
| 286 | + tp[det_idx] = 1 |
| 287 | + gt_matched[sample_idx][best_gt_idx] = True # Mark as matched for this sample |
| 288 | + else: |
| 289 | + fp[det_idx] = 1 # Matched a GT box already covered by a higher-scored prediction |
| 290 | + else: |
| 291 | + fp[det_idx] = 1 # Did not match any available GT box with sufficient IoU |
| 292 | + |
| 293 | + # Calculate precision/recall |
| 294 | + fp_cumsum = np.cumsum(fp) |
| 295 | + tp_cumsum = np.cumsum(tp) |
| 296 | + |
| 297 | + # Avoid division by zero if num_gt_cls is 0 |
| 298 | + recall = tp_cumsum / num_gt_cls if num_gt_cls > 0 else np.zeros_like(tp_cumsum, dtype=float) |
| 299 | + |
| 300 | + # Avoid division by zero if no predictions were made or matched (tp + fp = 0) |
| 301 | + denominator = tp_cumsum + fp_cumsum |
| 302 | + precision = np.divide(tp_cumsum, denominator, out=np.zeros_like(tp_cumsum, dtype=float), where=denominator!=0) |
| 303 | + |
| 304 | + |
| 305 | + ap = calculate_ap(precision, recall) |
| 306 | + |
| 307 | + results_by_class[cls] = { |
| 308 | + 'precision': precision, # Store as numpy arrays |
| 309 | + 'recall': recall, |
| 310 | + 'ap': ap, |
| 311 | + 'num_gt': num_gt_cls, |
| 312 | + 'num_pred': len(all_pred_boxes_cls) # Number of predictions *with valid scores* |
| 313 | + } |
| 314 | + |
| 315 | + return results_by_class |
| 316 | + |
| 317 | + |
| 318 | +def plot_pr_curves(results_all_detectors, classes, output_dir): |
| 319 | + """Plots Precision-Recall curves for each class.""" |
| 320 | + if not os.path.exists(output_dir): |
| 321 | + try: |
| 322 | + os.makedirs(output_dir) |
| 323 | + except OSError as e: |
| 324 | + print(f"[LOG] Error creating output directory {output_dir} for plots: {e}") |
| 325 | + return # Cannot save plots |
| 326 | + |
| 327 | + detector_names = list(results_all_detectors.keys()) |
| 328 | + |
| 329 | + for cls in classes: |
| 330 | + plt.figure(figsize=(10, 7)) |
| 331 | + any_results_for_class = False # Track if any detector had results for this class |
| 332 | + |
| 333 | + for detector_name, results_by_class in results_all_detectors.items(): |
| 334 | + if cls in results_by_class and results_by_class[cls]['num_pred'] > 0 : # Check if there were predictions |
| 335 | + res = results_by_class[cls] |
| 336 | + precision = res['precision'] |
| 337 | + recall = res['recall'] |
| 338 | + ap = res['ap'] |
| 339 | + |
| 340 | + # Ensure plotting works even if precision/recall are empty arrays |
| 341 | + if recall.size > 0 and precision.size > 0: |
| 342 | + # Prepend a point for plotting nicely from recall=0 |
| 343 | + plot_recall = np.concatenate(([0.], recall)) |
| 344 | + # Use precision[0] if available, else 0. |
| 345 | + plot_precision = np.concatenate(([precision[0] if precision.size > 0 else 0.], precision)) |
| 346 | + plt.plot(plot_recall, plot_precision, marker='.', markersize=4, linestyle='-', label=f'{detector_name} (AP={ap:.3f})') |
| 347 | + any_results_for_class = True |
| 348 | + else: # Handle case where num_pred > 0 but P/R arrays somehow ended up empty |
| 349 | + plt.plot([0], [0], marker='s', markersize=5, linestyle='', label=f'{detector_name} (AP={ap:.3f}, No P/R data?)') |
| 350 | + any_results_for_class = True # Still mark as having results |
| 351 | + |
| 352 | + |
| 353 | + elif cls in results_by_class: # Class exists in evaluation, but no predictions were made for it |
| 354 | + num_gt = results_by_class[cls]['num_gt'] |
| 355 | + if num_gt > 0: |
| 356 | + # Plot a marker indicating no predictions were made for this GT class |
| 357 | + plt.plot([0], [0], marker='x', markersize=6, linestyle='', label=f'{detector_name} (No Pred, GT={num_gt})') |
| 358 | + # else: # No GT and no predictions for this class, don't plot anything specific |
| 359 | + # pass |
| 360 | + # else: # Class not even in results dict for this detector (e.g., error during eval?) |
| 361 | + # Could happen if detector had no files or all files failed parsing for this class |
| 362 | + # Might indicate an issue, but avoid cluttering plot unless needed. |
| 363 | + pass |
| 364 | + |
| 365 | + |
| 366 | + if any_results_for_class: |
| 367 | + plt.xlabel('Recall') |
| 368 | + plt.ylabel('Precision') |
| 369 | + plt.title(f'Precision-Recall Curve for Class: {cls}') |
| 370 | + plt.legend(loc='lower left') |
| 371 | + plt.grid(True) |
| 372 | + plt.xlim([-0.05, 1.05]) |
| 373 | + plt.ylim([-0.05, 1.05]) |
| 374 | + plot_path = os.path.join(output_dir, f'pr_curve_{cls}.png') |
| 375 | + try: |
| 376 | + plt.savefig(plot_path) |
| 377 | + print(f"[LOG] Generated PR curve: {plot_path}") |
| 378 | + except Exception as e: |
| 379 | + print(f"[LOG] Error saving PR curve plot for class '{cls}': {e}") |
| 380 | + finally: |
| 381 | + plt.close() # Close the figure regardless of save success |
| 382 | + else: |
| 383 | + # Check if there was any GT data for this class across all detectors |
| 384 | + # Use .get() chain safely |
| 385 | + num_gt_total = sum(results_by_class.get(cls, {}).get('num_gt', 0) for results_by_class in results_all_detectors.values()) |
| 386 | + if num_gt_total > 0: |
| 387 | + print(f" Skipping PR plot for class '{cls}': No predictions found across detectors (GT={num_gt_total}).") |
| 388 | + else: |
| 389 | + print(f" Skipping PR plot for class '{cls}': No ground truth found.") |
| 390 | + plt.close() # Close the empty figure |
171 | 391 |
|
172 | 392 | def main(): |
173 | 393 | parser = argparse.ArgumentParser(description='Evaluate N 3D Object Detectors using KITTI format labels.') |
|
0 commit comments