Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 92 additions & 14 deletions pufferlib/ocean/drive/drive.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@
#define DEFAULT_DTC 50.0f // Ignore candidates beyond this range
#define STOP_LINE_EXTENSION_FACTOR 1.5f
#define RED_LIGHT_HEADING_THRESHOLD (M_PI / 4.0f)
#define NUPLAN_BEHIND_COS_THRESHOLD -0.8660254f // cos(150 degrees)

// TTC default value when no vehicle ahead
#define DEFAULT_TTC 5.0f
Expand All @@ -100,6 +101,14 @@
#define MULTI_LANE_FULL_SCORE_TIME 3.4f // seconds
#define MULTI_LANE_HALF_SCORE_TIME 5.7f // seconds

// nuPlan-style collision classification
#define COLLISION_TYPE_NONE 0
#define COLLISION_TYPE_STOPPED_EGO 1
#define COLLISION_TYPE_STOPPED_TRACK 2
#define COLLISION_TYPE_ACTIVE_REAR 3
#define COLLISION_TYPE_ACTIVE_FRONT 4
#define COLLISION_TYPE_ACTIVE_LATERAL 5

// Collision/Infraction behaviors
#define STOP_AGENT 1
#define REMOVE_AGENT 2
Expand Down Expand Up @@ -2007,7 +2016,7 @@ static bool check_line_intersection(float p1[2], float p2[2], float q1[2], float
return (s >= 0 && s <= 1 && t >= 0 && t <= 1);
}

static void compute_agent_corners(Agent *agent, float corners[4][2]) {
static void compute_agent_corners(const Agent *agent, float corners[4][2]) {
static const float offsets[4][2] = {{1, 1}, {1, -1}, {-1, -1}, {-1, 1}};
float half_length = agent->sim_length / 2.0f;
float half_width = agent->sim_width / 2.0f;
Expand All @@ -2020,6 +2029,37 @@ static void compute_agent_corners(Agent *agent, float corners[4][2]) {
}
}

// helpers for nuPlan's at-fault collision, you need to know if the front bumper was involed
static bool point_inside_agent_box(float x, float y, const Agent *agent) {
float dx = x - agent->sim_x;
float dy = y - agent->sim_y;
float local_long = dx * agent->cos_heading + dy * agent->sin_heading;
float local_lat = -dx * agent->sin_heading + dy * agent->cos_heading;
return fabsf(local_long) <= 0.5f * agent->sim_length && fabsf(local_lat) <= 0.5f * agent->sim_width;
}

static bool segment_intersects_agent_box(float p1[2], float p2[2], const Agent *agent) {
if (point_inside_agent_box(p1[0], p1[1], agent) || point_inside_agent_box(p2[0], p2[1], agent)) {
return true;
}

float corners[4][2];
compute_agent_corners(agent, corners);
for (int i = 0; i < 4; i++) {
int next = (i + 1) % 4;
if (check_line_intersection(p1, p2, corners[i], corners[next])) {
return true;
}
}
return false;
}

static bool front_bumper_intersects_agent(const Agent *ego, const Agent *other) {
float corners[4][2];
compute_agent_corners(ego, corners);
return segment_intersects_agent_box(corners[0], corners[1], other);
}

static bool check_agent_corners_cross_stop_line(float corners[4][2], TrafficControlElement *traffic) {
float sl_dx = traffic->stop_line[3] - traffic->stop_line[0];
float sl_dy = traffic->stop_line[4] - traffic->stop_line[1];
Expand Down Expand Up @@ -2280,30 +2320,68 @@ static int collision_check(Drive *env, int agent_idx) {
return car_collided_with_index;
}

static bool is_agent_behind_rear_axle(const Agent *ego, const Agent *other) {
float rear_x = ego->sim_x - 0.5f * ego->sim_length * ego->cos_heading;
float rear_y = ego->sim_y - 0.5f * ego->sim_length * ego->sin_heading;
float dx = other->sim_x - rear_x;
float dy = other->sim_y - rear_y;
float dist = sqrtf(dx * dx + dy * dy);
if (dist < 1e-6f) {
return false;
}

float dot = (ego->cos_heading * dx + ego->sin_heading * dy) / dist;
return dot < NUPLAN_BEHIND_COS_THRESHOLD;
}

static bool is_agent_cleanly_in_lane(const Agent *agent) {
if (agent->current_lane_idx == -1) {
return false;
}

float edge_dist = fabsf(agent->metrics_array[LANE_DIST_IDX]) + 0.5f * agent->sim_width;
return edge_dist <= MULTI_LANE_THRESHOLD;
}

static int classify_collision_type(const Agent *ego, const Agent *other) {
if (ego->sim_speed <= AGENT_STOPPED_SPEED_THRESHOLD) {
return COLLISION_TYPE_STOPPED_EGO;
}

if (other->sim_speed <= AGENT_STOPPED_SPEED_THRESHOLD) {
return COLLISION_TYPE_STOPPED_TRACK;
}

if (is_agent_behind_rear_axle(ego, other)) {
return COLLISION_TYPE_ACTIVE_REAR;
}

if (front_bumper_intersects_agent(ego, other)) {
return COLLISION_TYPE_ACTIVE_FRONT;
}

return COLLISION_TYPE_ACTIVE_LATERAL;
}

// Classify whether a collision is at-fault for the ego agent.
static bool is_at_fault_collision(Drive *env, int agent_idx, int other_idx) {
Agent *agent = &env->agents[agent_idx];
Agent *other = &env->agents[other_idx];
int collision_type = classify_collision_type(agent, other);

Comment on lines 2367 to 2371
// Rule 1: Collision with stopped vehicle = always at-fault.
if (other->sim_speed < AGENT_STOPPED_SPEED_THRESHOLD) {
if (collision_type == COLLISION_TYPE_STOPPED_TRACK) {
return true;
}
// Rule 2: If ego is stopped = never at-fault.
if (agent->sim_speed < AGENT_STOPPED_SPEED_THRESHOLD) {
return false;

if (collision_type == COLLISION_TYPE_ACTIVE_FRONT) {
return true;
}

// Rule 3: Rear-bumper collision = not at-fault.
// Check if the other car hit our rear using the heading-aligned relative position.
float dx = other->sim_x - agent->sim_x;
float dy = other->sim_y - agent->sim_y;
float dot = dx * agent->cos_heading + dy * agent->sin_heading;
if (dot < 0) {
return false;
if (collision_type == COLLISION_TYPE_ACTIVE_LATERAL) {
return !is_agent_cleanly_in_lane(agent);
}

return true;
return false;
}

static inline void ttc_update_min_result(Agent *ego, int other_idx, float closing_speed, float ttc) {
Expand Down
Loading