Skip to content

Commit 10d9e02

Browse files
committed
Improved YOLO integration into combined detector. Uploaded yaml file that was missed. Removed YOLO node that was missed.
1 parent 40bedbe commit 10d9e02

3 files changed

Lines changed: 195 additions & 288 deletions

File tree

GEMstack/onboard/perception/combined_detection.py

Lines changed: 62 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -231,16 +231,20 @@ def __init__(
231231
camera_name,
232232
camera_calib_file,
233233
iou_threshold: float = 0.1,
234+
score_threshold: float = 0.4,
234235
merge_mode: str = "Average",
235236
enable_tracking: bool = True,
236237
use_start_frame: bool = True,
238+
slop = 0.1,
239+
max_bb_buffer_size = 10,
240+
visualize = True,
237241
**kwargs
238242
):
239243
self.vehicle_interface = vehicle_interface
240244
self.tracked_agents: Dict[str, AgentState] = {}
241245
self.ped_counter = 0
242-
self.latest_yolo_bbxs: Optional[BoundingBoxArray] = None
243246
self.latest_pp_bbxs: Optional[BoundingBoxArray] = None
247+
self.latest_yolo_bbxs: Optional[BoundingBoxArray] = None
244248
self.start_pose_abs: Optional[ObjectPose] = None
245249
self.start_time: Optional[float] = None
246250

