Skip to content

Commit cfd3994

Browse files
combined_detection.py add merge, fusing 3d bboxes
1 parent 75dfafb commit cfd3994

1 file changed

Lines changed: 195 additions & 66 deletions

File tree

GEMstack/onboard/perception/combined_detection.py

Lines changed: 195 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -8,47 +8,72 @@
88
import time
99
import os
1010
import yaml
11+
from typing import Dict, List, Optional, Tuple
12+
import numpy as np
13+
from scipy.spatial.transform import Rotation as R
14+
1115

1216
from jsk_recognition_msgs.msg import BoundingBox, BoundingBoxArray
1317

1418

15-
class CombinedDetector3D(Component):
16-
"""
17-
Fuses the boxes in the lists of bounding boxes published by YoloNode and
18-
PointPillarsNode with late sensor fusion.
19-
TODO: SUBSCRIBE TO BOUNDING BOX LISTS AND PERFORM LATE SENSOR FUSION IN THIS FILE.
20-
TODO: MODIFY YAML FILE FOR THE CONTROL TEAM'S BASIC PATH PLANNING CODE?
19+
# TODO: Import IOU from IOU funcs in SensorFusion?
20+
21+
# Reuse eval funcs?
22+
def calculate_3d_iou(box1: BoundingBox, box2: BoundingBox) -> float:
23+
return 0.0
24+
25+
def merge_boxes(box1: BoundingBox, box2: BoundingBox) -> BoundingBox:
26+
# TODO: merging
27+
# Heuristics- Average pose
28+
# Average dimensions
29+
# Use highest score
30+
# Label specific logic
31+
merged_box = BoundingBox()
32+
merged_box.header = box1.header # Use header from one input
33+
34+
## Avg position, average dimensions, max score, box1 label
35+
merged_box.pose.position.x = (box1.pose.position.x + box2.pose.position.x) / 2.0
36+
merged_box.pose.position.y = (box1.pose.position.y + box2.pose.position.y) / 2.0
37+
merged_box.pose.position.z = (box1.pose.position.z + box2.pose.position.z) / 2.0
38+
# Avg orientation (quaternions)
39+
merged_box.pose.orientation = box1.pose.orientation
40+
merged_box.dimensions.x = (box1.dimensions.x + box2.dimensions.x) / 2.0
41+
merged_box.dimensions.y = (box1.dimensions.y + box2.dimensions.y) / 2.0
42+
merged_box.dimensions.z = (box1.dimensions.z + box2.dimensions.z) / 2.0
43+
merged_box.value = max(box1.value, box2.value) # Max score
44+
merged_box.label = box1.label # Label from first box
2145

22-
Tracking is optional: set `enable_tracking=False` to disable persistent tracking
23-
and return only detections from the current frame.
46+
return merged_box
2447

25-
Supports multiple cameras; each camera’s intrinsics and extrinsics are
26-
loaded from a single YAML calibration file via plain PyYAML.
27-
"""
2848

49+
class CombinedDetector3D(Component):
2950
def __init__(
3051
self,
3152
vehicle_interface: GEMInterface,
3253
enable_tracking: bool = True,
3354
use_start_frame: bool = True,
34-
**kwargs
55+
iou_threshold: float = 0.1,
56+
**kwargs
3557
):
36-
# Core interfaces and state
3758
self.vehicle_interface = vehicle_interface
38-
self.current_agents = {}
39-
self.tracked_agents = {}
59+
self.tracked_agents: Dict[str, AgentState] = {}
4060
self.ped_counter = 0
41-
self.latest_yolo_bbxs = None # Stores the latest list of YOLO bounding boxes
42-
self.latest_pp_bbxs = None # Stores the latest list of PointPillars bounding boxes
43-
self.start_pose_abs = None
44-
self.start_time = None
61+
self.latest_yolo_bbxs: Optional[BoundingBoxArray] = None
62+
self.latest_pp_bbxs: Optional[BoundingBoxArray] = None
63+
self.start_pose_abs: Optional[ObjectPose] = None
64+
self.start_time: Optional[float] = None
4565

