Skip to content

Commit 8c82061

Browse files
committed
WIP: pedestrian ids
1 parent 6978bb1 commit 8c82061

1 file changed

Lines changed: 184 additions & 49 deletions

File tree

GEMstack/onboard/perception/pedestrian_detection.py

Lines changed: 184 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import time
12
from ...state import AllState,VehicleState,ObjectPose,ObjectFrameEnum,AgentState,AgentEnum,AgentActivityEnum
23
from ..interface.gem import GEMInterface
34
from ..component import Component
@@ -37,6 +38,11 @@ def __init__(self,vehicle_interface : GEMInterface):
3738
self.vehicle_interface = vehicle_interface
3839
self.detector = YOLO(os.getcwd()+'/GEMstack/knowledge/detection/yolov8n.pt') # change to get model value from sys arg
3940
self.last_person_boxes = []
41+
42+
self.disappeared_timers = {}
43+
self.last_frame_time = None
44+
self.max_disappear_time = 1.0
45+
4046
self.pedestrians = {}
4147
self.visualization = True # Set this to true for visualization, later change to get value from sys arg
4248
self.confidence = 0.7
@@ -91,89 +97,218 @@ def image_callback(self, image : Image): #image : cv2.Mat):
9197
# Use Image directly for GEM Car
9298
# track_result = self.detector.track(source=image, classes=self.classes_to_detect, persist=True, conf=self.confidence)
9399

100+
# Compute time delta (dt) since last frame
101+
current_time = time.time()
102+
if self.last_frame_time is None:
103+
self.last_frame_time = current_time
104+
dt = current_time - self.last_frame_time
105+
self.last_frame_time = current_time
106+
94107
# Convert to CV2 format for RosBag
95108
bridge = CvBridge()
96-
image = bridge.imgmsg_to_cv2(image, "bgr8")
97-
track_result = self.detector.track(source=image, classes=self.classes_to_detect, persist=True, conf=self.confidence)
109+
cv_image = bridge.imgmsg_to_cv2(image, "bgr8")
110+
# track_result = self.detector.track(source=image, classes=self.classes_to_detect, persist=True, conf=self.confidence)
98111

99-
self.last_person_boxes = []
100-
boxes = track_result[0].boxes
112+
# Run YOLO tracking
113+
track_result = self.detector.track(
114+
source=cv_image,
115+
classes=self.classes_to_detect,
116+
persist=True,
117+
conf=self.confidence
118+
)
101119

102-
# Used for visualization
103-
if(self.visualization):
104-
label_text = "Pedestrian "
105-
font = cv2.FONT_HERSHEY_SIMPLEX
106-
font_scale = 0.5
107-
font_color = (255, 255, 255) # White text
108-
outline_color = (0, 0, 0) # Black outline
109-
line_type = 1
110-
text_thickness = 2 # Text thickness
111-
outline_thickness = 1 # Thickness of the text outline
120+
# YOLOv8 returns a 'Boxes' object
121+
# We'll gather the IDs we see in this frame:
122+
current_ids = set()
123+
124+
if len(track_result) > 0:
125+
boxes = track_result[0].boxes
126+
else:
127+
boxes = []
128+
129+
# self.last_person_boxes = []
130+
# boxes = track_result[0].boxes
131+
132+
# # Used for visualization
133+
# if(self.visualization):
134+
# label_text = "Pedestrian "
135+
# font = cv2.FONT_HERSHEY_SIMPLEX
136+
# font_scale = 0.5
137+
# font_color = (255, 255, 255) # White text
138+
# outline_color = (0, 0, 0) # Black outline
139+
# line_type = 1
140+
# text_thickness = 2 # Text thickness
141+
# outline_thickness = 1 # Thickness of the text outline
112142

113143
# Unpacking box dimentions detected into x,y,w,h
114144
for box in boxes:
115145

116146
xywh = box.xywh[0].tolist()
117147
self.last_person_boxes.append(xywh)
118148
x, y, w, h = xywh
119-
id = box.id.item()
149+
# id = box.id.item()
120150

