Skip to content

Commit abc0663

Browse files
committed
Update ROS callback function
1 parent 31b5cac commit abc0663

1 file changed

Lines changed: 70 additions & 92 deletions

File tree

homework/pedestrian_detection.py

Lines changed: 70 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,10 @@
99
from sklearn.cluster import DBSCAN
1010
from scipy.spatial.transform import Rotation as R
1111
import rospy
12-
from sensor_msgs.msg import PointCloud2
12+
from sensor_msgs.msg import PointCloud2, Image
1313
import sensor_msgs.point_cloud2 as pc2
1414
import struct, ctypes
15-
15+
from message_filters import Subscriber, ApproximateTimeSynchronizer
1616

1717
# ----- Helper Functions -----
1818
def match_existing_pedestrian(
@@ -37,7 +37,6 @@ def match_existing_pedestrian(
3737

3838
return best_agent_id
3939

40-
4140
def compute_velocity(old_pose: ObjectPose, new_pose: ObjectPose, dt: float) -> tuple:
4241
"""
4342
Returns a (vx, vy, vz) velocity in the same frame as old_pose/new_pose.
@@ -55,19 +54,14 @@ def extract_roi_box(lidar_pc, center, half_extents):
5554
mask = np.all((lidar_pc >= lower) & (lidar_pc <= upper), axis=1)
5655
return lidar_pc[mask]
5756

58-
5957
def pc2_to_numpy(pc2_msg, want_rgb=False):
6058
"""Convert ROS PointCloud2 message to a numpy array, filtering points with x > 0 and z < 2.5."""
6159
gen = pc2.read_points(pc2_msg, skip_nans=True)
6260
pts = np.array(list(gen), dtype=np.float32)
63-
# Use only the first three columns (x, y, z)
64-
pts = pts[:, :3]
65-
# Filter: only keep points where x > 0 and z < 2.5
61+
pts = pts[:, :3] # Use only x, y, z
6662
mask = (pts[:, 0] > 0) & (pts[:, 2] < 2.5)
6763
return pts[mask]
6864

69-
70-
7165
def backproject_pixel(u, v, K):
7266
"""Backprojects pixel (u,v) into a normalized 3D ray (camera coordinates)."""
7367
cx, cy = K[0, 2], K[1, 2]
@@ -77,28 +71,23 @@ def backproject_pixel(u, v, K):
7771
ray_dir = np.array([x, y, 1.0])
7872
return ray_dir / np.linalg.norm(ray_dir)
7973

80-
8174
def find_human_center_on_ray(lidar_pc, ray_origin, ray_direction,
8275
t_min, t_max, t_step,
8376
distance_threshold, min_points, ransac_threshold):
8477
"""
8578
Pre-filter the point cloud to only include points near the ray, then sweep along the ray.
86-
For each candidate along the ray, compute the centroid of nearby points and return that as the refined candidate.
8779
Returns (refined_candidate, None, None) if found; otherwise, (None, None, None).
8880
"""
89-
# Pre-filter: compute distance from each point to the ray.
90-
vecs = lidar_pc - ray_origin # Vectors from origin to points.
91-
proj_lengths = np.dot(vecs, ray_direction) # Projection lengths.
81+
vecs = lidar_pc - ray_origin
82+
proj_lengths = np.dot(vecs, ray_direction)
9283
proj_points = ray_origin + np.outer(proj_lengths, ray_direction)
9384
dists_to_ray = np.linalg.norm(lidar_pc - proj_points, axis=1)
9485
near_ray_mask = dists_to_ray < distance_threshold
9586
filtered_pc = lidar_pc[near_ray_mask]
9687

97-
# If too few points remain, return None.
9888
if filtered_pc.shape[0] < min_points:
9989
return None, None, None
10090

101-
# Sweep along the ray using the filtered point cloud.
10291
t_values = np.arange(t_min, t_max, t_step)
10392
for t in t_values:
10493
candidate = ray_origin + t * ray_direction
@@ -109,13 +98,11 @@ def find_human_center_on_ray(lidar_pc, ray_origin, ray_direction,
10998
return refined_candidate, None, None
11099
return None, None, None
111100

112-
113101
def extract_roi(pc, center, roi_radius):
114102
"""Extract points from pc that lie within roi_radius of center."""
115103
distances = np.linalg.norm(pc - center, axis=1)
116104
return pc[distances < roi_radius]
117105

118-
119106
def refine_cluster(roi_points, center, eps=0.2, min_samples=10):
120107
"""Refine a cluster using DBSCAN and return the cluster closest to center."""
121108
clustering = DBSCAN(eps=eps, min_samples=min_samples).fit(roi_points)
@@ -126,7 +113,6 @@ def refine_cluster(roi_points, center, eps=0.2, min_samples=10):
126113
best_cluster = min(valid_clusters, key=lambda c: np.linalg.norm(np.mean(c, axis=0) - center))
127114
return best_cluster
128115

129-
130116
def remove_ground_by_min_range(cluster, z_range=0.05):
131117
"""
132118
Remove ground points by finding the minimum z value in the cluster and eliminating
@@ -138,7 +124,6 @@ def remove_ground_by_min_range(cluster, z_range=0.05):
138124
filtered = cluster[cluster[:, 2] > (min_z + z_range)]
139125
return filtered
140126

141-
142127
def get_bounding_box_center_and_dimensions(points):
143128
"""
144129
Compute the bounding box center and dimensions (max - min) for the given points.
@@ -151,7 +136,6 @@ def get_bounding_box_center_and_dimensions(points):
151136
dimensions = max_vals - min_vals
152137
return center, dimensions
153138

154-
155139
def create_circle_line_set(center, radius, num_points=50):
156140
"""
157141
Create a LineSet representing a circle (in the X-Y plane) with given center and radius.
@@ -171,7 +155,6 @@ def create_circle_line_set(center, radius, num_points=50):
171155
line_set.colors = o3d.utility.Vector3dVector([[0, 1, 0] for _ in range(len(lines))])
172156
return line_set
173157

174-
175158
def create_ray_line_set(start, end):
176159
"""
177160
Create a LineSet representing a ray from 'start' to 'end' (colored yellow).
@@ -184,7 +167,6 @@ def create_ray_line_set(start, end):
184167
line_set.colors = o3d.utility.Vector3dVector([[1, 1, 0]])
185168
return line_set
186169

187-
188170
def visualize_geometries(geometries, window_name="Open3D", width=800, height=600, point_size=5.0):
189171
"""Utility to visualize a list of Open3D geometries."""
190172
vis = o3d.visualization.Visualizer()
@@ -196,76 +178,86 @@ def visualize_geometries(geometries, window_name="Open3D", width=800, height=600
196178
vis.run()
197179
vis.destroy_window()
198180

199-
200-
# ----- Pedestrian Detector 2D (with 3D fusion) -----
201-
181+
# ----- Pedestrian Detector 2D (with 3D fusion and synchronized callbacks) -----
202182
class PedestrianDetector2D(Component):
203-
"""Detects pedestrians using YOLO and LiDAR to estimate 3D pose."""
183+
"""
184+
Detects pedestrians using YOLO and LiDAR to estimate 3D pose.
185+
This version uses message_filters to synchronize the image and LiDAR data.
186+
The synchronized callback stores the latest sensor data, and heavy processing is done in update().
187+
"""
204188

205189
def __init__(self, vehicle_interface: GEMInterface):
206190
self.vehicle_interface = vehicle_interface
207-
self.last_person_boxes = []
208-
self.lidar_pc = None # Will be updated via ROS callback
209-
self.pc_raw = None
210-
191+
self.current_agents = {}
192+
self.tracked_agents = {}
193+
self.pedestrian_counter = 0
194+
# Variables to store synchronized sensor data:
195+
self.latest_image = None
196+
self.latest_lidar = None
211197

212-
def rate(self):
198+
def rate(self) -> float:
213199
return 4.0
214200

215-
def state_inputs(self):
201+
def state_inputs(self) -> list:
216202
return ['vehicle']
217203

218-
def state_outputs(self):
204+
def state_outputs(self) -> list:
219205
return ['agents']
220206

221207
def initialize(self):
222-
# Subscribe to camera and LiDAR.
208+
# Instead of individual subscriptions, use message_filters to synchronize
209+
self.rgb_sub = Subscriber('/oak/rgb/image_raw', Image)
210+
self.lidar_sub = Subscriber('/ouster/points', PointCloud2)
211+
self.sync = ApproximateTimeSynchronizer([self.rgb_sub, self.lidar_sub],
212+
queue_size=10, slop=0.1)
213+
self.sync.registerCallback(self.synchronized_callback)
214+
# Initialize YOLO detector
223215
self.detector = YOLO('../../knowledge/detection/yolov8n.pt')
224-
self.vehicle_interface.subscribe_sensor('front_camera', self.image_callback, cv2.Mat)
225-
self.vehicle_interface.subscribe_sensor('top_lidar', self.lidar_callback, PointCloud2)
226-
#self.vehicle_interface.subscribe_sensor('ouster/points', self.lidar_callback, PointCloud2)
227216
# Set up camera intrinsics and LiDAR-to-camera transformation.
228-
self.T_l2v = np.array([
229-
[0.99993639, 0.02547917, 0.023615, -1.1],
230-
[-0.02530848, 0.9996156, -0.00749882, 0.03773583],
231-
[-0.02379784, 0.00689664, 0.999693, 1.95320223],
232-
[0., 0., 0., 1.]
233-
])
217+
self.T_l2v = np.array([[ 0.99939639, 0.02547917, 0.023615, 1.1 ],
218+
[-0.02530848, 0.99965156, -0.00749882, 0.03773583],
219+
[-0.02379784, 0.00689664, 0.999693, 1.95320223],
220+
[ 0., 0., 0., 1. ]])
234221
self.K = np.array([[684.83331299, 0., 573.37109375],
235222
[0., 684.60968018, 363.70092773],
236223
[0., 0., 1.]])
237-
self.T_l2c = np.array([[-0.01909581, -0.9997844, 0.0081547, 0.24521313],
238-
[0.06526397, -0.00938524, -0.9978239, -0.80389025],
239-
[0.9976853, -0.01852205, 0.06542912, -0.6605772],
240-
[0., 0., 0., 1.]])
224+
self.T_l2c = np.array([
225+
[0.001090, -0.999489, -0.031941, 0.149698],
226+
[-0.007664, 0.031932, -0.999461, -0.397813],
227+
[0.999970, 0.001334, -0.007625, -0.691405],
228+
[0.000000, 0.000000, 0.000000, 1.000000]
229+
])
241230
self.T_c2l = np.linalg.inv(self.T_l2c)
242231
self.R_c2l = self.T_c2l[:3, :3]
243232
self.camera_origin_in_lidar = self.T_c2l[:3, 3]
244-
self.tracked_agents = {}
245-
246-
# For generating new IDs
247-
self.pedestrian_counter = 0
248233

249-
def lidar_callback(self, lidar_msg: PointCloud2):
250-
"""Convert ROS PointCloud2 to numpy array and store it."""
251-
self.lidar_pc = pc2_to_numpy(lidar_msg, want_rgb=False)
252-
def image_callback(self, image: cv2.Mat):
253-
results = self.detector(image, conf=0.5, classes=[0])
254-
boxes = np.array(results[0].boxes.xywh.cpu()) # Format: [center_x, center_y, w, h]
255-
self.last_person_boxes = boxes
234+
def synchronized_callback(self, image_msg, lidar_msg):
235+
"""
236+
This callback is triggered when both an image and a LiDAR message arrive within the slop.
237+
It stores the latest synchronized sensor data for processing in update().
238+
"""
239+
# Convert the image message to an OpenCV image (assuming it is already in cv2.Mat format or convert as needed)
240+
self.latest_image = image_msg
241+
# Convert the LiDAR message to a numpy array
242+
self.latest_lidar = pc2_to_numpy(lidar_msg, want_rgb=False)
256243

257244
def update(self, vehicle: VehicleState) -> Dict[str, AgentState]:
258-
agents = {}
259-
if self.lidar_pc is None:
260-
print("NOT FOUND")
261-
return agents
245+
# Process only if synchronized sensor data is available
246+
if self.latest_image is None or self.latest_lidar is None:
247+
rospy.logwarn("Synchronized sensor data not available; skipping update.")
248+
return {}
262249

263250
current_time = self.vehicle_interface.time()
264-
lidar_pc = self.lidar_pc.copy()
265-
for i, box in enumerate(self.last_person_boxes):
266-
251+
# Run YOLO inference on the latest synchronized image
252+
results = self.detector(self.latest_image, conf=0.5, classes=[0])
253+
boxes = np.array(results[0].boxes.xywh.cpu()) # Format: [center_x, center_y, w, h]
254+
255+
agents = {}
256+
lidar_pc = self.latest_lidar.copy()
257+
258+
for i, box in enumerate(boxes):
267259
cx, cy, w, h = box
268-
# --- same LiDAR + bounding box logic as before ---
260+
# Convert pixel center to a ray in LiDAR frame
269261
ray_dir_cam = backproject_pixel(cx, cy, self.K)
270262
ray_dir_lidar = self.R_c2l @ ray_dir_cam
271263
ray_dir_lidar /= np.linalg.norm(ray_dir_lidar)
@@ -282,11 +274,7 @@ def update(self, vehicle: VehicleState) -> Dict[str, AgentState]:
282274
physical_width = (w * d) / self.K[0, 0]
283275
physical_height = (h * d) / self.K[1, 1]
284276
depth_margin = physical_width
285-
half_extents = np.array([
286-
0.4,
287-
0.4,
288-
1.25 * physical_height / 2
289-
])
277+
half_extents = np.array([0.4, 0.4, 1.25 * physical_height / 2])
290278

291279
roi_points = extract_roi_box(lidar_pc, intersection, half_extents)
292280
if roi_points.shape[0] < 10:
@@ -309,7 +297,7 @@ def update(self, vehicle: VehicleState) -> Dict[str, AgentState]:
309297
dims = tuple(obb.extent)
310298
R_lidar = obb.R.copy()
311299

312-
# transform to vehicle frame
300+
# Transform refined center to vehicle frame
313301
refined_center_lidar_hom = np.array([refined_center[0],
314302
refined_center[1],
315303
refined_center[2],
@@ -319,14 +307,13 @@ def update(self, vehicle: VehicleState) -> Dict[str, AgentState]:
319307

320308
R_vehicle = self.T_l2v[:3, :3] @ R_lidar
321309
euler_angles_vehicle = R.from_matrix(R_vehicle).as_euler('zyx', degrees=True)
322-
yaw, pitch, roll = euler_angles_vehicle # rename for clarity
310+
yaw, pitch, roll = euler_angles_vehicle
311+
refined_center = refined_center_vehicle # Use vehicle frame for output
323312

324-
refined_center = refined_center_vehicle # override to use vehicle frame
325-
# dims remains the same; orientation is now (yaw, pitch, roll)
326-
print(f"Detected human in vehicle frame - Pose (yaw, pitch, roll): {euler_angles_vehicle}")
327-
print(f"Bounding box center (vehicle frame): {refined_center_vehicle}, Dimensions: {dims}")
313+
rospy.loginfo(f"Detected human in vehicle frame - Pose: {euler_angles_vehicle}, "
314+
f"Center: {refined_center_vehicle}, Dimensions: {dims}")
328315

329-
# -- CREATE the new pose in the vehicle frame --
316+
# Create new pose in the vehicle frame
330317
new_pose = ObjectPose(
331318
t=current_time,
332319
x=refined_center[0],
@@ -338,7 +325,7 @@ def update(self, vehicle: VehicleState) -> Dict[str, AgentState]:
338325
frame=ObjectFrameEnum.CURRENT
339326
)
340327

341-
# 1) Attempt to match with an existing pedestrian
328+
# Attempt to match with an existing pedestrian
342329
existing_id = match_existing_pedestrian(
343330
new_center=np.array([new_pose.x, new_pose.y, new_pose.z]),
344331
new_dims=dims,
@@ -347,8 +334,6 @@ def update(self, vehicle: VehicleState) -> Dict[str, AgentState]:
347334
)
348335

349336
if existing_id is not None:
350-
# if False:
351-
# 2) Update existing agent
352337
old_agent_state, old_time = self.tracked_agents[existing_id]
353338
dt = float(current_time) - float(old_time)
354339
vx, vy, vz = compute_velocity(old_agent_state.pose, new_pose, dt)
@@ -364,9 +349,7 @@ def update(self, vehicle: VehicleState) -> Dict[str, AgentState]:
364349
)
365350
agents[existing_id] = updated_agent
366351
self.tracked_agents[existing_id] = (updated_agent, str(current_time))
367-
368352
else:
369-
# 3) Create a new agent
370353
agent_id = f"pedestrian{self.pedestrian_counter}"
371354
self.pedestrian_counter += 1
372355

@@ -382,14 +365,12 @@ def update(self, vehicle: VehicleState) -> Dict[str, AgentState]:
382365
agents[agent_id] = new_agent
383366
self.tracked_agents[agent_id] = (new_agent, str(current_time))
384367

368+
self.current_agents = agents
385369
return agents
386370

387-
388371
# ----- Fake Pedestrian Detector 2D (unchanged) -----
389-
390372
class FakePedestrianDetector2D(Component):
391373
"""Triggers a pedestrian detection at some random time ranges."""
392-
393374
def __init__(self, vehicle_interface: GEMInterface):
394375
self.vehicle_interface = vehicle_interface
395376
self.times = [(5.0, 20.0), (30.0, 35.0)]
@@ -412,14 +393,12 @@ def update(self, vehicle: VehicleState) -> Dict[str, AgentState]:
412393
for times in self.times:
413394
if t >= times[0] and t <= times[1]:
414395
res['pedestrian0'] = box_to_fake_agent((0, 0, 0, 0))
415-
print("Detected a pedestrian")
396+
rospy.loginfo("Detected a pedestrian")
416397
return res
417398

418-
419399
def box_to_fake_agent(box):
420400
"""Creates a fake agent state from an (x,y,w,h) bounding box.
421-
422-
The location and size are just 2D approximations.
401+
The location and size are just 2D approximations.
423402
"""
424403
x, y, w, h = box
425404
pose = ObjectPose(t=0, x=x + w / 2, y=y + h / 2, z=0, yaw=0, pitch=0, roll=0, frame=ObjectFrameEnum.CURRENT)
@@ -428,7 +407,6 @@ def box_to_fake_agent(box):
428407
type=AgentEnum.PEDESTRIAN, activity=AgentActivityEnum.MOVING,
429408
velocity=(0, 0, 0), yaw_rate=0)
430409

431-
432410
if __name__ == '__main__':
433411
# This module is meant to be used within the vehicle interface context.
434412
# For testing standalone, you may create a fake vehicle state and call update().

0 commit comments

Comments
 (0)