46-
# Config flags
4766
self.enable_tracking = enable_tracking
4867
self.use_start_frame = use_start_frame
68+
self.iou_threshold = iou_threshold
69+
70+
self.yolo_topic = "/yolo_boxes"
71+
self.pp_topic = "/pointpillars_boxes"
72+
73+
rospy.loginfo(f"CombinedDetector3D Initialized. Subscribing to '{self.yolo_topic}' and '{self.pp_topic}'.")
4974

5075
def rate(self) -> float:
51-
return 8
76+
return 8.0
5277

5378
def state_inputs(self) -> list:
5479
return ['vehicle']
@@ -57,61 +82,165 @@ def state_outputs(self) -> list:
5782
return ['agents']
5883

5984
def initialize(self):
60-
# Subscribe to the BoundingBox
61-
self.yolo_sub = Subscriber('/yolo_boxes', BoundingBoxArray)
62-
self.pp_sub = Subscriber('/pointpillars_boxes', BoundingBoxArray)
63-
self.sync = ApproximateTimeSynchronizer([
64-
self.yolo_sub, self.pp_sub
65-
], queue_size=50, slop=0.05) # GREATLY DECREASED QUEUE SIZE, 50 might even be too much
85+
self.yolo_sub = Subscriber(self.yolo_topic, BoundingBoxArray)
86+
self.pp_sub = Subscriber(self.pp_topic, BoundingBoxArray)
87+
88+
queue_size = 10
89+
slop = 0.1
90+
91+
self.sync = ApproximateTimeSynchronizer(
92+
[self.yolo_sub, self.pp_sub],
93+
queue_size=queue_size,
94+
slop=slop
95+
)
6696
self.sync.registerCallback(self.synchronized_callback)
97+
rospy.loginfo("CombinedDetector3D Subscribers Initialized.")
6798

68-
def synchronized_callback(self, yolo_bbxs_msg, pp_bbxs_msg):
99+
def synchronized_callback(self, yolo_bbxs_msg: BoundingBoxArray, pp_bbxs_msg: BoundingBoxArray):
69100
self.latest_yolo_bbxs = yolo_bbxs_msg
70101
self.latest_pp_bbxs = pp_bbxs_msg
71102

72-
def update(self, vehicle: VehicleState) -> Dict[str, AgentState]:
73-
# Gate guards against data not being present for both sensors:
74-
if self.latest_yolo_bbxs is None or self.latest_pp_bbxs is None:
75-
return {}
76-
77-
# Set up current time variables
78-
start = time.time()
103+
def update(self, state: AllState) -> Dict[str, AgentState]:
104+
vehicle = state.vehicle
79105
current_time = self.vehicle_interface.time()
80106

