Skip to content

Commit a0ed900

Browse files
committed
Update pedestrian_detection.py
1 parent 24a72c0 commit a0ed900

1 file changed

Lines changed: 83 additions & 20 deletions

File tree

GEMstack/onboard/perception/pedestrian_detection.py

Lines changed: 83 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,79 @@
1616
from cv_bridge import CvBridge
1717
import time
1818

19+
20+
def find_human_center_in_bbox(lidar_pc, bbox, T_l2c, K, camera_origin, eps=0.15, min_samples=10):
21+
"""
22+
Maps LiDAR points to the image frame and retains those within the bounding box.
23+
Clusters these points and returns the center of the cluster that is closest to the camera.
24+
25+
Args:
26+
lidar_pc (np.ndarray): LiDAR points in LiDAR frame, shape (N, 3).
27+
bbox (tuple): Bounding box in image coordinates (cx, cy, w, h).
28+
T_l2c (np.ndarray): 4x4 transformation matrix from LiDAR to camera frame.
29+
K (np.ndarray): 3x3 camera intrinsic matrix.
30+
camera_origin (np.ndarray): The camera's origin in LiDAR frame.
31+
eps (float): DBSCAN eps parameter.
32+
min_samples (int): DBSCAN min_samples parameter.
33+
34+
Returns:
35+
refined_candidate (np.ndarray): The estimated 3D center (in LiDAR frame) of the chosen cluster.
36+
best_cluster (np.ndarray): Points belonging to the chosen cluster.
37+
(None): For compatibility with the previous API.
38+
"""
39+
cx, cy, w, h = bbox
40+
x_min = cx - w / 2
41+
x_max = cx + w / 2
42+
y_min = cy - h / 2
43+
y_max = cy + h / 2
44+
45+
valid_points = []
46+
for point in lidar_pc:
47+
# Convert point to homogeneous coordinates
48+
p_hom = np.append(point, 1.0)
49+
# Transform point from LiDAR to camera frame
50+
p_cam = T_l2c @ p_hom
51+
# Discard points behind the camera
52+
if p_cam[2] <= 0:
53+
continue
54+
# Project to image plane using intrinsics
55+
u = (K[0, 0] * p_cam[0] / p_cam[2]) + K[0, 2]
56+
v = (K[1, 1] * p_cam[1] / p_cam[2]) + K[1, 2]
57+
# Keep points that lie inside the bounding box
58+
if x_min <= u <= x_max and y_min <= v <= y_max:
59+
valid_points.append(point)
60+
61+
if len(valid_points) == 0:
62+
return None, None, None
63+
64+
valid_points = np.array(valid_points)
65+
66+
# Cluster the valid points using DBSCAN
67+
clustering = DBSCAN(eps=eps, min_samples=min_samples).fit(valid_points)
68+
labels = clustering.labels_
69+
clusters = []
70+
for label in set(labels):
71+
if label == -1: # Noise
72+
continue
73+
cluster_points = valid_points[labels == label]
74+
clusters.append(cluster_points)
75+
76+
if len(clusters) == 0:
77+
return None, None, None
78+
79+
# Select the cluster whose center is closest to the camera origin
80+
best_cluster = None
81+
best_distance = float('inf')
82+
for cluster in clusters:
83+
cluster_center = np.mean(cluster, axis=0)
84+
distance = np.linalg.norm(cluster_center - camera_origin)
85+
if distance < best_distance:
86+
best_distance = distance
87+
best_cluster = cluster
88+
89+
refined_candidate = np.mean(best_cluster, axis=0)
90+
return refined_candidate, best_cluster, None
91+
1992
# ----- Helper Functions -----
2093
def match_existing_pedestrian(
2194
new_center: np.ndarray,
@@ -266,15 +339,12 @@ def update(self, vehicle: VehicleState) -> Dict[str, AgentState]:
266339

267340
for i, box in enumerate(boxes):
268341
cx, cy, w, h = box
269-
# Convert pixel center to a ray in LiDAR frame
270-
ray_dir_cam = backproject_pixel(cx, cy, self.K)
271-
ray_dir_lidar = self.R_c2l @ ray_dir_cam
272-
ray_dir_lidar /= np.linalg.norm(ray_dir_lidar)
273-
274-
intersection, _, _ = find_human_center_on_ray(
275-
lidar_pc, self.camera_origin_in_lidar, ray_dir_lidar,
276-
t_min=0.4, t_max=20.0, t_step=0.1,
277-
distance_threshold=0.3, min_points=5, ransac_threshold=0.05
342+
343+
# Instead of using the ray-based approach, use the bounding box method.
344+
bbox = (cx, cy, w, h)
345+
intersection, cluster, _ = find_human_center_in_bbox(
346+
lidar_pc, bbox, self.T_l2c, self.K, self.camera_origin_in_lidar,
347+
eps=0.15, min_samples=10
278348
)
279349
if intersection is None:
280350
continue
@@ -284,6 +354,7 @@ def update(self, vehicle: VehicleState) -> Dict[str, AgentState]:
284354
physical_height = (h * d) / self.K[1, 1]
285355
half_extents = np.array([0.4, 0.4, 1.25 * physical_height / 2])
286356

357+
# (Optional) You can still extract an ROI and refine the cluster if needed:
287358
roi_points = extract_roi_box(lidar_pc, intersection, half_extents)
288359
if roi_points.shape[0] < 10:
289360
refined_cluster = roi_points
@@ -316,20 +387,12 @@ def update(self, vehicle: VehicleState) -> Dict[str, AgentState]:
316387
R_vehicle = self.T_l2v[:3, :3] @ R_lidar
317388
euler_angles_vehicle = R.from_matrix(R_vehicle).as_euler('zyx', degrees=False)
318389
yaw, pitch, roll = euler_angles_vehicle
319-
refined_center = refined_center_vehicle # Use vehicle frame for output
390+
refined_center = refined_center_vehicle
320391
vehicle_state = vehicle.to_frame(ObjectFrameEnum.GLOBAL)
321392
curr_x = vehicle_state.pose.x
322393
curr_y = vehicle_state.pose.y
323-
# curr_yaw = vehicle.pose.yaw
324-
# curr_pitch = vehicle.pose.pitch
325-
# curr_roll = vehicle.pose.roll
326-
# Note: Ensure refined_center has enough dimensions before adding these values.
327394
refined_center[0] += curr_x
328395
refined_center[1] += curr_y
329-
# euler_angles_vehicle[0] += curr_yaw
330-
# # If refined_center should only be 3D, remove or adjust the following lines:
331-
# euler_angles_vehicle[1] += curr_pitch
332-
# euler_angles_vehicle[2] += curr_roll
333396

334397
# Create new pose in the vehicle frame
335398
new_pose = ObjectPose(
@@ -343,7 +406,7 @@ def update(self, vehicle: VehicleState) -> Dict[str, AgentState]:
343406
frame=ObjectFrameEnum.CURRENT
344407
)
345408

346-
# Attempt to match with an existing pedestrian
409+
# Match with an existing pedestrian if possible
347410
existing_id = match_existing_pedestrian(
348411
new_center=np.array([new_pose.x, new_pose.y, new_pose.z]),
349412
new_dims=dims,
@@ -385,7 +448,7 @@ def update(self, vehicle: VehicleState) -> Dict[str, AgentState]:
385448

386449
# Remove stale agents that haven't been updated for more than 3 seconds.
387450
stale_ids = [agent_id for agent_id, agent in self.tracked_agents.items()
388-
if current_time - agent.pose.t > 3.0]
451+
if current_time - agent.pose.t > 5.0]
389452
for agent_id in stale_ids:
390453
rospy.loginfo(f"Removing stale agent: {agent_id}")
391454
del self.tracked_agents[agent_id]

0 commit comments

Comments
 (0)