diff --git a/src/tracking_car/carla_tracking_ros/.gitignore b/src/tracking_car/carla_tracking_ros/.gitignore
new file mode 100644
index 0000000000..ea4e295337
--- /dev/null
+++ b/src/tracking_car/carla_tracking_ros/.gitignore
@@ -0,0 +1,38 @@
+# ROS编译文件
+build/
+devel/
+install/
+
+# Python缓存
+__pycache__/
+*.py[cod]
+*$py.class
+*.so
+.Python
+
+# 日志和临时文件
+*.log
+*.pid
+*.sqlite
+*.sqlite3
+
+# 编辑器文件
+.vscode/
+.idea/
+*.swp
+*.swo
+*~
+
+# 系统文件
+.DS_Store
+Thumbs.db
+
+# 备份文件
+*.bak
+*.backup
+*.tmp
+
+# 大文件
+*.bag
+*.mp4
+*.avi
diff --git a/src/tracking_car/carla_tracking_ros/CMakeLists.txt b/src/tracking_car/carla_tracking_ros/CMakeLists.txt
new file mode 100644
index 0000000000..4963f232df
--- /dev/null
+++ b/src/tracking_car/carla_tracking_ros/CMakeLists.txt
@@ -0,0 +1,28 @@
+cmake_minimum_required(VERSION 3.0.2)
+project(carla_tracking_ros)
+
+find_package(catkin REQUIRED COMPONENTS
+ roscpp
+ rospy
+ std_msgs
+ cv_bridge
+ image_transport
+)
+
+catkin_python_setup()
+
+catkin_package(
+ CATKIN_DEPENDS roscpp rospy std_msgs cv_bridge image_transport
+)
+
+install(DIRECTORY launch/
+ DESTINATION ${CATKIN_PACKAGE_SHARE_DESTINATION}/launch
+)
+
+install(DIRECTORY config/
+ DESTINATION ${CATKIN_PACKAGE_SHARE_DESTINATION}/config
+)
+
+install(FILES requirements.txt
+ DESTINATION ${CATKIN_PACKAGE_SHARE_DESTINATION}
+)
diff --git a/src/tracking_car/carla_tracking_ros/config/config.yaml b/src/tracking_car/carla_tracking_ros/config/config.yaml
new file mode 100644
index 0000000000..5e1d9064b0
--- /dev/null
+++ b/src/tracking_car/carla_tracking_ros/config/config.yaml
@@ -0,0 +1,90 @@
+# ======================== CARLA连接配置 ========================
+host: "localhost"
+port: 2000
+timeout: 20.0
+
+# ======================== 传感器配置 ========================
+img_width: 1280
+img_height: 960
+fov: 110
+sensor_tick: 0.05
+use_lidar: true
+
+# ======================== LiDAR配置 ========================
+lidar_channels: 32
+lidar_range: 100.0
+lidar_points_per_second: 500000
+
+# ======================== 检测模型配置 ========================
+yolo_model: "yolov8n.pt"
+conf_thres: 0.1
+iou_thres: 0.3
+device: "cpu"
+yolo_imgsz_max: 640
+
+# ======================== 跟踪算法配置 ========================
+max_age: 5
+min_hits: 3
+kf_dt: 0.05
+max_speed: 50.0
+
+# ======================== 可视化配置 ========================
+window_width: 1280
+window_height: 720
+display_fps: 30
+
+# ======================== 轨迹历史配置 ========================
+track_history_len: 20
+
+# ======================== 行为分析配置 ========================
+stop_speed_thresh: 0.8
+stop_frames_thresh: 5
+overtake_speed_ratio: 1.5
+overtake_dist_thresh: 60.0
+lane_change_thresh: 0.4
+brake_accel_thresh: 2.0
+turn_angle_thresh: 15.0
+danger_dist_thresh: 15.0
+predict_frames: 10
+
+# ======================== 天气与NPC配置 ========================
+default_weather: "clear"
+num_npcs: 20
+
+# ======================== 视角控制配置 ========================
+view:
+ default_mode: "satellite"
+ satellite_height: 50.0
+ behind_distance: 10.0
+ first_person_height: 1.6
+
+# ======================== 热重载配置 ========================
+hot_reload:
+ enabled: true # 是否启用热重载
+ check_interval: 100 # 每多少帧检查一次配置文件
+ safe_keys_only: true # 是否只更新安全参数(避免重启连接)
+
+# 可热重载的安全参数列表(这些参数可以安全更新而不需要重启)
+hot_reload_safe_keys:
+ - "conf_thres" # 检测置信度阈值
+ - "iou_thres" # IOU阈值
+ - "display_fps" # 显示帧率
+ - "max_age" # 跟踪最大丢失帧数
+ - "min_hits" # 跟踪最小匹配次数
+ - "adaptive_fps" # 是否启用自适应帧率
+ - "min_fps" # 最小FPS
+ - "max_fps" # 最大FPS
+ - "fov" # 相机视野
+ - "yolo_imgsz_max" # YOLO输入尺寸
+ - "stop_speed_thresh" # 停车速度阈值
+ - "danger_dist_thresh" # 危险距离阈值
+ - "weather" # 天气设置
+
+# ======================== 交通标志检测配置 ========================
+enable_sign_detection: false # 是否启用交通标志检测
+
+traffic_sign:
+ enabled: true # 检测器是否启用
+ show_signs: true # 是否显示检测框
+ enable_actions: false # 是否触发动作(警告等)
+ conf_threshold: 0.5 # 置信度阈值
\ No newline at end of file
diff --git a/src/tracking_car/carla_tracking_ros/launch/main.launch b/src/tracking_car/carla_tracking_ros/launch/main.launch
new file mode 100644
index 0000000000..0fcf5e8f11
--- /dev/null
+++ b/src/tracking_car/carla_tracking_ros/launch/main.launch
@@ -0,0 +1,34 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/src/tracking_car/carla_tracking_ros/package.xml b/src/tracking_car/carla_tracking_ros/package.xml
new file mode 100644
index 0000000000..b614e5516e
--- /dev/null
+++ b/src/tracking_car/carla_tracking_ros/package.xml
@@ -0,0 +1,31 @@
+
+
+ carla_tracking_ros
+ 1.0.0
+ CARLA Multi-Object Tracking System - ROS Integration
+
+ tjl
+ MIT
+
+ https://github.com/yourusername/carla_tracking_ros
+ tjl
+
+ catkin
+
+ roscpp
+ rospy
+ std_msgs
+ cv_bridge
+ image_transport
+
+
+ message_generation
+ message_runtime
+
+
+ rosunit
+
+
+
+
+
diff --git a/src/tracking_car/carla_tracking_ros/requirements.txt b/src/tracking_car/carla_tracking_ros/requirements.txt
new file mode 100644
index 0000000000..27cc8e8623
--- /dev/null
+++ b/src/tracking_car/carla_tracking_ros/requirements.txt
@@ -0,0 +1,17 @@
+# 核心依赖
+# 可选依赖
+
+carla>=0.9.14
+loguru>=0.7.0
+matplotlib>=3.7.0
+numba>=0.57.0
+numpy>=1.24.0
+open3d>=0.17.0
+opencv-python>=4.8.0
+pandas>=2.0.0
+psutil>=5.9.0
+pyyaml>=6.0.0
+scikit-learn>=1.2.0
+scipy>=1.10.0
+torch>=2.0.0
+ultralytics>=8.0.0
diff --git a/src/tracking_car/carla_tracking_ros/scripts/__init__.py b/src/tracking_car/carla_tracking_ros/scripts/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/src/tracking_car/carla_tracking_ros/scripts/main.py b/src/tracking_car/carla_tracking_ros/scripts/main.py
new file mode 100644
index 0000000000..a7d504b607
--- /dev/null
+++ b/src/tracking_car/carla_tracking_ros/scripts/main.py
@@ -0,0 +1,159 @@
+#!/usr/bin/env python3
+"""
+测试版 - CARLA多目标跟踪系统(ROS版)
+简化版本用于测试ROS封装
+"""
+
+import sys
+import os
+import time
+import argparse
+
+# ======================== ROS支持 ========================
+try:
+ import rospy
+ from sensor_msgs.msg import Image
+ from std_msgs.msg import String, Float32
+ from cv_bridge import CvBridge
+ import cv2
+ import numpy as np
+ ROS_AVAILABLE = True
+ print("✅ ROS模块导入成功")
+except ImportError as e:
+ ROS_AVAILABLE = False
+ print(f"⚠️ ROS模块导入失败: {e}")
+
+# ======================== 简单日志类 ========================
+class SimpleLogger:
+ def info(self, msg): print(f"[INFO] {msg}")
+ def warning(self, msg): print(f"[WARN] {msg}")
+ def error(self, msg): print(f"[ERROR] {msg}")
+
+logger = SimpleLogger()
+
+# ======================== 模拟传感器类 ========================
+class MockSensorManager:
+ def __init__(self):
+ logger.info("初始化模拟传感器管理器")
+
+ def get_image(self):
+ # 创建一个简单的测试图像
+ img = np.zeros((480, 640, 3), dtype=np.uint8)
+ cv2.putText(img, 'CARLA Tracking ROS', (50, 240),
+ cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
+ return img
+
+ def get_detections(self):
+ # 模拟检测结果
+ return [
+ {'id': 1, 'bbox': [100, 100, 200, 200], 'class': 'vehicle'},
+ {'id': 2, 'bbox': [300, 150, 400, 250], 'class': 'pedestrian'}
+ ]
+
+# ======================== 模拟跟踪器类 ========================
+class MockTracker:
+ def __init__(self):
+ logger.info("初始化模拟多目标跟踪器")
+ self.track_id = 0
+
+ def update(self, detections):
+ self.track_id += 1
+ return [{'id': self.track_id, 'detection': d} for d in detections]
+
+# ======================== ROS发布器类 ========================
+class ROSPublisher:
+ def __init__(self):
+ if ROS_AVAILABLE:
+ self.bridge = CvBridge()
+ self.image_pub = rospy.Publisher('/carla/camera', Image, queue_size=10)
+ self.detection_pub = rospy.Publisher('/carla/detections', String, queue_size=10)
+ self.status_pub = rospy.Publisher('/carla/status', String, queue_size=10)
+ logger.info("ROS发布器初始化完成")
+ else:
+ logger.warning("ROS不可用,发布器未初始化")
+
+ def publish_image(self, cv_image):
+ if ROS_AVAILABLE and hasattr(self, 'image_pub'):
+ try:
+ ros_image = self.bridge.cv2_to_imgmsg(cv_image, "bgr8")
+ self.image_pub.publish(ros_image)
+ return True
+ except Exception as e:
+ logger.warning(f"发布图像失败: {e}")
+ return False
+
+ def publish_detection(self, detections):
+ if ROS_AVAILABLE and hasattr(self, 'detection_pub'):
+ det_str = str(detections)
+ self.detection_pub.publish(det_str)
+ return True
+ return False
+
+ def publish_status(self, status):
+ if ROS_AVAILABLE and hasattr(self, 'status_pub'):
+ self.status_pub.publish(status)
+ return True
+ return False
+
+# ======================== 主函数 ========================
+def main():
+ parser = argparse.ArgumentParser(description='CARLA多目标跟踪系统ROS测试版')
+ parser.add_argument('--mode', choices=['simulation', 'test'], default='test',
+ help='运行模式:simulation(仿真)或 test(测试)')
+ parser.add_argument('--duration', type=int, default=10,
+ help='运行时长(秒)')
+ parser.add_argument('--rate', type=float, default=2.0,
+ help='发布频率(Hz)')
+ parser.add_argument('--no-ros', action='store_true',
+ help='禁用ROS功能')
+
+ args = parser.parse_args()
+
+ # 初始化ROS
+ if ROS_AVAILABLE and not args.no_ros:
+ rospy.init_node('carla_tracking_system')
+ logger.info("ROS节点初始化: carla_tracking_system")
+
+ # 创建组件
+ sensor_manager = MockSensorManager()
+ tracker = MockTracker()
+ ros_publisher = ROSPublisher() if not args.no_ros else None
+
+ logger.info(f"启动CARLA多目标跟踪系统ROS测试版")
+ logger.info(f"模式: {args.mode}, 时长: {args.duration}秒, 频率: {args.rate}Hz")
+ logger.info(f"ROS支持: {'启用' if ros_publisher else '禁用'}")
+
+ # 主循环
+ start_time = time.time()
+ frame_count = 0
+
+ try:
+ while time.time() - start_time < args.duration:
+ # 获取数据
+ image = sensor_manager.get_image()
+ detections = sensor_manager.get_detections()
+ tracks = tracker.update(detections)
+
+ # 发布ROS消息
+ if ros_publisher:
+ ros_publisher.publish_image(image)
+ ros_publisher.publish_detection(detections)
+ ros_publisher.publish_status(f"Frame {frame_count}: {len(tracks)} tracks")
+
+ # 显示信息
+ logger.info(f"帧 {frame_count}: 检测到 {len(detections)} 个目标, 跟踪 {len(tracks)} 个轨迹")
+
+ frame_count += 1
+ time.sleep(1.0 / args.rate)
+
+ except KeyboardInterrupt:
+ logger.info("收到退出信号")
+ except Exception as e:
+ logger.error(f"运行错误: {e}")
+
+ logger.info(f"系统运行完成,共处理 {frame_count} 帧")
+ logger.info(f"平均帧率: {frame_count / args.duration:.2f} Hz")
+ logger.info("系统关闭")
+
+if __name__ == '__main__':
+ main()
diff --git a/src/tracking_car/carla_tracking_ros/scripts/sensors.py b/src/tracking_car/carla_tracking_ros/scripts/sensors.py
new file mode 100644
index 0000000000..e4ba38fda4
--- /dev/null
+++ b/src/tracking_car/carla_tracking_ros/scripts/sensors.py
@@ -0,0 +1,945 @@
+"""
+sensors.py - CARLA传感器管理
+包含:相机、LiDAR传感器封装和管理
+"""
+
+import random # 添加随机模块支持
+import carla
+import cv2
+import numpy as np
+import queue
+import threading
+import sys
+import time
+
+# 配置日志
+try:
+ from loguru import logger
+except ImportError:
+ # 使用标准logging作为回退
+ import logging
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
+ logger = logging.getLogger(__name__)
+
+import open3d as o3d
+from sklearn.cluster import DBSCAN
+
+class CameraManager:
+ """相机管理器"""
+
+ def __init__(self, world, ego_vehicle, config):
+ """
+ 初始化相机
+
+ Args:
+ world: CARLA世界对象
+ ego_vehicle: 自车对象
+ config: 配置字典
+ """
+ self.world = world
+ self.ego_vehicle = ego_vehicle
+ self.config = config
+ self.camera = None
+ self.image_queue = queue.Queue(maxsize=2)
+ self.current_image = None
+ self.frame_count = 0
+
+ def setup(self):
+ """设置相机"""
+ try:
+ # 获取相机蓝图
+ camera_bp = self.world.get_blueprint_library().find('sensor.camera.rgb')
+
+ # 设置相机属性
+ camera_bp.set_attribute('image_size_x', str(self.config.get('img_width', 640)))
+ camera_bp.set_attribute('image_size_y', str(self.config.get('img_height', 480)))
+ camera_bp.set_attribute('fov', str(self.config.get('fov', 90)))
+ camera_bp.set_attribute('sensor_tick', str(self.config.get('sensor_tick', 0.05)))
+
+ # 设置相机位置(车顶前方)
+ camera_transform = carla.Transform(
+ carla.Location(x=2.5, z=2.5),
+ carla.Rotation(pitch=-5) # 略微向下倾斜
+ )
+
+ # 生成相机
+ self.camera = self.world.spawn_actor(
+ camera_bp,
+ camera_transform,
+ attach_to=self.ego_vehicle
+ )
+
+ # 绑定回调函数
+ self.camera.listen(self._camera_callback)
+
+ logger.info(f"✅ 相机初始化成功 (ID: {self.camera.id})")
+ return True
+
+ except Exception as e:
+ logger.error(f"❌ 相机初始化失败: {e}")
+ return False
+
+ def _camera_callback(self, image):
+ """相机数据回调函数"""
+ try:
+ # 将原始数据转换为numpy数组
+ array = np.frombuffer(image.raw_data, dtype=np.uint8)
+ array = array.reshape((image.height, image.width, 4))
+
+ # 提取RGB通道(去掉alpha通道)
+ rgb_array = array[:, :, :3]
+
+ # 轻微高斯模糊减少噪声
+ rgb_array = cv2.GaussianBlur(rgb_array, (3, 3), 0)
+
+ # 更新当前图像
+ self.current_image = rgb_array
+
+ # 放入队列(如果队列已满,丢弃最旧的数据)
+ if self.image_queue.full():
+ try:
+ self.image_queue.get_nowait()
+ except queue.Empty:
+ pass
+
+ self.image_queue.put(rgb_array.copy())
+ self.frame_count += 1
+
+ except Exception as e:
+ logger.warning(f"相机回调错误: {e}")
+
+ def get_image(self, timeout=0.1):
+ """
+ 获取最新图像
+
+ Args:
+ timeout: 超时时间(秒)
+
+ Returns:
+ np.ndarray or None: 图像数据
+ """
+ try:
+ # 首先尝试从队列获取最新图像
+ image = self.image_queue.get(timeout=timeout)
+ # 清空队列中的旧图像
+ while not self.image_queue.empty():
+ try:
+ self.image_queue.get_nowait()
+ except queue.Empty:
+ break
+ return image
+ except queue.Empty:
+ # 如果队列为空,返回当前图像
+ return self.current_image
+
+ def get_current_image(self):
+ """获取当前图像(不阻塞)"""
+ return self.current_image
+
+ def destroy(self):
+ """销毁相机"""
+ if self.camera and self.camera.is_alive:
+ try:
+ self.camera.stop()
+ self.camera.destroy()
+ logger.info("✅ 相机已销毁")
+ except Exception as e:
+ logger.warning(f"销毁相机失败: {e}")
+ self.camera = None
+
+class LiDARManager:
+ """LiDAR管理器"""
+
+ def __init__(self, world, ego_vehicle, config):
+ """
+ 初始化LiDAR
+
+ Args:
+ world: CARLA世界对象
+ ego_vehicle: 自车对象
+ config: 配置字典
+ """
+ self.world = world
+ self.ego_vehicle = ego_vehicle
+ self.config = config
+ self.lidar = None
+ self.pointcloud_queue = queue.Queue(maxsize=2)
+ self.current_pointcloud = None
+ self.current_transform = None
+
+ def setup(self):
+ """设置LiDAR"""
+ try:
+ if not self.config.get('use_lidar', True):
+ logger.info("LiDAR被禁用")
+ return True
+
+ # 获取LiDAR蓝图
+ lidar_bp = self.world.get_blueprint_library().find('sensor.lidar.ray_cast')
+
+ # 设置LiDAR属性
+ lidar_bp.set_attribute('channels', str(self.config.get('lidar_channels', 32)))
+ lidar_bp.set_attribute('range', str(self.config.get('lidar_range', 100.0)))
+ lidar_bp.set_attribute('points_per_second',
+ str(self.config.get('lidar_points_per_second', 500000)))
+ lidar_bp.set_attribute('rotation_frequency', str(self.config.get('rotation_frequency', 20)))
+ lidar_bp.set_attribute('sensor_tick', str(self.config.get('sensor_tick', 0.05)))
+
+ # 设置LiDAR位置(车顶中央)
+ lidar_transform = carla.Transform(
+ carla.Location(x=0.0, z=2.5),
+ carla.Rotation()
+ )
+
+ # 生成LiDAR
+ self.lidar = self.world.spawn_actor(
+ lidar_bp,
+ lidar_transform,
+ attach_to=self.ego_vehicle
+ )
+
+ # 绑定回调函数
+ self.lidar.listen(self._lidar_callback)
+
+ logger.info(f"✅ LiDAR初始化成功 (ID: {self.lidar.id})")
+ return True
+
+ except Exception as e:
+ logger.error(f"❌ LiDAR初始化失败: {e}")
+ return False
+
+ def _lidar_callback(self, pointcloud):
+ """LiDAR数据回调函数"""
+ try:
+ # 将原始数据转换为numpy数组
+ points = np.frombuffer(pointcloud.raw_data, dtype=np.float32)
+ points = points.reshape(-1, 4)[:, :3] # 只取xyz,忽略反射强度
+
+ # 过滤地面点(简单的高度过滤)
+ ground_mask = points[:, 2] < -1.0
+ filtered_points = points[~ground_mask]
+
+ # 更新当前点云
+ self.current_pointcloud = filtered_points
+ self.current_transform = pointcloud.transform
+
+ # 放入队列(如果队列已满,丢弃最旧的数据)
+ if self.pointcloud_queue.full():
+ try:
+ self.pointcloud_queue.get_nowait()
+ except queue.Empty:
+ pass
+
+ self.pointcloud_queue.put((filtered_points.copy(), pointcloud.transform))
+
+ except Exception as e:
+ logger.warning(f"LiDAR回调错误: {e}")
+
+ def get_pointcloud(self, timeout=0.1):
+ """
+ 获取最新点云数据
+
+ Args:
+ timeout: 超时时间(秒)
+
+ Returns:
+ tuple: (points, transform) 或 (None, None)
+ """
+ try:
+ points, transform = self.pointcloud_queue.get(timeout=timeout)
+ # 清空队列中的旧数据
+ while not self.pointcloud_queue.empty():
+ try:
+ self.pointcloud_queue.get_nowait()
+ except queue.Empty:
+ break
+ return points, transform
+ except queue.Empty:
+ # 如果队列为空,返回当前点云
+ return self.current_pointcloud, self.current_transform
+
+ def detect_objects(self, min_points=30):
+ """
+ 从点云中检测物体
+
+ Args:
+ min_points: 最小点数阈值
+
+ Returns:
+ list: 检测到的物体列表
+ """
+ if self.current_pointcloud is None or len(self.current_pointcloud) < min_points:
+ return []
+
+ try:
+ # 使用DBSCAN聚类
+ clustering = DBSCAN(eps=0.8, min_samples=30).fit(self.current_pointcloud[:, :2])
+
+ objects = []
+ for label in set(clustering.labels_):
+ if label == -1: # 忽略噪声点
+ continue
+
+ # 获取该聚类的点
+ cluster_points = self.current_pointcloud[clustering.labels_ == label]
+
+ if len(cluster_points) < min_points:
+ continue
+
+ # 计算3D边界框
+ min_coords = cluster_points.min(axis=0)
+ max_coords = cluster_points.max(axis=0)
+ center = (min_coords + max_coords) / 2
+ size = max_coords - min_coords
+
+ # 估计物体类型(基于尺寸)
+ obj_type = self._estimate_object_type(size)
+
+ objects.append({
+ 'bbox_3d': [*min_coords, *max_coords], # [x_min, y_min, z_min, x_max, y_max, z_max]
+ 'center': center.tolist(),
+ 'size': size.tolist(),
+ 'num_points': len(cluster_points),
+ 'type': obj_type,
+ 'label': label
+ })
+
+ return objects
+
+ except Exception as e:
+ logger.warning(f"LiDAR物体检测失败: {e}")
+ return []
+
+ def _estimate_object_type(self, size):
+ """根据尺寸估计物体类型"""
+ length, width, height = size
+
+ # 简单的大小分类
+ if height > 2.5:
+ return "truck"
+ elif width > 2.0:
+ return "bus"
+ elif length > 4.0:
+ return "car"
+ else:
+ return "unknown"
+
+ def get_open3d_pointcloud(self):
+ """
+ 获取Open3D格式的点云
+
+ Returns:
+ o3d.geometry.PointCloud or None: Open3D点云对象
+ """
+ if self.current_pointcloud is None or len(self.current_pointcloud) == 0:
+ return None
+
+ try:
+ pcd = o3d.geometry.PointCloud()
+ pcd.points = o3d.utility.Vector3dVector(self.current_pointcloud)
+
+ # 根据高度着色(低处蓝色,高处红色)
+ z_min = self.current_pointcloud[:, 2].min()
+ z_max = self.current_pointcloud[:, 2].max()
+ z_range = max(z_max - z_min, 1e-6)
+
+ colors = np.zeros((len(self.current_pointcloud), 3))
+ normalized_z = (self.current_pointcloud[:, 2] - z_min) / z_range
+ colors[:, 0] = normalized_z # 红色通道(高处)
+ colors[:, 2] = 1 - normalized_z # 蓝色通道(低处)
+
+ pcd.colors = o3d.utility.Vector3dVector(colors)
+
+ return pcd
+
+ except Exception as e:
+ logger.warning(f"创建Open3D点云失败: {e}")
+ return None
+
+ def destroy(self):
+ """销毁LiDAR"""
+ if self.lidar and self.lidar.is_alive:
+ try:
+ self.lidar.stop()
+ self.lidar.destroy()
+ logger.info("✅ LiDAR已销毁")
+ except Exception as e:
+ logger.warning(f"销毁LiDAR失败: {e}")
+ self.lidar = None
+
+class SpectatorManager:
+ """CARLA视角管理器 - 提供卫星视角跟随"""
+
+ def __init__(self, world, ego_vehicle, config):
+ """
+ 初始化视角管理器
+
+ Args:
+ world: CARLA世界对象
+ ego_vehicle: 自车对象
+ config: 配置字典
+ """
+ self.world = world
+ self.ego_vehicle = ego_vehicle
+ self.config = config
+ self.spectator = None
+ view_config = config.get('view', {})
+ self.view_mode = view_config.get('default_mode', 'satellite')
+ self.view_height = view_config.get('satellite_height', 50.0)
+ self.follow_distance = view_config.get('behind_distance', 10.0)
+ self.first_person_height = view_config.get('first_person_height', 1.6)
+
+ def setup(self):
+ """设置视角"""
+ try:
+ # 获取世界中的观察者(spectator)
+ self.spectator = self.world.get_spectator()
+
+ # 设置初始视角
+ self._set_satellite_view()
+
+ logger.info(f"✅ 视角管理器初始化成功")
+ return True
+
+ except Exception as e:
+ logger.error(f"❌ 视角管理器初始化失败: {e}")
+ return False
+
+ def _set_satellite_view(self):
+ """设置卫星视角"""
+ if not self.ego_vehicle or not self.ego_vehicle.is_alive:
+ return
+
+ try:
+ # 获取车辆位置和方向
+ vehicle_transform = self.ego_vehicle.get_transform()
+ vehicle_location = vehicle_transform.location
+ vehicle_rotation = vehicle_transform.rotation
+
+ # 设置卫星视角位置(车辆正上方)
+ spectator_location = carla.Location(
+ x=vehicle_location.x,
+ y=vehicle_location.y,
+ z=vehicle_location.z + self.view_height
+ )
+
+ # 设置视角朝向(俯瞰车辆)
+ spectator_rotation = carla.Rotation(
+ pitch=-90, # 向下看
+ yaw=vehicle_rotation.yaw,
+ roll=0
+ )
+
+ # 应用变换
+ spectator_transform = carla.Transform(
+ spectator_location,
+ spectator_rotation
+ )
+ self.spectator.set_transform(spectator_transform)
+
+ except Exception as e:
+ logger.warning(f"设置卫星视角失败: {e}")
+
+ def _set_behind_view(self):
+ """设置后方跟随视角"""
+ if not self.ego_vehicle or not self.ego_vehicle.is_alive:
+ return
+
+ try:
+ # 获取车辆位置和方向
+ vehicle_transform = self.ego_vehicle.get_transform()
+ vehicle_location = vehicle_transform.location
+ vehicle_rotation = vehicle_transform.rotation
+
+ # 计算后方偏移位置
+ import math
+ yaw_rad = math.radians(vehicle_rotation.yaw)
+
+ # 车辆后方偏移
+ behind_x = vehicle_location.x - self.follow_distance * math.cos(yaw_rad)
+ behind_y = vehicle_location.y - self.follow_distance * math.sin(yaw_rad)
+
+ # 设置摄像机位置(稍高于车辆)
+ spectator_location = carla.Location(
+ x=behind_x,
+ y=behind_y,
+ z=vehicle_location.z + 3.0
+ )
+
+ # 设置视角朝向(看向车辆)
+ # 计算朝向车辆的旋转
+ dx = vehicle_location.x - spectator_location.x
+ dy = vehicle_location.y - spectator_location.y
+ target_yaw = math.degrees(math.atan2(dy, dx))
+
+ spectator_rotation = carla.Rotation(
+ pitch=-15, # 略微向下
+ yaw=target_yaw,
+ roll=0
+ )
+
+ # 应用变换
+ spectator_transform = carla.Transform(
+ spectator_location,
+ spectator_rotation
+ )
+ self.spectator.set_transform(spectator_transform)
+
+ except Exception as e:
+ logger.warning(f"设置后方视角失败: {e}")
+
+ def _set_first_person_view(self):
+ """设置第一人称视角"""
+ if not self.ego_vehicle or not self.ego_vehicle.is_alive:
+ return
+
+ try:
+ # 获取车辆变换
+ vehicle_transform = self.ego_vehicle.get_transform()
+
+ # 稍微调整位置(从驾驶员视角)
+ location = carla.Location(
+ x=vehicle_transform.location.x,
+ y=vehicle_transform.location.y,
+ z=vehicle_transform.location.z + self.first_person_height # 改为使用配置
+ )
+
+ # 使用车辆的方向
+ rotation = carla.Rotation(
+ pitch=vehicle_transform.rotation.pitch,
+ yaw=vehicle_transform.rotation.yaw,
+ roll=vehicle_transform.rotation.roll
+ )
+
+ # 应用变换
+ spectator_transform = carla.Transform(
+ location,
+ rotation
+ )
+ self.spectator.set_transform(spectator_transform)
+
+ except Exception as e:
+ logger.warning(f"设置第一人称视角失败: {e}")
+
+ def set_view_mode(self, mode):
+ """
+ 设置视角模式
+
+ Args:
+ mode: 视角模式 ('satellite', 'behind', 'first_person')
+ """
+ if mode not in ['satellite', 'behind', 'first_person']:
+ logger.warning(f"未知的视角模式: {mode}")
+ return
+
+ self.view_mode = mode
+
+ def update(self):
+ """更新视角"""
+ if not self.ego_vehicle or not self.ego_vehicle.is_alive:
+ return
+
+ try:
+ if self.view_mode == 'satellite':
+ self._set_satellite_view()
+ elif self.view_mode == 'behind':
+ self._set_behind_view()
+ elif self.view_mode == 'first_person':
+ self._set_first_person_view()
+
+ except Exception as e:
+ logger.debug(f"更新视角失败: {e}")
+
+ def cycle_view_mode(self):
+ """循环切换视角模式"""
+ modes = ['satellite', 'behind', 'first_person']
+ current_index = modes.index(self.view_mode) if self.view_mode in modes else 0
+ next_index = (current_index + 1) % len(modes)
+ self.view_mode = modes[next_index]
+ logger.info(f"切换到 {self.view_mode} 视角")
+
+ def destroy(self):
+ """销毁视角管理器"""
+ self.spectator = None
+ logger.info("✅ 视角管理器已销毁")
+
+class SensorManager:
+ """传感器管理器(统一管理所有传感器)"""
+
+ def __init__(self, world, ego_vehicle, config):
+ """
+ 初始化传感器管理器
+
+ Args:
+ world: CARLA世界对象
+ ego_vehicle: 自车对象
+ config: 配置字典
+ """
+ self.world = world
+ self.ego_vehicle = ego_vehicle
+ self.config = config
+
+ self.camera_manager = None
+ self.lidar_manager = None
+ self.spectator_manager = None # 新增视角管理器
+ self.is_setup = False
+
+ def setup(self):
+ """设置所有传感器"""
+ logger.info("正在初始化传感器...")
+
+ # 初始化相机
+ self.camera_manager = CameraManager(self.world, self.ego_vehicle, self.config)
+ camera_success = self.camera_manager.setup()
+
+ # 初始化LiDAR
+ lidar_success = True
+ if self.config.get('use_lidar', True):
+ self.lidar_manager = LiDARManager(self.world, self.ego_vehicle, self.config)
+ lidar_success = self.lidar_manager.setup()
+ else:
+ logger.info("LiDAR功能已禁用")
+
+ # 初始化视角管理器
+ self.spectator_manager = SpectatorManager(self.world, self.ego_vehicle, self.config)
+ spectator_success = self.spectator_manager.setup()
+
+ self.is_setup = camera_success and lidar_success and spectator_success
+
+ if self.is_setup:
+ logger.info("✅ 所有传感器初始化完成")
+ else:
+ logger.warning("⚠️ 传感器初始化不完全")
+
+ return self.is_setup
+
+ # 在get_sensor_data方法中添加视角更新
+ def get_sensor_data(self, timeout=0.05):
+ """
+ 获取所有传感器数据
+
+ Args:
+ timeout: 超时时间(秒)
+
+ Returns:
+ dict: 传感器数据字典
+ """
+ data = {
+ 'image': None,
+ 'pointcloud': None,
+ 'lidar_transform': None,
+ 'lidar_objects': [],
+ 'timestamp': time.time()
+ }
+
+ # 获取相机图像
+ if self.camera_manager:
+ data['image'] = self.camera_manager.get_image(timeout=timeout)
+
+ # 获取LiDAR数据
+ if self.lidar_manager:
+ points, transform = self.lidar_manager.get_pointcloud(timeout=timeout)
+ data['pointcloud'] = points
+ data['lidar_transform'] = transform
+
+ # 检测物体
+ if points is not None:
+ data['lidar_objects'] = self.lidar_manager.detect_objects()
+
+ # 更新视角
+ if self.spectator_manager:
+ self.spectator_manager.update()
+
+ return data
+
+ # 添加视角控制方法
+ def set_view_mode(self, mode):
+ """设置视角模式"""
+ if self.spectator_manager:
+ self.spectator_manager.set_view_mode(mode)
+
+ def cycle_view_mode(self):
+ """循环切换视角模式"""
+ if self.spectator_manager:
+ self.spectator_manager.cycle_view_mode()
+
+ def destroy(self):
+ """销毁所有传感器"""
+ logger.info("正在销毁传感器...")
+
+ if self.camera_manager:
+ self.camera_manager.destroy()
+
+ if self.lidar_manager:
+ self.lidar_manager.destroy()
+
+ if self.spectator_manager:
+ self.spectator_manager.destroy()
+
+ logger.info("✅ 所有传感器已销毁")
+
+def create_ego_vehicle(world, config, spawn_points=None):
+ """
+ 创建自车
+
+ Args:
+ world: CARLA世界对象
+ config: 配置字典
+ spawn_points: 可选的自定义生成点列表
+
+ Returns:
+ carla.Vehicle or None: 自车对象
+ """
+ try:
+ # 获取生成点
+ if spawn_points is None:
+ spawn_points = world.get_map().get_spawn_points()
+
+ if not spawn_points:
+ logger.error("❌ 无可用生成点")
+ return None
+
+ logger.info(f"找到 {len(spawn_points)} 个生成点")
+
+ # 选择车辆蓝图
+ vehicle_bp = None
+ vehicle_filter = config.get('ego_vehicle_filter', 'vehicle.tesla.model3')
+
+ # 尝试首选车辆
+ blueprint_library = world.get_blueprint_library()
+ for bp in blueprint_library.filter(vehicle_filter):
+ if int(bp.get_attribute('number_of_wheels')) == 4:
+ vehicle_bp = bp
+ logger.info(f"找到车辆蓝图: {bp.id}")
+ break
+
+ # 如果没找到,选择任意四轮车辆
+ if vehicle_bp is None:
+ logger.info("首选车辆未找到,尝试其他四轮车辆...")
+ for bp in blueprint_library.filter('vehicle.*'):
+ if int(bp.get_attribute('number_of_wheels')) == 4:
+ vehicle_bp = bp
+ logger.info(f"使用备用车辆蓝图: {bp.id}")
+ break
+
+ if vehicle_bp is None:
+ logger.error("❌ 找不到合适的车辆蓝图")
+ return None
+
+ # 设置车辆颜色
+ color = config.get('ego_vehicle_color', '255,0,0')
+ vehicle_bp.set_attribute('color', color)
+
+ # 尝试生成车辆 - 改进的碰撞避免策略
+ max_attempts = config.get('spawn_max_attempts', 20)
+
+ for attempt in range(max_attempts):
+ # 随机选择生成点
+ if attempt < len(spawn_points):
+ spawn_point = spawn_points[attempt]
+ else:
+ # 选择随机一个生成点
+ import random
+ spawn_point = random.choice(spawn_points)
+
+ # 随机偏移位置以避免碰撞
+ spawn_point.location.x += random.uniform(-3, 3)
+ spawn_point.location.y += random.uniform(-3, 3)
+
+ logger.info(f"尝试生成自车 (尝试 {attempt + 1}/{max_attempts}) "
+ f"位置: x={spawn_point.location.x:.1f}, y={spawn_point.location.y:.1f}")
+
+ # 设置生成点的高度为地面以上0.5米
+ spawn_point.location.z += 0.5
+
+ ego_vehicle = world.try_spawn_actor(vehicle_bp, spawn_point)
+
+ if ego_vehicle is not None:
+ logger.info(f"✅ 自车生成成功 (尝试 {attempt + 1}/{max_attempts})")
+ logger.info(f" 位置: ({spawn_point.location.x:.1f}, {spawn_point.location.y:.1f}, {spawn_point.location.z:.1f})")
+
+ # 等待一小段时间让车辆稳定
+ world.tick()
+
+ # 设置自动驾驶
+ try:
+ ego_vehicle.set_autopilot(True, 8000)
+ logger.info("✅ 自车自动驾驶已启用")
+ except Exception as e:
+ logger.warning(f"设置自动驾驶失败: {e}")
+ try:
+ ego_vehicle.set_autopilot(True)
+ logger.info("✅ 自车自动驾驶已启用(备用方法)")
+ except:
+ logger.warning("无法设置自动驾驶,车辆将保持静止")
+
+ return ego_vehicle
+ else:
+ logger.debug(f"生成失败,尝试下一个位置...")
+
+ logger.error(f"❌ 经过 {max_attempts} 次尝试后仍无法生成自车")
+ logger.info("建议:")
+ logger.info("1. 重新启动CARLA服务器")
+ logger.info("2. 在CARLA中手动清理场景中的车辆")
+ logger.info("3. 尝试不同的生成点")
+
+ return None
+
+ except Exception as e:
+ logger.error(f"❌ 创建自车失败: {e}")
+ import traceback
+ traceback.print_exc()
+ return None
+
+def spawn_npc_vehicles(world, config, count=None):
+ """
+ 生成NPC车辆
+
+ Args:
+ world: CARLA世界对象
+ config: 配置字典
+ count: NPC数量(默认使用配置中的值)
+
+ Returns:
+ int: 成功生成的NPC数量
+ """
+ try:
+ if count is None:
+ count = config.get('num_npcs', 20)
+
+ spawn_points = world.get_map().get_spawn_points()
+ if not spawn_points:
+ logger.warning("无可用生成点,无法生成NPC")
+ return 0
+
+ # 过滤合适的车辆蓝图(四轮车辆,排除特殊车辆)
+ vehicle_bps = []
+ for bp in world.get_blueprint_library().filter('vehicle.*'):
+ if int(bp.get_attribute('number_of_wheels')) == 4:
+ # 排除特殊车辆
+ if not bp.id.endswith(('firetruck', 'ambulance', 'police', 'charger')):
+ vehicle_bps.append(bp)
+
+ if not vehicle_bps:
+ logger.warning("找不到合适的NPC车辆蓝图")
+ return 0
+
+ spawned_count = 0
+ used_spawn_points = set()
+
+ for i in range(min(count * 3, len(spawn_points))): # 最多尝试3倍数量
+ if spawned_count >= count:
+ break
+
+ spawn_point = spawn_points[i]
+
+ # 检查是否已使用该位置
+ position_key = (round(spawn_point.location.x, 1),
+ round(spawn_point.location.y, 1))
+
+ if position_key in used_spawn_points:
+ continue
+
+ # 随机选择车辆蓝图
+ vehicle_bp = random.choice(vehicle_bps)
+
+ # 尝试生成
+ npc = world.try_spawn_actor(vehicle_bp, spawn_point)
+
+ if npc is not None:
+ used_spawn_points.add(position_key)
+ spawned_count += 1
+
+ # 设置自动驾驶
+ try:
+ npc.set_autopilot(True, 8000)
+ except:
+ try:
+ npc.set_autopilot(True)
+ except:
+ pass
+
+ logger.info(f"✅ 成功生成 {spawned_count}/{count} 个NPC车辆")
+ return spawned_count
+
+ except Exception as e:
+ logger.error(f"生成NPC车辆失败: {e}")
+ return 0
+
+def clear_all_actors(world, exclude_ids=None):
+ """
+ 清理所有演员(车辆和传感器)
+
+ Args:
+ world: CARLA世界对象
+ exclude_ids: 要排除的演员ID列表
+ """
+ try:
+ exclude_ids = set(exclude_ids) if exclude_ids else set()
+
+ actors = world.get_actors()
+
+ # 按类型分组清理
+ vehicle_actors = []
+ sensor_actors = []
+
+ for actor in actors:
+ if actor.id in exclude_ids:
+ continue
+
+ if actor.type_id.startswith('vehicle.'):
+ vehicle_actors.append(actor)
+ elif actor.type_id.startswith('sensor.'):
+ sensor_actors.append(actor)
+
+ # 先清理传感器
+ logger.info(f"清理 {len(sensor_actors)} 个传感器...")
+ for sensor in sensor_actors:
+ try:
+ if sensor.is_alive:
+ sensor.destroy()
+ except:
+ pass
+
+ # 再清理车辆
+ logger.info(f"清理 {len(vehicle_actors)} 个车辆...")
+ batch_size = 10
+ for i in range(0, len(vehicle_actors), batch_size):
+ batch = vehicle_actors[i:i+batch_size]
+ for vehicle in batch:
+ try:
+ if vehicle.is_alive:
+ vehicle.destroy()
+ except:
+ pass
+
+ logger.info("✅ 清理完成")
+
+ except Exception as e:
+ logger.warning(f"清理演员时出错: {e}")
+
+def test_sensor_manager():
+ """测试传感器管理器"""
+ print("=" * 50)
+ print("测试 sensors.py...")
+ print("=" * 50)
+
+ # 模拟配置
+ test_config = {
+ 'img_width': 640,
+ 'img_height': 480,
+ 'fov': 90,
+ 'sensor_tick': 0.05,
+ 'use_lidar': True,
+ 'lidar_channels': 32,
+ 'lidar_range': 100.0,
+ 'lidar_points_per_second': 500000,
+ }
+
+ print("✅ sensors.py 结构测试通过")
+ print("注:完整测试需要CARLA环境")
+
+ return True
+
+if __name__ == "__main__":
+ test_sensor_manager()
\ No newline at end of file
diff --git a/src/tracking_car/carla_tracking_ros/scripts/sign_detector.py b/src/tracking_car/carla_tracking_ros/scripts/sign_detector.py
new file mode 100644
index 0000000000..bbd8a82f18
--- /dev/null
+++ b/src/tracking_car/carla_tracking_ros/scripts/sign_detector.py
@@ -0,0 +1,438 @@
+"""
+sign_detector.py - 轻量级交通标志识别模块
+最小化集成,不破坏现有结构
+"""
+
+import cv2
+import numpy as np
+import torch
+from typing import List, Dict, Any, Optional
+import time
+
+try:
+ from loguru import logger
+except ImportError:
+ import logging
+ logger = logging.getLogger(__name__)
+
+class TrafficSignDetector:
+ """轻量级交通标志检测器"""
+
+ def __init__(self, config=None):
+ """
+ 初始化
+ Args:
+ config: 可选配置,默认使用内置简单配置
+ """
+ # 默认配置
+ self.config = config or {
+ 'enabled': True,
+ 'conf_threshold': 0.5,
+ 'show_signs': True,
+ 'enable_actions': False, # 默认不触发动作,只显示
+ }
+
+ # 尝试加载YOLO模型,如果失败则使用简单的颜色检测
+ self.model = None
+ self.use_yolo = False
+
+ try:
+ from ultralytics import YOLO
+ # 尝试加载预训练模型(可以用通用物体检测模型)
+ self.model = YOLO('yolov8n.pt') # 使用现有的YOLO模型
+ self.use_yolo = True
+ logger.info("✅ 使用YOLO进行标志检测")
+ except Exception as e:
+ logger.warning(f"无法加载YOLO模型,使用简单颜色检测: {e}")
+ self._init_simple_detector()
+
+ # 简单的标志颜色检测
+ self.sign_colors = {
+ 'red': { # 停车、禁止类标志
+ 'lower': np.array([0, 50, 50]),
+ 'upper': np.array([10, 255, 255])
+ },
+ 'blue': { # 指示类标志
+ 'lower': np.array([100, 50, 50]),
+ 'upper': np.array([130, 255, 255])
+ },
+ 'yellow': { # 警告类标志
+ 'lower': np.array([20, 100, 100]),
+ 'upper': np.array([30, 255, 255])
+ }
+ }
+
+ # 标志形状模板(可选)
+ self.shape_templates = self._load_shape_templates()
+
+ # 检测历史
+ self.detected_signs_history = []
+
+ logger.info("✅ 轻量级标志检测器初始化完成")
+
+ def _init_simple_detector(self):
+ """初始化简单检测器"""
+ # 加载简单的形状模板
+ self.templates = {
+ 'triangle': self._create_triangle_mask(),
+ 'circle': self._create_circle_mask(),
+ 'octagon': self._create_octagon_mask(), # 停车标志
+ 'square': self._create_square_mask()
+ }
+
+ def detect(self, image: np.ndarray, ego_speed: float = 0.0) -> List[Dict]:
+ """
+ 检测图像中的交通标志
+
+ Args:
+ image: 输入图像
+ ego_speed: 自车速度(用于距离估计)
+
+ Returns:
+ List[Dict]: 检测到的标志列表
+ """
+ if not self.config.get('enabled', True):
+ return []
+
+ signs = []
+
+ try:
+ # 方法1:如果YOLO可用,使用YOLO检测
+ if self.use_yolo and self.model is not None:
+ signs = self._detect_with_yolo(image, ego_speed)
+
+ # 方法2:否则使用简单的颜色+形状检测
+ else:
+ signs = self._detect_with_color(image, ego_speed)
+
+ # 过滤重复的标志(简单的NMS)
+ signs = self._non_max_suppression(signs)
+
+ # 更新历史
+ if signs:
+ self.detected_signs_history = signs[:10] # 保留最近10个
+
+ # 简单的动作触发(可选的)
+ if self.config.get('enable_actions', False):
+ self._trigger_actions(signs, ego_speed)
+
+ return signs
+
+ except Exception as e:
+ logger.error(f"标志检测失败: {e}")
+ return []
+
+ def _detect_with_yolo(self, image: np.ndarray, ego_speed: float) -> List[Dict]:
+ """使用YOLO检测"""
+ results = self.model.predict(
+ image,
+ conf=self.config.get('conf_threshold', 0.5),
+ verbose=False
+ )
+
+ signs = []
+
+ for result in results:
+ if result.boxes is not None:
+ for box in result.boxes:
+ # 过滤出可能是交通标志的类别(YOLO COCO数据集中的类别)
+ class_id = int(box.cls[0])
+ class_name = result.names[class_id]
+
+ # 只保留相关类别
+ if class_name in ['stop sign', 'traffic light', 'parking meter']:
+ bbox = box.xyxy[0].cpu().numpy()
+ confidence = float(box.conf[0])
+
+ # 推断标志类型
+ sign_type = self._infer_sign_type(class_name, bbox, image)
+
+ sign_info = {
+ 'bbox': bbox.tolist(),
+ 'confidence': confidence,
+ 'type': sign_type,
+ 'class_id': class_id,
+ 'class_name': class_name,
+ 'timestamp': time.time()
+ }
+
+ # 估算距离(基于边界框大小)
+ sign_info['distance'] = self._estimate_distance(sign_info['bbox'], ego_speed)
+
+ signs.append(sign_info)
+
+ return signs
+
+ def _detect_with_color(self, image: np.ndarray, ego_speed: float) -> List[Dict]:
+ """使用颜色检测"""
+ # 转换为HSV颜色空间
+ hsv = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
+
+ signs = []
+
+ # 检测每种标志颜色
+ for color_name, color_range in self.sign_colors.items():
+ # 创建颜色掩码
+ mask = cv2.inRange(hsv, color_range['lower'], color_range['upper'])
+
+ # 形态学操作去除噪声
+ kernel = np.ones((5, 5), np.uint8)
+ mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel)
+ mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel)
+
+ # 查找轮廓
+ contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
+
+ for contour in contours:
+ area = cv2.contourArea(contour)
+
+ # 过滤太小的区域
+ if area < 100:
+ continue
+
+ # 获取边界框
+ x, y, w, h = cv2.boundingRect(contour)
+
+ # 计算形状特征
+ shape = self._detect_shape(contour)
+
+ # 推断标志类型
+ sign_type = self._infer_sign_type_from_color(color_name, shape)
+
+ if sign_type:
+ sign_info = {
+ 'bbox': [x, y, x + w, y + h],
+ 'confidence': min(0.9, area / 1000.0), # 简单置信度
+ 'type': sign_type,
+ 'color': color_name,
+ 'shape': shape,
+ 'timestamp': time.time()
+ }
+
+ # 估算距离
+ sign_info['distance'] = self._estimate_distance(sign_info['bbox'], ego_speed)
+
+ signs.append(sign_info)
+
+ return signs
+
+ def _infer_sign_type(self, class_name: str, bbox: np.ndarray, image: np.ndarray) -> str:
+ """根据检测结果推断标志类型"""
+ type_map = {
+ 'stop sign': 'stop',
+ 'traffic light': 'traffic_light',
+ 'parking meter': 'parking',
+ }
+
+ # 如果是stop sign,进一步确认(检查是否是八边形)
+ if class_name == 'stop sign':
+ # 提取ROI检查形状
+ x1, y1, x2, y2 = map(int, bbox)
+ roi = image[y1:y2, x1:x2]
+ if self._is_octagon_shape(roi):
+ return 'stop'
+
+ return type_map.get(class_name, 'unknown')
+
+ def _infer_sign_type_from_color(self, color: str, shape: str) -> Optional[str]:
+ """根据颜色和形状推断标志类型"""
+ # 简单的推断规则
+ if color == 'red' and shape == 'octagon':
+ return 'stop'
+ elif color == 'red' and shape == 'circle':
+ return 'no_entry'
+ elif color == 'yellow' and shape == 'triangle':
+ return 'warning'
+ elif color == 'blue' and shape == 'circle':
+ return 'mandatory'
+ elif color == 'blue' and shape == 'square':
+ return 'information'
+
+ return None
+
+ def _detect_shape(self, contour) -> str:
+ """检测轮廓形状"""
+ approx = cv2.approxPolyDP(contour, 0.04 * cv2.arcLength(contour, True), True)
+ num_sides = len(approx)
+
+ if num_sides == 3:
+ return 'triangle'
+ elif num_sides == 4:
+ # 判断是正方形还是长方形
+ x, y, w, h = cv2.boundingRect(contour)
+ aspect_ratio = float(w) / h
+ if 0.8 <= aspect_ratio <= 1.2:
+ return 'square'
+ else:
+ return 'rectangle'
+ elif 8 <= num_sides <= 12: # 八边形
+ return 'octagon'
+ else:
+ # 计算圆形度
+ area = cv2.contourArea(contour)
+ perimeter = cv2.arcLength(contour, True)
+ circularity = 4 * np.pi * area / (perimeter * perimeter)
+ if circularity > 0.7:
+ return 'circle'
+
+ return 'unknown'
+
+ def _estimate_distance(self, bbox: List[float], ego_speed: float) -> float:
+ """估算标志距离(简化版本)"""
+ # 基于边界框高度估算距离
+ # 假设标志的标准高度为0.5米,焦距为1000像素
+ x1, y1, x2, y2 = bbox
+ bbox_height = y2 - y1
+
+ if bbox_height <= 0:
+ return 100.0 # 默认远距离
+
+ # 简化距离公式:距离 ∝ 1/高度
+ distance = 500.0 / bbox_height
+
+ # 根据速度调整(运动模糊)
+ if ego_speed > 10:
+ distance *= (1 + ego_speed / 100.0)
+
+ return min(distance, 100.0) # 最大100米
+
+ def _non_max_suppression(self, signs: List[Dict], iou_threshold: float = 0.5) -> List[Dict]:
+ """简单的非极大值抑制"""
+ if len(signs) <= 1:
+ return signs
+
+ # 按置信度排序
+ sorted_signs = sorted(signs, key=lambda x: x['confidence'], reverse=True)
+ keep = []
+
+ while sorted_signs:
+ # 取置信度最高的
+ best = sorted_signs.pop(0)
+ keep.append(best)
+
+ # 移除与best重叠度高的
+ sorted_signs = [
+ sign for sign in sorted_signs
+ if self._iou(best['bbox'], sign['bbox']) < iou_threshold
+ ]
+
+ return keep
+
+ def _iou(self, box1: List[float], box2: List[float]) -> float:
+ """计算IoU"""
+ x1 = max(box1[0], box2[0])
+ y1 = max(box1[1], box2[1])
+ x2 = min(box1[2], box2[2])
+ y2 = min(box1[3], box2[3])
+
+ inter_area = max(0, x2 - x1) * max(0, y2 - y1)
+ box1_area = (box1[2] - box1[0]) * (box1[3] - box1[1])
+ box2_area = (box2[2] - box2[0]) * (box2[3] - box2[1])
+
+ union_area = box1_area + box2_area - inter_area
+
+ return inter_area / union_area if union_area > 0 else 0.0
+
+ def _trigger_actions(self, signs: List[Dict], ego_speed: float):
+ """触发简单动作(可选)"""
+ for sign in signs:
+ sign_type = sign.get('type', '')
+ distance = sign.get('distance', 100.0)
+
+ if sign_type == 'stop' and distance < 20:
+ logger.warning(f"🛑 前方 {distance:.1f}米有停车标志")
+
+ elif 'speed_limit' in sign_type and distance < 30:
+ try:
+ limit = int(sign_type.split('_')[-1])
+ if ego_speed * 3.6 > limit + 5: # m/s转km/h
+ logger.warning(f"📏 前方限速{limit}km/h,当前{ego_speed*3.6:.0f}km/h")
+ except:
+ pass
+
+ def draw_signs(self, image: np.ndarray, signs: List[Dict]) -> np.ndarray:
+ """在图像上绘制检测到的标志"""
+ if not self.config.get('show_signs', True):
+ return image
+
+ result = image.copy()
+
+ for sign in signs:
+ bbox = sign['bbox']
+ sign_type = sign.get('type', 'unknown')
+ confidence = sign.get('confidence', 0.0)
+ distance = sign.get('distance', 0.0)
+
+ x1, y1, x2, y2 = map(int, bbox)
+
+ # 根据标志类型选择颜色
+ color_map = {
+ 'stop': (0, 0, 255), # 红色
+ 'warning': (0, 165, 255), # 橙色
+ 'traffic_light': (0, 255, 0), # 绿色
+ 'speed_limit': (0, 255, 255), # 黄色
+ 'unknown': (128, 128, 128) # 灰色
+ }
+
+ color = color_map.get(sign_type, (128, 128, 128))
+
+ # 绘制边界框
+ cv2.rectangle(result, (x1, y1), (x2, y2), color, 2)
+
+ # 绘制标签
+ label = f"{sign_type} {confidence:.2f}"
+ if distance > 0:
+ label += f" {distance:.1f}m"
+
+ cv2.putText(result, label, (x1, y1 - 10),
+ cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
+
+ return result
+
+ def get_signs_info(self) -> Dict:
+ """获取标志检测统计信息"""
+ return {
+ 'total_signs': len(self.detected_signs_history),
+ 'recent_signs': self.detected_signs_history[-5:] if self.detected_signs_history else [],
+ 'enabled': self.config.get('enabled', True)
+ }
+
+
+def test_detector():
+ """测试检测器"""
+ detector = TrafficSignDetector()
+
+ # 测试图片
+ test_image = np.zeros((480, 640, 3), dtype=np.uint8)
+
+ # 模拟一个红色八边形(停车标志)
+ center = (320, 240)
+ radius = 50
+ points = []
+ for i in range(8):
+ angle = 2 * np.pi * i / 8
+ x = center[0] + radius * np.cos(angle)
+ y = center[1] + radius * np.sin(angle)
+ points.append((int(x), int(y)))
+
+ cv2.fillPoly(test_image, [np.array(points)], (0, 0, 255))
+
+ # 检测
+ signs = detector.detect(test_image)
+
+ print(f"检测到 {len(signs)} 个标志")
+ for sign in signs:
+ print(f" 类型: {sign.get('type')}, 置信度: {sign.get('confidence'):.2f}")
+
+ # 绘制
+ result = detector.draw_signs(test_image, signs)
+
+ return len(signs) > 0
+
+
+if __name__ == "__main__":
+ success = test_detector()
+ if success:
+ print("✅ 标志检测器测试通过")
+ else:
+ print("⚠️ 测试未检测到标志")
\ No newline at end of file
diff --git a/src/tracking_car/carla_tracking_ros/scripts/tracker.py b/src/tracking_car/carla_tracking_ros/scripts/tracker.py
new file mode 100644
index 0000000000..a15d34dd96
--- /dev/null
+++ b/src/tracking_car/carla_tracking_ros/scripts/tracker.py
@@ -0,0 +1,966 @@
+"""
+tracker.py - 目标检测与跟踪核心算法
+包含:YOLO检测器、卡尔曼滤波、SORT跟踪器、行为分析
+"""
+
+import numpy as np
+import cv2
+import torch
+import queue
+import threading
+import time
+import sys
+import os
+import queue
+
+# 配置日志
+try:
+ from loguru import logger
+except ImportError:
+ # 使用标准logging作为回退
+ import logging
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
+ logger = logging.getLogger(__name__)
+
+from ultralytics import YOLO
+from scipy.optimize import linear_sum_assignment
+from dataclasses import dataclass
+from typing import List, Tuple, Optional, Dict, Any
+
+# 导入utils模块中的工具函数
+try:
+ from utils import iou, iou_numpy, clip_box, valid_img, bbox_center
+except ImportError:
+ # 如果在同一目录下,可以直接导入
+ sys.path.append(os.path.dirname(os.path.abspath(__file__)))
+ from utils import iou, iou_numpy, clip_box, valid_img, bbox_center
+
+# ======================== 数据结构 ========================
+
+@dataclass
+class Detection:
+ """检测结果数据结构"""
+ bbox: np.ndarray # [x1, y1, x2, y2]
+ confidence: float
+ class_id: int
+ class_name: str = "Unknown"
+
+ def __post_init__(self):
+ self.bbox = np.array(self.bbox, dtype=np.float32)
+ self.area = (self.bbox[2] - self.bbox[0]) * (self.bbox[3] - self.bbox[1])
+
+
+@dataclass
+class TrackState:
+ """跟踪状态枚举"""
+ NEW = "new"
+ TRACKED = "tracked"
+ LOST = "lost"
+ REMOVED = "removed"
+
+
+# ======================== 卡尔曼滤波器 ========================
+
+class KalmanFilter:
+ """
+ 卡尔曼滤波器 - 用于目标状态估计
+ 8维状态: [x1, y1, x2, y2, vx1, vy1, vx2, vy2]
+ 4维观测: [x1, y1, x2, y2]
+ """
+
+ def __init__(self, dt=0.05, max_speed=50.0):
+ """
+ 初始化卡尔曼滤波器
+
+ Args:
+ dt: 时间间隔(秒)
+ max_speed: 最大速度(像素/秒)
+ """
+ self.dt = dt
+ self.max_speed = max_speed
+
+ # 状态向量维度: 8
+ # 观测向量维度: 4
+ self.state_dim = 8
+ self.measure_dim = 4
+
+ # 状态转移矩阵 F
+ self.F = np.eye(self.state_dim, dtype=np.float32)
+ for i in range(4):
+ self.F[i, i + 4] = dt
+
+ # 观测矩阵 H
+ self.H = np.zeros((self.measure_dim, self.state_dim), dtype=np.float32)
+ for i in range(self.measure_dim):
+ self.H[i, i] = 1.0
+
+ # 过程噪声协方差矩阵 Q
+ self.Q = np.eye(self.state_dim, dtype=np.float32)
+ for i in range(4):
+ self.Q[i, i] = 1.0
+ for i in range(4, 8):
+ self.Q[i, i] = 5.0
+
+ # 观测噪声协方差矩阵 R
+ self.R = np.eye(self.measure_dim, dtype=np.float32) * 5.0
+
+ # 状态协方差矩阵 P
+ self.P = np.eye(self.state_dim, dtype=np.float32) * 50.0
+
+ # 状态向量 x
+ self.x = np.zeros(self.state_dim, dtype=np.float32)
+
+ # 首次更新标志
+ self.first_update = True
+
+ def init(self, bbox):
+ """
+ 初始化滤波器状态
+
+ Args:
+ bbox: 初始边界框 [x1, y1, x2, y2]
+ """
+ self.x[:4] = bbox
+ self.first_update = True
+
+ def predict(self):
+ """
+ 状态预测
+
+ Returns:
+ np.ndarray: 预测的边界框
+ """
+ # 状态预测
+ self.x = self.F @ self.x
+
+ # 协方差预测
+ self.P = self.F @ self.P @ self.F.T + self.Q
+
+ # 返回预测的边界框
+ return self.x[:4].copy()
+
+ def update(self, bbox):
+ """
+ 状态更新
+
+ Args:
+ bbox: 观测到的边界框 [x1, y1, x2, y2]
+
+ Returns:
+ np.ndarray: 更新后的边界框
+ """
+ z = np.array(bbox, dtype=np.float32)
+
+ # 计算卡尔曼增益
+ S = self.H @ self.P @ self.H.T + self.R
+ try:
+ K = self.P @ self.H.T @ np.linalg.inv(S)
+ except np.linalg.LinAlgError:
+ # 如果矩阵不可逆,使用伪逆
+ K = self.P @ self.H.T @ np.linalg.pinv(S)
+
+ # 计算新息
+ y = z - self.H @ self.x
+
+ # 状态更新
+ self.x = self.x + K @ y
+
+ # 协方差更新
+ I = np.eye(self.state_dim, dtype=np.float32)
+ self.P = (I - K @ self.H) @ self.P
+
+ self.first_update = False
+
+ return self.x[:4].copy()
+
+ def update_noise(self, speed):
+ """
+ 根据速度更新过程噪声
+
+ Args:
+ speed: 估计的速度(像素/秒)
+ """
+ # 速度归一化
+ speed_factor = min(1.0, speed / self.max_speed)
+
+ # 更新过程噪声协方差
+ for i in range(4):
+ self.Q[i, i] = 1.0 + speed_factor * 4.0
+ for i in range(4, 8):
+ self.Q[i, i] = 5.0 + speed_factor * 20.0
+
+
+# ======================== 跟踪目标 ========================
+
+class TrackedObject:
+ """
+ 单个跟踪目标
+ """
+
+ def __init__(self, track_id: int, bbox: np.ndarray,
+ img_shape: Tuple[int, int], config: Dict[str, Any]):
+ """
+ 初始化跟踪目标
+
+ Args:
+ track_id: 跟踪ID
+ bbox: 初始边界框 [x1, y1, x2, y2]
+ img_shape: 图像尺寸 (height, width)
+ config: 配置字典
+ """
+ self.track_id = track_id
+ self.img_shape = img_shape
+ self.config = config
+
+ # 卡尔曼滤波器
+ self.kf = KalmanFilter(
+ dt=config.get('kf_dt', 0.05),
+ max_speed=config.get('max_speed', 50.0)
+ )
+
+ # 边界框处理
+ self.bbox = clip_box(bbox.astype(np.float32), img_shape)
+ self.kf.init(self.bbox)
+
+ # 跟踪历史
+ self.track_history: List[Tuple[float, float]] = [] # [(cx, cy), ...]
+ self.speed_history: List[float] = [] # 速度历史
+ self.acceleration_history: List[float] = [] # 加速度历史
+
+ # 状态管理
+ self.state = TrackState.NEW
+ self.age = 0 # 存在帧数
+ self.time_since_update = 0 # 自上次更新以来的帧数
+ self.hits = 1 # 匹配次数
+ self.total_frames = 0 # 总跟踪帧数
+
+ # 检测信息
+ self.class_id: Optional[int] = None
+ self.class_name: str = "Unknown"
+ self.confidence: float = 0.0
+
+ # 行为分析
+ self.is_stopped = False
+ self.is_overtaking = False
+ self.is_lane_changing = False
+ self.is_braking = False
+ self.is_accelerating = False
+ self.is_turning = False
+ self.is_dangerous = False
+
+ self.stop_frames = 0
+ self.overtake_frames = 0
+ self.lane_change_frames = 0
+ self.brake_frames = 0
+ self.turn_frames = 0
+
+ # 预测轨迹
+ self.predicted_trajectory: List[Tuple[float, float]] = []
+
+ # 初始化历史
+ self._update_history()
+
+ def _update_history(self):
+ """更新跟踪历史"""
+ cx, cy = bbox_center(self.bbox)
+ self.track_history.append((cx, cy))
+
+ # 限制历史长度
+ max_len = self.config.get('track_history_len', 20)
+ if len(self.track_history) > max_len:
+ self.track_history.pop(0)
+
+ # 限制速度历史
+ if len(self.speed_history) > 10:
+ self.speed_history.pop(0)
+
+ # 限制加速度历史
+ if len(self.acceleration_history) > 10:
+ self.acceleration_history.pop(0)
+
+ def _calculate_speed(self) -> float:
+ """
+ 计算当前速度
+
+ Returns:
+ float: 速度(像素/秒)
+ """
+ if len(self.track_history) < 2:
+ return 0.0
+
+ # 计算最后两帧的位移
+ prev_cx, prev_cy = self.track_history[-2]
+ curr_cx, curr_cy = self.track_history[-1]
+
+ dx = curr_cx - prev_cx
+ dy = curr_cy - prev_cy
+ distance = np.sqrt(dx**2 + dy**2)
+
+ # 计算速度
+ speed = distance / self.kf.dt
+
+ # 更新速度历史
+ self.speed_history.append(speed)
+
+ # 计算加速度
+ if len(self.speed_history) >= 2:
+ acceleration = (self.speed_history[-1] - self.speed_history[-2]) / self.kf.dt
+ self.acceleration_history.append(acceleration)
+
+ return speed
+
+ def _calculate_heading(self) -> float:
+ """
+ 计算当前航向角
+
+ Returns:
+ float: 航向角(度)
+ """
+ if len(self.track_history) < 3:
+ return 0.0
+
+ # 使用最近三帧计算航向
+ cx1, cy1 = self.track_history[-3]
+ cx2, cy2 = self.track_history[-1]
+
+ dx = cx2 - cx1
+ dy = cy2 - cy1
+
+ # 计算角度(弧度转度)
+ angle = np.degrees(np.arctan2(dy, dx))
+
+ return angle
+
+ def _analyze_behavior(self, ego_center: Optional[Tuple[float, float]] = None):
+ """
+ 分析目标行为
+
+ Args:
+ ego_center: 自车中心点坐标
+ """
+ # 计算基本状态
+ speed = self._calculate_speed()
+ heading = self._calculate_heading()
+
+ # 1. 停车检测
+ stop_speed_thresh = self.config.get('stop_speed_thresh', 1.0)
+ stop_frames_thresh = self.config.get('stop_frames_thresh', 5)
+
+ if speed < stop_speed_thresh:
+ self.stop_frames += 1
+ self.is_stopped = self.stop_frames >= stop_frames_thresh
+ else:
+ self.stop_frames = 0
+ self.is_stopped = False
+
+ # 2. 超车检测
+ overtake_speed_ratio = self.config.get('overtake_speed_ratio', 1.5)
+ overtake_dist_thresh = self.config.get('overtake_dist_thresh', 50.0)
+
+ if ego_center and len(self.track_history) >= 2:
+ curr_cx, curr_cy = self.track_history[-1]
+ distance = np.sqrt((curr_cx - ego_center[0])**2 + (curr_cy - ego_center[1])**2)
+
+ if distance < overtake_dist_thresh:
+ ego_speed = getattr(self, 'ego_speed', 0.0)
+ if speed > ego_speed * overtake_speed_ratio:
+ self.overtake_frames += 1
+ self.is_overtaking = self.overtake_frames >= 3
+ else:
+ self.overtake_frames = 0
+ self.is_overtaking = False
+ else:
+ self.overtake_frames = 0
+ self.is_overtaking = False
+
+ # 3. 变道检测
+ lane_change_thresh = self.config.get('lane_change_thresh', 0.5)
+
+ if len(self.track_history) >= 5:
+ # 计算横向位移
+ lateral_displacements = []
+ for i in range(1, min(5, len(self.track_history))):
+ lateral_displacements.append(
+ abs(self.track_history[-i][0] - self.track_history[-i-1][0])
+ )
+
+ avg_lateral = np.mean(lateral_displacements) if lateral_displacements else 0.0
+
+ if avg_lateral > lane_change_thresh:
+ self.lane_change_frames += 1
+ self.is_lane_changing = self.lane_change_frames >= 3
+ else:
+ self.lane_change_frames = 0
+ self.is_lane_changing = False
+
+ # 4. 刹车/加速检测
+ brake_accel_thresh = self.config.get('brake_accel_thresh', 2.0)
+
+ if len(self.acceleration_history) >= 3:
+ avg_accel = np.mean(self.acceleration_history[-3:])
+
+ if avg_accel < -brake_accel_thresh:
+ self.brake_frames += 1
+ self.is_braking = self.brake_frames >= 2
+ self.is_accelerating = False
+ elif avg_accel > brake_accel_thresh:
+ self.is_accelerating = True
+ self.is_braking = False
+ self.brake_frames = 0
+ else:
+ self.is_braking = False
+ self.is_accelerating = False
+ self.brake_frames = 0
+
+ # 5. 转弯检测
+ turn_angle_thresh = self.config.get('turn_angle_thresh', 15.0)
+
+ if len(self.track_history) >= 3:
+ # 计算航向变化
+ if hasattr(self, '_prev_heading'):
+ heading_change = abs(heading - self._prev_heading)
+ if heading_change > turn_angle_thresh:
+ self.turn_frames += 1
+ self.is_turning = self.turn_frames >= 2
+ else:
+ self.turn_frames = 0
+ self.is_turning = False
+ self._prev_heading = heading
+
+ # 6. 危险距离检测
+ danger_dist_thresh = self.config.get('danger_dist_thresh', 10.0)
+
+ if ego_center:
+ curr_cx, curr_cy = self.track_history[-1]
+ distance = np.sqrt((curr_cx - ego_center[0])**2 + (curr_cy - ego_center[1])**2)
+ self.is_dangerous = distance < danger_dist_thresh
+
+ # 7. 预测轨迹
+ self._predict_trajectory()
+
+ def _predict_trajectory(self):
+ """预测未来轨迹"""
+ predict_frames = self.config.get('predict_frames', 10)
+ self.predicted_trajectory = []
+
+ if len(self.track_history) < 5:
+ return
+
+ # 创建临时的卡尔曼滤波器用于预测
+ temp_kf = KalmanFilter(
+ dt=self.kf.dt,
+ max_speed=self.kf.max_speed
+ )
+ temp_kf.x = self.kf.x.copy()
+ temp_kf.P = self.kf.P.copy()
+
+ # 预测未来位置
+ for _ in range(predict_frames):
+ predicted_bbox = temp_kf.predict()
+ predicted_center = bbox_center(predicted_bbox)
+ self.predicted_trajectory.append(predicted_center)
+
+ def predict(self) -> np.ndarray:
+ """
+ 预测下一帧的位置
+
+ Returns:
+ np.ndarray: 预测的边界框
+ """
+ # 预测速度用于调整噪声
+ if len(self.track_history) >= 2:
+ speed = self._calculate_speed()
+ self.kf.update_noise(speed)
+
+ # 卡尔曼预测
+ self.bbox = self.kf.predict()
+ self.bbox = clip_box(self.bbox, self.img_shape)
+
+ # 更新状态
+ self._update_history()
+ self.age += 1
+ self.time_since_update += 1
+ self.total_frames += 1
+
+ if self.time_since_update > 1:
+ self.state = TrackState.LOST
+
+ return self.bbox
+
+ def update(self, detection: Detection, ego_center: Optional[Tuple[float, float]] = None):
+ """
+ 用新的检测更新跟踪
+
+ Args:
+ detection: 检测结果
+ ego_center: 自车中心点坐标
+ """
+ # 卡尔曼更新
+ self.bbox = self.kf.update(detection.bbox)
+ self.bbox = clip_box(self.bbox, self.img_shape)
+
+ # 更新检测信息
+ self.class_id = detection.class_id
+ self.class_name = detection.class_name
+ self.confidence = detection.confidence
+
+ # 更新状态
+ self._update_history()
+ self.hits += 1
+ self.time_since_update = 0
+ self.state = TrackState.TRACKED
+
+ # 行为分析
+ self._analyze_behavior(ego_center)
+
+ def get_behavior_string(self) -> str:
+ """获取行为描述字符串"""
+ behaviors = []
+ if self.is_stopped:
+ behaviors.append("停车")
+ if self.is_overtaking:
+ behaviors.append("超车")
+ if self.is_lane_changing:
+ behaviors.append("变道")
+ if self.is_braking:
+ behaviors.append("刹车")
+ if self.is_accelerating:
+ behaviors.append("加速")
+ if self.is_turning:
+ behaviors.append("转弯")
+ if self.is_dangerous:
+ behaviors.append("危险")
+
+ return "|".join(behaviors) if behaviors else "正常"
+
+
+# ======================== YOLO检测器 ========================
+
+class YOLODetector:
+ """
+ YOLOv8检测器
+ """
+
+ def __init__(self, config: Dict[str, Any]):
+ """
+ 初始化YOLO检测器
+
+ Args:
+ config: 配置字典
+ """
+ self.config = config
+
+ # 模型配置
+ model_path = config.get('yolo_model', 'yolov8n.pt')
+ self.device = config.get('device', 'cuda' if torch.cuda.is_available() else 'cpu')
+ self.conf_thres = config.get('conf_thres', 0.5)
+ self.iou_thres = config.get('iou_thres', 0.3)
+ self.imgsz_max = config.get('yolo_imgsz_max', 320)
+ self.quantize = config.get('yolo_quantize', False)
+
+ # 类别过滤(只检测车辆)
+ self.vehicle_classes = {2: "Car", 5: "Bus", 7: "Truck"}
+
+ # 加载模型
+ self.model = self._load_model(model_path)
+
+ logger.info(f"✅ YOLO检测器初始化完成 (设备: {self.device}, 模型: {model_path})")
+
+ def _load_model(self, model_path: str):
+ """加载YOLO模型"""
+ try:
+ model = YOLO(model_path)
+
+ if self.quantize and self.device == "cuda":
+ model = model.quantize()
+
+ # 预热模型
+ dummy_input = np.random.randint(0, 255, (480, 640, 3), dtype=np.uint8)
+ _ = model.predict(dummy_input, verbose=False, device=self.device)
+
+ return model
+
+ except Exception as e:
+ logger.error(f"❌ 加载YOLO模型失败: {e}")
+ raise
+
+ def detect(self, image: np.ndarray) -> List[Detection]:
+ """
+ 检测图像中的目标
+
+ Args:
+ image: 输入图像
+
+ Returns:
+ List[Detection]: 检测结果列表
+ """
+ if not valid_img(image):
+ return []
+
+ try:
+ # 调整图像尺寸
+ h, w = image.shape[:2]
+ resize_ratio = min(self.imgsz_max / w, self.imgsz_max / h)
+ new_w = int(w * resize_ratio)
+ new_h = int(h * resize_ratio)
+
+ # 确保尺寸是32的倍数
+ new_w = (new_w + 31) // 32 * 32
+ new_h = (new_h + 31) // 32 * 32
+
+ # 执行检测
+ results = self.model.predict(
+ image,
+ conf=self.conf_thres,
+ iou=self.iou_thres,
+ imgsz=(new_h, new_w),
+ device=self.device,
+ verbose=False,
+ agnostic_nms=True
+ )
+
+ detections = []
+
+ for result in results:
+ if result.boxes is not None and len(result.boxes) > 0:
+ for box in result.boxes:
+ # 获取边界框
+ xyxy = box.xyxy[0].cpu().numpy()
+ confidence = float(box.conf[0])
+ class_id = int(box.cls[0])
+
+ # 只处理车辆类别
+ if class_id in self.vehicle_classes:
+ # 确保边界框有效
+ if xyxy[2] > xyxy[0] and xyxy[3] > xyxy[1] and confidence > 0:
+ detection = Detection(
+ bbox=xyxy,
+ confidence=confidence,
+ class_id=class_id,
+ class_name=self.vehicle_classes[class_id]
+ )
+ detections.append(detection)
+
+ return detections
+
+ except Exception as e:
+ logger.error(f"❌ YOLO检测失败: {e}")
+ return []
+
+
+# ======================== SORT跟踪器 ========================
+
+class SORTTracker:
+ """
+ SORT (Simple Online and Realtime Tracking) 跟踪器
+ """
+
+ def __init__(self, config: Dict[str, Any]):
+ """
+ 初始化SORT跟踪器
+
+ Args:
+ config: 配置字典
+ """
+ self.config = config
+
+ # 跟踪参数
+ self.max_age = config.get('max_age', 5)
+ self.min_hits = config.get('min_hits', 3)
+ self.iou_threshold = config.get('iou_thres', 0.3)
+
+ # 图像尺寸
+ self.img_height = config.get('img_height', 480)
+ self.img_width = config.get('img_width', 640)
+ self.img_shape = (self.img_height, self.img_width)
+
+ # 跟踪目标管理
+ self.tracks: List[TrackedObject] = []
+ self.next_track_id = 1
+
+ # 自车信息
+ self.ego_center = (self.img_width // 2, self.img_height // 2)
+ self.ego_speed = 0.0
+
+ logger.info("✅ SORT跟踪器初始化完成")
+
+ def update(self, detections: List[Detection],
+ ego_center: Optional[Tuple[float, float]] = None,
+ lidar_detections: Optional[List[Dict]] = None) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
+ """
+ 更新跟踪器
+
+ Args:
+ detections: 检测结果列表
+ ego_center: 自车中心点坐标
+ lidar_detections: LiDAR检测结果(可选)
+
+ Returns:
+ Tuple: (边界框数组, ID数组, 类别数组)
+ """
+ # 更新自车信息
+ if ego_center:
+ self.ego_center = ego_center
+
+ # 如果没有检测结果,只进行预测
+ if not detections:
+ # 预测所有现有目标
+ for track in self.tracks:
+ track.predict()
+
+ # 移除丢失时间过长的目标
+ self.tracks = [t for t in self.tracks if t.time_since_update <= self.max_age]
+
+ # 返回空结果
+ return np.array([]), np.array([]), np.array([])
+
+ # 预测现有目标的位置
+ for track in self.tracks:
+ track.predict()
+
+ # 创建匹配成本矩阵
+ if self.tracks:
+ # 计算IoU矩阵
+ iou_matrix = np.zeros((len(detections), len(self.tracks)), dtype=np.float32)
+
+ for i, det in enumerate(detections):
+ for j, track in enumerate(self.tracks):
+ iou_matrix[i, j] = iou(det.bbox, track.bbox)
+
+ # 将IoU转换为成本(1 - IoU)
+ cost_matrix = 1.0 - iou_matrix
+
+ # 使用匈牙利算法进行匹配
+ try:
+ det_indices, track_indices = linear_sum_assignment(cost_matrix)
+ except ValueError:
+ det_indices, track_indices = [], []
+
+ # 根据阈值过滤匹配
+ matched_pairs = []
+ unmatched_detections = set(range(len(detections)))
+ unmatched_tracks = set(range(len(self.tracks)))
+
+ for det_idx, track_idx in zip(det_indices, track_indices):
+ if iou_matrix[det_idx, track_idx] >= self.iou_threshold:
+ matched_pairs.append((det_idx, track_idx))
+ unmatched_detections.discard(det_idx)
+ unmatched_tracks.discard(track_idx)
+ else:
+ matched_pairs = []
+ unmatched_detections = set(range(len(detections)))
+ unmatched_tracks = set()
+
+ # 更新匹配的目标
+ for det_idx, track_idx in matched_pairs:
+ track = self.tracks[track_idx]
+ track.ego_speed = self.ego_speed # 传递自车速度用于行为分析
+ track.update(detections[det_idx], self.ego_center)
+
+ # 为未匹配的检测创建新目标
+ for det_idx in unmatched_detections:
+ new_track = TrackedObject(
+ track_id=self.next_track_id,
+ bbox=detections[det_idx].bbox,
+ img_shape=self.img_shape,
+ config=self.config
+ )
+ new_track.update(detections[det_idx], self.ego_center)
+ self.tracks.append(new_track)
+ self.next_track_id += 1
+
+ # 移除长时间未更新的目标
+ self.tracks = [t for t in self.tracks if t.time_since_update <= self.max_age]
+
+ # 返回跟踪结果(只返回满足最小匹配次数的目标)
+ active_tracks = [t for t in self.tracks if t.hits >= self.min_hits and t.state == TrackState.TRACKED]
+
+ if not active_tracks:
+ return np.array([]), np.array([]), np.array([])
+
+ # 提取边界框、ID和类别
+ boxes = np.array([t.bbox for t in active_tracks])
+ ids = np.array([t.track_id for t in active_tracks])
+ classes = np.array([t.class_id if t.class_id is not None else -1 for t in active_tracks])
+
+ return boxes, ids, classes
+
+ def get_tracks_info(self) -> List[Dict[str, Any]]:
+ """
+ 获取所有跟踪目标的详细信息
+
+ Returns:
+ List[Dict]: 跟踪目标信息列表
+ """
+ tracks_info = []
+
+ for track in self.tracks:
+ if track.hits >= self.min_hits and track.state == TrackState.TRACKED:
+ info = {
+ 'track_id': track.track_id,
+ 'bbox': track.bbox.tolist(),
+ 'class_id': track.class_id,
+ 'class_name': track.class_name,
+ 'confidence': track.confidence,
+ 'speed': track._calculate_speed(),
+ 'behavior': track.get_behavior_string(),
+ 'age': track.age,
+ 'hits': track.hits,
+ 'is_stopped': track.is_stopped,
+ 'is_overtaking': track.is_overtaking,
+ 'is_dangerous': track.is_dangerous,
+ }
+ tracks_info.append(info)
+
+ return tracks_info
+
+ def reset(self):
+ """重置跟踪器"""
+ self.tracks = []
+ self.next_track_id = 1
+ logger.info("✅ 跟踪器已重置")
+
+
+# ======================== 检测线程 ========================
+
+class DetectionThread(threading.Thread):
+ """
+ 检测线程 - 将检测过程放到单独线程中
+ """
+
+ def __init__(self, detector: YOLODetector, input_queue: queue.Queue,
+ output_queue: queue.Queue, maxsize: int = 2):
+ """
+ 初始化检测线程
+
+ Args:
+ detector: YOLO检测器
+ input_queue: 输入图像队列
+ output_queue: 输出检测结果队列
+ maxsize: 队列最大大小
+ """
+ super().__init__(daemon=True)
+ self.detector = detector
+ self.input_queue = input_queue
+ self.output_queue = output_queue
+ self.running = True
+ self.processed_count = 0
+
+ logger.info("✅ 检测线程初始化完成")
+
+ def run(self):
+ """线程主函数"""
+ while self.running:
+ try:
+ # 从输入队列获取图像
+ image = self.input_queue.get(timeout=1.0)
+
+ if not valid_img(image):
+ self.output_queue.put((image, []))
+ continue
+
+ # 执行检测
+ detections = self.detector.detect(image)
+
+ # 放入输出队列
+ if self.output_queue.full():
+ try:
+ self.output_queue.get_nowait()
+ except queue.Empty:
+ pass
+
+ self.output_queue.put((image, detections))
+ self.processed_count += 1
+
+ except queue.Empty:
+ continue
+ except Exception as e:
+ logger.error(f"❌ 检测线程出错: {e}")
+ self.output_queue.put((None, []))
+
+ def stop(self):
+ """停止线程"""
+ self.running = False
+ logger.info("🛑 检测线程已停止")
+
+
+# ======================== 测试函数 ========================
+
+def test_tracker():
+ """测试跟踪器"""
+ print("=" * 50)
+ print("测试 tracker.py...")
+ print("=" * 50)
+
+ # 模拟配置
+ test_config = {
+ 'yolo_model': 'yolov8n.pt',
+ 'conf_thres': 0.5,
+ 'iou_thres': 0.3,
+ 'max_age': 5,
+ 'min_hits': 3,
+ 'kf_dt': 0.05,
+ 'max_speed': 50.0,
+ 'img_width': 640,
+ 'img_height': 480,
+ 'track_history_len': 20,
+ 'stop_speed_thresh': 1.0,
+ 'stop_frames_thresh': 5,
+ 'overtake_speed_ratio': 1.5,
+ 'overtake_dist_thresh': 50.0,
+ 'lane_change_thresh': 0.5,
+ 'brake_accel_thresh': 2.0,
+ 'turn_angle_thresh': 15.0,
+ 'danger_dist_thresh': 10.0,
+ 'predict_frames': 10,
+ }
+
+ # 测试数据结构
+ print("1. 测试数据结构...")
+ bbox = np.array([100, 100, 200, 200], dtype=np.float32)
+ detection = Detection(bbox=bbox, confidence=0.9, class_id=2, class_name="Car")
+ assert detection.confidence == 0.9
+ assert detection.class_id == 2
+ print(" ✅ Detection数据结构测试通过")
+
+ # 测试卡尔曼滤波器
+ print("2. 测试卡尔曼滤波器...")
+ kf = KalmanFilter(dt=0.05)
+ kf.init(bbox)
+ predicted = kf.predict()
+ assert len(predicted) == 4
+ updated = kf.update(bbox + 10)
+ assert len(updated) == 4
+ print(" ✅ 卡尔曼滤波器测试通过")
+
+ # 测试跟踪目标
+ print("3. 测试跟踪目标...")
+ track = TrackedObject(
+ track_id=1,
+ bbox=bbox,
+ img_shape=(480, 640),
+ config=test_config
+ )
+ track.update(detection, ego_center=(320, 240))
+ assert track.track_id == 1
+ assert track.class_id == 2
+ track.predict()
+ print(" ✅ 跟踪目标测试通过")
+
+ # 测试SORT跟踪器
+ print("4. 测试SORT跟踪器...")
+ tracker = SORTTracker(test_config)
+ detections = [detection]
+ boxes, ids, classes = tracker.update(detections)
+ assert len(boxes) >= 0 # 可能没有匹配到
+ print(" ✅ SORT跟踪器测试通过")
+
+ print("=" * 50)
+ print("✅ tracker.py 所有测试通过")
+ print("注:完整测试需要YOLO模型文件")
+
+ return True
+
+
+if __name__ == "__main__":
+ test_tracker()
\ No newline at end of file
diff --git a/src/tracking_car/carla_tracking_ros/scripts/utils.py b/src/tracking_car/carla_tracking_ros/scripts/utils.py
new file mode 100644
index 0000000000..0ddf17f2a4
--- /dev/null
+++ b/src/tracking_car/carla_tracking_ros/scripts/utils.py
@@ -0,0 +1,689 @@
+"""
+utils.py - 通用工具函数
+包含:图像处理、几何计算、性能监控、文件操作等工具函数
+"""
+
+import cv2
+import numpy as np
+import time
+import os
+import sys
+from numba import njit
+from datetime import datetime
+
+# 配置loguru logger
+# 配置日志
+try:
+ from loguru import logger
+except ImportError:
+ import logging
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
+ logger = logging.getLogger(__name__)
+
+# 尝试导入yaml,如果失败提供友好的错误信息
+try:
+ import yaml
+ YAML_AVAILABLE = True
+except ImportError:
+ YAML_AVAILABLE = False
+ logger.warning("PyYAML未安装,配置文件功能将受限")
+
+def valid_img(img):
+ """
+ 检查图像是否有效
+
+ Args:
+ img: 输入图像
+
+ Returns:
+ bool: 图像是否有效
+ """
+ return img is not None and len(img.shape) == 3 and img.shape[2] == 3 and img.size > 0
+
+def clip_box(bbox, img_shape):
+ """
+ 裁剪边界框到图像范围内
+
+ Args:
+ bbox: [x1, y1, x2, y2] 边界框坐标
+ img_shape: (height, width) 图像尺寸
+
+ Returns:
+ np.ndarray: 裁剪后的边界框
+ """
+ h, w = img_shape[:2]
+ return np.array([
+ max(0, min(bbox[0], w - 1)),
+ max(0, min(bbox[1], h - 1)),
+ max(bbox[0] + 1, min(bbox[2], w - 1)),
+ max(bbox[1] + 1, min(bbox[3], h - 1))
+ ], dtype=np.float32)
+
+def make_div(x, d=32):
+ """
+ 将数值调整为d的倍数(用于YOLO输入尺寸)
+
+ Args:
+ x: 原始数值
+ d: 倍数(默认为32)
+
+ Returns:
+ int: 调整后的数值
+ """
+ return (x + d - 1) // d * d
+
+def resize_with_padding(image, target_size, color=(114, 114, 114)):
+ """
+ 保持长宽比的resize,用指定颜色填充
+
+ Args:
+ image: 输入图像
+ target_size: (width, height) 目标尺寸
+ color: 填充颜色
+
+ Returns:
+ tuple: (resized_image, scale, padding)
+ """
+ h, w = image.shape[:2]
+ target_w, target_h = target_size
+
+ # 计算缩放比例
+ scale = min(target_w / w, target_h / h)
+ new_w = int(w * scale)
+ new_h = int(h * scale)
+
+ # 缩放图像
+ if scale != 1:
+ image = cv2.resize(image, (new_w, new_h), interpolation=cv2.INTER_LINEAR)
+
+ # 创建填充图像
+ padded = np.full((target_h, target_w, 3), color, dtype=np.uint8)
+
+ # 计算填充位置(居中)
+ dx = (target_w - new_w) // 2
+ dy = (target_h - new_h) // 2
+
+ # 放置图像
+ padded[dy:dy + new_h, dx:dx + new_w] = image
+
+ return padded, scale, (dx, dy)
+
+@njit
+def iou_numpy(box1, box2):
+ """
+ 计算两个边界框的IoU(交并比)- 使用numpy数组版本
+
+ Args:
+ box1: np.array([x1, y1, x2, y2])
+ box2: np.array([x1, y1, x2, y2])
+
+ Returns:
+ float: IoU值
+ """
+ ix1 = max(box1[0], box2[0])
+ iy1 = max(box1[1], box2[1])
+ ix2 = min(box1[2], box2[2])
+ iy2 = min(box1[3], box2[3])
+
+ ia = max(0, ix2 - ix1) * max(0, iy2 - iy1)
+ a1 = (box1[2] - box1[0]) * (box1[3] - box1[1])
+ a2 = (box2[2] - box2[0]) * (box2[3] - box2[1])
+ ua = a1 + a2 - ia
+
+ return ia / ua if ua > 0 else 0.0
+
+def iou(box1, box2):
+ """
+ 计算两个边界框的IoU(兼容list和numpy数组)
+
+ Args:
+ box1: [x1, y1, x2, y2] 或 np.array
+ box2: [x1, y1, x2, y2] 或 np.array
+
+ Returns:
+ float: IoU值
+ """
+ # 转换为numpy数组
+ box1_np = np.array(box1, dtype=np.float32)
+ box2_np = np.array(box2, dtype=np.float32)
+ return iou_numpy(box1_np, box2_np)
+
+@njit
+def iou_batch(boxes1, boxes2):
+ """
+ 批量计算IoU矩阵
+
+ Args:
+ boxes1: (N, 4) 边界框数组
+ boxes2: (M, 4) 边界框数组
+
+ Returns:
+ np.ndarray: (N, M) IoU矩阵
+ """
+ N = boxes1.shape[0]
+ M = boxes2.shape[0]
+ iou_matrix = np.zeros((N, M), dtype=np.float32)
+
+ for i in range(N):
+ for j in range(M):
+ iou_matrix[i, j] = iou_numpy(boxes1[i], boxes2[j])
+
+ return iou_matrix
+
+def bbox_center(bbox):
+ """
+ 计算边界框中心点
+
+ Args:
+ bbox: [x1, y1, x2, y2] 边界框
+
+ Returns:
+ tuple: (cx, cy) 中心点坐标
+ """
+ return ((bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2)
+
+def bbox_area(bbox):
+ """
+ 计算边界框面积
+
+ Args:
+ bbox: [x1, y1, x2, y2] 边界框
+
+ Returns:
+ float: 边界框面积
+ """
+ return max(0, bbox[2] - bbox[0]) * max(0, bbox[3] - bbox[1])
+
+def bbox_aspect_ratio(bbox):
+ """
+ 计算边界框宽高比
+
+ Args:
+ bbox: [x1, y1, x2, y2] 边界框
+
+ Returns:
+ float: 宽高比(宽/高)
+ """
+ width = max(0.1, bbox[2] - bbox[0])
+ height = max(0.1, bbox[3] - bbox[1])
+ return width / height
+
+class FPSCounter:
+ """
+ FPS计数器
+ """
+
+ def __init__(self, window_size=15):
+ """
+ Args:
+ window_size: 滑动窗口大小
+ """
+ self.window_size = window_size
+ self.timestamps = []
+ self.fps = 0.0
+ self.avg_fps = 0.0
+ self.fps_history = []
+
+ def update(self):
+ """
+ 更新FPS计数
+
+ Returns:
+ float: 当前FPS
+ """
+ self.timestamps.append(time.time())
+
+ if len(self.timestamps) > self.window_size:
+ self.timestamps.pop(0)
+
+ if len(self.timestamps) >= 2:
+ self.fps = (len(self.timestamps) - 1) / (self.timestamps[-1] - self.timestamps[0])
+ self.fps_history.append(self.fps)
+
+ if len(self.fps_history) > 100:
+ self.fps_history.pop(0)
+
+ self.avg_fps = np.mean(self.fps_history) if self.fps_history else self.fps
+
+ return self.fps
+
+ def reset(self):
+ """重置计数器"""
+ self.timestamps = []
+ self.fps = 0.0
+ self.fps_history = []
+ self.avg_fps = 0.0
+
+class PerformanceMonitor:
+ """
+ 性能监控器
+ """
+
+ def __init__(self):
+ self.frame_count = 0
+ self.start_time = time.time()
+ self.frame_times = []
+ self.detection_times = []
+ self.tracking_times = []
+
+ def start_frame(self):
+ """开始新帧计时"""
+ self.frame_start = time.time()
+
+ def end_frame(self):
+ """结束帧计时"""
+ frame_time = time.time() - self.frame_start
+ self.frame_times.append(frame_time)
+ self.frame_count += 1
+
+ # 保留最近100帧的计时
+ if len(self.frame_times) > 100:
+ self.frame_times.pop(0)
+
+ def record_detection_time(self, dt):
+ """记录检测时间"""
+ self.detection_times.append(dt)
+ if len(self.detection_times) > 100:
+ self.detection_times.pop(0)
+
+ def record_tracking_time(self, dt):
+ """记录跟踪时间"""
+ self.tracking_times.append(dt)
+ if len(self.tracking_times) > 100:
+ self.tracking_times.pop(0)
+
+ def get_stats(self):
+ """获取性能统计"""
+ stats = {
+ 'total_frames': self.frame_count,
+ 'total_time': time.time() - self.start_time,
+ 'avg_fps': len(self.frame_times) / sum(self.frame_times) if self.frame_times else 0,
+ 'avg_frame_time': np.mean(self.frame_times) * 1000 if self.frame_times else 0,
+ 'avg_detection_time': np.mean(self.detection_times) * 1000 if self.detection_times else 0,
+ 'avg_tracking_time': np.mean(self.tracking_times) * 1000 if self.tracking_times else 0,
+ }
+ return stats
+
+ def print_stats(self):
+ """打印性能统计"""
+ stats = self.get_stats()
+ logger.info(f"总帧数: {stats['total_frames']}")
+ logger.info(f"总时间: {stats['total_time']:.1f}s")
+ logger.info(f"平均FPS: {stats['avg_fps']:.1f}")
+ logger.info(f"平均帧时间: {stats['avg_frame_time']:.1f}ms")
+ logger.info(f"平均检测时间: {stats['avg_detection_time']:.1f}ms")
+ logger.info(f"平均跟踪时间: {stats['avg_tracking_time']:.1f}ms")
+
+def create_output_dir(base_dir="outputs"):
+ """
+ 创建输出目录
+
+ Args:
+ base_dir: 基础目录名
+
+ Returns:
+ str: 创建的目录路径
+ """
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
+ output_dir = os.path.join(base_dir, timestamp)
+
+ # 创建目录
+ os.makedirs(output_dir, exist_ok=True)
+ os.makedirs(os.path.join(output_dir, "screenshots"), exist_ok=True)
+ os.makedirs(os.path.join(output_dir, "logs"), exist_ok=True)
+
+ logger.info(f"创建输出目录: {output_dir}")
+ return output_dir
+
+def save_image(image, path, create_dir=True):
+ """
+ 保存图像
+
+ Args:
+ image: 要保存的图像
+ path: 保存路径
+ create_dir: 是否创建目录
+
+ Returns:
+ bool: 是否保存成功
+ """
+ if not valid_img(image):
+ logger.warning(f"无效图像,无法保存到 {path}")
+ return False
+
+ try:
+ if create_dir:
+ os.makedirs(os.path.dirname(path), exist_ok=True)
+
+ cv2.imwrite(path, image)
+ logger.debug(f"图像已保存: {path}")
+ return True
+
+ except Exception as e:
+ logger.error(f"保存图像失败 {path}: {e}")
+ return False
+
+def load_yaml_config(path):
+ """
+ 加载YAML配置文件
+
+ Args:
+ path: 配置文件路径
+
+ Returns:
+ dict: 配置字典
+ """
+ if not YAML_AVAILABLE:
+ logger.error("无法加载YAML配置: PyYAML未安装")
+ return {}
+
+ try:
+ with open(path, 'r', encoding='utf-8') as f:
+ config = yaml.safe_load(f)
+ logger.info(f"配置文件加载成功: {path}")
+ return config if config else {}
+ except FileNotFoundError:
+ logger.warning(f"配置文件不存在: {path}")
+ return {}
+ except Exception as e:
+ logger.error(f"加载配置文件失败 {path}: {e}")
+ return {}
+
+def save_yaml_config(config, path):
+ """
+ 保存配置到YAML文件
+
+ Args:
+ config: 配置字典
+ path: 保存路径
+ """
+ if not YAML_AVAILABLE:
+ logger.error("无法保存YAML配置: PyYAML未安装")
+ return
+
+ try:
+ os.makedirs(os.path.dirname(path), exist_ok=True)
+ with open(path, 'w', encoding='utf-8') as f:
+ yaml.dump(config, f, default_flow_style=False, indent=2)
+ logger.debug(f"配置已保存: {path}")
+ except Exception as e:
+ logger.error(f"保存配置失败 {path}: {e}")
+
+def draw_bbox(image, bbox, color=(255, 0, 0), thickness=2, label=None):
+ """
+ 在图像上绘制单个边界框
+
+ Args:
+ image: 输入图像
+ bbox: [x1, y1, x2, y2] 边界框
+ color: 颜色 (B, G, R)
+ thickness: 线宽
+ label: 标签文本
+
+ Returns:
+ np.ndarray: 绘制后的图像
+ """
+ if not valid_img(image):
+ return image
+
+ x1, y1, x2, y2 = map(int, bbox)
+
+ # 检查坐标有效性
+ if x1 >= x2 or y1 >= y2:
+ return image
+
+ # 绘制边界框
+ cv2.rectangle(image, (x1, y1), (x2, y2), color, thickness)
+
+ # 绘制标签
+ if label:
+ font = cv2.FONT_HERSHEY_SIMPLEX
+ font_scale = 0.5
+
+ # 获取文本尺寸
+ (text_width, text_height), baseline = cv2.getTextSize(label, font, font_scale, thickness)
+
+ # 绘制标签背景
+ cv2.rectangle(image, (x1, y1 - text_height - 5),
+ (x1 + text_width, y1), color, -1)
+
+ # 绘制文本
+ cv2.putText(image, label, (x1, y1 - 5),
+ font, font_scale, (255, 255, 255), thickness)
+
+ return image
+
+def draw_trajectory(image, points, color=(0, 255, 0), thickness=2, max_points=20):
+ """
+ 在图像上绘制轨迹
+
+ Args:
+ image: 输入图像
+ points: 轨迹点列表 [(x1, y1), (x2, y2), ...]
+ color: 轨迹颜色
+ thickness: 线宽
+ max_points: 最大显示点数
+
+ Returns:
+ np.ndarray: 绘制后的图像
+ """
+ if not valid_img(image) or len(points) < 2:
+ return image
+
+ # 限制轨迹点数量
+ points = points[-max_points:]
+
+ # 绘制轨迹线
+ for i in range(1, len(points)):
+ pt1 = (int(points[i-1][0]), int(points[i-1][1]))
+ pt2 = (int(points[i][0]), int(points[i][1]))
+
+ # 检查点是否有效
+ if 0 <= pt1[0] < image.shape[1] and 0 <= pt1[1] < image.shape[0] and \
+ 0 <= pt2[0] < image.shape[1] and 0 <= pt2[1] < image.shape[0]:
+ cv2.line(image, pt1, pt2, color, thickness)
+
+ return image
+
+def draw_info_panel(image, info_dict, position="top_left"):
+ """
+ 在图像上绘制信息面板
+
+ Args:
+ image: 输入图像
+ info_dict: 信息字典 {key: value}
+ position: 位置 ("top_left", "top_right", "bottom_left", "bottom_right")
+
+ Returns:
+ np.ndarray: 绘制后的图像
+ """
+ if not valid_img(image):
+ return image
+
+ h, w = image.shape[:2]
+ font = cv2.FONT_HERSHEY_SIMPLEX
+ font_scale = 0.6
+ thickness = 1
+ line_height = 25
+
+ # 确定起始位置
+ if position == "top_left":
+ x, y = 10, 30
+ elif position == "top_right":
+ x, y = w - 200, 30
+ elif position == "bottom_left":
+ x, y = 10, h - 30 - len(info_dict) * line_height
+ elif position == "bottom_right":
+ x, y = w - 200, h - 30 - len(info_dict) * line_height
+ else:
+ x, y = 10, 30
+
+ # 绘制信息背景
+ bg_height = len(info_dict) * line_height + 10
+ cv2.rectangle(image, (x - 5, y - 25), (x + 190, y + bg_height - 20), (0, 0, 0), -1)
+
+ # 绘制标题
+ cv2.putText(image, "SYSTEM INFO", (x, y - 5), font, 0.7, (0, 255, 0), thickness)
+
+ # 绘制信息项
+ for i, (key, value) in enumerate(info_dict.items()):
+ text = f"{key}: {value}"
+ cv2.putText(image, text, (x, y + (i + 1) * line_height),
+ font, font_scale, (255, 255, 255), thickness)
+
+ return image
+
+def run_self_tests():
+ """运行自测试"""
+ print("=" * 50)
+ print("运行 utils.py 自测试...")
+ print("=" * 50)
+
+ tests_passed = 0
+ tests_failed = 0
+
+ # 测试 1: valid_img
+ try:
+ test_img = np.zeros((100, 100, 3), dtype=np.uint8)
+ assert valid_img(test_img) == True, "valid_img应该返回True"
+ assert valid_img(None) == False, "valid_img(None)应该返回False"
+ assert valid_img(np.zeros((100, 100), dtype=np.uint8)) == False, "灰度图应该返回False"
+ print("✅ valid_img测试通过")
+ tests_passed += 1
+ except AssertionError as e:
+ print(f"❌ valid_img测试失败: {e}")
+ tests_failed += 1
+
+ # 测试 2: clip_box
+ try:
+ bbox = [10, 10, 200, 200]
+ clipped = clip_box(bbox, (150, 150))
+ expected = [10, 10, 149, 149] # 索引从0开始,所以是149不是150
+ assert np.allclose(clipped[:2], expected[:2]), f"clip_box坐标错误: {clipped[:2]} != {expected[:2]}"
+ assert clipped[2] <= 149 and clipped[3] <= 149, "clip_box应该限制在图像范围内"
+ print("✅ clip_box测试通过")
+ tests_passed += 1
+ except AssertionError as e:
+ print(f"❌ clip_box测试失败: {e}")
+ tests_failed += 1
+
+ # 测试 3: iou (兼容性版本)
+ try:
+ box1 = [0, 0, 10, 10]
+ box2 = [5, 5, 15, 15]
+ iou_val = iou(box1, box2)
+ expected_iou = 25 / (100 + 100 - 25) # (5x5)/(100+100-25) = 25/175 ≈ 0.1429
+ assert abs(iou_val - expected_iou) < 0.001, f"iou计算错误: {iou_val} != {expected_iou}"
+
+ # 测试numpy数组版本
+ box1_np = np.array(box1, dtype=np.float32)
+ box2_np = np.array(box2, dtype=np.float32)
+ iou_val_np = iou_numpy(box1_np, box2_np)
+ assert abs(iou_val_np - expected_iou) < 0.001, f"iou_numpy计算错误"
+
+ print("✅ iou测试通过")
+ tests_passed += 1
+ except AssertionError as e:
+ print(f"❌ iou测试失败: {e}")
+ tests_failed += 1
+
+ # 测试 4: make_div
+ try:
+ assert make_div(100) == 128, "make_div(100)应该返回128"
+ assert make_div(128) == 128, "make_div(128)应该返回128"
+ assert make_div(129) == 160, "make_div(129)应该返回160"
+ assert make_div(0, 32) == 0, "make_div(0)应该返回0"
+ print("✅ make_div测试通过")
+ tests_passed += 1
+ except AssertionError as e:
+ print(f"❌ make_div测试失败: {e}")
+ tests_failed += 1
+
+ # 测试 5: FPSCounter (修复的测试)
+ try:
+ fps_counter = FPSCounter(window_size=3)
+
+ # 第一次update会初始化但不会计算FPS(需要至少2个时间点)
+ fps1 = fps_counter.update()
+ time.sleep(0.05) # 等待50ms
+
+ # 第二次update才会计算FPS
+ fps2 = fps_counter.update()
+ time.sleep(0.05)
+
+ fps3 = fps_counter.update()
+
+ # 现在应该有FPS值了
+ assert fps3 > 0, f"FPS应该大于0,当前: {fps3}"
+ assert fps_counter.fps > 0, f"内部FPS应该大于0"
+
+ print(f"✅ FPSCounter测试通过 (FPS: {fps3:.1f})")
+ tests_passed += 1
+ except Exception as e:
+ print(f"❌ FPSCounter测试失败: {e}")
+ import traceback
+ traceback.print_exc()
+ tests_failed += 1
+
+ # 测试 6: 可视化函数
+ try:
+ test_img = np.zeros((100, 100, 3), dtype=np.uint8)
+ # 测试draw_bbox
+ result1 = draw_bbox(test_img.copy(), [10, 10, 50, 50], label="test")
+ assert result1.shape == test_img.shape, "draw_bbox应该返回相同尺寸的图像"
+
+ # 测试draw_trajectory
+ points = [(20, 20), (30, 30), (40, 40)]
+ result2 = draw_trajectory(test_img.copy(), points)
+ assert result2.shape == test_img.shape, "draw_trajectory应该返回相同尺寸的图像"
+
+ # 测试draw_info_panel
+ info = {"FPS": "30.0", "Objects": "5"}
+ result3 = draw_info_panel(test_img.copy(), info)
+ assert result3.shape == test_img.shape, "draw_info_panel应该返回相同尺寸的图像"
+
+ print("✅ 可视化函数测试通过")
+ tests_passed += 1
+ except Exception as e:
+ print(f"❌ 可视化函数测试失败: {e}")
+ import traceback
+ traceback.print_exc()
+ tests_failed += 1
+
+ # 测试 7: bbox工具函数
+ try:
+ bbox = [10, 20, 50, 80]
+ center = bbox_center(bbox)
+ area = bbox_area(bbox)
+ aspect = bbox_aspect_ratio(bbox)
+
+ assert center == (30.0, 50.0), f"中心点计算错误: {center}"
+ assert area == 40 * 60, f"面积计算错误: {area}"
+ assert abs(aspect - 40/60) < 0.001, f"宽高比计算错误: {aspect}"
+
+ print("✅ bbox工具函数测试通过")
+ tests_passed += 1
+ except Exception as e:
+ print(f"❌ bbox工具函数测试失败: {e}")
+ tests_failed += 1
+
+ print("=" * 50)
+ print(f"测试结果: {tests_passed}通过, {tests_failed}失败")
+
+ if tests_failed == 0:
+ print("🎉 所有测试通过!")
+ else:
+ print("⚠️ 有测试失败,请检查")
+
+ return tests_failed == 0
+
+if __name__ == "__main__":
+ # 运行自测试
+ success = run_self_tests()
+
+ if success:
+ print("\nutils.py 可以安全使用")
+ sys.exit(0)
+ else:
+ print("\n⚠️ utils.py 有测试失败,请修复")
+ sys.exit(1)
\ No newline at end of file
diff --git a/src/tracking_car/carla_tracking_ros/setup.py b/src/tracking_car/carla_tracking_ros/setup.py
new file mode 100644
index 0000000000..8d4fe919f9
--- /dev/null
+++ b/src/tracking_car/carla_tracking_ros/setup.py
@@ -0,0 +1,16 @@
+#!/usr/bin/env python
+from distutils.core import setup
+from catkin_pkg.python_setup import generate_distutils_setup
+
+d = generate_distutils_setup(
+ packages=[],
+ package_dir={},
+ scripts=[
+ 'scripts/main.py',
+ 'scripts/sensors.py',
+ 'scripts/tracker.py',
+ 'scripts/utils.py'
+ ]
+)
+
+setup(**d)
diff --git a/src/tracking_car/config.yaml b/src/tracking_car/config.yaml
index 79759fe126..5e1d9064b0 100644
--- a/src/tracking_car/config.yaml
+++ b/src/tracking_car/config.yaml
@@ -78,4 +78,13 @@ hot_reload_safe_keys:
- "yolo_imgsz_max" # YOLO输入尺寸
- "stop_speed_thresh" # 停车速度阈值
- "danger_dist_thresh" # 危险距离阈值
- - "weather" # 天气设置
\ No newline at end of file
+ - "weather" # 天气设置
+
+# ======================== 交通标志检测配置 ========================
+enable_sign_detection: false # 是否启用交通标志检测
+
+traffic_sign:
+ enabled: true # 检测器是否启用
+ show_signs: true # 是否显示检测框
+ enable_actions: false # 是否触发动作(警告等)
+ conf_threshold: 0.5 # 置信度阈值
\ No newline at end of file
diff --git a/src/tracking_car/main.py b/src/tracking_car/main.py
index 06b02aea15..2329ae04c6 100644
--- a/src/tracking_car/main.py
+++ b/src/tracking_car/main.py
@@ -1102,6 +1102,13 @@ def _draw_info_panel(self, image, track_count):
f"L: PointCloud | V: View Mode" # 添加点云提示
]
+ status_lines = [
+ f"Tracked Objects: {track_count}",
+ f"ESC: Exit | W: Weather | S: Screenshot",
+ f"P: Pause | T: Stats | M: Color Legend",
+ f"L: PointCloud | V: View | O: SignDetect", # 修改这一行
+ ]
+
# Draw status information
font = cv2.FONT_HERSHEY_SIMPLEX
for i, line in enumerate(status_lines):
@@ -1312,6 +1319,10 @@ def __init__(self, config):
self.image_queue = None
self.result_queue = None
+ # 添加交通标志检测器
+ self.traffic_sign_detector = None
+ self.enable_sign_detection = config.get('enable_sign_detection', False)
+
logger.info("[OK] Tracking system initialized (Color ID encoding + Independent statistics window)")
def initialize(self):
@@ -1380,7 +1391,18 @@ def initialize(self):
import traceback
traceback.print_exc()
return False
-
+
+ if self.enable_sign_detection:
+ try:
+ from sign_detector import TrafficSignDetector
+ self.traffic_sign_detector = TrafficSignDetector(
+ config.get('traffic_sign', {})
+ )
+ logger.info("✅ 交通标志检测器已启用")
+ except Exception as e:
+ logger.warning(f"交通标志检测器初始化失败: {e}")
+ self.traffic_sign_detector = None
+
def _setup_detection_thread(self):
"""Setup detection thread"""
try:
@@ -1784,6 +1806,23 @@ def run(self):
if self.frame_count % 100 == 0:
self._print_status(stats_data)
+ # ============ 新增:交通标志检测 ============
+ detected_signs = []
+ if self.traffic_sign_detector and self.enable_sign_detection:
+ try:
+ # 检测标志
+ detected_signs = self.traffic_sign_detector.detect(
+ image,
+ ego_speed=self.ego_vehicle.get_velocity().length() if self.ego_vehicle else 0.0
+ )
+
+ # 在图像上绘制标志
+ if self.traffic_sign_detector.config.get('show_signs', True):
+ image = self.traffic_sign_detector.draw_signs(image, detected_signs)
+
+ except Exception as e:
+ logger.debug(f"标志检测失败: {e}")
+
except KeyboardInterrupt:
logger.info("[STOP] User interrupted program")
except Exception as e:
@@ -1860,6 +1899,12 @@ def _handle_keyboard_input(self, key):
# 新增:R键手动重载配置
elif key == ord('r') or key == ord('R'):
self._force_reload_config()
+
+ # 添加:O键切换标志检测
+ elif key == ord('o') or key == ord('O'):
+ self.enable_sign_detection = not self.enable_sign_detection
+ status = "开启" if self.enable_sign_detection else "关闭"
+ logger.info(f"🚦 交通标志检测: {status}")
def _control_frame_rate(self, current_fps):
"""自适应帧率控制(简单版)"""
diff --git a/src/tracking_car/sign_detector.py b/src/tracking_car/sign_detector.py
new file mode 100644
index 0000000000..bbd8a82f18
--- /dev/null
+++ b/src/tracking_car/sign_detector.py
@@ -0,0 +1,438 @@
+"""
+sign_detector.py - 轻量级交通标志识别模块
+最小化集成,不破坏现有结构
+"""
+
+import cv2
+import numpy as np
+import torch
+from typing import List, Dict, Any, Optional
+import time
+
+try:
+ from loguru import logger
+except ImportError:
+ import logging
+ logger = logging.getLogger(__name__)
+
+class TrafficSignDetector:
+ """轻量级交通标志检测器"""
+
+ def __init__(self, config=None):
+ """
+ 初始化
+ Args:
+ config: 可选配置,默认使用内置简单配置
+ """
+ # 默认配置
+ self.config = config or {
+ 'enabled': True,
+ 'conf_threshold': 0.5,
+ 'show_signs': True,
+ 'enable_actions': False, # 默认不触发动作,只显示
+ }
+
+ # 尝试加载YOLO模型,如果失败则使用简单的颜色检测
+ self.model = None
+ self.use_yolo = False
+
+ try:
+ from ultralytics import YOLO
+ # 尝试加载预训练模型(可以用通用物体检测模型)
+ self.model = YOLO('yolov8n.pt') # 使用现有的YOLO模型
+ self.use_yolo = True
+ logger.info("✅ 使用YOLO进行标志检测")
+ except Exception as e:
+ logger.warning(f"无法加载YOLO模型,使用简单颜色检测: {e}")
+ self._init_simple_detector()
+
+ # 简单的标志颜色检测
+ self.sign_colors = {
+ 'red': { # 停车、禁止类标志
+ 'lower': np.array([0, 50, 50]),
+ 'upper': np.array([10, 255, 255])
+ },
+ 'blue': { # 指示类标志
+ 'lower': np.array([100, 50, 50]),
+ 'upper': np.array([130, 255, 255])
+ },
+ 'yellow': { # 警告类标志
+ 'lower': np.array([20, 100, 100]),
+ 'upper': np.array([30, 255, 255])
+ }
+ }
+
+ # 标志形状模板(可选)
+ self.shape_templates = self._load_shape_templates()
+
+ # 检测历史
+ self.detected_signs_history = []
+
+ logger.info("✅ 轻量级标志检测器初始化完成")
+
+ def _init_simple_detector(self):
+ """初始化简单检测器"""
+ # 加载简单的形状模板
+ self.templates = {
+ 'triangle': self._create_triangle_mask(),
+ 'circle': self._create_circle_mask(),
+ 'octagon': self._create_octagon_mask(), # 停车标志
+ 'square': self._create_square_mask()
+ }
+
+ def detect(self, image: np.ndarray, ego_speed: float = 0.0) -> List[Dict]:
+ """
+ 检测图像中的交通标志
+
+ Args:
+ image: 输入图像
+ ego_speed: 自车速度(用于距离估计)
+
+ Returns:
+ List[Dict]: 检测到的标志列表
+ """
+ if not self.config.get('enabled', True):
+ return []
+
+ signs = []
+
+ try:
+ # 方法1:如果YOLO可用,使用YOLO检测
+ if self.use_yolo and self.model is not None:
+ signs = self._detect_with_yolo(image, ego_speed)
+
+ # 方法2:否则使用简单的颜色+形状检测
+ else:
+ signs = self._detect_with_color(image, ego_speed)
+
+ # 过滤重复的标志(简单的NMS)
+ signs = self._non_max_suppression(signs)
+
+ # 更新历史
+ if signs:
+ self.detected_signs_history = signs[:10] # 保留最近10个
+
+ # 简单的动作触发(可选的)
+ if self.config.get('enable_actions', False):
+ self._trigger_actions(signs, ego_speed)
+
+ return signs
+
+ except Exception as e:
+ logger.error(f"标志检测失败: {e}")
+ return []
+
+ def _detect_with_yolo(self, image: np.ndarray, ego_speed: float) -> List[Dict]:
+ """使用YOLO检测"""
+ results = self.model.predict(
+ image,
+ conf=self.config.get('conf_threshold', 0.5),
+ verbose=False
+ )
+
+ signs = []
+
+ for result in results:
+ if result.boxes is not None:
+ for box in result.boxes:
+ # 过滤出可能是交通标志的类别(YOLO COCO数据集中的类别)
+ class_id = int(box.cls[0])
+ class_name = result.names[class_id]
+
+ # 只保留相关类别
+ if class_name in ['stop sign', 'traffic light', 'parking meter']:
+ bbox = box.xyxy[0].cpu().numpy()
+ confidence = float(box.conf[0])
+
+ # 推断标志类型
+ sign_type = self._infer_sign_type(class_name, bbox, image)
+
+ sign_info = {
+ 'bbox': bbox.tolist(),
+ 'confidence': confidence,
+ 'type': sign_type,
+ 'class_id': class_id,
+ 'class_name': class_name,
+ 'timestamp': time.time()
+ }
+
+ # 估算距离(基于边界框大小)
+ sign_info['distance'] = self._estimate_distance(sign_info['bbox'], ego_speed)
+
+ signs.append(sign_info)
+
+ return signs
+
+ def _detect_with_color(self, image: np.ndarray, ego_speed: float) -> List[Dict]:
+ """使用颜色检测"""
+ # 转换为HSV颜色空间
+ hsv = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
+
+ signs = []
+
+ # 检测每种标志颜色
+ for color_name, color_range in self.sign_colors.items():
+ # 创建颜色掩码
+ mask = cv2.inRange(hsv, color_range['lower'], color_range['upper'])
+
+ # 形态学操作去除噪声
+ kernel = np.ones((5, 5), np.uint8)
+ mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel)
+ mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel)
+
+ # 查找轮廓
+ contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
+
+ for contour in contours:
+ area = cv2.contourArea(contour)
+
+ # 过滤太小的区域
+ if area < 100:
+ continue
+
+ # 获取边界框
+ x, y, w, h = cv2.boundingRect(contour)
+
+ # 计算形状特征
+ shape = self._detect_shape(contour)
+
+ # 推断标志类型
+ sign_type = self._infer_sign_type_from_color(color_name, shape)
+
+ if sign_type:
+ sign_info = {
+ 'bbox': [x, y, x + w, y + h],
+ 'confidence': min(0.9, area / 1000.0), # 简单置信度
+ 'type': sign_type,
+ 'color': color_name,
+ 'shape': shape,
+ 'timestamp': time.time()
+ }
+
+ # 估算距离
+ sign_info['distance'] = self._estimate_distance(sign_info['bbox'], ego_speed)
+
+ signs.append(sign_info)
+
+ return signs
+
+ def _infer_sign_type(self, class_name: str, bbox: np.ndarray, image: np.ndarray) -> str:
+ """根据检测结果推断标志类型"""
+ type_map = {
+ 'stop sign': 'stop',
+ 'traffic light': 'traffic_light',
+ 'parking meter': 'parking',
+ }
+
+ # 如果是stop sign,进一步确认(检查是否是八边形)
+ if class_name == 'stop sign':
+ # 提取ROI检查形状
+ x1, y1, x2, y2 = map(int, bbox)
+ roi = image[y1:y2, x1:x2]
+ if self._is_octagon_shape(roi):
+ return 'stop'
+
+ return type_map.get(class_name, 'unknown')
+
+ def _infer_sign_type_from_color(self, color: str, shape: str) -> Optional[str]:
+ """根据颜色和形状推断标志类型"""
+ # 简单的推断规则
+ if color == 'red' and shape == 'octagon':
+ return 'stop'
+ elif color == 'red' and shape == 'circle':
+ return 'no_entry'
+ elif color == 'yellow' and shape == 'triangle':
+ return 'warning'
+ elif color == 'blue' and shape == 'circle':
+ return 'mandatory'
+ elif color == 'blue' and shape == 'square':
+ return 'information'
+
+ return None
+
+ def _detect_shape(self, contour) -> str:
+ """检测轮廓形状"""
+ approx = cv2.approxPolyDP(contour, 0.04 * cv2.arcLength(contour, True), True)
+ num_sides = len(approx)
+
+ if num_sides == 3:
+ return 'triangle'
+ elif num_sides == 4:
+ # 判断是正方形还是长方形
+ x, y, w, h = cv2.boundingRect(contour)
+ aspect_ratio = float(w) / h
+ if 0.8 <= aspect_ratio <= 1.2:
+ return 'square'
+ else:
+ return 'rectangle'
+ elif 8 <= num_sides <= 12: # 八边形
+ return 'octagon'
+ else:
+ # 计算圆形度
+ area = cv2.contourArea(contour)
+ perimeter = cv2.arcLength(contour, True)
+ circularity = 4 * np.pi * area / (perimeter * perimeter)
+ if circularity > 0.7:
+ return 'circle'
+
+ return 'unknown'
+
+ def _estimate_distance(self, bbox: List[float], ego_speed: float) -> float:
+ """估算标志距离(简化版本)"""
+ # 基于边界框高度估算距离
+ # 假设标志的标准高度为0.5米,焦距为1000像素
+ x1, y1, x2, y2 = bbox
+ bbox_height = y2 - y1
+
+ if bbox_height <= 0:
+ return 100.0 # 默认远距离
+
+ # 简化距离公式:距离 ∝ 1/高度
+ distance = 500.0 / bbox_height
+
+ # 根据速度调整(运动模糊)
+ if ego_speed > 10:
+ distance *= (1 + ego_speed / 100.0)
+
+ return min(distance, 100.0) # 最大100米
+
+ def _non_max_suppression(self, signs: List[Dict], iou_threshold: float = 0.5) -> List[Dict]:
+ """简单的非极大值抑制"""
+ if len(signs) <= 1:
+ return signs
+
+ # 按置信度排序
+ sorted_signs = sorted(signs, key=lambda x: x['confidence'], reverse=True)
+ keep = []
+
+ while sorted_signs:
+ # 取置信度最高的
+ best = sorted_signs.pop(0)
+ keep.append(best)
+
+ # 移除与best重叠度高的
+ sorted_signs = [
+ sign for sign in sorted_signs
+ if self._iou(best['bbox'], sign['bbox']) < iou_threshold
+ ]
+
+ return keep
+
+ def _iou(self, box1: List[float], box2: List[float]) -> float:
+ """计算IoU"""
+ x1 = max(box1[0], box2[0])
+ y1 = max(box1[1], box2[1])
+ x2 = min(box1[2], box2[2])
+ y2 = min(box1[3], box2[3])
+
+ inter_area = max(0, x2 - x1) * max(0, y2 - y1)
+ box1_area = (box1[2] - box1[0]) * (box1[3] - box1[1])
+ box2_area = (box2[2] - box2[0]) * (box2[3] - box2[1])
+
+ union_area = box1_area + box2_area - inter_area
+
+ return inter_area / union_area if union_area > 0 else 0.0
+
+ def _trigger_actions(self, signs: List[Dict], ego_speed: float):
+ """触发简单动作(可选)"""
+ for sign in signs:
+ sign_type = sign.get('type', '')
+ distance = sign.get('distance', 100.0)
+
+ if sign_type == 'stop' and distance < 20:
+ logger.warning(f"🛑 前方 {distance:.1f}米有停车标志")
+
+ elif 'speed_limit' in sign_type and distance < 30:
+ try:
+ limit = int(sign_type.split('_')[-1])
+ if ego_speed * 3.6 > limit + 5: # m/s转km/h
+ logger.warning(f"📏 前方限速{limit}km/h,当前{ego_speed*3.6:.0f}km/h")
+ except:
+ pass
+
+ def draw_signs(self, image: np.ndarray, signs: List[Dict]) -> np.ndarray:
+ """在图像上绘制检测到的标志"""
+ if not self.config.get('show_signs', True):
+ return image
+
+ result = image.copy()
+
+ for sign in signs:
+ bbox = sign['bbox']
+ sign_type = sign.get('type', 'unknown')
+ confidence = sign.get('confidence', 0.0)
+ distance = sign.get('distance', 0.0)
+
+ x1, y1, x2, y2 = map(int, bbox)
+
+ # 根据标志类型选择颜色
+ color_map = {
+ 'stop': (0, 0, 255), # 红色
+ 'warning': (0, 165, 255), # 橙色
+ 'traffic_light': (0, 255, 0), # 绿色
+ 'speed_limit': (0, 255, 255), # 黄色
+ 'unknown': (128, 128, 128) # 灰色
+ }
+
+ color = color_map.get(sign_type, (128, 128, 128))
+
+ # 绘制边界框
+ cv2.rectangle(result, (x1, y1), (x2, y2), color, 2)
+
+ # 绘制标签
+ label = f"{sign_type} {confidence:.2f}"
+ if distance > 0:
+ label += f" {distance:.1f}m"
+
+ cv2.putText(result, label, (x1, y1 - 10),
+ cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
+
+ return result
+
+ def get_signs_info(self) -> Dict:
+ """获取标志检测统计信息"""
+ return {
+ 'total_signs': len(self.detected_signs_history),
+ 'recent_signs': self.detected_signs_history[-5:] if self.detected_signs_history else [],
+ 'enabled': self.config.get('enabled', True)
+ }
+
+
+def test_detector():
+ """测试检测器"""
+ detector = TrafficSignDetector()
+
+ # 测试图片
+ test_image = np.zeros((480, 640, 3), dtype=np.uint8)
+
+ # 模拟一个红色八边形(停车标志)
+ center = (320, 240)
+ radius = 50
+ points = []
+ for i in range(8):
+ angle = 2 * np.pi * i / 8
+ x = center[0] + radius * np.cos(angle)
+ y = center[1] + radius * np.sin(angle)
+ points.append((int(x), int(y)))
+
+ cv2.fillPoly(test_image, [np.array(points)], (0, 0, 255))
+
+ # 检测
+ signs = detector.detect(test_image)
+
+ print(f"检测到 {len(signs)} 个标志")
+ for sign in signs:
+ print(f" 类型: {sign.get('type')}, 置信度: {sign.get('confidence'):.2f}")
+
+ # 绘制
+ result = detector.draw_signs(test_image, signs)
+
+ return len(signs) > 0
+
+
+if __name__ == "__main__":
+ success = test_detector()
+ if success:
+ print("✅ 标志检测器测试通过")
+ else:
+ print("⚠️ 测试未检测到标志")
\ No newline at end of file