|
1 | 1 | from __future__ import annotations |
2 | 2 |
|
3 | 3 | from dataclasses import dataclass |
4 | | -from typing import Iterator, Optional |
| 4 | +from typing import Iterator |
5 | 5 |
|
6 | 6 | import cv2 |
7 | 7 | import numpy as np |
@@ -94,3 +94,87 @@ def process_stream(self, source: str, display: bool = False) -> None: |
94 | 94 | except Exception: # noqa: BLE001 |
95 | 95 | logger.warning("Failed to destroy windows (likely headless environment)") |
96 | 96 |
|
| 97 | + |
| 98 | +class VideoProcessor: |
| 99 | + """Batch video processor using YOLOv8 + ByteTrack via supervision. |
| 100 | +
|
| 101 | + This processor reads frames from an input video, performs detection and |
| 102 | + tracking, annotates results, and writes an output video. |
| 103 | + """ |
| 104 | + |
| 105 | + def __init__(self, model_path: str = "yolov8m.pt") -> None: |
| 106 | + # Lazy imports to avoid importing heavy deps where not needed |
| 107 | + from ultralytics import YOLO # type: ignore[import-not-found] |
| 108 | + import supervision as sv # type: ignore[import-not-found] |
| 109 | + |
| 110 | + self._sv = sv |
| 111 | + self._model = YOLO(model_path) |
| 112 | + # ByteTrack tracker |
| 113 | + self._tracker = sv.ByteTrack() |
| 114 | + # Annotators |
| 115 | + self._box_annotator = sv.BoundingBoxAnnotator() |
| 116 | + self._label_annotator = sv.LabelAnnotator() |
| 117 | + # COCO class name mapping from model |
| 118 | + try: |
| 119 | + self._class_names = self._model.model.names # type: ignore[attr-defined] |
| 120 | + except Exception: # noqa: BLE001 |
| 121 | + # Fallback to standard COCO mapping indices used by YOLOv8 |
| 122 | + self._class_names = { |
| 123 | + 0: "person", |
| 124 | + 1: "bicycle", |
| 125 | + 2: "car", |
| 126 | + 3: "motorcycle", |
| 127 | + 5: "bus", |
| 128 | + 7: "truck", |
| 129 | + } |
| 130 | + |
| 131 | + logger.info("Initialized VideoProcessor with model: {}", model_path) |
| 132 | + |
| 133 | + def _filter_detections(self, detections: "np.ndarray | object") -> "object": |
| 134 | + """Filter detection classes to person (0), car (2), truck (7). |
| 135 | +
|
| 136 | + Works with supervision.Detections instance which supports numpy-like |
| 137 | + indexing using a boolean mask. |
| 138 | + """ |
| 139 | + sv = self._sv |
| 140 | + assert isinstance(detections, sv.Detections) |
| 141 | + allowed = np.array([0, 2, 7]) |
| 142 | + mask = np.isin(detections.class_id, allowed) |
| 143 | + return detections[mask] |
| 144 | + |
| 145 | + def process_video(self, input_path: str, output_path: str) -> None: |
| 146 | + import supervision as sv # type: ignore[import-not-found] |
| 147 | + |
| 148 | + video_info = sv.VideoInfo.from_video_path(input_path) |
| 149 | + frames = sv.get_video_frames_generator(input_path) |
| 150 | + |
| 151 | + # Use a broadly supported codec for MP4 writing in headless envs |
| 152 | + with sv.VideoSink(output_path, video_info, codec="mp4v") as sink: |
| 153 | + for frame in frames: |
| 154 | + # Inference |
| 155 | + result = self._model(frame, verbose=False)[0] |
| 156 | + detections = sv.Detections.from_ultralytics(result) |
| 157 | + detections = self._filter_detections(detections) |
| 158 | + |
| 159 | + # Tracking |
| 160 | + tracked = self._tracker.update_with_detections(detections) |
| 161 | + |
| 162 | + # Labels for annotation |
| 163 | + labels = [] |
| 164 | + for i in range(len(tracked)): |
| 165 | + class_id = int(tracked.class_id[i]) if tracked.class_id is not None else -1 |
| 166 | + confidence = float(tracked.confidence[i]) if tracked.confidence is not None else 0.0 |
| 167 | + track_id = int(tracked.tracker_id[i]) if tracked.tracker_id is not None else -1 |
| 168 | + class_name = self._class_names.get(class_id, str(class_id)) |
| 169 | + labels.append(f"{class_name} #{track_id} {confidence:.2f}") |
| 170 | + |
| 171 | + # Annotation |
| 172 | + annotated = self._box_annotator.annotate(scene=frame.copy(), detections=tracked) |
| 173 | + annotated = self._label_annotator.annotate( |
| 174 | + scene=annotated, |
| 175 | + detections=tracked, |
| 176 | + labels=labels, |
| 177 | + ) |
| 178 | + |
| 179 | + sink.write_frame(annotated) |
| 180 | + |
0 commit comments