Skip to content

Commit 51d44de

Browse files
committed
Add potential performance improvements to map-based state estimation
1 parent 6545522 commit 51d44de

2 files changed

Lines changed: 406 additions & 13 deletions

File tree

Lines changed: 395 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,395 @@
1+
from dataclasses import replace
2+
import math
3+
from typing import List
4+
from ...utils import settings
5+
from ...mathutils import transforms
6+
from ...state.vehicle import VehicleState,VehicleGearEnum
7+
from ...state.physical_object import ObjectFrameEnum,ObjectPose,convert_xyhead
8+
from ...knowledge.vehicle.geometry import front2steer,steer2front
9+
from ...mathutils.signal import OnlineLowPassFilter
10+
from ..interface.gem import GEMInterface
11+
from ..component import Component
12+
from ..interface.gem import GNSSReading
13+
14+
import numpy as np
15+
import open3d as o3d
16+
import copy
17+
import time
18+
import argparse
19+
import os
20+
import glob
21+
from scipy.spatial.transform import Rotation as R
22+
23+
def load_map(map_file):
24+
"""Load a .ply map file."""
25+
try:
26+
print(map_file)
27+
map_pcd = o3d.io.read_point_cloud(map_file)
28+
points = np.asarray(map_pcd.points)
29+
30+
# Calculate map center for later use
31+
map_center = np.mean(points, axis=0)
32+
33+
return map_pcd
34+
except Exception as e:
35+
print(f"Error loading map: {e}")
36+
return None, None
37+
38+
def load_lidar_scan(points):
39+
"""Load a .npz lidar scan file."""
40+
try:
41+
# Create point cloud from numpy array
42+
scan_pcd = o3d.geometry.PointCloud()
43+
points = np.ascontiguousarray(points[:, :3], dtype=np.float64)
44+
scan_pcd.points = o3d.utility.Vector3dVector(points)
45+
46+
47+
# # Add intensity as colors if available (4th column)
48+
# if points.shape[1] >= 4:
49+
# intensities = points[:, 3]
50+
# normalized_intensity = (intensities - np.min(intensities)) / (np.max(intensities) - np.min(intensities) + 1e-10)
51+
# colors = np.zeros((points.shape[0], 3))
52+
# colors[:, 0] = normalized_intensity # Map intensity to red channel
53+
# colors[:, 1] = normalized_intensity # Map intensity to green channel
54+
# colors[:, 2] = normalized_intensity # Map intensity to blue channel
55+
# scan_pcd.colors = o3d.utility.Vector3dVector(colors)
56+
57+
return scan_pcd
58+
except Exception as e:
59+
print(f"Error loading scan: {e}")
60+
return None
61+
62+
def remove_floor_ceiling(pcd, z_min=-0.5, z_max=2.5):
63+
"""Remove floor and ceiling points to focus on walls and structural features."""
64+
points = np.asarray(pcd.points)
65+
mask = np.logical_and(points[:, 2] > z_min, points[:, 2] < z_max)
66+
67+
filtered_pcd = o3d.geometry.PointCloud()
68+
filtered_pcd.points = o3d.utility.Vector3dVector(points[mask])
69+
if pcd.has_colors():
70+
filtered_pcd.colors = o3d.utility.Vector3dVector(np.asarray(pcd.colors)[mask])
71+
72+
return filtered_pcd
73+
74+
def extract_structural_features(pcd, voxel_size):
75+
"""Extract structural features like walls from point cloud."""
76+
# Downsample first
77+
pcd_down = pcd.voxel_down_sample(voxel_size)
78+
79+
# Estimate normals
80+
pcd_down.estimate_normals(
81+
o3d.geometry.KDTreeSearchParamHybrid(radius=voxel_size*2, max_nn=30))
82+
83+
# Find planar segments (walls, floors, etc.)
84+
planes = []
85+
rest = copy.deepcopy(pcd_down)
86+
for i in range(6): # Extract top 6 planes
87+
if len(np.asarray(rest.points)) < 100:
88+
break
89+
plane_model, inliers = rest.segment_plane(distance_threshold=voxel_size*2,
90+
ransac_n=3,
91+
num_iterations=1000)
92+
if len(inliers) < 50:
93+
break
94+
95+
plane = rest.select_by_index(inliers)
96+
planes.append(plane)
97+
rest = rest.select_by_index(inliers, invert=True)
98+
99+
# Combine planes into a single point cloud
100+
structural_pcd = o3d.geometry.PointCloud()
101+
for plane in planes:
102+
structural_pcd += plane
103+
104+
# If no planes were found, return the original downsampled point cloud
105+
if len(planes) == 0:
106+
return pcd_down
107+
108+
return structural_pcd
109+
110+
def prepare_scan_for_global_registration(scan_pcd, map_pcd, scale_ratio=None):
111+
"""Improved scaling/translation using actual map bounds"""
112+
# Get map dimensions
113+
map_points = np.asarray(map_pcd.points)
114+
map_min = np.min(map_points, axis=0)
115+
map_max = np.max(map_points, axis=0)
116+
map_center = (map_min + map_max) / 2
117+
118+
# Get scan dimensions
119+
scan_points = np.asarray(scan_pcd.points)
120+
scan_min = np.min(scan_points, axis=0)
121+
scan_max = np.max(scan_points, axis=0)
122+
123+
# Calculate dynamic scale ratio
124+
if not scale_ratio:
125+
map_range = map_max - map_min
126+
scan_range = scan_max - scan_min
127+
scale_ratio = np.min(map_range / scan_range) * 0.8 # Use 80% of map size
128+
129+
# Apply scaling and center alignment
130+
scaled_points = (scan_points - scan_min) * scale_ratio + map_min
131+
132+
aligned_scan = o3d.geometry.PointCloud()
133+
aligned_scan.points = o3d.utility.Vector3dVector(scaled_points)
134+
135+
print(f"Dynamic scaling ratio: {scale_ratio}")
136+
return aligned_scan
137+
138+
def preprocess_point_cloud(pcd, voxel_size, radius_normal=None, radius_feature=None):
139+
"""Modified feature parameters for better matching"""
140+
pcd_down = pcd.voxel_down_sample(voxel_size)
141+
142+
# Larger radii to capture building-scale features
143+
radius_normal = radius_normal or voxel_size * 5.0 # Increased from 2.0
144+
radius_feature = radius_feature or voxel_size * 10.0 # Increased from 5.0
145+
146+
pcd_down.estimate_normals(
147+
o3d.geometry.KDTreeSearchParamHybrid(radius=radius_normal, max_nn=50))
148+
149+
pcd_fpfh = o3d.pipelines.registration.compute_fpfh_feature(
150+
pcd_down,
151+
o3d.geometry.KDTreeSearchParamHybrid(radius=radius_feature, max_nn=100))
152+
153+
return pcd_down, pcd_fpfh
154+
155+
def execute_global_registration(source_down, target_down, source_fpfh, target_fpfh,
156+
voxel_size, max_iterations=1000000):
157+
"""Improved RANSAC registration with configurable iterations."""
158+
distance_threshold = voxel_size * 15 # Increased for initial alignment
159+
160+
result = o3d.pipelines.registration.registration_ransac_based_on_feature_matching(
161+
source_down, target_down, source_fpfh, target_fpfh, True,
162+
distance_threshold,
163+
o3d.pipelines.registration.TransformationEstimationPointToPoint(False),
164+
4,
165+
[o3d.pipelines.registration.CorrespondenceCheckerBasedOnEdgeLength(0.9),
166+
o3d.pipelines.registration.CorrespondenceCheckerBasedOnDistance(distance_threshold),
167+
o3d.pipelines.registration.CorrespondenceCheckerBasedOnNormal(0.5)],
168+
o3d.pipelines.registration.RANSACConvergenceCriteria(max_iterations, 500))
169+
170+
return result
171+
172+
def execute_fast_global_registration(source_down, target_down, source_fpfh, target_fpfh, voxel_size):
173+
"""Perform Fast Global Registration."""
174+
distance_threshold = voxel_size * 0.5
175+
print(f":: Apply fast global registration with distance threshold {distance_threshold:.3f}")
176+
177+
try:
178+
result = o3d.pipelines.registration.registration_fast_based_on_feature_matching(
179+
source_down, target_down, source_fpfh, target_fpfh,
180+
o3d.pipelines.registration.FastGlobalRegistrationOption(
181+
maximum_correspondence_distance=distance_threshold))
182+
return result
183+
except RuntimeError as e:
184+
print(f"Error in FGR: {e}")
185+
# Return dummy result with identity transformation in case of failure
186+
dummy_result = o3d.pipelines.registration.RegistrationResult()
187+
dummy_result.transformation = np.identity(4)
188+
dummy_result.fitness = 0.0
189+
dummy_result.inlier_rmse = 0.0
190+
dummy_result.correspondence_set = []
191+
return dummy_result
192+
193+
def multi_scale_icp(source, target, voxel_sizes=[2.0, 1.0, 0.5], max_iterations=[50, 30, 14],
194+
initial_transform=np.eye(4)):
195+
"""Perform multi-scale ICP for robust alignment."""
196+
print("Running multi-scale ICP...")
197+
current_transform = initial_transform
198+
199+
for i, (voxel_size, max_iter) in enumerate(zip(voxel_sizes, max_iterations)):
200+
print(f"ICP Scale {i+1}/{len(voxel_sizes)}: voxel_size={voxel_size}, max_iterations={max_iter}")
201+
202+
# Downsample based on current voxel size
203+
source_down = source.voxel_down_sample(voxel_size)
204+
target_down = target.voxel_down_sample(voxel_size)
205+
206+
# Estimate normals if not already computed
207+
source_down.estimate_normals(
208+
o3d.geometry.KDTreeSearchParamHybrid(radius=voxel_size*2, max_nn=30))
209+
target_down.estimate_normals(
210+
o3d.geometry.KDTreeSearchParamHybrid(radius=voxel_size*2, max_nn=30))
211+
212+
# Use appropriate distance threshold based on scale
213+
distance_threshold = max(0.5, voxel_size * 2)
214+
215+
# Run ICP
216+
result = o3d.pipelines.registration.registration_icp(
217+
source_down, target_down, distance_threshold, current_transform,
218+
o3d.pipelines.registration.TransformationEstimationPointToPoint(),
219+
o3d.pipelines.registration.ICPConvergenceCriteria(max_iteration=max_iter))
220+
221+
current_transform = result.transformation
222+
print(f" Scale {i+1} result - Fitness: {result.fitness:.4f}, RMSE: {result.inlier_rmse:.4f}")
223+
224+
return current_transform
225+
226+
def transform_to_pose(transformation_matrix):
227+
"""Convert transformation matrix to position and orientation (RPY)."""
228+
# Extract translation
229+
x, y, z = transformation_matrix[:3, 3]
230+
231+
# Extract rotation matrix and convert to Euler angles
232+
# Make a writable copy to avoid read-only array issues
233+
rotation_matrix = np.array(transformation_matrix[:3, :3], copy=True)
234+
r = R.from_matrix(rotation_matrix)
235+
roll, pitch, yaw = r.as_euler('xyz', degrees=True)
236+
237+
return x, y, z, roll, pitch, yaw
238+
239+
class MapBasedStateEstimator(Component):
240+
"""Just looks at the GNSS reading to estimate the vehicle state"""
241+
def __init__(self, map_fn : str, vehicle_interface : GEMInterface):
242+
self.vehicle_interface = vehicle_interface
243+
if 'top_lidar' not in vehicle_interface.sensors():
244+
raise RuntimeError("GNSS sensor not available")
245+
vehicle_interface.subscribe_sensor('top_lidar',self.lidar_callback,np.ndarray)
246+
self.map_based_pose = None
247+
self.map_based_speed = None
248+
self.points = None
249+
self.map = load_map(map_fn)
250+
251+
# TODO: Change these to be some form of variable
252+
self.map_scale_ratio = 1.0
253+
self.voxel_size = 1.0
254+
self.scan_voxel_size = 0.5
255+
self.normal_radius_factor = 2.0
256+
self.feature_radius_factor = 5.0
257+
258+
self.map_down, self.map_fpfh = preprocess_point_cloud(
259+
self.map, self.voxel_size, self.normal_radius_factor, self.feature_radius_factor)
260+
261+
if 'gnss' not in vehicle_interface.sensors():
262+
raise RuntimeError("GNSS sensor not available")
263+
vehicle_interface.subscribe_sensor('gnss',self.gnss_callback,GNSSReading)
264+
self.gnss_pose = None
265+
self.location = settings.get('vehicle.calibration.gnss_location')[:2]
266+
self.yaw_offset = settings.get('vehicle.calibration.gnss_yaw')
267+
self.speed_filter = OnlineLowPassFilter(1.2, 30, 4)
268+
self.status = None
269+
self.transformation = np.identity(4)
270+
271+
# Get GNSS information
272+
def gnss_callback(self, reading : GNSSReading):
273+
self.gnss_pose = reading.pose
274+
self.gnss_speed = reading.speed
275+
self.status = reading.status
276+
277+
# Get lidar information
278+
def lidar_callback(self, reading : np.ndarray):
279+
self.points = reading
280+
281+
def rate(self):
282+
return 1
283+
284+
def state_outputs(self) -> List[str]:
285+
return ['vehicle']
286+
287+
def healthy(self):
288+
return self.map_based_pose is not None
289+
290+
def update(self) -> VehicleState:
291+
if self.points is None:
292+
return
293+
294+
scan_time = self.vehicle_interface.time()
295+
print("Initialized", self.vehicle_interface.time() - scan_time)
296+
297+
# Load scans
298+
scan_pcd = load_lidar_scan(self.points)
299+
300+
print("Load scan", self.vehicle_interface.time() - scan_time)
301+
302+
# Scale and translate the scan to match map scale and center
303+
scaled_scan_pcd = prepare_scan_for_global_registration(
304+
scan_pcd,
305+
self.map,
306+
self.map_scale_ratio
307+
)
308+
309+
print("Prepare scan", self.vehicle_interface.time() - scan_time)
310+
311+
scan_down, scan_fpfh = preprocess_point_cloud(
312+
scaled_scan_pcd, self.scan_voxel_size, self.normal_radius_factor, self.feature_radius_factor)
313+
314+
print("Process scan", self.vehicle_interface.time() - scan_time)
315+
316+
# Global registration
317+
if self.transformation == np.identity(4):
318+
# RANSAC
319+
ransac_result = execute_global_registration(
320+
scan_down, self.map_down, scan_fpfh, self.map_fpfh, self.voxel_size)
321+
322+
self.transformation = ransac_result.transformation
323+
print("RANSAC", self.vehicle_interface.time() - scan_time)
324+
325+
# Fast Global Registration
326+
fgr_result = execute_fast_global_registration(
327+
scan_down, self.map_down, scan_fpfh, self.map_fpfh, self.voxel_size)
328+
329+
# Use the better result based on fitness
330+
if fgr_result.fitness > ransac_result.fitness:
331+
self.transformation = fgr_result.transformation
332+
print("Fast", self.vehicle_interface.time() - scan_time)
333+
334+
# Refine with ICP
335+
icp_transformation = multi_scale_icp(
336+
scan_down, self.map_down,
337+
voxel_sizes=[2.0, 1.0, 0.5],
338+
max_iterations=[100, 50, 25],
339+
initial_transform=self.transformation)
340+
341+
print("ICP", self.vehicle_interface.time() - scan_time)
342+
self.transformation = icp_transformation
343+
344+
# Extract position and orientation
345+
x, y, z, roll, pitch, yaw = transform_to_pose(icp_transformation)
346+
# TODO: Estimate speed
347+
if self.map_based_pose != None:
348+
translation = np.array([x - self.map_based_pose.x, y - self.map_based_pose.y, z - self.map_based_pose.z])
349+
self.map_based_speed = np.linalg.norm(translation) / (scan_time - self.map_based_pose.t)
350+
self.map_based_pose = ObjectPose(ObjectFrameEnum.GLOBAL, scan_time, x, y, z, yaw, pitch, roll)
351+
352+
# # vehicle gnss heading (yaw) in radians
353+
# # vehicle x, y position in fixed local frame, in meters
354+
# # reference point is located at the center of GNSS antennas
355+
# localxy = transforms.rotate2d(self.location,-self.yaw_offset)
356+
# gnss_xyhead_inv = (-localxy[0],-localxy[1],-self.yaw_offset)
357+
# center_xyhead = self.gnss_pose.apply_xyhead(gnss_xyhead_inv)
358+
# vehicle_pose_global = replace(self.gnss_pose,
359+
# t=self.vehicle_interface.time(),
360+
# x=center_xyhead[0],
361+
# y=center_xyhead[1],
362+
# yaw=center_xyhead[2])
363+
364+
readings = self.vehicle_interface.get_reading()
365+
raw = readings.to_state(self.map_based_pose)
366+
367+
print("Extraction", self.vehicle_interface.time() - scan_time)
368+
print(x, y, z, yaw)
369+
370+
#filtering speed
371+
if self.map_based_speed != None:
372+
raw.v = self.map_based_speed
373+
else:
374+
raw.v = 0.0
375+
376+
if self.gnss_pose is None:
377+
return
378+
#TODO: figure out what this status means
379+
#print("INS status",self.status)
380+
381+
# vehicle gnss heading (yaw) in radians
382+
# vehicle x, y position in fixed local frame, in meters
383+
# reference point is located at the center of GNSS antennas
384+
localxy = transforms.rotate2d(self.location,-self.yaw_offset)
385+
gnss_xyhead_inv = (-localxy[0],-localxy[1],-self.yaw_offset)
386+
center_xyhead = self.gnss_pose.apply_xyhead(gnss_xyhead_inv)
387+
vehicle_pose_global = replace(self.gnss_pose,
388+
t=self.vehicle_interface.time(),
389+
x=center_xyhead[0],
390+
y=center_xyhead[1],
391+
yaw=center_xyhead[2])
392+
393+
print(vehicle_pose_global.x, vehicle_pose_global.y, vehicle_pose_global.z, vehicle_pose_global.yaw)
394+
395+
return raw

0 commit comments

Comments
 (0)