@@ -261,8 +265,8 @@ def __init__(
261265
self.bridge = CvBridge()
262266
self.camera_name = camera_name
263267
self.camera_front = (self.camera_name == 'front')
264-
self.score_threshold = 0.4
265-
self.debug = True
268+
self.score_threshold = score_threshold
269+
self.visualize = visualize
266270

267271
# Load camera intrinsics/extrinsics from YAML
268272
with open(camera_calib_file, 'r') as f:
@@ -283,7 +287,10 @@ def __init__(
283287
self.undistort_map1 = None
284288
self.undistort_map2 = None
285289
self.camera_front = (self.camera_name == 'front')
286-
290+
self.yolo_buffer: List[BoundingBoxArray] = []
291+
self.pp_buffer: List[BoundingBoxArray] = []
292+
self.max_bb_buffer_size = max_bb_buffer_size
293+
self.slop = slop
287294

288295
def rate(self) -> float:
289296
return 8.0
@@ -324,20 +331,52 @@ def initialize(self):
324331
self.sync.registerCallback(self.synchronized_yolo_callback)
325332

326333
self.yolo_sub = Subscriber(self.yolo_topic, BoundingBoxArray)
327-
self.pp_sub = Subscriber(self.pp_topic, BoundingBoxArray)
334+
rospy.Subscriber(self.pp_topic, BoundingBoxArray, self.store_pp_array, queue_size=10)
328335
self.pub_fused = rospy.Publisher("/fused_boxes", BoundingBoxArray, queue_size=1)
329336

330-
queue_size = 10
331-
slop = 0.1
332-
333-
self.sync = ApproximateTimeSynchronizer(
334-
[self.yolo_sub, self.pp_sub],
335-
queue_size=queue_size,
336-
slop=slop
337-
)
338-
self.sync.registerCallback(self.synchronized_callback)
339337
rospy.loginfo("CombinedDetector3D Subscribers Initialized.")
340338

339+
def add_bb_array(self, bb_buffer: List[BoundingBoxArray], add_arr: BoundingBoxArray):
340+
bb_buffer.append(add_arr)
341+
342+
# If buffer exceeds max size, remove the YOLO bounding box array with the oldest timestamp
343+
if len(bb_buffer) > self.max_bb_buffer_size:
344+
oldest_index = min(range(len(bb_buffer)),
345+
key=lambda i: bb_buffer[i].header.stamp.to_sec())
346+
del bb_buffer[oldest_index]
347+
348+
def get_bb_arrays_sorted(self, bb_buffer: List[BoundingBoxArray]) -> List[BoundingBoxArray]:
349+
return sorted(bb_buffer, key=lambda curr_bb_arr: curr_bb_arr.header.stamp.to_sec())
350+
351+
def store_pp_array(self, bbxs_msg: BoundingBoxArray):
352+
self.add_bb_array(bb_buffer=self.pp_buffer, add_arr=bbxs_msg) # Store the boxes array for later comparison to YOLO bounding box array
353+
self.try_match()
354+
355+
def store_yolo_array(self, bbxs_msg: BoundingBoxArray):
356+
self.add_bb_array(bb_buffer=self.yolo_buffer, add_arr=bbxs_msg) # Store the boxes array for later comparison to PointPillars bounding box array
357+
self.try_match()
358+
359+
def try_match(self):
360+
matched = []
361+
362+
for i, yolo_bbxs in enumerate(self.yolo_buffer):
363+
time_yolo = yolo_bbxs.header.stamp.to_sec()
364+
365+
for j, pp_bbxs in enumerate(self.pp_buffer):
366+
time_pp = pp_bbxs.header.stamp.to_sec()
367+
368+
if abs(time_yolo - time_pp) <= self.slop:
369+
matched_pair = (copy.deepcopy(yolo_bbxs), copy.deepcopy(pp_bbxs))
370+
matched.append((i, j, matched_pair))
371+
break # We only want one match
372+
373+
# Remove the match messages
374+
for i, j, (yolo_bbxs, pp_bbxs) in reversed(matched):
375+
del self.yolo_buffer[i]
376+
del self.pp_buffer[j]
377+
self.latest_pp_bbxs = pp_bbxs
378+
self.latest_yolo_bbxs = yolo_bbxs
379+
341380
def synchronized_yolo_callback(self, image_msg, lidar_msg):
342381
"""Process synchronized RGB and LiDAR messages to detect pedestrians."""
343382
rospy.loginfo("Received synchronized RGB and LiDAR messages")
@@ -482,9 +521,13 @@ def synchronized_yolo_callback(self, image_msg, lidar_msg):
482521
f"{refined_center_vehicle[1]:.2f}, {refined_center_vehicle[2]:.2f}) "
483522
f"with score {conf_scores[i]:.2f}")
484523

485-
# Publish the bounding boxes
486-
rospy.loginfo(f"Publishing {len(boxes.boxes)} person bounding boxes")
487-
self.pub_yolo.publish(boxes)
524+
525+
self.store_yolo_array(bbxs_msg=boxes) # Store the boxes array for later comparison to PointPillars bounding box array
526+
527+
# Publish the bounding boxes for visualization
528+
if self.visualize:
529+
rospy.loginfo(f"Publishing {len(boxes.boxes)} person bounding boxes")
530+
self.pub_yolo.publish(boxes)
488531

489532
def undistort_image(self, image, K, D):
490533
"""Undistort an image using the camera calibration parameters."""
@@ -503,11 +546,6 @@ def undistort_image(self, image, K, D):
503546
# print('--------undistort', end-start)
504547
return undistorted, newK
505548

506-
def synchronized_callback(self, yolo_bbxs_msg: BoundingBoxArray, pp_bbxs_msg: BoundingBoxArray):
507-
"""Callback for synchronized YOLO and PointPillars messages."""
508-
self.latest_yolo_bbxs = yolo_bbxs_msg
509-
self.latest_pp_bbxs = pp_bbxs_msg
510-
511549
def update(self, vehicle: VehicleState) -> Dict[str, AgentState]:
512550
"""Update function called by the GEMstack pipeline."""
513551
current_time = self.vehicle_interface.time()
@@ -613,7 +651,8 @@ def update(self, vehicle: VehicleState) -> Dict[str, AgentState]:
613651
agents[agent_id] = new_agent
614652
self.tracked_agents[agent_id] = new_agent
615653

616-
self.pub_fused.publish(fused_bb_array)
654+
if self.visualize:
655+
self.pub_fused.publish(fused_bb_array)
617656

618657
stale_ids = [agent_id for agent_id, agent in self.tracked_agents.items()
619658
if current_time - agent.pose.t > 5.0]

0 commit comments

Comments
 (0)