-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathdetection_worker.py
More file actions
117 lines (93 loc) · 4.1 KB
/
detection_worker.py
File metadata and controls
117 lines (93 loc) · 4.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import threading
import cv2
import torch
import torch.nn as nn
import timm
import logging
from Detection.detector import load_model, detect_trucks
from deep_sort_realtime.deepsort_tracker import DeepSort
from Detection.db import init_db, is_already_saved, save_illegal_vehicle
from Detection.utils import match_with_track
import onnxruntime
from torchvision import transforms
from PIL import Image
import numpy as np
from PyQt5.QtCore import QObject, pyqtSignal
from datetime import datetime
import uuid
class WorkerSignals(QObject):
image_saved = pyqtSignal(str)
class DetectionWorker(threading.Thread):
def __init__(self, stream_url, cctvname, signal_handler=None):
super().__init__()
self.stream_url = stream_url
self.cctvname = cctvname
self.running = True
self.signals = signal_handler # PyQt용 시그널 핸들러 등
self.session_uid = str(uuid.uuid4())[:8]
# YOLOv8 로드
self.model = load_model("Detection/model/yolov8_n.pt").to("cuda") # !! 모델 경로 확인 필요 !!
logging.getLogger("ultralytics").setLevel(logging.ERROR)
# 분류 모델 수정
self.onnx_session = onnxruntime.InferenceSession(
'Detection/model/final_classification.onnx', # !! 모델 경로 확인 필요 !!
providers = ['CUDAExecutionProvider']
)
self.onnx_input_name = self.onnx_session.get_inputs()[0].name
self.onnx_transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])
self.tracker = DeepSort(max_age=10, n_init=3)
def run(self):
conn, cursor = init_db()
cap = cv2.VideoCapture(self.stream_url)
print(f"[{self.cctvname}] 스트림 시작")
try:
while self.running and cap.isOpened():
ret, frame = cap.read()
if not ret:
continue
# 1. 트럭 감지
truck_boxes = detect_trucks(self.model, frame)
# 2. 트래킹
tracks = self.tracker.update_tracks(truck_boxes, frame=frame)
for track in tracks:
if not track.is_confirmed():
continue
# 3. 트래킹 ID 매칭
track_id = track.track_id
unique_id = f"{self.session_uid}_{track_id}"
# 4. 크롭
x1, y1, x2, y2 = map(int, track.to_ltrb())
roi = frame[y1:y2, x1:x2]
if roi.size == 0:
continue
# 5. 분류
label = self.classify_onnx(roi)
print('label:', label)
print('is_already_saved:', is_already_saved(cursor, unique_id))
# 6. DB 저장
if label == 'illegal' and not is_already_saved(cursor, unique_id):
print(f"[{self.cctvname}] 🚨 불법 차량 저장 (ID: {track_id})")
def notify_image_saved(image_path):
if self.signals:
self.signals.image_saved.emit(image_path)
try:
save_illegal_vehicle(frame, track, unique_id, cursor, conn, self.cctvname, on_save_callback=notify_image_saved)
except Exception as e:
print('[❌ 저장 실패]', e)
finally:
cap.release()
conn.close()
print(f"[{self.cctvname}] 스트림 종료")
def classify_onnx(self, image):
pil_img = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
input_tensor = self.onnx_transform(pil_img).unsqueeze(0).numpy()
output = self.onnx_session.run(None, {self.onnx_input_name: input_tensor})
logit = output[0][0][0]
prob = 1 / (1 + np.exp(-logit))
return 'illegal' if prob < 0.5 else 'legal'
def stop(self):
self.running = False