Skip to content

Commit 920fcc5

Browse files
committed
Initial map-based localization as component
1 parent b6a67d2 commit 920fcc5

2 files changed

Lines changed: 454 additions & 0 deletions

File tree

Lines changed: 397 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,397 @@
1+
from dataclasses import replace
2+
import math
3+
from typing import List
4+
from ...utils import settings
5+
from ...mathutils import transforms
6+
from ...state.vehicle import VehicleState,VehicleGearEnum
7+
from ...state.physical_object import ObjectFrameEnum,ObjectPose,convert_xyhead
8+
from ...knowledge.vehicle.geometry import front2steer,steer2front
9+
from ...mathutils.signal import OnlineLowPassFilter
10+
from ..interface.gem import GEMInterface
11+
from ..component import Component
12+
from ..interface.gem import GNSSReading
13+
14+
import numpy as np
15+
import open3d as o3d
16+
import copy
17+
import time
18+
import argparse
19+
import os
20+
import glob
21+
from scipy.spatial.transform import Rotation as R
22+
23+
def load_map(map_file):
24+
"""Load a .ply map file."""
25+
try:
26+
print(map_file)
27+
map_pcd = o3d.io.read_point_cloud(map_file)
28+
points = np.asarray(map_pcd.points)
29+
30+
# Calculate map center for later use
31+
map_center = np.mean(points, axis=0)
32+
33+
return map_pcd
34+
except Exception as e:
35+
print(f"Error loading map: {e}")
36+
return None, None
37+
38+
def load_lidar_scan(points):
39+
"""Load a .npz lidar scan file."""
40+
try:
41+
# Create point cloud from numpy array
42+
scan_pcd = o3d.geometry.PointCloud()
43+
scan_pcd.points = o3d.utility.Vector3dVector(points[:, :3])
44+
45+
46+
# # Add intensity as colors if available (4th column)
47+
# if points.shape[1] >= 4:
48+
# intensities = points[:, 3]
49+
# normalized_intensity = (intensities - np.min(intensities)) / (np.max(intensities) - np.min(intensities) + 1e-10)
50+
# colors = np.zeros((points.shape[0], 3))
51+
# colors[:, 0] = normalized_intensity # Map intensity to red channel
52+
# colors[:, 1] = normalized_intensity # Map intensity to green channel
53+
# colors[:, 2] = normalized_intensity # Map intensity to blue channel
54+
# scan_pcd.colors = o3d.utility.Vector3dVector(colors)
55+
56+
return scan_pcd
57+
except Exception as e:
58+
print(f"Error loading scan: {e}")
59+
return None
60+
61+
def remove_floor_ceiling(pcd, z_min=-0.5, z_max=2.5):
62+
"""Remove floor and ceiling points to focus on walls and structural features."""
63+
points = np.asarray(pcd.points)
64+
mask = np.logical_and(points[:, 2] > z_min, points[:, 2] < z_max)
65+
66+
filtered_pcd = o3d.geometry.PointCloud()
67+
filtered_pcd.points = o3d.utility.Vector3dVector(points[mask])
68+
if pcd.has_colors():
69+
filtered_pcd.colors = o3d.utility.Vector3dVector(np.asarray(pcd.colors)[mask])
70+
71+
return filtered_pcd
72+
73+
def extract_structural_features(pcd, voxel_size):
74+
"""Extract structural features like walls from point cloud."""
75+
# Downsample first
76+
pcd_down = pcd.voxel_down_sample(voxel_size)
77+
78+
# Estimate normals
79+
pcd_down.estimate_normals(
80+
o3d.geometry.KDTreeSearchParamHybrid(radius=voxel_size*2, max_nn=30))
81+
82+
# Find planar segments (walls, floors, etc.)
83+
planes = []
84+
rest = copy.deepcopy(pcd_down)
85+
for i in range(6): # Extract top 6 planes
86+
if len(np.asarray(rest.points)) < 100:
87+
break
88+
plane_model, inliers = rest.segment_plane(distance_threshold=voxel_size*2,
89+
ransac_n=3,
90+
num_iterations=1000)
91+
if len(inliers) < 50:
92+
break
93+
94+
plane = rest.select_by_index(inliers)
95+
planes.append(plane)
96+
rest = rest.select_by_index(inliers, invert=True)
97+
98+
# Combine planes into a single point cloud
99+
structural_pcd = o3d.geometry.PointCloud()
100+
for plane in planes:
101+
structural_pcd += plane
102+
103+
# If no planes were found, return the original downsampled point cloud
104+
if len(planes) == 0:
105+
return pcd_down
106+
107+
return structural_pcd
108+
109+
def prepare_scan_for_global_registration(scan_pcd, map_pcd, scale_ratio=None):
110+
"""Improved scaling/translation using actual map bounds"""
111+
# Get map dimensions
112+
map_points = np.asarray(map_pcd.points)
113+
map_min = np.min(map_points, axis=0)
114+
map_max = np.max(map_points, axis=0)
115+
map_center = (map_min + map_max) / 2
116+
117+
# Get scan dimensions
118+
scan_points = np.asarray(scan_pcd.points)
119+
scan_min = np.min(scan_points, axis=0)
120+
scan_max = np.max(scan_points, axis=0)
121+
122+
# Calculate dynamic scale ratio
123+
if not scale_ratio:
124+
map_range = map_max - map_min
125+
scan_range = scan_max - scan_min
126+
scale_ratio = np.min(map_range / scan_range) * 0.8 # Use 80% of map size
127+
128+
# Apply scaling and center alignment
129+
scaled_points = (scan_points - scan_min) * scale_ratio + map_min
130+
131+
aligned_scan = o3d.geometry.PointCloud()
132+
aligned_scan.points = o3d.utility.Vector3dVector(scaled_points)
133+
134+
print(f"Dynamic scaling ratio: {scale_ratio}")
135+
return aligned_scan
136+
137+
def preprocess_point_cloud(pcd, voxel_size, radius_normal=None, radius_feature=None):
138+
"""Modified feature parameters for better matching"""
139+
pcd_down = pcd.voxel_down_sample(voxel_size)
140+
141+
# Larger radii to capture building-scale features
142+
radius_normal = radius_normal or voxel_size * 5.0 # Increased from 2.0
143+
radius_feature = radius_feature or voxel_size * 10.0 # Increased from 5.0
144+
145+
pcd_down.estimate_normals(
146+
o3d.geometry.KDTreeSearchParamHybrid(radius=radius_normal, max_nn=50))
147+
148+
pcd_fpfh = o3d.pipelines.registration.compute_fpfh_feature(
149+
pcd_down,
150+
o3d.geometry.KDTreeSearchParamHybrid(radius=radius_feature, max_nn=100))
151+
152+
return pcd_down, pcd_fpfh
153+
154+
def execute_global_registration(source_down, target_down, source_fpfh, target_fpfh,
155+
voxel_size, max_iterations=1000000):
156+
"""Improved RANSAC registration with configurable iterations."""
157+
distance_threshold = voxel_size * 15 # Increased for initial alignment
158+
159+
result = o3d.pipelines.registration.registration_ransac_based_on_feature_matching(
160+
source_down, target_down, source_fpfh, target_fpfh, True,
161+
distance_threshold,
162+
o3d.pipelines.registration.TransformationEstimationPointToPoint(False),
163+
4,
164+
[o3d.pipelines.registration.CorrespondenceCheckerBasedOnEdgeLength(0.9),
165+
o3d.pipelines.registration.CorrespondenceCheckerBasedOnDistance(distance_threshold),
166+
o3d.pipelines.registration.CorrespondenceCheckerBasedOnNormal(0.5)],
167+
o3d.pipelines.registration.RANSACConvergenceCriteria(max_iterations, 500))
168+
169+
return result
170+
171+
def execute_fast_global_registration(source_down, target_down, source_fpfh, target_fpfh, voxel_size):
172+
"""Perform Fast Global Registration."""
173+
distance_threshold = voxel_size * 0.5
174+
print(f":: Apply fast global registration with distance threshold {distance_threshold:.3f}")
175+
176+
try:
177+
result = o3d.pipelines.registration.registration_fast_based_on_feature_matching(
178+
source_down, target_down, source_fpfh, target_fpfh,
179+
o3d.pipelines.registration.FastGlobalRegistrationOption(
180+
maximum_correspondence_distance=distance_threshold))
181+
return result
182+
except RuntimeError as e:
183+
print(f"Error in FGR: {e}")
184+
# Return dummy result with identity transformation in case of failure
185+
dummy_result = o3d.pipelines.registration.RegistrationResult()
186+
dummy_result.transformation = np.identity(4)
187+
dummy_result.fitness = 0.0
188+
dummy_result.inlier_rmse = 0.0
189+
dummy_result.correspondence_set = []
190+
return dummy_result
191+
192+
def multi_scale_icp(source, target, voxel_sizes=[2.0, 1.0, 0.5], max_iterations=[50, 30, 14],
193+
initial_transform=np.eye(4)):
194+
"""Perform multi-scale ICP for robust alignment."""
195+
print("Running multi-scale ICP...")
196+
current_transform = initial_transform
197+
198+
for i, (voxel_size, max_iter) in enumerate(zip(voxel_sizes, max_iterations)):
199+
print(f"ICP Scale {i+1}/{len(voxel_sizes)}: voxel_size={voxel_size}, max_iterations={max_iter}")
200+
201+
# Downsample based on current voxel size
202+
source_down = source.voxel_down_sample(voxel_size)
203+
target_down = target.voxel_down_sample(voxel_size)
204+
205+
# Estimate normals if not already computed
206+
source_down.estimate_normals(
207+
o3d.geometry.KDTreeSearchParamHybrid(radius=voxel_size*2, max_nn=30))
208+
target_down.estimate_normals(
209+
o3d.geometry.KDTreeSearchParamHybrid(radius=voxel_size*2, max_nn=30))
210+
211+
# Use appropriate distance threshold based on scale
212+
distance_threshold = max(0.5, voxel_size * 2)
213+
214+
# Run ICP
215+
result = o3d.pipelines.registration.registration_icp(
216+
source_down, target_down, distance_threshold, current_transform,
217+
o3d.pipelines.registration.TransformationEstimationPointToPoint(),
218+
o3d.pipelines.registration.ICPConvergenceCriteria(max_iteration=max_iter))
219+
220+
current_transform = result.transformation
221+
print(f" Scale {i+1} result - Fitness: {result.fitness:.4f}, RMSE: {result.inlier_rmse:.4f}")
222+
223+
return current_transform
224+
225+
def transform_to_pose(transformation_matrix):
226+
"""Convert transformation matrix to position and orientation (RPY)."""
227+
# Extract translation
228+
x, y, z = transformation_matrix[:3, 3]
229+
230+
# Extract rotation matrix and convert to Euler angles
231+
# Make a writable copy to avoid read-only array issues
232+
rotation_matrix = np.array(transformation_matrix[:3, :3], copy=True)
233+
r = R.from_matrix(rotation_matrix)
234+
roll, pitch, yaw = r.as_euler('xyz', degrees=True)
235+
236+
return x, y, z, roll, pitch, yaw
237+
238+
class MapBasedStateEstimator(Component):
239+
"""Just looks at the GNSS reading to estimate the vehicle state"""
240+
def __init__(self, map_fn : str, vehicle_interface : GEMInterface):
241+
self.vehicle_interface = vehicle_interface
242+
if 'top_lidar' not in vehicle_interface.sensors():
243+
raise RuntimeError("GNSS sensor not available")
244+
vehicle_interface.subscribe_sensor('top_lidar',self.lidar_callback,np.ndarray)
245+
self.map_based_pose = None
246+
self.map_based_speed = None
247+
self.points = None
248+
self.map = load_map(map_fn)
249+
250+
# TODO: Change these to be some form of variable
251+
self.map_scale_ratio = 1.0
252+
self.voxel_size = 1.0
253+
self.scan_voxel_size = 0.5
254+
255+
if 'gnss' not in vehicle_interface.sensors():
256+
raise RuntimeError("GNSS sensor not available")
257+
vehicle_interface.subscribe_sensor('gnss',self.gnss_callback,GNSSReading)
258+
self.gnss_pose = None
259+
self.location = settings.get('vehicle.calibration.gnss_location')[:2]
260+
self.yaw_offset = settings.get('vehicle.calibration.gnss_yaw')
261+
self.speed_filter = OnlineLowPassFilter(1.2, 30, 4)
262+
self.status = None
263+
264+
# Get GNSS information
265+
def gnss_callback(self, reading : GNSSReading):
266+
self.gnss_pose = reading.pose
267+
self.gnss_speed = reading.speed
268+
self.status = reading.status
269+
270+
# Get lidar information
271+
def lidar_callback(self, reading : np.ndarray):
272+
self.points = reading
273+
274+
def rate(self):
275+
return 1/30
276+
277+
def state_outputs(self) -> List[str]:
278+
return ['vehicle']
279+
280+
def healthy(self):
281+
return self.map_based_pose is not None
282+
283+
def update(self) -> VehicleState:
284+
if self.points is None:
285+
return
286+
287+
scan_time = self.vehicle_interface.time()
288+
print("Initialized", self.vehicle_interface.time() - scan_time)
289+
290+
# Load scans
291+
scan_pcd = load_lidar_scan(self.points)
292+
293+
print("Load scan", self.vehicle_interface.time() - scan_time)
294+
295+
# Scale and translate the scan to match map scale and center
296+
scaled_scan_pcd = prepare_scan_for_global_registration(
297+
scan_pcd,
298+
self.map,
299+
self.map_scale_ratio
300+
)
301+
302+
print("Prepare scan", self.vehicle_interface.time() - scan_time)
303+
304+
normal_radius_factor = 2.0
305+
feature_radius_factor = 5.0
306+
307+
map_down, map_fpfh = preprocess_point_cloud(
308+
self.map, self.voxel_size, normal_radius_factor, feature_radius_factor)
309+
310+
print("Process map", self.vehicle_interface.time() - scan_time)
311+
312+
scan_down, scan_fpfh = preprocess_point_cloud(
313+
scaled_scan_pcd, self.scan_voxel_size, normal_radius_factor, feature_radius_factor)
314+
315+
print("Process scan", self.vehicle_interface.time() - scan_time)
316+
317+
# Global registration
318+
transformation = np.identity(4)
319+
320+
# RANSAC
321+
ransac_result = execute_global_registration(
322+
scan_down, map_down, scan_fpfh, map_fpfh, self.voxel_size)
323+
324+
transformation = ransac_result.transformation
325+
print("RANSAC", self.vehicle_interface.time() - scan_time)
326+
327+
# Fast Global Registration
328+
fgr_result = execute_fast_global_registration(
329+
scan_down, map_down, scan_fpfh, map_fpfh, self.voxel_size)
330+
331+
# Use the better result based on fitness
332+
if fgr_result.fitness > ransac_result.fitness:
333+
transformation = fgr_result.transformation
334+
print("Fast", self.vehicle_interface.time() - scan_time)
335+
336+
# Refine with ICP
337+
icp_transformation = multi_scale_icp(
338+
scan_down, map_down,
339+
voxel_sizes=[2.0, 1.0, 0.5],
340+
max_iterations=[100, 50, 25],
341+
initial_transform=transformation)
342+
343+
final_transformation = icp_transformation
344+
print("ICP", self.vehicle_interface.time() - scan_time)
345+
346+
# Extract position and orientation
347+
x, y, z, roll, pitch, yaw = transform_to_pose(final_transformation)
348+
# TODO: Estimate speed
349+
if self.map_based_pose != None:
350+
translation = np.array([x - self.map_based_pose.x, y - self.map_based_pose.y, z - self.map_based_pose.z])
351+
self.map_based_speed = np.linalg.norm(translation) / (scan_time - self.map_based_pose.t)
352+
self.map_based_pose = ObjectPose(ObjectFrameEnum.GLOBAL, scan_time, x, y, z, yaw, pitch, roll)
353+
354+
# # vehicle gnss heading (yaw) in radians
355+
# # vehicle x, y position in fixed local frame, in meters
356+
# # reference point is located at the center of GNSS antennas
357+
# localxy = transforms.rotate2d(self.location,-self.yaw_offset)
358+
# gnss_xyhead_inv = (-localxy[0],-localxy[1],-self.yaw_offset)
359+
# center_xyhead = self.gnss_pose.apply_xyhead(gnss_xyhead_inv)
360+
# vehicle_pose_global = replace(self.gnss_pose,
361+
# t=self.vehicle_interface.time(),
362+
# x=center_xyhead[0],
363+
# y=center_xyhead[1],
364+
# yaw=center_xyhead[2])
365+
366+
readings = self.vehicle_interface.get_reading()
367+
raw = readings.to_state(self.map_based_pose)
368+
369+
print("Extraction", self.vehicle_interface.time() - scan_time)
370+
print(x, y, z, yaw)
371+
372+
#filtering speed
373+
if self.map_based_speed != None:
374+
raw.v = self.map_based_speed
375+
else:
376+
raw.v = 0.0
377+
378+
if self.gnss_pose is None:
379+
return
380+
#TODO: figure out what this status means
381+
#print("INS status",self.status)
382+
383+
# vehicle gnss heading (yaw) in radians
384+
# vehicle x, y position in fixed local frame, in meters
385+
# reference point is located at the center of GNSS antennas
386+
localxy = transforms.rotate2d(self.location,-self.yaw_offset)
387+
gnss_xyhead_inv = (-localxy[0],-localxy[1],-self.yaw_offset)
388+
center_xyhead = self.gnss_pose.apply_xyhead(gnss_xyhead_inv)
389+
vehicle_pose_global = replace(self.gnss_pose,
390+
t=self.vehicle_interface.time(),
391+
x=center_xyhead[0],
392+
y=center_xyhead[1],
393+
yaw=center_xyhead[2])
394+
395+
print(vehicle_pose_global.x, vehicle_pose_global.y, vehicle_pose_global.z, vehicle_pose_global.yaw)
396+
397+
return raw

0 commit comments

Comments
 (0)