107+
yolo_bbx_array = self.latest_yolo_bbxs
108+
pp_bbx_array = self.latest_pp_bbxs
109+
110+
if yolo_bbx_array is None or pp_bbx_array is None:
111+
return {}
112+
81113
if self.start_time is None:
82114
self.start_time = current_time
83-
time_elapsed = current_time - self.start_time
84-
85-
agents = {}
86-
# TODO: Loop through bounding box lists here
87-
# The bounding box lists that were matched up by self.synchronized_callback SHOULD match up
88-
# correctly because we manually inserted the time stamp of the lidar data into the header
89-
# of the bounding box list. So since ApproximateTimeSynchronizer syncs up messages which
90-
# have similar time stamps (assumed to be determined by the time stamp in the message header),
91-
# the bounding box lists being compared should be from the same point cloud data. The image
92-
# data paired with it may be slightly different but since the bounding boxes from both models
93-
# were built in 3D space using the lidar data, they should pair up well enough
94-
95-
# The bounding boxes in both lists SHOULD ALREADY BE IN THE VEHICLE FRAME since we transformed
96-
# the data from lidar->vehicle before creating the bounding boxes and then publishing.
97-
98-
# To compare the bounding boxes in the lists, we can either use a 2D intersection over union birds
99-
# eye view approach (since point pillars creates vertical pillars anyways) or we can do a 3D
100-
# intersection over union. We could then merge the boxes that match closely by averaging
101-
# their positions and dimensions and then we'd choose the label with the highest confidence.
102-
103-
# For the leftover bounding boxes, we can still use them with their original confidence
104-
# (confidence was placed in the value field of each box).
105-
106-
# Finally, we would need to convert each box to an AgentState object
107-
# We would then need to transform the AgentState object to the start frame to compare with old
108-
# AgentState objects to assign id's and calculate velocity
109-
# Then we would need to return the new list of AgentState objects
110-
111-
end = time.time()
112-
# print('-------processing time---', end -start)
115+
if self.use_start_frame and self.start_pose_abs is None:
116+
self.start_pose_abs = vehicle.pose
117+
rospy.loginfo("CombinedDetector3D latched start pose.")
118+
119+
current_frame_agents = self._fuse_bounding_boxes(yolo_bbx_array, pp_bbx_array, vehicle, current_time)
120+
121+
if self.enable_tracking:
122+
self._update_tracking(current_frame_agents)
123+
else:
124+
self.tracked_agents = current_frame_agents # NOTE: No deepcopy
125+
113126
return self.tracked_agents
114127