121-
# Stores AgentState in a dict, can be removed if not required
122-
pose = ObjectPose(t=0,x=x,y=y,z=0,yaw=0,pitch=0,roll=0,frame=ObjectFrameEnum.CURRENT)
123-
dims = (w,h,0)
124-
if(id not in self.pedestrians.keys()):
125-
self.pedestrians[id] = AgentState(pose=pose,dimensions=dims,outline=None,type=AgentEnum.PEDESTRIAN,activity=AgentActivityEnum.MOVING,velocity=(0,0,0),yaw_rate=0)
151+
# YOLO assigned ID
152+
track_id = box.id.item()
153+
current_ids.add(track_id)
154+
155+
# Check if we have seen this ID before
156+
if track_id not in self.pedestrians:
157+
# It's a new ID, create a new AgentState
158+
self.pedestrians[track_id] = box_to_fake_agent(xywh, velocity=(0,0,0))
126159
else:
127-
self.pedestrians[id].pose = pose
128-
self.pedestrians[id].dims = dims
160+
# Update existing
161+
old_pose = self.pedestrians[track_id].pose
162+
163+
# Calculate new center
164+
new_center_x = x + w/2
165+
new_center_y = y + h/2
166+
167+
# Compute velocity from old pose -> new pose (pixels/sec or similar)
168+
if dt > 0:
169+
vx = (new_center_x - old_pose.x) / dt
170+
vy = (new_center_y - old_pose.y) / dt
171+
else:
172+
vx = 0
173+
vy = 0
174+
175+
# Update the AgentState
176+
self.pedestrians[track_id].pose.x = new_center_x
177+
self.pedestrians[track_id].pose.y = new_center_y
178+
self.pedestrians[track_id].dimensions = (w, h, 0)
179+
self.pedestrians[track_id].velocity = (vx, vy, 0)
180+
181+
# Reset the disappeared timer for this ID
182+
self.disappeared_timers[track_id] = current_time
183+
184+
# Stores AgentState in a dict, can be removed if not required
185+
# pose = ObjectPose(t=0,x=x,y=y,z=0,yaw=0,pitch=0,roll=0,frame=ObjectFrameEnum.CURRENT)
186+
# dims = (w,h,0)
187+
# if(id not in self.pedestrians.keys()):
188+
# self.pedestrians[id] = AgentState(pose=pose,dimensions=dims,outline=None,type=AgentEnum.PEDESTRIAN,activity=AgentActivityEnum.MOVING,velocity=(0,0,0),yaw_rate=0)
189+
# else:
190+
# self.pedestrians[id].pose = pose
191+
# self.pedestrians[id].dims = dims
129192

130193
# Used for visualization
131-
if(self.visualization):
132-
# Draw bounding box
133-
cv2.rectangle(image, (int(x - w / 2), int(y - h / 2)), (int(x + w / 2), int(y + h / 2)), (255, 0, 255), 3)
194+
# if(self.visualization):
195+
# # Draw bounding box
196+
# cv2.rectangle(image, (int(x - w / 2), int(y - h / 2)), (int(x + w / 2), int(y + h / 2)), (255, 0, 255), 3)
134197

135-
# Define text label
136-
x = int(x - w / 2)
137-
y = int(y - h / 2)
138-
label = label_text + str(id) + " : " + str(round(box.conf.item(), 2))
198+
# # Define text label
199+
# x = int(x - w / 2)
200+
# y = int(y - h / 2)
201+
# label = label_text + str(id) + " : " + str(round(box.conf.item(), 2))
139202

140-
# Get text size
141-
text_size, baseline = cv2.getTextSize(label, font, font_scale, line_type)
142-
text_w, text_h = text_size
203+
# # Get text size
204+
# text_size, baseline = cv2.getTextSize(label, font, font_scale, line_type)
205+
# text_w, text_h = text_size
143206

144-
# Position text above the bounding box
145-
text_x = x
146-
text_y = y - 10 if y - 10 > 10 else y + h + text_h
207+
# # Position text above the bounding box
208+
# text_x = x
209+
# text_y = y - 10 if y - 10 > 10 else y + h + text_h
147210

148-
# Draw text outline for better visibility, uncomment for outline
149-
# for dx, dy in [(-1, -1), (-1, 1), (1, -1), (1, 1)]:
150-
# cv2.putText(image, label, (text_x + dx, text_y - baseline + dy), font, font_scale, outline_color, outline_thickness)
211+
# # Draw text outline for better visibility, uncomment for outline
212+
# # for dx, dy in [(-1, -1), (-1, 1), (1, -1), (1, 1)]:
213+
# # cv2.putText(image, label, (text_x + dx, text_y - baseline + dy), font, font_scale, outline_color, outline_thickness)
151214

152-
# Draw main text on top of the outline
153-
cv2.putText(image, label, (text_x, text_y - baseline), font, font_scale, font_color, text_thickness)
215+
# # Draw main text on top of the outline
216+
# cv2.putText(image, label, (text_x, text_y - baseline), font, font_scale, font_color, text_thickness)
154217

