From 566f3cebdcee507ecaf94699134e4fde378aedbe Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wa=C3=ABl=20Doulazmi?= Date: Sun, 24 May 2026 02:54:25 +0200 Subject: [PATCH 1/3] Introduce a bid refacto around the notion of controller --- pufferlib/config/ocean/drive.ini | 9 ++++ pufferlib/ocean/drive/binding.c | 23 +++++++++++ pufferlib/ocean/drive/datatypes.h | 1 + pufferlib/ocean/drive/drive.h | 69 +++++++++++++++++++++++++++---- pufferlib/ocean/drive/drive.py | 39 +++++++++++++++++ pufferlib/ocean/env_binding.h | 3 ++ 6 files changed, 136 insertions(+), 8 deletions(-) diff --git a/pufferlib/config/ocean/drive.ini b/pufferlib/config/ocean/drive.ini index 571b2ea582..1a32d80849 100644 --- a/pufferlib/config/ocean/drive.ini +++ b/pufferlib/config/ocean/drive.ini @@ -66,6 +66,15 @@ inactive_agent_threshold = 0.4 init_steps = 0 ; options: "control_vehicles", "control_agents", "control_tracks_to_predict", "control_sdc_only" control_mode = "control_vehicles" +; Controller used by agent 0, the canonical SDC/target. +; options: "static", "policy", "replay" +sdc_controller = "policy" +; Controller used by non-SDC vehicles. +; options: "static", "policy", "replay" +non_sdc_controller = "policy" +; Controller used by non-vehicle agents. "auto" follows non_sdc_controller. +; options: "auto", "static", "policy", "replay" +non_vehicle_controller = "auto" ; Options: "created_all_valid", "create_only_controlled" init_mode = "create_all_valid" ; Enable computation of evaluation-only metrics diff --git a/pufferlib/ocean/drive/binding.c b/pufferlib/ocean/drive/binding.c index f67243a32d..c830249375 100644 --- a/pufferlib/ocean/drive/binding.c +++ b/pufferlib/ocean/drive/binding.c @@ -770,6 +770,20 @@ static PyObject *my_get(PyObject *dict, Env *env) { } Py_DECREF(tmp); + tmp = PyLong_FromLong(a->controller); + if (!tmp) { + Py_DECREF(agent); + Py_DECREF(agents_list); + return NULL; + } + if (PyDict_SetItemString(agent, "controller", tmp) < 0) { + Py_DECREF(tmp); + Py_DECREF(agent); + Py_DECREF(agents_list); + return NULL; + } + Py_DECREF(tmp); + tmp = PyLong_FromLong(a->mark_as_expert); if (!tmp) { Py_DECREF(agent); @@ -1569,6 +1583,9 @@ static PyObject *my_shared(PyObject *self, PyObject *args, PyObject *kwargs) { int s_map_counter = starting_map_counter; int init_mode = unpack(kwargs, "init_mode"); int control_mode = unpack(kwargs, "control_mode"); + int sdc_controller = unpack(kwargs, "sdc_controller"); + int non_sdc_controller = unpack(kwargs, "non_sdc_controller"); + int non_vehicle_controller = unpack(kwargs, "non_vehicle_controller"); int simulation_mode = unpack(kwargs, "simulation_mode"); int init_steps = unpack(kwargs, "init_steps"); int seed = unpack(kwargs, "seed"); @@ -1705,6 +1722,9 @@ static PyObject *my_shared(PyObject *self, PyObject *args, PyObject *kwargs) { Drive *env = calloc(1, sizeof(Drive)); env->init_mode = init_mode; env->control_mode = control_mode; + env->sdc_controller = sdc_controller; + env->non_sdc_controller = non_sdc_controller; + env->non_vehicle_controller = non_vehicle_controller; env->simulation_mode = simulation_mode; env->init_steps = init_steps; env->num_max_agents = max_agents_per_env; @@ -1827,6 +1847,9 @@ static int my_init(Env *env, PyObject *args, PyObject *kwargs) { env->timestep = init_steps; env->init_mode = (int) unpack(kwargs, "init_mode"); env->control_mode = (int) unpack(kwargs, "control_mode"); + env->sdc_controller = (int) unpack(kwargs, "sdc_controller"); + env->non_sdc_controller = (int) unpack(kwargs, "non_sdc_controller"); + env->non_vehicle_controller = (int) unpack(kwargs, "non_vehicle_controller"); env->simulation_mode = (int) unpack(kwargs, "simulation_mode"); env->reward_conditioning = (bool) unpack(kwargs, "reward_conditioning"); env->reward_randomization = (bool) unpack(kwargs, "reward_randomization"); diff --git a/pufferlib/ocean/drive/datatypes.h b/pufferlib/ocean/drive/datatypes.h index 2a391fb744..6a6f964506 100644 --- a/pufferlib/ocean/drive/datatypes.h +++ b/pufferlib/ocean/drive/datatypes.h @@ -223,6 +223,7 @@ struct Agent { int num_goals_reached; int active_agent; int mark_as_expert; + int controller; float cumulative_displacement; int displacement_sample_count; float path_progression; diff --git a/pufferlib/ocean/drive/drive.h b/pufferlib/ocean/drive/drive.h index 329e1bdf76..f47f9f20a8 100644 --- a/pufferlib/ocean/drive/drive.h +++ b/pufferlib/ocean/drive/drive.h @@ -64,6 +64,11 @@ #define CONTROL_WOSAC 2 #define CONTROL_SDC_ONLY 3 +// Controller modes +#define CONTROLLER_STATIC 0 +#define CONTROLLER_POLICY 1 +#define CONTROLLER_REPLAY 2 + // Simulation modes #define SIMULATION_GIGAFLOW 0 #define SIMULATION_REPLAY 1 @@ -370,6 +375,9 @@ struct Drive { int *tracks_to_predict; int init_mode; int control_mode; + int sdc_controller; + int non_sdc_controller; + int non_vehicle_controller; int simulation_mode; int termination_mode; float inactive_agent_threshold; @@ -3136,6 +3144,28 @@ static bool should_control_agent(Drive *env, int agent_idx) { return agent->route_length != 0; } +static int resolve_agent_controller(Drive *env, int agent_idx, int is_active, int replay_by_default) { + if (replay_by_default) { + return CONTROLLER_REPLAY; + } + + Agent *agent = &env->agents[agent_idx]; + int requested_controller = CONTROLLER_STATIC; + if (agent_idx == EGO_IDX) { + requested_controller = env->sdc_controller; + } else if (agent->type == VEHICLE) { + requested_controller = env->non_sdc_controller; + } else { + requested_controller = env->non_vehicle_controller; + } + + if (requested_controller == CONTROLLER_POLICY && !is_active) { + return CONTROLLER_STATIC; + } + + return requested_controller; +} + void set_active_agents(Drive *env) { // Initialize env->active_agent_count = 0; // Policy-controlled agents @@ -3170,6 +3200,8 @@ void set_active_agents(Drive *env) { for (int i = 0; i < successfully_created; i++) { env->active_agent_indices[i] = active_agent_indices[i]; + env->agents[active_agent_indices[i]].controller + = resolve_agent_controller(env, active_agent_indices[i], 1, 0); } free(active_agent_indices); @@ -3222,12 +3254,15 @@ void set_active_agents(Drive *env) { active_agent_indices[env->active_agent_count] = i; env->active_agent_count++; env->agents[i].active_agent = 1; + env->agents[i].controller = resolve_agent_controller(env, i, 1, 0); } else if (is_log_replay || env->init_mode != INIT_ONLY_CONTROLLABLE_AGENTS) { - // In log-replay mode, all non-controlled agents become expert_static static_agent_indices[env->static_agent_count] = i; env->static_agent_count++; env->agents[i].active_agent = 0; - if (is_log_replay || env->agents[i].mark_as_expert == 1 || env->active_agent_count == env->num_max_agents) { + int replay_by_default + = is_log_replay || env->agents[i].mark_as_expert == 1 || env->active_agent_count == env->num_max_agents; + env->agents[i].controller = resolve_agent_controller(env, i, 0, replay_by_default); + if (env->agents[i].controller == CONTROLLER_REPLAY) { expert_static_agent_indices[env->expert_static_agent_count] = i; env->expert_static_agent_count++; env->agents[i].mark_as_expert = 1; @@ -4895,6 +4930,25 @@ static void move_dynamics(Drive *env, int action_idx, int agent_idx) { return; } +static void move_agent_with_controller(Drive *env, int action_idx, int agent_idx) { + Agent *agent = &env->agents[agent_idx]; + + if (agent->controller == CONTROLLER_STATIC) { + return; + } + + if (agent->controller == CONTROLLER_REPLAY) { + if (env->simulation_mode == SIMULATION_REPLAY) { + move_expert(env, env->actions, agent_idx); + } + return; + } + + if (agent->controller == CONTROLLER_POLICY && action_idx >= 0) { + move_dynamics(env, action_idx, agent_idx); + } +} + static inline void sample_erratic_flags(Drive *env, Agent *agent) { agent->is_blind_partner = (env->partner_blindness_prob > 0.0f && random_uniform(0.0f, 1.0f) < env->partner_blindness_prob) ? 1 : 0; @@ -5025,18 +5079,17 @@ void c_step(Drive *env) { env->timestep++; // -> 1. Apply actions and move agents - // Move static experts - for (int i = 0; i < env->expert_static_agent_count; i++) { - int expert_idx = env->expert_static_agent_indices[i]; - move_expert(env, env->actions, expert_idx); + // Move background agents according to their per-agent controller. + for (int i = 0; i < env->static_agent_count; i++) { + int background_idx = env->static_agent_indices[i]; + move_agent_with_controller(env, -1, background_idx); } // Move active agents with policy actions for (int i = 0; i < env->active_agent_count; i++) { env->logs[i].score = 0.0f; env->logs[i].episode_length += 1; int agent_idx = env->active_agent_indices[i]; - move_dynamics(env, i, agent_idx); - // move_expert(env, env->actions, agent_idx); + move_agent_with_controller(env, i, agent_idx); } // -> 2. Compute metrics and rewards diff --git a/pufferlib/ocean/drive/drive.py b/pufferlib/ocean/drive/drive.py index d49563ebf6..951e7f789b 100644 --- a/pufferlib/ocean/drive/drive.py +++ b/pufferlib/ocean/drive/drive.py @@ -75,6 +75,9 @@ def __init__( num_eval_scenarios=16, init_mode="create_all_valid", control_mode="control_vehicles", + sdc_controller="policy", + non_sdc_controller="policy", + non_vehicle_controller="auto", map_dir=None, target_type="static", reward_conditioning=False, @@ -234,6 +237,9 @@ def __init__( self.init_steps = init_steps self.init_mode_str = init_mode self.control_mode_str = control_mode + self.sdc_controller_str = sdc_controller + self.non_sdc_controller_str = non_sdc_controller + self.non_vehicle_controller_str = non_vehicle_controller self.simulation_mode_str = simulation_mode self.map_dir = map_dir self.map_files = sorted(os.path.join(map_dir, f) for f in os.listdir(map_dir) if f.endswith(".bin")) @@ -258,6 +264,30 @@ def __init__( "control_mode must be one of 'control_vehicles', 'control_agents', 'control_wosac', or " f"'control_sdc_only'. Got: {self.control_mode_str}" ) + + controller_values = { + "static": binding.CONTROLLER_STATIC, + "policy": binding.CONTROLLER_POLICY, + "replay": binding.CONTROLLER_REPLAY, + } + controller_options = "'static', 'policy', or 'replay'" + if self.sdc_controller_str not in controller_values: + raise ValueError(f"sdc_controller must be one of {controller_options}. Got: {self.sdc_controller_str}") + if self.non_sdc_controller_str not in controller_values: + raise ValueError( + f"non_sdc_controller must be one of {controller_options}. Got: {self.non_sdc_controller_str}" + ) + if self.non_vehicle_controller_str == "auto": + self.non_vehicle_controller_str = self.non_sdc_controller_str + elif self.non_vehicle_controller_str not in controller_values: + raise ValueError( + f"non_vehicle_controller must be 'auto' or one of {controller_options}. " + f"Got: {self.non_vehicle_controller_str}" + ) + self.sdc_controller = controller_values[self.sdc_controller_str] + self.non_sdc_controller = controller_values[self.non_sdc_controller_str] + self.non_vehicle_controller = controller_values[self.non_vehicle_controller_str] + if self.init_mode_str == "create_all_valid": self.init_mode = 0 elif self.init_mode_str == "create_only_controlled": @@ -315,6 +345,9 @@ def __init__( eval_mode=self.eval_mode, init_mode=self.init_mode, control_mode=self.control_mode, + sdc_controller=self.sdc_controller, + non_sdc_controller=self.non_sdc_controller, + non_vehicle_controller=self.non_vehicle_controller, simulation_mode=self.simulation_mode, init_steps=self.init_steps, seed=self.random_seed, @@ -404,6 +437,9 @@ def _env_init_kwargs(self, map_file, max_agents): "init_steps": self.init_steps, "init_mode": self.init_mode, "control_mode": self.control_mode, + "sdc_controller": self.sdc_controller, + "non_sdc_controller": self.non_sdc_controller, + "non_vehicle_controller": self.non_vehicle_controller, "simulation_mode": self.simulation_mode, "reward_conditioning": self.reward_conditioning, "reward_randomization": self.reward_randomization, @@ -493,6 +529,9 @@ def step(self, actions): eval_mode=self.eval_mode, init_mode=self.init_mode, control_mode=self.control_mode, + sdc_controller=self.sdc_controller, + non_sdc_controller=self.non_sdc_controller, + non_vehicle_controller=self.non_vehicle_controller, simulation_mode=self.simulation_mode, init_steps=self.init_steps, map_files=self.map_files, diff --git a/pufferlib/ocean/env_binding.h b/pufferlib/ocean/env_binding.h index 4822895dfb..e91d4073c3 100644 --- a/pufferlib/ocean/env_binding.h +++ b/pufferlib/ocean/env_binding.h @@ -1268,6 +1268,9 @@ PyMODINIT_FUNC PyInit_binding(void) { PyModule_AddIntConstant(m, "NUM_REWARD_COEFS", NUM_REWARD_COEFS); PyModule_AddIntConstant(m, "TARGET_STATIC", TARGET_STATIC); PyModule_AddIntConstant(m, "TARGET_DYNAMIC", TARGET_DYNAMIC); + PyModule_AddIntConstant(m, "CONTROLLER_STATIC", CONTROLLER_STATIC); + PyModule_AddIntConstant(m, "CONTROLLER_POLICY", CONTROLLER_POLICY); + PyModule_AddIntConstant(m, "CONTROLLER_REPLAY", CONTROLLER_REPLAY); PyObject_SetAttrString(m, "MULTI_LANE_FULL_SCORE_TIME", PyFloat_FromDouble(MULTI_LANE_FULL_SCORE_TIME)); PyObject_SetAttrString(m, "MULTI_LANE_HALF_SCORE_TIME", PyFloat_FromDouble(MULTI_LANE_HALF_SCORE_TIME)); From c53b9487197a9b44d6947dcb5a709d17a21fb5d4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wa=C3=ABl=20Doulazmi?= Date: Sun, 24 May 2026 03:02:44 +0200 Subject: [PATCH 2/3] Add IDM implementation --- pufferlib/config/ocean/drive.ini | 8 +- pufferlib/ocean/drive/drive.h | 8 + pufferlib/ocean/drive/drive.py | 8 +- pufferlib/ocean/drive/idm.h | 619 +++++++++++++++++++++++++++++++ pufferlib/ocean/env_binding.h | 1 + 5 files changed, 638 insertions(+), 6 deletions(-) create mode 100644 pufferlib/ocean/drive/idm.h diff --git a/pufferlib/config/ocean/drive.ini b/pufferlib/config/ocean/drive.ini index 1a32d80849..fee7e530ee 100644 --- a/pufferlib/config/ocean/drive.ini +++ b/pufferlib/config/ocean/drive.ini @@ -67,13 +67,13 @@ init_steps = 0 ; options: "control_vehicles", "control_agents", "control_tracks_to_predict", "control_sdc_only" control_mode = "control_vehicles" ; Controller used by agent 0, the canonical SDC/target. -; options: "static", "policy", "replay" +; options: "static", "policy", "replay", "idm" sdc_controller = "policy" ; Controller used by non-SDC vehicles. -; options: "static", "policy", "replay" +; options: "static", "policy", "replay", "idm" non_sdc_controller = "policy" -; Controller used by non-vehicle agents. "auto" follows non_sdc_controller. -; options: "auto", "static", "policy", "replay" +; Controller used by non-vehicle agents. "auto" follows non_sdc_controller unless it is "idm", then it uses "replay". +; options: "auto", "static", "policy", "replay", "idm" non_vehicle_controller = "auto" ; Options: "created_all_valid", "create_only_controlled" init_mode = "create_all_valid" diff --git a/pufferlib/ocean/drive/drive.h b/pufferlib/ocean/drive/drive.h index f47f9f20a8..d07f5d1d35 100644 --- a/pufferlib/ocean/drive/drive.h +++ b/pufferlib/ocean/drive/drive.h @@ -68,6 +68,7 @@ #define CONTROLLER_STATIC 0 #define CONTROLLER_POLICY 1 #define CONTROLLER_REPLAY 2 +#define CONTROLLER_IDM 3 // Simulation modes #define SIMULATION_GIGAFLOW 0 @@ -4930,6 +4931,8 @@ static void move_dynamics(Drive *env, int action_idx, int agent_idx) { return; } +#include "idm.h" + static void move_agent_with_controller(Drive *env, int action_idx, int agent_idx) { Agent *agent = &env->agents[agent_idx]; @@ -4937,6 +4940,11 @@ static void move_agent_with_controller(Drive *env, int action_idx, int agent_idx return; } + if (agent->controller == CONTROLLER_IDM) { + move_idm(env, agent_idx); + return; + } + if (agent->controller == CONTROLLER_REPLAY) { if (env->simulation_mode == SIMULATION_REPLAY) { move_expert(env, env->actions, agent_idx); diff --git a/pufferlib/ocean/drive/drive.py b/pufferlib/ocean/drive/drive.py index 951e7f789b..3f67f0ddeb 100644 --- a/pufferlib/ocean/drive/drive.py +++ b/pufferlib/ocean/drive/drive.py @@ -269,8 +269,9 @@ def __init__( "static": binding.CONTROLLER_STATIC, "policy": binding.CONTROLLER_POLICY, "replay": binding.CONTROLLER_REPLAY, + "idm": binding.CONTROLLER_IDM, } - controller_options = "'static', 'policy', or 'replay'" + controller_options = "'static', 'policy', 'replay', or 'idm'" if self.sdc_controller_str not in controller_values: raise ValueError(f"sdc_controller must be one of {controller_options}. Got: {self.sdc_controller_str}") if self.non_sdc_controller_str not in controller_values: @@ -278,7 +279,10 @@ def __init__( f"non_sdc_controller must be one of {controller_options}. Got: {self.non_sdc_controller_str}" ) if self.non_vehicle_controller_str == "auto": - self.non_vehicle_controller_str = self.non_sdc_controller_str + if self.non_sdc_controller_str == "idm": + self.non_vehicle_controller_str = "replay" + else: + self.non_vehicle_controller_str = self.non_sdc_controller_str elif self.non_vehicle_controller_str not in controller_values: raise ValueError( f"non_vehicle_controller must be 'auto' or one of {controller_options}. " diff --git a/pufferlib/ocean/drive/idm.h b/pufferlib/ocean/drive/idm.h new file mode 100644 index 0000000000..e7652fc994 --- /dev/null +++ b/pufferlib/ocean/drive/idm.h @@ -0,0 +1,619 @@ +#ifndef PUFFERLIB_OCEAN_DRIVE_IDM_H +#define PUFFERLIB_OCEAN_DRIVE_IDM_H + +#define IDM_MINIMUM_LEAD_DISTANCE 0.1f +#define IDM_MIN_SPACING 2.0f +#define IDM_SAFE_TIME_HEADWAY 2.0f +#define NUM_ACCELERATION_VALUES ((int) (sizeof(ACCELERATION_VALUES) / sizeof(ACCELERATION_VALUES[0]))) +#define IDM_MAX_ACCEL ACCELERATION_VALUES[NUM_ACCELERATION_VALUES - 1] +#define IDM_MAX_DECEL (-ACCELERATION_VALUES[0]) +#define IDM_DELTA 4.0f +#define IDM_LOOKAHEAD_TIME 5.0f +#define IDM_MIN_LOOKAHEAD 20.0f +#define IDM_MAX_LOOKAHEAD 80.0f +#define IDM_BBOX_MARGIN 0.05f +#define IDM_DEFAULT_DESIRED_SPEED 15.0f +#define IDM_ROUTE_SAMPLE_DS 1.0f +#define IDM_MAX_CANDIDATES 64 + +typedef struct { + int has_leader; + int leader_agent_idx; + int is_traffic_light; + float gap; + float leader_speed; +} IDMLeader; + +typedef struct { + int valid; + int route_idx; + int lane_idx; + int segment_idx; + float t; + float dist_sq; +} IDMLaneProjection; + +static inline IDMLeader idm_no_leader(void) { + IDMLeader leader = {0}; + leader.leader_agent_idx = -1; + leader.gap = INFINITY; + return leader; +} + +static inline void idm_update_best_leader( + IDMLeader *best, + int leader_agent_idx, + int is_traffic_light, + float gap, + float leader_speed) { + if (gap < 0.0f) { + gap = IDM_MINIMUM_LEAD_DISTANCE; + } + if (gap >= best->gap) { + return; + } + + best->has_leader = 1; + best->leader_agent_idx = leader_agent_idx; + best->is_traffic_light = is_traffic_light; + best->gap = fmaxf(gap, IDM_MINIMUM_LEAD_DISTANCE); + best->leader_speed = fmaxf(0.0f, leader_speed); +} + +static inline int idm_check_z_overlap(const Agent *a, const Agent *b) { + float a_bottom = a->sim_z; + float a_top = a->sim_z + a->sim_height; + float b_bottom = b->sim_z; + float b_top = b->sim_z + b->sim_height; + return !(a_top < b_bottom || b_top < a_bottom); +} + +static int idm_traffic_light_controls_lane(TrafficControlElement *traffic, int lane_idx) { + if (lane_idx == -1 || traffic->num_controlled_lanes <= 0) { + return 0; + } + for (int i = 0; i < traffic->num_controlled_lanes; i++) { + if (traffic->controlled_lanes[i] == lane_idx) { + return 1; + } + } + return 0; +} + +static inline int idm_is_stop_light_obstacle_state(int state) { + return state == TRAFFIC_CONTROL_STATE_RED || state == TRAFFIC_CONTROL_STATE_YELLOW; +} + +static IDMLaneProjection idm_project_to_route_lanes(Drive *env, Agent *agent); +static float idm_lane_segment_length(RoadMapElement *lane, int seg_idx); + +static inline void idm_agent_corners(const Agent *agent, float corners[4][2]) { + const float offsets[4][2] = {{1, 1}, {1, -1}, {-1, -1}, {-1, 1}}; + float half_length = 0.5f * agent->sim_length; + float half_width = 0.5f * agent->sim_width; + for (int i = 0; i < 4; i++) { + corners[i][0] = agent->sim_x + offsets[i][0] * half_length * agent->cos_heading + - offsets[i][1] * half_width * agent->sin_heading; + corners[i][1] = agent->sim_y + offsets[i][0] * half_length * agent->sin_heading + + offsets[i][1] * half_width * agent->cos_heading; + } +} + +static int idm_collect_route_candidates(Drive *env, int ego_idx, float lookahead, int *candidates, int max_candidates) { + Agent *ego = &env->agents[ego_idx]; + int count = 0; + + for (int i = 0; i < env->num_agents && count < max_candidates; i++) { + int other_idx = -1; + if (i < env->active_agent_count) { + other_idx = env->active_agent_indices[i]; + } else { + other_idx = env->static_agent_indices[i - env->active_agent_count]; + } + if (other_idx == -1 || other_idx == ego_idx) { + continue; + } + + Agent *other = &env->agents[other_idx]; + if (other->removed || other->sim_x == INVALID_POSITION || other->sim_valid == 0) { + continue; + } + if (!idm_check_z_overlap(ego, other)) { + continue; + } + + float dx = other->sim_x - ego->sim_x; + float dy = other->sim_y - ego->sim_y; + float max_dist = lookahead + 0.5f * ego->sim_length + 0.5f * other->sim_length + 5.0f + 2.0f * IDM_BBOX_MARGIN; + if (dx * dx + dy * dy > max_dist * max_dist) { + continue; + } + + candidates[count++] = other_idx; + } + + return count; +} + +static inline Agent idm_make_sample_agent(const Agent *ego, float x, float y, float z, float heading) { + Agent sample = *ego; + sample.sim_x = x; + sample.sim_y = y; + sample.sim_z = z; + sample.sim_heading = normalize_heading(heading); + sample.cos_heading = cosf(sample.sim_heading); + sample.sin_heading = sinf(sample.sim_heading); + sample.sim_length = ego->sim_length + 2.0f * IDM_BBOX_MARGIN; + sample.sim_width = ego->sim_width + 2.0f * IDM_BBOX_MARGIN; + sample.removed = 0; + sample.sim_valid = 1; + return sample; +} + +static int idm_sample_hits_agent(const Agent *sample, Agent *other) { + if (!idm_check_z_overlap(sample, other)) { + return 0; + } + + float dx = other->sim_x - sample->sim_x; + float dy = other->sim_y - sample->sim_y; + float local_radius = 0.5f * sample->sim_length + 0.5f * other->sim_length + sample->sim_width + other->sim_width + + 1.0f + 2.0f * IDM_BBOX_MARGIN; + if (dx * dx + dy * dy > local_radius * local_radius) { + return 0; + } + + Agent sample_expanded = *sample; + Agent other_expanded = *other; + other_expanded.sim_length = other->sim_length + 2.0f * IDM_BBOX_MARGIN; + other_expanded.sim_width = other->sim_width + 2.0f * IDM_BBOX_MARGIN; + return check_obb_collision(&sample_expanded, &other_expanded); +} + +static int idm_sample_hits_red_light(Drive *env, Agent *sample, int lane_idx) { + float corners[4][2]; + idm_agent_corners(sample, corners); + + for (int i = 0; i < env->num_traffic_elements; i++) { + TrafficControlElement *traffic = &env->traffic_elements[i]; + if (traffic->type != TRAFFIC_CONTROL_TYPE_TRAFFIC_LIGHT) { + continue; + } + if (!idm_traffic_light_controls_lane(traffic, lane_idx)) { + continue; + } + if (env->timestep < 0 || env->timestep >= traffic->state_length || traffic->states == NULL) { + continue; + } + if (!idm_is_stop_light_obstacle_state(traffic->states[env->timestep])) { + continue; + } + + float mid_x = 0.5f * (traffic->stop_line[0] + traffic->stop_line[3]); + float mid_y = 0.5f * (traffic->stop_line[1] + traffic->stop_line[4]); + float dx = sample->sim_x - mid_x; + float dy = sample->sim_y - mid_y; + if (dx * dx + dy * dy > TRAFFIC_LIGHT_DISTANCE_THRESHOLD * TRAFFIC_LIGHT_DISTANCE_THRESHOLD) { + continue; + } + + float heading_diff = compute_heading_diff(sample->sim_heading, traffic->heading); + if (fabsf(heading_diff) > RED_LIGHT_HEADING_THRESHOLD) { + continue; + } + + float sl_dx = traffic->stop_line[3] - traffic->stop_line[0]; + float sl_dy = traffic->stop_line[4] - traffic->stop_line[1]; + float ext = (STOP_LINE_EXTENSION_FACTOR - 1.0f) * 0.5f; + float ext_p1[2] = {traffic->stop_line[0] - ext * sl_dx, traffic->stop_line[1] - ext * sl_dy}; + float ext_p2[2] = {traffic->stop_line[3] + ext * sl_dx, traffic->stop_line[4] + ext * sl_dy}; + + for (int k = 0; k < 4; k++) { + if (k == 2) { + continue; + } + int next = (k + 1) % 4; + if (check_line_intersection(corners[k], corners[next], ext_p1, ext_p2)) { + return 1; + } + } + } + + return 0; +} + +static inline void idm_update_sample_agent_pose(Agent *sample, RoadMapElement *lane, int seg_idx, float t) { + sample->sim_x = lane->x[seg_idx] + t * (lane->x[seg_idx + 1] - lane->x[seg_idx]); + sample->sim_y = lane->y[seg_idx] + t * (lane->y[seg_idx + 1] - lane->y[seg_idx]); + sample->sim_z = lane->z[seg_idx] + t * (lane->z[seg_idx + 1] - lane->z[seg_idx]); + sample->sim_heading = normalize_heading(lane->headings[seg_idx]); + sample->cos_heading = cosf(sample->sim_heading); + sample->sin_heading = sinf(sample->sim_heading); +} + +static IDMLeader idm_find_leader_by_route_boxes(Drive *env, int ego_idx) { + Agent *ego = &env->agents[ego_idx]; + IDMLeader no_leader = idm_no_leader(); + IDMLaneProjection projection = idm_project_to_route_lanes(env, ego); + if (!projection.valid) { + return no_leader; + } + + float speed = fmaxf(0.0f, ego->sim_speed_signed); + float lookahead = clip(speed * IDM_LOOKAHEAD_TIME, IDM_MIN_LOOKAHEAD, IDM_MAX_LOOKAHEAD); + int candidates[IDM_MAX_CANDIDATES]; + int num_candidates = idm_collect_route_candidates(env, ego_idx, lookahead, candidates, IDM_MAX_CANDIDATES); + + Agent sample = idm_make_sample_agent(ego, ego->sim_x, ego->sim_y, ego->sim_z, ego->sim_heading); + float next_sample_s = IDM_ROUTE_SAMPLE_DS; + float traveled_s = 0.0f; + int route_idx = projection.route_idx; + int seg_idx = projection.segment_idx; + float t = projection.t; + + while (route_idx < ego->route_length && next_sample_s <= lookahead + 1e-4f) { + int lane_idx = ego->route[route_idx]; + if (lane_idx < 0 || lane_idx >= env->num_road_elements) { + break; + } + RoadMapElement *lane = &env->road_elements[lane_idx]; + if (lane->segment_length < 2) { + break; + } + + while (seg_idx < lane->segment_length - 1 && next_sample_s <= lookahead + 1e-4f) { + float seg_len = idm_lane_segment_length(lane, seg_idx); + if (seg_len < 1e-6f) { + seg_idx++; + t = 0.0f; + continue; + } + + float remaining = (1.0f - t) * seg_len; + if (traveled_s + remaining + 1e-4f < next_sample_s) { + traveled_s += remaining; + seg_idx++; + t = 0.0f; + continue; + } + + float sample_t = t + (next_sample_s - traveled_s) / seg_len; + sample_t = clip(sample_t, 0.0f, 1.0f); + idm_update_sample_agent_pose(&sample, lane, seg_idx, sample_t); + + if (idm_sample_hits_red_light(env, &sample, lane_idx)) { + idm_update_best_leader(&no_leader, -1, 1, next_sample_s, 0.0f); + return no_leader; + } + + IDMLeader best_at_sample = idm_no_leader(); + for (int i = 0; i < num_candidates; i++) { + int other_idx = candidates[i]; + Agent *other = &env->agents[other_idx]; + if (!idm_sample_hits_agent(&sample, other)) { + continue; + } + float leader_speed = other->sim_vx * sample.cos_heading + other->sim_vy * sample.sin_heading; + idm_update_best_leader(&best_at_sample, other_idx, 0, next_sample_s, leader_speed); + } + if (best_at_sample.has_leader) { + return best_at_sample; + } + + next_sample_s += IDM_ROUTE_SAMPLE_DS; + } + + route_idx++; + seg_idx = 0; + t = 0.0f; + } + + return no_leader; +} + +static inline float idm_lane_speed_limit(Drive *env, int lane_idx) { + if (lane_idx < 0 || lane_idx >= env->num_road_elements) { + return 0.0f; + } + return env->road_elements[lane_idx].speed_limit; +} + +static float idm_desired_speed(Drive *env, Agent *agent) { + float desired_speed = idm_lane_speed_limit(env, agent->current_lane_idx); + + if (desired_speed <= 0.0f && agent->route != NULL && agent->route_length > 0) { + int route_idx = agent->current_route_index; + if (route_idx < 0) { + route_idx = 0; + } else if (route_idx >= agent->route_length) { + route_idx = agent->route_length - 1; + } + desired_speed = idm_lane_speed_limit(env, agent->route[route_idx]); + } + + if (desired_speed <= 0.0f) { + desired_speed = IDM_DEFAULT_DESIRED_SPEED; + } + + return clip(desired_speed, 1.0f, MAX_SPEED); +} + +static float idm_compute_acceleration(Drive *env, Agent *agent, IDMLeader leader) { + float current_speed = fmaxf(0.0f, agent->sim_speed_signed); + float desired_speed = idm_desired_speed(env, agent); + float speed_ratio = current_speed / desired_speed; + float free_road_term = powf(speed_ratio, IDM_DELTA); + float leader_term = 0.0f; + + if (leader.has_leader) { + float s_star = IDM_MIN_SPACING + + fmaxf(0.0f, + current_speed * IDM_SAFE_TIME_HEADWAY + + current_speed * (current_speed - leader.leader_speed) + / (2.0f * sqrtf(IDM_MAX_ACCEL * IDM_MAX_DECEL))); + float lead_dist = fmaxf(leader.gap, IDM_MINIMUM_LEAD_DISTANCE); + leader_term = (s_star / lead_dist) * (s_star / lead_dist); + } + + return IDM_MAX_ACCEL * (1.0f - free_road_term - leader_term); +} + +static IDMLaneProjection idm_project_to_route_lanes(Drive *env, Agent *agent) { + IDMLaneProjection best = {0}; + best.route_idx = 0; + best.lane_idx = -1; + best.segment_idx = 0; + best.t = 0.0f; + best.dist_sq = INFINITY; + + if (agent->route == NULL || agent->route_length <= 0) { + return best; + } + + int start_route = agent->current_route_index - 1; + if (start_route < 0) { + start_route = 0; + } + int end_route = agent->current_route_index + 4; + if (end_route > agent->route_length) { + end_route = agent->route_length; + } + + for (int pass = 0; pass < 2; pass++) { + for (int route_idx = start_route; route_idx < end_route; route_idx++) { + int lane_idx = agent->route[route_idx]; + if (lane_idx < 0 || lane_idx >= env->num_road_elements) { + continue; + } + RoadMapElement *lane = &env->road_elements[lane_idx]; + if (lane->segment_length < 2) { + continue; + } + for (int seg_idx = 0; seg_idx < lane->segment_length - 1; seg_idx++) { + float dx = lane->x[seg_idx + 1] - lane->x[seg_idx]; + float dy = lane->y[seg_idx + 1] - lane->y[seg_idx]; + float dz = lane->z[seg_idx + 1] - lane->z[seg_idx]; + float seg_len_sq = dx * dx + dy * dy + dz * dz; + if (seg_len_sq < 1e-6f) { + continue; + } + + float ax = agent->sim_x - lane->x[seg_idx]; + float ay = agent->sim_y - lane->y[seg_idx]; + float az = agent->sim_z - lane->z[seg_idx]; + float t = (ax * dx + ay * dy + az * dz) / seg_len_sq; + t = clip(t, 0.0f, 1.0f); + + float px = lane->x[seg_idx] + t * dx; + float py = lane->y[seg_idx] + t * dy; + float pz = lane->z[seg_idx] + t * dz; + float err_x = agent->sim_x - px; + float err_y = agent->sim_y - py; + float err_z = agent->sim_z - pz; + float dist_sq = err_x * err_x + err_y * err_y + err_z * err_z; + + if (dist_sq < best.dist_sq) { + best.valid = 1; + best.route_idx = route_idx; + best.lane_idx = lane_idx; + best.segment_idx = seg_idx; + best.t = t; + best.dist_sq = dist_sq; + } + } + } + + if (best.valid) { + break; + } + + start_route = 0; + end_route = agent->route_length; + } + + return best; +} + +static float idm_lane_segment_length(RoadMapElement *lane, int seg_idx) { + float dx = lane->x[seg_idx + 1] - lane->x[seg_idx]; + float dy = lane->y[seg_idx + 1] - lane->y[seg_idx]; + float dz = lane->z[seg_idx + 1] - lane->z[seg_idx]; + return sqrtf(dx * dx + dy * dy + dz * dz); +} + +static int idm_set_pose_on_lane_segment( + Drive *env, + Agent *agent, + int route_idx, + int lane_idx, + int seg_idx, + float t, + float *old_heading_out) { + if (lane_idx < 0 || lane_idx >= env->num_road_elements) { + return 0; + } + RoadMapElement *lane = &env->road_elements[lane_idx]; + if (seg_idx < 0 || seg_idx >= lane->segment_length - 1) { + return 0; + } + t = clip(t, 0.0f, 1.0f); + + if (old_heading_out != NULL) { + *old_heading_out = agent->sim_heading; + } + + agent->sim_x = lane->x[seg_idx] + t * (lane->x[seg_idx + 1] - lane->x[seg_idx]); + agent->sim_y = lane->y[seg_idx] + t * (lane->y[seg_idx + 1] - lane->y[seg_idx]); + agent->sim_z = lane->z[seg_idx] + t * (lane->z[seg_idx + 1] - lane->z[seg_idx]); + agent->sim_heading = normalize_heading(lane->headings[seg_idx]); + agent->cos_heading = cosf(agent->sim_heading); + agent->sin_heading = sinf(agent->sim_heading); + agent->current_route_index = route_idx; + agent->current_lane_idx = lane_idx; + agent->current_lane_geometry_idx = seg_idx; + return 1; +} + +static int idm_refresh_route_at_lane_end(Drive *env, int agent_idx, int lane_idx) { + Agent *agent = &env->agents[agent_idx]; + if (lane_idx == -1) { + lane_idx = agent->current_lane_idx; + } + if (lane_idx == -1) { + return 0; + } + + if (!compute_new_route(env, agent_idx, lane_idx)) { + return 0; + } + compute_goals(env, agent_idx); + return 1; +} + +static int idm_advance_along_route_lanes(Drive *env, int agent_idx, float distance, float *old_heading_out) { + Agent *agent = &env->agents[agent_idx]; + if (distance <= 0.0f) { + if (old_heading_out != NULL) { + *old_heading_out = agent->sim_heading; + } + return 1; + } + + for (int attempt = 0; attempt < 4; attempt++) { + IDMLaneProjection projection = idm_project_to_route_lanes(env, agent); + if (!projection.valid) { + return 0; + } + + int route_idx = projection.route_idx; + int seg_idx = projection.segment_idx; + float t = projection.t; + int lane_idx = projection.lane_idx; + + while (route_idx < agent->route_length) { + lane_idx = agent->route[route_idx]; + if (lane_idx < 0 || lane_idx >= env->num_road_elements) { + return 0; + } + RoadMapElement *lane = &env->road_elements[lane_idx]; + if (lane->segment_length < 2) { + return 0; + } + + while (seg_idx < lane->segment_length - 1) { + float seg_len = idm_lane_segment_length(lane, seg_idx); + if (seg_len < 1e-6f) { + seg_idx++; + t = 0.0f; + continue; + } + + float remaining = (1.0f - t) * seg_len; + if (distance <= remaining) { + float next_t = t + distance / seg_len; + return idm_set_pose_on_lane_segment( + env, + agent, + route_idx, + lane_idx, + seg_idx, + next_t, + old_heading_out); + } + + distance -= remaining; + seg_idx++; + t = 0.0f; + } + + route_idx++; + seg_idx = 0; + t = 0.0f; + } + + if (!idm_refresh_route_at_lane_end(env, agent_idx, lane_idx)) { + return 0; + } + } + + return 0; +} + +static void idm_move_with_leader(Drive *env, int agent_idx, IDMLeader leader) { + Agent *agent = &env->agents[agent_idx]; + + if (agent->removed) { + invalidate_agent(agent); + return; + } + + if (agent->stopped || agent->sim_x == INVALID_POSITION) { + agent->sim_vx = 0.0f; + agent->sim_vy = 0.0f; + agent->yaw_rate = 0.0f; + agent->sim_speed = 0.0f; + agent->sim_speed_signed = 0.0f; + agent->a_long = 0.0f; + agent->a_lat = 0.0f; + agent->jerk_long = 0.0f; + agent->jerk_lat = 0.0f; + agent->steering_angle = 0.0f; + return; + } + + float old_a_long = agent->a_long; + float accel = idm_compute_acceleration(env, agent, leader); + accel = clip(accel, -IDM_MAX_DECEL, IDM_MAX_ACCEL); + + float current_speed = fmaxf(0.0f, agent->sim_speed_signed); + float new_speed = current_speed + accel * env->dt; + if (new_speed < 0.0f) { + new_speed = 0.0f; + } + accel = (new_speed - current_speed) / env->dt; + + float old_heading = agent->sim_heading; + float distance = new_speed * env->dt; + if (!idm_advance_along_route_lanes(env, agent_idx, distance, &old_heading)) { + agent->stopped = 1; + new_speed = 0.0f; + accel = (new_speed - current_speed) / env->dt; + } + agent->sim_vx = new_speed * agent->cos_heading; + agent->sim_vy = new_speed * agent->sin_heading; + agent->yaw_rate = compute_heading_diff(agent->sim_heading, old_heading) / env->dt; + agent->jerk_long = (accel - old_a_long) / env->dt; + float new_a_lat = new_speed * agent->yaw_rate; + agent->jerk_lat = (new_a_lat - agent->a_lat) / env->dt; + agent->a_long = accel; + agent->a_lat = new_a_lat; + agent->steering_angle = 0.0f; + update_agent_speed(agent); +} + +static void move_idm(Drive *env, int agent_idx) { + IDMLeader leader = idm_find_leader_by_route_boxes(env, agent_idx); + idm_move_with_leader(env, agent_idx, leader); +} + +#endif diff --git a/pufferlib/ocean/env_binding.h b/pufferlib/ocean/env_binding.h index e91d4073c3..26aecd78e0 100644 --- a/pufferlib/ocean/env_binding.h +++ b/pufferlib/ocean/env_binding.h @@ -1271,6 +1271,7 @@ PyMODINIT_FUNC PyInit_binding(void) { PyModule_AddIntConstant(m, "CONTROLLER_STATIC", CONTROLLER_STATIC); PyModule_AddIntConstant(m, "CONTROLLER_POLICY", CONTROLLER_POLICY); PyModule_AddIntConstant(m, "CONTROLLER_REPLAY", CONTROLLER_REPLAY); + PyModule_AddIntConstant(m, "CONTROLLER_IDM", CONTROLLER_IDM); PyObject_SetAttrString(m, "MULTI_LANE_FULL_SCORE_TIME", PyFloat_FromDouble(MULTI_LANE_FULL_SCORE_TIME)); PyObject_SetAttrString(m, "MULTI_LANE_HALF_SCORE_TIME", PyFloat_FromDouble(MULTI_LANE_HALF_SCORE_TIME)); From 5883697b5f029d7520db854558ac1562f6bbbbcb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wa=C3=ABl=20Doulazmi?= Date: Sun, 24 May 2026 03:18:04 +0200 Subject: [PATCH 3/3] Match the same constants used in nuPlan and extend footprint at intersections --- pufferlib/ocean/drive/idm.h | 94 ++++++++++++++++++++++++++++++++----- 1 file changed, 83 insertions(+), 11 deletions(-) diff --git a/pufferlib/ocean/drive/idm.h b/pufferlib/ocean/drive/idm.h index e7652fc994..272c673489 100644 --- a/pufferlib/ocean/drive/idm.h +++ b/pufferlib/ocean/drive/idm.h @@ -2,17 +2,16 @@ #define PUFFERLIB_OCEAN_DRIVE_IDM_H #define IDM_MINIMUM_LEAD_DISTANCE 0.1f -#define IDM_MIN_SPACING 2.0f -#define IDM_SAFE_TIME_HEADWAY 2.0f -#define NUM_ACCELERATION_VALUES ((int) (sizeof(ACCELERATION_VALUES) / sizeof(ACCELERATION_VALUES[0]))) -#define IDM_MAX_ACCEL ACCELERATION_VALUES[NUM_ACCELERATION_VALUES - 1] -#define IDM_MAX_DECEL (-ACCELERATION_VALUES[0]) +#define IDM_MIN_SPACING 1.0f +#define IDM_SAFE_TIME_HEADWAY 1.5f +#define IDM_MAX_ACCEL 1.0f +#define IDM_MAX_DECEL 3.0f #define IDM_DELTA 4.0f #define IDM_LOOKAHEAD_TIME 5.0f #define IDM_MIN_LOOKAHEAD 20.0f -#define IDM_MAX_LOOKAHEAD 80.0f +#define IDM_MAX_LOOKAHEAD 40.0f #define IDM_BBOX_MARGIN 0.05f -#define IDM_DEFAULT_DESIRED_SPEED 15.0f +#define IDM_DEFAULT_DESIRED_SPEED 10.0f #define IDM_ROUTE_SAMPLE_DS 1.0f #define IDM_MAX_CANDIDATES 64 @@ -68,6 +67,10 @@ static inline int idm_check_z_overlap(const Agent *a, const Agent *b) { return !(a_top < b_bottom || b_top < a_bottom); } +static inline float idm_projected_footprint_length(const Agent *agent) { + return 0.5f * agent->sim_length + fmaxf(0.0f, agent->sim_speed_signed) * IDM_SAFE_TIME_HEADWAY; +} + static int idm_traffic_light_controls_lane(TrafficControlElement *traffic, int lane_idx) { if (lane_idx == -1 || traffic->num_controlled_lanes <= 0) { return 0; @@ -124,7 +127,8 @@ static int idm_collect_route_candidates(Drive *env, int ego_idx, float lookahead float dx = other->sim_x - ego->sim_x; float dy = other->sim_y - ego->sim_y; - float max_dist = lookahead + 0.5f * ego->sim_length + 0.5f * other->sim_length + 5.0f + 2.0f * IDM_BBOX_MARGIN; + float max_dist = lookahead + 0.5f * ego->sim_length + idm_projected_footprint_length(other) + 5.0f + + 2.0f * IDM_BBOX_MARGIN; if (dx * dx + dy * dy > max_dist * max_dist) { continue; } @@ -150,7 +154,7 @@ static inline Agent idm_make_sample_agent(const Agent *ego, float x, float y, fl return sample; } -static int idm_sample_hits_agent(const Agent *sample, Agent *other) { +static int idm_boxes_overlap(const Agent *sample, const Agent *other) { if (!idm_check_z_overlap(sample, other)) { return 0; } @@ -169,7 +173,6 @@ static int idm_sample_hits_agent(const Agent *sample, Agent *other) { other_expanded.sim_width = other->sim_width + 2.0f * IDM_BBOX_MARGIN; return check_obb_collision(&sample_expanded, &other_expanded); } - static int idm_sample_hits_red_light(Drive *env, Agent *sample, int lane_idx) { float corners[4][2]; idm_agent_corners(sample, corners); @@ -231,6 +234,75 @@ static inline void idm_update_sample_agent_pose(Agent *sample, RoadMapElement *l sample->sin_heading = sinf(sample->sim_heading); } +static int idm_set_projected_agent_pose(Drive *env, Agent *agent, IDMLaneProjection projection, float distance) { + int route_idx = projection.route_idx; + int seg_idx = projection.segment_idx; + float t = projection.t; + + while (route_idx < agent->route_length) { + int lane_idx = agent->route[route_idx]; + if (lane_idx < 0 || lane_idx >= env->num_road_elements) { + return 0; + } + RoadMapElement *lane = &env->road_elements[lane_idx]; + if (lane->segment_length < 2) { + return 0; + } + + while (seg_idx < lane->segment_length - 1) { + float seg_len = idm_lane_segment_length(lane, seg_idx); + if (seg_len < 1e-6f) { + seg_idx++; + t = 0.0f; + continue; + } + + float remaining = (1.0f - t) * seg_len; + if (distance <= remaining) { + float next_t = t + distance / seg_len; + idm_update_sample_agent_pose(agent, lane, seg_idx, clip(next_t, 0.0f, 1.0f)); + return 1; + } + + distance -= remaining; + seg_idx++; + t = 0.0f; + } + + route_idx++; + seg_idx = 0; + t = 0.0f; + } + + return 0; +} + +static int idm_sample_hits_projected_agent(Drive *env, const Agent *sample, int other_idx) { + Agent *other = &env->agents[other_idx]; + if (idm_boxes_overlap(sample, other)) { + return 1; + } + + IDMLaneProjection projection = idm_project_to_route_lanes(env, other); + if (!projection.valid) { + return 0; + } + + Agent projected = *other; + float end_s = idm_projected_footprint_length(other); + for (float s = IDM_ROUTE_SAMPLE_DS; s <= end_s + 1e-4f; s += IDM_ROUTE_SAMPLE_DS) { + projected = *other; + if (!idm_set_projected_agent_pose(env, &projected, projection, s)) { + return 0; + } + if (idm_boxes_overlap(sample, &projected)) { + return 1; + } + } + + return 0; +} + static IDMLeader idm_find_leader_by_route_boxes(Drive *env, int ego_idx) { Agent *ego = &env->agents[ego_idx]; IDMLeader no_leader = idm_no_leader(); @@ -290,7 +362,7 @@ static IDMLeader idm_find_leader_by_route_boxes(Drive *env, int ego_idx) { for (int i = 0; i < num_candidates; i++) { int other_idx = candidates[i]; Agent *other = &env->agents[other_idx]; - if (!idm_sample_hits_agent(&sample, other)) { + if (!idm_sample_hits_projected_agent(env, &sample, other_idx)) { continue; } float leader_speed = other->sim_vx * sample.cos_heading + other->sim_vy * sample.sin_heading;