diff --git a/src/tracking_car/carla_tracking_ros/.gitignore b/src/tracking_car/carla_tracking_ros/.gitignore new file mode 100644 index 0000000000..ea4e295337 --- /dev/null +++ b/src/tracking_car/carla_tracking_ros/.gitignore @@ -0,0 +1,38 @@ +# ROS编译文件 +build/ +devel/ +install/ + +# Python缓存 +__pycache__/ +*.py[cod] +*$py.class +*.so +.Python + +# 日志和临时文件 +*.log +*.pid +*.sqlite +*.sqlite3 + +# 编辑器文件 +.vscode/ +.idea/ +*.swp +*.swo +*~ + +# 系统文件 +.DS_Store +Thumbs.db + +# 备份文件 +*.bak +*.backup +*.tmp + +# 大文件 +*.bag +*.mp4 +*.avi diff --git a/src/tracking_car/carla_tracking_ros/CMakeLists.txt b/src/tracking_car/carla_tracking_ros/CMakeLists.txt new file mode 100644 index 0000000000..4963f232df --- /dev/null +++ b/src/tracking_car/carla_tracking_ros/CMakeLists.txt @@ -0,0 +1,28 @@ +cmake_minimum_required(VERSION 3.0.2) +project(carla_tracking_ros) + +find_package(catkin REQUIRED COMPONENTS + roscpp + rospy + std_msgs + cv_bridge + image_transport +) + +catkin_python_setup() + +catkin_package( + CATKIN_DEPENDS roscpp rospy std_msgs cv_bridge image_transport +) + +install(DIRECTORY launch/ + DESTINATION ${CATKIN_PACKAGE_SHARE_DESTINATION}/launch +) + +install(DIRECTORY config/ + DESTINATION ${CATKIN_PACKAGE_SHARE_DESTINATION}/config +) + +install(FILES requirements.txt + DESTINATION ${CATKIN_PACKAGE_SHARE_DESTINATION} +) diff --git a/src/tracking_car/carla_tracking_ros/config/config.yaml b/src/tracking_car/carla_tracking_ros/config/config.yaml new file mode 100644 index 0000000000..5e1d9064b0 --- /dev/null +++ b/src/tracking_car/carla_tracking_ros/config/config.yaml @@ -0,0 +1,90 @@ +# ======================== CARLA连接配置 ======================== +host: "localhost" +port: 2000 +timeout: 20.0 + +# ======================== 传感器配置 ======================== +img_width: 1280 +img_height: 960 +fov: 110 +sensor_tick: 0.05 +use_lidar: true + +# ======================== LiDAR配置 ======================== +lidar_channels: 32 +lidar_range: 100.0 +lidar_points_per_second: 500000 + +# ======================== 检测模型配置 ======================== +yolo_model: "yolov8n.pt" +conf_thres: 0.1 +iou_thres: 0.3 +device: "cpu" +yolo_imgsz_max: 640 + +# ======================== 跟踪算法配置 ======================== +max_age: 5 +min_hits: 3 +kf_dt: 0.05 +max_speed: 50.0 + +# ======================== 可视化配置 ======================== +window_width: 1280 +window_height: 720 +display_fps: 30 + +# ======================== 轨迹历史配置 ======================== +track_history_len: 20 + +# ======================== 行为分析配置 ======================== +stop_speed_thresh: 0.8 +stop_frames_thresh: 5 +overtake_speed_ratio: 1.5 +overtake_dist_thresh: 60.0 +lane_change_thresh: 0.4 +brake_accel_thresh: 2.0 +turn_angle_thresh: 15.0 +danger_dist_thresh: 15.0 +predict_frames: 10 + +# ======================== 天气与NPC配置 ======================== +default_weather: "clear" +num_npcs: 20 + +# ======================== 视角控制配置 ======================== +view: + default_mode: "satellite" + satellite_height: 50.0 + behind_distance: 10.0 + first_person_height: 1.6 + +# ======================== 热重载配置 ======================== +hot_reload: + enabled: true # 是否启用热重载 + check_interval: 100 # 每多少帧检查一次配置文件 + safe_keys_only: true # 是否只更新安全参数(避免重启连接) + +# 可热重载的安全参数列表(这些参数可以安全更新而不需要重启) +hot_reload_safe_keys: + - "conf_thres" # 检测置信度阈值 + - "iou_thres" # IOU阈值 + - "display_fps" # 显示帧率 + - "max_age" # 跟踪最大丢失帧数 + - "min_hits" # 跟踪最小匹配次数 + - "adaptive_fps" # 是否启用自适应帧率 + - "min_fps" # 最小FPS + - "max_fps" # 最大FPS + - "fov" # 相机视野 + - "yolo_imgsz_max" # YOLO输入尺寸 + - "stop_speed_thresh" # 停车速度阈值 + - "danger_dist_thresh" # 危险距离阈值 + - "weather" # 天气设置 + +# ======================== 交通标志检测配置 ======================== +enable_sign_detection: false # 是否启用交通标志检测 + +traffic_sign: + enabled: true # 检测器是否启用 + show_signs: true # 是否显示检测框 + enable_actions: false # 是否触发动作(警告等) + conf_threshold: 0.5 # 置信度阈值 \ No newline at end of file diff --git a/src/tracking_car/carla_tracking_ros/launch/main.launch b/src/tracking_car/carla_tracking_ros/launch/main.launch new file mode 100644 index 0000000000..0fcf5e8f11 --- /dev/null +++ b/src/tracking_car/carla_tracking_ros/launch/main.launch @@ -0,0 +1,34 @@ + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/src/tracking_car/carla_tracking_ros/package.xml b/src/tracking_car/carla_tracking_ros/package.xml new file mode 100644 index 0000000000..b614e5516e --- /dev/null +++ b/src/tracking_car/carla_tracking_ros/package.xml @@ -0,0 +1,31 @@ + + + carla_tracking_ros + 1.0.0 + CARLA Multi-Object Tracking System - ROS Integration + + tjl + MIT + + https://github.com/yourusername/carla_tracking_ros + tjl + + catkin + + roscpp + rospy + std_msgs + cv_bridge + image_transport + + + message_generation + message_runtime + + + rosunit + + + + + diff --git a/src/tracking_car/carla_tracking_ros/requirements.txt b/src/tracking_car/carla_tracking_ros/requirements.txt new file mode 100644 index 0000000000..27cc8e8623 --- /dev/null +++ b/src/tracking_car/carla_tracking_ros/requirements.txt @@ -0,0 +1,17 @@ +# 核心依赖 +# 可选依赖 + +carla>=0.9.14 +loguru>=0.7.0 +matplotlib>=3.7.0 +numba>=0.57.0 +numpy>=1.24.0 +open3d>=0.17.0 +opencv-python>=4.8.0 +pandas>=2.0.0 +psutil>=5.9.0 +pyyaml>=6.0.0 +scikit-learn>=1.2.0 +scipy>=1.10.0 +torch>=2.0.0 +ultralytics>=8.0.0 diff --git a/src/tracking_car/carla_tracking_ros/scripts/__init__.py b/src/tracking_car/carla_tracking_ros/scripts/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/tracking_car/carla_tracking_ros/scripts/main.py b/src/tracking_car/carla_tracking_ros/scripts/main.py new file mode 100644 index 0000000000..a7d504b607 --- /dev/null +++ b/src/tracking_car/carla_tracking_ros/scripts/main.py @@ -0,0 +1,159 @@ +#!/usr/bin/env python3 +""" +测试版 - CARLA多目标跟踪系统(ROS版) +简化版本用于测试ROS封装 +""" + +import sys +import os +import time +import argparse + +# ======================== ROS支持 ======================== +try: + import rospy + from sensor_msgs.msg import Image + from std_msgs.msg import String, Float32 + from cv_bridge import CvBridge + import cv2 + import numpy as np + ROS_AVAILABLE = True + print("✅ ROS模块导入成功") +except ImportError as e: + ROS_AVAILABLE = False + print(f"⚠️ ROS模块导入失败: {e}") + +# ======================== 简单日志类 ======================== +class SimpleLogger: + def info(self, msg): print(f"[INFO] {msg}") + def warning(self, msg): print(f"[WARN] {msg}") + def error(self, msg): print(f"[ERROR] {msg}") + +logger = SimpleLogger() + +# ======================== 模拟传感器类 ======================== +class MockSensorManager: + def __init__(self): + logger.info("初始化模拟传感器管理器") + + def get_image(self): + # 创建一个简单的测试图像 + img = np.zeros((480, 640, 3), dtype=np.uint8) + cv2.putText(img, 'CARLA Tracking ROS', (50, 240), + cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2) + return img + + def get_detections(self): + # 模拟检测结果 + return [ + {'id': 1, 'bbox': [100, 100, 200, 200], 'class': 'vehicle'}, + {'id': 2, 'bbox': [300, 150, 400, 250], 'class': 'pedestrian'} + ] + +# ======================== 模拟跟踪器类 ======================== +class MockTracker: + def __init__(self): + logger.info("初始化模拟多目标跟踪器") + self.track_id = 0 + + def update(self, detections): + self.track_id += 1 + return [{'id': self.track_id, 'detection': d} for d in detections] + +# ======================== ROS发布器类 ======================== +class ROSPublisher: + def __init__(self): + if ROS_AVAILABLE: + self.bridge = CvBridge() + self.image_pub = rospy.Publisher('/carla/camera', Image, queue_size=10) + self.detection_pub = rospy.Publisher('/carla/detections', String, queue_size=10) + self.status_pub = rospy.Publisher('/carla/status', String, queue_size=10) + logger.info("ROS发布器初始化完成") + else: + logger.warning("ROS不可用,发布器未初始化") + + def publish_image(self, cv_image): + if ROS_AVAILABLE and hasattr(self, 'image_pub'): + try: + ros_image = self.bridge.cv2_to_imgmsg(cv_image, "bgr8") + self.image_pub.publish(ros_image) + return True + except Exception as e: + logger.warning(f"发布图像失败: {e}") + return False + + def publish_detection(self, detections): + if ROS_AVAILABLE and hasattr(self, 'detection_pub'): + det_str = str(detections) + self.detection_pub.publish(det_str) + return True + return False + + def publish_status(self, status): + if ROS_AVAILABLE and hasattr(self, 'status_pub'): + self.status_pub.publish(status) + return True + return False + +# ======================== 主函数 ======================== +def main(): + parser = argparse.ArgumentParser(description='CARLA多目标跟踪系统ROS测试版') + parser.add_argument('--mode', choices=['simulation', 'test'], default='test', + help='运行模式:simulation(仿真)或 test(测试)') + parser.add_argument('--duration', type=int, default=10, + help='运行时长(秒)') + parser.add_argument('--rate', type=float, default=2.0, + help='发布频率(Hz)') + parser.add_argument('--no-ros', action='store_true', + help='禁用ROS功能') + + args = parser.parse_args() + + # 初始化ROS + if ROS_AVAILABLE and not args.no_ros: + rospy.init_node('carla_tracking_system') + logger.info("ROS节点初始化: carla_tracking_system") + + # 创建组件 + sensor_manager = MockSensorManager() + tracker = MockTracker() + ros_publisher = ROSPublisher() if not args.no_ros else None + + logger.info(f"启动CARLA多目标跟踪系统ROS测试版") + logger.info(f"模式: {args.mode}, 时长: {args.duration}秒, 频率: {args.rate}Hz") + logger.info(f"ROS支持: {'启用' if ros_publisher else '禁用'}") + + # 主循环 + start_time = time.time() + frame_count = 0 + + try: + while time.time() - start_time < args.duration: + # 获取数据 + image = sensor_manager.get_image() + detections = sensor_manager.get_detections() + tracks = tracker.update(detections) + + # 发布ROS消息 + if ros_publisher: + ros_publisher.publish_image(image) + ros_publisher.publish_detection(detections) + ros_publisher.publish_status(f"Frame {frame_count}: {len(tracks)} tracks") + + # 显示信息 + logger.info(f"帧 {frame_count}: 检测到 {len(detections)} 个目标, 跟踪 {len(tracks)} 个轨迹") + + frame_count += 1 + time.sleep(1.0 / args.rate) + + except KeyboardInterrupt: + logger.info("收到退出信号") + except Exception as e: + logger.error(f"运行错误: {e}") + + logger.info(f"系统运行完成,共处理 {frame_count} 帧") + logger.info(f"平均帧率: {frame_count / args.duration:.2f} Hz") + logger.info("系统关闭") + +if __name__ == '__main__': + main() diff --git a/src/tracking_car/carla_tracking_ros/scripts/sensors.py b/src/tracking_car/carla_tracking_ros/scripts/sensors.py new file mode 100644 index 0000000000..e4ba38fda4 --- /dev/null +++ b/src/tracking_car/carla_tracking_ros/scripts/sensors.py @@ -0,0 +1,945 @@ +""" +sensors.py - CARLA传感器管理 +包含:相机、LiDAR传感器封装和管理 +""" + +import random # 添加随机模块支持 +import carla +import cv2 +import numpy as np +import queue +import threading +import sys +import time + +# 配置日志 +try: + from loguru import logger +except ImportError: + # 使用标准logging作为回退 + import logging + logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') + logger = logging.getLogger(__name__) + +import open3d as o3d +from sklearn.cluster import DBSCAN + +class CameraManager: + """相机管理器""" + + def __init__(self, world, ego_vehicle, config): + """ + 初始化相机 + + Args: + world: CARLA世界对象 + ego_vehicle: 自车对象 + config: 配置字典 + """ + self.world = world + self.ego_vehicle = ego_vehicle + self.config = config + self.camera = None + self.image_queue = queue.Queue(maxsize=2) + self.current_image = None + self.frame_count = 0 + + def setup(self): + """设置相机""" + try: + # 获取相机蓝图 + camera_bp = self.world.get_blueprint_library().find('sensor.camera.rgb') + + # 设置相机属性 + camera_bp.set_attribute('image_size_x', str(self.config.get('img_width', 640))) + camera_bp.set_attribute('image_size_y', str(self.config.get('img_height', 480))) + camera_bp.set_attribute('fov', str(self.config.get('fov', 90))) + camera_bp.set_attribute('sensor_tick', str(self.config.get('sensor_tick', 0.05))) + + # 设置相机位置(车顶前方) + camera_transform = carla.Transform( + carla.Location(x=2.5, z=2.5), + carla.Rotation(pitch=-5) # 略微向下倾斜 + ) + + # 生成相机 + self.camera = self.world.spawn_actor( + camera_bp, + camera_transform, + attach_to=self.ego_vehicle + ) + + # 绑定回调函数 + self.camera.listen(self._camera_callback) + + logger.info(f"✅ 相机初始化成功 (ID: {self.camera.id})") + return True + + except Exception as e: + logger.error(f"❌ 相机初始化失败: {e}") + return False + + def _camera_callback(self, image): + """相机数据回调函数""" + try: + # 将原始数据转换为numpy数组 + array = np.frombuffer(image.raw_data, dtype=np.uint8) + array = array.reshape((image.height, image.width, 4)) + + # 提取RGB通道(去掉alpha通道) + rgb_array = array[:, :, :3] + + # 轻微高斯模糊减少噪声 + rgb_array = cv2.GaussianBlur(rgb_array, (3, 3), 0) + + # 更新当前图像 + self.current_image = rgb_array + + # 放入队列(如果队列已满,丢弃最旧的数据) + if self.image_queue.full(): + try: + self.image_queue.get_nowait() + except queue.Empty: + pass + + self.image_queue.put(rgb_array.copy()) + self.frame_count += 1 + + except Exception as e: + logger.warning(f"相机回调错误: {e}") + + def get_image(self, timeout=0.1): + """ + 获取最新图像 + + Args: + timeout: 超时时间(秒) + + Returns: + np.ndarray or None: 图像数据 + """ + try: + # 首先尝试从队列获取最新图像 + image = self.image_queue.get(timeout=timeout) + # 清空队列中的旧图像 + while not self.image_queue.empty(): + try: + self.image_queue.get_nowait() + except queue.Empty: + break + return image + except queue.Empty: + # 如果队列为空,返回当前图像 + return self.current_image + + def get_current_image(self): + """获取当前图像(不阻塞)""" + return self.current_image + + def destroy(self): + """销毁相机""" + if self.camera and self.camera.is_alive: + try: + self.camera.stop() + self.camera.destroy() + logger.info("✅ 相机已销毁") + except Exception as e: + logger.warning(f"销毁相机失败: {e}") + self.camera = None + +class LiDARManager: + """LiDAR管理器""" + + def __init__(self, world, ego_vehicle, config): + """ + 初始化LiDAR + + Args: + world: CARLA世界对象 + ego_vehicle: 自车对象 + config: 配置字典 + """ + self.world = world + self.ego_vehicle = ego_vehicle + self.config = config + self.lidar = None + self.pointcloud_queue = queue.Queue(maxsize=2) + self.current_pointcloud = None + self.current_transform = None + + def setup(self): + """设置LiDAR""" + try: + if not self.config.get('use_lidar', True): + logger.info("LiDAR被禁用") + return True + + # 获取LiDAR蓝图 + lidar_bp = self.world.get_blueprint_library().find('sensor.lidar.ray_cast') + + # 设置LiDAR属性 + lidar_bp.set_attribute('channels', str(self.config.get('lidar_channels', 32))) + lidar_bp.set_attribute('range', str(self.config.get('lidar_range', 100.0))) + lidar_bp.set_attribute('points_per_second', + str(self.config.get('lidar_points_per_second', 500000))) + lidar_bp.set_attribute('rotation_frequency', str(self.config.get('rotation_frequency', 20))) + lidar_bp.set_attribute('sensor_tick', str(self.config.get('sensor_tick', 0.05))) + + # 设置LiDAR位置(车顶中央) + lidar_transform = carla.Transform( + carla.Location(x=0.0, z=2.5), + carla.Rotation() + ) + + # 生成LiDAR + self.lidar = self.world.spawn_actor( + lidar_bp, + lidar_transform, + attach_to=self.ego_vehicle + ) + + # 绑定回调函数 + self.lidar.listen(self._lidar_callback) + + logger.info(f"✅ LiDAR初始化成功 (ID: {self.lidar.id})") + return True + + except Exception as e: + logger.error(f"❌ LiDAR初始化失败: {e}") + return False + + def _lidar_callback(self, pointcloud): + """LiDAR数据回调函数""" + try: + # 将原始数据转换为numpy数组 + points = np.frombuffer(pointcloud.raw_data, dtype=np.float32) + points = points.reshape(-1, 4)[:, :3] # 只取xyz,忽略反射强度 + + # 过滤地面点(简单的高度过滤) + ground_mask = points[:, 2] < -1.0 + filtered_points = points[~ground_mask] + + # 更新当前点云 + self.current_pointcloud = filtered_points + self.current_transform = pointcloud.transform + + # 放入队列(如果队列已满,丢弃最旧的数据) + if self.pointcloud_queue.full(): + try: + self.pointcloud_queue.get_nowait() + except queue.Empty: + pass + + self.pointcloud_queue.put((filtered_points.copy(), pointcloud.transform)) + + except Exception as e: + logger.warning(f"LiDAR回调错误: {e}") + + def get_pointcloud(self, timeout=0.1): + """ + 获取最新点云数据 + + Args: + timeout: 超时时间(秒) + + Returns: + tuple: (points, transform) 或 (None, None) + """ + try: + points, transform = self.pointcloud_queue.get(timeout=timeout) + # 清空队列中的旧数据 + while not self.pointcloud_queue.empty(): + try: + self.pointcloud_queue.get_nowait() + except queue.Empty: + break + return points, transform + except queue.Empty: + # 如果队列为空,返回当前点云 + return self.current_pointcloud, self.current_transform + + def detect_objects(self, min_points=30): + """ + 从点云中检测物体 + + Args: + min_points: 最小点数阈值 + + Returns: + list: 检测到的物体列表 + """ + if self.current_pointcloud is None or len(self.current_pointcloud) < min_points: + return [] + + try: + # 使用DBSCAN聚类 + clustering = DBSCAN(eps=0.8, min_samples=30).fit(self.current_pointcloud[:, :2]) + + objects = [] + for label in set(clustering.labels_): + if label == -1: # 忽略噪声点 + continue + + # 获取该聚类的点 + cluster_points = self.current_pointcloud[clustering.labels_ == label] + + if len(cluster_points) < min_points: + continue + + # 计算3D边界框 + min_coords = cluster_points.min(axis=0) + max_coords = cluster_points.max(axis=0) + center = (min_coords + max_coords) / 2 + size = max_coords - min_coords + + # 估计物体类型(基于尺寸) + obj_type = self._estimate_object_type(size) + + objects.append({ + 'bbox_3d': [*min_coords, *max_coords], # [x_min, y_min, z_min, x_max, y_max, z_max] + 'center': center.tolist(), + 'size': size.tolist(), + 'num_points': len(cluster_points), + 'type': obj_type, + 'label': label + }) + + return objects + + except Exception as e: + logger.warning(f"LiDAR物体检测失败: {e}") + return [] + + def _estimate_object_type(self, size): + """根据尺寸估计物体类型""" + length, width, height = size + + # 简单的大小分类 + if height > 2.5: + return "truck" + elif width > 2.0: + return "bus" + elif length > 4.0: + return "car" + else: + return "unknown" + + def get_open3d_pointcloud(self): + """ + 获取Open3D格式的点云 + + Returns: + o3d.geometry.PointCloud or None: Open3D点云对象 + """ + if self.current_pointcloud is None or len(self.current_pointcloud) == 0: + return None + + try: + pcd = o3d.geometry.PointCloud() + pcd.points = o3d.utility.Vector3dVector(self.current_pointcloud) + + # 根据高度着色(低处蓝色,高处红色) + z_min = self.current_pointcloud[:, 2].min() + z_max = self.current_pointcloud[:, 2].max() + z_range = max(z_max - z_min, 1e-6) + + colors = np.zeros((len(self.current_pointcloud), 3)) + normalized_z = (self.current_pointcloud[:, 2] - z_min) / z_range + colors[:, 0] = normalized_z # 红色通道(高处) + colors[:, 2] = 1 - normalized_z # 蓝色通道(低处) + + pcd.colors = o3d.utility.Vector3dVector(colors) + + return pcd + + except Exception as e: + logger.warning(f"创建Open3D点云失败: {e}") + return None + + def destroy(self): + """销毁LiDAR""" + if self.lidar and self.lidar.is_alive: + try: + self.lidar.stop() + self.lidar.destroy() + logger.info("✅ LiDAR已销毁") + except Exception as e: + logger.warning(f"销毁LiDAR失败: {e}") + self.lidar = None + +class SpectatorManager: + """CARLA视角管理器 - 提供卫星视角跟随""" + + def __init__(self, world, ego_vehicle, config): + """ + 初始化视角管理器 + + Args: + world: CARLA世界对象 + ego_vehicle: 自车对象 + config: 配置字典 + """ + self.world = world + self.ego_vehicle = ego_vehicle + self.config = config + self.spectator = None + view_config = config.get('view', {}) + self.view_mode = view_config.get('default_mode', 'satellite') + self.view_height = view_config.get('satellite_height', 50.0) + self.follow_distance = view_config.get('behind_distance', 10.0) + self.first_person_height = view_config.get('first_person_height', 1.6) + + def setup(self): + """设置视角""" + try: + # 获取世界中的观察者(spectator) + self.spectator = self.world.get_spectator() + + # 设置初始视角 + self._set_satellite_view() + + logger.info(f"✅ 视角管理器初始化成功") + return True + + except Exception as e: + logger.error(f"❌ 视角管理器初始化失败: {e}") + return False + + def _set_satellite_view(self): + """设置卫星视角""" + if not self.ego_vehicle or not self.ego_vehicle.is_alive: + return + + try: + # 获取车辆位置和方向 + vehicle_transform = self.ego_vehicle.get_transform() + vehicle_location = vehicle_transform.location + vehicle_rotation = vehicle_transform.rotation + + # 设置卫星视角位置(车辆正上方) + spectator_location = carla.Location( + x=vehicle_location.x, + y=vehicle_location.y, + z=vehicle_location.z + self.view_height + ) + + # 设置视角朝向(俯瞰车辆) + spectator_rotation = carla.Rotation( + pitch=-90, # 向下看 + yaw=vehicle_rotation.yaw, + roll=0 + ) + + # 应用变换 + spectator_transform = carla.Transform( + spectator_location, + spectator_rotation + ) + self.spectator.set_transform(spectator_transform) + + except Exception as e: + logger.warning(f"设置卫星视角失败: {e}") + + def _set_behind_view(self): + """设置后方跟随视角""" + if not self.ego_vehicle or not self.ego_vehicle.is_alive: + return + + try: + # 获取车辆位置和方向 + vehicle_transform = self.ego_vehicle.get_transform() + vehicle_location = vehicle_transform.location + vehicle_rotation = vehicle_transform.rotation + + # 计算后方偏移位置 + import math + yaw_rad = math.radians(vehicle_rotation.yaw) + + # 车辆后方偏移 + behind_x = vehicle_location.x - self.follow_distance * math.cos(yaw_rad) + behind_y = vehicle_location.y - self.follow_distance * math.sin(yaw_rad) + + # 设置摄像机位置(稍高于车辆) + spectator_location = carla.Location( + x=behind_x, + y=behind_y, + z=vehicle_location.z + 3.0 + ) + + # 设置视角朝向(看向车辆) + # 计算朝向车辆的旋转 + dx = vehicle_location.x - spectator_location.x + dy = vehicle_location.y - spectator_location.y + target_yaw = math.degrees(math.atan2(dy, dx)) + + spectator_rotation = carla.Rotation( + pitch=-15, # 略微向下 + yaw=target_yaw, + roll=0 + ) + + # 应用变换 + spectator_transform = carla.Transform( + spectator_location, + spectator_rotation + ) + self.spectator.set_transform(spectator_transform) + + except Exception as e: + logger.warning(f"设置后方视角失败: {e}") + + def _set_first_person_view(self): + """设置第一人称视角""" + if not self.ego_vehicle or not self.ego_vehicle.is_alive: + return + + try: + # 获取车辆变换 + vehicle_transform = self.ego_vehicle.get_transform() + + # 稍微调整位置(从驾驶员视角) + location = carla.Location( + x=vehicle_transform.location.x, + y=vehicle_transform.location.y, + z=vehicle_transform.location.z + self.first_person_height # 改为使用配置 + ) + + # 使用车辆的方向 + rotation = carla.Rotation( + pitch=vehicle_transform.rotation.pitch, + yaw=vehicle_transform.rotation.yaw, + roll=vehicle_transform.rotation.roll + ) + + # 应用变换 + spectator_transform = carla.Transform( + location, + rotation + ) + self.spectator.set_transform(spectator_transform) + + except Exception as e: + logger.warning(f"设置第一人称视角失败: {e}") + + def set_view_mode(self, mode): + """ + 设置视角模式 + + Args: + mode: 视角模式 ('satellite', 'behind', 'first_person') + """ + if mode not in ['satellite', 'behind', 'first_person']: + logger.warning(f"未知的视角模式: {mode}") + return + + self.view_mode = mode + + def update(self): + """更新视角""" + if not self.ego_vehicle or not self.ego_vehicle.is_alive: + return + + try: + if self.view_mode == 'satellite': + self._set_satellite_view() + elif self.view_mode == 'behind': + self._set_behind_view() + elif self.view_mode == 'first_person': + self._set_first_person_view() + + except Exception as e: + logger.debug(f"更新视角失败: {e}") + + def cycle_view_mode(self): + """循环切换视角模式""" + modes = ['satellite', 'behind', 'first_person'] + current_index = modes.index(self.view_mode) if self.view_mode in modes else 0 + next_index = (current_index + 1) % len(modes) + self.view_mode = modes[next_index] + logger.info(f"切换到 {self.view_mode} 视角") + + def destroy(self): + """销毁视角管理器""" + self.spectator = None + logger.info("✅ 视角管理器已销毁") + +class SensorManager: + """传感器管理器(统一管理所有传感器)""" + + def __init__(self, world, ego_vehicle, config): + """ + 初始化传感器管理器 + + Args: + world: CARLA世界对象 + ego_vehicle: 自车对象 + config: 配置字典 + """ + self.world = world + self.ego_vehicle = ego_vehicle + self.config = config + + self.camera_manager = None + self.lidar_manager = None + self.spectator_manager = None # 新增视角管理器 + self.is_setup = False + + def setup(self): + """设置所有传感器""" + logger.info("正在初始化传感器...") + + # 初始化相机 + self.camera_manager = CameraManager(self.world, self.ego_vehicle, self.config) + camera_success = self.camera_manager.setup() + + # 初始化LiDAR + lidar_success = True + if self.config.get('use_lidar', True): + self.lidar_manager = LiDARManager(self.world, self.ego_vehicle, self.config) + lidar_success = self.lidar_manager.setup() + else: + logger.info("LiDAR功能已禁用") + + # 初始化视角管理器 + self.spectator_manager = SpectatorManager(self.world, self.ego_vehicle, self.config) + spectator_success = self.spectator_manager.setup() + + self.is_setup = camera_success and lidar_success and spectator_success + + if self.is_setup: + logger.info("✅ 所有传感器初始化完成") + else: + logger.warning("⚠️ 传感器初始化不完全") + + return self.is_setup + + # 在get_sensor_data方法中添加视角更新 + def get_sensor_data(self, timeout=0.05): + """ + 获取所有传感器数据 + + Args: + timeout: 超时时间(秒) + + Returns: + dict: 传感器数据字典 + """ + data = { + 'image': None, + 'pointcloud': None, + 'lidar_transform': None, + 'lidar_objects': [], + 'timestamp': time.time() + } + + # 获取相机图像 + if self.camera_manager: + data['image'] = self.camera_manager.get_image(timeout=timeout) + + # 获取LiDAR数据 + if self.lidar_manager: + points, transform = self.lidar_manager.get_pointcloud(timeout=timeout) + data['pointcloud'] = points + data['lidar_transform'] = transform + + # 检测物体 + if points is not None: + data['lidar_objects'] = self.lidar_manager.detect_objects() + + # 更新视角 + if self.spectator_manager: + self.spectator_manager.update() + + return data + + # 添加视角控制方法 + def set_view_mode(self, mode): + """设置视角模式""" + if self.spectator_manager: + self.spectator_manager.set_view_mode(mode) + + def cycle_view_mode(self): + """循环切换视角模式""" + if self.spectator_manager: + self.spectator_manager.cycle_view_mode() + + def destroy(self): + """销毁所有传感器""" + logger.info("正在销毁传感器...") + + if self.camera_manager: + self.camera_manager.destroy() + + if self.lidar_manager: + self.lidar_manager.destroy() + + if self.spectator_manager: + self.spectator_manager.destroy() + + logger.info("✅ 所有传感器已销毁") + +def create_ego_vehicle(world, config, spawn_points=None): + """ + 创建自车 + + Args: + world: CARLA世界对象 + config: 配置字典 + spawn_points: 可选的自定义生成点列表 + + Returns: + carla.Vehicle or None: 自车对象 + """ + try: + # 获取生成点 + if spawn_points is None: + spawn_points = world.get_map().get_spawn_points() + + if not spawn_points: + logger.error("❌ 无可用生成点") + return None + + logger.info(f"找到 {len(spawn_points)} 个生成点") + + # 选择车辆蓝图 + vehicle_bp = None + vehicle_filter = config.get('ego_vehicle_filter', 'vehicle.tesla.model3') + + # 尝试首选车辆 + blueprint_library = world.get_blueprint_library() + for bp in blueprint_library.filter(vehicle_filter): + if int(bp.get_attribute('number_of_wheels')) == 4: + vehicle_bp = bp + logger.info(f"找到车辆蓝图: {bp.id}") + break + + # 如果没找到,选择任意四轮车辆 + if vehicle_bp is None: + logger.info("首选车辆未找到,尝试其他四轮车辆...") + for bp in blueprint_library.filter('vehicle.*'): + if int(bp.get_attribute('number_of_wheels')) == 4: + vehicle_bp = bp + logger.info(f"使用备用车辆蓝图: {bp.id}") + break + + if vehicle_bp is None: + logger.error("❌ 找不到合适的车辆蓝图") + return None + + # 设置车辆颜色 + color = config.get('ego_vehicle_color', '255,0,0') + vehicle_bp.set_attribute('color', color) + + # 尝试生成车辆 - 改进的碰撞避免策略 + max_attempts = config.get('spawn_max_attempts', 20) + + for attempt in range(max_attempts): + # 随机选择生成点 + if attempt < len(spawn_points): + spawn_point = spawn_points[attempt] + else: + # 选择随机一个生成点 + import random + spawn_point = random.choice(spawn_points) + + # 随机偏移位置以避免碰撞 + spawn_point.location.x += random.uniform(-3, 3) + spawn_point.location.y += random.uniform(-3, 3) + + logger.info(f"尝试生成自车 (尝试 {attempt + 1}/{max_attempts}) " + f"位置: x={spawn_point.location.x:.1f}, y={spawn_point.location.y:.1f}") + + # 设置生成点的高度为地面以上0.5米 + spawn_point.location.z += 0.5 + + ego_vehicle = world.try_spawn_actor(vehicle_bp, spawn_point) + + if ego_vehicle is not None: + logger.info(f"✅ 自车生成成功 (尝试 {attempt + 1}/{max_attempts})") + logger.info(f" 位置: ({spawn_point.location.x:.1f}, {spawn_point.location.y:.1f}, {spawn_point.location.z:.1f})") + + # 等待一小段时间让车辆稳定 + world.tick() + + # 设置自动驾驶 + try: + ego_vehicle.set_autopilot(True, 8000) + logger.info("✅ 自车自动驾驶已启用") + except Exception as e: + logger.warning(f"设置自动驾驶失败: {e}") + try: + ego_vehicle.set_autopilot(True) + logger.info("✅ 自车自动驾驶已启用(备用方法)") + except: + logger.warning("无法设置自动驾驶,车辆将保持静止") + + return ego_vehicle + else: + logger.debug(f"生成失败,尝试下一个位置...") + + logger.error(f"❌ 经过 {max_attempts} 次尝试后仍无法生成自车") + logger.info("建议:") + logger.info("1. 重新启动CARLA服务器") + logger.info("2. 在CARLA中手动清理场景中的车辆") + logger.info("3. 尝试不同的生成点") + + return None + + except Exception as e: + logger.error(f"❌ 创建自车失败: {e}") + import traceback + traceback.print_exc() + return None + +def spawn_npc_vehicles(world, config, count=None): + """ + 生成NPC车辆 + + Args: + world: CARLA世界对象 + config: 配置字典 + count: NPC数量(默认使用配置中的值) + + Returns: + int: 成功生成的NPC数量 + """ + try: + if count is None: + count = config.get('num_npcs', 20) + + spawn_points = world.get_map().get_spawn_points() + if not spawn_points: + logger.warning("无可用生成点,无法生成NPC") + return 0 + + # 过滤合适的车辆蓝图(四轮车辆,排除特殊车辆) + vehicle_bps = [] + for bp in world.get_blueprint_library().filter('vehicle.*'): + if int(bp.get_attribute('number_of_wheels')) == 4: + # 排除特殊车辆 + if not bp.id.endswith(('firetruck', 'ambulance', 'police', 'charger')): + vehicle_bps.append(bp) + + if not vehicle_bps: + logger.warning("找不到合适的NPC车辆蓝图") + return 0 + + spawned_count = 0 + used_spawn_points = set() + + for i in range(min(count * 3, len(spawn_points))): # 最多尝试3倍数量 + if spawned_count >= count: + break + + spawn_point = spawn_points[i] + + # 检查是否已使用该位置 + position_key = (round(spawn_point.location.x, 1), + round(spawn_point.location.y, 1)) + + if position_key in used_spawn_points: + continue + + # 随机选择车辆蓝图 + vehicle_bp = random.choice(vehicle_bps) + + # 尝试生成 + npc = world.try_spawn_actor(vehicle_bp, spawn_point) + + if npc is not None: + used_spawn_points.add(position_key) + spawned_count += 1 + + # 设置自动驾驶 + try: + npc.set_autopilot(True, 8000) + except: + try: + npc.set_autopilot(True) + except: + pass + + logger.info(f"✅ 成功生成 {spawned_count}/{count} 个NPC车辆") + return spawned_count + + except Exception as e: + logger.error(f"生成NPC车辆失败: {e}") + return 0 + +def clear_all_actors(world, exclude_ids=None): + """ + 清理所有演员(车辆和传感器) + + Args: + world: CARLA世界对象 + exclude_ids: 要排除的演员ID列表 + """ + try: + exclude_ids = set(exclude_ids) if exclude_ids else set() + + actors = world.get_actors() + + # 按类型分组清理 + vehicle_actors = [] + sensor_actors = [] + + for actor in actors: + if actor.id in exclude_ids: + continue + + if actor.type_id.startswith('vehicle.'): + vehicle_actors.append(actor) + elif actor.type_id.startswith('sensor.'): + sensor_actors.append(actor) + + # 先清理传感器 + logger.info(f"清理 {len(sensor_actors)} 个传感器...") + for sensor in sensor_actors: + try: + if sensor.is_alive: + sensor.destroy() + except: + pass + + # 再清理车辆 + logger.info(f"清理 {len(vehicle_actors)} 个车辆...") + batch_size = 10 + for i in range(0, len(vehicle_actors), batch_size): + batch = vehicle_actors[i:i+batch_size] + for vehicle in batch: + try: + if vehicle.is_alive: + vehicle.destroy() + except: + pass + + logger.info("✅ 清理完成") + + except Exception as e: + logger.warning(f"清理演员时出错: {e}") + +def test_sensor_manager(): + """测试传感器管理器""" + print("=" * 50) + print("测试 sensors.py...") + print("=" * 50) + + # 模拟配置 + test_config = { + 'img_width': 640, + 'img_height': 480, + 'fov': 90, + 'sensor_tick': 0.05, + 'use_lidar': True, + 'lidar_channels': 32, + 'lidar_range': 100.0, + 'lidar_points_per_second': 500000, + } + + print("✅ sensors.py 结构测试通过") + print("注:完整测试需要CARLA环境") + + return True + +if __name__ == "__main__": + test_sensor_manager() \ No newline at end of file diff --git a/src/tracking_car/carla_tracking_ros/scripts/sign_detector.py b/src/tracking_car/carla_tracking_ros/scripts/sign_detector.py new file mode 100644 index 0000000000..bbd8a82f18 --- /dev/null +++ b/src/tracking_car/carla_tracking_ros/scripts/sign_detector.py @@ -0,0 +1,438 @@ +""" +sign_detector.py - 轻量级交通标志识别模块 +最小化集成,不破坏现有结构 +""" + +import cv2 +import numpy as np +import torch +from typing import List, Dict, Any, Optional +import time + +try: + from loguru import logger +except ImportError: + import logging + logger = logging.getLogger(__name__) + +class TrafficSignDetector: + """轻量级交通标志检测器""" + + def __init__(self, config=None): + """ + 初始化 + Args: + config: 可选配置,默认使用内置简单配置 + """ + # 默认配置 + self.config = config or { + 'enabled': True, + 'conf_threshold': 0.5, + 'show_signs': True, + 'enable_actions': False, # 默认不触发动作,只显示 + } + + # 尝试加载YOLO模型,如果失败则使用简单的颜色检测 + self.model = None + self.use_yolo = False + + try: + from ultralytics import YOLO + # 尝试加载预训练模型(可以用通用物体检测模型) + self.model = YOLO('yolov8n.pt') # 使用现有的YOLO模型 + self.use_yolo = True + logger.info("✅ 使用YOLO进行标志检测") + except Exception as e: + logger.warning(f"无法加载YOLO模型,使用简单颜色检测: {e}") + self._init_simple_detector() + + # 简单的标志颜色检测 + self.sign_colors = { + 'red': { # 停车、禁止类标志 + 'lower': np.array([0, 50, 50]), + 'upper': np.array([10, 255, 255]) + }, + 'blue': { # 指示类标志 + 'lower': np.array([100, 50, 50]), + 'upper': np.array([130, 255, 255]) + }, + 'yellow': { # 警告类标志 + 'lower': np.array([20, 100, 100]), + 'upper': np.array([30, 255, 255]) + } + } + + # 标志形状模板(可选) + self.shape_templates = self._load_shape_templates() + + # 检测历史 + self.detected_signs_history = [] + + logger.info("✅ 轻量级标志检测器初始化完成") + + def _init_simple_detector(self): + """初始化简单检测器""" + # 加载简单的形状模板 + self.templates = { + 'triangle': self._create_triangle_mask(), + 'circle': self._create_circle_mask(), + 'octagon': self._create_octagon_mask(), # 停车标志 + 'square': self._create_square_mask() + } + + def detect(self, image: np.ndarray, ego_speed: float = 0.0) -> List[Dict]: + """ + 检测图像中的交通标志 + + Args: + image: 输入图像 + ego_speed: 自车速度(用于距离估计) + + Returns: + List[Dict]: 检测到的标志列表 + """ + if not self.config.get('enabled', True): + return [] + + signs = [] + + try: + # 方法1:如果YOLO可用,使用YOLO检测 + if self.use_yolo and self.model is not None: + signs = self._detect_with_yolo(image, ego_speed) + + # 方法2:否则使用简单的颜色+形状检测 + else: + signs = self._detect_with_color(image, ego_speed) + + # 过滤重复的标志(简单的NMS) + signs = self._non_max_suppression(signs) + + # 更新历史 + if signs: + self.detected_signs_history = signs[:10] # 保留最近10个 + + # 简单的动作触发(可选的) + if self.config.get('enable_actions', False): + self._trigger_actions(signs, ego_speed) + + return signs + + except Exception as e: + logger.error(f"标志检测失败: {e}") + return [] + + def _detect_with_yolo(self, image: np.ndarray, ego_speed: float) -> List[Dict]: + """使用YOLO检测""" + results = self.model.predict( + image, + conf=self.config.get('conf_threshold', 0.5), + verbose=False + ) + + signs = [] + + for result in results: + if result.boxes is not None: + for box in result.boxes: + # 过滤出可能是交通标志的类别(YOLO COCO数据集中的类别) + class_id = int(box.cls[0]) + class_name = result.names[class_id] + + # 只保留相关类别 + if class_name in ['stop sign', 'traffic light', 'parking meter']: + bbox = box.xyxy[0].cpu().numpy() + confidence = float(box.conf[0]) + + # 推断标志类型 + sign_type = self._infer_sign_type(class_name, bbox, image) + + sign_info = { + 'bbox': bbox.tolist(), + 'confidence': confidence, + 'type': sign_type, + 'class_id': class_id, + 'class_name': class_name, + 'timestamp': time.time() + } + + # 估算距离(基于边界框大小) + sign_info['distance'] = self._estimate_distance(sign_info['bbox'], ego_speed) + + signs.append(sign_info) + + return signs + + def _detect_with_color(self, image: np.ndarray, ego_speed: float) -> List[Dict]: + """使用颜色检测""" + # 转换为HSV颜色空间 + hsv = cv2.cvtColor(image, cv2.COLOR_BGR2HSV) + + signs = [] + + # 检测每种标志颜色 + for color_name, color_range in self.sign_colors.items(): + # 创建颜色掩码 + mask = cv2.inRange(hsv, color_range['lower'], color_range['upper']) + + # 形态学操作去除噪声 + kernel = np.ones((5, 5), np.uint8) + mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel) + mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel) + + # 查找轮廓 + contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) + + for contour in contours: + area = cv2.contourArea(contour) + + # 过滤太小的区域 + if area < 100: + continue + + # 获取边界框 + x, y, w, h = cv2.boundingRect(contour) + + # 计算形状特征 + shape = self._detect_shape(contour) + + # 推断标志类型 + sign_type = self._infer_sign_type_from_color(color_name, shape) + + if sign_type: + sign_info = { + 'bbox': [x, y, x + w, y + h], + 'confidence': min(0.9, area / 1000.0), # 简单置信度 + 'type': sign_type, + 'color': color_name, + 'shape': shape, + 'timestamp': time.time() + } + + # 估算距离 + sign_info['distance'] = self._estimate_distance(sign_info['bbox'], ego_speed) + + signs.append(sign_info) + + return signs + + def _infer_sign_type(self, class_name: str, bbox: np.ndarray, image: np.ndarray) -> str: + """根据检测结果推断标志类型""" + type_map = { + 'stop sign': 'stop', + 'traffic light': 'traffic_light', + 'parking meter': 'parking', + } + + # 如果是stop sign,进一步确认(检查是否是八边形) + if class_name == 'stop sign': + # 提取ROI检查形状 + x1, y1, x2, y2 = map(int, bbox) + roi = image[y1:y2, x1:x2] + if self._is_octagon_shape(roi): + return 'stop' + + return type_map.get(class_name, 'unknown') + + def _infer_sign_type_from_color(self, color: str, shape: str) -> Optional[str]: + """根据颜色和形状推断标志类型""" + # 简单的推断规则 + if color == 'red' and shape == 'octagon': + return 'stop' + elif color == 'red' and shape == 'circle': + return 'no_entry' + elif color == 'yellow' and shape == 'triangle': + return 'warning' + elif color == 'blue' and shape == 'circle': + return 'mandatory' + elif color == 'blue' and shape == 'square': + return 'information' + + return None + + def _detect_shape(self, contour) -> str: + """检测轮廓形状""" + approx = cv2.approxPolyDP(contour, 0.04 * cv2.arcLength(contour, True), True) + num_sides = len(approx) + + if num_sides == 3: + return 'triangle' + elif num_sides == 4: + # 判断是正方形还是长方形 + x, y, w, h = cv2.boundingRect(contour) + aspect_ratio = float(w) / h + if 0.8 <= aspect_ratio <= 1.2: + return 'square' + else: + return 'rectangle' + elif 8 <= num_sides <= 12: # 八边形 + return 'octagon' + else: + # 计算圆形度 + area = cv2.contourArea(contour) + perimeter = cv2.arcLength(contour, True) + circularity = 4 * np.pi * area / (perimeter * perimeter) + if circularity > 0.7: + return 'circle' + + return 'unknown' + + def _estimate_distance(self, bbox: List[float], ego_speed: float) -> float: + """估算标志距离(简化版本)""" + # 基于边界框高度估算距离 + # 假设标志的标准高度为0.5米,焦距为1000像素 + x1, y1, x2, y2 = bbox + bbox_height = y2 - y1 + + if bbox_height <= 0: + return 100.0 # 默认远距离 + + # 简化距离公式:距离 ∝ 1/高度 + distance = 500.0 / bbox_height + + # 根据速度调整(运动模糊) + if ego_speed > 10: + distance *= (1 + ego_speed / 100.0) + + return min(distance, 100.0) # 最大100米 + + def _non_max_suppression(self, signs: List[Dict], iou_threshold: float = 0.5) -> List[Dict]: + """简单的非极大值抑制""" + if len(signs) <= 1: + return signs + + # 按置信度排序 + sorted_signs = sorted(signs, key=lambda x: x['confidence'], reverse=True) + keep = [] + + while sorted_signs: + # 取置信度最高的 + best = sorted_signs.pop(0) + keep.append(best) + + # 移除与best重叠度高的 + sorted_signs = [ + sign for sign in sorted_signs + if self._iou(best['bbox'], sign['bbox']) < iou_threshold + ] + + return keep + + def _iou(self, box1: List[float], box2: List[float]) -> float: + """计算IoU""" + x1 = max(box1[0], box2[0]) + y1 = max(box1[1], box2[1]) + x2 = min(box1[2], box2[2]) + y2 = min(box1[3], box2[3]) + + inter_area = max(0, x2 - x1) * max(0, y2 - y1) + box1_area = (box1[2] - box1[0]) * (box1[3] - box1[1]) + box2_area = (box2[2] - box2[0]) * (box2[3] - box2[1]) + + union_area = box1_area + box2_area - inter_area + + return inter_area / union_area if union_area > 0 else 0.0 + + def _trigger_actions(self, signs: List[Dict], ego_speed: float): + """触发简单动作(可选)""" + for sign in signs: + sign_type = sign.get('type', '') + distance = sign.get('distance', 100.0) + + if sign_type == 'stop' and distance < 20: + logger.warning(f"🛑 前方 {distance:.1f}米有停车标志") + + elif 'speed_limit' in sign_type and distance < 30: + try: + limit = int(sign_type.split('_')[-1]) + if ego_speed * 3.6 > limit + 5: # m/s转km/h + logger.warning(f"📏 前方限速{limit}km/h,当前{ego_speed*3.6:.0f}km/h") + except: + pass + + def draw_signs(self, image: np.ndarray, signs: List[Dict]) -> np.ndarray: + """在图像上绘制检测到的标志""" + if not self.config.get('show_signs', True): + return image + + result = image.copy() + + for sign in signs: + bbox = sign['bbox'] + sign_type = sign.get('type', 'unknown') + confidence = sign.get('confidence', 0.0) + distance = sign.get('distance', 0.0) + + x1, y1, x2, y2 = map(int, bbox) + + # 根据标志类型选择颜色 + color_map = { + 'stop': (0, 0, 255), # 红色 + 'warning': (0, 165, 255), # 橙色 + 'traffic_light': (0, 255, 0), # 绿色 + 'speed_limit': (0, 255, 255), # 黄色 + 'unknown': (128, 128, 128) # 灰色 + } + + color = color_map.get(sign_type, (128, 128, 128)) + + # 绘制边界框 + cv2.rectangle(result, (x1, y1), (x2, y2), color, 2) + + # 绘制标签 + label = f"{sign_type} {confidence:.2f}" + if distance > 0: + label += f" {distance:.1f}m" + + cv2.putText(result, label, (x1, y1 - 10), + cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2) + + return result + + def get_signs_info(self) -> Dict: + """获取标志检测统计信息""" + return { + 'total_signs': len(self.detected_signs_history), + 'recent_signs': self.detected_signs_history[-5:] if self.detected_signs_history else [], + 'enabled': self.config.get('enabled', True) + } + + +def test_detector(): + """测试检测器""" + detector = TrafficSignDetector() + + # 测试图片 + test_image = np.zeros((480, 640, 3), dtype=np.uint8) + + # 模拟一个红色八边形(停车标志) + center = (320, 240) + radius = 50 + points = [] + for i in range(8): + angle = 2 * np.pi * i / 8 + x = center[0] + radius * np.cos(angle) + y = center[1] + radius * np.sin(angle) + points.append((int(x), int(y))) + + cv2.fillPoly(test_image, [np.array(points)], (0, 0, 255)) + + # 检测 + signs = detector.detect(test_image) + + print(f"检测到 {len(signs)} 个标志") + for sign in signs: + print(f" 类型: {sign.get('type')}, 置信度: {sign.get('confidence'):.2f}") + + # 绘制 + result = detector.draw_signs(test_image, signs) + + return len(signs) > 0 + + +if __name__ == "__main__": + success = test_detector() + if success: + print("✅ 标志检测器测试通过") + else: + print("⚠️ 测试未检测到标志") \ No newline at end of file diff --git a/src/tracking_car/carla_tracking_ros/scripts/tracker.py b/src/tracking_car/carla_tracking_ros/scripts/tracker.py new file mode 100644 index 0000000000..a15d34dd96 --- /dev/null +++ b/src/tracking_car/carla_tracking_ros/scripts/tracker.py @@ -0,0 +1,966 @@ +""" +tracker.py - 目标检测与跟踪核心算法 +包含:YOLO检测器、卡尔曼滤波、SORT跟踪器、行为分析 +""" + +import numpy as np +import cv2 +import torch +import queue +import threading +import time +import sys +import os +import queue + +# 配置日志 +try: + from loguru import logger +except ImportError: + # 使用标准logging作为回退 + import logging + logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') + logger = logging.getLogger(__name__) + +from ultralytics import YOLO +from scipy.optimize import linear_sum_assignment +from dataclasses import dataclass +from typing import List, Tuple, Optional, Dict, Any + +# 导入utils模块中的工具函数 +try: + from utils import iou, iou_numpy, clip_box, valid_img, bbox_center +except ImportError: + # 如果在同一目录下,可以直接导入 + sys.path.append(os.path.dirname(os.path.abspath(__file__))) + from utils import iou, iou_numpy, clip_box, valid_img, bbox_center + +# ======================== 数据结构 ======================== + +@dataclass +class Detection: + """检测结果数据结构""" + bbox: np.ndarray # [x1, y1, x2, y2] + confidence: float + class_id: int + class_name: str = "Unknown" + + def __post_init__(self): + self.bbox = np.array(self.bbox, dtype=np.float32) + self.area = (self.bbox[2] - self.bbox[0]) * (self.bbox[3] - self.bbox[1]) + + +@dataclass +class TrackState: + """跟踪状态枚举""" + NEW = "new" + TRACKED = "tracked" + LOST = "lost" + REMOVED = "removed" + + +# ======================== 卡尔曼滤波器 ======================== + +class KalmanFilter: + """ + 卡尔曼滤波器 - 用于目标状态估计 + 8维状态: [x1, y1, x2, y2, vx1, vy1, vx2, vy2] + 4维观测: [x1, y1, x2, y2] + """ + + def __init__(self, dt=0.05, max_speed=50.0): + """ + 初始化卡尔曼滤波器 + + Args: + dt: 时间间隔(秒) + max_speed: 最大速度(像素/秒) + """ + self.dt = dt + self.max_speed = max_speed + + # 状态向量维度: 8 + # 观测向量维度: 4 + self.state_dim = 8 + self.measure_dim = 4 + + # 状态转移矩阵 F + self.F = np.eye(self.state_dim, dtype=np.float32) + for i in range(4): + self.F[i, i + 4] = dt + + # 观测矩阵 H + self.H = np.zeros((self.measure_dim, self.state_dim), dtype=np.float32) + for i in range(self.measure_dim): + self.H[i, i] = 1.0 + + # 过程噪声协方差矩阵 Q + self.Q = np.eye(self.state_dim, dtype=np.float32) + for i in range(4): + self.Q[i, i] = 1.0 + for i in range(4, 8): + self.Q[i, i] = 5.0 + + # 观测噪声协方差矩阵 R + self.R = np.eye(self.measure_dim, dtype=np.float32) * 5.0 + + # 状态协方差矩阵 P + self.P = np.eye(self.state_dim, dtype=np.float32) * 50.0 + + # 状态向量 x + self.x = np.zeros(self.state_dim, dtype=np.float32) + + # 首次更新标志 + self.first_update = True + + def init(self, bbox): + """ + 初始化滤波器状态 + + Args: + bbox: 初始边界框 [x1, y1, x2, y2] + """ + self.x[:4] = bbox + self.first_update = True + + def predict(self): + """ + 状态预测 + + Returns: + np.ndarray: 预测的边界框 + """ + # 状态预测 + self.x = self.F @ self.x + + # 协方差预测 + self.P = self.F @ self.P @ self.F.T + self.Q + + # 返回预测的边界框 + return self.x[:4].copy() + + def update(self, bbox): + """ + 状态更新 + + Args: + bbox: 观测到的边界框 [x1, y1, x2, y2] + + Returns: + np.ndarray: 更新后的边界框 + """ + z = np.array(bbox, dtype=np.float32) + + # 计算卡尔曼增益 + S = self.H @ self.P @ self.H.T + self.R + try: + K = self.P @ self.H.T @ np.linalg.inv(S) + except np.linalg.LinAlgError: + # 如果矩阵不可逆,使用伪逆 + K = self.P @ self.H.T @ np.linalg.pinv(S) + + # 计算新息 + y = z - self.H @ self.x + + # 状态更新 + self.x = self.x + K @ y + + # 协方差更新 + I = np.eye(self.state_dim, dtype=np.float32) + self.P = (I - K @ self.H) @ self.P + + self.first_update = False + + return self.x[:4].copy() + + def update_noise(self, speed): + """ + 根据速度更新过程噪声 + + Args: + speed: 估计的速度(像素/秒) + """ + # 速度归一化 + speed_factor = min(1.0, speed / self.max_speed) + + # 更新过程噪声协方差 + for i in range(4): + self.Q[i, i] = 1.0 + speed_factor * 4.0 + for i in range(4, 8): + self.Q[i, i] = 5.0 + speed_factor * 20.0 + + +# ======================== 跟踪目标 ======================== + +class TrackedObject: + """ + 单个跟踪目标 + """ + + def __init__(self, track_id: int, bbox: np.ndarray, + img_shape: Tuple[int, int], config: Dict[str, Any]): + """ + 初始化跟踪目标 + + Args: + track_id: 跟踪ID + bbox: 初始边界框 [x1, y1, x2, y2] + img_shape: 图像尺寸 (height, width) + config: 配置字典 + """ + self.track_id = track_id + self.img_shape = img_shape + self.config = config + + # 卡尔曼滤波器 + self.kf = KalmanFilter( + dt=config.get('kf_dt', 0.05), + max_speed=config.get('max_speed', 50.0) + ) + + # 边界框处理 + self.bbox = clip_box(bbox.astype(np.float32), img_shape) + self.kf.init(self.bbox) + + # 跟踪历史 + self.track_history: List[Tuple[float, float]] = [] # [(cx, cy), ...] + self.speed_history: List[float] = [] # 速度历史 + self.acceleration_history: List[float] = [] # 加速度历史 + + # 状态管理 + self.state = TrackState.NEW + self.age = 0 # 存在帧数 + self.time_since_update = 0 # 自上次更新以来的帧数 + self.hits = 1 # 匹配次数 + self.total_frames = 0 # 总跟踪帧数 + + # 检测信息 + self.class_id: Optional[int] = None + self.class_name: str = "Unknown" + self.confidence: float = 0.0 + + # 行为分析 + self.is_stopped = False + self.is_overtaking = False + self.is_lane_changing = False + self.is_braking = False + self.is_accelerating = False + self.is_turning = False + self.is_dangerous = False + + self.stop_frames = 0 + self.overtake_frames = 0 + self.lane_change_frames = 0 + self.brake_frames = 0 + self.turn_frames = 0 + + # 预测轨迹 + self.predicted_trajectory: List[Tuple[float, float]] = [] + + # 初始化历史 + self._update_history() + + def _update_history(self): + """更新跟踪历史""" + cx, cy = bbox_center(self.bbox) + self.track_history.append((cx, cy)) + + # 限制历史长度 + max_len = self.config.get('track_history_len', 20) + if len(self.track_history) > max_len: + self.track_history.pop(0) + + # 限制速度历史 + if len(self.speed_history) > 10: + self.speed_history.pop(0) + + # 限制加速度历史 + if len(self.acceleration_history) > 10: + self.acceleration_history.pop(0) + + def _calculate_speed(self) -> float: + """ + 计算当前速度 + + Returns: + float: 速度(像素/秒) + """ + if len(self.track_history) < 2: + return 0.0 + + # 计算最后两帧的位移 + prev_cx, prev_cy = self.track_history[-2] + curr_cx, curr_cy = self.track_history[-1] + + dx = curr_cx - prev_cx + dy = curr_cy - prev_cy + distance = np.sqrt(dx**2 + dy**2) + + # 计算速度 + speed = distance / self.kf.dt + + # 更新速度历史 + self.speed_history.append(speed) + + # 计算加速度 + if len(self.speed_history) >= 2: + acceleration = (self.speed_history[-1] - self.speed_history[-2]) / self.kf.dt + self.acceleration_history.append(acceleration) + + return speed + + def _calculate_heading(self) -> float: + """ + 计算当前航向角 + + Returns: + float: 航向角(度) + """ + if len(self.track_history) < 3: + return 0.0 + + # 使用最近三帧计算航向 + cx1, cy1 = self.track_history[-3] + cx2, cy2 = self.track_history[-1] + + dx = cx2 - cx1 + dy = cy2 - cy1 + + # 计算角度(弧度转度) + angle = np.degrees(np.arctan2(dy, dx)) + + return angle + + def _analyze_behavior(self, ego_center: Optional[Tuple[float, float]] = None): + """ + 分析目标行为 + + Args: + ego_center: 自车中心点坐标 + """ + # 计算基本状态 + speed = self._calculate_speed() + heading = self._calculate_heading() + + # 1. 停车检测 + stop_speed_thresh = self.config.get('stop_speed_thresh', 1.0) + stop_frames_thresh = self.config.get('stop_frames_thresh', 5) + + if speed < stop_speed_thresh: + self.stop_frames += 1 + self.is_stopped = self.stop_frames >= stop_frames_thresh + else: + self.stop_frames = 0 + self.is_stopped = False + + # 2. 超车检测 + overtake_speed_ratio = self.config.get('overtake_speed_ratio', 1.5) + overtake_dist_thresh = self.config.get('overtake_dist_thresh', 50.0) + + if ego_center and len(self.track_history) >= 2: + curr_cx, curr_cy = self.track_history[-1] + distance = np.sqrt((curr_cx - ego_center[0])**2 + (curr_cy - ego_center[1])**2) + + if distance < overtake_dist_thresh: + ego_speed = getattr(self, 'ego_speed', 0.0) + if speed > ego_speed * overtake_speed_ratio: + self.overtake_frames += 1 + self.is_overtaking = self.overtake_frames >= 3 + else: + self.overtake_frames = 0 + self.is_overtaking = False + else: + self.overtake_frames = 0 + self.is_overtaking = False + + # 3. 变道检测 + lane_change_thresh = self.config.get('lane_change_thresh', 0.5) + + if len(self.track_history) >= 5: + # 计算横向位移 + lateral_displacements = [] + for i in range(1, min(5, len(self.track_history))): + lateral_displacements.append( + abs(self.track_history[-i][0] - self.track_history[-i-1][0]) + ) + + avg_lateral = np.mean(lateral_displacements) if lateral_displacements else 0.0 + + if avg_lateral > lane_change_thresh: + self.lane_change_frames += 1 + self.is_lane_changing = self.lane_change_frames >= 3 + else: + self.lane_change_frames = 0 + self.is_lane_changing = False + + # 4. 刹车/加速检测 + brake_accel_thresh = self.config.get('brake_accel_thresh', 2.0) + + if len(self.acceleration_history) >= 3: + avg_accel = np.mean(self.acceleration_history[-3:]) + + if avg_accel < -brake_accel_thresh: + self.brake_frames += 1 + self.is_braking = self.brake_frames >= 2 + self.is_accelerating = False + elif avg_accel > brake_accel_thresh: + self.is_accelerating = True + self.is_braking = False + self.brake_frames = 0 + else: + self.is_braking = False + self.is_accelerating = False + self.brake_frames = 0 + + # 5. 转弯检测 + turn_angle_thresh = self.config.get('turn_angle_thresh', 15.0) + + if len(self.track_history) >= 3: + # 计算航向变化 + if hasattr(self, '_prev_heading'): + heading_change = abs(heading - self._prev_heading) + if heading_change > turn_angle_thresh: + self.turn_frames += 1 + self.is_turning = self.turn_frames >= 2 + else: + self.turn_frames = 0 + self.is_turning = False + self._prev_heading = heading + + # 6. 危险距离检测 + danger_dist_thresh = self.config.get('danger_dist_thresh', 10.0) + + if ego_center: + curr_cx, curr_cy = self.track_history[-1] + distance = np.sqrt((curr_cx - ego_center[0])**2 + (curr_cy - ego_center[1])**2) + self.is_dangerous = distance < danger_dist_thresh + + # 7. 预测轨迹 + self._predict_trajectory() + + def _predict_trajectory(self): + """预测未来轨迹""" + predict_frames = self.config.get('predict_frames', 10) + self.predicted_trajectory = [] + + if len(self.track_history) < 5: + return + + # 创建临时的卡尔曼滤波器用于预测 + temp_kf = KalmanFilter( + dt=self.kf.dt, + max_speed=self.kf.max_speed + ) + temp_kf.x = self.kf.x.copy() + temp_kf.P = self.kf.P.copy() + + # 预测未来位置 + for _ in range(predict_frames): + predicted_bbox = temp_kf.predict() + predicted_center = bbox_center(predicted_bbox) + self.predicted_trajectory.append(predicted_center) + + def predict(self) -> np.ndarray: + """ + 预测下一帧的位置 + + Returns: + np.ndarray: 预测的边界框 + """ + # 预测速度用于调整噪声 + if len(self.track_history) >= 2: + speed = self._calculate_speed() + self.kf.update_noise(speed) + + # 卡尔曼预测 + self.bbox = self.kf.predict() + self.bbox = clip_box(self.bbox, self.img_shape) + + # 更新状态 + self._update_history() + self.age += 1 + self.time_since_update += 1 + self.total_frames += 1 + + if self.time_since_update > 1: + self.state = TrackState.LOST + + return self.bbox + + def update(self, detection: Detection, ego_center: Optional[Tuple[float, float]] = None): + """ + 用新的检测更新跟踪 + + Args: + detection: 检测结果 + ego_center: 自车中心点坐标 + """ + # 卡尔曼更新 + self.bbox = self.kf.update(detection.bbox) + self.bbox = clip_box(self.bbox, self.img_shape) + + # 更新检测信息 + self.class_id = detection.class_id + self.class_name = detection.class_name + self.confidence = detection.confidence + + # 更新状态 + self._update_history() + self.hits += 1 + self.time_since_update = 0 + self.state = TrackState.TRACKED + + # 行为分析 + self._analyze_behavior(ego_center) + + def get_behavior_string(self) -> str: + """获取行为描述字符串""" + behaviors = [] + if self.is_stopped: + behaviors.append("停车") + if self.is_overtaking: + behaviors.append("超车") + if self.is_lane_changing: + behaviors.append("变道") + if self.is_braking: + behaviors.append("刹车") + if self.is_accelerating: + behaviors.append("加速") + if self.is_turning: + behaviors.append("转弯") + if self.is_dangerous: + behaviors.append("危险") + + return "|".join(behaviors) if behaviors else "正常" + + +# ======================== YOLO检测器 ======================== + +class YOLODetector: + """ + YOLOv8检测器 + """ + + def __init__(self, config: Dict[str, Any]): + """ + 初始化YOLO检测器 + + Args: + config: 配置字典 + """ + self.config = config + + # 模型配置 + model_path = config.get('yolo_model', 'yolov8n.pt') + self.device = config.get('device', 'cuda' if torch.cuda.is_available() else 'cpu') + self.conf_thres = config.get('conf_thres', 0.5) + self.iou_thres = config.get('iou_thres', 0.3) + self.imgsz_max = config.get('yolo_imgsz_max', 320) + self.quantize = config.get('yolo_quantize', False) + + # 类别过滤(只检测车辆) + self.vehicle_classes = {2: "Car", 5: "Bus", 7: "Truck"} + + # 加载模型 + self.model = self._load_model(model_path) + + logger.info(f"✅ YOLO检测器初始化完成 (设备: {self.device}, 模型: {model_path})") + + def _load_model(self, model_path: str): + """加载YOLO模型""" + try: + model = YOLO(model_path) + + if self.quantize and self.device == "cuda": + model = model.quantize() + + # 预热模型 + dummy_input = np.random.randint(0, 255, (480, 640, 3), dtype=np.uint8) + _ = model.predict(dummy_input, verbose=False, device=self.device) + + return model + + except Exception as e: + logger.error(f"❌ 加载YOLO模型失败: {e}") + raise + + def detect(self, image: np.ndarray) -> List[Detection]: + """ + 检测图像中的目标 + + Args: + image: 输入图像 + + Returns: + List[Detection]: 检测结果列表 + """ + if not valid_img(image): + return [] + + try: + # 调整图像尺寸 + h, w = image.shape[:2] + resize_ratio = min(self.imgsz_max / w, self.imgsz_max / h) + new_w = int(w * resize_ratio) + new_h = int(h * resize_ratio) + + # 确保尺寸是32的倍数 + new_w = (new_w + 31) // 32 * 32 + new_h = (new_h + 31) // 32 * 32 + + # 执行检测 + results = self.model.predict( + image, + conf=self.conf_thres, + iou=self.iou_thres, + imgsz=(new_h, new_w), + device=self.device, + verbose=False, + agnostic_nms=True + ) + + detections = [] + + for result in results: + if result.boxes is not None and len(result.boxes) > 0: + for box in result.boxes: + # 获取边界框 + xyxy = box.xyxy[0].cpu().numpy() + confidence = float(box.conf[0]) + class_id = int(box.cls[0]) + + # 只处理车辆类别 + if class_id in self.vehicle_classes: + # 确保边界框有效 + if xyxy[2] > xyxy[0] and xyxy[3] > xyxy[1] and confidence > 0: + detection = Detection( + bbox=xyxy, + confidence=confidence, + class_id=class_id, + class_name=self.vehicle_classes[class_id] + ) + detections.append(detection) + + return detections + + except Exception as e: + logger.error(f"❌ YOLO检测失败: {e}") + return [] + + +# ======================== SORT跟踪器 ======================== + +class SORTTracker: + """ + SORT (Simple Online and Realtime Tracking) 跟踪器 + """ + + def __init__(self, config: Dict[str, Any]): + """ + 初始化SORT跟踪器 + + Args: + config: 配置字典 + """ + self.config = config + + # 跟踪参数 + self.max_age = config.get('max_age', 5) + self.min_hits = config.get('min_hits', 3) + self.iou_threshold = config.get('iou_thres', 0.3) + + # 图像尺寸 + self.img_height = config.get('img_height', 480) + self.img_width = config.get('img_width', 640) + self.img_shape = (self.img_height, self.img_width) + + # 跟踪目标管理 + self.tracks: List[TrackedObject] = [] + self.next_track_id = 1 + + # 自车信息 + self.ego_center = (self.img_width // 2, self.img_height // 2) + self.ego_speed = 0.0 + + logger.info("✅ SORT跟踪器初始化完成") + + def update(self, detections: List[Detection], + ego_center: Optional[Tuple[float, float]] = None, + lidar_detections: Optional[List[Dict]] = None) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + """ + 更新跟踪器 + + Args: + detections: 检测结果列表 + ego_center: 自车中心点坐标 + lidar_detections: LiDAR检测结果(可选) + + Returns: + Tuple: (边界框数组, ID数组, 类别数组) + """ + # 更新自车信息 + if ego_center: + self.ego_center = ego_center + + # 如果没有检测结果,只进行预测 + if not detections: + # 预测所有现有目标 + for track in self.tracks: + track.predict() + + # 移除丢失时间过长的目标 + self.tracks = [t for t in self.tracks if t.time_since_update <= self.max_age] + + # 返回空结果 + return np.array([]), np.array([]), np.array([]) + + # 预测现有目标的位置 + for track in self.tracks: + track.predict() + + # 创建匹配成本矩阵 + if self.tracks: + # 计算IoU矩阵 + iou_matrix = np.zeros((len(detections), len(self.tracks)), dtype=np.float32) + + for i, det in enumerate(detections): + for j, track in enumerate(self.tracks): + iou_matrix[i, j] = iou(det.bbox, track.bbox) + + # 将IoU转换为成本(1 - IoU) + cost_matrix = 1.0 - iou_matrix + + # 使用匈牙利算法进行匹配 + try: + det_indices, track_indices = linear_sum_assignment(cost_matrix) + except ValueError: + det_indices, track_indices = [], [] + + # 根据阈值过滤匹配 + matched_pairs = [] + unmatched_detections = set(range(len(detections))) + unmatched_tracks = set(range(len(self.tracks))) + + for det_idx, track_idx in zip(det_indices, track_indices): + if iou_matrix[det_idx, track_idx] >= self.iou_threshold: + matched_pairs.append((det_idx, track_idx)) + unmatched_detections.discard(det_idx) + unmatched_tracks.discard(track_idx) + else: + matched_pairs = [] + unmatched_detections = set(range(len(detections))) + unmatched_tracks = set() + + # 更新匹配的目标 + for det_idx, track_idx in matched_pairs: + track = self.tracks[track_idx] + track.ego_speed = self.ego_speed # 传递自车速度用于行为分析 + track.update(detections[det_idx], self.ego_center) + + # 为未匹配的检测创建新目标 + for det_idx in unmatched_detections: + new_track = TrackedObject( + track_id=self.next_track_id, + bbox=detections[det_idx].bbox, + img_shape=self.img_shape, + config=self.config + ) + new_track.update(detections[det_idx], self.ego_center) + self.tracks.append(new_track) + self.next_track_id += 1 + + # 移除长时间未更新的目标 + self.tracks = [t for t in self.tracks if t.time_since_update <= self.max_age] + + # 返回跟踪结果(只返回满足最小匹配次数的目标) + active_tracks = [t for t in self.tracks if t.hits >= self.min_hits and t.state == TrackState.TRACKED] + + if not active_tracks: + return np.array([]), np.array([]), np.array([]) + + # 提取边界框、ID和类别 + boxes = np.array([t.bbox for t in active_tracks]) + ids = np.array([t.track_id for t in active_tracks]) + classes = np.array([t.class_id if t.class_id is not None else -1 for t in active_tracks]) + + return boxes, ids, classes + + def get_tracks_info(self) -> List[Dict[str, Any]]: + """ + 获取所有跟踪目标的详细信息 + + Returns: + List[Dict]: 跟踪目标信息列表 + """ + tracks_info = [] + + for track in self.tracks: + if track.hits >= self.min_hits and track.state == TrackState.TRACKED: + info = { + 'track_id': track.track_id, + 'bbox': track.bbox.tolist(), + 'class_id': track.class_id, + 'class_name': track.class_name, + 'confidence': track.confidence, + 'speed': track._calculate_speed(), + 'behavior': track.get_behavior_string(), + 'age': track.age, + 'hits': track.hits, + 'is_stopped': track.is_stopped, + 'is_overtaking': track.is_overtaking, + 'is_dangerous': track.is_dangerous, + } + tracks_info.append(info) + + return tracks_info + + def reset(self): + """重置跟踪器""" + self.tracks = [] + self.next_track_id = 1 + logger.info("✅ 跟踪器已重置") + + +# ======================== 检测线程 ======================== + +class DetectionThread(threading.Thread): + """ + 检测线程 - 将检测过程放到单独线程中 + """ + + def __init__(self, detector: YOLODetector, input_queue: queue.Queue, + output_queue: queue.Queue, maxsize: int = 2): + """ + 初始化检测线程 + + Args: + detector: YOLO检测器 + input_queue: 输入图像队列 + output_queue: 输出检测结果队列 + maxsize: 队列最大大小 + """ + super().__init__(daemon=True) + self.detector = detector + self.input_queue = input_queue + self.output_queue = output_queue + self.running = True + self.processed_count = 0 + + logger.info("✅ 检测线程初始化完成") + + def run(self): + """线程主函数""" + while self.running: + try: + # 从输入队列获取图像 + image = self.input_queue.get(timeout=1.0) + + if not valid_img(image): + self.output_queue.put((image, [])) + continue + + # 执行检测 + detections = self.detector.detect(image) + + # 放入输出队列 + if self.output_queue.full(): + try: + self.output_queue.get_nowait() + except queue.Empty: + pass + + self.output_queue.put((image, detections)) + self.processed_count += 1 + + except queue.Empty: + continue + except Exception as e: + logger.error(f"❌ 检测线程出错: {e}") + self.output_queue.put((None, [])) + + def stop(self): + """停止线程""" + self.running = False + logger.info("🛑 检测线程已停止") + + +# ======================== 测试函数 ======================== + +def test_tracker(): + """测试跟踪器""" + print("=" * 50) + print("测试 tracker.py...") + print("=" * 50) + + # 模拟配置 + test_config = { + 'yolo_model': 'yolov8n.pt', + 'conf_thres': 0.5, + 'iou_thres': 0.3, + 'max_age': 5, + 'min_hits': 3, + 'kf_dt': 0.05, + 'max_speed': 50.0, + 'img_width': 640, + 'img_height': 480, + 'track_history_len': 20, + 'stop_speed_thresh': 1.0, + 'stop_frames_thresh': 5, + 'overtake_speed_ratio': 1.5, + 'overtake_dist_thresh': 50.0, + 'lane_change_thresh': 0.5, + 'brake_accel_thresh': 2.0, + 'turn_angle_thresh': 15.0, + 'danger_dist_thresh': 10.0, + 'predict_frames': 10, + } + + # 测试数据结构 + print("1. 测试数据结构...") + bbox = np.array([100, 100, 200, 200], dtype=np.float32) + detection = Detection(bbox=bbox, confidence=0.9, class_id=2, class_name="Car") + assert detection.confidence == 0.9 + assert detection.class_id == 2 + print(" ✅ Detection数据结构测试通过") + + # 测试卡尔曼滤波器 + print("2. 测试卡尔曼滤波器...") + kf = KalmanFilter(dt=0.05) + kf.init(bbox) + predicted = kf.predict() + assert len(predicted) == 4 + updated = kf.update(bbox + 10) + assert len(updated) == 4 + print(" ✅ 卡尔曼滤波器测试通过") + + # 测试跟踪目标 + print("3. 测试跟踪目标...") + track = TrackedObject( + track_id=1, + bbox=bbox, + img_shape=(480, 640), + config=test_config + ) + track.update(detection, ego_center=(320, 240)) + assert track.track_id == 1 + assert track.class_id == 2 + track.predict() + print(" ✅ 跟踪目标测试通过") + + # 测试SORT跟踪器 + print("4. 测试SORT跟踪器...") + tracker = SORTTracker(test_config) + detections = [detection] + boxes, ids, classes = tracker.update(detections) + assert len(boxes) >= 0 # 可能没有匹配到 + print(" ✅ SORT跟踪器测试通过") + + print("=" * 50) + print("✅ tracker.py 所有测试通过") + print("注:完整测试需要YOLO模型文件") + + return True + + +if __name__ == "__main__": + test_tracker() \ No newline at end of file diff --git a/src/tracking_car/carla_tracking_ros/scripts/utils.py b/src/tracking_car/carla_tracking_ros/scripts/utils.py new file mode 100644 index 0000000000..0ddf17f2a4 --- /dev/null +++ b/src/tracking_car/carla_tracking_ros/scripts/utils.py @@ -0,0 +1,689 @@ +""" +utils.py - 通用工具函数 +包含:图像处理、几何计算、性能监控、文件操作等工具函数 +""" + +import cv2 +import numpy as np +import time +import os +import sys +from numba import njit +from datetime import datetime + +# 配置loguru logger +# 配置日志 +try: + from loguru import logger +except ImportError: + import logging + logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') + logger = logging.getLogger(__name__) + +# 尝试导入yaml,如果失败提供友好的错误信息 +try: + import yaml + YAML_AVAILABLE = True +except ImportError: + YAML_AVAILABLE = False + logger.warning("PyYAML未安装,配置文件功能将受限") + +def valid_img(img): + """ + 检查图像是否有效 + + Args: + img: 输入图像 + + Returns: + bool: 图像是否有效 + """ + return img is not None and len(img.shape) == 3 and img.shape[2] == 3 and img.size > 0 + +def clip_box(bbox, img_shape): + """ + 裁剪边界框到图像范围内 + + Args: + bbox: [x1, y1, x2, y2] 边界框坐标 + img_shape: (height, width) 图像尺寸 + + Returns: + np.ndarray: 裁剪后的边界框 + """ + h, w = img_shape[:2] + return np.array([ + max(0, min(bbox[0], w - 1)), + max(0, min(bbox[1], h - 1)), + max(bbox[0] + 1, min(bbox[2], w - 1)), + max(bbox[1] + 1, min(bbox[3], h - 1)) + ], dtype=np.float32) + +def make_div(x, d=32): + """ + 将数值调整为d的倍数(用于YOLO输入尺寸) + + Args: + x: 原始数值 + d: 倍数(默认为32) + + Returns: + int: 调整后的数值 + """ + return (x + d - 1) // d * d + +def resize_with_padding(image, target_size, color=(114, 114, 114)): + """ + 保持长宽比的resize,用指定颜色填充 + + Args: + image: 输入图像 + target_size: (width, height) 目标尺寸 + color: 填充颜色 + + Returns: + tuple: (resized_image, scale, padding) + """ + h, w = image.shape[:2] + target_w, target_h = target_size + + # 计算缩放比例 + scale = min(target_w / w, target_h / h) + new_w = int(w * scale) + new_h = int(h * scale) + + # 缩放图像 + if scale != 1: + image = cv2.resize(image, (new_w, new_h), interpolation=cv2.INTER_LINEAR) + + # 创建填充图像 + padded = np.full((target_h, target_w, 3), color, dtype=np.uint8) + + # 计算填充位置(居中) + dx = (target_w - new_w) // 2 + dy = (target_h - new_h) // 2 + + # 放置图像 + padded[dy:dy + new_h, dx:dx + new_w] = image + + return padded, scale, (dx, dy) + +@njit +def iou_numpy(box1, box2): + """ + 计算两个边界框的IoU(交并比)- 使用numpy数组版本 + + Args: + box1: np.array([x1, y1, x2, y2]) + box2: np.array([x1, y1, x2, y2]) + + Returns: + float: IoU值 + """ + ix1 = max(box1[0], box2[0]) + iy1 = max(box1[1], box2[1]) + ix2 = min(box1[2], box2[2]) + iy2 = min(box1[3], box2[3]) + + ia = max(0, ix2 - ix1) * max(0, iy2 - iy1) + a1 = (box1[2] - box1[0]) * (box1[3] - box1[1]) + a2 = (box2[2] - box2[0]) * (box2[3] - box2[1]) + ua = a1 + a2 - ia + + return ia / ua if ua > 0 else 0.0 + +def iou(box1, box2): + """ + 计算两个边界框的IoU(兼容list和numpy数组) + + Args: + box1: [x1, y1, x2, y2] 或 np.array + box2: [x1, y1, x2, y2] 或 np.array + + Returns: + float: IoU值 + """ + # 转换为numpy数组 + box1_np = np.array(box1, dtype=np.float32) + box2_np = np.array(box2, dtype=np.float32) + return iou_numpy(box1_np, box2_np) + +@njit +def iou_batch(boxes1, boxes2): + """ + 批量计算IoU矩阵 + + Args: + boxes1: (N, 4) 边界框数组 + boxes2: (M, 4) 边界框数组 + + Returns: + np.ndarray: (N, M) IoU矩阵 + """ + N = boxes1.shape[0] + M = boxes2.shape[0] + iou_matrix = np.zeros((N, M), dtype=np.float32) + + for i in range(N): + for j in range(M): + iou_matrix[i, j] = iou_numpy(boxes1[i], boxes2[j]) + + return iou_matrix + +def bbox_center(bbox): + """ + 计算边界框中心点 + + Args: + bbox: [x1, y1, x2, y2] 边界框 + + Returns: + tuple: (cx, cy) 中心点坐标 + """ + return ((bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2) + +def bbox_area(bbox): + """ + 计算边界框面积 + + Args: + bbox: [x1, y1, x2, y2] 边界框 + + Returns: + float: 边界框面积 + """ + return max(0, bbox[2] - bbox[0]) * max(0, bbox[3] - bbox[1]) + +def bbox_aspect_ratio(bbox): + """ + 计算边界框宽高比 + + Args: + bbox: [x1, y1, x2, y2] 边界框 + + Returns: + float: 宽高比(宽/高) + """ + width = max(0.1, bbox[2] - bbox[0]) + height = max(0.1, bbox[3] - bbox[1]) + return width / height + +class FPSCounter: + """ + FPS计数器 + """ + + def __init__(self, window_size=15): + """ + Args: + window_size: 滑动窗口大小 + """ + self.window_size = window_size + self.timestamps = [] + self.fps = 0.0 + self.avg_fps = 0.0 + self.fps_history = [] + + def update(self): + """ + 更新FPS计数 + + Returns: + float: 当前FPS + """ + self.timestamps.append(time.time()) + + if len(self.timestamps) > self.window_size: + self.timestamps.pop(0) + + if len(self.timestamps) >= 2: + self.fps = (len(self.timestamps) - 1) / (self.timestamps[-1] - self.timestamps[0]) + self.fps_history.append(self.fps) + + if len(self.fps_history) > 100: + self.fps_history.pop(0) + + self.avg_fps = np.mean(self.fps_history) if self.fps_history else self.fps + + return self.fps + + def reset(self): + """重置计数器""" + self.timestamps = [] + self.fps = 0.0 + self.fps_history = [] + self.avg_fps = 0.0 + +class PerformanceMonitor: + """ + 性能监控器 + """ + + def __init__(self): + self.frame_count = 0 + self.start_time = time.time() + self.frame_times = [] + self.detection_times = [] + self.tracking_times = [] + + def start_frame(self): + """开始新帧计时""" + self.frame_start = time.time() + + def end_frame(self): + """结束帧计时""" + frame_time = time.time() - self.frame_start + self.frame_times.append(frame_time) + self.frame_count += 1 + + # 保留最近100帧的计时 + if len(self.frame_times) > 100: + self.frame_times.pop(0) + + def record_detection_time(self, dt): + """记录检测时间""" + self.detection_times.append(dt) + if len(self.detection_times) > 100: + self.detection_times.pop(0) + + def record_tracking_time(self, dt): + """记录跟踪时间""" + self.tracking_times.append(dt) + if len(self.tracking_times) > 100: + self.tracking_times.pop(0) + + def get_stats(self): + """获取性能统计""" + stats = { + 'total_frames': self.frame_count, + 'total_time': time.time() - self.start_time, + 'avg_fps': len(self.frame_times) / sum(self.frame_times) if self.frame_times else 0, + 'avg_frame_time': np.mean(self.frame_times) * 1000 if self.frame_times else 0, + 'avg_detection_time': np.mean(self.detection_times) * 1000 if self.detection_times else 0, + 'avg_tracking_time': np.mean(self.tracking_times) * 1000 if self.tracking_times else 0, + } + return stats + + def print_stats(self): + """打印性能统计""" + stats = self.get_stats() + logger.info(f"总帧数: {stats['total_frames']}") + logger.info(f"总时间: {stats['total_time']:.1f}s") + logger.info(f"平均FPS: {stats['avg_fps']:.1f}") + logger.info(f"平均帧时间: {stats['avg_frame_time']:.1f}ms") + logger.info(f"平均检测时间: {stats['avg_detection_time']:.1f}ms") + logger.info(f"平均跟踪时间: {stats['avg_tracking_time']:.1f}ms") + +def create_output_dir(base_dir="outputs"): + """ + 创建输出目录 + + Args: + base_dir: 基础目录名 + + Returns: + str: 创建的目录路径 + """ + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + output_dir = os.path.join(base_dir, timestamp) + + # 创建目录 + os.makedirs(output_dir, exist_ok=True) + os.makedirs(os.path.join(output_dir, "screenshots"), exist_ok=True) + os.makedirs(os.path.join(output_dir, "logs"), exist_ok=True) + + logger.info(f"创建输出目录: {output_dir}") + return output_dir + +def save_image(image, path, create_dir=True): + """ + 保存图像 + + Args: + image: 要保存的图像 + path: 保存路径 + create_dir: 是否创建目录 + + Returns: + bool: 是否保存成功 + """ + if not valid_img(image): + logger.warning(f"无效图像,无法保存到 {path}") + return False + + try: + if create_dir: + os.makedirs(os.path.dirname(path), exist_ok=True) + + cv2.imwrite(path, image) + logger.debug(f"图像已保存: {path}") + return True + + except Exception as e: + logger.error(f"保存图像失败 {path}: {e}") + return False + +def load_yaml_config(path): + """ + 加载YAML配置文件 + + Args: + path: 配置文件路径 + + Returns: + dict: 配置字典 + """ + if not YAML_AVAILABLE: + logger.error("无法加载YAML配置: PyYAML未安装") + return {} + + try: + with open(path, 'r', encoding='utf-8') as f: + config = yaml.safe_load(f) + logger.info(f"配置文件加载成功: {path}") + return config if config else {} + except FileNotFoundError: + logger.warning(f"配置文件不存在: {path}") + return {} + except Exception as e: + logger.error(f"加载配置文件失败 {path}: {e}") + return {} + +def save_yaml_config(config, path): + """ + 保存配置到YAML文件 + + Args: + config: 配置字典 + path: 保存路径 + """ + if not YAML_AVAILABLE: + logger.error("无法保存YAML配置: PyYAML未安装") + return + + try: + os.makedirs(os.path.dirname(path), exist_ok=True) + with open(path, 'w', encoding='utf-8') as f: + yaml.dump(config, f, default_flow_style=False, indent=2) + logger.debug(f"配置已保存: {path}") + except Exception as e: + logger.error(f"保存配置失败 {path}: {e}") + +def draw_bbox(image, bbox, color=(255, 0, 0), thickness=2, label=None): + """ + 在图像上绘制单个边界框 + + Args: + image: 输入图像 + bbox: [x1, y1, x2, y2] 边界框 + color: 颜色 (B, G, R) + thickness: 线宽 + label: 标签文本 + + Returns: + np.ndarray: 绘制后的图像 + """ + if not valid_img(image): + return image + + x1, y1, x2, y2 = map(int, bbox) + + # 检查坐标有效性 + if x1 >= x2 or y1 >= y2: + return image + + # 绘制边界框 + cv2.rectangle(image, (x1, y1), (x2, y2), color, thickness) + + # 绘制标签 + if label: + font = cv2.FONT_HERSHEY_SIMPLEX + font_scale = 0.5 + + # 获取文本尺寸 + (text_width, text_height), baseline = cv2.getTextSize(label, font, font_scale, thickness) + + # 绘制标签背景 + cv2.rectangle(image, (x1, y1 - text_height - 5), + (x1 + text_width, y1), color, -1) + + # 绘制文本 + cv2.putText(image, label, (x1, y1 - 5), + font, font_scale, (255, 255, 255), thickness) + + return image + +def draw_trajectory(image, points, color=(0, 255, 0), thickness=2, max_points=20): + """ + 在图像上绘制轨迹 + + Args: + image: 输入图像 + points: 轨迹点列表 [(x1, y1), (x2, y2), ...] + color: 轨迹颜色 + thickness: 线宽 + max_points: 最大显示点数 + + Returns: + np.ndarray: 绘制后的图像 + """ + if not valid_img(image) or len(points) < 2: + return image + + # 限制轨迹点数量 + points = points[-max_points:] + + # 绘制轨迹线 + for i in range(1, len(points)): + pt1 = (int(points[i-1][0]), int(points[i-1][1])) + pt2 = (int(points[i][0]), int(points[i][1])) + + # 检查点是否有效 + if 0 <= pt1[0] < image.shape[1] and 0 <= pt1[1] < image.shape[0] and \ + 0 <= pt2[0] < image.shape[1] and 0 <= pt2[1] < image.shape[0]: + cv2.line(image, pt1, pt2, color, thickness) + + return image + +def draw_info_panel(image, info_dict, position="top_left"): + """ + 在图像上绘制信息面板 + + Args: + image: 输入图像 + info_dict: 信息字典 {key: value} + position: 位置 ("top_left", "top_right", "bottom_left", "bottom_right") + + Returns: + np.ndarray: 绘制后的图像 + """ + if not valid_img(image): + return image + + h, w = image.shape[:2] + font = cv2.FONT_HERSHEY_SIMPLEX + font_scale = 0.6 + thickness = 1 + line_height = 25 + + # 确定起始位置 + if position == "top_left": + x, y = 10, 30 + elif position == "top_right": + x, y = w - 200, 30 + elif position == "bottom_left": + x, y = 10, h - 30 - len(info_dict) * line_height + elif position == "bottom_right": + x, y = w - 200, h - 30 - len(info_dict) * line_height + else: + x, y = 10, 30 + + # 绘制信息背景 + bg_height = len(info_dict) * line_height + 10 + cv2.rectangle(image, (x - 5, y - 25), (x + 190, y + bg_height - 20), (0, 0, 0), -1) + + # 绘制标题 + cv2.putText(image, "SYSTEM INFO", (x, y - 5), font, 0.7, (0, 255, 0), thickness) + + # 绘制信息项 + for i, (key, value) in enumerate(info_dict.items()): + text = f"{key}: {value}" + cv2.putText(image, text, (x, y + (i + 1) * line_height), + font, font_scale, (255, 255, 255), thickness) + + return image + +def run_self_tests(): + """运行自测试""" + print("=" * 50) + print("运行 utils.py 自测试...") + print("=" * 50) + + tests_passed = 0 + tests_failed = 0 + + # 测试 1: valid_img + try: + test_img = np.zeros((100, 100, 3), dtype=np.uint8) + assert valid_img(test_img) == True, "valid_img应该返回True" + assert valid_img(None) == False, "valid_img(None)应该返回False" + assert valid_img(np.zeros((100, 100), dtype=np.uint8)) == False, "灰度图应该返回False" + print("✅ valid_img测试通过") + tests_passed += 1 + except AssertionError as e: + print(f"❌ valid_img测试失败: {e}") + tests_failed += 1 + + # 测试 2: clip_box + try: + bbox = [10, 10, 200, 200] + clipped = clip_box(bbox, (150, 150)) + expected = [10, 10, 149, 149] # 索引从0开始,所以是149不是150 + assert np.allclose(clipped[:2], expected[:2]), f"clip_box坐标错误: {clipped[:2]} != {expected[:2]}" + assert clipped[2] <= 149 and clipped[3] <= 149, "clip_box应该限制在图像范围内" + print("✅ clip_box测试通过") + tests_passed += 1 + except AssertionError as e: + print(f"❌ clip_box测试失败: {e}") + tests_failed += 1 + + # 测试 3: iou (兼容性版本) + try: + box1 = [0, 0, 10, 10] + box2 = [5, 5, 15, 15] + iou_val = iou(box1, box2) + expected_iou = 25 / (100 + 100 - 25) # (5x5)/(100+100-25) = 25/175 ≈ 0.1429 + assert abs(iou_val - expected_iou) < 0.001, f"iou计算错误: {iou_val} != {expected_iou}" + + # 测试numpy数组版本 + box1_np = np.array(box1, dtype=np.float32) + box2_np = np.array(box2, dtype=np.float32) + iou_val_np = iou_numpy(box1_np, box2_np) + assert abs(iou_val_np - expected_iou) < 0.001, f"iou_numpy计算错误" + + print("✅ iou测试通过") + tests_passed += 1 + except AssertionError as e: + print(f"❌ iou测试失败: {e}") + tests_failed += 1 + + # 测试 4: make_div + try: + assert make_div(100) == 128, "make_div(100)应该返回128" + assert make_div(128) == 128, "make_div(128)应该返回128" + assert make_div(129) == 160, "make_div(129)应该返回160" + assert make_div(0, 32) == 0, "make_div(0)应该返回0" + print("✅ make_div测试通过") + tests_passed += 1 + except AssertionError as e: + print(f"❌ make_div测试失败: {e}") + tests_failed += 1 + + # 测试 5: FPSCounter (修复的测试) + try: + fps_counter = FPSCounter(window_size=3) + + # 第一次update会初始化但不会计算FPS(需要至少2个时间点) + fps1 = fps_counter.update() + time.sleep(0.05) # 等待50ms + + # 第二次update才会计算FPS + fps2 = fps_counter.update() + time.sleep(0.05) + + fps3 = fps_counter.update() + + # 现在应该有FPS值了 + assert fps3 > 0, f"FPS应该大于0,当前: {fps3}" + assert fps_counter.fps > 0, f"内部FPS应该大于0" + + print(f"✅ FPSCounter测试通过 (FPS: {fps3:.1f})") + tests_passed += 1 + except Exception as e: + print(f"❌ FPSCounter测试失败: {e}") + import traceback + traceback.print_exc() + tests_failed += 1 + + # 测试 6: 可视化函数 + try: + test_img = np.zeros((100, 100, 3), dtype=np.uint8) + # 测试draw_bbox + result1 = draw_bbox(test_img.copy(), [10, 10, 50, 50], label="test") + assert result1.shape == test_img.shape, "draw_bbox应该返回相同尺寸的图像" + + # 测试draw_trajectory + points = [(20, 20), (30, 30), (40, 40)] + result2 = draw_trajectory(test_img.copy(), points) + assert result2.shape == test_img.shape, "draw_trajectory应该返回相同尺寸的图像" + + # 测试draw_info_panel + info = {"FPS": "30.0", "Objects": "5"} + result3 = draw_info_panel(test_img.copy(), info) + assert result3.shape == test_img.shape, "draw_info_panel应该返回相同尺寸的图像" + + print("✅ 可视化函数测试通过") + tests_passed += 1 + except Exception as e: + print(f"❌ 可视化函数测试失败: {e}") + import traceback + traceback.print_exc() + tests_failed += 1 + + # 测试 7: bbox工具函数 + try: + bbox = [10, 20, 50, 80] + center = bbox_center(bbox) + area = bbox_area(bbox) + aspect = bbox_aspect_ratio(bbox) + + assert center == (30.0, 50.0), f"中心点计算错误: {center}" + assert area == 40 * 60, f"面积计算错误: {area}" + assert abs(aspect - 40/60) < 0.001, f"宽高比计算错误: {aspect}" + + print("✅ bbox工具函数测试通过") + tests_passed += 1 + except Exception as e: + print(f"❌ bbox工具函数测试失败: {e}") + tests_failed += 1 + + print("=" * 50) + print(f"测试结果: {tests_passed}通过, {tests_failed}失败") + + if tests_failed == 0: + print("🎉 所有测试通过!") + else: + print("⚠️ 有测试失败,请检查") + + return tests_failed == 0 + +if __name__ == "__main__": + # 运行自测试 + success = run_self_tests() + + if success: + print("\nutils.py 可以安全使用") + sys.exit(0) + else: + print("\n⚠️ utils.py 有测试失败,请修复") + sys.exit(1) \ No newline at end of file diff --git a/src/tracking_car/carla_tracking_ros/setup.py b/src/tracking_car/carla_tracking_ros/setup.py new file mode 100644 index 0000000000..8d4fe919f9 --- /dev/null +++ b/src/tracking_car/carla_tracking_ros/setup.py @@ -0,0 +1,16 @@ +#!/usr/bin/env python +from distutils.core import setup +from catkin_pkg.python_setup import generate_distutils_setup + +d = generate_distutils_setup( + packages=[], + package_dir={}, + scripts=[ + 'scripts/main.py', + 'scripts/sensors.py', + 'scripts/tracker.py', + 'scripts/utils.py' + ] +) + +setup(**d) diff --git a/src/tracking_car/config.yaml b/src/tracking_car/config.yaml index 79759fe126..5e1d9064b0 100644 --- a/src/tracking_car/config.yaml +++ b/src/tracking_car/config.yaml @@ -78,4 +78,13 @@ hot_reload_safe_keys: - "yolo_imgsz_max" # YOLO输入尺寸 - "stop_speed_thresh" # 停车速度阈值 - "danger_dist_thresh" # 危险距离阈值 - - "weather" # 天气设置 \ No newline at end of file + - "weather" # 天气设置 + +# ======================== 交通标志检测配置 ======================== +enable_sign_detection: false # 是否启用交通标志检测 + +traffic_sign: + enabled: true # 检测器是否启用 + show_signs: true # 是否显示检测框 + enable_actions: false # 是否触发动作(警告等) + conf_threshold: 0.5 # 置信度阈值 \ No newline at end of file diff --git a/src/tracking_car/main.py b/src/tracking_car/main.py index 06b02aea15..2329ae04c6 100644 --- a/src/tracking_car/main.py +++ b/src/tracking_car/main.py @@ -1102,6 +1102,13 @@ def _draw_info_panel(self, image, track_count): f"L: PointCloud | V: View Mode" # 添加点云提示 ] + status_lines = [ + f"Tracked Objects: {track_count}", + f"ESC: Exit | W: Weather | S: Screenshot", + f"P: Pause | T: Stats | M: Color Legend", + f"L: PointCloud | V: View | O: SignDetect", # 修改这一行 + ] + # Draw status information font = cv2.FONT_HERSHEY_SIMPLEX for i, line in enumerate(status_lines): @@ -1312,6 +1319,10 @@ def __init__(self, config): self.image_queue = None self.result_queue = None + # 添加交通标志检测器 + self.traffic_sign_detector = None + self.enable_sign_detection = config.get('enable_sign_detection', False) + logger.info("[OK] Tracking system initialized (Color ID encoding + Independent statistics window)") def initialize(self): @@ -1380,7 +1391,18 @@ def initialize(self): import traceback traceback.print_exc() return False - + + if self.enable_sign_detection: + try: + from sign_detector import TrafficSignDetector + self.traffic_sign_detector = TrafficSignDetector( + config.get('traffic_sign', {}) + ) + logger.info("✅ 交通标志检测器已启用") + except Exception as e: + logger.warning(f"交通标志检测器初始化失败: {e}") + self.traffic_sign_detector = None + def _setup_detection_thread(self): """Setup detection thread""" try: @@ -1784,6 +1806,23 @@ def run(self): if self.frame_count % 100 == 0: self._print_status(stats_data) + # ============ 新增:交通标志检测 ============ + detected_signs = [] + if self.traffic_sign_detector and self.enable_sign_detection: + try: + # 检测标志 + detected_signs = self.traffic_sign_detector.detect( + image, + ego_speed=self.ego_vehicle.get_velocity().length() if self.ego_vehicle else 0.0 + ) + + # 在图像上绘制标志 + if self.traffic_sign_detector.config.get('show_signs', True): + image = self.traffic_sign_detector.draw_signs(image, detected_signs) + + except Exception as e: + logger.debug(f"标志检测失败: {e}") + except KeyboardInterrupt: logger.info("[STOP] User interrupted program") except Exception as e: @@ -1860,6 +1899,12 @@ def _handle_keyboard_input(self, key): # 新增:R键手动重载配置 elif key == ord('r') or key == ord('R'): self._force_reload_config() + + # 添加:O键切换标志检测 + elif key == ord('o') or key == ord('O'): + self.enable_sign_detection = not self.enable_sign_detection + status = "开启" if self.enable_sign_detection else "关闭" + logger.info(f"🚦 交通标志检测: {status}") def _control_frame_rate(self, current_fps): """自适应帧率控制(简单版)""" diff --git a/src/tracking_car/sign_detector.py b/src/tracking_car/sign_detector.py new file mode 100644 index 0000000000..bbd8a82f18 --- /dev/null +++ b/src/tracking_car/sign_detector.py @@ -0,0 +1,438 @@ +""" +sign_detector.py - 轻量级交通标志识别模块 +最小化集成,不破坏现有结构 +""" + +import cv2 +import numpy as np +import torch +from typing import List, Dict, Any, Optional +import time + +try: + from loguru import logger +except ImportError: + import logging + logger = logging.getLogger(__name__) + +class TrafficSignDetector: + """轻量级交通标志检测器""" + + def __init__(self, config=None): + """ + 初始化 + Args: + config: 可选配置,默认使用内置简单配置 + """ + # 默认配置 + self.config = config or { + 'enabled': True, + 'conf_threshold': 0.5, + 'show_signs': True, + 'enable_actions': False, # 默认不触发动作,只显示 + } + + # 尝试加载YOLO模型,如果失败则使用简单的颜色检测 + self.model = None + self.use_yolo = False + + try: + from ultralytics import YOLO + # 尝试加载预训练模型(可以用通用物体检测模型) + self.model = YOLO('yolov8n.pt') # 使用现有的YOLO模型 + self.use_yolo = True + logger.info("✅ 使用YOLO进行标志检测") + except Exception as e: + logger.warning(f"无法加载YOLO模型,使用简单颜色检测: {e}") + self._init_simple_detector() + + # 简单的标志颜色检测 + self.sign_colors = { + 'red': { # 停车、禁止类标志 + 'lower': np.array([0, 50, 50]), + 'upper': np.array([10, 255, 255]) + }, + 'blue': { # 指示类标志 + 'lower': np.array([100, 50, 50]), + 'upper': np.array([130, 255, 255]) + }, + 'yellow': { # 警告类标志 + 'lower': np.array([20, 100, 100]), + 'upper': np.array([30, 255, 255]) + } + } + + # 标志形状模板(可选) + self.shape_templates = self._load_shape_templates() + + # 检测历史 + self.detected_signs_history = [] + + logger.info("✅ 轻量级标志检测器初始化完成") + + def _init_simple_detector(self): + """初始化简单检测器""" + # 加载简单的形状模板 + self.templates = { + 'triangle': self._create_triangle_mask(), + 'circle': self._create_circle_mask(), + 'octagon': self._create_octagon_mask(), # 停车标志 + 'square': self._create_square_mask() + } + + def detect(self, image: np.ndarray, ego_speed: float = 0.0) -> List[Dict]: + """ + 检测图像中的交通标志 + + Args: + image: 输入图像 + ego_speed: 自车速度(用于距离估计) + + Returns: + List[Dict]: 检测到的标志列表 + """ + if not self.config.get('enabled', True): + return [] + + signs = [] + + try: + # 方法1:如果YOLO可用,使用YOLO检测 + if self.use_yolo and self.model is not None: + signs = self._detect_with_yolo(image, ego_speed) + + # 方法2:否则使用简单的颜色+形状检测 + else: + signs = self._detect_with_color(image, ego_speed) + + # 过滤重复的标志(简单的NMS) + signs = self._non_max_suppression(signs) + + # 更新历史 + if signs: + self.detected_signs_history = signs[:10] # 保留最近10个 + + # 简单的动作触发(可选的) + if self.config.get('enable_actions', False): + self._trigger_actions(signs, ego_speed) + + return signs + + except Exception as e: + logger.error(f"标志检测失败: {e}") + return [] + + def _detect_with_yolo(self, image: np.ndarray, ego_speed: float) -> List[Dict]: + """使用YOLO检测""" + results = self.model.predict( + image, + conf=self.config.get('conf_threshold', 0.5), + verbose=False + ) + + signs = [] + + for result in results: + if result.boxes is not None: + for box in result.boxes: + # 过滤出可能是交通标志的类别(YOLO COCO数据集中的类别) + class_id = int(box.cls[0]) + class_name = result.names[class_id] + + # 只保留相关类别 + if class_name in ['stop sign', 'traffic light', 'parking meter']: + bbox = box.xyxy[0].cpu().numpy() + confidence = float(box.conf[0]) + + # 推断标志类型 + sign_type = self._infer_sign_type(class_name, bbox, image) + + sign_info = { + 'bbox': bbox.tolist(), + 'confidence': confidence, + 'type': sign_type, + 'class_id': class_id, + 'class_name': class_name, + 'timestamp': time.time() + } + + # 估算距离(基于边界框大小) + sign_info['distance'] = self._estimate_distance(sign_info['bbox'], ego_speed) + + signs.append(sign_info) + + return signs + + def _detect_with_color(self, image: np.ndarray, ego_speed: float) -> List[Dict]: + """使用颜色检测""" + # 转换为HSV颜色空间 + hsv = cv2.cvtColor(image, cv2.COLOR_BGR2HSV) + + signs = [] + + # 检测每种标志颜色 + for color_name, color_range in self.sign_colors.items(): + # 创建颜色掩码 + mask = cv2.inRange(hsv, color_range['lower'], color_range['upper']) + + # 形态学操作去除噪声 + kernel = np.ones((5, 5), np.uint8) + mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel) + mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel) + + # 查找轮廓 + contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) + + for contour in contours: + area = cv2.contourArea(contour) + + # 过滤太小的区域 + if area < 100: + continue + + # 获取边界框 + x, y, w, h = cv2.boundingRect(contour) + + # 计算形状特征 + shape = self._detect_shape(contour) + + # 推断标志类型 + sign_type = self._infer_sign_type_from_color(color_name, shape) + + if sign_type: + sign_info = { + 'bbox': [x, y, x + w, y + h], + 'confidence': min(0.9, area / 1000.0), # 简单置信度 + 'type': sign_type, + 'color': color_name, + 'shape': shape, + 'timestamp': time.time() + } + + # 估算距离 + sign_info['distance'] = self._estimate_distance(sign_info['bbox'], ego_speed) + + signs.append(sign_info) + + return signs + + def _infer_sign_type(self, class_name: str, bbox: np.ndarray, image: np.ndarray) -> str: + """根据检测结果推断标志类型""" + type_map = { + 'stop sign': 'stop', + 'traffic light': 'traffic_light', + 'parking meter': 'parking', + } + + # 如果是stop sign,进一步确认(检查是否是八边形) + if class_name == 'stop sign': + # 提取ROI检查形状 + x1, y1, x2, y2 = map(int, bbox) + roi = image[y1:y2, x1:x2] + if self._is_octagon_shape(roi): + return 'stop' + + return type_map.get(class_name, 'unknown') + + def _infer_sign_type_from_color(self, color: str, shape: str) -> Optional[str]: + """根据颜色和形状推断标志类型""" + # 简单的推断规则 + if color == 'red' and shape == 'octagon': + return 'stop' + elif color == 'red' and shape == 'circle': + return 'no_entry' + elif color == 'yellow' and shape == 'triangle': + return 'warning' + elif color == 'blue' and shape == 'circle': + return 'mandatory' + elif color == 'blue' and shape == 'square': + return 'information' + + return None + + def _detect_shape(self, contour) -> str: + """检测轮廓形状""" + approx = cv2.approxPolyDP(contour, 0.04 * cv2.arcLength(contour, True), True) + num_sides = len(approx) + + if num_sides == 3: + return 'triangle' + elif num_sides == 4: + # 判断是正方形还是长方形 + x, y, w, h = cv2.boundingRect(contour) + aspect_ratio = float(w) / h + if 0.8 <= aspect_ratio <= 1.2: + return 'square' + else: + return 'rectangle' + elif 8 <= num_sides <= 12: # 八边形 + return 'octagon' + else: + # 计算圆形度 + area = cv2.contourArea(contour) + perimeter = cv2.arcLength(contour, True) + circularity = 4 * np.pi * area / (perimeter * perimeter) + if circularity > 0.7: + return 'circle' + + return 'unknown' + + def _estimate_distance(self, bbox: List[float], ego_speed: float) -> float: + """估算标志距离(简化版本)""" + # 基于边界框高度估算距离 + # 假设标志的标准高度为0.5米,焦距为1000像素 + x1, y1, x2, y2 = bbox + bbox_height = y2 - y1 + + if bbox_height <= 0: + return 100.0 # 默认远距离 + + # 简化距离公式:距离 ∝ 1/高度 + distance = 500.0 / bbox_height + + # 根据速度调整(运动模糊) + if ego_speed > 10: + distance *= (1 + ego_speed / 100.0) + + return min(distance, 100.0) # 最大100米 + + def _non_max_suppression(self, signs: List[Dict], iou_threshold: float = 0.5) -> List[Dict]: + """简单的非极大值抑制""" + if len(signs) <= 1: + return signs + + # 按置信度排序 + sorted_signs = sorted(signs, key=lambda x: x['confidence'], reverse=True) + keep = [] + + while sorted_signs: + # 取置信度最高的 + best = sorted_signs.pop(0) + keep.append(best) + + # 移除与best重叠度高的 + sorted_signs = [ + sign for sign in sorted_signs + if self._iou(best['bbox'], sign['bbox']) < iou_threshold + ] + + return keep + + def _iou(self, box1: List[float], box2: List[float]) -> float: + """计算IoU""" + x1 = max(box1[0], box2[0]) + y1 = max(box1[1], box2[1]) + x2 = min(box1[2], box2[2]) + y2 = min(box1[3], box2[3]) + + inter_area = max(0, x2 - x1) * max(0, y2 - y1) + box1_area = (box1[2] - box1[0]) * (box1[3] - box1[1]) + box2_area = (box2[2] - box2[0]) * (box2[3] - box2[1]) + + union_area = box1_area + box2_area - inter_area + + return inter_area / union_area if union_area > 0 else 0.0 + + def _trigger_actions(self, signs: List[Dict], ego_speed: float): + """触发简单动作(可选)""" + for sign in signs: + sign_type = sign.get('type', '') + distance = sign.get('distance', 100.0) + + if sign_type == 'stop' and distance < 20: + logger.warning(f"🛑 前方 {distance:.1f}米有停车标志") + + elif 'speed_limit' in sign_type and distance < 30: + try: + limit = int(sign_type.split('_')[-1]) + if ego_speed * 3.6 > limit + 5: # m/s转km/h + logger.warning(f"📏 前方限速{limit}km/h,当前{ego_speed*3.6:.0f}km/h") + except: + pass + + def draw_signs(self, image: np.ndarray, signs: List[Dict]) -> np.ndarray: + """在图像上绘制检测到的标志""" + if not self.config.get('show_signs', True): + return image + + result = image.copy() + + for sign in signs: + bbox = sign['bbox'] + sign_type = sign.get('type', 'unknown') + confidence = sign.get('confidence', 0.0) + distance = sign.get('distance', 0.0) + + x1, y1, x2, y2 = map(int, bbox) + + # 根据标志类型选择颜色 + color_map = { + 'stop': (0, 0, 255), # 红色 + 'warning': (0, 165, 255), # 橙色 + 'traffic_light': (0, 255, 0), # 绿色 + 'speed_limit': (0, 255, 255), # 黄色 + 'unknown': (128, 128, 128) # 灰色 + } + + color = color_map.get(sign_type, (128, 128, 128)) + + # 绘制边界框 + cv2.rectangle(result, (x1, y1), (x2, y2), color, 2) + + # 绘制标签 + label = f"{sign_type} {confidence:.2f}" + if distance > 0: + label += f" {distance:.1f}m" + + cv2.putText(result, label, (x1, y1 - 10), + cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2) + + return result + + def get_signs_info(self) -> Dict: + """获取标志检测统计信息""" + return { + 'total_signs': len(self.detected_signs_history), + 'recent_signs': self.detected_signs_history[-5:] if self.detected_signs_history else [], + 'enabled': self.config.get('enabled', True) + } + + +def test_detector(): + """测试检测器""" + detector = TrafficSignDetector() + + # 测试图片 + test_image = np.zeros((480, 640, 3), dtype=np.uint8) + + # 模拟一个红色八边形(停车标志) + center = (320, 240) + radius = 50 + points = [] + for i in range(8): + angle = 2 * np.pi * i / 8 + x = center[0] + radius * np.cos(angle) + y = center[1] + radius * np.sin(angle) + points.append((int(x), int(y))) + + cv2.fillPoly(test_image, [np.array(points)], (0, 0, 255)) + + # 检测 + signs = detector.detect(test_image) + + print(f"检测到 {len(signs)} 个标志") + for sign in signs: + print(f" 类型: {sign.get('type')}, 置信度: {sign.get('confidence'):.2f}") + + # 绘制 + result = detector.draw_signs(test_image, signs) + + return len(signs) > 0 + + +if __name__ == "__main__": + success = test_detector() + if success: + print("✅ 标志检测器测试通过") + else: + print("⚠️ 测试未检测到标志") \ No newline at end of file