|
| 1 | +import time |
1 | 2 | from ...state import AllState,VehicleState,ObjectPose,ObjectFrameEnum,AgentState,AgentEnum,AgentActivityEnum |
2 | 3 | from ..interface.gem import GEMInterface |
3 | 4 | from ..component import Component |
@@ -37,6 +38,11 @@ def __init__(self,vehicle_interface : GEMInterface): |
37 | 38 | self.vehicle_interface = vehicle_interface |
38 | 39 | self.detector = YOLO(os.getcwd()+'/GEMstack/knowledge/detection/yolov8n.pt') # change to get model value from sys arg |
39 | 40 | self.last_person_boxes = [] |
| 41 | + |
| 42 | + self.disappeared_timers = {} |
| 43 | + self.last_frame_time = None |
| 44 | + self.max_disappear_time = 1.0 |
| 45 | + |
40 | 46 | self.pedestrians = {} |
41 | 47 | self.visualization = True # Set this to true for visualization, later change to get value from sys arg |
42 | 48 | self.confidence = 0.7 |
@@ -91,89 +97,218 @@ def image_callback(self, image : Image): #image : cv2.Mat): |
91 | 97 | # Use Image directly for GEM Car |
92 | 98 | # track_result = self.detector.track(source=image, classes=self.classes_to_detect, persist=True, conf=self.confidence) |
93 | 99 |
|
| 100 | + # Compute time delta (dt) since last frame |
| 101 | + current_time = time.time() |
| 102 | + if self.last_frame_time is None: |
| 103 | + self.last_frame_time = current_time |
| 104 | + dt = current_time - self.last_frame_time |
| 105 | + self.last_frame_time = current_time |
| 106 | + |
94 | 107 | # Convert to CV2 format for RosBag |
95 | 108 | bridge = CvBridge() |
96 | | - image = bridge.imgmsg_to_cv2(image, "bgr8") |
97 | | - track_result = self.detector.track(source=image, classes=self.classes_to_detect, persist=True, conf=self.confidence) |
| 109 | + cv_image = bridge.imgmsg_to_cv2(image, "bgr8") |
| 110 | + # track_result = self.detector.track(source=image, classes=self.classes_to_detect, persist=True, conf=self.confidence) |
98 | 111 |
|
99 | | - self.last_person_boxes = [] |
100 | | - boxes = track_result[0].boxes |
| 112 | + # Run YOLO tracking |
| 113 | + track_result = self.detector.track( |
| 114 | + source=cv_image, |
| 115 | + classes=self.classes_to_detect, |
| 116 | + persist=True, |
| 117 | + conf=self.confidence |
| 118 | + ) |
101 | 119 |
|
102 | | - # Used for visualization |
103 | | - if(self.visualization): |
104 | | - label_text = "Pedestrian " |
105 | | - font = cv2.FONT_HERSHEY_SIMPLEX |
106 | | - font_scale = 0.5 |
107 | | - font_color = (255, 255, 255) # White text |
108 | | - outline_color = (0, 0, 0) # Black outline |
109 | | - line_type = 1 |
110 | | - text_thickness = 2 # Text thickness |
111 | | - outline_thickness = 1 # Thickness of the text outline |
| 120 | + # YOLOv8 returns a 'Boxes' object |
| 121 | + # We'll gather the IDs we see in this frame: |
| 122 | + current_ids = set() |
| 123 | + |
| 124 | + if len(track_result) > 0: |
| 125 | + boxes = track_result[0].boxes |
| 126 | + else: |
| 127 | + boxes = [] |
| 128 | + |
| 129 | + # self.last_person_boxes = [] |
| 130 | + # boxes = track_result[0].boxes |
| 131 | + |
| 132 | + # # Used for visualization |
| 133 | + # if(self.visualization): |
| 134 | + # label_text = "Pedestrian " |
| 135 | + # font = cv2.FONT_HERSHEY_SIMPLEX |
| 136 | + # font_scale = 0.5 |
| 137 | + # font_color = (255, 255, 255) # White text |
| 138 | + # outline_color = (0, 0, 0) # Black outline |
| 139 | + # line_type = 1 |
| 140 | + # text_thickness = 2 # Text thickness |
| 141 | + # outline_thickness = 1 # Thickness of the text outline |
112 | 142 |
|
113 | 143 | # Unpacking box dimentions detected into x,y,w,h |
114 | 144 | for box in boxes: |
115 | 145 |
|
116 | 146 | xywh = box.xywh[0].tolist() |
117 | 147 | self.last_person_boxes.append(xywh) |
118 | 148 | x, y, w, h = xywh |
119 | | - id = box.id.item() |
| 149 | + # id = box.id.item() |
120 | 150 |
|
121 | | - # Stores AgentState in a dict, can be removed if not required |
122 | | - pose = ObjectPose(t=0,x=x,y=y,z=0,yaw=0,pitch=0,roll=0,frame=ObjectFrameEnum.CURRENT) |
123 | | - dims = (w,h,0) |
124 | | - if(id not in self.pedestrians.keys()): |
125 | | - self.pedestrians[id] = AgentState(pose=pose,dimensions=dims,outline=None,type=AgentEnum.PEDESTRIAN,activity=AgentActivityEnum.MOVING,velocity=(0,0,0),yaw_rate=0) |
| 151 | + # YOLO assigned ID |
| 152 | + track_id = box.id.item() |
| 153 | + current_ids.add(track_id) |
| 154 | + |
| 155 | + # Check if we have seen this ID before |
| 156 | + if track_id not in self.pedestrians: |
| 157 | + # It's a new ID, create a new AgentState |
| 158 | + self.pedestrians[track_id] = box_to_fake_agent(xywh, velocity=(0,0,0)) |
126 | 159 | else: |
127 | | - self.pedestrians[id].pose = pose |
128 | | - self.pedestrians[id].dims = dims |
| 160 | + # Update existing |
| 161 | + old_pose = self.pedestrians[track_id].pose |
| 162 | + |
| 163 | + # Calculate new center |
| 164 | + new_center_x = x + w/2 |
| 165 | + new_center_y = y + h/2 |
| 166 | + |
| 167 | + # Compute velocity from old pose -> new pose (pixels/sec or similar) |
| 168 | + if dt > 0: |
| 169 | + vx = (new_center_x - old_pose.x) / dt |
| 170 | + vy = (new_center_y - old_pose.y) / dt |
| 171 | + else: |
| 172 | + vx = 0 |
| 173 | + vy = 0 |
| 174 | + |
| 175 | + # Update the AgentState |
| 176 | + self.pedestrians[track_id].pose.x = new_center_x |
| 177 | + self.pedestrians[track_id].pose.y = new_center_y |
| 178 | + self.pedestrians[track_id].dimensions = (w, h, 0) |
| 179 | + self.pedestrians[track_id].velocity = (vx, vy, 0) |
| 180 | + |
| 181 | + # Reset the disappeared timer for this ID |
| 182 | + self.disappeared_timers[track_id] = current_time |
| 183 | + |
| 184 | + # Stores AgentState in a dict, can be removed if not required |
| 185 | + # pose = ObjectPose(t=0,x=x,y=y,z=0,yaw=0,pitch=0,roll=0,frame=ObjectFrameEnum.CURRENT) |
| 186 | + # dims = (w,h,0) |
| 187 | + # if(id not in self.pedestrians.keys()): |
| 188 | + # self.pedestrians[id] = AgentState(pose=pose,dimensions=dims,outline=None,type=AgentEnum.PEDESTRIAN,activity=AgentActivityEnum.MOVING,velocity=(0,0,0),yaw_rate=0) |
| 189 | + # else: |
| 190 | + # self.pedestrians[id].pose = pose |
| 191 | + # self.pedestrians[id].dims = dims |
129 | 192 |
|
130 | 193 | # Used for visualization |
131 | | - if(self.visualization): |
132 | | - # Draw bounding box |
133 | | - cv2.rectangle(image, (int(x - w / 2), int(y - h / 2)), (int(x + w / 2), int(y + h / 2)), (255, 0, 255), 3) |
| 194 | + # if(self.visualization): |
| 195 | + # # Draw bounding box |
| 196 | + # cv2.rectangle(image, (int(x - w / 2), int(y - h / 2)), (int(x + w / 2), int(y + h / 2)), (255, 0, 255), 3) |
134 | 197 |
|
135 | | - # Define text label |
136 | | - x = int(x - w / 2) |
137 | | - y = int(y - h / 2) |
138 | | - label = label_text + str(id) + " : " + str(round(box.conf.item(), 2)) |
| 198 | + # # Define text label |
| 199 | + # x = int(x - w / 2) |
| 200 | + # y = int(y - h / 2) |
| 201 | + # label = label_text + str(id) + " : " + str(round(box.conf.item(), 2)) |
139 | 202 |
|
140 | | - # Get text size |
141 | | - text_size, baseline = cv2.getTextSize(label, font, font_scale, line_type) |
142 | | - text_w, text_h = text_size |
| 203 | + # # Get text size |
| 204 | + # text_size, baseline = cv2.getTextSize(label, font, font_scale, line_type) |
| 205 | + # text_w, text_h = text_size |
143 | 206 |
|
144 | | - # Position text above the bounding box |
145 | | - text_x = x |
146 | | - text_y = y - 10 if y - 10 > 10 else y + h + text_h |
| 207 | + # # Position text above the bounding box |
| 208 | + # text_x = x |
| 209 | + # text_y = y - 10 if y - 10 > 10 else y + h + text_h |
147 | 210 |
|
148 | | - # Draw text outline for better visibility, uncomment for outline |
149 | | - # for dx, dy in [(-1, -1), (-1, 1), (1, -1), (1, 1)]: |
150 | | - # cv2.putText(image, label, (text_x + dx, text_y - baseline + dy), font, font_scale, outline_color, outline_thickness) |
| 211 | + # # Draw text outline for better visibility, uncomment for outline |
| 212 | + # # for dx, dy in [(-1, -1), (-1, 1), (1, -1), (1, 1)]: |
| 213 | + # # cv2.putText(image, label, (text_x + dx, text_y - baseline + dy), font, font_scale, outline_color, outline_thickness) |
151 | 214 |
|
152 | | - # Draw main text on top of the outline |
153 | | - cv2.putText(image, label, (text_x, text_y - baseline), font, font_scale, font_color, text_thickness) |
| 215 | + # # Draw main text on top of the outline |
| 216 | + # cv2.putText(image, label, (text_x, text_y - baseline), font, font_scale, font_color, text_thickness) |
154 | 217 |
|
155 | 218 |
|
156 | 219 | # Used for visualization |
157 | 220 | if(self.visualization): |
158 | | - ros_img = bridge.cv2_to_imgmsg(image, 'bgr8') |
| 221 | + self._visualize(cv_image, boxes) |
| 222 | + |
| 223 | + # Publish the annotated image |
| 224 | + ros_img = bridge.cv2_to_imgmsg(cv_image, 'bgr8') |
159 | 225 | self.pub_image.publish(ros_img) |
160 | 226 |
|
| 227 | + # Now handle "disappeared" IDs: |
| 228 | + # If an ID hasn't appeared for > max_disappear_time, remove it. |
| 229 | + to_remove = [] |
| 230 | + for pid, last_seen_time in self.disappeared_timers.items(): |
| 231 | + if (current_time - last_seen_time) > self.max_disappear_time: |
| 232 | + to_remove.append(pid) |
| 233 | + |
| 234 | + for pid in to_remove: |
| 235 | + self.disappeared_timers.pop(pid, None) |
| 236 | + self.pedestrians.pop(pid, None) |
| 237 | + |
161 | 238 | #uncomment if you want to debug the detector... |
162 | 239 | # print(self.last_person_boxes) |
163 | 240 | # print(self.pedestrians.keys()) |
164 | 241 | #for bb in self.last_person_boxes: |
165 | 242 | # x,y,w,h = bb |
166 | 243 | # cv2.rectangle(image, (int(x-w/2), int(y-h/2)), (int(x+w/2), int(y+h/2)), (255, 0, 255), 3) |
167 | 244 | #cv2.imwrite("pedestrian_detections.png",image) |
| 245 | + |
| 246 | + def _visualize(self, image, boxes): |
| 247 | + """ |
| 248 | + Overlays bounding boxes and labels on the image for debugging/visualization. |
| 249 | + """ |
| 250 | + label_text = "Pedestrian " |
| 251 | + font = cv2.FONT_HERSHEY_SIMPLEX |
| 252 | + font_scale = 0.5 |
| 253 | + font_color = (255, 255, 255) # White text |
| 254 | + line_type = 1 |
| 255 | + text_thickness = 2 |
| 256 | + |
| 257 | + for box in boxes: |
| 258 | + xywh = box.xywh[0].tolist() |
| 259 | + x, y, w, h = xywh |
| 260 | + track_id = box.id.item() |
| 261 | + conf = box.conf.item() |
| 262 | + |
| 263 | + # Draw bounding box |
| 264 | + x1, y1 = int(x - w/2), int(y - h/2) |
| 265 | + x2, y2 = int(x + w/2), int(y + h/2) |
| 266 | + cv2.rectangle(image, (x1, y1), (x2, y2), (255, 0, 255), 3) |
| 267 | + |
| 268 | + # Construct label |
| 269 | + label = f"{label_text}{track_id} : {conf:.2f}" |
| 270 | + |
| 271 | + # Position text above the bounding box |
| 272 | + text_size, baseline = cv2.getTextSize(label, font, font_scale, line_type) |
| 273 | + text_w, text_h = text_size |
| 274 | + text_x = x1 |
| 275 | + text_y = y1 - 10 if y1 - 10 > 10 else y2 + text_h |
| 276 | + |
| 277 | + # Draw text |
| 278 | + cv2.putText( |
| 279 | + image, |
| 280 | + label, |
| 281 | + (text_x, text_y - baseline), |
| 282 | + font, |
| 283 | + font_scale, |
| 284 | + font_color, |
| 285 | + text_thickness |
| 286 | + ) |
168 | 287 |
|
169 | | - def update(self, vehicle : VehicleState) -> Dict[str,AgentState]: |
170 | | - res = {} |
171 | | - for i,b in enumerate(self.last_person_boxes): |
172 | | - x,y,w,h = b |
173 | | - res['pedestrian'+str(i)] = box_to_fake_agent(b) |
174 | | - if len(res) > 0: |
175 | | - print("Detected",len(res),"pedestrians") |
176 | | - return res |
| 288 | + def update(self, vehicle: VehicleState) -> Dict[str, AgentState]: |
| 289 | + """ |
| 290 | + Called at the rate specified by self.rate(). |
| 291 | + Returns a dictionary of {agent_name: AgentState} for all currently tracked pedestrians. |
| 292 | + """ |
| 293 | + # You can name them 'pedestrian0', 'pedestrian1', etc. based on their YOLO ID |
| 294 | + agents = {} |
| 295 | + for pid, agent_state in self.pedestrians.items(): |
| 296 | + agents[f"pedestrian_{pid}"] = agent_state |
| 297 | + |
| 298 | + # If you want to see the console output: |
| 299 | + if len(agents) > 0: |
| 300 | + print(f"Currently tracking {len(agents)} pedestrians.") |
| 301 | + return agents |
| 302 | + |
| 303 | + |
| 304 | + # def update(self, vehicle : VehicleState) -> Dict[str,AgentState]: |
| 305 | + # res = {} |
| 306 | + # for i,b in enumerate(self.last_person_boxes): |
| 307 | + # x,y,w,h = b |
| 308 | + # res['pedestrian'+str(i)] = box_to_fake_agent(b) |
| 309 | + # if len(res) > 0: |
| 310 | + # print("Detected",len(res),"pedestrians") |
| 311 | + # return res |
177 | 312 |
|
178 | 313 |
|
179 | 314 | class FakePedestrianDetector2D(Component): |
|
0 commit comments