Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions pufferlib/config/ocean/drive.ini
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,15 @@ inactive_agent_threshold = 0.4
init_steps = 0
; options: "control_vehicles", "control_agents", "control_tracks_to_predict", "control_sdc_only"
control_mode = "control_vehicles"
; Controller used by agent 0, the canonical SDC/target.
; options: "static", "policy", "replay", "idm"
sdc_controller = "policy"
; Controller used by non-SDC vehicles.
; options: "static", "policy", "replay", "idm"
non_sdc_controller = "policy"
; Controller used by non-vehicle agents. "auto" follows non_sdc_controller unless it is "idm", then it uses "replay".
; options: "auto", "static", "policy", "replay", "idm"
non_vehicle_controller = "auto"
Comment on lines +71 to +77
; Options: "created_all_valid", "create_only_controlled"
init_mode = "create_all_valid"
; Enable computation of evaluation-only metrics
Expand Down
23 changes: 23 additions & 0 deletions pufferlib/ocean/drive/binding.c
Original file line number Diff line number Diff line change
Expand Up @@ -770,6 +770,20 @@ static PyObject *my_get(PyObject *dict, Env *env) {
}
Py_DECREF(tmp);

tmp = PyLong_FromLong(a->controller);
if (!tmp) {
Py_DECREF(agent);
Py_DECREF(agents_list);
return NULL;
}
if (PyDict_SetItemString(agent, "controller", tmp) < 0) {
Py_DECREF(tmp);
Py_DECREF(agent);
Py_DECREF(agents_list);
return NULL;
}
Py_DECREF(tmp);