155218

156219
# Used for visualization
157220
if(self.visualization):
158-
ros_img = bridge.cv2_to_imgmsg(image, 'bgr8')
221+
self._visualize(cv_image, boxes)
222+
223+
# Publish the annotated image
224+
ros_img = bridge.cv2_to_imgmsg(cv_image, 'bgr8')
159225
self.pub_image.publish(ros_img)
160226

227+
# Now handle "disappeared" IDs:
228+
# If an ID hasn't appeared for > max_disappear_time, remove it.
229+
to_remove = []
230+
for pid, last_seen_time in self.disappeared_timers.items():
231+
if (current_time - last_seen_time) > self.max_disappear_time:
232+
to_remove.append(pid)
233+
234+
for pid in to_remove:
235+
self.disappeared_timers.pop(pid, None)
236+
self.pedestrians.pop(pid, None)
237+
161238
#uncomment if you want to debug the detector...
162239
# print(self.last_person_boxes)
163240
# print(self.pedestrians.keys())
164241
#for bb in self.last_person_boxes:
165242
# x,y,w,h = bb
166243
# cv2.rectangle(image, (int(x-w/2), int(y-h/2)), (int(x+w/2), int(y+h/2)), (255, 0, 255), 3)
167244
#cv2.imwrite("pedestrian_detections.png",image)
245+
246+
def _visualize(self, image, boxes):
247+
"""
248+
Overlays bounding boxes and labels on the image for debugging/visualization.
249+
"""
250+
label_text = "Pedestrian "
251+
font = cv2.FONT_HERSHEY_SIMPLEX
252+
font_scale = 0.5
253+
font_color = (255, 255, 255) # White text
254+
line_type = 1
255+
text_thickness = 2
256+
257+
for box in boxes:
258+
xywh = box.xywh[0].tolist()
259+
x, y, w, h = xywh
260+
track_id = box.id.item()
261+
conf = box.conf.item()
262+
263+
# Draw bounding box
264+
x1, y1 = int(x - w/2), int(y - h/2)
265+
x2, y2 = int(x + w/2), int(y + h/2)
266+
cv2.rectangle(image, (x1, y1), (x2, y2), (255, 0, 255), 3)
267+
268+
# Construct label
269+
label = f"{label_text}{track_id} : {conf:.2f}"
270+
271+
# Position text above the bounding box
272+
text_size, baseline = cv2.getTextSize(label, font, font_scale, line_type)
273+
text_w, text_h = text_size
274+
text_x = x1
275+
text_y = y1 - 10 if y1 - 10 > 10 else y2 + text_h
276+
277+
# Draw text
278+
cv2.putText(
279+
image,
280+
label,
281+
(text_x, text_y - baseline),
282+
font,
283+
font_scale,
284+
font_color,
285+
text_thickness
286+
)
168287

169-
def update(self, vehicle : VehicleState) -> Dict[str,AgentState]:
170-
res = {}
171-
for i,b in enumerate(self.last_person_boxes):
172-
x,y,w,h = b
173-
res['pedestrian'+str(i)] = box_to_fake_agent(b)
174-
if len(res) > 0:
175-
print("Detected",len(res),"pedestrians")
176-
return res
288+
def update(self, vehicle: VehicleState) -> Dict[str, AgentState]:
289+
"""
290+
Called at the rate specified by self.rate().
291+
Returns a dictionary of {agent_name: AgentState} for all currently tracked pedestrians.
292+
"""
293+
# You can name them 'pedestrian0', 'pedestrian1', etc. based on their YOLO ID
294+
agents = {}
295+
for pid, agent_state in self.pedestrians.items():
296+
agents[f"pedestrian_{pid}"] = agent_state
297+
298+
# If you want to see the console output:
299+
if len(agents) > 0:
300+
print(f"Currently tracking {len(agents)} pedestrians.")
301+
return agents
302+
303+
304+
# def update(self, vehicle : VehicleState) -> Dict[str,AgentState]:
305+
# res = {}
306+
# for i,b in enumerate(self.last_person_boxes):
307+
# x,y,w,h = b
308+
# res['pedestrian'+str(i)] = box_to_fake_agent(b)
309+
# if len(res) > 0:
310+
# print("Detected",len(res),"pedestrians")
311+
# return res
177312

178313

179314
class FakePedestrianDetector2D(Component):

0 commit comments

Comments
 (0)