diff --git a/pufferlib/ocean/drive/drive.h b/pufferlib/ocean/drive/drive.h index 329e1bdf76..a95bae7885 100644 --- a/pufferlib/ocean/drive/drive.h +++ b/pufferlib/ocean/drive/drive.h @@ -86,6 +86,7 @@ #define DEFAULT_DTC 50.0f // Ignore candidates beyond this range #define STOP_LINE_EXTENSION_FACTOR 1.5f #define RED_LIGHT_HEADING_THRESHOLD (M_PI / 4.0f) +#define NUPLAN_BEHIND_COS_THRESHOLD -0.8660254f // cos(150 degrees) // TTC default value when no vehicle ahead #define DEFAULT_TTC 5.0f @@ -100,6 +101,14 @@ #define MULTI_LANE_FULL_SCORE_TIME 3.4f // seconds #define MULTI_LANE_HALF_SCORE_TIME 5.7f // seconds +// nuPlan-style collision classification +#define COLLISION_TYPE_NONE 0 +#define COLLISION_TYPE_STOPPED_EGO 1 +#define COLLISION_TYPE_STOPPED_TRACK 2 +#define COLLISION_TYPE_ACTIVE_REAR 3 +#define COLLISION_TYPE_ACTIVE_FRONT 4 +#define COLLISION_TYPE_ACTIVE_LATERAL 5 + // Collision/Infraction behaviors #define STOP_AGENT 1 #define REMOVE_AGENT 2 @@ -2007,7 +2016,7 @@ static bool check_line_intersection(float p1[2], float p2[2], float q1[2], float return (s >= 0 && s <= 1 && t >= 0 && t <= 1); } -static void compute_agent_corners(Agent *agent, float corners[4][2]) { +static void compute_agent_corners(const Agent *agent, float corners[4][2]) { static const float offsets[4][2] = {{1, 1}, {1, -1}, {-1, -1}, {-1, 1}}; float half_length = agent->sim_length / 2.0f; float half_width = agent->sim_width / 2.0f; @@ -2020,6 +2029,37 @@ static void compute_agent_corners(Agent *agent, float corners[4][2]) { } } +// helpers for nuPlan's at-fault collision, you need to know if the front bumper was involed +static bool point_inside_agent_box(float x, float y, const Agent *agent) { + float dx = x - agent->sim_x; + float dy = y - agent->sim_y; + float local_long = dx * agent->cos_heading + dy * agent->sin_heading; + float local_lat = -dx * agent->sin_heading + dy * agent->cos_heading; + return fabsf(local_long) <= 0.5f * agent->sim_length && fabsf(local_lat) <= 0.5f * agent->sim_width; +} + +static bool segment_intersects_agent_box(float p1[2], float p2[2], const Agent *agent) { + if (point_inside_agent_box(p1[0], p1[1], agent) || point_inside_agent_box(p2[0], p2[1], agent)) { + return true; + } + + float corners[4][2]; + compute_agent_corners(agent, corners); + for (int i = 0; i < 4; i++) { + int next = (i + 1) % 4; + if (check_line_intersection(p1, p2, corners[i], corners[next])) { + return true; + } + } + return false; +} + +static bool front_bumper_intersects_agent(const Agent *ego, const Agent *other) { + float corners[4][2]; + compute_agent_corners(ego, corners); + return segment_intersects_agent_box(corners[0], corners[1], other); +} + static bool check_agent_corners_cross_stop_line(float corners[4][2], TrafficControlElement *traffic) { float sl_dx = traffic->stop_line[3] - traffic->stop_line[0]; float sl_dy = traffic->stop_line[4] - traffic->stop_line[1]; @@ -2280,30 +2320,68 @@ static int collision_check(Drive *env, int agent_idx) { return car_collided_with_index; } +static bool is_agent_behind_rear_axle(const Agent *ego, const Agent *other) { + float rear_x = ego->sim_x - 0.5f * ego->sim_length * ego->cos_heading; + float rear_y = ego->sim_y - 0.5f * ego->sim_length * ego->sin_heading; + float dx = other->sim_x - rear_x; + float dy = other->sim_y - rear_y; + float dist = sqrtf(dx * dx + dy * dy); + if (dist < 1e-6f) { + return false; + } + + float dot = (ego->cos_heading * dx + ego->sin_heading * dy) / dist; + return dot < NUPLAN_BEHIND_COS_THRESHOLD; +} + +static bool is_agent_cleanly_in_lane(const Agent *agent) { + if (agent->current_lane_idx == -1) { + return false; + } + + float edge_dist = fabsf(agent->metrics_array[LANE_DIST_IDX]) + 0.5f * agent->sim_width; + return edge_dist <= MULTI_LANE_THRESHOLD; +} + +static int classify_collision_type(const Agent *ego, const Agent *other) { + if (ego->sim_speed <= AGENT_STOPPED_SPEED_THRESHOLD) { + return COLLISION_TYPE_STOPPED_EGO; + } + + if (other->sim_speed <= AGENT_STOPPED_SPEED_THRESHOLD) { + return COLLISION_TYPE_STOPPED_TRACK; + } + + if (is_agent_behind_rear_axle(ego, other)) { + return COLLISION_TYPE_ACTIVE_REAR; + } + + if (front_bumper_intersects_agent(ego, other)) { + return COLLISION_TYPE_ACTIVE_FRONT; + } + + return COLLISION_TYPE_ACTIVE_LATERAL; +} + // Classify whether a collision is at-fault for the ego agent. static bool is_at_fault_collision(Drive *env, int agent_idx, int other_idx) { Agent *agent = &env->agents[agent_idx]; Agent *other = &env->agents[other_idx]; + int collision_type = classify_collision_type(agent, other); - // Rule 1: Collision with stopped vehicle = always at-fault. - if (other->sim_speed < AGENT_STOPPED_SPEED_THRESHOLD) { + if (collision_type == COLLISION_TYPE_STOPPED_TRACK) { return true; } - // Rule 2: If ego is stopped = never at-fault. - if (agent->sim_speed < AGENT_STOPPED_SPEED_THRESHOLD) { - return false; + + if (collision_type == COLLISION_TYPE_ACTIVE_FRONT) { + return true; } - // Rule 3: Rear-bumper collision = not at-fault. - // Check if the other car hit our rear using the heading-aligned relative position. - float dx = other->sim_x - agent->sim_x; - float dy = other->sim_y - agent->sim_y; - float dot = dx * agent->cos_heading + dy * agent->sin_heading; - if (dot < 0) { - return false; + if (collision_type == COLLISION_TYPE_ACTIVE_LATERAL) { + return !is_agent_cleanly_in_lane(agent); } - return true; + return false; } static inline void ttc_update_min_result(Agent *ego, int other_idx, float closing_speed, float ttc) {