128+
129+
def _fuse_bounding_boxes(self,
130+
yolo_bbx_array: BoundingBoxArray,
131+
pp_bbx_array: BoundingBoxArray,
132+
vehicle_state: VehicleState,
133+
current_time: float
134+
) -> Dict[str, AgentState]:
135+
current_agents_in_frame: Dict[str, AgentState] = {}
136+
yolo_boxes: List[BoundingBox] = yolo_bbx_array.boxes
137+
pp_boxes: List[BoundingBox] = pp_bbx_array.boxes
138+
139+
output_frame_enum = ObjectFrameEnum.START if self.use_start_frame else ObjectFrameEnum.CURRENT
140+
141+
matched_yolo_indices = set()
142+
matched_pp_indices = set()
143+
fused_boxes_list: List[BoundingBox] = []
144+
145+
# Can optimize from NxM loop
146+
for i, yolo_box in enumerate(yolo_boxes):
147+
best_match_j = -1
148+
best_iou = -1.0
149+
for j, pp_box in enumerate(pp_boxes):
150+
if j in matched_pp_indices: # Skip already matched PP boxes
151+
continue
152+
153+
## IoU
154+
iou = calculate_3d_iou(yolo_box, pp_box)
155+
156+
if iou > self.iou_threshold and iou > best_iou:
157+
best_iou = iou
158+
best_match_j = j
159+
160+
if best_match_j != -1:
161+
rospy.logdebug(f"Matched YOLO box {i} with PP box {best_match_j} (IoU: {best_iou:.3f})")
162+
matched_yolo_indices.add(i)
163+
matched_pp_indices.add(best_match_j)
164+
merged = merge_boxes(yolo_box, pp_boxes[best_match_j])
165+
fused_boxes_list.append(merged)
166+
167+
## Unmatched Bboxes
168+
for i, yolo_box in enumerate(yolo_boxes):
169+
if i not in matched_yolo_indices:
170+
fused_boxes_list.append(yolo_box)
171+
rospy.logdebug(f"Kept unmatched YOLO box {i}")
172+
173+
for j, pp_box in enumerate(pp_boxes):
174+
if j not in matched_pp_indices:
175+
fused_boxes_list.append(pp_box)
176+
rospy.logdebug(f"Kept unmatched PP box {j}")
177+
178+
# Agenstate
179+
for i, box in enumerate(fused_boxes_list):
180+
try:
181+
# Cur vehicle frame
182+
pos_x = box.pose.position.x; pos_y = box.pose.position.y; pos_z = box.pose.position.z
183+
quat_x = box.pose.orientation.x; quat_y = box.pose.orientation.y; quat_z = box.pose.orientation.z; quat_w = box.pose.orientation.w
184+
yaw, pitch, roll = R.from_quat([quat_x, quat_y, quat_z, quat_w]).as_euler('zyx', degrees=False)
185+
186+
# Start frame
187+
if self.use_start_frame and self.start_pose_abs is not None:
188+
vehicle_pose_in_start_frame = vehicle_state.pose.to_frame(
189+
ObjectFrameEnum.START, vehicle_state.pose, self.start_pose_abs
190+
)
191+
T_vehicle_to_start = pose_to_matrix(vehicle_pose_in_start_frame)
192+
object_pose_current_h = np.array([[pos_x],[pos_y],[pos_z],[1.0]])
193+
object_pose_start_h = T_vehicle_to_start @ object_pose_current_h
194+
final_x, final_y, final_z = object_pose_start_h[:3, 0]
195+
else:
196+
final_x, final_y, final_z = pos_x, pos_y, pos_z
197+
198+
final_pose = ObjectPose(
199+
t=current_time, x=final_x, y=final_y, z=final_z,
200+
yaw=yaw, pitch=pitch, roll=roll, frame=output_frame_enum
201+
)
202+
dims = (box.dimensions.x, box.dimensions.y, box.dimensions.z)
203+
######### Mapping based on label (integer) from BoundingBox msg
204+
agent_type = AgentEnum.PEDESTRIAN if box.label == 0 else AgentEnum.UNKNOWN # Needs refinement
205+
activity = AgentActivityEnum.UNKNOWN # Placeholder
206+
207+
# temp id
208+
# _update_tracking assign persistent IDs
209+
temp_agent_id = f"FrameDet_{i}"
210+
211+
current_agents_in_frame[temp_agent_id] = AgentState(
212+
pose=final_pose, dimensions=dims, outline=None, type=agent_type,
213+
activity=activity, velocity=(0.0,0.0,0.0), yaw_rate=0.0
214+
# score=box.value # score
215+
)
216+
except Exception as e:
217+
rospy.logwarn(f"Failed to convert final BoundingBox {i} to AgentState: {e}")
218+
continue
219+
220+
return current_agents_in_frame
221+
222+
223+
def _update_tracking(self, current_frame_agents: Dict[str, AgentState]):
224+
225+
# Todo tracking
226+
## Match 'current_frame_agents' to 'self.tracked_agents'.
227+
## - Use position (already in correct START or CURRENT frame), maybe size/type.
228+
## - Need a matching algorithm (e.g., nearest neighbor within radius, Hungarian).
229+
## For matched pairs:
230+
## - Update the existing agent in 'self.tracked_agents' (e.g., smooth pose, update timestamp).
231+
## For unmatched 'current_frame_agents':
232+
## - These are new detections. Assign a persistent ID (e.g., f"Ped_{self.ped_counter}").
233+
## - Increment self.ped_counter.
234+
## - Add them to 'self.tracked_agents'.
235+
## For unmatched 'self.tracked_agents' (agents not seen this frame):
236+
## - Increment a 'missed frames' counter or check timestamp.
237+
## - If missed for too long (e.g., > 1 second), remove from 'self.tracked_agents'.
238+
239+
# return without tracking
240+
self.tracked_agents = current_frame_agents
241+
242+
243+
115244
# Fake 2D Combined Detector for testing purposes
116245
# TODO FIX THIS
117246
class FakeCombinedDetector2D(Component):
@@ -151,4 +280,4 @@ def box_to_fake_agent(box):
151280

152281

153282
if __name__ == '__main__':
154-
pass
283+
pass

0 commit comments

Comments
 (0)