tmp = PyLong_FromLong(a->mark_as_expert);
if (!tmp) {
Py_DECREF(agent);
Expand Down Expand Up @@ -1569,6 +1583,9 @@ static PyObject *my_shared(PyObject *self, PyObject *args, PyObject *kwargs) {
int s_map_counter = starting_map_counter;
int init_mode = unpack(kwargs, "init_mode");
int control_mode = unpack(kwargs, "control_mode");
int sdc_controller = unpack(kwargs, "sdc_controller");
int non_sdc_controller = unpack(kwargs, "non_sdc_controller");
int non_vehicle_controller = unpack(kwargs, "non_vehicle_controller");
int simulation_mode = unpack(kwargs, "simulation_mode");
int init_steps = unpack(kwargs, "init_steps");
int seed = unpack(kwargs, "seed");
Expand Down Expand Up @@ -1705,6 +1722,9 @@ static PyObject *my_shared(PyObject *self, PyObject *args, PyObject *kwargs) {
Drive *env = calloc(1, sizeof(Drive));
env->init_mode = init_mode;
env->control_mode = control_mode;
env->sdc_controller = sdc_controller;
env->non_sdc_controller = non_sdc_controller;
env->non_vehicle_controller = non_vehicle_controller;
env->simulation_mode = simulation_mode;
env->init_steps = init_steps;
env->num_max_agents = max_agents_per_env;
Expand Down Expand Up @@ -1827,6 +1847,9 @@ static int my_init(Env *env, PyObject *args, PyObject *kwargs) {
env->timestep = init_steps;
env->init_mode = (int) unpack(kwargs, "init_mode");
env->control_mode = (int) unpack(kwargs, "control_mode");
env->sdc_controller = (int) unpack(kwargs, "sdc_controller");
env->non_sdc_controller = (int) unpack(kwargs, "non_sdc_controller");
env->non_vehicle_controller = (int) unpack(kwargs, "non_vehicle_controller");
env->simulation_mode = (int) unpack(kwargs, "simulation_mode");
env->reward_conditioning = (bool) unpack(kwargs, "reward_conditioning");
env->reward_randomization = (bool) unpack(kwargs, "reward_randomization");
Expand Down
1 change: 1 addition & 0 deletions pufferlib/ocean/drive/datatypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,7 @@ struct Agent {
int num_goals_reached;
int active_agent;
int mark_as_expert;
int controller;
float cumulative_displacement;
int displacement_sample_count;
float path_progression;
Expand Down
77 changes: 69 additions & 8 deletions pufferlib/ocean/drive/drive.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,12 @@
#define CONTROL_WOSAC 2
#define CONTROL_SDC_ONLY 3

// Controller modes
#define CONTROLLER_STATIC 0
#define CONTROLLER_POLICY 1
#define CONTROLLER_REPLAY 2
#define CONTROLLER_IDM 3

// Simulation modes
#define SIMULATION_GIGAFLOW 0
#define SIMULATION_REPLAY 1
Expand Down Expand Up @@ -370,6 +376,9 @@ struct Drive {
int *tracks_to_predict;
int init_mode;
int control_mode;
int sdc_controller;
int non_sdc_controller;
int non_vehicle_controller;
int simulation_mode;
int termination_mode;
float inactive_agent_threshold;
Expand Down Expand Up @@ -3136,6 +3145,28 @@ static bool should_control_agent(Drive *env, int agent_idx) {
return agent->route_length != 0;
}

static int resolve_agent_controller(Drive *env, int agent_idx, int is_active, int replay_by_default) {
if (replay_by_default) {
return CONTROLLER_REPLAY;
}

Agent *agent = &env->agents[agent_idx];
int requested_controller = CONTROLLER_STATIC;
if (agent_idx == EGO_IDX) {
requested_controller = env->sdc_controller;
} else if (agent->type == VEHICLE) {
requested_controller = env->non_sdc_controller;
} else {
requested_controller = env->non_vehicle_controller;
}

if (requested_controller == CONTROLLER_POLICY && !is_active) {
return CONTROLLER_STATIC;
}

return requested_controller;
}

void set_active_agents(Drive *env) {
// Initialize
env->active_agent_count = 0; // Policy-controlled agents
Expand Down Expand Up @@ -3170,6 +3201,8 @@ void set_active_agents(Drive *env) {

for (int i = 0; i < successfully_created; i++) {
env->active_agent_indices[i] = active_agent_indices[i];
env->agents[active_agent_indices[i]].controller
= resolve_agent_controller(env, active_agent_indices[i], 1, 0);
}
free(active_agent_indices);

Expand Down Expand Up @@ -3222,12 +3255,15 @@ void set_active_agents(Drive *env) {
active_agent_indices[env->active_agent_count] = i;
env->active_agent_count++;
env->agents[i].active_agent = 1;
env->agents[i].controller = resolve_agent_controller(env, i, 1, 0);
} else if (is_log_replay || env->init_mode != INIT_ONLY_CONTROLLABLE_AGENTS) {
// In log-replay mode, all non-controlled agents become expert_static
static_agent_indices[env->static_agent_count] = i;
env->static_agent_count++;
env->agents[i].active_agent = 0;
if (is_log_replay || env->agents[i].mark_as_expert == 1 || env->active_agent_count == env->num_max_agents) {
int replay_by_default
= is_log_replay || env->agents[i].mark_as_expert == 1 || env->active_agent_count == env->num_max_agents;
env->agents[i].controller = resolve_agent_controller(env, i, 0, replay_by_default);
if (env->agents[i].controller == CONTROLLER_REPLAY) {
expert_static_agent_indices[env->expert_static_agent_count] = i;
env->expert_static_agent_count++;
env->agents[i].mark_as_expert = 1;
Expand Down Expand Up @@ -4895,6 +4931,32 @@ static void move_dynamics(Drive *env, int action_idx, int agent_idx) {
return;
}

#include "idm.h"

static void move_agent_with_controller(Drive *env, int action_idx, int agent_idx) {
Agent *agent = &env->agents[agent_idx];

if (agent->controller == CONTROLLER_STATIC) {
return;
}

if (agent->controller == CONTROLLER_IDM) {
move_idm(env, agent_idx);
return;
}

if (agent->controller == CONTROLLER_REPLAY) {
if (env->simulation_mode == SIMULATION_REPLAY) {
move_expert(env, env->actions, agent_idx);
}
Comment on lines +4949 to +4951
return;
}

if (agent->controller == CONTROLLER_POLICY && action_idx >= 0) {
move_dynamics(env, action_idx, agent_idx);
}
}

static inline void sample_erratic_flags(Drive *env, Agent *agent) {
agent->is_blind_partner
= (env->partner_blindness_prob > 0.0f && random_uniform(0.0f, 1.0f) < env->partner_blindness_prob) ? 1 : 0;
Expand Down Expand Up @@ -5025,18 +5087,17 @@ void c_step(Drive *env) {
env->timestep++;

// -> 1. Apply actions and move agents
// Move static experts
for (int i = 0; i < env->expert_static_agent_count; i++) {
int expert_idx = env->expert_static_agent_indices[i];
move_expert(env, env->actions, expert_idx);
// Move background agents according to their per-agent controller.
for (int i = 0; i < env->static_agent_count; i++) {
int background_idx = env->static_agent_indices[i];
move_agent_with_controller(env, -1, background_idx);
}
// Move active agents with policy actions
for (int i = 0; i < env->active_agent_count; i++) {
env->logs[i].score = 0.0f;
env->logs[i].episode_length += 1;
int agent_idx = env->active_agent_indices[i];
move_dynamics(env, i, agent_idx);
// move_expert(env, env->actions, agent_idx);
move_agent_with_controller(env, i, agent_idx);
}

// -> 2. Compute metrics and rewards
Expand Down
43 changes: 43 additions & 0 deletions pufferlib/ocean/drive/drive.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,9 @@ def __init__(
num_eval_scenarios=16,
init_mode="create_all_valid",
control_mode="control_vehicles",
sdc_controller="policy",
non_sdc_controller="policy",
non_vehicle_controller="auto",
map_dir=None,
target_type="static",
reward_conditioning=False,
Expand Down Expand Up @@ -234,6 +237,9 @@ def __init__(
self.init_steps = init_steps
self.init_mode_str = init_mode
self.control_mode_str = control_mode
self.sdc_controller_str = sdc_controller
self.non_sdc_controller_str = non_sdc_controller
self.non_vehicle_controller_str = non_vehicle_controller
self.simulation_mode_str = simulation_mode
self.map_dir = map_dir
self.map_files = sorted(os.path.join(map_dir, f) for f in os.listdir(map_dir) if f.endswith(".bin"))
Expand All @@ -258,6 +264,34 @@ def __init__(
"control_mode must be one of 'control_vehicles', 'control_agents', 'control_wosac', or "
f"'control_sdc_only'. Got: {self.control_mode_str}"
)

controller_values = {
"static": binding.CONTROLLER_STATIC,
"policy": binding.CONTROLLER_POLICY,
"replay": binding.CONTROLLER_REPLAY,
"idm": binding.CONTROLLER_IDM,
}
controller_options = "'static', 'policy', 'replay', or 'idm'"
if self.sdc_controller_str not in controller_values:
raise ValueError(f"sdc_controller must be one of {controller_options}. Got: {self.sdc_controller_str}")
if self.non_sdc_controller_str not in controller_values:
raise ValueError(
f"non_sdc_controller must be one of {controller_options}. Got: {self.non_sdc_controller_str}"
)
if self.non_vehicle_controller_str == "auto":
if self.non_sdc_controller_str == "idm":
self.non_vehicle_controller_str = "replay"
else:
self.non_vehicle_controller_str = self.non_sdc_controller_str
elif self.non_vehicle_controller_str not in controller_values:
raise ValueError(
f"non_vehicle_controller must be 'auto' or one of {controller_options}. "
f"Got: {self.non_vehicle_controller_str}"
)
self.sdc_controller = controller_values[self.sdc_controller_str]
self.non_sdc_controller = controller_values[self.non_sdc_controller_str]
self.non_vehicle_controller = controller_values[self.non_vehicle_controller_str]
Comment on lines +268 to +293

if self.init_mode_str == "create_all_valid":
self.init_mode = 0
elif self.init_mode_str == "create_only_controlled":
Expand Down Expand Up @@ -315,6 +349,9 @@ def __init__(
eval_mode=self.eval_mode,
init_mode=self.init_mode,
control_mode=self.control_mode,
sdc_controller=self.sdc_controller,
non_sdc_controller=self.non_sdc_controller,
non_vehicle_controller=self.non_vehicle_controller,
simulation_mode=self.simulation_mode,
init_steps=self.init_steps,
seed=self.random_seed,
Expand Down Expand Up @@ -404,6 +441,9 @@ def _env_init_kwargs(self, map_file, max_agents):
"init_steps": self.init_steps,
"init_mode": self.init_mode,
"control_mode": self.control_mode,
"sdc_controller": self.sdc_controller,
"non_sdc_controller": self.non_sdc_controller,
"non_vehicle_controller": self.non_vehicle_controller,
"simulation_mode": self.simulation_mode,
"reward_conditioning": self.reward_conditioning,
"reward_randomization": self.reward_randomization,
Expand Down Expand Up @@ -493,6 +533,9 @@ def step(self, actions):
eval_mode=self.eval_mode,
init_mode=self.init_mode,
control_mode=self.control_mode,
sdc_controller=self.sdc_controller,
non_sdc_controller=self.non_sdc_controller,
non_vehicle_controller=self.non_vehicle_controller,
simulation_mode=self.simulation_mode,
init_steps=self.init_steps,
map_files=self.map_files,
Expand Down
Loading
Loading